提交 257550c3 作者: root

add train.sh

上级 38686e31
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ChangeListManager">
<list default="true" id="b4fb7f33-5387-4628-bcdb-b1b79dd926d0" name="Default Changelist" comment="">
<change beforePath="$PROJECT_DIR$/run_interaction.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_interaction.py" afterDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
<option name="LAST_RESOLUTION" value="IGNORE" />
</component>
<component name="Git.Settings">
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
</component>
<component name="ProjectId" id="1qTnCnbIt5Qh7Sqj8SOdw3SKDsD" />
<component name="ProjectViewState">
<option name="hideEmptyMiddlePackages" value="true" />
<option name="showLibraryContents" value="true" />
</component>
<component name="PropertiesComponent">
<property name="RunOnceActivity.OpenProjectViewOnStart" value="true" />
<property name="RunOnceActivity.ShowReadmeOnStart" value="true" />
<property name="WebServerToolWindowFactoryState" value="false" />
</component>
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
<component name="TaskManager">
<task active="true" id="Default" summary="Default task">
<changelist id="b4fb7f33-5387-4628-bcdb-b1b79dd926d0" name="Default Changelist" comment="" />
<created>1617112323231</created>
<option name="number" value="Default" />
<option name="presentableId" value="Default" />
<updated>1617112323231</updated>
<workItem from="1617112324809" duration="1883000" />
</task>
<servers />
</component>
<component name="TypeScriptGeneratedFilesManager">
<option name="version" value="3" />
</component>
</project>
\ No newline at end of file
......@@ -59,28 +59,28 @@ def main(args):
# detect GPU
if torch.cuda.is_available():
model.cuda(args.device)
model.cuda()
print('begin training')
for epoch in range(args.epochs):
for i, (input, affinity) in enumerate(data_generator):
# use cuda
if torch.cuda.is_available():
input.cuda(args.device)
affinity.cuda(args.divice)
# if torch.cuda.is_available():
# input.cuda()
# affinity.cuda()
# input model
pred_affinity = model(input.long())
loss = loss_fct(pred_affinity, affinity.unsqueeze(-1))
pred_affinity = model(input.cuda().long())
loss = loss_fct(pred_affinity, affinity.cuda().unsqueeze(-1))
writer.add_scalar('loss', loss, num_step)
# Update gradient
opt.zero_grad()
loss.backward()
opt.step()
if (i % 100 == 0):
print('Training at Epoch ' + str(epoch + 1) + ' step ' + str(i) + ' with loss ' + str(
# if (i % 100 == 0):
print('Training at Epoch ' + str(epoch + 1) + ' step ' + str(i) + ' with loss ' + str(
loss.cpu().detach().numpy()))
# save
if epoch > args.epoches/3:
if epoch > 1:
save_path = './model/' + args.save + '/'
if not os.path.exists(save_path):
os.mkdir(save_path)
......@@ -112,7 +112,7 @@ if __name__ == '__main__':
parser.add_argument('--config', default='./config/config.json', type=str, help='model config file path')
parser.add_argument('--log', default='training_log', type=str, help='training log')
parser.add_argument('--savedir', default='train', type=str, help='model save path')
parser.add_argument('--device', default='0', type=str, help='name of GPU')
# parser.add_argument('--device', default='0', type=str, help='name of GPU')
args = parser.parse_args()
......
CUDA_VISIBLE_DEVICES=0 python run_interaction.py --b=30 --task=train --epochs=10 --lr=5e-5 --log=lr-5e-5-batch-30 --savedir=lr-5e-5-batch-30
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论