mask和index在一起操作挺多的,故关于index操作也可以看: 活到老学到老之index操作

1. 对矩阵获取句子长度

1
2
3
4
5
6
from torch.nn.utils.rnn import pad_sequence

a = [torch.tensor([1,2, 3]), torch.tensor([4,5])]
b = pad_sequence(a, batch_first=True)
mask = b.not_equal(0)
b[mask].split(mask.sum(1).tolist())

2. 计算loss的时候把mask加上

略.

3. 比如三维矩阵操作mask

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
a = torch.arange(24).reshape(2, 3, 4)

Out[28]:
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],

[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])



mask = torch.tril(torch.ones(2, 3)).bool()

Out[29]:
tensor([[ True, False, False],
[ True, True, False]])


# 假设mask为下面的如何理解?
a[mask]

Out[30]:
tensor([[ 0, 1, 2, 3],
[12, 13, 14, 15],
[16, 17, 18, 19]])

简单理解,就是将sequence_length中padding位忽略掉。

4. gather

也可以看这里:bert以首字表示词向量(2)

1
2
# 例如这句,他在ltp中:https://github.com/HIT-SCIR/ltp/blob/f3d4a25ee2fbb71613f76c99a47e70a5445b8c03/ltp/transformer_rel_linear.py#L58 中出现 
input = torch.gather(input, dim=1, index=word_index.unsqueeze(-1).expand(-1, -1, input.size(-1)))

它的完整实验代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

def tokenize(form: List[List[str]], tokenizer: PreTrainedTokenizerBase, max_length: int = 128, char_base: bool = False):
"""
Args:
form:
tokenizer:
max_length:
char_base: 这里指的是form[即 word]是否是字级别的
Returns:
"""
res = tokenizer.batch_encode_plus(
form,
is_split_into_words=True,
max_length=max_length,
truncation=True,
)
result = res.data
# 可用于长度大于指定长度过滤, overflow指字长度大于指定max_length,如果有cls,sep,那么就算上这个
result['overflow'] = [len(encoding.overflowing) > 0 for encoding in res.encodings]
if not char_base:
word_index = []
for encoding in res.encodings:
word_index.append([])

last_word_idx = -1
current_length = 0
for word_idx in encoding.word_ids[1:-1]:
if word_idx != last_word_idx:
word_index[-1].append(current_length)

current_length += 1
last_word_idx = word_idx
result['word_index'] = word_index
result['word_attention_mask'] = [[True] * len(index) for index in word_index]
return result
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# -*- coding: utf8 -*-
#
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoModel, AutoTokenizer

from src.conf import PRETRAINED_NAME_OR_PATH
from src.transform import tokenize

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_NAME_OR_PATH)


class DemoModel(nn.Module):
def __init__(self):
super(DemoModel, self).__init__()
self.transformer = AutoModel.from_pretrained(PRETRAINED_NAME_OR_PATH)

def forward(self, input_ids, token_type_ids, attention_mask, word_index):
seq_out = self.transformer(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask
)[0]
word_out = torch.cat([seq_out[:, :1, :], torch.gather(
seq_out[:, 1:, :], dim=1, index=word_index.unsqueeze(-1).expand(-1, -1, seq_out.size(-1))
)], dim=1)

return seq_out, word_out[:, 1:, :]


if __name__ == '__main__':
inputs = [
['我', '爱', '北京'],
['我', '爱', '北京', '和', '你']
]

rs = tokenize(inputs, tokenizer, 10)

batch_input_ids = pad_sequence([torch.tensor(i) for i in rs['input_ids']], batch_first=True)
batch_token_type_ids = pad_sequence([torch.tensor(i) for i in rs['token_type_ids']], batch_first=True)
batch_attention_mask = pad_sequence([torch.tensor(i) for i in rs['attention_mask']], batch_first=True)
batch_word_index = pad_sequence([torch.tensor(i) for i in rs['word_index']], batch_first=True)
batch_word_attention_mask = pad_sequence([torch.tensor(i) for i in rs['word_attention_mask']],
batch_first=True)

model = DemoModel()
seq_out, word_out = model(batch_input_ids, batch_token_type_ids, batch_attention_mask, batch_word_index)
print(seq_out.shape, word_out.shape)
# torch.Size([2, 6, 256]) torch.Size([2, 5, 256])
# (1)
assert (seq_out[0][1] == word_out[0][0]).all() # 我
assert (seq_out[0][2] == word_out[0][1]).all() # 爱
assert (seq_out[0][3] == word_out[0][2]).all() # 北京

assert (seq_out[1][1] == word_out[1][0]).all() # 我
assert (seq_out[1][2] == word_out[1][1]).all() # 爱
assert (seq_out[1][3] == word_out[1][2]).all() # 北京
#
assert (seq_out[1][5] == word_out[1][3]).all() # 和
assert (seq_out[1][6] == word_out[1][4]).all() # 你

# (2)对于填充位置的设置value
new_mask = batch_word_attention_mask.unsqueeze(-1).expand(-1, -1, word_out.size(-1))
# word_out._masked_fill_(~new_mask, float('-inf')) ,这样做可以加快模型收敛
# 此处为了演示,所以改下写法
new_word_out = word_out.masked_fill(~new_mask, float("-inf"))
assert (seq_out[0][1] == word_out[0][0]).all() # 我
assert (seq_out[0][2] == word_out[0][1]).all() # 爱
assert (seq_out[0][3] == word_out[0][2]).all() # 北京

assert (new_word_out[0][3] == torch.full(new_word_out[0][3].shape, fill_value=float("-inf"))).all()
assert (new_word_out[0][4] == torch.full(new_word_out[0][4].shape, fill_value=float("-inf"))).all()

assert (seq_out[1][1] == word_out[1][0]).all() # 我
assert (seq_out[1][2] == word_out[1][1]).all() # 爱
assert (seq_out[1][3] == word_out[1][2]).all() # 北京
#
assert (seq_out[1][5] == word_out[1][3]).all() # 和
assert (seq_out[1][6] == word_out[1][4]).all() # 你

解释:
word_out表示取首字的向量作为整个词的向量,但是对于长度为padding的,那么取的是首个词的索引,需要mask掉。

比如:
['我', '爱', '北京'],seq_out为(2,6,256),其中6['我', '爱', '北京', '和', '你']字的长度,那么 ['我', '爱', '北京']后两位就为padding,word_index为0,这个进行gather的时候就会取到的向量了,所以这个在后续计算时需要把padding部分给mask掉。

5. masked_fill

1
2
3
4
5
6
a = torch.arange(4).float().reshape(2,2)
b = torch.tensor([[1, 0], [0, 1]]).bool()
a.masked_fill(b, float('-inf'))
Out[44]:
tensor([[-inf, 1.],
[2., -inf]])

6. 二维mask -> 三维mask

这个在变成4维(最后一维表示feature)时用到。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
mask = torch.tensor([
[1, 1, 1, 0],
[1, 1, 0, 0]
])
mask3d = mask.unsqueeze(-1) & mask.unsqueeze(-2)
print(mask3d)

Out:
tensor([[[1, 1, 1, 0],
[1, 1, 1, 0],
[1, 1, 1, 0],
[0, 0, 0, 0]],

[[1, 1, 0, 0],
[1, 1, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]]])

7. flatten操作

1
2
3
4
5
6
7
8
9
10
11
12
13
a = torch.arange(12).view(2, 2, 3)
a
Out[39]:
tensor([[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]]])
a.flatten(end_dim=1)
Out[40]:
tensor([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])

8. 对角线

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# -*- coding: utf8 -*-
#
import torch

seq_len = 3


def a():
"""
tensor([[-inf, 0., 0.],
[-inf, -inf, 0.],
[-inf, -inf, -inf]])
Returns:

"""
a = torch.arange(0, seq_len)
mask = a.unsqueeze(0) - a.unsqueeze(1)
mask = torch.log((mask > 0).to(torch.float))
print(mask)


def b():
"""
tensor([[0., 0., 0.],
[-inf, 0., 0.],
[-inf, -inf, 0.]])
Returns:

"""
mask = torch.triu(torch.ones(seq_len, seq_len))
mask = torch.log(mask)
print(mask)