diff --git a/agentune/core/sampler/base.py b/agentune/core/sampler/base.py index 36a75871..1812dc8b 100644 --- a/agentune/core/sampler/base.py +++ b/agentune/core/sampler/base.py @@ -7,7 +7,9 @@ from typing import override import attrs +from duckdb import DuckDBPyConnection +from agentune.core.database import DuckdbName from agentune.core.dataset import Dataset @@ -27,6 +29,14 @@ def _validate_inputs(self, dataset: Dataset, sample_size: int) -> None: raise ValueError('Sample size must be positive') +@attrs.define +class TableSampler(ABC): + """Abstract base class for data sampling from tables.""" + @abstractmethod + def sample(self, table_name: DuckdbName | str, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = 42) -> Dataset: + """Sample data from a given table.""" + + @attrs.define class RandomSampler(DataSampler): """Simple random sampling without any stratification. diff --git a/agentune/core/sampler/table_samples.py b/agentune/core/sampler/table_samples.py new file mode 100644 index 00000000..ebdbb46d --- /dev/null +++ b/agentune/core/sampler/table_samples.py @@ -0,0 +1,84 @@ + +import random +from typing import override + +import attrs +from duckdb import DuckDBPyConnection + +from agentune.core.database import DuckdbName +from agentune.core.dataset import Dataset, duckdb_to_dataset +from agentune.core.sampler.base import TableSampler + + +@attrs.define +class HeadTableSampler(TableSampler): + """Table sampler that returns the first N rows from a table. + + This sampler executes a SELECT * query with a WHERE clause filtering by rowid + to retrieve the head of the table. Rows are ordered by DuckDB's rowid pseudocolumn. + """ + + @override + def sample(self, table_name: DuckdbName | str, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = 42) -> Dataset: + """Sample the first N rows from the table. + + Args: + table_name: The name of the table to sample from + conn: The DuckDB connection + sample_size: Number of rows to sample (limit) + random_seed: Not used in this sampler (kept for interface compatibility) + + Returns: + Dataset containing the first sample_size rows + """ + if isinstance(table_name, str): + table_name = DuckdbName.qualify(table_name, conn) + + end_rowid = sample_size - 1 + sql_query = f'SELECT * FROM {table_name!s} WHERE rowid BETWEEN 0 AND {end_rowid} ORDER BY rowid' + + relation = conn.sql(sql_query) + return duckdb_to_dataset(relation) + + +@attrs.define +class RandomStartTableSampler(TableSampler): + """Table sampler that returns consecutive rows starting from a random position. + + This sampler first determines the table size, then selects a random starting + point (using the provided seed for reproducibility), and returns consecutive + rows from that point. Uses row_number() to handle the offset. + """ + + @override + def sample(self, table_name: DuckdbName | str, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = 42) -> Dataset: + """Sample consecutive rows from a random starting point in the table. + + Args: + table_name: The name of the table to sample from + conn: The DuckDB connection + sample_size: Number of consecutive rows to sample + random_seed: Random seed for selecting the starting point (for reproducibility) + + Returns: + Dataset containing sample_size consecutive rows starting from a random position + """ + if isinstance(table_name, str): + table_name = DuckdbName.qualify(table_name, conn) + # Get table size + table_size = len(conn.table(str(table_name))) + + # Adjust sample size if it exceeds table size + sample_size = min(sample_size, table_size) + + # Select random starting point + rng = random.Random(random_seed) + start_rowid = rng.randint(0, max(0, table_size - sample_size)) + end_rowid = start_rowid + sample_size - 1 + + # Select consecutive rows starting from the random rowid + # Using DuckDB's built-in rowid pseudocolumn for deterministic and efficient filtering + sql_query = f'SELECT * FROM {table_name!s} WHERE rowid BETWEEN {start_rowid} AND {end_rowid} ORDER BY rowid' + + relation = conn.sql(sql_query) + return duckdb_to_dataset(relation) diff --git a/tests/agentune/core/sampler/test_table_samplers.py b/tests/agentune/core/sampler/test_table_samplers.py new file mode 100644 index 00000000..cb895c3a --- /dev/null +++ b/tests/agentune/core/sampler/test_table_samplers.py @@ -0,0 +1,191 @@ +"""Tests for table sampling utilities.""" + +import polars as pl +from duckdb import DuckDBPyConnection + +from agentune.core.dataset import Dataset, DatasetSink +from agentune.core.sampler.table_samples import HeadTableSampler, RandomStartTableSampler + + +def create_test_table(conn: DuckDBPyConnection, table_name: str, num_rows: int) -> None: + """Helper to create test tables in DuckDB.""" + # Create test data with explicit schema to avoid Null type issues with empty DataFrames + data = pl.DataFrame({ + 'id': list(range(num_rows)), + 'value': [f'item_{i}' for i in range(num_rows)],}, + schema={'id': pl.Int64, 'value': pl.String} + ) + + # Create actual table in DuckDB (not just register) + DatasetSink.into_unqualified_duckdb_table(table_name, conn).write(Dataset.from_polars(data).as_source(), conn) + + +class TestHeadTableSampler: + """Test HeadTableSampler functionality.""" + + def test_basic_head_sampling(self, conn: DuckDBPyConnection) -> None: + """Test basic head sampling functionality.""" + sampler = HeadTableSampler() + table_name = 'test_table' + create_test_table(conn, table_name, 100) + + # Sample first 20 rows + result = sampler.sample(table_name, conn, sample_size=20) + + # Validate result + expected = conn.table(str(table_name)).pl().head(20) + assert result.data.equals(expected) + + def test_head_sampling_full_table(self, conn: DuckDBPyConnection) -> None: + """Test head sampling when sample size equals table size.""" + sampler = HeadTableSampler() + table_name = 'test_table_full' + create_test_table(conn, table_name, 50) + + result = sampler.sample(table_name, conn, sample_size=50) + + expected = conn.table(str(table_name)).pl() + assert result.data.equals(expected) + + def test_head_sampling_ignores_random_seed(self, conn: DuckDBPyConnection) -> None: + """Test that head sampler produces same results regardless of seed.""" + sampler = HeadTableSampler() + table_name = 'test_table_seed' + create_test_table(conn, table_name, 50) + + result1 = sampler.sample(table_name, conn, sample_size=10, random_seed=42) + result2 = sampler.sample(table_name, conn, sample_size=10, random_seed=999) + + assert result1.data.equals(result2.data) + + def test_head_sampling_larger_than_table(self, conn: DuckDBPyConnection) -> None: + """Test head sampling when sample size exceeds table size.""" + sampler = HeadTableSampler() + table_name = 'test_head_oversample' + create_test_table(conn, table_name, 30) + + result = sampler.sample(table_name, conn, sample_size=50) + + # Should return entire table (all 30 rows) + expected = conn.table(str(table_name)).pl() + assert result.data.equals(expected) + + def test_head_sampling_empty_table(self, conn: DuckDBPyConnection) -> None: + """Test head sampling from an empty table.""" + sampler = HeadTableSampler() + table_name = 'test_head_empty' + create_test_table(conn, table_name, 0) + + result = sampler.sample(table_name, conn, sample_size=10) + + # Should return empty dataset with correct schema + expected = conn.table(str(table_name)).pl() + assert result.data.equals(expected) + + +class TestRandomStartTableSampler: + """Test RandomStartTableSampler functionality.""" + + def test_basic_random_start_sampling(self, conn: DuckDBPyConnection) -> None: + """Test basic random start sampling functionality.""" + sampler = RandomStartTableSampler() + table_name = 'test_random_table' + create_test_table(conn, table_name, 100) + + result = sampler.sample(table_name, conn, sample_size=20, random_seed=42) + + # Validate result + assert result.data.height == 20 + assert result.schema.names == ['id', 'value'] + # Check that rows are consecutive + ids = result.data['id'].to_list() + assert ids == list(range(ids[0], ids[0] + 20)) + + def test_random_start_sampling_reproducibility(self, conn: DuckDBPyConnection) -> None: + """Test that random start sampling is reproducible with same seed.""" + sampler = RandomStartTableSampler() + table_name = 'test_reproducible' + create_test_table(conn, table_name, 100) + + result1 = sampler.sample(table_name, conn, sample_size=15, random_seed=123) + result2 = sampler.sample(table_name, conn, sample_size=15, random_seed=123) + + assert result1.data.equals(result2.data) + + def test_random_start_sampling_different_seeds(self, conn: DuckDBPyConnection) -> None: + """Test that different seeds produce different starting points.""" + sampler = RandomStartTableSampler() + table_name = 'test_different_seeds' + create_test_table(conn, table_name, 100) + + result1 = sampler.sample(table_name, conn, sample_size=20, random_seed=42) + result2 = sampler.sample(table_name, conn, sample_size=20, random_seed=999) + + # Different seeds should (very likely) produce different starting points + assert result1.data['id'][0] != result2.data['id'][0] + + def test_random_start_sampling_consecutive_rows(self, conn: DuckDBPyConnection) -> None: + """Test that sampled rows are consecutive.""" + sampler = RandomStartTableSampler() + table_name = 'test_consecutive_rows' + create_test_table(conn, table_name, 100) + + result = sampler.sample(table_name, conn, sample_size=30, random_seed=42) + + ids = result.data['id'].to_list() + # Verify consecutiveness + for i in range(len(ids) - 1): + assert ids[i + 1] == ids[i] + 1 + + def test_random_start_sampling_full_table(self, conn: DuckDBPyConnection) -> None: + """Test random start sampling when sample size equals table size.""" + sampler = RandomStartTableSampler() + table_name = 'test_full_random' + create_test_table(conn, table_name, 50) + + result = sampler.sample(table_name, conn, sample_size=50, random_seed=42) + + # When sample size equals table size, starting point should be 0 + expected = conn.table(str(table_name)).pl() + assert result.data.equals(expected) + + def test_random_start_sampling_larger_than_table(self, conn: DuckDBPyConnection) -> None: + """Test random start sampling when sample size exceeds table size.""" + sampler = RandomStartTableSampler() + table_name = 'test_oversize_random' + create_test_table(conn, table_name, 50) + + result = sampler.sample(table_name, conn, sample_size=50, random_seed=42) + + # Should return entire table + expected = conn.table(str(table_name)).pl() + assert result.data.equals(expected) + + def test_random_start_sampling_empty_table(self, conn: DuckDBPyConnection) -> None: + """Test random start sampling with empty table.""" + sampler = RandomStartTableSampler() + table_name = 'test_empty_random' + create_test_table(conn, table_name, 0) + + result = sampler.sample(table_name, conn, sample_size=10, random_seed=42) + + # Should return empty dataset with correct schema + expected = conn.table(str(table_name)).pl() + assert result.data.equals(expected) + + def test_random_start_sampling_near_end(self, conn: DuckDBPyConnection) -> None: + """Test that samples near the end of the table work correctly.""" + sampler = RandomStartTableSampler() + table_name = 'test_near_end' + create_test_table(conn, table_name, 100) + + # Force a start near the end by setting specific seed + # This tests that we correctly handle the range + result = sampler.sample(table_name, conn, sample_size=10, random_seed=42) + + ids = result.data['id'].to_list() + # Should still get 10 consecutive rows + assert len(ids) == 10 + assert ids == list(range(ids[0], ids[0] + 10)) + # And they should be within the table bounds + assert ids[-1] < 100