Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ classifiers = [
dependencies = [
"numpy",
"mat3ra-code",
"mat3ra-utils",
"mat3ra-utils @ git+https://github.com/Exabyte-io/utils.git@5f17ac3e7ab242f2c387b5691c8d8cacc6d59f3f",
"mat3ra-esse",
"mat3ra-mode",
"mat3ra-ade",
"mat3ra-ade @ git+https://github.com/Exabyte-io/ade.git@9ef73f59d2099cda4426011ed1f31891ee19b68c",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

"mat3ra-made",
"mat3ra-standata"
]
Expand Down
219 changes: 155 additions & 64 deletions src/py/mat3ra/wode/subworkflows/convergence_mixin.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Any, Dict, List, Optional, Protocol, cast

from ..context.providers import PointsGridDataProvider
from ..units import Unit

from mat3ra.ade import Template
from mat3ra.esse.models.workflow.subworkflow.convergence.enum_options import ConvergenceParameterNameEnum

from .convergence.factory import create_convergence_parameter
from ..context.providers import PointsGridDataProvider
from ..units import Unit

CONVERGENCE_PARAMETER_TAG = "hasConvergenceParam"
CONVERGENCE_RESULT_TAG = "hasConvergenceResult"
Expand All @@ -29,7 +30,7 @@ def scalar_results(self) -> List[str]:
return [ENERGY_CONVERGENCE_RESULT]

@property
def convergence_param(self) -> Optional[str]:
def convergence_parameter(self) -> Optional[str]:
unit = cast(ConvergenceHost, self).find_unit_with_tag(CONVERGENCE_PARAMETER_TAG)
return getattr(unit, "operand", None) if unit else None

Expand All @@ -40,7 +41,7 @@ def convergence_result(self) -> Optional[str]:

@property
def has_convergence(self) -> bool:
return bool(self.convergence_param and self.convergence_result)
return bool(self.convergence_parameter and self.convergence_result)

def convergence_series(self, scope_track: Optional[List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
if not self.has_convergence or not scope_track:
Expand All @@ -50,12 +51,12 @@ def convergence_series(self, scope_track: Optional[List[Dict[str, Any]]]) -> Lis
series: List[Dict[str, Any]] = []
for scope_item in scope_track:
global_scope = ((scope_item or {}).get("scope") or {}).get("global") or {}
param = global_scope.get(self.convergence_param)
parameter = global_scope.get(self.convergence_parameter)
result = global_scope.get(self.convergence_result)
is_new_result = result is not None and result != last_result
last_result = result
if is_new_result:
series.append({"x": len(series) + 1, "param": param, "y": result})
series.append({"x": len(series) + 1, "parameter": parameter, "y": result})
return series

def _find_unit_for_convergence(self, result: str):
Expand All @@ -76,95 +77,61 @@ def _merge_convergence_context(unit_context: Dict[str, Any], convergence_context
merged_context["kgrid"] = merged_kgrid_context
return merged_context

def add_convergence(
def _build_convergence_units(
self,
parameter: str,
parameter_initial: Any,
parameter_increment: Any,
reciprocal_vector_ratios: Optional[List[float]] = None,
result: str = ENERGY_CONVERGENCE_RESULT,
parameter_name: str,
parameter_initial_value: str,
parameter_increment_expr: str,
parameter_final_value: str,
parameter_input: List[Dict[str, str]],
result_name: str,
result_unit_flowchart_id: str,
execution_unit_flowchart_id: str,
result_initial: Any = 0,
condition: Optional[str] = None,
operator: str = "<",
tolerance: Any = 1e-5,
max_occurrences: int = 10,
) -> None:
# Used for type checking correctness
host = cast(ConvergenceHost, self)
parameter_name = ConvergenceParameterNameEnum(parameter)

if result != ENERGY_CONVERGENCE_RESULT:
raise ValueError(f"Unsupported convergence result: {result}")

unit_for_convergence = self._find_unit_for_convergence(result)
if unit_for_convergence is None:
raise ValueError(f"Subworkflow does not contain a unit with '{result}' as an extracted property.")

if (
parameter_name
in (
ConvergenceParameterNameEnum.N_k_nonuniform,
ConvergenceParameterNameEnum.N_k_nonuniform_2D,
)
and reciprocal_vector_ratios is None
):
reciprocal_vector_ratios = PointsGridDataProvider(
context=unit_for_convergence.context
).get_reciprocal_vector_ratios()
if reciprocal_vector_ratios is None:
raise ValueError("Non-uniform k-grid convergence requires reciprocal_vector_ratios to be provided.")

param = create_convergence_parameter(
name=parameter_name.value,
initial_value=parameter_initial,
increment=parameter_increment,
reciprocal_vector_ratios=reciprocal_vector_ratios,
)

merged_context = self._merge_convergence_context(
unit_for_convergence.context,
param.unit_context,
)
unit_for_convergence.set_context(merged_context)

prev_result = "prev_result"
iteration = "iteration"
condition_expression = condition or f"abs(({prev_result}-{result})/{result})"
condition_expression = condition or f"abs(({prev_result}-{result_name})/{result_name})"

param_init = Unit(
name="init parameter",
type="assignment",
operand=param.name,
value=param.initial_value,
operand=parameter_name,
value=parameter_initial_value,
tags=[CONVERGENCE_PARAMETER_TAG],
)
prev_result_init = Unit(name="init result", type="assignment", operand=prev_result, value=result_initial)
iter_init = Unit(name="init counter", type="assignment", operand=iteration, value=1)
store_result = Unit(
name="update result",
type="assignment",
input=[{"scope": unit_for_convergence.flowchartId, "name": result}],
operand=result,
value=result,
input=[{"scope": result_unit_flowchart_id, "name": result_name}],
operand=result_name,
value=result_name,
tags=[CONVERGENCE_RESULT_TAG],
)
store_prev_result = Unit(
name="store result",
type="assignment",
input=[{"scope": unit_for_convergence.flowchartId, "name": result}],
input=[{"scope": result_unit_flowchart_id, "name": result_name}],
operand=prev_result,
value=result,
value=result_name,
)
next_iter = Unit(name="update counter", type="assignment", operand=iteration, value=f"{iteration} + 1")
next_step = Unit(
name="update parameter",
type="assignment",
input=param.use_variables_from_unit_context(unit_for_convergence.flowchartId),
operand=param.name,
value=param.increment,
next=unit_for_convergence.flowchartId,
input=parameter_input,
operand=parameter_name,
value=parameter_increment_expr,
next=execution_unit_flowchart_id,
)
exit_unit = Unit(name="exit", type="assignment", operand=param.name, value=param.final_value)
exit_unit = Unit(name="exit", type="assignment", operand=parameter_name, value=parameter_final_value)
condition_unit = Unit(
name="check convergence",
type="condition",
Expand All @@ -186,4 +153,128 @@ def add_convergence(
host.add_unit(next_step)
host.add_unit(exit_unit)

next_step.next = unit_for_convergence.flowchartId
next_step.next = execution_unit_flowchart_id

def add_convergence(
self,
parameter: str,
parameter_initial: Any,
parameter_increment: Any,
reciprocal_vector_ratios: Optional[List[float]] = None,
result: str = ENERGY_CONVERGENCE_RESULT,
result_initial: Any = 0,
condition: Optional[str] = None,
operator: str = "<",
tolerance: Any = 1e-5,
max_occurrences: int = 10,
) -> None:
parameter_name = ConvergenceParameterNameEnum(parameter)

if result != ENERGY_CONVERGENCE_RESULT:
raise ValueError(f"Unsupported convergence result: {result}")

unit_for_convergence = self._find_unit_for_convergence(result)
if unit_for_convergence is None:
raise ValueError(f"Subworkflow does not contain a unit with '{result}' as an extracted property.")

if (
parameter_name
in (
ConvergenceParameterNameEnum.N_k_nonuniform,
ConvergenceParameterNameEnum.N_k_nonuniform_2D,
)
and reciprocal_vector_ratios is None
):
reciprocal_vector_ratios = PointsGridDataProvider(
context=unit_for_convergence.context
).get_reciprocal_vector_ratios()
if reciprocal_vector_ratios is None:
raise ValueError("Non-uniform k-grid convergence requires reciprocal_vector_ratios to be provided.")

parameter = create_convergence_parameter(
name=parameter_name.value,
initial_value=parameter_initial,
increment=parameter_increment,
reciprocal_vector_ratios=reciprocal_vector_ratios,
)

merged_context = self._merge_convergence_context(
unit_for_convergence.context,
parameter.unit_context,
)
unit_for_convergence.set_context(merged_context)

self._build_convergence_units(
parameter_name=parameter.name,
parameter_initial_value=parameter.initial_value,
parameter_increment_expr=parameter.increment,
parameter_final_value=parameter.final_value,
parameter_input=parameter.use_variables_from_unit_context(unit_for_convergence.flowchartId),
result_name=result,
result_unit_flowchart_id=unit_for_convergence.flowchartId,
execution_unit_flowchart_id=unit_for_convergence.flowchartId,
result_initial=result_initial,
condition=condition,
operator=operator,
tolerance=tolerance,
max_occurrences=max_occurrences,
)

def add_template_parameter_convergence(
self,
parameter_name: str,
parameter_initial: Any,
parameter_increment: Any,
result_name: str,
result_initial: Any = 0,
condition: Optional[str] = None,
operator: str = "<",
tolerance: Any = 1e-3,
max_occurrences: int = 10,
) -> None:
"""
Add a convergence loop for an arbitrary template parameter.

Uses regex substitution to inject the parameter as a runtime scope variable into the
execution unit's input template, then delegates to _build_convergence_units.

Args:
parameter_name: Parameter name as it appears in the input template (e.g. "degauss").
parameter_initial: Starting value of the parameter.
parameter_increment: Scalar step added each iteration.
result_name: Name of the result property to monitor (must exist in a unit's results).
result_initial: Seed value for the result before the first iteration.
condition: Optional custom convergence condition expression.
operator: Comparison operator for the convergence condition (default "<").
tolerance: Convergence threshold.
max_occurrences: Maximum number of loop iterations.
"""
host = cast(ConvergenceHost, self)
execution_units = [u for u in host.units if u.type == "execution"]
if not execution_units:
raise ValueError("No execution unit found in subworkflow.")

result_unit = self._find_unit_for_convergence(result_name)
if result_unit is None:
raise ValueError(f"No unit with result '{result_name}' found in subworkflow.")

scope_reference = Template.format_as_scope_reference(parameter_name)
for execution_unit in execution_units:
execution_unit.replace_variable_value_in_inputs(parameter_name, scope_reference)
execution_unit.set_context({**execution_unit.context, parameter_name: parameter_initial})

self._build_convergence_units(
parameter_name=parameter_name,
parameter_initial_value=parameter_initial,
parameter_increment_expr=f"{parameter_name} + {parameter_increment}",
parameter_final_value=parameter_name,
parameter_input=[],
result_name=result_name,
result_unit_flowchart_id=result_unit.flowchartId,
execution_unit_flowchart_id=execution_units[0].flowchartId,
result_initial=result_initial,
condition=condition,
operator=operator,
tolerance=tolerance,
max_occurrences=max_occurrences,
)
45 changes: 37 additions & 8 deletions src/py/mat3ra/wode/units/execution.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,65 @@
from typing import Any, Dict, List, Literal

from mat3ra.ade import Application, Executable, Flavor
from mat3ra.ade import Application, Executable, Flavor, Template
from mat3ra.code.entity import InMemoryEntitySnakeCase
from mat3ra.esse.models.workflow.unit.execution import ExecutionUnitSchemaBase
from mat3ra.utils import (
calculate_hash_from_object,
remove_comments_from_source_code,
remove_empty_lines_from_string,
remove_timestampable_keys,
)
from pydantic import Field
from pydantic import Field, model_serializer, model_validator

from .unit import Unit

_ITEM_KEYS = {"rendered", "isManuallyChanged"}


# TODO: use from ESSE when epic/SOF-7756 merged
class ExecutionUnitInputItem(InMemoryEntitySnakeCase):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

template: Template = Field(default_factory=Template)
rendered: str = ""
isManuallyChanged: bool = False

@model_validator(mode="before")
@classmethod
def from_flat(cls, data: Any) -> Any:
if isinstance(data, dict) and "template" not in data:
return {
"template": {k: v for k, v in data.items() if k not in _ITEM_KEYS},
"rendered": data.get("rendered", ""),
"isManuallyChanged": data.get("isManuallyChanged", False),
}
return data

@model_serializer(mode="plain")
def to_flat(self) -> Dict[str, Any]:
return {**self.template.to_dict(), "rendered": self.rendered, "isManuallyChanged": self.isManuallyChanged}


class ExecutionUnit(Unit, ExecutionUnitSchemaBase):
type: Literal["execution"] = "execution"
executable: Executable = None
flavor: Flavor = None
application: Application = None
input: List = Field(default_factory=list)
input: List[ExecutionUnitInputItem] = Field(default_factory=List[ExecutionUnitInputItem])

def replace_in_input_content(self, pattern: str, replacement: str) -> None:
for input_item in self.input:
input_item.template.replace_in_content(pattern, replacement)

def replace_variable_value_in_inputs(self, variable_name: str, new_value: str) -> None:
for input_item in self.input:
input_item.template.replace_variable_value(variable_name, new_value)

def get_hash_object(self) -> Dict[str, Any]:
app = self.application.to_dict() if self.application else {}
exe = self.executable.to_dict() if self.executable else {}
flv = self.flavor.to_dict() if self.flavor else {}
input_items = self.input if isinstance(self.input, list) else []
input_hash = calculate_hash_from_object(
[
remove_empty_lines_from_string(remove_comments_from_source_code(i.get("content", "")))
for i in input_items
if isinstance(i, dict)
]
[remove_empty_lines_from_string(remove_comments_from_source_code(i.template.content)) for i in input_items]
)
return {
**super().get_hash_object(),
Expand Down
Loading
Loading