diff --git a/tests/kernels/moe/test_expert_usage_histogram.py b/tests/kernels/moe/test_expert_usage_histogram.py new file mode 100644 index 00000000000..378b32309c8 --- /dev/null +++ b/tests/kernels/moe/test_expert_usage_histogram.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.utils import ( + collect_expert_usage_histogram) + + +@pytest.mark.parametrize("topk_experts,expert_count,topk_ids_dtype", + [(4, 32, torch.int32), (1, 1, torch.int64)]) +@pytest.mark.parametrize("token_count", [256, 7]) +def test_collect_expert_usage_histogram(topk_experts: int, expert_count: int, + token_count: int, + topk_ids_dtype: torch.dtype): + device = torch.device('cuda') + + # Make an uniform distribution of expert usage + topk_ids = torch.stack([torch.arange(topk_experts, dtype=topk_ids_dtype)] * + token_count) + + topk_ids_gpu = topk_ids.to(device) + + expert_usage_histogram_gpu = torch.zeros(expert_count, + dtype=torch.int32, + device=device) + + collect_expert_usage_histogram(topk_ids_gpu, expert_usage_histogram_gpu) + + # Every expert is used the same amount, so expecting token_count for + # each expert set in the topk_ids tensor. + assert torch.equal( + expert_usage_histogram_gpu[:topk_experts], + torch.full([topk_experts], + token_count, + dtype=torch.int32, + device=device)) + + # The rest of the experts weren't used, so they should be zero. + assert expert_usage_histogram_gpu[topk_experts:].sum() == 0 + + +@pytest.mark.parametrize("topk_experts,expert_count", [(16, 32)]) +@pytest.mark.parametrize("token_count", [1]) +@pytest.mark.parametrize("seed", [0xDEADBEEF, 0xCAFEBABE]) +def test_collect_expert_usage_histogram_random(topk_experts: int, + expert_count: int, + token_count: int, seed: int): + device = torch.device('cuda') + + generator = torch.Generator() + generator.manual_seed(seed) + + # Make random distribution of expert usage + topk_ids_cpu = torch.stack( + [torch.randperm(topk_experts, generator=generator, dtype=torch.int32) + ] * token_count) + + # Compute ground truth + torch_histogram = torch.histogram(topk_ids_cpu.to(torch.float), + bins=expert_count, + range=(0, expert_count - 1)) + + # Use our function + expert_usage_histogram_gpu = torch.zeros(expert_count, + dtype=torch.int32, + device=device) + + topk_ids_gpu = topk_ids_cpu.to(device) + + collect_expert_usage_histogram(topk_ids_gpu, expert_usage_histogram_gpu) + + assert torch.equal(expert_usage_histogram_gpu, + torch_histogram.hist.to(torch.int32).to(device)) diff --git a/vllm/config.py b/vllm/config.py index 96ea47a0dce..77fa4dce562 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -990,7 +990,7 @@ def _verify_bnb_config(self) -> None: self.enforce_eager = True - def _verify_with_expert_parallelism(self) -> None: + def get_total_num_experts(self) -> int: num_expert_names = [ "moe_num_experts", # Dbrx "num_experts", # Jamba @@ -1002,7 +1002,10 @@ def _verify_with_expert_parallelism(self) -> None: num_experts = getattr(self.hf_text_config, name, 0) if num_experts > 0: break - if num_experts < 1: + return num_experts + + def _verify_with_expert_parallelism(self) -> None: + if self.get_total_num_experts() < 1: raise ValueError( "Number of experts in the model must be greater than 0 " "when expert parallelism is enabled.") @@ -1231,9 +1234,7 @@ def get_num_attention_heads(self, num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) return num_heads // parallel_config.tensor_parallel_size - def get_layers_start_end_indices( - self, parallel_config: "ParallelConfig") -> tuple[int, int]: - from vllm.distributed.utils import get_pp_indices + def get_total_num_hidden_layers(self) -> int: if (self.hf_text_config.model_type == "deepseek_mtp" or self.hf_config.model_type == "mimo_mtp"): total_num_hidden_layers = getattr(self.hf_text_config, @@ -1241,6 +1242,13 @@ def get_layers_start_end_indices( else: total_num_hidden_layers = getattr(self.hf_text_config, "num_hidden_layers", 0) + return total_num_hidden_layers + + def get_layers_start_end_indices( + self, parallel_config: "ParallelConfig") -> tuple[int, int]: + from vllm.distributed.utils import get_pp_indices + total_num_hidden_layers = self.get_total_num_hidden_layers() + # the layout order is: DP x PP x TP pp_rank = (parallel_config.rank // parallel_config.tensor_parallel_size ) % parallel_config.pipeline_parallel_size diff --git a/vllm/envs.py b/vllm/envs.py index 04c80807cd4..d8542a566f9 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -133,6 +133,7 @@ VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300 VLLM_KV_CACHE_LAYOUT: Optional[str] = None VLLM_COMPUTE_NANS_IN_LOGITS: bool = False + VLLM_COLLECT_EXPERT_USAGE_HISTOGRAM: bool = False def get_default_cache_root(): @@ -918,6 +919,10 @@ def get_vllm_port() -> Optional[int]: # or bad hardware but it may add compute overhead. "VLLM_COMPUTE_NANS_IN_LOGITS": lambda: bool(int(os.getenv("VLLM_COMPUTE_NANS_IN_LOGITS", "0"))), + + # Collects expert routing histogram per layer + "VLLM_COLLECT_EXPERT_USAGE_HISTOGRAM": + lambda: bool(int(os.getenv("VLLM_COLLECT_EXPERT_USAGE_HISTOGRAM", "0"))), } # --8<-- [end:env-vars-definition] diff --git a/vllm/forward_context.py b/vllm/forward_context.py index dd55b19feea..1b4a998811c 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -95,6 +95,8 @@ class ForwardContext: # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None skip_cuda_graphs: bool = False + # Set when recording usage histogram + expert_usage_histogram: Optional[torch.Tensor] = None _forward_context: Optional[ForwardContext] = None @@ -116,6 +118,7 @@ def set_forward_context( num_tokens: Optional[int] = None, num_tokens_across_dp: Optional[torch.Tensor] = None, skip_cuda_graphs: bool = False, + expert_usage_histogram: Optional[torch.Tensor] = None, ): """A context manager that stores the current forward context, can be attention metadata, etc. @@ -132,6 +135,9 @@ def set_forward_context( attn_metadata, num_tokens or 0, num_tokens_across_dp) + if expert_usage_histogram is not None: + expert_usage_histogram.zero_() + global _forward_context prev_context = _forward_context _forward_context = ForwardContext( @@ -141,6 +147,7 @@ def set_forward_context( attn_metadata=attn_metadata, dp_metadata=dp_metadata, skip_cuda_graphs=skip_cuda_graphs, + expert_usage_histogram=expert_usage_histogram, ) try: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c1bae033c2b..ea152d221bd 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -25,8 +25,11 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled) +from vllm.model_executor.layers.fused_moe.utils import ( + collect_expert_usage_histogram) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum @@ -423,6 +426,7 @@ def apply( router_logits: torch.Tensor, top_k: int, renormalize: bool, + layer_index: int, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, @@ -562,6 +566,7 @@ def apply( router_logits: torch.Tensor, top_k: int, renormalize: bool, + layer_index: int, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, @@ -579,6 +584,7 @@ def apply( router_logits=router_logits, top_k=top_k, renormalize=renormalize, + layer_index=layer_index, use_grouped_topk=use_grouped_topk, topk_group=topk_group, num_expert_group=num_expert_group, @@ -598,6 +604,7 @@ def forward_cuda( top_k: int, router_logits: torch.Tensor, renormalize: bool, + layer_index: int, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, global_num_experts: int = -1, @@ -615,6 +622,7 @@ def forward_cuda( use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, + layer_index=layer_index, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, @@ -654,6 +662,7 @@ def forward_cpu( top_k: int, router_logits: torch.Tensor, renormalize: bool, + layer_index: int, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, global_num_experts: int = -1, @@ -688,6 +697,7 @@ def forward_hpu( top_k: int, router_logits: torch.Tensor, renormalize: bool, + layer_index: int, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, global_num_experts: int = -1, @@ -721,6 +731,7 @@ def forward_tpu( top_k: int, router_logits: torch.Tensor, renormalize: bool, + layer_index: int, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, global_num_experts: int = -1, @@ -867,6 +878,8 @@ def __init__( compilation_config.static_forward_context[prefix] = self self.layer_name = prefix + self.layer_index = extract_layer_index(prefix) + # Determine expert maps if self.use_ep: self.local_num_experts, self.expert_map = determine_expert_map( @@ -1288,6 +1301,7 @@ def select_experts(hidden_states: torch.Tensor, top_k: int, use_grouped_topk: bool, renormalize: bool, + layer_index: int, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, @@ -1328,6 +1342,12 @@ def select_experts(hidden_states: torch.Tensor, if indices_type is not None: topk_ids = topk_ids.to(dtype=indices_type) + expert_usage_histogram = get_forward_context().expert_usage_histogram + + if expert_usage_histogram is not None: + collect_expert_usage_histogram(topk_ids, + expert_usage_histogram[layer_index]) + return topk_weights, topk_ids def must_reduce_shared_expert_outputs(self) -> bool: @@ -1360,7 +1380,8 @@ def maybe_all_reduce_tensor_model_parallel( def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): return torch.ops.vllm.moe_forward(hidden_states, router_logits, - self.layer_name) + self.layer_name, + self.layer_index) def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor): @@ -1399,6 +1420,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): router_logits=staged_router_logits, top_k=self.top_k, renormalize=self.renormalize, + layer_index=self.layer_index, use_grouped_topk=self.use_grouped_topk, global_num_experts=self.global_num_experts, expert_map=self.expert_map, @@ -1435,7 +1457,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): return full_final_hidden_states def forward_impl(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + router_logits: torch.Tensor, layer_index: int): assert self.quant_method is not None if (self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels): @@ -1455,6 +1477,7 @@ def forward_impl(self, hidden_states: torch.Tensor, router_logits=router_logits, top_k=self.top_k, renormalize=self.renormalize, + layer_index=layer_index, use_grouped_topk=self.use_grouped_topk, global_num_experts=self.global_num_experts, expert_map=self.expert_map, @@ -1517,16 +1540,16 @@ def extra_repr(self) -> str: def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, - layer_name: str) -> torch.Tensor: + layer_name: str, layer_index: int) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] assert self.quant_method is not None - return self.forward_impl(hidden_states, router_logits) + return self.forward_impl(hidden_states, router_logits, layer_index) def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, - layer_name: str) -> torch.Tensor: + layer_name: str, layer_index: int) -> torch.Tensor: return torch.empty_like(hidden_states) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 692482c2ea6..be1ec36d3d5 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -4,6 +4,8 @@ from typing import Optional import torch +import triton +import triton.language as tl from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -13,6 +15,62 @@ from vllm.utils import cdiv +@triton.jit +def _collect_expert_usage_histogram( + topk_experts_ptr, # [M, K] + histogram_ptr, # [E] + M: int, + K: tl.constexpr, + stride_m: int, + stride_k: int, + E: tl.constexpr, + stride_e: int, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + + topk_experts_load_offsets = topk_experts_ptr + pid * stride_m + tl.arange( + 0, K) * stride_k + + expert_indices = tl.load(topk_experts_load_offsets).cast(dtype=tl.int32) + + expert_usage_histogram_layer = expert_indices.histogram(E) + + histogram_store_offsets = histogram_ptr + tl.arange(0, E) * stride_e + + tl.atomic_add(histogram_store_offsets, + expert_usage_histogram_layer, + mask=tl.arange(0, BLOCK_SIZE) < E) + + +def collect_expert_usage_histogram( + topk_experts: torch.Tensor, + expert_usage_histogram_layer: torch.Tensor) -> None: + assert len(topk_experts.shape) == 2 + M = topk_experts.shape[0] + K = topk_experts.shape[1] + + E = expert_usage_histogram_layer.shape[0] + + block_size = triton.next_power_of_2(E) + assert block_size == E # Don't allow padding + + assert block_size >= K + assert block_size >= E + + _collect_expert_usage_histogram[(M, )]( + topk_experts, + expert_usage_histogram_layer, + M=M, + K=K, + stride_m=topk_experts.stride(0), + stride_k=topk_experts.stride(1), + E=E, + stride_e=expert_usage_histogram_layer.stride(0), + BLOCK_SIZE=block_size, + ) + + def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor: """ Shrink the given tensor and apply the given view to it. This is diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 56d803c6baf..a773492b201 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -472,6 +472,7 @@ def apply( router_logits: torch.Tensor, top_k: int, renormalize: bool, + layer_index: int, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, @@ -496,6 +497,7 @@ def apply( use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, + layer_index=layer_index, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index f14131c5f05..caa51c6256a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -321,6 +321,7 @@ def apply( router_logits: torch.Tensor, top_k: int, renormalize: bool, + layer_index: int, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, @@ -339,6 +340,7 @@ def apply( use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, + layer_index=layer_index, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, @@ -583,6 +585,7 @@ def apply( router_logits: torch.Tensor, top_k: int, renormalize: bool, + layer_index: int, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, @@ -601,6 +604,7 @@ def apply( use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, + layer_index=layer_index, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, @@ -712,6 +716,7 @@ def apply( router_logits: torch.Tensor, top_k: int, renormalize: bool, + layer_index: int, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, @@ -731,6 +736,7 @@ def apply( use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, + layer_index=layer_index, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, @@ -1002,6 +1008,7 @@ def apply( router_logits: torch.Tensor, top_k: int, renormalize: bool, + layer_index: int, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, @@ -1024,6 +1031,7 @@ def apply( use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, + layer_index=layer_index, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, @@ -1218,6 +1226,7 @@ def apply( router_logits: torch.Tensor, top_k: int, renormalize: bool, + layer_index: int, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, @@ -1237,6 +1246,7 @@ def apply( use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, + layer_index=layer_index, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 01b0064f080..a7889882466 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -107,6 +107,7 @@ def apply( router_logits: torch.Tensor, top_k: int, renormalize: bool, + layer_index: int, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, @@ -126,6 +127,7 @@ def apply( use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, + layer_index=layer_index, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index b3042bfaed3..0dc48fdcc97 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -815,6 +815,7 @@ def apply( router_logits: torch.Tensor, top_k: int, renormalize: bool, + layer_index: int, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, @@ -833,6 +834,7 @@ def apply( use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, + layer_index=layer_index, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 9c8f74545d3..c8e6676e092 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -510,6 +510,7 @@ def apply( router_logits: torch.Tensor, top_k: int, renormalize: bool, + layer_index: int, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, @@ -533,6 +534,7 @@ def apply( use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, + layer_index=layer_index, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index e9b8dc3266b..d9fe86524c7 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -625,6 +625,7 @@ def apply( router_logits: torch.Tensor, top_k: int, renormalize: bool, + layer_index: int, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, @@ -648,6 +649,7 @@ def apply( use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, + layer_index=layer_index, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 3f79b203aa1..bbd563d9682 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -654,6 +654,7 @@ def apply( router_logits: torch.Tensor, top_k: int, renormalize: bool, + layer_index: int, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, @@ -672,6 +673,7 @@ def apply( use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, + layer_index=layer_index, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, @@ -708,6 +710,7 @@ def apply( use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, + layer_index=layer_index, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 3aa23f06825..39ca5b18e51 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -287,6 +287,7 @@ def apply( router_logits: torch.Tensor, top_k: int, renormalize: bool, + layer_index: int, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, @@ -306,6 +307,7 @@ def apply( use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, + layer_index=layer_index, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 4c2da4c8b04..86b7bf84544 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -195,6 +195,7 @@ def apply( router_logits: torch.Tensor, top_k: int, renormalize: bool, + layer_index: int, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, @@ -214,6 +215,7 @@ def apply( use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, + layer_index=layer_index, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 00b0844a566..2f4281158c6 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -8,6 +8,8 @@ from collections.abc import Iterable from typing import Any, Optional, Union +import torch + from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_transfer.kv_connector.factory import ( @@ -906,7 +908,9 @@ def update_from_output( if engine_core_outputs: # Return stats to only one of the front-ends. next(iter(engine_core_outputs.values())).scheduler_stats = ( - self.make_stats(spec_decoding_stats)) + self.make_stats( + spec_decoding_stats, + model_runner_output.expert_usage_histogram_cpu)) return engine_core_outputs @@ -1000,6 +1004,7 @@ def reset_prefix_cache(self) -> bool: def make_stats( self, spec_decoding_stats: Optional[SpecDecodingStats] = None, + expert_usage_histogram_cpu: Optional[torch.Tensor] = None, ) -> Optional[SchedulerStats]: if not self.log_stats: return None @@ -1013,6 +1018,7 @@ def make_stats( spec_decoding_stats=spec_decoding_stats, num_corrupted_reqs=sum(req.is_output_corrupted for req in self.running), + expert_usage_histogram_cpu=expert_usage_histogram_cpu, ) def make_spec_decoding_stats( diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index c720ca13e51..e90622c58e6 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -9,6 +9,7 @@ import numpy as np import prometheus_client +import vllm.envs as envs from vllm.config import SupportsMetricsInfo, VllmConfig from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics @@ -374,6 +375,24 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): buckets=request_latency_buckets, labelnames=labelnames).labels(*labelvalues) + if envs.VLLM_COLLECT_EXPERT_USAGE_HISTOGRAM: + layer_count = vllm_config.model_config.get_total_num_hidden_layers( + ) + expert_count = vllm_config.model_config.get_total_num_experts() + + moe_expert_selection = self._counter_cls( + name="vllm:moe_expert_selection_counter", + documentation= + "Histogram (actually Counter) of MoE expert selection.", + labelnames=labelnames + ['layer', 'expert']) + + self.counter_moe_expert_selection_array = [[ + moe_expert_selection.labels( + *(labelvalues + + [str(layer_index), str(expert_index)])) + for expert_index in range(expert_count) + ] for layer_index in range(layer_count)] + # # LoRA metrics # @@ -444,6 +463,14 @@ def record(self, scheduler_stats: Optional[SchedulerStats], self.spec_decoding_prom.observe( scheduler_stats.spec_decoding_stats) + if scheduler_stats.expert_usage_histogram_cpu is not None: + histogram = scheduler_stats.expert_usage_histogram_cpu + + for i in range(histogram.shape[0]): + for j in range(histogram.shape[1]): + self.counter_moe_expert_selection_array[i][j].inc( + histogram[i, j].item()) + if iteration_stats is None: return diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 1eb10ccb6c4..21caf1700ed 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -5,6 +5,8 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Optional +import torch + from vllm.v1.spec_decode.metrics import SpecDecodingStats if TYPE_CHECKING: @@ -42,6 +44,8 @@ class SchedulerStats: num_corrupted_reqs: int = 0 + expert_usage_histogram_cpu: Optional[torch.Tensor] = None + @dataclass class LoRAStats: diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index f78623f571b..d0dc5df5c14 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -111,6 +111,9 @@ class ModelRunnerOutput: # req_id -> num_nans_in_logits num_nans_in_logits: Optional[dict[str, int]] = None + # [num_layer, num_expert] + expert_usage_histogram_cpu: Optional[torch.Tensor] = None + EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], req_id_to_index={}, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 40639fdf243..a8ab5586cdf 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -301,6 +301,19 @@ def __init__( pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() + self.expert_usage_histogram: Optional[torch.Tensor] = None + + if envs.VLLM_COLLECT_EXPERT_USAGE_HISTOGRAM: + logger.warning_once( + "Collecting expert routing histogram per layer, " + "this can affect performance negatively") + + self.expert_usage_histogram = torch.zeros( + model_config.get_total_num_hidden_layers(), + model_config.get_total_num_experts(), + dtype=torch.int32, + device=self.device) + # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it # means this layer will perform attention using the keys and values @@ -1344,6 +1357,7 @@ def execute_model( num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, skip_cuda_graphs=skip_cuda_graphs, + expert_usage_histogram=self.expert_usage_histogram, ): self.maybe_setup_kv_connector(scheduler_output) @@ -1591,6 +1605,10 @@ def execute_model( ) spec_token_ids = draft_token_ids.tolist() + expert_usage_histogram_cpu: Optional[torch.Tensor] = None + if self.expert_usage_histogram is not None: + expert_usage_histogram_cpu = self.expert_usage_histogram.cpu() + # Clear KVConnector state after all KVs are generated. if has_kv_transfer_group(): get_kv_transfer_group().clear_connector_metadata() @@ -1606,6 +1624,7 @@ def execute_model( finished_sending=finished_sending, finished_recving=finished_recving, num_nans_in_logits=num_nans_in_logits, + expert_usage_histogram_cpu=expert_usage_histogram_cpu, ) def kv_connector_no_forward( @@ -1967,7 +1986,8 @@ def _dummy_run( attn_metadata, self.vllm_config, num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp): + num_tokens_across_dp=num_tokens_across_dp, + expert_usage_histogram=self.expert_usage_histogram): outputs = model( input_ids=input_ids, positions=positions,