TorchSSLで半教師あり学習を試す~FlexMatch編~

TorchSSLは、9つの半教師あり学習手法のPyTorch実装を集めたリポジトリです。今回はその中でもSOTAな手法であるFlexMatchを使った半教師あり学習を試してみたいと思います。FlexMatchについての詳細な解説は以下の記事を参考にしていただければと思います。

sek165-ai.hatenablog.jp

github.com

ひとまずFlexMatchを動かしたい方は、こちらのGoogle Colabのノートブックを開き、上から順番に実行していただくと良いと思います。上記ノートブックでは、CIFAR-10データセット+ラベル数40(各クラス4枚のデータのみ)の学習が実行できるようになっています。バッチサイズはColab Proでの使用を前提に設定していますので、Free版などメモリが少ないGPUを使用する場合は適宜減らしていただければと思います。

パラメータの確認

FlexMatchの学習パラメータはconfig/flexmatch内のyamlファイルに記載されている。ここではCIFAR-10+ラベル付きデータ数40の場合の設定ファイルを例に主なパラメータの意味を説明する。まず、yamlファイルに記載されているパラメータは以下の通り。

save_dir: ./saved_models
save_name: flexmatch_cifar10_40_1
resume: False
load_path: None
overwrite: True
epoch: 1
num_train_iter: 1048576
num_eval_iter: 5000
num_labels: 40
batch_size: 64
eval_batch_size: 1024
hard_label: True
T: 0.5
p_cutoff: 0.95
ulb_loss_ratio: 1.0
uratio: 7
ema_m: 0.999
optim: SGD
lr: 0.03
momentum: 0.9
weight_decay: 0.0005
amp: False
net: WideResNet
net_from_name: False
depth: 28
widen_factor: 2
leaky_slope: 0.1
dropout: 0.0
data_dir: ./data
dataset: cifar10
train_sampler: RandomSampler
num_classes: 10
num_workers: 1
alg: flexmatch
seed: 1
world_size: 1
rank: 0
multiprocessing_distributed: True
dist_url: tcp://127.0.0.1:10001
dist_backend: nccl
gpu: None

上記のうち、FlexMatchのアルゴリズムに関わるパラメータは以下。以下のパラメータはImageNet以外のデータセットでは同じ値を使用している。

num_labels: 40
hard_label: True
p_cutoff: 0.95
ulb_loss_ratio: 1.0
uratio: 7

各パラメータの意味を説明する。

  • num_labels: 学習に用いるラベル付きデータ数。多ければ多いほどTop-1精度が高くなる。
  • hard_label: Trueの場合、CE損失計算時にハードラベル(疑似ラベル)を用いる。Falseの場合、ソフトラベル(温度付きソフトマックスの出力値)を用いる。蒸留の考え方に従うならば、ソフトラベルを使用したほうが良さそうな気がするが、半教師あり学習ではうまくいかないのだろうか(要実験)。当該処理の実装はこちら
  • p_cutoff: Confidenceスコアのしきい値の係数。 \tau_{t}(c)=\beta_{t}(c)\cdot\tau の式で示される  \tau の値。実装はこちら。Threshold warm-upも実装されている。
  • ulb_loss_ratio: 損失をバランスするハイパーパラメータ。損失関数 L_{t}=L_{s}+\lambda L_{u, t}  \lambda の値。
  • uratio: ラベル付きデータとラベル無しデータの比率。7の場合、ラベル付き:ラベル無し=1:7の比率で学習。実装はこちら

Data Augmentation実装部

Weak AugmentationはRandomCropとRandomHorizontalFlipのみ。実装部分はこちら。Strong AugmentationはWeak Augmentationに加え、RandAugmentが追加されている。RandAugmentクラスを見るとアフィン変換に加え、CutOutも入れられている。

学習実行

筆者作成のGoogle Colabのノートブックを参照。

学習結果(途中)

Iteration 125000時点(全体の約1/8の進捗)で、Top-1 Acc.は92.3 %となっている。論文によれば、CIFAR-10+40ラベル付きデータのError rateが4.99 %なので、あと3 %くらい精度が上がるはずだが、Colab Proだと約30分で5000 Iterationしか回らないので、学習完了まであと30時間ほどかかる模様。ImageNetデータセットのyamlファイルを見るとCIFAR-10と同じIteration数が設定されているので、CIFAR-10に対してIteration 1048576は値が大きすぎるのでは・・・という気がする。

まとめ

FlexMatchの学習・評価が簡単に実行できるTorchSSLは神。他の手法も試してみたい。