-
Notifications
You must be signed in to change notification settings - Fork 2
thread-based data fetching #176
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| """Background data prefetcher for overlapping data generation with training.""" | ||
|
|
||
| from collections.abc import Callable | ||
| from concurrent.futures import Future, ThreadPoolExecutor | ||
| from types import TracebackType | ||
|
|
||
|
|
||
| class DataPrefetcher[T]: | ||
| """Prefetches training data in background threads to overlap with GPU training. | ||
|
|
||
| Uses a thread pool to generate upcoming batches while the current batch is being trained on. | ||
casperlchristensen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| This works because JAX JIT-compiled functions release the GIL, enabling genuine parallelism | ||
| between JAX data generation and PyTorch training. | ||
|
|
||
| Not thread-safe: ``prefetch``, ``get``, and ``shutdown`` should all be called from a single | ||
| thread. The background thread pool is managed internally. | ||
|
|
||
| Intended to be used as a context manager:: | ||
|
|
||
| with DataPrefetcher(generate_fn, lookahead=1) as prefetcher: | ||
| for step in range(num_steps): | ||
| batch = prefetcher.get(step) | ||
| train(batch) | ||
|
|
||
| Args: | ||
| generate_fn: A function that takes a step number (int) and returns batch data. | ||
| lookahead: Number of future steps to prefetch. Defaults to 1. | ||
| """ | ||
|
|
||
| def __init__(self, generate_fn: Callable[[int], T], lookahead: int = 1) -> None: | ||
| self._generate_fn = generate_fn | ||
| self._lookahead = lookahead | ||
| self._executor = ThreadPoolExecutor(max_workers=lookahead + 1) | ||
| self._futures: dict[int, Future[T]] = {} | ||
|
|
||
| def __enter__(self) -> "DataPrefetcher[T]": | ||
| return self | ||
|
|
||
| def __exit__( | ||
| self, | ||
| exc_type: type[BaseException] | None, | ||
| exc_val: BaseException | None, | ||
| exc_tb: TracebackType | None, | ||
| ) -> None: | ||
| self.shutdown() | ||
|
|
||
| def prefetch(self, step: int) -> None: | ||
| """Submit a background task to generate data for the given step.""" | ||
| if step not in self._futures: | ||
| self._futures[step] = self._executor.submit(self._generate_fn, step) | ||
|
|
||
| def get(self, step: int) -> T: | ||
| """Return the generated data for the given step, blocking until ready. | ||
|
|
||
| Also triggers prefetch for the next `lookahead` steps and cleans up old futures. | ||
| """ | ||
| self.prefetch(step) | ||
| for s in range(step + 1, step + 1 + self._lookahead): | ||
| self.prefetch(s) | ||
| result = self._futures.pop(step).result() | ||
| old_keys = [k for k in self._futures if k < step] | ||
| for k in old_keys: | ||
| self._futures.pop(k).cancel() | ||
| return result | ||
|
|
||
| def shutdown(self) -> None: | ||
| """Cancel pending futures and shut down the thread pool.""" | ||
| for future in self._futures.values(): | ||
| future.cancel() | ||
| self._futures.clear() | ||
| self._executor.shutdown(wait=False, cancel_futures=True) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,139 @@ | ||
| """Test the data prefetcher module.""" | ||
|
|
||
| import threading | ||
| import time | ||
|
|
||
| import pytest | ||
|
|
||
| from simplexity.generative_processes.data_prefetcher import DataPrefetcher | ||
|
|
||
|
|
||
| def test_get_returns_correct_result(): | ||
| """Get should return the result of generate_fn called with the step number.""" | ||
| with DataPrefetcher(lambda step: step * 10, lookahead=1) as prefetcher: | ||
| assert prefetcher.get(0) == 0 | ||
| assert prefetcher.get(5) == 50 | ||
| assert prefetcher.get(100) == 1000 | ||
|
|
||
|
|
||
| def test_prefetch_submits_future(): | ||
| """Prefetch should submit a future that can be retrieved by get without re-computation.""" | ||
| call_count = 0 | ||
|
|
||
| def counting_fn(step: int) -> int: | ||
| nonlocal call_count | ||
| call_count += 1 | ||
| return step | ||
|
|
||
| with DataPrefetcher(counting_fn, lookahead=1) as prefetcher: | ||
| prefetcher.prefetch(0) | ||
| result = prefetcher.get(0) | ||
| assert result == 0 | ||
| assert call_count == 1 | ||
|
|
||
|
|
||
| def test_lookahead_prefetches_future_steps(): | ||
casperlchristensen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Get should trigger prefetch for the next lookahead steps.""" | ||
| prefetcher = DataPrefetcher(lambda step: step, lookahead=2) | ||
| prefetcher.get(0) | ||
| assert 1 in prefetcher._futures # noqa: SLF001 # pylint: disable=protected-access | ||
| assert 2 in prefetcher._futures # noqa: SLF001 # pylint: disable=protected-access | ||
| prefetcher.shutdown() | ||
|
|
||
|
|
||
| def test_get_cleans_up_old_futures(): | ||
| """Get should not hold references to old step results after advancing.""" | ||
| call_count = 0 | ||
|
|
||
| def counting_fn(step: int) -> int: | ||
| nonlocal call_count | ||
| call_count += 1 | ||
| return step | ||
|
|
||
| with DataPrefetcher(counting_fn, lookahead=1) as prefetcher: | ||
| prefetcher.prefetch(0) | ||
| prefetcher.prefetch(1) | ||
| prefetcher.get(2) | ||
| old_count = call_count | ||
| prefetcher.get(0) | ||
| assert call_count > old_count | ||
|
|
||
|
|
||
| def test_error_propagation(): | ||
| """Exceptions in generate_fn should be re-raised by get.""" | ||
|
|
||
| def failing_fn(step: int) -> int: | ||
| raise ValueError(f"step {step} failed") | ||
|
|
||
| with DataPrefetcher(failing_fn, lookahead=1) as prefetcher: | ||
| with pytest.raises(ValueError, match="step 3 failed"): | ||
|
Check failure on line 69 in tests/generative_processes/test_data_prefetcher.py
|
||
| prefetcher.get(3) | ||
|
|
||
|
|
||
| def test_shutdown_does_not_hang(): | ||
casperlchristensen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Shutdown via context manager should return promptly even with pending futures.""" | ||
|
|
||
| def slow_fn(step: int) -> int: | ||
| time.sleep(10) | ||
| return step | ||
|
|
||
| completed = threading.Event() | ||
|
|
||
| def run_shutdown(): | ||
| with DataPrefetcher(slow_fn, lookahead=1) as prefetcher: | ||
| prefetcher.prefetch(0) | ||
| completed.set() | ||
|
|
||
| t = threading.Thread(target=run_shutdown) | ||
| t.start() | ||
| assert completed.wait(timeout=2), "Shutdown blocked for over 2s" | ||
|
|
||
|
|
||
| def test_context_manager_cleans_up_on_exception(): | ||
| """Context manager should shut down the executor even if an exception occurs.""" | ||
|
|
||
| def identity(step: int) -> int: | ||
| return step | ||
|
|
||
| prefetcher: DataPrefetcher[int] | None = None | ||
| with pytest.raises(RuntimeError, match="boom"): | ||
| with DataPrefetcher(identity, lookahead=1) as p: | ||
|
Check failure on line 100 in tests/generative_processes/test_data_prefetcher.py
|
||
| prefetcher = p | ||
| p.get(0) | ||
| raise RuntimeError("boom") | ||
|
|
||
| assert prefetcher is not None | ||
| with pytest.raises(RuntimeError): | ||
| prefetcher.prefetch(99) | ||
|
|
||
|
|
||
| def test_generate_fn_runs_in_background_thread(): | ||
| """Generate function should execute in a different thread than the caller.""" | ||
| caller_thread = threading.current_thread().ident | ||
| gen_thread_id: int | None = None | ||
|
|
||
| def capture_thread(step: int) -> int: | ||
| nonlocal gen_thread_id | ||
| gen_thread_id = threading.current_thread().ident | ||
| return step | ||
|
|
||
| with DataPrefetcher(capture_thread, lookahead=1) as prefetcher: | ||
| prefetcher.get(0) | ||
| assert gen_thread_id is not None | ||
| assert gen_thread_id != caller_thread | ||
|
|
||
|
|
||
| def test_duplicate_prefetch_is_noop(): | ||
| """Calling prefetch twice for the same step should not submit a second task.""" | ||
| call_count = 0 | ||
|
|
||
| def counting_fn(step: int) -> int: | ||
| nonlocal call_count | ||
| call_count += 1 | ||
| return step | ||
|
|
||
| with DataPrefetcher(counting_fn, lookahead=1) as prefetcher: | ||
| prefetcher.prefetch(0) | ||
| prefetcher.prefetch(0) | ||
| prefetcher.get(0) | ||
| assert call_count == 1 | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.