344 lines
12 KiB
C++
344 lines
12 KiB
C++
#include "usql.h"
|
|
#include "exception.h"
|
|
#include "ml_date.h"
|
|
#include "ml_string.h"
|
|
|
|
#include <algorithm>
|
|
#include <fstream>
|
|
|
|
namespace usql {
|
|
|
|
|
|
std::unique_ptr<Table> USql::execute_select(SelectFromTableNode &node) {
|
|
// find source table
|
|
Table *table = find_table(node.table_name);
|
|
|
|
// expand *
|
|
expand_asterix_char(node, table);
|
|
|
|
// create result table
|
|
std::vector<ColDefNode> result_tbl_col_defs{};
|
|
std::vector<int> source_table_col_index{};
|
|
for (int i = 0; i < node.cols_names->size(); i++) {
|
|
SelectColNode * col_node = &node.cols_names->operator[](i);
|
|
auto [src_tbl_col_index, rst_tbl_col_def] = get_column_definition(table, col_node, i);
|
|
|
|
source_table_col_index.push_back(src_tbl_col_index);
|
|
result_tbl_col_defs.push_back(rst_tbl_col_def);
|
|
}
|
|
|
|
// check for aggregate function
|
|
bool aggregate_funcs = check_for_aggregate_only_functions(node, result_tbl_col_defs.size());
|
|
|
|
// prepare result table structure
|
|
auto result = std::make_unique<Table>("result", result_tbl_col_defs);
|
|
|
|
// replace possible order by col names to col indexes and validate
|
|
setup_order_columns(node.order_by, result.get());
|
|
|
|
// execute access plan
|
|
Row* new_row = nullptr;
|
|
for (auto row = begin(table->m_rows); row != end(table->m_rows); ++row) {
|
|
// eval where for row
|
|
if (eval_where(node.where.get(), table, *row)) {
|
|
// prepare empty row and copy column values
|
|
// when agregate functions in result only one row for table
|
|
if (!aggregate_funcs || result->rows_count()==0) {
|
|
new_row = &result->create_empty_row();
|
|
}
|
|
|
|
for (auto idx = 0; idx < result->columns_count(); idx++) {
|
|
auto src_table_col_idx = source_table_col_index[idx];
|
|
|
|
if (src_table_col_idx == FUNCTION_CALL) {
|
|
auto evaluated_value = eval_value_node(table, *row, node.cols_names->operator[](idx).value.get(), &result_tbl_col_defs[idx], &new_row->operator[](idx));
|
|
ValueNode *col_value = evaluated_value.get();
|
|
|
|
new_row->setColumnValue(&result_tbl_col_defs[idx], col_value);
|
|
} else {
|
|
ColValue &col_value = row->operator[](src_table_col_idx);
|
|
new_row->setColumnValue(&result_tbl_col_defs[idx], col_value);
|
|
}
|
|
}
|
|
|
|
// add row to result
|
|
if (aggregate_funcs == 0) {
|
|
result->commit_row(*new_row);
|
|
}
|
|
}
|
|
}
|
|
// when aggregates commit this one row
|
|
if (aggregate_funcs && new_row != nullptr) {
|
|
result->commit_row(*new_row);
|
|
}
|
|
|
|
execute_distinct(node, result.get());
|
|
|
|
execute_order_by(node, table, result.get());
|
|
|
|
execute_offset_limit(node.offset_limit, result.get());
|
|
|
|
return result;
|
|
}
|
|
|
|
bool USql::check_for_aggregate_only_functions(SelectFromTableNode &node, int result_cols_cnt) const {
|
|
int aggregate_funcs = 0;
|
|
for (int i = 0; i < node.cols_names->size(); i++) {
|
|
SelectColNode * col_node = &node.cols_names->operator[](i);
|
|
if (col_node->value->node_type == NodeType::function) {
|
|
auto func_node = static_cast<FunctionNode *>(col_node->value.get());
|
|
if (func_node->function == "count" || func_node->function == "min" || func_node->function == "max")
|
|
aggregate_funcs++;
|
|
}
|
|
}
|
|
// check whether aggregates are not present or all columns are aggregates
|
|
if (aggregate_funcs > 0 && aggregate_funcs != result_cols_cnt) {
|
|
throw Exception("aggregate functions with no aggregates");
|
|
}
|
|
|
|
return aggregate_funcs > 0;
|
|
}
|
|
|
|
void USql::expand_asterix_char(SelectFromTableNode &node, Table *table) const {
|
|
if (node.cols_names->size() == 1 && node.cols_names->operator[](0).name == "*") {
|
|
node.cols_names->clear();
|
|
node.cols_names->reserve(table->columns_count());
|
|
for(const auto& col : table->m_col_defs) {
|
|
node.cols_names->emplace_back(SelectColNode{std::make_unique<DatabaseValueNode>(col.name), col.name});
|
|
}
|
|
}
|
|
}
|
|
|
|
void USql::setup_order_columns(std::vector<ColOrderNode> &node, Table *table) const {
|
|
for (auto& order_node : node) {
|
|
if (!order_node.col_name.empty()) {
|
|
ColDefNode col_def = table->get_column_def(order_node.col_name);
|
|
order_node.col_index = col_def.order;
|
|
} else {
|
|
order_node.col_index = order_node.col_index - 1; // user counts from 1
|
|
}
|
|
|
|
if (order_node.col_index < 0 || order_node.col_index >= table->columns_count())
|
|
throw Exception("unknown column in order by clause (" + order_node.col_name + ")");
|
|
}
|
|
}
|
|
|
|
void USql::execute_distinct(SelectFromTableNode &node, Table *result) {
|
|
if (!node.distinct) return;
|
|
|
|
auto compare_rows = [](const Row &a, const Row &b) { return a.compare(b) >= 0; };
|
|
std::sort(result->m_rows.begin(), result->m_rows.end(), compare_rows);
|
|
|
|
result->m_rows.erase(std::unique(result->m_rows.begin(), result->m_rows.end()), result->m_rows.end());
|
|
}
|
|
|
|
void USql::execute_order_by(SelectFromTableNode &node, Table *table, Table *result) {
|
|
if (node.order_by.empty()) return;
|
|
|
|
auto compare_rows = [&node, &result](const Row &a, const Row &b) {
|
|
for(const auto& order_by_col_def : node.order_by) {
|
|
ColDefNode col_def = result->get_column_def(order_by_col_def.col_index);
|
|
ColValue &a_val = a[col_def.order];
|
|
ColValue &b_val = b[col_def.order];
|
|
|
|
int compare = a_val.compare(b_val);
|
|
|
|
if (compare < 0) return order_by_col_def.ascending;
|
|
if (compare > 0) return !order_by_col_def.ascending;
|
|
}
|
|
return false;
|
|
};
|
|
|
|
std::sort(result->m_rows.begin(), result->m_rows.end(), compare_rows);
|
|
}
|
|
|
|
void USql::execute_offset_limit(OffsetLimitNode &node, Table *result) {
|
|
if (node.offset > 0)
|
|
result->m_rows.erase(result->m_rows.begin(), result->rows_count() > node.offset ? result->m_rows.begin() + node.offset : result->m_rows.end());
|
|
|
|
if (node.limit > 0 && node.limit < result->rows_count())
|
|
result->m_rows.erase(result->m_rows.begin() + node.limit, result->m_rows.end());
|
|
}
|
|
|
|
std::tuple<int, ColDefNode> USql::get_column_definition(Table *table, SelectColNode *select_col_node, int col_order ) {
|
|
return get_node_definition(table, select_col_node->value.get(), select_col_node->name, col_order );
|
|
}
|
|
|
|
ColDefNode USql::get_db_column_definition(Table *table, Node *node) {
|
|
if (node->node_type == NodeType::database_value) {
|
|
auto db_node = static_cast<DatabaseValueNode *>(node);
|
|
return table->get_column_def(db_node->col_name);
|
|
}
|
|
|
|
throw Exception("Undefined table node - get_db_column_definition");
|
|
}
|
|
|
|
std::tuple<int, ColDefNode> USql::get_node_definition(Table *table, Node * node, const std::string & col_name, int col_order ) {
|
|
if (node->node_type == NodeType::database_value) {
|
|
ColDefNode src_col_def = get_db_column_definition(table, node);
|
|
ColDefNode col_def = ColDefNode{col_name, src_col_def.type, col_order, src_col_def.length, src_col_def.null};
|
|
return std::make_tuple(src_col_def.order, col_def);
|
|
|
|
} else if (node->node_type == NodeType::function) {
|
|
auto func_node = static_cast<FunctionNode *>(node);
|
|
|
|
if (func_node->function == "to_string") {
|
|
ColDefNode col_def = ColDefNode{col_name, ColumnType::varchar_type, col_order, 32, true};
|
|
return std::make_tuple(-1, col_def);
|
|
} else if (func_node->function == "to_date") {
|
|
ColDefNode col_def = ColDefNode{col_name, ColumnType::integer_type, col_order, 1, true};
|
|
return std::make_tuple(-1, col_def);
|
|
} else if (func_node->function == "pp") {
|
|
ColDefNode col_def = ColDefNode{col_name, ColumnType::varchar_type, col_order, 10, true};
|
|
return std::make_tuple(-1, col_def);
|
|
} else if (func_node->function == "lower" || func_node->function == "upper") {
|
|
// TODO get length, use get_db_column_definition
|
|
ColDefNode col_def = ColDefNode{col_name, ColumnType::varchar_type, col_order, 256, true};
|
|
return std::make_tuple(-1, col_def);
|
|
} else if (func_node->function == "min" || func_node->function == "max") {
|
|
auto col_type= ColumnType::float_type;
|
|
int col_len = 1;
|
|
auto & v = func_node->params[0];
|
|
if (v->node_type == NodeType::database_value) {
|
|
ColDefNode src_col_def = get_db_column_definition(table, v.get());
|
|
col_type = src_col_def.type;
|
|
col_len = src_col_def.length;
|
|
}
|
|
ColDefNode col_def = ColDefNode{col_name, col_type, col_order, col_len, true};
|
|
return std::make_tuple(-1, col_def);
|
|
} else if (func_node->function == "count") {
|
|
ColDefNode col_def = ColDefNode{col_name, ColumnType::integer_type, col_order, 1, true};
|
|
return std::make_tuple(-1, col_def);
|
|
}
|
|
throw Exception("Unsupported function");
|
|
|
|
} else if (node->node_type == NodeType::arithmetical_operator) {
|
|
auto ari_node = static_cast<ArithmeticalOperatorNode *>(node);
|
|
|
|
auto [left_col_index, left_tbl_col_def] = get_node_definition(table, ari_node->left.get(), col_name, col_order );
|
|
auto [right_col_index, right_tbl_col_def] = get_node_definition(table, ari_node->right.get(), col_name, col_order );
|
|
|
|
ColumnType col_type; // TODO handle varchar and it len
|
|
if (left_tbl_col_def.type==ColumnType::float_type || right_tbl_col_def.type==ColumnType::float_type)
|
|
col_type = ColumnType::float_type;
|
|
else
|
|
col_type = ColumnType::integer_type;
|
|
|
|
ColDefNode col_def = ColDefNode{col_name, col_type, col_order, 1, true};
|
|
return std::make_tuple(-1, col_def);
|
|
|
|
} else if (node->node_type == NodeType::logical_operator) {
|
|
ColDefNode col_def = ColDefNode{col_name, ColumnType::bool_type, col_order, 1, true};
|
|
return std::make_tuple(-1, col_def);
|
|
|
|
} else if (node->node_type == NodeType::int_value) {
|
|
ColDefNode col_def = ColDefNode{col_name, ColumnType::integer_type, col_order, 1, true};
|
|
return std::make_tuple(-1, col_def);
|
|
|
|
} else if (node->node_type == NodeType::float_value) {
|
|
ColDefNode col_def = ColDefNode{col_name, ColumnType::float_type, col_order, 1, true};
|
|
return std::make_tuple(-1, col_def);
|
|
|
|
} else if (node->node_type == NodeType::string_value) {
|
|
// TODO right len
|
|
ColDefNode col_def = ColDefNode{col_name, ColumnType::varchar_type, col_order, 64, true};
|
|
return std::make_tuple(-1, col_def);
|
|
}
|
|
throw Exception("Unsupported node type");
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<Table> USql::execute_insert_into_table(InsertIntoTableNode &node) {
|
|
// find table
|
|
Table *table_def = find_table(node.table_name);
|
|
|
|
if (node.cols_names.size() != node.cols_values.size())
|
|
throw Exception("Incorrect number of values");
|
|
|
|
// prepare empty new_row
|
|
Row& new_row = table_def->create_empty_row();
|
|
|
|
// copy values
|
|
for (size_t i = 0; i < node.cols_names.size(); i++) {
|
|
ColDefNode col_def = table_def->get_column_def(node.cols_names[i].col_name);
|
|
auto col_value = eval_value_node(table_def, new_row, node.cols_values[i].get(), nullptr, nullptr);
|
|
|
|
new_row.setColumnValue(&col_def, col_value.get());
|
|
}
|
|
|
|
// append new_row
|
|
table_def->commit_row(new_row);
|
|
|
|
return create_stmt_result_table(0, "insert succeeded", 1);
|
|
}
|
|
|
|
|
|
|
|
std::unique_ptr<Table> USql::execute_delete(DeleteFromTableNode &node) {
|
|
// find source table
|
|
Table *table = find_table(node.table_name);
|
|
|
|
// execute access plan
|
|
auto affected_rows = table->rows_count();
|
|
|
|
table->m_rows.erase(
|
|
std::remove_if(table->m_rows.begin(), table->m_rows.end(),
|
|
[&node, table](Row &row){return eval_where(node.where.get(), table, row);}),
|
|
table->m_rows.end());
|
|
|
|
affected_rows -= table->rows_count();
|
|
|
|
return create_stmt_result_table(0, "delete succeeded", affected_rows);
|
|
}
|
|
|
|
|
|
std::unique_ptr<Table> USql::execute_update(UpdateTableNode &node) {
|
|
// find source table
|
|
Table *table = find_table(node.table_name);
|
|
|
|
// execute access plan
|
|
int affected_rows = 0;
|
|
for (auto row = begin(table->m_rows); row != end(table->m_rows); ++row) {
|
|
// eval where for row
|
|
if (eval_where(node.where.get(), table, *row)) {
|
|
int i = 0;
|
|
for (const auto& col : node.cols_names) {
|
|
// TODO cache it like in select
|
|
ColDefNode col_def = table->get_column_def(col.col_name);
|
|
std::unique_ptr<ValueNode> new_val = eval_arithmetic_operator(col_def.type,
|
|
static_cast<ArithmeticalOperatorNode &>(*node.values[i]),
|
|
table, *row);
|
|
|
|
usql::Table::validate_column(&col_def, new_val.get());
|
|
row->setColumnValue(&col_def, new_val.get());
|
|
i++;
|
|
}
|
|
affected_rows++;
|
|
// TODO tady je problem, ze kdyz to zfajluje na jednom radku ostatni by se nemely provest
|
|
}
|
|
}
|
|
|
|
return create_stmt_result_table(0, "update succeeded", affected_rows);
|
|
}
|
|
|
|
|
|
bool USql::eval_where(Node *where, Table *table, Row &row) {
|
|
switch (where->node_type) {
|
|
case NodeType::true_node:
|
|
return true;
|
|
case NodeType::relational_operator: // just one condition
|
|
return eval_relational_operator(*((RelationalOperatorNode *) where), table, row);
|
|
case NodeType::logical_operator:
|
|
return eval_logical_operator(*((LogicalOperatorNode *) where), table, row);
|
|
default:
|
|
throw Exception("Wrong node type");
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
|
|
} // namespace
|