Skip to content

Improve PairICL device handling and archive legacy runner#8

Open
yananlong wants to merge 7 commits into
mainfrom
codex/integrate-pair_icl_tpu-functions-into-pairicl
Open

Improve PairICL device handling and archive legacy runner#8
yananlong wants to merge 7 commits into
mainfrom
codex/integrate-pair_icl_tpu-functions-into-pairicl

Conversation

@yananlong
Copy link
Copy Markdown
Owner

Summary

  • cast tabular features to float32 during preprocessing so GPU and TPU runs consume identical tensors
  • move embeddings and support tokens to accelerator memory with CUDA-aware transfers inside the predictor
  • preserve the original pair_icl_tpu.py implementation under archive/ for future reference

Testing

  • pytest

Codex Task

Copilot AI review requested due to automatic review settings February 7, 2026 05:52
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR modularizes the legacy pair_icl_tpu.py runner into a new PairICL package, adds a CLI and TPU/XLA helpers, and preserves the original monolithic script under archive/ while keeping the old entry point working.

Changes:

  • Introduce the PairICL package (data/model/predictor/CLI utilities) and include it in packaging.
  • Replace pair_icl_tpu.py with a legacy forwarder to PairICL.cli, and archive the original implementation.
  • Add unit tests covering pair embedding invariance, predictor behavior, and GPU-fingerprint-dataset conversion.

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
PairICL/data.py Implements dataset/IO utilities, including GPU fingerprint dataset conversion.
PairICL/models.py Adds model components and the order-invariant pair embedding builder.
PairICL/predictor.py Adds a high-level predictor that supports optional support tokens and XLA helpers.
PairICL/xla_utils.py Provides optional torch_xla wrappers (device selection, spawn, broadcast, reduce/gather).
PairICL/cli.py Adds a new CLI for preprocessing and prediction, including GPU dataset ingestion.
PairICL/__init__.py Exposes the package public API for imports/tests.
PairICL/__main__.py Enables python -m PairICL execution.
pair_icl_tpu.py Keeps backwards compatibility by forwarding to the new CLI entry point.
archive/pair_icl_tpu.py Preserves the legacy monolithic TPU script for reference.
tests/test_pairicl.py Adds new tests for PairICL core behaviors and dataset conversion.
pyproject.toml Adds PairICL to the packaged modules list.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread PairICL/data.py
Comment thread PairICL/xla_utils.py Outdated
Comment thread PairICL/predictor.py
Comment on lines +58 to +60
def _lookup_embeddings(self, indices: torch.Tensor) -> torch.Tensor:
vectors = [self.row_embeddings[int(idx)] for idx in indices.tolist()]
return torch.stack(vectors)
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_lookup_embeddings() uses indices.tolist() to drive Python dict lookups. When indices is on an accelerator (CUDA/XLA), tolist() forces a device sync + host transfer, which can severely degrade throughput (and negates the benefit of ParallelLoader). Consider keeping indices on CPU, or materialising row_embeddings into a single tensor and using tensor indexing instead of Python iteration.

Copilot uses AI. Check for mistakes.
Comment thread PairICL/predictor.py
Comment on lines +37 to +43
self.device = device or xla_utils.get_default_device()
self.icl_model = icl_model.to(self.device)
non_blocking = self.device.type == "cuda"
self.row_embeddings = {
int(k): v.to(self.device, non_blocking=non_blocking)
for k, v in row_embeddings.items()
}
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The constructor eagerly moves all row_embeddings tensors to self.device. For realistic datasets this can easily exhaust TPU/GPU memory and also duplicates memory on every replica. Consider making this behaviour optional (e.g., a flag to keep embeddings on CPU and transfer per-batch), or storing embeddings in a sharded/packed representation.

Copilot uses AI. Check for mistakes.
Comment thread PairICL/xla_utils.py Outdated
Comment thread archive/pair_icl_tpu.py Outdated
Comment thread archive/pair_icl_tpu.py Outdated
Comment thread archive/pair_icl_tpu.py Outdated
yananlong and others added 6 commits February 8, 2026 18:53
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 11 out of 11 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread PairICL/xla_utils.py
Comment on lines +35 to +39
if xla_available():
return xm.xla_device()
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_default_device() currently prefers xm.xla_device() whenever torch_xla imports successfully, even if the runtime is CPU XLA. That contradicts the documented priority order (TPU > CUDA > CPU) and can unexpectedly route GPU-capable environments to XLA/CPU, hurting performance or breaking expected device selection. Consider detecting actual TPU availability (e.g., checking xm.xla_device_hw() / PJRT_DEVICE) before selecting XLA, otherwise fall back to CUDA when available.

Copilot uses AI. Check for mistakes.
Comment thread PairICL/predictor.py
Comment on lines +87 to +88
idx_a = idx_a.to(self.device)
idx_b = idx_b.to(self.device)
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In predict_loader(), idx_a/idx_b are moved to the accelerator and then _lookup_embeddings() calls indices.tolist(). On CUDA this forces a device sync; on XLA it can trigger expensive host transfers/compilation barriers. Since the indices are only used for Python dict lookups, keep them on CPU (don’t .to(self.device) before lookup), or switch to a tensor-based embedding table so indexing stays on-device.

Suggested change
idx_a = idx_a.to(self.device)
idx_b = idx_b.to(self.device)
# Keep idx_a and idx_b on CPU for Python dict lookups in _lookup_embeddings

Copilot uses AI. Check for mistakes.
Comment thread pair_icl_tpu.py
Comment on lines +3 to +7
The original repository shipped a monolithic script that combined data loading,
model definitions and TPU orchestration. Those components now live inside the
:mod:`PairICL` package, but several external references still import this file.
To keep those references working we expose the same ``main`` function while
reusing the new implementation.
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This legacy entry point now forwards to PairICL.cli.main, but the new CLI uses subcommands (preprocess/predict) instead of the legacy --mode/--row_embeds/--pairs_csv flags. That means existing invocations like python pair_icl_tpu.py --mode predict ... will fail even though this file claims to keep legacy references working. Consider adding an argv-translation shim here (or in PairICL.cli) to accept the old flag-style interface and map it to the new subcommands.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants