Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions simplexity/generative_processes/data_prefetcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Background data prefetcher for overlapping data generation with training."""

from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor
from types import TracebackType


class DataPrefetcher[T]:
"""Prefetches training data in background threads to overlap with GPU training.

Uses a thread pool to generate upcoming batches while the current batch is being trained on.
This works because JAX JIT-compiled functions release the GIL, enabling genuine parallelism
between JAX data generation and PyTorch training.

Not thread-safe: ``prefetch``, ``get``, and ``shutdown`` should all be called from a single
thread. The background thread pool is managed internally.

Intended to be used as a context manager::

with DataPrefetcher(generate_fn, lookahead=1) as prefetcher:
for step in range(num_steps):
batch = prefetcher.get(step)
train(batch)

Args:
generate_fn: A function that takes a step number (int) and returns batch data.
lookahead: Number of future steps to prefetch. Defaults to 1.
"""

def __init__(self, generate_fn: Callable[[int], T], lookahead: int = 1) -> None:
self._generate_fn = generate_fn
self._lookahead = lookahead
self._executor = ThreadPoolExecutor(max_workers=lookahead + 1)
self._futures: dict[int, Future[T]] = {}

def __enter__(self) -> "DataPrefetcher[T]":
return self

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.shutdown()

def prefetch(self, step: int) -> None:
"""Submit a background task to generate data for the given step."""
if step not in self._futures:
self._futures[step] = self._executor.submit(self._generate_fn, step)

def get(self, step: int) -> T:
"""Return the generated data for the given step, blocking until ready.

Also triggers prefetch for the next `lookahead` steps and cleans up old futures.
"""
self.prefetch(step)
for s in range(step + 1, step + 1 + self._lookahead):
self.prefetch(s)
result = self._futures.pop(step).result()
old_keys = [k for k in self._futures if k < step]
for k in old_keys:
self._futures.pop(k).cancel()
return result

def shutdown(self) -> None:
"""Cancel pending futures and shut down the thread pool."""
for future in self._futures.values():
future.cancel()
self._futures.clear()
self._executor.shutdown(wait=False, cancel_futures=True)
139 changes: 139 additions & 0 deletions tests/generative_processes/test_data_prefetcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""Test the data prefetcher module."""

import threading
import time

import pytest

from simplexity.generative_processes.data_prefetcher import DataPrefetcher


def test_get_returns_correct_result():
"""Get should return the result of generate_fn called with the step number."""
with DataPrefetcher(lambda step: step * 10, lookahead=1) as prefetcher:
assert prefetcher.get(0) == 0
assert prefetcher.get(5) == 50
assert prefetcher.get(100) == 1000


def test_prefetch_submits_future():
"""Prefetch should submit a future that can be retrieved by get without re-computation."""
call_count = 0

def counting_fn(step: int) -> int:
nonlocal call_count
call_count += 1
return step

with DataPrefetcher(counting_fn, lookahead=1) as prefetcher:
prefetcher.prefetch(0)
result = prefetcher.get(0)
assert result == 0
assert call_count == 1


def test_lookahead_prefetches_future_steps():
"""Get should trigger prefetch for the next lookahead steps."""
prefetcher = DataPrefetcher(lambda step: step, lookahead=2)
prefetcher.get(0)
assert 1 in prefetcher._futures # noqa: SLF001 # pylint: disable=protected-access
assert 2 in prefetcher._futures # noqa: SLF001 # pylint: disable=protected-access
prefetcher.shutdown()


def test_get_cleans_up_old_futures():
"""Get should not hold references to old step results after advancing."""
call_count = 0

def counting_fn(step: int) -> int:
nonlocal call_count
call_count += 1
return step

with DataPrefetcher(counting_fn, lookahead=1) as prefetcher:
prefetcher.prefetch(0)
prefetcher.prefetch(1)
prefetcher.get(2)
old_count = call_count
prefetcher.get(0)
assert call_count > old_count


def test_error_propagation():
"""Exceptions in generate_fn should be re-raised by get."""

def failing_fn(step: int) -> int:
raise ValueError(f"step {step} failed")

with DataPrefetcher(failing_fn, lookahead=1) as prefetcher:
with pytest.raises(ValueError, match="step 3 failed"):

Check failure on line 69 in tests/generative_processes/test_data_prefetcher.py

View workflow job for this annotation

GitHub Actions / static-analysis

Ruff (SIM117)

tests/generative_processes/test_data_prefetcher.py:68:5: SIM117 Use a single `with` statement with multiple contexts instead of nested `with` statements
prefetcher.get(3)


def test_shutdown_does_not_hang():
"""Shutdown via context manager should return promptly even with pending futures."""

def slow_fn(step: int) -> int:
time.sleep(10)
return step

completed = threading.Event()

def run_shutdown():
with DataPrefetcher(slow_fn, lookahead=1) as prefetcher:
prefetcher.prefetch(0)
completed.set()

t = threading.Thread(target=run_shutdown)
t.start()
assert completed.wait(timeout=2), "Shutdown blocked for over 2s"


def test_context_manager_cleans_up_on_exception():
"""Context manager should shut down the executor even if an exception occurs."""

def identity(step: int) -> int:
return step

prefetcher: DataPrefetcher[int] | None = None
with pytest.raises(RuntimeError, match="boom"):
with DataPrefetcher(identity, lookahead=1) as p:

Check failure on line 100 in tests/generative_processes/test_data_prefetcher.py

View workflow job for this annotation

GitHub Actions / static-analysis

Ruff (SIM117)

tests/generative_processes/test_data_prefetcher.py:99:5: SIM117 Use a single `with` statement with multiple contexts instead of nested `with` statements
prefetcher = p
p.get(0)
raise RuntimeError("boom")

Check failure on line 103 in tests/generative_processes/test_data_prefetcher.py

View workflow job for this annotation

GitHub Actions / static-analysis

Ruff (PT012)

tests/generative_processes/test_data_prefetcher.py:99:5: PT012 `pytest.raises()` block should contain a single simple statement

assert prefetcher is not None
with pytest.raises(RuntimeError):
prefetcher.prefetch(99)


def test_generate_fn_runs_in_background_thread():
"""Generate function should execute in a different thread than the caller."""
caller_thread = threading.current_thread().ident
gen_thread_id: int | None = None

def capture_thread(step: int) -> int:
nonlocal gen_thread_id
gen_thread_id = threading.current_thread().ident
return step

with DataPrefetcher(capture_thread, lookahead=1) as prefetcher:
prefetcher.get(0)
assert gen_thread_id is not None
assert gen_thread_id != caller_thread


def test_duplicate_prefetch_is_noop():
"""Calling prefetch twice for the same step should not submit a second task."""
call_count = 0

def counting_fn(step: int) -> int:
nonlocal call_count
call_count += 1
return step

with DataPrefetcher(counting_fn, lookahead=1) as prefetcher:
prefetcher.prefetch(0)
prefetcher.prefetch(0)
prefetcher.get(0)
assert call_count == 1
Loading