Loading...
墨滴

金立

2021/11/02  阅读:51  主题:极客黑

GAN生成对抗网络

首发: https://zhuanlan.zhihu.com/p/423930735

写在前面

之前看到视网膜血管分割的,有很多用GAN做的。然后就主要是对GAN的一些基础知识的学习,大概了解一下gan的思想。主要就是四篇文章 GAN ,CGAN, pix2pix和 cyclegan

主要还是看代码理解的,代码也加了一定量注释

一、【GAN】

Paper: Goodfellow, Ian, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. "Generative adversarial nets." Advances in neural information processing systems 27 (2014).

remark:GAN的提出

cited by: 34267

code:https://github.com/eriklindernoren/PyTorch-GAN

1. 概要

  1. 提出了一个基于对抗的新生成式模型, 它由一个生成器和一个判别器组成
  2. 生成器的目标是学习到样本的数据分布, 从而能生成样本欺骗判别器; 判别器的目标是判断输入样本是生成/真实的概率。对于任意的生成器和判别器, 都存在一个独特的全局最优解
  3. 这篇文章中,生成器和判别器都由多层感知机实现

判别式模型 • 模型学习的是条件概率分布P(Y|X) • 任务是从属性X(特征) 预测标记Y(类别) 生成式模型 • 模型学习的是联合概率分布P(X,Y) • 任务是得到属性为X且类别为Y时的联合概率

gan
gan

2. 基本架构

In this article, we explore the special case when the generative model generates samples by passing random noise through a multilayer perceptron, and the discriminative model is also a multilayer perceptron. We refer to this special case as adversarial nets.

In this case, we can train both models using only the highly successful backpropagation and dropout algorithms [16] and sample from the generative model using only forward propagation. No approximate inference or Markov chains are necessary

随机噪声作为生成器的输入,模型是一个多层感知机的结构,判别器同样也是多层感知机。

3. VAE

Variational Auto-Encoder

编码器把数据编码成mean vector和standard deviation vector

采样从构建的高斯分布中采样得到latent vector

解码器从latent vector生成数据

一个encoder-decoder的架构

VAE
VAE
generation_loss = mean(square(generated_image - real_image)) 

latent_loss =KL-Divergence(latent_variable,unit_gaussian) #构造的高斯分布和单位高斯分布的KL散度

loss=generation_loss +latent_loss

4. GAN

生成器 G:多层感知机, ReLU, Sigmoid 判别器 D:多层感知机, Maxout, Dropout

image-20210831102407463
image-20210831102407463

5. Value function

data: 真实数据 D: 判别器, 输出值为 [0, 1], 代表输入来自真实数据的概率 z: 随机噪声 G: 生成器, 输出为合成数据

D的目标, 是最大化价值函数V 对数函数log在底数大于1时, 为单调递增函数 最大化V, 就是最大化 D(x) 和 1-D(G(z)) 对于任意的x, 都有 D(x) = 1 对于任意的z, 都有 D(G(z))) = 0

G的目标, 是针对特定的D, 去最小化价值函数V 最小化V, 就是最小化 D(x) 和 1-D(G(z)) 对于任意的z, 都有 D(G(z))) = 1

6. 训练流程

• 训练k次判别器( 论文实验中k=1) • 训练1次生成器

image-20210831110722408
image-20210831110722408

7.代码

生成器和判别器均为全连接层

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128256),
            *block(256512),
            *block(5121024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)  #28*28
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(2561),
            nn.Sigmoid(),   #0-1
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

训练过程

# Loss function
adversarial_loss = torch.nn.BCELoss()  #对抗损失以交叉熵损失的形式出现

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

# Optimizers,生成两个优化器,分别针对生成器和判别器的参数
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

# ----------
#  Training 开始训练
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths  对抗时候的GT, 这里1代表真,0代表假
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)  # 1
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)   # 0

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))    ##这个是真的图片

        # -----------------
        #  Train Generator    开始训练生成器(目的是使得生成的图片能以假乱真)
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(01, (imgs.shape[0], opt.latent_dim))))  #生成随机向量作为输入

        # Generate a batch of images
        gen_imgs = generator(z)  ##随机向量进入生成器  1*100 --> 1*5xx  _> 28*28

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)  # 希望能误导判别器,使得判别器认为该图是真

        g_loss.backward()    ##反向传播
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator  开始训练判别器(目的是希望能正确判断真假图片)
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid) #对于一张本来就是真的图片希望他能判断为真
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) #对于一张假的图片希望他能判断为假
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()
image-20211008151729538
image-20211008151729538

二、【CGAN】

Paper: Mirza, Mehdi, and Simon Osindero. "Conditional generative adversarial nets." arXiv preprint arXiv:1411.1784 (2014).

remark:Conditional GAN, 可以生成指定条件下的图像

cited by: 5877

code:https://github.com/caffeinism/cDC-GAN-pytorch

1. 概要

在原模型基础上,会输入额外的数据作为条件,对生成器和判别器都进行了修改。例如,在MNIST数据集上, 新模型可以生成以数字类别标签为条件的手写数字图像,模型还可以用来做多模态学习,可以生成输入图像相关的描述标签 。

2.多模态学习和图像描述

多模态学习

多模态学习
多模态学习

图像标记: 用词语对图像中不同内容进行多维度表述 图像描述: 把一幅图片翻译为一段描述文字,获取图像的标记词语,理解图像标记之间的关系,生成人类可读的句子

3. 网络结构

CGAN
CGAN

与原始GAN不同的一点就在于,加入了Y,作为一个条件输入,在生成器和判别器中的y是同一个y。这里是嵌入后concat到一起。

GAN的价值函数:

CGAN的价值函数:

两者主要的区别就在于有一个条件的形式:x|y

4. 实验

单模态任务【手写数字识别】

单模态任务
单模态任务

y一个十维的向量,代表了数字0-9,它作为条件信息输入生成器和判别器。

训练复杂:采用随机梯度下降,使用初始值为0.5的初始动量, 并逐渐增加到0.7。在生成器和判别器上都使用概率为0.5的Dropout。 使用验证集上的最大对数似然估计作为停止点 。

image-20210831215954261
image-20210831215954261

多模态任务【有图像有文本】

多模态任务
多模态任务

左边是在ImageNet上训练一个类似AlexNet的图像分类模型, 使用其最后一个全连接层的输出来提取图像特征。 右边是使用YFCC100M数据集, 训练一个词向量长度为200的 skip-gram模型(word2vector)。

多模态任务
多模态任务

基于MIR Flickr 25,000数据集, 使用上面的图像特征提取模型和skip-gram模型分别提取图像和标签特征。把提取的图像作为条件输入, 标签特征作为输出来训练CGAN。在训练CGAN时,不修改图像特征提取模型和skip-gram模型。在训练集内具有多个标签的图像, 每个标签训练一次。为每个条件输入生成100个样本, 对于每个样本输出的词向量,找到距离最近的20个单词。 在100*20个单词中,选择前10个最常见的单词 。

image-20210831214622412
image-20210831214622412

5.代码

生成器和判别器也都是多层感知机,区别在于需要concat一个条件信息

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim + opt.n_classes, 128, normalize=False),
            *block(128256),
            *block(256512),
            *block(5121024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, noise, labels):  # 1x100 1x1(0-9)
        # Concatenate label embedding and image to produce input  1x10
        gen_input = torch.cat((self.label_emb(labels), noise), -1)  #64*100 emb 64  -> 64*110
        img = self.model(gen_input)
        img = img.view(img.size(0), *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)

        self.model = nn.Sequential(
            nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(5121),
        )

    def forward(self, img, labels):
        # Concatenate label embedding and image to produce input
        d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)
        validity = self.model(d_in)
        return validity

训练过程

# Loss functions
adversarial_loss = torch.nn.MSELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()


# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))



def sample_image(n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    z = Variable(FloatTensor(np.random.normal(01, (n_row ** 2, opt.latent_dim))))
    # Get labels ranging from 0 to n_classes for n rows
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = Variable(LongTensor(labels))
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)


# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):

        batch_size = imgs.shape[0]

        # Adversarial ground truths
        valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
        fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(FloatTensor))
        labels = Variable(labels.type(LongTensor))   #是数字,代表上面的img是什么数字

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise and labels as generator input  生成随机向量和条件作为输入
        z = Variable(FloatTensor(np.random.normal(01, (batch_size, opt.latent_dim))))   #64*100
        gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))   #64
       

        # Generate a batch of images
        gen_imgs = generator(z, gen_labels)  #64*1*32*32

        # Loss measures generator's ability to fool the discriminator
        validity = discriminator(gen_imgs, gen_labels)  # 64*1
        g_loss = adversarial_loss(validity, valid)  ##希望是真

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Loss for real images
        validity_real = discriminator(real_imgs, labels)
        d_real_loss = adversarial_loss(validity_real, valid)

        # Loss for fake images
        validity_fake = discriminator(gen_imgs.detach(), gen_labels)
        d_fake_loss = adversarial_loss(validity_fake, fake)

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()
image-20211008151906105
image-20211008151906105
number: 0
embedding: tensor([ 1.4811-1.7132,  0.9492-0.6573,  0.1672,  0.9248-0.7020,  0.9547,
         1.6032-2.3548], device='cuda:0')
number: 1
embedding: tensor([-1.5507-1.3809,  0.2168-0.8488,  2.8917,  0.2372,  0.1957-0.0985,
         0.1174,  0.3064], device='cuda:0')
number: 2
embedding: tensor([-0.6030,  1.5416,  1.3762,  0.6906-1.1198-0.1106,  0.3889-0.1090,
        -1.4397-0.3295], device='cuda:0')
number: 3
embedding: tensor([-0.6867,  0.7198-0.8587-0.7496-1.1232,  0.6080,  0.4279-1.7805,
        -0.7185-0.1199], device='cuda:0')
number: 4
embedding: tensor([ 0.7740-0.5720,  1.5042-1.2249-0.9243-0.2284,  1.9632,  0.7334,
         0.7445,  0.6742], device='cuda:0')
number: 5
embedding: tensor([-0.4957-1.7681,  0.8678-1.0104,  0.1090-2.6871-0.2063-1.4525,
         0.5074-0.4610], device='cuda:0')
number: 6
embedding: tensor([ 1.2048,  0.0191,  0.2202-0.2698,  1.0157,  0.0824,  1.4816,  0.5010,
        -2.1109-0.3817], device='cuda:0')
number: 7
embedding: tensor([-0.5756-0.3427,  0.7223,  1.2729,  0.3910,  1.4034-0.6651-0.1426,
         0.0193,  1.5144], device='cuda:0')
number: 8
embedding: tensor([ 0.8431,  0.6971-0.9441-1.3121,  1.0290,  0.1637-0.1292,  1.9960,
         0.0947,  1.1863], device='cuda:0')
number: 9
embedding: tensor([ 1.5576-1.5944-1.4717,  1.1670-1.0608-0.0957-0.6714-0.3822,
         2.3542,  0.6662], device='cuda:0')

三、【pix2pix】

Paper: Isola, Phillip, Jun-Yan Zhu, Tinghui Zhou, and Alexei A. Efros. "Image-to-image translation with conditional adversarial networks." In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 1125-1134. 2017.

remark:图像翻译

cited by: 10095

code:https://phillipi.github.io/pix2pix/?utm_source=catalyzex.com

1. 概要

研究条件生成式对抗网络在图像翻译任务中的通用解决方案。网络不仅学习从输入图像到输出图像的映射(生成器),还学习了用于训练该映射的损失函数(判别器)。把这种方法可以有效应用在图像合成,图像上色等多种图像翻译任务中。表明可以在不手工设计损失函数的情况下,也能获得理想的结果。

2. 网络结构

生成器是一个UNet。

unet
unet

判别器是PatchGAN

作者认为像素级的l1 loss能很好的捕捉到图像中的低频信息,GAN的判别器只需要关注高频信息。所以把图像切成 N*N 的patch,其中N显著小于图像尺寸。假设在大于N时,像素之间是相互独立的,从而可以把图像建模成马尔科夫随机场。把判别器在所有patch上的推断结果,求平均来作为最终输出。可以把PatchGAN理解为对图像纹理/style损失的计算。

目标函数

总的目标是:

它由一个cgan损失和L1损失加权相加而成。其中cgan的损失为:

pix2pix
pix2pix

这里的x是条件,也就是一个分割图。y是通过生成器生成的实景图。在判别器中,y可以是前面生成器的输出,也可以是GT。patchgan的体现就是,最后输出时,(16,16,1)中的每一个像素点,都代表着原图中的一个16x16的patch。

criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()

for epoch in range(opt.epoch, opt.n_epochs):
    for i, batch in enumerate(dataloader):

        # Model inputs
        real_A = Variable(batch["B"].type(Tensor))   #真实的分割图 1*3*256*256
        real_B = Variable(batch["A"].type(Tensor))   #真实的建筑图 1*3*256*256

        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False#全1(1*1*16*16)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False#全0(1*1*16*16)

        # ------------------
        #  Train Generators
        # ------------------

        optimizer_G.zero_grad()

        # GAN loss
        fake_B = generator(real_A)   #先通过真的分割图生成假的建筑图 1*3*256*256
        pred_fake = discriminator(fake_B, real_A)  #判别一个这个假的建筑图 1*1*16*16
        loss_GAN = criterion_GAN(pred_fake, valid) #希望被判别器识别错误
        # Pixel-wise loss
        loss_pixel = criterion_pixelwise(fake_B, real_B) #计算与真实建筑图之间的差距 1*3*256*256

        # Total loss
        loss_G = loss_GAN + lambda_pixel * loss_pixel 

        loss_G.backward()

        optimizer_G.step()
        
      #先训练生成器,对于生成器来讲,有两个损失GAN loss(由MSE实现)和Pixel-wise loss。其中GAN loss就是希望生成器生成的假图片逼近真的。Pixel-wise loss就是生成图和label的L1—loss。


        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Real loss
        pred_real = discriminator(real_B, real_A)
        loss_real = criterion_GAN(pred_real, valid)

        # Fake loss
        pred_fake = discriminator(fake_B.detach(), real_A)
        loss_fake = criterion_GAN(pred_fake, fake)

        # Total loss
        loss_D = 0.5 * (loss_real + loss_fake)

        loss_D.backward()
        optimizer_D.step()
   ##对于判别器来讲,一方面希望能将真实的图片识别为真。另外一方面,希望将假的图片识别为假,两个平均求和


        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
        batches_done = epoch * len(dataloader) + i
        batches_left = opt.n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

image-20211008152002595
image-20211008152002595

四、【CycleGAN】

Paper: Zhu, Jun-Yan, Taesung Park, Phillip Isola, and Alexei A. Efros. "Unpaired image-to-image translation using cycle-consistent adversarial networks." In Proceedings of the IEEE international conference on computer vision, pp. 2223-2232. 2017.

remark:图像翻译 无监督 Domain Adaptation

cited by: 9546

code:https://github.com/junyanz/CycleGAN

1. 概要

一般来说,图像翻译任务需要对齐的图像对, 但很多场景下无法获得这样的训练数据。于是作者提出了一个基于非配对数据的方法, 可以学习到不同 domain 图像间的映射。CycleGAN是在GAN loss的基础上加入循环一致性损失,使得 F(G(X)) 尽量接近X(反之亦然)。

2. 网络结构和设计框架

生成器

cyclegan生成器
cyclegan生成器

判别器

判别器使用了PatchGAN

架构

cyclegan
cyclegan

3. 损失函数

目标是在X和Y两个不同domain间,建立起双向的映射关系 G 和 F ;并使用两个判别器 ,来分别对{x}和{F(Y)}、{y}和{G(x)}进行区分,于是就存在两个损失: • 对抗损失——使得映射后的数据分布接近目标domain的数据分布 • 循环一致性损失——保证学习到的两个映射 G 和 F 不会相互矛盾

image-20210902110931366
image-20210902110931366

GAN损失使用的是和传统GAN网络一致的对抗损失函数

优化目标是两个min-max函数

循环一致性损失 ,对于任意一个x和y, 应该有:

使用L1距离时, 则损失函数为:

于是,完整的损失函数应该为:

论文里没有提,但是代码中还存在的一个损失,identity损失:

#先定义损失函数:
criterion_GAN = torch.nn.MSELoss()   ##判别器损失
criterion_cycle = torch.nn.L1Loss()  ##循环一致性损失
criterion_identity = torch.nn.L1Loss()  ##identity损失

#两个生成器
G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks)
G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks)

#两个判别器
D_A = Discriminator(input_shape)
D_B = Discriminator(input_shape)

##定义优化器
optimizer_G = torch.optim.Adam(
    itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)


optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))


for epoch in range(opt.epoch, opt.n_epochs):
    for i, batch in enumerate(dataloader):

        # Set model input 
        real_A = Variable(batch["A"].type(Tensor))  # A是油画图   1*3*256*256
        real_B = Variable(batch["B"].type(Tensor))  # B是真实的风景图  1*3*256*256

        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)#表真
        #1*1*16*16
        fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)#表假

        # ------------------
        #  Train Generators
        # ------------------
        ##训练生成器:

        G_AB.train()
        G_BA.train()

        optimizer_G.zero_grad()  #生成器的优化器

        # Identity loss
        loss_id_A = criterion_identity(G_BA(real_A), real_A) #通过BA生成器后的输出和自身的损失,不要偏离太远 
        loss_id_B = criterion_identity(G_AB(real_B), real_B)     # 1*3*256*256

        loss_identity = (loss_id_A + loss_id_B) / 2

        # GAN loss
        fake_B = G_AB(real_A)   # 1*3*256*256
        loss_GAN_AB = criterion_GAN(D_B(fake_B), valid) #对于生成器来讲,希望生成的假图被判别器判断为真。                                                             D_B(fake_B)的结果是1*1*16*16
        fake_A = G_BA(real_B)
        loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)  #A到B,B到A各来一次

        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

        # Cycle loss
        recov_A = G_BA(fake_B) # 1*3*256*256
        loss_cycle_A = criterion_cycle(recov_A, real_A) #生成的假图再从B->A,得到循环一致性损失
        recov_B = G_AB(fake_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)

        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

        # Total loss
        loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity

        loss_G.backward()  ###同时优化两个生成器
        optimizer_G.step()
        
        ##训练判别器A:
  #两个判别器是分开来训练更新的

        # -----------------------
        #  Train Discriminator A
        # -----------------------

        optimizer_D_A.zero_grad()

        # Real loss
        loss_real = criterion_GAN(D_A(real_A), valid)  #1*1*16*16
        # Fake loss (on batch of previously generated samples)
        fake_A_ = fake_A_buffer.push_and_pop(fake_A)
        loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
        # Total loss
        loss_D_A = (loss_real + loss_fake) / 2

        loss_D_A.backward()
        optimizer_D_A.step()

        # -----------------------
        #  Train Discriminator B
        # -----------------------

        ##同理训练判别器B:
        optimizer_D_B.zero_grad()

        # Real loss
        loss_real = criterion_GAN(D_B(real_B), valid)    #判别器希望能把对的认对
        # Fake loss (on batch of previously generated samples)
        fake_B_ = fake_B_buffer.push_and_pop(fake_B)
        loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake) #判别器希望吧错的辨别出来
        # Total loss
        loss_D_B = (loss_real + loss_fake) / 2

        loss_D_B.backward()
        optimizer_D_B.step()

        loss_D = (loss_D_A + loss_D_B) / 2
image-20211008152121452
image-20211008152121452

金立

2021/11/02  阅读:51  主题:极客黑

作者介绍

金立