介绍

Llava是一个多模态大模型,本文以如下代码大致介绍下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' # noqa

import requests
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration, LlavaConfig


config = LlavaConfig.from_pretrained("llava-hf/llava-1.5-7b-hf")


model = LlavaForConditionalGeneration(config=config)

processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(text=prompt, images=image, return_tensors="pt")

generate_ids = model.generate(**inputs, max_new_tokens=15)
processor.batch_decode(generate_ids, skip_special_tokens=True,
clean_up_tokenization_spaces=False)[0]

print()


模型结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

LlavaForConditionalGeneration(
(vision_tower): CLIPVisionModel(
(vision_model): CLIPVisionTransformer(
(embeddings): CLIPVisionEmbeddings(
(patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
(position_embedding): Embedding(577, 1024)
)
(pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(encoder): CLIPEncoder(
(layers): ModuleList(
(0-23): 24 x CLIPEncoderLayer(
(self_attn): CLIPAttention(
(k_proj): Linear(in_features=1024, out_features=1024, bias=True)
(v_proj): Linear(in_features=1024, out_features=1024, bias=True)
(q_proj): Linear(in_features=1024, out_features=1024, bias=True)
(out_proj): Linear(in_features=1024, out_features=1024, bias=True)
)
(layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP(
(activation_fn): QuickGELUActivation()
(fc1): Linear(in_features=1024, out_features=4096, bias=True)
(fc2): Linear(in_features=4096, out_features=1024, bias=True)
)
(layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
)
)
(post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
)
(multi_modal_projector): LlavaMultiModalProjector(
(linear_1): Linear(in_features=1024, out_features=4096, bias=True)
(act): GELUActivation()
(linear_2): Linear(in_features=4096, out_features=4096, bias=True)
)
(language_model): LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(32064, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=4096, bias=False)
(v_proj): Linear(in_features=4096, out_features=4096, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
(up_proj): Linear(in_features=4096, out_features=11008, bias=False)
(down_proj): Linear(in_features=11008, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
)
(norm): LlamaRMSNorm()
)
(lm_head): Linear(in_features=4096, out_features=32064, bias=False)
)
)

一共三部分:

  1. CLIPVision负责处理图像部分
  2. Llama负责文本部分
  3. multi_modal_projector负责将图像hidden_size投影到Llama一样维度。

数据处理

图像部分走的是CLIP处理流程,resize到336*336,所以pixel_values shape为(3, 336, 336),其他没啥特殊。文字部分走的是Llama,这个就很熟悉了。

visual和text对齐

image走ViT,kernel_size为14,所以计算过程和结果如下:

1
2
3
(336-14)/14+1 = 24
24 * 24 = 596
# 如果考虑CLIP CLS的话就是597。

由于CLIP输出是1024,经过multi_modal_projector后维度为(1,576,4096),这个也是下面image_features的维度。
至此同一个维度4096。

image插入位置

原prompt如下:

1
prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"

<image>是插在指定位置的,那么这里图文对齐和之前的对齐就增添了另外一个含义:不仅要能图文对齐,还要考虑图像插入位置

参考上图中_merge_input_ids_with_image_features函数以及结合上图中的信息,那么不难得出如下结论:

1
2
(final_embedding[:, 5:576+5, :] == image_features[:, :, :]).all()
# > tensor(True)

由于后续target task为VQA、Image Caption之类的,先到此为止。