From b32da2f97464d1c6208c1772b568ff6a51759e8c Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Fri, 24 Mar 2023 11:03:31 -0700 Subject: [PATCH 1/9] Convert from tuple of tensors to list of tensors --- mart/attack/adversary.py | 17 +++++++++-------- mart/attack/perturber.py | 18 +++++++++--------- mart/datamodules/coco.py | 4 ++-- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/mart/attack/adversary.py b/mart/attack/adversary.py index 50c78fe5..047ad870 100644 --- a/mart/attack/adversary.py +++ b/mart/attack/adversary.py @@ -68,7 +68,8 @@ def __init__( self.enforcer = enforcer @silent() - def forward(self, *, input: torch.Tensor | tuple, **batch): + def forward(self, *, input: torch.Tensor | list[torch.Tensor], **batch): + print("input =", input.__class__) # Adversary lives within a sequence of model. To signal the adversary should attack, one # must pass a model to attack when calling the adversary. Since we do not know where the # Adversary lives inside the model, we also need the remaining sequence to be able to @@ -85,7 +86,7 @@ def forward(self, *, input: torch.Tensor | tuple, **batch): return input_adv - def _attack(self, input: torch.Tensor | tuple, **kwargs): + def _attack(self, input: torch.Tensor | list[torch.Tensor], **kwargs): batch = {"input": input, **kwargs} # Attack, aka fit a perturbation, for one epoch by cycling over the same input batch. @@ -94,7 +95,7 @@ def _attack(self, input: torch.Tensor | tuple, **kwargs): attacker.fit_loop.max_epochs += 1 attacker.fit(self.perturber, train_dataloaders=cycle([batch])) - def _initialize_attack(self, input: torch.Tensor | tuple): + def _initialize_attack(self, input: torch.Tensor | list[torch.Tensor]): # Configure perturber to use batch inputs self.perturber.configure_perturbation(input) @@ -102,7 +103,7 @@ def _initialize_attack(self, input: torch.Tensor | tuple): return self.attacker # Convert torch.device to PL accelerator - device = input[0].device if isinstance(input, tuple) else input.device + device = input[0].device if isinstance(input, list) else input.device if device.type == "cuda": accelerator = "gpu" @@ -120,13 +121,13 @@ def _initialize_attack(self, input: torch.Tensor | tuple): def _enforce( self, - input_adv: torch.Tensor | tuple, + input_adv: torch.Tensor | list[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], **kwargs, ): - if not isinstance(input_adv, tuple): + if not isinstance(input_adv, list): self.enforcer(input_adv, input=input, target=target) else: for inp_adv, inp, tar in zip(input_adv, input, target): diff --git a/mart/attack/perturber.py b/mart/attack/perturber.py index a95e4932..a155bae7 100644 --- a/mart/attack/perturber.py +++ b/mart/attack/perturber.py @@ -60,16 +60,16 @@ def __init__( self.perturbation = None - def configure_perturbation(self, input: torch.Tensor | tuple): + def configure_perturbation(self, input: torch.Tensor | list[torch.Tensor]): def create_and_initialize(inp): pert = torch.empty_like(inp, dtype=torch.float, requires_grad=True) self.initializer(pert) return pert - if not isinstance(input, tuple): + if not isinstance(input, list): self.perturbation = create_and_initialize(input) else: - self.perturbation = tuple(create_and_initialize(inp) for inp in input) + self.perturbation = [create_and_initialize(inp) for inp in input] def configure_optimizers(self): if self.perturbation is None: @@ -78,7 +78,7 @@ def configure_optimizers(self): ) params = self.perturbation - if not isinstance(params, tuple): + if not isinstance(params, list): # FIXME: Should we treat the batch dimension as independent parameters? params = (params,) @@ -123,8 +123,8 @@ def configure_gradient_clipping( def forward( self, *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], **kwargs, ): if self.perturbation is None: @@ -136,12 +136,12 @@ def project_and_compose(pert, inp, tar): self.projector(pert, inp, tar) return self.composer(pert, input=inp, target=tar) - if not isinstance(self.perturbation, tuple): + if not isinstance(self.perturbation, list): input_adv = project_and_compose(self.perturbation, input, target) else: - input_adv = tuple( + input_adv = [ project_and_compose(pert, inp, tar) for pert, inp, tar in zip(self.perturbation, input, target) - ) + ] return input_adv diff --git a/mart/datamodules/coco.py b/mart/datamodules/coco.py index 42ddcebb..e7ed4770 100644 --- a/mart/datamodules/coco.py +++ b/mart/datamodules/coco.py @@ -86,6 +86,6 @@ def __getitem__(self, index: int): return image, target_dict -# Source: https://github.com/pytorch/vision/blob/dc07ac2add8285e16a716564867d0b4b953f6735/references/detection/utils.py#L203 def collate_fn(batch): - return tuple(zip(*batch)) + # [(x0, y0), ..., (xN, yN)] -> ([x0, ..., xN], [y0, ..., yN]) + return tuple(map(list, zip(*batch))) From 1965591b2550fea58d71e2ab612033192b1ab016 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Fri, 24 Mar 2023 11:09:42 -0700 Subject: [PATCH 2/9] More convert tuple of tensors to list of tensors --- mart/attack/adversary_in_art.py | 12 +++++------- mart/attack/adversary_wrapper.py | 4 ++-- mart/attack/callbacks/base.py | 24 ++++++++++++------------ mart/attack/composer.py | 18 ++++++++---------- mart/attack/enforcer.py | 8 ++++---- 5 files changed, 31 insertions(+), 35 deletions(-) diff --git a/mart/attack/adversary_in_art.py b/mart/attack/adversary_in_art.py index 2a993349..1170340a 100644 --- a/mart/attack/adversary_in_art.py +++ b/mart/attack/adversary_in_art.py @@ -82,17 +82,17 @@ def convert_input_art_to_mart(self, x: numpy.ndarray): x (np.ndarray): NHWC, [0, 1] Returns: - tuple: a tuple of tensors in CHW, [0, 255]. + list[torch.Tensor]: a list of tensors in CHW, [0, 255]. """ input = torch.tensor(x).permute((0, 3, 1, 2)).to(self._device) * 255 - input = tuple(inp_ for inp_ in input) + input = [inp_ for inp_ in input] return input - def convert_input_mart_to_art(self, input: tuple): + def convert_input_mart_to_art(self, input: list[torch.Tensor]): """Convert MART input to the ART's format. Args: - input (tuple): a tuple of tensors in CHW, [0, 255]. + input (list[torch.Tensor]): a list of tensors in CHW, [0, 255]. Returns: np.ndarray: NHWC, [0, 1] @@ -112,7 +112,7 @@ def convert_target_art_to_mart(self, y: numpy.ndarray, y_patch_metadata: List): y_patch_metadata (_type_): _description_ Returns: - tuple: a tuple of target dictionaies. + list: a list of target dictionaies. """ # Copy y to target, and convert ndarray to pytorch tensors accordingly. target = [] @@ -132,6 +132,4 @@ def convert_target_art_to_mart(self, y: numpy.ndarray, y_patch_metadata: List): target_i["file_name"] = f"{yi['image_id'][0]}.jpg" target.append(target_i) - target = tuple(target) - return target diff --git a/mart/attack/adversary_wrapper.py b/mart/attack/adversary_wrapper.py index c4b02953..b6f20e56 100644 --- a/mart/attack/adversary_wrapper.py +++ b/mart/attack/adversary_wrapper.py @@ -37,8 +37,8 @@ def __init__( def forward( self, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], model: torch.nn.Module | None = None, **kwargs, ): diff --git a/mart/attack/callbacks/base.py b/mart/attack/callbacks/base.py index 97541ecb..d9a701d0 100644 --- a/mart/attack/callbacks/base.py +++ b/mart/attack/callbacks/base.py @@ -24,8 +24,8 @@ def on_run_start( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], model: torch.nn.Module, **kwargs, ): @@ -35,8 +35,8 @@ def on_examine_start( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], model: torch.nn.Module, **kwargs, ): @@ -46,8 +46,8 @@ def on_examine_end( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], model: torch.nn.Module, **kwargs, ): @@ -57,8 +57,8 @@ def on_advance_start( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], model: torch.nn.Module, **kwargs, ): @@ -68,8 +68,8 @@ def on_advance_end( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], model: torch.nn.Module, **kwargs, ): @@ -79,8 +79,8 @@ def on_run_end( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], model: torch.nn.Module, **kwargs, ): diff --git a/mart/attack/composer.py b/mart/attack/composer.py index c3f1c738..dae5e222 100644 --- a/mart/attack/composer.py +++ b/mart/attack/composer.py @@ -18,11 +18,11 @@ class Composer(torch.nn.Module, abc.ABC): @abc.abstractclassmethod def forward( self, - perturbation: torch.Tensor | tuple, + perturbation: torch.Tensor | list[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, - ) -> torch.Tensor | tuple: + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], + ) -> torch.Tensor | list[torch.Tensor]: raise NotImplementedError @@ -34,12 +34,12 @@ def __init__(self, composer: Composer): def forward( self, - perturbation: torch.Tensor | tuple, + perturbation: torch.Tensor | list[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], **kwargs, - ) -> torch.Tensor | tuple: + ) -> torch.Tensor | list[torch.Tensor]: output = [] for input_i, target_i, perturbation_i in zip(input, target, perturbation): @@ -48,8 +48,6 @@ def forward( if isinstance(input, torch.Tensor): output = torch.stack(output) - else: - output = tuple(output) return output diff --git a/mart/attack/enforcer.py b/mart/attack/enforcer.py index 2a1a9c15..99537c5e 100644 --- a/mart/attack/enforcer.py +++ b/mart/attack/enforcer.py @@ -92,10 +92,10 @@ class BatchEnforcer(Enforcer): @torch.no_grad() def __call__( self, - input_adv: torch.Tensor | tuple, + input_adv: torch.Tensor | list[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, - ) -> torch.Tensor | tuple: + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], + ) -> torch.Tensor | list[torch.Tensor]: for input_adv_i, input_i, target_i in zip(input_adv, input, target): self._check_constraints(input_adv_i, input=input_i, target=target_i) From aacda23b616267d18208e6e0fb989d4b1c953939 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Mon, 27 Mar 2023 08:59:48 -0700 Subject: [PATCH 3/9] cleanup --- mart/attack/adversary.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mart/attack/adversary.py b/mart/attack/adversary.py index 798c43a9..cbf1ea7c 100644 --- a/mart/attack/adversary.py +++ b/mart/attack/adversary.py @@ -71,7 +71,6 @@ def __init__( @silent() def forward(self, *, input: torch.Tensor | list[torch.Tensor], **batch): - print("input =", input.__class__) # Adversary lives within a sequence of model. To signal the adversary should attack, one # must pass a model to attack when calling the adversary. Since we do not know where the # Adversary lives inside the model, we also need the remaining sequence to be able to From b8fed3dbe5b21b847739619186710fffdd6c98a3 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Mon, 27 Mar 2023 13:43:12 -0700 Subject: [PATCH 4/9] Override default collate function --- mart/datamodules/coco.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/mart/datamodules/coco.py b/mart/datamodules/coco.py index e7ed4770..d8d2a0e6 100644 --- a/mart/datamodules/coco.py +++ b/mart/datamodules/coco.py @@ -8,6 +8,12 @@ from typing import Any, Callable, List, Optional import numpy as np +import torch +from torch.utils.data._utils.collate import ( # WHY ARE THESE PRIVATE?! + collate, + collate_tensor_fn, + default_collate_fn_map, +) from torchvision.datasets.coco import CocoDetection as CocoDetection_ from torchvision.datasets.folder import default_loader @@ -86,6 +92,20 @@ def __getitem__(self, index: int): return image, target_dict +def _collate_tensor_fn(batch, *, collate_fn_map=None): + """Handle the case when all elements in list are not the same shape. + + Instead of throwing an exception, we just leave them as a list of Tensors. + """ + + if not all([x.shape == batch[0].shape for x in batch]): + return batch + + return collate_tensor_fn(batch, collate_fn_map=collate_fn_map) + + def collate_fn(batch): - # [(x0, y0), ..., (xN, yN)] -> ([x0, ..., xN], [y0, ..., yN]) - return tuple(map(list, zip(*batch))) + collate_fn_map = default_collate_fn_map.copy() + collate_fn_map[torch.Tensor] = _collate_tensor_fn + + return collate(batch, collate_fn_map=collate_fn_map) From c5b617f8993cee8c77491551d8afdd4cec66870e Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Mon, 27 Mar 2023 14:10:33 -0700 Subject: [PATCH 5/9] Force batch to be a list --- mart/datamodules/coco.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mart/datamodules/coco.py b/mart/datamodules/coco.py index d8d2a0e6..c367159f 100644 --- a/mart/datamodules/coco.py +++ b/mart/datamodules/coco.py @@ -99,7 +99,7 @@ def _collate_tensor_fn(batch, *, collate_fn_map=None): """ if not all([x.shape == batch[0].shape for x in batch]): - return batch + return list(batch) return collate_tensor_fn(batch, collate_fn_map=collate_fn_map) From d1ab4b886b2f72c882a46b825d4338e07931a839 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Mon, 27 Mar 2023 14:46:16 -0700 Subject: [PATCH 6/9] bugfix collate_fn --- mart/datamodules/coco.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/mart/datamodules/coco.py b/mart/datamodules/coco.py index c367159f..23f7c4d5 100644 --- a/mart/datamodules/coco.py +++ b/mart/datamodules/coco.py @@ -108,4 +108,11 @@ def collate_fn(batch): collate_fn_map = default_collate_fn_map.copy() collate_fn_map[torch.Tensor] = _collate_tensor_fn - return collate(batch, collate_fn_map=collate_fn_map) + images, targets = collate(batch, collate_fn_map=collate_fn_map) + + # dict of lists to list of dicts for backwards compatibility + if isinstance(targets, dict): + targets = [dict(zip(targets.keys(), values)) for values in zip(*targets.values())] + + # FIXME: Ideally we would just return a dict with {"input": images, **targets} + return images, targets From ae2e89c7a590ec0f3de560325397db8524597681 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Mon, 27 Mar 2023 14:51:10 -0700 Subject: [PATCH 7/9] Raise proper errors for input dicts --- mart/attack/perturber.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/mart/attack/perturber.py b/mart/attack/perturber.py index 2c367588..86424bd5 100644 --- a/mart/attack/perturber.py +++ b/mart/attack/perturber.py @@ -67,10 +67,12 @@ def create_and_initialize(inp): self.initializer(pert) return pert - if not isinstance(input, list): - self.perturbation = create_and_initialize(input) - else: + if isinstance(input, list): self.perturbation = [create_and_initialize(inp) for inp in input] + elif isinstance(input, dict): + raise NotImplementedError + else: + self.perturbation = create_and_initialize(input) def configure_optimizers(self): if self.perturbation is None: @@ -79,6 +81,7 @@ def configure_optimizers(self): ) params = self.perturbation + # FIXME: Figure out how to handle perturbation for dict inputs if not isinstance(params, list): # FIXME: Should we treat the batch dimension as independent parameters? params = (params,) @@ -138,12 +141,14 @@ def project_and_compose(pert, inp, tar): self.projector(pert, inp, tar) return self.composer(pert, input=inp, target=tar) - if not isinstance(self.perturbation, list): - input_adv = project_and_compose(self.perturbation, input, target) - else: + if isinstance(self.perturbation, list): input_adv = [ project_and_compose(pert, inp, tar) for pert, inp, tar in zip(self.perturbation, input, target) ] + elif isinstance(self.perturbation, dict): + raise NotImplementedError + else: + input_adv = project_and_compose(self.perturbation, input, target) return input_adv From e552a4a426d551599817cb526a9cc5bfdcaf2b8f Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Tue, 28 Mar 2023 11:47:38 -0700 Subject: [PATCH 8/9] Convert enforcer and projector to use lists --- mart/attack/enforcer.py | 14 +++++++------- mart/attack/projector.py | 20 ++++++++++---------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/mart/attack/enforcer.py b/mart/attack/enforcer.py index 4d4a1364..07b101ef 100644 --- a/mart/attack/enforcer.py +++ b/mart/attack/enforcer.py @@ -19,12 +19,12 @@ class ConstraintViolated(Exception): class Constraint(abc.ABC): def __call__( self, - input_adv: torch.Tensor | tuple, + input_adv: torch.Tensor | list[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], ) -> None: - if isinstance(input_adv, tuple): + if isinstance(input_adv, list): for input_adv_i, input_i, target_i in zip(input_adv, input, target): self.verify(input_adv_i, input=input_i, target=target_i) else: @@ -103,10 +103,10 @@ def __init__(self, constraints: dict[str, Constraint] | None = None) -> None: @torch.no_grad() def __call__( self, - input_adv: torch.Tensor | tuple, + input_adv: torch.Tensor | list[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], **kwargs, ) -> None: for constraint in self.constraints.values(): diff --git a/mart/attack/projector.py b/mart/attack/projector.py index 92391c67..c77fda3e 100644 --- a/mart/attack/projector.py +++ b/mart/attack/projector.py @@ -17,13 +17,13 @@ class Projector: @torch.no_grad() def __call__( self, - perturbation: torch.Tensor | tuple, + perturbation: torch.Tensor | list[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], **kwargs, ) -> None: - if isinstance(perturbation, tuple): + if isinstance(perturbation, list): for perturbation_i, input_i, target_i in zip(perturbation, input, target): self.project(perturbation_i, input=input_i, target=target_i) else: @@ -31,10 +31,10 @@ def __call__( def project( self, - perturbation: torch.Tensor | tuple, + perturbation: torch.Tensor | list[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], ) -> None: pass @@ -48,10 +48,10 @@ def __init__(self, projectors: list[Projector]): @torch.no_grad() def __call__( self, - perturbation: torch.Tensor | tuple, + perturbation: torch.Tensor | list[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | list[torch.Tensor], + target: torch.Tensor | dict[str, Any] | list[Any], **kwargs, ) -> None: for projector in self.projectors: From 9298300de8d3e0b3add19dc157b2feca20b551ac Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Tue, 28 Mar 2023 11:49:55 -0700 Subject: [PATCH 9/9] Gracefully fail when input is a dict --- mart/attack/perturber.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mart/attack/perturber.py b/mart/attack/perturber.py index a8606819..682867be 100644 --- a/mart/attack/perturber.py +++ b/mart/attack/perturber.py @@ -67,10 +67,12 @@ def create_and_initialize(inp): self.initializer(pert) return pert - if not isinstance(input, list): - self.perturbation = create_and_initialize(input) - else: + if isinstance(input, list): self.perturbation = [create_and_initialize(inp) for inp in input] + elif isinstance(input, dict): + raise NotImplementedError + else: + self.perturbation = create_and_initialize(input) def configure_optimizers(self): if self.perturbation is None: