TorchSSLで半教師あり学習を試す~FlexMatch編~
TorchSSLは、9つの半教師あり学習手法のPyTorch実装を集めたリポジトリです。今回はその中でもSOTAな手法であるFlexMatchを使った半教師あり学習を試してみたいと思います。FlexMatchについての詳細な解説は以下の記事を参考にしていただければと思います。
ひとまずFlexMatchを動かしたい方は、こちらのGoogle Colabのノートブックを開き、上から順番に実行していただくと良いと思います。上記ノートブックでは、CIFAR-10データセット+ラベル数40(各クラス4枚のデータのみ)の学習が実行できるようになっています。バッチサイズはColab Proでの使用を前提に設定していますので、Free版などメモリが少ないGPUを使用する場合は適宜減らしていただければと思います。
パラメータの確認
FlexMatchの学習パラメータはconfig/flexmatch
内のyamlファイルに記載されている。ここではCIFAR-10+ラベル付きデータ数40の場合の設定ファイルを例に主なパラメータの意味を説明する。まず、yamlファイルに記載されているパラメータは以下の通り。
save_dir: ./saved_models save_name: flexmatch_cifar10_40_1 resume: False load_path: None overwrite: True epoch: 1 num_train_iter: 1048576 num_eval_iter: 5000 num_labels: 40 batch_size: 64 eval_batch_size: 1024 hard_label: True T: 0.5 p_cutoff: 0.95 ulb_loss_ratio: 1.0 uratio: 7 ema_m: 0.999 optim: SGD lr: 0.03 momentum: 0.9 weight_decay: 0.0005 amp: False net: WideResNet net_from_name: False depth: 28 widen_factor: 2 leaky_slope: 0.1 dropout: 0.0 data_dir: ./data dataset: cifar10 train_sampler: RandomSampler num_classes: 10 num_workers: 1 alg: flexmatch seed: 1 world_size: 1 rank: 0 multiprocessing_distributed: True dist_url: tcp://127.0.0.1:10001 dist_backend: nccl gpu: None
上記のうち、FlexMatchのアルゴリズムに関わるパラメータは以下。以下のパラメータはImageNet以外のデータセットでは同じ値を使用している。
num_labels: 40 hard_label: True p_cutoff: 0.95 ulb_loss_ratio: 1.0 uratio: 7
各パラメータの意味を説明する。
num_labels
: 学習に用いるラベル付きデータ数。多ければ多いほどTop-1精度が高くなる。hard_label
: Trueの場合、CE損失計算時にハードラベル(疑似ラベル)を用いる。Falseの場合、ソフトラベル(温度付きソフトマックスの出力値)を用いる。蒸留の考え方に従うならば、ソフトラベルを使用したほうが良さそうな気がするが、半教師あり学習ではうまくいかないのだろうか(要実験)。当該処理の実装はこちらp_cutoff
: Confidenceスコアのしきい値の係数。 の式で示される の値。実装はこちら。Threshold warm-upも実装されている。ulb_loss_ratio
: 損失をバランスするハイパーパラメータ。損失関数のの値。uratio
: ラベル付きデータとラベル無しデータの比率。7の場合、ラベル付き:ラベル無し=1:7の比率で学習。実装はこちら
Data Augmentation実装部
Weak AugmentationはRandomCropとRandomHorizontalFlipのみ。実装部分はこちら。Strong AugmentationはWeak Augmentationに加え、RandAugmentが追加されている。RandAugmentクラスを見るとアフィン変換に加え、CutOutも入れられている。
学習実行
筆者作成のGoogle Colabのノートブックを参照。
学習結果(途中)
Iteration 125000時点(全体の約1/8の進捗)で、Top-1 Acc.は92.3 %となっている。論文によれば、CIFAR-10+40ラベル付きデータのError rateが4.99 %なので、あと3 %くらい精度が上がるはずだが、Colab Proだと約30分で5000 Iterationしか回らないので、学習完了まであと30時間ほどかかる模様。ImageNetデータセットのyamlファイルを見るとCIFAR-10と同じIteration数が設定されているので、CIFAR-10に対してIteration 1048576は値が大きすぎるのでは・・・という気がする。
まとめ
FlexMatchの学習・評価が簡単に実行できるTorchSSLは神。他の手法も試してみたい。