defaggregate(label_words_logits: torch.Tensor, label_words_mask: torch.Tensor) -> torch.Tensor: r"""Use weight to aggregate the logits of label words. Args: label_words_logits(:obj:`torch.Tensor`): The logits of the label words. Returns: :obj:`torch.Tensor`: The aggregated logits from the label words. """ label_words_logits = (label_words_logits * label_words_mask).sum(-1) / label_words_mask.sum(-1) return label_words_logits
classes = [ # There are two classes in Sentiment Analysis, one for negative and one for positive "negative", "positive" ] dataset = [ # For simplicity, there's only two examples # text_a is the input text of the data, some other datasets may have multiple input sentences in one example. InputExample( guid=0, text_a="Albert Einstein was one of the greatest intellects of his time.", ), InputExample( guid=1, text_a="The film was badly made.", ), ] # step 2 from openprompt.plms import load_plm
promptTemplate = ManualTemplate( text='{"placeholder":"text_a"} It was {"mask"}', tokenizer=tokenizer, ) # step4 from openprompt.prompts import ManualVerbalizer
data_loader = PromptDataLoader( dataset=dataset, tokenizer=tokenizer, template=promptTemplate, tokenizer_wrapper_class=WrapperClass, ) # step 7 # making zero-shot inference using pretrained MLM with prompt promptModel.eval() with torch.no_grad(): for batch in data_loader: logits = promptModel(batch) preds = torch.argmax(logits, dim=-1) print(classes[preds]) # predictions would be 1, 0 for classes 'positive', 'negative'