提交 c24ead73 作者: 朱学凯

change test

上级 e341a5dc
...@@ -140,11 +140,13 @@ class Data_Encoder(data.Dataset): ...@@ -140,11 +140,13 @@ class Data_Encoder(data.Dataset):
input_seq = [self.begin_id] + d + [self.sep_id] + p + [self.sep_id] input_seq = [self.begin_id] + d + [self.sep_id] + p + [self.sep_id]
token_type_ids = np.concatenate((np.zeros((len(d) + 2), dtype=np.int), np.ones((len(p) + 1), dtype=np.int))) token_type_ids = np.concatenate((np.zeros((len(d) + 2), dtype=np.int), np.ones((len(p) + 1), dtype=np.int)))
token_type_ids = np.pad(token_type_ids, (0, self.max_len-len(input_seq)), 'constant', constant_values=0) token_type_ids = np.pad(token_type_ids, (0, self.max_len - len(input_seq)), 'constant', constant_values=0)
input, input_mask = seq2emb_encoder(input_seq, self.max_len, self.vocab) input, input_mask = seq2emb_encoder(input_seq, self.max_len, self.vocab)
return torch.from_numpy(input).long(), torch.from_numpy(token_type_ids).long(), torch.from_numpy(input_mask).long(), y return torch.from_numpy(input).long(), torch.from_numpy(token_type_ids).long(), torch.from_numpy(
input_mask).long(), y
# return len(d), len(p) # return len(d), len(p)
class Data_Encoder_mol(data.Dataset): class Data_Encoder_mol(data.Dataset):
def __init__(self, train_file, tokenizer_config): def __init__(self, train_file, tokenizer_config):
'Initialization' 'Initialization'
...@@ -169,7 +171,6 @@ class Data_Encoder_mol(data.Dataset): ...@@ -169,7 +171,6 @@ class Data_Encoder_mol(data.Dataset):
bpe_codes_prot = codecs.open(tokenizer_config["vocab_pair_p"]) bpe_codes_prot = codecs.open(tokenizer_config["vocab_pair_p"])
self.pbpe = BPE(bpe_codes_prot, merges=-1, separator='') self.pbpe = BPE(bpe_codes_prot, merges=-1, separator='')
def __len__(self): def __len__(self):
'Denotes the total number of samples' 'Denotes the total number of samples'
return len(self.smile) return len(self.smile)
...@@ -185,14 +186,13 @@ class Data_Encoder_mol(data.Dataset): ...@@ -185,14 +186,13 @@ class Data_Encoder_mol(data.Dataset):
input_seq = [self.begin_id] + d + [self.sep_id] + p + [self.sep_id] input_seq = [self.begin_id] + d + [self.sep_id] + p + [self.sep_id]
token_type_ids = np.concatenate((np.zeros((len(d) + 2), dtype=np.int), np.ones((len(p) + 1), dtype=np.int))) token_type_ids = np.concatenate((np.zeros((len(d) + 2), dtype=np.int), np.ones((len(p) + 1), dtype=np.int)))
token_type_ids = np.pad(token_type_ids, (0, self.max_len-len(input_seq)), 'constant', constant_values=0) token_type_ids = np.pad(token_type_ids, (0, self.max_len - len(input_seq)), 'constant', constant_values=0)
input, input_mask = seq2emb_encoder(input_seq, self.max_len, self.vocab) input, input_mask = seq2emb_encoder(input_seq, self.max_len, self.vocab)
return torch.from_numpy(input).long(), torch.from_numpy(token_type_ids).long(), torch.from_numpy(input_mask).long(), y return torch.from_numpy(input).long(), torch.from_numpy(token_type_ids).long(), torch.from_numpy(input_mask).long(), y
# return len(d), len(p) # return len(d), len(p)
def get_task(task_name): def get_task(task_name):
tokenizer_config = {"vocab_file": './config/vocab.txt', tokenizer_config = {"vocab_file": './config/vocab.txt',
"vocab_pair": './config/drug_codes_chembl.txt', "vocab_pair": './config/drug_codes_chembl.txt',
"begin_id": '[CLS]', "begin_id": '[CLS]',
...@@ -210,9 +210,9 @@ def get_task(task_name): ...@@ -210,9 +210,9 @@ def get_task(task_name):
elif task_name.lower() == 'test': elif task_name.lower() == 'test':
df_test = {"sps": './data/test/test_sps', df_test = {"sps": './data/test/test_sps',
"smile": './data/test/test_smile', "smile": './data/test/test_smile',
"affinity": './data/test/test_ic50', "affinity": './data/test/test_ic50',
} }
return df_test, tokenizer_config return df_test, tokenizer_config
...@@ -243,9 +243,9 @@ def get_task(task_name): ...@@ -243,9 +243,9 @@ def get_task(task_name):
elif task_name.lower() == 'train_mol': elif task_name.lower() == 'train_mol':
df_train = {"sps": './data/train/train_sps', df_train = {"sps": './data/train/train_sps',
'seq': './data/train/train_protein_seq', 'seq': './data/train/train_protein_seq',
"smile": './data/train/train_smile', "smile": './data/train/train_smile',
"affinity": './data/train/train_ic50', "affinity": './data/train/train_ic50',
} }
tokenizer_config = {"vocab_file": './config/vocab_mol.txt', tokenizer_config = {"vocab_file": './config/vocab_mol.txt',
"vocab_pair": './config/drug_codes_chembl.txt', "vocab_pair": './config/drug_codes_chembl.txt',
...@@ -255,12 +255,9 @@ def get_task(task_name): ...@@ -255,12 +255,9 @@ def get_task(task_name):
"max_len": 595 "max_len": 595
} }
return df_train, tokenizer_config return df_train, tokenizer_config
if __name__ == "__main__": if __name__ == "__main__":
# local test # local test
# dataFolder = './IC50/SPS/train_smile' # dataFolder = './IC50/SPS/train_smile'
......
...@@ -3,11 +3,15 @@ from modeling_bert import BertForMaskedLM ...@@ -3,11 +3,15 @@ from modeling_bert import BertForMaskedLM
import torch import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased') # model = BertForMaskedLM.from_pretrained('bert-base-uncased')
inputs = tokenizer("The capital of France is [MASK].", return_tensors="pt") seq_a = "The capital of France is [MASK]."
seq_b = "The capital of France is Paris."
choice0 = "It is eaten with a fork and a knife."
choice1 = "It is eaten while held in the hand."
inputs = tokenizer([[seq_a, seq_b], [choice0, choice1]], padding=True)
labels = tokenizer("The capital of France is Paris.", return_tensors="pt") #["input_ids"] labels = tokenizer("The capital of France is Paris.", return_tensors="pt") #["input_ids"]
outputs = model(**inputs, labels=labels)
loss = outputs.loss
logits = outputs.logits
print('----------------') print('----------------')
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论