Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dependencies = [
"pydantic>=2.0",
"mat3ra-esse",
"mat3ra-code",
"mat3ra-utils[extra]"
"mat3ra-utils[extra] @ git+https://github.com/Exabyte-io/utils.git@5f17ac3e7ab242f2c387b5691c8d8cacc6d59f3f"
]

[project.optional-dependencies]
Expand Down
19 changes: 18 additions & 1 deletion src/py/mat3ra/ade/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@

from mat3ra.code.entity import InMemoryEntitySnakeCase
from mat3ra.esse.models.software.template import TemplateSchema
from mat3ra.utils.extra.jinja import render_jinja_with_error_handling
from mat3ra.utils.extra.jinja import (
JINJA_EXPRESSION_PATTERN,
NUMERIC_VALUE_PATTERN,
render_jinja_with_error_handling,
replace_in_template_content,
wrap_in_raw_block,
)
from pydantic import Field

from .context.context_provider import ContextProvider
Expand Down Expand Up @@ -35,6 +41,17 @@ def get_rendered(self) -> str:
def set_content(self, text: str) -> None:
self.content = text

def replace_in_content(self, pattern: str, replacement: str) -> None:
self.content = replace_in_template_content(self.content, pattern, replacement)

def replace_variable_value(self, variable_name: str, new_value: str) -> None:
pattern = rf"{variable_name}\s*=\s*(?:{NUMERIC_VALUE_PATTERN}|{JINJA_EXPRESSION_PATTERN})"
self.replace_in_content(pattern, f"{variable_name} = {new_value}")

@staticmethod
def format_as_scope_reference(variable_name: str) -> str:
return wrap_in_raw_block("{{ " + variable_name + " }}")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All three above can go to utils


def set_rendered(self, text: str) -> None:
self.rendered = text

Expand Down
14 changes: 7 additions & 7 deletions tests/fixtures/application_hash.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{
"standata": {
"name": "espresso",
"version": "6.3",
"build": "GNU"
},
"hash": "e4f4762afefb659c36b6ac2d08d860cc"
}
"standata": {
"name": "espresso",
"version": "6.3",
"build": "GNU"
},
"hash": "bc4593c715c85ee9f1ac5f9e790b490d"
}
57 changes: 57 additions & 0 deletions tests/py/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,25 @@
**TEMPLATE_DEFAULT_FIELDS,
}

CONTENT_DEGAUSS_NUMERIC = "degauss = 0.005\n"
CONTENT_ECUTWFC_JINJA = "ecutwfc = {{ cutoffs.wavefunction }}\n"
CONTENT_MIXING_BETA_SCIENTIFIC = "mixing_beta = 1e-3\n"
CONTENT_NO_MATCH = "no_match\n"

PATTERN_DEGAUSS_NUMERIC = r"degauss\s*=\s*[\d.e+\-]+"
PATTERN_OTHER_NUMERIC = r"other\s*=\s*[\d.e+\-]+"

REPLACEMENT_DEGAUSS = "degauss = NEW"
REPLACEMENT_OTHER = "other = x"
NEW_VALUE = "NEW"

EXPECTED_DEGAUSS_REPLACED = "degauss = NEW\n"
EXPECTED_ECUTWFC_REPLACED = "ecutwfc = NEW\n"
EXPECTED_MIXING_BETA_REPLACED = "mixing_beta = NEW\n"

EXPECTED_RAW_SCOPE_DEGAUSS = "{% raw %}{{ degauss }}{% endraw %}"
EXPECTED_RAW_SCOPE_CUTOFFS_WF = "{% raw %}{{ cutoffs_wf }}{% endraw %}"


@pytest.mark.parametrize(
"config,expected_fields",
Expand Down Expand Up @@ -287,3 +306,41 @@ def test_render_with_external_context_and_provider():
template.add_context_provider(PROVIDER_KPATH)
template.render(EXTERNAL_CONTEXT_KPATH)
assert template.get_rendered() == EXPECTED_EXTERNAL_CONTEXT_RENDER


@pytest.mark.parametrize(
"content,pattern,replacement,expected",
[
(CONTENT_DEGAUSS_NUMERIC, PATTERN_DEGAUSS_NUMERIC, REPLACEMENT_DEGAUSS, EXPECTED_DEGAUSS_REPLACED),
(CONTENT_NO_MATCH, PATTERN_OTHER_NUMERIC, REPLACEMENT_OTHER, CONTENT_NO_MATCH),
],
)
def test_replace_in_content(content, pattern, replacement, expected):
template = Template(name="test.in", content=content)
template.replace_in_content(pattern, replacement)
assert template.content == expected


@pytest.mark.parametrize(
"content,variable_name,new_value,expected",
[
(CONTENT_DEGAUSS_NUMERIC, "degauss", NEW_VALUE, EXPECTED_DEGAUSS_REPLACED),
(CONTENT_ECUTWFC_JINJA, "ecutwfc", NEW_VALUE, EXPECTED_ECUTWFC_REPLACED),
(CONTENT_MIXING_BETA_SCIENTIFIC, "mixing_beta", NEW_VALUE, EXPECTED_MIXING_BETA_REPLACED),
],
)
def test_replace_variable_value(content, variable_name, new_value, expected):
template = Template(name="test.in", content=content)
template.replace_variable_value(variable_name, new_value)
assert template.content == expected


@pytest.mark.parametrize(
"variable_name,expected",
[
("degauss", EXPECTED_RAW_SCOPE_DEGAUSS),
("cutoffs_wf", EXPECTED_RAW_SCOPE_CUTOFFS_WF),
],
)
def test_format_as_scope_reference(variable_name, expected):
assert Template.format_as_scope_reference(variable_name) == expected
Loading