提交 77e5c7c0 作者: 朱学凯

fix step error

上级 d2a30458
......@@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="jdk" jdkName="Python 3.6 (py3.6)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (code)" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (py3.6)" project-jdk-type="Python SDK" />
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ChangeListManager">
<list default="true" id="f877ac68-9cea-46d8-9125-207eebe5b5d6" name="Default Changelist" comment="">
<list default="true" id="b4fb7f33-5387-4628-bcdb-b1b79dd926d0" name="Default Changelist" comment="">
<change beforePath="$PROJECT_DIR$/.idea/CPI.iml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/CPI.iml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/misc.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/run_interaction.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_interaction.py" afterDir="false" />
</list>
......@@ -84,6 +86,7 @@
</option>
</component>
<component name="com.intellij.coverage.CoverageDataManagerImpl">
<SUITE FILE_PATH="coverage/CPI$run_interaction.coverage" NAME="run_interaction Coverage Results" MODIFIED="1617888322915" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/CPI$run_interaction.coverage" NAME="run_interaction Coverage Results" MODIFIED="1617283608264" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/CPI$draft.coverage" NAME="draft Coverage Results" MODIFIED="1617456765793" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
</component>
</project>
\ No newline at end of file
......@@ -41,7 +41,7 @@ def train(args, model, data_generator):
loss_fct = torch.nn.MSELoss()
writer = SummaryWriter('./log/' + args.savedir)
num_step = args.epochs * len(data_generator)
step = 0
# detect GPU
if torch.cuda.is_available():
model.cuda()
......@@ -56,7 +56,8 @@ def train(args, model, data_generator):
else:
pred_affinity = model(input.long())
loss = loss_fct(pred_affinity, affinity.unsqueeze(-1))
writer.add_scalar('loss', loss, global_step=num_step)
step += 1
writer.add_scalar('loss', loss, global_step=step)
# Update gradient
opt.zero_grad()
loss.backward()
......@@ -70,7 +71,7 @@ def train(args, model, data_generator):
save_path = './model/' + args.savedir + '/'
if not os.path.exists(save_path):
os.mkdir(save_path)
torch.save(model.state_dict(), save_path + 'epoch-{}-step-{}-loss-{}.pth'.format(epoch, i, loss))
torch.save(model.state_dict(), save_path + 'epoch-{}-step-{}-loss-{}.pth'.format(epoch, step, loss))
print('training over')
writer.close()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论