こんにちは,DSOC R&Dグループ インターン生の内田です. 涼しくなってきたにも関わらずエアコンガンガンで寝てしまって風邪気味な今日この頃です. 皆様も体調にはお気をつけくださいませ.
さて,私はかねてより深層学習用フレームワークとしてTensorFlowを使っていたのですが,最近はPyTorchにドップリです. TensorFlowのメジャーアップデートがなかなか来ない*1ために,帰巣のタイミングを失っているという経緯です. PyTorchを書き下すだけでも十分簡単に実装可能ですが,実験サイクルを回したい場合には,もう少し高レベルなラッパがあると便利です. 普段はawesomeなレポジトリを参考にしながらオレオレラッパを書いているのですが,先日Twitterで次のような情報が回ってきました.
@pytorchLightin, the @PyTorch keras for #ml researchers will be added to the official @PyTorch ecosystem next week.
— William Falcon (@_willfalcon) 2019年8月2日
@MILAMontreal @NYUDataScience @berkeley_ai @StanfordAILab @MIT_CSAIL @karpathy @amuellerml @RichardSocher @soumithchintala @VectorInst
https://t.co/lIhXW9Su8V
これは使ってみるしか.
PyTorch Lightningについて
PyTorch LightningはML研究者向けに設計された軽量なPyTorchラッパで,TensorFlowにおけるKerasに相当するパッケージです*2. 学習ループや早期終了,モデルの保存と読み出しなどを自動化し,新規プロジェクトにおいて都度発生する研究の本質でない手間を減らしてくれます. 実際どれくらい手間を省いてくれるかというと,次の図で説明されています. 青色部分のみを記述すればよくなるらしい. それ以外にも,Test Tube*3と組み合わせることでロギングやコードのバージョニングをいい感じにやってくれたり,マルチノード学習や量子化をサポートしているなどの特徴があるようです. もっと詳細に知りたい方は,下記レポジトリのドキュメントをご覧いただければと思います. github.com
⚡️ライトニングPyTorch Lightning⚡️
最小限の学習コードは次のようになります.
from pytorch_lightning.models.trainer import Trainer model = MyLightningModule() # 自分で定義する trainer = Trainer() trainer.fit(model)
モデル自体はもちろんですが,データローダの定義とか誤差計算をMyLightningModule
クラス内に定義しておけばすぐに学習を回せるので,モデルの改良に集中できるというわけですね.
Kerasも同じようにfit()
一発で学習できると謳ってはいるものの,GANのように複数のモデルを順に学習するモデルではなんとなく癖のある感じになります*4.
正直辛い印象だったので,私個人こういうラッパはあまり使ってきませんでした.
今回はその感情を払拭できればいいなと思い,GANを用いた超解像モデルであるSRGAN[*5] を実装してみます*6.
SRGANの実装
以下ではかい摘んで説明するので,多少コードに抜けがあります. わからない部分などはコードは次に置いてありますので参照ください*7.
LightningModule
PyTorch Lightningのモデルクラスはpl.LightningModule
を継承して構築します.
pl.LightningModule
はデータローダやオプティマイザの定義,誤差計算など諸々を含むのですが,PyTorchの標準モジュールクラスnn.Module
の孫クラスに当たるため,
従来通りコンストラクタでモジュールを順々に定義しても良いですし,既存のネットワークを丸々読み込むことも可能です.
ネットワークやら自前の誤差関数は事前定義済みとすると,モデルクラスのコンストラクタおよびforward()
は次のようになります.
import pytorch_lightning as pl import torch.nn as nn from .networks import SRResNet, Discriminator from .losses import GANLoss, TVLoss, VGGLoss class SRGANModel(pl.LightningModule): def __init__(self, opt): super(SRGANModel, self).__init__() # パラメータの保存 self.scale_factor = opt.scale_factor self.batch_size = opt.batch_size self.patch_size = opt.patch_size # ネットワーク定義 self.net_G = SRResNet(opt.scale_factor, opt.ngf, opt.n_blocks) self.net_D = Discriminator(opt.ndf) # 誤差関数の定義 self.criterion_MSE = nn.MSELoss() self.criterion_VGG = VGGLoss(net_type='vgg19', layer='relu5_4') self.criterion_GAN = GANLoss(gan_mode='wgangp') self.criterion_TV = TVLoss() def forward(self, input): # メインのネットワークだけに通す return self.net_G(input)
ここまではnn.Module
と同じ感じですが,PyTorch Lightningではこれに加えて次の3メソッドを定義する必要があります *8.
tng_dataloader()
PyTorch標準のDataLoader
インスタンスを返却します.@pl.data_loader
の記述をお忘れなく.
from torch.utils.data.dataloader import DataLoader from .datasets import DatasetFromFolder ... @pl.data_loader def tng_dataloader(self): dataset = DatasetFromFolder( data_dir='./data/DIV2K/train', scale_factor=self.scale_factor, patch_size=self.patch_size ) return DataLoader(dataset, self.batch_size, shuffle=True, num_workers=4)
ちなみに今回はdata_dir
内の画像をランダムクロップして低/高解像度画像ペアを返却するデータローダを定義しています.
読み込むデータセットは超解像タスクでよく使われるDIV2K[*9]を用います.
学習用だけでなく,バリデーションにも同様のメソッドが定義でき,定義するとよしなにやってくれます. これに関しては複数のデータローダを返却することも可能です.
configure_optimizers()
PyTorch標準のオプティマイザおよび学習率スケジューラを返却します. GANの場合,生成器と識別器にそれぞれオプティマイザを用意するため,学習する順番にリストで格納して返却します. 対応するスケジューラも同様にリストで返却します.
import torch.optim as optim from torch.optim.lr_scheduler import StepLR ... def configure_optimizers(self): optimizer_G = optim.Adam(self.net_G.parameters(), lr=1e-4) optimizer_D = optim.Adam(self.net_D.parameters(), lr=1e-4) scheduler_G = StepLR(optimizer_G, step_size=1e+5, gamma=0.1) scheduler_D = StepLR(optimizer_D, step_size=1e+5, gamma=0.1) return [optimizer_D, optimizer_G], [scheduler_D, scheduler_G]
今回は識別器→生成器の順に学習するので,その順番に返却しています.
training_step(data_batch, batch_nb, optimizer_i)
DataLoader
から取得したミニバッチdata_batch
をネットワークに順伝播し,誤差を計算します.
batch_nb
,optimizer_i
にはそれぞれ何番目のバッチ・オプティマイザかを示す整数が入ってきます.
def training_step(self, data_batch, batch_nb, optimizer_i): img_lr = data_batch['lr'] img_hr = data_batch['hr'] if optimizer_i == 0: # 識別器 self.img_sr = self.forward(img_lr) # HR画像 d_out_real = self.net_D(img_hr) d_loss_real = self.criterion_GAN(d_out_real, True) # SR画像 d_out_fake = self.net_D(self.img_sr.detach()) d_loss_fake = self.criterion_GAN(d_out_fake, False) # がっちゃんこ d_loss = 1 + d_loss_real + d_loss_fake return {'loss': d_loss, 'prog': {'tng/d_loss': d_loss}} elif optimizer_i == 1: # 生成器 # content loss mse_loss = self.criterion_MSE(self.img_sr * 2 - 1, img_hr * 2 - 1) vgg_loss = self.criterion_VGG(self.img_sr, img_hr) content_loss = (vgg_loss + mse_loss) / 2 # adversarial loss adv_loss = self.criterion_GAN(self.net_D(self.img_sr), True) # tv loss tv_loss = self.criterion_TV(self.img_sr) # がっちゃんこ g_loss = content_loss + 1e-3 * adv_loss + 2e-8 * tv_loss return {'loss': g_loss, 'prog': {'tng/g_loss': g_loss, 'tng/content_loss': content_loss, 'tng/adv_loss': adv_loss, 'tng/tv_loss': tv_loss}}
GANの学習ではオプティマイザが複数あるため,optimizer_i
によって条件分岐してます.
1バッチの学習につき1度だけ推論が走ればいいため,optimizer_i == 0
の時の結果をインスタンス変数として確保しておいて,その後は確保した結果を用いて誤差計算を行います.
返却値は辞書型で,'loss'
に誤差そのものを格納して,'prog'
にはログを取りたい物の名前と値の辞書で格納します.
また,tng_dataloader()
と同様に,バリデーション用にvalidation_step
を定義するとよしなに実行されます.
その他
'prog'
で返した値はログを取ってくれますが,画像などのログは別途自分で登録するようです.
その時はself.experiment
というメンバがいるので,次のようにすればいいです.
記述方法はTensorBoardに準拠しています.
self.experiment.add_image('tag', image_tensor, self.global_step) self.experiment.add_histogram('tag', image_tensor, self.global_step)
学習してみる
学習コードは上の例とほぼ変わらずです.
今回は2,000epoch学習してみます.
下記では書いてないですが,ログとかバリデーションの頻度をTrainer
の引数で指定できるのでよく使います.
import argparse from pytorch_lightning.models.trainer import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from test_tube import Experiment import models def main(): parser = argparse.ArgumentParser() parser.add_argument('--scale_factor', type=int, default=4) ... opt = parser.parse_args() exp = Experiment(save_dir=f'./logs/') # 実験管理用のインスタンス model = models.SRGANModel(opt) # モデルのインスタンス化 # コールバックの定義(重みの保存) checkpoint_callback = ModelCheckpoint( filepath=exp.get_media_path(exp.name, exp.version), ) # Trainerの定義 trainer = Trainer( experiment=exp, max_nb_epochs=2000, checkpoint_callback=checkpoint_callback, gpus=[0] ) # 学習開始! trainer.fit(model) if __name__ == "__main__": main()
学習が始まるとこんな感じでプログレスバーが表示されます.
TensorBoardのログもいい感じに出てくれます.
結果
2000epoch学習後の結果はこんな感じです.概ねうまくいっていると思います.
ちなみに,tng_loader()
と同様にtest_loader()
というメソッドも存在するのですが,test_loader()
を呼び出してテストする機能はまだないらしい*10ので,テストは現状手で書く感じになります.
まとめ
ML研究者向けPyTorchラッパであるPyTorch Lightningについて紹介しました. Kerasで辛いと感じた部分はほぼ感じず,SRGANを実装することができました. 元々PyTorchを使っている人ならほぼ学習コストなく使える代物だと感じました. ただし,現状ではいくつか問題があって,ドキュメントが貧弱だったり速度面に不安があるので,すぐすぐ使えるかといったら微妙な気もします. まだまだ発展途上のライブラリということで,今後の発展に期待です!
*1:書いてる間になんか来ました. - https://github.com/tensorflow/tensorflow/releases
*2:本人達もPyTorch Kerasを謳ってます.
*3:同じ人が作ってる実験管理ツールです. - https://github.com/williamFalcon/test-tube
*4:仕様を理解してないだけですが…
*5:Ledig, Christian, et al. "Photo-realistic single image super-resolution using a generative adversarial network." Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.
*6:https://buildersbox.corp-sansan.com/entry/2019/04/29/110000 で解説しています.
*7:IssueもPRも歓迎です!
*8:いずれも定義しないとNotImplementedErrorを吐きます.
*9:Agustsson, Eirikur, and Radu Timofte. "Ntire 2017 challenge on single image super-resolution: Dataset and study." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshops. 2017.
*10:https://github.com/williamFalcon/pytorch-lightning/issues/89
*11:Zeyde, Roman, Michael Elad, and Matan Protter. "On single image scale-up using sparse-representations." International conference on curves and surfaces. Springer, Berlin, Heidelberg, 2010.