- リンクを取得
- ×
- メール
- 他のアプリ
画像言語モデル、画像を読み込ませると、画像のキャプションを考えてくれたり、画像についての質問に答えてくれたりします。
要約
- GeForce 2070 (VRAM 8GB) でも動作した(FP32 は NG。FP16、または INT8 を指定する必要あり)
- 1 タスクに 30 秒~60 秒はかかる
準備
いつもの Docker イメージで起動します。
docker run -it --gpus=all --rm -p 7860:7860 -v /work:/work nvidia/cuda:11.8.0-base-ubuntu22.04 /bin/bash
必要なツールをインストールします。
apt update apt install -y python3-pip pip install scipy ftfy regex tqdm gradio transformers sentencepiece 'accelerate>=0.12.0' 'bitsandbytes>=0.31.5'
Hugging Face へログインします。(事前にハッシュを生成しておきます)
huggingface-cli login
生成
利用例の Colab のコードを使います。
まずモデルをロードします。FP16 でも 20GB 程度のダウンロードが発生します。FP32 だとそれのほぼ倍くらいの容量のダウンロードが必要になります。
コード中の「load_in」の値を fp32、fp16、int8 に変更することで、ロードするモデルを指定できます。
コード中の「load_in」の値を fp32、fp16、int8 に変更することで、ロードするモデルを指定できます。
import torch
from transformers import AutoTokenizer, AutoModelForVision2Seq, AutoImageProcessor
from PIL import Image
import requests
# helper function to format input prompts
TASK2INSTRUCTION = {
"caption": "画像を詳細に述べてください。",
"tag": "与えられた単語を使って、画像を詳細に述べてください。",
"vqa": "与えられた画像を下に、質問に答えてください。",
}
def build_prompt(task="caption", input=None, sep="\n\n### "):
assert (
task in TASK2INSTRUCTION
), f"Please choose from {list(TASK2INSTRUCTION.keys())}"
if task in ["tag", "vqa"]:
assert input is not None, "Please fill in `input`!"
if task == "tag" and isinstance(input, list):
input = "、".join(input)
else:
assert input is None, f"`{task}` mode doesn't support to input questions"
sys_msg = "以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。"
p = sys_msg
roles = ["指示", "応答"]
instruction = TASK2INSTRUCTION[task]
msgs = [": \n" + instruction, ": \n"]
if input:
roles.insert(1, "入力")
msgs.insert(1, ": \n" + input)
for role, msg in zip(roles, msgs):
p += sep + role + msg
return p
# load model
device = "cuda" if torch.cuda.is_available() else "cpu"
load_in = "fp16" # @param ["fp32", "fp16", "int8"]
# @markdown If you use Colab free plan, please set `load_in` to `int8`. But, please remember that `int8` degrades the performance. In general, `fp32` is better than `fp16` and `fp16` is better than `int8`.
model_kwargs = {"trust_remote_code": True, "low_cpu_mem_usage": True}
if load_in == "fp16":
model_kwargs["variant"] = "fp16"
model_kwargs["torch_dtype"] = torch.float16
elif load_in == "int8":
model_kwargs["variant"] = "fp16"
model_kwargs["load_in_8bit"] = True
model_kwargs["max_memory"] = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
model = AutoModelForVision2Seq.from_pretrained("stabilityai/japanese-stable-vlm", **model_kwargs)
processor = AutoImageProcessor.from_pretrained("stabilityai/japanese-stable-vlm")
tokenizer = AutoTokenizer.from_pretrained("stabilityai/japanese-stable-vlm")
if load_in != "int8":
model.to(device)
model = model.eval()
関数を定義します。これ自体は一瞬で完了します。
@torch.inference_mode()
def inference_fn(
image,
task,
prompt,
min_len,
max_len,
beam_size,
len_penalty,
repetition_penalty,
top_p,
decoding_method,
num_return_sequences=3,
):
prompt = build_prompt(task=task, input=prompt if not task == "caption" else None)
print(f"instruction: {prompt}")
inputs = processor(images=image, return_tensors="pt")
text_encoding = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
inputs.update(text_encoding)
generation_kwargs = {
"do_sample": decoding_method == "Nucleus sampling",
"length_penalty": float(len_penalty),
"repetition_penalty": float(repetition_penalty),
"num_beams": beam_size,
"max_new_tokens": max_len,
"min_length": min_len,
"top_p": top_p,
"num_return_sequences": 3,
}
outputs = model.generate(
**inputs.to(device, dtype=model.dtype), **generation_kwargs
)
generated = [
txt.strip() for txt in tokenizer.batch_decode(outputs, skip_special_tokens=True)
]
if num_return_sequences > 1:
generated = "\n".join([f"{i}: {g}" for i, g in enumerate(generated)])
else:
generated = generated[0]
del inputs
del outputs
torch.cuda.empty_cache()
return generated
gradio で WebUI を作ります。
import gradio as gr
with gr.Blocks() as demo:
gr.Markdown(f"# Japanese VLM Demo")
gr.Markdown(
"""[Japanese Stable VLM](https://huggingface.co/stabilityai/japanese-stable-vlm) is a Japanese vision-language model by [Stability AI](https://ja.stability.ai/).
- Blog: https://ja.stability.ai/blog/japanese-stable-vlm
- Twitter: https://twitter.com/StabilityAI_JP
- Discord: https://discord.com/invite/StableJP"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="image")
task = gr.Radio(
choices=list(TASK2INSTRUCTION.keys()), value=0, label="task"
)
prompt = gr.Textbox(label="input", value="")
with gr.Accordion(label="Configs", open=False):
min_len = gr.Slider(
minimum=1,
maximum=50,
value=1,
step=1,
interactive=True,
label="Min Length",
)
max_len = gr.Slider(
minimum=10,
maximum=100,
value=65,
step=5,
interactive=True,
label="Max New Tokens",
)
sampling = gr.Radio(
choices=["Beam search", "Nucleus sampling"],
value="Beam search",
label="Text Decoding Method",
interactive=True,
)
top_p = gr.Slider(
minimum=0.5,
maximum=1.0,
value=0.9,
step=0.1,
interactive=True,
label="Top p",
)
beam_size = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=1,
interactive=True,
label="Beam Size",
)
len_penalty = gr.Slider(
minimum=-1,
maximum=2,
value=1,
step=0.2,
interactive=True,
label="Length Penalty",
)
repetition_penalty = gr.Slider(
minimum=-1,
maximum=3,
value=1.5,
step=0.2,
interactive=True,
label="Repetition Penalty",
)
num_return_sequences = gr.Number(
value=3, label="Number of Outputs", precision=0
)
# button
input_button = gr.Button(value="Submit")
with gr.Column():
output = gr.Textbox(label="Output")
inputs = [
input_image,
task,
prompt,
min_len,
max_len,
beam_size,
len_penalty,
repetition_penalty,
top_p,
sampling,
num_return_sequences,
]
input_button.click(inference_fn, inputs=inputs, outputs=[output])
prompt.submit(inference_fn, inputs=inputs, outputs=[output])
img2txt_examples = gr.Examples(
examples=[
[
"https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80",
"caption",
"",
1,
32,
5,
1.0,
1.5,
0.9,
"Beam search",
1,
],
[
"https://images.unsplash.com/photo-1589467397966-5e600cb93d3d?auto=format&fit=crop&q=60&w=900&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxzZWFyY2h8MjB8fEphcGFuJTIwc3RyZWV0fGVufDB8fDB8fHww",
"vqa",
"道路に書かれた速度制限は?",
1,
32,
5,
1.0,
1.5,
0.9,
"Beam search",
1,
],
],
inputs=inputs,
)
if __name__ == "__main__":
demo.launch(share=True, debug=True, show_error=True)
これで以下のような WebUI が起動します。
ページ下部にサンプルがあるので、実行してみることが可能です。
まずはキャプションを作る方。ちゃんと画像の説明を生成できていますね。
VQA の方です。画像に書かれた制限速度の数字を正しく出力できていますね。
質問は自由に記述できます。入力画像も、質問文も、あまり複雑だとうまく回答できないことがありますが、それでも結構回答してくれるので面白いですね。
.png)
.png)
.png)
.png)
コメント
コメントを投稿