介绍

此次开辟一个新的话题,叫做prompting learning。prompting learning常用于零样本、小样本领域,通过将下游训练数据转换成预训练模型任务,进行小样本微调或者零样本进行预测。

之前在UIE-事件提取中也有涉及prompt,但是UIE的做法是构造prompt输入,再和原句一起微调模型,完成提取任务。那这里我们来看看OpenPrompt是怎么做的。

关于OpenPrompt,此处不做过多介绍,请直接看官网。

官方有张图阐述了OpenPrompt的架构,请看下图。

下面通过官方示例来说明它的工作原理。

主要模块介绍

1. Template

这个从网上一搜prompting learning,不可避免搜到template,这个是将下游训练数据通过template转换成符合预训练模型任务的训练数据,保持和预训练任务一致。

2. Verbalizer

这个应该怎么翻译呢,姑且称为同义词吧。。。它的作用是从预训练模型vocab中获取符合label的词,比如:

1
2
3
4
{
"negative": ["bad"],
"positive": ["good", "wonderful", "great"],
}

此处先不展开介绍,后续会介绍它的作用和比较坑的地方。

3. PLM

这个就没什么可解释的了,就是预训练模型。

流程

整体流程代码可见附录。

1. 推理

比如这个情感分类任务,经过Template转换后变成了如下:

1
2
Albert Einstein was one of the greatest intellects of his time. It was [MASK]
The film was badly made. It was [MASK]

接着输入给bert模型,拿到mask位置的预测结果output,此时shape为(2, 28996),表示batch_size为2,整个bert vocab size为28996。

那如何将这个预测结果对应到negativepositive呢,这里就引出来Verbalizer了。

['bad', 'good', 'wonderful', 'great']转成tokenizer vocab的index,然后获取output所对应的index,比如下面伪代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

def aggregate(label_words_logits: torch.Tensor, label_words_mask: torch.Tensor) -> torch.Tensor:
r"""Use weight to aggregate the logits of label words.

Args:
label_words_logits(:obj:`torch.Tensor`): The logits of the label words.

Returns:
:obj:`torch.Tensor`: The aggregated logits from the label words.
"""
label_words_logits = (label_words_logits * label_words_mask).sum(-1) / label_words_mask.sum(-1)
return label_words_logits

output = torch.rand(2, 28996)
# [tokenizer.vocab.get(i) for i in ['bad', 'good', 'wonderful', 'great']]
# [2213, 1363, 7310, 1632]
verbalizer = torch.tensor([[2213, 0, 0],
[1363, 7310, 1632]])
mask = torch.tensor([[1, 0, 0],
[1, 1, 1]])
logits = output[:, verbalizer]
# tensor([[[0.0508, 0.0000, 0.0000],
# [0.1770, 0.9680, 0.0091]],
# [[0.2045, 0.0000, 0.0000],
# [0.9504, 0.2127, 0.1417]]])
print(aggreate(logits, mask))
# negative, positive
# tensor([[0.0508, 0.3847],
# [0.2045, 0.4349]])
样本 bad good wonderful great
1 0.0508 0.1770 0.9680 0.0091
2 0.2045 0.9504 0.2127 0.1417

通过aggreate就拿到最终对应negative和postive的logit了。

2. 训练

1
2
3
print(aggreate(logits, mask))
# tensor([[0.0508, 0.3847],
# [0.2045, 0.4349]])

拿到这个后,直接和label做交叉熵不就拿到loss啦~

3. Verbalizer缺点

可看到,['bad', 'good', 'wonderful', 'great']都是在bert vocab中存在的词,但是在中文里面是以字进行拆分的,那么不会存在不好,精彩,漂亮,优秀...这种词,即使是在WWM模型里面也不存在这些词,那么这个对中文构造这些verbalizer带来了不足。

而这种方式被称为硬解码,后续再说。

附录

1. 完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# -*- coding: utf8 -*-
#
# step1:
import torch

from openprompt.data_utils import InputExample

classes = [ # There are two classes in Sentiment Analysis, one for negative and one for positive
"negative",
"positive"
]
dataset = [ # For simplicity, there's only two examples
# text_a is the input text of the data, some other datasets may have multiple input sentences in one example.
InputExample(
guid=0,
text_a="Albert Einstein was one of the greatest intellects of his time.",
),
InputExample(
guid=1,
text_a="The film was badly made.",
),
]
# step 2
from openprompt.plms import load_plm

plm, tokenizer, model_config, WrapperClass = load_plm("bert", "bert-base-cased")
# step3
from openprompt.prompts import ManualTemplate

promptTemplate = ManualTemplate(
text='{"placeholder":"text_a"} It was {"mask"}',
tokenizer=tokenizer,
)
# step4
from openprompt.prompts import ManualVerbalizer

promptVerbalizer = ManualVerbalizer(
classes=classes,
label_words={
"negative": ["bad"],
"positive": ["good", "wonderful", "great"],
},
tokenizer=tokenizer,
)
# step 5
from openprompt import PromptForClassification

promptModel = PromptForClassification(
template=promptTemplate,
plm=plm,
verbalizer=promptVerbalizer,
)
# step 6
from openprompt import PromptDataLoader

data_loader = PromptDataLoader(
dataset=dataset,
tokenizer=tokenizer,
template=promptTemplate,
tokenizer_wrapper_class=WrapperClass,
)
# step 7
# making zero-shot inference using pretrained MLM with prompt
promptModel.eval()
with torch.no_grad():
for batch in data_loader:
logits = promptModel(batch)
preds = torch.argmax(logits, dim=-1)
print(classes[preds])
# predictions would be 1, 0 for classes 'positive', 'negative'