【论文阅读笔记】Class-Incremental Learning with Strong Pre-trained Models

2023-06-26,,

Key_words: Continual learning, strong pretrained model, fix, fusion
Create_time: April 14, 2022 6:32 PM
Edited_by: Huang Yujun
Publisher: CVPR 2022
Score /5: ️
Status: Finished
Org: AWS AI Labs

[29]Class-Incremental Learning with Strong Pre-trained Models.pdf

1. Motivation

目前的研究并未考虑新类之间重叠(具有相同类标签)的情况,默认序列间类别不重叠,但这与实际情况不符
初始阶段使用大量类别数据初始化特征提取器,从而获得一个学习到丰富特征的特征提取器

本文认为训练得不错的特征提取器就已经能够在新类上表现很好,因此并不需要在每次学习新类时fintune整个网络,上图为本文所做假设的验证性实验。

左图,实验设置为,分别使用100, 200, ... ,800 个类别数据作为初始类预训练一个特征提取器后

    固定这个特征提取器并训练一个新类的全连接分类器
    微调整个特征提取器及全连接分类器

对上述两种设置进行对比,结果发现固定特征提取器的新类分类准确类,随初始化类别增多而升高,说明初始类别数越多越有助于新类的学习。

右图,为探索 仅finetune 对应层+初始阶段类别变化 对模型性能的影响。本文发现,初始阶段类别数非常多时,是否 finetune 高层对最终准确率影响不大

此外,本文预设了一种新的场景,即每个批次的类别数据中,可能包含类标号一样的数据。如第 i 批次数据出现过类别为”狗“的类,在后续批次数据中可能再次出现。

2. Contribution

提出了一种在新场景下(不同批次数据之间,旧类可能会再次出现)的持续学习方法
首次对初始阶段的训练进行研究
提出了一种能够融合各个阶段特征提取器特征的分类器

这篇工作并没有在结构或者损失函数上有跨越性的创新,但提出了一种比较有意思的解决方向。

个人觉得不好的地方:
1. 未充分探索base的层数
2. 只用了ImageNet1000 这个自然图像数据集
3. 融合 logits 的方式有点怪(虽然实验中说明了效果比concat feature 好)
3. 数学符号有点乱,表达有些地方不太清晰

3. Methodology

模型训练具体可以分为4个阶段:

初始阶段预训练得到一个学习过丰富特征的特征提取器(使用 ImageNet 中的800个类)
训练各阶段新类数据的特征提取器,同时学习如何融合对特征提取器的输出
如果发现序列数据中含学习过的旧类数据,通过 knowledge pooler 合并对应的特征

3.1 Pretrain Stage

本文的 backbone 包含两部分,一部分是 \(\phi_{s}\) 为 Resnet10(第1~3个block)使用800个类训练好后固定,另外一部分是并行连接的特征提取器 \(\phi_{b}\)(第4个block,每个批次数据对应一个,通过复制上一阶段的参数做新阶段参数的初始化)。

3.2 Training pipeline

Stage-I Feature augmentation(FA)

这样通过多个阶段学习,可以得到网络中具有多条分支的第4层参数集合:\(\{\Phi_{b},W_{b},\Phi_{n1},W_{n1},...,\Phi_{nT},W_{nT}\}\) ,其中 b 表示base,是初始阶段参数,nT表示第T个阶段,\(\Phi\)表示特征提取器(最后一层),\(W\) 表示各阶段对应的全连接分类器。后续的讨论,都基于这个特征提取器,后续训练均是在 freeze 这个特征提取器的条件下完成的。

Stage-II Fusion

这一阶段需要解决的问题是如何设计 Figure 3 中对输出特征向量进行融合的网络结构(途中打问号的区域)。针对这个问题,本文探索了两种 baseline 用作对比,即 Figure 4 中的(a)(b)。这两种 baseline 想解决的都是如何选择哪条通路的输出作为最终输出的问题-->是属于 base 的数据\(D_{b}\)(800个初始类),还是属于novel 的数据 \(D_{n}\)(后续学习到的类,文中的例子是T=1,即只有初始阶段和第一个阶段)。

\[\hat{y}_{d}=\hat{r}(x)=argmax_{l}\hat{p}^{(l)}(x;\Phi_{s},\Phi_{d},W_{d}),d\in \{b,n\}
\]

式子中 \(\hat{r}(x)\) 为分支选择函数,输出为 0 即选择base,,输出为 1 即选择 novel

Figure 4 (a)为文中提到的 Confidence-based routing,判断方式为通过对各个独立分类头输出的类别置信度进行比较,选出最大的,从而确定是属于 base 还是 novel,数学表达为

Figure 4 (b)为文中提到的 Learning-based routing,判断方式为对输出的两个特征向量进行拼接,然后使用一个全连接分类器学习如何区分 base 和 novel。数学表达为:(\(\oplus\)为concat操作)

考虑到学习如何区分新旧类时,数据中存在类别样本间的不均衡,本文针对存储样本以及新类样本的损失函数为:

其中,公式 (4) 是 binary cross-enctropy (不同于 cross-entropy 每一类对应输出均会产生loss输出,即同一时刻所有输出对应的通路都可更新,而 CE loss 只有一条通路可更新),\(r=1_{[x\in D_{n}]}\) 是 x 所对应的 onehot label;公式(5) 中 \(\varepsilon\) 为存储样本。考虑到新类旧类数据的不均衡,即 \(|D_{n}| \gg |\varepsilon|\) ,本文对loss进行了均衡化处理。

3.3 General score fusion network

经过上面2种 baseline 的对比测试后,作者提出了一种融合各个高维特征提取器输出的方法,示意图见 Figure 4 (c)(注意,本文是保留各个特征提取器的分类头的)。具体操作要点是:

固定特征提取器的参数 \(\{\Phi_{s},\Phi_{b},W_{b},\Phi_{n1},W_{n1},...,\Phi_{nT},W_{nT}\}\)
模型最终做推断时,只需要使用各个特征提取器的全连接输出logits socore \(z_{d}=W_{d}^{T}h_{d}, W_{d}\in R^{k\times |y_{d}|}\) concat到一起,然后做softmax即可得到各个类别的输出概率
为了能够让各个知识能够在各个分支融合,作者提出使用 \(\varepsilon \cup D_{nt}\) 来学习各个输出 logits 之间的关联知识权重 \(W_{dd'}\in R^{k\times |y_{d}|},d,d'\in \{b,n1,...,nt\},d\neq d'\) 。 \(W_{dd'}\) 表示其为连接第 d 个分支的特征到第 d’ 个分支。d 分支输出与其他分支输出融合的方式为直接相加。融合过程数学表达为

直接使用分类头logits而不是特征向量进行融合?

融合后的各分支输出直接 concat 形成一个完整的输出logits,从而完成推断的工作
\[z_{a}=\tilde{z}_{b}\oplus \tilde{z}_{n1} \oplus \cdot \cdot \cdot \oplus \tilde{z}_{nt}, \tilde{z}_{d}\in R^{|y_{d}|},d\in\{b,n1,...,nt\}
\]

Overlap knowledge integration

针对本文前面提到的新旧批次数据中类别 overlap 的情况,本文方法采取的策略是:直接对相同类别对应的输出做 pooling(average pooling 或 max pooling,实验结果表明 average pooling 比 max pooling效果好)见图 Figure 4 (c)

经过Pooling后的 logits 记为 \(\bar{x}\) ,当且仅当 \(y_{d}\cap y_{d'}=\emptyset\) 时, \(\tilde{z}_{a}=z_{a}\) 成立。

3.3 Balanced optimization

在最终得到一个融合后的拼接向量后,本文方法会 freeze 整个特征提取部分,单独训练全连接分类器,此时,的损失函数为公式(8):

这里同样考虑到了类别样本均衡(如\(|D_{nt}|\gg|\varepsilon|\)),公式中 \(B \in \varepsilon \cup D_{nt}\) 为分别从存储类别数据及最后一个阶段新类数据中采样的均衡数据集,损失函数为交叉熵。

分支输出选择函数可表示为:

其中,\(W_{r,aux}\in R^{(t+1)\times (t+1)}\) 为一个全连接路径分类器的参数。

最终的 loss 函数可表示为:

此外,为了防止训练过程中过度倾向于 base classes,作者会对 \(h_{d}\) 做 normalize 以及 scale。但需注意的是,对于 base 的logits,为了防止融合时过去倾向于base,文中还设置了一个超参数 \(\beta \in [0,1]\) 用于调节 base logits 的值(base logits 先乘以这个系数后在进行fusion操作)

4. Experiments

本文中只用到了一个 ImageNet1000 这个自然图像数据集(说服力有点弱),

4.1 参数\(\alpha,\beta\) 的灵敏度实验

4.2 与其他现有方法的对比(无 Overlap)

本文使用的指标为:

\(Acc_{all}\) ,最后一个阶段结束后,在所有类别数据上的acc
\(Acc_{base}\),第一阶段 800 个类的acc
\(Acc_{novel}\),新类的acc
\(Acc_{ovlp}\),overlap 部分数据的acc
\(Acc_{avg}=\frac{\sum_{d\in\{b,n1,...,nt\}}Acc_{d}}{t+1}\),

说明一下,文中按照 训练集、验证集、测试集 的方式划分数据,因此本文的参数是在验证集中挑选最好的产生。因此就需要选定一个性能指标去挑选验证集上“最好”的模型。下表中的 \(best-Acc_{all},best-balanced,best-Acc_{avg}\) 分别表示使用 \(Acc_{all},balanced=\frac{Acc_{all}+Acc_{avg}}{2},Acc_{avg}\) 指标下选择的模型。

Table 1 为只有一个阶段的新类的结果(800base+novel)

其中, joint learning(oracle) 指使用所有base类进行训练();“fc-only”是在分类器部分只用了全连接层,设置这组实验的目的是为了保证本方法的参数与其他对比方法的参数量接近。(奇怪?)

Table 2 为有多个含新类阶段的结果(800base+novel)

4.2 Fusion策略的对比实验(无overlap)

下表中的标识为:

FA:本文提出的 feature augmentaion(即初始阶段使用800个类训练)
RT:retrain
FT:fineture
FeatCat+RT:重新训练一个输入为 将所有特征concat在一起 的全连接分类器
LogitCat:重新训练(RT)或微调(FT)一个将 logits 拼接的全连接分类器

4.3 Base 特征提取器层数的影响(无Overlap)

4.4 Overlap情况下的表现

4.5 Base中包含所有新类的情况

论文阅读笔记】Class-Incremental Learning with Strong Pre-trained Models的相关教程结束。

《【论文阅读笔记】Class-Incremental Learning with Strong Pre-trained Models.doc》

下载本文的Word格式文档,以方便收藏与打印。