From 2f4fbb260503aaaf982aa38579093e02339a2c46 Mon Sep 17 00:00:00 2001 From: Surya Sunkara Date: Sat, 23 May 2026 17:31:24 +0530 Subject: [PATCH] CSR optimisations --- src/amsa/backends/jax.py | 2 + src/amsa/backends/numpy.py | 57 ++++----- src/amsa/fusion.py | 74 +++++++++++- src/amsa/storage.py | 233 ++++++++++++++++++++++++++++++++----- tests/test_fusion.py | 104 ++++++++++++++++- 5 files changed, 411 insertions(+), 59 deletions(-) diff --git a/src/amsa/backends/jax.py b/src/amsa/backends/jax.py index 16d0a13..9b8c649 100644 --- a/src/amsa/backends/jax.py +++ b/src/amsa/backends/jax.py @@ -26,6 +26,7 @@ "Install with: uv pip install amsa-ga[jax]" ) from err +from amsa.fusion import optimize_sequence_ir from amsa.ir import ( ProductIR, SequenceIR, @@ -163,6 +164,7 @@ def execute_sequence_ir( ir: SequenceIR, ) -> Any: """Execute a ``SequenceIR`` step-by-step using JAX operations.""" + ir = optimize_sequence_ir(ir) env: dict[str, Any] = dict(inputs) for step in ir.steps: diff --git a/src/amsa/backends/numpy.py b/src/amsa/backends/numpy.py index 39a98fa..97d90e2 100644 --- a/src/amsa/backends/numpy.py +++ b/src/amsa/backends/numpy.py @@ -19,6 +19,7 @@ import numpy as np +from amsa.fusion import optimize_sequence_ir from amsa.ir import ( ProductIR, SequenceIR, @@ -30,6 +31,7 @@ from amsa.mv import MVArray from amsa.storage import ( CSRStorage, + _broadcast_row_indices, add_storage, coefficient_magnitude_squared_storage, gather_storage_columns, @@ -41,20 +43,6 @@ ) -def _broadcast_csr_rows(storage: CSRStorage, batch_shape: tuple[int, ...]) -> np.ndarray: - rows = np.arange(storage.row_count, dtype=np.intp).reshape(storage.batch_shape) - return np.broadcast_to(rows, batch_shape).reshape(-1) - - -def _csr_row_values(storage: CSRStorage, row: int) -> dict[int, Any]: - start = int(storage._payload.indptr[row]) - stop = int(storage._payload.indptr[row + 1]) - return { - int(storage._payload.indices[offset]): storage._payload.data[offset] - for offset in range(start, stop) - } - - def _execute_csr_product_ir( lhs: MVArray, rhs: MVArray, @@ -66,8 +54,14 @@ def _execute_csr_product_ir( layout = output_layout_from_product_ir(ir, lhs.algebra) dtype = np.dtype(np.result_type(lhs.dtype, rhs.dtype)) row_count = int(np.prod(batch_shape, dtype=np.intp)) - lhs_rows = _broadcast_csr_rows(lhs.storage, batch_shape) - rhs_rows = _broadcast_csr_rows(rhs.storage, batch_shape) + lhs_rows = _broadcast_row_indices(lhs.storage.batch_shape, batch_shape) + rhs_rows = _broadcast_row_indices(rhs.storage.batch_shape, batch_shape) + + terms_by_pair: dict[tuple[int, int], list[tuple[int, int]]] = {} + for term in ir.terms: + terms_by_pair.setdefault((term.lhs_col, term.rhs_col), []).append( + (term.out_col, term.coefficient) + ) data_values: list[Any] = [] index_values: list[int] = [] @@ -75,20 +69,26 @@ def _execute_csr_product_ir( nnz = 0 for out_row, (lhs_row, rhs_row) in enumerate(zip(lhs_rows, rhs_rows, strict=True)): - lhs_values = _csr_row_values(lhs.storage, int(lhs_row)) - rhs_values = _csr_row_values(rhs.storage, int(rhs_row)) + lhs_start = int(lhs.storage._payload.indptr[lhs_row]) + lhs_stop = int(lhs.storage._payload.indptr[lhs_row + 1]) + rhs_start = int(rhs.storage._payload.indptr[rhs_row]) + rhs_stop = int(rhs.storage._payload.indptr[rhs_row + 1]) out_values: dict[int, Any] = {} - for term in ir.terms: - lhs_value = lhs_values.get(term.lhs_col) - if lhs_value is None: - continue - rhs_value = rhs_values.get(term.rhs_col) - if rhs_value is None: - continue - out_values[term.out_col] = out_values.get( - term.out_col, dtype.type(0) - ) + term.coefficient * lhs_value * rhs_value + for lhs_offset in range(lhs_start, lhs_stop): + lhs_col = int(lhs.storage._payload.indices[lhs_offset]) + lhs_value = lhs.storage._payload.data[lhs_offset] + for rhs_offset in range(rhs_start, rhs_stop): + pair_terms = terms_by_pair.get( + (lhs_col, int(rhs.storage._payload.indices[rhs_offset])) + ) + if pair_terms is None: + continue + product = lhs_value * rhs.storage._payload.data[rhs_offset] + for out_col, coefficient in pair_terms: + out_values[out_col] = out_values.get( + out_col, dtype.type(0) + ) + coefficient * product for column in sorted(out_values): value = out_values[column] @@ -189,6 +189,7 @@ def execute_sequence_ir( sequence into a single kernel; the NumPy backend executes faithfully step-by-step, with optional fusion support for common patterns. """ + ir = optimize_sequence_ir(ir) env: dict[str, Any] = dict(inputs) i = 0 diff --git a/src/amsa/fusion.py b/src/amsa/fusion.py index 6dd0498..4feb994 100644 --- a/src/amsa/fusion.py +++ b/src/amsa/fusion.py @@ -21,9 +21,9 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Literal +from typing import Any, Literal -from amsa.ir import SequenceIR, SequenceStepKind +from amsa.ir import IRStep, SequenceIR, SequenceStepKind FusionKind = Literal[ "scale_product", # scale followed by binary product @@ -58,6 +58,74 @@ class FusionPattern: ) +def _freeze_metadata(value: Any) -> Any: + if isinstance(value, dict): + return tuple(sorted((key, _freeze_metadata(item)) for key, item in value.items())) + if isinstance(value, (tuple, list)): + return tuple(_freeze_metadata(item) for item in value) + if isinstance(value, set): + return tuple(sorted(_freeze_metadata(item) for item in value)) + try: + hash(value) + except TypeError: + return repr(value) + return value + + +def _cse_step_key(step: IRStep) -> tuple[Any, ...]: + return ( + step.kind, + step.operands, + step.ir, + _freeze_metadata(step.metadata), + ) + + +def eliminate_common_subexpressions(ir: SequenceIR) -> SequenceIR: + """Remove duplicate pure steps from a ``SequenceIR``. + + The pass is intentionally conservative: a step is common only when its kind, + remapped operands, lowered IR, and metadata are identical. Sequence steps are + pure coefficient operations, so later duplicate outputs can safely alias the + first output. + """ + output_aliases: dict[str, str] = {} + seen: dict[tuple[Any, ...], str] = {} + new_steps: list[IRStep] = [] + + for step in ir.steps: + operands = tuple(output_aliases.get(operand, operand) for operand in step.operands) + remapped = IRStep( + kind=step.kind, + operands=operands, + ir=step.ir, + output=step.output, + metadata=step.metadata, + fusion=None, + ) + key = _cse_step_key(remapped) + existing_output = seen.get(key) + if existing_output is not None: + output_aliases[step.output] = existing_output + continue + + seen[key] = step.output + output_aliases[step.output] = step.output + new_steps.append(remapped) + + return SequenceIR( + name=ir.name, + inputs=ir.inputs, + steps=tuple(new_steps), + result=output_aliases.get(ir.result, ir.result), + ) + + +def optimize_sequence_ir(ir: SequenceIR) -> SequenceIR: + """Apply conservative sequence optimizations used by eager backends.""" + return apply_fusion_metadata(eliminate_common_subexpressions(ir)) + + def analyze_fusion(ir: SequenceIR) -> dict[int, FusionKind]: """Analyze a SequenceIR and identify fusion opportunities. @@ -103,8 +171,6 @@ def apply_fusion_metadata(ir: SequenceIR) -> SequenceIR: Returns: A new SequenceIR with fusion metadata applied to fusible steps. """ - from amsa.ir import IRStep - fusion_opportunities = analyze_fusion(ir) # Rebuild steps with fusion metadata diff --git a/src/amsa/storage.py b/src/amsa/storage.py index 847e8c6..b3bc456 100644 --- a/src/amsa/storage.py +++ b/src/amsa/storage.py @@ -734,9 +734,140 @@ def _add_dense_csr_storage( return DenseStorage.from_array(values) +def _append_merged_csr_rows( + data_values: list[Any], + index_values: list[int], + left_data: NDArray[Any], + left_indices: NDArray[Any], + right_data: NDArray[Any], + right_indices: NDArray[Any], + sign: Literal[1, -1], +) -> int: + left_pos = 0 + right_pos = 0 + nnz = 0 + while left_pos < left_indices.size and right_pos < right_indices.size: + left_column = int(left_indices[left_pos]) + right_column = int(right_indices[right_pos]) + if left_column == right_column: + value = left_data[left_pos] + sign * right_data[right_pos] + if value != 0: + index_values.append(left_column) + data_values.append(value) + nnz += 1 + left_pos += 1 + right_pos += 1 + elif left_column < right_column: + index_values.append(left_column) + data_values.append(left_data[left_pos]) + nnz += 1 + left_pos += 1 + else: + index_values.append(right_column) + data_values.append(sign * right_data[right_pos]) + nnz += 1 + right_pos += 1 + + while left_pos < left_indices.size: + index_values.append(int(left_indices[left_pos])) + data_values.append(left_data[left_pos]) + nnz += 1 + left_pos += 1 + while right_pos < right_indices.size: + index_values.append(int(right_indices[right_pos])) + data_values.append(sign * right_data[right_pos]) + nnz += 1 + right_pos += 1 + return nnz + + +def _add_batched_csr_to_scalar_csr( + lhs: CSRStorage, + rhs: CSRStorage, + sign: Literal[1, -1], + result_dtype: np.dtype[Any], +) -> CSRStorage: + rhs_start = int(rhs._payload.indptr[0]) + rhs_stop = int(rhs._payload.indptr[1]) + rhs_data = rhs._payload.data[rhs_start:rhs_stop] + rhs_indices = rhs._payload.indices[rhs_start:rhs_stop] + data_values: list[Any] = [] + index_values: list[int] = [] + indptr = np.zeros(lhs.row_count + 1, dtype=np.intp) + + nnz = 0 + for row in range(lhs.row_count): + lhs_start = int(lhs._payload.indptr[row]) + lhs_stop = int(lhs._payload.indptr[row + 1]) + nnz += _append_merged_csr_rows( + data_values, + index_values, + lhs._payload.data[lhs_start:lhs_stop], + lhs._payload.indices[lhs_start:lhs_stop], + rhs_data, + rhs_indices, + sign, + ) + indptr[row + 1] = nnz + + return CSRStorage( + np.asarray(data_values, dtype=result_dtype), + np.asarray(index_values, dtype=np.intp), + indptr, + batch_shape=lhs.batch_shape, + width=lhs.width, + dtype=result_dtype, + ) + + +def _add_scalar_csr_to_batched_csr( + lhs: CSRStorage, + rhs: CSRStorage, + sign: Literal[1, -1], + result_dtype: np.dtype[Any], +) -> CSRStorage: + lhs_start = int(lhs._payload.indptr[0]) + lhs_stop = int(lhs._payload.indptr[1]) + lhs_data = lhs._payload.data[lhs_start:lhs_stop] + lhs_indices = lhs._payload.indices[lhs_start:lhs_stop] + data_values: list[Any] = [] + index_values: list[int] = [] + indptr = np.zeros(rhs.row_count + 1, dtype=np.intp) + + nnz = 0 + for row in range(rhs.row_count): + rhs_start = int(rhs._payload.indptr[row]) + rhs_stop = int(rhs._payload.indptr[row + 1]) + nnz += _append_merged_csr_rows( + data_values, + index_values, + lhs_data, + lhs_indices, + rhs._payload.data[rhs_start:rhs_stop], + rhs._payload.indices[rhs_start:rhs_stop], + sign, + ) + indptr[row + 1] = nnz + + return CSRStorage( + np.asarray(data_values, dtype=result_dtype), + np.asarray(index_values, dtype=np.intp), + indptr, + batch_shape=rhs.batch_shape, + width=lhs.width, + dtype=result_dtype, + ) + + def _add_csr_storage(lhs: CSRStorage, rhs: CSRStorage, sign: Literal[1, -1]) -> CSRStorage: _, batch_shape = _check_binary_storage_compatible(lhs, rhs) result_dtype = np.dtype(np.result_type(lhs.dtype, rhs.dtype)) + + if lhs.batch_shape == () and rhs.batch_shape != (): + return _add_scalar_csr_to_batched_csr(lhs, rhs, sign, result_dtype) + if rhs.batch_shape == () and lhs.batch_shape != (): + return _add_batched_csr_to_scalar_csr(lhs, rhs, sign, result_dtype) + data_values: list[Any] = [] index_values: list[int] = [] indptr = np.zeros(int(prod(batch_shape)) + 1, dtype=np.intp) @@ -745,26 +876,48 @@ def _add_csr_storage(lhs: CSRStorage, rhs: CSRStorage, sign: Literal[1, -1]) -> nnz = 0 for row, (lhs_row, rhs_row) in enumerate(zip(lhs_rows, rhs_rows, strict=True)): - row_values: dict[int, Any] = {} lhs_start = int(lhs._payload.indptr[lhs_row]) lhs_stop = int(lhs._payload.indptr[lhs_row + 1]) rhs_start = int(rhs._payload.indptr[rhs_row]) rhs_stop = int(rhs._payload.indptr[rhs_row + 1]) - - for offset in range(lhs_start, lhs_stop): - row_values[int(lhs._payload.indices[offset])] = lhs._payload.data[offset] - for offset in range(rhs_start, rhs_stop): - column = int(rhs._payload.indices[offset]) - rhs_value = rhs._payload.data[offset] - row_values[column] = row_values.get(column, result_dtype.type(0)) + sign * rhs_value - - for column in sorted(row_values): - value = row_values[column] - if value == 0: - continue - index_values.append(column) - data_values.append(value) + lhs_pos = lhs_start + rhs_pos = rhs_start + + while lhs_pos < lhs_stop and rhs_pos < rhs_stop: + lhs_column = int(lhs._payload.indices[lhs_pos]) + rhs_column = int(rhs._payload.indices[rhs_pos]) + if lhs_column == rhs_column: + value = lhs._payload.data[lhs_pos] + sign * rhs._payload.data[rhs_pos] + if value != 0: + index_values.append(lhs_column) + data_values.append(value) + nnz += 1 + lhs_pos += 1 + rhs_pos += 1 + elif lhs_column < rhs_column: + index_values.append(lhs_column) + data_values.append(lhs._payload.data[lhs_pos]) + nnz += 1 + lhs_pos += 1 + else: + index_values.append(rhs_column) + data_values.append( + sign * rhs._payload.data[rhs_pos] + ) + nnz += 1 + rhs_pos += 1 + + while lhs_pos < lhs_stop: + index_values.append(int(lhs._payload.indices[lhs_pos])) + data_values.append(lhs._payload.data[lhs_pos]) nnz += 1 + lhs_pos += 1 + while rhs_pos < rhs_stop: + index_values.append(int(rhs._payload.indices[rhs_pos])) + data_values.append(sign * rhs._payload.data[rhs_pos]) + nnz += 1 + rhs_pos += 1 + indptr[row + 1] = nnz return CSRStorage( @@ -847,6 +1000,27 @@ def index_csr_storage(storage: MVStorage, key: Any) -> MVStorage: if not isinstance(storage, CSRStorage): raise TypeError("index_csr_storage only supports CSRStorage.") + if isinstance(key, slice) and len(storage.batch_shape) == 1: + row_indices = np.arange(storage.row_count, dtype=np.intp)[key] + if row_indices.ndim == 0: + row_indices = row_indices.reshape(1) + if row_indices.size == 0: + return CSRStorage.zeros(storage.width, batch_shape=(0,), dtype=storage.dtype) + if row_indices.size == 1 or np.all(row_indices[1:] == row_indices[:-1] + 1): + first_row = int(row_indices[0]) + last_row = int(row_indices[-1]) + data_start = int(storage._payload.indptr[first_row]) + data_stop = int(storage._payload.indptr[last_row + 1]) + indptr = storage._payload.indptr[first_row : last_row + 2] - data_start + return CSRStorage( + storage._payload.data[data_start:data_stop].copy(), + storage._payload.indices[data_start:data_stop].copy(), + indptr.copy(), + batch_shape=(int(row_indices.size),), + width=storage.width, + dtype=storage.dtype, + ) + row_grid = np.arange(storage.row_count, dtype=np.intp).reshape(storage.batch_shape) selected_rows = row_grid[key] selected_array = np.asarray(selected_rows) @@ -857,23 +1031,30 @@ def index_csr_storage(storage: MVStorage, key: Any) -> MVStorage: result_batch_shape = selected_array.shape flat_selected_rows = selected_array.reshape(-1) - data_values: list[Any] = [] - index_values: list[int] = [] - indptr = np.zeros(flat_selected_rows.size + 1, dtype=np.intp) + lengths = ( + storage._payload.indptr[flat_selected_rows + 1] + - storage._payload.indptr[flat_selected_rows] + ) + indptr = np.empty(flat_selected_rows.size + 1, dtype=np.intp) + indptr[0] = 0 + np.cumsum(lengths, out=indptr[1:]) - nnz = 0 + nnz = int(indptr[-1]) + data = np.empty(nnz, dtype=storage.dtype) + indices = np.empty(nnz, dtype=np.intp) for out_row, source_row in enumerate(flat_selected_rows): start = int(storage._payload.indptr[int(source_row)]) stop = int(storage._payload.indptr[int(source_row) + 1]) - if start != stop: - data_values.extend(storage._payload.data[start:stop]) - index_values.extend(int(column) for column in storage._payload.indices[start:stop]) - nnz += stop - start - indptr[out_row + 1] = nnz + out_start = int(indptr[out_row]) + out_stop = int(indptr[out_row + 1]) + if out_start == out_stop: + continue + data[out_start:out_stop] = storage._payload.data[start:stop] + indices[out_start:out_stop] = storage._payload.indices[start:stop] return CSRStorage( - np.asarray(data_values, dtype=storage.dtype), - np.asarray(index_values, dtype=np.intp), + data, + indices, indptr, batch_shape=result_batch_shape, width=storage.width, diff --git a/tests/test_fusion.py b/tests/test_fusion.py index 85a757f..0733a21 100644 --- a/tests/test_fusion.py +++ b/tests/test_fusion.py @@ -1,6 +1,12 @@ """Tests for IR fusion analysis.""" -from amsa.fusion import FUSION_PATTERNS, analyze_fusion, apply_fusion_metadata +from amsa.fusion import ( + FUSION_PATTERNS, + analyze_fusion, + apply_fusion_metadata, + eliminate_common_subexpressions, + optimize_sequence_ir, +) from amsa.ir import IRStep, SequenceIR from tests._utils import assert_allclose @@ -305,3 +311,99 @@ def test_fusion_no_opportunity_unchanged(): # Verify no fusion metadata for step in fused_ir.steps: assert step.fusion is None + + +def test_eliminate_common_subexpressions_removes_duplicate_steps(): + ir = SequenceIR( + name="duplicate", + inputs=("input",), + steps=( + IRStep( + kind="scale", + operands=("input",), + ir=None, + output="scaled_a", + metadata={"factor": 2.0}, + ), + IRStep( + kind="scale", + operands=("input",), + ir=None, + output="scaled_b", + metadata={"factor": 2.0}, + ), + IRStep(kind="add", operands=("scaled_a", "scaled_b"), ir=None, output="result"), + ), + result="result", + ) + + optimized = eliminate_common_subexpressions(ir) + + assert len(optimized.steps) == 2 + assert optimized.steps[1].operands == ("scaled_a", "scaled_a") + assert optimized.result == "result" + + +def test_eliminate_common_subexpressions_can_alias_result(): + ir = SequenceIR( + name="duplicate_result", + inputs=("input",), + steps=( + IRStep( + kind="scale", + operands=("input",), + ir=None, + output="scaled_a", + metadata={"factor": 2.0}, + ), + IRStep( + kind="scale", + operands=("input",), + ir=None, + output="scaled_b", + metadata={"factor": 2.0}, + ), + ), + result="scaled_b", + ) + + optimized = eliminate_common_subexpressions(ir) + + assert len(optimized.steps) == 1 + assert optimized.result == "scaled_a" + + +def test_optimize_sequence_ir_runs_cse_before_fusion(): + ir = SequenceIR( + name="cse_then_fuse", + inputs=("input", "other"), + steps=( + IRStep( + kind="scale", + operands=("input",), + ir=None, + output="scaled_a", + metadata={"factor": 2.0}, + ), + IRStep( + kind="scale", + operands=("input",), + ir=None, + output="scaled_b", + metadata={"factor": 2.0}, + ), + IRStep( + kind="binary_product", + operands=("scaled_b", "other"), + ir=None, + output="result", + ), + ), + result="result", + ) + + optimized = optimize_sequence_ir(ir) + + assert len(optimized.steps) == 2 + assert optimized.steps[1].operands == ("scaled_a", "other") + assert optimized.steps[0].fusion == "scale_product"