From 0ef36fe06c791e4abd926c0bc3468e7d611cacb4 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 13 Feb 2026 22:04:18 +0100 Subject: [PATCH 1/6] DOC: disable full-screen button on website --- docs/conf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/conf.py b/docs/conf.py index d152e241..35e92bc0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -188,6 +188,7 @@ def get_tensorflow_url() -> str: "show_toc_level": 2, "use_download_button": False, "use_edit_page_button": True, + "use_fullscreen_button": False, "use_issues_button": True, "use_repository_button": True, "use_source_button": True, From 56d9cf9a3f8f0bb92f423a55996d23e9706d233a Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 13 Feb 2026 22:04:19 +0100 Subject: [PATCH 2/6] DX: add `MatrixSymbol` test --- tests/function/test_sympy.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/function/test_sympy.py b/tests/function/test_sympy.py index a777c3f0..fc565991 100644 --- a/tests/function/test_sympy.py +++ b/tests/function/test_sympy.py @@ -98,6 +98,14 @@ def test_create_function_indexed_symbol(backend: str): assert func.argument_order == ("A[0]", "A[1]") +@pytest.mark.parametrize("backend", ["jax", "math", "numpy", "tf"]) +def test_create_function_matrix_symbol(backend: str): + M = sp.MatrixSymbol("M", 2, 2) # noqa: N806 + expr = M[0, 0] ** 2 + M[1, 1] ** 2 + func = create_function(expr, backend=backend) + assert func.argument_order == ("M[0, 0]", "M[1, 1]") + + @pytest.mark.parametrize("backend", ["jax", "math", "numpy", "tf"]) @pytest.mark.parametrize("max_complexity", [0, 1, 2, 3, 4, 5]) @pytest.mark.parametrize("use_cse", [False, True]) From ad0d0604c2b19246a67a0bb7fb7b321c8641aa76 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 13 Feb 2026 22:04:20 +0100 Subject: [PATCH 3/6] FIX: support lambdifying `Expr`s with `MatrixSymbol`s --- src/tensorwaves/function/sympy/__init__.py | 34 +++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/src/tensorwaves/function/sympy/__init__.py b/src/tensorwaves/function/sympy/__init__.py index b7d9ba9a..7290fab7 100644 --- a/src/tensorwaves/function/sympy/__init__.py +++ b/src/tensorwaves/function/sympy/__init__.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeVar from tqdm.auto import tqdm @@ -22,6 +22,8 @@ from tensorwaves.interface import ParameterValue + Symbolic = TypeVar("Symbolic", bound=sp.Basic) + _LOGGER = logging.getLogger(__name__) @@ -60,6 +62,7 @@ def create_function( >>> function(data).tolist() [0.0, 2.0, 8.0, 18.0] """ + expression = _substitute_matrix_elements(expression) free_symbols = _get_free_symbols(expression) sorted_symbols = sorted(free_symbols, key=lambda s: s.name) lambdified_function = _lambdify_normal_or_fast( @@ -119,6 +122,7 @@ def create_parametrized_function( # noqa: PLR0913 >>> function(data).tolist() [0.0, 0.0, 0.0, 0.0, 0.0] """ + expression = _substitute_matrix_elements(expression) free_symbols = _get_free_symbols(expression) parameter_set = set(parameters) parameter_symbols = sorted(free_symbols & parameter_set, key=lambda s: s.name) @@ -139,6 +143,34 @@ def create_parametrized_function( # noqa: PLR0913 ) +def _substitute_matrix_elements(expression: Symbolic) -> Symbolic: + """Substitute elements of matrix symbols with actual symbol objects. + + This is a workaround for the fact that the `~sympy.core.basic.Basic.free_symbols` of + an expression containing `~sympy.matrices.expressions.MatrixSymbol`s does not + contain the individual elements of these matrix symbols, but only the matrix symbols + themselves. + + >>> import sympy as sp + >>> M = sp.MatrixSymbol("M", 2, 2) + >>> expr = M[0, 0] ** 2 + M[1, 1] ** 2 + >>> expr.free_symbols + {M} + >>> new_expr = _substitute_matrix_elements(expr) + >>> sorted(new_expr.free_symbols, key=str) + [M[0, 0], M[1, 1]] + """ + import sympy as sp + + return expression.xreplace({ + element: sp.Symbol(str(element), **element.assumptions0) + for matrix in expression.free_symbols + if isinstance(matrix, sp.MatrixSymbol) + for row in matrix.as_explicit().tolist() + for element in row + }) + + def _get_free_symbols(expression: sp.Basic) -> set[sp.Symbol]: """Get free symbols in an expression, excluding IndexedBase. From a3bab677f525f0d21d2c211feeb5ef7ec6bb314d Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 13 Feb 2026 22:04:21 +0100 Subject: [PATCH 4/6] ENH: accept any `sp.Basic` in sorting --- docs/amplitude-analysis.ipynb | 6 +++--- docs/usage/caching.ipynb | 2 +- docs/usage/faster-lambdify.ipynb | 8 ++++---- src/tensorwaves/data/transform.py | 4 ++-- src/tensorwaves/function/sympy/__init__.py | 10 +++++----- tests/function/test_sympy.py | 2 +- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/docs/amplitude-analysis.ipynb b/docs/amplitude-analysis.ipynb index f2f39658..2539596e 100644 --- a/docs/amplitude-analysis.ipynb +++ b/docs/amplitude-analysis.ipynb @@ -930,7 +930,7 @@ }, "outputs": [], "source": [ - "parameter_names = {symbol.name for symbol in model.parameter_defaults}\n", + "parameter_names = {str(s) for s in model.parameter_defaults}\n", "not_in_model = {\n", " par_name for par_name in initial_parameters if par_name not in parameter_names\n", "}\n", @@ -988,7 +988,7 @@ "metadata": {}, "outputs": [], "source": [ - "free_parameters = {p for p in model.parameter_defaults if p.name in initial_parameters}\n", + "free_parameters = {p for p in model.parameter_defaults if str(p) in initial_parameters}\n", "fixed_parameters = {\n", " p: v for p, v in model.parameter_defaults.items() if p not in free_parameters\n", "}\n", @@ -1296,7 +1296,7 @@ "free_parameter_symbols = [\n", " symbol\n", " for symbol in model.parameter_defaults\n", - " if symbol.name in set(initial_parameters)\n", + " if str(symbol) in set(initial_parameters)\n", "]" ] }, diff --git a/docs/usage/caching.ipynb b/docs/usage/caching.ipynb index e292dc11..adb43c9b 100644 --- a/docs/usage/caching.ipynb +++ b/docs/usage/caching.ipynb @@ -165,7 +165,7 @@ " # SymbolIdentifiable because of alphabetical sorting in dotprint\n", " @classmethod\n", " def from_symbol(cls, symbol):\n", - " return SymbolIdentifiable(symbol.name, **symbol.assumptions0)\n", + " return SymbolIdentifiable(str(symbol), **symbol.assumptions0)\n", "\n", "\n", "dot_style = (\n", diff --git a/docs/usage/faster-lambdify.ipynb b/docs/usage/faster-lambdify.ipynb index 8279472f..1ab2af3e 100644 --- a/docs/usage/faster-lambdify.ipynb +++ b/docs/usage/faster-lambdify.ipynb @@ -230,18 +230,18 @@ "for symbol, definition in sub_expressions.items():\n", " dot = sp.dotprint(definition, bgcolor=\"none\")\n", " graph = graphviz.Source(dot)\n", - " graph.render(filename=f\"sub_expr_{symbol.name}\", format=\"svg\")\n", + " graph.render(filename=f\"sub_expr_{symbol}\", format=\"svg\")\n", "\n", "html = \"\\n\"\n", "html += \" \\n\"\n", "html += \"\".join(\n", - " f' \\n'\n", + " f' \\n'\n", " for symbol in sub_expressions\n", ")\n", "html += \" \\n\"\n", "html += \" \\n\"\n", "for symbol in sub_expressions:\n", - " svg = SVG(f\"sub_expr_{symbol.name}.svg\").data\n", + " svg = SVG(f\"sub_expr_{symbol}.svg\").data\n", " html += f' \\n'\n", "html += \" \\n\"\n", "html += \"
{symbol.name}{symbol}
{svg}
\"\n", @@ -286,7 +286,7 @@ "outputs": [], "source": [ "expression = model.expression.doit()\n", - "sorted_symbols = sorted(expression.free_symbols, key=lambda s: s.name)" + "sorted_symbols = sorted(expression.free_symbols, key=str)" ] }, { diff --git a/src/tensorwaves/data/transform.py b/src/tensorwaves/data/transform.py index 7db6b5c0..5446753d 100644 --- a/src/tensorwaves/data/transform.py +++ b/src/tensorwaves/data/transform.py @@ -88,12 +88,12 @@ def from_sympy( max_complexity: int | None = None, ) -> SympyDataTransformer: expanded_expressions: dict[str, sp.Expr] = { - k.name: expr.doit() for k, expr in expressions.items() + str(k): expr.doit() for k, expr in expressions.items() } free_symbols: set[sp.Symbol] = set() for expr in expanded_expressions.values(): free_symbols |= _get_free_symbols(expr) - ordered_symbols = tuple(sorted(free_symbols, key=lambda s: s.name)) + ordered_symbols = tuple(sorted(free_symbols, key=str)) argument_order = tuple(map(str, ordered_symbols)) functions = {} for variable_name, expr in expanded_expressions.items(): diff --git a/src/tensorwaves/function/sympy/__init__.py b/src/tensorwaves/function/sympy/__init__.py index 7290fab7..f53f44af 100644 --- a/src/tensorwaves/function/sympy/__init__.py +++ b/src/tensorwaves/function/sympy/__init__.py @@ -64,7 +64,7 @@ def create_function( """ expression = _substitute_matrix_elements(expression) free_symbols = _get_free_symbols(expression) - sorted_symbols = sorted(free_symbols, key=lambda s: s.name) + sorted_symbols = sorted(free_symbols, key=str) lambdified_function = _lambdify_normal_or_fast( expression=expression, symbols=sorted_symbols, @@ -125,8 +125,8 @@ def create_parametrized_function( # noqa: PLR0913 expression = _substitute_matrix_elements(expression) free_symbols = _get_free_symbols(expression) parameter_set = set(parameters) - parameter_symbols = sorted(free_symbols & parameter_set, key=lambda s: s.name) - data_symbols = sorted(free_symbols - parameter_set, key=lambda s: s.name) + parameter_symbols = sorted(free_symbols & parameter_set, key=str) + data_symbols = sorted(free_symbols - parameter_set, key=str) sorted_symbols = tuple(data_symbols + parameter_symbols) # for partial+gradient lambdified_function = _lambdify_normal_or_fast( expression=expression, @@ -186,7 +186,7 @@ def _get_free_symbols(expression: sp.Basic) -> set[sp.Symbol]: free_symbols: set[sp.Symbol] = expression.free_symbols # type: ignore[assignment] index_bases = { - sp.Symbol(s.base.name, **s.assumptions0) + sp.Symbol(str(s.base), **s.assumptions0) for s in free_symbols if isinstance(s, sp.Indexed) } @@ -356,7 +356,7 @@ def fast_lambdify( # noqa: PLR0913 top_expression, symbols, backend, use_cse=use_cse, use_jit=use_jit ) - sorted_top_symbols = sorted(sub_expressions, key=lambda s: s.name) + sorted_top_symbols = sorted(sub_expressions, key=str) top_function = lambdify( top_expression, sorted_top_symbols, backend, use_cse=use_cse, use_jit=use_jit ) diff --git a/tests/function/test_sympy.py b/tests/function/test_sympy.py index fc565991..84bad847 100644 --- a/tests/function/test_sympy.py +++ b/tests/function/test_sympy.py @@ -184,7 +184,7 @@ def test_split_expression(): assert expression == top_expr.xreplace(sub_expressions) free_symbols: set[sp.Symbol] = top_expr.free_symbols # type: ignore[assignment] - sub_symbols = sorted(free_symbols, key=lambda s: s.name) + sub_symbols = sorted(free_symbols, key=str) assert len(sub_symbols) == 3 f0, f1, f2 = tuple(sub_symbols) assert f0 is a From cccd73c89226577e95ae7c9cf52405fd10ffcd93 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 13 Feb 2026 22:35:38 +0100 Subject: [PATCH 5/6] FIX: remove outdated `pwa` Sphinx inventory --- docs/conf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index 35e92bc0..94797372 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -204,7 +204,6 @@ def get_tensorflow_url() -> str: "matplotlib": (f"https://matplotlib.org/{pin('matplotlib')}", None), "numpy": (f"https://numpy.org/doc/{pin_minor('numpy')}", None), "pandas": (f"https://pandas.pydata.org/pandas-docs/version/{pin('pandas')}", None), - "pwa": ("https://pwa.readthedocs.io", None), "python": ("https://docs.python.org/3", None), "qrules": (f"https://qrules.readthedocs.io/{pin('qrules')}", None), "scipy": (get_scipy_url(), None), From e11721337986b9d256378cb4388e9d2362a54067 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 13 Feb 2026 22:36:27 +0100 Subject: [PATCH 6/6] DX: suppress `myst.directive_unknown` --- docs/conf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/conf.py b/docs/conf.py index 94797372..6302ee23 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -251,6 +251,7 @@ def get_tensorflow_url() -> str: project = REPO_TITLE pygments_style = "sphinx" release = get_package_version("tensorwaves") +suppress_warnings = ["myst.directive_unknown"] thebe_config = { "repository_url": html_theme_options["repository_url"], "repository_branch": html_theme_options["repository_branch"],