Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions agentune/core/sampler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from typing import override

import attrs
from duckdb import DuckDBPyConnection

from agentune.core.database import DuckdbName
from agentune.core.dataset import Dataset


Expand All @@ -27,6 +29,14 @@ def _validate_inputs(self, dataset: Dataset, sample_size: int) -> None:
raise ValueError('Sample size must be positive')


@attrs.define
class TableSampler(ABC):
"""Abstract base class for data sampling from tables."""
@abstractmethod
def sample(self, table_name: DuckdbName | str, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = 42) -> Dataset:
"""Sample data from a given table."""


@attrs.define
class RandomSampler(DataSampler):
"""Simple random sampling without any stratification.
Expand Down
84 changes: 84 additions & 0 deletions agentune/core/sampler/table_samples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@

import random
from typing import override

import attrs
from duckdb import DuckDBPyConnection

from agentune.core.database import DuckdbName
from agentune.core.dataset import Dataset, duckdb_to_dataset
from agentune.core.sampler.base import TableSampler


@attrs.define
class HeadTableSampler(TableSampler):
"""Table sampler that returns the first N rows from a table.

This sampler executes a SELECT * query with a WHERE clause filtering by rowid
to retrieve the head of the table. Rows are ordered by DuckDB's rowid pseudocolumn.
"""

@override
def sample(self, table_name: DuckdbName | str, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = 42) -> Dataset:
"""Sample the first N rows from the table.

Args:
table_name: The name of the table to sample from
conn: The DuckDB connection
sample_size: Number of rows to sample (limit)
random_seed: Not used in this sampler (kept for interface compatibility)

Returns:
Dataset containing the first sample_size rows
"""
if isinstance(table_name, str):
table_name = DuckdbName.qualify(table_name, conn)

end_rowid = sample_size - 1
sql_query = f'SELECT * FROM {table_name!s} WHERE rowid BETWEEN 0 AND {end_rowid} ORDER BY rowid'

relation = conn.sql(sql_query)
return duckdb_to_dataset(relation)


@attrs.define
class RandomStartTableSampler(TableSampler):
"""Table sampler that returns consecutive rows starting from a random position.

This sampler first determines the table size, then selects a random starting
point (using the provided seed for reproducibility), and returns consecutive
rows from that point. Uses row_number() to handle the offset.
"""

@override
def sample(self, table_name: DuckdbName | str, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = 42) -> Dataset:
"""Sample consecutive rows from a random starting point in the table.

Args:
table_name: The name of the table to sample from
conn: The DuckDB connection
sample_size: Number of consecutive rows to sample
random_seed: Random seed for selecting the starting point (for reproducibility)

Returns:
Dataset containing sample_size consecutive rows starting from a random position
"""
if isinstance(table_name, str):
table_name = DuckdbName.qualify(table_name, conn)
# Get table size
table_size = len(conn.table(str(table_name)))

# Adjust sample size if it exceeds table size
sample_size = min(sample_size, table_size)

# Select random starting point
rng = random.Random(random_seed)
start_rowid = rng.randint(0, max(0, table_size - sample_size))
end_rowid = start_rowid + sample_size - 1

# Select consecutive rows starting from the random rowid
# Using DuckDB's built-in rowid pseudocolumn for deterministic and efficient filtering
sql_query = f'SELECT * FROM {table_name!s} WHERE rowid BETWEEN {start_rowid} AND {end_rowid} ORDER BY rowid'

relation = conn.sql(sql_query)
return duckdb_to_dataset(relation)
191 changes: 191 additions & 0 deletions tests/agentune/core/sampler/test_table_samplers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""Tests for table sampling utilities."""

import polars as pl
from duckdb import DuckDBPyConnection

from agentune.core.dataset import Dataset, DatasetSink
from agentune.core.sampler.table_samples import HeadTableSampler, RandomStartTableSampler


def create_test_table(conn: DuckDBPyConnection, table_name: str, num_rows: int) -> None:
"""Helper to create test tables in DuckDB."""
# Create test data with explicit schema to avoid Null type issues with empty DataFrames
data = pl.DataFrame({
'id': list(range(num_rows)),
'value': [f'item_{i}' for i in range(num_rows)],},
schema={'id': pl.Int64, 'value': pl.String}
)

# Create actual table in DuckDB (not just register)
DatasetSink.into_unqualified_duckdb_table(table_name, conn).write(Dataset.from_polars(data).as_source(), conn)


class TestHeadTableSampler:
"""Test HeadTableSampler functionality."""

def test_basic_head_sampling(self, conn: DuckDBPyConnection) -> None:
"""Test basic head sampling functionality."""
sampler = HeadTableSampler()
table_name = 'test_table'
create_test_table(conn, table_name, 100)

# Sample first 20 rows
result = sampler.sample(table_name, conn, sample_size=20)

# Validate result
expected = conn.table(str(table_name)).pl().head(20)
assert result.data.equals(expected)

def test_head_sampling_full_table(self, conn: DuckDBPyConnection) -> None:
"""Test head sampling when sample size equals table size."""
sampler = HeadTableSampler()
table_name = 'test_table_full'
create_test_table(conn, table_name, 50)

result = sampler.sample(table_name, conn, sample_size=50)

expected = conn.table(str(table_name)).pl()
assert result.data.equals(expected)

def test_head_sampling_ignores_random_seed(self, conn: DuckDBPyConnection) -> None:
"""Test that head sampler produces same results regardless of seed."""
sampler = HeadTableSampler()
table_name = 'test_table_seed'
create_test_table(conn, table_name, 50)

result1 = sampler.sample(table_name, conn, sample_size=10, random_seed=42)
result2 = sampler.sample(table_name, conn, sample_size=10, random_seed=999)

assert result1.data.equals(result2.data)

def test_head_sampling_larger_than_table(self, conn: DuckDBPyConnection) -> None:
"""Test head sampling when sample size exceeds table size."""
sampler = HeadTableSampler()
table_name = 'test_head_oversample'
create_test_table(conn, table_name, 30)

result = sampler.sample(table_name, conn, sample_size=50)

# Should return entire table (all 30 rows)
expected = conn.table(str(table_name)).pl()
assert result.data.equals(expected)

def test_head_sampling_empty_table(self, conn: DuckDBPyConnection) -> None:
"""Test head sampling from an empty table."""
sampler = HeadTableSampler()
table_name = 'test_head_empty'
create_test_table(conn, table_name, 0)

result = sampler.sample(table_name, conn, sample_size=10)

# Should return empty dataset with correct schema
expected = conn.table(str(table_name)).pl()
assert result.data.equals(expected)


class TestRandomStartTableSampler:
"""Test RandomStartTableSampler functionality."""

def test_basic_random_start_sampling(self, conn: DuckDBPyConnection) -> None:
"""Test basic random start sampling functionality."""
sampler = RandomStartTableSampler()
table_name = 'test_random_table'
create_test_table(conn, table_name, 100)

result = sampler.sample(table_name, conn, sample_size=20, random_seed=42)

# Validate result
assert result.data.height == 20
assert result.schema.names == ['id', 'value']
# Check that rows are consecutive
ids = result.data['id'].to_list()
assert ids == list(range(ids[0], ids[0] + 20))

def test_random_start_sampling_reproducibility(self, conn: DuckDBPyConnection) -> None:
"""Test that random start sampling is reproducible with same seed."""
sampler = RandomStartTableSampler()
table_name = 'test_reproducible'
create_test_table(conn, table_name, 100)

result1 = sampler.sample(table_name, conn, sample_size=15, random_seed=123)
result2 = sampler.sample(table_name, conn, sample_size=15, random_seed=123)

assert result1.data.equals(result2.data)

def test_random_start_sampling_different_seeds(self, conn: DuckDBPyConnection) -> None:
"""Test that different seeds produce different starting points."""
sampler = RandomStartTableSampler()
table_name = 'test_different_seeds'
create_test_table(conn, table_name, 100)

result1 = sampler.sample(table_name, conn, sample_size=20, random_seed=42)
result2 = sampler.sample(table_name, conn, sample_size=20, random_seed=999)

# Different seeds should (very likely) produce different starting points
assert result1.data['id'][0] != result2.data['id'][0]

def test_random_start_sampling_consecutive_rows(self, conn: DuckDBPyConnection) -> None:
"""Test that sampled rows are consecutive."""
sampler = RandomStartTableSampler()
table_name = 'test_consecutive_rows'
create_test_table(conn, table_name, 100)

result = sampler.sample(table_name, conn, sample_size=30, random_seed=42)

ids = result.data['id'].to_list()
# Verify consecutiveness
for i in range(len(ids) - 1):
assert ids[i + 1] == ids[i] + 1

def test_random_start_sampling_full_table(self, conn: DuckDBPyConnection) -> None:
"""Test random start sampling when sample size equals table size."""
sampler = RandomStartTableSampler()
table_name = 'test_full_random'
create_test_table(conn, table_name, 50)

result = sampler.sample(table_name, conn, sample_size=50, random_seed=42)

# When sample size equals table size, starting point should be 0
expected = conn.table(str(table_name)).pl()
assert result.data.equals(expected)

def test_random_start_sampling_larger_than_table(self, conn: DuckDBPyConnection) -> None:
"""Test random start sampling when sample size exceeds table size."""
sampler = RandomStartTableSampler()
table_name = 'test_oversize_random'
create_test_table(conn, table_name, 50)

result = sampler.sample(table_name, conn, sample_size=50, random_seed=42)

# Should return entire table
expected = conn.table(str(table_name)).pl()
assert result.data.equals(expected)

def test_random_start_sampling_empty_table(self, conn: DuckDBPyConnection) -> None:
"""Test random start sampling with empty table."""
sampler = RandomStartTableSampler()
table_name = 'test_empty_random'
create_test_table(conn, table_name, 0)

result = sampler.sample(table_name, conn, sample_size=10, random_seed=42)

# Should return empty dataset with correct schema
expected = conn.table(str(table_name)).pl()
assert result.data.equals(expected)

def test_random_start_sampling_near_end(self, conn: DuckDBPyConnection) -> None:
"""Test that samples near the end of the table work correctly."""
sampler = RandomStartTableSampler()
table_name = 'test_near_end'
create_test_table(conn, table_name, 100)

# Force a start near the end by setting specific seed
# This tests that we correctly handle the range
result = sampler.sample(table_name, conn, sample_size=10, random_seed=42)

ids = result.data['id'].to_list()
# Should still get 10 consecutive rows
assert len(ids) == 10
assert ids == list(range(ids[0], ids[0] + 10))
# And they should be within the table bounds
assert ids[-1] < 100