From 34041c08e3d4b614e7c220b6ec711d6989f8d838 Mon Sep 17 00:00:00 2001 From: Addison Date: Fri, 19 Dec 2025 02:46:36 -0600 Subject: [PATCH 01/43] add named Set and Map creation --- namedisl/__init__.py | 203 +++++++++++++++++++++++++++------ namedisl/test/test_namedisl.py | 38 +++++- 2 files changed, 204 insertions(+), 37 deletions(-) diff --git a/namedisl/__init__.py b/namedisl/__init__.py index a2edcaf..dc02871 100644 --- a/namedisl/__init__.py +++ b/namedisl/__init__.py @@ -32,10 +32,10 @@ THE SOFTWARE. """ import re +from abc import ABC, abstractmethod from collections.abc import Mapping -from dataclasses import dataclass from importlib import metadata -from typing import TypeAlias, TypeVar, overload +from typing import Generic, TypeAlias, TypeVar, final from constantdict import constantdict from typing_extensions import override @@ -49,56 +49,193 @@ VERSION = tuple(int(nr) for nr in _match.group(1).split(".")) -IslObject = TypeVar("IslObject", isl.BasicSet, isl.Set) -NameToDim: TypeAlias = Mapping[str, tuple[isl.dim_type, int]] +IslExpressionLikeT = TypeVar("IslExpressionLikeT", isl.Aff, isl.QPolynomial) +IslSetLikeT = TypeVar("IslSetLikeT", isl.Set, isl.Map) +IslObjectT = TypeVar("IslObjectT", isl.Set, isl.Map) +IslSetLike = isl.Set | isl.Map +IslExpressionLike = isl.Aff | isl.QPolynomial -def _strip_names(obj: IslObject) -> tuple[IslObject, NameToDim]: - name_to_dim: dict[str, tuple[isl.dim_type, int]] = {} - for tp in isl._CHECK_DIM_TYPES: - for i in range(obj.dim(tp)): - name = obj.get_dim_name(tp, i) - if name is None: - raise ValueError("unnamed dimension found") - if name in name_to_dim: - raise ValueError(f"non-unique dim name: {name}") - name_to_dim[name] = (tp, i) +IslObjectPieces: TypeAlias = tuple[IslObjectT, tuple[frozenset[str], ...]] - # FIXME: Enable, to avoid misunderstandings - # obj = obj.set_dim_id(tp, i, None) +NameToDim: TypeAlias = Mapping[str, int] + + +def _strip_names(obj: IslObjectT) -> tuple[IslObjectT, NameToDim]: + name_to_dim: dict[str, int] = {} + for i in range(obj.dim(isl.dim_type.set)): + name = obj.get_dim_name(isl.dim_type.set, i) + + if name is None: + raise ValueError("unnamed dimension found") + + if name in name_to_dim: + raise ValueError(f"non-unique dim name: {name}") + + name_to_dim[name] = i return obj, constantdict(name_to_dim) -def _restore_names(obj: IslObject, name_to_dim: NameToDim) -> IslObject: - for name, (dt, i) in name_to_dim.items(): - obj = obj.set_dim_name(dt, i, name) +def _get_dim_names(obj: IslObjectT, dt: isl.dim_type) -> frozenset[str]: + all_dt_names: list[str] = [] + for dim in range(obj.dim(dt)): + name = obj.get_dim_name(dt, dim) + + if name is None: + raise ValueError("unnamed dimension found") + + all_dt_names.append(name) + + return frozenset(all_dt_names) + + +def _move_dims_to_set_dim(obj: IslObjectT, dt: isl.dim_type) -> IslObjectT: + obj = obj.move_dims( + isl.dim_type.set, obj.dim(isl.dim_type.set), + dt, 0, obj.dim(dt) + ) return obj -@dataclass(frozen=True) -class BasicSet: - _obj: isl.BasicSet +class NamedIslObject(Generic[IslObjectT], ABC): + _obj: IslObjectT _name_to_dim: NameToDim + _parameter_names: frozenset[str] + _parameter_dim_start: int + + @abstractmethod + def __init__(self, src: IslObjectT | str, ctx: isl.Context | None = None): + ... + + # FIXME: needs a type on obj relaxed enough to be specialized in subclasses + @abstractmethod + def _deconstruct_isl_object(self, obj) -> IslObjectPieces[IslObjectT]: + ... + + @abstractmethod + def _reconstruct_isl_object(self) -> IslExpressionLike | IslSetLike: + ... @override def __str__(self) -> str: - return str(_restore_names(self._obj, self._name_to_dim)) + return str(self._reconstruct_isl_object()) + + +class _NamedIslSetLike(NamedIslObject[isl.Set], ABC): + _obj: isl.Set -@overload -def make_basic_set(src: str, ctx: isl.Context | None = None) -> BasicSet: - ... +@final +class Set(_NamedIslSetLike): + def __init__(self, src: isl.Set | str, ctx: isl.Context | None = None): + obj = isl.Set(src, ctx) if isinstance(src, str) else src + + obj, (parameter_names,) = self._deconstruct_isl_object(obj) + obj, name_to_dim = _strip_names(obj) + + self._obj = obj + self._name_to_dim = name_to_dim + self._parameter_names = parameter_names + + self._parameter_dim_start = min( + self._name_to_dim[name] + for name in self._parameter_names + ) + + @override + def _deconstruct_isl_object(self, obj: isl.Set) -> IslObjectPieces[isl.Set]: + """ + Internal set dimensions ordered in two contiguous chunks: + [ (set dimensions), (parameter dimensions) ] + """ + parameter_names = _get_dim_names(obj, isl.dim_type.param) + obj = _move_dims_to_set_dim(obj, isl.dim_type.param) + return obj, (frozenset(parameter_names),) + + @override + def _reconstruct_isl_object(self) -> isl.Set: + return self._obj.move_dims( + isl.dim_type.param, 0, + isl.dim_type.set, self._parameter_dim_start, + len(self._parameter_names) + ) + + +@final +class Map(_NamedIslSetLike): + _input_names: frozenset[str] + _input_dim_start: int + + def __init__(self, src: isl.Map | str, ctx: isl.Context | None = None): + obj = isl.Map(src, ctx) if isinstance(src, str) else src + + obj, (parameter_names, input_names) = self._deconstruct_isl_object(obj) + obj, name_to_dim = _strip_names(obj) + + self._parameter_names = parameter_names + self._input_names = input_names + self._name_to_dim = name_to_dim + self._obj = obj + + self._parameter_dim_start = min( + self._name_to_dim[name] + for name in self._parameter_names + ) + + self._input_dim_start = min( + self._name_to_dim[name] + for name in self._input_names + ) + + # NOTE: hard requirement for object reconstruction is to have each type + # of dimension contiguous in the underlying set. each type of dimension + # can be shuffled around arbitrarily within each contiguous chunk. + # impose chunk ordering as [ (set), (parameter), (input) ] + if self._input_dim_start < self._parameter_dim_start: + raise ValueError( + "Expected input dimensions to be ordered after parameter " + "dimensions in set representation" + ) + + @override + def _deconstruct_isl_object(self, obj: isl.Map) -> IslObjectPieces[isl.Set]: + """ + Internal set dimensions ordered in three contiguous chunks as: + [ (set dimensions), (parameter dimensions), (input dimensions) ] + """ + parameter_names = _get_dim_names(obj, isl.dim_type.param) + input_names = _get_dim_names(obj, isl.dim_type.in_) + + obj = _move_dims_to_set_dim(obj, isl.dim_type.param) + obj = _move_dims_to_set_dim(obj, isl.dim_type.in_) + + return obj.range(), (parameter_names, input_names) + + @override + def _reconstruct_isl_object(self) -> isl.Map: + """ + Relies on the dimension type ordering in + :func:`_deconstruct_isl_object`. + """ + domain = isl.Set("{ [] }") + range = self._obj + + map = isl.Map.from_domain_and_range(domain, range) + param_start = self._parameter_dim_start -@overload -def make_basic_set(src: isl.BasicSet) -> BasicSet: - ... + map = map.move_dims( + isl.dim_type.param, 0, + isl.dim_type.set, param_start, len(self._parameter_names) + ) + inp_start = self._input_dim_start - len(self._parameter_names) -def make_basic_set(src: str | isl.BasicSet, ctx: isl.Context | None = None) -> BasicSet: - obj = isl.BasicSet(src, ctx) if isinstance(src, str) else src + map = map.move_dims( + isl.dim_type.in_, 0, + isl.dim_type.set, inp_start, len(self._input_names) + ) - obj, name_to_dim = _strip_names(obj) - return BasicSet(obj, name_to_dim) + return map diff --git a/namedisl/test/test_namedisl.py b/namedisl/test/test_namedisl.py index 4eb4f07..23d0442 100644 --- a/namedisl/test/test_namedisl.py +++ b/namedisl/test/test_namedisl.py @@ -25,10 +25,40 @@ THE SOFTWARE. """ +import islpy as isl + import namedisl as nisl -def test_basic_set() -> None: - bs = nisl.make_basic_set("[n] -> {[i]}: 0<=i None: + spec = "[n] -> { [i] : 0 <= i < n }" + s_isl = isl.Set(spec) + s = nisl.Set(spec) + print(s) + + assert s._reconstruct_isl_object() == s_isl + + +def test_set_from_set() -> None: + s_isl = isl.Set("[n] -> { [i] : 0 <= i < n }") + s = nisl.Set(s_isl) + print(s) + + assert s._reconstruct_isl_object() == s_isl + + +def test_map_from_str() -> None: + spec = "[n] -> { [i] -> [j] : 0 <= i < n and j = 2 * i }" + m = nisl.Map(spec) + m_isl = isl.Map(spec) + print(m) + + assert m._reconstruct_isl_object() == m_isl + + +def test_map_from_map() -> None: + m_isl = isl.Map("[n] -> { [i] -> [j] : 0 <= i < n and j = 2 * i }") + m = nisl.Map(m_isl) + print(m) + + assert m._reconstruct_isl_object() == m_isl From 7d57de29e24b8aa3f065530fefea6bf1048ad6eb Mon Sep 17 00:00:00 2001 From: Addison Date: Sun, 11 Jan 2026 14:39:12 -0600 Subject: [PATCH 02/43] implement joint name to dim finding based on alphabetical ordering within dimension types --- namedisl/__init__.py | 45 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/namedisl/__init__.py b/namedisl/__init__.py index dc02871..7e47720 100644 --- a/namedisl/__init__.py +++ b/namedisl/__init__.py @@ -102,9 +102,15 @@ def _move_dims_to_set_dim(obj: IslObjectT, dt: isl.dim_type) -> IslObjectT: class NamedIslObject(Generic[IslObjectT], ABC): _obj: IslObjectT _name_to_dim: NameToDim + _parameter_names: frozenset[str] _parameter_dim_start: int + # NOTE: defaulting these for all subclasses reduces the amount of + # specialization when aligning spaces of objects + _input_names: frozenset[str] = frozenset() + _input_dim_start: int = -1 + @abstractmethod def __init__(self, src: IslObjectT | str, ctx: isl.Context | None = None): ... @@ -123,6 +129,42 @@ def __str__(self) -> str: return str(self._reconstruct_isl_object()) +# FIXME: enforcing alphabetical ordering within each contiguous chunk of +# dimension types solves the problem +def _find_joint_name_to_dim( + obj: NamedIslObject[IslObjectT], + other: NamedIslObject[IslObjectT] + ) -> tuple[NameToDim, tuple[frozenset[str], frozenset[str]]]: + """ + Constructs a mapping from names to dimensions such that names within each + "type chunk" are sorted alphabetically. Specifically, the internal + :class:`isl.Set` representation of each :class:`NamedIslObject` will have + the form + + [ (set dimensions), (parameter dimensions), (input_dimensions) ] + + where the names in each dimension appear in alphabetical order. + """ + obj_all_names = frozenset(obj._name_to_dim.keys()) + obj_inp_names = obj._input_names + obj_param_names = obj._parameter_names + obj_set_names = (obj_all_names - obj_param_names) - obj_inp_names + + other_all_names = frozenset(other._name_to_dim.keys()) + other_inp_names = other._input_names + other_param_names = other._parameter_names + other_set_names = (other_all_names - other_param_names) - other_inp_names + + all_inp_names = sorted(list(obj_inp_names | other_inp_names)) + all_param_names = sorted(list(obj_param_names | other_param_names)) + all_set_names = sorted(list(obj_set_names | other_set_names)) + all_names = all_set_names + all_param_names + all_inp_names + + name_to_dim = { name : dim for dim, name in enumerate(all_names) } + + return name_to_dim, (frozenset(all_param_names), frozenset(all_inp_names)) + + class _NamedIslSetLike(NamedIslObject[isl.Set], ABC): _obj: isl.Set @@ -165,9 +207,6 @@ def _reconstruct_isl_object(self) -> isl.Set: @final class Map(_NamedIslSetLike): - _input_names: frozenset[str] - _input_dim_start: int - def __init__(self, src: isl.Map | str, ctx: isl.Context | None = None): obj = isl.Map(src, ctx) if isinstance(src, str) else src From 6b792d0ca7960ae911accc1d357e0c960c9921cb Mon Sep 17 00:00:00 2001 From: Addison Date: Mon, 12 Jan 2026 14:58:30 -0600 Subject: [PATCH 03/43] use dataclasses for namedisl objects --- namedisl/__init__.py | 175 +++++++++++++++++++++++-------------------- 1 file changed, 92 insertions(+), 83 deletions(-) diff --git a/namedisl/__init__.py b/namedisl/__init__.py index 7e47720..8d8d2ed 100644 --- a/namedisl/__init__.py +++ b/namedisl/__init__.py @@ -31,11 +31,12 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from dataclasses import dataclass import re from abc import ABC, abstractmethod from collections.abc import Mapping from importlib import metadata -from typing import Generic, TypeAlias, TypeVar, final +from typing import Generic, TypeAlias, TypeVar, final, overload from constantdict import constantdict from typing_extensions import override @@ -56,7 +57,7 @@ IslSetLike = isl.Set | isl.Map IslExpressionLike = isl.Aff | isl.QPolynomial -IslObjectPieces: TypeAlias = tuple[IslObjectT, tuple[frozenset[str], ...]] +SetLikePieces: TypeAlias = tuple[isl.Set, tuple[frozenset[str], ...]] NameToDim: TypeAlias = Mapping[str, int] @@ -90,15 +91,42 @@ def _get_dim_names(obj: IslObjectT, dt: isl.dim_type) -> frozenset[str]: return frozenset(all_dt_names) -def _move_dims_to_set_dim(obj: IslObjectT, dt: isl.dim_type) -> IslObjectT: - obj = obj.move_dims( - isl.dim_type.set, obj.dim(isl.dim_type.set), - dt, 0, obj.dim(dt) - ) +def _deconstruct_set_like_object(obj: IslSetLikeT) -> SetLikePieces: + from islpy import dim_type + + dt_to_names: dict[dim_type, frozenset[str]] = {} + for dt in dt_to_names.keys(): + dt_to_names[dt] = _get_dim_names(obj, dt) + obj = obj.move_dims( + dim_type.set, + obj.dim(dim_type.set), + dt, + 0, + obj.dim(dt) + ) + + if isinstance(obj, isl.Map): + set_obj = obj.range() + else: + set_obj = obj + + input_names = dt_to_names[dim_type.in_] + param_names = dt_to_names[dim_type.param] + + if input_names: + input_names = frozenset(input_names) + else: + input_names = frozenset() - return obj + if param_names: + param_names = frozenset(param_names) + else: + param_names = frozenset() + return set_obj, (input_names, param_names) + +@dataclass(frozen=True) class NamedIslObject(Generic[IslObjectT], ABC): _obj: IslObjectT _name_to_dim: NameToDim @@ -111,15 +139,6 @@ class NamedIslObject(Generic[IslObjectT], ABC): _input_names: frozenset[str] = frozenset() _input_dim_start: int = -1 - @abstractmethod - def __init__(self, src: IslObjectT | str, ctx: isl.Context | None = None): - ... - - # FIXME: needs a type on obj relaxed enough to be specialized in subclasses - @abstractmethod - def _deconstruct_isl_object(self, obj) -> IslObjectPieces[IslObjectT]: - ... - @abstractmethod def _reconstruct_isl_object(self) -> IslExpressionLike | IslSetLike: ... @@ -165,37 +184,14 @@ def _find_joint_name_to_dim( return name_to_dim, (frozenset(all_param_names), frozenset(all_inp_names)) +@dataclass(frozen=True) class _NamedIslSetLike(NamedIslObject[isl.Set], ABC): _obj: isl.Set @final +@dataclass(frozen=True, eq=False) class Set(_NamedIslSetLike): - def __init__(self, src: isl.Set | str, ctx: isl.Context | None = None): - obj = isl.Set(src, ctx) if isinstance(src, str) else src - - obj, (parameter_names,) = self._deconstruct_isl_object(obj) - obj, name_to_dim = _strip_names(obj) - - self._obj = obj - self._name_to_dim = name_to_dim - self._parameter_names = parameter_names - - self._parameter_dim_start = min( - self._name_to_dim[name] - for name in self._parameter_names - ) - - @override - def _deconstruct_isl_object(self, obj: isl.Set) -> IslObjectPieces[isl.Set]: - """ - Internal set dimensions ordered in two contiguous chunks: - [ (set dimensions), (parameter dimensions) ] - """ - parameter_names = _get_dim_names(obj, isl.dim_type.param) - obj = _move_dims_to_set_dim(obj, isl.dim_type.param) - return obj, (frozenset(parameter_names),) - @override def _reconstruct_isl_object(self) -> isl.Set: return self._obj.move_dims( @@ -205,53 +201,32 @@ def _reconstruct_isl_object(self) -> isl.Set: ) -@final -class Map(_NamedIslSetLike): - def __init__(self, src: isl.Map | str, ctx: isl.Context | None = None): - obj = isl.Map(src, ctx) if isinstance(src, str) else src - - obj, (parameter_names, input_names) = self._deconstruct_isl_object(obj) - obj, name_to_dim = _strip_names(obj) +@overload +def make_set(src: str, ctx: isl.Context | None = None) -> Set: + ... - self._parameter_names = parameter_names - self._input_names = input_names - self._name_to_dim = name_to_dim - self._obj = obj - self._parameter_dim_start = min( - self._name_to_dim[name] - for name in self._parameter_names - ) +@overload +def make_set(src: isl.Set) -> Set: + ... - self._input_dim_start = min( - self._name_to_dim[name] - for name in self._input_names - ) - # NOTE: hard requirement for object reconstruction is to have each type - # of dimension contiguous in the underlying set. each type of dimension - # can be shuffled around arbitrarily within each contiguous chunk. - # impose chunk ordering as [ (set), (parameter), (input) ] - if self._input_dim_start < self._parameter_dim_start: - raise ValueError( - "Expected input dimensions to be ordered after parameter " - "dimensions in set representation" - ) +def make_set(src: isl.Set | str, ctx: isl.Context | None = None) -> Set: + obj = isl.Set(src, ctx) if isinstance(src, str) else src - @override - def _deconstruct_isl_object(self, obj: isl.Map) -> IslObjectPieces[isl.Set]: - """ - Internal set dimensions ordered in three contiguous chunks as: - [ (set dimensions), (parameter dimensions), (input dimensions) ] - """ - parameter_names = _get_dim_names(obj, isl.dim_type.param) - input_names = _get_dim_names(obj, isl.dim_type.in_) + set_obj, (param_names, _) = _deconstruct_set_like_object(obj) + set_obj, name_to_dim = _strip_names(set_obj) + parameter_dim_start = min( + name_to_dim[name] + for name in param_names + ) - obj = _move_dims_to_set_dim(obj, isl.dim_type.param) - obj = _move_dims_to_set_dim(obj, isl.dim_type.in_) + return Set(set_obj, name_to_dim, param_names, parameter_dim_start) - return obj.range(), (parameter_names, input_names) +@final +@dataclass(frozen=True, eq=False) +class Map(_NamedIslSetLike): @override def _reconstruct_isl_object(self) -> isl.Map: """ @@ -264,17 +239,51 @@ def _reconstruct_isl_object(self) -> isl.Map: map = isl.Map.from_domain_and_range(domain, range) param_start = self._parameter_dim_start - map = map.move_dims( isl.dim_type.param, 0, isl.dim_type.set, param_start, len(self._parameter_names) ) inp_start = self._input_dim_start - len(self._parameter_names) - map = map.move_dims( isl.dim_type.in_, 0, isl.dim_type.set, inp_start, len(self._input_names) ) return map + + +@overload +def make_map(src: str, ctx: isl.Context | None = None) -> Map: + ... + + +@overload +def make_map(src: isl.Map) -> Map: + ... + + +def make_map(src: str | isl.Map, ctx: isl.Context | None = None) -> Map: + obj = isl.Map(src, ctx) if isinstance(src, str) else src + + set_obj, (param_names, inp_names) = _deconstruct_set_like_object(obj) + set_obj, name_to_dim = _strip_names(set_obj) + + parameter_dim_start = min( + name_to_dim[name] + for name in name_to_dim + ) + + input_dim_start = min( + name_to_dim[name] + for name in name_to_dim + ) + + return Map( + set_obj, + name_to_dim, + param_names, + parameter_dim_start, + _input_names=inp_names, + _input_dim_start=input_dim_start + ) From 4771dada357f1c07f15edf72e3e900a80eb4d5ac Mon Sep 17 00:00:00 2001 From: Addison Date: Mon, 12 Jan 2026 15:01:00 -0600 Subject: [PATCH 04/43] remove unclear FIXME --- namedisl/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/namedisl/__init__.py b/namedisl/__init__.py index 8d8d2ed..743d966 100644 --- a/namedisl/__init__.py +++ b/namedisl/__init__.py @@ -148,8 +148,6 @@ def __str__(self) -> str: return str(self._reconstruct_isl_object()) -# FIXME: enforcing alphabetical ordering within each contiguous chunk of -# dimension types solves the problem def _find_joint_name_to_dim( obj: NamedIslObject[IslObjectT], other: NamedIslObject[IslObjectT] From d9849ec22c76925c043cd4e2a629f75d6f13f6c8 Mon Sep 17 00:00:00 2001 From: Addison Date: Tue, 13 Jan 2026 15:46:30 -0600 Subject: [PATCH 05/43] reorder dimension type-chunk movement --- namedisl/__init__.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/namedisl/__init__.py b/namedisl/__init__.py index 743d966..d044235 100644 --- a/namedisl/__init__.py +++ b/namedisl/__init__.py @@ -236,16 +236,16 @@ def _reconstruct_isl_object(self) -> isl.Map: map = isl.Map.from_domain_and_range(domain, range) - param_start = self._parameter_dim_start + inp_start = self._input_dim_start map = map.move_dims( - isl.dim_type.param, 0, - isl.dim_type.set, param_start, len(self._parameter_names) + isl.dim_type.in_, 0, + isl.dim_type.set, inp_start, len(self._input_names) ) - inp_start = self._input_dim_start - len(self._parameter_names) + param_start = self._parameter_dim_start map = map.move_dims( - isl.dim_type.in_, 0, - isl.dim_type.set, inp_start, len(self._input_names) + isl.dim_type.param, 0, + isl.dim_type.set, param_start, len(self._parameter_names) ) return map From 65301c1a6d1ecf18f2595ac97d0275e88dc0af68 Mon Sep 17 00:00:00 2001 From: Addison Date: Thu, 15 Jan 2026 11:11:21 -0600 Subject: [PATCH 06/43] refactor dimension type tracking for reconstruction; undo joint name to dim finding --- namedisl/__init__.py | 222 +++++++++++++++++---------------- namedisl/test/test_namedisl.py | 8 +- 2 files changed, 119 insertions(+), 111 deletions(-) diff --git a/namedisl/__init__.py b/namedisl/__init__.py index d044235..2323482 100644 --- a/namedisl/__init__.py +++ b/namedisl/__init__.py @@ -57,10 +57,15 @@ IslSetLike = isl.Set | isl.Map IslExpressionLike = isl.Aff | isl.QPolynomial -SetLikePieces: TypeAlias = tuple[isl.Set, tuple[frozenset[str], ...]] - NameToDim: TypeAlias = Mapping[str, int] +# NOTE: without tracking what dimension type a particular name belongs to, it is +# not possible to reconstruct the ISL object after dimension operations, e.g. +# alignment +DimTypeToNames: TypeAlias = Mapping[isl.dim_type, frozenset[str]] + +SetLikePieces: TypeAlias = tuple[isl.Set, DimTypeToNames] + def _strip_names(obj: IslObjectT) -> tuple[IslObjectT, NameToDim]: name_to_dim: dict[str, int] = {} @@ -78,6 +83,12 @@ def _strip_names(obj: IslObjectT) -> tuple[IslObjectT, NameToDim]: return obj, constantdict(name_to_dim) +def _restore_names(obj: IslObjectT, name_to_dim: NameToDim) -> IslObjectT: + for name, dim in name_to_dim.items(): + obj = obj.set_dim_name(isl.dim_type.set, dim, name) + return obj + + def _get_dim_names(obj: IslObjectT, dt: isl.dim_type) -> frozenset[str]: all_dt_names: list[str] = [] for dim in range(obj.dim(dt)): @@ -94,36 +105,28 @@ def _get_dim_names(obj: IslObjectT, dt: isl.dim_type) -> frozenset[str]: def _deconstruct_set_like_object(obj: IslSetLikeT) -> SetLikePieces: from islpy import dim_type - dt_to_names: dict[dim_type, frozenset[str]] = {} + dt_to_names: dict[dim_type, frozenset[str]] = dict.fromkeys( + [isl.dim_type.in_, isl.dim_type.param], frozenset() + ) for dt in dt_to_names.keys(): dt_to_names[dt] = _get_dim_names(obj, dt) - obj = obj.move_dims( - dim_type.set, - obj.dim(dim_type.set), - dt, - 0, - obj.dim(dt) - ) + if dt_to_names[dt]: + obj = obj.move_dims( + dim_type.set, + obj.dim(dim_type.set), + dt, + 0, + obj.dim(dt) + ) + + dt_to_names = {dt: names for dt, names in dt_to_names.items() if names} if isinstance(obj, isl.Map): set_obj = obj.range() else: set_obj = obj - input_names = dt_to_names[dim_type.in_] - param_names = dt_to_names[dim_type.param] - - if input_names: - input_names = frozenset(input_names) - else: - input_names = frozenset() - - if param_names: - param_names = frozenset(param_names) - else: - param_names = frozenset() - - return set_obj, (input_names, param_names) + return set_obj, constantdict(dt_to_names) @dataclass(frozen=True) @@ -131,13 +134,46 @@ class NamedIslObject(Generic[IslObjectT], ABC): _obj: IslObjectT _name_to_dim: NameToDim - _parameter_names: frozenset[str] - _parameter_dim_start: int - - # NOTE: defaulting these for all subclasses reduces the amount of - # specialization when aligning spaces of objects - _input_names: frozenset[str] = frozenset() - _input_dim_start: int = -1 + # used to reconstruct ISL object + _dimtype_to_names: DimTypeToNames + + @property + def _has_inputs(self) -> bool: + return isl.dim_type.in_ in self._dimtype_to_names + + @property + def _input_names(self) -> frozenset[str]: + if self._has_inputs: + return self._dimtype_to_names[isl.dim_type.in_] + return frozenset() + + @property + def _input_dim_start(self) -> int | None: + if self._has_inputs: + return min( + self._name_to_dim[name] + for name in self._dimtype_to_names[isl.dim_type.in_] + ) + return None + + @property + def _has_params(self) -> bool: + return isl.dim_type.param in self._dimtype_to_names + + @property + def _parameter_names(self) -> frozenset[str]: + if self._has_params: + return self._dimtype_to_names[isl.dim_type.param] + return frozenset() + + @property + def _parameter_dim_start(self) -> int | None: + if self._has_params: + return min( + self._name_to_dim[name] + for name in self._dimtype_to_names[isl.dim_type.param] + ) + return None @abstractmethod def _reconstruct_isl_object(self) -> IslExpressionLike | IslSetLike: @@ -148,42 +184,13 @@ def __str__(self) -> str: return str(self._reconstruct_isl_object()) -def _find_joint_name_to_dim( - obj: NamedIslObject[IslObjectT], - other: NamedIslObject[IslObjectT] - ) -> tuple[NameToDim, tuple[frozenset[str], frozenset[str]]]: - """ - Constructs a mapping from names to dimensions such that names within each - "type chunk" are sorted alphabetically. Specifically, the internal - :class:`isl.Set` representation of each :class:`NamedIslObject` will have - the form - - [ (set dimensions), (parameter dimensions), (input_dimensions) ] - - where the names in each dimension appear in alphabetical order. - """ - obj_all_names = frozenset(obj._name_to_dim.keys()) - obj_inp_names = obj._input_names - obj_param_names = obj._parameter_names - obj_set_names = (obj_all_names - obj_param_names) - obj_inp_names - - other_all_names = frozenset(other._name_to_dim.keys()) - other_inp_names = other._input_names - other_param_names = other._parameter_names - other_set_names = (other_all_names - other_param_names) - other_inp_names - - all_inp_names = sorted(list(obj_inp_names | other_inp_names)) - all_param_names = sorted(list(obj_param_names | other_param_names)) - all_set_names = sorted(list(obj_set_names | other_set_names)) - all_names = all_set_names + all_param_names + all_inp_names - - name_to_dim = { name : dim for dim, name in enumerate(all_names) } - - return name_to_dim, (frozenset(all_param_names), frozenset(all_inp_names)) - - @dataclass(frozen=True) class _NamedIslSetLike(NamedIslObject[isl.Set], ABC): + """ + Represents set-like objects with parameter dimensions as a non-parameterized + set. Names are organized as contiguous chunks of dimension types, i.e. + [ (set names), (input names), (parameter names) ] + """ _obj: isl.Set @@ -192,11 +199,20 @@ class _NamedIslSetLike(NamedIslObject[isl.Set], ABC): class Set(_NamedIslSetLike): @override def _reconstruct_isl_object(self) -> isl.Set: - return self._obj.move_dims( - isl.dim_type.param, 0, - isl.dim_type.set, self._parameter_dim_start, - len(self._parameter_names) - ) + if self._has_params: + if self._parameter_dim_start is None: + raise ValueError( + "Object has parameter dimensions, but a starting index for " + "parameter names is not given. Reconstruction is not " + "possible") + + return self._obj.move_dims( + isl.dim_type.param, 0, + isl.dim_type.set, self._parameter_dim_start, + len(self._parameter_names) + ) + + return self._obj @overload @@ -212,14 +228,10 @@ def make_set(src: isl.Set) -> Set: def make_set(src: isl.Set | str, ctx: isl.Context | None = None) -> Set: obj = isl.Set(src, ctx) if isinstance(src, str) else src - set_obj, (param_names, _) = _deconstruct_set_like_object(obj) + set_obj, dimtype_to_names = _deconstruct_set_like_object(obj) set_obj, name_to_dim = _strip_names(set_obj) - parameter_dim_start = min( - name_to_dim[name] - for name in param_names - ) - return Set(set_obj, name_to_dim, param_names, parameter_dim_start) + return Set(set_obj, name_to_dim, dimtype_to_names) @final @@ -229,26 +241,39 @@ class Map(_NamedIslSetLike): def _reconstruct_isl_object(self) -> isl.Map: """ Relies on the dimension type ordering in - :func:`_deconstruct_isl_object`. + :func:`_deconstruct_set_like_object`. """ - domain = isl.Set("{ [] }") - range = self._obj + if self._input_dim_start is None: + raise ValueError("Cannot reconstruct a map object without knowledge " + "of the starting position of input dimensions") + + obj = _restore_names(self._obj, self._name_to_dim) - map = isl.Map.from_domain_and_range(domain, range) + obj_domain = isl.Set("{ [] }") + obj_range = obj + + map_obj = isl.Map.from_domain_and_range(obj_domain, obj_range) + + if self._has_params: + if self._parameter_dim_start is None: + raise ValueError( + "Object has parameter dimensions, but a starting index for " + "parameter names is not given. Reconstruction is not " + "possible") + + param_start = self._parameter_dim_start + map_obj = map_obj.move_dims( + isl.dim_type.param, 0, + isl.dim_type.set, param_start, len(self._parameter_names) + ) inp_start = self._input_dim_start - map = map.move_dims( + map_obj = map_obj.move_dims( isl.dim_type.in_, 0, isl.dim_type.set, inp_start, len(self._input_names) ) - param_start = self._parameter_dim_start - map = map.move_dims( - isl.dim_type.param, 0, - isl.dim_type.set, param_start, len(self._parameter_names) - ) - - return map + return map_obj @overload @@ -264,24 +289,7 @@ def make_map(src: isl.Map) -> Map: def make_map(src: str | isl.Map, ctx: isl.Context | None = None) -> Map: obj = isl.Map(src, ctx) if isinstance(src, str) else src - set_obj, (param_names, inp_names) = _deconstruct_set_like_object(obj) + set_obj, dimtype_to_names = _deconstruct_set_like_object(obj) set_obj, name_to_dim = _strip_names(set_obj) - parameter_dim_start = min( - name_to_dim[name] - for name in name_to_dim - ) - - input_dim_start = min( - name_to_dim[name] - for name in name_to_dim - ) - - return Map( - set_obj, - name_to_dim, - param_names, - parameter_dim_start, - _input_names=inp_names, - _input_dim_start=input_dim_start - ) + return Map(set_obj, name_to_dim, dimtype_to_names) diff --git a/namedisl/test/test_namedisl.py b/namedisl/test/test_namedisl.py index 23d0442..eaa2d85 100644 --- a/namedisl/test/test_namedisl.py +++ b/namedisl/test/test_namedisl.py @@ -33,7 +33,7 @@ def test_set_from_str() -> None: spec = "[n] -> { [i] : 0 <= i < n }" s_isl = isl.Set(spec) - s = nisl.Set(spec) + s = nisl.make_set(spec) print(s) assert s._reconstruct_isl_object() == s_isl @@ -41,7 +41,7 @@ def test_set_from_str() -> None: def test_set_from_set() -> None: s_isl = isl.Set("[n] -> { [i] : 0 <= i < n }") - s = nisl.Set(s_isl) + s = nisl.make_set(s_isl) print(s) assert s._reconstruct_isl_object() == s_isl @@ -49,7 +49,7 @@ def test_set_from_set() -> None: def test_map_from_str() -> None: spec = "[n] -> { [i] -> [j] : 0 <= i < n and j = 2 * i }" - m = nisl.Map(spec) + m = nisl.make_map(spec) m_isl = isl.Map(spec) print(m) @@ -58,7 +58,7 @@ def test_map_from_str() -> None: def test_map_from_map() -> None: m_isl = isl.Map("[n] -> { [i] -> [j] : 0 <= i < n and j = 2 * i }") - m = nisl.Map(m_isl) + m = nisl.make_map(m_isl) print(m) assert m._reconstruct_isl_object() == m_isl From 5ad094d02710a3747d841c1d0eb85449e7eee9d0 Mon Sep 17 00:00:00 2001 From: Addison Date: Thu, 15 Jan 2026 11:25:16 -0600 Subject: [PATCH 07/43] address ruff + basedpyright complaints --- namedisl/__init__.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/namedisl/__init__.py b/namedisl/__init__.py index 2323482..ca2f01c 100644 --- a/namedisl/__init__.py +++ b/namedisl/__init__.py @@ -31,10 +31,10 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from dataclasses import dataclass import re from abc import ABC, abstractmethod from collections.abc import Mapping +from dataclasses import dataclass from importlib import metadata from typing import Generic, TypeAlias, TypeVar, final, overload @@ -108,7 +108,7 @@ def _deconstruct_set_like_object(obj: IslSetLikeT) -> SetLikePieces: dt_to_names: dict[dim_type, frozenset[str]] = dict.fromkeys( [isl.dim_type.in_, isl.dim_type.param], frozenset() ) - for dt in dt_to_names.keys(): + for dt in dt_to_names: dt_to_names[dt] = _get_dim_names(obj, dt) if dt_to_names[dt]: obj = obj.move_dims( @@ -121,10 +121,7 @@ def _deconstruct_set_like_object(obj: IslSetLikeT) -> SetLikePieces: dt_to_names = {dt: names for dt, names in dt_to_names.items() if names} - if isinstance(obj, isl.Map): - set_obj = obj.range() - else: - set_obj = obj + set_obj = obj.range() if isinstance(obj, isl.Map) else obj return set_obj, constantdict(dt_to_names) From 727f18ca48ee6c2ad0c8bddb5a1683cbcd45e773 Mon Sep 17 00:00:00 2001 From: Addison Date: Thu, 15 Jan 2026 13:16:42 -0600 Subject: [PATCH 08/43] joint ordering and alignment v0.1 implementation --- namedisl/__init__.py | 123 +++++++++++++++++++++++++++++++-- namedisl/test/test_namedisl.py | 15 ++++ 2 files changed, 134 insertions(+), 4 deletions(-) diff --git a/namedisl/__init__.py b/namedisl/__init__.py index ca2f01c..f44f4e1 100644 --- a/namedisl/__init__.py +++ b/namedisl/__init__.py @@ -119,8 +119,6 @@ def _deconstruct_set_like_object(obj: IslSetLikeT) -> SetLikePieces: obj.dim(dt) ) - dt_to_names = {dt: names for dt, names in dt_to_names.items() if names} - set_obj = obj.range() if isinstance(obj, isl.Map) else obj return set_obj, constantdict(dt_to_names) @@ -136,7 +134,11 @@ class NamedIslObject(Generic[IslObjectT], ABC): @property def _has_inputs(self) -> bool: - return isl.dim_type.in_ in self._dimtype_to_names + return ( + isl.dim_type.in_ in self._dimtype_to_names + and + len(self._dimtype_to_names[isl.dim_type.in_]) > 0 + ) @property def _input_names(self) -> frozenset[str]: @@ -155,7 +157,11 @@ def _input_dim_start(self) -> int | None: @property def _has_params(self) -> bool: - return isl.dim_type.param in self._dimtype_to_names + return ( + isl.dim_type.param in self._dimtype_to_names + and + len(self._dimtype_to_names[isl.dim_type.param]) > 0 + ) @property def _parameter_names(self) -> frozenset[str]: @@ -191,6 +197,115 @@ class _NamedIslSetLike(NamedIslObject[isl.Set], ABC): _obj: isl.Set +# FIXME: think through whether or not alphabetical ordering will require more +# work on average than using one of the objects as a template in alignment +def _find_joint_name_to_dim( + obj1: _NamedIslSetLike, + obj2: _NamedIslSetLike + ) -> tuple[NameToDim, DimTypeToNames]: + """ + Enforces alphabetical ordering of all dimensions found in :arg:`obj1` and + :arg:`obj2` within each dimension-type chunk. This ordering is used in + alignment before performing operations between two set-like objects. + """ + obj1_inp_names = obj1._input_names + obj1_param_names = obj1._parameter_names + obj1_set_names = ( + frozenset(obj1._name_to_dim.keys()) - (obj1_inp_names | obj1_param_names) + ) + + obj2_inp_names = obj2._input_names + obj2_param_names = obj2._parameter_names + obj2_set_names = ( + frozenset(obj2._name_to_dim.keys()) - (obj2_inp_names | obj2_param_names) + ) + + all_inp_names = obj1_inp_names | obj2_inp_names + all_param_names = obj1_param_names | obj2_param_names + all_set_names = obj1_set_names | obj2_set_names + + dt_to_names: DimTypeToNames = {} + dt_to_names[isl.dim_type.param] = all_param_names + dt_to_names[isl.dim_type.in_] = all_inp_names + + # enforces contiguous ordering of [ (set), (input), (param) ] in set + # representation + all_names = sorted(list(all_set_names)) + all_names += sorted(list(all_inp_names)) + all_names += sorted(list(all_param_names)) + + name_to_dim: NameToDim = {} + for pos, name in enumerate(all_names): + name_to_dim[name] = pos + + return constantdict(name_to_dim), constantdict(dt_to_names) + + +def _align_obj( + named_obj: _NamedIslSetLike, + ordering: NameToDim, + dimtype_to_names: DimTypeToNames + ) -> _NamedIslSetLike: + new_isl_obj = named_obj._obj + running_name_to_dim = dict(named_obj._name_to_dim) + + # three cases for alignment + # - if name does not currently exist, then add a dimension + # - if name exists: + # - if new dim and old dim match, then do nothing + # - if new dim and old dim do not match, then swap and update old dim + + for name, dim in sorted(ordering.items(), key=lambda x: x[1]): + if name in running_name_to_dim: + old_dim = running_name_to_dim[name] + + if old_dim == dim: + continue + + # temporarily move to parameter dimension since destination and + # source dim types cannot match in ISL + new_isl_obj = new_isl_obj.move_dims( + isl.dim_type.param, 0, + isl.dim_type.set, dim, 1 + ) + + new_isl_obj = new_isl_obj.move_dims( + isl.dim_type.set, dim, + isl.dim_type.param, 0, 1 + ) + + else: + old_dim = new_isl_obj.dim(isl.dim_type.set) + new_isl_obj = new_isl_obj.insert_dims(isl.dim_type.set, dim, 1) + + # track side effects of inserting/swapping dimensions + temp_name_to_dim = running_name_to_dim.copy() + for cur_name, cur_dim in sorted( + running_name_to_dim.items(), key=lambda x: x[1]): + if (dim > old_dim) and (cur_dim > old_dim): + temp_name_to_dim[cur_name] = cur_dim - 1 + elif (dim < old_dim) and (cur_dim < old_dim): + temp_name_to_dim[cur_name] = cur_dim + 1 + + running_name_to_dim = temp_name_to_dim + running_name_to_dim[name] = dim + + return type(named_obj)(new_isl_obj, ordering, dimtype_to_names) + + +def _align_two(named_obj1: _NamedIslSetLike, + named_obj2: _NamedIslSetLike) -> tuple[_NamedIslSetLike, ...]: + + name_to_dim, dimtype_to_names = _find_joint_name_to_dim(named_obj1, + named_obj2) + + named_obj1 = _align_obj(named_obj1, name_to_dim, dimtype_to_names) + named_obj2 = _align_obj(named_obj2, name_to_dim, dimtype_to_names) + + return named_obj1, named_obj2 + + + @final @dataclass(frozen=True, eq=False) class Set(_NamedIslSetLike): diff --git a/namedisl/test/test_namedisl.py b/namedisl/test/test_namedisl.py index eaa2d85..086fd9c 100644 --- a/namedisl/test/test_namedisl.py +++ b/namedisl/test/test_namedisl.py @@ -62,3 +62,18 @@ def test_map_from_map() -> None: print(m) assert m._reconstruct_isl_object() == m_isl + + +def test_align_two() -> None: + m1 = nisl.make_map( + "{ [l, m, n, o] -> [i, j, k] : 0 <= i, j, k, l, m, n, o < 10 }") + m2 = nisl.make_map( + "{ [a, b, c] -> [x, y, z] : 0 <= a, b, c, x, y, z < 5 }") + + print(m1) + print(m2) + + m1, m2 = nisl._align_two(m1, m2) + + print(m1) + print(m2) From dcdcb207e7ca434cd840a398426d6bc2018708d6 Mon Sep 17 00:00:00 2001 From: Addison Date: Thu, 15 Jan 2026 13:18:09 -0600 Subject: [PATCH 09/43] address ruff + basedpyright complaints --- namedisl/__init__.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/namedisl/__init__.py b/namedisl/__init__.py index f44f4e1..52f7154 100644 --- a/namedisl/__init__.py +++ b/namedisl/__init__.py @@ -230,9 +230,9 @@ def _find_joint_name_to_dim( # enforces contiguous ordering of [ (set), (input), (param) ] in set # representation - all_names = sorted(list(all_set_names)) - all_names += sorted(list(all_inp_names)) - all_names += sorted(list(all_param_names)) + all_names = sorted(all_set_names) + all_names += sorted(all_inp_names) + all_names += sorted(all_param_names) name_to_dim: NameToDim = {} for pos, name in enumerate(all_names): @@ -249,12 +249,6 @@ def _align_obj( new_isl_obj = named_obj._obj running_name_to_dim = dict(named_obj._name_to_dim) - # three cases for alignment - # - if name does not currently exist, then add a dimension - # - if name exists: - # - if new dim and old dim match, then do nothing - # - if new dim and old dim do not match, then swap and update old dim - for name, dim in sorted(ordering.items(), key=lambda x: x[1]): if name in running_name_to_dim: old_dim = running_name_to_dim[name] @@ -305,7 +299,6 @@ def _align_two(named_obj1: _NamedIslSetLike, return named_obj1, named_obj2 - @final @dataclass(frozen=True, eq=False) class Set(_NamedIslSetLike): From 491e5fb4f16fad5af0628cf0081be32fcc8def36 Mon Sep 17 00:00:00 2001 From: Addison Date: Mon, 19 Jan 2026 15:25:05 -0600 Subject: [PATCH 10/43] implement alignment, setlike operations, setlike tests --- namedisl/__init__.py | 353 ++++++++++++++++++++++++------- namedisl/test/test_map.py | 206 ++++++++++++++++++ namedisl/test/test_namedisl.py | 79 ------- namedisl/test/test_set.py | 137 ++++++++++++ namedisl/test/utils_for_tests.py | 81 +++++++ 5 files changed, 697 insertions(+), 159 deletions(-) create mode 100644 namedisl/test/test_map.py delete mode 100644 namedisl/test/test_namedisl.py create mode 100644 namedisl/test/test_set.py create mode 100644 namedisl/test/utils_for_tests.py diff --git a/namedisl/__init__.py b/namedisl/__init__.py index 52f7154..516ab23 100644 --- a/namedisl/__init__.py +++ b/namedisl/__init__.py @@ -31,9 +31,10 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +import operator import re from abc import ABC, abstractmethod -from collections.abc import Mapping +from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass from importlib import metadata from typing import Generic, TypeAlias, TypeVar, final, overload @@ -50,13 +51,13 @@ VERSION = tuple(int(nr) for nr in _match.group(1).split(".")) -IslExpressionLikeT = TypeVar("IslExpressionLikeT", isl.Aff, isl.QPolynomial) -IslSetLikeT = TypeVar("IslSetLikeT", isl.Set, isl.Map) -IslObjectT = TypeVar("IslObjectT", isl.Set, isl.Map) - IslSetLike = isl.Set | isl.Map IslExpressionLike = isl.Aff | isl.QPolynomial +IslExpressionLikeT = TypeVar("IslExpressionLikeT", isl.Aff, isl.QPolynomial) +IslSetLikeT = TypeVar("IslSetLikeT", isl.Set, isl.Map) +IslObjectT = TypeVar("IslObjectT", IslSetLike, IslExpressionLike) + NameToDim: TypeAlias = Mapping[str, int] # NOTE: without tracking what dimension type a particular name belongs to, it is @@ -70,7 +71,11 @@ def _strip_names(obj: IslObjectT) -> tuple[IslObjectT, NameToDim]: name_to_dim: dict[str, int] = {} for i in range(obj.dim(isl.dim_type.set)): - name = obj.get_dim_name(isl.dim_type.set, i) + + if isinstance(obj, isl.QPolynomial): + name = obj.space.get_dim_name(isl.dim_type.set, i) + else: + name = obj.get_dim_name(isl.dim_type.set, i) if name is None: raise ValueError("unnamed dimension found") @@ -92,7 +97,11 @@ def _restore_names(obj: IslObjectT, name_to_dim: NameToDim) -> IslObjectT: def _get_dim_names(obj: IslObjectT, dt: isl.dim_type) -> frozenset[str]: all_dt_names: list[str] = [] for dim in range(obj.dim(dt)): - name = obj.get_dim_name(dt, dim) + + if isinstance(obj, isl.QPolynomial): + name = obj.space.get_dim_name(dt, dim) + else: + name = obj.get_dim_name(dt, dim) if name is None: raise ValueError("unnamed dimension found") @@ -124,84 +133,37 @@ def _deconstruct_set_like_object(obj: IslSetLikeT) -> SetLikePieces: return set_obj, constantdict(dt_to_names) -@dataclass(frozen=True) -class NamedIslObject(Generic[IslObjectT], ABC): - _obj: IslObjectT - _name_to_dim: NameToDim - - # used to reconstruct ISL object - _dimtype_to_names: DimTypeToNames - - @property - def _has_inputs(self) -> bool: - return ( - isl.dim_type.in_ in self._dimtype_to_names - and - len(self._dimtype_to_names[isl.dim_type.in_]) > 0 - ) - - @property - def _input_names(self) -> frozenset[str]: - if self._has_inputs: - return self._dimtype_to_names[isl.dim_type.in_] - return frozenset() - - @property - def _input_dim_start(self) -> int | None: - if self._has_inputs: - return min( - self._name_to_dim[name] - for name in self._dimtype_to_names[isl.dim_type.in_] - ) - return None - - @property - def _has_params(self) -> bool: - return ( - isl.dim_type.param in self._dimtype_to_names - and - len(self._dimtype_to_names[isl.dim_type.param]) > 0 - ) - - @property - def _parameter_names(self) -> frozenset[str]: - if self._has_params: - return self._dimtype_to_names[isl.dim_type.param] - return frozenset() - - @property - def _parameter_dim_start(self) -> int | None: - if self._has_params: - return min( - self._name_to_dim[name] - for name in self._dimtype_to_names[isl.dim_type.param] - ) - return None +def _find_contiguous_dim_chunks(dims: Sequence[int]) -> Mapping[int, int]: + """ + Determines contiguous chunks of dimensions within a sequence of dimensions. + Returns a mapping of the first dimension in the chunk to the length of the + chunk. + """ + if not dims: + return {} - @abstractmethod - def _reconstruct_isl_object(self) -> IslExpressionLike | IslSetLike: - ... + chunks: dict[int, int] = {} - @override - def __str__(self) -> str: - return str(self._reconstruct_isl_object()) + start = dims[0] + count = 1 + from itertools import pairwise + for prev, curr in pairwise(dims): + if curr == prev + 1: + count += 1 + else: + chunks[start] = count + start = curr + count = 1 -@dataclass(frozen=True) -class _NamedIslSetLike(NamedIslObject[isl.Set], ABC): - """ - Represents set-like objects with parameter dimensions as a non-parameterized - set. Names are organized as contiguous chunks of dimension types, i.e. - [ (set names), (input names), (parameter names) ] - """ - _obj: isl.Set + return constantdict(chunks) # FIXME: think through whether or not alphabetical ordering will require more # work on average than using one of the objects as a template in alignment def _find_joint_name_to_dim( - obj1: _NamedIslSetLike, - obj2: _NamedIslSetLike + obj1: NamedIslObject[IslObjectT], + obj2: NamedIslObject[IslObjectT] ) -> tuple[NameToDim, DimTypeToNames]: """ Enforces alphabetical ordering of all dimensions found in :arg:`obj1` and @@ -242,10 +204,10 @@ def _find_joint_name_to_dim( def _align_obj( - named_obj: _NamedIslSetLike, + named_obj: NamedIslObject[IslObjectT], ordering: NameToDim, dimtype_to_names: DimTypeToNames - ) -> _NamedIslSetLike: + ) -> NamedIslObject[IslObjectT]: new_isl_obj = named_obj._obj running_name_to_dim = dict(named_obj._name_to_dim) @@ -287,8 +249,10 @@ def _align_obj( return type(named_obj)(new_isl_obj, ordering, dimtype_to_names) -def _align_two(named_obj1: _NamedIslSetLike, - named_obj2: _NamedIslSetLike) -> tuple[_NamedIslSetLike, ...]: +def _align_two( + named_obj1: NamedIslObject[IslObjectT], + named_obj2: NamedIslObject[IslObjectT] + ) -> tuple[NamedIslObject[IslObjectT], ...]: name_to_dim, dimtype_to_names = _find_joint_name_to_dim(named_obj1, named_obj2) @@ -299,6 +263,206 @@ def _align_two(named_obj1: _NamedIslSetLike, return named_obj1, named_obj2 +def _align_and_apply_binary_op( + lhs: NamedIslObject[IslObjectT], + rhs: NamedIslObject[IslObjectT], + op: Callable[[IslObjectT, IslObjectT], IslObjectT] + ) -> NamedIslObject[IslObjectT]: + + lhs, rhs = _align_two(lhs, rhs) + result = op(lhs._obj, lhs._obj) + + # NOTE: since lhs and rhs were aligned, they both agree on what name-to-dim + # and dimtype-to-name is, can just take information from lhs + return type(lhs)(result, lhs._name_to_dim, lhs._dimtype_to_names) + + +@dataclass(frozen=True, eq=False) +class NamedIslObject(Generic[IslObjectT], ABC): + _obj: IslObjectT + _name_to_dim: NameToDim + + # used to reconstruct ISL object + _dimtype_to_names: DimTypeToNames + + @property + def _has_inputs(self) -> bool: + return ( + isl.dim_type.in_ in self._dimtype_to_names + and + len(self._dimtype_to_names[isl.dim_type.in_]) > 0 + ) + + @property + def _input_names(self) -> frozenset[str]: + if self._has_inputs: + return self._dimtype_to_names[isl.dim_type.in_] + return frozenset() + + @property + def _input_dim_start(self) -> int | None: + if self._has_inputs: + return min( + self._name_to_dim[name] + for name in self._dimtype_to_names[isl.dim_type.in_] + ) + return None + + @property + def _has_params(self) -> bool: + return ( + isl.dim_type.param in self._dimtype_to_names + and + len(self._dimtype_to_names[isl.dim_type.param]) > 0 + ) + + @property + def _parameter_names(self) -> frozenset[str]: + if self._has_params: + return self._dimtype_to_names[isl.dim_type.param] + return frozenset() + + @property + def _parameter_dim_start(self) -> int | None: + if self._has_params: + return min( + self._name_to_dim[name] + for name in self._dimtype_to_names[isl.dim_type.param] + ) + return None + + def __and__( + self, other: NamedIslObject[IslObjectT]) -> NamedIslObject[IslObjectT]: + return _align_and_apply_binary_op(self, other, operator.and_) + + def __or__( + self, other: NamedIslObject[IslObjectT]) -> NamedIslObject[IslObjectT]: + return _align_and_apply_binary_op(self, other, operator.or_) + + def __sub__( + self, other: NamedIslObject[IslObjectT]) -> NamedIslObject[IslObjectT]: + return _align_and_apply_binary_op(self, other, operator.sub) + + @abstractmethod + def _reconstruct_isl_object(self) -> IslExpressionLike | IslSetLike: + ... + + @override + def __str__(self) -> str: + return str(self._reconstruct_isl_object()) + + +@dataclass(frozen=True, eq=False) +class _NamedIslSetLike(NamedIslObject[isl.Set], ABC): + """ + Represents set-like objects with parameter dimensions as a non-parameterized + set. Names are organized as contiguous chunks of dimension types, i.e. + [ (set names), (input names), (parameter names) ] + """ + _obj: isl.Set + + def complement(self: _NamedIslSetLike) -> _NamedIslSetLike: + return type(self)(self._obj.complement(), + self._name_to_dim, + self._dimtype_to_names) + + def eliminate(self, names_to_eliminate: str | Sequence[str]) -> _NamedIslSetLike: + if isinstance(names_to_eliminate, str): + names_to_eliminate = [names_to_eliminate] + + dims_to_eliminate = sorted( + self._name_to_dim[name] + for name in names_to_eliminate + ) + + contiguous_dim_chunks = _find_contiguous_dim_chunks(dims_to_eliminate) + + new_isl_obj = self._obj + for start in sorted(contiguous_dim_chunks): + new_isl_obj = new_isl_obj.eliminate( + isl.dim_type.set, start, contiguous_dim_chunks[start] + ) + + return type(self)( + new_isl_obj, + self._name_to_dim, # NOTE: no dimensions are removed by elimination + self._dimtype_to_names + ) + + def project_out(self: _NamedIslSetLike, + names_to_project_out: str | Sequence[str]) -> _NamedIslSetLike: + + if isinstance(names_to_project_out, str): + names_to_project_out = [names_to_project_out] + + names_to_remove = set(names_to_project_out) + + dims_to_remove = sorted( + self._name_to_dim[name] + for name in names_to_remove + ) + + new_isl_obj = self._obj + contiguous_dim_chunks = _find_contiguous_dim_chunks(dims_to_remove) + for start in sorted(contiguous_dim_chunks, reverse=True): + new_isl_obj = new_isl_obj.project_out( + isl.dim_type.set, start, contiguous_dim_chunks[start] + ) + + new_name_to_dim: NameToDim = {} + for name, dim in self._name_to_dim.items(): + if name in names_to_remove: + continue + + shift = 0 + for removed_dim in dims_to_remove: + if removed_dim < dim: + shift += 1 + else: + break + + new_name_to_dim[name] = dim - shift + + new_type_to_names = constantdict({ + dt: self._dimtype_to_names[dt] - frozenset(names_to_remove) + for dt in self._dimtype_to_names + }) + + return type(self)( + new_isl_obj, + constantdict(new_name_to_dim), + new_type_to_names + ) + + def project_out_except( + self: _NamedIslSetLike, + names_to_keep: str | Sequence[str] + ) -> _NamedIslSetLike: + + if isinstance(names_to_keep, str): + names_to_keep = [names_to_keep] + + names_to_project_out = [ + name for name in self._name_to_dim + if name not in names_to_keep + ] + + return self.project_out(names_to_project_out) + + # {{{ TODO: funtions that return ExpressionLike objects + + def dim_max(self, name: str): + ... + + def dim_min(self, name: str): + ... + + def as_pw_multi_aff(self): + ... + + # }}} + + @final @dataclass(frozen=True, eq=False) class Set(_NamedIslSetLike): @@ -319,6 +483,19 @@ def _reconstruct_isl_object(self) -> isl.Set: return self._obj + @override + def __eq__(self, other: object) -> bool: + if not isinstance(other, Set): + raise NotImplementedError + + aligned_self, aligned_other = _align_two(self, other) + + # FIXME: type checker complains because it's not clear whether the + # underlying object after alignment is an isl.Set + assert isinstance(aligned_other._obj, isl.Set) + assert isinstance(aligned_self._obj, isl.Set) + return aligned_self._obj.plain_is_equal(aligned_other._obj) + @overload def make_set(src: str, ctx: isl.Context | None = None) -> Set: @@ -336,6 +513,7 @@ def make_set(src: isl.Set | str, ctx: isl.Context | None = None) -> Set: set_obj, dimtype_to_names = _deconstruct_set_like_object(obj) set_obj, name_to_dim = _strip_names(set_obj) + assert isinstance(set_obj, isl.Set) return Set(set_obj, name_to_dim, dimtype_to_names) @@ -353,6 +531,7 @@ def _reconstruct_isl_object(self) -> isl.Map: "of the starting position of input dimensions") obj = _restore_names(self._obj, self._name_to_dim) + assert isinstance(obj, isl.Set) obj_domain = isl.Set("{ [] }") obj_range = obj @@ -380,6 +559,19 @@ def _reconstruct_isl_object(self) -> isl.Map: return map_obj + @override + def __eq__(self, other: object) -> bool: + if not isinstance(other, Map): + raise NotImplementedError + + aligned_self, aligned_other = _align_two(self, other) + + # FIXME: type checker complains because it's not clear whether the + # underlying object after alignment is an isl.Set + assert isinstance(aligned_self._obj, isl.Set) + assert isinstance(aligned_other._obj, isl.Set) + return aligned_self._obj.plain_is_equal(aligned_other._obj) + @overload def make_map(src: str, ctx: isl.Context | None = None) -> Map: @@ -397,4 +589,5 @@ def make_map(src: str | isl.Map, ctx: isl.Context | None = None) -> Map: set_obj, dimtype_to_names = _deconstruct_set_like_object(obj) set_obj, name_to_dim = _strip_names(set_obj) + assert isinstance(set_obj, isl.Set) return Map(set_obj, name_to_dim, dimtype_to_names) diff --git a/namedisl/test/test_map.py b/namedisl/test/test_map.py new file mode 100644 index 0000000..41665ad --- /dev/null +++ b/namedisl/test/test_map.py @@ -0,0 +1,206 @@ +from __future__ import annotations + + +__copyright__ = """ +Copyright (C) 2025- University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import pytest + +import islpy as isl + +import namedisl as nisl +from .utils_for_tests import generate_random_named_map + + +def test_map_from_str() -> None: + m = nisl.make_map( + "[n] -> { [i,j] -> [a,b] : 0 <= i, j < 10 and 0 <= a, b < 20 }") + + print(m._obj) + print(m) + + +def test_map_from_map() -> None: + m = isl.Map( + "[n] -> { [i,j] -> [a,b] : 0 <= i, j < 10 and 0 <= a, b < 20 }") + named_map = nisl.make_map(m) + + print(named_map._obj) + print(named_map) + + +@pytest.mark.parametrize("ndims_domain", [2, 3, 4, 5]) +@pytest.mark.parametrize("ndims_range", [2, 3, 4, 5]) +@pytest.mark.parametrize("has_params", [True, False]) +def test_map_equality(ndims_domain, ndims_range, has_params): + if has_params: + d_param = "n" + r_param = "m" + else: + d_param = None + r_param = None + + og_map, domain_info, range_info = generate_random_named_map( + ndims_domain, "d", d_param, + ndims_range, "r", r_param) + + _, d_dims, d_cond = domain_info + _, r_dims, r_cond = range_info + + from itertools import permutations + d_perms = list(permutations(d_dims.split(","))) + r_perms = list(permutations(r_dims.split(","))) + + for d_perm, r_perm in zip(d_perms, r_perms, strict=False): + d_perm_dims = ",".join(p for p in d_perm) + r_perm_dims = ",".join(p for p in r_perm) + + domain_str = f"{{ [{d_perm_dims}] : {d_cond} }}" + range_str = f"{{ [{r_perm_dims}] : {r_cond} }}" + + if has_params: + domain_str = f"[{d_param}] ->" + domain_str + range_str = f"[{r_param}] ->" + range_str + + perm_map = nisl.make_map( + isl.Map.from_domain_and_range( + isl.Set(domain_str), isl.Set(range_str) + ) + ) + + assert perm_map == og_map + + +@pytest.mark.parametrize("ndims_domain", [2, 3, 4, 5]) +@pytest.mark.parametrize("ndims_range", [2, 3, 4, 5]) +@pytest.mark.parametrize("has_params", [True, False]) +def test_map_union(ndims_domain, ndims_range, has_params): + if has_params: + d_param = "n" + r_param = "m" + else: + d_param = None + r_param = None + + x, x_domain_info, x_range_info = generate_random_named_map( + ndims_domain, "x_in", d_param, + ndims_range, "x_out", r_param + ) + + y, y_domain_info, y_range_info = generate_random_named_map( + ndims_domain, "y_in", d_param, + ndims_domain, "y_out", r_param + ) + + _, x_in_dims, x_in_cond = x_domain_info + _, x_out_dims, x_out_cond = x_range_info + + _, y_in_dims, y_in_cond = y_domain_info + _, y_out_dims, y_out_cond = y_range_info + + result_dims = f"[{x_in_dims}, {y_in_dims}] -> [{x_out_dims}, {y_out_dims}]" + result_conds = f"({x_in_cond} and {x_out_cond}) or ({y_in_cond} and {y_out_cond})" + + result_str = "{" + result_dims + " : " + result_conds + "}" + + if has_params: + result_str = f"[{d_param}, {r_param}] ->" + result_str + + result_map = nisl.make_map(result_str) + + assert (x | y) == result_map + + +@pytest.mark.parametrize("ndims_domain", [2, 3, 4, 5]) +@pytest.mark.parametrize("ndims_range", [2, 3, 4, 5]) +@pytest.mark.parametrize("has_params", [True, False]) +def test_map_intersection(ndims_domain, ndims_range, has_params): + if has_params: + d_param = "n" + r_param = "m" + else: + d_param = None + r_param = None + + x, x_domain_info, x_range_info = generate_random_named_map( + ndims_domain, "x_in", d_param, + ndims_range, "x_out", r_param + ) + + y, y_domain_info, y_range_info = generate_random_named_map( + ndims_domain, "y_in", d_param, + ndims_domain, "y_out", r_param + ) + + _, x_in_dims, x_in_cond = x_domain_info + _, x_out_dims, x_out_cond = x_range_info + + _, y_in_dims, y_in_cond = y_domain_info + _, y_out_dims, y_out_cond = y_range_info + + result_dims = f"[{x_in_dims}, {y_in_dims}] -> [{x_out_dims}, {y_out_dims}]" + result_conds = f"({x_in_cond} and {x_out_cond}) and ({y_in_cond} and {y_out_cond})" + + result_str = "{" + result_dims + " : " + result_conds + "}" + + if has_params: + result_str = f"[{d_param}, {r_param}] ->" + result_str + + result_map = nisl.make_map(result_str) + + assert (x & y) == result_map + + +@pytest.mark.parametrize("ndims_domain", [1, 2, 4, 8]) +@pytest.mark.parametrize("ndims_range", [1, 2, 4, 8]) +def test_map_eliminate(ndims_domain, ndims_range): + x, x_domain_info, x_range_info = generate_random_named_map( + ndims_domain, "x_in", None, + ndims_range, "x_out", None + ) + + _, x_in_dims, _ = x_domain_info + _, x_out_dims, _ = x_range_info + + dims_to_remove = (x_in_dims + "," + x_out_dims).split(",") + x = x.eliminate(dims_to_remove) + + assert x == nisl.make_map(f"{{[{x_in_dims}] -> [{x_out_dims}]}}") + + +@pytest.mark.parametrize("ndims_domain", [1, 2, 4, 8]) +@pytest.mark.parametrize("ndims_range", [1, 2, 4, 8]) +def test_map_project_out(ndims_domain, ndims_range): + x, x_domain_info, x_range_info = generate_random_named_map( + ndims_domain, "x_in", None, + ndims_range, "x_out", None + ) + + _, x_in_dims, _ = x_domain_info + _, x_out_dims, _ = x_range_info + + dims_to_remove = (x_in_dims + "," + x_out_dims).split(",") + x = x.project_out(dims_to_remove) + + assert x == nisl.make_map("{[] -> []}") diff --git a/namedisl/test/test_namedisl.py b/namedisl/test/test_namedisl.py deleted file mode 100644 index 086fd9c..0000000 --- a/namedisl/test/test_namedisl.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - - -__copyright__ = """ -Copyright (C) 2025- University of Illinois Board of Trustees -""" - -__license__ = """ -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. -""" - -import islpy as isl - -import namedisl as nisl - - -def test_set_from_str() -> None: - spec = "[n] -> { [i] : 0 <= i < n }" - s_isl = isl.Set(spec) - s = nisl.make_set(spec) - print(s) - - assert s._reconstruct_isl_object() == s_isl - - -def test_set_from_set() -> None: - s_isl = isl.Set("[n] -> { [i] : 0 <= i < n }") - s = nisl.make_set(s_isl) - print(s) - - assert s._reconstruct_isl_object() == s_isl - - -def test_map_from_str() -> None: - spec = "[n] -> { [i] -> [j] : 0 <= i < n and j = 2 * i }" - m = nisl.make_map(spec) - m_isl = isl.Map(spec) - print(m) - - assert m._reconstruct_isl_object() == m_isl - - -def test_map_from_map() -> None: - m_isl = isl.Map("[n] -> { [i] -> [j] : 0 <= i < n and j = 2 * i }") - m = nisl.make_map(m_isl) - print(m) - - assert m._reconstruct_isl_object() == m_isl - - -def test_align_two() -> None: - m1 = nisl.make_map( - "{ [l, m, n, o] -> [i, j, k] : 0 <= i, j, k, l, m, n, o < 10 }") - m2 = nisl.make_map( - "{ [a, b, c] -> [x, y, z] : 0 <= a, b, c, x, y, z < 5 }") - - print(m1) - print(m2) - - m1, m2 = nisl._align_two(m1, m2) - - print(m1) - print(m2) diff --git a/namedisl/test/test_set.py b/namedisl/test/test_set.py new file mode 100644 index 0000000..9e4f2ae --- /dev/null +++ b/namedisl/test/test_set.py @@ -0,0 +1,137 @@ +from __future__ import annotations + + +__copyright__ = """ +Copyright (C) 2025- University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import pytest + +import islpy as isl + +import namedisl as nisl +from .utils_for_tests import generate_random_named_set + + +def test_set_from_str() -> None: + s = nisl.make_set("[n] -> { [i]: 0 <= i < n }") + + print(s._obj) + print(s) + + +def test_set_from_set() -> None: + s = isl.Set("[n] -> { [i, j] : 0 <= i, j < n }") + named_set = nisl.make_set(s) + + print(named_set._obj) + print(named_set) + + +@pytest.mark.parametrize("ndims", [2, 3, 4, 5]) +@pytest.mark.parametrize("has_params", [True, False]) +def test_set_equality(ndims, has_params): + a_param = "n" if has_params else None + + a, a_dims, a_cond = generate_random_named_set(ndims, "a", a_param) + + from itertools import permutations + for perm in list(permutations(a_dims.split(","))): + perm_dims = ",".join(p for p in perm) + set_str = f"{{ [{perm_dims}] : {a_cond} }}" + if has_params: + set_str = f"[{a_param}] ->" + set_str + perm_set = nisl.make_set(set_str) + + assert a == perm_set + + +@pytest.mark.parametrize("ndims", [1, 2, 4, 8]) +@pytest.mark.parametrize("has_params", [True, False]) +def test_set_union(ndims, has_params): + + if has_params: + a_param = "n" + b_param = "m" + else: + a_param = None + b_param = None + + a, a_dims, a_cond = generate_random_named_set(ndims, "a", a_param) + b, b_dims, b_cond = generate_random_named_set(ndims, "b", b_param) + + set_str = f"{{ [{a_dims}, {b_dims}] : ({a_cond}) or ({b_cond})}}" + if has_params: + set_str = "[n, m] -> " + set_str + + result = nisl.make_set(set_str) + + assert (a | b) == result + + +@pytest.mark.parametrize("ndims", [1, 2, 4, 8]) +@pytest.mark.parametrize("has_params", [True, False]) +def test_set_intersection(ndims, has_params): + + if has_params: + a_param = "n" + b_param = "m" + else: + a_param = None + b_param = None + + a, a_dims, a_cond = generate_random_named_set(ndims, "a", a_param) + b, b_dims, b_cond = generate_random_named_set(ndims, "b", b_param) + + set_str = f"{{ [{a_dims}, {b_dims}] : ({a_cond}) and ({b_cond})}}" + if has_params: + set_str = "[n, m] -> " + set_str + + result = nisl.make_set(set_str) + + assert (a & b) == result + + +@pytest.mark.parametrize("ndims", [1, 2, 4, 8]) +def test_set_eliminate(ndims): + a, a_dims, _ = generate_random_named_set(ndims, "a", None) + a = a.eliminate(a_dims.split(",")) + + assert a == nisl.make_set(f"{{[{a_dims}]}}") + + +@pytest.mark.parametrize("ndims", [2, 4, 8]) +def test_set_project_out(ndims): + a, a_dims, _ = generate_random_named_set(ndims, "a", None) + a = a.project_out(a_dims.split(",")) + + assert a == nisl.make_set("{[]}") + + +if __name__ == "__main__": + import sys + if len(sys.argv) > 1: + exec(sys.argv[0]) + else: + from pytest import main + main([__file__]) diff --git a/namedisl/test/utils_for_tests.py b/namedisl/test/utils_for_tests.py new file mode 100644 index 0000000..f2f42ee --- /dev/null +++ b/namedisl/test/utils_for_tests.py @@ -0,0 +1,81 @@ +from __future__ import annotations + + +__copyright__ = """ +Copyright (C) 2025- University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from random import randint + +import islpy as isl + +import namedisl as nisl + + +NamedSetReturnT = tuple[nisl.Set, str, str] +NamedMapReturnT = tuple[nisl.Map, NamedSetReturnT, NamedSetReturnT] + + +def generate_random_named_set( + ndims: int, + dim_prefix: str, + param: str | None + ) -> NamedSetReturnT: + dims = [f"{dim_prefix}_{i}" for i in range(ndims)] + dim_str = ",".join(d for d in dims) + + if param is not None: + conditions = f"0 <= {dim_str} < {param}" + set_str = f"[{param}] -> {{ [{dim_str}] : {conditions} }}" + else: + upper_bounds = [randint(1, 100) for _ in range(ndims)] + lower_bounds = [ + randint(0, upper_bound - 1) for upper_bound in upper_bounds] + + conditions = " and ".join( + f"{lower_bound} <= {d} < {upper_bound}" + for d, lower_bound, upper_bound in zip( + dims, lower_bounds, upper_bounds, strict=True) + ) + set_str = f"{{ [{dim_str}] : {conditions} }}" + + return nisl.make_set(set_str), dim_str, conditions + + +def generate_random_named_map( + ndims_domain: int, + domain_prefix: str, + domain_param: str | None, + ndims_range: int, + range_prefix: str, + range_param: str | None + ) -> NamedMapReturnT: + + d = generate_random_named_set(ndims_domain, domain_prefix, domain_param) + r = generate_random_named_set(ndims_range, range_prefix, range_param) + + return ( + nisl.make_map(isl.Map.from_domain_and_range(d[0]._obj, r[0]._obj)), + d, + r + ) From 40526fb35f773964597f454dddebf5e061b86c82 Mon Sep 17 00:00:00 2001 From: Addison Date: Mon, 19 Jan 2026 22:28:49 -0600 Subject: [PATCH 11/43] fix test utils bug; get tests for setlikes passing --- namedisl/__init__.py | 29 ++++++++++++++--------------- namedisl/test/test_map.py | 10 +++++----- namedisl/test/test_set.py | 10 +++++----- namedisl/test/utils_for_tests.py | 5 ++++- 4 files changed, 28 insertions(+), 26 deletions(-) diff --git a/namedisl/__init__.py b/namedisl/__init__.py index 516ab23..926e582 100644 --- a/namedisl/__init__.py +++ b/namedisl/__init__.py @@ -156,6 +156,8 @@ def _find_contiguous_dim_chunks(dims: Sequence[int]) -> Mapping[int, int]: start = curr count = 1 + chunks[start] = count + return constantdict(chunks) @@ -211,40 +213,37 @@ def _align_obj( new_isl_obj = named_obj._obj running_name_to_dim = dict(named_obj._name_to_dim) - for name, dim in sorted(ordering.items(), key=lambda x: x[1]): + for name, target_dim in sorted(ordering.items(), key=lambda x: x[1]): if name in running_name_to_dim: old_dim = running_name_to_dim[name] - if old_dim == dim: + if old_dim == target_dim: continue # temporarily move to parameter dimension since destination and # source dim types cannot match in ISL new_isl_obj = new_isl_obj.move_dims( isl.dim_type.param, 0, - isl.dim_type.set, dim, 1 + isl.dim_type.set, old_dim, 1 ) new_isl_obj = new_isl_obj.move_dims( - isl.dim_type.set, dim, + isl.dim_type.set, target_dim, isl.dim_type.param, 0, 1 ) else: old_dim = new_isl_obj.dim(isl.dim_type.set) - new_isl_obj = new_isl_obj.insert_dims(isl.dim_type.set, dim, 1) + new_isl_obj = new_isl_obj.insert_dims(isl.dim_type.set, target_dim, 1) # track side effects of inserting/swapping dimensions - temp_name_to_dim = running_name_to_dim.copy() - for cur_name, cur_dim in sorted( - running_name_to_dim.items(), key=lambda x: x[1]): - if (dim > old_dim) and (cur_dim > old_dim): - temp_name_to_dim[cur_name] = cur_dim - 1 - elif (dim < old_dim) and (cur_dim < old_dim): - temp_name_to_dim[cur_name] = cur_dim + 1 + for n, d in list(running_name_to_dim.items()): + if (target_dim > old_dim) and (d > old_dim): + running_name_to_dim[n] = d - 1 + elif (target_dim < old_dim) and (d < old_dim): + running_name_to_dim[n] = d + 1 - running_name_to_dim = temp_name_to_dim - running_name_to_dim[name] = dim + running_name_to_dim[name] = target_dim return type(named_obj)(new_isl_obj, ordering, dimtype_to_names) @@ -270,7 +269,7 @@ def _align_and_apply_binary_op( ) -> NamedIslObject[IslObjectT]: lhs, rhs = _align_two(lhs, rhs) - result = op(lhs._obj, lhs._obj) + result = op(lhs._obj, rhs._obj) # NOTE: since lhs and rhs were aligned, they both agree on what name-to-dim # and dimtype-to-name is, can just take information from lhs diff --git a/namedisl/test/test_map.py b/namedisl/test/test_map.py index 41665ad..b8c6d01 100644 --- a/namedisl/test/test_map.py +++ b/namedisl/test/test_map.py @@ -53,7 +53,7 @@ def test_map_from_map() -> None: @pytest.mark.parametrize("ndims_domain", [2, 3, 4, 5]) @pytest.mark.parametrize("ndims_range", [2, 3, 4, 5]) @pytest.mark.parametrize("has_params", [True, False]) -def test_map_equality(ndims_domain, ndims_range, has_params): +def test_map_equality(ndims_domain: int, ndims_range: int, has_params: bool): if has_params: d_param = "n" r_param = "m" @@ -95,7 +95,7 @@ def test_map_equality(ndims_domain, ndims_range, has_params): @pytest.mark.parametrize("ndims_domain", [2, 3, 4, 5]) @pytest.mark.parametrize("ndims_range", [2, 3, 4, 5]) @pytest.mark.parametrize("has_params", [True, False]) -def test_map_union(ndims_domain, ndims_range, has_params): +def test_map_union(ndims_domain: int, ndims_range: int, has_params: bool): if has_params: d_param = "n" r_param = "m" @@ -135,7 +135,7 @@ def test_map_union(ndims_domain, ndims_range, has_params): @pytest.mark.parametrize("ndims_domain", [2, 3, 4, 5]) @pytest.mark.parametrize("ndims_range", [2, 3, 4, 5]) @pytest.mark.parametrize("has_params", [True, False]) -def test_map_intersection(ndims_domain, ndims_range, has_params): +def test_map_intersection(ndims_domain: int, ndims_range: int, has_params: bool): if has_params: d_param = "n" r_param = "m" @@ -174,7 +174,7 @@ def test_map_intersection(ndims_domain, ndims_range, has_params): @pytest.mark.parametrize("ndims_domain", [1, 2, 4, 8]) @pytest.mark.parametrize("ndims_range", [1, 2, 4, 8]) -def test_map_eliminate(ndims_domain, ndims_range): +def test_map_eliminate(ndims_domain: int, ndims_range: int): x, x_domain_info, x_range_info = generate_random_named_map( ndims_domain, "x_in", None, ndims_range, "x_out", None @@ -191,7 +191,7 @@ def test_map_eliminate(ndims_domain, ndims_range): @pytest.mark.parametrize("ndims_domain", [1, 2, 4, 8]) @pytest.mark.parametrize("ndims_range", [1, 2, 4, 8]) -def test_map_project_out(ndims_domain, ndims_range): +def test_map_project_out(ndims_domain: int, ndims_range: int): x, x_domain_info, x_range_info = generate_random_named_map( ndims_domain, "x_in", None, ndims_range, "x_out", None diff --git a/namedisl/test/test_set.py b/namedisl/test/test_set.py index 9e4f2ae..7697195 100644 --- a/namedisl/test/test_set.py +++ b/namedisl/test/test_set.py @@ -50,7 +50,7 @@ def test_set_from_set() -> None: @pytest.mark.parametrize("ndims", [2, 3, 4, 5]) @pytest.mark.parametrize("has_params", [True, False]) -def test_set_equality(ndims, has_params): +def test_set_equality(ndims: int, has_params: bool): a_param = "n" if has_params else None a, a_dims, a_cond = generate_random_named_set(ndims, "a", a_param) @@ -68,7 +68,7 @@ def test_set_equality(ndims, has_params): @pytest.mark.parametrize("ndims", [1, 2, 4, 8]) @pytest.mark.parametrize("has_params", [True, False]) -def test_set_union(ndims, has_params): +def test_set_union(ndims: int, has_params: bool): if has_params: a_param = "n" @@ -91,7 +91,7 @@ def test_set_union(ndims, has_params): @pytest.mark.parametrize("ndims", [1, 2, 4, 8]) @pytest.mark.parametrize("has_params", [True, False]) -def test_set_intersection(ndims, has_params): +def test_set_intersection(ndims: int, has_params: bool): if has_params: a_param = "n" @@ -113,7 +113,7 @@ def test_set_intersection(ndims, has_params): @pytest.mark.parametrize("ndims", [1, 2, 4, 8]) -def test_set_eliminate(ndims): +def test_set_eliminate(ndims: int): a, a_dims, _ = generate_random_named_set(ndims, "a", None) a = a.eliminate(a_dims.split(",")) @@ -121,7 +121,7 @@ def test_set_eliminate(ndims): @pytest.mark.parametrize("ndims", [2, 4, 8]) -def test_set_project_out(ndims): +def test_set_project_out(ndims: int): a, a_dims, _ = generate_random_named_set(ndims, "a", None) a = a.project_out(a_dims.split(",")) diff --git a/namedisl/test/utils_for_tests.py b/namedisl/test/utils_for_tests.py index f2f42ee..a615e72 100644 --- a/namedisl/test/utils_for_tests.py +++ b/namedisl/test/utils_for_tests.py @@ -74,8 +74,11 @@ def generate_random_named_map( d = generate_random_named_set(ndims_domain, domain_prefix, domain_param) r = generate_random_named_set(ndims_range, range_prefix, range_param) + dom = d[0]._reconstruct_isl_object() + ran = r[0]._reconstruct_isl_object() + return ( - nisl.make_map(isl.Map.from_domain_and_range(d[0]._obj, r[0]._obj)), + nisl.make_map(isl.Map.from_domain_and_range(dom, ran)), d, r ) From 8788935001ae8949020643afc4ebd5328855d4fc Mon Sep 17 00:00:00 2001 From: Addison Date: Tue, 20 Jan 2026 12:55:57 -0600 Subject: [PATCH 12/43] typing changes to be consistent with other packages + minor refactoring --- namedisl/__init__.py | 64 ++++++++++++++++++++++---------------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/namedisl/__init__.py b/namedisl/__init__.py index 926e582..1c41301 100644 --- a/namedisl/__init__.py +++ b/namedisl/__init__.py @@ -54,9 +54,9 @@ IslSetLike = isl.Set | isl.Map IslExpressionLike = isl.Aff | isl.QPolynomial -IslExpressionLikeT = TypeVar("IslExpressionLikeT", isl.Aff, isl.QPolynomial) -IslSetLikeT = TypeVar("IslSetLikeT", isl.Set, isl.Map) -IslObjectT = TypeVar("IslObjectT", IslSetLike, IslExpressionLike) +IslExpressionLikeTc = TypeVar("IslExpressionLikeT", isl.Aff, isl.QPolynomial) +IslSetLikeTc = TypeVar("IslSetLikeT", isl.Set, isl.Map) +IslObjectTc = TypeVar("IslObjectT", IslSetLike, IslExpressionLike) NameToDim: TypeAlias = Mapping[str, int] @@ -68,7 +68,7 @@ SetLikePieces: TypeAlias = tuple[isl.Set, DimTypeToNames] -def _strip_names(obj: IslObjectT) -> tuple[IslObjectT, NameToDim]: +def _strip_names(obj: IslObjectTc) -> tuple[IslObjectTc, NameToDim]: name_to_dim: dict[str, int] = {} for i in range(obj.dim(isl.dim_type.set)): @@ -88,13 +88,13 @@ def _strip_names(obj: IslObjectT) -> tuple[IslObjectT, NameToDim]: return obj, constantdict(name_to_dim) -def _restore_names(obj: IslObjectT, name_to_dim: NameToDim) -> IslObjectT: +def _restore_names(obj: IslObjectTc, name_to_dim: NameToDim) -> IslObjectTc: for name, dim in name_to_dim.items(): obj = obj.set_dim_name(isl.dim_type.set, dim, name) return obj -def _get_dim_names(obj: IslObjectT, dt: isl.dim_type) -> frozenset[str]: +def _get_dim_names(obj: IslObjectTc, dt: isl.dim_type) -> frozenset[str]: all_dt_names: list[str] = [] for dim in range(obj.dim(dt)): @@ -111,7 +111,7 @@ def _get_dim_names(obj: IslObjectT, dt: isl.dim_type) -> frozenset[str]: return frozenset(all_dt_names) -def _deconstruct_set_like_object(obj: IslSetLikeT) -> SetLikePieces: +def _deconstruct_set_like_object(obj: IslSetLikeTc) -> SetLikePieces: from islpy import dim_type dt_to_names: dict[dim_type, frozenset[str]] = dict.fromkeys( @@ -164,8 +164,8 @@ def _find_contiguous_dim_chunks(dims: Sequence[int]) -> Mapping[int, int]: # FIXME: think through whether or not alphabetical ordering will require more # work on average than using one of the objects as a template in alignment def _find_joint_name_to_dim( - obj1: NamedIslObject[IslObjectT], - obj2: NamedIslObject[IslObjectT] + obj1: NamedIslObject[IslObjectTc], + obj2: NamedIslObject[IslObjectTc] ) -> tuple[NameToDim, DimTypeToNames]: """ Enforces alphabetical ordering of all dimensions found in :arg:`obj1` and @@ -194,22 +194,24 @@ def _find_joint_name_to_dim( # enforces contiguous ordering of [ (set), (input), (param) ] in set # representation - all_names = sorted(all_set_names) - all_names += sorted(all_inp_names) - all_names += sorted(all_param_names) + all_names = [ + *sorted(all_set_names), + *sorted(all_inp_names), + *sorted(all_param_names), + ] - name_to_dim: NameToDim = {} - for pos, name in enumerate(all_names): - name_to_dim[name] = pos + name_to_dim: NameToDim = constantdict({ + name: pos for pos, name in enumerate(all_names) + }) - return constantdict(name_to_dim), constantdict(dt_to_names) + return name_to_dim, constantdict(dt_to_names) def _align_obj( - named_obj: NamedIslObject[IslObjectT], + named_obj: NamedIslObject[IslObjectTc], ordering: NameToDim, dimtype_to_names: DimTypeToNames - ) -> NamedIslObject[IslObjectT]: + ) -> NamedIslObject[IslObjectTc]: new_isl_obj = named_obj._obj running_name_to_dim = dict(named_obj._name_to_dim) @@ -249,9 +251,9 @@ def _align_obj( def _align_two( - named_obj1: NamedIslObject[IslObjectT], - named_obj2: NamedIslObject[IslObjectT] - ) -> tuple[NamedIslObject[IslObjectT], ...]: + named_obj1: NamedIslObject[IslObjectTc], + named_obj2: NamedIslObject[IslObjectTc] + ) -> tuple[NamedIslObject[IslObjectTc], ...]: name_to_dim, dimtype_to_names = _find_joint_name_to_dim(named_obj1, named_obj2) @@ -263,10 +265,10 @@ def _align_two( def _align_and_apply_binary_op( - lhs: NamedIslObject[IslObjectT], - rhs: NamedIslObject[IslObjectT], - op: Callable[[IslObjectT, IslObjectT], IslObjectT] - ) -> NamedIslObject[IslObjectT]: + lhs: NamedIslObject[IslObjectTc], + rhs: NamedIslObject[IslObjectTc], + op: Callable[[IslObjectTc, IslObjectTc], IslObjectTc] + ) -> NamedIslObject[IslObjectTc]: lhs, rhs = _align_two(lhs, rhs) result = op(lhs._obj, rhs._obj) @@ -277,8 +279,8 @@ def _align_and_apply_binary_op( @dataclass(frozen=True, eq=False) -class NamedIslObject(Generic[IslObjectT], ABC): - _obj: IslObjectT +class NamedIslObject(Generic[IslObjectTc], ABC): + _obj: IslObjectTc _name_to_dim: NameToDim # used to reconstruct ISL object @@ -331,15 +333,15 @@ def _parameter_dim_start(self) -> int | None: return None def __and__( - self, other: NamedIslObject[IslObjectT]) -> NamedIslObject[IslObjectT]: + self, other: NamedIslObject[IslObjectTc]) -> NamedIslObject[IslObjectTc]: return _align_and_apply_binary_op(self, other, operator.and_) def __or__( - self, other: NamedIslObject[IslObjectT]) -> NamedIslObject[IslObjectT]: + self, other: NamedIslObject[IslObjectTc]) -> NamedIslObject[IslObjectTc]: return _align_and_apply_binary_op(self, other, operator.or_) def __sub__( - self, other: NamedIslObject[IslObjectT]) -> NamedIslObject[IslObjectT]: + self, other: NamedIslObject[IslObjectTc]) -> NamedIslObject[IslObjectTc]: return _align_and_apply_binary_op(self, other, operator.sub) @abstractmethod @@ -358,8 +360,6 @@ class _NamedIslSetLike(NamedIslObject[isl.Set], ABC): set. Names are organized as contiguous chunks of dimension types, i.e. [ (set names), (input names), (parameter names) ] """ - _obj: isl.Set - def complement(self: _NamedIslSetLike) -> _NamedIslSetLike: return type(self)(self._obj.complement(), self._name_to_dim, From 981a176b7f25b3765dac175e0ef536e44fb9bb68 Mon Sep 17 00:00:00 2001 From: Addison Date: Tue, 20 Jan 2026 14:31:37 -0600 Subject: [PATCH 13/43] major refactor of __init__ into different files to improve structure; fix github basedpyright CI --- .github/workflows/ci.yml | 1 + namedisl/__init__.py | 565 +------------------------------------- namedisl/core.py | 356 ++++++++++++++++++++++++ namedisl/set_like.py | 298 ++++++++++++++++++++ namedisl/test/test_set.py | 9 - 5 files changed, 657 insertions(+), 572 deletions(-) create mode 100644 namedisl/core.py create mode 100644 namedisl/set_like.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 27f0008..eea93bd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -70,6 +70,7 @@ jobs: python-version: '3.x' - name: "Main Script" run: | + EXTRA_INSTALL="pytest" curl -L -O https://tiker.net/ci-support-v0 . ci-support-v0 build_py_project_in_venv diff --git a/namedisl/__init__.py b/namedisl/__init__.py index 1c41301..0795cea 100644 --- a/namedisl/__init__.py +++ b/namedisl/__init__.py @@ -1,10 +1,3 @@ -""" -.. autoclass:: BasicSet - -.. autofunction:: make_basic_set -""" - - from __future__ import annotations @@ -31,562 +24,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -import operator -import re -from abc import ABC, abstractmethod -from collections.abc import Callable, Mapping, Sequence -from dataclasses import dataclass -from importlib import metadata -from typing import Generic, TypeAlias, TypeVar, final, overload - -from constantdict import constantdict -from typing_extensions import override - -import islpy as isl - - -__version__ = metadata.version("namedisl") -_match = re.match(r"^([0-9.]+)([a-z0-9]*?)$", __version__) -assert _match -VERSION = tuple(int(nr) for nr in _match.group(1).split(".")) - - -IslSetLike = isl.Set | isl.Map -IslExpressionLike = isl.Aff | isl.QPolynomial - -IslExpressionLikeTc = TypeVar("IslExpressionLikeT", isl.Aff, isl.QPolynomial) -IslSetLikeTc = TypeVar("IslSetLikeT", isl.Set, isl.Map) -IslObjectTc = TypeVar("IslObjectT", IslSetLike, IslExpressionLike) - -NameToDim: TypeAlias = Mapping[str, int] - -# NOTE: without tracking what dimension type a particular name belongs to, it is -# not possible to reconstruct the ISL object after dimension operations, e.g. -# alignment -DimTypeToNames: TypeAlias = Mapping[isl.dim_type, frozenset[str]] - -SetLikePieces: TypeAlias = tuple[isl.Set, DimTypeToNames] - - -def _strip_names(obj: IslObjectTc) -> tuple[IslObjectTc, NameToDim]: - name_to_dim: dict[str, int] = {} - for i in range(obj.dim(isl.dim_type.set)): - - if isinstance(obj, isl.QPolynomial): - name = obj.space.get_dim_name(isl.dim_type.set, i) - else: - name = obj.get_dim_name(isl.dim_type.set, i) - - if name is None: - raise ValueError("unnamed dimension found") - - if name in name_to_dim: - raise ValueError(f"non-unique dim name: {name}") - - name_to_dim[name] = i - - return obj, constantdict(name_to_dim) - - -def _restore_names(obj: IslObjectTc, name_to_dim: NameToDim) -> IslObjectTc: - for name, dim in name_to_dim.items(): - obj = obj.set_dim_name(isl.dim_type.set, dim, name) - return obj - - -def _get_dim_names(obj: IslObjectTc, dt: isl.dim_type) -> frozenset[str]: - all_dt_names: list[str] = [] - for dim in range(obj.dim(dt)): - - if isinstance(obj, isl.QPolynomial): - name = obj.space.get_dim_name(dt, dim) - else: - name = obj.get_dim_name(dt, dim) - - if name is None: - raise ValueError("unnamed dimension found") - - all_dt_names.append(name) - - return frozenset(all_dt_names) - - -def _deconstruct_set_like_object(obj: IslSetLikeTc) -> SetLikePieces: - from islpy import dim_type - - dt_to_names: dict[dim_type, frozenset[str]] = dict.fromkeys( - [isl.dim_type.in_, isl.dim_type.param], frozenset() - ) - for dt in dt_to_names: - dt_to_names[dt] = _get_dim_names(obj, dt) - if dt_to_names[dt]: - obj = obj.move_dims( - dim_type.set, - obj.dim(dim_type.set), - dt, - 0, - obj.dim(dt) - ) - - set_obj = obj.range() if isinstance(obj, isl.Map) else obj - - return set_obj, constantdict(dt_to_names) - - -def _find_contiguous_dim_chunks(dims: Sequence[int]) -> Mapping[int, int]: - """ - Determines contiguous chunks of dimensions within a sequence of dimensions. - Returns a mapping of the first dimension in the chunk to the length of the - chunk. - """ - if not dims: - return {} - - chunks: dict[int, int] = {} - - start = dims[0] - count = 1 - - from itertools import pairwise - for prev, curr in pairwise(dims): - if curr == prev + 1: - count += 1 - else: - chunks[start] = count - start = curr - count = 1 - - chunks[start] = count - - return constantdict(chunks) - - -# FIXME: think through whether or not alphabetical ordering will require more -# work on average than using one of the objects as a template in alignment -def _find_joint_name_to_dim( - obj1: NamedIslObject[IslObjectTc], - obj2: NamedIslObject[IslObjectTc] - ) -> tuple[NameToDim, DimTypeToNames]: - """ - Enforces alphabetical ordering of all dimensions found in :arg:`obj1` and - :arg:`obj2` within each dimension-type chunk. This ordering is used in - alignment before performing operations between two set-like objects. - """ - obj1_inp_names = obj1._input_names - obj1_param_names = obj1._parameter_names - obj1_set_names = ( - frozenset(obj1._name_to_dim.keys()) - (obj1_inp_names | obj1_param_names) - ) - - obj2_inp_names = obj2._input_names - obj2_param_names = obj2._parameter_names - obj2_set_names = ( - frozenset(obj2._name_to_dim.keys()) - (obj2_inp_names | obj2_param_names) - ) - - all_inp_names = obj1_inp_names | obj2_inp_names - all_param_names = obj1_param_names | obj2_param_names - all_set_names = obj1_set_names | obj2_set_names - - dt_to_names: DimTypeToNames = {} - dt_to_names[isl.dim_type.param] = all_param_names - dt_to_names[isl.dim_type.in_] = all_inp_names - - # enforces contiguous ordering of [ (set), (input), (param) ] in set - # representation - all_names = [ - *sorted(all_set_names), - *sorted(all_inp_names), - *sorted(all_param_names), - ] - - name_to_dim: NameToDim = constantdict({ - name: pos for pos, name in enumerate(all_names) - }) - - return name_to_dim, constantdict(dt_to_names) - - -def _align_obj( - named_obj: NamedIslObject[IslObjectTc], - ordering: NameToDim, - dimtype_to_names: DimTypeToNames - ) -> NamedIslObject[IslObjectTc]: - new_isl_obj = named_obj._obj - running_name_to_dim = dict(named_obj._name_to_dim) - - for name, target_dim in sorted(ordering.items(), key=lambda x: x[1]): - if name in running_name_to_dim: - old_dim = running_name_to_dim[name] - - if old_dim == target_dim: - continue - - # temporarily move to parameter dimension since destination and - # source dim types cannot match in ISL - new_isl_obj = new_isl_obj.move_dims( - isl.dim_type.param, 0, - isl.dim_type.set, old_dim, 1 - ) - - new_isl_obj = new_isl_obj.move_dims( - isl.dim_type.set, target_dim, - isl.dim_type.param, 0, 1 - ) - - else: - old_dim = new_isl_obj.dim(isl.dim_type.set) - new_isl_obj = new_isl_obj.insert_dims(isl.dim_type.set, target_dim, 1) - - # track side effects of inserting/swapping dimensions - for n, d in list(running_name_to_dim.items()): - if (target_dim > old_dim) and (d > old_dim): - running_name_to_dim[n] = d - 1 - elif (target_dim < old_dim) and (d < old_dim): - running_name_to_dim[n] = d + 1 - - running_name_to_dim[name] = target_dim - - return type(named_obj)(new_isl_obj, ordering, dimtype_to_names) - - -def _align_two( - named_obj1: NamedIslObject[IslObjectTc], - named_obj2: NamedIslObject[IslObjectTc] - ) -> tuple[NamedIslObject[IslObjectTc], ...]: - - name_to_dim, dimtype_to_names = _find_joint_name_to_dim(named_obj1, - named_obj2) - - named_obj1 = _align_obj(named_obj1, name_to_dim, dimtype_to_names) - named_obj2 = _align_obj(named_obj2, name_to_dim, dimtype_to_names) - - return named_obj1, named_obj2 - - -def _align_and_apply_binary_op( - lhs: NamedIslObject[IslObjectTc], - rhs: NamedIslObject[IslObjectTc], - op: Callable[[IslObjectTc, IslObjectTc], IslObjectTc] - ) -> NamedIslObject[IslObjectTc]: - - lhs, rhs = _align_two(lhs, rhs) - result = op(lhs._obj, rhs._obj) - - # NOTE: since lhs and rhs were aligned, they both agree on what name-to-dim - # and dimtype-to-name is, can just take information from lhs - return type(lhs)(result, lhs._name_to_dim, lhs._dimtype_to_names) - - -@dataclass(frozen=True, eq=False) -class NamedIslObject(Generic[IslObjectTc], ABC): - _obj: IslObjectTc - _name_to_dim: NameToDim - - # used to reconstruct ISL object - _dimtype_to_names: DimTypeToNames - - @property - def _has_inputs(self) -> bool: - return ( - isl.dim_type.in_ in self._dimtype_to_names - and - len(self._dimtype_to_names[isl.dim_type.in_]) > 0 - ) - - @property - def _input_names(self) -> frozenset[str]: - if self._has_inputs: - return self._dimtype_to_names[isl.dim_type.in_] - return frozenset() - - @property - def _input_dim_start(self) -> int | None: - if self._has_inputs: - return min( - self._name_to_dim[name] - for name in self._dimtype_to_names[isl.dim_type.in_] - ) - return None - - @property - def _has_params(self) -> bool: - return ( - isl.dim_type.param in self._dimtype_to_names - and - len(self._dimtype_to_names[isl.dim_type.param]) > 0 - ) - - @property - def _parameter_names(self) -> frozenset[str]: - if self._has_params: - return self._dimtype_to_names[isl.dim_type.param] - return frozenset() - - @property - def _parameter_dim_start(self) -> int | None: - if self._has_params: - return min( - self._name_to_dim[name] - for name in self._dimtype_to_names[isl.dim_type.param] - ) - return None - - def __and__( - self, other: NamedIslObject[IslObjectTc]) -> NamedIslObject[IslObjectTc]: - return _align_and_apply_binary_op(self, other, operator.and_) - - def __or__( - self, other: NamedIslObject[IslObjectTc]) -> NamedIslObject[IslObjectTc]: - return _align_and_apply_binary_op(self, other, operator.or_) - - def __sub__( - self, other: NamedIslObject[IslObjectTc]) -> NamedIslObject[IslObjectTc]: - return _align_and_apply_binary_op(self, other, operator.sub) - - @abstractmethod - def _reconstruct_isl_object(self) -> IslExpressionLike | IslSetLike: - ... - - @override - def __str__(self) -> str: - return str(self._reconstruct_isl_object()) - - -@dataclass(frozen=True, eq=False) -class _NamedIslSetLike(NamedIslObject[isl.Set], ABC): - """ - Represents set-like objects with parameter dimensions as a non-parameterized - set. Names are organized as contiguous chunks of dimension types, i.e. - [ (set names), (input names), (parameter names) ] - """ - def complement(self: _NamedIslSetLike) -> _NamedIslSetLike: - return type(self)(self._obj.complement(), - self._name_to_dim, - self._dimtype_to_names) - - def eliminate(self, names_to_eliminate: str | Sequence[str]) -> _NamedIslSetLike: - if isinstance(names_to_eliminate, str): - names_to_eliminate = [names_to_eliminate] - - dims_to_eliminate = sorted( - self._name_to_dim[name] - for name in names_to_eliminate - ) - - contiguous_dim_chunks = _find_contiguous_dim_chunks(dims_to_eliminate) - - new_isl_obj = self._obj - for start in sorted(contiguous_dim_chunks): - new_isl_obj = new_isl_obj.eliminate( - isl.dim_type.set, start, contiguous_dim_chunks[start] - ) - - return type(self)( - new_isl_obj, - self._name_to_dim, # NOTE: no dimensions are removed by elimination - self._dimtype_to_names - ) - - def project_out(self: _NamedIslSetLike, - names_to_project_out: str | Sequence[str]) -> _NamedIslSetLike: - - if isinstance(names_to_project_out, str): - names_to_project_out = [names_to_project_out] - - names_to_remove = set(names_to_project_out) - - dims_to_remove = sorted( - self._name_to_dim[name] - for name in names_to_remove - ) - - new_isl_obj = self._obj - contiguous_dim_chunks = _find_contiguous_dim_chunks(dims_to_remove) - for start in sorted(contiguous_dim_chunks, reverse=True): - new_isl_obj = new_isl_obj.project_out( - isl.dim_type.set, start, contiguous_dim_chunks[start] - ) - - new_name_to_dim: NameToDim = {} - for name, dim in self._name_to_dim.items(): - if name in names_to_remove: - continue - - shift = 0 - for removed_dim in dims_to_remove: - if removed_dim < dim: - shift += 1 - else: - break - - new_name_to_dim[name] = dim - shift - - new_type_to_names = constantdict({ - dt: self._dimtype_to_names[dt] - frozenset(names_to_remove) - for dt in self._dimtype_to_names - }) - - return type(self)( - new_isl_obj, - constantdict(new_name_to_dim), - new_type_to_names - ) - - def project_out_except( - self: _NamedIslSetLike, - names_to_keep: str | Sequence[str] - ) -> _NamedIslSetLike: - - if isinstance(names_to_keep, str): - names_to_keep = [names_to_keep] - - names_to_project_out = [ - name for name in self._name_to_dim - if name not in names_to_keep - ] - - return self.project_out(names_to_project_out) - - # {{{ TODO: funtions that return ExpressionLike objects - - def dim_max(self, name: str): - ... - - def dim_min(self, name: str): - ... - - def as_pw_multi_aff(self): - ... - - # }}} - - -@final -@dataclass(frozen=True, eq=False) -class Set(_NamedIslSetLike): - @override - def _reconstruct_isl_object(self) -> isl.Set: - if self._has_params: - if self._parameter_dim_start is None: - raise ValueError( - "Object has parameter dimensions, but a starting index for " - "parameter names is not given. Reconstruction is not " - "possible") - - return self._obj.move_dims( - isl.dim_type.param, 0, - isl.dim_type.set, self._parameter_dim_start, - len(self._parameter_names) - ) - - return self._obj - - @override - def __eq__(self, other: object) -> bool: - if not isinstance(other, Set): - raise NotImplementedError - - aligned_self, aligned_other = _align_two(self, other) - - # FIXME: type checker complains because it's not clear whether the - # underlying object after alignment is an isl.Set - assert isinstance(aligned_other._obj, isl.Set) - assert isinstance(aligned_self._obj, isl.Set) - return aligned_self._obj.plain_is_equal(aligned_other._obj) - - -@overload -def make_set(src: str, ctx: isl.Context | None = None) -> Set: - ... - - -@overload -def make_set(src: isl.Set) -> Set: - ... - - -def make_set(src: isl.Set | str, ctx: isl.Context | None = None) -> Set: - obj = isl.Set(src, ctx) if isinstance(src, str) else src - - set_obj, dimtype_to_names = _deconstruct_set_like_object(obj) - set_obj, name_to_dim = _strip_names(set_obj) - - assert isinstance(set_obj, isl.Set) - return Set(set_obj, name_to_dim, dimtype_to_names) - - -@final -@dataclass(frozen=True, eq=False) -class Map(_NamedIslSetLike): - @override - def _reconstruct_isl_object(self) -> isl.Map: - """ - Relies on the dimension type ordering in - :func:`_deconstruct_set_like_object`. - """ - if self._input_dim_start is None: - raise ValueError("Cannot reconstruct a map object without knowledge " - "of the starting position of input dimensions") - - obj = _restore_names(self._obj, self._name_to_dim) - assert isinstance(obj, isl.Set) - - obj_domain = isl.Set("{ [] }") - obj_range = obj - - map_obj = isl.Map.from_domain_and_range(obj_domain, obj_range) - - if self._has_params: - if self._parameter_dim_start is None: - raise ValueError( - "Object has parameter dimensions, but a starting index for " - "parameter names is not given. Reconstruction is not " - "possible") - - param_start = self._parameter_dim_start - map_obj = map_obj.move_dims( - isl.dim_type.param, 0, - isl.dim_type.set, param_start, len(self._parameter_names) - ) - - inp_start = self._input_dim_start - map_obj = map_obj.move_dims( - isl.dim_type.in_, 0, - isl.dim_type.set, inp_start, len(self._input_names) - ) - - return map_obj - - @override - def __eq__(self, other: object) -> bool: - if not isinstance(other, Map): - raise NotImplementedError - - aligned_self, aligned_other = _align_two(self, other) - - # FIXME: type checker complains because it's not clear whether the - # underlying object after alignment is an isl.Set - assert isinstance(aligned_self._obj, isl.Set) - assert isinstance(aligned_other._obj, isl.Set) - return aligned_self._obj.plain_is_equal(aligned_other._obj) - - -@overload -def make_map(src: str, ctx: isl.Context | None = None) -> Map: - ... - - -@overload -def make_map(src: isl.Map) -> Map: - ... - -def make_map(src: str | isl.Map, ctx: isl.Context | None = None) -> Map: - obj = isl.Map(src, ctx) if isinstance(src, str) else src +from .set_like import Map, Set, make_map, make_set - set_obj, dimtype_to_names = _deconstruct_set_like_object(obj) - set_obj, name_to_dim = _strip_names(set_obj) - assert isinstance(set_obj, isl.Set) - return Map(set_obj, name_to_dim, dimtype_to_names) +__all__ = ["Map", "Set", "make_map", "make_set"] diff --git a/namedisl/core.py b/namedisl/core.py new file mode 100644 index 0000000..bcd5c88 --- /dev/null +++ b/namedisl/core.py @@ -0,0 +1,356 @@ +from __future__ import annotations + + +__copyright__ = """ +Copyright (C) 2025- University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import operator +import re +from abc import ABC, abstractmethod +from collections.abc import Callable, Mapping, Sequence +from dataclasses import dataclass +from importlib import metadata +from typing import Generic, TypeAlias, TypeVar + +from constantdict import constantdict +from typing_extensions import override + +import islpy as isl + + +IslSetLike = isl.Set | isl.Map +IslExpressionLike = isl.Aff | isl.QPolynomial + +IslExpressionLikeT = TypeVar("IslExpressionLikeT", isl.Aff, isl.QPolynomial) +IslSetLikeT = TypeVar("IslSetLikeT", isl.Set, isl.Map) +IslObjectT = TypeVar("IslObjectT", IslSetLike, IslExpressionLike) + +NameToDim: TypeAlias = Mapping[str, int] + +# NOTE: without tracking what dimension type a particular name belongs to, it is +# not possible to reconstruct the ISL object after dimension operations, e.g. +# alignment +DimTypeToNames: TypeAlias = Mapping[isl.dim_type, frozenset[str]] + +SetLikePieces: TypeAlias = tuple[isl.Set, DimTypeToNames] + + +__version__ = metadata.version("namedisl") +_match = re.match(r"^([0-9.]+)([a-z0-9]*?)$", __version__) +assert _match +VERSION = tuple(int(nr) for nr in _match.group(1).split(".")) + +__all__ = [ + "_align_and_apply_binary_op", + "_align_two", + "_deconstruct_object", + "_find_contiguous_dim_chunks", + "_restore_names", + "_strip_names", +] + + +def _strip_names(obj: IslObjectT) -> tuple[IslObjectT, NameToDim]: + name_to_dim: dict[str, int] = {} + for i in range(obj.dim(isl.dim_type.set)): + + if isinstance(obj, isl.QPolynomial): + name = obj.space.get_dim_name(isl.dim_type.set, i) + else: + name = obj.get_dim_name(isl.dim_type.set, i) + + if name is None: + raise ValueError("unnamed dimension found") + + if name in name_to_dim: + raise ValueError(f"non-unique dim name: {name}") + + name_to_dim[name] = i + + return obj, constantdict(name_to_dim) + + +def _restore_names(obj: IslObjectT, name_to_dim: NameToDim) -> IslObjectT: + for name, dim in name_to_dim.items(): + obj = obj.set_dim_name(isl.dim_type.set, dim, name) + return obj + + +def _get_dim_names(obj: IslObjectT, dt: isl.dim_type) -> frozenset[str]: + all_dt_names: list[str] = [] + for dim in range(obj.dim(dt)): + + if isinstance(obj, isl.QPolynomial): + name = obj.space.get_dim_name(dt, dim) + else: + name = obj.get_dim_name(dt, dim) + + if name is None: + raise ValueError("unnamed dimension found") + + all_dt_names.append(name) + + return frozenset(all_dt_names) + + +def _deconstruct_object(obj: IslSetLikeT) -> SetLikePieces: + from islpy import dim_type + + dt_to_names: dict[dim_type, frozenset[str]] = dict.fromkeys( + [isl.dim_type.in_, isl.dim_type.param], frozenset() + ) + for dt in dt_to_names: + dt_to_names[dt] = _get_dim_names(obj, dt) + if dt_to_names[dt]: + obj = obj.move_dims( + dim_type.set, + obj.dim(dim_type.set), + dt, + 0, + obj.dim(dt) + ) + + set_obj = obj.range() if isinstance(obj, isl.Map) else obj + + return set_obj, constantdict(dt_to_names) + + +def _find_contiguous_dim_chunks(dims: Sequence[int]) -> Mapping[int, int]: + """ + Determines contiguous chunks of dimensions within a sequence of dimensions. + Returns a mapping of the first dimension in the chunk to the length of the + chunk. + """ + if not dims: + return {} + + chunks: dict[int, int] = {} + + start = dims[0] + count = 1 + + from itertools import pairwise + for prev, curr in pairwise(dims): + if curr == prev + 1: + count += 1 + else: + chunks[start] = count + start = curr + count = 1 + + chunks[start] = count + + return constantdict(chunks) + + +# FIXME: think through whether or not alphabetical ordering will require more +# work on average than using one of the objects as a template in alignment +def _find_joint_name_to_dim( + obj1: NamedIslObject[IslObjectT], + obj2: NamedIslObject[IslObjectT] + ) -> tuple[NameToDim, DimTypeToNames]: + """ + Enforces alphabetical ordering of all dimensions found in :arg:`obj1` and + :arg:`obj2` within each dimension-type chunk. This ordering is used in + alignment before performing operations between two set-like objects. + """ + obj1_inp_names = obj1._input_names + obj1_param_names = obj1._parameter_names + obj1_set_names = ( + frozenset(obj1._name_to_dim.keys()) - (obj1_inp_names | obj1_param_names) + ) + + obj2_inp_names = obj2._input_names + obj2_param_names = obj2._parameter_names + obj2_set_names = ( + frozenset(obj2._name_to_dim.keys()) - (obj2_inp_names | obj2_param_names) + ) + + all_inp_names = obj1_inp_names | obj2_inp_names + all_param_names = obj1_param_names | obj2_param_names + all_set_names = obj1_set_names | obj2_set_names + + dt_to_names: DimTypeToNames = {} + dt_to_names[isl.dim_type.param] = all_param_names + dt_to_names[isl.dim_type.in_] = all_inp_names + + # enforces contiguous ordering of [ (set), (input), (param) ] in set + # representation + all_names = [ + *sorted(all_set_names), + *sorted(all_inp_names), + *sorted(all_param_names), + ] + + name_to_dim: NameToDim = constantdict({ + name: pos for pos, name in enumerate(all_names) + }) + + return name_to_dim, constantdict(dt_to_names) + + +def _align_obj( + named_obj: NamedIslObject[IslObjectT], + ordering: NameToDim, + dimtype_to_names: DimTypeToNames + ) -> NamedIslObject[IslObjectT]: + new_isl_obj = named_obj._obj + running_name_to_dim = dict(named_obj._name_to_dim) + + for name, target_dim in sorted(ordering.items(), key=lambda x: x[1]): + if name in running_name_to_dim: + old_dim = running_name_to_dim[name] + + if old_dim == target_dim: + continue + + # temporarily move to parameter dimension since destination and + # source dim types cannot match in ISL + new_isl_obj = new_isl_obj.move_dims( + isl.dim_type.param, 0, + isl.dim_type.set, old_dim, 1 + ) + + new_isl_obj = new_isl_obj.move_dims( + isl.dim_type.set, target_dim, + isl.dim_type.param, 0, 1 + ) + + else: + old_dim = new_isl_obj.dim(isl.dim_type.set) + new_isl_obj = new_isl_obj.insert_dims(isl.dim_type.set, target_dim, 1) + + # track side effects of inserting/swapping dimensions + for n, d in list(running_name_to_dim.items()): + if (target_dim > old_dim) and (d > old_dim): + running_name_to_dim[n] = d - 1 + elif (target_dim < old_dim) and (d < old_dim): + running_name_to_dim[n] = d + 1 + + running_name_to_dim[name] = target_dim + + return type(named_obj)(new_isl_obj, ordering, dimtype_to_names) + + +def _align_two( + named_obj1: NamedIslObject[IslObjectT], + named_obj2: NamedIslObject[IslObjectT] + ) -> tuple[NamedIslObject[IslObjectT], ...]: + + name_to_dim, dimtype_to_names = _find_joint_name_to_dim(named_obj1, + named_obj2) + + named_obj1 = _align_obj(named_obj1, name_to_dim, dimtype_to_names) + named_obj2 = _align_obj(named_obj2, name_to_dim, dimtype_to_names) + + return named_obj1, named_obj2 + + +def _align_and_apply_binary_op( + lhs: NamedIslObject[IslObjectT], + rhs: NamedIslObject[IslObjectT], + op: Callable[[IslObjectT, IslObjectT], IslObjectT] + ) -> NamedIslObject[IslObjectT]: + + lhs, rhs = _align_two(lhs, rhs) + result = op(lhs._obj, rhs._obj) + + # NOTE: since lhs and rhs were aligned, they both agree on what name-to-dim + # and dimtype-to-name is, can just take information from lhs + return type(lhs)(result, lhs._name_to_dim, lhs._dimtype_to_names) + + +@dataclass(frozen=True, eq=False) +class NamedIslObject(Generic[IslObjectT], ABC): + _obj: IslObjectT + _name_to_dim: NameToDim + + # used to reconstruct ISL object + _dimtype_to_names: DimTypeToNames + + @property + def _has_inputs(self) -> bool: + return ( + isl.dim_type.in_ in self._dimtype_to_names + and + len(self._dimtype_to_names[isl.dim_type.in_]) > 0 + ) + + @property + def _input_names(self) -> frozenset[str]: + if self._has_inputs: + return self._dimtype_to_names[isl.dim_type.in_] + return frozenset() + + @property + def _input_dim_start(self) -> int | None: + if self._has_inputs: + return min( + self._name_to_dim[name] + for name in self._dimtype_to_names[isl.dim_type.in_] + ) + return None + + @property + def _has_params(self) -> bool: + return ( + isl.dim_type.param in self._dimtype_to_names + and + len(self._dimtype_to_names[isl.dim_type.param]) > 0 + ) + + @property + def _parameter_names(self) -> frozenset[str]: + if self._has_params: + return self._dimtype_to_names[isl.dim_type.param] + return frozenset() + + @property + def _parameter_dim_start(self) -> int | None: + if self._has_params: + return min( + self._name_to_dim[name] + for name in self._dimtype_to_names[isl.dim_type.param] + ) + return None + + def __and__( + self, other: NamedIslObject[IslObjectT]) -> NamedIslObject[IslObjectT]: + return _align_and_apply_binary_op(self, other, operator.and_) + + def __or__( + self, other: NamedIslObject[IslObjectT]) -> NamedIslObject[IslObjectT]: + return _align_and_apply_binary_op(self, other, operator.or_) + + def __sub__( + self, other: NamedIslObject[IslObjectT]) -> NamedIslObject[IslObjectT]: + return _align_and_apply_binary_op(self, other, operator.sub) + + @abstractmethod + def _reconstruct_isl_object(self) -> IslExpressionLike | IslSetLike: + ... + + @override + def __str__(self) -> str: + return str(self._reconstruct_isl_object()) diff --git a/namedisl/set_like.py b/namedisl/set_like.py new file mode 100644 index 0000000..67d946c --- /dev/null +++ b/namedisl/set_like.py @@ -0,0 +1,298 @@ +from __future__ import annotations + + +__copyright__ = """ +Copyright (C) 2025- University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" +from abc import ABC +from dataclasses import dataclass, replace +from typing import TYPE_CHECKING, final, overload + +from constantdict import constantdict +from typing_extensions import override + +import islpy as isl + +from .core import ( + NamedIslObject, + NameToDim, + _align_two, + _deconstruct_object, + _find_contiguous_dim_chunks, + _restore_names, + _strip_names, +) + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +@dataclass(frozen=True, eq=False) +class _NamedIslSetLike(NamedIslObject[isl.Set], ABC): + """ + Represents set-like objects with parameter dimensions as a non-parameterized + set. Names are organized as contiguous chunks of dimension types, i.e. + [ (set names), (input names), (parameter names) ] + """ + def complement(self: _NamedIslSetLike) -> _NamedIslSetLike: + return replace( + self, + _obj=self._obj.complement(), + _name_to_dim=self._name_to_dim, + _dimtype_to_names=self._dimtype_to_names + ) + + def eliminate(self, names_to_eliminate: str | Sequence[str]) -> _NamedIslSetLike: + if isinstance(names_to_eliminate, str): + names_to_eliminate = [names_to_eliminate] + + dims_to_eliminate = sorted( + self._name_to_dim[name] + for name in names_to_eliminate + ) + + contiguous_dim_chunks = _find_contiguous_dim_chunks(dims_to_eliminate) + + new_isl_obj = self._obj + for start in sorted(contiguous_dim_chunks): + new_isl_obj = new_isl_obj.eliminate( + isl.dim_type.set, start, contiguous_dim_chunks[start] + ) + + return replace( + self, + _obj=new_isl_obj, + _name_to_dim=self._name_to_dim, # NOTE: no dims removed by eliminate + _dimtype_to_names=self._dimtype_to_names + ) + + def project_out(self: _NamedIslSetLike, + names_to_project_out: str | Sequence[str]) -> _NamedIslSetLike: + + if isinstance(names_to_project_out, str): + names_to_project_out = [names_to_project_out] + + names_to_remove = set(names_to_project_out) + + dims_to_remove = sorted( + self._name_to_dim[name] + for name in names_to_remove + ) + + new_isl_obj = self._obj + contiguous_dim_chunks = _find_contiguous_dim_chunks(dims_to_remove) + for start in sorted(contiguous_dim_chunks, reverse=True): + new_isl_obj = new_isl_obj.project_out( + isl.dim_type.set, start, contiguous_dim_chunks[start] + ) + + new_name_to_dim: NameToDim = {} + for name, dim in self._name_to_dim.items(): + if name in names_to_remove: + continue + + shift = 0 + for removed_dim in dims_to_remove: + if removed_dim < dim: + shift += 1 + else: + break + + new_name_to_dim[name] = dim - shift + + new_type_to_names = constantdict({ + dt: self._dimtype_to_names[dt] - frozenset(names_to_remove) + for dt in self._dimtype_to_names + }) + + return replace( + self, + _obj=new_isl_obj, + _name_to_dim=constantdict(new_name_to_dim), + _dimtype_to_names=new_type_to_names + ) + + def project_out_except( + self: _NamedIslSetLike, + names_to_keep: str | Sequence[str] + ) -> _NamedIslSetLike: + + if isinstance(names_to_keep, str): + names_to_keep = [names_to_keep] + + names_to_project_out = [ + name for name in self._name_to_dim + if name not in names_to_keep + ] + + return self.project_out(names_to_project_out) + + # {{{ TODO: functions that return ExpressionLike objects + + def dim_max(self, name: str): + ... + + def dim_min(self, name: str): + ... + + def as_pw_multi_aff(self): + ... + + # }}} + + +@final +@dataclass(frozen=True, eq=False) +class Set(_NamedIslSetLike): + @override + def _reconstruct_isl_object(self) -> isl.Set: + # FIXME: typechecker complains that self._obj is not an isl.Set even + # though _NamedIslObject is instantiated with isl.Set. + # using reveal_type(self._obj) below shows self._obj is + # isl.Set | isl.Map? + assert isinstance(self._obj, isl.Set) + + if self._has_params: + if self._parameter_dim_start is None: + raise ValueError( + "Object has parameter dimensions, but a starting index for " + "parameter names is not given. Reconstruction is not " + "possible") + + return self._obj.move_dims( + isl.dim_type.param, 0, + isl.dim_type.set, self._parameter_dim_start, + len(self._parameter_names) + ) + + return self._obj + + @override + def __eq__(self, other: object) -> bool: + if not isinstance(other, Set): + raise NotImplementedError + + aligned_self, aligned_other = _align_two(self, other) + + # FIXME: type checker complains because it's not clear whether the + # underlying object after alignment is an isl.Set + assert isinstance(aligned_other._obj, isl.Set) + assert isinstance(aligned_self._obj, isl.Set) + return aligned_self._obj.plain_is_equal(aligned_other._obj) + + +@overload +def make_set(src: str, ctx: isl.Context | None = None) -> Set: + ... + + +@overload +def make_set(src: isl.Set) -> Set: + ... + + +def make_set(src: isl.Set | str, ctx: isl.Context | None = None) -> Set: + obj = isl.Set(src, ctx) if isinstance(src, str) else src + + set_obj, dimtype_to_names = _deconstruct_object(obj) + set_obj, name_to_dim = _strip_names(set_obj) + + assert isinstance(set_obj, isl.Set) + return Set(set_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args + + +@final +@dataclass(frozen=True, eq=False) +class Map(_NamedIslSetLike): + @override + def _reconstruct_isl_object(self) -> isl.Map: + """ + Relies on the dimension type ordering in + :func:`_deconstruct_set_like_object`. + """ + if self._input_dim_start is None: + raise ValueError("Cannot reconstruct a map object without knowledge " + "of the starting position of input dimensions") + + obj = _restore_names(self._obj, self._name_to_dim) + assert isinstance(obj, isl.Set) + + obj_domain = isl.Set("{ [] }") + obj_range = obj + + map_obj = isl.Map.from_domain_and_range(obj_domain, obj_range) + + if self._has_params: + if self._parameter_dim_start is None: + raise ValueError( + "Object has parameter dimensions, but a starting index for " + "parameter names is not given. Reconstruction is not " + "possible") + + param_start = self._parameter_dim_start + map_obj = map_obj.move_dims( + isl.dim_type.param, 0, + isl.dim_type.set, param_start, len(self._parameter_names) + ) + + inp_start = self._input_dim_start + map_obj = map_obj.move_dims( + isl.dim_type.in_, 0, + isl.dim_type.set, inp_start, len(self._input_names) + ) + + return map_obj + + @override + def __eq__(self, other: object) -> bool: + if not isinstance(other, Map): + raise NotImplementedError + + aligned_self, aligned_other = _align_two(self, other) + + # FIXME: type checker complains because it's not clear whether the + # underlying object after alignment is an isl.Set + assert isinstance(aligned_self._obj, isl.Set) + assert isinstance(aligned_other._obj, isl.Set) + return aligned_self._obj.plain_is_equal(aligned_other._obj) + + +@overload +def make_map(src: str, ctx: isl.Context | None = None) -> Map: + ... + + +@overload +def make_map(src: isl.Map) -> Map: + ... + + +def make_map(src: str | isl.Map, ctx: isl.Context | None = None) -> Map: + obj = isl.Map(src, ctx) if isinstance(src, str) else src + + set_obj, dimtype_to_names = _deconstruct_object(obj) + set_obj, name_to_dim = _strip_names(set_obj) + + assert isinstance(set_obj, isl.Set) + return Map(set_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args diff --git a/namedisl/test/test_set.py b/namedisl/test/test_set.py index 7697195..89bcc49 100644 --- a/namedisl/test/test_set.py +++ b/namedisl/test/test_set.py @@ -126,12 +126,3 @@ def test_set_project_out(ndims: int): a = a.project_out(a_dims.split(",")) assert a == nisl.make_set("{[]}") - - -if __name__ == "__main__": - import sys - if len(sys.argv) > 1: - exec(sys.argv[0]) - else: - from pytest import main - main([__file__]) From dfc5adefdc24aaffa662794814878d5ce52f6e0a Mon Sep 17 00:00:00 2001 From: Addison Date: Mon, 26 Jan 2026 14:31:41 -0600 Subject: [PATCH 14/43] add expression-likes necessary for set-like operations; basedpyright not happy yet --- namedisl/core.py | 91 +++++++++------ namedisl/expression_like.py | 160 ++++++++++++++++++++++++++ namedisl/set_like.py | 221 +++++++++++++++++++++++------------- namedisl/test/test_map.py | 60 ++++++++++ namedisl/test/test_set.py | 26 +++++ 5 files changed, 443 insertions(+), 115 deletions(-) create mode 100644 namedisl/expression_like.py diff --git a/namedisl/core.py b/namedisl/core.py index bcd5c88..2b068ec 100644 --- a/namedisl/core.py +++ b/namedisl/core.py @@ -25,7 +25,6 @@ THE SOFTWARE. """ -import operator import re from abc import ABC, abstractmethod from collections.abc import Callable, Mapping, Sequence @@ -39,11 +38,28 @@ import islpy as isl -IslSetLike = isl.Set | isl.Map -IslExpressionLike = isl.Aff | isl.QPolynomial - -IslExpressionLikeT = TypeVar("IslExpressionLikeT", isl.Aff, isl.QPolynomial) -IslSetLikeT = TypeVar("IslSetLikeT", isl.Set, isl.Map) +IslSetLike = isl.BasicSet | isl.BasicMap | isl.Set | isl.Map +IslBaseExpressionLike = isl.Aff | isl.QPolynomial +IslPwExpressionLike = isl.PwAff | isl.PwQPolynomial +IslMultiExpressionLike = isl.MultiAff | isl.PwMultiAff +IslExpressionLike = IslBaseExpressionLike | IslPwExpressionLike | IslMultiExpressionLike + +IslExpressionLikeT = TypeVar( + "IslExpressionLikeT", + isl.Aff, + isl.MultiAff, + isl.PwAff, + isl.PwMultiAff, + isl.QPolynomial, + isl.PwQPolynomial +) +IslSetLikeT = TypeVar( + "IslSetLikeT", + isl.BasicSet, + isl.BasicMap, + isl.Set, + isl.Map +) IslObjectT = TypeVar("IslObjectT", IslSetLike, IslExpressionLike) NameToDim: TypeAlias = Mapping[str, int] @@ -53,7 +69,7 @@ # alignment DimTypeToNames: TypeAlias = Mapping[isl.dim_type, frozenset[str]] -SetLikePieces: TypeAlias = tuple[isl.Set, DimTypeToNames] +IslObjectPieces: TypeAlias = tuple[IslSetLike | IslExpressionLike, DimTypeToNames] __version__ = metadata.version("namedisl") @@ -75,7 +91,7 @@ def _strip_names(obj: IslObjectT) -> tuple[IslObjectT, NameToDim]: name_to_dim: dict[str, int] = {} for i in range(obj.dim(isl.dim_type.set)): - if isinstance(obj, isl.QPolynomial): + if isinstance(obj, isl.QPolynomial | isl.PwQPolynomial): name = obj.space.get_dim_name(isl.dim_type.set, i) else: name = obj.get_dim_name(isl.dim_type.set, i) @@ -114,26 +130,41 @@ def _get_dim_names(obj: IslObjectT, dt: isl.dim_type) -> frozenset[str]: return frozenset(all_dt_names) -def _deconstruct_object(obj: IslSetLikeT) -> SetLikePieces: +def _deconstruct_object(obj: IslObjectT) -> IslObjectPieces: from islpy import dim_type - dt_to_names: dict[dim_type, frozenset[str]] = dict.fromkeys( - [isl.dim_type.in_, isl.dim_type.param], frozenset() - ) - for dt in dt_to_names: - dt_to_names[dt] = _get_dim_names(obj, dt) - if dt_to_names[dt]: - obj = obj.move_dims( - dim_type.set, - obj.dim(dim_type.set), - dt, - 0, - obj.dim(dt) - ) + dt_to_names: dict[dim_type, frozenset[str]] = {} + + if isinstance(obj, IslSetLike): + setlike_obj = obj + dt_to_names = dict.fromkeys( + [isl.dim_type.in_, isl.dim_type.param], frozenset() + ) + for dt in dt_to_names: + dt_to_names[dt] = _get_dim_names(setlike_obj, dt) + if dt_to_names[dt]: + setlike_obj = setlike_obj.move_dims( + dim_type.set, + setlike_obj.dim(dim_type.set), + dt, + 0, + setlike_obj.dim(dt) + ) + + setlike_obj = ( + setlike_obj.range() + if isinstance(setlike_obj, isl.Map) + else setlike_obj + ) + + return setlike_obj, constantdict(dt_to_names) + + elif isinstance(obj, IslExpressionLike): + expr_obj = obj - set_obj = obj.range() if isinstance(obj, isl.Map) else obj + dt_to_names = dict.fromkeys([isl.dim_type.param], frozenset()) - return set_obj, constantdict(dt_to_names) + return expr_obj, constantdict(dt_to_names) def _find_contiguous_dim_chunks(dims: Sequence[int]) -> Mapping[int, int]: @@ -335,18 +366,6 @@ def _parameter_dim_start(self) -> int | None: ) return None - def __and__( - self, other: NamedIslObject[IslObjectT]) -> NamedIslObject[IslObjectT]: - return _align_and_apply_binary_op(self, other, operator.and_) - - def __or__( - self, other: NamedIslObject[IslObjectT]) -> NamedIslObject[IslObjectT]: - return _align_and_apply_binary_op(self, other, operator.or_) - - def __sub__( - self, other: NamedIslObject[IslObjectT]) -> NamedIslObject[IslObjectT]: - return _align_and_apply_binary_op(self, other, operator.sub) - @abstractmethod def _reconstruct_isl_object(self) -> IslExpressionLike | IslSetLike: ... diff --git a/namedisl/expression_like.py b/namedisl/expression_like.py new file mode 100644 index 0000000..586a029 --- /dev/null +++ b/namedisl/expression_like.py @@ -0,0 +1,160 @@ +from __future__ import annotations + + +__copyright__ = """ +Copyright (C) 2025- University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import operator +from dataclasses import dataclass +from typing import final, overload + +from typing_extensions import Self, override + +import islpy as isl + +from .core import ( + IslExpressionLike, + NamedIslObject, + _align_and_apply_binary_op, + _deconstruct_object, + _strip_names, +) + + +@dataclass(frozen=True, eq=False) +class _NamedExpressionLike(NamedIslObject[IslExpressionLike]): + @override + def _reconstruct_isl_object(self) -> IslExpressionLike: + return self._obj + + # FIXME: Self is used here is because _NamedExpressionLike is generic, + # leading to complaints from basedpyright + def __add__(self, other: Self) -> Self: + return _align_and_apply_binary_op(self, other, operator.add) + + def __sub__(self, other: Self) -> Self: + return _align_and_apply_binary_op(self, other, operator.sub) + + def __mul__(self, other: Self) -> Self: + return _align_and_apply_binary_op(self, other, operator.mul) + + @override + def __eq__(self, other: object) -> bool: + assert type(other) is type(self) + return self._obj == other._obj + + +@dataclass(frozen=True, eq=False) +class _NamedPwExpressionLike(_NamedExpressionLike): + ... + + +@dataclass(frozen=True, eq=False) +class _NamedMultiExpressionLike(_NamedExpressionLike): + ... + + +@final +@dataclass(frozen=True, eq=False) +class Aff(_NamedExpressionLike): + _obj: isl.Aff + + +@overload +def make_aff(src: str, ctx: isl.Context | None = None) -> Aff: + ... + + +@overload +def make_aff(src: isl.Aff) -> Aff: + ... + + +def make_aff(src: str | isl.Aff, ctx: isl.Context | None = None) -> Aff: + obj = isl.Aff(src, ctx) if isinstance(src, str) else src + + aff_obj, dimtype_to_names = _deconstruct_object(obj) + + assert isinstance(aff_obj, isl.Aff) + aff_obj, name_to_dim = _strip_names(aff_obj) + + return Aff(aff_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args + + +@final +@dataclass(frozen=True, eq=False) +class PwAff(_NamedPwExpressionLike): + _obj: isl.PwAff + + +@overload +def make_pwaff(src: str, ctx: isl.Context | None = None) -> PwAff: + ... + + +@overload +def make_pwaff(src: isl.PwAff) -> PwAff: + ... + + +def make_pwaff(src: str | isl.PwAff, ctx: isl.Context | None = None) -> PwAff: + obj = isl.PwAff(src, ctx) if isinstance(src, str) else src + + pwaff_obj, dimtype_to_names = _deconstruct_object(obj) + + assert isinstance(obj, isl.PwAff) + pwaff_obj, name_to_dim = _strip_names(obj) + + return PwAff(pwaff_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args + + +@final +@dataclass(frozen=True, eq=False) +class PwMultiAff(_NamedMultiExpressionLike): + _obj: isl.PwMultiAff + + +@overload +def make_pw_multi_aff(src: str, ctx: isl.Context | None = None) -> PwMultiAff: + ... + + +@overload +def make_pw_multi_aff(src: isl.PwMultiAff) -> PwMultiAff: + ... + + +def make_pw_multi_aff( + src: str | isl.PwMultiAff, + ctx: isl.Context | None = None + ) -> PwMultiAff: + + obj = isl.PwMultiAff(src, ctx) if isinstance(src, str) else src + + pw_maff_obj, dimtype_to_names = _deconstruct_object(obj) + + assert isinstance(pw_maff_obj, isl.PwMultiAff) + pw_maff_obj, name_to_dim = _strip_names(pw_maff_obj) + + return PwMultiAff(pw_maff_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args diff --git a/namedisl/set_like.py b/namedisl/set_like.py index 67d946c..12f70b2 100644 --- a/namedisl/set_like.py +++ b/namedisl/set_like.py @@ -24,6 +24,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + +import operator from abc import ABC from dataclasses import dataclass, replace from typing import TYPE_CHECKING, final, overload @@ -34,14 +36,17 @@ import islpy as isl from .core import ( + IslSetLike, NamedIslObject, NameToDim, + _align_and_apply_binary_op, _align_two, _deconstruct_object, _find_contiguous_dim_chunks, _restore_names, _strip_names, ) +from .expression_like import PwAff, PwMultiAff, make_pw_multi_aff, make_pwaff if TYPE_CHECKING: @@ -55,6 +60,50 @@ class _NamedIslSetLike(NamedIslObject[isl.Set], ABC): set. Names are organized as contiguous chunks of dimension types, i.e. [ (set names), (input names), (parameter names) ] """ + + @override + def _reconstruct_isl_object(self) -> IslSetLike: + """ + Relies on the dimension type ordering in + :func:`_deconstruct_set_like_object`. + """ + + obj = _restore_names(self._obj, self._name_to_dim) + assert isinstance(obj, isl.Set) + + if self._has_params: + if self._parameter_dim_start is None: + raise ValueError( + "Object has parameter dimensions, but a starting index for " + "parameter names is not given. Reconstruction is not " + "possible") + + param_start = self._parameter_dim_start + obj = obj.move_dims( + isl.dim_type.param, 0, + isl.dim_type.set, param_start, len(self._parameter_names) + ) + + if self._has_inputs: + if self._input_dim_start is None: + raise ValueError( + "Object has input dimensions, but a starting index for " + "input names is not given. Reconstruction is not " + "possible") + + obj_domain = isl.Set("{ [] }") + obj_range = obj + + obj = isl.Map.from_domain_and_range(obj_domain, obj_range) + + inp_start = self._input_dim_start + obj = obj.move_dims( + isl.dim_type.in_, 0, + isl.dim_type.set, inp_start, len(self._input_names) + ) + + return obj + def complement(self: _NamedIslSetLike) -> _NamedIslSetLike: return replace( self, @@ -148,60 +197,85 @@ def project_out_except( return self.project_out(names_to_project_out) - # {{{ TODO: functions that return ExpressionLike objects - - def dim_max(self, name: str): - ... + def dim_max(self, name: str) -> PwAff: + return make_pwaff(self._obj.dim_max(self._name_to_dim[name])) - def dim_min(self, name: str): - ... + def dim_min(self, name: str) -> PwAff: + return make_pwaff(self._obj.dim_min(self._name_to_dim[name])) - def as_pw_multi_aff(self): - ... + def as_pw_multi_aff(self) -> PwMultiAff: + return make_pw_multi_aff(self._reconstruct_isl_object().as_pw_multi_aff()) - # }}} + # FIXME: basedpyright is not happy with these function signatures + def __and__( + self, other: _NamedIslSetLike) -> _NamedIslSetLike: + return _align_and_apply_binary_op(self, other, operator.and_) + def __or__( + self, other: _NamedIslSetLike) -> _NamedIslSetLike: + return _align_and_apply_binary_op(self, other, operator.or_) -@final -@dataclass(frozen=True, eq=False) -class Set(_NamedIslSetLike): - @override - def _reconstruct_isl_object(self) -> isl.Set: - # FIXME: typechecker complains that self._obj is not an isl.Set even - # though _NamedIslObject is instantiated with isl.Set. - # using reveal_type(self._obj) below shows self._obj is - # isl.Set | isl.Map? - assert isinstance(self._obj, isl.Set) - - if self._has_params: - if self._parameter_dim_start is None: - raise ValueError( - "Object has parameter dimensions, but a starting index for " - "parameter names is not given. Reconstruction is not " - "possible") - - return self._obj.move_dims( - isl.dim_type.param, 0, - isl.dim_type.set, self._parameter_dim_start, - len(self._parameter_names) - ) - - return self._obj + def __sub__( + self, other: _NamedIslSetLike) -> _NamedIslSetLike: + return _align_and_apply_binary_op(self, other, operator.sub) @override def __eq__(self, other: object) -> bool: - if not isinstance(other, Set): + if not isinstance(other, type(self)): raise NotImplementedError aligned_self, aligned_other = _align_two(self, other) # FIXME: type checker complains because it's not clear whether the # underlying object after alignment is an isl.Set - assert isinstance(aligned_other._obj, isl.Set) assert isinstance(aligned_self._obj, isl.Set) + assert isinstance(aligned_other._obj, isl.Set) return aligned_self._obj.plain_is_equal(aligned_other._obj) +@final +@dataclass(frozen=True, eq=False) +class BasicSet(_NamedIslSetLike): + + @override + def _reconstruct_isl_object(self) -> isl.BasicSet: + obj = super()._reconstruct_isl_object() + + if not isinstance(obj, isl.Set) or obj.n_basic_set() != 1: + raise ValueError( + "Cannot reconstruct an isl.BasicSet from anything other than " + "an isl.Set containing only a single isl.BasicSet.") + + return obj.get_basic_sets()[0] + + +@overload +def make_basic_set(src: str, ctx: isl.Context | None = None) -> BasicSet: + ... + + +@overload +def make_basic_set(src: isl.BasicSet) -> BasicSet: + ... + + +def make_basic_set(src: str | isl.BasicSet, ctx: isl.Context | None = None) -> BasicSet: + obj = isl.BasicSet(src, ctx) if isinstance(src, str) else src + + set_obj, dimtype_to_names = _deconstruct_object(obj) + + assert isinstance(set_obj, isl.Set) + set_obj, name_to_dim = _strip_names(set_obj) + + return BasicSet(set_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args + + +@final +@dataclass(frozen=True, eq=False) +class Set(_NamedIslSetLike): + ... + + @overload def make_set(src: str, ctx: isl.Context | None = None) -> Set: ... @@ -216,66 +290,54 @@ def make_set(src: isl.Set | str, ctx: isl.Context | None = None) -> Set: obj = isl.Set(src, ctx) if isinstance(src, str) else src set_obj, dimtype_to_names = _deconstruct_object(obj) - set_obj, name_to_dim = _strip_names(set_obj) assert isinstance(set_obj, isl.Set) + set_obj, name_to_dim = _strip_names(set_obj) + return Set(set_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args @final @dataclass(frozen=True, eq=False) -class Map(_NamedIslSetLike): +class BasicMap(_NamedIslSetLike): + @override - def _reconstruct_isl_object(self) -> isl.Map: - """ - Relies on the dimension type ordering in - :func:`_deconstruct_set_like_object`. - """ - if self._input_dim_start is None: - raise ValueError("Cannot reconstruct a map object without knowledge " - "of the starting position of input dimensions") + def _reconstruct_isl_object(self) -> isl.BasicMap: + obj = super()._reconstruct_isl_object() - obj = _restore_names(self._obj, self._name_to_dim) - assert isinstance(obj, isl.Set) + if not isinstance(obj, isl.Map) or obj.n_basic_map() != 1: + raise ValueError( + "Cannot reconstruct an isl.BasicMap from anything other than " + "an isl.Map containing only a single isl.BasicMap.") - obj_domain = isl.Set("{ [] }") - obj_range = obj + return obj.get_basic_maps()[0] - map_obj = isl.Map.from_domain_and_range(obj_domain, obj_range) - if self._has_params: - if self._parameter_dim_start is None: - raise ValueError( - "Object has parameter dimensions, but a starting index for " - "parameter names is not given. Reconstruction is not " - "possible") +@overload +def make_basic_map(src: str, ctx: isl.Context | None = None) -> BasicMap: + ... - param_start = self._parameter_dim_start - map_obj = map_obj.move_dims( - isl.dim_type.param, 0, - isl.dim_type.set, param_start, len(self._parameter_names) - ) - inp_start = self._input_dim_start - map_obj = map_obj.move_dims( - isl.dim_type.in_, 0, - isl.dim_type.set, inp_start, len(self._input_names) - ) +@overload +def make_basic_map(src: isl.BasicMap) -> BasicMap: + ... - return map_obj - @override - def __eq__(self, other: object) -> bool: - if not isinstance(other, Map): - raise NotImplementedError +def make_basic_map(src: str | isl.BasicMap, ctx: isl.Context | None = None) -> BasicMap: + obj = isl.BasicMap(src, ctx) if isinstance(src, str) else src - aligned_self, aligned_other = _align_two(self, other) + set_obj, dimtype_to_names = _deconstruct_object(obj) - # FIXME: type checker complains because it's not clear whether the - # underlying object after alignment is an isl.Set - assert isinstance(aligned_self._obj, isl.Set) - assert isinstance(aligned_other._obj, isl.Set) - return aligned_self._obj.plain_is_equal(aligned_other._obj) + assert isinstance(set_obj, isl.Set) + set_obj, name_to_dim = _strip_names(set_obj) + + return BasicMap(set_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args + + +@final +@dataclass(frozen=True, eq=False) +class Map(_NamedIslSetLike): + ... @overload @@ -292,7 +354,8 @@ def make_map(src: str | isl.Map, ctx: isl.Context | None = None) -> Map: obj = isl.Map(src, ctx) if isinstance(src, str) else src set_obj, dimtype_to_names = _deconstruct_object(obj) - set_obj, name_to_dim = _strip_names(set_obj) assert isinstance(set_obj, isl.Set) + set_obj, name_to_dim = _strip_names(set_obj) + return Map(set_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args diff --git a/namedisl/test/test_map.py b/namedisl/test/test_map.py index b8c6d01..5795569 100644 --- a/namedisl/test/test_map.py +++ b/namedisl/test/test_map.py @@ -204,3 +204,63 @@ def test_map_project_out(ndims_domain: int, ndims_range: int): x = x.project_out(dims_to_remove) assert x == nisl.make_map("{[] -> []}") + + +def test_map_as_pw_multi_aff(): + spec = "{ [i] -> [io, ii] : i = 32 * io + ii and 0 <= ii < 32 }" + m = nisl.make_map(spec) + m_isl = isl.Map(spec) + + assert m.as_pw_multi_aff()._obj == m_isl.as_pw_multi_aff() + + +@pytest.mark.parametrize("ndims_domain", [1, 2, 4, 8]) +@pytest.mark.parametrize("ndims_range", [1, 2, 4, 8]) +def test_map_dim_max(ndims_domain: int, ndims_range: int): + m, (_, in_names, in_conds), (_, out_names, out_conds) = generate_random_named_map( + ndims_domain, "x_in", None, + ndims_range, "x_out", None + ) + + in_upper_bound_pw_maffs = [ + isl.PwAff(f"{{ [{int(cond.split('<')[2].strip(' '))}] }}") + for cond in in_conds.split("and") + ] + + for i, name in enumerate(in_names.split(",")): + # NOTE: constructing PwAffs assumes starting index of 0, so subtract 1 + assert m.dim_max(name)._obj == (in_upper_bound_pw_maffs[i] - 1) + + out_upper_bound_pw_maffs = [ + isl.PwAff(f"{{ [{int(cond.split('<')[2].strip(' '))}] }}") + for cond in out_conds.split("and") + ] + + for i, name in enumerate(out_names.split(",")): + # NOTE: constructing PwAffs assumes starting index of 0, so subtract 1 + assert m.dim_max(name)._obj == (out_upper_bound_pw_maffs[i] - 1) + + +@pytest.mark.parametrize("ndims_domain", [1, 2, 4, 8]) +@pytest.mark.parametrize("ndims_range", [1, 2, 4, 8]) +def test_map_dim_min(ndims_domain: int, ndims_range: int): + m, (_, in_names, in_conds), (_, out_names, out_conds) = generate_random_named_map( + ndims_domain, "x_in", None, + ndims_range, "x_out", None + ) + + in_lower_bound_pw_maffs = [ + isl.PwAff(f"{{ [{int(cond.split('<')[0].strip(' '))}] }}") + for cond in in_conds.split("and") + ] + + for i, name in enumerate(in_names.split(",")): + assert m.dim_min(name)._obj == in_lower_bound_pw_maffs[i] + + out_lower_bound_pw_maffs = [ + isl.PwAff(f"{{ [{int(cond.split('<')[0].strip(' '))}] }}") + for cond in out_conds.split("and") + ] + + for i, name in enumerate(out_names.split(",")): + assert m.dim_min(name)._obj == out_lower_bound_pw_maffs[i] diff --git a/namedisl/test/test_set.py b/namedisl/test/test_set.py index 89bcc49..9cb12a1 100644 --- a/namedisl/test/test_set.py +++ b/namedisl/test/test_set.py @@ -126,3 +126,29 @@ def test_set_project_out(ndims: int): a = a.project_out(a_dims.split(",")) assert a == nisl.make_set("{[]}") + + +@pytest.mark.parametrize("ndims", [2, 4, 8]) +def test_set_dim_max(ndims: int): + a, a_dims, a_cond = generate_random_named_set(ndims, "a", None) + + cond_pw_affs = [ + isl.PwAff(f"{{ [{cond.split('<')[2].strip(' ')}] }}") + for cond in a_cond.split("and") + ] + + for i, name in enumerate(a_dims.split(",")): + assert a.dim_max(name)._obj == (cond_pw_affs[i] - 1) + + +@pytest.mark.parametrize("ndims", [2, 4, 8]) +def test_set_dim_min(ndims: int): + a, a_dims, a_cond = generate_random_named_set(ndims, "a", None) + + cond_pw_affs = [ + isl.PwAff(f"{{ [{cond.split('<')[0].strip(' ')}] }}") + for cond in a_cond.split("and") + ] + + for i, name in enumerate(a_dims.split(",")): + assert a.dim_min(name)._obj == cond_pw_affs[i] From fe404311cf968a8120176dedf4a77e3421de013e Mon Sep 17 00:00:00 2001 From: Addison Date: Fri, 30 Jan 2026 14:06:41 -0600 Subject: [PATCH 15/43] add tests for expression-likes; add qpolynomials; refactor --- namedisl/__init__.py | 48 ++- namedisl/core.py | 151 ++++++-- namedisl/expression_like.py | 183 ++++++++-- namedisl/set_like.py | 53 +-- namedisl/test/test_expression_like.py | 327 ++++++++++++++++++ namedisl/test/test_set.py | 154 --------- .../test/{test_map.py => test_set_like.py} | 168 ++++++++- 7 files changed, 824 insertions(+), 260 deletions(-) create mode 100644 namedisl/test/test_expression_like.py delete mode 100644 namedisl/test/test_set.py rename namedisl/test/{test_map.py => test_set_like.py} (63%) diff --git a/namedisl/__init__.py b/namedisl/__init__.py index 0795cea..dc863d4 100644 --- a/namedisl/__init__.py +++ b/namedisl/__init__.py @@ -25,7 +25,51 @@ THE SOFTWARE. """ -from .set_like import Map, Set, make_map, make_set +from .expression_like import ( + Aff, + MultiAff, + PwAff, + PwMultiAff, + PwQPolynomial, + QPolynomial, + make_aff, + make_multi_aff, + make_pw_aff, + make_pw_multi_aff, + make_pw_qpolynomial, + make_qpolynomial, +) +from .set_like import ( + BasicMap, + BasicSet, + Map, + Set, + make_basic_map, + make_basic_set, + make_map, + make_set, +) -__all__ = ["Map", "Set", "make_map", "make_set"] +__all__ = [ + "Aff", + "BasicMap", + "BasicSet", + "Map", + "MultiAff", + "PwAff", + "PwMultiAff", + "PwQPolynomial", + "QPolynomial", + "Set", + "make_aff", + "make_basic_map", + "make_basic_set", + "make_map", + "make_multi_aff", + "make_pw_aff", + "make_pw_multi_aff", + "make_pw_qpolynomial", + "make_qpolynomial", + "make_set" +] diff --git a/namedisl/core.py b/namedisl/core.py index 2b068ec..c1dc3e5 100644 --- a/namedisl/core.py +++ b/namedisl/core.py @@ -26,7 +26,7 @@ """ import re -from abc import ABC, abstractmethod +from abc import ABC from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass from importlib import metadata @@ -43,13 +43,13 @@ IslPwExpressionLike = isl.PwAff | isl.PwQPolynomial IslMultiExpressionLike = isl.MultiAff | isl.PwMultiAff IslExpressionLike = IslBaseExpressionLike | IslPwExpressionLike | IslMultiExpressionLike +IslObject = IslSetLike | IslExpressionLike | IslMultiExpressionLike IslExpressionLikeT = TypeVar( "IslExpressionLikeT", isl.Aff, isl.MultiAff, isl.PwAff, - isl.PwMultiAff, isl.QPolynomial, isl.PwQPolynomial ) @@ -58,7 +58,12 @@ isl.BasicSet, isl.BasicMap, isl.Set, - isl.Map + isl.Map, +) +IslMultiExpressionLikeT = TypeVar( + "IslMultiExpressionLikeT", + isl.PwMultiAff, + isl.MultiAff ) IslObjectT = TypeVar("IslObjectT", IslSetLike, IslExpressionLike) @@ -89,12 +94,16 @@ def _strip_names(obj: IslObjectT) -> tuple[IslObjectT, NameToDim]: name_to_dim: dict[str, int] = {} - for i in range(obj.dim(isl.dim_type.set)): + dt_to_strip = ( + isl.dim_type.set if isinstance(obj, IslSetLike) else isl.dim_type.in_ + ) + + for i in range(obj.dim(dt_to_strip)): if isinstance(obj, isl.QPolynomial | isl.PwQPolynomial): - name = obj.space.get_dim_name(isl.dim_type.set, i) + name = obj.space.get_dim_name(dt_to_strip, i) else: - name = obj.get_dim_name(isl.dim_type.set, i) + name = obj.get_dim_name(dt_to_strip, i) if name is None: raise ValueError("unnamed dimension found") @@ -108,8 +117,42 @@ def _strip_names(obj: IslObjectT) -> tuple[IslObjectT, NameToDim]: def _restore_names(obj: IslObjectT, name_to_dim: NameToDim) -> IslObjectT: + if isinstance(obj, isl.PwAff): + pwaff_obj = obj.move_dims( + isl.dim_type.param, + 0, + isl.dim_type.in_, + 0, + obj.dim(isl.dim_type.in_) + ) + + for name, dim in name_to_dim.items(): + pwaff_obj = pwaff_obj.set_dim_name( + isl.dim_type.param, + dim, + name + ) + + pwaff_obj = pwaff_obj.get_pw_aff_list().get_at(0) + + pwaff_obj = pwaff_obj.move_dims( + isl.dim_type.in_, + 0, + isl.dim_type.param, + 0, + pwaff_obj.dim(isl.dim_type.param) + ) + + return pwaff_obj + + if isinstance(obj, IslSetLike): + dt_to_restore = isl.dim_type.set + else: + dt_to_restore = isl.dim_type.in_ + for name, dim in name_to_dim.items(): - obj = obj.set_dim_name(isl.dim_type.set, dim, name) + obj = obj.set_dim_name(dt_to_restore, dim, name) + return obj @@ -117,7 +160,7 @@ def _get_dim_names(obj: IslObjectT, dt: isl.dim_type) -> frozenset[str]: all_dt_names: list[str] = [] for dim in range(obj.dim(dt)): - if isinstance(obj, isl.QPolynomial): + if isinstance(obj, isl.QPolynomial | isl.PwQPolynomial): name = obj.space.get_dim_name(dt, dim) else: name = obj.get_dim_name(dt, dim) @@ -131,21 +174,25 @@ def _get_dim_names(obj: IslObjectT, dt: isl.dim_type) -> frozenset[str]: def _deconstruct_object(obj: IslObjectT) -> IslObjectPieces: - from islpy import dim_type - - dt_to_names: dict[dim_type, frozenset[str]] = {} + dt_to_names: dict[isl.dim_type, frozenset[str]] = {} - if isinstance(obj, IslSetLike): + if isinstance(obj, IslSetLike | IslMultiExpressionLike): setlike_obj = obj dt_to_names = dict.fromkeys( [isl.dim_type.in_, isl.dim_type.param], frozenset() ) + + # NOTE: isl.PwMultiAff.move_dims does not exist, represent as map + # internally + if isinstance(setlike_obj, IslMultiExpressionLike): + setlike_obj = setlike_obj.as_map() + for dt in dt_to_names: dt_to_names[dt] = _get_dim_names(setlike_obj, dt) if dt_to_names[dt]: setlike_obj = setlike_obj.move_dims( - dim_type.set, - setlike_obj.dim(dim_type.set), + isl.dim_type.set, + setlike_obj.dim(isl.dim_type.set), dt, 0, setlike_obj.dim(dt) @@ -153,16 +200,32 @@ def _deconstruct_object(obj: IslObjectT) -> IslObjectPieces: setlike_obj = ( setlike_obj.range() - if isinstance(setlike_obj, isl.Map) + if isinstance(setlike_obj, isl.Map | isl.BasicMap) + else setlike_obj + ) + + setlike_obj = ( + isl.Set.from_basic_set(setlike_obj) + if isinstance(setlike_obj, isl.BasicSet) else setlike_obj ) return setlike_obj, constantdict(dt_to_names) - elif isinstance(obj, IslExpressionLike): + else: expr_obj = obj dt_to_names = dict.fromkeys([isl.dim_type.param], frozenset()) + dt_to_names[isl.dim_type.param] = _get_dim_names(expr_obj, + isl.dim_type.param) + + expr_obj = expr_obj.move_dims( + isl.dim_type.in_, + expr_obj.dim(isl.dim_type.in_), + isl.dim_type.param, + 0, + expr_obj.dim(isl.dim_type.param) + ) return expr_obj, constantdict(dt_to_names) @@ -249,6 +312,12 @@ def _align_obj( new_isl_obj = named_obj._obj running_name_to_dim = dict(named_obj._name_to_dim) + target_dt = ( + isl.dim_type.set + if isinstance(new_isl_obj, IslSetLike) + else isl.dim_type.in_ + ) + for name, target_dim in sorted(ordering.items(), key=lambda x: x[1]): if name in running_name_to_dim: old_dim = running_name_to_dim[name] @@ -260,11 +329,11 @@ def _align_obj( # source dim types cannot match in ISL new_isl_obj = new_isl_obj.move_dims( isl.dim_type.param, 0, - isl.dim_type.set, old_dim, 1 + target_dt, old_dim, 1 ) new_isl_obj = new_isl_obj.move_dims( - isl.dim_type.set, target_dim, + target_dt, target_dim, isl.dim_type.param, 0, 1 ) @@ -366,9 +435,49 @@ def _parameter_dim_start(self) -> int | None: ) return None - @abstractmethod - def _reconstruct_isl_object(self) -> IslExpressionLike | IslSetLike: - ... + def _reconstruct_isl_object(self) -> IslObject: + """ + Relies on the dimension type ordering in + :func:`_deconstruct_set_like_object`. + """ + obj = _restore_names(self._obj, self._name_to_dim) + + internal_dim = ( + isl.dim_type.set if isinstance(obj, isl.Set) else isl.dim_type.in_ + ) + + if self._has_params: + if self._parameter_dim_start is None: + raise ValueError( + "Object has parameter dimensions, but a starting index for " + "parameter names is not given. Reconstruction is not " + "possible") + + param_start = self._parameter_dim_start + obj = obj.move_dims( + isl.dim_type.param, 0, + internal_dim, param_start, len(self._parameter_names) + ) + + if self._has_inputs: + if self._input_dim_start is None: + raise ValueError( + "Object has input dimensions, but a starting index for " + "input names is not given. Reconstruction is not " + "possible") + + obj_domain = isl.Set("{ [] }") + obj_range = obj + + obj = isl.Map.from_domain_and_range(obj_domain, obj_range) + + inp_start = self._input_dim_start + obj = obj.move_dims( + isl.dim_type.in_, 0, + internal_dim, inp_start, len(self._input_names) + ) + + return obj @override def __str__(self) -> str: diff --git a/namedisl/expression_like.py b/namedisl/expression_like.py index 586a029..e6668e4 100644 --- a/namedisl/expression_like.py +++ b/namedisl/expression_like.py @@ -26,10 +26,10 @@ """ import operator -from dataclasses import dataclass -from typing import final, overload +from dataclasses import dataclass, replace +from typing import final, overload, override -from typing_extensions import Self, override +from typing_extensions import Self import islpy as isl @@ -42,27 +42,51 @@ ) +# {{{ "base" named expression-likes (affs, pwaffs, qpolynomials, pwqpolynomials) + @dataclass(frozen=True, eq=False) class _NamedExpressionLike(NamedIslObject[IslExpressionLike]): - @override - def _reconstruct_isl_object(self) -> IslExpressionLike: - return self._obj - # FIXME: Self is used here is because _NamedExpressionLike is generic, # leading to complaints from basedpyright - def __add__(self, other: Self) -> Self: + def __add__(self, other: Self | int) -> Self: + if isinstance(other, int): + return replace( + self, + _obj=operator.add(self._obj, other), + _name_to_dim=self._name_to_dim, + _dimtype_to_names=self._dimtype_to_names + ) + return _align_and_apply_binary_op(self, other, operator.add) - def __sub__(self, other: Self) -> Self: + def __sub__(self, other: Self | int) -> Self: + if isinstance(other, int): + return replace( + self, + _obj=operator.sub(self._obj, other), + _name_to_dim=self._name_to_dim, + _dimtype_to_names=self._dimtype_to_names + ) + return _align_and_apply_binary_op(self, other, operator.sub) - def __mul__(self, other: Self) -> Self: + def __mul__(self, other: Self | int) -> Self: + if isinstance(other, int): + return replace( + self, + _obj=operator.mul(self._obj, other), + _name_to_dim=self._name_to_dim, + _dimtype_to_names=self._dimtype_to_names + ) + return _align_and_apply_binary_op(self, other, operator.mul) + def is_zero(self) -> bool: + return self._reconstruct_isl_object().is_zero() + @override def __eq__(self, other: object) -> bool: - assert type(other) is type(self) - return self._obj == other._obj + raise NotImplementedError @dataclass(frozen=True, eq=False) @@ -70,11 +94,6 @@ class _NamedPwExpressionLike(_NamedExpressionLike): ... -@dataclass(frozen=True, eq=False) -class _NamedMultiExpressionLike(_NamedExpressionLike): - ... - - @final @dataclass(frozen=True, eq=False) class Aff(_NamedExpressionLike): @@ -95,13 +114,43 @@ def make_aff(src: str | isl.Aff, ctx: isl.Context | None = None) -> Aff: obj = isl.Aff(src, ctx) if isinstance(src, str) else src aff_obj, dimtype_to_names = _deconstruct_object(obj) - - assert isinstance(aff_obj, isl.Aff) aff_obj, name_to_dim = _strip_names(aff_obj) return Aff(aff_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args +@final +@dataclass(frozen=True, eq=False) +class QPolynomial(_NamedExpressionLike): + _obj: isl.QPolynomial + + +@overload +def make_qpolynomial(src: str, ctx: isl.Context | None = None) -> QPolynomial: + ... + + +@overload +def make_qpolynomial(src: isl.QPolynomial) -> QPolynomial: + ... + + +def make_qpolynomial( + src: str | isl.QPolynomial, ctx: isl.Context | None = None) -> QPolynomial: + # NOTE: ISL does not have a QPolynomial constructor, but we can make one + # here by first creating a PwQPolynomial, then taking the only QPolynomial + # that comes out of it :shrug: + obj = ( + isl.PwQPolynomial(src, ctx).get_pieces()[0][1] if isinstance(src, str) + else src + ) + + qp_obj, dimtype_to_names = _deconstruct_object(obj) + qp_obj, name_to_dim = _strip_names(qp_obj) + + return QPolynomial(qp_obj, name_to_dim, dimtype_to_names) + + @final @dataclass(frozen=True, eq=False) class PwAff(_NamedPwExpressionLike): @@ -109,30 +158,78 @@ class PwAff(_NamedPwExpressionLike): @overload -def make_pwaff(src: str, ctx: isl.Context | None = None) -> PwAff: +def make_pw_aff(src: str, ctx: isl.Context | None = None) -> PwAff: ... @overload -def make_pwaff(src: isl.PwAff) -> PwAff: +def make_pw_aff(src: isl.PwAff) -> PwAff: ... -def make_pwaff(src: str | isl.PwAff, ctx: isl.Context | None = None) -> PwAff: +def make_pw_aff(src: str | isl.PwAff, ctx: isl.Context | None = None) -> PwAff: obj = isl.PwAff(src, ctx) if isinstance(src, str) else src pwaff_obj, dimtype_to_names = _deconstruct_object(obj) - - assert isinstance(obj, isl.PwAff) - pwaff_obj, name_to_dim = _strip_names(obj) + pwaff_obj, name_to_dim = _strip_names(pwaff_obj) return PwAff(pwaff_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args +@final +@dataclass(frozen=True, eq=False) +class PwQPolynomial(_NamedPwExpressionLike): + _obj: isl.PwQPolynomial + + +@overload +def make_pw_qpolynomial( + src: str, ctx: isl.Context | None = None) -> PwQPolynomial: + ... + + +@overload +def make_pw_qpolynomial(src: isl.PwQPolynomial) -> PwQPolynomial: + ... + + +def make_pw_qpolynomial( + src: str | isl.PwQPolynomial, + ctx: isl.Context | None = None + ) -> PwQPolynomial: + obj = isl.PwQPolynomial(src, ctx) if isinstance(src, str) else src + + pw_qp_obj, dimtype_to_names = _deconstruct_object(obj) + pw_qp_obj, name_to_dim = _strip_names(pw_qp_obj) + + return PwQPolynomial(pw_qp_obj, name_to_dim, dimtype_to_names) + +# }}} + + +# {{{ multi expression-likes (multiaff, pwmultiaff) + +@dataclass(frozen=True, eq=False) +class _NamedMultiExpressionLike(NamedIslObject[isl.Set]): + """ + Multi-expressions in ISL cannot have dimensions moved. As a workaround, we + represent multi-expressions as sets internally. This is done during + deconstruction by converting a multi-expression to a map, then converting + the resulting map to a set. During reconstruction, we simply follow the + deconstruction steps backwards (set -> map -> multi-expression). As such, + reconstruction is special-cased for each subclass. + """ + ... + + @final @dataclass(frozen=True, eq=False) class PwMultiAff(_NamedMultiExpressionLike): - _obj: isl.PwMultiAff + @override + def _reconstruct_isl_object(self) -> isl.PwMultiAff: + # deconstruction: isl.PwMultiAff -> isl.Map -> isl.Set + # reconstruction: isl.Set -> isl.Map -> isl.PwMultiAff + return super()._reconstruct_isl_object().as_pw_multi_aff() @overload @@ -153,8 +250,38 @@ def make_pw_multi_aff( obj = isl.PwMultiAff(src, ctx) if isinstance(src, str) else src pw_maff_obj, dimtype_to_names = _deconstruct_object(obj) - - assert isinstance(pw_maff_obj, isl.PwMultiAff) pw_maff_obj, name_to_dim = _strip_names(pw_maff_obj) return PwMultiAff(pw_maff_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args + + +@final +@dataclass(frozen=True, eq=False) +class MultiAff(NamedIslObject[isl.Set]): + @override + def _reconstruct_isl_object(self) -> isl.MultiAff: + # deconstruction: isl.MultiAff -> isl.Map -> isl.Set + # reconstruction: isl.Set -> isl.Map -> isl.PwMultiAff -> isl.MultiAff + return super()._reconstruct_isl_object().as_pw_multi_aff().as_multi_aff() + + +@overload +def make_multi_aff(src: str, ctx: isl.Context | None = None) -> MultiAff: + ... + + +@overload +def make_multi_aff(src: isl.MultiAff) -> MultiAff: + ... + + +def make_multi_aff( + src: str | isl.MultiAff, ctx: isl.Context | None = None) -> MultiAff: + obj = isl.MultiAff(src, ctx) if isinstance(src, str) else src + + maff_obj, dimtype_to_names = _deconstruct_object(obj) + maff_obj, name_to_dim = _strip_names(maff_obj) + + return MultiAff(maff_obj, name_to_dim, dimtype_to_names) + +# }}} diff --git a/namedisl/set_like.py b/namedisl/set_like.py index 12f70b2..29fcb1d 100644 --- a/namedisl/set_like.py +++ b/namedisl/set_like.py @@ -36,17 +36,15 @@ import islpy as isl from .core import ( - IslSetLike, NamedIslObject, NameToDim, _align_and_apply_binary_op, _align_two, _deconstruct_object, _find_contiguous_dim_chunks, - _restore_names, _strip_names, ) -from .expression_like import PwAff, PwMultiAff, make_pw_multi_aff, make_pwaff +from .expression_like import PwAff, PwMultiAff, make_pw_aff, make_pw_multi_aff if TYPE_CHECKING: @@ -61,49 +59,6 @@ class _NamedIslSetLike(NamedIslObject[isl.Set], ABC): [ (set names), (input names), (parameter names) ] """ - @override - def _reconstruct_isl_object(self) -> IslSetLike: - """ - Relies on the dimension type ordering in - :func:`_deconstruct_set_like_object`. - """ - - obj = _restore_names(self._obj, self._name_to_dim) - assert isinstance(obj, isl.Set) - - if self._has_params: - if self._parameter_dim_start is None: - raise ValueError( - "Object has parameter dimensions, but a starting index for " - "parameter names is not given. Reconstruction is not " - "possible") - - param_start = self._parameter_dim_start - obj = obj.move_dims( - isl.dim_type.param, 0, - isl.dim_type.set, param_start, len(self._parameter_names) - ) - - if self._has_inputs: - if self._input_dim_start is None: - raise ValueError( - "Object has input dimensions, but a starting index for " - "input names is not given. Reconstruction is not " - "possible") - - obj_domain = isl.Set("{ [] }") - obj_range = obj - - obj = isl.Map.from_domain_and_range(obj_domain, obj_range) - - inp_start = self._input_dim_start - obj = obj.move_dims( - isl.dim_type.in_, 0, - isl.dim_type.set, inp_start, len(self._input_names) - ) - - return obj - def complement(self: _NamedIslSetLike) -> _NamedIslSetLike: return replace( self, @@ -198,10 +153,10 @@ def project_out_except( return self.project_out(names_to_project_out) def dim_max(self, name: str) -> PwAff: - return make_pwaff(self._obj.dim_max(self._name_to_dim[name])) + return make_pw_aff(self._obj.dim_max(self._name_to_dim[name])) def dim_min(self, name: str) -> PwAff: - return make_pwaff(self._obj.dim_min(self._name_to_dim[name])) + return make_pw_aff(self._obj.dim_min(self._name_to_dim[name])) def as_pw_multi_aff(self) -> PwMultiAff: return make_pw_multi_aff(self._reconstruct_isl_object().as_pw_multi_aff()) @@ -222,7 +177,7 @@ def __sub__( @override def __eq__(self, other: object) -> bool: if not isinstance(other, type(self)): - raise NotImplementedError + raise ValueError("Objects are not of the same type") aligned_self, aligned_other = _align_two(self, other) diff --git a/namedisl/test/test_expression_like.py b/namedisl/test/test_expression_like.py new file mode 100644 index 0000000..9d5454a --- /dev/null +++ b/namedisl/test/test_expression_like.py @@ -0,0 +1,327 @@ +from __future__ import annotations + + +__copyright__ = """ +Copyright (C) 2025- University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + + +import islpy as isl + +import namedisl as nisl + + +# {{{ affs + +def test_aff_from_str(): + spec = "[n] -> { [i] -> [2 * i + n] }" + named_aff = nisl.make_aff(spec) + aff = isl.Aff(spec) + + print(named_aff) + print(aff) + + assert aff == named_aff._reconstruct_isl_object() + + +def test_aff_from_aff(): + aff = isl.Aff("[n] -> { [i] -> [2 * i + n] }") + named_aff = nisl.make_aff(aff) + + print(named_aff) + print(aff) + + assert aff == named_aff._reconstruct_isl_object() + + +def test_aff_binary_ops(): + spec = "[n] -> { [i] -> [2 * i + n] }" + named_aff = nisl.make_aff(spec) + aff = isl.Aff(spec) + + aff_p_aff = aff + aff + aff_p_1 = aff + 1 + + naff_p_naff = named_aff + named_aff + naff_p_1 = named_aff + 1 + + assert naff_p_naff._reconstruct_isl_object() == aff_p_aff + assert naff_p_1._reconstruct_isl_object() == aff_p_1 + + aff_s_aff = aff - aff + aff_s_1 = aff - 1 + + naff_s_naff = named_aff - named_aff + naff_s_1 = named_aff - 1 + + assert naff_s_naff._reconstruct_isl_object() == aff_s_aff + assert naff_s_1._reconstruct_isl_object() == aff_s_1 + + aff_m_3 = aff * 3 + naff_m_3 = named_aff * 3 + + assert naff_m_3._reconstruct_isl_object() == aff_m_3 + +# }}} + + +# {{{ pwaffs + +def test_pwaff_from_str(): + spec = "[n] -> { [i] -> [2 * i + n] }" + named_pwaff = nisl.make_pw_aff(spec) + pwaff = isl.PwAff(spec) + + print(named_pwaff) + print(pwaff) + + assert pwaff == named_pwaff._reconstruct_isl_object() + + +def test_pwaff_from_pwaff(): + pwaff = isl.PwAff("[n] -> { [i] -> [2 * i + n] }") + named_pwaff = nisl.make_pw_aff(pwaff) + + print(named_pwaff) + print(pwaff) + + assert pwaff == named_pwaff._reconstruct_isl_object() + + +def test_pwaff_binary_ops(): + spec = "[n] -> { [i] -> [2 * i + n] }" + named_pwaff = nisl.make_pw_aff(spec) + pwaff = isl.PwAff(spec) + + pwaff_p_pwaff = pwaff + pwaff + pwaff_p_1 = pwaff + 1 + + npwaff_p_npwaff = named_pwaff + named_pwaff + npwaff_p_1 = named_pwaff + 1 + + assert npwaff_p_npwaff._reconstruct_isl_object() == pwaff_p_pwaff + assert npwaff_p_1._reconstruct_isl_object() == pwaff_p_1 + + pwaff_s_pwaff = pwaff - pwaff + pwaff_s_1 = pwaff - 1 + + npwaff_s_npwaff = named_pwaff - named_pwaff + npwaff_s_1 = named_pwaff - 1 + + assert npwaff_s_npwaff._reconstruct_isl_object() == pwaff_s_pwaff + assert npwaff_s_1._reconstruct_isl_object() == pwaff_s_1 + + pwaff_m_3 = pwaff * 3 + npwaff_m_3 = named_pwaff * 3 + + assert npwaff_m_3._reconstruct_isl_object() == pwaff_m_3 + +# }}} + + +# {{{ qpolynomials + +def test_qpolynomial_from_str(): + spec = "[n] -> { [i] -> 2 * i + n }" + named_qpolynomial = nisl.make_qpolynomial(spec) + qpolynomial = isl.PwQPolynomial(spec).get_pieces()[0][1] + + print(named_qpolynomial) + print(qpolynomial) + + assert (named_qpolynomial._reconstruct_isl_object() - qpolynomial).is_zero() + + +def test_qpolynomial_from_qpolynomial(): + qpolynomial = isl.PwQPolynomial( + "[n] -> { [i] -> 2 * i + n }").get_pieces()[0][1] + named_qpolynomial = nisl.make_qpolynomial(qpolynomial) + + print(named_qpolynomial) + print(qpolynomial) + + assert (named_qpolynomial._reconstruct_isl_object() - qpolynomial).is_zero() + + +def test_qpolynomial_binary_ops(): + spec = "[n] -> { [i] -> 2 * i + n }" + named_qpolynomial = nisl.make_qpolynomial(spec) + qpolynomial = isl.PwQPolynomial(spec).get_pieces()[0][1] + + qpolynomial_p_qpolynomial = qpolynomial + qpolynomial + qpolynomial_p_1 = qpolynomial + 1 + + nqpolynomial_p_nqpolynomial = named_qpolynomial + named_qpolynomial + nqpolynomial_p_1 = named_qpolynomial + 1 + + assert ( + nqpolynomial_p_nqpolynomial._reconstruct_isl_object() + - + qpolynomial_p_qpolynomial + ).is_zero() + assert ( + nqpolynomial_p_1._reconstruct_isl_object() + - + qpolynomial_p_1 + ).is_zero() + + qpolynomial_s_qpolynomial = qpolynomial - qpolynomial + qpolynomial_s_1 = qpolynomial - 1 + + nqpolynomial_s_nqpolynomial = named_qpolynomial - named_qpolynomial + nqpolynomial_s_1 = named_qpolynomial - 1 + + assert ( + nqpolynomial_s_nqpolynomial._reconstruct_isl_object() + - + qpolynomial_s_qpolynomial + ).is_zero() + assert ( + nqpolynomial_s_1._reconstruct_isl_object() + - + qpolynomial_s_1 + ).is_zero() + + qpolynomial_m_3 = qpolynomial * 3 + nqpolynomial_m_3 = named_qpolynomial * 3 + + assert ( + nqpolynomial_m_3._reconstruct_isl_object() + - + qpolynomial_m_3 + ).is_zero() + + +def test_qpolynomial_permuted_is_zero(): + # forces the objects to be aligned before binary op is applied, whereas + # the test above does not necessarily touch alignment code + + named_qp = nisl.make_qpolynomial( + "[n, m] -> { [a, b] -> a*a + 2*b + n - m }") + named_qp_perm = nisl.make_qpolynomial( + "[m, n] -> { [b, a] -> a*a + 2*b + n - m }") + + assert ( + named_qp - named_qp_perm + )._reconstruct_isl_object().is_zero() + +# }}} + + +# {{{ pwqpolynomials + +def test_pw_qp_from_str(): + spec = "[n] -> { [i] -> 2 * i + n }" + named_pw_qp = nisl.make_pw_qpolynomial(spec) + pw_qp = isl.PwQPolynomial(spec) + + print(named_pw_qp) + print(pw_qp) + + assert (named_pw_qp._reconstruct_isl_object() - pw_qp).is_zero() + + +def test_pw_qp_from_pw_qp(): + pw_qp = isl.PwQPolynomial( + "[n] -> { [i] -> 2 * i + n }") + named_pw_qp = nisl.make_pw_qpolynomial(pw_qp) + + print(named_pw_qp) + print(pw_qp) + + assert (named_pw_qp._reconstruct_isl_object() - pw_qp).is_zero() + + +def test_pw_qp_binary_ops(): + spec = "[n] -> { [i] -> 2 * i + n }" + named_pw_qp = nisl.make_pw_qpolynomial(spec) + pw_qp = isl.PwQPolynomial(spec) + + pw_qp_p_pw_qp = pw_qp + pw_qp + pw_qp_p_1 = pw_qp + 1 + + npw_qp_p_npw_qp = named_pw_qp + named_pw_qp + npw_qp_p_1 = named_pw_qp + 1 + + assert ( + npw_qp_p_npw_qp._reconstruct_isl_object() + - + pw_qp_p_pw_qp + ).is_zero() + assert ( + npw_qp_p_1._reconstruct_isl_object() + - + pw_qp_p_1 + ).is_zero() + + pw_qp_s_pw_qp = pw_qp - pw_qp + pw_qp_s_1 = pw_qp - 1 + + npw_qp_s_npw_qp = named_pw_qp - named_pw_qp + npw_qp_s_1 = named_pw_qp - 1 + + assert ( + npw_qp_s_npw_qp._reconstruct_isl_object() + - + pw_qp_s_pw_qp + ).is_zero() + assert ( + npw_qp_s_1._reconstruct_isl_object() + - + pw_qp_s_1 + ).is_zero() + + pw_qp_m_3 = pw_qp * 3 + npw_qp_m_3 = named_pw_qp * 3 + + assert ( + npw_qp_m_3._reconstruct_isl_object() + - + pw_qp_m_3 + ).is_zero() + + +def test_pw_qpolynomial_permuted_is_zero(): + # forces the objects to be aligned before binary op is applied, whereas + # the test above does not necessarily touch alignment code + + named_qp = nisl.make_pw_qpolynomial( + "[n, m] -> { [a, b] -> a*a + 2*b + n - m }") + named_qp_perm = nisl.make_pw_qpolynomial( + "[m, n] -> { [b, a] -> a*a + 2*b + n - m }") + + assert ( + named_qp - named_qp_perm + ).is_zero() + +# }}} + + +# {{{ multiaffs + +# }}} + + +# {{{ pwmultiaffs + +# }}} diff --git a/namedisl/test/test_set.py b/namedisl/test/test_set.py deleted file mode 100644 index 9cb12a1..0000000 --- a/namedisl/test/test_set.py +++ /dev/null @@ -1,154 +0,0 @@ -from __future__ import annotations - - -__copyright__ = """ -Copyright (C) 2025- University of Illinois Board of Trustees -""" - -__license__ = """ -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. -""" - -import pytest - -import islpy as isl - -import namedisl as nisl -from .utils_for_tests import generate_random_named_set - - -def test_set_from_str() -> None: - s = nisl.make_set("[n] -> { [i]: 0 <= i < n }") - - print(s._obj) - print(s) - - -def test_set_from_set() -> None: - s = isl.Set("[n] -> { [i, j] : 0 <= i, j < n }") - named_set = nisl.make_set(s) - - print(named_set._obj) - print(named_set) - - -@pytest.mark.parametrize("ndims", [2, 3, 4, 5]) -@pytest.mark.parametrize("has_params", [True, False]) -def test_set_equality(ndims: int, has_params: bool): - a_param = "n" if has_params else None - - a, a_dims, a_cond = generate_random_named_set(ndims, "a", a_param) - - from itertools import permutations - for perm in list(permutations(a_dims.split(","))): - perm_dims = ",".join(p for p in perm) - set_str = f"{{ [{perm_dims}] : {a_cond} }}" - if has_params: - set_str = f"[{a_param}] ->" + set_str - perm_set = nisl.make_set(set_str) - - assert a == perm_set - - -@pytest.mark.parametrize("ndims", [1, 2, 4, 8]) -@pytest.mark.parametrize("has_params", [True, False]) -def test_set_union(ndims: int, has_params: bool): - - if has_params: - a_param = "n" - b_param = "m" - else: - a_param = None - b_param = None - - a, a_dims, a_cond = generate_random_named_set(ndims, "a", a_param) - b, b_dims, b_cond = generate_random_named_set(ndims, "b", b_param) - - set_str = f"{{ [{a_dims}, {b_dims}] : ({a_cond}) or ({b_cond})}}" - if has_params: - set_str = "[n, m] -> " + set_str - - result = nisl.make_set(set_str) - - assert (a | b) == result - - -@pytest.mark.parametrize("ndims", [1, 2, 4, 8]) -@pytest.mark.parametrize("has_params", [True, False]) -def test_set_intersection(ndims: int, has_params: bool): - - if has_params: - a_param = "n" - b_param = "m" - else: - a_param = None - b_param = None - - a, a_dims, a_cond = generate_random_named_set(ndims, "a", a_param) - b, b_dims, b_cond = generate_random_named_set(ndims, "b", b_param) - - set_str = f"{{ [{a_dims}, {b_dims}] : ({a_cond}) and ({b_cond})}}" - if has_params: - set_str = "[n, m] -> " + set_str - - result = nisl.make_set(set_str) - - assert (a & b) == result - - -@pytest.mark.parametrize("ndims", [1, 2, 4, 8]) -def test_set_eliminate(ndims: int): - a, a_dims, _ = generate_random_named_set(ndims, "a", None) - a = a.eliminate(a_dims.split(",")) - - assert a == nisl.make_set(f"{{[{a_dims}]}}") - - -@pytest.mark.parametrize("ndims", [2, 4, 8]) -def test_set_project_out(ndims: int): - a, a_dims, _ = generate_random_named_set(ndims, "a", None) - a = a.project_out(a_dims.split(",")) - - assert a == nisl.make_set("{[]}") - - -@pytest.mark.parametrize("ndims", [2, 4, 8]) -def test_set_dim_max(ndims: int): - a, a_dims, a_cond = generate_random_named_set(ndims, "a", None) - - cond_pw_affs = [ - isl.PwAff(f"{{ [{cond.split('<')[2].strip(' ')}] }}") - for cond in a_cond.split("and") - ] - - for i, name in enumerate(a_dims.split(",")): - assert a.dim_max(name)._obj == (cond_pw_affs[i] - 1) - - -@pytest.mark.parametrize("ndims", [2, 4, 8]) -def test_set_dim_min(ndims: int): - a, a_dims, a_cond = generate_random_named_set(ndims, "a", None) - - cond_pw_affs = [ - isl.PwAff(f"{{ [{cond.split('<')[0].strip(' ')}] }}") - for cond in a_cond.split("and") - ] - - for i, name in enumerate(a_dims.split(",")): - assert a.dim_min(name)._obj == cond_pw_affs[i] diff --git a/namedisl/test/test_map.py b/namedisl/test/test_set_like.py similarity index 63% rename from namedisl/test/test_map.py rename to namedisl/test/test_set_like.py index 5795569..8d9d19c 100644 --- a/namedisl/test/test_map.py +++ b/namedisl/test/test_set_like.py @@ -30,9 +30,138 @@ import islpy as isl import namedisl as nisl -from .utils_for_tests import generate_random_named_map +from .utils_for_tests import generate_random_named_map, generate_random_named_set +# {{{ sets + +def test_set_from_str() -> None: + s = nisl.make_set("[n] -> { [i]: 0 <= i < n }") + + print(s._obj) + print(s) + + +def test_set_from_set() -> None: + s = isl.Set("[n] -> { [i, j] : 0 <= i, j < n }") + named_set = nisl.make_set(s) + + print(named_set._obj) + print(named_set) + + +@pytest.mark.parametrize("ndims", [2, 3, 4, 5]) +@pytest.mark.parametrize("has_params", [True, False]) +def test_set_equality(ndims: int, has_params: bool): + a_param = "n" if has_params else None + + a, a_dims, a_cond = generate_random_named_set(ndims, "a", a_param) + + from itertools import permutations + for perm in list(permutations(a_dims.split(","))): + perm_dims = ",".join(p for p in perm) + set_str = f"{{ [{perm_dims}] : {a_cond} }}" + if has_params: + set_str = f"[{a_param}] ->" + set_str + perm_set = nisl.make_set(set_str) + + assert a == perm_set + + +@pytest.mark.parametrize("ndims", [1, 2, 4, 8]) +@pytest.mark.parametrize("has_params", [True, False]) +def test_set_union(ndims: int, has_params: bool): + + if has_params: + a_param = "n" + b_param = "m" + else: + a_param = None + b_param = None + + a, a_dims, a_cond = generate_random_named_set(ndims, "a", a_param) + b, b_dims, b_cond = generate_random_named_set(ndims, "b", b_param) + + set_str = f"{{ [{a_dims}, {b_dims}] : ({a_cond}) or ({b_cond})}}" + if has_params: + set_str = "[n, m] -> " + set_str + + result = nisl.make_set(set_str) + + assert (a | b) == result + + +@pytest.mark.parametrize("ndims", [1, 2, 4, 8]) +@pytest.mark.parametrize("has_params", [True, False]) +def test_set_intersection(ndims: int, has_params: bool): + + if has_params: + a_param = "n" + b_param = "m" + else: + a_param = None + b_param = None + + a, a_dims, a_cond = generate_random_named_set(ndims, "a", a_param) + b, b_dims, b_cond = generate_random_named_set(ndims, "b", b_param) + + set_str = f"{{ [{a_dims}, {b_dims}] : ({a_cond}) and ({b_cond})}}" + if has_params: + set_str = "[n, m] -> " + set_str + + result = nisl.make_set(set_str) + + assert (a & b) == result + + +@pytest.mark.parametrize("ndims", [1, 2, 4, 8]) +def test_set_eliminate(ndims: int): + a, a_dims, _ = generate_random_named_set(ndims, "a", None) + a = a.eliminate(a_dims.split(",")) + + assert a == nisl.make_set(f"{{[{a_dims}]}}") + + +@pytest.mark.parametrize("ndims", [2, 4, 8]) +def test_set_project_out(ndims: int): + a, a_dims, _ = generate_random_named_set(ndims, "a", None) + a = a.project_out(a_dims.split(",")) + + assert a == nisl.make_set("{[]}") + + +@pytest.mark.parametrize("ndims", [2, 4, 8]) +def test_set_dim_max(ndims: int): + a, a_dims, a_cond = generate_random_named_set(ndims, "a", None) + + # unnamed, so use isl.PwAff instead of nisl.make_pw_aff + cond_pw_affs = [ + isl.PwAff(f"{{ [] -> [{cond.split('<')[2].strip(' ')}] }}") + for cond in a_cond.split("and") + ] + + for i, name in enumerate(a_dims.split(",")): + assert a.dim_max(name)._reconstruct_isl_object() == (cond_pw_affs[i] - 1) + + +@pytest.mark.parametrize("ndims", [2, 4, 8]) +def test_set_dim_min(ndims: int): + a, a_dims, a_cond = generate_random_named_set(ndims, "a", None) + + # unnamed, so use isl.PwAff instead of nisl.make_pw_aff + cond_pw_affs = [ + isl.PwAff(f"{{ [] -> [{cond.split('<')[0].strip(' ')}] }}") + for cond in a_cond.split("and") + ] + + for i, name in enumerate(a_dims.split(",")): + assert a.dim_min(name)._reconstruct_isl_object() == cond_pw_affs[i] + +# }}} + + +# {{{ maps + def test_map_from_str() -> None: m = nisl.make_map( "[n] -> { [i,j] -> [a,b] : 0 <= i, j < 10 and 0 <= a, b < 20 }") @@ -211,7 +340,7 @@ def test_map_as_pw_multi_aff(): m = nisl.make_map(spec) m_isl = isl.Map(spec) - assert m.as_pw_multi_aff()._obj == m_isl.as_pw_multi_aff() + assert m.as_pw_multi_aff()._reconstruct_isl_object() == m_isl.as_pw_multi_aff() @pytest.mark.parametrize("ndims_domain", [1, 2, 4, 8]) @@ -222,8 +351,9 @@ def test_map_dim_max(ndims_domain: int, ndims_range: int): ndims_range, "x_out", None ) + # unnamed, so use isl.PwAff instead of nisl.make_pw_aff in_upper_bound_pw_maffs = [ - isl.PwAff(f"{{ [{int(cond.split('<')[2].strip(' '))}] }}") + isl.PwAff(f"{{ [] -> [{int(cond.split('<')[2].strip(' '))}] }}") for cond in in_conds.split("and") ] @@ -231,8 +361,9 @@ def test_map_dim_max(ndims_domain: int, ndims_range: int): # NOTE: constructing PwAffs assumes starting index of 0, so subtract 1 assert m.dim_max(name)._obj == (in_upper_bound_pw_maffs[i] - 1) + # unnamed, so use isl.PwAff instead of nisl.make_pw_aff out_upper_bound_pw_maffs = [ - isl.PwAff(f"{{ [{int(cond.split('<')[2].strip(' '))}] }}") + isl.PwAff(f"{{ [] -> [{int(cond.split('<')[2].strip(' '))}] }}") for cond in out_conds.split("and") ] @@ -249,18 +380,43 @@ def test_map_dim_min(ndims_domain: int, ndims_range: int): ndims_range, "x_out", None ) + # unnamed, so use isl.PwAff instead of nisl.make_pw_aff in_lower_bound_pw_maffs = [ - isl.PwAff(f"{{ [{int(cond.split('<')[0].strip(' '))}] }}") + isl.PwAff(f"{{ [] -> [{int(cond.split('<')[0].strip(' '))}] }}") for cond in in_conds.split("and") ] for i, name in enumerate(in_names.split(",")): assert m.dim_min(name)._obj == in_lower_bound_pw_maffs[i] + # unnamed, so use isl.PwAff instead of nisl.make_pw_aff out_lower_bound_pw_maffs = [ - isl.PwAff(f"{{ [{int(cond.split('<')[0].strip(' '))}] }}") + isl.PwAff(f"{{ [] -> [{int(cond.split('<')[0].strip(' '))}] }}") for cond in out_conds.split("and") ] for i, name in enumerate(out_names.split(",")): assert m.dim_min(name)._obj == out_lower_bound_pw_maffs[i] + +# }}} + + +# {{{ basic{map, set} + +def test_basic_map_from_str() -> None: + m = nisl.make_basic_map( + "[n] -> { [i,j] -> [a,b] : 0 <= i, j < 10 and 0 <= a, b < 20 }") + + print(m._obj) + print(m) + + +def test_basic_map_from_map() -> None: + m = isl.BasicMap( + "[n] -> { [i,j] -> [a,b] : 0 <= i, j < 10 and 0 <= a, b < 20 }") + named_map = nisl.make_basic_map(m) + + print(named_map._obj) + print(named_map) + +# }}} From d8484541fefa1e9e0b1336ca9a43a7f56b80fd5e Mon Sep 17 00:00:00 2001 From: Addison Date: Fri, 30 Jan 2026 14:09:23 -0600 Subject: [PATCH 16/43] address pylint concerns --- namedisl/expression_like.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/namedisl/expression_like.py b/namedisl/expression_like.py index e6668e4..851809b 100644 --- a/namedisl/expression_like.py +++ b/namedisl/expression_like.py @@ -148,7 +148,7 @@ def make_qpolynomial( qp_obj, dimtype_to_names = _deconstruct_object(obj) qp_obj, name_to_dim = _strip_names(qp_obj) - return QPolynomial(qp_obj, name_to_dim, dimtype_to_names) + return QPolynomial(qp_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args @final @@ -202,7 +202,7 @@ def make_pw_qpolynomial( pw_qp_obj, dimtype_to_names = _deconstruct_object(obj) pw_qp_obj, name_to_dim = _strip_names(pw_qp_obj) - return PwQPolynomial(pw_qp_obj, name_to_dim, dimtype_to_names) + return PwQPolynomial(pw_qp_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args # }}} @@ -282,6 +282,6 @@ def make_multi_aff( maff_obj, dimtype_to_names = _deconstruct_object(obj) maff_obj, name_to_dim = _strip_names(maff_obj) - return MultiAff(maff_obj, name_to_dim, dimtype_to_names) + return MultiAff(maff_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args # }}} From 157e1197c9373721e48892a8eb909df803649c4f Mon Sep 17 00:00:00 2001 From: Addison Date: Fri, 30 Jan 2026 14:10:48 -0600 Subject: [PATCH 17/43] address ruff concerns; leaving basedpyright + docs failing for now --- namedisl/expression_like.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/namedisl/expression_like.py b/namedisl/expression_like.py index 851809b..dcb3333 100644 --- a/namedisl/expression_like.py +++ b/namedisl/expression_like.py @@ -148,7 +148,7 @@ def make_qpolynomial( qp_obj, dimtype_to_names = _deconstruct_object(obj) qp_obj, name_to_dim = _strip_names(qp_obj) - return QPolynomial(qp_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args + return QPolynomial(qp_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args @final @@ -202,7 +202,7 @@ def make_pw_qpolynomial( pw_qp_obj, dimtype_to_names = _deconstruct_object(obj) pw_qp_obj, name_to_dim = _strip_names(pw_qp_obj) - return PwQPolynomial(pw_qp_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args + return PwQPolynomial(pw_qp_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args # }}} @@ -282,6 +282,6 @@ def make_multi_aff( maff_obj, dimtype_to_names = _deconstruct_object(obj) maff_obj, name_to_dim = _strip_names(maff_obj) - return MultiAff(maff_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args + return MultiAff(maff_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args # }}} From 2f1eff05f754a7cfa80a25aba2f4aeaa8461291c Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 2 Feb 2026 10:40:05 -0600 Subject: [PATCH 18/43] Toward type happiness? --- namedisl/core.py | 58 ++++++++++++++++++++++--------------- namedisl/expression_like.py | 15 +++++----- 2 files changed, 41 insertions(+), 32 deletions(-) diff --git a/namedisl/core.py b/namedisl/core.py index c1dc3e5..d9272c0 100644 --- a/namedisl/core.py +++ b/namedisl/core.py @@ -30,7 +30,7 @@ from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass from importlib import metadata -from typing import Generic, TypeAlias, TypeVar +from typing import Generic, TypeAlias, TypeVar, overload from constantdict import constantdict from typing_extensions import override @@ -47,25 +47,19 @@ IslExpressionLikeT = TypeVar( "IslExpressionLikeT", - isl.Aff, - isl.MultiAff, - isl.PwAff, - isl.QPolynomial, - isl.PwQPolynomial + bound=IslExpressionLike, ) IslSetLikeT = TypeVar( "IslSetLikeT", - isl.BasicSet, - isl.BasicMap, - isl.Set, - isl.Map, + bound=IslSetLike ) IslMultiExpressionLikeT = TypeVar( "IslMultiExpressionLikeT", - isl.PwMultiAff, - isl.MultiAff + bound=IslMultiExpressionLike ) -IslObjectT = TypeVar("IslObjectT", IslSetLike, IslExpressionLike) +IslObjectT = TypeVar("IslObjectT", bound=IslSetLike | IslExpressionLike) + +NamedIslObjectT = TypeVar("NamedIslObjectT", bound="NamedIslObject[IslObject]") NameToDim: TypeAlias = Mapping[str, int] @@ -74,7 +68,7 @@ # alignment DimTypeToNames: TypeAlias = Mapping[isl.dim_type, frozenset[str]] -IslObjectPieces: TypeAlias = tuple[IslSetLike | IslExpressionLike, DimTypeToNames] +IslObjectPieces: TypeAlias = tuple[IslObjectT, DimTypeToNames] __version__ = metadata.version("namedisl") @@ -173,7 +167,23 @@ def _get_dim_names(obj: IslObjectT, dt: isl.dim_type) -> frozenset[str]: return frozenset(all_dt_names) -def _deconstruct_object(obj: IslObjectT) -> IslObjectPieces: +@overload +def _deconstruct_object(obj: isl.Map) -> tuple[isl.Set, DimTypeToNames]: + ... + + +@overload +# PwMultiAff doesn't have move_dims, so we're being a bit crooked here. +def _deconstruct_object(obj: isl.PwMultiAff) -> tuple[isl.Set, DimTypeToNames]: + ... + + +@overload +def _deconstruct_object(obj: IslObjectT) -> tuple[IslObjectT, DimTypeToNames]: + ... + + +def _deconstruct_object(obj: IslObjectT) -> tuple[IslObject, DimTypeToNames]: dt_to_names: dict[isl.dim_type, frozenset[str]] = {} if isinstance(obj, IslSetLike | IslMultiExpressionLike): @@ -305,10 +315,10 @@ def _find_joint_name_to_dim( def _align_obj( - named_obj: NamedIslObject[IslObjectT], + named_obj: NamedIslObjectT, ordering: NameToDim, dimtype_to_names: DimTypeToNames - ) -> NamedIslObject[IslObjectT]: + ) -> NamedIslObjectT: new_isl_obj = named_obj._obj running_name_to_dim = dict(named_obj._name_to_dim) @@ -354,9 +364,9 @@ def _align_obj( def _align_two( - named_obj1: NamedIslObject[IslObjectT], - named_obj2: NamedIslObject[IslObjectT] - ) -> tuple[NamedIslObject[IslObjectT], ...]: + named_obj1: NamedIslObjectT, + named_obj2: NamedIslObjectT + ) -> tuple[NamedIslObjectT, ...]: name_to_dim, dimtype_to_names = _find_joint_name_to_dim(named_obj1, named_obj2) @@ -368,10 +378,10 @@ def _align_two( def _align_and_apply_binary_op( - lhs: NamedIslObject[IslObjectT], - rhs: NamedIslObject[IslObjectT], + lhs: NamedIslObjectT, + rhs: NamedIslObjectT, op: Callable[[IslObjectT, IslObjectT], IslObjectT] - ) -> NamedIslObject[IslObjectT]: + ) -> NamedIslObjectT: lhs, rhs = _align_two(lhs, rhs) result = op(lhs._obj, rhs._obj) @@ -435,7 +445,7 @@ def _parameter_dim_start(self) -> int | None: ) return None - def _reconstruct_isl_object(self) -> IslObject: + def _reconstruct_isl_object(self) -> IslObjectT: """ Relies on the dimension type ordering in :func:`_deconstruct_set_like_object`. diff --git a/namedisl/expression_like.py b/namedisl/expression_like.py index dcb3333..d851809 100644 --- a/namedisl/expression_like.py +++ b/namedisl/expression_like.py @@ -27,14 +27,14 @@ import operator from dataclasses import dataclass, replace -from typing import final, overload, override +from typing import final, overload -from typing_extensions import Self +from typing_extensions import Self, override import islpy as isl from .core import ( - IslExpressionLike, + IslExpressionLikeT, NamedIslObject, _align_and_apply_binary_op, _deconstruct_object, @@ -45,7 +45,7 @@ # {{{ "base" named expression-likes (affs, pwaffs, qpolynomials, pwqpolynomials) @dataclass(frozen=True, eq=False) -class _NamedExpressionLike(NamedIslObject[IslExpressionLike]): +class _NamedExpressionLike(NamedIslObject[IslExpressionLikeT]): # FIXME: Self is used here is because _NamedExpressionLike is generic, # leading to complaints from basedpyright def __add__(self, other: Self | int) -> Self: @@ -90,16 +90,15 @@ def __eq__(self, other: object) -> bool: @dataclass(frozen=True, eq=False) -class _NamedPwExpressionLike(_NamedExpressionLike): +class _NamedPwExpressionLike(_NamedExpressionLike[IslExpressionLikeT]): ... @final @dataclass(frozen=True, eq=False) -class Aff(_NamedExpressionLike): +class Aff(_NamedExpressionLike[isl.Aff]): _obj: isl.Aff - @overload def make_aff(src: str, ctx: isl.Context | None = None) -> Aff: ... @@ -121,7 +120,7 @@ def make_aff(src: str | isl.Aff, ctx: isl.Context | None = None) -> Aff: @final @dataclass(frozen=True, eq=False) -class QPolynomial(_NamedExpressionLike): +class QPolynomial(_NamedExpressionLike[isl.QPolynomial]): _obj: isl.QPolynomial From 4c397ee3ff149d08c8fe3178ccb0f9dcc7d7e3c9 Mon Sep 17 00:00:00 2001 From: Addison Date: Mon, 16 Feb 2026 10:20:48 -0600 Subject: [PATCH 19/43] toward making basedpyright happy --- namedisl/core.py | 99 +++++++++++++++++++++++++++++------------------- 1 file changed, 61 insertions(+), 38 deletions(-) diff --git a/namedisl/core.py b/namedisl/core.py index d9272c0..a610fa5 100644 --- a/namedisl/core.py +++ b/namedisl/core.py @@ -38,11 +38,12 @@ import islpy as isl -IslSetLike = isl.BasicSet | isl.BasicMap | isl.Set | isl.Map IslBaseExpressionLike = isl.Aff | isl.QPolynomial IslPwExpressionLike = isl.PwAff | isl.PwQPolynomial IslMultiExpressionLike = isl.MultiAff | isl.PwMultiAff + IslExpressionLike = IslBaseExpressionLike | IslPwExpressionLike | IslMultiExpressionLike +IslSetLike = isl.BasicSet | isl.BasicMap | isl.Set | isl.Map IslObject = IslSetLike | IslExpressionLike | IslMultiExpressionLike IslExpressionLikeT = TypeVar( @@ -57,7 +58,11 @@ "IslMultiExpressionLikeT", bound=IslMultiExpressionLike ) -IslObjectT = TypeVar("IslObjectT", bound=IslSetLike | IslExpressionLike) +IslPwExpressionLikeT = TypeVar( + "IslPwExpressionLikeT", + bound=IslPwExpressionLike +) +IslObjectT = TypeVar("IslObjectT", bound=IslObject) NamedIslObjectT = TypeVar("NamedIslObjectT", bound="NamedIslObject[IslObject]") @@ -110,47 +115,67 @@ def _strip_names(obj: IslObjectT) -> tuple[IslObjectT, NameToDim]: return obj, constantdict(name_to_dim) +@overload +def _restore_names(obj: isl.PwAff, name_to_dim: NameToDim) -> isl.PwAff: + ... + + +@overload +def _restore_names(obj: IslSetLikeT, name_to_dim: NameToDim) -> IslSetLikeT: + ... + + +@overload +def _restore_names(obj: IslPwExpressionLikeT, name_to_dim: NameToDim) -> IslPwExpressionLikeT: + ... + + def _restore_names(obj: IslObjectT, name_to_dim: NameToDim) -> IslObjectT: - if isinstance(obj, isl.PwAff): - pwaff_obj = obj.move_dims( + restored_obj = obj.copy() + if isinstance(restored_obj, isl.PwAff): + # input dimensions cannot be renamed for isl.PwAff, so we first move + # input dims to the parameter space, rename then move back + restored_obj = restored_obj.move_dims( isl.dim_type.param, 0, isl.dim_type.in_, 0, - obj.dim(isl.dim_type.in_) + restored_obj.dim(isl.dim_type.in_) ) for name, dim in name_to_dim.items(): - pwaff_obj = pwaff_obj.set_dim_name( + restored_obj = restored_obj.set_dim_name( isl.dim_type.param, dim, name ) - pwaff_obj = pwaff_obj.get_pw_aff_list().get_at(0) - - pwaff_obj = pwaff_obj.move_dims( + restored_obj = restored_obj.get_pw_aff_list().get_at(0) + restored_obj = restored_obj.move_dims( isl.dim_type.in_, 0, isl.dim_type.param, 0, - pwaff_obj.dim(isl.dim_type.param) + restored_obj.dim(isl.dim_type.param) ) - return pwaff_obj + return restored_obj - if isinstance(obj, IslSetLike): + if isinstance(restored_obj, IslSetLike): dt_to_restore = isl.dim_type.set else: dt_to_restore = isl.dim_type.in_ for name, dim in name_to_dim.items(): - obj = obj.set_dim_name(dt_to_restore, dim, name) + restored_obj = restored_obj.set_dim_name(dt_to_restore, dim, name) - return obj + if isinstance(restored_obj, isl.UnionPwAff | isl.UnionPwMultiAff): + raise NotImplementedError + return restored_obj -def _get_dim_names(obj: IslObjectT, dt: isl.dim_type) -> frozenset[str]: + +def _get_dim_names(obj: IslObject, dt: isl.dim_type) -> frozenset[str]: all_dt_names: list[str] = [] for dim in range(obj.dim(dt)): @@ -187,57 +212,55 @@ def _deconstruct_object(obj: IslObjectT) -> tuple[IslObject, DimTypeToNames]: dt_to_names: dict[isl.dim_type, frozenset[str]] = {} if isinstance(obj, IslSetLike | IslMultiExpressionLike): - setlike_obj = obj + decon_obj = obj dt_to_names = dict.fromkeys( [isl.dim_type.in_, isl.dim_type.param], frozenset() ) # NOTE: isl.PwMultiAff.move_dims does not exist, represent as map # internally - if isinstance(setlike_obj, IslMultiExpressionLike): - setlike_obj = setlike_obj.as_map() + if isinstance(decon_obj, IslMultiExpressionLike): + decon_obj = decon_obj.as_map() for dt in dt_to_names: - dt_to_names[dt] = _get_dim_names(setlike_obj, dt) + dt_to_names[dt] = _get_dim_names(decon_obj, dt) if dt_to_names[dt]: - setlike_obj = setlike_obj.move_dims( + decon_obj = decon_obj.move_dims( isl.dim_type.set, - setlike_obj.dim(isl.dim_type.set), + decon_obj.dim(isl.dim_type.set), dt, 0, - setlike_obj.dim(dt) + decon_obj.dim(dt) ) - setlike_obj = ( - setlike_obj.range() - if isinstance(setlike_obj, isl.Map | isl.BasicMap) - else setlike_obj + decon_obj = ( + decon_obj.range() + if isinstance(decon_obj, isl.Map | isl.BasicMap) + else decon_obj ) - setlike_obj = ( - isl.Set.from_basic_set(setlike_obj) - if isinstance(setlike_obj, isl.BasicSet) - else setlike_obj + decon_obj = ( + isl.Set.from_basic_set(decon_obj) + if isinstance(decon_obj, isl.BasicSet) + else decon_obj ) - return setlike_obj, constantdict(dt_to_names) - else: - expr_obj = obj + decon_obj = obj dt_to_names = dict.fromkeys([isl.dim_type.param], frozenset()) - dt_to_names[isl.dim_type.param] = _get_dim_names(expr_obj, + dt_to_names[isl.dim_type.param] = _get_dim_names(decon_obj, isl.dim_type.param) - expr_obj = expr_obj.move_dims( + decon_obj = decon_obj.move_dims( isl.dim_type.in_, - expr_obj.dim(isl.dim_type.in_), + decon_obj.dim(isl.dim_type.in_), isl.dim_type.param, 0, - expr_obj.dim(isl.dim_type.param) + decon_obj.dim(isl.dim_type.param) ) - return expr_obj, constantdict(dt_to_names) + return decon_obj, constantdict(dt_to_names) def _find_contiguous_dim_chunks(dims: Sequence[int]) -> Mapping[int, int]: From 7c97d5c5fbd8a77891531aef56fc317c404cd4d8 Mon Sep 17 00:00:00 2001 From: Addison Date: Mon, 16 Feb 2026 10:27:28 -0600 Subject: [PATCH 20/43] add property for getting names --- namedisl/core.py | 4 ++++ namedisl/test/test_namedisl.py | 43 ++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 namedisl/test/test_namedisl.py diff --git a/namedisl/core.py b/namedisl/core.py index a610fa5..75cba24 100644 --- a/namedisl/core.py +++ b/namedisl/core.py @@ -422,6 +422,10 @@ class NamedIslObject(Generic[IslObjectT], ABC): # used to reconstruct ISL object _dimtype_to_names: DimTypeToNames + @property + def names(self) -> frozenset[str]: + return frozenset(self._name_to_dim.keys()) + @property def _has_inputs(self) -> bool: return ( diff --git a/namedisl/test/test_namedisl.py b/namedisl/test/test_namedisl.py new file mode 100644 index 0000000..021c8d7 --- /dev/null +++ b/namedisl/test/test_namedisl.py @@ -0,0 +1,43 @@ +from __future__ import annotations + + +__copyright__ = """ +Copyright (C) 2025- University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import pytest + +from .utils_for_tests import generate_random_named_set + + +@pytest.mark.parametrize("ndims", [2, 3, 4, 5]) +@pytest.mark.parametrize("has_params", [True, False]) +def test_names(ndims: int, has_params: bool): + s_param = "n" if has_params else None + s, s_dims, _ = generate_random_named_set(ndims, "s", s_param) + names = frozenset(s_dims.split(",")) + + if s_param: + names = names | frozenset({s_param}) + + assert s.names == names From 4d5555c747e71b19cda01844af08b8f038ac0398 Mon Sep 17 00:00:00 2001 From: Addison Date: Mon, 16 Feb 2026 15:12:29 -0600 Subject: [PATCH 21/43] formatting fix --- namedisl/expression_like.py | 1 + 1 file changed, 1 insertion(+) diff --git a/namedisl/expression_like.py b/namedisl/expression_like.py index d851809..b176695 100644 --- a/namedisl/expression_like.py +++ b/namedisl/expression_like.py @@ -99,6 +99,7 @@ class _NamedPwExpressionLike(_NamedExpressionLike[IslExpressionLikeT]): class Aff(_NamedExpressionLike[isl.Aff]): _obj: isl.Aff + @overload def make_aff(src: str, ctx: isl.Context | None = None) -> Aff: ... From 1f33281c84ce3635a18e7d83d56def5f505c7658 Mon Sep 17 00:00:00 2001 From: Addison Date: Mon, 16 Feb 2026 15:24:46 -0600 Subject: [PATCH 22/43] make {_input, _parameter}_names public --- namedisl/core.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/namedisl/core.py b/namedisl/core.py index 75cba24..9a0efe3 100644 --- a/namedisl/core.py +++ b/namedisl/core.py @@ -302,14 +302,14 @@ def _find_joint_name_to_dim( :arg:`obj2` within each dimension-type chunk. This ordering is used in alignment before performing operations between two set-like objects. """ - obj1_inp_names = obj1._input_names - obj1_param_names = obj1._parameter_names + obj1_inp_names = obj1.input_names + obj1_param_names = obj1.parameter_names obj1_set_names = ( frozenset(obj1._name_to_dim.keys()) - (obj1_inp_names | obj1_param_names) ) - obj2_inp_names = obj2._input_names - obj2_param_names = obj2._parameter_names + obj2_inp_names = obj2.input_names + obj2_param_names = obj2.parameter_names obj2_set_names = ( frozenset(obj2._name_to_dim.keys()) - (obj2_inp_names | obj2_param_names) ) @@ -435,7 +435,7 @@ def _has_inputs(self) -> bool: ) @property - def _input_names(self) -> frozenset[str]: + def input_names(self) -> frozenset[str]: if self._has_inputs: return self._dimtype_to_names[isl.dim_type.in_] return frozenset() @@ -458,7 +458,7 @@ def _has_params(self) -> bool: ) @property - def _parameter_names(self) -> frozenset[str]: + def parameter_names(self) -> frozenset[str]: if self._has_params: return self._dimtype_to_names[isl.dim_type.param] return frozenset() @@ -493,7 +493,7 @@ def _reconstruct_isl_object(self) -> IslObjectT: param_start = self._parameter_dim_start obj = obj.move_dims( isl.dim_type.param, 0, - internal_dim, param_start, len(self._parameter_names) + internal_dim, param_start, len(self.parameter_names) ) if self._has_inputs: @@ -511,7 +511,7 @@ def _reconstruct_isl_object(self) -> IslObjectT: inp_start = self._input_dim_start obj = obj.move_dims( isl.dim_type.in_, 0, - internal_dim, inp_start, len(self._input_names) + internal_dim, inp_start, len(self.input_names) ) return obj From 12230ae7bc325bfe3b385100b6be33fbe19c19ef Mon Sep 17 00:00:00 2001 From: Addison Date: Tue, 10 Mar 2026 13:02:14 -0500 Subject: [PATCH 23/43] fix repo corruption --- namedisl/core.py | 50 ++++++++++++++++++++++++++-- namedisl/set_like.py | 7 ++++ namedisl/tags.py | 57 ++++++++++++++++++++++++++++++++ namedisl/test/test_namedisl.py | 17 +++++++++- namedisl/test/utils_for_tests.py | 11 ++++-- 5 files changed, 137 insertions(+), 5 deletions(-) create mode 100644 namedisl/tags.py diff --git a/namedisl/core.py b/namedisl/core.py index 9a0efe3..fc2ee8d 100644 --- a/namedisl/core.py +++ b/namedisl/core.py @@ -26,16 +26,24 @@ """ import re -from abc import ABC +from abc import ABC, abstractmethod from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass from importlib import metadata from typing import Generic, TypeAlias, TypeVar, overload from constantdict import constantdict -from typing_extensions import override +from typing_extensions import Self, override import islpy as isl +ISL_DIM_TYPES = [ + isl.dim_type.out, + isl.dim_type.in_, + isl.dim_type.set, + isl.dim_type.param +] + +from namedisl.tags import _TaggedName IslBaseExpressionLike = isl.Aff | isl.QPolynomial @@ -422,6 +430,44 @@ class NamedIslObject(Generic[IslObjectT], ABC): # used to reconstruct ISL object _dimtype_to_names: DimTypeToNames + def add_names(self, tagged_names_to_add: Sequence[_TaggedName]) -> Self: + + if isinstance(self._obj, isl.PwMultiAff): + raise NotImplementedError + + new_obj = self._obj + new_name_to_dim = dict(self._name_to_dim) + new_dt_to_names: Mapping[isl.dim_type, frozenset[str]] = dict.fromkeys( + ISL_DIM_TYPES, frozenset() + ) + + for tagged_name in tagged_names_to_add: + name = tagged_name.name + dt = tagged_name._isl_dim_type + + new_dt_to_names[dt] |= frozenset({name}) + + # get rid of unused keys + new_dt_to_names = { + dt : new_dt_to_names[dt] + for dt in new_dt_to_names if new_dt_to_names[dt] + } + + for dt in new_dt_to_names: + if dt == isl.dim_type.out or dt == isl.dim_type.set: + start = 0 + elif dt == isl.dim_type.in_: + start = self._input_dim_start + else: + start = self._parameter_dim_start + + new_obj = new_obj.insert_dims(dt, start, len(new_dt_to_names[dt])) + + return type(self)( + new_obj, + constantdict(new_name_to_dim), + constantdict(new_dt_to_names)) + @property def names(self) -> frozenset[str]: return frozenset(self._name_to_dim.keys()) diff --git a/namedisl/set_like.py b/namedisl/set_like.py index 29fcb1d..fe7902a 100644 --- a/namedisl/set_like.py +++ b/namedisl/set_like.py @@ -191,6 +191,13 @@ def __eq__(self, other: object) -> bool: @final @dataclass(frozen=True, eq=False) class BasicSet(_NamedIslSetLike): + @override + def add_input_names(self, names_to_add: Sequence[str]) -> BasicSet: + raise NotImplementedError + + @override + def add_output_names(self, names_to_add: Sequence[str]) -> BasicSet: + raise NotImplementedError @override def _reconstruct_isl_object(self) -> isl.BasicSet: diff --git a/namedisl/tags.py b/namedisl/tags.py new file mode 100644 index 0000000..c110ee5 --- /dev/null +++ b/namedisl/tags.py @@ -0,0 +1,57 @@ +from __future__ import annotations + + +__copyright__ = """ +Copyright (C) 2025- University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from dataclasses import dataclass +from pytools.tag import Tag + +from islpy import dim_type + + +@dataclass(frozen=True) +class _TaggedName(Tag): + name: str + _isl_dim_type: dim_type + + +@dataclass(frozen=True) +class InputName(_TaggedName): + _isl_dim_type: dim_type = dim_type.in_ + + +@dataclass(frozen=True) +class OutputName(_TaggedName): + _isl_dim_type: dim_type = dim_type.out + + +@dataclass(frozen=True) +class ParameterName(_TaggedName): + _isl_dim_type: dim_type = dim_type.param + + +@dataclass(frozen=True) +class SetName(_TaggedName): + _isl_dim_type: dim_type = dim_type.set diff --git a/namedisl/test/test_namedisl.py b/namedisl/test/test_namedisl.py index 021c8d7..d150fa2 100644 --- a/namedisl/test/test_namedisl.py +++ b/namedisl/test/test_namedisl.py @@ -27,7 +27,7 @@ import pytest -from .utils_for_tests import generate_random_named_set +from .utils_for_tests import generate_random_named_set, get_name_sequence @pytest.mark.parametrize("ndims", [2, 3, 4, 5]) @@ -41,3 +41,18 @@ def test_names(ndims: int, has_params: bool): names = names | frozenset({s_param}) assert s.names == names + +@pytest.mark.parametrize("ndims", [2, 3, 4, 5]) +@pytest.mark.parametrize("n_names_to_add", [2, 3, 4, 5]) +def test_add_names( + ndims: int, + n_names_to_add: int + ): + + s, s_dims, _ = generate_random_named_set(ndims, "s", None) + new_set_names, _ = get_name_sequence(n_names_to_add, "set") + + from namedisl.tags import SetName + s = s.add_names([SetName(name) for name in new_set_names]) + + print(s) diff --git a/namedisl/test/utils_for_tests.py b/namedisl/test/utils_for_tests.py index a615e72..6688cb8 100644 --- a/namedisl/test/utils_for_tests.py +++ b/namedisl/test/utils_for_tests.py @@ -25,6 +25,7 @@ THE SOFTWARE. """ +from collections.abc import Sequence from random import randint import islpy as isl @@ -36,13 +37,19 @@ NamedMapReturnT = tuple[nisl.Map, NamedSetReturnT, NamedSetReturnT] +def get_name_sequence(n: int, dim_prefix: str) -> tuple[Sequence[str], str]: + dims = [f"{dim_prefix}_{i}" for i in range(n)] + dim_str = ",".join(d for d in dims) + + return dims, dim_str + + def generate_random_named_set( ndims: int, dim_prefix: str, param: str | None ) -> NamedSetReturnT: - dims = [f"{dim_prefix}_{i}" for i in range(ndims)] - dim_str = ",".join(d for d in dims) + dims, dim_str = get_name_sequence(ndims, dim_prefix) if param is not None: conditions = f"0 <= {dim_str} < {param}" From 6841aba63b8c2f13232f8f5589c2aaf63374dd1d Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 10 Mar 2026 13:57:24 -0500 Subject: [PATCH 24/43] Fix ruff --- namedisl/core.py | 25 +++++++++++++++---------- namedisl/expression_like.py | 1 - namedisl/set_like.py | 2 +- namedisl/tags.py | 1 + namedisl/test/test_namedisl.py | 3 ++- namedisl/test/utils_for_tests.py | 6 +++++- pyproject.toml | 4 ++++ 7 files changed, 28 insertions(+), 14 deletions(-) diff --git a/namedisl/core.py b/namedisl/core.py index fc2ee8d..c8588d1 100644 --- a/namedisl/core.py +++ b/namedisl/core.py @@ -26,16 +26,18 @@ """ import re -from abc import ABC, abstractmethod +from abc import ABC from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass from importlib import metadata -from typing import Generic, TypeAlias, TypeVar, overload +from typing import TYPE_CHECKING, Generic, TypeAlias, TypeVar, overload from constantdict import constantdict from typing_extensions import Self, override import islpy as isl + + ISL_DIM_TYPES = [ isl.dim_type.out, isl.dim_type.in_, @@ -43,7 +45,9 @@ isl.dim_type.param ] -from namedisl.tags import _TaggedName + +if TYPE_CHECKING: + from namedisl.tags import _TaggedName IslBaseExpressionLike = isl.Aff | isl.QPolynomial @@ -134,7 +138,10 @@ def _restore_names(obj: IslSetLikeT, name_to_dim: NameToDim) -> IslSetLikeT: @overload -def _restore_names(obj: IslPwExpressionLikeT, name_to_dim: NameToDim) -> IslPwExpressionLikeT: +def _restore_names( + obj: IslPwExpressionLikeT, + name_to_dim: NameToDim + ) -> IslPwExpressionLikeT: ... @@ -159,7 +166,7 @@ def _restore_names(obj: IslObjectT, name_to_dim: NameToDim) -> IslObjectT: ) restored_obj = restored_obj.get_pw_aff_list().get_at(0) - restored_obj = restored_obj.move_dims( + return restored_obj.move_dims( isl.dim_type.in_, 0, isl.dim_type.param, @@ -167,8 +174,6 @@ def _restore_names(obj: IslObjectT, name_to_dim: NameToDim) -> IslObjectT: restored_obj.dim(isl.dim_type.param) ) - return restored_obj - if isinstance(restored_obj, IslSetLike): dt_to_restore = isl.dim_type.set else: @@ -423,7 +428,7 @@ def _align_and_apply_binary_op( @dataclass(frozen=True, eq=False) -class NamedIslObject(Generic[IslObjectT], ABC): +class NamedIslObject(ABC, Generic[IslObjectT]): _obj: IslObjectT _name_to_dim: NameToDim @@ -449,12 +454,12 @@ def add_names(self, tagged_names_to_add: Sequence[_TaggedName]) -> Self: # get rid of unused keys new_dt_to_names = { - dt : new_dt_to_names[dt] + dt: new_dt_to_names[dt] for dt in new_dt_to_names if new_dt_to_names[dt] } for dt in new_dt_to_names: - if dt == isl.dim_type.out or dt == isl.dim_type.set: + if dt in (isl.dim_type.out, isl.dim_type.set): start = 0 elif dt == isl.dim_type.in_: start = self._input_dim_start diff --git a/namedisl/expression_like.py b/namedisl/expression_like.py index b176695..cac103b 100644 --- a/namedisl/expression_like.py +++ b/namedisl/expression_like.py @@ -219,7 +219,6 @@ class _NamedMultiExpressionLike(NamedIslObject[isl.Set]): deconstruction steps backwards (set -> map -> multi-expression). As such, reconstruction is special-cased for each subclass. """ - ... @final diff --git a/namedisl/set_like.py b/namedisl/set_like.py index fe7902a..5813f94 100644 --- a/namedisl/set_like.py +++ b/namedisl/set_like.py @@ -177,7 +177,7 @@ def __sub__( @override def __eq__(self, other: object) -> bool: if not isinstance(other, type(self)): - raise ValueError("Objects are not of the same type") + raise TypeError("Objects are not of the same type") aligned_self, aligned_other = _align_two(self, other) diff --git a/namedisl/tags.py b/namedisl/tags.py index c110ee5..65bf176 100644 --- a/namedisl/tags.py +++ b/namedisl/tags.py @@ -26,6 +26,7 @@ """ from dataclasses import dataclass + from pytools.tag import Tag from islpy import dim_type diff --git a/namedisl/test/test_namedisl.py b/namedisl/test/test_namedisl.py index d150fa2..f6136d8 100644 --- a/namedisl/test/test_namedisl.py +++ b/namedisl/test/test_namedisl.py @@ -42,6 +42,7 @@ def test_names(ndims: int, has_params: bool): assert s.names == names + @pytest.mark.parametrize("ndims", [2, 3, 4, 5]) @pytest.mark.parametrize("n_names_to_add", [2, 3, 4, 5]) def test_add_names( @@ -49,7 +50,7 @@ def test_add_names( n_names_to_add: int ): - s, s_dims, _ = generate_random_named_set(ndims, "s", None) + s, _s_dims, _ = generate_random_named_set(ndims, "s", None) new_set_names, _ = get_name_sequence(n_names_to_add, "set") from namedisl.tags import SetName diff --git a/namedisl/test/utils_for_tests.py b/namedisl/test/utils_for_tests.py index 6688cb8..de64bbd 100644 --- a/namedisl/test/utils_for_tests.py +++ b/namedisl/test/utils_for_tests.py @@ -25,14 +25,18 @@ THE SOFTWARE. """ -from collections.abc import Sequence from random import randint +from typing import TYPE_CHECKING import islpy as isl import namedisl as nisl +if TYPE_CHECKING: + from collections.abc import Sequence + + NamedSetReturnT = tuple[nisl.Set, str, str] NamedMapReturnT = tuple[nisl.Map, NamedSetReturnT, NamedSetReturnT] diff --git a/pyproject.toml b/pyproject.toml index 5305241..07f33c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,10 @@ extend-ignore = [ "RUF067", # __init__ should contain no code ] +[tool.ruff.lint.per-file-ignores] +"pytools/test/*.py" = ["S102"] +"doc/conf.py" = ["S102"] + [tool.ruff.lint.flake8-quotes] docstring-quotes = "double" inline-quotes = "double" From fffd7751038475eb5c11d24b275383e80e5c08d0 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 10 Mar 2026 14:15:57 -0500 Subject: [PATCH 25/43] Fix typing of _align_and_apply_binary_op --- namedisl/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/namedisl/core.py b/namedisl/core.py index c8588d1..4835e94 100644 --- a/namedisl/core.py +++ b/namedisl/core.py @@ -414,10 +414,10 @@ def _align_two( def _align_and_apply_binary_op( - lhs: NamedIslObjectT, - rhs: NamedIslObjectT, + lhs: NamedIslObject[IslObjectT], + rhs: NamedIslObject[IslObjectT], op: Callable[[IslObjectT, IslObjectT], IslObjectT] - ) -> NamedIslObjectT: + ) -> NamedIslObject[IslObjectT]: lhs, rhs = _align_two(lhs, rhs) result = op(lhs._obj, rhs._obj) From 80f3818ff6375fe0974802beb33e2fd9e84a8877 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 10 Mar 2026 14:16:10 -0500 Subject: [PATCH 26/43] Don't use len in emptiness checks --- namedisl/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/namedisl/core.py b/namedisl/core.py index 4835e94..f54d601 100644 --- a/namedisl/core.py +++ b/namedisl/core.py @@ -482,7 +482,7 @@ def _has_inputs(self) -> bool: return ( isl.dim_type.in_ in self._dimtype_to_names and - len(self._dimtype_to_names[isl.dim_type.in_]) > 0 + bool(self._dimtype_to_names[isl.dim_type.in_]) ) @property @@ -505,7 +505,7 @@ def _has_params(self) -> bool: return ( isl.dim_type.param in self._dimtype_to_names and - len(self._dimtype_to_names[isl.dim_type.param]) > 0 + bool(self._dimtype_to_names[isl.dim_type.param]) ) @property From bc9d2af0c11be68f06653a748626420ecf69f977 Mon Sep 17 00:00:00 2001 From: Addison Date: Mon, 6 Apr 2026 09:20:20 -0500 Subject: [PATCH 27/43] use codex to debug, add new features --- .codex | 0 namedisl/__init__.py | 2 + namedisl/core.py | 123 ++++++++++++-- namedisl/expression_like.py | 13 ++ namedisl/set_like.py | 293 +++++++++++++++++++++++++++++++-- namedisl/test/test_set_like.py | 179 ++++++++++++++++++++ 6 files changed, 585 insertions(+), 25 deletions(-) create mode 100644 .codex diff --git a/.codex b/.codex new file mode 100644 index 0000000..e69de29 diff --git a/namedisl/__init__.py b/namedisl/__init__.py index dc863d4..2152775 100644 --- a/namedisl/__init__.py +++ b/namedisl/__init__.py @@ -47,6 +47,7 @@ make_basic_map, make_basic_set, make_map, + make_map_from_domain_and_range, make_set, ) @@ -66,6 +67,7 @@ "make_basic_map", "make_basic_set", "make_map", + "make_map_from_domain_and_range", "make_multi_aff", "make_pw_aff", "make_pw_multi_aff", diff --git a/namedisl/core.py b/namedisl/core.py index f54d601..26395c5 100644 --- a/namedisl/core.py +++ b/namedisl/core.py @@ -98,6 +98,7 @@ "_align_two", "_deconstruct_object", "_find_contiguous_dim_chunks", + "_normalize_dimtype_to_names", "_restore_names", "_strip_names", ] @@ -105,26 +106,114 @@ def _strip_names(obj: IslObjectT) -> tuple[IslObjectT, NameToDim]: name_to_dim: dict[str, int] = {} + first_occurrence_dim_by_base_name: dict[str, int] = {} + seen_occurrences_by_base_name: dict[str, int] = {} dt_to_strip = ( isl.dim_type.set if isinstance(obj, IslSetLike) else isl.dim_type.in_ ) - for i in range(obj.dim(dt_to_strip)): - if isinstance(obj, isl.QPolynomial | isl.PwQPolynomial): - name = obj.space.get_dim_name(dt_to_strip, i) + stripped_obj = obj.copy() + raw_names: list[str] = [] + + for i in range(stripped_obj.dim(dt_to_strip)): + if isinstance(stripped_obj, isl.QPolynomial | isl.PwQPolynomial): + name = stripped_obj.space.get_dim_name(dt_to_strip, i) else: - name = obj.get_dim_name(dt_to_strip, i) + name = stripped_obj.get_dim_name(dt_to_strip, i) if name is None: raise ValueError("unnamed dimension found") - if name in name_to_dim: - raise ValueError(f"non-unique dim name: {name}") + raw_names.append(name) + + logical_name_counts: dict[str, int] = {} + for name in raw_names: + logical_name = name.rstrip("'") + logical_name_counts[logical_name] = logical_name_counts.get(logical_name, 0) + 1 + + for i, raw_name in enumerate(raw_names): + logical_name = raw_name.rstrip("'") + occurrence_index = seen_occurrences_by_base_name.get(logical_name, 0) + canonical_name = logical_name + "'" * occurrence_index + + if raw_name != canonical_name: + stripped_obj = stripped_obj.set_dim_name( + dt_to_strip, + i, + canonical_name + ) + + if occurrence_index > 0: + first_dim = first_occurrence_dim_by_base_name[logical_name] + stripped_obj = stripped_obj.equate( + dt_to_strip, + first_dim, + dt_to_strip, + i + ) + + if occurrence_index == 0: + first_occurrence_dim_by_base_name[logical_name] = i + + seen_occurrences_by_base_name[logical_name] = occurrence_index + 1 + name_to_dim[canonical_name] = i + + return stripped_obj, constantdict(name_to_dim) + - name_to_dim[name] = i +def _get_obj_dim_name(obj: IslObject, dt: isl.dim_type, dim: int) -> str: + if isinstance(obj, isl.QPolynomial | isl.PwQPolynomial): + name = obj.space.get_dim_name(dt, dim) + else: + name = obj.get_dim_name(dt, dim) + + if name is None: + raise ValueError("unnamed dimension found") - return obj, constantdict(name_to_dim) + return name + + +def _normalize_dimtype_to_names( + obj: IslObject, + dimtype_to_names: DimTypeToNames + ) -> DimTypeToNames: + if isinstance(obj, IslSetLike | IslMultiExpressionLike): + dim_type = isl.dim_type.set + total_dims = obj.dim(dim_type) + n_in = len(dimtype_to_names.get(isl.dim_type.in_, frozenset())) + n_param = len(dimtype_to_names.get(isl.dim_type.param, frozenset())) + + new_dimtype_to_names: dict[isl.dim_type, frozenset[str]] = {} + + if n_in: + start = total_dims - n_param - n_in + new_dimtype_to_names[isl.dim_type.in_] = frozenset( + _get_obj_dim_name(obj, dim_type, dim) + for dim in range(start, start + n_in) + ) + + if n_param: + start = total_dims - n_param + new_dimtype_to_names[isl.dim_type.param] = frozenset( + _get_obj_dim_name(obj, dim_type, dim) + for dim in range(start, start + n_param) + ) + + return constantdict(new_dimtype_to_names) + + total_dims = obj.dim(isl.dim_type.in_) + n_param = len(dimtype_to_names.get(isl.dim_type.param, frozenset())) + if not n_param: + return dimtype_to_names + + start = total_dims - n_param + return constantdict({ + isl.dim_type.param: frozenset( + _get_obj_dim_name(obj, isl.dim_type.in_, dim) + for dim in range(start, start + n_param) + ) + }) @overload @@ -384,8 +473,8 @@ def _align_obj( ) else: - old_dim = new_isl_obj.dim(isl.dim_type.set) - new_isl_obj = new_isl_obj.insert_dims(isl.dim_type.set, target_dim, 1) + old_dim = new_isl_obj.dim(target_dt) + new_isl_obj = new_isl_obj.insert_dims(target_dt, target_dim, 1) # track side effects of inserting/swapping dimensions for n, d in list(running_name_to_dim.items()): @@ -396,6 +485,8 @@ def _align_obj( running_name_to_dim[name] = target_dim + new_isl_obj = _restore_names(new_isl_obj, ordering) + return type(named_obj)(new_isl_obj, ordering, dimtype_to_names) @@ -477,6 +568,15 @@ def add_names(self, tagged_names_to_add: Sequence[_TaggedName]) -> Self: def names(self) -> frozenset[str]: return frozenset(self._name_to_dim.keys()) + def get_space(self) -> isl.Space: + return self._reconstruct_isl_object().get_space() + + def dim(self, dim_type: isl.dim_type) -> int: + return self._reconstruct_isl_object().dim(dim_type) + + def get_dim_name(self, dim_type: isl.dim_type, dim: int) -> str | None: + return self._reconstruct_isl_object().get_dim_name(dim_type, dim) + @property def _has_inputs(self) -> bool: return ( @@ -567,6 +667,9 @@ def _reconstruct_isl_object(self) -> IslObjectT: return obj + def __getattr__(self, name: str): + return getattr(self._reconstruct_isl_object(), name) + @override def __str__(self) -> str: return str(self._reconstruct_isl_object()) diff --git a/namedisl/expression_like.py b/namedisl/expression_like.py index cac103b..ac74cd6 100644 --- a/namedisl/expression_like.py +++ b/namedisl/expression_like.py @@ -38,6 +38,7 @@ NamedIslObject, _align_and_apply_binary_op, _deconstruct_object, + _normalize_dimtype_to_names, _strip_names, ) @@ -115,6 +116,7 @@ def make_aff(src: str | isl.Aff, ctx: isl.Context | None = None) -> Aff: aff_obj, dimtype_to_names = _deconstruct_object(obj) aff_obj, name_to_dim = _strip_names(aff_obj) + dimtype_to_names = _normalize_dimtype_to_names(aff_obj, dimtype_to_names) return Aff(aff_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args @@ -147,6 +149,7 @@ def make_qpolynomial( qp_obj, dimtype_to_names = _deconstruct_object(obj) qp_obj, name_to_dim = _strip_names(qp_obj) + dimtype_to_names = _normalize_dimtype_to_names(qp_obj, dimtype_to_names) return QPolynomial(qp_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args @@ -172,6 +175,7 @@ def make_pw_aff(src: str | isl.PwAff, ctx: isl.Context | None = None) -> PwAff: pwaff_obj, dimtype_to_names = _deconstruct_object(obj) pwaff_obj, name_to_dim = _strip_names(pwaff_obj) + dimtype_to_names = _normalize_dimtype_to_names(pwaff_obj, dimtype_to_names) return PwAff(pwaff_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args @@ -201,6 +205,7 @@ def make_pw_qpolynomial( pw_qp_obj, dimtype_to_names = _deconstruct_object(obj) pw_qp_obj, name_to_dim = _strip_names(pw_qp_obj) + dimtype_to_names = _normalize_dimtype_to_names(pw_qp_obj, dimtype_to_names) return PwQPolynomial(pw_qp_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args @@ -224,6 +229,9 @@ class _NamedMultiExpressionLike(NamedIslObject[isl.Set]): @final @dataclass(frozen=True, eq=False) class PwMultiAff(_NamedMultiExpressionLike): + def get_at(self, dim: int) -> PwAff: + return make_pw_aff(self._reconstruct_isl_object().get_at(dim)) + @override def _reconstruct_isl_object(self) -> isl.PwMultiAff: # deconstruction: isl.PwMultiAff -> isl.Map -> isl.Set @@ -250,6 +258,7 @@ def make_pw_multi_aff( pw_maff_obj, dimtype_to_names = _deconstruct_object(obj) pw_maff_obj, name_to_dim = _strip_names(pw_maff_obj) + dimtype_to_names = _normalize_dimtype_to_names(pw_maff_obj, dimtype_to_names) return PwMultiAff(pw_maff_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args @@ -257,6 +266,9 @@ def make_pw_multi_aff( @final @dataclass(frozen=True, eq=False) class MultiAff(NamedIslObject[isl.Set]): + def get_at(self, dim: int) -> Aff: + return make_aff(self._reconstruct_isl_object().get_at(dim)) + @override def _reconstruct_isl_object(self) -> isl.MultiAff: # deconstruction: isl.MultiAff -> isl.Map -> isl.Set @@ -280,6 +292,7 @@ def make_multi_aff( maff_obj, dimtype_to_names = _deconstruct_object(obj) maff_obj, name_to_dim = _strip_names(maff_obj) + dimtype_to_names = _normalize_dimtype_to_names(maff_obj, dimtype_to_names) return MultiAff(maff_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args diff --git a/namedisl/set_like.py b/namedisl/set_like.py index 5813f94..dcec262 100644 --- a/namedisl/set_like.py +++ b/namedisl/set_like.py @@ -38,17 +38,18 @@ from .core import ( NamedIslObject, NameToDim, - _align_and_apply_binary_op, + _align_obj, _align_two, _deconstruct_object, _find_contiguous_dim_chunks, + _normalize_dimtype_to_names, _strip_names, ) from .expression_like import PwAff, PwMultiAff, make_pw_aff, make_pw_multi_aff if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Callable, Sequence @dataclass(frozen=True, eq=False) @@ -139,40 +140,72 @@ def project_out(self: _NamedIslSetLike, def project_out_except( self: _NamedIslSetLike, - names_to_keep: str | Sequence[str] + names_to_keep: str | Sequence[str], + dim_types: Sequence[isl.dim_type] | None = None ) -> _NamedIslSetLike: if isinstance(names_to_keep, str): - names_to_keep = [names_to_keep] + names_to_keep = [names_to_keep] if names_to_keep else [] + + considered_names = set(self._name_to_dim) + if dim_types is not None: + considered_names = set() + for dim_type in dim_types: + if dim_type == isl.dim_type.param: + considered_names |= set(self.parameter_names) + elif dim_type in ( + isl.dim_type.set, + isl.dim_type.out, + isl.dim_type.in_ + ): + considered_names |= ( + set(self._name_to_dim) + - set(self.parameter_names) + - set(self.input_names) + ) names_to_project_out = [ - name for name in self._name_to_dim + name for name in considered_names if name not in names_to_keep ] return self.project_out(names_to_project_out) - def dim_max(self, name: str) -> PwAff: - return make_pw_aff(self._obj.dim_max(self._name_to_dim[name])) + def dim_max(self, name: str | int) -> PwAff: + dim = name if isinstance(name, int) else self._name_to_dim[name] + return make_pw_aff(self._obj.dim_max(dim)) - def dim_min(self, name: str) -> PwAff: - return make_pw_aff(self._obj.dim_min(self._name_to_dim[name])) + def dim_min(self, name: str | int) -> PwAff: + dim = name if isinstance(name, int) else self._name_to_dim[name] + return make_pw_aff(self._obj.dim_min(dim)) def as_pw_multi_aff(self) -> PwMultiAff: return make_pw_multi_aff(self._reconstruct_isl_object().as_pw_multi_aff()) + @override + def dim(self, dim_type: isl.dim_type) -> int: + if dim_type == isl.dim_type.out: + dim_type = isl.dim_type.set + return self._reconstruct_isl_object().dim(dim_type) + + @override + def get_dim_name(self, dim_type: isl.dim_type, dim: int) -> str | None: + if dim_type == isl.dim_type.out: + dim_type = isl.dim_type.set + return self._reconstruct_isl_object().get_dim_name(dim_type, dim) + # FIXME: basedpyright is not happy with these function signatures def __and__( self, other: _NamedIslSetLike) -> _NamedIslSetLike: - return _align_and_apply_binary_op(self, other, operator.and_) + return _apply_set_like_binary_op(self, other, operator.and_) def __or__( self, other: _NamedIslSetLike) -> _NamedIslSetLike: - return _align_and_apply_binary_op(self, other, operator.or_) + return _apply_set_like_binary_op(self, other, operator.or_) def __sub__( self, other: _NamedIslSetLike) -> _NamedIslSetLike: - return _align_and_apply_binary_op(self, other, operator.sub) + return _apply_set_like_binary_op(self, other, operator.sub) @override def __eq__(self, other: object) -> bool: @@ -228,6 +261,7 @@ def make_basic_set(src: str | isl.BasicSet, ctx: isl.Context | None = None) -> B assert isinstance(set_obj, isl.Set) set_obj, name_to_dim = _strip_names(set_obj) + dimtype_to_names = _normalize_dimtype_to_names(set_obj, dimtype_to_names) return BasicSet(set_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args @@ -238,6 +272,152 @@ class Set(_NamedIslSetLike): ... +def _apply_set_like_binary_op( + lhs: _NamedIslSetLike, + rhs: _NamedIslSetLike, + op: Callable[[isl.Set, isl.Set], isl.Set] + ) -> _NamedIslSetLike: + lhs, rhs = _align_two(lhs, rhs) + result = op(lhs._obj, rhs._obj) + + if isinstance(lhs, BasicMap) and isinstance(rhs, BasicMap): + result_type = BasicMap if result.n_basic_set() == 1 else Map + elif isinstance(lhs, BasicSet) and isinstance(rhs, BasicSet): + result_type = BasicSet if result.n_basic_set() == 1 else Set + elif isinstance(lhs, (BasicMap, Map)) and isinstance(rhs, (BasicMap, Map)): + result_type = Map + else: + result_type = Set + + return result_type(result, lhs._name_to_dim, lhs._dimtype_to_names) + + +class _NamedIslMapLike(_NamedIslSetLike): + @override + def _reconstruct_isl_object(self) -> isl.Map: + obj = super()._reconstruct_isl_object() + if isinstance(obj, isl.Set): + return isl.Map.from_domain_and_range(isl.Set("{ [] }"), obj) + return obj + + def _output_names(self) -> frozenset[str]: + return frozenset(self._name_to_dim) - self.input_names - self.parameter_names + + @staticmethod + def _logical_name(name: str) -> str: + return name.rstrip("'") + + def _ordered_names(self, names: frozenset[str]) -> tuple[str, ...]: + return tuple(sorted(names, key=self._name_to_dim.__getitem__)) + + def _ordered_logical_names(self, names: frozenset[str]) -> tuple[str, ...]: + return tuple(self._logical_name(name) for name in self._ordered_names(names)) + + def _actual_names_for_logical_order( + self, + names: frozenset[str], + logical_order: tuple[str, ...] + ) -> tuple[str, ...]: + name_by_logical: dict[str, str] = {} + for name in self._ordered_names(names): + logical_name = self._logical_name(name) + if logical_name in name_by_logical: + raise ValueError( + "multiple dimensions in one interface share the same " + f"logical name: {logical_name}" + ) + name_by_logical[logical_name] = name + + try: + return tuple(name_by_logical[logical_name] for logical_name in logical_order) + except KeyError as exc: + raise ValueError("maps are not composable: interface names differ") from exc + + def _reorder_interface( + self, + dim_type: isl.dim_type, + logical_order: tuple[str, ...] + ) -> _NamedIslMapLike: + interface_names = ( + self.input_names if dim_type == isl.dim_type.in_ else self._output_names() + ) + ordered_names = self._actual_names_for_logical_order( + interface_names, + logical_order + ) + current_names = self._ordered_names(interface_names) + if current_names == ordered_names: + return self + + out_names = ( + ordered_names if dim_type == isl.dim_type.out + else self._ordered_names(self._output_names()) + ) + in_names = ( + ordered_names if dim_type == isl.dim_type.in_ + else self._ordered_names(self.input_names) + ) + param_names = self._ordered_names(self.parameter_names) + + ordering: NameToDim = constantdict({ + name: dim + for dim, name in enumerate((*out_names, *in_names, *param_names)) + }) + + return _align_obj(self, ordering, self._dimtype_to_names) + + def _validate_composable( + self, + lhs_dim_type: isl.dim_type, + other: BasicMap | Map, + rhs_dim_type: isl.dim_type + ) -> tuple[str, ...]: + lhs_names = self.input_names if lhs_dim_type == isl.dim_type.in_ else self._output_names() + rhs_names = other.input_names if rhs_dim_type == isl.dim_type.in_ else other._output_names() + if ( + frozenset(self._logical_name(name) for name in lhs_names) + != + frozenset(self._logical_name(name) for name in rhs_names) + ): + raise ValueError("maps are not composable: interface names differ") + return self._ordered_logical_names(lhs_names) + + def intersect_domain(self, domain: BasicSet | Set) -> Map: + return self & make_map(isl.Map.from_domain_and_range( + domain._reconstruct_isl_object(), + isl.Set.universe(self._reconstruct_isl_object().range().get_space()) + )) + + def intersect_range(self, range_: BasicSet | Set) -> Map: + return self & make_map(isl.Map.from_domain_and_range( + isl.Set.universe(self._reconstruct_isl_object().domain().get_space()), + range_._reconstruct_isl_object() + )) + + def apply_range(self, other: BasicMap | Map) -> Map: + ordered_names = self._validate_composable(isl.dim_type.out, other, isl.dim_type.in_) + other = other._reorder_interface(isl.dim_type.in_, ordered_names) + return make_map( + self._reconstruct_isl_object().apply_range(other._reconstruct_isl_object()) + ) + + def apply_domain(self, other: BasicMap | Map) -> Map: + ordered_names = self._validate_composable(isl.dim_type.in_, other, isl.dim_type.out) + other = other._reorder_interface(isl.dim_type.out, ordered_names) + return make_map( + other._reconstruct_isl_object().apply_range(self._reconstruct_isl_object()) + ) + + def reverse(self) -> Map: + return make_map(self._reconstruct_isl_object().reverse()) + + def domain(self) -> Set: + return make_set(self._reconstruct_isl_object().domain()) + + def range(self) -> Set: + return make_set(self._reconstruct_isl_object().range()) + + @overload def make_set(src: str, ctx: isl.Context | None = None) -> Set: ... @@ -255,18 +435,77 @@ def make_set(src: isl.Set | str, ctx: isl.Context | None = None) -> Set: assert isinstance(set_obj, isl.Set) set_obj, name_to_dim = _strip_names(set_obj) + dimtype_to_names = _normalize_dimtype_to_names(set_obj, dimtype_to_names) return Set(set_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args @final @dataclass(frozen=True, eq=False) -class BasicMap(_NamedIslSetLike): +class BasicMap(_NamedIslMapLike): + @classmethod + def empty(cls, space: isl.Space) -> BasicMap: + obj = isl.BasicMap.empty(space) + set_obj, dimtype_to_names = _deconstruct_object(obj) + assert isinstance(set_obj, isl.Set) + set_obj, name_to_dim = _strip_names(set_obj) + dimtype_to_names = _normalize_dimtype_to_names(set_obj, dimtype_to_names) + return cls(set_obj, name_to_dim, dimtype_to_names) + + def reverse(self) -> BasicMap: + return make_basic_map(self._reconstruct_isl_object().reverse()) + + def domain(self) -> BasicSet: + return make_basic_set(self._reconstruct_isl_object().domain()) + + def range(self) -> BasicSet: + return make_basic_set(self._reconstruct_isl_object().range()) + + def intersect_domain(self, domain: BasicSet | Set) -> BasicMap | Map: + if isinstance(domain, BasicSet): + return self & make_basic_map(isl.BasicMap.from_domain_and_range( + domain._reconstruct_isl_object(), + isl.BasicSet.universe(self._reconstruct_isl_object().range().get_space()) + )) + return super().intersect_domain(domain) + + def intersect_range(self, range_: BasicSet | Set) -> BasicMap | Map: + if isinstance(range_, BasicSet): + return self & make_basic_map(isl.BasicMap.from_domain_and_range( + isl.BasicSet.universe(self._reconstruct_isl_object().domain().get_space()), + range_._reconstruct_isl_object() + )) + return super().intersect_range(range_) + + def apply_range(self, other: BasicMap | Map) -> BasicMap | Map: + if isinstance(other, BasicMap): + ordered_names = self._validate_composable( + isl.dim_type.out, other, isl.dim_type.in_) + other = other._reorder_interface(isl.dim_type.in_, ordered_names) + return make_basic_map( + self._reconstruct_isl_object().apply_range( + other._reconstruct_isl_object()) + ) + return super().apply_range(other) + + def apply_domain(self, other: BasicMap | Map) -> BasicMap | Map: + if isinstance(other, BasicMap): + ordered_names = self._validate_composable( + isl.dim_type.in_, other, isl.dim_type.out) + other = other._reorder_interface(isl.dim_type.out, ordered_names) + return make_basic_map( + other._reconstruct_isl_object().apply_range( + self._reconstruct_isl_object()) + ) + return super().apply_domain(other) @override def _reconstruct_isl_object(self) -> isl.BasicMap: obj = super()._reconstruct_isl_object() + if isinstance(obj, isl.Map) and obj.is_empty(): + return isl.BasicMap.empty(obj.get_space()) + if not isinstance(obj, isl.Map) or obj.n_basic_map() != 1: raise ValueError( "Cannot reconstruct an isl.BasicMap from anything other than " @@ -292,14 +531,37 @@ def make_basic_map(src: str | isl.BasicMap, ctx: isl.Context | None = None) -> B assert isinstance(set_obj, isl.Set) set_obj, name_to_dim = _strip_names(set_obj) + dimtype_to_names = _normalize_dimtype_to_names(set_obj, dimtype_to_names) return BasicMap(set_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args +def make_map_from_domain_and_range( + domain: BasicSet | Set, + range_: BasicSet | Set + ) -> BasicMap | Map: + if isinstance(domain, BasicSet) and isinstance(range_, BasicSet): + return make_basic_map( + isl.BasicMap.from_domain_and_range( + domain._reconstruct_isl_object(), + range_._reconstruct_isl_object() + ) + ) + + return make_map( + isl.Map.from_domain_and_range( + domain._reconstruct_isl_object(), + range_._reconstruct_isl_object() + ) + ) + + @final @dataclass(frozen=True, eq=False) -class Map(_NamedIslSetLike): - ... +class Map(_NamedIslMapLike): + @classmethod + def empty(cls, space: isl.Space) -> Map: + return make_map(isl.Map.empty(space)) @overload @@ -319,5 +581,6 @@ def make_map(src: str | isl.Map, ctx: isl.Context | None = None) -> Map: assert isinstance(set_obj, isl.Set) set_obj, name_to_dim = _strip_names(set_obj) + dimtype_to_names = _normalize_dimtype_to_names(set_obj, dimtype_to_names) return Map(set_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args diff --git a/namedisl/test/test_set_like.py b/namedisl/test/test_set_like.py index 8d9d19c..e18297d 100644 --- a/namedisl/test/test_set_like.py +++ b/namedisl/test/test_set_like.py @@ -30,6 +30,9 @@ import islpy as isl import namedisl as nisl +from namedisl.core import _align_two +from loopy.symbolic import pwaff_from_expr +from pymbolic import var from .utils_for_tests import generate_random_named_map, generate_random_named_set @@ -114,6 +117,24 @@ def test_set_intersection(ndims: int, has_params: bool): assert (a & b) == result +def test_basic_set_intersection_promotes_to_set() -> None: + basic = nisl.make_basic_set( + "{ [ii_s, ji_s, k_s] : 0 <= ii_s <= 4 and 0 <= ji_s <= 4 and k_s = 0 }" + ) + footprint = nisl.make_set( + "{ [ii_s, ji_s, k_s] : " + "(0 <= ii_s <= 4 and ji_s = 2 and k_s = 0) or " + "(ii_s = 2 and 0 <= ji_s <= 4 and k_s = 0) }" + ) + + result = basic & footprint + + assert isinstance(result, nisl.Set) + reconstructed = result._reconstruct_isl_object() + assert isinstance(reconstructed, isl.Set) + assert reconstructed.n_basic_set() > 1 + + @pytest.mark.parametrize("ndims", [1, 2, 4, 8]) def test_set_eliminate(ndims: int): a, a_dims, _ = generate_random_named_set(ndims, "a", None) @@ -301,6 +322,164 @@ def test_map_intersection(ndims_domain: int, ndims_range: int, has_params: bool) assert (x & y) == result_map +def test_map_alignment_syncs_internal_output_positions_and_names() -> None: + lhs = nisl.make_map("{ [i] -> [x] }") + rhs = nisl.make_map("{ [i] -> [y, x] }") + + aligned_lhs, aligned_rhs = _align_two(lhs, rhs) + + lhs_names = [ + aligned_lhs._obj.get_dim_name(isl.dim_type.set, dim) + for dim in range(aligned_lhs._obj.dim(isl.dim_type.set)) + ] + rhs_names = [ + aligned_rhs._obj.get_dim_name(isl.dim_type.set, dim) + for dim in range(aligned_rhs._obj.dim(isl.dim_type.set)) + ] + + assert lhs_names == ["x", "y", "i"] + assert rhs_names == ["x", "y", "i"] + + +def test_map_alignment_syncs_internal_input_and_parameter_positions_and_names() -> None: + lhs = nisl.make_map("[n] -> { [i] -> [x] }") + rhs = nisl.make_map("[m, n] -> { [j, i] -> [x] }") + + aligned_lhs, aligned_rhs = _align_two(lhs, rhs) + + lhs_names = [ + aligned_lhs._obj.get_dim_name(isl.dim_type.set, dim) + for dim in range(aligned_lhs._obj.dim(isl.dim_type.set)) + ] + rhs_names = [ + aligned_rhs._obj.get_dim_name(isl.dim_type.set, dim) + for dim in range(aligned_rhs._obj.dim(isl.dim_type.set)) + ] + + assert lhs_names == ["x", "i", "j", "m", "n"] + assert rhs_names == ["x", "i", "j", "m", "n"] + + +def test_map_apply_range_for_compute_a_tile_usage_map() -> None: + bm = 32 + bk = 16 + + compute_map = nisl.make_map(f"""{{ + [is, ks] -> [ii_s, io, ki_s, ko] : + is = io * {bm} + ii_s and + ks = ko * {bk} + ki_s + }}""") + + usage_domain = nisl.make_set( + "{ [i, j, k, io, jo, ko, ii, ji, ki, ii_s, ji_s, ki_s] }" + ) + global_usage_map = nisl.make_map_from_domain_and_range( + usage_domain, + nisl.make_set("{ [is, ks] }") + ) + + local_usage_mpwaff = isl.MultiPwAff.zero(global_usage_map.get_space()) + for idx, expr in enumerate([var("i"), var("k")]): + local_space = local_usage_mpwaff.get_at(idx).get_space().domain() + local_usage_mpwaff = local_usage_mpwaff.set_pw_aff( + idx, + pwaff_from_expr(local_space, expr) + ) + + local_usage_map = nisl.make_map(local_usage_mpwaff.as_map()) + local_usage_map = local_usage_map.intersect_domain( + nisl.make_basic_set( + "{ [i, j, k, io, jo, ko, ii, ji, ki, ii_s, ji_s, ki_s] }" + ) + ) + + global_usage_map = global_usage_map | local_usage_map + composed = global_usage_map.apply_range(compute_map) + + assert frozenset( + name.rstrip("'") for name in composed.input_names + ) == frozenset( + {"i", "ii", "ii_s", "io", "j", "ji", "ji_s", "jo", + "k", "ki", "ki_s", "ko"} + ) + assert composed.range() == nisl.make_set("{ [ii_s, io, ki_s, ko] }") + + +def test_map_apply_domain_accepts_logically_equal_ticked_interface_names() -> None: + lhs = nisl.make_map("{ [x] -> [y] }").apply_range( + nisl.make_map("{ [y] -> [x] }") + ) + rhs = nisl.make_map("{ [p] -> [x] }") + + result = lhs.apply_domain(rhs) + + assert result.range() == nisl.make_set("{ [x] }") + + +def test_map_domain_canonicalizes_single_remaining_ticked_name() -> None: + m = nisl.make_map("{ [x] -> [y] }").apply_range( + nisl.make_map("{ [y] -> [x] }") + ) + + domain = m.domain() + + assert domain == nisl.make_set("{ [x] }") + assert domain.names == frozenset({"x"}) + + +def test_map_empty_from_space_preserves_names_and_is_empty() -> None: + space = isl.Space.create_from_names( + isl.DEFAULT_CONTEXT, + params=["n"], + in_=["i", "j"], + out=["x", "y"] + ) + + m = nisl.Map.empty(space) + + assert m._reconstruct_isl_object().is_empty() + assert m.input_names == frozenset({"i", "j"}) + assert m.range() == nisl.make_set("[n] -> { [x, y] : false }") + + +def test_basic_map_empty_from_space_preserves_names_and_is_empty() -> None: + space = isl.Space.create_from_names( + isl.DEFAULT_CONTEXT, + in_=["i"], + out=["x"] + ) + + m = nisl.BasicMap.empty(space) + + assert m._reconstruct_isl_object().is_empty() + assert m.input_names == frozenset({"i"}) + assert m.range() == nisl.make_basic_set("{ [x] : 1 = 0 }") + + +def test_map_empty_matches_existing_named_space() -> None: + template = nisl.make_map("[n] -> { [i, k] -> [ii_s, io, ki_s, ko] }") + + empty_map = nisl.Map.empty(template.get_space()) + + assert empty_map._reconstruct_isl_object().is_empty() + assert empty_map.input_names == template.input_names + assert empty_map.range()._reconstruct_isl_object().is_empty() + assert empty_map.range().names == template.range().names + + +def test_empty_map_is_identity_for_union() -> None: + space = isl.Space.create_from_names( + isl.DEFAULT_CONTEXT, + in_=["i"], + out=["x"] + ) + empty_map = nisl.Map.empty(space) + nonempty_map = nisl.make_map("{ [i] -> [x] }") + + assert (empty_map | nonempty_map) == nonempty_map + assert (nonempty_map | empty_map) == nonempty_map + + @pytest.mark.parametrize("ndims_domain", [1, 2, 4, 8]) @pytest.mark.parametrize("ndims_range", [1, 2, 4, 8]) def test_map_eliminate(ndims_domain: int, ndims_range: int): From 11480972a7b0f7b613b85ece87109fea7a3e6330 Mon Sep 17 00:00:00 2001 From: Addison Date: Thu, 16 Apr 2026 09:57:55 -0500 Subject: [PATCH 28/43] tighten abstraction, add more features to set-likes --- .codex | 0 .gitignore | 2 + namedisl/core.py | 543 ++++++++++++++++++++------ namedisl/expression_like.py | 99 +++-- namedisl/set_like.py | 461 ++++++++++++++-------- namedisl/test/test_expression_like.py | 14 + namedisl/test/test_namedisl.py | 256 ++++++++++++ namedisl/test/test_set_like.py | 286 ++++++++++++-- 8 files changed, 1284 insertions(+), 377 deletions(-) delete mode 100644 .codex diff --git a/.codex b/.codex deleted file mode 100644 index e69de29..0000000 diff --git a/.gitignore b/.gitignore index ea396a7..a20f1e1 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,5 @@ distribute*tar.gz .pylintrc.yml .run-pylint.py + +.codex diff --git a/namedisl/core.py b/namedisl/core.py index 26395c5..f9ebb98 100644 --- a/namedisl/core.py +++ b/namedisl/core.py @@ -27,10 +27,10 @@ import re from abc import ABC -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Collection, Mapping, Sequence from dataclasses import dataclass from importlib import metadata -from typing import TYPE_CHECKING, Generic, TypeAlias, TypeVar, overload +from typing import TYPE_CHECKING, Generic, TypeAlias, TypeVar, cast, overload from constantdict import constantdict from typing_extensions import Self, override @@ -74,7 +74,7 @@ "IslPwExpressionLikeT", bound=IslPwExpressionLike ) -IslObjectT = TypeVar("IslObjectT", bound=IslObject) +IslObjectT = TypeVar("IslObjectT", bound=IslObject, covariant=True) NamedIslObjectT = TypeVar("NamedIslObjectT", bound="NamedIslObject[IslObject]") @@ -85,7 +85,7 @@ # alignment DimTypeToNames: TypeAlias = Mapping[isl.dim_type, frozenset[str]] -IslObjectPieces: TypeAlias = tuple[IslObjectT, DimTypeToNames] +IslObjectPieces: TypeAlias = tuple[IslObject, NameToDim, DimTypeToNames] __version__ = metadata.version("namedisl") @@ -98,23 +98,47 @@ "_align_two", "_deconstruct_object", "_find_contiguous_dim_chunks", + "_make_named_object_pieces", "_normalize_dimtype_to_names", "_restore_names", "_strip_names", ] +def _normalize_public_dim_type(dim_type: isl.dim_type) -> isl.dim_type: + if dim_type == isl.dim_type.out: + return isl.dim_type.set + return dim_type + + +def _ensure_unique_public_names(obj: IslObject) -> None: + if isinstance(obj, IslSetLike | IslMultiExpressionLike): + dim_types = (isl.dim_type.set, isl.dim_type.in_, isl.dim_type.param) + else: + dim_types = (isl.dim_type.in_, isl.dim_type.param) + + seen_names: set[str] = set() + for dim_type in dim_types: + for dim in range(obj.dim(dim_type)): + if isinstance(obj, isl.QPolynomial | isl.PwQPolynomial): + name = obj.space.get_dim_name(dim_type, dim) + else: + name = obj.get_dim_name(dim_type, dim) + if name is None: + raise ValueError("duplicate or unnamed dimension found") + if name in seen_names: + raise ValueError(f"duplicate dimension name found: {name}") + seen_names.add(name) + + def _strip_names(obj: IslObjectT) -> tuple[IslObjectT, NameToDim]: name_to_dim: dict[str, int] = {} - first_occurrence_dim_by_base_name: dict[str, int] = {} - seen_occurrences_by_base_name: dict[str, int] = {} dt_to_strip = ( isl.dim_type.set if isinstance(obj, IslSetLike) else isl.dim_type.in_ ) stripped_obj = obj.copy() - raw_names: list[str] = [] for i in range(stripped_obj.dim(dt_to_strip)): if isinstance(stripped_obj, isl.QPolynomial | isl.PwQPolynomial): @@ -125,41 +149,12 @@ def _strip_names(obj: IslObjectT) -> tuple[IslObjectT, NameToDim]: if name is None: raise ValueError("unnamed dimension found") - raw_names.append(name) - - logical_name_counts: dict[str, int] = {} - for name in raw_names: - logical_name = name.rstrip("'") - logical_name_counts[logical_name] = logical_name_counts.get(logical_name, 0) + 1 - - for i, raw_name in enumerate(raw_names): - logical_name = raw_name.rstrip("'") - occurrence_index = seen_occurrences_by_base_name.get(logical_name, 0) - canonical_name = logical_name + "'" * occurrence_index + if name in name_to_dim: + raise ValueError(f"duplicate dimension name found: {name}") - if raw_name != canonical_name: - stripped_obj = stripped_obj.set_dim_name( - dt_to_strip, - i, - canonical_name - ) - - if occurrence_index > 0: - first_dim = first_occurrence_dim_by_base_name[logical_name] - stripped_obj = stripped_obj.equate( - dt_to_strip, - first_dim, - dt_to_strip, - i - ) + name_to_dim[name] = i - if occurrence_index == 0: - first_occurrence_dim_by_base_name[logical_name] = i - - seen_occurrences_by_base_name[logical_name] = occurrence_index + 1 - name_to_dim[canonical_name] = i - - return stripped_obj, constantdict(name_to_dim) + return cast("IslObjectT", stripped_obj), constantdict(name_to_dim) def _get_obj_dim_name(obj: IslObject, dt: isl.dim_type, dim: int) -> str: @@ -216,22 +211,12 @@ def _normalize_dimtype_to_names( }) -@overload -def _restore_names(obj: isl.PwAff, name_to_dim: NameToDim) -> isl.PwAff: - ... - - -@overload -def _restore_names(obj: IslSetLikeT, name_to_dim: NameToDim) -> IslSetLikeT: - ... - - -@overload -def _restore_names( - obj: IslPwExpressionLikeT, - name_to_dim: NameToDim - ) -> IslPwExpressionLikeT: - ... +def _make_named_object_pieces(obj: IslObject) -> IslObjectPieces: + _ensure_unique_public_names(obj) + decon_obj, dimtype_to_names = _deconstruct_object(obj) + decon_obj, name_to_dim = _strip_names(decon_obj) + dimtype_to_names = _normalize_dimtype_to_names(decon_obj, dimtype_to_names) + return decon_obj, name_to_dim, dimtype_to_names def _restore_names(obj: IslObjectT, name_to_dim: NameToDim) -> IslObjectT: @@ -255,13 +240,13 @@ def _restore_names(obj: IslObjectT, name_to_dim: NameToDim) -> IslObjectT: ) restored_obj = restored_obj.get_pw_aff_list().get_at(0) - return restored_obj.move_dims( + return cast("IslObjectT", restored_obj.move_dims( isl.dim_type.in_, 0, isl.dim_type.param, 0, restored_obj.dim(isl.dim_type.param) - ) + )) if isinstance(restored_obj, IslSetLike): dt_to_restore = isl.dim_type.set @@ -274,7 +259,7 @@ def _restore_names(obj: IslObjectT, name_to_dim: NameToDim) -> IslObjectT: if isinstance(restored_obj, isl.UnionPwAff | isl.UnionPwMultiAff): raise NotImplementedError - return restored_obj + return cast("IslObjectT", restored_obj) def _get_dim_names(obj: IslObject, dt: isl.dim_type) -> frozenset[str]: @@ -306,11 +291,11 @@ def _deconstruct_object(obj: isl.PwMultiAff) -> tuple[isl.Set, DimTypeToNames]: @overload -def _deconstruct_object(obj: IslObjectT) -> tuple[IslObjectT, DimTypeToNames]: +def _deconstruct_object(obj: IslObject) -> tuple[IslObject, DimTypeToNames]: ... -def _deconstruct_object(obj: IslObjectT) -> tuple[IslObject, DimTypeToNames]: +def _deconstruct_object(obj: IslObject) -> tuple[IslObject, DimTypeToNames]: dt_to_names: dict[isl.dim_type, frozenset[str]] = {} if isinstance(obj, IslSetLike | IslMultiExpressionLike): @@ -404,25 +389,20 @@ def _find_joint_name_to_dim( :arg:`obj2` within each dimension-type chunk. This ordering is used in alignment before performing operations between two set-like objects. """ - obj1_inp_names = obj1.input_names - obj1_param_names = obj1.parameter_names - obj1_set_names = ( - frozenset(obj1._name_to_dim.keys()) - (obj1_inp_names | obj1_param_names) - ) - - obj2_inp_names = obj2.input_names - obj2_param_names = obj2.parameter_names - obj2_set_names = ( - frozenset(obj2._name_to_dim.keys()) - (obj2_inp_names | obj2_param_names) + all_set_names = ( + obj1._names_for_dim_type(isl.dim_type.set) + | obj2._names_for_dim_type(isl.dim_type.set) ) - - all_inp_names = obj1_inp_names | obj2_inp_names - all_param_names = obj1_param_names | obj2_param_names - all_set_names = obj1_set_names | obj2_set_names + all_inp_names = obj1.input_names | obj2.input_names + all_param_names = obj1.parameter_names | obj2.parameter_names dt_to_names: DimTypeToNames = {} dt_to_names[isl.dim_type.param] = all_param_names - dt_to_names[isl.dim_type.in_] = all_inp_names + if ( + isinstance(obj1._obj, IslSetLike | IslMultiExpressionLike) + or isinstance(obj2._obj, IslSetLike | IslMultiExpressionLike) + ): + dt_to_names[isl.dim_type.in_] = all_inp_names # enforces contiguous ordering of [ (set), (input), (param) ] in set # representation @@ -444,7 +424,10 @@ def _align_obj( ordering: NameToDim, dimtype_to_names: DimTypeToNames ) -> NamedIslObjectT: - new_isl_obj = named_obj._obj + new_isl_obj = cast( + "IslSetLike | IslBaseExpressionLike | IslPwExpressionLike | isl.MultiAff", + named_obj._obj, + ) running_name_to_dim = dict(named_obj._name_to_dim) target_dt = ( @@ -487,7 +470,11 @@ def _align_obj( new_isl_obj = _restore_names(new_isl_obj, ordering) - return type(named_obj)(new_isl_obj, ordering, dimtype_to_names) + return type(named_obj)( + new_isl_obj, + ordering, + dimtype_to_names, + ) def _align_two( @@ -526,109 +513,413 @@ class NamedIslObject(ABC, Generic[IslObjectT]): # used to reconstruct ISL object _dimtype_to_names: DimTypeToNames - def add_names(self, tagged_names_to_add: Sequence[_TaggedName]) -> Self: + @property + def _metadata_input_names(self) -> frozenset[str]: + return self._dimtype_to_names.get(isl.dim_type.in_, frozenset()) + + @property + def _metadata_parameter_names(self) -> frozenset[str]: + return self._dimtype_to_names.get(isl.dim_type.param, frozenset()) + + def _names_for_dim_type(self, dim_type: isl.dim_type) -> frozenset[str]: + dim_type = _normalize_public_dim_type(dim_type) + if dim_type == isl.dim_type.param: + return self.parameter_names + + if isinstance(self._obj, IslSetLike | IslMultiExpressionLike): + if dim_type == isl.dim_type.in_: + return self._metadata_input_names + if dim_type == isl.dim_type.set: + return self.names - self._metadata_input_names - self.parameter_names + else: + if dim_type == isl.dim_type.in_: + return self.names - self.parameter_names + if dim_type == isl.dim_type.set: + return frozenset() + + raise ValueError(f"unsupported dim type: {dim_type}") + + def _ordered_names_for_dim_type(self, dim_type: isl.dim_type) -> tuple[str, ...]: + names = self._names_for_dim_type(dim_type) + return tuple(sorted(names, key=self._name_to_dim.__getitem__)) + + def _ordered_name_chunks(self) -> dict[isl.dim_type, tuple[str, ...]]: + return { + isl.dim_type.set: self._ordered_names_for_dim_type(isl.dim_type.set), + isl.dim_type.in_: self._ordered_names_for_dim_type(isl.dim_type.in_), + isl.dim_type.param: self._ordered_names_for_dim_type(isl.dim_type.param), + } + def _add_names_by_dim_type( + self, + names_to_add: Collection[str], + dim_type: isl.dim_type + ) -> Self: if isinstance(self._obj, isl.PwMultiAff): raise NotImplementedError + dim_type = _normalize_public_dim_type(dim_type) + if dim_type not in (isl.dim_type.set, isl.dim_type.in_, isl.dim_type.param): + raise ValueError(f"unsupported dim type: {dim_type}") + if ( + not isinstance(self._obj, IslSetLike | IslMultiExpressionLike) + and dim_type == isl.dim_type.set + ): + raise ValueError(f"unsupported dim type: {dim_type}") + + if len(set(names_to_add)) != len(tuple(names_to_add)): + raise ValueError("duplicate names to add") + + for name in names_to_add: + if name in self.names: + raise ValueError(f"name already exists: {name}") + + if not names_to_add: + return self + + grouped_names: dict[isl.dim_type, list[str]] = { + isl.dim_type.set: [], + isl.dim_type.in_: [], + isl.dim_type.param: [], + } + grouped_names[dim_type] = list(names_to_add) + + return self._add_grouped_names(grouped_names) + + def _add_grouped_names( + self, + grouped_names: Mapping[isl.dim_type, Collection[str]] + ) -> Self: + if isinstance(self._obj, isl.PwMultiAff): + raise NotImplementedError + + seen_names: set[str] = set() + for names in grouped_names.values(): + for name in names: + if name in seen_names: + raise ValueError("duplicate names to add") + if name in self.names: + raise ValueError(f"name already exists: {name}") + seen_names.add(name) + new_obj = self._obj - new_name_to_dim = dict(self._name_to_dim) - new_dt_to_names: Mapping[isl.dim_type, frozenset[str]] = dict.fromkeys( - ISL_DIM_TYPES, frozenset() + chunk_names = { + dt: list(names) + for dt, names in self._ordered_name_chunks().items() + } + internal_dim_type = ( + isl.dim_type.set + if isinstance(new_obj, IslSetLike | IslMultiExpressionLike) + else isl.dim_type.in_ ) - for tagged_name in tagged_names_to_add: - name = tagged_name.name - dt = tagged_name._isl_dim_type + insertion_starts = { + isl.dim_type.set: 0, + isl.dim_type.in_: len(chunk_names[isl.dim_type.set]), + isl.dim_type.param: ( + len(chunk_names[isl.dim_type.set]) + + len(chunk_names[isl.dim_type.in_]) + ), + } + + for dim_type in (isl.dim_type.param, isl.dim_type.in_, isl.dim_type.set): + names_to_add = grouped_names[dim_type] + if not names_to_add: + continue + new_obj = new_obj.insert_dims( + internal_dim_type, + insertion_starts[dim_type], + len(names_to_add) + ) + chunk_names[dim_type] = [*names_to_add, *chunk_names[dim_type]] + + ordered_names = [ + *chunk_names[isl.dim_type.set], + *chunk_names[isl.dim_type.in_], + *chunk_names[isl.dim_type.param], + ] + new_name_to_dim: NameToDim = constantdict({ + name: dim for dim, name in enumerate(ordered_names) + }) + new_dimtype_to_names: dict[isl.dim_type, frozenset[str]] = {} + if ( + isinstance(new_obj, IslSetLike | IslMultiExpressionLike) + and chunk_names[isl.dim_type.in_] + ): + new_dimtype_to_names[isl.dim_type.in_] = frozenset( + chunk_names[isl.dim_type.in_] + ) + if chunk_names[isl.dim_type.param]: + new_dimtype_to_names[isl.dim_type.param] = frozenset( + chunk_names[isl.dim_type.param] + ) - new_dt_to_names[dt] |= frozenset({name}) + return type(self)( + cast("IslObjectT", _restore_names(new_obj, new_name_to_dim)), + new_name_to_dim, + constantdict(new_dimtype_to_names), + ) - # get rid of unused keys - new_dt_to_names = { - dt: new_dt_to_names[dt] - for dt in new_dt_to_names if new_dt_to_names[dt] + def add_names(self, tagged_names_to_add: Sequence[_TaggedName]) -> Self: + grouped_names: dict[isl.dim_type, list[str]] = { + isl.dim_type.set: [], + isl.dim_type.in_: [], + isl.dim_type.param: [], } + for tagged_name in tagged_names_to_add: + dim_type = _normalize_public_dim_type(tagged_name._isl_dim_type) + if dim_type not in grouped_names: + raise ValueError(f"unsupported dim type: {tagged_name._isl_dim_type}") + grouped_names[dim_type].append(tagged_name.name) - for dt in new_dt_to_names: - if dt in (isl.dim_type.out, isl.dim_type.set): - start = 0 - elif dt == isl.dim_type.in_: - start = self._input_dim_start - else: - start = self._parameter_dim_start + return self._add_grouped_names(grouped_names) - new_obj = new_obj.insert_dims(dt, start, len(new_dt_to_names[dt])) + def add_set_names(self, names_to_add: Collection[str]) -> Self: + return self._add_names_by_dim_type(names_to_add, isl.dim_type.set) - return type(self)( - new_obj, - constantdict(new_name_to_dim), - constantdict(new_dt_to_names)) + def add_output_names(self, names_to_add: Collection[str]) -> Self: + return self._add_names_by_dim_type(names_to_add, isl.dim_type.out) + + def add_input_names(self, names_to_add: Collection[str]) -> Self: + return self._add_names_by_dim_type(names_to_add, isl.dim_type.in_) + + def add_parameter_names(self, names_to_add: Collection[str]) -> Self: + return self._add_names_by_dim_type(names_to_add, isl.dim_type.param) + + def add_dim_names( + self, + names_to_add: Collection[str], + dim_type: isl.dim_type + ) -> Self: + return self._add_names_by_dim_type(names_to_add, dim_type) @property def names(self) -> frozenset[str]: return frozenset(self._name_to_dim.keys()) + def dim_names(self, dim_type: isl.dim_type) -> frozenset[str]: + return self._names_for_dim_type(dim_type) + + @property + def set_names(self) -> frozenset[str]: + return self._names_for_dim_type(isl.dim_type.set) + + @property + def output_names(self) -> frozenset[str]: + return self._names_for_dim_type(isl.dim_type.out) + def get_space(self) -> isl.Space: return self._reconstruct_isl_object().get_space() def dim(self, dim_type: isl.dim_type) -> int: + dim_type = _normalize_public_dim_type(dim_type) + if dim_type in (isl.dim_type.set, isl.dim_type.in_, isl.dim_type.param): + return len(self._names_for_dim_type(dim_type)) return self._reconstruct_isl_object().dim(dim_type) - def get_dim_name(self, dim_type: isl.dim_type, dim: int) -> str | None: - return self._reconstruct_isl_object().get_dim_name(dim_type, dim) + def move_dims( + self, + names_to_move: str | Collection[str], + dst_type: isl.dim_type, + ) -> Self: + if isinstance(names_to_move, str): + names_to_move = [names_to_move] + + if not names_to_move: + return self + + dst_type = _normalize_public_dim_type(dst_type) + if dst_type not in (isl.dim_type.set, isl.dim_type.in_, isl.dim_type.param): + raise ValueError(f"unsupported destination dim type: {dst_type}") + + missing_names = [name for name in names_to_move if name not in self.names] + if missing_names: + raise ValueError(f"unknown names: {', '.join(missing_names)}") + + if len(set(names_to_move)) != len(tuple(names_to_move)): + raise ValueError("duplicate names in move_dims") + + names_to_move = [ + name for name in names_to_move + if name not in self._names_for_dim_type(dst_type) + ] + if not names_to_move: + return self + + moved_name_set = set(names_to_move) + chunk_names = { + dt: [ + name for name in names + if name not in moved_name_set + ] + for dt, names in self._ordered_name_chunks().items() + } + moved_names = sorted(names_to_move, key=self._name_to_dim.__getitem__) + chunk_names[dst_type].extend(moved_names) + + new_order = [ + *chunk_names[isl.dim_type.set], + *chunk_names[isl.dim_type.in_], + *chunk_names[isl.dim_type.param], + ] + new_name_to_dim: NameToDim = constantdict({ + name: dim for dim, name in enumerate(new_order) + }) + + new_dimtype_to_names: dict[isl.dim_type, frozenset[str]] = {} + if chunk_names[isl.dim_type.in_]: + new_dimtype_to_names[isl.dim_type.in_] = frozenset( + chunk_names[isl.dim_type.in_] + ) + if chunk_names[isl.dim_type.param]: + new_dimtype_to_names[isl.dim_type.param] = frozenset( + chunk_names[isl.dim_type.param] + ) + + return _align_obj( + self, + new_name_to_dim, + constantdict(new_dimtype_to_names) + ) + + def rename_dims(self, renaming: Mapping[str, str]) -> Self: + if not renaming: + return self + + missing_names = [name for name in renaming if name not in self.names] + if missing_names: + raise ValueError(f"unknown names: {', '.join(missing_names)}") + + if len(set(renaming.values())) != len(renaming): + raise ValueError("duplicate destination names in rename_dims") + + unchanged_names = { + old_name for old_name, new_name in renaming.items() + if old_name == new_name + } + renaming = { + old_name: new_name + for old_name, new_name in renaming.items() + if old_name not in unchanged_names + } + if not renaming: + return self + + existing_names = self.names - frozenset(renaming) + conflicting_names = existing_names & frozenset(renaming.values()) + if conflicting_names: + raise ValueError( + "cannot rename to existing names: " + + ", ".join(sorted(conflicting_names)) + ) + + new_name_to_dim: NameToDim = constantdict({ + renaming.get(name, name): dim + for name, dim in self._name_to_dim.items() + }) + new_dimtype_to_names: DimTypeToNames = constantdict({ + dim_type: frozenset( + renaming.get(name, name) + for name in names + ) + for dim_type, names in self._dimtype_to_names.items() + }) + + return type(self)(self._obj, new_name_to_dim, new_dimtype_to_names) + + @overload + def equate_dims(self, name1: Mapping[str, str]) -> Self: + ... + + @overload + def equate_dims(self, name1: str, name2: str) -> Self: + ... + + def equate_dims( + self, + name1: str | Mapping[str, str], + name2: str | None = None, + ) -> Self: + if isinstance(name1, str): + if name2 is None: + raise TypeError("name2 must be provided when name1 is a string") + equated_names = ((name1, name2),) + else: + if name2 is not None: + raise TypeError("name2 cannot be provided when name1 is a mapping") + equated_names = tuple(name1.items()) + + for lhs_name, rhs_name in equated_names: + if lhs_name not in self.names: + raise ValueError(f"unknown name: {lhs_name}") + if rhs_name not in self.names: + raise ValueError(f"unknown name: {rhs_name}") + + if all(lhs_name == rhs_name for lhs_name, rhs_name in equated_names): + return self + + if not isinstance(self._obj, IslSetLike): + raise NotImplementedError( + "equate_dims is only implemented for set-like objects" + ) + + obj = self._obj + for lhs_name, rhs_name in equated_names: + if lhs_name != rhs_name: + obj = obj.equate( + isl.dim_type.set, self._name_to_dim[lhs_name], + isl.dim_type.set, self._name_to_dim[rhs_name], + ) + + return type(self)( + cast("IslObjectT", obj), + self._name_to_dim, + self._dimtype_to_names, + ) @property def _has_inputs(self) -> bool: - return ( - isl.dim_type.in_ in self._dimtype_to_names - and - bool(self._dimtype_to_names[isl.dim_type.in_]) - ) + return bool(self._metadata_input_names) @property def input_names(self) -> frozenset[str]: - if self._has_inputs: - return self._dimtype_to_names[isl.dim_type.in_] - return frozenset() + return self._names_for_dim_type(isl.dim_type.in_) @property def _input_dim_start(self) -> int | None: if self._has_inputs: return min( self._name_to_dim[name] - for name in self._dimtype_to_names[isl.dim_type.in_] + for name in self._metadata_input_names ) return None @property def _has_params(self) -> bool: - return ( - isl.dim_type.param in self._dimtype_to_names - and - bool(self._dimtype_to_names[isl.dim_type.param]) - ) + return bool(self._metadata_parameter_names) @property def parameter_names(self) -> frozenset[str]: - if self._has_params: - return self._dimtype_to_names[isl.dim_type.param] - return frozenset() + return self._metadata_parameter_names @property def _parameter_dim_start(self) -> int | None: if self._has_params: return min( self._name_to_dim[name] - for name in self._dimtype_to_names[isl.dim_type.param] + for name in self._metadata_parameter_names ) return None - def _reconstruct_isl_object(self) -> IslObjectT: + def _reconstruct_isl_object(self) -> IslObject: """ Relies on the dimension type ordering in :func:`_deconstruct_set_like_object`. """ - obj = _restore_names(self._obj, self._name_to_dim) + obj = cast( + "IslSetLike | IslBaseExpressionLike | IslPwExpressionLike | isl.MultiAff", + _restore_names(self._obj, self._name_to_dim), + ) internal_dim = ( isl.dim_type.set if isinstance(obj, isl.Set) else isl.dim_type.in_ @@ -656,6 +947,7 @@ def _reconstruct_isl_object(self) -> IslObjectT: obj_domain = isl.Set("{ [] }") obj_range = obj + assert isinstance(obj_range, isl.BasicSet | isl.Set) obj = isl.Map.from_domain_and_range(obj_domain, obj_range) @@ -667,9 +959,6 @@ def _reconstruct_isl_object(self) -> IslObjectT: return obj - def __getattr__(self, name: str): - return getattr(self._reconstruct_isl_object(), name) - @override def __str__(self) -> str: return str(self._reconstruct_isl_object()) diff --git a/namedisl/expression_like.py b/namedisl/expression_like.py index ac74cd6..00d922a 100644 --- a/namedisl/expression_like.py +++ b/namedisl/expression_like.py @@ -37,9 +37,7 @@ IslExpressionLikeT, NamedIslObject, _align_and_apply_binary_op, - _deconstruct_object, - _normalize_dimtype_to_names, - _strip_names, + _make_named_object_pieces, ) @@ -83,7 +81,7 @@ def __mul__(self, other: Self | int) -> Self: return _align_and_apply_binary_op(self, other, operator.mul) def is_zero(self) -> bool: - return self._reconstruct_isl_object().is_zero() + return bool(self._obj.is_zero()) # pyright: ignore[reportAttributeAccessIssue, reportUnknownArgumentType, reportUnknownMemberType] @override def __eq__(self, other: object) -> bool: @@ -100,6 +98,12 @@ class _NamedPwExpressionLike(_NamedExpressionLike[IslExpressionLikeT]): class Aff(_NamedExpressionLike[isl.Aff]): _obj: isl.Aff + @override + def _reconstruct_isl_object(self) -> isl.Aff: + obj = super()._reconstruct_isl_object() + assert isinstance(obj, isl.Aff) + return obj + @overload def make_aff(src: str, ctx: isl.Context | None = None) -> Aff: @@ -113,11 +117,8 @@ def make_aff(src: isl.Aff) -> Aff: def make_aff(src: str | isl.Aff, ctx: isl.Context | None = None) -> Aff: obj = isl.Aff(src, ctx) if isinstance(src, str) else src - - aff_obj, dimtype_to_names = _deconstruct_object(obj) - aff_obj, name_to_dim = _strip_names(aff_obj) - dimtype_to_names = _normalize_dimtype_to_names(aff_obj, dimtype_to_names) - + aff_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) + assert isinstance(aff_obj, isl.Aff) return Aff(aff_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args @@ -126,6 +127,12 @@ def make_aff(src: str | isl.Aff, ctx: isl.Context | None = None) -> Aff: class QPolynomial(_NamedExpressionLike[isl.QPolynomial]): _obj: isl.QPolynomial + @override + def _reconstruct_isl_object(self) -> isl.QPolynomial: + obj = super()._reconstruct_isl_object() + assert isinstance(obj, isl.QPolynomial) + return obj + @overload def make_qpolynomial(src: str, ctx: isl.Context | None = None) -> QPolynomial: @@ -147,18 +154,22 @@ def make_qpolynomial( else src ) - qp_obj, dimtype_to_names = _deconstruct_object(obj) - qp_obj, name_to_dim = _strip_names(qp_obj) - dimtype_to_names = _normalize_dimtype_to_names(qp_obj, dimtype_to_names) - + qp_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) + assert isinstance(qp_obj, isl.QPolynomial) return QPolynomial(qp_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args @final @dataclass(frozen=True, eq=False) -class PwAff(_NamedPwExpressionLike): +class PwAff(_NamedPwExpressionLike[isl.PwAff]): _obj: isl.PwAff + @override + def _reconstruct_isl_object(self) -> isl.PwAff: + obj = super()._reconstruct_isl_object() + assert isinstance(obj, isl.PwAff) + return obj + @overload def make_pw_aff(src: str, ctx: isl.Context | None = None) -> PwAff: @@ -172,19 +183,22 @@ def make_pw_aff(src: isl.PwAff) -> PwAff: def make_pw_aff(src: str | isl.PwAff, ctx: isl.Context | None = None) -> PwAff: obj = isl.PwAff(src, ctx) if isinstance(src, str) else src - - pwaff_obj, dimtype_to_names = _deconstruct_object(obj) - pwaff_obj, name_to_dim = _strip_names(pwaff_obj) - dimtype_to_names = _normalize_dimtype_to_names(pwaff_obj, dimtype_to_names) - + pwaff_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) + assert isinstance(pwaff_obj, isl.PwAff) return PwAff(pwaff_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args @final @dataclass(frozen=True, eq=False) -class PwQPolynomial(_NamedPwExpressionLike): +class PwQPolynomial(_NamedPwExpressionLike[isl.PwQPolynomial]): _obj: isl.PwQPolynomial + @override + def _reconstruct_isl_object(self) -> isl.PwQPolynomial: + obj = super()._reconstruct_isl_object() + assert isinstance(obj, isl.PwQPolynomial) + return obj + @overload def make_pw_qpolynomial( @@ -202,11 +216,8 @@ def make_pw_qpolynomial( ctx: isl.Context | None = None ) -> PwQPolynomial: obj = isl.PwQPolynomial(src, ctx) if isinstance(src, str) else src - - pw_qp_obj, dimtype_to_names = _deconstruct_object(obj) - pw_qp_obj, name_to_dim = _strip_names(pw_qp_obj) - dimtype_to_names = _normalize_dimtype_to_names(pw_qp_obj, dimtype_to_names) - + pw_qp_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) + assert isinstance(pw_qp_obj, isl.PwQPolynomial) return PwQPolynomial(pw_qp_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args # }}} @@ -229,14 +240,20 @@ class _NamedMultiExpressionLike(NamedIslObject[isl.Set]): @final @dataclass(frozen=True, eq=False) class PwMultiAff(_NamedMultiExpressionLike): - def get_at(self, dim: int) -> PwAff: - return make_pw_aff(self._reconstruct_isl_object().get_at(dim)) + def get_at(self, name: str) -> PwAff: + if name not in self._names_for_dim_type(isl.dim_type.set): + raise ValueError(f"unknown output name: {name}") + return make_pw_aff( + self._reconstruct_isl_object().get_at(self._name_to_dim[name]) + ) @override def _reconstruct_isl_object(self) -> isl.PwMultiAff: # deconstruction: isl.PwMultiAff -> isl.Map -> isl.Set # reconstruction: isl.Set -> isl.Map -> isl.PwMultiAff - return super()._reconstruct_isl_object().as_pw_multi_aff() + obj = super()._reconstruct_isl_object() + assert isinstance(obj, isl.Set | isl.Map) + return obj.as_pw_multi_aff() @overload @@ -255,25 +272,26 @@ def make_pw_multi_aff( ) -> PwMultiAff: obj = isl.PwMultiAff(src, ctx) if isinstance(src, str) else src - - pw_maff_obj, dimtype_to_names = _deconstruct_object(obj) - pw_maff_obj, name_to_dim = _strip_names(pw_maff_obj) - dimtype_to_names = _normalize_dimtype_to_names(pw_maff_obj, dimtype_to_names) - + pw_maff_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) + assert isinstance(pw_maff_obj, isl.Set) return PwMultiAff(pw_maff_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args @final @dataclass(frozen=True, eq=False) -class MultiAff(NamedIslObject[isl.Set]): - def get_at(self, dim: int) -> Aff: - return make_aff(self._reconstruct_isl_object().get_at(dim)) +class MultiAff(_NamedMultiExpressionLike): + def get_at(self, name: str) -> Aff: + if name not in self._names_for_dim_type(isl.dim_type.set): + raise ValueError(f"unknown output name: {name}") + return make_aff(self._reconstruct_isl_object().get_at(self._name_to_dim[name])) @override def _reconstruct_isl_object(self) -> isl.MultiAff: # deconstruction: isl.MultiAff -> isl.Map -> isl.Set # reconstruction: isl.Set -> isl.Map -> isl.PwMultiAff -> isl.MultiAff - return super()._reconstruct_isl_object().as_pw_multi_aff().as_multi_aff() + obj = super()._reconstruct_isl_object() + assert isinstance(obj, isl.Set | isl.Map) + return obj.as_pw_multi_aff().as_multi_aff() @overload @@ -289,11 +307,8 @@ def make_multi_aff(src: isl.MultiAff) -> MultiAff: def make_multi_aff( src: str | isl.MultiAff, ctx: isl.Context | None = None) -> MultiAff: obj = isl.MultiAff(src, ctx) if isinstance(src, str) else src - - maff_obj, dimtype_to_names = _deconstruct_object(obj) - maff_obj, name_to_dim = _strip_names(maff_obj) - dimtype_to_names = _normalize_dimtype_to_names(maff_obj, dimtype_to_names) - + maff_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) + assert isinstance(maff_obj, isl.Set) return MultiAff(maff_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args # }}} diff --git a/namedisl/set_like.py b/namedisl/set_like.py index dcec262..a8fa4e1 100644 --- a/namedisl/set_like.py +++ b/namedisl/set_like.py @@ -28,10 +28,10 @@ import operator from abc import ABC from dataclasses import dataclass, replace -from typing import TYPE_CHECKING, final, overload +from typing import TYPE_CHECKING, cast, final, overload from constantdict import constantdict -from typing_extensions import override +from typing_extensions import Self, override import islpy as isl @@ -40,16 +40,13 @@ NameToDim, _align_obj, _align_two, - _deconstruct_object, _find_contiguous_dim_chunks, - _normalize_dimtype_to_names, - _strip_names, + _make_named_object_pieces, ) -from .expression_like import PwAff, PwMultiAff, make_pw_aff, make_pw_multi_aff if TYPE_CHECKING: - from collections.abc import Callable, Sequence + from collections.abc import Callable, Collection, Sequence @dataclass(frozen=True, eq=False) @@ -68,7 +65,7 @@ def complement(self: _NamedIslSetLike) -> _NamedIslSetLike: _dimtype_to_names=self._dimtype_to_names ) - def eliminate(self, names_to_eliminate: str | Sequence[str]) -> _NamedIslSetLike: + def eliminate(self, names_to_eliminate: str | Collection[str]) -> _NamedIslSetLike: if isinstance(names_to_eliminate, str): names_to_eliminate = [names_to_eliminate] @@ -92,8 +89,84 @@ def eliminate(self, names_to_eliminate: str | Sequence[str]) -> _NamedIslSetLike _dimtype_to_names=self._dimtype_to_names ) + def add_constraint( + self: Self, + constraints: str | Collection[str], + ) -> Self: + if isinstance(constraints, str): + constraints = [constraints] + else: + constraints = list(constraints) + + if not constraints: + return self + + ordered_names = tuple( + sorted(self._name_to_dim, key=self._name_to_dim.__getitem__) + ) + constraint_text = " and ".join(f"({constraint})" + for constraint in constraints) + constraint_src = ( + "{ " + f"[{', '.join(ordered_names)}] : " + f"{constraint_text}" + " }" + ) + + try: + constraint_obj = isl.Set(constraint_src) + except isl.Error as exc: + raise ValueError( + f"invalid constraint for names {ordered_names}: " + f"{constraint_text}" + ) from exc + + constraint_obj = constraint_obj.remove_redundancies() + constraint_set, constraint_name_to_dim, _ = _make_named_object_pieces( + constraint_obj + ) + assert isinstance(constraint_set, isl.Set) + + if constraint_name_to_dim != self._name_to_dim: + constraint_set = _align_obj( + Set( + constraint_set, + constraint_name_to_dim, + self._dimtype_to_names + ), + self._name_to_dim, + self._dimtype_to_names + )._obj + assert isinstance(constraint_set, isl.Set) + + return replace( + self, + _obj=self._obj.intersect(constraint_set), + _name_to_dim=self._name_to_dim, + _dimtype_to_names=self._dimtype_to_names + ) + + def gist(self, context: _NamedIslSetLike) -> _NamedIslSetLike: + self_aligned, context_aligned = _align_two(self, context) + result = self_aligned._obj.gist(context_aligned._obj) + + if isinstance(self, BasicMap): + result_type = BasicMap if result.n_basic_set() == 1 else Map + elif isinstance(self, Map): + result_type = Map + elif isinstance(self, BasicSet): + result_type = BasicSet if result.n_basic_set() == 1 else Set + else: + result_type = Set + + return result_type( + result, + self_aligned._name_to_dim, + self_aligned._dimtype_to_names + ) + def project_out(self: _NamedIslSetLike, - names_to_project_out: str | Sequence[str]) -> _NamedIslSetLike: + names_to_project_out: str | Collection[str]) -> _NamedIslSetLike: if isinstance(names_to_project_out, str): names_to_project_out = [names_to_project_out] @@ -140,59 +213,35 @@ def project_out(self: _NamedIslSetLike, def project_out_except( self: _NamedIslSetLike, - names_to_keep: str | Sequence[str], - dim_types: Sequence[isl.dim_type] | None = None + names_to_keep: str | Collection[str], ) -> _NamedIslSetLike: if isinstance(names_to_keep, str): names_to_keep = [names_to_keep] if names_to_keep else [] - considered_names = set(self._name_to_dim) - if dim_types is not None: - considered_names = set() - for dim_type in dim_types: - if dim_type == isl.dim_type.param: - considered_names |= set(self.parameter_names) - elif dim_type in ( - isl.dim_type.set, - isl.dim_type.out, - isl.dim_type.in_ - ): - considered_names |= ( - set(self._name_to_dim) - - set(self.parameter_names) - - set(self.input_names) - ) - names_to_project_out = [ - name for name in considered_names + name for name in self._name_to_dim if name not in names_to_keep ] return self.project_out(names_to_project_out) - def dim_max(self, name: str | int) -> PwAff: - dim = name if isinstance(name, int) else self._name_to_dim[name] - return make_pw_aff(self._obj.dim_max(dim)) + def dim_max(self, name: str) -> isl.PwAff: + return self._obj.dim_max(self._name_to_dim[name]) - def dim_min(self, name: str | int) -> PwAff: - dim = name if isinstance(name, int) else self._name_to_dim[name] - return make_pw_aff(self._obj.dim_min(dim)) + def dim_min(self, name: str) -> isl.PwAff: + return self._obj.dim_min(self._name_to_dim[name]) - def as_pw_multi_aff(self) -> PwMultiAff: - return make_pw_multi_aff(self._reconstruct_isl_object().as_pw_multi_aff()) + def as_pw_multi_aff(self) -> isl.PwMultiAff: + obj = self._reconstruct_isl_object() + assert isinstance(obj, isl.Set | isl.Map) + return obj.as_pw_multi_aff() @override def dim(self, dim_type: isl.dim_type) -> int: if dim_type == isl.dim_type.out: dim_type = isl.dim_type.set - return self._reconstruct_isl_object().dim(dim_type) - - @override - def get_dim_name(self, dim_type: isl.dim_type, dim: int) -> str | None: - if dim_type == isl.dim_type.out: - dim_type = isl.dim_type.set - return self._reconstruct_isl_object().get_dim_name(dim_type, dim) + return super().dim(dim_type) # FIXME: basedpyright is not happy with these function signatures def __and__( @@ -220,16 +269,24 @@ def __eq__(self, other: object) -> bool: assert isinstance(aligned_other._obj, isl.Set) return aligned_self._obj.plain_is_equal(aligned_other._obj) + def __lt__(self, other: _NamedIslSetLike) -> bool: + return _compare_set_like(self, other, isl.Set.is_strict_subset) + + def __le__(self, other: _NamedIslSetLike) -> bool: + return _compare_set_like(self, other, isl.Set.is_subset) + + def __gt__(self, other: _NamedIslSetLike) -> bool: + return _compare_set_like(other, self, isl.Set.is_strict_subset) + + def __ge__(self, other: _NamedIslSetLike) -> bool: + return _compare_set_like(other, self, isl.Set.is_subset) + @final @dataclass(frozen=True, eq=False) class BasicSet(_NamedIslSetLike): @override - def add_input_names(self, names_to_add: Sequence[str]) -> BasicSet: - raise NotImplementedError - - @override - def add_output_names(self, names_to_add: Sequence[str]) -> BasicSet: + def add_input_names(self, names_to_add: Collection[str]) -> BasicSet: raise NotImplementedError @override @@ -256,20 +313,29 @@ def make_basic_set(src: isl.BasicSet) -> BasicSet: def make_basic_set(src: str | isl.BasicSet, ctx: isl.Context | None = None) -> BasicSet: obj = isl.BasicSet(src, ctx) if isinstance(src, str) else src - - set_obj, dimtype_to_names = _deconstruct_object(obj) - + set_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) assert isinstance(set_obj, isl.Set) - set_obj, name_to_dim = _strip_names(set_obj) - dimtype_to_names = _normalize_dimtype_to_names(set_obj, dimtype_to_names) - return BasicSet(set_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args @final @dataclass(frozen=True, eq=False) class Set(_NamedIslSetLike): - ... + @override + def add_input_names(self, names_to_add: Collection[str]) -> Set: + raise NotImplementedError + + @override + def _reconstruct_isl_object(self) -> isl.Set: + obj = super()._reconstruct_isl_object() + assert isinstance(obj, isl.Set) + return obj + + def get_basic_sets(self) -> Sequence[BasicSet]: + isl_obj = self._reconstruct_isl_object() + + bsets = isl_obj.get_basic_sets() + return [make_basic_set(bset) for bset in bsets] def _apply_set_like_binary_op( @@ -292,59 +358,79 @@ def _apply_set_like_binary_op( return result_type(result, lhs._name_to_dim, lhs._dimtype_to_names) +def _compare_set_like( + lhs: _NamedIslSetLike, + rhs: _NamedIslSetLike, + op: Callable[[isl.Set, isl.Set], bool] + ) -> bool: + lhs_is_map = isinstance(lhs, _NamedIslMapLike) + rhs_is_map = isinstance(rhs, _NamedIslMapLike) + if lhs_is_map != rhs_is_map: + raise TypeError("Cannot compare set-like and map-like objects") + + aligned_lhs, aligned_rhs = _align_two(lhs, rhs) + + assert isinstance(aligned_lhs._obj, isl.Set) + assert isinstance(aligned_rhs._obj, isl.Set) + return op(aligned_lhs._obj, aligned_rhs._obj) + + class _NamedIslMapLike(_NamedIslSetLike): @override - def _reconstruct_isl_object(self) -> isl.Map: + def _reconstruct_isl_object(self) -> isl.BasicMap | isl.Map: obj = super()._reconstruct_isl_object() if isinstance(obj, isl.Set): return isl.Map.from_domain_and_range(isl.Set("{ [] }"), obj) + assert isinstance(obj, isl.BasicMap | isl.Map) return obj def _output_names(self) -> frozenset[str]: return frozenset(self._name_to_dim) - self.input_names - self.parameter_names + def _map_obj(self) -> isl.BasicMap | isl.Map: + return self._reconstruct_isl_object() + @staticmethod - def _logical_name(name: str) -> str: - return name.rstrip("'") + def _wrap_map_result(result: isl.BasicMap | isl.Map) -> BasicMap | Map: + if isinstance(result, isl.BasicMap): + return make_basic_map(result) + return make_map(result) + + def _map_with_universe( + self, + dim_type: isl.dim_type, + set_obj: isl.BasicSet | isl.Set + ) -> BasicMap | Map: + map_obj = self._map_obj() + if dim_type == isl.dim_type.in_: + universe = isl.Set.universe(map_obj.range().get_space()) + return make_map(isl.Map.from_domain_and_range(set_obj, universe)) + if dim_type == isl.dim_type.out: + universe = isl.Set.universe(map_obj.domain().get_space()) + return make_map(isl.Map.from_domain_and_range(universe, set_obj)) + raise ValueError(f"unsupported dim type: {dim_type}") def _ordered_names(self, names: frozenset[str]) -> tuple[str, ...]: return tuple(sorted(names, key=self._name_to_dim.__getitem__)) - def _ordered_logical_names(self, names: frozenset[str]) -> tuple[str, ...]: - return tuple(self._logical_name(name) for name in self._ordered_names(names)) - - def _actual_names_for_logical_order( + def _reject_surviving_name_collisions( self, - names: frozenset[str], - logical_order: tuple[str, ...] - ) -> tuple[str, ...]: - name_by_logical: dict[str, str] = {} - for name in self._ordered_names(names): - logical_name = self._logical_name(name) - if logical_name in name_by_logical: - raise ValueError( - "multiple dimensions in one interface share the same " - f"logical name: {logical_name}" - ) - name_by_logical[logical_name] = name - - try: - return tuple(name_by_logical[logical_name] for logical_name in logical_order) - except KeyError as exc: - raise ValueError("maps are not composable: interface names differ") from exc + collisions: frozenset[str], + ) -> None: + if collisions: + raise ValueError( + "composition would create duplicate surviving names: " + + ", ".join(sorted(collisions)) + ) def _reorder_interface( self, dim_type: isl.dim_type, - logical_order: tuple[str, ...] + ordered_names: tuple[str, ...] ) -> _NamedIslMapLike: interface_names = ( self.input_names if dim_type == isl.dim_type.in_ else self._output_names() ) - ordered_names = self._actual_names_for_logical_order( - interface_names, - logical_order - ) current_names = self._ordered_names(interface_names) if current_names == ordered_names: return self @@ -372,50 +458,76 @@ def _validate_composable( other: BasicMap | Map, rhs_dim_type: isl.dim_type ) -> tuple[str, ...]: - lhs_names = self.input_names if lhs_dim_type == isl.dim_type.in_ else self._output_names() - rhs_names = other.input_names if rhs_dim_type == isl.dim_type.in_ else other._output_names() - if ( - frozenset(self._logical_name(name) for name in lhs_names) - != - frozenset(self._logical_name(name) for name in rhs_names) - ): + lhs_names = ( + self.input_names if lhs_dim_type == isl.dim_type.in_ + else self._output_names() + ) + rhs_names = ( + other.input_names if rhs_dim_type == isl.dim_type.in_ + else other._output_names() + ) + if lhs_names != rhs_names: raise ValueError("maps are not composable: interface names differ") - return self._ordered_logical_names(lhs_names) + return self._ordered_names(lhs_names) - def intersect_domain(self, domain: BasicSet | Set) -> Map: - return self & make_map(isl.Map.from_domain_and_range( - domain._reconstruct_isl_object(), - isl.Set.universe(self._reconstruct_isl_object().range().get_space()) + def intersect_domain(self, domain: BasicSet | Set) -> BasicMap | Map: + domain_obj = domain._reconstruct_isl_object() + assert isinstance(domain_obj, isl.BasicSet | isl.Set) + return cast("BasicMap | Map", self & self._map_with_universe( + isl.dim_type.in_, + domain_obj, )) - def intersect_range(self, range_: BasicSet | Set) -> Map: - return self & make_map(isl.Map.from_domain_and_range( - isl.Set.universe(self._reconstruct_isl_object().domain().get_space()), - range_._reconstruct_isl_object() + def intersect_range(self, range_: BasicSet | Set) -> BasicMap | Map: + range_obj = range_._reconstruct_isl_object() + assert isinstance(range_obj, isl.BasicSet | isl.Set) + return cast("BasicMap | Map", self & self._map_with_universe( + isl.dim_type.out, + range_obj, )) - def apply_range(self, other: BasicMap | Map) -> Map: - ordered_names = self._validate_composable(isl.dim_type.out, other, isl.dim_type.in_) - other = other._reorder_interface(isl.dim_type.in_, ordered_names) - return make_map( - self._reconstruct_isl_object().apply_range(other._reconstruct_isl_object()) + def apply_range(self, other: BasicMap | Map) -> BasicMap | Map: + ordered_names = self._validate_composable( + isl.dim_type.out, + other, + isl.dim_type.in_ + ) + other = cast("BasicMap | Map", other._reorder_interface( + isl.dim_type.in_, ordered_names)) + self._reject_surviving_name_collisions( + self.input_names & other._output_names() ) + result = self._map_obj().apply_range(other._map_obj()) + return self._wrap_map_result(result) - def apply_domain(self, other: BasicMap | Map) -> Map: - ordered_names = self._validate_composable(isl.dim_type.in_, other, isl.dim_type.out) - other = other._reorder_interface(isl.dim_type.out, ordered_names) - return make_map( - other._reconstruct_isl_object().apply_range(self._reconstruct_isl_object()) + def apply_domain(self, other: BasicMap | Map) -> BasicMap | Map: + ordered_names = self._validate_composable( + isl.dim_type.in_, + other, + isl.dim_type.out + ) + other = cast("BasicMap | Map", other._reorder_interface( + isl.dim_type.out, ordered_names)) + self._reject_surviving_name_collisions( + other.input_names & self._output_names() ) + result = other._map_obj().apply_range(self._map_obj()) + return self._wrap_map_result(result) - def reverse(self) -> Map: - return make_map(self._reconstruct_isl_object().reverse()) + def reverse(self) -> BasicMap | Map: + return self._wrap_map_result(self._map_obj().reverse()) - def domain(self) -> Set: - return make_set(self._reconstruct_isl_object().domain()) + def domain(self) -> BasicSet | Set: + domain = self._map_obj().domain() + if isinstance(domain, isl.BasicSet): + return make_basic_set(domain) + return make_set(domain) - def range(self) -> Set: - return make_set(self._reconstruct_isl_object().range()) + def range(self) -> BasicSet | Set: + range_ = self._map_obj().range() + if isinstance(range_, isl.BasicSet): + return make_basic_set(range_) + return make_set(range_) @overload @@ -430,13 +542,8 @@ def make_set(src: isl.Set) -> Set: def make_set(src: isl.Set | str, ctx: isl.Context | None = None) -> Set: obj = isl.Set(src, ctx) if isinstance(src, str) else src - - set_obj, dimtype_to_names = _deconstruct_object(obj) - + set_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) assert isinstance(set_obj, isl.Set) - set_obj, name_to_dim = _strip_names(set_obj) - dimtype_to_names = _normalize_dimtype_to_names(set_obj, dimtype_to_names) - return Set(set_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args @@ -446,57 +553,67 @@ class BasicMap(_NamedIslMapLike): @classmethod def empty(cls, space: isl.Space) -> BasicMap: obj = isl.BasicMap.empty(space) - set_obj, dimtype_to_names = _deconstruct_object(obj) + set_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) assert isinstance(set_obj, isl.Set) - set_obj, name_to_dim = _strip_names(set_obj) - dimtype_to_names = _normalize_dimtype_to_names(set_obj, dimtype_to_names) return cls(set_obj, name_to_dim, dimtype_to_names) - def reverse(self) -> BasicMap: - return make_basic_map(self._reconstruct_isl_object().reverse()) + @override + def _map_obj(self) -> isl.BasicMap: + obj = self._reconstruct_isl_object() + assert isinstance(obj, isl.BasicMap) + return obj + + @staticmethod + @override + def _wrap_map_result(result: isl.BasicMap | isl.Map) -> BasicMap | Map: + if isinstance(result, isl.BasicMap): + return make_basic_map(result) + return make_map(result) + + @override + def reverse(self) -> BasicMap | Map: + return self._wrap_map_result(self._map_obj().reverse()) + @override def domain(self) -> BasicSet: - return make_basic_set(self._reconstruct_isl_object().domain()) + return make_basic_set(self._map_obj().domain()) + @override def range(self) -> BasicSet: - return make_basic_set(self._reconstruct_isl_object().range()) + return make_basic_set(self._map_obj().range()) + @override def intersect_domain(self, domain: BasicSet | Set) -> BasicMap | Map: if isinstance(domain, BasicSet): - return self & make_basic_map(isl.BasicMap.from_domain_and_range( - domain._reconstruct_isl_object(), - isl.BasicSet.universe(self._reconstruct_isl_object().range().get_space()) - )) + range_space = self._map_obj().range().get_space() + filter_map = make_basic_map( + isl.BasicMap.from_domain_and_range( + domain._reconstruct_isl_object(), + isl.BasicSet.universe(range_space) + ) + ) + return cast("BasicMap | Map", self & filter_map) return super().intersect_domain(domain) + @override def intersect_range(self, range_: BasicSet | Set) -> BasicMap | Map: if isinstance(range_, BasicSet): - return self & make_basic_map(isl.BasicMap.from_domain_and_range( - isl.BasicSet.universe(self._reconstruct_isl_object().domain().get_space()), - range_._reconstruct_isl_object() - )) + domain_space = self._map_obj().domain().get_space() + filter_map = make_basic_map( + isl.BasicMap.from_domain_and_range( + isl.BasicSet.universe(domain_space), + range_._reconstruct_isl_object() + ) + ) + return cast("BasicMap | Map", self & filter_map) return super().intersect_range(range_) + @override def apply_range(self, other: BasicMap | Map) -> BasicMap | Map: - if isinstance(other, BasicMap): - ordered_names = self._validate_composable( - isl.dim_type.out, other, isl.dim_type.in_) - other = other._reorder_interface(isl.dim_type.in_, ordered_names) - return make_basic_map( - self._reconstruct_isl_object().apply_range( - other._reconstruct_isl_object()) - ) return super().apply_range(other) + @override def apply_domain(self, other: BasicMap | Map) -> BasicMap | Map: - if isinstance(other, BasicMap): - ordered_names = self._validate_composable( - isl.dim_type.in_, other, isl.dim_type.out) - other = other._reorder_interface(isl.dim_type.out, ordered_names) - return make_basic_map( - other._reconstruct_isl_object().apply_range( - self._reconstruct_isl_object()) - ) return super().apply_domain(other) @override @@ -526,13 +643,8 @@ def make_basic_map(src: isl.BasicMap) -> BasicMap: def make_basic_map(src: str | isl.BasicMap, ctx: isl.Context | None = None) -> BasicMap: obj = isl.BasicMap(src, ctx) if isinstance(src, str) else src - - set_obj, dimtype_to_names = _deconstruct_object(obj) - + set_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) assert isinstance(set_obj, isl.Set) - set_obj, name_to_dim = _strip_names(set_obj) - dimtype_to_names = _normalize_dimtype_to_names(set_obj, dimtype_to_names) - return BasicMap(set_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args @@ -541,17 +653,25 @@ def make_map_from_domain_and_range( range_: BasicSet | Set ) -> BasicMap | Map: if isinstance(domain, BasicSet) and isinstance(range_, BasicSet): + domain_obj = domain._reconstruct_isl_object() + range_obj = range_._reconstruct_isl_object() + assert isinstance(domain_obj, isl.BasicSet) + assert isinstance(range_obj, isl.BasicSet) return make_basic_map( isl.BasicMap.from_domain_and_range( - domain._reconstruct_isl_object(), - range_._reconstruct_isl_object() + domain_obj, + range_obj, ) ) + domain_obj = domain._reconstruct_isl_object() + range_obj = range_._reconstruct_isl_object() + assert isinstance(domain_obj, isl.BasicSet | isl.Set) + assert isinstance(range_obj, isl.BasicSet | isl.Set) return make_map( isl.Map.from_domain_and_range( - domain._reconstruct_isl_object(), - range_._reconstruct_isl_object() + domain_obj, + range_obj, ) ) @@ -563,6 +683,12 @@ class Map(_NamedIslMapLike): def empty(cls, space: isl.Space) -> Map: return make_map(isl.Map.empty(space)) + @override + def _reconstruct_isl_object(self) -> isl.Map: + obj = super()._reconstruct_isl_object() + assert isinstance(obj, isl.Map) + return obj + @overload def make_map(src: str, ctx: isl.Context | None = None) -> Map: @@ -576,11 +702,6 @@ def make_map(src: isl.Map) -> Map: def make_map(src: str | isl.Map, ctx: isl.Context | None = None) -> Map: obj = isl.Map(src, ctx) if isinstance(src, str) else src - - set_obj, dimtype_to_names = _deconstruct_object(obj) - + set_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) assert isinstance(set_obj, isl.Set) - set_obj, name_to_dim = _strip_names(set_obj) - dimtype_to_names = _normalize_dimtype_to_names(set_obj, dimtype_to_names) - return Map(set_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args diff --git a/namedisl/test/test_expression_like.py b/namedisl/test/test_expression_like.py index 9d5454a..fa1975b 100644 --- a/namedisl/test/test_expression_like.py +++ b/namedisl/test/test_expression_like.py @@ -139,6 +139,20 @@ def test_pwaff_binary_ops(): # }}} +def test_multi_aff_get_at_uses_name() -> None: + map_ = nisl.make_map("{ [i] -> [x = i, y = 2i] }") + maff = nisl.make_multi_aff( + map_._reconstruct_isl_object().as_pw_multi_aff().as_multi_aff() + ) + assert maff.get_at("x")._reconstruct_isl_object() == isl.Aff("{ [i] -> [(i)] }") + + +def test_pw_multi_aff_get_at_uses_name() -> None: + map_ = nisl.make_map("{ [i] -> [x = i, y = 2i] }") + pmaff = nisl.make_pw_multi_aff(map_.as_pw_multi_aff()) + assert pmaff.get_at("y")._reconstruct_isl_object() == isl.PwAff("{ [i] -> [(2i)] }") + + # {{{ qpolynomials def test_qpolynomial_from_str(): diff --git a/namedisl/test/test_namedisl.py b/namedisl/test/test_namedisl.py index f6136d8..6c91d17 100644 --- a/namedisl/test/test_namedisl.py +++ b/namedisl/test/test_namedisl.py @@ -27,6 +27,9 @@ import pytest +import islpy as isl + +import namedisl as nisl from .utils_for_tests import generate_random_named_set, get_name_sequence @@ -57,3 +60,256 @@ def test_add_names( s = s.add_names([SetName(name) for name in new_set_names]) print(s) + + +def test_public_dim_type_name_accessors() -> None: + named_map = nisl.make_map("[n] -> { [i] -> [o] }") + + assert named_map.dim_names(isl.dim_type.in_) == frozenset({"i"}) + assert named_map.input_names == frozenset({"i"}) + assert named_map.dim_names(isl.dim_type.out) == frozenset({"o"}) + assert named_map.output_names == frozenset({"o"}) + assert named_map.dim_names(isl.dim_type.param) == frozenset({"n"}) + assert named_map.parameter_names == frozenset({"n"}) + + +def test_public_dim_type_name_accessors_for_aff() -> None: + named_aff = nisl.make_aff("[n] -> { [i] -> [i + n] }") + + assert named_aff.input_names == frozenset({"i"}) + assert named_aff.parameter_names == frozenset({"n"}) + assert named_aff.output_names == frozenset() + assert named_aff.set_names == frozenset() + + +def test_add_set_and_parameter_names_reconstructs_expected_set() -> None: + named_set = ( + nisl.make_set("{ [x] }") + .add_set_names(["y"]) + .add_parameter_names(["p"]) + ) + expected = isl.Set("[p] -> { [y, x] }") + reconstructed = named_set._reconstruct_isl_object() + + assert isinstance(reconstructed, isl.Set) + assert reconstructed.plain_is_equal(expected) + assert _dim_names(reconstructed, isl.dim_type.set) == _dim_names( + expected, isl.dim_type.set + ) + assert _dim_names(reconstructed, isl.dim_type.param) == _dim_names( + expected, isl.dim_type.param + ) + + +def test_add_output_input_and_parameter_names_reconstructs_expected_map() -> None: + named_map = ( + nisl.make_map("{ [i] -> [o] }") + .add_output_names(["o2"]) + .add_input_names(["i2"]) + .add_parameter_names(["n"]) + ) + expected = isl.Map("[n] -> { [i2, i] -> [o2, o] }") + reconstructed = named_map._reconstruct_isl_object() + + assert isinstance(reconstructed, isl.Map) + assert reconstructed.plain_is_equal(expected) + assert _dim_names(reconstructed, isl.dim_type.in_) == _dim_names( + expected, isl.dim_type.in_ + ) + assert _dim_names(reconstructed, isl.dim_type.out) == _dim_names( + expected, isl.dim_type.out + ) + assert _dim_names(reconstructed, isl.dim_type.param) == _dim_names( + expected, isl.dim_type.param + ) + + +def test_add_input_and_parameter_names_reconstructs_expected_aff() -> None: + named_aff = ( + nisl.make_aff("[n] -> { [i] -> [i + n] }") + .add_input_names(["j"]) + .add_parameter_names(["m"]) + ) + expected = isl.Aff("[m, n] -> { [j, i] -> [i + n] }") + reconstructed = named_aff._reconstruct_isl_object() + + assert reconstructed == expected + + +def test_add_dim_names_uses_dim_type() -> None: + named_map = ( + nisl.make_map("{ [i] -> [o] }") + .add_dim_names(["j"], isl.dim_type.in_) + .add_dim_names(["p"], isl.dim_type.param) + .add_dim_names(["x"], isl.dim_type.out) + ) + + assert named_map.input_names == frozenset({"j", "i"}) + assert named_map.parameter_names == frozenset({"p"}) + assert named_map.output_names == frozenset({"x", "o"}) + + +def _dim_names( + obj: isl.Set | isl.Map, + dim_type: isl.dim_type + ) -> tuple[str | None, ...]: + return tuple(obj.get_dim_name(dim_type, i) for i in range(obj.dim(dim_type))) + + +def test_move_dims_set_reconstructs_like_isl() -> None: + named_set = nisl.make_set("[p] -> { [x, y, z] : x + y = z and 0 <= x, y, z < p }") + + moved = named_set.move_dims("z", isl.dim_type.param) + + expected = isl.Set( + "[p] -> { [x, y, z] : " + "x + y = z and 0 <= x and 0 <= y and 0 <= z " + "and x < p and y < p and z < p }" + ).move_dims( + isl.dim_type.param, 1, + isl.dim_type.set, 2, 1 + ) + + reconstructed = moved._reconstruct_isl_object() + assert isinstance(reconstructed, isl.Set) + assert reconstructed.plain_is_equal(expected) + assert _dim_names(reconstructed, isl.dim_type.set) == _dim_names( + expected, isl.dim_type.set + ) + assert _dim_names(reconstructed, isl.dim_type.param) == _dim_names( + expected, isl.dim_type.param + ) + + +def test_move_dims_map_reconstructs_like_isl() -> None: + named_map = nisl.make_map( + "[p] -> { [i0, i1] -> [o0, o1, o2] : o0 = i0 and o1 = i1 and o2 = p }" + ) + + moved = named_map.move_dims("o2", isl.dim_type.in_) + + expected = isl.Map( + "[p] -> { [i0, i1] -> [o0, o1, o2] : o0 = i0 and o1 = i1 and o2 = p }" + ).move_dims(isl.dim_type.in_, 2, isl.dim_type.out, 2, 1) + + reconstructed = moved._reconstruct_isl_object() + assert isinstance(reconstructed, isl.Map) + assert reconstructed.plain_is_equal(expected) + assert _dim_names(reconstructed, isl.dim_type.in_) == _dim_names( + expected, isl.dim_type.in_ + ) + assert _dim_names(reconstructed, isl.dim_type.out) == _dim_names( + expected, isl.dim_type.out + ) + assert _dim_names(reconstructed, isl.dim_type.param) == _dim_names( + expected, isl.dim_type.param + ) + + +def test_move_dims_multiple_names_preserves_relative_order() -> None: + named_map = nisl.make_map( + "[p] -> { [i0, i1] -> [o0, o1, o2] : o0 = i0 and o1 = i1 and o2 = p }" + ) + + moved = named_map.move_dims(["o1", "o2"], isl.dim_type.in_) + + expected = isl.Map( + "[p] -> { [i0, i1] -> [o0, o1, o2] : o0 = i0 and o1 = i1 and o2 = p }" + ) + expected = expected.move_dims(isl.dim_type.in_, 2, isl.dim_type.out, 1, 1) + expected = expected.move_dims(isl.dim_type.in_, 3, isl.dim_type.out, 1, 1) + + reconstructed = moved._reconstruct_isl_object() + assert isinstance(reconstructed, isl.Map) + assert reconstructed.plain_is_equal(expected) + assert _dim_names(reconstructed, isl.dim_type.in_) == _dim_names( + expected, isl.dim_type.in_ + ) + assert _dim_names(reconstructed, isl.dim_type.out) == _dim_names( + expected, isl.dim_type.out + ) + + +def test_rename_dims_set_reconstructs_like_isl() -> None: + named_set = nisl.make_set("[p] -> { [x, y] : x < p and y < p }") + + renamed = named_set.rename_dims({"x": "x_new", "p": "n"}) + + expected = isl.Set("[p] -> { [x, y] : x < p and y < p }") + expected = expected.set_dim_name(isl.dim_type.set, 0, "x_new") + expected = expected.set_dim_name(isl.dim_type.param, 0, "n") + + reconstructed = renamed._reconstruct_isl_object() + assert isinstance(reconstructed, isl.Set) + assert reconstructed.plain_is_equal(expected) + assert _dim_names(reconstructed, isl.dim_type.set) == _dim_names( + expected, isl.dim_type.set + ) + assert _dim_names(reconstructed, isl.dim_type.param) == _dim_names( + expected, isl.dim_type.param + ) + + +def test_rename_dims_map_reconstructs_like_isl() -> None: + named_map = nisl.make_map( + "[p] -> { [i0, i1] -> [o0, o1] : o0 = i0 and o1 = p + i1 }" + ) + + renamed = named_map.rename_dims({"i1": "j", "o1": "x", "p": "n"}) + + expected = isl.Map( + "[p] -> { [i0, i1] -> [o0, o1] : o0 = i0 and o1 = p + i1 }" + ) + expected = expected.set_dim_name(isl.dim_type.in_, 1, "j") + expected = expected.set_dim_name(isl.dim_type.out, 1, "x") + expected = expected.set_dim_name(isl.dim_type.param, 0, "n") + + reconstructed = renamed._reconstruct_isl_object() + assert isinstance(reconstructed, isl.Map) + assert reconstructed.plain_is_equal(expected) + assert _dim_names(reconstructed, isl.dim_type.in_) == _dim_names( + expected, isl.dim_type.in_ + ) + assert _dim_names(reconstructed, isl.dim_type.out) == _dim_names( + expected, isl.dim_type.out + ) + assert _dim_names(reconstructed, isl.dim_type.param) == _dim_names( + expected, isl.dim_type.param + ) + + +def test_rename_dims_rejects_renaming_to_existing_name() -> None: + named_map = nisl.make_map("{ [i] -> [o] }") + + with pytest.raises(ValueError, match="existing names"): + _ = named_map.rename_dims({"i": "o"}) + + +def test_rename_dims_rejects_unknown_name() -> None: + named_set = nisl.make_set("{ [x] }") + + with pytest.raises(ValueError, match="unknown names"): + _ = named_set.rename_dims({"y": "z"}) + + +def test_positional_fallback_methods_are_not_exposed() -> None: + named_set = nisl.make_set("{ [x, y] }") + + with pytest.raises(AttributeError): + named_set.__getattribute__("get_dim_name") + + +def test_duplicate_set_names_are_rejected() -> None: + with pytest.raises(ValueError, match=r"duplicate|unnamed"): + _ = nisl.make_set("{ [x, x] }") + + +def test_ticked_names_are_distinct_names() -> None: + space = isl.Space.create_from_names( + isl.DEFAULT_CONTEXT, + set=["x", "x'"] + ) + + named_set = nisl.make_set(isl.Set.universe(space)) + + assert named_set.names == frozenset({"x", "x'"}) diff --git a/namedisl/test/test_set_like.py b/namedisl/test/test_set_like.py index e18297d..8bd26b8 100644 --- a/namedisl/test/test_set_like.py +++ b/namedisl/test/test_set_like.py @@ -26,14 +26,15 @@ """ import pytest +from loopy.symbolic import pwaff_from_expr +from pymbolic import var +from typing_extensions import assert_type import islpy as isl import namedisl as nisl -from namedisl.core import _align_two -from loopy.symbolic import pwaff_from_expr -from pymbolic import var from .utils_for_tests import generate_random_named_map, generate_random_named_set +from namedisl.core import _align_two # {{{ sets @@ -117,6 +118,76 @@ def test_set_intersection(ndims: int, has_params: bool): assert (a & b) == result +def test_set_add_constraint_uses_named_dimensions() -> None: + set_ = nisl.make_set("{ [j, i] }") + + constrained = set_.add_constraint("i = j - 1") + + assert constrained == nisl.make_set("{ [j, i] : i = j - 1 }") + + +def test_set_add_constraint_accepts_multiple_constraints() -> None: + set_ = nisl.make_set("{ [i, j, k] }") + + constrained = set_.add_constraint(["0 <= i", "j = i + 1", "k <= j"]) + + assert constrained == nisl.make_set( + "{ [i, j, k] : 0 <= i and j = i + 1 and k <= j }" + ) + + +def test_set_add_constraint_rejects_unknown_name() -> None: + set_ = nisl.make_set("{ [i] }") + + with pytest.raises(ValueError, match="invalid constraint"): + _ = set_.add_constraint("j = i") + + +def test_set_gist_simplifies_against_named_context() -> None: + set_ = nisl.make_set( + "{ [i, j, kb] : 0 <= i <= 13 and 0 <= j <= 13 " + "and 0 <= kb <= 4 and kb <= 3 }" + ) + context = nisl.make_set("{ [j, i, kb] : 0 <= i <= 13 and 0 <= j <= 13 }") + + assert set_.gist(context) == nisl.make_set("{ [i, j, kb] : 0 <= kb <= 3 }") + + +def test_basic_set_gist_preserves_basic_set_when_result_is_basic() -> None: + set_ = nisl.make_basic_set("{ [i] : 0 <= i <= 10 }") + context = nisl.make_set("{ [i] : i <= 5 or i >= 8 }") + + result = set_.gist(context) + + assert isinstance(result, nisl.BasicSet) + assert result == nisl.make_basic_set("{ [i] : 0 <= i <= 10 }") + + +def test_set_subset_comparisons_align_by_name() -> None: + smaller = nisl.make_set("{ [i] : 0 <= i < 5 }") + larger = nisl.make_set("{ [j, i] : 0 <= i < 10 }") + equal_reordered = nisl.make_set("{ [i, j] : 0 <= i < 5 }") + + assert smaller < larger + assert smaller <= larger + assert larger > smaller + assert larger >= smaller + assert smaller <= equal_reordered + assert equal_reordered <= smaller + assert not smaller < equal_reordered + assert not larger <= smaller + + +def test_basic_set_subset_comparisons_allow_set_promotion() -> None: + smaller = nisl.make_basic_set("{ [i] : 0 <= i < 5 }") + larger = nisl.make_set("{ [j, i] : 0 <= i < 10 }") + + assert smaller < larger + assert smaller <= larger + assert larger > smaller + assert larger >= smaller + + def test_basic_set_intersection_promotes_to_set() -> None: basic = nisl.make_basic_set( "{ [ii_s, ji_s, k_s] : 0 <= ii_s <= 4 and 0 <= ji_s <= 4 and k_s = 0 }" @@ -155,28 +226,28 @@ def test_set_project_out(ndims: int): def test_set_dim_max(ndims: int): a, a_dims, a_cond = generate_random_named_set(ndims, "a", None) - # unnamed, so use isl.PwAff instead of nisl.make_pw_aff + # dim_{min,max} return raw isl.PwAff objects on a zero-dimensional set space. cond_pw_affs = [ - isl.PwAff(f"{{ [] -> [{cond.split('<')[2].strip(' ')}] }}") + isl.PwAff(f"{{ [{cond.split('<')[2].strip(' ')}] }}") for cond in a_cond.split("and") ] for i, name in enumerate(a_dims.split(",")): - assert a.dim_max(name)._reconstruct_isl_object() == (cond_pw_affs[i] - 1) + assert a.dim_max(name) == (cond_pw_affs[i] - 1) @pytest.mark.parametrize("ndims", [2, 4, 8]) def test_set_dim_min(ndims: int): a, a_dims, a_cond = generate_random_named_set(ndims, "a", None) - # unnamed, so use isl.PwAff instead of nisl.make_pw_aff + # dim_{min,max} return raw isl.PwAff objects on a zero-dimensional set space. cond_pw_affs = [ - isl.PwAff(f"{{ [] -> [{cond.split('<')[0].strip(' ')}] }}") + isl.PwAff(f"{{ [{cond.split('<')[0].strip(' ')}] }}") for cond in a_cond.split("and") ] for i, name in enumerate(a_dims.split(",")): - assert a.dim_min(name)._reconstruct_isl_object() == cond_pw_affs[i] + assert a.dim_min(name) == cond_pw_affs[i] # }}} @@ -322,6 +393,82 @@ def test_map_intersection(ndims_domain: int, ndims_range: int, has_params: bool) assert (x & y) == result_map +def test_map_add_constraint_uses_input_output_and_parameter_names() -> None: + map_ = nisl.make_map("[n] -> { [i] -> [j] }") + + constrained = map_.add_constraint("j = i + n") + + assert constrained == nisl.make_map("[n] -> { [i] -> [j] : j = i + n }") + + +def test_map_add_constraint_preserves_basic_map_type() -> None: + map_ = nisl.make_basic_map("{ [i] -> [j] }") + + constrained = map_.add_constraint("j = i + 1") + + assert isinstance(constrained, nisl.BasicMap) + assert constrained == nisl.make_basic_map("{ [i] -> [j] : j = i + 1 }") + + +def test_map_add_constraint_supports_previous_context_relation() -> None: + relation = nisl.make_map( + "{ [ki_prev, kb_prev] -> [ki_cur, kb_cur] : " + "kb_cur = kb_prev + 1 }" + ) + + constrained = relation.add_constraint("ki_prev = ki_cur - 1") + + assert constrained == nisl.make_map( + "{ [ki_prev, kb_prev] -> [ki_cur, kb_cur] : " + "ki_prev = ki_cur - 1 and kb_cur = kb_prev + 1 }" + ) + + +def test_map_gist_simplifies_against_named_context() -> None: + map_ = nisl.make_map( + "{ [i, kb] -> [j] : 0 <= i <= 13 and 0 <= kb <= 4 " + "and kb <= 3 and j = i }" + ) + context = nisl.make_map("{ [kb, i] -> [j] : 0 <= i <= 13 and j = i }") + + assert map_.gist(context) == nisl.make_map("{ [i, kb] -> [j] : 0 <= kb <= 3 }") + + +def test_map_subset_comparisons_align_by_name() -> None: + smaller = nisl.make_map("{ [i] -> [x] : x = i and 0 <= i < 5 }") + larger = nisl.make_map("{ [j, i] -> [y, x] : x = i and 0 <= i < 10 }") + equal_reordered = nisl.make_map( + "{ [i, j] -> [x, y] : x = i and 0 <= i < 5 }" + ) + + assert smaller < larger + assert smaller <= larger + assert larger > smaller + assert larger >= smaller + assert smaller <= equal_reordered + assert equal_reordered <= smaller + assert not smaller < equal_reordered + assert not larger <= smaller + + +def test_basic_map_subset_comparisons_allow_map_promotion() -> None: + smaller = nisl.make_basic_map("{ [i] -> [x] : x = i and 0 <= i < 5 }") + larger = nisl.make_map("{ [j, i] -> [y, x] : x = i and 0 <= i < 10 }") + + assert smaller < larger + assert smaller <= larger + assert larger > smaller + assert larger >= smaller + + +def test_subset_comparison_rejects_set_map_mismatch() -> None: + set_ = nisl.make_set("{ [i] : 0 <= i < 5 }") + map_ = nisl.make_map("{ [i] -> [x] : x = i and 0 <= i < 5 }") + + with pytest.raises(TypeError): + _ = set_ <= map_ + + def test_map_alignment_syncs_internal_output_positions_and_names() -> None: lhs = nisl.make_map("{ [i] -> [x] }") rhs = nisl.make_map("{ [i] -> [y, x] }") @@ -394,37 +541,100 @@ def test_map_apply_range_for_compute_a_tile_usage_map() -> None: ) global_usage_map = global_usage_map | local_usage_map + assert isinstance(global_usage_map, nisl.Map) + assert_type(global_usage_map, nisl.Map) + compute_map = compute_map.rename_dims({ + "ii_s": "ii_s_out", + "io": "io_out", + "ki_s": "ki_s_out", + "ko": "ko_out", + }) composed = global_usage_map.apply_range(compute_map) - assert frozenset( - name.rstrip("'") for name in composed.input_names - ) == frozenset( + assert composed.input_names == frozenset( {"i", "ii", "ii_s", "io", "j", "ji", "ji_s", "jo", "k", "ki", "ki_s", "ko"} ) - assert composed.range() == nisl.make_set("{ [ii_s, io, ki_s, ko] }") + assert composed.range().names == frozenset( + {"ii_s_out", "io_out", "ki_s_out", "ko_out"} + ) + + +def test_map_apply_range_rejects_surviving_name_collisions() -> None: + lhs = nisl.make_map("{ [x] -> [y] }") + rhs = nisl.make_map("{ [y] -> [x] }") + with pytest.raises(ValueError, match="duplicate surviving names"): + _ = lhs.apply_range(rhs) -def test_map_apply_domain_accepts_logically_equal_ticked_interface_names() -> None: - lhs = nisl.make_map("{ [x] -> [y] }").apply_range( - nisl.make_map("{ [y] -> [x] }") + +def test_map_apply_range_can_explicitly_rename_and_equate_collision() -> None: + lhs = nisl.make_map("{ [x] -> [y] }") + rhs = nisl.make_map("{ [y] -> [x] }").rename_dims({"x": "x_out"}) + + result = lhs.apply_range(rhs).equate_dims("x", "x_out") + + assert result.input_names == frozenset({"x"}) + assert result.range().names == frozenset({"x_out"}) + assert ( + result.intersect_domain(nisl.make_set("{ [x] : x = 3 }")) + .range() + .dim_min("x_out") + .plain_is_equal(isl.PwAff("{ [(3)] }")) ) - rhs = nisl.make_map("{ [p] -> [x] }") - result = lhs.apply_domain(rhs) - assert result.range() == nisl.make_set("{ [x] }") +def test_map_apply_range_can_equate_renamed_collisions_from_mapping() -> None: + lhs = nisl.make_map("{ [x, z] -> [y] }") + rhs = nisl.make_map("{ [y] -> [x, z] }").rename_dims({ + "x": "x_out", + "z": "z_out", + }) + + result = lhs.apply_range(rhs).equate_dims({ + "x": "x_out", + "z": "z_out", + }) + + assert result == nisl.make_map( + "{ [x, z] -> [x_out, z_out] : x = x_out and z = z_out }" + ) + + +def test_equate_dims_mapping_rejects_unknown_name() -> None: + map_ = nisl.make_map("{ [x] -> [x_out] }") + + with pytest.raises(ValueError, match="unknown name: missing"): + _ = map_.equate_dims({"x": "missing"}) + + +def test_map_apply_domain_rejects_surviving_name_collisions() -> None: + lhs = nisl.make_map("{ [x] -> [y] }") + rhs = nisl.make_map("{ [y] -> [x] }") + + with pytest.raises(ValueError, match="duplicate surviving names"): + _ = rhs.apply_domain(lhs) + + +def test_map_apply_domain_can_explicitly_rename_and_equate_collision() -> None: + lhs = nisl.make_map("{ [x] -> [y] }") + rhs = nisl.make_map("{ [y] -> [x] }").rename_dims({"x": "x_out"}) + result = rhs.apply_domain(lhs).equate_dims("x", "x_out") -def test_map_domain_canonicalizes_single_remaining_ticked_name() -> None: - m = nisl.make_map("{ [x] -> [y] }").apply_range( - nisl.make_map("{ [y] -> [x] }") + assert result.input_names == frozenset({"x"}) + assert result.range().names == frozenset({"x_out"}) + assert ( + result.intersect_domain(nisl.make_set("{ [x] : x = 4 }")) + .range() + .dim_min("x_out") + .plain_is_equal(isl.PwAff("{ [(4)] }")) ) - domain = m.domain() - assert domain == nisl.make_set("{ [x] }") - assert domain.names == frozenset({"x"}) +def test_duplicate_map_names_are_rejected() -> None: + with pytest.raises(ValueError, match=r"duplicate|unnamed"): + _ = nisl.make_map("{ [x] -> [x] }") def test_map_empty_from_space_preserves_names_and_is_empty() -> None: @@ -519,7 +729,7 @@ def test_map_as_pw_multi_aff(): m = nisl.make_map(spec) m_isl = isl.Map(spec) - assert m.as_pw_multi_aff()._reconstruct_isl_object() == m_isl.as_pw_multi_aff() + assert m.as_pw_multi_aff() == m_isl.as_pw_multi_aff() @pytest.mark.parametrize("ndims_domain", [1, 2, 4, 8]) @@ -530,25 +740,25 @@ def test_map_dim_max(ndims_domain: int, ndims_range: int): ndims_range, "x_out", None ) - # unnamed, so use isl.PwAff instead of nisl.make_pw_aff + # dim_{min,max} return raw isl.PwAff objects on a zero-dimensional set space. in_upper_bound_pw_maffs = [ - isl.PwAff(f"{{ [] -> [{int(cond.split('<')[2].strip(' '))}] }}") + isl.PwAff(f"{{ [{int(cond.split('<')[2].strip(' '))}] }}") for cond in in_conds.split("and") ] for i, name in enumerate(in_names.split(",")): # NOTE: constructing PwAffs assumes starting index of 0, so subtract 1 - assert m.dim_max(name)._obj == (in_upper_bound_pw_maffs[i] - 1) + assert m.dim_max(name) == (in_upper_bound_pw_maffs[i] - 1) - # unnamed, so use isl.PwAff instead of nisl.make_pw_aff + # dim_{min,max} return raw isl.PwAff objects on a zero-dimensional set space. out_upper_bound_pw_maffs = [ - isl.PwAff(f"{{ [] -> [{int(cond.split('<')[2].strip(' '))}] }}") + isl.PwAff(f"{{ [{int(cond.split('<')[2].strip(' '))}] }}") for cond in out_conds.split("and") ] for i, name in enumerate(out_names.split(",")): # NOTE: constructing PwAffs assumes starting index of 0, so subtract 1 - assert m.dim_max(name)._obj == (out_upper_bound_pw_maffs[i] - 1) + assert m.dim_max(name) == (out_upper_bound_pw_maffs[i] - 1) @pytest.mark.parametrize("ndims_domain", [1, 2, 4, 8]) @@ -559,23 +769,23 @@ def test_map_dim_min(ndims_domain: int, ndims_range: int): ndims_range, "x_out", None ) - # unnamed, so use isl.PwAff instead of nisl.make_pw_aff + # dim_{min,max} return raw isl.PwAff objects on a zero-dimensional set space. in_lower_bound_pw_maffs = [ - isl.PwAff(f"{{ [] -> [{int(cond.split('<')[0].strip(' '))}] }}") + isl.PwAff(f"{{ [{int(cond.split('<')[0].strip(' '))}] }}") for cond in in_conds.split("and") ] for i, name in enumerate(in_names.split(",")): - assert m.dim_min(name)._obj == in_lower_bound_pw_maffs[i] + assert m.dim_min(name) == in_lower_bound_pw_maffs[i] - # unnamed, so use isl.PwAff instead of nisl.make_pw_aff + # dim_{min,max} return raw isl.PwAff objects on a zero-dimensional set space. out_lower_bound_pw_maffs = [ - isl.PwAff(f"{{ [] -> [{int(cond.split('<')[0].strip(' '))}] }}") + isl.PwAff(f"{{ [{int(cond.split('<')[0].strip(' '))}] }}") for cond in out_conds.split("and") ] for i, name in enumerate(out_names.split(",")): - assert m.dim_min(name)._obj == out_lower_bound_pw_maffs[i] + assert m.dim_min(name) == out_lower_bound_pw_maffs[i] # }}} From 27d46eb3ace8995d4d0309631cfc45ee6ad7a17a Mon Sep 17 00:00:00 2001 From: Addison Date: Wed, 22 Apr 2026 15:43:29 -0500 Subject: [PATCH 29/43] clean up typing --- namedisl/set_like.py | 135 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 109 insertions(+), 26 deletions(-) diff --git a/namedisl/set_like.py b/namedisl/set_like.py index a8fa4e1..a627716 100644 --- a/namedisl/set_like.py +++ b/namedisl/set_like.py @@ -57,7 +57,7 @@ class _NamedIslSetLike(NamedIslObject[isl.Set], ABC): [ (set names), (input names), (parameter names) ] """ - def complement(self: _NamedIslSetLike) -> _NamedIslSetLike: + def complement(self: Self) -> Self: return replace( self, _obj=self._obj.complement(), @@ -65,7 +65,7 @@ def complement(self: _NamedIslSetLike) -> _NamedIslSetLike: _dimtype_to_names=self._dimtype_to_names ) - def eliminate(self, names_to_eliminate: str | Collection[str]) -> _NamedIslSetLike: + def eliminate(self: Self, names_to_eliminate: str | Collection[str]) -> Self: if isinstance(names_to_eliminate, str): names_to_eliminate = [names_to_eliminate] @@ -146,6 +146,22 @@ def add_constraint( _dimtype_to_names=self._dimtype_to_names ) + @overload + def gist(self: BasicMap, context: _NamedIslSetLike) -> BasicMap | Map: + ... + + @overload + def gist(self: Map, context: _NamedIslSetLike) -> Map: + ... + + @overload + def gist(self: BasicSet, context: _NamedIslSetLike) -> BasicSet | Set: + ... + + @overload + def gist(self: Set, context: _NamedIslSetLike) -> Set: + ... + def gist(self, context: _NamedIslSetLike) -> _NamedIslSetLike: self_aligned, context_aligned = _align_two(self, context) result = self_aligned._obj.gist(context_aligned._obj) @@ -165,8 +181,8 @@ def gist(self, context: _NamedIslSetLike) -> _NamedIslSetLike: self_aligned._dimtype_to_names ) - def project_out(self: _NamedIslSetLike, - names_to_project_out: str | Collection[str]) -> _NamedIslSetLike: + def project_out(self: Self, + names_to_project_out: str | Collection[str]) -> Self: if isinstance(names_to_project_out, str): names_to_project_out = [names_to_project_out] @@ -212,9 +228,9 @@ def project_out(self: _NamedIslSetLike, ) def project_out_except( - self: _NamedIslSetLike, + self: Self, names_to_keep: str | Collection[str], - ) -> _NamedIslSetLike: + ) -> Self: if isinstance(names_to_keep, str): names_to_keep = [names_to_keep] if names_to_keep else [] @@ -243,15 +259,62 @@ def dim(self, dim_type: isl.dim_type) -> int: dim_type = isl.dim_type.set return super().dim(dim_type) - # FIXME: basedpyright is not happy with these function signatures + @overload + def __and__(self: BasicMap, other: BasicMap | Map) -> BasicMap | Map: + ... + + @overload + def __and__(self: Map, other: BasicMap | Map) -> Map: + ... + + @overload + def __and__(self: BasicSet, other: BasicSet | Set) -> BasicSet | Set: + ... + + @overload + def __and__(self: Set, other: BasicSet | Set) -> Set: + ... + def __and__( self, other: _NamedIslSetLike) -> _NamedIslSetLike: return _apply_set_like_binary_op(self, other, operator.and_) + @overload + def __or__(self: BasicMap, other: BasicMap | Map) -> BasicMap | Map: + ... + + @overload + def __or__(self: Map, other: BasicMap | Map) -> Map: + ... + + @overload + def __or__(self: BasicSet, other: BasicSet | Set) -> BasicSet | Set: + ... + + @overload + def __or__(self: Set, other: BasicSet | Set) -> Set: + ... + def __or__( self, other: _NamedIslSetLike) -> _NamedIslSetLike: return _apply_set_like_binary_op(self, other, operator.or_) + @overload + def __sub__(self: BasicMap, other: BasicMap | Map) -> BasicMap | Map: + ... + + @overload + def __sub__(self: Map, other: BasicMap | Map) -> Map: + ... + + @overload + def __sub__(self: BasicSet, other: BasicSet | Set) -> BasicSet | Set: + ... + + @overload + def __sub__(self: Set, other: BasicSet | Set) -> Set: + ... + def __sub__( self, other: _NamedIslSetLike) -> _NamedIslSetLike: return _apply_set_like_binary_op(self, other, operator.sub) @@ -300,6 +363,9 @@ def _reconstruct_isl_object(self) -> isl.BasicSet: return obj.get_basic_sets()[0] + def get_basic_sets(self) -> Sequence[BasicSet]: + return [self] + @overload def make_basic_set(src: str, ctx: isl.Context | None = None) -> BasicSet: @@ -338,6 +404,42 @@ def get_basic_sets(self) -> Sequence[BasicSet]: return [make_basic_set(bset) for bset in bsets] +@overload +def _apply_set_like_binary_op( + lhs: BasicMap, + rhs: BasicMap | Map, + op: Callable[[isl.Set, isl.Set], isl.Set] + ) -> BasicMap | Map: + ... + + +@overload +def _apply_set_like_binary_op( + lhs: Map, + rhs: BasicMap | Map, + op: Callable[[isl.Set, isl.Set], isl.Set] + ) -> Map: + ... + + +@overload +def _apply_set_like_binary_op( + lhs: BasicSet, + rhs: BasicSet | Set, + op: Callable[[isl.Set, isl.Set], isl.Set] + ) -> BasicSet | Set: + ... + + +@overload +def _apply_set_like_binary_op( + lhs: Set, + rhs: BasicSet | Set, + op: Callable[[isl.Set, isl.Set], isl.Set] + ) -> Set: + ... + + def _apply_set_like_binary_op( lhs: _NamedIslSetLike, rhs: _NamedIslSetLike, @@ -563,17 +665,6 @@ def _map_obj(self) -> isl.BasicMap: assert isinstance(obj, isl.BasicMap) return obj - @staticmethod - @override - def _wrap_map_result(result: isl.BasicMap | isl.Map) -> BasicMap | Map: - if isinstance(result, isl.BasicMap): - return make_basic_map(result) - return make_map(result) - - @override - def reverse(self) -> BasicMap | Map: - return self._wrap_map_result(self._map_obj().reverse()) - @override def domain(self) -> BasicSet: return make_basic_set(self._map_obj().domain()) @@ -608,14 +699,6 @@ def intersect_range(self, range_: BasicSet | Set) -> BasicMap | Map: return cast("BasicMap | Map", self & filter_map) return super().intersect_range(range_) - @override - def apply_range(self, other: BasicMap | Map) -> BasicMap | Map: - return super().apply_range(other) - - @override - def apply_domain(self, other: BasicMap | Map) -> BasicMap | Map: - return super().apply_domain(other) - @override def _reconstruct_isl_object(self) -> isl.BasicMap: obj = super()._reconstruct_isl_object() From 90ab98c18e2698896f055b255745eaac7551ab6d Mon Sep 17 00:00:00 2001 From: Addison Date: Thu, 23 Apr 2026 14:42:55 -0500 Subject: [PATCH 30/43] some chunking + typing fixes --- namedisl/core.py | 96 +++++++++++++++++++++----------------------- namedisl/set_like.py | 55 ++++++++++++++++++------- 2 files changed, 86 insertions(+), 65 deletions(-) diff --git a/namedisl/core.py b/namedisl/core.py index f9ebb98..ead445e 100644 --- a/namedisl/core.py +++ b/namedisl/core.py @@ -550,6 +550,40 @@ def _ordered_name_chunks(self) -> dict[isl.dim_type, tuple[str, ...]]: isl.dim_type.param: self._ordered_names_for_dim_type(isl.dim_type.param), } + def _empty_grouped_names(self) -> dict[isl.dim_type, list[str]]: + return { + isl.dim_type.set: [], + isl.dim_type.in_: [], + isl.dim_type.param: [], + } + + def _metadata_from_chunk_names( + self, + chunk_names: Mapping[isl.dim_type, Collection[str]], + *, + has_inputs: bool + ) -> tuple[NameToDim, DimTypeToNames]: + ordered_names = [ + *chunk_names[isl.dim_type.set], + *chunk_names[isl.dim_type.in_], + *chunk_names[isl.dim_type.param], + ] + new_name_to_dim: NameToDim = constantdict({ + name: dim for dim, name in enumerate(ordered_names) + }) + + new_dimtype_to_names: dict[isl.dim_type, frozenset[str]] = {} + if has_inputs and chunk_names[isl.dim_type.in_]: + new_dimtype_to_names[isl.dim_type.in_] = frozenset( + chunk_names[isl.dim_type.in_] + ) + if chunk_names[isl.dim_type.param]: + new_dimtype_to_names[isl.dim_type.param] = frozenset( + chunk_names[isl.dim_type.param] + ) + + return new_name_to_dim, constantdict(new_dimtype_to_names) + def _add_names_by_dim_type( self, names_to_add: Collection[str], @@ -577,11 +611,7 @@ def _add_names_by_dim_type( if not names_to_add: return self - grouped_names: dict[isl.dim_type, list[str]] = { - isl.dim_type.set: [], - isl.dim_type.in_: [], - isl.dim_type.param: [], - } + grouped_names = self._empty_grouped_names() grouped_names[dim_type] = list(names_to_add) return self._add_grouped_names(grouped_names) @@ -633,39 +663,19 @@ def _add_grouped_names( ) chunk_names[dim_type] = [*names_to_add, *chunk_names[dim_type]] - ordered_names = [ - *chunk_names[isl.dim_type.set], - *chunk_names[isl.dim_type.in_], - *chunk_names[isl.dim_type.param], - ] - new_name_to_dim: NameToDim = constantdict({ - name: dim for dim, name in enumerate(ordered_names) - }) - new_dimtype_to_names: dict[isl.dim_type, frozenset[str]] = {} - if ( - isinstance(new_obj, IslSetLike | IslMultiExpressionLike) - and chunk_names[isl.dim_type.in_] - ): - new_dimtype_to_names[isl.dim_type.in_] = frozenset( - chunk_names[isl.dim_type.in_] - ) - if chunk_names[isl.dim_type.param]: - new_dimtype_to_names[isl.dim_type.param] = frozenset( - chunk_names[isl.dim_type.param] - ) + new_name_to_dim, new_dimtype_to_names = self._metadata_from_chunk_names( + chunk_names, + has_inputs=isinstance(new_obj, IslSetLike | IslMultiExpressionLike), + ) return type(self)( cast("IslObjectT", _restore_names(new_obj, new_name_to_dim)), new_name_to_dim, - constantdict(new_dimtype_to_names), + new_dimtype_to_names, ) def add_names(self, tagged_names_to_add: Sequence[_TaggedName]) -> Self: - grouped_names: dict[isl.dim_type, list[str]] = { - isl.dim_type.set: [], - isl.dim_type.in_: [], - isl.dim_type.param: [], - } + grouped_names = self._empty_grouped_names() for tagged_name in tagged_names_to_add: dim_type = _normalize_public_dim_type(tagged_name._isl_dim_type) if dim_type not in grouped_names: @@ -757,29 +767,15 @@ def move_dims( moved_names = sorted(names_to_move, key=self._name_to_dim.__getitem__) chunk_names[dst_type].extend(moved_names) - new_order = [ - *chunk_names[isl.dim_type.set], - *chunk_names[isl.dim_type.in_], - *chunk_names[isl.dim_type.param], - ] - new_name_to_dim: NameToDim = constantdict({ - name: dim for dim, name in enumerate(new_order) - }) - - new_dimtype_to_names: dict[isl.dim_type, frozenset[str]] = {} - if chunk_names[isl.dim_type.in_]: - new_dimtype_to_names[isl.dim_type.in_] = frozenset( - chunk_names[isl.dim_type.in_] - ) - if chunk_names[isl.dim_type.param]: - new_dimtype_to_names[isl.dim_type.param] = frozenset( - chunk_names[isl.dim_type.param] - ) + new_name_to_dim, new_dimtype_to_names = self._metadata_from_chunk_names( + chunk_names, + has_inputs=True, + ) return _align_obj( self, new_name_to_dim, - constantdict(new_dimtype_to_names) + new_dimtype_to_names ) def rename_dims(self, renaming: Mapping[str, str]) -> Self: diff --git a/namedisl/set_like.py b/namedisl/set_like.py index a627716..58dd730 100644 --- a/namedisl/set_like.py +++ b/namedisl/set_like.py @@ -28,7 +28,7 @@ import operator from abc import ABC from dataclasses import dataclass, replace -from typing import TYPE_CHECKING, cast, final, overload +from typing import TYPE_CHECKING, final, overload from constantdict import constantdict from typing_extensions import Self, override @@ -440,6 +440,15 @@ def _apply_set_like_binary_op( ... +@overload +def _apply_set_like_binary_op( + lhs: _NamedIslSetLike, + rhs: _NamedIslSetLike, + op: Callable[[isl.Set, isl.Set], isl.Set] + ) -> _NamedIslSetLike: + ... + + def _apply_set_like_binary_op( lhs: _NamedIslSetLike, rhs: _NamedIslSetLike, @@ -575,18 +584,30 @@ def _validate_composable( def intersect_domain(self, domain: BasicSet | Set) -> BasicMap | Map: domain_obj = domain._reconstruct_isl_object() assert isinstance(domain_obj, isl.BasicSet | isl.Set) - return cast("BasicMap | Map", self & self._map_with_universe( + result = _apply_set_like_binary_op( + self, + self._map_with_universe( isl.dim_type.in_, domain_obj, - )) + ), + operator.and_, + ) + assert isinstance(result, BasicMap | Map) + return result def intersect_range(self, range_: BasicSet | Set) -> BasicMap | Map: range_obj = range_._reconstruct_isl_object() assert isinstance(range_obj, isl.BasicSet | isl.Set) - return cast("BasicMap | Map", self & self._map_with_universe( + result = _apply_set_like_binary_op( + self, + self._map_with_universe( isl.dim_type.out, range_obj, - )) + ), + operator.and_, + ) + assert isinstance(result, BasicMap | Map) + return result def apply_range(self, other: BasicMap | Map) -> BasicMap | Map: ordered_names = self._validate_composable( @@ -594,12 +615,12 @@ def apply_range(self, other: BasicMap | Map) -> BasicMap | Map: other, isl.dim_type.in_ ) - other = cast("BasicMap | Map", other._reorder_interface( - isl.dim_type.in_, ordered_names)) + reordered_other = other._reorder_interface(isl.dim_type.in_, ordered_names) + assert isinstance(reordered_other, BasicMap | Map) self._reject_surviving_name_collisions( - self.input_names & other._output_names() + self.input_names & reordered_other._output_names() ) - result = self._map_obj().apply_range(other._map_obj()) + result = self._map_obj().apply_range(reordered_other._map_obj()) return self._wrap_map_result(result) def apply_domain(self, other: BasicMap | Map) -> BasicMap | Map: @@ -608,12 +629,12 @@ def apply_domain(self, other: BasicMap | Map) -> BasicMap | Map: other, isl.dim_type.out ) - other = cast("BasicMap | Map", other._reorder_interface( - isl.dim_type.out, ordered_names)) + reordered_other = other._reorder_interface(isl.dim_type.out, ordered_names) + assert isinstance(reordered_other, BasicMap | Map) self._reject_surviving_name_collisions( - other.input_names & self._output_names() + reordered_other.input_names & self._output_names() ) - result = other._map_obj().apply_range(self._map_obj()) + result = reordered_other._map_obj().apply_range(self._map_obj()) return self._wrap_map_result(result) def reverse(self) -> BasicMap | Map: @@ -683,7 +704,9 @@ def intersect_domain(self, domain: BasicSet | Set) -> BasicMap | Map: isl.BasicSet.universe(range_space) ) ) - return cast("BasicMap | Map", self & filter_map) + result = self & filter_map + assert isinstance(result, BasicMap | Map) + return result return super().intersect_domain(domain) @override @@ -696,7 +719,9 @@ def intersect_range(self, range_: BasicSet | Set) -> BasicMap | Map: range_._reconstruct_isl_object() ) ) - return cast("BasicMap | Map", self & filter_map) + result = self & filter_map + assert isinstance(result, BasicMap | Map) + return result return super().intersect_range(range_) @override From 26a4f8a4b77244f383b60cfaec36af65fa999b13 Mon Sep 17 00:00:00 2001 From: Addison Date: Tue, 19 May 2026 15:33:58 -0500 Subject: [PATCH 31/43] more setlike operations, tests --- namedisl/core.py | 245 ++++++++++++---------------- namedisl/set_like.py | 289 ++++++++++++++------------------- namedisl/test/test_set_like.py | 28 ++++ 3 files changed, 251 insertions(+), 311 deletions(-) diff --git a/namedisl/core.py b/namedisl/core.py index ead445e..b284060 100644 --- a/namedisl/core.py +++ b/namedisl/core.py @@ -42,7 +42,7 @@ isl.dim_type.out, isl.dim_type.in_, isl.dim_type.set, - isl.dim_type.param + isl.dim_type.param, ] @@ -62,18 +62,11 @@ "IslExpressionLikeT", bound=IslExpressionLike, ) -IslSetLikeT = TypeVar( - "IslSetLikeT", - bound=IslSetLike -) +IslSetLikeT = TypeVar("IslSetLikeT", bound=IslSetLike) IslMultiExpressionLikeT = TypeVar( - "IslMultiExpressionLikeT", - bound=IslMultiExpressionLike -) -IslPwExpressionLikeT = TypeVar( - "IslPwExpressionLikeT", - bound=IslPwExpressionLike + "IslMultiExpressionLikeT", bound=IslMultiExpressionLike ) +IslPwExpressionLikeT = TypeVar("IslPwExpressionLikeT", bound=IslPwExpressionLike) IslObjectT = TypeVar("IslObjectT", bound=IslObject, covariant=True) NamedIslObjectT = TypeVar("NamedIslObjectT", bound="NamedIslObject[IslObject]") @@ -134,9 +127,7 @@ def _ensure_unique_public_names(obj: IslObject) -> None: def _strip_names(obj: IslObjectT) -> tuple[IslObjectT, NameToDim]: name_to_dim: dict[str, int] = {} - dt_to_strip = ( - isl.dim_type.set if isinstance(obj, IslSetLike) else isl.dim_type.in_ - ) + dt_to_strip = isl.dim_type.set if isinstance(obj, IslSetLike) else isl.dim_type.in_ stripped_obj = obj.copy() @@ -170,9 +161,8 @@ def _get_obj_dim_name(obj: IslObject, dt: isl.dim_type, dim: int) -> str: def _normalize_dimtype_to_names( - obj: IslObject, - dimtype_to_names: DimTypeToNames - ) -> DimTypeToNames: + obj: IslObject, dimtype_to_names: DimTypeToNames +) -> DimTypeToNames: if isinstance(obj, IslSetLike | IslMultiExpressionLike): dim_type = isl.dim_type.set total_dims = obj.dim(dim_type) @@ -229,24 +219,23 @@ def _restore_names(obj: IslObjectT, name_to_dim: NameToDim) -> IslObjectT: 0, isl.dim_type.in_, 0, - restored_obj.dim(isl.dim_type.in_) + restored_obj.dim(isl.dim_type.in_), ) for name, dim in name_to_dim.items(): - restored_obj = restored_obj.set_dim_name( - isl.dim_type.param, - dim, - name - ) + restored_obj = restored_obj.set_dim_name(isl.dim_type.param, dim, name) restored_obj = restored_obj.get_pw_aff_list().get_at(0) - return cast("IslObjectT", restored_obj.move_dims( - isl.dim_type.in_, - 0, - isl.dim_type.param, - 0, - restored_obj.dim(isl.dim_type.param) - )) + return cast( + "IslObjectT", + restored_obj.move_dims( + isl.dim_type.in_, + 0, + isl.dim_type.param, + 0, + restored_obj.dim(isl.dim_type.param), + ), + ) if isinstance(restored_obj, IslSetLike): dt_to_restore = isl.dim_type.set @@ -265,7 +254,6 @@ def _restore_names(obj: IslObjectT, name_to_dim: NameToDim) -> IslObjectT: def _get_dim_names(obj: IslObject, dt: isl.dim_type) -> frozenset[str]: all_dt_names: list[str] = [] for dim in range(obj.dim(dt)): - if isinstance(obj, isl.QPolynomial | isl.PwQPolynomial): name = obj.space.get_dim_name(dt, dim) else: @@ -280,19 +268,16 @@ def _get_dim_names(obj: IslObject, dt: isl.dim_type) -> frozenset[str]: @overload -def _deconstruct_object(obj: isl.Map) -> tuple[isl.Set, DimTypeToNames]: - ... +def _deconstruct_object(obj: isl.Map) -> tuple[isl.Set, DimTypeToNames]: ... @overload # PwMultiAff doesn't have move_dims, so we're being a bit crooked here. -def _deconstruct_object(obj: isl.PwMultiAff) -> tuple[isl.Set, DimTypeToNames]: - ... +def _deconstruct_object(obj: isl.PwMultiAff) -> tuple[isl.Set, DimTypeToNames]: ... @overload -def _deconstruct_object(obj: IslObject) -> tuple[IslObject, DimTypeToNames]: - ... +def _deconstruct_object(obj: IslObject) -> tuple[IslObject, DimTypeToNames]: ... def _deconstruct_object(obj: IslObject) -> tuple[IslObject, DimTypeToNames]: @@ -300,9 +285,7 @@ def _deconstruct_object(obj: IslObject) -> tuple[IslObject, DimTypeToNames]: if isinstance(obj, IslSetLike | IslMultiExpressionLike): decon_obj = obj - dt_to_names = dict.fromkeys( - [isl.dim_type.in_, isl.dim_type.param], frozenset() - ) + dt_to_names = dict.fromkeys([isl.dim_type.in_, isl.dim_type.param], frozenset()) # NOTE: isl.PwMultiAff.move_dims does not exist, represent as map # internally @@ -317,7 +300,7 @@ def _deconstruct_object(obj: IslObject) -> tuple[IslObject, DimTypeToNames]: decon_obj.dim(isl.dim_type.set), dt, 0, - decon_obj.dim(dt) + decon_obj.dim(dt), ) decon_obj = ( @@ -336,15 +319,14 @@ def _deconstruct_object(obj: IslObject) -> tuple[IslObject, DimTypeToNames]: decon_obj = obj dt_to_names = dict.fromkeys([isl.dim_type.param], frozenset()) - dt_to_names[isl.dim_type.param] = _get_dim_names(decon_obj, - isl.dim_type.param) + dt_to_names[isl.dim_type.param] = _get_dim_names(decon_obj, isl.dim_type.param) decon_obj = decon_obj.move_dims( isl.dim_type.in_, decon_obj.dim(isl.dim_type.in_), isl.dim_type.param, 0, - decon_obj.dim(isl.dim_type.param) + decon_obj.dim(isl.dim_type.param), ) return decon_obj, constantdict(dt_to_names) @@ -365,6 +347,7 @@ def _find_contiguous_dim_chunks(dims: Sequence[int]) -> Mapping[int, int]: count = 1 from itertools import pairwise + for prev, curr in pairwise(dims): if curr == prev + 1: count += 1 @@ -381,27 +364,24 @@ def _find_contiguous_dim_chunks(dims: Sequence[int]) -> Mapping[int, int]: # FIXME: think through whether or not alphabetical ordering will require more # work on average than using one of the objects as a template in alignment def _find_joint_name_to_dim( - obj1: NamedIslObject[IslObjectT], - obj2: NamedIslObject[IslObjectT] - ) -> tuple[NameToDim, DimTypeToNames]: + obj1: NamedIslObject[IslObjectT], obj2: NamedIslObject[IslObjectT] +) -> tuple[NameToDim, DimTypeToNames]: """ Enforces alphabetical ordering of all dimensions found in :arg:`obj1` and :arg:`obj2` within each dimension-type chunk. This ordering is used in alignment before performing operations between two set-like objects. """ - all_set_names = ( - obj1._names_for_dim_type(isl.dim_type.set) - | obj2._names_for_dim_type(isl.dim_type.set) - ) + all_set_names = obj1._names_for_dim_type( + isl.dim_type.set + ) | obj2._names_for_dim_type(isl.dim_type.set) all_inp_names = obj1.input_names | obj2.input_names all_param_names = obj1.parameter_names | obj2.parameter_names dt_to_names: DimTypeToNames = {} dt_to_names[isl.dim_type.param] = all_param_names - if ( - isinstance(obj1._obj, IslSetLike | IslMultiExpressionLike) - or isinstance(obj2._obj, IslSetLike | IslMultiExpressionLike) - ): + if isinstance(obj1._obj, IslSetLike | IslMultiExpressionLike) or isinstance( + obj2._obj, IslSetLike | IslMultiExpressionLike + ): dt_to_names[isl.dim_type.in_] = all_inp_names # enforces contiguous ordering of [ (set), (input), (param) ] in set @@ -420,10 +400,8 @@ def _find_joint_name_to_dim( def _align_obj( - named_obj: NamedIslObjectT, - ordering: NameToDim, - dimtype_to_names: DimTypeToNames - ) -> NamedIslObjectT: + named_obj: NamedIslObjectT, ordering: NameToDim, dimtype_to_names: DimTypeToNames +) -> NamedIslObjectT: new_isl_obj = cast( "IslSetLike | IslBaseExpressionLike | IslPwExpressionLike | isl.MultiAff", named_obj._obj, @@ -431,9 +409,7 @@ def _align_obj( running_name_to_dim = dict(named_obj._name_to_dim) target_dt = ( - isl.dim_type.set - if isinstance(new_isl_obj, IslSetLike) - else isl.dim_type.in_ + isl.dim_type.set if isinstance(new_isl_obj, IslSetLike) else isl.dim_type.in_ ) for name, target_dim in sorted(ordering.items(), key=lambda x: x[1]): @@ -446,13 +422,11 @@ def _align_obj( # temporarily move to parameter dimension since destination and # source dim types cannot match in ISL new_isl_obj = new_isl_obj.move_dims( - isl.dim_type.param, 0, - target_dt, old_dim, 1 + isl.dim_type.param, 0, target_dt, old_dim, 1 ) new_isl_obj = new_isl_obj.move_dims( - target_dt, target_dim, - isl.dim_type.param, 0, 1 + target_dt, target_dim, isl.dim_type.param, 0, 1 ) else: @@ -478,12 +452,10 @@ def _align_obj( def _align_two( - named_obj1: NamedIslObjectT, - named_obj2: NamedIslObjectT - ) -> tuple[NamedIslObjectT, ...]: + named_obj1: NamedIslObjectT, named_obj2: NamedIslObjectT +) -> tuple[NamedIslObjectT, ...]: - name_to_dim, dimtype_to_names = _find_joint_name_to_dim(named_obj1, - named_obj2) + name_to_dim, dimtype_to_names = _find_joint_name_to_dim(named_obj1, named_obj2) named_obj1 = _align_obj(named_obj1, name_to_dim, dimtype_to_names) named_obj2 = _align_obj(named_obj2, name_to_dim, dimtype_to_names) @@ -492,10 +464,10 @@ def _align_two( def _align_and_apply_binary_op( - lhs: NamedIslObject[IslObjectT], - rhs: NamedIslObject[IslObjectT], - op: Callable[[IslObjectT, IslObjectT], IslObjectT] - ) -> NamedIslObject[IslObjectT]: + lhs: NamedIslObject[IslObjectT], + rhs: NamedIslObject[IslObjectT], + op: Callable[[IslObjectT, IslObjectT], IslObjectT], +) -> NamedIslObject[IslObjectT]: lhs, rhs = _align_two(lhs, rhs) result = op(lhs._obj, rhs._obj) @@ -558,11 +530,8 @@ def _empty_grouped_names(self) -> dict[isl.dim_type, list[str]]: } def _metadata_from_chunk_names( - self, - chunk_names: Mapping[isl.dim_type, Collection[str]], - *, - has_inputs: bool - ) -> tuple[NameToDim, DimTypeToNames]: + self, chunk_names: Mapping[isl.dim_type, Collection[str]], *, has_inputs: bool + ) -> tuple[NameToDim, DimTypeToNames]: ordered_names = [ *chunk_names[isl.dim_type.set], *chunk_names[isl.dim_type.in_], @@ -585,10 +554,8 @@ def _metadata_from_chunk_names( return new_name_to_dim, constantdict(new_dimtype_to_names) def _add_names_by_dim_type( - self, - names_to_add: Collection[str], - dim_type: isl.dim_type - ) -> Self: + self, names_to_add: Collection[str], dim_type: isl.dim_type + ) -> Self: if isinstance(self._obj, isl.PwMultiAff): raise NotImplementedError @@ -596,9 +563,9 @@ def _add_names_by_dim_type( if dim_type not in (isl.dim_type.set, isl.dim_type.in_, isl.dim_type.param): raise ValueError(f"unsupported dim type: {dim_type}") if ( - not isinstance(self._obj, IslSetLike | IslMultiExpressionLike) - and dim_type == isl.dim_type.set - ): + not isinstance(self._obj, IslSetLike | IslMultiExpressionLike) + and dim_type == isl.dim_type.set + ): raise ValueError(f"unsupported dim type: {dim_type}") if len(set(names_to_add)) != len(tuple(names_to_add)): @@ -617,9 +584,8 @@ def _add_names_by_dim_type( return self._add_grouped_names(grouped_names) def _add_grouped_names( - self, - grouped_names: Mapping[isl.dim_type, Collection[str]] - ) -> Self: + self, grouped_names: Mapping[isl.dim_type, Collection[str]] + ) -> Self: if isinstance(self._obj, isl.PwMultiAff): raise NotImplementedError @@ -634,8 +600,7 @@ def _add_grouped_names( new_obj = self._obj chunk_names = { - dt: list(names) - for dt, names in self._ordered_name_chunks().items() + dt: list(names) for dt, names in self._ordered_name_chunks().items() } internal_dim_type = ( isl.dim_type.set @@ -647,8 +612,7 @@ def _add_grouped_names( isl.dim_type.set: 0, isl.dim_type.in_: len(chunk_names[isl.dim_type.set]), isl.dim_type.param: ( - len(chunk_names[isl.dim_type.set]) - + len(chunk_names[isl.dim_type.in_]) + len(chunk_names[isl.dim_type.set]) + len(chunk_names[isl.dim_type.in_]) ), } @@ -657,9 +621,7 @@ def _add_grouped_names( if not names_to_add: continue new_obj = new_obj.insert_dims( - internal_dim_type, - insertion_starts[dim_type], - len(names_to_add) + internal_dim_type, insertion_starts[dim_type], len(names_to_add) ) chunk_names[dim_type] = [*names_to_add, *chunk_names[dim_type]] @@ -697,10 +659,8 @@ def add_parameter_names(self, names_to_add: Collection[str]) -> Self: return self._add_names_by_dim_type(names_to_add, isl.dim_type.param) def add_dim_names( - self, - names_to_add: Collection[str], - dim_type: isl.dim_type - ) -> Self: + self, names_to_add: Collection[str], dim_type: isl.dim_type + ) -> Self: return self._add_names_by_dim_type(names_to_add, dim_type) @property @@ -710,6 +670,9 @@ def names(self) -> frozenset[str]: def dim_names(self, dim_type: isl.dim_type) -> frozenset[str]: return self._names_for_dim_type(dim_type) + def ordered_dim_names(self, dim_type: isl.dim_type) -> tuple[str, ...]: + return self._ordered_names_for_dim_type(dim_type) + @property def set_names(self) -> frozenset[str]: return self._names_for_dim_type(isl.dim_type.set) @@ -728,10 +691,10 @@ def dim(self, dim_type: isl.dim_type) -> int: return self._reconstruct_isl_object().dim(dim_type) def move_dims( - self, - names_to_move: str | Collection[str], - dst_type: isl.dim_type, - ) -> Self: + self, + names_to_move: str | Collection[str], + dst_type: isl.dim_type, + ) -> Self: if isinstance(names_to_move, str): names_to_move = [names_to_move] @@ -750,7 +713,8 @@ def move_dims( raise ValueError("duplicate names in move_dims") names_to_move = [ - name for name in names_to_move + name + for name in names_to_move if name not in self._names_for_dim_type(dst_type) ] if not names_to_move: @@ -758,10 +722,7 @@ def move_dims( moved_name_set = set(names_to_move) chunk_names = { - dt: [ - name for name in names - if name not in moved_name_set - ] + dt: [name for name in names if name not in moved_name_set] for dt, names in self._ordered_name_chunks().items() } moved_names = sorted(names_to_move, key=self._name_to_dim.__getitem__) @@ -772,11 +733,7 @@ def move_dims( has_inputs=True, ) - return _align_obj( - self, - new_name_to_dim, - new_dimtype_to_names - ) + return _align_obj(self, new_name_to_dim, new_dimtype_to_names) def rename_dims(self, renaming: Mapping[str, str]) -> Self: if not renaming: @@ -790,8 +747,7 @@ def rename_dims(self, renaming: Mapping[str, str]) -> Self: raise ValueError("duplicate destination names in rename_dims") unchanged_names = { - old_name for old_name, new_name in renaming.items() - if old_name == new_name + old_name for old_name, new_name in renaming.items() if old_name == new_name } renaming = { old_name: new_name @@ -810,32 +766,26 @@ def rename_dims(self, renaming: Mapping[str, str]) -> Self: ) new_name_to_dim: NameToDim = constantdict({ - renaming.get(name, name): dim - for name, dim in self._name_to_dim.items() + renaming.get(name, name): dim for name, dim in self._name_to_dim.items() }) new_dimtype_to_names: DimTypeToNames = constantdict({ - dim_type: frozenset( - renaming.get(name, name) - for name in names - ) + dim_type: frozenset(renaming.get(name, name) for name in names) for dim_type, names in self._dimtype_to_names.items() }) return type(self)(self._obj, new_name_to_dim, new_dimtype_to_names) @overload - def equate_dims(self, name1: Mapping[str, str]) -> Self: - ... + def equate_dims(self, name1: Mapping[str, str]) -> Self: ... @overload - def equate_dims(self, name1: str, name2: str) -> Self: - ... + def equate_dims(self, name1: str, name2: str) -> Self: ... def equate_dims( - self, - name1: str | Mapping[str, str], - name2: str | None = None, - ) -> Self: + self, + name1: str | Mapping[str, str], + name2: str | None = None, + ) -> Self: if isinstance(name1, str): if name2 is None: raise TypeError("name2 must be provided when name1 is a string") @@ -863,8 +813,10 @@ def equate_dims( for lhs_name, rhs_name in equated_names: if lhs_name != rhs_name: obj = obj.equate( - isl.dim_type.set, self._name_to_dim[lhs_name], - isl.dim_type.set, self._name_to_dim[rhs_name], + isl.dim_type.set, + self._name_to_dim[lhs_name], + isl.dim_type.set, + self._name_to_dim[rhs_name], ) return type(self)( @@ -884,10 +836,7 @@ def input_names(self) -> frozenset[str]: @property def _input_dim_start(self) -> int | None: if self._has_inputs: - return min( - self._name_to_dim[name] - for name in self._metadata_input_names - ) + return min(self._name_to_dim[name] for name in self._metadata_input_names) return None @property @@ -902,8 +851,7 @@ def parameter_names(self) -> frozenset[str]: def _parameter_dim_start(self) -> int | None: if self._has_params: return min( - self._name_to_dim[name] - for name in self._metadata_parameter_names + self._name_to_dim[name] for name in self._metadata_parameter_names ) return None @@ -926,12 +874,16 @@ def _reconstruct_isl_object(self) -> IslObject: raise ValueError( "Object has parameter dimensions, but a starting index for " "parameter names is not given. Reconstruction is not " - "possible") + "possible" + ) param_start = self._parameter_dim_start obj = obj.move_dims( - isl.dim_type.param, 0, - internal_dim, param_start, len(self.parameter_names) + isl.dim_type.param, + 0, + internal_dim, + param_start, + len(self.parameter_names), ) if self._has_inputs: @@ -939,7 +891,8 @@ def _reconstruct_isl_object(self) -> IslObject: raise ValueError( "Object has input dimensions, but a starting index for " "input names is not given. Reconstruction is not " - "possible") + "possible" + ) obj_domain = isl.Set("{ [] }") obj_range = obj @@ -949,12 +902,14 @@ def _reconstruct_isl_object(self) -> IslObject: inp_start = self._input_dim_start obj = obj.move_dims( - isl.dim_type.in_, 0, - internal_dim, inp_start, len(self.input_names) + isl.dim_type.in_, 0, internal_dim, inp_start, len(self.input_names) ) return obj + def get_isl_object(self) -> IslObject: + return self._reconstruct_isl_object() + @override def __str__(self) -> str: return str(self._reconstruct_isl_object()) diff --git a/namedisl/set_like.py b/namedisl/set_like.py index 58dd730..836e563 100644 --- a/namedisl/set_like.py +++ b/namedisl/set_like.py @@ -62,16 +62,36 @@ def complement(self: Self) -> Self: self, _obj=self._obj.complement(), _name_to_dim=self._name_to_dim, - _dimtype_to_names=self._dimtype_to_names + _dimtype_to_names=self._dimtype_to_names, ) + @overload + def convex_hull(self: BasicMap | Map) -> BasicMap: ... + + @overload + def convex_hull(self: BasicSet | Set) -> BasicSet: ... + + def convex_hull(self) -> BasicMap | BasicSet: + obj = self._reconstruct_isl_object() + assert isinstance(obj, isl.BasicMap | isl.Map | isl.BasicSet | isl.Set) + + if isinstance(obj, isl.BasicMap): + return make_basic_map(obj.to_map().convex_hull()) + + if isinstance(obj, isl.Map): + return make_basic_map(obj.convex_hull()) + + if isinstance(obj, isl.BasicSet): + return make_basic_set(obj.to_set().convex_hull()) + + return make_basic_set(obj.convex_hull()) + def eliminate(self: Self, names_to_eliminate: str | Collection[str]) -> Self: if isinstance(names_to_eliminate, str): names_to_eliminate = [names_to_eliminate] dims_to_eliminate = sorted( - self._name_to_dim[name] - for name in names_to_eliminate + self._name_to_dim[name] for name in names_to_eliminate ) contiguous_dim_chunks = _find_contiguous_dim_chunks(dims_to_eliminate) @@ -86,13 +106,13 @@ def eliminate(self: Self, names_to_eliminate: str | Collection[str]) -> Self: self, _obj=new_isl_obj, _name_to_dim=self._name_to_dim, # NOTE: no dims removed by eliminate - _dimtype_to_names=self._dimtype_to_names + _dimtype_to_names=self._dimtype_to_names, ) def add_constraint( - self: Self, - constraints: str | Collection[str], - ) -> Self: + self: Self, + constraints: str | Collection[str], + ) -> Self: if isinstance(constraints, str): constraints = [constraints] else: @@ -104,21 +124,14 @@ def add_constraint( ordered_names = tuple( sorted(self._name_to_dim, key=self._name_to_dim.__getitem__) ) - constraint_text = " and ".join(f"({constraint})" - for constraint in constraints) - constraint_src = ( - "{ " - f"[{', '.join(ordered_names)}] : " - f"{constraint_text}" - " }" - ) + constraint_text = " and ".join(f"({constraint})" for constraint in constraints) + constraint_src = f"{{ [{', '.join(ordered_names)}] : {constraint_text} }}" try: constraint_obj = isl.Set(constraint_src) except isl.Error as exc: raise ValueError( - f"invalid constraint for names {ordered_names}: " - f"{constraint_text}" + f"invalid constraint for names {ordered_names}: {constraint_text}" ) from exc constraint_obj = constraint_obj.remove_redundancies() @@ -129,13 +142,9 @@ def add_constraint( if constraint_name_to_dim != self._name_to_dim: constraint_set = _align_obj( - Set( - constraint_set, - constraint_name_to_dim, - self._dimtype_to_names - ), + Set(constraint_set, constraint_name_to_dim, self._dimtype_to_names), self._name_to_dim, - self._dimtype_to_names + self._dimtype_to_names, )._obj assert isinstance(constraint_set, isl.Set) @@ -143,24 +152,20 @@ def add_constraint( self, _obj=self._obj.intersect(constraint_set), _name_to_dim=self._name_to_dim, - _dimtype_to_names=self._dimtype_to_names + _dimtype_to_names=self._dimtype_to_names, ) @overload - def gist(self: BasicMap, context: _NamedIslSetLike) -> BasicMap | Map: - ... + def gist(self: BasicMap, context: _NamedIslSetLike) -> BasicMap | Map: ... @overload - def gist(self: Map, context: _NamedIslSetLike) -> Map: - ... + def gist(self: Map, context: _NamedIslSetLike) -> Map: ... @overload - def gist(self: BasicSet, context: _NamedIslSetLike) -> BasicSet | Set: - ... + def gist(self: BasicSet, context: _NamedIslSetLike) -> BasicSet | Set: ... @overload - def gist(self: Set, context: _NamedIslSetLike) -> Set: - ... + def gist(self: Set, context: _NamedIslSetLike) -> Set: ... def gist(self, context: _NamedIslSetLike) -> _NamedIslSetLike: self_aligned, context_aligned = _align_two(self, context) @@ -176,23 +181,17 @@ def gist(self, context: _NamedIslSetLike) -> _NamedIslSetLike: result_type = Set return result_type( - result, - self_aligned._name_to_dim, - self_aligned._dimtype_to_names + result, self_aligned._name_to_dim, self_aligned._dimtype_to_names ) - def project_out(self: Self, - names_to_project_out: str | Collection[str]) -> Self: + def project_out(self: Self, names_to_project_out: str | Collection[str]) -> Self: if isinstance(names_to_project_out, str): names_to_project_out = [names_to_project_out] names_to_remove = set(names_to_project_out) - dims_to_remove = sorted( - self._name_to_dim[name] - for name in names_to_remove - ) + dims_to_remove = sorted(self._name_to_dim[name] for name in names_to_remove) new_isl_obj = self._obj contiguous_dim_chunks = _find_contiguous_dim_chunks(dims_to_remove) @@ -224,7 +223,7 @@ def project_out(self: Self, self, _obj=new_isl_obj, _name_to_dim=constantdict(new_name_to_dim), - _dimtype_to_names=new_type_to_names + _dimtype_to_names=new_type_to_names, ) def project_out_except( @@ -236,8 +235,7 @@ def project_out_except( names_to_keep = [names_to_keep] if names_to_keep else [] names_to_project_out = [ - name for name in self._name_to_dim - if name not in names_to_keep + name for name in self._name_to_dim if name not in names_to_keep ] return self.project_out(names_to_project_out) @@ -248,6 +246,11 @@ def dim_max(self, name: str) -> isl.PwAff: def dim_min(self, name: str) -> isl.PwAff: return self._obj.dim_min(self._name_to_dim[name]) + def is_empty(self) -> bool: + obj = self._reconstruct_isl_object() + assert isinstance(obj, isl.Set | isl.Map) + return bool(obj.is_empty()) + def as_pw_multi_aff(self) -> isl.PwMultiAff: obj = self._reconstruct_isl_object() assert isinstance(obj, isl.Set | isl.Map) @@ -260,63 +263,48 @@ def dim(self, dim_type: isl.dim_type) -> int: return super().dim(dim_type) @overload - def __and__(self: BasicMap, other: BasicMap | Map) -> BasicMap | Map: - ... + def __and__(self: BasicMap, other: BasicMap | Map) -> BasicMap | Map: ... @overload - def __and__(self: Map, other: BasicMap | Map) -> Map: - ... + def __and__(self: Map, other: BasicMap | Map) -> Map: ... @overload - def __and__(self: BasicSet, other: BasicSet | Set) -> BasicSet | Set: - ... + def __and__(self: BasicSet, other: BasicSet | Set) -> BasicSet | Set: ... @overload - def __and__(self: Set, other: BasicSet | Set) -> Set: - ... + def __and__(self: Set, other: BasicSet | Set) -> Set: ... - def __and__( - self, other: _NamedIslSetLike) -> _NamedIslSetLike: + def __and__(self, other: _NamedIslSetLike) -> _NamedIslSetLike: return _apply_set_like_binary_op(self, other, operator.and_) @overload - def __or__(self: BasicMap, other: BasicMap | Map) -> BasicMap | Map: - ... + def __or__(self: BasicMap, other: BasicMap | Map) -> BasicMap | Map: ... @overload - def __or__(self: Map, other: BasicMap | Map) -> Map: - ... + def __or__(self: Map, other: BasicMap | Map) -> Map: ... @overload - def __or__(self: BasicSet, other: BasicSet | Set) -> BasicSet | Set: - ... + def __or__(self: BasicSet, other: BasicSet | Set) -> BasicSet | Set: ... @overload - def __or__(self: Set, other: BasicSet | Set) -> Set: - ... + def __or__(self: Set, other: BasicSet | Set) -> Set: ... - def __or__( - self, other: _NamedIslSetLike) -> _NamedIslSetLike: + def __or__(self, other: _NamedIslSetLike) -> _NamedIslSetLike: return _apply_set_like_binary_op(self, other, operator.or_) @overload - def __sub__(self: BasicMap, other: BasicMap | Map) -> BasicMap | Map: - ... + def __sub__(self: BasicMap, other: BasicMap | Map) -> BasicMap | Map: ... @overload - def __sub__(self: Map, other: BasicMap | Map) -> Map: - ... + def __sub__(self: Map, other: BasicMap | Map) -> Map: ... @overload - def __sub__(self: BasicSet, other: BasicSet | Set) -> BasicSet | Set: - ... + def __sub__(self: BasicSet, other: BasicSet | Set) -> BasicSet | Set: ... @overload - def __sub__(self: Set, other: BasicSet | Set) -> Set: - ... + def __sub__(self: Set, other: BasicSet | Set) -> Set: ... - def __sub__( - self, other: _NamedIslSetLike) -> _NamedIslSetLike: + def __sub__(self, other: _NamedIslSetLike) -> _NamedIslSetLike: return _apply_set_like_binary_op(self, other, operator.sub) @override @@ -359,22 +347,18 @@ def _reconstruct_isl_object(self) -> isl.BasicSet: if not isinstance(obj, isl.Set) or obj.n_basic_set() != 1: raise ValueError( "Cannot reconstruct an isl.BasicSet from anything other than " - "an isl.Set containing only a single isl.BasicSet.") + "an isl.Set containing only a single isl.BasicSet." + ) return obj.get_basic_sets()[0] - def get_basic_sets(self) -> Sequence[BasicSet]: - return [self] - @overload -def make_basic_set(src: str, ctx: isl.Context | None = None) -> BasicSet: - ... +def make_basic_set(src: str, ctx: isl.Context | None = None) -> BasicSet: ... @overload -def make_basic_set(src: isl.BasicSet) -> BasicSet: - ... +def make_basic_set(src: isl.BasicSet) -> BasicSet: ... def make_basic_set(src: str | isl.BasicSet, ctx: isl.Context | None = None) -> BasicSet: @@ -406,54 +390,41 @@ def get_basic_sets(self) -> Sequence[BasicSet]: @overload def _apply_set_like_binary_op( - lhs: BasicMap, - rhs: BasicMap | Map, - op: Callable[[isl.Set, isl.Set], isl.Set] - ) -> BasicMap | Map: - ... + lhs: BasicMap, rhs: BasicMap | Map, op: Callable[[isl.Set, isl.Set], isl.Set] +) -> BasicMap | Map: ... @overload def _apply_set_like_binary_op( - lhs: Map, - rhs: BasicMap | Map, - op: Callable[[isl.Set, isl.Set], isl.Set] - ) -> Map: - ... + lhs: Map, rhs: BasicMap | Map, op: Callable[[isl.Set, isl.Set], isl.Set] +) -> Map: ... @overload def _apply_set_like_binary_op( - lhs: BasicSet, - rhs: BasicSet | Set, - op: Callable[[isl.Set, isl.Set], isl.Set] - ) -> BasicSet | Set: - ... + lhs: BasicSet, rhs: BasicSet | Set, op: Callable[[isl.Set, isl.Set], isl.Set] +) -> BasicSet | Set: ... @overload def _apply_set_like_binary_op( - lhs: Set, - rhs: BasicSet | Set, - op: Callable[[isl.Set, isl.Set], isl.Set] - ) -> Set: - ... + lhs: Set, rhs: BasicSet | Set, op: Callable[[isl.Set, isl.Set], isl.Set] +) -> Set: ... @overload def _apply_set_like_binary_op( - lhs: _NamedIslSetLike, - rhs: _NamedIslSetLike, - op: Callable[[isl.Set, isl.Set], isl.Set] - ) -> _NamedIslSetLike: - ... + lhs: _NamedIslSetLike, + rhs: _NamedIslSetLike, + op: Callable[[isl.Set, isl.Set], isl.Set], +) -> _NamedIslSetLike: ... def _apply_set_like_binary_op( - lhs: _NamedIslSetLike, - rhs: _NamedIslSetLike, - op: Callable[[isl.Set, isl.Set], isl.Set] - ) -> _NamedIslSetLike: + lhs: _NamedIslSetLike, + rhs: _NamedIslSetLike, + op: Callable[[isl.Set, isl.Set], isl.Set], +) -> _NamedIslSetLike: lhs, rhs = _align_two(lhs, rhs) result = op(lhs._obj, rhs._obj) @@ -470,10 +441,8 @@ def _apply_set_like_binary_op( def _compare_set_like( - lhs: _NamedIslSetLike, - rhs: _NamedIslSetLike, - op: Callable[[isl.Set, isl.Set], bool] - ) -> bool: + lhs: _NamedIslSetLike, rhs: _NamedIslSetLike, op: Callable[[isl.Set, isl.Set], bool] +) -> bool: lhs_is_map = isinstance(lhs, _NamedIslMapLike) rhs_is_map = isinstance(rhs, _NamedIslMapLike) if lhs_is_map != rhs_is_map: @@ -508,10 +477,8 @@ def _wrap_map_result(result: isl.BasicMap | isl.Map) -> BasicMap | Map: return make_map(result) def _map_with_universe( - self, - dim_type: isl.dim_type, - set_obj: isl.BasicSet | isl.Set - ) -> BasicMap | Map: + self, dim_type: isl.dim_type, set_obj: isl.BasicSet | isl.Set + ) -> BasicMap | Map: map_obj = self._map_obj() if dim_type == isl.dim_type.in_: universe = isl.Set.universe(map_obj.range().get_space()) @@ -525,9 +492,9 @@ def _ordered_names(self, names: frozenset[str]) -> tuple[str, ...]: return tuple(sorted(names, key=self._name_to_dim.__getitem__)) def _reject_surviving_name_collisions( - self, - collisions: frozenset[str], - ) -> None: + self, + collisions: frozenset[str], + ) -> None: if collisions: raise ValueError( "composition would create duplicate surviving names: " @@ -535,10 +502,8 @@ def _reject_surviving_name_collisions( ) def _reorder_interface( - self, - dim_type: isl.dim_type, - ordered_names: tuple[str, ...] - ) -> _NamedIslMapLike: + self, dim_type: isl.dim_type, ordered_names: tuple[str, ...] + ) -> _NamedIslMapLike: interface_names = ( self.input_names if dim_type == isl.dim_type.in_ else self._output_names() ) @@ -547,34 +512,37 @@ def _reorder_interface( return self out_names = ( - ordered_names if dim_type == isl.dim_type.out + ordered_names + if dim_type == isl.dim_type.out else self._ordered_names(self._output_names()) ) in_names = ( - ordered_names if dim_type == isl.dim_type.in_ + ordered_names + if dim_type == isl.dim_type.in_ else self._ordered_names(self.input_names) ) param_names = self._ordered_names(self.parameter_names) ordering: NameToDim = constantdict({ - name: dim - for dim, name in enumerate((*out_names, *in_names, *param_names)) + name: dim for dim, name in enumerate((*out_names, *in_names, *param_names)) }) return _align_obj(self, ordering, self._dimtype_to_names) def _validate_composable( - self, - lhs_dim_type: isl.dim_type, - other: BasicMap | Map, - rhs_dim_type: isl.dim_type - ) -> tuple[str, ...]: + self, + lhs_dim_type: isl.dim_type, + other: BasicMap | Map, + rhs_dim_type: isl.dim_type, + ) -> tuple[str, ...]: lhs_names = ( - self.input_names if lhs_dim_type == isl.dim_type.in_ + self.input_names + if lhs_dim_type == isl.dim_type.in_ else self._output_names() ) rhs_names = ( - other.input_names if rhs_dim_type == isl.dim_type.in_ + other.input_names + if rhs_dim_type == isl.dim_type.in_ else other._output_names() ) if lhs_names != rhs_names: @@ -587,8 +555,8 @@ def intersect_domain(self, domain: BasicSet | Set) -> BasicMap | Map: result = _apply_set_like_binary_op( self, self._map_with_universe( - isl.dim_type.in_, - domain_obj, + isl.dim_type.in_, + domain_obj, ), operator.and_, ) @@ -601,8 +569,8 @@ def intersect_range(self, range_: BasicSet | Set) -> BasicMap | Map: result = _apply_set_like_binary_op( self, self._map_with_universe( - isl.dim_type.out, - range_obj, + isl.dim_type.out, + range_obj, ), operator.and_, ) @@ -611,9 +579,7 @@ def intersect_range(self, range_: BasicSet | Set) -> BasicMap | Map: def apply_range(self, other: BasicMap | Map) -> BasicMap | Map: ordered_names = self._validate_composable( - isl.dim_type.out, - other, - isl.dim_type.in_ + isl.dim_type.out, other, isl.dim_type.in_ ) reordered_other = other._reorder_interface(isl.dim_type.in_, ordered_names) assert isinstance(reordered_other, BasicMap | Map) @@ -625,9 +591,7 @@ def apply_range(self, other: BasicMap | Map) -> BasicMap | Map: def apply_domain(self, other: BasicMap | Map) -> BasicMap | Map: ordered_names = self._validate_composable( - isl.dim_type.in_, - other, - isl.dim_type.out + isl.dim_type.in_, other, isl.dim_type.out ) reordered_other = other._reorder_interface(isl.dim_type.out, ordered_names) assert isinstance(reordered_other, BasicMap | Map) @@ -654,13 +618,11 @@ def range(self) -> BasicSet | Set: @overload -def make_set(src: str, ctx: isl.Context | None = None) -> Set: - ... +def make_set(src: str, ctx: isl.Context | None = None) -> Set: ... @overload -def make_set(src: isl.Set) -> Set: - ... +def make_set(src: isl.Set) -> Set: ... def make_set(src: isl.Set | str, ctx: isl.Context | None = None) -> Set: @@ -700,8 +662,7 @@ def intersect_domain(self, domain: BasicSet | Set) -> BasicMap | Map: range_space = self._map_obj().range().get_space() filter_map = make_basic_map( isl.BasicMap.from_domain_and_range( - domain._reconstruct_isl_object(), - isl.BasicSet.universe(range_space) + domain._reconstruct_isl_object(), isl.BasicSet.universe(range_space) ) ) result = self & filter_map @@ -716,7 +677,7 @@ def intersect_range(self, range_: BasicSet | Set) -> BasicMap | Map: filter_map = make_basic_map( isl.BasicMap.from_domain_and_range( isl.BasicSet.universe(domain_space), - range_._reconstruct_isl_object() + range_._reconstruct_isl_object(), ) ) result = self & filter_map @@ -734,19 +695,18 @@ def _reconstruct_isl_object(self) -> isl.BasicMap: if not isinstance(obj, isl.Map) or obj.n_basic_map() != 1: raise ValueError( "Cannot reconstruct an isl.BasicMap from anything other than " - "an isl.Map containing only a single isl.BasicMap.") + "an isl.Map containing only a single isl.BasicMap." + ) return obj.get_basic_maps()[0] @overload -def make_basic_map(src: str, ctx: isl.Context | None = None) -> BasicMap: - ... +def make_basic_map(src: str, ctx: isl.Context | None = None) -> BasicMap: ... @overload -def make_basic_map(src: isl.BasicMap) -> BasicMap: - ... +def make_basic_map(src: isl.BasicMap) -> BasicMap: ... def make_basic_map(src: str | isl.BasicMap, ctx: isl.Context | None = None) -> BasicMap: @@ -757,9 +717,8 @@ def make_basic_map(src: str | isl.BasicMap, ctx: isl.Context | None = None) -> B def make_map_from_domain_and_range( - domain: BasicSet | Set, - range_: BasicSet | Set - ) -> BasicMap | Map: + domain: BasicSet | Set, range_: BasicSet | Set +) -> BasicMap | Map: if isinstance(domain, BasicSet) and isinstance(range_, BasicSet): domain_obj = domain._reconstruct_isl_object() range_obj = range_._reconstruct_isl_object() @@ -799,13 +758,11 @@ def _reconstruct_isl_object(self) -> isl.Map: @overload -def make_map(src: str, ctx: isl.Context | None = None) -> Map: - ... +def make_map(src: str, ctx: isl.Context | None = None) -> Map: ... @overload -def make_map(src: isl.Map) -> Map: - ... +def make_map(src: isl.Map) -> Map: ... def make_map(src: str | isl.Map, ctx: isl.Context | None = None) -> Map: diff --git a/namedisl/test/test_set_like.py b/namedisl/test/test_set_like.py index 8bd26b8..0ebdf7d 100644 --- a/namedisl/test/test_set_like.py +++ b/namedisl/test/test_set_like.py @@ -206,6 +206,21 @@ def test_basic_set_intersection_promotes_to_set() -> None: assert reconstructed.n_basic_set() > 1 +def test_set_convex_hull_returns_basic_set() -> None: + set_ = nisl.make_set( + "{ [j, i] : " + "(j = 0 and 0 <= i <= 2) or " + "(j = 2 and 0 <= i <= 2) }" + ) + + result = set_.convex_hull() + + assert isinstance(result, nisl.BasicSet) + assert result == nisl.make_basic_set( + "{ [j, i] : 0 <= j <= 2 and 0 <= i <= 2 }" + ) + + @pytest.mark.parametrize("ndims", [1, 2, 4, 8]) def test_set_eliminate(ndims: int): a, a_dims, _ = generate_random_named_set(ndims, "a", None) @@ -461,6 +476,19 @@ def test_basic_map_subset_comparisons_allow_map_promotion() -> None: assert larger >= smaller +def test_map_convex_hull_returns_basic_map() -> None: + map_ = nisl.make_map( + "{ [i] -> [j] : " + "(i = 0 and j = 0) or " + "(i = 2 and j = 2) }" + ) + + result = map_.convex_hull() + + assert isinstance(result, nisl.BasicMap) + assert result == nisl.make_basic_map("{ [i] -> [j] : j = i and 0 <= i <= 2 }") + + def test_subset_comparison_rejects_set_map_mismatch() -> None: set_ = nisl.make_set("{ [i] : 0 <= i < 5 }") map_ = nisl.make_map("{ [i] -> [x] : x = i and 0 <= i < 5 }") From dd2885e3c25ddf8a341c9b6f0428af0ddded4d55 Mon Sep 17 00:00:00 2001 From: Addison Date: Thu, 21 May 2026 12:55:54 -0500 Subject: [PATCH 32/43] make basedpyright happy --- namedisl/expression_like.py | 32 +++++++++++++++++++++++++------- namedisl/set_like.py | 18 +++++++++++++----- 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/namedisl/expression_like.py b/namedisl/expression_like.py index 00d922a..1ed9728 100644 --- a/namedisl/expression_like.py +++ b/namedisl/expression_like.py @@ -27,7 +27,7 @@ import operator from dataclasses import dataclass, replace -from typing import final, overload +from typing import Any, cast, final, overload from typing_extensions import Self, override @@ -41,6 +41,24 @@ ) +def _add_isl_expression( + lhs: IslExpressionLikeT, rhs: IslExpressionLikeT | int +) -> IslExpressionLikeT: + return cast("IslExpressionLikeT", cast(Any, operator.add)(lhs, rhs)) + + +def _sub_isl_expression( + lhs: IslExpressionLikeT, rhs: IslExpressionLikeT | int +) -> IslExpressionLikeT: + return cast("IslExpressionLikeT", cast(Any, operator.sub)(lhs, rhs)) + + +def _mul_isl_expression( + lhs: IslExpressionLikeT, rhs: IslExpressionLikeT | int +) -> IslExpressionLikeT: + return cast("IslExpressionLikeT", cast(Any, operator.mul)(lhs, rhs)) + + # {{{ "base" named expression-likes (affs, pwaffs, qpolynomials, pwqpolynomials) @dataclass(frozen=True, eq=False) @@ -51,34 +69,34 @@ def __add__(self, other: Self | int) -> Self: if isinstance(other, int): return replace( self, - _obj=operator.add(self._obj, other), + _obj=_add_isl_expression(self._obj, other), _name_to_dim=self._name_to_dim, _dimtype_to_names=self._dimtype_to_names ) - return _align_and_apply_binary_op(self, other, operator.add) + return _align_and_apply_binary_op(self, other, _add_isl_expression) def __sub__(self, other: Self | int) -> Self: if isinstance(other, int): return replace( self, - _obj=operator.sub(self._obj, other), + _obj=_sub_isl_expression(self._obj, other), _name_to_dim=self._name_to_dim, _dimtype_to_names=self._dimtype_to_names ) - return _align_and_apply_binary_op(self, other, operator.sub) + return _align_and_apply_binary_op(self, other, _sub_isl_expression) def __mul__(self, other: Self | int) -> Self: if isinstance(other, int): return replace( self, - _obj=operator.mul(self._obj, other), + _obj=_mul_isl_expression(self._obj, other), _name_to_dim=self._name_to_dim, _dimtype_to_names=self._dimtype_to_names ) - return _align_and_apply_binary_op(self, other, operator.mul) + return _align_and_apply_binary_op(self, other, _mul_isl_expression) def is_zero(self) -> bool: return bool(self._obj.is_zero()) # pyright: ignore[reportAttributeAccessIssue, reportUnknownArgumentType, reportUnknownMemberType] diff --git a/namedisl/set_like.py b/namedisl/set_like.py index 836e563..93d1d78 100644 --- a/namedisl/set_like.py +++ b/namedisl/set_like.py @@ -28,7 +28,7 @@ import operator from abc import ABC from dataclasses import dataclass, replace -from typing import TYPE_CHECKING, final, overload +from typing import TYPE_CHECKING, Any, cast, final, overload from constantdict import constantdict from typing_extensions import Self, override @@ -45,6 +45,14 @@ ) +def _set_like_and(lhs: isl.Set, rhs: isl.Set) -> isl.Set: + return cast("isl.Set", cast(Any, operator.and_)(lhs, rhs)) + + +def _set_like_or(lhs: isl.Set, rhs: isl.Set) -> isl.Set: + return cast("isl.Set", cast(Any, operator.or_)(lhs, rhs)) + + if TYPE_CHECKING: from collections.abc import Callable, Collection, Sequence @@ -275,7 +283,7 @@ def __and__(self: BasicSet, other: BasicSet | Set) -> BasicSet | Set: ... def __and__(self: Set, other: BasicSet | Set) -> Set: ... def __and__(self, other: _NamedIslSetLike) -> _NamedIslSetLike: - return _apply_set_like_binary_op(self, other, operator.and_) + return _apply_set_like_binary_op(self, other, _set_like_and) @overload def __or__(self: BasicMap, other: BasicMap | Map) -> BasicMap | Map: ... @@ -290,7 +298,7 @@ def __or__(self: BasicSet, other: BasicSet | Set) -> BasicSet | Set: ... def __or__(self: Set, other: BasicSet | Set) -> Set: ... def __or__(self, other: _NamedIslSetLike) -> _NamedIslSetLike: - return _apply_set_like_binary_op(self, other, operator.or_) + return _apply_set_like_binary_op(self, other, _set_like_or) @overload def __sub__(self: BasicMap, other: BasicMap | Map) -> BasicMap | Map: ... @@ -558,7 +566,7 @@ def intersect_domain(self, domain: BasicSet | Set) -> BasicMap | Map: isl.dim_type.in_, domain_obj, ), - operator.and_, + _set_like_and, ) assert isinstance(result, BasicMap | Map) return result @@ -572,7 +580,7 @@ def intersect_range(self, range_: BasicSet | Set) -> BasicMap | Map: isl.dim_type.out, range_obj, ), - operator.and_, + _set_like_and, ) assert isinstance(result, BasicMap | Map) return result From db4c31793a71d8c1e512a45fd42bf7683b6b5c88 Mon Sep 17 00:00:00 2001 From: Addison Date: Thu, 21 May 2026 13:01:49 -0500 Subject: [PATCH 33/43] make ruff happy --- namedisl/core.py | 2 +- namedisl/expression_like.py | 6 +++--- namedisl/set_like.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/namedisl/core.py b/namedisl/core.py index b284060..120b0e5 100644 --- a/namedisl/core.py +++ b/namedisl/core.py @@ -67,7 +67,7 @@ "IslMultiExpressionLikeT", bound=IslMultiExpressionLike ) IslPwExpressionLikeT = TypeVar("IslPwExpressionLikeT", bound=IslPwExpressionLike) -IslObjectT = TypeVar("IslObjectT", bound=IslObject, covariant=True) +IslObjectT = TypeVar("IslObjectT", bound=IslObject, covariant=True) # noqa: PLC0105 NamedIslObjectT = TypeVar("NamedIslObjectT", bound="NamedIslObject[IslObject]") diff --git a/namedisl/expression_like.py b/namedisl/expression_like.py index 1ed9728..6b0c5b9 100644 --- a/namedisl/expression_like.py +++ b/namedisl/expression_like.py @@ -44,19 +44,19 @@ def _add_isl_expression( lhs: IslExpressionLikeT, rhs: IslExpressionLikeT | int ) -> IslExpressionLikeT: - return cast("IslExpressionLikeT", cast(Any, operator.add)(lhs, rhs)) + return cast("IslExpressionLikeT", cast("Any", operator.add)(lhs, rhs)) def _sub_isl_expression( lhs: IslExpressionLikeT, rhs: IslExpressionLikeT | int ) -> IslExpressionLikeT: - return cast("IslExpressionLikeT", cast(Any, operator.sub)(lhs, rhs)) + return cast("IslExpressionLikeT", cast("Any", operator.sub)(lhs, rhs)) def _mul_isl_expression( lhs: IslExpressionLikeT, rhs: IslExpressionLikeT | int ) -> IslExpressionLikeT: - return cast("IslExpressionLikeT", cast(Any, operator.mul)(lhs, rhs)) + return cast("IslExpressionLikeT", cast("Any", operator.mul)(lhs, rhs)) # {{{ "base" named expression-likes (affs, pwaffs, qpolynomials, pwqpolynomials) diff --git a/namedisl/set_like.py b/namedisl/set_like.py index 93d1d78..f865664 100644 --- a/namedisl/set_like.py +++ b/namedisl/set_like.py @@ -46,11 +46,11 @@ def _set_like_and(lhs: isl.Set, rhs: isl.Set) -> isl.Set: - return cast("isl.Set", cast(Any, operator.and_)(lhs, rhs)) + return cast("isl.Set", cast("Any", operator.and_)(lhs, rhs)) def _set_like_or(lhs: isl.Set, rhs: isl.Set) -> isl.Set: - return cast("isl.Set", cast(Any, operator.or_)(lhs, rhs)) + return cast("isl.Set", cast("Any", operator.or_)(lhs, rhs)) if TYPE_CHECKING: From 6360261b614fd4544c3bdaf980fd9063fc45420a Mon Sep 17 00:00:00 2001 From: Addison Date: Thu, 21 May 2026 13:12:34 -0500 Subject: [PATCH 34/43] remove some dead code; remove temporary tests that imported loopy and pymbolic --- namedisl/core.py | 21 +---- namedisl/tags.py | 58 ------------ namedisl/test/test_namedisl.py | 18 +--- namedisl/test/test_set_like.py | 158 ++++++++------------------------- 4 files changed, 40 insertions(+), 215 deletions(-) delete mode 100644 namedisl/tags.py diff --git a/namedisl/core.py b/namedisl/core.py index 120b0e5..e47387b 100644 --- a/namedisl/core.py +++ b/namedisl/core.py @@ -30,7 +30,7 @@ from collections.abc import Callable, Collection, Mapping, Sequence from dataclasses import dataclass from importlib import metadata -from typing import TYPE_CHECKING, Generic, TypeAlias, TypeVar, cast, overload +from typing import Generic, TypeAlias, TypeVar, cast, overload from constantdict import constantdict from typing_extensions import Self, override @@ -46,10 +46,6 @@ ] -if TYPE_CHECKING: - from namedisl.tags import _TaggedName - - IslBaseExpressionLike = isl.Aff | isl.QPolynomial IslPwExpressionLike = isl.PwAff | isl.PwQPolynomial IslMultiExpressionLike = isl.MultiAff | isl.PwMultiAff @@ -62,11 +58,6 @@ "IslExpressionLikeT", bound=IslExpressionLike, ) -IslSetLikeT = TypeVar("IslSetLikeT", bound=IslSetLike) -IslMultiExpressionLikeT = TypeVar( - "IslMultiExpressionLikeT", bound=IslMultiExpressionLike -) -IslPwExpressionLikeT = TypeVar("IslPwExpressionLikeT", bound=IslPwExpressionLike) IslObjectT = TypeVar("IslObjectT", bound=IslObject, covariant=True) # noqa: PLC0105 NamedIslObjectT = TypeVar("NamedIslObjectT", bound="NamedIslObject[IslObject]") @@ -636,16 +627,6 @@ def _add_grouped_names( new_dimtype_to_names, ) - def add_names(self, tagged_names_to_add: Sequence[_TaggedName]) -> Self: - grouped_names = self._empty_grouped_names() - for tagged_name in tagged_names_to_add: - dim_type = _normalize_public_dim_type(tagged_name._isl_dim_type) - if dim_type not in grouped_names: - raise ValueError(f"unsupported dim type: {tagged_name._isl_dim_type}") - grouped_names[dim_type].append(tagged_name.name) - - return self._add_grouped_names(grouped_names) - def add_set_names(self, names_to_add: Collection[str]) -> Self: return self._add_names_by_dim_type(names_to_add, isl.dim_type.set) diff --git a/namedisl/tags.py b/namedisl/tags.py deleted file mode 100644 index 65bf176..0000000 --- a/namedisl/tags.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import annotations - - -__copyright__ = """ -Copyright (C) 2025- University of Illinois Board of Trustees -""" - -__license__ = """ -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. -""" - -from dataclasses import dataclass - -from pytools.tag import Tag - -from islpy import dim_type - - -@dataclass(frozen=True) -class _TaggedName(Tag): - name: str - _isl_dim_type: dim_type - - -@dataclass(frozen=True) -class InputName(_TaggedName): - _isl_dim_type: dim_type = dim_type.in_ - - -@dataclass(frozen=True) -class OutputName(_TaggedName): - _isl_dim_type: dim_type = dim_type.out - - -@dataclass(frozen=True) -class ParameterName(_TaggedName): - _isl_dim_type: dim_type = dim_type.param - - -@dataclass(frozen=True) -class SetName(_TaggedName): - _isl_dim_type: dim_type = dim_type.set diff --git a/namedisl/test/test_namedisl.py b/namedisl/test/test_namedisl.py index 6c91d17..f09e415 100644 --- a/namedisl/test/test_namedisl.py +++ b/namedisl/test/test_namedisl.py @@ -30,7 +30,7 @@ import islpy as isl import namedisl as nisl -from .utils_for_tests import generate_random_named_set, get_name_sequence +from .utils_for_tests import generate_random_named_set @pytest.mark.parametrize("ndims", [2, 3, 4, 5]) @@ -46,22 +46,6 @@ def test_names(ndims: int, has_params: bool): assert s.names == names -@pytest.mark.parametrize("ndims", [2, 3, 4, 5]) -@pytest.mark.parametrize("n_names_to_add", [2, 3, 4, 5]) -def test_add_names( - ndims: int, - n_names_to_add: int - ): - - s, _s_dims, _ = generate_random_named_set(ndims, "s", None) - new_set_names, _ = get_name_sequence(n_names_to_add, "set") - - from namedisl.tags import SetName - s = s.add_names([SetName(name) for name in new_set_names]) - - print(s) - - def test_public_dim_type_name_accessors() -> None: named_map = nisl.make_map("[n] -> { [i] -> [o] }") diff --git a/namedisl/test/test_set_like.py b/namedisl/test/test_set_like.py index 0ebdf7d..a46f1a1 100644 --- a/namedisl/test/test_set_like.py +++ b/namedisl/test/test_set_like.py @@ -26,9 +26,6 @@ """ import pytest -from loopy.symbolic import pwaff_from_expr -from pymbolic import var -from typing_extensions import assert_type import islpy as isl @@ -39,6 +36,7 @@ # {{{ sets + def test_set_from_str() -> None: s = nisl.make_set("[n] -> { [i]: 0 <= i < n }") @@ -62,6 +60,7 @@ def test_set_equality(ndims: int, has_params: bool): a, a_dims, a_cond = generate_random_named_set(ndims, "a", a_param) from itertools import permutations + for perm in list(permutations(a_dims.split(","))): perm_dims = ",".join(p for p in perm) set_str = f"{{ [{perm_dims}] : {a_cond} }}" @@ -145,8 +144,7 @@ def test_set_add_constraint_rejects_unknown_name() -> None: def test_set_gist_simplifies_against_named_context() -> None: set_ = nisl.make_set( - "{ [i, j, kb] : 0 <= i <= 13 and 0 <= j <= 13 " - "and 0 <= kb <= 4 and kb <= 3 }" + "{ [i, j, kb] : 0 <= i <= 13 and 0 <= j <= 13 and 0 <= kb <= 4 and kb <= 3 }" ) context = nisl.make_set("{ [j, i, kb] : 0 <= i <= 13 and 0 <= j <= 13 }") @@ -208,17 +206,13 @@ def test_basic_set_intersection_promotes_to_set() -> None: def test_set_convex_hull_returns_basic_set() -> None: set_ = nisl.make_set( - "{ [j, i] : " - "(j = 0 and 0 <= i <= 2) or " - "(j = 2 and 0 <= i <= 2) }" + "{ [j, i] : (j = 0 and 0 <= i <= 2) or (j = 2 and 0 <= i <= 2) }" ) result = set_.convex_hull() assert isinstance(result, nisl.BasicSet) - assert result == nisl.make_basic_set( - "{ [j, i] : 0 <= j <= 2 and 0 <= i <= 2 }" - ) + assert result == nisl.make_basic_set("{ [j, i] : 0 <= j <= 2 and 0 <= i <= 2 }") @pytest.mark.parametrize("ndims", [1, 2, 4, 8]) @@ -264,22 +258,22 @@ def test_set_dim_min(ndims: int): for i, name in enumerate(a_dims.split(",")): assert a.dim_min(name) == cond_pw_affs[i] + # }}} # {{{ maps + def test_map_from_str() -> None: - m = nisl.make_map( - "[n] -> { [i,j] -> [a,b] : 0 <= i, j < 10 and 0 <= a, b < 20 }") + m = nisl.make_map("[n] -> { [i,j] -> [a,b] : 0 <= i, j < 10 and 0 <= a, b < 20 }") print(m._obj) print(m) def test_map_from_map() -> None: - m = isl.Map( - "[n] -> { [i,j] -> [a,b] : 0 <= i, j < 10 and 0 <= a, b < 20 }") + m = isl.Map("[n] -> { [i,j] -> [a,b] : 0 <= i, j < 10 and 0 <= a, b < 20 }") named_map = nisl.make_map(m) print(named_map._obj) @@ -298,13 +292,14 @@ def test_map_equality(ndims_domain: int, ndims_range: int, has_params: bool): r_param = None og_map, domain_info, range_info = generate_random_named_map( - ndims_domain, "d", d_param, - ndims_range, "r", r_param) + ndims_domain, "d", d_param, ndims_range, "r", r_param + ) _, d_dims, d_cond = domain_info _, r_dims, r_cond = range_info from itertools import permutations + d_perms = list(permutations(d_dims.split(","))) r_perms = list(permutations(r_dims.split(","))) @@ -320,9 +315,7 @@ def test_map_equality(ndims_domain: int, ndims_range: int, has_params: bool): range_str = f"[{r_param}] ->" + range_str perm_map = nisl.make_map( - isl.Map.from_domain_and_range( - isl.Set(domain_str), isl.Set(range_str) - ) + isl.Map.from_domain_and_range(isl.Set(domain_str), isl.Set(range_str)) ) assert perm_map == og_map @@ -340,13 +333,11 @@ def test_map_union(ndims_domain: int, ndims_range: int, has_params: bool): r_param = None x, x_domain_info, x_range_info = generate_random_named_map( - ndims_domain, "x_in", d_param, - ndims_range, "x_out", r_param + ndims_domain, "x_in", d_param, ndims_range, "x_out", r_param ) y, y_domain_info, y_range_info = generate_random_named_map( - ndims_domain, "y_in", d_param, - ndims_domain, "y_out", r_param + ndims_domain, "y_in", d_param, ndims_domain, "y_out", r_param ) _, x_in_dims, x_in_cond = x_domain_info @@ -380,13 +371,11 @@ def test_map_intersection(ndims_domain: int, ndims_range: int, has_params: bool) r_param = None x, x_domain_info, x_range_info = generate_random_named_map( - ndims_domain, "x_in", d_param, - ndims_range, "x_out", r_param + ndims_domain, "x_in", d_param, ndims_range, "x_out", r_param ) y, y_domain_info, y_range_info = generate_random_named_map( - ndims_domain, "y_in", d_param, - ndims_domain, "y_out", r_param + ndims_domain, "y_in", d_param, ndims_domain, "y_out", r_param ) _, x_in_dims, x_in_cond = x_domain_info @@ -427,8 +416,7 @@ def test_map_add_constraint_preserves_basic_map_type() -> None: def test_map_add_constraint_supports_previous_context_relation() -> None: relation = nisl.make_map( - "{ [ki_prev, kb_prev] -> [ki_cur, kb_cur] : " - "kb_cur = kb_prev + 1 }" + "{ [ki_prev, kb_prev] -> [ki_cur, kb_cur] : kb_cur = kb_prev + 1 }" ) constrained = relation.add_constraint("ki_prev = ki_cur - 1") @@ -441,8 +429,7 @@ def test_map_add_constraint_supports_previous_context_relation() -> None: def test_map_gist_simplifies_against_named_context() -> None: map_ = nisl.make_map( - "{ [i, kb] -> [j] : 0 <= i <= 13 and 0 <= kb <= 4 " - "and kb <= 3 and j = i }" + "{ [i, kb] -> [j] : 0 <= i <= 13 and 0 <= kb <= 4 and kb <= 3 and j = i }" ) context = nisl.make_map("{ [kb, i] -> [j] : 0 <= i <= 13 and j = i }") @@ -452,9 +439,7 @@ def test_map_gist_simplifies_against_named_context() -> None: def test_map_subset_comparisons_align_by_name() -> None: smaller = nisl.make_map("{ [i] -> [x] : x = i and 0 <= i < 5 }") larger = nisl.make_map("{ [j, i] -> [y, x] : x = i and 0 <= i < 10 }") - equal_reordered = nisl.make_map( - "{ [i, j] -> [x, y] : x = i and 0 <= i < 5 }" - ) + equal_reordered = nisl.make_map("{ [i, j] -> [x, y] : x = i and 0 <= i < 5 }") assert smaller < larger assert smaller <= larger @@ -477,11 +462,7 @@ def test_basic_map_subset_comparisons_allow_map_promotion() -> None: def test_map_convex_hull_returns_basic_map() -> None: - map_ = nisl.make_map( - "{ [i] -> [j] : " - "(i = 0 and j = 0) or " - "(i = 2 and j = 2) }" - ) + map_ = nisl.make_map("{ [i] -> [j] : (i = 0 and j = 0) or (i = 2 and j = 2) }") result = map_.convex_hull() @@ -535,59 +516,6 @@ def test_map_alignment_syncs_internal_input_and_parameter_positions_and_names() assert rhs_names == ["x", "i", "j", "m", "n"] -def test_map_apply_range_for_compute_a_tile_usage_map() -> None: - bm = 32 - bk = 16 - - compute_map = nisl.make_map(f"""{{ - [is, ks] -> [ii_s, io, ki_s, ko] : - is = io * {bm} + ii_s and - ks = ko * {bk} + ki_s - }}""") - - usage_domain = nisl.make_set( - "{ [i, j, k, io, jo, ko, ii, ji, ki, ii_s, ji_s, ki_s] }" - ) - global_usage_map = nisl.make_map_from_domain_and_range( - usage_domain, - nisl.make_set("{ [is, ks] }") - ) - - local_usage_mpwaff = isl.MultiPwAff.zero(global_usage_map.get_space()) - for idx, expr in enumerate([var("i"), var("k")]): - local_space = local_usage_mpwaff.get_at(idx).get_space().domain() - local_usage_mpwaff = local_usage_mpwaff.set_pw_aff( - idx, - pwaff_from_expr(local_space, expr) - ) - - local_usage_map = nisl.make_map(local_usage_mpwaff.as_map()) - local_usage_map = local_usage_map.intersect_domain( - nisl.make_basic_set( - "{ [i, j, k, io, jo, ko, ii, ji, ki, ii_s, ji_s, ki_s] }" - ) - ) - - global_usage_map = global_usage_map | local_usage_map - assert isinstance(global_usage_map, nisl.Map) - assert_type(global_usage_map, nisl.Map) - compute_map = compute_map.rename_dims({ - "ii_s": "ii_s_out", - "io": "io_out", - "ki_s": "ki_s_out", - "ko": "ko_out", - }) - composed = global_usage_map.apply_range(compute_map) - - assert composed.input_names == frozenset( - {"i", "ii", "ii_s", "io", "j", "ji", "ji_s", "jo", - "k", "ki", "ki_s", "ko"} - ) - assert composed.range().names == frozenset( - {"ii_s_out", "io_out", "ki_s_out", "ko_out"} - ) - - def test_map_apply_range_rejects_surviving_name_collisions() -> None: lhs = nisl.make_map("{ [x] -> [y] }") rhs = nisl.make_map("{ [y] -> [x] }") @@ -605,7 +533,8 @@ def test_map_apply_range_can_explicitly_rename_and_equate_collision() -> None: assert result.input_names == frozenset({"x"}) assert result.range().names == frozenset({"x_out"}) assert ( - result.intersect_domain(nisl.make_set("{ [x] : x = 3 }")) + result + .intersect_domain(nisl.make_set("{ [x] : x = 3 }")) .range() .dim_min("x_out") .plain_is_equal(isl.PwAff("{ [(3)] }")) @@ -653,7 +582,8 @@ def test_map_apply_domain_can_explicitly_rename_and_equate_collision() -> None: assert result.input_names == frozenset({"x"}) assert result.range().names == frozenset({"x_out"}) assert ( - result.intersect_domain(nisl.make_set("{ [x] : x = 4 }")) + result + .intersect_domain(nisl.make_set("{ [x] : x = 4 }")) .range() .dim_min("x_out") .plain_is_equal(isl.PwAff("{ [(4)] }")) @@ -667,10 +597,7 @@ def test_duplicate_map_names_are_rejected() -> None: def test_map_empty_from_space_preserves_names_and_is_empty() -> None: space = isl.Space.create_from_names( - isl.DEFAULT_CONTEXT, - params=["n"], - in_=["i", "j"], - out=["x", "y"] + isl.DEFAULT_CONTEXT, params=["n"], in_=["i", "j"], out=["x", "y"] ) m = nisl.Map.empty(space) @@ -681,11 +608,7 @@ def test_map_empty_from_space_preserves_names_and_is_empty() -> None: def test_basic_map_empty_from_space_preserves_names_and_is_empty() -> None: - space = isl.Space.create_from_names( - isl.DEFAULT_CONTEXT, - in_=["i"], - out=["x"] - ) + space = isl.Space.create_from_names(isl.DEFAULT_CONTEXT, in_=["i"], out=["x"]) m = nisl.BasicMap.empty(space) @@ -706,11 +629,7 @@ def test_map_empty_matches_existing_named_space() -> None: def test_empty_map_is_identity_for_union() -> None: - space = isl.Space.create_from_names( - isl.DEFAULT_CONTEXT, - in_=["i"], - out=["x"] - ) + space = isl.Space.create_from_names(isl.DEFAULT_CONTEXT, in_=["i"], out=["x"]) empty_map = nisl.Map.empty(space) nonempty_map = nisl.make_map("{ [i] -> [x] }") @@ -722,8 +641,7 @@ def test_empty_map_is_identity_for_union() -> None: @pytest.mark.parametrize("ndims_range", [1, 2, 4, 8]) def test_map_eliminate(ndims_domain: int, ndims_range: int): x, x_domain_info, x_range_info = generate_random_named_map( - ndims_domain, "x_in", None, - ndims_range, "x_out", None + ndims_domain, "x_in", None, ndims_range, "x_out", None ) _, x_in_dims, _ = x_domain_info @@ -739,8 +657,7 @@ def test_map_eliminate(ndims_domain: int, ndims_range: int): @pytest.mark.parametrize("ndims_range", [1, 2, 4, 8]) def test_map_project_out(ndims_domain: int, ndims_range: int): x, x_domain_info, x_range_info = generate_random_named_map( - ndims_domain, "x_in", None, - ndims_range, "x_out", None + ndims_domain, "x_in", None, ndims_range, "x_out", None ) _, x_in_dims, _ = x_domain_info @@ -764,8 +681,7 @@ def test_map_as_pw_multi_aff(): @pytest.mark.parametrize("ndims_range", [1, 2, 4, 8]) def test_map_dim_max(ndims_domain: int, ndims_range: int): m, (_, in_names, in_conds), (_, out_names, out_conds) = generate_random_named_map( - ndims_domain, "x_in", None, - ndims_range, "x_out", None + ndims_domain, "x_in", None, ndims_range, "x_out", None ) # dim_{min,max} return raw isl.PwAff objects on a zero-dimensional set space. @@ -793,8 +709,7 @@ def test_map_dim_max(ndims_domain: int, ndims_range: int): @pytest.mark.parametrize("ndims_range", [1, 2, 4, 8]) def test_map_dim_min(ndims_domain: int, ndims_range: int): m, (_, in_names, in_conds), (_, out_names, out_conds) = generate_random_named_map( - ndims_domain, "x_in", None, - ndims_range, "x_out", None + ndims_domain, "x_in", None, ndims_range, "x_out", None ) # dim_{min,max} return raw isl.PwAff objects on a zero-dimensional set space. @@ -815,25 +730,28 @@ def test_map_dim_min(ndims_domain: int, ndims_range: int): for i, name in enumerate(out_names.split(",")): assert m.dim_min(name) == out_lower_bound_pw_maffs[i] + # }}} # {{{ basic{map, set} + def test_basic_map_from_str() -> None: m = nisl.make_basic_map( - "[n] -> { [i,j] -> [a,b] : 0 <= i, j < 10 and 0 <= a, b < 20 }") + "[n] -> { [i,j] -> [a,b] : 0 <= i, j < 10 and 0 <= a, b < 20 }" + ) print(m._obj) print(m) def test_basic_map_from_map() -> None: - m = isl.BasicMap( - "[n] -> { [i,j] -> [a,b] : 0 <= i, j < 10 and 0 <= a, b < 20 }") + m = isl.BasicMap("[n] -> { [i,j] -> [a,b] : 0 <= i, j < 10 and 0 <= a, b < 20 }") named_map = nisl.make_basic_map(m) print(named_map._obj) print(named_map) + # }}} From 006b129e7b5fabc0042279035da26395059a461f Mon Sep 17 00:00:00 2001 From: Addison Date: Thu, 21 May 2026 13:15:01 -0500 Subject: [PATCH 35/43] pylint related fixes --- namedisl/expression_like.py | 6 +++--- namedisl/set_like.py | 28 +++++++++++++++++++++------- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/namedisl/expression_like.py b/namedisl/expression_like.py index 6b0c5b9..d20b9dd 100644 --- a/namedisl/expression_like.py +++ b/namedisl/expression_like.py @@ -44,19 +44,19 @@ def _add_isl_expression( lhs: IslExpressionLikeT, rhs: IslExpressionLikeT | int ) -> IslExpressionLikeT: - return cast("IslExpressionLikeT", cast("Any", operator.add)(lhs, rhs)) + return cast(IslExpressionLikeT, cast(Any, operator.add)(lhs, rhs)) def _sub_isl_expression( lhs: IslExpressionLikeT, rhs: IslExpressionLikeT | int ) -> IslExpressionLikeT: - return cast("IslExpressionLikeT", cast("Any", operator.sub)(lhs, rhs)) + return cast(IslExpressionLikeT, cast(Any, operator.sub)(lhs, rhs)) def _mul_isl_expression( lhs: IslExpressionLikeT, rhs: IslExpressionLikeT | int ) -> IslExpressionLikeT: - return cast("IslExpressionLikeT", cast("Any", operator.mul)(lhs, rhs)) + return cast(IslExpressionLikeT, cast(Any, operator.mul)(lhs, rhs)) # {{{ "base" named expression-likes (affs, pwaffs, qpolynomials, pwqpolynomials) diff --git a/namedisl/set_like.py b/namedisl/set_like.py index f865664..7aa160a 100644 --- a/namedisl/set_like.py +++ b/namedisl/set_like.py @@ -46,11 +46,11 @@ def _set_like_and(lhs: isl.Set, rhs: isl.Set) -> isl.Set: - return cast("isl.Set", cast("Any", operator.and_)(lhs, rhs)) + return cast(isl.Set, cast(Any, operator.and_)(lhs, rhs)) def _set_like_or(lhs: isl.Set, rhs: isl.Set) -> isl.Set: - return cast("isl.Set", cast("Any", operator.or_)(lhs, rhs)) + return cast(isl.Set, cast(Any, operator.or_)(lhs, rhs)) if TYPE_CHECKING: @@ -150,7 +150,11 @@ def add_constraint( if constraint_name_to_dim != self._name_to_dim: constraint_set = _align_obj( - Set(constraint_set, constraint_name_to_dim, self._dimtype_to_names), + Set( # pylint: disable=too-many-function-args + constraint_set, + constraint_name_to_dim, + self._dimtype_to_names, + ), self._name_to_dim, self._dimtype_to_names, )._obj @@ -188,8 +192,10 @@ def gist(self, context: _NamedIslSetLike) -> _NamedIslSetLike: else: result_type = Set - return result_type( - result, self_aligned._name_to_dim, self_aligned._dimtype_to_names + return result_type( # pylint: disable=too-many-function-args + result, + self_aligned._name_to_dim, + self_aligned._dimtype_to_names, ) def project_out(self: Self, names_to_project_out: str | Collection[str]) -> Self: @@ -445,7 +451,11 @@ def _apply_set_like_binary_op( else: result_type = Set - return result_type(result, lhs._name_to_dim, lhs._dimtype_to_names) + return result_type( # pylint: disable=too-many-function-args + result, + lhs._name_to_dim, + lhs._dimtype_to_names, + ) def _compare_set_like( @@ -648,7 +658,11 @@ def empty(cls, space: isl.Space) -> BasicMap: obj = isl.BasicMap.empty(space) set_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) assert isinstance(set_obj, isl.Set) - return cls(set_obj, name_to_dim, dimtype_to_names) + return cls( # pylint: disable=too-many-function-args + set_obj, + name_to_dim, + dimtype_to_names, + ) @override def _map_obj(self) -> isl.BasicMap: From 97a0a87720fc9bd43c5bdfc6aab9ccacfb73392c Mon Sep 17 00:00:00 2001 From: Addison Date: Thu, 21 May 2026 14:20:28 -0500 Subject: [PATCH 36/43] add docs; more typechecker/linter fixes --- doc/conf.py | 33 +++--- doc/index.rst | 6 + doc/internals.rst | 18 +++ doc/ref.rst | 86 ++++++++++++++ namedisl/__init__.py | 15 ++- namedisl/core.py | 137 ++++++++++++++++++++-- namedisl/expression_like.py | 86 +++++++++++++- namedisl/set_like.py | 208 +++++++++++++++++++++++++++++---- namedisl/test/test_set_like.py | 59 ++++++---- 9 files changed, 567 insertions(+), 81 deletions(-) create mode 100644 doc/internals.rst diff --git a/doc/conf.py b/doc/conf.py index a27710b..1dbaeb8 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -1,31 +1,30 @@ from __future__ import annotations -from urllib.request import urlopen +import sys +from importlib import metadata +from pathlib import Path -_conf_url = \ - "https://raw.githubusercontent.com/inducer/sphinxconfig/main/sphinxconfig.py" -with urlopen(_conf_url) as _inf: - exec(compile(_inf.read(), _conf_url, "exec"), globals()) +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) -copyright = "2025- University of Illiois Board of Trustees" -author = "Andreas Kloeckner" +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", +] + +autodoc_member_order = "bysource" +autodoc_typehints = "none" -# The version info for the project you're documenting, acts as replacement for -# |version| and |release|, also used in various other places throughout the -# built documents. -# -# The short X.Y version. -ver_dic = {} -with open("../namedisl/__init__.py") as vfile: - exec(compile(vfile.read(), "../namedisl/__init__.py", "exec"), ver_dic) +copyright = "2025- University of Illinois Board of Trustees" +author = "Andreas Kloeckner" -version = ".".join(str(x) for x in ver_dic["__version__"]) -release = ver_dic["__version__"] +release = metadata.version("namedisl") +version = ".".join(release.split(".")[:2]) intersphinx_mapping = { "islpy": ("https://documen.tician.de/islpy", None), "constantdict": ("https://matthiasdiener.github.io/constantdict/", None), + "python": ("https://docs.python.org/3", None), } nitpicky = True diff --git a/doc/index.rst b/doc/index.rst index 0f429a7..28d47f6 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -1,11 +1,17 @@ Welcome to namedisl's documentation! ==================================== +namedisl is a small wrapper around :mod:`islpy` for code that needs dimension +names to carry semantic meaning. It keeps isl's integer-set operations while +offering a name-oriented interface for sets, maps, affine expressions, and +piecewise affine expressions. + .. toctree:: :maxdepth: 2 :caption: Contents: ref + internals 🚀 Github 💾 Download Releases diff --git a/doc/internals.rst b/doc/internals.rst new file mode 100644 index 0000000..267bdcf --- /dev/null +++ b/doc/internals.rst @@ -0,0 +1,18 @@ +Internal Reference +================== + +These interfaces support the public wrappers and are mainly useful when +maintaining namedisl itself. + +.. automodule:: namedisl.core + +.. autoclass:: namedisl.core.NamedIslObject + :members: + +.. autofunction:: namedisl.core._deconstruct_object + +.. autofunction:: namedisl.core._make_named_object_pieces + +.. autofunction:: namedisl.core._align_two + +.. autofunction:: namedisl.core._align_and_apply_binary_op diff --git a/doc/ref.rst b/doc/ref.rst index 56fd5a9..9d60a32 100644 --- a/doc/ref.rst +++ b/doc/ref.rst @@ -2,3 +2,89 @@ Reference ========= .. automodule:: namedisl + +Set and Map Objects +------------------- + +.. autofunction:: make_basic_set + +.. autofunction:: make_set + +.. autofunction:: make_basic_map + +.. autofunction:: make_map + +.. autofunction:: make_map_from_domain_and_range + +.. autoclass:: BasicSet + :members: + :inherited-members: + :special-members: __and__, __or__, __sub__, __lt__, __le__, __gt__, __ge__ + :exclude-members: __init__, __new__ + +.. autoclass:: Set + :members: + :inherited-members: + :special-members: __and__, __or__, __sub__, __lt__, __le__, __gt__, __ge__ + :exclude-members: __init__, __new__ + +.. autoclass:: BasicMap + :members: + :inherited-members: + :special-members: __and__, __or__, __sub__, __lt__, __le__, __gt__, __ge__ + :exclude-members: __init__, __new__ + +.. autoclass:: Map + :members: + :inherited-members: + :special-members: __and__, __or__, __sub__, __lt__, __le__, __gt__, __ge__ + :exclude-members: __init__, __new__ + +Expression Objects +------------------ + +.. autofunction:: make_aff + +.. autofunction:: make_pw_aff + +.. autofunction:: make_qpolynomial + +.. autofunction:: make_pw_qpolynomial + +.. autofunction:: make_multi_aff + +.. autofunction:: make_pw_multi_aff + +.. autoclass:: Aff + :members: + :inherited-members: + :special-members: __add__, __sub__, __mul__ + :exclude-members: __init__, __new__ + +.. autoclass:: PwAff + :members: + :inherited-members: + :special-members: __add__, __sub__, __mul__ + :exclude-members: __init__, __new__ + +.. autoclass:: QPolynomial + :members: + :inherited-members: + :special-members: __add__, __sub__, __mul__ + :exclude-members: __init__, __new__ + +.. autoclass:: PwQPolynomial + :members: + :inherited-members: + :special-members: __add__, __sub__, __mul__ + :exclude-members: __init__, __new__ + +.. autoclass:: MultiAff + :members: + :inherited-members: + :exclude-members: __init__, __new__ + +.. autoclass:: PwMultiAff + :members: + :inherited-members: + :exclude-members: __init__, __new__ diff --git a/namedisl/__init__.py b/namedisl/__init__.py index 2152775..bce7f38 100644 --- a/namedisl/__init__.py +++ b/namedisl/__init__.py @@ -1,3 +1,16 @@ +""" +Name-aware wrappers for :mod:`islpy` objects. + +namedisl offers small Python wrappers around isl sets, maps, and expression +objects. The wrappers keep a separate mapping from dimension names to isl +dimension positions, align operands by name before applying binary operations, +and reconstruct ordinary :mod:`islpy` objects when callers need to interoperate +with islpy or downstream libraries. + +Most users should construct objects through the ``make_*`` functions exported +from this module. +""" + from __future__ import annotations @@ -73,5 +86,5 @@ "make_pw_multi_aff", "make_pw_qpolynomial", "make_qpolynomial", - "make_set" + "make_set", ] diff --git a/namedisl/core.py b/namedisl/core.py index e47387b..fb0cdb4 100644 --- a/namedisl/core.py +++ b/namedisl/core.py @@ -1,3 +1,13 @@ +""" +Core metadata machinery for name-aware isl wrappers. + +The public set-like and expression-like classes store an isl object together +with metadata that records which semantic name belongs to each internal +dimension. This module +contains the alignment, reconstruction, and metadata-manipulation helpers used +to keep that invariant intact. +""" + from __future__ import annotations @@ -38,14 +48,6 @@ import islpy as isl -ISL_DIM_TYPES = [ - isl.dim_type.out, - isl.dim_type.in_, - isl.dim_type.set, - isl.dim_type.param, -] - - IslBaseExpressionLike = isl.Aff | isl.QPolynomial IslPwExpressionLike = isl.PwAff | isl.PwQPolynomial IslMultiExpressionLike = isl.MultiAff | isl.PwMultiAff @@ -193,6 +195,13 @@ def _normalize_dimtype_to_names( def _make_named_object_pieces(obj: IslObject) -> IslObjectPieces: + """ + Extract the normalized isl object and name metadata for *obj*. + + The returned object uses namedisl's internal flattened dimension layout. + The accompanying mappings record how to reconstruct the original public isl + dimension kinds and names. + """ _ensure_unique_public_names(obj) decon_obj, dimtype_to_names = _deconstruct_object(obj) decon_obj, name_to_dim = _strip_names(decon_obj) @@ -201,6 +210,13 @@ def _make_named_object_pieces(obj: IslObject) -> IslObjectPieces: def _restore_names(obj: IslObjectT, name_to_dim: NameToDim) -> IslObjectT: + """ + Return a copy of *obj* with dimension names restored from *name_to_dim*. + + This is intentionally used at reconstruction boundaries. Internal + metadata-only operations may leave the private isl object's own dimension + names stale because namedisl does not rely on those names for correctness. + """ restored_obj = obj.copy() if isinstance(restored_obj, isl.PwAff): # input dimensions cannot be renamed for isl.PwAff, so we first move @@ -272,6 +288,13 @@ def _deconstruct_object(obj: IslObject) -> tuple[IslObject, DimTypeToNames]: ... def _deconstruct_object(obj: IslObject) -> tuple[IslObject, DimTypeToNames]: + """ + Convert a public isl object into namedisl's internal representation. + + Set-like objects are represented as flattened sets whose dimensions are + grouped as set/output names, input names, and parameter names. Expression + objects use input dimensions followed by parameters. + """ dt_to_names: dict[isl.dim_type, frozenset[str]] = {} if isinstance(obj, IslSetLike | IslMultiExpressionLike): @@ -393,6 +416,12 @@ def _find_joint_name_to_dim( def _align_obj( named_obj: NamedIslObjectT, ordering: NameToDim, dimtype_to_names: DimTypeToNames ) -> NamedIslObjectT: + """ + Return *named_obj* with internal dimensions arranged according to *ordering*. + + The isl object is moved or expanded as needed, but public names are carried + in metadata and are restored only when a raw isl object is reconstructed. + """ new_isl_obj = cast( "IslSetLike | IslBaseExpressionLike | IslPwExpressionLike | isl.MultiAff", named_obj._obj, @@ -433,8 +462,6 @@ def _align_obj( running_name_to_dim[name] = target_dim - new_isl_obj = _restore_names(new_isl_obj, ordering) - return type(named_obj)( new_isl_obj, ordering, @@ -444,7 +471,10 @@ def _align_obj( def _align_two( named_obj1: NamedIslObjectT, named_obj2: NamedIslObjectT -) -> tuple[NamedIslObjectT, ...]: +) -> tuple[NamedIslObjectT, NamedIslObjectT]: + """ + Align two named isl objects to a common name-to-dimension mapping. + """ name_to_dim, dimtype_to_names = _find_joint_name_to_dim(named_obj1, named_obj2) @@ -459,6 +489,9 @@ def _align_and_apply_binary_op( rhs: NamedIslObject[IslObjectT], op: Callable[[IslObjectT, IslObjectT], IslObjectT], ) -> NamedIslObject[IslObjectT]: + """ + Align *lhs* and *rhs*, apply *op* to their isl objects, and wrap the result. + """ lhs, rhs = _align_two(lhs, rhs) result = op(lhs._obj, rhs._obj) @@ -470,6 +503,15 @@ def _align_and_apply_binary_op( @dataclass(frozen=True, eq=False) class NamedIslObject(ABC, Generic[IslObjectT]): + """ + Base class for named isl wrappers. + + Instances pair a private isl object with metadata that records the semantic + name and public dimension kind of every internal dimension. Subclasses use + this metadata to implement operations in terms of names while still + delegating the underlying integer-set algebra to isl. + """ + _obj: IslObjectT _name_to_dim: NameToDim @@ -622,50 +664,94 @@ def _add_grouped_names( ) return type(self)( - cast("IslObjectT", _restore_names(new_obj, new_name_to_dim)), + cast("IslObjectT", new_obj), new_name_to_dim, new_dimtype_to_names, ) def add_set_names(self, names_to_add: Collection[str]) -> Self: + """ + Return a copy with new unconstrained set/output dimensions. + + :arg names_to_add: Names to insert. Existing names and duplicates are + rejected. + """ return self._add_names_by_dim_type(names_to_add, isl.dim_type.set) def add_output_names(self, names_to_add: Collection[str]) -> Self: + """ + Return a copy with new unconstrained output dimensions. + + This is equivalent to :meth:`add_set_names` for map-like objects. + """ return self._add_names_by_dim_type(names_to_add, isl.dim_type.out) def add_input_names(self, names_to_add: Collection[str]) -> Self: + """ + Return a copy with new unconstrained input dimensions. + """ return self._add_names_by_dim_type(names_to_add, isl.dim_type.in_) def add_parameter_names(self, names_to_add: Collection[str]) -> Self: + """ + Return a copy with new parameter dimensions. + """ return self._add_names_by_dim_type(names_to_add, isl.dim_type.param) def add_dim_names( self, names_to_add: Collection[str], dim_type: isl.dim_type ) -> Self: + """ + Return a copy with new dimensions of *dim_type*. + + :arg dim_type: One of ``isl.dim_type.set``, ``isl.dim_type.out``, + ``isl.dim_type.in_``, or ``isl.dim_type.param``. + """ return self._add_names_by_dim_type(names_to_add, dim_type) @property def names(self) -> frozenset[str]: + """ + All dimension names known to this object. + """ return frozenset(self._name_to_dim.keys()) def dim_names(self, dim_type: isl.dim_type) -> frozenset[str]: + """ + Return the names belonging to *dim_type*. + """ return self._names_for_dim_type(dim_type) def ordered_dim_names(self, dim_type: isl.dim_type) -> tuple[str, ...]: + """ + Return names for *dim_type* in their current dimension order. + """ return self._ordered_names_for_dim_type(dim_type) @property def set_names(self) -> frozenset[str]: + """ + Names of set dimensions. + """ return self._names_for_dim_type(isl.dim_type.set) @property def output_names(self) -> frozenset[str]: + """ + Names of output dimensions. + """ return self._names_for_dim_type(isl.dim_type.out) def get_space(self) -> isl.Space: + """ + Reconstruct and return the object's public isl space. + """ return self._reconstruct_isl_object().get_space() def dim(self, dim_type: isl.dim_type) -> int: + """ + Return the number of dimensions of *dim_type*. + """ dim_type = _normalize_public_dim_type(dim_type) if dim_type in (isl.dim_type.set, isl.dim_type.in_, isl.dim_type.param): return len(self._names_for_dim_type(dim_type)) @@ -676,6 +762,12 @@ def move_dims( names_to_move: str | Collection[str], dst_type: isl.dim_type, ) -> Self: + """ + Return a copy with named dimensions moved to *dst_type*. + + The relative order of moved names is preserved. Moving a name to its + current dimension kind is a no-op. + """ if isinstance(names_to_move, str): names_to_move = [names_to_move] @@ -717,6 +809,12 @@ def move_dims( return _align_obj(self, new_name_to_dim, new_dimtype_to_names) def rename_dims(self, renaming: Mapping[str, str]) -> Self: + """ + Return a copy with dimension names changed according to *renaming*. + + Renaming updates namedisl metadata only. The private isl object's own + names are restored when :meth:`get_isl_object` is called. + """ if not renaming: return self @@ -767,6 +865,12 @@ def equate_dims( name1: str | Mapping[str, str], name2: str | None = None, ) -> Self: + """ + Return a copy constrained so paired dimensions are equal. + + Either pass two names, or pass a mapping whose keys and values are + equated pairwise. + """ if isinstance(name1, str): if name2 is None: raise TypeError("name2 must be provided when name1 is a string") @@ -812,6 +916,9 @@ def _has_inputs(self) -> bool: @property def input_names(self) -> frozenset[str]: + """ + Names of input dimensions. + """ return self._names_for_dim_type(isl.dim_type.in_) @property @@ -826,6 +933,9 @@ def _has_params(self) -> bool: @property def parameter_names(self) -> frozenset[str]: + """ + Names of parameter dimensions. + """ return self._metadata_parameter_names @property @@ -889,6 +999,9 @@ def _reconstruct_isl_object(self) -> IslObject: return obj def get_isl_object(self) -> IslObject: + """ + Reconstruct and return the wrapped public :mod:`islpy` object. + """ return self._reconstruct_isl_object() @override diff --git a/namedisl/expression_like.py b/namedisl/expression_like.py index d20b9dd..a308329 100644 --- a/namedisl/expression_like.py +++ b/namedisl/expression_like.py @@ -1,3 +1,11 @@ +""" +Name-aware affine and polynomial expression wrappers. + +The wrappers in this module provide a small arithmetic interface around isl +expression objects while preserving named dimension metadata across alignment +and reconstruction. +""" + from __future__ import annotations @@ -44,19 +52,19 @@ def _add_isl_expression( lhs: IslExpressionLikeT, rhs: IslExpressionLikeT | int ) -> IslExpressionLikeT: - return cast(IslExpressionLikeT, cast(Any, operator.add)(lhs, rhs)) + return cast("IslExpressionLikeT", cast("Any", operator.add)(lhs, rhs)) def _sub_isl_expression( lhs: IslExpressionLikeT, rhs: IslExpressionLikeT | int ) -> IslExpressionLikeT: - return cast(IslExpressionLikeT, cast(Any, operator.sub)(lhs, rhs)) + return cast("IslExpressionLikeT", cast("Any", operator.sub)(lhs, rhs)) def _mul_isl_expression( lhs: IslExpressionLikeT, rhs: IslExpressionLikeT | int ) -> IslExpressionLikeT: - return cast(IslExpressionLikeT, cast(Any, operator.mul)(lhs, rhs)) + return cast("IslExpressionLikeT", cast("Any", operator.mul)(lhs, rhs)) # {{{ "base" named expression-likes (affs, pwaffs, qpolynomials, pwqpolynomials) @@ -66,6 +74,9 @@ class _NamedExpressionLike(NamedIslObject[IslExpressionLikeT]): # FIXME: Self is used here is because _NamedExpressionLike is generic, # leading to complaints from basedpyright def __add__(self, other: Self | int) -> Self: + """ + Add another compatible named expression or an integer. + """ if isinstance(other, int): return replace( self, @@ -77,6 +88,9 @@ def __add__(self, other: Self | int) -> Self: return _align_and_apply_binary_op(self, other, _add_isl_expression) def __sub__(self, other: Self | int) -> Self: + """ + Subtract another compatible named expression or an integer. + """ if isinstance(other, int): return replace( self, @@ -88,6 +102,9 @@ def __sub__(self, other: Self | int) -> Self: return _align_and_apply_binary_op(self, other, _sub_isl_expression) def __mul__(self, other: Self | int) -> Self: + """ + Multiply by another compatible named expression or an integer. + """ if isinstance(other, int): return replace( self, @@ -99,6 +116,9 @@ def __mul__(self, other: Self | int) -> Self: return _align_and_apply_binary_op(self, other, _mul_isl_expression) def is_zero(self) -> bool: + """ + Return whether this expression is identically zero. + """ return bool(self._obj.is_zero()) # pyright: ignore[reportAttributeAccessIssue, reportUnknownArgumentType, reportUnknownMemberType] @override @@ -114,6 +134,12 @@ class _NamedPwExpressionLike(_NamedExpressionLike[IslExpressionLikeT]): @final @dataclass(frozen=True, eq=False) class Aff(_NamedExpressionLike[isl.Aff]): + """ + Name-aware wrapper around :class:`islpy.Aff`. + + Construct instances with :func:`make_aff`. + """ + _obj: isl.Aff @override @@ -134,6 +160,9 @@ def make_aff(src: isl.Aff) -> Aff: def make_aff(src: str | isl.Aff, ctx: isl.Context | None = None) -> Aff: + """ + Create an :class:`Aff` from isl syntax or an :class:`islpy.Aff`. + """ obj = isl.Aff(src, ctx) if isinstance(src, str) else src aff_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) assert isinstance(aff_obj, isl.Aff) @@ -143,6 +172,12 @@ def make_aff(src: str | isl.Aff, ctx: isl.Context | None = None) -> Aff: @final @dataclass(frozen=True, eq=False) class QPolynomial(_NamedExpressionLike[isl.QPolynomial]): + """ + Name-aware wrapper around :class:`islpy.QPolynomial`. + + Construct instances with :func:`make_qpolynomial`. + """ + _obj: isl.QPolynomial @override @@ -164,6 +199,9 @@ def make_qpolynomial(src: isl.QPolynomial) -> QPolynomial: def make_qpolynomial( src: str | isl.QPolynomial, ctx: isl.Context | None = None) -> QPolynomial: + """ + Create a :class:`QPolynomial` from isl syntax or an isl qpolynomial. + """ # NOTE: ISL does not have a QPolynomial constructor, but we can make one # here by first creating a PwQPolynomial, then taking the only QPolynomial # that comes out of it :shrug: @@ -180,6 +218,12 @@ def make_qpolynomial( @final @dataclass(frozen=True, eq=False) class PwAff(_NamedPwExpressionLike[isl.PwAff]): + """ + Name-aware wrapper around :class:`islpy.PwAff`. + + Construct instances with :func:`make_pw_aff`. + """ + _obj: isl.PwAff @override @@ -200,6 +244,9 @@ def make_pw_aff(src: isl.PwAff) -> PwAff: def make_pw_aff(src: str | isl.PwAff, ctx: isl.Context | None = None) -> PwAff: + """ + Create a :class:`PwAff` from isl syntax or an :class:`islpy.PwAff`. + """ obj = isl.PwAff(src, ctx) if isinstance(src, str) else src pwaff_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) assert isinstance(pwaff_obj, isl.PwAff) @@ -209,6 +256,12 @@ def make_pw_aff(src: str | isl.PwAff, ctx: isl.Context | None = None) -> PwAff: @final @dataclass(frozen=True, eq=False) class PwQPolynomial(_NamedPwExpressionLike[isl.PwQPolynomial]): + """ + Name-aware wrapper around :class:`islpy.PwQPolynomial`. + + Construct instances with :func:`make_pw_qpolynomial`. + """ + _obj: isl.PwQPolynomial @override @@ -233,6 +286,9 @@ def make_pw_qpolynomial( src: str | isl.PwQPolynomial, ctx: isl.Context | None = None ) -> PwQPolynomial: + """ + Create a :class:`PwQPolynomial` from isl syntax or an isl object. + """ obj = isl.PwQPolynomial(src, ctx) if isinstance(src, str) else src pw_qp_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) assert isinstance(pw_qp_obj, isl.PwQPolynomial) @@ -258,7 +314,16 @@ class _NamedMultiExpressionLike(NamedIslObject[isl.Set]): @final @dataclass(frozen=True, eq=False) class PwMultiAff(_NamedMultiExpressionLike): + """ + Name-aware wrapper around :class:`islpy.PwMultiAff`. + + Construct instances with :func:`make_pw_multi_aff`. + """ + def get_at(self, name: str) -> PwAff: + """ + Return the output component named *name*. + """ if name not in self._names_for_dim_type(isl.dim_type.set): raise ValueError(f"unknown output name: {name}") return make_pw_aff( @@ -288,6 +353,9 @@ def make_pw_multi_aff( src: str | isl.PwMultiAff, ctx: isl.Context | None = None ) -> PwMultiAff: + """ + Create a :class:`PwMultiAff` from isl syntax or an :class:`islpy.PwMultiAff`. + """ obj = isl.PwMultiAff(src, ctx) if isinstance(src, str) else src pw_maff_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) @@ -298,7 +366,16 @@ def make_pw_multi_aff( @final @dataclass(frozen=True, eq=False) class MultiAff(_NamedMultiExpressionLike): + """ + Name-aware wrapper around :class:`islpy.MultiAff`. + + Construct instances with :func:`make_multi_aff`. + """ + def get_at(self, name: str) -> Aff: + """ + Return the output component named *name*. + """ if name not in self._names_for_dim_type(isl.dim_type.set): raise ValueError(f"unknown output name: {name}") return make_aff(self._reconstruct_isl_object().get_at(self._name_to_dim[name])) @@ -324,6 +401,9 @@ def make_multi_aff(src: isl.MultiAff) -> MultiAff: def make_multi_aff( src: str | isl.MultiAff, ctx: isl.Context | None = None) -> MultiAff: + """ + Create a :class:`MultiAff` from isl syntax or an :class:`islpy.MultiAff`. + """ obj = isl.MultiAff(src, ctx) if isinstance(src, str) else src maff_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) assert isinstance(maff_obj, isl.Set) diff --git a/namedisl/set_like.py b/namedisl/set_like.py index 7aa160a..796fb4b 100644 --- a/namedisl/set_like.py +++ b/namedisl/set_like.py @@ -1,3 +1,12 @@ +""" +Name-aware set and map wrappers. + +The classes in this module wrap :mod:`islpy` sets and maps while making +dimension names the primary way to address axes. Internally, maps and sets are +stored as set-like isl objects with metadata that distinguishes output, input, +and parameter dimensions. +""" + from __future__ import annotations @@ -46,11 +55,11 @@ def _set_like_and(lhs: isl.Set, rhs: isl.Set) -> isl.Set: - return cast(isl.Set, cast(Any, operator.and_)(lhs, rhs)) + return cast("isl.Set", cast("Any", operator.and_)(lhs, rhs)) def _set_like_or(lhs: isl.Set, rhs: isl.Set) -> isl.Set: - return cast(isl.Set, cast(Any, operator.or_)(lhs, rhs)) + return cast("isl.Set", cast("Any", operator.or_)(lhs, rhs)) if TYPE_CHECKING: @@ -66,6 +75,9 @@ class _NamedIslSetLike(NamedIslObject[isl.Set], ABC): """ def complement(self: Self) -> Self: + """ + Return the complement of this set-like object. + """ return replace( self, _obj=self._obj.complement(), @@ -80,21 +92,27 @@ def convex_hull(self: BasicMap | Map) -> BasicMap: ... def convex_hull(self: BasicSet | Set) -> BasicSet: ... def convex_hull(self) -> BasicMap | BasicSet: - obj = self._reconstruct_isl_object() - assert isinstance(obj, isl.BasicMap | isl.Map | isl.BasicSet | isl.Set) - - if isinstance(obj, isl.BasicMap): - return make_basic_map(obj.to_map().convex_hull()) - - if isinstance(obj, isl.Map): - return make_basic_map(obj.convex_hull()) - - if isinstance(obj, isl.BasicSet): - return make_basic_set(obj.to_set().convex_hull()) + """ + Return the convex hull as a basic set or basic map. + """ + result = isl.Set.from_basic_set(self._obj.convex_hull()) + if isinstance(self, _NamedIslMapLike): + return BasicMap( # pylint: disable=too-many-function-args + result, + self._name_to_dim, + self._dimtype_to_names, + ) - return make_basic_set(obj.convex_hull()) + return BasicSet( # pylint: disable=too-many-function-args + result, + self._name_to_dim, + self._dimtype_to_names, + ) def eliminate(self: Self, names_to_eliminate: str | Collection[str]) -> Self: + """ + Eliminate constraints involving the named dimensions without removing them. + """ if isinstance(names_to_eliminate, str): names_to_eliminate = [names_to_eliminate] @@ -121,6 +139,12 @@ def add_constraint( self: Self, constraints: str | Collection[str], ) -> Self: + """ + Return a copy intersected with additional named constraints. + + :arg constraints: One constraint string or a collection of constraint + strings in isl syntax, written using this object's dimension names. + """ if isinstance(constraints, str): constraints = [constraints] else: @@ -180,6 +204,9 @@ def gist(self: BasicSet, context: _NamedIslSetLike) -> BasicSet | Set: ... def gist(self: Set, context: _NamedIslSetLike) -> Set: ... def gist(self, context: _NamedIslSetLike) -> _NamedIslSetLike: + """ + Simplify this object under the assumptions described by *context*. + """ self_aligned, context_aligned = _align_two(self, context) result = self_aligned._obj.gist(context_aligned._obj) @@ -199,6 +226,9 @@ def gist(self, context: _NamedIslSetLike) -> _NamedIslSetLike: ) def project_out(self: Self, names_to_project_out: str | Collection[str]) -> Self: + """ + Return a copy with the named dimensions projected out. + """ if isinstance(names_to_project_out, str): names_to_project_out = [names_to_project_out] @@ -244,6 +274,9 @@ def project_out_except( self: Self, names_to_keep: str | Collection[str], ) -> Self: + """ + Project out every dimension except those named in *names_to_keep*. + """ if isinstance(names_to_keep, str): names_to_keep = [names_to_keep] if names_to_keep else [] @@ -255,17 +288,47 @@ def project_out_except( return self.project_out(names_to_project_out) def dim_max(self, name: str) -> isl.PwAff: - return self._obj.dim_max(self._name_to_dim[name]) + """ + Return the parametric maximum of the named dimension. + """ + obj, dim = self._dim_bound_object_and_dim(name) + return obj.dim_max(dim) def dim_min(self, name: str) -> isl.PwAff: - return self._obj.dim_min(self._name_to_dim[name]) + """ + Return the parametric minimum of the named dimension. + """ + obj, dim = self._dim_bound_object_and_dim(name) + return obj.dim_min(dim) + + def _dim_bound_object_and_dim( + self, name: str + ) -> tuple[isl.BasicSet | isl.Set, int]: + if name not in self.names: + raise ValueError(f"unknown name: {name}") + if name in self.parameter_names: + raise ValueError(f"cannot compute a bound for parameter: {name}") + + if isinstance(self, _NamedIslMapLike): + bound_set = self.domain() if name in self.input_names else self.range() + obj = bound_set._reconstruct_isl_object() + assert isinstance(obj, isl.BasicSet | isl.Set) + return obj, bound_set._name_to_dim[name] - def is_empty(self) -> bool: obj = self._reconstruct_isl_object() - assert isinstance(obj, isl.Set | isl.Map) - return bool(obj.is_empty()) + assert isinstance(obj, isl.BasicSet | isl.Set) + return obj, self._name_to_dim[name] + + def is_empty(self) -> bool: + """ + Return whether this object contains no integer points. + """ + return bool(self._obj.is_empty()) def as_pw_multi_aff(self) -> isl.PwMultiAff: + """ + Reconstruct and convert this object to :class:`islpy.PwMultiAff`. + """ obj = self._reconstruct_isl_object() assert isinstance(obj, isl.Set | isl.Map) return obj.as_pw_multi_aff() @@ -289,6 +352,9 @@ def __and__(self: BasicSet, other: BasicSet | Set) -> BasicSet | Set: ... def __and__(self: Set, other: BasicSet | Set) -> Set: ... def __and__(self, other: _NamedIslSetLike) -> _NamedIslSetLike: + """ + Return the intersection of two compatible named set-like objects. + """ return _apply_set_like_binary_op(self, other, _set_like_and) @overload @@ -304,6 +370,9 @@ def __or__(self: BasicSet, other: BasicSet | Set) -> BasicSet | Set: ... def __or__(self: Set, other: BasicSet | Set) -> Set: ... def __or__(self, other: _NamedIslSetLike) -> _NamedIslSetLike: + """ + Return the union of two compatible named set-like objects. + """ return _apply_set_like_binary_op(self, other, _set_like_or) @overload @@ -319,6 +388,9 @@ def __sub__(self: BasicSet, other: BasicSet | Set) -> BasicSet | Set: ... def __sub__(self: Set, other: BasicSet | Set) -> Set: ... def __sub__(self, other: _NamedIslSetLike) -> _NamedIslSetLike: + """ + Return the set difference with *other* removed. + """ return _apply_set_like_binary_op(self, other, operator.sub) @override @@ -335,21 +407,39 @@ def __eq__(self, other: object) -> bool: return aligned_self._obj.plain_is_equal(aligned_other._obj) def __lt__(self, other: _NamedIslSetLike) -> bool: + """ + Return whether this object is a strict subset of *other*. + """ return _compare_set_like(self, other, isl.Set.is_strict_subset) def __le__(self, other: _NamedIslSetLike) -> bool: + """ + Return whether this object is a subset of *other*. + """ return _compare_set_like(self, other, isl.Set.is_subset) def __gt__(self, other: _NamedIslSetLike) -> bool: + """ + Return whether this object is a strict superset of *other*. + """ return _compare_set_like(other, self, isl.Set.is_strict_subset) def __ge__(self, other: _NamedIslSetLike) -> bool: + """ + Return whether this object is a superset of *other*. + """ return _compare_set_like(other, self, isl.Set.is_subset) @final @dataclass(frozen=True, eq=False) class BasicSet(_NamedIslSetLike): + """ + Name-aware wrapper around :class:`islpy.BasicSet`. + + Construct instances with :func:`make_basic_set`. + """ + @override def add_input_names(self, names_to_add: Collection[str]) -> BasicSet: raise NotImplementedError @@ -376,6 +466,9 @@ def make_basic_set(src: isl.BasicSet) -> BasicSet: ... def make_basic_set(src: str | isl.BasicSet, ctx: isl.Context | None = None) -> BasicSet: + """ + Create a :class:`BasicSet` from isl syntax or an :class:`islpy.BasicSet`. + """ obj = isl.BasicSet(src, ctx) if isinstance(src, str) else src set_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) assert isinstance(set_obj, isl.Set) @@ -385,6 +478,12 @@ def make_basic_set(src: str | isl.BasicSet, ctx: isl.Context | None = None) -> B @final @dataclass(frozen=True, eq=False) class Set(_NamedIslSetLike): + """ + Name-aware wrapper around :class:`islpy.Set`. + + Construct instances with :func:`make_set`. + """ + @override def add_input_names(self, names_to_add: Collection[str]) -> Set: raise NotImplementedError @@ -396,6 +495,9 @@ def _reconstruct_isl_object(self) -> isl.Set: return obj def get_basic_sets(self) -> Sequence[BasicSet]: + """ + Return the basic-set pieces of this set. + """ isl_obj = self._reconstruct_isl_object() bsets = isl_obj.get_basic_sets() @@ -451,11 +553,11 @@ def _apply_set_like_binary_op( else: result_type = Set - return result_type( # pylint: disable=too-many-function-args - result, - lhs._name_to_dim, - lhs._dimtype_to_names, - ) + return result_type( # pylint: disable=too-many-function-args + result, + lhs._name_to_dim, + lhs._dimtype_to_names, + ) def _compare_set_like( @@ -568,6 +670,9 @@ def _validate_composable( return self._ordered_names(lhs_names) def intersect_domain(self, domain: BasicSet | Set) -> BasicMap | Map: + """ + Return this map restricted to *domain*. + """ domain_obj = domain._reconstruct_isl_object() assert isinstance(domain_obj, isl.BasicSet | isl.Set) result = _apply_set_like_binary_op( @@ -582,6 +687,9 @@ def intersect_domain(self, domain: BasicSet | Set) -> BasicMap | Map: return result def intersect_range(self, range_: BasicSet | Set) -> BasicMap | Map: + """ + Return this map restricted to *range_*. + """ range_obj = range_._reconstruct_isl_object() assert isinstance(range_obj, isl.BasicSet | isl.Set) result = _apply_set_like_binary_op( @@ -596,6 +704,11 @@ def intersect_range(self, range_: BasicSet | Set) -> BasicMap | Map: return result def apply_range(self, other: BasicMap | Map) -> BasicMap | Map: + """ + Compose this map with *other* on this map's range. + + The output names of this map must match the input names of *other*. + """ ordered_names = self._validate_composable( isl.dim_type.out, other, isl.dim_type.in_ ) @@ -608,6 +721,11 @@ def apply_range(self, other: BasicMap | Map) -> BasicMap | Map: return self._wrap_map_result(result) def apply_domain(self, other: BasicMap | Map) -> BasicMap | Map: + """ + Compose *other* with this map on this map's domain. + + The input names of this map must match the output names of *other*. + """ ordered_names = self._validate_composable( isl.dim_type.in_, other, isl.dim_type.out ) @@ -620,15 +738,24 @@ def apply_domain(self, other: BasicMap | Map) -> BasicMap | Map: return self._wrap_map_result(result) def reverse(self) -> BasicMap | Map: + """ + Return the map with domain and range exchanged. + """ return self._wrap_map_result(self._map_obj().reverse()) def domain(self) -> BasicSet | Set: + """ + Return the domain as a named set. + """ domain = self._map_obj().domain() if isinstance(domain, isl.BasicSet): return make_basic_set(domain) return make_set(domain) def range(self) -> BasicSet | Set: + """ + Return the range as a named set. + """ range_ = self._map_obj().range() if isinstance(range_, isl.BasicSet): return make_basic_set(range_) @@ -644,6 +771,9 @@ def make_set(src: isl.Set) -> Set: ... def make_set(src: isl.Set | str, ctx: isl.Context | None = None) -> Set: + """ + Create a :class:`Set` from isl syntax or an :class:`islpy.Set`. + """ obj = isl.Set(src, ctx) if isinstance(src, str) else src set_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) assert isinstance(set_obj, isl.Set) @@ -653,8 +783,17 @@ def make_set(src: isl.Set | str, ctx: isl.Context | None = None) -> Set: @final @dataclass(frozen=True, eq=False) class BasicMap(_NamedIslMapLike): + """ + Name-aware wrapper around :class:`islpy.BasicMap`. + + Construct instances with :func:`make_basic_map`. + """ + @classmethod def empty(cls, space: isl.Space) -> BasicMap: + """ + Return an empty :class:`BasicMap` in *space*. + """ obj = isl.BasicMap.empty(space) set_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) assert isinstance(set_obj, isl.Set) @@ -732,6 +871,9 @@ def make_basic_map(src: isl.BasicMap) -> BasicMap: ... def make_basic_map(src: str | isl.BasicMap, ctx: isl.Context | None = None) -> BasicMap: + """ + Create a :class:`BasicMap` from isl syntax or an :class:`islpy.BasicMap`. + """ obj = isl.BasicMap(src, ctx) if isinstance(src, str) else src set_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) assert isinstance(set_obj, isl.Set) @@ -741,6 +883,12 @@ def make_basic_map(src: str | isl.BasicMap, ctx: isl.Context | None = None) -> B def make_map_from_domain_and_range( domain: BasicSet | Set, range_: BasicSet | Set ) -> BasicMap | Map: + """ + Create a named map from a named *domain* and named *range_*. + + A :class:`BasicMap` is returned when both inputs are basic sets; otherwise a + :class:`Map` is returned. + """ if isinstance(domain, BasicSet) and isinstance(range_, BasicSet): domain_obj = domain._reconstruct_isl_object() range_obj = range_._reconstruct_isl_object() @@ -768,8 +916,17 @@ def make_map_from_domain_and_range( @final @dataclass(frozen=True, eq=False) class Map(_NamedIslMapLike): + """ + Name-aware wrapper around :class:`islpy.Map`. + + Construct instances with :func:`make_map`. + """ + @classmethod def empty(cls, space: isl.Space) -> Map: + """ + Return an empty :class:`Map` in *space*. + """ return make_map(isl.Map.empty(space)) @override @@ -788,6 +945,9 @@ def make_map(src: isl.Map) -> Map: ... def make_map(src: str | isl.Map, ctx: isl.Context | None = None) -> Map: + """ + Create a :class:`Map` from isl syntax or an :class:`islpy.Map`. + """ obj = isl.Map(src, ctx) if isinstance(src, str) else src set_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) assert isinstance(set_obj, isl.Set) diff --git a/namedisl/test/test_set_like.py b/namedisl/test/test_set_like.py index a46f1a1..6212c7e 100644 --- a/namedisl/test/test_set_like.py +++ b/namedisl/test/test_set_like.py @@ -259,6 +259,16 @@ def test_set_dim_min(ndims: int): assert a.dim_min(name) == cond_pw_affs[i] +def test_set_dim_bounds_reconstruct_parameter_metadata() -> None: + set_ = nisl.make_set("[n] -> { [i] : 0 <= i < n }").rename_dims({ + "i": "j", + "n": "m", + }) + + assert set_.dim_min("j") == isl.PwAff("[m] -> { [(0)] : m > 0 }") + assert set_.dim_max("j") == isl.PwAff("[m] -> { [(-1 + m)] : m > 0 }") + + # }}} @@ -478,42 +488,30 @@ def test_subset_comparison_rejects_set_map_mismatch() -> None: _ = set_ <= map_ -def test_map_alignment_syncs_internal_output_positions_and_names() -> None: +def test_map_alignment_syncs_output_metadata() -> None: lhs = nisl.make_map("{ [i] -> [x] }") rhs = nisl.make_map("{ [i] -> [y, x] }") aligned_lhs, aligned_rhs = _align_two(lhs, rhs) - lhs_names = [ - aligned_lhs._obj.get_dim_name(isl.dim_type.set, dim) - for dim in range(aligned_lhs._obj.dim(isl.dim_type.set)) - ] - rhs_names = [ - aligned_rhs._obj.get_dim_name(isl.dim_type.set, dim) - for dim in range(aligned_rhs._obj.dim(isl.dim_type.set)) - ] + assert aligned_lhs.ordered_dim_names(isl.dim_type.out) == ("x", "y") + assert aligned_lhs.ordered_dim_names(isl.dim_type.in_) == ("i",) + assert aligned_rhs.ordered_dim_names(isl.dim_type.out) == ("x", "y") + assert aligned_rhs.ordered_dim_names(isl.dim_type.in_) == ("i",) - assert lhs_names == ["x", "y", "i"] - assert rhs_names == ["x", "y", "i"] - -def test_map_alignment_syncs_internal_input_and_parameter_positions_and_names() -> None: +def test_map_alignment_syncs_input_and_parameter_metadata() -> None: lhs = nisl.make_map("[n] -> { [i] -> [x] }") rhs = nisl.make_map("[m, n] -> { [j, i] -> [x] }") aligned_lhs, aligned_rhs = _align_two(lhs, rhs) - lhs_names = [ - aligned_lhs._obj.get_dim_name(isl.dim_type.set, dim) - for dim in range(aligned_lhs._obj.dim(isl.dim_type.set)) - ] - rhs_names = [ - aligned_rhs._obj.get_dim_name(isl.dim_type.set, dim) - for dim in range(aligned_rhs._obj.dim(isl.dim_type.set)) - ] - - assert lhs_names == ["x", "i", "j", "m", "n"] - assert rhs_names == ["x", "i", "j", "m", "n"] + assert aligned_lhs.ordered_dim_names(isl.dim_type.out) == ("x",) + assert aligned_lhs.ordered_dim_names(isl.dim_type.in_) == ("i", "j") + assert aligned_lhs.ordered_dim_names(isl.dim_type.param) == ("m", "n") + assert aligned_rhs.ordered_dim_names(isl.dim_type.out) == ("x",) + assert aligned_rhs.ordered_dim_names(isl.dim_type.in_) == ("i", "j") + assert aligned_rhs.ordered_dim_names(isl.dim_type.param) == ("m", "n") def test_map_apply_range_rejects_surviving_name_collisions() -> None: @@ -731,6 +729,19 @@ def test_map_dim_min(ndims_domain: int, ndims_range: int): assert m.dim_min(name) == out_lower_bound_pw_maffs[i] +def test_map_dim_bounds_reconstruct_parameter_metadata() -> None: + map_ = nisl.make_map( + "[n] -> { [i] -> [j] : 0 <= i < n and j = i + 1 }" + ).rename_dims({ + "i": "k", + "j": "l", + "n": "m", + }) + + assert map_.dim_min("k") == isl.PwAff("[m] -> { [(0)] : m > 0 }") + assert map_.dim_max("l") == isl.PwAff("[m] -> { [(m)] : m > 0 }") + + # }}} From c1c8424c6b920794017487a1fae331e7a944189a Mon Sep 17 00:00:00 2001 From: Addison Date: Thu, 21 May 2026 14:29:17 -0500 Subject: [PATCH 37/43] configure docs --- doc/conf.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/doc/conf.py b/doc/conf.py index 1dbaeb8..b9fa5f5 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -3,18 +3,25 @@ import sys from importlib import metadata from pathlib import Path +from urllib.request import urlopen sys.path.insert(0, str(Path(__file__).resolve().parents[1])) +_conf_url = "https://tiker.net/sphinxconfig-v0.py" +with urlopen(_conf_url) as _inf: + exec(compile(_inf.read(), _conf_url, "exec"), globals()) + extensions = [ "sphinx.ext.autodoc", "sphinx.ext.intersphinx", + "sphinx.ext.linkcode", ] autodoc_member_order = "bysource" autodoc_typehints = "none" +project = "namedisl" copyright = "2025- University of Illinois Board of Trustees" author = "Andreas Kloeckner" From 6d695e14bef197f415c8b3812c157dcdcf07cf7a Mon Sep 17 00:00:00 2001 From: Addison Date: Tue, 26 May 2026 08:22:50 -0500 Subject: [PATCH 38/43] make NamedIslObject (and subclasses) vary over internal and public ISL types --- namedisl/core.py | 57 ++++++++++++++++--------- namedisl/expression_like.py | 21 +++++++--- namedisl/set_like.py | 84 ++++++++++++++++++++++++------------- 3 files changed, 108 insertions(+), 54 deletions(-) diff --git a/namedisl/core.py b/namedisl/core.py index fb0cdb4..905505b 100644 --- a/namedisl/core.py +++ b/namedisl/core.py @@ -60,9 +60,21 @@ "IslExpressionLikeT", bound=IslExpressionLike, ) -IslObjectT = TypeVar("IslObjectT", bound=IslObject, covariant=True) # noqa: PLC0105 +InternalIslObjectT_co = TypeVar( + "InternalIslObjectT_co", + bound=IslObject, + covariant=True, +) +PublicIslObjectT_co = TypeVar( + "PublicIslObjectT_co", + bound=IslObject, + covariant=True, +) -NamedIslObjectT = TypeVar("NamedIslObjectT", bound="NamedIslObject[IslObject]") +NamedIslObjectT = TypeVar( + "NamedIslObjectT", + bound="NamedIslObject[IslObject, IslObject]", +) NameToDim: TypeAlias = Mapping[str, int] @@ -117,7 +129,9 @@ def _ensure_unique_public_names(obj: IslObject) -> None: seen_names.add(name) -def _strip_names(obj: IslObjectT) -> tuple[IslObjectT, NameToDim]: +def _strip_names( + obj: InternalIslObjectT_co, +) -> tuple[InternalIslObjectT_co, NameToDim]: name_to_dim: dict[str, int] = {} dt_to_strip = isl.dim_type.set if isinstance(obj, IslSetLike) else isl.dim_type.in_ @@ -138,7 +152,7 @@ def _strip_names(obj: IslObjectT) -> tuple[IslObjectT, NameToDim]: name_to_dim[name] = i - return cast("IslObjectT", stripped_obj), constantdict(name_to_dim) + return cast("InternalIslObjectT_co", stripped_obj), constantdict(name_to_dim) def _get_obj_dim_name(obj: IslObject, dt: isl.dim_type, dim: int) -> str: @@ -209,7 +223,9 @@ def _make_named_object_pieces(obj: IslObject) -> IslObjectPieces: return decon_obj, name_to_dim, dimtype_to_names -def _restore_names(obj: IslObjectT, name_to_dim: NameToDim) -> IslObjectT: +def _restore_names( + obj: InternalIslObjectT_co, name_to_dim: NameToDim +) -> InternalIslObjectT_co: """ Return a copy of *obj* with dimension names restored from *name_to_dim*. @@ -234,7 +250,7 @@ def _restore_names(obj: IslObjectT, name_to_dim: NameToDim) -> IslObjectT: restored_obj = restored_obj.get_pw_aff_list().get_at(0) return cast( - "IslObjectT", + "InternalIslObjectT_co", restored_obj.move_dims( isl.dim_type.in_, 0, @@ -255,7 +271,7 @@ def _restore_names(obj: IslObjectT, name_to_dim: NameToDim) -> IslObjectT: if isinstance(restored_obj, isl.UnionPwAff | isl.UnionPwMultiAff): raise NotImplementedError - return cast("IslObjectT", restored_obj) + return cast("InternalIslObjectT_co", restored_obj) def _get_dim_names(obj: IslObject, dt: isl.dim_type) -> frozenset[str]: @@ -378,7 +394,8 @@ def _find_contiguous_dim_chunks(dims: Sequence[int]) -> Mapping[int, int]: # FIXME: think through whether or not alphabetical ordering will require more # work on average than using one of the objects as a template in alignment def _find_joint_name_to_dim( - obj1: NamedIslObject[IslObjectT], obj2: NamedIslObject[IslObjectT] + obj1: NamedIslObject[IslObject, IslObject], + obj2: NamedIslObject[IslObject, IslObject], ) -> tuple[NameToDim, DimTypeToNames]: """ Enforces alphabetical ordering of all dimensions found in :arg:`obj1` and @@ -485,10 +502,12 @@ def _align_two( def _align_and_apply_binary_op( - lhs: NamedIslObject[IslObjectT], - rhs: NamedIslObject[IslObjectT], - op: Callable[[IslObjectT, IslObjectT], IslObjectT], -) -> NamedIslObject[IslObjectT]: + lhs: NamedIslObject[InternalIslObjectT_co, IslObject], + rhs: NamedIslObject[InternalIslObjectT_co, IslObject], + op: Callable[ + [InternalIslObjectT_co, InternalIslObjectT_co], InternalIslObjectT_co + ], +) -> NamedIslObject[InternalIslObjectT_co, IslObject]: """ Align *lhs* and *rhs*, apply *op* to their isl objects, and wrap the result. """ @@ -502,7 +521,7 @@ def _align_and_apply_binary_op( @dataclass(frozen=True, eq=False) -class NamedIslObject(ABC, Generic[IslObjectT]): +class NamedIslObject(ABC, Generic[InternalIslObjectT_co, PublicIslObjectT_co]): """ Base class for named isl wrappers. @@ -512,7 +531,7 @@ class NamedIslObject(ABC, Generic[IslObjectT]): delegating the underlying integer-set algebra to isl. """ - _obj: IslObjectT + _obj: InternalIslObjectT_co _name_to_dim: NameToDim # used to reconstruct ISL object @@ -664,7 +683,7 @@ def _add_grouped_names( ) return type(self)( - cast("IslObjectT", new_obj), + cast("InternalIslObjectT_co", new_obj), new_name_to_dim, new_dimtype_to_names, ) @@ -905,7 +924,7 @@ def equate_dims( ) return type(self)( - cast("IslObjectT", obj), + cast("InternalIslObjectT_co", obj), self._name_to_dim, self._dimtype_to_names, ) @@ -946,7 +965,7 @@ def _parameter_dim_start(self) -> int | None: ) return None - def _reconstruct_isl_object(self) -> IslObject: + def _reconstruct_isl_object(self) -> PublicIslObjectT_co: """ Relies on the dimension type ordering in :func:`_deconstruct_set_like_object`. @@ -996,9 +1015,9 @@ def _reconstruct_isl_object(self) -> IslObject: isl.dim_type.in_, 0, internal_dim, inp_start, len(self.input_names) ) - return obj + return cast("PublicIslObjectT_co", obj) - def get_isl_object(self) -> IslObject: + def get_isl_object(self) -> PublicIslObjectT_co: """ Reconstruct and return the wrapped public :mod:`islpy` object. """ diff --git a/namedisl/expression_like.py b/namedisl/expression_like.py index a308329..a56ba49 100644 --- a/namedisl/expression_like.py +++ b/namedisl/expression_like.py @@ -35,7 +35,7 @@ import operator from dataclasses import dataclass, replace -from typing import Any, cast, final, overload +from typing import Any, TypeVar, cast, final, overload from typing_extensions import Self, override @@ -49,6 +49,13 @@ ) +PublicMultiExpressionLikeT = TypeVar( + "PublicMultiExpressionLikeT", + isl.MultiAff, + isl.PwMultiAff, +) + + def _add_isl_expression( lhs: IslExpressionLikeT, rhs: IslExpressionLikeT | int ) -> IslExpressionLikeT: @@ -70,7 +77,9 @@ def _mul_isl_expression( # {{{ "base" named expression-likes (affs, pwaffs, qpolynomials, pwqpolynomials) @dataclass(frozen=True, eq=False) -class _NamedExpressionLike(NamedIslObject[IslExpressionLikeT]): +class _NamedExpressionLike( + NamedIslObject[IslExpressionLikeT, IslExpressionLikeT] +): # FIXME: Self is used here is because _NamedExpressionLike is generic, # leading to complaints from basedpyright def __add__(self, other: Self | int) -> Self: @@ -300,7 +309,9 @@ def make_pw_qpolynomial( # {{{ multi expression-likes (multiaff, pwmultiaff) @dataclass(frozen=True, eq=False) -class _NamedMultiExpressionLike(NamedIslObject[isl.Set]): +class _NamedMultiExpressionLike( + NamedIslObject[isl.Set, PublicMultiExpressionLikeT] +): """ Multi-expressions in ISL cannot have dimensions moved. As a workaround, we represent multi-expressions as sets internally. This is done during @@ -313,7 +324,7 @@ class _NamedMultiExpressionLike(NamedIslObject[isl.Set]): @final @dataclass(frozen=True, eq=False) -class PwMultiAff(_NamedMultiExpressionLike): +class PwMultiAff(_NamedMultiExpressionLike[isl.PwMultiAff]): """ Name-aware wrapper around :class:`islpy.PwMultiAff`. @@ -365,7 +376,7 @@ def make_pw_multi_aff( @final @dataclass(frozen=True, eq=False) -class MultiAff(_NamedMultiExpressionLike): +class MultiAff(_NamedMultiExpressionLike[isl.MultiAff]): """ Name-aware wrapper around :class:`islpy.MultiAff`. diff --git a/namedisl/set_like.py b/namedisl/set_like.py index 796fb4b..56d935d 100644 --- a/namedisl/set_like.py +++ b/namedisl/set_like.py @@ -37,7 +37,7 @@ import operator from abc import ABC from dataclasses import dataclass, replace -from typing import TYPE_CHECKING, Any, cast, final, overload +from typing import TYPE_CHECKING, Any, TypeVar, cast, final, overload from constantdict import constantdict from typing_extensions import Self, override @@ -45,6 +45,7 @@ import islpy as isl from .core import ( + IslSetLike, NamedIslObject, NameToDim, _align_obj, @@ -54,6 +55,12 @@ ) +PublicSetLikeT_co = TypeVar("PublicSetLikeT_co", bound=IslSetLike, covariant=True) +PublicMapLikeT_co = TypeVar( + "PublicMapLikeT_co", isl.BasicMap, isl.Map, covariant=True +) + + def _set_like_and(lhs: isl.Set, rhs: isl.Set) -> isl.Set: return cast("isl.Set", cast("Any", operator.and_)(lhs, rhs)) @@ -67,7 +74,7 @@ def _set_like_or(lhs: isl.Set, rhs: isl.Set) -> isl.Set: @dataclass(frozen=True, eq=False) -class _NamedIslSetLike(NamedIslObject[isl.Set], ABC): +class _NamedIslSetLike(NamedIslObject[isl.Set, PublicSetLikeT_co], ABC): """ Represents set-like objects with parameter dimensions as a non-parameterized set. Names are organized as contiguous chunks of dimension types, i.e. @@ -192,18 +199,24 @@ def add_constraint( ) @overload - def gist(self: BasicMap, context: _NamedIslSetLike) -> BasicMap | Map: ... + def gist( + self: BasicMap, context: _NamedIslSetLike[IslSetLike] + ) -> BasicMap | Map: ... @overload - def gist(self: Map, context: _NamedIslSetLike) -> Map: ... + def gist(self: Map, context: _NamedIslSetLike[IslSetLike]) -> Map: ... @overload - def gist(self: BasicSet, context: _NamedIslSetLike) -> BasicSet | Set: ... + def gist( + self: BasicSet, context: _NamedIslSetLike[IslSetLike] + ) -> BasicSet | Set: ... @overload - def gist(self: Set, context: _NamedIslSetLike) -> Set: ... + def gist(self: Set, context: _NamedIslSetLike[IslSetLike]) -> Set: ... - def gist(self, context: _NamedIslSetLike) -> _NamedIslSetLike: + def gist( + self, context: _NamedIslSetLike[IslSetLike] + ) -> _NamedIslSetLike[IslSetLike]: """ Simplify this object under the assumptions described by *context*. """ @@ -351,7 +364,9 @@ def __and__(self: BasicSet, other: BasicSet | Set) -> BasicSet | Set: ... @overload def __and__(self: Set, other: BasicSet | Set) -> Set: ... - def __and__(self, other: _NamedIslSetLike) -> _NamedIslSetLike: + def __and__( + self, other: _NamedIslSetLike[IslSetLike] + ) -> _NamedIslSetLike[IslSetLike]: """ Return the intersection of two compatible named set-like objects. """ @@ -369,7 +384,9 @@ def __or__(self: BasicSet, other: BasicSet | Set) -> BasicSet | Set: ... @overload def __or__(self: Set, other: BasicSet | Set) -> Set: ... - def __or__(self, other: _NamedIslSetLike) -> _NamedIslSetLike: + def __or__( + self, other: _NamedIslSetLike[IslSetLike] + ) -> _NamedIslSetLike[IslSetLike]: """ Return the union of two compatible named set-like objects. """ @@ -387,7 +404,9 @@ def __sub__(self: BasicSet, other: BasicSet | Set) -> BasicSet | Set: ... @overload def __sub__(self: Set, other: BasicSet | Set) -> Set: ... - def __sub__(self, other: _NamedIslSetLike) -> _NamedIslSetLike: + def __sub__( + self, other: _NamedIslSetLike[IslSetLike] + ) -> _NamedIslSetLike[IslSetLike]: """ Return the set difference with *other* removed. """ @@ -406,25 +425,25 @@ def __eq__(self, other: object) -> bool: assert isinstance(aligned_other._obj, isl.Set) return aligned_self._obj.plain_is_equal(aligned_other._obj) - def __lt__(self, other: _NamedIslSetLike) -> bool: + def __lt__(self, other: _NamedIslSetLike[IslSetLike]) -> bool: """ Return whether this object is a strict subset of *other*. """ return _compare_set_like(self, other, isl.Set.is_strict_subset) - def __le__(self, other: _NamedIslSetLike) -> bool: + def __le__(self, other: _NamedIslSetLike[IslSetLike]) -> bool: """ Return whether this object is a subset of *other*. """ return _compare_set_like(self, other, isl.Set.is_subset) - def __gt__(self, other: _NamedIslSetLike) -> bool: + def __gt__(self, other: _NamedIslSetLike[IslSetLike]) -> bool: """ Return whether this object is a strict superset of *other*. """ return _compare_set_like(other, self, isl.Set.is_strict_subset) - def __ge__(self, other: _NamedIslSetLike) -> bool: + def __ge__(self, other: _NamedIslSetLike[IslSetLike]) -> bool: """ Return whether this object is a superset of *other*. """ @@ -433,7 +452,7 @@ def __ge__(self, other: _NamedIslSetLike) -> bool: @final @dataclass(frozen=True, eq=False) -class BasicSet(_NamedIslSetLike): +class BasicSet(_NamedIslSetLike[isl.BasicSet]): """ Name-aware wrapper around :class:`islpy.BasicSet`. @@ -477,7 +496,7 @@ def make_basic_set(src: str | isl.BasicSet, ctx: isl.Context | None = None) -> B @final @dataclass(frozen=True, eq=False) -class Set(_NamedIslSetLike): +class Set(_NamedIslSetLike[isl.Set]): """ Name-aware wrapper around :class:`islpy.Set`. @@ -530,17 +549,17 @@ def _apply_set_like_binary_op( @overload def _apply_set_like_binary_op( - lhs: _NamedIslSetLike, - rhs: _NamedIslSetLike, + lhs: _NamedIslSetLike[IslSetLike], + rhs: _NamedIslSetLike[IslSetLike], op: Callable[[isl.Set, isl.Set], isl.Set], -) -> _NamedIslSetLike: ... +) -> _NamedIslSetLike[IslSetLike]: ... def _apply_set_like_binary_op( - lhs: _NamedIslSetLike, - rhs: _NamedIslSetLike, + lhs: _NamedIslSetLike[IslSetLike], + rhs: _NamedIslSetLike[IslSetLike], op: Callable[[isl.Set, isl.Set], isl.Set], -) -> _NamedIslSetLike: +) -> _NamedIslSetLike[IslSetLike]: lhs, rhs = _align_two(lhs, rhs) result = op(lhs._obj, rhs._obj) @@ -561,7 +580,9 @@ def _apply_set_like_binary_op( def _compare_set_like( - lhs: _NamedIslSetLike, rhs: _NamedIslSetLike, op: Callable[[isl.Set, isl.Set], bool] + lhs: _NamedIslSetLike[IslSetLike], + rhs: _NamedIslSetLike[IslSetLike], + op: Callable[[isl.Set, isl.Set], bool], ) -> bool: lhs_is_map = isinstance(lhs, _NamedIslMapLike) rhs_is_map = isinstance(rhs, _NamedIslMapLike) @@ -575,14 +596,17 @@ def _compare_set_like( return op(aligned_lhs._obj, aligned_rhs._obj) -class _NamedIslMapLike(_NamedIslSetLike): +class _NamedIslMapLike(_NamedIslSetLike[PublicMapLikeT_co]): @override - def _reconstruct_isl_object(self) -> isl.BasicMap | isl.Map: + def _reconstruct_isl_object(self) -> PublicMapLikeT_co: obj = super()._reconstruct_isl_object() if isinstance(obj, isl.Set): - return isl.Map.from_domain_and_range(isl.Set("{ [] }"), obj) + return cast( + "PublicMapLikeT_co", + isl.Map.from_domain_and_range(isl.Set("{ [] }"), obj), + ) assert isinstance(obj, isl.BasicMap | isl.Map) - return obj + return cast("PublicMapLikeT_co", obj) def _output_names(self) -> frozenset[str]: return frozenset(self._name_to_dim) - self.input_names - self.parameter_names @@ -623,7 +647,7 @@ def _reject_surviving_name_collisions( def _reorder_interface( self, dim_type: isl.dim_type, ordered_names: tuple[str, ...] - ) -> _NamedIslMapLike: + ) -> _NamedIslMapLike[PublicMapLikeT_co]: interface_names = ( self.input_names if dim_type == isl.dim_type.in_ else self._output_names() ) @@ -782,7 +806,7 @@ def make_set(src: isl.Set | str, ctx: isl.Context | None = None) -> Set: @final @dataclass(frozen=True, eq=False) -class BasicMap(_NamedIslMapLike): +class BasicMap(_NamedIslMapLike[isl.BasicMap]): """ Name-aware wrapper around :class:`islpy.BasicMap`. @@ -915,7 +939,7 @@ def make_map_from_domain_and_range( @final @dataclass(frozen=True, eq=False) -class Map(_NamedIslMapLike): +class Map(_NamedIslMapLike[isl.Map]): """ Name-aware wrapper around :class:`islpy.Map`. From d980b5fd656264a2a869014a2498b6e417c32245 Mon Sep 17 00:00:00 2001 From: Addison Date: Tue, 26 May 2026 14:46:35 -0500 Subject: [PATCH 39/43] fix conflicting, cross-dimension name handling; other minor fixes --- namedisl/core.py | 18 ++++- namedisl/expression_like.py | 105 ++++++++++++++++++++------ namedisl/test/test_expression_like.py | 13 +++- namedisl/test/test_set_like.py | 10 ++- 4 files changed, 119 insertions(+), 27 deletions(-) diff --git a/namedisl/core.py b/namedisl/core.py index 905505b..6decae2 100644 --- a/namedisl/core.py +++ b/namedisl/core.py @@ -245,10 +245,13 @@ def _restore_names( restored_obj.dim(isl.dim_type.in_), ) + restored_union_pw_aff = restored_obj.to_union_pw_aff() for name, dim in name_to_dim.items(): - restored_obj = restored_obj.set_dim_name(isl.dim_type.param, dim, name) + restored_union_pw_aff = restored_union_pw_aff.set_dim_name( + isl.dim_type.param, dim, name + ) - restored_obj = restored_obj.get_pw_aff_list().get_at(0) + restored_obj = restored_union_pw_aff.get_pw_aff_list().get_at(0) return cast( "InternalIslObjectT_co", restored_obj.move_dims( @@ -408,6 +411,17 @@ def _find_joint_name_to_dim( all_inp_names = obj1.input_names | obj2.input_names all_param_names = obj1.parameter_names | obj2.parameter_names + duplicate_names = ( + (all_set_names & all_inp_names) + | (all_set_names & all_param_names) + | (all_inp_names & all_param_names) + ) + if duplicate_names: + raise ValueError( + "duplicate dimension names across dimension types: " + + ", ".join(sorted(duplicate_names)) + ) + dt_to_names: DimTypeToNames = {} dt_to_names[isl.dim_type.param] = all_param_names if isinstance(obj1._obj, IslSetLike | IslMultiExpressionLike) or isinstance( diff --git a/namedisl/expression_like.py b/namedisl/expression_like.py index a56ba49..d6800e0 100644 --- a/namedisl/expression_like.py +++ b/namedisl/expression_like.py @@ -34,7 +34,8 @@ """ import operator -from dataclasses import dataclass, replace +from collections.abc import Callable +from dataclasses import dataclass from typing import Any, TypeVar, cast, final, overload from typing_extensions import Self, override @@ -42,9 +43,11 @@ import islpy as isl from .core import ( + DimTypeToNames, IslExpressionLikeT, NamedIslObject, - _align_and_apply_binary_op, + NameToDim, + _align_two, _make_named_object_pieces, ) @@ -74,6 +77,16 @@ def _mul_isl_expression( return cast("IslExpressionLikeT", cast("Any", operator.mul)(lhs, rhs)) +def _explicitly_promote_isl_expressions( + lhs: IslExpressionLikeT, rhs: IslExpressionLikeT +) -> tuple[IslExpressionLikeT, IslExpressionLikeT]: + if isinstance(lhs, isl.Aff) and isinstance(rhs, isl.PwAff): + return cast("IslExpressionLikeT", lhs.to_pw_aff()), rhs + if isinstance(lhs, isl.PwAff) and isinstance(rhs, isl.Aff): + return lhs, cast("IslExpressionLikeT", rhs.to_pw_aff()) + return lhs, rhs + + # {{{ "base" named expression-likes (affs, pwaffs, qpolynomials, pwqpolynomials) @dataclass(frozen=True, eq=False) @@ -82,47 +95,50 @@ class _NamedExpressionLike( ): # FIXME: Self is used here is because _NamedExpressionLike is generic, # leading to complaints from basedpyright - def __add__(self, other: Self | int) -> Self: + def __add__( + self, other: _NamedExpressionLike[IslExpressionLikeT] | int + ) -> _NamedExpressionLike[IslExpressionLikeT]: """ Add another compatible named expression or an integer. """ if isinstance(other, int): - return replace( - self, - _obj=_add_isl_expression(self._obj, other), - _name_to_dim=self._name_to_dim, - _dimtype_to_names=self._dimtype_to_names + return _wrap_expression_result( + _add_isl_expression(self._obj, other), + self._name_to_dim, + self._dimtype_to_names, ) - return _align_and_apply_binary_op(self, other, _add_isl_expression) + return _align_and_apply_expression_op(self, other, _add_isl_expression) - def __sub__(self, other: Self | int) -> Self: + def __sub__( + self, other: _NamedExpressionLike[IslExpressionLikeT] | int + ) -> _NamedExpressionLike[IslExpressionLikeT]: """ Subtract another compatible named expression or an integer. """ if isinstance(other, int): - return replace( - self, - _obj=_sub_isl_expression(self._obj, other), - _name_to_dim=self._name_to_dim, - _dimtype_to_names=self._dimtype_to_names + return _wrap_expression_result( + _sub_isl_expression(self._obj, other), + self._name_to_dim, + self._dimtype_to_names, ) - return _align_and_apply_binary_op(self, other, _sub_isl_expression) + return _align_and_apply_expression_op(self, other, _sub_isl_expression) - def __mul__(self, other: Self | int) -> Self: + def __mul__( + self, other: _NamedExpressionLike[IslExpressionLikeT] | int + ) -> _NamedExpressionLike[IslExpressionLikeT]: """ Multiply by another compatible named expression or an integer. """ if isinstance(other, int): - return replace( - self, - _obj=_mul_isl_expression(self._obj, other), - _name_to_dim=self._name_to_dim, - _dimtype_to_names=self._dimtype_to_names + return _wrap_expression_result( + _mul_isl_expression(self._obj, other), + self._name_to_dim, + self._dimtype_to_names, ) - return _align_and_apply_binary_op(self, other, _mul_isl_expression) + return _align_and_apply_expression_op(self, other, _mul_isl_expression) def is_zero(self) -> bool: """ @@ -303,6 +319,49 @@ def make_pw_qpolynomial( assert isinstance(pw_qp_obj, isl.PwQPolynomial) return PwQPolynomial(pw_qp_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args + +def _wrap_expression_result( + result: IslExpressionLikeT, + name_to_dim: NameToDim, + dimtype_to_names: DimTypeToNames, +) -> _NamedExpressionLike[IslExpressionLikeT]: + if isinstance(result, isl.Aff): + return cast( + "_NamedExpressionLike[IslExpressionLikeT]", + Aff(result, name_to_dim, dimtype_to_names), # pylint: disable=too-many-function-args + ) + if isinstance(result, isl.PwAff): + return cast( + "_NamedExpressionLike[IslExpressionLikeT]", + PwAff(result, name_to_dim, dimtype_to_names), # pylint: disable=too-many-function-args + ) + if isinstance(result, isl.QPolynomial): + return cast( + "_NamedExpressionLike[IslExpressionLikeT]", + QPolynomial(result, name_to_dim, dimtype_to_names), # pylint: disable=too-many-function-args + ) + if isinstance(result, isl.PwQPolynomial): + return cast( + "_NamedExpressionLike[IslExpressionLikeT]", + PwQPolynomial(result, name_to_dim, dimtype_to_names), # pylint: disable=too-many-function-args + ) + raise TypeError(f"unsupported expression result type: {type(result).__name__}") + + +def _align_and_apply_expression_op( + lhs: _NamedExpressionLike[IslExpressionLikeT], + rhs: _NamedExpressionLike[IslExpressionLikeT], + op: Callable[[IslExpressionLikeT, IslExpressionLikeT], IslExpressionLikeT], +) -> _NamedExpressionLike[IslExpressionLikeT]: + lhs, rhs = _align_two(lhs, rhs) + lhs_obj, rhs_obj = _explicitly_promote_isl_expressions(lhs._obj, rhs._obj) + result = op(lhs_obj, rhs_obj) + return _wrap_expression_result( + result, + lhs._name_to_dim, + lhs._dimtype_to_names, + ) + # }}} diff --git a/namedisl/test/test_expression_like.py b/namedisl/test/test_expression_like.py index fa1975b..ae3b59c 100644 --- a/namedisl/test/test_expression_like.py +++ b/namedisl/test/test_expression_like.py @@ -25,7 +25,6 @@ THE SOFTWARE. """ - import islpy as isl import namedisl as nisl @@ -136,6 +135,18 @@ def test_pwaff_binary_ops(): assert npwaff_m_3._reconstruct_isl_object() == pwaff_m_3 + +def test_mixed_aff_and_pwaff_binary_op_promotes_to_pwaff() -> None: + aff = nisl.make_aff("{ [i] -> [i] }") + pwaff = nisl.make_pw_aff("{ [i] -> [i] }") + + result = aff + pwaff + + assert isinstance(result, nisl.PwAff) + assert result._reconstruct_isl_object() == ( + aff._reconstruct_isl_object().to_pw_aff() + pwaff._reconstruct_isl_object() + ) + # }}} diff --git a/namedisl/test/test_set_like.py b/namedisl/test/test_set_like.py index 6212c7e..836644f 100644 --- a/namedisl/test/test_set_like.py +++ b/namedisl/test/test_set_like.py @@ -31,7 +31,7 @@ import namedisl as nisl from .utils_for_tests import generate_random_named_map, generate_random_named_set -from namedisl.core import _align_two +from namedisl.core import _align_two, _find_joint_name_to_dim # {{{ sets @@ -117,6 +117,14 @@ def test_set_intersection(ndims: int, has_params: bool): assert (a & b) == result +def test_set_intersection_rejects_name_collision_across_dim_types() -> None: + set_with_n = nisl.make_set("{ [n] }") + param_with_n = nisl.make_set("[n] -> { [i] }") + + with pytest.raises(ValueError, match="duplicate|collision"): + _find_joint_name_to_dim(set_with_n, param_with_n) + + def test_set_add_constraint_uses_named_dimensions() -> None: set_ = nisl.make_set("{ [j, i] }") From 7ee2e654a8583b5c98067da9c62d409c80970a5a Mon Sep 17 00:00:00 2001 From: Addison Date: Tue, 9 Jun 2026 09:45:29 -0500 Subject: [PATCH 40/43] check that names exist before operating; Multi expression overhaul --- namedisl/core.py | 135 +++++++++++-- namedisl/expression_like.py | 275 +++++++++++++++++++++++--- namedisl/set_like.py | 14 +- namedisl/test/test_expression_like.py | 184 ++++++++++++++++- namedisl/test/test_set_like.py | 38 +++- 5 files changed, 596 insertions(+), 50 deletions(-) diff --git a/namedisl/core.py b/namedisl/core.py index 6decae2..3f9ad36 100644 --- a/namedisl/core.py +++ b/namedisl/core.py @@ -51,10 +51,12 @@ IslBaseExpressionLike = isl.Aff | isl.QPolynomial IslPwExpressionLike = isl.PwAff | isl.PwQPolynomial IslMultiExpressionLike = isl.MultiAff | isl.PwMultiAff +MultiExpressionParts: TypeAlias = Mapping[int, object] IslExpressionLike = IslBaseExpressionLike | IslPwExpressionLike | IslMultiExpressionLike IslSetLike = isl.BasicSet | isl.BasicMap | isl.Set | isl.Map IslObject = IslSetLike | IslExpressionLike | IslMultiExpressionLike +InternalIslObject = IslObject | MultiExpressionParts IslExpressionLikeT = TypeVar( "IslExpressionLikeT", @@ -62,7 +64,7 @@ ) InternalIslObjectT_co = TypeVar( "InternalIslObjectT_co", - bound=IslObject, + bound=InternalIslObject, covariant=True, ) PublicIslObjectT_co = TypeVar( @@ -73,7 +75,7 @@ NamedIslObjectT = TypeVar( "NamedIslObjectT", - bound="NamedIslObject[IslObject, IslObject]", + bound="NamedIslObject[InternalIslObject, IslObject]", ) NameToDim: TypeAlias = Mapping[str, int] @@ -109,6 +111,17 @@ def _normalize_public_dim_type(dim_type: isl.dim_type) -> isl.dim_type: return dim_type +def _is_multi_expression_parts(obj: object) -> bool: + return isinstance(obj, Mapping) + + +def _uses_explicit_input_metadata(obj: object) -> bool: + return ( + isinstance(obj, IslSetLike | IslMultiExpressionLike) + or _is_multi_expression_parts(obj) + ) + + def _ensure_unique_public_names(obj: IslObject) -> None: if isinstance(obj, IslSetLike | IslMultiExpressionLike): dim_types = (isl.dim_type.set, isl.dim_type.in_, isl.dim_type.param) @@ -424,8 +437,8 @@ def _find_joint_name_to_dim( dt_to_names: DimTypeToNames = {} dt_to_names[isl.dim_type.param] = all_param_names - if isinstance(obj1._obj, IslSetLike | IslMultiExpressionLike) or isinstance( - obj2._obj, IslSetLike | IslMultiExpressionLike + if _uses_explicit_input_metadata(obj1._obj) or _uses_explicit_input_metadata( + obj2._obj ): dt_to_names[isl.dim_type.in_] = all_inp_names @@ -453,6 +466,48 @@ def _align_obj( The isl object is moved or expanded as needed, but public names are carried in metadata and are restored only when a raw isl object is reconstructed. """ + if _is_multi_expression_parts(named_obj._obj): + old_output_names = named_obj._ordered_names_for_dim_type(isl.dim_type.set) + output_names = tuple( + name + for name, _ in sorted(ordering.items(), key=lambda x: x[1]) + if name not in dimtype_to_names.get(isl.dim_type.in_, frozenset()) + and name not in dimtype_to_names.get(isl.dim_type.param, frozenset()) + ) + + if set(output_names) != set(old_output_names): + raise NotImplementedError( + "moving dimensions between output and non-output dimensions " + "is not implemented for multi expressions" + ) + + part_ordering: NameToDim = constantdict({ + name: dim - len(output_names) + for name, dim in ordering.items() + if name not in output_names + }) + part_dimtype_to_names: DimTypeToNames = constantdict({ + isl.dim_type.param: dimtype_to_names.get( + isl.dim_type.param, frozenset() + ) + }) + + new_parts = constantdict({ + new_dim: _align_obj( + cast("NamedIslObject[InternalIslObject, IslObject]", + named_obj._obj[named_obj._name_to_dim[name]]), + part_ordering, + part_dimtype_to_names, + ) + for new_dim, name in enumerate(output_names) + }) + + return type(named_obj)( + new_parts, + ordering, + dimtype_to_names, + ) + new_isl_obj = cast( "IslSetLike | IslBaseExpressionLike | IslPwExpressionLike | isl.MultiAff", named_obj._obj, @@ -564,7 +619,7 @@ def _names_for_dim_type(self, dim_type: isl.dim_type) -> frozenset[str]: if dim_type == isl.dim_type.param: return self.parameter_names - if isinstance(self._obj, IslSetLike | IslMultiExpressionLike): + if _uses_explicit_input_metadata(self._obj): if dim_type == isl.dim_type.in_: return self._metadata_input_names if dim_type == isl.dim_type.set: @@ -622,14 +677,11 @@ def _metadata_from_chunk_names( def _add_names_by_dim_type( self, names_to_add: Collection[str], dim_type: isl.dim_type ) -> Self: - if isinstance(self._obj, isl.PwMultiAff): - raise NotImplementedError - dim_type = _normalize_public_dim_type(dim_type) if dim_type not in (isl.dim_type.set, isl.dim_type.in_, isl.dim_type.param): raise ValueError(f"unsupported dim type: {dim_type}") if ( - not isinstance(self._obj, IslSetLike | IslMultiExpressionLike) + not _uses_explicit_input_metadata(self._obj) and dim_type == isl.dim_type.set ): raise ValueError(f"unsupported dim type: {dim_type}") @@ -652,9 +704,6 @@ def _add_names_by_dim_type( def _add_grouped_names( self, grouped_names: Mapping[isl.dim_type, Collection[str]] ) -> Self: - if isinstance(self._obj, isl.PwMultiAff): - raise NotImplementedError - seen_names: set[str] = set() for names in grouped_names.values(): for name in names: @@ -664,13 +713,53 @@ def _add_grouped_names( raise ValueError(f"name already exists: {name}") seen_names.add(name) + if _is_multi_expression_parts(self._obj): + if grouped_names[isl.dim_type.set]: + raise NotImplementedError( + "adding output dimensions to multi expressions is not " + "implemented" + ) + + new_obj = self._obj + for dim_type in (isl.dim_type.param, isl.dim_type.in_): + names_to_add = grouped_names[dim_type] + if not names_to_add: + continue + new_obj = constantdict({ + dim: cast("NamedIslObject[InternalIslObject, IslObject]", part) + .add_dim_names(names_to_add, dim_type) + for dim, part in new_obj.items() + }) + + chunk_names = { + dt: list(names) for dt, names in self._ordered_name_chunks().items() + } + for dim_type in (isl.dim_type.param, isl.dim_type.in_): + names_to_add = grouped_names[dim_type] + if names_to_add: + chunk_names[dim_type] = [ + *names_to_add, + *chunk_names[dim_type], + ] + + new_name_to_dim, new_dimtype_to_names = self._metadata_from_chunk_names( + chunk_names, + has_inputs=True, + ) + + return type(self)( + new_obj, + new_name_to_dim, + new_dimtype_to_names, + ) + new_obj = self._obj chunk_names = { dt: list(names) for dt, names in self._ordered_name_chunks().items() } internal_dim_type = ( isl.dim_type.set - if isinstance(new_obj, IslSetLike | IslMultiExpressionLike) + if _uses_explicit_input_metadata(new_obj) else isl.dim_type.in_ ) @@ -693,7 +782,7 @@ def _add_grouped_names( new_name_to_dim, new_dimtype_to_names = self._metadata_from_chunk_names( chunk_names, - has_inputs=isinstance(new_obj, IslSetLike | IslMultiExpressionLike), + has_inputs=_uses_explicit_input_metadata(new_obj), ) return type(self)( @@ -836,7 +925,7 @@ def move_dims( new_name_to_dim, new_dimtype_to_names = self._metadata_from_chunk_names( chunk_names, - has_inputs=True, + has_inputs=_uses_explicit_input_metadata(self._obj), ) return _align_obj(self, new_name_to_dim, new_dimtype_to_names) @@ -885,7 +974,21 @@ def rename_dims(self, renaming: Mapping[str, str]) -> Self: for dim_type, names in self._dimtype_to_names.items() }) - return type(self)(self._obj, new_name_to_dim, new_dimtype_to_names) + new_obj = self._obj + if _is_multi_expression_parts(new_obj): + new_obj = constantdict({ + dim: cast("NamedIslObject[InternalIslObject, IslObject]", part) + .rename_dims({ + old_name: new_name + for old_name, new_name in renaming.items() + if old_name in cast( + "NamedIslObject[InternalIslObject, IslObject]", part + ).names + }) + for dim, part in new_obj.items() + }) + + return type(self)(new_obj, new_name_to_dim, new_dimtype_to_names) @overload def equate_dims(self, name1: Mapping[str, str]) -> Self: ... diff --git a/namedisl/expression_like.py b/namedisl/expression_like.py index d6800e0..7a6ad8b 100644 --- a/namedisl/expression_like.py +++ b/namedisl/expression_like.py @@ -34,17 +34,18 @@ """ import operator -from collections.abc import Callable from dataclasses import dataclass -from typing import Any, TypeVar, cast, final, overload +from typing import TYPE_CHECKING, Any, TypeVar, cast, final, overload -from typing_extensions import Self, override +from constantdict import constantdict +from typing_extensions import override import islpy as isl from .core import ( DimTypeToNames, IslExpressionLikeT, + MultiExpressionParts, NamedIslObject, NameToDim, _align_two, @@ -52,6 +53,10 @@ ) +if TYPE_CHECKING: + from collections.abc import Callable + + PublicMultiExpressionLikeT = TypeVar( "PublicMultiExpressionLikeT", isl.MultiAff, @@ -77,6 +82,18 @@ def _mul_isl_expression( return cast("IslExpressionLikeT", cast("Any", operator.mul)(lhs, rhs)) +def _radd_isl_expression(lhs: int, rhs: IslExpressionLikeT) -> IslExpressionLikeT: + return cast("IslExpressionLikeT", cast("Any", operator.add)(lhs, rhs)) + + +def _rsub_isl_expression(lhs: int, rhs: IslExpressionLikeT) -> IslExpressionLikeT: + return cast("IslExpressionLikeT", cast("Any", operator.sub)(lhs, rhs)) + + +def _rmul_isl_expression(lhs: int, rhs: IslExpressionLikeT) -> IslExpressionLikeT: + return cast("IslExpressionLikeT", cast("Any", operator.mul)(lhs, rhs)) + + def _explicitly_promote_isl_expressions( lhs: IslExpressionLikeT, rhs: IslExpressionLikeT ) -> tuple[IslExpressionLikeT, IslExpressionLikeT]: @@ -93,8 +110,23 @@ def _explicitly_promote_isl_expressions( class _NamedExpressionLike( NamedIslObject[IslExpressionLikeT, IslExpressionLikeT] ): - # FIXME: Self is used here is because _NamedExpressionLike is generic, - # leading to complaints from basedpyright + @overload + def __add__(self: Aff, other: Aff | int) -> Aff: ... + + @overload + def __add__(self: Aff, other: PwAff) -> PwAff: ... + + @overload + def __add__(self: PwAff, other: Aff | PwAff | int) -> PwAff: ... + + @overload + def __add__(self: QPolynomial, other: QPolynomial | int) -> QPolynomial: ... + + @overload + def __add__( + self: PwQPolynomial, other: PwQPolynomial | int + ) -> PwQPolynomial: ... + def __add__( self, other: _NamedExpressionLike[IslExpressionLikeT] | int ) -> _NamedExpressionLike[IslExpressionLikeT]: @@ -110,6 +142,45 @@ def __add__( return _align_and_apply_expression_op(self, other, _add_isl_expression) + @overload + def __radd__(self: Aff, other: int) -> Aff: ... + + @overload + def __radd__(self: PwAff, other: int) -> PwAff: ... + + @overload + def __radd__(self: QPolynomial, other: int) -> QPolynomial: ... + + @overload + def __radd__(self: PwQPolynomial, other: int) -> PwQPolynomial: ... + + def __radd__(self, other: int) -> _NamedExpressionLike[IslExpressionLikeT]: + """ + Add this expression to an integer. + """ + return _wrap_expression_result( + _radd_isl_expression(other, self._obj), + self._name_to_dim, + self._dimtype_to_names, + ) + + @overload + def __sub__(self: Aff, other: Aff | int) -> Aff: ... + + @overload + def __sub__(self: Aff, other: PwAff) -> PwAff: ... + + @overload + def __sub__(self: PwAff, other: Aff | PwAff | int) -> PwAff: ... + + @overload + def __sub__(self: QPolynomial, other: QPolynomial | int) -> QPolynomial: ... + + @overload + def __sub__( + self: PwQPolynomial, other: PwQPolynomial | int + ) -> PwQPolynomial: ... + def __sub__( self, other: _NamedExpressionLike[IslExpressionLikeT] | int ) -> _NamedExpressionLike[IslExpressionLikeT]: @@ -125,6 +196,45 @@ def __sub__( return _align_and_apply_expression_op(self, other, _sub_isl_expression) + @overload + def __rsub__(self: Aff, other: int) -> Aff: ... + + @overload + def __rsub__(self: PwAff, other: int) -> PwAff: ... + + @overload + def __rsub__(self: QPolynomial, other: int) -> QPolynomial: ... + + @overload + def __rsub__(self: PwQPolynomial, other: int) -> PwQPolynomial: ... + + def __rsub__(self, other: int) -> _NamedExpressionLike[IslExpressionLikeT]: + """ + Subtract this expression from an integer. + """ + return _wrap_expression_result( + _rsub_isl_expression(other, self._obj), + self._name_to_dim, + self._dimtype_to_names, + ) + + @overload + def __mul__(self: Aff, other: Aff | int) -> Aff: ... + + @overload + def __mul__(self: Aff, other: PwAff) -> PwAff: ... + + @overload + def __mul__(self: PwAff, other: Aff | PwAff | int) -> PwAff: ... + + @overload + def __mul__(self: QPolynomial, other: QPolynomial | int) -> QPolynomial: ... + + @overload + def __mul__( + self: PwQPolynomial, other: PwQPolynomial | int + ) -> PwQPolynomial: ... + def __mul__( self, other: _NamedExpressionLike[IslExpressionLikeT] | int ) -> _NamedExpressionLike[IslExpressionLikeT]: @@ -140,6 +250,28 @@ def __mul__( return _align_and_apply_expression_op(self, other, _mul_isl_expression) + @overload + def __rmul__(self: Aff, other: int) -> Aff: ... + + @overload + def __rmul__(self: PwAff, other: int) -> PwAff: ... + + @overload + def __rmul__(self: QPolynomial, other: int) -> QPolynomial: ... + + @overload + def __rmul__(self: PwQPolynomial, other: int) -> PwQPolynomial: ... + + def __rmul__(self, other: int) -> _NamedExpressionLike[IslExpressionLikeT]: + """ + Multiply this expression by an integer. + """ + return _wrap_expression_result( + _rmul_isl_expression(other, self._obj), + self._name_to_dim, + self._dimtype_to_names, + ) + def is_zero(self) -> bool: """ Return whether this expression is identically zero. @@ -148,6 +280,8 @@ def is_zero(self) -> bool: @override def __eq__(self, other: object) -> bool: + if not isinstance(other, type(self)): + raise NotImplementedError("Objects are not of the same type") raise NotImplementedError @@ -367,19 +501,86 @@ def _align_and_apply_expression_op( # {{{ multi expression-likes (multiaff, pwmultiaff) +def _ordered_multi_dim_names( + obj: isl.MultiAff | isl.PwMultiAff, dim_type: isl.dim_type +) -> tuple[str, ...]: + space = obj.get_space() + names: list[str] = [] + for dim in range(obj.dim(dim_type)): + name = space.get_dim_name(dim_type, dim) + if name is None: + raise ValueError("duplicate or unnamed dimension found") + names.append(name) + return tuple(names) + + +def _make_multi_expression_parts( + obj: isl.MultiAff | isl.PwMultiAff, +) -> tuple[MultiExpressionParts, NameToDim, DimTypeToNames]: + output_names = _ordered_multi_dim_names(obj, isl.dim_type.out) + input_names = _ordered_multi_dim_names(obj, isl.dim_type.in_) + parameter_names = _ordered_multi_dim_names(obj, isl.dim_type.param) + + seen_names: set[str] = set() + for name in (*output_names, *input_names, *parameter_names): + if name in seen_names: + raise ValueError(f"duplicate dimension name found: {name}") + seen_names.add(name) + + parts: MultiExpressionParts = constantdict({ + dim: make_pw_aff( + obj.get_at(dim).to_pw_aff() + if isinstance(obj, isl.MultiAff) + else obj.get_at(dim) + ) + for dim in range(obj.dim(isl.dim_type.out)) + }) + + all_names = [*output_names, *input_names, *parameter_names] + name_to_dim: NameToDim = constantdict({ + name: dim for dim, name in enumerate(all_names) + }) + + dimtype_to_names: dict[isl.dim_type, frozenset[str]] = {} + if input_names: + dimtype_to_names[isl.dim_type.in_] = frozenset(input_names) + if parameter_names: + dimtype_to_names[isl.dim_type.param] = frozenset(parameter_names) + + return parts, name_to_dim, constantdict(dimtype_to_names) + + @dataclass(frozen=True, eq=False) class _NamedMultiExpressionLike( - NamedIslObject[isl.Set, PublicMultiExpressionLikeT] + NamedIslObject[MultiExpressionParts, PublicMultiExpressionLikeT] ): """ - Multi-expressions in ISL cannot have dimensions moved. As a workaround, we - represent multi-expressions as sets internally. This is done during - deconstruction by converting a multi-expression to a map, then converting - the resulting map to a set. During reconstruction, we simply follow the - deconstruction steps backwards (set -> map -> multi-expression). As such, - reconstruction is special-cased for each subclass. + Multi-expression components are stored directly as named :class:`PwAff` + parts, keyed by output dimension. """ + _obj: MultiExpressionParts + + def _multi_expression_context(self) -> isl.Context: + if self._obj: + first_part = cast("PwAff", next(iter(self._obj.values()))) + return first_part._obj.get_ctx() + return isl.DEFAULT_CONTEXT + + def _multi_expression_space(self) -> isl.Space: + return isl.Space.create_from_names( + self._multi_expression_context(), + params=list(self.ordered_dim_names(isl.dim_type.param)), + in_=list(self.ordered_dim_names(isl.dim_type.in_)), + out=list(self.ordered_dim_names(isl.dim_type.out)), + ) + + def _ordered_pw_aff_parts(self) -> tuple[isl.PwAff, ...]: + return tuple( + cast("PwAff", self._obj[dim])._reconstruct_isl_object() + for dim in range(self.dim(isl.dim_type.out)) + ) + @final @dataclass(frozen=True, eq=False) @@ -396,17 +597,24 @@ def get_at(self, name: str) -> PwAff: """ if name not in self._names_for_dim_type(isl.dim_type.set): raise ValueError(f"unknown output name: {name}") - return make_pw_aff( - self._reconstruct_isl_object().get_at(self._name_to_dim[name]) - ) + return cast("PwAff", self._obj[self._name_to_dim[name]]) @override def _reconstruct_isl_object(self) -> isl.PwMultiAff: - # deconstruction: isl.PwMultiAff -> isl.Map -> isl.Set - # reconstruction: isl.Set -> isl.Map -> isl.PwMultiAff - obj = super()._reconstruct_isl_object() - assert isinstance(obj, isl.Set | isl.Map) - return obj.as_pw_multi_aff() + space = self._multi_expression_space() + if not self._obj: + return isl.PwMultiAff.zero(space) + + pw_aff_list = isl.PwAffList.alloc( + self._multi_expression_context(), + len(self._obj), + ) + for part in self._ordered_pw_aff_parts(): + pw_aff_list = pw_aff_list.add(part) + + return isl.PwMultiAff.from_multi_pw_aff( + isl.MultiPwAff.from_pw_aff_list(space, pw_aff_list) + ) @overload @@ -428,8 +636,7 @@ def make_pw_multi_aff( """ obj = isl.PwMultiAff(src, ctx) if isinstance(src, str) else src - pw_maff_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) - assert isinstance(pw_maff_obj, isl.Set) + pw_maff_obj, name_to_dim, dimtype_to_names = _make_multi_expression_parts(obj) return PwMultiAff(pw_maff_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args @@ -442,21 +649,28 @@ class MultiAff(_NamedMultiExpressionLike[isl.MultiAff]): Construct instances with :func:`make_multi_aff`. """ - def get_at(self, name: str) -> Aff: + def get_at(self, name: str) -> PwAff: """ Return the output component named *name*. """ if name not in self._names_for_dim_type(isl.dim_type.set): raise ValueError(f"unknown output name: {name}") - return make_aff(self._reconstruct_isl_object().get_at(self._name_to_dim[name])) + return cast("PwAff", self._obj[self._name_to_dim[name]]) @override def _reconstruct_isl_object(self) -> isl.MultiAff: - # deconstruction: isl.MultiAff -> isl.Map -> isl.Set - # reconstruction: isl.Set -> isl.Map -> isl.PwMultiAff -> isl.MultiAff - obj = super()._reconstruct_isl_object() - assert isinstance(obj, isl.Set | isl.Map) - return obj.as_pw_multi_aff().as_multi_aff() + space = self._multi_expression_space() + if not self._obj: + return isl.MultiAff.zero(space) + + aff_list = isl.AffList.alloc( + self._multi_expression_context(), + len(self._obj), + ) + for part in self._ordered_pw_aff_parts(): + aff_list = aff_list.add(part.as_aff()) + + return isl.MultiAff.from_aff_list(space, aff_list) @overload @@ -475,8 +689,7 @@ def make_multi_aff( Create a :class:`MultiAff` from isl syntax or an :class:`islpy.MultiAff`. """ obj = isl.MultiAff(src, ctx) if isinstance(src, str) else src - maff_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) - assert isinstance(maff_obj, isl.Set) + maff_obj, name_to_dim, dimtype_to_names = _make_multi_expression_parts(obj) return MultiAff(maff_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args # }}} diff --git a/namedisl/set_like.py b/namedisl/set_like.py index 56d935d..a467a76 100644 --- a/namedisl/set_like.py +++ b/namedisl/set_like.py @@ -123,6 +123,12 @@ def eliminate(self: Self, names_to_eliminate: str | Collection[str]) -> Self: if isinstance(names_to_eliminate, str): names_to_eliminate = [names_to_eliminate] + missing_names = [ + name for name in names_to_eliminate if name not in self.names + ] + if missing_names: + raise ValueError(f"unknown names: {', '.join(missing_names)}") + dims_to_eliminate = sorted( self._name_to_dim[name] for name in names_to_eliminate ) @@ -246,6 +252,12 @@ def project_out(self: Self, names_to_project_out: str | Collection[str]) -> Self if isinstance(names_to_project_out, str): names_to_project_out = [names_to_project_out] + missing_names = [ + name for name in names_to_project_out if name not in self.names + ] + if missing_names: + raise ValueError(f"unknown names: {', '.join(missing_names)}") + names_to_remove = set(names_to_project_out) dims_to_remove = sorted(self._name_to_dim[name] for name in names_to_remove) @@ -415,7 +427,7 @@ def __sub__( @override def __eq__(self, other: object) -> bool: if not isinstance(other, type(self)): - raise TypeError("Objects are not of the same type") + raise NotImplementedError("Objects are not of the same type") aligned_self, aligned_other = _align_two(self, other) diff --git a/namedisl/test/test_expression_like.py b/namedisl/test/test_expression_like.py index ae3b59c..ed06449 100644 --- a/namedisl/test/test_expression_like.py +++ b/namedisl/test/test_expression_like.py @@ -25,6 +25,9 @@ THE SOFTWARE. """ +import pytest +from constantdict import constantdict + import islpy as isl import namedisl as nisl @@ -147,6 +150,121 @@ def test_mixed_aff_and_pwaff_binary_op_promotes_to_pwaff() -> None: aff._reconstruct_isl_object().to_pw_aff() + pwaff._reconstruct_isl_object() ) + +def test_expression_equality_type_mismatch_raises_not_implemented_error() -> None: + aff = nisl.make_aff("{ [i] -> [i] }") + pwaff = nisl.make_pw_aff("{ [i] -> [i] }") + + with pytest.raises(NotImplementedError, match="Objects are not of the same type"): + _ = aff == pwaff + + +def test_reflected_integer_expression_ops() -> None: + def is_zero( + expr: isl.Aff | isl.PwAff | isl.QPolynomial | isl.PwQPolynomial, + ) -> bool: + if isinstance(expr, isl.Aff): + return bool(expr.plain_is_zero()) + if isinstance(expr, isl.PwAff): + return bool(expr.plain_is_equal(expr * 0)) + return bool(expr.is_zero()) + + expressions = [ + nisl.make_aff("{ [i] -> [i] }"), + nisl.make_pw_aff("{ [i] -> [i] }"), + nisl.make_qpolynomial("{ [i] -> i }"), + nisl.make_pw_qpolynomial("{ [i] -> i }"), + ] + + for expr in expressions: + obj = expr._reconstruct_isl_object() + + assert is_zero((1 + expr)._reconstruct_isl_object() - (1 + obj)) + assert is_zero((1 - expr)._reconstruct_isl_object() - (1 - obj)) + assert is_zero((2 * expr)._reconstruct_isl_object() - (2 * obj)) + + +def _qpolynomial(spec: str) -> isl.QPolynomial: + return isl.PwQPolynomial(spec).get_pieces()[0][1] + + +def _assert_expression_equal(actual, expected) -> None: + if isinstance(actual, isl.Aff | isl.PwAff): + assert actual == expected + return + + assert (actual - expected).is_zero() + + +def test_move_dims_expression_param_to_input_reconstructs_like_isl() -> None: + cases = ( + ( + nisl.make_aff("[n] -> { [i] -> [i + n] }"), + isl.Aff("[n] -> { [i] -> [i + n] }"), + ), + ( + nisl.make_pw_aff("[n] -> { [i] -> [i + n] }"), + isl.PwAff("[n] -> { [i] -> [i + n] }"), + ), + ( + nisl.make_qpolynomial("[n] -> { [i] -> i + n }"), + _qpolynomial("[n] -> { [i] -> i + n }"), + ), + ( + nisl.make_pw_qpolynomial("[n] -> { [i] -> i + n }"), + isl.PwQPolynomial("[n] -> { [i] -> i + n }"), + ), + ) + + for named_expr, isl_expr in cases: + moved = named_expr.move_dims("n", isl.dim_type.in_) + expected = isl_expr.move_dims( + isl.dim_type.in_, + 1, + isl.dim_type.param, + 0, + 1, + ) + + _assert_expression_equal(moved._reconstruct_isl_object(), expected) + assert moved.input_names == frozenset({"i", "n"}) + assert moved.parameter_names == frozenset() + + +def test_move_dims_expression_input_to_param_reconstructs_like_isl() -> None: + cases = ( + ( + nisl.make_aff("[n] -> { [i] -> [i + n] }"), + isl.Aff("[n] -> { [i] -> [i + n] }"), + ), + ( + nisl.make_pw_aff("[n] -> { [i] -> [i + n] }"), + isl.PwAff("[n] -> { [i] -> [i + n] }"), + ), + ( + nisl.make_qpolynomial("[n] -> { [i] -> i + n }"), + _qpolynomial("[n] -> { [i] -> i + n }"), + ), + ( + nisl.make_pw_qpolynomial("[n] -> { [i] -> i + n }"), + isl.PwQPolynomial("[n] -> { [i] -> i + n }"), + ), + ) + + for named_expr, isl_expr in cases: + moved = named_expr.move_dims("i", isl.dim_type.param) + expected = isl_expr.move_dims( + isl.dim_type.param, + 1, + isl.dim_type.in_, + 0, + 1, + ) + + _assert_expression_equal(moved._reconstruct_isl_object(), expected) + assert moved.input_names == frozenset() + assert moved.parameter_names == frozenset({"n", "i"}) + # }}} @@ -155,7 +273,7 @@ def test_multi_aff_get_at_uses_name() -> None: maff = nisl.make_multi_aff( map_._reconstruct_isl_object().as_pw_multi_aff().as_multi_aff() ) - assert maff.get_at("x")._reconstruct_isl_object() == isl.Aff("{ [i] -> [(i)] }") + assert maff.get_at("x")._reconstruct_isl_object() == isl.PwAff("{ [i] -> [(i)] }") def test_pw_multi_aff_get_at_uses_name() -> None: @@ -164,6 +282,70 @@ def test_pw_multi_aff_get_at_uses_name() -> None: assert pmaff.get_at("y")._reconstruct_isl_object() == isl.PwAff("{ [i] -> [(2i)] }") +def test_multi_aff_stores_pw_aff_parts() -> None: + raw_maff = isl.MultiAff("{ [i] -> [x = i, y = 2i] }") + + maff = nisl.make_multi_aff(raw_maff) + + assert isinstance(maff._obj, constantdict) + assert not isinstance(maff._obj, isl.Set) + assert all(isinstance(part, nisl.PwAff) for part in maff._obj.values()) + assert maff.get_at("x") is maff._obj[0] + assert maff.get_at("y") is maff._obj[1] + assert maff._reconstruct_isl_object() == raw_maff + + +def test_pw_multi_aff_stores_pw_aff_parts() -> None: + raw_pmaff = isl.PwMultiAff("[n] -> { [i] -> [x = i + n, y = 2i] }") + + pmaff = nisl.make_pw_multi_aff(raw_pmaff) + + assert isinstance(pmaff._obj, constantdict) + assert not isinstance(pmaff._obj, isl.Set) + assert all(isinstance(part, nisl.PwAff) for part in pmaff._obj.values()) + assert pmaff.get_at("x") is pmaff._obj[0] + assert pmaff.get_at("y") is pmaff._obj[1] + assert pmaff._reconstruct_isl_object() == raw_pmaff + + +def test_multi_aff_rename_dims_updates_stored_parts() -> None: + maff = nisl.make_multi_aff("[n] -> { [i] -> [x = i + n] }") + + renamed = maff.rename_dims({"x": "z", "i": "j", "n": "m"}) + + part = renamed.get_at("z") + assert part.input_names == frozenset({"j"}) + assert part.parameter_names == frozenset({"m"}) + assert part._reconstruct_isl_object() == isl.PwAff("[m] -> { [j] -> [(j + m)] }") + + +def test_multi_aff_move_dims_updates_stored_parts() -> None: + maff = nisl.make_multi_aff("[n] -> { [i] -> [x = i + n] }") + + moved = maff.move_dims("n", isl.dim_type.in_) + + part = moved.get_at("x") + assert moved.input_names == frozenset({"i", "n"}) + assert moved.parameter_names == frozenset() + assert part.input_names == frozenset({"i", "n"}) + assert part.parameter_names == frozenset() + assert part._reconstruct_isl_object() == isl.PwAff("{ [i, n] -> [(i + n)] }") + + +def test_pw_multi_aff_named_operations_update_stored_parts() -> None: + pmaff = nisl.make_pw_multi_aff("[n] -> { [i] -> [x = i + n] }") + + renamed = pmaff.rename_dims({"x": "z", "i": "j", "n": "m"}) + moved = renamed.move_dims("m", isl.dim_type.in_) + + part = moved.get_at("z") + assert moved.input_names == frozenset({"j", "m"}) + assert moved.parameter_names == frozenset() + assert part.input_names == frozenset({"j", "m"}) + assert part.parameter_names == frozenset() + assert part._reconstruct_isl_object() == isl.PwAff("{ [j, m] -> [(j + m)] }") + + # {{{ qpolynomials def test_qpolynomial_from_str(): diff --git a/namedisl/test/test_set_like.py b/namedisl/test/test_set_like.py index 836644f..297ea8c 100644 --- a/namedisl/test/test_set_like.py +++ b/namedisl/test/test_set_like.py @@ -71,6 +71,14 @@ def test_set_equality(ndims: int, has_params: bool): assert a == perm_set +def test_set_like_equality_type_mismatch_raises_not_implemented_error() -> None: + set_ = nisl.make_set("{ [i] }") + map_ = nisl.make_map("{ [i] -> [j] }") + + with pytest.raises(NotImplementedError, match="Objects are not of the same type"): + _ = set_ == map_ + + @pytest.mark.parametrize("ndims", [1, 2, 4, 8]) @pytest.mark.parametrize("has_params", [True, False]) def test_set_union(ndims: int, has_params: bool): @@ -121,7 +129,7 @@ def test_set_intersection_rejects_name_collision_across_dim_types() -> None: set_with_n = nisl.make_set("{ [n] }") param_with_n = nisl.make_set("[n] -> { [i] }") - with pytest.raises(ValueError, match="duplicate|collision"): + with pytest.raises(ValueError, match=r"duplicate|collision"): _find_joint_name_to_dim(set_with_n, param_with_n) @@ -231,6 +239,13 @@ def test_set_eliminate(ndims: int): assert a == nisl.make_set(f"{{[{a_dims}]}}") +def test_set_eliminate_rejects_unknown_name() -> None: + set_ = nisl.make_set("{ [i] }") + + with pytest.raises(ValueError, match="unknown names: missing"): + _ = set_.eliminate("missing") + + @pytest.mark.parametrize("ndims", [2, 4, 8]) def test_set_project_out(ndims: int): a, a_dims, _ = generate_random_named_set(ndims, "a", None) @@ -239,6 +254,13 @@ def test_set_project_out(ndims: int): assert a == nisl.make_set("{[]}") +def test_set_project_out_rejects_unknown_name() -> None: + set_ = nisl.make_set("{ [i] }") + + with pytest.raises(ValueError, match="unknown names: missing"): + _ = set_.project_out("missing") + + @pytest.mark.parametrize("ndims", [2, 4, 8]) def test_set_dim_max(ndims: int): a, a_dims, a_cond = generate_random_named_set(ndims, "a", None) @@ -659,6 +681,13 @@ def test_map_eliminate(ndims_domain: int, ndims_range: int): assert x == nisl.make_map(f"{{[{x_in_dims}] -> [{x_out_dims}]}}") +def test_map_eliminate_rejects_unknown_name() -> None: + map_ = nisl.make_map("{ [i] -> [j] }") + + with pytest.raises(ValueError, match="unknown names: missing"): + _ = map_.eliminate("missing") + + @pytest.mark.parametrize("ndims_domain", [1, 2, 4, 8]) @pytest.mark.parametrize("ndims_range", [1, 2, 4, 8]) def test_map_project_out(ndims_domain: int, ndims_range: int): @@ -675,6 +704,13 @@ def test_map_project_out(ndims_domain: int, ndims_range: int): assert x == nisl.make_map("{[] -> []}") +def test_map_project_out_rejects_unknown_name() -> None: + map_ = nisl.make_map("{ [i] -> [j] }") + + with pytest.raises(ValueError, match="unknown names: missing"): + _ = map_.project_out("missing") + + def test_map_as_pw_multi_aff(): spec = "{ [i] -> [io, ii] : i = 32 * io + ii and 0 <= ii < 32 }" m = nisl.make_map(spec) From 95328a1d07f1bcace6a83b419fec15e6b0ead531 Mon Sep 17 00:00:00 2001 From: Addison Date: Wed, 10 Jun 2026 10:57:51 -0500 Subject: [PATCH 41/43] Make MultiExpressionLikes store named parts --- namedisl/core.py | 117 ++++++++----- namedisl/expression_like.py | 242 +++++++++++++++++++------- namedisl/test/test_expression_like.py | 92 +++++++--- pyproject.toml | 3 +- 4 files changed, 316 insertions(+), 138 deletions(-) diff --git a/namedisl/core.py b/namedisl/core.py index 3f9ad36..31b2d66 100644 --- a/namedisl/core.py +++ b/namedisl/core.py @@ -38,29 +38,39 @@ import re from abc import ABC from collections.abc import Callable, Collection, Mapping, Sequence -from dataclasses import dataclass +from dataclasses import dataclass, replace from importlib import metadata from typing import Generic, TypeAlias, TypeVar, cast, overload from constantdict import constantdict -from typing_extensions import Self, override +from typing_extensions import Self, TypeIs, override import islpy as isl IslBaseExpressionLike = isl.Aff | isl.QPolynomial IslPwExpressionLike = isl.PwAff | isl.PwQPolynomial +IslScalarExpressionLike = IslBaseExpressionLike | IslPwExpressionLike IslMultiExpressionLike = isl.MultiAff | isl.PwMultiAff -MultiExpressionParts: TypeAlias = Mapping[int, object] +MultiExpressionPart: TypeAlias = ( + "NamedIslObject[isl.Aff, isl.Aff] | NamedIslObject[isl.PwAff, isl.PwAff]" +) +MultiExpressionParts: TypeAlias = Mapping[int, MultiExpressionPart] -IslExpressionLike = IslBaseExpressionLike | IslPwExpressionLike | IslMultiExpressionLike +IslExpressionLike = IslScalarExpressionLike | IslMultiExpressionLike IslSetLike = isl.BasicSet | isl.BasicMap | isl.Set | isl.Map -IslObject = IslSetLike | IslExpressionLike | IslMultiExpressionLike -InternalIslObject = IslObject | MultiExpressionParts +IslObject = IslSetLike | IslExpressionLike +RawInternalIslObject = IslSetLike | IslScalarExpressionLike +InternalIslObject = RawInternalIslObject | MultiExpressionParts IslExpressionLikeT = TypeVar( "IslExpressionLikeT", - bound=IslExpressionLike, + bound=IslScalarExpressionLike, +) +RawInternalIslObjectT_co = TypeVar( + "RawInternalIslObjectT_co", + bound=RawInternalIslObject, + covariant=True, ) InternalIslObjectT_co = TypeVar( "InternalIslObjectT_co", @@ -85,7 +95,7 @@ # alignment DimTypeToNames: TypeAlias = Mapping[isl.dim_type, frozenset[str]] -IslObjectPieces: TypeAlias = tuple[IslObject, NameToDim, DimTypeToNames] +IslObjectPieces: TypeAlias = tuple[RawInternalIslObject, NameToDim, DimTypeToNames] __version__ = metadata.version("namedisl") @@ -111,15 +121,21 @@ def _normalize_public_dim_type(dim_type: isl.dim_type) -> isl.dim_type: return dim_type -def _is_multi_expression_parts(obj: object) -> bool: - return isinstance(obj, Mapping) +def _is_multi_expression_parts(obj: object) -> TypeIs[MultiExpressionParts]: + if not isinstance(obj, Mapping): + return False + + mapping: Mapping[object, object] = obj + return all( + isinstance(dim, int) and isinstance(part, NamedIslObject) + for dim, part in mapping.items() + ) def _uses_explicit_input_metadata(obj: object) -> bool: - return ( - isinstance(obj, IslSetLike | IslMultiExpressionLike) - or _is_multi_expression_parts(obj) - ) + return isinstance( + obj, IslSetLike | IslMultiExpressionLike + ) or _is_multi_expression_parts(obj) def _ensure_unique_public_names(obj: IslObject) -> None: @@ -143,8 +159,8 @@ def _ensure_unique_public_names(obj: IslObject) -> None: def _strip_names( - obj: InternalIslObjectT_co, -) -> tuple[InternalIslObjectT_co, NameToDim]: + obj: RawInternalIslObjectT_co, +) -> tuple[RawInternalIslObjectT_co, NameToDim]: name_to_dim: dict[str, int] = {} dt_to_strip = isl.dim_type.set if isinstance(obj, IslSetLike) else isl.dim_type.in_ @@ -165,7 +181,7 @@ def _strip_names( name_to_dim[name] = i - return cast("InternalIslObjectT_co", stripped_obj), constantdict(name_to_dim) + return cast("RawInternalIslObjectT_co", stripped_obj), constantdict(name_to_dim) def _get_obj_dim_name(obj: IslObject, dt: isl.dim_type, dim: int) -> str: @@ -237,8 +253,8 @@ def _make_named_object_pieces(obj: IslObject) -> IslObjectPieces: def _restore_names( - obj: InternalIslObjectT_co, name_to_dim: NameToDim -) -> InternalIslObjectT_co: + obj: RawInternalIslObjectT_co, name_to_dim: NameToDim +) -> RawInternalIslObjectT_co: """ Return a copy of *obj* with dimension names restored from *name_to_dim*. @@ -266,7 +282,7 @@ def _restore_names( restored_obj = restored_union_pw_aff.get_pw_aff_list().get_at(0) return cast( - "InternalIslObjectT_co", + "RawInternalIslObjectT_co", restored_obj.move_dims( isl.dim_type.in_, 0, @@ -287,7 +303,7 @@ def _restore_names( if isinstance(restored_obj, isl.UnionPwAff | isl.UnionPwMultiAff): raise NotImplementedError - return cast("InternalIslObjectT_co", restored_obj) + return cast("RawInternalIslObjectT_co", restored_obj) def _get_dim_names(obj: IslObject, dt: isl.dim_type) -> frozenset[str]: @@ -316,10 +332,16 @@ def _deconstruct_object(obj: isl.PwMultiAff) -> tuple[isl.Set, DimTypeToNames]: @overload -def _deconstruct_object(obj: IslObject) -> tuple[IslObject, DimTypeToNames]: ... +def _deconstruct_object(obj: isl.MultiAff) -> tuple[isl.Set, DimTypeToNames]: ... + + +@overload +def _deconstruct_object( + obj: RawInternalIslObject, +) -> tuple[RawInternalIslObject, DimTypeToNames]: ... -def _deconstruct_object(obj: IslObject) -> tuple[IslObject, DimTypeToNames]: +def _deconstruct_object(obj: IslObject) -> tuple[RawInternalIslObject, DimTypeToNames]: """ Convert a public isl object into namedisl's internal representation. @@ -375,6 +397,7 @@ def _deconstruct_object(obj: IslObject) -> tuple[IslObject, DimTypeToNames]: decon_obj.dim(isl.dim_type.param), ) + assert not isinstance(decon_obj, IslMultiExpressionLike) return decon_obj, constantdict(dt_to_names) @@ -410,8 +433,8 @@ def _find_contiguous_dim_chunks(dims: Sequence[int]) -> Mapping[int, int]: # FIXME: think through whether or not alphabetical ordering will require more # work on average than using one of the objects as a template in alignment def _find_joint_name_to_dim( - obj1: NamedIslObject[IslObject, IslObject], - obj2: NamedIslObject[IslObject, IslObject], + obj1: NamedIslObject[InternalIslObject, IslObject], + obj2: NamedIslObject[InternalIslObject, IslObject], ) -> tuple[NameToDim, DimTypeToNames]: """ Enforces alphabetical ordering of all dimensions found in :arg:`obj1` and @@ -494,8 +517,7 @@ def _align_obj( new_parts = constantdict({ new_dim: _align_obj( - cast("NamedIslObject[InternalIslObject, IslObject]", - named_obj._obj[named_obj._name_to_dim[name]]), + named_obj._obj[named_obj._name_to_dim[name]], part_ordering, part_dimtype_to_names, ) @@ -508,10 +530,7 @@ def _align_obj( dimtype_to_names, ) - new_isl_obj = cast( - "IslSetLike | IslBaseExpressionLike | IslPwExpressionLike | isl.MultiAff", - named_obj._obj, - ) + new_isl_obj = named_obj._obj running_name_to_dim = dict(named_obj._name_to_dim) target_dt = ( @@ -726,8 +745,7 @@ def _add_grouped_names( if not names_to_add: continue new_obj = constantdict({ - dim: cast("NamedIslObject[InternalIslObject, IslObject]", part) - .add_dim_names(names_to_add, dim_type) + dim: part.add_dim_names(names_to_add, dim_type) for dim, part in new_obj.items() }) @@ -747,10 +765,11 @@ def _add_grouped_names( has_inputs=True, ) - return type(self)( - new_obj, - new_name_to_dim, - new_dimtype_to_names, + return replace( + self, + _obj=new_obj, + _name_to_dim=new_name_to_dim, + _dimtype_to_names=new_dimtype_to_names, ) new_obj = self._obj @@ -977,18 +996,20 @@ def rename_dims(self, renaming: Mapping[str, str]) -> Self: new_obj = self._obj if _is_multi_expression_parts(new_obj): new_obj = constantdict({ - dim: cast("NamedIslObject[InternalIslObject, IslObject]", part) - .rename_dims({ + dim: part.rename_dims({ old_name: new_name for old_name, new_name in renaming.items() - if old_name in cast( - "NamedIslObject[InternalIslObject, IslObject]", part - ).names + if old_name in part.names }) for dim, part in new_obj.items() }) - return type(self)(new_obj, new_name_to_dim, new_dimtype_to_names) + return replace( + self, + _obj=new_obj, + _name_to_dim=new_name_to_dim, + _dimtype_to_names=new_dimtype_to_names, + ) @overload def equate_dims(self, name1: Mapping[str, str]) -> Self: ... @@ -1087,10 +1108,12 @@ def _reconstruct_isl_object(self) -> PublicIslObjectT_co: Relies on the dimension type ordering in :func:`_deconstruct_set_like_object`. """ - obj = cast( - "IslSetLike | IslBaseExpressionLike | IslPwExpressionLike | isl.MultiAff", - _restore_names(self._obj, self._name_to_dim), - ) + if _is_multi_expression_parts(self._obj): + raise NotImplementedError( + "multi-expression parts require subclass reconstruction" + ) + + obj = _restore_names(self._obj, self._name_to_dim) internal_dim = ( isl.dim_type.set if isinstance(obj, isl.Set) else isl.dim_type.in_ diff --git a/namedisl/expression_like.py b/namedisl/expression_like.py index 7a6ad8b..183b788 100644 --- a/namedisl/expression_like.py +++ b/namedisl/expression_like.py @@ -34,6 +34,7 @@ """ import operator +from collections.abc import Mapping from dataclasses import dataclass from typing import TYPE_CHECKING, Any, TypeVar, cast, final, overload @@ -45,7 +46,7 @@ from .core import ( DimTypeToNames, IslExpressionLikeT, - MultiExpressionParts, + IslScalarExpressionLike, NamedIslObject, NameToDim, _align_two, @@ -62,6 +63,11 @@ isl.MultiAff, isl.PwMultiAff, ) +NamedScalarExpressionLikeT_co = TypeVar( + "NamedScalarExpressionLikeT_co", + bound=IslScalarExpressionLike, + covariant=True, +) def _add_isl_expression( @@ -108,7 +114,7 @@ def _explicitly_promote_isl_expressions( @dataclass(frozen=True, eq=False) class _NamedExpressionLike( - NamedIslObject[IslExpressionLikeT, IslExpressionLikeT] + NamedIslObject[NamedScalarExpressionLikeT_co, NamedScalarExpressionLikeT_co] ): @overload def __add__(self: Aff, other: Aff | int) -> Aff: ... @@ -128,19 +134,13 @@ def __add__( ) -> PwQPolynomial: ... def __add__( - self, other: _NamedExpressionLike[IslExpressionLikeT] | int - ) -> _NamedExpressionLike[IslExpressionLikeT]: + self: _NamedExpressionLike[IslScalarExpressionLike], + other: _NamedExpressionLike[IslScalarExpressionLike] | int, + ) -> _NamedExpressionLike[IslScalarExpressionLike]: """ Add another compatible named expression or an integer. """ - if isinstance(other, int): - return _wrap_expression_result( - _add_isl_expression(self._obj, other), - self._name_to_dim, - self._dimtype_to_names, - ) - - return _align_and_apply_expression_op(self, other, _add_isl_expression) + return _apply_expression_binary_op(self, other, _add_isl_expression) @overload def __radd__(self: Aff, other: int) -> Aff: ... @@ -154,15 +154,14 @@ def __radd__(self: QPolynomial, other: int) -> QPolynomial: ... @overload def __radd__(self: PwQPolynomial, other: int) -> PwQPolynomial: ... - def __radd__(self, other: int) -> _NamedExpressionLike[IslExpressionLikeT]: + def __radd__( + self: _NamedExpressionLike[IslScalarExpressionLike], + other: int, + ) -> _NamedExpressionLike[IslScalarExpressionLike]: """ Add this expression to an integer. """ - return _wrap_expression_result( - _radd_isl_expression(other, self._obj), - self._name_to_dim, - self._dimtype_to_names, - ) + return _apply_reflected_int_expression_op(self, other, _radd_isl_expression) @overload def __sub__(self: Aff, other: Aff | int) -> Aff: ... @@ -182,19 +181,13 @@ def __sub__( ) -> PwQPolynomial: ... def __sub__( - self, other: _NamedExpressionLike[IslExpressionLikeT] | int - ) -> _NamedExpressionLike[IslExpressionLikeT]: + self: _NamedExpressionLike[IslScalarExpressionLike], + other: _NamedExpressionLike[IslScalarExpressionLike] | int, + ) -> _NamedExpressionLike[IslScalarExpressionLike]: """ Subtract another compatible named expression or an integer. """ - if isinstance(other, int): - return _wrap_expression_result( - _sub_isl_expression(self._obj, other), - self._name_to_dim, - self._dimtype_to_names, - ) - - return _align_and_apply_expression_op(self, other, _sub_isl_expression) + return _apply_expression_binary_op(self, other, _sub_isl_expression) @overload def __rsub__(self: Aff, other: int) -> Aff: ... @@ -208,15 +201,14 @@ def __rsub__(self: QPolynomial, other: int) -> QPolynomial: ... @overload def __rsub__(self: PwQPolynomial, other: int) -> PwQPolynomial: ... - def __rsub__(self, other: int) -> _NamedExpressionLike[IslExpressionLikeT]: + def __rsub__( + self: _NamedExpressionLike[IslScalarExpressionLike], + other: int, + ) -> _NamedExpressionLike[IslScalarExpressionLike]: """ Subtract this expression from an integer. """ - return _wrap_expression_result( - _rsub_isl_expression(other, self._obj), - self._name_to_dim, - self._dimtype_to_names, - ) + return _apply_reflected_int_expression_op(self, other, _rsub_isl_expression) @overload def __mul__(self: Aff, other: Aff | int) -> Aff: ... @@ -236,19 +228,13 @@ def __mul__( ) -> PwQPolynomial: ... def __mul__( - self, other: _NamedExpressionLike[IslExpressionLikeT] | int - ) -> _NamedExpressionLike[IslExpressionLikeT]: + self: _NamedExpressionLike[IslScalarExpressionLike], + other: _NamedExpressionLike[IslScalarExpressionLike] | int, + ) -> _NamedExpressionLike[IslScalarExpressionLike]: """ Multiply by another compatible named expression or an integer. """ - if isinstance(other, int): - return _wrap_expression_result( - _mul_isl_expression(self._obj, other), - self._name_to_dim, - self._dimtype_to_names, - ) - - return _align_and_apply_expression_op(self, other, _mul_isl_expression) + return _apply_expression_binary_op(self, other, _mul_isl_expression) @overload def __rmul__(self: Aff, other: int) -> Aff: ... @@ -262,15 +248,14 @@ def __rmul__(self: QPolynomial, other: int) -> QPolynomial: ... @overload def __rmul__(self: PwQPolynomial, other: int) -> PwQPolynomial: ... - def __rmul__(self, other: int) -> _NamedExpressionLike[IslExpressionLikeT]: + def __rmul__( + self: _NamedExpressionLike[IslScalarExpressionLike], + other: int, + ) -> _NamedExpressionLike[IslScalarExpressionLike]: """ Multiply this expression by an integer. """ - return _wrap_expression_result( - _rmul_isl_expression(other, self._obj), - self._name_to_dim, - self._dimtype_to_names, - ) + return _apply_reflected_int_expression_op(self, other, _rmul_isl_expression) def is_zero(self) -> bool: """ @@ -286,7 +271,7 @@ def __eq__(self, other: object) -> bool: @dataclass(frozen=True, eq=False) -class _NamedPwExpressionLike(_NamedExpressionLike[IslExpressionLikeT]): +class _NamedPwExpressionLike(_NamedExpressionLike[NamedScalarExpressionLikeT_co]): ... @@ -482,11 +467,87 @@ def _wrap_expression_result( raise TypeError(f"unsupported expression result type: {type(result).__name__}") -def _align_and_apply_expression_op( - lhs: _NamedExpressionLike[IslExpressionLikeT], - rhs: _NamedExpressionLike[IslExpressionLikeT], - op: Callable[[IslExpressionLikeT, IslExpressionLikeT], IslExpressionLikeT], -) -> _NamedExpressionLike[IslExpressionLikeT]: +@overload +def _apply_expression_binary_op( + lhs: Aff, + rhs: Aff | int, + op: Callable[ + [IslScalarExpressionLike, IslScalarExpressionLike | int], + IslScalarExpressionLike, + ], +) -> Aff: ... + + +@overload +def _apply_expression_binary_op( + lhs: Aff, + rhs: PwAff, + op: Callable[ + [IslScalarExpressionLike, IslScalarExpressionLike | int], + IslScalarExpressionLike, + ], +) -> PwAff: ... + + +@overload +def _apply_expression_binary_op( + lhs: PwAff, + rhs: Aff | PwAff | int, + op: Callable[ + [IslScalarExpressionLike, IslScalarExpressionLike | int], + IslScalarExpressionLike, + ], +) -> PwAff: ... + + +@overload +def _apply_expression_binary_op( + lhs: QPolynomial, + rhs: QPolynomial | int, + op: Callable[ + [IslScalarExpressionLike, IslScalarExpressionLike | int], + IslScalarExpressionLike, + ], +) -> QPolynomial: ... + + +@overload +def _apply_expression_binary_op( + lhs: PwQPolynomial, + rhs: PwQPolynomial | int, + op: Callable[ + [IslScalarExpressionLike, IslScalarExpressionLike | int], + IslScalarExpressionLike, + ], +) -> PwQPolynomial: ... + + +@overload +def _apply_expression_binary_op( + lhs: _NamedExpressionLike[IslScalarExpressionLike], + rhs: _NamedExpressionLike[IslScalarExpressionLike] | int, + op: Callable[ + [IslScalarExpressionLike, IslScalarExpressionLike | int], + IslScalarExpressionLike, + ], +) -> _NamedExpressionLike[IslScalarExpressionLike]: ... + + +def _apply_expression_binary_op( + lhs: _NamedExpressionLike[IslScalarExpressionLike], + rhs: _NamedExpressionLike[IslScalarExpressionLike] | int, + op: Callable[ + [IslScalarExpressionLike, IslScalarExpressionLike | int], + IslScalarExpressionLike, + ], +) -> _NamedExpressionLike[IslScalarExpressionLike]: + if isinstance(rhs, int): + return _wrap_expression_result( + op(lhs._obj, rhs), + lhs._name_to_dim, + lhs._dimtype_to_names, + ) + lhs, rhs = _align_two(lhs, rhs) lhs_obj, rhs_obj = _explicitly_promote_isl_expressions(lhs._obj, rhs._obj) result = op(lhs_obj, rhs_obj) @@ -496,6 +557,58 @@ def _align_and_apply_expression_op( lhs._dimtype_to_names, ) + +@overload +def _apply_reflected_int_expression_op( + expr: Aff, + other: int, + op: Callable[[int, IslScalarExpressionLike], IslScalarExpressionLike], +) -> Aff: ... + + +@overload +def _apply_reflected_int_expression_op( + expr: PwAff, + other: int, + op: Callable[[int, IslScalarExpressionLike], IslScalarExpressionLike], +) -> PwAff: ... + + +@overload +def _apply_reflected_int_expression_op( + expr: QPolynomial, + other: int, + op: Callable[[int, IslScalarExpressionLike], IslScalarExpressionLike], +) -> QPolynomial: ... + + +@overload +def _apply_reflected_int_expression_op( + expr: PwQPolynomial, + other: int, + op: Callable[[int, IslScalarExpressionLike], IslScalarExpressionLike], +) -> PwQPolynomial: ... + + +@overload +def _apply_reflected_int_expression_op( + expr: _NamedExpressionLike[IslScalarExpressionLike], + other: int, + op: Callable[[int, IslScalarExpressionLike], IslScalarExpressionLike], +) -> _NamedExpressionLike[IslScalarExpressionLike]: ... + + +def _apply_reflected_int_expression_op( + expr: _NamedExpressionLike[IslScalarExpressionLike], + other: int, + op: Callable[[int, IslScalarExpressionLike], IslScalarExpressionLike], +) -> _NamedExpressionLike[IslScalarExpressionLike]: + return _wrap_expression_result( + op(other, expr._obj), + expr._name_to_dim, + expr._dimtype_to_names, + ) + # }}} @@ -516,7 +629,7 @@ def _ordered_multi_dim_names( def _make_multi_expression_parts( obj: isl.MultiAff | isl.PwMultiAff, -) -> tuple[MultiExpressionParts, NameToDim, DimTypeToNames]: +) -> tuple[Mapping[int, PwAff], NameToDim, DimTypeToNames]: output_names = _ordered_multi_dim_names(obj, isl.dim_type.out) input_names = _ordered_multi_dim_names(obj, isl.dim_type.in_) parameter_names = _ordered_multi_dim_names(obj, isl.dim_type.param) @@ -527,7 +640,7 @@ def _make_multi_expression_parts( raise ValueError(f"duplicate dimension name found: {name}") seen_names.add(name) - parts: MultiExpressionParts = constantdict({ + parts: Mapping[int, PwAff] = constantdict({ dim: make_pw_aff( obj.get_at(dim).to_pw_aff() if isinstance(obj, isl.MultiAff) @@ -552,19 +665,18 @@ def _make_multi_expression_parts( @dataclass(frozen=True, eq=False) class _NamedMultiExpressionLike( - NamedIslObject[MultiExpressionParts, PublicMultiExpressionLikeT] + NamedIslObject[Mapping[int, PwAff], PublicMultiExpressionLikeT] ): """ Multi-expression components are stored directly as named :class:`PwAff` parts, keyed by output dimension. """ - _obj: MultiExpressionParts + _obj: Mapping[int, PwAff] def _multi_expression_context(self) -> isl.Context: if self._obj: - first_part = cast("PwAff", next(iter(self._obj.values()))) - return first_part._obj.get_ctx() + return next(iter(self._obj.values()))._obj.get_ctx() return isl.DEFAULT_CONTEXT def _multi_expression_space(self) -> isl.Space: @@ -577,7 +689,7 @@ def _multi_expression_space(self) -> isl.Space: def _ordered_pw_aff_parts(self) -> tuple[isl.PwAff, ...]: return tuple( - cast("PwAff", self._obj[dim])._reconstruct_isl_object() + self._obj[dim]._reconstruct_isl_object() for dim in range(self.dim(isl.dim_type.out)) ) @@ -597,7 +709,7 @@ def get_at(self, name: str) -> PwAff: """ if name not in self._names_for_dim_type(isl.dim_type.set): raise ValueError(f"unknown output name: {name}") - return cast("PwAff", self._obj[self._name_to_dim[name]]) + return self._obj[self._name_to_dim[name]] @override def _reconstruct_isl_object(self) -> isl.PwMultiAff: @@ -655,7 +767,7 @@ def get_at(self, name: str) -> PwAff: """ if name not in self._names_for_dim_type(isl.dim_type.set): raise ValueError(f"unknown output name: {name}") - return cast("PwAff", self._obj[self._name_to_dim[name]]) + return self._obj[self._name_to_dim[name]] @override def _reconstruct_isl_object(self) -> isl.MultiAff: diff --git a/namedisl/test/test_expression_like.py b/namedisl/test/test_expression_like.py index ed06449..e3a3178 100644 --- a/namedisl/test/test_expression_like.py +++ b/namedisl/test/test_expression_like.py @@ -33,6 +33,17 @@ import namedisl as nisl +ScalarIslExpression = isl.Aff | isl.PwAff | isl.QPolynomial | isl.PwQPolynomial + + +def _is_zero_expression(expr: ScalarIslExpression) -> bool: + if isinstance(expr, isl.Aff): + return bool(expr.plain_is_zero()) + if isinstance(expr, isl.PwAff): + return bool(expr.plain_is_equal(expr * 0)) + return bool(expr.is_zero()) + + # {{{ affs def test_aff_from_str(): @@ -160,40 +171,73 @@ def test_expression_equality_type_mismatch_raises_not_implemented_error() -> Non def test_reflected_integer_expression_ops() -> None: - def is_zero( - expr: isl.Aff | isl.PwAff | isl.QPolynomial | isl.PwQPolynomial, - ) -> bool: - if isinstance(expr, isl.Aff): - return bool(expr.plain_is_zero()) - if isinstance(expr, isl.PwAff): - return bool(expr.plain_is_equal(expr * 0)) - return bool(expr.is_zero()) - - expressions = [ - nisl.make_aff("{ [i] -> [i] }"), - nisl.make_pw_aff("{ [i] -> [i] }"), - nisl.make_qpolynomial("{ [i] -> i }"), - nisl.make_pw_qpolynomial("{ [i] -> i }"), - ] - - for expr in expressions: - obj = expr._reconstruct_isl_object() - - assert is_zero((1 + expr)._reconstruct_isl_object() - (1 + obj)) - assert is_zero((1 - expr)._reconstruct_isl_object() - (1 - obj)) - assert is_zero((2 * expr)._reconstruct_isl_object() - (2 * obj)) + aff_expr = nisl.make_aff("{ [i] -> [i] }") + aff_obj = aff_expr._reconstruct_isl_object() + assert _is_zero_expression((1 + aff_expr)._reconstruct_isl_object() - (1 + aff_obj)) + assert _is_zero_expression((1 - aff_expr)._reconstruct_isl_object() - (1 - aff_obj)) + assert _is_zero_expression((2 * aff_expr)._reconstruct_isl_object() - (2 * aff_obj)) + + pw_aff_expr = nisl.make_pw_aff("{ [i] -> [i] }") + pw_aff_obj = pw_aff_expr._reconstruct_isl_object() + assert _is_zero_expression( + (1 + pw_aff_expr)._reconstruct_isl_object() - (1 + pw_aff_obj) + ) + assert _is_zero_expression( + (1 - pw_aff_expr)._reconstruct_isl_object() - (1 - pw_aff_obj) + ) + assert _is_zero_expression( + (2 * pw_aff_expr)._reconstruct_isl_object() - (2 * pw_aff_obj) + ) + + qpoly_expr = nisl.make_qpolynomial("{ [i] -> i }") + qpoly_obj = qpoly_expr._reconstruct_isl_object() + assert _is_zero_expression( + (1 + qpoly_expr)._reconstruct_isl_object() - (1 + qpoly_obj) + ) + assert _is_zero_expression( + (1 - qpoly_expr)._reconstruct_isl_object() - (1 - qpoly_obj) + ) + assert _is_zero_expression( + (2 * qpoly_expr)._reconstruct_isl_object() - (2 * qpoly_obj) + ) + + pw_qpoly_expr = nisl.make_pw_qpolynomial("{ [i] -> i }") + pw_qpoly_obj = pw_qpoly_expr._reconstruct_isl_object() + assert _is_zero_expression( + (1 + pw_qpoly_expr)._reconstruct_isl_object() - (1 + pw_qpoly_obj) + ) + assert _is_zero_expression( + (1 - pw_qpoly_expr)._reconstruct_isl_object() - (1 - pw_qpoly_obj) + ) + assert _is_zero_expression( + (2 * pw_qpoly_expr)._reconstruct_isl_object() - (2 * pw_qpoly_obj) + ) def _qpolynomial(spec: str) -> isl.QPolynomial: return isl.PwQPolynomial(spec).get_pieces()[0][1] -def _assert_expression_equal(actual, expected) -> None: +def _assert_expression_equal( + actual: ScalarIslExpression, expected: ScalarIslExpression +) -> None: if isinstance(actual, isl.Aff | isl.PwAff): assert actual == expected return - assert (actual - expected).is_zero() + if isinstance(actual, isl.QPolynomial) and isinstance(expected, isl.QPolynomial): + assert (actual - expected).is_zero() + return + + if isinstance(actual, isl.PwQPolynomial) and isinstance( + expected, isl.PwQPolynomial + ): + assert (actual - expected).is_zero() + return + + raise TypeError( + f"cannot compare {type(actual).__name__} and {type(expected).__name__}" + ) def test_move_dims_expression_param_to_input_reconstructs_like_isl() -> None: diff --git a/pyproject.toml b/pyproject.toml index 07f33c4..cb46763 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ dependencies = [ "constantdict", "islpy", - "typing-extensions>=4.5", + "typing-extensions>=4.10", ] [project.urls] @@ -111,4 +111,3 @@ exclude = [ ".conda-root", ".env", ] - From aa393de655cfe3f4064981a54a52a8520f3be2ad Mon Sep 17 00:00:00 2001 From: Addison Date: Wed, 10 Jun 2026 13:43:26 -0500 Subject: [PATCH 42/43] Make multi objects container-like data structures separate from NamedIslObjects --- namedisl/core.py | 154 ++---------------------- namedisl/expression_like.py | 162 ++++++++++++++++++++++---- namedisl/test/test_expression_like.py | 52 ++------- 3 files changed, 162 insertions(+), 206 deletions(-) diff --git a/namedisl/core.py b/namedisl/core.py index 31b2d66..a372247 100644 --- a/namedisl/core.py +++ b/namedisl/core.py @@ -38,12 +38,12 @@ import re from abc import ABC from collections.abc import Callable, Collection, Mapping, Sequence -from dataclasses import dataclass, replace +from dataclasses import dataclass from importlib import metadata from typing import Generic, TypeAlias, TypeVar, cast, overload from constantdict import constantdict -from typing_extensions import Self, TypeIs, override +from typing_extensions import Self, override import islpy as isl @@ -52,16 +52,12 @@ IslPwExpressionLike = isl.PwAff | isl.PwQPolynomial IslScalarExpressionLike = IslBaseExpressionLike | IslPwExpressionLike IslMultiExpressionLike = isl.MultiAff | isl.PwMultiAff -MultiExpressionPart: TypeAlias = ( - "NamedIslObject[isl.Aff, isl.Aff] | NamedIslObject[isl.PwAff, isl.PwAff]" -) -MultiExpressionParts: TypeAlias = Mapping[int, MultiExpressionPart] -IslExpressionLike = IslScalarExpressionLike | IslMultiExpressionLike +IslExpressionLike = IslScalarExpressionLike IslSetLike = isl.BasicSet | isl.BasicMap | isl.Set | isl.Map IslObject = IslSetLike | IslExpressionLike RawInternalIslObject = IslSetLike | IslScalarExpressionLike -InternalIslObject = RawInternalIslObject | MultiExpressionParts +InternalIslObject = RawInternalIslObject IslExpressionLikeT = TypeVar( "IslExpressionLikeT", @@ -121,25 +117,12 @@ def _normalize_public_dim_type(dim_type: isl.dim_type) -> isl.dim_type: return dim_type -def _is_multi_expression_parts(obj: object) -> TypeIs[MultiExpressionParts]: - if not isinstance(obj, Mapping): - return False - - mapping: Mapping[object, object] = obj - return all( - isinstance(dim, int) and isinstance(part, NamedIslObject) - for dim, part in mapping.items() - ) - - def _uses_explicit_input_metadata(obj: object) -> bool: - return isinstance( - obj, IslSetLike | IslMultiExpressionLike - ) or _is_multi_expression_parts(obj) + return isinstance(obj, IslSetLike) def _ensure_unique_public_names(obj: IslObject) -> None: - if isinstance(obj, IslSetLike | IslMultiExpressionLike): + if isinstance(obj, IslSetLike): dim_types = (isl.dim_type.set, isl.dim_type.in_, isl.dim_type.param) else: dim_types = (isl.dim_type.in_, isl.dim_type.param) @@ -199,7 +182,7 @@ def _get_obj_dim_name(obj: IslObject, dt: isl.dim_type, dim: int) -> str: def _normalize_dimtype_to_names( obj: IslObject, dimtype_to_names: DimTypeToNames ) -> DimTypeToNames: - if isinstance(obj, IslSetLike | IslMultiExpressionLike): + if isinstance(obj, IslSetLike): dim_type = isl.dim_type.set total_dims = obj.dim(dim_type) n_in = len(dimtype_to_names.get(isl.dim_type.in_, frozenset())) @@ -326,15 +309,6 @@ def _get_dim_names(obj: IslObject, dt: isl.dim_type) -> frozenset[str]: def _deconstruct_object(obj: isl.Map) -> tuple[isl.Set, DimTypeToNames]: ... -@overload -# PwMultiAff doesn't have move_dims, so we're being a bit crooked here. -def _deconstruct_object(obj: isl.PwMultiAff) -> tuple[isl.Set, DimTypeToNames]: ... - - -@overload -def _deconstruct_object(obj: isl.MultiAff) -> tuple[isl.Set, DimTypeToNames]: ... - - @overload def _deconstruct_object( obj: RawInternalIslObject, @@ -351,15 +325,10 @@ def _deconstruct_object(obj: IslObject) -> tuple[RawInternalIslObject, DimTypeTo """ dt_to_names: dict[isl.dim_type, frozenset[str]] = {} - if isinstance(obj, IslSetLike | IslMultiExpressionLike): + if isinstance(obj, IslSetLike): decon_obj = obj dt_to_names = dict.fromkeys([isl.dim_type.in_, isl.dim_type.param], frozenset()) - # NOTE: isl.PwMultiAff.move_dims does not exist, represent as map - # internally - if isinstance(decon_obj, IslMultiExpressionLike): - decon_obj = decon_obj.as_map() - for dt in dt_to_names: dt_to_names[dt] = _get_dim_names(decon_obj, dt) if dt_to_names[dt]: @@ -397,7 +366,6 @@ def _deconstruct_object(obj: IslObject) -> tuple[RawInternalIslObject, DimTypeTo decon_obj.dim(isl.dim_type.param), ) - assert not isinstance(decon_obj, IslMultiExpressionLike) return decon_obj, constantdict(dt_to_names) @@ -489,47 +457,6 @@ def _align_obj( The isl object is moved or expanded as needed, but public names are carried in metadata and are restored only when a raw isl object is reconstructed. """ - if _is_multi_expression_parts(named_obj._obj): - old_output_names = named_obj._ordered_names_for_dim_type(isl.dim_type.set) - output_names = tuple( - name - for name, _ in sorted(ordering.items(), key=lambda x: x[1]) - if name not in dimtype_to_names.get(isl.dim_type.in_, frozenset()) - and name not in dimtype_to_names.get(isl.dim_type.param, frozenset()) - ) - - if set(output_names) != set(old_output_names): - raise NotImplementedError( - "moving dimensions between output and non-output dimensions " - "is not implemented for multi expressions" - ) - - part_ordering: NameToDim = constantdict({ - name: dim - len(output_names) - for name, dim in ordering.items() - if name not in output_names - }) - part_dimtype_to_names: DimTypeToNames = constantdict({ - isl.dim_type.param: dimtype_to_names.get( - isl.dim_type.param, frozenset() - ) - }) - - new_parts = constantdict({ - new_dim: _align_obj( - named_obj._obj[named_obj._name_to_dim[name]], - part_ordering, - part_dimtype_to_names, - ) - for new_dim, name in enumerate(output_names) - }) - - return type(named_obj)( - new_parts, - ordering, - dimtype_to_names, - ) - new_isl_obj = named_obj._obj running_name_to_dim = dict(named_obj._name_to_dim) @@ -732,46 +659,6 @@ def _add_grouped_names( raise ValueError(f"name already exists: {name}") seen_names.add(name) - if _is_multi_expression_parts(self._obj): - if grouped_names[isl.dim_type.set]: - raise NotImplementedError( - "adding output dimensions to multi expressions is not " - "implemented" - ) - - new_obj = self._obj - for dim_type in (isl.dim_type.param, isl.dim_type.in_): - names_to_add = grouped_names[dim_type] - if not names_to_add: - continue - new_obj = constantdict({ - dim: part.add_dim_names(names_to_add, dim_type) - for dim, part in new_obj.items() - }) - - chunk_names = { - dt: list(names) for dt, names in self._ordered_name_chunks().items() - } - for dim_type in (isl.dim_type.param, isl.dim_type.in_): - names_to_add = grouped_names[dim_type] - if names_to_add: - chunk_names[dim_type] = [ - *names_to_add, - *chunk_names[dim_type], - ] - - new_name_to_dim, new_dimtype_to_names = self._metadata_from_chunk_names( - chunk_names, - has_inputs=True, - ) - - return replace( - self, - _obj=new_obj, - _name_to_dim=new_name_to_dim, - _dimtype_to_names=new_dimtype_to_names, - ) - new_obj = self._obj chunk_names = { dt: list(names) for dt, names in self._ordered_name_chunks().items() @@ -993,22 +880,10 @@ def rename_dims(self, renaming: Mapping[str, str]) -> Self: for dim_type, names in self._dimtype_to_names.items() }) - new_obj = self._obj - if _is_multi_expression_parts(new_obj): - new_obj = constantdict({ - dim: part.rename_dims({ - old_name: new_name - for old_name, new_name in renaming.items() - if old_name in part.names - }) - for dim, part in new_obj.items() - }) - - return replace( - self, - _obj=new_obj, - _name_to_dim=new_name_to_dim, - _dimtype_to_names=new_dimtype_to_names, + return type(self)( + self._obj, + new_name_to_dim, + new_dimtype_to_names, ) @overload @@ -1108,11 +983,6 @@ def _reconstruct_isl_object(self) -> PublicIslObjectT_co: Relies on the dimension type ordering in :func:`_deconstruct_set_like_object`. """ - if _is_multi_expression_parts(self._obj): - raise NotImplementedError( - "multi-expression parts require subclass reconstruction" - ) - obj = _restore_names(self._obj, self._name_to_dim) internal_dim = ( diff --git a/namedisl/expression_like.py b/namedisl/expression_like.py index 183b788..16e5e4a 100644 --- a/namedisl/expression_like.py +++ b/namedisl/expression_like.py @@ -36,7 +36,7 @@ import operator from collections.abc import Mapping from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, TypeVar, cast, final, overload +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, final, overload from constantdict import constantdict from typing_extensions import override @@ -51,6 +51,7 @@ NameToDim, _align_two, _make_named_object_pieces, + _normalize_public_dim_type, ) @@ -629,26 +630,31 @@ def _ordered_multi_dim_names( def _make_multi_expression_parts( obj: isl.MultiAff | isl.PwMultiAff, -) -> tuple[Mapping[int, PwAff], NameToDim, DimTypeToNames]: +) -> tuple[Mapping[str, PwAff], NameToDim, DimTypeToNames]: output_names = _ordered_multi_dim_names(obj, isl.dim_type.out) - input_names = _ordered_multi_dim_names(obj, isl.dim_type.in_) - parameter_names = _ordered_multi_dim_names(obj, isl.dim_type.param) - seen_names: set[str] = set() - for name in (*output_names, *input_names, *parameter_names): - if name in seen_names: - raise ValueError(f"duplicate dimension name found: {name}") - seen_names.add(name) - - parts: Mapping[int, PwAff] = constantdict({ - dim: make_pw_aff( + parts: Mapping[str, PwAff] = constantdict({ + name: make_pw_aff( obj.get_at(dim).to_pw_aff() if isinstance(obj, isl.MultiAff) else obj.get_at(dim) ) - for dim in range(obj.dim(isl.dim_type.out)) + for dim, name in enumerate(output_names) }) + if parts: + input_names = _ordered_part_dim_names(parts, isl.dim_type.in_) + parameter_names = _ordered_part_dim_names(parts, isl.dim_type.param) + else: + input_names = _ordered_multi_dim_names(obj, isl.dim_type.in_) + parameter_names = _ordered_multi_dim_names(obj, isl.dim_type.param) + + seen_names: set[str] = set() + for name in (*output_names, *input_names, *parameter_names): + if name in seen_names: + raise ValueError(f"duplicate dimension name found: {name}") + seen_names.add(name) + all_names = [*output_names, *input_names, *parameter_names] name_to_dim: NameToDim = constantdict({ name: dim for dim, name in enumerate(all_names) @@ -663,16 +669,128 @@ def _make_multi_expression_parts( return parts, name_to_dim, constantdict(dimtype_to_names) +def _ordered_part_dim_names( + parts: Mapping[str, PwAff], + dim_type: isl.dim_type, +) -> tuple[str, ...]: + part_iter = iter(parts.items()) + _, first_part = next(part_iter) + ordered_names = first_part.ordered_dim_names(dim_type) + + for output_name, part in part_iter: + part_ordered_names = part.ordered_dim_names(dim_type) + if part_ordered_names != ordered_names: + raise ValueError( + f"multi expression part '{output_name}' has inconsistent " + f"{dim_type.name} dimension names" + ) + + return ordered_names + + @dataclass(frozen=True, eq=False) -class _NamedMultiExpressionLike( - NamedIslObject[Mapping[int, PwAff], PublicMultiExpressionLikeT] -): +class _NamedMultiExpressionLike(Generic[PublicMultiExpressionLikeT]): """ Multi-expression components are stored directly as named :class:`PwAff` - parts, keyed by output dimension. + parts, keyed by output name. """ - _obj: Mapping[int, PwAff] + _obj: Mapping[str, PwAff] + _name_to_dim: NameToDim + _dimtype_to_names: DimTypeToNames + + @property + def _metadata_input_names(self) -> frozenset[str]: + return self._dimtype_to_names.get(isl.dim_type.in_, frozenset()) + + @property + def _metadata_parameter_names(self) -> frozenset[str]: + return self._dimtype_to_names.get(isl.dim_type.param, frozenset()) + + def _names_for_dim_type(self, dim_type: isl.dim_type) -> frozenset[str]: + dim_type = _normalize_public_dim_type(dim_type) + if dim_type == isl.dim_type.param: + return self.parameter_names + if dim_type == isl.dim_type.in_: + return self._metadata_input_names + if dim_type == isl.dim_type.set: + return frozenset(self._obj) + raise ValueError(f"unsupported dim type: {dim_type}") + + def _ordered_names_for_dim_type(self, dim_type: isl.dim_type) -> tuple[str, ...]: + names = self._names_for_dim_type(dim_type) + return tuple(sorted(names, key=self._name_to_dim.__getitem__)) + + @property + def names(self) -> frozenset[str]: + """ + All dimension names known to this object. + """ + return self.output_names | self.input_names | self.parameter_names + + def dim_names(self, dim_type: isl.dim_type) -> frozenset[str]: + """ + Return the names belonging to *dim_type*. + """ + return self._names_for_dim_type(dim_type) + + def ordered_dim_names(self, dim_type: isl.dim_type) -> tuple[str, ...]: + """ + Return names for *dim_type* in their current dimension order. + """ + return self._ordered_names_for_dim_type(dim_type) + + @property + def set_names(self) -> frozenset[str]: + """ + Names of set dimensions. + """ + return self._names_for_dim_type(isl.dim_type.set) + + @property + def output_names(self) -> frozenset[str]: + """ + Names of output dimensions. + """ + return self._names_for_dim_type(isl.dim_type.out) + + @property + def input_names(self) -> frozenset[str]: + """ + Names of input dimensions. + """ + return self._names_for_dim_type(isl.dim_type.in_) + + @property + def parameter_names(self) -> frozenset[str]: + """ + Names of parameter dimensions. + """ + return self._metadata_parameter_names + + def dim(self, dim_type: isl.dim_type) -> int: + """ + Return the number of dimensions of *dim_type*. + """ + dim_type = _normalize_public_dim_type(dim_type) + if dim_type in (isl.dim_type.set, isl.dim_type.in_, isl.dim_type.param): + return len(self._names_for_dim_type(dim_type)) + return self._reconstruct_isl_object().dim(dim_type) + + def get_space(self) -> isl.Space: + """ + Reconstruct and return the object's public isl space. + """ + return self._reconstruct_isl_object().get_space() + + def get_isl_object(self) -> PublicMultiExpressionLikeT: + """ + Reconstruct and return the wrapped public :mod:`islpy` object. + """ + return self._reconstruct_isl_object() + + def _reconstruct_isl_object(self) -> PublicMultiExpressionLikeT: + raise NotImplementedError def _multi_expression_context(self) -> isl.Context: if self._obj: @@ -689,8 +807,8 @@ def _multi_expression_space(self) -> isl.Space: def _ordered_pw_aff_parts(self) -> tuple[isl.PwAff, ...]: return tuple( - self._obj[dim]._reconstruct_isl_object() - for dim in range(self.dim(isl.dim_type.out)) + self._obj[name]._reconstruct_isl_object() + for name in self.ordered_dim_names(isl.dim_type.out) ) @@ -709,7 +827,7 @@ def get_at(self, name: str) -> PwAff: """ if name not in self._names_for_dim_type(isl.dim_type.set): raise ValueError(f"unknown output name: {name}") - return self._obj[self._name_to_dim[name]] + return self._obj[name] @override def _reconstruct_isl_object(self) -> isl.PwMultiAff: @@ -767,7 +885,7 @@ def get_at(self, name: str) -> PwAff: """ if name not in self._names_for_dim_type(isl.dim_type.set): raise ValueError(f"unknown output name: {name}") - return self._obj[self._name_to_dim[name]] + return self._obj[name] @override def _reconstruct_isl_object(self) -> isl.MultiAff: diff --git a/namedisl/test/test_expression_like.py b/namedisl/test/test_expression_like.py index e3a3178..45b5c08 100644 --- a/namedisl/test/test_expression_like.py +++ b/namedisl/test/test_expression_like.py @@ -334,8 +334,11 @@ def test_multi_aff_stores_pw_aff_parts() -> None: assert isinstance(maff._obj, constantdict) assert not isinstance(maff._obj, isl.Set) assert all(isinstance(part, nisl.PwAff) for part in maff._obj.values()) - assert maff.get_at("x") is maff._obj[0] - assert maff.get_at("y") is maff._obj[1] + assert maff.get_at("x") is maff._obj["x"] + assert maff.get_at("y") is maff._obj["y"] + assert maff.output_names == frozenset(maff._obj) + assert maff.input_names == maff.get_at("x").input_names + assert maff.parameter_names == maff.get_at("x").parameter_names assert maff._reconstruct_isl_object() == raw_maff @@ -347,49 +350,14 @@ def test_pw_multi_aff_stores_pw_aff_parts() -> None: assert isinstance(pmaff._obj, constantdict) assert not isinstance(pmaff._obj, isl.Set) assert all(isinstance(part, nisl.PwAff) for part in pmaff._obj.values()) - assert pmaff.get_at("x") is pmaff._obj[0] - assert pmaff.get_at("y") is pmaff._obj[1] + assert pmaff.get_at("x") is pmaff._obj["x"] + assert pmaff.get_at("y") is pmaff._obj["y"] + assert pmaff.output_names == frozenset(pmaff._obj) + assert pmaff.input_names == pmaff.get_at("x").input_names + assert pmaff.parameter_names == pmaff.get_at("x").parameter_names assert pmaff._reconstruct_isl_object() == raw_pmaff -def test_multi_aff_rename_dims_updates_stored_parts() -> None: - maff = nisl.make_multi_aff("[n] -> { [i] -> [x = i + n] }") - - renamed = maff.rename_dims({"x": "z", "i": "j", "n": "m"}) - - part = renamed.get_at("z") - assert part.input_names == frozenset({"j"}) - assert part.parameter_names == frozenset({"m"}) - assert part._reconstruct_isl_object() == isl.PwAff("[m] -> { [j] -> [(j + m)] }") - - -def test_multi_aff_move_dims_updates_stored_parts() -> None: - maff = nisl.make_multi_aff("[n] -> { [i] -> [x = i + n] }") - - moved = maff.move_dims("n", isl.dim_type.in_) - - part = moved.get_at("x") - assert moved.input_names == frozenset({"i", "n"}) - assert moved.parameter_names == frozenset() - assert part.input_names == frozenset({"i", "n"}) - assert part.parameter_names == frozenset() - assert part._reconstruct_isl_object() == isl.PwAff("{ [i, n] -> [(i + n)] }") - - -def test_pw_multi_aff_named_operations_update_stored_parts() -> None: - pmaff = nisl.make_pw_multi_aff("[n] -> { [i] -> [x = i + n] }") - - renamed = pmaff.rename_dims({"x": "z", "i": "j", "n": "m"}) - moved = renamed.move_dims("m", isl.dim_type.in_) - - part = moved.get_at("z") - assert moved.input_names == frozenset({"j", "m"}) - assert moved.parameter_names == frozenset() - assert part.input_names == frozenset({"j", "m"}) - assert part.parameter_names == frozenset() - assert part._reconstruct_isl_object() == isl.PwAff("{ [j, m] -> [(j + m)] }") - - # {{{ qpolynomials def test_qpolynomial_from_str(): From 89b020b1a8b0a2a38d3a91cf9bae9219eefe661a Mon Sep 17 00:00:00 2001 From: Addison Date: Wed, 10 Jun 2026 14:29:40 -0500 Subject: [PATCH 43/43] fix ruff complaint --- namedisl/expression_like.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/namedisl/expression_like.py b/namedisl/expression_like.py index 16e5e4a..5d75e52 100644 --- a/namedisl/expression_like.py +++ b/namedisl/expression_like.py @@ -34,7 +34,6 @@ """ import operator -from collections.abc import Mapping from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, final, overload @@ -56,7 +55,7 @@ if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Callable, Mapping PublicMultiExpressionLikeT = TypeVar(