▊ 1 引言
该论文出自于牛津大学,主要是关于对抗训练的研究。目前已经有研究表明使用单步进行对抗训练会导致一种严重的过拟合现象,在该论文中作者经过理论分析和实验验证重新审视了对抗噪声和梯度剪切在单步对抗训练中的作用。
论文链接:https://arxiv.org/abs/2202.01181
▊ 2 预备知识
给定一个参数为的分类器,一个对抗扰动集合。如果对于任意的对抗扰动,有,则可以说在点关于对抗扰动集合是鲁棒的。对抗扰动集合的定义为:
▊ 3 N-FGSM对抗训练
证明:由不等式可知,当时函数是凹函数,则有:
证毕。
定理1 令是方法生成的对抗扰动,是方法生成的对抗扰动,是方法生成的对抗扰动,对于任意的,则有以下不等式成立:
证明:由引理1可知:
又因为
如果令超参数,,,则有:
▊ 4 实验结果
▊ 5 论文代码
import argparse
import logging
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import os
import argparse
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', default=100, type=int)
parser.add_argument('--data-dir', default='mnist-data', type=str)
parser.add_argument('--epochs', default=10, type=int)
parser.add_argument('--epsilon', default=0.3, type=float)
parser.add_argument('--alpha', default=0.375, type=float)
parser.add_argument('--lr-max', default=5e-3, type=float)
parser.add_argument('--lr-type', default='cyclic')
parser.add_argument('--fname', default='mnist_model', type=str)
parser.add_argument('--seed', default=0, type=int)
return parser.parse_args()
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
def mnist_net():
model = nn.Sequential(
nn.Conv2d(1, 16, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, 4, stride=2, padding=1),
nn.ReLU(),
Flatten(),
nn.Linear(32*7*7,100),
nn.ReLU(),
nn.Linear(100, 10)
)
return model
class Attack_methods(object):
def __init__(self, model, X, Y, epsilon, alpha):
self.model = model
self.epsilon = epsilon
self.X = X
self.Y = Y
self.epsilon = epsilon
self.alpha = alpha
def nfgsm(self):
eta = torch.zeros_like(self.X).uniform_(-self.epsilon, self.epsilon)
delta = torch.zeros_like(self.X)
eta.requires_grad = True
output = self.model(self.X + eta)
loss = nn.CrossEntropyLoss()(output, self.Y)
loss.backward()
grad = eta.grad.detach()
delta.data = eta + self.alpha * torch.sign(grad)
return delta
class Adversarial_Trainings(object):
def __init__(self, epochs, train_loader, model, opt, epsilon, alpha, iter_num, lr_max, lr_schedule,
fname, logger):
self.epochs = epochs
self.train_loader = train_loader
self.model = model
self.opt = opt
self.epsilon = epsilon
self.alpha = alpha
self.iter_num = iter_num
self.lr_max = lr_max
self.lr_schedule = lr_schedule
self.fname = fname
self.logger = logger
def fast_training(self):
for epoch in range(self.epochs):
start_time = time.time()
train_loss = 0
train_acc = 0
train_n = 0
for i, (X, y) in enumerate(self.train_loader):
X, y = X.cuda(), y.cuda()
lr = self.lr_schedule(epoch + (i + 1) / len(self.train_loader))
self.opt.param_groups[0].update(lr=lr)
# Generating adversarial example
adversarial_attack = Attack_methods(self.model, X, y, self.epsilon, self.alpha)
delta = adversarial_attack.nfgsm()
# Update network parameters
output = self.model(torch.clamp(X + delta, 0, 1))
loss = nn.CrossEntropyLoss()(output, y)
self.opt.zero_grad()
loss.backward()
self.opt.step()
train_loss += loss.item() * y.size(0)
train_acc += (output.max(1)[1] == y).sum().item()
train_n += y.size(0)
train_time = time.time()
self.logger.info('%d \t %.1f \t %.4f \t %.4f \t %.4f', epoch, train_time - start_time, lr, train_loss/train_n, train_acc/train_n)
torch.save(self.model.state_dict(), self.fname)
logger = logging.getLogger(__name__)
logging.basicConfig(
format='[%(asctime)s] - %(message)s',
datefmt='%Y/%m/%d %H:%M:%S',
level=logging.DEBUG)
def main():
args = get_args()
logger.info(args)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=args.batch_size, shuffle=True)
model = mnist_net().cuda()
model.train()
opt = torch.optim.Adam(model.parameters(), lr=args.lr_max)
if args.lr_type == 'cyclic':
lr_schedule = lambda t: np.interp([t], [0, args.epochs * 2 // 5, args.epochs], [0, args.lr_max, 0])[0]
elif args.lr_type == 'flat':
lr_schedule = lambda t: args.lr_max
else:
raise ValueError('Unknown lr_type')
logger.info('Epoch \t Time \t LR \t \t Train Loss \t Train Acc')
adversarial_training = Adversarial_Trainings(args.epochs, train_loader, model, opt, args.epsilon, args.alpha, 40,
args.lr_max, lr_schedule, args.fname, logger)
adversarial_training.fast_training()
if __name__ == "__main__":
main()
END
联系客服