usql/usql.cpp

330 lines
14 KiB
C++

#include "usql.h"
#include "exception.h"
#include "ml_string.h"
#include <algorithm>
namespace usql {
USql::USql() {
// create catalogue tables first
std::vector<std::string> k_debug_sql_commands {
"create table usql_tables(name varchar(32) not null, modified boolean not null)",
"create table usql_columns(table_name varchar(32) not null, column_name varchar(32) not null, column_type varchar(16) not null, column_length integer not null, nullable boolean not null, column_order integer not null)",
"create table usql_indexes(index_name varchar(32) not null, table_name varchar(32), column_name varchar(32) not null)"
};
// create cataloque tables
for (const auto &command : k_debug_sql_commands) {
std::unique_ptr<Node> create_table_node = m_parser.parse(command);
const CreateTableNode &node = static_cast<CreateTableNode &>(*create_table_node);
Table table{node.table_name, node.cols_defs};
m_tables.push_back(table);
}
// insert data into cataloque tables
for (const auto &command : k_debug_sql_commands) {
std::unique_ptr<Node> create_table_node = m_parser.parse(command);
const CreateTableNode &node = static_cast<CreateTableNode &>(*create_table_node);
execute_create_table_sys_catalogue(node);
}
}
std::unique_ptr<Table> USql::execute(const std::string &command) {
try {
std::unique_ptr<Node> node = m_parser.parse(command);
// node->dump();
return execute(*node);
} catch (const std::exception &e) {
return create_stmt_result_table(-1, e.what(), 0);
}
}
std::unique_ptr<Table> USql::execute(Node &node) {
switch (node.node_type) {
case NodeType::create_table:
return execute_create_table(static_cast<CreateTableNode &>(node));
case NodeType::create_index:
return execute_create_index(static_cast<CreateIndexNode &>(node));
case NodeType::create_table_as_select:
return execute_create_table_as_table(static_cast<CreateTableAsSelectNode &>(node));
case NodeType::drop_table:
return execute_drop(static_cast<DropTableNode &>(node));
case NodeType::insert_into:
return execute_insert_into_table(static_cast<InsertIntoTableNode &>(node));
case NodeType::select_from:
return execute_select(static_cast<SelectFromTableNode &>(node));
case NodeType::delete_from:
return execute_delete(static_cast<DeleteFromTableNode &>(node));
case NodeType::update_table:
return execute_update(static_cast<UpdateTableNode &>(node));
case NodeType::load_table:
return execute_load(static_cast<LoadIntoTableNode &>(node));
case NodeType::save_table:
return execute_save(static_cast<SaveTableNode &>(node));
case NodeType::set:
return execute_set(static_cast<SetNode &>(node));
case NodeType::show:
return execute_show(static_cast<ShowNode &>(node));
default:
return create_stmt_result_table(-1, "unknown statement", 0);
}
}
bool USql::eval_relational_operator(const RelationalOperatorNode &filter, Table *table, Row &row) {
std::unique_ptr<ValueNode> left_value = eval_value_node(table, row, filter.left.get(), nullptr, nullptr);
std::unique_ptr<ValueNode> right_value = eval_value_node(table, row, filter.right.get(), nullptr, nullptr);
double comparator;
if (left_value->node_type == NodeType::null_value || right_value->node_type == NodeType::null_value) {
bool all_null = (left_value->isNull() && right_value->node_type == NodeType::null_value) ||
(right_value->isNull() && left_value->node_type == NodeType::null_value);
if (filter.op == RelationalOperatorType::is)
return all_null;
if (filter.op == RelationalOperatorType::is_not)
return !all_null;
return false;
} else if (left_value->node_type == NodeType::int_value && right_value->node_type == NodeType::int_value) {
comparator = (double)(left_value->getIntegerValue() - right_value->getIntegerValue());
} else if ((left_value->node_type == NodeType::int_value && right_value->node_type == NodeType::float_value) ||
(left_value->node_type == NodeType::float_value && right_value->node_type == NodeType::int_value) ||
(left_value->node_type == NodeType::float_value && right_value->node_type == NodeType::float_value)) {
comparator = left_value->getDoubleValue() - right_value->getDoubleValue();
} else if (left_value->node_type == NodeType::bool_value || right_value->node_type == NodeType::bool_value) {
bool bl = left_value->getBooleanValue();
bool br = right_value->getBooleanValue();
comparator = bl == br ? 0 : 1;
} else if (left_value->node_type == NodeType::string_value || right_value->node_type == NodeType::string_value) {
comparator = left_value->getStringValue().compare(right_value->getStringValue());
// date values are essentially int values so handled above
} else {
throw Exception("Undefined combination of types");
}
switch (filter.op) {
case RelationalOperatorType::equal:
return comparator == 0.0;
case RelationalOperatorType::not_equal:
return comparator != 0.0;
case RelationalOperatorType::greater:
return comparator > 0.0;
case RelationalOperatorType::greater_equal:
return comparator >= 0.0;
case RelationalOperatorType::lesser:
return comparator < 0.0;
case RelationalOperatorType::lesser_equal:
return comparator <= 0.0;
case RelationalOperatorType::is:
case RelationalOperatorType::is_not:
// already handled
throw Exception("error in is or is not handling");
}
throw Exception("invalid relational operator");
}
std::unique_ptr<ValueNode> USql::eval_value_node(Table *table, Row &row, Node *node, ColDefNode *col_def_node, ColValue *agg_func_value) {
if (node->node_type == NodeType::database_value) {
return eval_database_value_node(table, row, node);
} else if (node->node_type == NodeType::int_value || node->node_type == NodeType::float_value || node->node_type == NodeType::string_value || node->node_type == NodeType::bool_value) {
return eval_literal_value_node(row, node);
} else if (node->node_type == NodeType::function) {
return eval_function_value_node(table, row, node, col_def_node, agg_func_value);
} else if (node->node_type == NodeType::null_value) {
return std::make_unique<NullValueNode>();
} else if (node->node_type == NodeType::arithmetical_operator) {
return eval_arithmetic_operator(ColumnType::float_type, static_cast<ArithmeticalOperatorNode &>(*node), table, row);
}
throw Exception("unsupported node type");
}
std::unique_ptr<ValueNode> USql::eval_database_value_node(Table *table, Row &row, Node *node) {
auto *dvl = static_cast<DatabaseValueNode *>(node);
ColDefNode col_def = table->get_column_def(dvl->col_name); // TODO optimize it to just get this def once
ColValue &db_value = row[col_def.order];
if (db_value.isNull())
return std::make_unique<NullValueNode>();
if (col_def.type == ColumnType::integer_type)
return std::make_unique<IntValueNode>(db_value.getIntegerValue());
if (col_def.type == ColumnType::float_type)
return std::make_unique<DoubleValueNode>(db_value.getDoubleValue());
if (col_def.type == ColumnType::varchar_type)
return std::make_unique<StringValueNode>(db_value.getStringValue());
if (col_def.type == ColumnType::bool_type)
return std::make_unique<BooleanValueNode>(db_value.getBoolValue());
if (col_def.type == ColumnType::date_type)
return std::make_unique<IntValueNode>(db_value.getIntegerValue());
throw Exception("unknown database value type");
}
std::unique_ptr<ValueNode> USql::eval_literal_value_node(Row &row, Node *node) {
if (node->node_type == NodeType::int_value) {
auto *ivl = static_cast<IntValueNode *>(node);
return std::make_unique<IntValueNode>(ivl->value);
} else if (node->node_type == NodeType::float_value) {
auto *ivl = static_cast<DoubleValueNode *>(node);
return std::make_unique<DoubleValueNode>(ivl->value);
} else if (node->node_type == NodeType::string_value) {
auto *ivl = static_cast<StringValueNode *>(node);
return std::make_unique<StringValueNode>(ivl->value);
} else if (node->node_type == NodeType::bool_value) {
auto *ivl = static_cast<BooleanValueNode *>(node);
return std::make_unique<BooleanValueNode>(ivl->value);
}
// Date has no it's own value node (it is passed around as string)
throw Exception("invalid type");
}
std::unique_ptr<ValueNode> USql::eval_function_value_node(Table *table, Row &row, Node *node, ColDefNode *col_def_node, ColValue *agg_func_value) {
auto *fnc = static_cast<FunctionNode *>(node);
std::vector<std::unique_ptr<ValueNode>> evaluatedPars;
for(auto & param : fnc->params) {
evaluatedPars.push_back(eval_value_node(table, row, param.get(), nullptr, nullptr));
}
// coalesce function can have first parameter null, so must be calles before following "return null"
if (fnc->function == FunctionNode::Type::coalesce) return coalesce_function(evaluatedPars);
if (evaluatedPars.empty() || evaluatedPars[0]->isNull())
return std::make_unique<NullValueNode>();
if (fnc->function == FunctionNode::Type::lower) return lower_function(evaluatedPars);
if (fnc->function == FunctionNode::Type::upper) return upper_function(evaluatedPars);
if (fnc->function == FunctionNode::Type::to_date) return to_date_function(evaluatedPars);
if (fnc->function == FunctionNode::Type::to_char) return to_char_function(evaluatedPars);
if (fnc->function == FunctionNode::Type::to_int) return to_int_function(evaluatedPars);
if (fnc->function == FunctionNode::Type::to_float) return to_float_function(evaluatedPars);
if (fnc->function == FunctionNode::Type::date_add) return date_add_function(evaluatedPars);
if (fnc->function == FunctionNode::Type::pp) return pp_function(evaluatedPars);
if (fnc->function == FunctionNode::Type::count) return count_function(agg_func_value, evaluatedPars);
if (fnc->function == FunctionNode::Type::max) return max_function(evaluatedPars, col_def_node, agg_func_value);
if (fnc->function == FunctionNode::Type::min) return min_function(evaluatedPars, col_def_node, agg_func_value);
throw Exception("invalid function: " + FunctionNode::function_name(fnc->function));
}
bool USql::eval_logical_operator(LogicalOperatorNode &node, Table *pTable, Row &row) {
bool left = eval_where(&(*node.left), pTable, row);
if ((node.op == LogicalOperatorType::and_operator && !left) || (node.op == LogicalOperatorType::or_operator && left))
return left;
bool right = eval_where(&(*node.right), pTable, row);
return right;
}
std::unique_ptr<ValueNode> USql::eval_arithmetic_operator(ColumnType outType, ArithmeticalOperatorNode &node, Table *table, Row &row) {
if (node.op == ArithmeticalOperatorType::copy_value) {
return eval_value_node(table, row, node.left.get(), nullptr, nullptr);
}
std::unique_ptr<ValueNode> left = eval_value_node(table, row, node.left.get(), nullptr, nullptr);
std::unique_ptr<ValueNode> right = eval_value_node(table, row, node.right.get(), nullptr, nullptr);
if (left->isNull() || right->isNull())
return std::make_unique<NullValueNode>();
if (outType == ColumnType::float_type) {
auto l = left->getDoubleValue();
auto r = right->getDoubleValue();
switch (node.op) {
case ArithmeticalOperatorType::plus_operator:
return std::make_unique<DoubleValueNode>(l + r);
case ArithmeticalOperatorType::minus_operator:
return std::make_unique<DoubleValueNode>(l - r);
case ArithmeticalOperatorType::multiply_operator:
return std::make_unique<DoubleValueNode>(l * r);
case ArithmeticalOperatorType::divide_operator:
return std::make_unique<DoubleValueNode>(l / r);
default:
throw Exception("eval_arithmetic_operator, float type implement me!!");
}
} else if (outType == ColumnType::integer_type) {
auto l = left->getIntegerValue();
auto r = right->getIntegerValue();
switch (node.op) {
case ArithmeticalOperatorType::plus_operator:
return std::make_unique<IntValueNode>(l + r);
case ArithmeticalOperatorType::minus_operator:
return std::make_unique<IntValueNode>(l - r);
case ArithmeticalOperatorType::multiply_operator:
return std::make_unique<IntValueNode>(l * r);
case ArithmeticalOperatorType::divide_operator:
return std::make_unique<IntValueNode>(l / r);
default:
throw Exception("eval_arithmetic_operator, integer type implement me!!");
}
} else if (outType == ColumnType::varchar_type) {
auto l = left->getStringValue();
auto r = right->getStringValue();
switch (node.op) {
case ArithmeticalOperatorType::plus_operator:
return std::make_unique<StringValueNode>(l + r);
default:
throw Exception("eval_arithmetic_operator, varchar type implement me!!");
}
} else if (outType == ColumnType::date_type) {
auto l = left->getDateValue();
auto r = right->getDateValue();
switch (node.op) {
case ArithmeticalOperatorType::plus_operator:
return std::make_unique<IntValueNode>(l + r);
case ArithmeticalOperatorType::minus_operator:
return std::make_unique<IntValueNode>(l - r);
default:
throw Exception("eval_arithmetic_operator, date_type type implement me!!");
}
}
throw Exception("eval_arithmetic_operator, implement me!!");
}
Table *USql::find_table(const std::string &name) const {
auto name_cmp = [name](const Table& t) { return t.m_name == name; };
auto table_def = std::find_if(begin(m_tables), end(m_tables), name_cmp);
if (table_def != std::end(m_tables))
return const_cast<Table *>(table_def.operator->());
throw Exception("table not found (" + name + ")");
}
void USql::check_table_not_exists(const std::string &name) const {
auto name_cmp = [name](const Table& t) { return t.m_name == name; };
auto table_def = std::find_if(begin(m_tables), end(m_tables), name_cmp);
if (table_def != std::end(m_tables))
throw Exception("table already exists");
}
void USql::check_index_not_exists(const std::string &index_name) {
for (auto &table : m_tables)
if (table.get_index(index_name) != nullptr)
throw Exception("index already exists");
}
} // namespace