diff --git a/include/network/rpc_message.hpp b/include/network/rpc_message.hpp index 1bffe34f..16cfd43a 100644 --- a/include/network/rpc_message.hpp +++ b/include/network/rpc_message.hpp @@ -39,8 +39,8 @@ enum class RpcType : uint8_t { UnmatchedRowsPush = 14, // Coordinator sends unmatched rows for NULL-padding FetchUnmatchedRows = 15, // Coordinator fetches stored unmatched rows from data node // LEFT-side counterparts for FULL join - UnmatchedLeftRowsReport = 16, // Data node reports unmatched LEFT rows for FULL join - FetchUnmatchedLeftRows = 17, // Coordinator fetches stored unmatched LEFT rows + UnmatchedLeftRowsReport = 16, // Data node reports unmatched LEFT rows for FULL join + FetchUnmatchedLeftRows = 17, // Coordinator fetches stored unmatched LEFT rows Error = 255 }; @@ -710,9 +710,9 @@ struct FetchUnmatchedRowsArgs { struct UnmatchedLeftRowsReportArgs { std::string context_id; std::string left_table; - std::string join_key_col; // Which column was the join key - std::vector unmatched_keys; // LEFT key values that had no match - uint32_t right_column_count = 0; // Number of right columns for NULL-padding + std::string join_key_col; // Which column was the join key + std::vector unmatched_keys; // LEFT key values that had no match + uint32_t right_column_count = 0; // Number of right columns for NULL-padding [[nodiscard]] std::vector serialize() const { std::vector out; diff --git a/src/distributed/distributed_executor.cpp b/src/distributed/distributed_executor.cpp index 36f8a190..76db1c2a 100644 --- a/src/distributed/distributed_executor.cpp +++ b/src/distributed/distributed_executor.cpp @@ -184,6 +184,7 @@ QueryResult DistributedExecutor::execute(const parser::Statement& stmt, bool is_outer_join_join_query = false; std::string outer_join_left_table; std::string outer_join_right_table; + std::string outer_join_left_key; std::string outer_join_right_key; parser::SelectStatement::JoinType outer_join_type = parser::SelectStatement::JoinType::Inner; @@ -234,6 +235,7 @@ QueryResult DistributedExecutor::execute(const parser::Statement& stmt, is_outer_join_join_query = true; outer_join_left_table = left_table; outer_join_right_table = right_table; + outer_join_left_key = left_key; outer_join_right_key = right_key; outer_join_type = join.type; } @@ -611,15 +613,15 @@ QueryResult DistributedExecutor::execute(const parser::Statement& stmt, } } - // Phase 3-5: Currently disabled for all outer joins due to issues with column indexing - // when SELECT doesn't use SELECT * (causes duplicate rows instead of correct results). + // Phase 3-5: For FULL JOIN, collect unmatched LEFT rows from data nodes + // LEFT rows are emitted during probe phase when no match found, but we need to + // COLLECT them from all data nodes for the coordinator's final result. // - // For RIGHT JOIN: Local executor on each data node correctly handles unmatched right rows. - // For FULL JOIN: Unmatched LEFT rows are not collected (to be implemented in separate PR). + // For RIGHT JOIN: Local executor on each data node correctly handles unmatched right rows + // (no collection needed - each node emits them locally). // - // TODO: Re-enable Phase 3-5 for FULL JOIN once column indexing is fixed to properly - // identify which rows were unmatched during the distributed join. - if (false && is_outer_join_join_query && all_success) { + // This block is only enabled for FULL JOIN. + if (outer_join_type == parser::SelectStatement::JoinType::Full && all_success) { // Extract matched right keys from aggregated results // The right key column is at a known position in the result schema std::vector matched_keys; @@ -643,7 +645,7 @@ QueryResult DistributedExecutor::execute(const parser::Statement& stmt, } } - // Phase 3: Ask each node to scan local table and store unmatched rows + // Phase 3: Ask each node to scan local right table and store unmatched rows // First, compute the left column count for NULL-padding uint32_t left_column_count = 0; if (!outer_join_left_table.empty()) { @@ -713,7 +715,7 @@ QueryResult DistributedExecutor::execute(const parser::Statement& stmt, })); } - // Aggregate all unmatched rows from all nodes + // Aggregate all unmatched RIGHT rows from all nodes for (auto& f : fetch_futures) { auto result = f.get(); if (result.first) { @@ -722,6 +724,102 @@ QueryResult DistributedExecutor::execute(const parser::Statement& stmt, } } } + + // === LEFT-side Phase 3-4 for FULL JOIN === + // Extract matched LEFT keys from aggregated results + std::vector matched_left_keys; + size_t left_key_idx = static_cast(-1); + for (size_t i = 0; i < result_schema.columns().size(); ++i) { + const auto& col = result_schema.columns()[i]; + if (col.name() == outer_join_left_key) { + left_key_idx = i; + break; + } + } + if (left_key_idx != static_cast(-1)) { + for (const auto& row : aggregated_rows) { + if (row.size() > left_key_idx) { + matched_left_keys.push_back(row.get(left_key_idx).to_string()); + } + } + } + + // LEFT-side Phase 3: Ask each node to scan local left table and store unmatched rows + uint32_t right_column_count = 0; + if (!outer_join_right_table.empty()) { + auto right_table_info = catalog_.get_table_by_name(outer_join_right_table); + if (right_table_info.has_value()) { + right_column_count = static_cast((*right_table_info)->columns.size()); + } + } + + std::vector>> + left_report_futures; + + for (const auto& node : data_nodes) { + left_report_futures.push_back(std::async( + std::launch::async, [node, context_id, outer_join_left_table, outer_join_left_key, + matched_left_keys, right_column_count]() { + network::RpcClient client(node.address, node.cluster_port); + network::UnmatchedLeftRowsReportArgs reply; + if (client.connect()) { + network::UnmatchedLeftRowsReportArgs report_args; + report_args.context_id = context_id; + report_args.left_table = outer_join_left_table; + report_args.join_key_col = outer_join_left_key; + report_args.unmatched_keys = matched_left_keys; + report_args.right_column_count = right_column_count; + + std::vector resp; + if (client.call(network::RpcType::UnmatchedLeftRowsReport, + report_args.serialize(), resp)) { + reply = network::UnmatchedLeftRowsReportArgs::deserialize(resp); + return std::make_pair(true, reply); + } + } + return std::make_pair(false, reply); + })); + } + + // Wait for all LEFT report futures to complete + for (auto& f : left_report_futures) { + f.get(); + } + + // LEFT-side Phase 4: Fetch stored unmatched LEFT rows from each node + std::vector>>> left_fetch_futures; + + for (const auto& node : data_nodes) { + left_fetch_futures.push_back( + std::async(std::launch::async, [node, context_id, outer_join_left_table]() { + network::RpcClient client(node.address, node.cluster_port); + std::vector rows; + if (client.connect()) { + network::FetchUnmatchedLeftRowsArgs fetch_args; + fetch_args.context_id = context_id; + fetch_args.table_name = outer_join_left_table; + + std::vector resp; + if (client.call(network::RpcType::FetchUnmatchedLeftRows, + fetch_args.serialize(), resp)) { + auto reply = network::UnmatchedRowsPushArgs::deserialize(resp); + rows = std::move(reply.unmatched_rows); + return std::make_pair(true, std::move(rows)); + } + } + return std::make_pair(false, std::move(rows)); + })); + } + + // Aggregate all unmatched LEFT rows from all nodes + for (auto& f : left_fetch_futures) { + auto result = f.get(); + if (result.first) { + for (auto& row : result.second) { + aggregated_rows.push_back(std::move(row)); + } + } + } } if (all_success) {