Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
# ============================================================================ #
from __future__ import annotations

from collections import OrderedDict
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any, ClassVar
from typing import Any, ClassVar, Optional, Union

import numpy.typing as npt
import opt_einsum as oe
import torch
from quimb.tensor import TensorNetwork


Expand All @@ -33,6 +36,67 @@ def contractor(subscripts: str,
return oe.contract(subscripts, *tensors, optimize=optimize)


def oe_torch_contractor(subscripts: str,
tensors: list[torch.Tensor],
optimize: str = "auto",
**_: Any) -> Any:
"""Perform einsum contraction using opt_einsum with the torch backend.

Combines opt_einsum's contraction-path optimisation with torch's
execution engine, giving autograd support and GPU acceleration in a
single call. Execution device follows the input tensors.

Args:
subscripts: The einsum subscripts.
tensors: list of torch tensors to contract.
optimize: Optimization strategy passed to ``opt_einsum.contract``.
Defaults to ``"auto"``.

Returns:
The contracted tensor.
"""
return oe.contract(subscripts, *tensors, optimize=optimize, backend="torch")


# TODO: move to per-decoder instance; module-global cache means unrelated decoders share and evict each other.
_OE_EXPR_CACHE_MAXSIZE = 32
_oe_expr_cache: OrderedDict[tuple, Any] = OrderedDict()


def oe_torch_compiled_contractor(subscripts: str,
tensors: list[torch.Tensor],
optimize: str = "auto",
**_: Any) -> Any:
"""Perform einsum contraction using a cached ``opt_einsum.contract_expression``
with the torch backend.

On the first call for a given ``(subscripts, shapes, optimize)``
combination, builds and caches a :class:`opt_einsum.ContractExpression`.
Subsequent calls with the same key skip path search entirely and only
execute the pairwise tensor contractions via torch.

Args:
subscripts: The einsum subscripts.
tensors: list of torch tensors to contract.
optimize: Optimization strategy passed to
``opt_einsum.contract_expression``. Defaults to ``"auto"``.

Returns:
The contracted tensor.
"""
shapes = tuple(t.shape for t in tensors)
key = (subscripts, shapes, str(optimize))
if key in _oe_expr_cache:
_oe_expr_cache.move_to_end(key)
else:
if len(_oe_expr_cache) >= _OE_EXPR_CACHE_MAXSIZE:
_oe_expr_cache.popitem(last=False)
_oe_expr_cache[key] = oe.contract_expression(subscripts,
*shapes,
optimize=optimize)
return _oe_expr_cache[key](*tensors, backend="torch")


def cutn_contractor(subscripts: str,
tensors: list[Any],
optimize: Any | None = None,
Expand Down Expand Up @@ -61,31 +125,43 @@ def cutn_contractor(subscripts: str,
)


def optimize_path(optimize: Any, output_inds: tuple[str, ...],
tn: TensorNetwork) -> tuple[Any, Any]:
def optimize_path(optimize: Any,
output_inds: tuple[str, ...],
tn: TensorNetwork,
network_options: Any = None) -> tuple[Any, Any]:
"""
Optimize the contraction path for a tensor network.

Args:
optimize (Any): The optimization options to use.
If None or cuquantum.tensornet.OptimizerOptions, we use cuquantum.tensornet.
Else, Quimb interface at
https://quimb.readthedocs.io/en/latest/autoapi/quimb/tensor/tensor_core/index.html#quimb.tensor.tensor_core.TensorNetwork.contraction_info
output_inds (tuple[str, ...]): Output indices for the contraction.
tn (TensorNetwork): The tensor network.
optimize: The optimization options to use.
If ``None`` or a ``cuquantum.tensornet.OptimizerOptions``
instance, dispatches to ``cuquantum.tensornet.contract_path``.
Otherwise routes through Quimb's
:meth:`TensorNetwork.contraction_info` (which accepts
opt_einsum string presets, :class:`PathOptimizer`
instances, :class:`cotengra.HyperOptimizer`, etc.).
output_inds: Output indices for the contraction.
tn: The tensor network.
network_options: Optional cuTensorNet ``NetworkOptions`` (or
equivalent dict). Forwarded as ``options=`` to
``cuquantum.tensornet.contract_path``. Ignored for
non-cuTensorNet optimizers.

Returns:
tuple[Any, Any]: The contraction path and optimizer info.
A ``(path, info)`` tuple.
"""
use_cutn = optimize is None or (
type(optimize).__module__.startswith("cuquantum") and
type(optimize).__name__ == "OptimizerOptions")
if use_cutn:
from cuquantum import tensornet as cutn
kwargs: dict[str, Any] = {"optimize": optimize}
if network_options is not None:
kwargs["options"] = network_options
path, info = cutn.contract_path(
tn.get_equation(output_inds=output_inds),
*tn.arrays,
optimize=optimize,
**kwargs,
)
return path, info

Expand All @@ -109,13 +185,19 @@ class ContractorConfig:
_allowed_configs: ClassVar[tuple[tuple[str, str, str], ...]] = (
("numpy", "numpy", "cpu"),
("torch", "torch", "cpu"),
("oe_torch", "torch", "cpu"),
("oe_torch", "torch", "cuda"),
("oe_torch_compiled", "torch", "cpu"),
("oe_torch_compiled", "torch", "cuda"),
("cutensornet", "numpy", "cuda"),
("cutensornet", "torch", "cuda"),
)
_allowed_backends: ClassVar[list[str]] = ["numpy", "torch"]
_contractors: ClassVar[dict[str, Callable]] = {
"numpy": contractor,
"torch": contractor,
"oe_torch": oe_torch_contractor,
"oe_torch_compiled": oe_torch_compiled_contractor,
"cutensornet": cutn_contractor,
}

Expand Down
Loading
Loading