1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
| from copy import deepcopy
import torch import torch.nn.functional as F from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaConfig
torch.manual_seed(0)
beta = 0.1
prompt_ids = [1, 2, 3, 4, 5, 6] good_response_ids = [7, 8, 9, 10]
bad_response_ids_list = [[1, 2, 3, 0], [4, 5, 6, 0]]
input_ids = torch.LongTensor( [prompt_ids + good_response_ids, *[prompt_ids + bad_response_ids for bad_response_ids in bad_response_ids_list]] )
labels = torch.LongTensor( [ [-100] * len(prompt_ids) + good_response_ids, *[[-100] * len(prompt_ids) + bad_response_ids for bad_response_ids in bad_response_ids_list] ] )[:, 1:] loss_mask = (labels != -100) labels[labels == -100] = 0
policy_model = LlamaForCausalLM(config=LlamaConfig(vocab_size=1000, num_hidden_layers=1, hidden_size=128)) reference_model = deepcopy(policy_model)
logits = policy_model(input_ids)["logits"][:, :-1, :] per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) all_logps = (per_token_logps * loss_mask).sum(-1)
policy_good_logps, policy_bad_logps = all_logps[:1], all_logps[1:]
with torch.no_grad(): logits = reference_model(input_ids)["logits"][:, :-1, :] per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) all_logps = (per_token_logps * loss_mask).sum(-1) reference_good_logps, reference_bad_logps = all_logps[:1], all_logps[1:]
logits = (policy_good_logps - reference_good_logps) - (policy_bad_logps - reference_bad_logps) loss = -F.logsigmoid(beta * logits).mean() print(loss)
|