diff --git a/mart/attack/adversary.py b/mart/attack/adversary.py index 66dd1a6c..0cc7add3 100644 --- a/mart/attack/adversary.py +++ b/mart/attack/adversary.py @@ -123,6 +123,9 @@ def training_step(self, batch, batch_idx): # Use CallWith to dispatch **outputs. gain = self.gain_fn(**outputs) + # Log gain of all examples as a metric for LR scheduler to monitor, and show gain on progress bar. + self.log("gain", gain.sum(), prog_bar=True) + # objective_fn is optional, because adversaries may never reach their objective. if self.objective_fn is not None: found = self.objective_fn(**outputs) @@ -131,13 +134,7 @@ def training_step(self, batch, batch_idx): if len(gain.shape) > 0: gain = gain[~found] - if len(gain.shape) > 0: - gain = gain.sum() - - # Log gain as a metric for LR scheduler to monitor, and show gain on progress bar. - self.log("gain", gain, prog_bar=True) - - return gain + return gain.sum() def configure_gradient_clipping( self, optimizer, gradient_clip_val=None, gradient_clip_algorithm=None