From 370ad715fae99d7036a88b1aa44c781162b72a85 Mon Sep 17 00:00:00 2001 From: Paolo Fraccaro Date: Tue, 4 Nov 2025 13:42:12 +0000 Subject: [PATCH 1/3] fix types and imports --- terratorch_iterate/iterate_types.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/terratorch_iterate/iterate_types.py b/terratorch_iterate/iterate_types.py index 1ff0201..dd9bca9 100644 --- a/terratorch_iterate/iterate_types.py +++ b/terratorch_iterate/iterate_types.py @@ -6,7 +6,7 @@ import copy import enum from dataclasses import dataclass, field, replace -from typing import Any, Optional, Union +from typing import Any, Optional, Union, TYPE_CHECKING from terratorch.tasks import ( ClassificationTask, MultiLabelClassificationTask, @@ -16,6 +16,20 @@ ) from torchgeo.datamodules import BaseDataModule +import logging + +try: + from geobench_v2.datamodules import GeoBenchDataModule + GEOBENCH_AVAILABLE = True +except ImportError: + GeoBenchDataModule = None # type: ignore + GEOBENCH_AVAILABLE = False + logging.getLogger("terratorch").debug("geobench_v2 not installed") + + +if TYPE_CHECKING: + from geobench_v2.datamodules import GeoBenchDataModule + valid_task_types = type[ SemanticSegmentationTask | ClassificationTask @@ -129,7 +143,7 @@ class Task: name: str type: TaskTypeEnum = field(repr=False) - datamodule: BaseDataModule = field(repr=False) + datamodule: Union[BaseDataModule, "GeoBenchDataModule"] = field(repr=False) direction: str terratorch_task: Optional[dict[str, Any]] = None metric: str = "val/loss" From 74a943e9c3f0d8beea3b50658306653106a47b65 Mon Sep 17 00:00:00 2001 From: Leonardo P Tizzei Date: Tue, 4 Nov 2025 11:52:50 -0300 Subject: [PATCH 2/3] remove torch upper bound Signed-off-by: Leonardo P Tizzei --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e1f4527..fff020c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,7 @@ dependencies = [ "optuna-integration", "psutil", "tabulate>=0.9.0", -"torch<=2.2.2", +"torch", "seaborn" ] From 7dd301b285a550d5ae3564f7d338884f26a93c62 Mon Sep 17 00:00:00 2001 From: Leonardo P Tizzei Date: Tue, 4 Nov 2025 12:11:00 -0300 Subject: [PATCH 3/3] include deepdiff Signed-off-by: Leonardo P Tizzei --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fff020c..5c4fd83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,8 @@ dev = [ test = [ "coverage", "pytest", - "pytest-cov" + "pytest-cov", + "deepdiff" ] utility = [