提交 aba94ef4 作者: 朱学凯

fix pad

上级 a736cc2c
......@@ -186,6 +186,9 @@ class Data_Encoder_mol(data.Dataset):
input_seq = [self.begin_id] + d + [self.sep_id] + p + [self.sep_id]
token_type_ids = np.concatenate((np.zeros((len(d) + 2), dtype=np.int), np.ones((len(p) + 1), dtype=np.int)))
if len(input_seq) > self.max_len:
input_seq = input_seq[:self.max_len-1] + [self.sep_id]
else:
token_type_ids = np.pad(token_type_ids, (0, self.max_len - len(input_seq)), 'constant', constant_values=0)
input, input_mask = seq2emb_encoder(input_seq, self.max_len, self.vocab)
return torch.from_numpy(input).long(), torch.from_numpy(token_type_ids).long(), torch.from_numpy(input_mask).long(), y
......
......@@ -108,6 +108,11 @@ def main(args):
print('------------------creat model---------------------------')
config = BertConfig.from_pretrained(args.config)
model = BertAffinityModel(config)
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = torch.nn.DataParallel(model, dim=0)
print('model name : BertAffinity')
print('task name : {}'.format(args.task))
......@@ -150,11 +155,11 @@ if __name__ == '__main__':
# local test
# args.task = 'train_mol'
# args.savedir = 'local_test_train'
# args.epochs = 10
# args.lr = 1e-5
# args.config = './config/config_layer_3_mol.json'
args.task = 'train_mol'
args.savedir = 'local_test_train'
args.epochs = 10
args.lr = 1e-5
args.config = './config/config_layer_3_mol.json'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论