Sansan Tech Blog

Sansanのものづくりを支えるメンバーの技術やデザイン、プロダクトマネジメントの情報を発信

Flash Attention 2 + 量子化でVLLMはどこまで軽くなる?ローカル運用に向けた画像枚数とメモリ使用量の検証

はじめに

Sansan 技術本部 研究開発部の齋藤 慎一朗です。

最近、VLLM(Vision Large Language Model)やLLM(Large Language Model)をプロダクト応用できるかの検証、そのリリース関連の仕事をすることが増えています。 VLLMやLLMをローカル運用(ベンダーが提供するAPIを利用するのではなく、商用利用可能なモデルを自社の環境で運用すること)することを想定する場合、多くの不確実性が存在します。例えば、以下のような点が不確実となります。

  1. モデルが、プロダクトの価値につながるほど高い精度を出せるのか?
  2. 運用環境はどの程度の費用が必要なのか?
  3. 運用環境の費用は、モデルが生み出す価値と比較して高すぎないか?

もちろんPJや企業によるかと思いますが、弊社のPJについて、1は本格的な検証(例えば大量のデータを用いたFine-Tuningなど)を通して判断することが多いです。一方、簡単な検証により判断可能な不確実性については、事前に可能な限り不確実性を減らすことで、PJを成功させる確率を高めることができます。

本記事では、複数存在する不確実性のうち、「運用環境はどの程度の費用が必要なのか?」について、VLLMの利用を想定し、実際に私が簡易検証した方法を共有します。

運用環境の費用の見積もり

運用環境の費用を見積もりたい場合、主な要件はメモリとスループットになるかと思います。

  • どの程度のメモリがあれば、問題なく推論が行えるのか?
  • どの程度のスループットを担保すれば、運用面で問題ないか?(もしくは、スループットが重要な同期処理ではなく、非同期処理による運用が可能か?)

本記事では、メモリに関する検証方法とその結果を紹介します。

メモリの観点で運用費用がどの程度変わるのか?

例えば、必ず画像4枚をinputにして推論したいとします。4枚の画像をinputにしたモデルが安定して稼働できる環境を選び、その運用費用を計算する必要があります。

仮に、AWSのg5.xlarge(A10, GPU メモリ 24GB)を運用環境に選択する場合、本ブログ執筆時点では、オンデマンド料金/時間が1.006USDです。よって、インスタンス1つを1ヶ月常時運用した場合は、約724USD、つまり約10万円/月となります。

一方、p4d.24xlarge(A100 * 8, GPU メモリ 40 * 8 = 320GB)が必要な場合、同様の計算をすると、約338万円/月の運用費用がかかります。

どの程度運用費用を負担可能かは、VLLM、LLMを使ってどのようなプロダクト価値を生み出すかに依存します。

本記事の目的

利用モデル、想定の運用環境、出力したい最大トークン数が決まっている際に、VLLMに対して画像を何枚まで入力可能か検証します。 検証した結果、想定よりも入力可能な画像枚数が少なかった場合、モデルをより軽量にする、運用環境をよりハイメモリにする、などの意思決定を行えます。

条件

利用モデル

想定の運用環境

  • AWS EC2 g5.xlarge(A10 GPU メモリ 24GB)

  • 選定理由: FlashAttentionが利用可能であること、費用がそこまで高くないこと

利用する画像サイズ

  • 872 × 1242

常に1枚の同じ画像を利用します。複数枚inputする際には、同じ画像を複数回利用することになります。

プロンプト

続きを、可能な限り長く生成してください。昔々あるところに、

確実に512token出力されるようにするため、inputの画像に関わらず限界までテキストが生成されるプロンプトにします。 別のVLLMのモデルを試した際には、プロンプトの指示に従わず、すぐに出力をやめてしまったこともあったため、より良い方法はありそうです。

注意点として、今回は入力と出力が全く無関係の状況を想定して実験を行なっています。厳密に考えると、入力に応じて出力されるトークンが変化し、そのメモリ使用量も変化するかと考えられます。ただし、大きな違いはないと考え、今回の検証を行いました。

実験する枚数

  • 1 ~ 30枚

出力したい最大トークン数

512

実験パターン

Flash Attention 2はAttention機構における、入出力行列以外に追加で必要となるメモリ使用量を、 O(N2) から O(N)にすることができます。 また、量子化の手法はいろいろありますが、bnbを用いた場合、8bit量子化では、LLM.int8、4bit量子化では QLoRA における NF4が利用されています。

実験

実験に用いたコードは以下です。

from transformers import (
    Qwen2VLForConditionalGeneration,
    AutoProcessor,
    BitsAndBytesConfig,
)
from qwen_vl_utils import process_vision_info
import torch

MAX_TOKENS = 512


def prediction(model, processor, image_num: int):
    image_contents = [
        {
            "type": "image",
            "image": "for_blog/input/1.png",
        }
        for _ in range(image_num)
    ]

    image_contents.append(
        {
            "type": "text",
            "text": "続きを、可能な限り長く生成してください。昔々あるところに、",
        }
    )

    messages = [
        {
            "role": "user",
            "content": image_contents,
        }
    ]

    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(model.device)

    try:
        print(f"{MAX_TOKENS=}")
        generated_ids = model.generate(**inputs, max_new_tokens=MAX_TOKENS)
        generated_ids_trimmed = [
            out_ids[len(in_ids) :]
            for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = processor.batch_decode(
            generated_ids_trimmed,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )
        print(output_text)
    except torch.cuda.OutOfMemoryError:
        print(f"OOM at {image_num} pages\n")
    finally:
        print(f"Peak memory: {torch.cuda.max_memory_allocated() / (1024**3):.1f} GB")
        del inputs, generated_ids, generated_ids_trimmed
        torch.cuda.empty_cache()


if __name__ == "__main__":
    torch.cuda.reset_peak_memory_stats()

    # quantization_config = BitsAndBytesConfig(load_in_8bit=True)
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
    )

    model_path = "Qwen/Qwen2-VL-7B-Instruct"
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        model_path,
        torch_dtype="auto",
        device_map="auto",
        quantization_config=quantization_config,
        # attn_implementation="eager", # ベースライン
        attn_implementation="flash_attention_2", 
    )
    processor = AutoProcessor.from_pretrained(
        model_path,
    )

    for image_num in range(1, 30 + 1):
        print(f"{image_num=}")
        prediction(model, processor, image_num)

結果

まずは、設定ごとに、OOMせずに処理可能な画像の枚数を示します。

実験名 最大処理画像数
1: ベースライン 1枚
2: Flash Attention 2 10枚
3: 8bit量子化 1枚
4: 4bit量子化 1枚
5: 8bit量子化 + Flash Attention 2 20枚
6: 4bit量子化 + Flash Attention 2 28枚

次に、利用した GPUメモリをグラフとして表示します。24GBのGPUを利用しているため、最大は24GBとなります。

結果から、Flash Attention 2の効果が強力だと分かります。Flash Attention 2の導入により、画像2枚でOOMしていた構成が、10枚まで処理可能になりました。Flash Attention 2は、Attention機構のメモリ使用量を O(N2)から O(N) にするため、入力画像、つまりinputトークンが線形に増えていく今回の実験では効果が出やすいアルゴリズムであったと考えられます。

次に、量子化にも効果はありました。ただし、量子化のみだと、2 枚目でOOMしてしまいました。量子化はモデル自体の重みを小さくすることでメモリ節約はできますが、Self Attentionが O(N2) で増えていく影響の方が今回は大きかったことが分かります。

最後に、量子化 + Flash Attention 2を用いると、4bitにおいては画像を28枚、8bitにおいては画像を20枚入れてもOOMしませんでした。

よって、今回の検証を通して、「4bit量子化 + Flash Attention 2を用いた Qwen2-VL-7B-Instructを、AWS EC2 g5.xlargeで運用した場合、メモリの観点に限定すると、画像を28枚までinput可能である。」ことが分かりました。

一方、実際の運用時には、限界に対して少しメモリ的なマージンを持つ方が安全であると考えられます。

また、ある出力例は以下でした。文章は破綻していますが、限界までトークンが出力されていることが分かります。

昔々あるところに、一つの小さな村がありました。その村には、人々が暮らすために必要なすべての物が豊富に存在していました。しかし、村の中心には一つの大きな問題がありました。それは、村の水道が常に不足していたことです。\n\n村の水道は、村の周りに広がる森の一部を水源としていました。しかし、森の一部が開けられ、村の周りに建物が建てられると、水道の水量が減少しました。村の住民たちは、水道の水量が不足するのを心配していました。\n\nある日、村の長老たちは、村の水道の水量を増やすために森の一部を再開けることを提案しました。しかし、村の住民たちは、森の一部を再開けることで、森の生態系が破壊され、村の環境が悪化する可能性があると懸念していました。\n\n村の長老たちは、村の住民たちの懸念を理解しましたが、村の水道の水量を増やすために森の一部を再開けることを提案しました。村の住民たちは、森の一部を再開けることで、村の水道の水量を増やすことができると信じました。\n\n村の住民たちは、森の一部を再開けることで、村の水道の水量を増やすことができると信じました。村の住民たちは、森の一部を再開けることで、村の水道の水量を増やすことができると信じました。村の住民たちは、森の一部を再開けることで、村の水道の水量を増やすことができると信じました。村の住民たちは、森の一部を再開けることで、村の水道の水量を増やすことができると信じました。村の住民たちは、森の一部を再開けることで、村の水道の水量を増やすことができると信じました。村の住民たちは、森の一部を再開けることで、村の水道の水量を増やすことができると信じました。村の住民たちは、森の一部を再開けることで、村の水道の水量を増やすことができると信じました。村の住民たちは、森の一部を再開けることで、村の水道の水量を増やすことができると信じ

注意点

今回の検証結果は、あくまでメモリの観点での調査のみに留めています。画像をinputできても、スループットが低すぎて運用できない、8bit or 4bitでは性能が出ないなどの問題が発生する可能性があります。スループットが低すぎる場合、スループットを改善するためには並列化も手段として考えられますが、運用費用が増加する可能性があります。それらを全て考慮し、最終的なモデルを決定する必要があります。

終わりに

以上となります。VLLM、LLMの運用を考えるのは楽しいですね。

また、Sansan技術本部ではカジュアル面談を実施しています。

Sansan技術本部での働き方、仕事の魅力について、現役エンジニアの視点からお話しします。「実際に働く人の話を直接聞きたい」「どんな人が働いているのかを事前に知っておきたい」とお考えの方は、ぜひエントリーをご検討ください。

© Sansan, Inc.