diff --git a/mart/attack/adversary_in_art.py b/mart/attack/adversary_in_art.py index 2a993349..1170340a 100644 --- a/mart/attack/adversary_in_art.py +++ b/mart/attack/adversary_in_art.py @@ -82,17 +82,17 @@ def convert_input_art_to_mart(self, x: numpy.ndarray): x (np.ndarray): NHWC, [0, 1] Returns: - tuple: a tuple of tensors in CHW, [0, 255]. + list[torch.Tensor]: a list of tensors in CHW, [0, 255]. """ input = torch.tensor(x).permute((0, 3, 1, 2)).to(self._device) * 255 - input = tuple(inp_ for inp_ in input) + input = [inp_ for inp_ in input] return input - def convert_input_mart_to_art(self, input: tuple): + def convert_input_mart_to_art(self, input: list[torch.Tensor]): """Convert MART input to the ART's format. Args: - input (tuple): a tuple of tensors in CHW, [0, 255]. + input (list[torch.Tensor]): a list of tensors in CHW, [0, 255]. Returns: np.ndarray: NHWC, [0, 1] @@ -112,7 +112,7 @@ def convert_target_art_to_mart(self, y: numpy.ndarray, y_patch_metadata: List): y_patch_metadata (_type_): _description_ Returns: - tuple: a tuple of target dictionaies. + list: a list of target dictionaies. """ # Copy y to target, and convert ndarray to pytorch tensors accordingly. target = [] @@ -132,6 +132,4 @@ def convert_target_art_to_mart(self, y: numpy.ndarray, y_patch_metadata: List): target_i["file_name"] = f"{yi['image_id'][0]}.jpg" target.append(target_i) - target = tuple(target) - return target diff --git a/mart/attack/adversary_wrapper.py b/mart/attack/adversary_wrapper.py index c4b02953..b6f20e56 100644 --- a/mart/attack/adversary_wrapper.py +++ b/mart/attack/adversary_wrapper.py @@ -37,8 +37,8 @@ def __init__( def forward( self, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], model: torch.nn.Module | None = None, **kwargs, ): diff --git a/mart/attack/callbacks/base.py b/mart/attack/callbacks/base.py index 97541ecb..d9a701d0 100644 --- a/mart/attack/callbacks/base.py +++ b/mart/attack/callbacks/base.py @@ -24,8 +24,8 @@ def on_run_start( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], model: torch.nn.Module, **kwargs, ): @@ -35,8 +35,8 @@ def on_examine_start( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], model: torch.nn.Module, **kwargs, ): @@ -46,8 +46,8 @@ def on_examine_end( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], model: torch.nn.Module, **kwargs, ): @@ -57,8 +57,8 @@ def on_advance_start( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], model: torch.nn.Module, **kwargs, ): @@ -68,8 +68,8 @@ def on_advance_end( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], model: torch.nn.Module, **kwargs, ): @@ -79,8 +79,8 @@ def on_run_end( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], model: torch.nn.Module, **kwargs, ): diff --git a/mart/attack/composer.py b/mart/attack/composer.py index ddfdc45b..af8decff 100644 --- a/mart/attack/composer.py +++ b/mart/attack/composer.py @@ -15,17 +15,17 @@ class Composer(abc.ABC): def __call__( self, - perturbation: torch.Tensor | tuple, + perturbation: torch.Tensor | list[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], **kwargs, - ) -> torch.Tensor | tuple: - if isinstance(perturbation, tuple): - input_adv = tuple( + ) -> torch.Tensor | list[torch.Tensor]: + if isinstance(perturbation, list): + input_adv = [ self.compose(perturbation_i, input=input_i, target=target_i) for perturbation_i, input_i, target_i in zip(perturbation, input, target) - ) + ] else: input_adv = self.compose(perturbation, input=input, target=target) diff --git a/mart/attack/enforcer.py b/mart/attack/enforcer.py index 4d4a1364..07b101ef 100644 --- a/mart/attack/enforcer.py +++ b/mart/attack/enforcer.py @@ -19,12 +19,12 @@ class ConstraintViolated(Exception): class Constraint(abc.ABC): def __call__( self, - input_adv: torch.Tensor | tuple, + input_adv: torch.Tensor | list[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], ) -> None: - if isinstance(input_adv, tuple): + if isinstance(input_adv, list): for input_adv_i, input_i, target_i in zip(input_adv, input, target): self.verify(input_adv_i, input=input_i, target=target_i) else: @@ -103,10 +103,10 @@ def __init__(self, constraints: dict[str, Constraint] | None = None) -> None: @torch.no_grad() def __call__( self, - input_adv: torch.Tensor | tuple, + input_adv: torch.Tensor | list[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], **kwargs, ) -> None: for constraint in self.constraints.values(): diff --git a/mart/attack/perturber.py b/mart/attack/perturber.py index 42ab7e01..682867be 100644 --- a/mart/attack/perturber.py +++ b/mart/attack/perturber.py @@ -61,14 +61,14 @@ def __init__( self.perturbation = None - def configure_perturbation(self, input: torch.Tensor | tuple): + def configure_perturbation(self, input: torch.Tensor | list[torch.Tensor]): def create_and_initialize(inp): pert = torch.empty_like(inp, dtype=torch.float, requires_grad=True) self.initializer(pert) return pert - if isinstance(input, tuple): - self.perturbation = tuple(create_and_initialize(inp) for inp in input) + if isinstance(input, list): + self.perturbation = [create_and_initialize(inp) for inp in input] elif isinstance(input, dict): raise NotImplementedError else: @@ -81,9 +81,9 @@ def configure_optimizers(self): ) params = self.perturbation - if not isinstance(params, tuple): + if not isinstance(params, list): # FIXME: Should we treat the batch dimension as independent parameters? - params = (params,) + params = [params] return self.optimizer_fn(params) diff --git a/mart/attack/projector.py b/mart/attack/projector.py index 92391c67..c77fda3e 100644 --- a/mart/attack/projector.py +++ b/mart/attack/projector.py @@ -17,13 +17,13 @@ class Projector: @torch.no_grad() def __call__( self, - perturbation: torch.Tensor | tuple, + perturbation: torch.Tensor | list[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], **kwargs, ) -> None: - if isinstance(perturbation, tuple): + if isinstance(perturbation, list): for perturbation_i, input_i, target_i in zip(perturbation, input, target): self.project(perturbation_i, input=input_i, target=target_i) else: @@ -31,10 +31,10 @@ def __call__( def project( self, - perturbation: torch.Tensor | tuple, + perturbation: torch.Tensor | list[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], ) -> None: pass @@ -48,10 +48,10 @@ def __init__(self, projectors: list[Projector]): @torch.no_grad() def __call__( self, - perturbation: torch.Tensor | tuple, + perturbation: torch.Tensor | list[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], **kwargs, ) -> None: for projector in self.projectors: diff --git a/mart/datamodules/coco.py b/mart/datamodules/coco.py index 42ddcebb..23f7c4d5 100644 --- a/mart/datamodules/coco.py +++ b/mart/datamodules/coco.py @@ -8,6 +8,12 @@ from typing import Any, Callable, List, Optional import numpy as np +import torch +from torch.utils.data._utils.collate import ( # WHY ARE THESE PRIVATE?! + collate, + collate_tensor_fn, + default_collate_fn_map, +) from torchvision.datasets.coco import CocoDetection as CocoDetection_ from torchvision.datasets.folder import default_loader @@ -86,6 +92,27 @@ def __getitem__(self, index: int): return image, target_dict -# Source: https://github.com/pytorch/vision/blob/dc07ac2add8285e16a716564867d0b4b953f6735/references/detection/utils.py#L203 +def _collate_tensor_fn(batch, *, collate_fn_map=None): + """Handle the case when all elements in list are not the same shape. + + Instead of throwing an exception, we just leave them as a list of Tensors. + """ + + if not all([x.shape == batch[0].shape for x in batch]): + return list(batch) + + return collate_tensor_fn(batch, collate_fn_map=collate_fn_map) + + def collate_fn(batch): - return tuple(zip(*batch)) + collate_fn_map = default_collate_fn_map.copy() + collate_fn_map[torch.Tensor] = _collate_tensor_fn + + images, targets = collate(batch, collate_fn_map=collate_fn_map) + + # dict of lists to list of dicts for backwards compatibility + if isinstance(targets, dict): + targets = [dict(zip(targets.keys(), values)) for values in zip(*targets.values())] + + # FIXME: Ideally we would just return a dict with {"input": images, **targets} + return images, targets