提交 4d4d4982 作者: 朱学凯

fix output

上级 6e471cf2
......@@ -2,12 +2,9 @@
<project version="4">
<component name="ChangeListManager">
<list default="true" id="f877ac68-9cea-46d8-9125-207eebe5b5d6" name="Default Changelist" comment="">
<change beforePath="$PROJECT_DIR$/.idea/CPI.iml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/CPI.iml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/misc.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.gitignore" beforeDir="false" afterPath="$PROJECT_DIR$/.gitignore" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/dataset.py" beforeDir="false" afterPath="$PROJECT_DIR$/dataset.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/run_interaction.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_interaction.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/utils/data_analyse_train.tsv" beforeDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
......@@ -54,7 +51,7 @@
<recent name="$PROJECT_DIR$/experment_result/learning_rate" />
</key>
</component>
<component name="RunManager" selected="Python.test">
<component name="RunManager" selected="Python.run_interaction">
<configuration name="dataset" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="CPI" />
<option name="INTERPRETER_OPTIONS" value="" />
......@@ -145,8 +142,8 @@
</configuration>
<recent_temporary>
<list>
<item itemvalue="Python.test" />
<item itemvalue="Python.run_interaction" />
<item itemvalue="Python.test" />
<item itemvalue="Python.dataset" />
<item itemvalue="Python.eval" />
</list>
......@@ -193,6 +190,6 @@
<SUITE FILE_PATH="coverage/CPI$dataset.coverage" NAME="dataset Coverage Results" MODIFIED="1618641059668" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/CPI$test.coverage" NAME="test Coverage Results" MODIFIED="1618643206375" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/CPI$eval.coverage" NAME="eval Coverage Results" MODIFIED="1618396849549" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/CPI$run_interaction.coverage" NAME="run_interaction Coverage Results" MODIFIED="1618642769537" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/CPI$run_interaction.coverage" NAME="run_interaction Coverage Results" MODIFIED="1618917155441" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
</component>
</project>
\ No newline at end of file
......@@ -116,8 +116,10 @@ def test(args, model, dataset):
attention_mask=input_mask.cuda())
else:
pred_affinity = model(input_ids=input, token_type_ids=token_type_ids, attention_mask=input_mask)
pred_affinity = pred_affinity.cpu().numpy()
for res in range(args.batch_size):
f.write(str(pred_affinity[res, :][0]) + '\n')
pred = pred_affinity[res, :][0]
f.write(str(pred) + '\n')
if args.do_eval:
os.system('python eval.py')
......@@ -184,10 +186,9 @@ if __name__ == '__main__':
# args.task = 'test'
# assert args.init == './model/lr-1e-6-batch-64-layer3-0411/epoch-2-step-9439-loss-2.063138484954834.pth'
# assert args.output == './predict/lr-1e-6-batch-64-layer3-0411-2-9439_test'
# args.config = './config/config_layer_3.json'
# assert args.shuffle == False
args.task = 'test'
args.init = './model/lr-1e-5-batch-32-e-10-layer3-0417-add-type-ids-and-mask/epoch-9-step-82370-loss-0.8841055645024439.pth'
args.output = './predict/test'
args.config = './config/config_layer_3.json'
main(args)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论