这个是对bert以首字表示词向量(2)文章的扩充,是对指定span index进行加权。

整体思想来自coreference resolution#word-level实现demo部分。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# -*- coding: utf8 -*-
#


import torch


class WordEncoder(torch.nn.Module):
def __init__(self, n_in, p=0.1):
super(WordEncoder, self).__init__()
self.n_in = n_in
self.p = p
self.linear = torch.nn.Linear(in_features=n_in, out_features=1)
self.dp = torch.nn.Dropout(p)

def forward(self, x, indices):
"""
如果indices中为padding,那么就(0, bert_out.size(1))
:param x:
:param indices:
:return:
"""
# x: BS, SL, HS
# indices: BS, L
start_end_indices_to_tensor = torch.tensor(indices).to(x.device)
start_indices = start_end_indices_to_tensor[:, :, 0]
end_indices = start_end_indices_to_tensor[:, :, 1]
mask = torch.arange(x.shape[1]).expand(x.shape[0], len(indices[0]), x.shape[1]).to(x.device)
attn_mask = (mask >= start_indices.unsqueeze(2)) * (mask < end_indices.unsqueeze(2))
attn_mask = torch.log(attn_mask.to(torch.float))

attn_scores = self.linear(x).permute(0, 2, 1)
attn_scores = attn_scores.expand(x.shape[0], len(indices[0]), x.shape[1])
attn_scores = attn_scores + attn_mask
output = torch.softmax(attn_scores, dim=2).bmm(x)
return self.dp(output)


if __name__ == '__main__':
# 我 是 中国人
# 我 中国人 null
word_feature = torch.arange(60, dtype=torch.float).view(2, 5, 6)
start_end_indices = [
[(0, 1), (1, 2), (2, 5)],
[(0, 1), (1, 4), (0, 0)]
]
w = WordEncoder(6)
w(word_feature, start_end_indices)

# single
# word_feature = torch.arange(50, dtype=torch.float).view(1, 5, 10)
# # 那么index是:
# start_end_indices = [[(0, 1), (1, 2), (2, 5)]]
#
# we = WordEncoder(n_in=10, p=0.1)
# output = we.forward(word_feature, start_end_indices)
# print(output)