Improve PairICL device handling and archive legacy runner#8
Conversation
There was a problem hiding this comment.
Pull request overview
This PR modularizes the legacy pair_icl_tpu.py runner into a new PairICL package, adds a CLI and TPU/XLA helpers, and preserves the original monolithic script under archive/ while keeping the old entry point working.
Changes:
- Introduce the
PairICLpackage (data/model/predictor/CLI utilities) and include it in packaging. - Replace
pair_icl_tpu.pywith a legacy forwarder toPairICL.cli, and archive the original implementation. - Add unit tests covering pair embedding invariance, predictor behavior, and GPU-fingerprint-dataset conversion.
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
PairICL/data.py |
Implements dataset/IO utilities, including GPU fingerprint dataset conversion. |
PairICL/models.py |
Adds model components and the order-invariant pair embedding builder. |
PairICL/predictor.py |
Adds a high-level predictor that supports optional support tokens and XLA helpers. |
PairICL/xla_utils.py |
Provides optional torch_xla wrappers (device selection, spawn, broadcast, reduce/gather). |
PairICL/cli.py |
Adds a new CLI for preprocessing and prediction, including GPU dataset ingestion. |
PairICL/__init__.py |
Exposes the package public API for imports/tests. |
PairICL/__main__.py |
Enables python -m PairICL execution. |
pair_icl_tpu.py |
Keeps backwards compatibility by forwarding to the new CLI entry point. |
archive/pair_icl_tpu.py |
Preserves the legacy monolithic TPU script for reference. |
tests/test_pairicl.py |
Adds new tests for PairICL core behaviors and dataset conversion. |
pyproject.toml |
Adds PairICL to the packaged modules list. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _lookup_embeddings(self, indices: torch.Tensor) -> torch.Tensor: | ||
| vectors = [self.row_embeddings[int(idx)] for idx in indices.tolist()] | ||
| return torch.stack(vectors) |
There was a problem hiding this comment.
_lookup_embeddings() uses indices.tolist() to drive Python dict lookups. When indices is on an accelerator (CUDA/XLA), tolist() forces a device sync + host transfer, which can severely degrade throughput (and negates the benefit of ParallelLoader). Consider keeping indices on CPU, or materialising row_embeddings into a single tensor and using tensor indexing instead of Python iteration.
| self.device = device or xla_utils.get_default_device() | ||
| self.icl_model = icl_model.to(self.device) | ||
| non_blocking = self.device.type == "cuda" | ||
| self.row_embeddings = { | ||
| int(k): v.to(self.device, non_blocking=non_blocking) | ||
| for k, v in row_embeddings.items() | ||
| } |
There was a problem hiding this comment.
The constructor eagerly moves all row_embeddings tensors to self.device. For realistic datasets this can easily exhaust TPU/GPU memory and also duplicates memory on every replica. Consider making this behaviour optional (e.g., a flag to keep embeddings on CPU and transfer per-batch), or storing embeddings in a sharded/packed representation.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 11 out of 11 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if xla_available(): | ||
| return xm.xla_device() | ||
| if torch.cuda.is_available(): | ||
| return torch.device("cuda") | ||
| return torch.device("cpu") |
There was a problem hiding this comment.
get_default_device() currently prefers xm.xla_device() whenever torch_xla imports successfully, even if the runtime is CPU XLA. That contradicts the documented priority order (TPU > CUDA > CPU) and can unexpectedly route GPU-capable environments to XLA/CPU, hurting performance or breaking expected device selection. Consider detecting actual TPU availability (e.g., checking xm.xla_device_hw() / PJRT_DEVICE) before selecting XLA, otherwise fall back to CUDA when available.
| idx_a = idx_a.to(self.device) | ||
| idx_b = idx_b.to(self.device) |
There was a problem hiding this comment.
In predict_loader(), idx_a/idx_b are moved to the accelerator and then _lookup_embeddings() calls indices.tolist(). On CUDA this forces a device sync; on XLA it can trigger expensive host transfers/compilation barriers. Since the indices are only used for Python dict lookups, keep them on CPU (don’t .to(self.device) before lookup), or switch to a tensor-based embedding table so indexing stays on-device.
| idx_a = idx_a.to(self.device) | |
| idx_b = idx_b.to(self.device) | |
| # Keep idx_a and idx_b on CPU for Python dict lookups in _lookup_embeddings |
| The original repository shipped a monolithic script that combined data loading, | ||
| model definitions and TPU orchestration. Those components now live inside the | ||
| :mod:`PairICL` package, but several external references still import this file. | ||
| To keep those references working we expose the same ``main`` function while | ||
| reusing the new implementation. |
There was a problem hiding this comment.
This legacy entry point now forwards to PairICL.cli.main, but the new CLI uses subcommands (preprocess/predict) instead of the legacy --mode/--row_embeds/--pairs_csv flags. That means existing invocations like python pair_icl_tpu.py --mode predict ... will fail even though this file claims to keep legacy references working. Consider adding an argv-translation shim here (or in PairICL.cli) to accept the old flag-style interface and map it to the new subcommands.
Summary
Testing
Codex Task