Skip to content

Conversation

@neoblizz
Copy link
Member

@neoblizz neoblizz commented Dec 9, 2025

TLDR;

Introduces iris.x APIs.

iris.x: Device-side Tile-Level Primitives for Fused Patterns

iris.x provides composable device-side tile-level primitives for fine-grained compute and collective operations within Triton kernels. Unlike iris.ccl which operates on full tensors with internal tiling, iris.x gives you direct control over tile iteration, enabling custom fusion patterns and fine-grained overlap of computation and communication.

Key Differences from iris.ccl

# iris.ccl: High-level, operates on full tensors
shmem.all_reduce(input_tensor, output_tensor)

# iris.x: Low-level, operates on tiles within your kernel
@triton.jit
def my_kernel(...):
    tile = iris.x.Tile(pid_m, pid_n, BLOCK_M, BLOCK_N)
    ctx.all_reduce(tile, src_view, dst_view)

Overview

Feature iris.ccl iris.x
Level Host-side, operates on full tensors Device-side, operates on tiles
Tiling Automatic, internal Manual, user-controlled
Control Simple, high-level Fine-grained, low-level
Use Case General collectives Custom fusion, overlap patterns

Core Abstractions

Tile

Represents a tile with position and dimensions:

tile = iris.x.Tile(pid_m, pid_n, BLOCK_M, BLOCK_N)

TensorView

Describes tensor memory layout for device-side access:

view = iris.x.TensorView(ptr, M, N, stride_m, stride_n)

DeviceContext

Encapsulates distributed context (rank, world size, heap bases):

ctx = iris.x.DeviceContext(rank, world_size, heap_bases)

AllReduceConfig

Configures all-reduce algorithm selection:

# Use default (atomic)
config = iris.x.AllReduceConfig()

# Use ring algorithm
config = iris.x.AllReduceConfig("ring")

# Use spinlock with locks
config = iris.x.AllReduceConfig("spinlock", locks_ptr)

Collective Operations

All-Reduce

Reduce data across all ranks with support for multiple algorithms:

# Context API (recommended)
ctx.all_reduce(tile, src_view, dst_view, config=config)

# Standalone API with specific algorithms
iris.x.all_reduce_atomic(tile, src_view, dst_view, ctx)     # Default: atomic
iris.x.all_reduce_ring(tile, src_view, dst_view, ctx)       # Ring algorithm
iris.x.all_reduce_two_shot(tile, src_view, dst_view, ctx)   # Two-shot
iris.x.all_reduce_one_shot(tile, src_view, dst_view, ctx)   # One-shot
iris.x.all_reduce_spinlock(tile, src_view, dst_view, ctx, locks_ptr, tile_id)  # Spinlock

Algorithms:

  • atomic (default): Fine-grained atomic operations
  • ring: Ring-based reduction
  • two_shot: Two-shot algorithm for larger messages
  • one_shot: One-shot algorithm for smaller messages
  • spinlock: Lock-based synchronization for specific patterns

All-Gather

Gather data from all ranks along a specified dimension:

# Context API
ctx.all_gather(tile, src_view, dst_view, dim=0)  # Gather along dimension 0

# Standalone API
iris.x.all_gather(tile, src_view, dst_view, dim, ctx)

All-to-All

Personalized all-to-all exchange:

# Context API
ctx.all_to_all(tile, src_view, dst_view, N_per_rank)

# Standalone API
iris.x.all_to_all(tile, src_view, dst_view, N_per_rank, ctx)

Reduce-Scatter

Reduce and scatter results to assigned ranks:

# Context API
ctx.reduce_scatter(tile, src_view, dst_view)

# Standalone API
iris.x.reduce_scatter(tile, src_view, dst_view, ctx)

Usage Example

Here's a complete example showing custom tile iteration with all-reduce:

import triton
import triton.language as tl
import iris.x

@triton.jit
def custom_kernel(
    input_ptr, output_ptr,
    M, N,
    stride_m, stride_n,
    heap_bases: tl.tensor,
    cur_rank: tl.constexpr,
    world_size: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    # Setup device context
    ctx = iris.x.DeviceContext(cur_rank, world_size, heap_bases)
    
    # Create tensor views
    src_view = iris.x.TensorView(input_ptr, M, N, stride_m, stride_n)
    dst_view = iris.x.TensorView(output_ptr, M, N, stride_m, stride_n)
    
    # Get program ID
    pid = tl.program_id(0)
    num_tiles_m = tl.cdiv(M, BLOCK_M)
    num_tiles_n = tl.cdiv(N, BLOCK_N)
    total_tiles = num_tiles_m * num_tiles_n
    
    # Persistent tile iteration
    for tile_id in range(pid, total_tiles, tl.num_programs(0)):
        # Compute tile coordinates
        pid_m = tile_id // num_tiles_n
        pid_n = tile_id % num_tiles_n
        
        # Create tile
        tile = iris.x.Tile(pid_m, pid_n, BLOCK_M, BLOCK_N)
        
        # *** Your custom computation here ***
        # Load tile, do computation, etc.
        
        # Perform all-reduce on this tile
        ctx.all_reduce(tile, src_view, dst_view)

Submission Checklist

@github-actions github-actions bot added in-progress We are working on it iris Iris project issue labels Dec 9, 2025
@neoblizz neoblizz changed the base branch from main to muhosama/ccl-more December 9, 2025 20:32
@neoblizz neoblizz requested a review from Copilot December 9, 2025 20:33
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces iris.x, a new module providing device-side tile-level primitives for fine-grained collective operations. Unlike iris.ccl which handles full tensors with internal tiling, iris.x provides composable functions that users can call from their own kernels to manage tile iteration themselves.

Key Changes:

  • New iris.x module with tile-level communication primitives (all-reduce, all-gather, all-to-all, reduce-scatter)
  • Fused GEMM+Communication operations requiring tritonBLAS (gemm_all_reduce, gemm_all_gather, etc.)
  • Comprehensive test suite for new primitives in tests/x/
  • CI/CD modernization with unified workflow replacing 3 separate workflows
  • Documentation updates and benchmark enhancements

Reviewed changes

Copilot reviewed 33 out of 33 changed files in this pull request and generated 14 comments.

Show a summary per file
File Description
iris/x/__init__.py Module initialization exposing all tile-level primitives with optional GEMM operations
iris/x/all_reduce.py Five all-reduce variants (atomic, one-shot, two-shot, spinlock, ring) for different use cases
iris/x/all_gather.py Tile-level all-gather primitive for gathering data from all ranks
iris/x/all_to_all.py Tile-level all-to-all primitive for bidirectional data exchange
iris/x/reduce_scatter.py Tile-level reduce-scatter that reduces and scatters to assigned ranks
iris/x/gemm_all_reduce.py Fused GEMM + all-reduce using tritonBLAS stages
iris/x/gemm_all_gather.py Fused GEMM + all-gather combining computation and communication
iris/x/gemm_reduce_scatter.py Fused GEMM + reduce-scatter for column-parallel workloads
iris/x/all_gather_gemm.py Fused all-gather + GEMM for tensor-parallel workloads
iris/x/common.py Shared utilities for tile indexing and offset computation
tests/x/test_*.py Comprehensive test suite validating all primitives against PyTorch references
.github/workflows/iris-tests.yml New unified test workflow supporting multiple test directories and install methods
.github/scripts/run_tests.sh Updated test runner with tritonBLAS installation for iris.x tests
tests/ccl/test_all_reduce.py Modified to add explicit preamble calls for better test isolation
pyproject.toml Added optional gemm dependency group for tritonBLAS
docs/reference/examples.md Updated documentation with new example references
benchmark/ccl/all_to_all/benchmark.py Added RCCL comparison benchmarking option

@mawad-amd
Copy link
Collaborator

@neoblizz we should be able to use aggregate to cleanup the APIs for device-side APIs. See https://godbolt.org/z/hY3oWfW1x

Resolved conflicts by accepting main's changes for:
- .gitignore
- benchmark/ccl/*.py files
- docker/Dockerfile
- iris/ccl/*.py files
…eContext

Refactor all tile-based collective operations and fused GEMM operators to use
new object-oriented API, dramatically simplifying function signatures and
improving code readability.

Changes:
- Collectives: all_gather, all_reduce (4 variants), reduce_scatter, all_to_all
- Fused ops: all_gather_gemm, gemm_all_gather, gemm_all_reduce, gemm_reduce_scatter
- Replace verbose parameter lists with OOP objects (Tile, TensorView, DeviceContext)
- Add tl.constexpr annotations to all GEMM kernel parameters
- Fix iris.load/atomic_add call signatures for correct argument ordering
- Net reduction: -50 lines of code across 8 files
Update all test kernels to use new OOP API (Tile, TensorView, DeviceContext)
and fix critical tile iteration bug causing test failures at scale.

Changes:
- Rename all test kernels to test_x_*_kernel pattern (avoids pytest warnings)
- Update kernel calls to use OOP objects instead of verbose parameters
- Fix tile iteration stride: use tl.num_programs(0) instead of 1 to prevent
  multiple CUs from processing the same tiles (fixes race conditions)
- Fix all_to_all PyTorch reference to use .contiguous() chunks
@neoblizz neoblizz marked this pull request as ready for review January 28, 2026 20:57
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 24 out of 24 changed files in this pull request and generated 10 comments.

Copy link
Collaborator

@mawad-amd mawad-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. We need to move DeviceContext to iris.py but that can happen in a different PR. APIs looking good to me. Thanks!

tile: Tile object with position and dimensions.
src_view: TensorView for input tensor.
dst_view: TensorView for output tensor.
locks_ptr: Pointer to locks array (one lock per tile).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Future work: it Would be great if we can get rid of this locks_ptr argument.

@neoblizz neoblizz changed the title iris.x: Device-side communication + .x APIs. iris.x: Device-side communication + .ops APIs. Jan 29, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

in-progress We are working on it iris Iris project issue

Projects

None yet

3 participants