From 0ef36fe06c791e4abd926c0bc3468e7d611cacb4 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 13 Feb 2026 22:04:18 +0100 Subject: [PATCH 1/6] DOC: disable full-screen button on website --- docs/conf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/conf.py b/docs/conf.py index d152e241..35e92bc0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -188,6 +188,7 @@ def get_tensorflow_url() -> str: "show_toc_level": 2, "use_download_button": False, "use_edit_page_button": True, + "use_fullscreen_button": False, "use_issues_button": True, "use_repository_button": True, "use_source_button": True, From 56d9cf9a3f8f0bb92f423a55996d23e9706d233a Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 13 Feb 2026 22:04:19 +0100 Subject: [PATCH 2/6] DX: add `MatrixSymbol` test --- tests/function/test_sympy.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/function/test_sympy.py b/tests/function/test_sympy.py index a777c3f0..fc565991 100644 --- a/tests/function/test_sympy.py +++ b/tests/function/test_sympy.py @@ -98,6 +98,14 @@ def test_create_function_indexed_symbol(backend: str): assert func.argument_order == ("A[0]", "A[1]") +@pytest.mark.parametrize("backend", ["jax", "math", "numpy", "tf"]) +def test_create_function_matrix_symbol(backend: str): + M = sp.MatrixSymbol("M", 2, 2) # noqa: N806 + expr = M[0, 0] ** 2 + M[1, 1] ** 2 + func = create_function(expr, backend=backend) + assert func.argument_order == ("M[0, 0]", "M[1, 1]") + + @pytest.mark.parametrize("backend", ["jax", "math", "numpy", "tf"]) @pytest.mark.parametrize("max_complexity", [0, 1, 2, 3, 4, 5]) @pytest.mark.parametrize("use_cse", [False, True]) From ad0d0604c2b19246a67a0bb7fb7b321c8641aa76 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 13 Feb 2026 22:04:20 +0100 Subject: [PATCH 3/6] FIX: support lambdifying `Expr`s with `MatrixSymbol`s --- src/tensorwaves/function/sympy/__init__.py | 34 +++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/src/tensorwaves/function/sympy/__init__.py b/src/tensorwaves/function/sympy/__init__.py index b7d9ba9a..7290fab7 100644 --- a/src/tensorwaves/function/sympy/__init__.py +++ b/src/tensorwaves/function/sympy/__init__.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeVar from tqdm.auto import tqdm @@ -22,6 +22,8 @@ from tensorwaves.interface import ParameterValue + Symbolic = TypeVar("Symbolic", bound=sp.Basic) + _LOGGER = logging.getLogger(__name__) @@ -60,6 +62,7 @@ def create_function( >>> function(data).tolist() [0.0, 2.0, 8.0, 18.0] """ + expression = _substitute_matrix_elements(expression) free_symbols = _get_free_symbols(expression) sorted_symbols = sorted(free_symbols, key=lambda s: s.name) lambdified_function = _lambdify_normal_or_fast( @@ -119,6 +122,7 @@ def create_parametrized_function( # noqa: PLR0913 >>> function(data).tolist() [0.0, 0.0, 0.0, 0.0, 0.0] """ + expression = _substitute_matrix_elements(expression) free_symbols = _get_free_symbols(expression) parameter_set = set(parameters) parameter_symbols = sorted(free_symbols & parameter_set, key=lambda s: s.name) @@ -139,6 +143,34 @@ def create_parametrized_function( # noqa: PLR0913 ) +def _substitute_matrix_elements(expression: Symbolic) -> Symbolic: + """Substitute elements of matrix symbols with actual symbol objects. + + This is a workaround for the fact that the `~sympy.core.basic.Basic.free_symbols` of + an expression containing `~sympy.matrices.expressions.MatrixSymbol`s does not + contain the individual elements of these matrix symbols, but only the matrix symbols + themselves. + + >>> import sympy as sp + >>> M = sp.MatrixSymbol("M", 2, 2) + >>> expr = M[0, 0] ** 2 + M[1, 1] ** 2 + >>> expr.free_symbols + {M} + >>> new_expr = _substitute_matrix_elements(expr) + >>> sorted(new_expr.free_symbols, key=str) + [M[0, 0], M[1, 1]] + """ + import sympy as sp + + return expression.xreplace({ + element: sp.Symbol(str(element), **element.assumptions0) + for matrix in expression.free_symbols + if isinstance(matrix, sp.MatrixSymbol) + for row in matrix.as_explicit().tolist() + for element in row + }) + + def _get_free_symbols(expression: sp.Basic) -> set[sp.Symbol]: """Get free symbols in an expression, excluding IndexedBase. From a3bab677f525f0d21d2c211feeb5ef7ec6bb314d Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 13 Feb 2026 22:04:21 +0100 Subject: [PATCH 4/6] ENH: accept any `sp.Basic` in sorting --- docs/amplitude-analysis.ipynb | 6 +++--- docs/usage/caching.ipynb | 2 +- docs/usage/faster-lambdify.ipynb | 8 ++++---- src/tensorwaves/data/transform.py | 4 ++-- src/tensorwaves/function/sympy/__init__.py | 10 +++++----- tests/function/test_sympy.py | 2 +- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/docs/amplitude-analysis.ipynb b/docs/amplitude-analysis.ipynb index f2f39658..2539596e 100644 --- a/docs/amplitude-analysis.ipynb +++ b/docs/amplitude-analysis.ipynb @@ -930,7 +930,7 @@ }, "outputs": [], "source": [ - "parameter_names = {symbol.name for symbol in model.parameter_defaults}\n", + "parameter_names = {str(s) for s in model.parameter_defaults}\n", "not_in_model = {\n", " par_name for par_name in initial_parameters if par_name not in parameter_names\n", "}\n", @@ -988,7 +988,7 @@ "metadata": {}, "outputs": [], "source": [ - "free_parameters = {p for p in model.parameter_defaults if p.name in initial_parameters}\n", + "free_parameters = {p for p in model.parameter_defaults if str(p) in initial_parameters}\n", "fixed_parameters = {\n", " p: v for p, v in model.parameter_defaults.items() if p not in free_parameters\n", "}\n", @@ -1296,7 +1296,7 @@ "free_parameter_symbols = [\n", " symbol\n", " for symbol in model.parameter_defaults\n", - " if symbol.name in set(initial_parameters)\n", + " if str(symbol) in set(initial_parameters)\n", "]" ] }, diff --git a/docs/usage/caching.ipynb b/docs/usage/caching.ipynb index e292dc11..adb43c9b 100644 --- a/docs/usage/caching.ipynb +++ b/docs/usage/caching.ipynb @@ -165,7 +165,7 @@ " # SymbolIdentifiable because of alphabetical sorting in dotprint\n", " @classmethod\n", " def from_symbol(cls, symbol):\n", - " return SymbolIdentifiable(symbol.name, **symbol.assumptions0)\n", + " return SymbolIdentifiable(str(symbol), **symbol.assumptions0)\n", "\n", "\n", "dot_style = (\n", diff --git a/docs/usage/faster-lambdify.ipynb b/docs/usage/faster-lambdify.ipynb index 8279472f..1ab2af3e 100644 --- a/docs/usage/faster-lambdify.ipynb +++ b/docs/usage/faster-lambdify.ipynb @@ -230,18 +230,18 @@ "for symbol, definition in sub_expressions.items():\n", " dot = sp.dotprint(definition, bgcolor=\"none\")\n", " graph = graphviz.Source(dot)\n", - " graph.render(filename=f\"sub_expr_{symbol.name}\", format=\"svg\")\n", + " graph.render(filename=f\"sub_expr_{symbol}\", format=\"svg\")\n", "\n", "html = \"
| {symbol.name} | \\n'\n", + " f'{symbol} | \\n'\n", " for symbol in sub_expressions\n", ")\n", "html += \"
|---|---|
| {svg} | \\n'\n", "html += \"