diff --git a/docs/amplitude-analysis.ipynb b/docs/amplitude-analysis.ipynb index f2f39658a..2539596e2 100644 --- a/docs/amplitude-analysis.ipynb +++ b/docs/amplitude-analysis.ipynb @@ -930,7 +930,7 @@ }, "outputs": [], "source": [ - "parameter_names = {symbol.name for symbol in model.parameter_defaults}\n", + "parameter_names = {str(s) for s in model.parameter_defaults}\n", "not_in_model = {\n", " par_name for par_name in initial_parameters if par_name not in parameter_names\n", "}\n", @@ -988,7 +988,7 @@ "metadata": {}, "outputs": [], "source": [ - "free_parameters = {p for p in model.parameter_defaults if p.name in initial_parameters}\n", + "free_parameters = {p for p in model.parameter_defaults if str(p) in initial_parameters}\n", "fixed_parameters = {\n", " p: v for p, v in model.parameter_defaults.items() if p not in free_parameters\n", "}\n", @@ -1296,7 +1296,7 @@ "free_parameter_symbols = [\n", " symbol\n", " for symbol in model.parameter_defaults\n", - " if symbol.name in set(initial_parameters)\n", + " if str(symbol) in set(initial_parameters)\n", "]" ] }, diff --git a/docs/conf.py b/docs/conf.py index d152e2417..6302ee23a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -188,6 +188,7 @@ def get_tensorflow_url() -> str: "show_toc_level": 2, "use_download_button": False, "use_edit_page_button": True, + "use_fullscreen_button": False, "use_issues_button": True, "use_repository_button": True, "use_source_button": True, @@ -203,7 +204,6 @@ def get_tensorflow_url() -> str: "matplotlib": (f"https://matplotlib.org/{pin('matplotlib')}", None), "numpy": (f"https://numpy.org/doc/{pin_minor('numpy')}", None), "pandas": (f"https://pandas.pydata.org/pandas-docs/version/{pin('pandas')}", None), - "pwa": ("https://pwa.readthedocs.io", None), "python": ("https://docs.python.org/3", None), "qrules": (f"https://qrules.readthedocs.io/{pin('qrules')}", None), "scipy": (get_scipy_url(), None), @@ -251,6 +251,7 @@ def get_tensorflow_url() -> str: project = REPO_TITLE pygments_style = "sphinx" release = get_package_version("tensorwaves") +suppress_warnings = ["myst.directive_unknown"] thebe_config = { "repository_url": html_theme_options["repository_url"], "repository_branch": html_theme_options["repository_branch"], diff --git a/docs/usage/caching.ipynb b/docs/usage/caching.ipynb index e292dc11b..adb43c9be 100644 --- a/docs/usage/caching.ipynb +++ b/docs/usage/caching.ipynb @@ -165,7 +165,7 @@ " # SymbolIdentifiable because of alphabetical sorting in dotprint\n", " @classmethod\n", " def from_symbol(cls, symbol):\n", - " return SymbolIdentifiable(symbol.name, **symbol.assumptions0)\n", + " return SymbolIdentifiable(str(symbol), **symbol.assumptions0)\n", "\n", "\n", "dot_style = (\n", diff --git a/docs/usage/faster-lambdify.ipynb b/docs/usage/faster-lambdify.ipynb index 8279472f6..1ab2af3e6 100644 --- a/docs/usage/faster-lambdify.ipynb +++ b/docs/usage/faster-lambdify.ipynb @@ -230,18 +230,18 @@ "for symbol, definition in sub_expressions.items():\n", " dot = sp.dotprint(definition, bgcolor=\"none\")\n", " graph = graphviz.Source(dot)\n", - " graph.render(filename=f\"sub_expr_{symbol.name}\", format=\"svg\")\n", + " graph.render(filename=f\"sub_expr_{symbol}\", format=\"svg\")\n", "\n", "html = \"\\n\"\n", "html += \" \\n\"\n", "html += \"\".join(\n", - " f' \\n'\n", + " f' \\n'\n", " for symbol in sub_expressions\n", ")\n", "html += \" \\n\"\n", "html += \" \\n\"\n", "for symbol in sub_expressions:\n", - " svg = SVG(f\"sub_expr_{symbol.name}.svg\").data\n", + " svg = SVG(f\"sub_expr_{symbol}.svg\").data\n", " html += f' \\n'\n", "html += \" \\n\"\n", "html += \"
{symbol.name}{symbol}
{svg}
\"\n", @@ -286,7 +286,7 @@ "outputs": [], "source": [ "expression = model.expression.doit()\n", - "sorted_symbols = sorted(expression.free_symbols, key=lambda s: s.name)" + "sorted_symbols = sorted(expression.free_symbols, key=str)" ] }, { diff --git a/src/tensorwaves/data/transform.py b/src/tensorwaves/data/transform.py index 7db6b5c04..5446753d1 100644 --- a/src/tensorwaves/data/transform.py +++ b/src/tensorwaves/data/transform.py @@ -88,12 +88,12 @@ def from_sympy( max_complexity: int | None = None, ) -> SympyDataTransformer: expanded_expressions: dict[str, sp.Expr] = { - k.name: expr.doit() for k, expr in expressions.items() + str(k): expr.doit() for k, expr in expressions.items() } free_symbols: set[sp.Symbol] = set() for expr in expanded_expressions.values(): free_symbols |= _get_free_symbols(expr) - ordered_symbols = tuple(sorted(free_symbols, key=lambda s: s.name)) + ordered_symbols = tuple(sorted(free_symbols, key=str)) argument_order = tuple(map(str, ordered_symbols)) functions = {} for variable_name, expr in expanded_expressions.items(): diff --git a/src/tensorwaves/function/sympy/__init__.py b/src/tensorwaves/function/sympy/__init__.py index b7d9ba9a6..f53f44af2 100644 --- a/src/tensorwaves/function/sympy/__init__.py +++ b/src/tensorwaves/function/sympy/__init__.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeVar from tqdm.auto import tqdm @@ -22,6 +22,8 @@ from tensorwaves.interface import ParameterValue + Symbolic = TypeVar("Symbolic", bound=sp.Basic) + _LOGGER = logging.getLogger(__name__) @@ -60,8 +62,9 @@ def create_function( >>> function(data).tolist() [0.0, 2.0, 8.0, 18.0] """ + expression = _substitute_matrix_elements(expression) free_symbols = _get_free_symbols(expression) - sorted_symbols = sorted(free_symbols, key=lambda s: s.name) + sorted_symbols = sorted(free_symbols, key=str) lambdified_function = _lambdify_normal_or_fast( expression=expression, symbols=sorted_symbols, @@ -119,10 +122,11 @@ def create_parametrized_function( # noqa: PLR0913 >>> function(data).tolist() [0.0, 0.0, 0.0, 0.0, 0.0] """ + expression = _substitute_matrix_elements(expression) free_symbols = _get_free_symbols(expression) parameter_set = set(parameters) - parameter_symbols = sorted(free_symbols & parameter_set, key=lambda s: s.name) - data_symbols = sorted(free_symbols - parameter_set, key=lambda s: s.name) + parameter_symbols = sorted(free_symbols & parameter_set, key=str) + data_symbols = sorted(free_symbols - parameter_set, key=str) sorted_symbols = tuple(data_symbols + parameter_symbols) # for partial+gradient lambdified_function = _lambdify_normal_or_fast( expression=expression, @@ -139,6 +143,34 @@ def create_parametrized_function( # noqa: PLR0913 ) +def _substitute_matrix_elements(expression: Symbolic) -> Symbolic: + """Substitute elements of matrix symbols with actual symbol objects. + + This is a workaround for the fact that the `~sympy.core.basic.Basic.free_symbols` of + an expression containing `~sympy.matrices.expressions.MatrixSymbol`s does not + contain the individual elements of these matrix symbols, but only the matrix symbols + themselves. + + >>> import sympy as sp + >>> M = sp.MatrixSymbol("M", 2, 2) + >>> expr = M[0, 0] ** 2 + M[1, 1] ** 2 + >>> expr.free_symbols + {M} + >>> new_expr = _substitute_matrix_elements(expr) + >>> sorted(new_expr.free_symbols, key=str) + [M[0, 0], M[1, 1]] + """ + import sympy as sp + + return expression.xreplace({ + element: sp.Symbol(str(element), **element.assumptions0) + for matrix in expression.free_symbols + if isinstance(matrix, sp.MatrixSymbol) + for row in matrix.as_explicit().tolist() + for element in row + }) + + def _get_free_symbols(expression: sp.Basic) -> set[sp.Symbol]: """Get free symbols in an expression, excluding IndexedBase. @@ -154,7 +186,7 @@ def _get_free_symbols(expression: sp.Basic) -> set[sp.Symbol]: free_symbols: set[sp.Symbol] = expression.free_symbols # type: ignore[assignment] index_bases = { - sp.Symbol(s.base.name, **s.assumptions0) + sp.Symbol(str(s.base), **s.assumptions0) for s in free_symbols if isinstance(s, sp.Indexed) } @@ -324,7 +356,7 @@ def fast_lambdify( # noqa: PLR0913 top_expression, symbols, backend, use_cse=use_cse, use_jit=use_jit ) - sorted_top_symbols = sorted(sub_expressions, key=lambda s: s.name) + sorted_top_symbols = sorted(sub_expressions, key=str) top_function = lambdify( top_expression, sorted_top_symbols, backend, use_cse=use_cse, use_jit=use_jit ) diff --git a/tests/function/test_sympy.py b/tests/function/test_sympy.py index a777c3f06..84bad8477 100644 --- a/tests/function/test_sympy.py +++ b/tests/function/test_sympy.py @@ -98,6 +98,14 @@ def test_create_function_indexed_symbol(backend: str): assert func.argument_order == ("A[0]", "A[1]") +@pytest.mark.parametrize("backend", ["jax", "math", "numpy", "tf"]) +def test_create_function_matrix_symbol(backend: str): + M = sp.MatrixSymbol("M", 2, 2) # noqa: N806 + expr = M[0, 0] ** 2 + M[1, 1] ** 2 + func = create_function(expr, backend=backend) + assert func.argument_order == ("M[0, 0]", "M[1, 1]") + + @pytest.mark.parametrize("backend", ["jax", "math", "numpy", "tf"]) @pytest.mark.parametrize("max_complexity", [0, 1, 2, 3, 4, 5]) @pytest.mark.parametrize("use_cse", [False, True]) @@ -176,7 +184,7 @@ def test_split_expression(): assert expression == top_expr.xreplace(sub_expressions) free_symbols: set[sp.Symbol] = top_expr.free_symbols # type: ignore[assignment] - sub_symbols = sorted(free_symbols, key=lambda s: s.name) + sub_symbols = sorted(free_symbols, key=str) assert len(sub_symbols) == 3 f0, f1, f2 = tuple(sub_symbols) assert f0 is a