From ddb9441e23cbb0f211ebc447025b6f993319ca2c Mon Sep 17 00:00:00 2001 From: VaclavT Date: Thu, 8 Jul 2021 20:51:03 +0200 Subject: [PATCH] basic skeletons of update and delete added --- Readme.md | 3 +- executor.cpp | 153 +++++++++++++++++++++++++++++++++++++-------------- executor.h | 12 +++- lexer.cpp | 27 +++++++-- lexer.h | 5 +- main.cpp | 9 ++- parser.cpp | 112 ++++++++++++++++++++++++++++++++----- parser.h | 57 +++++++++++++++++-- row.h | 2 +- 9 files changed, 308 insertions(+), 72 deletions(-) diff --git a/Readme.md b/Readme.md index 09d4834..b6ff5a7 100644 --- a/Readme.md +++ b/Readme.md @@ -2,9 +2,10 @@ ### TODO - rename it to usql - rename Exception to UException, Table to UTable, Row to URow etc +- remove newlines from lexed string tokens - unify using of float and double keywords - add constructors - add exceptions -- class members should have prefix m_O +- class members should have prefix m_ - add pipe | token - add logging \ No newline at end of file diff --git a/executor.cpp b/executor.cpp index b7a474c..02fcee0 100644 --- a/executor.cpp +++ b/executor.cpp @@ -28,6 +28,10 @@ bool Executor::execute(Node& node) { return execute_insert_into_table(static_cast(node)); case NodeType::select_from: return execute_select(static_cast(node)); + case NodeType::delete_from: + return execute_delete(static_cast(node)); + case NodeType::update_table: + return execute_update(static_cast(node)); default: // TODO error message return false; @@ -97,20 +101,17 @@ bool Executor::execute_select(SelectFromTableNode& node) { } Table result {"result", result_tbl_col_defs}; - // execute access plan - - for (auto row = begin (table->m_rows); row != end (table->m_rows); ++row) { // eval where for row - if (evalWhere(node, table, row)) { + if (evalWhere(node.where.get(), table, row)) { // prepare empty row Row new_row = result.createEmptyRow(); // copy column values for(auto idx=0; idxithColum(row_col_index); + ColValue *col_value = row->ithColumn(row_col_index); if (result_tbl_col_defs[idx].type == ColumnType::integer_type) new_row.setColumnValue(idx, ((ColIntegerValue*)col_value)->integerValue()); if (result_tbl_col_defs[idx].type == ColumnType::float_type) @@ -129,22 +130,73 @@ bool Executor::execute_select(SelectFromTableNode& node) { return true; } -bool Executor::evalWhere(const SelectFromTableNode &node, Table *table, +bool Executor::execute_delete(DeleteFromTableNode& node) { + // TODO create plan for accessing rows + + // find source table + Table* table = find_table(node.table_name); + + // execute access plan + auto it = table->m_rows.begin(); + for ( ; it != table->m_rows.end(); ) { + if (evalWhere(node.where.get(), table, it)) { + std::cout << "delete here" << std::endl; + ++it; // TODO this does not work : it = table->m_rows.erase(it); + } else { + ++it; + } + } + + return true; +} + +bool Executor::execute_update(UpdateTableNode &node) { + // TODO create plan for accessing rows + + // find source table + Table* table = find_table(node.table_name); + + // execute access plan + for (auto row = begin (table->m_rows); row != end (table->m_rows); ++row) { + // eval where for row + if (evalWhere(node.where.get(), table, row)) { + // TODO do update + int i = 0; + for(auto col : node.cols_names) { + // TODO cache it like in select + ColDefNode cdef = table->get_column_def(col.name); + + std::unique_ptr new_val = evalArithmetic(static_cast(*node.values[i]), table, row); + + if (cdef.type == ColumnType::integer_type) { + row->setColumnValue(cdef.order, ((IntValueNode*)new_val.get())->value); + } else if (cdef.type == ColumnType::float_type) { + row->setColumnValue(cdef.order, ((FloatValueNode*)new_val.get())->value); + } else { + throw Exception("Implement me!"); + } + i++; + } + } + } + + return true; +} + + +bool Executor::evalWhere(Node *where, Table *table, std::vector>::iterator &row) const { - if (node.where->node_type == NodeType::true_node) { // no where clause - return true; + switch (where->node_type) { // no where clause + case NodeType::true_node: + return true; + case NodeType::relational_operator: // just one condition + return evalRelationalOperator(*((RelationalOperatorNode *)where), table, row); + case NodeType::logical_operator: + return evalLogicalOperator(*((LogicalOperatorNode *)where), table, row); + default: + throw Exception("Wrong node type"); } - if (node.where.get()->node_type == NodeType::relational_operator) { - RelationalOperatorNode &filter = static_cast(*node.where); - return evalRelationalOperator(filter, table, row); - } - -// if (node.where.get()->node_type == NodeType::logical_operator) { -// LogicalOperatorNode &filter = static_cast(*node.where); -// return evalLogicalOperator(filter, table, row); -// } - return false; } @@ -155,35 +207,35 @@ bool Executor::evalRelationalOperator(const RelationalOperatorNode &filter, Tabl double comparator; if (left_value->node_type == NodeType::int_value && right_value->node_type == NodeType::int_value) { - IntValueNode *lvalue = static_cast(left_value.get()); - IntValueNode *rvalue = static_cast(right_value.get()); + auto lvalue = static_cast(left_value.get()); + auto rvalue = static_cast(right_value.get()); comparator = lvalue->value - rvalue->value; } if (left_value->node_type == NodeType::int_value && right_value->node_type == NodeType::float_value) { - IntValueNode *lvalue = static_cast(left_value.get()); - FloatValueNode *rvalue = static_cast(right_value.get()); + auto *lvalue = static_cast(left_value.get()); + auto *rvalue = static_cast(right_value.get()); comparator = (double)lvalue->value - rvalue->value; } if (left_value->node_type == NodeType::int_value && right_value->node_type == NodeType::string_value) { - IntValueNode *lvalue = static_cast(left_value.get()); - StringValueNode *rvalue = static_cast(right_value.get()); + auto *lvalue = static_cast(left_value.get()); + auto *rvalue = static_cast(right_value.get()); comparator = std::to_string(lvalue->value).compare(rvalue->value); } if (left_value->node_type == NodeType::float_value && right_value->node_type == NodeType::int_value) { - FloatValueNode *lvalue = static_cast(left_value.get()); - IntValueNode *rvalue = static_cast(right_value.get()); + auto *lvalue = static_cast(left_value.get()); + auto *rvalue = static_cast(right_value.get()); comparator = lvalue->value - (double)rvalue->value; } if (left_value->node_type == NodeType::float_value && right_value->node_type == NodeType::float_value) { - FloatValueNode *lvalue = static_cast(left_value.get()); - FloatValueNode *rvalue = static_cast(right_value.get()); + auto *lvalue = static_cast(left_value.get()); + auto *rvalue = static_cast(right_value.get()); comparator = lvalue->value - rvalue->value; } if (left_value->node_type == NodeType::float_value && right_value->node_type == NodeType::string_value) { - FloatValueNode *lvalue = static_cast(left_value.get()); - StringValueNode *rvalue = static_cast(right_value.get()); + auto *lvalue = static_cast(left_value.get()); + auto *rvalue = static_cast(right_value.get()); comparator = std::to_string(lvalue->value).compare(rvalue->value); } @@ -223,11 +275,11 @@ bool Executor::evalRelationalOperator(const RelationalOperatorNode &filter, Tabl throw Exception("invalid relational operator"); } -std::unique_ptr Executor::evalNode(Table *table, std::vector>::iterator &row, Node *filter) const { - if (filter->node_type == NodeType::database_value) { - DatabaseValueNode *dvl = static_cast(filter); +std::unique_ptr Executor::evalNode(Table *table, std::vector>::iterator &row, Node *node) const { + if (node->node_type == NodeType::database_value) { + DatabaseValueNode *dvl = static_cast(node); ColDefNode col_def = table->get_column_def(dvl->col_name); // TODO optimize it to just get this def once - auto db_value = row->ithColum(col_def.order); + auto db_value = row->ithColumn(col_def.order); if (col_def.type == ColumnType::integer_type) { return std::make_unique(db_value->integerValue()); @@ -239,19 +291,40 @@ std::unique_ptr Executor::evalNode(Table *table, std::vector(db_value->stringValue()); } - } else if (filter->node_type == NodeType::int_value) { - IntValueNode *ivl = static_cast(filter); + } else if (node->node_type == NodeType::int_value) { + IntValueNode *ivl = static_cast(node); return std::make_unique(ivl->value); - } else if (filter->node_type == NodeType::float_value) { - FloatValueNode *ivl = static_cast(filter); + } else if (node->node_type == NodeType::float_value) { + FloatValueNode *ivl = static_cast(node); return std::make_unique(ivl->value); - } else if (filter->node_type == NodeType::string_value) { - StringValueNode *ivl = static_cast(filter); + } else if (node->node_type == NodeType::string_value) { + StringValueNode *ivl = static_cast(node); return std::make_unique(ivl->value); } throw Exception("invalid type"); } +bool Executor::evalLogicalOperator(LogicalOperatorNode &node, Table *pTable, + std::vector>::iterator &iter) const { + bool left = evalRelationalOperator(static_cast(*node.left), pTable, iter); + + if ((node.op == LogicalOperatorType::and_operator && !left) || (node.op == LogicalOperatorType::or_operator && left)) + return left; + + bool right = evalRelationalOperator(static_cast(*node.right), pTable, iter); + return right; +} + +std::unique_ptr Executor::evalArithmetic(ArithmeticalOperatorNode &node, Table *table, + std::vector>::iterator &row) const { + + switch (node.op) { + case ArithmeticalOperatorType::copy_value: + return evalNode(table, row, node.left.get()); + default: + throw Exception("implement me!!"); + } +} \ No newline at end of file diff --git a/executor.h b/executor.h index 0ef4689..53e491e 100644 --- a/executor.h +++ b/executor.h @@ -17,18 +17,26 @@ private: bool execute_create_table(CreateTableNode& node); bool execute_insert_into_table(InsertIntoTableNode& node); bool execute_select(SelectFromTableNode& node); + bool execute_delete(DeleteFromTableNode& node); + bool execute_update(UpdateTableNode& node); Table* find_table(const std::string name); private: std::vector m_tables; - bool evalWhere(const SelectFromTableNode &node, Table *table, + bool evalWhere(Node *where, Table *table, std::vector>::iterator &row) const; std::unique_ptr evalNode(Table *table, std::vector>::iterator &row, - Node *filter) const; + Node *node) const; bool evalRelationalOperator(const RelationalOperatorNode &filter, Table *table, std::vector>::iterator &row) const; + + bool evalLogicalOperator(LogicalOperatorNode &node, Table *pTable, + std::vector>::iterator &iter) const; + + std::unique_ptr evalArithmetic(ArithmeticalOperatorNode &node, Table *table, + std::vector>::iterator &row) const; }; diff --git a/lexer.cpp b/lexer.cpp index e8de4e2..fb2eddd 100644 --- a/lexer.cpp +++ b/lexer.cpp @@ -13,7 +13,7 @@ void Lexer::parse(const std::string &code) { // TODO handle empty code m_tokens.clear(); - // PERF something like this to prealocate ?? + // PERF something like this to preallocate ?? if (code.size() > 100) { m_tokens.reserve(code.size() / 10); } @@ -36,7 +36,8 @@ void Lexer::parse(const std::string &code) { if (token_type == TokenType::string_literal) match_str = stringLiteral(match_str); - m_tokens.push_back(Token{match_str, token_type}); + if (token_type != TokenType::newline) + m_tokens.push_back(Token{match_str, token_type}); } // DEBUG IT @@ -92,6 +93,14 @@ bool Lexer::isRelationalOperator(TokenType token_type) { token_type == TokenType::lesser || token_type == TokenType::lesser_equal); } +bool Lexer::isLogicalOperator(TokenType token_type) { + return (token_type == TokenType::logical_and || token_type == TokenType::logical_or); +} + +bool Lexer::isArithmeticalOperator(TokenType token_type) { + return (token_type == TokenType::plus || token_type == TokenType::minus || token_type == TokenType::multiply || token_type == TokenType::divide); +} + TokenType Lexer::type(const std::string &token) { // TODO move it to class level not to reinit it again and again std::regex int_regex("[0-9]+"); @@ -145,8 +154,11 @@ TokenType Lexer::type(const std::string &token) { if (token == "where") return TokenType::keyword_where; - if (token == "from") - return TokenType::keyword_from; + if (token == "from") + return TokenType::keyword_from; + + if (token == "delete") + return TokenType::keyword_delete; if (token == "table") return TokenType::keyword_table; @@ -166,8 +178,11 @@ TokenType Lexer::type(const std::string &token) { if (token == "set") return TokenType::keyword_set; - if (token == "copy") - return TokenType::keyword_copy; + if (token == "copy") + return TokenType::keyword_copy; + + if (token == "update") + return TokenType::keyword_update; if (token == "not") return TokenType::keyword_not; diff --git a/lexer.h b/lexer.h index 4098266..0a71707 100644 --- a/lexer.h +++ b/lexer.h @@ -21,6 +21,8 @@ enum class TokenType { keyword_create, keyword_table, keyword_where, + keyword_delete, + keyword_update, keyword_from, keyword_insert, keyword_into, @@ -75,7 +77,8 @@ public: TokenType prevTokenType(); static bool isRelationalOperator(TokenType token_type); - static bool isLogicalOperator(TokenType token_type); + static bool isLogicalOperator(TokenType token_type); + static bool isArithmeticalOperator(TokenType token_type); private: TokenType type(const std::string &token); diff --git a/main.cpp b/main.cpp index cc9bf4b..070e0dd 100644 --- a/main.cpp +++ b/main.cpp @@ -17,12 +17,15 @@ int main(int argc, char *argv[]) { "insert into a (i, s) values(2, 'two')", "insert into a (i, s) values(3, 'two')", "insert into a (i, s) values(4, 'four')", + "insert into a (i, s) values(5, 'five')", "select i, s from a where i > 2", "select i, s from a where i = 1", "select i, s from a where s = 'two'", - "select i, s from a where i <= 3" -// "update a set s = 'three' where i = 3" -// "delete from a where i = 3" + "select i, s from a where i <= 3 and s = 'one'", + "update a set f = 9.99 where i = 3", +// "update a set s = 'three', f = 1.0 + 2.0 where i = 3", + "select i, s, f from a where i = 3" +// "delete from a where i = 4", // "select i, s from a where i > 0" }; diff --git a/parser.cpp b/parser.cpp index 50f6b48..b4a5e25 100644 --- a/parser.cpp +++ b/parser.cpp @@ -15,8 +15,12 @@ std::unique_ptr Parser::parse(const std::string &code) { return parse_create_table(); } if (lexer.tokenType() == TokenType::keyword_insert) { return parse_insert_into_table(); - } if (lexer.tokenType() == TokenType::keyword_select) { - return parse_select_from_table(); + } if (lexer.tokenType() == TokenType::keyword_select) { + return parse_select_from_table(); + } if (lexer.tokenType() == TokenType::keyword_delete) { + return parse_delete_from_table(); + } if (lexer.tokenType() == TokenType::keyword_update) { + return parse_update_table(); } std::cout << "ERROR, token:" << lexer.currentToken().token_string << std::endl; @@ -120,24 +124,18 @@ std::unique_ptr Parser::parse_insert_into_table() { std::unique_ptr Parser::parse_select_from_table() { std::vector cols_names {}; - std::unique_ptr where_node; lexer.skipToken(TokenType::keyword_select); - // TODO support also numbers and expressions while (lexer.tokenType() != TokenType::keyword_from) { - // TODO add consumeToken() which returns token and advances to next token cols_names.push_back(lexer.consumeCurrentToken().token_string); lexer.skipTokenOptional(TokenType::comma); } + lexer.skipToken(TokenType::keyword_from); std::string table_name = lexer.consumeCurrentToken().token_string; - if (lexer.tokenType() == TokenType::keyword_where) { - lexer.skipToken(TokenType::keyword_where); - where_node = parse_where_clause(); - } else { - where_node = std::make_unique(); - } + std::unique_ptr where_node = parse_where_clause(); + // if (lexer.tokenType() == TokenType::keyword_order_by) {} // if (lexer.tokenType() == TokenType::keyword_offset) {} // if (lexer.tokenType() == TokenType::keyword_limit) {} @@ -145,12 +143,77 @@ std::unique_ptr Parser::parse_select_from_table() { return std::make_unique(table_name, cols_names, std::move(where_node)); } +std::unique_ptr Parser::parse_delete_from_table() { + lexer.skipToken(TokenType::keyword_delete); + lexer.skipToken(TokenType::keyword_from); + + std::string table_name = lexer.consumeCurrentToken().token_string; + + std::unique_ptr where_node = parse_where_clause(); + + return std::make_unique(table_name, std::move(where_node)); +} + +std::unique_ptr Parser::parse_update_table() { + lexer.skipToken(TokenType::keyword_update); + lexer.skipTokenOptional(TokenType::keyword_table); + + std::string table_name = lexer.consumeCurrentToken().token_string; + + lexer.skipToken(TokenType::keyword_set); + + std::vector cols_names; + std::vector> values; + + do { + cols_names.push_back(lexer.consumeCurrentToken().token_string); + lexer.skipToken(TokenType::equal); + + std::unique_ptr left = Parser::parse_operand_node(); + if (Lexer::isArithmeticalOperator(lexer.tokenType())) { + ArithmeticalOperatorType op = parse_arithmetical_operator(); + std::unique_ptr right = Parser::parse_operand_node(); + + values.push_back(std::make_unique(op, std::move(left), std::move(right))); + } else { + std::unique_ptr right = std::make_unique(0); + values.push_back(std::make_unique(ArithmeticalOperatorType::copy_value, std::move(left), std::move(right))); + } + lexer.skipTokenOptional(TokenType::comma); + + } while (lexer.tokenType() != TokenType::keyword_where && lexer.tokenType() != TokenType::eof); + + std::unique_ptr where_node = parse_where_clause(); + + return std::make_unique(table_name, cols_names, std::move(values), std::move(where_node)); +} + std::unique_ptr Parser::parse_where_clause() { // TODO add support for multiple filters // TODO add support for parenthesis + if (lexer.tokenType() != TokenType::keyword_where) { + return std::make_unique(); + } + + std::unique_ptr node; + lexer.skipToken(TokenType::keyword_where); + do { + node = parse_relational_expression(); + + if (Lexer::isLogicalOperator(lexer.tokenType())) { + auto operation = parse_logical_operator(); + std::unique_ptr node2 = parse_relational_expression(); + node = std::make_unique(operation, std::move(node), std::move(node2)); + } + } while (lexer.tokenType() != TokenType::eof); // until whole where clause parsed + + return node; +} + +std::unique_ptr Parser::parse_relational_expression() { auto left = parse_operand_node(); - auto operation = parse_operator(); + auto operation = parse_relational_operator(); auto right = parse_operand_node(); return std::make_unique(operation, std::move(left), std::move(right)); @@ -174,7 +237,7 @@ std::unique_ptr Parser::parse_operand_node() { } } -RelationalOperatorType Parser::parse_operator() { +RelationalOperatorType Parser::parse_relational_operator() { auto op = lexer.consumeCurrentToken(); switch (op.type) { case TokenType::equal: @@ -189,7 +252,28 @@ RelationalOperatorType Parser::parse_operator() { return RelationalOperatorType::lesser; case TokenType::lesser_equal: return RelationalOperatorType::lesser_equal; - default: ; + default: throw Exception("Unknown relational operator"); } +} +LogicalOperatorType Parser::parse_logical_operator() { + auto op = lexer.consumeCurrentToken(); + switch (op.type) { + case TokenType::logical_and: + return LogicalOperatorType::and_operator; + case TokenType::logical_or: + return LogicalOperatorType::or_operator; + default: + throw Exception("Unknown logical operator"); + } +} + +ArithmeticalOperatorType Parser::parse_arithmetical_operator() { + auto op = lexer.consumeCurrentToken(); + switch (op.type) { + case TokenType::plus: + return ArithmeticalOperatorType::plus_operator; + default: + throw Exception("Unknown arithmetical operator"); + } } \ No newline at end of file diff --git a/parser.h b/parser.h index 4b71028..0a755cc 100644 --- a/parser.h +++ b/parser.h @@ -21,9 +21,12 @@ enum class NodeType { database_value, logical_operator, relational_operator, + arithmetical_operator, create_table, insert_into, select_from, + delete_from, + update_table, column_name, column_value, column_def, @@ -101,6 +104,9 @@ struct LogicalOperatorNode : Node { LogicalOperatorType op; std::unique_ptr left; std::unique_ptr right; + + LogicalOperatorNode(LogicalOperatorType op, std::unique_ptr left, std::unique_ptr right) : + Node(NodeType::logical_operator), op(op), left(std::move(left)), right(std::move(right)) {}; }; enum class RelationalOperatorType { @@ -123,6 +129,25 @@ struct RelationalOperatorNode : Node { Node(NodeType::relational_operator), op(op), left(std::move(left)), right(std::move(right)) {}; }; +enum class ArithmeticalOperatorType { + copy_value, // just copy lef value and do nothing with it + plus_operator, + minus_operator, + multiply_operator, + divide_operator +}; + +struct ArithmeticalOperatorNode : Node { + ArithmeticalOperatorType op; + + std::unique_ptr left; + std::unique_ptr right; + + ArithmeticalOperatorNode(ArithmeticalOperatorType op, std::unique_ptr left, std::unique_ptr right) : + Node(NodeType::arithmetical_operator), op(op), left(std::move(left)), right(std::move(right)) {}; +}; + + struct CreateTableNode : Node { std::string table_name; std::vector cols_defs; @@ -145,12 +170,29 @@ struct SelectFromTableNode : Node { std::vector cols_names; std::unique_ptr where; - SelectFromTableNode(const std::string name, std::vector names, std::unique_ptr where_clause) : + SelectFromTableNode(std::string name, std::vector names, std::unique_ptr where_clause) : Node(NodeType::select_from), table_name(name), cols_names(names), where(std::move(where_clause)) {} }; -struct UpdateTableNode : Node { }; -struct DeleteFromTableNode : Node { }; +struct UpdateTableNode : Node { + std::string table_name; + std::vector cols_names; + std::vector> values; + std::unique_ptr where; + + UpdateTableNode(std::string name, std::vector names, std::vector> vals, + std::unique_ptr where_clause) : + Node(NodeType::update_table), table_name(name), cols_names(names), values(std::move(vals)), where(std::move(where_clause)) {} +}; + +struct DeleteFromTableNode : Node { + std::string table_name; + std::unique_ptr where; + + DeleteFromTableNode(const std::string name, std::unique_ptr where_clause) : + Node(NodeType::delete_from), table_name(name), where(std::move(where_clause)) {} + +}; @@ -167,11 +209,18 @@ private: std::unique_ptr parse_create_table(); std::unique_ptr parse_insert_into_table(); std::unique_ptr parse_select_from_table(); + std::unique_ptr parse_delete_from_table(); + std::unique_ptr parse_update_table(); std::unique_ptr parse_where_clause(); std::unique_ptr parse_operand_node(); - RelationalOperatorType parse_operator(); + RelationalOperatorType parse_relational_operator(); + LogicalOperatorType parse_logical_operator(); + ArithmeticalOperatorType parse_arithmetical_operator(); private: Lexer lexer; + + std::unique_ptr parse_relational_expression(); + }; diff --git a/row.h b/row.h index 474cb6e..41fc34e 100644 --- a/row.h +++ b/row.h @@ -108,7 +108,7 @@ public: return *m_columns[i]; } - ColValue* ithColum(int i) { + ColValue* ithColumn(int i) { return m_columns[i].get(); }