From d553ff37c8f433d897546cdce974c212f2deaf2f Mon Sep 17 00:00:00 2001 From: yotam amar~ Date: Mon, 12 Jan 2026 11:59:23 +0200 Subject: [PATCH 1/6] feat: implement HeadTableSampler and RandomStartTableSampler for table sampling --- agentune/core/sampler/base.py | 10 + agentune/core/sampler/table_samples.py | 83 ++++++++ .../core/sampler/test_table_samplers.py | 201 ++++++++++++++++++ 3 files changed, 294 insertions(+) create mode 100644 agentune/core/sampler/table_samples.py create mode 100644 tests/agentune/core/sampler/test_table_samplers.py diff --git a/agentune/core/sampler/base.py b/agentune/core/sampler/base.py index 36a75871..671a985c 100644 --- a/agentune/core/sampler/base.py +++ b/agentune/core/sampler/base.py @@ -7,7 +7,9 @@ from typing import override import attrs +from duckdb import DuckDBPyConnection +from agentune.analyze.join.base import TableWithJoinStrategies from agentune.core.dataset import Dataset @@ -27,6 +29,14 @@ def _validate_inputs(self, dataset: Dataset, sample_size: int) -> None: raise ValueError('Sample size must be positive') +@attrs.define +class TableSampler(ABC): + """Abstract base class for data sampling from TableWithJoinStrategies.""" + @abstractmethod + def sample(self, table: TableWithJoinStrategies, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = None) -> Dataset: + """Sample data from a given table with join strategies.""" + + @attrs.define class RandomSampler(DataSampler): """Simple random sampling without any stratification. diff --git a/agentune/core/sampler/table_samples.py b/agentune/core/sampler/table_samples.py new file mode 100644 index 00000000..694a58c2 --- /dev/null +++ b/agentune/core/sampler/table_samples.py @@ -0,0 +1,83 @@ + +import random +from typing import override + +import attrs +from duckdb import DuckDBPyConnection + +from agentune.analyze.join.base import TableWithJoinStrategies +from agentune.core.dataset import Dataset, duckdb_to_dataset +from agentune.core.sampler.base import TableSampler + + +@attrs.define +class HeadTableSampler(TableSampler): + """Table sampler that returns the first N rows from a table. + + This sampler executes a simple SELECT * query with a LIMIT clause + to retrieve the head of the table. The order of rows is determined + by the table's natural order (or index if present). + """ + + @override + def sample(self, table: TableWithJoinStrategies, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = None) -> Dataset: + """Sample the first N rows from the table. + + Args: + table: The table with join strategies to sample from + conn: The DuckDB connection + sample_size: Number of rows to sample (limit) + random_seed: Not used in this sampler (kept for interface compatibility) + + Returns: + Dataset containing the first sample_size rows + """ + table_name = str(table.table.name) + sql_query = f'SELECT * FROM {table_name} LIMIT {sample_size}' + + relation = conn.sql(sql_query) + return duckdb_to_dataset(relation) + + +@attrs.define +class RandomStartTableSampler(TableSampler): + """Table sampler that returns consecutive rows starting from a random position. + + This sampler first determines the table size, then selects a random starting + point (using the provided seed for reproducibility), and returns consecutive + rows from that point. Uses row_number() to handle the offset. + """ + + @override + def sample(self, table: TableWithJoinStrategies, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = None) -> Dataset: + """Sample consecutive rows from a random starting point in the table. + + Args: + table: The table with join strategies to sample from + conn: The DuckDB connection + sample_size: Number of consecutive rows to sample + random_seed: Random seed for selecting the starting point (for reproducibility) + + Returns: + Dataset containing sample_size consecutive rows starting from a random position + """ + table_name = str(table.table.name) + + # Get table size + count_query = f'SELECT COUNT(*) as count FROM {table_name}' + result = conn.sql(count_query).fetchone() + table_size = result[0] if result else 0 + + # Adjust sample size if it exceeds table size + sample_size = min(sample_size, table_size) + + # Select random starting point + rng = random.Random(random_seed) + start_rowid = rng.randint(0, max(0, table_size - sample_size)) + + # Select consecutive rows starting from the random rowid + # Using DuckDB's built-in rowid pseudocolumn for deterministic and efficient filtering + sql_query = f'SELECT * FROM {table_name} WHERE rowid >= {start_rowid} LIMIT {sample_size}' + + relation = conn.sql(sql_query) + return duckdb_to_dataset(relation) diff --git a/tests/agentune/core/sampler/test_table_samplers.py b/tests/agentune/core/sampler/test_table_samplers.py new file mode 100644 index 00000000..4621f211 --- /dev/null +++ b/tests/agentune/core/sampler/test_table_samplers.py @@ -0,0 +1,201 @@ +"""Tests for table sampling utilities.""" + +import polars as pl +from duckdb import DuckDBPyConnection + +from agentune.analyze.join.base import TableWithJoinStrategies +from agentune.core.database import DuckdbName, DuckdbTable +from agentune.core.sampler.table_samples import HeadTableSampler, RandomStartTableSampler +from agentune.core.schema import Field, Schema +from agentune.core.types import int32, string + + +def create_test_table(conn: DuckDBPyConnection, table_name: str, num_rows: int) -> TableWithJoinStrategies: + """Helper to create test tables in DuckDB.""" + # Create test data + data = pl.DataFrame({ + 'id': list(range(num_rows)), + 'value': [f'item_{i}' for i in range(num_rows)], + }) + + # Create actual table in DuckDB (not just register) + qualified_name = DuckdbName.qualify(table_name, conn) + schema = Schema((Field('id', int32), Field('value', string))) + duckdb_table = DuckdbTable(name=qualified_name, schema=schema) + + # Create the table + duckdb_table.create(conn, if_not_exists=True) + + # Insert data + if num_rows > 0: + conn.register('__temp_data', data) + conn.execute(f'INSERT INTO {qualified_name} SELECT * FROM __temp_data') + conn.unregister('__temp_data') + + # Return TableWithJoinStrategies + return TableWithJoinStrategies(table=duckdb_table, join_strategies={}) + + +class TestHeadTableSampler: + """Test HeadTableSampler functionality.""" + + def test_basic_head_sampling(self, conn: DuckDBPyConnection) -> None: + """Test basic head sampling functionality.""" + sampler = HeadTableSampler() + table = create_test_table(conn, 'test_table', 100) + + # Sample first 20 rows + result = sampler.sample(table, conn, sample_size=20) + + # Validate result + assert result.data.height == 20 + assert result.data['id'].to_list() == list(range(20)) + assert result.schema.names == ['id', 'value'] + + def test_head_sampling_full_table(self, conn: DuckDBPyConnection) -> None: + """Test head sampling when sample size equals table size.""" + sampler = HeadTableSampler() + table = create_test_table(conn, 'test_table_full', 50) + + result = sampler.sample(table, conn, sample_size=50) + + assert result.data.height == 50 + assert result.data['id'].to_list() == list(range(50)) + + def test_head_sampling_ignores_random_seed(self, conn: DuckDBPyConnection) -> None: + """Test that head sampler produces same results regardless of seed.""" + sampler = HeadTableSampler() + table = create_test_table(conn, 'test_table_seed', 50) + + result1 = sampler.sample(table, conn, sample_size=10, random_seed=42) + result2 = sampler.sample(table, conn, sample_size=10, random_seed=999) + + assert result1.data.equals(result2.data) + + def test_head_sampling_larger_than_table(self, conn: DuckDBPyConnection) -> None: + """Test head sampling when sample size exceeds table size.""" + sampler = HeadTableSampler() + table = create_test_table(conn, 'test_head_oversample', 30) + + result = sampler.sample(table, conn, sample_size=50) + + # Should return entire table (all 30 rows) + assert result.data.height == 30 + assert result.data['id'].to_list() == list(range(30)) + + def test_head_sampling_empty_table(self, conn: DuckDBPyConnection) -> None: + """Test head sampling from an empty table.""" + sampler = HeadTableSampler() + table = create_test_table(conn, 'test_head_empty', 0) + + result = sampler.sample(table, conn, sample_size=10) + + # Should return empty dataset + assert result.data.height == 0 + # But schema should still be correct + assert result.schema.names == ['id', 'value'] + + +class TestRandomStartTableSampler: + """Test RandomStartTableSampler functionality.""" + + def test_basic_random_start_sampling(self, conn: DuckDBPyConnection) -> None: + """Test basic random start sampling functionality.""" + sampler = RandomStartTableSampler() + table = create_test_table(conn, 'test_random_table', 100) + + result = sampler.sample(table, conn, sample_size=20, random_seed=42) + + # Validate result + assert result.data.height == 20 + assert result.schema.names == ['id', 'value'] + # Check that rows are consecutive + ids = result.data['id'].to_list() + assert ids == list(range(ids[0], ids[0] + 20)) + + def test_random_start_sampling_reproducibility(self, conn: DuckDBPyConnection) -> None: + """Test that random start sampling is reproducible with same seed.""" + sampler = RandomStartTableSampler() + table = create_test_table(conn, 'test_reproducible', 100) + + result1 = sampler.sample(table, conn, sample_size=15, random_seed=123) + result2 = sampler.sample(table, conn, sample_size=15, random_seed=123) + + assert result1.data.equals(result2.data) + + def test_random_start_sampling_different_seeds(self, conn: DuckDBPyConnection) -> None: + """Test that different seeds produce different starting points.""" + sampler = RandomStartTableSampler() + table = create_test_table(conn, 'test_different_seeds', 100) + + result1 = sampler.sample(table, conn, sample_size=20, random_seed=42) + result2 = sampler.sample(table, conn, sample_size=20, random_seed=999) + + # Different seeds should (very likely) produce different starting points + ids1 = result1.data['id'].to_list() + ids2 = result2.data['id'].to_list() + assert ids1[0] != ids2[0] or ids1 != ids2 + + def test_random_start_sampling_consecutive_rows(self, conn: DuckDBPyConnection) -> None: + """Test that sampled rows are consecutive.""" + sampler = RandomStartTableSampler() + table = create_test_table(conn, 'test_consecutive', 200) + + result = sampler.sample(table, conn, sample_size=30, random_seed=42) + + ids = result.data['id'].to_list() + # Verify consecutiveness + for i in range(len(ids) - 1): + assert ids[i + 1] == ids[i] + 1 + + def test_random_start_sampling_full_table(self, conn: DuckDBPyConnection) -> None: + """Test random start sampling when sample size equals table size.""" + sampler = RandomStartTableSampler() + table = create_test_table(conn, 'test_full_random', 50) + + result = sampler.sample(table, conn, sample_size=50, random_seed=42) + + # When sample size equals table size, starting point should be 0 + assert result.data.height == 50 + assert result.data['id'].to_list() == list(range(50)) + + def test_random_start_sampling_larger_than_table(self, conn: DuckDBPyConnection) -> None: + """Test random start sampling when sample size exceeds table size.""" + sampler = RandomStartTableSampler() + table = create_test_table(conn, 'test_oversample', 30) + + result = sampler.sample(table, conn, sample_size=50, random_seed=42) + + # Should return entire table + assert result.data.height == 30 + assert result.data['id'].to_list() == list(range(30)) + # Schema should be correct + assert result.schema.names == ['id', 'value'] + + def test_random_start_sampling_empty_table(self, conn: DuckDBPyConnection) -> None: + """Test random start sampling with empty table.""" + sampler = RandomStartTableSampler() + table = create_test_table(conn, 'test_empty', 0) + + result = sampler.sample(table, conn, sample_size=10, random_seed=42) + + # Should return empty dataset + assert result.data.height == 0 + # But schema should still be correct + assert result.schema.names == ['id', 'value'] + + def test_random_start_sampling_near_end(self, conn: DuckDBPyConnection) -> None: + """Test that samples near the end of the table work correctly.""" + sampler = RandomStartTableSampler() + table = create_test_table(conn, 'test_near_end', 100) + + # Force a start near the end by setting specific seed + # This tests that we correctly handle the range + result = sampler.sample(table, conn, sample_size=10, random_seed=42) + + ids = result.data['id'].to_list() + # Should still get 10 consecutive rows + assert len(ids) == 10 + assert ids == list(range(ids[0], ids[0] + 10)) + # And they should be within the table bounds + assert ids[-1] < 100 From f9f968fd6b5715c3d452718d157c4e1799d99770 Mon Sep 17 00:00:00 2001 From: yotam amar~ Date: Mon, 12 Jan 2026 12:50:14 +0200 Subject: [PATCH 2/6] feat: update sampling queries to use rowid range --- agentune/core/sampler/table_samples.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/agentune/core/sampler/table_samples.py b/agentune/core/sampler/table_samples.py index 694a58c2..b27c2250 100644 --- a/agentune/core/sampler/table_samples.py +++ b/agentune/core/sampler/table_samples.py @@ -33,7 +33,8 @@ def sample(self, table: TableWithJoinStrategies, conn: DuckDBPyConnection, sampl Dataset containing the first sample_size rows """ table_name = str(table.table.name) - sql_query = f'SELECT * FROM {table_name} LIMIT {sample_size}' + end_rowid = sample_size - 1 + sql_query = f'SELECT * FROM {table_name} WHERE rowid BETWEEN 0 AND {end_rowid} ORDER BY rowid' relation = conn.sql(sql_query) return duckdb_to_dataset(relation) @@ -74,10 +75,11 @@ def sample(self, table: TableWithJoinStrategies, conn: DuckDBPyConnection, sampl # Select random starting point rng = random.Random(random_seed) start_rowid = rng.randint(0, max(0, table_size - sample_size)) + end_rowid = start_rowid + sample_size - 1 # Select consecutive rows starting from the random rowid # Using DuckDB's built-in rowid pseudocolumn for deterministic and efficient filtering - sql_query = f'SELECT * FROM {table_name} WHERE rowid >= {start_rowid} LIMIT {sample_size}' + sql_query = f'SELECT * FROM {table_name} WHERE rowid BETWEEN {start_rowid} AND {end_rowid} ORDER BY rowid' relation = conn.sql(sql_query) return duckdb_to_dataset(relation) From a81f5552c61c303e9c8d092dd0177d09aa5180f0 Mon Sep 17 00:00:00 2001 From: yotam amar~ Date: Mon, 12 Jan 2026 16:00:02 +0200 Subject: [PATCH 3/6] feat: optimize table size retrieval in RandomStartTableSampler --- agentune/core/sampler/table_samples.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/agentune/core/sampler/table_samples.py b/agentune/core/sampler/table_samples.py index b27c2250..99484947 100644 --- a/agentune/core/sampler/table_samples.py +++ b/agentune/core/sampler/table_samples.py @@ -65,9 +65,7 @@ def sample(self, table: TableWithJoinStrategies, conn: DuckDBPyConnection, sampl table_name = str(table.table.name) # Get table size - count_query = f'SELECT COUNT(*) as count FROM {table_name}' - result = conn.sql(count_query).fetchone() - table_size = result[0] if result else 0 + table_size = len(conn.table(table_name)) # Adjust sample size if it exceeds table size sample_size = min(sample_size, table_size) From f77cd2e89172f901ebfd7f64c272fe29782283b4 Mon Sep 17 00:00:00 2001 From: yotam amar~ Date: Tue, 13 Jan 2026 17:57:57 +0200 Subject: [PATCH 4/6] apply comments --- agentune/core/sampler/base.py | 2 +- agentune/core/sampler/table_samples.py | 9 ++- .../core/sampler/test_table_samplers.py | 69 ++++++++----------- 3 files changed, 32 insertions(+), 48 deletions(-) diff --git a/agentune/core/sampler/base.py b/agentune/core/sampler/base.py index 671a985c..02606faf 100644 --- a/agentune/core/sampler/base.py +++ b/agentune/core/sampler/base.py @@ -33,7 +33,7 @@ def _validate_inputs(self, dataset: Dataset, sample_size: int) -> None: class TableSampler(ABC): """Abstract base class for data sampling from TableWithJoinStrategies.""" @abstractmethod - def sample(self, table: TableWithJoinStrategies, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = None) -> Dataset: + def sample(self, table: TableWithJoinStrategies, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = 42) -> Dataset: """Sample data from a given table with join strategies.""" diff --git a/agentune/core/sampler/table_samples.py b/agentune/core/sampler/table_samples.py index 99484947..2f0f88ae 100644 --- a/agentune/core/sampler/table_samples.py +++ b/agentune/core/sampler/table_samples.py @@ -14,13 +14,12 @@ class HeadTableSampler(TableSampler): """Table sampler that returns the first N rows from a table. - This sampler executes a simple SELECT * query with a LIMIT clause - to retrieve the head of the table. The order of rows is determined - by the table's natural order (or index if present). + This sampler executes a SELECT * query with a WHERE clause filtering by rowid + to retrieve the head of the table. Rows are ordered by DuckDB's rowid pseudocolumn. """ @override - def sample(self, table: TableWithJoinStrategies, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = None) -> Dataset: + def sample(self, table: TableWithJoinStrategies, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = 42) -> Dataset: """Sample the first N rows from the table. Args: @@ -50,7 +49,7 @@ class RandomStartTableSampler(TableSampler): """ @override - def sample(self, table: TableWithJoinStrategies, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = None) -> Dataset: + def sample(self, table: TableWithJoinStrategies, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = 42) -> Dataset: """Sample consecutive rows from a random starting point in the table. Args: diff --git a/tests/agentune/core/sampler/test_table_samplers.py b/tests/agentune/core/sampler/test_table_samplers.py index 4621f211..178c03ed 100644 --- a/tests/agentune/core/sampler/test_table_samplers.py +++ b/tests/agentune/core/sampler/test_table_samplers.py @@ -4,34 +4,26 @@ from duckdb import DuckDBPyConnection from agentune.analyze.join.base import TableWithJoinStrategies -from agentune.core.database import DuckdbName, DuckdbTable +from agentune.core.database import DuckdbTable +from agentune.core.dataset import Dataset, DatasetSink from agentune.core.sampler.table_samples import HeadTableSampler, RandomStartTableSampler -from agentune.core.schema import Field, Schema -from agentune.core.types import int32, string def create_test_table(conn: DuckDBPyConnection, table_name: str, num_rows: int) -> TableWithJoinStrategies: """Helper to create test tables in DuckDB.""" - # Create test data + # Create test data with explicit schema to avoid Null type issues with empty DataFrames data = pl.DataFrame({ 'id': list(range(num_rows)), - 'value': [f'item_{i}' for i in range(num_rows)], - }) + 'value': [f'item_{i}' for i in range(num_rows)],}, + schema={'id': pl.Int64, 'value': pl.String} + ) # Create actual table in DuckDB (not just register) - qualified_name = DuckdbName.qualify(table_name, conn) - schema = Schema((Field('id', int32), Field('value', string))) - duckdb_table = DuckdbTable(name=qualified_name, schema=schema) - - # Create the table - duckdb_table.create(conn, if_not_exists=True) - - # Insert data - if num_rows > 0: - conn.register('__temp_data', data) - conn.execute(f'INSERT INTO {qualified_name} SELECT * FROM __temp_data') - conn.unregister('__temp_data') + DatasetSink.into_unqualified_duckdb_table(table_name, conn).write(Dataset.from_polars(data).as_source(), conn) + # get DuckdbTable representation from conn + duckdb_table = DuckdbTable.from_duckdb(table_name, conn) + # Return TableWithJoinStrategies return TableWithJoinStrategies(table=duckdb_table, join_strategies={}) @@ -48,9 +40,8 @@ def test_basic_head_sampling(self, conn: DuckDBPyConnection) -> None: result = sampler.sample(table, conn, sample_size=20) # Validate result - assert result.data.height == 20 - assert result.data['id'].to_list() == list(range(20)) - assert result.schema.names == ['id', 'value'] + expected = conn.table(str(table.table.name)).pl().head(20) + assert result.data.equals(expected) def test_head_sampling_full_table(self, conn: DuckDBPyConnection) -> None: """Test head sampling when sample size equals table size.""" @@ -59,8 +50,8 @@ def test_head_sampling_full_table(self, conn: DuckDBPyConnection) -> None: result = sampler.sample(table, conn, sample_size=50) - assert result.data.height == 50 - assert result.data['id'].to_list() == list(range(50)) + expected = conn.table(str(table.table.name)).pl() + assert result.data.equals(expected) def test_head_sampling_ignores_random_seed(self, conn: DuckDBPyConnection) -> None: """Test that head sampler produces same results regardless of seed.""" @@ -80,8 +71,8 @@ def test_head_sampling_larger_than_table(self, conn: DuckDBPyConnection) -> None result = sampler.sample(table, conn, sample_size=50) # Should return entire table (all 30 rows) - assert result.data.height == 30 - assert result.data['id'].to_list() == list(range(30)) + expected = conn.table(str(table.table.name)).pl() + assert result.data.equals(expected) def test_head_sampling_empty_table(self, conn: DuckDBPyConnection) -> None: """Test head sampling from an empty table.""" @@ -90,10 +81,9 @@ def test_head_sampling_empty_table(self, conn: DuckDBPyConnection) -> None: result = sampler.sample(table, conn, sample_size=10) - # Should return empty dataset - assert result.data.height == 0 - # But schema should still be correct - assert result.schema.names == ['id', 'value'] + # Should return empty dataset with correct schema + expected = conn.table(str(table.table.name)).pl() + assert result.data.equals(expected) class TestRandomStartTableSampler: @@ -132,9 +122,7 @@ def test_random_start_sampling_different_seeds(self, conn: DuckDBPyConnection) - result2 = sampler.sample(table, conn, sample_size=20, random_seed=999) # Different seeds should (very likely) produce different starting points - ids1 = result1.data['id'].to_list() - ids2 = result2.data['id'].to_list() - assert ids1[0] != ids2[0] or ids1 != ids2 + assert result1.data['id'][0] != result2.data['id'][0] def test_random_start_sampling_consecutive_rows(self, conn: DuckDBPyConnection) -> None: """Test that sampled rows are consecutive.""" @@ -156,8 +144,8 @@ def test_random_start_sampling_full_table(self, conn: DuckDBPyConnection) -> Non result = sampler.sample(table, conn, sample_size=50, random_seed=42) # When sample size equals table size, starting point should be 0 - assert result.data.height == 50 - assert result.data['id'].to_list() == list(range(50)) + expected = conn.table(str(table.table.name)).pl() + assert result.data.equals(expected) def test_random_start_sampling_larger_than_table(self, conn: DuckDBPyConnection) -> None: """Test random start sampling when sample size exceeds table size.""" @@ -167,10 +155,8 @@ def test_random_start_sampling_larger_than_table(self, conn: DuckDBPyConnection) result = sampler.sample(table, conn, sample_size=50, random_seed=42) # Should return entire table - assert result.data.height == 30 - assert result.data['id'].to_list() == list(range(30)) - # Schema should be correct - assert result.schema.names == ['id', 'value'] + expected = conn.table(str(table.table.name)).pl() + assert result.data.equals(expected) def test_random_start_sampling_empty_table(self, conn: DuckDBPyConnection) -> None: """Test random start sampling with empty table.""" @@ -179,10 +165,9 @@ def test_random_start_sampling_empty_table(self, conn: DuckDBPyConnection) -> No result = sampler.sample(table, conn, sample_size=10, random_seed=42) - # Should return empty dataset - assert result.data.height == 0 - # But schema should still be correct - assert result.schema.names == ['id', 'value'] + # Should return empty dataset with correct schema + expected = conn.table(str(table.table.name)).pl() + assert result.data.equals(expected) def test_random_start_sampling_near_end(self, conn: DuckDBPyConnection) -> None: """Test that samples near the end of the table work correctly.""" From 674d3bb2822519b27f59a6edd5d711859e634141 Mon Sep 17 00:00:00 2001 From: yotam amar~ Date: Wed, 14 Jan 2026 11:34:42 +0200 Subject: [PATCH 5/6] update TableSampler interface to accept table names --- agentune/core/sampler/base.py | 8 +-- agentune/core/sampler/table_samples.py | 24 ++++--- .../core/sampler/test_table_samplers.py | 72 +++++++++---------- 3 files changed, 53 insertions(+), 51 deletions(-) diff --git a/agentune/core/sampler/base.py b/agentune/core/sampler/base.py index 02606faf..1812dc8b 100644 --- a/agentune/core/sampler/base.py +++ b/agentune/core/sampler/base.py @@ -9,7 +9,7 @@ import attrs from duckdb import DuckDBPyConnection -from agentune.analyze.join.base import TableWithJoinStrategies +from agentune.core.database import DuckdbName from agentune.core.dataset import Dataset @@ -31,10 +31,10 @@ def _validate_inputs(self, dataset: Dataset, sample_size: int) -> None: @attrs.define class TableSampler(ABC): - """Abstract base class for data sampling from TableWithJoinStrategies.""" + """Abstract base class for data sampling from tables.""" @abstractmethod - def sample(self, table: TableWithJoinStrategies, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = 42) -> Dataset: - """Sample data from a given table with join strategies.""" + def sample(self, table_name: DuckdbName | str, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = 42) -> Dataset: + """Sample data from a given table.""" @attrs.define diff --git a/agentune/core/sampler/table_samples.py b/agentune/core/sampler/table_samples.py index 2f0f88ae..ebdbb46d 100644 --- a/agentune/core/sampler/table_samples.py +++ b/agentune/core/sampler/table_samples.py @@ -5,7 +5,7 @@ import attrs from duckdb import DuckDBPyConnection -from agentune.analyze.join.base import TableWithJoinStrategies +from agentune.core.database import DuckdbName from agentune.core.dataset import Dataset, duckdb_to_dataset from agentune.core.sampler.base import TableSampler @@ -19,11 +19,11 @@ class HeadTableSampler(TableSampler): """ @override - def sample(self, table: TableWithJoinStrategies, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = 42) -> Dataset: + def sample(self, table_name: DuckdbName | str, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = 42) -> Dataset: """Sample the first N rows from the table. Args: - table: The table with join strategies to sample from + table_name: The name of the table to sample from conn: The DuckDB connection sample_size: Number of rows to sample (limit) random_seed: Not used in this sampler (kept for interface compatibility) @@ -31,9 +31,11 @@ def sample(self, table: TableWithJoinStrategies, conn: DuckDBPyConnection, sampl Returns: Dataset containing the first sample_size rows """ - table_name = str(table.table.name) + if isinstance(table_name, str): + table_name = DuckdbName.qualify(table_name, conn) + end_rowid = sample_size - 1 - sql_query = f'SELECT * FROM {table_name} WHERE rowid BETWEEN 0 AND {end_rowid} ORDER BY rowid' + sql_query = f'SELECT * FROM {table_name!s} WHERE rowid BETWEEN 0 AND {end_rowid} ORDER BY rowid' relation = conn.sql(sql_query) return duckdb_to_dataset(relation) @@ -49,11 +51,11 @@ class RandomStartTableSampler(TableSampler): """ @override - def sample(self, table: TableWithJoinStrategies, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = 42) -> Dataset: + def sample(self, table_name: DuckdbName | str, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = 42) -> Dataset: """Sample consecutive rows from a random starting point in the table. Args: - table: The table with join strategies to sample from + table_name: The name of the table to sample from conn: The DuckDB connection sample_size: Number of consecutive rows to sample random_seed: Random seed for selecting the starting point (for reproducibility) @@ -61,10 +63,10 @@ def sample(self, table: TableWithJoinStrategies, conn: DuckDBPyConnection, sampl Returns: Dataset containing sample_size consecutive rows starting from a random position """ - table_name = str(table.table.name) - + if isinstance(table_name, str): + table_name = DuckdbName.qualify(table_name, conn) # Get table size - table_size = len(conn.table(table_name)) + table_size = len(conn.table(str(table_name))) # Adjust sample size if it exceeds table size sample_size = min(sample_size, table_size) @@ -76,7 +78,7 @@ def sample(self, table: TableWithJoinStrategies, conn: DuckDBPyConnection, sampl # Select consecutive rows starting from the random rowid # Using DuckDB's built-in rowid pseudocolumn for deterministic and efficient filtering - sql_query = f'SELECT * FROM {table_name} WHERE rowid BETWEEN {start_rowid} AND {end_rowid} ORDER BY rowid' + sql_query = f'SELECT * FROM {table_name!s} WHERE rowid BETWEEN {start_rowid} AND {end_rowid} ORDER BY rowid' relation = conn.sql(sql_query) return duckdb_to_dataset(relation) diff --git a/tests/agentune/core/sampler/test_table_samplers.py b/tests/agentune/core/sampler/test_table_samplers.py index 178c03ed..3218b078 100644 --- a/tests/agentune/core/sampler/test_table_samplers.py +++ b/tests/agentune/core/sampler/test_table_samplers.py @@ -34,55 +34,55 @@ class TestHeadTableSampler: def test_basic_head_sampling(self, conn: DuckDBPyConnection) -> None: """Test basic head sampling functionality.""" sampler = HeadTableSampler() - table = create_test_table(conn, 'test_table', 100) + table_name = create_test_table(conn, 'test_table', 100).table.name # Sample first 20 rows - result = sampler.sample(table, conn, sample_size=20) + result = sampler.sample(table_name, conn, sample_size=20) # Validate result - expected = conn.table(str(table.table.name)).pl().head(20) + expected = conn.table(str(table_name)).pl().head(20) assert result.data.equals(expected) def test_head_sampling_full_table(self, conn: DuckDBPyConnection) -> None: """Test head sampling when sample size equals table size.""" sampler = HeadTableSampler() - table = create_test_table(conn, 'test_table_full', 50) + table_name = create_test_table(conn, 'test_table_full', 50).table.name - result = sampler.sample(table, conn, sample_size=50) + result = sampler.sample(table_name, conn, sample_size=50) - expected = conn.table(str(table.table.name)).pl() + expected = conn.table(str(table_name)).pl() assert result.data.equals(expected) def test_head_sampling_ignores_random_seed(self, conn: DuckDBPyConnection) -> None: """Test that head sampler produces same results regardless of seed.""" sampler = HeadTableSampler() - table = create_test_table(conn, 'test_table_seed', 50) + table_name = create_test_table(conn, 'test_table_seed', 50).table.name - result1 = sampler.sample(table, conn, sample_size=10, random_seed=42) - result2 = sampler.sample(table, conn, sample_size=10, random_seed=999) + result1 = sampler.sample(table_name, conn, sample_size=10, random_seed=42) + result2 = sampler.sample(table_name, conn, sample_size=10, random_seed=999) assert result1.data.equals(result2.data) def test_head_sampling_larger_than_table(self, conn: DuckDBPyConnection) -> None: """Test head sampling when sample size exceeds table size.""" sampler = HeadTableSampler() - table = create_test_table(conn, 'test_head_oversample', 30) + table_name = create_test_table(conn, 'test_head_oversample', 30).table.name - result = sampler.sample(table, conn, sample_size=50) + result = sampler.sample(table_name, conn, sample_size=50) # Should return entire table (all 30 rows) - expected = conn.table(str(table.table.name)).pl() + expected = conn.table(str(table_name)).pl() assert result.data.equals(expected) def test_head_sampling_empty_table(self, conn: DuckDBPyConnection) -> None: """Test head sampling from an empty table.""" sampler = HeadTableSampler() - table = create_test_table(conn, 'test_head_empty', 0) + table_name = create_test_table(conn, 'test_head_empty', 0).table.name - result = sampler.sample(table, conn, sample_size=10) + result = sampler.sample(table_name, conn, sample_size=10) # Should return empty dataset with correct schema - expected = conn.table(str(table.table.name)).pl() + expected = conn.table(str(table_name)).pl() assert result.data.equals(expected) @@ -92,9 +92,9 @@ class TestRandomStartTableSampler: def test_basic_random_start_sampling(self, conn: DuckDBPyConnection) -> None: """Test basic random start sampling functionality.""" sampler = RandomStartTableSampler() - table = create_test_table(conn, 'test_random_table', 100) + table_name = create_test_table(conn, 'test_random_table', 100).table.name - result = sampler.sample(table, conn, sample_size=20, random_seed=42) + result = sampler.sample(table_name, conn, sample_size=20, random_seed=42) # Validate result assert result.data.height == 20 @@ -106,20 +106,20 @@ def test_basic_random_start_sampling(self, conn: DuckDBPyConnection) -> None: def test_random_start_sampling_reproducibility(self, conn: DuckDBPyConnection) -> None: """Test that random start sampling is reproducible with same seed.""" sampler = RandomStartTableSampler() - table = create_test_table(conn, 'test_reproducible', 100) + table_name = create_test_table(conn, 'test_reproducible', 100).table.name - result1 = sampler.sample(table, conn, sample_size=15, random_seed=123) - result2 = sampler.sample(table, conn, sample_size=15, random_seed=123) + result1 = sampler.sample(table_name, conn, sample_size=15, random_seed=123) + result2 = sampler.sample(table_name, conn, sample_size=15, random_seed=123) assert result1.data.equals(result2.data) def test_random_start_sampling_different_seeds(self, conn: DuckDBPyConnection) -> None: """Test that different seeds produce different starting points.""" sampler = RandomStartTableSampler() - table = create_test_table(conn, 'test_different_seeds', 100) + table_name = create_test_table(conn, 'test_different_seeds', 100).table.name - result1 = sampler.sample(table, conn, sample_size=20, random_seed=42) - result2 = sampler.sample(table, conn, sample_size=20, random_seed=999) + result1 = sampler.sample(table_name, conn, sample_size=20, random_seed=42) + result2 = sampler.sample(table_name, conn, sample_size=20, random_seed=999) # Different seeds should (very likely) produce different starting points assert result1.data['id'][0] != result2.data['id'][0] @@ -127,9 +127,9 @@ def test_random_start_sampling_different_seeds(self, conn: DuckDBPyConnection) - def test_random_start_sampling_consecutive_rows(self, conn: DuckDBPyConnection) -> None: """Test that sampled rows are consecutive.""" sampler = RandomStartTableSampler() - table = create_test_table(conn, 'test_consecutive', 200) + table_name = create_test_table(conn, 'test_consecutive', 200).table.name - result = sampler.sample(table, conn, sample_size=30, random_seed=42) + result = sampler.sample(table_name, conn, sample_size=30, random_seed=42) ids = result.data['id'].to_list() # Verify consecutiveness @@ -139,44 +139,44 @@ def test_random_start_sampling_consecutive_rows(self, conn: DuckDBPyConnection) def test_random_start_sampling_full_table(self, conn: DuckDBPyConnection) -> None: """Test random start sampling when sample size equals table size.""" sampler = RandomStartTableSampler() - table = create_test_table(conn, 'test_full_random', 50) + table_name = create_test_table(conn, 'test_full_random', 50).table.name - result = sampler.sample(table, conn, sample_size=50, random_seed=42) + result = sampler.sample(table_name, conn, sample_size=50, random_seed=42) # When sample size equals table size, starting point should be 0 - expected = conn.table(str(table.table.name)).pl() + expected = conn.table(str(table_name)).pl() assert result.data.equals(expected) def test_random_start_sampling_larger_than_table(self, conn: DuckDBPyConnection) -> None: """Test random start sampling when sample size exceeds table size.""" sampler = RandomStartTableSampler() - table = create_test_table(conn, 'test_oversample', 30) + table_name = create_test_table(conn, 'test_oversample', 30).table.name - result = sampler.sample(table, conn, sample_size=50, random_seed=42) + result = sampler.sample(table_name, conn, sample_size=50, random_seed=42) # Should return entire table - expected = conn.table(str(table.table.name)).pl() + expected = conn.table(str(table_name)).pl() assert result.data.equals(expected) def test_random_start_sampling_empty_table(self, conn: DuckDBPyConnection) -> None: """Test random start sampling with empty table.""" sampler = RandomStartTableSampler() - table = create_test_table(conn, 'test_empty', 0) + table_name = create_test_table(conn, 'test_empty', 0).table.name - result = sampler.sample(table, conn, sample_size=10, random_seed=42) + result = sampler.sample(table_name, conn, sample_size=10, random_seed=42) # Should return empty dataset with correct schema - expected = conn.table(str(table.table.name)).pl() + expected = conn.table(str(table_name)).pl() assert result.data.equals(expected) def test_random_start_sampling_near_end(self, conn: DuckDBPyConnection) -> None: """Test that samples near the end of the table work correctly.""" sampler = RandomStartTableSampler() - table = create_test_table(conn, 'test_near_end', 100) + table_name = create_test_table(conn, 'test_near_end', 100).table.name # Force a start near the end by setting specific seed # This tests that we correctly handle the range - result = sampler.sample(table, conn, sample_size=10, random_seed=42) + result = sampler.sample(table_name, conn, sample_size=10, random_seed=42) ids = result.data['id'].to_list() # Should still get 10 consecutive rows From a1a9ec2eee51c2e788a25103f20b4013b50bb8ab Mon Sep 17 00:00:00 2001 From: yotam amar~ Date: Wed, 14 Jan 2026 13:16:38 +0200 Subject: [PATCH 6/6] refactor: simplify create_test_table function and update test cases --- .../core/sampler/test_table_samplers.py | 49 ++++++++++--------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/tests/agentune/core/sampler/test_table_samplers.py b/tests/agentune/core/sampler/test_table_samplers.py index 3218b078..cb895c3a 100644 --- a/tests/agentune/core/sampler/test_table_samplers.py +++ b/tests/agentune/core/sampler/test_table_samplers.py @@ -3,13 +3,11 @@ import polars as pl from duckdb import DuckDBPyConnection -from agentune.analyze.join.base import TableWithJoinStrategies -from agentune.core.database import DuckdbTable from agentune.core.dataset import Dataset, DatasetSink from agentune.core.sampler.table_samples import HeadTableSampler, RandomStartTableSampler -def create_test_table(conn: DuckDBPyConnection, table_name: str, num_rows: int) -> TableWithJoinStrategies: +def create_test_table(conn: DuckDBPyConnection, table_name: str, num_rows: int) -> None: """Helper to create test tables in DuckDB.""" # Create test data with explicit schema to avoid Null type issues with empty DataFrames data = pl.DataFrame({ @@ -20,12 +18,6 @@ def create_test_table(conn: DuckDBPyConnection, table_name: str, num_rows: int) # Create actual table in DuckDB (not just register) DatasetSink.into_unqualified_duckdb_table(table_name, conn).write(Dataset.from_polars(data).as_source(), conn) - - # get DuckdbTable representation from conn - duckdb_table = DuckdbTable.from_duckdb(table_name, conn) - - # Return TableWithJoinStrategies - return TableWithJoinStrategies(table=duckdb_table, join_strategies={}) class TestHeadTableSampler: @@ -34,7 +26,8 @@ class TestHeadTableSampler: def test_basic_head_sampling(self, conn: DuckDBPyConnection) -> None: """Test basic head sampling functionality.""" sampler = HeadTableSampler() - table_name = create_test_table(conn, 'test_table', 100).table.name + table_name = 'test_table' + create_test_table(conn, table_name, 100) # Sample first 20 rows result = sampler.sample(table_name, conn, sample_size=20) @@ -46,7 +39,8 @@ def test_basic_head_sampling(self, conn: DuckDBPyConnection) -> None: def test_head_sampling_full_table(self, conn: DuckDBPyConnection) -> None: """Test head sampling when sample size equals table size.""" sampler = HeadTableSampler() - table_name = create_test_table(conn, 'test_table_full', 50).table.name + table_name = 'test_table_full' + create_test_table(conn, table_name, 50) result = sampler.sample(table_name, conn, sample_size=50) @@ -56,7 +50,8 @@ def test_head_sampling_full_table(self, conn: DuckDBPyConnection) -> None: def test_head_sampling_ignores_random_seed(self, conn: DuckDBPyConnection) -> None: """Test that head sampler produces same results regardless of seed.""" sampler = HeadTableSampler() - table_name = create_test_table(conn, 'test_table_seed', 50).table.name + table_name = 'test_table_seed' + create_test_table(conn, table_name, 50) result1 = sampler.sample(table_name, conn, sample_size=10, random_seed=42) result2 = sampler.sample(table_name, conn, sample_size=10, random_seed=999) @@ -66,7 +61,8 @@ def test_head_sampling_ignores_random_seed(self, conn: DuckDBPyConnection) -> No def test_head_sampling_larger_than_table(self, conn: DuckDBPyConnection) -> None: """Test head sampling when sample size exceeds table size.""" sampler = HeadTableSampler() - table_name = create_test_table(conn, 'test_head_oversample', 30).table.name + table_name = 'test_head_oversample' + create_test_table(conn, table_name, 30) result = sampler.sample(table_name, conn, sample_size=50) @@ -77,7 +73,8 @@ def test_head_sampling_larger_than_table(self, conn: DuckDBPyConnection) -> None def test_head_sampling_empty_table(self, conn: DuckDBPyConnection) -> None: """Test head sampling from an empty table.""" sampler = HeadTableSampler() - table_name = create_test_table(conn, 'test_head_empty', 0).table.name + table_name = 'test_head_empty' + create_test_table(conn, table_name, 0) result = sampler.sample(table_name, conn, sample_size=10) @@ -92,7 +89,8 @@ class TestRandomStartTableSampler: def test_basic_random_start_sampling(self, conn: DuckDBPyConnection) -> None: """Test basic random start sampling functionality.""" sampler = RandomStartTableSampler() - table_name = create_test_table(conn, 'test_random_table', 100).table.name + table_name = 'test_random_table' + create_test_table(conn, table_name, 100) result = sampler.sample(table_name, conn, sample_size=20, random_seed=42) @@ -106,7 +104,8 @@ def test_basic_random_start_sampling(self, conn: DuckDBPyConnection) -> None: def test_random_start_sampling_reproducibility(self, conn: DuckDBPyConnection) -> None: """Test that random start sampling is reproducible with same seed.""" sampler = RandomStartTableSampler() - table_name = create_test_table(conn, 'test_reproducible', 100).table.name + table_name = 'test_reproducible' + create_test_table(conn, table_name, 100) result1 = sampler.sample(table_name, conn, sample_size=15, random_seed=123) result2 = sampler.sample(table_name, conn, sample_size=15, random_seed=123) @@ -116,7 +115,8 @@ def test_random_start_sampling_reproducibility(self, conn: DuckDBPyConnection) - def test_random_start_sampling_different_seeds(self, conn: DuckDBPyConnection) -> None: """Test that different seeds produce different starting points.""" sampler = RandomStartTableSampler() - table_name = create_test_table(conn, 'test_different_seeds', 100).table.name + table_name = 'test_different_seeds' + create_test_table(conn, table_name, 100) result1 = sampler.sample(table_name, conn, sample_size=20, random_seed=42) result2 = sampler.sample(table_name, conn, sample_size=20, random_seed=999) @@ -127,7 +127,8 @@ def test_random_start_sampling_different_seeds(self, conn: DuckDBPyConnection) - def test_random_start_sampling_consecutive_rows(self, conn: DuckDBPyConnection) -> None: """Test that sampled rows are consecutive.""" sampler = RandomStartTableSampler() - table_name = create_test_table(conn, 'test_consecutive', 200).table.name + table_name = 'test_consecutive_rows' + create_test_table(conn, table_name, 100) result = sampler.sample(table_name, conn, sample_size=30, random_seed=42) @@ -139,7 +140,8 @@ def test_random_start_sampling_consecutive_rows(self, conn: DuckDBPyConnection) def test_random_start_sampling_full_table(self, conn: DuckDBPyConnection) -> None: """Test random start sampling when sample size equals table size.""" sampler = RandomStartTableSampler() - table_name = create_test_table(conn, 'test_full_random', 50).table.name + table_name = 'test_full_random' + create_test_table(conn, table_name, 50) result = sampler.sample(table_name, conn, sample_size=50, random_seed=42) @@ -150,7 +152,8 @@ def test_random_start_sampling_full_table(self, conn: DuckDBPyConnection) -> Non def test_random_start_sampling_larger_than_table(self, conn: DuckDBPyConnection) -> None: """Test random start sampling when sample size exceeds table size.""" sampler = RandomStartTableSampler() - table_name = create_test_table(conn, 'test_oversample', 30).table.name + table_name = 'test_oversize_random' + create_test_table(conn, table_name, 50) result = sampler.sample(table_name, conn, sample_size=50, random_seed=42) @@ -161,7 +164,8 @@ def test_random_start_sampling_larger_than_table(self, conn: DuckDBPyConnection) def test_random_start_sampling_empty_table(self, conn: DuckDBPyConnection) -> None: """Test random start sampling with empty table.""" sampler = RandomStartTableSampler() - table_name = create_test_table(conn, 'test_empty', 0).table.name + table_name = 'test_empty_random' + create_test_table(conn, table_name, 0) result = sampler.sample(table_name, conn, sample_size=10, random_seed=42) @@ -172,7 +176,8 @@ def test_random_start_sampling_empty_table(self, conn: DuckDBPyConnection) -> No def test_random_start_sampling_near_end(self, conn: DuckDBPyConnection) -> None: """Test that samples near the end of the table work correctly.""" sampler = RandomStartTableSampler() - table_name = create_test_table(conn, 'test_near_end', 100).table.name + table_name = 'test_near_end' + create_test_table(conn, table_name, 100) # Force a start near the end by setting specific seed # This tests that we correctly handle the range