提交 d2a30458 作者: 朱学凯

adding predict

上级 dc50914b
...@@ -2,18 +2,7 @@ ...@@ -2,18 +2,7 @@
<project version="4"> <project version="4">
<component name="ChangeListManager"> <component name="ChangeListManager">
<list default="true" id="f877ac68-9cea-46d8-9125-207eebe5b5d6" name="Default Changelist" comment=""> <list default="true" id="f877ac68-9cea-46d8-9125-207eebe5b5d6" name="Default Changelist" comment="">
<change beforePath="$PROJECT_DIR$/.idea/CPI.iml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/CPI.iml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/inspectionProfiles/Project_Default.xml" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/misc.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/modules.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/modules.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/vcs.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/vcs.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" /> <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/log/lr-5e-6-batch-30-layer3/events.out.tfevents.1617189040.b1393040f57d.421.0" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/log/train/events.out.tfevents.1617283406.DESKTOP-K4GGDLG.1656.0" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/log/train/events.out.tfevents.1617283575.DESKTOP-K4GGDLG.15080.0" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/log/train/events.out.tfevents.1617283616.DESKTOP-K4GGDLG.14676.0" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/log/training_log/events.out.tfevents.1617113424.DESKTOP-K4GGDLG.11896.0" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/modeling_bert.py" beforeDir="false" afterPath="$PROJECT_DIR$/modeling_bert.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/run_interaction.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_interaction.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/run_interaction.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_interaction.py" afterDir="false" />
</list> </list>
<option name="SHOW_DIALOG" value="false" /> <option name="SHOW_DIALOG" value="false" />
...@@ -25,6 +14,9 @@ ...@@ -25,6 +14,9 @@
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" /> <option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
</component> </component>
<component name="ProjectId" id="1qpu2Wq6VU5TQVQOm73pQEwAahA" /> <component name="ProjectId" id="1qpu2Wq6VU5TQVQOm73pQEwAahA" />
<component name="ProjectLevelVcsManager">
<ConfirmationsSetting value="1" id="Add" />
</component>
<component name="ProjectViewState"> <component name="ProjectViewState">
<option name="hideEmptyMiddlePackages" value="true" /> <option name="hideEmptyMiddlePackages" value="true" />
<option name="showLibraryContents" value="true" /> <option name="showLibraryContents" value="true" />
...@@ -73,7 +65,7 @@ ...@@ -73,7 +65,7 @@
<option name="number" value="Default" /> <option name="number" value="Default" />
<option name="presentableId" value="Default" /> <option name="presentableId" value="Default" />
<updated>1617788646167</updated> <updated>1617788646167</updated>
<workItem from="1617788647548" duration="4985000" /> <workItem from="1617788647548" duration="5550000" />
</task> </task>
<servers /> <servers />
</component> </component>
......
CUDA_VISIBLE_DEVICES=1 python run_interaction.py --task=test --output=./predict/test --config=./config/config_layer_3.json --init=./model/lr-1e-5-batch-64-layer3/epoch-4-step-2000-loss-2.9310436248779297.pth
\ No newline at end of file
...@@ -112,7 +112,7 @@ def main(args): ...@@ -112,7 +112,7 @@ def main(args):
if args.task == 'train': if args.task == 'train':
train(args, model, data_generator) train(args, model, data_generator)
elif args.task == 'test': elif args.task in ['test']:
test(args, model, data_generator) test(args, model, data_generator)
...@@ -148,11 +148,12 @@ if __name__ == '__main__': ...@@ -148,11 +148,12 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
args.task = 'test' # local test
args.init = './model/lr-1e-5-batch-64-layer3/epoch-4-step-2000-loss-2.9310436248779297.pth' # args.task = 'test'
args.output = './predict/test' # args.init = './model/lr-1e-5-batch-64-layer3/epoch-4-step-2000-loss-2.9310436248779297.pth'
args.config = './config/config_layer_3.json' # args.output = './predict/test'
# input dict # args.config = './config/config_layer_3.json'
main(args) main(args)
# load_config(args.config)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论