Skip to content

GAN目前常用的loss函数

DingfengShi edited this page Feb 16, 2018 · 11 revisions

总览

以下几个是用于Tensorflow API:TFGAN中的几个LOSS的类型

  • ACGAN LOSS
  • Least Squares GAN LOSS
  • modified LOSS (原作者Ian Goodfellow提出的修改损失)
  • Wasserstein GAN LOSS
  • mutual_information_penalty(源自InfoGAN)

几种模型的结构


(图片来源于网络,侵删)

modified LOSS

  • 对于生成器,一开始Goodfellow提出的loss function是(D(X)=1表示判断为真实数据):
  • 后来,改成了:
  • 两者虽然在目标上是一致的,但是转换为距离度量结果不一样,具体可参考WGAN的分析,实验发现后者效果更好,但在WGAN的论文里指出其实该形式也不合理

ACGAN LOSS

Conditional Image Synthesis With Auxiliary Classifier GANs

主要特点:

  • 判别器输出两种类型:真假、类别

  • 用与类别一一对应的one-hot向量表示condition,与noise一并送入生成器

  • 损失函数定义了如下两项:

  • Ls用于分辨真假,Lc用于分辨类型
    判别器loss: 最大化Ls+Lc
    生成器loss: 最大化Lc-Ls

  • 也即是说判别器和生成器都共同致力于让分类更准确


Wasserstein GAN LOSS(WGAN)以及WGAN-GP

Wasserstein GAN

mode collapse(模式崩溃)

  mode collapse:一般的数据都是多模(multimode)的,也就是说,在某些数据点附近,数据的数量会比较大。比如有一个数据集:记录的是澳大利亚Alice Spring和南极点在夏天时候的数据,如下图

  可以看到数据分布在两个峰值处。而我们如果想预测一下某年夏天可能出现的温度(没有指定地点),那么理想的预测结果应该是35度左右和-20度左右的概率值分别是50%和50%。但是实际上GAN训练的网络产生的结果可能是某个峰值的概率非常大,而另一个峰值的概率非常小的情况。

而WGAN解决了以下几点问题:

  • 彻底解决GAN训练不稳定的问题,不再需要小心平衡生成器和判别器的训练程度
  • 基本解决了collapse mode的问题,确保了生成样本的多样性
  • 训练过程中终于有一个像交叉熵、准确率这样的数值来指示训练的进程,这个数值越小代表GAN训练得越好,代表生成器产生的图像质量越高(如题图所示)
  • 以上一切好处不需要精心设计的网络架构,最简单的多层全连接网络就可以做到

对WGAN的分析可以参考:令人拍案叫绝的Wasserstein GAN

而对LOSS的定义如下:

  • 生成器LOSS:
  • 判别器LOSS:

原始WGAN模型的条件:

  • 判别器最后一层去掉sigmoid
  • 生成器和判别器的loss不取log
  • 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c
  • 不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行
  • 在理论分析中fw是一个Lipschitz连续(导数存在上界)的函数,这个函数作者是通过神经网络,通过改变权重w,学习这个函数的近似。
  • 对于以前的分类问题,判别器最后用的是sigmoid激活,但是现在变成回归问题去拟合Wasserstein距离,所以要取走sigmoid和log

在原始论文中,作者是通过Weight clipping,也就是把网络权重限制在一定范围内,去限制导数的上界。但是这个在后面的改进的论文也被指出,这样做的结果会使权重的分布集中在范围的左右两个端点处。因此,后来提出的改进方案gradient penalty,就是用来替代这种方法,这种网络被称为WGAN-GP

Gradient penalty的思想其实很简单:在Loss function中加入一个梯度的正则化项,惩罚导数值大于上界的权重。

而为了让样本更有多样性,让梯度尽量的大,惩罚项可改为让其固定在上界值(Lipschitz常数)K附近,最后得到的Loss function如下(这里把K值简单地设定成了1):

  • 惩罚项本来应该是在整个样本空间内采样,但是这在很多问题中不现实,于是作者就采用了一点技巧,改成在训练数据和生成数据附近的范围采样:
  • 即分别在真实数据r和生成数据g中采样,并加入一个随机噪声

    这便是惩罚空间

Least Squares GAN LOSS

Least Squares Generative Adversarial Networks

主要特点:

  • 收敛速度比WGAN更快,更稳定

  • 生成效果比普通GAN好

  • 用L2损失代替log损失

  • LSGAN的框架如下:

  • 在上面方程式中,我们选择 b=1 表明它为真实的数据,a=0 表明其为伪造数据。最后 c=1 表明我们想欺骗辨别器 D。
  • 但是这些值并不是唯一有效的值。LSGAN 作者提供了一些优化上述损失的理论,即如果 b-c=1 并且 b-a=2,那么优化上述损失就等同于最小化 Pearson χ^2 散度(Pearson χ^2 divergence)。因此,选择 a=-1、b=1 和 c=0 也是同样有效的。
  • 最终,得到的LOSS如下:
Clone this wiki locally