Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ include = [
[tool.poetry.dependencies]
python = "^3.10"

dask-core = {version = "2025.3.*"} # also in simpeg[dask]
dask-core = "2025.3.*" # also in simpeg[dask]
discretize = "0.11.*" # also in simpeg, octree-creation-app
distributed = "2025.3.*" # because conda-lock doesn't take dask extras into account
distributed = "2025.3.*" # conda needs explicit dask-core etc for equivalent dask[distributed]
numpy = "~1.26.0" # also in geoh5py, simpeg
pydantic = "^2.5.2" # also in geoh5py, curve-apps, geoapps-utils
scikit-learn = "~1.4.0"
Expand Down
6 changes: 5 additions & 1 deletion recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ tests:
- geoh5py
- dask
- distributed
pip_check: false # pip checks fails on missing dask-core because it only sees name 'dask'

# `pip check` fails on missing dask-core because it only sees name 'dask'
# Possibly, use custom mapping for dask => dask-core
# See `conda-lock --pypi_to_conda_lookup_file`, or the `pixi` option "conda-pypi-map"
pip_check: false

- script:
- pytest --ignore=tests/version_test.py
Expand Down
37 changes: 31 additions & 6 deletions simpeg_drivers/uijson.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging

from geoh5py.ui_json.ui_json import BaseUIJson
from packaging.version import Version
from pydantic import field_validator

import simpeg_drivers
Expand All @@ -29,17 +30,41 @@ class SimPEGDriversUIJson(BaseUIJson):
@field_validator("version", mode="before")
@classmethod
def verify_and_update_version(cls, value: str) -> str:
version = simpeg_drivers.__version__
if value != version:
package_version = cls.comparable_version(simpeg_drivers.__version__)
input_version = cls.comparable_version(value)
if input_version != package_version:
logger.warning(
"Provided ui.json file version %s does not match the the current"
"simpeg-drivers version %s. This may lead to unpredictable"
"behavior.",
"Provided ui.json file version '%s' does not match the current "
"simpeg-drivers version '%s'. This may lead to unpredictable behavior.",
value,
version,
simpeg_drivers.__version__,
)
return value

@staticmethod
def comparable_version(value: str) -> str:
"""Normalize the version string for comparison.

Remove the post-release information, or the pre-release information if it is an rc version.
For example, if the version is "0.2.0.post1", it will return "0.2.0".
If the version is "0.2.0rc1", it will return "0.2.0".

Then, it will return the public version of the version object.
For example, if the version is "0.2.0+local", it will return "0.2.0".
"""
version = Version(value)

# Extract the base version (major.minor.patch)
base_version = version.base_version

# If it's not an RC, keep any pre-release info (alpha/beta)
if version.pre is not None and version.pre[0] != "rc": # pylint: disable=unsubscriptable-object
# Recreate version with pre-release but no post or local
return f"{base_version}{version.pre[0]}{version.pre[1]}"

# No pre-release info or it's an RC, return just the base version
return base_version

@classmethod
def write_default(cls):
"""Write the default UIJson file to disk with updated version."""
Expand Down
165 changes: 113 additions & 52 deletions tests/uijson_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
from typing import ClassVar

import numpy as np
import pytest
from geoh5py import Workspace
from geoh5py.ui_json.annotations import Deprecated
from packaging.version import Version
from pydantic import AliasChoices, Field

import simpeg_drivers
Expand All @@ -29,12 +31,28 @@
logger = logging.getLogger(__name__)


def test_version_warning(tmp_path, caplog):
workspace = Workspace.create(tmp_path / "test.geoh5")
def _current_version() -> Version:
"""Get the package version."""
return Version(simpeg_drivers.__version__)

with caplog.at_level(logging.WARNING):
_ = SimPEGDriversUIJson(
version="0.2.0",

@pytest.fixture(name="workspace")
def workspace_fixture(tmp_path):
"""Create a workspace for testing."""
return Workspace.create(tmp_path / "test.geoh5")


@pytest.fixture(name="simpeg_uijson_factory")
def simpeg_uijson_factory_fixture(workspace):
"""Create a SimPEGDriversUIJson object with configurable version."""

def _create_uijson(version: str | None = None, **kwargs):
"""Create a SimPEGDriversUIJson with the given version and custom fields."""
if version is None:
version = _current_version().public

return SimPEGDriversUIJson(
version=version,
title="My app",
icon="",
documentation="",
Expand All @@ -43,8 +61,87 @@ def test_version_warning(tmp_path, caplog):
monitoring_directory="",
conda_environment="my-app",
workspace_geoh5="",
**kwargs,
)

return _create_uijson


@pytest.mark.parametrize(
"version_input,expected",
[
# Normal version
("1.2.3", "1.2.3"),
# Post-release version
("1.2.3.post1", "1.2.3"),
# RC pre-release version
("1.2.3rc1", "1.2.3"),
# Alpha pre-release version (should not normalize)
("1.2.3a1", "1.2.3a1"),
# Beta pre-release version (should not normalize)
("1.2.3b1", "1.2.3b1"),
# Local version
("1.2.3+local", "1.2.3"),
# Combined cases
("1.2.3rc1.post2+local", "1.2.3"),
],
)
def test_comparable_version(version_input, expected):
"""Test the comparable_version method of SimPEGDriversUIJson."""
assert SimPEGDriversUIJson.comparable_version(version_input) == expected


@pytest.mark.parametrize(
"version_input,package_version,should_warn",
[
# Different version (should warn)
("1.0.0", "2.0.0", True),
# Same version (should not warn)
("2.0.0", "2.0.0", False),
# Post-release variant (should not warn)
("2.0.0.post1", "2.0.0", False),
("2.0.0", "2.0.0.post1", False),
# RC variant (should not warn)
("2.0.0rc1", "2.0.0", False),
("2.0.0", "2.0.0rc1", False),
("2.0.0rc1", "2.0.0rc2", False),
# differ by the pre-release number, non RC (should warn)
("2.0.0a1", "2.0.0a2", True),
("2.0.0b1", "2.0.0b2", True),
("2.0.0a1", "2.0.0", True),
("2.0.0", "2.0.0a1", True),
("2.0.0a1", "2.0.0b1", True),
("2.0.0b1", "2.0.0a1", True),
("2.0.0rc1", "2.0.0b1", True),
("2.0.0b1", "2.0.0rc1", True),
# same normalized versions (should not warn)
("2.0.0-beta.1", "2.0.0b1", False),
("2.0.0b1", "2.0.0-beta.1", False),
],
)
def test_version_warning(
monkeypatch,
caplog,
simpeg_uijson_factory,
version_input,
package_version,
should_warn,
):
"""Test version warning behavior with mocked package version."""
# Mock the package version
monkeypatch.setattr(simpeg_drivers, "__version__", package_version)

with caplog.at_level(logging.WARNING):
caplog.clear()
_ = simpeg_uijson_factory(version=version_input)

warning_message = f"version '{version_input}' does not match the current simpeg-drivers version"
warning_found = any(
warning_message in record.message for record in caplog.records
)

assert warning_found == should_warn


def test_write_default(tmp_path):
default_path = tmp_path / "default.ui.json"
Expand All @@ -69,70 +166,34 @@ class MyUIJson(SimPEGDriversUIJson):
with open(default_path, encoding="utf-8") as f:
data = json.load(f)

assert data["version"] == "0.3.0-alpha.1"

# Use comparable_version for comparison to handle pre/post-release versions
assert SimPEGDriversUIJson.comparable_version(
data["version"]
) == SimPEGDriversUIJson.comparable_version(simpeg_drivers.__version__)

def test_deprecations(tmp_path, caplog):
workspace = Workspace.create(tmp_path / "test.geoh5")

def test_deprecations(caplog, simpeg_uijson_factory):
class MyUIJson(SimPEGDriversUIJson):
my_param: Deprecated

with caplog.at_level(logging.WARNING):
_ = MyUIJson(
version="0.3.0-alpha.1",
title="My app",
icon="",
documentation="",
geoh5=str(workspace.h5file),
run_command="myapp.driver",
monitoring_directory="",
conda_environment="my-app",
workspace_geoh5="",
my_param="whoopsie",
)
_ = MyUIJson(**simpeg_uijson_factory().model_dump(), my_param="whoopsie")
assert "Skipping deprecated field: my_param." in caplog.text


def test_pydantic_deprecation(tmp_path):
workspace = Workspace.create(tmp_path / "test.geoh5")

def test_pydantic_deprecation(simpeg_uijson_factory):
class MyUIJson(SimPEGDriversUIJson):
my_param: str = Field(deprecated="Use my_param2 instead.", exclude=True)

uijson = MyUIJson(
version="0.3.0-alpha.1",
title="My app",
icon="",
documentation="",
geoh5=str(workspace.h5file),
run_command="myapp.driver",
monitoring_directory="",
conda_environment="my-app",
workspace_geoh5="",
my_param="whoopsie",
)
uijson = MyUIJson(**simpeg_uijson_factory(my_param="whoopsie").model_dump())
assert "my_param" not in uijson.model_dump()


def test_alias(tmp_path):
workspace = Workspace.create(tmp_path / "test.geoh5")

def test_alias(simpeg_uijson_factory):
class MyUIJson(SimPEGDriversUIJson):
my_param: str = Field(validation_alias=AliasChoices("my_param", "myParam"))

uijson = MyUIJson(
version="0.3.0-alpha.1",
title="My app",
icon="",
documentation="",
geoh5=str(workspace.h5file),
run_command="myapp.driver",
monitoring_directory="",
conda_environment="my-app",
workspace_geoh5="",
myParam="hello",
)
uijson = MyUIJson(**simpeg_uijson_factory(myParam="hello").model_dump())
assert uijson.my_param == "hello"
assert "myParam" not in uijson.model_fields_set
assert "my_param" in uijson.model_dump()
Expand Down Expand Up @@ -170,7 +231,7 @@ def test_gravity_uijson(tmp_path):
uijson.write(uijson_path)
with open(params_uijson_path, encoding="utf-8") as f:
params_data = json.load(f)
assert params_data["version"] == simpeg_drivers.__version__
assert Version(params_data["version"]) == Version(_current_version().public)
with open(uijson_path, encoding="utf-8") as f:
uijson_data = json.load(f)

Expand Down
Loading