提交 8f39b5be 作者: 朱学凯

v1

上级 92740d2f
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import torch.nn.functional as F
from packaging import version
from utils import logging
logger = logging.get_logger(__name__)
def _gelu_python(x):
"""
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in
torch.nn.functional Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def gelu_new(x):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
if version.parse(torch.__version__) < version.parse("1.4"):
gelu = _gelu_python
else:
gelu = F.gelu
def gelu_fast(x):
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
def _silu_python(x):
"""
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
later.
"""
return x * torch.sigmoid(x)
if version.parse(torch.__version__) < version.parse("1.7"):
silu = _silu_python
else:
silu = F.silu
def mish(x):
return x * torch.tanh(torch.nn.functional.softplus(x))
def linear_act(x):
return x
ACT2FN = {
"relu": F.relu,
"silu": silu,
"swish": silu,
"gelu": gelu,
"tanh": torch.tanh,
"gelu_new": gelu_new,
"gelu_fast": gelu_fast,
"mish": mish,
"linear": linear_act,
"sigmoid": torch.sigmoid,
}
def get_activation(activation_string):
if activation_string in ACT2FN:
return ACT2FN[activation_string]
else:
raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
{
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 23615
}
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" BERT model configuration """
from configuration_utils import PretrainedConfig
from file_utils import logging
logger = logging.get_logger(__name__)
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/config.json",
"bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/config.json",
"bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/config.json",
"bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/config.json",
"bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/config.json",
"bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/config.json",
"bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/config.json",
"bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/config.json",
"bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/config.json",
"bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/config.json",
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/config.json",
"bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/config.json",
"bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/config.json",
"bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/config.json",
"bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/config.json",
"cl-tohoku/bert-base-japanese": "https://huggingface.co/cl-tohoku/bert-base-japanese/resolve/main/config.json",
"cl-tohoku/bert-base-japanese-whole-word-masking": "https://huggingface.co/cl-tohoku/bert-base-japanese-whole-word-masking/resolve/main/config.json",
"cl-tohoku/bert-base-japanese-char": "https://huggingface.co/cl-tohoku/bert-base-japanese-char/resolve/main/config.json",
"cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://huggingface.co/cl-tohoku/bert-base-japanese-char-whole-word-masking/resolve/main/config.json",
"TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/config.json",
"TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/config.json",
"wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/config.json",
# See all BERT models at https://huggingface.co/models?filter=bert
}
class BertConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.BertModel` or a
:class:`~transformers.TFBertModel`. It is used to instantiate a BERT model according to the specified arguments,
defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration
to that of the BERT `bert-base-uncased <https://huggingface.co/bert-base-uncased>`__ architecture.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Args:
vocab_size (:obj:`int`, `optional`, defaults to 30522):
Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
:obj:`inputs_ids` passed when calling :class:`~transformers.BertModel` or
:class:`~transformers.TFBertModel`.
hidden_size (:obj:`int`, `optional`, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
type_vocab_size (:obj:`int`, `optional`, defaults to 2):
The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.BertModel` or
:class:`~transformers.TFBertModel`.
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
:obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
<https://arxiv.org/abs/2009.13658>`__.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if ``config.is_decoder=True``.
Examples::
>>> from transformers import BertModel, BertConfig
>>> # Initializing a BERT bert-base-uncased style configuration
>>> configuration = BertConfig()
>>> # Initializing a model from the bert-base-uncased style configuration
>>> model = BertModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
"""
model_type = "bert"
def __init__(
self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
gradient_checkpointing=False,
position_embedding_type="absolute",
use_cache=True,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
......@@ -110,15 +110,17 @@ def seq2emb_encoder(input_seq, max_len, vocab):
class Data_Encoder(data.Dataset):
def __init__(self, train_file, tokenizer_config):
'Initialization'
self.begin_id = train_file["begin_id"]
self.sep_id = train_file["separate_id"]
self.max_len = train_file["max_len"]
# load data
with open(train_file["sps"], 'r') as f:
self.sps = f.readlines()
with open(train_file["smile"], 'r') as f:
self.smile = f.readlines()
with open(train_file["affinity"], 'r') as f:
self.affinity = f.readlines()
# define tokenizer
self.begin_id = tokenizer_config["begin_id"]
self.sep_id = tokenizer_config["separate_id"]
self.max_len = tokenizer_config["max_len"]
self.vocab = load_vocab(tokenizer_config["vocab_file"])
bpe_codes_drug = codecs.open(tokenizer_config["vocab_pair"])
self.dbpe = BPE(bpe_codes_drug, merges=-1, separator='')
......@@ -158,17 +160,16 @@ if __name__ == "__main__":
df_train = {"sps": './data/train_sps',
"smile": './data/train_smile',
"affinity": './data/train_ic50',
"vocab_file": './config/vocab.txt',
"begin_id": '[CLS]',
"separate_id": "[SEP]",
"max_len": 256
}
tokenizer_config = {"vocab_file": './config/vocab.txt',
"vocab_pair": './config/drug_codes_chembl.txt'
"vocab_pair": './config/drug_codes_chembl.txt',
"begin_id": '[CLS]',
"separate_id": "[SEP]",
"max_len": 256
}
params = {'batch_size': 5,
'shuffle': False,
'shuffle': True,
'num_workers': 0,
'drop_last': True}
trainset = Data_Encoder(df_train, tokenizer_config)
......
import time
import warnings
from abc import ABC
from typing import Optional
import torch
from file_utils import add_start_docstrings
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.BertTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
or scores for each vocabulary token after SoftMax.
kwargs:
Additional stopping critera specific kwargs.
Return:
:obj:`bool`. :obj:`False` indicates we should continue, :obj:`True` indicates we should stop.
"""
class StoppingCriteria(ABC):
"""Abstract base class for all stopping criteria that can be applied during generation."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool:
raise NotImplementedError("StoppingCriteria needs to be subclassed")
class MaxLengthCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever the full generated number of tokens exceeds :obj:`max_length`.
Keep in mind for decoder-only type of transformers, this will include the initial prompted tokens.
Args:
max_length (:obj:`int`):
The maximum length that the output sequence can have in number of tokens.
"""
def __init__(self, max_length: int):
self.max_length = max_length
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return input_ids.shape[-1] > self.max_length
class MaxTimeCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the
time will start being counted when you initialize this function. You can override this by passing an
:obj:`initial_time`.
Args:
max_time (:obj:`float`):
The maximum allowed time in seconds for the generation.
initial_time (:obj:`float`, `optional`, defaults to :obj:`time.time()`):
The start of the generation allowed time.
"""
def __init__(self, max_time: float, initial_timestamp: Optional[float] = None):
self.max_time = max_time
self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return time.time() - self.initial_timestamp > self.max_time
class StoppingCriteriaList(list):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any(criteria(input_ids, scores) for criteria in self)
def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int):
found = False
for stopping_criterium in stopping_criteria:
if isinstance(stopping_criterium, MaxLengthCriteria):
found = True
if stopping_criterium.max_length != max_length:
warnings.warn(
"You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning
)
if not found:
stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
This source diff could not be displayed because it is too large. You can view the blob instead.
from argparse import ArgumentParser
from dataset import Data_Encoder
import torch
from transformers import BertModel, BertTokenizer, BertLayer
from transformers import AutoTokenizer, AutoModel
# from modeling_bert import BertModel
parser = ArgumentParser(description='BertAff Training.')
parser.add_argument('-b', '--batch-size', default=16, type=int,
metavar='N',
help='mini-batch size (default: 16), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel')
parser.add_argument('-j', '--workers', default=0, type=int, metavar='N',
help='number of data loading workers (default: 0)')
parser.add_argument('--epochs', default=50, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--task', choices=['train', 'test', 'channel', 'ER', 'GPCR', 'kinase'],
default='', type=str, metavar='TASK',
help='Task name. Could be train, test, channel, ER, GPCR, kinase.')
parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float,
metavar='LR', help='initial learning rate', dest='lr')
def get_task(task_name):
if task_name.lower() == 'train':
df_train = {"sps": './data/train_sps',
"smile": './data/train_smile',
"affinity": './data/train_ic50',
}
tokenizer_config = {"vocab_file": './config/vocab.txt',
"vocab_pair": './config/drug_codes_chembl.txt',
"begin_id": '[CLS]',
"separate_id": "[SEP]",
"max_len": 256
}
params = {'batch_size': 5,
'shuffle': True,
'num_workers': 0,
'drop_last': True}
return df_train, tokenizer_config, params
elif task_name.lower() == 'test':
df_test = {"sps": './data/test_sps',
"smile": './data/test_smile',
"affinity": './data/test_ic50',
}
tokenizer_config = {"vocab_file": './config/vocab.txt',
"vocab_pair": './config/drug_codes_chembl.txt',
"begin_id": '[CLS]',
"separate_id": "[SEP]",
"max_len": 256
}
params = {'batch_size': 5,
'shuffle': True,
'num_workers': 0,
'drop_last': True}
return df_test, tokenizer_config, params
def main():
args = parser.parse_args()
dataset, config, params = get_task("train")
model = BertModel.from_pretrained("bert-base-uncased")
if __name__ == '__main__':
# main()
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# model = BertModel('./config/config.json')
model = BertModel.from_pretrained("bert-base-uncased")
inputs = tokenizer("Hello world!", return_tensors="pt")
outputs = model(**inputs)
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
from os.path import expanduser
from typing import Dict, List, Optional, Tuple
from tqdm import tqdm
import requests
ENDPOINT = "https://huggingface.co"
class RepoObj:
"""
HuggingFace git-based system, data structure that represents a file belonging to the current user.
"""
def __init__(self, filename: str, lastModified: str, commit: str, size: int, **kwargs):
self.filename = filename
self.lastModified = lastModified
self.commit = commit
self.size = size
class ModelSibling:
"""
Data structure that represents a public file inside a model, accessible from huggingface.co
"""
def __init__(self, rfilename: str, **kwargs):
self.rfilename = rfilename # filename relative to the model root
for k, v in kwargs.items():
setattr(self, k, v)
class ModelInfo:
"""
Info about a public model accessible from huggingface.co
"""
def __init__(
self,
modelId: Optional[str] = None, # id of model
tags: List[str] = [],
pipeline_tag: Optional[str] = None,
siblings: Optional[List[Dict]] = None, # list of files that constitute the model
**kwargs
):
self.modelId = modelId
self.tags = tags
self.pipeline_tag = pipeline_tag
self.siblings = [ModelSibling(**x) for x in siblings] if siblings is not None else None
for k, v in kwargs.items():
setattr(self, k, v)
class HfApi:
def __init__(self, endpoint=None):
self.endpoint = endpoint if endpoint is not None else ENDPOINT
def login(self, username: str, password: str) -> str:
"""
Call HF API to sign in a user and get a token if credentials are valid.
Outputs: token if credentials are valid
Throws: requests.exceptions.HTTPError if credentials are invalid
"""
path = "{}/api/login".format(self.endpoint)
r = requests.post(path, json={"username": username, "password": password})
r.raise_for_status()
d = r.json()
return d["token"]
def whoami(self, token: str) -> Tuple[str, List[str]]:
"""
Call HF API to know "whoami"
"""
path = "{}/api/whoami".format(self.endpoint)
r = requests.get(path, headers={"authorization": "Bearer {}".format(token)})
r.raise_for_status()
d = r.json()
return d["user"], d["orgs"]
def logout(self, token: str) -> None:
"""
Call HF API to log out.
"""
path = "{}/api/logout".format(self.endpoint)
r = requests.post(path, headers={"authorization": "Bearer {}".format(token)})
r.raise_for_status()
def model_list(self) -> List[ModelInfo]:
"""
Get the public list of all the models on huggingface.co
"""
path = "{}/api/models".format(self.endpoint)
r = requests.get(path)
r.raise_for_status()
d = r.json()
return [ModelInfo(**x) for x in d]
def list_repos_objs(self, token: str, organization: Optional[str] = None) -> List[RepoObj]:
"""
HuggingFace git-based system, used for models.
Call HF API to list all stored files for user (or one of their organizations).
"""
path = "{}/api/repos/ls".format(self.endpoint)
params = {"organization": organization} if organization is not None else None
r = requests.get(path, params=params, headers={"authorization": "Bearer {}".format(token)})
r.raise_for_status()
d = r.json()
return [RepoObj(**x) for x in d]
def create_repo(
self,
token: str,
name: str,
organization: Optional[str] = None,
private: Optional[bool] = None,
exist_ok=False,
lfsmultipartthresh: Optional[int] = None,
) -> str:
"""
HuggingFace git-based system, used for models.
Call HF API to create a whole repo.
Params:
private: Whether the model repo should be private (requires a paid huggingface.co account)
exist_ok: Do not raise an error if repo already exists
lfsmultipartthresh: Optional: internal param for testing purposes.
"""
path = "{}/api/repos/create".format(self.endpoint)
json = {"name": name, "organization": organization, "private": private}
if lfsmultipartthresh is not None:
json["lfsmultipartthresh"] = lfsmultipartthresh
r = requests.post(
path,
headers={"authorization": "Bearer {}".format(token)},
json=json,
)
if exist_ok and r.status_code == 409:
return ""
r.raise_for_status()
d = r.json()
return d["url"]
def delete_repo(self, token: str, name: str, organization: Optional[str] = None):
"""
HuggingFace git-based system, used for models.
Call HF API to delete a whole repo.
CAUTION(this is irreversible).
"""
path = "{}/api/repos/delete".format(self.endpoint)
r = requests.delete(
path,
headers={"authorization": "Bearer {}".format(token)},
json={"name": name, "organization": organization},
)
r.raise_for_status()
class TqdmProgressFileReader:
"""
Wrap an io.BufferedReader `f` (such as the output of `open(…, "rb")`) and override `f.read()` so as to display a
tqdm progress bar.
see github.com/huggingface/transformers/pull/2078#discussion_r354739608 for implementation details.
"""
def __init__(self, f: io.BufferedReader):
self.f = f
self.total_size = os.fstat(f.fileno()).st_size
self.pbar = tqdm(total=self.total_size, leave=False)
self.read = f.read
f.read = self._read
def _read(self, n=-1):
self.pbar.update(n)
return self.read(n)
def close(self):
self.pbar.close()
class HfFolder:
path_token = expanduser("~/.huggingface/token")
@classmethod
def save_token(cls, token):
"""
Save token, creating folder as needed.
"""
os.makedirs(os.path.dirname(cls.path_token), exist_ok=True)
with open(cls.path_token, "w+") as f:
f.write(token)
@classmethod
def get_token(cls):
"""
Get token or None if not existent.
"""
try:
with open(cls.path_token, "r") as f:
return f.read()
except FileNotFoundError:
pass
@classmethod
def delete_token(cls):
"""
Delete token. Do not fail if token does not exist.
"""
try:
os.remove(cls.path_token)
except FileNotFoundError:
pass
# coding=utf-8
# Copyright 2020 Optuna, Hugging Face
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Logging utilities. """
import logging
import os
import sys
import threading
from logging import CRITICAL # NOQA
from logging import DEBUG # NOQA
from logging import ERROR # NOQA
from logging import FATAL # NOQA
from logging import INFO # NOQA
from logging import NOTSET # NOQA
from logging import WARN # NOQA
from logging import WARNING # NOQA
from typing import Optional
_lock = threading.Lock()
_default_handler: Optional[logging.Handler] = None
log_levels = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"error": logging.ERROR,
"critical": logging.CRITICAL,
}
_default_log_level = logging.WARNING
def _get_default_logging_level():
"""
If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
not - fall back to ``_default_log_level``
"""
env_level_str = os.getenv("TRANSFORMERS_VERBOSITY", None)
if env_level_str:
if env_level_str in log_levels:
return log_levels[env_level_str]
else:
logging.getLogger().warning(
f"Unknown option TRANSFORMERS_VERBOSITY={env_level_str}, "
f"has to be one of: { ', '.join(log_levels.keys()) }"
)
return _default_log_level
def _get_library_name() -> str:
return __name__.split(".")[0]
def _get_library_root_logger() -> logging.Logger:
return logging.getLogger(_get_library_name())
def _configure_library_root_logger() -> None:
global _default_handler
with _lock:
if _default_handler:
# This library has already configured the library root logger.
return
_default_handler = logging.StreamHandler() # Set sys.stderr as stream.
_default_handler.flush = sys.stderr.flush
# Apply our default configuration to the library root logger.
library_root_logger = _get_library_root_logger()
library_root_logger.addHandler(_default_handler)
library_root_logger.setLevel(_get_default_logging_level())
library_root_logger.propagate = False
def _reset_library_root_logger() -> None:
global _default_handler
with _lock:
if not _default_handler:
return
library_root_logger = _get_library_root_logger()
library_root_logger.removeHandler(_default_handler)
library_root_logger.setLevel(logging.NOTSET)
_default_handler = None
def get_logger(name: Optional[str] = None) -> logging.Logger:
"""
Return a logger with the specified name.
This function is not supposed to be directly accessed unless you are writing a custom transformers module.
"""
if name is None:
name = _get_library_name()
_configure_library_root_logger()
return logging.getLogger(name)
def get_verbosity() -> int:
"""
Return the current level for the 🤗 Transformers's root logger as an int.
Returns:
:obj:`int`: The logging level.
.. note::
🤗 Transformers has following logging levels:
- 50: ``transformers.logging.CRITICAL`` or ``transformers.logging.FATAL``
- 40: ``transformers.logging.ERROR``
- 30: ``transformers.logging.WARNING`` or ``transformers.logging.WARN``
- 20: ``transformers.logging.INFO``
- 10: ``transformers.logging.DEBUG``
"""
_configure_library_root_logger()
return _get_library_root_logger().getEffectiveLevel()
def set_verbosity(verbosity: int) -> None:
"""
Set the vebosity level for the 🤗 Transformers's root logger.
Args:
verbosity (:obj:`int`):
Logging level, e.g., one of:
- ``transformers.logging.CRITICAL`` or ``transformers.logging.FATAL``
- ``transformers.logging.ERROR``
- ``transformers.logging.WARNING`` or ``transformers.logging.WARN``
- ``transformers.logging.INFO``
- ``transformers.logging.DEBUG``
"""
_configure_library_root_logger()
_get_library_root_logger().setLevel(verbosity)
def set_verbosity_info():
"""Set the verbosity to the :obj:`INFO` level."""
return set_verbosity(INFO)
def set_verbosity_warning():
"""Set the verbosity to the :obj:`WARNING` level."""
return set_verbosity(WARNING)
def set_verbosity_debug():
"""Set the verbosity to the :obj:`DEBUG` level."""
return set_verbosity(DEBUG)
def set_verbosity_error():
"""Set the verbosity to the :obj:`ERROR` level."""
return set_verbosity(ERROR)
def disable_default_handler() -> None:
"""Disable the default handler of the HuggingFace Transformers's root logger."""
_configure_library_root_logger()
assert _default_handler is not None
_get_library_root_logger().removeHandler(_default_handler)
def enable_default_handler() -> None:
"""Enable the default handler of the HuggingFace Transformers's root logger."""
_configure_library_root_logger()
assert _default_handler is not None
_get_library_root_logger().addHandler(_default_handler)
def add_handler(handler: logging.Handler) -> None:
"""adds a handler to the HuggingFace Transformers's root logger."""
_configure_library_root_logger()
assert handler is not None
_get_library_root_logger().addHandler(handler)
def remove_handler(handler: logging.Handler) -> None:
"""removes given handler from the HuggingFace Transformers's root logger."""
_configure_library_root_logger()
assert handler is not None and handler not in _get_library_root_logger().handlers
_get_library_root_logger().removeHandler(handler)
def disable_propagation() -> None:
"""
Disable propagation of the library log outputs. Note that log propagation is disabled by default.
"""
_configure_library_root_logger()
_get_library_root_logger().propagate = False
def enable_propagation() -> None:
"""
Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to
prevent double logging if the root logger has been configured.
"""
_configure_library_root_logger()
_get_library_root_logger().propagate = True
def enable_explicit_format() -> None:
"""
Enable explicit formatting for every HuggingFace Transformers's logger. The explicit formatter is as follows:
::
[LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
All handlers currently bound to the root logger are affected by this method.
"""
handlers = _get_library_root_logger().handlers
for handler in handlers:
formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
handler.setFormatter(formatter)
def reset_format() -> None:
"""
Resets the formatting for HuggingFace Transformers's loggers.
All handlers currently bound to the root logger are affected by this method.
"""
handlers = _get_library_root_logger().handlers
for handler in handlers:
handler.setFormatter(None)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论