Sansan Tech Blog

Sansanのものづくりを支えるメンバーの技術やデザイン、プロダクトマネジメントの情報を発信

【Zoom or Die】第3回 torchvisionのI/O・前処理が新しくなった話

f:id:S_aiueo321:20200529093604p:plain

こんにちは,DSOC研究開発部 Automation Groupの内田です. 普段オフィスではスタンディングデスクと曲面ディスプレイという環境で作業をしているのですが,秋は設備投資の季節ということで,一念発起して自宅にも曲面ディスプレイを導入しました.ディスプレイの広さは心の余裕ということで,QOLが爆上がりしています.皆さんも導入を検討してみてはいかがでしょうか?

www.amazon.co.jp

宣伝はさておき,今回は PyTorch のエコシステムである torchvision が少し進化した話をしたいと思います.

torchvision

近年の深層学習の隆盛は,簡単に深層モデルを実装できるフレームワークや周辺のエコシステムが整備されたことに起因していると言っても過言ではないでしょう.その中で,TensorFlowとPyTorchは深層学習フレームワークの二大巨塔と目されています.PyTorchは後発でありながら,その直感的な仕様から,近年TensorFlowを追い落とす勢いでシェアを獲得しています.Computer Vision (以下CV) 界隈では,トップ会議に採択された論文の半分以上が実装にPyTorchを利用しているということも話題になっていました.

CV分野においてPyTorchが人気を博す理由の1つとして,torchvisionの存在があります.torchvisionは,画像の前処理や学習済みモデルなどを提供するエコシステムです.画像の前処理に関しては,次のように書けます.

import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


class OleOleDataset(Dataset):
    def __init__(self, data_dir: str) -> None:
        self.filenames = list(Path(data_dir).glob('*.png'))
        self.transforms = transforms.Compose([  # 前処置を積んでく
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

    def __getitem__(self, index: int) -> torch.Tensor:
        filename = self.filenames[index]
        img = Image.open(filename).convert('RGB')
        return self.transforms(img)

    def __len__(self) -> int:
        return len(self.filenames)


dataset = OleOleDataset(data_dir='./foo')
dataloader = DataLoader(dataset)

ここでは,PIL (Python Image Library)によって一度画像を読み込み,PIL上で前処理を行い,最後にテンソルに変換する流れになっています. 一般的にPILはそこまで速度が速いわけではなく,画像を高速に読み込めるライブラリを忍ばせたり,魔改造が割と普通に行われてきました.


そんな一長一短な感もあるtorchvisionでしたが,つい先日,PyTorch v1.7のリリースに伴ってv0.8がリリースされました 🎉

PyTorch 1.7 released w/ CUDA 11, New APIs for FFTs, Windows support for Distributed training and more | PyTorch

主な変更点として,

  1. torchvision.transforms の変換に対するテンソルの入力,バッチ処理,GPU利用,TorchScript化に対応
  2. JPEG/PNGの入出力をネイティブサポート

という期待できそうな文言が見られます.しかもStable.これは試すしかない.

実験

データのI/O・前処理について,次の3つの観点で処理時間の比較を行います.

  1. 画像ファイルからテンソルへの変換
  2. バイト列からテンソルへの変換
  3. GPU上のテンソルへの変換 (前処理込)

全ての実験に共通して,得られるテンソルのデータ型は Float32 であり,[0, 1] に正規化されたRGB値を持っているとします. 実験に用いた画像は baboon で,サイズは512x512,拡張子はPNGです.また,実行環境は次の通りです.

  • インスタンスタイプ: g4dn.xlarge
  • OS: Ubuntu 18.04
  • Python: 3.8.6
  • PyTorch/torchvision: 1.7/0.8

画像ファイルからテンソルへの変換

torchvison 0.8では,torchvision.io.read_image というメソッドが用意されており,JPEGかPNGのファイル名を指定すると直接テンソルを返却します.ここでは,画像読み込みを100回繰り返し,その合計時間をPIL/OpenCVと比較します.

以下に実験に用いたコードと結果を示します.

[PIL] done in 928.280 ms
[OpenCV] done in 901.126 ms
[torchvision] done in 761.186 ms

torchvision が PIL/OpenCV に比べ高速に動作していることがわかります. PIL/OpenCVでは一度データを読み込んでからテンソルへの変換するのに対し,torchvision では直接テンソルへの変換しているためボトルネックが少ないと考えられます *1

バイト列からテンソルへの変換

API上でPyTorchに推論させる場合,バイト列として送り付けられる画像をテンソルに変換する必要があります. 一般的には PIL や NumPy を介して変換しますが,今回のアップデートで decode_image というメソッドが追加されたました. バイト列を直接テンソルに変換できれば,依存ライブラリを減らす意味合いで有用に思えるので,ここでは速度と合わせて検証します. 条件はファイルからの変換とほぼ同様です.

[PIL] done in 949.447 ms
[OpenCV] done in 711.606 ms
[torchvision] done in 1587.219 ms

先ほどとは打って変わり,torchvision では倍近くの時間を要してしまいました. decode_image の入力がテンソルである必要があり,Python組み込みの bytes からの変換に時間がかかることが原因と考えています. 処理速度はともかくとして,decode_image の入力として bytes も扱えるようにした方が直感的ですし,改善されることを願いたいです.

GPUテンソルへの変換 (前処理込)

最後に,ファイルからGPUテンソルへの変換を,前処理も合わせて比較します. ここでは torchvision.transforms のAPIを共通化して比較したいため,PILベースのデータセットとTorchベースのデータセットを比較します. PILベースの方はデータを1枚ずつ読込・前処理をCPU上で行っていく方式です. Torchベースの方は先ほどの read_image で読込だけCPU上で済ませ,以降の前処理はGPU上でバッチ処理していきます. どちらも同じ画像をひたすら読んでくる仕様になっていて,1024枚をバッチサイズ64で取得するため,計16個のバッチを取得します. 各クラスの主なコードの相違点は,T.Composenn.Sequential に,T.ToTensorT.ConvertImageDtype に置き換わっていることです.

[PIL-based] done in 5235.337 ms
[Torch-based] done in 2294.152 ms

流石にGPUを使っているだけあって,2倍以上高速化されています. 前処理をGPU上で行う考え自体,同じくPyTorchのエコシステムである Kornia などで以前から実装されていましたが,torchvisionに組み込まれたことで信頼性が増しました. 個人的には,同じ torchvision.transforms のAPIから,T.Composenn.Sequential を切り替えるだけでGPU実行ができるのは感動しました. 以前は,torchvisionでできることとKorniaでできることが微妙に違ったりしてコードを共通化できない問題がありましたが,このアップデートで互いにすっきりする部分が大きいのでないでしょうか.

まとめ

今回は torchvision 0.8 で追加された画像の入出力および前処理機能についてまとめ,パフォーマンスの比較実験を行いました. 結果からは,学習時におけるI/O簡素化と前処理の高速化については効果が大きいことが分かりました. 対して,バイト列のデコードが低速なこともあり,今回追加された機能を業務利用するにはまだ尚早と感じました. 学習時のみ利用するという選択肢もありますが,学習コードとAPIでデコード方法が変わることになるので,デコードの精度も追加で比較する必要がありそうです.

DSOCにおける名刺データ化では,Frugality*2に常に向き合うこととなります. その中で,パフォーマンスの改善は非常に重要なファクターとなります*3. 今後もこのようなアップデートには常にアンテナを張り,サービスを改善していけたらと思います.


最後に,Sansan DSOCでは最新の技術にキャッチアップしながら,名刺のデータ化および活用を支える仲間を募集しています. もし少しでもご興味をお持ちいただけましたら,下記募集などからカジュアルにご連絡いただければと思います.

hrmos.co

hrmos.co

▼これまでの記事

buildersbox.corp-sansan.com

buildersbox.corp-sansan.com

*1:内部的なところは追えていませんが,画像読み込みだけを比較した場合,どのライブラリも同等の性能を示したため,torchvisionでは1つ処理が少ないイメージを持っています.

*2:英語で倹約,質素を意味します.

*3:なのでデコードが遅いのは少しショックでした😭

© Sansan, Inc.