提交 da2551fa 作者: 朱学凯

change data

上级 9d55deb2
...@@ -146,6 +146,58 @@ class Data_Encoder(data.Dataset): ...@@ -146,6 +146,58 @@ class Data_Encoder(data.Dataset):
# return len(d), len(p) # return len(d), len(p)
def get_task(task_name):
tokenizer_config = {"vocab_file": './config/vocab.txt',
"vocab_pair": './config/drug_codes_chembl.txt',
"begin_id": '[CLS]',
"separate_id": "[SEP]",
"max_len": 256
}
if task_name.lower() == 'train':
df_train = {"sps": './data/train/train_sps',
"smile": './data/train/train_smile',
"affinity": './data/train/train_ic50',
}
return df_train, tokenizer_config
elif task_name.lower() == 'test':
df_test = {"sps": './data/test/test_sps',
"smile": './data/test/test_smile',
"affinity": './data/test/test_ic50',
}
return df_test, tokenizer_config
elif task_name.lower() == 'train_z_1':
df = {"sps": './data/train_sps',
"smile": './data/train_smile',
"affinity": './data/train_z_1_ic50',
}
return df, tokenizer_config
elif task_name.lower() == 'train_z_10':
df = {"sps": './data/train_sps',
"smile": './data/train_smile',
"affinity": './data/train_z_10_ic50',
}
return df, tokenizer_config
elif task_name.lower() == 'train_z_100':
df = {"sps": './data/train_sps',
"smile": './data/train_smile',
"affinity": './data/train_z_100_ic50',
}
return df, tokenizer_config
if __name__ == "__main__": if __name__ == "__main__":
# local test # local test
# dataFolder = './IC50/SPS/train_smile' # dataFolder = './IC50/SPS/train_smile'
......
from argparse import ArgumentParser from argparse import ArgumentParser
from dataset import Data_Encoder from dataset import Data_Encoder, get_task
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from configuration_bert import BertConfig from configuration_bert import BertConfig
...@@ -9,55 +9,6 @@ import os ...@@ -9,55 +9,6 @@ import os
from tqdm import tqdm from tqdm import tqdm
torch.set_default_tensor_type(torch.DoubleTensor) torch.set_default_tensor_type(torch.DoubleTensor)
def get_task(task_name):
tokenizer_config = {"vocab_file": './config/vocab.txt',
"vocab_pair": './config/drug_codes_chembl.txt',
"begin_id": '[CLS]',
"separate_id": "[SEP]",
"max_len": 256
}
if task_name.lower() == 'train':
df_train = {"sps": './data/train_sps',
"smile": './data/train_smile',
"affinity": './data/train_ic50',
}
return df_train, tokenizer_config
elif task_name.lower() == 'test':
df_test = {"sps": './data/test_sps',
"smile": './data/test_smile',
"affinity": './data/test_ic50',
}
return df_test, tokenizer_config
elif task_name.lower() == 'train_z_1':
df = {"sps": './data/train_sps',
"smile": './data/train_smile',
"affinity": './data/train_z_1_ic50',
}
return df, tokenizer_config
elif task_name.lower() == 'train_z_10':
df = {"sps": './data/train_sps',
"smile": './data/train_smile',
"affinity": './data/train_z_10_ic50',
}
return df, tokenizer_config
elif task_name.lower() == 'train_z_100':
df = {"sps": './data/train_sps',
"smile": './data/train_smile',
"affinity": './data/train_z_100_ic50',
}
return df, tokenizer_config
...@@ -144,8 +95,8 @@ def test(args, model, dataset): ...@@ -144,8 +95,8 @@ def test(args, model, dataset):
for res in pred_affinity: for res in pred_affinity:
f.write(str(res) + '\n') f.write(str(res) + '\n')
if args.do_eval: # if args.do_eval:
os.system('python eval.py') # os.system('python eval.py')
def main(args): def main(args):
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased') model = BertForMaskedLM.from_pretrained('bert-base-uncased')
inputs = tokenizer("The capital of France is [MASK].", return_tensors="pt") inputs = tokenizer("The capital of France is [MASK].", return_tensors="pt")
labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"] labels = tokenizer("The capital of France is Paris.", return_tensors="pt") #["input_ids"]
outputs = model(**inputs, labels=labels) outputs = model(**inputs, labels=labels)
loss = outputs.loss loss = outputs.loss
logits = outputs.logits logits = outputs.logits
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论