From 1757cf944de4d279640d088af806e66da88b2a5b Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Mon, 6 Apr 2026 19:07:16 -0700 Subject: [PATCH 01/11] update: add generic convergence --- .../wode/subworkflows/convergence_mixin.py | 218 +++++++++++++----- 1 file changed, 161 insertions(+), 57 deletions(-) diff --git a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py index 000a05de..216adc9f 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py +++ b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py @@ -1,3 +1,4 @@ +import re from typing import Any, Dict, List, Optional, Protocol, cast from ..context.providers import PointsGridDataProvider @@ -76,66 +77,32 @@ def _merge_convergence_context(unit_context: Dict[str, Any], convergence_context merged_context["kgrid"] = merged_kgrid_context return merged_context - def add_convergence( + def _build_convergence_units( self, - parameter: str, - parameter_initial: Any, - parameter_increment: Any, - reciprocal_vector_ratios: Optional[List[float]] = None, - result: str = ENERGY_CONVERGENCE_RESULT, + param_name: str, + param_initial_value: str, + param_increment_expr: str, + param_final_value: str, + param_input: List[Dict[str, str]], + result_name: str, + result_unit_flowchart_id: str, + execution_unit_flowchart_id: str, result_initial: Any = 0, condition: Optional[str] = None, operator: str = "<", tolerance: Any = 1e-5, max_occurrences: int = 10, ) -> None: - # Used for type checking correctness host = cast(ConvergenceHost, self) - parameter_name = ConvergenceParameterNameEnum(parameter) - - if result != ENERGY_CONVERGENCE_RESULT: - raise ValueError(f"Unsupported convergence result: {result}") - - unit_for_convergence = self._find_unit_for_convergence(result) - if unit_for_convergence is None: - raise ValueError(f"Subworkflow does not contain a unit with '{result}' as an extracted property.") - - if ( - parameter_name - in ( - ConvergenceParameterNameEnum.N_k_nonuniform, - ConvergenceParameterNameEnum.N_k_nonuniform_2D, - ) - and reciprocal_vector_ratios is None - ): - reciprocal_vector_ratios = PointsGridDataProvider( - context=unit_for_convergence.context - ).get_reciprocal_vector_ratios() - if reciprocal_vector_ratios is None: - raise ValueError("Non-uniform k-grid convergence requires reciprocal_vector_ratios to be provided.") - - param = create_convergence_parameter( - name=parameter_name.value, - initial_value=parameter_initial, - increment=parameter_increment, - reciprocal_vector_ratios=reciprocal_vector_ratios, - ) - - merged_context = self._merge_convergence_context( - unit_for_convergence.context, - param.unit_context, - ) - unit_for_convergence.set_context(merged_context) - prev_result = "prev_result" iteration = "iteration" - condition_expression = condition or f"abs(({prev_result}-{result})/{result})" + condition_expression = condition or f"abs(({prev_result}-{result_name})/{result_name})" param_init = Unit( name="init parameter", type="assignment", - operand=param.name, - value=param.initial_value, + operand=param_name, + value=param_initial_value, tags=[CONVERGENCE_PARAMETER_TAG], ) prev_result_init = Unit(name="init result", type="assignment", operand=prev_result, value=result_initial) @@ -143,28 +110,28 @@ def add_convergence( store_result = Unit( name="update result", type="assignment", - input=[{"scope": unit_for_convergence.flowchartId, "name": result}], - operand=result, - value=result, + input=[{"scope": result_unit_flowchart_id, "name": result_name}], + operand=result_name, + value=result_name, tags=[CONVERGENCE_RESULT_TAG], ) store_prev_result = Unit( name="store result", type="assignment", - input=[{"scope": unit_for_convergence.flowchartId, "name": result}], + input=[{"scope": result_unit_flowchart_id, "name": result_name}], operand=prev_result, - value=result, + value=result_name, ) next_iter = Unit(name="update counter", type="assignment", operand=iteration, value=f"{iteration} + 1") next_step = Unit( name="update parameter", type="assignment", - input=param.use_variables_from_unit_context(unit_for_convergence.flowchartId), - operand=param.name, - value=param.increment, - next=unit_for_convergence.flowchartId, + input=param_input, + operand=param_name, + value=param_increment_expr, + next=execution_unit_flowchart_id, ) - exit_unit = Unit(name="exit", type="assignment", operand=param.name, value=param.final_value) + exit_unit = Unit(name="exit", type="assignment", operand=param_name, value=param_final_value) condition_unit = Unit( name="check convergence", type="condition", @@ -186,4 +153,141 @@ def add_convergence( host.add_unit(next_step) host.add_unit(exit_unit) - next_step.next = unit_for_convergence.flowchartId + next_step.next = execution_unit_flowchart_id + + def add_convergence( + self, + parameter: str, + parameter_initial: Any, + parameter_increment: Any, + reciprocal_vector_ratios: Optional[List[float]] = None, + result: str = ENERGY_CONVERGENCE_RESULT, + result_initial: Any = 0, + condition: Optional[str] = None, + operator: str = "<", + tolerance: Any = 1e-5, + max_occurrences: int = 10, + ) -> None: + parameter_name = ConvergenceParameterNameEnum(parameter) + + if result != ENERGY_CONVERGENCE_RESULT: + raise ValueError(f"Unsupported convergence result: {result}") + + unit_for_convergence = self._find_unit_for_convergence(result) + if unit_for_convergence is None: + raise ValueError(f"Subworkflow does not contain a unit with '{result}' as an extracted property.") + + if ( + parameter_name + in ( + ConvergenceParameterNameEnum.N_k_nonuniform, + ConvergenceParameterNameEnum.N_k_nonuniform_2D, + ) + and reciprocal_vector_ratios is None + ): + reciprocal_vector_ratios = PointsGridDataProvider( + context=unit_for_convergence.context + ).get_reciprocal_vector_ratios() + if reciprocal_vector_ratios is None: + raise ValueError("Non-uniform k-grid convergence requires reciprocal_vector_ratios to be provided.") + + param = create_convergence_parameter( + name=parameter_name.value, + initial_value=parameter_initial, + increment=parameter_increment, + reciprocal_vector_ratios=reciprocal_vector_ratios, + ) + + merged_context = self._merge_convergence_context( + unit_for_convergence.context, + param.unit_context, + ) + unit_for_convergence.set_context(merged_context) + + self._build_convergence_units( + param_name=param.name, + param_initial_value=param.initial_value, + param_increment_expr=param.increment, + param_final_value=param.final_value, + param_input=param.use_variables_from_unit_context(unit_for_convergence.flowchartId), + result_name=result, + result_unit_flowchart_id=unit_for_convergence.flowchartId, + execution_unit_flowchart_id=unit_for_convergence.flowchartId, + result_initial=result_initial, + condition=condition, + operator=operator, + tolerance=tolerance, + max_occurrences=max_occurrences, + ) + + def add_template_param_convergence( + self, + param_name: str, + param_initial: Any, + param_increment: Any, + result_name: str, + result_initial: Any = 0, + condition: Optional[str] = None, + operator: str = "<", + tolerance: Any = 1e-3, + max_occurrences: int = 10, + ) -> None: + """ + Add a convergence loop for an arbitrary template parameter. + + Uses regex substitution to inject the parameter as a runtime scope variable into the + execution unit's input template, then delegates to _build_convergence_units. + + Args: + param_name: Parameter name as it appears in the input template (e.g. "degauss"). + param_initial: Starting value of the parameter. + param_increment: Scalar step added each iteration. + result_name: Name of the result property to monitor (must exist in a unit's results). + result_initial: Seed value for the result before the first iteration. + condition: Optional custom convergence condition expression. + operator: Comparison operator for the convergence condition (default "<"). + tolerance: Convergence threshold. + max_occurrences: Maximum number of loop iterations. + """ + host = cast(ConvergenceHost, self) + execution_unit = next((u for u in host.units if u.type == "execution"), None) + if execution_unit is None: + raise ValueError("No execution unit found in subworkflow.") + + result_unit = self._find_unit_for_convergence(result_name) + if result_unit is None: + raise ValueError(f"No unit with result '{result_name}' found in subworkflow.") + + self._inject_template_variable(execution_unit, param_name) + execution_unit.set_context({**execution_unit.context, param_name: param_initial}) + + self._build_convergence_units( + param_name=param_name, + param_initial_value=param_initial, + param_increment_expr=f"{param_name} + {param_increment}", + param_final_value=param_name, + param_input=[], + result_name=result_name, + result_unit_flowchart_id=result_unit.flowchartId, + execution_unit_flowchart_id=execution_unit.flowchartId, + result_initial=result_initial, + condition=condition, + operator=operator, + tolerance=tolerance, + max_occurrences=max_occurrences, + ) + + @staticmethod + def _inject_template_variable(unit, param_name: str) -> None: + """Replace a value assignment for param_name in the unit's input template. + + Auto-generates a regex matching either a bare numeric value or an existing Jinja2 + expression (e.g. `ecutwfc = {{ cutoffs.wavefunction }}`), replacing it with a runtime + scope variable wrapped in {%raw%}...{%endraw%} so Jinja2 pre-rendering leaves it intact. + """ + numeric = r"[\d.e+\-]+" + jinja_var = r"\{\{[^}]+\}\}" + pattern = rf"{param_name}\s*=\s*(?:{numeric}|{jinja_var})" + replacement = f"{param_name} = {{% raw %}}{{{{ {param_name} }}}}{{% endraw %}}" + for input_item in unit.input: + input_item["content"] = re.sub(pattern, replacement, input_item["content"]) From 7a52ce452645c86658958c95af27a73704df807f Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Mon, 6 Apr 2026 19:17:44 -0700 Subject: [PATCH 02/11] update: add test --- tests/py/test_convergence.py | 72 +++++++++++++++++++++++++++++++++++- 1 file changed, 71 insertions(+), 1 deletion(-) diff --git a/tests/py/test_convergence.py b/tests/py/test_convergence.py index 900972b0..58f3ef48 100644 --- a/tests/py/test_convergence.py +++ b/tests/py/test_convergence.py @@ -1,7 +1,9 @@ +import pytest +from mat3ra.esse.models.workflow.subworkflow.convergence.enum_options import ConvergenceParameterNameEnum from mat3ra.made.lattice import Lattice from mat3ra.standata.workflows import WorkflowStandata + from mat3ra.wode import Workflow -from mat3ra.esse.models.workflow.subworkflow.convergence.enum_options import ConvergenceParameterNameEnum def _build_total_energy_subworkflow(): @@ -167,3 +169,71 @@ def test_convergence_series_uses_scope_track(): {"x": 1, "param": 1, "y": -10.0}, {"x": 2, "param": 2, "y": -10.5}, ] + + +TEMPLATE_PARAM_TEST_CASES = [ + pytest.param( + "degauss", + 0.001, + 0.002, + "total_energy", + "degauss = 0.005", + id="degauss_bare_numeric", + ), + pytest.param( + "ecutwfc", + 20, + 10, + "total_energy", + "ecutwfc = {{ cutoffs.wavefunction }}", + id="ecutwfc_jinja_expression", + ), +] + + +@pytest.mark.parametrize( + "param_name,param_initial,param_increment,result_name,original_pattern", TEMPLATE_PARAM_TEST_CASES +) +def test_add_template_param_convergence(param_name, param_initial, param_increment, result_name, original_pattern): + subworkflow = _build_total_energy_subworkflow() + + subworkflow.add_template_param_convergence( + param_name=param_name, + param_initial=param_initial, + param_increment=param_increment, + result_name=result_name, + tolerance=1e-3, + ) + + assert [unit.name for unit in subworkflow.units] == [ + "init parameter", + "init result", + "init counter", + "pw_scf", + "update result", + "check convergence", + "store result", + "update counter", + "update parameter", + "exit", + ] + + pw_scf = subworkflow.get_unit_by_name(name="pw_scf") + assert pw_scf.context[param_name] == param_initial + template_content = pw_scf.input[0]["content"] + assert f"{param_name} = {{% raw %}}{{{{ {param_name} }}}}{{% endraw %}}" in template_content + assert original_pattern not in template_content + + assert subworkflow.convergence_param == param_name + assert subworkflow.convergence_result == result_name + assert subworkflow.has_convergence is True + + update_parameter = subworkflow.get_unit_by_name(name="update parameter") + assert update_parameter.operand == param_name + assert update_parameter.value == f"{param_name} + {param_increment}" + assert update_parameter.input == [] + assert update_parameter.next == pw_scf.flowchartId + + exit_unit = subworkflow.get_unit_by_name(name="exit") + assert exit_unit.operand == param_name + assert exit_unit.value == param_name From a0940f2a3d74e6a3fdd043902d6ff5fbdb36703d Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Mon, 6 Apr 2026 19:32:45 -0700 Subject: [PATCH 03/11] update: multi unit --- .../wode/subworkflows/convergence_mixin.py | 11 ++++---- tests/py/test_convergence.py | 28 +++++++++++++++++++ 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py index 216adc9f..999d5f48 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py +++ b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py @@ -250,16 +250,17 @@ def add_template_param_convergence( max_occurrences: Maximum number of loop iterations. """ host = cast(ConvergenceHost, self) - execution_unit = next((u for u in host.units if u.type == "execution"), None) - if execution_unit is None: + execution_units = [u for u in host.units if u.type == "execution"] + if not execution_units: raise ValueError("No execution unit found in subworkflow.") result_unit = self._find_unit_for_convergence(result_name) if result_unit is None: raise ValueError(f"No unit with result '{result_name}' found in subworkflow.") - self._inject_template_variable(execution_unit, param_name) - execution_unit.set_context({**execution_unit.context, param_name: param_initial}) + for execution_unit in execution_units: + self._inject_template_variable(execution_unit, param_name) + execution_unit.set_context({**execution_unit.context, param_name: param_initial}) self._build_convergence_units( param_name=param_name, @@ -269,7 +270,7 @@ def add_template_param_convergence( param_input=[], result_name=result_name, result_unit_flowchart_id=result_unit.flowchartId, - execution_unit_flowchart_id=execution_unit.flowchartId, + execution_unit_flowchart_id=execution_units[0].flowchartId, result_initial=result_initial, condition=condition, operator=operator, diff --git a/tests/py/test_convergence.py b/tests/py/test_convergence.py index 58f3ef48..ce35014b 100644 --- a/tests/py/test_convergence.py +++ b/tests/py/test_convergence.py @@ -237,3 +237,31 @@ def test_add_template_param_convergence(param_name, param_initial, param_increme exit_unit = subworkflow.get_unit_by_name(name="exit") assert exit_unit.operand == param_name assert exit_unit.value == param_name + + +def test_add_template_param_convergence_multi_unit(): + workflow_config = WorkflowStandata.filter_by_application("espresso").get_by_name_first_match("band_structure.json") + workflow = Workflow.create(workflow_config) + subworkflow = workflow.subworkflows[0] + + subworkflow.add_template_param_convergence( + param_name="ecutwfc", + param_initial=20, + param_increment=10, + result_name="total_energy", + ) + + execution_units = [u for u in subworkflow.units if u.type == "execution"] + assert len(execution_units) == 3 + + pw_scf = subworkflow.get_unit_by_name("pw_scf") + pw_bands = subworkflow.get_unit_by_name("pw_bands") + + for unit in [pw_scf, pw_bands]: + assert unit.context["ecutwfc"] == 20 + template_content = unit.input[0]["content"] + assert "ecutwfc = {% raw %}{{ ecutwfc }}{% endraw %}" in template_content + assert "ecutwfc = {{ cutoffs.wavefunction }}" not in template_content + + assert subworkflow.convergence_param == "ecutwfc" + assert subworkflow.convergence_result == "total_energy" From ba005c2e0e7aebb05c14d8268cc4f66ad655e748 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 7 Apr 2026 11:26:36 -0700 Subject: [PATCH 04/11] update: input item --- src/py/mat3ra/wode/units/execution.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/py/mat3ra/wode/units/execution.py b/src/py/mat3ra/wode/units/execution.py index 4d0f2688..357fe43d 100644 --- a/src/py/mat3ra/wode/units/execution.py +++ b/src/py/mat3ra/wode/units/execution.py @@ -1,6 +1,7 @@ -from typing import Any, Dict, List, Literal +from typing import Any, Dict, List, Literal, Optional -from mat3ra.ade import Application, Executable, Flavor +from mat3ra.ade import Application, Executable, Flavor, Template +from mat3ra.code.entity import InMemoryEntitySnakeCase from mat3ra.esse.models.workflow.unit.execution import ExecutionUnitSchemaBase from mat3ra.utils import ( calculate_hash_from_object, @@ -13,12 +14,19 @@ from .unit import Unit +# TODO: use from ESSE when epic/SOF-7756 merged +class ExecutionUnitInputItem(InMemoryEntitySnakeCase): + template: Template = Field(default_factory=Template) + rendered: str + isManuallyChanged: bool = False + + class ExecutionUnit(Unit, ExecutionUnitSchemaBase): type: Literal["execution"] = "execution" executable: Executable = None flavor: Flavor = None application: Application = None - input: List = Field(default_factory=list) + input: List[ExecutionUnitInputItem] = Field(default_factory=List[ExecutionUnitInputItem]) def get_hash_object(self) -> Dict[str, Any]: app = self.application.to_dict() if self.application else {} From 23406e1f3265eed989555dfc32942cc5bd31e897 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 7 Apr 2026 14:02:22 -0700 Subject: [PATCH 05/11] update: use input item --- pyproject.toml | 4 +-- src/py/mat3ra/wode/units/execution.py | 44 +++++++++++++++++++++++---- 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index daa5a951..59fc6ca0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,10 +21,10 @@ classifiers = [ dependencies = [ "numpy", "mat3ra-code", - "mat3ra-utils", + "mat3ra-utils @ git+https://github.com/Exabyte-io/utils.git@d7f6c5cc020283f6bd8eb6434d4632bf57561bcf", "mat3ra-esse", "mat3ra-mode", - "mat3ra-ade", + "mat3ra-ade @ git+https://github.com/Exabyte-io/ade.git@d66b4cb1363d292fc3918767eb0eeeb538d3dabd", "mat3ra-made", "mat3ra-standata" ] diff --git a/src/py/mat3ra/wode/units/execution.py b/src/py/mat3ra/wode/units/execution.py index 357fe43d..196dbcce 100644 --- a/src/py/mat3ra/wode/units/execution.py +++ b/src/py/mat3ra/wode/units/execution.py @@ -9,7 +9,7 @@ remove_empty_lines_from_string, remove_timestampable_keys, ) -from pydantic import Field +from pydantic import Field, model_validator from .unit import Unit @@ -20,6 +20,34 @@ class ExecutionUnitInputItem(InMemoryEntitySnakeCase): rendered: str isManuallyChanged: bool = False + @model_validator(mode="before") + @classmethod + def handle_legacy_flat_structure(cls, data: Any) -> Any: + if isinstance(data, dict) and "name" in data and "content" in data and "template" not in data: + template_fields = { + k: v + for k, v in data.items() + if k + in [ + "_id", + "slug", + "systemName", + "schemaVersion", + "name", + "applicationName", + "applicationVersion", + "executableName", + "contextProviders", + "content", + ] + } + return { + "template": template_fields, + "rendered": data.get("rendered", ""), + "isManuallyChanged": data.get("isManuallyChanged", False), + } + return data + class ExecutionUnit(Unit, ExecutionUnitSchemaBase): type: Literal["execution"] = "execution" @@ -28,17 +56,21 @@ class ExecutionUnit(Unit, ExecutionUnitSchemaBase): application: Application = None input: List[ExecutionUnitInputItem] = Field(default_factory=List[ExecutionUnitInputItem]) + def replace_in_input_content(self, pattern: str, replacement: str) -> None: + for input_item in self.input: + input_item.template.replace_in_content(pattern, replacement) + + def replace_variable_value_in_inputs(self, variable_name: str, new_value: str) -> None: + for input_item in self.input: + input_item.template.replace_variable_value(variable_name, new_value) + def get_hash_object(self) -> Dict[str, Any]: app = self.application.to_dict() if self.application else {} exe = self.executable.to_dict() if self.executable else {} flv = self.flavor.to_dict() if self.flavor else {} input_items = self.input if isinstance(self.input, list) else [] input_hash = calculate_hash_from_object( - [ - remove_empty_lines_from_string(remove_comments_from_source_code(i.get("content", ""))) - for i in input_items - if isinstance(i, dict) - ] + [remove_empty_lines_from_string(remove_comments_from_source_code(i.template.content)) for i in input_items] ) return { **super().get_hash_object(), From 0e8d86f4c0bc162e7fe0efdaf63dd3cf4a139785 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 7 Apr 2026 14:03:03 -0700 Subject: [PATCH 06/11] update: adjust convergence mixin --- .../wode/subworkflows/convergence_mixin.py | 24 ++++--------------- tests/py/test_convergence.py | 6 +++-- 2 files changed, 9 insertions(+), 21 deletions(-) diff --git a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py index 999d5f48..08336ebb 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py +++ b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py @@ -1,10 +1,10 @@ -import re from typing import Any, Dict, List, Optional, Protocol, cast +from mat3ra.ade import Template +from mat3ra.esse.models.workflow.subworkflow.convergence.enum_options import ConvergenceParameterNameEnum + from ..context.providers import PointsGridDataProvider from ..units import Unit - -from mat3ra.esse.models.workflow.subworkflow.convergence.enum_options import ConvergenceParameterNameEnum from .convergence.factory import create_convergence_parameter CONVERGENCE_PARAMETER_TAG = "hasConvergenceParam" @@ -258,8 +258,9 @@ def add_template_param_convergence( if result_unit is None: raise ValueError(f"No unit with result '{result_name}' found in subworkflow.") + scope_reference = Template.make_raw_scope_reference(param_name) for execution_unit in execution_units: - self._inject_template_variable(execution_unit, param_name) + execution_unit.replace_variable_value_in_inputs(param_name, scope_reference) execution_unit.set_context({**execution_unit.context, param_name: param_initial}) self._build_convergence_units( @@ -277,18 +278,3 @@ def add_template_param_convergence( tolerance=tolerance, max_occurrences=max_occurrences, ) - - @staticmethod - def _inject_template_variable(unit, param_name: str) -> None: - """Replace a value assignment for param_name in the unit's input template. - - Auto-generates a regex matching either a bare numeric value or an existing Jinja2 - expression (e.g. `ecutwfc = {{ cutoffs.wavefunction }}`), replacing it with a runtime - scope variable wrapped in {%raw%}...{%endraw%} so Jinja2 pre-rendering leaves it intact. - """ - numeric = r"[\d.e+\-]+" - jinja_var = r"\{\{[^}]+\}\}" - pattern = rf"{param_name}\s*=\s*(?:{numeric}|{jinja_var})" - replacement = f"{param_name} = {{% raw %}}{{{{ {param_name} }}}}{{% endraw %}}" - for input_item in unit.input: - input_item["content"] = re.sub(pattern, replacement, input_item["content"]) diff --git a/tests/py/test_convergence.py b/tests/py/test_convergence.py index ce35014b..d3b770fc 100644 --- a/tests/py/test_convergence.py +++ b/tests/py/test_convergence.py @@ -220,7 +220,8 @@ def test_add_template_param_convergence(param_name, param_initial, param_increme pw_scf = subworkflow.get_unit_by_name(name="pw_scf") assert pw_scf.context[param_name] == param_initial - template_content = pw_scf.input[0]["content"] + input_item = pw_scf.input[0] + template_content = input_item.template.content assert f"{param_name} = {{% raw %}}{{{{ {param_name} }}}}{{% endraw %}}" in template_content assert original_pattern not in template_content @@ -259,7 +260,8 @@ def test_add_template_param_convergence_multi_unit(): for unit in [pw_scf, pw_bands]: assert unit.context["ecutwfc"] == 20 - template_content = unit.input[0]["content"] + input_item = unit.input[0] + template_content = input_item.template.content assert "ecutwfc = {% raw %}{{ ecutwfc }}{% endraw %}" in template_content assert "ecutwfc = {{ cutoffs.wavefunction }}" not in template_content From c75cbe0ffae4eefa2b0877c6fea1d324afb78370 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 7 Apr 2026 14:08:51 -0700 Subject: [PATCH 07/11] chore: lint --- src/py/mat3ra/wode/subworkflows/convergence_mixin.py | 2 +- src/py/mat3ra/wode/units/execution.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py index 08336ebb..b5c7743c 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py +++ b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py @@ -3,9 +3,9 @@ from mat3ra.ade import Template from mat3ra.esse.models.workflow.subworkflow.convergence.enum_options import ConvergenceParameterNameEnum +from .convergence.factory import create_convergence_parameter from ..context.providers import PointsGridDataProvider from ..units import Unit -from .convergence.factory import create_convergence_parameter CONVERGENCE_PARAMETER_TAG = "hasConvergenceParam" CONVERGENCE_RESULT_TAG = "hasConvergenceResult" diff --git a/src/py/mat3ra/wode/units/execution.py b/src/py/mat3ra/wode/units/execution.py index 196dbcce..3a1b7d00 100644 --- a/src/py/mat3ra/wode/units/execution.py +++ b/src/py/mat3ra/wode/units/execution.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal from mat3ra.ade import Application, Executable, Flavor, Template from mat3ra.code.entity import InMemoryEntitySnakeCase From 859d81bcba428b67190752379d773b744f447ec5 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 7 Apr 2026 16:37:44 -0700 Subject: [PATCH 08/11] update: add test for inputs --- tests/py/units/test_execution_unit.py | 74 +++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 tests/py/units/test_execution_unit.py diff --git a/tests/py/units/test_execution_unit.py b/tests/py/units/test_execution_unit.py new file mode 100644 index 00000000..a3b00b9d --- /dev/null +++ b/tests/py/units/test_execution_unit.py @@ -0,0 +1,74 @@ +import pytest +from mat3ra.ade import Template +from mat3ra.wode.units.execution import ExecutionUnit, ExecutionUnitInputItem + +CONTENT_DEGAUSS = "degauss = 0.005\n" +CONTENT_ECUTWFC_JINJA = "ecutwfc = {{ cutoffs.wavefunction }}\n" +CONTENT_MULTI = "degauss = 0.005\necutwfc = 40\n" + +RAW_SCOPE_DEGAUSS = "{% raw %}{{ degauss }}{% endraw %}" +RAW_SCOPE_ECUTWFC = "{% raw %}{{ ecutwfc }}{% endraw %}" + + +def _make_unit(*contents: str) -> ExecutionUnit: + inputs = [ + ExecutionUnitInputItem(template=Template(name="pw.in", content=c), rendered=c) + for c in contents + ] + return ExecutionUnit(name="pw_scf", input=inputs) + + +REPLACE_IN_INPUT_CONTENT_CASES = [ + pytest.param( + [CONTENT_DEGAUSS], + r"degauss\s*=\s*[\d.e+\-]+", + f"degauss = {RAW_SCOPE_DEGAUSS}", + [f"degauss = {RAW_SCOPE_DEGAUSS}\n"], + id="single_input_numeric", + ), + pytest.param( + [CONTENT_DEGAUSS, CONTENT_ECUTWFC_JINJA], + r"degauss\s*=\s*[\d.e+\-]+", + f"degauss = {RAW_SCOPE_DEGAUSS}", + [f"degauss = {RAW_SCOPE_DEGAUSS}\n", CONTENT_ECUTWFC_JINJA], + id="multiple_inputs_only_first_matches", + ), +] + +REPLACE_VARIABLE_VALUE_IN_INPUTS_CASES = [ + pytest.param( + [CONTENT_DEGAUSS], + "degauss", + RAW_SCOPE_DEGAUSS, + [f"degauss = {RAW_SCOPE_DEGAUSS}\n"], + id="bare_numeric", + ), + pytest.param( + [CONTENT_ECUTWFC_JINJA], + "ecutwfc", + RAW_SCOPE_ECUTWFC, + [f"ecutwfc = {RAW_SCOPE_ECUTWFC}\n"], + id="jinja_expression", + ), + pytest.param( + [CONTENT_DEGAUSS, CONTENT_ECUTWFC_JINJA], + "degauss", + RAW_SCOPE_DEGAUSS, + [f"degauss = {RAW_SCOPE_DEGAUSS}\n", CONTENT_ECUTWFC_JINJA], + id="multi_input_partial_match", + ), +] + + +@pytest.mark.parametrize("contents,pattern,replacement,expected_contents", REPLACE_IN_INPUT_CONTENT_CASES) +def test_replace_in_input_content(contents, pattern, replacement, expected_contents): + unit = _make_unit(*contents) + unit.replace_in_input_content(pattern, replacement) + assert [item.template.content for item in unit.input] == expected_contents + + +@pytest.mark.parametrize("contents,variable_name,new_value,expected_contents", REPLACE_VARIABLE_VALUE_IN_INPUTS_CASES) +def test_replace_variable_value_in_inputs(contents, variable_name, new_value, expected_contents): + unit = _make_unit(*contents) + unit.replace_variable_value_in_inputs(variable_name, new_value) + assert [item.template.content for item in unit.input] == expected_contents From 6bafefccbc9a83a1c73ed5216f8ca63b1dd5e702 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 7 Apr 2026 16:47:01 -0700 Subject: [PATCH 09/11] chore: renanme param -> parameter --- pyproject.toml | 2 +- .../wode/subworkflows/convergence_mixin.py | 74 +++++++++---------- tests/py/test_convergence.py | 26 +++---- 3 files changed, 51 insertions(+), 51 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 59fc6ca0..520c5169 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ dependencies = [ "numpy", "mat3ra-code", - "mat3ra-utils @ git+https://github.com/Exabyte-io/utils.git@d7f6c5cc020283f6bd8eb6434d4632bf57561bcf", + "mat3ra-utils @ git+https://github.com/Exabyte-io/utils.git@c256d6a774a45ced9cb4a051bd5126f1ec2df520", "mat3ra-esse", "mat3ra-mode", "mat3ra-ade @ git+https://github.com/Exabyte-io/ade.git@d66b4cb1363d292fc3918767eb0eeeb538d3dabd", diff --git a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py index b5c7743c..f604c0a1 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py +++ b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py @@ -30,7 +30,7 @@ def scalar_results(self) -> List[str]: return [ENERGY_CONVERGENCE_RESULT] @property - def convergence_param(self) -> Optional[str]: + def convergence_parameter(self) -> Optional[str]: unit = cast(ConvergenceHost, self).find_unit_with_tag(CONVERGENCE_PARAMETER_TAG) return getattr(unit, "operand", None) if unit else None @@ -41,7 +41,7 @@ def convergence_result(self) -> Optional[str]: @property def has_convergence(self) -> bool: - return bool(self.convergence_param and self.convergence_result) + return bool(self.convergence_parameter and self.convergence_result) def convergence_series(self, scope_track: Optional[List[Dict[str, Any]]]) -> List[Dict[str, Any]]: if not self.has_convergence or not scope_track: @@ -51,12 +51,12 @@ def convergence_series(self, scope_track: Optional[List[Dict[str, Any]]]) -> Lis series: List[Dict[str, Any]] = [] for scope_item in scope_track: global_scope = ((scope_item or {}).get("scope") or {}).get("global") or {} - param = global_scope.get(self.convergence_param) + parameter = global_scope.get(self.convergence_parameter) result = global_scope.get(self.convergence_result) is_new_result = result is not None and result != last_result last_result = result if is_new_result: - series.append({"x": len(series) + 1, "param": param, "y": result}) + series.append({"x": len(series) + 1, "parameter": parameter, "y": result}) return series def _find_unit_for_convergence(self, result: str): @@ -79,11 +79,11 @@ def _merge_convergence_context(unit_context: Dict[str, Any], convergence_context def _build_convergence_units( self, - param_name: str, - param_initial_value: str, - param_increment_expr: str, - param_final_value: str, - param_input: List[Dict[str, str]], + parameter_name: str, + parameter_initial_value: str, + parameter_increment_expr: str, + parameter_final_value: str, + parameter_input: List[Dict[str, str]], result_name: str, result_unit_flowchart_id: str, execution_unit_flowchart_id: str, @@ -101,8 +101,8 @@ def _build_convergence_units( param_init = Unit( name="init parameter", type="assignment", - operand=param_name, - value=param_initial_value, + operand=parameter_name, + value=parameter_initial_value, tags=[CONVERGENCE_PARAMETER_TAG], ) prev_result_init = Unit(name="init result", type="assignment", operand=prev_result, value=result_initial) @@ -126,12 +126,12 @@ def _build_convergence_units( next_step = Unit( name="update parameter", type="assignment", - input=param_input, - operand=param_name, - value=param_increment_expr, + input=parameter_input, + operand=parameter_name, + value=parameter_increment_expr, next=execution_unit_flowchart_id, ) - exit_unit = Unit(name="exit", type="assignment", operand=param_name, value=param_final_value) + exit_unit = Unit(name="exit", type="assignment", operand=parameter_name, value=parameter_final_value) condition_unit = Unit( name="check convergence", type="condition", @@ -191,7 +191,7 @@ def add_convergence( if reciprocal_vector_ratios is None: raise ValueError("Non-uniform k-grid convergence requires reciprocal_vector_ratios to be provided.") - param = create_convergence_parameter( + parameter = create_convergence_parameter( name=parameter_name.value, initial_value=parameter_initial, increment=parameter_increment, @@ -200,16 +200,16 @@ def add_convergence( merged_context = self._merge_convergence_context( unit_for_convergence.context, - param.unit_context, + parameter.unit_context, ) unit_for_convergence.set_context(merged_context) self._build_convergence_units( - param_name=param.name, - param_initial_value=param.initial_value, - param_increment_expr=param.increment, - param_final_value=param.final_value, - param_input=param.use_variables_from_unit_context(unit_for_convergence.flowchartId), + parameter_name=parameter.name, + parameter_initial_value=parameter.initial_value, + parameter_increment_expr=parameter.increment, + parameter_final_value=parameter.final_value, + parameter_input=parameter.use_variables_from_unit_context(unit_for_convergence.flowchartId), result_name=result, result_unit_flowchart_id=unit_for_convergence.flowchartId, execution_unit_flowchart_id=unit_for_convergence.flowchartId, @@ -220,11 +220,11 @@ def add_convergence( max_occurrences=max_occurrences, ) - def add_template_param_convergence( + def add_template_parameter_convergence( self, - param_name: str, - param_initial: Any, - param_increment: Any, + parameter_name: str, + parameter_initial: Any, + parameter_increment: Any, result_name: str, result_initial: Any = 0, condition: Optional[str] = None, @@ -239,9 +239,9 @@ def add_template_param_convergence( execution unit's input template, then delegates to _build_convergence_units. Args: - param_name: Parameter name as it appears in the input template (e.g. "degauss"). - param_initial: Starting value of the parameter. - param_increment: Scalar step added each iteration. + parameter_name: Parameter name as it appears in the input template (e.g. "degauss"). + parameter_initial: Starting value of the parameter. + parameter_increment: Scalar step added each iteration. result_name: Name of the result property to monitor (must exist in a unit's results). result_initial: Seed value for the result before the first iteration. condition: Optional custom convergence condition expression. @@ -258,17 +258,17 @@ def add_template_param_convergence( if result_unit is None: raise ValueError(f"No unit with result '{result_name}' found in subworkflow.") - scope_reference = Template.make_raw_scope_reference(param_name) + scope_reference = Template.make_raw_scope_reference(parameter_name) for execution_unit in execution_units: - execution_unit.replace_variable_value_in_inputs(param_name, scope_reference) - execution_unit.set_context({**execution_unit.context, param_name: param_initial}) + execution_unit.replace_variable_value_in_inputs(parameter_name, scope_reference) + execution_unit.set_context({**execution_unit.context, parameter_name: parameter_initial}) self._build_convergence_units( - param_name=param_name, - param_initial_value=param_initial, - param_increment_expr=f"{param_name} + {param_increment}", - param_final_value=param_name, - param_input=[], + parameter_name=parameter_name, + parameter_initial_value=parameter_initial, + parameter_increment_expr=f"{parameter_name} + {parameter_increment}", + parameter_final_value=parameter_name, + parameter_input=[], result_name=result_name, result_unit_flowchart_id=result_unit.flowchartId, execution_unit_flowchart_id=execution_units[0].flowchartId, diff --git a/tests/py/test_convergence.py b/tests/py/test_convergence.py index d3b770fc..3c8625dc 100644 --- a/tests/py/test_convergence.py +++ b/tests/py/test_convergence.py @@ -46,7 +46,7 @@ def test_add_uniform_energy_convergence(): assert pw_scf.context["isKgridEdited"] is True assert pw_scf.context["isUsingJinjaVariables"] is True - assert subworkflow.convergence_param == ConvergenceParameterNameEnum.N_k.value + assert subworkflow.convergence_parameter == ConvergenceParameterNameEnum.N_k.value assert subworkflow.convergence_result == "total_energy" assert subworkflow.has_convergence is True @@ -166,8 +166,8 @@ def test_convergence_series_uses_scope_track(): ] assert subworkflow.convergence_series(scope_track) == [ - {"x": 1, "param": 1, "y": -10.0}, - {"x": 2, "param": 2, "y": -10.5}, + {"x": 1, "parameter": 1, "y": -10.0}, + {"x": 2, "parameter": 2, "y": -10.5}, ] @@ -197,10 +197,10 @@ def test_convergence_series_uses_scope_track(): def test_add_template_param_convergence(param_name, param_initial, param_increment, result_name, original_pattern): subworkflow = _build_total_energy_subworkflow() - subworkflow.add_template_param_convergence( - param_name=param_name, - param_initial=param_initial, - param_increment=param_increment, + subworkflow.add_template_parameter_convergence( + parameter_name=param_name, + parameter_initial=param_initial, + parameter_increment=param_increment, result_name=result_name, tolerance=1e-3, ) @@ -225,7 +225,7 @@ def test_add_template_param_convergence(param_name, param_initial, param_increme assert f"{param_name} = {{% raw %}}{{{{ {param_name} }}}}{{% endraw %}}" in template_content assert original_pattern not in template_content - assert subworkflow.convergence_param == param_name + assert subworkflow.convergence_parameter == param_name assert subworkflow.convergence_result == result_name assert subworkflow.has_convergence is True @@ -245,10 +245,10 @@ def test_add_template_param_convergence_multi_unit(): workflow = Workflow.create(workflow_config) subworkflow = workflow.subworkflows[0] - subworkflow.add_template_param_convergence( - param_name="ecutwfc", - param_initial=20, - param_increment=10, + subworkflow.add_template_parameter_convergence( + parameter_name="ecutwfc", + parameter_initial=20, + parameter_increment=10, result_name="total_energy", ) @@ -265,5 +265,5 @@ def test_add_template_param_convergence_multi_unit(): assert "ecutwfc = {% raw %}{{ ecutwfc }}{% endraw %}" in template_content assert "ecutwfc = {{ cutoffs.wavefunction }}" not in template_content - assert subworkflow.convergence_param == "ecutwfc" + assert subworkflow.convergence_parameter == "ecutwfc" assert subworkflow.convergence_result == "total_energy" From e03be6ed632d3e3dc8a888f1de79664e6b5e63c5 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 7 Apr 2026 17:20:32 -0700 Subject: [PATCH 10/11] chore: use existing approach --- pyproject.toml | 4 +- .../wode/subworkflows/convergence_mixin.py | 2 +- tests/py/units/test_execution_unit.py | 79 ++++++++----------- 3 files changed, 36 insertions(+), 49 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 520c5169..c39a141b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,10 +21,10 @@ classifiers = [ dependencies = [ "numpy", "mat3ra-code", - "mat3ra-utils @ git+https://github.com/Exabyte-io/utils.git@c256d6a774a45ced9cb4a051bd5126f1ec2df520", + "mat3ra-utils @ git+https://github.com/Exabyte-io/utils.git@5f17ac3e7ab242f2c387b5691c8d8cacc6d59f3f", "mat3ra-esse", "mat3ra-mode", - "mat3ra-ade @ git+https://github.com/Exabyte-io/ade.git@d66b4cb1363d292fc3918767eb0eeeb538d3dabd", + "mat3ra-ade @ git+https://github.com/Exabyte-io/ade.git@9ef73f59d2099cda4426011ed1f31891ee19b68c", "mat3ra-made", "mat3ra-standata" ] diff --git a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py index f604c0a1..b1e923b7 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py +++ b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py @@ -258,7 +258,7 @@ def add_template_parameter_convergence( if result_unit is None: raise ValueError(f"No unit with result '{result_name}' found in subworkflow.") - scope_reference = Template.make_raw_scope_reference(parameter_name) + scope_reference = Template.format_as_scope_reference(parameter_name) for execution_unit in execution_units: execution_unit.replace_variable_value_in_inputs(parameter_name, scope_reference) execution_unit.set_context({**execution_unit.context, parameter_name: parameter_initial}) diff --git a/tests/py/units/test_execution_unit.py b/tests/py/units/test_execution_unit.py index a3b00b9d..99210af5 100644 --- a/tests/py/units/test_execution_unit.py +++ b/tests/py/units/test_execution_unit.py @@ -2,13 +2,19 @@ from mat3ra.ade import Template from mat3ra.wode.units.execution import ExecutionUnit, ExecutionUnitInputItem -CONTENT_DEGAUSS = "degauss = 0.005\n" +CONTENT_DEGAUSS_NUMERIC = "degauss = 0.005\n" CONTENT_ECUTWFC_JINJA = "ecutwfc = {{ cutoffs.wavefunction }}\n" -CONTENT_MULTI = "degauss = 0.005\necutwfc = 40\n" + +PATTERN_DEGAUSS_NUMERIC = r"degauss\s*=\s*[\d.e+\-]+" RAW_SCOPE_DEGAUSS = "{% raw %}{{ degauss }}{% endraw %}" RAW_SCOPE_ECUTWFC = "{% raw %}{{ ecutwfc }}{% endraw %}" +REPLACEMENT_DEGAUSS_RAW = f"degauss = {RAW_SCOPE_DEGAUSS}" + +EXPECTED_DEGAUSS_REPLACED = f"degauss = {RAW_SCOPE_DEGAUSS}\n" +EXPECTED_ECUTWFC_REPLACED = f"ecutwfc = {RAW_SCOPE_ECUTWFC}\n" + def _make_unit(*contents: str) -> ExecutionUnit: inputs = [ @@ -18,56 +24,37 @@ def _make_unit(*contents: str) -> ExecutionUnit: return ExecutionUnit(name="pw_scf", input=inputs) -REPLACE_IN_INPUT_CONTENT_CASES = [ - pytest.param( - [CONTENT_DEGAUSS], - r"degauss\s*=\s*[\d.e+\-]+", - f"degauss = {RAW_SCOPE_DEGAUSS}", - [f"degauss = {RAW_SCOPE_DEGAUSS}\n"], - id="single_input_numeric", - ), - pytest.param( - [CONTENT_DEGAUSS, CONTENT_ECUTWFC_JINJA], - r"degauss\s*=\s*[\d.e+\-]+", - f"degauss = {RAW_SCOPE_DEGAUSS}", - [f"degauss = {RAW_SCOPE_DEGAUSS}\n", CONTENT_ECUTWFC_JINJA], - id="multiple_inputs_only_first_matches", - ), -] - -REPLACE_VARIABLE_VALUE_IN_INPUTS_CASES = [ - pytest.param( - [CONTENT_DEGAUSS], - "degauss", - RAW_SCOPE_DEGAUSS, - [f"degauss = {RAW_SCOPE_DEGAUSS}\n"], - id="bare_numeric", - ), - pytest.param( - [CONTENT_ECUTWFC_JINJA], - "ecutwfc", - RAW_SCOPE_ECUTWFC, - [f"ecutwfc = {RAW_SCOPE_ECUTWFC}\n"], - id="jinja_expression", - ), - pytest.param( - [CONTENT_DEGAUSS, CONTENT_ECUTWFC_JINJA], - "degauss", - RAW_SCOPE_DEGAUSS, - [f"degauss = {RAW_SCOPE_DEGAUSS}\n", CONTENT_ECUTWFC_JINJA], - id="multi_input_partial_match", - ), -] - - -@pytest.mark.parametrize("contents,pattern,replacement,expected_contents", REPLACE_IN_INPUT_CONTENT_CASES) +@pytest.mark.parametrize( + "contents,pattern,replacement,expected_contents", + [ + ([CONTENT_DEGAUSS_NUMERIC], PATTERN_DEGAUSS_NUMERIC, REPLACEMENT_DEGAUSS_RAW, [EXPECTED_DEGAUSS_REPLACED]), + ( + [CONTENT_DEGAUSS_NUMERIC, CONTENT_ECUTWFC_JINJA], + PATTERN_DEGAUSS_NUMERIC, + REPLACEMENT_DEGAUSS_RAW, + [EXPECTED_DEGAUSS_REPLACED, CONTENT_ECUTWFC_JINJA], + ), + ], +) def test_replace_in_input_content(contents, pattern, replacement, expected_contents): unit = _make_unit(*contents) unit.replace_in_input_content(pattern, replacement) assert [item.template.content for item in unit.input] == expected_contents -@pytest.mark.parametrize("contents,variable_name,new_value,expected_contents", REPLACE_VARIABLE_VALUE_IN_INPUTS_CASES) +@pytest.mark.parametrize( + "contents,variable_name,new_value,expected_contents", + [ + ([CONTENT_DEGAUSS_NUMERIC], "degauss", RAW_SCOPE_DEGAUSS, [EXPECTED_DEGAUSS_REPLACED]), + ([CONTENT_ECUTWFC_JINJA], "ecutwfc", RAW_SCOPE_ECUTWFC, [EXPECTED_ECUTWFC_REPLACED]), + ( + [CONTENT_DEGAUSS_NUMERIC, CONTENT_ECUTWFC_JINJA], + "degauss", + RAW_SCOPE_DEGAUSS, + [EXPECTED_DEGAUSS_REPLACED, CONTENT_ECUTWFC_JINJA], + ), + ], +) def test_replace_variable_value_in_inputs(contents, variable_name, new_value, expected_contents): unit = _make_unit(*contents) unit.replace_variable_value_in_inputs(variable_name, new_value) From 73dc01c4d7ab97f59fa4643c703415c2bccfe10d Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 7 Apr 2026 17:59:33 -0700 Subject: [PATCH 11/11] chore: simplify --- src/py/mat3ra/wode/units/execution.py | 33 +++++++++------------------ 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/src/py/mat3ra/wode/units/execution.py b/src/py/mat3ra/wode/units/execution.py index 3a1b7d00..f9b638ac 100644 --- a/src/py/mat3ra/wode/units/execution.py +++ b/src/py/mat3ra/wode/units/execution.py @@ -9,45 +9,34 @@ remove_empty_lines_from_string, remove_timestampable_keys, ) -from pydantic import Field, model_validator +from pydantic import Field, model_serializer, model_validator from .unit import Unit +_ITEM_KEYS = {"rendered", "isManuallyChanged"} + # TODO: use from ESSE when epic/SOF-7756 merged class ExecutionUnitInputItem(InMemoryEntitySnakeCase): template: Template = Field(default_factory=Template) - rendered: str + rendered: str = "" isManuallyChanged: bool = False @model_validator(mode="before") @classmethod - def handle_legacy_flat_structure(cls, data: Any) -> Any: - if isinstance(data, dict) and "name" in data and "content" in data and "template" not in data: - template_fields = { - k: v - for k, v in data.items() - if k - in [ - "_id", - "slug", - "systemName", - "schemaVersion", - "name", - "applicationName", - "applicationVersion", - "executableName", - "contextProviders", - "content", - ] - } + def from_flat(cls, data: Any) -> Any: + if isinstance(data, dict) and "template" not in data: return { - "template": template_fields, + "template": {k: v for k, v in data.items() if k not in _ITEM_KEYS}, "rendered": data.get("rendered", ""), "isManuallyChanged": data.get("isManuallyChanged", False), } return data + @model_serializer(mode="plain") + def to_flat(self) -> Dict[str, Any]: + return {**self.template.to_dict(), "rendered": self.rendered, "isManuallyChanged": self.isManuallyChanged} + class ExecutionUnit(Unit, ExecutionUnitSchemaBase): type: Literal["execution"] = "execution"