提交 dc50914b 作者: 朱学凯

adding predict

上级 ef8fe61a
No preview for this file type
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="1">
<item index="0" class="java.lang.String" itemvalue="subword-nmt" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (py3.6)" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (code)" project-jdk-type="Python SDK" />
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ChangeListManager">
<list default="true" id="b4fb7f33-5387-4628-bcdb-b1b79dd926d0" name="Default Changelist" comment="">
<list default="true" id="f877ac68-9cea-46d8-9125-207eebe5b5d6" name="Default Changelist" comment="">
<change beforePath="$PROJECT_DIR$/.idea/CPI.iml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/CPI.iml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/inspectionProfiles/Project_Default.xml" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/misc.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/modules.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/modules.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/vcs.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/vcs.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/config/config.json" beforeDir="false" afterPath="$PROJECT_DIR$/config/config.json" afterDir="false" />
<change beforePath="$PROJECT_DIR$/config/config_layer_3.json" beforeDir="false" afterPath="$PROJECT_DIR$/config/config_layer_3.json" afterDir="false" />
<change beforePath="$PROJECT_DIR$/log/lr-5e-6-batch-30-layer3/events.out.tfevents.1617189040.b1393040f57d.421.0" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/log/train/events.out.tfevents.1617283406.DESKTOP-K4GGDLG.1656.0" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/log/train/events.out.tfevents.1617283575.DESKTOP-K4GGDLG.15080.0" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/log/train/events.out.tfevents.1617283616.DESKTOP-K4GGDLG.14676.0" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/log/training_log/events.out.tfevents.1617113424.DESKTOP-K4GGDLG.11896.0" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/modeling_bert.py" beforeDir="false" afterPath="$PROJECT_DIR$/modeling_bert.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/run_interaction.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_interaction.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/train.sh" beforeDir="false" afterPath="$PROJECT_DIR$/train.sh" afterDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
......@@ -17,24 +24,18 @@
<component name="Git.Settings">
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
</component>
<component name="ProjectId" id="1qTnCnbIt5Qh7Sqj8SOdw3SKDsD" />
<component name="ProjectId" id="1qpu2Wq6VU5TQVQOm73pQEwAahA" />
<component name="ProjectViewState">
<option name="hideEmptyMiddlePackages" value="true" />
<option name="showLibraryContents" value="true" />
</component>
<component name="PropertiesComponent">
<property name="ASKED_ADD_EXTERNAL_FILES" value="true" />
<property name="RunOnceActivity.OpenProjectViewOnStart" value="true" />
<property name="RunOnceActivity.ShowReadmeOnStart" value="true" />
<property name="WebServerToolWindowFactoryState" value="false" />
<property name="restartRequiresConfirmation" value="false" />
</component>
<component name="RecentsManager">
<key name="MoveFile.RECENT_KEYS">
<recent name="E:\CPI\project\CPI" />
<recent name="E:\CPI\project\CPI\config" />
<recent name="E:\CPI\project\CPI\log\lr-5e-6-batch-30-layer3" />
</key>
</component>
<component name="RunManager">
<configuration name="run_interaction" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="CPI" />
......@@ -67,12 +68,12 @@
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
<component name="TaskManager">
<task active="true" id="Default" summary="Default task">
<changelist id="b4fb7f33-5387-4628-bcdb-b1b79dd926d0" name="Default Changelist" comment="" />
<created>1617112323231</created>
<changelist id="f877ac68-9cea-46d8-9125-207eebe5b5d6" name="Default Changelist" comment="" />
<created>1617788646167</created>
<option name="number" value="Default" />
<option name="presentableId" value="Default" />
<updated>1617112323231</updated>
<workItem from="1617112324809" duration="10450000" />
<updated>1617788646167</updated>
<workItem from="1617788647548" duration="4985000" />
</task>
<servers />
</component>
......@@ -90,18 +91,7 @@
</map>
</option>
</component>
<component name="XDebuggerManager">
<breakpoint-manager>
<breakpoints>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/modeling_bert.py</url>
<line>1989</line>
<option name="timeStamp" value="1" />
</line-breakpoint>
</breakpoints>
</breakpoint-manager>
</component>
<component name="com.intellij.coverage.CoverageDataManagerImpl">
<SUITE FILE_PATH="coverage/CPI$run_interaction.coverage" NAME="run_interaction Coverage Results" MODIFIED="1617283608264" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/CPI$run_interaction.coverage" NAME="run_interaction Coverage Results" MODIFIED="1617888322915" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
</component>
</project>
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -1828,18 +1828,20 @@ class BertForQuestionAnswering(BertPreTrainedModel):
class Multilayer_perceptron(nn.Module):
def __init__(self, config):
super(Multilayer_perceptron, self).__init__()
self.layer_1 = nn.Linear(config.hidden_size, 512)
self.layer_1 = nn.Linear(config.hidden_size, 256)
self.layer_2 = nn.Linear(512, 256)
self.layer_3 = nn.Linear(256, 1)
self.drop_out = nn.Dropout(0.5)
def forward(self, bert_output):
x1 = self.layer_1(bert_output)
x1 = F.relu(x1, inplace=True)
x2 = self.layer_2(x1)
x2 = F.relu(x2, inplace=True)
x3 = self.layer_3(x2)
x1 = self.drop_out(x1)
# x2 = self.layer_2(x1)
# x2 = F.relu(x2, inplace=True)
# x2 = self.drop_out(x2)
x2 = self.layer_3(x1)
return x3
return x2
class BertAffinityModel(BertPreTrainedModel):
......
[6.3596783]
[6.149239]
[6.2786694]
[6.2630787]
[6.267824]
[6.3351264]
[6.330311]
[6.4089146]
[7.237949]
[7.134188]
[6.4724584]
[6.70895]
[6.5461855]
[6.462588]
[7.3434105]
[6.320741]
[5.8242393]
[6.1270766]
[5.861475]
[6.110052]
[6.1946096]
[6.8105106]
[6.7464366]
[7.4582233]
[7.283591]
[6.4545326]
[7.1453023]
[7.4970183]
[7.3376856]
[7.326717]
[7.376761]
[7.1191435]
[6.465303]
[6.118551]
[6.5188856]
[7.1967335]
[7.111934]
[6.4573274]
[6.8755574]
[7.310424]
[6.472396]
[6.7267003]
[7.2710166]
[6.549745]
[6.5536447]
[6.924736]
[6.8095484]
[6.9481206]
[6.5082817]
[7.0767627]
[6.4931135]
[6.120075]
[6.6630616]
[7.108821]
[6.150074]
[6.7352633]
[6.873269]
[6.0789833]
[6.507503]
[6.8817987]
[6.3812904]
[6.906433]
[7.3494034]
[6.3384047]
[6.792997]
[6.6977696]
[6.940779]
[6.7653384]
[7.098506]
[6.5834336]
[6.988887]
[6.3332806]
[6.938017]
[6.7161164]
[6.692145]
[6.300637]
[7.020443]
[6.6622524]
[6.6540895]
[6.9563365]
[6.433102]
[6.193932]
[7.1243067]
[7.1229043]
[7.304058]
[7.0475173]
[6.454037]
[6.411943]
[6.214428]
[7.13115]
[7.262042]
[7.285933]
[6.356842]
[6.159394]
[6.2058687]
[6.2058687]
[6.193981]
[6.194264]
[6.235496]
[6.16327]
[6.429044]
[6.2092743]
[6.0436616]
[6.37629]
[6.0988636]
[7.0545926]
[6.1575265]
[6.232292]
[6.150721]
[6.203047]
[6.5228686]
[6.3344517]
[6.530283]
[6.25474]
[6.2900543]
[6.118474]
[6.281304]
[6.458498]
[6.531927]
[6.2390556]
[6.172672]
[6.341542]
[6.535589]
[6.155416]
[6.6641674]
[6.2277637]
[6.0520573]
[6.3962502]
[6.3962502]
[6.400052]
[6.175363]
[6.260881]
[6.1761503]
[6.169209]
[5.889789]
[6.748949]
[5.9564304]
[6.8459992]
[6.5334473]
[6.563169]
[6.758628]
[6.911051]
[7.1392055]
[6.846547]
[6.8458405]
[6.6978965]
[6.676156]
[6.7089925]
[6.9050856]
[5.682077]
[7.0423]
[6.5804777]
[6.3241353]
[6.071189]
[6.713427]
[6.6073346]
[6.5932746]
[6.9655457]
[6.6886163]
[7.084936]
[6.5593066]
[6.8145075]
[6.85552]
[7.0440683]
[6.7756834]
[7.0484014]
[7.4163127]
[7.250523]
[7.3405814]
[6.773723]
[6.6779857]
[7.1147037]
[7.422548]
[7.427442]
[6.503265]
[6.878583]
[6.5071664]
[7.1635723]
[6.984735]
[6.5814033]
[6.760788]
[7.23371]
[7.00921]
[6.830633]
[6.8368325]
[6.34864]
[6.933655]
[6.9342484]
[6.777509]
[6.5005274]
[6.2883234]
[6.5519056]
[6.933951]
[6.889946]
[7.018728]
[6.8836727]
[7.0157747]
[7.101351]
[6.029176]
[6.831605]
[6.9702425]
[6.82664]
[7.1038437]
[6.9983196]
[6.8305345]
[7.2107124]
[6.966565]
[7.0524]
[6.8862734]
[6.886519]
[7.0750933]
[6.1289387]
[6.7051854]
[6.5207334]
[6.752949]
[6.787603]
[6.4389377]
[6.488281]
[6.9817696]
[6.8437653]
[6.8910346]
[5.7938113]
[5.7771935]
[5.551153]
[5.8725457]
[6.3399625]
[6.2708063]
[5.8033767]
[6.1500363]
[6.1604433]
[6.604599]
[6.724487]
[6.568345]
[6.245967]
[6.5355034]
[5.509695]
[5.5034876]
[6.145286]
[6.1130595]
[6.068443]
[6.20528]
[6.776703]
[6.3172307]
[6.323985]
[6.101276]
[5.768876]
[5.665632]
[6.5093927]
[6.829164]
[6.1964192]
[5.996513]
[6.125445]
[6.1920133]
[6.212432]
[6.369281]
[6.532722]
[6.590001]
[6.1853275]
[6.319353]
[6.4810433]
[6.9386272]
[6.738956]
[7.1271806]
[7.254016]
[5.917221]
[6.3355255]
[6.279427]
[6.154666]
[6.268687]
[6.750529]
[5.7007713]
[5.9609857]
[5.999047]
[6.296248]
[6.7443557]
[6.3195715]
[6.3289285]
[6.2859406]
[6.0315843]
[6.0488844]
[6.057394]
[6.6425843]
[6.3916855]
[6.3705444]
[6.681577]
[6.2595778]
[6.3593593]
[6.2299366]
[6.239003]
[6.167323]
[6.5075097]
[6.532083]
[6.2562227]
[6.264555]
[6.728636]
[6.1108727]
[6.780371]
[6.0572762]
[6.315681]
[7.579753]
[6.5053773]
[7.2419486]
[6.570504]
[6.680853]
[6.899315]
[6.899315]
[6.900144]
[6.9840407]
[5.6793914]
[5.9459867]
[6.2629614]
[6.0814476]
[6.3267517]
[5.539407]
[5.882191]
[5.494314]
[5.6211386]
[5.515706]
[5.801549]
[5.468592]
[6.0084925]
[5.7175364]
[5.6570005]
[6.2782784]
[6.7088523]
[6.7548356]
[6.747021]
[6.8559723]
[6.8976417]
[6.974851]
[6.8754845]
[6.8688793]
[7.139256]
[7.071559]
[6.897979]
[6.9007335]
[6.8505716]
[6.9822254]
[6.7001185]
[6.972347]
[6.8769293]
[6.6995554]
[6.8163023]
[6.8163023]
[6.6987977]
[6.873486]
[6.7295995]
[6.8747244]
[6.4349375]
[6.562923]
[6.523191]
[6.5495043]
[7.2503533]
[6.514708]
[6.4954786]
[6.378421]
[6.372544]
[6.3874154]
[6.3919272]
[6.5428963]
[6.489891]
[6.5428815]
[6.9015393]
[6.708762]
[6.455141]
[6.949461]
[6.9181933]
[6.879914]
[7.0192304]
[6.92454]
[6.8427377]
[6.7819786]
[6.9665337]
[6.792196]
[7.1155367]
[6.9954014]
[6.8167653]
[6.8305454]
[7.089679]
[6.7758846]
[7.0728493]
[6.8687367]
[6.978027]
[6.978027]
[5.4763994]
[5.7276473]
[6.090256]
[5.7736]
[5.5663285]
[5.706211]
[5.5991273]
[5.59329]
[5.6816583]
[5.688107]
[5.616148]
[5.522856]
[5.9210315]
[7.209167]
[7.187614]
[6.823409]
[6.660959]
[7.0327945]
[7.1511526]
[6.894468]
[6.9579277]
[6.410731]
[6.5899734]
[6.144295]
[5.8612866]
[6.352736]
[5.794497]
[6.6813993]
[6.5315557]
[6.703766]
[6.589502]
[6.5612297]
[6.662801]
[6.821087]
[6.9433713]
[6.918713]
[5.762656]
[5.915827]
[5.923086]
[6.1559224]
[5.7713246]
[6.0018964]
[6.3544703]
[6.1085567]
[6.103257]
[6.270781]
[6.1643667]
[6.1894827]
[6.0769196]
[6.0201955]
[6.1858096]
[6.3710337]
[6.31503]
[6.3787436]
[6.185884]
[6.3232837]
[6.1559224]
[6.6858497]
[5.6684117]
[5.985989]
[5.969078]
[6.208828]
[5.7348866]
[6.146168]
[5.8905745]
[6.2671876]
[7.0127273]
[6.800625]
[6.8149886]
[6.5758567]
[6.7115874]
[7.129005]
[6.6799273]
[6.7200594]
[6.801302]
[7.128217]
[6.541928]
[6.8825183]
[7.0010886]
[7.2036786]
[7.087343]
[6.231343]
[6.190108]
[6.439097]
[6.2274537]
[5.945822]
[6.5451407]
[6.215738]
[6.0463524]
[5.8921466]
[6.2184496]
[6.142793]
[5.9534755]
[5.9359536]
[6.0487404]
[6.6750526]
[5.58417]
[6.747651]
[6.653012]
[6.575788]
[6.3520694]
[6.954247]
[6.774977]
[7.261607]
[6.6315556]
[6.5408077]
[6.5227594]
[6.658483]
[6.4391403]
[6.637759]
[6.7642226]
[6.919209]
......@@ -6,6 +6,7 @@ from configuration_bert import BertConfig
from modeling_bert import BertAffinityModel
from torch.utils.tensorboard import SummaryWriter
import os
from tqdm import tqdm
def get_task(task_name):
if task_name.lower() == 'train':
......@@ -35,23 +36,7 @@ def get_task(task_name):
return df_test, tokenizer_config
def main(args):
# load data
data_file, tokenizer_config = get_task(args.task)
dataset = Data_Encoder(data_file, tokenizer_config)
data_loder_para = {'batch_size': args.batch_size,
'shuffle': False,
'num_workers': args.workers,
'drop_last': True
}
data_generator = DataLoader(dataset, **data_loder_para)
# creat model
print('------------------creat model---------------------------')
config = BertConfig.from_pretrained(args.config)
model = BertAffinityModel(config)
print('model name : BertAffinity')
if args.task == 'train':
def train(args, model, data_generator):
opt = torch.optim.Adam(model.parameters(), lr=args.lr)
loss_fct = torch.nn.MSELoss()
writer = SummaryWriter('./log/' + args.savedir)
......@@ -64,9 +49,6 @@ def main(args):
for epoch in range(args.epochs):
for i, (input, affinity) in enumerate(data_generator):
# use cuda
# if torch.cuda.is_available():
# input.cuda()
# affinity.cuda()
# input model
if torch.cuda.is_available():
pred_affinity = model(input.cuda().long())
......@@ -74,13 +56,13 @@ def main(args):
else:
pred_affinity = model(input.long())
loss = loss_fct(pred_affinity, affinity.unsqueeze(-1))
writer.add_scalar('loss', loss, num_step)
writer.add_scalar('loss', loss, global_step=num_step)
# Update gradient
opt.zero_grad()
loss.backward()
opt.step()
# if (i % 100 == 0):
# if (i % 100 == 0):
print('Training at Epoch ' + str(epoch + 1) + ' step ' + str(i) + ' with loss ' + str(
loss.cpu().detach().numpy()))
# save
......@@ -89,6 +71,49 @@ def main(args):
if not os.path.exists(save_path):
os.mkdir(save_path)
torch.save(model.state_dict(), save_path + 'epoch-{}-step-{}-loss-{}.pth'.format(epoch, i, loss))
print('training over')
writer.close()
def test(args, model, data_generator):
with torch.no_grad():
model.load_state_dict(torch.load(args.init, map_location=torch.device('cpu')))
model.eval()
if not os.path.exists(args.output):
os.mkdir(args.output)
result = args.output + '/' + 'results.txt'
print('begin predicting')
with open(result, 'w') as f:
for i, (input, affinity) in enumerate(tqdm(data_generator)):
if torch.cuda.is_available():
model.cuda()
pred_affinity = model(input.cuda().long()).detach().cpu().numpy()
else:
pred_affinity = model(input.long()).numpy()
for res in range(args.batch_size):
f.write(str(pred_affinity[res, :]) + '\n')
def main(args):
# load data
data_file, tokenizer_config = get_task(args.task)
dataset = Data_Encoder(data_file, tokenizer_config)
data_loder_para = {'batch_size': args.batch_size,
'shuffle': False,
'num_workers': args.workers,
'drop_last': True
}
data_generator = DataLoader(dataset, **data_loder_para)
# creat model
print('------------------creat model---------------------------')
config = BertConfig.from_pretrained(args.config)
model = BertAffinityModel(config)
print('model name : BertAffinity')
if args.task == 'train':
train(args, model, data_generator)
elif args.task == 'test':
test(args, model, data_generator)
......@@ -114,12 +139,19 @@ if __name__ == '__main__':
parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--config', default='./config/config.json', type=str, help='model config file path')
# parser.add_argument('--log', default='training_log', type=str, help='training log')
# parser.add_argument('--log', default='training_log', type=str, help='training log')
parser.add_argument('--savedir', default='train', type=str, help='log and model save path')
# parser.add_argument('--device', default='0', type=str, help='name of GPU')
args = parser.parse_args()
parser.add_argument('--init', default='model', type=str, help='init checkpoint')
parser.add_argument('--output', default='predict', type=str, help='result save path')
args = parser.parse_args()
args.task = 'test'
args.init = './model/lr-1e-5-batch-64-layer3/epoch-4-step-2000-loss-2.9310436248779297.pth'
args.output = './predict/test'
args.config = './config/config_layer_3.json'
# input dict
main(args)
# load_config(args.config)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论