Skip to content

Commit 86a4522

Browse files
committed
[Model] use AutoWeightsLoader for phimoe,qwen2_moe,qwen3_moe
Signed-off-by: rongfu.leng <[email protected]>
1 parent 027b204 commit 86a4522

File tree

3 files changed

+219
-198
lines changed

3 files changed

+219
-198
lines changed

vllm/model_executor/models/phimoe.py

Lines changed: 89 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from vllm.sequence import IntermediateTensors
5050

5151
from .interfaces import SupportsLoRA, SupportsPP
52-
from .utils import (is_pp_missing_parameter,
52+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
5353
make_empty_intermediate_tensors_factory, make_layers,
5454
maybe_prefix)
5555

@@ -448,6 +448,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
448448
(lora_config.max_loras or 1)) if lora_config else 0)
449449
self.vocab_size = config.vocab_size + lora_vocab
450450
self.org_vocab_size = config.vocab_size
451+
self.config = config
451452

452453
self.embed_tokens = VocabParallelEmbedding(
453454
self.vocab_size,
@@ -504,85 +505,6 @@ def forward(
504505
hidden_states = self.norm(hidden_states)
505506
return hidden_states
506507

507-
508-
class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
509-
fall_back_to_pt_during_load = False
510-
511-
packed_modules_mapping = {
512-
"qkv_proj": [
513-
"q_proj",
514-
"k_proj",
515-
"v_proj",
516-
],
517-
}
518-
519-
# LoRA specific attributes
520-
embedding_modules = {
521-
"embed_tokens": "input_embeddings",
522-
"lm_head": "output_embeddings",
523-
}
524-
embedding_padding_modules = ["lm_head"]
525-
526-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
527-
super().__init__()
528-
config = vllm_config.model_config.hf_config
529-
lora_config = vllm_config.lora_config
530-
self.config = config
531-
self.lora_config = lora_config
532-
self.quant_config = vllm_config.quant_config
533-
534-
self.model = PhiMoEModel(vllm_config=vllm_config,
535-
prefix=maybe_prefix(prefix, "model"))
536-
self.unpadded_vocab_size = config.vocab_size
537-
if lora_config:
538-
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
539-
self.lm_head = ParallelLMHead(
540-
self.unpadded_vocab_size,
541-
config.hidden_size,
542-
org_num_embeddings=config.vocab_size,
543-
padding_size=(
544-
DEFAULT_VOCAB_PADDING_SIZE
545-
# We need bigger padding if using lora for kernel
546-
# compatibility
547-
if not lora_config else lora_config.lora_vocab_padding_size),
548-
quant_config=None,
549-
bias=True,
550-
)
551-
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
552-
config.vocab_size)
553-
self.sampler = get_sampler()
554-
555-
self.make_empty_intermediate_tensors = (
556-
self.model.make_empty_intermediate_tensors)
557-
558-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
559-
return self.model.get_input_embeddings(input_ids)
560-
561-
def forward(
562-
self,
563-
input_ids: torch.Tensor,
564-
positions: torch.Tensor,
565-
intermediate_tensors: Optional[IntermediateTensors] = None,
566-
inputs_embeds: Optional[torch.Tensor] = None,
567-
) -> Union[torch.Tensor, IntermediateTensors]:
568-
hidden_states = self.model(input_ids, positions, intermediate_tensors,
569-
inputs_embeds)
570-
return hidden_states
571-
572-
def compute_logits(self, hidden_states: torch.Tensor,
573-
sampling_metadata: SamplingMetadata) -> torch.Tensor:
574-
logits = self.logits_processor(self.lm_head, hidden_states,
575-
sampling_metadata)
576-
return logits
577-
578-
def sample(
579-
self,
580-
logits: Optional[torch.Tensor],
581-
sampling_metadata: SamplingMetadata,
582-
) -> Optional[SamplerOutput]:
583-
next_tokens = self.sampler(logits, sampling_metadata)
584-
return next_tokens
585-
586508
def load_weights(self, weights: Iterable[Tuple[str,
587509
torch.Tensor]]) -> Set[str]:
588510
stacked_params_mapping = [
@@ -601,9 +523,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
601523
params_dict = dict(self.named_parameters())
602524
loaded_params: Set[str] = set()
603525
for name, loaded_weight in weights:
604-
if "rotary_emb.inv_freq" in name:
605-
continue
606-
607526
if (self.quant_config is not None and
608527
(scale_name := self.quant_config.get_cache_scale(name))):
609528
# Loading kv cache quantization scales
@@ -667,3 +586,90 @@ def load_weights(self, weights: Iterable[Tuple[str,
667586
weight_loader(param, loaded_weight)
668587
loaded_params.add(name)
669588
return loaded_params
589+
590+
591+
class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
592+
fall_back_to_pt_during_load = False
593+
594+
packed_modules_mapping = {
595+
"qkv_proj": [
596+
"q_proj",
597+
"k_proj",
598+
"v_proj",
599+
],
600+
}
601+
602+
# LoRA specific attributes
603+
embedding_modules = {
604+
"embed_tokens": "input_embeddings",
605+
"lm_head": "output_embeddings",
606+
}
607+
embedding_padding_modules = ["lm_head"]
608+
609+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
610+
super().__init__()
611+
config = vllm_config.model_config.hf_config
612+
lora_config = vllm_config.lora_config
613+
self.config = config
614+
self.lora_config = lora_config
615+
self.quant_config = vllm_config.quant_config
616+
617+
self.model = PhiMoEModel(vllm_config=vllm_config,
618+
prefix=maybe_prefix(prefix, "model"))
619+
self.unpadded_vocab_size = config.vocab_size
620+
if lora_config:
621+
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
622+
self.lm_head = ParallelLMHead(
623+
self.unpadded_vocab_size,
624+
config.hidden_size,
625+
org_num_embeddings=config.vocab_size,
626+
padding_size=(
627+
DEFAULT_VOCAB_PADDING_SIZE
628+
# We need bigger padding if using lora for kernel
629+
# compatibility
630+
if not lora_config else lora_config.lora_vocab_padding_size),
631+
quant_config=None,
632+
bias=True,
633+
)
634+
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
635+
config.vocab_size)
636+
self.sampler = get_sampler()
637+
638+
self.make_empty_intermediate_tensors = (
639+
self.model.make_empty_intermediate_tensors)
640+
641+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
642+
return self.model.get_input_embeddings(input_ids)
643+
644+
def forward(
645+
self,
646+
input_ids: torch.Tensor,
647+
positions: torch.Tensor,
648+
intermediate_tensors: Optional[IntermediateTensors] = None,
649+
inputs_embeds: Optional[torch.Tensor] = None,
650+
) -> Union[torch.Tensor, IntermediateTensors]:
651+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
652+
inputs_embeds)
653+
return hidden_states
654+
655+
def compute_logits(self, hidden_states: torch.Tensor,
656+
sampling_metadata: SamplingMetadata) -> torch.Tensor:
657+
logits = self.logits_processor(self.lm_head, hidden_states,
658+
sampling_metadata)
659+
return logits
660+
661+
def sample(
662+
self,
663+
logits: Optional[torch.Tensor],
664+
sampling_metadata: SamplingMetadata,
665+
) -> Optional[SamplerOutput]:
666+
next_tokens = self.sampler(logits, sampling_metadata)
667+
return next_tokens
668+
669+
def load_weights(self, weights: Iterable[Tuple[str,
670+
torch.Tensor]]) -> Set[str]:
671+
loader = AutoWeightsLoader(
672+
self,
673+
skip_prefixes=(["rotary_emb.inv_freq"]),
674+
)
675+
return loader.load_weights(weights)

vllm/model_executor/models/qwen2_moe.py

Lines changed: 65 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@
5555
from vllm.sequence import IntermediateTensors
5656

5757
from .interfaces import SupportsPP
58-
from .utils import (extract_layer_index, is_pp_missing_parameter,
58+
from .utils import (AutoWeightsLoader, extract_layer_index,
59+
is_pp_missing_parameter,
5960
make_empty_intermediate_tensors_factory, make_layers,
6061
maybe_prefix)
6162

@@ -329,6 +330,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
329330
quant_config = vllm_config.quant_config
330331

331332
self.vocab_size = config.vocab_size
333+
self.config = config
332334

333335
self.embed_tokens = VocabParallelEmbedding(
334336
config.vocab_size,
@@ -377,60 +379,6 @@ def forward(
377379
hidden_states, _ = self.norm(hidden_states, residual)
378380
return hidden_states
379381

380-
381-
class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
382-
383-
fall_back_to_pt_during_load = False
384-
385-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
386-
super().__init__()
387-
config = vllm_config.model_config.hf_config
388-
quant_config = vllm_config.quant_config
389-
self.config = config
390-
self.quant_config = quant_config
391-
self.model = Qwen2MoeModel(vllm_config=vllm_config,
392-
prefix=maybe_prefix(prefix, "model"))
393-
self.lm_head = ParallelLMHead(config.vocab_size,
394-
config.hidden_size,
395-
quant_config=quant_config)
396-
if self.config.tie_word_embeddings:
397-
self.lm_head.weight = self.model.embed_tokens.weight
398-
self.logits_processor = LogitsProcessor(config.vocab_size)
399-
self.sampler = get_sampler()
400-
self.make_empty_intermediate_tensors = (
401-
self.model.make_empty_intermediate_tensors)
402-
403-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
404-
return self.model.get_input_embeddings(input_ids)
405-
406-
def forward(
407-
self,
408-
input_ids: torch.Tensor,
409-
positions: torch.Tensor,
410-
intermediate_tensors: Optional[IntermediateTensors] = None,
411-
inputs_embeds: Optional[torch.Tensor] = None,
412-
) -> Union[torch.Tensor, IntermediateTensors]:
413-
hidden_states = self.model(input_ids, positions, intermediate_tensors,
414-
inputs_embeds)
415-
return hidden_states
416-
417-
def compute_logits(
418-
self,
419-
hidden_states: torch.Tensor,
420-
sampling_metadata: SamplingMetadata,
421-
) -> Optional[torch.Tensor]:
422-
logits = self.logits_processor(self.lm_head, hidden_states,
423-
sampling_metadata)
424-
return logits
425-
426-
def sample(
427-
self,
428-
logits: Optional[torch.Tensor],
429-
sampling_metadata: SamplingMetadata,
430-
) -> Optional[SamplerOutput]:
431-
next_tokens = self.sampler(logits, sampling_metadata)
432-
return next_tokens
433-
434382
def load_weights(self, weights: Iterable[Tuple[str,
435383
torch.Tensor]]) -> Set[str]:
436384
stacked_params_mapping = [
@@ -453,8 +401,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
453401
params_dict = dict(self.named_parameters())
454402
loaded_params: Set[str] = set()
455403
for name, loaded_weight in weights:
456-
if "rotary_emb.inv_freq" in name:
457-
continue
458404
for (param_name, weight_name, shard_id) in stacked_params_mapping:
459405
# Skip non-stacked layers and experts (experts handled below).
460406
if weight_name not in name:
@@ -531,3 +477,65 @@ def load_weights(self, weights: Iterable[Tuple[str,
531477
weight_loader(param, loaded_weight)
532478
loaded_params.add(name)
533479
return loaded_params
480+
481+
482+
class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
483+
484+
fall_back_to_pt_during_load = False
485+
486+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
487+
super().__init__()
488+
config = vllm_config.model_config.hf_config
489+
quant_config = vllm_config.quant_config
490+
self.config = config
491+
self.quant_config = quant_config
492+
self.model = Qwen2MoeModel(vllm_config=vllm_config,
493+
prefix=maybe_prefix(prefix, "model"))
494+
self.lm_head = ParallelLMHead(config.vocab_size,
495+
config.hidden_size,
496+
quant_config=quant_config)
497+
if self.config.tie_word_embeddings:
498+
self.lm_head.weight = self.model.embed_tokens.weight
499+
self.logits_processor = LogitsProcessor(config.vocab_size)
500+
self.sampler = get_sampler()
501+
self.make_empty_intermediate_tensors = (
502+
self.model.make_empty_intermediate_tensors)
503+
504+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
505+
return self.model.get_input_embeddings(input_ids)
506+
507+
def forward(
508+
self,
509+
input_ids: torch.Tensor,
510+
positions: torch.Tensor,
511+
intermediate_tensors: Optional[IntermediateTensors] = None,
512+
inputs_embeds: Optional[torch.Tensor] = None,
513+
) -> Union[torch.Tensor, IntermediateTensors]:
514+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
515+
inputs_embeds)
516+
return hidden_states
517+
518+
def compute_logits(
519+
self,
520+
hidden_states: torch.Tensor,
521+
sampling_metadata: SamplingMetadata,
522+
) -> Optional[torch.Tensor]:
523+
logits = self.logits_processor(self.lm_head, hidden_states,
524+
sampling_metadata)
525+
return logits
526+
527+
def sample(
528+
self,
529+
logits: Optional[torch.Tensor],
530+
sampling_metadata: SamplingMetadata,
531+
) -> Optional[SamplerOutput]:
532+
next_tokens = self.sampler(logits, sampling_metadata)
533+
return next_tokens
534+
535+
def load_weights(self, weights: Iterable[Tuple[str,
536+
torch.Tensor]]) -> Set[str]:
537+
loader = AutoWeightsLoader(
538+
self,
539+
skip_prefixes=(["rotary_emb.inv_freq"]),
540+
)
541+
return loader.load_weights(weights)

0 commit comments

Comments
 (0)