diff --git a/aux_bn.py b/aux_bn.py index 40207c7..f946da1 100644 --- a/aux_bn.py +++ b/aux_bn.py @@ -38,11 +38,19 @@ def to_status(m, status): """ change the status of batch norm layer status can be 'clean', 'adv' or 'mix' + + Three statuses, meaning the training samples in this batch are: + - clean: all clean samples + - adv: all adversarial samples + - mix: *1st* half are *adversarial* samples, and the *2nd* half are *clean* samples """ if hasattr(m, 'batch_type'): m.batch_type = status to_clean_status = partial(to_status, status='clean') +'''all clean examples''' to_adv_status = partial(to_status, status='adv') +'''all adversarial samples''' to_mix_status = partial(to_status, status='mix') +'''*1st* half are *adversarial* samples, and the *2nd* half are *clean* samples''' diff --git a/imagenet.py b/imagenet.py index beb0d44..6c9bd17 100644 --- a/imagenet.py +++ b/imagenet.py @@ -38,7 +38,9 @@ import models.imagenet as customized_models from models.AdaIN import StyleTransfer -from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig +from progress.bar import Bar +from utils import Logger, AverageMeter, accuracy, mkdir_p, savefig +from utils.eval import accuracy_and_perclass from utils.imagenet_a import indices_in_1k from tensorboardX import SummaryWriter @@ -118,6 +120,8 @@ parser.add_argument('--manualSeed', type=int, help='manual seed') parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') +parser.add_argument('-et', '--evaluate-train', dest='evaluate_train', action='store_true', + help='evaluate model on train set') # Device options parser.add_argument('--gpu-id', default='7', type=str, help='id(s) for CUDA_VISIBLE_DEVICES') @@ -208,6 +212,7 @@ def main(): num_workers=args.workers, pin_memory=True) if not args.evaluate_imagenet_c else None # create model + # only works for resnet or resnext if args.arch.startswith('resnext'): norm_layer = MixBatchNorm2d if args.mixbn else None model = models.__dict__[args.arch]( @@ -264,6 +269,7 @@ def main(): break if args.mixbn and not already_mixbn: + # update the model checkpoint with mixbn to_merge = {} for key in checkpoint['state_dict']: if 'bn' in key: @@ -301,9 +307,15 @@ def main(): if args.evaluate: print('\nEvaluation only') - test_loss, test_acc = test(val_loader, model, criterion, start_epoch, use_cuda, args.FGSM) + test_loss, test_acc = test(val_loader, model, criterion, start_epoch, use_cuda, args.FGSM, args.num_classes) print(' Test Loss: %.8f, Test Acc: %.2f' % (test_loss, test_acc)) return + + if args.evaluate_train: + print('\nEvaluation on training set only') + test_loss, test_acc = test(train_loader, model, criterion, start_epoch, use_cuda, args.FGSM, args.num_classes) + print(' Train Loss: %.8f, Train Acc: %.2f' % (test_loss, test_acc)) + return if args.evaluate_imagenet_c: print("Evaluate ImageNet C") @@ -331,7 +343,7 @@ def main(): start_lr=args.warm_lr) if args.warm > 0 else None for epoch in range(start_epoch, args.epochs): if epoch >= args.warm and args.lr_schedule == 'step': - adjust_learning_rate(optimizer, epoch, args) + adjust_learning_rate(optimizer, epoch, args, state) print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, optimizer.param_groups[-1]['lr'])) @@ -339,7 +351,7 @@ def main(): label_mix_alpha=1 - args.label_gamma) if args.style else None train_func = partial(train, train_loader=train_loader, model=model, criterion=criterion, optimizer=optimizer, epoch=epoch, use_cuda=use_cuda, - warmup_scheduler=warmup_scheduler, mixbn=args.mixbn, + warmup_scheduler=warmup_scheduler, state=state, mixbn=args.mixbn, style_transfer=style_transfer, writer=writer) if args.mixbn: model.apply(to_mix_status) @@ -398,8 +410,13 @@ def img_size_scheduler(batch_idx, epoch, schedule): return ret_size, ret_size -def train(train_loader, model, criterion, optimizer, epoch, use_cuda, warmup_scheduler, mixbn=False, +def train(train_loader, model, criterion, optimizer, epoch, use_cuda, warmup_scheduler, state, mixbn=False, style_transfer=None, writer=None): + ''' + Train the model for a single epoch + + Core of shape-texture debiased training happens here + ''' # switch to train mode model.train() @@ -422,7 +439,7 @@ def train(train_loader, model, criterion, optimizer, epoch, use_cuda, warmup_sch if epoch < args.warm: warmup_scheduler.step() elif args.lr_schedule == 'cos': - adjust_learning_rate(optimizer, epoch, args, batch=batch_idx, nBatch=len(train_loader)) + adjust_learning_rate(optimizer, epoch, args, state, batch=batch_idx, nBatch=len(train_loader)) # measure data loading time data_time.update(time.time() - end) @@ -433,10 +450,16 @@ def train(train_loader, model, criterion, optimizer, epoch, use_cuda, warmup_sch if style_transfer is not None: if args.multi_grid: + ''' + Multigrid training + https://arxiv.org/pdf/1912.00998.pdf + Help with faster convergence. Might improve performance for small models + ''' img_size = img_size_scheduler(batch_idx, epoch, args.schedule) resized_inputs = torch.nn.functional.interpolate(inputs, size=img_size) inputs_aux, targets_aux = style_transfer(resized_inputs, targets, replace=True) inputs = (inputs, inputs_aux) + # get the nwe set of targets that include the label of style transferred images if len(targets_aux) == 3: n = targets.size(0) targets = (torch.cat([targets, targets_aux[0]]), @@ -462,16 +485,20 @@ def train(train_loader, model, criterion, optimizer, epoch, use_cuda, warmup_sch elif args.cutmix: inputs, targets = cutmix_data(inputs, targets, beta=args.cutmix, half=False) + # normalize AFTER style transfer if not args.multi_grid: inputs = (inputs - MEAN[:, None, None]) / STD[:, None, None] + # If using mixbn, model should be in mixed status here because inputs contain both the original and stylized images outputs = model(inputs) else: inputs = ((inputs[0] - MEAN[:, None, None]) / STD[:, None, None], (inputs[1] - MEAN[:, None, None]) / STD[:, None, None]) - if args.mixbn: + + # Run the batch on model. Since the original and stylized images are in two separate variable in this case, the first one is considered all clean samples, and the second one is considered all adversarial samples + if mixbn: model.apply(to_clean_status) outputs1 = model(inputs[0]) - if args.mixbn: + if mixbn: model.apply(to_adv_status) outputs2 = model(inputs[1]) outputs = torch.cat([outputs1, outputs2]) @@ -488,6 +515,7 @@ def train(train_loader, model, criterion, optimizer, epoch, use_cuda, warmup_sch top1.update(prec1.item(), outputs.size(0)) top5.update(prec5.item(), outputs.size(0)) + # Compute main and aux loss/metrics separately when using mixbn if mixbn: with torch.no_grad(): batch_size = outputs.size(0) @@ -537,7 +565,7 @@ def train(train_loader, model, criterion, optimizer, epoch, use_cuda, warmup_sch return losses.avg, top1.avg -def test(val_loader, model, criterion, epoch, use_cuda, FGSM=False): +def test(val_loader, model, criterion, epoch, use_cuda, FGSM=False, num_classes=None): global best_acc batch_time = AverageMeter() @@ -545,6 +573,9 @@ def test(val_loader, model, criterion, epoch, use_cuda, FGSM=False): losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() + if num_classes is not None: + total_num_per_class = torch.zeros(num_classes).int() + total_correct_per_class = torch.zeros(num_classes).int() # switch to evaluate mode model.eval() @@ -580,7 +611,12 @@ def test(val_loader, model, criterion, epoch, use_cuda, FGSM=False): loss = criterion(outputs, targets).mean() # measure accuracy and record loss - prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) + if num_classes is not None: + prec1, prec5, num_per_class, correct_per_class = accuracy_and_perclass(outputs.data, targets.data, topk=(1, 5), numclasses=num_classes) + total_num_per_class += num_per_class + total_correct_per_class += correct_per_class + else: + prec1, prec5= accuracy(outputs.data, targets.data, topk=(1, 5)) losses.update(loss.item(), inputs.size(0)) top1.update(prec1.item(), inputs.size(0)) top5.update(prec5.item(), inputs.size(0)) @@ -603,6 +639,11 @@ def test(val_loader, model, criterion, epoch, use_cuda, FGSM=False): ) bar.next() bar.finish() + if num_classes is not None: + accuracy_per_class = total_correct_per_class / total_num_per_class + print(f"class\tacc") + for i, acc in enumerate(accuracy_per_class.tolist()): + print(f"{i}\t{acc}") return (losses.avg, top1.avg) diff --git a/models/AdaIN.py b/models/AdaIN.py index 5d699f2..b256a82 100644 --- a/models/AdaIN.py +++ b/models/AdaIN.py @@ -225,14 +225,17 @@ def __call__(self, image, label, alpha, replace=True, label_mix_alpha=0): n, c, h, w = image.shape content = image.detach() random_index = torch.randperm(n) - style = image.detach()[random_index] - label_style = label.detach()[random_index] + style = image.detach()[random_index] # Style is created from randomly permuting a batch of image. Thus each (content[i], style[i]) pair is from the same batch and essentially could be the same image + label_style = label.detach()[random_index] # need to also interpolate the label with torch.no_grad(): + # Run AdaIN and get style-transferred image stylized_image = self.style_transfer(content, style, alpha) if replace: + # In replace model, the original image is not kept. return stylized_image, (label, label_style, torch.ones(n).cuda() * label_mix_alpha) else: + # Return both orignal image and stylized image. Thus, label and label style is copied twice label1 = torch.cat([label, label]) label2 = torch.cat([label_style, label_style]) label_weight = torch.cat([torch.zeros(n), torch.ones(n) * label_mix_alpha]).cuda() diff --git a/utils/eval.py b/utils/eval.py index 5051350..60a47a3 100755 --- a/utils/eval.py +++ b/utils/eval.py @@ -1,5 +1,5 @@ from __future__ import print_function, absolute_import - +import torch __all__ = ['accuracy'] def accuracy(output, target, topk=(1,)): @@ -13,6 +13,31 @@ def accuracy(output, target, topk=(1,)): res = [] for k in topk: - correct_k = correct[:k].view(-1).float().sum(0) + correct_k = correct[:k].reshape(-1).float().sum(0) res.append(correct_k.mul_(100.0 / batch_size)) - return res \ No newline at end of file + return res + +def accuracy_and_perclass(output, target, topk=(1,), numclasses=200): + """Computes the precision@k for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + + # top1 per-class stats + num = target.bincount() + wt = (pred[0] == target).int() + correct = target.bincount(wt) + num_per_class = torch.zeros(numclasses).int() + correct_per_class = torch.zeros(numclasses).int() + num_per_class[:len(num)] = num + correct_per_class[:len(correct)] = correct + + return *res, num_per_class, correct_per_class \ No newline at end of file diff --git a/utils/lr_scheduler.py b/utils/lr_scheduler.py index 4b89653..375eb86 100644 --- a/utils/lr_scheduler.py +++ b/utils/lr_scheduler.py @@ -1,8 +1,7 @@ from torch.optim.lr_scheduler import _LRScheduler +import math - -def adjust_learning_rate(optimizer, epoch, args, batch=None, nBatch=None): - global state +def adjust_learning_rate(optimizer, epoch, args, state, batch=None, nBatch=None): if args.lr_schedule == 'cos': T_total = args.epochs * nBatch T_cur = (epoch % args.epochs) * nBatch + batch