【論文読み】FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence
FixMatchと呼ばれる半教師あり学習を用いた画像分類モデルの学習手法の論文を読んでみたいと思います。FixMatchは2020年1月にGoogleによって提案された手法で、NeurIPS 2020に採択されています。
なお、本手法をベースとした改良版のFlexMatchが2021年10月に提案されており、本手法は現在ではやや古い手法となっています。FlexMatchに興味がある方は以下の記事も参考にしていただければと思います。
概要
- FixMatchは、2つの一般的な半教師あり学習手法である consistency regularization (一貫性正則化) と pseudo-labeling (疑似ラベル付け) を組み合わせた手法
- 弱いデータ拡張を加えたラベル無しデータに対し信頼度の高い疑似ラベルを生成。さらに、同じラベル無しデータに対し強いデータ拡張を加え、疑似ラベルと予測結果が一致するように学習
- 250ラベルのCIFAR-10で94.93%の精度、40ラベルで88.61%の精度など、さまざまな標準的な半教師あり学習ベンチマークでSOTAを達成
FixMatchとは
FixMatchでは、ラベル無しデータに弱いデータ拡張を加えたデータを学習対象のモデルで推論 (Fig. 1の上側のパス) し、予測スコアがしきい値以上のクラスをラベル無しデータの擬似ラベルとして付与する。同時に、ラベル無しデータに強いデータ拡張を加えたデータを推論 (Fig. 1の下側のパス) し、この予測分布と前述の疑似ラベルとの間でCross Entropy Lossを取る。通常の疑似ラベルを用いた学習とは異なり、2つの推論結果(うち1つは疑似ラベル)のConsistency Lossを取ることで正則化の効果がある。
ラベル有りデータに対しても、通常の教師あり学習と同様にCross Entropy Lossによる損失を計算する。FixMatchでは、これらの2つのCross Entropy Lossを最小化するように学習する。
Consistency regularizationとは
Consistency regularizationは以下の式で表される。
上記の式で、ラベル無しデータ に対する弱いデータ拡張 は確率的な関数であるため、データ拡張後のデータ も確率的な値となり、2つの項は同じになっているが同じ値にはならない点に注意。 を敵対的変換に置き換えた手法 (Virtual Adversarial Training)、モデル に移動平均を用いたモデルを使用する手法 (Mean teacher)、 loss ではなく、Cross Entropy Loss を使用する手法などが提案されている。
Pseudo-labelingとは
以下の式で、ラベル無しデータに対するモデルの予測スコア の最大値がしきい値 以上の場合、 から求まるOne-hotラベルと、予測スコア の損失 が最小となるように学習する手法である。
(補足:上式中の1は数字の1ではなく、指示関数を表しています。はてブで指示関数の表示方法が分からないため、数字の1で代用しました)
FixMatchの損失関数
ラベル付きデータに対するCE損失 と ラベル無しデータに対するCE損失 の重み付き和 がFixMatchの損失関数である。ここで、 , であり、 は2つの損失をバランスするハイパーパラメータである。また、 はラベル無しデータに対する強いデータ拡張 (RandAugmentなど) を表す。一般的な半教師あり学習では、学習が進むにつれて の値を大きくするが、FixMatchでは不要であることがわかった。これは、学習初期は が より小さいため、信頼性の低い疑似ラベルが無視され、カリキュラム学習のような働きをするためであると考えられる。
FixMatchにおけるAugmentation
- 弱いデータ拡張:SVHNを除くすべてのデータセットで、RandomHorizontalFlip (確率50%) + Translation (垂直方向と水平方向に最大12.5%)
- 強いデータ拡張:RandAugment or CTAugment + CutOut
AutoAugmentを使用しない理由は、データ拡張戦略の学習に大量のラベル付きデータが必要であるため。
その他の重要な要素
実験
使用したデータセットとモデルの組み合わせは以下の通り。
データセット | モデル |
---|---|
CIFAR-10 | Wide ResNet-28-2 |
CIFAR-100 | Wide ResNet-28-8 |
SVHN | Wide ResNet-28-2 |
STL-10 | Wide ResNet-37-2 |
ImageNet | ResNet-50 |
使用したハイパーパラメータは以下の通り。
- ImageNet以外:, はSGDのmomentum, はIteration数を表す。
- ImageNet:論文のAppendix C を参照。
以下のTable 2より、FixMatchが一部のデータセットを除いて、他の手法より高い精度を達成。
結論
- FixMatchは他の手法と比較してハイパーパラメータが少なく、かつ高い精度が得られる
- 重要なハイパーパラメータはWeight DecayとOptimizerである (論文のAppendix B.3, B.4, B.6を参照)