提交 c27bacec 作者: 朱学凯

fix output

上级 825c3d12
......@@ -2,13 +2,12 @@
<project version="4">
<component name="ChangeListManager">
<list default="true" id="f877ac68-9cea-46d8-9125-207eebe5b5d6" name="Default Changelist" comment="">
<change beforePath="$PROJECT_DIR$/.idea/CPI.iml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/CPI.iml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/misc.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/model/lr-1e-5-batch-64-layer3/epoch-4-step-2000-loss-2.9310436248779297.pth" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/modeling_bert.py" beforeDir="false" afterPath="$PROJECT_DIR$/modeling_bert.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/pre_test.sh" beforeDir="false" afterPath="$PROJECT_DIR$/pre_test.sh" afterDir="false" />
<change beforePath="$PROJECT_DIR$/predict/lr-1e-6-batch-64-layer3/results.txt" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/predict/test/results.txt" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/predict/test/test.txt" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/run_interaction.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_interaction.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/train.sh" beforeDir="false" afterPath="$PROJECT_DIR$/train.sh" afterDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
......@@ -89,18 +88,7 @@
</map>
</option>
</component>
<component name="XDebuggerManager">
<breakpoint-manager>
<breakpoints>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/run_interaction.py</url>
<line>98</line>
<option name="timeStamp" value="13" />
</line-breakpoint>
</breakpoints>
</breakpoint-manager>
</component>
<component name="com.intellij.coverage.CoverageDataManagerImpl">
<SUITE FILE_PATH="coverage/CPI$run_interaction.coverage" NAME="run_interaction Coverage Results" MODIFIED="1618127107728" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/CPI$run_interaction.coverage" NAME="run_interaction Coverage Results" MODIFIED="1618133791228" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
</component>
</project>
\ No newline at end of file
CUDA_VISIBLE_DEVICES=1 python run_interaction.py --task=test --output=./predict/test --config=./config/config_layer_3.json --init=./model/lr-1e-5-batch-64-layer3/epoch-4-step-2000-loss-2.9310436248779297.pth --shuffle=False
\ No newline at end of file
CUDA_VISIBLE_DEVICES=1 python run_interaction.py --task=test --output=./predict/lr-1e-6-batch-64-layer3-0411-2-9439_test_sh --config=./config/config_layer_3.json --init=./model/lr-1e-6-batch-64-layer3-0411/epoch-2-step-9439-loss-2.063138484954834.pth
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -36,7 +36,14 @@ def get_task(task_name):
return df_test, tokenizer_config
def train(args, model, data_generator):
def train(args, model, dataset):
data_loder_para = {'batch_size': args.batch_size,
'shuffle': True,
'num_workers': args.workers,
}
data_generator = DataLoader(dataset, **data_loder_para)
model.train()
opt = torch.optim.Adam(model.parameters(), lr=args.lr)
loss_fct = torch.nn.MSELoss()
......@@ -80,9 +87,18 @@ def train(args, model, data_generator):
print('training over')
writer.close()
def test(args, model, data_generator):
def test(args, model, dataset):
data_loder_para = {'batch_size': args.batch_size,
'shuffle': False,
'num_workers': args.workers,
}
data_generator = DataLoader(dataset, **data_loder_para)
with torch.no_grad():
model.load_state_dict(torch.load(args.init, map_location=torch.device('cpu')))
if torch.cuda.is_available():
model.load_state_dict(torch.load(args.init), strict=True)
else:
model.load_state_dict(torch.load(args.init, map_location=torch.device('cpu')), strict=True)
model.eval()
if not os.path.exists(args.output):
os.mkdir(args.output)
......@@ -103,11 +119,6 @@ def main(args):
# load data
data_file, tokenizer_config = get_task(args.task)
dataset = Data_Encoder(data_file, tokenizer_config)
data_loder_para = {'batch_size': args.batch_size,
'shuffle': args.shuffle,
'num_workers': args.workers,
}
data_generator = DataLoader(dataset, **data_loder_para)
# creat model
print('------------------creat model---------------------------')
config = BertConfig.from_pretrained(args.config)
......@@ -116,10 +127,10 @@ def main(args):
print('task name : {}'.format(args.task))
if args.task == 'train':
train(args, model, data_generator)
train(args, model, dataset)
elif args.task in ['test']:
test(args, model, data_generator)
test(args, model, dataset)
......@@ -150,7 +161,7 @@ if __name__ == '__main__':
# parser.add_argument('--device', default='0', type=str, help='name of GPU')
parser.add_argument('--init', default='model', type=str, help='init checkpoint')
parser.add_argument('--output', default='predict', type=str, help='result save path')
parser.add_argument('--shuffle', default=True, type=str, help='shuffle data')
# parser.add_argument('--shuffle', default=True, type=str, help='shuffle data')
args = parser.parse_args()
......@@ -166,12 +177,11 @@ if __name__ == '__main__':
args.task = 'test'
args.init = './model/lr-1e-6-batch-64-layer3-0411/epoch-2-step-9439-loss-2.063138484954834.pth'
args.output = './predict/test'
args.config = './config/config_layer_3.json'
args.shuffle = False
# args.task = 'test'
# assert args.init == './model/lr-1e-6-batch-64-layer3-0411/epoch-2-step-9439-loss-2.063138484954834.pth'
# assert args.output == './predict/lr-1e-6-batch-64-layer3-0411-2-9439_test'
# args.config = './config/config_layer_3.json'
# assert args.shuffle == False
main(args)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论