Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions mart/attack/composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,26 @@ def forward(self, perturbation, input, target):

input = input * (1 - mask) + perturbation
return perturbation, input, target


class InputFakeClamp(Function):
"""A Clamp operation that preserves gradients.

This should eliminate any assumption on Composer(e.g. additive) in Projector.
"""

def __init__(self, *args, min_val, max_val, **kwargs):
super().__init__(*args, **kwargs)
self.min_val = min_val
self.max_val = max_val

@staticmethod
def fake_clamp(x, *, min_val, max_val):
with torch.no_grad():
x_clamped = x.clamp(min_val, max_val)
diff = x_clamped - x
return x + diff

def forward(self, perturbation, input, target):
input = self.fake_clamp(input, min_val=self.min_val, max_val=self.max_val)
return perturbation, input, target
147 changes: 69 additions & 78 deletions mart/attack/projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,46 @@

from __future__ import annotations

import abc
from collections import OrderedDict
from typing import Any, Iterable

import torch


class Function:
def __init__(self, order=0) -> None:
"""A stackable function for Projector.

Args:
order (int, optional): The priority number. A smaller number makes a function run earlier than others in a sequence. Defaults to 0.
"""
self.order = order

@abc.abstractmethod
def __call__(self, perturbation, input, target) -> None:
"""It returns None because we only perform non-differentiable in-place operations."""
pass

def __repr__(self):
return f"{self.__class__.__name__}()"


class Projector:
"""A projector modifies nn.Parameter's data."""

@torch.no_grad()
def __init__(self, functions: dict[str, Function] = {}) -> None:
"""_summary_

Args:
functions (dict[str, Function]): A dictionary of functions for perturbation projection.
"""
# Sort functions by function.order and the name.
self.functions_dict = OrderedDict(
sorted(functions.items(), key=lambda name_fn: (name_fn[1].order, name_fn[0]))
)
self.functions = list(self.functions_dict.values())

def __call__(
self,
perturbation: torch.Tensor | Iterable[torch.Tensor],
Expand Down Expand Up @@ -45,118 +76,78 @@ def project_(
input: torch.Tensor | Iterable[torch.Tensor],
target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]],
) -> None:
pass


class Compose(Projector):
"""Apply a list of perturbation modifier."""

def __init__(self, projectors: list[Projector]):
self.projectors = projectors

@torch.no_grad()
def __call__(
self,
perturbation: torch.Tensor | Iterable[torch.Tensor],
*,
input: torch.Tensor | Iterable[torch.Tensor],
target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]],
**kwargs,
) -> None:
for projector in self.projectors:
projector(perturbation, input=input, target=target)
for function in self.functions:
# Some functions such as Mask need access to target["perturbable_mask"]
function(perturbation, input, target)

def __repr__(self):
projector_names = [repr(p) for p in self.projectors]
return f"{self.__class__.__name__}({projector_names})"

function_names = [repr(p) for p in self.functions_dict]
return f"{self.__class__.__name__}({function_names})"

class Range(Projector):
"""Clamp the perturbation so that the output is range-constrained."""

def __init__(self, quantize: bool = False, min: int | float = 0, max: int | float = 255):
self.quantize = quantize
self.min = min
self.max = max

@torch.no_grad()
def project_(self, perturbation, *, input, target):
if self.quantize:
perturbation.round_()
perturbation.clamp_(self.min, self.max)

def __repr__(self):
return (
f"{self.__class__.__name__}(quantize={self.quantize}, min={self.min}, max={self.max})"
)


class RangeAdditive(Projector):
class Range(Function):
"""Clamp the perturbation so that the output is range-constrained.

The projector assumes an additive perturbation threat model.
Maybe used in overlay composer.
"""

def __init__(self, quantize: bool = False, min: int | float = 0, max: int | float = 255):
self.quantize = quantize
self.min = min
self.max = max

@torch.no_grad()
def project_(self, perturbation, *, input, target):
def __call__(self, perturbation, input, target):
if self.quantize:
perturbation.round_()
perturbation.clamp_(self.min - input, self.max - input)
perturbation.clamp_(self.min, self.max)

def __repr__(self):
return (
f"{self.__class__.__name__}(quantize={self.quantize}, min={self.min}, max={self.max})"
)


class Lp(Projector):
class Lp(Function):
"""Project perturbations to Lp norm, only if the Lp norm is larger than eps."""

def __init__(self, eps: int | float, p: int | float = torch.inf):
def __init__(self, eps: int | float, p: int | float = torch.inf, *args, **kwargs):
"""_summary_

Args:
eps (float): The max norm.
p (float): The p in L-p norm, which must be positive.. Defaults to torch.inf.
"""
super().__init__(*args, **kwargs)

self.p = p
self.eps = eps

@torch.no_grad()
def project_(self, perturbation, *, input, target):
pert_norm = perturbation.norm(p=self.p)
if pert_norm > self.eps:
# We only upper-bound the norm.
perturbation.mul_(self.eps / pert_norm)


class LinfAdditiveRange(Projector):
"""Make sure the perturbation is within the Linf norm ball, and "input + perturbation" is
within the [min, max] range."""

def __init__(self, eps: int | float, min: int | float = 0, max: int | float = 255):
self.eps = eps
self.min = min
self.max = max
@staticmethod
def linf(x, p, eps):
x.clamp_(min=-eps, max=eps)

@staticmethod
def lp(x, p, eps):
x_norm = x.norm(p=p)
if x_norm > eps:
x.mul_(eps / x_norm)

def __call__(self, perturbation, input, target):
if self.p == torch.inf:
method = self.linf
elif self.p == 0:
raise NotImplementedError("L-0 projection is not implemented.")
else:
method = self.lp

@torch.no_grad()
def project_(self, perturbation, *, input, target):
eps_min = (input - self.eps).clamp(self.min, self.max) - input
eps_max = (input + self.eps).clamp(self.min, self.max) - input
method(perturbation, self.p, self.eps)

perturbation.clamp_(eps_min, eps_max)

# TODO: We may move the mask projection to Initialzier, if we also have mask in composer, because no gradient to update the masked pixels.
class Mask(Function):
def __init__(self, *args, key="perturbable_mask", **kwargs):
super().__init__(*args, **kwargs)
self.key = key

class Mask(Projector):
@torch.no_grad()
def project_(self, perturbation, *, input, target):
perturbation.mul_(target["perturbable_mask"])

def __repr__(self):
return f"{self.__class__.__name__}()"
def __call__(self, perturbation, input, target):
perturbation.mul_(target[self.key])
10 changes: 9 additions & 1 deletion mart/configs/attack/classification_fgsm_linf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,18 @@ defaults:
- adversary
- fgm
- linf
- composer/functions: additive
- composer/functions: [additive, input_fake_clamp]
- gradient_modifier: sign
- gain: cross_entropy
- objective: misclassification

composer:
functions:
additive:
order: 0
# input_fake_clamp ensures valid range of pixel values after addition.
input_fake_clamp:
order: 1

eps: ???
max_iters: 1
10 changes: 9 additions & 1 deletion mart/configs/attack/classification_pgd_linf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,19 @@ defaults:
- adversary
- pgd
- linf
- composer/functions: additive
- composer/functions: [additive, input_fake_clamp]
- gradient_modifier: sign
- gain: cross_entropy
- objective: misclassification

composer:
functions:
additive:
order: 0
# input_fake_clamp ensures valid range of pixel values after addition.
input_fake_clamp:
order: 1

eps: ???
lr: ???
max_iters: ???
5 changes: 5 additions & 0 deletions mart/configs/attack/composer/functions/input_fake_clamp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
input_fake_clamp:
_target_: mart.attack.composer.InputFakeClamp
order: 0
min_val: 0
max_val: 255
7 changes: 5 additions & 2 deletions mart/configs/attack/composer/perturber/default.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Avoid null projector here due to the chance of overriding projectors defined in other config files?

defaults:
- projector: default

_target_: mart.attack.Perturber
initializer: ???
# Avoid null projector here due to the chance of overriding projectors defined in other config files.
projector: ???
2 changes: 2 additions & 0 deletions mart/configs/attack/composer/perturber/projector/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_target_: mart.attack.Projector
functions: ???
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
linf:
_target_: mart.attack.projector.Lp
order: 0
# p is actually torch.inf by default.
p:
_target_: builtins.float
_args_: ["inf"]
eps: ???
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
lp:
_target_: mart.attack.projector.Lp
order: 0
p: ???
eps: ???
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
mask:
_target_: mart.attack.projector.Mask
order: 0
key: perturbable_mask
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
range:
_target_: mart.attack.projector.Range
order: 0
quantize: false
min: 0
max: 255

This file was deleted.

This file was deleted.

This file was deleted.

4 changes: 0 additions & 4 deletions mart/configs/attack/composer/perturber/projector/range.yaml

This file was deleted.

4 changes: 3 additions & 1 deletion mart/configs/attack/fgm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ composer:
initializer:
constant: 0
projector:
eps: ${....eps}
functions:
lp:
eps: ${......eps}

# We can turn off progress bar for one-step attack.
callbacks:
Expand Down
17 changes: 13 additions & 4 deletions mart/configs/attack/linf.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
defaults:
- composer/perturber/projector: linf_additive_range
- composer/perturber/projector/functions: lp
- enforcer: default
- enforcer/constraints: lp

composer:
perturber:
projector:
functions:
lp:
p: ${......p}
eps: ${......eps}

enforcer:
constraints:
lp:
p:
_target_: builtins.float
_args_: ["inf"]
p: ${....p}
eps: ${....eps}

p:
_target_: builtins.float
_args_: ["inf"]
eps: ???
2 changes: 1 addition & 1 deletion mart/configs/attack/mask.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
defaults:
- composer/perturber/projector: mask_range
- composer/perturber/projector/functions: mask
- enforcer: default
- enforcer/constraints: [mask, pixel_range]
Loading