Skip to content

Commit 5fd612e

Browse files
authored
update t5 (#763)
1 parent 79b7d52 commit 5fd612e

31 files changed

+3336
-637
lines changed

mindnlp/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
MindNLP library.
1818
"""
1919
import os
20-
os.environ["HF_ENDPOINT"] = 'https://hf-mirror.com/'
20+
if os.environ.get('HF_ENDPOINT', None) is None:
21+
os.environ["HF_ENDPOINT"] = 'https://hf-mirror.com/'
2122
os.environ["MS_DEV_FORCE_ACL"] = '1'
2223

2324
import mindspore

mindnlp/_legacy/functional.py

-2
Original file line numberDiff line numberDiff line change
@@ -1463,8 +1463,6 @@ def sumproduct_pair(left_, right_, sum_dims_, keep_dim_):
14631463
ELLIPSIS = 52
14641464

14651465
def einsum(equation, *operands):
1466-
if mindspore.get_context('device_target') == 'GPU':
1467-
return _get_cache_prim(ops.Einsum)(equation)(operands)
14681466
assert operands, "einsum(): must provide at least one operand"
14691467

14701468
arrow_pos = equation.find("->")

mindnlp/configs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
# for modelscope models
3333
MS_URL_BASE = "https://modelscope.cn/api/v1/models/mindnlp/{}/repo?Revision=master&FilePath={}"
3434
# for huggingface url
35-
HF_URL_BASE = 'https://hf-mirror.com/{}/resolve/main/{}'
35+
HF_URL_BASE = os.environ.get('HF_ENDPOINT', 'https://hf-mirror.com/') + '{}/resolve/main/{}'
3636

3737
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
3838
MINDNLP_CACHE = os.getenv("MINDNLP_CACHE", DEFAULT_ROOT)

mindnlp/injection.py

+28-8
Original file line numberDiff line numberDiff line change
@@ -275,10 +275,16 @@ def custom_multinomial(probabilities, num_samples, replacement=True):
275275
samples = ops.searchsorted(cumulative_probs, uniform_samples, right=True)
276276
else:
277277
# without replacement
278-
indices = ops.arange(probabilities.shape[-1])
279-
shuffled_indices = ops.randperm(probabilities.shape[-1]).unsqueeze(0).broadcast_to((probabilities.shape[:-1], -1))
280-
selected_indices = shuffled_indices[:, :num_samples]
281-
samples = indices[selected_indices]
278+
n_dist = 1
279+
if probabilities.ndim > 1:
280+
n_dist = probabilities.shape[-2]
281+
random_uniform = ops.rand((n_dist * probabilities.shape[-1],))
282+
if n_dist != 1:
283+
random_uniform = random_uniform.reshape(n_dist, probabilities.shape[-1])
284+
285+
vals = ops.div(ops.log(random_uniform), probabilities + 1e-6)
286+
_, samples = ops.top_k(vals, num_samples)
287+
282288
return samples
283289

284290
if DEVICE_TARGET == 'GPU':
@@ -291,6 +297,22 @@ def eq(self, other):
291297
Tensor.eq = eq
292298
StubTensor.eq = eq
293299

300+
301+
def _eq(self, other):
302+
if not isinstance(other, (int, float, Tensor)):
303+
return False
304+
if isinstance(other, Tensor) and self.shape != other.shape:
305+
return False
306+
if id(self) == id(other):
307+
return True
308+
# bool type is not supported for `Equal` operator in backend.
309+
if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_):
310+
self = self.to(mstype.int32)
311+
other = other.to(mstype.int32)
312+
return ops.eq(self, other)
313+
314+
Parameter.__eq__ = _eq
315+
294316
class Dense(nn.Cell):
295317
"""patched Dense"""
296318
def __init__(self,
@@ -481,12 +503,10 @@ def extend_repr(self):
481503
return f'normalized_shape={self.normalized_shape}, begin_norm_axis={self.begin_norm_axis}, ' \
482504
f'begin_params_axis={self.begin_params_axis}, gamma={self.gamma}, beta={self.beta}'
483505

484-
485506
def half(self):
486507
"""patched nn.Cell.half"""
487-
for param in self.get_parameters():
488-
if param.dtype in (mindspore.float32, mindspore.float16):
489-
param.set_dtype(mindspore.float16)
508+
self.to_float(mindspore.float16)
509+
return self
490510

491511
nn.Cell.half = half
492512

mindnlp/transformers/configuration_utils.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(self, **kwargs):
9797
if not isinstance(self.id2label, dict):
9898
raise ValueError("Argument id2label should be a dictionary.")
9999
num_labels = kwargs.pop("num_labels", None)
100-
print(num_labels)
100+
101101
if num_labels is not None and len(self.id2label) != num_labels:
102102
logger.warning(
103103
f"You passed along `num_labels={num_labels}` with an incompatible id to label map: "
@@ -532,6 +532,16 @@ def to_file(self, save_path):
532532
with open(os.path.join(save_path, 'config.json'), encoding='utf-8') as f:
533533
json.dump(output_dict, f, sort_keys=True, indent=2)
534534

535+
def update(self, config_dict: Dict[str, Any]):
536+
"""
537+
Updates attributes of this class with attributes from `config_dict`.
538+
539+
Args:
540+
config_dict (`Dict[str, Any]`): Dictionary of attributes that should be updated for this class.
541+
"""
542+
for key, value in config_dict.items():
543+
setattr(self, key, value)
544+
535545
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
536546
"""
537547
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the

mindnlp/transformers/generation/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3237,7 +3237,7 @@ def beam_sample(
32373237

32383238
probs = ops.softmax(next_token_scores, axis=-1)
32393239

3240-
next_tokens = ops.multinomial(probs, num_samples=2 * num_beams, replacement=False)
3240+
next_tokens = ops.multinomial(probs, num_samples=2 * num_beams)
32413241
next_token_scores = ops.gather_elements(next_token_scores, -1, next_tokens)
32423242

32433243
next_token_scores, _indices = ops.sort(next_token_scores, descending=True, axis=1)
@@ -4292,7 +4292,7 @@ def assisted_decoding(
42924292
# 3. Obtain the next tokens from the original model logits.
42934293
if do_sample:
42944294
probs = ops.softmax(new_logits, axis=-1)
4295-
selected_tokens = ops.multinomial(probs[0, :, :], num_samples=1, replacement=False).squeeze(1)[None, :]
4295+
selected_tokens = ops.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
42964296
else:
42974297
selected_tokens = new_logits.argmax(axis=-1)
42984298

mindnlp/transformers/modeling_utils.py

+84-15
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
# pylint: disable=unused-argument
2424
# pylint: disable=attribute-defined-outside-init
2525
# pylint: disable=self-cls-assignment
26+
# pylint: disable=no-name-in-module
2627
"""
2728
Abstract class for Pretrained models.
2829
"""
@@ -38,6 +39,7 @@
3839
import mindspore
3940
from mindspore import load_checkpoint, save_checkpoint
4041
from mindspore import nn, ops, Tensor, Parameter
42+
from mindspore._c_expression import MixedPrecisionType
4143

4244
from mindnlp.configs import MS_URL_BASE, HF_URL_BASE, PT_WEIGHTS_NAME, WEIGHTS_NAME, WEIGHTS_INDEX_NAME, PT_WEIGHTS_INDEX_NAME
4345
from mindnlp.utils.download import is_remote_url, download_url, cached_file, get_checkpoint_shard_files
@@ -59,11 +61,20 @@ class CellUtilMixin:
5961
"""
6062

6163
@property
62-
def dtype(self) -> mindspore.dtype:
64+
def dtype(self) -> mindspore.TensorType:
6365
"""
6466
`mindspore.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
6567
"""
66-
return mindspore.float32
68+
if not hasattr(self, 'get_mixed_precision_type'):
69+
return mindspore.float32
70+
mixed_type = self.get_mixed_precision_type()
71+
if mixed_type == MixedPrecisionType.FP16:
72+
cast_type = mindspore.float16
73+
elif mixed_type == MixedPrecisionType.BF16:
74+
cast_type = mindspore.bfloat16
75+
else:
76+
cast_type = mindspore.float32
77+
return cast_type
6778

6879
@staticmethod
6980
def create_extended_attention_mask_for_decoder(input_shape, attention_mask):
@@ -387,7 +398,7 @@ def tie_weights(self):
387398
self._tie_encoder_decoder_weights(
388399
self.encoder, self.decoder, self.base_model_prefix)
389400

390-
for cell in self.cells():
401+
for _, cell in self.cells_and_names():
391402
if hasattr(cell, "_tie_weights"):
392403
cell._tie_weights()
393404

@@ -398,20 +409,27 @@ def _tie_encoder_decoder_weights(encoder: nn.Cell, decoder: nn.Cell, base_model_
398409
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
399410
""" Tie or clone module weights depending of weither we are using or not
400411
"""
401-
output_embeddings.weight = input_embeddings.embedding_table
402-
output_embeddings._params['weight'] = input_embeddings.embedding_table
412+
if hasattr(output_embeddings, 'weight'):
413+
output_embeddings.weight = input_embeddings.embedding_table
414+
output_embeddings._params['weight'] = input_embeddings.embedding_table
415+
416+
if hasattr(output_embeddings, 'embedding_table'):
417+
output_embeddings.embedding_table = input_embeddings.embedding_table
418+
output_embeddings._params['embedding_table'] = input_embeddings.embedding_table
419+
403420
if getattr(output_embeddings, "bias", None) is not None:
404421
if output_embeddings.weight.shape[0] == output_embeddings.bias.shape[0]:
405422
pass
406423
else:
407424
# instantial a new Parameter since mindspore.Parameter do not support assign_value with different shape
408-
output_embeddings.bias = Parameter(ops.pad(
425+
replace_references(output_embeddings.bias, Parameter(ops.pad(
409426
output_embeddings.bias.data,
410427
(0, output_embeddings.weight.shape[0] -
411428
output_embeddings.bias.shape[0]),
412429
"constant",
413430
0,
414-
))
431+
), name=output_embeddings.bias.name, requires_grad=output_embeddings.bias.requires_grad))
432+
415433
if hasattr(output_embeddings, "out_channels") and hasattr(input_embeddings, "vocab_size"):
416434
output_embeddings.out_channels = input_embeddings.vocab_size
417435

@@ -435,7 +453,6 @@ def resize_token_embeddings(
435453
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
436454
if new_num_tokens is None and pad_to_multiple_of is None:
437455
return model_embeds
438-
439456
# Update base model and current model config
440457
self.config.vocab_size = model_embeds.embedding_table.shape[0]
441458
self.vocab_size = model_embeds.embedding_table.shape[0]
@@ -641,6 +658,8 @@ def from_pretrained(
641658
output_loading_info = kwargs.pop("output_loading_info", False)
642659
subfolder = kwargs.pop("subfolder", "")
643660
variant = kwargs.pop("variant", None)
661+
ms_dtype = kwargs.pop("ms_dtype", None)
662+
_ = kwargs.pop('low_cpu_mem_usage', None)
644663

645664
is_sharded = False
646665
# Load config if we don't provide a configuration
@@ -800,6 +819,8 @@ def from_pretrained(
800819

801820
# Instantiate model.
802821
model = cls(config, *model_args, **model_kwargs)
822+
if ms_dtype:
823+
model = model.to_float(ms_dtype)
803824

804825
if from_pt:
805826
if is_sharded:
@@ -827,43 +848,66 @@ def load_ckpt(resolved_archive_file):
827848
keys_missing = list(model.parameters_dict().keys())
828849
param_id_set = set()
829850

851+
use_keep_in_fp32_modules = False
852+
if model._keep_in_fp32_modules:
853+
use_keep_in_fp32_modules = True
830854

831855
def load_param_into_net(model: nn.Cell, param_dict: dict, prefix: str):
856+
keep_in_fp32_modules = model._keep_in_fp32_modules
832857
keys_unexpected = list(param_dict.keys())
833858

834859
has_prefix_module = any(s.startswith(prefix) for s in keys_unexpected)
835860
expects_prefix_module = any(s.startswith(prefix) for s in keys_missing)
836861

837862
for pname_in_net, param in model.parameters_and_names():
838863
if has_prefix_module and not expects_prefix_module:
839-
param_name = prefix + '.' + param.name
864+
param_name = prefix + '.' + pname_in_net
840865
elif not has_prefix_module and expects_prefix_module:
841-
param_name = param.name.replace(f'{prefix}.', '')
866+
param_name = pname_in_net.replace(f'{prefix}.', '')
842867
else:
843-
param_name = param.name
868+
param_name = pname_in_net
844869

845870
if id(param) in param_id_set:
846871
# for tied params
847872
if pname_in_net in keys_missing:
848873
keys_missing.remove(pname_in_net)
849874

850-
if pname_in_net in keys_unexpected:
851-
keys_unexpected.remove(pname_in_net)
875+
if param_name in keys_missing:
876+
keys_missing.remove(param_name)
877+
878+
if param_name in keys_unexpected:
879+
keys_unexpected.remove(param_name)
852880
continue
853881
new_param = param_dict.pop(param_name, None)
882+
854883
if new_param is not None:
884+
use_replace = False
855885
if new_param.shape != param.shape:
856886
if not ignore_mismatched_sizes:
857887
raise RuntimeError(f'The shape of parameter `{param.name} is {param.shape}, but got mismatch parameter'
858888
f' `{param_name} with shape {new_param.shape} in checkpoint, '
859889
f'\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.')
860890
logger.warning(f'The shape of parameter `{param.name} is {param.shape}, but got mismatch parameter'
861891
f' `{param_name} with shape {new_param.shape} in checkpoint, ')
862-
param = Parameter(new_param.data, param.name)
892+
continue
893+
894+
if new_param.dtype != param.dtype:
895+
use_replace = True
896+
897+
if ms_dtype:
898+
use_replace = True
899+
new_param = new_param.astype(ms_dtype)
900+
901+
if use_keep_in_fp32_modules and \
902+
any(module_to_keep_in_fp32 in pname_in_net.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
903+
new_param = new_param.astype(mindspore.float32)
904+
905+
if use_replace:
906+
replace_references(param, Parameter(new_param, name=param.name, requires_grad=param.requires_grad))
863907
else:
864908
param.set_data(new_param)
865909
keys_unexpected.remove(param_name)
866-
keys_missing.remove(param.name)
910+
keys_missing.remove(pname_in_net)
867911
param_id_set.add(id(param))
868912

869913
return keys_unexpected, keys_missing
@@ -1340,6 +1384,7 @@ def convert_torch_to_mindspore(pth_file):
13401384
key = key.replace('.bias', '.beta')
13411385
if 'wpe' in key or 'wte' in key or \
13421386
'embeddings' in key or 'embedding' in key or \
1387+
'shared' in key or 'relative_attention_bias' in key or \
13431388
'embed_' in key or '_embed' in key and \
13441389
'embedding_hidden_mapping_in' not in key: # for albert
13451390
key = key.replace('weight', 'embedding_table')
@@ -1734,3 +1779,27 @@ def construct(self, hidden_states: Tensor, cls_index: Optional[Tensor] = None) -
17341779
output = self.activation(output)
17351780
output = self.last_dropout(output)
17361781
return output
1782+
1783+
def replace_references(old_obj, new_obj):
1784+
"""use replace_references instead of Tensor.set_data due to mindspore errors."""
1785+
# Get all objects referring to old_obj
1786+
referrers = gc.get_referrers(old_obj)
1787+
1788+
# Replace references
1789+
for referrer in referrers:
1790+
if isinstance(referrer, dict):
1791+
# If the reference is in a dictionary
1792+
for key, value in referrer.items():
1793+
if value is old_obj:
1794+
referrer[key] = new_obj
1795+
elif isinstance(referrer, list):
1796+
# If the reference is in a list or tuple
1797+
index = referrer.index(old_obj)
1798+
referrer[index] = new_obj
1799+
elif isinstance(referrer, tuple):
1800+
pass
1801+
elif hasattr(referrer, '__dict__'):
1802+
# If the reference is in the __dict__ of an object
1803+
for key, value in referrer.__dict__.items():
1804+
if value is old_obj:
1805+
setattr(referrer, key, new_obj)

mindnlp/transformers/models/auto/modeling_auto.py

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
("bert", "BertModel"),
3434
("roberta", "RobertaModel"),
3535
("gpt_bigcode", "GPTBigCodeModel"),
36+
("t5", "T5Model"),
3637
]
3738
)
3839

0 commit comments

Comments
 (0)