VGGNet和GoogLeNet等网络都表明有足够的深度是模型表现良好的前提,但是在网络深度增加到一定程度时,更深的网络意味着更高的训练误差。误差升高的原因是网络越深,梯度弥散[还有梯度爆炸的可能性]的现象就越明显,所以在后向传播的时候,无法有效的把梯度更新到前面的网络层,靠前的网络层参数无法更新,导致训练和测试效果变差。所以ResNet面临的问题是怎样在增加网络深度的情况下有可以有效解决梯度消失的问题。ResNet中解决深层网络梯度消失的问题的核心结构是残差网络。
1.ResNet残差学习单元
ResNet提出了2种mapping:一种是identity mapping[恒等映射],指的是图中的曲线,把当前输出直接传输给下一层网络,相当于走了一个捷径,跳过了本层运算,另一种是residual mapping[残差映射],指的是除了曲线的部分,最终的输出是。identity mapping指的是方程中的,而residual mapping指的是。其中,identity mapping就是shortcut connection。
2.两种ResNet残差学习单元设计
2种结构分别针对ResNet34[左图]和ResNet50/101/152[右图],通常称整个结构为一个building block,右图又称bottleneck design。bottleneck design主要是为了降低参数数目,第一个的卷积把256维channel降到64维,然后在最后通过卷积恢复,总共用的参数数目为,而不使用bottleneck时就是两个的卷积,总共参数数目为,相差16.94倍。左图ResNet模块设计通常用于34层或者更少层数的网络中,右图ResNet模块设计通常用于更深的网络中,比如101层,目的是减少计算和参数量。
3.ResNet不同结构
上表列出了5种深度的ResNet,分别是18、34、50、101和152。所有的网络都分为5部分,分别是conv1、conv2_x、conv3_x、conv4_x、conv5_x。以ResNet101为例,首先有个输入的卷积,然后经过个building block,每个block为3层,所以有层,最后有个fc层用于分类,所以层。可见ResNet101中的101中的是网络的层数,需要说明的是101层网络仅仅指卷积或者全连接层,而激活层或者pooling层并没有计算在内。
4.两种Shortcut Connection方式
和是按照channel维度相加的,实线部分的Shortcut Connection两者的channel维度是相同的,执行的都是卷积操作,计算方程为。虚线部分的Shortcut Connection两者的channel维度是不同的,执行的分别是和卷积操作,计算方程为,其中是卷积操作,用来调整的channel维度。
5.残差学习
残差学习解决了深度神经网络退化的问题,但为什么残差学习比原始特征直接学习更容易呢?假设对于一个堆积层结构当输入为时,其学习到的特征记为,现在希望其可以学习到残差,这样其原始的学习特征为。当残差为0时,此时堆积层仅仅做了恒等映射,至少网络性能不会下降,实际上残差不会为0,这也会使得堆积层在输入特征基础上学习到新的特征。
接下来从数学的角度来分析这个问题,残差单元可以表示为:
其中,和分别表示第个残差单元的输入和输出,表示学习到的残差,而表示恒等映射,是ReLU激活函数。
基于上述方程,当和都是恒等映射时[即、],可以求得从浅层到深层的学习特征:
利用链式规则,可以求得反向过程的梯度:
其中,第一个因子表示损失函数到达的梯度,小括号中的1表示Shortcut Connection可以无损的传播梯度,而另一项残差梯度则需要经过带有权重的层,梯度不是直接传递过来的。残差梯度不会那么巧全为-1,而且就算其比较小,有1的存在也不会导致梯度消失。所以残差学习比原始特征直接学习会更容易。
6.ResNet的TensorFlow和Keras实现
(1)TensorFlow实现:https://download.csdn.net/download/shengshengwang/10933002
(2)Keras实现:https://download.csdn.net/download/shengshengwang/10933009
参考文献:
[1]ResNet网络结构:https://blog.csdn.net/dcrmg/article/details/79263415
[2]论文笔记Deep Residual Learning:https://www.cnblogs.com/jermmyhsu/p/8228007.html
[3]解析卷积神经网络:深度学习实践手册
[4]你必须要知道CNN模型ResNet:https://zhuanlan.zhihu.com/p/31852747
[5]resnet-in-tensorflow:https://github.com/wenxinxu/resnet-in-tensorflow