- リンクを取得
- ×
- メール
- 他のアプリ
画像言語モデル、画像を読み込ませると、画像のキャプションを考えてくれたり、画像についての質問に答えてくれたりします。
要約
- GeForce 2070 (VRAM 8GB) でも動作した(FP32 は NG。FP16、または INT8 を指定する必要あり)
- 1 タスクに 30 秒~60 秒はかかる
準備
いつもの Docker イメージで起動します。
docker run -it --gpus=all --rm -p 7860:7860 -v /work:/work nvidia/cuda:11.8.0-base-ubuntu22.04 /bin/bash
必要なツールをインストールします。
apt update apt install -y python3-pip pip install scipy ftfy regex tqdm gradio transformers sentencepiece 'accelerate>=0.12.0' 'bitsandbytes>=0.31.5'
Hugging Face へログインします。(事前にハッシュを生成しておきます)
huggingface-cli login
生成
利用例の Colab のコードを使います。
まずモデルをロードします。FP16 でも 20GB 程度のダウンロードが発生します。FP32 だとそれのほぼ倍くらいの容量のダウンロードが必要になります。
コード中の「load_in」の値を fp32、fp16、int8 に変更することで、ロードするモデルを指定できます。
コード中の「load_in」の値を fp32、fp16、int8 に変更することで、ロードするモデルを指定できます。
import torch from transformers import AutoTokenizer, AutoModelForVision2Seq, AutoImageProcessor from PIL import Image import requests # helper function to format input prompts TASK2INSTRUCTION = { "caption": "画像を詳細に述べてください。", "tag": "与えられた単語を使って、画像を詳細に述べてください。", "vqa": "与えられた画像を下に、質問に答えてください。", } def build_prompt(task="caption", input=None, sep="\n\n### "): assert ( task in TASK2INSTRUCTION ), f"Please choose from {list(TASK2INSTRUCTION.keys())}" if task in ["tag", "vqa"]: assert input is not None, "Please fill in `input`!" if task == "tag" and isinstance(input, list): input = "、".join(input) else: assert input is None, f"`{task}` mode doesn't support to input questions" sys_msg = "以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。" p = sys_msg roles = ["指示", "応答"] instruction = TASK2INSTRUCTION[task] msgs = [": \n" + instruction, ": \n"] if input: roles.insert(1, "入力") msgs.insert(1, ": \n" + input) for role, msg in zip(roles, msgs): p += sep + role + msg return p # load model device = "cuda" if torch.cuda.is_available() else "cpu" load_in = "fp16" # @param ["fp32", "fp16", "int8"] # @markdown If you use Colab free plan, please set `load_in` to `int8`. But, please remember that `int8` degrades the performance. In general, `fp32` is better than `fp16` and `fp16` is better than `int8`. model_kwargs = {"trust_remote_code": True, "low_cpu_mem_usage": True} if load_in == "fp16": model_kwargs["variant"] = "fp16" model_kwargs["torch_dtype"] = torch.float16 elif load_in == "int8": model_kwargs["variant"] = "fp16" model_kwargs["load_in_8bit"] = True model_kwargs["max_memory"] = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' model = AutoModelForVision2Seq.from_pretrained("stabilityai/japanese-stable-vlm", **model_kwargs) processor = AutoImageProcessor.from_pretrained("stabilityai/japanese-stable-vlm") tokenizer = AutoTokenizer.from_pretrained("stabilityai/japanese-stable-vlm") if load_in != "int8": model.to(device) model = model.eval()
関数を定義します。これ自体は一瞬で完了します。
@torch.inference_mode() def inference_fn( image, task, prompt, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, decoding_method, num_return_sequences=3, ): prompt = build_prompt(task=task, input=prompt if not task == "caption" else None) print(f"instruction: {prompt}") inputs = processor(images=image, return_tensors="pt") text_encoding = tokenizer(prompt, add_special_tokens=False, return_tensors="pt") inputs.update(text_encoding) generation_kwargs = { "do_sample": decoding_method == "Nucleus sampling", "length_penalty": float(len_penalty), "repetition_penalty": float(repetition_penalty), "num_beams": beam_size, "max_new_tokens": max_len, "min_length": min_len, "top_p": top_p, "num_return_sequences": 3, } outputs = model.generate( **inputs.to(device, dtype=model.dtype), **generation_kwargs ) generated = [ txt.strip() for txt in tokenizer.batch_decode(outputs, skip_special_tokens=True) ] if num_return_sequences > 1: generated = "\n".join([f"{i}: {g}" for i, g in enumerate(generated)]) else: generated = generated[0] del inputs del outputs torch.cuda.empty_cache() return generated
gradio で WebUI を作ります。
import gradio as gr with gr.Blocks() as demo: gr.Markdown(f"# Japanese VLM Demo") gr.Markdown( """[Japanese Stable VLM](https://huggingface.co/stabilityai/japanese-stable-vlm) is a Japanese vision-language model by [Stability AI](https://ja.stability.ai/). - Blog: https://ja.stability.ai/blog/japanese-stable-vlm - Twitter: https://twitter.com/StabilityAI_JP - Discord: https://discord.com/invite/StableJP""" ) with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="image") task = gr.Radio( choices=list(TASK2INSTRUCTION.keys()), value=0, label="task" ) prompt = gr.Textbox(label="input", value="") with gr.Accordion(label="Configs", open=False): min_len = gr.Slider( minimum=1, maximum=50, value=1, step=1, interactive=True, label="Min Length", ) max_len = gr.Slider( minimum=10, maximum=100, value=65, step=5, interactive=True, label="Max New Tokens", ) sampling = gr.Radio( choices=["Beam search", "Nucleus sampling"], value="Beam search", label="Text Decoding Method", interactive=True, ) top_p = gr.Slider( minimum=0.5, maximum=1.0, value=0.9, step=0.1, interactive=True, label="Top p", ) beam_size = gr.Slider( minimum=1, maximum=10, value=5, step=1, interactive=True, label="Beam Size", ) len_penalty = gr.Slider( minimum=-1, maximum=2, value=1, step=0.2, interactive=True, label="Length Penalty", ) repetition_penalty = gr.Slider( minimum=-1, maximum=3, value=1.5, step=0.2, interactive=True, label="Repetition Penalty", ) num_return_sequences = gr.Number( value=3, label="Number of Outputs", precision=0 ) # button input_button = gr.Button(value="Submit") with gr.Column(): output = gr.Textbox(label="Output") inputs = [ input_image, task, prompt, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, sampling, num_return_sequences, ] input_button.click(inference_fn, inputs=inputs, outputs=[output]) prompt.submit(inference_fn, inputs=inputs, outputs=[output]) img2txt_examples = gr.Examples( examples=[ [ "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80", "caption", "", 1, 32, 5, 1.0, 1.5, 0.9, "Beam search", 1, ], [ "https://images.unsplash.com/photo-1589467397966-5e600cb93d3d?auto=format&fit=crop&q=60&w=900&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxzZWFyY2h8MjB8fEphcGFuJTIwc3RyZWV0fGVufDB8fDB8fHww", "vqa", "道路に書かれた速度制限は?", 1, 32, 5, 1.0, 1.5, 0.9, "Beam search", 1, ], ], inputs=inputs, ) if __name__ == "__main__": demo.launch(share=True, debug=True, show_error=True)
これで以下のような WebUI が起動します。
ページ下部にサンプルがあるので、実行してみることが可能です。
まずはキャプションを作る方。ちゃんと画像の説明を生成できていますね。
VQA の方です。画像に書かれた制限速度の数字を正しく出力できていますね。
質問は自由に記述できます。入力画像も、質問文も、あまり複雑だとうまく回答できないことがありますが、それでも結構回答してくれるので面白いですね。
コメント
コメントを投稿