【論文読み】FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence

FixMatchと呼ばれる半教師あり学習を用いた画像分類モデルの学習手法の論文を読んでみたいと思います。FixMatchは2020年1月にGoogleによって提案された手法で、NeurIPS 2020に採択されています。

arxiv.org

なお、本手法をベースとした改良版のFlexMatchが2021年10月に提案されており、本手法は現在ではやや古い手法となっています。FlexMatchに興味がある方は以下の記事も参考にしていただければと思います。

sek165-ai.hatenablog.jp

概要

  • FixMatchは、2つの一般的な半教師あり学習手法である consistency regularization (一貫性正則化) と pseudo-labeling (疑似ラベル付け) を組み合わせた手法
  • 弱いデータ拡張を加えたラベル無しデータに対し信頼度の高い疑似ラベルを生成。さらに、同じラベル無しデータに対し強いデータ拡張を加え、疑似ラベルと予測結果が一致するように学習
  • 250ラベルのCIFAR-10で94.93%の精度、40ラベルで88.61%の精度など、さまざまな標準的な半教師あり学習ベンチマークでSOTAを達成

FixMatchとは

f:id:sek_165:20211026230625p:plain

FixMatchでは、ラベル無しデータに弱いデータ拡張を加えたデータを学習対象のモデルで推論 (Fig. 1の上側のパス) し、予測スコアがしきい値以上のクラスをラベル無しデータの擬似ラベルとして付与する。同時に、ラベル無しデータに強いデータ拡張を加えたデータを推論 (Fig. 1の下側のパス) し、この予測分布と前述の疑似ラベルとの間でCross Entropy Lossを取る。通常の疑似ラベルを用いた学習とは異なり、2つの推論結果(うち1つは疑似ラベル)のConsistency Lossを取ることで正則化の効果がある。

ラベル有りデータに対しても、通常の教師あり学習と同様にCross Entropy Lossによる損失を計算する。FixMatchでは、これらの2つのCross Entropy Lossを最小化するように学習する。

Consistency regularizationとは

Consistency regularizationは以下の式で表される。

 \displaystyle{\sum_{b=1}^{\mu B} \lVert p_m(y|\alpha(u_b)) - p_m(y|\alpha(u_b)) \rVert _2^2}

上記の式で、ラベル無しデータ  u_b に対する弱いデータ拡張  \alpha(\cdot) は確率的な関数であるため、データ拡張後のデータ  \alpha(u_b) も確率的な値となり、2つの項は同じになっているが同じ値にはならない点に注意。 \alpha を敵対的変換に置き換えた手法 (Virtual Adversarial Training)、モデル  p_{m}移動平均を用いたモデルを使用する手法 (Mean teacher)、 \ell^2 loss ではなく、Cross Entropy Loss を使用する手法などが提案されている。

Pseudo-labelingとは

以下の式で、ラベル無しデータに対するモデルの予測スコア  q_{b}=p_{m}(y|u_{b}) の最大値がしきい値  \tau 以上の場合、 \hat{q_{b}}=argmax (q_{b}) から求まるOne-hotラベルと、予測スコア  q_{b} の損失  H(\hat{q_{b}}, q_{b}) が最小となるように学習する手法である。

 \displaystyle{\frac{1}{\mu  B} \sum_{b=1}^{\mu B}}  1(max(q_{b}) \geq \tau) H(\hat{q_{b}}, q_{b})

(補足:上式中の1は数字の1ではなく、指示関数を表しています。はてブで指示関数の表示方法が分からないため、数字の1で代用しました)

FixMatchの損失関数

ラベル付きデータに対するCE損失  \ell_s と ラベル無しデータに対するCE損失  \ell_u の重み付き和  \ell_s + \lambda_u\ell_u がFixMatchの損失関数である。ここで、 \ell_s=\displaystyle{\frac{1}{B} \sum_{b=1}^{\mu B} H(p_b, p_m(y | \alpha (x_b)))} ,  \ell_u=\displaystyle{\frac{1}{\mu  B} \sum_{b=1}^{\mu B}}  1(max(q_{b}) \geq \tau) H(\hat{q_{b}}, p_m(y | \mathcal{A}(u_b))) であり、 \lambda_u は2つの損失をバランスするハイパーパラメータである。また、 \mathcal{A}(\cdot) はラベル無しデータに対する強いデータ拡張 (RandAugmentなど) を表す。一般的な半教師あり学習では、学習が進むにつれて  \lambda_u の値を大きくするが、FixMatchでは不要であることがわかった。これは、学習初期は  max(q_b) \tau より小さいため、信頼性の低い疑似ラベルが無視され、カリキュラム学習のような働きをするためであると考えられる。

FixMatchにおけるAugmentation

  • 弱いデータ拡張:SVHNを除くすべてのデータセットで、RandomHorizontalFlip (確率50%) + Translation (垂直方向と水平方向に最大12.5%)
  • 強いデータ拡張:RandAugment or CTAugment + CutOut

AutoAugmentを使用しない理由は、データ拡張戦略の学習に大量のラベル付きデータが必要であるため。

その他の重要な要素

  • OptimizerはAdamよりSGDの方が良い
  • LR schedulerはCosine learning rateが良い

実験

使用したデータセットとモデルの組み合わせは以下の通り。

データセット モデル
CIFAR-10 Wide ResNet-28-2
CIFAR-100 Wide ResNet-28-8
SVHN Wide ResNet-28-2
STL-10 Wide ResNet-37-2
ImageNet ResNet-50

使用したハイパーパラメータは以下の通り。

  • ImageNet以外: \lambda_u=1, \eta=0.03, \beta=0.9, \tau=0.95, \mu=7, B=64, K=2^{20},  \betaSGDのmomentum,  K はIteration数を表す。
  • ImageNet:論文のAppendix C を参照。

以下のTable 2より、FixMatchが一部のデータセットを除いて、他の手法より高い精度を達成。

f:id:sek_165:20211103224017p:plain

結論

  • FixMatchは他の手法と比較してハイパーパラメータが少なく、かつ高い精度が得られる
  • 重要なハイパーパラメータはWeight DecayとOptimizerである (論文のAppendix B.3, B.4, B.6を参照)