Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 115 additions & 3 deletions tests/operator_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,17 @@ AggregateInfo make_agg(AggregateType type, const std::string& name,
return info;
}

// Helper to create AggregateInfo with distinct flag
AggregateInfo make_agg_distinct(AggregateType type, const std::string& name,
std::unique_ptr<Expression> expr = nullptr) {
AggregateInfo info;
info.type = type;
info.name = name;
info.expr = std::move(expr);
info.is_distinct = true;
return info;
}

// Helper: create a BufferScanOperator with test data
std::unique_ptr<BufferScanOperator> make_buffer_scan(const std::string& table_name,
const std::vector<Tuple>& data,
Expand Down Expand Up @@ -769,6 +780,106 @@ TEST_F(OperatorTests, AggregateAvgFractional) {
agg->close();
}

TEST_F(OperatorTests, AggregateMultipleAggregates) {
// Test multiple aggregates in single query: SUM, COUNT, AVG
Schema schema = make_schema({{"val", common::ValueType::TYPE_INT64}});
std::vector<Tuple> data;
data.push_back(make_tuple({common::Value::make_int64(10)}));
data.push_back(make_tuple({common::Value::make_int64(20)}));
data.push_back(make_tuple({common::Value::make_int64(30)}));

auto scan = make_buffer_scan("test_table", data, schema);
std::vector<AggregateInfo> aggs;
aggs.push_back(make_agg(AggregateType::Sum, "total", col_expr("val")));
aggs.push_back(make_agg(AggregateType::Count, "cnt"));
aggs.push_back(make_agg(AggregateType::Avg, "avg_val", col_expr("val")));
auto agg = make_agg_op(std::move(scan), {}, std::move(aggs));

ASSERT_TRUE(agg->init());
ASSERT_TRUE(agg->open());

Tuple tuple;
EXPECT_TRUE(agg->next(tuple));
EXPECT_EQ(tuple.get(0).to_int64(), 60); // SUM
EXPECT_EQ(tuple.get(1).to_int64(), 3); // COUNT
EXPECT_EQ(tuple.get(2).to_float64(), 20.0); // AVG
EXPECT_FALSE(agg->next(tuple));
agg->close();
}

TEST_F(OperatorTests, AggregateGroupByMultipleCols) {
// Test GROUP BY with multiple columns
Schema schema = make_schema({{"dept", common::ValueType::TYPE_INT64},
{"name", common::ValueType::TYPE_TEXT},
{"salary", common::ValueType::TYPE_INT64}});
std::vector<Tuple> data;
data.push_back(make_tuple({common::Value::make_int64(1), common::Value::make_text("alice"),
common::Value::make_int64(1000)}));
data.push_back(make_tuple({common::Value::make_int64(1), common::Value::make_text("bob"),
common::Value::make_int64(2000)}));
data.push_back(make_tuple({common::Value::make_int64(2), common::Value::make_text("charlie"),
common::Value::make_int64(1500)}));

auto scan = make_buffer_scan("test_table", data, schema);
std::vector<std::unique_ptr<Expression>> group_by;
group_by.push_back(col_expr("dept"));
std::vector<AggregateInfo> aggs;
aggs.push_back(make_agg(AggregateType::Sum, "total_salary", col_expr("salary")));
auto agg = make_agg_op(std::move(scan), std::move(group_by), std::move(aggs));

ASSERT_TRUE(agg->init());
ASSERT_TRUE(agg->open());

// Should get 2 groups: dept 1 (sum=3000), dept 2 (sum=1500)
std::vector<std::pair<int64_t, int64_t>> results;
Tuple tuple;
while (agg->next(tuple)) {
results.push_back({tuple.get(0).to_int64(), tuple.get(1).to_int64()});
}

EXPECT_EQ(results.size(), 2U);
// Verify both groups appear
bool found_dept1 = false;
bool found_dept2 = false;
for (const auto& r : results) {
if (r.first == 1 && r.second == 3000) {
found_dept1 = true;
} else if (r.first == 2 && r.second == 1500) {
found_dept2 = true;
}
}
EXPECT_TRUE(found_dept1);
EXPECT_TRUE(found_dept2);
agg->close();
}

TEST_F(OperatorTests, AggregateWithNulls) {
// Test aggregate functions with NULL values in data
Schema schema = make_schema({{"val", common::ValueType::TYPE_INT64}});
std::vector<Tuple> data;
data.push_back(make_tuple({common::Value::make_int64(10)}));
data.push_back(make_tuple({common::Value()})); // NULL
data.push_back(make_tuple({common::Value::make_int64(20)}));

auto scan = make_buffer_scan("test_table", data, schema);
std::vector<AggregateInfo> aggs;
aggs.push_back(make_agg(AggregateType::Sum, "total", col_expr("val")));
aggs.push_back(make_agg(AggregateType::Count, "cnt", col_expr("val")));
auto agg = make_agg_op(std::move(scan), {}, std::move(aggs));

ASSERT_TRUE(agg->init());
ASSERT_TRUE(agg->open());

Tuple tuple;
EXPECT_TRUE(agg->next(tuple));
// SUM should skip NULL: 10 + 20 = 30
EXPECT_EQ(tuple.get(0).to_int64(), 30);
// COUNT should skip NULL: only 2 non-null values
EXPECT_EQ(tuple.get(1).to_int64(), 2);
EXPECT_FALSE(agg->next(tuple));
agg->close();
}

TEST_F(OperatorTests, HashJoinRightOuter) {
// Right table: values 2, 3, 4 (only 2 matches)
Schema left_schema = make_schema({{"id", common::ValueType::TYPE_INT64}});
Expand All @@ -794,7 +905,8 @@ TEST_F(OperatorTests, HashJoinRightOuter) {
// RIGHT join output: matched rows + unmatched right rows with NULLs
// Matched: (2, 2)
// Unmatched right: (NULL, 3), (NULL, 4)
std::vector<std::pair<int64_t, int64_t>> results; // (left_value, right_value); use INT64_MIN as sentinel for NULL
std::vector<std::pair<int64_t, int64_t>>
results; // (left_value, right_value); use INT64_MIN as sentinel for NULL
Tuple tuple;
while (join->next(tuple)) {
int64_t left_val = tuple.get(0).is_null() ? INT64_MIN : tuple.get(0).to_int64();
Expand Down Expand Up @@ -880,11 +992,11 @@ TEST_F(OperatorTests, HashJoinNullKeys) {
Schema left_schema = make_schema({{"id", common::ValueType::TYPE_INT64}});
std::vector<Tuple> left_data;
left_data.push_back(make_tuple({common::Value::make_int64(1)})); // matches 1
left_data.push_back(make_tuple({common::Value()})); // NULL - currently matches NULL
left_data.push_back(make_tuple({common::Value()})); // NULL - currently matches NULL

Schema right_schema = make_schema({{"id", common::ValueType::TYPE_INT64}});
std::vector<Tuple> right_data;
right_data.push_back(make_tuple({common::Value()})); // NULL - currently matches
right_data.push_back(make_tuple({common::Value()})); // NULL - currently matches
right_data.push_back(make_tuple({common::Value::make_int64(1)})); // matches 1

auto left_scan = make_buffer_scan("left_table", left_data, left_schema);
Expand Down