diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 31ca6a7..3da48e6 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10", "3.11", "3.12"] + python-version: ["3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v5 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ec30804..08e218f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,3 +19,10 @@ repos: # when "--baseline" with "--use-all-plugins", pre-commit scan with all available plugins # add "--fail-on-unaudited" to fail pre-commit for unaudited potential secrets args: [--baseline, .secrets.baseline, --use-all-plugins] + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.12.3 + hooks: + - id: ruff-format + types_or: + - python + - jupyter diff --git a/.secrets.baseline b/.secrets.baseline new file mode 100644 index 0000000..e3ac1a7 --- /dev/null +++ b/.secrets.baseline @@ -0,0 +1,146 @@ +{ + "exclude": { + "files": "^.secrets.baseline$", + "lines": null + }, + "generated_at": "2025-10-01T20:02:29Z", + "plugins_used": [ + { + "name": "AWSKeyDetector" + }, + { + "name": "ArtifactoryDetector" + }, + { + "name": "AzureStorageKeyDetector" + }, + { + "base64_limit": 4.5, + "name": "Base64HighEntropyString" + }, + { + "name": "BasicAuthDetector" + }, + { + "name": "BoxDetector" + }, + { + "name": "CloudantDetector" + }, + { + "ghe_instance": "github.ibm.com", + "name": "GheDetector" + }, + { + "name": "GitHubTokenDetector" + }, + { + "hex_limit": 3, + "name": "HexHighEntropyString" + }, + { + "name": "IbmCloudIamDetector" + }, + { + "name": "IbmCosHmacDetector" + }, + { + "name": "JwtTokenDetector" + }, + { + "keyword_exclude": null, + "name": "KeywordDetector" + }, + { + "name": "MailchimpDetector" + }, + { + "name": "NpmDetector" + }, + { + "name": "PrivateKeyDetector" + }, + { + "name": "SlackDetector" + }, + { + "name": "SoftlayerDetector" + }, + { + "name": "SquareOAuthDetector" + }, + { + "name": "StripeDetector" + }, + { + "name": "TwilioKeyDetector" + } + ], + "results": { + "plotting/plot_results_mlflow.ipynb": [ + { + "hashed_secret": "5810b71c07271f259208c5790992170ac1e13b37", + "is_verified": false, + "line_number": 437, + "type": "Base64 High Entropy String", + "verified_result": null + }, + { + "hashed_secret": "1c1dc227208cec78bbdb8d9247164879f908a9ad", + "is_verified": false, + "line_number": 482, + "type": "Base64 High Entropy String", + "verified_result": null + }, + { + "hashed_secret": "e57967bc8f018a30bb192717673876f0ebdbe5d9", + "is_verified": false, + "line_number": 558, + "type": "Base64 High Entropy String", + "verified_result": null + } + ], + "plotting/plot_results_repeated_runs.ipynb": [ + { + "hashed_secret": "e52b18568a4fa073b958134ea5ec0f9407b6ebc3", + "is_verified": false, + "line_number": 352, + "type": "Base64 High Entropy String", + "verified_result": null + }, + { + "hashed_secret": "43cf2641021e5833120affd5a2bcdf35089eaf75", + "is_verified": false, + "line_number": 417, + "type": "Base64 High Entropy String", + "verified_result": null + }, + { + "hashed_secret": "78f9a422a3afb6ff5aff30094699c2b299dfd614", + "is_verified": false, + "line_number": 949, + "type": "Base64 High Entropy String", + "verified_result": null + }, + { + "hashed_secret": "2525429c7a93512ed0c4b799b867a83a6b19f7ff", + "is_verified": false, + "line_number": 1014, + "type": "Base64 High Entropy String", + "verified_result": null + }, + { + "hashed_secret": "8915fab07d3bf85d3755089a7fc82e911405d40a", + "is_verified": false, + "line_number": 1080, + "type": "Base64 High Entropy String", + "verified_result": null + } + ] + }, + "version": "0.13.1+ibm.61.dss", + "word_list": { + "file": null, + "hash": null + } +} diff --git a/benchmark/config_util/geobenchv2_template.yaml b/benchmark/config_util/geobenchv2_template.yaml deleted file mode 100644 index 38351a3..0000000 --- a/benchmark/config_util/geobenchv2_template.yaml +++ /dev/null @@ -1,28 +0,0 @@ -experiment_name: my_experiment -defaults: - terratorch_task: - model_args: - backbone: terramind_v1_large - backbone_pretrained: true - model_factory: ObjectDetectionModelFactory - optimizer: AdamW - trainer_args: - log_every_n_steps: 1 - max_epochs: 1 -tasks: - - name: X - type: object_detection - direction: max - metric: X - terratorch_task: - datamodule: -n_trials: 1 -save_models: False -storage_uri: /opt/app-root/src/fm-geospatial/pf/logs/geobench/mlflow -run_repetitions: 1 -optimization_space: - lr: - max: 1e-3 - min: 1e-6 - type: real - log: true diff --git a/benchmark/main.py b/benchmark/main.py deleted file mode 100644 index b72b160..0000000 --- a/benchmark/main.py +++ /dev/null @@ -1,324 +0,0 @@ -from jsonargparse import Namespace -import logging -from pathlib import Path -from typing import Any, List -from jsonargparse import ArgumentParser -import pandas as pd -from benchmark.backbone_benchmark import benchmark_backbone -from benchmark.benchmark_types import Defaults, Task -from benchmark.repeat_best_experiment import rerun_best_from_backbone -from benchmark.utils import ( - get_logger, - import_custom_modules, - get_results_and_parameters, -) - - -def _summarize( - config_init: Namespace, - hpo: bool, - repeat: bool, - storage_uri: str, - logger: logging.RootLogger, -) -> pd.DataFrame: - """only summarize results from multiple experiments - - Args: - config_init (Namespace): _description_ - hpo (bool): flag that indicates whether to run hpo - repeat (bool): flag that indicates whether to repeat best experiment - storage_uri (str): path to directory in which results will be stored - logger (logging.RootLogger): logger variable - - Returns: - _type_: _description_ - """ - assert ( - hpo is False and repeat is False - ), f"Error! both {repeat=} and {hpo=} must be False when summarizing results from multiple experiments." - - list_of_experiment_names = config_init.list_of_experiment_names - assert isinstance( - list_of_experiment_names, list - ), f"Error! {list_of_experiment_names=} is not a list" - for exp in list_of_experiment_names: - assert isinstance(exp, str), f"Error! {exp=} is not a str" - - task_names = config_init.task_names - assert isinstance(task_names, list), f"Error! {task_names=} is not a list" - for t in task_names: - assert isinstance(t, str), f"Error! {t=} is not a str" - - task_metrics = config_init.task_metrics - assert isinstance(task_metrics, list), f"Error! {task_metrics=} is not a list" - for t in task_metrics: - assert isinstance(t, str), f"Error! {t=} is not a str" - - benchmark_name = config_init.benchmark_name - assert isinstance(benchmark_name, str), f"Error! {benchmark_name=} is not a str" - - run_repetitions = config_init.run_repetitions - assert ( - isinstance(run_repetitions, int) and run_repetitions > 0 - ), f"Error! {run_repetitions=} is invalid" - # get results and parameters from mlflow logs - results_and_parameters = get_results_and_parameters( - benchmark_name=benchmark_name, - storage_uri=storage_uri, - logger=logger, - experiments=list_of_experiment_names, - task_names=task_names, - num_repetitions=run_repetitions, - task_metrics=task_metrics, - ) - return results_and_parameters - - -def _repeat_experiment( - config_init: Namespace, - storage_uri: str, - experiment_name: str, - parent_run_id: str, - defaults: Defaults, - tasks: list[Task], - optimization_space: dict, - run_repetitions: int, - save_models: bool, - report_on_best_val: bool, - logger: logging.RootLogger, -): - """repeat best experiments - - Args: - config_init (Namespace): _description_ - storage_uri (str): _description_ - experiment_name (str): _description_ - parent_run_id (str): _description_ - defaults (Defaults): _description_ - tasks (list[Task]): _description_ - optimization_space (dict): _description_ - run_repetitions (int): _description_ - save_models (bool): _description_ - report_on_best_val (bool): _description_ - logger (logging.RootLogger): _description_ - - Returns: - _type_: _description_ - """ - output: str | None = config_init.output_path - if output is None: - storage_uri_path = Path(storage_uri) - assert ( - storage_uri_path.exists() and storage_uri_path.is_dir() - ), f"Error! Unable to create new output_path based on storage_uri_path because the latter does not exist: {storage_uri_path}" - output_path = storage_uri_path.parents[0] / "repeated_exp_output_csv" - output_path.mkdir(parents=True, exist_ok=True) - output_path = output_path / f"{experiment_name}_repeated_exp_mlflow.csv" - output = str(output_path) - - logger.info("Rerun best experiments...") - rerun_best_from_backbone( - logger=logger, - parent_run_id=parent_run_id, - output_path=output_path, - defaults=defaults, - tasks=tasks, - experiment_name=experiment_name, - storage_uri=storage_uri, - optimization_space=optimization_space, - run_repetitions=run_repetitions, - save_models=save_models, - report_on_best_val=report_on_best_val, - ) - - -def main(): - - parser = ArgumentParser() - - parser.add_argument('--defaults', type=Defaults) # to ignore model - parser.add_argument('--optimization_space', type=dict) # to ignore model - parser.add_argument('--experiment_name', type=str) # to ignore model - parser.add_argument('--run_name', type=str) # to ignore model - parser.add_argument('--save_models', type=bool) # to ignore model - parser.add_argument('--storage_uri', type=str) # to ignore model - parser.add_argument('--ray_storage_path', type=str) # to ignore model - parser.add_argument('--n_trials', type=int) # to ignore model - parser.add_argument('--run_repetitions', type=int) # to ignore model - parser.add_argument('--tasks', type=list[Task]) - parser.add_argument("--parent_run_id", type=str) - parser.add_argument("--output_path", type=str) - parser.add_argument("--logger", type=str) - parser.add_argument("--config", type=str) - parser.add_argument('--custom_modules_path', type=str) - parser.add_argument('--report_on_best_val', type=bool, default=True) - parser.add_argument('--test_models', type=bool, default=False) - parser.add_argument('--bayesian_search', type=bool, default=True) - parser.add_argument("--hpo", help="optimize hyperparameters", action="store_true") - parser.add_argument("--repeat", help="repeat best experiments", action="store_true") - parser.add_argument( - "--summarize", - help="summarize results from repeated experiments", - action="store_true", - ) - parser.add_argument('--list_of_experiment_names', type=list[str]) - parser.add_argument('--task_names', type=list[str]) - parser.add_argument('--task_metrics', type=list[str]) - parser.add_argument( - '--benchmark_name', - type=str, - help="name of summarized results file", - ) - - args = parser.parse_args() - config_path: str | None = args.config - if config_path is None: - msg = """ - Error: config argument has not been passed - usage: terratorch [-h] [--hpo] [--repeat] [--summarize] [--config CONFIG] - """ - print(msg) - else: - assert isinstance( - config_path, str - ), f"Error! Unexpected config type: {config_path}" - config = parser.parse_path(config_path) - config_init: Namespace = parser.instantiate_classes(config) - - summarize: bool = args.summarize - assert isinstance(summarize, bool), f"Error! {summarize=} is not a bool" - repeat = args.repeat - assert isinstance(repeat, bool), f"Error! {repeat=} is not a bool" - hpo = args.hpo - assert isinstance(hpo, bool), f"Error! {hpo=} is not a bool" - - storage_uri = config_init.storage_uri - assert isinstance(storage_uri, str), f"Error! {storage_uri=} is not a str" - logger_path = config_init.logger - if logger_path is None: - storage_uri_path = Path(storage_uri) - logger = get_logger( - log_folder=f"{str(storage_uri_path.parents[0])}/job_logs" - ) - else: - logging.config.fileConfig(fname=logger_path, disable_existing_loggers=False) - logger = logging.getLogger("terratorch-iterate") - - # only summarize results from multiple experiments - if summarize: - return _summarize( - config_init=config_init, - ) - - # optimize hyperparameters and/or do repeated runs for single experiments - assert ( - hpo is True or repeat is True - ), f"Error! either {repeat=} or {hpo=} must be True" - parent_run_id = args.parent_run_id - if parent_run_id is not None: - assert isinstance( - parent_run_id, str - ), f"Error! {parent_run_id=} is not a str" - - # validate the objects - experiment_name = config_init.experiment_name - assert isinstance( - experiment_name, str - ), f"Error! {experiment_name=} is not a str" - run_name = config_init.run_name - if run_name is not None: - assert isinstance(run_name, str), f"Error! {run_name=} is not a str" - # validate defaults - defaults = config_init.defaults - assert isinstance(defaults, Defaults), f"Error! {defaults=} is not a Defaults" - - tasks = config_init.tasks - assert isinstance(tasks, list), f"Error! {tasks=} is not a list" - for t in tasks: - assert isinstance(t, Task), f"Error! {t=} is not a Task" - # if there is not specific terratorch_task specified, then use default terratorch_task - if t.terratorch_task is None: - t.terratorch_task = defaults.terratorch_task - # defaults.trainer_args["max_epochs"] = 5 - - optimization_space = config_init.optimization_space - assert isinstance( - optimization_space, dict - ), f"Error! {optimization_space=} is not a dict" - - # ray_storage_path is optional - ray_storage_path = config_init.ray_storage_path - if ray_storage_path is not None: - assert isinstance( - ray_storage_path, str - ), f"Error! {ray_storage_path=} is not a str" - - n_trials = config_init.n_trials - assert ( - isinstance(n_trials, int) and n_trials > 0 - ), f"Error! {n_trials=} is invalid" - run_repetitions = config_init.run_repetitions - - report_on_best_val = config_init.report_on_best_val - assert isinstance( - report_on_best_val, bool - ), f"Error! {ray_storage_path=} is not a bool" - - save_models = config_init.save_models - assert isinstance(save_models, bool), f"Error! {save_models=} is not a bool" - - test_models = config_init.test_models - assert isinstance(test_models, bool), f"Error! {test_models=} is not a bool" - - bayesian_search = config_init.bayesian_search - assert isinstance( - bayesian_search, bool - ), f"Error! {bayesian_search=} is not a bool" - - # custom_modules_path is optional - custom_modules_path = config_init.custom_modules_path - if custom_modules_path is not None: - assert isinstance( - custom_modules_path, str - ), f"Error! {custom_modules_path=} is not a str" - import_custom_modules( - logger=logger, custom_modules_path=custom_modules_path - ) - - if repeat and not hpo: - _repeat_experiment( - config_init=config_init, - storage_uri=storage_uri, - experiment_name=experiment_name, - defaults=defaults, - tasks=tasks, - optimization_space=optimization_space, - run_repetitions=run_repetitions, - save_models=save_models, - logger=logger, - ) - else: - if not repeat and hpo: - run_repetitions = 0 - - # run_repetitions is an optional parameter - benchmark_backbone( - defaults=defaults, - tasks=tasks, - experiment_name=experiment_name, - storage_uri=storage_uri, - ray_storage_path=ray_storage_path, - run_name=run_name, - optimization_space=optimization_space, - n_trials=n_trials, - run_repetitions=run_repetitions, - save_models=save_models, - report_on_best_val=report_on_best_val, - test_models=test_models, - bayesian_search=bayesian_search, - logger=logger, - ) - - -if __name__ == "__main__": - main() diff --git a/benchmark/resources/dataset_specifications/agb.yaml b/benchmark/resources/dataset_specifications/agb.yaml deleted file mode 100644 index 33e9c95..0000000 --- a/benchmark/resources/dataset_specifications/agb.yaml +++ /dev/null @@ -1,64 +0,0 @@ -class_path: terratorch.datamodules.GenericNonGeoPixelwiseRegressionDataModule -init_args: - batch_size: 16 - num_workers: 4 - train_transform: - - class_path: albumentations.HorizontalFlip - init_args: - p: 0.5 - - class_path: albumentations.augmentations.geometric.rotate.Rotate - init_args: - limit: 30 - border_mode: 0 # cv2.BORDER_CONSTANT - # value: 0 - # mask_value: 1 - p: 0.5 - dict_kwargs: - value: 0 - mask_value: 1 - - class_path: ToTensorV2 - dataset_bands: - - 0 - - BLUE - - GREEN - - RED - - NIR_NARROW - - SWIR_1 - - SWIR_2 - - 1 - - 2 - - 3 - - 4 - output_bands: - - BLUE - - GREEN - - RED - - NIR_NARROW - - SWIR_1 - - SWIR_2 - rgb_indices: - - 2 - - 1 - - 0 - train_data_root: /dccstor/hhr-weather/latest_filters_all_agb_patches_tts_clipped_0_500/train_images - train_label_data_root: /dccstor/hhr-weather/latest_filters_all_agb_patches_tts_clipped_0_500/train_labels - val_data_root: /dccstor/hhr-weather/latest_filters_all_agb_patches_tts_clipped_0_500/val_images - val_label_data_root: /dccstor/hhr-weather/latest_filters_all_agb_patches_tts_clipped_0_500/val_labels - test_data_root: /dccstor/hhr-weather/latest_filters_all_agb_patches_tts_clipped_0_500/test_images - test_label_data_root: /dccstor/hhr-weather/latest_filters_all_agb_patches_tts_clipped_0_500/test_labels - # img_grep: "*.tif" - # label_grep: "*.tif" - means: - - 385.88501817 - - 714.60615207 - - 658.96267376 - - 3314.57774238 - - 2238.71812558 - - 1250.00982518 - stds: - - 264.62872 - - 355.62848 - - 504.54855 - - 898.4953 - - 947.22894 - - 828.1297 diff --git a/benchmark/resources/dataset_specifications/eurosat.yaml b/benchmark/resources/dataset_specifications/eurosat.yaml deleted file mode 100644 index 029ee51..0000000 --- a/benchmark/resources/dataset_specifications/eurosat.yaml +++ /dev/null @@ -1,28 +0,0 @@ -class_path: terratorch.datamodules.TorchNonGeoDataModule -init_args: - transforms: - # a possible way to select bands: - # - class_path: SelectBands - # init_args: - # band_indices: - # - 2 - # - 1 - # - 0 - - class_path: albumentations.augmentations.geometric.resize.Resize - dict_kwargs: - height: 224 - width: 224 - - class_path: ToTensorV2 - cls: torchgeo.datamodules.EuroSATDataModule - batch_size: 16 - num_workers: 4 -dict_kwargs: - root: /dccstor/geofm-pre/EuroSat - download: True - bands: - - B02 - - B03 - - B04 - - B08A - - B11 - - B12 diff --git a/benchmark/resources/dataset_specifications/fire_scars.yaml b/benchmark/resources/dataset_specifications/fire_scars.yaml deleted file mode 100644 index a2f50a1..0000000 --- a/benchmark/resources/dataset_specifications/fire_scars.yaml +++ /dev/null @@ -1,56 +0,0 @@ -class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule -init_args: - batch_size: 4 - num_workers: 8 - dataset_bands: - - BLUE - - GREEN - - RED - - NIR_NARROW - - SWIR_1 - - SWIR_2 - output_bands: - - BLUE - - GREEN - - RED - - NIR_NARROW - - SWIR_1 - - SWIR_2 - rgb_indices: - - 2 - - 1 - - 0 - train_transform: - - class_path: albumentations.RandomCrop - init_args: - height: 224 - width: 224 - - class_path: albumentations.HorizontalFlip - init_args: - p: 0.5 - - class_path: ToTensorV2 - no_data_replace: 0 - no_label_replace: -1 - train_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/training - train_label_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/training - val_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation - val_label_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation - test_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation - test_label_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation - img_grep: "*_merged.tif" - label_grep: "*.mask.tif" - means: - - 0.033349706741586264 - - 0.05701185520536176 - - 0.05889748132001316 - - 0.2323245113436119 - - 0.1972854853760658 - - 0.11944914225186566 - stds: - - 0.02269135568823774 - - 0.026807560223070237 - - 0.04004109844362779 - - 0.07791732423672691 - - 0.08708738838140137 - - 0.07241979477437814 - num_classes: 2 diff --git a/benchmark/resources/dataset_specifications/multi_temporal_crop.yaml b/benchmark/resources/dataset_specifications/multi_temporal_crop.yaml deleted file mode 100644 index bc30877..0000000 --- a/benchmark/resources/dataset_specifications/multi_temporal_crop.yaml +++ /dev/null @@ -1,57 +0,0 @@ -class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule -init_args: - batch_size: 8 - num_workers: 12 - train_transform: - - class_path: FlattenTemporalIntoChannels - - class_path: albumentations.Flip - - class_path: ToTensorV2 - - class_path: UnflattenTemporalFromChannels - init_args: - n_timesteps: 3 - dataset_bands: - - BLUE - - GREEN - - RED - - NIR_NARROW - - SWIR_1 - - SWIR_2 - output_bands: - - BLUE - - GREEN - - RED - - NIR_NARROW - - SWIR_1 - - SWIR_2 - rgb_indices: - - 2 - - 1 - - 0 - reduce_zero_label: True - expand_temporal_dimension: True - train_data_root: /dccstor/geofm-finetuning/hls_cdl_reclassed/training_chips - train_label_data_root: /dccstor/geofm-finetuning/hls_cdl_reclassed/training_chips - val_data_root: /dccstor/geofm-finetuning/hls_cdl_reclassed/validation_chips - val_label_data_root: /dccstor/geofm-finetuning/hls_cdl_reclassed/validation_chips - test_data_root: /dccstor/geofm-finetuning/hls_cdl_reclassed/validation_chips - test_label_data_root: /dccstor/geofm-finetuning/hls_cdl_reclassed/validation_chips - train_split: /dccstor/geofm-finetuning/hls_cdl_reclassed/training_chips/training_data.txt - test_split: /dccstor/geofm-finetuning/hls_cdl_reclassed/validation_chips/validation_data.txt - val_split: /dccstor/geofm-finetuning/hls_cdl_reclassed/validation_chips/validation_data.txt - img_grep: "*_merged.tif" - label_grep: "*.mask.tif" - means: - - 494.905781 - - 815.239594 - - 924.335066 - - 2968.881459 - - 2634.621962 - - 1739.579917 - stds: - - 284.925432 - - 357.84876 - - 575.566823 - - 896.601013 - - 951.900334 - - 921.407808 - num_classes: 13 diff --git a/benchmark/resources/dataset_specifications/sen1floods11.yaml b/benchmark/resources/dataset_specifications/sen1floods11.yaml deleted file mode 100644 index d3201e1..0000000 --- a/benchmark/resources/dataset_specifications/sen1floods11.yaml +++ /dev/null @@ -1,59 +0,0 @@ -class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule -init_args: - batch_size: 8 - num_workers: 4 - constant_scale: 0.0001 - dataset_bands: - - COASTAL_AEROSOL - - BLUE - - GREEN - - RED - - RED_EDGE_1 - - RED_EDGE_2 - - RED_EDGE_3 - - NIR_BROAD - - NIR_NARROW - - WATER_VAPOR - - CIRRUS - - SWIR_1 - - SWIR_2 - output_bands: - - BLUE - - GREEN - - RED - - NIR_NARROW - - SWIR_1 - - SWIR_2 - rgb_indices: - - 2 - - 1 - - 0 - train_data_root: /dccstor/geofm-finetuning/datasets/sen1floods11/v1.1/data/flood_events/HandLabeled/S2Hand/ - train_label_data_root: /dccstor/geofm-finetuning/datasets/sen1floods11/v1.1/data/flood_events/HandLabeled/LabelHand - val_data_root: /dccstor/geofm-finetuning/datasets/sen1floods11/v1.1/data/flood_events/HandLabeled/S2Hand/ - val_label_data_root: /dccstor/geofm-finetuning/datasets/sen1floods11/v1.1/data/flood_events/HandLabeled/LabelHand - test_data_root: /dccstor/geofm-finetuning/datasets/sen1floods11/v1.1/data/flood_events/HandLabeled/S2Hand/ - test_label_data_root: /dccstor/geofm-finetuning/datasets/sen1floods11/v1.1/data/flood_events/HandLabeled/LabelHand - # these must be obtained by running terratorch/examples/scripts/convert_sen1floods11_splits.py on the original split csv files - train_split: /dccstor/geofm-finetuning/datasets/sen1floods11/v1.1/splits/flood_handlabeled/flood_train_data.txt - test_split: /dccstor/geofm-finetuning/datasets/sen1floods11/v1.1/splits/flood_handlabeled/flood_test_data.txt - val_split: /dccstor/geofm-finetuning/datasets/sen1floods11/v1.1/splits/flood_handlabeled/flood_valid_data.txt - img_grep: "*_S2Hand.tif" - label_grep: "*_LabelHand.tif" - no_label_replace: -1 - no_data_replace: 0 -means: - - 0.1412956 - - 0.13795798 - - 0.12353792 - - 0.30902815 - - 0.2044958 - - 0.11912015 -stds: - - 0.07406382 - - 0.07370365 - - 0.08692279 - - 0.11798815 - - 0.09772074 - - 0.07659938 -num_classes: 2 diff --git a/benchmark/resources/dataset_specifications/sen1floods11_transforms.yaml b/benchmark/resources/dataset_specifications/sen1floods11_transforms.yaml deleted file mode 100644 index ffea683..0000000 --- a/benchmark/resources/dataset_specifications/sen1floods11_transforms.yaml +++ /dev/null @@ -1,67 +0,0 @@ -class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule -init_args: - batch_size: 8 - num_workers: 4 - constant_scale: 0.0001 - dataset_bands: - - COASTAL_AEROSOL - - BLUE - - GREEN - - RED - - RED_EDGE_1 - - RED_EDGE_2 - - RED_EDGE_3 - - NIR_BROAD - - NIR_NARROW - - WATER_VAPOR - - CIRRUS - - SWIR_1 - - SWIR_2 - output_bands: - - BLUE - - GREEN - - RED - - NIR_NARROW - - SWIR_1 - - SWIR_2 - rgb_indices: - - 2 - - 1 - - 0 - train_data_root: /dccstor/geofm-finetuning/datasets/sen1floods11/v1.1/data/flood_events/HandLabeled/S2Hand/ - train_label_data_root: /dccstor/geofm-finetuning/datasets/sen1floods11/v1.1/data/flood_events/HandLabeled/LabelHand - val_data_root: /dccstor/geofm-finetuning/datasets/sen1floods11/v1.1/data/flood_events/HandLabeled/S2Hand/ - val_label_data_root: /dccstor/geofm-finetuning/datasets/sen1floods11/v1.1/data/flood_events/HandLabeled/LabelHand - test_data_root: /dccstor/geofm-finetuning/datasets/sen1floods11/v1.1/data/flood_events/HandLabeled/S2Hand/ - test_label_data_root: /dccstor/geofm-finetuning/datasets/sen1floods11/v1.1/data/flood_events/HandLabeled/LabelHand - # these must be obtained by running terratorch/examples/scripts/convert_sen1floods11_splits.py on the original split csv files - train_split: /dccstor/geofm-finetuning/datasets/sen1floods11/v1.1/splits/flood_handlabeled/flood_train_data.txt - test_split: /dccstor/geofm-finetuning/datasets/sen1floods11/v1.1/splits/flood_handlabeled/flood_test_data.txt - val_split: /dccstor/geofm-finetuning/datasets/sen1floods11/v1.1/splits/flood_handlabeled/flood_valid_data.txt - img_grep: "*_S2Hand.tif" - label_grep: "*_LabelHand.tif" - no_label_replace: -1 - no_data_replace: 0 - train_transform: - - class_path: albumentations.HorizontalFlip - init_args: - p: 0.5 - - class_path: albumentations.VerticalFlip - init_args: - p: 0.5 - - class_path: ToTensorV2 - means: - - 0.1412956 - - 0.13795798 - - 0.12353792 - - 0.30902815 - - 0.2044958 - - 0.11912015 - stds: - - 0.07406382 - - 0.07370365 - - 0.08692279 - - 0.11798815 - - 0.09772074 - - 0.07659938 - num_classes: 2 diff --git a/configs/templates/template.yaml b/configs/templates/template.yaml new file mode 100644 index 0000000..0e31607 --- /dev/null +++ b/configs/templates/template.yaml @@ -0,0 +1,41 @@ +experiment_name: X +defaults: + trainer_args: + max_epochs: 5 + log_every_n_steps: 1 + terratorch_task: + model_factory: EncoderDecoderFactory + model_args: + backbone: X + backbone_pretrained: true + optimizer: AdamW + scheduler: ReduceLROnPlateau + scheduler_hparams: + mode: min + factor: 0.5 + patience: 5 + threshold: 0.0001 + threshold_mode: rel + cooldown: 0 + min_lr: 0.0 + eps: 1.0e-08 + + +tasks: + - name: X + type: segmentation + direction: max + metric: X + terratorch_task: + datamodule: + +n_trials: 1 +save_models: False +storage_uri: ./mlflow +run_repetitions: 5 +optimization_space: + lr: + max: 1e-3 + min: 1e-6 + type: real + log: true diff --git a/configs/tests/geobench_v1_prithvi_big_earth_net.yaml b/configs/tests/geobench_v1_prithvi_big_earth_net.yaml deleted file mode 100644 index 6c19c15..0000000 --- a/configs/tests/geobench_v1_prithvi_big_earth_net.yaml +++ /dev/null @@ -1,111 +0,0 @@ -experiment_name: geobench_v2_test -run_name: test_models_saved_multiple_epochs_no_ray -defaults: - trainer_args: - precision: bf16-mixed # for these new models pretrained with bf16-mixed we should probably finetune with bf16-mixed - max_epochs: 5 - terratorch_task: - model_args: - pretrained: True - backbone: prithvi_eo_v1_100 - backbone_out_indices: - - 2 - - 5 - - 8 - - 11 - backbone_pretrained_cfg_overlay: - file: /dccstor/geofm-finetuning/pretrain_ckpts/v9_no_sea/vit_b/epoch-395-loss-0.0339_clean.pt - model_factory: PrithviModelFactory - optimizer: AdamW - -tasks: - # class - - name: big_earth_net - type: multilabel_classification - direction: max - terratorch_task: - loss: balanced_bce - model_args: - bands: - - RED - - GREEN - - BLUE - - NIR_NARROW - - SWIR_1 - - SWIR_2 - num_classes: 43 - decoder: IdentityDecoder - head_linear_after_pool: True - datamodule: - class_path: terratorch.datamodules.MBigEarthNonGeoDataModule - init_args: - partition: 0.10x_train - train_transform: - - class_path: albumentations.HorizontalFlip - init_args: - p: 0.5 - # - class_path: albumentations.RandomRotate90 - # init_args: - # p: 0.5 - - class_path: albumentations.VerticalFlip - init_args: - p: 0.5 - # - class_path: albumentations.RandomBrightnessContrast - # init_args: - # p: 0.8 - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - val_transform: - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - test_transform: - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - batch_size: 16 - num_workers: 6 - data_root: "/dccstor/geofm-finetuning/datasets/geobench/classification_v1.0" - bands: - - "RED" - - "GREEN" - - "BLUE" - - "NIR_NARROW" - - "SWIR_1" - - "SWIR_2" - optimization_except: - - decoder_channels - - head_dropout - metric: val/Multilabel_F1_Score - early_stop_patience: 5 -n_trials: 2 -save_models: False -storage_uri: /dccstor/geofm-finetuning/terratorch-iterate-test/benchmark -ray_storage_path: /dccstor/geofm-finetuning/terratorch-iterate-test/ray_storage -optimization_space: - batch_size: - - 8 - - 32 - - 64 - lr: - max: 1e-3 - min: 1e-6 - type: real - log: true - optimizer_hparams: - weight_decay: - min: 0 - max: 0.4 - type: real - model_args: - decoder_channels: - - 64 - - 128 - - 256 \ No newline at end of file diff --git a/configs/tests/geobench_v1_prithvi_cashew.yaml b/configs/tests/geobench_v1_prithvi_cashew.yaml deleted file mode 100644 index 2bb7a1c..0000000 --- a/configs/tests/geobench_v1_prithvi_cashew.yaml +++ /dev/null @@ -1,108 +0,0 @@ -experiment_name: geobench_v2_test -run_name: test_models_saved_multiple_epochs_no_ray -defaults: - trainer_args: - precision: bf16-mixed # for these new models pretrained with bf16-mixed we should probably finetune with bf16-mixed - max_epochs: 300 - terratorch_task: - model_args: - pretrained: True - backbone: prithvi_eo_v1_100 - backbone_out_indices: - - 2 - - 5 - - 8 - - 11 - backbone_pretrained_cfg_overlay: - file: /dccstor/geofm-finetuning/pretrain_ckpts/v9_no_sea/vit_b/epoch-395-loss-0.0339_clean.pt - model_factory: PrithviModelFactory - optimizer: AdamW - -tasks: - - name: cashew - type: segmentation - direction: max - metric: val/Multiclass_Jaccard_Index - early_stop_patience: 50 - terratorch_task: - loss: ce - model_args: - num_classes: 7 - bands: - - RED - - GREEN - - BLUE - - NIR_NARROW - - SWIR_1 - - SWIR_2 - decoder: UperNetDecoder - decoder_channels: 128 - decoder_scale_modules: true - datamodule: - class_path: terratorch.datamodules.MBeninSmallHolderCashewsNonGeoDataModule - init_args: - partition: 0.10x_train - train_transform: - - class_path: albumentations.HorizontalFlip - init_args: - p: 0.5 - # - class_path: albumentations.RandomRotate90 - # init_args: - # p: 0.5 - - class_path: albumentations.VerticalFlip - init_args: - p: 0.5 - # - class_path: albumentations.RandomBrightnessContrast - # init_args: - # p: 0.8 - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - val_transform: - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - test_transform: - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - batch_size: 16 - num_workers: 6 - data_root: "/dccstor/geofm-finetuning/datasets/geobench/segmentation_v1.0" - bands: - - "RED" - - "GREEN" - - "BLUE" - - "NIR_NARROW" - - "SWIR_1" - - "SWIR_2" -n_trials: 16 -save_models: False -storage_uri: /dccstor/geofm-finetuning/carlosgomes/benchmark -ray_storage_path: /dccstor/geofm-finetuning/carlosgomes/ray_storage -optimization_space: - batch_size: - - 8 - - 32 - - 64 - lr: - max: 1e-3 - min: 1e-6 - type: real - log: true - optimizer_hparams: - weight_decay: - min: 0 - max: 0.4 - type: real - model_args: - decoder_channels: - - 64 - - 128 - - 256 \ No newline at end of file diff --git a/configs/tests/geobench_v1_prithvi_chesapeake.yaml b/configs/tests/geobench_v1_prithvi_chesapeake.yaml deleted file mode 100644 index 52be24a..0000000 --- a/configs/tests/geobench_v1_prithvi_chesapeake.yaml +++ /dev/null @@ -1,104 +0,0 @@ -experiment_name: geobench_v2_test -run_name: test_models_saved_multiple_epochs_no_ray -defaults: - trainer_args: - precision: bf16-mixed # for these new models pretrained with bf16-mixed we should probably finetune with bf16-mixed - max_epochs: 300 - terratorch_task: - model_args: - pretrained: True - backbone: prithvi_eo_v1_100 - backbone_out_indices: - - 2 - - 5 - - 8 - - 11 - backbone_pretrained_cfg_overlay: - file: /dccstor/geofm-finetuning/pretrain_ckpts/v9_no_sea/vit_b/epoch-395-loss-0.0339_clean.pt - model_factory: PrithviModelFactory - optimizer: AdamW - -tasks: - - name: chesapeake - type: segmentation - direction: max - metric: val/Multiclass_Jaccard_Index - early_stop_patience: 50 - terratorch_task: - loss: ce - model_args: - decoder: UperNetDecoder - decoder_channels: 128 - decoder_scale_modules: true - bands: - - RED - - GREEN - - BLUE - - NIR_NARROW - num_classes: 7 - datamodule: - class_path: terratorch.datamodules.MChesapeakeLandcoverNonGeoDataModule - init_args: - partition: 0.10x_train - train_transform: - - class_path: albumentations.HorizontalFlip - init_args: - p: 0.5 - # - class_path: albumentations.RandomRotate90 - # init_args: - # p: 0.5 - - class_path: albumentations.VerticalFlip - init_args: - p: 0.5 - # - class_path: albumentations.RandomBrightnessContrast - # init_args: - # p: 0.8 - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - val_transform: - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - test_transform: - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - batch_size: 16 - num_workers: 6 - data_root: "/dccstor/geofm-finetuning/datasets/geobench/segmentation_v1.0" - bands: - - "RED" - - "GREEN" - - "BLUE" - - "NIR" -n_trials: 16 -save_models: False -storage_uri: /dccstor/geofm-finetuning/carlosgomes/benchmark -ray_storage_path: /dccstor/geofm-finetuning/carlosgomes/ray_storage -optimization_space: - batch_size: - - 8 - - 32 - - 64 - lr: - max: 1e-3 - min: 1e-6 - type: real - log: true - optimizer_hparams: - weight_decay: - min: 0 - max: 0.4 - type: real - model_args: - decoder_channels: - - 64 - - 128 - - 256 \ No newline at end of file diff --git a/configs/tests/geobench_v1_resnet_cashew.yaml b/configs/tests/geobench_v1_resnet_cashew.yaml deleted file mode 100644 index f07937e..0000000 --- a/configs/tests/geobench_v1_resnet_cashew.yaml +++ /dev/null @@ -1,89 +0,0 @@ -experiment_name: geobench_resnet -run_name: resnet_50_rgb_only_16_trials -bayesian_search: False -defaults: - trainer_args: - precision: bf16-mixed # for these new models pretrained with bf16-mixed we should probably finetune with bf16-mixed - max_epochs: 2 - terratorch_task: - model_args: - pretrained: True - backbone: resnet50 - optimizer: AdamW - -tasks: - - name: cashew - type: segmentation - direction: max - metric: val/Multiclass_Jaccard_Index - early_stop_patience: 50 - terratorch_task: - loss: ce - model_factory: SMPModelFactory - model_args: - num_classes: 7 - bands: - - RED - - GREEN - - BLUE - model: Unet - datamodule: - class_path: terratorch.datamodules.MBeninSmallHolderCashewsNonGeoDataModule - init_args: - partition: 0.10x_train - train_transform: - - class_path: albumentations.HorizontalFlip - init_args: - p: 0.5 - # - class_path: albumentations.RandomRotate90 - # init_args: - # p: 0.5 - - class_path: albumentations.VerticalFlip - init_args: - p: 0.5 - # - class_path: albumentations.RandomBrightnessContrast - # init_args: - # p: 0.8 - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - val_transform: - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - test_transform: - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - batch_size: 16 - num_workers: 6 - data_root: "/dccstor/geofm-finetuning/datasets/geobench/segmentation_v1.0" - bands: - - "RED" - - "GREEN" - - "BLUE" -n_trials: 16 -save_models: False -storage_uri: /dccstor/geofm-finetuning/carlosgomes/benchmark -ray_storage_path: /dccstor/geofm-finetuning/carlosgomes/ray_storage -optimization_space: - batch_size: - - 8 - - 32 - - 64 - lr: - max: 1e-3 - min: 1e-6 - type: real - log: true - optimizer_hparams: - weight_decay: - min: 0 - max: 0.4 - type: real \ No newline at end of file diff --git a/configs/tests/geobench_v1_resnet_chesapeake.yaml b/configs/tests/geobench_v1_resnet_chesapeake.yaml deleted file mode 100644 index bd842ea..0000000 --- a/configs/tests/geobench_v1_resnet_chesapeake.yaml +++ /dev/null @@ -1,90 +0,0 @@ -experiment_name: geobench_resnet -run_name: resnet_50_rgb_only_16_trials -bayesian_search: False -defaults: - trainer_args: - precision: bf16-mixed # for these new models pretrained with bf16-mixed we should probably finetune with bf16-mixed - max_epochs: 2 - terratorch_task: - model_args: - pretrained: True - backbone: resnet50 - optimizer: AdamW - -tasks: - - name: chesapeake - type: segmentation - direction: max - metric: val/Multiclass_Jaccard_Index - early_stop_patience: 50 - terratorch_task: - loss: ce - model_factory: SMPModelFactory - model_args: - model: Unet - bands: - - RED - - GREEN - - BLUE - num_classes: 7 - datamodule: - class_path: terratorch.datamodules.MChesapeakeLandcoverNonGeoDataModule - init_args: - partition: 0.10x_train - train_transform: - - class_path: albumentations.HorizontalFlip - init_args: - p: 0.5 - # - class_path: albumentations.RandomRotate90 - # init_args: - # p: 0.5 - - class_path: albumentations.VerticalFlip - init_args: - p: 0.5 - # - class_path: albumentations.RandomBrightnessContrast - # init_args: - # p: 0.8 - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - val_transform: - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - test_transform: - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - batch_size: 16 - num_workers: 6 - data_root: "/dccstor/geofm-finetuning/datasets/geobench/segmentation_v1.0" - bands: - - "RED" - - "GREEN" - - "BLUE" - -n_trials: 16 -save_models: False -storage_uri: /dccstor/geofm-finetuning/carlosgomes/benchmark -ray_storage_path: /dccstor/geofm-finetuning/carlosgomes/ray_storage -optimization_space: - batch_size: - - 8 - - 32 - - 64 - lr: - max: 1e-3 - min: 1e-6 - type: real - log: true - optimizer_hparams: - weight_decay: - min: 0 - max: 0.4 - type: real \ No newline at end of file diff --git a/configs/tests/geobench_v1_ssl4eos12_resnet50_sentinel2_all_moco_smp_unet_true.yaml b/configs/tests/geobench_v1_ssl4eos12_resnet50_sentinel2_all_moco_smp_unet_true.yaml deleted file mode 100644 index 1076979..0000000 --- a/configs/tests/geobench_v1_ssl4eos12_resnet50_sentinel2_all_moco_smp_unet_true.yaml +++ /dev/null @@ -1,402 +0,0 @@ -experiment_name: ssl4eos12_resnet50_sentinel2_all_moco_smp_unet -defaults: - trainer_args: - max_epochs: 1 - log_every_n_steps: 1 - terratorch_task: - model_args: - backbone_pretrained: True - backbone: ssl4eos12_resnet50_sentinel2_all_moco - backbone_out_indices: - - 0 - - 1 - - 2 - - 3 - - 4 - model_factory: EncoderDecoderFactory - optimizer: AdamW - -tasks: - - name: chesapeake - type: segmentation - direction: max - metric: val/Multiclass_Jaccard_Index - early_stop_patience: 5 - terratorch_task: - loss: ce - model_args: - decoder: smp_Unet - decoder_decoder_channels: - - 512 - - 256 - - 128 - - 64 - backbone_model_bands: - - RED - - GREEN - - BLUE - - NIR_NARROW - num_classes: 7 - datamodule: - class_path: terratorch.datamodules.MChesapeakeLandcoverNonGeoDataModule - init_args: - partition: "0.01x_train" - train_transform: - - class_path: albumentations.HorizontalFlip - init_args: - p: 0.5 - - class_path: albumentations.VerticalFlip - init_args: - p: 0.5 - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - val_transform: - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - test_transform: - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - batch_size: 16 - num_workers: 4 - data_root: "/dccstor/geofm-finetuning/datasets/geobench/segmentation_v1.0" - bands: - - "RED" - - "GREEN" - - "BLUE" - - "NIR" - - name: cashew - type: segmentation - direction: max - metric: val/Multiclass_Jaccard_Index - early_stop_patience: 50 - terratorch_task: - loss: ce - model_args: - num_classes: 7 - backbone_model_bands: - - "COASTAL_AEROSOL" - - "BLUE" - - "GREEN" - - "RED" - - "RED_EDGE_1" - - "RED_EDGE_2" - - "RED_EDGE_3" - - "NIR_BROAD" - - "NIR_NARROW" - - "WATER_VAPOR" - - "SWIR_1" - - "SWIR_2" - decoder: smp_Unet - decoder_decoder_channels: - - 512 - - 256 - - 128 - - 64 - datamodule: - class_path: terratorch.datamodules.MBeninSmallHolderCashewsNonGeoDataModule - init_args: - partition: "0.01x_train" - train_transform: - - class_path: albumentations.HorizontalFlip - init_args: - p: 0.5 - - class_path: albumentations.VerticalFlip - init_args: - p: 0.5 - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - val_transform: - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - test_transform: - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - batch_size: 16 - num_workers: 4 - data_root: "/dccstor/geofm-finetuning/datasets/geobench/segmentation_v1.0" - bands: - - "COASTAL_AEROSOL" - - "BLUE" - - "GREEN" - - "RED" - - "RED_EDGE_1" - - "RED_EDGE_2" - - "RED_EDGE_3" - - "NIR_BROAD" - - "NIR_NARROW" - - "WATER_VAPOR" - - "SWIR_1" - - "SWIR_2" - - name: neontree - type: segmentation - direction: max - metric: val/Multiclass_Jaccard_Index - early_stop_patience: 5 - terratorch_task: - loss: ce - model_args: - num_classes: 2 - backbone_model_bands: - - RED - - GREEN - - BLUE - decoder: smp_Unet - decoder_decoder_channels: - - 512 - - 256 - - 128 - - 64 - datamodule: - class_path: terratorch.datamodules.MNeonTreeNonGeoDataModule - init_args: - partition: "0.01x_train" - train_transform: - - class_path: albumentations.HorizontalFlip - init_args: - p: 0.5 - - class_path: albumentations.VerticalFlip - init_args: - p: 0.5 - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - val_transform: - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - test_transform: - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - batch_size: 8 - num_workers: 4 - data_root: "/dccstor/geofm-finetuning/datasets/geobench/segmentation_v1.0" - bands: - - "RED" - - "GREEN" - - "BLUE" - - name: nz_cattle - type: segmentation - direction: max - metric: val/Multiclass_Jaccard_Index - early_stop_patience: 5 - terratorch_task: - loss: ce - model_args: - backbone_model_bands: - - RED - - GREEN - - BLUE - num_classes: 2 - decoder: smp_Unet - decoder_decoder_channels: - - 512 - - 256 - - 128 - - 64 - datamodule: - class_path: terratorch.datamodules.MNzCattleNonGeoDataModule - init_args: - partition: "0.01x_train" - train_transform: - - class_path: albumentations.HorizontalFlip - init_args: - p: 0.5 - - class_path: albumentations.VerticalFlip - init_args: - p: 0.5 - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - val_transform: - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - test_transform: - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - batch_size: 16 - num_workers: 4 - data_root: "/dccstor/geofm-finetuning/datasets/geobench/segmentation_v1.0" - bands: - - "RED" - - "GREEN" - - "BLUE" - - name: pv4ger_seg - type: segmentation - direction: max - metric: val/Multiclass_Jaccard_Index - early_stop_patience: 5 - terratorch_task: - loss: ce - model_args: - decoder: smp_Unet - decoder_decoder_channels: - - 512 - - 256 - - 128 - - 64 - - backbone_model_bands: - - RED - - GREEN - - BLUE - num_classes: 2 - datamodule: - class_path: terratorch.datamodules.MPv4gerSegNonGeoDataModule - init_args: - partition: "0.01x_train" - train_transform: - - class_path: albumentations.HorizontalFlip - init_args: - p: 0.5 - - class_path: albumentations.VerticalFlip - init_args: - p: 0.5 - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - val_transform: - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - test_transform: - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - batch_size: 16 - num_workers: 4 - data_root: "/dccstor/geofm-finetuning/datasets/geobench/segmentation_v1.0" - bands: - - "RED" - - "GREEN" - - "BLUE" - - name: sa_crop_type - type: segmentation - direction: max - metric: val/Multiclass_Jaccard_Index - early_stop_patience: 5 - terratorch_task: - loss: ce - model_args: - decoder: smp_Unet - decoder_decoder_channels: - - 512 - - 256 - - 128 - - 64 - backbone_model_bands: - - "COASTAL_AEROSOL" - - "BLUE" - - "GREEN" - - "RED" - - "RED_EDGE_1" - - "RED_EDGE_2" - - "RED_EDGE_3" - - "NIR_BROAD" - - "NIR_NARROW" - - "WATER_VAPOR" - - "SWIR_1" - - "SWIR_2" - num_classes: 10 - datamodule: - class_path: terratorch.datamodules.m_SA_crop_type.MSACropTypeNonGeoDataModule - init_args: - partition: "0.01x_train" - train_transform: - - class_path: albumentations.HorizontalFlip - init_args: - p: 0.5 - - class_path: albumentations.VerticalFlip - init_args: - p: 0.5 - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - val_transform: - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - test_transform: - - class_path: albumentations.Resize - init_args: - height: 224 - width: 224 - - class_path: ToTensorV2 - batch_size: 16 - num_workers: 4 - data_root: "/dccstor/geofm-finetuning/datasets/geobench/segmentation_v1.0" - bands: - - "COASTAL_AEROSOL" - - "BLUE" - - "GREEN" - - "RED" - - "RED_EDGE_1" - - "RED_EDGE_2" - - "RED_EDGE_3" - - "NIR_BROAD" - - "NIR_NARROW" - - "WATER_VAPOR" - - "SWIR_1" - - "SWIR_2" - -n_trials: 16 -save_models: False -storage_uri: /dccstor/geofm-finetuning/terratorch-iterate-test/benchmark -ray_storage_path: /dccstor/geofm-finetuning/terratorch-iterate-test/benchmark/ray_storage_results -optimization_space: - batch_size: - - 8 - - 16 - - 32 - lr: - min: 6e-5 - max: 1e-3 - type: real - log: true - optimizer_hparams: - weight_decay: - min: 0 - max: 0.4 - type: real diff --git a/configs/tests/terratorch-iterate-configs/test_case_01/oracle/convnext_LM_iterate.yaml b/configs/tests/terratorch-iterate-configs/test_case_01/oracle/convnext_LM_iterate.yaml new file mode 100644 index 0000000..975390b --- /dev/null +++ b/configs/tests/terratorch-iterate-configs/test_case_01/oracle/convnext_LM_iterate.yaml @@ -0,0 +1,105 @@ +defaults: + terratorch_task: + model_args: + backbone: timm_convnext_large.fb_in22k + backbone_pretrained: true + model_factory: EncoderDecoderFactory + optimizer: AdamW + scheduler: ReduceLROnPlateau + scheduler_hparams: + cooldown: 0 + eps: 1.0e-08 + factor: 0.5 + min_lr: 0.0 + mode: min + patience: 5 + threshold: 0.0001 + threshold_mode: rel + trainer_args: + log_every_n_steps: 1 + max_epochs: 5 +experiment_name: convnext_LM +n_trials: 1 +optimization_space: + lr: + log: true + max: 1e-3 + min: 1e-6 + type: real +run_repetitions: 5 +save_models: false +storage_uri: ./mlflow +tasks: +- datamodule: + class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule + init_args: + batch_size: 16 + check_stackability: false + constant_scale: 1.0 + dataset_bands: + - RED + - GREEN + - BLUE + img_grep: '*train.tif' + label_grep: '*label.tif' + means: + - 104.24203383423682 + - 109.92963788132441 + - 100.98120642006803 + no_data_replace: 0 + no_label_replace: -1 + num_classes: 2 + num_workers: 16 + output_bands: + - RED + - GREEN + - BLUE + rgb_indices: + - 0 + - 1 + - 2 + stds: + - 51.593745217159935 + - 47.218880227273814 + - 45.45813303733705 + test_data_root: AerialImageDatasetTiledMergedFixedLabels_sample/ + test_label_data_root: AerialImageDatasetTiledMergedFixedLabels_sample/ + test_split: test.txt + test_transform: + - class_path: ToTensorV2 + train_data_root: AerialImageDatasetTiledMergedFixedLabels_sample/ + train_label_data_root: AerialImageDatasetTiledMergedFixedLabels_sample/ + train_split: train.txt + train_transform: + - class_path: albumentations.D4 + - class_path: ToTensorV2 + val_data_root: AerialImageDatasetTiledMergedFixedLabels_sample/ + val_label_data_root: AerialImageDatasetTiledMergedFixedLabels_sample/ + val_split: val.txt + val_transform: + - class_path: ToTensorV2 + direction: max + metric: val/loss + name: LM + terratorch_task: + freeze_backbone: false + freeze_decoder: false + ignore_index: -1 + loss: dice + model_args: + backbone: timm_convnext_large.fb_in22k + backbone_pretrained: true + decoder: UNetDecoder + decoder_channels: + - 512 + - 256 + - 128 + - 64 + head_channel_list: + - 256 + head_dropout: 0.1 + necks: null + num_classes: 2 + model_factory: EncoderDecoderFactory + plot_on_val: 2 + type: segmentation diff --git a/configs/tests/terratorch-iterate-configs/test_case_01/test_config_util__convnext.yaml b/configs/tests/terratorch-iterate-configs/test_case_01/test_config_util__convnext.yaml new file mode 100644 index 0000000..b465f60 --- /dev/null +++ b/configs/tests/terratorch-iterate-configs/test_case_01/test_config_util__convnext.yaml @@ -0,0 +1,105 @@ +defaults: + terratorch_task: + model_args: + backbone: X + backbone_pretrained: true + model_factory: EncoderDecoderFactory + optimizer: AdamW + scheduler: ReduceLROnPlateau + scheduler_hparams: + cooldown: 0 + eps: 1.0e-08 + factor: 0.5 + min_lr: 0.0 + mode: min + patience: 5 + threshold: 0.0001 + threshold_mode: rel + trainer_args: + log_every_n_steps: 1 + max_epochs: 5 +experiment_name: test_config_util__convnext +n_trials: 1 +optimization_space: + lr: + log: true + max: 1e-3 + min: 1e-6 + type: real +run_repetitions: 5 +save_models: false +storage_uri: ./mlflow +tasks: +- datamodule: + class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule + init_args: + batch_size: 16 + check_stackability: false + constant_scale: 1.0 + dataset_bands: + - RED + - GREEN + - BLUE + img_grep: '*train.tif' + label_grep: '*label.tif' + means: + - 104.24203383423682 + - 109.92963788132441 + - 100.98120642006803 + no_data_replace: 0 + no_label_replace: -1 + num_classes: 2 + num_workers: 16 + output_bands: + - RED + - GREEN + - BLUE + rgb_indices: + - 0 + - 1 + - 2 + stds: + - 51.593745217159935 + - 47.218880227273814 + - 45.45813303733705 + test_data_root: AerialImageDatasetTiledMergedFixedLabels_sample/ + test_label_data_root: AerialImageDatasetTiledMergedFixedLabels_sample/ + test_split: test.txt + test_transform: + - class_path: ToTensorV2 + train_data_root: AerialImageDatasetTiledMergedFixedLabels_sample/ + train_label_data_root: AerialImageDatasetTiledMergedFixedLabels_sample/ + train_split: train.txt + train_transform: + - class_path: albumentations.D4 + - class_path: ToTensorV2 + val_data_root: AerialImageDatasetTiledMergedFixedLabels_sample/ + val_label_data_root: AerialImageDatasetTiledMergedFixedLabels_sample/ + val_split: val.txt + val_transform: + - class_path: ToTensorV2 + direction: max + metric: val/loss + name: convnext.yaml + terratorch_task: + freeze_backbone: false + freeze_decoder: false + ignore_index: -1 + loss: dice + model_args: + backbone: timm_convnext_large.fb_in22k + backbone_pretrained: true + decoder: UNetDecoder + decoder_channels: + - 512 + - 256 + - 128 + - 64 + head_channel_list: + - 256 + head_dropout: 0.1 + necks: null + num_classes: 2 + model_factory: EncoderDecoderFactory + plot_on_val: 2 + type: segmentation diff --git a/configs/tests/terratorch-iterate-configs/test_case_02/oracle/test_config_util__encoderdecoder_eo_v2_300_model_factory.yaml b/configs/tests/terratorch-iterate-configs/test_case_02/oracle/test_config_util__encoderdecoder_eo_v2_300_model_factory.yaml new file mode 100644 index 0000000..c5d97b8 --- /dev/null +++ b/configs/tests/terratorch-iterate-configs/test_case_02/oracle/test_config_util__encoderdecoder_eo_v2_300_model_factory.yaml @@ -0,0 +1,130 @@ +defaults: + terratorch_task: + model_args: + backbone: X + backbone_pretrained: true + model_factory: EncoderDecoderFactory + optimizer: AdamW + scheduler: ReduceLROnPlateau + scheduler_hparams: + cooldown: 0 + eps: 1.0e-08 + factor: 0.5 + min_lr: 0.0 + mode: min + patience: 5 + threshold: 0.0001 + threshold_mode: rel + trainer_args: + log_every_n_steps: 1 + max_epochs: 5 +experiment_name: test_config_util__encoderdecoder_eo_v2_300_model_factory +n_trials: 1 +optimization_space: + lr: + log: true + max: 1e-3 + min: 1e-6 + type: real +run_repetitions: 5 +save_models: false +storage_uri: /u/ltizzei/test_terratorch_iterate +tasks: +- datamodule: + class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule + init_args: + allow_substring_split_file: true + batch_size: 4 + constant_scale: 1.0 + dataset_bands: + - '0' + - '1' + - '2' + - '3' + - '4' + - '5' + ignore_split_file_extensions: true + img_grep: '*_merged.tif' + label_grep: '*.mask.tif' + means: + - 0.052829564761523104 + - 0.07822514779700994 + - 0.09545302348640401 + - 0.2128596444116123 + - 0.2363016737011897 + - 0.17234100022878698 + no_data_replace: 0 + no_label_replace: -1 + num_classes: 2 + num_workers: 2 + output_bands: + - '0' + - '1' + - '2' + - '3' + - '4' + - '5' + rgb_indices: + - 0 + - 1 + - 2 + stds: + - 0.028757146620143812 + - 0.03540772770593507 + - 0.05291947163682527 + - 0.06949186937256507 + - 0.08958868240264736 + - 0.08198354165348874 + test_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation + test_label_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation + test_transform: + - class_path: ToTensorV2 + train_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/training + train_label_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/training + val_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation + val_label_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation + direction: max + metric: val/loss + name: test + terratorch_task: + freeze_backbone: false + freeze_decoder: false + ignore_index: -1 + loss: ce + model_args: + backbone: prithvi_eo_v2_300 + backbone_bands: + - '0' + - '1' + - '2' + - '3' + - '4' + - '5' + backbone_drop_path: 0.1 + backbone_pretrained: true + decoder: UNetDecoder + decoder_channels: + - 512 + - 256 + - 128 + - 64 + head_dropout: 0.1 + necks: + - indices: + - 5 + - 11 + - 17 + - 23 + name: SelectIndices + - name: ReshapeTokensToImage + - name: LearnedInterpolateToPyramidal + num_classes: 2 + model_factory: EncoderDecoderFactory + plot_on_val: 2 + tiled_inference_parameters: + average_patches: true + h_crop: 512 + h_stride: 448 + w_crop: 512 + w_stride: 448 + type: segmentation diff --git a/configs/tests/terratorch-iterate-configs/test_case_02/test_config_util__encoderdecoder_eo_v2_300_model_factory.yaml b/configs/tests/terratorch-iterate-configs/test_case_02/test_config_util__encoderdecoder_eo_v2_300_model_factory.yaml new file mode 100644 index 0000000..dac536c --- /dev/null +++ b/configs/tests/terratorch-iterate-configs/test_case_02/test_config_util__encoderdecoder_eo_v2_300_model_factory.yaml @@ -0,0 +1,130 @@ +defaults: + terratorch_task: + model_args: + backbone: X + backbone_pretrained: true + model_factory: EncoderDecoderFactory + optimizer: AdamW + scheduler: ReduceLROnPlateau + scheduler_hparams: + cooldown: 0 + eps: 1.0e-08 + factor: 0.5 + min_lr: 0.0 + mode: min + patience: 5 + threshold: 0.0001 + threshold_mode: rel + trainer_args: + log_every_n_steps: 1 + max_epochs: 5 +experiment_name: test_config_util__encoderdecoder_eo_v2_300_model_factory +n_trials: 1 +optimization_space: + lr: + log: true + max: 1e-3 + min: 1e-6 + type: real +run_repetitions: 5 +save_models: false +storage_uri: ./mlflow +tasks: +- datamodule: + class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule + init_args: + allow_substring_split_file: true + batch_size: 4 + constant_scale: 1.0 + dataset_bands: + - '0' + - '1' + - '2' + - '3' + - '4' + - '5' + ignore_split_file_extensions: true + img_grep: '*_merged.tif' + label_grep: '*.mask.tif' + means: + - 0.052829564761523104 + - 0.07822514779700994 + - 0.09545302348640401 + - 0.2128596444116123 + - 0.2363016737011897 + - 0.17234100022878698 + no_data_replace: 0 + no_label_replace: -1 + num_classes: 2 + num_workers: 2 + output_bands: + - '0' + - '1' + - '2' + - '3' + - '4' + - '5' + rgb_indices: + - 0 + - 1 + - 2 + stds: + - 0.028757146620143812 + - 0.03540772770593507 + - 0.05291947163682527 + - 0.06949186937256507 + - 0.08958868240264736 + - 0.08198354165348874 + test_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation + test_label_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation + test_transform: + - class_path: ToTensorV2 + train_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/training + train_label_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/training + val_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation + val_label_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation + direction: max + metric: val/loss + name: test + terratorch_task: + freeze_backbone: false + freeze_decoder: false + ignore_index: -1 + loss: ce + model_args: + backbone: prithvi_eo_v2_300 + backbone_bands: + - '0' + - '1' + - '2' + - '3' + - '4' + - '5' + backbone_drop_path: 0.1 + backbone_pretrained: true + decoder: UNetDecoder + decoder_channels: + - 512 + - 256 + - 128 + - 64 + head_dropout: 0.1 + necks: + - indices: + - 5 + - 11 + - 17 + - 23 + name: SelectIndices + - name: ReshapeTokensToImage + - name: LearnedInterpolateToPyramidal + num_classes: 2 + model_factory: EncoderDecoderFactory + plot_on_val: 2 + tiled_inference_parameters: + average_patches: true + h_crop: 512 + h_stride: 448 + w_crop: 512 + w_stride: 448 + type: segmentation diff --git a/configs/tests/terratorch-iterate-configs/test_case_03/test_config_util__encoder_decoder_timm_resnet101_model_factory.yaml b/configs/tests/terratorch-iterate-configs/test_case_03/test_config_util__encoder_decoder_timm_resnet101_model_factory.yaml new file mode 100644 index 0000000..7f63285 --- /dev/null +++ b/configs/tests/terratorch-iterate-configs/test_case_03/test_config_util__encoder_decoder_timm_resnet101_model_factory.yaml @@ -0,0 +1,123 @@ +defaults: + terratorch_task: + model_args: + backbone: X + backbone_pretrained: true + model_factory: EncoderDecoderFactory + optimizer: AdamW + scheduler: ReduceLROnPlateau + scheduler_hparams: + cooldown: 0 + eps: 1.0e-08 + factor: 0.5 + min_lr: 0.0 + mode: min + patience: 5 + threshold: 0.0001 + threshold_mode: rel + trainer_args: + log_every_n_steps: 1 + max_epochs: 5 +experiment_name: test_config_util__encoder_decoder_timm_resnet101_model_factory +n_trials: 1 +optimization_space: + lr: + log: true + max: 1e-3 + min: 1e-6 + type: real +run_repetitions: 5 +save_models: false +storage_uri: ./mlflow +tasks: +- datamodule: + class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule + init_args: + allow_substring_split_file: true + batch_size: 4 + constant_scale: 1.0 + dataset_bands: + - '0' + - '1' + - '2' + - '3' + - '4' + - '5' + ignore_split_file_extensions: true + img_grep: '*_merged.tif' + label_grep: '*.mask.tif' + means: + - 0.052829564761523104 + - 0.07822514779700994 + - 0.09545302348640401 + - 0.2128596444116123 + - 0.2363016737011897 + - 0.17234100022878698 + no_data_replace: 0 + no_label_replace: -1 + num_classes: 2 + num_workers: 2 + output_bands: + - '0' + - '1' + - '2' + - '3' + - '4' + - '5' + rgb_indices: + - 0 + - 1 + - 2 + stds: + - 0.028757146620143812 + - 0.03540772770593507 + - 0.05291947163682527 + - 0.06949186937256507 + - 0.08958868240264736 + - 0.08198354165348874 + test_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation + test_label_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation + test_transform: + - class_path: ToTensorV2 + train_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/training + train_label_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/training + val_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation + val_label_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation + direction: max + metric: val/loss + name: test + terratorch_task: + freeze_backbone: false + freeze_decoder: false + ignore_index: -1 + loss: ce + model_args: + backbone: timm_resnet101 + backbone_in_chans: 6 + backbone_pretrained: true + decoder: UNetDecoder + decoder_channels: + - 512 + - 256 + - 128 + - 64 + head_channel_list: + - 256 + head_dropout: 0.1 + necks: + - indices: + - 0 + - 1 + - 2 + - 3 + name: SelectIndices + num_classes: 3 + model_factory: EncoderDecoderFactory + plot_on_val: 2 + tiled_inference_parameters: + average_patches: true + h_crop: 224 + h_stride: 196 + w_crop: 224 + w_stride: 196 + type: segmentation diff --git a/configs/tests/terratorch_configs/test_case_01/convnext.yaml b/configs/tests/terratorch_configs/test_case_01/convnext.yaml new file mode 100644 index 0000000..c983aeb --- /dev/null +++ b/configs/tests/terratorch_configs/test_case_01/convnext.yaml @@ -0,0 +1,156 @@ +# lightning.pytorch==2.1.1 +seed_everything: 0 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 16-mixed + + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + # - class_path: ModelCheckpoint + # init_args: + # mode: min + # monitor: val/loss + # filename: best-{epoch:02d} + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 20 + # ---- Early stop if ---- + # ---- Early stop endif ---- + max_epochs: 50 + check_val_every_n_epoch: 1 + log_every_n_steps: 5 + enable_checkpointing: false + default_root_dir: logs/ + +data: + class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule + init_args: + batch_size: 16 + num_workers: 16 + no_label_replace: -1 + no_data_replace: 0 + constant_scale: 1.0 + dataset_bands: + - 'RED' + - 'GREEN' + - 'BLUE' + + output_bands: + - 'RED' + - 'GREEN' + - 'BLUE' + + rgb_indices: + - 0 + - 1 + - 2 + + train_data_root: AerialImageDatasetTiledMergedFixedLabels_sample/ + train_label_data_root: AerialImageDatasetTiledMergedFixedLabels_sample/ + val_data_root: AerialImageDatasetTiledMergedFixedLabels_sample/ + val_label_data_root: AerialImageDatasetTiledMergedFixedLabels_sample/ + test_data_root: AerialImageDatasetTiledMergedFixedLabels_sample/ + test_label_data_root: AerialImageDatasetTiledMergedFixedLabels_sample/ + img_grep: "*train.tif" + label_grep: "*label.tif" + train_split: train.txt + val_split: val.txt + test_split: test.txt + # constant_scale: 0.0039 + # means: [0.485, 0.456, 0.406] + # stds: [0.229, 0.224, 0.225] + means: + - 104.24203383423682 + - 109.92963788132441 + - 100.98120642006803 + + stds: + - 51.593745217159935 + - 47.218880227273814 + - 45.45813303733705 + + check_stackability: false + + num_classes: 2 + + train_transform: + - class_path: albumentations.D4 + - class_path: ToTensorV2 + val_transform: + - class_path: ToTensorV2 + test_transform: + - class_path: ToTensorV2 +model: + class_path: terratorch.tasks.SemanticSegmentationTask + init_args: + model_factory: EncoderDecoderFactory + model_args: + backbone: timm_convnext_large.fb_in22k + num_classes: 2 + backbone_pretrained: true + necks: + # - name: SelectIndices + # indices: [1,2,3,4] + decoder: UNetDecoder + decoder_channels: [512, 256, 128, 64] + head_channel_list: [256] + head_dropout: 0.1 + loss: dice + # loss: ce + plot_on_val: 2 + ignore_index: -1 + freeze_backbone: false + freeze_decoder: false + + # tiled_inference_parameters: + # h_crop: 224 + # h_stride: 198 + # w_crop: 224 + # w_stride: 198 + # average_patches: True + +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 5.0e-05 + # betas: + # - 0.9 + # - 0.999 + # eps: 1.0e-08 + # weight_decay: 0.05 + # amsgrad: false + # maximize: false + # capturable: false + # differentiable: false + # ---- Optimizer stop if ---- +# lr_scheduler: +# class_path: CosineAnnealingLR +# init_args: +# T_max: 20 + +# lr_scheduler_interval: step +# lr_scheduler: +# class_path: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts +# init_args: +# T_0: 1000 # first cycle: 1000 steps +# T_mult: 2 # cycles: 1000, 2000, 4000, ... (fits well in 10k) +# eta_min: 1.0e-6 +lr_scheduler: + class_path: lightning.pytorch.cli.ReduceLROnPlateau + init_args: + monitor: val/loss + mode: min + factor: 0.5 + patience: 5 + threshold: 0.0001 + threshold_mode: rel + cooldown: 0 + min_lr: 0.0 + eps: 1.0e-08 diff --git a/configs/tests/terratorch_configs/test_case_02/test_encoderdecoder_eo_v2_300_model_factory.yaml b/configs/tests/terratorch_configs/test_case_02/test_encoderdecoder_eo_v2_300_model_factory.yaml new file mode 100644 index 0000000..c12eee4 --- /dev/null +++ b/configs/tests/terratorch_configs/test_case_02/test_encoderdecoder_eo_v2_300_model_factory.yaml @@ -0,0 +1,156 @@ +# lightning.pytorch==2.1.1 +seed_everything: 0 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 16-mixed + logger: true + max_epochs: 2 + + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + # ---- Early stop if ---- + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 20 + # ---- Early stop endif ---- + - class_path: ModelCheckpoint + init_args: + dirpath: /dccstor/terratorch/tmp/eo_v2_300/ + mode: min + monitor: val/loss + filename: best-state_dict-{epoch:02d} + save_weights_only: True + check_val_every_n_epoch: 1 + log_every_n_steps: 50 + enable_checkpointing: true + default_root_dir: /dccstor/terratorch/tmp/eo_v2_300/ + +data: + class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule + init_args: + batch_size: 4 + num_workers: 2 + no_label_replace: -1 + no_data_replace: 0 + constant_scale: 1.0 + dataset_bands: + - '0' + - '1' + - '2' + - '3' + - '4' + - '5' + + output_bands: + - '0' + - '1' + - '2' + - '3' + - '4' + - '5' + + rgb_indices: + - 0 + - 1 + - 2 + + train_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/training + train_label_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/training + val_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation + val_label_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation + test_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation + test_label_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation + # Splits not available in ccc for burnscars data + # train_split: /data//geodata-060bbc44822a11efb3260a580a830dad/split_files/train_data.txt + # test_split: /data//geodata-060bbc44822a11efb3260a580a830dad/split_files/test_data.txt + # val_split: /data//geodata-060bbc44822a11efb3260a580a830dad/split_files/val_data.txt + ignore_split_file_extensions: true + allow_substring_split_file: true + img_grep: "*_merged.tif" + label_grep: "*.mask.tif" + means: + - 0.052829564761523104 + - 0.07822514779700994 + - 0.09545302348640401 + - 0.2128596444116123 + - 0.2363016737011897 + - 0.17234100022878698 + + stds: + - 0.028757146620143812 + - 0.03540772770593507 + - 0.05291947163682527 + - 0.06949186937256507 + - 0.08958868240264736 + - 0.08198354165348874 + + num_classes: 2 + # ---- train_transform if ---- + # ---- train_transform endif ---- + + # if backbone is prithvi-EO-v2 + test_transform: + - class_path: ToTensorV2 +model: + class_path: terratorch.tasks.SemanticSegmentationTask + init_args: + model_args: + backbone_pretrained: true + backbone: prithvi_eo_v2_300 + # backbone_ckpt_path: /terratorch/gfm_models/prithvi_eo_v2_300/Prithvi_EO_V2_300M.pt + backbone_drop_path: 0.1 + backbone_bands: + - '0' + - '1' + - '2' + - '3' + - '4' + - '5' + + + necks: + - name: SelectIndices + indices: [5, 11, 17, 23] # 300M models + - name: ReshapeTokensToImage # required + - name: LearnedInterpolateToPyramidal + decoder: UNetDecoder + #TODO user provided channels + decoder_channels: [512, 256, 128, 64] + num_classes: 2 + head_dropout: 0.1 + model_factory: EncoderDecoderFactory + loss: ce + plot_on_val: 2 + ignore_index: -1 + freeze_backbone: false + freeze_decoder: false + + # ---- optimizer start ---- + # ---- optimizer end ---- + + tiled_inference_parameters: + h_crop: 512 + h_stride: 448 + w_crop: 512 + w_stride: 448 + average_patches: True + +optimizer: + class_path: torch.optim.Adam + init_args: + # ---- Optimizer start if ---- + lr: 6e-05 + + weight_decay: 0.05 + # ---- Optimizer stop if ---- +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss \ No newline at end of file diff --git a/configs/tests/terratorch_configs/test_case_03/test_encoder_decoder_timm_resnet101_model_factory.yaml b/configs/tests/terratorch_configs/test_case_03/test_encoder_decoder_timm_resnet101_model_factory.yaml new file mode 100644 index 0000000..253de10 --- /dev/null +++ b/configs/tests/terratorch_configs/test_case_03/test_encoder_decoder_timm_resnet101_model_factory.yaml @@ -0,0 +1,154 @@ +################################################################ +# Licensed Materials - Property of IBM +# "Restricted Materials of IBM" +# Copyright IBM Corp. 2025 ALL RIGHTS RESERVED +################################################################ + + +# lightning.pytorch==2.1.1 +seed_everything: 0 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 16-mixed + logger: true + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + # ---- Early stop if ---- + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 20 + # ---- Early stop endif ---- + - class_path: ModelCheckpoint + init_args: + dirpath: /dccstor/terratorch/tmp/timm_resnet101/ + mode: min + monitor: val/loss + filename: best-state_dict-{epoch:02d} + save_weights_only: True + + max_epochs: 2 + check_val_every_n_epoch: 1 + log_every_n_steps: 50 + enable_checkpointing: true + default_root_dir: /dccstor/terratorch/tmp/timm_resnet101/ +data: + class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule + init_args: + batch_size: 4 + num_workers: 2 + no_label_replace: -1 + no_data_replace: 0 + constant_scale: 1.0 + dataset_bands: + - '0' + - '1' + - '2' + - '3' + - '4' + - '5' + + output_bands: + - '0' + - '1' + - '2' + - '3' + - '4' + - '5' + + rgb_indices: + - 0 + - 1 + - 2 + + train_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/training + train_label_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/training + val_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation + val_label_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation + test_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation + test_label_data_root: /dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation + # Splits not available in ccc for burnscars data + # train_split: /data//geodata-060bbc44822a11efb3260a580a830dad/split_files/train_data.txt + # test_split: /data//geodata-060bbc44822a11efb3260a580a830dad/split_files/test_data.txt + # val_split: /data//geodata-060bbc44822a11efb3260a580a830dad/split_files/val_data.txt + ignore_split_file_extensions: true + allow_substring_split_file: true + img_grep: "*_merged.tif" + label_grep: "*.mask.tif" + means: + - 0.052829564761523104 + - 0.07822514779700994 + - 0.09545302348640401 + - 0.2128596444116123 + - 0.2363016737011897 + - 0.17234100022878698 + + stds: + - 0.028757146620143812 + - 0.03540772770593507 + - 0.05291947163682527 + - 0.06949186937256507 + - 0.08958868240264736 + - 0.08198354165348874 + + num_classes: 2 + # ---- train_transform if ---- + # ---- train_transform endif ---- + + # if backbone is prithvi-EO-v2 + test_transform: + - class_path: ToTensorV2 +model: + class_path: terratorch.tasks.SemanticSegmentationTask + init_args: + model_args: + backbone: timm_resnet101 # timm_resnet34 , timm_resnet18 , timm_resnet50 , timm_resnet101 , timm_resnet152 + backbone_pretrained: true + num_classes: 3 + backbone_in_chans: 6 # To be used with RGB when pretrained, can be more if not retrained + necks: + - name: SelectIndices + indices: [0, 1, 2, 3] + decoder: UNetDecoder + #TODO user provided channels + decoder_channels: [512, 256, 128, 64] + head_channel_list: + - 256 + + head_dropout: 0.1 + + model_factory: EncoderDecoderFactory + loss: ce + plot_on_val: 2 + ignore_index: -1 + freeze_backbone: false + freeze_decoder: false + + # ---- optimizer start ---- + # ---- optimizer end ---- + + tiled_inference_parameters: + h_crop: 224 + h_stride: 196 + w_crop: 224 + w_stride: 196 + average_patches: True + +optimizer: + class_path: torch.optim.Adam + init_args: + # ---- Optimizer start if ---- + lr: 6e-05 + + weight_decay: 0.05 + # ---- Optimizer stop if ---- +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss \ No newline at end of file diff --git a/plotting/plot_results_mlflow.ipynb b/plotting/plot_results_mlflow.ipynb index 407e568..5d3a752 100644 --- a/plotting/plot_results_mlflow.ipynb +++ b/plotting/plot_results_mlflow.ipynb @@ -19,16 +19,24 @@ "outputs": [], "source": [ "def add_means_to_df(df, classification_datasets, segmentation_datasets):\n", - " class_means = df[df[\"Task\"].isin(classification_datasets)].groupby('Model', as_index=False).agg({'Best Score': 'mean'})\n", - " seg_means = df[df[\"Task\"].isin(segmentation_datasets)].groupby('Model', as_index=False).agg({'Best Score': 'mean'})\n", + " class_means = (\n", + " df[df[\"Task\"].isin(classification_datasets)]\n", + " .groupby(\"Model\", as_index=False)\n", + " .agg({\"Best Score\": \"mean\"})\n", + " )\n", + " seg_means = (\n", + " df[df[\"Task\"].isin(segmentation_datasets)]\n", + " .groupby(\"Model\", as_index=False)\n", + " .agg({\"Best Score\": \"mean\"})\n", + " )\n", "\n", - " class_means['Task'] = 'Classification Mean'\n", - " class_means['Metric'] = 'Mean' # You can adjust this as needed\n", - " class_means['Hyperparameters'] = None # Or fill with appropriate value\n", + " class_means[\"Task\"] = \"Classification Mean\"\n", + " class_means[\"Metric\"] = \"Mean\" # You can adjust this as needed\n", + " class_means[\"Hyperparameters\"] = None # Or fill with appropriate value\n", "\n", - " seg_means['Task'] = 'Segmentation Mean'\n", - " seg_means['Metric'] = 'Mean' # You can adjust this as needed\n", - " seg_means['Hyperparameters'] = None # Or fill with appropriate value\n", + " seg_means[\"Task\"] = \"Segmentation Mean\"\n", + " seg_means[\"Metric\"] = \"Mean\" # You can adjust this as needed\n", + " seg_means[\"Hyperparameters\"] = None # Or fill with appropriate value\n", "\n", " df = pd.concat([df, class_means, seg_means], ignore_index=True)\n", " return df" @@ -40,24 +48,60 @@ "metadata": {}, "outputs": [], "source": [ - "results_prithvi_subset_10 = pd.read_json(\"results_table_prithvi_subset_10.json\", orient=\"split\")\n", - "results_prithvi_swin_b_new_subset_10 = pd.read_json(\"results_table_prithvi_b_subset_new.json\", orient=\"split\")\n", - "results_scratch_subset_10 = pd.read_json(\"results_table_scratch_subset_10.json\", orient=\"split\")\n", - "results_imagenet_subset_10 = pd.read_json(\"results_table_imagenet_subset_10.json\", orient=\"split\")\n", - "results_imagenet_resnet_subset_10 = pd.read_json(\"results_table_resnet_rgb.json\", orient=\"split\")\n", - "results_prithvi_l_subset_10_old = pd.read_json(\"results_table_prithvi_l_subset_fixed.json\", orient=\"split\")\n", - "results_prithvi_l_subset_10 = pd.read_json(\"results_table_prithvi_l_subset_fixed_new.json\", orient=\"split\")\n", - "results_prithvi_h_subset_10 = pd.read_json(\"results_table_prithvi_h_subset.json\", orient=\"split\")\n", - "results_prithvi_l_subset_10_mask = pd.read_json(\"results_table_prithvi_l_subset_mask.json\", orient=\"split\")\n", - "results_prithvi_l_subset_10_fp32 = pd.read_json(\"results_table_prithvi_l_subset_fp32.json\", orient=\"split\")\n", - "prithvi_l_subset_coords_10 = pd.read_json(\"results_table_prithvi_l_subset_coords_pre_no_ft.json\", orient=\"split\")\n", - "prithvi_swin_l_subset_10 = pd.read_json(\"results_table_prithvi_swin_l_subset.json\", orient=\"split\")\n", - "results_prithvi_l_full_pretrain_subset_10 = pd.read_json(\"results_table_prithvi_l_full_pretrain_subset.json\", orient=\"split\")\n", - "results_prithvi_b_subset_10 = pd.read_json(\"results_table_prithvi_vit_b_subset.json\", orient=\"split\")\n", - "results_prithvi_b_subset_10_new = pd.read_json(\"results_table_prithvi_vit_b_subset_new.json\", orient=\"split\")\n", - "results_prithvi_b_os_subset_10 = pd.read_json(\"results_table_vit_b_os_subset.json\", orient=\"split\")\n", - "results_prithvi_3d_subset_10 = pd.read_json(\"results_table_swin_3d_subset_10.json\", orient=\"split\")\n", - "results_scalemae_subset_10 = pd.read_json(\"results_table_scalemae_subset.json\", orient=\"split\")\n", + "results_prithvi_subset_10 = pd.read_json(\n", + " \"results_table_prithvi_subset_10.json\", orient=\"split\"\n", + ")\n", + "results_prithvi_swin_b_new_subset_10 = pd.read_json(\n", + " \"results_table_prithvi_b_subset_new.json\", orient=\"split\"\n", + ")\n", + "results_scratch_subset_10 = pd.read_json(\n", + " \"results_table_scratch_subset_10.json\", orient=\"split\"\n", + ")\n", + "results_imagenet_subset_10 = pd.read_json(\n", + " \"results_table_imagenet_subset_10.json\", orient=\"split\"\n", + ")\n", + "results_imagenet_resnet_subset_10 = pd.read_json(\n", + " \"results_table_resnet_rgb.json\", orient=\"split\"\n", + ")\n", + "results_prithvi_l_subset_10_old = pd.read_json(\n", + " \"results_table_prithvi_l_subset_fixed.json\", orient=\"split\"\n", + ")\n", + "results_prithvi_l_subset_10 = pd.read_json(\n", + " \"results_table_prithvi_l_subset_fixed_new.json\", orient=\"split\"\n", + ")\n", + "results_prithvi_h_subset_10 = pd.read_json(\n", + " \"results_table_prithvi_h_subset.json\", orient=\"split\"\n", + ")\n", + "results_prithvi_l_subset_10_mask = pd.read_json(\n", + " \"results_table_prithvi_l_subset_mask.json\", orient=\"split\"\n", + ")\n", + "results_prithvi_l_subset_10_fp32 = pd.read_json(\n", + " \"results_table_prithvi_l_subset_fp32.json\", orient=\"split\"\n", + ")\n", + "prithvi_l_subset_coords_10 = pd.read_json(\n", + " \"results_table_prithvi_l_subset_coords_pre_no_ft.json\", orient=\"split\"\n", + ")\n", + "prithvi_swin_l_subset_10 = pd.read_json(\n", + " \"results_table_prithvi_swin_l_subset.json\", orient=\"split\"\n", + ")\n", + "results_prithvi_l_full_pretrain_subset_10 = pd.read_json(\n", + " \"results_table_prithvi_l_full_pretrain_subset.json\", orient=\"split\"\n", + ")\n", + "results_prithvi_b_subset_10 = pd.read_json(\n", + " \"results_table_prithvi_vit_b_subset.json\", orient=\"split\"\n", + ")\n", + "results_prithvi_b_subset_10_new = pd.read_json(\n", + " \"results_table_prithvi_vit_b_subset_new.json\", orient=\"split\"\n", + ")\n", + "results_prithvi_b_os_subset_10 = pd.read_json(\n", + " \"results_table_vit_b_os_subset.json\", orient=\"split\"\n", + ")\n", + "results_prithvi_3d_subset_10 = pd.read_json(\n", + " \"results_table_swin_3d_subset_10.json\", orient=\"split\"\n", + ")\n", + "results_scalemae_subset_10 = pd.read_json(\n", + " \"results_table_scalemae_subset.json\", orient=\"split\"\n", + ")\n", "results_satlas_subset_10 = pd.read_json(\"results_satlas_subset_10.json\", orient=\"split\")" ] }, @@ -67,7 +111,9 @@ "metadata": {}, "outputs": [], "source": [ - "results_prithvi_subset_10[\"Model\"] = \"Prithvi Swin B (Old Training Dataset - New Dataset model training currently)\"" + "results_prithvi_subset_10[\"Model\"] = (\n", + " \"Prithvi Swin B (Old Training Dataset - New Dataset model training currently)\"\n", + ")" ] }, { @@ -103,7 +149,9 @@ "metadata": {}, "outputs": [], "source": [ - "results_prithvi_l_full_pretrain_subset_10[\"Model\"] = \"Prithvi ViT L (New Training Dataset)\"" + "results_prithvi_l_full_pretrain_subset_10[\"Model\"] = (\n", + " \"Prithvi ViT L (New Training Dataset)\"\n", + ")" ] }, { @@ -112,7 +160,9 @@ "metadata": {}, "outputs": [], "source": [ - "results_prithvi_l_subset_10[\"Model\"] = \"Prithvi ViT L (1/3 pretraining, Mask ratio 0.75)\"" + "results_prithvi_l_subset_10[\"Model\"] = (\n", + " \"Prithvi ViT L (1/3 pretraining, Mask ratio 0.75)\"\n", + ")" ] }, { @@ -121,7 +171,9 @@ "metadata": {}, "outputs": [], "source": [ - "results_prithvi_h_subset_10[\"Model\"] = \"Prithvi ViT H (1/3 pretraining, Mask ratio 0.75)\"" + "results_prithvi_h_subset_10[\"Model\"] = (\n", + " \"Prithvi ViT H (1/3 pretraining, Mask ratio 0.75)\"\n", + ")" ] }, { @@ -130,7 +182,9 @@ "metadata": {}, "outputs": [], "source": [ - "results_prithvi_l_subset_10_mask[\"Model\"] = \"Prithvi ViT L (1/3 pretraining, Mask Ratio 0.9)\"" + "results_prithvi_l_subset_10_mask[\"Model\"] = (\n", + " \"Prithvi ViT L (1/3 pretraining, Mask Ratio 0.9)\"\n", + ")" ] }, { @@ -238,8 +292,25 @@ "metadata": {}, "outputs": [], "source": [ - "df = pd.concat([results_scratch_subset_10, results_imagenet_subset_10, results_imagenet_resnet_subset_10, results_scalemae_subset_10, results_prithvi_subset_10, results_prithvi_swin_b_new_subset_10, prithvi_swin_l_subset_10, results_prithvi_3d_subset_10, results_prithvi_h_subset_10, results_prithvi_l_full_pretrain_subset_10, results_prithvi_l_subset_10_fp32, prithvi_l_subset_coords_10, results_prithvi_b_subset_10_new, results_prithvi_b_os_subset_10, results_satlas_subset_10\n", - "])" + "df = pd.concat(\n", + " [\n", + " results_scratch_subset_10,\n", + " results_imagenet_subset_10,\n", + " results_imagenet_resnet_subset_10,\n", + " results_scalemae_subset_10,\n", + " results_prithvi_subset_10,\n", + " results_prithvi_swin_b_new_subset_10,\n", + " prithvi_swin_l_subset_10,\n", + " results_prithvi_3d_subset_10,\n", + " results_prithvi_h_subset_10,\n", + " results_prithvi_l_full_pretrain_subset_10,\n", + " results_prithvi_l_subset_10_fp32,\n", + " prithvi_l_subset_coords_10,\n", + " results_prithvi_b_subset_10_new,\n", + " results_prithvi_b_os_subset_10,\n", + " results_satlas_subset_10,\n", + " ]\n", + ")" ] }, { @@ -248,13 +319,35 @@ "metadata": {}, "outputs": [], "source": [ - "name_mapping = {\"big_earth_net\": \"m-bigearthnet\", \"brick_kiln\": \"m-brick-kiln\", \"eurosat\": \"m-eurosat\", \"forestnet\": \"m-forestnet\", \"pv4ger\": \"m-pv4ger\", \"so2sat\": \"m-so2sat\", \"neontree\": \"m-NeonTree\", \"sa_crop_type\": \"m-SA-crop-type\", \"cashew\": \"m-cashew-plant\", \"chesapeake\": \"m-chesapeake\", \"nz_cattle\": \"m-nz-cattle\", \"pv4ger_seg\": \"m-pv4ger-seg\"}\n", + "name_mapping = {\n", + " \"big_earth_net\": \"m-bigearthnet\",\n", + " \"brick_kiln\": \"m-brick-kiln\",\n", + " \"eurosat\": \"m-eurosat\",\n", + " \"forestnet\": \"m-forestnet\",\n", + " \"pv4ger\": \"m-pv4ger\",\n", + " \"so2sat\": \"m-so2sat\",\n", + " \"neontree\": \"m-NeonTree\",\n", + " \"sa_crop_type\": \"m-SA-crop-type\",\n", + " \"cashew\": \"m-cashew-plant\",\n", + " \"chesapeake\": \"m-chesapeake\",\n", + " \"nz_cattle\": \"m-nz-cattle\",\n", + " \"pv4ger_seg\": \"m-pv4ger-seg\",\n", + "}\n", "df[\"Task\"] = df[\"Task\"].map(name_mapping)\n", "\n", - "classification_datasets = [\"m-bigearthnet\", \"m-brick-kiln\", \"m-eurosat\", \"m-forestnet\", \"m-pv4ger\", \"m-so2sat\"]\n", + "classification_datasets = [\n", + " \"m-bigearthnet\",\n", + " \"m-brick-kiln\",\n", + " \"m-eurosat\",\n", + " \"m-forestnet\",\n", + " \"m-pv4ger\",\n", + " \"m-so2sat\",\n", + "]\n", "# exclude bigearthnet for now\n", "# classification_datasets = [\"m-brick-kiln\", \"m-eurosat\", \"m-forestnet\", \"m-pv4ger\", \"m-so2sat\"]\n", - "segmentation_datasets = list(set(df[\"Task\"].unique().tolist()) - set(classification_datasets))\n" + "segmentation_datasets = list(\n", + " set(df[\"Task\"].unique().tolist()) - set(classification_datasets)\n", + ")" ] }, { @@ -272,7 +365,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "dataset_res = {\n", " \"m-bigearthnet\": \"10m\",\n", " \"m-so2sat\": \"10m\",\n", @@ -285,7 +377,7 @@ " \"m-cashew-plant\": \"10m\",\n", " \"m-SA-crop-type\": \"10m\",\n", " \"m-nz-cattle\": \"0.1m\",\n", - " \"m-NeonTree\": \"0.1m\"\n", + " \"m-NeonTree\": \"0.1m\",\n", "}\n", "\n", "dataset_instrument = {\n", @@ -300,7 +392,7 @@ " \"m-cashew-plant\": \"S2\",\n", " \"m-SA-crop-type\": \"S2\",\n", " \"m-nz-cattle\": \"RGB\",\n", - " \"m-NeonTree\": \"RGB + Hyper\"\n", + " \"m-NeonTree\": \"RGB + Hyper\",\n", "}\n", "\n", "img_size = {\n", @@ -315,14 +407,16 @@ " \"m-cashew-plant\": 256,\n", " \"m-SA-crop-type\": 256,\n", " \"m-nz-cattle\": 500,\n", - " \"m-NeonTree\": 400\n", + " \"m-NeonTree\": 400,\n", "}\n", "\n", - "dataset_name_map = {name: f\"{name}\\n {dataset_instrument[name]} @ {dataset_res[name]}\" for name in img_size.keys()}\n", + "dataset_name_map = {\n", + " name: f\"{name}\\n {dataset_instrument[name]} @ {dataset_res[name]}\"\n", + " for name in img_size.keys()\n", + "}\n", "dataset_name_map[\"Segmentation Mean\"] = \"Segmentation Mean\"\n", "dataset_name_map[\"Classification Mean\"] = \"Classification Mean\"\n", - "df[\"Task\"] = df[\"Task\"].map(dataset_name_map)\n", - "\n" + "df[\"Task\"] = df[\"Task\"].map(dataset_name_map)" ] }, { @@ -396,9 +490,15 @@ ], "source": [ "g = sns.catplot(\n", - " data=df, kind=\"bar\",\n", - " x=\"Task\", y=\"Best Score\", hue=\"Model\",\n", - " errorbar=\"sd\", palette=\"dark\", alpha=.6, height=10\n", + " data=df,\n", + " kind=\"bar\",\n", + " x=\"Task\",\n", + " y=\"Best Score\",\n", + " hue=\"Model\",\n", + " errorbar=\"sd\",\n", + " palette=\"dark\",\n", + " alpha=0.6,\n", + " height=10,\n", ")\n", "g.despine(left=True)\n", "g.set_axis_labels(\"Dataset\", \"Metric\")\n", @@ -413,8 +513,18 @@ "metadata": {}, "outputs": [], "source": [ - "df_subset = pd.concat([results_imagenet_subset_10, results_imagenet_resnet_subset_10, results_scalemae_subset_10, results_prithvi_subset_10, results_prithvi_swin_b_new_subset_10, results_prithvi_l_full_pretrain_subset_10,results_prithvi_b_subset_10_new, results_satlas_subset_10\n", - "])\n", + "df_subset = pd.concat(\n", + " [\n", + " results_imagenet_subset_10,\n", + " results_imagenet_resnet_subset_10,\n", + " results_scalemae_subset_10,\n", + " results_prithvi_subset_10,\n", + " results_prithvi_swin_b_new_subset_10,\n", + " results_prithvi_l_full_pretrain_subset_10,\n", + " results_prithvi_b_subset_10_new,\n", + " results_satlas_subset_10,\n", + " ]\n", + ")\n", "df_subset[\"Task\"] = df_subset[\"Task\"].map(name_mapping)\n", "df_subset = add_means_to_df(df_subset, classification_datasets, segmentation_datasets)\n", "df_subset[\"Task\"] = df_subset[\"Task\"].map(dataset_name_map)" @@ -456,9 +566,15 @@ ], "source": [ "g = sns.catplot(\n", - " data=df_subset, kind=\"bar\",\n", - " x=\"Task\", y=\"Best Score\", hue=\"Model\",\n", - " errorbar=\"sd\", palette=\"dark\", alpha=.6, height=10\n", + " data=df_subset,\n", + " kind=\"bar\",\n", + " x=\"Task\",\n", + " y=\"Best Score\",\n", + " hue=\"Model\",\n", + " errorbar=\"sd\",\n", + " palette=\"dark\",\n", + " alpha=0.6,\n", + " height=10,\n", ")\n", "g.despine(left=True)\n", "g.set_axis_labels(\"Dataset\", \"Metric\")\n", diff --git a/plotting/plot_results_repeated_runs.ipynb b/plotting/plot_results_repeated_runs.ipynb index 5e5080d..bcb8f85 100644 --- a/plotting/plot_results_repeated_runs.ipynb +++ b/plotting/plot_results_repeated_runs.ipynb @@ -38,8 +38,7 @@ "from matplotlib.ticker import FormatStrFormatter\n", "import json\n", "from scipy.stats import trim_mean\n", - "import plot_tools\n", - "\n" + "import plot_tools" ] }, { @@ -48,7 +47,20 @@ "metadata": {}, "outputs": [], "source": [ - "name_mapping = {\"big_earth_net\": \"m-bigearthnet\", \"brick_kiln\": \"m-brick-kiln\", \"eurosat\": \"m-eurosat\", \"forestnet\": \"m-forestnet\", \"pv4ger\": \"m-pv4ger\", \"so2sat\": \"m-so2sat\", \"neontree\": \"m-NeonTree\", \"sa_crop_type\": \"m-SA-crop-type\", \"cashew\": \"m-cashew-plant\", \"chesapeake\": \"m-chesapeake\", \"nz_cattle\": \"m-nz-cattle\", \"pv4ger_seg\": \"m-pv4ger-seg\"}" + "name_mapping = {\n", + " \"big_earth_net\": \"m-bigearthnet\",\n", + " \"brick_kiln\": \"m-brick-kiln\",\n", + " \"eurosat\": \"m-eurosat\",\n", + " \"forestnet\": \"m-forestnet\",\n", + " \"pv4ger\": \"m-pv4ger\",\n", + " \"so2sat\": \"m-so2sat\",\n", + " \"neontree\": \"m-NeonTree\",\n", + " \"sa_crop_type\": \"m-SA-crop-type\",\n", + " \"cashew\": \"m-cashew-plant\",\n", + " \"chesapeake\": \"m-chesapeake\",\n", + " \"nz_cattle\": \"m-nz-cattle\",\n", + " \"pv4ger_seg\": \"m-pv4ger-seg\",\n", + "}" ] }, { @@ -73,7 +85,9 @@ "prithvi_os_results = pd.read_csv(\"prithvi_vit_os.csv\", index_col=\"Unnamed: 0\")\n", "prithvi_os_results[\"Backbone\"] = \"prithvi-eo-hls-100m-vit-os\"\n", "\n", - "prithvi_results = pd.concat([prithvi_results, prithvi_os_results, prithvi_global_results], ignore_index=True)" + "prithvi_results = pd.concat(\n", + " [prithvi_results, prithvi_os_results, prithvi_global_results], ignore_index=True\n", + ")" ] }, { @@ -82,7 +96,9 @@ "metadata": {}, "outputs": [], "source": [ - "prithvi_results = prithvi_results.rename(columns={\"Task\": \"dataset\", \"Backbone\": \"model\", \"Score\": \"test metric\"})\n", + "prithvi_results = prithvi_results.rename(\n", + " columns={\"Task\": \"dataset\", \"Backbone\": \"model\", \"Score\": \"test metric\"}\n", + ")\n", "prithvi_results[\"partition name\"] = \"1.00x train\"\n", "prithvi_results[\"dataset\"] = prithvi_results[\"dataset\"].map(name_mapping)" ] @@ -93,7 +109,9 @@ "metadata": {}, "outputs": [], "source": [ - "prithvi_results_2 = prithvi_results_2.rename(columns={\"Task\": \"dataset\", \"Backbone\": \"model\", \"Score\": \"test metric\"})\n", + "prithvi_results_2 = prithvi_results_2.rename(\n", + " columns={\"Task\": \"dataset\", \"Backbone\": \"model\", \"Score\": \"test metric\"}\n", + ")\n", "prithvi_results_2[\"partition name\"] = \"1.00x train\"\n", "prithvi_results_2[\"dataset\"] = prithvi_results_2[\"dataset\"].map(name_mapping)" ] @@ -106,8 +124,12 @@ "source": [ "geobench_results_class = pd.read_csv(\"baseline_classification_results.csv\")\n", "df_1x = plot_tools.extract_1x_data(geobench_results_class)\n", - "model_order = \"prithvi-eo-hls-90m-swin-B,prithvi-eo-hls-100m-vit,prithvi-eo-hls-100m-vit-os,ResNet18-Rnd,ResNet18-timm,ResNet18-MoCo-S2,ResNet50-SECO-S2,ResNet50-MoCo-S2,ResNet50-timm,ConvNeXt-B-timm,ViT-T-timm,ViT-S-timm,SwinV2-T-timm\".split(\",\")\n", - "model_colors = dict( zip(model_order, sns.color_palette(\"tab20\", n_colors=len(model_order))))" + "model_order = \"prithvi-eo-hls-90m-swin-B,prithvi-eo-hls-100m-vit,prithvi-eo-hls-100m-vit-os,ResNet18-Rnd,ResNet18-timm,ResNet18-MoCo-S2,ResNet50-SECO-S2,ResNet50-MoCo-S2,ResNet50-timm,ConvNeXt-B-timm,ViT-T-timm,ViT-S-timm,SwinV2-T-timm\".split(\n", + " \",\"\n", + ")\n", + "model_colors = dict(\n", + " zip(model_order, sns.color_palette(\"tab20\", n_colors=len(model_order)))\n", + ")" ] }, { @@ -160,10 +182,19 @@ "metadata": {}, "outputs": [], "source": [ - "classification_datasets = [\"m-bigearthnet\", \"m-brick-kiln\", \"m-eurosat\", \"m-forestnet\", \"m-pv4ger\", \"m-so2sat\"]\n", + "classification_datasets = [\n", + " \"m-bigearthnet\",\n", + " \"m-brick-kiln\",\n", + " \"m-eurosat\",\n", + " \"m-forestnet\",\n", + " \"m-pv4ger\",\n", + " \"m-so2sat\",\n", + "]\n", "# exclude bigearthnet for now\n", "# classification_datasets = [\"m-brick-kiln\", \"m-eurosat\", \"m-forestnet\", \"m-pv4ger\", \"m-so2sat\"]\n", - "segmentation_datasets = list(set(prithvi_results[\"dataset\"].unique().tolist()) - set(classification_datasets))\n", + "segmentation_datasets = list(\n", + " set(prithvi_results[\"dataset\"].unique().tolist()) - set(classification_datasets)\n", + ")\n", "# segmentation_datasets.remove(\"m-bigearthnet\")\n", "# segmentation_datasets.remove(\"m-cashew-plant\")" ] @@ -174,11 +205,17 @@ "metadata": {}, "outputs": [], "source": [ - "prithvi_class = prithvi_results[prithvi_results[\"dataset\"].isin(classification_datasets)]\n", + "prithvi_class = prithvi_results[\n", + " prithvi_results[\"dataset\"].isin(classification_datasets)\n", + "]\n", "prithvi_seg = prithvi_results[prithvi_results[\"dataset\"].isin(segmentation_datasets)]\n", "\n", - "prithvi_class_2 = prithvi_results_2[prithvi_results_2[\"dataset\"].isin(classification_datasets)]\n", - "prithvi_seg_2 = prithvi_results_2[prithvi_results_2[\"dataset\"].isin(segmentation_datasets)]" + "prithvi_class_2 = prithvi_results_2[\n", + " prithvi_results_2[\"dataset\"].isin(classification_datasets)\n", + "]\n", + "prithvi_seg_2 = prithvi_results_2[\n", + " prithvi_results_2[\"dataset\"].isin(segmentation_datasets)\n", + "]" ] }, { @@ -199,7 +236,7 @@ " \"m-cashew-plant\": \"10m\",\n", " \"m-SA-crop-type\": \"10m\",\n", " \"m-nz-cattle\": \"0.1m\",\n", - " \"m-NeonTree\": \"0.1m\"\n", + " \"m-NeonTree\": \"0.1m\",\n", "}\n", "\n", "dataset_instrument = {\n", @@ -214,7 +251,7 @@ " \"m-cashew-plant\": \"S2\",\n", " \"m-SA-crop-type\": \"S2\",\n", " \"m-nz-cattle\": \"RGB\",\n", - " \"m-NeonTree\": \"RGB + Hyper\"\n", + " \"m-NeonTree\": \"RGB + Hyper\",\n", "}\n", "\n", "img_size = {\n", @@ -229,10 +266,13 @@ " \"m-cashew-plant\": 256,\n", " \"m-SA-crop-type\": 256,\n", " \"m-nz-cattle\": 500,\n", - " \"m-NeonTree\": 400\n", + " \"m-NeonTree\": 400,\n", "}\n", "\n", - "dataset_name_map = {name: f\"{name}\\n {dataset_instrument[name]} @ {dataset_res[name]}\" for name in img_size.keys()}" + "dataset_name_map = {\n", + " name: f\"{name}\\n {dataset_instrument[name]} @ {dataset_res[name]}\"\n", + " for name in img_size.keys()\n", + "}" ] }, { @@ -241,7 +281,13 @@ "metadata": {}, "outputs": [], "source": [ - "class_df = pd.concat([df_1x[[\"model\", \"dataset\", \"test metric\", \"partition name\"]], prithvi_class.drop(columns=[\"Metric\"])], ignore_index=True)\n", + "class_df = pd.concat(\n", + " [\n", + " df_1x[[\"model\", \"dataset\", \"test metric\", \"partition name\"]],\n", + " prithvi_class.drop(columns=[\"Metric\"]),\n", + " ],\n", + " ignore_index=True,\n", + ")\n", "# class_df = pd.concat([class_df[[\"model\", \"dataset\", \"test metric\", \"partition name\"]], prithvi_class_2.drop(columns=[\"Metric\"])], ignore_index=True)\n", "# class_df[\"dataset\"] = class_df[\"dataset\"].map(lambda x: f'{x} ({dataset_instrument[x]} [{dataset_res[x]}])\\n{img_size[x]} x {img_size[x]}').astype(str)" ] @@ -261,7 +307,9 @@ } ], "source": [ - "class_df.groupby([\"model\", \"dataset\"]).agg([\"mean\", \"std\"]).to_csv(\"table_classification.csv\")" + "class_df.groupby([\"model\", \"dataset\"]).agg([\"mean\", \"std\"]).to_csv(\n", + " \"table_classification.csv\"\n", + ")" ] }, { @@ -312,7 +360,16 @@ ], "source": [ "class_df[\"dataset\"] = class_df[\"dataset\"].map(dataset_name_map)\n", - "plot_tools.plot_per_dataset(class_df, model_order, model_colors=model_colors, metric=\"test metric\", sharey=False, inner=\"points\", fig_size=(14, 3), n_legend_rows=2)\n", + "plot_tools.plot_per_dataset(\n", + " class_df,\n", + " model_order,\n", + " model_colors=model_colors,\n", + " metric=\"test metric\",\n", + " sharey=False,\n", + " inner=\"points\",\n", + " fig_size=(14, 3),\n", + " n_legend_rows=2,\n", + ")\n", "plt.savefig(\"classification_raw.png\", bbox_inches=\"tight\")" ] }, @@ -367,7 +424,15 @@ } ], "source": [ - "agg_class = plot_tools.normalize_bootstrap_and_plot(class_df, metric=\"test metric\",benchmark_name=\"classification_v1.0\", model_order=model_order, model_colors=model_colors, fig_size=(12,2.3), dataset_name_map=dataset_name_map)\n", + "agg_class = plot_tools.normalize_bootstrap_and_plot(\n", + " class_df,\n", + " metric=\"test metric\",\n", + " benchmark_name=\"classification_v1.0\",\n", + " model_order=model_order,\n", + " model_colors=model_colors,\n", + " fig_size=(12, 2.3),\n", + " dataset_name_map=dataset_name_map,\n", + ")\n", "plt.savefig(\"classification_normalized.png\", bbox_inches=\"tight\")" ] }, @@ -386,7 +451,9 @@ } ], "source": [ - "agg_class.groupby([\"model\", \"dataset\"]).agg([\"mean\", \"std\"]).to_csv(\"class_with_aggregated.csv\")" + "agg_class.groupby([\"model\", \"dataset\"]).agg([\"mean\", \"std\"]).to_csv(\n", + " \"class_with_aggregated.csv\"\n", + ")" ] }, { @@ -402,8 +469,12 @@ "metadata": {}, "outputs": [], "source": [ - "model_order = 'prithvi-eo-hls-90m-swin-B,prithvi-eo-hls-100m-vit,prithvi-eo-hls-100m-vit-os,ResNet18-U-Net-timm,ResNet50-U-Net-timm,ResNet101-U-Net-timm,ResNet18 DeepLabV3-timm,ResNet50 DeepLabV3-timm,ResNet101 DeepLabV3-timm'.split(',')\n", - "model_colors = dict( zip(model_order, sns.color_palette(\"tab20\", n_colors=len(model_order))))" + "model_order = \"prithvi-eo-hls-90m-swin-B,prithvi-eo-hls-100m-vit,prithvi-eo-hls-100m-vit-os,ResNet18-U-Net-timm,ResNet50-U-Net-timm,ResNet101-U-Net-timm,ResNet18 DeepLabV3-timm,ResNet50 DeepLabV3-timm,ResNet101 DeepLabV3-timm\".split(\n", + " \",\"\n", + ")\n", + "model_colors = dict(\n", + " zip(model_order, sns.color_palette(\"tab20\", n_colors=len(model_order)))\n", + ")" ] }, { @@ -422,7 +493,13 @@ "metadata": {}, "outputs": [], "source": [ - "seg_df = pd.concat([df_1x[[\"model\", \"dataset\", \"test metric\", \"partition name\"]], prithvi_seg.drop(columns=[\"Metric\"])], ignore_index=True)" + "seg_df = pd.concat(\n", + " [\n", + " df_1x[[\"model\", \"dataset\", \"test metric\", \"partition name\"]],\n", + " prithvi_seg.drop(columns=[\"Metric\"]),\n", + " ],\n", + " ignore_index=True,\n", + ")" ] }, { @@ -880,7 +957,16 @@ ], "source": [ "seg_df[\"dataset\"] = seg_df[\"dataset\"].map(dataset_name_map)\n", - "plot_tools.plot_per_dataset(seg_df, model_order, model_colors=model_colors, metric=\"test metric\", sharey=False, inner=\"points\", fig_size=(14, 3), n_legend_rows=2)\n", + "plot_tools.plot_per_dataset(\n", + " seg_df,\n", + " model_order,\n", + " model_colors=model_colors,\n", + " metric=\"test metric\",\n", + " sharey=False,\n", + " inner=\"points\",\n", + " fig_size=(14, 3),\n", + " n_legend_rows=2,\n", + ")\n", "plt.savefig(\"segmentation_raw.png\", bbox_inches=\"tight\")" ] }, @@ -935,7 +1021,15 @@ } ], "source": [ - "agg_seg = plot_tools.normalize_bootstrap_and_plot(seg_df, metric=\"test metric\",benchmark_name=\"segmentation_v1.0\", model_order=model_order, model_colors=model_colors, fig_size=(12,2.3), dataset_name_map=dataset_name_map)\n", + "agg_seg = plot_tools.normalize_bootstrap_and_plot(\n", + " seg_df,\n", + " metric=\"test metric\",\n", + " benchmark_name=\"segmentation_v1.0\",\n", + " model_order=model_order,\n", + " model_colors=model_colors,\n", + " fig_size=(12, 2.3),\n", + " dataset_name_map=dataset_name_map,\n", + ")\n", "\n", "plt.savefig(\"segmentation_normalized.png\", bbox_inches=\"tight\")" ] @@ -993,7 +1087,15 @@ } ], "source": [ - "agg_seg = plot_tools.normalize_bootstrap_and_plot(seg_df[seg_df[\"dataset\"] != dataset_name_map[\"m-cashew-plant\"]], metric=\"test metric\",benchmark_name=\"segmentation_v1.0\", model_order=model_order, model_colors=model_colors, fig_size=(12,2.3), dataset_name_map=dataset_name_map)\n", + "agg_seg = plot_tools.normalize_bootstrap_and_plot(\n", + " seg_df[seg_df[\"dataset\"] != dataset_name_map[\"m-cashew-plant\"]],\n", + " metric=\"test metric\",\n", + " benchmark_name=\"segmentation_v1.0\",\n", + " model_order=model_order,\n", + " model_colors=model_colors,\n", + " fig_size=(12, 2.3),\n", + " dataset_name_map=dataset_name_map,\n", + ")\n", "plt.savefig(\"segmentation_normalized_no_cashew.png\", bbox_inches=\"tight\")" ] }, @@ -1012,7 +1114,9 @@ } ], "source": [ - "agg_seg.groupby([\"model\", \"dataset\"]).agg([\"mean\", \"std\"]).to_csv(\"seg_with_aggregated.csv\")" + "agg_seg.groupby([\"model\", \"dataset\"]).agg([\"mean\", \"std\"]).to_csv(\n", + " \"seg_with_aggregated.csv\"\n", + ")" ] }, { diff --git a/pyproject.toml b/pyproject.toml index e226ff1..e36edb0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,8 +13,8 @@ include = ["benchmark*"] [project] name = "terratorch-iterate" -version = "0.1.5" -requires-python = ">= 3.9" +version = "0.2.0" +requires-python = ">= 3.10" description = "A terratorch's plugin for benchmarking and hyperparameter optimization" authors = [ { name = "Carlos Gomes"}, @@ -37,9 +37,7 @@ classifiers = [ readme = "README.md" dependencies = [ -# ObjectDetection is not supported on terratorch==1.0.2, so iterate relies on main branch -"terratorch", -# "terratorch", +"terratorch>=1.1.0", # requests>=2.32.0 because of this vulnerability https://github.com/psf/requests/security/advisories/GHSA-9wx4-h78v-vm56 "requests>=2.32.0", # Jinja2 vulnerability issue https://github.com/pallets/jinja/security/advisories/GHSA-h75v-3vvj-5mfj @@ -64,7 +62,6 @@ dependencies = [ "importlib-metadata", "numpy", "optuna", -"tabulate", "types-tabulate", "ray", "gputil", @@ -73,8 +70,8 @@ dependencies = [ "configspace", "optuna-integration", "seaborn", -"torchgeo", "psutil", +"tabulate>=0.9.0", ] [project.urls] @@ -83,7 +80,7 @@ Issues = "https://github.com/IBM/terratorch-iterate/issues" [project.optional-dependencies] dev = [ - "black", + "ruff", "flake8", "mkdocs-material", "mkdocstrings[python]", @@ -108,12 +105,12 @@ nvidia = ["pynvml"] amd = ["pyrsmi"] [tool.black] -target-version = ["py310"] +target-version = ["py312"] line-length = 88 skip-string-normalization = true [project.scripts] -iterate = "benchmark.main:main" +iterate = "terratorch_iterate.main:main" # ray_benchmark = "benchmark.benchmark_ray:main" # repeat_experiments = "benchmark.main:main" diff --git a/run_tests.py b/run_tests.py index 7427fb7..60fa0b1 100644 --- a/run_tests.py +++ b/run_tests.py @@ -1,48 +1,110 @@ import subprocess from pathlib import Path from typing import Optional -from tests.test_benchmark import TEST_CASE_IDS +from tests.integration.test_main import get_test_ids import click +import logging +import sys + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) # Set appropriate level + +# Create a StreamHandler that writes to stdout +ch = logging.StreamHandler(sys.stdout) +ch.setLevel(logging.DEBUG) # Set appropriate level for the handler + +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +ch.setFormatter(formatter) + +logger.addHandler(ch) # rm geobench_v1_prithvi* && bsub -e ~/geobench_v1_prithvi.err -o ~/geobench_v1_prithvi.out -M 40G -gpu "num=1/task:mode=exclusive_process:gmodel=NVIDIAA100_SXM4_80GB" terratorch iterate --hpo --config configs/geobench_v1_prithvi.yaml +REPO_HOME_DIR = Path(__file__).parent +LOGS_DIR = REPO_HOME_DIR / "logs" + +if not LOGS_DIR.exists(): + LOGS_DIR.mkdir() + +# Delete all files in logs dir +for item in LOGS_DIR.iterdir(): + if item.is_file(): + item.unlink() + + +@click.group() +def cli(): + pass + + +def submit_job( + stderr_file: str, + stdout_file: str, + tc_id: str | None = None, + config: str | None = None, +): + err_file = LOGS_DIR / stderr_file + # delete file if it exists + if err_file.exists(): + logger.info(f"Delete file {err_file}") + err_file.unlink(missing_ok=True) + assert not err_file.exists() + + out_file = LOGS_DIR / stdout_file + # delete file if it exists + if out_file.exists(): + logger.info(f"Delete file {out_file}") + out_file.unlink(missing_ok=True) + assert not out_file.exists() + if tc_id is not None: + jbsub = f'bsub -e {err_file} -o {out_file} -M 40G -gpu "num=1/task:mode=exclusive_process:gmodel=NVIDIAA100_SXM4_80GB" pytest -vv tests/integration/test_main.py::test_main[{tc_id}]' + elif config is not None: + jbsub = f'bsub -e {err_file} -o {out_file} -M 40G -gpu "num=1/task:mode=exclusive_process:gmodel=NVIDIAA100_SXM4_80GB" terratorch iterate --hpo --config {config}' + else: + raise ValueError("Error! Either tc_id or config must be not None") + cmd = jbsub.split() + result = subprocess.run(cmd, capture_output=True) + if result.returncode == 0: + logger.info(f"Command executed successfully: {jbsub}") + + else: + logger.info(f"Command failed: {jbsub}") + logger.info("Command failed with error code:", result.returncode) + logger.info("stderr:", result.stderr) + @click.command() -@click.option('--test_id', default=None, help='test ID') +@click.option("--test_id", default=None, help="test ID") def run_tests(test_id: Optional[str] = None): if test_id is None: - test_ids = TEST_CASE_IDS + test_ids = get_test_ids() else: test_ids = [test_id] for tc_id in test_ids: - print(f"Running test case: tests/test_benchmark.py::test_run_benchmark {tc_id}") - stderr_file = f"test-iterate-test_benchmark-{tc_id}.err" - stdout_file = f"test-iterate-test_benchmark-{tc_id}.out" - - err_file = Path.home() / stderr_file - # delete file if it exists - if err_file.exists(): - print(f"Delete file {err_file}") - err_file.unlink(missing_ok=True) - assert not err_file.exists() - out_file = Path.home() / stdout_file - - # delete file if it exists - if out_file.exists(): - print(f"Delete file {out_file}") - out_file.unlink(missing_ok=True) - assert not out_file.exists() - jbsub = f"bsub -e {err_file} -o {out_file} -M 40G -gpu \"num=1/task:mode=exclusive_process:gmodel=NVIDIAA100_SXM4_80GB\" pytest -vv tests/test_benchmark.py::test_run_benchmark[{tc_id}]" - cmd = jbsub.split() - result = subprocess.run(cmd, capture_output=True) - if result.returncode == 0: - print(f"Command executed successfully: {jbsub}") - - else: - print(f"Command failed: {jbsub}") - print("Command failed with error code:", result.returncode) - print("stderr:", result.stderr) + logger.info( + f"Running test case: tests/test_benchmark.py::test_run_benchmark {tc_id}" + ) + stderr_file = f"{tc_id}.err" + stdout_file = f"{tc_id}.out" + + submit_job(stderr_file=stderr_file, stdout_file=stdout_file, tc_id=tc_id) + + +@click.command() +@click.option("--config", default=None, help="path to config file") +def run_job(config: str): + home_dir = Path(__file__).parent + config_path = home_dir / config + assert config_path.exists() + stem = config_path.stem + err_file = f"{stem}.err" + out_file = f"{stem}.out" + logger.info(f"Running job with config: {config}") + submit_job(stdout_file=out_file, stderr_file=err_file, config=config) + +cli.add_command(run_job) +cli.add_command(run_tests) if __name__ == "__main__": - run_tests() + cli() diff --git a/benchmark/__init__.py b/terratorch_iterate/__init__.py similarity index 100% rename from benchmark/__init__.py rename to terratorch_iterate/__init__.py diff --git a/benchmark/backbone_benchmark.py b/terratorch_iterate/backbone_benchmark.py similarity index 70% rename from benchmark/backbone_benchmark.py rename to terratorch_iterate/backbone_benchmark.py index cf161a8..412d245 100644 --- a/benchmark/backbone_benchmark.py +++ b/terratorch_iterate/backbone_benchmark.py @@ -17,16 +17,16 @@ from optuna.samplers import BaseSampler, RandomSampler from tabulate import tabulate import pickle -from benchmark.benchmark_types import ( +from terratorch_iterate.iterate_types import ( Defaults, ParameterBounds, Task, combine_with_defaults, optimization_space_type, ) -from benchmark.model_fitting import fit_model, fit_model_with_hparams -from benchmark.repeat_best_experiment import rerun_best_from_backbone -from benchmark.utils import ( +from terratorch_iterate.model_fitting import fit_model, fit_model_with_hparams +from terratorch_iterate.repeat_best_experiment import rerun_best_from_backbone +from terratorch_iterate.utils import ( check_existing_task_parent_runs, check_existing_experiments, unflatten, @@ -39,7 +39,7 @@ def benchmark_backbone_on_task( - logger, + logger: logging.RootLogger, defaults: Defaults, task: Task, storage_uri: str, @@ -52,8 +52,14 @@ def benchmark_backbone_on_task( sampler: BaseSampler | None = None, test_models: bool = False, ) -> tuple[float, str | list[str] | None, dict[str, Any]]: + logger.info( + f"starting backbone benchmark on task {task.name} {task_run_id=} {experiment_name=}" + ) + if storage_uri.startswith("http"): + optuna_db_path = Path(".") / "optuna_db" + else: + optuna_db_path = Path(storage_uri).parents[0] / "optuna_db" - optuna_db_path = Path(storage_uri).parents[0] / "optuna_db" if not os.path.exists(optuna_db_path): os.makedirs(optuna_db_path) optuna_db_path = optuna_db_path / f"{experiment_name}_{experiment_run_id}" @@ -68,8 +74,13 @@ def benchmark_backbone_on_task( n_trials=n_trials, logger=logger, ) - - with mlflow.start_run(run_name=task.name, nested=True, run_id=task_run_id) as run: + if task_run_id is not None: + # run_name is used only when run_id is unspecified. + run_name = None + else: + run_name = task.name + logger.info(f"start run: {run_name=} {task_run_id=}") + with mlflow.start_run(run_name=run_name, nested=True, run_id=task_run_id) as run: logger.info(f"starting task run with id: {run.info.run_id}") training_spec = combine_with_defaults(task, defaults) if "max_epochs" not in training_spec.trainer_args: @@ -136,15 +147,15 @@ def benchmark_backbone_on_task( "early_stop_patience": str(training_spec.task.early_stop_patience), "partition_name": ( str(training_spec.task.datamodule.partition) - if hasattr(training_spec.task.datamodule, 'partition') - else 'default' + if hasattr(training_spec.task.datamodule, "partition") + else "default" ), "decoder": ( str(training_spec.task.terratorch_task["model_args"]["decoder"]) if "decoder" in training_spec.task.terratorch_task["model_args"] - else training_spec.task.terratorch_task['model_args']['framework'] + else training_spec.task.terratorch_task["model_args"]["framework"] ), - "task": str(training_spec.task.type).split('.')[-1], + "task": str(training_spec.task.type).split(".")[-1], "backbone": str( training_spec.task.terratorch_task["model_args"]["backbone"] ), @@ -179,6 +190,104 @@ def parse_optimization_space(space: dict | None) -> optimization_space_type | No return parsed_space +def _run_hpo( + run_name: str | None, + run_id: str | None, + description: str, + tasks: list, + completed_task_run_names: list, + task_run_to_id_match: dict, + defaults, + storage_uri: str, + experiment_name: str, + optimization_space, + n_trials, + save_models, + sampler, + test_models, + table_entries, + table_columns, + backbone, + task_names, + PATH_TO_JOB_TRACKING, + logger, +) -> tuple[str, str]: + logger.info( + f"Running hyperparameter optimization: {run_name=} {run_id=} {description=}" + ) + if run_id is not None: + run_name = None + + with mlflow.start_run( + run_name=run_name, run_id=run_id, description=description + ) as run: + for task in tasks: + # only run task if it was not completed before + task_run_name = task.name + if task_run_name in completed_task_run_names: + logger.info(f"{task_run_name} already completed") + continue + else: + logger.info(f"{task_run_name} not completed. starting now") + + task_run_id = ( + task_run_to_id_match[task_run_name] + if task_run_name in task_run_to_id_match + else None + ) + best_value, metric_name, hparams = benchmark_backbone_on_task( + logger, + defaults, + task, + storage_uri, + experiment_name, + experiment_run_id=run.info.run_id, + task_run_id=task_run_id, + optimization_space=optimization_space, + n_trials=n_trials, + save_models=save_models, + sampler=sampler, + test_models=test_models, + ) + table_entries.append([task.name, metric_name, best_value, hparams]) + table_entries_filename = str( + PATH_TO_JOB_TRACKING + / f"{experiment_name}-{run.info.run_id}_table_entries.pkl" + ) + with open(table_entries_filename, "wb") as handle: + pickle.dump(table_entries, handle, protocol=pickle.HIGHEST_PROTOCOL) + + table = tabulate(table_entries, headers=table_columns) + logger.info(table) + df = pd.DataFrame(data=table_entries, columns=table_columns) + df.set_index("Task") + logger.info("Starting to save results") + mlflow.log_table( + df, + "results_table.json", + run.info.run_id, + ) + experiment_id = run.info.experiment_id + + # check completion of HPO for all tasks before proceeding to next stage + existing_experiments = check_existing_experiments( + logger=logger, + storage_uri=storage_uri, + experiment_name=experiment_name, + exp_parent_run_name=run_name, + task_names=task_names, + n_trials=n_trials, + backbone=backbone, + ) + if existing_experiments["finished_run"] is not None: + finished_run_id = existing_experiments["finished_run"] + else: + logger.info("HPO is not complete. Please re-run this experiment") + raise RuntimeError + + return experiment_id, finished_run_id + + def benchmark_backbone( defaults: Defaults, tasks: list[Task], @@ -236,25 +345,19 @@ def benchmark_backbone( if backbone_import: importlib.import_module(backbone_import) - + logger.info(f"Setting tracking URI: {storage_uri}") mlflow.set_tracking_uri(storage_uri) + logger.info(f"Setting experiment name: {experiment_name}") mlflow.set_experiment(experiment_name) - if bayesian_search: - sampler: BaseSampler | None = None # take the default - else: - sampler = RandomSampler() - optimization_space = parse_optimization_space(optimization_space) - table_columns = ["Task", "Metric", "Best Score", "Hyperparameters"] - table_entries = [] backbone: str = defaults.terratorch_task["model_args"]["backbone"] task_names = [task.name for task in tasks] run_name = f"top_run_{experiment_name}" if run_name is None else run_name completed_task_run_names = [] - run_hpo = True + optimize_hyperparams = True task_run_to_id_match = {} if continue_existing_experiment: # find status of existing runs, and delete incomplete runs except one with the most complete tasks @@ -275,12 +378,14 @@ def benchmark_backbone( ): logger.info("Continuing previous experiment parent run") run_id = existing_experiments["incomplete_run_to_finish"] + logger.debug(f"incomplete_run_to_finish: {run_id=}") experiment_id = existing_experiments["experiment_id"] - run_hpo = True + optimize_hyperparams = True if existing_experiments["finished_run"] is not None: - run_hpo = False + optimize_hyperparams = False finished_run_id = existing_experiments["finished_run"] + logger.debug(f"finished_run: {run_id=}") run_id = existing_experiments["finished_run"] # get previously completed tasks @@ -294,85 +399,45 @@ def benchmark_backbone( PATH_TO_JOB_TRACKING / f"{experiment_name}-{run_id}_table_entries.pkl" ) if os.path.exists(table_entries_filename): - with open(table_entries_filename, 'rb') as handle: + with open(table_entries_filename, "rb") as handle: table_entries = pickle.load(handle) else: logger.info("Starting new experiment from scratch") # only run hyperparameter optimization (HPO) if there are no experiments with finished HPO - if run_hpo: - logger.info("Running hyperparameter optimization") - with mlflow.start_run( - run_name=run_name, run_id=run_id, description=description - ) as run: - for task in tasks: - # only run task if it was not completed before - task_run_name = task.name - if task_run_name in completed_task_run_names: - logger.info(f"{task_run_name} already completed") - continue - else: - logger.info(f"{task_run_name} not completed. starting now") - - task_run_id = ( - task_run_to_id_match[task_run_name] - if task_run_name in task_run_to_id_match - else None - ) - best_value, metric_name, hparams = benchmark_backbone_on_task( - logger, - defaults, - task, - storage_uri, - experiment_name, - experiment_run_id=run.info.run_id, - task_run_id=task_run_id, - optimization_space=optimization_space, - n_trials=n_trials, - save_models=save_models, - sampler=sampler, - test_models=test_models, - ) - table_entries.append([task.name, metric_name, best_value, hparams]) - table_entries_filename = str( - PATH_TO_JOB_TRACKING - / f"{experiment_name}-{run.info.run_id}_table_entries.pkl" - ) - with open(table_entries_filename, 'wb') as handle: - pickle.dump(table_entries, handle, protocol=pickle.HIGHEST_PROTOCOL) - - table = tabulate(table_entries, headers=table_columns) - logger.info(table) - df = pd.DataFrame(data=table_entries, columns=table_columns) - df.set_index("Task") - logger.info("Starting to save results") - mlflow.log_table( - df, - "results_table.json", - run.info.run_id, - ) - experiment_id = run.info.experiment_id - - # check completion of HPO for all tasks before proceeding to next stage - existing_experiments = check_existing_experiments( - logger=logger, + if optimize_hyperparams: + if bayesian_search: + sampler: BaseSampler | None = None # take the default + else: + sampler = RandomSampler() + table_columns = ["Task", "Metric", "Best Score", "Hyperparameters"] + table_entries = [] + experiment_id, finished_run_id = _run_hpo( + run_name=run_name, + run_id=run_id, + description=description, + tasks=tasks, + task_names=task_names, + completed_task_run_names=completed_task_run_names, + task_run_to_id_match=task_run_to_id_match, + defaults=defaults, storage_uri=storage_uri, experiment_name=experiment_name, - exp_parent_run_name=run_name, - task_names=task_names, n_trials=n_trials, + save_models=save_models, + sampler=sampler, + test_models=test_models, + table_entries=table_entries, + table_columns=table_columns, backbone=backbone, + PATH_TO_JOB_TRACKING=PATH_TO_JOB_TRACKING, + optimization_space=optimization_space, + logger=logger, ) - if existing_experiments["finished_run"] is not None: - finished_run_id = existing_experiments["finished_run"] - else: - logger.info("HPO is not complete. Please re-run this experiment") - raise RuntimeError - logger.info("HPO complete") - - logger.info(f"run_repetitions: {run_repetitions}") + logger.info("HPO complete") if run_repetitions >= 1: + logger.info(f"run_repetitions: {run_repetitions}") # run repeated experiments logger.info( f"Now running {run_repetitions} repeats per experiment \n\ diff --git a/benchmark/benchmark_ray.py b/terratorch_iterate/benchmark_ray.py similarity index 97% rename from benchmark/benchmark_ray.py rename to terratorch_iterate/benchmark_ray.py index 81eed60..0c7ea5c 100644 --- a/benchmark/benchmark_ray.py +++ b/terratorch_iterate/benchmark_ray.py @@ -14,15 +14,15 @@ from ray.tune.search.optuna import OptunaSearch from tabulate import tabulate -from benchmark.backbone_benchmark import parse_optimization_space -from benchmark.benchmark_types import ( +from terratorch_iterate.backbone_benchmark import parse_optimization_space +from terratorch_iterate.iterate_types import ( Defaults, Task, TrainingSpec, combine_with_defaults, optimization_space_type, ) -from benchmark.model_fitting import fit_model, ray_tune_model, valid_task_types +from terratorch_iterate.model_fitting import fit_model, ray_tune_model, valid_task_types def benchmark_backbone_on_task( @@ -166,7 +166,6 @@ def benchmark_backbone( with mlflow.start_run( run_name=run_name, run_id=run_id, description=description ) as run: - if optimization_space is None: # no hparams, parallelize over tasks ray_tasks = [] diff --git a/benchmark/tests/__init__.py b/terratorch_iterate/config_util/__init__.py similarity index 100% rename from benchmark/tests/__init__.py rename to terratorch_iterate/config_util/__init__.py diff --git a/benchmark/config_util/build_geobench_configs.py b/terratorch_iterate/config_util/build_iterate_config.py similarity index 59% rename from benchmark/config_util/build_geobench_configs.py rename to terratorch_iterate/config_util/build_iterate_config.py index d704883..e45d8c7 100644 --- a/benchmark/config_util/build_geobench_configs.py +++ b/terratorch_iterate/config_util/build_iterate_config.py @@ -1,14 +1,15 @@ from pathlib import Path -from typing import Any import yaml import pandas as pd import click -from benchmark.benchmark_types import ( +from terratorch_iterate.iterate_types import ( TaskTypeEnum, ) from copy import deepcopy -PRITHVI_600M = 'prithvi_600M' +DEFAULT_TEMPLATE = ( + Path(__file__).parent.parent.parent / "configs/templates/template.yaml" +) def _build_dataframe(config_files) -> pd.DataFrame: @@ -21,7 +22,7 @@ def _build_dataframe(config_files) -> pd.DataFrame: for config_file in config_files: try: # extract dataset name from filename - ds = str(config_file).split('/')[-1].split('_')[0] + ds = str(config_file).split("/")[-1].split("_")[0] dataset.append(ds) # append file path files.append(str(config_file)) @@ -32,36 +33,13 @@ def _build_dataframe(config_files) -> pd.DataFrame: df = pd.DataFrame(data={"file": files, "dataset": dataset}) models = [ - x.split('/')[-1].replace(y + '_', '').replace('.yaml', '') - for x, y in zip(df['file'].values, df['dataset'].values) + x.split("/")[-1].replace(y + "_", "").replace(".yaml", "") + for x, y in zip(df["file"].values, df["dataset"].values) ] df["model"] = models return df -def _create_basemodule(data: dict[str, Any], model_filter: str) -> dict: - """create a dict based on the "data" field of the terratorch config - - Args: - data (dict[str, Any]): _description_ - model_filter (str): model name is used to specify batch_size and eval_batch_size - - Returns: - dict: returns a dict that represents the datamodule field of iterate config file - """ - base_module = dict() - base_module["class_path"] = data["class_path"] - if "dict_kwargs" in data.keys(): - dict_kwargs = data["dict_kwargs"] - batch_size = 8 if model_filter != PRITHVI_600M else 4 - dict_kwargs["batch_size"] = batch_size - dict_kwargs['eval_batch_size'] = 8 if model_filter != PRITHVI_600M else 4 - - base_module["dict_kwargs"] = dict_kwargs - base_module["init_args"] = data["init_args"] - return base_module - - def _create_task( name: str, datamodule: dict, @@ -71,7 +49,7 @@ def _create_task( direction: str, max_run_duration: str | None = None, early_stop_patience: int | None = None, - early_prune: bool = False, + early_prune: bool | None = None, ) -> dict: """instantiate Task dataclass and convert it to dict @@ -97,10 +75,15 @@ def _create_task( "direction": direction, "metric": metric, "terratorch_task": terratorch_task, - "max_run_duration": max_run_duration, - "early_stop_patience": early_stop_patience, - "early_prune": early_prune, } + # set optional fields if they are not None + for k, v in [ + ("max_run_duration", max_run_duration), + ("early_stop_patience", early_stop_patience), + ("early_prune", early_prune), + ]: + if v is not None: + task_dict[k] = v return task_dict @@ -132,7 +115,10 @@ def _get_task_direction(template: dict) -> str: def generate_iterate_config( - input_dir: Path, template: Path, output_dir: Path, prefix: str = "test_" + input: Path, + output: Path, + template: Path = DEFAULT_TEMPLATE, + prefix: str = "tt-iterate-", ): """generate the tt-iterate based on yaml files located within the specified directory, based on previously defined template and save the result using specified output filename @@ -141,53 +127,61 @@ def generate_iterate_config( input_dir (Path): contains all terratorch yaml files output_dir (Path): filename of the result template (Path): template file that contains pre-defined values + prefix (str): prefix for creating new config files """ - - config_files = input_dir.glob('**/*.yaml') + assert input.exists() + if input.is_dir(): + config_files = input.glob("**/*.yaml") + elif input.is_file(): + config_files = [input] + else: + ValueError(f"Error! {input=} is neither a file nor a directory") files_df = _build_dataframe(config_files=config_files) - files_df = files_df[files_df['dataset'].values != 'M4SAR'] - files_df = files_df[files_df['model'].values != 'resnet50_torchgeo'] - - files_df = files_df.sort_values(['model', 'dataset']) + # set default values if necessary + if template is None: + template = DEFAULT_TEMPLATE + if prefix is None: + prefix = "tt-iterate-" - models = files_df['model'].unique() + models = files_df["model"].unique() - with open(template, 'r') as file: - template = yaml.safe_load(file) + with open(template, "r") as file: + template_dict: dict = yaml.safe_load(file) # generate one config per model for model in models: - model_specific_template = deepcopy(template) + model_specific_template = deepcopy(template_dict) + # create unique name for experiment model_specific_template["experiment_name"] = f"{prefix}_{model}" tasks = list() - single_model_df = files_df[files_df['model'].values == model] + # filter dataframe by model + single_model_df = files_df[files_df["model"].values == model] for i in range(single_model_df.shape[0]): - - with open(single_model_df['file'].values[i], 'r') as file: + # open terratorch config file + with open(single_model_df["file"].values[i], "r") as file: data = yaml.safe_load(file) - name = single_model_df['dataset'].values[i] + name = single_model_df["dataset"].values[i] - model_args: dict = data['model']['init_args']['model_args'] + model_args: dict = data["model"]["init_args"]["model_args"] # framework is an optional field of terratorch config if ( model_args.get("framework") is not None and model_args.get("framework") == "faster-rcnn" ): - metric = 'val_map' + metric = "val_map" else: - metric = 'val_segm_map' + metric = "val/loss" - # terratorchtask is the data.model.init_args of terratorch config file - terratorch_task = data['model']['init_args'] + # terratorchtask is extracted from the data.model.init_args of terratorch config file + terratorch_task = data["model"]["init_args"] # create datamodule based on data field - data = data['data'] - datamodule = _create_basemodule(data=data, model_filter=model) - task_type = _get_task_type(template=template) - task_direction = _get_task_direction(template=template) + datamodule = data["data"] + task_type = _get_task_type(template=template_dict) + task_direction = _get_task_direction(template=template_dict) task = _create_task( name=name, datamodule=datamodule, @@ -198,35 +192,38 @@ def generate_iterate_config( ) tasks.append(task) - model_specific_template['tasks'] = tasks - path = output_dir / f"{prefix}_{model}.yaml" + model_specific_template["tasks"] = tasks + if output.is_dir(): + path = output / f"{prefix}_{model}.yaml" + else: + path = output if path.exists(): path.unlink() - with open(path, 'w') as file: + with open(path, "w") as file: yaml.dump(model_specific_template, file) print(f"{path} file has been created") @click.command() @click.option( - '--input_dir', - prompt='Full path to the directory that contains all terratorch config yaml files', - help='Full path to the directory that contains all terratorch config yaml files', + "--input_dir", + prompt="Full path to the directory that contains all terratorch config yaml files", + help="Full path to the directory that contains all terratorch config yaml files", ) @click.option( - '--output_dir', - prompt='Full path to the directory in which the new config files will be stored', - help='Full path to the directory in which the new config files will be stored', + "--output_dir", + prompt="Full path to the directory in which the new config files will be stored", + help="Full path to the directory in which the new config files will be stored", ) @click.option( - '--template', - prompt='Full path to the template file', - help='Full path to the template file', + "--template", + prompt="Full path to the template file", + help="Full path to the template file", ) @click.option( - '--prefix', - prompt='Prefix of the config filename, e.g., my-config-', - help='Prefix of the config filename', + "--prefix", + prompt="Prefix of the config filename, e.g., my-config-", + help="Prefix of the config filename", ) def generate_tt_iterate_config( input_dir: str, output_dir: str, template: str, prefix: str @@ -245,12 +242,12 @@ def generate_tt_iterate_config( assert isinstance(prefix, str), f"Error! {type(prefix)} is not a str" generate_iterate_config( - input_dir=directory_path, - output_dir=output_path, + input=directory_path, + output=output_path, template=template_path, prefix=prefix, ) -if __name__ == '__main__': +if __name__ == "__main__": generate_tt_iterate_config() diff --git a/benchmark/benchmark_types.py b/terratorch_iterate/iterate_types.py similarity index 98% rename from benchmark/benchmark_types.py rename to terratorch_iterate/iterate_types.py index f8f8d31..b9e6082 100644 --- a/benchmark/benchmark_types.py +++ b/terratorch_iterate/iterate_types.py @@ -86,7 +86,7 @@ def __post_init__(self): optimization_space_type = dict[ - str, Union[list, ParameterBounds, 'optimization_space_type'] + str, Union[list, ParameterBounds, "optimization_space_type"] ] @@ -146,7 +146,6 @@ class TrainingSpec: def recursive_merge(first_dict: dict[str, Any], second_dict: dict[str, Any]): - # consider using deepmerge instead of this for key, val in second_dict.items(): if key not in first_dict: diff --git a/terratorch_iterate/main.py b/terratorch_iterate/main.py new file mode 100644 index 0000000..8b23c4b --- /dev/null +++ b/terratorch_iterate/main.py @@ -0,0 +1,415 @@ +import os +from jsonargparse import Namespace +import logging +from pathlib import Path +from jsonargparse import ArgumentParser +import pandas as pd +from terratorch_iterate.backbone_benchmark import benchmark_backbone +from terratorch_iterate.iterate_types import Defaults, Task +from terratorch_iterate.repeat_best_experiment import rerun_best_from_backbone +from terratorch_iterate.utils import ( + get_logger, + import_custom_modules, + get_results_and_parameters, +) +from terratorch_iterate.config_util import build_iterate_config + + +def _summarize( + config_init: Namespace, + hpo: bool, + repeat: bool, + storage_uri: str, + logger: logging.RootLogger, +) -> pd.DataFrame: + """only summarize results from multiple experiments + + Args: + config_init (Namespace): _description_ + hpo (bool): flag that indicates whether to run hpo + repeat (bool): flag that indicates whether to repeat best experiment + storage_uri (str): path to directory in which results will be stored + logger (logging.RootLogger): logger variable + + Returns: + _type_: _description_ + """ + assert hpo is False and repeat is False, ( + f"Error! both {repeat=} and {hpo=} must be False when summarizing results from multiple experiments." + ) + + list_of_experiment_names = config_init.list_of_experiment_names + assert isinstance(list_of_experiment_names, list), ( + f"Error! {list_of_experiment_names=} is not a list" + ) + for exp in list_of_experiment_names: + assert isinstance(exp, str), f"Error! {exp=} is not a str" + + task_names = config_init.task_names + assert isinstance(task_names, list), f"Error! {task_names=} is not a list" + for t in task_names: + assert isinstance(t, str), f"Error! {t=} is not a str" + + task_metrics = config_init.task_metrics + assert isinstance(task_metrics, list), f"Error! {task_metrics=} is not a list" + for t in task_metrics: + assert isinstance(t, str), f"Error! {t=} is not a str" + + benchmark_name = config_init.benchmark_name + assert isinstance(benchmark_name, str), f"Error! {benchmark_name=} is not a str" + + run_repetitions = config_init.run_repetitions + assert isinstance(run_repetitions, int) and run_repetitions > 0, ( + f"Error! {run_repetitions=} is invalid" + ) + # get results and parameters from mlflow logs + results_and_parameters = get_results_and_parameters( + benchmark_name=benchmark_name, + storage_uri=storage_uri, + logger=logger, + experiments=list_of_experiment_names, + task_names=task_names, + num_repetitions=run_repetitions, + task_metrics=task_metrics, + ) + return results_and_parameters + + +def _repeat_experiment( + config_init: Namespace, + storage_uri: str, + experiment_name: str, + parent_run_id: str, + defaults: Defaults, + tasks: list[Task], + optimization_space: dict, + run_repetitions: int, + save_models: bool, + report_on_best_val: bool, + logger: logging.RootLogger, +): + """repeat best experiments + + Args: + config_init (Namespace): _description_ + storage_uri (str): _description_ + experiment_name (str): _description_ + parent_run_id (str): _description_ + defaults (Defaults): _description_ + tasks (list[Task]): _description_ + optimization_space (dict): _description_ + run_repetitions (int): _description_ + save_models (bool): _description_ + report_on_best_val (bool): _description_ + logger (logging.RootLogger): _description_ + + Returns: + _type_: _description_ + """ + output: str | None = config_init.output_path + if output is None: + storage_uri_path = Path(storage_uri) + assert storage_uri_path.exists() and storage_uri_path.is_dir(), ( + f"Error! Unable to create new output_path based on storage_uri_path because the latter does not exist: {storage_uri_path}" + ) + output_path = storage_uri_path.parents[0] / "repeated_exp_output_csv" + output_path.mkdir(parents=True, exist_ok=True) + output_path = output_path / f"{experiment_name}_repeated_exp_mlflow.csv" + output = str(output_path) + + logger.info("Rerun best experiments...") + rerun_best_from_backbone( + logger=logger, + parent_run_id=parent_run_id, + output_path=output_path, + defaults=defaults, + tasks=tasks, + experiment_name=experiment_name, + storage_uri=storage_uri, + optimization_space=optimization_space, + run_repetitions=run_repetitions, + save_models=save_models, + report_on_best_val=report_on_best_val, + ) + + +def _convert_config(args: Namespace): + """ + This function processes command-line arguments to convert configuration files. + + Parameters: + args (argparse.Namespace): Namespace object containing command-line arguments. + + Raises: + AssertionError: If input or output paths are invalid or missing. + + This function performs the following steps: + 1. Asserts that the 'input' argument is a non-empty string and checks if the file exists. + 2. Asserts that the 'output' argument is a non-empty string. + 3. Calls the `generate_iterate_config` function from the `build_iterate_config` module, passing the input path, output path, prefix (if provided), and template (if provided). + """ + input: str = args.input + assert input is not None and isinstance(input, str), ( + f"Error! Invalid value: {input=}" + ) + input_path = Path(input) + assert input_path.exists() + + output: str = args.output + assert output is not None and isinstance(output, str), ( + f"Error! Invalid value: {output=}" + ) + output_path = Path(output) + template: str | None = args.template + + prefix: str | None = args.prefix + + template: str | None = args.template + build_iterate_config.generate_iterate_config( + input=input_path, output=output_path, prefix=prefix, template=template + ) + + +def main(): + parser = ArgumentParser() + + parser.add_argument("--defaults", type=Defaults) # to ignore model + parser.add_argument("--optimization_space", type=dict) # to ignore model + parser.add_argument("--experiment_name", type=str) # to ignore model + parser.add_argument("--run_name", type=str) # to ignore model + parser.add_argument("--save_models", type=bool) # to ignore model + parser.add_argument("--storage_uri", type=str) # to ignore model + parser.add_argument("--ray_storage_path", type=str) # to ignore model + parser.add_argument("--n_trials", type=int) # to ignore model + parser.add_argument("--run_repetitions", type=int) # to ignore model + parser.add_argument("--tasks", type=list[Task]) + parser.add_argument("--parent_run_id", type=str) + parser.add_argument("--output_path", type=str) + parser.add_argument("--logger", type=str) + parser.add_argument("--config", type=str) + parser.add_argument("--custom_modules_path", type=str) + parser.add_argument("--report_on_best_val", type=bool, default=True) + parser.add_argument("--test_models", type=bool, default=False) + parser.add_argument("--bayesian_search", type=bool, default=True) + parser.add_argument("--hpo", help="optimize hyperparameters", action="store_true") + parser.add_argument("--repeat", help="repeat best experiments", action="store_true") + parser.add_argument( + "--continue_existing_experiments", + help="continue existing experiments", + action="store_true", + ) + parser.add_argument( + "--summarize", + help="summarize results from repeated experiments", + action="store_true", + ) + parser.add_argument("--list_of_experiment_names", type=list[str]) + parser.add_argument("--task_names", type=list[str]) + parser.add_argument("--task_metrics", type=list[str]) + parser.add_argument( + "--benchmark_name", + type=str, + help="name of summarized results file", + ) + # arguments to convert terratorch's config into iterate's config + parser.add_argument( + "--build_iterate_config", + help="convert terratorch's config into terratorch-iterate's config", + action="store_true", + ) + parser.add_argument( + "--input", + help="input file or directory", + type=str, + ) + parser.add_argument( + "--output", + help="output file or directory", + type=str, + ) + parser.add_argument( + "--template", + help="template for creating config files", + type=str, + ) + parser.add_argument( + "--prefix", + help="prefix of new config files", + type=str, + ) + + args = parser.parse_args() + if args.build_iterate_config is not None and args.build_iterate_config is True: + _convert_config(args) + else: + config_path: str | None = args.config + if config_path is None: + msg = """ + Error: config argument has not been passed + usage: terratorch iterate [-h] [--hpo] [--repeat] [--summarize] [--config CONFIG] + """ + print(msg) + else: + assert isinstance(config_path, str), ( + f"Error! Unexpected config type: {config_path}" + ) + config = parser.parse_path(config_path) + + config_init: Namespace = parser.instantiate_classes(config) + + summarize: bool = args.summarize + assert isinstance(summarize, bool), f"Error! {summarize=} is not a bool" + repeat = args.repeat + assert isinstance(repeat, bool), f"Error! {repeat=} is not a bool" + hpo = args.hpo + assert isinstance(hpo, bool), f"Error! {hpo=} is not a bool" + + continue_existing_experiments: bool = args.continue_existing_experiments + assert isinstance(continue_existing_experiments, bool), ( + f"Error! {continue_existing_experiments=} is not a bool" + ) + + storage_uri = config_init.storage_uri + assert isinstance(storage_uri, str), f"Error! {storage_uri=} is not a str" + os.environ["MLFLOW_TRACKING_URI"] = storage_uri + # handling relative paths + if storage_uri.startswith(".") or storage_uri.startswith(".."): + repo_home_dir = Path(__file__).parent.parent + abs_path = repo_home_dir / storage_uri + storage_uri = str(abs_path.resolve()) + + logger_path = config_init.logger + if logger_path is None: + storage_uri_path = Path(storage_uri) + logger = get_logger( + log_folder=f"{str(storage_uri_path.parents[0])}/job_logs" + ) + else: + logging.config.fileConfig( + fname=logger_path, disable_existing_loggers=False + ) + logger = logging.getLogger("terratorch-iterate") + + # only summarize results from multiple experiments + if summarize: + return _summarize( + config_init=config_init, + ) + + # optimize hyperparameters and/or do repeated runs for single experiments + assert hpo is True or repeat is True, ( + f"Error! either {repeat=} or {hpo=} must be True" + ) + parent_run_id = args.parent_run_id + if parent_run_id is not None: + assert isinstance(parent_run_id, str), ( + f"Error! {parent_run_id=} is not a str" + ) + + # validate the objects + experiment_name = config_init.experiment_name + assert isinstance(experiment_name, str), ( + f"Error! {experiment_name=} is not a str" + ) + run_name = config_init.run_name + if run_name is not None: + assert isinstance(run_name, str), f"Error! {run_name=} is not a str" + # validate defaults + defaults = config_init.defaults + assert isinstance(defaults, Defaults), ( + f"Error! {defaults=} is not a Defaults" + ) + + tasks = config_init.tasks + assert isinstance(tasks, list), f"Error! {tasks=} is not a list" + for t in tasks: + assert isinstance(t, Task), f"Error! {t=} is not a Task" + # if there is not specific terratorch_task specified, then use default terratorch_task + if t.terratorch_task is None: + t.terratorch_task = defaults.terratorch_task + # defaults.trainer_args["max_epochs"] = 5 + + optimization_space = config_init.optimization_space + assert isinstance(optimization_space, dict), ( + f"Error! {optimization_space=} is not a dict" + ) + + # ray_storage_path is optional + ray_storage_path = config_init.ray_storage_path + if ray_storage_path is not None: + assert isinstance(ray_storage_path, str), ( + f"Error! {ray_storage_path=} is not a str" + ) + + n_trials = config_init.n_trials + assert isinstance(n_trials, int) and n_trials > 0, ( + f"Error! {n_trials=} is invalid" + ) + run_repetitions = config_init.run_repetitions + + report_on_best_val = config_init.report_on_best_val + assert isinstance(report_on_best_val, bool), ( + f"Error! {ray_storage_path=} is not a bool" + ) + + save_models = config_init.save_models + assert isinstance(save_models, bool), f"Error! {save_models=} is not a bool" + + test_models = config_init.test_models + assert isinstance(test_models, bool), f"Error! {test_models=} is not a bool" + + bayesian_search = config_init.bayesian_search + assert isinstance(bayesian_search, bool), ( + f"Error! {bayesian_search=} is not a bool" + ) + + # custom_modules_path is optional + custom_modules_path = config_init.custom_modules_path + if custom_modules_path is not None: + assert isinstance(custom_modules_path, str), ( + f"Error! {custom_modules_path=} is not a str" + ) + import_custom_modules( + logger=logger, custom_modules_path=custom_modules_path + ) + + if repeat and not hpo: + _repeat_experiment( + config_init=config_init, + storage_uri=storage_uri, + experiment_name=experiment_name, + defaults=defaults, + tasks=tasks, + optimization_space=optimization_space, + run_repetitions=run_repetitions, + save_models=save_models, + logger=logger, + ) + else: + if not repeat and hpo: + run_repetitions = 0 + + # run_repetitions is an optional parameter + experiment_info: dict = benchmark_backbone( + defaults=defaults, + tasks=tasks, + experiment_name=experiment_name, + storage_uri=storage_uri, + ray_storage_path=ray_storage_path, + run_name=run_name, + run_id=None, + optimization_space=optimization_space, + n_trials=n_trials, + run_repetitions=run_repetitions, + save_models=save_models, + report_on_best_val=report_on_best_val, + test_models=test_models, + bayesian_search=bayesian_search, + continue_existing_experiment=continue_existing_experiments, + logger=logger, + ) + return experiment_info + + +if __name__ == "__main__": + main() diff --git a/benchmark/model_fitting.py b/terratorch_iterate/model_fitting.py similarity index 96% rename from benchmark/model_fitting.py rename to terratorch_iterate/model_fitting.py index f2b7544..8ef1819 100644 --- a/benchmark/model_fitting.py +++ b/terratorch_iterate/model_fitting.py @@ -44,7 +44,7 @@ from torchgeo.datamodules import BaseDataModule from torchgeo.trainers import BaseTask -from benchmark.benchmark_types import ( +from terratorch_iterate.iterate_types import ( ParameterBounds, ParameterTypeEnum, TrainingSpec, @@ -54,7 +54,7 @@ ) -from benchmark.utils import get_logger +from terratorch_iterate.utils import get_logger LOGGER = get_logger() @@ -119,9 +119,9 @@ def __init__(self, *args, **kwargs): def inject_hparams(training_spec: TrainingSpec, config: dict): # treat batch size specially config_without_batch_size = copy.deepcopy(config) - assert isinstance( - config_without_batch_size, dict - ), f"Error! Unexpected config type: {config_without_batch_size}" + assert isinstance(config_without_batch_size, dict), ( + f"Error! Unexpected config type: {config_without_batch_size}" + ) batch_size: int | None = config_without_batch_size.pop("batch_size", None) # type: ignore datamodule_with_generated_hparams = copy.deepcopy(training_spec.task.datamodule) if batch_size: @@ -310,9 +310,9 @@ def launch_training( ["metric_name", "step"], verify_integrity=True ) series_val_metrics = df_val_metrics["value"] - assert ( - metric in series_val_metrics - ), f"Error! {metric} is not in {series_val_metrics}" + assert metric in series_val_metrics, ( + f"Error! {metric} is not in {series_val_metrics}" + ) if direction == "max": best_step = series_val_metrics[metric].idxmax() elif direction == "min": @@ -351,9 +351,9 @@ def fit_model( PixelwiseRegressionTask, ]: task.terratorch_task["plot_on_val"] = False - assert isinstance( - task.terratorch_task, dict - ), f"Error! Invalid type: {task.terratorch_task}" + assert isinstance(task.terratorch_task, dict), ( + f"Error! Invalid type: {task.terratorch_task}" + ) lightning_task = lightning_task_class(**task.terratorch_task) @@ -445,18 +445,16 @@ def fit_model_with_hparams( ) run_name = f"{run_name}_{trial.number}" return fit_model( - training_spec_with_generated_hparams, - lightning_task_class, - run_name, - experiment_name, - storage_uri, - parent_run_id, - trial, + training_spec=training_spec_with_generated_hparams, + lightning_task_class=lightning_task_class, + run_name=run_name, + experiment_name=experiment_name, + storage_uri=storage_uri, + parent_run_id=parent_run_id, + trial=trial, save_models=save_models, test_models=test_models, - )[ - 0 - ] # return only the metric value for optuna + )[0] # return only the metric value for optuna """ @@ -476,7 +474,6 @@ def ray_tune_model( backbone_import: str | None = None, searcher: Searcher | SearchAlgorithm | None = None, ) -> tune.ResultGrid: - if not searcher: raise ValueError("searcher must be specified") trainable = tune.with_parameters( diff --git a/benchmark/module.py b/terratorch_iterate/module.py similarity index 100% rename from benchmark/module.py rename to terratorch_iterate/module.py diff --git a/benchmark/plot_tools.py b/terratorch_iterate/plot_tools.py similarity index 100% rename from benchmark/plot_tools.py rename to terratorch_iterate/plot_tools.py diff --git a/benchmark/repeat_best_experiment.py b/terratorch_iterate/repeat_best_experiment.py similarity index 96% rename from benchmark/repeat_best_experiment.py rename to terratorch_iterate/repeat_best_experiment.py index 273132f..3a168dd 100644 --- a/benchmark/repeat_best_experiment.py +++ b/terratorch_iterate/repeat_best_experiment.py @@ -24,13 +24,13 @@ from lightning.pytorch.loggers.mlflow import MLFlowLogger import time -from benchmark.benchmark_types import ( +from terratorch_iterate.iterate_types import ( Defaults, Task, TrainingSpec, combine_with_defaults, ) -from benchmark.model_fitting import ( +from terratorch_iterate.model_fitting import ( get_default_callbacks, inject_hparams, valid_task_types, @@ -53,7 +53,6 @@ def remote_fit( run_name=f"{lightning_task_class.name}_{seed}", nested=True, ): - training_spec_copy = copy.deepcopy(training_spec) training_spec_with_generated_hparams = inject_hparams( training_spec_copy, best_params @@ -78,9 +77,7 @@ def remote_fit( # get callbacks (set to empty list if none defined) and extend with default ones training_spec_with_generated_hparams.trainer_args.setdefault( "callbacks", [] - ).extend( - default_callbacks - ) # type: ignore + ).extend(default_callbacks) # type: ignore if "enable_checkpointing" in training_spec_with_generated_hparams.trainer_args: warnings.warn( "enable_checkpointing found. Will be overwritten to False as ray will be responsible for saving models." @@ -105,8 +102,8 @@ def remote_fit( test_metric = ( "test/" + task.metric.split("/")[1] - if '/' in task.metric - else 'test_' + task.metric.replace(task.metric.split('_')[0] + "_", '') + if "/" in task.metric + else "test_" + task.metric.replace(task.metric.split("_")[0] + "_", "") ) mlflow.log_metric(f"test_{test_metric}", metrics[test_metric]) return metrics[test_metric] @@ -179,9 +176,7 @@ def non_remote_fit( # get callbacks (set to empty list if none defined) and extend with default ones training_spec_with_generated_hparams.trainer_args.setdefault( "callbacks", [] - ).extend( - default_callbacks - ) # type: ignore + ).extend(default_callbacks) # type: ignore trainer = Trainer(**training_spec_with_generated_hparams.trainer_args) trainer.logger = MLFlowLogger( @@ -217,8 +212,8 @@ def non_remote_fit( # return None test_metric = ( "test/" + task.metric.split("/")[1] - if '/' in task.metric - else 'test_' + task.metric.replace(task.metric.split('_')[0] + "_", '') + if "/" in task.metric + else "test_" + task.metric.replace(task.metric.split("_")[0] + "_", "") ) mlflow.log_metric(f"test_{test_metric}", metrics[test_metric]) return metrics[test_metric] @@ -303,7 +298,9 @@ def rerun_best_from_backbone( with mlflow.start_run(run_name=experiment_name, run_id=None) as run: for task in tasks: logger.info(f"\n\ntask: {task.name}") - matching_runs = [run for run in runs if run.info.run_name.endswith(task.name)] # type: ignore + matching_runs = [ + run for run in runs if run.info.run_name.endswith(task.name) + ] # type: ignore if len(matching_runs) == 0: msg = f"No runs found for task {task.name}. Skipping." warnings.warn(msg) @@ -442,7 +439,7 @@ def rerun_best_from_backbone( ) existing_output.reset_index(inplace=True) existing_output = existing_output.drop( - columns=['index', 'level_0'] + columns=["index", "level_0"] ) existing_output.to_csv(output_path, index=False) else: diff --git a/benchmark/utils.py b/terratorch_iterate/utils.py similarity index 95% rename from benchmark/utils.py rename to terratorch_iterate/utils.py index 8c8a7f8..1cf0e38 100644 --- a/benchmark/utils.py +++ b/terratorch_iterate/utils.py @@ -10,12 +10,11 @@ from matplotlib import pyplot as plt from ast import literal_eval import optuna -from benchmark.benchmark_types import Task -from benchmark import plot_tools +from terratorch_iterate.iterate_types import Task +from terratorch_iterate import plot_tools import sys from mlflow.entities.experiment import Experiment import importlib -import logging N_TRIALS_DEFAULT = 16 REPEATED_SEEDS_DEFAULT = 10 @@ -63,17 +62,20 @@ def sync_mlflow_optuna( Returns: task_run_id: run id of the task to be continued (if one exists) or None """ + logger.info( + f"sync_mlflow_optuna - {optuna_db_path=} {storage_uri=} {task_run_id=} {experiment_name=} {task_run_id=}" + ) # check number of successful mlflow runs in task client = mlflow.tracking.MlflowClient(tracking_uri=storage_uri) completed_in_mlflow_for_task = [] all_mlflow_runs_for_task = [] if task_run_id is not None: all_mlflow_runs_for_task.append(task_run_id) - logger.info(f"task_run_id : {task_run_id}") + logger.info(f"sync_mlflow_optuna - {task_run_id=}") experiment_info = client.get_experiment_by_name(experiment_name) - assert isinstance( - experiment_info, Experiment - ), f"Error! Unexpected type of {experiment_info=}" + assert isinstance(experiment_info, Experiment), ( + f"Error! Unexpected type of {experiment_info=}" + ) individual_run_data = client.search_runs( experiment_ids=[experiment_info.experiment_id], filter_string=f'tags."mlflow.parentRunId" LIKE "{task_run_id}"', @@ -124,9 +126,9 @@ def sync_mlflow_optuna( for item in all_mlflow_runs_for_task: logger.info(f"deleting {item}") client.delete_run(item) - assert isinstance( - experiment_info, Experiment - ), f"Error! Unexpected type of {experiment_info=}" + assert isinstance(experiment_info, Experiment), ( + f"Error! Unexpected type of {experiment_info=}" + ) os.system(f"rm -r {experiment_info.artifact_location}/{item}") task_run_id = None else: @@ -135,11 +137,12 @@ def sync_mlflow_optuna( for item in all_mlflow_runs_for_task: logger.info(f"deleting {item}") client.delete_run(item) - assert isinstance( - experiment_info, Experiment - ), f"Error! Unexpected type of {experiment_info=}" + assert isinstance(experiment_info, Experiment), ( + f"Error! Unexpected type of {experiment_info=}" + ) os.system(f"rm -r {experiment_info.artifact_location}/{item}") task_run_id = None + logging.info(f"sync_mlflow_optuna returns {task_run_id=}") return task_run_id @@ -211,7 +214,7 @@ def extract_repeated_experiment_results( seed = int(run.info.run_name.split("_")[-1]) if task in task_info: metric_name = task_info[task] - metric_name = 'test_test/' + metric_name.split("/")[-1] + metric_name = "test_test/" + metric_name.split("/")[-1] else: continue @@ -350,19 +353,19 @@ def extract_parameters( best_params["data_percentages"] = DATA_PARTITIONS[ best_params["partition_name"] ] - if 'optimizer_hparams' in best_params: + if "optimizer_hparams" in best_params: logger.info( f"optimizer_hparams: {best_params['optimizer_hparams'].items()}" ) optimizer_hparams = { - k: v for k, v in best_params['optimizer_hparams'].items() + k: v for k, v in best_params["optimizer_hparams"].items() } best_params.update(optimizer_hparams) - del best_params['optimizer_hparams'] - if 'model_args' in best_params: - model_args = {k: v for k, v in best_params['model_args'].items()} + del best_params["optimizer_hparams"] + if "model_args" in best_params: + model_args = {k: v for k, v in best_params["model_args"].items()} best_params.update(model_args) - del best_params['model_args'] + del best_params["model_args"] best_params = pd.DataFrame(best_params, index=[0]) all_params.append(best_params) @@ -421,11 +424,11 @@ def get_results_and_parameters( task_metrics=task_metrics, ) - with open(f"{results_dir}/incomplete_experiments.txt", 'w') as f: + with open(f"{results_dir}/incomplete_experiments.txt", "w") as f: for line in incomplete_experiments: f.write(f"{line}\n") results_and_parameters = results.merge( - parameters, on=['experiment_name', 'dataset'] + parameters, on=["experiment_name", "dataset"] ) results_and_parameters.to_csv( f"{str(results_dir)}/results_and_parameters.csv", index=False @@ -790,7 +793,7 @@ def get_logger(log_level="INFO", log_folder="./experiment_logs") -> logging.Root handler = logging.FileHandler(log_file) handler.setLevel(log_level) formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) handler.setFormatter(formatter) logger.addHandler(handler) @@ -802,13 +805,10 @@ def import_custom_modules( logger: logging.RootLogger, custom_modules_path: str | Path | None = None, ) -> None: - if custom_modules_path: - custom_modules_path = Path(custom_modules_path) if custom_modules_path.is_dir(): - # Add 'custom_modules' folder to sys.path workdir = custom_modules_path.parents[0] module_dir = custom_modules_path.name diff --git a/benchmark/py.typed b/tests/integration/__init__.py similarity index 100% rename from benchmark/py.typed rename to tests/integration/__init__.py diff --git a/tests/integration/test_main.py b/tests/integration/test_main.py new file mode 100644 index 0000000..acc342d --- /dev/null +++ b/tests/integration/test_main.py @@ -0,0 +1,92 @@ +import itertools +from pathlib import Path + +import yaml +from terratorch_iterate.main import main +import pytest +import sys + +CONFIG_FILES = [ + # "configs/tests/benchmark_v2_simple.yaml", + "configs/tests/dofa_large_patch16_224_upernetdecoder_true_modified.yaml", + "configs/tests/terratorch-iterate-configs/test_case_02/test_config_util__encoderdecoder_eo_v2_300_model_factory.yaml", + "configs/tests/terratorch-iterate-configs/test_case_03/test_config_util__encoder_decoder_timm_resnet101_model_factory.yaml", +] +HPO = [True] +INPUT_TEST_MAIN = list(itertools.product(HPO, CONFIG_FILES)) + + +def get_test_ids() -> list[str]: + test_case_ids = list() + for hpo, config in INPUT_TEST_MAIN: + # get the filename + filename = config.split("/")[-1].replace(".yaml", "") + # set test id + tid = f"{filename}_hpo_{hpo}" + # append to list of test ids + test_case_ids.append(tid) + return test_case_ids + + +def validate_results(experiment_name: str, storage_uri: str, finished_run_id: str): + # get the most recent modified directory + dir_path = Path(storage_uri) / finished_run_id + assert dir_path.exists(), f"Error! Directory does not exist: {dir_path}" + # find mlflow.runName files within the result dir + meta_yaml = "meta.yaml" + + meta_yaml_path = dir_path / meta_yaml + assert meta_yaml_path.exists(), ( + f"Error! meta.yaml file {meta_yaml_path} does not exist" + ) + # open file and check that the experiment name is the same + with open(meta_yaml_path, mode="r") as f: + # read all the lines + lines = f.readlines() + # try to find experiment id and name in these lines + experiment_name_found: bool = False + experiment_id_found: bool = False + for line in lines: + if experiment_name in line: + experiment_name_found = True + if finished_run_id in line: + experiment_id_found = True + assert experiment_name_found and experiment_id_found, ( + f"Error! Both experiment name ({experiment_name=}) and finished run id ({finished_run_id=}) must be in the {meta_yaml_path=}: {experiment_id_found=} {experiment_name_found=}" + ) + # TODO delete the directories that were created by this test case + + +@pytest.mark.parametrize( + "hpo, config", + INPUT_TEST_MAIN, + ids=get_test_ids(), +) +def test_main( + hpo: bool, + config: str, +): + home_dir = Path(__file__).parent.parent.parent + config_file: Path = home_dir / config + assert config_file.exists() + with open(config_file, "r") as file: + config_data = yaml.safe_load(file) + storage_uri: str = config_data["storage_uri"] + # handling relative paths + if storage_uri.startswith(".") or storage_uri.startswith(".."): + repo_home_dir = Path(__file__).parent.parent.parent + abs_path = repo_home_dir / storage_uri + storage_uri = str(abs_path.resolve()) + experiment_name = config_data["experiment_name"] + arguments = ["terratorch", "--config", str(config_file.resolve())] + if hpo: + arguments.insert(1, "--hpo") + sys.argv = arguments + # main only returns a dict when hpo is True + mlflow_info = main() + assert isinstance(mlflow_info, dict), f"Error! {mlflow_info=} is not a dict" + validate_results( + experiment_name=experiment_name, + storage_uri=storage_uri, + finished_run_id=mlflow_info["experiment_id"], + ) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py deleted file mode 100644 index 54cea49..0000000 --- a/tests/test_benchmark.py +++ /dev/null @@ -1,344 +0,0 @@ -import itertools -from benchmark.benchmark_types import Defaults, Task, TaskTypeEnum -import pytest -from benchmark.backbone_benchmark import benchmark_backbone -from terratorch.datamodules import MChesapeakeLandcoverNonGeoDataModule -from albumentations import HorizontalFlip, VerticalFlip, Resize -from albumentations.pytorch.transforms import ToTensorV2 -import os -from pathlib import Path -import uuid -from jsonargparse import ArgumentParser - - -BACKBONE_PRETRAINED_FILE = os.getenv( - "BACKBONE_PRETRAINED_FILE", - "/dccstor/geofm-finetuning/pretrain_ckpts/v9_no_sea/vit_b/epoch-395-loss-0.0339_clean.pt", -) - -SEGMENTATION_V1 = os.getenv( - "SEGMENTATION_V1", "/dccstor/geofm-finetuning/datasets/geobench/segmentation_v1.0" -) - -# OUTPUT_DIR = os.getenv( -# "OUTPUT_DIR", "/dccstor/geofm-finetuning/terratorch-iterate-test/" -# ) - -RAY_STORAGE = os.getenv( - "RAY_STORAGE", "/dccstor/geofm-finetuning/terratorch-iterate-test/ray_storage" -) - - -@pytest.fixture(scope="module") -def defaults() -> Defaults: - file = BACKBONE_PRETRAINED_FILE - assert Path(file).exists(), f"Error! {file=} does not exist" - trainer_args = { - "precision": "bf16-mixed", - "max_epochs": 10, - } - terratorch_task = { - "model_args": { - "pretrained": True, - "backbone": "prithvi_vit_100", - "backbone_out_indices": [2, 5, 8, 11], - "backbone_pretrained_cfg_overlay": {"file": file}, - }, - "model_factory": "PrithviModelFactory", - "optimizer": "AdamW", - } - return Defaults(trainer_args=trainer_args, terratorch_task=terratorch_task) - - -@pytest.fixture(scope="module") -def mchesapeakelandcovernongeodatamodule() -> MChesapeakeLandcoverNonGeoDataModule: - data_root = SEGMENTATION_V1 - assert Path(data_root).exists(), f"Error! Directory {data_root} does not exist" - train_transform = [Resize(height=224, width=224), ToTensorV2()] - test_transform = [ - HorizontalFlip(p=0.5), - VerticalFlip(p=0.5), - Resize(height=224, width=224), - ToTensorV2(), - ] - return MChesapeakeLandcoverNonGeoDataModule( - num_workers=6, - batch_size=16, - partition="0.10x_train", - train_transform=train_transform, - test_transform=test_transform, - data_root=data_root, - bands=["RED", "GREEN", "BLUE", "NIR"], - ) - - -@pytest.fixture(scope="module") -def tasks(mchesapeakelandcovernongeodatamodule): - - t = Task( - name="chesapeake", - type=TaskTypeEnum.segmentation, - direction="max", - metric="val/Multiclass_Jaccard_Index", - early_stop_patience=10, - terratorch_task={ - "loss": "ce", - "model_args": { - "decoder": "UperNetDecoder", - "decoder_channels": 128, - "decoder_scale_modules": True, - "bands": ["RED", "GREEN", "BLUE", "NIR"], - "num_classes": 7, - }, - }, - datamodule=mchesapeakelandcovernongeodatamodule, - ) - return [t] - - -def find_file(directory: str, filename: str): - for root, _, files in os.walk(directory): - if filename in files: - return os.path.join(root, filename) - return None - - -CONFIG_FILES = [ - # "configs/tests/geobench_v1_resnet_cashew.yaml", - # "configs/tests/geobench_v1_prithvi_cashew.yaml", - "configs/tests/benchmark_v2_simple.yaml", - "configs/tests/dofa_large_patch16_224_upernetdecoder_true_modified.yaml", - "configs/tests/test_config_util__prithvi_eo_v1_100.yaml", - # "configs/tests/geobench_v1_ssl4eos12_resnet50_sentinel2_all_moco_smp_unet_true.yaml", - # "configs/nasabench_vit_b_os.yaml", -] -CONTINUE_EXISTING_EXPERIMENT = [True, False] -TEST_MODELS = [True, False] -INPUT_TEST_RUN_BENCHMARK = list( - itertools.product(CONFIG_FILES, CONTINUE_EXISTING_EXPERIMENT, TEST_MODELS) -) -TEST_CASE_IDS = [str(i) for i in range(0, len(INPUT_TEST_RUN_BENCHMARK))] - - -@pytest.mark.parametrize( - "config, continue_existing_experiment, test_models", - INPUT_TEST_RUN_BENCHMARK, - ids=TEST_CASE_IDS, -) -def test_run_benchmark( - config: str, continue_existing_experiment: bool, test_models: bool -): - path = os.path.join(os.getcwd(), config) - config_path = Path(path) - # instantiate objects from yaml - parser = ArgumentParser() - parser.add_argument('--defaults', type=Defaults) # to ignore model - parser.add_argument('--optimization_space', type=dict) # to ignore model - parser.add_argument('--experiment_name', type=str) # to ignore model - parser.add_argument('--run_name', type=str) # to ignore model - parser.add_argument('--save_models', type=bool) # to ignore model - parser.add_argument('--storage_uri', type=str) # to ignore model - parser.add_argument('--ray_storage_path', type=str) # to ignore model - parser.add_argument('--n_trials', type=int) # to ignore model - parser.add_argument('--run_repetitions', type=int) # to ignore model - parser.add_argument('--tasks', type=list[Task]) - config = parser.parse_path(str(config_path)) - config_init = parser.instantiate_classes(config) - # validate the objects - experiment_name = config_init.experiment_name - experiment_name = f"{experiment_name}_continue_{continue_existing_experiment}_test_models_{test_models}" - assert isinstance(experiment_name, str), f"Error! {experiment_name=} is not a str" - run_name = config_init.run_name - if run_name is not None: - assert isinstance(run_name, str), f"Error! {run_name=} is not a str" - tasks = config_init.tasks - assert isinstance(tasks, list), f"Error! {tasks=} is not a list" - for t in tasks: - assert isinstance(t, Task), f"Error! {t=} is not a Task" - defaults = config_init.defaults - assert isinstance(defaults, Defaults), f"Error! {defaults=} is not a Defaults" - # defaults.trainer_args["max_epochs"] = 5 - storage_uri = config_init.storage_uri - assert isinstance(storage_uri, str), f"Error! {storage_uri=} is not a str" - storage_uri_path = Path(storage_uri) / uuid.uuid4().hex / "hpo" - if not storage_uri_path.exists(): - try: - storage_uri_path.mkdir(parents=True, exist_ok=True) - print(f"Directory created at: {path}") - except FileNotFoundError as e: - print(f"Error creating directory: {e}") - - optimization_space = config_init.optimization_space - assert isinstance( - optimization_space, dict - ), f"Error! {optimization_space=} is not a dict" - ray_storage = RAY_STORAGE - assert isinstance(ray_storage, str), f"Error! {ray_storage=} is not a str" - ray_storage_path = Path(ray_storage) / uuid.uuid4().hex - if not ray_storage_path.exists(): - try: - ray_storage_path.mkdir(parents=True, exist_ok=True) - print(f"Directory created at: {path}") - except FileNotFoundError as e: - print(f"Error creating directory: {e}") - n_trials = config_init.n_trials - assert isinstance(n_trials, int) and n_trials > 0, f"Error! {n_trials=} is invalid" - # run_repetions is an optional parameter - run_repetitions = config_init.run_repetitions - if run_repetitions is not None: - assert ( - isinstance(run_repetitions, int) and run_repetitions >= 0 - ), f"Error! {run_repetitions=} is invalid" - else: - run_repetitions = 0 - mlflow_info = benchmark_backbone( - experiment_name=experiment_name, - run_name=run_name, - run_id=None, - defaults=defaults, - tasks=tasks, - n_trials=n_trials, - save_models=False, - storage_uri=str(storage_uri_path), - ray_storage_path=str(ray_storage_path), - optimization_space=optimization_space, - continue_existing_experiment=continue_existing_experiment, - test_models=test_models, - run_repetitions=run_repetitions, - logger=None, - ) - assert isinstance(mlflow_info, dict), f"Error! {mlflow_info=} is not a dict" - validate_results( - experiment_name=experiment_name, - storage_uri=str(storage_uri_path), - finished_run_id=mlflow_info["experiment_id"], - ) - - -@pytest.mark.parametrize( - "config, continue_existing_experiment, test_models", - [ - ("configs/tests/benchmark_marida_l2a_terramind_base.yaml", False, False), - ], -) -def test_run_benchmark_no_specific_terratorch_task( - config: str, continue_existing_experiment: bool, test_models: bool -): - - path = os.path.join(os.getcwd(), config) - config_path = Path(path) - assert ( - config_path.exists() - ), f"Error! config does not exist: {config_path.resolve()}" - # instantiate objects from yaml - parser = ArgumentParser() - parser.add_argument('--defaults', type=Defaults) # to ignore model - parser.add_argument('--optimization_space', type=dict) # to ignore model - parser.add_argument('--experiment_name', type=str) # to ignore model - parser.add_argument('--run_name', type=str) # to ignore model - parser.add_argument('--save_models', type=bool) # to ignore model - parser.add_argument('--storage_uri', type=str) # to ignore model - parser.add_argument('--ray_storage_path', type=str) # to ignore model - parser.add_argument('--n_trials', type=int) # to ignore model - parser.add_argument('--run_repetitions', type=int) # to ignore model - parser.add_argument('--tasks', type=list[Task]) - config = parser.parse_path(str(config_path)) - config_init = parser.instantiate_classes(config) - # validate the objects - experiment_name = config_init.experiment_name - experiment_name = f"{experiment_name}_continue_{continue_existing_experiment}_test_models_{test_models}" - assert isinstance(experiment_name, str), f"Error! {experiment_name=} is not a str" - run_name = config_init.run_name - if run_name is not None: - assert isinstance(run_name, str), f"Error! {run_name=} is not a str" - tasks = config_init.tasks - assert isinstance(tasks, list), f"Error! {tasks=} is not a list" - for t in tasks: - assert isinstance(t, Task), f"Error! {t=} is not a Task" - if t.terratorch_task is not None: - t.terratorch_task = None - - defaults = config_init.defaults - assert isinstance(defaults, Defaults), f"Error! {defaults=} is not a Defaults" - # defaults.trainer_args["max_epochs"] = 5 - storage_uri = config_init.storage_uri - assert isinstance(storage_uri, str), f"Error! {storage_uri=} is not a str" - storage_uri_path = Path(storage_uri) / uuid.uuid4().hex / "hpo" - if not storage_uri_path.exists(): - try: - storage_uri_path.mkdir(parents=True, exist_ok=True) - print(f"Directory created at: {path}") - except FileNotFoundError as e: - print(f"Error creating directory: {e}") - optimization_space = config_init.optimization_space - assert isinstance( - optimization_space, dict - ), f"Error! {optimization_space=} is not a dict" - ray_storage = RAY_STORAGE - assert isinstance(ray_storage, str), f"Error! {ray_storage=} is not a str" - ray_storage_path = Path(ray_storage) / uuid.uuid4().hex - if not ray_storage_path.exists(): - try: - ray_storage_path.mkdir(parents=True, exist_ok=True) - print(f"Directory created at: {path}") - except FileNotFoundError as e: - print(f"Error creating directory: {e}") - n_trials = config_init.n_trials - assert isinstance(n_trials, int) and n_trials > 0, f"Error! {n_trials=} is invalid" - # run_repetions is an optional parameter - run_repetitions = config_init.run_repetitions - if run_repetitions is not None: - assert ( - isinstance(run_repetitions, int) and run_repetitions >= 0 - ), f"Error! {run_repetitions=} is invalid" - else: - run_repetitions = 0 - finished_run_id = benchmark_backbone( - experiment_name=experiment_name, - run_name=run_name, - run_id=None, - defaults=defaults, - tasks=tasks, - n_trials=n_trials, - save_models=False, - storage_uri=str(storage_uri_path), - ray_storage_path=str(ray_storage_path), - optimization_space=optimization_space, - continue_existing_experiment=continue_existing_experiment, - test_models=test_models, - run_repetitions=run_repetitions, - ) - validate_results( - experiment_name=experiment_name, - storage_uri=str(storage_uri_path), - finished_run_id=finished_run_id, - ) - - -def validate_results(experiment_name: str, storage_uri: str, finished_run_id: str): - # get the most recent modified directory - dir_path = Path(storage_uri) / finished_run_id - assert dir_path.exists(), f"Error! Directory does not exist: {dir_path}" - # find mlflow.runName files within the result dir - meta_yaml = "meta.yaml" - - meta_yaml_path = dir_path / meta_yaml - assert ( - meta_yaml_path.exists() - ), f"Error! meta.yaml file {meta_yaml_path} does not exist" - # open file and check that the experiment name is the same - with open(meta_yaml_path, mode="r") as f: - # read all the lines - lines = f.readlines() - # try to find experiment id and name in these lines - experiment_name_found: bool = False - experiment_id_found: bool = False - for line in lines: - if experiment_name in line: - experiment_name_found = True - if finished_run_id in line: - experiment_id_found = True - assert ( - experiment_name_found and experiment_id_found - ), f"Error! Both experiment name ({experiment_name=}) and finished run id ({finished_run_id=}) must be in the {meta_yaml_path=}: {experiment_id_found=} {experiment_name_found=}" - # TODO delete the directories that were created by this test case diff --git a/tests/test_build_geobench_configs.py b/tests/test_build_geobench_configs.py deleted file mode 100644 index 62987d6..0000000 --- a/tests/test_build_geobench_configs.py +++ /dev/null @@ -1,86 +0,0 @@ -from pathlib import Path -import pytest -import yaml -from benchmark.config_util.build_geobench_configs import generate_iterate_config -from deepdiff import DeepDiff - - -@pytest.mark.parametrize( - "input_dir, output_dir, template, prefix", - [ - ( - # terratorch branch geobench_v2_od - "/Users/ltizzei/Projects/Orgs/IBM/terratorch/examples/confs/geobenchv2_detection", - "/Users/ltizzei/Projects/Orgs/IBM/terratorch-iterate/tests/test_config_util", - "/Users/ltizzei/Projects/Orgs/IBM/terratorch-iterate/benchmark/config_util/geobenchv2_template.yaml", - "test_examples_confs_geobenchv2_detection", - ), - ( - "/Users/ltizzei/Projects/Orgs/IBM/terratorch/tests/resources/configs", - "/Users/ltizzei/Projects/Orgs/IBM/terratorch-iterate/tests/test_config_util", - "/Users/ltizzei/Projects/Orgs/IBM/terratorch-iterate/benchmark/config_util/geobenchv2_template.yaml", - "test_config_util_", - ), - ], -) -def test__generate_iterate_config(input_dir, output_dir, template, prefix): - input_dir_path = Path(input_dir) - assert input_dir_path.exists() - assert input_dir_path.is_dir() - output_path = Path(output_dir) - assert output_path.exists() - assert output_path.is_dir() - # warning! delete all files of the output dir - for item in output_path.iterdir(): - if item.is_file(): - item.unlink() - - generate_iterate_config( - input_dir=input_dir_path, - output_dir=output_path, - template=template, - prefix=prefix, - ) - generated_config_files = list(output_path.glob(f'**/{prefix}*.yaml')) - assert len(generated_config_files) > 0 - - oracle_config_files = [ - f for f in input_dir_path.glob(f'**/geobench*.yaml') if "template" not in str(f) - ] - for gen_config_file in generated_config_files: - end_gen_config_filename = gen_config_file.name.replace(prefix, "") - for oracle_config_file in oracle_config_files: - end_oracle_config_filename = oracle_config_file.name.replace( - "geobenchv2", "" - ) - if end_gen_config_filename == end_oracle_config_filename: - with open(gen_config_file, "r") as gen_file: - new_config = yaml.safe_load(gen_file) - with open(oracle_config_file, "r") as gt_file: - oracle_config = yaml.safe_load(gt_file) - - oracle_tasks = oracle_config["tasks"] - new_config_tasks = new_config["tasks"] - # comparing the tasks - for oracle_task in oracle_tasks: - oracle_task_name = oracle_task["name"] - found = False - for new_config_task in new_config_tasks: - new_config_task_name = new_config_task["name"] - if new_config_task_name == oracle_task_name: - - diff = DeepDiff(new_config_task, oracle_task) - if len(diff) == 0: - found = True - else: - for k in [ - "datamodule", - "direction", - "metric", - "terratorch_task", - "type", - ]: - diff = DeepDiff(new_config_task[k], oracle_task[k]) - assert len(diff) == 0, f"Error! {diff}" - found = True - assert found diff --git a/tests/test_cli.py b/tests/test_cli.py deleted file mode 100644 index 71ed890..0000000 --- a/tests/test_cli.py +++ /dev/null @@ -1,6 +0,0 @@ -import os - - -def test_cli(): - exit_status = os.system('terratorch iterate --help') - assert exit_status == 0 diff --git a/tests/unit/test_build_geobench_configs.py b/tests/unit/test_build_geobench_configs.py new file mode 100644 index 0000000..a75ae55 --- /dev/null +++ b/tests/unit/test_build_geobench_configs.py @@ -0,0 +1,110 @@ +from pathlib import Path +import pytest +import yaml +from terratorch_iterate.config_util.build_iterate_config import generate_iterate_config +from deepdiff import DeepDiff +import logging + +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +@pytest.mark.parametrize( + "input, output, template, prefix, oracle_config_file", + [ + ( + "./configs/tests/terratorch_configs/test_case_01", + "./configs/tests/terratorch-iterate-configs/test_case_01", + "./configs/templates/template.yaml", + "test_config_util_", + "./configs/tests/terratorch-iterate-configs/test_case_01/oracle/convnext_LM_iterate.yaml", + ), + ( + "./configs/tests/terratorch_configs/test_case_02", + "./configs/tests/terratorch-iterate-configs/test_case_02", + "./configs/templates/template.yaml", + "test_config_util_", + "./configs/tests/terratorch-iterate-configs/test_case_02/oracle/test_config_util__encoderdecoder_eo_v2_300_model_factory.yaml", + ), + ( + "./configs/tests/terratorch_configs/test_case_02/test_encoderdecoder_eo_v2_300_model_factory.yaml", + "./configs/tests/terratorch-iterate-configs/test_case_02/test_config_util__encoderdecoder_eo_v2_300_model_factory.yaml", + "./configs/templates/template.yaml", + "test_config_util_", + "./configs/tests/terratorch-iterate-configs/test_case_02/oracle/test_config_util__encoderdecoder_eo_v2_300_model_factory.yaml", + ), + ( + "./configs/tests/terratorch_configs/test_case_03", + "./configs/tests/terratorch-iterate-configs/test_case_03", + "./configs/templates/template.yaml", + "test_config_util_", + None, + ), + ], +) +def test__generate_iterate_config(input, output, template, prefix, oracle_config_file): + # Get the absolute path of the current script file + script_path = Path(__file__).resolve() + + # Get the home directory + repo_home_dir = script_path.parent.parent.parent + input_path: Path = repo_home_dir / input + assert input_path.exists() + output_path: Path = repo_home_dir / output + assert output_path.exists() + # warning! delete all files of the output dir + if output_path.is_dir(): + for item in output_path.iterdir(): + if item.is_file(): + logging.debug(f"Cleaning up directory: {item} deleted") + item.unlink() + else: + output_path.unlink() + + generate_iterate_config( + input=input_path, + output=output_path, + template=repo_home_dir / template, + prefix=prefix, + ) + if output_path.is_dir(): + generated_config_files = list(output_path.glob(f"**/{prefix}*.yaml")) + else: + generated_config_files = [output_path] + + assert len(generated_config_files) > 0 + + if oracle_config_file is not None: + oracle_path: Path = repo_home_dir / oracle_config_file + with open(oracle_path, "r") as gt_file: + oracle_config = yaml.safe_load(gt_file) + + for gen_config_file in generated_config_files: + with open(gen_config_file, "r") as gen_file: + new_config = yaml.safe_load(gen_file) + + oracle_tasks = oracle_config["tasks"] + new_config_tasks = new_config["tasks"] + # comparing the tasks + for oracle_task in oracle_tasks: + found = False + if oracle_task.get("name") is not None: + del oracle_task["name"] + for new_config_task in new_config_tasks: + if new_config_task.get("name") is not None: + del new_config_task["name"] + + diff = DeepDiff(new_config_task, oracle_task) + if len(diff) == 0: + found = True + else: + for k in [ + "datamodule", + "direction", + "metric", + "terratorch_task", + "type", + ]: + diff = DeepDiff(new_config_task[k], oracle_task[k]) + assert len(diff) == 0, f"Error! {diff}" + found = True + assert found, f"Error! task not found: {oracle_task}" diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py new file mode 100644 index 0000000..953ac53 --- /dev/null +++ b/tests/unit/test_cli.py @@ -0,0 +1,6 @@ +import os + + +def test_cli(): + exit_status = os.system("terratorch iterate --help") + assert exit_status == 0 diff --git a/tests/unit/test_tasktypeenum.py b/tests/unit/test_tasktypeenum.py index 24ce069..1dfb669 100644 --- a/tests/unit/test_tasktypeenum.py +++ b/tests/unit/test_tasktypeenum.py @@ -1,4 +1,4 @@ -from benchmark.benchmark_types import TaskTypeEnum +from terratorch_iterate.iterate_types import TaskTypeEnum import pytest from terratorch.tasks.base_task import TerraTorchTask from terratorch.tasks.classification_tasks import ClassificationTask diff --git a/tox.ini b/tox.ini index 68ae8fd..21e49b6 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ [tox] requires = tox>=4.23.0 -env_list = 3.1{2,1,0}, lint, style +env_list = 3.1{3,2,1}, lint, style isolated_build = true skip_missing_interpreters = false @@ -9,8 +9,8 @@ skip_missing_interpreters = false description = run code style skip_install = true deps = - black -commands = black {posargs:.} + ruff +commands = ruff format {posargs:.} [testenv:lint] description = run linters