usql update

This commit is contained in:
2021-12-19 13:33:47 +01:00
parent 37d0d9b3f5
commit 5c925f2608
23 changed files with 1570 additions and 1124 deletions

View File

@@ -1,87 +1,118 @@
#include "usql.h"
#include "exception.h"
#include "ml_date.h"
#include "ml_string.h"
#include <algorithm>
#include <fstream>
namespace usql {
std::unique_ptr<Table> USql::execute_select(SelectFromTableNode &node) {
// find source table
Table *table = find_table(node.table_name);
std::pair<bool, std::vector<rowid_t>> USql::probe_index_scan(const Node *where, Table *table) const {
bool indexscan_possible = normalize_where(where);
// expand *
expand_asterix_char(node, table);
// create result table
std::vector<ColDefNode> result_tbl_col_defs{};
std::vector<int> source_table_col_index{};
for (int i = 0; i < node.cols_names->size(); i++) {
SelectColNode * col_node = &node.cols_names->operator[](i);
auto [src_tbl_col_index, rst_tbl_col_def] = get_column_definition(table, col_node, i);
source_table_col_index.push_back(src_tbl_col_index);
result_tbl_col_defs.push_back(rst_tbl_col_def);
if (indexscan_possible && Settings::get_bool_setting("USE_INDEXSCAN")) {
// where->dump();
return look_for_usable_index(where, table);
}
// check for aggregate function
bool aggregate_funcs = check_for_aggregate_only_functions(node, result_tbl_col_defs.size());
// prepare result table structure
auto result = std::make_unique<Table>("result", result_tbl_col_defs);
// replace possible order by col names to col indexes and validate
setup_order_columns(node.order_by, result.get());
// execute access plan
Row* new_row = nullptr;
for (auto row = begin(table->m_rows); row != end(table->m_rows); ++row) {
// eval where for row
if (eval_where(node.where.get(), table, *row)) {
// prepare empty row and copy column values
// when agregate functions in result only one row for table
if (!aggregate_funcs || result->rows_count()==0) {
new_row = &result->create_empty_row();
}
for (auto idx = 0; idx < result->columns_count(); idx++) {
auto src_table_col_idx = source_table_col_index[idx];
if (src_table_col_idx == FUNCTION_CALL) {
auto evaluated_value = eval_value_node(table, *row, node.cols_names->operator[](idx).value.get(), &result_tbl_col_defs[idx], &new_row->operator[](idx));
ValueNode *col_value = evaluated_value.get();
new_row->setColumnValue(&result_tbl_col_defs[idx], col_value);
} else {
ColValue &col_value = row->operator[](src_table_col_idx);
new_row->setColumnValue(&result_tbl_col_defs[idx], col_value);
}
}
// add row to result
if (aggregate_funcs == 0) {
result->commit_row(*new_row);
}
}
}
// when aggregates commit this one row
if (aggregate_funcs && new_row != nullptr) {
result->commit_row(*new_row);
}
execute_distinct(node, result.get());
execute_order_by(node, table, result.get());
execute_offset_limit(node.offset_limit, result.get());
return result;
// no index scan
return std::make_pair(false, std::vector<rowid_t>{});
}
bool USql::check_for_aggregate_only_functions(SelectFromTableNode &node, int result_cols_cnt) const {
std::pair<bool, std::vector<rowid_t>> USql::look_for_usable_index(const Node *where, Table *table) const {
if (where->node_type == NodeType::relational_operator) {
auto * ron = (RelationalOperatorNode *)where;
// TODO implement >, >=, <=, <
// https://en.cppreference.com/w/cpp/container/map/upper_bound
if (ron->op == RelationalOperatorType::equal) {
if (ron->left->node_type == NodeType::database_value &&
((ron->right->node_type == NodeType::int_value) || (ron->right->node_type == NodeType::string_value))
) {
auto col_name = ((DatabaseValueNode *)ron->left.get())->col_name;
Index * used_index = table->get_index_for_column(col_name);
if (used_index != nullptr) {
std::vector<rowid_t> rowids = used_index->search((ValueNode *)ron->right.get());
#ifndef NDEBUG
std::cout << "using index " << table->m_name << "(" << used_index->get_column_name() << "), " << rowids.size() << "/" << table->rows_count() << std::endl;
#endif
return std::make_pair(true, rowids);
}
}
}
} else if (where->node_type == NodeType::logical_operator) {
auto * operatorNode = (LogicalOperatorNode *)where;
if (operatorNode->op == LogicalOperatorType::and_operator) {
auto [use_index, rowids] = look_for_usable_index(operatorNode->left.get(), table);
if (use_index) {
return std::make_pair(true, rowids);
}
return look_for_usable_index(operatorNode->right.get(), table);
}
}
// no index available
return std::make_pair(false, std::vector<rowid_t>{});
}
bool USql::normalize_where(const Node *node) const {
// normalize relational operators "layout" and check whether index scan even possible
// unify relational operators tha left node is always database value
if (node->node_type == NodeType::relational_operator) {
// TODO more optimizations here, for example node 1 = 2 etc
auto * ron = (RelationalOperatorNode *)node;
if (ron->right->node_type == NodeType::database_value && ((ron->left->node_type == NodeType::int_value) || (ron->left->node_type == NodeType::string_value)) ) {
std::swap(ron->left, ron->right);
}
return true;
} else if (node->node_type == NodeType::logical_operator) {
auto * operatorNode = (LogicalOperatorNode *)node;
if (operatorNode->op == LogicalOperatorType::or_operator) {
return false;
}
bool left_subnode = normalize_where(operatorNode->left.get());
bool right_subnode = normalize_where(operatorNode->left.get());
return left_subnode && right_subnode;
}
return true;
}
void USql::select_row(SelectFromTableNode &where_node,
Table *src_table, Row *src_row,
Table *rslt_table,
const std::vector<ColDefNode> &rslt_tbl_col_defs,
const std::vector<int> &src_table_col_index,
bool is_aggregated) {
Row *rslt_row = nullptr;
// when aggregate functions in rslt_table only one row exists
if (is_aggregated && !rslt_table->empty())
rslt_row = &rslt_table->m_rows[0];
else
rslt_row = &rslt_table->create_empty_row();
for (auto idx = 0; idx < rslt_table->columns_count(); idx++) {
auto src_table_col_idx = src_table_col_index[idx];
if (src_table_col_idx == FUNCTION_CALL) {
auto evaluated_value = eval_value_node(src_table, *src_row, where_node.cols_names->operator[](idx).value.get(),
const_cast<ColDefNode *>(&rslt_tbl_col_defs[idx]), &rslt_row->operator[](idx));
ValueNode *col_value = evaluated_value.get();
rslt_row->setColumnValue((ColDefNode *) &rslt_tbl_col_defs[idx], col_value);
} else {
ColValue &col_value = src_row->operator[](src_table_col_idx);
rslt_row->setColumnValue((ColDefNode *) &rslt_tbl_col_defs[idx], col_value);
}
}
// for aggregate is validated more than needed
rslt_table->commit_row(*rslt_row);
}
bool USql::check_for_aggregate_only_functions(SelectFromTableNode &node, size_t result_cols_cnt) {
int aggregate_funcs = 0;
for (int i = 0; i < node.cols_names->size(); i++) {
SelectColNode * col_node = &node.cols_names->operator[](i);
@@ -99,7 +130,7 @@ bool USql::check_for_aggregate_only_functions(SelectFromTableNode &node, int res
return aggregate_funcs > 0;
}
void USql::expand_asterix_char(SelectFromTableNode &node, Table *table) const {
void USql::expand_asterix_char(SelectFromTableNode &node, Table *table) {
if (node.cols_names->size() == 1 && node.cols_names->operator[](0).name == "*") {
node.cols_names->clear();
node.cols_names->reserve(table->columns_count());
@@ -109,7 +140,7 @@ void USql::expand_asterix_char(SelectFromTableNode &node, Table *table) const {
}
}
void USql::setup_order_columns(std::vector<ColOrderNode> &node, Table *table) const {
void USql::setup_order_columns(std::vector<ColOrderNode> &node, Table *table) {
for (auto& order_node : node) {
if (!order_node.col_name.empty()) {
ColDefNode col_def = table->get_column_def(order_node.col_name);
@@ -120,19 +151,19 @@ void USql::setup_order_columns(std::vector<ColOrderNode> &node, Table *table) co
if (order_node.col_index < 0 || order_node.col_index >= table->columns_count())
throw Exception("unknown column in order by clause (" + order_node.col_name + ")");
}
}
}
void USql::execute_distinct(SelectFromTableNode &node, Table *result) {
if (!node.distinct) return;
auto compare_rows = [](const Row &a, const Row &b) { return a.compare(b) >= 0; };
std::sort(result->m_rows.begin(), result->m_rows.end(), compare_rows);
result->m_rows.erase(std::unique(result->m_rows.begin(), result->m_rows.end()), result->m_rows.end());
}
void USql::execute_order_by(SelectFromTableNode &node, Table *table, Table *result) {
void USql::execute_order_by(SelectFromTableNode &node, Table *result) {
if (node.order_by.empty()) return;
auto compare_rows = [&node, &result](const Row &a, const Row &b) {
@@ -160,6 +191,21 @@ void USql::execute_offset_limit(OffsetLimitNode &node, Table *result) {
result->m_rows.erase(result->m_rows.begin() + node.limit, result->m_rows.end());
}
bool USql::eval_where(Node *where, Table *table, Row &row)
{
switch (where->node_type)
{
case NodeType::true_node:
return true;
case NodeType::relational_operator: // just one condition
return eval_relational_operator(*((RelationalOperatorNode *)where), table, row);
case NodeType::logical_operator:
return eval_logical_operator(*((LogicalOperatorNode *)where), table, row);
default:
throw Exception("Wrong node type");
}
}
std::tuple<int, ColDefNode> USql::get_column_definition(Table *table, SelectColNode *select_col_node, int col_order ) {
return get_node_definition(table, select_col_node->value.get(), select_col_node->name, col_order );
}
@@ -218,7 +264,7 @@ std::tuple<int, ColDefNode> USql::get_node_definition(Table *table, Node * node,
auto [left_col_index, left_tbl_col_def] = get_node_definition(table, ari_node->left.get(), col_name, col_order );
auto [right_col_index, right_tbl_col_def] = get_node_definition(table, ari_node->right.get(), col_name, col_order );
ColumnType col_type; // TODO handle varchar and it len
ColumnType col_type; // TODO handle varchar and its len
if (left_tbl_col_def.type==ColumnType::float_type || right_tbl_col_def.type==ColumnType::float_type)
col_type = ColumnType::float_type;
else
@@ -249,8 +295,7 @@ std::tuple<int, ColDefNode> USql::get_node_definition(Table *table, Node * node,
std::unique_ptr<Table> USql::execute_insert_into_table(InsertIntoTableNode &node) {
std::unique_ptr<Table> USql::execute_insert_into_table(const InsertIntoTableNode &node) {
// find table
Table *table_def = find_table(node.table_name);
@@ -276,45 +321,52 @@ std::unique_ptr<Table> USql::execute_insert_into_table(InsertIntoTableNode &node
std::unique_ptr<Table> USql::execute_delete(DeleteFromTableNode &node) {
std::unique_ptr<Table> USql::execute_delete(const DeleteFromTableNode &node) {
size_t affected_rows = 0;
// find source table
Table *table = find_table(node.table_name);
// execute access plan
auto affected_rows = table->rows_count();
Table::rows_scanner i = get_iterator(table, node.where.get());
while(Row *row = i.next()) {
if (eval_where(node.where.get(), table, *row)) {
row->set_deleted();
table->unindex_row(*row);
table->m_rows.erase(
std::remove_if(table->m_rows.begin(), table->m_rows.end(),
[&node, table](Row &row){return eval_where(node.where.get(), table, row);}),
table->m_rows.end());
affected_rows -= table->rows_count();
affected_rows++;
}
}
return create_stmt_result_table(0, "delete succeeded", affected_rows);
}
std::unique_ptr<Table> USql::execute_update(UpdateTableNode &node) {
std::unique_ptr<Table> USql::execute_update(const UpdateTableNode &node) {
size_t affected_rows = 0;
// find source table
Table *table = find_table(node.table_name);
// execute access plan
int affected_rows = 0;
for (auto row = begin(table->m_rows); row != end(table->m_rows); ++row) {
// eval where for row
Table::rows_scanner i = get_iterator(table, node.where.get());
while(Row *row = i.next()) {
if (eval_where(node.where.get(), table, *row)) {
int i = 0;
Row old_row = * row;
int col_idx = 0;
for (const auto& col : node.cols_names) {
// TODO cache it like in select
// PERF cache it like in select
ColDefNode col_def = table->get_column_def(col.col_name);
std::unique_ptr<ValueNode> new_val = eval_arithmetic_operator(col_def.type,
static_cast<ArithmeticalOperatorNode &>(*node.values[i]),
table, *row);
static_cast<ArithmeticalOperatorNode &>(*node.values[col_idx]), table, *row);
usql::Table::validate_column(&col_def, new_val.get());
row->setColumnValue(&col_def, new_val.get());
i++;
col_idx++;
}
table->reindex_row(old_row, *row);
affected_rows++;
// TODO tady je problem, ze kdyz to zfajluje na jednom radku ostatni by se nemely provest
}
@@ -324,20 +376,58 @@ std::unique_ptr<Table> USql::execute_update(UpdateTableNode &node) {
}
bool USql::eval_where(Node *where, Table *table, Row &row) {
switch (where->node_type) {
case NodeType::true_node:
return true;
case NodeType::relational_operator: // just one condition
return eval_relational_operator(*((RelationalOperatorNode *) where), table, row);
case NodeType::logical_operator:
return eval_logical_operator(*((LogicalOperatorNode *) where), table, row);
default:
throw Exception("Wrong node type");
std::unique_ptr<Table> USql::execute_select(SelectFromTableNode &node) const {
// find source table
Table *table = find_table(node.table_name);
// expand *
expand_asterix_char(node, table);
// create result table
std::vector<ColDefNode> result_tbl_col_defs{};
std::vector<int> source_table_col_index{};
for (int i = 0; i < node.cols_names->size(); i++) {
SelectColNode *col_node = &node.cols_names->operator[](i);
auto [src_tbl_col_index, rst_tbl_col_def] = get_column_definition(table, col_node, i);
source_table_col_index.push_back(src_tbl_col_index);
result_tbl_col_defs.push_back(rst_tbl_col_def);
}
return false;
// check for aggregate function
bool is_aggregated = check_for_aggregate_only_functions(node, result_tbl_col_defs.size());
// prepare result table structure
auto result = std::make_unique<Table>("result", result_tbl_col_defs);
// replace possible order by col names to col indexes and validate
setup_order_columns(node.order_by, result.get());
// execute access plan
Table::rows_scanner i = get_iterator(table, node.where.get());
while(Row *row = i.next()) {
if (eval_where(node.where.get(), table, *row)) { // put it into row_scanner.next
select_row(node, table, row, result.get(), result_tbl_col_defs, source_table_col_index, is_aggregated);
}
}
execute_distinct(node, result.get());
execute_order_by(node, result.get());
execute_offset_limit(node.offset_limit, result.get());
return result;
}
Table::rows_scanner USql::get_iterator(Table *table, const Node *where) const {
auto[use_index, rowids] = probe_index_scan(where, table);
if (use_index)
return Table::rows_scanner(table, rowids);
else
return Table::rows_scanner(table);
}
} // namespace