Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion src/synthetic/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
from llm_client import create_llm_client


# Maximum rows to request per individual LLM call.
# Requests above this threshold are automatically split into sub-batches.
# Production testing shows ≤100 rows generates reliably; larger counts cause
# the model to return fewer rows than requested (typically only ~70-75%).
_LLM_MAX_ROWS_PER_CALL = 100


class SyntheticDataGenerator:
"""Generates synthetic data using LLM."""

Expand Down Expand Up @@ -344,10 +351,37 @@ def generate_batch(
logs.append("ERROR: No LLM configured, using mock data")
return self._generate_mock_data(schema, num_rows), logs

# For large requests, split into smaller LLM calls for reliability.
# Production testing shows ≤100 rows generates reliably; larger counts
# cause the model to return only ~70-75% of requested rows, wasting
# API quota on retries.
if num_rows > _LLM_MAX_ROWS_PER_CALL:
all_records = []
offset = 0
while len(all_records) < num_rows:
remaining = num_rows - len(all_records)
sub_size = min(_LLM_MAX_ROWS_PER_CALL, remaining)
sub_seed = seed + offset if seed is not None else None
sub_records, sub_logs = self.generate_batch(
schema=schema,
num_rows=sub_size,
locale=locale,
seed=sub_seed,
max_retries=max_retries,
)
logs.extend(sub_logs)
all_records.extend(sub_records)
offset += sub_size
return all_records[:num_rows], logs

# Build prompt
prompt = self._build_prompt(schema, num_rows, locale, seed)
logs.append(f"Generated prompt for {num_rows} rows")

# Track the best partial result across attempts so we can fall back
# gracefully when all retries are exhausted.
best_rows = []

# Try to generate data with retries
for attempt in range(max_retries):
try:
Expand All @@ -367,6 +401,10 @@ def generate_batch(
rows = self._parse_csv_response(response_text, num_columns)
logs.append(f"Parsed {len(rows)} rows from CSV")

# Keep the best partial result in case all retries fail
if len(rows) > len(best_rows):
best_rows = rows

if len(rows) < num_rows * 0.8: # At least 80% of requested rows
logs.append(f"WARNING: Only got {len(rows)}/{num_rows} rows, retrying...")
continue
Expand All @@ -388,7 +426,21 @@ def generate_batch(
return self._generate_mock_data(schema, num_rows), logs
time.sleep(1) # Brief delay before retry

return [], logs
# All retries exhausted via the insufficient-rows path (no exception).
# Use the best partial LLM result and fill any gap with mock data so
# callers always receive the requested number of rows.
logs.append(
f"Max retries reached with only {len(best_rows)}/{num_rows} rows, "
"filling remainder with mock data"
)
if best_rows:
records = self._coerce_types(best_rows, schema)
records = self._enforce_uniqueness(records, schema)
if len(records) < num_rows:
mock_fill = self._generate_mock_data(schema, num_rows - len(records))
records.extend(mock_fill)
return records[:num_rows], logs
return self._generate_mock_data(schema, num_rows), logs

def _generate_random_date(
self, start_str: str, end_str: str, include_time: bool = False
Expand Down
133 changes: 132 additions & 1 deletion tests/backend/api/test_synthetic_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import re
import pytest
from unittest.mock import MagicMock, patch
from src.synthetic.validators import (
validate_schema,
validate_generate_request,
validate_preview_request,
)
from src.synthetic.generator import SyntheticDataGenerator
from src.synthetic.generator import SyntheticDataGenerator, _LLM_MAX_ROWS_PER_CALL


class TestValidators:
Expand Down Expand Up @@ -312,5 +313,135 @@ def test_generate_mock_data_with_swapped_dates(self):
assert datetime_pattern.match(record["created_at"]), f"Invalid datetime format: {record['created_at']}"


class TestGenerateBatchLargeRowHandling:
"""Tests for generate_batch behaviour with large row counts and LLM failures."""

SCHEMA = {
"columns": [
{"name": "id", "type": "integer", "options": {"min": 1, "max": 1000000}},
{"name": "name", "type": "string", "options": {}},
]
}

def _make_generator_with_mock_llm(self, llm_side_effect=None, llm_return_value=None):
"""Return a SyntheticDataGenerator with a mocked LLM client."""
generator = SyntheticDataGenerator.__new__(SyntheticDataGenerator)
generator.api_key = "fake-key"
generator.model = "fake-model"
generator._llm_available = True
mock_client = MagicMock()
if llm_side_effect is not None:
mock_client.generate.side_effect = llm_side_effect
elif llm_return_value is not None:
mock_client.generate.return_value = llm_return_value
generator.llm_client = mock_client
return generator

def _csv_for_n_rows(self, n, start=1):
"""Generate a valid CSV string for n rows (id, name)."""
lines = [f"{start + i},name_{start + i}" for i in range(n)]
return "\n".join(lines)

# ------------------------------------------------------------------
# Sub-batching tests
# ------------------------------------------------------------------

def test_llm_max_rows_per_call_constant_is_100(self):
"""_LLM_MAX_ROWS_PER_CALL should be 100 so ≤100-row requests stay
in a single LLM call while larger requests are split."""
assert _LLM_MAX_ROWS_PER_CALL == 100

def test_large_request_is_split_into_sub_batches(self):
"""Requesting >100 rows must trigger sub-batching; each sub-call
should receive at most _LLM_MAX_ROWS_PER_CALL rows."""
call_sizes = []

def capture_prompt(messages, **kwargs):
content = messages[0]["content"]
# Extract row count from "Generate X rows…"
import re as _re
m = _re.search(r"Generate (\d+) rows", content)
if m:
call_sizes.append(int(m.group(1)))
# Return enough rows for the requested count
n = int(m.group(1)) if m else 1
return self._csv_for_n_rows(n)

generator = self._make_generator_with_mock_llm(llm_side_effect=capture_prompt)
records, logs = generator.generate_batch(self.SCHEMA, num_rows=200)

assert len(records) == 200
# Each individual LLM call must be ≤ _LLM_MAX_ROWS_PER_CALL
assert all(n <= _LLM_MAX_ROWS_PER_CALL for n in call_sizes), (
f"Some LLM calls requested more than {_LLM_MAX_ROWS_PER_CALL} rows: {call_sizes}"
)

def test_exact_threshold_uses_single_call(self):
"""Requesting exactly _LLM_MAX_ROWS_PER_CALL rows must NOT trigger
sub-batching (one LLM call)."""
call_count = [0]

def counting_side_effect(messages, **kwargs):
call_count[0] += 1
return self._csv_for_n_rows(_LLM_MAX_ROWS_PER_CALL)

generator = self._make_generator_with_mock_llm(llm_side_effect=counting_side_effect)
records, _ = generator.generate_batch(self.SCHEMA, num_rows=_LLM_MAX_ROWS_PER_CALL)

assert call_count[0] == 1
assert len(records) == _LLM_MAX_ROWS_PER_CALL

# ------------------------------------------------------------------
# Bug-fix: return [], logs → fallback when all retries return too few rows
# ------------------------------------------------------------------

def test_insufficient_rows_all_retries_falls_back_to_mock_data(self):
"""When all retries return fewer than 80% of requested rows (no
exception), generate_batch must return the requested count via
mock-data fill instead of an empty list."""
# LLM always returns only 70 rows when 100 are requested
generator = self._make_generator_with_mock_llm(
llm_return_value=self._csv_for_n_rows(70)
)
records, logs = generator.generate_batch(self.SCHEMA, num_rows=100, max_retries=2)

assert len(records) == 100, (
"Expected 100 records but got an empty list – the return [], logs bug may have reappeared"
)
assert any("filling remainder with mock data" in log for log in logs), (
"Expected a log message about filling with mock data"
)

def test_insufficient_rows_result_is_never_empty(self):
"""generate_batch must never return an empty list regardless of how
many retries fail the 80% threshold."""
generator = self._make_generator_with_mock_llm(
llm_return_value=self._csv_for_n_rows(10) # 10/100 = 10%, well below 80%
)
records, logs = generator.generate_batch(self.SCHEMA, num_rows=100, max_retries=3)

assert len(records) > 0, "generate_batch must never return an empty list"
assert len(records) == 100

def test_best_partial_rows_are_kept_when_retries_exhausted(self):
"""The best partial LLM result should be used (not discarded) when
retries are exhausted via the insufficient-rows path."""
attempt_num = [0]

def improving_side_effect(messages, **kwargs):
attempt_num[0] += 1
# Return progressively more rows but always < 80 (threshold for 100)
rows = 50 + attempt_num[0] * 5 # 55, 60, 65 – all < 80
return self._csv_for_n_rows(rows)

generator = self._make_generator_with_mock_llm(llm_side_effect=improving_side_effect)
records, logs = generator.generate_batch(self.SCHEMA, num_rows=100, max_retries=3)

# Should have 100 rows (best partial 65 + 35 mock fill)
assert len(records) == 100
# The mock-fill log should be present
assert any("filling remainder with mock data" in log for log in logs)


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading