提交 e92d1555 作者: 朱学凯

add auto eval

上级 9ad087cc
...@@ -2,10 +2,8 @@ ...@@ -2,10 +2,8 @@
<project version="4"> <project version="4">
<component name="ChangeListManager"> <component name="ChangeListManager">
<list default="true" id="f877ac68-9cea-46d8-9125-207eebe5b5d6" name="Default Changelist" comment=""> <list default="true" id="f877ac68-9cea-46d8-9125-207eebe5b5d6" name="Default Changelist" comment="">
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" /> <change beforePath="$PROJECT_DIR$/pre_test.sh" beforeDir="false" afterPath="$PROJECT_DIR$/pre_test.sh" afterDir="false" />
<change beforePath="$PROJECT_DIR$/eval.py" beforeDir="false" afterPath="$PROJECT_DIR$/eval.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/run_interaction.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_interaction.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/experment_result/loss.png" beforeDir="false" afterPath="$PROJECT_DIR$/experment_result/learning_rate/loss.png" afterDir="false" />
<change beforePath="$PROJECT_DIR$/experment_result/loss.svg" beforeDir="false" />
</list> </list>
<option name="SHOW_DIALOG" value="false" /> <option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" /> <option name="HIGHLIGHT_CONFLICTS" value="true" />
...@@ -36,7 +34,7 @@ ...@@ -36,7 +34,7 @@
<recent name="$PROJECT_DIR$/experment_result/learning_rate" /> <recent name="$PROJECT_DIR$/experment_result/learning_rate" />
</key> </key>
</component> </component>
<component name="RunManager" selected="Python.eval"> <component name="RunManager" selected="Python.run_interaction">
<configuration name="eval" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true"> <configuration name="eval" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="CPI" /> <module name="CPI" />
<option name="INTERPRETER_OPTIONS" value="" /> <option name="INTERPRETER_OPTIONS" value="" />
...@@ -83,8 +81,8 @@ ...@@ -83,8 +81,8 @@
</configuration> </configuration>
<recent_temporary> <recent_temporary>
<list> <list>
<item itemvalue="Python.eval" />
<item itemvalue="Python.run_interaction" /> <item itemvalue="Python.run_interaction" />
<item itemvalue="Python.eval" />
</list> </list>
</recent_temporary> </recent_temporary>
</component> </component>
...@@ -114,19 +112,9 @@ ...@@ -114,19 +112,9 @@
</map> </map>
</option> </option>
</component> </component>
<component name="XDebuggerManager">
<breakpoint-manager>
<breakpoints>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/eval.py</url>
<line>36</line>
<option name="timeStamp" value="21" />
</line-breakpoint>
</breakpoints>
</breakpoint-manager>
</component>
<component name="com.intellij.coverage.CoverageDataManagerImpl"> <component name="com.intellij.coverage.CoverageDataManagerImpl">
<SUITE FILE_PATH="coverage/CPI$eval.coverage" NAME="eval Coverage Results" MODIFIED="1618396849549" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" /> <SUITE FILE_PATH="coverage/CPI$run_interaction.coverage" NAME="run_interaction Coverage Results" MODIFIED="1618475585639" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/CPI$run_interaction.coverage" NAME="run_interaction Coverage Results" MODIFIED="1618133791228" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" /> <SUITE FILE_PATH="coverage/CPI$eval.coverage" NAME="eval Coverage Results" MODIFIED="1618326256952" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/CPI$draft.coverage" NAME="draft Coverage Results" MODIFIED="1617456765793" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
</component> </component>
</project> </project>
\ No newline at end of file
CUDA_VISIBLE_DEVICES=1 python run_interaction.py --task=test --output=./predict/lr-1e-6-batch-64-layer3-0411-2-9439_test_sh --config=./config/config_layer_3.json --init=./model/lr-1e-6-batch-64-layer3-0411/epoch-2-step-9439-loss-2.063138484954834.pth CUDA_VISIBLE_DEVICES=1 python run_interaction.py --task=test --output=./predict/lr-1e-6-batch-64-layer3-0411-2-9439_test_sh --config=./config/config_layer_3.json --init=./model/lr-1e-6-batch-64-layer3-0411/epoch-2-step-9439-loss-2.063138484954834.pth --do_eval=True
\ No newline at end of file \ No newline at end of file
...@@ -8,6 +8,7 @@ from torch.utils.tensorboard import SummaryWriter ...@@ -8,6 +8,7 @@ from torch.utils.tensorboard import SummaryWriter
import os import os
from tqdm import tqdm from tqdm import tqdm
def get_task(task_name): def get_task(task_name):
if task_name.lower() == 'train': if task_name.lower() == 'train':
df_train = {"sps": './data/train_sps', df_train = {"sps": './data/train_sps',
...@@ -117,6 +118,9 @@ def test(args, model, dataset): ...@@ -117,6 +118,9 @@ def test(args, model, dataset):
for res in range(args.batch_size): for res in range(args.batch_size):
f.write(str(pred_affinity[res, :][0]) + '\n') f.write(str(pred_affinity[res, :][0]) + '\n')
if args.do_eval:
os.system('python eval.py')
def main(args): def main(args):
# load data # load data
...@@ -165,7 +169,7 @@ if __name__ == '__main__': ...@@ -165,7 +169,7 @@ if __name__ == '__main__':
parser.add_argument('--init', default='model', type=str, help='init checkpoint') parser.add_argument('--init', default='model', type=str, help='init checkpoint')
parser.add_argument('--output', default='predict', type=str, help='result save path') parser.add_argument('--output', default='predict', type=str, help='result save path')
# parser.add_argument('--shuffle', default=True, type=str, help='shuffle data') # parser.add_argument('--shuffle', default=True, type=str, help='shuffle data')
parser.add_argument('--do_eval', default=False, type=bool, help='do eval')
args = parser.parse_args() args = parser.parse_args()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论