快速想一想,你能想到torch有哪些常见的index操作??

1. gather

1
2
3
4
5
>>> a = torch.tensor([[1, 2, 3],
[4, 5, 6]])
>>> a.gather(dim=1, index=torch.tensor([[0,1], [1,2]]))
tensor([[1, 2],
[5, 6]])

2. index_select

1
2
3
4
5
6
>>> a
tensor([[1, 2, 3],
[4, 5, 6]])
>>> a.index_select(dim=1, index=torch.tensor([1,2]))
tensor([[2, 3],
[5, 6]])

3. 骚气的来了哦

根据上面例子可以看到,a为矩阵,选择a中的index,但是下面介绍一个map操作.

1
2
3
4
5
6
7
8
9
10
11
12
13
>>> index
tensor([[1, 2, 3],
[4, 5, 6]])

>>> a = torch.tensor([11, 22, 33, 44, 55, 66, 77])
>>> a
tensor([11, 22, 33, 44, 55, 66, 77])
>>> index
tensor([[1, 2, 3],
[4, 5, 6]])
>>> a[index]
tensor([[22, 33, 44],
[55, 66, 77]])

这种操作有一个真实场景,比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# 1. 这是两个特征
>>> words = ['我', '爱', '中', '国']
>>> pos = ['n', 'v', 'n', 'n']

# 2. 假设words变成了一个4 * 4的临接矩阵,用于表示每个token和其他token的一个关联重要程度

>>> words_attn = torch.rand(4,4)

>>> words_attn
tensor([[0.6279, 0.6234, 0.9831, 0.5267],
[0.2265, 0.8453, 0.5740, 0.4772],
[0.7759, 0.6952, 0.1758, 0.3800],
[0.9998, 0.3138, 0.5078, 0.5565]])


>>> scores, indices = words_attn.topk(k=2, dim=1)

>>> indices
tensor([[2, 0],
[1, 2],
[0, 1],
[0, 3]])

# 3. 假设pos转为了
>>> pos_tensor = torch.tensor([111, 222, 333, 444])

# 4. map操作
>>> pos_tensor[indices]
tensor([[333, 111],
[222, 333],
[111, 222],
[111, 444]])

# 5. 随后就可以接一个embedding搞事情了
pos_embedding(pos_tensor[indices])

# 6. 总结,这个示例的优点可以看出是快速计算,取topK然后再结合其他的特征进行操作。

4. batch_gather

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch


def batch_gather(data: torch.Tensor, index: torch.Tensor):
index = index.unsqueeze(-1).repeat_interleave(data.size()[-1], dim=-1) # (bs, n, hidden)
# index = index.unsqueeze(-1).expand(*(*index.shape, data.shape[-1]))
return torch.gather(data, 1, index)


if __name__ == '__main__':
a = torch.randn(3, 128, 312)
indices = torch.tensor([
[1, 2],
[5, 6],
[7, 7]
])
output = batch_gather(a, indices)
print(output.shape)
print((output[0][0] == a[0][1]).all())
print((output[0][1] == a[0][2]).all())
print((output[1][0] == a[1][5]).all())
print((output[1][1] == a[1][6]).all())
print((output[2][0] == a[2][7]).all())
print((output[2][1] == a[2][7]).all())