diff --git a/simplexity/logging/logger.py b/simplexity/logging/logger.py deleted file mode 100644 index 6638bce0..00000000 --- a/simplexity/logging/logger.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Logger interface for logging to a variety of backends.""" - -from abc import ABC, abstractmethod -from collections.abc import Mapping -from typing import Any - -import matplotlib.figure -import mlflow -import numpy -import PIL.Image -import plotly.graph_objects -from omegaconf import DictConfig - - -class Logger(ABC): - """Logs to a variety of backends.""" - - @abstractmethod - def log_config(self, config: DictConfig, resolve: bool = False) -> None: - """Log config to the logger.""" - - @abstractmethod - def log_metrics(self, step: int, metric_dict: Mapping[str, Any]) -> None: - """Log metrics to the logger.""" - - @abstractmethod - def log_params(self, param_dict: Mapping[str, Any]) -> None: - """Log params to the logger.""" - - @abstractmethod - def log_tags(self, tag_dict: Mapping[str, Any]) -> None: - """Log tags to the logger.""" - - @abstractmethod - def log_figure( - self, - figure: matplotlib.figure.Figure | plotly.graph_objects.Figure, - artifact_file: str, - **kwargs, - ) -> None: - """Log a figure to the logger.""" - - @abstractmethod - def log_image( - self, - image: numpy.ndarray | PIL.Image.Image | mlflow.Image, - artifact_file: str | None = None, - key: str | None = None, - step: int | None = None, - **kwargs, - ) -> None: - """Log an image to the logger. - - Args: - image: Image to log (numpy array, PIL Image, or mlflow Image) - artifact_file: File path for artifact mode (e.g., "image.png") - key: Key name for time-stepped mode (requires step parameter) - step: Step number for time-stepped mode (requires key parameter) - **kwargs: Additional arguments passed to the underlying save method - - Note: - Must provide either artifact_file OR both key and step parameters. - Providing neither or only one of key/step will result in an error. - """ - - @abstractmethod - def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None: - """Log an artifact (file or directory) to the logger. - - Args: - local_path: Path to the local file or directory to log - artifact_path: Optional artifact path within the experiment run. - If None, uses the filename from local_path. - """ - - @abstractmethod - def log_json_artifact(self, data: dict | list, artifact_name: str) -> None: - """Log a JSON object as an artifact to the logger. - - Args: - data: Dictionary or list to serialize as JSON - artifact_name: Name for the artifact (e.g., "results.json") - """ - - @abstractmethod - def close(self) -> None: - """Close the logger.""" diff --git a/simplexity/logging/mlflow_logger.py b/simplexity/logging/mlflow_logger.py deleted file mode 100644 index 900b8a55..00000000 --- a/simplexity/logging/mlflow_logger.py +++ /dev/null @@ -1,215 +0,0 @@ -"""MLFlowLogger class for logging to MLflow.""" - -# pylint: disable-all -# Temporarily disable all pylint checkers during AST traversal to prevent crash. -# The imports checker crashes when resolving simplexity package imports due to a bug -# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 -# pylint: enable=all -# Re-enable all pylint checkers for the checking phase. This allows other checks -# (code quality, style, undefined names, etc.) to run normally while bypassing -# the problematic imports checker that would crash during AST traversal. - -import json -import os -import tempfile -import time -from collections.abc import Mapping -from typing import Any - -import dotenv -import matplotlib.figure -import mlflow -import numpy -import PIL.Image -import plotly.graph_objects -from mlflow.entities import Metric, Param, RunTag -from omegaconf import DictConfig, OmegaConf - -from simplexity.logging.logger import Logger -from simplexity.structured_configs.logging import MLFlowLoggerInstanceConfig -from simplexity.utils.mlflow_utils import ( - get_experiment, - get_run, - maybe_terminate_run, - resolve_registry_uri, -) - -dotenv.load_dotenv() -_DATABRICKS_HOST = os.getenv("DATABRICKS_HOST") - - -class MLFlowLogger(Logger): - """Logs to MLflow Tracking.""" - - def __init__( - self, - experiment_id: str | None = None, - experiment_name: str | None = None, - run_id: str | None = None, - run_name: str | None = None, - tracking_uri: str | None = None, - registry_uri: str | None = None, - downgrade_unity_catalog: bool | None = None, - ): - """Initialize MLflow logger.""" - self._downgrade_unity_catalog = downgrade_unity_catalog if downgrade_unity_catalog is not None else True - resolved_registry_uri = resolve_registry_uri( - registry_uri=registry_uri, - tracking_uri=tracking_uri, - downgrade_unity_catalog=downgrade_unity_catalog, - ) - self._client = mlflow.MlflowClient(tracking_uri=tracking_uri, registry_uri=resolved_registry_uri) - experiment = get_experiment(experiment_id=experiment_id, experiment_name=experiment_name, client=self.client) - assert experiment is not None - self._experiment_id = experiment.experiment_id - self._experiment_name = experiment.name - run = get_run(run_id=run_id, run_name=run_name, experiment_id=self.experiment_id, client=self.client) - assert run is not None - self._run_id = run.info.run_id - self._run_name = run.info.run_name - - @property - def client(self) -> mlflow.MlflowClient: - """Expose underlying MLflow client for integrations.""" - return self._client - - @property - def experiment_name(self) -> str: - """Expose active MLflow experiment name.""" - return self._experiment_name - - @property - def experiment_id(self) -> str: - """Expose active MLflow experiment identifier.""" - return self._experiment_id - - @property - def run_name(self) -> str | None: - """Expose active MLflow run name.""" - return self._run_name - - @property - def run_id(self) -> str: - """Expose active MLflow run identifier.""" - return self._run_id - - @property - def tracking_uri(self) -> str | None: - """Return the tracking URI associated with this logger.""" - return self.client.tracking_uri - - @property - def registry_uri(self) -> str | None: - """Return the model registry URI associated with this logger.""" - return self.client._registry_uri # pylint: disable=protected-access - - @property - def cfg(self) -> MLFlowLoggerInstanceConfig: - """Return the configuration of this logger.""" - return MLFlowLoggerInstanceConfig( - _target_=self.__class__.__qualname__, - experiment_id=self.experiment_id, - experiment_name=self.experiment_name, - run_id=self.run_id, - run_name=self.run_name, - tracking_uri=self.tracking_uri, - registry_uri=self.registry_uri, - downgrade_unity_catalog=self._downgrade_unity_catalog, - ) - - def log_config(self, config: DictConfig, resolve: bool = False) -> None: - """Log config to MLflow.""" - with tempfile.TemporaryDirectory() as temp_dir: - config_path = os.path.join(temp_dir, "config.yaml") - OmegaConf.save(config, config_path, resolve=resolve) - self.client.log_artifact(self.run_id, config_path) - - def log_metrics(self, step: int, metric_dict: Mapping[str, Any]) -> None: - """Log metrics to MLflow.""" - timestamp = int(time.time() * 1000) - metrics = self._flatten_metric_dict(metric_dict, timestamp, step) - self._log_batch(metrics=metrics) - - def _flatten_metric_dict( - self, metric_dict: Mapping[str, Any], timestamp: int, step: int, key_prefix: str = "" - ) -> list[Metric]: - """Flatten a dictionary of metrics into a list of Metric entities.""" - metrics = [] - for key, value in metric_dict.items(): - key = f"{key_prefix}/{key}" if key_prefix else key - if isinstance(value, Mapping): - nested_metrics = self._flatten_metric_dict(value, timestamp, step, key_prefix=key) - metrics.extend(nested_metrics) - else: - value = float(value) - metric = Metric(key, value, timestamp, step) - metrics.append(metric) - return metrics - - def log_params(self, param_dict: Mapping[str, Any]) -> None: - """Log params to MLflow.""" - params = self._flatten_param_dict(param_dict) - self._log_batch(params=params) - - def _flatten_param_dict(self, param_dict: Mapping[str, Any], key_prefix: str = "") -> list[Param]: - """Flatten a dictionary of params into a list of Param entities.""" - params = [] - for key, value in param_dict.items(): - key = f"{key_prefix}.{key}" if key_prefix else key - if isinstance(value, Mapping): - nested_params = self._flatten_param_dict(value, key_prefix=key) - params.extend(nested_params) - else: - value = str(value) - param = Param(key, value) - params.append(param) - return params - - def log_tags(self, tag_dict: Mapping[str, Any]) -> None: - """Set tags on the MLFlow.""" - tags = [RunTag(k, str(v)) for k, v in tag_dict.items()] - self._log_batch(tags=tags) - - def log_figure( - self, - figure: matplotlib.figure.Figure | plotly.graph_objects.Figure, - artifact_file: str, - **kwargs, - ) -> None: - """Log a figure to MLflow using MLflowClient.log_figure.""" - self.client.log_figure(self.run_id, figure, artifact_file, **kwargs) - - def log_image( - self, - image: numpy.ndarray | PIL.Image.Image | mlflow.Image, - artifact_file: str | None = None, - key: str | None = None, - step: int | None = None, - **kwargs, - ) -> None: - """Log an image to MLflow using MLflowClient.log_image.""" - # Parameter validation - ensure we have either artifact_file or (key + step) - if not artifact_file and not (key and step is not None): - raise ValueError("Must provide either artifact_file or both key and step parameters") - - self.client.log_image(self.run_id, image, artifact_file=artifact_file, key=key, step=step, **kwargs) - - def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None: - """Log an artifact (file or directory) to MLflow.""" - self.client.log_artifact(self.run_id, local_path, artifact_path) - - def log_json_artifact(self, data: dict | list, artifact_name: str) -> None: - """Log a JSON object as an artifact to MLflow.""" - with tempfile.TemporaryDirectory() as temp_dir: - json_path = os.path.join(temp_dir, artifact_name) - with open(json_path, "w", encoding="utf-8") as f: - json.dump(data, f, indent=2) - self.client.log_artifact(self.run_id, json_path) - - def close(self) -> None: - """End the MLflow run.""" - maybe_terminate_run(run_id=self.run_id, client=self.client) - - def _log_batch(self, **kwargs: Any) -> None: - """Log arbitrary data to MLflow.""" - self.client.log_batch(self.run_id, **kwargs, synchronous=False) diff --git a/simplexity/logging/print_logger.py b/simplexity/logging/print_logger.py deleted file mode 100644 index 33629e95..00000000 --- a/simplexity/logging/print_logger.py +++ /dev/null @@ -1,87 +0,0 @@ -"""PrintLogger class for logging to the console.""" - -# pylint: disable-all -# Temporarily disable all pylint checkers during AST traversal to prevent crash. -# The imports checker crashes when resolving simplexity package imports due to a bug -# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 -# pylint: enable=all -# Re-enable all pylint checkers for the checking phase. This allows other checks -# (code quality, style, undefined names, etc.) to run normally while bypassing -# the problematic imports checker that would crash during AST traversal. - -from collections.abc import Mapping -from pprint import pprint -from typing import Any - -import matplotlib.figure -import mlflow -import numpy -import PIL.Image -import plotly.graph_objects -from omegaconf import DictConfig, OmegaConf - -from simplexity.logging.logger import Logger - - -class PrintLogger(Logger): - """Logs to the console.""" - - def log_config(self, config: DictConfig, resolve: bool = False) -> None: - """Log config to the console.""" - _config = OmegaConf.to_container(config, resolve=resolve) - pprint(f"Config: {_config}") - - def log_metrics(self, step: int, metric_dict: Mapping[str, Any]) -> None: - """Log metrics to the console.""" - pprint(f"Metrics at step {step}: {metric_dict}") - - def log_params(self, param_dict: Mapping[str, Any]) -> None: - """Log params to the console.""" - pprint(f"Params: {param_dict}") - - def log_tags(self, tag_dict: Mapping[str, Any]) -> None: - """Log tags to the console.""" - pprint(f"Tags: {tag_dict}") - - def log_figure( - self, - figure: matplotlib.figure.Figure | plotly.graph_objects.Figure, - artifact_file: str, - **kwargs, - ) -> None: - """Log figure info to the console (no actual figure saved).""" - print(f"[PrintLogger] Figure NOT saved - would be: {artifact_file} (type: {type(figure).__name__})") - - def log_image( - self, - image: numpy.ndarray | PIL.Image.Image | mlflow.Image, - artifact_file: str | None = None, - key: str | None = None, - step: int | None = None, - **kwargs, - ) -> None: - """Log image info to the console (no actual image saved).""" - # Parameter validation - ensure we have either artifact_file or (key + step) - if not artifact_file and not (key and step is not None): - print("[PrintLogger] Image logging failed - need either artifact_file or (key + step)") - return - - if artifact_file: - print(f"[PrintLogger] Image NOT saved - would be artifact: {artifact_file} (type: {type(image).__name__})") - else: - print(f"[PrintLogger] Image NOT saved - would be key: {key}, step: {step} (type: {type(image).__name__})") - - def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None: - """Print artifact info to the console (no actual artifact logged).""" - dest_name = artifact_path if artifact_path else f"" - print(f"[PrintLogger] Artifact NOT logged - would copy: {local_path} -> {dest_name}") - - def log_json_artifact(self, data: dict | list, artifact_name: str) -> None: - """Print JSON artifact info to the console (no actual artifact saved).""" - data_type = "dict" if isinstance(data, dict) else "list" - data_size = len(data) - print(f"[PrintLogger] JSON artifact NOT saved - would be: {artifact_name} ({data_type} with {data_size} items)") - - def close(self) -> None: - """Close the logger.""" - pass diff --git a/simplexity/persistence/local_persister.py b/simplexity/persistence/local_persister.py deleted file mode 100644 index 7c5f4fdd..00000000 --- a/simplexity/persistence/local_persister.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Local persister.""" - -# pylint: disable-all -# Temporarily disable all pylint checkers during AST traversal to prevent crash. -# The imports checker crashes when resolving simplexity package imports due to a bug -# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 -# pylint: enable=all -# Re-enable all pylint checkers for the checking phase. This allows other checks -# (code quality, style, undefined names, etc.) to run normally while bypassing -# the problematic imports checker that would crash during AST traversal. - -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Any - - -class LocalPersister(ABC): - """Persists a model to the local filesystem.""" - - directory: Path - - def cleanup(self) -> None: # noqa: B027 - """Cleans up the persister.""" - - @abstractmethod - def save_weights(self, model: Any, step: int = 0) -> None: - """Saves a model.""" - - @abstractmethod - def load_weights(self, model: Any, step: int = 0) -> Any: - """Load weights into an existing model instance.""" diff --git a/simplexity/persistence/model_persister.py b/simplexity/persistence/model_persister.py deleted file mode 100644 index 0707d87e..00000000 --- a/simplexity/persistence/model_persister.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Model persister protocol.""" - -from typing import Any, Protocol - - -class ModelPersister(Protocol): - """Persists a model to a file.""" - - def cleanup(self) -> None: - """Cleans up the persister.""" - - def save_weights(self, model: Any, step: int = 0) -> None: - """Saves a model.""" - - def load_weights(self, model: Any, step: int = 0) -> Any: - """Load weights into an existing model instance.""" - ... # pylint: disable=unnecessary-ellipsis diff --git a/simplexity/persistence/s3_persister.py b/simplexity/persistence/s3_persister.py deleted file mode 100644 index dcf05d10..00000000 --- a/simplexity/persistence/s3_persister.py +++ /dev/null @@ -1,169 +0,0 @@ -"""S3 persister for predictive models.""" - -import configparser -import tempfile -from collections.abc import Iterable, Mapping -from pathlib import Path -from typing import Any, Protocol - -import boto3.session -from botocore.exceptions import ClientError - -from simplexity.persistence.local_equinox_persister import LocalEquinoxPersister -from simplexity.persistence.local_penzai_persister import LocalPenzaiPersister -from simplexity.persistence.local_persister import LocalPersister -from simplexity.persistence.local_pytorch_persister import LocalPytorchPersister -from simplexity.predictive_models.types import ModelFramework - - -class S3Paginator(Protocol): - """Protocol for an S3 paginator. - - https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#paginators - Since boto3 does not currently support type checking: https://github.com/boto/boto3/issues/1055 - """ - - def paginate(self, Bucket: str, Prefix: str) -> Iterable[Mapping[str, Any]]: # pylint: disable=invalid-name - """Paginate over the objects in an S3 bucket.""" - ... # pylint: disable=unnecessary-ellipsis - - -class S3Client(Protocol): - """Protocol for S3 client. - - https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client - Since boto3 does not currently support type checking: https://github.com/boto/boto3/issues/1055 - """ - - def upload_file(self, file_name: str, bucket: str, object_name: str) -> None: - """Upload a file to S3.""" - - def download_file(self, bucket: str, object_name: str, file_name: str) -> None: - """Download a file from S3.""" - - def get_paginator(self, operation_name: str) -> S3Paginator: - """Get a paginator for the given operation.""" - ... # pylint: disable=unnecessary-ellipsis - - -class S3Persister: - """Persists a model to an S3 bucket.""" - - def __init__( - self, - bucket: str, - prefix: str, - s3_client: S3Client, - temp_dir: tempfile.TemporaryDirectory, - local_persister: LocalPersister, - ): - self.bucket = bucket - self.prefix = prefix - self.s3_client = s3_client - self.temp_dir = temp_dir - self.local_persister = local_persister - - @classmethod - def from_config( - cls, - prefix: str, - model_framework: ModelFramework = ModelFramework.EQUINOX, - config_filename: str = "config.ini", - ) -> "S3Persister": - """Creates a new S3Persister from configuration parameters. - - Args: - prefix: S3 prefix for model storage (from YAML config) - model_framework: Framework for local persistence - config_filename: Path to config.ini file containing AWS settings - """ - config = configparser.ConfigParser() - config.read(config_filename) - - bucket = config.get("s3", "bucket") - profile_name = config.get("aws", "profile_name", fallback="default") - session = boto3.session.Session(profile_name=profile_name) - s3_client = session.client("s3") - temp_dir = tempfile.TemporaryDirectory() - if model_framework == ModelFramework.EQUINOX: - local_persister = LocalEquinoxPersister(directory=temp_dir.name) - elif model_framework == ModelFramework.PENZAI: - local_persister = LocalPenzaiPersister(directory=temp_dir.name) - elif model_framework == ModelFramework.PYTORCH: - local_persister = LocalPytorchPersister(directory=temp_dir.name) - else: - raise ValueError(f"Unsupported model framework: {model_framework}") - - return cls( - bucket=bucket, - prefix=prefix, - s3_client=s3_client, # type: ignore - temp_dir=temp_dir, - local_persister=local_persister, - ) - - def cleanup(self) -> None: - """Cleans up the temporary directory.""" - self.temp_dir.cleanup() - - def save_weights(self, model: Any, step: int = 0) -> None: - """Saves a model to S3.""" - self.local_persister.save_weights(model, step) - directory = self.local_persister.directory / str(step) - self._upload_local_directory(directory) - - def load_weights(self, model: Any, step: int = 0) -> Any: - """Loads a model from S3.""" - self._download_s3_objects(step) - return self.local_persister.load_weights(model, step) - - def _upload_local_directory(self, directory: Path) -> None: - for root, _, files in directory.walk(): - for file in files: - file_path = root / file - relative_path = file_path.relative_to(directory.parent) - object_name = f"{self.prefix}/{relative_path}" - file_name = str(file_path) - self._upload_local_file(file_name, object_name) - - def _upload_local_file(self, file_name: str, object_name: str) -> None: - try: - self.s3_client.upload_file(file_name, self.bucket, object_name) - except ClientError as e: - error_code = e.response.get("Error", {}).get("Code", "Unknown") - if error_code == "NoSuchBucket": - raise RuntimeError(f"Bucket {self.bucket} does not exist") from e - elif error_code == "AccessDenied": - raise RuntimeError(f"Access denied to bucket {self.bucket}") from e - else: - raise RuntimeError(f"Failed to save {file_name} to S3: {e}") from e - except Exception as e: - raise RuntimeError(f"Unexpected error saving {file_name} to S3: {e}") from e - - def _download_s3_objects(self, step: int) -> None: - prefix = f"{self.prefix}/{step}" - paginator = self.s3_client.get_paginator("list_objects_v2") - for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix): - for obj in page.get("Contents", []): - object_name = obj["Key"] - relative_path = Path(object_name).relative_to(self.prefix) - file_name = str(self.local_persister.directory / relative_path) - self._download_s3_object(object_name, file_name) - - def _download_s3_object(self, object_name: str, file_name: str) -> None: - try: - local_path = Path(file_name) - local_path.parent.mkdir(parents=True, exist_ok=True) - self.s3_client.download_file(self.bucket, object_name, file_name) - except ClientError as e: - error_code = e.response.get("Error", {}).get("Code", "Unknown") - if error_code == "NoSuchKey": - raise RuntimeError(f"{file_name} not found in bucket {self.bucket}") from e - elif error_code == "NoSuchBucket": - raise RuntimeError(f"Bucket {self.bucket} does not exist") from e - elif error_code == "AccessDenied": - raise RuntimeError(f"Access denied to bucket {self.bucket}") from e - else: - raise RuntimeError(f"Failed to load {file_name} from S3: {e}") from e - except Exception as e: - raise RuntimeError(f"Unexpected error loading {file_name} from S3: {e}") from e diff --git a/simplexity/run_management/components.py b/simplexity/run_management/components.py index 4fb34192..f4ae1df5 100644 --- a/simplexity/run_management/components.py +++ b/simplexity/run_management/components.py @@ -14,35 +14,29 @@ from simplexity.activations.activation_tracker import ActivationTracker from simplexity.generative_processes.generative_process import GenerativeProcess -from simplexity.logging.logger import Logger from simplexity.metrics.metric_tracker import MetricTracker -from simplexity.persistence.model_persister import ModelPersister +from simplexity.tracking.tracker import RunTracker @dataclass class Components: """Components for the run.""" - loggers: dict[str, Logger] | None = None + run_trackers: dict[str, RunTracker] | None = None generative_processes: dict[str, GenerativeProcess] | None = None - persisters: dict[str, ModelPersister] | None = None predictive_models: dict[str, Any] | None = None # TODO: improve typing optimizers: dict[str, Any] | None = None # TODO: improve typing metric_trackers: dict[str, MetricTracker] | None = None activation_trackers: dict[str, ActivationTracker] | None = None - def get_logger(self, key: str | None = None) -> Logger | None: - """Get the logger.""" - return self._get_instance_by_key(self.loggers, key, "logger") + def get_run_tracker(self, key: str | None = None) -> RunTracker | None: + """Get the run tracker.""" + return self._get_instance_by_key(self.run_trackers, key, "run tracker") def get_generative_process(self, key: str | None = None) -> GenerativeProcess | None: """Get the generative process.""" return self._get_instance_by_key(self.generative_processes, key, "generative process") - def get_persister(self, key: str | None = None) -> ModelPersister | None: - """Get the persister.""" - return self._get_instance_by_key(self.persisters, key, "persister") - def get_predictive_model(self, key: str | None = None) -> Any | None: """Get the predictive model.""" return self._get_instance_by_key(self.predictive_models, key, "predictive model") diff --git a/simplexity/run_management/run_logging.py b/simplexity/run_management/run_logging.py index 5a5ea3ce..cafe182e 100644 --- a/simplexity/run_management/run_logging.py +++ b/simplexity/run_management/run_logging.py @@ -18,7 +18,7 @@ from hydra.core.hydra_config import HydraConfig from simplexity.logger import SIMPLEXITY_LOGGER -from simplexity.logging.logger import Logger +from simplexity.tracking.tracker import RunTracker from simplexity.utils.git_utils import get_git_info @@ -73,17 +73,17 @@ def _get_calling_file_path() -> str | None: return None -def log_git_info(logger: Logger) -> None: +def log_git_info(tracker: RunTracker) -> None: """Log git information for reproducibility. Logs git information for the main repository where training is running. """ tags = {f"git.main.{k}": v for k, v in get_git_info().items()} if tags: - logger.log_tags(tags) + tracker.log_tags(tags) -def log_environment_artifacts(logger: Logger) -> None: +def log_environment_artifacts(tracker: RunTracker) -> None: """Log environment configuration files as MLflow artifacts for reproducibility. Logs dependency lockfile, project configuration, and system information @@ -92,10 +92,10 @@ def log_environment_artifacts(logger: Logger) -> None: environment_objects = ["uv.lock", "pyproject.toml"] for obj in environment_objects: if Path(obj).exists(): - logger.log_artifact(str(obj), "environment") + tracker.log_artifact(str(obj), "environment") -def log_system_info(logger: Logger) -> None: +def log_system_info(tracker: RunTracker) -> None: """Generate and log system information as an artifact.""" with tempfile.TemporaryDirectory() as temp_dir: info_path = Path(temp_dir) / "system_info.txt" @@ -106,10 +106,10 @@ def log_system_info(logger: Logger) -> None: f.write(f"Machine: {platform.machine()}\n") f.write(f"Processor: {platform.processor()}\n") - logger.log_artifact(str(info_path), "environment") + tracker.log_artifact(str(info_path), "environment") -def log_hydra_artifacts(logger: Logger) -> None: +def log_hydra_artifacts(tracker: RunTracker) -> None: """Log Hydra artifacts for reproducibility.""" try: hydra_dir = Path(HydraConfig.get().runtime.output_dir) / ".hydra" @@ -120,15 +120,15 @@ def log_hydra_artifacts(logger: Logger) -> None: path = hydra_dir / artifact if path.exists(): try: - logger.log_artifact(str(path), artifact_path=".hydra") + tracker.log_artifact(str(path), artifact_path=".hydra") except Exception as e: SIMPLEXITY_LOGGER.warning("Failed to log Hydra artifact %s: %s", path, e) -def log_source_script(logger: Logger) -> None: +def log_source_script(tracker: RunTracker) -> None: """Log the source script for reproducibility.""" calling_file_path = _get_calling_file_path() if calling_file_path: - logger.log_artifact(calling_file_path, artifact_path="source") + tracker.log_artifact(calling_file_path, artifact_path="source") else: SIMPLEXITY_LOGGER.warning("Failed to log source script") diff --git a/simplexity/run_management/run_management.py b/simplexity/run_management/run_management.py index ed75a730..2536803f 100644 --- a/simplexity/run_management/run_management.py +++ b/simplexity/run_management/run_management.py @@ -1,7 +1,7 @@ """Run management utilities for orchestrating experiment setup and teardown. This module centralizes environment setup, configuration resolution, component -instantiation (logging, generative processes, models, optimizers), MLflow run +instantiation (tracking, generative processes, models, optimizers), MLflow run management, and cleanup via the `managed_run` decorator. """ @@ -32,10 +32,6 @@ from simplexity.generative_processes.generative_process import GenerativeProcess from simplexity.logger import SIMPLEXITY_LOGGER -from simplexity.logging.logger import Logger -from simplexity.logging.mlflow_logger import MLFlowLogger -from simplexity.persistence.mlflow_persister import MLFlowPersister -from simplexity.persistence.model_persister import ModelPersister from simplexity.run_management.components import Components from simplexity.run_management.run_logging import ( log_environment_artifacts, @@ -54,11 +50,6 @@ resolve_generative_process_config, validate_generative_process_config, ) -from simplexity.structured_configs.logging import ( - is_logger_target, - update_logging_instance_config, - validate_logging_config, -) from simplexity.structured_configs.metric_tracker import ( is_metric_tracker_target, validate_metric_tracker_config, @@ -69,16 +60,18 @@ is_pytorch_optimizer_config, validate_optimizer_config, ) -from simplexity.structured_configs.persistence import ( - is_model_persister_target, - update_persister_instance_config, - validate_persistence_config, -) from simplexity.structured_configs.predictive_model import ( is_hooked_transformer_config, is_predictive_model_target, resolve_hooked_transformer_config, ) +from simplexity.structured_configs.tracking import ( + is_run_tracker_target, + update_tracking_instance_config, + validate_tracking_config, +) +from simplexity.tracking.mlflow_tracker import MlflowTracker +from simplexity.tracking.tracker import RunTracker from simplexity.utils.config_utils import ( filter_instance_keys, get_config, @@ -257,39 +250,49 @@ def _setup_mlflow(cfg: DictConfig) -> mlflow.ActiveRun | nullcontext[None]: ) -def _instantiate_logger(cfg: DictConfig, instance_key: str) -> Logger: - """Setup the logging.""" +def _instantiate_tracker(cfg: DictConfig, instance_key: str) -> RunTracker: + """Setup the tracker.""" instance_config = OmegaConf.select(cfg, instance_key, throw_on_missing=True) if instance_config: - logger = typed_instantiate(instance_config, Logger) - SIMPLEXITY_LOGGER.info("[logging] instantiated logger: %s", logger.__class__.__name__) - if isinstance(logger, MLFlowLogger): - updated_cfg = OmegaConf.structured(logger.cfg) - update_logging_instance_config(instance_config, updated_cfg=updated_cfg) - return logger + tracker = typed_instantiate(instance_config, RunTracker) + SIMPLEXITY_LOGGER.info("[tracking] instantiated tracker: %s", tracker.__class__.__name__) + if isinstance(tracker, MlflowTracker): + updated_cfg = OmegaConf.structured(tracker.cfg) + update_tracking_instance_config(instance_config, updated_cfg=updated_cfg) + return tracker raise KeyError -def _setup_logging(cfg: DictConfig, instance_keys: list[str], *, strict: bool) -> dict[str, Logger] | None: +def _setup_tracking(cfg: DictConfig, instance_keys: list[str], *, strict: bool) -> dict[str, RunTracker] | None: instance_keys = filter_instance_keys( cfg, instance_keys, - is_logger_target, - validate_fn=validate_logging_config, - component_name="logging", + is_run_tracker_target, + validate_fn=validate_tracking_config, + component_name="tracking", ) if instance_keys: - loggers = {instance_key: _instantiate_logger(cfg, instance_key) for instance_key in instance_keys} + trackers = {instance_key: _instantiate_tracker(cfg, instance_key) for instance_key in instance_keys} if strict: - mlflow_loggers = [logger for logger in loggers.values() if isinstance(logger, MLFlowLogger)] - assert mlflow_loggers, "Logger must be an instance of MLFlowLogger" + mlflow_trackers = [tracker for tracker in trackers.values() if isinstance(tracker, MlflowTracker)] + assert mlflow_trackers, "No MLFlow trackers found" assert any( - logger.tracking_uri and logger.tracking_uri.startswith("databricks") for logger in mlflow_loggers + tracker.tracking_uri and tracker.tracking_uri.startswith("databricks") for tracker in mlflow_trackers ), "Tracking URI must start with 'databricks'" - return loggers - SIMPLEXITY_LOGGER.info("[logging] no logging configs found") + return trackers + SIMPLEXITY_LOGGER.info("[tracking] no tracking configs found") if strict: - raise ValueError(f"Config must contain 1 logger, {len(instance_keys)} found") + raise ValueError("No tracking configs found (strict mode requires at least one tracker)") + return None + + +def _get_tracker(trackers: dict[str, RunTracker] | None) -> RunTracker | None: + if trackers: + if len(trackers) == 1: + return next(iter(trackers.values())) + SIMPLEXITY_LOGGER.warning("[tracking] multiple trackers found, any model loading will be skipped") + return None + SIMPLEXITY_LOGGER.warning("[tracking] no trackers found, any model loading will be skipped") return None @@ -335,43 +338,6 @@ def _setup_generative_processes(cfg: DictConfig, instance_keys: list[str]) -> di return None -def _instantiate_persister(cfg: DictConfig, instance_key: str) -> ModelPersister: - """Setup the persister.""" - instance_config = OmegaConf.select(cfg, instance_key, throw_on_missing=True) - if instance_config: - persister: ModelPersister = hydra.utils.instantiate(instance_config) - SIMPLEXITY_LOGGER.info("[persister] instantiated persister: %s", persister.__class__.__name__) - if isinstance(persister, MLFlowPersister): - updated_cfg = OmegaConf.structured(persister.cfg) - update_persister_instance_config(instance_config, updated_cfg=updated_cfg) - return persister - raise KeyError - - -def _setup_persisters(cfg: DictConfig, instance_keys: list[str]) -> dict[str, ModelPersister] | None: - instance_keys = filter_instance_keys( - cfg, - instance_keys, - is_model_persister_target, - validate_fn=validate_persistence_config, - component_name="persistence", - ) - if instance_keys: - return {instance_key: _instantiate_persister(cfg, instance_key) for instance_key in instance_keys} - SIMPLEXITY_LOGGER.info("[persister] no persister configs found") - return None - - -def _get_persister(persisters: dict[str, ModelPersister] | None) -> ModelPersister | None: - if persisters: - if len(persisters) == 1: - return next(iter(persisters.values())) - SIMPLEXITY_LOGGER.warning("Multiple persisters found, any model model checkpoint loading will be skipped") - return None - SIMPLEXITY_LOGGER.warning("No persister found, any model checkpoint loading will be skipped") - return None - - def _get_attribute_value(cfg: DictConfig, instance_keys: list[str], attribute_name: str) -> int | None: """Get the vocab size.""" instance_keys = filter_instance_keys( @@ -413,18 +379,17 @@ def _instantiate_predictive_model(cfg: DictConfig, instance_key: str) -> Any: raise KeyError -def _load_checkpoint(model: Any, persisters: dict[str, ModelPersister] | None, load_checkpoint_step: int) -> None: +def _load_checkpoint(model: Any, trackers: dict[str, RunTracker] | None, load_checkpoint_step: int) -> None: """Load the checkpoint.""" - persister = _get_persister(persisters) - if persister: - persister.load_weights(model, load_checkpoint_step) - SIMPLEXITY_LOGGER.info("[predictive model] loaded checkpoint step: %s", load_checkpoint_step) - else: - raise RuntimeError("Unable to load model checkpoint") + tracker = _get_tracker(trackers) + if tracker is None: + raise RuntimeError("No trackers found to load model checkpoint") + tracker.load_model(model, load_checkpoint_step) + SIMPLEXITY_LOGGER.info("[predictive model] loaded checkpoint step: %s", load_checkpoint_step) def _setup_predictive_models( - cfg: DictConfig, instance_keys: list[str], persisters: dict[str, ModelPersister] | None + cfg: DictConfig, instance_keys: list[str], trackers: dict[str, RunTracker] | None ) -> dict[str, Any] | None: """Setup the predictive model.""" models = {} @@ -441,7 +406,7 @@ def _setup_predictive_models( step_key = instance_key.rsplit(".", 1)[0] + ".load_checkpoint_step" load_checkpoint_step: int | None = OmegaConf.select(cfg, step_key, throw_on_missing=True) if load_checkpoint_step is not None: - _load_checkpoint(model, persisters, load_checkpoint_step) + _load_checkpoint(model, trackers, load_checkpoint_step) models[instance_key] = model if models: return models @@ -576,21 +541,21 @@ def _setup_activation_trackers(cfg: DictConfig, instance_keys: list[str]) -> dic return None -def _do_logging(cfg: DictConfig, loggers: dict[str, Logger] | None, *, verbose: bool) -> None: - if loggers is None: +def _do_logging(cfg: DictConfig, trackers: dict[str, RunTracker] | None, *, verbose: bool) -> None: + if trackers is None: return - for logger in loggers.values(): - logger.log_config(cfg, resolve=True) - logger.log_params(cfg) - log_git_info(logger) - log_system_info(logger) + for tracker in trackers.values(): + tracker.log_config(cfg, resolve=True) + tracker.log_params(cfg) + log_git_info(tracker) + log_system_info(tracker) tags = cfg.get("tags", {}) if tags: - logger.log_tags(tags) + tracker.log_tags(tags) if verbose: - log_hydra_artifacts(logger) - log_environment_artifacts(logger) - log_source_script(logger) + log_hydra_artifacts(tracker) + log_environment_artifacts(tracker) + log_source_script(tracker) def _setup(cfg: DictConfig, strict: bool, verbose: bool) -> Components: @@ -603,27 +568,23 @@ def _setup(cfg: DictConfig, strict: bool, verbose: bool) -> Components: _set_random_seeds(cfg.get("seed", None)) components = Components() instance_keys = get_instance_keys(cfg) - components.loggers = _setup_logging(cfg, instance_keys, strict=strict) + components.run_trackers = _setup_tracking(cfg, instance_keys, strict=strict) components.generative_processes = _setup_generative_processes(cfg, instance_keys) - components.persisters = _setup_persisters(cfg, instance_keys) - components.predictive_models = _setup_predictive_models(cfg, instance_keys, components.persisters) + components.predictive_models = _setup_predictive_models(cfg, instance_keys, components.run_trackers) components.optimizers = _setup_optimizers(cfg, instance_keys, components.predictive_models) components.metric_trackers = _setup_metric_trackers( cfg, instance_keys, components.predictive_models, components.optimizers ) components.activation_trackers = _setup_activation_trackers(cfg, instance_keys) - _do_logging(cfg, components.loggers, verbose=verbose) + _do_logging(cfg, components.run_trackers, verbose=verbose) return components def _cleanup(components: Components) -> None: """Cleanup the run.""" - if components.loggers: - for logger in components.loggers.values(): - logger.close() - if components.persisters: - for persister in components.persisters.values(): - persister.cleanup() + if components.run_trackers: + for tracker in components.run_trackers.values(): + tracker.cleanup() def managed_run(strict: bool = True, verbose: bool = False) -> Callable[[Callable[..., Any]], Callable[..., Any]]: diff --git a/simplexity/structured_configs/base.py b/simplexity/structured_configs/base.py index 946b1043..0cc2c273 100644 --- a/simplexity/structured_configs/base.py +++ b/simplexity/structured_configs/base.py @@ -15,7 +15,7 @@ from simplexity.exceptions import ConfigValidationError from simplexity.logger import SIMPLEXITY_LOGGER -from simplexity.structured_configs.mlflow import MLFlowConfig, validate_mlflow_config +from simplexity.structured_configs.mlflow import MlflowConfig, validate_mlflow_config from simplexity.structured_configs.validation import validate_mapping, validate_non_negative_int, validate_nonempty_str from simplexity.utils.config_utils import dynamic_resolve @@ -27,7 +27,7 @@ class BaseConfig: device: str | None = None seed: int | None = None tags: dict[str, str] | None = None - mlflow: MLFlowConfig | None = None + mlflow: MlflowConfig | None = None def validate_base_config(cfg: DictConfig) -> None: @@ -49,7 +49,7 @@ def validate_base_config(cfg: DictConfig) -> None: validate_mapping(tags, "BaseConfig.tags", key_type=str, value_type=str, is_none_allowed=True) if mlflow is not None: if not isinstance(mlflow, DictConfig): - raise ConfigValidationError("BaseConfig.mlflow must be a MLFlowConfig") + raise ConfigValidationError("BaseConfig.mlflow must be a MlflowConfig") validate_mlflow_config(mlflow) diff --git a/simplexity/structured_configs/logging.py b/simplexity/structured_configs/logging.py deleted file mode 100644 index d7f1e4c2..00000000 --- a/simplexity/structured_configs/logging.py +++ /dev/null @@ -1,155 +0,0 @@ -"""Logging configuration dataclasses.""" - -# pylint: disable=all -# Temporarily disable all pylint checkers during AST traversal to prevent crash. -# The imports checker crashes when resolving simplexity package imports due to a bug -# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 -# pylint: enable=all -# Re-enable all pylint checkers for the checking phase. This allows other checks -# (code quality, style, undefined names, etc.) to run normally while bypassing -# the problematic imports checker that would crash during AST traversal. - -from dataclasses import dataclass - -from omegaconf import DictConfig - -from simplexity.exceptions import ConfigValidationError -from simplexity.structured_configs.instance import InstanceConfig, validate_instance_config -from simplexity.structured_configs.validation import validate_bool, validate_nonempty_str, validate_uri -from simplexity.utils.config_utils import dynamic_resolve - - -@dataclass -class FileLoggerInstanceConfig(InstanceConfig): - """Configuration for FileLogger.""" - - file_path: str - - def __init__(self, file_path: str, _target_: str = "simplexity.logging.file_logger.FileLogger") -> None: - super().__init__(_target_=_target_) - self.file_path = file_path - - -def is_file_logger_target(target: str) -> bool: - """Check if the target is a file logger target.""" - return target == "simplexity.logging.file_logger.FileLogger" - - -def is_file_logger_config(cfg: DictConfig) -> bool: - """Check if the configuration is a FileLoggerInstanceConfig.""" - target = cfg.get("_target_", None) - if isinstance(target, str): - return is_file_logger_target(target) - return False - - -def validate_file_logger_instance_config(cfg: DictConfig) -> None: - """Validate a FileLoggerInstanceConfig. - - Args: - cfg: A DictConfig with FileLoggerInstanceConfig fields (from Hydra). - """ - file_path = cfg.get("file_path") - - validate_instance_config(cfg, expected_target="simplexity.logging.file_logger.FileLogger") - validate_nonempty_str(file_path, "FileLoggerInstanceConfig.file_path") - - -@dataclass -class MLFlowLoggerInstanceConfig(InstanceConfig): - """Configuration for MLFlowLogger.""" - - experiment_id: str | None = None - experiment_name: str | None = None - run_id: str | None = None - run_name: str | None = None - tracking_uri: str | None = None - registry_uri: str | None = None - downgrade_unity_catalog: bool = True - - -def is_mlflow_logger_target(target: str) -> bool: - """Check if the target is a mlflow logger target.""" - return target == "simplexity.logging.mlflow_logger.MLFlowLogger" - - -def is_mlflow_logger_config(cfg: DictConfig) -> bool: - """Check if the configuration is a MLFlowLoggerInstanceConfig.""" - target = cfg.get("_target_", None) - if isinstance(target, str): - return is_mlflow_logger_target(target) - return False - - -def validate_mlflow_logger_instance_config(cfg: DictConfig) -> None: - """Validate a MLFlowLoggerInstanceConfig. - - Args: - cfg: A DictConfig with MLFlowLoggerInstanceConfig fields (from Hydra). - """ - experiment_id = cfg.get("experiment_id") - experiment_name = cfg.get("experiment_name") - run_id = cfg.get("run_id") - run_name = cfg.get("run_name") - tracking_uri = cfg.get("tracking_uri") - registry_uri = cfg.get("registry_uri") - downgrade_unity_catalog = cfg.get("downgrade_unity_catalog") - - validate_instance_config(cfg, expected_target="simplexity.logging.mlflow_logger.MLFlowLogger") - validate_nonempty_str(experiment_id, "MLFlowLoggerInstanceConfig.experiment_id", is_none_allowed=True) - validate_nonempty_str(experiment_name, "MLFlowLoggerInstanceConfig.experiment_name", is_none_allowed=True) - validate_nonempty_str(run_id, "MLFlowLoggerInstanceConfig.run_id", is_none_allowed=True) - validate_nonempty_str(run_name, "MLFlowLoggerInstanceConfig.run_name", is_none_allowed=True) - validate_uri(tracking_uri, "MLFlowLoggerInstanceConfig.tracking_uri", is_none_allowed=True) - validate_uri(registry_uri, "MLFlowLoggerInstanceConfig.registry_uri", is_none_allowed=True) - validate_bool(downgrade_unity_catalog, "MLFlowLoggerInstanceConfig.downgrade_unity_catalog", is_none_allowed=True) - - -@dynamic_resolve -def update_logging_instance_config(cfg: DictConfig, updated_cfg: DictConfig) -> None: - """Update a LoggingInstanceConfig with the updated configuration.""" - cfg.merge_with(updated_cfg) - - -@dataclass -class LoggingConfig: - """Base configuration for logging.""" - - instance: InstanceConfig - name: str | None = None - - -def is_logger_target(target: str) -> bool: - """Check if the target is a logger target.""" - return target.startswith("simplexity.logging.") - - -def is_logger_config(cfg: DictConfig) -> bool: - """Check if the configuration is a LoggingInstanceConfig.""" - target = cfg.get("_target_", None) - if isinstance(target, str): - return is_logger_target(target) - return False - - -def validate_logging_config(cfg: DictConfig) -> None: - """Validate a LoggingConfig. - - Args: - cfg: A DictConfig with instance and optional name fields (from Hydra). - """ - instance = cfg.get("instance") - name = cfg.get("name") - - if not isinstance(instance, DictConfig): - raise ConfigValidationError("LoggingConfig.instance must be a DictConfig") - - if is_file_logger_config(instance): - validate_file_logger_instance_config(instance) - elif is_mlflow_logger_config(instance): - validate_mlflow_logger_instance_config(instance) - else: - validate_instance_config(instance) - if not is_logger_config(instance): - raise ConfigValidationError("LoggingConfig.instance must be a logger target") - validate_nonempty_str(name, "LoggingConfig.name", is_none_allowed=True) diff --git a/simplexity/structured_configs/mlflow.py b/simplexity/structured_configs/mlflow.py index f3b88403..ca532ee6 100644 --- a/simplexity/structured_configs/mlflow.py +++ b/simplexity/structured_configs/mlflow.py @@ -18,7 +18,7 @@ @dataclass -class MLFlowConfig: +class MlflowConfig: """Configuration for MLflow.""" experiment_id: str | None = None @@ -31,10 +31,10 @@ class MLFlowConfig: def validate_mlflow_config(cfg: DictConfig) -> None: - """Validate an MLFlowConfig. + """Validate an MlflowConfig. Args: - cfg: A DictConfig with MLFlowConfig fields (from Hydra). + cfg: A DictConfig with MlflowConfig fields (from Hydra). """ experiment_id = cfg.get("experiment_id") experiment_name = cfg.get("experiment_name") @@ -44,16 +44,16 @@ def validate_mlflow_config(cfg: DictConfig) -> None: registry_uri = cfg.get("registry_uri") downgrade_unity_catalog = cfg.get("downgrade_unity_catalog") - validate_nonempty_str(experiment_id, "MLFlowConfig.experiment_id", is_none_allowed=True) - validate_nonempty_str(experiment_name, "MLFlowConfig.experiment_name", is_none_allowed=True) - validate_nonempty_str(run_id, "MLFlowConfig.run_id", is_none_allowed=True) - validate_nonempty_str(run_name, "MLFlowConfig.run_name", is_none_allowed=True) - validate_bool(downgrade_unity_catalog, "MLFlowConfig.downgrade_unity_catalog", is_none_allowed=True) - validate_uri(tracking_uri, "MLFlowConfig.tracking_uri", is_none_allowed=True) - validate_uri(registry_uri, "MLFlowConfig.registry_uri", is_none_allowed=True) + validate_nonempty_str(experiment_id, "MlflowConfig.experiment_id", is_none_allowed=True) + validate_nonempty_str(experiment_name, "MlflowConfig.experiment_name", is_none_allowed=True) + validate_nonempty_str(run_id, "MlflowConfig.run_id", is_none_allowed=True) + validate_nonempty_str(run_name, "MlflowConfig.run_name", is_none_allowed=True) + validate_bool(downgrade_unity_catalog, "MlflowConfig.downgrade_unity_catalog", is_none_allowed=True) + validate_uri(tracking_uri, "MlflowConfig.tracking_uri", is_none_allowed=True) + validate_uri(registry_uri, "MlflowConfig.registry_uri", is_none_allowed=True) @dynamic_resolve def update_mlflow_config(cfg: DictConfig, updated_cfg: DictConfig) -> None: - """Update a MLFlowConfig with the updated configuration.""" + """Update a MlflowConfig with the updated configuration.""" cfg.merge_with(updated_cfg) diff --git a/simplexity/structured_configs/persistence.py b/simplexity/structured_configs/persistence.py deleted file mode 100644 index b9386338..00000000 --- a/simplexity/structured_configs/persistence.py +++ /dev/null @@ -1,283 +0,0 @@ -"""Persistence configuration dataclasses.""" - -# pylint: disable=all -# Temporarily disable all pylint checkers during AST traversal to prevent crash. -# The imports checker crashes when resolving simplexity package imports due to a bug -# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 -# pylint: enable=all -# Re-enable all pylint checkers for the checking phase. This allows other checks -# (code quality, style, undefined names, etc.) to run normally while bypassing -# the problematic imports checker that would crash during AST traversal. - -import re -from dataclasses import dataclass - -from omegaconf import DictConfig, OmegaConf - -from simplexity.exceptions import ConfigValidationError -from simplexity.structured_configs.instance import InstanceConfig, validate_instance_config -from simplexity.structured_configs.validation import validate_bool, validate_nonempty_str, validate_uri -from simplexity.utils.config_utils import dynamic_resolve - - -@dataclass -class LocalPersisterInstanceConfig(InstanceConfig): - """Configuration for the local persister.""" - - directory: str - - def __init__(self, directory: str, _target_: str = "simplexity.persistence.local_persister.LocalPersister"): - super().__init__(_target_=_target_) - self.directory = directory - - -def is_local_persister_config(cfg: DictConfig, framework: str | None = None) -> bool: - """Check if the configuration is a LocalPersisterInstanceConfig.""" - if framework is None: - file_pattern = "local_[a-z]+_persister" - class_pattern = "Local[A-Z][a-z]+Persister" - else: - file_pattern = f"local_{framework.lower()}_persister" - class_pattern = f"Local{framework.capitalize()}Persister" - target = OmegaConf.select(cfg, "_target_") - if not isinstance(target, str): - return False - return re.match(f"simplexity.persistence.{file_pattern}.{class_pattern}", target) is not None - - -def validate_local_persister_instance_config(cfg: DictConfig, framework: str | None = None) -> None: - """Validate a LocalPersisterInstanceConfig. - - Args: - cfg: A DictConfig with LocalPersisterInstanceConfig fields (from Hydra). - framework: The framework of the local persister. If None, the framework will be inferred from the target. - """ - target = cfg.get("_target_") - directory = cfg.get("directory") - - validate_instance_config(cfg) - if not is_local_persister_config(cfg, framework=framework): - class_name = f"Local{framework.capitalize()}Persister" if framework is not None else "LocalPersister" - raise ConfigValidationError(f"{class_name}InstanceConfig must be a local persister, got {target}") - validate_nonempty_str(directory, "LocalPersisterInstanceConfig.directory") - - -@dataclass -class LocalEquinoxPersisterInstanceConfig(LocalPersisterInstanceConfig): - """Configuration for the local equinox persister.""" - - filename: str = "model.eqx" - - def __init__( - self, - directory: str, - filename: str = "model.eqx", - _target_: str = "simplexity.persistence.local_equinox_persister.LocalEquinoxPersister", - ): - super().__init__(_target_=_target_, directory=directory) - self.filename = filename - - -def is_local_equinox_persister_config(cfg: DictConfig) -> bool: - """Check if the configuration is a LocalEquinoxPersisterInstanceConfig.""" - return is_local_persister_config(cfg, framework="equinox") - - -def validate_local_equinox_persister_instance_config(cfg: DictConfig) -> None: - """Validate a LocalEquinoxPersisterInstanceConfig. - - Args: - cfg: A DictConfig with LocalEquinoxPersisterInstanceConfig fields (from Hydra). - """ - filename = cfg.get("filename") - - validate_local_persister_instance_config(cfg, framework="equinox") - validate_nonempty_str(filename, "LocalEquinoxPersisterInstanceConfig.filename") - assert isinstance(filename, str) - if not filename.endswith(".eqx"): - raise ConfigValidationError("LocalEquinoxPersisterInstanceConfig.filename must end with .eqx, got {filename}") - - -@dataclass -class LocalPenzaiPersisterInstanceConfig(LocalPersisterInstanceConfig): - """Configuration for the local penzai persister.""" - - def __init__( - self, directory: str, _target_: str = "simplexity.persistence.local_penzai_persister.LocalPenzaiPersister" - ): - super().__init__(_target_=_target_, directory=directory) - - -def is_local_penzai_persister_config(cfg: DictConfig) -> bool: - """Check if the configuration is a LocalPenzaiPersisterInstanceConfig.""" - return is_local_persister_config(cfg, framework="penzai") - - -def validate_local_penzai_persister_instance_config(cfg: DictConfig) -> None: - """Validate a LocalPenzaiPersisterInstanceConfig. - - Args: - cfg: A DictConfig with LocalPenzaiPersisterInstanceConfig fields (from Hydra). - """ - validate_local_persister_instance_config(cfg, framework="penzai") - - -@dataclass -class LocalPytorchPersisterInstanceConfig(LocalPersisterInstanceConfig): - """Configuration for the local pytorch persister.""" - - filename: str = "model.pt" - - def __init__( - self, - directory: str, - filename: str = "model.pt", - _target_: str = "simplexity.persistence.local_pytorch_persister.LocalPytorchPersister", - ): - super().__init__(_target_=_target_, directory=directory) - self.filename = filename - - -def is_local_pytorch_persister_config(cfg: DictConfig) -> bool: - """Check if the configuration is a LocalPytorchPersisterInstanceConfig.""" - return is_local_persister_config(cfg, framework="pytorch") - - -def validate_local_pytorch_persister_instance_config(cfg: DictConfig) -> None: - """Validate a LocalPytorchPersisterInstanceConfig. - - Args: - cfg: A DictConfig with LocalPytorchPersisterInstanceConfig fields (from Hydra). - """ - filename = cfg.get("filename") - - validate_local_persister_instance_config(cfg, framework="pytorch") - validate_nonempty_str(filename, "LocalPytorchPersisterInstanceConfig.filename") - assert isinstance(filename, str) - if not filename.endswith(".pt"): - raise ConfigValidationError("LocalPytorchPersisterInstanceConfig.filename must end with .pt, got {filename}") - - -@dataclass -class MLFlowPersisterInstanceConfig(InstanceConfig): - """Configuration for the MLflow persister.""" - - experiment_id: str | None = None - experiment_name: str | None = None - run_id: str | None = None - run_name: str | None = None - tracking_uri: str | None = None - registry_uri: str | None = None - downgrade_unity_catalog: bool = True - artifact_path: str | None = "models" - config_path: str | None = "config.yaml" - - def __init__( - self, - experiment_id: str | None = None, - experiment_name: str | None = None, - run_id: str | None = None, - run_name: str | None = None, - tracking_uri: str | None = None, - registry_uri: str | None = None, - downgrade_unity_catalog: bool = True, - artifact_path: str | None = "models", - config_path: str | None = "config.yaml", - _target_: str = "simplexity.persistence.mlflow_persister.MLFlowPersister", - ): - super().__init__(_target_=_target_) - self.experiment_id = experiment_id - self.experiment_name = experiment_name - self.run_id = run_id - self.run_name = run_name - self.tracking_uri = tracking_uri - self.registry_uri = registry_uri - self.downgrade_unity_catalog = downgrade_unity_catalog - self.artifact_path = artifact_path - self.config_path = config_path - - -def is_mlflow_persister_config(cfg: DictConfig) -> bool: - """Check if the configuration is a MLFlowPersisterInstanceConfig.""" - return OmegaConf.select(cfg, "_target_") == "simplexity.persistence.mlflow_persister.MLFlowPersister" - - -def validate_mlflow_persister_instance_config(cfg: DictConfig) -> None: - """Validate a MLFlowPersisterInstanceConfig. - - Args: - cfg: A DictConfig with MLFlowPersisterInstanceConfig fields (from Hydra). - """ - validate_instance_config(cfg, expected_target="simplexity.persistence.mlflow_persister.MLFlowPersister") - experiment_id = cfg.get("experiment_id") - experiment_name = cfg.get("experiment_name") - run_id = cfg.get("run_id") - run_name = cfg.get("run_name") - tracking_uri = cfg.get("tracking_uri") - registry_uri = cfg.get("registry_uri") - downgrade_unity_catalog = cfg.get("downgrade_unity_catalog") - artifact_path = cfg.get("artifact_path") - config_path = cfg.get("config_path") - - validate_nonempty_str(experiment_id, "MLFlowPersisterInstanceConfig.experiment_id", is_none_allowed=True) - validate_nonempty_str(experiment_name, "MLFlowPersisterInstanceConfig.experiment_name", is_none_allowed=True) - validate_nonempty_str(run_id, "MLFlowPersisterInstanceConfig.run_id", is_none_allowed=True) - validate_nonempty_str(run_name, "MLFlowPersisterInstanceConfig.run_name", is_none_allowed=True) - validate_uri(tracking_uri, "MLFlowPersisterInstanceConfig.tracking_uri", is_none_allowed=True) - validate_uri(registry_uri, "MLFlowPersisterInstanceConfig.registry_uri", is_none_allowed=True) - validate_bool( - downgrade_unity_catalog, "MLFlowPersisterInstanceConfig.downgrade_unity_catalog", is_none_allowed=True - ) - validate_nonempty_str(artifact_path, "MLFlowPersisterInstanceConfig.artifact_path", is_none_allowed=True) - validate_nonempty_str(config_path, "MLFlowPersisterInstanceConfig.config_path", is_none_allowed=True) - - -@dynamic_resolve -def update_persister_instance_config(cfg: DictConfig, updated_cfg: DictConfig) -> None: - """Update a PersistenceConfig with the updated configuration.""" - cfg.merge_with(updated_cfg) - - -@dataclass -class PersistenceConfig: - """Base configuration for persistence.""" - - instance: InstanceConfig - name: str | None = None - - -def is_model_persister_target(target: str) -> bool: - """Check if the target is a model persister target.""" - return target.startswith("simplexity.persistence.") - - -def is_persister_config(cfg: DictConfig) -> bool: - """Check if the configuration is a PersistenceInstanceConfig.""" - target = cfg.get("_target_", None) - if isinstance(target, str): - return is_model_persister_target(target) - return False - - -def validate_persistence_config(cfg: DictConfig) -> None: - """Validate a PersistenceConfig. - - Args: - cfg: A DictConfig with instance and optional name fields (from Hydra). - """ - instance = cfg.get("instance") - if not isinstance(instance, DictConfig): - raise ConfigValidationError("PersistenceConfig.instance is required") - if is_local_equinox_persister_config(instance): - validate_local_equinox_persister_instance_config(instance) - elif is_local_penzai_persister_config(instance): - validate_local_penzai_persister_instance_config(instance) - elif is_local_pytorch_persister_config(instance): - validate_local_pytorch_persister_instance_config(instance) - elif is_mlflow_persister_config(instance): - validate_mlflow_persister_instance_config(instance) - else: - validate_instance_config(instance) - if not is_persister_config(instance): - raise ConfigValidationError("PersistenceConfig.instance must be a persister target") - validate_nonempty_str(cfg.get("name"), "PersistenceConfig.name", is_none_allowed=True) diff --git a/simplexity/structured_configs/tracking.py b/simplexity/structured_configs/tracking.py new file mode 100644 index 00000000..2293ba74 --- /dev/null +++ b/simplexity/structured_configs/tracking.py @@ -0,0 +1,225 @@ +"""Tracking configuration dataclasses.""" + +# pylint: disable=all +# Temporarily disable all pylint checkers during AST traversal to prevent crash. +# The imports checker crashes when resolving simplexity package imports due to a bug +# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 +# pylint: enable=all +# Re-enable all pylint checkers for the checking phase. This allows other checks +# (code quality, style, undefined names, etc.) to run normally while bypassing +# the problematic imports checker that would crash during AST traversal. + +from dataclasses import dataclass + +from omegaconf import DictConfig + +from simplexity.exceptions import ConfigValidationError +from simplexity.structured_configs.instance import InstanceConfig, validate_instance_config +from simplexity.structured_configs.validation import validate_bool, validate_nonempty_str, validate_uri +from simplexity.utils.config_utils import dynamic_resolve + + +@dataclass +class FileTrackerInstanceConfig(InstanceConfig): + """Configuration for FileTracker.""" + + file_path: str + model_dir_name: str = "models" + + def __init__( + self, + file_path: str, + model_dir_name: str = "models", + _target_: str = "simplexity.tracking.file_tracker.FileTracker", + ) -> None: + super().__init__(_target_=_target_) + self.file_path = file_path + self.model_dir_name = model_dir_name + + +def is_file_tracker_target(target: str) -> bool: + """Check if the target is a file tracker target.""" + return target == "simplexity.tracking.file_tracker.FileTracker" + + +def is_file_tracker_config(cfg: DictConfig) -> bool: + """Check if the configuration is a FileTrackerInstanceConfig.""" + target = cfg.get("_target_", None) + if isinstance(target, str): + return is_file_tracker_target(target) + return False + + +def validate_file_tracker_instance_config(cfg: DictConfig) -> None: + """Validate a FileTrackerInstanceConfig.""" + file_path = cfg.get("file_path") + model_dir_name = cfg.get("model_dir_name") + + validate_instance_config(cfg, expected_target="simplexity.tracking.file_tracker.FileTracker") + validate_nonempty_str(file_path, "FileTrackerInstanceConfig.file_path") + validate_nonempty_str(model_dir_name, "FileTrackerInstanceConfig.model_dir_name", is_none_allowed=True) + + +@dataclass +class MlflowTrackerInstanceConfig(InstanceConfig): + """Configuration for MlflowTracker.""" + + experiment_id: str | None = None + experiment_name: str | None = None + run_id: str | None = None + run_name: str | None = None + tracking_uri: str | None = None + registry_uri: str | None = None + downgrade_unity_catalog: bool = True + model_dir: str = "models" + config_path: str = "config.yaml" + + def __init__( + self, + experiment_id: str | None = None, + experiment_name: str | None = None, + run_id: str | None = None, + run_name: str | None = None, + tracking_uri: str | None = None, + registry_uri: str | None = None, + downgrade_unity_catalog: bool = True, + model_dir: str = "models", + config_path: str = "config.yaml", + _target_: str = "simplexity.tracking.mlflow_tracker.MlflowTracker", + ) -> None: + super().__init__(_target_=_target_) + self.experiment_id = experiment_id + self.experiment_name = experiment_name + self.run_id = run_id + self.run_name = run_name + self.tracking_uri = tracking_uri + self.registry_uri = registry_uri + self.downgrade_unity_catalog = downgrade_unity_catalog + self.model_dir = model_dir + self.config_path = config_path + + +def is_mlflow_tracker_target(target: str) -> bool: + """Check if the target is a mlflow tracker target.""" + return target == "simplexity.tracking.mlflow_tracker.MlflowTracker" + + +def is_mlflow_tracker_config(cfg: DictConfig) -> bool: + """Check if the configuration is a MlflowTrackerInstanceConfig.""" + target = cfg.get("_target_", None) + if isinstance(target, str): + return is_mlflow_tracker_target(target) + return False + + +def validate_mlflow_tracker_instance_config(cfg: DictConfig) -> None: + """Validate a MlflowTrackerInstanceConfig.""" + experiment_id = cfg.get("experiment_id") + experiment_name = cfg.get("experiment_name") + run_id = cfg.get("run_id") + run_name = cfg.get("run_name") + tracking_uri = cfg.get("tracking_uri") + registry_uri = cfg.get("registry_uri") + downgrade_unity_catalog = cfg.get("downgrade_unity_catalog") + model_dir = cfg.get("model_dir") + config_path = cfg.get("config_path") + + validate_instance_config(cfg, expected_target="simplexity.tracking.mlflow_tracker.MlflowTracker") + validate_nonempty_str(experiment_id, "MlflowTrackerInstanceConfig.experiment_id", is_none_allowed=True) + validate_nonempty_str(experiment_name, "MlflowTrackerInstanceConfig.experiment_name", is_none_allowed=True) + validate_nonempty_str(run_id, "MlflowTrackerInstanceConfig.run_id", is_none_allowed=True) + validate_nonempty_str(run_name, "MlflowTrackerInstanceConfig.run_name", is_none_allowed=True) + validate_uri(tracking_uri, "MlflowTrackerInstanceConfig.tracking_uri", is_none_allowed=True) + validate_uri(registry_uri, "MlflowTrackerInstanceConfig.registry_uri", is_none_allowed=True) + validate_bool(downgrade_unity_catalog, "MlflowTrackerInstanceConfig.downgrade_unity_catalog", is_none_allowed=True) + validate_nonempty_str(model_dir, "MlflowTrackerInstanceConfig.model_dir", is_none_allowed=True) + validate_nonempty_str(config_path, "MlflowTrackerInstanceConfig.config_path", is_none_allowed=True) + + +@dataclass +class S3TrackerInstanceConfig(InstanceConfig): + """Configuration for S3Tracker (from_config factory).""" + + prefix: str + config_filename: str = "config.ini" + + def __init__( + self, + prefix: str, + config_filename: str = "config.ini", + _target_: str = "simplexity.tracking.s3_tracker.S3Tracker.from_config", + ) -> None: + super().__init__(_target_=_target_) + self.prefix = prefix + self.config_filename = config_filename + + +def is_s3_tracker_target(target: str) -> bool: + """Check if the target is a s3 tracker target.""" + return target == "simplexity.tracking.s3_tracker.S3Tracker.from_config" + + +def is_s3_tracker_config(cfg: DictConfig) -> bool: + """Check if the configuration is a S3TrackerInstanceConfig.""" + target = cfg.get("_target_", None) + if isinstance(target, str): + return is_s3_tracker_target(target) + return False + + +def validate_s3_tracker_instance_config(cfg: DictConfig) -> None: + """Validate a S3TrackerInstanceConfig.""" + prefix = cfg.get("prefix") + config_filename = cfg.get("config_filename") + + validate_instance_config(cfg, expected_target="simplexity.tracking.s3_tracker.S3Tracker.from_config") + validate_nonempty_str(prefix, "S3TrackerInstanceConfig.prefix") + validate_nonempty_str(config_filename, "S3TrackerInstanceConfig.config_filename") + + +@dynamic_resolve +def update_tracking_instance_config(cfg: DictConfig, updated_cfg: DictConfig) -> None: + """Update a TrackingInstanceConfig with the updated configuration.""" + cfg.merge_with(updated_cfg) + + +@dataclass +class TrackingConfig: + """Base configuration for tracking.""" + + instance: InstanceConfig + name: str | None = None + + +def is_run_tracker_target(target: str) -> bool: + """Check if the target is a run tracker target.""" + return target.startswith("simplexity.tracking.") + + +def is_run_tracker_config(cfg: DictConfig) -> bool: + """Check if the configuration is a TrackingInstanceConfig.""" + target = cfg.get("_target_", None) + if isinstance(target, str): + return is_run_tracker_target(target) + return False + + +def validate_tracking_config(cfg: DictConfig) -> None: + """Validate a TrackingConfig.""" + instance = cfg.get("instance") + name = cfg.get("name") + + if not isinstance(instance, DictConfig): + raise ConfigValidationError("TrackingConfig.instance must be a DictConfig") + + if is_file_tracker_config(instance): + validate_file_tracker_instance_config(instance) + elif is_mlflow_tracker_config(instance): + validate_mlflow_tracker_instance_config(instance) + elif is_s3_tracker_config(instance): + validate_s3_tracker_instance_config(instance) + else: + validate_instance_config(instance) + if not is_run_tracker_config(instance): + raise ConfigValidationError("TrackingConfig.instance must be a tracker target") + validate_nonempty_str(name, "TrackingConfig.name", is_none_allowed=True) diff --git a/simplexity/logging/file_logger.py b/simplexity/tracking/file_tracker.py similarity index 65% rename from simplexity/logging/file_logger.py rename to simplexity/tracking/file_tracker.py index b11791fc..0a49bf5d 100644 --- a/simplexity/logging/file_logger.py +++ b/simplexity/tracking/file_tracker.py @@ -1,4 +1,4 @@ -"""FileLogger class for logging to a file.""" +"""File tracker.""" # pylint: disable-all # Temporarily disable all pylint checkers during AST traversal to prevent crash. @@ -22,19 +22,35 @@ import plotly.graph_objects from omegaconf import DictConfig, OmegaConf -from simplexity.logging.logger import Logger +from simplexity.predictive_models.types import ModelFramework, get_model_framework +from simplexity.tracking.model_persistence.local_model_persister import ( + LocalModelPersister, +) +from simplexity.tracking.tracker import RunTracker +from simplexity.tracking.utils import build_local_persister -class FileLogger(Logger): - """Logs to a file.""" +def _clear_subdirectory(subdirectory: Path) -> None: + if subdirectory.exists(): + shutil.rmtree(subdirectory) + subdirectory.parent.mkdir(parents=True, exist_ok=True) - def __init__(self, file_path: str): - self.file_path = file_path + +class FileTracker(RunTracker): + """Tracks runs to a file/directory.""" + + def __init__(self, file_path: str, model_dir_name: str = "models"): + self.file_path = Path(file_path) try: - Path(self.file_path).parent.mkdir(parents=True, exist_ok=True) + self.file_path.parent.mkdir(parents=True, exist_ok=True) except PermissionError as e: raise RuntimeError(f"Failed to create directory for logging: {e}") from e + # Model persistence + self._model_dir = self.file_path.parent / model_dir_name + self._model_dir.mkdir(parents=True, exist_ok=True) + self._local_persisters: dict[ModelFramework, LocalModelPersister] = {} + def log_config(self, config: DictConfig, resolve: bool = False) -> None: """Log config to the file.""" with open(self.file_path, "a") as f: @@ -63,7 +79,7 @@ def log_figure( **kwargs, ) -> None: """Save figure to file system.""" - figure_path = Path(self.file_path).parent / artifact_file + figure_path = self.file_path.parent / artifact_file figure_path.parent.mkdir(parents=True, exist_ok=True) # Handle different figure types @@ -122,16 +138,16 @@ def log_image( if artifact_file: # Artifact mode - image_path = Path(self.file_path).parent / artifact_file + image_path = self.file_path.parent / artifact_file image_path.parent.mkdir(parents=True, exist_ok=True) if self._save_image_to_path(image, image_path, **kwargs): with open(self.file_path, "a") as f: print(f"Image saved: {image_path}", file=f) else: - # Time-stepped mode (we know key and step are valid due to validation above) + # Time-stepped mode filename = f"{key}_step_{step}.png" - image_path = Path(self.file_path).parent / filename + image_path = self.file_path.parent / filename image_path.parent.mkdir(parents=True, exist_ok=True) if self._save_image_to_path(image, image_path, **kwargs): @@ -141,7 +157,7 @@ def log_image( def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None: """Copy artifact to the log directory.""" source_path = Path(local_path) - dest_path = Path(self.file_path).parent / (artifact_path or source_path.name) + dest_path = self.file_path.parent / (artifact_path or source_path.name) dest_path.parent.mkdir(parents=True, exist_ok=True) if source_path.is_file(): @@ -154,7 +170,7 @@ def log_artifact(self, local_path: str, artifact_path: str | None = None) -> Non def log_json_artifact(self, data: dict | list, artifact_name: str) -> None: """Save JSON data as an artifact to the log directory.""" - json_path = Path(self.file_path).parent / artifact_name + json_path = self.file_path.parent / artifact_name json_path.parent.mkdir(parents=True, exist_ok=True) with open(json_path, "w") as f: @@ -163,6 +179,39 @@ def log_json_artifact(self, data: dict | list, artifact_name: str) -> None: with open(self.file_path, "a") as f: print(f"JSON artifact saved: {json_path}", file=f) - def close(self) -> None: - """Close the logger.""" - pass + def cleanup(self) -> None: + """Cleanup resources.""" + for persister in self._local_persisters.values(): + persister.cleanup() + + # Persistence + + def save_model(self, model: Any, step: int = 0) -> None: + """Save a model to the file system.""" + local_persister = self.get_local_persister(model) + step_dir = local_persister.directory / str(step) + _clear_subdirectory(step_dir) + local_persister.save_weights(model, step) + # Note: Local persisters already save to the model_dir which is under file_path.parent + # So we just need to ensure the local persister is built with the right root. + + def load_model(self, model: Any, step: int = 0) -> Any: + """Load a model from the file system.""" + local_persister = self.get_local_persister(model) + return local_persister.load_weights(model, step) + + def get_local_persister(self, model: Any) -> LocalModelPersister: + """Get the local persister for the given model.""" + model_framework = get_model_framework(model) + if model_framework not in self._local_persisters: + self._local_persisters[model_framework] = build_local_persister(model_framework, self._model_dir) + return self._local_persisters[model_framework] + + # Model Registry (Not supported) + def save_model_to_registry(self, model: Any, registered_model_name: str, **kwargs) -> Any: + """Save a model to the registry (Not Supported).""" + raise NotImplementedError("FileTracker does not support model registry.") + + def load_model_from_registry(self, registered_model_name: str, **kwargs) -> Any: + """Load a model from the registry (Not Supported).""" + raise NotImplementedError("FileTracker does not support model registry.") diff --git a/simplexity/persistence/mlflow_persister.py b/simplexity/tracking/mlflow_tracker.py similarity index 65% rename from simplexity/persistence/mlflow_persister.py rename to simplexity/tracking/mlflow_tracker.py index 02a363d3..db0327a8 100644 --- a/simplexity/persistence/mlflow_persister.py +++ b/simplexity/tracking/mlflow_tracker.py @@ -1,4 +1,4 @@ -"""MLflow-backed model persistence utilities.""" +"""MLFlow tracker.""" # pylint: disable-all # Temporarily disable all pylint checkers during AST traversal to prevent crash. @@ -9,26 +9,37 @@ # (code quality, style, undefined names, etc.) to run normally while bypassing # the problematic imports checker that would crash during AST traversal. -from __future__ import annotations - +import json +import os import shutil import tempfile +import time +from collections.abc import Mapping from contextlib import nullcontext from pathlib import Path from typing import Any +import dotenv +import matplotlib.figure import mlflow import mlflow.pytorch as mlflow_pytorch +import numpy +import PIL.Image +import plotly.graph_objects import torch +from mlflow.entities import Metric, Param, RunTag from mlflow.models.model import ModelInfo from mlflow.models.signature import infer_signature from omegaconf import DictConfig, OmegaConf from simplexity.logger import SIMPLEXITY_LOGGER -from simplexity.persistence.local_persister import LocalPersister from simplexity.predictive_models.types import ModelFramework, get_model_framework -from simplexity.structured_configs.persistence import MLFlowPersisterInstanceConfig -from simplexity.utils.config_utils import typed_instantiate +from simplexity.structured_configs.tracking import MlflowTrackerInstanceConfig +from simplexity.tracking.model_persistence.local_model_persister import ( + LocalModelPersister, +) +from simplexity.tracking.tracker import RunTracker +from simplexity.tracking.utils import build_local_persister from simplexity.utils.mlflow_utils import ( get_experiment, get_run, @@ -39,40 +50,17 @@ from simplexity.utils.pip_utils import create_requirements_file -def _build_local_persister(model_framework: ModelFramework, artifact_dir: Path) -> LocalPersister: - if model_framework == ModelFramework.EQUINOX: - from simplexity.persistence.local_equinox_persister import ( # pylint: disable=import-outside-toplevel - LocalEquinoxPersister, - ) - - directory = artifact_dir / "equinox" - return LocalEquinoxPersister(directory=directory) - if model_framework == ModelFramework.PENZAI: - from simplexity.persistence.local_penzai_persister import ( # pylint: disable=import-outside-toplevel - LocalPenzaiPersister, - ) - - directory = artifact_dir / "penzai" - return LocalPenzaiPersister(directory=directory) - if model_framework == ModelFramework.PYTORCH: - from simplexity.persistence.local_pytorch_persister import ( # pylint: disable=import-outside-toplevel - LocalPytorchPersister, - ) - - directory = artifact_dir / "pytorch" - return LocalPytorchPersister(directory=directory) - - raise ValueError(f"Unsupported model framework: {model_framework}") - - def _clear_subdirectory(subdirectory: Path) -> None: if subdirectory.exists(): shutil.rmtree(subdirectory) subdirectory.parent.mkdir(parents=True, exist_ok=True) -class MLFlowPersister: # pylint: disable=too-many-instance-attributes - """Persist model checkpoints as MLflow artifacts, optionally reusing an existing run.""" +dotenv.load_dotenv() + + +class MlflowTracker(RunTracker): # pylint: disable=too-many-instance-attributes + """Tracks runs to MLflow.""" def __init__( self, @@ -86,7 +74,7 @@ def __init__( model_dir: str = "models", config_path: str = "config.yaml", ): - """Create a persister from an MLflow experiment.""" + """Initialize MLflow tracker.""" self._downgrade_unity_catalog = downgrade_unity_catalog if downgrade_unity_catalog is not None else True resolved_registry_uri = resolve_registry_uri( registry_uri=registry_uri, @@ -102,12 +90,14 @@ def __init__( assert run is not None self._run_id = run.info.run_id self._run_name = run.info.run_name + + # Model persistence setup self._model_dir = model_dir.strip().strip("/") self._temp_dir = tempfile.TemporaryDirectory() self._model_path = Path(self._temp_dir.name) / self._model_dir if self._model_dir else Path(self._temp_dir.name) self._model_path.mkdir(parents=True, exist_ok=True) self._config_path = config_path - self._local_persisters = {} + self._local_persisters: dict[ModelFramework, LocalModelPersister] = {} @property def client(self) -> mlflow.MlflowClient: @@ -124,36 +114,36 @@ def experiment_id(self) -> str: """Expose active MLflow experiment identifier.""" return self._experiment_id - @property - def run_id(self) -> str: - """Expose active MLflow run identifier.""" - return self._run_id - @property def run_name(self) -> str | None: """Expose active MLflow run name.""" return self._run_name + @property + def run_id(self) -> str: + """Expose active MLflow run identifier.""" + return self._run_id + @property def tracking_uri(self) -> str | None: - """Return the tracking URI associated with this persister.""" + """Return the tracking URI associated with this tracker.""" return self.client.tracking_uri @property def registry_uri(self) -> str | None: - """Return the model registry URI associated with this persister.""" + """Return the model registry URI associated with this tracker.""" return self.client._registry_uri # pylint: disable=protected-access @property def model_dir(self) -> str: - """Return the artifact path associated with this persister.""" + """Return the artifact path associated with this tracker.""" return self._model_dir @property - def cfg(self) -> MLFlowPersisterInstanceConfig: - """Return the configuration of this persister.""" - return MLFlowPersisterInstanceConfig( - _target_=self.__class__.__qualname__, + def cfg(self) -> MlflowTrackerInstanceConfig: + """Return the configuration of this tracker.""" + return MlflowTrackerInstanceConfig( + _target_=f"simpexity.tracking.{self.__class__.__module__}.{self.__class__.__qualname__}", experiment_id=self.experiment_id, experiment_name=self.experiment_name, run_id=self.run_id, @@ -161,11 +151,116 @@ def cfg(self) -> MLFlowPersisterInstanceConfig: tracking_uri=self.tracking_uri, registry_uri=self.registry_uri, downgrade_unity_catalog=self._downgrade_unity_catalog, - artifact_path=self.model_dir, + model_dir=self.model_dir, config_path=self._config_path, ) - def save_weights(self, model: Any, step: int = 0) -> None: + # Lifecycle + + def cleanup(self) -> None: + """Remove temporary resources and optionally end the MLflow run.""" + for persister in self._local_persisters.values(): + persister.cleanup() + self._temp_dir.cleanup() + maybe_terminate_run(run_id=self.run_id, client=self.client) + + # Logging + + def log_config(self, config: DictConfig, resolve: bool = False) -> None: + """Log config to MLflow.""" + with tempfile.TemporaryDirectory() as temp_dir: + config_path = os.path.join(temp_dir, "config.yaml") + OmegaConf.save(config, config_path, resolve=resolve) + self.client.log_artifact(self.run_id, config_path) + + def log_metrics(self, step: int, metric_dict: Mapping[str, Any]) -> None: + """Log metrics to MLflow.""" + timestamp = int(time.time() * 1000) + metrics = self._flatten_metric_dict(metric_dict, timestamp, step) + self._log_batch(metrics=metrics) + + def _flatten_metric_dict( + self, metric_dict: Mapping[str, Any], timestamp: int, step: int, key_prefix: str = "" + ) -> list[Metric]: + """Flatten a dictionary of metrics into a list of Metric entities.""" + metrics = [] + for key, value in metric_dict.items(): + key = f"{key_prefix}/{key}" if key_prefix else key + if isinstance(value, Mapping): + nested_metrics = self._flatten_metric_dict(value, timestamp, step, key_prefix=key) + metrics.extend(nested_metrics) + else: + value = float(value) + metric = Metric(key, value, timestamp, step) + metrics.append(metric) + return metrics + + def log_params(self, param_dict: Mapping[str, Any]) -> None: + """Log params to MLflow.""" + params = self._flatten_param_dict(param_dict) + self._log_batch(params=params) + + def _flatten_param_dict(self, param_dict: Mapping[str, Any], key_prefix: str = "") -> list[Param]: + """Flatten a dictionary of params into a list of Param entities.""" + params = [] + for key, value in param_dict.items(): + key = f"{key_prefix}.{key}" if key_prefix else key + if isinstance(value, Mapping): + nested_params = self._flatten_param_dict(value, key_prefix=key) + params.extend(nested_params) + else: + value = str(value) + param = Param(key, value) + params.append(param) + return params + + def log_tags(self, tag_dict: Mapping[str, Any]) -> None: + """Set tags on the MLFlow.""" + tags = [RunTag(k, str(v)) for k, v in tag_dict.items()] + self._log_batch(tags=tags) + + def log_figure( + self, + figure: matplotlib.figure.Figure | plotly.graph_objects.Figure, + artifact_file: str, + **kwargs, + ) -> None: + """Log a figure to MLflow using MLflowClient.log_figure.""" + self.client.log_figure(self.run_id, figure, artifact_file, **kwargs) + + def log_image( + self, + image: numpy.ndarray | PIL.Image.Image | mlflow.Image, + artifact_file: str | None = None, + key: str | None = None, + step: int | None = None, + **kwargs, + ) -> None: + """Log an image to MLflow using MLflowClient.log_image.""" + if not artifact_file and not (key and step is not None): + raise ValueError("Must provide either artifact_file or both key and step parameters") + + self.client.log_image(self.run_id, image, artifact_file=artifact_file, key=key, step=step, **kwargs) + + def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None: + """Log an artifact (file or directory) to MLflow.""" + self.client.log_artifact(self.run_id, local_path, artifact_path) + + def log_json_artifact(self, data: dict | list, artifact_name: str) -> None: + """Log a JSON object as an artifact to MLflow.""" + with tempfile.TemporaryDirectory() as temp_dir: + json_path = os.path.join(temp_dir, artifact_name) + with open(json_path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + self.client.log_artifact(self.run_id, json_path) + + def _log_batch(self, **kwargs: Any) -> None: + """Log arbitrary data to MLflow.""" + self.client.log_batch(self.run_id, **kwargs, synchronous=False) + + # Persistence + + def save_model(self, model: Any, step: int = 0) -> None: """Serialize weights locally and upload them as MLflow artifacts.""" local_persister = self.get_local_persister(model) step_dir = local_persister.directory / str(step) @@ -174,7 +269,7 @@ def save_weights(self, model: Any, step: int = 0) -> None: framework_dir = step_dir.parent self.client.log_artifacts(self.run_id, str(framework_dir), artifact_path=self._model_dir) - def load_weights(self, model: Any, step: int = 0) -> Any: + def load_model(self, model: Any, step: int = 0) -> Any: """Download MLflow artifacts and restore them into the provided model.""" local_persister = self.get_local_persister(model) step_dir = local_persister.directory / str(step) @@ -189,38 +284,15 @@ def load_weights(self, model: Any, step: int = 0) -> Any: raise RuntimeError(f"MLflow artifact for step {step} was not found after download") return local_persister.load_weights(model, step) - def load_model(self, step: int = 0) -> Any: - """Load a model from a specified MLflow run and step.""" - config_path = self._config_path - - with tempfile.TemporaryDirectory() as temp_dir: - downloaded_config_path = self.client.download_artifacts( - self.run_id, - config_path, - dst_path=str(temp_dir), - ) - run_config = OmegaConf.load(downloaded_config_path) - - instance: DictConfig = OmegaConf.select(run_config, "predictive_model.instance", throw_on_missing=True) - target: str = OmegaConf.select(run_config, "predictive_model.instance._target_", throw_on_missing=True) - model = typed_instantiate(instance, target) - - return self.load_weights(model, step) - - def cleanup(self) -> None: - """Remove temporary resources and optionally end the MLflow run.""" - for persister in self._local_persisters.values(): - persister.cleanup() - self._temp_dir.cleanup() - maybe_terminate_run(run_id=self.run_id, client=self.client) - - def get_local_persister(self, model: Any) -> LocalPersister: + def get_local_persister(self, model: Any) -> LocalModelPersister: """Get the local persister for the given model.""" model_framework = get_model_framework(model) if model_framework not in self._local_persisters: - self._local_persisters[model_framework] = _build_local_persister(model_framework, self._model_path) + self._local_persisters[model_framework] = build_local_persister(model_framework, self._model_path) return self._local_persisters[model_framework] + # Model Registry + def save_model_to_registry( self, model: Any, @@ -228,19 +300,7 @@ def save_model_to_registry( model_inputs: torch.Tensor | None = None, **kwargs: Any, ) -> ModelInfo: - """Save a PyTorch model to the MLflow model registry. - - Args: - model: The PyTorch model to save. Must be a torch.nn.Module instance. - registered_model_name: The name to register the model under in the registry. - model_inputs: Optional model inputs (torch.Tensor) to use for inferring the model signature. - If provided, the signature will be automatically inferred. - **kwargs: Additional keyword arguments passed to mlflow.pytorch.log_model. - Can include 'signature' or 'pip_requirements' to override defaults. - - Raises: - ValueError: If the model is not a PyTorch model. - """ + """Save a PyTorch model to the MLflow model registry.""" if not isinstance(model, torch.nn.Module): raise ValueError(f"Model must be a PyTorch model (torch.nn.Module), got {type(model)}") @@ -293,14 +353,7 @@ def save_model_to_registry( def registered_model_uri( self, registered_model_name: str, version: str | None = None, stage: str | None = None ) -> str: - """Get the URI for a registered model. - - Args: - registered_model_name: The name of the registered model. - version: Optional specific version to load (e.g., "1", "2"). If None, loads the latest version. - stage: Optional stage to load from (e.g., "Production", "Staging", "Archived"). - If provided, takes precedence over version. - """ + """Get the URI for a registered model.""" prefix = "models:" if version is not None and stage is not None: raise ValueError("Cannot specify both version and stage. Use one or the other.") @@ -317,31 +370,6 @@ def registered_model_uri( latest_version = model_versions[0].version return f"{prefix}/{registered_model_name}/{latest_version}" - def load_model_from_registry( - self, - registered_model_name: str, - version: str | None = None, - stage: str | None = None, - ) -> Any: - """Load a PyTorch model from the MLflow model registry. - - Args: - registered_model_name: The name of the registered model. - version: Optional specific version to load (e.g., "1", "2"). If None, loads the latest version. - stage: Optional stage to load from (e.g., "Production", "Staging", "Archived"). - If provided, takes precedence over version. - - Returns: - The loaded PyTorch model. - - Raises: - ValueError: If both version and stage are provided. - RuntimeError: If the model cannot be found or loaded. - """ - model_uri = self.registered_model_uri(registered_model_name, version, stage) - with set_mlflow_uris(tracking_uri=self.tracking_uri, registry_uri=self.registry_uri): - return mlflow_pytorch.load_model(model_uri) - def list_model_versions( self, registered_model_name: str, @@ -375,3 +403,15 @@ def list_model_versions( } for mv in model_versions ] + + def load_model_from_registry( + self, + registered_model_name: str, + version: str | None = None, + stage: str | None = None, + **kwargs: Any, # pylint: disable=unused-argument + ) -> Any: + """Load a PyTorch model from the MLflow model registry.""" + model_uri = self.registered_model_uri(registered_model_name, version, stage) + with set_mlflow_uris(tracking_uri=self.tracking_uri, registry_uri=self.registry_uri): + return mlflow_pytorch.load_model(model_uri) diff --git a/simplexity/persistence/local_equinox_persister.py b/simplexity/tracking/model_persistence/local_equinox_persister.py similarity index 90% rename from simplexity/persistence/local_equinox_persister.py rename to simplexity/tracking/model_persistence/local_equinox_persister.py index 039cb36e..778ea2eb 100644 --- a/simplexity/persistence/local_equinox_persister.py +++ b/simplexity/tracking/model_persistence/local_equinox_persister.py @@ -13,10 +13,12 @@ import equinox as eqx -from simplexity.persistence.local_persister import LocalPersister +from simplexity.tracking.model_persistence.local_model_persister import ( + LocalModelPersister, +) -class LocalEquinoxPersister(LocalPersister): +class LocalEquinoxPersister(LocalModelPersister): """Persists a model to the local filesystem.""" filename: str = "model.eqx" diff --git a/simplexity/tracking/model_persistence/local_model_persister.py b/simplexity/tracking/model_persistence/local_model_persister.py new file mode 100644 index 00000000..10d397ef --- /dev/null +++ b/simplexity/tracking/model_persistence/local_model_persister.py @@ -0,0 +1,23 @@ +"""Local model persister protocol.""" + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any + + +class LocalModelPersister(ABC): + """Abstract base class for local model persisters.""" + + directory: Path + """Return the directory where the model is persisted.""" + + def cleanup(self) -> None: # noqa: B027 + """Cleans up the persister.""" + + @abstractmethod + def save_weights(self, model: Any, step: int = 0) -> None: + """Saves a model.""" + + @abstractmethod + def load_weights(self, model: Any, step: int = 0) -> Any: + """Load weights into an existing model instance.""" diff --git a/simplexity/persistence/local_penzai_persister.py b/simplexity/tracking/model_persistence/local_penzai_persister.py similarity index 93% rename from simplexity/persistence/local_penzai_persister.py rename to simplexity/tracking/model_persistence/local_penzai_persister.py index af94eebf..eb46268a 100644 --- a/simplexity/persistence/local_penzai_persister.py +++ b/simplexity/tracking/model_persistence/local_penzai_persister.py @@ -16,11 +16,13 @@ from penzai import pz from penzai.nn.layer import Layer as PenzaiModel -from simplexity.persistence.local_persister import LocalPersister +from simplexity.tracking.model_persistence.local_model_persister import ( + LocalModelPersister, +) from simplexity.utils.penzai_utils import deconstruct_variables, reconstruct_variables -class LocalPenzaiPersister(LocalPersister): +class LocalPenzaiPersister(LocalModelPersister): """Persists a model to the local filesystem.""" registry: DefaultCheckpointHandlerRegistry diff --git a/simplexity/persistence/local_pytorch_persister.py b/simplexity/tracking/model_persistence/local_pytorch_persister.py similarity index 91% rename from simplexity/persistence/local_pytorch_persister.py rename to simplexity/tracking/model_persistence/local_pytorch_persister.py index 778f899d..106907a0 100644 --- a/simplexity/persistence/local_pytorch_persister.py +++ b/simplexity/tracking/model_persistence/local_pytorch_persister.py @@ -13,10 +13,12 @@ import torch -from simplexity.persistence.local_persister import LocalPersister +from simplexity.tracking.model_persistence.local_model_persister import ( + LocalModelPersister, +) -class LocalPytorchPersister(LocalPersister): +class LocalPytorchPersister(LocalModelPersister): """Persists a PyTorch model to the local filesystem.""" filename: str = "model.pt" diff --git a/simplexity/tracking/s3_tracker.py b/simplexity/tracking/s3_tracker.py new file mode 100644 index 00000000..240c3cfa --- /dev/null +++ b/simplexity/tracking/s3_tracker.py @@ -0,0 +1,225 @@ +"""S3 tracker.""" + +# pylint: disable=all +# Temporarily disable all pylint checkers during AST traversal to prevent crash. +# The imports checker crashes when resolving simplexity package imports due to a bug +# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 +# pylint: enable=all +# Re-enable all pylint checkers for the checking phase. This allows other checks +# (code quality, style, undefined names, etc.) to run normally while bypassing +# the problematic imports checker that would crash during AST traversal. + +import configparser +import tempfile +from collections.abc import Iterable, Mapping +from pathlib import Path +from typing import Any, Protocol + +from omegaconf import DictConfig + +from simplexity.predictive_models.types import ModelFramework +from simplexity.tracking.model_persistence.local_model_persister import ( + LocalModelPersister, +) +from simplexity.tracking.tracker import RunTracker +from simplexity.tracking.utils import build_local_persister + + +class S3Paginator(Protocol): + """Protocol for an S3 paginator.""" + + def paginate(self, Bucket: str, Prefix: str) -> Iterable[Mapping[str, Any]]: # pylint: disable=invalid-name + """Paginate over the objects in an S3 bucket.""" + ... + + +class S3Client(Protocol): + """Protocol for S3 client.""" + + def upload_file(self, file_name: str, bucket: str, object_name: str) -> None: + """Upload a file to S3.""" + + def download_file(self, bucket: str, object_name: str, file_name: str) -> None: + """Download a file from S3.""" + + def get_paginator(self, operation_name: str) -> S3Paginator: + """Get a paginator for the given operation.""" + ... + + +class S3Tracker(RunTracker): + """Tracks runs to S3 (persistence only).""" + + def __init__( + self, + bucket: str, + prefix: str, + s3_client: S3Client, + temp_dir: tempfile.TemporaryDirectory, + local_persisters: dict[ModelFramework, LocalModelPersister] | None = None, + ): + self.bucket = bucket + self.prefix = prefix + self.s3_client = s3_client + self.temp_dir = temp_dir + self.local_persisters = local_persisters or {} + + @classmethod + def from_config( + cls, + prefix: str, + config_filename: str = "config.ini", + ) -> "S3Tracker": + """Creates a new S3Tracker from configuration parameters.""" + import boto3.session # pylint: disable=import-outside-toplevel + + config = configparser.ConfigParser() + config.read(config_filename) + + bucket = config.get("s3", "bucket") + profile_name = config.get("aws", "profile_name", fallback="default") + session = boto3.session.Session(profile_name=profile_name) + s3_client = session.client("s3") + temp_dir = tempfile.TemporaryDirectory() + + return cls( + bucket=bucket, + prefix=prefix, + s3_client=s3_client, # type: ignore + temp_dir=temp_dir, + ) + + # Lifecycle + + def cleanup(self) -> None: + """Cleans up the temporary directory.""" + self.temp_dir.cleanup() + + # Logging (Not Implemented / No-ops) + + def log_config(self, config: DictConfig, resolve: bool = False) -> None: + """Log config (Not Supported).""" + raise NotImplementedError("S3Tracker does not support logging config.") + + def log_metrics(self, step: int, metric_dict: Mapping[str, Any]) -> None: + """Log metrics (Not Supported).""" + raise NotImplementedError("S3Tracker does not support logging metrics.") + + def log_params(self, param_dict: Mapping[str, Any]) -> None: + """Log params (Not Supported).""" + raise NotImplementedError("S3Tracker does not support logging params.") + + def log_tags(self, tag_dict: Mapping[str, Any]) -> None: + """Log tags (Not Supported).""" + raise NotImplementedError("S3Tracker does not support logging tags.") + + def log_figure(self, figure: Any, artifact_file: str, **kwargs) -> None: + """Log figure (Not Supported).""" + raise NotImplementedError("S3Tracker does not support logging figures.") + + def log_image( + self, + image: Any, + artifact_file: str | None = None, + key: str | None = None, + step: int | None = None, + **kwargs, + ) -> None: + """Log image (Not Supported).""" + raise NotImplementedError("S3Tracker does not support logging images.") + + def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None: + """Log artifact (Not Supported).""" + raise NotImplementedError("S3Tracker does not support logging artifacts.") + + def log_json_artifact(self, data: dict | list, artifact_name: str) -> None: + """Log JSON artifact (Not Supported).""" + raise NotImplementedError("S3Tracker does not support logging JSON artifacts.") + + # Persistence + + def save_model(self, model: Any, step: int = 0) -> None: + """Saves a model to S3.""" + local_persister = self.get_local_persister(model) + local_persister.save_weights(model, step) + directory = local_persister.directory / str(step) + self._upload_local_directory(directory) + + def load_model(self, model: Any, step: int = 0) -> Any: + """Loads a model from S3.""" + local_persister = self.get_local_persister(model) + self._download_s3_objects(step, local_persister) + return local_persister.load_weights(model, step) + + def get_local_persister(self, model: Any) -> LocalModelPersister: + """Get the local persister for the given model.""" + from simplexity.predictive_models.types import get_model_framework + + model_framework = get_model_framework(model) + if model_framework not in self.local_persisters: + self.local_persisters[model_framework] = build_local_persister(model_framework, Path(self.temp_dir.name)) + return self.local_persisters[model_framework] + + def _upload_local_directory(self, directory: Path) -> None: + for root, _, files in directory.walk(): + for file in files: + file_path = root / file + relative_path = file_path.relative_to(directory.parent) + object_name = f"{self.prefix}/{relative_path}" + file_name = str(file_path) + self._upload_local_file(file_name, object_name) + + def _upload_local_file(self, file_name: str, object_name: str) -> None: + from botocore.exceptions import ClientError + + try: + self.s3_client.upload_file(file_name, self.bucket, object_name) + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "Unknown") + if error_code == "NoSuchBucket": + raise RuntimeError(f"Bucket {self.bucket} does not exist") from e + elif error_code == "AccessDenied": + raise RuntimeError(f"Access denied to bucket {self.bucket}") from e + else: + raise RuntimeError(f"Failed to save {file_name} to S3: {e}") from e + except Exception as e: + raise RuntimeError(f"Unexpected error saving {file_name} to S3: {e}") from e + + def _download_s3_objects(self, step: int, local_persister: LocalModelPersister) -> None: + prefix = f"{self.prefix}/{step}" + paginator = self.s3_client.get_paginator("list_objects_v2") + for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix): + for obj in page.get("Contents", []): + object_name = obj["Key"] + relative_path = Path(object_name).relative_to(self.prefix) + file_name = str(local_persister.directory / relative_path) + self._download_s3_object(object_name, file_name) + + def _download_s3_object(self, object_name: str, file_name: str) -> None: + from botocore.exceptions import ClientError + + try: + local_path = Path(file_name) + local_path.parent.mkdir(parents=True, exist_ok=True) + self.s3_client.download_file(self.bucket, object_name, file_name) + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "Unknown") + if error_code == "NoSuchKey": + raise RuntimeError(f"{file_name} not found in bucket {self.bucket}") from e + elif error_code == "NoSuchBucket": + raise RuntimeError(f"Bucket {self.bucket} does not exist") from e + elif error_code == "AccessDenied": + raise RuntimeError(f"Access denied to bucket {self.bucket}") from e + else: + raise RuntimeError(f"Failed to load {file_name} from S3: {e}") from e + except Exception as e: + raise RuntimeError(f"Unexpected error loading {file_name} from S3: {e}") from e + + # Model Registry (Not supported) + def save_model_to_registry(self, model: Any, registered_model_name: str, **kwargs) -> Any: + """Save a model to the registry (Not Supported).""" + raise NotImplementedError("S3Tracker does not support model registry.") + + def load_model_from_registry(self, registered_model_name: str, **kwargs) -> Any: + """Load a model from the registry (Not Supported).""" + raise NotImplementedError("S3Tracker does not support model registry.") diff --git a/simplexity/tracking/tracker.py b/simplexity/tracking/tracker.py new file mode 100644 index 00000000..8cee55eb --- /dev/null +++ b/simplexity/tracking/tracker.py @@ -0,0 +1,90 @@ +"""RunTracker protocol.""" + +# pylint: disable=unnecessary-ellipsis + +from collections.abc import Mapping +from typing import Any, Protocol, runtime_checkable + +import matplotlib.figure +import mlflow +import numpy +import PIL.Image +import plotly.graph_objects +from omegaconf import DictConfig + + +@runtime_checkable +class RunTracker(Protocol): + """Tracks run data (metrics, params, artifacts, models).""" + + # Lifecycle + def cleanup(self) -> None: + """Cleanup resources.""" + ... + + # Logging + def log_config(self, config: DictConfig, resolve: bool = False) -> None: + """Log config.""" + ... + + def log_metrics(self, step: int, metric_dict: Mapping[str, Any]) -> None: + """Log metrics.""" + ... + + def log_params(self, param_dict: Mapping[str, Any]) -> None: + """Log params.""" + ... + + def log_tags(self, tag_dict: Mapping[str, Any]) -> None: + """Log tags.""" + ... + + def log_figure( + self, + figure: matplotlib.figure.Figure | plotly.graph_objects.Figure, + artifact_file: str, + **kwargs, + ) -> None: + """Log a figure.""" + ... + + def log_image( + self, + image: numpy.ndarray | PIL.Image.Image | mlflow.Image, + artifact_file: str | None = None, + key: str | None = None, + step: int | None = None, + **kwargs, + ) -> None: + """Log an image.""" + ... + + def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None: + """Log an artifact (file or directory).""" + ... + + def log_json_artifact(self, data: dict | list, artifact_name: str) -> None: + """Log a JSON object as an artifact.""" + ... + + # Model Persistence + def save_model(self, model: Any, step: int = 0) -> None: + """Save a model.""" + ... + + def load_model(self, model: Any, step: int = 0) -> Any: + """Load a model.""" + ... + + # Model Registry (Optional) + def save_model_to_registry(self, model: Any, registered_model_name: str, **kwargs) -> Any: + """Save a model to the registry.""" + ... + + def load_model_from_registry(self, registered_model_name: str, **kwargs) -> Any: + """Load a model from the registry.""" + ... + + # Data Retrieval & Listing (Future Scope) + # def list_run_data(self) -> dict[str, Any]: ... + # def download_run_data(self, ...): ... diff --git a/simplexity/tracking/utils.py b/simplexity/tracking/utils.py new file mode 100644 index 00000000..ac737f9b --- /dev/null +++ b/simplexity/tracking/utils.py @@ -0,0 +1,35 @@ +"""Tracking utilities.""" + +from pathlib import Path + +from simplexity.predictive_models.types import ModelFramework +from simplexity.tracking.model_persistence.local_model_persister import ( + LocalModelPersister, +) + + +def build_local_persister(model_framework: ModelFramework, artifact_dir: Path) -> LocalModelPersister: + """Build a local persister.""" + if model_framework == ModelFramework.EQUINOX: + from simplexity.tracking.model_persistence.local_equinox_persister import ( # pylint: disable=import-outside-toplevel + LocalEquinoxPersister, + ) + + directory = artifact_dir / "equinox" + return LocalEquinoxPersister(directory=directory) + if model_framework == ModelFramework.PENZAI: + from simplexity.tracking.model_persistence.local_penzai_persister import ( # pylint: disable=import-outside-toplevel + LocalPenzaiPersister, + ) + + directory = artifact_dir / "penzai" + return LocalPenzaiPersister(directory=directory) + if model_framework == ModelFramework.PYTORCH: + from simplexity.tracking.model_persistence.local_pytorch_persister import ( # pylint: disable=import-outside-toplevel + LocalPytorchPersister, + ) + + directory = artifact_dir / "pytorch" + return LocalPytorchPersister(directory=directory) + + raise ValueError(f"Unsupported model framework: {model_framework}") diff --git a/tests/end_to_end/configs/persistence/mlflow_persister.yaml b/tests/end_to_end/configs/persistence/mlflow_persister.yaml deleted file mode 100644 index fd3755fb..00000000 --- a/tests/end_to_end/configs/persistence/mlflow_persister.yaml +++ /dev/null @@ -1,10 +0,0 @@ -name: mlflow_persister - -instance: - _target_: simplexity.persistence.mlflow_persister.MLFlowPersister - experiment_name: - run_name: - tracking_uri: - registry_uri: - model_dir: models - downgrade_unity_catalog: true diff --git a/tests/end_to_end/configs/logging/mlflow_logger.yaml b/tests/end_to_end/configs/tracking/mlflow_tracker.yaml similarity index 54% rename from tests/end_to_end/configs/logging/mlflow_logger.yaml rename to tests/end_to_end/configs/tracking/mlflow_tracker.yaml index 8ad567fb..4c4f3d5d 100644 --- a/tests/end_to_end/configs/logging/mlflow_logger.yaml +++ b/tests/end_to_end/configs/tracking/mlflow_tracker.yaml @@ -1,6 +1,6 @@ -name: mlflow_logger +name: mlflow_tracker instance: - _target_: simplexity.logging.mlflow_logger.MLFlowLogger + _target_: simplexity.tracking.mlflow_tracker.MlflowTracker experiment_name: run_name: tracking_uri: diff --git a/tests/end_to_end/configs/training.yaml b/tests/end_to_end/configs/training.yaml index de164936..693b4ce8 100644 --- a/tests/end_to_end/configs/training.yaml +++ b/tests/end_to_end/configs/training.yaml @@ -1,9 +1,8 @@ defaults: - _self_ - mlflow: databricks - - logging: mlflow_logger + - tracking: mlflow_tracker - generative_process: mess3 - - persistence: mlflow_persister - predictive_model: transformer - optimizer: pytorch_adam - metric_tracker@training_metric_tracker: default diff --git a/tests/end_to_end/configs/training_test.yaml b/tests/end_to_end/configs/training_test.yaml index 6a47cba9..7701a937 100644 --- a/tests/end_to_end/configs/training_test.yaml +++ b/tests/end_to_end/configs/training_test.yaml @@ -1,9 +1,8 @@ defaults: - _self_ - mlflow: databricks - - logging: mlflow_logger + - tracking: mlflow_tracker - generative_process: mess3 - - persistence: mlflow_persister - predictive_model: tiny_transformer - optimizer: pytorch_adam - metric_tracker@training_metric_tracker: default diff --git a/tests/end_to_end/test_training.py b/tests/end_to_end/test_training.py index 2f92e5e4..ab0df6c2 100644 --- a/tests/end_to_end/test_training.py +++ b/tests/end_to_end/test_training.py @@ -76,7 +76,7 @@ def get_metric_values(metric_name: str) -> np.ndarray: assert np.all(param_norm > 0) # Checkpoints - model_dir = cfg.persistence.instance.model_dir or "models" # type: ignore[attr-defined] + model_dir = cfg.tracking.instance.model_dir or "models" # type: ignore[attr-defined] checkpoints = client.list_artifacts(run.info.run_id, model_dir) assert len(checkpoints) == cfg.training.num_steps // cfg.training.checkpoint_every + 1 diff --git a/tests/end_to_end/training.py b/tests/end_to_end/training.py index ea0b52ed..0bdeced2 100644 --- a/tests/end_to_end/training.py +++ b/tests/end_to_end/training.py @@ -25,16 +25,14 @@ import simplexity from simplexity.generative_processes.hidden_markov_model import HiddenMarkovModel from simplexity.generative_processes.torch_generator import generate_data_batch -from simplexity.logging.mlflow_logger import MLFlowLogger from simplexity.metrics.metric_tracker import MetricTracker -from simplexity.persistence.mlflow_persister import MLFlowPersister from simplexity.structured_configs.generative_process import GenerativeProcessConfig -from simplexity.structured_configs.logging import LoggingConfig from simplexity.structured_configs.metric_tracker import MetricTrackerConfig -from simplexity.structured_configs.mlflow import MLFlowConfig +from simplexity.structured_configs.mlflow import MlflowConfig from simplexity.structured_configs.optimizer import OptimizerConfig -from simplexity.structured_configs.persistence import PersistenceConfig from simplexity.structured_configs.predictive_model import PredictiveModelConfig +from simplexity.structured_configs.tracking import MlflowTrackerInstanceConfig +from simplexity.tracking.mlflow_tracker import MlflowTracker CONFIG_DIR = str(Path(__file__).parent / "configs") CONFIG_NAME = "training_test.yaml" @@ -59,10 +57,9 @@ class TrainingConfig: class TrainingRunConfig: """Configuration for the managed run demo.""" - mlflow: MLFlowConfig - logging: LoggingConfig + mlflow: MlflowConfig + tracking: MlflowTrackerInstanceConfig generative_process: GenerativeProcessConfig - persistence: PersistenceConfig predictive_model: PredictiveModelConfig optimizer: OptimizerConfig training_metric_tracker: MetricTrackerConfig @@ -81,12 +78,10 @@ def train(cfg: TrainingRunConfig, components: simplexity.Components) -> None: """Test the managed run decorator.""" active_run = mlflow.active_run() assert active_run is not None - logger = components.get_logger() - assert isinstance(logger, MLFlowLogger) + tracker = components.get_run_tracker() + assert isinstance(tracker, MlflowTracker) generative_process = components.get_generative_process() assert isinstance(generative_process, HiddenMarkovModel) - persister = components.get_persister() - assert isinstance(persister, MLFlowPersister) predictive_model = components.get_predictive_model() assert isinstance(predictive_model, HookedTransformer) optimizer = components.get_optimizer() @@ -127,7 +122,7 @@ def train_step(step: int): def log_step(step: int, group: str) -> None: metrics = training_metric_tracker.get_metrics(group) - logger.log_metrics(step, metrics) + tracker.log_metrics(step, metrics) eval_inputs, eval_labels = generate(cfg.training.num_steps) @@ -145,10 +140,10 @@ def eval_step(step: int) -> None: eval_metric_tracker.step(loss=loss) metrics = eval_metric_tracker.get_metrics() metrics = add_key_prefix(metrics, "eval") - logger.log_metrics(step, metrics) + tracker.log_metrics(step, metrics) def checkpoint_step(step: int) -> None: - persister.save_weights(predictive_model, step) + tracker.save_model(predictive_model, step) for step in range(cfg.training.num_steps + 1): if step == 0: @@ -170,7 +165,7 @@ def checkpoint_step(step: int) -> None: sample_inputs = generate(0)[0] # TODO(https://github.com/Astera-org/simplexity/issues/125): This is a hack step += 1 # pyright: ignore[reportPossiblyUnboundVariable] - persister.save_model_to_registry(predictive_model, registered_model_name, model_inputs=sample_inputs, step=step) + tracker.save_model_to_registry(predictive_model, registered_model_name, model_inputs=sample_inputs, step=step) if __name__ == "__main__": diff --git a/tests/logging/test_artifact_logging.py b/tests/logging/test_artifact_logging.py deleted file mode 100644 index 836162fd..00000000 --- a/tests/logging/test_artifact_logging.py +++ /dev/null @@ -1,273 +0,0 @@ -"""Tests for artifact logging functionality across all logger implementations.""" - -# pylint: disable-all -# Temporarily disable all pylint checkers during AST traversal to prevent crash. -# The imports checker crashes when resolving simplexity package imports due to a bug -# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 -# pylint: enable=all -# Re-enable all pylint checkers for the checking phase. This allows other checks -# (code quality, style, undefined names, etc.) to run normally while bypassing -# the problematic imports checker that would crash during AST traversal. - -import json -import os -import tempfile -from pathlib import Path -from unittest.mock import MagicMock, patch - -import pytest - -from simplexity.logging.file_logger import FileLogger -from simplexity.logging.mlflow_logger import MLFlowLogger -from simplexity.logging.print_logger import PrintLogger - - -@pytest.fixture -def sample_json_data(): - """Sample JSON data for testing.""" - return {"key": "value", "number": 42, "list": [1, 2, 3]} - - -@pytest.fixture -def sample_list_data(): - """Sample list data for testing.""" - return [{"id": 1, "name": "test1"}, {"id": 2, "name": "test2"}] - - -@pytest.fixture -def test_artifact_file(tmp_path: Path): - """Create a test file to use as an artifact.""" - test_file = tmp_path / "source" / "test_artifact.txt" - test_file.parent.mkdir() - test_file.write_text("test content") - return test_file - - -@pytest.fixture -def test_artifact_directory(tmp_path: Path): - """Create a test directory to use as an artifact.""" - source_dir = tmp_path / "source_dir" - source_dir.mkdir() - (source_dir / "file1.txt").write_text("content1") - (source_dir / "file2.txt").write_text("content2") - return source_dir - - -@pytest.fixture -def file_logger(tmp_path: Path): - """Create FileLogger with temporary path.""" - return FileLogger(str(tmp_path / "test.log")) - - -@pytest.fixture -def print_logger(): - """Create PrintLogger.""" - return PrintLogger() - - -class TestFileLoggerArtifacts: - """Tests for FileLogger artifact logging.""" - - def test_log_artifact_copies_file(self, file_logger, test_artifact_file, tmp_path: Path): - """Test that log_artifact copies a file to the log directory.""" - # Act - file_logger.log_artifact(str(test_artifact_file)) - file_logger.close() - - # Assert - copied_file = tmp_path / "test_artifact.txt" - assert copied_file.exists() - assert copied_file.read_text() == "test content" - - # Verify log content - with open(tmp_path / "test.log") as f: - log_content = f.read() - assert "Artifact copied:" in log_content - assert "test_artifact.txt" in log_content - - def test_log_artifact_with_custom_path(self, file_logger, tmp_path: Path): - """Test log_artifact with custom artifact path.""" - # Arrange - test_file = tmp_path / "source.txt" - test_file.write_text("content") - - # Act - file_logger.log_artifact(str(test_file), "custom/path/dest.txt") - file_logger.close() - - # Assert - copied_file = tmp_path / "custom" / "path" / "dest.txt" - assert copied_file.exists() - assert copied_file.read_text() == "content" - - def test_log_artifact_directory(self, file_logger, test_artifact_directory, tmp_path: Path): - """Test that log_artifact can copy entire directories.""" - # Act - file_logger.log_artifact(str(test_artifact_directory), "copied_dir") - file_logger.close() - - # Assert - copied_dir = tmp_path / "copied_dir" - assert copied_dir.is_dir() - assert (copied_dir / "file1.txt").read_text() == "content1" - assert (copied_dir / "file2.txt").read_text() == "content2" - - def test_log_json_artifact_saves_json(self, file_logger, sample_json_data, tmp_path: Path): - """Test that log_json_artifact saves JSON data.""" - # Act - file_logger.log_json_artifact(sample_json_data, "results.json") - file_logger.close() - - # Assert - json_file = tmp_path / "results.json" - assert json_file.exists() - - with open(json_file) as f: - loaded_data = json.load(f) - assert loaded_data == sample_json_data - - # Verify log content - with open(tmp_path / "test.log") as f: - log_content = f.read() - assert "JSON artifact saved:" in log_content - assert "results.json" in log_content - - def test_log_json_artifact_with_list(self, file_logger, sample_list_data, tmp_path: Path): - """Test log_json_artifact with list data.""" - # Act - file_logger.log_json_artifact(sample_list_data, "data_list.json") - file_logger.close() - - # Assert - json_file = tmp_path / "data_list.json" - assert json_file.exists() - - with open(json_file) as f: - loaded_data = json.load(f) - assert loaded_data == sample_list_data - - -class TestPrintLoggerArtifacts: - """Tests for PrintLogger artifact logging.""" - - def test_log_artifact_prints_info(self, print_logger, capsys): - """Test that log_artifact prints appropriate message.""" - # Act - print_logger.log_artifact("/path/to/file.txt") - print_logger.close() - - # Assert - captured = capsys.readouterr() - expected = ( - "[PrintLogger] Artifact NOT logged - would copy: /path/to/file.txt -> " - ) - assert expected in captured.out - - def test_log_artifact_with_custom_path_prints_info(self, print_logger, capsys): - """Test log_artifact with custom path prints correct message.""" - # Act - print_logger.log_artifact("/source.txt", "dest/path.txt") - print_logger.close() - - # Assert - captured = capsys.readouterr() - expected = "[PrintLogger] Artifact NOT logged - would copy: /source.txt -> dest/path.txt" - assert expected in captured.out - - def test_log_json_artifact_prints_info(self, print_logger, sample_json_data, capsys): - """Test log_json_artifact prints correct message.""" - # Act - print_logger.log_json_artifact(sample_json_data, "test.json") - print_logger.close() - - # Assert - captured = capsys.readouterr() - expected = "[PrintLogger] JSON artifact NOT saved - would be: test.json (dict with 3 items)" - assert expected in captured.out - - def test_log_json_artifact_list_prints_info(self, print_logger, sample_list_data, capsys): - """Test log_json_artifact with list prints correct message.""" - # Act - print_logger.log_json_artifact(sample_list_data, "list.json") - print_logger.close() - - # Assert - captured = capsys.readouterr() - expected = "[PrintLogger] JSON artifact NOT saved - would be: list.json (list with 2 items)" - assert expected in captured.out - - -class TestMLFlowLoggerArtifacts: - """Tests for MLFlowLogger artifact logging.""" - - @pytest.fixture(autouse=True) - def setup_mlflow_temp_dir(self): - """Set up temporary directory for MLflow tracking during tests.""" - with tempfile.TemporaryDirectory() as tmp_dir: - # Set MLflow tracking URI to temp directory to avoid creating mlruns/ in project - original_uri = os.environ.get("MLFLOW_TRACKING_URI") - os.environ["MLFLOW_TRACKING_URI"] = f"file://{tmp_dir}" - try: - yield - finally: - # Restore original URI - if original_uri is not None: - os.environ["MLFLOW_TRACKING_URI"] = original_uri - else: - os.environ.pop("MLFLOW_TRACKING_URI", None) - - @pytest.fixture - def mock_mlflow_logger(self): - """Create a mocked MLFlowLogger for testing.""" - with patch("mlflow.MlflowClient") as mock_client_class: - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.get_experiment_by_name.return_value = None - mock_client.create_experiment.return_value = "exp_123" - mock_run = MagicMock() - mock_run.info.run_id = "run_456" - mock_client.create_run.return_value = mock_run - - logger = MLFlowLogger(experiment_name="test_experiment") - yield logger, mock_client - - def test_log_artifact_calls_client(self, mock_mlflow_logger): - """Test that log_artifact calls the MLflow client correctly.""" - # Arrange - logger, mock_client = mock_mlflow_logger - - # Act - logger.log_artifact("/path/to/file.txt", "artifacts/file.txt") - logger.close() - - # Assert - mock_client.log_artifact.assert_called_once_with("run_456", "/path/to/file.txt", "artifacts/file.txt") - - def test_log_artifact_without_artifact_path(self, mock_mlflow_logger): - """Test log_artifact without custom artifact path.""" - # Arrange - logger, mock_client = mock_mlflow_logger - - # Act - logger.log_artifact("/path/to/model.pkl") - logger.close() - - # Assert - mock_client.log_artifact.assert_called_once_with("run_456", "/path/to/model.pkl", None) - - def test_log_json_artifact_calls_client(self, mock_mlflow_logger, sample_json_data): - """Test that log_json_artifact creates temp file and calls client.""" - # Arrange - logger, mock_client = mock_mlflow_logger - - # Act - logger.log_json_artifact(sample_json_data, "metrics.json") - logger.close() - - # Assert - mock_client.log_artifact.assert_called_once() - call_args = mock_client.log_artifact.call_args - assert call_args[0][0] == "run_456" # run_id - assert call_args[0][1].endswith("metrics.json") # temp file path - # log_json_artifact calls with only 2 args (no artifact_path) - assert len(call_args[0]) == 2 diff --git a/tests/logging/test_file_logger.py b/tests/logging/test_file_logger.py deleted file mode 100644 index 84b71773..00000000 --- a/tests/logging/test_file_logger.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Test the file logger.""" - -# pylint: disable-all -# Temporarily disable all pylint checkers during AST traversal to prevent crash. -# The imports checker crashes when resolving simplexity package imports due to a bug -# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 -# pylint: enable=all -# Re-enable all pylint checkers for the checking phase. This allows other checks -# (code quality, style, undefined names, etc.) to run normally while bypassing -# the problematic imports checker that would crash during AST traversal. - -from pathlib import Path - -import jax.numpy as jnp -from omegaconf import DictConfig - -from simplexity.logging.file_logger import FileLogger - -EXPECTED_LOG = """Config: {'str_param': 'str_value', 'int_param': 1, 'float_param': 1.0, 'bool_param': True} -Config: {'str_param': 'str_value', 'int_param': 1, 'float_param': 1.0, 'bool_param': True} -Params: {'str_param': 'str_value', 'int_param': 1, 'float_param': 1.0, 'bool_param': True} -Tags: {'str_tag': 'str_value', 'int_tag': 1, 'float_tag': 1.0, 'bool_tag': True} -Metrics at step 1: {'int_metric': 1, 'float_metric': 1.0, 'jnp_metric': Array(0.1, dtype=float32, weak_type=True)} -""" - -EXPECTED_LOG_WITH_INTERPOLATION = ( - "Config: {'base_value': 'hello', 'interpolated_value': 'hello_world', 'nested': {'value': 'hello_nested'}}\n" -) - - -def test_file_logger(tmp_path: Path): - logger = FileLogger(str(tmp_path / "test.log")) - params = { - "str_param": "str_value", - "int_param": 1, - "float_param": 1.0, - "bool_param": True, - } - logger.log_config(DictConfig(params)) - logger.log_config(DictConfig(params), resolve=True) - logger.log_params(params) - tags = { - "str_tag": "str_value", - "int_tag": 1, - "float_tag": 1.0, - "bool_tag": True, - } - logger.log_tags(tags) - metrics = { - "int_metric": 1, - "float_metric": 1.0, - "jnp_metric": jnp.array(0.1), - } - logger.log_metrics(1, metrics) - logger.close() - - with open(tmp_path / "test.log") as f: - assert f.read() == EXPECTED_LOG - - -def test_file_logger_with_interpolation(tmp_path: Path): - """Test that resolved config properly resolves interpolations.""" - logger = FileLogger(str(tmp_path / "test.log")) - - # Create a config with interpolation - config_dict = { - "base_value": "hello", - "interpolated_value": "${base_value}_world", - "nested": { - "value": "${base_value}_nested", - }, - } - - config = DictConfig(config_dict) - logger.log_config(config, resolve=True) - logger.close() - - with open(tmp_path / "test.log") as f: - assert f.read() == EXPECTED_LOG_WITH_INTERPOLATION diff --git a/tests/logging/test_plot_logging.py b/tests/logging/test_plot_logging.py deleted file mode 100644 index 3145b036..00000000 --- a/tests/logging/test_plot_logging.py +++ /dev/null @@ -1,344 +0,0 @@ -"""Tests for plot and image logging functionality across all logger implementations.""" - -# pylint: disable-all -# Temporarily disable all pylint checkers during AST traversal to prevent crash. -# The imports checker crashes when resolving simplexity package imports due to a bug -# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 -# pylint: enable=all -# Re-enable all pylint checkers for the checking phase. This allows other checks -# (code quality, style, undefined names, etc.) to run normally while bypassing -# the problematic imports checker that would crash during AST traversal. - -import os -import tempfile -from pathlib import Path -from unittest.mock import MagicMock, patch - -import matplotlib.pyplot as plt -import numpy as np -import pytest -from PIL import Image - -from simplexity.logging.file_logger import FileLogger -from simplexity.logging.mlflow_logger import MLFlowLogger -from simplexity.logging.print_logger import PrintLogger - - -@pytest.fixture -def matplotlib_figure(): - """Create a reusable matplotlib figure for testing.""" - fig, ax = plt.subplots(figsize=(4, 3)) - ax.plot([1, 2, 3, 4], [1, 4, 2, 3]) - ax.set_title("Test Plot") - yield fig - plt.close(fig) - - -@pytest.fixture -def simple_matplotlib_figure(): - """Create a simple matplotlib figure for basic tests.""" - fig, ax = plt.subplots() - ax.plot([1, 2, 3]) - yield fig - plt.close(fig) - - -@pytest.fixture -def numpy_image(): - """Create a reusable numpy image array for testing.""" - return np.random.randint(0, 255, (80, 120, 3), dtype=np.uint8) - - -@pytest.fixture -def small_numpy_image(): - """Create a small numpy image array for testing.""" - return np.ones((50, 50, 3), dtype=np.uint8) * 100 - - -@pytest.fixture -def tiny_numpy_image(): - """Create a tiny numpy image array for testing.""" - return np.zeros((10, 10, 3), dtype=np.uint8) - - -@pytest.fixture -def pil_image(): - """Create a reusable PIL image for testing.""" - return Image.new("RGB", (100, 50), color="red") - - -@pytest.fixture -def small_pil_image(): - """Create a small PIL image for testing.""" - return Image.new("RGB", (10, 10)) - - -@pytest.fixture -def larger_pil_image(): - """Create a larger PIL image for testing.""" - return Image.new("RGB", (20, 20)) - - -class TestFileLoggerPlotting: - """Tests for FileLogger figure and image logging.""" - - def test_log_figure_saves_matplotlib_plot(self, matplotlib_figure, tmp_path: Path): - """Test that log_figure saves a matplotlib figure to disk.""" - # Arrange - logger = FileLogger(str(tmp_path / "test.log")) - - # Act - logger.log_figure(matplotlib_figure, "test_plot.png") - logger.close() - - # Assert - # Verify it's a valid PNG image with expected size - with Image.open(tmp_path / "test_plot.png") as img: - assert img.size == (400, 300) # 4x3 inches * 100 DPI default - - # Verify log content - with open(tmp_path / "test.log") as f: - log_content = f.read() - assert "Figure saved:" in log_content - assert "test_plot.png" in log_content - - def test_log_figure_with_kwargs(self, simple_matplotlib_figure, tmp_path: Path): - """Test that log_figure passes kwargs to matplotlib savefig.""" - # Arrange - logger = FileLogger(str(tmp_path / "test.log")) - - # Act - logger.log_figure(simple_matplotlib_figure, "high_dpi.png", dpi=200, bbox_inches="tight") - logger.close() - - # Assert - with Image.open(tmp_path / "high_dpi.png") as img: - # Higher DPI should result in larger image - assert img.size[0] >= 800 # 4 inches * 200 DPI - - def test_log_image_pil_artifact_mode(self, pil_image, tmp_path: Path): - """Test logging PIL Image in artifact mode.""" - # Arrange - logger = FileLogger(str(tmp_path / "test.log")) - - # Act - logger.log_image(pil_image, artifact_file="pil_test.png") - logger.close() - - # Assert - with Image.open(tmp_path / "pil_test.png") as img: - assert img.size == (100, 50) - - with open(tmp_path / "test.log") as f: - log_content = f.read() - assert "Image saved:" in log_content - assert "pil_test.png" in log_content - - def test_log_image_numpy_artifact_mode(self, numpy_image, tmp_path: Path): - """Test logging numpy array in artifact mode.""" - # Arrange - logger = FileLogger(str(tmp_path / "test.log")) - - # Act - logger.log_image(numpy_image, artifact_file="numpy_test.png") - logger.close() - - # Assert - with Image.open(tmp_path / "numpy_test.png") as img: - assert img.size == (120, 80) # PIL uses (width, height) - - def test_log_image_time_stepped_mode(self, small_numpy_image, tmp_path: Path): - """Test logging image in time-stepped mode with key and step.""" - # Arrange - logger = FileLogger(str(tmp_path / "test.log")) - - # Act - logger.log_image(small_numpy_image, key="training_viz", step=42) - logger.close() - - # Assert - with Image.open(tmp_path / "training_viz_step_42.png") as img: - assert img.size == (50, 50) - - with open(tmp_path / "test.log") as f: - log_content = f.read() - assert "Time-stepped image saved:" in log_content - assert "training_viz_step_42.png" in log_content - - def test_log_image_unsupported_type(self, tmp_path: Path): - """Test logging unsupported image type logs error.""" - # Arrange - logger = FileLogger(str(tmp_path / "test.log")) - unsupported_image = "not an image" - - # Act - logger.log_image(unsupported_image, artifact_file="bad.png") # type: ignore[arg-type] # Intentionally testing unsupported type - logger.close() - - # Assert - assert not (tmp_path / "bad.png").exists() - - with open(tmp_path / "test.log") as f: - log_content = f.read() - assert "not supported for file saving" in log_content - - def test_log_image_missing_parameters_fails(self, tiny_numpy_image, tmp_path: Path): - """Test that log_image without proper parameters logs error.""" - # Arrange - logger = FileLogger(str(tmp_path / "test.log")) - - # Act - no artifact_file and incomplete key+step - logger.log_image(tiny_numpy_image, key="incomplete") # missing step - logger.close() - - # Assert - with open(tmp_path / "test.log") as f: - log_content = f.read() - assert "Image logging failed" in log_content - - def test_log_image_no_parameters_fails(self, tiny_numpy_image, tmp_path: Path): - """Test that log_image with no parameters logs error.""" - # Arrange - logger = FileLogger(str(tmp_path / "test.log")) - - # Act - no parameters provided - logger.log_image(tiny_numpy_image) # Neither artifact_file nor key+step - logger.close() - - # Assert - with open(tmp_path / "test.log") as f: - log_content = f.read() - assert "Image logging failed - need either artifact_file or (key + step)" in log_content - - -class TestPrintLoggerPlotting: - """Tests for PrintLogger figure and image logging.""" - - def test_log_figure_prints_info(self, simple_matplotlib_figure, capsys): - """Test that log_figure prints appropriate message.""" - # Arrange - logger = PrintLogger() - - # Act - logger.log_figure(simple_matplotlib_figure, "test.png") - logger.close() - - # Assert - captured = capsys.readouterr() - assert "[PrintLogger] Figure NOT saved - would be: test.png (type: Figure)" in captured.out - - def test_log_image_artifact_mode_prints_info(self, tiny_numpy_image, capsys): - """Test log_image in artifact mode prints correct message.""" - # Arrange - logger = PrintLogger() - - # Act - logger.log_image(tiny_numpy_image, artifact_file="test.png") - logger.close() - - # Assert - captured = capsys.readouterr() - assert "[PrintLogger] Image NOT saved - would be artifact: test.png (type: ndarray)" in captured.out - - def test_log_image_time_stepped_mode_prints_info(self, small_pil_image, capsys): - """Test log_image in time-stepped mode prints correct message.""" - # Arrange - logger = PrintLogger() - - # Act - logger.log_image(small_pil_image, key="loss_viz", step=100) - logger.close() - - # Assert - captured = capsys.readouterr() - assert "[PrintLogger] Image NOT saved - would be key: loss_viz, step: 100 (type: Image)" in captured.out - - -class TestMLFlowLoggerPlotting: - """Tests for MLFlowLogger figure and image logging.""" - - @pytest.fixture(autouse=True) - def setup_mlflow_temp_dir(self): - """Set up temporary directory for MLflow tracking during tests.""" - with tempfile.TemporaryDirectory() as tmp_dir: - # Set MLflow tracking URI to temp directory to avoid creating mlruns/ in project - original_uri = os.environ.get("MLFLOW_TRACKING_URI") - os.environ["MLFLOW_TRACKING_URI"] = f"file://{tmp_dir}" - try: - yield - finally: - # Restore original URI - if original_uri is not None: - os.environ["MLFLOW_TRACKING_URI"] = original_uri - else: - os.environ.pop("MLFLOW_TRACKING_URI", None) - - @patch("mlflow.MlflowClient") - def test_log_figure_calls_client_method(self, mock_client_class, simple_matplotlib_figure): - """Test that log_figure calls the MLflow client correctly.""" - # Arrange - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.get_experiment_by_name.return_value = None - mock_client.create_experiment.return_value = "exp_123" - mock_client.search_runs.return_value = [] - mock_run = MagicMock() - mock_run.info.run_id = "run_456" - mock_client.create_run.return_value = mock_run - - logger = MLFlowLogger(experiment_name="test_experiment", run_name="test_run") - - # Act - logger.log_figure(simple_matplotlib_figure, "test.png", dpi=150) - logger.close() - - # Assert - mock_client.log_figure.assert_called_once_with("run_456", simple_matplotlib_figure, "test.png", dpi=150) - - @patch("mlflow.MlflowClient") - def test_log_image_artifact_mode_calls_client(self, mock_client_class, tiny_numpy_image): - """Test log_image in artifact mode calls MLflow client.""" - # Arrange - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.get_experiment_by_name.return_value = None - mock_client.create_experiment.return_value = "exp_123" - mock_client.search_runs.return_value = [] - mock_run = MagicMock() - mock_run.info.run_id = "run_456" - mock_client.create_run.return_value = mock_run - - logger = MLFlowLogger(experiment_name="test_experiment") - - # Act - logger.log_image(tiny_numpy_image, artifact_file="image.png") - logger.close() - - # Assert - mock_client.log_image.assert_called_once_with( - "run_456", tiny_numpy_image, artifact_file="image.png", key=None, step=None - ) - - @patch("mlflow.MlflowClient") - def test_log_image_time_stepped_mode_calls_client(self, mock_client_class, larger_pil_image): - """Test log_image in time-stepped mode calls MLflow client.""" - # Arrange - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.get_experiment_by_name.return_value = None - mock_client.create_experiment.return_value = "exp_123" - mock_client.search_runs.return_value = [] - mock_run = MagicMock() - mock_run.info.run_id = "run_456" - mock_client.create_run.return_value = mock_run - - logger = MLFlowLogger(experiment_name="test_experiment") - - # Act - logger.log_image(larger_pil_image, key="training", step=50, timestamp=1234567890) - logger.close() - - # Assert - mock_client.log_image.assert_called_once_with( - "run_456", larger_pil_image, artifact_file=None, key="training", step=50, timestamp=1234567890 - ) diff --git a/tests/run_management/test_components.py b/tests/run_management/test_components.py index 6feacee9..7191d24c 100644 --- a/tests/run_management/test_components.py +++ b/tests/run_management/test_components.py @@ -16,10 +16,9 @@ from simplexity.activations.activation_tracker import ActivationTracker from simplexity.generative_processes.generative_process import GenerativeProcess -from simplexity.logging.logger import Logger from simplexity.metrics.metric_tracker import MetricTracker -from simplexity.persistence.model_persister import ModelPersister from simplexity.run_management.components import Components +from simplexity.tracking.tracker import RunTracker def test_get_none(): @@ -34,25 +33,25 @@ def test_get_none_with_key_raises_error(): def test_get_unique_instance(): - logger = Mock(spec=Logger) - components = Components(loggers={"mock": logger}) - assert components.get_logger() == logger + tracker = Mock(spec=RunTracker) + components = Components(run_trackers={"mock": tracker}) + assert components.get_run_tracker() == tracker def test_get_multiple_instances_without_key_raises_error(): - persister_1 = Mock(spec=ModelPersister) - persister_2 = Mock(spec=ModelPersister) - components = Components(persisters={"mock_1": persister_1, "mock_2": persister_2}) - with pytest.raises(KeyError, match="No key provided and multiple persisters found"): - components.get_persister() + tracker_1 = Mock(spec=RunTracker) + tracker_2 = Mock(spec=RunTracker) + components = Components(run_trackers={"mock_1": tracker_1, "mock_2": tracker_2}) + with pytest.raises(KeyError, match="No key provided and multiple run trackers found"): + components.get_run_tracker() def test_get_instance_with_key(): - persister_1 = Mock(spec=ModelPersister) - persister_2 = Mock(spec=ModelPersister) - components = Components(persisters={"mock_1": persister_1, "mock_2": persister_2}) - assert components.get_persister("mock_1") == persister_1 - assert components.get_persister("mock_2") == persister_2 + tracker_1 = Mock(spec=RunTracker) + tracker_2 = Mock(spec=RunTracker) + components = Components(run_trackers={"mock_1": tracker_1, "mock_2": tracker_2}) + assert components.get_run_tracker("mock_1") == tracker_1 + assert components.get_run_tracker("mock_2") == tracker_2 def test_get_instance_with_ending_key(): diff --git a/tests/structured_configs/test_base_config.py b/tests/structured_configs/test_base_config.py index 648c209f..abd7ecbb 100644 --- a/tests/structured_configs/test_base_config.py +++ b/tests/structured_configs/test_base_config.py @@ -91,26 +91,26 @@ def test_validate_base_config_invalid_tags(self) -> None: def test_validate_base_config_invalid_mlflow(self) -> None: """Test validate_base_config with invalid mlflow.""" - # Non-MLFlowConfig mlflow - cfg = DictConfig({"mlflow": "not an MLFlowConfig"}) - with pytest.raises(ConfigValidationError, match="BaseConfig.mlflow must be a MLFlowConfig"): + # Non-MlflowConfig mlflow + cfg = DictConfig({"mlflow": "not an MlflowConfig"}) + with pytest.raises(ConfigValidationError, match="BaseConfig.mlflow must be a MlflowConfig"): validate_base_config(cfg) - # MLFlowConfig with empty experiment_name (whitespace) + # MlflowConfig with empty experiment_name (whitespace) cfg = DictConfig({"mlflow": DictConfig({"experiment_name": " "})}) - with pytest.raises(ConfigValidationError, match="MLFlowConfig.experiment_name must be a non-empty string"): + with pytest.raises(ConfigValidationError, match="MlflowConfig.experiment_name must be a non-empty string"): validate_base_config(cfg) def test_validate_base_config_propagates_mlflow_errors(self) -> None: """Test that MLflow validation errors propagate correctly.""" # Invalid tracking_uri scheme cfg = DictConfig({"mlflow": DictConfig({"tracking_uri": "relative/path"})}) - with pytest.raises(ConfigValidationError, match="MLFlowConfig.tracking_uri must have a valid URI scheme"): + with pytest.raises(ConfigValidationError, match="MlflowConfig.tracking_uri must have a valid URI scheme"): validate_base_config(cfg) # Empty experiment_name cfg = DictConfig({"mlflow": DictConfig({"experiment_name": " "})}) - with pytest.raises(ConfigValidationError, match="MLFlowConfig.experiment_name must be a non-empty string"): + with pytest.raises(ConfigValidationError, match="MlflowConfig.experiment_name must be a non-empty string"): validate_base_config(cfg) def test_resolve_base_config(self) -> None: diff --git a/tests/structured_configs/test_logging_config.py b/tests/structured_configs/test_logging_config.py deleted file mode 100644 index a39352eb..00000000 --- a/tests/structured_configs/test_logging_config.py +++ /dev/null @@ -1,287 +0,0 @@ -"""Tests for LoggingConfig validation. - -This module contains tests for logging configuration validation, including -validation of logger targets, logger configs, and logging configuration instances. -""" - -# pylint: disable-all -# Temporarily disable all pylint checkers during AST traversal to prevent crash. -# The imports checker crashes when resolving simplexity package imports due to a bug -# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 -# pylint: enable=all -# Re-enable all pylint checkers for the checking phase. This allows other checks -# (code quality, style, undefined names, etc.) to run normally while bypassing -# the problematic imports checker that would crash during AST traversal. - -import pytest -from omegaconf import DictConfig, OmegaConf - -from simplexity.exceptions import ConfigValidationError -from simplexity.structured_configs.logging import ( - FileLoggerInstanceConfig, - InstanceConfig, - LoggingConfig, - is_logger_config, - is_logger_target, - update_logging_instance_config, - validate_logging_config, -) - - -class TestLoggingConfig: - """Test LoggingConfig.""" - - def test_logging_config(self) -> None: - """Test creating logger config from dataclass.""" - cfg: DictConfig = OmegaConf.structured(LoggingConfig(instance=InstanceConfig(_target_="some_target"))) - assert OmegaConf.select(cfg, "instance._target_") == "some_target" - assert cfg.get("name") is None - - def test_is_logger_target_valid(self) -> None: - """Test is_logger_target with valid logger targets.""" - assert is_logger_target("simplexity.logging.file_logger.FileLogger") - assert is_logger_target("simplexity.logging.mlflow_logger.MLFlowLogger") - assert is_logger_target("simplexity.logging.print_logger.PrintLogger") - - def test_is_logger_target_invalid(self) -> None: - """Test is_logger_target with invalid targets.""" - assert not is_logger_target("simplexity.persistence.mlflow_persister.MLFlowPersister") - assert not is_logger_target("logging.Logger") - assert not is_logger_target("") - - def test_is_logger_config_valid(self) -> None: - """Test is_logger_config with valid logger configs.""" - cfg = DictConfig({"_target_": "simplexity.logging.mlflow_logger.MLFlowLogger"}) - assert is_logger_config(cfg) - - cfg = DictConfig( - { - "_target_": "simplexity.logging.mlflow_logger.MLFlowLogger", - "experiment_name": "my_experiment", - "run_name": "my_run", - "tracking_uri": "databricks", - } - ) - assert is_logger_config(cfg) - - def test_is_logger_config_invalid(self) -> None: - """Test is_logger_config with invalid configs.""" - # Non-logger target - cfg = DictConfig({"_target_": "simplexity.persistence.mlflow_persister.MLFlowPersister"}) - assert not is_logger_config(cfg) - - # Missing _target_ - cfg = DictConfig({"experiment_name": "my_experiment", "run_name": "my_run", "tracking_uri": "databricks"}) - assert not is_logger_config(cfg) - - # _target_ is not a omegaconf target - cfg = DictConfig({"target": "simplexity.logging.mlflow_logger.MLFlowLogger"}) - assert not is_logger_config(cfg) - - # _target_ is None - cfg = DictConfig({"_target_": None}) - assert not is_logger_config(cfg) - - # _target_ is not a string - cfg = DictConfig({"_target_": 123}) - assert not is_logger_config(cfg) - - # Empty config - cfg = DictConfig({}) - assert not is_logger_config(cfg) - - def test_validate_logging_config_valid(self) -> None: - """Test validate_logging_config with valid configs.""" - # Valid config without name - cfg = DictConfig( - { - "instance": DictConfig( - { - "_target_": "simplexity.logging.mlflow_logger.MLFlowLogger", - "experiment_name": "my_experiment", - "run_name": "my_run", - "tracking_uri": "databricks", - } - ), - } - ) - validate_logging_config(cfg) # Should not raise - - # Valid config with name - cfg = DictConfig( - { - "instance": DictConfig( - { - "_target_": "simplexity.logging.mlflow_logger.MLFlowLogger", - "experiment_name": "my_experiment", - "run_name": "my_run", - "tracking_uri": "databricks", - } - ), - "name": "my_logger", - } - ) - validate_logging_config(cfg) # Should not raise - - # Valid config with None name - cfg = DictConfig( - { - "instance": DictConfig( - { - "_target_": "simplexity.logging.mlflow_logger.MLFlowLogger", - "experiment_name": "my_experiment", - "run_name": "my_run", - "tracking_uri": "databricks", - } - ), - "name": None, - } - ) - validate_logging_config(cfg) # Should not raise - - def test_validate_logging_config_missing_instance(self) -> None: - """Test validate_logging_config raises when instance is missing.""" - cfg = DictConfig({}) - with pytest.raises(ConfigValidationError, match="LoggingConfig.instance must be a DictConfig"): - validate_logging_config(cfg) - - cfg = DictConfig({"name": "my_logger"}) - with pytest.raises(ConfigValidationError, match="LoggingConfig.instance must be a DictConfig"): - validate_logging_config(cfg) - - def test_validate_logging_config_invalid_instance(self) -> None: - """Test validate_logging_config raises when instance is invalid.""" - # Instance without _target_ - cfg = DictConfig({"instance": DictConfig({"other_field": "value"})}) - with pytest.raises(ConfigValidationError, match="InstanceConfig._target_ must be a string"): - validate_logging_config(cfg) - - # Instance with empty _target_ - cfg = DictConfig({"instance": DictConfig({"_target_": ""})}) - with pytest.raises(ConfigValidationError, match="InstanceConfig._target_ must be a non-empty string"): - validate_logging_config(cfg) - - # Instance with non-string _target_ - cfg = DictConfig({"instance": DictConfig({"_target_": 123})}) - with pytest.raises(ConfigValidationError, match="InstanceConfig._target_ must be a string"): - validate_logging_config(cfg) - - def test_validate_logging_config_non_logger_target(self) -> None: - """Test validate_logging_config raises when instance target is not a logger target.""" - cfg = DictConfig( - {"instance": DictConfig({"_target_": "simplexity.persistence.mlflow_persister.MLFlowPersister"})} - ) - with pytest.raises(ConfigValidationError, match="LoggingConfig.instance must be a logger target"): - validate_logging_config(cfg) - - cfg = DictConfig({"instance": DictConfig({"_target_": "torch.optim.Adam"})}) - with pytest.raises(ConfigValidationError, match="LoggingConfig.instance must be a logger target"): - validate_logging_config(cfg) - - def test_validate_logging_config_invalid_name(self) -> None: - """Test validate_logging_config raises when name is invalid.""" - # Empty string name - cfg = DictConfig( - { - "instance": DictConfig({"_target_": "simplexity.logging.mlflow_logger.MLFlowLogger"}), - "name": "", - } - ) - with pytest.raises(ConfigValidationError, match="LoggingConfig.name must be a non-empty string"): - validate_logging_config(cfg) - - # Whitespace-only name - cfg = DictConfig( - { - "instance": DictConfig({"_target_": "simplexity.logging.mlflow_logger.MLFlowLogger"}), - "name": " ", - } - ) - with pytest.raises(ConfigValidationError, match="LoggingConfig.name must be a non-empty string"): - validate_logging_config(cfg) - - # Non-string name - cfg = DictConfig( - { - "instance": DictConfig({"_target_": "simplexity.logging.mlflow_logger.MLFlowLogger"}), - "name": 123, - } - ) - with pytest.raises(ConfigValidationError, match="LoggingConfig.name must be a string or None"): - validate_logging_config(cfg) - - def test_validate_file_logger_config(self) -> None: - """Test validation of FileLogger configuration.""" - # Valid file logger config - cfg = DictConfig( - { - "instance": DictConfig( - { - "_target_": "simplexity.logging.file_logger.FileLogger", - "file_path": "/tmp/test.log", - } - ) - } - ) - validate_logging_config(cfg) - - # Missing file_path - cfg = DictConfig( - { - "instance": DictConfig( - { - "_target_": "simplexity.logging.file_logger.FileLogger", - } - ) - } - ) - with pytest.raises(ConfigValidationError, match="FileLoggerInstanceConfig.file_path must be a string"): - validate_logging_config(cfg) - - # Empty file_path - cfg = DictConfig( - { - "instance": DictConfig( - { - "_target_": "simplexity.logging.file_logger.FileLogger", - "file_path": "", - } - ) - } - ) - with pytest.raises( - ConfigValidationError, match="FileLoggerInstanceConfig.file_path must be a non-empty string" - ): - validate_logging_config(cfg) - - def test_update_logging_instance_config(self) -> None: - """Test update_logging_instance_config function.""" - # Initial config - cfg = DictConfig( - { - "_target_": "simplexity.logging.mlflow_logger.MLFlowLogger", - "experiment_name": "exp1", - "run_name": "run1", - } - ) - - # Update config - updated_cfg = DictConfig( - { - "_target_": "simplexity.logging.mlflow_logger.MLFlowLogger", - "experiment_name": "exp2", - "tracking_uri": "file:///tmp/mlruns", - } - ) - - update_logging_instance_config(cfg, updated_cfg) - - assert cfg.experiment_name == "exp2" - assert cfg.run_name == "run1" # Should remain unchanged - assert cfg.tracking_uri == "file:///tmp/mlruns" - - def test_file_logger_instance_config_init(self) -> None: - """Test FileLoggerInstanceConfig instantiation.""" - config = FileLoggerInstanceConfig(file_path="test.log") - assert config.file_path == "test.log" - assert config._target_ == "simplexity.logging.file_logger.FileLogger" diff --git a/tests/structured_configs/test_mlflow_config.py b/tests/structured_configs/test_mlflow_config.py index 1475d8b8..58992b9b 100644 --- a/tests/structured_configs/test_mlflow_config.py +++ b/tests/structured_configs/test_mlflow_config.py @@ -1,4 +1,4 @@ -"""Tests for MLFlowConfig validation. +"""Tests for MlflowConfig validation. This module contains tests for MLFlow configuration validation, including validation of experiment_name, run_name, tracking_uri, registry_uri, and @@ -18,15 +18,15 @@ from omegaconf import DictConfig, OmegaConf from simplexity.exceptions import ConfigValidationError -from simplexity.structured_configs.mlflow import MLFlowConfig, validate_mlflow_config +from simplexity.structured_configs.mlflow import MlflowConfig, validate_mlflow_config -class TestMLFlowConfig: - """Test MLFlowConfig.""" +class TestMlflowConfig: + """Test MlflowConfig.""" def test_mlflow_config(self) -> None: """Test creating mlflow config from dataclass.""" - cfg: DictConfig = OmegaConf.structured(MLFlowConfig(experiment_name="some_experiment", run_name="some_run")) + cfg: DictConfig = OmegaConf.structured(MlflowConfig(experiment_name="some_experiment", run_name="some_run")) assert cfg.get("experiment_name") == "some_experiment" assert cfg.get("run_name") == "some_run" assert cfg.get("tracking_uri") is None @@ -63,7 +63,7 @@ def test_validate_mlflow_config_invalid_downgrade_unity_catalog(self) -> None: "downgrade_unity_catalog": "not_a_bool", } ) - with pytest.raises(ConfigValidationError, match="MLFlowConfig.downgrade_unity_catalog must be a bool"): + with pytest.raises(ConfigValidationError, match="MlflowConfig.downgrade_unity_catalog must be a bool"): validate_mlflow_config(cfg) @pytest.mark.parametrize("uri_type", ["tracking_uri", "registry_uri"]) @@ -77,7 +77,7 @@ def test_validate_mlflow_config_invalid_uri(self, uri_type: str) -> None: uri_type: " ", } ) - with pytest.raises(ConfigValidationError, match=f"MLFlowConfig.{uri_type} cannot be empty"): + with pytest.raises(ConfigValidationError, match=f"MlflowConfig.{uri_type} cannot be empty"): validate_mlflow_config(cfg) # parse error (urlparse raises an exception) @@ -88,7 +88,7 @@ def test_validate_mlflow_config_invalid_uri(self, uri_type: str) -> None: uri_type: "%parse_error%", } ) - with pytest.raises(ConfigValidationError, match=f"MLFlowConfig.{uri_type} is not a valid URI"): + with pytest.raises(ConfigValidationError, match=f"MlflowConfig.{uri_type} is not a valid URI"): validate_mlflow_config(cfg) # missing scheme @@ -99,5 +99,5 @@ def test_validate_mlflow_config_invalid_uri(self, uri_type: str) -> None: uri_type: "no_scheme", } ) - with pytest.raises(ConfigValidationError, match=f"MLFlowConfig.{uri_type} must have a valid URI scheme"): + with pytest.raises(ConfigValidationError, match=f"MlflowConfig.{uri_type} must have a valid URI scheme"): validate_mlflow_config(cfg) diff --git a/tests/structured_configs/test_persistence_config.py b/tests/structured_configs/test_persistence_config.py deleted file mode 100644 index 23cab046..00000000 --- a/tests/structured_configs/test_persistence_config.py +++ /dev/null @@ -1,295 +0,0 @@ -"""Tests for PersistenceConfig validation. - -This module contains tests for persistence configuration validation, including -validation of model persister targets, persister configs, and persistence -configuration instances. -""" - -# pylint: disable-all -# Temporarily disable all pylint checkers during AST traversal to prevent crash. -# The imports checker crashes when resolving simplexity package imports due to a bug -# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 -# pylint: enable=all -# Re-enable all pylint checkers for the checking phase. This allows other checks -# (code quality, style, undefined names, etc.) to run normally while bypassing -# the problematic imports checker that would crash during AST traversal. - -import pytest -from omegaconf import DictConfig, OmegaConf - -from simplexity.exceptions import ConfigValidationError -from simplexity.structured_configs.persistence import ( - InstanceConfig, - LocalEquinoxPersisterInstanceConfig, - LocalPenzaiPersisterInstanceConfig, - LocalPytorchPersisterInstanceConfig, - MLFlowPersisterInstanceConfig, - PersistenceConfig, - is_model_persister_target, - is_persister_config, - update_persister_instance_config, - validate_local_equinox_persister_instance_config, - validate_local_penzai_persister_instance_config, - validate_local_pytorch_persister_instance_config, - validate_persistence_config, -) - - -class TestPersistenceConfig: - """Test PersistenceConfig.""" - - def test_persistence_config(self) -> None: - """Test creating persistence config from dataclass.""" - cfg: DictConfig = OmegaConf.structured(PersistenceConfig(instance=InstanceConfig(_target_="some_target"))) - assert OmegaConf.select(cfg, "instance._target_") == "some_target" - assert cfg.get("name") is None - - def test_is_model_persister_target_valid(self) -> None: - """Test is_model_persister_target with valid persister targets.""" - assert is_model_persister_target("simplexity.persistence.mlflow_persister.MLFlowPersister") - assert is_model_persister_target("simplexity.persistence.local_pytorch_persister.LocalPytorchPersister") - - def test_is_model_persister_target_invalid(self) -> None: - """Test is_model_persister_target with invalid targets.""" - assert not is_model_persister_target("simplexity.logging.mlflow_logger.MLFlowLogger") - assert not is_model_persister_target("torch.optim.Adam") - assert not is_model_persister_target("") - - def test_is_persister_config_valid(self) -> None: - """Test is_persister_config with valid persister configs.""" - cfg = DictConfig({"_target_": "simplexity.persistence.mlflow_persister.MLFlowPersister"}) - assert is_persister_config(cfg) - - cfg = DictConfig( - { - "_target_": "simplexity.persistence.local_pytorch_persister.LocalPytorchPersister", - "path": "/tmp/model", - } - ) - assert is_persister_config(cfg) - - def test_is_persister_config_invalid(self) -> None: - """Test is_persister_config with invalid configs.""" - # Non-persister target - cfg = DictConfig({"_target_": "simplexity.logging.mlflow_logger.MLFlowLogger"}) - assert not is_persister_config(cfg) - - # Missing _target_ - cfg = DictConfig({"path": "/tmp/model"}) - assert not is_persister_config(cfg) - - # _target_ is None - cfg = DictConfig({"_target_": None}) - assert not is_persister_config(cfg) - - # _target_ is not a string - cfg = DictConfig({"_target_": 123}) - assert not is_persister_config(cfg) - - # Empty config - cfg = DictConfig({}) - assert not is_persister_config(cfg) - - def test_validate_persistence_config_valid(self) -> None: - """Test validate_persistence_config with valid configs.""" - # Valid config without name - cfg = DictConfig( - { - "instance": DictConfig( - { - "_target_": "simplexity.persistence.mlflow_persister.MLFlowPersister", - "experiment_name": "my_experiment", - "run_name": "my_run", - } - ), - } - ) - validate_persistence_config(cfg) # Should not raise - - # Valid config with name - cfg = DictConfig( - { - "instance": DictConfig( - { - "_target_": "simplexity.persistence.mlflow_persister.MLFlowPersister", - "experiment_name": "my_experiment", - "run_name": "my_run", - } - ), - "name": "my_persister", - } - ) - validate_persistence_config(cfg) # Should not raise - - # Valid config with None name - cfg = DictConfig( - { - "instance": DictConfig( - { - "_target_": "simplexity.persistence.mlflow_persister.MLFlowPersister", - "experiment_name": "my_experiment", - "run_name": "my_run", - } - ), - "name": None, - } - ) - validate_persistence_config(cfg) # Should not raise - - def test_validate_persistence_config_missing_instance(self) -> None: - """Test validate_persistence_config raises when instance is missing.""" - cfg = DictConfig({}) - with pytest.raises(ConfigValidationError, match="PersistenceConfig.instance is required"): - validate_persistence_config(cfg) - - cfg = DictConfig({"name": "my_persister"}) - with pytest.raises(ConfigValidationError, match="PersistenceConfig.instance is required"): - validate_persistence_config(cfg) - - def test_validate_persistence_config_invalid_instance(self) -> None: - """Test validate_persistence_config raises when instance is invalid.""" - # Instance without _target_ - cfg = DictConfig({"instance": DictConfig({"other_field": "value"})}) - with pytest.raises(ConfigValidationError, match="InstanceConfig._target_ must be a string"): - validate_persistence_config(cfg) - - # Instance with empty _target_ - cfg = DictConfig({"instance": DictConfig({"_target_": ""})}) - with pytest.raises(ConfigValidationError, match="InstanceConfig._target_ must be a non-empty string"): - validate_persistence_config(cfg) - - # Instance with non-string _target_ - cfg = DictConfig({"instance": DictConfig({"_target_": 123})}) - with pytest.raises(ConfigValidationError, match="InstanceConfig._target_ must be a string"): - validate_persistence_config(cfg) - - def test_validate_persistence_config_non_persister_target(self) -> None: - """Test validate_persistence_config raises when instance target is not a persister target.""" - cfg = DictConfig({"instance": DictConfig({"_target_": "simplexity.logging.mlflow_logger.MLFlowLogger"})}) - with pytest.raises(ConfigValidationError, match="PersistenceConfig.instance must be a persister target"): - validate_persistence_config(cfg) - - cfg = DictConfig({"instance": DictConfig({"_target_": "torch.optim.Adam"})}) - with pytest.raises(ConfigValidationError, match="PersistenceConfig.instance must be a persister target"): - validate_persistence_config(cfg) - - def test_validate_persistence_config_invalid_name(self) -> None: - """Test validate_persistence_config raises when name is invalid.""" - # Empty string name - cfg = DictConfig( - { - "instance": DictConfig({"_target_": "simplexity.persistence.mlflow_persister.MLFlowPersister"}), - "name": "", - } - ) - with pytest.raises(ConfigValidationError, match="PersistenceConfig.name must be a non-empty string"): - validate_persistence_config(cfg) - - # Whitespace-only name - cfg = DictConfig( - { - "instance": DictConfig({"_target_": "simplexity.persistence.mlflow_persister.MLFlowPersister"}), - "name": " ", - } - ) - with pytest.raises(ConfigValidationError, match="PersistenceConfig.name must be a non-empty string"): - validate_persistence_config(cfg) - - # Non-string name - cfg = DictConfig( - { - "instance": DictConfig({"_target_": "simplexity.persistence.mlflow_persister.MLFlowPersister"}), - "name": 123, - } - ) - with pytest.raises(ConfigValidationError, match="PersistenceConfig.name must be a string or None"): - validate_persistence_config(cfg) - - def test_local_persister_configs(self) -> None: - """Test validation of local persister configs.""" - # Equinox - eqx_cfg = DictConfig( - { - "_target_": "simplexity.persistence.local_equinox_persister.LocalEquinoxPersister", - "directory": "/tmp", - "filename": "model.eqx", - } - ) - validate_local_equinox_persister_instance_config(eqx_cfg) - # Test __init__ - eqx_instance = LocalEquinoxPersisterInstanceConfig(directory="/tmp") - assert eqx_instance.filename == "model.eqx" - assert eqx_instance._target_ == "simplexity.persistence.local_equinox_persister.LocalEquinoxPersister" - - # Invalid Equinox filename - eqx_cfg_invalid = DictConfig( - { - "_target_": "simplexity.persistence.local_equinox_persister.LocalEquinoxPersister", - "directory": "/tmp", - "filename": "model.pt", - } - ) - with pytest.raises( - ConfigValidationError, match="LocalEquinoxPersisterInstanceConfig.filename must end with .eqx" - ): - validate_local_equinox_persister_instance_config(eqx_cfg_invalid) - - # Penzai - penzai_cfg = DictConfig( - { - "_target_": "simplexity.persistence.local_penzai_persister.LocalPenzaiPersister", - "directory": "/tmp", - } - ) - validate_local_penzai_persister_instance_config(penzai_cfg) - # Test __init__ - penzai_instance = LocalPenzaiPersisterInstanceConfig(directory="/tmp") - assert penzai_instance._target_ == "simplexity.persistence.local_penzai_persister.LocalPenzaiPersister" - - # Pytorch - pt_cfg = DictConfig( - { - "_target_": "simplexity.persistence.local_pytorch_persister.LocalPytorchPersister", - "directory": "/tmp", - "filename": "model.pt", - } - ) - validate_local_pytorch_persister_instance_config(pt_cfg) - # Test __init__ - pt_instance = LocalPytorchPersisterInstanceConfig(directory="/tmp") - assert pt_instance.filename == "model.pt" - assert pt_instance._target_ == "simplexity.persistence.local_pytorch_persister.LocalPytorchPersister" - - # Invalid Pytorch filename - pt_cfg_invalid = DictConfig( - { - "_target_": "simplexity.persistence.local_pytorch_persister.LocalPytorchPersister", - "directory": "/tmp", - "filename": "model.eqx", - } - ) - with pytest.raises( - ConfigValidationError, match="LocalPytorchPersisterInstanceConfig.filename must end with .pt" - ): - validate_local_pytorch_persister_instance_config(pt_cfg_invalid) - - def test_mlflow_persister_config_init(self) -> None: - """Test MLFlowPersisterInstanceConfig initialization.""" - config = MLFlowPersisterInstanceConfig( - experiment_name="test_exp", run_name="test_run", tracking_uri="file:///tmp/mlruns" - ) - assert config.experiment_name == "test_exp" - assert config.run_name == "test_run" - assert config.tracking_uri == "file:///tmp/mlruns" - assert config._target_ == "simplexity.persistence.mlflow_persister.MLFlowPersister" - - def test_update_persister_instance_config(self) -> None: - """Test update_persister_instance_config.""" - cfg = OmegaConf.structured(MLFlowPersisterInstanceConfig(experiment_name="old")) - updated_cfg = DictConfig({"experiment_name": "new", "run_name": "new_run"}) - - update_persister_instance_config(cfg, updated_cfg) - - assert cfg.experiment_name == "new" - assert cfg.run_name == "new_run" - assert cfg.experiment_id is None diff --git a/tests/structured_configs/test_tracking_config.py b/tests/structured_configs/test_tracking_config.py new file mode 100644 index 00000000..c9c22aa8 --- /dev/null +++ b/tests/structured_configs/test_tracking_config.py @@ -0,0 +1,865 @@ +"""Tests for TrackingConfig validation. + +This module contains tests for tracking configuration validation, including +validation of tracker targets, tracker configs, and tracking configuration instances. +""" + +# pylint: disable-all +# Temporarily disable all pylint checkers during AST traversal to prevent crash. +# The imports checker crashes when resolving simplexity package imports due to a bug +# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 +# pylint: enable=all +# Re-enable all pylint checkers for the checking phase. This allows other checks +# (code quality, style, undefined names, etc.) to run normally while bypassing +# the problematic imports checker that would crash during AST traversal. + +import pytest +from omegaconf import DictConfig, OmegaConf + +from simplexity.exceptions import ConfigValidationError +from simplexity.structured_configs.tracking import ( + FileTrackerInstanceConfig, + InstanceConfig, + MlflowTrackerInstanceConfig, + S3TrackerInstanceConfig, + TrackingConfig, + is_file_tracker_config, + is_file_tracker_target, + is_mlflow_tracker_config, + is_mlflow_tracker_target, + is_run_tracker_config, + is_run_tracker_target, + is_s3_tracker_config, + is_s3_tracker_target, + update_tracking_instance_config, + validate_file_tracker_instance_config, + validate_mlflow_tracker_instance_config, + validate_s3_tracker_instance_config, + validate_tracking_config, +) + + +class TestTrackingConfig: + """Test TrackingConfig.""" + + def test_structured_config(self) -> None: + """Test creating tracking config from dataclass.""" + cfg: DictConfig = OmegaConf.structured(TrackingConfig(instance=InstanceConfig(_target_="some_target"))) + assert OmegaConf.select(cfg, "instance._target_") == "some_target" + assert cfg.get("name") is None + + def test_structured_config_with_name(self) -> None: + """Test creating tracking config with name.""" + cfg: DictConfig = OmegaConf.structured( + TrackingConfig(instance=InstanceConfig(_target_="some_target"), name="my_tracker") + ) + assert OmegaConf.select(cfg, "instance._target_") == "some_target" + assert cfg.get("name") == "my_tracker" + + def test_is_tracker_target(self) -> None: + """Test is_tracker_target with valid targets.""" + assert is_run_tracker_target("simplexity.tracking.file_tracker.FileTracker") + assert is_run_tracker_target("simplexity.tracking.mlflow_tracker.MlflowTracker") + assert is_run_tracker_target("simplexity.tracking.s3_tracker.S3Tracker.from_config") + + def test_is_tracker_target_invalid(self) -> None: + """Test is_tracker_target with invalid targets.""" + assert not is_run_tracker_target("simplexity.logging.mlflow_logger.MLFlowLogger") + assert not is_run_tracker_target("some.other.tracker.Tracker") + assert not is_run_tracker_target("") + + def test_validate_tracking_config_valid(self) -> None: + """Test validate_tracking_config with valid configs.""" + # Valid config without name + cfg = DictConfig( + { + "instance": DictConfig( + { + "_target_": "simplexity.tracking.mlflow_tracker.MlflowTracker", + } + ), + } + ) + validate_tracking_config(cfg) + + # Valid config with name + cfg = DictConfig( + { + "instance": DictConfig( + { + "_target_": "simplexity.tracking.mlflow_tracker.MlflowTracker", + } + ), + "name": "my_tracker", + } + ) + validate_tracking_config(cfg) + + # Valid config with None name + cfg = DictConfig( + { + "instance": DictConfig( + { + "_target_": "simplexity.tracking.mlflow_tracker.MlflowTracker", + } + ), + "name": None, + } + ) + validate_tracking_config(cfg) + + def test_validate_tracking_config_missing_instance(self) -> None: + """Test validate_tracking_config raises when instance is missing.""" + cfg = DictConfig({}) + with pytest.raises(ConfigValidationError, match="TrackingConfig.instance must be a DictConfig"): + validate_tracking_config(cfg) + + cfg = DictConfig({"name": "my_tracker"}) + with pytest.raises(ConfigValidationError, match="TrackingConfig.instance must be a DictConfig"): + validate_tracking_config(cfg) + + def test_validate_tracking_config_invalid_instance(self) -> None: + """Test validate_tracking_config raises when instance is invalid.""" + # Instance without _target_ + cfg = DictConfig({"instance": DictConfig({"other_field": "value"})}) + with pytest.raises(ConfigValidationError, match="InstanceConfig._target_ must be a string"): + validate_tracking_config(cfg) + + # Instance with empty _target_ + cfg = DictConfig({"instance": DictConfig({"_target_": ""})}) + with pytest.raises(ConfigValidationError, match="InstanceConfig._target_ must be a non-empty string"): + validate_tracking_config(cfg) + + # Instance with non-string _target_ + cfg = DictConfig({"instance": DictConfig({"_target_": 123})}) + with pytest.raises(ConfigValidationError, match="InstanceConfig._target_ must be a string"): + validate_tracking_config(cfg) + + def test_validate_tracking_config_non_tracker_target(self) -> None: + """Test validate_tracking_config raises when instance target is not a tracker target.""" + cfg = DictConfig( + {"instance": DictConfig({"_target_": "simplexity.persistence.mlflow_persister.MLFlowPersister"})} + ) + with pytest.raises(ConfigValidationError, match="TrackingConfig.instance must be a tracker target"): + validate_tracking_config(cfg) + + cfg = DictConfig({"instance": DictConfig({"_target_": "torch.optim.Adam"})}) + with pytest.raises(ConfigValidationError, match="TrackingConfig.instance must be a tracker target"): + validate_tracking_config(cfg) + + def test_validate_tracking_config_invalid_name(self) -> None: + """Test validate_tracking_config raises when name is invalid.""" + # Empty string name + cfg = DictConfig( + { + "instance": DictConfig({"_target_": "simplexity.tracking.mlflow_tracker.MlflowTracker"}), + "name": "", + } + ) + with pytest.raises(ConfigValidationError, match="TrackingConfig.name must be a non-empty string"): + validate_tracking_config(cfg) + + # Whitespace-only name + cfg = DictConfig( + { + "instance": DictConfig({"_target_": "simplexity.tracking.mlflow_tracker.MlflowTracker"}), + "name": " ", + } + ) + with pytest.raises(ConfigValidationError, match="TrackingConfig.name must be a non-empty string"): + validate_tracking_config(cfg) + + # Non-string name + cfg = DictConfig( + { + "instance": DictConfig({"_target_": "simplexity.tracking.mlflow_tracker.MlflowTracker"}), + "name": 123, + } + ) + with pytest.raises(ConfigValidationError, match="TrackingConfig.name must be a string or None"): + validate_tracking_config(cfg) + + +class TestFileTrackerConfig: + """Test FileTracker configuration functions.""" + + def test_is_file_tracker_target_valid(self) -> None: + """Test is_file_tracker_target with valid target.""" + assert is_file_tracker_target("simplexity.tracking.file_tracker.FileTracker") + + def test_is_file_tracker_target_invalid(self) -> None: + """Test is_file_tracker_target with invalid targets.""" + assert not is_file_tracker_target("simplexity.tracking.file_tracker.FileTracker.from_config") + assert not is_file_tracker_target("simplexity.tracking.mlflow_tracker.MlflowTracker") + + def test_is_file_tracker_config_valid(self) -> None: + """Test is_file_tracker_config with valid configs.""" + cfg = DictConfig( + { + "_target_": "simplexity.tracking.file_tracker.FileTracker", + "file_path": "my_file.log", + } + ) + assert is_file_tracker_config(cfg) + + cfg = DictConfig( + { + "_target_": "simplexity.tracking.file_tracker.FileTracker", + "file_path": "my_file.log", + "model_dir_name": "custom_models", + } + ) + assert is_file_tracker_config(cfg) + + def test_is_file_tracker_config_invalid(self) -> None: + """Test is_file_tracker_config with invalid configs.""" + # Non-file tracker target + cfg = DictConfig( + { + "_target_": "simplexity.tracking.mlflow_tracker.MlflowTracker", + "file_path": "my_file.log", + } + ) + assert not is_file_tracker_config(cfg) + + # Missing _target_ + cfg = DictConfig({"file_path": "my_file.log"}) + assert not is_file_tracker_config(cfg) + + # _target_ is None + cfg = DictConfig({"_target_": None}) + assert not is_file_tracker_config(cfg) + + # _target_ is not a string + cfg = DictConfig({"_target_": 123}) + assert not is_file_tracker_config(cfg) + + # Empty config + cfg = DictConfig({}) + assert not is_file_tracker_config(cfg) + + def test_validate_file_tracker_instance_config_valid(self) -> None: + """Test validate_file_tracker_instance_config with valid configs.""" + # Valid config with required fields + cfg = DictConfig( + { + "_target_": "simplexity.tracking.file_tracker.FileTracker", + "file_path": "/tmp/test.log", + } + ) + validate_file_tracker_instance_config(cfg) # Should not raise + + # Valid config with optional model_dir_name + cfg = DictConfig( + { + "_target_": "simplexity.tracking.file_tracker.FileTracker", + "file_path": "/tmp/test.log", + "model_dir_name": "custom_models", + } + ) + validate_file_tracker_instance_config(cfg) # Should not raise + + def test_validate_file_tracker_instance_config_invalid_target(self) -> None: + """Test validate_file_tracker_instance_config raises with invalid target.""" + cfg = DictConfig( + { + "_target_": "simplexity.tracking.mlflow_tracker.MlflowTracker", + "file_path": "/tmp/test.log", + } + ) + with pytest.raises(ConfigValidationError, match="InstanceConfig._target_ must be"): + validate_file_tracker_instance_config(cfg) + + def test_validate_file_tracker_instance_config_missing_file_path(self) -> None: + """Test validate_file_tracker_instance_config raises when file_path is missing.""" + cfg = DictConfig( + { + "_target_": "simplexity.tracking.file_tracker.FileTracker", + } + ) + with pytest.raises(ConfigValidationError, match="FileTrackerInstanceConfig.file_path must be a string"): + validate_file_tracker_instance_config(cfg) + + def test_validate_file_tracker_instance_config_empty_file_path(self) -> None: + """Test validate_file_tracker_instance_config raises when file_path is empty.""" + cfg = DictConfig( + { + "_target_": "simplexity.tracking.file_tracker.FileTracker", + "file_path": "", + } + ) + with pytest.raises( + ConfigValidationError, match="FileTrackerInstanceConfig.file_path must be a non-empty string" + ): + validate_file_tracker_instance_config(cfg) + + def test_validate_file_tracker_instance_config_whitespace_file_path(self) -> None: + """Test validate_file_tracker_instance_config raises when file_path is whitespace only.""" + cfg = DictConfig( + { + "_target_": "simplexity.tracking.file_tracker.FileTracker", + "file_path": " ", + } + ) + with pytest.raises( + ConfigValidationError, match="FileTrackerInstanceConfig.file_path must be a non-empty string" + ): + validate_file_tracker_instance_config(cfg) + + def test_validate_file_tracker_instance_config_empty_model_dir_name(self) -> None: + """Test validate_file_tracker_instance_config raises when model_dir_name is empty.""" + cfg = DictConfig( + { + "_target_": "simplexity.tracking.file_tracker.FileTracker", + "file_path": "/tmp/test.log", + "model_dir_name": "", + } + ) + with pytest.raises( + ConfigValidationError, + match="FileTrackerInstanceConfig.model_dir_name must be a non-empty string", + ): + validate_file_tracker_instance_config(cfg) + + def test_file_tracker_instance_config_init(self) -> None: + """Test FileTrackerInstanceConfig instantiation.""" + config = FileTrackerInstanceConfig(file_path="test.log") + assert config.file_path == "test.log" + assert config._target_ == "simplexity.tracking.file_tracker.FileTracker" + assert config.model_dir_name == "models" # Default value + + def test_file_tracker_instance_config_init_with_custom_model_dir(self) -> None: + """Test FileTrackerInstanceConfig instantiation with custom model_dir_name.""" + config = FileTrackerInstanceConfig(file_path="test.log", model_dir_name="custom_models") + assert config.file_path == "test.log" + assert config.model_dir_name == "custom_models" + assert config._target_ == "simplexity.tracking.file_tracker.FileTracker" + + def test_validate_tracking_config_with_file_tracker(self) -> None: + """Test validate_tracking_config with FileTracker instance.""" + # Valid file tracker config + cfg = DictConfig( + { + "instance": DictConfig( + { + "_target_": "simplexity.tracking.file_tracker.FileTracker", + "file_path": "/tmp/test.log", + } + ) + } + ) + validate_tracking_config(cfg) + + # Missing file_path + cfg = DictConfig( + { + "instance": DictConfig( + { + "_target_": "simplexity.tracking.file_tracker.FileTracker", + } + ) + } + ) + with pytest.raises(ConfigValidationError, match="FileTrackerInstanceConfig.file_path must be a string"): + validate_tracking_config(cfg) + + # Empty file_path + cfg = DictConfig( + { + "instance": DictConfig( + { + "_target_": "simplexity.tracking.file_tracker.FileTracker", + "file_path": "", + } + ) + } + ) + with pytest.raises( + ConfigValidationError, match="FileTrackerInstanceConfig.file_path must be a non-empty string" + ): + validate_tracking_config(cfg) + + +class TestMlflowTrackerConfig: + """Test MlflowTracker configuration functions.""" + + def test_is_mlflow_tracker_target_valid(self) -> None: + """Test is_mlflow_tracker_target with valid target.""" + assert is_mlflow_tracker_target("simplexity.tracking.mlflow_tracker.MlflowTracker") + + def test_is_mlflow_tracker_target_invalid(self) -> None: + """Test is_mlflow_tracker_target with invalid targets.""" + assert not is_mlflow_tracker_target("simplexity.tracking.mlflow_tracker.MlflowTracker.from_config") + assert not is_mlflow_tracker_target("simplexity.tracking.file_tracker.FileTracker") + assert not is_mlflow_tracker_target("simplexity.tracking.s3_tracker.S3Tracker.from_config") + assert not is_mlflow_tracker_target("") + assert not is_mlflow_tracker_target("some.other.target") + + def test_is_mlflow_tracker_config_valid(self) -> None: + """Test is_mlflow_tracker_config with valid configs.""" + cfg = DictConfig({"_target_": "simplexity.tracking.mlflow_tracker.MlflowTracker"}) + assert is_mlflow_tracker_config(cfg) + + cfg = DictConfig( + { + "_target_": "simplexity.tracking.mlflow_tracker.MlflowTracker", + "experiment_name": "my_experiment", + "run_name": "my_run", + } + ) + assert is_mlflow_tracker_config(cfg) + + def test_is_mlflow_tracker_config_invalid(self) -> None: + """Test is_mlflow_tracker_config with invalid configs.""" + # Non-mlflow tracker target + cfg = DictConfig({"_target_": "simplexity.tracking.s3_tracker.S3Tracker.from_config"}) + assert not is_mlflow_tracker_config(cfg) + + # Missing _target_ + cfg = DictConfig({"experiment_name": "my_experiment", "run_name": "my_run"}) + assert not is_mlflow_tracker_config(cfg) + + # _target_ is None + cfg = DictConfig({"_target_": None}) + assert not is_mlflow_tracker_config(cfg) + + # _target_ is not a string + cfg = DictConfig({"_target_": 123}) + assert not is_mlflow_tracker_config(cfg) + + # Empty config + cfg = DictConfig({}) + assert not is_mlflow_tracker_config(cfg) + + def test_validate_mlflow_tracker_instance_config_valid(self) -> None: + """Test validate_mlflow_tracker_instance_config with valid configs.""" + # Valid config with minimal fields + cfg = DictConfig( + { + "_target_": "simplexity.tracking.mlflow_tracker.MlflowTracker", + } + ) + validate_mlflow_tracker_instance_config(cfg) # Should not raise + + # Valid config with all optional fields + cfg = DictConfig( + { + "_target_": "simplexity.tracking.mlflow_tracker.MlflowTracker", + "experiment_id": "exp123", + "experiment_name": "my_experiment", + "run_id": "run456", + "run_name": "my_run", + "tracking_uri": "databricks", + "registry_uri": "databricks", + "downgrade_unity_catalog": True, + "model_dir": "models", + "config_path": "config.yaml", + } + ) + validate_mlflow_tracker_instance_config(cfg) # Should not raise + + # Valid config with None optional fields + cfg = DictConfig( + { + "_target_": "simplexity.tracking.mlflow_tracker.MlflowTracker", + "experiment_id": None, + "experiment_name": None, + "run_id": None, + "run_name": None, + "tracking_uri": None, + "registry_uri": None, + } + ) + validate_mlflow_tracker_instance_config(cfg) # Should not raise + + def test_validate_mlflow_tracker_instance_config_invalid_target(self) -> None: + """Test validate_mlflow_tracker_instance_config raises with invalid target.""" + cfg = DictConfig( + { + "_target_": "simplexity.tracking.file_tracker.FileTracker", + "experiment_name": "my_experiment", + } + ) + with pytest.raises(ConfigValidationError, match="InstanceConfig._target_ must be"): + validate_mlflow_tracker_instance_config(cfg) + + def test_validate_mlflow_tracker_instance_config_empty_string_fields(self) -> None: + """Test validate_mlflow_tracker_instance_config raises with empty string fields.""" + # Empty experiment_id + cfg = DictConfig( + { + "_target_": "simplexity.tracking.mlflow_tracker.MlflowTracker", + "experiment_id": "", + } + ) + with pytest.raises( + ConfigValidationError, + match="MlflowTrackerInstanceConfig.experiment_id must be a non-empty string", + ): + validate_mlflow_tracker_instance_config(cfg) + + # Empty experiment_name + cfg = DictConfig( + { + "_target_": "simplexity.tracking.mlflow_tracker.MlflowTracker", + "experiment_name": "", + } + ) + with pytest.raises( + ConfigValidationError, + match="MlflowTrackerInstanceConfig.experiment_name must be a non-empty string", + ): + validate_mlflow_tracker_instance_config(cfg) + + # Empty model_dir + cfg = DictConfig( + { + "_target_": "simplexity.tracking.mlflow_tracker.MlflowTracker", + "model_dir": "", + } + ) + with pytest.raises( + ConfigValidationError, match="MlflowTrackerInstanceConfig.model_dir must be a non-empty string" + ): + validate_mlflow_tracker_instance_config(cfg) + + def test_validate_mlflow_tracker_instance_config_invalid_uri(self) -> None: + """Test validate_mlflow_tracker_instance_config raises with invalid URIs.""" + # Invalid tracking_uri + cfg = DictConfig( + { + "_target_": "simplexity.tracking.mlflow_tracker.MlflowTracker", + "tracking_uri": " ", + } + ) + with pytest.raises(ConfigValidationError, match="MlflowTrackerInstanceConfig.tracking_uri"): + validate_mlflow_tracker_instance_config(cfg) + + # Invalid registry_uri + cfg = DictConfig( + { + "_target_": "simplexity.tracking.mlflow_tracker.MlflowTracker", + "registry_uri": "%parse_error%", + } + ) + with pytest.raises(ConfigValidationError, match="MlflowTrackerInstanceConfig.registry_uri"): + validate_mlflow_tracker_instance_config(cfg) + + def test_mlflow_tracker_instance_config_init(self) -> None: + """Test MlflowTrackerInstanceConfig instantiation.""" + config = MlflowTrackerInstanceConfig() + assert config._target_ == "simplexity.tracking.mlflow_tracker.MlflowTracker" + assert config.experiment_id is None + assert config.experiment_name is None + assert config.downgrade_unity_catalog is True + assert config.model_dir == "models" + assert config.config_path == "config.yaml" + + def test_mlflow_tracker_instance_config_init_with_fields(self) -> None: + """Test MlflowTrackerInstanceConfig instantiation with fields.""" + config = MlflowTrackerInstanceConfig( + experiment_name="my_experiment", + run_name="my_run", + tracking_uri="databricks", + ) + assert config.experiment_name == "my_experiment" + assert config.run_name == "my_run" + assert config.tracking_uri == "databricks" + assert config._target_ == "simplexity.tracking.mlflow_tracker.MlflowTracker" + + def test_validate_tracking_config_with_mlflow_tracker(self) -> None: + """Test validate_tracking_config with MlflowTracker instance.""" + # Valid mlflow tracker config + cfg = DictConfig( + { + "instance": DictConfig( + { + "_target_": "simplexity.tracking.mlflow_tracker.MlflowTracker", + "experiment_name": "my_experiment", + "run_name": "my_run", + } + ) + } + ) + validate_tracking_config(cfg) + + +class TestS3TrackerConfig: + """Test S3Tracker configuration functions.""" + + def test_is_s3_tracker_target_valid(self) -> None: + """Test is_s3_tracker_target with valid target.""" + assert is_s3_tracker_target("simplexity.tracking.s3_tracker.S3Tracker.from_config") + + def test_is_s3_tracker_target_invalid(self) -> None: + """Test is_s3_tracker_target with invalid targets.""" + assert not is_s3_tracker_target("simplexity.tracking.s3_tracker.S3Tracker") + assert not is_s3_tracker_target("simplexity.tracking.file_tracker.FileTracker") + assert not is_s3_tracker_target("simplexity.tracking.mlflow_tracker.MlflowTracker") + assert not is_s3_tracker_target("") + assert not is_s3_tracker_target("some.other.target") + + def test_is_s3_tracker_config_valid(self) -> None: + """Test is_s3_tracker_config with valid configs.""" + cfg = DictConfig({"_target_": "simplexity.tracking.s3_tracker.S3Tracker.from_config"}) + assert is_s3_tracker_config(cfg) + + cfg = DictConfig( + { + "_target_": "simplexity.tracking.s3_tracker.S3Tracker.from_config", + "prefix": "s3://bucket/prefix", + } + ) + assert is_s3_tracker_config(cfg) + + def test_is_s3_tracker_config_invalid(self) -> None: + """Test is_s3_tracker_config with invalid configs.""" + # Non-s3 tracker target + cfg = DictConfig({"_target_": "simplexity.tracking.file_tracker.FileTracker"}) + assert not is_s3_tracker_config(cfg) + + # Missing _target_ + cfg = DictConfig({"prefix": "s3://bucket/prefix"}) + assert not is_s3_tracker_config(cfg) + + # _target_ is None + cfg = DictConfig({"_target_": None}) + assert not is_s3_tracker_config(cfg) + + # _target_ is not a string + cfg = DictConfig({"_target_": 123}) + assert not is_s3_tracker_config(cfg) + + # Empty config + cfg = DictConfig({}) + assert not is_s3_tracker_config(cfg) + + def test_validate_s3_tracker_instance_config_valid(self) -> None: + """Test validate_s3_tracker_instance_config with valid configs.""" + # Valid config with required fields and config_filename + cfg = DictConfig( + { + "_target_": "simplexity.tracking.s3_tracker.S3Tracker.from_config", + "prefix": "s3://bucket/prefix", + "config_filename": "config.ini", + } + ) + validate_s3_tracker_instance_config(cfg) # Should not raise + + # Valid config with optional config_filename + cfg = DictConfig( + { + "_target_": "simplexity.tracking.s3_tracker.S3Tracker.from_config", + "prefix": "s3://bucket/prefix", + "config_filename": "custom.ini", + } + ) + validate_s3_tracker_instance_config(cfg) # Should not raise + + def test_validate_s3_tracker_instance_config_invalid_target(self) -> None: + """Test validate_s3_tracker_instance_config raises with invalid target.""" + cfg = DictConfig( + { + "_target_": "simplexity.tracking.file_tracker.FileTracker", + "prefix": "s3://bucket/prefix", + } + ) + with pytest.raises(ConfigValidationError, match="InstanceConfig._target_ must be"): + validate_s3_tracker_instance_config(cfg) + + def test_validate_s3_tracker_instance_config_missing_prefix(self) -> None: + """Test validate_s3_tracker_instance_config raises when prefix is missing.""" + cfg = DictConfig( + { + "_target_": "simplexity.tracking.s3_tracker.S3Tracker.from_config", + } + ) + with pytest.raises(ConfigValidationError, match="S3TrackerInstanceConfig.prefix must be a string"): + validate_s3_tracker_instance_config(cfg) + + def test_validate_s3_tracker_instance_config_empty_prefix(self) -> None: + """Test validate_s3_tracker_instance_config raises when prefix is empty.""" + cfg = DictConfig( + { + "_target_": "simplexity.tracking.s3_tracker.S3Tracker.from_config", + "prefix": "", + } + ) + with pytest.raises(ConfigValidationError, match="S3TrackerInstanceConfig.prefix must be a non-empty string"): + validate_s3_tracker_instance_config(cfg) + + def test_validate_s3_tracker_instance_config_empty_config_filename(self) -> None: + """Test validate_s3_tracker_instance_config raises when config_filename is empty.""" + cfg = DictConfig( + { + "_target_": "simplexity.tracking.s3_tracker.S3Tracker.from_config", + "prefix": "s3://bucket/prefix", + "config_filename": "", + } + ) + with pytest.raises( + ConfigValidationError, + match="S3TrackerInstanceConfig.config_filename must be a non-empty string", + ): + validate_s3_tracker_instance_config(cfg) + + def test_s3_tracker_instance_config_init(self) -> None: + """Test S3TrackerInstanceConfig instantiation.""" + config = S3TrackerInstanceConfig(prefix="s3://bucket/prefix") + assert config.prefix == "s3://bucket/prefix" + assert config._target_ == "simplexity.tracking.s3_tracker.S3Tracker.from_config" + assert config.config_filename == "config.ini" # Default value + + def test_s3_tracker_instance_config_init_with_custom_config_filename(self) -> None: + """Test S3TrackerInstanceConfig instantiation with custom config_filename.""" + config = S3TrackerInstanceConfig(prefix="s3://bucket/prefix", config_filename="custom.ini") + assert config.prefix == "s3://bucket/prefix" + assert config.config_filename == "custom.ini" + assert config._target_ == "simplexity.tracking.s3_tracker.S3Tracker.from_config" + + def test_validate_tracking_config_with_s3_tracker(self) -> None: + """Test validate_tracking_config with S3Tracker instance.""" + # Valid s3 tracker config + cfg = DictConfig( + { + "instance": DictConfig( + { + "_target_": "simplexity.tracking.s3_tracker.S3Tracker.from_config", + "prefix": "s3://bucket/prefix", + "config_filename": "config.ini", + } + ) + } + ) + validate_tracking_config(cfg) + + # Missing prefix + cfg = DictConfig( + { + "instance": DictConfig( + { + "_target_": "simplexity.tracking.s3_tracker.S3Tracker.from_config", + } + ) + } + ) + with pytest.raises(ConfigValidationError, match="S3TrackerInstanceConfig.prefix must be a string"): + validate_tracking_config(cfg) + + # Empty prefix + cfg = DictConfig( + { + "instance": DictConfig( + { + "_target_": "simplexity.tracking.s3_tracker.S3Tracker.from_config", + "prefix": "", + } + ) + } + ) + with pytest.raises(ConfigValidationError, match="S3TrackerInstanceConfig.prefix must be a non-empty string"): + validate_tracking_config(cfg) + + +class TestRunTrackerConfig: + """Test run tracker configuration functions.""" + + def test_is_run_tracker_target_valid(self) -> None: + """Test is_run_tracker_target with valid targets.""" + assert is_run_tracker_target("simplexity.tracking.file_tracker.FileTracker") + assert is_run_tracker_target("simplexity.tracking.mlflow_tracker.MlflowTracker") + assert is_run_tracker_target("simplexity.tracking.s3_tracker.S3Tracker.from_config") + assert is_run_tracker_target("simplexity.tracking.any_tracker.AnyTracker") + + def test_is_run_tracker_target_invalid(self) -> None: + """Test is_run_tracker_target with invalid targets.""" + assert not is_run_tracker_target("simplexity.persistence.mlflow_persister.MLFlowPersister") + assert not is_run_tracker_target("torch.optim.Adam") + assert not is_run_tracker_target("") + assert not is_run_tracker_target("some.other.target") + assert not is_run_tracker_target("simplexity.logging.mlflow_logger.MLFlowLogger") + + def test_is_run_tracker_config_valid(self) -> None: + """Test is_run_tracker_config with valid configs.""" + cfg = DictConfig({"_target_": "simplexity.tracking.file_tracker.FileTracker"}) + assert is_run_tracker_config(cfg) + + cfg = DictConfig({"_target_": "simplexity.tracking.mlflow_tracker.MlflowTracker"}) + assert is_run_tracker_config(cfg) + + cfg = DictConfig({"_target_": "simplexity.tracking.s3_tracker.S3Tracker.from_config"}) + assert is_run_tracker_config(cfg) + + def test_is_run_tracker_config_invalid(self) -> None: + """Test is_run_tracker_config with invalid configs.""" + # Non-tracker target + cfg = DictConfig({"_target_": "simplexity.persistence.mlflow_persister.MLFlowPersister"}) + assert not is_run_tracker_config(cfg) + + # Missing _target_ + cfg = DictConfig({"experiment_name": "my_experiment"}) + assert not is_run_tracker_config(cfg) + + # _target_ is None + cfg = DictConfig({"_target_": None}) + assert not is_run_tracker_config(cfg) + + # _target_ is not a string + cfg = DictConfig({"_target_": 123}) + assert not is_run_tracker_config(cfg) + + # Empty config + cfg = DictConfig({}) + assert not is_run_tracker_config(cfg) + + +class TestUpdateTrackingInstanceConfig: + """Test update_tracking_instance_config function.""" + + def test_update_tracking_instance_config(self) -> None: + """Test update_tracking_instance_config function.""" + # Initial config + cfg = DictConfig( + { + "_target_": "simplexity.tracking.mlflow_tracker.MlflowTracker", + "experiment_name": "exp1", + "run_name": "run1", + } + ) + + # Update config + updated_cfg = DictConfig( + { + "_target_": "simplexity.tracking.mlflow_tracker.MlflowTracker", + "experiment_name": "exp2", + "tracking_uri": "file:///tmp/mlruns", + } + ) + + update_tracking_instance_config(cfg, updated_cfg) + + assert cfg.experiment_name == "exp2" + assert cfg.run_name == "run1" # Should remain unchanged + assert cfg.tracking_uri == "file:///tmp/mlruns" + + def test_update_tracking_instance_config_overwrites(self) -> None: + """Test update_tracking_instance_config overwrites existing values.""" + cfg = DictConfig( + { + "_target_": "simplexity.tracking.file_tracker.FileTracker", + "file_path": "/tmp/old.log", + "model_dir_name": "old_models", + } + ) + + updated_cfg = DictConfig( + { + "_target_": "simplexity.tracking.file_tracker.FileTracker", + "file_path": "/tmp/new.log", + } + ) + + update_tracking_instance_config(cfg, updated_cfg) + + assert cfg.file_path == "/tmp/new.log" + assert cfg.model_dir_name == "old_models" # Should remain unchanged diff --git a/tests/persistence/test_local_equinox_persister.py b/tests/tracking/model_persistence/test_local_equinox_persister.py similarity index 90% rename from tests/persistence/test_local_equinox_persister.py rename to tests/tracking/model_persistence/test_local_equinox_persister.py index 04de3aeb..64b28fad 100644 --- a/tests/persistence/test_local_equinox_persister.py +++ b/tests/tracking/model_persistence/test_local_equinox_persister.py @@ -7,7 +7,9 @@ import jax import pytest -from simplexity.persistence.local_equinox_persister import LocalEquinoxPersister +from simplexity.tracking.model_persistence.local_equinox_persister import ( + LocalEquinoxPersister, +) def get_model(seed: int) -> eqx.Module: diff --git a/tests/persistence/test_local_penzai_persister.py b/tests/tracking/model_persistence/test_local_penzai_persister.py similarity index 87% rename from tests/persistence/test_local_penzai_persister.py rename to tests/tracking/model_persistence/test_local_penzai_persister.py index 068ee7ad..0b12d7bc 100644 --- a/tests/persistence/test_local_penzai_persister.py +++ b/tests/tracking/model_persistence/test_local_penzai_persister.py @@ -7,10 +7,18 @@ import pytest from penzai import pz from penzai.core.variables import UnboundVariableError -from penzai.models.transformer.variants.llamalike_common import LlamalikeTransformerConfig, build_llamalike_transformer +from penzai.models.transformer.variants.llamalike_common import ( + LlamalikeTransformerConfig, + build_llamalike_transformer, +) from penzai.nn.layer import Layer as PenzaiModel -from simplexity.persistence.local_penzai_persister import LocalPenzaiPersister +from simplexity.tracking.model_persistence.local_penzai_persister import ( + LocalPenzaiPersister, +) + +# Skip if penzai is not installed +pytest.importorskip("penzai") def test_local_penzai_persister(tmp_path: Path): diff --git a/tests/persistence/test_local_pytorch_persister.py b/tests/tracking/model_persistence/test_local_pytorch_persister.py similarity index 95% rename from tests/persistence/test_local_pytorch_persister.py rename to tests/tracking/model_persistence/test_local_pytorch_persister.py index 2d2dafbe..e2a4250e 100644 --- a/tests/persistence/test_local_pytorch_persister.py +++ b/tests/tracking/model_persistence/test_local_pytorch_persister.py @@ -5,7 +5,9 @@ import torch from torch.nn import GRU, Embedding, Linear, Module -from simplexity.persistence.local_pytorch_persister import LocalPytorchPersister +from simplexity.tracking.model_persistence.local_pytorch_persister import ( + LocalPytorchPersister, +) class SimpleLM(Module): diff --git a/tests/persistence/s3_mocks.py b/tests/tracking/s3_mocks.py similarity index 100% rename from tests/persistence/s3_mocks.py rename to tests/tracking/s3_mocks.py diff --git a/tests/tracking/test_file_tracker.py b/tests/tracking/test_file_tracker.py new file mode 100644 index 00000000..9a1a5fce --- /dev/null +++ b/tests/tracking/test_file_tracker.py @@ -0,0 +1,336 @@ +"""Test the file tracker.""" + +# pylint: disable=all +# Temporarily disable all pylint checkers during AST traversal to prevent crash. +# The imports checker crashes when resolving simplexity package imports due to a bug +# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 +# pylint: enable=all +# Re-enable all pylint checkers for the checking phase. This allows other checks +# (code quality, style, undefined names, etc.) to run normally while bypassing +# the problematic imports checker that would crash during AST traversal. + +import json +from pathlib import Path + +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +import pytest +from omegaconf import DictConfig +from PIL import Image + +from simplexity.tracking.file_tracker import FileTracker + +EXPECTED_LOG = """Config: {'str_param': 'str_value', 'int_param': 1, 'float_param': 1.0, 'bool_param': True} +Config: {'str_param': 'str_value', 'int_param': 1, 'float_param': 1.0, 'bool_param': True} +Params: {'str_param': 'str_value', 'int_param': 1, 'float_param': 1.0, 'bool_param': True} +Tags: {'str_tag': 'str_value', 'int_tag': 1, 'float_tag': 1.0, 'bool_tag': True} +Metrics at step 1: {'int_metric': 1, 'float_metric': 1.0, 'jnp_metric': Array(0.1, dtype=float32, weak_type=True)} +""" + +EXPECTED_LOG_WITH_INTERPOLATION = ( + "Config: {'base_value': 'hello', 'interpolated_value': 'hello_world', 'nested': {'value': 'hello_nested'}}\n" +) + + +@pytest.fixture +def matplotlib_figure(): + """Create a reusable matplotlib figure for testing.""" + fig, ax = plt.subplots(figsize=(4, 3)) + ax.plot([1, 2, 3, 4], [1, 4, 2, 3]) + ax.set_title("Test Plot") + yield fig + plt.close(fig) + + +@pytest.fixture +def simple_matplotlib_figure(): + """Create a simple matplotlib figure for basic tests.""" + fig, ax = plt.subplots() + ax.plot([1, 2, 3]) + yield fig + plt.close(fig) + + +@pytest.fixture +def numpy_image(): + """Create a reusable numpy image array for testing.""" + return np.random.randint(0, 255, (80, 120, 3), dtype=np.uint8) + + +@pytest.fixture +def small_numpy_image(): + """Create a small numpy image array for testing.""" + return np.ones((50, 50, 3), dtype=np.uint8) * 100 + + +@pytest.fixture +def tiny_numpy_image(): + """Create a tiny numpy image array for testing.""" + return np.zeros((10, 10, 3), dtype=np.uint8) + + +@pytest.fixture +def pil_image(): + """Create a reusable PIL image for testing.""" + return Image.new("RGB", (100, 50), color="red") + + +@pytest.fixture +def test_artifact_file(tmp_path: Path): + """Create a test file to use as an artifact.""" + test_file = tmp_path / "source" / "test_artifact.txt" + test_file.parent.mkdir() + test_file.write_text("test content") + return test_file + + +@pytest.fixture +def test_artifact_directory(tmp_path: Path): + """Create a test directory to use as an artifact.""" + source_dir = tmp_path / "source_dir" + source_dir.mkdir() + (source_dir / "file1.txt").write_text("content1") + (source_dir / "file2.txt").write_text("content2") + return source_dir + + +@pytest.fixture +def sample_json_data(): + """Sample JSON data for testing.""" + return {"key": "value", "number": 42, "list": [1, 2, 3]} + + +@pytest.fixture +def sample_list_data(): + """Sample list data for testing.""" + return [{"id": 1, "name": "test1"}, {"id": 2, "name": "test2"}] + + +def test_file_tracker(tmp_path: Path): + """Test FileTracker initialization.""" + tracker = FileTracker(str(tmp_path / "test.log")) + params = { + "str_param": "str_value", + "int_param": 1, + "float_param": 1.0, + "bool_param": True, + } + tracker.log_config(DictConfig(params)) + tracker.log_config(DictConfig(params), resolve=True) + tracker.log_params(params) + tags = { + "str_tag": "str_value", + "int_tag": 1, + "float_tag": 1.0, + "bool_tag": True, + } + tracker.log_tags(tags) + metrics = { + "int_metric": 1, + "float_metric": 1.0, + "jnp_metric": jnp.array(0.1), + } + tracker.log_metrics(1, metrics) + tracker.cleanup() + + with open(tmp_path / "test.log", encoding="utf-8") as f: + assert f.read() == EXPECTED_LOG + + +def test_file_tracker_with_interpolation(tmp_path: Path): + """Test that resolved config properly resolves interpolations.""" + tracker = FileTracker(str(tmp_path / "test.log")) + + # Create a config with interpolation + config_dict = { + "base_value": "hello", + "interpolated_value": "${base_value}_world", + "nested": { + "value": "${base_value}_nested", + }, + } + + config = DictConfig(config_dict) + tracker.log_config(config, resolve=True) + tracker.cleanup() + + with open(tmp_path / "test.log", encoding="utf-8") as f: + assert f.read() == EXPECTED_LOG_WITH_INTERPOLATION + + +def test_log_artifact_copies_file(test_artifact_file, tmp_path: Path): + """Test that log_artifact copies a file to the log directory.""" + tracker = FileTracker(str(tmp_path / "test.log")) + tracker.log_artifact(str(test_artifact_file)) + tracker.cleanup() + + copied_file = tmp_path / "test_artifact.txt" + assert copied_file.exists() + assert copied_file.read_text() == "test content" + + with open(tmp_path / "test.log", encoding="utf-8") as f: + log_content = f.read() + assert "Artifact copied:" in log_content + assert "test_artifact.txt" in log_content + + +def test_log_artifact_with_custom_path(tmp_path: Path): + """Test log_artifact with custom artifact path.""" + tracker = FileTracker(str(tmp_path / "test.log")) + test_file = tmp_path / "source.txt" + test_file.write_text("content") + + tracker.log_artifact(str(test_file), "custom/path/dest.txt") + tracker.cleanup() + + copied_file = tmp_path / "custom" / "path" / "dest.txt" + assert copied_file.exists() + assert copied_file.read_text() == "content" + + +def test_log_artifact_directory(test_artifact_directory, tmp_path: Path): + """Test that log_artifact can copy entire directories.""" + tracker = FileTracker(str(tmp_path / "test.log")) + tracker.log_artifact(str(test_artifact_directory), "copied_dir") + tracker.cleanup() + + copied_dir = tmp_path / "copied_dir" + assert copied_dir.is_dir() + assert (copied_dir / "file1.txt").read_text() == "content1" + assert (copied_dir / "file2.txt").read_text() == "content2" + + +def test_log_json_artifact_saves_json(sample_json_data, tmp_path: Path): + """Test that log_json_artifact saves JSON data.""" + tracker = FileTracker(str(tmp_path / "test.log")) + tracker.log_json_artifact(sample_json_data, "results.json") + tracker.cleanup() + + json_file = tmp_path / "results.json" + assert json_file.exists() + + with open(json_file, encoding="utf-8") as f: + loaded_data = json.load(f) + assert loaded_data == sample_json_data + + with open(tmp_path / "test.log", encoding="utf-8") as f: + log_content = f.read() + assert "JSON artifact saved:" in log_content + assert "results.json" in log_content + + +def test_log_json_artifact_with_list(sample_list_data, tmp_path: Path): + """Test log_json_artifact with list data.""" + tracker = FileTracker(str(tmp_path / "test.log")) + tracker.log_json_artifact(sample_list_data, "data_list.json") + tracker.cleanup() + + json_file = tmp_path / "data_list.json" + assert json_file.exists() + + with open(json_file, encoding="utf-8") as f: + loaded_data = json.load(f) + assert loaded_data == sample_list_data + + +def test_log_figure_saves_matplotlib_plot(matplotlib_figure, tmp_path: Path): + """Test that log_figure saves a matplotlib figure to disk.""" + tracker = FileTracker(str(tmp_path / "test.log")) + tracker.log_figure(matplotlib_figure, "test_plot.png") + tracker.cleanup() + + with Image.open(tmp_path / "test_plot.png") as img: + assert img.size == (400, 300) + + with open(tmp_path / "test.log", encoding="utf-8") as f: + log_content = f.read() + assert "Figure saved:" in log_content + assert "test_plot.png" in log_content + + +def test_log_figure_with_kwargs(simple_matplotlib_figure, tmp_path: Path): + """Test that log_figure passes kwargs to matplotlib savefig.""" + tracker = FileTracker(str(tmp_path / "test.log")) + tracker.log_figure(simple_matplotlib_figure, "high_dpi.png", dpi=200, bbox_inches="tight") + tracker.cleanup() + + with Image.open(tmp_path / "high_dpi.png") as img: + assert img.size[0] >= 800 + + +def test_log_image_pil_artifact_mode(pil_image, tmp_path: Path): + """Test logging PIL Image in artifact mode.""" + tracker = FileTracker(str(tmp_path / "test.log")) + tracker.log_image(pil_image, artifact_file="pil_test.png") + tracker.cleanup() + + with Image.open(tmp_path / "pil_test.png") as img: + assert img.size == (100, 50) + + with open(tmp_path / "test.log", encoding="utf-8") as f: + log_content = f.read() + assert "Image saved:" in log_content + assert "pil_test.png" in log_content + + +def test_log_image_numpy_artifact_mode(numpy_image, tmp_path: Path): + """Test logging numpy array in artifact mode.""" + tracker = FileTracker(str(tmp_path / "test.log")) + tracker.log_image(numpy_image, artifact_file="numpy_test.png") + tracker.cleanup() + + with Image.open(tmp_path / "numpy_test.png") as img: + assert img.size == (120, 80) + + +def test_log_image_time_stepped_mode(small_numpy_image, tmp_path: Path): + """Test logging image in time-stepped mode with key and step.""" + tracker = FileTracker(str(tmp_path / "test.log")) + tracker.log_image(small_numpy_image, key="training_viz", step=42) + tracker.cleanup() + + with Image.open(tmp_path / "training_viz_step_42.png") as img: + assert img.size == (50, 50) + + with open(tmp_path / "test.log", encoding="utf-8") as f: + log_content = f.read() + assert "Time-stepped image saved:" in log_content + assert "training_viz_step_42.png" in log_content + + +def test_log_image_unsupported_type(tmp_path: Path): + """Test logging unsupported image type logs error.""" + tracker = FileTracker(str(tmp_path / "test.log")) + unsupported_image = "not an image" + + tracker.log_image(unsupported_image, artifact_file="bad.png") # type: ignore[arg-type] + tracker.cleanup() + + assert not (tmp_path / "bad.png").exists() + + with open(tmp_path / "test.log", encoding="utf-8") as f: + log_content = f.read() + assert "not supported for file saving" in log_content + + +def test_log_image_missing_parameters_fails(tiny_numpy_image, tmp_path: Path): + """Test that log_image without proper parameters logs error.""" + tracker = FileTracker(str(tmp_path / "test.log")) + tracker.log_image(tiny_numpy_image, key="incomplete") # missing step + tracker.cleanup() + + with open(tmp_path / "test.log", encoding="utf-8") as f: + log_content = f.read() + assert "Image logging failed" in log_content + + +def test_log_image_no_parameters_fails(tiny_numpy_image, tmp_path: Path): + """Test that log_image with no parameters logs error.""" + tracker = FileTracker(str(tmp_path / "test.log")) + tracker.log_image(tiny_numpy_image) # Neither artifact_file nor key+step + tracker.cleanup() + + with open(tmp_path / "test.log", encoding="utf-8") as f: + log_content = f.read() + assert "Image logging failed - need either artifact_file or (key + step)" in log_content diff --git a/tests/persistence/test_mlflow_persister.py b/tests/tracking/test_mlflow_tracker.py similarity index 59% rename from tests/persistence/test_mlflow_persister.py rename to tests/tracking/test_mlflow_tracker.py index f44d1a2a..1798554f 100644 --- a/tests/persistence/test_mlflow_persister.py +++ b/tests/tracking/test_mlflow_tracker.py @@ -1,4 +1,13 @@ -"""Integration-style tests for MLFlowPersister with a local MLflow backend.""" +"""Integration-style tests for MlflowTracker with a local MLflow backend.""" + +# pylint: disable=all +# Temporarily disable all pylint checkers during AST traversal to prevent crash. +# The imports checker crashes when resolving simplexity package imports due to a bug +# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 +# pylint: enable=all +# Re-enable all pylint checkers for the checking phase. This allows other checks +# (code quality, style, undefined names, etc.) to run normally while bypassing +# the problematic imports checker that would crash during AST traversal. from __future__ import annotations @@ -16,20 +25,26 @@ from mlflow.models import infer_signature from torch.nn import Linear, Module -from simplexity.persistence.mlflow_persister import MLFlowPersister from simplexity.predictive_models.types import ModelFramework +from simplexity.tracking.mlflow_tracker import MlflowTracker from simplexity.utils.mlflow_utils import set_mlflow_uris -def _get_artifacts_root(persister: MLFlowPersister) -> Path: - """Get the artifacts root directory for the given persister.""" - assert persister.tracking_uri is not None - tracking_dir = Path(persister.tracking_uri.replace("file://", "")) - experiment_id = persister.experiment_id - run_id = persister.run_id +def _get_artifacts_root(tracker: MlflowTracker) -> Path: + """Get the artifacts root directory for the given tracker.""" + assert tracker.tracking_uri is not None + tracking_dir = Path(tracker.tracking_uri.replace("file://", "")) + experiment_id = tracker.experiment_id + run_id = tracker.run_id return tracking_dir / experiment_id / run_id / "artifacts" +def _get_pytorch_model(seed: int) -> Linear: + """Build a small deterministic PyTorch model for serialization tests.""" + torch.manual_seed(seed) + return Linear(in_features=4, out_features=2) + + def _pytorch_models_equal(model1: Module, model2: Module) -> bool: """Check if two PyTorch models have identical parameters.""" params1 = dict(model1.named_parameters()) @@ -55,23 +70,23 @@ def _models_equal(model1: Module | eqx.Module, model2: Module | eqx.Module) -> b @pytest.fixture -def persister(tmp_path: Path) -> Generator[MLFlowPersister, None, None]: - """Get a MLFlowPersister instance.""" +def tracker(tmp_path: Path) -> Generator[MlflowTracker, None, None]: + """Get a MlflowTracker instance.""" artifact_dir = tmp_path / "mlruns" artifact_dir.mkdir() - persister = MLFlowPersister( + tracker = MlflowTracker( experiment_name="test-experiment", run_name="test-run", tracking_uri=artifact_dir.as_uri(), registry_uri=artifact_dir.as_uri(), model_dir="models", ) - yield persister - persister.cleanup() + yield tracker + tracker.cleanup() @pytest.mark.parametrize("framework", [ModelFramework.PYTORCH, ModelFramework.EQUINOX]) -def test_round_trip(persister: MLFlowPersister, framework: ModelFramework) -> None: +def test_round_trip(tracker: MlflowTracker, framework: ModelFramework) -> None: """PyTorch model weights saved via MLflow can be restored back into a new instance.""" if framework == ModelFramework.PYTORCH: @@ -87,58 +102,39 @@ def test_round_trip(persister: MLFlowPersister, framework: ModelFramework) -> No else: raise ValueError(f"Unsupported model framework: {framework}") - persister.save_weights(original, step=0) + tracker.save_model(original, step=0) - remote_model_path = _get_artifacts_root(persister) / persister.model_dir / "0" / model_filename + remote_model_path = _get_artifacts_root(tracker) / tracker.model_dir / "0" / model_filename assert remote_model_path.exists() assert not _models_equal(original, updated) - loaded = persister.load_weights(updated, step=0) + loaded = tracker.load_model(updated, step=0) assert _models_equal(loaded, original) @pytest.mark.parametrize("framework", [ModelFramework.PYTORCH, ModelFramework.EQUINOX]) -def test_round_trip_from_config(persister: MLFlowPersister, framework: ModelFramework) -> None: +def test_round_trip_from_config(tracker: MlflowTracker, framework: ModelFramework) -> None: """PyTorch model weights saved via MLflow can be restored back into a new instance via the config.""" - if framework == ModelFramework.PYTORCH: + torch.manual_seed(0) original = Linear(in_features=4, out_features=2) - config = { - "predictive_model": { - "instance": { - "_target_": "torch.nn.Linear", - "in_features": 4, - "out_features": 2, - } - } - } + torch.manual_seed(1) + updated = Linear(in_features=4, out_features=2) model_filename = "model.pt" elif framework == ModelFramework.EQUINOX: original = eqx.nn.Linear(in_features=4, out_features=2, key=jax.random.key(0)) - config = { - "predictive_model": { - "instance": { - "_target_": "equinox.nn.Linear", - "in_features": 4, - "out_features": 2, - "key": {"_target_": "jax.random.key", "seed": 0}, - } - } - } + updated = eqx.nn.Linear(in_features=4, out_features=2, key=jax.random.key(1)) model_filename = "model.eqx" else: raise ValueError(f"Unsupported model framework: {framework}") - persister.save_weights(original, step=0) + tracker.save_model(original, step=0) - remote_model_path = _get_artifacts_root(persister) / persister.model_dir / "0" / model_filename + remote_model_path = _get_artifacts_root(tracker) / tracker.model_dir / "0" / model_filename assert remote_model_path.exists() - config_path = _get_artifacts_root(persister) / "config.yaml" - with open(config_path, "w", encoding="utf-8") as f: - yaml.dump(config, f) - - loaded = persister.load_model(step=0) + assert not _models_equal(original, updated) + loaded = tracker.load_model(updated, step=0) assert _models_equal(loaded, original) @@ -147,16 +143,16 @@ def test_round_trip_from_config(persister: MLFlowPersister, framework: ModelFram [ModelFramework.PYTORCH, ModelFramework.EQUINOX], ) def test_cleanup(tmp_path: Path, framework: ModelFramework) -> None: - """Test PyTorch model cleanup with MLflow persister.""" + """Test PyTorch model cleanup with MLflow tracker.""" artifact_dir = tmp_path / "mlruns" artifact_dir.mkdir() - persister = MLFlowPersister(experiment_name="pytorch-cleanup", tracking_uri=artifact_dir.as_uri()) + tracker = MlflowTracker(experiment_name="pytorch-cleanup", tracking_uri=artifact_dir.as_uri()) def run_status() -> str: """Get the status of the run.""" - client = persister.client - run_id = persister.run_id + client = tracker.client + run_id = tracker.run_id run = client.get_run(run_id) return run.info.status @@ -169,11 +165,11 @@ def run_status() -> str: else: raise ValueError(f"Unsupported model framework: {framework}") - persister.save_weights(model, step=0) - local_persister = persister.get_local_persister(model) + tracker.save_model(model, step=0) + local_persister = tracker.get_local_persister(model) assert local_persister.directory.exists() - persister.cleanup() + tracker.cleanup() assert run_status() == "FINISHED" assert not local_persister.directory.exists() @@ -192,30 +188,30 @@ def mock_create_requirements_file(tmp_path: Path) -> Generator[str, None, None]: requirements_path = tmp_path / "requirements.txt" requirements_path.write_text(REQUIREMENTS_CONTENT, encoding="utf-8") - with patch("simplexity.persistence.mlflow_persister.create_requirements_file") as mock_create: + with patch("simplexity.tracking.mlflow_tracker.create_requirements_file") as mock_create: mock_create.return_value = str(requirements_path) yield mock_create @pytest.mark.usefixtures("mock_create_requirements_file") -def test_save_model_to_registry(persister: MLFlowPersister) -> None: +def test_save_model_to_registry(tracker: MlflowTracker) -> None: """Test saving a PyTorch model to the MLflow model registry.""" model = Linear(in_features=4, out_features=2) registered_model_name = "test_model" - model_info = persister.save_model_to_registry(model, registered_model_name) + model_info = tracker.save_model_to_registry(model, registered_model_name) - model_versions = persister.client.search_model_versions( + model_versions = tracker.client.search_model_versions( filter_string=f"name='{registered_model_name}'", max_results=10 ) assert len(model_versions) == 1 assert model_versions[0].run_id == model_info.run_id assert model_versions[0].version == model_info.registered_model_version - assert persister.registry_uri is not None - registry_dir = Path(persister.registry_uri.replace("file://", "")) - models_meta_path = registry_dir / persister.model_dir / registered_model_name / "version-1" / "meta.yaml" + assert tracker.registry_uri is not None + registry_dir = Path(tracker.registry_uri.replace("file://", "")) + models_meta_path = registry_dir / tracker.model_dir / registered_model_name / "version-1" / "meta.yaml" assert models_meta_path.exists() with open(models_meta_path, encoding="utf-8") as f: models_meta = yaml.load(f, Loader=yaml.FullLoader) @@ -232,58 +228,58 @@ def test_save_model_to_registry(persister: MLFlowPersister) -> None: requirements_content = f.read() assert requirements_content == REQUIREMENTS_CONTENT.rstrip() - persister.cleanup() + tracker.cleanup() @pytest.mark.usefixtures("mock_create_requirements_file") -def test_save_model_to_registry_with_matching_active_run(persister: MLFlowPersister) -> None: +def test_save_model_to_registry_with_matching_active_run(tracker: MlflowTracker) -> None: """save_model_to_registry should reuse an already active run with the same id.""" model = Linear(in_features=4, out_features=2) model_info = None with ( - set_mlflow_uris(tracking_uri=persister.tracking_uri, registry_uri=persister.registry_uri), - mlflow.start_run(run_id=persister.run_id), + set_mlflow_uris(tracking_uri=tracker.tracking_uri, registry_uri=tracker.registry_uri), + mlflow.start_run(run_id=tracker.run_id), ): - model_info = persister.save_model_to_registry(model, "test_model_active_run") + model_info = tracker.save_model_to_registry(model, "test_model_active_run") assert model_info is not None - assert model_info.run_id == persister.run_id + assert model_info.run_id == tracker.run_id @pytest.mark.usefixtures("mock_create_requirements_file") -def test_save_model_to_registry_with_mismatched_active_run(persister: MLFlowPersister) -> None: +def test_save_model_to_registry_with_mismatched_active_run(tracker: MlflowTracker) -> None: """save_model_to_registry should fail when another run is active.""" model = Linear(in_features=4, out_features=2) with ( - set_mlflow_uris(tracking_uri=persister.tracking_uri, registry_uri=persister.registry_uri), - mlflow.start_run(experiment_id=persister.experiment_id) as active_run, + set_mlflow_uris(tracking_uri=tracker.tracking_uri, registry_uri=tracker.registry_uri), + mlflow.start_run(experiment_id=tracker.experiment_id) as active_run, ): - assert active_run.info.run_id != persister.run_id + assert active_run.info.run_id != tracker.run_id with pytest.raises(RuntimeError, match="Cannot save model to registry"): - persister.save_model_to_registry(model, "test_model_mismatched_run") + tracker.save_model_to_registry(model, "test_model_mismatched_run") -def test_save_model_to_registry_with_no_requirements(persister: MLFlowPersister) -> None: +def test_save_model_to_registry_with_no_requirements(tracker: MlflowTracker) -> None: """Test saving a PyTorch model to the MLflow model registry.""" model = Linear(in_features=4, out_features=2) registered_model_name = "test_model" - assert persister.tracking_uri is not None - pyproject_path = Path(persister.tracking_uri.replace("file://", "")).parent / "pyproject.toml" + assert tracker.tracking_uri is not None + pyproject_path = Path(tracker.tracking_uri.replace("file://", "")).parent / "pyproject.toml" - with patch("simplexity.persistence.mlflow_persister.create_requirements_file") as mock_create: + with patch("simplexity.tracking.mlflow_tracker.create_requirements_file") as mock_create: mock_create.side_effect = FileNotFoundError(f"pyproject.toml not found at {pyproject_path}") - model_info = persister.save_model_to_registry(model, registered_model_name) + model_info = tracker.save_model_to_registry(model, registered_model_name) assert model_info is not None - persister.cleanup() + tracker.cleanup() @pytest.mark.usefixtures("mock_create_requirements_file") -def test_save_model_to_registry_with_model_inputs(persister: MLFlowPersister) -> None: +def test_save_model_to_registry_with_model_inputs(tracker: MlflowTracker) -> None: """Test saving a PyTorch model to registry with model inputs for automatic signature inference.""" model = Linear(in_features=4, out_features=2) @@ -291,13 +287,13 @@ def test_save_model_to_registry_with_model_inputs(persister: MLFlowPersister) -> sample_input = torch.randn(2, 4) - model_info = persister.save_model_to_registry(model, registered_model_name, model_inputs=sample_input) + model_info = tracker.save_model_to_registry(model, registered_model_name, model_inputs=sample_input) assert model_info.signature is not None, "Registered model should have a signature when model_inputs is provided" - persister.cleanup() + tracker.cleanup() -def test_save_model_to_registry_non_pytorch(persister: MLFlowPersister) -> None: +def test_save_model_to_registry_non_pytorch(tracker: MlflowTracker) -> None: """Test saving a non-PyTorch model to the MLflow model registry.""" registered_model_name = "test_non_pytorch_model" @@ -308,11 +304,11 @@ def test_save_model_to_registry_non_pytorch(persister: MLFlowPersister) -> None: ValueError, match=r"Model must be a PyTorch model \(torch\.nn\.Module\), got ", ): - persister.save_model_to_registry(equinox_model, registered_model_name) + tracker.save_model_to_registry(equinox_model, registered_model_name) @pytest.mark.usefixtures("mock_create_requirements_file") -def test_save_model_to_registry_with_signature(persister: MLFlowPersister) -> None: +def test_save_model_to_registry_with_signature(tracker: MlflowTracker) -> None: """Test saving a PyTorch model to the MLflow model registry with a signature.""" model = Linear(in_features=4, out_features=2) @@ -321,69 +317,69 @@ def test_save_model_to_registry_with_signature(persister: MLFlowPersister) -> No signature_data = {"some_key": "some_value", "some_other_key": "some_other_value"} signature = infer_signature(signature_data) - with patch("simplexity.persistence.mlflow_persister.SIMPLEXITY_LOGGER.warning") as mock_warning: - model_info = persister.save_model_to_registry( + with patch("simplexity.tracking.mlflow_tracker.SIMPLEXITY_LOGGER.warning") as mock_warning: + model_info = tracker.save_model_to_registry( model, registered_model_name, model_inputs=sample_input, signature=signature ) mock_warning.assert_called_once_with("Signature provided in kwargs, ignoring inferred signature") assert model_info.signature == signature - persister.cleanup() + tracker.cleanup() -def test_model_registry_round_trip(persister: MLFlowPersister) -> None: +def test_model_registry_round_trip(tracker: MlflowTracker) -> None: """Test loading a PyTorch model from the MLflow model registry.""" original = Linear(in_features=4, out_features=2) registered_model_name = "test_load_model" - persister.save_model_to_registry(original, registered_model_name) + tracker.save_model_to_registry(original, registered_model_name) - loaded = persister.load_model_from_registry(registered_model_name) + loaded = tracker.load_model_from_registry(registered_model_name) assert _pytorch_models_equal(loaded, original) -def test_load_model_from_registry_multiple_versions(persister: MLFlowPersister) -> None: +def test_load_model_from_registry_multiple_versions(tracker: MlflowTracker) -> None: """Test loading different versions of a model from the registry.""" registered_model_name = "test_model" torch.manual_seed(0) model_v1 = Linear(in_features=4, out_features=2) - model_v1_info = persister.save_model_to_registry(model_v1, registered_model_name) + model_v1_info = tracker.save_model_to_registry(model_v1, registered_model_name) torch.manual_seed(1) model_v2 = Linear(in_features=4, out_features=2) - model_v2_info = persister.save_model_to_registry(model_v2, registered_model_name) + model_v2_info = tracker.save_model_to_registry(model_v2, registered_model_name) assert not _pytorch_models_equal(model_v1, model_v2) # Load version 1 - loaded_v1 = persister.load_model_from_registry( + loaded_v1 = tracker.load_model_from_registry( registered_model_name, version=str(model_v1_info.registered_model_version) ) assert _pytorch_models_equal(loaded_v1, model_v1) # Load version 2 - loaded_v2 = persister.load_model_from_registry( + loaded_v2 = tracker.load_model_from_registry( registered_model_name, version=str(model_v2_info.registered_model_version) ) assert _pytorch_models_equal(loaded_v2, model_v2) # Load latest (should be version 2) - loaded_latest = persister.load_model_from_registry(registered_model_name) + loaded_latest = tracker.load_model_from_registry(registered_model_name) assert _pytorch_models_equal(loaded_latest, model_v2) -def test_load_model_from_registry_with_stage(persister: MLFlowPersister) -> None: +def test_load_model_from_registry_with_stage(tracker: MlflowTracker) -> None: """Test loading a model from the registry with a stage.""" registered_model_name = "test_model" torch.manual_seed(0) model_prod = Linear(in_features=4, out_features=2) - model_prod_info = persister.save_model_to_registry(model_prod, registered_model_name) - persister.client.transition_model_version_stage( + model_prod_info = tracker.save_model_to_registry(model_prod, registered_model_name) + tracker.client.transition_model_version_stage( name=registered_model_name, version=str(model_prod_info.registered_model_version), stage="Production", @@ -391,8 +387,8 @@ def test_load_model_from_registry_with_stage(persister: MLFlowPersister) -> None torch.manual_seed(1) model_stage = Linear(in_features=4, out_features=2) - model_stage_info = persister.save_model_to_registry(model_stage, registered_model_name) - persister.client.transition_model_version_stage( + model_stage_info = tracker.save_model_to_registry(model_stage, registered_model_name) + tracker.client.transition_model_version_stage( name=registered_model_name, version=str(model_stage_info.registered_model_version), stage="Staging", @@ -400,41 +396,59 @@ def test_load_model_from_registry_with_stage(persister: MLFlowPersister) -> None assert not _pytorch_models_equal(model_prod, model_stage) - loaded_prod = persister.load_model_from_registry(registered_model_name, stage="Production") + loaded_prod = tracker.load_model_from_registry(registered_model_name, stage="Production") assert _pytorch_models_equal(loaded_prod, model_prod) - loaded_stage = persister.load_model_from_registry(registered_model_name, stage="Staging") + loaded_stage = tracker.load_model_from_registry(registered_model_name, stage="Staging") assert _pytorch_models_equal(loaded_stage, model_stage) -def test_load_model_from_registry_no_registered_model(persister: MLFlowPersister) -> None: +def test_load_model_from_registry_no_registered_model(tracker: MlflowTracker) -> None: """Test that loading a non-existent version raises an error.""" with pytest.raises(RuntimeError, match="No versions found for registered model 'model_name'"): - persister.load_model_from_registry(registered_model_name="model_name") + tracker.load_model_from_registry(registered_model_name="model_name") -def test_load_model_from_registry_both_version_and_stage(persister: MLFlowPersister) -> None: +def test_load_model_from_registry_both_version_and_stage(tracker: MlflowTracker) -> None: """Test that specifying both version and stage raises an error.""" with pytest.raises(ValueError, match="Cannot specify both version and stage. Use one or the other."): - persister.load_model_from_registry(registered_model_name="model_name", version="1", stage="Production") + tracker.load_model_from_registry(registered_model_name="model_name", version="1", stage="Production") -def test_list_model_versions(persister: MLFlowPersister) -> None: +@pytest.mark.usefixtures("mock_create_requirements_file") +def test_list_model_versions(tmp_path: Path) -> None: """Test listing model versions from the registry.""" + artifact_dir = tmp_path / "mlruns" + artifact_dir.mkdir() registered_model_name = "test_list_model" - versions = persister.list_model_versions(registered_model_name) + tracker = MlflowTracker( + experiment_name="registry-list", + run_name="registry-list-run", + tracking_uri=artifact_dir.as_uri(), + registry_uri=artifact_dir.as_uri(), + ) + + versions = tracker.list_model_versions(registered_model_name) assert len(versions) == 0 for version_number in range(1, 4): - torch.manual_seed(version_number) - model = Linear(in_features=4, out_features=2) - persister.save_model_to_registry(model, registered_model_name) + tracker = MlflowTracker( + experiment_name="registry-list", + run_name=f"registry-list-run-{version_number}", + tracking_uri=artifact_dir.as_uri(), + registry_uri=artifact_dir.as_uri(), + ) - versions = persister.list_model_versions(registered_model_name) + model = _get_pytorch_model(version_number) + tracker.save_model_to_registry(model, registered_model_name) + + versions = tracker.list_model_versions(registered_model_name) assert len(versions) == version_number version_numbers = {v["version"] for v in versions} - assert version_numbers == {v + 1 for v in range(version_number)} + assert version_numbers == {v for v in range(1, version_number + 1)} + + tracker.cleanup() diff --git a/tests/tracking/test_mlflow_tracker_artifacts.py b/tests/tracking/test_mlflow_tracker_artifacts.py new file mode 100644 index 00000000..bac5e2c0 --- /dev/null +++ b/tests/tracking/test_mlflow_tracker_artifacts.py @@ -0,0 +1,96 @@ +"""Tests for artifact logging functionality for MlflowTracker.""" + +# pylint: disable=all +# Temporarily disable all pylint checkers during AST traversal to prevent crash. +# The imports checker crashes when resolving simplexity package imports due to a bug +# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 +# pylint: enable=all +# Re-enable all pylint checkers for the checking phase. This allows other checks +# (code quality, style, undefined names, etc.) to run normally while bypassing +# the problematic imports checker that would crash during AST traversal. + +import os +import tempfile +from unittest.mock import MagicMock, patch + +import pytest + +from simplexity.tracking.mlflow_tracker import MlflowTracker + + +@pytest.fixture +def sample_json_data(): + """Sample JSON data for testing.""" + return {"key": "value", "number": 42, "list": [1, 2, 3]} + + +class TestMlflowTrackerArtifacts: + """Tests for MlflowTracker artifact logging.""" + + @pytest.fixture(autouse=True) + def setup_mlflow_temp_dir(self): + """Set up temporary directory for MLflow tracking during tests.""" + with tempfile.TemporaryDirectory() as tmp_dir: + original_uri = os.environ.get("MLFLOW_TRACKING_URI") + os.environ["MLFLOW_TRACKING_URI"] = f"file://{tmp_dir}" + try: + yield + finally: + if original_uri is not None: + os.environ["MLFLOW_TRACKING_URI"] = original_uri + else: + os.environ.pop("MLFLOW_TRACKING_URI", None) + + @pytest.fixture + def mock_mlflow_tracker(self): + """Create a mocked MlflowTracker for testing.""" + with patch("mlflow.MlflowClient") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + # Mock experiment and run retrieval/creation + mock_experiment = MagicMock() + mock_experiment.experiment_id = "exp_123" + mock_experiment.name = "test_experiment" + mock_client.get_experiment_by_name.return_value = None + mock_client.create_experiment.return_value = "exp_123" + mock_client.get_experiment.return_value = mock_experiment # For get_experiment helper + + mock_run = MagicMock() + mock_run.info.run_id = "run_456" + mock_run.info.run_name = "test_run" + mock_client.create_run.return_value = mock_run + mock_client.get_run.return_value = mock_run + + tracker = MlflowTracker(experiment_name="test_experiment") + yield tracker, mock_client + + def test_log_artifact_calls_client(self, mock_mlflow_tracker): + """Test that log_artifact calls the MLflow client correctly.""" + tracker, mock_client = mock_mlflow_tracker + + tracker.log_artifact("/path/to/file.txt", "artifacts/file.txt") + tracker.cleanup() + + mock_client.log_artifact.assert_called_once_with("run_456", "/path/to/file.txt", "artifacts/file.txt") + + def test_log_artifact_without_artifact_path(self, mock_mlflow_tracker): + """Test log_artifact without custom artifact path.""" + tracker, mock_client = mock_mlflow_tracker + + tracker.log_artifact("/path/to/model.pkl") + tracker.cleanup() + + mock_client.log_artifact.assert_called_once_with("run_456", "/path/to/model.pkl", None) + + def test_log_json_artifact_calls_client(self, mock_mlflow_tracker, sample_json_data): + """Test that log_json_artifact creates temp file and calls client.""" + tracker, mock_client = mock_mlflow_tracker + + tracker.log_json_artifact(sample_json_data, "metrics.json") + tracker.cleanup() + + mock_client.log_artifact.assert_called_once() + call_args = mock_client.log_artifact.call_args + assert call_args[0][0] == "run_456" # run_id + assert call_args[0][1].endswith("metrics.json") # temp file path + assert len(call_args[0]) == 2 diff --git a/tests/tracking/test_mlflow_tracker_plots.py b/tests/tracking/test_mlflow_tracker_plots.py new file mode 100644 index 00000000..d05269ec --- /dev/null +++ b/tests/tracking/test_mlflow_tracker_plots.py @@ -0,0 +1,144 @@ +"""Tests for plot and image logging functionality for MlflowTracker.""" + +# pylint: disable=all +# Temporarily disable all pylint checkers during AST traversal to prevent crash. +# The imports checker crashes when resolving simplexity package imports due to a bug +# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 +# pylint: enable=all +# Re-enable all pylint checkers for the checking phase. This allows other checks +# (code quality, style, undefined names, etc.) to run normally while bypassing +# the problematic imports checker that would crash during AST traversal. + +import os +import tempfile +from unittest.mock import MagicMock, patch + +import matplotlib.pyplot as plt +import numpy as np +import pytest +from PIL import Image + +from simplexity.tracking.mlflow_tracker import MlflowTracker + + +@pytest.fixture +def simple_matplotlib_figure(): + """Create a simple matplotlib figure for basic tests.""" + fig, ax = plt.subplots() + ax.plot([1, 2, 3]) + yield fig + plt.close(fig) + + +@pytest.fixture +def tiny_numpy_image(): + """Create a tiny numpy image array for testing.""" + return np.zeros((10, 10, 3), dtype=np.uint8) + + +@pytest.fixture +def larger_pil_image(): + """Create a larger PIL image for testing.""" + return Image.new("RGB", (20, 20)) + + +class TestMlflowTrackerPlotting: + """Tests for MlflowTracker figure and image logging.""" + + @pytest.fixture(autouse=True) + def setup_mlflow_temp_dir(self): + """Set up temporary directory for MLflow tracking during tests.""" + with tempfile.TemporaryDirectory() as tmp_dir: + original_uri = os.environ.get("MLFLOW_TRACKING_URI") + os.environ["MLFLOW_TRACKING_URI"] = f"file://{tmp_dir}" + try: + yield + finally: + if original_uri is not None: + os.environ["MLFLOW_TRACKING_URI"] = original_uri + else: + os.environ.pop("MLFLOW_TRACKING_URI", None) + + @patch("mlflow.MlflowClient") + def test_log_figure_calls_client_method(self, mock_client_class, simple_matplotlib_figure): + """Test that log_figure calls the MLflow client correctly.""" + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + # Mock experiment retrieval/creation + mock_experiment = MagicMock() + mock_experiment.experiment_id = "exp_123" + mock_experiment.name = "test_experiment" + mock_client.get_experiment_by_name.return_value = None + mock_client.create_experiment.return_value = "exp_123" + mock_client.get_experiment.return_value = mock_experiment + + mock_client.search_runs.return_value = [] + mock_run = MagicMock() + mock_run.info.run_id = "run_456" + mock_run.info.run_name = "test_run" + mock_client.create_run.return_value = mock_run + mock_client.get_run.return_value = mock_run + + tracker = MlflowTracker(experiment_name="test_experiment", run_name="test_run") + + tracker.log_figure(simple_matplotlib_figure, "test.png", dpi=150) + tracker.cleanup() + + mock_client.log_figure.assert_called_once_with("run_456", simple_matplotlib_figure, "test.png", dpi=150) + + @patch("mlflow.MlflowClient") + def test_log_image_artifact_mode_calls_client(self, mock_client_class, tiny_numpy_image): + """Test log_image in artifact mode calls MLflow client.""" + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_experiment = MagicMock() + mock_experiment.experiment_id = "exp_123" + mock_experiment.name = "test_experiment" + mock_client.get_experiment_by_name.return_value = None + mock_client.create_experiment.return_value = "exp_123" + mock_client.get_experiment.return_value = mock_experiment + + mock_client.search_runs.return_value = [] + mock_run = MagicMock() + mock_run.info.run_id = "run_456" + mock_run.info.run_name = "test_run" + mock_client.create_run.return_value = mock_run + mock_client.get_run.return_value = mock_run + + tracker = MlflowTracker(experiment_name="test_experiment") + + tracker.log_image(tiny_numpy_image, artifact_file="image.png") + tracker.cleanup() + + mock_client.log_image.assert_called_once_with( + "run_456", tiny_numpy_image, artifact_file="image.png", key=None, step=None + ) + + @patch("mlflow.MlflowClient") + def test_log_image_time_stepped_mode_calls_client(self, mock_client_class, larger_pil_image): + """Test log_image in time-stepped mode calls MLflow client.""" + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_experiment = MagicMock() + mock_experiment.experiment_id = "exp_123" + mock_experiment.name = "test_experiment" + mock_client.get_experiment_by_name.return_value = None + mock_client.create_experiment.return_value = "exp_123" + mock_client.get_experiment.return_value = mock_experiment + + mock_client.search_runs.return_value = [] + mock_run = MagicMock() + mock_run.info.run_id = "run_456" + mock_run.info.run_name = "test_run" + mock_client.create_run.return_value = mock_run + mock_client.get_run.return_value = mock_run + + tracker = MlflowTracker(experiment_name="test_experiment") + + tracker.log_image(larger_pil_image, key="training", step=50, timestamp=1234567890) + tracker.cleanup() + + mock_client.log_image.assert_called_once_with( + "run_456", larger_pil_image, artifact_file=None, key="training", step=50, timestamp=1234567890 + ) diff --git a/tests/persistence/test_s3_persister.py b/tests/tracking/test_s3_tracker.py similarity index 50% rename from tests/persistence/test_s3_persister.py rename to tests/tracking/test_s3_tracker.py index 72f10927..46f16dd4 100644 --- a/tests/persistence/test_s3_persister.py +++ b/tests/tracking/test_s3_tracker.py @@ -1,16 +1,34 @@ -"""Test the S3 persister.""" +"""Test the S3 tracker.""" +import sys import tempfile from pathlib import Path +from unittest.mock import MagicMock import chex import equinox as eqx import jax import pytest -from simplexity.persistence.local_equinox_persister import LocalEquinoxPersister -from simplexity.persistence.s3_persister import S3Persister -from tests.persistence.s3_mocks import MockBoto3Session, MockS3Client +from simplexity.predictive_models.types import ModelFramework +from simplexity.tracking.model_persistence.local_equinox_persister import ( + LocalEquinoxPersister, +) +from simplexity.tracking.s3_tracker import S3Tracker +from tests.tracking.s3_mocks import MockBoto3Session, MockS3Client + + +@pytest.fixture(autouse=True) +def mock_boto(): + """Mock boto3 and botocore if missing.""" + if "boto3" not in sys.modules: + sys.modules["boto3"] = MagicMock() + sys.modules["boto3.session"] = MagicMock() + if "botocore" not in sys.modules: + sys.modules["botocore"] = MagicMock() + sys.modules["botocore.exceptions"] = MagicMock() + # Mock ClientError + sys.modules["botocore.exceptions"].ClientError = Exception def get_model(seed: int = 0) -> eqx.Module: @@ -18,42 +36,43 @@ def get_model(seed: int = 0) -> eqx.Module: return eqx.nn.Linear(in_features=4, out_features=2, key=jax.random.key(seed)) -def test_s3_persister(tmp_path: Path): - """Test S3Persister initialization.""" +def test_s3_tracker(tmp_path: Path): + """Test S3Tracker initialization.""" s3_client_mock = MockS3Client(tmp_path) temp_dir = tempfile.TemporaryDirectory() with temp_dir: local_persister = LocalEquinoxPersister(temp_dir.name) - persister = S3Persister( + tracker = S3Tracker( bucket="test_bucket", prefix="test_prefix", s3_client=s3_client_mock, temp_dir=temp_dir, - local_persister=local_persister, + local_persisters={ModelFramework.EQUINOX: local_persister}, ) - assert persister.bucket == "test_bucket" - assert persister.prefix == "test_prefix" + assert tracker.bucket == "test_bucket" + assert tracker.prefix == "test_prefix" model = get_model(0) assert not (tmp_path / "test_bucket" / "test_prefix" / "0" / "model.eqx").exists() - persister.save_weights(model, 0) + tracker.save_model(model, 0) assert (tmp_path / "test_bucket" / "test_prefix" / "0" / "model.eqx").exists() new_model = get_model(1) with pytest.raises(AssertionError): chex.assert_trees_all_close(new_model, model) - loaded_model = persister.load_weights(new_model, 0) + loaded_model = tracker.load_model(new_model, 0) chex.assert_trees_all_equal(loaded_model, model) -def test_s3_persister_from_config(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): - """Test S3Persister.from_config with mocked Boto3 session.""" +def test_s3_tracker_from_config(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + """Test S3Tracker.from_config with mocked Boto3 session.""" def mock_session_init(profile_name=None, **kwargs): # pylint: disable=unused-argument """Mock session initialization.""" return MockBoto3Session.create(tmp_path) - monkeypatch.setattr("simplexity.persistence.s3_persister.boto3.session.Session", mock_session_init) + # Patch where it's imported (or sys.modules) + monkeypatch.setattr("boto3.session.Session", mock_session_init) # Create config.ini file for testing config_file = tmp_path / "test_config.ini" @@ -65,18 +84,18 @@ def mock_session_init(profile_name=None, **kwargs): # pylint: disable=unused-ar """ config_file.write_text(config_content) - persister = S3Persister.from_config(prefix="test_prefix", config_filename=str(config_file)) + tracker = S3Tracker.from_config(prefix="test_prefix", config_filename=str(config_file)) - assert persister.bucket == "test_bucket" - assert persister.prefix == "test_prefix" + assert tracker.bucket == "test_bucket" + assert tracker.prefix == "test_prefix" model = get_model(0) assert not (tmp_path / "test_bucket" / "test_prefix" / "0" / "model.eqx").exists() - persister.save_weights(model, 0) + tracker.save_model(model, 0) assert (tmp_path / "test_bucket" / "test_prefix" / "0" / "model.eqx").exists() new_model = get_model(1) with pytest.raises(AssertionError): chex.assert_trees_all_close(new_model, model) - loaded_model = persister.load_weights(new_model, 0) + loaded_model = tracker.load_model(new_model, 0) chex.assert_trees_all_equal(loaded_model, model)