-
Notifications
You must be signed in to change notification settings - Fork 32
iris.x: Device-side communication + .ops APIs.
#296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR introduces iris.x, a new module providing device-side tile-level primitives for fine-grained collective operations. Unlike iris.ccl which handles full tensors with internal tiling, iris.x provides composable functions that users can call from their own kernels to manage tile iteration themselves.
Key Changes:
- New
iris.xmodule with tile-level communication primitives (all-reduce, all-gather, all-to-all, reduce-scatter) - Fused GEMM+Communication operations requiring tritonBLAS (gemm_all_reduce, gemm_all_gather, etc.)
- Comprehensive test suite for new primitives in
tests/x/ - CI/CD modernization with unified workflow replacing 3 separate workflows
- Documentation updates and benchmark enhancements
Reviewed changes
Copilot reviewed 33 out of 33 changed files in this pull request and generated 14 comments.
Show a summary per file
| File | Description |
|---|---|
iris/x/__init__.py |
Module initialization exposing all tile-level primitives with optional GEMM operations |
iris/x/all_reduce.py |
Five all-reduce variants (atomic, one-shot, two-shot, spinlock, ring) for different use cases |
iris/x/all_gather.py |
Tile-level all-gather primitive for gathering data from all ranks |
iris/x/all_to_all.py |
Tile-level all-to-all primitive for bidirectional data exchange |
iris/x/reduce_scatter.py |
Tile-level reduce-scatter that reduces and scatters to assigned ranks |
iris/x/gemm_all_reduce.py |
Fused GEMM + all-reduce using tritonBLAS stages |
iris/x/gemm_all_gather.py |
Fused GEMM + all-gather combining computation and communication |
iris/x/gemm_reduce_scatter.py |
Fused GEMM + reduce-scatter for column-parallel workloads |
iris/x/all_gather_gemm.py |
Fused all-gather + GEMM for tensor-parallel workloads |
iris/x/common.py |
Shared utilities for tile indexing and offset computation |
tests/x/test_*.py |
Comprehensive test suite validating all primitives against PyTorch references |
.github/workflows/iris-tests.yml |
New unified test workflow supporting multiple test directories and install methods |
.github/scripts/run_tests.sh |
Updated test runner with tritonBLAS installation for iris.x tests |
tests/ccl/test_all_reduce.py |
Modified to add explicit preamble calls for better test isolation |
pyproject.toml |
Added optional gemm dependency group for tritonBLAS |
docs/reference/examples.md |
Updated documentation with new example references |
benchmark/ccl/all_to_all/benchmark.py |
Added RCCL comparison benchmarking option |
|
@neoblizz we should be able to use |
Resolved conflicts by accepting main's changes for: - .gitignore - benchmark/ccl/*.py files - docker/Dockerfile - iris/ccl/*.py files
…eContext Refactor all tile-based collective operations and fused GEMM operators to use new object-oriented API, dramatically simplifying function signatures and improving code readability. Changes: - Collectives: all_gather, all_reduce (4 variants), reduce_scatter, all_to_all - Fused ops: all_gather_gemm, gemm_all_gather, gemm_all_reduce, gemm_reduce_scatter - Replace verbose parameter lists with OOP objects (Tile, TensorView, DeviceContext) - Add tl.constexpr annotations to all GEMM kernel parameters - Fix iris.load/atomic_add call signatures for correct argument ordering - Net reduction: -50 lines of code across 8 files
Update all test kernels to use new OOP API (Tile, TensorView, DeviceContext) and fix critical tile iteration bug causing test failures at scale. Changes: - Rename all test kernels to test_x_*_kernel pattern (avoids pytest warnings) - Update kernel calls to use OOP objects instead of verbose parameters - Fix tile iteration stride: use tl.num_programs(0) instead of 1 to prevent multiple CUs from processing the same tiles (fixes race conditions) - Fix all_to_all PyTorch reference to use .contiguous() chunks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 24 out of 24 changed files in this pull request and generated 10 comments.
mawad-amd
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. We need to move DeviceContext to iris.py but that can happen in a different PR. APIs looking good to me. Thanks!
| tile: Tile object with position and dimensions. | ||
| src_view: TensorView for input tensor. | ||
| dst_view: TensorView for output tensor. | ||
| locks_ptr: Pointer to locks array (one lock per tile). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Future work: it Would be great if we can get rid of this locks_ptr argument.
iris.x: Device-side communication + .x APIs.iris.x: Device-side communication + .ops APIs.
TLDR;
Introduces
iris.xAPIs.iris.x: Device-side Tile-Level Primitives for Fused Patternsiris.xprovides composable device-side tile-level primitives for fine-grained compute and collective operations within Triton kernels. Unlikeiris.cclwhich operates on full tensors with internal tiling,iris.xgives you direct control over tile iteration, enabling custom fusion patterns and fine-grained overlap of computation and communication.Key Differences from iris.ccl
Overview
iris.ccliris.xCore Abstractions
Tile
Represents a tile with position and dimensions:
TensorView
Describes tensor memory layout for device-side access:
DeviceContext
Encapsulates distributed context (rank, world size, heap bases):
AllReduceConfig
Configures all-reduce algorithm selection:
Collective Operations
All-Reduce
Reduce data across all ranks with support for multiple algorithms:
Algorithms:
atomic(default): Fine-grained atomic operationsring: Ring-based reductiontwo_shot: Two-shot algorithm for larger messagesone_shot: One-shot algorithm for smaller messagesspinlock: Lock-based synchronization for specific patternsAll-Gather
Gather data from all ranks along a specified dimension:
All-to-All
Personalized all-to-all exchange:
Reduce-Scatter
Reduce and scatter results to assigned ranks:
Usage Example
Here's a complete example showing custom tile iteration with all-reduce:
Submission Checklist