Sansan Tech Blog

Sansanのものづくりを支えるメンバーの技術やデザイン、プロダクトマネジメントの情報を発信

【Intern CV Report】PyTorch Lightningで超解像してみる

f:id:S_aiueo321:20190404193803p:plain

こんにちは,DSOC R&Dグループ インターン生の内田です. 涼しくなってきたにも関わらずエアコンガンガンで寝てしまって風邪気味な今日この頃です. 皆様も体調にはお気をつけくださいませ.

さて,私はかねてより深層学習用フレームワークとしてTensorFlowを使っていたのですが,最近はPyTorchにドップリです. TensorFlowのメジャーアップデートがなかなか来ない*1ために,帰巣のタイミングを失っているという経緯です. PyTorchを書き下すだけでも十分簡単に実装可能ですが,実験サイクルを回したい場合には,もう少し高レベルなラッパがあると便利です. 普段はawesomeなレポジトリを参考にしながらオレオレラッパを書いているのですが,先日Twitterで次のような情報が回ってきました.

これは使ってみるしか.

PyTorch Lightningについて

PyTorch LightningはML研究者向けに設計された軽量なPyTorchラッパで,TensorFlowにおけるKerasに相当するパッケージです*2. 学習ループや早期終了,モデルの保存と読み出しなどを自動化し,新規プロジェクトにおいて都度発生する研究の本質でない手間を減らしてくれます. 実際どれくらい手間を省いてくれるかというと,次の図で説明されています. 青色部分のみを記述すればよくなるらしい.

f:id:S_aiueo321:20190813160317j:plain
生PyTorch vs PyTorch Lightning (レポジトリから引用)
それ以外にも,Test Tube*3と組み合わせることでロギングやコードのバージョニングをいい感じにやってくれたり,マルチノード学習や量子化をサポートしているなどの特徴があるようです. もっと詳細に知りたい方は,下記レポジトリのドキュメントをご覧いただければと思います. github.com

⚡️ライトニングPyTorch Lightning⚡️

最小限の学習コードは次のようになります.

from pytorch_lightning.models.trainer import Trainer

model = MyLightningModule()  # 自分で定義する

trainer = Trainer()
trainer.fit(model)

モデル自体はもちろんですが,データローダの定義とか誤差計算をMyLightningModuleクラス内に定義しておけばすぐに学習を回せるので,モデルの改良に集中できるというわけですね.

Kerasも同じようにfit()一発で学習できると謳ってはいるものの,GANのように複数のモデルを順に学習するモデルではなんとなく癖のある感じになります*4. 正直辛い印象だったので,私個人こういうラッパはあまり使ってきませんでした. 今回はその感情を払拭できればいいなと思い,GANを用いた超解像モデルであるSRGAN[*5] を実装してみます*6

SRGANの実装

以下ではかい摘んで説明するので,多少コードに抜けがあります. わからない部分などはコードは次に置いてありますので参照ください*7

github.com

LightningModule

PyTorch Lightningのモデルクラスはpl.LightningModuleを継承して構築します. pl.LightningModuleはデータローダやオプティマイザの定義,誤差計算など諸々を含むのですが,PyTorchの標準モジュールクラスnn.Moduleの孫クラスに当たるため, 従来通りコンストラクタでモジュールを順々に定義しても良いですし,既存のネットワークを丸々読み込むことも可能です. ネットワークやら自前の誤差関数は事前定義済みとすると,モデルクラスのコンストラクタおよびforward()は次のようになります.

import pytorch_lightning as pl
import torch.nn as nn

from .networks import SRResNet, Discriminator
from .losses import GANLoss, TVLoss, VGGLoss

class SRGANModel(pl.LightningModule):
    def __init__(self, opt):

        super(SRGANModel, self).__init__()

        # パラメータの保存
        self.scale_factor = opt.scale_factor
        self.batch_size = opt.batch_size
        self.patch_size = opt.patch_size

        # ネットワーク定義
        self.net_G = SRResNet(opt.scale_factor, opt.ngf, opt.n_blocks)
        self.net_D = Discriminator(opt.ndf)

        # 誤差関数の定義
        self.criterion_MSE = nn.MSELoss()
        self.criterion_VGG = VGGLoss(net_type='vgg19', layer='relu5_4')
        self.criterion_GAN = GANLoss(gan_mode='wgangp')
        self.criterion_TV = TVLoss()

    def forward(self, input):  # メインのネットワークだけに通す
        return self.net_G(input)

ここまではnn.Moduleと同じ感じですが,PyTorch Lightningではこれに加えて次の3メソッドを定義する必要があります *8

tng_dataloader()

PyTorch標準のDataLoaderインスタンスを返却します.@pl.data_loaderの記述をお忘れなく.

from torch.utils.data.dataloader import DataLoader

from .datasets import DatasetFromFolder

...

@pl.data_loader
def tng_dataloader(self):
    dataset = DatasetFromFolder(
        data_dir='./data/DIV2K/train',
        scale_factor=self.scale_factor,
        patch_size=self.patch_size
    )
    return DataLoader(dataset, self.batch_size, shuffle=True, num_workers=4)

ちなみに今回はdata_dir内の画像をランダムクロップして低/高解像度画像ペアを返却するデータローダを定義しています. 読み込むデータセットは超解像タスクでよく使われるDIV2K[*9]を用います.

学習用だけでなく,バリデーションにも同様のメソッドが定義でき,定義するとよしなにやってくれます. これに関しては複数のデータローダを返却することも可能です.

configure_optimizers()

PyTorch標準のオプティマイザおよび学習率スケジューラを返却します. GANの場合,生成器と識別器にそれぞれオプティマイザを用意するため,学習する順番にリストで格納して返却します. 対応するスケジューラも同様にリストで返却します.

import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

...

def configure_optimizers(self):
    optimizer_G = optim.Adam(self.net_G.parameters(), lr=1e-4)
    optimizer_D = optim.Adam(self.net_D.parameters(), lr=1e-4)
    scheduler_G = StepLR(optimizer_G, step_size=1e+5, gamma=0.1)
    scheduler_D = StepLR(optimizer_D, step_size=1e+5, gamma=0.1)
    return [optimizer_D, optimizer_G], [scheduler_D, scheduler_G]

今回は識別器→生成器の順に学習するので,その順番に返却しています.

training_step(data_batch, batch_nb, optimizer_i)

DataLoaderから取得したミニバッチdata_batchをネットワークに順伝播し,誤差を計算します. batch_nboptimizer_iにはそれぞれ何番目のバッチ・オプティマイザかを示す整数が入ってきます.

def training_step(self, data_batch, batch_nb, optimizer_i):
    img_lr = data_batch['lr']
    img_hr = data_batch['hr']

    if optimizer_i == 0:  # 識別器
        self.img_sr = self.forward(img_lr)

        # HR画像
        d_out_real = self.net_D(img_hr)
        d_loss_real = self.criterion_GAN(d_out_real, True)
        # SR画像
        d_out_fake = self.net_D(self.img_sr.detach())
        d_loss_fake = self.criterion_GAN(d_out_fake, False)

        # がっちゃんこ
        d_loss = 1 + d_loss_real + d_loss_fake

        return {'loss': d_loss, 'prog': {'tng/d_loss': d_loss}}

    elif optimizer_i == 1:  # 生成器
        # content loss
        mse_loss = self.criterion_MSE(self.img_sr * 2 - 1, img_hr * 2 - 1)
        vgg_loss = self.criterion_VGG(self.img_sr, img_hr)
        content_loss = (vgg_loss + mse_loss) / 2
        # adversarial loss
        adv_loss = self.criterion_GAN(self.net_D(self.img_sr), True)
        # tv loss
        tv_loss = self.criterion_TV(self.img_sr)

        # がっちゃんこ
        g_loss = content_loss + 1e-3 * adv_loss + 2e-8 * tv_loss

        return {'loss': g_loss, 'prog': {'tng/g_loss': g_loss,
                                         'tng/content_loss': content_loss,
                                         'tng/adv_loss': adv_loss,
                                         'tng/tv_loss': tv_loss}}

GANの学習ではオプティマイザが複数あるため,optimizer_iによって条件分岐してます. 1バッチの学習につき1度だけ推論が走ればいいため,optimizer_i == 0の時の結果をインスタンス変数として確保しておいて,その後は確保した結果を用いて誤差計算を行います. 返却値は辞書型で,'loss'に誤差そのものを格納して,'prog'にはログを取りたい物の名前と値の辞書で格納します. また,tng_dataloader()と同様に,バリデーション用にvalidation_stepを定義するとよしなに実行されます.

その他

'prog'で返した値はログを取ってくれますが,画像などのログは別途自分で登録するようです. その時はself.experimentというメンバがいるので,次のようにすればいいです. 記述方法はTensorBoardに準拠しています.

self.experiment.add_image('tag', image_tensor, self.global_step)
self.experiment.add_histogram('tag', image_tensor, self.global_step)

学習してみる

学習コードは上の例とほぼ変わらずです. 今回は2,000epoch学習してみます. 下記では書いてないですが,ログとかバリデーションの頻度をTrainerの引数で指定できるのでよく使います.

import argparse

from pytorch_lightning.models.trainer import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from test_tube import Experiment

import models

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--scale_factor', type=int, default=4)
    ...
    opt = parser.parse_args()

    exp = Experiment(save_dir=f'./logs/')  # 実験管理用のインスタンス
    model = models.SRGANModel(opt)  # モデルのインスタンス化

    # コールバックの定義(重みの保存)
    checkpoint_callback = ModelCheckpoint(
        filepath=exp.get_media_path(exp.name, exp.version),
    )

    # Trainerの定義
    trainer = Trainer(
        experiment=exp,
        max_nb_epochs=2000,
        checkpoint_callback=checkpoint_callback,
        gpus=[0]
    )

    # 学習開始!
    trainer.fit(model)


if __name__ == "__main__":
    main()

学習が始まるとこんな感じでプログレスバーが表示されます.

f:id:S_aiueo321:20190827134709p:plain
学習の様子

TensorBoardのログもいい感じに出てくれます.

f:id:S_aiueo321:20190827133532p:plain
TensorBoardの出力①
f:id:S_aiueo321:20190827133606p:plain
TensorBoardの出力②

結果

2000epoch学習後の結果はこんな感じです.概ねうまくいっていると思います. ちなみに,tng_loader()と同様にtest_loader()というメソッドも存在するのですが,test_loader()を呼び出してテストする機能はまだないらしい*10ので,テストは現状手で書く感じになります.

f:id:S_aiueo321:20190828104246p:plain
2000epoch学習した結果(Set14[*11] 'baboon')

まとめ

ML研究者向けPyTorchラッパであるPyTorch Lightningについて紹介しました. Kerasで辛いと感じた部分はほぼ感じず,SRGANを実装することができました. 元々PyTorchを使っている人ならほぼ学習コストなく使える代物だと感じました. ただし,現状ではいくつか問題があって,ドキュメントが貧弱だったり速度面に不安があるので,すぐすぐ使えるかといったら微妙な気もします. まだまだ発展途上のライブラリということで,今後の発展に期待です!

*1:書いてる間になんか来ました. - https://github.com/tensorflow/tensorflow/releases

*2:本人達もPyTorch Kerasを謳ってます.

*3:同じ人が作ってる実験管理ツールです. - https://github.com/williamFalcon/test-tube

*4:仕様を理解してないだけですが…

*5:Ledig, Christian, et al. "Photo-realistic single image super-resolution using a generative adversarial network." Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.

*6:https://buildersbox.corp-sansan.com/entry/2019/04/29/110000 で解説しています.

*7:IssueもPRも歓迎です!

*8:いずれも定義しないとNotImplementedErrorを吐きます.

*9:Agustsson, Eirikur, and Radu Timofte. "Ntire 2017 challenge on single image super-resolution: Dataset and study." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshops. 2017.

*10:https://github.com/williamFalcon/pytorch-lightning/issues/89

*11:Zeyde, Roman, Michael Elad, and Matan Protter. "On single image scale-up using sparse-representations." International conference on curves and surfaces. Springer, Berlin, Heidelberg, 2010.

© Sansan, Inc.