From 3e913263fc85f957f54c570c8aec3cd9347f5a45 Mon Sep 17 00:00:00 2001 From: VaclavT Date: Fri, 16 Jul 2021 10:07:16 +0200 Subject: [PATCH] functions very basic functionality added --- main.cpp | 4 +- parser.cpp | 24 +++--- usql.cpp | 219 +++++++++++++++++++++++++++++------------------------ usql.h | 33 +++----- 4 files changed, 145 insertions(+), 135 deletions(-) diff --git a/main.cpp b/main.cpp index 42e3cf5..5c7b7e0 100644 --- a/main.cpp +++ b/main.cpp @@ -10,10 +10,10 @@ int main(int argc, char *argv[]) { std::vector sql_commands{ "create table a (i integer not null, s varchar(64), f float null)", - "insert into a (i, s) values(1, 'one')", + "insert into a (i, s) values(1, upper('one'))", "insert into a (i, s) values(2, 'two')", "insert into a (i, s) values(3, 'two')", - "insert into a (i, s) values(4, 'four')", + "insert into a (i, s) values(4, lower('FOUR'))", "insert into a (i, s) values(5, 'five')", "select i, s from a where i > 2", "select i, s from a where i = 1", diff --git a/parser.cpp b/parser.cpp index d7ff0db..85dd894 100644 --- a/parser.cpp +++ b/parser.cpp @@ -141,21 +141,19 @@ std::unique_ptr Parser::parse_value() { return std::make_unique(std::stof(lexer.consumeCurrentToken().token_string)); } if (lexer.tokenType() == TokenType::string_literal) { - if (lexer.nextTokenType() != TokenType::open_paren) { - return std::make_unique(lexer.consumeCurrentToken().token_string); - } else { - // function - std::string func_name = lexer.consumeCurrentToken().token_string; - std::vector> pars; + return std::make_unique(lexer.consumeCurrentToken().token_string); + } + if (lexer.tokenType() == TokenType::identifier) { + std::string func_name = lexer.consumeCurrentToken().token_string; + std::vector> pars; - lexer.skipToken(TokenType::open_paren); - while (lexer.tokenType() != TokenType::close_paren) { // TODO handle errors - auto par = parse_value(); - lexer.skipTokenOptional(TokenType::comma); - } - lexer.skipToken(TokenType::close_paren); - return std::make_unique(func_name, std::move(pars)); + lexer.skipToken(TokenType::open_paren); + while (lexer.tokenType() != TokenType::close_paren) { // TODO handle errors + pars.push_back(parse_value()); + lexer.skipTokenOptional(TokenType::comma); } + lexer.skipToken(TokenType::close_paren); + return std::make_unique(func_name, std::move(pars)); } throw Exception("Syntax error"); } diff --git a/usql.cpp b/usql.cpp index 5c6f139..065a80e 100644 --- a/usql.cpp +++ b/usql.cpp @@ -57,7 +57,7 @@ std::unique_ptr USql::execute_insert_into_table(InsertIntoTableNode &node ColDefNode col_def = table_def->get_column_def(node.cols_names[i].name); // TODO validate value - auto value = evalValueNode(node.cols_values[i].get()); + auto value = evalValueNode(table_def, new_row, node.cols_values[i].get()); if (col_def.type == ColumnType::integer_type) { new_row.setColumnValue(col_def.order, value->getIntValue()); @@ -99,7 +99,7 @@ std::unique_ptr
USql::execute_select(SelectFromTableNode &node) { // execute access plan for (auto row = begin(table->m_rows); row != end(table->m_rows); ++row) { // eval where for row - if (evalWhere(node.where.get(), table, row)) { + if (evalWhere(node.where.get(), table, *row)) { // prepare empty row Row new_row = result->createEmptyRow(); @@ -134,7 +134,7 @@ std::unique_ptr
USql::execute_delete(DeleteFromTableNode &node) { // execute access plan auto it = table->m_rows.begin(); for (; it != table->m_rows.end();) { - if (evalWhere(node.where.get(), table, it)) { + if (evalWhere(node.where.get(), table, *it)) { // TODO this can be really expensive operation it = table->m_rows.erase(it); } else { @@ -155,15 +155,15 @@ std::unique_ptr
USql::execute_update(UpdateTableNode &node) { // execute access plan for (auto row = begin(table->m_rows); row != end(table->m_rows); ++row) { // eval where for row - if (evalWhere(node.where.get(), table, row)) { + if (evalWhere(node.where.get(), table, *row)) { int i = 0; for (auto col : node.cols_names) { // TODO cache it like in select ColDefNode cdef = table->get_column_def(col.name); - std::unique_ptr new_val = evalArithmetic(cdef.type, - static_cast(*node.values[i]), - table, row); + std::unique_ptr new_val = evalArithmeticOperator(cdef.type, + static_cast(*node.values[i]), + table, *row); if (cdef.type == ColumnType::integer_type) { row->setColumnValue(cdef.order, new_val->getIntValue()); @@ -225,91 +225,92 @@ std::unique_ptr
USql::execute_load(LoadIntoTableNode &node) { } -bool USql::evalWhere(Node *where, Table *table, - std::vector>::iterator &row) const { +bool USql::evalWhere(Node *where, Table *table, Row &row) const { switch (where->node_type) { // no where clause case NodeType::true_node: return true; - case NodeType::relational_operator: // just one condition + case NodeType::relational_operator: // just one condition return evalRelationalOperator(*((RelationalOperatorNode *) where), table, row); - case NodeType::logical_operator: - return evalLogicalOperator(*((LogicalOperatorNode *) where), table, row); - default: - throw Exception("Wrong node type"); + case NodeType::logical_operator: + return evalLogicalOperator(*((LogicalOperatorNode *) where), table, row); + default: + throw Exception("Wrong node type"); } return false; } -bool USql::evalRelationalOperator(const RelationalOperatorNode &filter, Table *table, - std::vector>::iterator &row) const { - std::unique_ptr left_value = evalNode(table, row, filter.left.get()); - std::unique_ptr right_value = evalNode(table, row, filter.right.get()); +bool USql::evalRelationalOperator(const RelationalOperatorNode &filter, Table *table, Row &row) const { + std::unique_ptr left_value = evalValueNode(table, row, filter.left.get()); + std::unique_ptr right_value = evalValueNode(table, row, filter.right.get()); double comparator; if (left_value->node_type == NodeType::int_value && right_value->node_type == NodeType::int_value) { comparator = left_value->getIntValue() - right_value->getIntValue(); - } else if ((left_value->node_type == NodeType::int_value && - right_value->node_type == NodeType::float_value) || - (left_value->node_type == NodeType::float_value && - right_value->node_type == NodeType::int_value) || - (left_value->node_type == NodeType::float_value && - right_value->node_type == NodeType::float_value)) { + } else if ((left_value->node_type == NodeType::int_value && right_value->node_type == NodeType::float_value) || + (left_value->node_type == NodeType::float_value && right_value->node_type == NodeType::int_value) || + (left_value->node_type == NodeType::float_value && right_value->node_type == NodeType::float_value)) { comparator = left_value->getDoubleValue() - right_value->getDoubleValue(); - } else if (left_value->node_type == NodeType::string_value || - right_value->node_type == NodeType::string_value) { + } else if (left_value->node_type == NodeType::string_value || right_value->node_type == NodeType::string_value) { comparator = left_value->getStringValue().compare(right_value->getStringValue()); } else { // TODO throw exception } - switch (filter.op) { case RelationalOperatorType::equal: return comparator == 0.0; - case RelationalOperatorType::not_equal: - return comparator != 0.0; - case RelationalOperatorType::greater: - return comparator > 0.0; - case RelationalOperatorType::greater_equal: - return comparator >= 0.0; - case RelationalOperatorType::lesser: - return comparator < 0.0; - case RelationalOperatorType::lesser_equal: - return comparator <= 0.0; + case RelationalOperatorType::not_equal: + return comparator != 0.0; + case RelationalOperatorType::greater: + return comparator > 0.0; + case RelationalOperatorType::greater_equal: + return comparator >= 0.0; + case RelationalOperatorType::lesser: + return comparator < 0.0; + case RelationalOperatorType::lesser_equal: + return comparator <= 0.0; } throw Exception("invalid relational operator"); - } -std::unique_ptr -USql::evalNode(Table *table, std::vector>::iterator &row, Node *node) const { +std::unique_ptr USql::evalValueNode(Table *table, Row &row, Node *node) const { if (node->node_type == NodeType::database_value) { - DatabaseValueNode *dvl = static_cast(node); - ColDefNode col_def = table->get_column_def( - dvl->col_name); // TODO optimize it to just get this def once - auto db_value = row->ithColumn(col_def.order); + return evalDatabaseValueNode(table, row, node); - if (col_def.type == ColumnType::integer_type) { - return std::make_unique(db_value->integerValue()); - } - if (col_def.type == ColumnType::float_type) { - return std::make_unique(db_value->floatValue()); - } - if (col_def.type == ColumnType::varchar_type) { - return std::make_unique(db_value->stringValue()); - } - } else { - return evalValueNode(node); + } else if (node->node_type == NodeType::int_value || node->node_type == NodeType::float_value || node->node_type == NodeType::string_value) { + return evalLiteralValueNode(table, row, node); + + } else if (node->node_type == NodeType::function) { + return evalFunctionValueNode(table, row, node); } + throw Exception("unsupported node type"); } -std::unique_ptr USql::evalValueNode(Node *node) const { +std::unique_ptr USql::evalDatabaseValueNode(Table *table, Row &row, Node *node) const { + DatabaseValueNode *dvl = static_cast(node); + ColDefNode col_def = table->get_column_def( dvl->col_name); // TODO optimize it to just get this def once + auto db_value = row.ithColumn(col_def.order); + + if (col_def.type == ColumnType::integer_type) { + return std::__1::make_unique(db_value->integerValue()); + } + if (col_def.type == ColumnType::float_type) { + return std::__1::make_unique(db_value->floatValue()); + } + if (col_def.type == ColumnType::varchar_type) { + return std::__1::make_unique(db_value->stringValue()); + } + throw Exception("unknown database value type"); +} + + +std::unique_ptr USql::evalLiteralValueNode(Table *table, Row &row, Node *node) const { if (node->node_type == NodeType::int_value) { IntValueNode *ivl = static_cast(node); return std::make_unique(ivl->value); @@ -321,35 +322,56 @@ std::unique_ptr USql::evalValueNode(Node *node) const { } else if (node->node_type == NodeType::string_value) { StringValueNode *ivl = static_cast(node); return std::make_unique(ivl->value); - } else if ("function eval" == "xxx") { } throw Exception("invalid type"); } -bool USql::evalLogicalOperator(LogicalOperatorNode &node, Table *pTable, - std::vector>::iterator &iter) const { - bool left = evalRelationalOperator(static_cast(*node.left), pTable, iter); + +std::unique_ptr USql::evalFunctionValueNode(Table *table, Row &row, Node *node) const { + FunctionNode *fnc = static_cast(node); + + std::vector> evaluatedPars; + for(int i = 0; i < fnc->params.size(); i++) { + evaluatedPars.push_back(evalValueNode(table, row, fnc->params[i].get())); + } + + // TODO use some enum + if (fnc->function == "lower") { + std::string str = evaluatedPars[0]->getStringValue(); + std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) -> unsigned char { return std::tolower(c); }); + return std::make_unique(str); + } + if (fnc->function == "upper") { + std::string str = evaluatedPars[0]->getStringValue(); + std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) -> unsigned char { return std::toupper(c); }); + return std::make_unique(str); + } + + throw Exception("invalid function"); +} + + +bool USql::evalLogicalOperator(LogicalOperatorNode &node, Table *pTable, Row &row) const { + bool left = evalRelationalOperator(static_cast(*node.left), pTable, row); if ((node.op == LogicalOperatorType::and_operator && !left) || (node.op == LogicalOperatorType::or_operator && left)) return left; - bool right = evalRelationalOperator(static_cast(*node.right), pTable, iter); + bool right = evalRelationalOperator(static_cast(*node.right), pTable, row); return right; } -std::unique_ptr -USql::evalArithmetic(ColumnType outType, ArithmeticalOperatorNode &node, Table *table, - std::vector>::iterator &row) const { +std::unique_ptr USql::evalArithmeticOperator(ColumnType outType, ArithmeticalOperatorNode &node, Table *table, Row &row) const { if (node.op == ArithmeticalOperatorType::copy_value) { - return evalNode(table, row, node.left.get()); + return evalValueNode(table, row, node.left.get()); } - std::unique_ptr left = evalNode(table, row, node.left.get()); - std::unique_ptr right = evalNode(table, row, node.right.get()); + std::unique_ptr left = evalValueNode(table, row, node.left.get()); + std::unique_ptr right = evalValueNode(table, row, node.right.get()); if (outType == ColumnType::float_type) { double l = ((ValueNode *) left.get())->getDoubleValue(); @@ -357,14 +379,14 @@ USql::evalArithmetic(ColumnType outType, ArithmeticalOperatorNode &node, Table * switch (node.op) { case ArithmeticalOperatorType::plus_operator: return std::make_unique(l + r); - case ArithmeticalOperatorType::minus_operator: - return std::make_unique(l - r); - case ArithmeticalOperatorType::multiply_operator: - return std::make_unique(l * r); - case ArithmeticalOperatorType::divide_operator: - return std::make_unique(l / r); - default: - throw Exception("implement me!!"); + case ArithmeticalOperatorType::minus_operator: + return std::make_unique(l - r); + case ArithmeticalOperatorType::multiply_operator: + return std::make_unique(l * r); + case ArithmeticalOperatorType::divide_operator: + return std::make_unique(l / r); + default: + throw Exception("implement me!!"); } } else if (outType == ColumnType::integer_type) { int l = ((ValueNode *) left.get())->getIntValue(); @@ -372,14 +394,14 @@ USql::evalArithmetic(ColumnType outType, ArithmeticalOperatorNode &node, Table * switch (node.op) { case ArithmeticalOperatorType::plus_operator: return std::make_unique(l + r); - case ArithmeticalOperatorType::minus_operator: - return std::make_unique(l - r); - case ArithmeticalOperatorType::multiply_operator: - return std::make_unique(l * r); - case ArithmeticalOperatorType::divide_operator: - return std::make_unique(l / r); - default: - throw Exception("implement me!!"); + case ArithmeticalOperatorType::minus_operator: + return std::make_unique(l - r); + case ArithmeticalOperatorType::multiply_operator: + return std::make_unique(l * r); + case ArithmeticalOperatorType::divide_operator: + return std::make_unique(l / r); + default: + throw Exception("implement me!!"); } } else if (outType == ColumnType::varchar_type) { @@ -388,9 +410,8 @@ USql::evalArithmetic(ColumnType outType, ArithmeticalOperatorNode &node, Table * switch (node.op) { case ArithmeticalOperatorType::plus_operator: return std::make_unique(l + r); - - default: - throw Exception("implement me!!"); + default: + throw Exception("implement me!!"); } } @@ -398,18 +419,6 @@ USql::evalArithmetic(ColumnType outType, ArithmeticalOperatorNode &node, Table * } - -Table *USql::find_table(const std::string name) { - auto name_cmp = [name](const Table& t) { return t.m_name == name; }; - auto table_def = std::find_if(begin(m_tables), end(m_tables), name_cmp); - if (table_def != std::end(m_tables)) { - return table_def.operator->(); - } else { - throw Exception("table not found (" + name + ")"); - } -} - - std::unique_ptr
USql::create_stmt_result_table(int code, std::string text) { std::vector result_tbl_col_defs{}; result_tbl_col_defs.push_back(ColDefNode("code", ColumnType::integer_type, 0, 1, false)); @@ -425,4 +434,16 @@ std::unique_ptr
USql::create_stmt_result_table(int code, std::string text return std::move(table_def); } -} \ No newline at end of file + + +Table *USql::find_table(const std::string name) { + auto name_cmp = [name](const Table& t) { return t.m_name == name; }; + auto table_def = std::find_if(begin(m_tables), end(m_tables), name_cmp); + if (table_def != std::end(m_tables)) { + return table_def.operator->(); + } else { + throw Exception("table not found (" + name + ")"); + } +} + +} // namespace \ No newline at end of file diff --git a/usql.h b/usql.h index 9be3bd9..eeefeae 100644 --- a/usql.h +++ b/usql.h @@ -18,43 +18,34 @@ private: std::unique_ptr
execute(Node &node); std::unique_ptr
execute_create_table(CreateTableNode &node); - std::unique_ptr
execute_insert_into_table(InsertIntoTableNode &node); - std::unique_ptr
execute_select(SelectFromTableNode &node); - std::unique_ptr
execute_delete(DeleteFromTableNode &node); - std::unique_ptr
execute_update(UpdateTableNode &node); - std::unique_ptr
execute_load(LoadIntoTableNode &node); - Table *find_table(const std::string name); - - std::unique_ptr
create_stmt_result_table(int code, std::string text); - private: - bool evalWhere(Node *where, Table *table, - std::vector>::iterator &row) const; + bool evalWhere(Node *where, Table *table, Row &row) const; - std::unique_ptr evalNode(Table *table, std::vector>::iterator &row, - Node *node) const; + std::unique_ptr evalValueNode(Table *table, Row &row, Node *node) const; + std::unique_ptr evalDatabaseValueNode(Table *table, Row &row, Node *node) const; + std::unique_ptr evalLiteralValueNode(Table *table, Row &row, Node *node) const; + std::unique_ptr evalFunctionValueNode(Table *table, Row &row, Node *node) const; - std::unique_ptr evalValueNode(Node *node) const; - bool evalRelationalOperator(const RelationalOperatorNode &filter, Table *table, - std::vector>::iterator &row) const; + bool evalRelationalOperator(const RelationalOperatorNode &filter, Table *table, Row &row) const; + bool evalLogicalOperator(LogicalOperatorNode &node, Table *pTable, Row &row) const; + std::unique_ptr evalArithmeticOperator(ColumnType outType, ArithmeticalOperatorNode &node, Table *table, Row &row) const; - bool evalLogicalOperator(LogicalOperatorNode &node, Table *pTable, - std::vector>::iterator &iter) const; - std::unique_ptr evalArithmetic(ColumnType outType, ArithmeticalOperatorNode &node, Table *table, - std::vector>::iterator &row) const; + std::unique_ptr
create_stmt_result_table(int code, std::string text); + Table *find_table(const std::string name); + private: Parser m_parser; std::vector
m_tables; }; -} \ No newline at end of file +} // namespace \ No newline at end of file