How Batch Normalization works?

机器学习根据训练集得到模型在测试集上的结果通常会出现一定程度的下降,下降的多少取决于你对训练集的过拟合程度。

这是因为训练集的数据分布 $p(x)​$ 与测试集数据分布 $q(x)​$ 并不完全相同,那么根据训练集得到的参数模型 $p(y\vert x,\theta)​$ 与在测试集上所期望的的模型 $q(y\vert x)​$ 也会存在差异。对训练集进行过多的拟合导致这种差异性被无限放大。

这种由于数据集的分布差异 $p(x)\neq q(x)$,而造成的参数模型差异 $p(y\vert x,\theta) \neq q(y\vert x)$ 被 下平しもだいら 英寿ひでとし 定义为 covariate shift,协变量偏移,并使用KL散度来衡量这种分布差异。

如何最大限度地消除不同数据集之间的差异是 domain adaption 所要解决的问题。

$\star$ 然而,covariate shift 不只是模型面对不同数据集时会出现。

对于多层神经网络,在训练阶段参数优化时,前层网络的参数更新会造成的后层网络的输入分布变化。

Batch Normalization 这篇文章将训练神经网络时所出现的这种,不同网络层之间数据分布差异定义为 internal covariate shift。

1 作用机制

1.1 Normalization

如何处理不同网络层对数据改变所造成的 internal covariate shift?

答案是对数据做归一化(normalization),归一化即 scale + shift 两个操作的组合。

一般常用的Z-score归一化如下式:

\[\hat{x}=\cfrac{x-\mu}{\sigma}\]

之所以称为Batch Normalization,是因为此处的均值 $\mu$ 与标准差 $\sigma$ 是在每次迭代时当前batch的数据上求得

归一化可以减小 internal covariate shift 的影响,而且可以避免输入数据 $x$ 进入饱和区(saturated regime)从而加速模型的收敛。

但这样做的结果是抹去了前层网络所学习到的信息,即原始输入 $x$ 是上一层网络激活函数的输出,对其归一化后的 $\hat{x}$ 消除了上一层网络的激活作用。

因此为了保证归一化操作不会消除前层网络对数据的非线性变换,BN在归一化的基础上定义了两个可学习参数:再缩放参数 $\gamma$ 与再平移参数 $\beta$,来保证对 $x$ 的变换是恒等变换(identity transform)。

所以BN对原始数据 $x$ 的改变如下式:

\[\begin{aligned} y&=\gamma \hat{x}+\beta\\ &=\gamma \cfrac{x-\mu}{\sigma}+\beta \end{aligned}\]

所以BN引入这样一来一回、两个相互抵消的操作优点在哪里呢?

Juliuszh 的解释是,BN这样解除了不同网络层之间的复杂关联,而引入的两个新参数与模型中的其他参数一样可以通过梯度下降求解,从简化了神经网络的训练。

1.2 Regularization

与SGD求梯度时使用子样本代替全体样本引入随机噪声类似,BN使用mini-batch代替全体样本求均值 $\mu$ 与方差 $\sigma$ 同样引入了随机噪声,一定程度上也对模型起到了正则化作用。从而可以舍弃掉为了防止过拟合而采用的dropout操作。

2 参数更新

BN所引入的两个可学习参数 $\gamma$ 与 $\beta$ 求解方式与神经网络中的普通参数相同,所以只需要求梯度。

那么通过链式法则就可以推出 $\gamma$ 与 $\beta$ 相对于损失函数 $l$ 的梯度。

\[\cfrac{\partial l}{\partial \gamma}=\cfrac{\partial l}{\partial y}\cdot \hat{x}\\ \cfrac{\partial l}{\partial \beta}=\cfrac{\partial l}{\partial y}​$\]

3 TBC

これで今夜もくつろいで熟睡できるな。

4 Reference

[1] Improving predictive inference under covariate shift by weighting the log-likelihood function, JSPI2000

[2] Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift, arXiv2015

[3] 详解深度学习中的Normalization,BN/LN/WN