diff --git a/pyproject.toml b/pyproject.toml index daa5a951..c39a141b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,10 +21,10 @@ classifiers = [ dependencies = [ "numpy", "mat3ra-code", - "mat3ra-utils", + "mat3ra-utils @ git+https://github.com/Exabyte-io/utils.git@5f17ac3e7ab242f2c387b5691c8d8cacc6d59f3f", "mat3ra-esse", "mat3ra-mode", - "mat3ra-ade", + "mat3ra-ade @ git+https://github.com/Exabyte-io/ade.git@9ef73f59d2099cda4426011ed1f31891ee19b68c", "mat3ra-made", "mat3ra-standata" ] diff --git a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py index 000a05de..b1e923b7 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py +++ b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py @@ -1,10 +1,11 @@ from typing import Any, Dict, List, Optional, Protocol, cast -from ..context.providers import PointsGridDataProvider -from ..units import Unit - +from mat3ra.ade import Template from mat3ra.esse.models.workflow.subworkflow.convergence.enum_options import ConvergenceParameterNameEnum + from .convergence.factory import create_convergence_parameter +from ..context.providers import PointsGridDataProvider +from ..units import Unit CONVERGENCE_PARAMETER_TAG = "hasConvergenceParam" CONVERGENCE_RESULT_TAG = "hasConvergenceResult" @@ -29,7 +30,7 @@ def scalar_results(self) -> List[str]: return [ENERGY_CONVERGENCE_RESULT] @property - def convergence_param(self) -> Optional[str]: + def convergence_parameter(self) -> Optional[str]: unit = cast(ConvergenceHost, self).find_unit_with_tag(CONVERGENCE_PARAMETER_TAG) return getattr(unit, "operand", None) if unit else None @@ -40,7 +41,7 @@ def convergence_result(self) -> Optional[str]: @property def has_convergence(self) -> bool: - return bool(self.convergence_param and self.convergence_result) + return bool(self.convergence_parameter and self.convergence_result) def convergence_series(self, scope_track: Optional[List[Dict[str, Any]]]) -> List[Dict[str, Any]]: if not self.has_convergence or not scope_track: @@ -50,12 +51,12 @@ def convergence_series(self, scope_track: Optional[List[Dict[str, Any]]]) -> Lis series: List[Dict[str, Any]] = [] for scope_item in scope_track: global_scope = ((scope_item or {}).get("scope") or {}).get("global") or {} - param = global_scope.get(self.convergence_param) + parameter = global_scope.get(self.convergence_parameter) result = global_scope.get(self.convergence_result) is_new_result = result is not None and result != last_result last_result = result if is_new_result: - series.append({"x": len(series) + 1, "param": param, "y": result}) + series.append({"x": len(series) + 1, "parameter": parameter, "y": result}) return series def _find_unit_for_convergence(self, result: str): @@ -76,66 +77,32 @@ def _merge_convergence_context(unit_context: Dict[str, Any], convergence_context merged_context["kgrid"] = merged_kgrid_context return merged_context - def add_convergence( + def _build_convergence_units( self, - parameter: str, - parameter_initial: Any, - parameter_increment: Any, - reciprocal_vector_ratios: Optional[List[float]] = None, - result: str = ENERGY_CONVERGENCE_RESULT, + parameter_name: str, + parameter_initial_value: str, + parameter_increment_expr: str, + parameter_final_value: str, + parameter_input: List[Dict[str, str]], + result_name: str, + result_unit_flowchart_id: str, + execution_unit_flowchart_id: str, result_initial: Any = 0, condition: Optional[str] = None, operator: str = "<", tolerance: Any = 1e-5, max_occurrences: int = 10, ) -> None: - # Used for type checking correctness host = cast(ConvergenceHost, self) - parameter_name = ConvergenceParameterNameEnum(parameter) - - if result != ENERGY_CONVERGENCE_RESULT: - raise ValueError(f"Unsupported convergence result: {result}") - - unit_for_convergence = self._find_unit_for_convergence(result) - if unit_for_convergence is None: - raise ValueError(f"Subworkflow does not contain a unit with '{result}' as an extracted property.") - - if ( - parameter_name - in ( - ConvergenceParameterNameEnum.N_k_nonuniform, - ConvergenceParameterNameEnum.N_k_nonuniform_2D, - ) - and reciprocal_vector_ratios is None - ): - reciprocal_vector_ratios = PointsGridDataProvider( - context=unit_for_convergence.context - ).get_reciprocal_vector_ratios() - if reciprocal_vector_ratios is None: - raise ValueError("Non-uniform k-grid convergence requires reciprocal_vector_ratios to be provided.") - - param = create_convergence_parameter( - name=parameter_name.value, - initial_value=parameter_initial, - increment=parameter_increment, - reciprocal_vector_ratios=reciprocal_vector_ratios, - ) - - merged_context = self._merge_convergence_context( - unit_for_convergence.context, - param.unit_context, - ) - unit_for_convergence.set_context(merged_context) - prev_result = "prev_result" iteration = "iteration" - condition_expression = condition or f"abs(({prev_result}-{result})/{result})" + condition_expression = condition or f"abs(({prev_result}-{result_name})/{result_name})" param_init = Unit( name="init parameter", type="assignment", - operand=param.name, - value=param.initial_value, + operand=parameter_name, + value=parameter_initial_value, tags=[CONVERGENCE_PARAMETER_TAG], ) prev_result_init = Unit(name="init result", type="assignment", operand=prev_result, value=result_initial) @@ -143,28 +110,28 @@ def add_convergence( store_result = Unit( name="update result", type="assignment", - input=[{"scope": unit_for_convergence.flowchartId, "name": result}], - operand=result, - value=result, + input=[{"scope": result_unit_flowchart_id, "name": result_name}], + operand=result_name, + value=result_name, tags=[CONVERGENCE_RESULT_TAG], ) store_prev_result = Unit( name="store result", type="assignment", - input=[{"scope": unit_for_convergence.flowchartId, "name": result}], + input=[{"scope": result_unit_flowchart_id, "name": result_name}], operand=prev_result, - value=result, + value=result_name, ) next_iter = Unit(name="update counter", type="assignment", operand=iteration, value=f"{iteration} + 1") next_step = Unit( name="update parameter", type="assignment", - input=param.use_variables_from_unit_context(unit_for_convergence.flowchartId), - operand=param.name, - value=param.increment, - next=unit_for_convergence.flowchartId, + input=parameter_input, + operand=parameter_name, + value=parameter_increment_expr, + next=execution_unit_flowchart_id, ) - exit_unit = Unit(name="exit", type="assignment", operand=param.name, value=param.final_value) + exit_unit = Unit(name="exit", type="assignment", operand=parameter_name, value=parameter_final_value) condition_unit = Unit( name="check convergence", type="condition", @@ -186,4 +153,128 @@ def add_convergence( host.add_unit(next_step) host.add_unit(exit_unit) - next_step.next = unit_for_convergence.flowchartId + next_step.next = execution_unit_flowchart_id + + def add_convergence( + self, + parameter: str, + parameter_initial: Any, + parameter_increment: Any, + reciprocal_vector_ratios: Optional[List[float]] = None, + result: str = ENERGY_CONVERGENCE_RESULT, + result_initial: Any = 0, + condition: Optional[str] = None, + operator: str = "<", + tolerance: Any = 1e-5, + max_occurrences: int = 10, + ) -> None: + parameter_name = ConvergenceParameterNameEnum(parameter) + + if result != ENERGY_CONVERGENCE_RESULT: + raise ValueError(f"Unsupported convergence result: {result}") + + unit_for_convergence = self._find_unit_for_convergence(result) + if unit_for_convergence is None: + raise ValueError(f"Subworkflow does not contain a unit with '{result}' as an extracted property.") + + if ( + parameter_name + in ( + ConvergenceParameterNameEnum.N_k_nonuniform, + ConvergenceParameterNameEnum.N_k_nonuniform_2D, + ) + and reciprocal_vector_ratios is None + ): + reciprocal_vector_ratios = PointsGridDataProvider( + context=unit_for_convergence.context + ).get_reciprocal_vector_ratios() + if reciprocal_vector_ratios is None: + raise ValueError("Non-uniform k-grid convergence requires reciprocal_vector_ratios to be provided.") + + parameter = create_convergence_parameter( + name=parameter_name.value, + initial_value=parameter_initial, + increment=parameter_increment, + reciprocal_vector_ratios=reciprocal_vector_ratios, + ) + + merged_context = self._merge_convergence_context( + unit_for_convergence.context, + parameter.unit_context, + ) + unit_for_convergence.set_context(merged_context) + + self._build_convergence_units( + parameter_name=parameter.name, + parameter_initial_value=parameter.initial_value, + parameter_increment_expr=parameter.increment, + parameter_final_value=parameter.final_value, + parameter_input=parameter.use_variables_from_unit_context(unit_for_convergence.flowchartId), + result_name=result, + result_unit_flowchart_id=unit_for_convergence.flowchartId, + execution_unit_flowchart_id=unit_for_convergence.flowchartId, + result_initial=result_initial, + condition=condition, + operator=operator, + tolerance=tolerance, + max_occurrences=max_occurrences, + ) + + def add_template_parameter_convergence( + self, + parameter_name: str, + parameter_initial: Any, + parameter_increment: Any, + result_name: str, + result_initial: Any = 0, + condition: Optional[str] = None, + operator: str = "<", + tolerance: Any = 1e-3, + max_occurrences: int = 10, + ) -> None: + """ + Add a convergence loop for an arbitrary template parameter. + + Uses regex substitution to inject the parameter as a runtime scope variable into the + execution unit's input template, then delegates to _build_convergence_units. + + Args: + parameter_name: Parameter name as it appears in the input template (e.g. "degauss"). + parameter_initial: Starting value of the parameter. + parameter_increment: Scalar step added each iteration. + result_name: Name of the result property to monitor (must exist in a unit's results). + result_initial: Seed value for the result before the first iteration. + condition: Optional custom convergence condition expression. + operator: Comparison operator for the convergence condition (default "<"). + tolerance: Convergence threshold. + max_occurrences: Maximum number of loop iterations. + """ + host = cast(ConvergenceHost, self) + execution_units = [u for u in host.units if u.type == "execution"] + if not execution_units: + raise ValueError("No execution unit found in subworkflow.") + + result_unit = self._find_unit_for_convergence(result_name) + if result_unit is None: + raise ValueError(f"No unit with result '{result_name}' found in subworkflow.") + + scope_reference = Template.format_as_scope_reference(parameter_name) + for execution_unit in execution_units: + execution_unit.replace_variable_value_in_inputs(parameter_name, scope_reference) + execution_unit.set_context({**execution_unit.context, parameter_name: parameter_initial}) + + self._build_convergence_units( + parameter_name=parameter_name, + parameter_initial_value=parameter_initial, + parameter_increment_expr=f"{parameter_name} + {parameter_increment}", + parameter_final_value=parameter_name, + parameter_input=[], + result_name=result_name, + result_unit_flowchart_id=result_unit.flowchartId, + execution_unit_flowchart_id=execution_units[0].flowchartId, + result_initial=result_initial, + condition=condition, + operator=operator, + tolerance=tolerance, + max_occurrences=max_occurrences, + ) diff --git a/src/py/mat3ra/wode/units/execution.py b/src/py/mat3ra/wode/units/execution.py index 4d0f2688..f9b638ac 100644 --- a/src/py/mat3ra/wode/units/execution.py +++ b/src/py/mat3ra/wode/units/execution.py @@ -1,6 +1,7 @@ from typing import Any, Dict, List, Literal -from mat3ra.ade import Application, Executable, Flavor +from mat3ra.ade import Application, Executable, Flavor, Template +from mat3ra.code.entity import InMemoryEntitySnakeCase from mat3ra.esse.models.workflow.unit.execution import ExecutionUnitSchemaBase from mat3ra.utils import ( calculate_hash_from_object, @@ -8,17 +9,49 @@ remove_empty_lines_from_string, remove_timestampable_keys, ) -from pydantic import Field +from pydantic import Field, model_serializer, model_validator from .unit import Unit +_ITEM_KEYS = {"rendered", "isManuallyChanged"} + + +# TODO: use from ESSE when epic/SOF-7756 merged +class ExecutionUnitInputItem(InMemoryEntitySnakeCase): + template: Template = Field(default_factory=Template) + rendered: str = "" + isManuallyChanged: bool = False + + @model_validator(mode="before") + @classmethod + def from_flat(cls, data: Any) -> Any: + if isinstance(data, dict) and "template" not in data: + return { + "template": {k: v for k, v in data.items() if k not in _ITEM_KEYS}, + "rendered": data.get("rendered", ""), + "isManuallyChanged": data.get("isManuallyChanged", False), + } + return data + + @model_serializer(mode="plain") + def to_flat(self) -> Dict[str, Any]: + return {**self.template.to_dict(), "rendered": self.rendered, "isManuallyChanged": self.isManuallyChanged} + class ExecutionUnit(Unit, ExecutionUnitSchemaBase): type: Literal["execution"] = "execution" executable: Executable = None flavor: Flavor = None application: Application = None - input: List = Field(default_factory=list) + input: List[ExecutionUnitInputItem] = Field(default_factory=List[ExecutionUnitInputItem]) + + def replace_in_input_content(self, pattern: str, replacement: str) -> None: + for input_item in self.input: + input_item.template.replace_in_content(pattern, replacement) + + def replace_variable_value_in_inputs(self, variable_name: str, new_value: str) -> None: + for input_item in self.input: + input_item.template.replace_variable_value(variable_name, new_value) def get_hash_object(self) -> Dict[str, Any]: app = self.application.to_dict() if self.application else {} @@ -26,11 +59,7 @@ def get_hash_object(self) -> Dict[str, Any]: flv = self.flavor.to_dict() if self.flavor else {} input_items = self.input if isinstance(self.input, list) else [] input_hash = calculate_hash_from_object( - [ - remove_empty_lines_from_string(remove_comments_from_source_code(i.get("content", ""))) - for i in input_items - if isinstance(i, dict) - ] + [remove_empty_lines_from_string(remove_comments_from_source_code(i.template.content)) for i in input_items] ) return { **super().get_hash_object(), diff --git a/tests/py/test_convergence.py b/tests/py/test_convergence.py index 900972b0..3c8625dc 100644 --- a/tests/py/test_convergence.py +++ b/tests/py/test_convergence.py @@ -1,7 +1,9 @@ +import pytest +from mat3ra.esse.models.workflow.subworkflow.convergence.enum_options import ConvergenceParameterNameEnum from mat3ra.made.lattice import Lattice from mat3ra.standata.workflows import WorkflowStandata + from mat3ra.wode import Workflow -from mat3ra.esse.models.workflow.subworkflow.convergence.enum_options import ConvergenceParameterNameEnum def _build_total_energy_subworkflow(): @@ -44,7 +46,7 @@ def test_add_uniform_energy_convergence(): assert pw_scf.context["isKgridEdited"] is True assert pw_scf.context["isUsingJinjaVariables"] is True - assert subworkflow.convergence_param == ConvergenceParameterNameEnum.N_k.value + assert subworkflow.convergence_parameter == ConvergenceParameterNameEnum.N_k.value assert subworkflow.convergence_result == "total_energy" assert subworkflow.has_convergence is True @@ -164,6 +166,104 @@ def test_convergence_series_uses_scope_track(): ] assert subworkflow.convergence_series(scope_track) == [ - {"x": 1, "param": 1, "y": -10.0}, - {"x": 2, "param": 2, "y": -10.5}, + {"x": 1, "parameter": 1, "y": -10.0}, + {"x": 2, "parameter": 2, "y": -10.5}, ] + + +TEMPLATE_PARAM_TEST_CASES = [ + pytest.param( + "degauss", + 0.001, + 0.002, + "total_energy", + "degauss = 0.005", + id="degauss_bare_numeric", + ), + pytest.param( + "ecutwfc", + 20, + 10, + "total_energy", + "ecutwfc = {{ cutoffs.wavefunction }}", + id="ecutwfc_jinja_expression", + ), +] + + +@pytest.mark.parametrize( + "param_name,param_initial,param_increment,result_name,original_pattern", TEMPLATE_PARAM_TEST_CASES +) +def test_add_template_param_convergence(param_name, param_initial, param_increment, result_name, original_pattern): + subworkflow = _build_total_energy_subworkflow() + + subworkflow.add_template_parameter_convergence( + parameter_name=param_name, + parameter_initial=param_initial, + parameter_increment=param_increment, + result_name=result_name, + tolerance=1e-3, + ) + + assert [unit.name for unit in subworkflow.units] == [ + "init parameter", + "init result", + "init counter", + "pw_scf", + "update result", + "check convergence", + "store result", + "update counter", + "update parameter", + "exit", + ] + + pw_scf = subworkflow.get_unit_by_name(name="pw_scf") + assert pw_scf.context[param_name] == param_initial + input_item = pw_scf.input[0] + template_content = input_item.template.content + assert f"{param_name} = {{% raw %}}{{{{ {param_name} }}}}{{% endraw %}}" in template_content + assert original_pattern not in template_content + + assert subworkflow.convergence_parameter == param_name + assert subworkflow.convergence_result == result_name + assert subworkflow.has_convergence is True + + update_parameter = subworkflow.get_unit_by_name(name="update parameter") + assert update_parameter.operand == param_name + assert update_parameter.value == f"{param_name} + {param_increment}" + assert update_parameter.input == [] + assert update_parameter.next == pw_scf.flowchartId + + exit_unit = subworkflow.get_unit_by_name(name="exit") + assert exit_unit.operand == param_name + assert exit_unit.value == param_name + + +def test_add_template_param_convergence_multi_unit(): + workflow_config = WorkflowStandata.filter_by_application("espresso").get_by_name_first_match("band_structure.json") + workflow = Workflow.create(workflow_config) + subworkflow = workflow.subworkflows[0] + + subworkflow.add_template_parameter_convergence( + parameter_name="ecutwfc", + parameter_initial=20, + parameter_increment=10, + result_name="total_energy", + ) + + execution_units = [u for u in subworkflow.units if u.type == "execution"] + assert len(execution_units) == 3 + + pw_scf = subworkflow.get_unit_by_name("pw_scf") + pw_bands = subworkflow.get_unit_by_name("pw_bands") + + for unit in [pw_scf, pw_bands]: + assert unit.context["ecutwfc"] == 20 + input_item = unit.input[0] + template_content = input_item.template.content + assert "ecutwfc = {% raw %}{{ ecutwfc }}{% endraw %}" in template_content + assert "ecutwfc = {{ cutoffs.wavefunction }}" not in template_content + + assert subworkflow.convergence_parameter == "ecutwfc" + assert subworkflow.convergence_result == "total_energy" diff --git a/tests/py/units/test_execution_unit.py b/tests/py/units/test_execution_unit.py new file mode 100644 index 00000000..99210af5 --- /dev/null +++ b/tests/py/units/test_execution_unit.py @@ -0,0 +1,61 @@ +import pytest +from mat3ra.ade import Template +from mat3ra.wode.units.execution import ExecutionUnit, ExecutionUnitInputItem + +CONTENT_DEGAUSS_NUMERIC = "degauss = 0.005\n" +CONTENT_ECUTWFC_JINJA = "ecutwfc = {{ cutoffs.wavefunction }}\n" + +PATTERN_DEGAUSS_NUMERIC = r"degauss\s*=\s*[\d.e+\-]+" + +RAW_SCOPE_DEGAUSS = "{% raw %}{{ degauss }}{% endraw %}" +RAW_SCOPE_ECUTWFC = "{% raw %}{{ ecutwfc }}{% endraw %}" + +REPLACEMENT_DEGAUSS_RAW = f"degauss = {RAW_SCOPE_DEGAUSS}" + +EXPECTED_DEGAUSS_REPLACED = f"degauss = {RAW_SCOPE_DEGAUSS}\n" +EXPECTED_ECUTWFC_REPLACED = f"ecutwfc = {RAW_SCOPE_ECUTWFC}\n" + + +def _make_unit(*contents: str) -> ExecutionUnit: + inputs = [ + ExecutionUnitInputItem(template=Template(name="pw.in", content=c), rendered=c) + for c in contents + ] + return ExecutionUnit(name="pw_scf", input=inputs) + + +@pytest.mark.parametrize( + "contents,pattern,replacement,expected_contents", + [ + ([CONTENT_DEGAUSS_NUMERIC], PATTERN_DEGAUSS_NUMERIC, REPLACEMENT_DEGAUSS_RAW, [EXPECTED_DEGAUSS_REPLACED]), + ( + [CONTENT_DEGAUSS_NUMERIC, CONTENT_ECUTWFC_JINJA], + PATTERN_DEGAUSS_NUMERIC, + REPLACEMENT_DEGAUSS_RAW, + [EXPECTED_DEGAUSS_REPLACED, CONTENT_ECUTWFC_JINJA], + ), + ], +) +def test_replace_in_input_content(contents, pattern, replacement, expected_contents): + unit = _make_unit(*contents) + unit.replace_in_input_content(pattern, replacement) + assert [item.template.content for item in unit.input] == expected_contents + + +@pytest.mark.parametrize( + "contents,variable_name,new_value,expected_contents", + [ + ([CONTENT_DEGAUSS_NUMERIC], "degauss", RAW_SCOPE_DEGAUSS, [EXPECTED_DEGAUSS_REPLACED]), + ([CONTENT_ECUTWFC_JINJA], "ecutwfc", RAW_SCOPE_ECUTWFC, [EXPECTED_ECUTWFC_REPLACED]), + ( + [CONTENT_DEGAUSS_NUMERIC, CONTENT_ECUTWFC_JINJA], + "degauss", + RAW_SCOPE_DEGAUSS, + [EXPECTED_DEGAUSS_REPLACED, CONTENT_ECUTWFC_JINJA], + ), + ], +) +def test_replace_variable_value_in_inputs(contents, variable_name, new_value, expected_contents): + unit = _make_unit(*contents) + unit.replace_variable_value_in_inputs(variable_name, new_value) + assert [item.template.content for item in unit.input] == expected_contents