diff --git a/pyproject.toml b/pyproject.toml index a38a4a9..44cbfc1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ dependencies = [ "pydantic>=2.0", "mat3ra-esse", "mat3ra-code", - "mat3ra-utils[extra]" + "mat3ra-utils[extra] @ git+https://github.com/Exabyte-io/utils.git@5f17ac3e7ab242f2c387b5691c8d8cacc6d59f3f" ] [project.optional-dependencies] diff --git a/src/py/mat3ra/ade/template.py b/src/py/mat3ra/ade/template.py index 76e5547..3be7f5b 100644 --- a/src/py/mat3ra/ade/template.py +++ b/src/py/mat3ra/ade/template.py @@ -3,7 +3,13 @@ from mat3ra.code.entity import InMemoryEntitySnakeCase from mat3ra.esse.models.software.template import TemplateSchema -from mat3ra.utils.extra.jinja import render_jinja_with_error_handling +from mat3ra.utils.extra.jinja import ( + JINJA_EXPRESSION_PATTERN, + NUMERIC_VALUE_PATTERN, + render_jinja_with_error_handling, + replace_in_template_content, + wrap_in_raw_block, +) from pydantic import Field from .context.context_provider import ContextProvider @@ -35,6 +41,17 @@ def get_rendered(self) -> str: def set_content(self, text: str) -> None: self.content = text + def replace_in_content(self, pattern: str, replacement: str) -> None: + self.content = replace_in_template_content(self.content, pattern, replacement) + + def replace_variable_value(self, variable_name: str, new_value: str) -> None: + pattern = rf"{variable_name}\s*=\s*(?:{NUMERIC_VALUE_PATTERN}|{JINJA_EXPRESSION_PATTERN})" + self.replace_in_content(pattern, f"{variable_name} = {new_value}") + + @staticmethod + def format_as_scope_reference(variable_name: str) -> str: + return wrap_in_raw_block("{{ " + variable_name + " }}") + def set_rendered(self, text: str) -> None: self.rendered = text diff --git a/tests/fixtures/application_hash.json b/tests/fixtures/application_hash.json index b9273da..9b90230 100644 --- a/tests/fixtures/application_hash.json +++ b/tests/fixtures/application_hash.json @@ -1,8 +1,8 @@ { - "standata": { - "name": "espresso", - "version": "6.3", - "build": "GNU" - }, - "hash": "e4f4762afefb659c36b6ac2d08d860cc" -} + "standata": { + "name": "espresso", + "version": "6.3", + "build": "GNU" + }, + "hash": "bc4593c715c85ee9f1ac5f9e790b490d" +} \ No newline at end of file diff --git a/tests/py/test_template.py b/tests/py/test_template.py index a8bc768..be8b746 100644 --- a/tests/py/test_template.py +++ b/tests/py/test_template.py @@ -157,6 +157,25 @@ **TEMPLATE_DEFAULT_FIELDS, } +CONTENT_DEGAUSS_NUMERIC = "degauss = 0.005\n" +CONTENT_ECUTWFC_JINJA = "ecutwfc = {{ cutoffs.wavefunction }}\n" +CONTENT_MIXING_BETA_SCIENTIFIC = "mixing_beta = 1e-3\n" +CONTENT_NO_MATCH = "no_match\n" + +PATTERN_DEGAUSS_NUMERIC = r"degauss\s*=\s*[\d.e+\-]+" +PATTERN_OTHER_NUMERIC = r"other\s*=\s*[\d.e+\-]+" + +REPLACEMENT_DEGAUSS = "degauss = NEW" +REPLACEMENT_OTHER = "other = x" +NEW_VALUE = "NEW" + +EXPECTED_DEGAUSS_REPLACED = "degauss = NEW\n" +EXPECTED_ECUTWFC_REPLACED = "ecutwfc = NEW\n" +EXPECTED_MIXING_BETA_REPLACED = "mixing_beta = NEW\n" + +EXPECTED_RAW_SCOPE_DEGAUSS = "{% raw %}{{ degauss }}{% endraw %}" +EXPECTED_RAW_SCOPE_CUTOFFS_WF = "{% raw %}{{ cutoffs_wf }}{% endraw %}" + @pytest.mark.parametrize( "config,expected_fields", @@ -287,3 +306,41 @@ def test_render_with_external_context_and_provider(): template.add_context_provider(PROVIDER_KPATH) template.render(EXTERNAL_CONTEXT_KPATH) assert template.get_rendered() == EXPECTED_EXTERNAL_CONTEXT_RENDER + + +@pytest.mark.parametrize( + "content,pattern,replacement,expected", + [ + (CONTENT_DEGAUSS_NUMERIC, PATTERN_DEGAUSS_NUMERIC, REPLACEMENT_DEGAUSS, EXPECTED_DEGAUSS_REPLACED), + (CONTENT_NO_MATCH, PATTERN_OTHER_NUMERIC, REPLACEMENT_OTHER, CONTENT_NO_MATCH), + ], +) +def test_replace_in_content(content, pattern, replacement, expected): + template = Template(name="test.in", content=content) + template.replace_in_content(pattern, replacement) + assert template.content == expected + + +@pytest.mark.parametrize( + "content,variable_name,new_value,expected", + [ + (CONTENT_DEGAUSS_NUMERIC, "degauss", NEW_VALUE, EXPECTED_DEGAUSS_REPLACED), + (CONTENT_ECUTWFC_JINJA, "ecutwfc", NEW_VALUE, EXPECTED_ECUTWFC_REPLACED), + (CONTENT_MIXING_BETA_SCIENTIFIC, "mixing_beta", NEW_VALUE, EXPECTED_MIXING_BETA_REPLACED), + ], +) +def test_replace_variable_value(content, variable_name, new_value, expected): + template = Template(name="test.in", content=content) + template.replace_variable_value(variable_name, new_value) + assert template.content == expected + + +@pytest.mark.parametrize( + "variable_name,expected", + [ + ("degauss", EXPECTED_RAW_SCOPE_DEGAUSS), + ("cutoffs_wf", EXPECTED_RAW_SCOPE_CUTOFFS_WF), + ], +) +def test_format_as_scope_reference(variable_name, expected): + assert Template.format_as_scope_reference(variable_name) == expected