首页 > 有人用过transformer框架么,请教一个问题

有人用过transformer框架么,请教一个问题

代码如下:
from transformers import *
import torch 
import logging
logging.basicConfig(level=logging.INFO)
bert_model_path = "../pretrain_model/bert_base_cased"
tokenizer = BertTokenizer.from_pretrained(bert_model_path)
model = BertForSequenceClassification.from_pretrained(bert_model_path)
classes = ["not paraphrase", "is paraphrase"]
sequence_0 = "The company HuggingFace is based in New York City"
sequence_1 = "Apples are especially bad for your health"
sequence_2 = "HuggingFace's headquarters are situated in Manhattan"
paraphrase = tokenizer.encode_plus(sequence_0, sequence_2, return_tensors="pt")
not_paraphrase = tokenizer.encode_plus(sequence_0, sequence_1, return_tensors="pt")
paraphrase_classification_logits = model(**paraphrase)[0]
not_paraphrase_classification_logits = model(**not_paraphrase)[0]
paraphrase_results = torch.softmax(paraphrase_classification_logits, dim=1).tolist()[0]
not_paraphrase_results = torch.softmax(not_paraphrase_classification_logits, dim=1).tolist()[0]
print("Should be paraphrase")
for i in range(len(classes)):
    print(f"{classes[i]}: {round(paraphrase_results[i] * 100)}%")
print("\nShould not be paraphrase")
for i in range(len(classes)):
    print(f"{classes[i]}: {round(not_paraphrase_results[i] * 100)}%")
预期输出格式:
Should be paraphrase not paraphrase: 10% is paraphrase: 90% Should not be paraphrase not paraphrase: 94% is paraphrase: 6% 
在服务器上输出:
weight.t() size:  torch.Size([768, 3072])
input size:  torch.Size([1, 21, 3072])
weight.t() size:  torch.Size([3072, 768])
input size:  torch.Size([1, 21, 768])
weight.t() size:  torch.Size([768, 768])
input size:  torch.Size([1, 21, 768])
weight.t() size:  torch.Size([768, 768])
input size:  torch.Size([1, 21, 768])
weight.t() size:  torch.Size([768, 768])
input size:  torch.Size([1, 21, 768])
weight.t() size:  torch.Size([768, 768])
input size:  torch.Size([1, 21, 768])
weight.t() size:  torch.Size([768, 3072])
input size:  torch.Size([1, 21, 3072])
weight.t() size:  torch.Size([3072, 768])
#还有很多行像上面一样格式数据输出 Should be paraphrase not paraphrase: 10%  is paraphrase: 90%  Should not be paraphrase  not paraphrase: 94%  is paraphrase: 6% 
求问中间输出是怎么回事,我并没有主动打印上面的信息啊,我换了好几个模型都是这样子,请问有人碰到过这样的情况么。
我去看transformer的loging信息,但是并没有发现有weight.t()这样的信息打印。去github上看issue也没发现有人碰到这样的问题。
google直接搜不到
###救救孩子吧,开学不出成果会被延毕啊啊啊啊啊。



全部评论

(4) 回帖
加载中...
话题 回帖

推荐话题

相关热帖

近期热帖

近期精华帖

热门推荐