Deep InfoMax:基于互信息最大化的表示学习

论文标题:Learning deep representations by mutual information estimation and maximization
论文链接: https://arxiv.org/abs/1808.06670
论文来源:ICLR 2019

之前的相关博客: MINE:随机变量互信息的估计方法

本文提出的方法主要目的是训练一个表示学习的函数(也就是一个encoder)来最大化其输入与输出之间的互信息(Mutual Information, MI)。高维连续随机变量互信息的估计一向是个困难的问题,然而最近的方法(上面的博客)已经可以利用神经网络来对互信息进行有效地计算。

本文的方法利用互信息的估计来进行表示学习,并且表明完整的输入和encoder输出之间的互信息(global MI)对于学习到有用的表示来说是不足够的,输入的结构也起到一定作用,也就是说,表示与输入的局部部分(比如图片的patch)之间的互信息(local MI)能够提高表示的质量,而global MI在给定表示重构输入方面有重要的作用。

使用 \mathrm{X }\mathrm{Y } 分别代表encoder E_{\psi }: \mathrm {X}\rightarrow \mathrm {Y} 的定义域和值域, E_{\psi } 是由 \psi 参数化的神经网络。这些参数定义了一系列的encoder \mathrm{E }_{\Phi }= \left \{E_{\psi }\right \}_{\psi \in \Psi } 。现在已有输入空间上的一些训练样本 \mathrm {X}:X=\left \{x^{{(i)}}\in \mathrm {X}\right \}_{i=1}^{N} ,对应经验概率分布 \mathbb{P} 。定义 \mathbb{U}_{\psi ,\mathbb{P}} 代表encoder输出的分布,也就是说 \mathbb{U}_{\psi ,\mathbb{P}} 是首先采样 x\sim \mathrm{X} 然后采样 y\sim E_{\psi }(x) 所产生的 y\in \mathrm {Y} 的分布。

下图展示了本文中使用的一种encoder,大致过程就是将图片通过卷积网络得到 M\times M 的feature map,然后将这些feature map合并成图片的表示向量 Y

本文的encoder按照以下标准进行训练:
①优化参数 \psi ,使得互信息 I(X;E_{\psi }(X)) 最大,这里的互信息取决于目标的不同,可以是表示与完整输入 X 之间的,也可以是与输入的子集之间的;
②取决于目标的不同,分布 \mathbb{U}_{\psi ,\mathbb{P}} 可以匹配先验分布 \mathbb{V} ,这可以引导encoder的输出具有想要的特征。

  • 互信息最大化(Global MI)
  • 下图展示了互信息(global MI)最大化的基本框架:

    这里 G 代表“global”。注意 E_{\psi }f_{\psi }C_{\psi } 两部分组成,也就是 E_{\psi }=f_{\psi }\circ C_{\psi } ,其中 C_{\psi } 将图片映射成 M\times M 的feature map, f_{\psi } 将feature map映射成表示向量 Y=E_{\psi }(X) 。上面的式子中 X 其实指的是 C_{\psi }(X) ,global MI也就是要最大化 C_{\psi }(X)E_{\psi }(X) 。另外 T_{\psi ,\omega }=D_{\omega }\circ g\circ (C_{\psi },E_{\psi })g 代表连接encoder输出和Discriminator的函数。

  • JS散度形式
  • 另外,我们的目的是要最大化互信息,而不是估计其值(上述KL散度形式的互信息可以作为互信息的估计值),因此我们可以采用其他非KL散度的形式,比如Jensen-Shannon MI估计:

    x 是输入样本, x^{'} 是从 \tilde{\mathbb{P}} 中采样的样本, sp(z)=log(1+e^{z}) 是softplus激活函数。

    那么上面的式子是怎么来的呢?首先按照文章前面MINE博客中推导得到的f-divergence的对偶形式,有:

    而JS散度对应的共轭函数 f^{*}(t)=-log(1-e^{t}) ,激活函数为 g_{f}(z)=-log(1+e^{-z}) ,代入得到:

    XY 的互信息等价于 \mathbb{J}\mathbb{M} 的KL散度,但并不等价于其JS散度,不过通过最大化JS散度也可以最大化互信息,但是不能用来作为互信息的估计值。

    在上面的式子中, xE(x) 相当于正样本对, x^{'}E(x) 相当于负样本对,最大化互信息的过程就是让正样本对的得分 T_{\psi ,\omega }(x,E_{\psi }(x)) 变大,让负样本对的得分 T_{\psi ,\omega }(x^{'},E_{\psi }(x)) 的得分变小。

  • InfoNCE形式
  • 本文还利用了另外一种互信息的下界表示的形式,由InfoNCE损失而来。InfoNCE的目标是最大化正样本对的得分,最小化负样本对的得分,其形式为:

    因此类似的InfoNCE形式的互信息的下界就是:

    本文实验采用了JSD和InfoNCE两种形式,在下游任务上,使用InfoNCE通常优于JSD,尽管这种效果随着更具挑战性的数据而减弱。InfoNCE和DV需要大量的负样本,而JSD对负样本的数量够没那么敏感,在负样本数量较小时效果优于InfoNCE。

    Global MI的算法为:
    ①从数据集中采样原始图像 x_{+}^{(1)},\cdots ,x_{+}^{(n)}\sim \mathbb{P} ,接着计算feature map C_{\psi }(x_{+}^{(i)})\forall i
    ②获得图像的表示 y^{(i)}=E_{\psi }(x_{+}^{(i)})
    ③将 (C_{\psi }(x_{+}^{(i)}),y^{(i)}) 组成正样本对;
    ④从数据集中采样不同的图像 x_{-}^{(1)},\cdots ,x_{-}^{(n)}\sim \mathbb{P} ,接着计算feature map C_{\psi }(x_{-}^{(j)})\forall j
    ⑤将 (C_{\psi }(x_{-}^{(j)}),y^{(i)}) 组成负样本对;
    ⑥对JSD或者InfoNCE的目标函数进行梯度下降。

  • 局部互信息最大化(Local MI)
  • 对于global MI来说,某些任务可能是不必要的。举例来说,像素级别的噪声对于图像分类任务来说是不重要的,所以图片的表示不应该编码这些噪声。为了能够使得图像的表示更适合分类任务,我们可以最大化表示与图像的局部块之间的平均互信息,这样有利于表示包含图像块之间共享的信息。

    下图展示了local MI的框架: