diff --git a/row.cpp b/row.cpp index 66001e5..71561c2 100644 --- a/row.cpp +++ b/row.cpp @@ -3,6 +3,26 @@ namespace usql { + int ColNullValue::compare(ColValue * other) { + return other->isNull() ? 0 : -1; // null goes to end + } + + int ColIntegerValue::compare(ColValue * other) { + return other->isNull() ? 1 : m_integer - other->getIntValue(); // null goes to end + } + + int ColDoubleValue::compare(ColValue * other) { + if (other->isNull()) { // null goes to end + return 1; + } + double c = m_double - other->getDoubleValue(); + return c < 0 ? -1 : c == 0.0 ? 0 : 1; + } + + int ColStringValue::compare(ColValue * other) { + return other->isNull() ? 1 : m_string.compare(other->getStringValue()); // null goes to end + } + Row::Row(int cols_count) { m_columns.reserve(cols_count); for (int i = 0; i < cols_count; i++) { diff --git a/row.h b/row.h index 3bd92ae..edbc641 100644 --- a/row.h +++ b/row.h @@ -11,16 +11,24 @@ namespace usql { struct ColValue { virtual bool isNull() { return false; }; - virtual long getIntValue() { throw Exception("Not supported"); }; - virtual double getDoubleValue() { throw Exception("Not supported"); }; - virtual std::string getStringValue() { throw Exception("Not supported"); }; + virtual long getIntValue() = 0; + virtual double getDoubleValue() = 0; + virtual std::string getStringValue() = 0; + + virtual int compare(ColValue * other) = 0; + + virtual ~ColValue() = default; }; struct ColNullValue : ColValue { - virtual bool isNull() { return true; }; - virtual std::string getStringValue() { return "null"; }; + bool isNull() override { return true; }; + long getIntValue() override { throw Exception("Not supported"); }; + double getDoubleValue() override { throw Exception("Not supported"); }; + std::string getStringValue() override { return "null"; }; + + int compare(ColValue * other) override; }; @@ -29,11 +37,13 @@ namespace usql { ColIntegerValue(long value) : m_integer(value) {}; ColIntegerValue(const ColIntegerValue &other) : m_integer(other.m_integer) {}; - virtual long getIntValue() { return m_integer; }; - virtual double getDoubleValue() { return (double) m_integer; }; - virtual std::string getStringValue() { return std::to_string(m_integer); }; + long getIntValue() override { return m_integer; }; + double getDoubleValue() override { return (double) m_integer; }; + std::string getStringValue() override { return std::to_string(m_integer); }; - int m_integer; + int compare(ColValue * other) override; + + long m_integer; }; @@ -42,9 +52,11 @@ namespace usql { ColDoubleValue(double value) : m_double(value) {}; ColDoubleValue(const ColDoubleValue &other) : m_double(other.m_double) {} - virtual long getIntValue() { return (long) m_double; }; - virtual double getDoubleValue() { return m_double; }; - virtual std::string getStringValue() { return std::to_string(m_double); }; + long getIntValue() override { return (long) m_double; }; + double getDoubleValue() override { return m_double; }; + std::string getStringValue() override { return std::to_string(m_double); }; + + int compare(ColValue * other) override; double m_double; }; @@ -52,12 +64,14 @@ namespace usql { struct ColStringValue : ColValue { - ColStringValue(const std::string value) : m_string(value) {}; + ColStringValue(const std::string &value) : m_string(value) {}; ColStringValue(const ColStringValue &other) : m_string(other.m_string) {}; - virtual long getIntValue() { return std::stoi(m_string); }; - virtual double getDoubleValue() { return std::stod(m_string); }; - virtual std::string getStringValue() { return m_string; }; + long getIntValue() override { return std::stoi(m_string); }; + double getDoubleValue() override { return std::stod(m_string); }; + std::string getStringValue() override { return m_string; }; + + int compare(ColValue * other) override; std::string m_string; }; diff --git a/usql.cpp b/usql.cpp index d54e447..c9b6ac3 100644 --- a/usql.cpp +++ b/usql.cpp @@ -208,11 +208,8 @@ void USql::execute_order_by(SelectFromTableNode &node, Table *table, Table *resu ColValue *a_val = a.ith_column(col_def.order); ColValue *b_val = b.ith_column(col_def.order); - if (a_val->isNull() && b_val->isNull()) return true; // both is null so a goes to end - if (!a_val->isNull() && b_val->isNull()) return true; // b is null so goes to end - if (a_val->isNull() && !b_val->isNull()) return false; // a is null so goes to end + int compare = a_val->compare(b_val); - int compare = compare_col_values(col_def, a_val, b_val); if (compare < 0) return order_by_col_def.ascending ? true : false; if (compare > 0) return order_by_col_def.ascending ? false : true; } @@ -231,21 +228,6 @@ void USql::execute_offset_limit(OffsetLimitNode &node, Table *result) const { result->m_rows.erase(result->m_rows.begin() + node.limit, result->m_rows.end()); } -int USql::compare_col_values(const ColDefNode &col_def, ColValue *a_val, ColValue *b_val) const { - double c; - switch (col_def.type) { - case (ColumnType::integer_type): - return a_val->getIntValue() - b_val->getIntValue(); - case (ColumnType::float_type): - c = a_val->getDoubleValue() - b_val->getDoubleValue(); - return c < 0 ? -1 : c==0.0 ? 0 : 1; - case (ColumnType::varchar_type): - return a_val->getStringValue().compare(b_val->getStringValue()); - default: - throw Exception("Unsupported data type"); - } -} - std::tuple USql::get_column_definition(Table *table, SelectColNode *select_col_node, int col_order ) { std::string new_col_name = select_col_node->name; diff --git a/usql.h b/usql.h index 0178b79..2969532 100644 --- a/usql.h +++ b/usql.h @@ -55,8 +55,6 @@ private: Parser m_parser; std::list m_tables; - int compare_col_values(const ColDefNode &col_def, ColValue *a_val, ColValue *b_val) const; - void execute_order_by(SelectFromTableNode &node, Table *table, Table *result) const; void execute_offset_limit(OffsetLimitNode &node, Table *result) const;