提交 0ed3bb0e 作者: 朱学凯

add trainer

上级 e37fe169
No preview for this file type
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4"> <module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager"> <component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" /> <content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Python 3.6 (py3.6)" jdkType="Python SDK" /> <orderEntry type="jdk" jdkName="Python 3.6 (code)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" /> <orderEntry type="sourceFolder" forTests="false" />
</component> </component>
</module> </module>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (py3.6)" project-jdk-type="Python SDK" /> <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (code)" project-jdk-type="Python SDK" />
</project> </project>
\ No newline at end of file
...@@ -2,27 +2,11 @@ ...@@ -2,27 +2,11 @@
<project version="4"> <project version="4">
<component name="ChangeListManager"> <component name="ChangeListManager">
<list default="true" id="7d3a4caa-6c7d-4ed2-9017-145cec64d9a3" name="Default Changelist" comment=""> <list default="true" id="7d3a4caa-6c7d-4ed2-9017-145cec64d9a3" name="Default Changelist" comment="">
<change afterPath="$PROJECT_DIR$/.idea/inspectionProfiles/Project_Default.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/.gitignore" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/CPI.iml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/CPI.iml" afterDir="false" /> <change beforePath="$PROJECT_DIR$/.idea/CPI.iml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/CPI.iml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/misc.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" /> <change beforePath="$PROJECT_DIR$/.idea/misc.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/modules.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/modules.xml" afterDir="false" /> <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/vcs.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/vcs.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/activations.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/configuration_bert.py" beforeDir="false" afterPath="$PROJECT_DIR$/configuration_bert.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/configuration_utils.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/dataset.py" beforeDir="false" afterPath="$PROJECT_DIR$/dataset.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/dataset.py" beforeDir="false" afterPath="$PROJECT_DIR$/dataset.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/file_utils.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/generation_beam_search.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/generation_logits_process.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/generation_stopping_criteria.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/generation_utils.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/modeling_bert.py" beforeDir="false" afterPath="$PROJECT_DIR$/modeling_bert.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/modeling_outputs.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/modeling_utils.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/run_interaction.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_interaction.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/run_interaction.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_interaction.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/utils/hf_api.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/utils/logging.py" beforeDir="false" />
</list> </list>
<option name="SHOW_DIALOG" value="false" /> <option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" /> <option name="HIGHLIGHT_CONFLICTS" value="true" />
...@@ -113,6 +97,7 @@ ...@@ -113,6 +97,7 @@
<option name="presentableId" value="Default" /> <option name="presentableId" value="Default" />
<updated>1616659651684</updated> <updated>1616659651684</updated>
<workItem from="1616659659765" duration="24696000" /> <workItem from="1616659659765" duration="24696000" />
<workItem from="1616983248339" duration="1434000" />
</task> </task>
<servers /> <servers />
</component> </component>
...@@ -129,20 +114,10 @@ ...@@ -129,20 +114,10 @@
</entry> </entry>
</map> </map>
</option> </option>
</component> <option name="oldMeFiltersMigrated" value="true" />
<component name="XDebuggerManager">
<breakpoint-manager>
<breakpoints>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/run_interaction.py</url>
<line>57</line>
<option name="timeStamp" value="16" />
</line-breakpoint>
</breakpoints>
</breakpoint-manager>
</component> </component>
<component name="com.intellij.coverage.CoverageDataManagerImpl"> <component name="com.intellij.coverage.CoverageDataManagerImpl">
<SUITE FILE_PATH="coverage/CPI$run_interaction.coverage" NAME="run_interaction Coverage Results" MODIFIED="1616929902995" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/CPI$dataset.coverage" NAME="dataset Coverage Results" MODIFIED="1616847821413" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" /> <SUITE FILE_PATH="coverage/CPI$dataset.coverage" NAME="dataset Coverage Results" MODIFIED="1616847821413" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/CPI$run_interaction.coverage" NAME="run_interaction Coverage Results" MODIFIED="1616929902995" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
</component> </component>
</project> </project>
\ No newline at end of file
File mode changed from 100644 to 100755
...@@ -136,7 +136,7 @@ class Data_Encoder(data.Dataset): ...@@ -136,7 +136,7 @@ class Data_Encoder(data.Dataset):
# tokenization # tokenization
d = self.dbpe.process_line(self.smile[index].strip()).split() d = self.dbpe.process_line(self.smile[index].strip()).split()
p = self.sps[index].strip().split(',') p = self.sps[index].strip().split(',')
y = self.affinity[index].strip() y = np.float32(self.affinity[index].strip())
input_seq = [self.begin_id] + d + [self.sep_id] + p + [self.sep_id] input_seq = [self.begin_id] + d + [self.sep_id] + p + [self.sep_id]
input, input_mask = seq2emb_encoder(input_seq, self.max_len, self.vocab) input, input_mask = seq2emb_encoder(input_seq, self.max_len, self.vocab)
...@@ -172,7 +172,7 @@ if __name__ == "__main__": ...@@ -172,7 +172,7 @@ if __name__ == "__main__":
'shuffle': True, 'shuffle': True,
'num_workers': 0, 'num_workers': 0,
'drop_last': True} 'drop_last': True}
trainset = Data_Encoder(df_train, tokenizer_config) # trainset = Data_Encoder(df_train, tokenizer_config)
training_generator = data.DataLoader(trainset, **params) # training_generator = data.DataLoader(trainset, **params)
for i, (input, affinity) in tqdm(enumerate(training_generator)): # for i, (input, affinity) in tqdm(enumerate(training_generator)):
print(input.size()) # print(input.size())
File mode changed from 100644 to 100755
...@@ -4,7 +4,8 @@ import torch ...@@ -4,7 +4,8 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from configuration_bert import BertConfig from configuration_bert import BertConfig
from modeling_bert import BertAffinityModel from modeling_bert import BertAffinityModel
from torch.utils.tensorboard import SummaryWriter
import os
def get_task(task_name): def get_task(task_name):
if task_name.lower() == 'train': if task_name.lower() == 'train':
...@@ -40,26 +41,51 @@ def main(args): ...@@ -40,26 +41,51 @@ def main(args):
data_file, tokenizer_config = get_task(args.task) data_file, tokenizer_config = get_task(args.task)
dataset = Data_Encoder(data_file, tokenizer_config) dataset = Data_Encoder(data_file, tokenizer_config)
data_loder_para = {'batch_size': args.batch_size, data_loder_para = {'batch_size': args.batch_size,
'shuffle': True, 'shuffle': False,
'num_workers': args.workers, 'num_workers': args.workers,
'drop_last': True 'drop_last': True
} }
data_generator = DataLoader(dataset, **data_loder_para) data_generator = DataLoader(dataset, **data_loder_para)
# creat model # creat model
print('------------------creat model---------------------------')
config = BertConfig.from_pretrained(args.config) config = BertConfig.from_pretrained(args.config)
model = BertAffinityModel(config) model = BertAffinityModel(config)
print('model name : BertAffinity')
if args.task == 'train': if args.task == 'train':
opt = torch.optim.Adam(model.parameters(), lr=args.lr) opt = torch.optim.Adam(model.parameters(), lr=args.lr)
loss_fct = torch.nn.MSELoss() loss_fct = torch.nn.MSELoss()
writer = SummaryWriter('./log/' + args.log)
num_step = args.epochs * len(data_generator)
# detect GPU
if torch.cuda.is_available():
model.cuda(args.device)
print('begin training')
for epoch in range(args.epochs): for epoch in range(args.epochs):
for i, (input, affinity) in enumerate(data_generator): for i, (input, affinity) in enumerate(data_generator):
pred_affinity = model(input.long()) # use cuda
loss = loss_fct(pred_affinity, affinity) if torch.cuda.is_available():
input.long().cuda(args.device)
affinity.cuda(args.divice)
# input model
pred_affinity = model(input)
loss = loss_fct(pred_affinity, affinity.unsqueeze(-1))
writer.add_scalar('loss', loss, num_step)
# Update gradient
opt.zero_grad() opt.zero_grad()
loss.backward() loss.backward()
opt.step() opt.step()
print('------------------')
if (i % 100 == 0):
print('Training at Epoch ' + str(epoch + 1) + ' step ' + str(i) + ' with loss ' + str(
loss.cpu().detach().numpy()))
# save
if epoch > args.epoches/3:
save_path = './model/' + args.save + '/'
if not os.path.exists(save_path):
os.mkdir(save_path)
torch.save(model.state_dict(), save_path + 'epoch-{}-step-{}-loss-{}.pth'.format(epoch, i, loss))
...@@ -69,7 +95,7 @@ def main(args): ...@@ -69,7 +95,7 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
# get parameter # get parameter
parser = ArgumentParser(description='BertAffinity') parser = ArgumentParser(description='BertAffinity')
parser.add_argument('-b', '--batch-size', default=1, type=int, parser.add_argument('-b', '--batch-size', default=8, type=int,
metavar='N', metavar='N',
help='mini-batch size (default: 16), this is the total ' help='mini-batch size (default: 16), this is the total '
'batch size of all GPUs on the current node when ' 'batch size of all GPUs on the current node when '
...@@ -84,6 +110,9 @@ if __name__ == '__main__': ...@@ -84,6 +110,9 @@ if __name__ == '__main__':
parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float, parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float,
metavar='LR', help='initial learning rate', dest='lr') metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--config', default='./config/config.json', type=str, help='model config file path') parser.add_argument('--config', default='./config/config.json', type=str, help='model config file path')
parser.add_argument('--log', default='training_log', type=str, help='training log')
parser.add_argument('--savedir', default='train', type=str, help='model save path')
parser.add_argument('--device', default='0', type=str, help='name of GPU')
args = parser.parse_args() args = parser.parse_args()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论