提交 9ad087cc 作者: 朱学凯

Merge branch 'main' of github.com:Xuekai-Zhu/CPI into main

# Conflicts:
#	.idea/workspace.xml
#	eval.py
......@@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Python 3.6 (code)" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="Python 3.6 (py3.6)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (code)" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (py3.6)" project-jdk-type="Python SDK" />
</project>
\ No newline at end of file
......@@ -5,14 +5,14 @@
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"hidden_size": 384,
"initializer_range": 0.02,
"intermediate_size": 1536,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"max_position_embeddings": 384,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 3,
"num_hidden_layers": 12,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 23615
......
......@@ -9,7 +9,7 @@
"initializer_range": 0.02,
"intermediate_size": 1536,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"max_position_embeddings": 384,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 3,
......
......@@ -5,11 +5,11 @@
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"hidden_size": 384,
"initializer_range": 0.02,
"intermediate_size": 3072,
"intermediate_size": 1536,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"max_position_embeddings": 384,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 6,
......
......@@ -5,11 +5,11 @@
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"hidden_size": 384,
"initializer_range": 0.02,
"intermediate_size": 3072,
"intermediate_size": 1536,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"max_position_embeddings": 384,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 9,
......
RMSE : 1.4722575781012768 ; Pearson Correlation Coefficient : 0.2622760899895615
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -50,6 +50,7 @@ def train(args, model, dataset):
writer = SummaryWriter('./log/' + args.savedir)
num_step = args.epochs * len(data_generator)
step = 0
save_step = num_step // 5
# detect GPU
if torch.cuda.is_available():
model.cuda()
......@@ -57,7 +58,9 @@ def train(args, model, dataset):
print('epoch num : {}'.format(args.epochs))
print('step num : {}'.format(num_step))
print('batch size : {}'.format(args.batch_size))
print('learning rate : {}'.format(args.lr))
print('begin training')
# training
for epoch in range(args.epochs):
for i, (input, affinity) in enumerate(data_generator):
# use cuda
......@@ -79,7 +82,7 @@ def train(args, model, dataset):
print('Training at Epoch ' + str(epoch + 1) + ' step ' + str(step) + ' with loss ' + str(
loss.cpu().detach().numpy()))
# save
if epoch > 1 and epoch % 2 == 0 and i % 1200 == 0:
if epoch >= 1 and step % save_step == 0:
save_path = './model/' + args.savedir + '/'
if not os.path.exists(save_path):
os.mkdir(save_path)
......@@ -153,7 +156,7 @@ if __name__ == '__main__':
parser.add_argument('--task', choices=['train', 'test', 'channel', 'ER', 'GPCR', 'kinase'],
default='train', type=str, metavar='TASK',
help='Task name. Could be train, test, channel, ER, GPCR, kinase.')
parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float,
parser.add_argument('--lr', '--learning-rate', default=1e-5, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--config', default='./config/config.json', type=str, help='model config file path')
# parser.add_argument('--log', default='training_log', type=str, help='training log')
......@@ -184,4 +187,3 @@ if __name__ == '__main__':
# assert args.shuffle == False
main(args)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论