#include "usql.h" #include "exception.h" #include "ml_string.h" #include namespace usql { std::pair> USql::probe_index_scan(const Node *where, Table *table) const { bool indexscan_possible = normalize_where(where); if (indexscan_possible && Settings::get_bool_setting("USE_INDEXSCAN")) { // where->dump(); return look_for_usable_index(where, table); } // no index scan return std::make_pair(false, std::vector{}); } std::pair> USql::look_for_usable_index(const Node *where, Table *table) const { if (where->node_type == NodeType::relational_operator) { auto * ron = (RelationalOperatorNode *)where; // TODO implement >, >=, <=, < // https://en.cppreference.com/w/cpp/container/map/upper_bound if (ron->op == RelationalOperatorType::equal) { if (ron->left->node_type == NodeType::database_value && ((ron->right->node_type == NodeType::int_value) || (ron->right->node_type == NodeType::string_value)) ) { auto col_name = ((DatabaseValueNode *)ron->left.get())->col_name; Index * used_index = table->get_index_for_column(col_name); if (used_index != nullptr) { std::vector rowids = used_index->search((ValueNode *)ron->right.get()); #ifndef NDEBUG std::cout << "using index " << table->m_name << "(" << used_index->get_column_name() << "), " << rowids.size() << "/" << table->rows_count() << std::endl; #endif return std::make_pair(true, rowids); } } } } else if (where->node_type == NodeType::logical_operator) { auto * operatorNode = (LogicalOperatorNode *)where; if (operatorNode->op == LogicalOperatorType::and_operator) { auto [use_index, rowids] = look_for_usable_index(operatorNode->left.get(), table); if (use_index) { return std::make_pair(true, rowids); } return look_for_usable_index(operatorNode->right.get(), table); } } // no index available return std::make_pair(false, std::vector{}); } bool USql::normalize_where(const Node *node) const { // normalize relational operators "layout" and check whether index scan even possible // unify relational operators tha left node is always database value if (node->node_type == NodeType::relational_operator) { // TODO more optimizations here, for example node 1 = 2 etc auto * ron = (RelationalOperatorNode *)node; if (ron->right->node_type == NodeType::database_value && ((ron->left->node_type == NodeType::int_value) || (ron->left->node_type == NodeType::string_value)) ) { std::swap(ron->left, ron->right); } return true; } else if (node->node_type == NodeType::logical_operator) { auto * operatorNode = (LogicalOperatorNode *)node; if (operatorNode->op == LogicalOperatorType::or_operator) { return false; } bool left_subnode = normalize_where(operatorNode->left.get()); bool right_subnode = normalize_where(operatorNode->left.get()); return left_subnode && right_subnode; } return true; } void USql::select_row(SelectFromTableNode &where_node, Table *src_table, Row *src_row, Table *rslt_table, const std::vector &rslt_tbl_col_defs, const std::vector &src_table_col_index, bool is_aggregated) { Row *rslt_row = nullptr; // when aggregate functions in rslt_table only one row exists // TODO add function to get rows count if (is_aggregated && !rslt_table->m_rows.empty()) rslt_row = &rslt_table->m_rows[0]; else rslt_row = &rslt_table->create_empty_row(); for (auto idx = 0; idx < rslt_table->columns_count(); idx++) { auto src_table_col_idx = src_table_col_index[idx]; if (src_table_col_idx == FUNCTION_CALL) { auto evaluated_value = eval_value_node(src_table, *src_row, where_node.cols_names->operator[](idx).value.get(), const_cast(&rslt_tbl_col_defs[idx]), &rslt_row->operator[](idx)); ValueNode *col_value = evaluated_value.get(); rslt_row->setColumnValue((ColDefNode *) &rslt_tbl_col_defs[idx], col_value); } else { ColValue &col_value = src_row->operator[](src_table_col_idx); rslt_row->setColumnValue((ColDefNode *) &rslt_tbl_col_defs[idx], col_value); } } // for aggregate is validated more than needed rslt_table->commit_row(*rslt_row); } bool USql::check_for_aggregate_only_functions(SelectFromTableNode &node, size_t result_cols_cnt) { int aggregate_funcs = 0; for (int i = 0; i < node.cols_names->size(); i++) { SelectColNode * col_node = &node.cols_names->operator[](i); if (col_node->value->node_type == NodeType::function) { auto func_node = static_cast(col_node->value.get()); if (func_node->function == "count" || func_node->function == "min" || func_node->function == "max") aggregate_funcs++; } } // check whether aggregates are not present or all columns are aggregates if (aggregate_funcs > 0 && aggregate_funcs != result_cols_cnt) { throw Exception("aggregate functions with no aggregates"); } return aggregate_funcs > 0; } void USql::expand_asterix_char(SelectFromTableNode &node, Table *table) { if (node.cols_names->size() == 1 && node.cols_names->operator[](0).name == "*") { node.cols_names->clear(); node.cols_names->reserve(table->columns_count()); for(const auto& col : table->m_col_defs) { node.cols_names->emplace_back(SelectColNode{std::make_unique(col.name), col.name}); } } } void USql::setup_order_columns(std::vector &node, Table *table) { for (auto& order_node : node) { if (!order_node.col_name.empty()) { ColDefNode col_def = table->get_column_def(order_node.col_name); order_node.col_index = col_def.order; } else { order_node.col_index = order_node.col_index - 1; // user counts from 1 } if (order_node.col_index < 0 || order_node.col_index >= table->columns_count()) throw Exception("unknown column in order by clause (" + order_node.col_name + ")"); } } void USql::execute_distinct(SelectFromTableNode &node, Table *result) { if (!node.distinct) return; auto compare_rows = [](const Row &a, const Row &b) { return a.compare(b) >= 0; }; std::sort(result->m_rows.begin(), result->m_rows.end(), compare_rows); result->m_rows.erase(std::unique(result->m_rows.begin(), result->m_rows.end()), result->m_rows.end()); } void USql::execute_order_by(SelectFromTableNode &node, Table *result) { if (node.order_by.empty()) return; auto compare_rows = [&node, &result](const Row &a, const Row &b) { for(const auto& order_by_col_def : node.order_by) { ColDefNode col_def = result->get_column_def(order_by_col_def.col_index); ColValue &a_val = a[col_def.order]; ColValue &b_val = b[col_def.order]; int compare = a_val.compare(b_val); if (compare < 0) return order_by_col_def.ascending; if (compare > 0) return !order_by_col_def.ascending; } return false; }; std::sort(result->m_rows.begin(), result->m_rows.end(), compare_rows); } void USql::execute_offset_limit(OffsetLimitNode &node, Table *result) { if (node.offset > 0) result->m_rows.erase(result->m_rows.begin(), result->rows_count() > node.offset ? result->m_rows.begin() + node.offset : result->m_rows.end()); if (node.limit > 0 && node.limit < result->rows_count()) result->m_rows.erase(result->m_rows.begin() + node.limit, result->m_rows.end()); } bool USql::eval_where(Node *where, Table *table, Row &row) { switch (where->node_type) { case NodeType::true_node: return true; case NodeType::relational_operator: // just one condition return eval_relational_operator(*((RelationalOperatorNode *)where), table, row); case NodeType::logical_operator: return eval_logical_operator(*((LogicalOperatorNode *)where), table, row); default: throw Exception("Wrong node type"); } } std::tuple USql::get_column_definition(Table *table, SelectColNode *select_col_node, int col_order ) { return get_node_definition(table, select_col_node->value.get(), select_col_node->name, col_order ); } ColDefNode USql::get_db_column_definition(Table *table, Node *node) { if (node->node_type == NodeType::database_value) { auto db_node = static_cast(node); return table->get_column_def(db_node->col_name); } throw Exception("Undefined table node - get_db_column_definition"); } std::tuple USql::get_node_definition(Table *table, Node * node, const std::string & col_name, int col_order ) { if (node->node_type == NodeType::database_value) { ColDefNode src_col_def = get_db_column_definition(table, node); ColDefNode col_def = ColDefNode{col_name, src_col_def.type, col_order, src_col_def.length, src_col_def.null}; return std::make_tuple(src_col_def.order, col_def); } else if (node->node_type == NodeType::function) { auto func_node = static_cast(node); if (func_node->function == "to_string") { ColDefNode col_def = ColDefNode{col_name, ColumnType::varchar_type, col_order, 32, true}; return std::make_tuple(-1, col_def); } else if (func_node->function == "to_date") { ColDefNode col_def = ColDefNode{col_name, ColumnType::integer_type, col_order, 1, true}; return std::make_tuple(-1, col_def); } else if (func_node->function == "pp") { ColDefNode col_def = ColDefNode{col_name, ColumnType::varchar_type, col_order, 10, true}; return std::make_tuple(-1, col_def); } else if (func_node->function == "lower" || func_node->function == "upper") { // TODO get length, use get_db_column_definition ColDefNode col_def = ColDefNode{col_name, ColumnType::varchar_type, col_order, 256, true}; return std::make_tuple(-1, col_def); } else if (func_node->function == "min" || func_node->function == "max") { auto col_type= ColumnType::float_type; int col_len = 1; auto & v = func_node->params[0]; if (v->node_type == NodeType::database_value) { ColDefNode src_col_def = get_db_column_definition(table, v.get()); col_type = src_col_def.type; col_len = src_col_def.length; } ColDefNode col_def = ColDefNode{col_name, col_type, col_order, col_len, true}; return std::make_tuple(-1, col_def); } else if (func_node->function == "count") { ColDefNode col_def = ColDefNode{col_name, ColumnType::integer_type, col_order, 1, true}; return std::make_tuple(-1, col_def); } throw Exception("Unsupported function"); } else if (node->node_type == NodeType::arithmetical_operator) { auto ari_node = static_cast(node); auto [left_col_index, left_tbl_col_def] = get_node_definition(table, ari_node->left.get(), col_name, col_order ); auto [right_col_index, right_tbl_col_def] = get_node_definition(table, ari_node->right.get(), col_name, col_order ); ColumnType col_type; // TODO handle varchar and it len if (left_tbl_col_def.type==ColumnType::float_type || right_tbl_col_def.type==ColumnType::float_type) col_type = ColumnType::float_type; else col_type = ColumnType::integer_type; ColDefNode col_def = ColDefNode{col_name, col_type, col_order, 1, true}; return std::make_tuple(-1, col_def); } else if (node->node_type == NodeType::logical_operator) { ColDefNode col_def = ColDefNode{col_name, ColumnType::bool_type, col_order, 1, true}; return std::make_tuple(-1, col_def); } else if (node->node_type == NodeType::int_value) { ColDefNode col_def = ColDefNode{col_name, ColumnType::integer_type, col_order, 1, true}; return std::make_tuple(-1, col_def); } else if (node->node_type == NodeType::float_value) { ColDefNode col_def = ColDefNode{col_name, ColumnType::float_type, col_order, 1, true}; return std::make_tuple(-1, col_def); } else if (node->node_type == NodeType::string_value) { // TODO right len ColDefNode col_def = ColDefNode{col_name, ColumnType::varchar_type, col_order, 64, true}; return std::make_tuple(-1, col_def); } throw Exception("Unsupported node type"); } std::unique_ptr USql::execute_insert_into_table(const InsertIntoTableNode &node) { // find table Table *table_def = find_table(node.table_name); if (node.cols_names.size() != node.cols_values.size()) throw Exception("Incorrect number of values"); // prepare empty new_row Row& new_row = table_def->create_empty_row(); // copy values for (size_t i = 0; i < node.cols_names.size(); i++) { ColDefNode col_def = table_def->get_column_def(node.cols_names[i].col_name); auto col_value = eval_value_node(table_def, new_row, node.cols_values[i].get(), nullptr, nullptr); new_row.setColumnValue(&col_def, col_value.get()); } // append new_row table_def->commit_row(new_row); return create_stmt_result_table(0, "insert succeeded", 1); } std::unique_ptr
USql::execute_delete(const DeleteFromTableNode &node) { size_t affected_rows = 0; // find source table Table *table = find_table(node.table_name); // execute access plan Table::rows_scanner i = get_iterator(table, node.where.get()); while(Row *row = i.next()) { if (eval_where(node.where.get(), table, *row)) { row->set_deleted(); table->unindex_row(*row); affected_rows++; } } return create_stmt_result_table(0, "delete succeeded", affected_rows); } std::unique_ptr
USql::execute_update(const UpdateTableNode &node) { size_t affected_rows = 0; // find source table Table *table = find_table(node.table_name); // execute access plan Table::rows_scanner i = get_iterator(table, node.where.get()); while(Row *row = i.next()) { if (eval_where(node.where.get(), table, *row)) { Row old_row = * row; int col_idx = 0; for (const auto& col : node.cols_names) { // TODO cache it like in select ColDefNode col_def = table->get_column_def(col.col_name); std::unique_ptr new_val = eval_arithmetic_operator(col_def.type, static_cast(*node.values[col_idx]), table, *row); usql::Table::validate_column(&col_def, new_val.get()); row->setColumnValue(&col_def, new_val.get()); col_idx++; } table->reindex_row(old_row, *row); affected_rows++; // TODO tady je problem, ze kdyz to zfajluje na jednom radku ostatni by se nemely provest } } return create_stmt_result_table(0, "update succeeded", affected_rows); } std::unique_ptr
USql::execute_select(SelectFromTableNode &node) const { // find source table Table *table = find_table(node.table_name); // expand * expand_asterix_char(node, table); // create result table std::vector result_tbl_col_defs{}; std::vector source_table_col_index{}; for (int i = 0; i < node.cols_names->size(); i++) { SelectColNode *col_node = &node.cols_names->operator[](i); auto [src_tbl_col_index, rst_tbl_col_def] = get_column_definition(table, col_node, i); source_table_col_index.push_back(src_tbl_col_index); result_tbl_col_defs.push_back(rst_tbl_col_def); } // check for aggregate function bool is_aggregated = check_for_aggregate_only_functions(node, result_tbl_col_defs.size()); // prepare result table structure auto result = std::make_unique
("result", result_tbl_col_defs); // replace possible order by col names to col indexes and validate setup_order_columns(node.order_by, result.get()); // execute access plan Table::rows_scanner i = get_iterator(table, node.where.get()); while(Row *row = i.next()) { if (eval_where(node.where.get(), table, *row)) { // put it into row_scanner.next select_row(node, table, row, result.get(), result_tbl_col_defs, source_table_col_index, is_aggregated); } } execute_distinct(node, result.get()); execute_order_by(node, result.get()); execute_offset_limit(node.offset_limit, result.get()); return result; } Table::rows_scanner USql::get_iterator(Table *table, const Node *where) const { auto[use_index, rowids] = probe_index_scan(where, table); if (use_index) return Table::rows_scanner(table, rowids); else return Table::rows_scanner(table); } } // namespace