最近面试了一北京候选者,之前使用电子病例以及CT图像两种模态信息,训练ViLT多模态预测模型,提高肺结节良恶性预测准确率。
正好我对多模态如何对齐也比较感兴趣,以Transformers-Tutorials提供的代码为例,来看下其内部是如何实现的。
数据集我没有从VQA下载,太大了,这里也强烈安利huggingface提供的lmms-lab/VQAv2 dataset。

剩下就改下VQADataset部分,其他保持不变。

这里记录比较有趣的几个点。

1. text和image如何对齐?

答案:在第二维对齐。

具体来说,text部分使用的是BertTokenizer,max_position为40(所以如果有长文本,这里就坑了),假设batch_size为4,text embedding出来后就是(4, 40, 768)。
image部分使用的也是ViT,所以整体思想和CLIP是一样的,得到patch后flatten,然后concat到一起即可, 即embeddings = torch.cat([text_embeds, image_embeds], dim=1)

相比CLIP,不同之处在于支持不同宽高的image,CLIP默认只支持(224, 224)宽高的图像。所以ViLT多了个visual pad操作,另外kernel size也不一样。如果深究的话,可以看看visual_embed部分。

2. 竟然有ViltForTokenClassification

是的,上面既然已经对齐了,那么前面都是text部分,所以做文本相关例如token classification等这部分也是没问题的。

3. ViltForQuestionAnswering是QA吗?

先入为主来讲,这个A呢,是不固定的,比如使用generation来生成。但是这里呢,是有限的!!!。

怎么说呢,人家提供了个类似[CLS],然后你可以指定自己的类别,所以他是个分类器。

这里作者使用了finetuned的label来作为演示,可以看看”dandelin/vilt-b32-finetuned-vqa”下config.json中的label2id查看标签数量。

1
2
3
4

from transformers import ViltConfig

config = ViltConfig.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

4. 排序新操作

正如作者所说的那样,VQAv2是个多标签分类数据集,什么意思呢,例如下图。

它的question:where is he looking?

它的标注结果annotations:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16

{'question_type': 'none of the above',
'multiple_choice_answer': 'down',
'answers': [{'answer': 'down', 'answer_confidence': 'yes', 'answer_id': 1},
{'answer': 'down', 'answer_confidence': 'yes', 'answer_id': 2},
{'answer': 'at table', 'answer_confidence': 'yes', 'answer_id': 3},
{'answer': 'skateboard', 'answer_confidence': 'yes', 'answer_id': 4},
{'answer': 'down', 'answer_confidence': 'yes', 'answer_id': 5},
{'answer': 'table', 'answer_confidence': 'yes', 'answer_id': 6},
{'answer': 'down', 'answer_confidence': 'yes', 'answer_id': 7},
{'answer': 'down', 'answer_confidence': 'yes', 'answer_id': 8},
{'answer': 'down', 'answer_confidence': 'yes', 'answer_id': 9},
{'answer': 'down', 'answer_confidence': 'yes', 'answer_id': 10}],
'image_id': 262148,
'answer_type': 'other',
'question_id': 262148000}

answer是有多个的,也就是说它的answers是个众包的,有的人认为是down,有的人认为是table。

4.1 作者实现思路

1、处理answer

例如上面标注结果,首先统计answer次数,得到下面结果:

1
answers_count = {'down': 7, 'at table': 1, 'skateboard': 1, 'table': 1}

接着对每个answer count使用get_score进行平滑处理,get_score有三种可能,那每种可能就认为是每种答案的权重。

1
2
3
4
def get_score(count: int) -> float:
return min(1.0, count / 3)

# [0.3333333333333333, 0.6666666666666666, 1.0]
2、ViltForQuestionAnswering

在ViLT出来后,接了个分类器,这个分类器的标签数量来自自己实际业务指定,比如answers的总数量。假设一共有10个类别,down类别是1,at table是2,依次类推,得到label如下:

1
label = [1.0, 0.3333333333333333, 0.3333333333333333, 0.3333333333333333, 0, 0, 0, 0, 0, 0]

接着计算loss如下:

1
2
loss = nn.functional.binary_cross_entropy_with_logits(logits, labels) * labels.shape[1]

这里又乘了个labels总数量,可以简单理解成loss scale。

3、解码

作者使用topk进行解码:

1
2
probs, classes = torch.topk(predicted_classes, 5)

至此,大致理解了整体流程。

这里是get到一种新的实现方式了哈,那接下来我想对其扩展下,看看还有哪些方式可以做这种。

4.2 其他排序思路

1、rank loss。排序中常用到。
2、目前大模型中的打分模型(score model),在最开始我没看其实现的时候,我曾深深纠结于这又是什么新的技术,等我看完后,我感觉更应该关注的是“打分”俩字,即强化学习思想中的critic model。本质技术层面没有太大变化,作为打分模型为Actor提供优化方向。
3、DPO。对,你没看错,强化学习算法DPO,本质思想就在于好的要逐渐和坏的拉开差距,使其更符合人类偏好。

看的越多,会发现还是有共通之处的。