-
Notifications
You must be signed in to change notification settings - Fork 3
samplers: create HeadTableSampler and RandomStartTableSampler table samplers #151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
d553ff3
feat: implement HeadTableSampler and RandomStartTableSampler for tabl…
yotam319-sparkbeyond f9f968f
feat: update sampling queries to use rowid range
yotam319-sparkbeyond a81f555
feat: optimize table size retrieval in RandomStartTableSampler
yotam319-sparkbeyond f77cd2e
apply comments
yotam319-sparkbeyond 674d3bb
update TableSampler interface to accept table names
yotam319-sparkbeyond a1a9ec2
refactor: simplify create_test_table function and update test cases
yotam319-sparkbeyond File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
|
|
||
| import random | ||
| from typing import override | ||
|
|
||
| import attrs | ||
| from duckdb import DuckDBPyConnection | ||
|
|
||
| from agentune.core.database import DuckdbName | ||
| from agentune.core.dataset import Dataset, duckdb_to_dataset | ||
| from agentune.core.sampler.base import TableSampler | ||
|
|
||
|
|
||
| @attrs.define | ||
| class HeadTableSampler(TableSampler): | ||
| """Table sampler that returns the first N rows from a table. | ||
|
|
||
| This sampler executes a SELECT * query with a WHERE clause filtering by rowid | ||
| to retrieve the head of the table. Rows are ordered by DuckDB's rowid pseudocolumn. | ||
| """ | ||
|
|
||
| @override | ||
| def sample(self, table_name: DuckdbName | str, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = 42) -> Dataset: | ||
| """Sample the first N rows from the table. | ||
|
|
||
| Args: | ||
| table_name: The name of the table to sample from | ||
| conn: The DuckDB connection | ||
| sample_size: Number of rows to sample (limit) | ||
| random_seed: Not used in this sampler (kept for interface compatibility) | ||
|
|
||
| Returns: | ||
| Dataset containing the first sample_size rows | ||
| """ | ||
| if isinstance(table_name, str): | ||
| table_name = DuckdbName.qualify(table_name, conn) | ||
|
|
||
| end_rowid = sample_size - 1 | ||
| sql_query = f'SELECT * FROM {table_name!s} WHERE rowid BETWEEN 0 AND {end_rowid} ORDER BY rowid' | ||
|
|
||
| relation = conn.sql(sql_query) | ||
| return duckdb_to_dataset(relation) | ||
|
|
||
|
|
||
| @attrs.define | ||
| class RandomStartTableSampler(TableSampler): | ||
| """Table sampler that returns consecutive rows starting from a random position. | ||
|
|
||
| This sampler first determines the table size, then selects a random starting | ||
| point (using the provided seed for reproducibility), and returns consecutive | ||
| rows from that point. Uses row_number() to handle the offset. | ||
| """ | ||
|
|
||
| @override | ||
| def sample(self, table_name: DuckdbName | str, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = 42) -> Dataset: | ||
| """Sample consecutive rows from a random starting point in the table. | ||
|
|
||
| Args: | ||
| table_name: The name of the table to sample from | ||
| conn: The DuckDB connection | ||
| sample_size: Number of consecutive rows to sample | ||
| random_seed: Random seed for selecting the starting point (for reproducibility) | ||
|
|
||
| Returns: | ||
| Dataset containing sample_size consecutive rows starting from a random position | ||
| """ | ||
| if isinstance(table_name, str): | ||
| table_name = DuckdbName.qualify(table_name, conn) | ||
| # Get table size | ||
| table_size = len(conn.table(str(table_name))) | ||
|
|
||
| # Adjust sample size if it exceeds table size | ||
| sample_size = min(sample_size, table_size) | ||
|
|
||
| # Select random starting point | ||
| rng = random.Random(random_seed) | ||
| start_rowid = rng.randint(0, max(0, table_size - sample_size)) | ||
| end_rowid = start_rowid + sample_size - 1 | ||
|
|
||
| # Select consecutive rows starting from the random rowid | ||
| # Using DuckDB's built-in rowid pseudocolumn for deterministic and efficient filtering | ||
| sql_query = f'SELECT * FROM {table_name!s} WHERE rowid BETWEEN {start_rowid} AND {end_rowid} ORDER BY rowid' | ||
|
|
||
| relation = conn.sql(sql_query) | ||
| return duckdb_to_dataset(relation) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,191 @@ | ||
| """Tests for table sampling utilities.""" | ||
|
|
||
| import polars as pl | ||
| from duckdb import DuckDBPyConnection | ||
|
|
||
| from agentune.core.dataset import Dataset, DatasetSink | ||
| from agentune.core.sampler.table_samples import HeadTableSampler, RandomStartTableSampler | ||
|
|
||
|
|
||
| def create_test_table(conn: DuckDBPyConnection, table_name: str, num_rows: int) -> None: | ||
| """Helper to create test tables in DuckDB.""" | ||
| # Create test data with explicit schema to avoid Null type issues with empty DataFrames | ||
| data = pl.DataFrame({ | ||
| 'id': list(range(num_rows)), | ||
| 'value': [f'item_{i}' for i in range(num_rows)],}, | ||
| schema={'id': pl.Int64, 'value': pl.String} | ||
| ) | ||
|
|
||
| # Create actual table in DuckDB (not just register) | ||
| DatasetSink.into_unqualified_duckdb_table(table_name, conn).write(Dataset.from_polars(data).as_source(), conn) | ||
|
|
||
|
|
||
| class TestHeadTableSampler: | ||
| """Test HeadTableSampler functionality.""" | ||
|
|
||
| def test_basic_head_sampling(self, conn: DuckDBPyConnection) -> None: | ||
| """Test basic head sampling functionality.""" | ||
| sampler = HeadTableSampler() | ||
| table_name = 'test_table' | ||
| create_test_table(conn, table_name, 100) | ||
|
|
||
| # Sample first 20 rows | ||
| result = sampler.sample(table_name, conn, sample_size=20) | ||
|
|
||
| # Validate result | ||
| expected = conn.table(str(table_name)).pl().head(20) | ||
| assert result.data.equals(expected) | ||
|
|
||
| def test_head_sampling_full_table(self, conn: DuckDBPyConnection) -> None: | ||
| """Test head sampling when sample size equals table size.""" | ||
| sampler = HeadTableSampler() | ||
| table_name = 'test_table_full' | ||
| create_test_table(conn, table_name, 50) | ||
|
|
||
| result = sampler.sample(table_name, conn, sample_size=50) | ||
|
|
||
| expected = conn.table(str(table_name)).pl() | ||
| assert result.data.equals(expected) | ||
|
|
||
| def test_head_sampling_ignores_random_seed(self, conn: DuckDBPyConnection) -> None: | ||
| """Test that head sampler produces same results regardless of seed.""" | ||
| sampler = HeadTableSampler() | ||
| table_name = 'test_table_seed' | ||
| create_test_table(conn, table_name, 50) | ||
|
|
||
| result1 = sampler.sample(table_name, conn, sample_size=10, random_seed=42) | ||
| result2 = sampler.sample(table_name, conn, sample_size=10, random_seed=999) | ||
|
|
||
| assert result1.data.equals(result2.data) | ||
|
|
||
| def test_head_sampling_larger_than_table(self, conn: DuckDBPyConnection) -> None: | ||
| """Test head sampling when sample size exceeds table size.""" | ||
| sampler = HeadTableSampler() | ||
| table_name = 'test_head_oversample' | ||
| create_test_table(conn, table_name, 30) | ||
|
|
||
| result = sampler.sample(table_name, conn, sample_size=50) | ||
|
|
||
| # Should return entire table (all 30 rows) | ||
| expected = conn.table(str(table_name)).pl() | ||
| assert result.data.equals(expected) | ||
|
|
||
| def test_head_sampling_empty_table(self, conn: DuckDBPyConnection) -> None: | ||
| """Test head sampling from an empty table.""" | ||
| sampler = HeadTableSampler() | ||
| table_name = 'test_head_empty' | ||
| create_test_table(conn, table_name, 0) | ||
|
|
||
| result = sampler.sample(table_name, conn, sample_size=10) | ||
|
|
||
| # Should return empty dataset with correct schema | ||
| expected = conn.table(str(table_name)).pl() | ||
| assert result.data.equals(expected) | ||
|
|
||
|
|
||
| class TestRandomStartTableSampler: | ||
| """Test RandomStartTableSampler functionality.""" | ||
|
|
||
| def test_basic_random_start_sampling(self, conn: DuckDBPyConnection) -> None: | ||
| """Test basic random start sampling functionality.""" | ||
| sampler = RandomStartTableSampler() | ||
| table_name = 'test_random_table' | ||
| create_test_table(conn, table_name, 100) | ||
|
|
||
| result = sampler.sample(table_name, conn, sample_size=20, random_seed=42) | ||
|
|
||
| # Validate result | ||
| assert result.data.height == 20 | ||
| assert result.schema.names == ['id', 'value'] | ||
| # Check that rows are consecutive | ||
| ids = result.data['id'].to_list() | ||
| assert ids == list(range(ids[0], ids[0] + 20)) | ||
|
|
||
| def test_random_start_sampling_reproducibility(self, conn: DuckDBPyConnection) -> None: | ||
| """Test that random start sampling is reproducible with same seed.""" | ||
| sampler = RandomStartTableSampler() | ||
| table_name = 'test_reproducible' | ||
| create_test_table(conn, table_name, 100) | ||
|
|
||
| result1 = sampler.sample(table_name, conn, sample_size=15, random_seed=123) | ||
| result2 = sampler.sample(table_name, conn, sample_size=15, random_seed=123) | ||
|
|
||
| assert result1.data.equals(result2.data) | ||
|
|
||
| def test_random_start_sampling_different_seeds(self, conn: DuckDBPyConnection) -> None: | ||
| """Test that different seeds produce different starting points.""" | ||
| sampler = RandomStartTableSampler() | ||
| table_name = 'test_different_seeds' | ||
| create_test_table(conn, table_name, 100) | ||
|
|
||
| result1 = sampler.sample(table_name, conn, sample_size=20, random_seed=42) | ||
| result2 = sampler.sample(table_name, conn, sample_size=20, random_seed=999) | ||
|
|
||
| # Different seeds should (very likely) produce different starting points | ||
| assert result1.data['id'][0] != result2.data['id'][0] | ||
|
|
||
| def test_random_start_sampling_consecutive_rows(self, conn: DuckDBPyConnection) -> None: | ||
| """Test that sampled rows are consecutive.""" | ||
| sampler = RandomStartTableSampler() | ||
| table_name = 'test_consecutive_rows' | ||
| create_test_table(conn, table_name, 100) | ||
|
|
||
| result = sampler.sample(table_name, conn, sample_size=30, random_seed=42) | ||
|
|
||
| ids = result.data['id'].to_list() | ||
| # Verify consecutiveness | ||
| for i in range(len(ids) - 1): | ||
| assert ids[i + 1] == ids[i] + 1 | ||
|
|
||
| def test_random_start_sampling_full_table(self, conn: DuckDBPyConnection) -> None: | ||
| """Test random start sampling when sample size equals table size.""" | ||
| sampler = RandomStartTableSampler() | ||
| table_name = 'test_full_random' | ||
| create_test_table(conn, table_name, 50) | ||
|
|
||
| result = sampler.sample(table_name, conn, sample_size=50, random_seed=42) | ||
|
|
||
| # When sample size equals table size, starting point should be 0 | ||
| expected = conn.table(str(table_name)).pl() | ||
| assert result.data.equals(expected) | ||
|
|
||
| def test_random_start_sampling_larger_than_table(self, conn: DuckDBPyConnection) -> None: | ||
| """Test random start sampling when sample size exceeds table size.""" | ||
| sampler = RandomStartTableSampler() | ||
| table_name = 'test_oversize_random' | ||
| create_test_table(conn, table_name, 50) | ||
|
|
||
| result = sampler.sample(table_name, conn, sample_size=50, random_seed=42) | ||
|
|
||
| # Should return entire table | ||
| expected = conn.table(str(table_name)).pl() | ||
| assert result.data.equals(expected) | ||
|
|
||
| def test_random_start_sampling_empty_table(self, conn: DuckDBPyConnection) -> None: | ||
| """Test random start sampling with empty table.""" | ||
| sampler = RandomStartTableSampler() | ||
| table_name = 'test_empty_random' | ||
| create_test_table(conn, table_name, 0) | ||
|
|
||
| result = sampler.sample(table_name, conn, sample_size=10, random_seed=42) | ||
|
|
||
| # Should return empty dataset with correct schema | ||
| expected = conn.table(str(table_name)).pl() | ||
| assert result.data.equals(expected) | ||
|
|
||
| def test_random_start_sampling_near_end(self, conn: DuckDBPyConnection) -> None: | ||
| """Test that samples near the end of the table work correctly.""" | ||
| sampler = RandomStartTableSampler() | ||
| table_name = 'test_near_end' | ||
| create_test_table(conn, table_name, 100) | ||
|
|
||
| # Force a start near the end by setting specific seed | ||
| # This tests that we correctly handle the range | ||
| result = sampler.sample(table_name, conn, sample_size=10, random_seed=42) | ||
|
|
||
| ids = result.data['id'].to_list() | ||
| # Should still get 10 consecutive rows | ||
| assert len(ids) == 10 | ||
| assert ids == list(range(ids[0], ids[0] + 10)) | ||
| # And they should be within the table bounds | ||
| assert ids[-1] < 100 |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.