Sansan Tech Blog

Sansanのものづくりを支えるメンバーの技術やデザイン、プロダクトマネジメントの情報を発信

【Intern CV Report】超解像の歴史探訪 -SRGAN編-

f:id:S_aiueo321:20190404193803p:plain

こんにちは,DSOC R&D インターン生の内田です. 時が流れるのは早いものでもう4月末,連載も2回目を迎えました. もう少しで平成も終わってしまうので,平成生まれの僕は少し寂しさを感じています.

さて,前回の記事では2016年に発表された超解像モデルについて紹介しました. 今回は2017年にフォーカスしたいと思います. その中でも特に有名なSRGAN[*1]についてまとめます. なお,本記事では以前に紹介した手法や略語が前置き無しに登場する場合がありますので,不明な点がございましたら過去の記事をご参照ください. buildersbox.corp-sansan.com buildersbox.corp-sansan.com

GAN

SRGANの説明に入る前に,流行りのGANについて触れておきます. ご存知の方は読み飛ばしていただければと思います.

GANはGenerative Adversarial Networkの略であり,日本語に訳すと「敵対的生成ネットワーク」となります. 「生成モデル」を「敵対的学習」によって獲得する枠組みであるためこのような名前がついています. 「生成モデル」とは,データを生成する確率分布を学習し新たなデータをサンプルできるモデルを指します. 「敵対的学習」とは,生成器  G を学習するために,識別器  D を同時に学習する方法を指します.

 Dには学習データ及び  G の生成結果を入力し,入力が学習データからサンプリングされている確率を出力します. ざっくり言うと,  D は「本物っぽさ」を出力します.  D の学習では,学習データの分布  p_\mathrm{real} からサンプルされた入力に対して 1,生成結果の分布  p_\mathrm{fake} からサンプルされた入力に対して 0を出力するようにしたいため,次の学習基準 \mathcal{C}_Dを最大化すればよいと言えます. $$ \mathcal{C}_D = \begin{cases} \log D(\boldsymbol{x}) & \text{if $\boldsymbol{x} \sim p_\mathrm{real}$} \\ \log \left( 1-D(\boldsymbol{x}) \right) & \text{if $\boldsymbol{x} \sim p_\mathrm{fake}$} \end{cases} $$

 G は,  D を間違えさせるようなデータを生成することを目的とします. したがって, G の学習では,何らかの分布 p_{\rm z}からサンプルされた入力ベクトル  \boldsymbol{z}に対して,次の学習基準 \mathcal{C}_Gを最小化すればよいと言えます. $$ \mathcal{C}_G = \log(1-D(G(\boldsymbol{z}))) \text{where $\boldsymbol{z} \sim p_\mathrm{z}$}$$

 \boldsymbol{x} \sim p_\mathrm{fake}のとき \boldsymbol{x}\ = G(\boldsymbol{z}) であることに注目して,  \mathcal{C}_G \mathcal{C}_D をまとめると次の目的関数 V(D, G)のようになります. ただし, D G とで最適化の方向が異なるため,GANの学習は V(D, G)のmin-max最適化として定式化されます. $$ \mathop{\rm min}\limits_{G}\mathop{\rm max}\limits_{D} V(D,G) = \mathbb{E}_{\boldsymbol{x} \sim p_{\rm data}}[\log D(\boldsymbol{x})] + \mathbb{E}_{\boldsymbol{z} \sim p_{\rm z}}[\log (1-D(G(\boldsymbol{z})))] $$

ここで紹介したのは,Vanilla GANという最もベーシックな手法です. SRGANをはじめとするGANの派生では誤差関数にMSE(Mean Square Error)などの項も入ってくるため,必ずしも上記の形をとる訳ではありませんが, D Gを逆方向に最適化するというお気持ちは変わりません.

SRGAN

SRGAN(Generative Adversarial Network for Super-Resolurion)は超解像にGANを初めて適用したモデルで,論文のタイトルに「Photo-realistic」とあるように,従来手法に比べて結果の知覚品質が高く,自然であるとして有名です. 近年提案されている超解像モデルの多くがSRGANをベースとしていて,直近のECCV2018に併設されたPIRM Challenge on Perceptual Super Resolution[*2]においても,SRGANの拡張であるESRGAN[*3]が知覚品質を競うトラックで優勝を納めています.

以下では,SRGANの構造の解説と実装,GANを導入する効果について実験を行います. コードは下記で公開しておりますので,不明な点や詳細はぜひこちらでご覧ください.「Issue」を投げていただいたら,喜んでお受けします. github.com

また,実装にあたり次のレポジトリを参考にしています.

生成器

SRGANの生成器はResNet[*4]を参考にしており,識別器を伴わずに学習した場合SRResNetと呼ばれます. ResNetは,畳み込みの入出力を足し合わせて出力するResidual Blockを積み重ねたネットワークです. SRGANでも同様に,次のようなResidual Blockを積み重ねた構造をしています.

f:id:S_aiueo321:20190408132248p:plain
SRResNetの構造(x4)
ここでは,図中の各ブロックを定義した後,生成器を組んでいきます.

Residual Block

畳み込み,バッチ正規化,活性化を直列に適用するため,nn.Sequentialを用いて1つのモジュールにまとめることができます. forward()で,入力xとコンストラクタで定義したモジュールの出力self.net(x)を足し合わせて返り値とします.

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels),
            nn.PReLU(),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels),
        )

    def forward(self, x):
        return self.net(x) + x

Upsample Block

ここでは,前回紹介したESPCN同様,Subpixel Convolutionを用います. まず,入力特徴マップを倍率upscale_factorの2乗になるよう,畳み込みを用いてチャネル方向に拡張します. 次に,拡張した特徴マップを再配置します.PyTorchではnn.PixelShuffleとして用意されているのでそちらを用います. 最後の活性化も含め,nn.Sequentialを用いて1つにまとめます.

class UpsampleBLock(nn.Module):
    def __init__(self, in_channels, upscale_factor):
        super(UpsampleBLock, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, in_channels *
                      (upscale_factor ** 2), kernel_size=3, padding=1),
            nn.PixelShuffle(upscale_factor),
            nn.PReLU()
        )

    def forward(self, x):
        return self.net(x)

全体

SRResNetにはSkip-connectionがあり,単純にnn.Sequentialでまとめられないため,図のようにHead,Body,Tailの3部分に分けます. Residual Blockは16回,UpsampleBLockはint(log2(upscale_factor))回繰り返すように内包表記で記述します. また,出力はnn.Tanhを用いて (-1, 1)に正規化し,forward時に (0, 1)にスケーリングします.

class Generator(nn.Module):
    def __init__(self, scale_factor=4):
        super(Generator, self).__init__()
        self.head = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=9, padding=4),
            nn.PReLU()
        )
        self.body = nn.Sequential(
            *[ResidualBlock(64) for _ in range(16)],
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.PReLU()
        )
        self.tail = nn.Sequential(
            *[UpsampleBLock(64, 2) for _ in range(int(log2(scale_factor)))],
            nn.Conv2d(64, 3, kernel_size=9, padding=4),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.head(x)
        x = self.body(x) + x
        x = self.tail(x)
        return (x + 1) / 2

識別器

識別器は次のような構造をしています.

f:id:S_aiueo321:20190415110340p:plain
識別器の構造
単純に畳み込みを積み重ねていくだけなので,nn.Sequentialにまとめて書いてしまいます. torch.nn内には画像をフラットにするモジュールが定義されていないため,nn.Sequentialに組み込めるように,自前でFlattenモジュールを定義しておきます.

class Discriminator(nn.Module):
    def __init__(self, patch_size=96):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            Flatten(),
            nn.Linear(512 * (patch_size // 2 ** 4) ** 2, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.shape[0], -1)

損失関数

SRGANの学習では,生成器と識別器にそれぞれGenerator LossとDiscriminator Lossを定義します. Discriminator Lossは一般的なGANと変わらないですが,Generator Lossにはひと工夫あるため,ここではDiscriminator Lossをさらっと流してGenerator Lossについて重点的に解説します.

Discriminator Loss

識別器の解く問題は画像が本物かどうかの二値分類問題なので, nn.BCELossを用いて誤差関数を定義します. forward()には高解像度画像と超解像画像を入力し,高解像度画像を本物,超解像画像を偽物と判別できているか誤差を計算しています. 少しややこしいですが,以後もrealとつく変数は高解像度画像,fakeとつく変数は超解像画像に関わる変数とします.

class DiscriminatorLoss(nn.Module):
    def __init__(self):
        super(DiscriminatorLoss, self).__init__()
        self.bce_loss = nn.BCELoss()

    def forward(self, d_out_real, d_out_fake):
        loss_real = self.bce_loss(d_out_real, torch.ones_like(d_out_real))
        loss_fake = self.bce_loss(d_out_fake, torch.zeros_like(d_out_fake))
        return loss_real + loss_fake

Generator Loss

Generator Lossでは,生成画像と正解画像を何らかの距離尺度で比較するContent Loss,識別器の出力を最小化するAdversarial Lossを同時に最適化します. 論文ではContent LossにMSEとVGG Loss(後述)のどちらかを用いますが,手元での学習が安定しないため,今回は両方を足し合わせて用います. 実装すると次のようになります.

class GeneratorLoss(nn.Module):
    def __init__(self, loss_type='vgg22', adv_coefficient=1e-3):
        super(GeneratorLoss, self).__init__()
        self.content_loss = VGGLoss(loss_type)
        self.mse_loss = nn.MSELoss()
        self.adv_coefficient = adv_coefficient

    def forward(self, d_out_fake, real_img, fake_img):
        mse_loss = self.mse_loss(real_img, fake_img)
        content_loss = self.content_loss(real_img, fake_img)
        adv_loss = torch.mean(-torch.log(d_out_fake + 1e-3))
        return mse_loss + 2e-6 * content_loss + self.adv_coefficient * adv_loss

VGG Lossは,生成画像と正解画像をVGG19[*5]に入力し,中間特徴マップのMSEを取ることで画像を比較します.  \phi_{i,j} j番目のプーリング層の前の i番目の畳み込み層を表すとして,中間特徴マップには \phi_{2, 2} \phi_{5, 4}を用います. 一般に,MSEで学習したモデルはぼやけた画像を生成すると言われていますが,VGG Lossを用いることで改善されると言われています. VGG Lossの中身は次のようになっています.

from torchvision.models.vgg import vgg19

class VGGLoss(nn.Module):
    def __init__(self, loss_type):
        super(VGGLoss, self).__init__()
        vgg = vgg19(pretrained=True)
        if loss_type == 'vgg22':
            vgg_net = nn.Sequential(*list(vgg.features[:9]))
        elif loss_type == 'vgg54':
            vgg_net = nn.Sequential(*list(vgg.features[:36]))
        
        for param in vgg_net.parameters():
            param.requires_grad = False

        self.vgg_net = vgg_net.eval()
        self.mse_loss = nn.MSELoss()

        self.register_buffer('vgg_mean', torch.tensor([0.485, 0.456, 0.406], requires_grad=False))
        self.register_buffer('vgg_std', torch.tensor([0.229, 0.224, 0.225], requires_grad=False))

    def forward(self, real_img, fake_img):
        real_img = real_img.sub(self.vgg_mean[:, None, None]).div(self.vgg_std[:, None, None])
        fake_img = fake_img.sub(self.vgg_mean[:, None, None]).div(self.vgg_std[:, None, None])
        feature_real = self.vgg_net(real_img)
        feature_fake = self.vgg_net(fake_img)
        return self.mse_loss(feature_real, feature_fake)

コンストラクタでは,vgg19の学習済みモデルをロードして指定した層までをnn.Sequentialに確保しておきます. このとき,パラメータ更新は必要ないのでrequires_grad=Falseとして勾配計算を行わないようにします. また,VGGの入力をスケーリングするための平均,標準偏差をバッファとして確保しておきます. forward()では,確保した値を用いて入力をスケーリングし,中間特徴を計算した後,二乗誤差を計算しています.

学習

SRGANの学習では,ミニバッチ毎に識別器と生成器を1回ずつ学習します. データローダとネットワーク,オプティマイザはすでに定義済みとして,1epoch分のループを書くと次のようになります. GANの学習では,同じ計算グラフに2回以上誤差を逆伝播させることがよくあるので,backward(retain_graph=True)とすることを忘れないように注意が必要です.

for iteration, (input_img, real_img) in enumerate(train_loader, 1):
    input_img = input_img.to(device)
    real_img = real_img.to(device)
    fake_img = netG(input_img)

    # Update D
    optimizerD.zero_grad()
    d_out_real = netD(real_img)
    d_out_fake = netD(fake_img)
    d_loss = criterionD(d_out_real, d_out_fake)
    d_loss.backward(retain_graph=True)
    optimizerD.step()

    # Update G
    optimizerG.zero_grad()
    g_loss = criterionG(d_out_fake, real_img, fake_img)
    g_loss.backward()
    optimizerG.step()

実験

実験条件

ここでは,下記条件のもとでSRResNetとSRGANを学習し,比較を行います.

  • 倍率: 4
  • エポック数: 200
  • バッチサイズ: 16
  • 画像サイズ(HR): 96
  • 学習率:  10^{-4} (SRGANでは 10^{5}イテレーション毎に \times 0.1)
  • 誤差関数
    • SRResNet: MSE loss
    • SRGAN: MSE loss + VGG loss + Adversarial loss

学習データにはPASCAL VOC2012[*6]を用います. PASCAL VOC2012の画像総数は17125枚で,学習データに16700枚,バリデーションデータに425枚にランダムに振り分けます. なお,短辺96px以下のデータが含まれていたため,そちらはバリデーションデータに含めるようにしています. また,テストデータにはSet5[*7]を用います.

結果

200epoch学習したモデルをPSNR/SSIMを算出すると,SRResNetは29.74/0.8610,SRGANは25.41/0.7352となり,数値としてはSRResNetが圧倒的に優位です. それでは実際の出力画像がどうなっているか下図で確認してみましょう. 図中左から,低解像度画像をBicubic補間で拡大した画像,SRResNetの出力,SRGANの出力,高解像度画像となっています. さらにSRResNetとSRGANの比較のため,ズームした画像を右下に示してあります.

f:id:S_aiueo321:20190417160935p:plain
出力画像の比較 (画像はSet5のBaby)
見た目上,SRResNetとSRGANどちらも低解像度画像よりくっきりとしていることが確認でき,改善が見られます. ズームした画像の帽子部分を見ると,SRResNetにはのっぺりしたような印象がある一方で,SRGANには細かなざらつきが見られます. 感じ方は人それぞれですが,人間が目にする自然画像には定常的なノイズが乗っているため,後者の方が自然であると捉えられることが多いです.

実験結果から,GANの有無によって評価指標や見た目に一定の差が出ることがわかりました. 後に,この差についてThe Perception-Distortion Tradeoff[*8]という論文が発表され,評価指標と自然さにはトレードオフがあることが理論的に考察されています. 論文中で,これらのトレードオフはAdversarial lossの係数(今回は 10^{-3}に設定)によってコントロールできるとし,最適な動作点は応用先によって個別に決められるべきと主張されています.

まとめ

今回は,超解像における大きなブレークスルーであるSRGANについてまとめ,実験により超解像におけるGANの効果について確認しました. 現段階で業務で扱う超解像APIにGANの導入には至っていませんが,OCR精度に寄与する範囲を,トレードオフのコントロールをして検証してみる価値はあるかもしれません. とはいえ,SRGAN以後にもGANを用いないモデルは数多く発表されています. ちょうど来月頃にはCV系のトップ会議であるCVPRの論文にアクセスできるようになるため,次回は最新の動向についてまとめたいと思います.

参考文献

*1:Ledig, Christian, et al. "Photo-realistic single image super-resolution using a generative adversarial network." Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.

*2:Blau, Yochai, et al. "The 2018 PIRM Challenge on Perceptual Image Super-Resolution." European Conference on Computer Vision. Springer, Cham, 2018.

*3:Wang, Xintao, et al. "ESRGAN: Enhanced super-resolution generative adversarial networks." European Conference on Computer Vision. Springer, Cham, 2018.

*4:He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.

*5:Simonyan, Karen, and Andrew Zisserman. "Very deep convolutional networks for large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014).

*6:Everingham, Mark, et al. "The pascal visual object classes challenge: A retrospective." International journal of computer vision 111.1 (2015): 98-136.

*7:Bevilacqua, Marco, et al. "Low-complexity single-image super-resolution based on nonnegative neighbor embedding." (2012): 135-1.

*8:Blau, Yochai, and Tomer Michaeli. "The perception-distortion tradeoff." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018.

© Sansan, Inc.