49
49
from vllm .sequence import IntermediateTensors
50
50
51
51
from .interfaces import SupportsLoRA , SupportsPP
52
- from .utils import (is_pp_missing_parameter ,
52
+ from .utils import (AutoWeightsLoader , is_pp_missing_parameter ,
53
53
make_empty_intermediate_tensors_factory , make_layers ,
54
54
maybe_prefix )
55
55
@@ -448,6 +448,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
448
448
(lora_config .max_loras or 1 )) if lora_config else 0 )
449
449
self .vocab_size = config .vocab_size + lora_vocab
450
450
self .org_vocab_size = config .vocab_size
451
+ self .config = config
451
452
452
453
self .embed_tokens = VocabParallelEmbedding (
453
454
self .vocab_size ,
@@ -504,85 +505,6 @@ def forward(
504
505
hidden_states = self .norm (hidden_states )
505
506
return hidden_states
506
507
507
-
508
- class PhiMoEForCausalLM (nn .Module , SupportsLoRA , SupportsPP ):
509
- fall_back_to_pt_during_load = False
510
-
511
- packed_modules_mapping = {
512
- "qkv_proj" : [
513
- "q_proj" ,
514
- "k_proj" ,
515
- "v_proj" ,
516
- ],
517
- }
518
-
519
- # LoRA specific attributes
520
- embedding_modules = {
521
- "embed_tokens" : "input_embeddings" ,
522
- "lm_head" : "output_embeddings" ,
523
- }
524
- embedding_padding_modules = ["lm_head" ]
525
-
526
- def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
527
- super ().__init__ ()
528
- config = vllm_config .model_config .hf_config
529
- lora_config = vllm_config .lora_config
530
- self .config = config
531
- self .lora_config = lora_config
532
- self .quant_config = vllm_config .quant_config
533
-
534
- self .model = PhiMoEModel (vllm_config = vllm_config ,
535
- prefix = maybe_prefix (prefix , "model" ))
536
- self .unpadded_vocab_size = config .vocab_size
537
- if lora_config :
538
- self .unpadded_vocab_size += lora_config .lora_extra_vocab_size
539
- self .lm_head = ParallelLMHead (
540
- self .unpadded_vocab_size ,
541
- config .hidden_size ,
542
- org_num_embeddings = config .vocab_size ,
543
- padding_size = (
544
- DEFAULT_VOCAB_PADDING_SIZE
545
- # We need bigger padding if using lora for kernel
546
- # compatibility
547
- if not lora_config else lora_config .lora_vocab_padding_size ),
548
- quant_config = None ,
549
- bias = True ,
550
- )
551
- self .logits_processor = LogitsProcessor (self .unpadded_vocab_size ,
552
- config .vocab_size )
553
- self .sampler = get_sampler ()
554
-
555
- self .make_empty_intermediate_tensors = (
556
- self .model .make_empty_intermediate_tensors )
557
-
558
- def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
559
- return self .model .get_input_embeddings (input_ids )
560
-
561
- def forward (
562
- self ,
563
- input_ids : torch .Tensor ,
564
- positions : torch .Tensor ,
565
- intermediate_tensors : Optional [IntermediateTensors ] = None ,
566
- inputs_embeds : Optional [torch .Tensor ] = None ,
567
- ) -> Union [torch .Tensor , IntermediateTensors ]:
568
- hidden_states = self .model (input_ids , positions , intermediate_tensors ,
569
- inputs_embeds )
570
- return hidden_states
571
-
572
- def compute_logits (self , hidden_states : torch .Tensor ,
573
- sampling_metadata : SamplingMetadata ) -> torch .Tensor :
574
- logits = self .logits_processor (self .lm_head , hidden_states ,
575
- sampling_metadata )
576
- return logits
577
-
578
- def sample (
579
- self ,
580
- logits : Optional [torch .Tensor ],
581
- sampling_metadata : SamplingMetadata ,
582
- ) -> Optional [SamplerOutput ]:
583
- next_tokens = self .sampler (logits , sampling_metadata )
584
- return next_tokens
585
-
586
508
def load_weights (self , weights : Iterable [Tuple [str ,
587
509
torch .Tensor ]]) -> Set [str ]:
588
510
stacked_params_mapping = [
@@ -601,9 +523,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
601
523
params_dict = dict (self .named_parameters ())
602
524
loaded_params : Set [str ] = set ()
603
525
for name , loaded_weight in weights :
604
- if "rotary_emb.inv_freq" in name :
605
- continue
606
-
607
526
if (self .quant_config is not None and
608
527
(scale_name := self .quant_config .get_cache_scale (name ))):
609
528
# Loading kv cache quantization scales
@@ -667,3 +586,90 @@ def load_weights(self, weights: Iterable[Tuple[str,
667
586
weight_loader (param , loaded_weight )
668
587
loaded_params .add (name )
669
588
return loaded_params
589
+
590
+
591
+ class PhiMoEForCausalLM (nn .Module , SupportsLoRA , SupportsPP ):
592
+ fall_back_to_pt_during_load = False
593
+
594
+ packed_modules_mapping = {
595
+ "qkv_proj" : [
596
+ "q_proj" ,
597
+ "k_proj" ,
598
+ "v_proj" ,
599
+ ],
600
+ }
601
+
602
+ # LoRA specific attributes
603
+ embedding_modules = {
604
+ "embed_tokens" : "input_embeddings" ,
605
+ "lm_head" : "output_embeddings" ,
606
+ }
607
+ embedding_padding_modules = ["lm_head" ]
608
+
609
+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
610
+ super ().__init__ ()
611
+ config = vllm_config .model_config .hf_config
612
+ lora_config = vllm_config .lora_config
613
+ self .config = config
614
+ self .lora_config = lora_config
615
+ self .quant_config = vllm_config .quant_config
616
+
617
+ self .model = PhiMoEModel (vllm_config = vllm_config ,
618
+ prefix = maybe_prefix (prefix , "model" ))
619
+ self .unpadded_vocab_size = config .vocab_size
620
+ if lora_config :
621
+ self .unpadded_vocab_size += lora_config .lora_extra_vocab_size
622
+ self .lm_head = ParallelLMHead (
623
+ self .unpadded_vocab_size ,
624
+ config .hidden_size ,
625
+ org_num_embeddings = config .vocab_size ,
626
+ padding_size = (
627
+ DEFAULT_VOCAB_PADDING_SIZE
628
+ # We need bigger padding if using lora for kernel
629
+ # compatibility
630
+ if not lora_config else lora_config .lora_vocab_padding_size ),
631
+ quant_config = None ,
632
+ bias = True ,
633
+ )
634
+ self .logits_processor = LogitsProcessor (self .unpadded_vocab_size ,
635
+ config .vocab_size )
636
+ self .sampler = get_sampler ()
637
+
638
+ self .make_empty_intermediate_tensors = (
639
+ self .model .make_empty_intermediate_tensors )
640
+
641
+ def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
642
+ return self .model .get_input_embeddings (input_ids )
643
+
644
+ def forward (
645
+ self ,
646
+ input_ids : torch .Tensor ,
647
+ positions : torch .Tensor ,
648
+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
649
+ inputs_embeds : Optional [torch .Tensor ] = None ,
650
+ ) -> Union [torch .Tensor , IntermediateTensors ]:
651
+ hidden_states = self .model (input_ids , positions , intermediate_tensors ,
652
+ inputs_embeds )
653
+ return hidden_states
654
+
655
+ def compute_logits (self , hidden_states : torch .Tensor ,
656
+ sampling_metadata : SamplingMetadata ) -> torch .Tensor :
657
+ logits = self .logits_processor (self .lm_head , hidden_states ,
658
+ sampling_metadata )
659
+ return logits
660
+
661
+ def sample (
662
+ self ,
663
+ logits : Optional [torch .Tensor ],
664
+ sampling_metadata : SamplingMetadata ,
665
+ ) -> Optional [SamplerOutput ]:
666
+ next_tokens = self .sampler (logits , sampling_metadata )
667
+ return next_tokens
668
+
669
+ def load_weights (self , weights : Iterable [Tuple [str ,
670
+ torch .Tensor ]]) -> Set [str ]:
671
+ loader = AutoWeightsLoader (
672
+ self ,
673
+ skip_prefixes = (["rotary_emb.inv_freq" ]),
674
+ )
675
+ return loader .load_weights (weights )
0 commit comments