diff --git a/CMakeLists.txt b/CMakeLists.txt index a494e03..98266ee 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -212,6 +212,7 @@ include_directories( add_library(core src/core.cpp src/query.cpp + src/query_execution.cpp src/storage.cpp src/metadata.cpp src/file_utils.cpp diff --git a/include/arrow_utils.hpp b/include/arrow_utils.hpp index a876fc2..f5d305f 100644 --- a/include/arrow_utils.hpp +++ b/include/arrow_utils.hpp @@ -2,11 +2,14 @@ #define ARROW_UTILS_HPP #include -#include -#include +#include +#include namespace tundradb { +arrow::Result> get_ids_from_table( + const std::shared_ptr& table); + // Initialize Arrow Compute module - should be called once at startup bool initialize_arrow_compute(); diff --git a/include/core.hpp b/include/core.hpp index 07160ed..dbf1bc3 100644 --- a/include/core.hpp +++ b/include/core.hpp @@ -11,7 +11,6 @@ #include #include -#include #include #include #include @@ -28,6 +27,7 @@ #include "metadata.hpp" #include "node.hpp" #include "query.hpp" +#include "query_execution.hpp" #include "schema.hpp" #include "storage.hpp" #include "utils.hpp" diff --git a/include/query.hpp b/include/query.hpp index 709c932..4bf4adb 100644 --- a/include/query.hpp +++ b/include/query.hpp @@ -4,12 +4,17 @@ #include #include #include +#include +#include +#include +#include #include #include #include #include #include +#include #include #include "node.hpp" @@ -41,30 +46,9 @@ struct SchemaRef { // Parse a schema reference from a string format "alias:schema" // If the string does not contain a colon, the value is assigned to the alias // and schema - static SchemaRef parse(const std::string& s) { - SchemaRef r; - size_t pos = s.find(':'); - if (pos == std::string::npos) { - r.schema_ = s; - r.value_ = s; - r.declaration_ = false; - } else { - r.value_ = s.substr(0, pos); - r.schema_ = s.substr(pos + 1); - r.declaration_ = true; - } - return r; - } + static SchemaRef parse(const std::string& s); - [[nodiscard]] std::string toString() const { - std::stringstream ss; - if (declaration_) { - ss << value_; - ss << ":"; - } - ss << schema_; - return ss.str(); - } + [[nodiscard]] std::string toString() const; friend std::ostream& operator<<(std::ostream& os, const SchemaRef& obj) { os << obj.toString(); @@ -72,27 +56,6 @@ struct SchemaRef { } }; -struct GraphConnection { - SchemaRef source; - int64_t source_id; - std::string edge_type; - std::string label; - SchemaRef target; - int64_t target_id; - - [[nodiscard]] std::string toString() const { - std::stringstream ss; - ss << "{(" << source << ":id=" << source_id << "->[:" << edge_type << "]->" - << "(" << label << ":" << target << ":id=" << target_id << ")}"; - return ss.str(); - } - - friend std::ostream& operator<<(std::ostream& os, const GraphConnection& c) { - os << c.toString(); - return os; - } -}; - enum class CompareOp { Eq, NotEq, @@ -243,135 +206,22 @@ class ComparisonExpr : public Clause, public WhereExpr { bool inlined_ = false; static arrow::Result compare_values(const Value& value, CompareOp op, - const Value& where_value) { - if (value.type() == ValueType::NA || where_value.type() == ValueType::NA) { - switch (op) { - case CompareOp::Eq: - return value.type() == ValueType::NA && - where_value.type() == ValueType::NA; - case CompareOp::NotEq: - return value.type() != ValueType::NA || - where_value.type() != ValueType::NA; - default: - return arrow::Status::Invalid( - "Null values can only be compared with == or !="); - } - } - - if (op == CompareOp::Contains || op == CompareOp::StartsWith || - op == CompareOp::EndsWith) { - if (value.type() != ValueType::STRING || - where_value.type() != ValueType::STRING) { - return arrow::Status::Invalid( - "String operations (CONTAINS, STARTS_WITH, ENDS_WITH) can only be " - "applied to string values"); - } - } + const Value& where_value); - if (value.type() == ValueType::BOOL || - where_value.type() == ValueType::BOOL) { - if (value.type() != ValueType::BOOL || - where_value.type() != ValueType::BOOL) { - return arrow::Status::Invalid( - "Boolean values can only be compared with other boolean values"); - } - if (op != CompareOp::Eq && op != CompareOp::NotEq) { - return arrow::Status::Invalid( - "Boolean values can only be compared with == or !="); - } - } - - if (value.type() != where_value.type()) { - return arrow::Status::Invalid("Type mismatch: field is ", value.type(), - " but WHERE value is ", where_value.type()); - } - - switch (value.type()) { - case ValueType::INT32: { - int32_t field_val = value.get(); - int32_t where_val = where_value.get(); - return apply_comparison(field_val, op, where_val); - } - case ValueType::INT64: { - int64_t field_val = value.get(); - int64_t where_val = where_value.get(); - return apply_comparison(field_val, op, where_val); - } - case ValueType::FLOAT: { - float field_val = value.get(); - float where_val = where_value.get(); - return apply_comparison(field_val, op, where_val); - } - case ValueType::DOUBLE: { - double field_val = value.get(); - double where_val = where_value.get(); - return apply_comparison(field_val, op, where_val); - } - case ValueType::STRING: { - const std::string& field_val = value.as_string(); - const std::string& where_val = where_value.as_string(); - return apply_comparison(field_val, op, where_val); - } - case ValueType::BOOL: { - bool field_val = value.get(); - bool where_val = where_value.get(); - return apply_comparison(field_val, op, where_val); - } - case ValueType::NA: - return arrow::Status::Invalid("Unexpected null value in comparison"); - default: - return arrow::Status::NotImplemented( - "Unsupported value type for comparison: ", value.type()); - } - } + template + static bool apply_comparison(const T& field_val, CompareOp op, + const T& where_val); template - static bool apply_comparison(const T& field_val, const CompareOp op, - const T& where_val) { - switch (op) { - case CompareOp::Eq: - return field_val == where_val; - case CompareOp::NotEq: - return field_val != where_val; - case CompareOp::Gt: - return field_val > where_val; - case CompareOp::Lt: - return field_val < where_val; - case CompareOp::Gte: - return field_val >= where_val; - case CompareOp::Lte: - return field_val <= where_val; - case CompareOp::Contains: - if constexpr (std::is_same_v) { - return field_val.contains(where_val); - } else { - return false; - } - case CompareOp::StartsWith: - if constexpr (std::is_same_v) { - return field_val.starts_with(where_val); - } else { - return false; - } - case CompareOp::EndsWith: - if constexpr (std::is_same_v) { - return field_val.ends_with(where_val); - } else { - return false; - } - } - return false; - } + static bool apply_comparison(const T& field_val, const T& where_val, + CompareOp op); public: ComparisonExpr(FieldRef field_ref, CompareOp op, Value value) : field_ref_(std::move(field_ref)), op_(op), value_(std::move(value)) {} // Backward compatibility constructor - ComparisonExpr(const std::string& field, CompareOp op, Value value) - : field_ref_(FieldRef::from_string(field)), - op_(op), - value_(std::move(value)) {} + ComparisonExpr(const std::string& field, CompareOp op, Value value); [[nodiscard]] const FieldRef& field_ref() const { return field_ref_; } [[nodiscard]] const std::string& field() const { return field_ref_.value(); } @@ -382,200 +232,27 @@ class ComparisonExpr : public Clause, public WhereExpr { [[nodiscard]] bool inlined() const override { return inlined_; } void set_inlined(bool inlined) override { inlined_ = inlined; } - [[nodiscard]] std::string toString() const override { - std::stringstream ss; - ss << "WHERE " << field_ref_.to_string(); - - switch (op_) { - case CompareOp::Eq: - ss << " = "; - break; - case CompareOp::NotEq: - ss << " != "; - break; - case CompareOp::Gt: - ss << " > "; - break; - case CompareOp::Lt: - ss << " < "; - break; - case CompareOp::Gte: - ss << " >= "; - break; - case CompareOp::Lte: - ss << " <= "; - break; - case CompareOp::Contains: - ss << " CONTAINS "; - break; - case CompareOp::StartsWith: - ss << " STARTS_WITH "; - break; - case CompareOp::EndsWith: - ss << " ENDS_WITH "; - break; - } - - switch (value_.type()) { - case ValueType::NA: - ss << "NULL"; - break; - case ValueType::INT32: - ss << value_.get(); - break; - case ValueType::INT64: - ss << value_.get(); - break; - case ValueType::FLOAT: - ss << value_.get(); - break; - case ValueType::DOUBLE: - ss << value_.get(); - break; - case ValueType::BOOL: - ss << (value_.get() ? "true" : "false"); - break; - case ValueType::FIXED_STRING16: - case ValueType::FIXED_STRING32: - case ValueType::FIXED_STRING64: - case ValueType::STRING: - ss << "'" << value_.to_string() << "'"; - break; - } + [[nodiscard]] std::string toString() const override; - if (inlined_) { - ss << " (inlined)"; - } + friend std::ostream& operator<<(std::ostream& os, const ComparisonExpr& expr); - return ss.str(); - } - - friend std::ostream& operator<<(std::ostream& os, - const ComparisonExpr& expr) { - os << expr.toString(); - return os; - } - - arrow::Result matches( - const std::shared_ptr& node) const override { - if (!node) { - return arrow::Status::Invalid("Node is null"); - } - assert(field_ref_.field() != nullptr); - ARROW_ASSIGN_OR_RAISE(auto field_value, - node->get_value(field_ref_.field())); - return compare_values(field_value, op_, value_); - } + arrow::Result matches(const std::shared_ptr& node) const override; [[nodiscard]] arrow::compute::Expression to_arrow_expression( - bool strip_var) const override { - std::string field_name = - strip_var ? field_ref_.field_name() : field_ref_.value(); - const auto field_expr = arrow::compute::field_ref(field_name); - const auto value_expr = value_to_expression(value_); - - return apply_comparison_op(field_expr, value_expr, op_); - } + bool strip_var) const override; std::vector> get_conditions_for_variable( - const std::string& variable) const override { - if (field_ref_.variable() == variable) { - return {std::make_shared(*this)}; - } - return {}; - } + const std::string& variable) const override; - bool can_inline(const std::string& variable) const override { - return field_ref_.variable() == variable; - } + bool can_inline(const std::string& variable) const override; - std::string extract_first_variable() const override { - return field_ref_.variable(); - } + std::string extract_first_variable() const override; - std::set get_all_variables() const override { - std::set variables; - variables.insert(field_ref_.variable()); - return variables; - } + std::set get_all_variables() const override; arrow::Result resolve_field_ref( const std::unordered_map& aliases, - const SchemaRegistry* schema_registry) override { - if (field_ref_.is_resolved()) { - return true; - } - - const std::string& variable = field_ref_.variable(); - const std::string& field_name = field_ref_.field_name(); - - // Find the actual schema for this variable - auto it = aliases.find(variable); - if (it == aliases.end()) { - return arrow::Status::KeyError("Unknown variable '", variable, - "' in field '", field_ref_.to_string(), - "'"); - } - - const std::string& schema_name = it->second; - - auto schema_result = schema_registry->get(schema_name); - if (!schema_result.ok()) { - return arrow::Status::KeyError( - "Schema '", schema_name, "' not found for variable '", variable, "'"); - } - - auto schema = schema_result.ValueOrDie(); - auto field = schema->get_field(field_name); - if (!field) { - return arrow::Status::KeyError( - "Field '", field_name, "' not found in schema '", schema_name, "'"); - } - field_ref_.resolve(field); - - return true; - } - - private: - template - static bool apply_comparison(const T& field_val, const T& where_val, - CompareOp op) { - switch (op) { - case CompareOp::Eq: - return field_val == where_val; - case CompareOp::NotEq: - return field_val != where_val; - case CompareOp::Gt: - return field_val > where_val; - case CompareOp::Lt: - return field_val < where_val; - case CompareOp::Gte: - return field_val >= where_val; - case CompareOp::Lte: - return field_val <= where_val; - case CompareOp::Contains: - if constexpr (std::is_same_v) { - return field_val.find(where_val) != std::string::npos; - } else { - return false; - } - case CompareOp::StartsWith: - if constexpr (std::is_same_v) { - return field_val.find(where_val) == 0; - } else { - return false; - } - case CompareOp::EndsWith: - if constexpr (std::is_same_v) { - return field_val.size() >= where_val.size() && - field_val.substr(field_val.size() - where_val.size()) == - where_val; - } else { - return false; - } - } - return false; - } + const SchemaRegistry* schema_registry) override; }; class LogicalExpr : public Clause, public WhereExpr { @@ -592,41 +269,17 @@ class LogicalExpr : public Clause, public WhereExpr { [[nodiscard]] Type type() const override { return Type::WHERE; } [[nodiscard]] bool inlined() const override { return inlined_; } - void set_inlined(bool inlined) override { - inlined_ = inlined; - if (left_) left_->set_inlined(inlined); - if (right_) right_->set_inlined(inlined); - } + void set_inlined(bool inlined) override; arrow::Result resolve_field_ref( const std::unordered_map& aliases, - const SchemaRegistry* schema_registry) override { - if (left_) { - if (const auto res = left_->resolve_field_ref(aliases, schema_registry); - !res.ok()) { - return res.status(); - } - } - if (right_) { - if (const auto res = right_->resolve_field_ref(aliases, schema_registry); - !res.ok()) { - return res.status(); - } - } - return true; - } + const SchemaRegistry* schema_registry) override; static std::shared_ptr and_expr( - std::shared_ptr left, std::shared_ptr right) { - return std::make_shared(std::move(left), LogicalOp::AND, - std::move(right)); - } + std::shared_ptr left, std::shared_ptr right); - static std::shared_ptr or_expr( - std::shared_ptr left, std::shared_ptr right) { - return std::make_shared(std::move(left), LogicalOp::OR, - std::move(right)); - } + static std::shared_ptr or_expr(std::shared_ptr left, + std::shared_ptr right); // Public accessors [[nodiscard]] const std::shared_ptr& left() const { return left_; } @@ -635,140 +288,23 @@ class LogicalExpr : public Clause, public WhereExpr { } [[nodiscard]] LogicalOp op() const { return op_; } - arrow::Result matches( - const std::shared_ptr& node) const override { - if (!left_ || !right_) { - return arrow::Status::Invalid( - "LogicalExpr missing left or right operand"); - } - - auto left_result = left_->matches(node); - if (!left_result.ok()) { - return left_result.status(); - } - - auto right_result = right_->matches(node); - if (!right_result.ok()) { - return right_result.status(); - } - - bool left_val = left_result.ValueOrDie(); - bool right_val = right_result.ValueOrDie(); - - switch (op_) { - case LogicalOp::AND: - return left_val && right_val; - case LogicalOp::OR: - return left_val || right_val; - } - - return arrow::Status::Invalid("Unknown logical operator"); - } + arrow::Result matches(const std::shared_ptr& node) const override; [[nodiscard]] arrow::compute::Expression to_arrow_expression( - bool strip_var) const override { - if (!left_ || !right_) { - throw std::runtime_error("LogicalExpr missing left or right operand"); - } - - auto left_expr = left_->to_arrow_expression(strip_var); - auto right_expr = right_->to_arrow_expression(strip_var); - - switch (op_) { - case LogicalOp::AND: - return arrow::compute::and_(left_expr, right_expr); - case LogicalOp::OR: - return arrow::compute::or_(left_expr, right_expr); - default: - throw std::runtime_error("Unknown logical operator in LogicalExpr"); - } - } + bool strip_var) const override; std::vector> get_conditions_for_variable( - const std::string& variable) const override { - auto all_variables = get_all_variables(); - for (const auto& var : all_variables) { - if (var != variable) { - return {}; - } - } + const std::string& variable) const override; - std::vector> result; - if (left_) { - auto left_conditions = left_->get_conditions_for_variable(variable); - result.insert(result.end(), left_conditions.begin(), - left_conditions.end()); - } - if (right_) { - auto right_conditions = right_->get_conditions_for_variable(variable); - result.insert(result.end(), right_conditions.begin(), - right_conditions.end()); - } - return result; - } + std::string extract_first_variable() const override; - std::string extract_first_variable() const override { - if (left_) { - auto var = left_->extract_first_variable(); - if (!var.empty()) return var; - } - if (right_) { - auto var = right_->extract_first_variable(); - if (!var.empty()) return var; - } - return ""; - } + std::string toString() const override; - std::string toString() const override { - if (!left_ || !right_) { - return "WHERE (incomplete logical expression)"; - } + friend std::ostream& operator<<(std::ostream& os, const LogicalExpr& expr); - std::string left_str = left_->toString(); - std::string right_str = right_->toString(); - - if (left_str.substr(0, 6) == "WHERE ") { - left_str = left_str.substr(6); - } - if (right_str.substr(0, 6) == "WHERE ") { - right_str = right_str.substr(6); - } + std::set get_all_variables() const override; - std::string op_str = (op_ == LogicalOp::AND) ? " AND " : " OR "; - - std::string result = - "WHERE (" + left_str + ")" + op_str + "(" + right_str + ")"; - - if (inlined_) { - result += " (inlined)"; - } - - return result; - } - - friend std::ostream& operator<<(std::ostream& os, const LogicalExpr& expr) { - os << expr.toString(); - return os; - } - - std::set get_all_variables() const override { - std::set variables; - if (left_) { - auto left_variables = left_->get_all_variables(); - variables.insert(left_variables.begin(), left_variables.end()); - } - if (right_) { - auto right_variables = right_->get_all_variables(); - variables.insert(right_variables.begin(), right_variables.end()); - } - return variables; - } - - bool can_inline(const std::string& variable) const override { - if (left_ && !left_->can_inline(variable)) return false; - if (right_ && !right_->can_inline(variable)) return false; - return true; - } + bool can_inline(const std::string& variable) const override; }; struct ExecutionConfig { diff --git a/include/query_execution.hpp b/include/query_execution.hpp new file mode 100644 index 0000000..c711c5f --- /dev/null +++ b/include/query_execution.hpp @@ -0,0 +1,494 @@ +#ifndef QUERY_EXECUTION_HPP +#define QUERY_EXECUTION_HPP + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "query.hpp" +#include "schema.hpp" + +namespace tundradb { + +// Forward declarations +class SchemaRegistry; +class NodeManager; + +/** + * @brief Runtime connection between two nodes discovered during traversal + * + * Represents an edge in the graph that was found during query execution. + * Different from Traverse (which is part of query AST) - this is actual data. + */ +struct GraphConnection { + SchemaRef source; + int64_t source_id; + std::string edge_type; + std::string label; + SchemaRef target; + int64_t target_id; + + [[nodiscard]] std::string toString() const { + std::stringstream ss; + ss << "{(" << source << ":id=" << source_id << "->[:" << edge_type << "]->" + << "(" << label << ":" << target << ":id=" << target_id << ")}"; + return ss.str(); + } + + friend std::ostream& operator<<(std::ostream& os, const GraphConnection& c) { + os << c.toString(); + return os; + } +}; + +/** + * @brief Connection pool for reusing GraphConnection objects + */ +class ConnectionPool { + private: + std::vector pool_; + size_t next_index_ = 0; + + public: + explicit ConnectionPool(size_t initial_size = 1000) : pool_(initial_size) {} + + GraphConnection& get() { + if (next_index_ >= pool_.size()) { + pool_.resize(pool_.size() * 2); + } + return pool_[next_index_++]; + } + + void reset() { next_index_ = 0; } + size_t size() const { return next_index_; } +}; + +/** + * @brief Manages schema resolution and aliases for query execution + * + * Responsibilities: + * - Map aliases (e.g., "u") to schema names (e.g., "User") + * - Resolve SchemaRef objects to concrete schema names + * - Validate schema references + */ +class SchemaContext { + private: + std::unordered_map aliases_; + std::shared_ptr schema_registry_; + + public: + explicit SchemaContext(std::shared_ptr registry) + : schema_registry_(std::move(registry)) {} + + /** + * Register a schema alias (e.g., "u" -> "User") + */ + arrow::Result register_schema(const SchemaRef& schema_ref); + + /** + * Resolve schema reference to concrete schema name + */ + arrow::Result resolve(const SchemaRef& schema_ref) const; + + /** + * Get schema registry + */ + std::shared_ptr registry() const { return schema_registry_; } + + /** + * Get all registered aliases + */ + const std::unordered_map& get_aliases() const { + return aliases_; + } +}; + +/** + * @brief Manages graph topology during query execution + * + * Responsibilities: + * - Track active node IDs per schema + * - Store connections (edges) between nodes + * - Query connection information + */ +class GraphState { + private: + // Node IDs per schema alias + llvm::StringMap> node_ids_; + + // Outgoing connections: schema -> source_id -> [connections] + llvm::StringMap< + llvm::DenseMap>> + outgoing_; + + // Incoming connections: target_id -> [connections] + llvm::DenseMap> incoming_; + + // Connection object pool for performance + mutable ConnectionPool connection_pool_; + + public: + /** + * Get node IDs for a schema (mutable) + */ + llvm::DenseSet& ids(const std::string& schema_alias) { + return node_ids_[schema_alias]; + } + + /** + * Get node IDs for a schema (const) + */ + const llvm::DenseSet& ids(const std::string& schema_alias) const { + auto it = node_ids_.find(schema_alias); + if (it != node_ids_.end()) { + return it->second; + } + static const llvm::DenseSet empty; + return empty; + } + + /** + * Add a connection between nodes + */ + void add_connection(const GraphConnection& conn) { + auto& pool_conn = connection_pool_.get(); + pool_conn = conn; + + outgoing_[conn.source.value()][conn.source_id].push_back(pool_conn); + incoming_[conn.target_id].push_back(pool_conn); + } + + /** + * Check if node has outgoing edges + */ + bool has_outgoing(const SchemaRef& schema_ref, int64_t node_id) const { + return outgoing_.contains(schema_ref.value()) && + outgoing_.at(schema_ref.value()).contains(node_id) && + !outgoing_.at(schema_ref.value()).at(node_id).empty(); + } + + /** + * Get outgoing connections for a node + */ + const llvm::SmallVector* get_outgoing( + const std::string& schema_alias, int64_t node_id) const { + auto schema_it = outgoing_.find(schema_alias); + if (schema_it == outgoing_.end()) { + return nullptr; + } + + auto node_it = schema_it->second.find(node_id); + if (node_it == schema_it->second.end()) { + return nullptr; + } + + return &node_it->second; + } + + /** + * Remove a node from the graph + */ + void remove_node(int64_t node_id, const SchemaRef& schema_ref) { + node_ids_[schema_ref.value()].erase(node_id); + } + + /** + * Access to outgoing connections map + */ + const llvm::StringMap< + llvm::DenseMap>>& + outgoing() const { + return outgoing_; + } + + /** + * Access to incoming connections map + */ + const llvm::DenseMap>& + incoming() const { + return incoming_; + } + + /** + * Access to incoming connections map (mutable) + */ + llvm::DenseMap>& incoming() { + return incoming_; + } + + /** + * Access to connection pool + */ + ConnectionPool& connection_pool() const { return connection_pool_; } + + /** + * Get all node IDs (for backward compatibility) + */ + llvm::StringMap>& get_ids() { return node_ids_; } + + const llvm::StringMap>& get_ids() const { + return node_ids_; + } + + /** + * Get outgoing connections map (direct access to internal structure) + * Returns: schema -> node_id -> connections + */ + llvm::StringMap< + llvm::DenseMap>>& + get_outgoing_map() { + return outgoing_; + } + + const llvm::StringMap< + llvm::DenseMap>>& + get_outgoing_map() const { + return outgoing_; + } +}; + +/** + * @brief Manages field indexing for efficient row operations + * + * Responsibilities: + * - Assign unique integer IDs to fully-qualified field names + * - Map field IDs to names and vice versa + * - Compute field indices per schema + */ +class FieldIndexer { + private: + // Fully-qualified field names per schema alias + llvm::StringMap> fq_field_names_; + + // Field indices per schema + llvm::StringMap> schema_field_indices_; + + // Bidirectional mapping between field IDs and names + llvm::SmallDenseMap field_id_to_name_; + llvm::StringMap field_name_to_index_; + + // Global field ID counter + std::atomic next_field_id_{0}; + + public: + /** + * Compute fully-qualified field names for a schema + */ + arrow::Result compute_fq_names(const SchemaRef& schema_ref, + const std::string& resolved_schema, + SchemaRegistry* registry); + + /** + * Get field indices for a schema + */ + const std::vector* get_field_indices( + const std::string& schema_alias) const { + auto it = schema_field_indices_.find(schema_alias); + return it != schema_field_indices_.end() ? &it->second : nullptr; + } + + /** + * Get field name by ID + */ + const std::string& get_field_name(int field_id) const { + return field_id_to_name_.at(field_id); + } + + /** + * Get field ID by name + */ + int get_field_id(const std::string& field_name) const { + auto it = field_name_to_index_.find(field_name); + return it != field_name_to_index_.end() ? it->second : -1; + } + + /** + * Get all field ID to name mappings (for row operations) + */ + const llvm::SmallDenseMap& field_id_to_name() const { + return field_id_to_name_; + } + + /** + * Check if schema field names are already computed + */ + bool has_computed(const std::string& schema_alias) const { + return fq_field_names_.contains(schema_alias); + } + + /** + * Get schema_field_indices (for backward compatibility) + * Returns the actual internal map of schema -> field_indices + */ + llvm::StringMap>& get_schema_field_indices() { + return schema_field_indices_; + } + + const llvm::StringMap>& get_schema_field_indices() const { + return schema_field_indices_; + } + + /** + * Get field_id_to_name map (direct access to internal map) + */ + llvm::SmallDenseMap& get_field_id_to_name() { + return field_id_to_name_; + } + + const llvm::SmallDenseMap& get_field_id_to_name() + const { + return field_id_to_name_; + } + + private: + // No cached views needed - we expose the internal maps directly +}; + +/** + * @brief Query execution state container + * + * Composed of focused components: + * - SchemaContext: Schema resolution and aliases + * - GraphState: Graph topology (node IDs, connections) + * - FieldIndexer: Field indexing for efficient row operations + * - Tables: Arrow table storage + */ +struct QueryState { + // Core components + SchemaContext schemas; + GraphState graph; + FieldIndexer fields; + + // Table storage + std::unordered_map> tables; + + // Source schema for FROM clause + SchemaRef from; + + // Traversals in query + std::vector traversals; + + // Node manager for fetching nodes + std::shared_ptr node_manager; + + // Temporal context (nullptr = current version) + std::unique_ptr temporal_context; + + // Constructor + explicit QueryState(std::shared_ptr registry); + + // Convenience accessors (delegate to components) + + arrow::Result register_schema(const SchemaRef& ref) { + return schemas.register_schema(ref); + } + + arrow::Result resolve_schema(const SchemaRef& ref) const { + return schemas.resolve(ref); + } + + llvm::DenseSet& get_ids(const SchemaRef& schema_ref) { + return graph.ids(schema_ref.value()); + } + + const llvm::DenseSet& get_ids(const SchemaRef& schema_ref) const { + return graph.ids(schema_ref.value()); + } + + bool has_outgoing(const SchemaRef& ref, int64_t node_id) const { + return graph.has_outgoing(ref, node_id); + } + + arrow::Result compute_fully_qualified_names( + const SchemaRef& ref, const std::string& resolved_schema) { + return fields.compute_fq_names(ref, resolved_schema, + schemas.registry().get()); + } + + arrow::Result compute_fully_qualified_names(const SchemaRef& ref); + + void remove_node(int64_t node_id, const SchemaRef& ref) { + graph.remove_node(node_id, ref); + } + + // Backward compatibility accessors for core.cpp migration + std::shared_ptr schema_registry() const { + return schemas.registry(); + } + + const std::unordered_map& aliases() const { + return schemas.get_aliases(); + } + + llvm::StringMap>& ids() { return graph.get_ids(); } + + const llvm::StringMap>& ids() const { + return graph.get_ids(); + } + + llvm::StringMap< + llvm::DenseMap>>& + connections() { + return graph.get_outgoing_map(); + } + + const llvm::StringMap< + llvm::DenseMap>>& + connections() const { + return graph.get_outgoing_map(); + } + + // Direct access to incoming connections (by node ID) + llvm::DenseMap>& incoming() { + return graph.incoming(); + } + + const llvm::DenseMap>& + incoming() const { + return graph.incoming(); + } + + llvm::StringMap>& schema_field_indices() { + return fields.get_schema_field_indices(); + } + + const llvm::StringMap>& schema_field_indices() const { + return fields.get_schema_field_indices(); + } + + llvm::SmallDenseMap& field_id_to_name() { + return fields.get_field_id_to_name(); + } + + const llvm::SmallDenseMap& field_id_to_name() const { + return fields.get_field_id_to_name(); + } + + // Connection pool accessor (for core.cpp) + ConnectionPool& connection_pool() { return graph.connection_pool(); } + + const ConnectionPool& connection_pool() const { + return graph.connection_pool(); + } + + // Complex methods - implemented in query_execution.cpp + void reserve_capacity(const Query& query); + + arrow::Result update_table(const std::shared_ptr& table, + const SchemaRef& schema_ref); + + std::string ToString() const; +}; + +} // namespace tundradb + +#endif // QUERY_EXECUTION_HPP diff --git a/include/utils.hpp b/include/utils.hpp index e245c25..8a13592 100644 --- a/include/utils.hpp +++ b/include/utils.hpp @@ -85,33 +85,6 @@ static arrow::Result> filter_table_by_id( return filtered_table.table(); } -static arrow::Result> get_ids_from_table( - std::shared_ptr table) { - log_debug("Extracting IDs from table with {} rows", table->num_rows()); - - auto id_idx = table->schema()->GetFieldIndex("id"); - if (id_idx == -1) { - log_error("Table does not have an 'id' column"); - return arrow::Status::Invalid("table does not have an 'id' column"); - } - - auto id_column = table->column(id_idx); - llvm::DenseSet result_ids; - result_ids.reserve(table->num_rows()); - - for (int chunk_idx = 0; chunk_idx < id_column->num_chunks(); chunk_idx++) { - auto chunk = std::static_pointer_cast( - id_column->chunk(chunk_idx)); - log_debug("Processing chunk {} with {} rows", chunk_idx, chunk->length()); - for (int i = 0; i < chunk->length(); i++) { - result_ids.insert(chunk->Value(i)); - } - } - - log_debug("Extracted {} unique IDs from table", result_ids.size()); - return result_ids; -} - static arrow::Result> create_table( const std::shared_ptr& schema, const std::vector>& nodes, size_t chunk_size, diff --git a/src/arrow_utils.cpp b/src/arrow_utils.cpp index 46c7a77..0f01397 100644 --- a/src/arrow_utils.cpp +++ b/src/arrow_utils.cpp @@ -1,11 +1,47 @@ #include "../include/arrow_utils.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + #include #include "../include/logger.hpp" namespace tundradb { +arrow::Result> get_ids_from_table( + const std::shared_ptr& table) { + log_debug("Extracting IDs from table with {} rows", table->num_rows()); + + const auto id_idx = table->schema()->GetFieldIndex("id"); + if (id_idx == -1) { + log_error("Table does not have an 'id' column"); + return arrow::Status::Invalid("table does not have an 'id' column"); + } + + const auto id_column = table->column(id_idx); + llvm::DenseSet result_ids; + result_ids.reserve(table->num_rows()); + + for (int chunk_idx = 0; chunk_idx < id_column->num_chunks(); chunk_idx++) { + const auto chunk = std::static_pointer_cast( + id_column->chunk(chunk_idx)); + log_debug("Processing chunk {} with {} rows", chunk_idx, chunk->length()); + for (int i = 0; i < chunk->length(); i++) { + result_ids.insert(chunk->Value(i)); + } + } + + log_debug("Extracted {} unique IDs from table", result_ids.size()); + return result_ids; +} + // Initialize Arrow Compute module - should be called once at startup bool initialize_arrow_compute() { static bool initialized = false; diff --git a/src/core.cpp b/src/core.cpp index 2c1d9a3..f4bb20b 100644 --- a/src/core.cpp +++ b/src/core.cpp @@ -420,243 +420,6 @@ std::set get_roots( return roots; } -struct QueryState { - SchemaRef from; - std::unordered_map> tables; - llvm::StringMap> ids; - std::unordered_map aliases; - // Precomputed fully-qualified field names per alias (SchemaRef::value()) - llvm::StringMap> fq_field_names; - - // Field index optimization: replace string-based field lookups with integer - // indices - llvm::StringMap> - schema_field_indices; // "User" -> [0, 1, 2], "Company -> [3,4,5]" - llvm::SmallDenseMap - field_id_to_name; // 0 -> "user.name" - llvm::StringMap field_name_to_index; // "user.name" -> 0 - std::atomic next_field_id{0}; // Global field ID counter - - llvm::StringMap< - llvm::DenseMap>> - connections; // outgoing - llvm::DenseMap> incoming; - - std::shared_ptr node_manager; - std::shared_ptr schema_registry; - std::vector traversals; - - // Temporal context for time-travel queries (nullptr = current version) - std::unique_ptr temporal_context; - - // Connection object pooling to avoid repeated allocations - class ConnectionPool { - private: - std::vector pool_; - size_t next_index_ = 0; - - public: - explicit ConnectionPool(size_t initial_size = 1000) : pool_(initial_size) {} - - GraphConnection& get() { - if (next_index_ >= pool_.size()) { - pool_.resize(pool_.size() * 2); // Grow pool if needed - } - return pool_[next_index_++]; - } - - void reset() { next_index_ = 0; } // Reset for reuse - size_t size() const { return next_index_; } - }; - - mutable ConnectionPool connection_pool_; // Mutable for const methods - - // Pre-size hash maps to avoid expensive resizing during query execution - void reserve_capacity(const Query& query) { - // Estimate schema count from FROM + TRAVERSE clauses - size_t estimated_schemas = 1; // FROM clause - for (const auto& clause : query.clauses()) { - if (clause->type() == Clause::Type::TRAVERSE) { - estimated_schemas += 2; // source + target schemas - } - } - - // Pre-size standard containers (LLVM containers don't support reserve) - tables.reserve(estimated_schemas); - aliases.reserve(estimated_schemas); - - // Estimate nodes per schema (conservative estimate) - size_t estimated_nodes_per_schema = 1000; - incoming.reserve(estimated_nodes_per_schema); - - // Pre-size field mappings - field_id_to_name.reserve(estimated_schemas * 8); // ~8 fields per schema - } - - arrow::Result resolve_schema(const SchemaRef& schema_ref) { - // todo we need to separate functions: assign alias , resolve - if (aliases.contains(schema_ref.value()) && schema_ref.is_declaration()) { - IF_DEBUG_ENABLED { - log_debug("duplicated schema alias '" + schema_ref.value() + - "' already assigned to '" + aliases[schema_ref.value()] + - "'"); - } - return aliases[schema_ref.value()]; - } - if (schema_ref.is_declaration()) { - aliases[schema_ref.value()] = schema_ref.schema(); - return schema_ref.schema(); - } - return aliases[schema_ref.value()]; - } - - // Precompute fully-qualified field names for source and target aliases - arrow::Result compute_fully_qualified_names( - const SchemaRef& schema_ref) { - const auto it = aliases.find(schema_ref.value()); - if (it == aliases.end()) { - return arrow::Status::KeyError("keyset does not contain alias '{}'", - schema_ref.value()); - } - return compute_fully_qualified_names(schema_ref, it->second); - } - - // Precompute fully-qualified field names for source and target aliases - arrow::Result compute_fully_qualified_names( - const SchemaRef& schema_ref, const std::string& resolved_schema) { - const std::string& alias = schema_ref.value(); - if (fq_field_names.contains(alias)) { - return false; - } - auto schema_res = schema_registry->get(resolved_schema); - if (!schema_res.ok()) { - return schema_res.status(); - } - const auto& schema = schema_res.ValueOrDie(); - std::vector names; - std::vector indices; - names.reserve(schema->num_fields()); - indices.reserve(schema->num_fields()); - - for (const auto& f : schema->fields()) { - std::string fq_name = alias + "." + f->name(); - int field_id = next_field_id.fetch_add(1); - names.emplace_back(fq_name); - indices.emplace_back(field_id); - field_id_to_name[field_id] = fq_name; - field_name_to_index[fq_name] = field_id; - } - - fq_field_names[alias] = std::move(names); - schema_field_indices[alias] = std::move(indices); - return true; - } - - const llvm::DenseSet& get_ids(const SchemaRef& schema_ref) { - return ids[schema_ref.value()]; - } - - // removes node_id and updates all connections and ids - void remove_node(int64_t node_id, const SchemaRef& schema_ref) { - ids[schema_ref.value()].erase(node_id); - } - - arrow::Result update_table(const std::shared_ptr& table, - const SchemaRef& schema_ref) { - this->tables[schema_ref.value()] = table; - auto ids_result = get_ids_from_table(table); - if (!ids_result.ok()) { - log_error("Failed to get IDs from table: {}", schema_ref.value()); - return ids_result.status(); - } - ids[schema_ref.value()] = ids_result.ValueOrDie(); - return true; - } - - bool has_outgoing(const SchemaRef& schema_ref, int64_t node_id) const { - return connections.contains(schema_ref.value()) && - connections.at(schema_ref.value()).contains(node_id) && - !connections.at(schema_ref.value()).at(node_id).empty(); - } - - std::string ToString() const { - std::stringstream ss; - ss << "QueryState {\n"; - ss << " From: " << from.toString() << "\n"; - - ss << " Tables (" << tables.size() << "):\n"; - for (const auto& [alias, table_ptr] : tables) { - if (table_ptr) { - ss << " - " << alias << ": " << table_ptr->num_rows() << " rows, " - << table_ptr->num_columns() << " columns\n"; - } else { - ss << " - " << alias << ": (nullptr)\n"; - } - } - - ss << " IDs (" << ids.size() << "):\n"; - for (const auto& [alias, id_set] : ids) { - ss << " - " << alias.str() << ": " << id_set.size() << " IDs\n"; - } - - ss << " Aliases (" << aliases.size() << "):\n"; - for (const auto& [alias, schema_name] : aliases) { - ss << " - " << alias << " -> " << schema_name << "\n"; - } - - ss << " Connections (Outgoing) (" << connections.size() - << " source nodes):"; - for (const auto& [from, conns] : connections) { - for (const auto& [from_id, conn_vec] : conns) { - ss << "from " << from.str() << ":" << from_id << ":\n"; - for (const auto& conn : conn_vec) { - ss << " - " << conn.target.value() << ":" << conn.target_id - << "\n"; - } - } - } - - ss << " Connections (Incoming) (" << incoming.size() << " target nodes):"; - int target_nodes_printed = 0; - for (const auto& [target_id, conns_vec] : incoming) { - if (target_nodes_printed >= 3 && - incoming.size() > 5) { // Limit nodes printed - ss << " ... and " << (incoming.size() - target_nodes_printed) - << " more target nodes ...\n"; - break; - } - ss << " - Target ID " << target_id << " (" << conns_vec.size() - << " incoming):"; - int conns_printed_for_target = 0; - for (const auto& conn : conns_vec) { - if (conns_printed_for_target >= 3 && - conns_vec.size() > 5) { // Limit connections per node - ss << " ... and " - << (conns_vec.size() - conns_printed_for_target) - << " more connections ...\n"; - break; - } - ss << " <- " << conn.source.value() << ":" << conn.source_id - << " (via '" << conn.edge_type << "')\n"; - conns_printed_for_target++; - } - target_nodes_printed++; - } - - ss << " Traversals (" << traversals.size() << "):\n"; - for (size_t i = 0; i < traversals.size(); ++i) { - const auto& trav = traversals[i]; - ss << " - [" << i << "]: " << trav.source().value() << " -[" - << trav.edge_type() << "]-> " << trav.target().value() << " (Type: " - << (trav.traverse_type() == TraverseType::Inner ? "Inner" : "Other") - << ")\n"; - } - - ss << "}"; - return ss.str(); - } -}; - arrow::Result> build_denormalized_schema( const QueryState& query_state) { IF_DEBUG_ENABLED { log_debug("Building schema for denormalized table"); } @@ -673,7 +436,7 @@ arrow::Result> build_denormalized_schema( } auto schema_result = - query_state.schema_registry->get(query_state.aliases.at(from_schema)); + query_state.schema_registry()->get(query_state.aliases().at(from_schema)); if (!schema_result.ok()) { return schema_result.status(); } @@ -702,8 +465,8 @@ arrow::Result> build_denormalized_schema( log_debug("Adding fields from schema '{}'", schema_ref.value()); } - schema_result = query_state.schema_registry->get( - query_state.aliases.at(schema_ref.value())); + schema_result = query_state.schema_registry()->get( + query_state.aliases().at(schema_ref.value())); if (!schema_result.ok()) { return schema_result.status(); } @@ -1311,14 +1074,14 @@ populate_rows_bfs(int64_t node_id, const SchemaRef& start_schema, while (size-- > 0) { auto item = queue.front(); queue.pop(); - auto item_schema = item.schema_ref.is_declaration() - ? item.schema_ref.schema() - : query_state.aliases.at(item.schema_ref.value()); + ARROW_ASSIGN_OR_RAISE(const auto item_schema, + query_state.resolve_schema(item.schema_ref)); + auto node = query_state.node_manager->get_node(item_schema, item.node_id) .ValueOrDie(); const auto& it_fq = - query_state.schema_field_indices.find(item.schema_ref.value()); - if (it_fq == query_state.schema_field_indices.end()) { + query_state.schema_field_indices().find(item.schema_ref.value()); + if (it_fq == query_state.schema_field_indices().end()) { log_error("No fully-qualified field names for schema '{}'", item.schema_ref.value()); return arrow::Status::KeyError( @@ -1338,12 +1101,13 @@ populate_rows_bfs(int64_t node_id, const SchemaRef& start_schema, bool skip = false; if (query_state.has_outgoing(item.schema_ref, item.node_id)) { - for (const auto& conn : - query_state.connections.at(item.schema_ref.value()) - .at(item.node_id)) { + for (const auto& conn : query_state.connections() + .at(item.schema_ref.value()) + .at(item.node_id)) { const uint64_t tgt_packed = hash_code_(conn.target, conn.target_id); if (!item.path_visited_nodes.contains(tgt_packed)) { - if (query_state.ids.at(conn.target.value()) + if (query_state.ids() + .at(conn.target.value()) .contains(conn.target_id)) { grouped_connections[conn.target.value()].push_back(conn); } else { @@ -1411,7 +1175,7 @@ populate_rows_bfs(int64_t node_id, const SchemaRef& start_schema, tree.insert_row(r_copy); } IF_DEBUG_ENABLED { tree.print(); } - auto merged = tree.merge_rows(query_state.field_id_to_name); + auto merged = tree.merge_rows(query_state.field_id_to_name()); IF_DEBUG_ENABLED { for (const auto& row : merged) { log_debug("merge result: {}", row->ToString()); @@ -1546,13 +1310,13 @@ arrow::Result>>> populate_rows( static_cast(join_type)); } - if (!query_state.ids.contains(schema_ref.value())) { + if (!query_state.ids().contains(schema_ref.value())) { log_warn("Schema '{}' not found in query state IDs", schema_ref.value()); continue; } // Get all nodes for this schema - const auto& schema_nodes = query_state.ids.at(schema_ref.value()); + const auto& schema_nodes = query_state.ids().at(schema_ref.value()); std::vector> batch_ids; if (execution_config.parallel_enabled) { size_t batch_size = 0; @@ -1888,7 +1652,7 @@ arrow::Status prepare_query(Query& query, QueryState& query_state) { // Phase 1: Process FROM clause to populate aliases { ARROW_ASSIGN_OR_RAISE(auto from_schema, - query_state.resolve_schema(query.from())); + query_state.register_schema(query.from())); // FROM clause already processed in main query() function } @@ -1899,9 +1663,9 @@ arrow::Status prepare_query(Query& query, QueryState& query_state) { // Resolve schemas and populate aliases ARROW_ASSIGN_OR_RAISE(auto source_schema, - query_state.resolve_schema(traverse->source())); + query_state.register_schema(traverse->source())); ARROW_ASSIGN_OR_RAISE(auto target_schema, - query_state.resolve_schema(traverse->target())); + query_state.register_schema(traverse->target())); if (!traverse->source().is_declaration()) { traverse->mutable_source().set_schema(source_schema); @@ -1926,7 +1690,7 @@ arrow::Status prepare_query(Query& query, QueryState& query_state) { if (clause->type() == Clause::Type::WHERE) { auto where_expr = std::dynamic_pointer_cast(clause); auto res = where_expr->resolve_field_ref( - query_state.aliases, query_state.schema_registry.get()); + query_state.aliases(), query_state.schema_registry().get()); if (!res.ok()) { return res.status(); } @@ -1962,7 +1726,7 @@ void dense_difference(const SetA& a, const SetB& b, OutSet& out) { arrow::Result> Database::query( const Query& query) const { - QueryState query_state; + QueryState query_state(this->schema_registry_); auto result = std::make_shared(); // Initialize temporal context if AS OF clause is present @@ -1984,7 +1748,6 @@ arrow::Result> Database::query( query.from().toString()); } query_state.node_manager = this->node_manager_; - query_state.schema_registry = this->schema_registry_; query_state.from = query.from(); { @@ -1995,7 +1758,7 @@ arrow::Result> Database::query( query_state.from = query.from(); query_state.from.set_tag(compute_tag(query_state.from)); ARROW_ASSIGN_OR_RAISE(auto source_schema, - query_state.resolve_schema(query.from())); + query_state.register_schema(query.from())); if (!this->schema_registry_->exists(source_schema)) { log_error("schema '{}' doesn't exist", source_schema); return arrow::Status::KeyError("schema doesn't exit: {}", source_schema); @@ -2100,27 +1863,16 @@ arrow::Result> Database::query( // Tags and schemas are already set during preparation phase // Get resolved schemas using const resolve_schema (read-only) - auto source_schema = - traverse->source().is_declaration() - ? traverse->source().schema() - : query_state.aliases.at(traverse->source().value()); - auto target_schema = - traverse->target().is_declaration() - ? traverse->target().schema() - : query_state.aliases.at(traverse->target().value()); - + ARROW_ASSIGN_OR_RAISE(const auto source_schema, + query_state.resolve_schema(traverse->source())); + ARROW_ASSIGN_OR_RAISE(const auto target_schema, + query_state.resolve_schema(traverse->target())); // Fully-qualified field names should also be precomputed during // preparation - if (auto res = query_state.compute_fully_qualified_names( - traverse->source(), source_schema); - !res.ok()) { - return res.status(); - } - if (auto res = query_state.compute_fully_qualified_names( - traverse->target(), target_schema); - !res.ok()) { - return res.status(); - } + ARROW_RETURN_NOT_OK(query_state.compute_fully_qualified_names( + traverse->source(), source_schema)); + ARROW_RETURN_NOT_OK(query_state.compute_fully_qualified_names( + traverse->target(), target_schema)); std::vector> where_clauses; if (query.inline_where()) { @@ -2151,12 +1903,12 @@ arrow::Result> Database::query( IF_DEBUG_ENABLED { log_debug("Traversing from {} source nodes", - query_state.ids[source.value()].size()); + query_state.ids()[source.value()].size()); } llvm::DenseSet matched_source_ids; llvm::DenseSet matched_target_ids; llvm::DenseSet unmatched_source_ids; - for (auto source_id : query_state.ids[source.value()]) { + for (auto source_id : query_state.ids()[source.value()]) { auto outgoing_edges = edge_store_->get_outgoing_edges(source_id, traverse->edge_type()) .ValueOrDie(); // todo check result @@ -2168,8 +1920,9 @@ arrow::Result> Database::query( bool source_had_match = false; for (const auto& edge : outgoing_edges) { auto target_id = edge->get_target_id(); - if (query_state.ids.contains(traverse->target().value()) && - !query_state.ids.at(traverse->target().value()) + if (query_state.ids().contains(traverse->target().value()) && + !query_state.ids() + .at(traverse->target().value()) .contains(target_id)) { continue; } @@ -2203,7 +1956,7 @@ arrow::Result> Database::query( } matched_target_ids.insert(target_node->id); // Use connection pool to avoid allocation - auto& conn = query_state.connection_pool_.get(); + auto& conn = query_state.connection_pool().get(); conn.source = traverse->source(); conn.source_id = source_id; conn.edge_type = traverse->edge_type(); @@ -2211,9 +1964,10 @@ arrow::Result> Database::query( conn.target = traverse->target(); conn.target_id = target_node->id; - query_state.connections[traverse->source().value()][source_id] + query_state + .connections()[traverse->source().value()][source_id] .push_back(conn); - query_state.incoming[target_node->id].push_back(conn); + query_state.incoming()[target_node->id].push_back(conn); } } } else { @@ -2238,12 +1992,15 @@ arrow::Result> Database::query( query_state.remove_node(id, source); } IF_DEBUG_ENABLED { - log_debug("rebuild table for schema {}:{}", source.value(), - query_state.aliases[source.value()]); + auto resolved = query_state.resolve_schema(source); + if (resolved.ok()) { + log_debug("rebuild table for schema {}:{}", source.value(), + resolved.ValueOrDie()); + } } auto table_result = filter_table_by_id(query_state.tables[source.value()], - query_state.ids[source.value()]); + query_state.ids()[source.value()]); if (!table_result.ok()) { return table_result.status(); } @@ -2269,7 +2026,7 @@ arrow::Result> Database::query( dense_intersection(target_ids, matched_target_ids, intersect_ids); } - query_state.ids[traverse->target().value()] = intersect_ids; + query_state.ids()[traverse->target().value()] = intersect_ids; IF_DEBUG_ENABLED { log_debug("intersect_ids count: {}", intersect_ids.size()); log_debug("{} intersect_ids: {}", traverse->target().toString(), @@ -2277,7 +2034,7 @@ arrow::Result> Database::query( } } else if (traverse->traverse_type() == TraverseType::Left) { - query_state.ids[traverse->target().value()].insert( + query_state.ids()[traverse->target().value()].insert( matched_target_ids.begin(), matched_target_ids.end()); } else { // Right, Full: matched targets + unmatched targets auto target_ids = @@ -2319,11 +2076,11 @@ arrow::Result> Database::query( } } - query_state.ids[traverse->target().value()] = result; + query_state.ids()[traverse->target().value()] = result; } std::vector> neighbors; - for (auto id : query_state.ids[traverse->target().value()]) { + for (auto id : query_state.ids()[traverse->target().value()]) { auto node_res = node_manager_->get_node(target_schema, id); if (node_res.ok()) { neighbors.push_back(node_res.ValueOrDie()); diff --git a/src/query.cpp b/src/query.cpp index 43d68d6..90be46d 100644 --- a/src/query.cpp +++ b/src/query.cpp @@ -1,8 +1,43 @@ #include "query.hpp" +#include +#include +#include +#include + +#include "logger.hpp" + namespace tundradb { -// FieldRef implementation +// ================== SchemaRef Implementation ================== + +SchemaRef SchemaRef::parse(const std::string& s) { + SchemaRef r; + size_t pos = s.find(':'); + if (pos == std::string::npos) { + r.schema_ = s; + r.value_ = s; + r.declaration_ = false; + } else { + r.value_ = s.substr(0, pos); + r.schema_ = s.substr(pos + 1); + r.declaration_ = true; + } + return r; +} + +std::string SchemaRef::toString() const { + std::stringstream ss; + if (declaration_) { + ss << value_; + ss << ":"; + } + ss << schema_; + return ss.str(); +} + +// ================== FieldRef Implementation ================== + FieldRef FieldRef::from_string(const std::string& field_str) { const size_t dot_pos = field_str.find('.'); if (dot_pos != std::string::npos) { @@ -11,13 +46,539 @@ FieldRef FieldRef::from_string(const std::string& field_str) { // Return unresolved FieldRef - will be resolved later in query processing return {variable, field_name}; - } else { - // No variable prefix, treat entire string as field name - return {"", field_str}; } + // No variable prefix, treat entire string as field name + return {"", field_str}; +} + +// ================== ComparisonExpr Implementation ================== + +ComparisonExpr::ComparisonExpr(const std::string& field, CompareOp op, + Value value) + : field_ref_(FieldRef::from_string(field)), + op_(op), + value_(std::move(value)) {} + +arrow::Result ComparisonExpr::compare_values(const Value& value, + CompareOp op, + const Value& where_value) { + if (value.type() == ValueType::NA || where_value.type() == ValueType::NA) { + switch (op) { + case CompareOp::Eq: + return value.type() == ValueType::NA && + where_value.type() == ValueType::NA; + case CompareOp::NotEq: + return value.type() != ValueType::NA || + where_value.type() != ValueType::NA; + default: + return arrow::Status::Invalid( + "Null values can only be compared with == or !="); + } + } + + if (op == CompareOp::Contains || op == CompareOp::StartsWith || + op == CompareOp::EndsWith) { + if (value.type() != ValueType::STRING || + where_value.type() != ValueType::STRING) { + return arrow::Status::Invalid( + "String operations (CONTAINS, STARTS_WITH, ENDS_WITH) can only be " + "applied to string values"); + } + } + + if (value.type() == ValueType::BOOL || + where_value.type() == ValueType::BOOL) { + if (value.type() != ValueType::BOOL || + where_value.type() != ValueType::BOOL) { + return arrow::Status::Invalid( + "Boolean values can only be compared with other boolean values"); + } + if (op != CompareOp::Eq && op != CompareOp::NotEq) { + return arrow::Status::Invalid( + "Boolean values can only be compared with == or !="); + } + } + + if (value.type() != where_value.type()) { + return arrow::Status::Invalid("Type mismatch: field is ", value.type(), + " but WHERE value is ", where_value.type()); + } + + switch (value.type()) { + case ValueType::INT32: { + int32_t field_val = value.get(); + int32_t where_val = where_value.get(); + return apply_comparison(field_val, op, where_val); + } + case ValueType::INT64: { + int64_t field_val = value.get(); + int64_t where_val = where_value.get(); + return apply_comparison(field_val, op, where_val); + } + case ValueType::FLOAT: { + float field_val = value.get(); + float where_val = where_value.get(); + return apply_comparison(field_val, op, where_val); + } + case ValueType::DOUBLE: { + double field_val = value.get(); + double where_val = where_value.get(); + return apply_comparison(field_val, op, where_val); + } + case ValueType::STRING: { + const std::string& field_val = value.as_string(); + const std::string& where_val = where_value.as_string(); + return apply_comparison(field_val, op, where_val); + } + case ValueType::BOOL: { + bool field_val = value.get(); + bool where_val = where_value.get(); + return apply_comparison(field_val, op, where_val); + } + case ValueType::NA: + return arrow::Status::Invalid("Unexpected null value in comparison"); + default: + return arrow::Status::NotImplemented( + "Unsupported value type for comparison: ", value.type()); + } +} + +template +bool ComparisonExpr::apply_comparison(const T& field_val, CompareOp op, + const T& where_val) { + switch (op) { + case CompareOp::Eq: + return field_val == where_val; + case CompareOp::NotEq: + return field_val != where_val; + case CompareOp::Gt: + return field_val > where_val; + case CompareOp::Lt: + return field_val < where_val; + case CompareOp::Gte: + return field_val >= where_val; + case CompareOp::Lte: + return field_val <= where_val; + case CompareOp::Contains: + if constexpr (std::is_same_v) { + return field_val.contains(where_val); + } else { + return false; + } + case CompareOp::StartsWith: + if constexpr (std::is_same_v) { + return field_val.starts_with(where_val); + } else { + return false; + } + case CompareOp::EndsWith: + if constexpr (std::is_same_v) { + return field_val.ends_with(where_val); + } else { + return false; + } + } + return false; +} + +// Explicit template instantiations +template bool ComparisonExpr::apply_comparison(const int32_t&, + CompareOp, + const int32_t&); +template bool ComparisonExpr::apply_comparison(const int64_t&, + CompareOp, + const int64_t&); +template bool ComparisonExpr::apply_comparison(const float&, CompareOp, + const float&); +template bool ComparisonExpr::apply_comparison(const double&, CompareOp, + const double&); +template bool ComparisonExpr::apply_comparison(const std::string&, + CompareOp, + const std::string&); +template bool ComparisonExpr::apply_comparison(const bool&, CompareOp, + const bool&); + +template +bool ComparisonExpr::apply_comparison(const T& field_val, const T& where_val, + CompareOp op) { + switch (op) { + case CompareOp::Eq: + return field_val == where_val; + case CompareOp::NotEq: + return field_val != where_val; + case CompareOp::Gt: + return field_val > where_val; + case CompareOp::Lt: + return field_val < where_val; + case CompareOp::Gte: + return field_val >= where_val; + case CompareOp::Lte: + return field_val <= where_val; + case CompareOp::Contains: + if constexpr (std::is_same_v) { + return field_val.find(where_val) != std::string::npos; + } else { + return false; + } + case CompareOp::StartsWith: + if constexpr (std::is_same_v) { + return field_val.find(where_val) == 0; + } else { + return false; + } + case CompareOp::EndsWith: + if constexpr (std::is_same_v) { + return field_val.size() >= where_val.size() && + field_val.substr(field_val.size() - where_val.size()) == + where_val; + } else { + return false; + } + } + return false; +} + +// Explicit template instantiations for the second overload +template bool ComparisonExpr::apply_comparison(const int32_t&, + const int32_t&, + CompareOp); +template bool ComparisonExpr::apply_comparison(const int64_t&, + const int64_t&, + CompareOp); +template bool ComparisonExpr::apply_comparison(const float&, + const float&, CompareOp); +template bool ComparisonExpr::apply_comparison(const double&, + const double&, + CompareOp); +template bool ComparisonExpr::apply_comparison(const std::string&, + const std::string&, + CompareOp); +template bool ComparisonExpr::apply_comparison(const bool&, const bool&, + CompareOp); + +std::string ComparisonExpr::toString() const { + std::stringstream ss; + ss << "WHERE " << field_ref_.to_string(); + + switch (op_) { + case CompareOp::Eq: + ss << " = "; + break; + case CompareOp::NotEq: + ss << " != "; + break; + case CompareOp::Gt: + ss << " > "; + break; + case CompareOp::Lt: + ss << " < "; + break; + case CompareOp::Gte: + ss << " >= "; + break; + case CompareOp::Lte: + ss << " <= "; + break; + case CompareOp::Contains: + ss << " CONTAINS "; + break; + case CompareOp::StartsWith: + ss << " STARTS_WITH "; + break; + case CompareOp::EndsWith: + ss << " ENDS_WITH "; + break; + } + + switch (value_.type()) { + case ValueType::NA: + ss << "NULL"; + break; + case ValueType::INT32: + ss << value_.get(); + break; + case ValueType::INT64: + ss << value_.get(); + break; + case ValueType::FLOAT: + ss << value_.get(); + break; + case ValueType::DOUBLE: + ss << value_.get(); + break; + case ValueType::BOOL: + ss << (value_.get() ? "true" : "false"); + break; + case ValueType::FIXED_STRING16: + case ValueType::FIXED_STRING32: + case ValueType::FIXED_STRING64: + case ValueType::STRING: + ss << "'" << value_.to_string() << "'"; + break; + } + + if (inlined_) { + ss << " (inlined)"; + } + + return ss.str(); +} + +std::ostream& operator<<(std::ostream& os, const ComparisonExpr& expr) { + os << expr.toString(); + return os; +} + +arrow::Result ComparisonExpr::matches( + const std::shared_ptr& node) const { + if (!node) { + return arrow::Status::Invalid("Node is null"); + } + assert(field_ref_.field() != nullptr); + ARROW_ASSIGN_OR_RAISE(auto field_value, node->get_value(field_ref_.field())); + return compare_values(field_value, op_, value_); +} + +arrow::compute::Expression ComparisonExpr::to_arrow_expression( + bool strip_var) const { + std::string field_name = + strip_var ? field_ref_.field_name() : field_ref_.value(); + const auto field_expr = arrow::compute::field_ref(field_name); + const auto value_expr = value_to_expression(value_); + + return apply_comparison_op(field_expr, value_expr, op_); +} + +std::vector> +ComparisonExpr::get_conditions_for_variable(const std::string& variable) const { + if (field_ref_.variable() == variable) { + return {std::make_shared(*this)}; + } + return {}; +} + +bool ComparisonExpr::can_inline(const std::string& variable) const { + return field_ref_.variable() == variable; +} + +std::string ComparisonExpr::extract_first_variable() const { + return field_ref_.variable(); +} + +std::set ComparisonExpr::get_all_variables() const { + std::set variables; + variables.insert(field_ref_.variable()); + return variables; +} + +arrow::Result ComparisonExpr::resolve_field_ref( + const std::unordered_map& aliases, + const SchemaRegistry* schema_registry) { + if (field_ref_.is_resolved()) { + return true; + } + + const std::string& variable = field_ref_.variable(); + const std::string& field_name = field_ref_.field_name(); + + // Find the actual schema for this variable + auto it = aliases.find(variable); + if (it == aliases.end()) { + return arrow::Status::KeyError("Unknown variable '", variable, + "' in field '", field_ref_.to_string(), "'"); + } + + const std::string& schema_name = it->second; + + auto schema_result = schema_registry->get(schema_name); + if (!schema_result.ok()) { + return arrow::Status::KeyError("Schema '", schema_name, + "' not found for variable '", variable, "'"); + } + + auto schema = schema_result.ValueOrDie(); + auto field = schema->get_field(field_name); + if (!field) { + return arrow::Status::KeyError("Field '", field_name, + "' not found in schema '", schema_name, "'"); + } + field_ref_.resolve(field); + + return true; +} + +// ================== LogicalExpr Implementation ================== + +void LogicalExpr::set_inlined(bool inlined) { + inlined_ = inlined; + if (left_) left_->set_inlined(inlined); + if (right_) right_->set_inlined(inlined); +} + +arrow::Result LogicalExpr::resolve_field_ref( + const std::unordered_map& aliases, + const SchemaRegistry* schema_registry) { + if (left_) { + if (const auto res = left_->resolve_field_ref(aliases, schema_registry); + !res.ok()) { + return res.status(); + } + } + if (right_) { + if (const auto res = right_->resolve_field_ref(aliases, schema_registry); + !res.ok()) { + return res.status(); + } + } + return true; +} + +std::shared_ptr LogicalExpr::and_expr( + std::shared_ptr left, std::shared_ptr right) { + return std::make_shared(std::move(left), LogicalOp::AND, + std::move(right)); +} + +std::shared_ptr LogicalExpr::or_expr( + std::shared_ptr left, std::shared_ptr right) { + return std::make_shared(std::move(left), LogicalOp::OR, + std::move(right)); +} + +arrow::Result LogicalExpr::matches( + const std::shared_ptr& node) const { + if (!left_ || !right_) { + return arrow::Status::Invalid("LogicalExpr missing left or right operand"); + } + + auto left_result = left_->matches(node); + if (!left_result.ok()) { + return left_result.status(); + } + + auto right_result = right_->matches(node); + if (!right_result.ok()) { + return right_result.status(); + } + + bool left_val = left_result.ValueOrDie(); + bool right_val = right_result.ValueOrDie(); + + switch (op_) { + case LogicalOp::AND: + return left_val && right_val; + case LogicalOp::OR: + return left_val || right_val; + } + + return arrow::Status::Invalid("Unknown logical operator"); +} + +arrow::compute::Expression LogicalExpr::to_arrow_expression( + bool strip_var) const { + if (!left_ || !right_) { + throw std::runtime_error("LogicalExpr missing left or right operand"); + } + + auto left_expr = left_->to_arrow_expression(strip_var); + auto right_expr = right_->to_arrow_expression(strip_var); + + switch (op_) { + case LogicalOp::AND: + return arrow::compute::and_(left_expr, right_expr); + case LogicalOp::OR: + return arrow::compute::or_(left_expr, right_expr); + default: + throw std::runtime_error("Unknown logical operator in LogicalExpr"); + } +} + +std::vector> +LogicalExpr::get_conditions_for_variable(const std::string& variable) const { + auto all_variables = get_all_variables(); + for (const auto& var : all_variables) { + if (var != variable) { + return {}; + } + } + + std::vector> result; + if (left_) { + auto left_conditions = left_->get_conditions_for_variable(variable); + result.insert(result.end(), left_conditions.begin(), left_conditions.end()); + } + if (right_) { + auto right_conditions = right_->get_conditions_for_variable(variable); + result.insert(result.end(), right_conditions.begin(), + right_conditions.end()); + } + return result; +} + +std::string LogicalExpr::extract_first_variable() const { + if (left_) { + auto var = left_->extract_first_variable(); + if (!var.empty()) return var; + } + if (right_) { + auto var = right_->extract_first_variable(); + if (!var.empty()) return var; + } + return ""; +} + +std::string LogicalExpr::toString() const { + if (!left_ || !right_) { + return "WHERE (incomplete logical expression)"; + } + + std::string left_str = left_->toString(); + std::string right_str = right_->toString(); + + if (left_str.substr(0, 6) == "WHERE ") { + left_str = left_str.substr(6); + } + if (right_str.substr(0, 6) == "WHERE ") { + right_str = right_str.substr(6); + } + + std::string op_str = (op_ == LogicalOp::AND) ? " AND " : " OR "; + + std::string result = + "WHERE (" + left_str + ")" + op_str + "(" + right_str + ")"; + + if (inlined_) { + result += " (inlined)"; + } + + return result; +} + +std::ostream& operator<<(std::ostream& os, const LogicalExpr& expr) { + os << expr.toString(); + return os; +} + +std::set LogicalExpr::get_all_variables() const { + std::set variables; + if (left_) { + auto left_variables = left_->get_all_variables(); + variables.insert(left_variables.begin(), left_variables.end()); + } + if (right_) { + auto right_variables = right_->get_all_variables(); + variables.insert(right_variables.begin(), right_variables.end()); + } + return variables; +} + +bool LogicalExpr::can_inline(const std::string& variable) const { + if (left_ && !left_->can_inline(variable)) return false; + if (right_ && !right_->can_inline(variable)) return false; + return true; } -// Convert CompareOp to appropriate Arrow compute function +// ================== Helper Functions ================== arrow::compute::Expression apply_comparison_op( const arrow::compute::Expression& field, const arrow::compute::Expression& value, CompareOp op) { @@ -35,9 +596,6 @@ arrow::compute::Expression apply_comparison_op( case CompareOp::Lte: return arrow::compute::less_equal(field, value); case CompareOp::Contains: - // For string operations, we'd need to use match_substring_regex or - // similar For now, fall back to equal (this would need more sophisticated - // handling) log_warn( "CONTAINS operator not fully implemented for Arrow expressions, " "using equality"); @@ -81,4 +639,4 @@ arrow::compute::Expression value_to_expression(const Value& value) { } } -} // namespace tundradb \ No newline at end of file +} // namespace tundradb diff --git a/src/query_execution.cpp b/src/query_execution.cpp new file mode 100644 index 0000000..dfddb8a --- /dev/null +++ b/src/query_execution.cpp @@ -0,0 +1,191 @@ +#include "query_execution.hpp" + +#include "arrow_utils.hpp" +#include "logger.hpp" + +namespace tundradb { + +// SchemaContext implementation + +arrow::Result SchemaContext::register_schema( + const SchemaRef& schema_ref) { + if (aliases_.contains(schema_ref.value()) && schema_ref.is_declaration()) { + IF_DEBUG_ENABLED { + log_debug("Schema alias '{}' already assigned to '{}'", + schema_ref.value(), aliases_.at(schema_ref.value())); + } + return aliases_[schema_ref.value()]; + } + + if (schema_ref.is_declaration()) { + aliases_[schema_ref.value()] = schema_ref.schema(); + return schema_ref.schema(); + } + + return aliases_[schema_ref.value()]; +} + +arrow::Result SchemaContext::resolve( + const SchemaRef& schema_ref) const { + if (schema_ref.is_declaration()) { + return schema_ref.schema(); + } + + auto it = aliases_.find(schema_ref.value()); + if (it == aliases_.end()) { + return arrow::Status::KeyError("No alias for '{}'", schema_ref.value()); + } + + return it->second; +} + +// FieldIndexer implementation + +arrow::Result FieldIndexer::compute_fq_names( + const SchemaRef& schema_ref, const std::string& resolved_schema, + SchemaRegistry* registry) { + const std::string& alias = schema_ref.value(); + if (fq_field_names_.contains(alias)) { + return false; // Already computed + } + + auto schema_res = registry->get(resolved_schema); + if (!schema_res.ok()) { + return schema_res.status(); + } + + const auto& schema = schema_res.ValueOrDie(); + std::vector names; + std::vector indices; + names.reserve(schema->num_fields()); + indices.reserve(schema->num_fields()); + + for (const auto& field : schema->fields()) { + std::string fq_name = alias + "." + field->name(); + int field_id = next_field_id_.fetch_add(1); + + names.emplace_back(fq_name); + indices.emplace_back(field_id); + + field_id_to_name_[field_id] = fq_name; + field_name_to_index_[fq_name] = field_id; + } + + fq_field_names_[alias] = std::move(names); + schema_field_indices_[alias] = std::move(indices); + + return true; +} + +// QueryState implementation + +QueryState::QueryState(std::shared_ptr registry) + : schemas(std::move(registry)) {} + +void QueryState::reserve_capacity(const Query& query) { + // Estimate schema count from FROM + TRAVERSE clauses + size_t estimated_schemas = 1; // FROM clause + for (const auto& clause : query.clauses()) { + if (clause->type() == Clause::Type::TRAVERSE) { + estimated_schemas += 2; // source + target schemas + } + } + + // Pre-size standard containers + tables.reserve(estimated_schemas); +} + +arrow::Result QueryState::compute_fully_qualified_names( + const SchemaRef& schema_ref) { + const auto& aliases_map = schemas.get_aliases(); + const auto it = aliases_map.find(schema_ref.value()); + if (it == aliases_map.end()) { + return arrow::Status::KeyError("keyset does not contain alias '{}'", + schema_ref.value()); + } + return compute_fully_qualified_names(schema_ref, it->second); +} + +arrow::Result QueryState::update_table( + const std::shared_ptr& table, const SchemaRef& schema_ref) { + this->tables[schema_ref.value()] = table; + auto ids_result = get_ids_from_table(table); + if (!ids_result.ok()) { + log_error("Failed to get IDs from table: {}", schema_ref.value()); + return ids_result.status(); + } + graph.ids(schema_ref.value()) = ids_result.ValueOrDie(); + return true; +} + +std::string QueryState::ToString() const { + std::stringstream ss; + ss << "QueryState {\n"; + ss << " From: " << from.toString() << "\n"; + + ss << " Tables (" << tables.size() << "):\n"; + for (const auto& [alias, table_ptr] : tables) { + if (table_ptr) { + ss << " - " << alias << ": " << table_ptr->num_rows() << " rows, " + << table_ptr->num_columns() << " columns\n"; + } else { + ss << " - " << alias << ": (nullptr)\n"; + } + } + + ss << " Aliases (" << schemas.get_aliases().size() << "):\n"; + for (const auto& [alias, schema_name] : schemas.get_aliases()) { + ss << " - " << alias << " -> " << schema_name << "\n"; + } + + ss << " Connections (Outgoing) (" << graph.outgoing().size() + << " source nodes):"; + for (const auto& [from_schema, conns] : graph.outgoing()) { + for (const auto& [from_id, conn_vec] : conns) { + ss << "from " << from_schema.str() << ":" << from_id << ":\n"; + for (const auto& conn : conn_vec) { + ss << " - " << conn.target.value() << ":" << conn.target_id << "\n"; + } + } + } + + ss << " Connections (Incoming) (" << graph.incoming().size() + << " target nodes):"; + int target_nodes_printed = 0; + for (const auto& [target_id, conns_vec] : graph.incoming()) { + if (target_nodes_printed >= 3 && graph.incoming().size() > 5) { + ss << " ... and " << (graph.incoming().size() - target_nodes_printed) + << " more target nodes ...\n"; + break; + } + ss << " - Target ID " << target_id << " (" << conns_vec.size() + << " incoming):"; + int conns_printed_for_target = 0; + for (const auto& conn : conns_vec) { + if (conns_printed_for_target >= 3 && conns_vec.size() > 5) { + ss << " ... and " + << (conns_vec.size() - conns_printed_for_target) + << " more connections ...\n"; + break; + } + ss << " <- " << conn.source.value() << ":" << conn.source_id + << " (via '" << conn.edge_type << "')\n"; + conns_printed_for_target++; + } + target_nodes_printed++; + } + + ss << " Traversals (" << traversals.size() << "):\n"; + for (size_t i = 0; i < traversals.size(); ++i) { + const auto& trav = traversals[i]; + ss << " - [" << i << "]: " << trav.source().value() << " -[" + << trav.edge_type() << "]-> " << trav.target().value() << " (Type: " + << (trav.traverse_type() == TraverseType::Inner ? "Inner" : "Other") + << ")\n"; + } + + ss << "}"; + return ss.str(); +} + +} // namespace tundradb