diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..fe58f73 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,17 @@ +name: Lint + +on: + push: + branches: [main] + pull_request: + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/setup-uv@v5 + - run: uv python install 3.10 + - run: uv pip install ruff + - run: uv run ruff check . + - run: uv run ruff format --check . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..700cbc4 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,7 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.12.2 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/baselines/da_v2.py b/baselines/da_v2.py index bca560a..105cc85 100644 --- a/baselines/da_v2.py +++ b/baselines/da_v2.py @@ -1,8 +1,8 @@ # Reference: https://github.com/DepthAnything/Depth-Anything-V2 import os import sys -from typing import * from pathlib import Path +from typing import * import click import torch diff --git a/baselines/da_v2_metric.py b/baselines/da_v2_metric.py index ee4c70d..4b83f86 100644 --- a/baselines/da_v2_metric.py +++ b/baselines/da_v2_metric.py @@ -1,15 +1,14 @@ # Reference https://github.com/DepthAnything/Depth-Anything-V2/metric_depth import os import sys -from typing import * from pathlib import Path +from typing import * import click import torch import torch.nn.functional as F import torchvision.transforms as T import torchvision.transforms.functional as TF -import cv2 from moge.test.baseline import MGEBaselineInterface diff --git a/baselines/metric3d_v2.py b/baselines/metric3d_v2.py index 661ed5d..d5365db 100644 --- a/baselines/metric3d_v2.py +++ b/baselines/metric3d_v2.py @@ -1,12 +1,10 @@ # Reference: https://github.com/YvanYin/Metric3D -import os -import sys from typing import * import click +import cv2 import torch import torch.nn.functional as F -import cv2 from moge.test.baseline import MGEBaselineInterface diff --git a/baselines/moge.py b/baselines/moge.py index fd66d69..07cd99e 100644 --- a/baselines/moge.py +++ b/baselines/moge.py @@ -1,7 +1,4 @@ -import os -import sys from typing import * -import importlib import click import torch diff --git a/moge/model/dinov2/hub/utils.py b/moge/model/dinov2/hub/utils.py index 9c66414..a7482f9 100644 --- a/moge/model/dinov2/hub/utils.py +++ b/moge/model/dinov2/hub/utils.py @@ -10,7 +10,6 @@ import torch.nn as nn import torch.nn.functional as F - _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" diff --git a/moge/model/dinov2/layers/__init__.py b/moge/model/dinov2/layers/__init__.py index 05a0b61..26f3f38 100644 --- a/moge/model/dinov2/layers/__init__.py +++ b/moge/model/dinov2/layers/__init__.py @@ -3,9 +3,9 @@ # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. +from .attention import MemEffAttention +from .block import NestedTensorBlock from .dino_head import DINOHead from .mlp import Mlp from .patch_embed import PatchEmbed from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused -from .block import NestedTensorBlock -from .attention import MemEffAttention diff --git a/moge/model/dinov2/layers/attention.py b/moge/model/dinov2/layers/attention.py index c9f79d4..6d5b02c 100644 --- a/moge/model/dinov2/layers/attention.py +++ b/moge/model/dinov2/layers/attention.py @@ -9,12 +9,9 @@ import logging import os -import warnings import torch.nn.functional as F -from torch import Tensor -from torch import nn - +from torch import Tensor, nn logger = logging.getLogger("dinov2") diff --git a/moge/model/dinov2/layers/block.py b/moge/model/dinov2/layers/block.py index fd5b8a7..8b852d8 100644 --- a/moge/model/dinov2/layers/block.py +++ b/moge/model/dinov2/layers/block.py @@ -9,25 +9,23 @@ import logging import os -from typing import Callable, List, Any, Tuple, Dict -import warnings +from typing import Any, Callable, Dict, List, Tuple import torch -from torch import nn, Tensor +from torch import Tensor, nn from .attention import Attention, MemEffAttention from .drop_path import DropPath from .layer_scale import LayerScale from .mlp import Mlp - logger = logging.getLogger("dinov2") XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None try: if XFORMERS_ENABLED: - from xformers.ops import fmha, scaled_index_add, index_select_cat + from xformers.ops import fmha, index_select_cat, scaled_index_add XFORMERS_AVAILABLE = True # warnings.warn("xFormers is available (Block)") diff --git a/moge/model/dinov2/layers/layer_scale.py b/moge/model/dinov2/layers/layer_scale.py index 51df0d7..405dd84 100644 --- a/moge/model/dinov2/layers/layer_scale.py +++ b/moge/model/dinov2/layers/layer_scale.py @@ -8,8 +8,7 @@ from typing import Union import torch -from torch import Tensor -from torch import nn +from torch import Tensor, nn class LayerScale(nn.Module): diff --git a/moge/model/dinov2/layers/patch_embed.py b/moge/model/dinov2/layers/patch_embed.py index 8b7c080..d170d10 100644 --- a/moge/model/dinov2/layers/patch_embed.py +++ b/moge/model/dinov2/layers/patch_embed.py @@ -9,8 +9,8 @@ from typing import Callable, Optional, Tuple, Union -from torch import Tensor import torch.nn as nn +from torch import Tensor def make_2tuple(x): diff --git a/moge/model/dinov2/layers/swiglu_ffn.py b/moge/model/dinov2/layers/swiglu_ffn.py index 5ce2115..22b8326 100644 --- a/moge/model/dinov2/layers/swiglu_ffn.py +++ b/moge/model/dinov2/layers/swiglu_ffn.py @@ -5,10 +5,9 @@ import os from typing import Callable, Optional -import warnings -from torch import Tensor, nn import torch.nn.functional as F +from torch import Tensor, nn class SwiGLUFFN(nn.Module): diff --git a/moge/model/dinov2/models/__init__.py b/moge/model/dinov2/models/__init__.py index 3fdff20..50f937a 100644 --- a/moge/model/dinov2/models/__init__.py +++ b/moge/model/dinov2/models/__init__.py @@ -7,7 +7,6 @@ from . import vision_transformer as vits - logger = logging.getLogger("dinov2") diff --git a/moge/model/dinov2/models/vision_transformer.py b/moge/model/dinov2/models/vision_transformer.py index f0bed9d..432b181 100644 --- a/moge/model/dinov2/models/vision_transformer.py +++ b/moge/model/dinov2/models/vision_transformer.py @@ -7,18 +7,18 @@ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py -from functools import partial -import math import logging -from typing import Sequence, Tuple, Union, Callable, Optional, List +import math +from functools import partial +from typing import Callable, Sequence, Tuple, Union import torch import torch.nn as nn import torch.utils.checkpoint from torch.nn.init import trunc_normal_ -from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block - +from ..layers import MemEffAttention, Mlp, PatchEmbed, SwiGLUFFNFused +from ..layers import NestedTensorBlock as Block logger = logging.getLogger("dinov2") diff --git a/moge/model/dinov2/utils/cluster.py b/moge/model/dinov2/utils/cluster.py index 3df87dc..36f38ba 100644 --- a/moge/model/dinov2/utils/cluster.py +++ b/moge/model/dinov2/utils/cluster.py @@ -3,8 +3,8 @@ # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. -from enum import Enum import os +from enum import Enum from pathlib import Path from typing import Any, Dict, Optional diff --git a/moge/model/dinov2/utils/config.py b/moge/model/dinov2/utils/config.py index c9de578..377f282 100644 --- a/moge/model/dinov2/utils/config.py +++ b/moge/model/dinov2/utils/config.py @@ -3,17 +3,15 @@ # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. -import math import logging +import math import os -from omegaconf import OmegaConf - import dinov2.distributed as distributed +from dinov2.configs import dinov2_default_config from dinov2.logging import setup_logging from dinov2.utils import utils -from dinov2.configs import dinov2_default_config - +from omegaconf import OmegaConf logger = logging.getLogger("dinov2") diff --git a/moge/model/dinov2/utils/dtype.py b/moge/model/dinov2/utils/dtype.py index 80f4cd7..7ebdedf 100644 --- a/moge/model/dinov2/utils/dtype.py +++ b/moge/model/dinov2/utils/dtype.py @@ -9,7 +9,6 @@ import numpy as np import torch - TypeSpec = Union[str, np.dtype, torch.dtype] diff --git a/moge/model/dinov2/utils/param_groups.py b/moge/model/dinov2/utils/param_groups.py index 9a5d2ff..996775c 100644 --- a/moge/model/dinov2/utils/param_groups.py +++ b/moge/model/dinov2/utils/param_groups.py @@ -3,9 +3,8 @@ # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. -from collections import defaultdict import logging - +from collections import defaultdict logger = logging.getLogger("dinov2") diff --git a/moge/model/dinov2/utils/utils.py b/moge/model/dinov2/utils/utils.py index 68f8e2c..a26b42e 100644 --- a/moge/model/dinov2/utils/utils.py +++ b/moge/model/dinov2/utils/utils.py @@ -13,7 +13,6 @@ import torch from torch import nn - logger = logging.getLogger("dinov2") diff --git a/moge/model/modules.py b/moge/model/modules.py index b36ad48..1d94cf7 100644 --- a/moge/model/modules.py +++ b/moge/model/modules.py @@ -1,18 +1,17 @@ -from typing import * -from numbers import Number +import functools import importlib import itertools -import functools -import sys +from typing import * import torch -from torch import Tensor import torch.nn as nn import torch.nn.functional as F from .dinov2.models.vision_transformer import DinoVisionTransformer -from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing -from ..utils.geometry_torch import normalized_view_plane_uv +from .utils import ( + wrap_dinov2_attention_with_sdpa, + wrap_module_with_gradient_checkpointing, +) class ResidualConvBlock(nn.Module): diff --git a/moge/model/utils.py b/moge/model/utils.py index c50761d..42493d5 100644 --- a/moge/model/utils.py +++ b/moge/model/utils.py @@ -4,6 +4,7 @@ import torch.nn as nn import torch.nn.functional as F + def wrap_module_with_gradient_checkpointing(module: nn.Module): from torch.utils.checkpoint import checkpoint class _CheckpointingWrapper(module.__class__): diff --git a/moge/model/v1.py b/moge/model/v1.py index 2513b86..8c964c0 100644 --- a/moge/model/v1.py +++ b/moge/model/v1.py @@ -1,10 +1,8 @@ -from typing import * -from numbers import Number -from functools import partial -from pathlib import Path import importlib import warnings -import json +from numbers import Number +from pathlib import Path +from typing import * import torch import torch.nn as nn @@ -15,10 +13,10 @@ import utils3d from huggingface_hub import hf_hub_download - -from ..utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, gaussian_blur_2d, dilate_with_mask -from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing -from ..utils.tools import timeit +from ..utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift +from .utils import ( + wrap_module_with_gradient_checkpointing, +) class ResidualConvBlock(nn.Module): diff --git a/moge/model/v2.py b/moge/model/v2.py index 5cf8028..2553e66 100644 --- a/moge/model/v2.py +++ b/moge/model/v2.py @@ -1,24 +1,22 @@ -from typing import * +import warnings from numbers import Number -from functools import partial from pathlib import Path -import warnings +from typing import * import torch +import torch.amp import torch.nn as nn import torch.nn.functional as F import torch.utils import torch.utils.checkpoint -import torch.amp import torch.version import utils3d from huggingface_hub import hf_hub_download -from ..utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, angle_diff_vec3 -from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing -from .modules import DINOv2Encoder, MLP, ConvStack +from ..utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift +from .modules import MLP, ConvStack, DINOv2Encoder + - class MoGeModel(nn.Module): encoder: DINOv2Encoder neck: ConvStack diff --git a/moge/scripts/app.py b/moge/scripts/app.py index 9a63e62..e56f468 100644 --- a/moge/scripts/app.py +++ b/moge/scripts/app.py @@ -1,17 +1,18 @@ import os + os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' import sys from pathlib import Path + if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: sys.path.insert(0, _package_root) -import time -import uuid -import tempfile -import itertools -from typing import * import atexit -from concurrent.futures import ThreadPoolExecutor +import itertools import shutil +import tempfile +import time +from concurrent.futures import ThreadPoolExecutor +from typing import * import click @@ -25,24 +26,22 @@ def main(share: bool, pretrained_model_name_or_path: str, model_version: str, us print("Import modules...") # Lazy import import cv2 - import torch + import gradio as gr import numpy as np + import torch import trimesh import trimesh.visual from PIL import Image - import gradio as gr try: - import spaces # This is for deployment at huggingface.co/spaces + import spaces # This is for deployment at huggingface.co/spaces HUGGINFACE_SPACES_INSTALLED = True except ImportError: HUGGINFACE_SPACES_INSTALLED = False import utils3d - from moge.utils.io import write_normal - from moge.utils.vis import colorize_depth, colorize_normal + from moge.model import import_model_class_by_version - from moge.utils.geometry_numpy import depth_occlusion_edge_numpy - from moge.utils.tools import timeit + from moge.utils.vis import colorize_depth, colorize_normal print("Load model...") if pretrained_model_name_or_path is None: @@ -180,9 +179,9 @@ def run(image: np.ndarray, max_size: int = 800, resolution_level: str = 'High', fov_x, fov_y = np.rad2deg([fov_x, fov_y]) # messages - viewer_message = f'**Note:** Inference has been completed. It may take a few seconds to download the 3D model.' + viewer_message = '**Note:** Inference has been completed. It may take a few seconds to download the 3D model.' if resolution_level != 'Ultra': - depth_message = f'**Note:** Want sharper depth map? Try increasing the `maximum image size` and setting the `inference resolution level` to `Ultra` in the settings.' + depth_message = '**Note:** Want sharper depth map? Try increasing the `maximum image size` and setting the `inference resolution level` to `Ultra` in the settings.' else: depth_message = "" @@ -230,7 +229,7 @@ def measure(results: Dict[str, np.ndarray], measure_points: List[Tuple[int, int] print("Create Gradio app...") with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( -f''' +'''
diff --git a/moge/scripts/cli.py b/moge/scripts/cli.py index 45c3b90..c868ba2 100644 --- a/moge/scripts/cli.py +++ b/moge/scripts/cli.py @@ -1,7 +1,9 @@ import os + os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' -from pathlib import Path import sys +from pathlib import Path + if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: sys.path.insert(0, _package_root) @@ -13,7 +15,7 @@ def cli(): pass def main(): - from moge.scripts import app, infer, infer_baseline, infer_panorama, eval_baseline, vis_data + from moge.scripts import app, eval_baseline, infer, infer_baseline, infer_panorama, vis_data cli.add_command(app.main, name='app') cli.add_command(infer.main, name='infer') cli.add_command(infer_baseline.main, name='infer_baseline') diff --git a/moge/scripts/eval_baseline.py b/moge/scripts/eval_baseline.py index 8217d9e..d7db8c4 100644 --- a/moge/scripts/eval_baseline.py +++ b/moge/scripts/eval_baseline.py @@ -1,12 +1,10 @@ -import os import sys from pathlib import Path + if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: sys.path.insert(0, _package_root) import json from typing import * -import importlib -import importlib.util import click @@ -22,19 +20,17 @@ @click.pass_context def main(ctx: click.Context, baseline_code_path: str, config_path: str, oracle_mode: bool, output_path: Union[str, Path], dump_pred: bool, dump_gt: bool): # Lazy import - import cv2 + import cv2 import numpy as np - from tqdm import tqdm import torch - import torch.nn.functional as F - import utils3d + from tqdm import tqdm from moge.test.baseline import MGEBaselineInterface from moge.test.dataloader import EvalDataLoaderPipeline from moge.test.metrics import compute_metrics from moge.utils.geometry_torch import intrinsics_to_fov + from moge.utils.tools import import_file_as_module, key_average, timeit from moge.utils.vis import colorize_depth, colorize_normal - from moge.utils.tools import key_average, flatten_nested_dict, timeit, import_file_as_module # Load the baseline model module = import_file_as_module(baseline_code_path, Path(baseline_code_path).stem) @@ -76,7 +72,7 @@ def main(ctx: click.Context, baseline_code_path: str, config_path: str, oracle_m metrics_list.append(metrics) # Dump results - dump_path = Path(output_path.replace(".json", f"_dump"), f'{benchmark_name}', sample['filename'].replace('.zip', '')) + dump_path = Path(output_path.replace(".json", "_dump"), f'{benchmark_name}', sample['filename'].replace('.zip', '')) if dump_pred: dump_path.joinpath('pred').mkdir(parents=True, exist_ok=True) cv2.imwrite(str(dump_path / 'pred' / 'image.jpg'), cv2.cvtColor((image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)) diff --git a/moge/scripts/infer.py b/moge/scripts/infer.py index 09990f3..ce82a9e 100644 --- a/moge/scripts/infer.py +++ b/moge/scripts/infer.py @@ -1,13 +1,15 @@ import os + os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' -from pathlib import Path import sys +from pathlib import Path + if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: sys.path.insert(0, _package_root) -from typing import * import itertools import json import warnings +from typing import * import click @@ -52,15 +54,12 @@ def main( import cv2 import numpy as np import torch - from PIL import Image + import utils3d from tqdm import tqdm - import click from moge.model import import_model_class_by_version from moge.utils.io import save_glb, save_ply from moge.utils.vis import colorize_depth, colorize_normal - from moge.utils.geometry_numpy import depth_occlusion_edge_numpy - import utils3d device = torch.device(device_name) diff --git a/moge/scripts/infer_baseline.py b/moge/scripts/infer_baseline.py index ef81bc4..1e3caef 100644 --- a/moge/scripts/infer_baseline.py +++ b/moge/scripts/infer_baseline.py @@ -1,14 +1,16 @@ import os + os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' -from pathlib import Path import sys +from pathlib import Path + if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: sys.path.insert(0, _package_root) +import itertools import json +import warnings from pathlib import Path from typing import * -import itertools -import warnings import click @@ -26,17 +28,17 @@ @click.pass_context def main(ctx: click.Context, baseline_code_path: str, input_path: str, output_path: str, image_size: int, skip: bool, save_maps_, save_ply_: bool, save_glb_: bool, threshold: float): # Lazy import - import cv2 + import cv2 import numpy as np - from tqdm import tqdm import torch import utils3d + from tqdm import tqdm - from moge.utils.io import save_ply, save_glb + from moge.test.baseline import MGEBaselineInterface from moge.utils.geometry_numpy import intrinsics_to_fov_numpy + from moge.utils.io import save_glb, save_ply + from moge.utils.tools import import_file_as_module, timeit from moge.utils.vis import colorize_depth, colorize_depth_affine, colorize_disparity - from moge.utils.tools import key_average, flatten_nested_dict, timeit, import_file_as_module - from moge.test.baseline import MGEBaselineInterface # Load the baseline model module = import_file_as_module(baseline_code_path, Path(baseline_code_path).stem) diff --git a/moge/scripts/infer_panorama.py b/moge/scripts/infer_panorama.py index 525a8ad..a99cb35 100644 --- a/moge/scripts/infer_panorama.py +++ b/moge/scripts/infer_panorama.py @@ -1,17 +1,18 @@ import os + os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' -from pathlib import Path import sys +from pathlib import Path + if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: sys.path.insert(0, _package_root) -from typing import * import itertools -import json import warnings +from typing import * import click - + @click.command(help='Inference script for panorama images') @click.option('--input', '-i', 'input_path', type=click.Path(exists=True), required=True, help='Input image or folder path. "jpg" and "png" are supported.') @click.option('--output', '-o', 'output_path', type=click.Path(), default='./output', help='Output folder path') @@ -44,21 +45,21 @@ def main( # Lazy import import cv2 import numpy as np - from numpy import ndarray import torch - from PIL import Image - from tqdm import tqdm, trange import trimesh import trimesh.visual - from scipy.sparse import csr_array, hstack, vstack - from scipy.ndimage import convolve - from scipy.sparse.linalg import lsmr - import utils3d + from tqdm import tqdm, trange + from moge.model.v1 import MoGeModel from moge.utils.io import save_glb, save_ply + from moge.utils.panorama import ( + get_panorama_cameras, + merge_panorama_depth, + spherical_uv_to_directions, + split_panorama_image, + ) from moge.utils.vis import colorize_depth - from moge.utils.panorama import spherical_uv_to_directions, get_panorama_cameras, split_panorama_image, merge_panorama_depth device = torch.device(device_name) @@ -91,7 +92,7 @@ def main( splitted_images = split_panorama_image(image, splitted_extrinsics, splitted_intriniscs, splitted_resolution) # Infer each view - print('Inferring...') if pbar.disable else pbar.set_postfix_str(f'Inferring') + print('Inferring...') if pbar.disable else pbar.set_postfix_str('Inferring') splitted_distance_maps, splitted_masks = [], [] for i in trange(0, len(splitted_images), batch_size, desc='Inferring splitted views', disable=len(splitted_images) <= batch_size, leave=False): @@ -112,7 +113,7 @@ def main( cv2.imwrite(str(splitted_save_path / f'{i:02d}_distance_vis.png'), cv2.cvtColor(colorize_depth(splitted_distance_maps[i], splitted_masks[i]), cv2.COLOR_RGB2BGR)) # Merge - print('Merging...') if pbar.disable else pbar.set_postfix_str(f'Merging') + print('Merging...') if pbar.disable else pbar.set_postfix_str('Merging') merging_width, merging_height = min(1920, width), min(960, height) panorama_depth, panorama_mask = merge_panorama_depth(merging_width, merging_height, splitted_distance_maps, splitted_masks, splitted_extrinsics, splitted_intriniscs) @@ -122,7 +123,7 @@ def main( points = panorama_depth[:, :, None] * spherical_uv_to_directions(utils3d.np.uv_map(height, width)) # Write outputs - print('Writing outputs...') if pbar.disable else pbar.set_postfix_str(f'Inferring') + print('Writing outputs...') if pbar.disable else pbar.set_postfix_str('Inferring') save_path = Path(output_path, image_path.relative_to(input_path).parent, image_path.stem) save_path.mkdir(exist_ok=True, parents=True) if save_maps_: diff --git a/moge/scripts/train.py b/moge/scripts/train.py index 6d810cd..61e4742 100644 --- a/moge/scripts/train.py +++ b/moge/scripts/train.py @@ -1,50 +1,44 @@ -import os -from pathlib import Path import sys +from pathlib import Path + if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: sys.path.insert(0, _package_root) +import io import json -import time import random -from typing import * -import itertools -from contextlib import nullcontext +import time from concurrent.futures import ThreadPoolExecutor -import io +from typing import * -import numpy as np +import accelerate +import click import cv2 -from PIL import Image +import mlflow +import numpy as np import torch -import torch.nn as nn -import torch.nn.functional as F import torch.version -import accelerate +import utils3d from accelerate import Accelerator, DistributedDataParallelKwargs from accelerate.utils import set_seed -import utils3d -import click -from tqdm import tqdm, trange -import mlflow +from tqdm import tqdm + torch.backends.cudnn.benchmark = False # Varying input size, make sure cudnn benchmark is disabled from moge.train.dataloader import TrainDataLoaderPipeline from moge.train.losses import ( affine_invariant_global_loss, - affine_invariant_local_loss, + affine_invariant_local_loss, edge_loss, - normal_loss, - mask_l2_loss, mask_bce_loss, + mask_l2_loss, metric_scale_loss, + monitoring, + normal_loss, normal_map_loss, - monitoring, ) -from moge.train.utils import build_optimizer, build_lr_scheduler -from moge.utils.geometry_torch import intrinsics_to_fov +from moge.train.utils import build_lr_scheduler, build_optimizer +from moge.utils.tools import flatten_nested_dict, key_average from moge.utils.vis import colorize_depth, colorize_normal -from moge.utils.tools import key_average, recursive_replace, CallbackOnException, flatten_nested_dict -from moge.test.metrics import compute_metrics @click.command() @@ -168,7 +162,7 @@ def main( print(f'Load EMA model checkpoint: {checkpoint_ema_model_path}') checkpoint['ema_model'] = torch.load(checkpoint_ema_model_path, map_location='cpu', weights_only=True)['model'] else: - print(f'No latest checkpoint found. Start from scratch.') + print('No latest checkpoint found. Start from scratch.') checkpoint = None else: # - Load by step number @@ -340,7 +334,7 @@ def _write_bytes_retry_loop(save_path: Path, data: bytes): if accelerator.sync_gradients: if not enable_mixed_precision and any(torch.isnan(p.grad).any() for p in model.parameters() if p.grad is not None): if accelerator.is_main_process: - pbar.write(f'NaN gradients, skip update') + pbar.write('NaN gradients, skip update') optimizer.zero_grad() continue accelerator.clip_grad_norm_(model.parameters(), 1.0) diff --git a/moge/scripts/vis_data.py b/moge/scripts/vis_data.py index fcca724..7a7cf44 100644 --- a/moge/scripts/vis_data.py +++ b/moge/scripts/vis_data.py @@ -1,7 +1,9 @@ import os + os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' import sys from pathlib import Path + if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: sys.path.insert(0, _package_root) @@ -34,12 +36,12 @@ def main( # Lazy import import cv2 import numpy as np + import trimesh import utils3d from tqdm import tqdm - import trimesh - from moge.utils.io import read_image, read_depth, read_json - from moge.utils.vis import colorize_depth, colorize_normal + from moge.utils.io import read_depth, read_image, read_json + from moge.utils.vis import colorize_depth filepaths = sorted(p.parent for p in Path(folder_or_path).rglob('meta.json')) diff --git a/moge/test/dataloader.py b/moge/test/dataloader.py index 97a9298..4696e0f 100644 --- a/moge/test/dataloader.py +++ b/moge/test/dataloader.py @@ -1,18 +1,16 @@ -import os -from typing import * -from pathlib import Path import math +from pathlib import Path +from typing import * +import cv2 import numpy as np +import pipeline import torch -from PIL import Image -import cv2 import utils3d -import pipeline +from PIL import Image -from ..utils.geometry_numpy import focal_to_fov_numpy, norm3d +from ..utils.geometry_numpy import norm3d from ..utils.io import * -from ..utils.tools import timeit class EvalDataLoaderPipeline: diff --git a/moge/test/metrics.py b/moge/test/metrics.py index 4c79c33..190ba91 100644 --- a/moge/test/metrics.py +++ b/moge/test/metrics.py @@ -1,25 +1,19 @@ -from typing import * from numbers import Number +from typing import * import torch -import torch.nn.functional as F -import numpy as np import utils3d -from ..utils.geometry_torch import ( - weighted_mean, - intrinsics_to_fov -) from ..utils.alignment import ( - align_points_scale_z_shift, - align_points_scale_xyz_shift, - align_points_xyz_shift, - align_affine_lstsq, - align_depth_scale, - align_depth_affine, + align_affine_lstsq, + align_depth_affine, + align_depth_scale, align_points_scale, + align_points_scale_xyz_shift, + align_points_xyz_shift, ) -from ..utils.tools import key_average, timeit +from ..utils.geometry_torch import intrinsics_to_fov +from ..utils.tools import key_average def rel_depth(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6): diff --git a/moge/train/dataloader.py b/moge/train/dataloader.py index b08846d..63aaa03 100644 --- a/moge/train/dataloader.py +++ b/moge/train/dataloader.py @@ -1,26 +1,15 @@ -import os -from pathlib import Path -import json -import time import random +from pathlib import Path from typing import * -import traceback -import itertools -from numbers import Number -import io import numpy as np -import cv2 -from PIL import Image +import pipeline import torch -import torchvision.transforms.v2.functional as TF import utils3d -import pipeline from tqdm import tqdm +from ..utils.data_augmentation import image_color_augmentation, sample_perspective, warp_perspective from ..utils.io import * -from ..utils.geometry_numpy import harmonic_mean_numpy, norm3d, depth_occlusion_edge_numpy -from ..utils.data_augmentation import sample_perspective, warp_perspective, image_color_augmentation class TrainDataLoaderPipeline: diff --git a/moge/train/losses.py b/moge/train/losses.py index a568adf..6b9e61b 100644 --- a/moge/train/losses.py +++ b/moge/train/losses.py @@ -1,22 +1,18 @@ -from typing import * import math +from typing import * import torch import torch.nn.functional as F import utils3d -from ..utils.geometry_torch import ( - weighted_mean, - harmonic_mean, - geometric_mean, - normalized_view_plane_uv, - angle_diff_vec3 -) from ..utils.alignment import ( - align_points_scale_z_shift, - align_points_scale, align_points_scale_xyz_shift, - align_points_z_shift, + align_points_scale_z_shift, +) +from ..utils.geometry_torch import ( + angle_diff_vec3, + harmonic_mean, + weighted_mean, ) diff --git a/moge/train/utils.py b/moge/train/utils.py index 5f21e00..3adc851 100644 --- a/moge/train/utils.py +++ b/moge/train/utils.py @@ -1,5 +1,5 @@ -from typing import * import fnmatch +from typing import * import sympy import torch diff --git a/moge/utils/alignment.py b/moge/utils/alignment.py index 3d6bb78..fde51e3 100644 --- a/moge/utils/alignment.py +++ b/moge/utils/alignment.py @@ -1,13 +1,8 @@ -from typing import * import math -from collections import namedtuple +from typing import * -import numpy as np import torch -import torch.nn as nn -import torch.nn.functional as F import torch.types -import utils3d def scatter_min(size: int, dim: int, index: torch.LongTensor, src: torch.Tensor) -> torch.return_types.min: diff --git a/moge/utils/data_augmentation.py b/moge/utils/data_augmentation.py index 9fc4c9d..d3f23a9 100644 --- a/moge/utils/data_augmentation.py +++ b/moge/utils/data_augmentation.py @@ -1,22 +1,13 @@ -import os -import json -import time -import random from typing import * -import itertools -from numbers import Number -import io -import numpy as np import cv2 -from PIL import Image +import numpy as np import torch import torchvision.transforms.v2.functional as TF import utils3d +from PIL import Image from scipy.signal import fftconvolve -from ..utils.geometry_numpy import harmonic_mean_numpy, norm3d, depth_occlusion_edge_numpy - def sample_perspective( src_intrinsics: np.ndarray, diff --git a/moge/utils/download.py b/moge/utils/download.py index 886edbc..ea54ba2 100644 --- a/moge/utils/download.py +++ b/moge/utils/download.py @@ -1,10 +1,9 @@ from pathlib import Path from typing import * -import requests +import requests from tqdm import tqdm - __all__ = ["download_file", "download_bytes"] diff --git a/moge/utils/geometry_numpy.py b/moge/utils/geometry_numpy.py index 99de45c..622db36 100644 --- a/moge/utils/geometry_numpy.py +++ b/moge/utils/geometry_numpy.py @@ -1,14 +1,10 @@ -from typing import * from functools import partial -import math +from typing import * import cv2 import numpy as np -from scipy.signal import fftconvolve -import numpy as np import utils3d - -from .tools import timeit +from scipy.signal import fftconvolve def weighted_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray: diff --git a/moge/utils/geometry_torch.py b/moge/utils/geometry_torch.py index 20b5632..855fd34 100644 --- a/moge/utils/geometry_torch.py +++ b/moge/utils/geometry_torch.py @@ -1,15 +1,10 @@ from typing import * -import math -from collections import namedtuple -import numpy as np import torch -import torch.nn as nn import torch.nn.functional as F import torch.types import utils3d -from .tools import timeit from .geometry_numpy import solve_optimal_focal_shift, solve_optimal_shift diff --git a/moge/utils/io.py b/moge/utils/io.py index 47b1641..6c1431c 100644 --- a/moge/utils/io.py +++ b/moge/utils/io.py @@ -1,18 +1,15 @@ import os + os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' -from typing import IO -import zipfile -import json import io -from typing import * +import json from pathlib import Path -import re -from PIL import Image, PngImagePlugin +from typing import * +from typing import IO +import cv2 import numpy as np -import cv2 - -from .tools import timeit +from PIL import Image, PngImagePlugin def save_glb( @@ -52,7 +49,6 @@ def save_ply( ): import trimesh import trimesh.visual - from PIL import Image trimesh.Trimesh( vertices=vertices, diff --git a/moge/utils/panorama.py b/moge/utils/panorama.py index 42d915a..c1b31aa 100644 --- a/moge/utils/panorama.py +++ b/moge/utils/panorama.py @@ -1,21 +1,16 @@ import os + os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' -from pathlib import Path from typing import * -import itertools -import json -import warnings import cv2 import numpy as np +import utils3d from numpy import ndarray -from tqdm import tqdm, trange -from scipy.sparse import csr_array, hstack, vstack from scipy.ndimage import convolve +from scipy.sparse import csr_array, vstack from scipy.sparse.linalg import lsmr -import utils3d - def get_panorama_cameras(): vertices, _ = utils3d.np.create_icosahedron_mesh() diff --git a/moge/utils/tools.py b/moge/utils/tools.py index 3687f69..f209bed 100644 --- a/moge/utils/tools.py +++ b/moge/utils/tools.py @@ -1,14 +1,12 @@ -from typing import * -import time -from pathlib import Path -from numbers import Number -from functools import wraps -import warnings -import math -import json -import os import importlib import importlib.util +import json +import math +import os +import time +import warnings +from functools import wraps +from typing import * def catch_exception(fn): @@ -114,14 +112,12 @@ def unflatten_nested_dict(d: Dict[str, Any]) -> Dict[str, Any]: def read_jsonl(file): - import json with open(file, 'r') as f: data = f.readlines() return [json.loads(line) for line in data] def write_jsonl(data: List[dict], file): - import json with open(file, 'w') as f: for item in data: f.write(json.dumps(item) + '\n') @@ -223,7 +219,7 @@ def strip_common_prefix_suffix(strings: List[str]) -> List[str]: def multithead_execute(inputs: List[Any], num_workers: int, pbar = None): from concurrent.futures import ThreadPoolExecutor - from contextlib import nullcontext + from tqdm import tqdm if pbar is not None: diff --git a/moge/utils/vis.py b/moge/utils/vis.py index cb9c237..f8d1d36 100644 --- a/moge/utils/vis.py +++ b/moge/utils/vis.py @@ -1,7 +1,7 @@ from typing import * -import numpy as np import matplotlib +import numpy as np def colorize_depth(depth: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray: diff --git a/moge/utils/webfile.py b/moge/utils/webfile.py index 1e98abf..7ae5397 100644 --- a/moge/utils/webfile.py +++ b/moge/utils/webfile.py @@ -1,6 +1,7 @@ -import requests -from typing import * - +from typing import * + +import requests + __all__ = ["WebFile"] diff --git a/moge/utils/webzipfile.py b/moge/utils/webzipfile.py index 25ed1d3..8e3b3d9 100644 --- a/moge/utils/webzipfile.py +++ b/moge/utils/webzipfile.py @@ -1,13 +1,23 @@ +import struct from typing import * -import io -import os from zipfile import ( - ZipInfo, BadZipFile, ZipFile, ZipExtFile, - sizeFileHeader, structFileHeader, stringFileHeader, - _FH_SIGNATURE, _FH_FILENAME_LENGTH, _FH_EXTRA_FIELD_LENGTH, _FH_GENERAL_PURPOSE_FLAG_BITS, - _MASK_COMPRESSED_PATCH, _MASK_STRONG_ENCRYPTION, _MASK_UTF_FILENAME, _MASK_ENCRYPTED + _FH_EXTRA_FIELD_LENGTH, + _FH_FILENAME_LENGTH, + _FH_GENERAL_PURPOSE_FLAG_BITS, + _FH_SIGNATURE, + _MASK_COMPRESSED_PATCH, + _MASK_ENCRYPTED, + _MASK_STRONG_ENCRYPTION, + _MASK_UTF_FILENAME, + BadZipFile, + ZipExtFile, + ZipFile, + ZipInfo, + sizeFileHeader, + stringFileHeader, + structFileHeader, ) -import struct + from requests import Session from .webfile import WebFile diff --git a/pyproject.toml b/pyproject.toml index 27a7613..6a612aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,4 +33,16 @@ where = ["."] include = ["moge*"] [project.scripts] -moge = "moge.scripts.cli:main" \ No newline at end of file +moge = "moge.scripts.cli:main" +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = ["E501", "E402", "E731", "E722", "E741", "E721", "E701", "F841", "F403", "F405"] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] + +[tool.uv] +# Install with: uv sync