From 7ad26ba427b153fee635a4bf06302098479a8699 Mon Sep 17 00:00:00 2001 From: vaclavt Date: Thu, 17 Feb 2022 20:42:30 +0100 Subject: [PATCH] usql updated --- usql/parser.cpp | 42 ++++++++----- usql/parser.h | 140 +++++++++++++++++++++++++++++++++++------ usql/row.h | 2 +- usql/table.h | 1 + usql/usql.cpp | 23 ++++--- usql/usql_dml.cpp | 41 ++++++------ usql/usql_function.cpp | 37 ++++++----- 7 files changed, 201 insertions(+), 85 deletions(-) diff --git a/usql/parser.cpp b/usql/parser.cpp index 7d63d4b..32d55ef 100644 --- a/usql/parser.cpp +++ b/usql/parser.cpp @@ -5,6 +5,17 @@ namespace usql { // TOOD handle premature eof + std::string column_type_name(const ColumnType type) { + if (type == ColumnType::integer_type) return "integer_type"; + if (type == ColumnType::float_type) return "float_type"; + if (type == ColumnType::varchar_type) return "varchar_type"; + if (type == ColumnType::date_type) return "date_type"; + if (type == ColumnType::bool_type) return "bool_type"; + + throw Exception("invalid column type: " + (int)type); + }; + + Parser::Parser() { m_lexer = Lexer{}; } @@ -433,39 +444,39 @@ namespace usql { // function call if (token_typcol == TokenType::identifier && m_lexer.nextTokenType() == TokenType::open_paren) { - std::string function_name = m_lexer.consumeToken(TokenType::identifier).token_string; - std::vector> pars; + std::string function_name = m_lexer.consumeToken(TokenType::identifier).token_string; + std::vector> pars; - m_lexer.skipToken(TokenType::open_paren); - while (m_lexer.tokenType() != TokenType::close_paren && m_lexer.tokenType() != TokenType::eof) { - pars.push_back(parse_expression()); - m_lexer.skipTokenOptional(TokenType::comma); - } - m_lexer.skipToken(TokenType::close_paren); - return std::make_unique(function_name, std::move(pars)); + m_lexer.skipToken(TokenType::open_paren); + while (m_lexer.tokenType() != TokenType::close_paren && m_lexer.tokenType() != TokenType::eof) { + pars.push_back(parse_expression()); + m_lexer.skipTokenOptional(TokenType::comma); + } + m_lexer.skipToken(TokenType::close_paren); + return std::make_unique(function_name, std::move(pars)); } // numbers and strings std::string tokenString = m_lexer.consumeToken().token_string; if (token_typcol == TokenType::int_number) - return std::make_unique(std::stoi(tokenString)); + return std::make_unique(std::stoi(tokenString)); if (token_typcol == TokenType::double_number) - return std::make_unique(std::stod(tokenString)); + return std::make_unique(std::stod(tokenString)); if (token_typcol == TokenType::string_literal) - return std::make_unique(tokenString); + return std::make_unique(tokenString); // db column if (token_typcol == TokenType::identifier) - return std::make_unique(tokenString); + return std::make_unique(tokenString); // null if (token_typcol == TokenType::keyword_null) - return std::make_unique(); + return std::make_unique(); // true / false if (token_typcol == TokenType::keyword_true || token_typcol == TokenType::keyword_false) - return std::make_unique(token_typcol == TokenType::keyword_true); + return std::make_unique(token_typcol == TokenType::keyword_true); // token * for count(*) if (token_typcol == TokenType::multiply) @@ -529,4 +540,3 @@ namespace usql { } } // namespace - diff --git a/usql/parser.h b/usql/parser.h index a444c85..f100c32 100644 --- a/usql/parser.h +++ b/usql/parser.h @@ -21,6 +21,9 @@ enum class ColumnType { bool_type }; +std::string column_type_name(const ColumnType type); + + enum class NodeType { true_node, null_value, @@ -51,6 +54,7 @@ enum class NodeType { error }; + struct Node { NodeType node_type; @@ -58,7 +62,7 @@ struct Node { virtual ~Node() = default; virtual void dump() const { - std::cout << "type: Node" << std::endl; + std::cout << "type: Node" << (int)node_type << std::endl; } }; @@ -115,19 +119,76 @@ struct ColDefNode : Node { null(nullable) {} void dump() const override { - std::cout << "type: ColDefNode, name: " << name << ", type: " << (int)type << " TODO add more" << std::endl; + std::cout << "type: ColDefNode, name: " << name << ", type: " << column_type_name(type) << ", order: " << order << ", length: " << length << ", null: " << null << std::endl; } }; struct FunctionNode : Node { - std::string function; // TODO use enum + + enum class Type { + to_string, + to_date, + date_add, + pp, + lower, + upper, + min, + max, + count + }; + + static Type get_function(const std::string &str) { + if (str=="to_string") return Type::to_string; + if (str=="to_date") return Type::to_date; + if (str=="date_add") return Type::date_add; + if (str=="pp") return Type::pp; + if (str=="lower") return Type::lower; + if (str=="upper") return Type::upper; + if (str=="min") return Type::min; + if (str=="max") return Type::max; + if (str=="count") return Type::count; + + throw Exception("invalid function: " + str); + }; + + static std::string function_name(const Type type) { + if (type == Type::to_string) return "to_string"; + if (type == Type::to_date) return "to_date"; + if (type == Type::date_add) return "date_add"; + if (type == Type::pp) return "pp"; + if (type == Type::lower) return "lower"; + if (type == Type::upper) return "upper"; + if (type == Type::min) return "min"; + if (type == Type::max) return "max"; + if (type == Type::count) return "count"; + + throw Exception("invalid function: " + (int)type); + }; + + + Type function; std::vector> params; FunctionNode(std::string func_name, std::vector> pars) : - Node(NodeType::function), function(std::move(func_name)), params(std::move(pars)) {} + Node(NodeType::function), function(get_function(func_name)), params(std::move(pars)) {} + + bool is_agg_function() { + return (function == Type::count || function == Type::min || function == Type::max); + } + + + friend std::ostream &operator<<(std::ostream &output, const Type &t ) { + output << function_name(t); + return output; + } void dump() const override { - std::cout << "type: FunctionNode, function: " << function << " TODO add more" << std::endl; + std::cout << "type: FunctionNode, function: " << function_name(function) << "("; + for(int i = 0; i < params.size(); i++){ + if (i > 0) std::cout << ","; + params[i]->dump(); + } + std::cout << ")" << std::endl; } }; @@ -325,12 +386,17 @@ struct CreateTableNode : Node { Node(NodeType::create_table), table_name(std::move(name)), cols_defs(std::move(defs)) {} void dump() const override { - std::cout << "type: CreateTableNode, table_name: " << table_name << "TODO complete me" << std::endl; + std::cout << "type: CreateTableNode, table_name: " << table_name << "("; + for(int i = 0; i < cols_defs.size(); i++) { + if (i > 0) std::cout << ","; + cols_defs[i].dump(); + } + std::cout << ")" << std::endl; } }; struct InsertIntoTableNode : Node { - std::string table_name; + std::string table_name; std::vector cols_names; std::vector> cols_values; @@ -338,29 +404,53 @@ struct InsertIntoTableNode : Node { Node(NodeType::insert_into), table_name(std::move(name)), cols_names(std::move(names)), cols_values(std::move(values)) {} void dump() const override { - std::cout << "type: InsertIntoTableNode, table_name: " << table_name << "TODO complete me" << std::endl; + std::cout << "type: InsertIntoTableNode, table_name: " << table_name << "("; + for(int i = 0; i < cols_names.size(); i++) { + if (i > 0) std::cout << ","; + cols_names[i].dump(); + } + std::cout << ") values ("; + for(int i = 0; i < cols_values.size(); i++) { + if (i > 0) std::cout << ","; + cols_values[i]->dump(); + } + std::cout << ")" << std::endl; } }; struct SelectFromTableNode : Node { - std::string table_name; + std::string table_name; std::unique_ptr> cols_names; - std::unique_ptr where; - std::vector order_by; - OffsetLimitNode offset_limit; + std::unique_ptr where; + std::vector order_by; + OffsetLimitNode offset_limit; bool distinct; SelectFromTableNode(std::string name, std::unique_ptr> names, std::unique_ptr where_clause, std::vector orderby, OffsetLimitNode offlim, bool distinct_): Node(NodeType::select_from), table_name(std::move(name)), cols_names(std::move(names)), where(std::move(where_clause)), order_by(std::move(orderby)), offset_limit(std::move(offlim)), distinct(distinct_) {} void dump() const override { - std::cout << "type: SelectFromTableNode, table_name: " << table_name << "TODO complete me" << std::endl; + std::cout << "type: SelectFromTableNode, table_name: " << table_name; + std::cout << "colums: "; + for(int i = 0; i < cols_names->size(); i++) { + if (i > 0) std::cout << ","; + cols_names->operator[](i).dump(); + } + std::cout << "where: "; where->dump(); + std::cout << "offset,limit: "; + for(int i = 0; i < order_by.size(); i++) { + if (i > 0) std::cout << ","; + order_by[i].dump(); + } + std::cout << "offset,limit: "; + offset_limit.dump(); + std::cout << std::endl; } }; struct CreateTableAsSelectNode : Node { - std::string table_name; + std::string table_name; std::unique_ptr select_table; CreateTableAsSelectNode(std::string name, std::unique_ptr table) : @@ -373,10 +463,10 @@ struct CreateTableAsSelectNode : Node { }; struct UpdateTableNode : Node { - std::string table_name; - std::vector cols_names; - std::vector> values; - std::unique_ptr where; + std::string table_name; + std::vector cols_names; + std::vector> values; + std::unique_ptr where; UpdateTableNode(std::string name, std::vector names, std::vector> vals, std::unique_ptr where_clause) : @@ -384,8 +474,16 @@ struct UpdateTableNode : Node { where(std::move(where_clause)) {} void dump() const override { - std::cout << "type: UpdateTableNode, table_name: " << table_name << "TODO complete me" << std::endl; + std::cout << "type: UpdateTableNode, table_name: " << table_name << " set "; + for(int i = 0; i < cols_names.size(); i++) { + if (i > 0) std::cout << ","; + cols_names[i].dump(); + std::cout << " = "; + values[i]->dump(); + } + std::cout << " where: "; where->dump(); + std::cout << std::endl; } }; @@ -431,8 +529,10 @@ struct DeleteFromTableNode : Node { Node(NodeType::delete_from), table_name(std::move(name)), where(std::move(where_clause)) {} void dump() const override { - std::cout << "type: DeleteFromTableNode, table_name: " << table_name << std::endl; + std::cout << "type: DeleteFromTableNode, table_name: " << table_name; + std::cout << "where: "; where->dump(); + std::cout << std::endl; } }; diff --git a/usql/row.h b/usql/row.h index dbaa0ea..b4edf66 100644 --- a/usql/row.h +++ b/usql/row.h @@ -182,7 +182,7 @@ public: [[nodiscard]] bool is_visible() const { return m_visible; }; void set_visible() { m_visible = true; }; - void set_deleted() { m_visible = true; }; + void set_deleted() { m_visible = false; }; private: bool m_visible; diff --git a/usql/table.h b/usql/table.h index d39499e..b2d88cb 100644 --- a/usql/table.h +++ b/usql/table.h @@ -8,6 +8,7 @@ #include #include + namespace usql { struct Table { diff --git a/usql/usql.cpp b/usql/usql.cpp index ca18fe1..1c634a9 100644 --- a/usql/usql.cpp +++ b/usql/usql.cpp @@ -121,7 +121,7 @@ std::unique_ptr USql::eval_value_node(Table *table, Row &row, Node *n std::unique_ptr USql::eval_database_value_node(Table *table, Row &row, Node *node) { auto *dvl = static_cast(node); - ColDefNode col_def = table->get_column_def( dvl->col_name); // TODO optimize it to just get this def once + ColDefNode col_def = table->get_column_def(dvl->col_name); // TODO optimize it to just get this def once ColValue &db_value = row[col_def.order]; if (db_value.isNull()) @@ -176,18 +176,17 @@ std::unique_ptr USql::eval_function_value_node(Table *table, Row &row if (evaluatedPars.empty() || evaluatedPars[0]->isNull()) return std::make_unique(); - // TODO use some enum - if (fnc->function == "lower") return lower_function(evaluatedPars); - if (fnc->function == "upper") return upper_function(evaluatedPars); - if (fnc->function == "to_date") return to_date_function(evaluatedPars); - if (fnc->function == "to_string") return to_string_function(evaluatedPars); - if (fnc->function == "date_add") return date_add_function(evaluatedPars); - if (fnc->function == "pp") return pp_function(evaluatedPars); - if (fnc->function == "count") return count_function(agg_func_value, evaluatedPars); - if (fnc->function == "max") return max_function(evaluatedPars, col_def_node, agg_func_value); - if (fnc->function == "min") return min_function(evaluatedPars, col_def_node, agg_func_value); + if (fnc->function == FunctionNode::Type::lower) return lower_function(evaluatedPars); + if (fnc->function == FunctionNode::Type::upper) return upper_function(evaluatedPars); + if (fnc->function == FunctionNode::Type::to_date) return to_date_function(evaluatedPars); + if (fnc->function == FunctionNode::Type::to_string) return to_string_function(evaluatedPars); + if (fnc->function == FunctionNode::Type::date_add) return date_add_function(evaluatedPars); + if (fnc->function == FunctionNode::Type::pp) return pp_function(evaluatedPars); + if (fnc->function == FunctionNode::Type::count) return count_function(agg_func_value, evaluatedPars); + if (fnc->function == FunctionNode::Type::max) return max_function(evaluatedPars, col_def_node, agg_func_value); + if (fnc->function == FunctionNode::Type::min) return min_function(evaluatedPars, col_def_node, agg_func_value); - throw Exception("invalid function: " + fnc->function); + throw Exception("invalid function: " + FunctionNode::function_name(fnc->function)); } diff --git a/usql/usql_dml.cpp b/usql/usql_dml.cpp index ddfc34b..cb79699 100644 --- a/usql/usql_dml.cpp +++ b/usql/usql_dml.cpp @@ -88,7 +88,7 @@ void USql::select_row(SelectFromTableNode &where_node, Row *rslt_row = nullptr; - // when aggregate functions in rslt_table only one row exists + // when aggregate functions in rslt_table only one row exists if (is_aggregated && !rslt_table->empty()) rslt_row = &rslt_table->m_rows[0]; else @@ -108,24 +108,21 @@ void USql::select_row(SelectFromTableNode &where_node, rslt_row->setColumnValue((ColDefNode *) &rslt_tbl_col_defs[idx], col_value); } } - - // for aggregate is validated more than needed - rslt_table->commit_row(*rslt_row); + // for aggregate is validated more than needed + rslt_table->commit_row(*rslt_row); } bool USql::check_for_aggregate_only_functions(SelectFromTableNode &node, size_t result_cols_cnt) { size_t aggregate_funcs = 0; + for (size_t i = 0; i < node.cols_names->size(); i++) { SelectColNode * col_node = &node.cols_names->operator[](i); - if (col_node->value->node_type == NodeType::function) { - auto func_node = static_cast(col_node->value.get()); - if (func_node->function == "count" || func_node->function == "min" || func_node->function == "max") - aggregate_funcs++; - } + if (col_node->value->node_type == NodeType::function && ((FunctionNode *)col_node->value.get())->is_agg_function()) + aggregate_funcs++; } // check whether aggregates are not present or all columns are aggregates if (aggregate_funcs > 0 && aggregate_funcs != result_cols_cnt) { - throw Exception("aggregate functions with no aggregates"); + throw Exception("aggregate functions mixed with no aggregate functions in select clause"); } return aggregate_funcs > 0; @@ -229,20 +226,20 @@ std::tuple USql::get_node_definition(Table *table, Node * node, } else if (node->node_type == NodeType::function) { auto func_node = static_cast(node); - if (func_node->function == "to_string") { + if (func_node->function == FunctionNode::Type::to_string) { ColDefNode col_def = ColDefNode{col_name, ColumnType::varchar_type, col_order, 32, true}; - return std::make_tuple(-1, col_def); - } else if (func_node->function == "to_date") { + return std::make_tuple(FUNCTION_CALL, col_def); + } else if (func_node->function == FunctionNode::Type::to_date) { ColDefNode col_def = ColDefNode{col_name, ColumnType::integer_type, col_order, 1, true}; - return std::make_tuple(-1, col_def); - } else if (func_node->function == "pp") { + return std::make_tuple(FUNCTION_CALL, col_def); + } else if (func_node->function == FunctionNode::Type::pp) { ColDefNode col_def = ColDefNode{col_name, ColumnType::varchar_type, col_order, 10, true}; - return std::make_tuple(-1, col_def); - } else if (func_node->function == "lower" || func_node->function == "upper") { + return std::make_tuple(FUNCTION_CALL, col_def); + } else if (func_node->function == FunctionNode::Type::lower || func_node->function == FunctionNode::Type::upper) { // TODO get length, use get_db_column_definition ColDefNode col_def = ColDefNode{col_name, ColumnType::varchar_type, col_order, 256, true}; - return std::make_tuple(-1, col_def); - } else if (func_node->function == "min" || func_node->function == "max") { + return std::make_tuple(FUNCTION_CALL, col_def); + } else if (func_node->function == FunctionNode::Type::min || func_node->function == FunctionNode::Type::max) { auto col_type= ColumnType::float_type; size_t col_len = 1; auto & v = func_node->params[0]; @@ -252,10 +249,10 @@ std::tuple USql::get_node_definition(Table *table, Node * node, col_len = src_col_def.length; } ColDefNode col_def = ColDefNode{col_name, col_type, col_order, col_len, true}; - return std::make_tuple(-1, col_def); - } else if (func_node->function == "count") { + return std::make_tuple(FUNCTION_CALL, col_def); + } else if (func_node->function == FunctionNode::Type::count) { ColDefNode col_def = ColDefNode{col_name, ColumnType::integer_type, col_order, 1, true}; - return std::make_tuple(-1, col_def); + return std::make_tuple(FUNCTION_CALL, col_def); } throw Exception("Unsupported function"); diff --git a/usql/usql_function.cpp b/usql/usql_function.cpp index 206c212..1d7b3e6 100644 --- a/usql/usql_function.cpp +++ b/usql/usql_function.cpp @@ -11,6 +11,7 @@ std::unique_ptr USql::to_string_function(const std::vectorgetDateValue(); std::string format = evaluatedPars[1]->getStringValue(); std::string formatted_date = date_to_string(date, format); + return std::make_unique(formatted_date); } @@ -18,6 +19,7 @@ std::unique_ptr USql::to_date_function(const std::vectorgetStringValue(); std::string format = evaluatedPars[1]->getStringValue(); long epoch_time = string_to_date(date, format); + return std::make_unique(epoch_time); // No DateValueNode for now } @@ -27,6 +29,7 @@ std::unique_ptr USql::date_add_function(const std::vectorgetStringValue(); long new_date = add_to_date(datetime, quantity, part); + return std::make_unique(new_date); // No DateValueNode for now } @@ -34,51 +37,57 @@ std::unique_ptr USql::date_add_function(const std::vector USql::upper_function(const std::vector> &evaluatedPars) { std::string str = evaluatedPars[0]->getStringValue(); std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) -> unsigned char { return toupper(c); }); + return std::make_unique(str); } std::unique_ptr USql::lower_function(const std::vector> &evaluatedPars) { std::string str = evaluatedPars[0]->getStringValue(); std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) -> unsigned char { return tolower(c); }); + return std::make_unique(str); } std::unique_ptr USql::pp_function(const std::vector> &evaluatedPars) { + constexpr auto k_num_format_rpad = 10; + constexpr auto k_num_format_maxlen = 20; + auto &parsed_value = evaluatedPars[0]; if (parsed_value->node_type == NodeType::int_value || parsed_value->node_type == NodeType::float_value) { std::string format = evaluatedPars.size() > 1 ? evaluatedPars[1]->getStringValue() : ""; - char buf[20] {0}; // TODO constant here + char buf[k_num_format_maxlen] {0}; double value = parsed_value->getDoubleValue(); if (format == "100%") - std::snprintf(buf, 20, "%.2f%%", value); + std::snprintf(buf, k_num_format_maxlen, "%.2f%%", value); else if (format == "%.2f") - std::snprintf(buf, 20, "%.2f", value); + std::snprintf(buf, k_num_format_maxlen, "%.2f", value); else if (value >= 1000000000000) - std::snprintf(buf, 20, "%7.2fT", value/1000000000000); + std::snprintf(buf, k_num_format_maxlen, "%7.2fT", value/1000000000000); else if (value >= 1000000000) - std::sprintf(buf, "%7.2fB", value/1000000000); + std::snprintf(buf, k_num_format_maxlen, "%7.2fB", value/1000000000); else if (value >= 1000000) - std::snprintf(buf, 20, "%7.2fM", value/1000000); + std::snprintf(buf, k_num_format_maxlen, "%7.2fM", value/1000000); else if (value >= 100000) - std::snprintf(buf, 20, "%7.2fM", value/100000); // 0.12M + std::snprintf(buf, k_num_format_maxlen, "%7.2fM", value/100000); // 0.12M else if (value <= -1000000000000) - std::snprintf(buf, 20, "%7.2fT", value/1000000000000); + std::snprintf(buf, k_num_format_maxlen, "%7.2fT", value/1000000000000); else if (value <= -1000000000) - std::snprintf(buf, 20, "%7.2fB", value/1000000000); + std::snprintf(buf, k_num_format_maxlen, "%7.2fB", value/1000000000); else if (value <= -1000000) - std::snprintf(buf, 20, "%7.2fM", value/1000000); + std::snprintf(buf, k_num_format_maxlen, "%7.2fM", value/1000000); else if (value <= -100000) - std::snprintf(buf, 20, "%7.2fM", value/100000); // 0.12M + std::snprintf(buf, k_num_format_maxlen, "%7.2fM", value/100000); // 0.12M else if (value == 0) buf[0]='0'; else - return std::make_unique(parsed_value->getStringValue().substr(0, 10)); - // TODO introduce constant for 10 + return std::make_unique(parsed_value->getStringValue().substr(0, k_num_format_rpad)); + std::string s {buf}; - return std::make_unique(string_padd(s.erase(s.find_last_not_of(' ')+1), 10, ' ', false)); + return std::make_unique(string_padd(s.erase(s.find_last_not_of(' ') + 1), k_num_format_rpad, ' ', false)); } + return std::make_unique(parsed_value->getStringValue()); }