打开APP
userphoto
未登录

开通VIP,畅享免费电子书等14项超值服

开通VIP
【百战GAN】定制属于二次元宅们的专属动漫头像,这款GAN正好!
userphoto

2022.07.26 北京

关注

大家好,欢迎来到专栏《百战GAN》,在这个专栏里,我们会进行GAN相关项目的核心思想讲解,代码的详解,模型的训练和测试等内容。

作者&编辑 | 言有三

本文资源与生成结果展示

本文篇幅:7000字

背景要求:会使用Python和Pytorch

附带资料:参考论文和项目

1 项目背景

如今二次元文化拥有数以亿计的群众基础,有三自己也是一个动漫宅,这些年醉心于国产动漫,在知乎写的第一篇文章还是给国产动漫打call。二次元宅们,给自己做一个专属动漫头像可好!作为有技术的动漫粉,我们当然不能满足于看看,有机会就要自己亲自参与创作一把。本次我们使用GAN来完成定制个人专属动漫头像的任务。

2 项目解读

本次我们要完成的任务就是从一张真人头像,变成高质量的动漫风格,并且要保证身份信息不被篡改,这样才能实现个性化的需求。

简单来说就是实现下面的转换过程:

左边是我们的真人人脸图,右边是二次元头像,这是一个风格化问题或者说图像翻译问题。

当前有CycleGAN等框架可以实现,但是效果是不行的,本次我们给大家介绍一个更合适做这个问题的框架,原理图如下:

上图就是框架的生成器和判别器,名为UGATIT,下面我们首先来解读一下该框架的主要特点:

UGATIT是一个基于GAN的无监督风格化模型,它包含了一个判别器和一个生成器,在生成器和判别器中都添加了注意力机制了保证模型的效果,具体实现就是全局和平均池化下的类激活图(Class Activation Map-CAM)。

左边是生成器,右边是判别器,首先我们看生成器:

输入两个域的图经过下采样,编码器提取得到特征图,然后对特征图通道应用注意力机制,学习输出各个通道的权重。这个注意力机制模型的目标是学习到那些能够区分源域和目标域区别的重要区域。

得到了每个通道的权重后,再应用AdaLIN层,输入生成器进行生成。AdaLIN层是Instance Normalization(简称IN)和Layer Normalization(简称LN)的结合。IN因为对各个图像特征图单独进行归一化,会保留较多的内容结构,LN与IN相比,使用了多个通道进行归一化,能够更好地获取全局特征。

两者的计算以及融合方式如下:

判别器的设计采用一个全局判别器加一个局部判别器,区别就在于全局判别器更深,达到了32倍的步长,全局判别器的感受野已经超过256×256。

 

3 模型训练

接下来我们来实现U-GAT-IT模型,剖析完整的工程代码。

3.1 数据预处理

首先是数据集的处理,我们使用作者开源的数据,https://github.com/nagadomi/lbpcascade_animeface。女性动漫头像图共3500张,其中3400张作为训练集trainA,100张作为测试集testA。真实的女性人脸肖像图也是3500张,其中3400张作为训练trainB,100张作为测试testB,目录结构如下:

我们可以使用pytorch的ImageFolder来完成数据集的读取,训练预处理函数和测试预处理函数如下:

## 训练预处理函数

train_transform = transforms.Compose([

            transforms.RandomHorizontalFlip(),

            transforms.Resize((self.img_size + 30, self.img_size+30)),

            transforms.RandomCrop(self.img_size),

            transforms.ToTensor(),

            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))

        ])

## 测试预处理函数

test_transform = transforms.Compose([

            transforms.Resize((self.img_size, self.img_size)),

            transforms.ToTensor(),

            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))

        ])

self.trainA = ImageFolder(os.path.join('dataset', self.dataset, 'trainA'), train_transform)

self.trainB = ImageFolder(os.path.join('dataset', self.dataset, 'trainB'), train_transform)

self.testA = ImageFolder(os.path.join('dataset', self.dataset, 'testA'), test_transform)

self.testB = ImageFolder(os.path.join('dataset', self.dataset, 'testB'), test_transform)

self.trainA_loader = DataLoader(self.trainA, batch_size=self.batch_size, shuffle=True)

self.trainB_loader = DataLoader(self.trainB, batch_size=self.batch_size, shuffle=True)

self.testA_loader = DataLoader(self.testA, batch_size=1, shuffle=False)

self.testB_loader = DataLoader(self.testB, batch_size=1, shuffle=False)


训练的时候将图像大小缩放到了self.img_size + 30,测试时大小缩放为self.img_size,使用了随机裁剪数据增强操作。

3.2 模型定义

接下来我们再查看模型的定义,首先是生成器的一些基本模块,包括残差网络模块和AdaILN模块:

## AdaILN网络层实现

class adaILN(nn.Module):

    def __init__(self, num_features, eps=1e-5):

        super(adaILN, self).__init__()

        self.eps = eps

        self.rho = Parameter(torch.Tensor(1, num_features, 1, 1)) #参数

        self.rho.data.fill_(0.9)

def forward(self, input, gamma, beta):

        in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True) ##计算通道均值和方差

        out_in = (input - in_mean) / torch.sqrt(in_var + self.eps) ##通道归一化

        ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3], keepdim=True), torch.var(input, dim=[1, 2, 3], keepdim=True) ##计算层均值和方差

        out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps) ##层归一化

        out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln 

        ##得到AdaLIN归一化结果

        out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)

        return out

## AdaLIN残差块,其中归一化层为AdaLIN,包含两个卷积层

class ResnetAdaILNBlock(nn.Module):

    def __init__(self, dim, use_bias):

        super(ResnetAdaILNBlock, self).__init__()

        self.pad1 = nn.ReflectionPad2d(1)

        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias)

        self.norm1 = adaILN(dim)

        self.relu1 = nn.ReLU(True)

        self.pad2 = nn.ReflectionPad2d(1)

        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias)

        self.norm2 = adaILN(dim)

    def forward(self, x, gamma, beta):

        out = self.pad1(x)

        out = self.conv1(out)

        out = self.norm1(out, gamma, beta)

        out = self.relu1(out)

        out = self.pad2(out)

        out = self.conv2(out)

        out = self.norm2(out, gamma, beta)

        return out + x


然后是整个生成器的定义:

class ResnetGenerator(nn.Module):

    def __init__(self, input_nc, output_nc, ngf=64, n_blocks=6, img_size=256, light=False):

        assert(n_blocks >= 0)

        super(ResnetGenerator, self).__init__()

        self.input_nc = input_nc ##输入通道

        self.output_nc = output_nc ##输出通道

        self.ngf = ngf ##基准通道数

        self.n_blocks = n_blocks ##残差块数量

        self.img_size = img_size ##输入图像大小

        self.light = light ##是否使用轻量级网络

        ## 第一个卷积层,卷积核大小为7,步长为1

        DownBlock = []

        DownBlock += [nn.ReflectionPad2d(3),

                      nn.Conv2d(input_nc, ngf, kernel_size=7, stride=1, padding=0, bias=False),

                      nn.InstanceNorm2d(ngf),

                      nn.ReLU(True)]

        ## n_downsampling个下采样卷积网络层

        n_downsampling = 2

        for i in range(n_downsampling):

            mult = 2**I  ## 该层通道乘因子

            ## 下采样卷积层,输入通道数ngf*mult,输出ngf*mult*2

            DownBlock += [nn.ReflectionPad2d(1),

                          nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=0, bias=False),

                          nn.InstanceNorm2d(ngf * mult * 2),

                          nn.ReLU(True)]

        ## n_block个残差网络层,输入输出通道相等

        mult = 2**n_downsampling

        for i in range(n_blocks):

            DownBlock += [ResnetBlock(ngf * mult, use_bias=False)]

        # CAM激活图(Class Activation Map)

        self.gap_fc = nn.Linear(ngf * mult, 1, bias=False) ##均值池化全连接层

        self.gmp_fc = nn.Linear(ngf * mult, 1, bias=False) ##最大池化全连接层

        self.conv1x1 = nn.Conv2d(ngf * mult * 2, ngf * mult, kernel_size=1, stride=1, bias=True)

        self.relu = nn.ReLU(True)

    ## 参数和学习模块

        if self.light: ##如果使用轻量级模型,则包含两个全连接层,输入输出大小都是ngf*mul

            FC = [nn.Linear(ngf * mult, ngf * mult, bias=False),

                  nn.ReLU(True),

                  nn.Linear(ngf * mult, ngf * mult, bias=False),

                  nn.ReLU(True)]

        else: ##使用更大的模型,则包含两个全连接层,第一个输入大小为img_size // mult * img_size // mult * ngf * mult,输出为ngf * mult,

            FC = [nn.Linear(img_size // mult * img_size // mult * ngf * mult, ngf * mult, bias=False),

                  nn.ReLU(True),

                  nn.Linear(ngf * mult, ngf * mult, bias=False),

                  nn.ReLU(True)]

        self.gamma = nn.Linear(ngf * mult, ngf * mult, bias=False) ##得到

        self.beta = nn.Linear(ngf * mult, ngf * mult, bias=False) ##得到

        # n_block个自适应残差瓶颈模块

        for i in range(n_blocks):

            setattr(self, 'UpBlock1_' + str(i+1), ResnetAdaILNBlock(ngf * mult, use_bias=False))

        # n_downsampling个上采样模块,与下采样模型对应

        UpBlock2 = []

        for i in range(n_downsampling):

            mult = 2**(n_downsampling - i)

            UpBlock2 += [nn.Upsample(scale_factor=2, mode='nearest'),

                         nn.ReflectionPad2d(1),

                         nn.Conv2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=1, padding=0, bias=False),

                         ILN(int(ngf * mult / 2)), ##使用ILN层

                         nn.ReLU(True)]

        UpBlock2 += [nn.ReflectionPad2d(3),

                     nn.Conv2d(ngf, output_nc, kernel_size=7, stride=1, padding=0, bias=False),

                     nn.Tanh()]

        self.DownBlock = nn.Sequential(*DownBlock)

        self.FC = nn.Sequential(*FC)

        self.UpBlock2 = nn.Sequential(*UpBlock2)

    def forward(self, input):

        x = self.DownBlock(input) ##卷积特征图

        gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) ##自适应均值池化

        gap_logit = self.gap_fc(gap.view(x.shape[0], -1))

        gap_weight = list(self.gap_fc.parameters())[0]

        gap = x * gap_weight.unsqueeze(2).unsqueeze(3)

        gmp = torch.nn.functional.adaptive_max_pool2d(x, 1) ##自适应最大池化

        gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))

        gmp_weight = list(self.gmp_fc.parameters())[0]

        gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)

        cam_logit = torch.cat([gap_logit, gmp_logit], 1) ##均值池化和最大池化CAM图拼接

        x = torch.cat([gap, gmp], 1) ##均值池化和最大池化特征图拼接

        x = self.relu(self.conv1x1(x)) ##1*1卷积变换

        heatmap = torch.sum(x, dim=1, keepdim=True) ##得到heatmap

        if self.light:

            x_ = torch.nn.functional.adaptive_avg_pool2d(x, 1) ##进行均值池化

            x_ = self.FC(x_.view(x_.shape[0], -1))

        else:

            x_ = self.FC(x.view(x.shape[0], -1))

        gamma, beta = self.gamma(x_), self.beta(x_)

        for i in range(self.n_blocks):

            x = getattr(self, 'UpBlock1_' + str(i+1))(x, gamma, beta)

        out = self.UpBlock2(x)

        return out, cam_logit, heatmap ##返回特征图,CAM图以及heatmap图


判别器的定义则比较简单,其中AdaIN层的应用类似,如下:

class Discriminator(nn.Module):

    def __init__(self, input_nc, ndf=64, n_layers=5):

        super(Discriminator, self).__init__()

        model = [nn.ReflectionPad2d(1),

                 nn.utils.spectral_norm(

                 nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=0, bias=True)),

                 nn.LeakyReLU(0.2, True)]

        for i in range(1, n_layers - 2):

            mult = 2 ** (i - 1)

            model += [nn.ReflectionPad2d(1),

                      nn.utils.spectral_norm(

                      nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=2, padding=0, bias=True)),

                      nn.LeakyReLU(0.2, True)]

        mult = 2 ** (n_layers - 2 - 1)

        model += [nn.ReflectionPad2d(1),

                  nn.utils.spectral_norm(

                  nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=1, padding=0, bias=True)),

                  nn.LeakyReLU(0.2, True)]

        ## CAM(Class Activation Map)图

        mult = 2 ** (n_layers - 2)

        self.gap_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False)) ##均值池化全连接层

        self.gmp_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False)) ##最大池化全连接层

        self.conv1x1 = nn.Conv2d(ndf * mult * 2, ndf * mult, kernel_size=1, stride=1, bias=True)

        self.leaky_relu = nn.LeakyReLU(0.2, True)

        self.pad = nn.ReflectionPad2d(1)

        self.conv = nn.utils.spectral_norm(

            nn.Conv2d(ndf * mult, 1, kernel_size=4, stride=1, padding=0, bias=False))

        self.model = nn.Sequential(*model)

    def forward(self, input):

        x = self.model(input) ##卷积特征图

        gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) ##自适应均值池化

        gap_logit = self.gap_fc(gap.view(x.shape[0], -1))

        gap_weight = list(self.gap_fc.parameters())[0]

        gap = x * gap_weight.unsqueeze(2).unsqueeze(3)

        gmp = torch.nn.functional.adaptive_max_pool2d(x, 1) ##自适应最大池化

        gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))

        gmp_weight = list(self.gmp_fc.parameters())[0]

        gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)

        cam_logit = torch.cat([gap_logit, gmp_logit], 1) ##均值池化和最大池化拼接

        x = torch.cat([gap, gmp], 1)

        x = self.leaky_relu(self.conv1x1(x))

        heatmap = torch.sum(x, dim=1, keepdim=True) ##得到heatmap

        x = self.pad(x)

        out = self.conv(x)

        return out, cam_logit, heatmap ##返回特征图,CAM图以及heatmap图


3.3 优化目标

首先我们看生成器的损失定义,令fake_GA_logit,fake_GB_logit表示生成器A和B的输出,fake_A2B2A和fake_B2A2B表示A和B经过一次完整的循环后的输出,fake_A2A和fake_B2B表示将A输入B到A方向的生成器输出,以及将B输入A到B方向的生成器的输出,所有的生成器损失定义如下:

G_ad_loss_GA = self.MSE_loss(fake_GA_logit, torch.ones_like(fake_GA_logit).to(self.device)) ##全局生成器A损失

G_ad_cam_loss_GA = self.MSE_loss(fake_GA_cam_logit, torch.ones_like(fake_GA_cam_logit).to(self.device)) #全局CAM生成器A损失

G_ad_loss_LA = self.MSE_loss(fake_LA_logit, torch.ones_like(fake_LA_logit).to(self.device)) #局部生成器A损失

G_ad_cam_loss_LA = self.MSE_loss(fake_LA_cam_logit, torch.ones_like(fake_LA_cam_logit).to(self.device)) #局部CAM生成器A损失

G_ad_loss_GB = self.MSE_loss(fake_GB_logit, torch.ones_like(fake_GB_logit).to(self.device)) #全局生成器B损失

G_ad_cam_loss_GB = self.MSE_loss(fake_GB_cam_logit, torch.ones_like(fake_GB_cam_logit).to(self.device)) #全局CAM生成器B损失

G_ad_loss_LB = self.MSE_loss(fake_LB_logit, torch.ones_like(fake_LB_logit).to(self.device)) 局部生成器B损失

G_ad_cam_loss_LB = self.MSE_loss(fake_LB_cam_logit, torch.ones_like(fake_LB_cam_logit).to(self.device)) 局部CAM生成器B损失

G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A) ##A的循环损失

G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B) ##B的循环损失

G_identity_loss_A = self.L1_loss(fake_A2A, real_A) ##A的身份保持损失

G_identity_loss_B = self.L1_loss(fake_B2B, real_B) ##B的身份保持损失

## A的CAM损失

G_cam_loss_A = self.BCE_loss(fake_B2A_cam_logit, torch.ones_like(fake_B2A_cam_logit).to(self.device)) + self.BCE_loss(fake_A2A_cam_logit, torch.zeros_like(fake_A2A_cam_logit).to(self.device))

## B的CAM损失

G_cam_loss_B = self.BCE_loss(fake_A2B_cam_logit, torch.ones_like(fake_A2B_cam_logit).to(self.device)) + self.BCE_loss(fake_B2B_cam_logit, torch.zeros_like(fake_B2B_cam_logit).to(self.device))

## 生成器A的损失

G_loss_A =  self.adv_weight * (G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A

## 生成器B的损失

G_loss_B = self.adv_weight * (G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B

# 总的损失

Generator_loss = G_loss_A + G_loss_B


接下来我们看判别器的损失定义,如下:

D_ad_loss_GA = self.MSE_loss(real_GA_logit, torch.ones_like(real_GA_logit).to(self.device)) + self.MSE_loss(fake_GA_logit, torch.zeros_like(fake_GA_logit).to(self.device)) #全局判别器A损失

D_ad_cam_loss_GA = self.MSE_loss(real_GA_cam_logit, torch.ones_like(real_GA_cam_logit).to(self.device)) + self.MSE_loss(fake_GA_cam_logit, torch.zeros_like(fake_GA_cam_logit).to(self.device)) #全局CAM判别器A损失

D_ad_loss_LA = self.MSE_loss(real_LA_logit, torch.ones_like(real_LA_logit).to(self.device)) + self.MSE_loss(fake_LA_logit, torch.zeros_like(fake_LA_logit).to(self.device)) #局部判别器A损失

D_ad_cam_loss_LA = self.MSE_loss(real_LA_cam_logit, torch.ones_like(real_LA_cam_logit).to(self.device)) + self.MSE_loss(fake_LA_cam_logit, torch.zeros_like(fake_LA_cam_logit).to(self.device)) #局部CAM判别器A损失

D_ad_loss_GB = self.MSE_loss(real_GB_logit, torch.ones_like(real_GB_logit).to(self.device)) + self.MSE_loss(fake_GB_logit, torch.zeros_like(fake_GB_logit).to(self.device)) #全局判别器B损失

D_ad_cam_loss_GB = self.MSE_loss(real_GB_cam_logit, torch.ones_like(real_GB_cam_logit).to(self.device)) + self.MSE_loss(fake_GB_cam_logit, torch.zeros_like(fake_GB_cam_logit).to(self.device)) #全局CAM判别器B损失

D_ad_loss_LB = self.MSE_loss(real_LB_logit, torch.ones_like(real_LB_logit).to(self.device)) + self.MSE_loss(fake_LB_logit, torch.zeros_like(fake_LB_logit).to(self.device)) #局部判别器B损失

D_ad_cam_loss_LB = self.MSE_loss(real_LB_cam_logit, torch.ones_like(real_LB_cam_logit).to(self.device)) + self.MSE_loss(fake_LB_cam_logit, torch.zeros_like(fake_LB_cam_logit).to(self.device)) #局部CAM判别器B损失

## 判别器A损失

D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA)

## 判别器B损失

D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB)

## 判别器损失

Discriminator_loss = D_loss_A + D_loss_B


 

4 模型训练与测试

接下来我们对模型进行训练和测试,主要包括模型结构参数,训练优化参数,损失权重。

4.1 模型训练

结构相关的参数主要是输入图像大小,网络通道数的配置,残差块的数目。

ch=32, 即通道单元数为32。

n_res=4, 即生成器中的残差模块数量为4。

n_dis=4, 即生成器中的残差模块数量为4。

img_size=256,即训练图像大小为256。


优化相关的参数主要是优化器,学习率,损失的权重。本次训练使用了Adam,一阶动量项系数为0.5,一阶动量项系数为0.999,固定学习率大小为0.0001。

损失权重为:

weight_decay==0.0001,权重正则项参数。

adv_weight=1,GAN的损失权重。

cycle_weight=10,CycleGAN损失权重。

identity_weight=10,身份一致性的损失权重。

cam_weight=1000,CAM的损失权重。


下图是训练的中间结果。

从上到下分别表示real_B,fake_B2B_heatmap,fake_B2B,fake_B2A_heatmap,fake_B2A,fake_B2A2B_heatmap,fake_B2A2B。

4.2 模型推理

训练完模型后接下来我们对模型进行测试,需要完成模型的载入,数据的预处理,结果后处理等操作。

## 后处理函数

def denorm(x):

    return x * 0.5 + 0.5

## tensor到numpy变量转换

def tensor2numpy(x):

    return x.detach().cpu().numpy().transpose(1,2,0)

## RGB到BGR转换

def RGB2BGR(x):

return cv2.cvtColor(x, cv2.COLOR_RGB2BGR)

if __name__ == '__main__':

    modelpath = 'selfie2anime_params_latest.pt'

    input_nc = 3 ##生成器输入通道

    output_nc = 3 ##判别器输出通道

    ch = 32 ##基准通道数

    n_res = 4 ##生成器残差块数量

    img_size = 256 ##测试图像大小

    device = 'cpu' ##CPU模式

    result_dir = 'results' ##结果文件夹

    input_dir = 'images' ##测试文件夹

    genA2B = ResnetGenerator(input_nc=input_nc, output_nc=output_nc, ngf=ch, n_blocks=n_res, img_size=img_size, light=True).to(device) ##生成器结构

    params = torch.load(modelpath,map_location='cpu') ##载入参数

    genA2B.load_state_dict(params['genA2B']) ##载入生成器参数

    genA2B.eval() ##测试模式

    if not os.path.isdir(result_dir):

        os.makedirs(result_dir)

    imagepaths = os.listdir(input_dir) ##遍历文件夹

    transform = transforms.Compose([transforms.Resize((img_size, img_size)),transforms.ToTensor(),transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) ##预处理函数

    with torch.no_grad():

        for imagepath in imagepaths:

            image = Image.open(os.path.join(input_dir,imagepath)) ##读取图片

            image = image.convert('RGB') ##转换为RGB

            image = transform(image) ##预处理

            image.requires_grad = False

            image = image.unsqueeze(0).to(device) ##维度扩充

            print(type(image))

            fake_A2B, _,_ = genA2B(image) ##前向预测

            result = RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))) ##得到BGR格式结果

            cv2.imwrite(os.path.join(result_dir,imagepath), result * 255.0)


实验结果如下图展示:

本文参考的文献如下:

[1] Kim J, Kim M, Kang H, et al. U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation[C]. international conference on learning representations, 2020.

本文视频讲解和代码,请大家移步:

【项目实战课】基于Pytorch的UGATIT人脸动漫风格化实战

总结

本次我们使用U-GAT-IT完成了动漫头像生成的任务,这是生成对抗网络非常有意思的一个应用,欢迎大家以后持续关注《百战GAN专栏》。

如何系统性地学习生成对抗网络GAN

欢迎大家关注有三AI-CV秋季划GAN小组,可以系统性学习GAN相关的内容,包括GAN的基础理论,《深度学习之图像生成GAN:理论与实践篇》,《深度学习之图像翻译GAN:理论与实践篇》以及各类GAN任务的实战。

介绍如下:【CV秋季划】生成对抗网络GAN有哪些研究和应用,如何循序渐进地学习好(2022年言有三一对一辅导)?

转载文章请后台联系

侵权必究

本站仅提供存储服务,所有内容均由用户发布,如发现有害或侵权内容,请点击举报
打开APP,阅读全文并永久保存 查看更多类似文章
猜你喜欢
类似文章
【热】打开小程序,算一算2024你的财运
如何用 PyTorch 构建 GAN?
chatgpt中英文写作翻译提示词(4)
全新Backbone | ReXNet在CV全任务以超低FLOPs达到SOTA水平
基于微软开源深度学习算法,用 Python 实现图像和视频修复
半小时学会 PyTorch Hook
搞懂Vision Transformer 原理和代码,看这篇技术综述就够了(三)
更多类似文章 >>
生活服务
热点新闻
分享 收藏 导长图 关注 下载文章
绑定账号成功
后续可登录账号畅享VIP特权!
如果VIP功能使用有故障,
可点击这里联系客服!

联系客服