Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/amsa/backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"Install with: uv pip install amsa-ga[jax]"
) from err

from amsa.fusion import optimize_sequence_ir
from amsa.ir import (
ProductIR,
SequenceIR,
Expand Down Expand Up @@ -163,6 +164,7 @@ def execute_sequence_ir(
ir: SequenceIR,
) -> Any:
"""Execute a ``SequenceIR`` step-by-step using JAX operations."""
ir = optimize_sequence_ir(ir)
env: dict[str, Any] = dict(inputs)

for step in ir.steps:
Expand Down
57 changes: 29 additions & 28 deletions src/amsa/backends/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import numpy as np

from amsa.fusion import optimize_sequence_ir
from amsa.ir import (
ProductIR,
SequenceIR,
Expand All @@ -30,6 +31,7 @@
from amsa.mv import MVArray
from amsa.storage import (
CSRStorage,
_broadcast_row_indices,
add_storage,
coefficient_magnitude_squared_storage,
gather_storage_columns,
Expand All @@ -41,20 +43,6 @@
)


def _broadcast_csr_rows(storage: CSRStorage, batch_shape: tuple[int, ...]) -> np.ndarray:
rows = np.arange(storage.row_count, dtype=np.intp).reshape(storage.batch_shape)
return np.broadcast_to(rows, batch_shape).reshape(-1)


def _csr_row_values(storage: CSRStorage, row: int) -> dict[int, Any]:
start = int(storage._payload.indptr[row])
stop = int(storage._payload.indptr[row + 1])
return {
int(storage._payload.indices[offset]): storage._payload.data[offset]
for offset in range(start, stop)
}


def _execute_csr_product_ir(
lhs: MVArray,
rhs: MVArray,
Expand All @@ -66,29 +54,41 @@ def _execute_csr_product_ir(
layout = output_layout_from_product_ir(ir, lhs.algebra)
dtype = np.dtype(np.result_type(lhs.dtype, rhs.dtype))
row_count = int(np.prod(batch_shape, dtype=np.intp))
lhs_rows = _broadcast_csr_rows(lhs.storage, batch_shape)
rhs_rows = _broadcast_csr_rows(rhs.storage, batch_shape)
lhs_rows = _broadcast_row_indices(lhs.storage.batch_shape, batch_shape)
rhs_rows = _broadcast_row_indices(rhs.storage.batch_shape, batch_shape)

terms_by_pair: dict[tuple[int, int], list[tuple[int, int]]] = {}
for term in ir.terms:
terms_by_pair.setdefault((term.lhs_col, term.rhs_col), []).append(
(term.out_col, term.coefficient)
)

data_values: list[Any] = []
index_values: list[int] = []
indptr = np.zeros(row_count + 1, dtype=np.intp)

nnz = 0
for out_row, (lhs_row, rhs_row) in enumerate(zip(lhs_rows, rhs_rows, strict=True)):
lhs_values = _csr_row_values(lhs.storage, int(lhs_row))
rhs_values = _csr_row_values(rhs.storage, int(rhs_row))
lhs_start = int(lhs.storage._payload.indptr[lhs_row])
lhs_stop = int(lhs.storage._payload.indptr[lhs_row + 1])
rhs_start = int(rhs.storage._payload.indptr[rhs_row])
rhs_stop = int(rhs.storage._payload.indptr[rhs_row + 1])
out_values: dict[int, Any] = {}

for term in ir.terms:
lhs_value = lhs_values.get(term.lhs_col)
if lhs_value is None:
continue
rhs_value = rhs_values.get(term.rhs_col)
if rhs_value is None:
continue
out_values[term.out_col] = out_values.get(
term.out_col, dtype.type(0)
) + term.coefficient * lhs_value * rhs_value
for lhs_offset in range(lhs_start, lhs_stop):
lhs_col = int(lhs.storage._payload.indices[lhs_offset])
lhs_value = lhs.storage._payload.data[lhs_offset]
for rhs_offset in range(rhs_start, rhs_stop):
pair_terms = terms_by_pair.get(
(lhs_col, int(rhs.storage._payload.indices[rhs_offset]))
)
if pair_terms is None:
continue
product = lhs_value * rhs.storage._payload.data[rhs_offset]
for out_col, coefficient in pair_terms:
out_values[out_col] = out_values.get(
out_col, dtype.type(0)
) + coefficient * product

for column in sorted(out_values):
value = out_values[column]
Expand Down Expand Up @@ -189,6 +189,7 @@ def execute_sequence_ir(
sequence into a single kernel; the NumPy backend executes faithfully
step-by-step, with optional fusion support for common patterns.
"""
ir = optimize_sequence_ir(ir)
env: dict[str, Any] = dict(inputs)

i = 0
Expand Down
74 changes: 70 additions & 4 deletions src/amsa/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Literal
from typing import Any, Literal

from amsa.ir import SequenceIR, SequenceStepKind
from amsa.ir import IRStep, SequenceIR, SequenceStepKind

FusionKind = Literal[
"scale_product", # scale followed by binary product
Expand Down Expand Up @@ -58,6 +58,74 @@ class FusionPattern:
)


def _freeze_metadata(value: Any) -> Any:
if isinstance(value, dict):
return tuple(sorted((key, _freeze_metadata(item)) for key, item in value.items()))
if isinstance(value, (tuple, list)):
return tuple(_freeze_metadata(item) for item in value)
if isinstance(value, set):
return tuple(sorted(_freeze_metadata(item) for item in value))
try:
hash(value)
except TypeError:
return repr(value)
return value


def _cse_step_key(step: IRStep) -> tuple[Any, ...]:
return (
step.kind,
step.operands,
step.ir,
_freeze_metadata(step.metadata),
)


def eliminate_common_subexpressions(ir: SequenceIR) -> SequenceIR:
"""Remove duplicate pure steps from a ``SequenceIR``.

The pass is intentionally conservative: a step is common only when its kind,
remapped operands, lowered IR, and metadata are identical. Sequence steps are
pure coefficient operations, so later duplicate outputs can safely alias the
first output.
"""
output_aliases: dict[str, str] = {}
seen: dict[tuple[Any, ...], str] = {}
new_steps: list[IRStep] = []

for step in ir.steps:
operands = tuple(output_aliases.get(operand, operand) for operand in step.operands)
remapped = IRStep(
kind=step.kind,
operands=operands,
ir=step.ir,
output=step.output,
metadata=step.metadata,
fusion=None,
)
key = _cse_step_key(remapped)
existing_output = seen.get(key)
if existing_output is not None:
output_aliases[step.output] = existing_output
continue

seen[key] = step.output
output_aliases[step.output] = step.output
new_steps.append(remapped)

return SequenceIR(
name=ir.name,
inputs=ir.inputs,
steps=tuple(new_steps),
result=output_aliases.get(ir.result, ir.result),
)


def optimize_sequence_ir(ir: SequenceIR) -> SequenceIR:
"""Apply conservative sequence optimizations used by eager backends."""
return apply_fusion_metadata(eliminate_common_subexpressions(ir))


def analyze_fusion(ir: SequenceIR) -> dict[int, FusionKind]:
"""Analyze a SequenceIR and identify fusion opportunities.

Expand Down Expand Up @@ -103,8 +171,6 @@ def apply_fusion_metadata(ir: SequenceIR) -> SequenceIR:
Returns:
A new SequenceIR with fusion metadata applied to fusible steps.
"""
from amsa.ir import IRStep

fusion_opportunities = analyze_fusion(ir)

# Rebuild steps with fusion metadata
Expand Down
Loading