提交 eb4a1c00 作者: 朱学凯

fix test output

上级 98bd5425
......@@ -132,10 +132,9 @@ def test(args, model, dataset):
attention_mask=input_mask.cuda())
else:
pred_affinity = model(input_ids=input, token_type_ids=token_type_ids, attention_mask=input_mask)
pred_affinity = pred_affinity.cpu().numpy()
for res in range(args.batch_size):
pred = pred_affinity[res, :][0]
f.write(str(pred) + '\n')
pred_affinity = pred_affinity.cpu().numpy().squeeze(-1)
for res in pred_affinity:
f.write(str(res) + '\n')
if args.do_eval:
os.system('python eval.py')
......@@ -203,8 +202,8 @@ if __name__ == '__main__':
# args.task = 'test'
# args.init = './model/lr-1e-5-batch-32-e-10-layer3-0417-add-type-ids-and-mask/epoch-9-step-82370-loss-0.8841055645024439.pth'
# args.init = './model/lr-1e-5-batch-32-e-10-layer6-0428/epoch-8-step-74133-loss-0.6730387237803921.pth'
# args.output = './predict/test'
# args.config = './config/config_layer_3.json'
# args.config = './config/config_layer_6.json'
main(args)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论