From 44d04055f6cd2ef8017d674c1f25a2e4f0748f9e Mon Sep 17 00:00:00 2001 From: benjamink Date: Tue, 11 Mar 2025 12:07:37 -0700 Subject: [PATCH] Framework for the UIJson/Options/Driver relationships. --- tests/uijson_test.py | 39 +++++++++++ uml/classes_uijson.puml | 145 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 184 insertions(+) create mode 100644 uml/classes_uijson.puml diff --git a/tests/uijson_test.py b/tests/uijson_test.py index 1f4ccbcb..3fab4811 100644 --- a/tests/uijson_test.py +++ b/tests/uijson_test.py @@ -113,3 +113,42 @@ def test_gravity_uijson(tmp_path): params_data_nobraces[param] = field_data_nobraces assert uijson_data == params_data_nobraces + + +def test_field_handling(): + # TODO: This is was for prototyping and should be removed once the + # behaviours tested here are incorporated into the UIJson classes. + + import warnings + from typing import Annotated, Any + + from pydantic import AliasChoices, BaseModel, BeforeValidator, Field + + def deprecate(value, info): + warnings.warn( # will be a logging.warn in production + f"Field {info.field_name} is deprecated." + ) + return value + + Deprecated = Annotated[ + Any, + Field(exclude=True), + BeforeValidator(deprecate), + ] + + class MyClass(BaseModel): + a: int = 1 # Represents a newly added field with a default value. + b: int = Field( # Represents a field with a name change. + validation_alias=AliasChoices("b", "bb") + ) + c: Deprecated # Represents a deprecated field. + + test = MyClass(bb=2, c=3) + assert test.a == 1 + assert test.b == 2 + + dump = test.model_dump() + assert "c" not in dump + assert "b" in dump + assert "bb" not in dump + assert dump["a"] == 1 diff --git a/uml/classes_uijson.puml b/uml/classes_uijson.puml new file mode 100644 index 00000000..14303937 --- /dev/null +++ b/uml/classes_uijson.puml @@ -0,0 +1,145 @@ +@startuml classes_uijson +set namespaceSeparator none + +class GravityInversionUIJson { + active_model : DataForm + alpha_s : DataForm + auto_scale_misfits : BoolForm + beta_tol : FloatForm + chi_factor : FloatForm + chunk_by_rows : BoolForm + coolingFactor : FloatForm + coolingRate : IntegerForm + data_object : ObjectForm + default_ui_json : ClassVar[Path] + distributed_workers : str + every_iteration_bool : BoolForm + f_min_change : FloatForm + forward_only : bool + ga_group : str + generate_sweep : BoolForm + gradient_type : ChoiceForm + guv_channel : DataForm + guv_uncertainty : DataForm + gx_channel : DataForm + gx_uncertainty : DataForm + gxx_channel : DataForm + gxx_uncertainty : DataForm + gxy_channel : DataForm + gxy_uncertainty : DataForm + gxz_channel : DataForm + gxz_uncertainty : DataForm + gy_channel : DataForm + gy_uncertainty : DataForm + gyy_channel : DataForm + gyy_uncertainty : DataForm + gyz_channel : DataForm + gyz_uncertainty : DataForm + gz_channel : DataForm + gz_uncertainty : DataForm + gzz_channel : DataForm + gzz_uncertainty : DataForm + initial_beta : FloatForm + initial_beta_ratio : FloatForm + inversion_style : Deprecated + inversion_type : str + length_scale_x : DataForm + length_scale_y : DataForm + length_scale_z : DataForm + lower_bound : DataForm + max_cg_iterations : IntegerForm + max_chunk_size : IntegerForm + max_global_iterations : IntegerForm + max_irls_iterations : IntegerForm + max_line_search_iterations : IntegerForm + max_ram : str + mesh : ObjectForm + n_cpu : IntegerForm + out_group : GroupForm + output_tile_files : bool + parallelized : BoolForm + {field} percentile : Field(IntegerForm, validation_alias=AliasChoices("percentile", "prctile")) + reference_model : DataForm + s_norm : DataForm + save_sensitivities : BoolForm + sens_wts_threshold : FloatForm + starting_chi_factor : FloatForm + starting_model : DataForm + store_sensitivities : ChoiceForm + tile_spatial : DataForm + tol_cg : FloatForm + topography : DataForm + topography_object : ObjectForm + upper_bound : DataForm + x_norm : DataForm + y_norm : DataForm + z_norm : DataForm + validate_version() +} +note left of GravityInversionUIJson::inversion_style + - inversion_style is an example of a deprecated field. + - We can handle these with a special annotation. + - The annotation accepts Any type, contains a Field with + exclude=True set so that it is not written to the + ui.json file, and it contains a BeforeValidator that + logs a warning about the deprecation. + - The annotation definition can be stored in geoh5py and + imported and used wherever needed. +end note + + +note left of GravityInversionUIJson::prctile + - percentile is an example of a parameter whose name has + been updated. + - We can handle these with a validation_alias. + - This will accept old naming conventions, while writing + the new naming convention to file at the end of the run. +end note + +note left of GravityInversionUIJson::validate_version + - logs a warning about provided uijson version not + matching the current simpeg_drivers.__version__ + - returns the current version since the uijson will + be written without any deprecated fields and will + add any newer parameters that have default values +end note + + +class GravityDriver { + ui_json_class : ClassVar[GravityInversionUIJson] + options_class : ClassVar[GravityInversionOptions] + uijson: GravityInversionUIJson + options : GravityInversionOptions + GravityDriver from_uijson(GravityInversionUIJson) + GravityDriver from_options(GravityInversionOptions) + @classmethod start(Path) +} +note left of GravityDriver::options + - options is a property that constructs only once. + - returns self.options_class(**self.ui_json.to_params()). + - options is then passes around the factories to assemble + inversion components. +end note + +note left of GravityDriver::uijson + - uijson is a stored on creation, and is frozen to modification. + - if data is modified during the run, such as what happens when + we copy objects into the out_group, the driver can model dump, + update the data, then construct a new UIJson instance to write + the ui.json file that stores the data used in the run. +end note + +note left of GravityDriver::start + - reads plain json + - infers forward_only and inversion_type to select driver import + - use inversion_type and forward_only to import driver using + cls.driver_class_from_name() + - retrieve ui_json_class from driver and instantiate from json.load output + - instantiate driver with UIJson instance (init stores UIJson instance) + and return the driver instance +end note + +GravityDriver o-- GravityInversionUIJson + + +@enduml