diff --git a/CMakeLists.txt b/CMakeLists.txt index 0ce8327..ebf7f0e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,7 +16,7 @@ set(PROJECT_NAME usql) include_directories(${CMAKE_SOURCE_DIR}/clib ${CMAKE_SOURCE_DIR}) set(SOURCE - exception.cpp lexer.cpp parser.cpp usql.cpp main.cpp table.cpp table.h row.cpp row.h csvreader.cpp csvreader.h ml_date.cpp settings.cpp clib/ml_string.cpp clib/linenoise.c) + exception.cpp lexer.cpp parser.cpp usql.cpp usql_ddl.cpp usql_dml.cpp main.cpp table.cpp row.cpp row.h csvreader.cpp ml_date.cpp settings.cpp clib/ml_string.cpp clib/linenoise.c) add_executable(${PROJECT_NAME} ${SOURCE}) diff --git a/lexer.cpp b/lexer.cpp index 4015bcf..57a0c55 100644 --- a/lexer.cpp +++ b/lexer.cpp @@ -120,111 +120,59 @@ namespace usql { TokenType Lexer::type(const std::string &token) { // FIXME 'one is evaluated as identifier - if (token == ";") - return TokenType::semicolon; - if (token == "+") - return TokenType::plus; - if (token == "-") - return TokenType::minus; - if (token == "*") - return TokenType::multiply; - if (token == "/") - return TokenType::divide; - if (token == "(") - return TokenType::open_paren; - if (token == ")") - return TokenType::close_paren; - if (token == "=") - return TokenType::equal; - if (token == "!=" || token == "<>") - return TokenType::not_equal; - if (token == ">") - return TokenType::greater; - if (token == ">=") - return TokenType::greater_equal; - if (token == "<") - return TokenType::lesser; - if (token == "<=") - return TokenType::lesser_equal; - if (token == "is") - return TokenType::is; - if (token == "as") - return TokenType::keyword_as; - if (token == "create") - return TokenType::keyword_create; - if (token == "drop") - return TokenType::keyword_drop; - if (token == "where") - return TokenType::keyword_where; - if (token == "order") - return TokenType::keyword_order; - if (token == "by") - return TokenType::keyword_by; - if (token == "offset") - return TokenType::keyword_offset; - if (token == "limit") - return TokenType::keyword_limit; - if (token == "asc") - return TokenType::keyword_asc; - if (token == "desc") - return TokenType::keyword_desc; - if (token == "from") - return TokenType::keyword_from; - if (token == "delete") - return TokenType::keyword_delete; - if (token == "table") - return TokenType::keyword_table; - if (token == "insert") - return TokenType::keyword_insert; - if (token == "into") - return TokenType::keyword_into; - if (token == "values") - return TokenType::keyword_values; - if (token == "select") - return TokenType::keyword_select; - if (token == "set") - return TokenType::keyword_set; - if (token == "copy") - return TokenType::keyword_copy; - if (token == "update") - return TokenType::keyword_update; - if (token == "load") - return TokenType::keyword_load; - if (token == "save") - return TokenType::keyword_save; - if (token == "not") - return TokenType::keyword_not; - if (token == "null") - return TokenType::keyword_null; - if (token == "integer") - return TokenType::keyword_integer; - if (token == "float") - return TokenType::keyword_float; - if (token == "varchar") - return TokenType::keyword_varchar; - if (token == "date") - return TokenType::keyword_date; - if (token == "boolean") - return TokenType::keyword_bool; - if (token == "true") - return TokenType::keyword_true; - if (token == "false") - return TokenType::keyword_false; - if (token == "distinct") - return TokenType::keyword_distinct; - if (token == "show") - return TokenType::keyword_show; - if (token == "or") - return TokenType::logical_or; - if (token == "and") - return TokenType::logical_and; - if (token == ",") - return TokenType::comma; - if (token == "\n" || token == "\r\n" || token == "\r") - return TokenType::newline; + if (token == ";") return TokenType::semicolon; + if (token == "+") return TokenType::plus; + if (token == "-") return TokenType::minus; + if (token == "*") return TokenType::multiply; + if (token == "/") return TokenType::divide; + if (token == "(") return TokenType::open_paren; + if (token == ")") return TokenType::close_paren; + if (token == "=") return TokenType::equal; + if (token == "!=" || token == "<>") return TokenType::not_equal; + if (token == ">") return TokenType::greater; + if (token == ">=") return TokenType::greater_equal; + if (token == "<") return TokenType::lesser; + if (token == "<=") return TokenType::lesser_equal; + if (token == "is") return TokenType::is; + if (token == "as") return TokenType::keyword_as; + if (token == "create") return TokenType::keyword_create; + if (token == "drop") return TokenType::keyword_drop; + if (token == "where") return TokenType::keyword_where; + if (token == "order") return TokenType::keyword_order; + if (token == "by") return TokenType::keyword_by; + if (token == "offset") return TokenType::keyword_offset; + if (token == "limit") return TokenType::keyword_limit; + if (token == "asc") return TokenType::keyword_asc; + if (token == "desc") return TokenType::keyword_desc; + if (token == "from") return TokenType::keyword_from; + if (token == "delete") return TokenType::keyword_delete; + if (token == "table") return TokenType::keyword_table; + if (token == "insert") return TokenType::keyword_insert; + if (token == "into") return TokenType::keyword_into; + if (token == "values") return TokenType::keyword_values; + if (token == "select") return TokenType::keyword_select; + if (token == "set") return TokenType::keyword_set; + if (token == "copy") return TokenType::keyword_copy; + if (token == "update") return TokenType::keyword_update; + if (token == "load") return TokenType::keyword_load; + if (token == "save") return TokenType::keyword_save; + if (token == "not") return TokenType::keyword_not; + if (token == "null") return TokenType::keyword_null; + if (token == "integer") return TokenType::keyword_integer; + if (token == "float") return TokenType::keyword_float; + if (token == "varchar") return TokenType::keyword_varchar; + if (token == "date") return TokenType::keyword_date; + if (token == "boolean") return TokenType::keyword_bool; + if (token == "true") return TokenType::keyword_true; + if (token == "false") return TokenType::keyword_false; + if (token == "distinct") return TokenType::keyword_distinct; + if (token == "show") return TokenType::keyword_show; + if (token == "or") return TokenType::logical_or; + if (token == "and") return TokenType::logical_and; + if (token == ",") return TokenType::comma; + if (token == "\n" || token == "\r\n" || token == "\r") return TokenType::newline; - if (token.length() > 1 && token.at(0) == '%' && - (token.at(token.length() - 1) == '\n' || token.at(token.length() - 1) == '\r')) + if (token.length() > 1 && token.at(0) == '%' && (token.at(token.length() - 1) == '\n' || token.at(token.length() - 1) == '\r')) return TokenType::comment; if (token.length() >= 2 && token.at(0) == '"' && token.at(token.length() - 1) == '"') @@ -233,17 +181,10 @@ namespace usql { if (token.length() >= 2 && token.at(0) == '\'' && token.at(token.length() - 1) == '\'') return TokenType::string_literal; - if (std::regex_match(token, k_int_regex)) - return TokenType::int_number; - - if (std::regex_match(token, k_int_underscored_regex)) - return TokenType::int_number; - - if (std::regex_match(token, k_double_regex)) - return TokenType::double_number; - - if (std::regex_match(token, k_identifier_regex)) - return TokenType::identifier; + if (std::regex_match(token, k_int_regex)) return TokenType::int_number; + if (std::regex_match(token, k_int_underscored_regex)) return TokenType::int_number; + if (std::regex_match(token, k_double_regex)) return TokenType::double_number; + if (std::regex_match(token, k_identifier_regex)) return TokenType::identifier; return TokenType::undef; } @@ -286,178 +227,65 @@ namespace usql { } std::string Lexer::typeToString(TokenType token_type) { - std::string txt; switch (token_type) { - case TokenType::undef: - txt = "undef"; - break; - case TokenType::identifier: - txt = "identifier"; - break; - case TokenType::plus: - txt = "+"; - break; - case TokenType::minus: - txt = "-"; - break; - case TokenType::multiply: - txt = "*"; - break; - case TokenType::divide: - txt = "/"; - break; - case TokenType::equal: - txt = "=="; - break; - case TokenType::not_equal: - txt = "!="; - break; - case TokenType::greater: - txt = ">"; - break; - case TokenType::greater_equal: - txt = ">="; - break; - case TokenType::lesser: - txt = "<"; - break; - case TokenType::lesser_equal: - txt = "<="; - break; - case TokenType::is: - txt = "is"; - break; - case TokenType::keyword_as: - txt = "as"; - break; - case TokenType::keyword_create: - txt = "create"; - break; - case TokenType::keyword_drop: - txt = "drop"; - break; - case TokenType::keyword_where: - txt = "where"; - break; - case TokenType::keyword_order: - txt = "order"; - break; - case TokenType::keyword_by: - txt = "by"; - break; - case TokenType::keyword_offset: - txt = "offset"; - break; - case TokenType::keyword_limit: - txt = "limit"; - break; - case TokenType::keyword_asc: - txt = "asc"; - break; - case TokenType::keyword_desc: - txt = "desc"; - break; - case TokenType::keyword_table: - txt = "table"; - break; - case TokenType::keyword_into: - txt = "into"; - break; - case TokenType::keyword_values: - txt = "values"; - break; - case TokenType::keyword_select: - txt = "select"; - break; - case TokenType::keyword_set: - txt = "set"; - break; - case TokenType::keyword_copy: - txt = "copy"; - break; - case TokenType::keyword_update: - txt = "update"; - break; - case TokenType::keyword_load: - txt = "load"; - break; - case TokenType::keyword_save: - txt = "save"; - break; - case TokenType::keyword_not: - txt = "not"; - break; - case TokenType::keyword_null: - txt = "null"; - break; - case TokenType::keyword_integer: - txt = "integer"; - break; - case TokenType::keyword_float: - txt = "float"; - break; - case TokenType::keyword_varchar: - txt = "varchar"; - break; - case TokenType::keyword_date: - txt = "date"; - break; - case TokenType::keyword_bool: - txt = "boolean"; - break; - case TokenType::keyword_true: - txt = "true"; - break; - case TokenType::keyword_false: - txt = "false"; - break; - case TokenType::keyword_distinct: - txt = "distinct"; - break; - case TokenType::keyword_show: - txt = "show"; - break; - case TokenType::int_number: - txt = "int number"; - break; - case TokenType::double_number: - txt = "double number"; - break; - case TokenType::string_literal: - txt = "string literal"; - break; - case TokenType::open_paren: - txt = "("; - break; - case TokenType::close_paren: - txt = ")"; - break; - case TokenType::logical_and: - txt = "and"; - break; - case TokenType::logical_or: - txt = "or"; - break; - case TokenType::semicolon: - txt = ";"; - break; - case TokenType::comma: - txt = ","; - break; - case TokenType::newline: - txt = "newline"; - break; - case TokenType::comment: - txt = "comment"; - break; - case TokenType::eof: - txt = "eof"; - break; + case TokenType::undef: return "undef"; + case TokenType::identifier: return "identifier"; + case TokenType::plus: return "+"; + case TokenType::minus: return "-"; + case TokenType::multiply: return "*"; + case TokenType::divide: return "/"; + case TokenType::equal: return "=="; + case TokenType::not_equal: return "!="; + case TokenType::greater: return ">"; + case TokenType::greater_equal: return ">="; + case TokenType::lesser: return "<"; + case TokenType::lesser_equal: return "<="; + case TokenType::is: return "is"; + case TokenType::keyword_as: return "as"; + case TokenType::keyword_create: return "create"; + case TokenType::keyword_drop: return "drop"; + case TokenType::keyword_where: return "where"; + case TokenType::keyword_order: return "order"; + case TokenType::keyword_by: return "by"; + case TokenType::keyword_offset: return "offset"; + case TokenType::keyword_limit: return "limit"; + case TokenType::keyword_asc: return "asc"; + case TokenType::keyword_desc: return "desc"; + case TokenType::keyword_table: return "table"; + case TokenType::keyword_into: return "into"; + case TokenType::keyword_values: return "values"; + case TokenType::keyword_select: return "select"; + case TokenType::keyword_set: return "set"; + case TokenType::keyword_copy: return "copy"; + case TokenType::keyword_update: return "update"; + case TokenType::keyword_load: return "load"; + case TokenType::keyword_save: return "save"; + case TokenType::keyword_not: return "not"; + case TokenType::keyword_null: return "null"; + case TokenType::keyword_integer: return "integer"; + case TokenType::keyword_float: return "float"; + case TokenType::keyword_varchar: return "varchar"; + case TokenType::keyword_date: return "date"; + case TokenType::keyword_bool: return "boolean"; + case TokenType::keyword_true: return "true"; + case TokenType::keyword_false: return "false"; + case TokenType::keyword_distinct: return "distinct"; + case TokenType::keyword_show: return "show"; + case TokenType::int_number: return "int number"; + case TokenType::double_number: return "double number"; + case TokenType::string_literal: return "string literal"; + case TokenType::open_paren: return "("; + case TokenType::close_paren: return ")"; + case TokenType::logical_and: return "and"; + case TokenType::logical_or: return "or"; + case TokenType::semicolon: return ";"; + case TokenType::comma: return ","; + case TokenType::newline: return "newline"; + case TokenType::comment: return "comment"; + case TokenType::eof: return "eof"; default: - txt = "FIXME, unknown token type"; - break; + return "FIXME, unknown token type"; } - return txt; } } \ No newline at end of file diff --git a/main.cpp b/main.cpp index efc63d7..49c4257 100644 --- a/main.cpp +++ b/main.cpp @@ -148,24 +148,27 @@ void debug() { "insert into a (i, s, b) values(1, upper('zero'), 'Y')", "insert into a (i, s, b, f) values(1 + 10000, upper('one'), 'N', 3.1415)", "insert into a (i, s, f) values(2 + 10000, upper('two'), 3.1415)", + "select min(i), max(i), count(*) from a where b is not null", "select * from a where b is null", "select * from a where b is not null", -// "select * from a where b='N'", -// "update a set i = i * 100, f = f + 0.01 where i > 1", -// "select to_string(i, '%d.%m.%Y %H:%M:%S'), i, s from a where i < to_date('20.12.2019', '%d.%m.%Y')", -// "select i + 2 as first, i, s, b, f from a where i >=1 order by 1 desc offset 0 limit 1", -// "update table a set s = 'null string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'", -// "update table a set i = null", -// "insert into a (i, s) values(2, 'two')", -// "insert into a (i, s) values(3, 'two')", -// "insert into a (i, s) values(4, lower('FOUR'))", -// "insert into a (i, s) values(5, 'five')", -// "insert into a (i, s) values(to_date('20.12.1973', '%d.%m.%Y'), 'six')", + "select * from a where b='N'", + "update a set i = i * 100, f = f + 0.01 where i > 1", + "select to_string(i, '%d.%m.%Y %H:%M:%S'), i, s from a where i < to_date('20.12.2019', '%d.%m.%Y')", + "select i + 2 as first, i, s, b, f from a where i >=1 order by 1 desc offset 0 limit 1", + + + "update table a set s = 'null string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'", + "update table a set i = null", + "insert into a (i, s) values(2, 'two')", + "insert into a (i, s) values(3, 'two')", + "insert into a (i, s) values(4, lower('FOUR'))", + "insert into a (i, s) values(5, 'five')", + "insert into a (i, s) values(to_date('20.12.1973', '%d.%m.%Y'), 'six')", // tohle zpusobi kresh "insert into a (i, d) values(6', '2006-10-04')", -// "insert into a (i, d) values(6, '2006-10-04')", -// "save table a into '/tmp/a.csv'", -// "select i, s from a where i > 2 order by 1 desc offset 1 limit 1", -// "select distinct s, d from a", + "insert into a (i, d) values(6, '2006-10-04')", + "save table a into '/tmp/a.csv'", + "select i, s from a where i > 2 order by 1 desc offset 1 limit 1", + "select distinct s, d from a", // "select i, s from a where i = 1", // "select i, s from a where s = 'two'", // "select i, s from a where i <= 3 and s = 'one'", diff --git a/parser.cpp b/parser.cpp index 8bc7fdc..e7f5c36 100644 --- a/parser.cpp +++ b/parser.cpp @@ -191,8 +191,10 @@ namespace usql { // column values m_lexer.skipToken(TokenType::open_paren); do { - auto col_value = parse_expression(); - column_values.push_back(std::move(col_value)); + // TODO here it is problem when exception from parse_expression<-parse_value is thrown + // it makes double free + auto value = parse_expression(); + column_values.emplace_back(std::move(value)); m_lexer.skipTokenOptional(TokenType::comma); } while (m_lexer.tokenType() != TokenType::close_paren); @@ -408,7 +410,6 @@ namespace usql { // parenthesised expression if (token_type == TokenType::open_paren) { m_lexer.skipToken(TokenType::open_paren); - auto left = parse_expression(); do { left = parse_expression(std::move(left)); @@ -454,6 +455,10 @@ namespace usql { if (token_type == TokenType::keyword_true || token_type == TokenType::keyword_false) return std::make_unique(token_type == TokenType::keyword_true); + // token * for count(*) + if (token_type == TokenType::multiply) + return std::make_unique(tokenString); + throw Exception("Unknown operand node " + tokenString); } diff --git a/usql.cpp b/usql.cpp index ca05bf1..518c622 100644 --- a/usql.cpp +++ b/usql.cpp @@ -50,357 +50,15 @@ std::unique_ptr USql::execute(Node &node) { } -std::unique_ptr
USql::execute_create_table(CreateTableNode &node) { - check_table_not_exists(node.table_name); - - Table table{node.table_name, node.cols_defs}; - m_tables.push_back(table); - - return create_stmt_result_table(0, "table created", 0); -} - - -std::unique_ptr
USql::execute_create_table_as_table(CreateTableAsSelectNode &node) { - check_table_not_exists(node.table_name); - - auto select = execute_select((SelectFromTableNode &) *node.select_table); - - // create table - Table new_table{node.table_name, select->m_col_defs}; - m_tables.push_back(new_table); - - // copy rows - // must be here, if rows are put into new_table, they are lost during m_tables.push_table - Table *table = find_table(node.table_name); - for( Row& orig_row : select->m_rows) { - table->commit_copy_of_row(orig_row); - } - - select.release(); // is it correct? hoping not to release select table here and then when releasing CreateTableAsSelectNode - - return create_stmt_result_table(0, "table created", table->m_rows.size()); -} - - - -std::unique_ptr
USql::execute_load(LoadIntoTableNode &node) { - // find source table - Table *table_def = find_table(node.table_name); - - // read data - // std::ifstream ifs(node.filename); - // std::string content((std::istreambuf_iterator(ifs)), - // (std::istreambuf_iterator())); - // load rows - // auto rows_cnt = table_def->load_csv_string(content); - - - auto rows_cnt = table_def->load_csv_file(node.filename); - - return create_stmt_result_table(0, "load succeeded", rows_cnt); -} - - -std::unique_ptr
USql::execute_save(SaveTableNode &node) { - // find source table - Table *table_def = find_table(node.table_name); - - // make csv string - std::string csv_string = table_def->csv_string(); - - // save data - std::ofstream file(node.filename); - file << csv_string; - file.close(); - - return create_stmt_result_table(0, "save succeeded", table_def->rows_count()); -} - -std::unique_ptr
USql::execute_drop(DropTableNode &node) { - auto name_cmp = [node](const Table& t) { return t.m_name == node.table_name; }; - - auto table_def = std::find_if(begin(m_tables), end(m_tables), name_cmp); - if (table_def != std::end(m_tables)) { - m_tables.erase(table_def); - return create_stmt_result_table(0, "drop succeeded", 0); - } - - throw Exception("table not found (" + node.table_name + ")"); -} - -std::unique_ptr
USql::execute_set(SetNode &node) { - Settings::set_setting(node.name, node.value); - return create_stmt_result_table(0, "set succeeded", 1); -} - -std::unique_ptr
USql::execute_show(ShowNode &node) { - std::string value = Settings::get_setting(node.name); - return create_stmt_result_table(0, "show succeeded: " + value, 1); -} - -std::unique_ptr
USql::execute_insert_into_table(InsertIntoTableNode &node) { - // find table - Table *table_def = find_table(node.table_name); - - if (node.cols_names.size() != node.cols_values.size()) - throw Exception("Incorrect number of values"); - - // prepare empty new_row - Row& new_row = table_def->create_empty_row(); - - // copy values - for (size_t i = 0; i < node.cols_names.size(); i++) { - ColDefNode col_def = table_def->get_column_def(node.cols_names[i].col_name); - auto col_value = eval_value_node(table_def, new_row, node.cols_values[i].get()); - - new_row.setColumnValue(&col_def, col_value.get()); - } - - // append new_row - table_def->commit_row(new_row); - - return create_stmt_result_table(0, "insert succeeded", 1); -} - - -std::unique_ptr
USql::execute_select(SelectFromTableNode &node) { - // find source table - Table *table = find_table(node.table_name); - - // expand * - if (node.cols_names->size()==1 && node.cols_names->operator[](0).name == "*") { - node.cols_names->clear(); - node.cols_names->reserve(table->columns_count()); - for(const auto& col : table->m_col_defs) { - node.cols_names->emplace_back(SelectColNode{std::make_unique(col.name), col.name}); - } - } - - - // create result table - std::vector result_tbl_col_defs{}; - std::vector source_table_col_index{}; - - for (int i = 0; i < node.cols_names->size(); i++) { - auto [src_tbl_col_index, rst_tbl_col_def] = get_column_definition(table, &node.cols_names->operator[](i), i); - - source_table_col_index.push_back(src_tbl_col_index); - result_tbl_col_defs.push_back(rst_tbl_col_def); - } - auto result = std::make_unique
("result", result_tbl_col_defs); - - - // execute access plan - for (auto row = begin(table->m_rows); row != end(table->m_rows); ++row) { - // eval where for row - if (eval_where(node.where.get(), table, *row)) { - // prepare empty row and copy column values - Row& new_row = result->create_empty_row(); - - for (auto idx = 0; idx < result->columns_count(); idx++) { - auto row_col_index = source_table_col_index[idx]; - - if (row_col_index == FUNCTION_CALL) { - auto evaluated_value = eval_value_node(table, *row, node.cols_names->operator[](idx).value.get()); - ValueNode *col_value = evaluated_value.get(); - - new_row.setColumnValue(&result_tbl_col_defs[idx], col_value); - } else { - ColValue &col_value = row->operator[](row_col_index); - new_row.setColumnValue(&result_tbl_col_defs[idx], col_value); - } - } - - // add row to result - result->commit_row(new_row); - } - } - - execute_distinct(node, result.get()); - - execute_order_by(node, table, result.get()); - - execute_offset_limit(node.offset_limit, result.get()); - - return result; -} - -void USql::execute_distinct(SelectFromTableNode &node, Table *result) { - if (!node.distinct) return; - - auto compare_rows = [](const Row &a, const Row &b) { return a.compare(b) >= 0; }; - std::sort(result->m_rows.begin(), result->m_rows.end(), compare_rows); - - result->m_rows.erase(std::unique(result->m_rows.begin(), result->m_rows.end()), result->m_rows.end()); -} - -void USql::execute_order_by(SelectFromTableNode &node, Table *table, Table *result) { - if (node.order_by.empty()) return; - - auto compare_rows = [&node, &result](const Row &a, const Row &b) { - for(const auto& order_by_col_def : node.order_by) { - // TODO validate index - ColDefNode col_def = result->get_column_def(order_by_col_def.col_index - 1); - ColValue &a_val = a[col_def.order]; - ColValue &b_val = b[col_def.order]; - - int compare = a_val.compare(b_val); - - if (compare < 0) return order_by_col_def.ascending; - if (compare > 0) return !order_by_col_def.ascending; - } - return false; - }; - - std::sort(result->m_rows.begin(), result->m_rows.end(), compare_rows); -} - -void USql::execute_offset_limit(OffsetLimitNode &node, Table *result) { - if (node.offset > 0) - result->m_rows.erase(result->m_rows.begin(), - result->rows_count() > node.offset ? result->m_rows.begin() + node.offset : result->m_rows.end()); - - if (node.limit > 0 && node.limit < result->rows_count()) - result->m_rows.erase(result->m_rows.begin() + node.limit, result->m_rows.end()); -} - -std::tuple USql::get_column_definition(Table *table, SelectColNode *select_col_node, int col_order ) { - return get_node_definition(table, select_col_node->value.get(), select_col_node->name, col_order ); -} - -std::tuple USql::get_node_definition(Table *table, Node * node, const std::string & col_name, int col_order ) { - if (node->node_type == NodeType::database_value) { - auto dbval_node = static_cast(node); - - ColDefNode src_col_def = table->get_column_def(dbval_node->col_name); - ColDefNode col_def = ColDefNode{col_name, src_col_def.type, col_order, src_col_def.length, src_col_def.null}; - return std::make_tuple(src_col_def.order, col_def); - - } else if (node->node_type == NodeType::function) { - auto func_node = static_cast(node); - - if (func_node->function == "to_string") { - ColDefNode col_def = ColDefNode{col_name, ColumnType::varchar_type, col_order, 32, true}; - return std::make_tuple(-1, col_def); - } else if (func_node->function == "to_date") { - ColDefNode col_def = ColDefNode{col_name, ColumnType::integer_type, col_order, 1, true}; - return std::make_tuple(-1, col_def); - } else if (func_node->function == "pp") { - ColDefNode col_def = ColDefNode{col_name, ColumnType::varchar_type, col_order, 10, true}; - return std::make_tuple(-1, col_def); - } - throw Exception("Unsupported function"); - - } else if (node->node_type == NodeType::arithmetical_operator) { - auto ari_node = static_cast(node); - - auto [left_col_index, left_tbl_col_def] = get_node_definition(table, ari_node->left.get(), col_name, col_order ); - auto [right_col_index, right_tbl_col_def] = get_node_definition(table, ari_node->right.get(), col_name, col_order ); - - ColumnType col_type; // TODO handle varchar and it len - if (left_tbl_col_def.type==ColumnType::float_type || right_tbl_col_def.type==ColumnType::float_type) - col_type = ColumnType::float_type; - else - col_type = ColumnType::integer_type; - - ColDefNode col_def = ColDefNode{col_name, col_type, col_order, 1, true}; - return std::make_tuple(-1, col_def); - - } else if (node->node_type == NodeType::logical_operator) { - ColDefNode col_def = ColDefNode{col_name, ColumnType::bool_type, col_order, 1, true}; - return std::make_tuple(-1, col_def); - - } else if (node->node_type == NodeType::int_value) { - ColDefNode col_def = ColDefNode{col_name, ColumnType::integer_type, col_order, 1, true}; - return std::make_tuple(-1, col_def); - - } else if (node->node_type == NodeType::float_value) { - ColDefNode col_def = ColDefNode{col_name, ColumnType::float_type, col_order, 1, true}; - return std::make_tuple(-1, col_def); - - } else if (node->node_type == NodeType::string_value) { - // TODO right len - ColDefNode col_def = ColDefNode{col_name, ColumnType::varchar_type, col_order, 64, true}; - return std::make_tuple(-1, col_def); - } - throw Exception("Unsupported node type"); -} - - - -std::unique_ptr
USql::execute_delete(DeleteFromTableNode &node) { - // find source table - Table *table = find_table(node.table_name); - - // execute access plan - auto affected_rows = table->rows_count(); - - table->m_rows.erase( - std::remove_if(table->m_rows.begin(), table->m_rows.end(), - [&node, table](Row &row){return eval_where(node.where.get(), table, row);}), - table->m_rows.end()); - - affected_rows -= table->rows_count(); - - return create_stmt_result_table(0, "delete succeeded", affected_rows); -} - - -std::unique_ptr
USql::execute_update(UpdateTableNode &node) { - // find source table - Table *table = find_table(node.table_name); - - // execute access plan - int affected_rows = 0; - for (auto row = begin(table->m_rows); row != end(table->m_rows); ++row) { - // eval where for row - if (eval_where(node.where.get(), table, *row)) { - int i = 0; - for (const auto& col : node.cols_names) { - // TODO cache it like in select - ColDefNode col_def = table->get_column_def(col.col_name); - std::unique_ptr new_val = eval_arithmetic_operator(col_def.type, - static_cast(*node.values[i]), - table, *row); - - usql::Table::validate_column(&col_def, new_val.get()); - row->setColumnValue(&col_def, new_val.get()); - i++; - } - affected_rows++; - // TODO tady je problem, ze kdyz to zfajluje na jednom radku ostatni by se nemely provest - } - } - - return create_stmt_result_table(0, "update succeeded", affected_rows); -} - - -bool USql::eval_where(Node *where, Table *table, Row &row) { - switch (where->node_type) { - case NodeType::true_node: - return true; - case NodeType::relational_operator: // just one condition - return eval_relational_operator(*((RelationalOperatorNode *) where), table, row); - case NodeType::logical_operator: - return eval_logical_operator(*((LogicalOperatorNode *) where), table, row); - default: - throw Exception("Wrong node type"); - } - - return false; -} - - bool USql::eval_relational_operator(const RelationalOperatorNode &filter, Table *table, Row &row) { - std::unique_ptr left_value = eval_value_node(table, row, filter.left.get()); - std::unique_ptr right_value = eval_value_node(table, row, filter.right.get()); + std::unique_ptr left_value = eval_value_node(table, row, filter.left.get(), nullptr, nullptr); + std::unique_ptr right_value = eval_value_node(table, row, filter.right.get(), nullptr, nullptr); double comparator; if (left_value->node_type == NodeType::null_value || right_value->node_type == NodeType::null_value) { - bool all_null = left_value->isNull() && right_value->node_type == NodeType::null_value || - right_value->isNull() && left_value->node_type == NodeType::null_value; + bool all_null = (left_value->isNull() && right_value->node_type == NodeType::null_value) || + (right_value->isNull() && left_value->node_type == NodeType::null_value); if (filter.op == RelationalOperatorType::is) return all_null; if (filter.op == RelationalOperatorType::is_not) @@ -436,19 +94,23 @@ bool USql::eval_relational_operator(const RelationalOperatorNode &filter, Table return comparator < 0.0; case RelationalOperatorType::lesser_equal: return comparator <= 0.0; + case RelationalOperatorType::is: + case RelationalOperatorType::is_not: + // already handled + throw Exception("error in is or is not handling"); } throw Exception("invalid relational operator"); } -std::unique_ptr USql::eval_value_node(Table *table, Row &row, Node *node) { +std::unique_ptr USql::eval_value_node(Table *table, Row &row, Node *node, ColDefNode *col_def_node, ColValue *agg_func_value) { if (node->node_type == NodeType::database_value) { return eval_database_value_node(table, row, node); } else if (node->node_type == NodeType::int_value || node->node_type == NodeType::float_value || node->node_type == NodeType::string_value || node->node_type == NodeType::bool_value) { return eval_literal_value_node(table, row, node); } else if (node->node_type == NodeType::function) { - return eval_function_value_node(table, row, node); + return eval_function_value_node(table, row, node, col_def_node, agg_func_value); } else if (node->node_type == NodeType::null_value) { return std::make_unique(); } else if (node->node_type == NodeType::arithmetical_operator) { @@ -504,12 +166,13 @@ std::unique_ptr USql::eval_literal_value_node(Table *table, Row &row, } -std::unique_ptr USql::eval_function_value_node(Table *table, Row &row, Node *node) { +std::unique_ptr +USql::eval_function_value_node(Table *table, Row &row, Node *node, ColDefNode *col_def_node, ColValue *agg_func_value) { auto *fnc = static_cast(node); std::vector> evaluatedPars; for(auto & param : fnc->params) { - evaluatedPars.push_back(eval_value_node(table, row, param.get())); + evaluatedPars.push_back(eval_value_node(table, row, param.get(), nullptr, nullptr)); } // at this moment no functions without parameter(s) or first param can be null @@ -517,69 +180,24 @@ std::unique_ptr USql::eval_function_value_node(Table *table, Row &row return std::make_unique(); // TODO use some enum - if (fnc->function == "lower") { - std::string str = evaluatedPars[0]->getStringValue(); - std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) -> unsigned char { return std::tolower(c); }); - return std::make_unique(str); + if (fnc->function == "lower") return lower_function(evaluatedPars); + if (fnc->function == "upper") return upper_function(evaluatedPars); + if (fnc->function == "to_date") return to_date_function(evaluatedPars); + if (fnc->function == "to_string") return to_string_function(evaluatedPars); + if (fnc->function == "pp") return pp_function(evaluatedPars); + if (fnc->function == "count") return count_function(agg_func_value, evaluatedPars); + if (fnc->function == "max") return max_function(evaluatedPars, col_def_node, agg_func_value); + if (fnc->function == "min") return min_function(evaluatedPars, col_def_node, agg_func_value); + + throw Exception("invalid function: " + fnc->function); +} + +std::unique_ptr USql::count_function(ColValue *agg_func_value, const std::vector> &evaluatedPars) { + long c = 1; + if (!agg_func_value->isNull()) { + c = agg_func_value->getIntValue() + 1; } - if (fnc->function == "upper") { - std::string str = evaluatedPars[0]->getStringValue(); - std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) -> unsigned char { return std::toupper(c); }); - return std::make_unique(str); - } - - if (fnc->function == "to_date") { - std::string date = evaluatedPars[0]->getStringValue(); - std::string format = evaluatedPars[1]->getStringValue(); - long epoch_time = string_to_date(date, format); - return std::make_unique(epoch_time); // No DateValueNode for now - } - if (fnc->function == "to_string") { - long date = evaluatedPars[0]->getDateValue(); - std::string format = evaluatedPars[1]->getStringValue(); - std::string formatted_date = date_to_string(date, format); - return std::make_unique(formatted_date); - } - if (fnc->function == "pp") { - auto &parsed_value = evaluatedPars[0]; - - if (parsed_value->node_type == NodeType::int_value || parsed_value->node_type == NodeType::float_value) { - std::string format = evaluatedPars.size() > 1 ? evaluatedPars[1]->getStringValue() : ""; - char buf[16] {0}; - double value = parsed_value->getDoubleValue(); - - if (format == "100%") - std::snprintf(buf, 20, "%.2f%%", value); - else if (value >= 1000000000000) - std::snprintf(buf, 20, "%7.2fT", value/1000000000000); - else if (value >= 1000000000) - std::sprintf(buf, "%7.2fB", value/1000000000); - else if (value >= 1000000) - std::snprintf(buf, 20, "%7.2fM", value/1000000); - else if (value >= 100000) - std::snprintf(buf, 20, "%7.2fM", value/100000); // 0.12M - else if (value <= -1000000000000) - std::snprintf(buf, 20, "%7.2fT", value/1000000000000); - else if (value <= -1000000000) - std::snprintf(buf, 20, "%7.2fB", value/1000000000); - else if (value <= -1000000) - std::snprintf(buf, 20, "%7.2fM", value/1000000); - else if (value <= -100000) - std::snprintf(buf, 20, "%7.2fM", value/100000); // 0.12M - else if (value == 0) - buf[0]='0'; - else - return std::make_unique(parsed_value->getStringValue().substr(0, 10)); - // TODO introduce constant for 10 - std::string s {buf}; - return std::make_unique(string_padd(s.erase(s.find_last_not_of(" ")+1), 10, ' ', false)); - } - - return std::make_unique(parsed_value->getStringValue()); - } - - - throw Exception("invalid function"); + return std::make_unique(c); } @@ -598,11 +216,11 @@ bool USql::eval_logical_operator(LogicalOperatorNode &node, Table *pTable, Row & std::unique_ptr USql::eval_arithmetic_operator(ColumnType outType, ArithmeticalOperatorNode &node, Table *table, Row &row) { if (node.op == ArithmeticalOperatorType::copy_value) { - return eval_value_node(table, row, node.left.get()); + return eval_value_node(table, row, node.left.get(), nullptr, nullptr); } - std::unique_ptr left = eval_value_node(table, row, node.left.get()); - std::unique_ptr right = eval_value_node(table, row, node.right.get()); + std::unique_ptr left = eval_value_node(table, row, node.left.get(), nullptr, nullptr); + std::unique_ptr right = eval_value_node(table, row, node.right.get(), nullptr, nullptr); if (left->isNull() || right->isNull()) return std::make_unique(); @@ -655,24 +273,130 @@ std::unique_ptr USql::eval_arithmetic_operator(ColumnType outType, Ar } -std::unique_ptr
USql::create_stmt_result_table(long code, const std::string &text, size_t affected_rows) { - std::vector result_tbl_col_defs{}; - result_tbl_col_defs.emplace_back("code", ColumnType::integer_type, 0, 1, false); - result_tbl_col_defs.emplace_back("desc", ColumnType::varchar_type, 1, 48, false); - result_tbl_col_defs.emplace_back("affected_rows", ColumnType::integer_type, 0, 1, true); - - auto table_def = std::make_unique
("result", result_tbl_col_defs); - - Row& new_row = table_def->create_empty_row(); - new_row.setIntColumnValue(0, code); - new_row.setStringColumnValue(1, text.size() <= 48 ? text : text.substr(0,48)); - new_row.setIntColumnValue(2, (long)affected_rows); - table_def->commit_row(new_row); - - return table_def; +std::unique_ptr USql::to_string_function(const std::vector> &evaluatedPars) { + long date = evaluatedPars[0]->getDateValue(); + std::string format = evaluatedPars[1]->getStringValue(); + std::string formatted_date = date_to_string(date, format); + return std::make_unique(formatted_date); } +std::unique_ptr USql::to_date_function(const std::vector> &evaluatedPars) { + std::string date = evaluatedPars[0]->getStringValue(); + std::string format = evaluatedPars[1]->getStringValue(); + long epoch_time = string_to_date(date, format); + return std::make_unique(epoch_time); // No DateValueNode for now +} +std::unique_ptr USql::upper_function(const std::vector> &evaluatedPars) { + std::string str = evaluatedPars[0]->getStringValue(); + std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) -> unsigned char { return toupper(c); }); + return std::make_unique(str); +} + +std::unique_ptr USql::lower_function(const std::vector> &evaluatedPars) { + std::string str = evaluatedPars[0]->getStringValue(); + std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) -> unsigned char { return tolower(c); }); + return std::make_unique(str); +} + +std::unique_ptr USql::pp_function(const std::vector> &evaluatedPars) { + auto &parsed_value = evaluatedPars[0]; + + if (parsed_value->node_type == NodeType::int_value || parsed_value->node_type == NodeType::float_value) { + std::string format = evaluatedPars.size() > 1 ? evaluatedPars[1]->getStringValue() : ""; + char buf[16] {0}; + double value = parsed_value->getDoubleValue(); + + if (format == "100%") + std::snprintf(buf, 20, "%.2f%%", value); + else if (value >= 1000000000000) + std::snprintf(buf, 20, "%7.2fT", value/1000000000000); + else if (value >= 1000000000) + std::sprintf(buf, "%7.2fB", value/1000000000); + else if (value >= 1000000) + std::snprintf(buf, 20, "%7.2fM", value/1000000); + else if (value >= 100000) + std::snprintf(buf, 20, "%7.2fM", value/100000); // 0.12M + else if (value <= -1000000000000) + std::snprintf(buf, 20, "%7.2fT", value/1000000000000); + else if (value <= -1000000000) + std::snprintf(buf, 20, "%7.2fB", value/1000000000); + else if (value <= -1000000) + std::snprintf(buf, 20, "%7.2fM", value/1000000); + else if (value <= -100000) + std::snprintf(buf, 20, "%7.2fM", value/100000); // 0.12M + else if (value == 0) + buf[0]='0'; + else + return std::make_unique(parsed_value->getStringValue().substr(0, 10)); + // TODO introduce constant for 10 + std::string s {buf}; + return std::make_unique(string_padd(s.erase(s.find_last_not_of(" ")+1), 10, ' ', false)); + } + return std::make_unique(parsed_value->getStringValue()); +} + +std::unique_ptr +USql::max_function(const std::vector> &evaluatedPars, const ColDefNode *col_def_node, + ColValue *agg_func_value) { + if (col_def_node->type == ColumnType::integer_type || col_def_node->type == ColumnType::date_type) { + if (!evaluatedPars[0]->isNull()) { + long val = evaluatedPars[0]->getIntegerValue(); + if (agg_func_value->isNull()) { + return std::make_unique(val); + } else { + return std::make_unique(std::max(val, agg_func_value->getIntValue())); + } + } else { + return std::make_unique(agg_func_value->getIntValue()); + } + } else if (col_def_node->type == ColumnType::float_type) { + if (!evaluatedPars[0]->isNull()) { + double val = evaluatedPars[0]->getDoubleValue(); + if (agg_func_value->isNull()) { + return std::make_unique(val); + } else { + return std::make_unique(std::max(val, agg_func_value->getDoubleValue())); + } + } else { + return std::make_unique(agg_func_value->getDoubleValue()); + } + } + + // TODO string and boolean + throw Exception("unsupported data type for max function"); +} + +std::unique_ptr +USql::min_function(const std::vector> &evaluatedPars, const ColDefNode *col_def_node, + ColValue *agg_func_value) { + if (col_def_node->type == ColumnType::integer_type || col_def_node->type == ColumnType::date_type) { + if (!evaluatedPars[0]->isNull()) { + long val = evaluatedPars[0]->getIntegerValue(); + if (agg_func_value->isNull()) { + return std::make_unique(val); + } else { + return std::make_unique(std::min(val, agg_func_value->getIntValue())); + } + } else { + return std::make_unique(agg_func_value->getIntValue()); + } + } else if (col_def_node->type == ColumnType::float_type) { + if (!evaluatedPars[0]->isNull()) { + double val = evaluatedPars[0]->getDoubleValue(); + if (agg_func_value->isNull()) { + return std::make_unique(val); + } else { + return std::make_unique(std::min(val, agg_func_value->getDoubleValue())); + } + } else { + return std::make_unique(agg_func_value->getDoubleValue()); + } + } + + // TODO string and boolean + throw Exception("unsupported data type for min function"); +} Table *USql::find_table(const std::string &name) { auto name_cmp = [name](const Table& t) { return t.m_name == name; }; diff --git a/usql.h b/usql.h index cfda794..554bb78 100644 --- a/usql.h +++ b/usql.h @@ -35,10 +35,10 @@ private: private: static bool eval_where(Node *where, Table *table, Row &row) ; - static std::unique_ptr eval_value_node(Table *table, Row &row, Node *node); + static std::unique_ptr eval_value_node(Table *table, Row &row, Node *node, ColDefNode *col_def_node, ColValue *agg_func_value); static std::unique_ptr eval_database_value_node(Table *table, Row &row, Node *node); static std::unique_ptr eval_literal_value_node(Table *table, Row &row, Node *node); - static std::unique_ptr eval_function_value_node(Table *table, Row &row, Node *node); + static std::unique_ptr eval_function_value_node(Table *table, Row &row, Node *node, ColDefNode *col_def_node, ColValue *agg_func_value); static bool eval_relational_operator(const RelationalOperatorNode &filter, Table *table, Row &row) ; @@ -60,6 +60,22 @@ private: static void execute_distinct(SelectFromTableNode &node, Table *result) ; static void execute_order_by(SelectFromTableNode &node, Table *table, Table *result) ; static void execute_offset_limit(OffsetLimitNode &node, Table *result) ; + + void expand_asterix_char(SelectFromTableNode &node, Table *table) const; + + bool check_for_aggregate_only_functions(SelectFromTableNode &node, int result_cols_cnt) const; + + static std::unique_ptr lower_function(const std::vector> &evaluatedPars); + static std::unique_ptr upper_function(const std::vector> &evaluatedPars); + static std::unique_ptr to_date_function(const std::vector> &evaluatedPars); + static std::unique_ptr to_string_function(const std::vector> &evaluatedPars); + static std::unique_ptr pp_function(const std::vector> &evaluatedPars); + + static std::unique_ptr max_function(const std::vector> &evaluatedPars, const ColDefNode *col_def_node, ColValue *agg_func_value); + static std::unique_ptr min_function(const std::vector> &evaluatedPars, const ColDefNode *col_def_node, ColValue *agg_func_value); + + static std::unique_ptr + count_function(ColValue *agg_func_value, const std::vector> &evaluatedPars); }; } // namespace \ No newline at end of file diff --git a/usql_ddl.cpp b/usql_ddl.cpp new file mode 100644 index 0000000..38a84d9 --- /dev/null +++ b/usql_ddl.cpp @@ -0,0 +1,120 @@ +#include "usql.h" +#include "exception.h" +#include "ml_date.h" +#include "ml_string.h" + +#include +#include + +namespace usql { + + + +std::unique_ptr
USql::execute_create_table(CreateTableNode &node) { + check_table_not_exists(node.table_name); + + Table table{node.table_name, node.cols_defs}; + m_tables.push_back(table); + + return create_stmt_result_table(0, "table created", 0); +} + + +std::unique_ptr
USql::execute_create_table_as_table(CreateTableAsSelectNode &node) { + check_table_not_exists(node.table_name); + + auto select = execute_select((SelectFromTableNode &) *node.select_table); + + // create table + Table new_table{node.table_name, select->m_col_defs}; + m_tables.push_back(new_table); + + // copy rows + // must be here, if rows are put into new_table, they are lost during m_tables.push_table + Table *table = find_table(node.table_name); + for( Row& orig_row : select->m_rows) { + table->commit_copy_of_row(orig_row); + } + + select.release(); // is it correct? hoping not to release select table here and then when releasing CreateTableAsSelectNode + + return create_stmt_result_table(0, "table created", table->m_rows.size()); +} + + + +std::unique_ptr
USql::execute_drop(DropTableNode &node) { + auto name_cmp = [node](const Table& t) { return t.m_name == node.table_name; }; + + auto table_def = std::find_if(begin(m_tables), end(m_tables), name_cmp); + if (table_def != std::end(m_tables)) { + m_tables.erase(table_def); + return create_stmt_result_table(0, "drop succeeded", 0); + } + + throw Exception("table not found (" + node.table_name + ")"); +} + +std::unique_ptr
USql::execute_set(SetNode &node) { + Settings::set_setting(node.name, node.value); + return create_stmt_result_table(0, "set succeeded", 1); +} + +std::unique_ptr
USql::execute_show(ShowNode &node) { + std::string value = Settings::get_setting(node.name); + return create_stmt_result_table(0, "show succeeded: " + value, 1); +} + + +std::unique_ptr
USql::create_stmt_result_table(long code, const std::string &text, size_t affected_rows) { + std::vector result_tbl_col_defs{}; + result_tbl_col_defs.emplace_back("code", ColumnType::integer_type, 0, 1, false); + result_tbl_col_defs.emplace_back("desc", ColumnType::varchar_type, 1, 48, false); + result_tbl_col_defs.emplace_back("affected_rows", ColumnType::integer_type, 0, 1, true); + + auto table_def = std::make_unique
("result", result_tbl_col_defs); + + Row& new_row = table_def->create_empty_row(); + new_row.setIntColumnValue(0, code); + new_row.setStringColumnValue(1, text.size() <= 48 ? text : text.substr(0,48)); + new_row.setIntColumnValue(2, (long)affected_rows); + table_def->commit_row(new_row); + + return table_def; +} + + + +std::unique_ptr
USql::execute_load(LoadIntoTableNode &node) { + // find source table + Table *table_def = find_table(node.table_name); + + // read data + // std::ifstream ifs(node.filename); + // std::string content((std::istreambuf_iterator(ifs)), (std::istreambuf_iterator())); + // load rows + // auto rows_cnt = table_def->load_csv_string(content); + + auto rows_cnt = table_def->load_csv_file(node.filename); + + return create_stmt_result_table(0, "load succeeded", rows_cnt); +} + + +std::unique_ptr
USql::execute_save(SaveTableNode &node) { + // find source table + Table *table_def = find_table(node.table_name); + + // make csv string + std::string csv_string = table_def->csv_string(); + + // save data + std::ofstream file(node.filename); + file << csv_string; + file.close(); + + return create_stmt_result_table(0, "save succeeded", table_def->rows_count()); +} + + +} // namespace diff --git a/usql_dml.cpp b/usql_dml.cpp new file mode 100644 index 0000000..acf2163 --- /dev/null +++ b/usql_dml.cpp @@ -0,0 +1,313 @@ +#include "usql.h" +#include "exception.h" +#include "ml_date.h" +#include "ml_string.h" + +#include +#include + +namespace usql { + + +std::unique_ptr
USql::execute_select(SelectFromTableNode &node) { + // find source table + Table *table = find_table(node.table_name); + + // expand * + expand_asterix_char(node, table); + + // create result table + std::vector result_tbl_col_defs{}; + std::vector source_table_col_index{}; + for (int i = 0; i < node.cols_names->size(); i++) { + SelectColNode * col_node = &node.cols_names->operator[](i); + auto [src_tbl_col_index, rst_tbl_col_def] = get_column_definition(table, col_node, i); + + source_table_col_index.push_back(src_tbl_col_index); + result_tbl_col_defs.push_back(rst_tbl_col_def); + } + + // check for aggregate function + bool aggregate_funcs = check_for_aggregate_only_functions(node, result_tbl_col_defs.size()); + + auto result = std::make_unique
("result", result_tbl_col_defs); + + + // execute access plan + Row* new_row = nullptr; + for (auto row = begin(table->m_rows); row != end(table->m_rows); ++row) { + // eval where for row + if (eval_where(node.where.get(), table, *row)) { + // prepare empty row and copy column values + // when agregate functions in result only one row for table + if (!aggregate_funcs || result->rows_count()==0) { + new_row = &result->create_empty_row(); + } + + for (auto idx = 0; idx < result->columns_count(); idx++) { + auto src_table_col_idx = source_table_col_index[idx]; + + if (src_table_col_idx == FUNCTION_CALL) { + auto evaluated_value = eval_value_node(table, *row, node.cols_names->operator[](idx).value.get(), &result_tbl_col_defs[idx], &new_row->operator[](idx)); + ValueNode *col_value = evaluated_value.get(); + + new_row->setColumnValue(&result_tbl_col_defs[idx], col_value); + } else { + ColValue &col_value = row->operator[](src_table_col_idx); + new_row->setColumnValue(&result_tbl_col_defs[idx], col_value); + } + } + + // add row to result + if (aggregate_funcs == 0) { + result->commit_row(*new_row); + } + } + } + // when aggregates commit this one row + if (aggregate_funcs && new_row != nullptr) { + result->commit_row(*new_row); + } + + execute_distinct(node, result.get()); + + execute_order_by(node, table, result.get()); + + execute_offset_limit(node.offset_limit, result.get()); + + return result; +} + +bool USql::check_for_aggregate_only_functions(SelectFromTableNode &node, int result_cols_cnt) const { + int aggregate_funcs = 0; + for (int i = 0; i < node.cols_names->size(); i++) { + SelectColNode * col_node = &node.cols_names->operator[](i); + if (col_node->value->node_type == NodeType::function) { + auto func_node = static_cast(col_node->value.get()); + if (func_node->function == "count" || func_node->function == "min" || func_node->function == "max") + aggregate_funcs++; + } + } + // check whether aggregates are not present or all columns are aggregates + if (aggregate_funcs > 0 && aggregate_funcs != result_cols_cnt) { + throw Exception("aggregate functions with no aggregates"); + } + + return aggregate_funcs > 0; +} + +void USql::expand_asterix_char(SelectFromTableNode &node, Table *table) const { + if (node.cols_names->size() == 1 && node.cols_names->operator[](0).name == "*") { + node.cols_names->clear(); + node.cols_names->reserve(table->columns_count()); + for(const auto& col : table->m_col_defs) { + node.cols_names->emplace_back(SelectColNode{std::__1::make_unique(col.name), col.name}); + } + } +} + +void USql::execute_distinct(SelectFromTableNode &node, Table *result) { + if (!node.distinct) return; + + auto compare_rows = [](const Row &a, const Row &b) { return a.compare(b) >= 0; }; + std::sort(result->m_rows.begin(), result->m_rows.end(), compare_rows); + + result->m_rows.erase(std::unique(result->m_rows.begin(), result->m_rows.end()), result->m_rows.end()); +} + +void USql::execute_order_by(SelectFromTableNode &node, Table *table, Table *result) { + if (node.order_by.empty()) return; + + auto compare_rows = [&node, &result](const Row &a, const Row &b) { + for(const auto& order_by_col_def : node.order_by) { + // TODO validate index + ColDefNode col_def = result->get_column_def(order_by_col_def.col_index - 1); + ColValue &a_val = a[col_def.order]; + ColValue &b_val = b[col_def.order]; + + int compare = a_val.compare(b_val); + + if (compare < 0) return order_by_col_def.ascending; + if (compare > 0) return !order_by_col_def.ascending; + } + return false; + }; + + std::sort(result->m_rows.begin(), result->m_rows.end(), compare_rows); +} + +void USql::execute_offset_limit(OffsetLimitNode &node, Table *result) { + if (node.offset > 0) + result->m_rows.erase(result->m_rows.begin(), result->rows_count() > node.offset ? result->m_rows.begin() + node.offset : result->m_rows.end()); + + if (node.limit > 0 && node.limit < result->rows_count()) + result->m_rows.erase(result->m_rows.begin() + node.limit, result->m_rows.end()); +} + +std::tuple USql::get_column_definition(Table *table, SelectColNode *select_col_node, int col_order ) { + return get_node_definition(table, select_col_node->value.get(), select_col_node->name, col_order ); +} + +std::tuple USql::get_node_definition(Table *table, Node * node, const std::string & col_name, int col_order ) { + if (node->node_type == NodeType::database_value) { + auto dbval_node = static_cast(node); + + ColDefNode src_col_def = table->get_column_def(dbval_node->col_name); + ColDefNode col_def = ColDefNode{col_name, src_col_def.type, col_order, src_col_def.length, src_col_def.null}; + return std::make_tuple(src_col_def.order, col_def); + + } else if (node->node_type == NodeType::function) { + auto func_node = static_cast(node); + + if (func_node->function == "to_string") { + ColDefNode col_def = ColDefNode{col_name, ColumnType::varchar_type, col_order, 32, true}; + return std::make_tuple(-1, col_def); + } else if (func_node->function == "to_date") { + ColDefNode col_def = ColDefNode{col_name, ColumnType::integer_type, col_order, 1, true}; + return std::make_tuple(-1, col_def); + } else if (func_node->function == "pp") { + ColDefNode col_def = ColDefNode{col_name, ColumnType::varchar_type, col_order, 10, true}; + return std::make_tuple(-1, col_def); + } else if (func_node->function == "lower" || func_node->function == "upper") { + // TODO get length + ColDefNode col_def = ColDefNode{col_name, ColumnType::varchar_type, col_order, 256, true}; + return std::make_tuple(-1, col_def); + } else if (func_node->function == "min" || func_node->function == "max") { + // TODO get correct type and length + ColDefNode col_def = ColDefNode{col_name, ColumnType::integer_type, col_order, 1, true}; + return std::make_tuple(-1, col_def); + } else if (func_node->function == "count") { + ColDefNode col_def = ColDefNode{col_name, ColumnType::integer_type, col_order, 1, true}; + return std::make_tuple(-1, col_def); + } + throw Exception("Unsupported function"); + + } else if (node->node_type == NodeType::arithmetical_operator) { + auto ari_node = static_cast(node); + + auto [left_col_index, left_tbl_col_def] = get_node_definition(table, ari_node->left.get(), col_name, col_order ); + auto [right_col_index, right_tbl_col_def] = get_node_definition(table, ari_node->right.get(), col_name, col_order ); + + ColumnType col_type; // TODO handle varchar and it len + if (left_tbl_col_def.type==ColumnType::float_type || right_tbl_col_def.type==ColumnType::float_type) + col_type = ColumnType::float_type; + else + col_type = ColumnType::integer_type; + + ColDefNode col_def = ColDefNode{col_name, col_type, col_order, 1, true}; + return std::make_tuple(-1, col_def); + + } else if (node->node_type == NodeType::logical_operator) { + ColDefNode col_def = ColDefNode{col_name, ColumnType::bool_type, col_order, 1, true}; + return std::make_tuple(-1, col_def); + + } else if (node->node_type == NodeType::int_value) { + ColDefNode col_def = ColDefNode{col_name, ColumnType::integer_type, col_order, 1, true}; + return std::make_tuple(-1, col_def); + + } else if (node->node_type == NodeType::float_value) { + ColDefNode col_def = ColDefNode{col_name, ColumnType::float_type, col_order, 1, true}; + return std::make_tuple(-1, col_def); + + } else if (node->node_type == NodeType::string_value) { + // TODO right len + ColDefNode col_def = ColDefNode{col_name, ColumnType::varchar_type, col_order, 64, true}; + return std::make_tuple(-1, col_def); + } + throw Exception("Unsupported node type"); +} + + + + +std::unique_ptr
USql::execute_insert_into_table(InsertIntoTableNode &node) { + // find table + Table *table_def = find_table(node.table_name); + + if (node.cols_names.size() != node.cols_values.size()) + throw Exception("Incorrect number of values"); + + // prepare empty new_row + Row& new_row = table_def->create_empty_row(); + + // copy values + for (size_t i = 0; i < node.cols_names.size(); i++) { + ColDefNode col_def = table_def->get_column_def(node.cols_names[i].col_name); + auto col_value = eval_value_node(table_def, new_row, node.cols_values[i].get(), nullptr, nullptr); + + new_row.setColumnValue(&col_def, col_value.get()); + } + + // append new_row + table_def->commit_row(new_row); + + return create_stmt_result_table(0, "insert succeeded", 1); +} + + + +std::unique_ptr
USql::execute_delete(DeleteFromTableNode &node) { + // find source table + Table *table = find_table(node.table_name); + + // execute access plan + auto affected_rows = table->rows_count(); + + table->m_rows.erase( + std::remove_if(table->m_rows.begin(), table->m_rows.end(), + [&node, table](Row &row){return eval_where(node.where.get(), table, row);}), + table->m_rows.end()); + + affected_rows -= table->rows_count(); + + return create_stmt_result_table(0, "delete succeeded", affected_rows); +} + + +std::unique_ptr
USql::execute_update(UpdateTableNode &node) { + // find source table + Table *table = find_table(node.table_name); + + // execute access plan + int affected_rows = 0; + for (auto row = begin(table->m_rows); row != end(table->m_rows); ++row) { + // eval where for row + if (eval_where(node.where.get(), table, *row)) { + int i = 0; + for (const auto& col : node.cols_names) { + // TODO cache it like in select + ColDefNode col_def = table->get_column_def(col.col_name); + std::unique_ptr new_val = eval_arithmetic_operator(col_def.type, + static_cast(*node.values[i]), + table, *row); + + usql::Table::validate_column(&col_def, new_val.get()); + row->setColumnValue(&col_def, new_val.get()); + i++; + } + affected_rows++; + // TODO tady je problem, ze kdyz to zfajluje na jednom radku ostatni by se nemely provest + } + } + + return create_stmt_result_table(0, "update succeeded", affected_rows); +} + + +bool USql::eval_where(Node *where, Table *table, Row &row) { + switch (where->node_type) { + case NodeType::true_node: + return true; + case NodeType::relational_operator: // just one condition + return eval_relational_operator(*((RelationalOperatorNode *) where), table, row); + case NodeType::logical_operator: + return eval_logical_operator(*((LogicalOperatorNode *) where), table, row); + default: + throw Exception("Wrong node type"); + } + + return false; +} + + +} // namespace