Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
B
BiTransDPI
概览
概览
详情
活动
周期分析
版本库
存储库
文件
提交
分支
标签
贡献者
分支图
比较
统计图
问题
0
议题
0
列表
看板
标记
里程碑
CI / CD
CI / CD
流水线
日程表
维基
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
提交
问题看板
Open sidebar
杨志辉
BiTransDPI
Commits
825c3d12
提交
825c3d12
authored
4月 11, 2021
作者:
朱学凯
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix out
上级
d365ba55
显示空白字符变更
内嵌
并排
正在显示
14 个修改的文件
包含
86 行增加
和
34 行删除
+86
-34
CPI.iml
.idea/CPI.iml
+2
-1
misc.xml
.idea/misc.xml
+2
-1
shelved.patch
...at_2021_4_11,_11_03_上午_[Default_Changelist]/shelved.patch
+28
-0
Uncommitted_changes_before_Update_at_2021_4_11__11_03___Default_Changelist_.xml
...fore_Update_at_2021_4_11__11_03___Default_Changelist_.xml
+5
-0
workspace.xml
.idea/workspace.xml
+19
-17
modeling_bert.cpython-36.pyc
__pycache__/modeling_bert.cpython-36.pyc
+0
-0
events.out.tfevents.1618120716.b1393040f57d.2362.0
...3-0411/events.out.tfevents.1618120716.b1393040f57d.2362.0
+0
-0
events.out.tfevents.1618120794.b1393040f57d.2469.0
...3-0411/events.out.tfevents.1618120794.b1393040f57d.2469.0
+0
-0
epoch-4-step-2000-loss-2.9310436248779297.pth
...h-64-layer3/epoch-4-step-2000-loss-2.9310436248779297.pth
+0
-0
epoch-2-step-9439-loss-2.063138484954834.pth
...-layer3-0411/epoch-2-step-9439-loss-2.063138484954834.pth
+0
-0
modeling_bert.py
modeling_bert.py
+5
-6
test.txt
predict/test/test.txt
+0
-0
run_interaction.py
run_interaction.py
+23
-7
train.sh
train.sh
+2
-2
没有找到文件。
.idea/CPI.iml
浏览文件 @
825c3d12
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
<module
type=
"PYTHON_MODULE"
version=
"4"
>
<module
type=
"PYTHON_MODULE"
version=
"4"
>
<component
name=
"NewModuleRootManager"
>
<component
name=
"NewModuleRootManager"
>
<content
url=
"file://$MODULE_DIR$"
/>
<content
url=
"file://$MODULE_DIR$"
/>
<orderEntry
type=
"jdk"
jdkName=
"Python 3.6 (
py3.6
)"
jdkType=
"Python SDK"
/>
<orderEntry
type=
"jdk"
jdkName=
"Python 3.6 (
code
)"
jdkType=
"Python SDK"
/>
<orderEntry
type=
"sourceFolder"
forTests=
"false"
/>
<orderEntry
type=
"sourceFolder"
forTests=
"false"
/>
</component>
</component>
</module>
</module>
\ No newline at end of file
.idea/misc.xml
浏览文件 @
825c3d12
<?xml version="1.0" encoding="UTF-8"?>
<?xml version="1.0" encoding="UTF-8"?>
<project
version=
"4"
>
<project
version=
"4"
>
<component
name=
"ProjectRootManager"
version=
"2"
project-jdk-name=
"Python 3.6 (
py3.6
)"
project-jdk-type=
"Python SDK"
/>
<component
name=
"ProjectRootManager"
version=
"2"
project-jdk-name=
"Python 3.6 (
code
)"
project-jdk-type=
"Python SDK"
/>
</project>
</project>
\ No newline at end of file
.idea/shelf/Uncommitted_changes_before_Update_at_2021_4_11,_11_03_上午_[Default_Changelist]/shelved.patch
0 → 100644
浏览文件 @
825c3d12
Index: .idea/workspace.xml
Index: .idea/workspace.xml
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.BaseRevisionTextPatchEP
<+><?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n <component name=\"ChangeListManager\">\n <list default=\"true\" id=\"f877ac68-9cea-46d8-9125-207eebe5b5d6\" name=\"Default Changelist\" comment=\"\">\n <change beforePath=\"$PROJECT_DIR$/.idea/workspace.xml\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/.idea/workspace.xml\" afterDir=\"false\" />\n <change beforePath=\"$PROJECT_DIR$/run_interaction.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/run_interaction.py\" afterDir=\"false\" />\n </list>\n <option name=\"SHOW_DIALOG\" value=\"false\" />\n <option name=\"HIGHLIGHT_CONFLICTS\" value=\"true\" />\n <option name=\"HIGHLIGHT_NON_ACTIVE_CHANGELIST\" value=\"false\" />\n <option name=\"LAST_RESOLUTION\" value=\"IGNORE\" />\n </component>\n <component name=\"Git.Settings\">\n <option name=\"RECENT_GIT_ROOT_PATH\" value=\"$PROJECT_DIR$\" />\n </component>\n <component name=\"ProjectId\" id=\"1qpu2Wq6VU5TQVQOm73pQEwAahA\" />\n <component name=\"ProjectLevelVcsManager\">\n <ConfirmationsSetting value=\"1\" id=\"Add\" />\n </component>\n <component name=\"ProjectViewState\">\n <option name=\"hideEmptyMiddlePackages\" value=\"true\" />\n <option name=\"showLibraryContents\" value=\"true\" />\n </component>\n <component name=\"PropertiesComponent\">\n <property name=\"ASKED_ADD_EXTERNAL_FILES\" value=\"true\" />\n <property name=\"RunOnceActivity.OpenProjectViewOnStart\" value=\"true\" />\n <property name=\"RunOnceActivity.ShowReadmeOnStart\" value=\"true\" />\n <property name=\"WebServerToolWindowFactoryState\" value=\"false\" />\n <property name=\"restartRequiresConfirmation\" value=\"false\" />\n </component>\n <component name=\"RunManager\">\n <configuration name=\"run_interaction\" type=\"PythonConfigurationType\" factoryName=\"Python\" temporary=\"true\" nameIsGenerated=\"true\">\n <module name=\"CPI\" />\n <option name=\"INTERPRETER_OPTIONS\" value=\"\" />\n <option name=\"PARENT_ENVS\" value=\"true\" />\n <envs>\n <env name=\"PYTHONUNBUFFERED\" value=\"1\" />\n </envs>\n <option name=\"SDK_HOME\" value=\"\" />\n <option name=\"WORKING_DIRECTORY\" value=\"$PROJECT_DIR$\" />\n <option name=\"IS_MODULE_SDK\" value=\"true\" />\n <option name=\"ADD_CONTENT_ROOTS\" value=\"true\" />\n <option name=\"ADD_SOURCE_ROOTS\" value=\"true\" />\n <EXTENSION ID=\"PythonCoverageRunConfigurationExtension\" runner=\"coverage.py\" />\n <option name=\"SCRIPT_NAME\" value=\"$PROJECT_DIR$/run_interaction.py\" />\n <option name=\"PARAMETERS\" value=\"\" />\n <option name=\"SHOW_COMMAND_LINE\" value=\"false\" />\n <option name=\"EMULATE_TERMINAL\" value=\"false\" />\n <option name=\"MODULE_MODE\" value=\"false\" />\n <option name=\"REDIRECT_INPUT\" value=\"false\" />\n <option name=\"INPUT_FILE\" value=\"\" />\n <method v=\"2\" />\n </configuration>\n <recent_temporary>\n <list>\n <item itemvalue=\"Python.run_interaction\" />\n </list>\n </recent_temporary>\n </component>\n <component name=\"SpellCheckerSettings\" RuntimeDictionaries=\"0\" Folders=\"0\" CustomDictionaries=\"0\" DefaultDictionary=\"application-level\" UseSingleDictionary=\"true\" transferred=\"true\" />\n <component name=\"TaskManager\">\n <task active=\"true\" id=\"Default\" summary=\"Default task\">\n <changelist id=\"f877ac68-9cea-46d8-9125-207eebe5b5d6\" name=\"Default Changelist\" comment=\"\" />\n <created>1617788646167</created>\n <option name=\"number\" value=\"Default\" />\n <option name=\"presentableId\" value=\"Default\" />\n <updated>1617788646167</updated>\n <workItem from=\"1617788647548\" duration=\"5550000\" />\n </task>\n <servers />\n </component>\n <component name=\"TypeScriptGeneratedFilesManager\">\n <option name=\"version\" value=\"3\" />\n </component>\n <component name=\"Vcs.Log.Tabs.Properties\">\n <option name=\"TAB_STATES\">\n <map>\n <entry key=\"MAIN\">\n <value>\n <State />\n </value>\n </entry>\n </map>\n </option>\n </component>\n <component name=\"com.intellij.coverage.CoverageDataManagerImpl\">\n <SUITE FILE_PATH=\"coverage/CPI$run_interaction.coverage\" NAME=\"run_interaction Coverage Results\" MODIFIED=\"1617888322915\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$\" />\n </component>\n</project>
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/.idea/workspace.xml b/.idea/workspace.xml
--- a/.idea/workspace.xml (revision d2a304581cb05de7f85d60774094ec940d9ff199)
+++ b/.idea/workspace.xml (date 1618051977139)
@@ -3,7 +3,6 @@
<component name="ChangeListManager">
<list default="true" id="f877ac68-9cea-46d8-9125-207eebe5b5d6" name="Default Changelist" comment="">
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
- <change beforePath="$PROJECT_DIR$/run_interaction.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_interaction.py" afterDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
@@ -65,7 +64,7 @@
<option name="number" value="Default" />
<option name="presentableId" value="Default" />
<updated>1617788646167</updated>
- <workItem from="1617788647548" duration="5550000" />
+ <workItem from="1617788647548" duration="5869000" />
</task>
<servers />
</component>
.idea/shelf/Uncommitted_changes_before_Update_at_2021_4_11__11_03___Default_Changelist_.xml
0 → 100644
浏览文件 @
825c3d12
<changelist
name=
"Uncommitted_changes_before_Update_at_2021_4_11,_11_03_上午_[Default_Changelist]"
date=
"1618110223042"
recycled=
"true"
deleted=
"true"
>
<option
name=
"PATH"
value=
"$PROJECT_DIR$/.idea/shelf/Uncommitted_changes_before_Update_at_2021_4_11,_11_03_上午_[Default_Changelist]/shelved.patch"
/>
<option
name=
"DESCRIPTION"
value=
"Uncommitted changes before Update at 2021/4/11, 11:03 上午 [Default Changelist]"
/>
</changelist>
\ No newline at end of file
.idea/workspace.xml
浏览文件 @
825c3d12
<?xml version="1.0" encoding="UTF-8"?>
<?xml version="1.0" encoding="UTF-8"?>
<project
version=
"4"
>
<project
version=
"4"
>
<component
name=
"ChangeListManager"
>
<component
name=
"ChangeListManager"
>
<list
default=
"true"
id=
"b4fb7f33-5387-4628-bcdb-b1b79dd926d0"
name=
"Default Changelist"
comment=
""
>
<list
default=
"true"
id=
"f877ac68-9cea-46d8-9125-207eebe5b5d6"
name=
"Default Changelist"
comment=
""
>
<change
beforePath=
"$PROJECT_DIR$/.idea/CPI.iml"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/.idea/CPI.iml"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/.idea/misc.xml"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/.idea/misc.xml"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/.idea/workspace.xml"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/.idea/workspace.xml"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/.idea/workspace.xml"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/.idea/workspace.xml"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/
pre_test.sh"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/pre_test.sh"
after
Dir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/
model/lr-1e-5-batch-64-layer3/epoch-4-step-2000-loss-2.9310436248779297.pth"
before
Dir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/
predict/test/results.txt"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/predict/test/results.txt
"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/
modeling_bert.py"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/modeling_bert.py
"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/run_interaction.py"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/run_interaction.py"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/run_interaction.py"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/run_interaction.py"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/train.sh"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/train.sh"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/train.sh"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/train.sh"
afterDir=
"false"
/>
</list>
</list>
...
@@ -13,15 +15,9 @@
...
@@ -13,15 +15,9 @@
<option
name=
"HIGHLIGHT_NON_ACTIVE_CHANGELIST"
value=
"false"
/>
<option
name=
"HIGHLIGHT_NON_ACTIVE_CHANGELIST"
value=
"false"
/>
<option
name=
"LAST_RESOLUTION"
value=
"IGNORE"
/>
<option
name=
"LAST_RESOLUTION"
value=
"IGNORE"
/>
</component>
</component>
<component
name=
"FileTemplateManagerImpl"
>
<option
name=
"RECENT_TEMPLATES"
>
<list>
<option
value=
"Python Script"
/>
</list>
</option>
</component>
<component
name=
"Git.Settings"
>
<component
name=
"Git.Settings"
>
<option
name=
"RECENT_GIT_ROOT_PATH"
value=
"$PROJECT_DIR$"
/>
<option
name=
"RECENT_GIT_ROOT_PATH"
value=
"$PROJECT_DIR$"
/>
<option
name=
"UPDATE_TYPE"
value=
"REBASE"
/>
</component>
</component>
<component
name=
"ProjectId"
id=
"1qpu2Wq6VU5TQVQOm73pQEwAahA"
/>
<component
name=
"ProjectId"
id=
"1qpu2Wq6VU5TQVQOm73pQEwAahA"
/>
<component
name=
"ProjectLevelVcsManager"
>
<component
name=
"ProjectLevelVcsManager"
>
...
@@ -38,11 +34,6 @@
...
@@ -38,11 +34,6 @@
<property
name=
"WebServerToolWindowFactoryState"
value=
"false"
/>
<property
name=
"WebServerToolWindowFactoryState"
value=
"false"
/>
<property
name=
"restartRequiresConfirmation"
value=
"false"
/>
<property
name=
"restartRequiresConfirmation"
value=
"false"
/>
</component>
</component>
<component
name=
"RecentsManager"
>
<key
name=
"MoveFile.RECENT_KEYS"
>
<recent
name=
"E:\CPI\project\CPI\predict\lr-1e-6-batch-64-layer3"
/>
</key>
</component>
<component
name=
"RunManager"
>
<component
name=
"RunManager"
>
<configuration
name=
"run_interaction"
type=
"PythonConfigurationType"
factoryName=
"Python"
temporary=
"true"
nameIsGenerated=
"true"
>
<configuration
name=
"run_interaction"
type=
"PythonConfigurationType"
factoryName=
"Python"
temporary=
"true"
nameIsGenerated=
"true"
>
<module
name=
"CPI"
/>
<module
name=
"CPI"
/>
...
@@ -98,8 +89,18 @@
...
@@ -98,8 +89,18 @@
</map>
</map>
</option>
</option>
</component>
</component>
<component
name=
"XDebuggerManager"
>
<breakpoint-manager>
<breakpoints>
<line-breakpoint
enabled=
"true"
suspend=
"THREAD"
type=
"python-line"
>
<url>
file://$PROJECT_DIR$/run_interaction.py
</url>
<line>
98
</line>
<option
name=
"timeStamp"
value=
"13"
/>
</line-breakpoint>
</breakpoints>
</breakpoint-manager>
</component>
<component
name=
"com.intellij.coverage.CoverageDataManagerImpl"
>
<component
name=
"com.intellij.coverage.CoverageDataManagerImpl"
>
<SUITE
FILE_PATH=
"coverage/CPI$run_interaction.coverage"
NAME=
"run_interaction Coverage Results"
MODIFIED=
"1618067424929"
SOURCE_PROVIDER=
"com.intellij.coverage.DefaultCoverageFileProvider"
RUNNER=
"coverage.py"
COVERAGE_BY_TEST_ENABLED=
"true"
COVERAGE_TRACING_ENABLED=
"false"
WORKING_DIRECTORY=
"$PROJECT_DIR$"
/>
<SUITE
FILE_PATH=
"coverage/CPI$run_interaction.coverage"
NAME=
"run_interaction Coverage Results"
MODIFIED=
"1618127107728"
SOURCE_PROVIDER=
"com.intellij.coverage.DefaultCoverageFileProvider"
RUNNER=
"coverage.py"
COVERAGE_BY_TEST_ENABLED=
"true"
COVERAGE_TRACING_ENABLED=
"false"
WORKING_DIRECTORY=
"$PROJECT_DIR$"
/>
<SUITE
FILE_PATH=
"coverage/CPI$draft.coverage"
NAME=
"draft Coverage Results"
MODIFIED=
"1617456765793"
SOURCE_PROVIDER=
"com.intellij.coverage.DefaultCoverageFileProvider"
RUNNER=
"coverage.py"
COVERAGE_BY_TEST_ENABLED=
"true"
COVERAGE_TRACING_ENABLED=
"false"
WORKING_DIRECTORY=
"$PROJECT_DIR$"
/>
</component>
</component>
</project>
</project>
\ No newline at end of file
__pycache__/modeling_bert.cpython-36.pyc
浏览文件 @
825c3d12
No preview for this file type
log/lr-1e-5-batch-64-layer3-0411/events.out.tfevents.1618120716.b1393040f57d.2362.0
0 → 100644
浏览文件 @
825c3d12
File added
log/lr-1e-6-batch-64-layer3-0411/events.out.tfevents.1618120794.b1393040f57d.2469.0
0 → 100644
浏览文件 @
825c3d12
File added
model/lr-1e-5-batch-64-layer3/epoch-4-step-2000-loss-2.9310436248779297.pth
deleted
100644 → 0
浏览文件 @
d365ba55
File deleted
model/lr-1e-6-batch-64-layer3-0411/epoch-2-step-9439-loss-2.063138484954834.pth
0 → 100644
浏览文件 @
825c3d12
File added
modeling_bert.py
浏览文件 @
825c3d12
...
@@ -1865,7 +1865,7 @@ class BertAffinityModel(BertPreTrainedModel):
...
@@ -1865,7 +1865,7 @@ class BertAffinityModel(BertPreTrainedModel):
self
.
embeddings
=
BertEmbeddings
(
config
)
self
.
embeddings
=
BertEmbeddings
(
config
)
self
.
encoder
=
BertEncoder
(
config
)
self
.
encoder
=
BertEncoder
(
config
)
self
.
mlp
=
Multilayer_perceptron
(
config
)
self
.
mlp
=
Multilayer_perceptron
(
config
)
self
.
pooler
=
BertPooler
(
config
)
if
add_pooling_layer
else
None
#
self.pooler = BertPooler(config) if add_pooling_layer else None
self
.
init_weights
()
self
.
init_weights
()
...
@@ -1988,7 +1988,6 @@ class BertAffinityModel(BertPreTrainedModel):
...
@@ -1988,7 +1988,6 @@ class BertAffinityModel(BertPreTrainedModel):
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
past_key_values_length
=
past_key_values_length
,
past_key_values_length
=
past_key_values_length
,
)
)
print
(
embedding_output
.
size
())
encoder_outputs
=
self
.
encoder
(
encoder_outputs
=
self
.
encoder
(
embedding_output
,
embedding_output
,
attention_mask
=
extended_attention_mask
,
attention_mask
=
extended_attention_mask
,
...
@@ -2002,11 +2001,11 @@ class BertAffinityModel(BertPreTrainedModel):
...
@@ -2002,11 +2001,11 @@ class BertAffinityModel(BertPreTrainedModel):
return_dict
=
return_dict
,
return_dict
=
return_dict
,
)
)
sequence_output
=
encoder_outputs
[
0
]
sequence_output
=
encoder_outputs
[
0
]
pooled_output
=
self
.
pooler
(
sequence_output
)
if
self
.
pooler
is
not
None
else
None
# pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if
not
return_dict
:
return
(
sequence_output
,
pooled_output
)
+
encoder_outputs
[
1
:]
# if not return_dict:
# return (sequence_output, pooled_output) + encoder_outputs[1:]
# print(sequence_output.size())
bert_pred
=
sequence_output
[:,
0
,:]
bert_pred
=
sequence_output
[:,
0
,:]
pred_affinity
=
self
.
mlp
.
forward
(
bert_pred
)
pred_affinity
=
self
.
mlp
.
forward
(
bert_pred
)
...
...
predict/test/test.txt
0 → 100644
浏览文件 @
825c3d12
run_interaction.py
浏览文件 @
825c3d12
...
@@ -37,6 +37,7 @@ def get_task(task_name):
...
@@ -37,6 +37,7 @@ def get_task(task_name):
return
df_test
,
tokenizer_config
return
df_test
,
tokenizer_config
def
train
(
args
,
model
,
data_generator
):
def
train
(
args
,
model
,
data_generator
):
model
.
train
()
opt
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
opt
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
loss_fct
=
torch
.
nn
.
MSELoss
()
loss_fct
=
torch
.
nn
.
MSELoss
()
writer
=
SummaryWriter
(
'./log/'
+
args
.
savedir
)
writer
=
SummaryWriter
(
'./log/'
+
args
.
savedir
)
...
@@ -45,6 +46,10 @@ def train(args, model, data_generator):
...
@@ -45,6 +46,10 @@ def train(args, model, data_generator):
# detect GPU
# detect GPU
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
model
.
cuda
()
model
.
cuda
()
# print(model)
print
(
'epoch num : {}'
.
format
(
args
.
epochs
))
print
(
'step num : {}'
.
format
(
num_step
))
print
(
'batch size : {}'
.
format
(
args
.
batch_size
))
print
(
'begin training'
)
print
(
'begin training'
)
for
epoch
in
range
(
args
.
epochs
):
for
epoch
in
range
(
args
.
epochs
):
for
i
,
(
input
,
affinity
)
in
enumerate
(
data_generator
):
for
i
,
(
input
,
affinity
)
in
enumerate
(
data_generator
):
...
@@ -64,10 +69,10 @@ def train(args, model, data_generator):
...
@@ -64,10 +69,10 @@ def train(args, model, data_generator):
opt
.
step
()
opt
.
step
()
# if (i % 100 == 0):
# if (i % 100 == 0):
print
(
'Training at Epoch '
+
str
(
epoch
+
1
)
+
' step '
+
str
(
i
)
+
' with loss '
+
str
(
print
(
'Training at Epoch '
+
str
(
epoch
+
1
)
+
' step '
+
str
(
step
)
+
' with loss '
+
str
(
loss
.
cpu
()
.
detach
()
.
numpy
()))
loss
.
cpu
()
.
detach
()
.
numpy
()))
# save
# save
if
epoch
>
1
:
if
epoch
>
1
and
epoch
%
2
==
0
and
i
%
1200
==
0
:
save_path
=
'./model/'
+
args
.
savedir
+
'/'
save_path
=
'./model/'
+
args
.
savedir
+
'/'
if
not
os
.
path
.
exists
(
save_path
):
if
not
os
.
path
.
exists
(
save_path
):
os
.
mkdir
(
save_path
)
os
.
mkdir
(
save_path
)
...
@@ -91,7 +96,7 @@ def test(args, model, data_generator):
...
@@ -91,7 +96,7 @@ def test(args, model, data_generator):
else
:
else
:
pred_affinity
=
model
(
input
.
long
())
.
numpy
()
pred_affinity
=
model
(
input
.
long
())
.
numpy
()
for
res
in
range
(
args
.
batch_size
):
for
res
in
range
(
args
.
batch_size
):
f
.
write
(
str
(
pred_affinity
[
res
,
:])
+
'
\n
'
)
f
.
write
(
str
(
pred_affinity
[
res
,
:]
[
0
]
)
+
'
\n
'
)
def
main
(
args
):
def
main
(
args
):
...
@@ -108,6 +113,7 @@ def main(args):
...
@@ -108,6 +113,7 @@ def main(args):
config
=
BertConfig
.
from_pretrained
(
args
.
config
)
config
=
BertConfig
.
from_pretrained
(
args
.
config
)
model
=
BertAffinityModel
(
config
)
model
=
BertAffinityModel
(
config
)
print
(
'model name : BertAffinity'
)
print
(
'model name : BertAffinity'
)
print
(
'task name : {}'
.
format
(
args
.
task
))
if
args
.
task
==
'train'
:
if
args
.
task
==
'train'
:
train
(
args
,
model
,
data_generator
)
train
(
args
,
model
,
data_generator
)
...
@@ -150,11 +156,21 @@ if __name__ == '__main__':
...
@@ -150,11 +156,21 @@ if __name__ == '__main__':
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
# local test
# local test
# args.task = 'test'
# args.init = './model/lr-1e-5-batch-64-layer3/epoch-4-step-2000-loss-2.9310436248779297.pth'
# args.task = 'train'
# args.output = './predict/test'
# args.savedir='local_test_train'
# args.epochs = 10
# args.lr = 1e-7
# args.config = './config/config_layer_3.json'
# args.config = './config/config_layer_3.json'
# args.shuffle = False
# args.shuffle = True
args
.
task
=
'test'
args
.
init
=
'./model/lr-1e-6-batch-64-layer3-0411/epoch-2-step-9439-loss-2.063138484954834.pth'
args
.
output
=
'./predict/test'
args
.
config
=
'./config/config_layer_3.json'
args
.
shuffle
=
False
main
(
args
)
main
(
args
)
...
...
train.sh
浏览文件 @
825c3d12
CUDA_VISIBLE_DEVICES
=
1 python run_interaction.py
--b
=
64
--task
=
train
--epochs
=
10
--lr
=
1e-7
--savedir
=
lr-1e-7-batch-64-layer3
--config
=
./config/config_layer_3.json
--shuffle
=
True
CUDA_VISIBLE_DEVICES
=
1 python run_interaction.py
--b
=
64
--task
=
train
--epochs
=
5
--lr
=
1e-7
--savedir
=
lr-1e-7-batch-64-layer3
--config
=
./config/config_layer_3.json
--shuffle
=
True
\ No newline at end of file
\ No newline at end of file
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论