提交 77e5c7c0 作者: 朱学凯

fix step error

上级 d2a30458
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4"> <module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager"> <component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" /> <content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" /> <orderEntry type="jdk" jdkName="Python 3.6 (py3.6)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" /> <orderEntry type="sourceFolder" forTests="false" />
</component> </component>
</module> </module>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (code)" project-jdk-type="Python SDK" /> <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (py3.6)" project-jdk-type="Python SDK" />
</project> </project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="ChangeListManager"> <component name="ChangeListManager">
<list default="true" id="f877ac68-9cea-46d8-9125-207eebe5b5d6" name="Default Changelist" comment=""> <list default="true" id="b4fb7f33-5387-4628-bcdb-b1b79dd926d0" name="Default Changelist" comment="">
<change beforePath="$PROJECT_DIR$/.idea/CPI.iml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/CPI.iml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/misc.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" /> <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/run_interaction.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_interaction.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/run_interaction.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_interaction.py" afterDir="false" />
</list> </list>
...@@ -84,6 +86,7 @@ ...@@ -84,6 +86,7 @@
</option> </option>
</component> </component>
<component name="com.intellij.coverage.CoverageDataManagerImpl"> <component name="com.intellij.coverage.CoverageDataManagerImpl">
<SUITE FILE_PATH="coverage/CPI$run_interaction.coverage" NAME="run_interaction Coverage Results" MODIFIED="1617888322915" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" /> <SUITE FILE_PATH="coverage/CPI$run_interaction.coverage" NAME="run_interaction Coverage Results" MODIFIED="1617283608264" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/CPI$draft.coverage" NAME="draft Coverage Results" MODIFIED="1617456765793" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
</component> </component>
</project> </project>
\ No newline at end of file
...@@ -41,7 +41,7 @@ def train(args, model, data_generator): ...@@ -41,7 +41,7 @@ def train(args, model, data_generator):
loss_fct = torch.nn.MSELoss() loss_fct = torch.nn.MSELoss()
writer = SummaryWriter('./log/' + args.savedir) writer = SummaryWriter('./log/' + args.savedir)
num_step = args.epochs * len(data_generator) num_step = args.epochs * len(data_generator)
step = 0
# detect GPU # detect GPU
if torch.cuda.is_available(): if torch.cuda.is_available():
model.cuda() model.cuda()
...@@ -56,7 +56,8 @@ def train(args, model, data_generator): ...@@ -56,7 +56,8 @@ def train(args, model, data_generator):
else: else:
pred_affinity = model(input.long()) pred_affinity = model(input.long())
loss = loss_fct(pred_affinity, affinity.unsqueeze(-1)) loss = loss_fct(pred_affinity, affinity.unsqueeze(-1))
writer.add_scalar('loss', loss, global_step=num_step) step += 1
writer.add_scalar('loss', loss, global_step=step)
# Update gradient # Update gradient
opt.zero_grad() opt.zero_grad()
loss.backward() loss.backward()
...@@ -70,7 +71,7 @@ def train(args, model, data_generator): ...@@ -70,7 +71,7 @@ def train(args, model, data_generator):
save_path = './model/' + args.savedir + '/' save_path = './model/' + args.savedir + '/'
if not os.path.exists(save_path): if not os.path.exists(save_path):
os.mkdir(save_path) os.mkdir(save_path)
torch.save(model.state_dict(), save_path + 'epoch-{}-step-{}-loss-{}.pth'.format(epoch, i, loss)) torch.save(model.state_dict(), save_path + 'epoch-{}-step-{}-loss-{}.pth'.format(epoch, step, loss))
print('training over') print('training over')
writer.close() writer.close()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论