From ebf5b7e842b45868adff854a2ecb952450658591 Mon Sep 17 00:00:00 2001 From: Casper Lutzhoft Christensen Date: Thu, 5 Mar 2026 12:04:34 -0800 Subject: [PATCH 1/3] thread-based data fetching --- .../generative_processes/data_prefetcher.py | 49 +++++++ .../test_data_prefetcher.py | 130 ++++++++++++++++++ 2 files changed, 179 insertions(+) create mode 100644 simplexity/generative_processes/data_prefetcher.py create mode 100644 tests/generative_processes/test_data_prefetcher.py diff --git a/simplexity/generative_processes/data_prefetcher.py b/simplexity/generative_processes/data_prefetcher.py new file mode 100644 index 00000000..8ab2cd79 --- /dev/null +++ b/simplexity/generative_processes/data_prefetcher.py @@ -0,0 +1,49 @@ +"""Background data prefetcher for overlapping data generation with training.""" + +from collections.abc import Callable +from concurrent.futures import Future, ThreadPoolExecutor + + +class DataPrefetcher[T]: + """Prefetches training data in background threads to overlap with GPU training. + + Uses a thread pool to generate upcoming batches while the current batch is being trained on. + This works because JAX JIT-compiled functions release the GIL, enabling genuine parallelism + between JAX data generation and PyTorch training. + + Args: + generate_fn: A function that takes a step number (int) and returns batch data. + lookahead: Number of future steps to prefetch. Defaults to 1. + """ + + def __init__(self, generate_fn: Callable[[int], T], lookahead: int = 1) -> None: + self._generate_fn = generate_fn + self._lookahead = lookahead + self._executor = ThreadPoolExecutor(max_workers=lookahead) + self._futures: dict[int, Future[T]] = {} + + def prefetch(self, step: int) -> None: + """Submit a background task to generate data for the given step.""" + if step not in self._futures: + self._futures[step] = self._executor.submit(self._generate_fn, step) + + def get(self, step: int) -> T: + """Return the generated data for the given step, blocking until ready. + + Also triggers prefetch for the next `lookahead` steps and cleans up old futures. + """ + self.prefetch(step) + for s in range(step + 1, step + 1 + self._lookahead): + self.prefetch(s) + result = self._futures.pop(step).result() + old_keys = [k for k in self._futures if k < step] + for k in old_keys: + self._futures.pop(k) + return result + + def shutdown(self) -> None: + """Shut down the thread pool. + + Does not wait for pending futures to complete. + """ + self._executor.shutdown(wait=False) diff --git a/tests/generative_processes/test_data_prefetcher.py b/tests/generative_processes/test_data_prefetcher.py new file mode 100644 index 00000000..3e733413 --- /dev/null +++ b/tests/generative_processes/test_data_prefetcher.py @@ -0,0 +1,130 @@ +"""Test the data prefetcher module.""" + +import threading +import time + +import pytest + +from simplexity.generative_processes.data_prefetcher import DataPrefetcher + + +def test_get_returns_correct_result(): + """Get should return the result of generate_fn called with the step number.""" + prefetcher = DataPrefetcher(lambda step: step * 10, lookahead=1) + assert prefetcher.get(0) == 0 + assert prefetcher.get(5) == 50 + assert prefetcher.get(100) == 1000 + prefetcher.shutdown() + + +def test_prefetch_submits_future(): + """Prefetch should submit a future that can be retrieved by get without re-computation.""" + call_count = 0 + + def counting_fn(step: int) -> int: + nonlocal call_count + call_count += 1 + return step + + prefetcher = DataPrefetcher(counting_fn, lookahead=1) + prefetcher.prefetch(0) + result = prefetcher.get(0) + assert result == 0 + assert call_count == 1 + prefetcher.shutdown() + + +def test_lookahead_prefetches_future_steps(): + """Get should trigger prefetch for the next lookahead steps.""" + called_steps: list[int] = [] + + def tracking_fn(step: int) -> int: + called_steps.append(step) + return step + + prefetcher = DataPrefetcher(tracking_fn, lookahead=2) + prefetcher.get(0) + prefetcher.get(1) + prefetcher.get(2) + assert sorted(set(called_steps)) == [0, 1, 2] + assert called_steps.count(0) == 1 + assert called_steps.count(1) == 1 + assert called_steps.count(2) == 1 + prefetcher.shutdown() + + +def test_get_cleans_up_old_futures(): + """Get should not hold references to old step results after advancing.""" + call_count = 0 + + def counting_fn(step: int) -> int: + nonlocal call_count + call_count += 1 + return step + + prefetcher = DataPrefetcher(counting_fn, lookahead=1) + prefetcher.prefetch(0) + prefetcher.prefetch(1) + prefetcher.get(2) + old_count = call_count + prefetcher.get(0) + assert call_count > old_count + prefetcher.shutdown() + + +def test_error_propagation(): + """Exceptions in generate_fn should be re-raised by get.""" + + def failing_fn(step: int) -> int: + raise ValueError(f"step {step} failed") + + prefetcher = DataPrefetcher(failing_fn, lookahead=1) + with pytest.raises(ValueError, match="step 3 failed"): + prefetcher.get(3) + prefetcher.shutdown() + + +def test_shutdown_does_not_hang(): + """Shutdown should return promptly even with pending futures.""" + + def slow_fn(step: int) -> int: + time.sleep(10) + return step + + prefetcher = DataPrefetcher(slow_fn, lookahead=1) + prefetcher.prefetch(0) + prefetcher.shutdown() + + +def test_generate_fn_runs_in_background_thread(): + """Generate function should execute in a different thread than the caller.""" + caller_thread = threading.current_thread().ident + gen_thread_id: int | None = None + + def capture_thread(step: int) -> int: + nonlocal gen_thread_id + gen_thread_id = threading.current_thread().ident + return step + + prefetcher = DataPrefetcher(capture_thread, lookahead=1) + prefetcher.get(0) + assert gen_thread_id is not None + assert gen_thread_id != caller_thread + prefetcher.shutdown() + + +def test_duplicate_prefetch_is_noop(): + """Calling prefetch twice for the same step should not submit a second task.""" + call_count = 0 + + def counting_fn(step: int) -> int: + nonlocal call_count + call_count += 1 + return step + + prefetcher = DataPrefetcher(counting_fn, lookahead=1) + prefetcher.prefetch(0) + prefetcher.prefetch(0) + prefetcher.get(0) + assert call_count == 1 + prefetcher.shutdown() From 191cc2eeeb449478741876b9b79f0a718b0021f6 Mon Sep 17 00:00:00 2001 From: Casper Lutzhoft Christensen Date: Mon, 9 Mar 2026 10:05:19 -0700 Subject: [PATCH 2/3] feedback --- .../generative_processes/data_prefetcher.py | 36 ++++-- .../test_data_prefetcher.py | 110 ++++++++++-------- 2 files changed, 88 insertions(+), 58 deletions(-) diff --git a/simplexity/generative_processes/data_prefetcher.py b/simplexity/generative_processes/data_prefetcher.py index 8ab2cd79..5f6b5a47 100644 --- a/simplexity/generative_processes/data_prefetcher.py +++ b/simplexity/generative_processes/data_prefetcher.py @@ -2,6 +2,7 @@ from collections.abc import Callable from concurrent.futures import Future, ThreadPoolExecutor +from types import TracebackType class DataPrefetcher[T]: @@ -11,6 +12,16 @@ class DataPrefetcher[T]: This works because JAX JIT-compiled functions release the GIL, enabling genuine parallelism between JAX data generation and PyTorch training. + Not thread-safe: ``prefetch``, ``get``, and ``shutdown`` should all be called from a single + thread. The background thread pool is managed internally. + + Intended to be used as a context manager:: + + with DataPrefetcher(generate_fn, lookahead=1) as prefetcher: + for step in range(num_steps): + batch = prefetcher.get(step) + train(batch) + Args: generate_fn: A function that takes a step number (int) and returns batch data. lookahead: Number of future steps to prefetch. Defaults to 1. @@ -19,9 +30,20 @@ class DataPrefetcher[T]: def __init__(self, generate_fn: Callable[[int], T], lookahead: int = 1) -> None: self._generate_fn = generate_fn self._lookahead = lookahead - self._executor = ThreadPoolExecutor(max_workers=lookahead) + self._executor = ThreadPoolExecutor(max_workers=lookahead + 1) self._futures: dict[int, Future[T]] = {} + def __enter__(self) -> "DataPrefetcher[T]": + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.shutdown() + def prefetch(self, step: int) -> None: """Submit a background task to generate data for the given step.""" if step not in self._futures: @@ -38,12 +60,12 @@ def get(self, step: int) -> T: result = self._futures.pop(step).result() old_keys = [k for k in self._futures if k < step] for k in old_keys: - self._futures.pop(k) + self._futures.pop(k).cancel() return result def shutdown(self) -> None: - """Shut down the thread pool. - - Does not wait for pending futures to complete. - """ - self._executor.shutdown(wait=False) + """Cancel pending futures and shut down the thread pool.""" + for future in self._futures.values(): + future.cancel() + self._futures.clear() + self._executor.shutdown(wait=False, cancel_futures=True) diff --git a/tests/generative_processes/test_data_prefetcher.py b/tests/generative_processes/test_data_prefetcher.py index 3e733413..b38aa195 100644 --- a/tests/generative_processes/test_data_prefetcher.py +++ b/tests/generative_processes/test_data_prefetcher.py @@ -10,11 +10,10 @@ def test_get_returns_correct_result(): """Get should return the result of generate_fn called with the step number.""" - prefetcher = DataPrefetcher(lambda step: step * 10, lookahead=1) - assert prefetcher.get(0) == 0 - assert prefetcher.get(5) == 50 - assert prefetcher.get(100) == 1000 - prefetcher.shutdown() + with DataPrefetcher(lambda step: step * 10, lookahead=1) as prefetcher: + assert prefetcher.get(0) == 0 + assert prefetcher.get(5) == 50 + assert prefetcher.get(100) == 1000 def test_prefetch_submits_future(): @@ -26,30 +25,19 @@ def counting_fn(step: int) -> int: call_count += 1 return step - prefetcher = DataPrefetcher(counting_fn, lookahead=1) - prefetcher.prefetch(0) - result = prefetcher.get(0) - assert result == 0 - assert call_count == 1 - prefetcher.shutdown() + with DataPrefetcher(counting_fn, lookahead=1) as prefetcher: + prefetcher.prefetch(0) + result = prefetcher.get(0) + assert result == 0 + assert call_count == 1 def test_lookahead_prefetches_future_steps(): """Get should trigger prefetch for the next lookahead steps.""" - called_steps: list[int] = [] - - def tracking_fn(step: int) -> int: - called_steps.append(step) - return step - - prefetcher = DataPrefetcher(tracking_fn, lookahead=2) + prefetcher = DataPrefetcher(lambda step: step, lookahead=2) prefetcher.get(0) - prefetcher.get(1) - prefetcher.get(2) - assert sorted(set(called_steps)) == [0, 1, 2] - assert called_steps.count(0) == 1 - assert called_steps.count(1) == 1 - assert called_steps.count(2) == 1 + assert 1 in prefetcher._futures # noqa: SLF001 + assert 2 in prefetcher._futures # noqa: SLF001 prefetcher.shutdown() @@ -62,14 +50,13 @@ def counting_fn(step: int) -> int: call_count += 1 return step - prefetcher = DataPrefetcher(counting_fn, lookahead=1) - prefetcher.prefetch(0) - prefetcher.prefetch(1) - prefetcher.get(2) - old_count = call_count - prefetcher.get(0) - assert call_count > old_count - prefetcher.shutdown() + with DataPrefetcher(counting_fn, lookahead=1) as prefetcher: + prefetcher.prefetch(0) + prefetcher.prefetch(1) + prefetcher.get(2) + old_count = call_count + prefetcher.get(0) + assert call_count > old_count def test_error_propagation(): @@ -78,22 +65,45 @@ def test_error_propagation(): def failing_fn(step: int) -> int: raise ValueError(f"step {step} failed") - prefetcher = DataPrefetcher(failing_fn, lookahead=1) - with pytest.raises(ValueError, match="step 3 failed"): - prefetcher.get(3) - prefetcher.shutdown() + with DataPrefetcher(failing_fn, lookahead=1) as prefetcher: + with pytest.raises(ValueError, match="step 3 failed"): + prefetcher.get(3) def test_shutdown_does_not_hang(): - """Shutdown should return promptly even with pending futures.""" + """Shutdown via context manager should return promptly even with pending futures.""" def slow_fn(step: int) -> int: time.sleep(10) return step - prefetcher = DataPrefetcher(slow_fn, lookahead=1) - prefetcher.prefetch(0) - prefetcher.shutdown() + completed = threading.Event() + + def run_shutdown(): + with DataPrefetcher(slow_fn, lookahead=1) as prefetcher: + prefetcher.prefetch(0) + completed.set() + + t = threading.Thread(target=run_shutdown) + t.start() + assert completed.wait(timeout=2), "Shutdown blocked for over 2s" + + +def test_context_manager_cleans_up_on_exception(): + """Context manager should shut down the executor even if an exception occurs.""" + + def identity(step: int) -> int: + return step + + prefetcher: DataPrefetcher[int] | None = None + with pytest.raises(RuntimeError, match="boom"): + with DataPrefetcher(identity, lookahead=1) as p: + prefetcher = p + p.get(0) + raise RuntimeError("boom") + + assert prefetcher is not None + assert prefetcher._executor._shutdown def test_generate_fn_runs_in_background_thread(): @@ -106,11 +116,10 @@ def capture_thread(step: int) -> int: gen_thread_id = threading.current_thread().ident return step - prefetcher = DataPrefetcher(capture_thread, lookahead=1) - prefetcher.get(0) - assert gen_thread_id is not None - assert gen_thread_id != caller_thread - prefetcher.shutdown() + with DataPrefetcher(capture_thread, lookahead=1) as prefetcher: + prefetcher.get(0) + assert gen_thread_id is not None + assert gen_thread_id != caller_thread def test_duplicate_prefetch_is_noop(): @@ -122,9 +131,8 @@ def counting_fn(step: int) -> int: call_count += 1 return step - prefetcher = DataPrefetcher(counting_fn, lookahead=1) - prefetcher.prefetch(0) - prefetcher.prefetch(0) - prefetcher.get(0) - assert call_count == 1 - prefetcher.shutdown() + with DataPrefetcher(counting_fn, lookahead=1) as prefetcher: + prefetcher.prefetch(0) + prefetcher.prefetch(0) + prefetcher.get(0) + assert call_count == 1 From 73aff6145b7867840d9da00e658fb0cb0f851b1c Mon Sep 17 00:00:00 2001 From: Casper Lutzhoft Christensen Date: Mon, 9 Mar 2026 10:47:53 -0700 Subject: [PATCH 3/3] linting --- tests/generative_processes/test_data_prefetcher.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/generative_processes/test_data_prefetcher.py b/tests/generative_processes/test_data_prefetcher.py index b38aa195..2de88c52 100644 --- a/tests/generative_processes/test_data_prefetcher.py +++ b/tests/generative_processes/test_data_prefetcher.py @@ -36,8 +36,8 @@ def test_lookahead_prefetches_future_steps(): """Get should trigger prefetch for the next lookahead steps.""" prefetcher = DataPrefetcher(lambda step: step, lookahead=2) prefetcher.get(0) - assert 1 in prefetcher._futures # noqa: SLF001 - assert 2 in prefetcher._futures # noqa: SLF001 + assert 1 in prefetcher._futures # noqa: SLF001 # pylint: disable=protected-access + assert 2 in prefetcher._futures # noqa: SLF001 # pylint: disable=protected-access prefetcher.shutdown() @@ -103,7 +103,8 @@ def identity(step: int) -> int: raise RuntimeError("boom") assert prefetcher is not None - assert prefetcher._executor._shutdown + with pytest.raises(RuntimeError): + prefetcher.prefetch(99) def test_generate_fn_runs_in_background_thread():