1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
|
import torch from torch import nn from torch.nn.utils.rnn import pad_sequence from transformers import AutoModel, AutoTokenizer
from src.conf import PRETRAINED_NAME_OR_PATH from src.transform import tokenize
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_NAME_OR_PATH)
class DemoModel(nn.Module): def __init__(self): super(DemoModel, self).__init__() self.transformer = AutoModel.from_pretrained(PRETRAINED_NAME_OR_PATH)
def forward(self, input_ids, token_type_ids, attention_mask, word_index): seq_out = self.transformer( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask )[0] word_out = torch.cat([seq_out[:, :1, :], torch.gather( seq_out[:, 1:, :], dim=1, index=word_index.unsqueeze(-1).expand(-1, -1, seq_out.size(-1)) )], dim=1)
return seq_out, word_out[:, 1:, :]
if __name__ == '__main__': inputs = [ ['我', '爱', '北京'], ['我', '爱', '北京', '和', '你'] ]
rs = tokenize(inputs, tokenizer, 10)
batch_input_ids = pad_sequence([torch.tensor(i) for i in rs['input_ids']], batch_first=True) batch_token_type_ids = pad_sequence([torch.tensor(i) for i in rs['token_type_ids']], batch_first=True) batch_attention_mask = pad_sequence([torch.tensor(i) for i in rs['attention_mask']], batch_first=True) batch_word_index = pad_sequence([torch.tensor(i) for i in rs['word_index']], batch_first=True) batch_word_attention_mask = pad_sequence([torch.tensor(i) for i in rs['word_attention_mask']], batch_first=True)
model = DemoModel() seq_out, word_out = model(batch_input_ids, batch_token_type_ids, batch_attention_mask, batch_word_index) print(seq_out.shape, word_out.shape) assert (seq_out[0][1] == word_out[0][0]).all() assert (seq_out[0][2] == word_out[0][1]).all() assert (seq_out[0][3] == word_out[0][2]).all()
assert (seq_out[1][1] == word_out[1][0]).all() assert (seq_out[1][2] == word_out[1][1]).all() assert (seq_out[1][3] == word_out[1][2]).all() assert (seq_out[1][5] == word_out[1][3]).all() assert (seq_out[1][6] == word_out[1][4]).all()
new_mask = batch_word_attention_mask.unsqueeze(-1).expand(-1, -1, word_out.size(-1)) new_word_out = word_out.masked_fill(~new_mask, float("-inf")) assert (seq_out[0][1] == word_out[0][0]).all() assert (seq_out[0][2] == word_out[0][1]).all() assert (seq_out[0][3] == word_out[0][2]).all()
assert (new_word_out[0][3] == torch.full(new_word_out[0][3].shape, fill_value=float("-inf"))).all() assert (new_word_out[0][4] == torch.full(new_word_out[0][4].shape, fill_value=float("-inf"))).all()
assert (seq_out[1][1] == word_out[1][0]).all() assert (seq_out[1][2] == word_out[1][1]).all() assert (seq_out[1][3] == word_out[1][2]).all() assert (seq_out[1][5] == word_out[1][3]).all() assert (seq_out[1][6] == word_out[1][4]).all()
|