diff --git a/docs/amplitude-analysis.ipynb b/docs/amplitude-analysis.ipynb
index f2f39658a..2539596e2 100644
--- a/docs/amplitude-analysis.ipynb
+++ b/docs/amplitude-analysis.ipynb
@@ -930,7 +930,7 @@
},
"outputs": [],
"source": [
- "parameter_names = {symbol.name for symbol in model.parameter_defaults}\n",
+ "parameter_names = {str(s) for s in model.parameter_defaults}\n",
"not_in_model = {\n",
" par_name for par_name in initial_parameters if par_name not in parameter_names\n",
"}\n",
@@ -988,7 +988,7 @@
"metadata": {},
"outputs": [],
"source": [
- "free_parameters = {p for p in model.parameter_defaults if p.name in initial_parameters}\n",
+ "free_parameters = {p for p in model.parameter_defaults if str(p) in initial_parameters}\n",
"fixed_parameters = {\n",
" p: v for p, v in model.parameter_defaults.items() if p not in free_parameters\n",
"}\n",
@@ -1296,7 +1296,7 @@
"free_parameter_symbols = [\n",
" symbol\n",
" for symbol in model.parameter_defaults\n",
- " if symbol.name in set(initial_parameters)\n",
+ " if str(symbol) in set(initial_parameters)\n",
"]"
]
},
diff --git a/docs/conf.py b/docs/conf.py
index d152e2417..6302ee23a 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -188,6 +188,7 @@ def get_tensorflow_url() -> str:
"show_toc_level": 2,
"use_download_button": False,
"use_edit_page_button": True,
+ "use_fullscreen_button": False,
"use_issues_button": True,
"use_repository_button": True,
"use_source_button": True,
@@ -203,7 +204,6 @@ def get_tensorflow_url() -> str:
"matplotlib": (f"https://matplotlib.org/{pin('matplotlib')}", None),
"numpy": (f"https://numpy.org/doc/{pin_minor('numpy')}", None),
"pandas": (f"https://pandas.pydata.org/pandas-docs/version/{pin('pandas')}", None),
- "pwa": ("https://pwa.readthedocs.io", None),
"python": ("https://docs.python.org/3", None),
"qrules": (f"https://qrules.readthedocs.io/{pin('qrules')}", None),
"scipy": (get_scipy_url(), None),
@@ -251,6 +251,7 @@ def get_tensorflow_url() -> str:
project = REPO_TITLE
pygments_style = "sphinx"
release = get_package_version("tensorwaves")
+suppress_warnings = ["myst.directive_unknown"]
thebe_config = {
"repository_url": html_theme_options["repository_url"],
"repository_branch": html_theme_options["repository_branch"],
diff --git a/docs/usage/caching.ipynb b/docs/usage/caching.ipynb
index e292dc11b..adb43c9be 100644
--- a/docs/usage/caching.ipynb
+++ b/docs/usage/caching.ipynb
@@ -165,7 +165,7 @@
" # SymbolIdentifiable because of alphabetical sorting in dotprint\n",
" @classmethod\n",
" def from_symbol(cls, symbol):\n",
- " return SymbolIdentifiable(symbol.name, **symbol.assumptions0)\n",
+ " return SymbolIdentifiable(str(symbol), **symbol.assumptions0)\n",
"\n",
"\n",
"dot_style = (\n",
diff --git a/docs/usage/faster-lambdify.ipynb b/docs/usage/faster-lambdify.ipynb
index 8279472f6..1ab2af3e6 100644
--- a/docs/usage/faster-lambdify.ipynb
+++ b/docs/usage/faster-lambdify.ipynb
@@ -230,18 +230,18 @@
"for symbol, definition in sub_expressions.items():\n",
" dot = sp.dotprint(definition, bgcolor=\"none\")\n",
" graph = graphviz.Source(dot)\n",
- " graph.render(filename=f\"sub_expr_{symbol.name}\", format=\"svg\")\n",
+ " graph.render(filename=f\"sub_expr_{symbol}\", format=\"svg\")\n",
"\n",
"html = \"
\\n\"\n",
"html += \" \\n\"\n",
"html += \"\".join(\n",
- " f' | {symbol.name} | \\n'\n",
+ " f' {symbol} | \\n'\n",
" for symbol in sub_expressions\n",
")\n",
"html += \"
\\n\"\n",
"html += \" \\n\"\n",
"for symbol in sub_expressions:\n",
- " svg = SVG(f\"sub_expr_{symbol.name}.svg\").data\n",
+ " svg = SVG(f\"sub_expr_{symbol}.svg\").data\n",
" html += f' | {svg} | \\n'\n",
"html += \"
\\n\"\n",
"html += \"
\"\n",
@@ -286,7 +286,7 @@
"outputs": [],
"source": [
"expression = model.expression.doit()\n",
- "sorted_symbols = sorted(expression.free_symbols, key=lambda s: s.name)"
+ "sorted_symbols = sorted(expression.free_symbols, key=str)"
]
},
{
diff --git a/src/tensorwaves/data/transform.py b/src/tensorwaves/data/transform.py
index 7db6b5c04..5446753d1 100644
--- a/src/tensorwaves/data/transform.py
+++ b/src/tensorwaves/data/transform.py
@@ -88,12 +88,12 @@ def from_sympy(
max_complexity: int | None = None,
) -> SympyDataTransformer:
expanded_expressions: dict[str, sp.Expr] = {
- k.name: expr.doit() for k, expr in expressions.items()
+ str(k): expr.doit() for k, expr in expressions.items()
}
free_symbols: set[sp.Symbol] = set()
for expr in expanded_expressions.values():
free_symbols |= _get_free_symbols(expr)
- ordered_symbols = tuple(sorted(free_symbols, key=lambda s: s.name))
+ ordered_symbols = tuple(sorted(free_symbols, key=str))
argument_order = tuple(map(str, ordered_symbols))
functions = {}
for variable_name, expr in expanded_expressions.items():
diff --git a/src/tensorwaves/function/sympy/__init__.py b/src/tensorwaves/function/sympy/__init__.py
index b7d9ba9a6..f53f44af2 100644
--- a/src/tensorwaves/function/sympy/__init__.py
+++ b/src/tensorwaves/function/sympy/__init__.py
@@ -3,7 +3,7 @@
from __future__ import annotations
import logging
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, TypeVar
from tqdm.auto import tqdm
@@ -22,6 +22,8 @@
from tensorwaves.interface import ParameterValue
+ Symbolic = TypeVar("Symbolic", bound=sp.Basic)
+
_LOGGER = logging.getLogger(__name__)
@@ -60,8 +62,9 @@ def create_function(
>>> function(data).tolist()
[0.0, 2.0, 8.0, 18.0]
"""
+ expression = _substitute_matrix_elements(expression)
free_symbols = _get_free_symbols(expression)
- sorted_symbols = sorted(free_symbols, key=lambda s: s.name)
+ sorted_symbols = sorted(free_symbols, key=str)
lambdified_function = _lambdify_normal_or_fast(
expression=expression,
symbols=sorted_symbols,
@@ -119,10 +122,11 @@ def create_parametrized_function( # noqa: PLR0913
>>> function(data).tolist()
[0.0, 0.0, 0.0, 0.0, 0.0]
"""
+ expression = _substitute_matrix_elements(expression)
free_symbols = _get_free_symbols(expression)
parameter_set = set(parameters)
- parameter_symbols = sorted(free_symbols & parameter_set, key=lambda s: s.name)
- data_symbols = sorted(free_symbols - parameter_set, key=lambda s: s.name)
+ parameter_symbols = sorted(free_symbols & parameter_set, key=str)
+ data_symbols = sorted(free_symbols - parameter_set, key=str)
sorted_symbols = tuple(data_symbols + parameter_symbols) # for partial+gradient
lambdified_function = _lambdify_normal_or_fast(
expression=expression,
@@ -139,6 +143,34 @@ def create_parametrized_function( # noqa: PLR0913
)
+def _substitute_matrix_elements(expression: Symbolic) -> Symbolic:
+ """Substitute elements of matrix symbols with actual symbol objects.
+
+ This is a workaround for the fact that the `~sympy.core.basic.Basic.free_symbols` of
+ an expression containing `~sympy.matrices.expressions.MatrixSymbol`s does not
+ contain the individual elements of these matrix symbols, but only the matrix symbols
+ themselves.
+
+ >>> import sympy as sp
+ >>> M = sp.MatrixSymbol("M", 2, 2)
+ >>> expr = M[0, 0] ** 2 + M[1, 1] ** 2
+ >>> expr.free_symbols
+ {M}
+ >>> new_expr = _substitute_matrix_elements(expr)
+ >>> sorted(new_expr.free_symbols, key=str)
+ [M[0, 0], M[1, 1]]
+ """
+ import sympy as sp
+
+ return expression.xreplace({
+ element: sp.Symbol(str(element), **element.assumptions0)
+ for matrix in expression.free_symbols
+ if isinstance(matrix, sp.MatrixSymbol)
+ for row in matrix.as_explicit().tolist()
+ for element in row
+ })
+
+
def _get_free_symbols(expression: sp.Basic) -> set[sp.Symbol]:
"""Get free symbols in an expression, excluding IndexedBase.
@@ -154,7 +186,7 @@ def _get_free_symbols(expression: sp.Basic) -> set[sp.Symbol]:
free_symbols: set[sp.Symbol] = expression.free_symbols # type: ignore[assignment]
index_bases = {
- sp.Symbol(s.base.name, **s.assumptions0)
+ sp.Symbol(str(s.base), **s.assumptions0)
for s in free_symbols
if isinstance(s, sp.Indexed)
}
@@ -324,7 +356,7 @@ def fast_lambdify( # noqa: PLR0913
top_expression, symbols, backend, use_cse=use_cse, use_jit=use_jit
)
- sorted_top_symbols = sorted(sub_expressions, key=lambda s: s.name)
+ sorted_top_symbols = sorted(sub_expressions, key=str)
top_function = lambdify(
top_expression, sorted_top_symbols, backend, use_cse=use_cse, use_jit=use_jit
)
diff --git a/tests/function/test_sympy.py b/tests/function/test_sympy.py
index a777c3f06..84bad8477 100644
--- a/tests/function/test_sympy.py
+++ b/tests/function/test_sympy.py
@@ -98,6 +98,14 @@ def test_create_function_indexed_symbol(backend: str):
assert func.argument_order == ("A[0]", "A[1]")
+@pytest.mark.parametrize("backend", ["jax", "math", "numpy", "tf"])
+def test_create_function_matrix_symbol(backend: str):
+ M = sp.MatrixSymbol("M", 2, 2) # noqa: N806
+ expr = M[0, 0] ** 2 + M[1, 1] ** 2
+ func = create_function(expr, backend=backend)
+ assert func.argument_order == ("M[0, 0]", "M[1, 1]")
+
+
@pytest.mark.parametrize("backend", ["jax", "math", "numpy", "tf"])
@pytest.mark.parametrize("max_complexity", [0, 1, 2, 3, 4, 5])
@pytest.mark.parametrize("use_cse", [False, True])
@@ -176,7 +184,7 @@ def test_split_expression():
assert expression == top_expr.xreplace(sub_expressions)
free_symbols: set[sp.Symbol] = top_expr.free_symbols # type: ignore[assignment]
- sub_symbols = sorted(free_symbols, key=lambda s: s.name)
+ sub_symbols = sorted(free_symbols, key=str)
assert len(sub_symbols) == 3
f0, f1, f2 = tuple(sub_symbols)
assert f0 is a