Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
default_language_version:
python: python3

default_stages: [commit, push]
default_stages: [pre-commit, pre-push]

repos:
- repo: https://github.com/PyCQA/isort
Expand Down
3 changes: 2 additions & 1 deletion src/bartiq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
Routine,
routine_to_qref,
)
from .compilation import compile_routine, evaluate
from .compilation import DerivedResource, compile_routine, evaluate
from .symbolics import sympy_backend

__all__ = [
Expand All @@ -32,6 +32,7 @@
"Resource",
"CompiledRoutine",
"routine_to_qref",
"DerivedResource",
"compile_routine",
"evaluate",
"sympy_backend",
Expand Down
1 change: 1 addition & 0 deletions src/bartiq/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""The bartiq analysis module provides methods to manipulate symbolic expressions."""

from bartiq.analysis.rewriters import rewrite_routine_resources, sympy_rewriter

__all__ = ["sympy_rewriter", "rewrite_routine_resources"]
1 change: 1 addition & 0 deletions src/bartiq/analysis/rewriters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rewriters can be used to modify, or simplify, the form of symbolic expressions."""

from bartiq.analysis.rewriters.routine_rewriter import rewrite_routine_resources
from bartiq.analysis.rewriters.sympy_expression import sympy_rewriter

Expand Down
1 change: 1 addition & 0 deletions src/bartiq/analysis/rewriters/routine_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Here we provide functionality to allow you to apply rewriters to CompiledRoutine resource expressions."""

from __future__ import annotations

from collections.abc import Iterable
Expand Down
1 change: 1 addition & 0 deletions src/bartiq/analysis/rewriters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module provides some helper functions for describing instructions in rewriters."""

from __future__ import annotations

import re
Expand Down
11 changes: 10 additions & 1 deletion src/bartiq/compilation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""The compilation submodule contains routine compilation functionality."""

from ._common import DerivedResource
from ._compile import CompilationFlags, CompilationResult, compile_routine
from ._evaluate import EvaluationResult, evaluate

__all__ = ["compile_routine", "CompilationResult", "evaluate", "EvaluationResult", "CompilationFlags"]
__all__ = [
"compile_routine",
"CompilationResult",
"DerivedResource",
"evaluate",
"EvaluationResult",
"CompilationFlags",
]
68 changes: 65 additions & 3 deletions src/bartiq/compilation/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,21 @@
# limitations under the License.
from __future__ import annotations

import inspect
from collections.abc import Iterable
from dataclasses import dataclass, replace
from typing import Callable

from .._routine import CompiledRoutine, Constraint, ConstraintStatus, Port, Resource
from typing import Callable, Generic, Protocol

from typing_extensions import TypedDict, TypeIs

from .._routine import (
CompiledRoutine,
Constraint,
ConstraintStatus,
Port,
Resource,
ResourceType,
)
from ..repetitions import Repetition
from ..symbolics.backend import ComparisonResult, SymbolicBackend, T, TExpr

Expand All @@ -39,6 +49,14 @@ def __init__(self, original_constraint: Constraint[T], compiled_constraint: Cons
super().__init__(original_constraint, compiled_constraint)


class DerivedResource(TypedDict, Generic[T]):
"""Contains information needed to calculate derived resources."""

name: str
type: str
calculate: Calculate[T] | CalculateWithName[T]


def evaluate_ports(
ports: dict[str, Port[T]],
inputs: dict[str, TExpr[T]],
Expand Down Expand Up @@ -130,3 +148,47 @@ def _collect_first_pass_resource_variables(children: dict[str, CompiledRoutine[T

def collect_children_variables(children: dict[str, CompiledRoutine[T]]) -> dict[str, TExpr[T]]:
return _collect_resource_variables(children) | _collect_first_pass_resource_variables(children)


def _accepts_resource_name(func: Calculate[T] | CalculateWithName[T]) -> TypeIs[CalculateWithName[T]]:
return "resource_name" in inspect.signature(func).parameters


class Calculate(Protocol[T]):

def __call__(self, routine: CompiledRoutine[T], backend: SymbolicBackend[T]) -> TExpr[T] | None:
pass


class CalculateWithName(Protocol[T]):

def __call__(self, routine: CompiledRoutine[T], backend: SymbolicBackend[T], resource_name: str) -> TExpr[T] | None:
pass


def add_derived_resources(
routine: CompiledRoutine[T],
backend: SymbolicBackend[T],
derived_resources: Iterable[DerivedResource[T]] = (),
) -> CompiledRoutine[T]:
for specs in derived_resources:
name = specs["name"]
type = specs["type"]
calculate = specs["calculate"]

value = (
calculate(routine, backend, resource_name=name)
if _accepts_resource_name(calculate)
else calculate(routine, backend)
)

if value is not None:
resource = Resource(name, type=ResourceType(type), value=value)
routine = replace(
routine,
resources={
**routine.resources,
name: resource,
},
)
return routine
64 changes: 6 additions & 58 deletions src/bartiq/compilation/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from __future__ import annotations

import ast
import inspect
import operator
import os
import warnings
Expand All @@ -24,13 +23,12 @@
from enum import Flag, auto
from functools import reduce
from graphlib import TopologicalSorter
from typing import Generic, Protocol
from typing import Generic

from qref import SchemaV1
from qref.functools import ensure_routine
from qref.schema_v1 import RoutineV1
from qref.verification import verify_topology
from typing_extensions import TypedDict, TypeIs

from bartiq._routine import (
CompiledRoutine,
Expand All @@ -44,6 +42,8 @@
from bartiq.compilation._common import (
ConstraintValidationError,
Context,
DerivedResource,
add_derived_resources,
collect_children_variables,
evaluate_constraints,
evaluate_ports,
Expand Down Expand Up @@ -92,26 +92,6 @@ class CompilationFlags(Flag):
"""Skip the verification step on the routine."""


class Calculate(Protocol[T]):

def __call__(self, routine: CompiledRoutine[T], backend: SymbolicBackend[T]) -> TExpr[T] | None:
pass


class CalculateWithName(Protocol[T]):

def __call__(self, routine: CompiledRoutine[T], backend: SymbolicBackend[T], resource_name: str) -> TExpr[T] | None:
pass


class DerivedResources(TypedDict, Generic[T]):
"""Contains information needed to calculate derived resources."""

name: str
type: str
calculate: Calculate[T] | CalculateWithName[T]


@dataclass
class CompilationResult(Generic[T]):
"""
Expand All @@ -137,7 +117,7 @@ def compile_routine(
backend: SymbolicBackend[T] = sympy_backend,
preprocessing_stages: Iterable[PreprocessingStage[T]] = DEFAULT_PREPROCESSING_STAGES,
postprocessing_stages: Iterable[PostprocessingStage[T]] = DEFAULT_POSTPROCESSING_STAGES,
derived_resources: Iterable[DerivedResources] = (),
derived_resources: Iterable[DerivedResource] = (),
compilation_flags: CompilationFlags | None = None,
) -> CompilationResult[T]:
"""Performs symbolic compilation of a given routine.
Expand Down Expand Up @@ -329,7 +309,7 @@ def _compile(
backend: SymbolicBackend[T],
inputs: dict[str, TExpr[T]],
context: Context,
derived_resources: Iterable[DerivedResources] = (),
derived_resources: Iterable[DerivedResource] = (),
compilation_flags: CompilationFlags = CompilationFlags(0), # CompilationsFlags(0) corresponds to no flags
) -> CompiledRoutine[T]:
try:
Expand Down Expand Up @@ -431,15 +411,11 @@ def _compile(
if CompilationFlags.EXPAND_RESOURCES in compilation_flags
else _introduce_placeholder_child_resources(compiled_routine, backend)
)
tmp_routine = _add_derived_resources(tmp_routine, backend, derived_resources)
tmp_routine = add_derived_resources(tmp_routine, backend, derived_resources)

return replace(compiled_routine, resources=tmp_routine.resources)


def _accepts_resource_name(func: Calculate[T] | CalculateWithName[T]) -> TypeIs[CalculateWithName[T]]:
return "resource_name" in inspect.signature(func).parameters


def _introduce_placeholder_resources(
compiled_routine: CompiledRoutine[T], backend: SymbolicBackend[T]
) -> CompiledRoutine[T]:
Expand Down Expand Up @@ -468,34 +444,6 @@ def _introduce_placeholder_child_resources(
)


def _add_derived_resources(
routine: CompiledRoutine[T],
backend: SymbolicBackend[T],
derived_resources: Iterable[DerivedResources[T]] = (),
) -> CompiledRoutine[T]:
for specs in derived_resources:
name = specs["name"]
type = specs["type"]
calculate = specs["calculate"]

value = (
calculate(routine, backend, resource_name=name)
if _accepts_resource_name(calculate)
else calculate(routine, backend)
)

if value is not None:
resource = Resource(name, type=ResourceType(type), value=value)
routine = replace(
routine,
resources={
**routine.resources,
name: resource,
},
)
return routine


def _generate_arithmetic_resources(
resources: dict[str, Resource[T]], compiled_children: dict[str, CompiledRoutine[T]], backend: SymbolicBackend[T]
) -> dict[str, Resource[T]]:
Expand Down
44 changes: 31 additions & 13 deletions src/bartiq/compilation/_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Mapping
from collections.abc import Iterable, Mapping
from dataclasses import dataclass, replace
from typing import Callable, Generic, TypeVar

Expand All @@ -22,6 +22,8 @@
from bartiq.compilation._common import (
ConstraintValidationError,
Context,
DerivedResource,
add_derived_resources,
collect_children_variables,
evaluate_constraints,
evaluate_ports,
Expand Down Expand Up @@ -63,6 +65,7 @@ def evaluate(
*,
backend: SymbolicBackend[T] = sympy_backend,
functions_map: FunctionsMap[T] | None = None,
derived_resources: Iterable[DerivedResource[T]] = (),
) -> EvaluationResult[T]:
"""Substitutes variables into compiled routine.

Expand All @@ -73,6 +76,9 @@ def evaluate(
expressions understood by backend, or via strings, e.g. `{"N": 2, "M": "k+3"}.
backend: a backend used for manipulating symbolic expressions.
functions_map: a dictionary mapping function names to their concrete implementations.
derived_resources: iterable with dictionaries describing how to calculate derived resources.
Each dictionary should contain the derived resource's name, type
and the function mapping a routine to the value of resource.

Returns:
A new instance of CompiledRoutine with appropriate substitutions made.
Expand All @@ -83,7 +89,7 @@ def evaluate(
assignment: backend.parse_constant(backend.as_expression(value)) for assignment, value in assignments.items()
}
evaluated_routine = _evaluate_internal(
compiled_routine, parsed_assignments, backend, functions_map, Context(compiled_routine.name)
compiled_routine, parsed_assignments, backend, functions_map, derived_resources, Context(compiled_routine.name)
)
return EvaluationResult(routine=evaluated_routine, _backend=backend)

Expand All @@ -93,6 +99,7 @@ def _evaluate_internal(
inputs: dict[str, TExpr[T]],
backend: SymbolicBackend[T],
functions_map: FunctionsMap[T],
derived_resources: Iterable[DerivedResource[T]],
context: Context,
) -> CompiledRoutine[T]:
try:
Expand All @@ -106,22 +113,33 @@ def _evaluate_internal(

updated_children: dict[str, CompiledRoutine[T]] = {
name: _evaluate_internal(
child, inputs, backend=backend, functions_map=functions_map, context=context.descend(name)
child,
inputs,
backend=backend,
functions_map=functions_map,
derived_resources=derived_resources,
context=context.descend(name),
)
for name, child in compiled_routine.children.items()
}

children_variables: dict[str, TExpr] = collect_children_variables(updated_children)

return replace(
compiled_routine,
input_params=sorted(set(compiled_routine.input_params).difference(inputs)),
ports=evaluate_ports(compiled_routine.ports, inputs, backend, functions_map),
resources=evaluate_resources(compiled_routine.resources, inputs | children_variables, backend, functions_map),
constraints=new_constraints,
repetition=evaluate_repetition(compiled_routine.repetition, inputs, backend, functions_map),
children=updated_children,
first_pass_resources=evaluate_resources(
compiled_routine.first_pass_resources, inputs | children_variables, backend, functions_map
return add_derived_resources(
replace(
compiled_routine,
input_params=sorted(set(compiled_routine.input_params).difference(inputs)),
ports=evaluate_ports(compiled_routine.ports, inputs, backend, functions_map),
resources=evaluate_resources(
compiled_routine.resources, inputs | children_variables, backend, functions_map
),
constraints=new_constraints,
repetition=evaluate_repetition(compiled_routine.repetition, inputs, backend, functions_map),
children=updated_children,
first_pass_resources=evaluate_resources(
compiled_routine.first_pass_resources, inputs | children_variables, backend, functions_map
),
),
backend,
derived_resources,
)
1 change: 1 addition & 0 deletions src/bartiq/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""The integrations submodule permits LaTeX pretty-printing, and implements Jupyter notebook widgets."""

from .latex import routine_to_latex

__all__ = ["routine_to_latex"]
Expand Down
Loading
Loading