diff --git a/usql/usql.h b/usql/usql.h index 554bb78..0f84de0 100644 --- a/usql/usql.h +++ b/usql/usql.h @@ -41,14 +41,15 @@ private: static std::unique_ptr eval_function_value_node(Table *table, Row &row, Node *node, ColDefNode *col_def_node, ColValue *agg_func_value); - static bool eval_relational_operator(const RelationalOperatorNode &filter, Table *table, Row &row) ; - static bool eval_logical_operator(LogicalOperatorNode &node, Table *pTable, Row &row) ; - static std::unique_ptr eval_arithmetic_operator(ColumnType outType, ArithmeticalOperatorNode &node, Table *table, Row &row) ; + static bool eval_relational_operator(const RelationalOperatorNode &filter, Table *table, Row &row); + static bool eval_logical_operator(LogicalOperatorNode &node, Table *pTable, Row &row); + static std::unique_ptr eval_arithmetic_operator(ColumnType outType, ArithmeticalOperatorNode &node, Table *table, Row &row); static std::unique_ptr create_stmt_result_table(long code, const std::string &text, size_t affected_rows); - static std::tuple get_column_definition(Table *table, SelectColNode *select_col_node, int col_order) ; - static std::tuple get_node_definition(Table *table, Node *select_col_node, const std::string & col_name, int col_order) ; + static std::tuple get_column_definition(Table *table, SelectColNode *select_col_node, int col_order); + static ColDefNode get_db_column_definition(Table *table, Node *node); + static std::tuple get_node_definition(Table *table, Node *select_col_node, const std::string & col_name, int col_order); Table *find_table(const std::string &name); void check_table_not_exists(const std::string &name); diff --git a/usql/usql_dml.cpp b/usql/usql_dml.cpp index acf2163..56cd32c 100644 --- a/usql/usql_dml.cpp +++ b/usql/usql_dml.cpp @@ -148,11 +148,18 @@ std::tuple USql::get_column_definition(Table *table, SelectColN return get_node_definition(table, select_col_node->value.get(), select_col_node->name, col_order ); } +ColDefNode USql::get_db_column_definition(Table *table, Node *node) { + if (node->node_type == NodeType::database_value) { + auto db_node = static_cast(node); + return table->get_column_def(db_node->col_name); + } + + throw Exception("Undefined table node - get_db_column_definition"); +} + std::tuple USql::get_node_definition(Table *table, Node * node, const std::string & col_name, int col_order ) { if (node->node_type == NodeType::database_value) { - auto dbval_node = static_cast(node); - - ColDefNode src_col_def = table->get_column_def(dbval_node->col_name); + ColDefNode src_col_def = get_db_column_definition(table, node); ColDefNode col_def = ColDefNode{col_name, src_col_def.type, col_order, src_col_def.length, src_col_def.null}; return std::make_tuple(src_col_def.order, col_def); @@ -169,12 +176,19 @@ std::tuple USql::get_node_definition(Table *table, Node * node, ColDefNode col_def = ColDefNode{col_name, ColumnType::varchar_type, col_order, 10, true}; return std::make_tuple(-1, col_def); } else if (func_node->function == "lower" || func_node->function == "upper") { - // TODO get length + // TODO get length, use get_db_column_definition ColDefNode col_def = ColDefNode{col_name, ColumnType::varchar_type, col_order, 256, true}; return std::make_tuple(-1, col_def); } else if (func_node->function == "min" || func_node->function == "max") { - // TODO get correct type and length - ColDefNode col_def = ColDefNode{col_name, ColumnType::integer_type, col_order, 1, true}; + auto col_type= ColumnType::float_type; + int col_len = 1; + auto & v = func_node->params[0]; + if (v->node_type == NodeType::database_value) { + ColDefNode src_col_def = get_db_column_definition(table, v.get()); + col_type = src_col_def.type; + col_len = src_col_def.length; + } + ColDefNode col_def = ColDefNode{col_name, col_type, col_order, col_len, true}; return std::make_tuple(-1, col_def); } else if (func_node->function == "count") { ColDefNode col_def = ColDefNode{col_name, ColumnType::integer_type, col_order, 1, true}; @@ -278,8 +292,8 @@ std::unique_ptr
USql::execute_update(UpdateTableNode &node) { // TODO cache it like in select ColDefNode col_def = table->get_column_def(col.col_name); std::unique_ptr new_val = eval_arithmetic_operator(col_def.type, - static_cast(*node.values[i]), - table, *row); + static_cast(*node.values[i]), + table, *row); usql::Table::validate_column(&col_def, new_val.get()); row->setColumnValue(&col_def, new_val.get());