提交 25b01fae 作者: 朱学凯

add attention_mask

上级 94ca8a36
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4"> <module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager"> <component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" /> <content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Python 3.6 (py3.6)" jdkType="Python SDK" /> <orderEntry type="jdk" jdkName="Python 3.6 (code)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" /> <orderEntry type="sourceFolder" forTests="false" />
</component> </component>
</module> </module>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (py3.6)" project-jdk-type="Python SDK" /> <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (code)" project-jdk-type="Python SDK" />
</project> </project>
\ No newline at end of file
Index: .idea/workspace.xml
Index: .idea/workspace.xml
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.BaseRevisionTextPatchEP
<+><?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n <component name=\"ChangeListManager\">\n <list default=\"true\" id=\"f877ac68-9cea-46d8-9125-207eebe5b5d6\" name=\"Default Changelist\" comment=\"\">\n <change beforePath=\"$PROJECT_DIR$/.idea/workspace.xml\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/.idea/workspace.xml\" afterDir=\"false\" />\n <change beforePath=\"$PROJECT_DIR$/eval.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/eval.py\" afterDir=\"false\" />\n <change beforePath=\"$PROJECT_DIR$/experment_result/loss.png\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/experment_result/learning_rate/loss.png\" afterDir=\"false\" />\n <change beforePath=\"$PROJECT_DIR$/experment_result/loss.svg\" beforeDir=\"false\" />\n </list>\n <option name=\"SHOW_DIALOG\" value=\"false\" />\n <option name=\"HIGHLIGHT_CONFLICTS\" value=\"true\" />\n <option name=\"HIGHLIGHT_NON_ACTIVE_CHANGELIST\" value=\"false\" />\n <option name=\"LAST_RESOLUTION\" value=\"IGNORE\" />\n </component>\n <component name=\"Git.Settings\">\n <option name=\"RECENT_GIT_ROOT_PATH\" value=\"$PROJECT_DIR$\" />\n <option name=\"UPDATE_TYPE\" value=\"REBASE\" />\n </component>\n <component name=\"ProjectId\" id=\"1qpu2Wq6VU5TQVQOm73pQEwAahA\" />\n <component name=\"ProjectLevelVcsManager\">\n <ConfirmationsSetting value=\"1\" id=\"Add\" />\n </component>\n <component name=\"ProjectViewState\">\n <option name=\"hideEmptyMiddlePackages\" value=\"true\" />\n <option name=\"showLibraryContents\" value=\"true\" />\n </component>\n <component name=\"PropertiesComponent\">\n <property name=\"ASKED_ADD_EXTERNAL_FILES\" value=\"true\" />\n <property name=\"RunOnceActivity.OpenProjectViewOnStart\" value=\"true\" />\n <property name=\"RunOnceActivity.ShowReadmeOnStart\" value=\"true\" />\n <property name=\"WebServerToolWindowFactoryState\" value=\"false\" />\n <property name=\"restartRequiresConfirmation\" value=\"false\" />\n </component>\n <component name=\"RecentsManager\">\n <key name=\"MoveFile.RECENT_KEYS\">\n <recent name=\"$PROJECT_DIR$/experment_result/learning_rate\" />\n </key>\n </component>\n <component name=\"RunManager\" selected=\"Python.eval\">\n <configuration name=\"eval\" type=\"PythonConfigurationType\" factoryName=\"Python\" temporary=\"true\" nameIsGenerated=\"true\">\n <module name=\"CPI\" />\n <option name=\"INTERPRETER_OPTIONS\" value=\"\" />\n <option name=\"PARENT_ENVS\" value=\"true\" />\n <envs>\n <env name=\"PYTHONUNBUFFERED\" value=\"1\" />\n </envs>\n <option name=\"SDK_HOME\" value=\"\" />\n <option name=\"WORKING_DIRECTORY\" value=\"$PROJECT_DIR$\" />\n <option name=\"IS_MODULE_SDK\" value=\"true\" />\n <option name=\"ADD_CONTENT_ROOTS\" value=\"true\" />\n <option name=\"ADD_SOURCE_ROOTS\" value=\"true\" />\n <EXTENSION ID=\"PythonCoverageRunConfigurationExtension\" runner=\"coverage.py\" />\n <option name=\"SCRIPT_NAME\" value=\"$PROJECT_DIR$/eval.py\" />\n <option name=\"PARAMETERS\" value=\"\" />\n <option name=\"SHOW_COMMAND_LINE\" value=\"false\" />\n <option name=\"EMULATE_TERMINAL\" value=\"false\" />\n <option name=\"MODULE_MODE\" value=\"false\" />\n <option name=\"REDIRECT_INPUT\" value=\"false\" />\n <option name=\"INPUT_FILE\" value=\"\" />\n <method v=\"2\" />\n </configuration>\n <configuration name=\"run_interaction\" type=\"PythonConfigurationType\" factoryName=\"Python\" temporary=\"true\" nameIsGenerated=\"true\">\n <module name=\"CPI\" />\n <option name=\"INTERPRETER_OPTIONS\" value=\"\" />\n <option name=\"PARENT_ENVS\" value=\"true\" />\n <envs>\n <env name=\"PYTHONUNBUFFERED\" value=\"1\" />\n </envs>\n <option name=\"SDK_HOME\" value=\"\" />\n <option name=\"WORKING_DIRECTORY\" value=\"$PROJECT_DIR$\" />\n <option name=\"IS_MODULE_SDK\" value=\"true\" />\n <option name=\"ADD_CONTENT_ROOTS\" value=\"true\" />\n <option name=\"ADD_SOURCE_ROOTS\" value=\"true\" />\n <EXTENSION ID=\"PythonCoverageRunConfigurationExtension\" runner=\"coverage.py\" />\n <option name=\"SCRIPT_NAME\" value=\"$PROJECT_DIR$/run_interaction.py\" />\n <option name=\"PARAMETERS\" value=\"\" />\n <option name=\"SHOW_COMMAND_LINE\" value=\"false\" />\n <option name=\"EMULATE_TERMINAL\" value=\"false\" />\n <option name=\"MODULE_MODE\" value=\"false\" />\n <option name=\"REDIRECT_INPUT\" value=\"false\" />\n <option name=\"INPUT_FILE\" value=\"\" />\n <method v=\"2\" />\n </configuration>\n <recent_temporary>\n <list>\n <item itemvalue=\"Python.eval\" />\n <item itemvalue=\"Python.run_interaction\" />\n </list>\n </recent_temporary>\n </component>\n <component name=\"SpellCheckerSettings\" RuntimeDictionaries=\"0\" Folders=\"0\" CustomDictionaries=\"0\" DefaultDictionary=\"application-level\" UseSingleDictionary=\"true\" transferred=\"true\" />\n <component name=\"TaskManager\">\n <task active=\"true\" id=\"Default\" summary=\"Default task\">\n <changelist id=\"f877ac68-9cea-46d8-9125-207eebe5b5d6\" name=\"Default Changelist\" comment=\"\" />\n <created>1617788646167</created>\n <option name=\"number\" value=\"Default\" />\n <option name=\"presentableId\" value=\"Default\" />\n <updated>1617788646167</updated>\n <workItem from=\"1617788647548\" duration=\"5550000\" />\n </task>\n <servers />\n </component>\n <component name=\"TypeScriptGeneratedFilesManager\">\n <option name=\"version\" value=\"3\" />\n </component>\n <component name=\"Vcs.Log.Tabs.Properties\">\n <option name=\"TAB_STATES\">\n <map>\n <entry key=\"MAIN\">\n <value>\n <State />\n </value>\n </entry>\n </map>\n </option>\n </component>\n <component name=\"XDebuggerManager\">\n <breakpoint-manager>\n <breakpoints>\n <line-breakpoint enabled=\"true\" suspend=\"THREAD\" type=\"python-line\">\n <url>file://$PROJECT_DIR$/eval.py</url>\n <line>36</line>\n <option name=\"timeStamp\" value=\"21\" />\n </line-breakpoint>\n </breakpoints>\n </breakpoint-manager>\n </component>\n <component name=\"com.intellij.coverage.CoverageDataManagerImpl\">\n <SUITE FILE_PATH=\"coverage/CPI$eval.coverage\" NAME=\"eval Coverage Results\" MODIFIED=\"1618396849549\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$\" />\n <SUITE FILE_PATH=\"coverage/CPI$run_interaction.coverage\" NAME=\"run_interaction Coverage Results\" MODIFIED=\"1618133791228\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$\" />\n </component>\n</project>
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/.idea/workspace.xml b/.idea/workspace.xml
--- a/.idea/workspace.xml (revision 9ad087cc206a1aaa224fb400aceeebf955c664e1)
+++ b/.idea/workspace.xml (date 1618543228916)
@@ -3,9 +3,6 @@
<component name="ChangeListManager">
<list default="true" id="f877ac68-9cea-46d8-9125-207eebe5b5d6" name="Default Changelist" comment="">
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
- <change beforePath="$PROJECT_DIR$/eval.py" beforeDir="false" afterPath="$PROJECT_DIR$/eval.py" afterDir="false" />
- <change beforePath="$PROJECT_DIR$/experment_result/loss.png" beforeDir="false" afterPath="$PROJECT_DIR$/experment_result/learning_rate/loss.png" afterDir="false" />
- <change beforePath="$PROJECT_DIR$/experment_result/loss.svg" beforeDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
<changelist name="Uncommitted_changes_before_Update_at_2021_4_16,_11_22_上午_[Default_Changelist]" date="1618543353862" recycled="true" deleted="true">
<option name="PATH" value="$PROJECT_DIR$/.idea/shelf/Uncommitted_changes_before_Update_at_2021_4_16,_11_22_上午_[Default_Changelist]/shelved.patch" />
<option name="DESCRIPTION" value="Uncommitted changes before Update at 2021/4/16, 11:22 上午 [Default Changelist]" />
</changelist>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="ChangeListManager"> <component name="ChangeListManager">
<list default="true" id="d29948e3-1642-45ab-9fe2-087a876b83b3" name="Default Changelist" comment=""> <list default="true" id="f877ac68-9cea-46d8-9125-207eebe5b5d6" name="Default Changelist" comment="">
<change beforePath="$PROJECT_DIR$/.idea/CPI.iml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/CPI.iml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/misc.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" /> <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/dataset.py" beforeDir="false" afterPath="$PROJECT_DIR$/dataset.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/run_interaction.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_interaction.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/run_interaction.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_interaction.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/utils/data_analyse_train.tsv" beforeDir="false" />
</list> </list>
<option name="SHOW_DIALOG" value="false" /> <option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" /> <option name="HIGHLIGHT_CONFLICTS" value="true" />
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" /> <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
<option name="LAST_RESOLUTION" value="IGNORE" /> <option name="LAST_RESOLUTION" value="IGNORE" />
</component> </component>
<component name="FileTemplateManagerImpl">
<option name="RECENT_TEMPLATES">
<list>
<option value="Python Script" />
</list>
</option>
</component>
<component name="Git.Settings"> <component name="Git.Settings">
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" /> <option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
<option name="UPDATE_TYPE" value="REBASE" />
</component>
<component name="GitSEFilterConfiguration">
<file-type-list>
<filtered-out-file-type name="LOCAL_BRANCH" />
<filtered-out-file-type name="REMOTE_BRANCH" />
<filtered-out-file-type name="TAG" />
<filtered-out-file-type name="COMMIT_BY_MESSAGE" />
</file-type-list>
</component>
<component name="ProjectId" id="1qpu2Wq6VU5TQVQOm73pQEwAahA" />
<component name="ProjectLevelVcsManager">
<ConfirmationsSetting value="1" id="Add" />
</component> </component>
<component name="ProjectId" id="1rCop1rMOjMHNiaO2kfyMOb9F7W" />
<component name="ProjectLevelVcsManager" settingsEditedManually="true" />
<component name="ProjectViewState"> <component name="ProjectViewState">
<option name="hideEmptyMiddlePackages" value="true" /> <option name="hideEmptyMiddlePackages" value="true" />
<option name="showLibraryContents" value="true" /> <option name="showLibraryContents" value="true" />
...@@ -23,12 +45,16 @@ ...@@ -23,12 +45,16 @@
<property name="ASKED_ADD_EXTERNAL_FILES" value="true" /> <property name="ASKED_ADD_EXTERNAL_FILES" value="true" />
<property name="RunOnceActivity.OpenProjectViewOnStart" value="true" /> <property name="RunOnceActivity.OpenProjectViewOnStart" value="true" />
<property name="RunOnceActivity.ShowReadmeOnStart" value="true" /> <property name="RunOnceActivity.ShowReadmeOnStart" value="true" />
<property name="SHARE_PROJECT_CONFIGURATION_FILES" value="true" />
<property name="WebServerToolWindowFactoryState" value="false" /> <property name="WebServerToolWindowFactoryState" value="false" />
<property name="restartRequiresConfirmation" value="false" /> <property name="restartRequiresConfirmation" value="false" />
<property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
</component> </component>
<component name="RunManager" selected="Python.run_interaction"> <component name="RecentsManager">
<key name="MoveFile.RECENT_KEYS">
<recent name="$PROJECT_DIR$/utils" />
<recent name="$PROJECT_DIR$/experment_result/learning_rate" />
</key>
</component>
<component name="RunManager" selected="Python.test">
<configuration name="dataset" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true"> <configuration name="dataset" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="CPI" /> <module name="CPI" />
<option name="INTERPRETER_OPTIONS" value="" /> <option name="INTERPRETER_OPTIONS" value="" />
...@@ -51,6 +77,28 @@ ...@@ -51,6 +77,28 @@
<option name="INPUT_FILE" value="" /> <option name="INPUT_FILE" value="" />
<method v="2" /> <method v="2" />
</configuration> </configuration>
<configuration name="eval" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="CPI" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/eval.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="run_interaction" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true"> <configuration name="run_interaction" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="CPI" /> <module name="CPI" />
<option name="INTERPRETER_OPTIONS" value="" /> <option name="INTERPRETER_OPTIONS" value="" />
...@@ -73,22 +121,46 @@ ...@@ -73,22 +121,46 @@
<option name="INPUT_FILE" value="" /> <option name="INPUT_FILE" value="" />
<method v="2" /> <method v="2" />
</configuration> </configuration>
<configuration name="test" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="CPI" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/test.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<recent_temporary> <recent_temporary>
<list> <list>
<item itemvalue="Python.test" />
<item itemvalue="Python.run_interaction" /> <item itemvalue="Python.run_interaction" />
<item itemvalue="Python.dataset" /> <item itemvalue="Python.dataset" />
<item itemvalue="Python.eval" />
</list> </list>
</recent_temporary> </recent_temporary>
</component> </component>
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" /> <component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
<component name="TaskManager"> <component name="TaskManager">
<task active="true" id="Default" summary="Default task"> <task active="true" id="Default" summary="Default task">
<changelist id="d29948e3-1642-45ab-9fe2-087a876b83b3" name="Default Changelist" comment="" /> <changelist id="f877ac68-9cea-46d8-9125-207eebe5b5d6" name="Default Changelist" comment="" />
<created>1618489611823</created> <created>1617788646167</created>
<option name="number" value="Default" /> <option name="number" value="Default" />
<option name="presentableId" value="Default" /> <option name="presentableId" value="Default" />
<updated>1618489611823</updated> <updated>1617788646167</updated>
<workItem from="1618489616148" duration="6549000" /> <workItem from="1617788647548" duration="5550000" />
</task> </task>
<servers /> <servers />
</component> </component>
...@@ -106,8 +178,21 @@ ...@@ -106,8 +178,21 @@
</map> </map>
</option> </option>
</component> </component>
<component name="XDebuggerManager">
<breakpoint-manager>
<breakpoints>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/eval.py</url>
<line>36</line>
<option name="timeStamp" value="21" />
</line-breakpoint>
</breakpoints>
</breakpoint-manager>
</component>
<component name="com.intellij.coverage.CoverageDataManagerImpl"> <component name="com.intellij.coverage.CoverageDataManagerImpl">
<SUITE FILE_PATH="coverage/CPI$run_interaction.coverage" NAME="run_interaction Coverage Results" MODIFIED="1618496487501" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" /> <SUITE FILE_PATH="coverage/CPI$dataset.coverage" NAME="dataset Coverage Results" MODIFIED="1618641059668" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/CPI$dataset.coverage" NAME="dataset Coverage Results" MODIFIED="1618494392648" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" /> <SUITE FILE_PATH="coverage/CPI$test.coverage" NAME="test Coverage Results" MODIFIED="1618643206375" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/CPI$eval.coverage" NAME="eval Coverage Results" MODIFIED="1618396849549" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/CPI$run_interaction.coverage" NAME="run_interaction Coverage Results" MODIFIED="1618642769537" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
</component> </component>
</project> </project>
\ No newline at end of file
...@@ -142,7 +142,8 @@ class Data_Encoder(data.Dataset): ...@@ -142,7 +142,8 @@ class Data_Encoder(data.Dataset):
token_type_ids = np.concatenate((np.zeros((len(d) + 2), dtype=np.int), np.ones((len(p) + 1), dtype=np.int))) token_type_ids = np.concatenate((np.zeros((len(d) + 2), dtype=np.int), np.ones((len(p) + 1), dtype=np.int)))
token_type_ids = np.pad(token_type_ids, (0, self.max_len-len(input_seq)), 'constant', constant_values=0) token_type_ids = np.pad(token_type_ids, (0, self.max_len-len(input_seq)), 'constant', constant_values=0)
input, input_mask = seq2emb_encoder(input_seq, self.max_len, self.vocab) input, input_mask = seq2emb_encoder(input_seq, self.max_len, self.vocab)
return torch.from_numpy(input).long(), torch.from_numpy(token_type_ids).long(), y return torch.from_numpy(input).long(), torch.from_numpy(token_type_ids).long(), torch.from_numpy(input_mask).long(), y
# return len(d), len(p)
if __name__ == "__main__": if __name__ == "__main__":
...@@ -162,18 +163,26 @@ if __name__ == "__main__": ...@@ -162,18 +163,26 @@ if __name__ == "__main__":
"smile": './data/train_smile', "smile": './data/train_smile',
"affinity": './data/train_ic50', "affinity": './data/train_ic50',
} }
df_test = {"sps": './data/test_sps',
"smile": './data/test_smile',
"affinity": './data/test_ic50',
}
tokenizer_config = {"vocab_file": './config/vocab.txt', tokenizer_config = {"vocab_file": './config/vocab.txt',
"vocab_pair": './config/drug_codes_chembl.txt', "vocab_pair": './config/drug_codes_chembl.txt',
"begin_id": '[CLS]', "begin_id": '[CLS]',
"separate_id": "[SEP]", "separate_id": "[SEP]",
"max_len": 256 "max_len": 256
} }
params = {'batch_size': 5, params = {'batch_size': 1,
'shuffle': True, 'shuffle': False,
'num_workers': 0, 'num_workers': 0
'drop_last': True} }
# trainset = Data_Encoder(df_train, tokenizer_config) # trainset = Data_Encoder(df_train, tokenizer_config)
# training_generator = data.DataLoader(trainset, **params) # training_generator = data.DataLoader(trainset, **params)
# for i, (input, affinity) in tqdm(enumerate(training_generator)): # with open('utils/train_data_analyse.csv', 'w', newline='') as f:
# print(input.size()) # csv_f = csv.writer(f)
# csv_f.writerow(['drup_len', 'protein_len'])
# for i, (len_d, len_p) in tqdm(enumerate(training_generator)):
# d = len_d.numpy()[0]
# p = len_p.numpy()[0]
# csv_f.writerow([str(d), str(p)])
...@@ -51,7 +51,7 @@ def train(args, model, dataset): ...@@ -51,7 +51,7 @@ def train(args, model, dataset):
writer = SummaryWriter('./log/' + args.savedir) writer = SummaryWriter('./log/' + args.savedir)
num_step = args.epochs * len(data_generator) num_step = args.epochs * len(data_generator)
step = 0 step = 0
save_step = num_step // 5 save_step = num_step // 10
# detect GPU # detect GPU
if torch.cuda.is_available(): if torch.cuda.is_available():
model.cuda() model.cuda()
...@@ -63,14 +63,14 @@ def train(args, model, dataset): ...@@ -63,14 +63,14 @@ def train(args, model, dataset):
print('begin training') print('begin training')
# training # training
for epoch in range(args.epochs): for epoch in range(args.epochs):
for i, (input, token_type_ids, affinity) in enumerate(data_generator): for i, (input, token_type_ids, input_mask, affinity) in enumerate(data_generator):
# use cuda # use cuda
# input model # input model
if torch.cuda.is_available(): if torch.cuda.is_available():
pred_affinity = model(input_ids=input.cuda(), token_type_ids=token_type_ids.cuda()) pred_affinity = model(input_ids=input.cuda(), token_type_ids=token_type_ids.cuda(), attention_mask=input_mask.cuda())
loss = loss_fct(pred_affinity, affinity.cuda().unsqueeze(-1)) loss = loss_fct(pred_affinity, affinity.cuda().unsqueeze(-1))
else: else:
pred_affinity = model(input_ids=input, token_type_ids=token_type_ids) pred_affinity = model(input_ids=input, token_type_ids=token_type_ids, attention_mask=input_mask)
loss = loss_fct(pred_affinity, affinity.unsqueeze(-1)) loss = loss_fct(pred_affinity, affinity.unsqueeze(-1))
step += 1 step += 1
writer.add_scalar('loss', loss, global_step=step) writer.add_scalar('loss', loss, global_step=step)
...@@ -109,12 +109,13 @@ def test(args, model, dataset): ...@@ -109,12 +109,13 @@ def test(args, model, dataset):
result = args.output + '/' + '{}.txt'.format(args.task) result = args.output + '/' + '{}.txt'.format(args.task)
print('begin predicting') print('begin predicting')
with open(result, 'w') as f: with open(result, 'w') as f:
for i, (input, affinity) in enumerate(tqdm(data_generator)): for i, (input, token_type_ids, input_mask, affinity) in enumerate(tqdm(data_generator)):
if torch.cuda.is_available(): if torch.cuda.is_available():
model.cuda() model.cuda()
pred_affinity = model(input.cuda().long()).detach().cpu().numpy() pred_affinity = model(input_ids=input.cuda(), token_type_ids=token_type_ids.cuda(),
attention_mask=input_mask.cuda())
else: else:
pred_affinity = model(input.long()).numpy() pred_affinity = model(input_ids=input, token_type_ids=token_type_ids, attention_mask=input_mask)
for res in range(args.batch_size): for res in range(args.batch_size):
f.write(str(pred_affinity[res, :][0]) + '\n') f.write(str(pred_affinity[res, :][0]) + '\n')
...@@ -176,7 +177,7 @@ if __name__ == '__main__': ...@@ -176,7 +177,7 @@ if __name__ == '__main__':
# local test # local test
args.task = 'train' args.task = 'train'
args.savedir='local_test_train' args.savedir = 'local_test_train'
args.epochs = 10 args.epochs = 10
args.lr = 1e-5 args.lr = 1e-5
args.config = './config/config_layer_3.json' args.config = './config/config_layer_3.json'
......
from transformers import BertTokenizer, BertModel
import torch
a = torch.Tensor([0])
if a > -1:
print('------')
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论