diff --git a/.gitignore b/.gitignore index ea396a7..a20f1e1 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,5 @@ distribute*tar.gz .pylintrc.yml .run-pylint.py + +.codex diff --git a/doc/conf.py b/doc/conf.py index a27710b..b9fa5f5 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -1,31 +1,37 @@ from __future__ import annotations +import sys +from importlib import metadata +from pathlib import Path from urllib.request import urlopen -_conf_url = \ - "https://raw.githubusercontent.com/inducer/sphinxconfig/main/sphinxconfig.py" +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +_conf_url = "https://tiker.net/sphinxconfig-v0.py" with urlopen(_conf_url) as _inf: exec(compile(_inf.read(), _conf_url, "exec"), globals()) -copyright = "2025- University of Illiois Board of Trustees" -author = "Andreas Kloeckner" +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.linkcode", +] -# The version info for the project you're documenting, acts as replacement for -# |version| and |release|, also used in various other places throughout the -# built documents. -# -# The short X.Y version. -ver_dic = {} -with open("../namedisl/__init__.py") as vfile: - exec(compile(vfile.read(), "../namedisl/__init__.py", "exec"), ver_dic) +autodoc_member_order = "bysource" +autodoc_typehints = "none" + +project = "namedisl" +copyright = "2025- University of Illinois Board of Trustees" +author = "Andreas Kloeckner" -version = ".".join(str(x) for x in ver_dic["__version__"]) -release = ver_dic["__version__"] +release = metadata.version("namedisl") +version = ".".join(release.split(".")[:2]) intersphinx_mapping = { "islpy": ("https://documen.tician.de/islpy", None), "constantdict": ("https://matthiasdiener.github.io/constantdict/", None), + "python": ("https://docs.python.org/3", None), } nitpicky = True diff --git a/doc/index.rst b/doc/index.rst index 0f429a7..28d47f6 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -1,11 +1,17 @@ Welcome to namedisl's documentation! ==================================== +namedisl is a small wrapper around :mod:`islpy` for code that needs dimension +names to carry semantic meaning. It keeps isl's integer-set operations while +offering a name-oriented interface for sets, maps, affine expressions, and +piecewise affine expressions. + .. toctree:: :maxdepth: 2 :caption: Contents: ref + internals 🚀 Github 💾 Download Releases diff --git a/doc/internals.rst b/doc/internals.rst new file mode 100644 index 0000000..267bdcf --- /dev/null +++ b/doc/internals.rst @@ -0,0 +1,18 @@ +Internal Reference +================== + +These interfaces support the public wrappers and are mainly useful when +maintaining namedisl itself. + +.. automodule:: namedisl.core + +.. autoclass:: namedisl.core.NamedIslObject + :members: + +.. autofunction:: namedisl.core._deconstruct_object + +.. autofunction:: namedisl.core._make_named_object_pieces + +.. autofunction:: namedisl.core._align_two + +.. autofunction:: namedisl.core._align_and_apply_binary_op diff --git a/doc/ref.rst b/doc/ref.rst index 56fd5a9..9d60a32 100644 --- a/doc/ref.rst +++ b/doc/ref.rst @@ -2,3 +2,89 @@ Reference ========= .. automodule:: namedisl + +Set and Map Objects +------------------- + +.. autofunction:: make_basic_set + +.. autofunction:: make_set + +.. autofunction:: make_basic_map + +.. autofunction:: make_map + +.. autofunction:: make_map_from_domain_and_range + +.. autoclass:: BasicSet + :members: + :inherited-members: + :special-members: __and__, __or__, __sub__, __lt__, __le__, __gt__, __ge__ + :exclude-members: __init__, __new__ + +.. autoclass:: Set + :members: + :inherited-members: + :special-members: __and__, __or__, __sub__, __lt__, __le__, __gt__, __ge__ + :exclude-members: __init__, __new__ + +.. autoclass:: BasicMap + :members: + :inherited-members: + :special-members: __and__, __or__, __sub__, __lt__, __le__, __gt__, __ge__ + :exclude-members: __init__, __new__ + +.. autoclass:: Map + :members: + :inherited-members: + :special-members: __and__, __or__, __sub__, __lt__, __le__, __gt__, __ge__ + :exclude-members: __init__, __new__ + +Expression Objects +------------------ + +.. autofunction:: make_aff + +.. autofunction:: make_pw_aff + +.. autofunction:: make_qpolynomial + +.. autofunction:: make_pw_qpolynomial + +.. autofunction:: make_multi_aff + +.. autofunction:: make_pw_multi_aff + +.. autoclass:: Aff + :members: + :inherited-members: + :special-members: __add__, __sub__, __mul__ + :exclude-members: __init__, __new__ + +.. autoclass:: PwAff + :members: + :inherited-members: + :special-members: __add__, __sub__, __mul__ + :exclude-members: __init__, __new__ + +.. autoclass:: QPolynomial + :members: + :inherited-members: + :special-members: __add__, __sub__, __mul__ + :exclude-members: __init__, __new__ + +.. autoclass:: PwQPolynomial + :members: + :inherited-members: + :special-members: __add__, __sub__, __mul__ + :exclude-members: __init__, __new__ + +.. autoclass:: MultiAff + :members: + :inherited-members: + :exclude-members: __init__, __new__ + +.. autoclass:: PwMultiAff + :members: + :inherited-members: + :exclude-members: __init__, __new__ diff --git a/namedisl/__init__.py b/namedisl/__init__.py index a2edcaf..bce7f38 100644 --- a/namedisl/__init__.py +++ b/namedisl/__init__.py @@ -1,9 +1,15 @@ """ -.. autoclass:: BasicSet +Name-aware wrappers for :mod:`islpy` objects. -.. autofunction:: make_basic_set -""" +namedisl offers small Python wrappers around isl sets, maps, and expression +objects. The wrappers keep a separate mapping from dimension names to isl +dimension positions, align operands by name before applying binary operations, +and reconstruct ordinary :mod:`islpy` objects when callers need to interoperate +with islpy or downstream libraries. +Most users should construct objects through the ``make_*`` functions exported +from this module. +""" from __future__ import annotations @@ -31,74 +37,54 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -import re -from collections.abc import Mapping -from dataclasses import dataclass -from importlib import metadata -from typing import TypeAlias, TypeVar, overload - -from constantdict import constantdict -from typing_extensions import override - -import islpy as isl - - -__version__ = metadata.version("namedisl") -_match = re.match(r"^([0-9.]+)([a-z0-9]*?)$", __version__) -assert _match -VERSION = tuple(int(nr) for nr in _match.group(1).split(".")) - - -IslObject = TypeVar("IslObject", isl.BasicSet, isl.Set) -NameToDim: TypeAlias = Mapping[str, tuple[isl.dim_type, int]] - - -def _strip_names(obj: IslObject) -> tuple[IslObject, NameToDim]: - name_to_dim: dict[str, tuple[isl.dim_type, int]] = {} - for tp in isl._CHECK_DIM_TYPES: - for i in range(obj.dim(tp)): - name = obj.get_dim_name(tp, i) - if name is None: - raise ValueError("unnamed dimension found") - if name in name_to_dim: - raise ValueError(f"non-unique dim name: {name}") - name_to_dim[name] = (tp, i) - - # FIXME: Enable, to avoid misunderstandings - # obj = obj.set_dim_id(tp, i, None) - - return obj, constantdict(name_to_dim) - - -def _restore_names(obj: IslObject, name_to_dim: NameToDim) -> IslObject: - for name, (dt, i) in name_to_dim.items(): - obj = obj.set_dim_name(dt, i, name) - - return obj - - -@dataclass(frozen=True) -class BasicSet: - _obj: isl.BasicSet - _name_to_dim: NameToDim - - @override - def __str__(self) -> str: - return str(_restore_names(self._obj, self._name_to_dim)) - - -@overload -def make_basic_set(src: str, ctx: isl.Context | None = None) -> BasicSet: - ... - - -@overload -def make_basic_set(src: isl.BasicSet) -> BasicSet: - ... - - -def make_basic_set(src: str | isl.BasicSet, ctx: isl.Context | None = None) -> BasicSet: - obj = isl.BasicSet(src, ctx) if isinstance(src, str) else src - obj, name_to_dim = _strip_names(obj) - return BasicSet(obj, name_to_dim) +from .expression_like import ( + Aff, + MultiAff, + PwAff, + PwMultiAff, + PwQPolynomial, + QPolynomial, + make_aff, + make_multi_aff, + make_pw_aff, + make_pw_multi_aff, + make_pw_qpolynomial, + make_qpolynomial, +) +from .set_like import ( + BasicMap, + BasicSet, + Map, + Set, + make_basic_map, + make_basic_set, + make_map, + make_map_from_domain_and_range, + make_set, +) + + +__all__ = [ + "Aff", + "BasicMap", + "BasicSet", + "Map", + "MultiAff", + "PwAff", + "PwMultiAff", + "PwQPolynomial", + "QPolynomial", + "Set", + "make_aff", + "make_basic_map", + "make_basic_set", + "make_map", + "make_map_from_domain_and_range", + "make_multi_aff", + "make_pw_aff", + "make_pw_multi_aff", + "make_pw_qpolynomial", + "make_qpolynomial", + "make_set", +] diff --git a/namedisl/core.py b/namedisl/core.py new file mode 100644 index 0000000..a372247 --- /dev/null +++ b/namedisl/core.py @@ -0,0 +1,1038 @@ +""" +Core metadata machinery for name-aware isl wrappers. + +The public set-like and expression-like classes store an isl object together +with metadata that records which semantic name belongs to each internal +dimension. This module +contains the alignment, reconstruction, and metadata-manipulation helpers used +to keep that invariant intact. +""" + +from __future__ import annotations + + +__copyright__ = """ +Copyright (C) 2025- University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import re +from abc import ABC +from collections.abc import Callable, Collection, Mapping, Sequence +from dataclasses import dataclass +from importlib import metadata +from typing import Generic, TypeAlias, TypeVar, cast, overload + +from constantdict import constantdict +from typing_extensions import Self, override + +import islpy as isl + + +IslBaseExpressionLike = isl.Aff | isl.QPolynomial +IslPwExpressionLike = isl.PwAff | isl.PwQPolynomial +IslScalarExpressionLike = IslBaseExpressionLike | IslPwExpressionLike +IslMultiExpressionLike = isl.MultiAff | isl.PwMultiAff + +IslExpressionLike = IslScalarExpressionLike +IslSetLike = isl.BasicSet | isl.BasicMap | isl.Set | isl.Map +IslObject = IslSetLike | IslExpressionLike +RawInternalIslObject = IslSetLike | IslScalarExpressionLike +InternalIslObject = RawInternalIslObject + +IslExpressionLikeT = TypeVar( + "IslExpressionLikeT", + bound=IslScalarExpressionLike, +) +RawInternalIslObjectT_co = TypeVar( + "RawInternalIslObjectT_co", + bound=RawInternalIslObject, + covariant=True, +) +InternalIslObjectT_co = TypeVar( + "InternalIslObjectT_co", + bound=InternalIslObject, + covariant=True, +) +PublicIslObjectT_co = TypeVar( + "PublicIslObjectT_co", + bound=IslObject, + covariant=True, +) + +NamedIslObjectT = TypeVar( + "NamedIslObjectT", + bound="NamedIslObject[InternalIslObject, IslObject]", +) + +NameToDim: TypeAlias = Mapping[str, int] + +# NOTE: without tracking what dimension type a particular name belongs to, it is +# not possible to reconstruct the ISL object after dimension operations, e.g. +# alignment +DimTypeToNames: TypeAlias = Mapping[isl.dim_type, frozenset[str]] + +IslObjectPieces: TypeAlias = tuple[RawInternalIslObject, NameToDim, DimTypeToNames] + + +__version__ = metadata.version("namedisl") +_match = re.match(r"^([0-9.]+)([a-z0-9]*?)$", __version__) +assert _match +VERSION = tuple(int(nr) for nr in _match.group(1).split(".")) + +__all__ = [ + "_align_and_apply_binary_op", + "_align_two", + "_deconstruct_object", + "_find_contiguous_dim_chunks", + "_make_named_object_pieces", + "_normalize_dimtype_to_names", + "_restore_names", + "_strip_names", +] + + +def _normalize_public_dim_type(dim_type: isl.dim_type) -> isl.dim_type: + if dim_type == isl.dim_type.out: + return isl.dim_type.set + return dim_type + + +def _uses_explicit_input_metadata(obj: object) -> bool: + return isinstance(obj, IslSetLike) + + +def _ensure_unique_public_names(obj: IslObject) -> None: + if isinstance(obj, IslSetLike): + dim_types = (isl.dim_type.set, isl.dim_type.in_, isl.dim_type.param) + else: + dim_types = (isl.dim_type.in_, isl.dim_type.param) + + seen_names: set[str] = set() + for dim_type in dim_types: + for dim in range(obj.dim(dim_type)): + if isinstance(obj, isl.QPolynomial | isl.PwQPolynomial): + name = obj.space.get_dim_name(dim_type, dim) + else: + name = obj.get_dim_name(dim_type, dim) + if name is None: + raise ValueError("duplicate or unnamed dimension found") + if name in seen_names: + raise ValueError(f"duplicate dimension name found: {name}") + seen_names.add(name) + + +def _strip_names( + obj: RawInternalIslObjectT_co, +) -> tuple[RawInternalIslObjectT_co, NameToDim]: + name_to_dim: dict[str, int] = {} + + dt_to_strip = isl.dim_type.set if isinstance(obj, IslSetLike) else isl.dim_type.in_ + + stripped_obj = obj.copy() + + for i in range(stripped_obj.dim(dt_to_strip)): + if isinstance(stripped_obj, isl.QPolynomial | isl.PwQPolynomial): + name = stripped_obj.space.get_dim_name(dt_to_strip, i) + else: + name = stripped_obj.get_dim_name(dt_to_strip, i) + + if name is None: + raise ValueError("unnamed dimension found") + + if name in name_to_dim: + raise ValueError(f"duplicate dimension name found: {name}") + + name_to_dim[name] = i + + return cast("RawInternalIslObjectT_co", stripped_obj), constantdict(name_to_dim) + + +def _get_obj_dim_name(obj: IslObject, dt: isl.dim_type, dim: int) -> str: + if isinstance(obj, isl.QPolynomial | isl.PwQPolynomial): + name = obj.space.get_dim_name(dt, dim) + else: + name = obj.get_dim_name(dt, dim) + + if name is None: + raise ValueError("unnamed dimension found") + + return name + + +def _normalize_dimtype_to_names( + obj: IslObject, dimtype_to_names: DimTypeToNames +) -> DimTypeToNames: + if isinstance(obj, IslSetLike): + dim_type = isl.dim_type.set + total_dims = obj.dim(dim_type) + n_in = len(dimtype_to_names.get(isl.dim_type.in_, frozenset())) + n_param = len(dimtype_to_names.get(isl.dim_type.param, frozenset())) + + new_dimtype_to_names: dict[isl.dim_type, frozenset[str]] = {} + + if n_in: + start = total_dims - n_param - n_in + new_dimtype_to_names[isl.dim_type.in_] = frozenset( + _get_obj_dim_name(obj, dim_type, dim) + for dim in range(start, start + n_in) + ) + + if n_param: + start = total_dims - n_param + new_dimtype_to_names[isl.dim_type.param] = frozenset( + _get_obj_dim_name(obj, dim_type, dim) + for dim in range(start, start + n_param) + ) + + return constantdict(new_dimtype_to_names) + + total_dims = obj.dim(isl.dim_type.in_) + n_param = len(dimtype_to_names.get(isl.dim_type.param, frozenset())) + if not n_param: + return dimtype_to_names + + start = total_dims - n_param + return constantdict({ + isl.dim_type.param: frozenset( + _get_obj_dim_name(obj, isl.dim_type.in_, dim) + for dim in range(start, start + n_param) + ) + }) + + +def _make_named_object_pieces(obj: IslObject) -> IslObjectPieces: + """ + Extract the normalized isl object and name metadata for *obj*. + + The returned object uses namedisl's internal flattened dimension layout. + The accompanying mappings record how to reconstruct the original public isl + dimension kinds and names. + """ + _ensure_unique_public_names(obj) + decon_obj, dimtype_to_names = _deconstruct_object(obj) + decon_obj, name_to_dim = _strip_names(decon_obj) + dimtype_to_names = _normalize_dimtype_to_names(decon_obj, dimtype_to_names) + return decon_obj, name_to_dim, dimtype_to_names + + +def _restore_names( + obj: RawInternalIslObjectT_co, name_to_dim: NameToDim +) -> RawInternalIslObjectT_co: + """ + Return a copy of *obj* with dimension names restored from *name_to_dim*. + + This is intentionally used at reconstruction boundaries. Internal + metadata-only operations may leave the private isl object's own dimension + names stale because namedisl does not rely on those names for correctness. + """ + restored_obj = obj.copy() + if isinstance(restored_obj, isl.PwAff): + # input dimensions cannot be renamed for isl.PwAff, so we first move + # input dims to the parameter space, rename then move back + restored_obj = restored_obj.move_dims( + isl.dim_type.param, + 0, + isl.dim_type.in_, + 0, + restored_obj.dim(isl.dim_type.in_), + ) + + restored_union_pw_aff = restored_obj.to_union_pw_aff() + for name, dim in name_to_dim.items(): + restored_union_pw_aff = restored_union_pw_aff.set_dim_name( + isl.dim_type.param, dim, name + ) + + restored_obj = restored_union_pw_aff.get_pw_aff_list().get_at(0) + return cast( + "RawInternalIslObjectT_co", + restored_obj.move_dims( + isl.dim_type.in_, + 0, + isl.dim_type.param, + 0, + restored_obj.dim(isl.dim_type.param), + ), + ) + + if isinstance(restored_obj, IslSetLike): + dt_to_restore = isl.dim_type.set + else: + dt_to_restore = isl.dim_type.in_ + + for name, dim in name_to_dim.items(): + restored_obj = restored_obj.set_dim_name(dt_to_restore, dim, name) + + if isinstance(restored_obj, isl.UnionPwAff | isl.UnionPwMultiAff): + raise NotImplementedError + + return cast("RawInternalIslObjectT_co", restored_obj) + + +def _get_dim_names(obj: IslObject, dt: isl.dim_type) -> frozenset[str]: + all_dt_names: list[str] = [] + for dim in range(obj.dim(dt)): + if isinstance(obj, isl.QPolynomial | isl.PwQPolynomial): + name = obj.space.get_dim_name(dt, dim) + else: + name = obj.get_dim_name(dt, dim) + + if name is None: + raise ValueError("unnamed dimension found") + + all_dt_names.append(name) + + return frozenset(all_dt_names) + + +@overload +def _deconstruct_object(obj: isl.Map) -> tuple[isl.Set, DimTypeToNames]: ... + + +@overload +def _deconstruct_object( + obj: RawInternalIslObject, +) -> tuple[RawInternalIslObject, DimTypeToNames]: ... + + +def _deconstruct_object(obj: IslObject) -> tuple[RawInternalIslObject, DimTypeToNames]: + """ + Convert a public isl object into namedisl's internal representation. + + Set-like objects are represented as flattened sets whose dimensions are + grouped as set/output names, input names, and parameter names. Expression + objects use input dimensions followed by parameters. + """ + dt_to_names: dict[isl.dim_type, frozenset[str]] = {} + + if isinstance(obj, IslSetLike): + decon_obj = obj + dt_to_names = dict.fromkeys([isl.dim_type.in_, isl.dim_type.param], frozenset()) + + for dt in dt_to_names: + dt_to_names[dt] = _get_dim_names(decon_obj, dt) + if dt_to_names[dt]: + decon_obj = decon_obj.move_dims( + isl.dim_type.set, + decon_obj.dim(isl.dim_type.set), + dt, + 0, + decon_obj.dim(dt), + ) + + decon_obj = ( + decon_obj.range() + if isinstance(decon_obj, isl.Map | isl.BasicMap) + else decon_obj + ) + + decon_obj = ( + isl.Set.from_basic_set(decon_obj) + if isinstance(decon_obj, isl.BasicSet) + else decon_obj + ) + + else: + decon_obj = obj + + dt_to_names = dict.fromkeys([isl.dim_type.param], frozenset()) + dt_to_names[isl.dim_type.param] = _get_dim_names(decon_obj, isl.dim_type.param) + + decon_obj = decon_obj.move_dims( + isl.dim_type.in_, + decon_obj.dim(isl.dim_type.in_), + isl.dim_type.param, + 0, + decon_obj.dim(isl.dim_type.param), + ) + + return decon_obj, constantdict(dt_to_names) + + +def _find_contiguous_dim_chunks(dims: Sequence[int]) -> Mapping[int, int]: + """ + Determines contiguous chunks of dimensions within a sequence of dimensions. + Returns a mapping of the first dimension in the chunk to the length of the + chunk. + """ + if not dims: + return {} + + chunks: dict[int, int] = {} + + start = dims[0] + count = 1 + + from itertools import pairwise + + for prev, curr in pairwise(dims): + if curr == prev + 1: + count += 1 + else: + chunks[start] = count + start = curr + count = 1 + + chunks[start] = count + + return constantdict(chunks) + + +# FIXME: think through whether or not alphabetical ordering will require more +# work on average than using one of the objects as a template in alignment +def _find_joint_name_to_dim( + obj1: NamedIslObject[InternalIslObject, IslObject], + obj2: NamedIslObject[InternalIslObject, IslObject], +) -> tuple[NameToDim, DimTypeToNames]: + """ + Enforces alphabetical ordering of all dimensions found in :arg:`obj1` and + :arg:`obj2` within each dimension-type chunk. This ordering is used in + alignment before performing operations between two set-like objects. + """ + all_set_names = obj1._names_for_dim_type( + isl.dim_type.set + ) | obj2._names_for_dim_type(isl.dim_type.set) + all_inp_names = obj1.input_names | obj2.input_names + all_param_names = obj1.parameter_names | obj2.parameter_names + + duplicate_names = ( + (all_set_names & all_inp_names) + | (all_set_names & all_param_names) + | (all_inp_names & all_param_names) + ) + if duplicate_names: + raise ValueError( + "duplicate dimension names across dimension types: " + + ", ".join(sorted(duplicate_names)) + ) + + dt_to_names: DimTypeToNames = {} + dt_to_names[isl.dim_type.param] = all_param_names + if _uses_explicit_input_metadata(obj1._obj) or _uses_explicit_input_metadata( + obj2._obj + ): + dt_to_names[isl.dim_type.in_] = all_inp_names + + # enforces contiguous ordering of [ (set), (input), (param) ] in set + # representation + all_names = [ + *sorted(all_set_names), + *sorted(all_inp_names), + *sorted(all_param_names), + ] + + name_to_dim: NameToDim = constantdict({ + name: pos for pos, name in enumerate(all_names) + }) + + return name_to_dim, constantdict(dt_to_names) + + +def _align_obj( + named_obj: NamedIslObjectT, ordering: NameToDim, dimtype_to_names: DimTypeToNames +) -> NamedIslObjectT: + """ + Return *named_obj* with internal dimensions arranged according to *ordering*. + + The isl object is moved or expanded as needed, but public names are carried + in metadata and are restored only when a raw isl object is reconstructed. + """ + new_isl_obj = named_obj._obj + running_name_to_dim = dict(named_obj._name_to_dim) + + target_dt = ( + isl.dim_type.set if isinstance(new_isl_obj, IslSetLike) else isl.dim_type.in_ + ) + + for name, target_dim in sorted(ordering.items(), key=lambda x: x[1]): + if name in running_name_to_dim: + old_dim = running_name_to_dim[name] + + if old_dim == target_dim: + continue + + # temporarily move to parameter dimension since destination and + # source dim types cannot match in ISL + new_isl_obj = new_isl_obj.move_dims( + isl.dim_type.param, 0, target_dt, old_dim, 1 + ) + + new_isl_obj = new_isl_obj.move_dims( + target_dt, target_dim, isl.dim_type.param, 0, 1 + ) + + else: + old_dim = new_isl_obj.dim(target_dt) + new_isl_obj = new_isl_obj.insert_dims(target_dt, target_dim, 1) + + # track side effects of inserting/swapping dimensions + for n, d in list(running_name_to_dim.items()): + if (target_dim > old_dim) and (d > old_dim): + running_name_to_dim[n] = d - 1 + elif (target_dim < old_dim) and (d < old_dim): + running_name_to_dim[n] = d + 1 + + running_name_to_dim[name] = target_dim + + return type(named_obj)( + new_isl_obj, + ordering, + dimtype_to_names, + ) + + +def _align_two( + named_obj1: NamedIslObjectT, named_obj2: NamedIslObjectT +) -> tuple[NamedIslObjectT, NamedIslObjectT]: + """ + Align two named isl objects to a common name-to-dimension mapping. + """ + + name_to_dim, dimtype_to_names = _find_joint_name_to_dim(named_obj1, named_obj2) + + named_obj1 = _align_obj(named_obj1, name_to_dim, dimtype_to_names) + named_obj2 = _align_obj(named_obj2, name_to_dim, dimtype_to_names) + + return named_obj1, named_obj2 + + +def _align_and_apply_binary_op( + lhs: NamedIslObject[InternalIslObjectT_co, IslObject], + rhs: NamedIslObject[InternalIslObjectT_co, IslObject], + op: Callable[ + [InternalIslObjectT_co, InternalIslObjectT_co], InternalIslObjectT_co + ], +) -> NamedIslObject[InternalIslObjectT_co, IslObject]: + """ + Align *lhs* and *rhs*, apply *op* to their isl objects, and wrap the result. + """ + + lhs, rhs = _align_two(lhs, rhs) + result = op(lhs._obj, rhs._obj) + + # NOTE: since lhs and rhs were aligned, they both agree on what name-to-dim + # and dimtype-to-name is, can just take information from lhs + return type(lhs)(result, lhs._name_to_dim, lhs._dimtype_to_names) + + +@dataclass(frozen=True, eq=False) +class NamedIslObject(ABC, Generic[InternalIslObjectT_co, PublicIslObjectT_co]): + """ + Base class for named isl wrappers. + + Instances pair a private isl object with metadata that records the semantic + name and public dimension kind of every internal dimension. Subclasses use + this metadata to implement operations in terms of names while still + delegating the underlying integer-set algebra to isl. + """ + + _obj: InternalIslObjectT_co + _name_to_dim: NameToDim + + # used to reconstruct ISL object + _dimtype_to_names: DimTypeToNames + + @property + def _metadata_input_names(self) -> frozenset[str]: + return self._dimtype_to_names.get(isl.dim_type.in_, frozenset()) + + @property + def _metadata_parameter_names(self) -> frozenset[str]: + return self._dimtype_to_names.get(isl.dim_type.param, frozenset()) + + def _names_for_dim_type(self, dim_type: isl.dim_type) -> frozenset[str]: + dim_type = _normalize_public_dim_type(dim_type) + if dim_type == isl.dim_type.param: + return self.parameter_names + + if _uses_explicit_input_metadata(self._obj): + if dim_type == isl.dim_type.in_: + return self._metadata_input_names + if dim_type == isl.dim_type.set: + return self.names - self._metadata_input_names - self.parameter_names + else: + if dim_type == isl.dim_type.in_: + return self.names - self.parameter_names + if dim_type == isl.dim_type.set: + return frozenset() + + raise ValueError(f"unsupported dim type: {dim_type}") + + def _ordered_names_for_dim_type(self, dim_type: isl.dim_type) -> tuple[str, ...]: + names = self._names_for_dim_type(dim_type) + return tuple(sorted(names, key=self._name_to_dim.__getitem__)) + + def _ordered_name_chunks(self) -> dict[isl.dim_type, tuple[str, ...]]: + return { + isl.dim_type.set: self._ordered_names_for_dim_type(isl.dim_type.set), + isl.dim_type.in_: self._ordered_names_for_dim_type(isl.dim_type.in_), + isl.dim_type.param: self._ordered_names_for_dim_type(isl.dim_type.param), + } + + def _empty_grouped_names(self) -> dict[isl.dim_type, list[str]]: + return { + isl.dim_type.set: [], + isl.dim_type.in_: [], + isl.dim_type.param: [], + } + + def _metadata_from_chunk_names( + self, chunk_names: Mapping[isl.dim_type, Collection[str]], *, has_inputs: bool + ) -> tuple[NameToDim, DimTypeToNames]: + ordered_names = [ + *chunk_names[isl.dim_type.set], + *chunk_names[isl.dim_type.in_], + *chunk_names[isl.dim_type.param], + ] + new_name_to_dim: NameToDim = constantdict({ + name: dim for dim, name in enumerate(ordered_names) + }) + + new_dimtype_to_names: dict[isl.dim_type, frozenset[str]] = {} + if has_inputs and chunk_names[isl.dim_type.in_]: + new_dimtype_to_names[isl.dim_type.in_] = frozenset( + chunk_names[isl.dim_type.in_] + ) + if chunk_names[isl.dim_type.param]: + new_dimtype_to_names[isl.dim_type.param] = frozenset( + chunk_names[isl.dim_type.param] + ) + + return new_name_to_dim, constantdict(new_dimtype_to_names) + + def _add_names_by_dim_type( + self, names_to_add: Collection[str], dim_type: isl.dim_type + ) -> Self: + dim_type = _normalize_public_dim_type(dim_type) + if dim_type not in (isl.dim_type.set, isl.dim_type.in_, isl.dim_type.param): + raise ValueError(f"unsupported dim type: {dim_type}") + if ( + not _uses_explicit_input_metadata(self._obj) + and dim_type == isl.dim_type.set + ): + raise ValueError(f"unsupported dim type: {dim_type}") + + if len(set(names_to_add)) != len(tuple(names_to_add)): + raise ValueError("duplicate names to add") + + for name in names_to_add: + if name in self.names: + raise ValueError(f"name already exists: {name}") + + if not names_to_add: + return self + + grouped_names = self._empty_grouped_names() + grouped_names[dim_type] = list(names_to_add) + + return self._add_grouped_names(grouped_names) + + def _add_grouped_names( + self, grouped_names: Mapping[isl.dim_type, Collection[str]] + ) -> Self: + seen_names: set[str] = set() + for names in grouped_names.values(): + for name in names: + if name in seen_names: + raise ValueError("duplicate names to add") + if name in self.names: + raise ValueError(f"name already exists: {name}") + seen_names.add(name) + + new_obj = self._obj + chunk_names = { + dt: list(names) for dt, names in self._ordered_name_chunks().items() + } + internal_dim_type = ( + isl.dim_type.set + if _uses_explicit_input_metadata(new_obj) + else isl.dim_type.in_ + ) + + insertion_starts = { + isl.dim_type.set: 0, + isl.dim_type.in_: len(chunk_names[isl.dim_type.set]), + isl.dim_type.param: ( + len(chunk_names[isl.dim_type.set]) + len(chunk_names[isl.dim_type.in_]) + ), + } + + for dim_type in (isl.dim_type.param, isl.dim_type.in_, isl.dim_type.set): + names_to_add = grouped_names[dim_type] + if not names_to_add: + continue + new_obj = new_obj.insert_dims( + internal_dim_type, insertion_starts[dim_type], len(names_to_add) + ) + chunk_names[dim_type] = [*names_to_add, *chunk_names[dim_type]] + + new_name_to_dim, new_dimtype_to_names = self._metadata_from_chunk_names( + chunk_names, + has_inputs=_uses_explicit_input_metadata(new_obj), + ) + + return type(self)( + cast("InternalIslObjectT_co", new_obj), + new_name_to_dim, + new_dimtype_to_names, + ) + + def add_set_names(self, names_to_add: Collection[str]) -> Self: + """ + Return a copy with new unconstrained set/output dimensions. + + :arg names_to_add: Names to insert. Existing names and duplicates are + rejected. + """ + return self._add_names_by_dim_type(names_to_add, isl.dim_type.set) + + def add_output_names(self, names_to_add: Collection[str]) -> Self: + """ + Return a copy with new unconstrained output dimensions. + + This is equivalent to :meth:`add_set_names` for map-like objects. + """ + return self._add_names_by_dim_type(names_to_add, isl.dim_type.out) + + def add_input_names(self, names_to_add: Collection[str]) -> Self: + """ + Return a copy with new unconstrained input dimensions. + """ + return self._add_names_by_dim_type(names_to_add, isl.dim_type.in_) + + def add_parameter_names(self, names_to_add: Collection[str]) -> Self: + """ + Return a copy with new parameter dimensions. + """ + return self._add_names_by_dim_type(names_to_add, isl.dim_type.param) + + def add_dim_names( + self, names_to_add: Collection[str], dim_type: isl.dim_type + ) -> Self: + """ + Return a copy with new dimensions of *dim_type*. + + :arg dim_type: One of ``isl.dim_type.set``, ``isl.dim_type.out``, + ``isl.dim_type.in_``, or ``isl.dim_type.param``. + """ + return self._add_names_by_dim_type(names_to_add, dim_type) + + @property + def names(self) -> frozenset[str]: + """ + All dimension names known to this object. + """ + return frozenset(self._name_to_dim.keys()) + + def dim_names(self, dim_type: isl.dim_type) -> frozenset[str]: + """ + Return the names belonging to *dim_type*. + """ + return self._names_for_dim_type(dim_type) + + def ordered_dim_names(self, dim_type: isl.dim_type) -> tuple[str, ...]: + """ + Return names for *dim_type* in their current dimension order. + """ + return self._ordered_names_for_dim_type(dim_type) + + @property + def set_names(self) -> frozenset[str]: + """ + Names of set dimensions. + """ + return self._names_for_dim_type(isl.dim_type.set) + + @property + def output_names(self) -> frozenset[str]: + """ + Names of output dimensions. + """ + return self._names_for_dim_type(isl.dim_type.out) + + def get_space(self) -> isl.Space: + """ + Reconstruct and return the object's public isl space. + """ + return self._reconstruct_isl_object().get_space() + + def dim(self, dim_type: isl.dim_type) -> int: + """ + Return the number of dimensions of *dim_type*. + """ + dim_type = _normalize_public_dim_type(dim_type) + if dim_type in (isl.dim_type.set, isl.dim_type.in_, isl.dim_type.param): + return len(self._names_for_dim_type(dim_type)) + return self._reconstruct_isl_object().dim(dim_type) + + def move_dims( + self, + names_to_move: str | Collection[str], + dst_type: isl.dim_type, + ) -> Self: + """ + Return a copy with named dimensions moved to *dst_type*. + + The relative order of moved names is preserved. Moving a name to its + current dimension kind is a no-op. + """ + if isinstance(names_to_move, str): + names_to_move = [names_to_move] + + if not names_to_move: + return self + + dst_type = _normalize_public_dim_type(dst_type) + if dst_type not in (isl.dim_type.set, isl.dim_type.in_, isl.dim_type.param): + raise ValueError(f"unsupported destination dim type: {dst_type}") + + missing_names = [name for name in names_to_move if name not in self.names] + if missing_names: + raise ValueError(f"unknown names: {', '.join(missing_names)}") + + if len(set(names_to_move)) != len(tuple(names_to_move)): + raise ValueError("duplicate names in move_dims") + + names_to_move = [ + name + for name in names_to_move + if name not in self._names_for_dim_type(dst_type) + ] + if not names_to_move: + return self + + moved_name_set = set(names_to_move) + chunk_names = { + dt: [name for name in names if name not in moved_name_set] + for dt, names in self._ordered_name_chunks().items() + } + moved_names = sorted(names_to_move, key=self._name_to_dim.__getitem__) + chunk_names[dst_type].extend(moved_names) + + new_name_to_dim, new_dimtype_to_names = self._metadata_from_chunk_names( + chunk_names, + has_inputs=_uses_explicit_input_metadata(self._obj), + ) + + return _align_obj(self, new_name_to_dim, new_dimtype_to_names) + + def rename_dims(self, renaming: Mapping[str, str]) -> Self: + """ + Return a copy with dimension names changed according to *renaming*. + + Renaming updates namedisl metadata only. The private isl object's own + names are restored when :meth:`get_isl_object` is called. + """ + if not renaming: + return self + + missing_names = [name for name in renaming if name not in self.names] + if missing_names: + raise ValueError(f"unknown names: {', '.join(missing_names)}") + + if len(set(renaming.values())) != len(renaming): + raise ValueError("duplicate destination names in rename_dims") + + unchanged_names = { + old_name for old_name, new_name in renaming.items() if old_name == new_name + } + renaming = { + old_name: new_name + for old_name, new_name in renaming.items() + if old_name not in unchanged_names + } + if not renaming: + return self + + existing_names = self.names - frozenset(renaming) + conflicting_names = existing_names & frozenset(renaming.values()) + if conflicting_names: + raise ValueError( + "cannot rename to existing names: " + + ", ".join(sorted(conflicting_names)) + ) + + new_name_to_dim: NameToDim = constantdict({ + renaming.get(name, name): dim for name, dim in self._name_to_dim.items() + }) + new_dimtype_to_names: DimTypeToNames = constantdict({ + dim_type: frozenset(renaming.get(name, name) for name in names) + for dim_type, names in self._dimtype_to_names.items() + }) + + return type(self)( + self._obj, + new_name_to_dim, + new_dimtype_to_names, + ) + + @overload + def equate_dims(self, name1: Mapping[str, str]) -> Self: ... + + @overload + def equate_dims(self, name1: str, name2: str) -> Self: ... + + def equate_dims( + self, + name1: str | Mapping[str, str], + name2: str | None = None, + ) -> Self: + """ + Return a copy constrained so paired dimensions are equal. + + Either pass two names, or pass a mapping whose keys and values are + equated pairwise. + """ + if isinstance(name1, str): + if name2 is None: + raise TypeError("name2 must be provided when name1 is a string") + equated_names = ((name1, name2),) + else: + if name2 is not None: + raise TypeError("name2 cannot be provided when name1 is a mapping") + equated_names = tuple(name1.items()) + + for lhs_name, rhs_name in equated_names: + if lhs_name not in self.names: + raise ValueError(f"unknown name: {lhs_name}") + if rhs_name not in self.names: + raise ValueError(f"unknown name: {rhs_name}") + + if all(lhs_name == rhs_name for lhs_name, rhs_name in equated_names): + return self + + if not isinstance(self._obj, IslSetLike): + raise NotImplementedError( + "equate_dims is only implemented for set-like objects" + ) + + obj = self._obj + for lhs_name, rhs_name in equated_names: + if lhs_name != rhs_name: + obj = obj.equate( + isl.dim_type.set, + self._name_to_dim[lhs_name], + isl.dim_type.set, + self._name_to_dim[rhs_name], + ) + + return type(self)( + cast("InternalIslObjectT_co", obj), + self._name_to_dim, + self._dimtype_to_names, + ) + + @property + def _has_inputs(self) -> bool: + return bool(self._metadata_input_names) + + @property + def input_names(self) -> frozenset[str]: + """ + Names of input dimensions. + """ + return self._names_for_dim_type(isl.dim_type.in_) + + @property + def _input_dim_start(self) -> int | None: + if self._has_inputs: + return min(self._name_to_dim[name] for name in self._metadata_input_names) + return None + + @property + def _has_params(self) -> bool: + return bool(self._metadata_parameter_names) + + @property + def parameter_names(self) -> frozenset[str]: + """ + Names of parameter dimensions. + """ + return self._metadata_parameter_names + + @property + def _parameter_dim_start(self) -> int | None: + if self._has_params: + return min( + self._name_to_dim[name] for name in self._metadata_parameter_names + ) + return None + + def _reconstruct_isl_object(self) -> PublicIslObjectT_co: + """ + Relies on the dimension type ordering in + :func:`_deconstruct_set_like_object`. + """ + obj = _restore_names(self._obj, self._name_to_dim) + + internal_dim = ( + isl.dim_type.set if isinstance(obj, isl.Set) else isl.dim_type.in_ + ) + + if self._has_params: + if self._parameter_dim_start is None: + raise ValueError( + "Object has parameter dimensions, but a starting index for " + "parameter names is not given. Reconstruction is not " + "possible" + ) + + param_start = self._parameter_dim_start + obj = obj.move_dims( + isl.dim_type.param, + 0, + internal_dim, + param_start, + len(self.parameter_names), + ) + + if self._has_inputs: + if self._input_dim_start is None: + raise ValueError( + "Object has input dimensions, but a starting index for " + "input names is not given. Reconstruction is not " + "possible" + ) + + obj_domain = isl.Set("{ [] }") + obj_range = obj + assert isinstance(obj_range, isl.BasicSet | isl.Set) + + obj = isl.Map.from_domain_and_range(obj_domain, obj_range) + + inp_start = self._input_dim_start + obj = obj.move_dims( + isl.dim_type.in_, 0, internal_dim, inp_start, len(self.input_names) + ) + + return cast("PublicIslObjectT_co", obj) + + def get_isl_object(self) -> PublicIslObjectT_co: + """ + Reconstruct and return the wrapped public :mod:`islpy` object. + """ + return self._reconstruct_isl_object() + + @override + def __str__(self) -> str: + return str(self._reconstruct_isl_object()) diff --git a/namedisl/expression_like.py b/namedisl/expression_like.py new file mode 100644 index 0000000..5d75e52 --- /dev/null +++ b/namedisl/expression_like.py @@ -0,0 +1,924 @@ +""" +Name-aware affine and polynomial expression wrappers. + +The wrappers in this module provide a small arithmetic interface around isl +expression objects while preserving named dimension metadata across alignment +and reconstruction. +""" + +from __future__ import annotations + + +__copyright__ = """ +Copyright (C) 2025- University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import operator +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, final, overload + +from constantdict import constantdict +from typing_extensions import override + +import islpy as isl + +from .core import ( + DimTypeToNames, + IslExpressionLikeT, + IslScalarExpressionLike, + NamedIslObject, + NameToDim, + _align_two, + _make_named_object_pieces, + _normalize_public_dim_type, +) + + +if TYPE_CHECKING: + from collections.abc import Callable, Mapping + + +PublicMultiExpressionLikeT = TypeVar( + "PublicMultiExpressionLikeT", + isl.MultiAff, + isl.PwMultiAff, +) +NamedScalarExpressionLikeT_co = TypeVar( + "NamedScalarExpressionLikeT_co", + bound=IslScalarExpressionLike, + covariant=True, +) + + +def _add_isl_expression( + lhs: IslExpressionLikeT, rhs: IslExpressionLikeT | int +) -> IslExpressionLikeT: + return cast("IslExpressionLikeT", cast("Any", operator.add)(lhs, rhs)) + + +def _sub_isl_expression( + lhs: IslExpressionLikeT, rhs: IslExpressionLikeT | int +) -> IslExpressionLikeT: + return cast("IslExpressionLikeT", cast("Any", operator.sub)(lhs, rhs)) + + +def _mul_isl_expression( + lhs: IslExpressionLikeT, rhs: IslExpressionLikeT | int +) -> IslExpressionLikeT: + return cast("IslExpressionLikeT", cast("Any", operator.mul)(lhs, rhs)) + + +def _radd_isl_expression(lhs: int, rhs: IslExpressionLikeT) -> IslExpressionLikeT: + return cast("IslExpressionLikeT", cast("Any", operator.add)(lhs, rhs)) + + +def _rsub_isl_expression(lhs: int, rhs: IslExpressionLikeT) -> IslExpressionLikeT: + return cast("IslExpressionLikeT", cast("Any", operator.sub)(lhs, rhs)) + + +def _rmul_isl_expression(lhs: int, rhs: IslExpressionLikeT) -> IslExpressionLikeT: + return cast("IslExpressionLikeT", cast("Any", operator.mul)(lhs, rhs)) + + +def _explicitly_promote_isl_expressions( + lhs: IslExpressionLikeT, rhs: IslExpressionLikeT +) -> tuple[IslExpressionLikeT, IslExpressionLikeT]: + if isinstance(lhs, isl.Aff) and isinstance(rhs, isl.PwAff): + return cast("IslExpressionLikeT", lhs.to_pw_aff()), rhs + if isinstance(lhs, isl.PwAff) and isinstance(rhs, isl.Aff): + return lhs, cast("IslExpressionLikeT", rhs.to_pw_aff()) + return lhs, rhs + + +# {{{ "base" named expression-likes (affs, pwaffs, qpolynomials, pwqpolynomials) + +@dataclass(frozen=True, eq=False) +class _NamedExpressionLike( + NamedIslObject[NamedScalarExpressionLikeT_co, NamedScalarExpressionLikeT_co] +): + @overload + def __add__(self: Aff, other: Aff | int) -> Aff: ... + + @overload + def __add__(self: Aff, other: PwAff) -> PwAff: ... + + @overload + def __add__(self: PwAff, other: Aff | PwAff | int) -> PwAff: ... + + @overload + def __add__(self: QPolynomial, other: QPolynomial | int) -> QPolynomial: ... + + @overload + def __add__( + self: PwQPolynomial, other: PwQPolynomial | int + ) -> PwQPolynomial: ... + + def __add__( + self: _NamedExpressionLike[IslScalarExpressionLike], + other: _NamedExpressionLike[IslScalarExpressionLike] | int, + ) -> _NamedExpressionLike[IslScalarExpressionLike]: + """ + Add another compatible named expression or an integer. + """ + return _apply_expression_binary_op(self, other, _add_isl_expression) + + @overload + def __radd__(self: Aff, other: int) -> Aff: ... + + @overload + def __radd__(self: PwAff, other: int) -> PwAff: ... + + @overload + def __radd__(self: QPolynomial, other: int) -> QPolynomial: ... + + @overload + def __radd__(self: PwQPolynomial, other: int) -> PwQPolynomial: ... + + def __radd__( + self: _NamedExpressionLike[IslScalarExpressionLike], + other: int, + ) -> _NamedExpressionLike[IslScalarExpressionLike]: + """ + Add this expression to an integer. + """ + return _apply_reflected_int_expression_op(self, other, _radd_isl_expression) + + @overload + def __sub__(self: Aff, other: Aff | int) -> Aff: ... + + @overload + def __sub__(self: Aff, other: PwAff) -> PwAff: ... + + @overload + def __sub__(self: PwAff, other: Aff | PwAff | int) -> PwAff: ... + + @overload + def __sub__(self: QPolynomial, other: QPolynomial | int) -> QPolynomial: ... + + @overload + def __sub__( + self: PwQPolynomial, other: PwQPolynomial | int + ) -> PwQPolynomial: ... + + def __sub__( + self: _NamedExpressionLike[IslScalarExpressionLike], + other: _NamedExpressionLike[IslScalarExpressionLike] | int, + ) -> _NamedExpressionLike[IslScalarExpressionLike]: + """ + Subtract another compatible named expression or an integer. + """ + return _apply_expression_binary_op(self, other, _sub_isl_expression) + + @overload + def __rsub__(self: Aff, other: int) -> Aff: ... + + @overload + def __rsub__(self: PwAff, other: int) -> PwAff: ... + + @overload + def __rsub__(self: QPolynomial, other: int) -> QPolynomial: ... + + @overload + def __rsub__(self: PwQPolynomial, other: int) -> PwQPolynomial: ... + + def __rsub__( + self: _NamedExpressionLike[IslScalarExpressionLike], + other: int, + ) -> _NamedExpressionLike[IslScalarExpressionLike]: + """ + Subtract this expression from an integer. + """ + return _apply_reflected_int_expression_op(self, other, _rsub_isl_expression) + + @overload + def __mul__(self: Aff, other: Aff | int) -> Aff: ... + + @overload + def __mul__(self: Aff, other: PwAff) -> PwAff: ... + + @overload + def __mul__(self: PwAff, other: Aff | PwAff | int) -> PwAff: ... + + @overload + def __mul__(self: QPolynomial, other: QPolynomial | int) -> QPolynomial: ... + + @overload + def __mul__( + self: PwQPolynomial, other: PwQPolynomial | int + ) -> PwQPolynomial: ... + + def __mul__( + self: _NamedExpressionLike[IslScalarExpressionLike], + other: _NamedExpressionLike[IslScalarExpressionLike] | int, + ) -> _NamedExpressionLike[IslScalarExpressionLike]: + """ + Multiply by another compatible named expression or an integer. + """ + return _apply_expression_binary_op(self, other, _mul_isl_expression) + + @overload + def __rmul__(self: Aff, other: int) -> Aff: ... + + @overload + def __rmul__(self: PwAff, other: int) -> PwAff: ... + + @overload + def __rmul__(self: QPolynomial, other: int) -> QPolynomial: ... + + @overload + def __rmul__(self: PwQPolynomial, other: int) -> PwQPolynomial: ... + + def __rmul__( + self: _NamedExpressionLike[IslScalarExpressionLike], + other: int, + ) -> _NamedExpressionLike[IslScalarExpressionLike]: + """ + Multiply this expression by an integer. + """ + return _apply_reflected_int_expression_op(self, other, _rmul_isl_expression) + + def is_zero(self) -> bool: + """ + Return whether this expression is identically zero. + """ + return bool(self._obj.is_zero()) # pyright: ignore[reportAttributeAccessIssue, reportUnknownArgumentType, reportUnknownMemberType] + + @override + def __eq__(self, other: object) -> bool: + if not isinstance(other, type(self)): + raise NotImplementedError("Objects are not of the same type") + raise NotImplementedError + + +@dataclass(frozen=True, eq=False) +class _NamedPwExpressionLike(_NamedExpressionLike[NamedScalarExpressionLikeT_co]): + ... + + +@final +@dataclass(frozen=True, eq=False) +class Aff(_NamedExpressionLike[isl.Aff]): + """ + Name-aware wrapper around :class:`islpy.Aff`. + + Construct instances with :func:`make_aff`. + """ + + _obj: isl.Aff + + @override + def _reconstruct_isl_object(self) -> isl.Aff: + obj = super()._reconstruct_isl_object() + assert isinstance(obj, isl.Aff) + return obj + + +@overload +def make_aff(src: str, ctx: isl.Context | None = None) -> Aff: + ... + + +@overload +def make_aff(src: isl.Aff) -> Aff: + ... + + +def make_aff(src: str | isl.Aff, ctx: isl.Context | None = None) -> Aff: + """ + Create an :class:`Aff` from isl syntax or an :class:`islpy.Aff`. + """ + obj = isl.Aff(src, ctx) if isinstance(src, str) else src + aff_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) + assert isinstance(aff_obj, isl.Aff) + return Aff(aff_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args + + +@final +@dataclass(frozen=True, eq=False) +class QPolynomial(_NamedExpressionLike[isl.QPolynomial]): + """ + Name-aware wrapper around :class:`islpy.QPolynomial`. + + Construct instances with :func:`make_qpolynomial`. + """ + + _obj: isl.QPolynomial + + @override + def _reconstruct_isl_object(self) -> isl.QPolynomial: + obj = super()._reconstruct_isl_object() + assert isinstance(obj, isl.QPolynomial) + return obj + + +@overload +def make_qpolynomial(src: str, ctx: isl.Context | None = None) -> QPolynomial: + ... + + +@overload +def make_qpolynomial(src: isl.QPolynomial) -> QPolynomial: + ... + + +def make_qpolynomial( + src: str | isl.QPolynomial, ctx: isl.Context | None = None) -> QPolynomial: + """ + Create a :class:`QPolynomial` from isl syntax or an isl qpolynomial. + """ + # NOTE: ISL does not have a QPolynomial constructor, but we can make one + # here by first creating a PwQPolynomial, then taking the only QPolynomial + # that comes out of it :shrug: + obj = ( + isl.PwQPolynomial(src, ctx).get_pieces()[0][1] if isinstance(src, str) + else src + ) + + qp_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) + assert isinstance(qp_obj, isl.QPolynomial) + return QPolynomial(qp_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args + + +@final +@dataclass(frozen=True, eq=False) +class PwAff(_NamedPwExpressionLike[isl.PwAff]): + """ + Name-aware wrapper around :class:`islpy.PwAff`. + + Construct instances with :func:`make_pw_aff`. + """ + + _obj: isl.PwAff + + @override + def _reconstruct_isl_object(self) -> isl.PwAff: + obj = super()._reconstruct_isl_object() + assert isinstance(obj, isl.PwAff) + return obj + + +@overload +def make_pw_aff(src: str, ctx: isl.Context | None = None) -> PwAff: + ... + + +@overload +def make_pw_aff(src: isl.PwAff) -> PwAff: + ... + + +def make_pw_aff(src: str | isl.PwAff, ctx: isl.Context | None = None) -> PwAff: + """ + Create a :class:`PwAff` from isl syntax or an :class:`islpy.PwAff`. + """ + obj = isl.PwAff(src, ctx) if isinstance(src, str) else src + pwaff_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) + assert isinstance(pwaff_obj, isl.PwAff) + return PwAff(pwaff_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args + + +@final +@dataclass(frozen=True, eq=False) +class PwQPolynomial(_NamedPwExpressionLike[isl.PwQPolynomial]): + """ + Name-aware wrapper around :class:`islpy.PwQPolynomial`. + + Construct instances with :func:`make_pw_qpolynomial`. + """ + + _obj: isl.PwQPolynomial + + @override + def _reconstruct_isl_object(self) -> isl.PwQPolynomial: + obj = super()._reconstruct_isl_object() + assert isinstance(obj, isl.PwQPolynomial) + return obj + + +@overload +def make_pw_qpolynomial( + src: str, ctx: isl.Context | None = None) -> PwQPolynomial: + ... + + +@overload +def make_pw_qpolynomial(src: isl.PwQPolynomial) -> PwQPolynomial: + ... + + +def make_pw_qpolynomial( + src: str | isl.PwQPolynomial, + ctx: isl.Context | None = None + ) -> PwQPolynomial: + """ + Create a :class:`PwQPolynomial` from isl syntax or an isl object. + """ + obj = isl.PwQPolynomial(src, ctx) if isinstance(src, str) else src + pw_qp_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) + assert isinstance(pw_qp_obj, isl.PwQPolynomial) + return PwQPolynomial(pw_qp_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args + + +def _wrap_expression_result( + result: IslExpressionLikeT, + name_to_dim: NameToDim, + dimtype_to_names: DimTypeToNames, +) -> _NamedExpressionLike[IslExpressionLikeT]: + if isinstance(result, isl.Aff): + return cast( + "_NamedExpressionLike[IslExpressionLikeT]", + Aff(result, name_to_dim, dimtype_to_names), # pylint: disable=too-many-function-args + ) + if isinstance(result, isl.PwAff): + return cast( + "_NamedExpressionLike[IslExpressionLikeT]", + PwAff(result, name_to_dim, dimtype_to_names), # pylint: disable=too-many-function-args + ) + if isinstance(result, isl.QPolynomial): + return cast( + "_NamedExpressionLike[IslExpressionLikeT]", + QPolynomial(result, name_to_dim, dimtype_to_names), # pylint: disable=too-many-function-args + ) + if isinstance(result, isl.PwQPolynomial): + return cast( + "_NamedExpressionLike[IslExpressionLikeT]", + PwQPolynomial(result, name_to_dim, dimtype_to_names), # pylint: disable=too-many-function-args + ) + raise TypeError(f"unsupported expression result type: {type(result).__name__}") + + +@overload +def _apply_expression_binary_op( + lhs: Aff, + rhs: Aff | int, + op: Callable[ + [IslScalarExpressionLike, IslScalarExpressionLike | int], + IslScalarExpressionLike, + ], +) -> Aff: ... + + +@overload +def _apply_expression_binary_op( + lhs: Aff, + rhs: PwAff, + op: Callable[ + [IslScalarExpressionLike, IslScalarExpressionLike | int], + IslScalarExpressionLike, + ], +) -> PwAff: ... + + +@overload +def _apply_expression_binary_op( + lhs: PwAff, + rhs: Aff | PwAff | int, + op: Callable[ + [IslScalarExpressionLike, IslScalarExpressionLike | int], + IslScalarExpressionLike, + ], +) -> PwAff: ... + + +@overload +def _apply_expression_binary_op( + lhs: QPolynomial, + rhs: QPolynomial | int, + op: Callable[ + [IslScalarExpressionLike, IslScalarExpressionLike | int], + IslScalarExpressionLike, + ], +) -> QPolynomial: ... + + +@overload +def _apply_expression_binary_op( + lhs: PwQPolynomial, + rhs: PwQPolynomial | int, + op: Callable[ + [IslScalarExpressionLike, IslScalarExpressionLike | int], + IslScalarExpressionLike, + ], +) -> PwQPolynomial: ... + + +@overload +def _apply_expression_binary_op( + lhs: _NamedExpressionLike[IslScalarExpressionLike], + rhs: _NamedExpressionLike[IslScalarExpressionLike] | int, + op: Callable[ + [IslScalarExpressionLike, IslScalarExpressionLike | int], + IslScalarExpressionLike, + ], +) -> _NamedExpressionLike[IslScalarExpressionLike]: ... + + +def _apply_expression_binary_op( + lhs: _NamedExpressionLike[IslScalarExpressionLike], + rhs: _NamedExpressionLike[IslScalarExpressionLike] | int, + op: Callable[ + [IslScalarExpressionLike, IslScalarExpressionLike | int], + IslScalarExpressionLike, + ], +) -> _NamedExpressionLike[IslScalarExpressionLike]: + if isinstance(rhs, int): + return _wrap_expression_result( + op(lhs._obj, rhs), + lhs._name_to_dim, + lhs._dimtype_to_names, + ) + + lhs, rhs = _align_two(lhs, rhs) + lhs_obj, rhs_obj = _explicitly_promote_isl_expressions(lhs._obj, rhs._obj) + result = op(lhs_obj, rhs_obj) + return _wrap_expression_result( + result, + lhs._name_to_dim, + lhs._dimtype_to_names, + ) + + +@overload +def _apply_reflected_int_expression_op( + expr: Aff, + other: int, + op: Callable[[int, IslScalarExpressionLike], IslScalarExpressionLike], +) -> Aff: ... + + +@overload +def _apply_reflected_int_expression_op( + expr: PwAff, + other: int, + op: Callable[[int, IslScalarExpressionLike], IslScalarExpressionLike], +) -> PwAff: ... + + +@overload +def _apply_reflected_int_expression_op( + expr: QPolynomial, + other: int, + op: Callable[[int, IslScalarExpressionLike], IslScalarExpressionLike], +) -> QPolynomial: ... + + +@overload +def _apply_reflected_int_expression_op( + expr: PwQPolynomial, + other: int, + op: Callable[[int, IslScalarExpressionLike], IslScalarExpressionLike], +) -> PwQPolynomial: ... + + +@overload +def _apply_reflected_int_expression_op( + expr: _NamedExpressionLike[IslScalarExpressionLike], + other: int, + op: Callable[[int, IslScalarExpressionLike], IslScalarExpressionLike], +) -> _NamedExpressionLike[IslScalarExpressionLike]: ... + + +def _apply_reflected_int_expression_op( + expr: _NamedExpressionLike[IslScalarExpressionLike], + other: int, + op: Callable[[int, IslScalarExpressionLike], IslScalarExpressionLike], +) -> _NamedExpressionLike[IslScalarExpressionLike]: + return _wrap_expression_result( + op(other, expr._obj), + expr._name_to_dim, + expr._dimtype_to_names, + ) + +# }}} + + +# {{{ multi expression-likes (multiaff, pwmultiaff) + +def _ordered_multi_dim_names( + obj: isl.MultiAff | isl.PwMultiAff, dim_type: isl.dim_type +) -> tuple[str, ...]: + space = obj.get_space() + names: list[str] = [] + for dim in range(obj.dim(dim_type)): + name = space.get_dim_name(dim_type, dim) + if name is None: + raise ValueError("duplicate or unnamed dimension found") + names.append(name) + return tuple(names) + + +def _make_multi_expression_parts( + obj: isl.MultiAff | isl.PwMultiAff, +) -> tuple[Mapping[str, PwAff], NameToDim, DimTypeToNames]: + output_names = _ordered_multi_dim_names(obj, isl.dim_type.out) + + parts: Mapping[str, PwAff] = constantdict({ + name: make_pw_aff( + obj.get_at(dim).to_pw_aff() + if isinstance(obj, isl.MultiAff) + else obj.get_at(dim) + ) + for dim, name in enumerate(output_names) + }) + + if parts: + input_names = _ordered_part_dim_names(parts, isl.dim_type.in_) + parameter_names = _ordered_part_dim_names(parts, isl.dim_type.param) + else: + input_names = _ordered_multi_dim_names(obj, isl.dim_type.in_) + parameter_names = _ordered_multi_dim_names(obj, isl.dim_type.param) + + seen_names: set[str] = set() + for name in (*output_names, *input_names, *parameter_names): + if name in seen_names: + raise ValueError(f"duplicate dimension name found: {name}") + seen_names.add(name) + + all_names = [*output_names, *input_names, *parameter_names] + name_to_dim: NameToDim = constantdict({ + name: dim for dim, name in enumerate(all_names) + }) + + dimtype_to_names: dict[isl.dim_type, frozenset[str]] = {} + if input_names: + dimtype_to_names[isl.dim_type.in_] = frozenset(input_names) + if parameter_names: + dimtype_to_names[isl.dim_type.param] = frozenset(parameter_names) + + return parts, name_to_dim, constantdict(dimtype_to_names) + + +def _ordered_part_dim_names( + parts: Mapping[str, PwAff], + dim_type: isl.dim_type, +) -> tuple[str, ...]: + part_iter = iter(parts.items()) + _, first_part = next(part_iter) + ordered_names = first_part.ordered_dim_names(dim_type) + + for output_name, part in part_iter: + part_ordered_names = part.ordered_dim_names(dim_type) + if part_ordered_names != ordered_names: + raise ValueError( + f"multi expression part '{output_name}' has inconsistent " + f"{dim_type.name} dimension names" + ) + + return ordered_names + + +@dataclass(frozen=True, eq=False) +class _NamedMultiExpressionLike(Generic[PublicMultiExpressionLikeT]): + """ + Multi-expression components are stored directly as named :class:`PwAff` + parts, keyed by output name. + """ + + _obj: Mapping[str, PwAff] + _name_to_dim: NameToDim + _dimtype_to_names: DimTypeToNames + + @property + def _metadata_input_names(self) -> frozenset[str]: + return self._dimtype_to_names.get(isl.dim_type.in_, frozenset()) + + @property + def _metadata_parameter_names(self) -> frozenset[str]: + return self._dimtype_to_names.get(isl.dim_type.param, frozenset()) + + def _names_for_dim_type(self, dim_type: isl.dim_type) -> frozenset[str]: + dim_type = _normalize_public_dim_type(dim_type) + if dim_type == isl.dim_type.param: + return self.parameter_names + if dim_type == isl.dim_type.in_: + return self._metadata_input_names + if dim_type == isl.dim_type.set: + return frozenset(self._obj) + raise ValueError(f"unsupported dim type: {dim_type}") + + def _ordered_names_for_dim_type(self, dim_type: isl.dim_type) -> tuple[str, ...]: + names = self._names_for_dim_type(dim_type) + return tuple(sorted(names, key=self._name_to_dim.__getitem__)) + + @property + def names(self) -> frozenset[str]: + """ + All dimension names known to this object. + """ + return self.output_names | self.input_names | self.parameter_names + + def dim_names(self, dim_type: isl.dim_type) -> frozenset[str]: + """ + Return the names belonging to *dim_type*. + """ + return self._names_for_dim_type(dim_type) + + def ordered_dim_names(self, dim_type: isl.dim_type) -> tuple[str, ...]: + """ + Return names for *dim_type* in their current dimension order. + """ + return self._ordered_names_for_dim_type(dim_type) + + @property + def set_names(self) -> frozenset[str]: + """ + Names of set dimensions. + """ + return self._names_for_dim_type(isl.dim_type.set) + + @property + def output_names(self) -> frozenset[str]: + """ + Names of output dimensions. + """ + return self._names_for_dim_type(isl.dim_type.out) + + @property + def input_names(self) -> frozenset[str]: + """ + Names of input dimensions. + """ + return self._names_for_dim_type(isl.dim_type.in_) + + @property + def parameter_names(self) -> frozenset[str]: + """ + Names of parameter dimensions. + """ + return self._metadata_parameter_names + + def dim(self, dim_type: isl.dim_type) -> int: + """ + Return the number of dimensions of *dim_type*. + """ + dim_type = _normalize_public_dim_type(dim_type) + if dim_type in (isl.dim_type.set, isl.dim_type.in_, isl.dim_type.param): + return len(self._names_for_dim_type(dim_type)) + return self._reconstruct_isl_object().dim(dim_type) + + def get_space(self) -> isl.Space: + """ + Reconstruct and return the object's public isl space. + """ + return self._reconstruct_isl_object().get_space() + + def get_isl_object(self) -> PublicMultiExpressionLikeT: + """ + Reconstruct and return the wrapped public :mod:`islpy` object. + """ + return self._reconstruct_isl_object() + + def _reconstruct_isl_object(self) -> PublicMultiExpressionLikeT: + raise NotImplementedError + + def _multi_expression_context(self) -> isl.Context: + if self._obj: + return next(iter(self._obj.values()))._obj.get_ctx() + return isl.DEFAULT_CONTEXT + + def _multi_expression_space(self) -> isl.Space: + return isl.Space.create_from_names( + self._multi_expression_context(), + params=list(self.ordered_dim_names(isl.dim_type.param)), + in_=list(self.ordered_dim_names(isl.dim_type.in_)), + out=list(self.ordered_dim_names(isl.dim_type.out)), + ) + + def _ordered_pw_aff_parts(self) -> tuple[isl.PwAff, ...]: + return tuple( + self._obj[name]._reconstruct_isl_object() + for name in self.ordered_dim_names(isl.dim_type.out) + ) + + +@final +@dataclass(frozen=True, eq=False) +class PwMultiAff(_NamedMultiExpressionLike[isl.PwMultiAff]): + """ + Name-aware wrapper around :class:`islpy.PwMultiAff`. + + Construct instances with :func:`make_pw_multi_aff`. + """ + + def get_at(self, name: str) -> PwAff: + """ + Return the output component named *name*. + """ + if name not in self._names_for_dim_type(isl.dim_type.set): + raise ValueError(f"unknown output name: {name}") + return self._obj[name] + + @override + def _reconstruct_isl_object(self) -> isl.PwMultiAff: + space = self._multi_expression_space() + if not self._obj: + return isl.PwMultiAff.zero(space) + + pw_aff_list = isl.PwAffList.alloc( + self._multi_expression_context(), + len(self._obj), + ) + for part in self._ordered_pw_aff_parts(): + pw_aff_list = pw_aff_list.add(part) + + return isl.PwMultiAff.from_multi_pw_aff( + isl.MultiPwAff.from_pw_aff_list(space, pw_aff_list) + ) + + +@overload +def make_pw_multi_aff(src: str, ctx: isl.Context | None = None) -> PwMultiAff: + ... + + +@overload +def make_pw_multi_aff(src: isl.PwMultiAff) -> PwMultiAff: + ... + + +def make_pw_multi_aff( + src: str | isl.PwMultiAff, + ctx: isl.Context | None = None + ) -> PwMultiAff: + """ + Create a :class:`PwMultiAff` from isl syntax or an :class:`islpy.PwMultiAff`. + """ + + obj = isl.PwMultiAff(src, ctx) if isinstance(src, str) else src + pw_maff_obj, name_to_dim, dimtype_to_names = _make_multi_expression_parts(obj) + return PwMultiAff(pw_maff_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args + + +@final +@dataclass(frozen=True, eq=False) +class MultiAff(_NamedMultiExpressionLike[isl.MultiAff]): + """ + Name-aware wrapper around :class:`islpy.MultiAff`. + + Construct instances with :func:`make_multi_aff`. + """ + + def get_at(self, name: str) -> PwAff: + """ + Return the output component named *name*. + """ + if name not in self._names_for_dim_type(isl.dim_type.set): + raise ValueError(f"unknown output name: {name}") + return self._obj[name] + + @override + def _reconstruct_isl_object(self) -> isl.MultiAff: + space = self._multi_expression_space() + if not self._obj: + return isl.MultiAff.zero(space) + + aff_list = isl.AffList.alloc( + self._multi_expression_context(), + len(self._obj), + ) + for part in self._ordered_pw_aff_parts(): + aff_list = aff_list.add(part.as_aff()) + + return isl.MultiAff.from_aff_list(space, aff_list) + + +@overload +def make_multi_aff(src: str, ctx: isl.Context | None = None) -> MultiAff: + ... + + +@overload +def make_multi_aff(src: isl.MultiAff) -> MultiAff: + ... + + +def make_multi_aff( + src: str | isl.MultiAff, ctx: isl.Context | None = None) -> MultiAff: + """ + Create a :class:`MultiAff` from isl syntax or an :class:`islpy.MultiAff`. + """ + obj = isl.MultiAff(src, ctx) if isinstance(src, str) else src + maff_obj, name_to_dim, dimtype_to_names = _make_multi_expression_parts(obj) + return MultiAff(maff_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args + +# }}} diff --git a/namedisl/set_like.py b/namedisl/set_like.py new file mode 100644 index 0000000..a467a76 --- /dev/null +++ b/namedisl/set_like.py @@ -0,0 +1,990 @@ +""" +Name-aware set and map wrappers. + +The classes in this module wrap :mod:`islpy` sets and maps while making +dimension names the primary way to address axes. Internally, maps and sets are +stored as set-like isl objects with metadata that distinguishes output, input, +and parameter dimensions. +""" + +from __future__ import annotations + + +__copyright__ = """ +Copyright (C) 2025- University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import operator +from abc import ABC +from dataclasses import dataclass, replace +from typing import TYPE_CHECKING, Any, TypeVar, cast, final, overload + +from constantdict import constantdict +from typing_extensions import Self, override + +import islpy as isl + +from .core import ( + IslSetLike, + NamedIslObject, + NameToDim, + _align_obj, + _align_two, + _find_contiguous_dim_chunks, + _make_named_object_pieces, +) + + +PublicSetLikeT_co = TypeVar("PublicSetLikeT_co", bound=IslSetLike, covariant=True) +PublicMapLikeT_co = TypeVar( + "PublicMapLikeT_co", isl.BasicMap, isl.Map, covariant=True +) + + +def _set_like_and(lhs: isl.Set, rhs: isl.Set) -> isl.Set: + return cast("isl.Set", cast("Any", operator.and_)(lhs, rhs)) + + +def _set_like_or(lhs: isl.Set, rhs: isl.Set) -> isl.Set: + return cast("isl.Set", cast("Any", operator.or_)(lhs, rhs)) + + +if TYPE_CHECKING: + from collections.abc import Callable, Collection, Sequence + + +@dataclass(frozen=True, eq=False) +class _NamedIslSetLike(NamedIslObject[isl.Set, PublicSetLikeT_co], ABC): + """ + Represents set-like objects with parameter dimensions as a non-parameterized + set. Names are organized as contiguous chunks of dimension types, i.e. + [ (set names), (input names), (parameter names) ] + """ + + def complement(self: Self) -> Self: + """ + Return the complement of this set-like object. + """ + return replace( + self, + _obj=self._obj.complement(), + _name_to_dim=self._name_to_dim, + _dimtype_to_names=self._dimtype_to_names, + ) + + @overload + def convex_hull(self: BasicMap | Map) -> BasicMap: ... + + @overload + def convex_hull(self: BasicSet | Set) -> BasicSet: ... + + def convex_hull(self) -> BasicMap | BasicSet: + """ + Return the convex hull as a basic set or basic map. + """ + result = isl.Set.from_basic_set(self._obj.convex_hull()) + if isinstance(self, _NamedIslMapLike): + return BasicMap( # pylint: disable=too-many-function-args + result, + self._name_to_dim, + self._dimtype_to_names, + ) + + return BasicSet( # pylint: disable=too-many-function-args + result, + self._name_to_dim, + self._dimtype_to_names, + ) + + def eliminate(self: Self, names_to_eliminate: str | Collection[str]) -> Self: + """ + Eliminate constraints involving the named dimensions without removing them. + """ + if isinstance(names_to_eliminate, str): + names_to_eliminate = [names_to_eliminate] + + missing_names = [ + name for name in names_to_eliminate if name not in self.names + ] + if missing_names: + raise ValueError(f"unknown names: {', '.join(missing_names)}") + + dims_to_eliminate = sorted( + self._name_to_dim[name] for name in names_to_eliminate + ) + + contiguous_dim_chunks = _find_contiguous_dim_chunks(dims_to_eliminate) + + new_isl_obj = self._obj + for start in sorted(contiguous_dim_chunks): + new_isl_obj = new_isl_obj.eliminate( + isl.dim_type.set, start, contiguous_dim_chunks[start] + ) + + return replace( + self, + _obj=new_isl_obj, + _name_to_dim=self._name_to_dim, # NOTE: no dims removed by eliminate + _dimtype_to_names=self._dimtype_to_names, + ) + + def add_constraint( + self: Self, + constraints: str | Collection[str], + ) -> Self: + """ + Return a copy intersected with additional named constraints. + + :arg constraints: One constraint string or a collection of constraint + strings in isl syntax, written using this object's dimension names. + """ + if isinstance(constraints, str): + constraints = [constraints] + else: + constraints = list(constraints) + + if not constraints: + return self + + ordered_names = tuple( + sorted(self._name_to_dim, key=self._name_to_dim.__getitem__) + ) + constraint_text = " and ".join(f"({constraint})" for constraint in constraints) + constraint_src = f"{{ [{', '.join(ordered_names)}] : {constraint_text} }}" + + try: + constraint_obj = isl.Set(constraint_src) + except isl.Error as exc: + raise ValueError( + f"invalid constraint for names {ordered_names}: {constraint_text}" + ) from exc + + constraint_obj = constraint_obj.remove_redundancies() + constraint_set, constraint_name_to_dim, _ = _make_named_object_pieces( + constraint_obj + ) + assert isinstance(constraint_set, isl.Set) + + if constraint_name_to_dim != self._name_to_dim: + constraint_set = _align_obj( + Set( # pylint: disable=too-many-function-args + constraint_set, + constraint_name_to_dim, + self._dimtype_to_names, + ), + self._name_to_dim, + self._dimtype_to_names, + )._obj + assert isinstance(constraint_set, isl.Set) + + return replace( + self, + _obj=self._obj.intersect(constraint_set), + _name_to_dim=self._name_to_dim, + _dimtype_to_names=self._dimtype_to_names, + ) + + @overload + def gist( + self: BasicMap, context: _NamedIslSetLike[IslSetLike] + ) -> BasicMap | Map: ... + + @overload + def gist(self: Map, context: _NamedIslSetLike[IslSetLike]) -> Map: ... + + @overload + def gist( + self: BasicSet, context: _NamedIslSetLike[IslSetLike] + ) -> BasicSet | Set: ... + + @overload + def gist(self: Set, context: _NamedIslSetLike[IslSetLike]) -> Set: ... + + def gist( + self, context: _NamedIslSetLike[IslSetLike] + ) -> _NamedIslSetLike[IslSetLike]: + """ + Simplify this object under the assumptions described by *context*. + """ + self_aligned, context_aligned = _align_two(self, context) + result = self_aligned._obj.gist(context_aligned._obj) + + if isinstance(self, BasicMap): + result_type = BasicMap if result.n_basic_set() == 1 else Map + elif isinstance(self, Map): + result_type = Map + elif isinstance(self, BasicSet): + result_type = BasicSet if result.n_basic_set() == 1 else Set + else: + result_type = Set + + return result_type( # pylint: disable=too-many-function-args + result, + self_aligned._name_to_dim, + self_aligned._dimtype_to_names, + ) + + def project_out(self: Self, names_to_project_out: str | Collection[str]) -> Self: + """ + Return a copy with the named dimensions projected out. + """ + + if isinstance(names_to_project_out, str): + names_to_project_out = [names_to_project_out] + + missing_names = [ + name for name in names_to_project_out if name not in self.names + ] + if missing_names: + raise ValueError(f"unknown names: {', '.join(missing_names)}") + + names_to_remove = set(names_to_project_out) + + dims_to_remove = sorted(self._name_to_dim[name] for name in names_to_remove) + + new_isl_obj = self._obj + contiguous_dim_chunks = _find_contiguous_dim_chunks(dims_to_remove) + for start in sorted(contiguous_dim_chunks, reverse=True): + new_isl_obj = new_isl_obj.project_out( + isl.dim_type.set, start, contiguous_dim_chunks[start] + ) + + new_name_to_dim: NameToDim = {} + for name, dim in self._name_to_dim.items(): + if name in names_to_remove: + continue + + shift = 0 + for removed_dim in dims_to_remove: + if removed_dim < dim: + shift += 1 + else: + break + + new_name_to_dim[name] = dim - shift + + new_type_to_names = constantdict({ + dt: self._dimtype_to_names[dt] - frozenset(names_to_remove) + for dt in self._dimtype_to_names + }) + + return replace( + self, + _obj=new_isl_obj, + _name_to_dim=constantdict(new_name_to_dim), + _dimtype_to_names=new_type_to_names, + ) + + def project_out_except( + self: Self, + names_to_keep: str | Collection[str], + ) -> Self: + """ + Project out every dimension except those named in *names_to_keep*. + """ + + if isinstance(names_to_keep, str): + names_to_keep = [names_to_keep] if names_to_keep else [] + + names_to_project_out = [ + name for name in self._name_to_dim if name not in names_to_keep + ] + + return self.project_out(names_to_project_out) + + def dim_max(self, name: str) -> isl.PwAff: + """ + Return the parametric maximum of the named dimension. + """ + obj, dim = self._dim_bound_object_and_dim(name) + return obj.dim_max(dim) + + def dim_min(self, name: str) -> isl.PwAff: + """ + Return the parametric minimum of the named dimension. + """ + obj, dim = self._dim_bound_object_and_dim(name) + return obj.dim_min(dim) + + def _dim_bound_object_and_dim( + self, name: str + ) -> tuple[isl.BasicSet | isl.Set, int]: + if name not in self.names: + raise ValueError(f"unknown name: {name}") + if name in self.parameter_names: + raise ValueError(f"cannot compute a bound for parameter: {name}") + + if isinstance(self, _NamedIslMapLike): + bound_set = self.domain() if name in self.input_names else self.range() + obj = bound_set._reconstruct_isl_object() + assert isinstance(obj, isl.BasicSet | isl.Set) + return obj, bound_set._name_to_dim[name] + + obj = self._reconstruct_isl_object() + assert isinstance(obj, isl.BasicSet | isl.Set) + return obj, self._name_to_dim[name] + + def is_empty(self) -> bool: + """ + Return whether this object contains no integer points. + """ + return bool(self._obj.is_empty()) + + def as_pw_multi_aff(self) -> isl.PwMultiAff: + """ + Reconstruct and convert this object to :class:`islpy.PwMultiAff`. + """ + obj = self._reconstruct_isl_object() + assert isinstance(obj, isl.Set | isl.Map) + return obj.as_pw_multi_aff() + + @override + def dim(self, dim_type: isl.dim_type) -> int: + if dim_type == isl.dim_type.out: + dim_type = isl.dim_type.set + return super().dim(dim_type) + + @overload + def __and__(self: BasicMap, other: BasicMap | Map) -> BasicMap | Map: ... + + @overload + def __and__(self: Map, other: BasicMap | Map) -> Map: ... + + @overload + def __and__(self: BasicSet, other: BasicSet | Set) -> BasicSet | Set: ... + + @overload + def __and__(self: Set, other: BasicSet | Set) -> Set: ... + + def __and__( + self, other: _NamedIslSetLike[IslSetLike] + ) -> _NamedIslSetLike[IslSetLike]: + """ + Return the intersection of two compatible named set-like objects. + """ + return _apply_set_like_binary_op(self, other, _set_like_and) + + @overload + def __or__(self: BasicMap, other: BasicMap | Map) -> BasicMap | Map: ... + + @overload + def __or__(self: Map, other: BasicMap | Map) -> Map: ... + + @overload + def __or__(self: BasicSet, other: BasicSet | Set) -> BasicSet | Set: ... + + @overload + def __or__(self: Set, other: BasicSet | Set) -> Set: ... + + def __or__( + self, other: _NamedIslSetLike[IslSetLike] + ) -> _NamedIslSetLike[IslSetLike]: + """ + Return the union of two compatible named set-like objects. + """ + return _apply_set_like_binary_op(self, other, _set_like_or) + + @overload + def __sub__(self: BasicMap, other: BasicMap | Map) -> BasicMap | Map: ... + + @overload + def __sub__(self: Map, other: BasicMap | Map) -> Map: ... + + @overload + def __sub__(self: BasicSet, other: BasicSet | Set) -> BasicSet | Set: ... + + @overload + def __sub__(self: Set, other: BasicSet | Set) -> Set: ... + + def __sub__( + self, other: _NamedIslSetLike[IslSetLike] + ) -> _NamedIslSetLike[IslSetLike]: + """ + Return the set difference with *other* removed. + """ + return _apply_set_like_binary_op(self, other, operator.sub) + + @override + def __eq__(self, other: object) -> bool: + if not isinstance(other, type(self)): + raise NotImplementedError("Objects are not of the same type") + + aligned_self, aligned_other = _align_two(self, other) + + # FIXME: type checker complains because it's not clear whether the + # underlying object after alignment is an isl.Set + assert isinstance(aligned_self._obj, isl.Set) + assert isinstance(aligned_other._obj, isl.Set) + return aligned_self._obj.plain_is_equal(aligned_other._obj) + + def __lt__(self, other: _NamedIslSetLike[IslSetLike]) -> bool: + """ + Return whether this object is a strict subset of *other*. + """ + return _compare_set_like(self, other, isl.Set.is_strict_subset) + + def __le__(self, other: _NamedIslSetLike[IslSetLike]) -> bool: + """ + Return whether this object is a subset of *other*. + """ + return _compare_set_like(self, other, isl.Set.is_subset) + + def __gt__(self, other: _NamedIslSetLike[IslSetLike]) -> bool: + """ + Return whether this object is a strict superset of *other*. + """ + return _compare_set_like(other, self, isl.Set.is_strict_subset) + + def __ge__(self, other: _NamedIslSetLike[IslSetLike]) -> bool: + """ + Return whether this object is a superset of *other*. + """ + return _compare_set_like(other, self, isl.Set.is_subset) + + +@final +@dataclass(frozen=True, eq=False) +class BasicSet(_NamedIslSetLike[isl.BasicSet]): + """ + Name-aware wrapper around :class:`islpy.BasicSet`. + + Construct instances with :func:`make_basic_set`. + """ + + @override + def add_input_names(self, names_to_add: Collection[str]) -> BasicSet: + raise NotImplementedError + + @override + def _reconstruct_isl_object(self) -> isl.BasicSet: + obj = super()._reconstruct_isl_object() + + if not isinstance(obj, isl.Set) or obj.n_basic_set() != 1: + raise ValueError( + "Cannot reconstruct an isl.BasicSet from anything other than " + "an isl.Set containing only a single isl.BasicSet." + ) + + return obj.get_basic_sets()[0] + + +@overload +def make_basic_set(src: str, ctx: isl.Context | None = None) -> BasicSet: ... + + +@overload +def make_basic_set(src: isl.BasicSet) -> BasicSet: ... + + +def make_basic_set(src: str | isl.BasicSet, ctx: isl.Context | None = None) -> BasicSet: + """ + Create a :class:`BasicSet` from isl syntax or an :class:`islpy.BasicSet`. + """ + obj = isl.BasicSet(src, ctx) if isinstance(src, str) else src + set_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) + assert isinstance(set_obj, isl.Set) + return BasicSet(set_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args + + +@final +@dataclass(frozen=True, eq=False) +class Set(_NamedIslSetLike[isl.Set]): + """ + Name-aware wrapper around :class:`islpy.Set`. + + Construct instances with :func:`make_set`. + """ + + @override + def add_input_names(self, names_to_add: Collection[str]) -> Set: + raise NotImplementedError + + @override + def _reconstruct_isl_object(self) -> isl.Set: + obj = super()._reconstruct_isl_object() + assert isinstance(obj, isl.Set) + return obj + + def get_basic_sets(self) -> Sequence[BasicSet]: + """ + Return the basic-set pieces of this set. + """ + isl_obj = self._reconstruct_isl_object() + + bsets = isl_obj.get_basic_sets() + return [make_basic_set(bset) for bset in bsets] + + +@overload +def _apply_set_like_binary_op( + lhs: BasicMap, rhs: BasicMap | Map, op: Callable[[isl.Set, isl.Set], isl.Set] +) -> BasicMap | Map: ... + + +@overload +def _apply_set_like_binary_op( + lhs: Map, rhs: BasicMap | Map, op: Callable[[isl.Set, isl.Set], isl.Set] +) -> Map: ... + + +@overload +def _apply_set_like_binary_op( + lhs: BasicSet, rhs: BasicSet | Set, op: Callable[[isl.Set, isl.Set], isl.Set] +) -> BasicSet | Set: ... + + +@overload +def _apply_set_like_binary_op( + lhs: Set, rhs: BasicSet | Set, op: Callable[[isl.Set, isl.Set], isl.Set] +) -> Set: ... + + +@overload +def _apply_set_like_binary_op( + lhs: _NamedIslSetLike[IslSetLike], + rhs: _NamedIslSetLike[IslSetLike], + op: Callable[[isl.Set, isl.Set], isl.Set], +) -> _NamedIslSetLike[IslSetLike]: ... + + +def _apply_set_like_binary_op( + lhs: _NamedIslSetLike[IslSetLike], + rhs: _NamedIslSetLike[IslSetLike], + op: Callable[[isl.Set, isl.Set], isl.Set], +) -> _NamedIslSetLike[IslSetLike]: + lhs, rhs = _align_two(lhs, rhs) + result = op(lhs._obj, rhs._obj) + + if isinstance(lhs, BasicMap) and isinstance(rhs, BasicMap): + result_type = BasicMap if result.n_basic_set() == 1 else Map + elif isinstance(lhs, BasicSet) and isinstance(rhs, BasicSet): + result_type = BasicSet if result.n_basic_set() == 1 else Set + elif isinstance(lhs, (BasicMap, Map)) and isinstance(rhs, (BasicMap, Map)): + result_type = Map + else: + result_type = Set + + return result_type( # pylint: disable=too-many-function-args + result, + lhs._name_to_dim, + lhs._dimtype_to_names, + ) + + +def _compare_set_like( + lhs: _NamedIslSetLike[IslSetLike], + rhs: _NamedIslSetLike[IslSetLike], + op: Callable[[isl.Set, isl.Set], bool], +) -> bool: + lhs_is_map = isinstance(lhs, _NamedIslMapLike) + rhs_is_map = isinstance(rhs, _NamedIslMapLike) + if lhs_is_map != rhs_is_map: + raise TypeError("Cannot compare set-like and map-like objects") + + aligned_lhs, aligned_rhs = _align_two(lhs, rhs) + + assert isinstance(aligned_lhs._obj, isl.Set) + assert isinstance(aligned_rhs._obj, isl.Set) + return op(aligned_lhs._obj, aligned_rhs._obj) + + +class _NamedIslMapLike(_NamedIslSetLike[PublicMapLikeT_co]): + @override + def _reconstruct_isl_object(self) -> PublicMapLikeT_co: + obj = super()._reconstruct_isl_object() + if isinstance(obj, isl.Set): + return cast( + "PublicMapLikeT_co", + isl.Map.from_domain_and_range(isl.Set("{ [] }"), obj), + ) + assert isinstance(obj, isl.BasicMap | isl.Map) + return cast("PublicMapLikeT_co", obj) + + def _output_names(self) -> frozenset[str]: + return frozenset(self._name_to_dim) - self.input_names - self.parameter_names + + def _map_obj(self) -> isl.BasicMap | isl.Map: + return self._reconstruct_isl_object() + + @staticmethod + def _wrap_map_result(result: isl.BasicMap | isl.Map) -> BasicMap | Map: + if isinstance(result, isl.BasicMap): + return make_basic_map(result) + return make_map(result) + + def _map_with_universe( + self, dim_type: isl.dim_type, set_obj: isl.BasicSet | isl.Set + ) -> BasicMap | Map: + map_obj = self._map_obj() + if dim_type == isl.dim_type.in_: + universe = isl.Set.universe(map_obj.range().get_space()) + return make_map(isl.Map.from_domain_and_range(set_obj, universe)) + if dim_type == isl.dim_type.out: + universe = isl.Set.universe(map_obj.domain().get_space()) + return make_map(isl.Map.from_domain_and_range(universe, set_obj)) + raise ValueError(f"unsupported dim type: {dim_type}") + + def _ordered_names(self, names: frozenset[str]) -> tuple[str, ...]: + return tuple(sorted(names, key=self._name_to_dim.__getitem__)) + + def _reject_surviving_name_collisions( + self, + collisions: frozenset[str], + ) -> None: + if collisions: + raise ValueError( + "composition would create duplicate surviving names: " + + ", ".join(sorted(collisions)) + ) + + def _reorder_interface( + self, dim_type: isl.dim_type, ordered_names: tuple[str, ...] + ) -> _NamedIslMapLike[PublicMapLikeT_co]: + interface_names = ( + self.input_names if dim_type == isl.dim_type.in_ else self._output_names() + ) + current_names = self._ordered_names(interface_names) + if current_names == ordered_names: + return self + + out_names = ( + ordered_names + if dim_type == isl.dim_type.out + else self._ordered_names(self._output_names()) + ) + in_names = ( + ordered_names + if dim_type == isl.dim_type.in_ + else self._ordered_names(self.input_names) + ) + param_names = self._ordered_names(self.parameter_names) + + ordering: NameToDim = constantdict({ + name: dim for dim, name in enumerate((*out_names, *in_names, *param_names)) + }) + + return _align_obj(self, ordering, self._dimtype_to_names) + + def _validate_composable( + self, + lhs_dim_type: isl.dim_type, + other: BasicMap | Map, + rhs_dim_type: isl.dim_type, + ) -> tuple[str, ...]: + lhs_names = ( + self.input_names + if lhs_dim_type == isl.dim_type.in_ + else self._output_names() + ) + rhs_names = ( + other.input_names + if rhs_dim_type == isl.dim_type.in_ + else other._output_names() + ) + if lhs_names != rhs_names: + raise ValueError("maps are not composable: interface names differ") + return self._ordered_names(lhs_names) + + def intersect_domain(self, domain: BasicSet | Set) -> BasicMap | Map: + """ + Return this map restricted to *domain*. + """ + domain_obj = domain._reconstruct_isl_object() + assert isinstance(domain_obj, isl.BasicSet | isl.Set) + result = _apply_set_like_binary_op( + self, + self._map_with_universe( + isl.dim_type.in_, + domain_obj, + ), + _set_like_and, + ) + assert isinstance(result, BasicMap | Map) + return result + + def intersect_range(self, range_: BasicSet | Set) -> BasicMap | Map: + """ + Return this map restricted to *range_*. + """ + range_obj = range_._reconstruct_isl_object() + assert isinstance(range_obj, isl.BasicSet | isl.Set) + result = _apply_set_like_binary_op( + self, + self._map_with_universe( + isl.dim_type.out, + range_obj, + ), + _set_like_and, + ) + assert isinstance(result, BasicMap | Map) + return result + + def apply_range(self, other: BasicMap | Map) -> BasicMap | Map: + """ + Compose this map with *other* on this map's range. + + The output names of this map must match the input names of *other*. + """ + ordered_names = self._validate_composable( + isl.dim_type.out, other, isl.dim_type.in_ + ) + reordered_other = other._reorder_interface(isl.dim_type.in_, ordered_names) + assert isinstance(reordered_other, BasicMap | Map) + self._reject_surviving_name_collisions( + self.input_names & reordered_other._output_names() + ) + result = self._map_obj().apply_range(reordered_other._map_obj()) + return self._wrap_map_result(result) + + def apply_domain(self, other: BasicMap | Map) -> BasicMap | Map: + """ + Compose *other* with this map on this map's domain. + + The input names of this map must match the output names of *other*. + """ + ordered_names = self._validate_composable( + isl.dim_type.in_, other, isl.dim_type.out + ) + reordered_other = other._reorder_interface(isl.dim_type.out, ordered_names) + assert isinstance(reordered_other, BasicMap | Map) + self._reject_surviving_name_collisions( + reordered_other.input_names & self._output_names() + ) + result = reordered_other._map_obj().apply_range(self._map_obj()) + return self._wrap_map_result(result) + + def reverse(self) -> BasicMap | Map: + """ + Return the map with domain and range exchanged. + """ + return self._wrap_map_result(self._map_obj().reverse()) + + def domain(self) -> BasicSet | Set: + """ + Return the domain as a named set. + """ + domain = self._map_obj().domain() + if isinstance(domain, isl.BasicSet): + return make_basic_set(domain) + return make_set(domain) + + def range(self) -> BasicSet | Set: + """ + Return the range as a named set. + """ + range_ = self._map_obj().range() + if isinstance(range_, isl.BasicSet): + return make_basic_set(range_) + return make_set(range_) + + +@overload +def make_set(src: str, ctx: isl.Context | None = None) -> Set: ... + + +@overload +def make_set(src: isl.Set) -> Set: ... + + +def make_set(src: isl.Set | str, ctx: isl.Context | None = None) -> Set: + """ + Create a :class:`Set` from isl syntax or an :class:`islpy.Set`. + """ + obj = isl.Set(src, ctx) if isinstance(src, str) else src + set_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) + assert isinstance(set_obj, isl.Set) + return Set(set_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args + + +@final +@dataclass(frozen=True, eq=False) +class BasicMap(_NamedIslMapLike[isl.BasicMap]): + """ + Name-aware wrapper around :class:`islpy.BasicMap`. + + Construct instances with :func:`make_basic_map`. + """ + + @classmethod + def empty(cls, space: isl.Space) -> BasicMap: + """ + Return an empty :class:`BasicMap` in *space*. + """ + obj = isl.BasicMap.empty(space) + set_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) + assert isinstance(set_obj, isl.Set) + return cls( # pylint: disable=too-many-function-args + set_obj, + name_to_dim, + dimtype_to_names, + ) + + @override + def _map_obj(self) -> isl.BasicMap: + obj = self._reconstruct_isl_object() + assert isinstance(obj, isl.BasicMap) + return obj + + @override + def domain(self) -> BasicSet: + return make_basic_set(self._map_obj().domain()) + + @override + def range(self) -> BasicSet: + return make_basic_set(self._map_obj().range()) + + @override + def intersect_domain(self, domain: BasicSet | Set) -> BasicMap | Map: + if isinstance(domain, BasicSet): + range_space = self._map_obj().range().get_space() + filter_map = make_basic_map( + isl.BasicMap.from_domain_and_range( + domain._reconstruct_isl_object(), isl.BasicSet.universe(range_space) + ) + ) + result = self & filter_map + assert isinstance(result, BasicMap | Map) + return result + return super().intersect_domain(domain) + + @override + def intersect_range(self, range_: BasicSet | Set) -> BasicMap | Map: + if isinstance(range_, BasicSet): + domain_space = self._map_obj().domain().get_space() + filter_map = make_basic_map( + isl.BasicMap.from_domain_and_range( + isl.BasicSet.universe(domain_space), + range_._reconstruct_isl_object(), + ) + ) + result = self & filter_map + assert isinstance(result, BasicMap | Map) + return result + return super().intersect_range(range_) + + @override + def _reconstruct_isl_object(self) -> isl.BasicMap: + obj = super()._reconstruct_isl_object() + + if isinstance(obj, isl.Map) and obj.is_empty(): + return isl.BasicMap.empty(obj.get_space()) + + if not isinstance(obj, isl.Map) or obj.n_basic_map() != 1: + raise ValueError( + "Cannot reconstruct an isl.BasicMap from anything other than " + "an isl.Map containing only a single isl.BasicMap." + ) + + return obj.get_basic_maps()[0] + + +@overload +def make_basic_map(src: str, ctx: isl.Context | None = None) -> BasicMap: ... + + +@overload +def make_basic_map(src: isl.BasicMap) -> BasicMap: ... + + +def make_basic_map(src: str | isl.BasicMap, ctx: isl.Context | None = None) -> BasicMap: + """ + Create a :class:`BasicMap` from isl syntax or an :class:`islpy.BasicMap`. + """ + obj = isl.BasicMap(src, ctx) if isinstance(src, str) else src + set_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) + assert isinstance(set_obj, isl.Set) + return BasicMap(set_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args + + +def make_map_from_domain_and_range( + domain: BasicSet | Set, range_: BasicSet | Set +) -> BasicMap | Map: + """ + Create a named map from a named *domain* and named *range_*. + + A :class:`BasicMap` is returned when both inputs are basic sets; otherwise a + :class:`Map` is returned. + """ + if isinstance(domain, BasicSet) and isinstance(range_, BasicSet): + domain_obj = domain._reconstruct_isl_object() + range_obj = range_._reconstruct_isl_object() + assert isinstance(domain_obj, isl.BasicSet) + assert isinstance(range_obj, isl.BasicSet) + return make_basic_map( + isl.BasicMap.from_domain_and_range( + domain_obj, + range_obj, + ) + ) + + domain_obj = domain._reconstruct_isl_object() + range_obj = range_._reconstruct_isl_object() + assert isinstance(domain_obj, isl.BasicSet | isl.Set) + assert isinstance(range_obj, isl.BasicSet | isl.Set) + return make_map( + isl.Map.from_domain_and_range( + domain_obj, + range_obj, + ) + ) + + +@final +@dataclass(frozen=True, eq=False) +class Map(_NamedIslMapLike[isl.Map]): + """ + Name-aware wrapper around :class:`islpy.Map`. + + Construct instances with :func:`make_map`. + """ + + @classmethod + def empty(cls, space: isl.Space) -> Map: + """ + Return an empty :class:`Map` in *space*. + """ + return make_map(isl.Map.empty(space)) + + @override + def _reconstruct_isl_object(self) -> isl.Map: + obj = super()._reconstruct_isl_object() + assert isinstance(obj, isl.Map) + return obj + + +@overload +def make_map(src: str, ctx: isl.Context | None = None) -> Map: ... + + +@overload +def make_map(src: isl.Map) -> Map: ... + + +def make_map(src: str | isl.Map, ctx: isl.Context | None = None) -> Map: + """ + Create a :class:`Map` from isl syntax or an :class:`islpy.Map`. + """ + obj = isl.Map(src, ctx) if isinstance(src, str) else src + set_obj, name_to_dim, dimtype_to_names = _make_named_object_pieces(obj) + assert isinstance(set_obj, isl.Set) + return Map(set_obj, name_to_dim, dimtype_to_names) # pylint: disable=too-many-function-args diff --git a/namedisl/test/test_expression_like.py b/namedisl/test/test_expression_like.py new file mode 100644 index 0000000..45b5c08 --- /dev/null +++ b/namedisl/test/test_expression_like.py @@ -0,0 +1,546 @@ +from __future__ import annotations + + +__copyright__ = """ +Copyright (C) 2025- University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import pytest +from constantdict import constantdict + +import islpy as isl + +import namedisl as nisl + + +ScalarIslExpression = isl.Aff | isl.PwAff | isl.QPolynomial | isl.PwQPolynomial + + +def _is_zero_expression(expr: ScalarIslExpression) -> bool: + if isinstance(expr, isl.Aff): + return bool(expr.plain_is_zero()) + if isinstance(expr, isl.PwAff): + return bool(expr.plain_is_equal(expr * 0)) + return bool(expr.is_zero()) + + +# {{{ affs + +def test_aff_from_str(): + spec = "[n] -> { [i] -> [2 * i + n] }" + named_aff = nisl.make_aff(spec) + aff = isl.Aff(spec) + + print(named_aff) + print(aff) + + assert aff == named_aff._reconstruct_isl_object() + + +def test_aff_from_aff(): + aff = isl.Aff("[n] -> { [i] -> [2 * i + n] }") + named_aff = nisl.make_aff(aff) + + print(named_aff) + print(aff) + + assert aff == named_aff._reconstruct_isl_object() + + +def test_aff_binary_ops(): + spec = "[n] -> { [i] -> [2 * i + n] }" + named_aff = nisl.make_aff(spec) + aff = isl.Aff(spec) + + aff_p_aff = aff + aff + aff_p_1 = aff + 1 + + naff_p_naff = named_aff + named_aff + naff_p_1 = named_aff + 1 + + assert naff_p_naff._reconstruct_isl_object() == aff_p_aff + assert naff_p_1._reconstruct_isl_object() == aff_p_1 + + aff_s_aff = aff - aff + aff_s_1 = aff - 1 + + naff_s_naff = named_aff - named_aff + naff_s_1 = named_aff - 1 + + assert naff_s_naff._reconstruct_isl_object() == aff_s_aff + assert naff_s_1._reconstruct_isl_object() == aff_s_1 + + aff_m_3 = aff * 3 + naff_m_3 = named_aff * 3 + + assert naff_m_3._reconstruct_isl_object() == aff_m_3 + +# }}} + + +# {{{ pwaffs + +def test_pwaff_from_str(): + spec = "[n] -> { [i] -> [2 * i + n] }" + named_pwaff = nisl.make_pw_aff(spec) + pwaff = isl.PwAff(spec) + + print(named_pwaff) + print(pwaff) + + assert pwaff == named_pwaff._reconstruct_isl_object() + + +def test_pwaff_from_pwaff(): + pwaff = isl.PwAff("[n] -> { [i] -> [2 * i + n] }") + named_pwaff = nisl.make_pw_aff(pwaff) + + print(named_pwaff) + print(pwaff) + + assert pwaff == named_pwaff._reconstruct_isl_object() + + +def test_pwaff_binary_ops(): + spec = "[n] -> { [i] -> [2 * i + n] }" + named_pwaff = nisl.make_pw_aff(spec) + pwaff = isl.PwAff(spec) + + pwaff_p_pwaff = pwaff + pwaff + pwaff_p_1 = pwaff + 1 + + npwaff_p_npwaff = named_pwaff + named_pwaff + npwaff_p_1 = named_pwaff + 1 + + assert npwaff_p_npwaff._reconstruct_isl_object() == pwaff_p_pwaff + assert npwaff_p_1._reconstruct_isl_object() == pwaff_p_1 + + pwaff_s_pwaff = pwaff - pwaff + pwaff_s_1 = pwaff - 1 + + npwaff_s_npwaff = named_pwaff - named_pwaff + npwaff_s_1 = named_pwaff - 1 + + assert npwaff_s_npwaff._reconstruct_isl_object() == pwaff_s_pwaff + assert npwaff_s_1._reconstruct_isl_object() == pwaff_s_1 + + pwaff_m_3 = pwaff * 3 + npwaff_m_3 = named_pwaff * 3 + + assert npwaff_m_3._reconstruct_isl_object() == pwaff_m_3 + + +def test_mixed_aff_and_pwaff_binary_op_promotes_to_pwaff() -> None: + aff = nisl.make_aff("{ [i] -> [i] }") + pwaff = nisl.make_pw_aff("{ [i] -> [i] }") + + result = aff + pwaff + + assert isinstance(result, nisl.PwAff) + assert result._reconstruct_isl_object() == ( + aff._reconstruct_isl_object().to_pw_aff() + pwaff._reconstruct_isl_object() + ) + + +def test_expression_equality_type_mismatch_raises_not_implemented_error() -> None: + aff = nisl.make_aff("{ [i] -> [i] }") + pwaff = nisl.make_pw_aff("{ [i] -> [i] }") + + with pytest.raises(NotImplementedError, match="Objects are not of the same type"): + _ = aff == pwaff + + +def test_reflected_integer_expression_ops() -> None: + aff_expr = nisl.make_aff("{ [i] -> [i] }") + aff_obj = aff_expr._reconstruct_isl_object() + assert _is_zero_expression((1 + aff_expr)._reconstruct_isl_object() - (1 + aff_obj)) + assert _is_zero_expression((1 - aff_expr)._reconstruct_isl_object() - (1 - aff_obj)) + assert _is_zero_expression((2 * aff_expr)._reconstruct_isl_object() - (2 * aff_obj)) + + pw_aff_expr = nisl.make_pw_aff("{ [i] -> [i] }") + pw_aff_obj = pw_aff_expr._reconstruct_isl_object() + assert _is_zero_expression( + (1 + pw_aff_expr)._reconstruct_isl_object() - (1 + pw_aff_obj) + ) + assert _is_zero_expression( + (1 - pw_aff_expr)._reconstruct_isl_object() - (1 - pw_aff_obj) + ) + assert _is_zero_expression( + (2 * pw_aff_expr)._reconstruct_isl_object() - (2 * pw_aff_obj) + ) + + qpoly_expr = nisl.make_qpolynomial("{ [i] -> i }") + qpoly_obj = qpoly_expr._reconstruct_isl_object() + assert _is_zero_expression( + (1 + qpoly_expr)._reconstruct_isl_object() - (1 + qpoly_obj) + ) + assert _is_zero_expression( + (1 - qpoly_expr)._reconstruct_isl_object() - (1 - qpoly_obj) + ) + assert _is_zero_expression( + (2 * qpoly_expr)._reconstruct_isl_object() - (2 * qpoly_obj) + ) + + pw_qpoly_expr = nisl.make_pw_qpolynomial("{ [i] -> i }") + pw_qpoly_obj = pw_qpoly_expr._reconstruct_isl_object() + assert _is_zero_expression( + (1 + pw_qpoly_expr)._reconstruct_isl_object() - (1 + pw_qpoly_obj) + ) + assert _is_zero_expression( + (1 - pw_qpoly_expr)._reconstruct_isl_object() - (1 - pw_qpoly_obj) + ) + assert _is_zero_expression( + (2 * pw_qpoly_expr)._reconstruct_isl_object() - (2 * pw_qpoly_obj) + ) + + +def _qpolynomial(spec: str) -> isl.QPolynomial: + return isl.PwQPolynomial(spec).get_pieces()[0][1] + + +def _assert_expression_equal( + actual: ScalarIslExpression, expected: ScalarIslExpression +) -> None: + if isinstance(actual, isl.Aff | isl.PwAff): + assert actual == expected + return + + if isinstance(actual, isl.QPolynomial) and isinstance(expected, isl.QPolynomial): + assert (actual - expected).is_zero() + return + + if isinstance(actual, isl.PwQPolynomial) and isinstance( + expected, isl.PwQPolynomial + ): + assert (actual - expected).is_zero() + return + + raise TypeError( + f"cannot compare {type(actual).__name__} and {type(expected).__name__}" + ) + + +def test_move_dims_expression_param_to_input_reconstructs_like_isl() -> None: + cases = ( + ( + nisl.make_aff("[n] -> { [i] -> [i + n] }"), + isl.Aff("[n] -> { [i] -> [i + n] }"), + ), + ( + nisl.make_pw_aff("[n] -> { [i] -> [i + n] }"), + isl.PwAff("[n] -> { [i] -> [i + n] }"), + ), + ( + nisl.make_qpolynomial("[n] -> { [i] -> i + n }"), + _qpolynomial("[n] -> { [i] -> i + n }"), + ), + ( + nisl.make_pw_qpolynomial("[n] -> { [i] -> i + n }"), + isl.PwQPolynomial("[n] -> { [i] -> i + n }"), + ), + ) + + for named_expr, isl_expr in cases: + moved = named_expr.move_dims("n", isl.dim_type.in_) + expected = isl_expr.move_dims( + isl.dim_type.in_, + 1, + isl.dim_type.param, + 0, + 1, + ) + + _assert_expression_equal(moved._reconstruct_isl_object(), expected) + assert moved.input_names == frozenset({"i", "n"}) + assert moved.parameter_names == frozenset() + + +def test_move_dims_expression_input_to_param_reconstructs_like_isl() -> None: + cases = ( + ( + nisl.make_aff("[n] -> { [i] -> [i + n] }"), + isl.Aff("[n] -> { [i] -> [i + n] }"), + ), + ( + nisl.make_pw_aff("[n] -> { [i] -> [i + n] }"), + isl.PwAff("[n] -> { [i] -> [i + n] }"), + ), + ( + nisl.make_qpolynomial("[n] -> { [i] -> i + n }"), + _qpolynomial("[n] -> { [i] -> i + n }"), + ), + ( + nisl.make_pw_qpolynomial("[n] -> { [i] -> i + n }"), + isl.PwQPolynomial("[n] -> { [i] -> i + n }"), + ), + ) + + for named_expr, isl_expr in cases: + moved = named_expr.move_dims("i", isl.dim_type.param) + expected = isl_expr.move_dims( + isl.dim_type.param, + 1, + isl.dim_type.in_, + 0, + 1, + ) + + _assert_expression_equal(moved._reconstruct_isl_object(), expected) + assert moved.input_names == frozenset() + assert moved.parameter_names == frozenset({"n", "i"}) + +# }}} + + +def test_multi_aff_get_at_uses_name() -> None: + map_ = nisl.make_map("{ [i] -> [x = i, y = 2i] }") + maff = nisl.make_multi_aff( + map_._reconstruct_isl_object().as_pw_multi_aff().as_multi_aff() + ) + assert maff.get_at("x")._reconstruct_isl_object() == isl.PwAff("{ [i] -> [(i)] }") + + +def test_pw_multi_aff_get_at_uses_name() -> None: + map_ = nisl.make_map("{ [i] -> [x = i, y = 2i] }") + pmaff = nisl.make_pw_multi_aff(map_.as_pw_multi_aff()) + assert pmaff.get_at("y")._reconstruct_isl_object() == isl.PwAff("{ [i] -> [(2i)] }") + + +def test_multi_aff_stores_pw_aff_parts() -> None: + raw_maff = isl.MultiAff("{ [i] -> [x = i, y = 2i] }") + + maff = nisl.make_multi_aff(raw_maff) + + assert isinstance(maff._obj, constantdict) + assert not isinstance(maff._obj, isl.Set) + assert all(isinstance(part, nisl.PwAff) for part in maff._obj.values()) + assert maff.get_at("x") is maff._obj["x"] + assert maff.get_at("y") is maff._obj["y"] + assert maff.output_names == frozenset(maff._obj) + assert maff.input_names == maff.get_at("x").input_names + assert maff.parameter_names == maff.get_at("x").parameter_names + assert maff._reconstruct_isl_object() == raw_maff + + +def test_pw_multi_aff_stores_pw_aff_parts() -> None: + raw_pmaff = isl.PwMultiAff("[n] -> { [i] -> [x = i + n, y = 2i] }") + + pmaff = nisl.make_pw_multi_aff(raw_pmaff) + + assert isinstance(pmaff._obj, constantdict) + assert not isinstance(pmaff._obj, isl.Set) + assert all(isinstance(part, nisl.PwAff) for part in pmaff._obj.values()) + assert pmaff.get_at("x") is pmaff._obj["x"] + assert pmaff.get_at("y") is pmaff._obj["y"] + assert pmaff.output_names == frozenset(pmaff._obj) + assert pmaff.input_names == pmaff.get_at("x").input_names + assert pmaff.parameter_names == pmaff.get_at("x").parameter_names + assert pmaff._reconstruct_isl_object() == raw_pmaff + + +# {{{ qpolynomials + +def test_qpolynomial_from_str(): + spec = "[n] -> { [i] -> 2 * i + n }" + named_qpolynomial = nisl.make_qpolynomial(spec) + qpolynomial = isl.PwQPolynomial(spec).get_pieces()[0][1] + + print(named_qpolynomial) + print(qpolynomial) + + assert (named_qpolynomial._reconstruct_isl_object() - qpolynomial).is_zero() + + +def test_qpolynomial_from_qpolynomial(): + qpolynomial = isl.PwQPolynomial( + "[n] -> { [i] -> 2 * i + n }").get_pieces()[0][1] + named_qpolynomial = nisl.make_qpolynomial(qpolynomial) + + print(named_qpolynomial) + print(qpolynomial) + + assert (named_qpolynomial._reconstruct_isl_object() - qpolynomial).is_zero() + + +def test_qpolynomial_binary_ops(): + spec = "[n] -> { [i] -> 2 * i + n }" + named_qpolynomial = nisl.make_qpolynomial(spec) + qpolynomial = isl.PwQPolynomial(spec).get_pieces()[0][1] + + qpolynomial_p_qpolynomial = qpolynomial + qpolynomial + qpolynomial_p_1 = qpolynomial + 1 + + nqpolynomial_p_nqpolynomial = named_qpolynomial + named_qpolynomial + nqpolynomial_p_1 = named_qpolynomial + 1 + + assert ( + nqpolynomial_p_nqpolynomial._reconstruct_isl_object() + - + qpolynomial_p_qpolynomial + ).is_zero() + assert ( + nqpolynomial_p_1._reconstruct_isl_object() + - + qpolynomial_p_1 + ).is_zero() + + qpolynomial_s_qpolynomial = qpolynomial - qpolynomial + qpolynomial_s_1 = qpolynomial - 1 + + nqpolynomial_s_nqpolynomial = named_qpolynomial - named_qpolynomial + nqpolynomial_s_1 = named_qpolynomial - 1 + + assert ( + nqpolynomial_s_nqpolynomial._reconstruct_isl_object() + - + qpolynomial_s_qpolynomial + ).is_zero() + assert ( + nqpolynomial_s_1._reconstruct_isl_object() + - + qpolynomial_s_1 + ).is_zero() + + qpolynomial_m_3 = qpolynomial * 3 + nqpolynomial_m_3 = named_qpolynomial * 3 + + assert ( + nqpolynomial_m_3._reconstruct_isl_object() + - + qpolynomial_m_3 + ).is_zero() + + +def test_qpolynomial_permuted_is_zero(): + # forces the objects to be aligned before binary op is applied, whereas + # the test above does not necessarily touch alignment code + + named_qp = nisl.make_qpolynomial( + "[n, m] -> { [a, b] -> a*a + 2*b + n - m }") + named_qp_perm = nisl.make_qpolynomial( + "[m, n] -> { [b, a] -> a*a + 2*b + n - m }") + + assert ( + named_qp - named_qp_perm + )._reconstruct_isl_object().is_zero() + +# }}} + + +# {{{ pwqpolynomials + +def test_pw_qp_from_str(): + spec = "[n] -> { [i] -> 2 * i + n }" + named_pw_qp = nisl.make_pw_qpolynomial(spec) + pw_qp = isl.PwQPolynomial(spec) + + print(named_pw_qp) + print(pw_qp) + + assert (named_pw_qp._reconstruct_isl_object() - pw_qp).is_zero() + + +def test_pw_qp_from_pw_qp(): + pw_qp = isl.PwQPolynomial( + "[n] -> { [i] -> 2 * i + n }") + named_pw_qp = nisl.make_pw_qpolynomial(pw_qp) + + print(named_pw_qp) + print(pw_qp) + + assert (named_pw_qp._reconstruct_isl_object() - pw_qp).is_zero() + + +def test_pw_qp_binary_ops(): + spec = "[n] -> { [i] -> 2 * i + n }" + named_pw_qp = nisl.make_pw_qpolynomial(spec) + pw_qp = isl.PwQPolynomial(spec) + + pw_qp_p_pw_qp = pw_qp + pw_qp + pw_qp_p_1 = pw_qp + 1 + + npw_qp_p_npw_qp = named_pw_qp + named_pw_qp + npw_qp_p_1 = named_pw_qp + 1 + + assert ( + npw_qp_p_npw_qp._reconstruct_isl_object() + - + pw_qp_p_pw_qp + ).is_zero() + assert ( + npw_qp_p_1._reconstruct_isl_object() + - + pw_qp_p_1 + ).is_zero() + + pw_qp_s_pw_qp = pw_qp - pw_qp + pw_qp_s_1 = pw_qp - 1 + + npw_qp_s_npw_qp = named_pw_qp - named_pw_qp + npw_qp_s_1 = named_pw_qp - 1 + + assert ( + npw_qp_s_npw_qp._reconstruct_isl_object() + - + pw_qp_s_pw_qp + ).is_zero() + assert ( + npw_qp_s_1._reconstruct_isl_object() + - + pw_qp_s_1 + ).is_zero() + + pw_qp_m_3 = pw_qp * 3 + npw_qp_m_3 = named_pw_qp * 3 + + assert ( + npw_qp_m_3._reconstruct_isl_object() + - + pw_qp_m_3 + ).is_zero() + + +def test_pw_qpolynomial_permuted_is_zero(): + # forces the objects to be aligned before binary op is applied, whereas + # the test above does not necessarily touch alignment code + + named_qp = nisl.make_pw_qpolynomial( + "[n, m] -> { [a, b] -> a*a + 2*b + n - m }") + named_qp_perm = nisl.make_pw_qpolynomial( + "[m, n] -> { [b, a] -> a*a + 2*b + n - m }") + + assert ( + named_qp - named_qp_perm + ).is_zero() + +# }}} + + +# {{{ multiaffs + +# }}} + + +# {{{ pwmultiaffs + +# }}} diff --git a/namedisl/test/test_namedisl.py b/namedisl/test/test_namedisl.py index 4eb4f07..f09e415 100644 --- a/namedisl/test/test_namedisl.py +++ b/namedisl/test/test_namedisl.py @@ -25,10 +25,275 @@ THE SOFTWARE. """ +import pytest + +import islpy as isl + import namedisl as nisl +from .utils_for_tests import generate_random_named_set + + +@pytest.mark.parametrize("ndims", [2, 3, 4, 5]) +@pytest.mark.parametrize("has_params", [True, False]) +def test_names(ndims: int, has_params: bool): + s_param = "n" if has_params else None + s, s_dims, _ = generate_random_named_set(ndims, "s", s_param) + names = frozenset(s_dims.split(",")) + + if s_param: + names = names | frozenset({s_param}) + + assert s.names == names + + +def test_public_dim_type_name_accessors() -> None: + named_map = nisl.make_map("[n] -> { [i] -> [o] }") + + assert named_map.dim_names(isl.dim_type.in_) == frozenset({"i"}) + assert named_map.input_names == frozenset({"i"}) + assert named_map.dim_names(isl.dim_type.out) == frozenset({"o"}) + assert named_map.output_names == frozenset({"o"}) + assert named_map.dim_names(isl.dim_type.param) == frozenset({"n"}) + assert named_map.parameter_names == frozenset({"n"}) + + +def test_public_dim_type_name_accessors_for_aff() -> None: + named_aff = nisl.make_aff("[n] -> { [i] -> [i + n] }") + + assert named_aff.input_names == frozenset({"i"}) + assert named_aff.parameter_names == frozenset({"n"}) + assert named_aff.output_names == frozenset() + assert named_aff.set_names == frozenset() + + +def test_add_set_and_parameter_names_reconstructs_expected_set() -> None: + named_set = ( + nisl.make_set("{ [x] }") + .add_set_names(["y"]) + .add_parameter_names(["p"]) + ) + expected = isl.Set("[p] -> { [y, x] }") + reconstructed = named_set._reconstruct_isl_object() + + assert isinstance(reconstructed, isl.Set) + assert reconstructed.plain_is_equal(expected) + assert _dim_names(reconstructed, isl.dim_type.set) == _dim_names( + expected, isl.dim_type.set + ) + assert _dim_names(reconstructed, isl.dim_type.param) == _dim_names( + expected, isl.dim_type.param + ) + + +def test_add_output_input_and_parameter_names_reconstructs_expected_map() -> None: + named_map = ( + nisl.make_map("{ [i] -> [o] }") + .add_output_names(["o2"]) + .add_input_names(["i2"]) + .add_parameter_names(["n"]) + ) + expected = isl.Map("[n] -> { [i2, i] -> [o2, o] }") + reconstructed = named_map._reconstruct_isl_object() + + assert isinstance(reconstructed, isl.Map) + assert reconstructed.plain_is_equal(expected) + assert _dim_names(reconstructed, isl.dim_type.in_) == _dim_names( + expected, isl.dim_type.in_ + ) + assert _dim_names(reconstructed, isl.dim_type.out) == _dim_names( + expected, isl.dim_type.out + ) + assert _dim_names(reconstructed, isl.dim_type.param) == _dim_names( + expected, isl.dim_type.param + ) + + +def test_add_input_and_parameter_names_reconstructs_expected_aff() -> None: + named_aff = ( + nisl.make_aff("[n] -> { [i] -> [i + n] }") + .add_input_names(["j"]) + .add_parameter_names(["m"]) + ) + expected = isl.Aff("[m, n] -> { [j, i] -> [i + n] }") + reconstructed = named_aff._reconstruct_isl_object() + + assert reconstructed == expected + + +def test_add_dim_names_uses_dim_type() -> None: + named_map = ( + nisl.make_map("{ [i] -> [o] }") + .add_dim_names(["j"], isl.dim_type.in_) + .add_dim_names(["p"], isl.dim_type.param) + .add_dim_names(["x"], isl.dim_type.out) + ) + + assert named_map.input_names == frozenset({"j", "i"}) + assert named_map.parameter_names == frozenset({"p"}) + assert named_map.output_names == frozenset({"x", "o"}) + + +def _dim_names( + obj: isl.Set | isl.Map, + dim_type: isl.dim_type + ) -> tuple[str | None, ...]: + return tuple(obj.get_dim_name(dim_type, i) for i in range(obj.dim(dim_type))) + + +def test_move_dims_set_reconstructs_like_isl() -> None: + named_set = nisl.make_set("[p] -> { [x, y, z] : x + y = z and 0 <= x, y, z < p }") + + moved = named_set.move_dims("z", isl.dim_type.param) + + expected = isl.Set( + "[p] -> { [x, y, z] : " + "x + y = z and 0 <= x and 0 <= y and 0 <= z " + "and x < p and y < p and z < p }" + ).move_dims( + isl.dim_type.param, 1, + isl.dim_type.set, 2, 1 + ) + + reconstructed = moved._reconstruct_isl_object() + assert isinstance(reconstructed, isl.Set) + assert reconstructed.plain_is_equal(expected) + assert _dim_names(reconstructed, isl.dim_type.set) == _dim_names( + expected, isl.dim_type.set + ) + assert _dim_names(reconstructed, isl.dim_type.param) == _dim_names( + expected, isl.dim_type.param + ) + + +def test_move_dims_map_reconstructs_like_isl() -> None: + named_map = nisl.make_map( + "[p] -> { [i0, i1] -> [o0, o1, o2] : o0 = i0 and o1 = i1 and o2 = p }" + ) + + moved = named_map.move_dims("o2", isl.dim_type.in_) + + expected = isl.Map( + "[p] -> { [i0, i1] -> [o0, o1, o2] : o0 = i0 and o1 = i1 and o2 = p }" + ).move_dims(isl.dim_type.in_, 2, isl.dim_type.out, 2, 1) + + reconstructed = moved._reconstruct_isl_object() + assert isinstance(reconstructed, isl.Map) + assert reconstructed.plain_is_equal(expected) + assert _dim_names(reconstructed, isl.dim_type.in_) == _dim_names( + expected, isl.dim_type.in_ + ) + assert _dim_names(reconstructed, isl.dim_type.out) == _dim_names( + expected, isl.dim_type.out + ) + assert _dim_names(reconstructed, isl.dim_type.param) == _dim_names( + expected, isl.dim_type.param + ) + + +def test_move_dims_multiple_names_preserves_relative_order() -> None: + named_map = nisl.make_map( + "[p] -> { [i0, i1] -> [o0, o1, o2] : o0 = i0 and o1 = i1 and o2 = p }" + ) + + moved = named_map.move_dims(["o1", "o2"], isl.dim_type.in_) + + expected = isl.Map( + "[p] -> { [i0, i1] -> [o0, o1, o2] : o0 = i0 and o1 = i1 and o2 = p }" + ) + expected = expected.move_dims(isl.dim_type.in_, 2, isl.dim_type.out, 1, 1) + expected = expected.move_dims(isl.dim_type.in_, 3, isl.dim_type.out, 1, 1) + + reconstructed = moved._reconstruct_isl_object() + assert isinstance(reconstructed, isl.Map) + assert reconstructed.plain_is_equal(expected) + assert _dim_names(reconstructed, isl.dim_type.in_) == _dim_names( + expected, isl.dim_type.in_ + ) + assert _dim_names(reconstructed, isl.dim_type.out) == _dim_names( + expected, isl.dim_type.out + ) + + +def test_rename_dims_set_reconstructs_like_isl() -> None: + named_set = nisl.make_set("[p] -> { [x, y] : x < p and y < p }") + + renamed = named_set.rename_dims({"x": "x_new", "p": "n"}) + + expected = isl.Set("[p] -> { [x, y] : x < p and y < p }") + expected = expected.set_dim_name(isl.dim_type.set, 0, "x_new") + expected = expected.set_dim_name(isl.dim_type.param, 0, "n") + + reconstructed = renamed._reconstruct_isl_object() + assert isinstance(reconstructed, isl.Set) + assert reconstructed.plain_is_equal(expected) + assert _dim_names(reconstructed, isl.dim_type.set) == _dim_names( + expected, isl.dim_type.set + ) + assert _dim_names(reconstructed, isl.dim_type.param) == _dim_names( + expected, isl.dim_type.param + ) + + +def test_rename_dims_map_reconstructs_like_isl() -> None: + named_map = nisl.make_map( + "[p] -> { [i0, i1] -> [o0, o1] : o0 = i0 and o1 = p + i1 }" + ) + + renamed = named_map.rename_dims({"i1": "j", "o1": "x", "p": "n"}) + + expected = isl.Map( + "[p] -> { [i0, i1] -> [o0, o1] : o0 = i0 and o1 = p + i1 }" + ) + expected = expected.set_dim_name(isl.dim_type.in_, 1, "j") + expected = expected.set_dim_name(isl.dim_type.out, 1, "x") + expected = expected.set_dim_name(isl.dim_type.param, 0, "n") + + reconstructed = renamed._reconstruct_isl_object() + assert isinstance(reconstructed, isl.Map) + assert reconstructed.plain_is_equal(expected) + assert _dim_names(reconstructed, isl.dim_type.in_) == _dim_names( + expected, isl.dim_type.in_ + ) + assert _dim_names(reconstructed, isl.dim_type.out) == _dim_names( + expected, isl.dim_type.out + ) + assert _dim_names(reconstructed, isl.dim_type.param) == _dim_names( + expected, isl.dim_type.param + ) + + +def test_rename_dims_rejects_renaming_to_existing_name() -> None: + named_map = nisl.make_map("{ [i] -> [o] }") + + with pytest.raises(ValueError, match="existing names"): + _ = named_map.rename_dims({"i": "o"}) + + +def test_rename_dims_rejects_unknown_name() -> None: + named_set = nisl.make_set("{ [x] }") + + with pytest.raises(ValueError, match="unknown names"): + _ = named_set.rename_dims({"y": "z"}) + + +def test_positional_fallback_methods_are_not_exposed() -> None: + named_set = nisl.make_set("{ [x, y] }") + + with pytest.raises(AttributeError): + named_set.__getattribute__("get_dim_name") + + +def test_duplicate_set_names_are_rejected() -> None: + with pytest.raises(ValueError, match=r"duplicate|unnamed"): + _ = nisl.make_set("{ [x, x] }") + + +def test_ticked_names_are_distinct_names() -> None: + space = isl.Space.create_from_names( + isl.DEFAULT_CONTEXT, + set=["x", "x'"] + ) + named_set = nisl.make_set(isl.Set.universe(space)) -def test_basic_set() -> None: - bs = nisl.make_basic_set("[n] -> {[i]}: 0<=i None: + s = nisl.make_set("[n] -> { [i]: 0 <= i < n }") + + print(s._obj) + print(s) + + +def test_set_from_set() -> None: + s = isl.Set("[n] -> { [i, j] : 0 <= i, j < n }") + named_set = nisl.make_set(s) + + print(named_set._obj) + print(named_set) + + +@pytest.mark.parametrize("ndims", [2, 3, 4, 5]) +@pytest.mark.parametrize("has_params", [True, False]) +def test_set_equality(ndims: int, has_params: bool): + a_param = "n" if has_params else None + + a, a_dims, a_cond = generate_random_named_set(ndims, "a", a_param) + + from itertools import permutations + + for perm in list(permutations(a_dims.split(","))): + perm_dims = ",".join(p for p in perm) + set_str = f"{{ [{perm_dims}] : {a_cond} }}" + if has_params: + set_str = f"[{a_param}] ->" + set_str + perm_set = nisl.make_set(set_str) + + assert a == perm_set + + +def test_set_like_equality_type_mismatch_raises_not_implemented_error() -> None: + set_ = nisl.make_set("{ [i] }") + map_ = nisl.make_map("{ [i] -> [j] }") + + with pytest.raises(NotImplementedError, match="Objects are not of the same type"): + _ = set_ == map_ + + +@pytest.mark.parametrize("ndims", [1, 2, 4, 8]) +@pytest.mark.parametrize("has_params", [True, False]) +def test_set_union(ndims: int, has_params: bool): + + if has_params: + a_param = "n" + b_param = "m" + else: + a_param = None + b_param = None + + a, a_dims, a_cond = generate_random_named_set(ndims, "a", a_param) + b, b_dims, b_cond = generate_random_named_set(ndims, "b", b_param) + + set_str = f"{{ [{a_dims}, {b_dims}] : ({a_cond}) or ({b_cond})}}" + if has_params: + set_str = "[n, m] -> " + set_str + + result = nisl.make_set(set_str) + + assert (a | b) == result + + +@pytest.mark.parametrize("ndims", [1, 2, 4, 8]) +@pytest.mark.parametrize("has_params", [True, False]) +def test_set_intersection(ndims: int, has_params: bool): + + if has_params: + a_param = "n" + b_param = "m" + else: + a_param = None + b_param = None + + a, a_dims, a_cond = generate_random_named_set(ndims, "a", a_param) + b, b_dims, b_cond = generate_random_named_set(ndims, "b", b_param) + + set_str = f"{{ [{a_dims}, {b_dims}] : ({a_cond}) and ({b_cond})}}" + if has_params: + set_str = "[n, m] -> " + set_str + + result = nisl.make_set(set_str) + + assert (a & b) == result + + +def test_set_intersection_rejects_name_collision_across_dim_types() -> None: + set_with_n = nisl.make_set("{ [n] }") + param_with_n = nisl.make_set("[n] -> { [i] }") + + with pytest.raises(ValueError, match=r"duplicate|collision"): + _find_joint_name_to_dim(set_with_n, param_with_n) + + +def test_set_add_constraint_uses_named_dimensions() -> None: + set_ = nisl.make_set("{ [j, i] }") + + constrained = set_.add_constraint("i = j - 1") + + assert constrained == nisl.make_set("{ [j, i] : i = j - 1 }") + + +def test_set_add_constraint_accepts_multiple_constraints() -> None: + set_ = nisl.make_set("{ [i, j, k] }") + + constrained = set_.add_constraint(["0 <= i", "j = i + 1", "k <= j"]) + + assert constrained == nisl.make_set( + "{ [i, j, k] : 0 <= i and j = i + 1 and k <= j }" + ) + + +def test_set_add_constraint_rejects_unknown_name() -> None: + set_ = nisl.make_set("{ [i] }") + + with pytest.raises(ValueError, match="invalid constraint"): + _ = set_.add_constraint("j = i") + + +def test_set_gist_simplifies_against_named_context() -> None: + set_ = nisl.make_set( + "{ [i, j, kb] : 0 <= i <= 13 and 0 <= j <= 13 and 0 <= kb <= 4 and kb <= 3 }" + ) + context = nisl.make_set("{ [j, i, kb] : 0 <= i <= 13 and 0 <= j <= 13 }") + + assert set_.gist(context) == nisl.make_set("{ [i, j, kb] : 0 <= kb <= 3 }") + + +def test_basic_set_gist_preserves_basic_set_when_result_is_basic() -> None: + set_ = nisl.make_basic_set("{ [i] : 0 <= i <= 10 }") + context = nisl.make_set("{ [i] : i <= 5 or i >= 8 }") + + result = set_.gist(context) + + assert isinstance(result, nisl.BasicSet) + assert result == nisl.make_basic_set("{ [i] : 0 <= i <= 10 }") + + +def test_set_subset_comparisons_align_by_name() -> None: + smaller = nisl.make_set("{ [i] : 0 <= i < 5 }") + larger = nisl.make_set("{ [j, i] : 0 <= i < 10 }") + equal_reordered = nisl.make_set("{ [i, j] : 0 <= i < 5 }") + + assert smaller < larger + assert smaller <= larger + assert larger > smaller + assert larger >= smaller + assert smaller <= equal_reordered + assert equal_reordered <= smaller + assert not smaller < equal_reordered + assert not larger <= smaller + + +def test_basic_set_subset_comparisons_allow_set_promotion() -> None: + smaller = nisl.make_basic_set("{ [i] : 0 <= i < 5 }") + larger = nisl.make_set("{ [j, i] : 0 <= i < 10 }") + + assert smaller < larger + assert smaller <= larger + assert larger > smaller + assert larger >= smaller + + +def test_basic_set_intersection_promotes_to_set() -> None: + basic = nisl.make_basic_set( + "{ [ii_s, ji_s, k_s] : 0 <= ii_s <= 4 and 0 <= ji_s <= 4 and k_s = 0 }" + ) + footprint = nisl.make_set( + "{ [ii_s, ji_s, k_s] : " + "(0 <= ii_s <= 4 and ji_s = 2 and k_s = 0) or " + "(ii_s = 2 and 0 <= ji_s <= 4 and k_s = 0) }" + ) + + result = basic & footprint + + assert isinstance(result, nisl.Set) + reconstructed = result._reconstruct_isl_object() + assert isinstance(reconstructed, isl.Set) + assert reconstructed.n_basic_set() > 1 + + +def test_set_convex_hull_returns_basic_set() -> None: + set_ = nisl.make_set( + "{ [j, i] : (j = 0 and 0 <= i <= 2) or (j = 2 and 0 <= i <= 2) }" + ) + + result = set_.convex_hull() + + assert isinstance(result, nisl.BasicSet) + assert result == nisl.make_basic_set("{ [j, i] : 0 <= j <= 2 and 0 <= i <= 2 }") + + +@pytest.mark.parametrize("ndims", [1, 2, 4, 8]) +def test_set_eliminate(ndims: int): + a, a_dims, _ = generate_random_named_set(ndims, "a", None) + a = a.eliminate(a_dims.split(",")) + + assert a == nisl.make_set(f"{{[{a_dims}]}}") + + +def test_set_eliminate_rejects_unknown_name() -> None: + set_ = nisl.make_set("{ [i] }") + + with pytest.raises(ValueError, match="unknown names: missing"): + _ = set_.eliminate("missing") + + +@pytest.mark.parametrize("ndims", [2, 4, 8]) +def test_set_project_out(ndims: int): + a, a_dims, _ = generate_random_named_set(ndims, "a", None) + a = a.project_out(a_dims.split(",")) + + assert a == nisl.make_set("{[]}") + + +def test_set_project_out_rejects_unknown_name() -> None: + set_ = nisl.make_set("{ [i] }") + + with pytest.raises(ValueError, match="unknown names: missing"): + _ = set_.project_out("missing") + + +@pytest.mark.parametrize("ndims", [2, 4, 8]) +def test_set_dim_max(ndims: int): + a, a_dims, a_cond = generate_random_named_set(ndims, "a", None) + + # dim_{min,max} return raw isl.PwAff objects on a zero-dimensional set space. + cond_pw_affs = [ + isl.PwAff(f"{{ [{cond.split('<')[2].strip(' ')}] }}") + for cond in a_cond.split("and") + ] + + for i, name in enumerate(a_dims.split(",")): + assert a.dim_max(name) == (cond_pw_affs[i] - 1) + + +@pytest.mark.parametrize("ndims", [2, 4, 8]) +def test_set_dim_min(ndims: int): + a, a_dims, a_cond = generate_random_named_set(ndims, "a", None) + + # dim_{min,max} return raw isl.PwAff objects on a zero-dimensional set space. + cond_pw_affs = [ + isl.PwAff(f"{{ [{cond.split('<')[0].strip(' ')}] }}") + for cond in a_cond.split("and") + ] + + for i, name in enumerate(a_dims.split(",")): + assert a.dim_min(name) == cond_pw_affs[i] + + +def test_set_dim_bounds_reconstruct_parameter_metadata() -> None: + set_ = nisl.make_set("[n] -> { [i] : 0 <= i < n }").rename_dims({ + "i": "j", + "n": "m", + }) + + assert set_.dim_min("j") == isl.PwAff("[m] -> { [(0)] : m > 0 }") + assert set_.dim_max("j") == isl.PwAff("[m] -> { [(-1 + m)] : m > 0 }") + + +# }}} + + +# {{{ maps + + +def test_map_from_str() -> None: + m = nisl.make_map("[n] -> { [i,j] -> [a,b] : 0 <= i, j < 10 and 0 <= a, b < 20 }") + + print(m._obj) + print(m) + + +def test_map_from_map() -> None: + m = isl.Map("[n] -> { [i,j] -> [a,b] : 0 <= i, j < 10 and 0 <= a, b < 20 }") + named_map = nisl.make_map(m) + + print(named_map._obj) + print(named_map) + + +@pytest.mark.parametrize("ndims_domain", [2, 3, 4, 5]) +@pytest.mark.parametrize("ndims_range", [2, 3, 4, 5]) +@pytest.mark.parametrize("has_params", [True, False]) +def test_map_equality(ndims_domain: int, ndims_range: int, has_params: bool): + if has_params: + d_param = "n" + r_param = "m" + else: + d_param = None + r_param = None + + og_map, domain_info, range_info = generate_random_named_map( + ndims_domain, "d", d_param, ndims_range, "r", r_param + ) + + _, d_dims, d_cond = domain_info + _, r_dims, r_cond = range_info + + from itertools import permutations + + d_perms = list(permutations(d_dims.split(","))) + r_perms = list(permutations(r_dims.split(","))) + + for d_perm, r_perm in zip(d_perms, r_perms, strict=False): + d_perm_dims = ",".join(p for p in d_perm) + r_perm_dims = ",".join(p for p in r_perm) + + domain_str = f"{{ [{d_perm_dims}] : {d_cond} }}" + range_str = f"{{ [{r_perm_dims}] : {r_cond} }}" + + if has_params: + domain_str = f"[{d_param}] ->" + domain_str + range_str = f"[{r_param}] ->" + range_str + + perm_map = nisl.make_map( + isl.Map.from_domain_and_range(isl.Set(domain_str), isl.Set(range_str)) + ) + + assert perm_map == og_map + + +@pytest.mark.parametrize("ndims_domain", [2, 3, 4, 5]) +@pytest.mark.parametrize("ndims_range", [2, 3, 4, 5]) +@pytest.mark.parametrize("has_params", [True, False]) +def test_map_union(ndims_domain: int, ndims_range: int, has_params: bool): + if has_params: + d_param = "n" + r_param = "m" + else: + d_param = None + r_param = None + + x, x_domain_info, x_range_info = generate_random_named_map( + ndims_domain, "x_in", d_param, ndims_range, "x_out", r_param + ) + + y, y_domain_info, y_range_info = generate_random_named_map( + ndims_domain, "y_in", d_param, ndims_domain, "y_out", r_param + ) + + _, x_in_dims, x_in_cond = x_domain_info + _, x_out_dims, x_out_cond = x_range_info + + _, y_in_dims, y_in_cond = y_domain_info + _, y_out_dims, y_out_cond = y_range_info + + result_dims = f"[{x_in_dims}, {y_in_dims}] -> [{x_out_dims}, {y_out_dims}]" + result_conds = f"({x_in_cond} and {x_out_cond}) or ({y_in_cond} and {y_out_cond})" + + result_str = "{" + result_dims + " : " + result_conds + "}" + + if has_params: + result_str = f"[{d_param}, {r_param}] ->" + result_str + + result_map = nisl.make_map(result_str) + + assert (x | y) == result_map + + +@pytest.mark.parametrize("ndims_domain", [2, 3, 4, 5]) +@pytest.mark.parametrize("ndims_range", [2, 3, 4, 5]) +@pytest.mark.parametrize("has_params", [True, False]) +def test_map_intersection(ndims_domain: int, ndims_range: int, has_params: bool): + if has_params: + d_param = "n" + r_param = "m" + else: + d_param = None + r_param = None + + x, x_domain_info, x_range_info = generate_random_named_map( + ndims_domain, "x_in", d_param, ndims_range, "x_out", r_param + ) + + y, y_domain_info, y_range_info = generate_random_named_map( + ndims_domain, "y_in", d_param, ndims_domain, "y_out", r_param + ) + + _, x_in_dims, x_in_cond = x_domain_info + _, x_out_dims, x_out_cond = x_range_info + + _, y_in_dims, y_in_cond = y_domain_info + _, y_out_dims, y_out_cond = y_range_info + + result_dims = f"[{x_in_dims}, {y_in_dims}] -> [{x_out_dims}, {y_out_dims}]" + result_conds = f"({x_in_cond} and {x_out_cond}) and ({y_in_cond} and {y_out_cond})" + + result_str = "{" + result_dims + " : " + result_conds + "}" + + if has_params: + result_str = f"[{d_param}, {r_param}] ->" + result_str + + result_map = nisl.make_map(result_str) + + assert (x & y) == result_map + + +def test_map_add_constraint_uses_input_output_and_parameter_names() -> None: + map_ = nisl.make_map("[n] -> { [i] -> [j] }") + + constrained = map_.add_constraint("j = i + n") + + assert constrained == nisl.make_map("[n] -> { [i] -> [j] : j = i + n }") + + +def test_map_add_constraint_preserves_basic_map_type() -> None: + map_ = nisl.make_basic_map("{ [i] -> [j] }") + + constrained = map_.add_constraint("j = i + 1") + + assert isinstance(constrained, nisl.BasicMap) + assert constrained == nisl.make_basic_map("{ [i] -> [j] : j = i + 1 }") + + +def test_map_add_constraint_supports_previous_context_relation() -> None: + relation = nisl.make_map( + "{ [ki_prev, kb_prev] -> [ki_cur, kb_cur] : kb_cur = kb_prev + 1 }" + ) + + constrained = relation.add_constraint("ki_prev = ki_cur - 1") + + assert constrained == nisl.make_map( + "{ [ki_prev, kb_prev] -> [ki_cur, kb_cur] : " + "ki_prev = ki_cur - 1 and kb_cur = kb_prev + 1 }" + ) + + +def test_map_gist_simplifies_against_named_context() -> None: + map_ = nisl.make_map( + "{ [i, kb] -> [j] : 0 <= i <= 13 and 0 <= kb <= 4 and kb <= 3 and j = i }" + ) + context = nisl.make_map("{ [kb, i] -> [j] : 0 <= i <= 13 and j = i }") + + assert map_.gist(context) == nisl.make_map("{ [i, kb] -> [j] : 0 <= kb <= 3 }") + + +def test_map_subset_comparisons_align_by_name() -> None: + smaller = nisl.make_map("{ [i] -> [x] : x = i and 0 <= i < 5 }") + larger = nisl.make_map("{ [j, i] -> [y, x] : x = i and 0 <= i < 10 }") + equal_reordered = nisl.make_map("{ [i, j] -> [x, y] : x = i and 0 <= i < 5 }") + + assert smaller < larger + assert smaller <= larger + assert larger > smaller + assert larger >= smaller + assert smaller <= equal_reordered + assert equal_reordered <= smaller + assert not smaller < equal_reordered + assert not larger <= smaller + + +def test_basic_map_subset_comparisons_allow_map_promotion() -> None: + smaller = nisl.make_basic_map("{ [i] -> [x] : x = i and 0 <= i < 5 }") + larger = nisl.make_map("{ [j, i] -> [y, x] : x = i and 0 <= i < 10 }") + + assert smaller < larger + assert smaller <= larger + assert larger > smaller + assert larger >= smaller + + +def test_map_convex_hull_returns_basic_map() -> None: + map_ = nisl.make_map("{ [i] -> [j] : (i = 0 and j = 0) or (i = 2 and j = 2) }") + + result = map_.convex_hull() + + assert isinstance(result, nisl.BasicMap) + assert result == nisl.make_basic_map("{ [i] -> [j] : j = i and 0 <= i <= 2 }") + + +def test_subset_comparison_rejects_set_map_mismatch() -> None: + set_ = nisl.make_set("{ [i] : 0 <= i < 5 }") + map_ = nisl.make_map("{ [i] -> [x] : x = i and 0 <= i < 5 }") + + with pytest.raises(TypeError): + _ = set_ <= map_ + + +def test_map_alignment_syncs_output_metadata() -> None: + lhs = nisl.make_map("{ [i] -> [x] }") + rhs = nisl.make_map("{ [i] -> [y, x] }") + + aligned_lhs, aligned_rhs = _align_two(lhs, rhs) + + assert aligned_lhs.ordered_dim_names(isl.dim_type.out) == ("x", "y") + assert aligned_lhs.ordered_dim_names(isl.dim_type.in_) == ("i",) + assert aligned_rhs.ordered_dim_names(isl.dim_type.out) == ("x", "y") + assert aligned_rhs.ordered_dim_names(isl.dim_type.in_) == ("i",) + + +def test_map_alignment_syncs_input_and_parameter_metadata() -> None: + lhs = nisl.make_map("[n] -> { [i] -> [x] }") + rhs = nisl.make_map("[m, n] -> { [j, i] -> [x] }") + + aligned_lhs, aligned_rhs = _align_two(lhs, rhs) + + assert aligned_lhs.ordered_dim_names(isl.dim_type.out) == ("x",) + assert aligned_lhs.ordered_dim_names(isl.dim_type.in_) == ("i", "j") + assert aligned_lhs.ordered_dim_names(isl.dim_type.param) == ("m", "n") + assert aligned_rhs.ordered_dim_names(isl.dim_type.out) == ("x",) + assert aligned_rhs.ordered_dim_names(isl.dim_type.in_) == ("i", "j") + assert aligned_rhs.ordered_dim_names(isl.dim_type.param) == ("m", "n") + + +def test_map_apply_range_rejects_surviving_name_collisions() -> None: + lhs = nisl.make_map("{ [x] -> [y] }") + rhs = nisl.make_map("{ [y] -> [x] }") + + with pytest.raises(ValueError, match="duplicate surviving names"): + _ = lhs.apply_range(rhs) + + +def test_map_apply_range_can_explicitly_rename_and_equate_collision() -> None: + lhs = nisl.make_map("{ [x] -> [y] }") + rhs = nisl.make_map("{ [y] -> [x] }").rename_dims({"x": "x_out"}) + + result = lhs.apply_range(rhs).equate_dims("x", "x_out") + + assert result.input_names == frozenset({"x"}) + assert result.range().names == frozenset({"x_out"}) + assert ( + result + .intersect_domain(nisl.make_set("{ [x] : x = 3 }")) + .range() + .dim_min("x_out") + .plain_is_equal(isl.PwAff("{ [(3)] }")) + ) + + +def test_map_apply_range_can_equate_renamed_collisions_from_mapping() -> None: + lhs = nisl.make_map("{ [x, z] -> [y] }") + rhs = nisl.make_map("{ [y] -> [x, z] }").rename_dims({ + "x": "x_out", + "z": "z_out", + }) + + result = lhs.apply_range(rhs).equate_dims({ + "x": "x_out", + "z": "z_out", + }) + + assert result == nisl.make_map( + "{ [x, z] -> [x_out, z_out] : x = x_out and z = z_out }" + ) + + +def test_equate_dims_mapping_rejects_unknown_name() -> None: + map_ = nisl.make_map("{ [x] -> [x_out] }") + + with pytest.raises(ValueError, match="unknown name: missing"): + _ = map_.equate_dims({"x": "missing"}) + + +def test_map_apply_domain_rejects_surviving_name_collisions() -> None: + lhs = nisl.make_map("{ [x] -> [y] }") + rhs = nisl.make_map("{ [y] -> [x] }") + + with pytest.raises(ValueError, match="duplicate surviving names"): + _ = rhs.apply_domain(lhs) + + +def test_map_apply_domain_can_explicitly_rename_and_equate_collision() -> None: + lhs = nisl.make_map("{ [x] -> [y] }") + rhs = nisl.make_map("{ [y] -> [x] }").rename_dims({"x": "x_out"}) + + result = rhs.apply_domain(lhs).equate_dims("x", "x_out") + + assert result.input_names == frozenset({"x"}) + assert result.range().names == frozenset({"x_out"}) + assert ( + result + .intersect_domain(nisl.make_set("{ [x] : x = 4 }")) + .range() + .dim_min("x_out") + .plain_is_equal(isl.PwAff("{ [(4)] }")) + ) + + +def test_duplicate_map_names_are_rejected() -> None: + with pytest.raises(ValueError, match=r"duplicate|unnamed"): + _ = nisl.make_map("{ [x] -> [x] }") + + +def test_map_empty_from_space_preserves_names_and_is_empty() -> None: + space = isl.Space.create_from_names( + isl.DEFAULT_CONTEXT, params=["n"], in_=["i", "j"], out=["x", "y"] + ) + + m = nisl.Map.empty(space) + + assert m._reconstruct_isl_object().is_empty() + assert m.input_names == frozenset({"i", "j"}) + assert m.range() == nisl.make_set("[n] -> { [x, y] : false }") + + +def test_basic_map_empty_from_space_preserves_names_and_is_empty() -> None: + space = isl.Space.create_from_names(isl.DEFAULT_CONTEXT, in_=["i"], out=["x"]) + + m = nisl.BasicMap.empty(space) + + assert m._reconstruct_isl_object().is_empty() + assert m.input_names == frozenset({"i"}) + assert m.range() == nisl.make_basic_set("{ [x] : 1 = 0 }") + + +def test_map_empty_matches_existing_named_space() -> None: + template = nisl.make_map("[n] -> { [i, k] -> [ii_s, io, ki_s, ko] }") + + empty_map = nisl.Map.empty(template.get_space()) + + assert empty_map._reconstruct_isl_object().is_empty() + assert empty_map.input_names == template.input_names + assert empty_map.range()._reconstruct_isl_object().is_empty() + assert empty_map.range().names == template.range().names + + +def test_empty_map_is_identity_for_union() -> None: + space = isl.Space.create_from_names(isl.DEFAULT_CONTEXT, in_=["i"], out=["x"]) + empty_map = nisl.Map.empty(space) + nonempty_map = nisl.make_map("{ [i] -> [x] }") + + assert (empty_map | nonempty_map) == nonempty_map + assert (nonempty_map | empty_map) == nonempty_map + + +@pytest.mark.parametrize("ndims_domain", [1, 2, 4, 8]) +@pytest.mark.parametrize("ndims_range", [1, 2, 4, 8]) +def test_map_eliminate(ndims_domain: int, ndims_range: int): + x, x_domain_info, x_range_info = generate_random_named_map( + ndims_domain, "x_in", None, ndims_range, "x_out", None + ) + + _, x_in_dims, _ = x_domain_info + _, x_out_dims, _ = x_range_info + + dims_to_remove = (x_in_dims + "," + x_out_dims).split(",") + x = x.eliminate(dims_to_remove) + + assert x == nisl.make_map(f"{{[{x_in_dims}] -> [{x_out_dims}]}}") + + +def test_map_eliminate_rejects_unknown_name() -> None: + map_ = nisl.make_map("{ [i] -> [j] }") + + with pytest.raises(ValueError, match="unknown names: missing"): + _ = map_.eliminate("missing") + + +@pytest.mark.parametrize("ndims_domain", [1, 2, 4, 8]) +@pytest.mark.parametrize("ndims_range", [1, 2, 4, 8]) +def test_map_project_out(ndims_domain: int, ndims_range: int): + x, x_domain_info, x_range_info = generate_random_named_map( + ndims_domain, "x_in", None, ndims_range, "x_out", None + ) + + _, x_in_dims, _ = x_domain_info + _, x_out_dims, _ = x_range_info + + dims_to_remove = (x_in_dims + "," + x_out_dims).split(",") + x = x.project_out(dims_to_remove) + + assert x == nisl.make_map("{[] -> []}") + + +def test_map_project_out_rejects_unknown_name() -> None: + map_ = nisl.make_map("{ [i] -> [j] }") + + with pytest.raises(ValueError, match="unknown names: missing"): + _ = map_.project_out("missing") + + +def test_map_as_pw_multi_aff(): + spec = "{ [i] -> [io, ii] : i = 32 * io + ii and 0 <= ii < 32 }" + m = nisl.make_map(spec) + m_isl = isl.Map(spec) + + assert m.as_pw_multi_aff() == m_isl.as_pw_multi_aff() + + +@pytest.mark.parametrize("ndims_domain", [1, 2, 4, 8]) +@pytest.mark.parametrize("ndims_range", [1, 2, 4, 8]) +def test_map_dim_max(ndims_domain: int, ndims_range: int): + m, (_, in_names, in_conds), (_, out_names, out_conds) = generate_random_named_map( + ndims_domain, "x_in", None, ndims_range, "x_out", None + ) + + # dim_{min,max} return raw isl.PwAff objects on a zero-dimensional set space. + in_upper_bound_pw_maffs = [ + isl.PwAff(f"{{ [{int(cond.split('<')[2].strip(' '))}] }}") + for cond in in_conds.split("and") + ] + + for i, name in enumerate(in_names.split(",")): + # NOTE: constructing PwAffs assumes starting index of 0, so subtract 1 + assert m.dim_max(name) == (in_upper_bound_pw_maffs[i] - 1) + + # dim_{min,max} return raw isl.PwAff objects on a zero-dimensional set space. + out_upper_bound_pw_maffs = [ + isl.PwAff(f"{{ [{int(cond.split('<')[2].strip(' '))}] }}") + for cond in out_conds.split("and") + ] + + for i, name in enumerate(out_names.split(",")): + # NOTE: constructing PwAffs assumes starting index of 0, so subtract 1 + assert m.dim_max(name) == (out_upper_bound_pw_maffs[i] - 1) + + +@pytest.mark.parametrize("ndims_domain", [1, 2, 4, 8]) +@pytest.mark.parametrize("ndims_range", [1, 2, 4, 8]) +def test_map_dim_min(ndims_domain: int, ndims_range: int): + m, (_, in_names, in_conds), (_, out_names, out_conds) = generate_random_named_map( + ndims_domain, "x_in", None, ndims_range, "x_out", None + ) + + # dim_{min,max} return raw isl.PwAff objects on a zero-dimensional set space. + in_lower_bound_pw_maffs = [ + isl.PwAff(f"{{ [{int(cond.split('<')[0].strip(' '))}] }}") + for cond in in_conds.split("and") + ] + + for i, name in enumerate(in_names.split(",")): + assert m.dim_min(name) == in_lower_bound_pw_maffs[i] + + # dim_{min,max} return raw isl.PwAff objects on a zero-dimensional set space. + out_lower_bound_pw_maffs = [ + isl.PwAff(f"{{ [{int(cond.split('<')[0].strip(' '))}] }}") + for cond in out_conds.split("and") + ] + + for i, name in enumerate(out_names.split(",")): + assert m.dim_min(name) == out_lower_bound_pw_maffs[i] + + +def test_map_dim_bounds_reconstruct_parameter_metadata() -> None: + map_ = nisl.make_map( + "[n] -> { [i] -> [j] : 0 <= i < n and j = i + 1 }" + ).rename_dims({ + "i": "k", + "j": "l", + "n": "m", + }) + + assert map_.dim_min("k") == isl.PwAff("[m] -> { [(0)] : m > 0 }") + assert map_.dim_max("l") == isl.PwAff("[m] -> { [(m)] : m > 0 }") + + +# }}} + + +# {{{ basic{map, set} + + +def test_basic_map_from_str() -> None: + m = nisl.make_basic_map( + "[n] -> { [i,j] -> [a,b] : 0 <= i, j < 10 and 0 <= a, b < 20 }" + ) + + print(m._obj) + print(m) + + +def test_basic_map_from_map() -> None: + m = isl.BasicMap("[n] -> { [i,j] -> [a,b] : 0 <= i, j < 10 and 0 <= a, b < 20 }") + named_map = nisl.make_basic_map(m) + + print(named_map._obj) + print(named_map) + + +# }}} diff --git a/namedisl/test/utils_for_tests.py b/namedisl/test/utils_for_tests.py new file mode 100644 index 0000000..de64bbd --- /dev/null +++ b/namedisl/test/utils_for_tests.py @@ -0,0 +1,95 @@ +from __future__ import annotations + + +__copyright__ = """ +Copyright (C) 2025- University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from random import randint +from typing import TYPE_CHECKING + +import islpy as isl + +import namedisl as nisl + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +NamedSetReturnT = tuple[nisl.Set, str, str] +NamedMapReturnT = tuple[nisl.Map, NamedSetReturnT, NamedSetReturnT] + + +def get_name_sequence(n: int, dim_prefix: str) -> tuple[Sequence[str], str]: + dims = [f"{dim_prefix}_{i}" for i in range(n)] + dim_str = ",".join(d for d in dims) + + return dims, dim_str + + +def generate_random_named_set( + ndims: int, + dim_prefix: str, + param: str | None + ) -> NamedSetReturnT: + dims, dim_str = get_name_sequence(ndims, dim_prefix) + + if param is not None: + conditions = f"0 <= {dim_str} < {param}" + set_str = f"[{param}] -> {{ [{dim_str}] : {conditions} }}" + else: + upper_bounds = [randint(1, 100) for _ in range(ndims)] + lower_bounds = [ + randint(0, upper_bound - 1) for upper_bound in upper_bounds] + + conditions = " and ".join( + f"{lower_bound} <= {d} < {upper_bound}" + for d, lower_bound, upper_bound in zip( + dims, lower_bounds, upper_bounds, strict=True) + ) + set_str = f"{{ [{dim_str}] : {conditions} }}" + + return nisl.make_set(set_str), dim_str, conditions + + +def generate_random_named_map( + ndims_domain: int, + domain_prefix: str, + domain_param: str | None, + ndims_range: int, + range_prefix: str, + range_param: str | None + ) -> NamedMapReturnT: + + d = generate_random_named_set(ndims_domain, domain_prefix, domain_param) + r = generate_random_named_set(ndims_range, range_prefix, range_param) + + dom = d[0]._reconstruct_isl_object() + ran = r[0]._reconstruct_isl_object() + + return ( + nisl.make_map(isl.Map.from_domain_and_range(dom, ran)), + d, + r + ) diff --git a/pyproject.toml b/pyproject.toml index 5305241..cb46763 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ dependencies = [ "constantdict", "islpy", - "typing-extensions>=4.5", + "typing-extensions>=4.10", ] [project.urls] @@ -58,6 +58,10 @@ extend-ignore = [ "RUF067", # __init__ should contain no code ] +[tool.ruff.lint.per-file-ignores] +"pytools/test/*.py" = ["S102"] +"doc/conf.py" = ["S102"] + [tool.ruff.lint.flake8-quotes] docstring-quotes = "double" inline-quotes = "double" @@ -107,4 +111,3 @@ exclude = [ ".conda-root", ".env", ] -