diff --git a/mart/attack/composer.py b/mart/attack/composer.py index cd6300f7..3401fd12 100644 --- a/mart/attack/composer.py +++ b/mart/attack/composer.py @@ -134,3 +134,26 @@ def forward(self, perturbation, input, target): input = input * (1 - mask) + perturbation return perturbation, input, target + + +class InputFakeClamp(Function): + """A Clamp operation that preserves gradients. + + This should eliminate any assumption on Composer(e.g. additive) in Projector. + """ + + def __init__(self, *args, min_val, max_val, **kwargs): + super().__init__(*args, **kwargs) + self.min_val = min_val + self.max_val = max_val + + @staticmethod + def fake_clamp(x, *, min_val, max_val): + with torch.no_grad(): + x_clamped = x.clamp(min_val, max_val) + diff = x_clamped - x + return x + diff + + def forward(self, perturbation, input, target): + input = self.fake_clamp(input, min_val=self.min_val, max_val=self.max_val) + return perturbation, input, target diff --git a/mart/attack/projector.py b/mart/attack/projector.py index f9887354..6b09c648 100644 --- a/mart/attack/projector.py +++ b/mart/attack/projector.py @@ -6,15 +6,46 @@ from __future__ import annotations +import abc +from collections import OrderedDict from typing import Any, Iterable import torch +class Function: + def __init__(self, order=0) -> None: + """A stackable function for Projector. + + Args: + order (int, optional): The priority number. A smaller number makes a function run earlier than others in a sequence. Defaults to 0. + """ + self.order = order + + @abc.abstractmethod + def __call__(self, perturbation, input, target) -> None: + """It returns None because we only perform non-differentiable in-place operations.""" + pass + + def __repr__(self): + return f"{self.__class__.__name__}()" + + class Projector: """A projector modifies nn.Parameter's data.""" - @torch.no_grad() + def __init__(self, functions: dict[str, Function] = {}) -> None: + """_summary_ + + Args: + functions (dict[str, Function]): A dictionary of functions for perturbation projection. + """ + # Sort functions by function.order and the name. + self.functions_dict = OrderedDict( + sorted(functions.items(), key=lambda name_fn: (name_fn[1].order, name_fn[0])) + ) + self.functions = list(self.functions_dict.values()) + def __call__( self, perturbation: torch.Tensor | Iterable[torch.Tensor], @@ -45,56 +76,19 @@ def project_( input: torch.Tensor | Iterable[torch.Tensor], target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], ) -> None: - pass - - -class Compose(Projector): - """Apply a list of perturbation modifier.""" - - def __init__(self, projectors: list[Projector]): - self.projectors = projectors - - @torch.no_grad() - def __call__( - self, - perturbation: torch.Tensor | Iterable[torch.Tensor], - *, - input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], - **kwargs, - ) -> None: - for projector in self.projectors: - projector(perturbation, input=input, target=target) + for function in self.functions: + # Some functions such as Mask need access to target["perturbable_mask"] + function(perturbation, input, target) def __repr__(self): - projector_names = [repr(p) for p in self.projectors] - return f"{self.__class__.__name__}({projector_names})" - + function_names = [repr(p) for p in self.functions_dict] + return f"{self.__class__.__name__}({function_names})" -class Range(Projector): - """Clamp the perturbation so that the output is range-constrained.""" - def __init__(self, quantize: bool = False, min: int | float = 0, max: int | float = 255): - self.quantize = quantize - self.min = min - self.max = max - - @torch.no_grad() - def project_(self, perturbation, *, input, target): - if self.quantize: - perturbation.round_() - perturbation.clamp_(self.min, self.max) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(quantize={self.quantize}, min={self.min}, max={self.max})" - ) - - -class RangeAdditive(Projector): +class Range(Function): """Clamp the perturbation so that the output is range-constrained. - The projector assumes an additive perturbation threat model. + Maybe used in overlay composer. """ def __init__(self, quantize: bool = False, min: int | float = 0, max: int | float = 255): @@ -102,11 +96,10 @@ def __init__(self, quantize: bool = False, min: int | float = 0, max: int | floa self.min = min self.max = max - @torch.no_grad() - def project_(self, perturbation, *, input, target): + def __call__(self, perturbation, input, target): if self.quantize: perturbation.round_() - perturbation.clamp_(self.min - input, self.max - input) + perturbation.clamp_(self.min, self.max) def __repr__(self): return ( @@ -114,49 +107,47 @@ def __repr__(self): ) -class Lp(Projector): +class Lp(Function): """Project perturbations to Lp norm, only if the Lp norm is larger than eps.""" - def __init__(self, eps: int | float, p: int | float = torch.inf): + def __init__(self, eps: int | float, p: int | float = torch.inf, *args, **kwargs): """_summary_ Args: eps (float): The max norm. p (float): The p in L-p norm, which must be positive.. Defaults to torch.inf. """ + super().__init__(*args, **kwargs) self.p = p self.eps = eps - @torch.no_grad() - def project_(self, perturbation, *, input, target): - pert_norm = perturbation.norm(p=self.p) - if pert_norm > self.eps: - # We only upper-bound the norm. - perturbation.mul_(self.eps / pert_norm) - - -class LinfAdditiveRange(Projector): - """Make sure the perturbation is within the Linf norm ball, and "input + perturbation" is - within the [min, max] range.""" - - def __init__(self, eps: int | float, min: int | float = 0, max: int | float = 255): - self.eps = eps - self.min = min - self.max = max + @staticmethod + def linf(x, p, eps): + x.clamp_(min=-eps, max=eps) + + @staticmethod + def lp(x, p, eps): + x_norm = x.norm(p=p) + if x_norm > eps: + x.mul_(eps / x_norm) + + def __call__(self, perturbation, input, target): + if self.p == torch.inf: + method = self.linf + elif self.p == 0: + raise NotImplementedError("L-0 projection is not implemented.") + else: + method = self.lp - @torch.no_grad() - def project_(self, perturbation, *, input, target): - eps_min = (input - self.eps).clamp(self.min, self.max) - input - eps_max = (input + self.eps).clamp(self.min, self.max) - input + method(perturbation, self.p, self.eps) - perturbation.clamp_(eps_min, eps_max) +# TODO: We may move the mask projection to Initialzier, if we also have mask in composer, because no gradient to update the masked pixels. +class Mask(Function): + def __init__(self, *args, key="perturbable_mask", **kwargs): + super().__init__(*args, **kwargs) + self.key = key -class Mask(Projector): - @torch.no_grad() - def project_(self, perturbation, *, input, target): - perturbation.mul_(target["perturbable_mask"]) - - def __repr__(self): - return f"{self.__class__.__name__}()" + def __call__(self, perturbation, input, target): + perturbation.mul_(target[self.key]) diff --git a/mart/configs/attack/classification_fgsm_linf.yaml b/mart/configs/attack/classification_fgsm_linf.yaml index 74c9b959..fd647228 100644 --- a/mart/configs/attack/classification_fgsm_linf.yaml +++ b/mart/configs/attack/classification_fgsm_linf.yaml @@ -2,10 +2,18 @@ defaults: - adversary - fgm - linf - - composer/functions: additive + - composer/functions: [additive, input_fake_clamp] - gradient_modifier: sign - gain: cross_entropy - objective: misclassification +composer: + functions: + additive: + order: 0 + # input_fake_clamp ensures valid range of pixel values after addition. + input_fake_clamp: + order: 1 + eps: ??? max_iters: 1 diff --git a/mart/configs/attack/classification_pgd_linf.yaml b/mart/configs/attack/classification_pgd_linf.yaml index fec19029..89d5001d 100644 --- a/mart/configs/attack/classification_pgd_linf.yaml +++ b/mart/configs/attack/classification_pgd_linf.yaml @@ -2,11 +2,19 @@ defaults: - adversary - pgd - linf - - composer/functions: additive + - composer/functions: [additive, input_fake_clamp] - gradient_modifier: sign - gain: cross_entropy - objective: misclassification +composer: + functions: + additive: + order: 0 + # input_fake_clamp ensures valid range of pixel values after addition. + input_fake_clamp: + order: 1 + eps: ??? lr: ??? max_iters: ??? diff --git a/mart/configs/attack/composer/functions/input_fake_clamp.yaml b/mart/configs/attack/composer/functions/input_fake_clamp.yaml new file mode 100644 index 00000000..764f08d4 --- /dev/null +++ b/mart/configs/attack/composer/functions/input_fake_clamp.yaml @@ -0,0 +1,5 @@ +input_fake_clamp: + _target_: mart.attack.composer.InputFakeClamp + order: 0 + min_val: 0 + max_val: 255 diff --git a/mart/configs/attack/composer/perturber/default.yaml b/mart/configs/attack/composer/perturber/default.yaml index 322b1440..bc354ab6 100644 --- a/mart/configs/attack/composer/perturber/default.yaml +++ b/mart/configs/attack/composer/perturber/default.yaml @@ -1,4 +1,7 @@ +# Avoid null projector here due to the chance of overriding projectors defined in other config files? + +defaults: + - projector: default + _target_: mart.attack.Perturber initializer: ??? -# Avoid null projector here due to the chance of overriding projectors defined in other config files. -projector: ??? diff --git a/mart/configs/attack/composer/perturber/projector/default.yaml b/mart/configs/attack/composer/perturber/projector/default.yaml new file mode 100644 index 00000000..a1f0402e --- /dev/null +++ b/mart/configs/attack/composer/perturber/projector/default.yaml @@ -0,0 +1,2 @@ +_target_: mart.attack.Projector +functions: ??? diff --git a/mart/configs/attack/composer/perturber/projector/functions/linf.yaml b/mart/configs/attack/composer/perturber/projector/functions/linf.yaml new file mode 100644 index 00000000..90b4d465 --- /dev/null +++ b/mart/configs/attack/composer/perturber/projector/functions/linf.yaml @@ -0,0 +1,8 @@ +linf: + _target_: mart.attack.projector.Lp + order: 0 + # p is actually torch.inf by default. + p: + _target_: builtins.float + _args_: ["inf"] + eps: ??? diff --git a/mart/configs/attack/composer/perturber/projector/functions/lp.yaml b/mart/configs/attack/composer/perturber/projector/functions/lp.yaml new file mode 100644 index 00000000..b7be2f1c --- /dev/null +++ b/mart/configs/attack/composer/perturber/projector/functions/lp.yaml @@ -0,0 +1,5 @@ +lp: + _target_: mart.attack.projector.Lp + order: 0 + p: ??? + eps: ??? diff --git a/mart/configs/attack/composer/perturber/projector/functions/mask.yaml b/mart/configs/attack/composer/perturber/projector/functions/mask.yaml new file mode 100644 index 00000000..2e6e2794 --- /dev/null +++ b/mart/configs/attack/composer/perturber/projector/functions/mask.yaml @@ -0,0 +1,4 @@ +mask: + _target_: mart.attack.projector.Mask + order: 0 + key: perturbable_mask diff --git a/mart/configs/attack/composer/perturber/projector/functions/range.yaml b/mart/configs/attack/composer/perturber/projector/functions/range.yaml new file mode 100644 index 00000000..f65a4f40 --- /dev/null +++ b/mart/configs/attack/composer/perturber/projector/functions/range.yaml @@ -0,0 +1,6 @@ +range: + _target_: mart.attack.projector.Range + order: 0 + quantize: false + min: 0 + max: 255 diff --git a/mart/configs/attack/composer/perturber/projector/linf_additive_range.yaml b/mart/configs/attack/composer/perturber/projector/linf_additive_range.yaml deleted file mode 100644 index 0723ffbf..00000000 --- a/mart/configs/attack/composer/perturber/projector/linf_additive_range.yaml +++ /dev/null @@ -1,4 +0,0 @@ -_target_: mart.attack.projector.LinfAdditiveRange -eps: ??? -min: 0 -max: 255 diff --git a/mart/configs/attack/composer/perturber/projector/lp_additive_range.yaml b/mart/configs/attack/composer/perturber/projector/lp_additive_range.yaml deleted file mode 100644 index 37efa30d..00000000 --- a/mart/configs/attack/composer/perturber/projector/lp_additive_range.yaml +++ /dev/null @@ -1,5 +0,0 @@ -_target_: mart.attack.projector.LpAdditiveRangeProjector -p: ??? -eps: ??? -min: 0 -max: 255 diff --git a/mart/configs/attack/composer/perturber/projector/mask_range.yaml b/mart/configs/attack/composer/perturber/projector/mask_range.yaml deleted file mode 100644 index 4fe5c3de..00000000 --- a/mart/configs/attack/composer/perturber/projector/mask_range.yaml +++ /dev/null @@ -1,7 +0,0 @@ -_target_: mart.attack.projector.Compose -projectors: - - _target_: mart.attack.projector.Mask - - _target_: mart.attack.projector.Range - quantize: false - min: 0 - max: 255 diff --git a/mart/configs/attack/composer/perturber/projector/range.yaml b/mart/configs/attack/composer/perturber/projector/range.yaml deleted file mode 100644 index 8d10563b..00000000 --- a/mart/configs/attack/composer/perturber/projector/range.yaml +++ /dev/null @@ -1,4 +0,0 @@ -_target_: mart.attack.projector.Range -quantize: false -min: 0 -max: 255 diff --git a/mart/configs/attack/fgm.yaml b/mart/configs/attack/fgm.yaml index 73a8bbc5..9c00573a 100644 --- a/mart/configs/attack/fgm.yaml +++ b/mart/configs/attack/fgm.yaml @@ -13,7 +13,9 @@ composer: initializer: constant: 0 projector: - eps: ${....eps} + functions: + lp: + eps: ${......eps} # We can turn off progress bar for one-step attack. callbacks: diff --git a/mart/configs/attack/linf.yaml b/mart/configs/attack/linf.yaml index 71c50c9e..8b01a601 100644 --- a/mart/configs/attack/linf.yaml +++ b/mart/configs/attack/linf.yaml @@ -1,14 +1,23 @@ defaults: - - composer/perturber/projector: linf_additive_range + - composer/perturber/projector/functions: lp - enforcer: default - enforcer/constraints: lp +composer: + perturber: + projector: + functions: + lp: + p: ${......p} + eps: ${......eps} + enforcer: constraints: lp: - p: - _target_: builtins.float - _args_: ["inf"] + p: ${....p} eps: ${....eps} +p: + _target_: builtins.float + _args_: ["inf"] eps: ??? diff --git a/mart/configs/attack/mask.yaml b/mart/configs/attack/mask.yaml index 08206522..62380750 100644 --- a/mart/configs/attack/mask.yaml +++ b/mart/configs/attack/mask.yaml @@ -1,4 +1,4 @@ defaults: - - composer/perturber/projector: mask_range + - composer/perturber/projector/functions: mask - enforcer: default - enforcer/constraints: [mask, pixel_range] diff --git a/mart/configs/attack/object_detection_mask_adversary.yaml b/mart/configs/attack/object_detection_mask_adversary.yaml index b2b6b394..11fcda08 100644 --- a/mart/configs/attack/object_detection_mask_adversary.yaml +++ b/mart/configs/attack/object_detection_mask_adversary.yaml @@ -3,7 +3,7 @@ defaults: - gradient_ascent - mask - composer/perturber/initializer: constant - - composer/functions: overlay + - composer/functions: [overlay, input_fake_clamp] - gradient_modifier: sign - gain: rcnn_training_loss - objective: zero_ap @@ -11,8 +11,15 @@ defaults: max_iters: ??? lr: ??? -# Start with grey perturbation in the overlay mode. composer: + # Use fake_clamp to ensure valid range of pixel values after overlay. + functions: + overlay: + order: 0 + input_fake_clamp: + order: 1 + + # Start with grey perturbation in the overlay mode. perturber: initializer: constant: 127 diff --git a/mart/configs/attack/object_detection_mask_adversary_missed.yaml b/mart/configs/attack/object_detection_mask_adversary_missed.yaml index f8b5298f..2eb83510 100644 --- a/mart/configs/attack/object_detection_mask_adversary_missed.yaml +++ b/mart/configs/attack/object_detection_mask_adversary_missed.yaml @@ -3,7 +3,7 @@ defaults: - gradient_ascent - mask - composer/perturber/initializer: constant - - composer/functions: overlay + - composer/functions: [overlay, input_fake_clamp] - gradient_modifier: sign - gain: rcnn_class_background - objective: object_detection_missed @@ -11,8 +11,15 @@ defaults: max_iters: ??? lr: ??? -# Start with grey perturbation in the overlay mode. composer: + # Use fake_clamp to ensure valid range of pixel values overlay. + functions: + overlay: + order: 0 + input_fake_clamp: + order: 1 + + # Start with grey perturbation in the overlay mode. perturber: initializer: constant: 127 diff --git a/mart/configs/attack/pgd.yaml b/mart/configs/attack/pgd.yaml index 9af4985a..646ed603 100644 --- a/mart/configs/attack/pgd.yaml +++ b/mart/configs/attack/pgd.yaml @@ -15,4 +15,6 @@ composer: min: ${negate:${....eps}} max: ${....eps} projector: - eps: ${....eps} + functions: + lp: + eps: ${......eps} diff --git a/tests/test_projector.py b/tests/test_projector.py index 19cb5c44..f8939a11 100644 --- a/tests/test_projector.py +++ b/tests/test_projector.py @@ -4,19 +4,10 @@ # SPDX-License-Identifier: BSD-3-Clause # -from unittest.mock import Mock - import pytest import torch -from mart.attack.projector import ( - Compose, - LinfAdditiveRange, - Lp, - Mask, - Range, - RangeAdditive, -) +from mart.attack.projector import Lp, Mask, Range def test_range_projector_repr(): @@ -42,36 +33,6 @@ def test_range_projector(quantize, min, max, input_data, target_data, perturbati assert torch.min(perturbation) >= min -def test_range_additive_projector_repr(): - min = 0 - max = 100 - quantize = True - projector = RangeAdditive(quantize, min, max) - representation = repr(projector) - expected_representation = ( - f"{projector.__class__.__name__}(quantize={quantize}, min={min}, max={max})" - ) - assert representation == expected_representation - - -@pytest.mark.parametrize("quantize", [False, True]) -@pytest.mark.parametrize("min", [-10, 0, 10]) -@pytest.mark.parametrize("max", [10, 100, 110]) -def test_range_additive_projector(quantize, min, max, input_data, target_data, perturbation): - expected_perturbation = torch.clone(perturbation) - - projector = RangeAdditive(quantize, min, max) - projector(perturbation, input=input_data, target=target_data) - - # modify expected_perturbation - if quantize: - expected_perturbation.round_() - expected_perturbation.clamp_(min - input_data, max - input_data) - - assert torch.max(perturbation) == torch.max(expected_perturbation) - assert torch.min(perturbation) == torch.min(expected_perturbation) - - @pytest.mark.parametrize("eps", [30, 40, 50, 60]) @pytest.mark.parametrize("p", [1, 2, 3]) def test_lp_projector(eps, p, input_data, target_data, perturbation): @@ -88,24 +49,6 @@ def test_lp_projector(eps, p, input_data, target_data, perturbation): torch.testing.assert_close(perturbation, expected_perturbation) -@pytest.mark.parametrize("min", [-10, 0, 10]) -@pytest.mark.parametrize("max", [10, 100, 110]) -@pytest.mark.parametrize("eps", [30, 40, 50]) -def test_linf_additive_range_projector(min, max, eps, input_data, target_data, perturbation): - expected_perturbation = torch.clone(perturbation) - - projector = LinfAdditiveRange(eps, min, max) - projector(perturbation, input=input_data, target=target_data) - - # get expected result - eps_min = (input_data - eps).clamp(min, max) - input_data - eps_max = (input_data + eps).clamp(min, max) - input_data - expected_perturbation.clamp_(eps_min, eps_max) - - assert torch.max(perturbation) == torch.max(expected_perturbation) - assert torch.min(perturbation) == torch.min(expected_perturbation) - - def test_mask_projector_repr(): projector = Mask() representation = repr(projector) @@ -123,42 +66,3 @@ def test_mask_projector(input_data, target_data, perturbation): expected_perturbation.mul_(target_data["perturbable_mask"]) torch.testing.assert_close(perturbation, expected_perturbation) - - -def test_compose_repr(): - eps = 5 - projectors = [ - Range(), - RangeAdditive(), - Lp(eps), - LinfAdditiveRange(eps), - Mask(), - ] - - compose = Compose(projectors) - - projector_names = [repr(p) for p in projectors] - expected_representation = f"{compose.__class__.__name__}({projector_names})" - representation = repr(compose) - assert representation == expected_representation - - -def test_compose(input_data, target_data): - eps = 5 - projectors = [ - Range(), - RangeAdditive(), - Lp(eps), - LinfAdditiveRange(eps), - Mask(), - ] - - compose = Compose(projectors) - tensor = Mock(spec=torch.Tensor) - tensor.norm.return_value = 10 - compose(tensor, input=input_data, target=target_data) - - # RangeProjector, RangeAdditiveProjector, and LinfAdditiveRangeProjector calls `clamp_` - assert tensor.clamp_.call_count == 3 - # LpProjector and MaskProjector calls `mul_` - assert tensor.mul_.call_count == 2