Japanese Stable CLIP を試してみた

Japanese Stable VLM を試してみた


画像言語モデル、画像を読み込ませると、画像のキャプションを考えてくれたり、画像についての質問に答えてくれたりします。

要約

  • GeForce 2070 (VRAM 8GB) でも動作した(FP32 は NG。FP16、または INT8 を指定する必要あり)
  • 1 タスクに 30 秒~60 秒はかかる

準備

いつもの Docker イメージで起動します。
docker run -it --gpus=all --rm -p 7860:7860 -v /work:/work nvidia/cuda:11.8.0-base-ubuntu22.04 /bin/bash

必要なツールをインストールします。
apt update
apt install -y python3-pip
pip install scipy ftfy regex tqdm gradio transformers sentencepiece 'accelerate>=0.12.0' 'bitsandbytes>=0.31.5'

Hugging Face へログインします。(事前にハッシュを生成しておきます)
huggingface-cli login

生成

利用例の Colab のコードを使います。
まずモデルをロードします。FP16 でも 20GB 程度のダウンロードが発生します。FP32 だとそれのほぼ倍くらいの容量のダウンロードが必要になります。
コード中の「load_in」の値を fp32、fp16、int8 に変更することで、ロードするモデルを指定できます。
import torch
from transformers import AutoTokenizer, AutoModelForVision2Seq, AutoImageProcessor
from PIL import Image
import requests
# helper function to format input prompts
TASK2INSTRUCTION = {
    "caption": "画像を詳細に述べてください。",
    "tag": "与えられた単語を使って、画像を詳細に述べてください。",
    "vqa": "与えられた画像を下に、質問に答えてください。",
}

def build_prompt(task="caption", input=None, sep="\n\n### "):
    assert (
        task in TASK2INSTRUCTION
    ), f"Please choose from {list(TASK2INSTRUCTION.keys())}"
    if task in ["tag", "vqa"]:
        assert input is not None, "Please fill in `input`!"
        if task == "tag" and isinstance(input, list):
            input = "、".join(input)
    else:
        assert input is None, f"`{task}` mode doesn't support to input questions"
    sys_msg = "以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。"
    p = sys_msg
    roles = ["指示", "応答"]
    instruction = TASK2INSTRUCTION[task]
    msgs = [": \n" + instruction, ": \n"]
    if input:
        roles.insert(1, "入力")
        msgs.insert(1, ": \n" + input)
    for role, msg in zip(roles, msgs):
        p += sep + role + msg
    return p

# load model
device = "cuda" if torch.cuda.is_available() else "cpu"
load_in = "fp16" # @param ["fp32", "fp16", "int8"]
# @markdown If you use Colab free plan, please set `load_in` to `int8`. But, please remember that `int8` degrades the performance. In general, `fp32` is better than `fp16` and `fp16` is better than `int8`.
model_kwargs = {"trust_remote_code": True, "low_cpu_mem_usage": True}
if load_in == "fp16":
  model_kwargs["variant"] = "fp16"
  model_kwargs["torch_dtype"] = torch.float16
elif load_in == "int8":
  model_kwargs["variant"] = "fp16"
  model_kwargs["load_in_8bit"] = True
  model_kwargs["max_memory"] = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
model = AutoModelForVision2Seq.from_pretrained("stabilityai/japanese-stable-vlm", **model_kwargs)
processor = AutoImageProcessor.from_pretrained("stabilityai/japanese-stable-vlm")
tokenizer = AutoTokenizer.from_pretrained("stabilityai/japanese-stable-vlm")
if load_in != "int8":
  model.to(device)
model = model.eval()

関数を定義します。これ自体は一瞬で完了します。
@torch.inference_mode()
def inference_fn(
    image,
    task,
    prompt,
    min_len,
    max_len,
    beam_size,
    len_penalty,
    repetition_penalty,
    top_p,
    decoding_method,
    num_return_sequences=3,
):
    prompt = build_prompt(task=task, input=prompt if not task == "caption" else None)
    print(f"instruction: {prompt}")
    inputs = processor(images=image, return_tensors="pt")
    text_encoding = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
    inputs.update(text_encoding)
    generation_kwargs = {
        "do_sample": decoding_method == "Nucleus sampling",
        "length_penalty": float(len_penalty),
        "repetition_penalty": float(repetition_penalty),
        "num_beams": beam_size,
        "max_new_tokens": max_len,
        "min_length": min_len,
        "top_p": top_p,
        "num_return_sequences": 3,
    }
    outputs = model.generate(
        **inputs.to(device, dtype=model.dtype), **generation_kwargs
    )
    generated = [
        txt.strip() for txt in tokenizer.batch_decode(outputs, skip_special_tokens=True)
    ]
    if num_return_sequences > 1:
        generated = "\n".join([f"{i}: {g}" for i, g in enumerate(generated)])
    else:
        generated = generated[0]
    del inputs
    del outputs
    torch.cuda.empty_cache()
    return generated

gradio で WebUI を作ります。
import gradio as gr
with gr.Blocks() as demo:
    gr.Markdown(f"# Japanese VLM Demo")
    gr.Markdown(
        """[Japanese Stable VLM](https://huggingface.co/stabilityai/japanese-stable-vlm) is a Japanese vision-language model by [Stability AI](https://ja.stability.ai/).
                - Blog: https://ja.stability.ai/blog/japanese-stable-vlm
                - Twitter: https://twitter.com/StabilityAI_JP
                - Discord: https://discord.com/invite/StableJP"""
    )
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil", label="image")
            task = gr.Radio(
                choices=list(TASK2INSTRUCTION.keys()), value=0, label="task"
            )
            prompt = gr.Textbox(label="input", value="")
            with gr.Accordion(label="Configs", open=False):
                min_len = gr.Slider(
                    minimum=1,
                    maximum=50,
                    value=1,
                    step=1,
                    interactive=True,
                    label="Min Length",
                )
                max_len = gr.Slider(
                    minimum=10,
                    maximum=100,
                    value=65,
                    step=5,
                    interactive=True,
                    label="Max New Tokens",
                )
                sampling = gr.Radio(
                    choices=["Beam search", "Nucleus sampling"],
                    value="Beam search",
                    label="Text Decoding Method",
                    interactive=True,
                )
                top_p = gr.Slider(
                    minimum=0.5,
                    maximum=1.0,
                    value=0.9,
                    step=0.1,
                    interactive=True,
                    label="Top p",
                )
                beam_size = gr.Slider(
                    minimum=1,
                    maximum=10,
                    value=5,
                    step=1,
                    interactive=True,
                    label="Beam Size",
                )
                len_penalty = gr.Slider(
                    minimum=-1,
                    maximum=2,
                    value=1,
                    step=0.2,
                    interactive=True,
                    label="Length Penalty",
                )
                repetition_penalty = gr.Slider(
                    minimum=-1,
                    maximum=3,
                    value=1.5,
                    step=0.2,
                    interactive=True,
                    label="Repetition Penalty",
                )
                num_return_sequences = gr.Number(
                    value=3, label="Number of Outputs", precision=0
                )
            # button
            input_button = gr.Button(value="Submit")
        with gr.Column():
            output = gr.Textbox(label="Output")
    inputs = [
        input_image,
        task,
        prompt,
        min_len,
        max_len,
        beam_size,
        len_penalty,
        repetition_penalty,
        top_p,
        sampling,
        num_return_sequences,
    ]
    input_button.click(inference_fn, inputs=inputs, outputs=[output])
    prompt.submit(inference_fn, inputs=inputs, outputs=[output])
    img2txt_examples = gr.Examples(
        examples=[
            [
                "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80",
                "caption",
                "",
                1,
                32,
                5,
                1.0,
                1.5,
                0.9,
                "Beam search",
                1,
            ],
            [
                "https://images.unsplash.com/photo-1589467397966-5e600cb93d3d?auto=format&fit=crop&q=60&w=900&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxzZWFyY2h8MjB8fEphcGFuJTIwc3RyZWV0fGVufDB8fDB8fHww",
                "vqa",
                "道路に書かれた速度制限は?",
                1,
                32,
                5,
                1.0,
                1.5,
                0.9,
                "Beam search",
                1,
            ],
        ],
        inputs=inputs,
    )



if __name__ == "__main__":
    demo.launch(share=True, debug=True, show_error=True)

これで以下のような WebUI が起動します。

ページ下部にサンプルがあるので、実行してみることが可能です。
まずはキャプションを作る方。ちゃんと画像の説明を生成できていますね。

VQA の方です。画像に書かれた制限速度の数字を正しく出力できていますね。

質問は自由に記述できます。入力画像も、質問文も、あまり複雑だとうまく回答できないことがありますが、それでも結構回答してくれるので面白いですね。


コメント