提交 c27bacec 作者: 朱学凯

fix output

上级 825c3d12
...@@ -2,13 +2,12 @@ ...@@ -2,13 +2,12 @@
<project version="4"> <project version="4">
<component name="ChangeListManager"> <component name="ChangeListManager">
<list default="true" id="f877ac68-9cea-46d8-9125-207eebe5b5d6" name="Default Changelist" comment=""> <list default="true" id="f877ac68-9cea-46d8-9125-207eebe5b5d6" name="Default Changelist" comment="">
<change beforePath="$PROJECT_DIR$/.idea/CPI.iml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/CPI.iml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/misc.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" /> <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/model/lr-1e-5-batch-64-layer3/epoch-4-step-2000-loss-2.9310436248779297.pth" beforeDir="false" /> <change beforePath="$PROJECT_DIR$/pre_test.sh" beforeDir="false" afterPath="$PROJECT_DIR$/pre_test.sh" afterDir="false" />
<change beforePath="$PROJECT_DIR$/modeling_bert.py" beforeDir="false" afterPath="$PROJECT_DIR$/modeling_bert.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/predict/lr-1e-6-batch-64-layer3/results.txt" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/predict/test/results.txt" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/predict/test/test.txt" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/run_interaction.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_interaction.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/run_interaction.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_interaction.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/train.sh" beforeDir="false" afterPath="$PROJECT_DIR$/train.sh" afterDir="false" />
</list> </list>
<option name="SHOW_DIALOG" value="false" /> <option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" /> <option name="HIGHLIGHT_CONFLICTS" value="true" />
...@@ -89,18 +88,7 @@ ...@@ -89,18 +88,7 @@
</map> </map>
</option> </option>
</component> </component>
<component name="XDebuggerManager">
<breakpoint-manager>
<breakpoints>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/run_interaction.py</url>
<line>98</line>
<option name="timeStamp" value="13" />
</line-breakpoint>
</breakpoints>
</breakpoint-manager>
</component>
<component name="com.intellij.coverage.CoverageDataManagerImpl"> <component name="com.intellij.coverage.CoverageDataManagerImpl">
<SUITE FILE_PATH="coverage/CPI$run_interaction.coverage" NAME="run_interaction Coverage Results" MODIFIED="1618127107728" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" /> <SUITE FILE_PATH="coverage/CPI$run_interaction.coverage" NAME="run_interaction Coverage Results" MODIFIED="1618133791228" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
</component> </component>
</project> </project>
\ No newline at end of file
CUDA_VISIBLE_DEVICES=1 python run_interaction.py --task=test --output=./predict/test --config=./config/config_layer_3.json --init=./model/lr-1e-5-batch-64-layer3/epoch-4-step-2000-loss-2.9310436248779297.pth --shuffle=False CUDA_VISIBLE_DEVICES=1 python run_interaction.py --task=test --output=./predict/lr-1e-6-batch-64-layer3-0411-2-9439_test_sh --config=./config/config_layer_3.json --init=./model/lr-1e-6-batch-64-layer3-0411/epoch-2-step-9439-loss-2.063138484954834.pth
\ No newline at end of file \ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -36,7 +36,14 @@ def get_task(task_name): ...@@ -36,7 +36,14 @@ def get_task(task_name):
return df_test, tokenizer_config return df_test, tokenizer_config
def train(args, model, data_generator): def train(args, model, dataset):
data_loder_para = {'batch_size': args.batch_size,
'shuffle': True,
'num_workers': args.workers,
}
data_generator = DataLoader(dataset, **data_loder_para)
model.train() model.train()
opt = torch.optim.Adam(model.parameters(), lr=args.lr) opt = torch.optim.Adam(model.parameters(), lr=args.lr)
loss_fct = torch.nn.MSELoss() loss_fct = torch.nn.MSELoss()
...@@ -80,9 +87,18 @@ def train(args, model, data_generator): ...@@ -80,9 +87,18 @@ def train(args, model, data_generator):
print('training over') print('training over')
writer.close() writer.close()
def test(args, model, data_generator): def test(args, model, dataset):
data_loder_para = {'batch_size': args.batch_size,
'shuffle': False,
'num_workers': args.workers,
}
data_generator = DataLoader(dataset, **data_loder_para)
with torch.no_grad(): with torch.no_grad():
model.load_state_dict(torch.load(args.init, map_location=torch.device('cpu'))) if torch.cuda.is_available():
model.load_state_dict(torch.load(args.init), strict=True)
else:
model.load_state_dict(torch.load(args.init, map_location=torch.device('cpu')), strict=True)
model.eval() model.eval()
if not os.path.exists(args.output): if not os.path.exists(args.output):
os.mkdir(args.output) os.mkdir(args.output)
...@@ -103,11 +119,6 @@ def main(args): ...@@ -103,11 +119,6 @@ def main(args):
# load data # load data
data_file, tokenizer_config = get_task(args.task) data_file, tokenizer_config = get_task(args.task)
dataset = Data_Encoder(data_file, tokenizer_config) dataset = Data_Encoder(data_file, tokenizer_config)
data_loder_para = {'batch_size': args.batch_size,
'shuffle': args.shuffle,
'num_workers': args.workers,
}
data_generator = DataLoader(dataset, **data_loder_para)
# creat model # creat model
print('------------------creat model---------------------------') print('------------------creat model---------------------------')
config = BertConfig.from_pretrained(args.config) config = BertConfig.from_pretrained(args.config)
...@@ -116,10 +127,10 @@ def main(args): ...@@ -116,10 +127,10 @@ def main(args):
print('task name : {}'.format(args.task)) print('task name : {}'.format(args.task))
if args.task == 'train': if args.task == 'train':
train(args, model, data_generator) train(args, model, dataset)
elif args.task in ['test']: elif args.task in ['test']:
test(args, model, data_generator) test(args, model, dataset)
...@@ -150,7 +161,7 @@ if __name__ == '__main__': ...@@ -150,7 +161,7 @@ if __name__ == '__main__':
# parser.add_argument('--device', default='0', type=str, help='name of GPU') # parser.add_argument('--device', default='0', type=str, help='name of GPU')
parser.add_argument('--init', default='model', type=str, help='init checkpoint') parser.add_argument('--init', default='model', type=str, help='init checkpoint')
parser.add_argument('--output', default='predict', type=str, help='result save path') parser.add_argument('--output', default='predict', type=str, help='result save path')
parser.add_argument('--shuffle', default=True, type=str, help='shuffle data') # parser.add_argument('--shuffle', default=True, type=str, help='shuffle data')
args = parser.parse_args() args = parser.parse_args()
...@@ -166,12 +177,11 @@ if __name__ == '__main__': ...@@ -166,12 +177,11 @@ if __name__ == '__main__':
args.task = 'test' # args.task = 'test'
args.init = './model/lr-1e-6-batch-64-layer3-0411/epoch-2-step-9439-loss-2.063138484954834.pth' # assert args.init == './model/lr-1e-6-batch-64-layer3-0411/epoch-2-step-9439-loss-2.063138484954834.pth'
args.output = './predict/test' # assert args.output == './predict/lr-1e-6-batch-64-layer3-0411-2-9439_test'
args.config = './config/config_layer_3.json' # args.config = './config/config_layer_3.json'
args.shuffle = False # assert args.shuffle == False
main(args) main(args)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论