生成对抗网络 (Generative Adversarial Nets)
Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair†, Aaron Courville, Yoshua Bengio‡* 蒙特利尔大学,计算机科学与运筹学系 蒙特利尔,QC H3C 3J7
摘要
我们提出了一种通过对抗过程估计生成模型的新框架,在该框架中,我们同时训练两个模型:一个捕捉数据分布的生成模型 ,以及一个估计样本来自训练数据而非 的概率的判别模型 。 的训练过程是最大化 犯错的概率。该框架对应于一个极小极大博弈(minimax two-player game)。在任意函数 和 的空间中,存在唯一解,其中 恢复了训练数据分布,且 在任何地方都等于 。在 和 由多层感知机定义的情况下,整个系统可以通过反向传播进行训练。在训练或生成样本的过程中,不需要任何马尔可夫链或展开的近似推理网络。实验通过对生成样本的定性和定量评估,证明了该框架的潜力。
1 引言
深度学习的前景在于发现丰富的层次化模型 [2],这些模型能够表示人工智能应用中遇到的各类数据的概率分布,例如自然图像、包含语音的音频波形以及自然语言语料库中的符号。到目前为止,深度学习中最引人注目的成功涉及判别模型,通常是将高维、丰富的感官输入映射到类标签的模型 [14, 22]。这些显著的成功主要基于反向传播和 Dropout 算法,并使用了具有良好梯度特性的分段线性单元 [19, 9, 10]。深度生成模型的影响力较小,这是由于在最大似然估计及相关策略中,许多难以处理的概率计算难以近似,且难以在生成背景下利用分段线性单元的优势。我们提出了一种新的生成模型估计过程,以规避这些困难。¹
在所提出的对抗网络框架中,生成模型与一个对手进行博弈:一个学习确定样本是来自模型分布还是数据分布的判别模型。生成模型可以被认为类似于一个造假者团队,试图制造假币并使用它而不被发现,而判别模型则类似于警察,试图检测假币。这种博弈中的竞争促使双方改进各自的方法,直到假币与真币无法区分。
该框架可以为多种模型和优化算法产生特定的训练算法。在本文中,我们探讨了当生成模型通过将随机噪声传递给多层感知机来生成样本,且判别模型也是多层感知机时的特殊情况。我们将这种特殊情况称为对抗网络。在这种情况下,我们仅使用非常成功的反向传播和 Dropout 算法 [17] 即可训练这两个模型,并仅使用前向传播即可从生成模型中采样。不需要任何近似推理或马尔可夫链。
2 相关工作
除了带有隐变量的定向图模型外,另一种选择是带有隐变量的无向图模型,例如受限玻尔兹曼机 (RBMs) [27, 16]、深度玻尔兹曼机 (DBMs) [26] 及其众多变体。此类模型中的相互作用表示为非归一化势函数的乘积,并通过对随机变量的所有状态进行全局求和/积分进行归一化。这个量(配分函数)及其梯度对于除最简单的情况外都是难以处理的,尽管可以通过马尔可夫链蒙特卡洛 (MCMC) 方法进行估计。混合问题对于依赖 MCMC 的学习算法构成了重大挑战 [3, 5]。
深度信念网络 (DBNs) [16] 是包含单个无向层和多个定向层的混合模型。虽然存在快速的近似逐层训练准则,但 DBNs 带来了与无向和定向模型相关的计算困难。
人们还提出了不近似或限制对数似然的其他准则,例如得分匹配 (score matching) [18] 和噪声对比估计 (NCE) [13]。这两者都需要学习到的概率密度在归一化常数之前被解析地指定。请注意,在许多具有多层隐变量的有趣生成模型(如 DBNs 和 DBMs)中,甚至无法推导出可处理的非归一化概率密度。一些模型(如去噪自编码器 [30] 和收缩自编码器)的学习规则与应用于 RBMs 的得分匹配非常相似。在 NCE 中,正如在本文中一样,采用判别训练准则来拟合生成模型。然而,生成模型本身被用来区分生成数据与来自固定噪声分布的样本,而不是拟合一个单独的判别模型。由于 NCE 使用固定的噪声分布,在模型学习到观测变量的一个小子集上的近似正确分布后,学习速度会急剧下降。
最后,一些技术不涉及显式定义概率分布,而是训练一个生成机器从期望的分布中抽取样本。这种方法具有可以设计为通过反向传播进行训练的优势。该领域最近的突出工作包括生成随机网络 (GSN) 框架 [5],它扩展了广义去噪自编码器 [4]:两者都可以看作是定义了一个参数化的马尔可夫链,即学习执行生成马尔可夫链一步的机器的参数。与 GSNs 相比,对抗网络框架不需要马尔可夫链进行采样。由于对抗网络在生成过程中不需要反馈循环,它们能够更好地利用分段线性单元 [19, 9, 10],这些单元改善了反向传播的性能,但在反馈循环中使用时存在无界激活的问题。通过反向传播训练生成机器的更近期的例子包括自动编码变分贝叶斯 [20] 和随机反向传播 [24] 的近期工作。
3 对抗网络
当模型均为多层感知机时,对抗建模框架最容易应用。为了学习生成器在数据 上的分布 ,我们定义输入噪声变量 的先验,然后将映射表示为数据空间 ,其中 是由参数为 的多层感知机表示的可微函数。我们还定义了第二个输出单个标量的多层感知机 。 表示 来自数据而非 的概率。我们训练 以最大化为训练样本和来自 的样本分配正确标签的概率。我们同时训练 以最小化 :
换句话说, 和 进行以下价值函数为 的双人极小极大博弈:
在下一节中,我们将对对抗网络进行理论分析,本质上表明,当 和 被赋予足够的容量时,即在非参数极限下,训练准则允许恢复数据生成分布。参见图 1 以获取更非正式、更具教学意义的方法解释。在实践中,我们必须使用迭代数值方法来实现该博弈。在训练的内循环中将 优化到完成在计算上是禁止的,并且在有限数据集上会导致过拟合。相反,我们交替进行 步优化 和一步优化 。只要 变化足够缓慢,这就会导致 保持在其最优解附近。这种策略类似于 SML/PCD [31, 29] 训练从一个学习步骤到下一个学习步骤保持马尔可夫链样本的方式,以避免在学习的内循环中进行马尔可夫链的预热 (burning in)。该过程在算法 1 中正式呈现。
在实践中,公式 1 可能无法为 的良好学习提供足够的梯度。在学习初期,当 较差时, 可以高置信度地拒绝样本,因为它们明显不同于训练数据。在这种情况下, 会饱和。与其训练 最小化 ,不如训练 最大化 。该目标函数导致 和 动力学的相同不动点,但在学习初期提供了更强的梯度。

图 1: 生成对抗网络通过同时更新判别分布(,蓝色,虚线)进行训练,使其能够区分来自数据生成分布(黑色,点线) 的样本与来自生成分布 ()(绿色,实线)的样本。下方的水平线是 的采样域,在本例中是均匀分布的。上方的水平线是 域的一部分。向上的箭头显示了映射 如何在变换后的样本上施加非均匀分布 。 在 高密度区域收缩,在低密度区域扩展。(a) 考虑接近收敛的对抗对: 与 相似, 是一个部分准确的分类器。(b) 在算法的内循环中, 被训练以区分来自数据的样本,收敛到 。(c) 在 更新后, 的梯度引导 流向更有可能被分类为数据的区域。(d) 经过几步训练后,如果 和 有足够的容量,它们将达到一个点,此时两者都无法改进,因为 。判别器无法区分这两个分布,即 。
4 理论结果
生成器 隐式定义了一个概率分布 ,作为当 时获得的样本 的分布。因此,如果给予足够的容量和训练时间,我们希望算法 1 收敛到 的良好估计量。本节的结果是在非参数设置下完成的,例如,我们通过研究概率密度函数空间中的收敛性来表示具有无限容量的模型。
我们将在 4.1 节中证明该极小极大博弈对于 具有全局最优解。然后,我们将在 4.2 节中证明算法 1 优化了公式 1,从而获得预期的结果。
算法 1: 生成对抗网络的小批量随机梯度下降训练。应用于判别器的步数 是一个超参数。在我们的实验中,我们使用了 ,这是最廉价的选择。
- for 训练迭代次数 do
- for 步 do
- 从噪声先验 中采样 个噪声样本 。
- 从数据生成分布 中采样 个样本 。
- 通过沿其随机梯度上升来更新判别器:
- end for
- 从噪声先验 中采样 个噪声样本 。
- 通过沿其随机梯度下降来更新生成器:
- for 步 do
- end for
基于梯度的更新可以使用任何标准的基于梯度的学习规则。我们在实验中使用了动量 (momentum)。
4.1 的全局最优性
我们首先考虑对于任何给定的生成器 的最优判别器 。
命题 1. 对于固定的 ,最优判别器 为:
证明。 给定任何生成器 ,判别器 的训练准则是最大化量 : 对于任何 ,函数 在 上在 处达到其最大值。判别器不需要在 之外定义,证明完毕。
请注意, 的训练目标可以解释为最大化估计条件概率 的对数似然,其中 指示 是来自 ()还是来自 ()。公式 1 中的极小极大博弈现在可以重构为:
定理 1. 虚拟训练准则 的全局最小值当且仅当 时达到。此时, 达到值 。
证明。 对于 ,(考虑公式 2)。因此,通过在 处检查公式 4,我们发现 。要看到这是 的最佳可能值,且仅在 时达到,观察到: 通过从 中减去该表达式,我们得到: 其中 是 Kullback–Leibler 散度。我们在前面的表达式中识别出模型分布与数据生成过程之间的 Jensen–Shannon 散度: 由于两个分布之间的 Jensen–Shannon 散度始终是非负的,且仅在它们相等时为零,我们已经证明 是 的全局最小值,且唯一的解是 ,即生成模型完美复制了数据生成过程。
4.2 算法 1 的收敛性
命题 2. 如果 和 有足够的容量,并且在算法 1 的每一步中,判别器被允许在给定 的情况下达到其最优值,且 被更新以改进准则 则 收敛到 。
证明。 将 视为 的函数,如上述准则中所做的那样。注意 在 中是凸的。凸函数上确界的次导数包括在达到最大值点处的函数导数。换句话说,如果 且 对于每个 在 中是凸的,则如果 ,则 。这等同于在给定相应 的最优 下计算 的梯度下降更新。 在 中是凸的,且具有唯一的全局最优解,如定理 1 所证明,因此通过足够小的 更新, 收敛到 ,证明完毕。
在实践中,对抗网络通过函数 表示有限的 分布族,我们优化的是 而不是 本身。使用多层感知机来定义 会在参数空间中引入多个临界点。然而,多层感知机在实践中的出色表现表明,尽管缺乏理论保证,它们仍然是一个合理的模型。
5 实验
我们在一系列数据集上训练了对抗网络,包括 MNIST [23]、多伦多人脸数据库 (TFD) [28] 和 CIFAR-10 [21]。生成器网络使用了整流线性激活 [19, 9] 和 Sigmoid 激活的混合,而判别器网络使用了 Maxout [10] 激活。Dropout [17] 被应用于判别器网络的训练中。虽然我们的理论框架允许在生成器的中间层使用 Dropout 和其他噪声,但我们仅在生成器网络的底层使用了噪声作为输入。
我们通过将高斯 Parzen 窗口拟合到用 生成的样本,并报告该分布下的对数似然,来估计 下测试集数据的概率。
| 模型 | MNIST | TFD |
|---|---|---|
| DBN [3] | ||
| Stacked CAE [3] | ||
| Deep GSN [6] | ||
| Adversarial nets |
表 1: 基于 Parzen 窗口的对数似然估计。MNIST 上报告的数字是测试集样本的平均对数似然,标准误差是在样本间计算的。在 TFD 上,我们计算了数据集各折叠之间的标准误差,并在每个折叠的验证集上选择了不同的 。在 TFD 上, 在每个折叠上进行了交叉验证,并计算了每个折叠上的平均对数似然。对于 MNIST,我们将其与数据集的实值(而非二进制)版本的其他模型进行了比较。
高斯的 参数是通过验证集上的交叉验证获得的。该过程由 Breuleux 等人 [8] 引入,并用于各种难以处理精确似然的生成模型 [25, 3, 5]。结果报告在表 1 中。这种估计似然的方法具有较高的方差,在高维空间中表现不佳,但据我们所知,它是目前可用的最佳方法。能够采样但不能直接估计似然的生成模型的进步,激发了对如何评估此类模型的进一步研究。
在图 2 和图 3 中,我们展示了训练后从生成器网络中抽取的样本。虽然我们并不声称这些样本比现有方法生成的样本更好,但我们相信这些样本至少与文献中更好的生成模型具有竞争力,并突显了对抗框架的潜力。

图 2: 模型样本可视化。最右侧一列显示了相邻样本的最近训练示例,以证明模型没有记忆训练集。样本是公平的随机抽取,而非精挑细选。与大多数深度生成模型的可视化不同,这些图像显示的是来自模型分布的实际样本,而不是给定隐单元样本的条件均值。此外,这些样本是不相关的,因为采样过程不依赖于马尔可夫链混合。(a) MNIST (b) TFD (c) CIFAR-10(全连接模型)(d) CIFAR-10(卷积判别器和“反卷积”生成器)

图 3: 在完整模型的 空间坐标之间线性插值得到的数字。
| Deep directed graphical models | Deep undirected graphical models | Generative autoencoders | Adversarial models | |
|---|---|---|---|---|
| Training | Inference needed during training. | Inference needed during training. MCMC needed to approximate partition function gradient. | Enforced tradeoff between mixing and power of reconstruction generation | Synchronizing the discriminator with the generator. Helvetica. |
| Inference | Learned approximate inference | Variational inference | MCMC-based inference | Learned approximate inference |
| Sampling | No difficulties | Requires Markov chain | Requires Markov chain | No difficulties |
| Evaluating | Intractable, may be approximated with AIS | Intractable, may be approximated with AIS | Not explicitly represented, may be approximated with Parzen density estimation | Not explicitly represented, may be approximated with Parzen density estimation |
| Model design | Nearly all models incur extreme difficulty | Careful design needed to ensure multiple properties | Any differentiable function is theoretically permitted | Any differentiable function is theoretically permitted |
表 2: 生成建模中的挑战:针对涉及模型的每个主要操作,不同深度生成建模方法所遇到的困难总结。
6 优势和劣势
这个新框架相对于以前的建模框架既有优势也有劣势。劣势主要是没有 的显式表示,并且在训练期间 必须与 同步良好(特别是, 不能在不更新 的情况下训练过多,以避免“Helvetica 场景”,即 将太多的 值坍缩为相同的 值,从而没有足够的多样性来建模 ),就像玻尔兹曼机的负链必须在学习步骤之间保持最新一样。优势在于永远不需要马尔可夫链,仅使用反向传播来获得梯度,学习期间不需要推理,并且可以将各种各样的函数合并到模型中。表 2 总结了生成对抗网络与其他生成建模方法的比较。
上述优势主要是计算性的。对抗模型也可能从生成器网络不直接使用数据示例更新,而仅通过流经判别器的梯度更新中获得一些统计优势。这意味着输入的组件不会直接复制到生成器的参数中。对抗网络的另一个优势是它们可以表示非常尖锐,甚至是退化的分布,而基于马尔可夫链的方法要求分布必须有些模糊,以便链能够在模式之间混合。
7 结论和未来工作
该框架允许许多直接的扩展:
- 通过将 作为输入添加到 和 中,可以获得条件生成模型 。
- 可以通过训练一个辅助网络来预测给定 的 ,从而执行学习到的近似推理。这类似于通过唤醒-睡眠算法 [15] 训练的推理网络,但具有在生成器网络完成训练后,可以为固定的生成器网络训练推理网络的优势。
- 通过训练共享参数的条件模型族,可以近似建模所有条件 ,其中 是 索引的子集。本质上,可以使用对抗网络来实现确定性 MP-DBM [11] 的随机扩展。
- 半监督学习:当有限的标记数据可用时,来自判别器或推理网络的特征可以提高分类器的性能。
- 效率改进:通过设计更好的协调 和 的方法,或确定在训练期间从 采样的更好分布,可以大大加速训练。
本文证明了对抗建模框架的可行性,表明这些研究方向可能被证明是有用的。
致谢
我们要感谢 Patrice Marcotte、Olivier Delalleau、Kyunghyun Cho、Guillaume Alain 和 Jason Yosinski 的有益讨论。Yann Dauphin 与我们分享了他的 Parzen 窗口评估代码。我们要感谢 Pylearn2 [12] 和 Theano [7, 1] 的开发者,特别是 Frédéric Bastien,他专门为造福本项目而赶制了一个 Theano 功能。Arnaud Bergeron 为 LaTeX 排版提供了急需的支持。我们还要感谢 CIFAR 和加拿大研究主席的资助,以及 Compute Canada 和 Calcul Québec 提供的计算资源。Ian Goodfellow 得到了 2013 年深度学习 Google 奖学金的支持。最后,我们要感谢 Les Trois Brasseurs 激发了我们的创造力。