23
23
# pylint: disable=unused-argument
24
24
# pylint: disable=attribute-defined-outside-init
25
25
# pylint: disable=self-cls-assignment
26
+ # pylint: disable=no-name-in-module
26
27
"""
27
28
Abstract class for Pretrained models.
28
29
"""
38
39
import mindspore
39
40
from mindspore import load_checkpoint , save_checkpoint
40
41
from mindspore import nn , ops , Tensor , Parameter
42
+ from mindspore ._c_expression import MixedPrecisionType
41
43
42
44
from mindnlp .configs import MS_URL_BASE , HF_URL_BASE , PT_WEIGHTS_NAME , WEIGHTS_NAME , WEIGHTS_INDEX_NAME , PT_WEIGHTS_INDEX_NAME
43
45
from mindnlp .utils .download import is_remote_url , download_url , cached_file , get_checkpoint_shard_files
@@ -59,11 +61,20 @@ class CellUtilMixin:
59
61
"""
60
62
61
63
@property
62
- def dtype (self ) -> mindspore .dtype :
64
+ def dtype (self ) -> mindspore .TensorType :
63
65
"""
64
66
`mindspore.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
65
67
"""
66
- return mindspore .float32
68
+ if not hasattr (self , 'get_mixed_precision_type' ):
69
+ return mindspore .float32
70
+ mixed_type = self .get_mixed_precision_type ()
71
+ if mixed_type == MixedPrecisionType .FP16 :
72
+ cast_type = mindspore .float16
73
+ elif mixed_type == MixedPrecisionType .BF16 :
74
+ cast_type = mindspore .bfloat16
75
+ else :
76
+ cast_type = mindspore .float32
77
+ return cast_type
67
78
68
79
@staticmethod
69
80
def create_extended_attention_mask_for_decoder (input_shape , attention_mask ):
@@ -387,7 +398,7 @@ def tie_weights(self):
387
398
self ._tie_encoder_decoder_weights (
388
399
self .encoder , self .decoder , self .base_model_prefix )
389
400
390
- for cell in self .cells ():
401
+ for _ , cell in self .cells_and_names ():
391
402
if hasattr (cell , "_tie_weights" ):
392
403
cell ._tie_weights ()
393
404
@@ -398,20 +409,27 @@ def _tie_encoder_decoder_weights(encoder: nn.Cell, decoder: nn.Cell, base_model_
398
409
def _tie_or_clone_weights (self , output_embeddings , input_embeddings ):
399
410
""" Tie or clone module weights depending of weither we are using or not
400
411
"""
401
- output_embeddings .weight = input_embeddings .embedding_table
402
- output_embeddings ._params ['weight' ] = input_embeddings .embedding_table
412
+ if hasattr (output_embeddings , 'weight' ):
413
+ output_embeddings .weight = input_embeddings .embedding_table
414
+ output_embeddings ._params ['weight' ] = input_embeddings .embedding_table
415
+
416
+ if hasattr (output_embeddings , 'embedding_table' ):
417
+ output_embeddings .embedding_table = input_embeddings .embedding_table
418
+ output_embeddings ._params ['embedding_table' ] = input_embeddings .embedding_table
419
+
403
420
if getattr (output_embeddings , "bias" , None ) is not None :
404
421
if output_embeddings .weight .shape [0 ] == output_embeddings .bias .shape [0 ]:
405
422
pass
406
423
else :
407
424
# instantial a new Parameter since mindspore.Parameter do not support assign_value with different shape
408
- output_embeddings .bias = Parameter (ops .pad (
425
+ replace_references ( output_embeddings .bias , Parameter (ops .pad (
409
426
output_embeddings .bias .data ,
410
427
(0 , output_embeddings .weight .shape [0 ] -
411
428
output_embeddings .bias .shape [0 ]),
412
429
"constant" ,
413
430
0 ,
414
- ))
431
+ ), name = output_embeddings .bias .name , requires_grad = output_embeddings .bias .requires_grad ))
432
+
415
433
if hasattr (output_embeddings , "out_channels" ) and hasattr (input_embeddings , "vocab_size" ):
416
434
output_embeddings .out_channels = input_embeddings .vocab_size
417
435
@@ -435,7 +453,6 @@ def resize_token_embeddings(
435
453
model_embeds = self ._resize_token_embeddings (new_num_tokens , pad_to_multiple_of )
436
454
if new_num_tokens is None and pad_to_multiple_of is None :
437
455
return model_embeds
438
-
439
456
# Update base model and current model config
440
457
self .config .vocab_size = model_embeds .embedding_table .shape [0 ]
441
458
self .vocab_size = model_embeds .embedding_table .shape [0 ]
@@ -641,6 +658,8 @@ def from_pretrained(
641
658
output_loading_info = kwargs .pop ("output_loading_info" , False )
642
659
subfolder = kwargs .pop ("subfolder" , "" )
643
660
variant = kwargs .pop ("variant" , None )
661
+ ms_dtype = kwargs .pop ("ms_dtype" , None )
662
+ _ = kwargs .pop ('low_cpu_mem_usage' , None )
644
663
645
664
is_sharded = False
646
665
# Load config if we don't provide a configuration
@@ -800,6 +819,8 @@ def from_pretrained(
800
819
801
820
# Instantiate model.
802
821
model = cls (config , * model_args , ** model_kwargs )
822
+ if ms_dtype :
823
+ model = model .to_float (ms_dtype )
803
824
804
825
if from_pt :
805
826
if is_sharded :
@@ -827,43 +848,66 @@ def load_ckpt(resolved_archive_file):
827
848
keys_missing = list (model .parameters_dict ().keys ())
828
849
param_id_set = set ()
829
850
851
+ use_keep_in_fp32_modules = False
852
+ if model ._keep_in_fp32_modules :
853
+ use_keep_in_fp32_modules = True
830
854
831
855
def load_param_into_net (model : nn .Cell , param_dict : dict , prefix : str ):
856
+ keep_in_fp32_modules = model ._keep_in_fp32_modules
832
857
keys_unexpected = list (param_dict .keys ())
833
858
834
859
has_prefix_module = any (s .startswith (prefix ) for s in keys_unexpected )
835
860
expects_prefix_module = any (s .startswith (prefix ) for s in keys_missing )
836
861
837
862
for pname_in_net , param in model .parameters_and_names ():
838
863
if has_prefix_module and not expects_prefix_module :
839
- param_name = prefix + '.' + param . name
864
+ param_name = prefix + '.' + pname_in_net
840
865
elif not has_prefix_module and expects_prefix_module :
841
- param_name = param . name .replace (f'{ prefix } .' , '' )
866
+ param_name = pname_in_net .replace (f'{ prefix } .' , '' )
842
867
else :
843
- param_name = param . name
868
+ param_name = pname_in_net
844
869
845
870
if id (param ) in param_id_set :
846
871
# for tied params
847
872
if pname_in_net in keys_missing :
848
873
keys_missing .remove (pname_in_net )
849
874
850
- if pname_in_net in keys_unexpected :
851
- keys_unexpected .remove (pname_in_net )
875
+ if param_name in keys_missing :
876
+ keys_missing .remove (param_name )
877
+
878
+ if param_name in keys_unexpected :
879
+ keys_unexpected .remove (param_name )
852
880
continue
853
881
new_param = param_dict .pop (param_name , None )
882
+
854
883
if new_param is not None :
884
+ use_replace = False
855
885
if new_param .shape != param .shape :
856
886
if not ignore_mismatched_sizes :
857
887
raise RuntimeError (f'The shape of parameter `{ param .name } is { param .shape } , but got mismatch parameter'
858
888
f' `{ param_name } with shape { new_param .shape } in checkpoint, '
859
889
f'\n \t You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.' )
860
890
logger .warning (f'The shape of parameter `{ param .name } is { param .shape } , but got mismatch parameter'
861
891
f' `{ param_name } with shape { new_param .shape } in checkpoint, ' )
862
- param = Parameter (new_param .data , param .name )
892
+ continue
893
+
894
+ if new_param .dtype != param .dtype :
895
+ use_replace = True
896
+
897
+ if ms_dtype :
898
+ use_replace = True
899
+ new_param = new_param .astype (ms_dtype )
900
+
901
+ if use_keep_in_fp32_modules and \
902
+ any (module_to_keep_in_fp32 in pname_in_net .split ("." ) for module_to_keep_in_fp32 in keep_in_fp32_modules ):
903
+ new_param = new_param .astype (mindspore .float32 )
904
+
905
+ if use_replace :
906
+ replace_references (param , Parameter (new_param , name = param .name , requires_grad = param .requires_grad ))
863
907
else :
864
908
param .set_data (new_param )
865
909
keys_unexpected .remove (param_name )
866
- keys_missing .remove (param . name )
910
+ keys_missing .remove (pname_in_net )
867
911
param_id_set .add (id (param ))
868
912
869
913
return keys_unexpected , keys_missing
@@ -1340,6 +1384,7 @@ def convert_torch_to_mindspore(pth_file):
1340
1384
key = key .replace ('.bias' , '.beta' )
1341
1385
if 'wpe' in key or 'wte' in key or \
1342
1386
'embeddings' in key or 'embedding' in key or \
1387
+ 'shared' in key or 'relative_attention_bias' in key or \
1343
1388
'embed_' in key or '_embed' in key and \
1344
1389
'embedding_hidden_mapping_in' not in key : # for albert
1345
1390
key = key .replace ('weight' , 'embedding_table' )
@@ -1734,3 +1779,27 @@ def construct(self, hidden_states: Tensor, cls_index: Optional[Tensor] = None) -
1734
1779
output = self .activation (output )
1735
1780
output = self .last_dropout (output )
1736
1781
return output
1782
+
1783
+ def replace_references (old_obj , new_obj ):
1784
+ """use replace_references instead of Tensor.set_data due to mindspore errors."""
1785
+ # Get all objects referring to old_obj
1786
+ referrers = gc .get_referrers (old_obj )
1787
+
1788
+ # Replace references
1789
+ for referrer in referrers :
1790
+ if isinstance (referrer , dict ):
1791
+ # If the reference is in a dictionary
1792
+ for key , value in referrer .items ():
1793
+ if value is old_obj :
1794
+ referrer [key ] = new_obj
1795
+ elif isinstance (referrer , list ):
1796
+ # If the reference is in a list or tuple
1797
+ index = referrer .index (old_obj )
1798
+ referrer [index ] = new_obj
1799
+ elif isinstance (referrer , tuple ):
1800
+ pass
1801
+ elif hasattr (referrer , '__dict__' ):
1802
+ # If the reference is in the __dict__ of an object
1803
+ for key , value in referrer .__dict__ .items ():
1804
+ if value is old_obj :
1805
+ setattr (referrer , key , new_obj )
0 commit comments