SWA:让你的目标检测模型无痛涨点1% AP

机器学习算法工程师

共 2682字,需浏览 6分钟

 · 2021-01-05

点蓝色字关注“机器学习算法工程师

设为星标,干货直达!


最近目标检测领域可谓是百花齐放,无论是anchor-free的检测算法还是基于transformer的检测算法都比较耀眼。虽然COCO 数据集上的AP值已经刷到了0.61,但是其实很多模型在同样条件下的mAP值差异也只是在1~2%。一篇最新的论文SWA Object Detection介绍了一个让你的检测模型无痛涨点1% AP值的策略:采用周期式学习速率(余弦退火学习速率)额外再训练你的模型12个epoch,然后简单地平均每个epoch训练得到的weights作为最终的模型。这个做法只是额外增加了训练时间,但是对模型的推理没有任何影响,更重要的是作者通过实验证明了这个策略在实例分割模型(Mask R-CNN)、two-stage检测模型(Faster R-CNN),基于anchor的one-stage检测模型(RetinaNet,YOLOv3)以及anchor-free的检测模型(FCOS)上都简单有效。这个trick是来源于18年的一份工作所提出的Stochastic Weights Averaging (SWA),经过实验作者发现SWA在检测领域也有效。

SWA

SWA简单来说就是对训练过程中的多个checkpoints进行平均,以提升模型的泛化性能。记训练过程第个epoch的checkpoint为,一般情况下我们会选择训练过程中最后的一个epoch的模型或者在验证集上效果最好的一个模型作为最终模型。但SWA一般在最后采用较高的固定学习速率或者周期式学习速率额外训练一段时间,取多个checkpoints的平均值作为最终模型。SWA的具体做法如下图所示,前75%的时间使用标准的衰减学习速率策略训练,然后剩余25%设置一个合理的固定学习速率进行训练,最后平均第二阶段每个epoch的weights。如下图b所示,也可以采用在每个epoch采用周期式的学习速率策略来训练。另外一点是模型中如果有BN层,那么应该用SWA得到的模型在训练数据中跑一遍得到BN层的running statistics。


那么SWA为什么有效呢,论文也给了简单的解释,由于模型的参数属于高维空间,SGD训练的模型往往收敛到最优解的边界区域,如下图a中的模型, 都落在边缘位置,但是平均它们可以接近最优解。那么SWA后面采用固定学习速率或者周期式学习速率来寻找更多的次优解,最后平均接近最优解。图b和c是说的是训练误差和测试误差往往不对齐,就是我们所说的模型泛化性能,那么平均的话其实是可以提升泛化性能的。


其实除了SWA,另外一个常用的策略是对训练过程的weights进行指数加权平均来提升泛化性能,这个TensorFlow有对应的实现tf.train.ExponentialMovingAverage:

  1. shadow_variable = decay * shadow_variable +(1- decay)* variable

SWA在检测上的应用

具体到目标检测模型,那么要通过实验来确定SWA的具体策略:学习速率策略以及训练epochs。论文中选择了Mask R-CNN模型进行实验,其中学习速率第一种是采用固定学习速率,共0.02, 0.002和0.0002三种学习速率,第二种是采用cos学习速率,如下图所示,每个epoch为一个周期,epoch开始时的学习速率最大,然后在epoch结束时学习速率衰减为最低,实验共选择了两套参数(0.01, 0.0001)和 (0.02, 0.0002)。至于训练epochs,共选择两套参数:24和48个epochs。这里对pretrained的模型进行finetune时,由于BN参数被frozen,所以不需要像原始的SWA那样重新计算训练集的running statistics。


具体实验结果如下表所示,从实验结果来看,采用固定学习速率最终的模型效果有所恶化,但是采用cos学习速率效果有提升,具体地采用cos lr为(0.02, 0.0002),额外训练12个epoch就可以额外提升约一个点。另外这个策略也在Faster R-CNN,RetinaNet,FCOS,YOLOv3和VFNet实验,最终都可以大约提升AP一个点左右。所以最后的策略是:

after the conventional training of an object detector with the initial learning rate and the ending learning rate , train it for an extra 12 epochs using the cyclical learning rates (, ) for each epoch, and then average these 12 checkpoints as the final detection model


参考文献

  1. SWA Object Detection
  2. Stochastic Weight Averaging in PyTorch


推荐阅读

PyTorch 源码解读之 torch.autograd

CondInst:性能和速度均超越Mask RCNN的实例分割模型

centerX: 用新的视角的方式打开CenterNet

mmdetection最小复刻版(十一):概率Anchor分配机制PAA深入分析

MMDetection新版本V2.7发布,支持DETR,还有YOLOV4在路上!

CNN:我不是你想的那样

TF Object Detection 终于支持TF2了!

无需tricks,知识蒸馏提升ResNet50在ImageNet上准确度至80%+

不妨试试MoCo,来替换ImageNet上pretrain模型!

重磅!一文深入深度学习模型压缩和加速

从源码学习Transformer!

mmdetection最小复刻版(七):anchor-base和anchor-free差异分析

mmdetection最小复刻版(四):独家yolo转化内幕


机器学习算法工程师


                                    一个用心的公众号


 


浏览 51
点赞
评论
收藏
分享

手机扫一扫分享

举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

举报