提交 92740d2f 作者: 朱学凯

fix path

上级 3d456595
......@@ -96,9 +96,6 @@ def seq2emb_encoder(input_seq, max_len, vocab):
ids = np.array([0])
l = len(ids)
with open('./utils/data_analyse_train.tsv', 'w') as f:
tsv = csv.writer(f)
tsv.writerow([ids, l, '\n'])
if l < max_len:
ids = np.pad(ids, (0, max_len - l), 'constant', constant_values=0)
......@@ -158,17 +155,17 @@ if __name__ == "__main__":
# vocab = load_vocab(vocab_file)
# test train
df_train = {"sps": './IC50/SPS/train_sps',
"smile": './IC50/SPS/train_smile',
"affinity": './IC50/SPS/train_ic50',
"vocab_file": './ESPF/vocab.txt',
df_train = {"sps": './data/train_sps',
"smile": './data/train_smile',
"affinity": './data/train_ic50',
"vocab_file": './config/vocab.txt',
"begin_id": '[CLS]',
"separate_id": "[SEP]",
"max_len": 256
}
tokenizer_config = {"vocab_file": './ESPF/vocab.txt',
"vocab_pair": './ESPF/drug_codes_chembl.txt'
tokenizer_config = {"vocab_file": './config/vocab.txt',
"vocab_pair": './config/drug_codes_chembl.txt'
}
params = {'batch_size': 5,
'shuffle': False,
......@@ -177,4 +174,4 @@ if __name__ == "__main__":
trainset = Data_Encoder(df_train, tokenizer_config)
training_generator = data.DataLoader(trainset)
for i, (input, affinity) in tqdm(enumerate(training_generator)):
print('----------------')
print('')
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论