423 lines
17 KiB
C++
423 lines
17 KiB
C++
#include "usql.h"
|
|
#include "exception.h"
|
|
#include "ml_date.h"
|
|
#include "ml_string.h"
|
|
|
|
#include <algorithm>
|
|
#include <fstream>
|
|
|
|
namespace usql {
|
|
|
|
std::unique_ptr<Table> USql::execute(const std::string &command) {
|
|
try {
|
|
std::unique_ptr<Node> node = m_parser.parse(command);
|
|
return execute(*node);
|
|
|
|
} catch (std::exception &e) {
|
|
return create_stmt_result_table(-1, e.what(), 0);
|
|
}
|
|
|
|
}
|
|
|
|
std::unique_ptr<Table> USql::execute(Node &node) {
|
|
// TODO optimize execution nodes here
|
|
switch (node.node_type) {
|
|
case NodeType::create_table:
|
|
return execute_create_table(static_cast<CreateTableNode &>(node));
|
|
case NodeType::create_table_as_select:
|
|
return execute_create_table_as_table(static_cast<CreateTableAsSelectNode &>(node));
|
|
case NodeType::drop_table:
|
|
return execute_drop(static_cast<DropTableNode &>(node));
|
|
case NodeType::insert_into:
|
|
return execute_insert_into_table(static_cast<InsertIntoTableNode &>(node));
|
|
case NodeType::select_from:
|
|
return execute_select(static_cast<SelectFromTableNode &>(node));
|
|
case NodeType::delete_from:
|
|
return execute_delete(static_cast<DeleteFromTableNode &>(node));
|
|
case NodeType::update_table:
|
|
return execute_update(static_cast<UpdateTableNode &>(node));
|
|
case NodeType::load_table:
|
|
return execute_load(static_cast<LoadIntoTableNode &>(node));
|
|
case NodeType::save_table:
|
|
return execute_save(static_cast<SaveTableNode &>(node));
|
|
case NodeType::set:
|
|
return execute_set(static_cast<SetNode &>(node));
|
|
case NodeType::show:
|
|
return execute_show(static_cast<ShowNode &>(node));
|
|
default:
|
|
return create_stmt_result_table(-1, "unknown statement", 0);
|
|
}
|
|
}
|
|
|
|
|
|
bool USql::eval_relational_operator(const RelationalOperatorNode &filter, Table *table, Row &row) {
|
|
std::unique_ptr<ValueNode> left_value = eval_value_node(table, row, filter.left.get(), nullptr, nullptr);
|
|
std::unique_ptr<ValueNode> right_value = eval_value_node(table, row, filter.right.get(), nullptr, nullptr);
|
|
|
|
double comparator;
|
|
|
|
if (left_value->node_type == NodeType::null_value || right_value->node_type == NodeType::null_value) {
|
|
bool all_null = (left_value->isNull() && right_value->node_type == NodeType::null_value) ||
|
|
(right_value->isNull() && left_value->node_type == NodeType::null_value);
|
|
if (filter.op == RelationalOperatorType::is)
|
|
return all_null;
|
|
if (filter.op == RelationalOperatorType::is_not)
|
|
return !all_null;
|
|
return false;
|
|
} else if (left_value->node_type == NodeType::int_value && right_value->node_type == NodeType::int_value) {
|
|
comparator = left_value->getIntegerValue() - right_value->getIntegerValue();
|
|
} else if ((left_value->node_type == NodeType::int_value && right_value->node_type == NodeType::float_value) ||
|
|
(left_value->node_type == NodeType::float_value && right_value->node_type == NodeType::int_value) ||
|
|
(left_value->node_type == NodeType::float_value && right_value->node_type == NodeType::float_value)) {
|
|
comparator = left_value->getDoubleValue() - right_value->getDoubleValue();
|
|
} else if (left_value->node_type == NodeType::bool_value || right_value->node_type == NodeType::bool_value) {
|
|
bool bl = left_value->getBooleanValue();
|
|
bool br = right_value->getBooleanValue();
|
|
comparator = bl == br ? 0 : 1;
|
|
} else if (left_value->node_type == NodeType::string_value || right_value->node_type == NodeType::string_value) {
|
|
comparator = left_value->getStringValue().compare(right_value->getStringValue());
|
|
// date values are essentially int values so handled above
|
|
} else {
|
|
throw Exception("Undefined combination of types");
|
|
}
|
|
|
|
switch (filter.op) {
|
|
case RelationalOperatorType::equal:
|
|
return comparator == 0.0;
|
|
case RelationalOperatorType::not_equal:
|
|
return comparator != 0.0;
|
|
case RelationalOperatorType::greater:
|
|
return comparator > 0.0;
|
|
case RelationalOperatorType::greater_equal:
|
|
return comparator >= 0.0;
|
|
case RelationalOperatorType::lesser:
|
|
return comparator < 0.0;
|
|
case RelationalOperatorType::lesser_equal:
|
|
return comparator <= 0.0;
|
|
case RelationalOperatorType::is:
|
|
case RelationalOperatorType::is_not:
|
|
// already handled
|
|
throw Exception("error in is or is not handling");
|
|
}
|
|
|
|
throw Exception("invalid relational operator");
|
|
}
|
|
|
|
|
|
std::unique_ptr<ValueNode> USql::eval_value_node(Table *table, Row &row, Node *node, ColDefNode *col_def_node, ColValue *agg_func_value) {
|
|
if (node->node_type == NodeType::database_value) {
|
|
return eval_database_value_node(table, row, node);
|
|
} else if (node->node_type == NodeType::int_value || node->node_type == NodeType::float_value || node->node_type == NodeType::string_value || node->node_type == NodeType::bool_value) {
|
|
return eval_literal_value_node(table, row, node);
|
|
} else if (node->node_type == NodeType::function) {
|
|
return eval_function_value_node(table, row, node, col_def_node, agg_func_value);
|
|
} else if (node->node_type == NodeType::null_value) {
|
|
return std::make_unique<NullValueNode>();
|
|
} else if (node->node_type == NodeType::arithmetical_operator) {
|
|
return eval_arithmetic_operator(ColumnType::float_type, static_cast<ArithmeticalOperatorNode &>(*node), table, row);
|
|
}
|
|
throw Exception("unsupported node type");
|
|
}
|
|
|
|
|
|
std::unique_ptr<ValueNode> USql::eval_database_value_node(Table *table, Row &row, Node *node) {
|
|
auto *dvl = static_cast<DatabaseValueNode *>(node);
|
|
ColDefNode col_def = table->get_column_def( dvl->col_name); // TODO optimize it to just get this def once
|
|
ColValue &db_value = row[col_def.order];
|
|
|
|
if (db_value.isNull())
|
|
return std::make_unique<NullValueNode>();
|
|
|
|
if (col_def.type == ColumnType::integer_type)
|
|
return std::make_unique<IntValueNode>(db_value.getIntValue());
|
|
if (col_def.type == ColumnType::float_type)
|
|
return std::make_unique<DoubleValueNode>(db_value.getDoubleValue());
|
|
if (col_def.type == ColumnType::varchar_type)
|
|
return std::make_unique<StringValueNode>(db_value.getStringValue());
|
|
if (col_def.type == ColumnType::bool_type)
|
|
return std::make_unique<BooleanValueNode>(db_value.getBoolValue());
|
|
if (col_def.type == ColumnType::date_type)
|
|
return std::make_unique<IntValueNode>(db_value.getIntValue());
|
|
|
|
throw Exception("unknown database value type");
|
|
}
|
|
|
|
|
|
std::unique_ptr<ValueNode> USql::eval_literal_value_node(Table *table, Row &row, Node *node) {
|
|
if (node->node_type == NodeType::int_value) {
|
|
auto *ivl = static_cast<IntValueNode *>(node);
|
|
return std::make_unique<IntValueNode>(ivl->value);
|
|
|
|
} else if (node->node_type == NodeType::float_value) {
|
|
auto *ivl = static_cast<DoubleValueNode *>(node);
|
|
return std::make_unique<DoubleValueNode>(ivl->value);
|
|
|
|
} else if (node->node_type == NodeType::string_value) {
|
|
auto *ivl = static_cast<StringValueNode *>(node);
|
|
return std::make_unique<StringValueNode>(ivl->value);
|
|
|
|
} else if (node->node_type == NodeType::bool_value) {
|
|
auto *ivl = static_cast<BooleanValueNode *>(node);
|
|
return std::make_unique<BooleanValueNode>(ivl->value);
|
|
}
|
|
// Date has no it's own value node (it is passed around as string)
|
|
|
|
throw Exception("invalid type");
|
|
}
|
|
|
|
|
|
std::unique_ptr<ValueNode>
|
|
USql::eval_function_value_node(Table *table, Row &row, Node *node, ColDefNode *col_def_node, ColValue *agg_func_value) {
|
|
auto *fnc = static_cast<FunctionNode *>(node);
|
|
|
|
std::vector<std::unique_ptr<ValueNode>> evaluatedPars;
|
|
for(auto & param : fnc->params) {
|
|
evaluatedPars.push_back(eval_value_node(table, row, param.get(), nullptr, nullptr));
|
|
}
|
|
|
|
// at this moment no functions without parameter(s) or first param can be null
|
|
if (evaluatedPars.empty() || evaluatedPars[0]->isNull())
|
|
return std::make_unique<NullValueNode>();
|
|
|
|
// TODO use some enum
|
|
if (fnc->function == "lower") return lower_function(evaluatedPars);
|
|
if (fnc->function == "upper") return upper_function(evaluatedPars);
|
|
if (fnc->function == "to_date") return to_date_function(evaluatedPars);
|
|
if (fnc->function == "to_string") return to_string_function(evaluatedPars);
|
|
if (fnc->function == "pp") return pp_function(evaluatedPars);
|
|
if (fnc->function == "count") return count_function(agg_func_value, evaluatedPars);
|
|
if (fnc->function == "max") return max_function(evaluatedPars, col_def_node, agg_func_value);
|
|
if (fnc->function == "min") return min_function(evaluatedPars, col_def_node, agg_func_value);
|
|
|
|
throw Exception("invalid function: " + fnc->function);
|
|
}
|
|
|
|
std::unique_ptr<ValueNode> USql::count_function(ColValue *agg_func_value, const std::vector<std::unique_ptr<ValueNode>> &evaluatedPars) {
|
|
long c = 1;
|
|
if (!agg_func_value->isNull()) {
|
|
c = agg_func_value->getIntValue() + 1;
|
|
}
|
|
return std::make_unique<IntValueNode>(c);
|
|
}
|
|
|
|
|
|
bool USql::eval_logical_operator(LogicalOperatorNode &node, Table *pTable, Row &row) {
|
|
//bool left = eval_relational_operator(static_cast<const RelationalOperatorNode &>(*node.left), pTable, row);
|
|
bool left = eval_where(&(*node.left), pTable, row);
|
|
|
|
if ((node.op == LogicalOperatorType::and_operator && !left) || (node.op == LogicalOperatorType::or_operator && left))
|
|
return left;
|
|
|
|
//bool right = eval_relational_operator(static_cast<const RelationalOperatorNode &>(*node.right), pTable, row);
|
|
bool right = eval_where(&(*node.right), pTable, row);
|
|
return right;
|
|
}
|
|
|
|
|
|
std::unique_ptr<ValueNode> USql::eval_arithmetic_operator(ColumnType outType, ArithmeticalOperatorNode &node, Table *table, Row &row) {
|
|
if (node.op == ArithmeticalOperatorType::copy_value) {
|
|
return eval_value_node(table, row, node.left.get(), nullptr, nullptr);
|
|
}
|
|
|
|
std::unique_ptr<ValueNode> left = eval_value_node(table, row, node.left.get(), nullptr, nullptr);
|
|
std::unique_ptr<ValueNode> right = eval_value_node(table, row, node.right.get(), nullptr, nullptr);
|
|
|
|
if (left->isNull() || right->isNull())
|
|
return std::make_unique<NullValueNode>();
|
|
|
|
if (outType == ColumnType::float_type) {
|
|
double l = ((ValueNode *) left.get())->getDoubleValue();
|
|
double r = ((ValueNode *) right.get())->getDoubleValue();
|
|
switch (node.op) {
|
|
case ArithmeticalOperatorType::plus_operator:
|
|
return std::make_unique<DoubleValueNode>(l + r);
|
|
case ArithmeticalOperatorType::minus_operator:
|
|
return std::make_unique<DoubleValueNode>(l - r);
|
|
case ArithmeticalOperatorType::multiply_operator:
|
|
return std::make_unique<DoubleValueNode>(l * r);
|
|
case ArithmeticalOperatorType::divide_operator:
|
|
return std::make_unique<DoubleValueNode>(l / r);
|
|
default:
|
|
throw Exception("implement me!!");
|
|
}
|
|
|
|
} else if (outType == ColumnType::integer_type) {
|
|
long l = ((ValueNode *) left.get())->getIntegerValue();
|
|
long r = ((ValueNode *) right.get())->getIntegerValue();
|
|
switch (node.op) {
|
|
case ArithmeticalOperatorType::plus_operator:
|
|
return std::make_unique<IntValueNode>(l + r);
|
|
case ArithmeticalOperatorType::minus_operator:
|
|
return std::make_unique<IntValueNode>(l - r);
|
|
case ArithmeticalOperatorType::multiply_operator:
|
|
return std::make_unique<IntValueNode>(l * r);
|
|
case ArithmeticalOperatorType::divide_operator:
|
|
return std::make_unique<IntValueNode>(l / r);
|
|
default:
|
|
throw Exception("implement me!!");
|
|
}
|
|
|
|
} else if (outType == ColumnType::varchar_type) {
|
|
std::string l = ((ValueNode *) left.get())->getStringValue();
|
|
std::string r = ((ValueNode *) right.get())->getStringValue();
|
|
switch (node.op) {
|
|
case ArithmeticalOperatorType::plus_operator:
|
|
return std::make_unique<StringValueNode>(l + r);
|
|
default:
|
|
throw Exception("implement me!!");
|
|
}
|
|
}
|
|
// TODO date node should support addition and subtraction
|
|
|
|
throw Exception("implement me!!");
|
|
}
|
|
|
|
|
|
std::unique_ptr<ValueNode> USql::to_string_function(const std::vector<std::unique_ptr<ValueNode>> &evaluatedPars) {
|
|
long date = evaluatedPars[0]->getDateValue();
|
|
std::string format = evaluatedPars[1]->getStringValue();
|
|
std::string formatted_date = date_to_string(date, format);
|
|
return std::make_unique<StringValueNode>(formatted_date);
|
|
}
|
|
|
|
std::unique_ptr<ValueNode> USql::to_date_function(const std::vector<std::unique_ptr<ValueNode>> &evaluatedPars) {
|
|
std::string date = evaluatedPars[0]->getStringValue();
|
|
std::string format = evaluatedPars[1]->getStringValue();
|
|
long epoch_time = string_to_date(date, format);
|
|
return std::make_unique<IntValueNode>(epoch_time); // No DateValueNode for now
|
|
}
|
|
|
|
std::unique_ptr<ValueNode> USql::upper_function(const std::vector<std::unique_ptr<ValueNode>> &evaluatedPars) {
|
|
std::string str = evaluatedPars[0]->getStringValue();
|
|
std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) -> unsigned char { return toupper(c); });
|
|
return std::make_unique<StringValueNode>(str);
|
|
}
|
|
|
|
std::unique_ptr<ValueNode> USql::lower_function(const std::vector<std::unique_ptr<ValueNode>> &evaluatedPars) {
|
|
std::string str = evaluatedPars[0]->getStringValue();
|
|
std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) -> unsigned char { return tolower(c); });
|
|
return std::make_unique<StringValueNode>(str);
|
|
}
|
|
|
|
std::unique_ptr<ValueNode> USql::pp_function(const std::vector<std::unique_ptr<ValueNode>> &evaluatedPars) {
|
|
auto &parsed_value = evaluatedPars[0];
|
|
|
|
if (parsed_value->node_type == NodeType::int_value || parsed_value->node_type == NodeType::float_value) {
|
|
std::string format = evaluatedPars.size() > 1 ? evaluatedPars[1]->getStringValue() : "";
|
|
char buf[20] {0}; // TODO constant here
|
|
double value = parsed_value->getDoubleValue();
|
|
|
|
if (format == "100%")
|
|
std::snprintf(buf, 20, "%.2f%%", value);
|
|
else if (format == "%.2f")
|
|
std::snprintf(buf, 20, "%.2f", value);
|
|
else if (value >= 1000000000000)
|
|
std::snprintf(buf, 20, "%7.2fT", value/1000000000000);
|
|
else if (value >= 1000000000)
|
|
std::sprintf(buf, "%7.2fB", value/1000000000);
|
|
else if (value >= 1000000)
|
|
std::snprintf(buf, 20, "%7.2fM", value/1000000);
|
|
else if (value >= 100000)
|
|
std::snprintf(buf, 20, "%7.2fM", value/100000); // 0.12M
|
|
else if (value <= -1000000000000)
|
|
std::snprintf(buf, 20, "%7.2fT", value/1000000000000);
|
|
else if (value <= -1000000000)
|
|
std::snprintf(buf, 20, "%7.2fB", value/1000000000);
|
|
else if (value <= -1000000)
|
|
std::snprintf(buf, 20, "%7.2fM", value/1000000);
|
|
else if (value <= -100000)
|
|
std::snprintf(buf, 20, "%7.2fM", value/100000); // 0.12M
|
|
else if (value == 0)
|
|
buf[0]='0';
|
|
else
|
|
return std::make_unique<StringValueNode>(parsed_value->getStringValue().substr(0, 10));
|
|
// TODO introduce constant for 10
|
|
std::string s {buf};
|
|
return std::make_unique<StringValueNode>(string_padd(s.erase(s.find_last_not_of(" ")+1), 10, ' ', false));
|
|
}
|
|
return std::make_unique<StringValueNode>(parsed_value->getStringValue());
|
|
}
|
|
|
|
std::unique_ptr<ValueNode>
|
|
USql::max_function(const std::vector<std::unique_ptr<ValueNode>> &evaluatedPars, const ColDefNode *col_def_node,
|
|
ColValue *agg_func_value) {
|
|
if (col_def_node->type == ColumnType::integer_type || col_def_node->type == ColumnType::date_type) {
|
|
if (!evaluatedPars[0]->isNull()) {
|
|
long val = evaluatedPars[0]->getIntegerValue();
|
|
if (agg_func_value->isNull()) {
|
|
return std::make_unique<IntValueNode>(val);
|
|
} else {
|
|
return std::make_unique<IntValueNode>(std::max(val, agg_func_value->getIntValue()));
|
|
}
|
|
} else {
|
|
return std::make_unique<IntValueNode>(agg_func_value->getIntValue());
|
|
}
|
|
} else if (col_def_node->type == ColumnType::float_type) {
|
|
if (!evaluatedPars[0]->isNull()) {
|
|
double val = evaluatedPars[0]->getDoubleValue();
|
|
if (agg_func_value->isNull()) {
|
|
return std::make_unique<DoubleValueNode>(val);
|
|
} else {
|
|
return std::make_unique<DoubleValueNode>(std::max(val, agg_func_value->getDoubleValue()));
|
|
}
|
|
} else {
|
|
return std::make_unique<DoubleValueNode>(agg_func_value->getDoubleValue());
|
|
}
|
|
}
|
|
|
|
// TODO string and boolean
|
|
throw Exception("unsupported data type for max function");
|
|
}
|
|
|
|
std::unique_ptr<ValueNode>
|
|
USql::min_function(const std::vector<std::unique_ptr<ValueNode>> &evaluatedPars, const ColDefNode *col_def_node,
|
|
ColValue *agg_func_value) {
|
|
if (col_def_node->type == ColumnType::integer_type || col_def_node->type == ColumnType::date_type) {
|
|
if (!evaluatedPars[0]->isNull()) {
|
|
long val = evaluatedPars[0]->getIntegerValue();
|
|
if (agg_func_value->isNull()) {
|
|
return std::make_unique<IntValueNode>(val);
|
|
} else {
|
|
return std::make_unique<IntValueNode>(std::min(val, agg_func_value->getIntValue()));
|
|
}
|
|
} else {
|
|
return std::make_unique<IntValueNode>(agg_func_value->getIntValue());
|
|
}
|
|
} else if (col_def_node->type == ColumnType::float_type) {
|
|
if (!evaluatedPars[0]->isNull()) {
|
|
double val = evaluatedPars[0]->getDoubleValue();
|
|
if (agg_func_value->isNull()) {
|
|
return std::make_unique<DoubleValueNode>(val);
|
|
} else {
|
|
return std::make_unique<DoubleValueNode>(std::min(val, agg_func_value->getDoubleValue()));
|
|
}
|
|
} else {
|
|
return std::make_unique<DoubleValueNode>(agg_func_value->getDoubleValue());
|
|
}
|
|
}
|
|
|
|
// TODO string and boolean
|
|
throw Exception("unsupported data type for min function");
|
|
}
|
|
|
|
Table *USql::find_table(const std::string &name) {
|
|
auto name_cmp = [name](const Table& t) { return t.m_name == name; };
|
|
auto table_def = std::find_if(begin(m_tables), end(m_tables), name_cmp);
|
|
if (table_def != std::end(m_tables)) {
|
|
return table_def.operator->();
|
|
} else {
|
|
throw Exception("table not found (" + name + ")");
|
|
}
|
|
}
|
|
|
|
void USql::check_table_not_exists(const std::string &name) {
|
|
auto name_cmp = [name](const Table& t) { return t.m_name == name; };
|
|
auto table_def = std::find_if(begin(m_tables), end(m_tables), name_cmp);
|
|
if (table_def != std::end(m_tables)) {
|
|
throw Exception("table already exists");
|
|
}
|
|
}
|
|
|
|
|
|
} // namespace
|