Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/amplitude-analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,7 @@
},
"outputs": [],
"source": [
"parameter_names = {symbol.name for symbol in model.parameter_defaults}\n",
"parameter_names = {str(s) for s in model.parameter_defaults}\n",
"not_in_model = {\n",
" par_name for par_name in initial_parameters if par_name not in parameter_names\n",
"}\n",
Expand Down Expand Up @@ -988,7 +988,7 @@
"metadata": {},
"outputs": [],
"source": [
"free_parameters = {p for p in model.parameter_defaults if p.name in initial_parameters}\n",
"free_parameters = {p for p in model.parameter_defaults if str(p) in initial_parameters}\n",
"fixed_parameters = {\n",
" p: v for p, v in model.parameter_defaults.items() if p not in free_parameters\n",
"}\n",
Expand Down Expand Up @@ -1296,7 +1296,7 @@
"free_parameter_symbols = [\n",
" symbol\n",
" for symbol in model.parameter_defaults\n",
" if symbol.name in set(initial_parameters)\n",
" if str(symbol) in set(initial_parameters)\n",
"]"
]
},
Expand Down
3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def get_tensorflow_url() -> str:
"show_toc_level": 2,
"use_download_button": False,
"use_edit_page_button": True,
"use_fullscreen_button": False,
"use_issues_button": True,
"use_repository_button": True,
"use_source_button": True,
Expand All @@ -203,7 +204,6 @@ def get_tensorflow_url() -> str:
"matplotlib": (f"https://matplotlib.org/{pin('matplotlib')}", None),
"numpy": (f"https://numpy.org/doc/{pin_minor('numpy')}", None),
"pandas": (f"https://pandas.pydata.org/pandas-docs/version/{pin('pandas')}", None),
"pwa": ("https://pwa.readthedocs.io", None),
"python": ("https://docs.python.org/3", None),
"qrules": (f"https://qrules.readthedocs.io/{pin('qrules')}", None),
"scipy": (get_scipy_url(), None),
Expand Down Expand Up @@ -251,6 +251,7 @@ def get_tensorflow_url() -> str:
project = REPO_TITLE
pygments_style = "sphinx"
release = get_package_version("tensorwaves")
suppress_warnings = ["myst.directive_unknown"]
thebe_config = {
"repository_url": html_theme_options["repository_url"],
"repository_branch": html_theme_options["repository_branch"],
Expand Down
2 changes: 1 addition & 1 deletion docs/usage/caching.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@
" # SymbolIdentifiable because of alphabetical sorting in dotprint\n",
" @classmethod\n",
" def from_symbol(cls, symbol):\n",
" return SymbolIdentifiable(symbol.name, **symbol.assumptions0)\n",
" return SymbolIdentifiable(str(symbol), **symbol.assumptions0)\n",
"\n",
"\n",
"dot_style = (\n",
Expand Down
8 changes: 4 additions & 4 deletions docs/usage/faster-lambdify.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -230,18 +230,18 @@
"for symbol, definition in sub_expressions.items():\n",
" dot = sp.dotprint(definition, bgcolor=\"none\")\n",
" graph = graphviz.Source(dot)\n",
" graph.render(filename=f\"sub_expr_{symbol.name}\", format=\"svg\")\n",
" graph.render(filename=f\"sub_expr_{symbol}\", format=\"svg\")\n",
"\n",
"html = \"<table>\\n\"\n",
"html += \" <tr>\\n\"\n",
"html += \"\".join(\n",
" f' <th style=\"text-align:center; background-color:white\">{symbol.name}</th>\\n'\n",
" f' <th style=\"text-align:center; background-color:white\">{symbol}</th>\\n'\n",
" for symbol in sub_expressions\n",
")\n",
"html += \" </tr>\\n\"\n",
"html += \" <tr>\\n\"\n",
"for symbol in sub_expressions:\n",
" svg = SVG(f\"sub_expr_{symbol.name}.svg\").data\n",
" svg = SVG(f\"sub_expr_{symbol}.svg\").data\n",
" html += f' <td style=\"background-color:white\">{svg}</td>\\n'\n",
"html += \" </tr>\\n\"\n",
"html += \"</table>\"\n",
Expand Down Expand Up @@ -286,7 +286,7 @@
"outputs": [],
"source": [
"expression = model.expression.doit()\n",
"sorted_symbols = sorted(expression.free_symbols, key=lambda s: s.name)"
"sorted_symbols = sorted(expression.free_symbols, key=str)"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions src/tensorwaves/data/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ def from_sympy(
max_complexity: int | None = None,
) -> SympyDataTransformer:
expanded_expressions: dict[str, sp.Expr] = {
k.name: expr.doit() for k, expr in expressions.items()
str(k): expr.doit() for k, expr in expressions.items()
}
free_symbols: set[sp.Symbol] = set()
for expr in expanded_expressions.values():
free_symbols |= _get_free_symbols(expr)
ordered_symbols = tuple(sorted(free_symbols, key=lambda s: s.name))
ordered_symbols = tuple(sorted(free_symbols, key=str))
argument_order = tuple(map(str, ordered_symbols))
functions = {}
for variable_name, expr in expanded_expressions.items():
Expand Down
44 changes: 38 additions & 6 deletions src/tensorwaves/function/sympy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, TypeVar

from tqdm.auto import tqdm

Expand All @@ -22,6 +22,8 @@

from tensorwaves.interface import ParameterValue

Symbolic = TypeVar("Symbolic", bound=sp.Basic)

_LOGGER = logging.getLogger(__name__)


Expand Down Expand Up @@ -60,8 +62,9 @@ def create_function(
>>> function(data).tolist()
[0.0, 2.0, 8.0, 18.0]
"""
expression = _substitute_matrix_elements(expression)
free_symbols = _get_free_symbols(expression)
sorted_symbols = sorted(free_symbols, key=lambda s: s.name)
sorted_symbols = sorted(free_symbols, key=str)
lambdified_function = _lambdify_normal_or_fast(
expression=expression,
symbols=sorted_symbols,
Expand Down Expand Up @@ -119,10 +122,11 @@ def create_parametrized_function( # noqa: PLR0913
>>> function(data).tolist()
[0.0, 0.0, 0.0, 0.0, 0.0]
"""
expression = _substitute_matrix_elements(expression)
free_symbols = _get_free_symbols(expression)
parameter_set = set(parameters)
parameter_symbols = sorted(free_symbols & parameter_set, key=lambda s: s.name)
data_symbols = sorted(free_symbols - parameter_set, key=lambda s: s.name)
parameter_symbols = sorted(free_symbols & parameter_set, key=str)
data_symbols = sorted(free_symbols - parameter_set, key=str)
sorted_symbols = tuple(data_symbols + parameter_symbols) # for partial+gradient
lambdified_function = _lambdify_normal_or_fast(
expression=expression,
Expand All @@ -139,6 +143,34 @@ def create_parametrized_function( # noqa: PLR0913
)


def _substitute_matrix_elements(expression: Symbolic) -> Symbolic:
"""Substitute elements of matrix symbols with actual symbol objects.

This is a workaround for the fact that the `~sympy.core.basic.Basic.free_symbols` of
an expression containing `~sympy.matrices.expressions.MatrixSymbol`s does not
contain the individual elements of these matrix symbols, but only the matrix symbols
themselves.

>>> import sympy as sp
>>> M = sp.MatrixSymbol("M", 2, 2)
>>> expr = M[0, 0] ** 2 + M[1, 1] ** 2
>>> expr.free_symbols
{M}
>>> new_expr = _substitute_matrix_elements(expr)
>>> sorted(new_expr.free_symbols, key=str)
[M[0, 0], M[1, 1]]
"""
import sympy as sp

return expression.xreplace({
element: sp.Symbol(str(element), **element.assumptions0)
for matrix in expression.free_symbols
if isinstance(matrix, sp.MatrixSymbol)
for row in matrix.as_explicit().tolist()
for element in row
})


def _get_free_symbols(expression: sp.Basic) -> set[sp.Symbol]:
"""Get free symbols in an expression, excluding IndexedBase.

Expand All @@ -154,7 +186,7 @@ def _get_free_symbols(expression: sp.Basic) -> set[sp.Symbol]:

free_symbols: set[sp.Symbol] = expression.free_symbols # type: ignore[assignment]
index_bases = {
sp.Symbol(s.base.name, **s.assumptions0)
sp.Symbol(str(s.base), **s.assumptions0)
for s in free_symbols
if isinstance(s, sp.Indexed)
}
Expand Down Expand Up @@ -324,7 +356,7 @@ def fast_lambdify( # noqa: PLR0913
top_expression, symbols, backend, use_cse=use_cse, use_jit=use_jit
)

sorted_top_symbols = sorted(sub_expressions, key=lambda s: s.name)
sorted_top_symbols = sorted(sub_expressions, key=str)
top_function = lambdify(
top_expression, sorted_top_symbols, backend, use_cse=use_cse, use_jit=use_jit
)
Expand Down
10 changes: 9 additions & 1 deletion tests/function/test_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ def test_create_function_indexed_symbol(backend: str):
assert func.argument_order == ("A[0]", "A[1]")


@pytest.mark.parametrize("backend", ["jax", "math", "numpy", "tf"])
def test_create_function_matrix_symbol(backend: str):
M = sp.MatrixSymbol("M", 2, 2) # noqa: N806
expr = M[0, 0] ** 2 + M[1, 1] ** 2
func = create_function(expr, backend=backend)
assert func.argument_order == ("M[0, 0]", "M[1, 1]")


@pytest.mark.parametrize("backend", ["jax", "math", "numpy", "tf"])
@pytest.mark.parametrize("max_complexity", [0, 1, 2, 3, 4, 5])
@pytest.mark.parametrize("use_cse", [False, True])
Expand Down Expand Up @@ -176,7 +184,7 @@ def test_split_expression():
assert expression == top_expr.xreplace(sub_expressions)

free_symbols: set[sp.Symbol] = top_expr.free_symbols # type: ignore[assignment]
sub_symbols = sorted(free_symbols, key=lambda s: s.name)
sub_symbols = sorted(free_symbols, key=str)
assert len(sub_symbols) == 3
f0, f1, f2 = tuple(sub_symbols)
assert f0 is a
Expand Down