提交 25b01fae 作者: 朱学凯

add attention_mask

上级 94ca8a36
......@@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Python 3.6 (py3.6)" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="Python 3.6 (code)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (py3.6)" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (code)" project-jdk-type="Python SDK" />
</project>
\ No newline at end of file
Index: .idea/workspace.xml
Index: .idea/workspace.xml
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.BaseRevisionTextPatchEP
<+><?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n <component name=\"ChangeListManager\">\n <list default=\"true\" id=\"f877ac68-9cea-46d8-9125-207eebe5b5d6\" name=\"Default Changelist\" comment=\"\">\n <change beforePath=\"$PROJECT_DIR$/.idea/workspace.xml\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/.idea/workspace.xml\" afterDir=\"false\" />\n <change beforePath=\"$PROJECT_DIR$/eval.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/eval.py\" afterDir=\"false\" />\n <change beforePath=\"$PROJECT_DIR$/experment_result/loss.png\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/experment_result/learning_rate/loss.png\" afterDir=\"false\" />\n <change beforePath=\"$PROJECT_DIR$/experment_result/loss.svg\" beforeDir=\"false\" />\n </list>\n <option name=\"SHOW_DIALOG\" value=\"false\" />\n <option name=\"HIGHLIGHT_CONFLICTS\" value=\"true\" />\n <option name=\"HIGHLIGHT_NON_ACTIVE_CHANGELIST\" value=\"false\" />\n <option name=\"LAST_RESOLUTION\" value=\"IGNORE\" />\n </component>\n <component name=\"Git.Settings\">\n <option name=\"RECENT_GIT_ROOT_PATH\" value=\"$PROJECT_DIR$\" />\n <option name=\"UPDATE_TYPE\" value=\"REBASE\" />\n </component>\n <component name=\"ProjectId\" id=\"1qpu2Wq6VU5TQVQOm73pQEwAahA\" />\n <component name=\"ProjectLevelVcsManager\">\n <ConfirmationsSetting value=\"1\" id=\"Add\" />\n </component>\n <component name=\"ProjectViewState\">\n <option name=\"hideEmptyMiddlePackages\" value=\"true\" />\n <option name=\"showLibraryContents\" value=\"true\" />\n </component>\n <component name=\"PropertiesComponent\">\n <property name=\"ASKED_ADD_EXTERNAL_FILES\" value=\"true\" />\n <property name=\"RunOnceActivity.OpenProjectViewOnStart\" value=\"true\" />\n <property name=\"RunOnceActivity.ShowReadmeOnStart\" value=\"true\" />\n <property name=\"WebServerToolWindowFactoryState\" value=\"false\" />\n <property name=\"restartRequiresConfirmation\" value=\"false\" />\n </component>\n <component name=\"RecentsManager\">\n <key name=\"MoveFile.RECENT_KEYS\">\n <recent name=\"$PROJECT_DIR$/experment_result/learning_rate\" />\n </key>\n </component>\n <component name=\"RunManager\" selected=\"Python.eval\">\n <configuration name=\"eval\" type=\"PythonConfigurationType\" factoryName=\"Python\" temporary=\"true\" nameIsGenerated=\"true\">\n <module name=\"CPI\" />\n <option name=\"INTERPRETER_OPTIONS\" value=\"\" />\n <option name=\"PARENT_ENVS\" value=\"true\" />\n <envs>\n <env name=\"PYTHONUNBUFFERED\" value=\"1\" />\n </envs>\n <option name=\"SDK_HOME\" value=\"\" />\n <option name=\"WORKING_DIRECTORY\" value=\"$PROJECT_DIR$\" />\n <option name=\"IS_MODULE_SDK\" value=\"true\" />\n <option name=\"ADD_CONTENT_ROOTS\" value=\"true\" />\n <option name=\"ADD_SOURCE_ROOTS\" value=\"true\" />\n <EXTENSION ID=\"PythonCoverageRunConfigurationExtension\" runner=\"coverage.py\" />\n <option name=\"SCRIPT_NAME\" value=\"$PROJECT_DIR$/eval.py\" />\n <option name=\"PARAMETERS\" value=\"\" />\n <option name=\"SHOW_COMMAND_LINE\" value=\"false\" />\n <option name=\"EMULATE_TERMINAL\" value=\"false\" />\n <option name=\"MODULE_MODE\" value=\"false\" />\n <option name=\"REDIRECT_INPUT\" value=\"false\" />\n <option name=\"INPUT_FILE\" value=\"\" />\n <method v=\"2\" />\n </configuration>\n <configuration name=\"run_interaction\" type=\"PythonConfigurationType\" factoryName=\"Python\" temporary=\"true\" nameIsGenerated=\"true\">\n <module name=\"CPI\" />\n <option name=\"INTERPRETER_OPTIONS\" value=\"\" />\n <option name=\"PARENT_ENVS\" value=\"true\" />\n <envs>\n <env name=\"PYTHONUNBUFFERED\" value=\"1\" />\n </envs>\n <option name=\"SDK_HOME\" value=\"\" />\n <option name=\"WORKING_DIRECTORY\" value=\"$PROJECT_DIR$\" />\n <option name=\"IS_MODULE_SDK\" value=\"true\" />\n <option name=\"ADD_CONTENT_ROOTS\" value=\"true\" />\n <option name=\"ADD_SOURCE_ROOTS\" value=\"true\" />\n <EXTENSION ID=\"PythonCoverageRunConfigurationExtension\" runner=\"coverage.py\" />\n <option name=\"SCRIPT_NAME\" value=\"$PROJECT_DIR$/run_interaction.py\" />\n <option name=\"PARAMETERS\" value=\"\" />\n <option name=\"SHOW_COMMAND_LINE\" value=\"false\" />\n <option name=\"EMULATE_TERMINAL\" value=\"false\" />\n <option name=\"MODULE_MODE\" value=\"false\" />\n <option name=\"REDIRECT_INPUT\" value=\"false\" />\n <option name=\"INPUT_FILE\" value=\"\" />\n <method v=\"2\" />\n </configuration>\n <recent_temporary>\n <list>\n <item itemvalue=\"Python.eval\" />\n <item itemvalue=\"Python.run_interaction\" />\n </list>\n </recent_temporary>\n </component>\n <component name=\"SpellCheckerSettings\" RuntimeDictionaries=\"0\" Folders=\"0\" CustomDictionaries=\"0\" DefaultDictionary=\"application-level\" UseSingleDictionary=\"true\" transferred=\"true\" />\n <component name=\"TaskManager\">\n <task active=\"true\" id=\"Default\" summary=\"Default task\">\n <changelist id=\"f877ac68-9cea-46d8-9125-207eebe5b5d6\" name=\"Default Changelist\" comment=\"\" />\n <created>1617788646167</created>\n <option name=\"number\" value=\"Default\" />\n <option name=\"presentableId\" value=\"Default\" />\n <updated>1617788646167</updated>\n <workItem from=\"1617788647548\" duration=\"5550000\" />\n </task>\n <servers />\n </component>\n <component name=\"TypeScriptGeneratedFilesManager\">\n <option name=\"version\" value=\"3\" />\n </component>\n <component name=\"Vcs.Log.Tabs.Properties\">\n <option name=\"TAB_STATES\">\n <map>\n <entry key=\"MAIN\">\n <value>\n <State />\n </value>\n </entry>\n </map>\n </option>\n </component>\n <component name=\"XDebuggerManager\">\n <breakpoint-manager>\n <breakpoints>\n <line-breakpoint enabled=\"true\" suspend=\"THREAD\" type=\"python-line\">\n <url>file://$PROJECT_DIR$/eval.py</url>\n <line>36</line>\n <option name=\"timeStamp\" value=\"21\" />\n </line-breakpoint>\n </breakpoints>\n </breakpoint-manager>\n </component>\n <component name=\"com.intellij.coverage.CoverageDataManagerImpl\">\n <SUITE FILE_PATH=\"coverage/CPI$eval.coverage\" NAME=\"eval Coverage Results\" MODIFIED=\"1618396849549\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$\" />\n <SUITE FILE_PATH=\"coverage/CPI$run_interaction.coverage\" NAME=\"run_interaction Coverage Results\" MODIFIED=\"1618133791228\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$\" />\n </component>\n</project>
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/.idea/workspace.xml b/.idea/workspace.xml
--- a/.idea/workspace.xml (revision 9ad087cc206a1aaa224fb400aceeebf955c664e1)
+++ b/.idea/workspace.xml (date 1618543228916)
@@ -3,9 +3,6 @@
<component name="ChangeListManager">
<list default="true" id="f877ac68-9cea-46d8-9125-207eebe5b5d6" name="Default Changelist" comment="">
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
- <change beforePath="$PROJECT_DIR$/eval.py" beforeDir="false" afterPath="$PROJECT_DIR$/eval.py" afterDir="false" />
- <change beforePath="$PROJECT_DIR$/experment_result/loss.png" beforeDir="false" afterPath="$PROJECT_DIR$/experment_result/learning_rate/loss.png" afterDir="false" />
- <change beforePath="$PROJECT_DIR$/experment_result/loss.svg" beforeDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
<changelist name="Uncommitted_changes_before_Update_at_2021_4_16,_11_22_上午_[Default_Changelist]" date="1618543353862" recycled="true" deleted="true">
<option name="PATH" value="$PROJECT_DIR$/.idea/shelf/Uncommitted_changes_before_Update_at_2021_4_16,_11_22_上午_[Default_Changelist]/shelved.patch" />
<option name="DESCRIPTION" value="Uncommitted changes before Update at 2021/4/16, 11:22 上午 [Default Changelist]" />
</changelist>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ChangeListManager">
<list default="true" id="d29948e3-1642-45ab-9fe2-087a876b83b3" name="Default Changelist" comment="">
<list default="true" id="f877ac68-9cea-46d8-9125-207eebe5b5d6" name="Default Changelist" comment="">
<change beforePath="$PROJECT_DIR$/.idea/CPI.iml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/CPI.iml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/misc.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/dataset.py" beforeDir="false" afterPath="$PROJECT_DIR$/dataset.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/run_interaction.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_interaction.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/utils/data_analyse_train.tsv" beforeDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
<option name="LAST_RESOLUTION" value="IGNORE" />
</component>
<component name="FileTemplateManagerImpl">
<option name="RECENT_TEMPLATES">
<list>
<option value="Python Script" />
</list>
</option>
</component>
<component name="Git.Settings">
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
<option name="UPDATE_TYPE" value="REBASE" />
</component>
<component name="GitSEFilterConfiguration">
<file-type-list>
<filtered-out-file-type name="LOCAL_BRANCH" />
<filtered-out-file-type name="REMOTE_BRANCH" />
<filtered-out-file-type name="TAG" />
<filtered-out-file-type name="COMMIT_BY_MESSAGE" />
</file-type-list>
</component>
<component name="ProjectId" id="1qpu2Wq6VU5TQVQOm73pQEwAahA" />
<component name="ProjectLevelVcsManager">
<ConfirmationsSetting value="1" id="Add" />
</component>
<component name="ProjectId" id="1rCop1rMOjMHNiaO2kfyMOb9F7W" />
<component name="ProjectLevelVcsManager" settingsEditedManually="true" />
<component name="ProjectViewState">
<option name="hideEmptyMiddlePackages" value="true" />
<option name="showLibraryContents" value="true" />
......@@ -23,12 +45,16 @@
<property name="ASKED_ADD_EXTERNAL_FILES" value="true" />
<property name="RunOnceActivity.OpenProjectViewOnStart" value="true" />
<property name="RunOnceActivity.ShowReadmeOnStart" value="true" />
<property name="SHARE_PROJECT_CONFIGURATION_FILES" value="true" />
<property name="WebServerToolWindowFactoryState" value="false" />
<property name="restartRequiresConfirmation" value="false" />
<property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
</component>
<component name="RunManager" selected="Python.run_interaction">
<component name="RecentsManager">
<key name="MoveFile.RECENT_KEYS">
<recent name="$PROJECT_DIR$/utils" />
<recent name="$PROJECT_DIR$/experment_result/learning_rate" />
</key>
</component>
<component name="RunManager" selected="Python.test">
<configuration name="dataset" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="CPI" />
<option name="INTERPRETER_OPTIONS" value="" />
......@@ -51,6 +77,28 @@
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="eval" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="CPI" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/eval.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="run_interaction" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="CPI" />
<option name="INTERPRETER_OPTIONS" value="" />
......@@ -73,22 +121,46 @@
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="test" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="CPI" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/test.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<recent_temporary>
<list>
<item itemvalue="Python.test" />
<item itemvalue="Python.run_interaction" />
<item itemvalue="Python.dataset" />
<item itemvalue="Python.eval" />
</list>
</recent_temporary>
</component>
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
<component name="TaskManager">
<task active="true" id="Default" summary="Default task">
<changelist id="d29948e3-1642-45ab-9fe2-087a876b83b3" name="Default Changelist" comment="" />
<created>1618489611823</created>
<changelist id="f877ac68-9cea-46d8-9125-207eebe5b5d6" name="Default Changelist" comment="" />
<created>1617788646167</created>
<option name="number" value="Default" />
<option name="presentableId" value="Default" />
<updated>1618489611823</updated>
<workItem from="1618489616148" duration="6549000" />
<updated>1617788646167</updated>
<workItem from="1617788647548" duration="5550000" />
</task>
<servers />
</component>
......@@ -106,8 +178,21 @@
</map>
</option>
</component>
<component name="XDebuggerManager">
<breakpoint-manager>
<breakpoints>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/eval.py</url>
<line>36</line>
<option name="timeStamp" value="21" />
</line-breakpoint>
</breakpoints>
</breakpoint-manager>
</component>
<component name="com.intellij.coverage.CoverageDataManagerImpl">
<SUITE FILE_PATH="coverage/CPI$run_interaction.coverage" NAME="run_interaction Coverage Results" MODIFIED="1618496487501" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/CPI$dataset.coverage" NAME="dataset Coverage Results" MODIFIED="1618494392648" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/CPI$dataset.coverage" NAME="dataset Coverage Results" MODIFIED="1618641059668" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/CPI$test.coverage" NAME="test Coverage Results" MODIFIED="1618643206375" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/CPI$eval.coverage" NAME="eval Coverage Results" MODIFIED="1618396849549" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/CPI$run_interaction.coverage" NAME="run_interaction Coverage Results" MODIFIED="1618642769537" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
</component>
</project>
\ No newline at end of file
......@@ -142,7 +142,8 @@ class Data_Encoder(data.Dataset):
token_type_ids = np.concatenate((np.zeros((len(d) + 2), dtype=np.int), np.ones((len(p) + 1), dtype=np.int)))
token_type_ids = np.pad(token_type_ids, (0, self.max_len-len(input_seq)), 'constant', constant_values=0)
input, input_mask = seq2emb_encoder(input_seq, self.max_len, self.vocab)
return torch.from_numpy(input).long(), torch.from_numpy(token_type_ids).long(), y
return torch.from_numpy(input).long(), torch.from_numpy(token_type_ids).long(), torch.from_numpy(input_mask).long(), y
# return len(d), len(p)
if __name__ == "__main__":
......@@ -162,18 +163,26 @@ if __name__ == "__main__":
"smile": './data/train_smile',
"affinity": './data/train_ic50',
}
df_test = {"sps": './data/test_sps',
"smile": './data/test_smile',
"affinity": './data/test_ic50',
}
tokenizer_config = {"vocab_file": './config/vocab.txt',
"vocab_pair": './config/drug_codes_chembl.txt',
"begin_id": '[CLS]',
"separate_id": "[SEP]",
"max_len": 256
}
params = {'batch_size': 5,
'shuffle': True,
'num_workers': 0,
'drop_last': True}
params = {'batch_size': 1,
'shuffle': False,
'num_workers': 0
}
# trainset = Data_Encoder(df_train, tokenizer_config)
# training_generator = data.DataLoader(trainset, **params)
# for i, (input, affinity) in tqdm(enumerate(training_generator)):
# print(input.size())
# with open('utils/train_data_analyse.csv', 'w', newline='') as f:
# csv_f = csv.writer(f)
# csv_f.writerow(['drup_len', 'protein_len'])
# for i, (len_d, len_p) in tqdm(enumerate(training_generator)):
# d = len_d.numpy()[0]
# p = len_p.numpy()[0]
# csv_f.writerow([str(d), str(p)])
......@@ -51,7 +51,7 @@ def train(args, model, dataset):
writer = SummaryWriter('./log/' + args.savedir)
num_step = args.epochs * len(data_generator)
step = 0
save_step = num_step // 5
save_step = num_step // 10
# detect GPU
if torch.cuda.is_available():
model.cuda()
......@@ -63,14 +63,14 @@ def train(args, model, dataset):
print('begin training')
# training
for epoch in range(args.epochs):
for i, (input, token_type_ids, affinity) in enumerate(data_generator):
for i, (input, token_type_ids, input_mask, affinity) in enumerate(data_generator):
# use cuda
# input model
if torch.cuda.is_available():
pred_affinity = model(input_ids=input.cuda(), token_type_ids=token_type_ids.cuda())
pred_affinity = model(input_ids=input.cuda(), token_type_ids=token_type_ids.cuda(), attention_mask=input_mask.cuda())
loss = loss_fct(pred_affinity, affinity.cuda().unsqueeze(-1))
else:
pred_affinity = model(input_ids=input, token_type_ids=token_type_ids)
pred_affinity = model(input_ids=input, token_type_ids=token_type_ids, attention_mask=input_mask)
loss = loss_fct(pred_affinity, affinity.unsqueeze(-1))
step += 1
writer.add_scalar('loss', loss, global_step=step)
......@@ -109,12 +109,13 @@ def test(args, model, dataset):
result = args.output + '/' + '{}.txt'.format(args.task)
print('begin predicting')
with open(result, 'w') as f:
for i, (input, affinity) in enumerate(tqdm(data_generator)):
for i, (input, token_type_ids, input_mask, affinity) in enumerate(tqdm(data_generator)):
if torch.cuda.is_available():
model.cuda()
pred_affinity = model(input.cuda().long()).detach().cpu().numpy()
pred_affinity = model(input_ids=input.cuda(), token_type_ids=token_type_ids.cuda(),
attention_mask=input_mask.cuda())
else:
pred_affinity = model(input.long()).numpy()
pred_affinity = model(input_ids=input, token_type_ids=token_type_ids, attention_mask=input_mask)
for res in range(args.batch_size):
f.write(str(pred_affinity[res, :][0]) + '\n')
......@@ -176,7 +177,7 @@ if __name__ == '__main__':
# local test
args.task = 'train'
args.savedir='local_test_train'
args.savedir = 'local_test_train'
args.epochs = 10
args.lr = 1e-5
args.config = './config/config_layer_3.json'
......
from transformers import BertTokenizer, BertModel
import torch
a = torch.Tensor([0])
if a > -1:
print('------')
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论