打开APP
userphoto
未登录

开通VIP,畅享免费电子书等14项超值服

开通VIP
【百战GAN】羡慕别人的美妆?那就用GAN复制粘贴过来
userphoto

2022.07.26 北京

关注

大家好,欢迎来到专栏《百战GAN》,在这个专栏里,我们会进行算法的核心思想讲解,代码的详解,模型的训练和测试等内容。

作者&编辑 | 言有三

本文资源与生成结果展示

本文篇幅:7000字

背景要求:会使用Python和Pytorch

附带资料:参考论文和项目

1 项目背景

美颜技术是祖国人民的刚需,当前的美颜技术已经从早期的美白,瘦脸隆鼻等技术发展到了更加复杂的应用,比如妆造迁移,就是将目标人脸面部完整的妆容迁移到自己的人脸图像,如下图a表示目标妆造图,b表示欲上妆的原图,c表示将目标妆迁移到原图上的效果。

要想实现如此复杂的上妆操作,需要使用GAN来进行风格迁移,本次我们来进行相关实战。

2 原理简介

最早期我们研究人脸妆造迁移算法,需要成对的妆造前后的图来进行模型训练,比如下面这样的。

但是化妆前后的图哪个妹子会公开给?怕是有点难呐,所以这限制了数据集的大小。

而基于GAN的框架中有一大类方法是无监督的,不需要成对的图就能实现两个风格(域)之间的转换,比如大家都熟悉的CycleGAN。

当然我们今天说的不是CycleGAN,而是BeautyGAN,整个框架示意图如下:

它采用了经典的图像翻译结构,生成器G包括两个输入,分别是无妆图Isrc、有妆图Iref,通过编码器(encoder)、若干个残差模块(residual blocks)、解码器(decoder)组成的生成器G 得到两个输出,分别是上妆图IBsrc、卸妆图IAref。

BeautyGAN使用了两个判别器DA和DB,其中DA用于区分真假无妆图,DB用于区分真假有妆图。

除了基本的GAN损失之外,BeautyGAN包含了3个重要的损失,分别是循环一致性损失Cycle consistency loss,感知Perceptual loss,妆造损失Makeup loss。前两者是全局损失,最后一个是局部损失。

(1) 循环一致性损失。为了消除迁移细节的瑕疵,将上妆图IBsrc和卸妆图IAref再次输入给G,重新执行一次卸妆和上妆,得到两张重建图Iresrc和卸妆图Ireref,此时通过循环损失(cycle consistency loss)约束一张图经过两次G变换后与对应的原始图相同。

(2) 感知损失。上妆和卸妆不能改变原始的人物身份信息,这可以通过基于VGG模型的Perceptual loss进行约束。

(3) 妆造损失。所谓的妆造损失,就是对人脸的各个区域进行分别化妆处理了,这是为了更加精确的控制局部区域的妆造效果。BeautyGAN训练了一个语义分割网络提取人脸不同区域的掩膜(mask),使得无妆图和有妆图在脸部、眼部、嘴部三个区域需满足妆造损失(makeup loss),妆造损失通过直方图匹配实现。

完整的BeautyGAN生成器损失如下

3 模型训练

接下来我们对整个工程的代码进行解读:

3.1 数据预处理

首先是数据集,我们这次任务需要准备好有妆的图和无妆的图,分别放置在不同的文件夹下面,与之同时,由于本模型的训练需要掩膜作为监督信息来计算不同部位的损失,所以整个目录如下:

├── images

   ├── makeup

   └── non-makeup

├── segs

   ├── makeup

   └── non-makeup


包括两个文件夹images和segs,分别是RGB图像文件夹和对应的分割掩膜文件夹,各自包括一个没有妆容的数据集non-makeup,一个有妆容的数据集makeup。Makeup其中的有效图片共2719张,non-makeup其中的有效图片共1115。

然后我们将其准备好,一一对应存入txt文件中,如下:

images/non-makeup/xfsy_0327.png segs/non-makeup/xfsy_0327.png

images/non-makeup/vSYYZ572.png segs/non-makeup/vSYYZ572.png

images/non-makeup/vSYYZ214.png segs/non-makeup/vSYYZ214.png

images/non-makeup/vSYYZ200.png segs/non-makeup/vSYYZ200.png


准备好数据之后我们需要完成数据的读取和预处理,我们定义好数据集类MAKEUP,实现__init__函数,preprocess函数,__getitem__函数:

class MAKEUP(Dataset):

    def __init__(self, image_path, transform, mode, transform_mask, cls_list):

        self.image_path = image_path ##图片目录

        self.transform = transform ##图片预处理接口

        self.mode = mode ##模式,为训练或者测试

        self.transform_mask = transform_mask ##掩膜预处理接口

        self.cls_list = cls_list ##分类类别,为妆造和非妆造两类

        self.cls_A = cls_list[0] ##第一类:makeup

        self.cls_B = cls_list[1] ##第二类:non-makeup

        ##设置训练相关的属性变量,包括txt文件路径,每一行的内容以及行数

        for cls in self.cls_list:

            setattr(self, "train_" + cls + "_list_path", os.path.join(self.image_path, "train_" + cls + ".txt"))

            setattr(self, "train_" + cls + "_lines", open(getattr(self, "train_" + cls + "_list_path"), 'r').readlines())

            setattr(self, "num_of_train_" + cls + "_data", len(getattr(self, "train_" + cls + "_lines")))

        ##设置测试相关的属性变量,包括txt文件路径,每一行的内容以及行数

        for cls in self.cls_list:

            setattr(self, "test_" + cls + "_list_path", os.path.join(self.image_path, "test_" + cls + ".txt"))

            setattr(self, "test_" + cls + "_lines", open(getattr(self, "test_" + cls + "_list_path"), 'r').readlines())

            setattr(self, "num_of_test_" + cls + "_data", len(getattr(self, "test_" + cls + "_lines")))

        self.preprocess() ##对数据文件进行预处理

    def preprocess(self):

        ## 对makeup类和non-makeup类的训练txt文件进行随机打乱操作,取得RGB和MASK文件路径

        for cls in self.cls_list:

            setattr(self, "train_" + cls + "_filenames", []) 

            setattr(self, "train_" + cls + "_mask_filenames", []) 

            lines = getattr(self, "train_" + cls + "_lines")

            random.shuffle(lines) ##对txt文件进行shuffle

            for i, line in enumerate(lines):

                splits = line.split()

                getattr(self, "train_" + cls + "_filenames").append(splits[0]) 

                getattr(self, "train_" + cls + "_mask_filenames").append(splits[1]) 

        for cls in self.cls_list:

            setattr(self, "test_" + cls + "_filenames", [])

            setattr(self, "test_" + cls + "_mask_filenames", [])

            lines = getattr(self, "test_" + cls + "_lines")

            for i, line in enumerate(lines):

                splits = line.split()

                getattr(self, "test_" + cls + "_filenames").append(splits[0])

                getattr(self, "test_" + cls + "_mask_filenames").append(splits[1])

    ## 从文件路径中获取RGB图片文件和MASK掩膜文件

    def __getitem__(self, index):

        ##训练模式,随机设置A类(makeup)和B类(non-makeup)的indexA和indexB,需要读入RGB图像和对应的掩膜图像

        if self.mode == 'train':

            index_A = random.randint(0, getattr(self, "num_of_train_" + self.cls_A + "_data") - 1)

            index_B = random.randint(0, getattr(self, "num_of_train_" + self.cls_B + "_data") - 1)

            image_A = Image.open(os.path.join(self.image_path, getattr(self, "train_" + self.cls_A + "_filenames")[index_A])).convert("RGB") ##读取RGB

            image_B = Image.open(os.path.join(self.image_path, getattr(self, "train_" + self.cls_B + "_filenames")[index_B])).convert("RGB") ##读取RGB

            mask_A = Image.open(os.path.join(self.image_path, getattr(self, "train_" + self.cls_A + "_mask_filenames")[index_A])) ##读取MASK

            mask_B = Image.open(os.path.join(self.image_path, getattr(self, "train_" + self.cls_B + "_mask_filenames")[index_B])) ##读取MASK

            ## 调用transform和transform_mask处理RGB图像和MASK图像

            return self.transform(image_A), self.transform(image_B), self.transform_mask(mask_A), self.transform_mask(mask_B)

        ##测试模式,使用输入的index变量从A类(makeup)和B类(non-makeup)中各自取出一张图做测试,不需要读入掩膜

        if self.mode in ['test', 'test_all']:

            image_A = Image.open(os.path.join(self.image_path, getattr(self, "test_" + self.cls_A + "_filenames")[index // getattr(self, 'num_of_test_' + self.cls_list[1] + '_data')])).convert("RGB")

            image_B = Image.open(os.path.join(self.image_path, getattr(self, "test_" + self.cls_B + "_filenames")[index % getattr(self, 'num_of_test_' + self.cls_list[1] + '_data')])).convert("RGB")

            ## 调用transform和transform_mask处理RGB图像和MASK图像

            return self.transform(image_A), self.transform(image_B)    


从上面的数据接口定义可以看出,__init__函数完成了路径相关的属性变量的设置,preprocess函数完成了训练和测试需要的文件路径变量的设置,__getitem__函数实现了图片数据的读取和预处理。

transform和transform_mask的定义如下:

transform = transforms.Compose([

transforms.Resize(config.img_size),transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])

transform_mask = transforms.Compose([transforms.Resize(config.img_size, interpolation=PIL.Image.NEAREST),ToTensor])


从上面可以看出,transform_mask重载了ToTensor函数,具体细节如下:

def ToTensor(pic):

    if pic.mode == 'I': ##32位int格式

        img = torch.from_numpy(np.array(pic, np.int32, copy=False))

    elif pic.mode == 'I;16': ##16位int格式

        img = torch.from_numpy(np.array(pic, np.int16, copy=False))

    else: ##8位uint格式

        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))

    # PIL的图像类型: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK

    if pic.mode == 'YCbCr':

        nchannel = 3

    elif pic.mode == 'I;16':

        nchannel = 1

    else:

        nchannel = len(pic.mode)

    img = img.view(pic.size[1], pic.size[0], nchannel)

    # 从图像的HWC顺序转换为CHW顺序

    img = img.transpose(0,1).transpose(0,2).contiguous()

    if isinstance(img, torch.ByteTensor):

        return img.float()

    else:

        return img


重载的ToTensor支持输入的掩膜为int32,int16以及uint8多种格式,自此就完成了数据接口的定义。

3.2 模型定义

接下来我们看生成模型的定义,根据对两个输入是否使用两个完全独立的分支,可以包括Generator_makeup和Generator_branch,前者定义如下:

class Generator_makeup(nn.Module):

    def __init__(self, conv_dim=64, repeat_num=6, input_nc=6):

        super(Generator_makeup, self).__init__()

        layers = []

        layers.append(nn.Conv2d(input_nc, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))

        layers.append(nn.InstanceNorm2d(conv_dim, affine=True)) ##InstanceNorm层

        layers.append(nn.ReLU(inplace=True)) ##ReLU层

        # 两层下采样编码器模块,每一层输出通道是输入通道数的2倍

        curr_dim = conv_dim

        for i in range(2):

            layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False))

            layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True))

            layers.append(nn.ReLU(inplace=True))

            curr_dim = curr_dim * 2

        # Bottleneck模块,共重复repeat_num次

        for i in range(repeat_num):

            layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))

        # 两层上采样解码器模块,每一层输出通道是输入通道数的0.5倍

        for i in range(2):

            layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False))

            layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True))

            layers.append(nn.ReLU(inplace=True))

            curr_dim = curr_dim // 2

        self.main = nn.Sequential(*layers) ##主干通道输出

        ##两个分支的定义

        ##分支1定义,包含一个7*7的卷积层和一个tanh激活函数层

        layers_1 = []

        layers_1.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))

        layers_1.append(nn.Tanh())

        self.branch_1 = nn.Sequential(*layers_1)

        ##分支2定义,包含一个7*7的卷积层和一个tanh激活函数层

        layers_2 = []

        layers_2.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))

        layers_2.append(nn.Tanh())

        self.branch_2 = nn.Sequential(*layers_2)

    def forward(self, x, y):

        input_x = torch.cat((x, y), dim=1) ##图像和标签按照维度1,即通道进行拼接

        out = self.main(input_x) ##主干通道输出

        out_A = self.branch_1(out) ##分支1通道输出

        out_B = self.branch_2(out) ##分支2通道输出

        return out_A, out_B


Generator_branch定义类似,不做赘述。接下来我们再看判别器模型Discriminator的定义:

class Discriminator(nn.Module):

    ## Discriminator使用了PatchGAN,来自于Pix2pix模型

    def __init__(self, image_size=128, conv_dim=64, repeat_num=3):

        super(Discriminator, self).__init__()

        layers = []

        ## 第一个卷积层定义,输入为3通道图像,卷积核大小为4×4,步长为2

        layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))

        layers.append(nn.LeakyReLU(0.01, inplace=True))

        ## 重复repeat_num个卷积层定义,每一个卷积核大小为4×4,步长等于2,输出通道为输入的两倍

        curr_dim = conv_dim

        for i in range(1, repeat_num):

            layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))

            layers.append(nn.LeakyReLU(0.01, inplace=True))

            curr_dim = curr_dim * 2

        # 主干模型最后一个卷积层定义,卷积核大小为4×4,步长为1

        layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=1, padding=1))

        layers.append(nn.LeakyReLU(0.01, inplace=True))

        curr_dim = curr_dim *2

        self.main = nn.Sequential(*layers)

        # 输出卷积层定义,卷积核大小为4×4,步长为1

        self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False)

    def forward(self, x):

        h = self.main(x)

        out_makeup = self.conv1(h) ##输出特征图

        return out_makeup.squeeze()


从上述模型可以看出,判别器使用了pix2pix模型中提出的PatchGAN。

由于要保持人脸的身份信息,BeautyGAN中从VGG模型的relu4_1层获得特征添加了Perceptual损失,VGG模型的结构比较简单,就不展示代码。

3.3 损失函数定义

接下来我们看损失函数的定义,GAN的基本损失可以选择使用LSGAN的MSE损失或者分类任务常用的BCE损失,定义在类GANLoss中,读者可以后面直接阅读。

重点要介绍的是直方图损失,主要是直方图和直方图匹配的计算:

def cal_hist(image):

    ##累积概率直方图的计算

    hists = []

    for i in range(0, 3):

        channel = image[i]

        channel = torch.from_numpy(channel)

        hist = torch.histc(channel, bins=256, min=0, max=256) ##统计各个bins像素数

        hist = hist.numpy()

        sum = hist.sum()

        pdf = [v / sum for v in hist] ##计算概率pdf

        for i in range(1, 256):

            pdf[i] = pdf[i - 1] + pdf[i] ##计算累积概率

        hists.append(pdf) ##得到直方图

    return hists

def cal_trans(ref, adj):

    ##直方图匹配转换函数的计算

    table = list(range(0, 256))

    for i in list(range(1, 256)):

        for j in list(range(1, 256)):

            if ref[i] >= adj[j - 1] and ref[i] <= adj[j]:

                table[i] = j

                break

    table[255] = 255

return table

def histogram_matching(dstImg, refImg, index):

    ##直方图匹配操作,使得输出图refImg与dstImg拥有一样的直方图分布

    ##index[0], index[1]: 输出图dstImg的x,y坐标

    ##index[2], index[3]: 输入图refImg的x,y坐标

    index = [x.cpu().numpy() for x in index]

    dstImg = dstImg.detach().cpu().numpy()

    refImg = refImg.detach().cpu().numpy()

    dst_align = [dstImg[i, index[0], index[1]] for i in range(0, 3)] #取出需要直方图匹配的坐标

    ref_align = [refImg[i, index[2], index[3]] for i in range(0, 3)] #取出需要直方图匹配的坐标

    hist_ref = cal_hist(ref_align) ##计算输入图直方图

    hist_dst = cal_hist(dst_align) ##计算输出图直方图

    tables = [cal_trans(hist_dst[i], hist_ref[i]) for i in range(0, 3)] #计算转换函数

    mid = copy.deepcopy(dst_align)

    for i in range(0, 3):

        for k in range(0, len(index[0])):

            dst_align[i][k] = tables[i][int(mid[i][k])] ##完成直方图匹配转换

    for i in range(0, 3):

        dstImg[i, index[0], index[1]] = dst_align[i] #将转换后的像素赋值回原图

    dstImg = torch.FloatTensor(dstImg).cuda()

    return dstImg


得到了直方图匹配的结果图后,直接对掩膜所在的像素使用L1损失,调用torch.nn.L1Loss即可,具体实现的时候,需要区分眼睛,嘴唇,皮肤等区域。

感知损失的定义则是从VGG的中间特征层取得特征向量后,直接调用欧式距离接口torch.nn.MSELoss()进行计算。

至此就完成了工程中核心代码的解读。

 

4 模型训练与测试

接下来我们对模型进行训练和测试。

4.1 模型训练

模型训练需要配置一些参数,首先是判别器和生成器的第一个卷积层的通道数以及其中重复模块的数量,配置如下:

config.g_conv_dim = 64 ##生成器第一个卷积层通道数

config.d_conv_dim = 64 ##判别器第一个卷积层通道数

config.g_repeat_num = 6 ##生成器重复的瓶颈模块数量

config.d_repeat_num = 3 ##判别器重复的模块数量


本项目使用了Adam优化器,配置学习率,动量项等相关参数配置如下:

config.G_LR = 2e-5 ##生成器学习率

config.D_LR = 2e-5 ##判别器学习率

config.beta1 = 0.5 ##一阶动量项

config.beta2 = 0.999 ##二阶动量项


最后是各项损失函数的权重:

config.lambda_A = 10.0 ##类别A的循环损失权重

config.lambda_B =10.0 ##类别B的循环损失权重

config.lambda_idt = 0.5 ##身份一致性损失权重

config.lambda_vgg = 5e-3 ##感知损失权重

config.lambda_his_lip = 1 ##嘴唇直方图权重

config.lambda_his_eye = 1 ##眼睛直方图权重

config.lambda_his_skin = 0.1##皮肤直方图权重


下图是训练了80个epoch后若干损失目标的曲线。

4.2 模型推理

训练完模型后接下来我们对模型进行测试,需要完成模型的载入,数据的预处理,结果后处理等操作

## 后处理函数

def de_norm(x):

    out = (x + 1) / 2

return out.clamp(0, 1)

## cpu和gpu变量切换函数

def to_var(x, requires_grad=True):

    if torch.cuda.is_available():

        x = x.cuda()

    if not requires_grad:

        return Variable(x, requires_grad=requires_grad)

    else:

        return Variable(x)

if __name__ == '__main__':

    G = net.Generator_branch(64,6) ##定义生成器

    snapshot_path = '80_G.pth' ##训练好的模型权重

    G.load_state_dict(torch.load(os.path.join(snapshot_path)) 

    G.eval() ##设置为推理模式

    results_dir = 'results'

    if not os.path.isdir(results_dir):

        os.makedirs(results_dest)

    imagedir = 'images' ##要上妆的内容图

    styledir = 'styles' ##妆造风格图

    resultdir = 'results' ##结果

    transform = transforms.Compose([

       transforms.Resize(256),

       transforms.ToTensor(),

       transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])]) 

    with torch.no_grad():

        imagepaths = os.listdir(imagedir) ##遍历内容图

        stylepaths = os.listdir(styledir) ##遍历风格图

        for imagepath in imagepaths:

            for stylepath in stylepaths:

                image =                  Image.open(os.path.join(imagedir,imagepath)) 

                style = Image.open(os.path.join(styledir,stylepath)) ##

                image = transform(image) 

                image.requires_grad = False 

                image = image.unsqueeze(0) 

                style = transform(style) 

                style.requires_grad = False 

                style = style.unsqueeze(0) 

                fake_makeup,fake_nomakeup = G(to_var(image,requires_grad=False),to_var(style,requires_grad=False)) 

                image_list = [] ##结果存储变量

                image_list.append(image)

                image_list.append(style)

                image_list.append(fake_makeup)

                image_list.append(fake_nomakeup)

                image_list.append(rec_makeup)

                image_list.append(rec_nomakeup)

                image_list = torch.cat(image_list, dim=3)

                save_path = os.path.join('results', imagepath.split('.')[0]+stylepath.split('.')[0]+'fake.png')

                save_image(de_norm(image_list.data), save_path, nrow=1, padding=0, normalize=True)


实验结果如下图展示:

本文参考的文献如下:

[1] Li T, Qian R, Dong C, et al. Beautygan: Instance-level facial makeup transfer with deep generative adversarial network[C]//Proceedings of the 26th ACM international conference on Multimedia. 2018: 645-653.

本文视频讲解和代码,请大家移步:

【项目实战课】基于Pytorch的BeautyGAN人脸智能美妆实战

总结

本次我们使用BeautyGAN完成了图像妆造迁移任务,这是生成对抗网络在人脸美颜上的重要应用,欢迎大家以后持续关注《百战GAN专栏》。

如何系统性地学习生成对抗网络GAN

欢迎大家关注有三AI-CV秋季划GAN小组,可以系统性学习GAN相关的内容,包括GAN的基础理论,《深度学习之图像生成GAN:理论与实践篇》,《深度学习之图像翻译GAN:理论与实践篇》以及各类GAN任务的实战。

介绍如下:【CV秋季划】生成对抗网络GAN有哪些研究和应用,如何循序渐进地学习好(2022年言有三一对一辅导)?

转载文章请后台联系

侵权必究

本站仅提供存储服务,所有内容均由用户发布,如发现有害或侵权内容,请点击举报
打开APP,阅读全文并永久保存 查看更多类似文章
猜你喜欢
类似文章
【热】打开小程序,算一算2024你的财运
轻松学Pytorch – 构建生成对抗网络
卷积降维与池化降维的对比分析
从零搭建Pytorch模型教程 | 搭建Transformer网络
基于深度学习框架pytorch搭建循环神经网络LSTM完成手写字体识别
PyPy 和 CPython 的性能比较测试
engine重构:新增order_amount函数
更多类似文章 >>
生活服务
热点新闻
分享 收藏 导长图 关注 下载文章
绑定账号成功
后续可登录账号畅享VIP特权!
如果VIP功能使用有故障,
可点击这里联系客服!

联系客服