From 56014e65d85ca88550ed90e8b7e6f2fa39609df7 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Fri, 14 Apr 2023 14:25:36 -0700 Subject: [PATCH 01/36] Replace tuple with Iterable[torch.Tensor] --- mart/attack/adversary_in_art.py | 12 ++++--- mart/attack/adversary_wrapper.py | 11 +++--- mart/attack/callbacks/base.py | 26 +++++++------- mart/attack/composer.py | 27 +++++++++------ mart/attack/enforcer.py | 59 ++++++++++++++------------------ mart/attack/initializer.py | 41 ++++++++++++---------- mart/attack/projector.py | 56 +++++++++++++++++++----------- 7 files changed, 129 insertions(+), 103 deletions(-) diff --git a/mart/attack/adversary_in_art.py b/mart/attack/adversary_in_art.py index 2a993349..d48f669c 100644 --- a/mart/attack/adversary_in_art.py +++ b/mart/attack/adversary_in_art.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause # -from typing import Any, List, Optional +from typing import Any, Iterable, List, Optional import hydra import numpy @@ -82,17 +82,18 @@ def convert_input_art_to_mart(self, x: numpy.ndarray): x (np.ndarray): NHWC, [0, 1] Returns: - tuple: a tuple of tensors in CHW, [0, 255]. + Iterable[torch.Tensor]: an Iterable of tensors in CHW, [0, 255]. """ input = torch.tensor(x).permute((0, 3, 1, 2)).to(self._device) * 255 + # FIXME: replace tuple with whatever input's type is input = tuple(inp_ for inp_ in input) return input - def convert_input_mart_to_art(self, input: tuple): + def convert_input_mart_to_art(self, input: Iterable[torch.Tensor]): """Convert MART input to the ART's format. Args: - input (tuple): a tuple of tensors in CHW, [0, 255]. + input (Iterable[torch.Tensor]): an Iterable of tensors in CHW, [0, 255]. Returns: np.ndarray: NHWC, [0, 1] @@ -112,7 +113,7 @@ def convert_target_art_to_mart(self, y: numpy.ndarray, y_patch_metadata: List): y_patch_metadata (_type_): _description_ Returns: - tuple: a tuple of target dictionaies. + Iterable[dict[str, Any]]: an Iterable of target dictionaies. """ # Copy y to target, and convert ndarray to pytorch tensors accordingly. target = [] @@ -132,6 +133,7 @@ def convert_target_art_to_mart(self, y: numpy.ndarray, y_patch_metadata: List): target_i["file_name"] = f"{yi['image_id'][0]}.jpg" target.append(target_i) + # FIXME: replace tuple with input type? target = tuple(target) return target diff --git a/mart/attack/adversary_wrapper.py b/mart/attack/adversary_wrapper.py index c4b02953..a40ee644 100644 --- a/mart/attack/adversary_wrapper.py +++ b/mart/attack/adversary_wrapper.py @@ -6,10 +6,13 @@ from __future__ import annotations -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Iterable import torch +if TYPE_CHECKING: + from .enforcer import Enforcer + __all__ = ["NormalizedAdversaryAdapter"] @@ -22,7 +25,7 @@ class NormalizedAdversaryAdapter(torch.nn.Module): def __init__( self, adversary: Callable[[Callable], Callable], - enforcer: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None], + enforcer: Enforcer, ): """ @@ -37,8 +40,8 @@ def __init__( def forward( self, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], model: torch.nn.Module | None = None, **kwargs, ): diff --git a/mart/attack/callbacks/base.py b/mart/attack/callbacks/base.py index 97541ecb..a982aa8e 100644 --- a/mart/attack/callbacks/base.py +++ b/mart/attack/callbacks/base.py @@ -7,7 +7,7 @@ from __future__ import annotations import abc -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Iterable import torch @@ -24,8 +24,8 @@ def on_run_start( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], model: torch.nn.Module, **kwargs, ): @@ -35,8 +35,8 @@ def on_examine_start( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], model: torch.nn.Module, **kwargs, ): @@ -46,8 +46,8 @@ def on_examine_end( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], model: torch.nn.Module, **kwargs, ): @@ -57,8 +57,8 @@ def on_advance_start( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], model: torch.nn.Module, **kwargs, ): @@ -68,8 +68,8 @@ def on_advance_end( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], model: torch.nn.Module, **kwargs, ): @@ -79,8 +79,8 @@ def on_run_end( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], model: torch.nn.Module, **kwargs, ): diff --git a/mart/attack/composer.py b/mart/attack/composer.py index 5bc4edb7..ef8f3417 100644 --- a/mart/attack/composer.py +++ b/mart/attack/composer.py @@ -7,7 +7,7 @@ from __future__ import annotations import abc -from typing import Any +from typing import Any, Iterable import torch @@ -15,21 +15,28 @@ class Composer(abc.ABC): def __call__( self, - perturbation: torch.Tensor | tuple, + perturbation: torch.Tensor | Iterable[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], **kwargs, - ) -> torch.Tensor | tuple: - if isinstance(perturbation, tuple): - input_adv = tuple( + ) -> torch.Tensor | Iterable[torch.Tensor]: + if isinstance(perturbation, torch.Tensor) and isinstance(input, torch.Tensor): + return self.compose(perturbation, input=input, target=target) + + elif ( + isinstance(perturbation, Iterable) + and isinstance(input, Iterable) # noqa: W503 + and isinstance(target, Iterable) # noqa: W503 + ): + # FIXME: replace tuple with whatever input's type is + return tuple( self.compose(perturbation_i, input=input_i, target=target_i) for perturbation_i, input_i, target_i in zip(perturbation, input, target) ) - else: - input_adv = self.compose(perturbation, input=input, target=target) - return input_adv + else: + raise NotImplementedError @abc.abstractmethod def compose( diff --git a/mart/attack/enforcer.py b/mart/attack/enforcer.py index babc44e6..95e6716b 100644 --- a/mart/attack/enforcer.py +++ b/mart/attack/enforcer.py @@ -7,7 +7,7 @@ from __future__ import annotations import abc -from typing import Any +from typing import Any, Iterable import torch @@ -95,45 +95,38 @@ def verify(self, input_adv, *, input, target): class Enforcer: - def __init__(self, **modality_constraints: dict[str, dict[str, Constraint]]) -> None: - self.modality_constraints = modality_constraints + def __init__(self, constraints: dict[str, Constraint]) -> None: + self.constraints = list(constraints.values()) # intentionally ignore keys @torch.no_grad() - def _enforce( + def __call__( self, - input_adv: torch.Tensor, + input_adv: torch.Tensor | Iterable[torch.Tensor], *, - input: torch.Tensor, - target: torch.Tensor | dict[str, Any], - modality: str, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], + **kwargs, ): - for constraint in self.modality_constraints[modality].values(): - constraint(input_adv, input=input, target=target) + if isinstance(input_adv, torch.Tensor) and isinstance(input, torch.Tensor): + self.enforce(input_adv, input=input, target=target) + + elif ( + isinstance(input_adv, Iterable) + and isinstance(input, Iterable) # noqa: W503 + and isinstance(target, Iterable) # noqa: W503 + ): + [ + self.enforce(input_adv_i, input=input_i, target=target_i) + for input_adv_i, input_i, target_i in zip(input_adv, input, target) + ] - def __call__( + @torch.no_grad() + def enforce( self, - input_adv: torch.Tensor | tuple | list[torch.Tensor] | dict[str, torch.Tensor], + input_adv: torch.Tensor, *, - input: torch.Tensor | tuple | list[torch.Tensor] | dict[str, torch.Tensor], + input: torch.Tensor, target: torch.Tensor | dict[str, Any], - modality: str = "constraints", - **kwargs, ): - assert type(input_adv) == type(input) - - if isinstance(input_adv, torch.Tensor): - # Finally we can verify constraints on tensor, per its modality. - # Set modality="constraints" by default, so that it is backward compatible with existing configs without modalities. - self._enforce(input_adv, input=input, target=target, modality=modality) - elif isinstance(input_adv, dict): - # The dict input has modalities specified in keys, passing them recursively. - for modality in input_adv: - self(input_adv[modality], input=input[modality], target=target, modality=modality) - elif isinstance(input_adv, (list, tuple)): - # We assume a modality-dictionary only contains tensors, but not list/tuple. - assert modality == "constraints" - # The list or tuple input is a collection of sub-input and sub-target. - for input_adv_i, input_i, target_i in zip(input_adv, input, target): - self(input_adv_i, input=input_i, target=target_i, modality=modality) - else: - raise ValueError(f"Unsupported data type of input_adv: {type(input_adv)}.") + for constraint in self.constraints: + constraint(input_adv, input=input, target=target) diff --git a/mart/attack/initializer.py b/mart/attack/initializer.py index cd05c6c6..d66bcf9f 100644 --- a/mart/attack/initializer.py +++ b/mart/attack/initializer.py @@ -4,52 +4,57 @@ # SPDX-License-Identifier: BSD-3-Clause # -import abc -from typing import Optional, Union +from __future__ import annotations -import torch +from typing import Iterable -__all__ = ["Initializer"] +import torch -class Initializer(abc.ABC): +class Initializer: """Initializer base class.""" @torch.no_grad() - @abc.abstractmethod - def __call__(self, perturbation: torch.Tensor) -> None: + def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + + [self.initialize_(parameter) for parameter in parameters] + + @torch.no_grad() + def initialize_(self, parameter: torch.Tensor) -> None: pass class Constant(Initializer): - def __init__(self, constant: Optional[Union[int, float]] = 0): + def __init__(self, constant: int | float = 0): self.constant = constant @torch.no_grad() - def __call__(self, perturbation: torch.Tensor) -> None: - torch.nn.init.constant_(perturbation, self.constant) + def initialize_(self, parameter: torch.Tensor) -> None: + torch.nn.init.constant_(parameter, self.constant) class Uniform(Initializer): - def __init__(self, min: Union[int, float], max: Union[int, float]): + def __init__(self, min: int | float, max: int | float): self.min = min self.max = max @torch.no_grad() - def __call__(self, perturbation: torch.Tensor) -> None: - torch.nn.init.uniform_(perturbation, self.min, self.max) + def initialize_(self, parameter: torch.Tensor) -> None: + torch.nn.init.uniform_(parameter, self.min, self.max) class UniformLp(Initializer): - def __init__(self, eps: Union[int, float], p: Optional[Union[int, float]] = torch.inf): + def __init__(self, eps: int | float, p: int | float = torch.inf): self.eps = eps self.p = p @torch.no_grad() - def __call__(self, perturbation: torch.Tensor) -> None: - torch.nn.init.uniform_(perturbation, -self.eps, self.eps) + def initialize_(self, parameter: torch.Tensor) -> None: + torch.nn.init.uniform_(parameter, -self.eps, self.eps) # TODO: make sure the first dim is the batch dim. if self.p is not torch.inf: # We don't do tensor.renorm_() because the first dim is not the batch dim. - pert_norm = perturbation.norm(p=self.p) - perturbation.mul_(self.eps / pert_norm) + pert_norm = parameter.norm(p=self.p) + parameter.mul_(self.eps / pert_norm) diff --git a/mart/attack/projector.py b/mart/attack/projector.py index 92391c67..095d2601 100644 --- a/mart/attack/projector.py +++ b/mart/attack/projector.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Any +from typing import Any, Iterable import torch @@ -17,24 +17,35 @@ class Projector: @torch.no_grad() def __call__( self, - perturbation: torch.Tensor | tuple, + perturbation: torch.Tensor | Iterable[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.tensor | dict[str, Any]], **kwargs, ) -> None: - if isinstance(perturbation, tuple): - for perturbation_i, input_i, target_i in zip(perturbation, input, target): - self.project(perturbation_i, input=input_i, target=target_i) + if isinstance(perturbation, torch.Tensor) and isinstance(input, torch.Tensor): + self.project_(perturbation, input=input, target=target) + + elif ( + isinstance(perturbation, Iterable) + and isinstance(input, Iterable) # noqa: W503 + and isinstance(target, Iterable) # noqa: W503 + ): + [ + self.project_(perturbation_i, input=input_i, target=target_i) + for perturbation_i, input_i, target_i in zip(perturbation, input, target) + ] + else: - self.project(perturbation, input=input, target=target) + raise NotImplementedError - def project( + @torch.no_grad() + def project_( self, - perturbation: torch.Tensor | tuple, + perturbation: torch.Tensor | Iterable[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], ) -> None: pass @@ -48,10 +59,10 @@ def __init__(self, projectors: list[Projector]): @torch.no_grad() def __call__( self, - perturbation: torch.Tensor | tuple, + perturbation: torch.Tensor | Iterable[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], **kwargs, ) -> None: for projector in self.projectors: @@ -70,7 +81,8 @@ def __init__(self, quantize: bool = False, min: int | float = 0, max: int | floa self.min = min self.max = max - def project(self, perturbation, *, input, target): + @torch.no_grad() + def project_(self, perturbation, *, input, target): if self.quantize: perturbation.round_() perturbation.clamp_(self.min, self.max) @@ -92,7 +104,8 @@ def __init__(self, quantize: bool = False, min: int | float = 0, max: int | floa self.min = min self.max = max - def project(self, perturbation, *, input, target): + @torch.no_grad() + def project_(self, perturbation, *, input, target): if self.quantize: perturbation.round_() perturbation.clamp_(self.min - input, self.max - input) @@ -117,7 +130,8 @@ def __init__(self, eps: int | float, p: int | float = torch.inf): self.p = p self.eps = eps - def project(self, perturbation, *, input, target): + @torch.no_grad() + def project_(self, perturbation, *, input, target): pert_norm = perturbation.norm(p=self.p) if pert_norm > self.eps: # We only upper-bound the norm. @@ -133,7 +147,8 @@ def __init__(self, eps: int | float, min: int | float = 0, max: int | float = 25 self.min = min self.max = max - def project(self, perturbation, *, input, target): + @torch.no_grad() + def project_(self, perturbation, *, input, target): eps_min = (input - self.eps).clamp(self.min, self.max) - input eps_max = (input + self.eps).clamp(self.min, self.max) - input @@ -141,7 +156,8 @@ def project(self, perturbation, *, input, target): class Mask(Projector): - def project(self, perturbation, *, input, target): + @torch.no_grad() + def project_(self, perturbation, *, input, target): perturbation.mul_(target["perturbable_mask"]) def __repr__(self): From 1c47cc049a7130802521080ee800f4c7ead55dc2 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Fri, 14 Apr 2023 14:41:15 -0700 Subject: [PATCH 02/36] Fix tests --- tests/test_enforcer.py | 54 ++++++++++++++++++++--------------------- tests/test_projector.py | 2 +- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/tests/test_enforcer.py b/tests/test_enforcer.py index 2c56b3ad..e67b1034 100644 --- a/tests/test_enforcer.py +++ b/tests/test_enforcer.py @@ -97,30 +97,30 @@ def test_enforcer_non_modality(): enforcer((input_adv,), input=(input,), target=(target,)) -def test_enforcer_modality(): - # Assume a rgb modality. - enforcer = Enforcer(rgb={"range": Range(min=0, max=255)}) - - input = torch.tensor([0, 0, 0]) - perturbation = torch.tensor([0, 128, 255]) - input_adv = input + perturbation - target = None - - # Dictionary input. - enforcer({"rgb": input_adv}, input={"rgb": input}, target=target) - # List of dictionary input. - enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target]) - # Tuple of dictionary input. - enforcer(({"rgb": input_adv},), input=({"rgb": input},), target=(target,)) - - perturbation = torch.tensor([0, -1, 255]) - input_adv = input + perturbation - - with pytest.raises(ConstraintViolated): - enforcer({"rgb": input_adv}, input={"rgb": input}, target=target) - - with pytest.raises(ConstraintViolated): - enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target]) - - with pytest.raises(ConstraintViolated): - enforcer(({"rgb": input_adv},), input=({"rgb": input},), target=(target,)) +# def test_enforcer_modality(): +# # Assume a rgb modality. +# enforcer = Enforcer(rgb={"range": Range(min=0, max=255)}) +# +# input = torch.tensor([0, 0, 0]) +# perturbation = torch.tensor([0, 128, 255]) +# input_adv = input + perturbation +# target = None +# +# # Dictionary input. +# enforcer({"rgb": input_adv}, input={"rgb": input}, target=target) +# # List of dictionary input. +# enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target]) +# # Tuple of dictionary input. +# enforcer(({"rgb": input_adv},), input=({"rgb": input},), target=(target,)) +# +# perturbation = torch.tensor([0, -1, 255]) +# input_adv = input + perturbation +# +# with pytest.raises(ConstraintViolated): +# enforcer({"rgb": input_adv}, input={"rgb": input}, target=target) +# +# with pytest.raises(ConstraintViolated): +# enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target]) +# +# with pytest.raises(ConstraintViolated): +# enforcer(({"rgb": input_adv},), input=({"rgb": input},), target=(target,)) diff --git a/tests/test_projector.py b/tests/test_projector.py index a397a98c..19cb5c44 100644 --- a/tests/test_projector.py +++ b/tests/test_projector.py @@ -154,7 +154,7 @@ def test_compose(input_data, target_data): ] compose = Compose(projectors) - tensor = Mock() + tensor = Mock(spec=torch.Tensor) tensor.norm.return_value = 10 compose(tensor, input=input_data, target=target_data) From 70cc36ac2c3719c56b8510bc711c9c0766f4fdd7 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Fri, 14 Apr 2023 14:49:45 -0700 Subject: [PATCH 03/36] Cleanup --- mart/attack/enforcer.py | 4 +--- mart/attack/projector.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/mart/attack/enforcer.py b/mart/attack/enforcer.py index 95e6716b..1c2347c2 100644 --- a/mart/attack/enforcer.py +++ b/mart/attack/enforcer.py @@ -115,10 +115,8 @@ def __call__( and isinstance(input, Iterable) # noqa: W503 and isinstance(target, Iterable) # noqa: W503 ): - [ + for input_adv_i, input_i, target_i in zip(input_adv, input, target): self.enforce(input_adv_i, input=input_i, target=target_i) - for input_adv_i, input_i, target_i in zip(input_adv, input, target) - ] @torch.no_grad() def enforce( diff --git a/mart/attack/projector.py b/mart/attack/projector.py index 095d2601..9f7c77ac 100644 --- a/mart/attack/projector.py +++ b/mart/attack/projector.py @@ -31,10 +31,8 @@ def __call__( and isinstance(input, Iterable) # noqa: W503 and isinstance(target, Iterable) # noqa: W503 ): - [ + for perturbation_i, input_i, target_i in zip(perturbation, input, target): self.project_(perturbation_i, input=input_i, target=target_i) - for perturbation_i, input_i, target_i in zip(perturbation, input, target) - ] else: raise NotImplementedError From 53ee7f4f7a9b9dc1ffc2a17b599af4b53f5813f3 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Fri, 14 Apr 2023 14:57:58 -0700 Subject: [PATCH 04/36] Make GradientModifier accept Iterable[torch.Tensor] --- mart/attack/gradient_modifier.py | 36 ++++++++++++++------------------ 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/mart/attack/gradient_modifier.py b/mart/attack/gradient_modifier.py index dd680a95..b2882574 100644 --- a/mart/attack/gradient_modifier.py +++ b/mart/attack/gradient_modifier.py @@ -6,7 +6,6 @@ from __future__ import annotations -import abc from typing import Iterable import torch @@ -14,36 +13,33 @@ __all__ = ["GradientModifier"] -class GradientModifier(abc.ABC): +class GradientModifier: """Gradient modifier base class.""" - def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None: - pass - - -class Sign(GradientModifier): def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None: if isinstance(parameters, torch.Tensor): parameters = [parameters] - parameters = [p for p in parameters if p.grad is not None] + [self.modify_(parameter) for parameter in parameters] + + @torch.no_grad() + def modify_(self, parameter: torch.Tensor) -> None: + pass + - for p in parameters: - p.grad.detach().sign_() +class Sign(GradientModifier): + @torch.no_grad() + def modify_(self, parameter: torch.Tensor) -> None: + parameter.grad.sign_() class LpNormalizer(GradientModifier): """Scale gradients by a certain L-p norm.""" def __init__(self, p: int | float): - self.p = p - - def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None: - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - - parameters = [p for p in parameters if p.grad is not None] + self.p = float(p) - for p in parameters: - p_norm = torch.norm(p.grad.detach(), p=self.p) - p.grad.detach().div_(p_norm) + @torch.no_grad() + def modify_(self, parameter: torch.Tensor) -> None: + p_norm = torch.norm(parameter.grad.detach(), p=self.p) + parameter.grad.detach().div_(p_norm) From 3f399fa74f8e6f89e7e829afab28aa83d637786c Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Tue, 18 Apr 2023 09:46:55 -0700 Subject: [PATCH 05/36] Pull the modality_dispatch code from PR 115. --- mart/utils/modality_dispatch.py | 67 +++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 mart/utils/modality_dispatch.py diff --git a/mart/utils/modality_dispatch.py b/mart/utils/modality_dispatch.py new file mode 100644 index 00000000..6ea6206e --- /dev/null +++ b/mart/utils/modality_dispatch.py @@ -0,0 +1,67 @@ +# +# Copyright (C) 2022 Intel Corporation +# +# SPDX-License-Identifier: BSD-3-Clause +# + +from __future__ import annotations + +from itertools import cycle +from typing import Any, Callable + +import torch +from torch import Tensor + +__all__ = ["modality_dispatch"] + + +def modality_dispatch( + modality_func: Callable | dict[str, Callable], + data: Tensor | tuple | list[Tensor] | dict[str, Tensor], + *, + input: Tensor | tuple | list[Tensor] | dict[str, Tensor], + target: torch.Tensor | dict[str, Any] | list[dict[str, Any]] | None, + modality: str = "default", +): + """Recursively dispatch data and input/target to functions of the same modality. + + The function returns an object that is homomorphic to input and data. + """ + + assert type(data) == type(input) + if target is None: + # Make target zips well with input. + target = cycle([None]) + + if isinstance(input, torch.Tensor): + if isinstance(modality_func, dict): + # A dictionary of Callable indexed by modality. + return modality_func[modality](data, input=input, target=target) + else: + # A Callable with modality=? as a keyword argument. + return modality_func(data, input=input, target=target, modality=modality) + elif isinstance(input, dict): + # The dict input has modalities specified in keys, passing them recursively. + output = {} + for modality in input.keys(): + output[modality] = modality_dispatch( + modality_func, + data[modality], + input=input[modality], + target=target, + modality=modality, + ) + return output + elif isinstance(input, (list, tuple)): + # The list or tuple input is a collection of sub-input and sub-target. + output = [] + for data_i, input_i, target_i in zip(data, input, target): + output_i = modality_dispatch( + modality_func, data_i, input=input_i, target=target_i, modality=modality + ) + output.append(output_i) + if isinstance(input, tuple): + output = tuple(output) + return output + else: + raise ValueError(f"Unsupported data type of input: {type(input)}.") From fe8786434c5983dd09f12d730cda3f81f3e86eae Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Tue, 18 Apr 2023 09:49:41 -0700 Subject: [PATCH 06/36] Add a constant DEFAULT_MODALITY. --- mart/utils/modality_dispatch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mart/utils/modality_dispatch.py b/mart/utils/modality_dispatch.py index 6ea6206e..c3a8b547 100644 --- a/mart/utils/modality_dispatch.py +++ b/mart/utils/modality_dispatch.py @@ -14,6 +14,8 @@ __all__ = ["modality_dispatch"] +DEFAULT_MODALITY = "default" + def modality_dispatch( modality_func: Callable | dict[str, Callable], @@ -21,7 +23,7 @@ def modality_dispatch( *, input: Tensor | tuple | list[Tensor] | dict[str, Tensor], target: torch.Tensor | dict[str, Any] | list[dict[str, Any]] | None, - modality: str = "default", + modality: str = DEFAULT_MODALITY, ): """Recursively dispatch data and input/target to functions of the same modality. From ad1f3721d426abfa50eec774cf50e8a760a6b016 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Tue, 18 Apr 2023 10:07:48 -0700 Subject: [PATCH 07/36] Add modality aware enforcer. --- mart/attack/enforcer.py | 30 +++++++++++------------ tests/test_enforcer.py | 54 ++++++++++++++++++++--------------------- 2 files changed, 41 insertions(+), 43 deletions(-) diff --git a/mart/attack/enforcer.py b/mart/attack/enforcer.py index 1c2347c2..8344fb3b 100644 --- a/mart/attack/enforcer.py +++ b/mart/attack/enforcer.py @@ -11,6 +11,8 @@ import torch +from ..utils.modality_dispatch import modality_dispatch + __all__ = ["Enforcer"] @@ -95,36 +97,32 @@ def verify(self, input_adv, *, input, target): class Enforcer: - def __init__(self, constraints: dict[str, Constraint]) -> None: - self.constraints = list(constraints.values()) # intentionally ignore keys + def __init__(self, **modality_constraints: dict[str, dict[str, Constraint]]) -> None: + self.modality_constraints = modality_constraints @torch.no_grad() def __call__( self, - input_adv: torch.Tensor | Iterable[torch.Tensor], + input_adv: torch.Tensor | Iterable[torch.Tensor | dict[str, torch.Tensor]], *, - input: torch.Tensor | Iterable[torch.Tensor], + input: torch.Tensor | Iterable[torch.Tensor | dict[str, torch.Tensor]], target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], **kwargs, ): - if isinstance(input_adv, torch.Tensor) and isinstance(input, torch.Tensor): - self.enforce(input_adv, input=input, target=target) - - elif ( - isinstance(input_adv, Iterable) - and isinstance(input, Iterable) # noqa: W503 - and isinstance(target, Iterable) # noqa: W503 - ): - for input_adv_i, input_i, target_i in zip(input_adv, input, target): - self.enforce(input_adv_i, input=input_i, target=target_i) + # The default modality is set to "constraints", so that it is backward compatible with existing configs. + modality_dispatch( + self._enforce, input_adv, input=input, target=target, modality="constraints" + ) @torch.no_grad() - def enforce( + def _enforce( self, input_adv: torch.Tensor, *, input: torch.Tensor, target: torch.Tensor | dict[str, Any], + modality: str, ): - for constraint in self.constraints: + # intentionally ignore keys after modality. + for constraint in self.modality_constraints[modality].values(): constraint(input_adv, input=input, target=target) diff --git a/tests/test_enforcer.py b/tests/test_enforcer.py index e67b1034..2c56b3ad 100644 --- a/tests/test_enforcer.py +++ b/tests/test_enforcer.py @@ -97,30 +97,30 @@ def test_enforcer_non_modality(): enforcer((input_adv,), input=(input,), target=(target,)) -# def test_enforcer_modality(): -# # Assume a rgb modality. -# enforcer = Enforcer(rgb={"range": Range(min=0, max=255)}) -# -# input = torch.tensor([0, 0, 0]) -# perturbation = torch.tensor([0, 128, 255]) -# input_adv = input + perturbation -# target = None -# -# # Dictionary input. -# enforcer({"rgb": input_adv}, input={"rgb": input}, target=target) -# # List of dictionary input. -# enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target]) -# # Tuple of dictionary input. -# enforcer(({"rgb": input_adv},), input=({"rgb": input},), target=(target,)) -# -# perturbation = torch.tensor([0, -1, 255]) -# input_adv = input + perturbation -# -# with pytest.raises(ConstraintViolated): -# enforcer({"rgb": input_adv}, input={"rgb": input}, target=target) -# -# with pytest.raises(ConstraintViolated): -# enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target]) -# -# with pytest.raises(ConstraintViolated): -# enforcer(({"rgb": input_adv},), input=({"rgb": input},), target=(target,)) +def test_enforcer_modality(): + # Assume a rgb modality. + enforcer = Enforcer(rgb={"range": Range(min=0, max=255)}) + + input = torch.tensor([0, 0, 0]) + perturbation = torch.tensor([0, 128, 255]) + input_adv = input + perturbation + target = None + + # Dictionary input. + enforcer({"rgb": input_adv}, input={"rgb": input}, target=target) + # List of dictionary input. + enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target]) + # Tuple of dictionary input. + enforcer(({"rgb": input_adv},), input=({"rgb": input},), target=(target,)) + + perturbation = torch.tensor([0, -1, 255]) + input_adv = input + perturbation + + with pytest.raises(ConstraintViolated): + enforcer({"rgb": input_adv}, input={"rgb": input}, target=target) + + with pytest.raises(ConstraintViolated): + enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target]) + + with pytest.raises(ConstraintViolated): + enforcer(({"rgb": input_adv},), input=({"rgb": input},), target=(target,)) From 5acc632416c11b64d4f55a0832d677cb66b74bec Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Tue, 18 Apr 2023 10:53:10 -0700 Subject: [PATCH 08/36] Type annotation. --- mart/utils/modality_dispatch.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/mart/utils/modality_dispatch.py b/mart/utils/modality_dispatch.py index c3a8b547..1b998d53 100644 --- a/mart/utils/modality_dispatch.py +++ b/mart/utils/modality_dispatch.py @@ -7,10 +7,9 @@ from __future__ import annotations from itertools import cycle -from typing import Any, Callable +from typing import Any, Callable, Iterable import torch -from torch import Tensor __all__ = ["modality_dispatch"] @@ -19,10 +18,10 @@ def modality_dispatch( modality_func: Callable | dict[str, Callable], - data: Tensor | tuple | list[Tensor] | dict[str, Tensor], + data: torch.Tensor | dict[str, torch.Tensor] | Iterable[Any], *, - input: Tensor | tuple | list[Tensor] | dict[str, Tensor], - target: torch.Tensor | dict[str, Any] | list[dict[str, Any]] | None, + input: torch.Tensor | dict[str, torch.Tensor] | Iterable[Any], + target: torch.Tensor | Iterable[Any] | None, modality: str = DEFAULT_MODALITY, ): """Recursively dispatch data and input/target to functions of the same modality. From 886143694fce699a6a25e2cae1ac3ba88b9017d6 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Tue, 18 Apr 2023 10:55:00 -0700 Subject: [PATCH 09/36] Type annotation. --- mart/attack/enforcer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mart/attack/enforcer.py b/mart/attack/enforcer.py index 8344fb3b..6dec5b11 100644 --- a/mart/attack/enforcer.py +++ b/mart/attack/enforcer.py @@ -103,10 +103,10 @@ def __init__(self, **modality_constraints: dict[str, dict[str, Constraint]]) -> @torch.no_grad() def __call__( self, - input_adv: torch.Tensor | Iterable[torch.Tensor | dict[str, torch.Tensor]], + input_adv: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], *, - input: torch.Tensor | Iterable[torch.Tensor | dict[str, torch.Tensor]], - target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], + input: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], **kwargs, ): # The default modality is set to "constraints", so that it is backward compatible with existing configs. From e46151dd748a076e8bdcfbfb6d8eabe4bae7b627 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Mon, 12 Jun 2023 11:50:49 -0700 Subject: [PATCH 10/36] Make a single-level if-else in modality_dispatch(). --- mart/utils/modality_dispatch.py | 41 ++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/mart/utils/modality_dispatch.py b/mart/utils/modality_dispatch.py index 1b998d53..e0be28c6 100644 --- a/mart/utils/modality_dispatch.py +++ b/mart/utils/modality_dispatch.py @@ -29,19 +29,25 @@ def modality_dispatch( The function returns an object that is homomorphic to input and data. """ - assert type(data) == type(input) if target is None: # Make target zips well with input. target = cycle([None]) - if isinstance(input, torch.Tensor): - if isinstance(modality_func, dict): - # A dictionary of Callable indexed by modality. - return modality_func[modality](data, input=input, target=target) - else: - # A Callable with modality=? as a keyword argument. - return modality_func(data, input=input, target=target, modality=modality) - elif isinstance(input, dict): + if ( + isinstance(input, torch.Tensor) + and isinstance(data, torch.Tensor) # noqa: W503 + and isinstance(modality_func, dict) # noqa: W503 + ): + # A dictionary of Callable indexed by modality. + return modality_func[modality](data, input=input, target=target) + elif ( + isinstance(input, torch.Tensor) + and isinstance(data, torch.Tensor) # noqa: W503 + and isinstance(modality_func, Callable) # noqa: W503 + ): + # A Callable with modality=? as a keyword argument. + return modality_func(data, input=input, target=target, modality=modality) + elif isinstance(input, dict) and isinstance(data, dict): # The dict input has modalities specified in keys, passing them recursively. output = {} for modality in input.keys(): @@ -53,7 +59,7 @@ def modality_dispatch( modality=modality, ) return output - elif isinstance(input, (list, tuple)): + elif isinstance(input, (list, tuple)) and isinstance(data, (list, tuple)): # The list or tuple input is a collection of sub-input and sub-target. output = [] for data_i, input_i, target_i in zip(data, input, target): @@ -64,5 +70,18 @@ def modality_dispatch( if isinstance(input, tuple): output = tuple(output) return output + elif isinstance(input, (list, tuple)) and isinstance(data, torch.Tensor): + # Data is shared for all input, e.g. universal perturbation. + output = [] + for input_i, target_i in zip(input, target): + output_i = modality_dispatch( + modality_func, data, input=input_i, target=target_i, modality=modality + ) + output.append(output_i) + if isinstance(input, tuple): + output = tuple(output) + return output else: - raise ValueError(f"Unsupported data type of input: {type(input)}.") + raise ValueError( + f"Unsupported data type combination: type(input)={type(input)} and type(data)={type(data)}." + ) From 7bb33216e8d5ace56265add2f4eceae835f38ae0 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Mon, 12 Jun 2023 15:55:41 -0700 Subject: [PATCH 11/36] Remove unused keys early. --- mart/attack/enforcer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/mart/attack/enforcer.py b/mart/attack/enforcer.py index 14650d02..8c04fa78 100644 --- a/mart/attack/enforcer.py +++ b/mart/attack/enforcer.py @@ -98,7 +98,12 @@ def verify(self, input_adv, *, input, target): class Enforcer: def __init__(self, **modality_constraints: dict[str, dict[str, Constraint]]) -> None: - self.modality_constraints = modality_constraints + self.modality_constraints = {} + + for modality, constraints in modality_constraints.items(): + # Intentionally ignore keys after modality. + # The keys are there for combining constraints easily in Hydra. + self.modality_constraints[modality] = constraints.values() @torch.no_grad() def __call__( @@ -123,6 +128,6 @@ def _enforce( target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], modality: str, ): - # intentionally ignore keys after modality. - for constraint in self.modality_constraints[modality].values(): + + for constraint in self.modality_constraints[modality]: constraint(input_adv, input=input, target=target) From 20ffada9a3159e00e41f09cc8939a561bb770ea9 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Tue, 20 Jun 2023 11:54:21 -0700 Subject: [PATCH 12/36] Make it fancy with singledispatch. --- mart/attack/enforcer.py | 6 +- mart/utils/modality_dispatch.py | 171 ++++++++++++++++++++++---------- 2 files changed, 123 insertions(+), 54 deletions(-) diff --git a/mart/attack/enforcer.py b/mart/attack/enforcer.py index 8c04fa78..2e46e442 100644 --- a/mart/attack/enforcer.py +++ b/mart/attack/enforcer.py @@ -116,7 +116,11 @@ def __call__( ): # The default modality is set to "constraints", so that it is backward compatible with existing configs. modality_dispatch( - self._enforce, input_adv, input=input, target=target, modality="constraints" + input, + data=input_adv, + target=target, + modality_func=self._enforce, + modality="constraints", ) @torch.no_grad() diff --git a/mart/utils/modality_dispatch.py b/mart/utils/modality_dispatch.py index e0be28c6..082c128f 100644 --- a/mart/utils/modality_dispatch.py +++ b/mart/utils/modality_dispatch.py @@ -6,6 +6,7 @@ from __future__ import annotations +import functools from itertools import cycle from typing import Any, Callable, Iterable @@ -16,72 +17,136 @@ DEFAULT_MODALITY = "default" +# We make input the first non-keyword argument for singledispatch to work. +@functools.singledispatch def modality_dispatch( - modality_func: Callable | dict[str, Callable], - data: torch.Tensor | dict[str, torch.Tensor] | Iterable[Any], - *, input: torch.Tensor | dict[str, torch.Tensor] | Iterable[Any], + *, + data: torch.Tensor | dict[str, torch.Tensor] | Iterable[Any], target: torch.Tensor | Iterable[Any] | None, + modality_func: Callable | dict[str, Callable], modality: str = DEFAULT_MODALITY, ): """Recursively dispatch data and input/target to functions of the same modality. - The function returns an object that is homomorphic to input and data. + The function returns an object that is homomorphic to input. """ - if target is None: - # Make target zips well with input. - target = cycle([None]) + raise ValueError( + f"Unsupported data type combination: type(input)={type(input)} and type(data)={type(data)}." + ) + + # if ( + # isinstance(input, torch.Tensor) + # and isinstance(data, torch.Tensor) # noqa: W503 + # and isinstance(modality_func, dict) # noqa: W503 + # ): + # # A dictionary of Callable indexed by modality. + # return modality_func[modality](data, input=input, target=target) + # elif ( + # isinstance(input, torch.Tensor) + # and isinstance(data, torch.Tensor) # noqa: W503 + # and isinstance(modality_func, Callable) # noqa: W503 + # ): + # # A Callable with modality=? as a keyword argument. + # return modality_func(data, input=input, target=target, modality=modality) + # elif isinstance(input, dict): + # # The dict input has modalities specified in keys, passing them recursively. + # output = {} + # for modality in input.keys(): + # output[modality] = modality_dispatch( + # modality_func, + # data[modality], + # input=input[modality], + # target=target, + # modality=modality, + # ) + # return output + # elif isinstance(input, (list, tuple)): + # # The list or tuple input is a collection of sub-input and sub-target. + # if not isinstance(target, (list, tuple)): + # # Make target zip well with input. + # target = cycle([target]) + # if not isinstance(data, (list, tuple)): + # # Data is shared for all input, e.g. universal perturbation. + # # Make data zip well with input. + # data = cycle([data]) + + # output = [] + # for data_i, input_i, target_i in zip(data, input, target): + # output_i = modality_dispatch( + # modality_func, data_i, input=input_i, target=target_i, modality=modality + # ) + # output.append(output_i) + # if isinstance(input, tuple): + # output = tuple(output) + # return output + # else: + # raise ValueError( + # f"Unsupported data type combination: type(input)={type(input)} and type(data)={type(data)}." + # ) + - if ( - isinstance(input, torch.Tensor) - and isinstance(data, torch.Tensor) # noqa: W503 - and isinstance(modality_func, dict) # noqa: W503 - ): +@modality_dispatch.register +def _(input: torch.Tensor, *, data, target, modality, modality_func): + if isinstance(modality_func, dict): # A dictionary of Callable indexed by modality. return modality_func[modality](data, input=input, target=target) - elif ( - isinstance(input, torch.Tensor) - and isinstance(data, torch.Tensor) # noqa: W503 - and isinstance(modality_func, Callable) # noqa: W503 - ): + elif isinstance(modality_func, Callable): # A Callable with modality=? as a keyword argument. return modality_func(data, input=input, target=target, modality=modality) - elif isinstance(input, dict) and isinstance(data, dict): - # The dict input has modalities specified in keys, passing them recursively. - output = {} - for modality in input.keys(): - output[modality] = modality_dispatch( - modality_func, - data[modality], - input=input[modality], - target=target, - modality=modality, - ) - return output - elif isinstance(input, (list, tuple)) and isinstance(data, (list, tuple)): - # The list or tuple input is a collection of sub-input and sub-target. - output = [] - for data_i, input_i, target_i in zip(data, input, target): - output_i = modality_dispatch( - modality_func, data_i, input=input_i, target=target_i, modality=modality - ) - output.append(output_i) - if isinstance(input, tuple): - output = tuple(output) - return output - elif isinstance(input, (list, tuple)) and isinstance(data, torch.Tensor): + + +@modality_dispatch.register +def _(input: dict, *, data, target, modality, modality_func): + # The dict input has modalities specified in keys, passing them recursively. + output = {} + for modality in input.keys(): + output[modality] = modality_dispatch( + input[modality], + data=data[modality], + target=target, + modality=modality, + modality_func=modality_func, + ) + return output + + +@modality_dispatch.register +def _(input: list, *, data, target, modality, modality_func): + # The list input implies a collection of sub-input and sub-target. + if not isinstance(target, (list, tuple)): + # Make target zip well with input. + target = cycle([target]) + if not isinstance(data, (list, tuple)): + # Make data zip well with input. # Data is shared for all input, e.g. universal perturbation. - output = [] - for input_i, target_i in zip(input, target): - output_i = modality_dispatch( - modality_func, data, input=input_i, target=target_i, modality=modality - ) - output.append(output_i) - if isinstance(input, tuple): - output = tuple(output) - return output - else: - raise ValueError( - f"Unsupported data type combination: type(input)={type(input)} and type(data)={type(data)}." + data = cycle([data]) + + output = [] + for data_i, input_i, target_i in zip(data, input, target): + output_i = modality_dispatch( + input_i, + data=data_i, + target=target_i, + modality=modality, + modality_func=modality_func, ) + output.append(output_i) + + return output + + +@modality_dispatch.register +def _(input: tuple, *, data, target, modality, modality_func): + # The tuple input is similar with the list input. + output = modality_dispatch( + list(input), + data=data, + target=target, + modality=modality, + modality_func=modality_func, + ) + # Make the output a tuple, the same as input. + output = tuple(output) + return output From e77236f333b20035d865f0ee2d88a80a2d4d3f24 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Tue, 20 Jun 2023 14:20:56 -0700 Subject: [PATCH 13/36] Rename back to Enforcer.enforce(). --- mart/attack/enforcer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mart/attack/enforcer.py b/mart/attack/enforcer.py index 2e46e442..48310d30 100644 --- a/mart/attack/enforcer.py +++ b/mart/attack/enforcer.py @@ -119,12 +119,12 @@ def __call__( input, data=input_adv, target=target, - modality_func=self._enforce, + modality_func=self.enforce, modality="constraints", ) @torch.no_grad() - def _enforce( + def enforce( self, input_adv: torch.Tensor | Iterable[torch.Tensor], *, From 3737c9876be1c3416b72d414f804a127736f318c Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Tue, 20 Jun 2023 14:28:45 -0700 Subject: [PATCH 14/36] Comment. --- mart/utils/modality_dispatch.py | 54 +-------------------------------- 1 file changed, 1 insertion(+), 53 deletions(-) diff --git a/mart/utils/modality_dispatch.py b/mart/utils/modality_dispatch.py index 082c128f..dde7d6e3 100644 --- a/mart/utils/modality_dispatch.py +++ b/mart/utils/modality_dispatch.py @@ -32,59 +32,7 @@ def modality_dispatch( The function returns an object that is homomorphic to input. """ - raise ValueError( - f"Unsupported data type combination: type(input)={type(input)} and type(data)={type(data)}." - ) - - # if ( - # isinstance(input, torch.Tensor) - # and isinstance(data, torch.Tensor) # noqa: W503 - # and isinstance(modality_func, dict) # noqa: W503 - # ): - # # A dictionary of Callable indexed by modality. - # return modality_func[modality](data, input=input, target=target) - # elif ( - # isinstance(input, torch.Tensor) - # and isinstance(data, torch.Tensor) # noqa: W503 - # and isinstance(modality_func, Callable) # noqa: W503 - # ): - # # A Callable with modality=? as a keyword argument. - # return modality_func(data, input=input, target=target, modality=modality) - # elif isinstance(input, dict): - # # The dict input has modalities specified in keys, passing them recursively. - # output = {} - # for modality in input.keys(): - # output[modality] = modality_dispatch( - # modality_func, - # data[modality], - # input=input[modality], - # target=target, - # modality=modality, - # ) - # return output - # elif isinstance(input, (list, tuple)): - # # The list or tuple input is a collection of sub-input and sub-target. - # if not isinstance(target, (list, tuple)): - # # Make target zip well with input. - # target = cycle([target]) - # if not isinstance(data, (list, tuple)): - # # Data is shared for all input, e.g. universal perturbation. - # # Make data zip well with input. - # data = cycle([data]) - - # output = [] - # for data_i, input_i, target_i in zip(data, input, target): - # output_i = modality_dispatch( - # modality_func, data_i, input=input_i, target=target_i, modality=modality - # ) - # output.append(output_i) - # if isinstance(input, tuple): - # output = tuple(output) - # return output - # else: - # raise ValueError( - # f"Unsupported data type combination: type(input)={type(input)} and type(data)={type(data)}." - # ) + raise ValueError(f"Unsupported data type of input: type(input)={type(input)}.") @modality_dispatch.register From 3833c34ca68f63edc11d2a1665338a04cfd4fbfd Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Tue, 20 Jun 2023 14:30:30 -0700 Subject: [PATCH 15/36] Comment. --- mart/utils/modality_dispatch.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mart/utils/modality_dispatch.py b/mart/utils/modality_dispatch.py index dde7d6e3..93379045 100644 --- a/mart/utils/modality_dispatch.py +++ b/mart/utils/modality_dispatch.py @@ -17,7 +17,6 @@ DEFAULT_MODALITY = "default" -# We make input the first non-keyword argument for singledispatch to work. @functools.singledispatch def modality_dispatch( input: torch.Tensor | dict[str, torch.Tensor] | Iterable[Any], @@ -29,7 +28,8 @@ def modality_dispatch( ): """Recursively dispatch data and input/target to functions of the same modality. - The function returns an object that is homomorphic to input. + The function returns an object that is homomorphic to input. We make input the first non- + keyword argument for singledispatch to work. """ raise ValueError(f"Unsupported data type of input: type(input)={type(input)}.") @@ -37,6 +37,7 @@ def modality_dispatch( @modality_dispatch.register def _(input: torch.Tensor, *, data, target, modality, modality_func): + # Take action when input is a tensor. if isinstance(modality_func, dict): # A dictionary of Callable indexed by modality. return modality_func[modality](data, input=input, target=target) From f08510bf33bfd28480e850dbcdcbdc8868521ae3 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Tue, 20 Jun 2023 20:22:35 -0700 Subject: [PATCH 16/36] Loosen data type requirement. --- mart/utils/modality_dispatch.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mart/utils/modality_dispatch.py b/mart/utils/modality_dispatch.py index 93379045..4fff257f 100644 --- a/mart/utils/modality_dispatch.py +++ b/mart/utils/modality_dispatch.py @@ -64,11 +64,12 @@ def _(input: dict, *, data, target, modality, modality_func): @modality_dispatch.register def _(input: list, *, data, target, modality, modality_func): # The list input implies a collection of sub-input and sub-target. - if not isinstance(target, (list, tuple)): + if not isinstance(target, Iterable): # Make target zip well with input. target = cycle([target]) - if not isinstance(data, (list, tuple)): + if not isinstance(data, Iterable): # Make data zip well with input. + # Besides list and tuple, data could be ParameterList too. # Data is shared for all input, e.g. universal perturbation. data = cycle([data]) From 403861029d5afb2282c68fea82fc47ccdeb1e6c3 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Tue, 20 Jun 2023 20:29:05 -0700 Subject: [PATCH 17/36] Modality-aware adversary. --- mart/attack/adversary.py | 22 +++++++++++----- mart/attack/initializer.py | 3 ++- mart/attack/perturber.py | 23 +++++++++++++--- mart/configs/attack/enforcer/default.yaml | 3 ++- .../object_detection_rgb_mask_adversary.yaml | 26 +++++++++++++++++++ mart/optim/optimizer.py | 21 +++++++++++++-- 6 files changed, 84 insertions(+), 14 deletions(-) create mode 100644 mart/configs/attack/object_detection_rgb_mask_adversary.yaml diff --git a/mart/attack/adversary.py b/mart/attack/adversary.py index 8c5513d2..a5c2e558 100644 --- a/mart/attack/adversary.py +++ b/mart/attack/adversary.py @@ -16,6 +16,7 @@ from mart.utils import silent from ..optim import OptimizerFactory +from ..utils.modality_dispatch import DEFAULT_MODALITY, modality_dispatch if TYPE_CHECKING: from .composer import Composer @@ -149,10 +150,11 @@ def configure_gradient_clipping( if self.gradient_modifier: for group in optimizer.param_groups: - self.gradient_modifier(group["params"]) + modality = group["modality"] if "modality" in group else DEFAULT_MODALITY + self.gradient_modifier[modality](group["params"]) @silent() - def forward(self, *, model=None, sequence=None, **batch): + def forward(self, *, model=None, sequence=None, input, target, **batch): batch["model"] = model batch["sequence"] = sequence @@ -161,14 +163,20 @@ def forward(self, *, model=None, sequence=None, **batch): # Adversary lives inside the model, we also need the remaining sequence to be able to # get a loss. if model and sequence: - self._attack(**batch) - - perturbation = self.perturber(**batch) - input_adv = self.composer(perturbation, **batch) + self._attack(input=input, target=target, **batch) + + perturbation = self.perturber(input=input, target=target, **batch) + input_adv = modality_dispatch( + input, + data=perturbation, + target=target, + modality_func=self.composer, + modality=DEFAULT_MODALITY, + ) # Enforce constraints after the attack optimization ends. if model and sequence: - self.enforcer(input_adv, **batch) + self.enforcer(input_adv, input=input, target=target, **batch) return input_adv diff --git a/mart/attack/initializer.py b/mart/attack/initializer.py index 9b38e6a1..99a51cd7 100644 --- a/mart/attack/initializer.py +++ b/mart/attack/initializer.py @@ -21,7 +21,8 @@ class Initializer: """Initializer base class.""" @torch.no_grad() - def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None: + def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor], **kwargs) -> None: + # Allow **kwargs to work with modality_dispatch(). if isinstance(parameters, torch.Tensor): parameters = [parameters] diff --git a/mart/attack/perturber.py b/mart/attack/perturber.py index 29df3059..bd67a000 100644 --- a/mart/attack/perturber.py +++ b/mart/attack/perturber.py @@ -11,6 +11,7 @@ import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException +from ..utils.modality_dispatch import DEFAULT_MODALITY, modality_dispatch from .projector import Projector if TYPE_CHECKING: @@ -65,6 +66,10 @@ def create_from_tensor(tensor): return torch.nn.Parameter( torch.empty_like(tensor, dtype=torch.float, requires_grad=True) ) + elif isinstance(tensor, dict): + return torch.nn.ParameterDict( + {modality: create_from_tensor(t) for modality, t in tensor.items()} + ) elif isinstance(tensor, Iterable): return torch.nn.ParameterList([create_from_tensor(t) for t in tensor]) else: @@ -76,7 +81,13 @@ def create_from_tensor(tensor): self.perturbation = create_from_tensor(input) # Always (re)initialize perturbation. - self.initializer_(self.perturbation) + modality_dispatch( + input, + data=self.perturbation, + target=None, + modality_func=self.initializer_, + modality=DEFAULT_MODALITY, + ) def named_parameters(self, *args, **kwargs): if self.perturbation is None: @@ -90,12 +101,18 @@ def parameters(self, *args, **kwargs): return super().parameters(*args, **kwargs) - def forward(self, **batch): + def forward(self, *, input, target, **batch): if self.perturbation is None: raise MisconfigurationException( "You need to call the configure_perturbation before forward." ) - self.projector_(self.perturbation, **batch) + modality_dispatch( + input, + data=self.perturbation, + target=target, + modality_func=self.projector_, + modality=DEFAULT_MODALITY, + ) return self.perturbation diff --git a/mart/configs/attack/enforcer/default.yaml b/mart/configs/attack/enforcer/default.yaml index 46fc0bb1..a59d8d3e 100644 --- a/mart/configs/attack/enforcer/default.yaml +++ b/mart/configs/attack/enforcer/default.yaml @@ -1,2 +1,3 @@ _target_: mart.attack.Enforcer -constraints: ??? +# FIXME: Hydra does not detect modality-aware constraints defined as sub-components. +# constraints: ??? diff --git a/mart/configs/attack/object_detection_rgb_mask_adversary.yaml b/mart/configs/attack/object_detection_rgb_mask_adversary.yaml new file mode 100644 index 00000000..bea84349 --- /dev/null +++ b/mart/configs/attack/object_detection_rgb_mask_adversary.yaml @@ -0,0 +1,26 @@ +defaults: + - adversary + - perturber: default + - perturber/initializer@perturber.initializer.rgb: constant + - perturber/projector@perturber.projector.rgb: mask_range + - composer@composer.rgb: overlay + - /optimizer@optimizer: sgd + - gain: rcnn_training_loss + - gradient_modifier@gradient_modifier.rgb: sign + - objective: zero_ap + - enforcer: default + - enforcer/constraints@enforcer.rgb: [mask, pixel_range] + +# Make a 5-step attack for the demonstration purpose. +optimizer: + # Though we only use modality-aware hyper-params, the config requires a value for optimizer.lr. + lr: 0 + rgb: + lr: 55 + +max_iters: 5 + +perturber: + initializer: + rgb: + constant: 127 diff --git a/mart/optim/optimizer.py b/mart/optim/optimizer.py index 3cc57131..94de38a0 100644 --- a/mart/optim/optimizer.py +++ b/mart/optim/optimizer.py @@ -5,6 +5,7 @@ # import logging +from collections import defaultdict logger = logging.getLogger(__name__) @@ -32,6 +33,7 @@ def __call__(self, module): bias_params = [] norm_params = [] weight_params = [] + modality_params = defaultdict(list) for param_name, param in module.named_parameters(): if not param.requires_grad: @@ -42,7 +44,12 @@ def __call__(self, module): _, param_module = next(filter(lambda nm: nm[0] == module_name, module.named_modules())) module_kind = param_module.__class__.__name__ - if "Norm" in module_kind: + # FIXME: Other modules may also use ParameterDict. + if module_kind == "ParameterDict": + # Identify modality-aware parameters for adversary. + modality = param_name.split(".")[-1] + modality_params[modality].append(param) + elif "Norm" in module_kind: assert len(param.shape) == 1 norm_params.append(param) elif isinstance(param, torch.nn.UninitializedParameter): @@ -53,8 +60,18 @@ def __call__(self, module): else: # Assume weights weight_params.append(param) - # Set decay for bias and norm parameters params = [] + + # Set modality-aware params. + if len(modality_params) > 0: + for modality, param in modality_params.items(): + # Take notes of modality for gradient modifier later. + # Add modality-specific optim params. + params.append( + {"params": param, "modality": modality} | self.kwargs.pop(modality, {}) + ) + + # Set decay for bias and norm parameters if len(weight_params) > 0: params.append({"params": weight_params}) # use default weight decay if len(bias_params) > 0: From 7f47ab6989aef572fb2e12b75e227ca8321397cf Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Tue, 20 Jun 2023 20:49:03 -0700 Subject: [PATCH 18/36] Backward compatible with exisiting non-modality configs of adversary. --- mart/attack/adversary.py | 8 ++++++++ mart/attack/perturber.py | 11 ++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/mart/attack/adversary.py b/mart/attack/adversary.py index a5c2e558..ba4affcc 100644 --- a/mart/attack/adversary.py +++ b/mart/attack/adversary.py @@ -68,6 +68,14 @@ def __init__( # Hide the perturber module in a list, so that perturbation is not exported as a parameter in the model checkpoint. # and DDP won't try to get the uninitialized parameters of perturbation. self._perturber = [perturber] + + # Modality-specific objects. + # Backward compatibility, in case modality is unknown, and not given in input. + if not isinstance(gradient_modifier, dict): + gradient_modifier = {DEFAULT_MODALITY: gradient_modifier} + if not isinstance(composer, dict): + composer = {DEFAULT_MODALITY: composer} + self.composer = composer self.optimizer = optimizer if not isinstance(self.optimizer, OptimizerFactory): diff --git a/mart/attack/perturber.py b/mart/attack/perturber.py index bd67a000..cf5d7f5c 100644 --- a/mart/attack/perturber.py +++ b/mart/attack/perturber.py @@ -35,8 +35,17 @@ def __init__( """ super().__init__() + projector = projector or Projector() + + # Modality-specific objects. + # Backward compatibility, in case modality is unknown, and not given in input. + if not isinstance(initializer, dict): + initializer = {DEFAULT_MODALITY: initializer} + if not isinstance(projector, dict): + projector = {DEFAULT_MODALITY: projector} + self.initializer_ = initializer - self.projector_ = projector or Projector() + self.projector_ = projector self.perturbation = None From a121c80382ff64abdce0ee1871f46f50133db27d Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Tue, 20 Jun 2023 20:49:22 -0700 Subject: [PATCH 19/36] Fix test. --- tests/test_adversary.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_adversary.py b/tests/test_adversary.py index bf9ff9c0..edf7afe1 100644 --- a/tests/test_adversary.py +++ b/tests/test_adversary.py @@ -218,7 +218,8 @@ def gain(logits): ) # Perturbation initialized as zero. - def initializer(x): + # Initializer needs to absorb **kwargs from modality_dispatch(). + def initializer(x, **kwargs): torch.nn.init.constant_(x, 0) perturber = Perturber( From c50826673d008828d3382fcf66f74df2c15c6ade Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 21 Jun 2023 15:42:05 -0700 Subject: [PATCH 20/36] Type annotation for modality-aware componenets. --- mart/attack/adversary.py | 4 ++-- mart/attack/perturber.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mart/attack/adversary.py b/mart/attack/adversary.py index ba4affcc..941461d8 100644 --- a/mart/attack/adversary.py +++ b/mart/attack/adversary.py @@ -36,10 +36,10 @@ def __init__( self, *, perturber: Perturber, - composer: Composer, + composer: Composer | dict[str, Composer], optimizer: OptimizerFactory | Callable[[Any], torch.optim.Optimizer], gain: Gain, - gradient_modifier: GradientModifier | None = None, + gradient_modifier: GradientModifier | dict[str, GradientModifier] | None = None, objective: Objective | None = None, enforcer: Enforcer | None = None, attacker: pl.Trainer | None = None, diff --git a/mart/attack/perturber.py b/mart/attack/perturber.py index cf5d7f5c..cb6cf4d3 100644 --- a/mart/attack/perturber.py +++ b/mart/attack/perturber.py @@ -24,8 +24,8 @@ class Perturber(torch.nn.Module): def __init__( self, *, - initializer: Initializer, - projector: Projector | None = None, + initializer: Initializer | dict[str, Initializer], + projector: Projector | dict[str, Projector] | None = None, ): """_summary_ From 267695445a04790bf73ed328b98e98a9c6f03ca3 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 21 Jun 2023 15:57:32 -0700 Subject: [PATCH 21/36] Make a new name ModalityParameterDict for modality-aware parameters. --- mart/attack/perturber.py | 8 ++++++-- mart/optim/optimizer.py | 3 +-- mart/utils/modality_dispatch.py | 8 ++++++-- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/mart/attack/perturber.py b/mart/attack/perturber.py index cb6cf4d3..7b42baca 100644 --- a/mart/attack/perturber.py +++ b/mart/attack/perturber.py @@ -11,7 +11,11 @@ import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException -from ..utils.modality_dispatch import DEFAULT_MODALITY, modality_dispatch +from ..utils.modality_dispatch import ( + DEFAULT_MODALITY, + ModalityParameterDict, + modality_dispatch, +) from .projector import Projector if TYPE_CHECKING: @@ -76,7 +80,7 @@ def create_from_tensor(tensor): torch.empty_like(tensor, dtype=torch.float, requires_grad=True) ) elif isinstance(tensor, dict): - return torch.nn.ParameterDict( + return ModalityParameterDict( {modality: create_from_tensor(t) for modality, t in tensor.items()} ) elif isinstance(tensor, Iterable): diff --git a/mart/optim/optimizer.py b/mart/optim/optimizer.py index 94de38a0..2d96e59a 100644 --- a/mart/optim/optimizer.py +++ b/mart/optim/optimizer.py @@ -44,8 +44,7 @@ def __call__(self, module): _, param_module = next(filter(lambda nm: nm[0] == module_name, module.named_modules())) module_kind = param_module.__class__.__name__ - # FIXME: Other modules may also use ParameterDict. - if module_kind == "ParameterDict": + if module_kind == "ModalityParameterDict": # Identify modality-aware parameters for adversary. modality = param_name.split(".")[-1] modality_params[modality].append(param) diff --git a/mart/utils/modality_dispatch.py b/mart/utils/modality_dispatch.py index 4fff257f..a94bc4da 100644 --- a/mart/utils/modality_dispatch.py +++ b/mart/utils/modality_dispatch.py @@ -12,8 +12,6 @@ import torch -__all__ = ["modality_dispatch"] - DEFAULT_MODALITY = "default" @@ -100,3 +98,9 @@ def _(input: tuple, *, data, target, modality, modality_func): # Make the output a tuple, the same as input. output = tuple(output) return output + + +class ModalityParameterDict(torch.nn.ParameterDict): + """Get a new name so we know when parameters are associated with modality.""" + + pass From e38627610f0b0ae8f919a117ebb7bd7fdab1e488 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 21 Jun 2023 16:06:09 -0700 Subject: [PATCH 22/36] Fix function arguments and type annotations. --- mart/utils/modality_dispatch.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/mart/utils/modality_dispatch.py b/mart/utils/modality_dispatch.py index a94bc4da..f9654df1 100644 --- a/mart/utils/modality_dispatch.py +++ b/mart/utils/modality_dispatch.py @@ -17,10 +17,10 @@ @functools.singledispatch def modality_dispatch( - input: torch.Tensor | dict[str, torch.Tensor] | Iterable[Any], + input: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], *, - data: torch.Tensor | dict[str, torch.Tensor] | Iterable[Any], - target: torch.Tensor | Iterable[Any] | None, + data: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]] | None, modality_func: Callable | dict[str, Callable], modality: str = DEFAULT_MODALITY, ): @@ -34,7 +34,7 @@ def modality_dispatch( @modality_dispatch.register -def _(input: torch.Tensor, *, data, target, modality, modality_func): +def _(input: torch.Tensor, *, data, target, modality_func, modality): # Take action when input is a tensor. if isinstance(modality_func, dict): # A dictionary of Callable indexed by modality. @@ -45,7 +45,7 @@ def _(input: torch.Tensor, *, data, target, modality, modality_func): @modality_dispatch.register -def _(input: dict, *, data, target, modality, modality_func): +def _(input: dict, *, data, target, modality_func, modality): # The dict input has modalities specified in keys, passing them recursively. output = {} for modality in input.keys(): @@ -53,14 +53,14 @@ def _(input: dict, *, data, target, modality, modality_func): input[modality], data=data[modality], target=target, - modality=modality, modality_func=modality_func, + modality=modality, ) return output @modality_dispatch.register -def _(input: list, *, data, target, modality, modality_func): +def _(input: list, *, data, target, modality_func, modality): # The list input implies a collection of sub-input and sub-target. if not isinstance(target, Iterable): # Make target zip well with input. @@ -77,8 +77,8 @@ def _(input: list, *, data, target, modality, modality_func): input_i, data=data_i, target=target_i, - modality=modality, modality_func=modality_func, + modality=modality, ) output.append(output_i) @@ -86,14 +86,14 @@ def _(input: list, *, data, target, modality, modality_func): @modality_dispatch.register -def _(input: tuple, *, data, target, modality, modality_func): +def _(input: tuple, *, data, target, modality_func, modality): # The tuple input is similar with the list input. output = modality_dispatch( list(input), data=data, target=target, - modality=modality, modality_func=modality_func, + modality=modality, ) # Make the output a tuple, the same as input. output = tuple(output) From 7f100d94bc935d08868abfeb254733df2805dcea Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 21 Jun 2023 16:09:12 -0700 Subject: [PATCH 23/36] Make modality an optional keyword argument. --- mart/utils/modality_dispatch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mart/utils/modality_dispatch.py b/mart/utils/modality_dispatch.py index f9654df1..4f0757e8 100644 --- a/mart/utils/modality_dispatch.py +++ b/mart/utils/modality_dispatch.py @@ -34,7 +34,7 @@ def modality_dispatch( @modality_dispatch.register -def _(input: torch.Tensor, *, data, target, modality_func, modality): +def _(input: torch.Tensor, *, data, target, modality_func, modality=DEFAULT_MODALITY): # Take action when input is a tensor. if isinstance(modality_func, dict): # A dictionary of Callable indexed by modality. @@ -45,7 +45,7 @@ def _(input: torch.Tensor, *, data, target, modality_func, modality): @modality_dispatch.register -def _(input: dict, *, data, target, modality_func, modality): +def _(input: dict, *, data, target, modality_func, modality=DEFAULT_MODALITY): # The dict input has modalities specified in keys, passing them recursively. output = {} for modality in input.keys(): @@ -60,7 +60,7 @@ def _(input: dict, *, data, target, modality_func, modality): @modality_dispatch.register -def _(input: list, *, data, target, modality_func, modality): +def _(input: list, *, data, target, modality_func, modality=DEFAULT_MODALITY): # The list input implies a collection of sub-input and sub-target. if not isinstance(target, Iterable): # Make target zip well with input. @@ -86,7 +86,7 @@ def _(input: list, *, data, target, modality_func, modality): @modality_dispatch.register -def _(input: tuple, *, data, target, modality_func, modality): +def _(input: tuple, *, data, target, modality_func, modality=DEFAULT_MODALITY): # The tuple input is similar with the list input. output = modality_dispatch( list(input), From c5239077582a95ccaefd9267a12e3ecf55311366 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 21 Jun 2023 16:10:28 -0700 Subject: [PATCH 24/36] Fix type annotation. --- mart/utils/modality_dispatch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mart/utils/modality_dispatch.py b/mart/utils/modality_dispatch.py index 4f0757e8..1262ba8e 100644 --- a/mart/utils/modality_dispatch.py +++ b/mart/utils/modality_dispatch.py @@ -20,7 +20,7 @@ def modality_dispatch( input: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], *, data: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], - target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]] | None, + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], modality_func: Callable | dict[str, Callable], modality: str = DEFAULT_MODALITY, ): From 8866c53f07f4d0c65f3fb53fa0077fee373da5ca Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 21 Jun 2023 16:12:46 -0700 Subject: [PATCH 25/36] Fix type annotation. --- mart/attack/enforcer.py | 2 +- mart/utils/modality_dispatch.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mart/attack/enforcer.py b/mart/attack/enforcer.py index 48310d30..c7d64121 100644 --- a/mart/attack/enforcer.py +++ b/mart/attack/enforcer.py @@ -111,7 +111,7 @@ def __call__( input_adv: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], *, input: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], - target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor | str]], **kwargs, ): # The default modality is set to "constraints", so that it is backward compatible with existing configs. diff --git a/mart/utils/modality_dispatch.py b/mart/utils/modality_dispatch.py index 1262ba8e..1020df23 100644 --- a/mart/utils/modality_dispatch.py +++ b/mart/utils/modality_dispatch.py @@ -20,7 +20,7 @@ def modality_dispatch( input: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], *, data: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], - target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor | str]], modality_func: Callable | dict[str, Callable], modality: str = DEFAULT_MODALITY, ): From 5611b3a35cd1ad851e70b12d0ef53f3ef2901121 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 21 Jun 2023 16:24:31 -0700 Subject: [PATCH 26/36] Simplify composer, initializerr and projector with modality_dispatch. --- mart/attack/composer.py | 23 ++++------------------- mart/attack/initializer.py | 7 ++----- mart/attack/projector.py | 20 ++++---------------- 3 files changed, 10 insertions(+), 40 deletions(-) diff --git a/mart/attack/composer.py b/mart/attack/composer.py index 6b40950a..e3962160 100644 --- a/mart/attack/composer.py +++ b/mart/attack/composer.py @@ -15,28 +15,13 @@ class Composer(abc.ABC): def __call__( self, - perturbation: torch.Tensor | Iterable[torch.Tensor], + perturbation: torch.Tensor, *, - input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], + input: torch.Tensor, + target: torch.Tensor | dict[str, Any], **kwargs, ) -> torch.Tensor | Iterable[torch.Tensor]: - if isinstance(perturbation, torch.Tensor) and isinstance(input, torch.Tensor): - return self.compose(perturbation, input=input, target=target) - - elif ( - isinstance(perturbation, Iterable) - and isinstance(input, Iterable) # noqa: W503 - and isinstance(target, Iterable) # noqa: W503 - ): - # FIXME: replace tuple with whatever input's type is - return tuple( - self.compose(perturbation_i, input=input_i, target=target_i) - for perturbation_i, input_i, target_i in zip(perturbation, input, target) - ) - - else: - raise NotImplementedError + return self.compose(perturbation, input=input, target=target) @abc.abstractmethod def compose( diff --git a/mart/attack/initializer.py b/mart/attack/initializer.py index 99a51cd7..4babbea1 100644 --- a/mart/attack/initializer.py +++ b/mart/attack/initializer.py @@ -21,12 +21,9 @@ class Initializer: """Initializer base class.""" @torch.no_grad() - def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor], **kwargs) -> None: + def __call__(self, parameter: torch.Tensor, **kwargs) -> None: # Allow **kwargs to work with modality_dispatch(). - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - - [self.initialize_(parameter) for parameter in parameters] + self.initialize_(parameter) @torch.no_grad() def initialize_(self, parameter: torch.Tensor) -> None: diff --git a/mart/attack/projector.py b/mart/attack/projector.py index f9887354..a9eb6a25 100644 --- a/mart/attack/projector.py +++ b/mart/attack/projector.py @@ -17,25 +17,13 @@ class Projector: @torch.no_grad() def __call__( self, - perturbation: torch.Tensor | Iterable[torch.Tensor], + perturbation: torch.Tensor, *, - input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], + input: torch.Tensor, + target: torch.Tensor | dict[str, Any], **kwargs, ) -> None: - if isinstance(perturbation, torch.Tensor) and isinstance(input, torch.Tensor): - self.project_(perturbation, input=input, target=target) - - elif ( - isinstance(perturbation, Iterable) - and isinstance(input, Iterable) # noqa: W503 - and isinstance(target, Iterable) # noqa: W503 - ): - for perturbation_i, input_i, target_i in zip(perturbation, input, target): - self.project_(perturbation_i, input=input_i, target=target_i) - - else: - raise NotImplementedError + self.project_(perturbation, input=input, target=target) @torch.no_grad() def project_( From 3fbd0483e380471cc371f0f0f699d90e2a334e5e Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 21 Jun 2023 16:39:31 -0700 Subject: [PATCH 27/36] Simplify type annotaiton with modality_dispatch(). --- mart/attack/projector.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mart/attack/projector.py b/mart/attack/projector.py index a9eb6a25..58af6a7f 100644 --- a/mart/attack/projector.py +++ b/mart/attack/projector.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Any, Iterable +from typing import Any import torch @@ -28,10 +28,10 @@ def __call__( @torch.no_grad() def project_( self, - perturbation: torch.Tensor | Iterable[torch.Tensor], + perturbation: torch.Tensor, *, - input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], + input: torch.Tensor, + target: torch.Tensor | dict[str, Any], ) -> None: pass @@ -45,10 +45,10 @@ def __init__(self, projectors: list[Projector]): @torch.no_grad() def __call__( self, - perturbation: torch.Tensor | Iterable[torch.Tensor], + perturbation: torch.Tensor, *, - input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], + input: torch.Tensor, + target: torch.Tensor | dict[str, Any], **kwargs, ) -> None: for projector in self.projectors: From b675d3f54b2097947118415eb3869e7236afe4af Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 21 Jun 2023 16:41:24 -0700 Subject: [PATCH 28/36] Update type annotation. --- mart/utils/modality_dispatch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mart/utils/modality_dispatch.py b/mart/utils/modality_dispatch.py index 1020df23..d0b29594 100644 --- a/mart/utils/modality_dispatch.py +++ b/mart/utils/modality_dispatch.py @@ -20,7 +20,7 @@ def modality_dispatch( input: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], *, data: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], - target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor | str]], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor | str]] | None, modality_func: Callable | dict[str, Callable], modality: str = DEFAULT_MODALITY, ): From 249223277c57a4542bc8362d40cb944d1750edc2 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 21 Jun 2023 16:42:00 -0700 Subject: [PATCH 29/36] Make explicit function arguments from modality_dispatch(). --- mart/attack/initializer.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/mart/attack/initializer.py b/mart/attack/initializer.py index 4babbea1..e197000a 100644 --- a/mart/attack/initializer.py +++ b/mart/attack/initializer.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Iterable +from typing import Any import torch import torchvision @@ -21,8 +21,14 @@ class Initializer: """Initializer base class.""" @torch.no_grad() - def __call__(self, parameter: torch.Tensor, **kwargs) -> None: - # Allow **kwargs to work with modality_dispatch(). + def __call__( + self, + parameter: torch.Tensor, + *, + input: torch.Tensor | None = None, + target: torch.Tensor | dict[str, Any] | None = None, + ) -> None: + # Accept input and target from modality_dispatch(). self.initialize_(parameter) @torch.no_grad() From 430ed3f30850705ae772fa07b66bf70e281f9843 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 21 Jun 2023 16:44:22 -0700 Subject: [PATCH 30/36] Fix test. --- tests/test_adversary.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/test_adversary.py b/tests/test_adversary.py index edf7afe1..b77f8fdb 100644 --- a/tests/test_adversary.py +++ b/tests/test_adversary.py @@ -15,6 +15,7 @@ import mart from mart.attack import Adversary, Composer, Perturber from mart.attack.gradient_modifier import Sign +from mart.attack.initializer import Constant def test_adversary(input_data, target_data, perturbation): @@ -218,12 +219,8 @@ def gain(logits): ) # Perturbation initialized as zero. - # Initializer needs to absorb **kwargs from modality_dispatch(). - def initializer(x, **kwargs): - torch.nn.init.constant_(x, 0) - perturber = Perturber( - initializer=initializer, + initializer=Constant(0), projector=None, ) From 7bfb9cbcf411fdc77bd3a9385ce1b64dc22de991 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 21 Jun 2023 16:47:32 -0700 Subject: [PATCH 31/36] Simplify type annotation. --- mart/attack/enforcer.py | 2 +- mart/utils/modality_dispatch.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mart/attack/enforcer.py b/mart/attack/enforcer.py index c7d64121..48310d30 100644 --- a/mart/attack/enforcer.py +++ b/mart/attack/enforcer.py @@ -111,7 +111,7 @@ def __call__( input_adv: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], *, input: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], - target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor | str]], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], **kwargs, ): # The default modality is set to "constraints", so that it is backward compatible with existing configs. diff --git a/mart/utils/modality_dispatch.py b/mart/utils/modality_dispatch.py index d0b29594..4f0757e8 100644 --- a/mart/utils/modality_dispatch.py +++ b/mart/utils/modality_dispatch.py @@ -20,7 +20,7 @@ def modality_dispatch( input: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], *, data: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], - target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor | str]] | None, + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]] | None, modality_func: Callable | dict[str, Callable], modality: str = DEFAULT_MODALITY, ): From 3607491bb22a29106750b917500c917761de0342 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Fri, 23 Jun 2023 11:39:24 -0700 Subject: [PATCH 32/36] Revert changes in Composer and make a new Modality(Composer). --- mart/attack/adversary.py | 22 +++----- mart/attack/composer.py | 56 +++++++++++++++++-- mart/configs/attack/composer/modality.yaml | 1 + .../object_detection_rgb_mask_adversary.yaml | 3 +- 4 files changed, 62 insertions(+), 20 deletions(-) create mode 100644 mart/configs/attack/composer/modality.yaml diff --git a/mart/attack/adversary.py b/mart/attack/adversary.py index 941461d8..f8ff5049 100644 --- a/mart/attack/adversary.py +++ b/mart/attack/adversary.py @@ -36,7 +36,7 @@ def __init__( self, *, perturber: Perturber, - composer: Composer | dict[str, Composer], + composer: Composer, optimizer: OptimizerFactory | Callable[[Any], torch.optim.Optimizer], gain: Gain, gradient_modifier: GradientModifier | dict[str, GradientModifier] | None = None, @@ -73,8 +73,6 @@ def __init__( # Backward compatibility, in case modality is unknown, and not given in input. if not isinstance(gradient_modifier, dict): gradient_modifier = {DEFAULT_MODALITY: gradient_modifier} - if not isinstance(composer, dict): - composer = {DEFAULT_MODALITY: composer} self.composer = composer self.optimizer = optimizer @@ -162,7 +160,7 @@ def configure_gradient_clipping( self.gradient_modifier[modality](group["params"]) @silent() - def forward(self, *, model=None, sequence=None, input, target, **batch): + def forward(self, *, model=None, sequence=None, **batch): batch["model"] = model batch["sequence"] = sequence @@ -171,20 +169,14 @@ def forward(self, *, model=None, sequence=None, input, target, **batch): # Adversary lives inside the model, we also need the remaining sequence to be able to # get a loss. if model and sequence: - self._attack(input=input, target=target, **batch) - - perturbation = self.perturber(input=input, target=target, **batch) - input_adv = modality_dispatch( - input, - data=perturbation, - target=target, - modality_func=self.composer, - modality=DEFAULT_MODALITY, - ) + self._attack(**batch) + + perturbation = self.perturber(**batch) + input_adv = self.composer(perturbation, **batch) # Enforce constraints after the attack optimization ends. if model and sequence: - self.enforcer(input_adv, input=input, target=target, **batch) + self.enforcer(input_adv, **batch) return input_adv diff --git a/mart/attack/composer.py b/mart/attack/composer.py index e3962160..afe36e3e 100644 --- a/mart/attack/composer.py +++ b/mart/attack/composer.py @@ -11,17 +11,34 @@ import torch +from ..utils.modality_dispatch import DEFAULT_MODALITY, modality_dispatch + class Composer(abc.ABC): def __call__( self, - perturbation: torch.Tensor, + perturbation: torch.Tensor | Iterable[torch.Tensor], *, - input: torch.Tensor, - target: torch.Tensor | dict[str, Any], + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], **kwargs, ) -> torch.Tensor | Iterable[torch.Tensor]: - return self.compose(perturbation, input=input, target=target) + if isinstance(perturbation, torch.Tensor) and isinstance(input, torch.Tensor): + return self.compose(perturbation, input=input, target=target) + + elif ( + isinstance(perturbation, Iterable) + and isinstance(input, Iterable) # noqa: W503 + and isinstance(target, Iterable) # noqa: W503 + ): + # FIXME: replace tuple with whatever input's type is + return tuple( + self.compose(perturbation_i, input=input_i, target=target_i) + for perturbation_i, input_i, target_i in zip(perturbation, input, target) + ) + + else: + raise NotImplementedError @abc.abstractmethod def compose( @@ -63,3 +80,34 @@ def compose(self, perturbation, *, input, target): masked_perturbation = perturbation * mask return input + masked_perturbation + + +class Modality(Composer): + def __init__(self, **modality_method): + self.modality_method = modality_method + + def __call__( + self, + perturbation: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], + *, + input: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], + **kwargs, + ) -> torch.Tensor | Iterable[torch.Tensor]: + return modality_dispatch( + input, + data=perturbation, + target=target, + modality_func=self.compose, + modality=DEFAULT_MODALITY, + ) + + def compose( + self, + perturbation: torch.Tensor, + *, + input: torch.Tensor, + target: torch.Tensor | dict[str, Any], + modality: str, + ) -> torch.Tensor: + return self.modality_method[modality](perturbation, input=input, target=target) diff --git a/mart/configs/attack/composer/modality.yaml b/mart/configs/attack/composer/modality.yaml new file mode 100644 index 00000000..34955313 --- /dev/null +++ b/mart/configs/attack/composer/modality.yaml @@ -0,0 +1 @@ +_target_: mart.attack.composer.Modality diff --git a/mart/configs/attack/object_detection_rgb_mask_adversary.yaml b/mart/configs/attack/object_detection_rgb_mask_adversary.yaml index bea84349..beec9d35 100644 --- a/mart/configs/attack/object_detection_rgb_mask_adversary.yaml +++ b/mart/configs/attack/object_detection_rgb_mask_adversary.yaml @@ -3,11 +3,12 @@ defaults: - perturber: default - perturber/initializer@perturber.initializer.rgb: constant - perturber/projector@perturber.projector.rgb: mask_range - - composer@composer.rgb: overlay - /optimizer@optimizer: sgd - gain: rcnn_training_loss - gradient_modifier@gradient_modifier.rgb: sign - objective: zero_ap + - composer: modality + - composer@composer.rgb: overlay - enforcer: default - enforcer/constraints@enforcer.rgb: [mask, pixel_range] From 086bda8247ad733dbb6e3d361a23f4532bd9f3e0 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Fri, 23 Jun 2023 12:14:19 -0700 Subject: [PATCH 33/36] Add Modality(GradientModifier) and change the usage of GradientModifier to consume modality info. --- mart/attack/adversary.py | 11 ++--------- mart/attack/gradient_modifier.py | 19 +++++++++++++++++-- .../attack/gradient_modifier/modality.yaml | 1 + .../object_detection_rgb_mask_adversary.yaml | 3 ++- 4 files changed, 22 insertions(+), 12 deletions(-) create mode 100644 mart/configs/attack/gradient_modifier/modality.yaml diff --git a/mart/attack/adversary.py b/mart/attack/adversary.py index f8ff5049..213525db 100644 --- a/mart/attack/adversary.py +++ b/mart/attack/adversary.py @@ -16,7 +16,6 @@ from mart.utils import silent from ..optim import OptimizerFactory -from ..utils.modality_dispatch import DEFAULT_MODALITY, modality_dispatch if TYPE_CHECKING: from .composer import Composer @@ -39,7 +38,7 @@ def __init__( composer: Composer, optimizer: OptimizerFactory | Callable[[Any], torch.optim.Optimizer], gain: Gain, - gradient_modifier: GradientModifier | dict[str, GradientModifier] | None = None, + gradient_modifier: GradientModifier | None = None, objective: Objective | None = None, enforcer: Enforcer | None = None, attacker: pl.Trainer | None = None, @@ -69,11 +68,6 @@ def __init__( # and DDP won't try to get the uninitialized parameters of perturbation. self._perturber = [perturber] - # Modality-specific objects. - # Backward compatibility, in case modality is unknown, and not given in input. - if not isinstance(gradient_modifier, dict): - gradient_modifier = {DEFAULT_MODALITY: gradient_modifier} - self.composer = composer self.optimizer = optimizer if not isinstance(self.optimizer, OptimizerFactory): @@ -156,8 +150,7 @@ def configure_gradient_clipping( if self.gradient_modifier: for group in optimizer.param_groups: - modality = group["modality"] if "modality" in group else DEFAULT_MODALITY - self.gradient_modifier[modality](group["params"]) + self.gradient_modifier(group) @silent() def forward(self, *, model=None, sequence=None, **batch): diff --git a/mart/attack/gradient_modifier.py b/mart/attack/gradient_modifier.py index b2882574..a5fd68c6 100644 --- a/mart/attack/gradient_modifier.py +++ b/mart/attack/gradient_modifier.py @@ -6,17 +6,20 @@ from __future__ import annotations -from typing import Iterable +from typing import Any import torch +from ..utils.modality_dispatch import DEFAULT_MODALITY + __all__ = ["GradientModifier"] class GradientModifier: """Gradient modifier base class.""" - def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None: + def __call__(self, param_group: dict[str, Any]) -> None: + parameters = param_group["params"] if isinstance(parameters, torch.Tensor): parameters = [parameters] @@ -43,3 +46,15 @@ def __init__(self, p: int | float): def modify_(self, parameter: torch.Tensor) -> None: p_norm = torch.norm(parameter.grad.detach(), p=self.p) parameter.grad.detach().div_(p_norm) + + +class Modality(GradientModifier): + def __init__(self, **modality_method): + if len(modality_method) == 0: + modality_method = {DEFAULT_MODALITY: self.modify_} + + self.modality_method_ = modality_method + + def __call__(self, param_group: dict[str, Any]) -> None: + modality = param_group["modality"] if "modality" in param_group else DEFAULT_MODALITY + self.modality_method_[modality](param_group) diff --git a/mart/configs/attack/gradient_modifier/modality.yaml b/mart/configs/attack/gradient_modifier/modality.yaml new file mode 100644 index 00000000..a5596dfd --- /dev/null +++ b/mart/configs/attack/gradient_modifier/modality.yaml @@ -0,0 +1 @@ +_target_: mart.attack.gradient_modifier.Modality diff --git a/mart/configs/attack/object_detection_rgb_mask_adversary.yaml b/mart/configs/attack/object_detection_rgb_mask_adversary.yaml index beec9d35..30634925 100644 --- a/mart/configs/attack/object_detection_rgb_mask_adversary.yaml +++ b/mart/configs/attack/object_detection_rgb_mask_adversary.yaml @@ -5,8 +5,9 @@ defaults: - perturber/projector@perturber.projector.rgb: mask_range - /optimizer@optimizer: sgd - gain: rcnn_training_loss - - gradient_modifier@gradient_modifier.rgb: sign - objective: zero_ap + - gradient_modifier: modality + - gradient_modifier@gradient_modifier.rgb: sign - composer: modality - composer@composer.rgb: overlay - enforcer: default From c57465f3718312b72e81f4fae2f997086294ebda Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Fri, 23 Jun 2023 12:34:47 -0700 Subject: [PATCH 34/36] Fix test on gradient modifier. --- tests/test_gradient.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_gradient.py b/tests/test_gradient.py index a4ad49ee..fe366342 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -14,9 +14,10 @@ def test_gradient_sign(input_data): # Don't share input_data with other tests, because the gradient would be changed. input_data = torch.tensor([1.0, 2.0, 3.0]) input_data.grad = torch.tensor([-1.0, 3.0, 0.0]) + param_group = {"params": input_data} grad_modifier = Sign() - grad_modifier(input_data) + grad_modifier(param_group) expected_grad = torch.tensor([-1.0, 1.0, 0.0]) torch.testing.assert_close(input_data.grad, expected_grad) @@ -25,9 +26,10 @@ def test_gradient_lp_normalizer(): # Don't share input_data with other tests, because the gradient would be changed. input_data = torch.tensor([1.0, 2.0, 3.0]) input_data.grad = torch.tensor([-1.0, 3.0, 0.0]) + param_group = {"params": input_data} p = 1 grad_modifier = LpNormalizer(p) - grad_modifier(input_data) + grad_modifier(param_group) expected_grad = torch.tensor([-0.25, 0.75, 0.0]) torch.testing.assert_close(input_data.grad, expected_grad) From bff59bc8505000542f40be43a8390750debaf0f5 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Fri, 23 Jun 2023 12:36:30 -0700 Subject: [PATCH 35/36] Cleanup. --- mart/attack/adversary.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mart/attack/adversary.py b/mart/attack/adversary.py index 213525db..c0aa3768 100644 --- a/mart/attack/adversary.py +++ b/mart/attack/adversary.py @@ -67,7 +67,6 @@ def __init__( # Hide the perturber module in a list, so that perturbation is not exported as a parameter in the model checkpoint. # and DDP won't try to get the uninitialized parameters of perturbation. self._perturber = [perturber] - self.composer = composer self.optimizer = optimizer if not isinstance(self.optimizer, OptimizerFactory): From a06db2304e7bebd81011c1360827cd02c6ad06e2 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 20 Sep 2023 08:16:27 -0700 Subject: [PATCH 36/36] Keep modality-wise params for weights for later iterations. --- mart/optim/optimizer.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/mart/optim/optimizer.py b/mart/optim/optimizer.py index 2d96e59a..0fb3b694 100644 --- a/mart/optim/optimizer.py +++ b/mart/optim/optimizer.py @@ -26,6 +26,9 @@ def __init__(self, optimizer, **kwargs): self.bias_decay = kwargs.pop("bias_decay", weight_decay) self.norm_decay = kwargs.pop("norm_decay", weight_decay) self.optimizer = optimizer + + # Separate modality-wise params from kwargs, because optimizers do not recognize them. + self.modality_wise_params = kwargs.pop("modality_wise", {}) self.kwargs = kwargs def __call__(self, module): @@ -61,14 +64,16 @@ def __call__(self, module): params = [] - # Set modality-aware params. + # Set modality-aware weight params. if len(modality_params) > 0: for modality, param in modality_params.items(): # Take notes of modality for gradient modifier later. # Add modality-specific optim params. - params.append( - {"params": param, "modality": modality} | self.kwargs.pop(modality, {}) - ) + if modality in self.modality_wise_params: + modality_params = self.modality_wise_params[modality] + else: + modality_params = {} + params.append({"params": param, "modality": modality} | modality_params) # Set decay for bias and norm parameters if len(weight_params) > 0: