提交 e37fe169 作者: 朱学凯

add trainer

上级 8f39b5be
# Default ignored files
/shelf/
/workspace.xml
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
# Editor-based HTTP Client requests
/httpRequests/
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4"> <module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager"> <component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" /> <content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" /> <orderEntry type="jdk" jdkName="Python 3.6 (py3.6)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" /> <orderEntry type="sourceFolder" forTests="false" />
</component> </component>
</module> </module>
\ No newline at end of file
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="1">
<item index="0" class="java.lang.String" itemvalue="subword-nmt" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (code)" project-jdk-type="Python SDK" /> <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (py3.6)" project-jdk-type="Python SDK" />
</project> </project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="ProjectModuleManager"> <component name="ProjectModuleManager">
<modules> <modules>
<module fileurl="file://$PROJECT_DIR$/.idea/CPI.iml" filepath="$PROJECT_DIR$/.idea/CPI.iml" /> <module fileurl="file://$PROJECT_DIR$/.idea/CPI.iml" filepath="$PROJECT_DIR$/.idea/CPI.iml" />
</modules> </modules>
</component> </component>
</project> </project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="VcsDirectoryMappings"> <component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" /> <mapping directory="$PROJECT_DIR$" vcs="Git" />
</component> </component>
</project> </project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ChangeListManager">
<list default="true" id="7d3a4caa-6c7d-4ed2-9017-145cec64d9a3" name="Default Changelist" comment="">
<change afterPath="$PROJECT_DIR$/.idea/inspectionProfiles/Project_Default.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/.gitignore" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/CPI.iml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/CPI.iml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/misc.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/modules.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/modules.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/vcs.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/vcs.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/activations.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/configuration_bert.py" beforeDir="false" afterPath="$PROJECT_DIR$/configuration_bert.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/configuration_utils.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/dataset.py" beforeDir="false" afterPath="$PROJECT_DIR$/dataset.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/file_utils.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/generation_beam_search.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/generation_logits_process.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/generation_stopping_criteria.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/generation_utils.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/modeling_bert.py" beforeDir="false" afterPath="$PROJECT_DIR$/modeling_bert.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/modeling_outputs.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/modeling_utils.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/run_interaction.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_interaction.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/utils/hf_api.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/utils/logging.py" beforeDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
<option name="LAST_RESOLUTION" value="IGNORE" />
</component>
<component name="Git.Settings">
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
</component>
<component name="ProjectId" id="1qEzgr8swJnGyU4GzuvJCGLLPSb" />
<component name="ProjectLevelVcsManager" settingsEditedManually="true" />
<component name="ProjectViewState">
<option name="hideEmptyMiddlePackages" value="true" />
<option name="showLibraryContents" value="true" />
</component>
<component name="PropertiesComponent">
<property name="RunOnceActivity.OpenProjectViewOnStart" value="true" />
<property name="RunOnceActivity.ShowReadmeOnStart" value="true" />
<property name="SHARE_PROJECT_CONFIGURATION_FILES" value="true" />
<property name="WebServerToolWindowFactoryState" value="false" />
<property name="last_opened_file_path" value="$PROJECT_DIR$" />
<property name="restartRequiresConfirmation" value="false" />
<property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
</component>
<component name="RecentsManager">
<key name="CopyFile.RECENT_KEYS">
<recent name="E:\CPI\project\CPI" />
</key>
</component>
<component name="RunManager" selected="Python.run_interaction">
<configuration name="dataset" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="CPI" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/dataset.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="run_interaction" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="CPI" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/run_interaction.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<recent_temporary>
<list>
<item itemvalue="Python.run_interaction" />
<item itemvalue="Python.dataset" />
</list>
</recent_temporary>
</component>
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
<component name="TaskManager">
<task active="true" id="Default" summary="Default task">
<changelist id="7d3a4caa-6c7d-4ed2-9017-145cec64d9a3" name="Default Changelist" comment="" />
<created>1616659651684</created>
<option name="number" value="Default" />
<option name="presentableId" value="Default" />
<updated>1616659651684</updated>
<workItem from="1616659659765" duration="24696000" />
</task>
<servers />
</component>
<component name="TypeScriptGeneratedFilesManager">
<option name="version" value="3" />
</component>
<component name="Vcs.Log.Tabs.Properties">
<option name="TAB_STATES">
<map>
<entry key="MAIN">
<value>
<State />
</value>
</entry>
</map>
</option>
</component>
<component name="XDebuggerManager">
<breakpoint-manager>
<breakpoints>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/run_interaction.py</url>
<line>57</line>
<option name="timeStamp" value="16" />
</line-breakpoint>
</breakpoints>
</breakpoint-manager>
</component>
<component name="com.intellij.coverage.CoverageDataManagerImpl">
<SUITE FILE_PATH="coverage/CPI$run_interaction.coverage" NAME="run_interaction Coverage Results" MODIFIED="1616929902995" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/CPI$dataset.coverage" NAME="dataset Coverage Results" MODIFIED="1616847821413" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
</component>
</project>
\ No newline at end of file
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import torch.nn.functional as F
from packaging import version
from utils import logging
logger = logging.get_logger(__name__)
def _gelu_python(x):
"""
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in
torch.nn.functional Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def gelu_new(x):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
if version.parse(torch.__version__) < version.parse("1.4"):
gelu = _gelu_python
else:
gelu = F.gelu
def gelu_fast(x):
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
def _silu_python(x):
"""
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
later.
"""
return x * torch.sigmoid(x)
if version.parse(torch.__version__) < version.parse("1.7"):
silu = _silu_python
else:
silu = F.silu
def mish(x):
return x * torch.tanh(torch.nn.functional.softplus(x))
def linear_act(x):
return x
ACT2FN = {
"relu": F.relu,
"silu": silu,
"swish": silu,
"gelu": gelu,
"tanh": torch.tanh,
"gelu_new": gelu_new,
"gelu_fast": gelu_fast,
"mish": mish,
"linear": linear_act,
"sigmoid": torch.sigmoid,
}
def get_activation(activation_string):
if activation_string in ACT2FN:
return ACT2FN[activation_string]
else:
raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
...@@ -15,11 +15,7 @@ ...@@ -15,11 +15,7 @@
# limitations under the License. # limitations under the License.
""" BERT model configuration """ """ BERT model configuration """
from configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from file_utils import logging
logger = logging.get_logger(__name__)
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/config.json", "bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/config.json",
...@@ -44,6 +40,7 @@ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -44,6 +40,7 @@ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/config.json", "TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/config.json",
"TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/config.json", "TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/config.json",
"wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/config.json", "wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/config.json",
"BertAffinity": "./config/config.json"
# See all BERT models at https://huggingface.co/models?filter=bert # See all BERT models at https://huggingface.co/models?filter=bert
} }
......
...@@ -173,6 +173,6 @@ if __name__ == "__main__": ...@@ -173,6 +173,6 @@ if __name__ == "__main__":
'num_workers': 0, 'num_workers': 0,
'drop_last': True} 'drop_last': True}
trainset = Data_Encoder(df_train, tokenizer_config) trainset = Data_Encoder(df_train, tokenizer_config)
training_generator = data.DataLoader(trainset) training_generator = data.DataLoader(trainset, **params)
for i, (input, affinity) in tqdm(enumerate(training_generator)): for i, (input, affinity) in tqdm(enumerate(training_generator)):
print('') print(input.size())
import time
import warnings
from abc import ABC
from typing import Optional
import torch
from file_utils import add_start_docstrings
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.BertTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
or scores for each vocabulary token after SoftMax.
kwargs:
Additional stopping critera specific kwargs.
Return:
:obj:`bool`. :obj:`False` indicates we should continue, :obj:`True` indicates we should stop.
"""
class StoppingCriteria(ABC):
"""Abstract base class for all stopping criteria that can be applied during generation."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool:
raise NotImplementedError("StoppingCriteria needs to be subclassed")
class MaxLengthCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever the full generated number of tokens exceeds :obj:`max_length`.
Keep in mind for decoder-only type of transformers, this will include the initial prompted tokens.
Args:
max_length (:obj:`int`):
The maximum length that the output sequence can have in number of tokens.
"""
def __init__(self, max_length: int):
self.max_length = max_length
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return input_ids.shape[-1] > self.max_length
class MaxTimeCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the
time will start being counted when you initialize this function. You can override this by passing an
:obj:`initial_time`.
Args:
max_time (:obj:`float`):
The maximum allowed time in seconds for the generation.
initial_time (:obj:`float`, `optional`, defaults to :obj:`time.time()`):
The start of the generation allowed time.
"""
def __init__(self, max_time: float, initial_timestamp: Optional[float] = None):
self.max_time = max_time
self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return time.time() - self.initial_timestamp > self.max_time
class StoppingCriteriaList(list):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any(criteria(input_ids, scores) for criteria in self)
def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int):
found = False
for stopping_criterium in stopping_criteria:
if isinstance(stopping_criterium, MaxLengthCriteria):
found = True
if stopping_criterium.max_length != max_length:
warnings.warn(
"You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning
)
if not found:
stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
from argparse import ArgumentParser from argparse import ArgumentParser
from dataset import Data_Encoder from dataset import Data_Encoder
import torch import torch
from transformers import BertModel, BertTokenizer, BertLayer from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel from configuration_bert import BertConfig
# from modeling_bert import BertModel from modeling_bert import BertAffinityModel
parser = ArgumentParser(description='BertAff Training.')
parser.add_argument('-b', '--batch-size', default=16, type=int,
metavar='N',
help='mini-batch size (default: 16), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel')
parser.add_argument('-j', '--workers', default=0, type=int, metavar='N',
help='number of data loading workers (default: 0)')
parser.add_argument('--epochs', default=50, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--task', choices=['train', 'test', 'channel', 'ER', 'GPCR', 'kinase'],
default='', type=str, metavar='TASK',
help='Task name. Could be train, test, channel, ER, GPCR, kinase.')
parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float,
metavar='LR', help='initial learning rate', dest='lr')
def get_task(task_name): def get_task(task_name):
...@@ -39,11 +19,7 @@ def get_task(task_name): ...@@ -39,11 +19,7 @@ def get_task(task_name):
"separate_id": "[SEP]", "separate_id": "[SEP]",
"max_len": 256 "max_len": 256
} }
params = {'batch_size': 5, return df_train, tokenizer_config
'shuffle': True,
'num_workers': 0,
'drop_last': True}
return df_train, tokenizer_config, params
elif task_name.lower() == 'test': elif task_name.lower() == 'test':
df_test = {"sps": './data/test_sps', df_test = {"sps": './data/test_sps',
"smile": './data/test_smile', "smile": './data/test_smile',
...@@ -55,25 +31,63 @@ def get_task(task_name): ...@@ -55,25 +31,63 @@ def get_task(task_name):
"separate_id": "[SEP]", "separate_id": "[SEP]",
"max_len": 256 "max_len": 256
} }
params = {'batch_size': 5,
'shuffle': True,
'num_workers': 0,
'drop_last': True}
return df_test, tokenizer_config, params return df_test, tokenizer_config
def main(args):
# load data
data_file, tokenizer_config = get_task(args.task)
dataset = Data_Encoder(data_file, tokenizer_config)
data_loder_para = {'batch_size': args.batch_size,
'shuffle': True,
'num_workers': args.workers,
'drop_last': True
}
data_generator = DataLoader(dataset, **data_loder_para)
# creat model
config = BertConfig.from_pretrained(args.config)
model = BertAffinityModel(config)
if args.task == 'train':
opt = torch.optim.Adam(model.parameters(), lr=args.lr)
loss_fct = torch.nn.MSELoss()
for epoch in range(args.epochs):
for i, (input, affinity) in enumerate(data_generator):
pred_affinity = model(input.long())
loss = loss_fct(pred_affinity, affinity)
opt.zero_grad()
loss.backward()
opt.step()
print('------------------')
def main():
args = parser.parse_args()
dataset, config, params = get_task("train")
model = BertModel.from_pretrained("bert-base-uncased")
if __name__ == '__main__': if __name__ == '__main__':
# main() # get parameter
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") parser = ArgumentParser(description='BertAffinity')
# model = BertModel('./config/config.json') parser.add_argument('-b', '--batch-size', default=1, type=int,
model = BertModel.from_pretrained("bert-base-uncased") metavar='N',
inputs = tokenizer("Hello world!", return_tensors="pt") help='mini-batch size (default: 16), this is the total '
outputs = model(**inputs) 'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel')
parser.add_argument('-j', '--workers', default=0, type=int, metavar='N',
help='number of data loading workers (default: 0)')
parser.add_argument('--epochs', default=50, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--task', choices=['train', 'test', 'channel', 'ER', 'GPCR', 'kinase'],
default='train', type=str, metavar='TASK',
help='Task name. Could be train, test, channel, ER, GPCR, kinase.')
parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--config', default='./config/config.json', type=str, help='model config file path')
args = parser.parse_args()
# input dict
main(args)
# load_config(args.config)
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
from os.path import expanduser
from typing import Dict, List, Optional, Tuple
from tqdm import tqdm
import requests
ENDPOINT = "https://huggingface.co"
class RepoObj:
"""
HuggingFace git-based system, data structure that represents a file belonging to the current user.
"""
def __init__(self, filename: str, lastModified: str, commit: str, size: int, **kwargs):
self.filename = filename
self.lastModified = lastModified
self.commit = commit
self.size = size
class ModelSibling:
"""
Data structure that represents a public file inside a model, accessible from huggingface.co
"""
def __init__(self, rfilename: str, **kwargs):
self.rfilename = rfilename # filename relative to the model root
for k, v in kwargs.items():
setattr(self, k, v)
class ModelInfo:
"""
Info about a public model accessible from huggingface.co
"""
def __init__(
self,
modelId: Optional[str] = None, # id of model
tags: List[str] = [],
pipeline_tag: Optional[str] = None,
siblings: Optional[List[Dict]] = None, # list of files that constitute the model
**kwargs
):
self.modelId = modelId
self.tags = tags
self.pipeline_tag = pipeline_tag
self.siblings = [ModelSibling(**x) for x in siblings] if siblings is not None else None
for k, v in kwargs.items():
setattr(self, k, v)
class HfApi:
def __init__(self, endpoint=None):
self.endpoint = endpoint if endpoint is not None else ENDPOINT
def login(self, username: str, password: str) -> str:
"""
Call HF API to sign in a user and get a token if credentials are valid.
Outputs: token if credentials are valid
Throws: requests.exceptions.HTTPError if credentials are invalid
"""
path = "{}/api/login".format(self.endpoint)
r = requests.post(path, json={"username": username, "password": password})
r.raise_for_status()
d = r.json()
return d["token"]
def whoami(self, token: str) -> Tuple[str, List[str]]:
"""
Call HF API to know "whoami"
"""
path = "{}/api/whoami".format(self.endpoint)
r = requests.get(path, headers={"authorization": "Bearer {}".format(token)})
r.raise_for_status()
d = r.json()
return d["user"], d["orgs"]
def logout(self, token: str) -> None:
"""
Call HF API to log out.
"""
path = "{}/api/logout".format(self.endpoint)
r = requests.post(path, headers={"authorization": "Bearer {}".format(token)})
r.raise_for_status()
def model_list(self) -> List[ModelInfo]:
"""
Get the public list of all the models on huggingface.co
"""
path = "{}/api/models".format(self.endpoint)
r = requests.get(path)
r.raise_for_status()
d = r.json()
return [ModelInfo(**x) for x in d]
def list_repos_objs(self, token: str, organization: Optional[str] = None) -> List[RepoObj]:
"""
HuggingFace git-based system, used for models.
Call HF API to list all stored files for user (or one of their organizations).
"""
path = "{}/api/repos/ls".format(self.endpoint)
params = {"organization": organization} if organization is not None else None
r = requests.get(path, params=params, headers={"authorization": "Bearer {}".format(token)})
r.raise_for_status()
d = r.json()
return [RepoObj(**x) for x in d]
def create_repo(
self,
token: str,
name: str,
organization: Optional[str] = None,
private: Optional[bool] = None,
exist_ok=False,
lfsmultipartthresh: Optional[int] = None,
) -> str:
"""
HuggingFace git-based system, used for models.
Call HF API to create a whole repo.
Params:
private: Whether the model repo should be private (requires a paid huggingface.co account)
exist_ok: Do not raise an error if repo already exists
lfsmultipartthresh: Optional: internal param for testing purposes.
"""
path = "{}/api/repos/create".format(self.endpoint)
json = {"name": name, "organization": organization, "private": private}
if lfsmultipartthresh is not None:
json["lfsmultipartthresh"] = lfsmultipartthresh
r = requests.post(
path,
headers={"authorization": "Bearer {}".format(token)},
json=json,
)
if exist_ok and r.status_code == 409:
return ""
r.raise_for_status()
d = r.json()
return d["url"]
def delete_repo(self, token: str, name: str, organization: Optional[str] = None):
"""
HuggingFace git-based system, used for models.
Call HF API to delete a whole repo.
CAUTION(this is irreversible).
"""
path = "{}/api/repos/delete".format(self.endpoint)
r = requests.delete(
path,
headers={"authorization": "Bearer {}".format(token)},
json={"name": name, "organization": organization},
)
r.raise_for_status()
class TqdmProgressFileReader:
"""
Wrap an io.BufferedReader `f` (such as the output of `open(…, "rb")`) and override `f.read()` so as to display a
tqdm progress bar.
see github.com/huggingface/transformers/pull/2078#discussion_r354739608 for implementation details.
"""
def __init__(self, f: io.BufferedReader):
self.f = f
self.total_size = os.fstat(f.fileno()).st_size
self.pbar = tqdm(total=self.total_size, leave=False)
self.read = f.read
f.read = self._read
def _read(self, n=-1):
self.pbar.update(n)
return self.read(n)
def close(self):
self.pbar.close()
class HfFolder:
path_token = expanduser("~/.huggingface/token")
@classmethod
def save_token(cls, token):
"""
Save token, creating folder as needed.
"""
os.makedirs(os.path.dirname(cls.path_token), exist_ok=True)
with open(cls.path_token, "w+") as f:
f.write(token)
@classmethod
def get_token(cls):
"""
Get token or None if not existent.
"""
try:
with open(cls.path_token, "r") as f:
return f.read()
except FileNotFoundError:
pass
@classmethod
def delete_token(cls):
"""
Delete token. Do not fail if token does not exist.
"""
try:
os.remove(cls.path_token)
except FileNotFoundError:
pass
# coding=utf-8
# Copyright 2020 Optuna, Hugging Face
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Logging utilities. """
import logging
import os
import sys
import threading
from logging import CRITICAL # NOQA
from logging import DEBUG # NOQA
from logging import ERROR # NOQA
from logging import FATAL # NOQA
from logging import INFO # NOQA
from logging import NOTSET # NOQA
from logging import WARN # NOQA
from logging import WARNING # NOQA
from typing import Optional
_lock = threading.Lock()
_default_handler: Optional[logging.Handler] = None
log_levels = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"error": logging.ERROR,
"critical": logging.CRITICAL,
}
_default_log_level = logging.WARNING
def _get_default_logging_level():
"""
If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
not - fall back to ``_default_log_level``
"""
env_level_str = os.getenv("TRANSFORMERS_VERBOSITY", None)
if env_level_str:
if env_level_str in log_levels:
return log_levels[env_level_str]
else:
logging.getLogger().warning(
f"Unknown option TRANSFORMERS_VERBOSITY={env_level_str}, "
f"has to be one of: { ', '.join(log_levels.keys()) }"
)
return _default_log_level
def _get_library_name() -> str:
return __name__.split(".")[0]
def _get_library_root_logger() -> logging.Logger:
return logging.getLogger(_get_library_name())
def _configure_library_root_logger() -> None:
global _default_handler
with _lock:
if _default_handler:
# This library has already configured the library root logger.
return
_default_handler = logging.StreamHandler() # Set sys.stderr as stream.
_default_handler.flush = sys.stderr.flush
# Apply our default configuration to the library root logger.
library_root_logger = _get_library_root_logger()
library_root_logger.addHandler(_default_handler)
library_root_logger.setLevel(_get_default_logging_level())
library_root_logger.propagate = False
def _reset_library_root_logger() -> None:
global _default_handler
with _lock:
if not _default_handler:
return
library_root_logger = _get_library_root_logger()
library_root_logger.removeHandler(_default_handler)
library_root_logger.setLevel(logging.NOTSET)
_default_handler = None
def get_logger(name: Optional[str] = None) -> logging.Logger:
"""
Return a logger with the specified name.
This function is not supposed to be directly accessed unless you are writing a custom transformers module.
"""
if name is None:
name = _get_library_name()
_configure_library_root_logger()
return logging.getLogger(name)
def get_verbosity() -> int:
"""
Return the current level for the 🤗 Transformers's root logger as an int.
Returns:
:obj:`int`: The logging level.
.. note::
🤗 Transformers has following logging levels:
- 50: ``transformers.logging.CRITICAL`` or ``transformers.logging.FATAL``
- 40: ``transformers.logging.ERROR``
- 30: ``transformers.logging.WARNING`` or ``transformers.logging.WARN``
- 20: ``transformers.logging.INFO``
- 10: ``transformers.logging.DEBUG``
"""
_configure_library_root_logger()
return _get_library_root_logger().getEffectiveLevel()
def set_verbosity(verbosity: int) -> None:
"""
Set the vebosity level for the 🤗 Transformers's root logger.
Args:
verbosity (:obj:`int`):
Logging level, e.g., one of:
- ``transformers.logging.CRITICAL`` or ``transformers.logging.FATAL``
- ``transformers.logging.ERROR``
- ``transformers.logging.WARNING`` or ``transformers.logging.WARN``
- ``transformers.logging.INFO``
- ``transformers.logging.DEBUG``
"""
_configure_library_root_logger()
_get_library_root_logger().setLevel(verbosity)
def set_verbosity_info():
"""Set the verbosity to the :obj:`INFO` level."""
return set_verbosity(INFO)
def set_verbosity_warning():
"""Set the verbosity to the :obj:`WARNING` level."""
return set_verbosity(WARNING)
def set_verbosity_debug():
"""Set the verbosity to the :obj:`DEBUG` level."""
return set_verbosity(DEBUG)
def set_verbosity_error():
"""Set the verbosity to the :obj:`ERROR` level."""
return set_verbosity(ERROR)
def disable_default_handler() -> None:
"""Disable the default handler of the HuggingFace Transformers's root logger."""
_configure_library_root_logger()
assert _default_handler is not None
_get_library_root_logger().removeHandler(_default_handler)
def enable_default_handler() -> None:
"""Enable the default handler of the HuggingFace Transformers's root logger."""
_configure_library_root_logger()
assert _default_handler is not None
_get_library_root_logger().addHandler(_default_handler)
def add_handler(handler: logging.Handler) -> None:
"""adds a handler to the HuggingFace Transformers's root logger."""
_configure_library_root_logger()
assert handler is not None
_get_library_root_logger().addHandler(handler)
def remove_handler(handler: logging.Handler) -> None:
"""removes given handler from the HuggingFace Transformers's root logger."""
_configure_library_root_logger()
assert handler is not None and handler not in _get_library_root_logger().handlers
_get_library_root_logger().removeHandler(handler)
def disable_propagation() -> None:
"""
Disable propagation of the library log outputs. Note that log propagation is disabled by default.
"""
_configure_library_root_logger()
_get_library_root_logger().propagate = False
def enable_propagation() -> None:
"""
Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to
prevent double logging if the root logger has been configured.
"""
_configure_library_root_logger()
_get_library_root_logger().propagate = True
def enable_explicit_format() -> None:
"""
Enable explicit formatting for every HuggingFace Transformers's logger. The explicit formatter is as follows:
::
[LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
All handlers currently bound to the root logger are affected by this method.
"""
handlers = _get_library_root_logger().handlers
for handler in handlers:
formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
handler.setFormatter(formatter)
def reset_format() -> None:
"""
Resets the formatting for HuggingFace Transformers's loggers.
All handlers currently bound to the root logger are affected by this method.
"""
handlers = _get_library_root_logger().handlers
for handler in handlers:
handler.setFormatter(None)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论