diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index 78bfda9bcf4..34edfebc65d 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -16,7 +16,7 @@ def main(): # Create an LLM. - llm = LLM(model="facebook/opt-125m") + llm = LLM(model="deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", trust_remote_code=True) # Generate texts from the prompts. # The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index b2f54f37a6e..135db4a35a0 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -564,6 +564,8 @@ def forward( """ assert output is not None, "Output tensor must be provided." + print("kv_cache.shape = {}".format(kv_cache.shape)) + if output_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 1878ae74dbc..a9b57eb852d 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -212,6 +212,8 @@ from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable +from vllm.config import VllmConfig, get_layers_from_vllm_config + try: from vllm.vllm_flash_attn import flash_attn_varlen_func is_vllm_fa = True @@ -225,6 +227,9 @@ from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.attention.layer import Attention +from flashinfer import BatchPrefillWithRaggedKVCacheWrapper + logger = init_logger(__name__) @@ -278,6 +283,77 @@ class ChunkedContextMetadata: chunked_context: Optional[ChunkedContextMetadata] = None +@dataclass +class FIPrefillMetadata: + + # num_actual_tokens: int # Number of tokens excluding padding. + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + # qo_indptr: torch.Tensor + # An example for paged_kv_indices, paged_kv_indptr: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + # The indptr of the paged kv cache, shape: [batch_size + 1] + # paged_kv_indptr: torch.Tensor + # The page indices of the paged kv cache + # paged_kv_indices: torch.Tensor + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + # paged_kv_last_page_len: torch.Tensor + # The number of query/output heads + # num_qo_heads: int + # The number of key/value heads + # num_kv_heads: int + # The dimension of the attention heads + # head_dim: int + # Block size of vllm + # page_size: int + # The data type of the paged kv cache + # data_type: torch.dtype + # The data type of the query + # q_data_type: torch.dtype + + # slot_mapping: torch.Tensor + + # For handling prefill decode split + # num_decodes: int + # num_decode_tokens: int + # num_prefills: int + # num_prefill_tokens: int + + # # For cascade attention. + # use_cascade: bool + # shared_qo_indptr: Optional[torch.Tensor] = None + # shared_kv_page_indptr: Optional[torch.Tensor] = None + # shared_kv_page_indices: Optional[torch.Tensor] = None + # shared_kv_last_page_len: Optional[torch.Tensor] = None + + prefill_wrapper: Optional[BatchPrefillWithRaggedKVCacheWrapper] = None + prefill_wrapper_ctx: Optional[BatchPrefillWithRaggedKVCacheWrapper] = None + + # @property + # def query_start_loc(self): + # # The GPUModelRunner expects to be able to access this property. + # return self.qo_indptr + + # def __post_init__(self): + # # Refer to + # # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 + # supported_head_sizes = FlashInferBackend.get_supported_head_sizes() + # if self.head_dim is not None and self.head_dim \ + # not in supported_head_sizes: + # raise ValueError( + # f"Only {supported_head_sizes} are supported for head_dim,", + # f" received {self.head_dim}.") + + @dataclass class MLACommonDecodeMetadata: block_table: torch.Tensor @@ -317,6 +393,7 @@ class MLACommonMetadata(Generic[D]): decode: Optional[D] = None prefill: Optional[MLACommonPrefillMetadata] = None + fi_prefill: Optional[FIPrefillMetadata] = None def __post_init__(self): supported_head_sizes = MLACommonBackend.get_supported_head_sizes() @@ -329,6 +406,72 @@ def __post_init__(self): M = TypeVar("M", bound=MLACommonMetadata) +FLASHINFER_WORKSPACE_BUFFER_SIZE = 394 * 1024 * 1024 + + +@dataclass +class PerLayerParameters: + """ + Currently, FlashInfer backend only support models in which all layers share + the same values for the following hyperparameters. + """ + + window_left: int + logits_soft_cap: Optional[float] + sm_scale: float + + +def get_per_layer_parameters( + vllm_config: VllmConfig) -> dict[str, PerLayerParameters]: + """ + Scan all attention layers and determine some hyperparameters + to use during `plan`. + """ + + layers = get_layers_from_vllm_config(vllm_config, Attention) + per_layer_params: dict[str, PerLayerParameters] = {} + + for key, layer in layers.items(): + impl = layer.impl + assert isinstance(impl, MLACommonImpl) + + # Infer hyperparameters from the attention layer + window_size = impl.sliding_window + window_left = window_size[0] if window_size is not None else -1 + logits_soft_cap = impl.logits_soft_cap + sm_scale = impl.scale + + per_layer_params[key] = PerLayerParameters(window_left, + logits_soft_cap, sm_scale) + + return per_layer_params + + +def infer_global_hyperparameters( + per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters: + """ + Currently, FlashInfer backend only support models in which all layers share + the same values for the following hyperparameters: + - `window_left` + - `logits_soft_cap` + - `sm_scale` + + So this function asserts that all layers share the same values for these + hyperparameters and returns the global values. + """ + + assert len(per_layer_params) > 0, "No attention layers found in the model." + + param_sets = list(per_layer_params.values()) + global_params = param_sets[0] + for params in param_sets: + assert params == global_params, ( + "FlashInfer backend currently only supports models in which all " + "layers share the same values for the following hyperparameters: " + "`window_left`, `logits_soft_cap`, `sm_scale`.") + + return global_params + class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): """ @@ -384,6 +527,204 @@ def __init__(self, ) self.block_table = block_table + # FI + self._workspace_buffer = None + self._prefill_wrapper = None # Wrapper for prefill/append + self._prefill_wrapper_ctx = [] # Wrapper for prefill/append + self.global_hyperparameters = None + + def _get_workspace_buffer(self): + if self._workspace_buffer is None: + self._workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.runner.device) + return self._workspace_buffer + + def _get_prefill_wrapper(self): + if self._prefill_wrapper is None: + self._prefill_wrapper = BatchPrefillWithRaggedKVCacheWrapper( + self._get_workspace_buffer(), "NHD") + + return self._prefill_wrapper + + def _get_prefill_wrapper_ctx(self, num_chunks): + if len(self._prefill_wrapper_ctx) < num_chunks: + for _ in range(len(self._prefill_wrapper_ctx), num_chunks): + self._prefill_wrapper_ctx.append( + BatchPrefillWithRaggedKVCacheWrapper( + self._get_workspace_buffer(), "NHD")) + + return self._prefill_wrapper_ctx + + def _build_fi_prefill(self, common_attn_metadata: CommonAttentionMetadata, + attn_metadata: MLACommonMetadata): + # print("INSIDE _build_fi_prefill") + if self.global_hyperparameters is None: + self.global_hyperparameters = infer_global_hyperparameters( + get_per_layer_parameters(self.runner.vllm_config)) + + assert attn_metadata.prefill is not None + qo_indptr = attn_metadata.prefill.query_start_loc + + # print(" qo_indptr.shape = {} qo_indptr = {}".format(qo_indptr.shape, qo_indptr)) + + # slot_mapping = attn_metadata.slot_mapping + + # num_reqs = common_attn_metadata.num_reqs + # num_actual_tokens = common_attn_metadata.num_actual_tokens + # print(" num_reqs = {} num_actual_tokens = {}".format(num_reqs, num_actual_tokens)) + + # assert self._num_decodes + self._num_prefills == num_reqs + # assert (self._num_decode_tokens + + # self._num_prefill_tokens == num_actual_tokens) + + # page_size = self.kv_cache_spec.block_size + # device = self.runner.device + # print(" page_size = {}".format(page_size)) + + # prefill_seq_lens = common_attn_metadata.seq_lens[self._num_decodes:] + # print(" prefill_seq_lens.shape = {} prefill_seq_lens = {}".format(prefill_seq_lens.shape, prefill_seq_lens)) + + # prefill_block_table_bounds = (prefill_seq_lens + page_size - + # 1) // page_size + # print(" prefill_block_table_bounds.shape = {} prefill_block_table_bounds = {}".format(prefill_block_table_bounds.shape, prefill_block_table_bounds)) + + # prefill_block_table = attn_metadata.prefill.block_table + # print(" prefill_block_table.shape = {}".format(prefill_block_table.shape)) + + # mask = (torch.arange(prefill_block_table.size(1), + # dtype=prefill_block_table.dtype, + # device=prefill_block_table.device).unsqueeze(0) + # < prefill_block_table_bounds.unsqueeze(1)) + # # print(" mask.shape = {} mask = {}".format(mask.shape, mask)) + + # prefill_paged_kv_indices = prefill_block_table[mask] + # # print(" prefill_paged_kv_indices.shape = {} prefill_paged_kv_indices = {}".format(prefill_paged_kv_indices.shape, prefill_paged_kv_indices)) + + # prefill_paged_kv_indptr = torch.cat([ + # torch.zeros(1, + # dtype=prefill_block_table_bounds.dtype, + # device=prefill_block_table_bounds.device), + # prefill_block_table_bounds.cumsum(dim=0, dtype=torch.int32) + # ]) + + # # print(" prefill_paged_kv_indptr.shape = {} prefill_paged_kv_indptr = {}".format(prefill_paged_kv_indptr.shape, prefill_paged_kv_indptr)) + + # prefill_paged_kv_last_page_len = prefill_seq_lens % page_size + # prefill_paged_kv_last_page_len = torch.where( + # prefill_paged_kv_last_page_len == 0, page_size, + # prefill_paged_kv_last_page_len) + + # print(" prefill_paged_kv_last_page_len.shape = {} prefill_paged_kv_last_page_len = {}".format(prefill_paged_kv_last_page_len.shape, prefill_paged_kv_last_page_len)) + + prefill_wrapper = self._get_prefill_wrapper() + + has_context = attn_metadata.prefill.chunked_context is not None + + if has_context: + num_chunks = attn_metadata.prefill.chunked_context.cu_seq_lens.shape[ + 0] + prefill_wrapper_ctx = self._get_prefill_wrapper_ctx(num_chunks) + else: + prefill_wrapper_ctx = [] + num_qo_heads = self.runner.num_query_heads + num_kv_heads = self.kv_cache_spec.num_kv_heads + head_dim_qk = self.kv_cache_spec.head_size + + # print("num_qo_heads = {}".format(num_qo_heads)) + # print("num_kv_heads = {}".format(num_kv_heads)) + # print("head_dim_qk = {}".format(head_dim_qk)) + # print("global_hyperparameters.sm_scale = {}".format(self.global_hyperparameters.sm_scale)) + # print("global_hyperparameters.window_left = {}".format(self.global_hyperparameters.window_left)) + # print("global_hyperparameters.logits_soft_cap = {}".format(self.global_hyperparameters.logits_soft_cap)) + + kv_indptr = qo_indptr.clone() + + prefill_wrapper.plan( + qo_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + 192, #head_dim_qk, + causal=True, + head_dim_vo=128, + + # qo_indptr, + # prefill_paged_kv_indptr, + # prefill_paged_kv_indices, + # prefill_paged_kv_last_page_len, + # num_qo_heads, + # num_kv_heads, + # head_dim_qk, + # page_size, + # 128, + # causal=True, + sm_scale=self.global_hyperparameters.sm_scale, + window_left=self.global_hyperparameters.window_left, + logits_soft_cap=self.global_hyperparameters.logits_soft_cap, + q_data_type=self.runner.dtype, + kv_data_type=self.kv_cache_spec.dtype, + ) + + if has_context: + for i in range(num_chunks): + kv_indptr_ctx = attn_metadata.prefill.chunked_context.cu_seq_lens[ + i] + + prefill_wrapper_ctx[i].plan( + qo_indptr, + kv_indptr_ctx, + num_qo_heads, + num_kv_heads, + 192, #head_dim_qk, + causal=False, + head_dim_vo=128, + + # qo_indptr, + # prefill_paged_kv_indptr, + # prefill_paged_kv_indices, + # prefill_paged_kv_last_page_len, + # num_qo_heads, + # num_kv_heads, + # head_dim_qk, + # page_size, + # 128, + # causal=True, + sm_scale=self.global_hyperparameters.sm_scale, + window_left=self.global_hyperparameters.window_left, + logits_soft_cap=self.global_hyperparameters. + logits_soft_cap, + q_data_type=self.runner.dtype, + kv_data_type=self.kv_cache_spec.dtype, + ) + + attn_metadata.fi_prefill = attn_metadata = FIPrefillMetadata( + # num_actual_tokens=0, #num_actual_tokens, + # qo_indptr=qo_indptr, + # paged_kv_indptr=prefill_paged_kv_indptr, + # paged_kv_indices=prefill_paged_kv_indices, + # paged_kv_last_page_len=prefill_paged_kv_last_page_len, + # num_qo_heads=self.runner.num_query_heads, + # num_kv_heads=self.kv_cache_spec.num_kv_heads, + # head_dim=self.kv_cache_spec.head_size, + # page_size=page_size, + # data_type=self.kv_cache_spec.dtype, + # q_data_type=self.runner.dtype, + # slot_mapping=slot_mapping, + # num_decodes=self._num_decodes, + # num_decode_tokens=self._num_decode_tokens, + # num_prefills=self._num_prefills, + # num_prefill_tokens=self._num_prefill_tokens, + # use_cascade=False, + # shared_qo_indptr=None, + # shared_kv_page_indptr=None, + # shared_kv_page_indices=None, + # shared_kv_last_page_len=None, + prefill_wrapper=prefill_wrapper, + prefill_wrapper_ctx=prefill_wrapper_ctx, + ) + def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: # We now want to reorder the batch so that the "decode" requests are and @@ -578,7 +919,7 @@ def build(self, common_prefix_len: int, seq_lens=seq_lens[:self._num_decodes], ) - return self.metadata_cls( + attn_metadata = self.metadata_cls( num_actual_tokens=num_actual_tokens, query_start_loc=query_start_loc, slot_mapping=slot_mapping, @@ -591,6 +932,11 @@ def build(self, common_prefix_len: int, decode=decode_metadata, ) + if self._num_prefills > 0: + self._build_fi_prefill(common_attn_metadata, attn_metadata) + + return attn_metadata + def can_run_in_cudagraph( self, common_attn_metadata: CommonAttentionMetadata) -> bool: return common_attn_metadata.max_query_len == 1 @@ -660,6 +1006,16 @@ def __init__( self.vllm_flash_attn_version == 3 and current_platform.get_device_capability()[0] == 9) + # print("!!! _pad_v = {}".format(self._pad_v)) + self._pad_v = False + # FI + if sliding_window is None: + self.sliding_window = (-1, -1) + else: + self.sliding_window = (sliding_window - 1, 0) + + self.logits_soft_cap = logits_soft_cap + def _flash_attn_varlen_diff_headdims(self, q, k, @@ -692,6 +1048,58 @@ def _flash_attn_varlen_diff_headdims(self, return attn_out, lse return attn_out + def _fi_prefill_run(self, + q, + k, + v, + kv_cache, + prefill_wrapper, + return_softmax_lse=False, + softmax_scale=None, + layer=None, + **kwargs): + maybe_padded_v = v + if self._pad_v: + maybe_padded_v = torch.nn.functional.pad( + v, [0, q.shape[-1] - v.shape[-1]], value=0) + + # print("q.shape = {}".format(q.shape)) + # print("k.shape = {}".format(k.shape)) + # print("v.shape = {}".format(v.shape)) + # print("maybe_padded_v.shape = {}".format(maybe_padded_v.shape)) + + # print("self.scale = {}".format(self.scale)) + # print("return_softmax_lse = {}".format(return_softmax_lse)) + attn_out = prefill_wrapper.run( + q, + k, + maybe_padded_v, + # k_scale=layer._k_scale_float, + # v_scale=layer._v_scale_float, + return_lse=return_softmax_lse, + ) + + # attn_out = self.flash_attn_varlen_func( + # q=q, + + # k=k, + # v=maybe_padded_v, + # return_softmax_lse=return_softmax_lse, + # softmax_scale=softmax_scale, + # **kwargs, + # ) + + # Unpack the output if there is multiple results + lse = None + if isinstance(attn_out, tuple): + attn_out, lse = attn_out[0], attn_out[1] + + # Remain consistent with old `flash_attn_varlen_func` where there + # is only one output tensor if `return_softmax_lse` is False. + if return_softmax_lse: + return attn_out, lse + return attn_out + def _v_up_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) @@ -790,19 +1198,40 @@ def _compute_prefill_context( k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - attn_output, attn_softmax_lse = \ - self._flash_attn_varlen_diff_headdims( + print("_compute_prefill_context") + print(" q.shape = {}".format(q.shape)) + print(" k.shape = {}".format(k.shape)) + print(" v.shape = {}".format(v.shape)) + + attn_output, attn_softmax_lse = self._fi_prefill_run( q=q, k=k, v=v, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i], - max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i], + kv_cache=kv_c_and_k_pe_cache, + prefill_wrapper=attn_metadata.fi_prefill. + prefill_wrapper_ctx[i], + cu_seqlens_q=attn_metadata.prefill.query_start_loc, + cu_seqlens_k=attn_metadata.prefill.query_start_loc, + max_seqlen_q=attn_metadata.prefill.max_query_len, + max_seqlen_k=attn_metadata.prefill.max_query_len, softmax_scale=self.scale, - causal=False, # Context is unmasked + causal=False, return_softmax_lse=True, + layer=None, ) + # attn_output, attn_softmax_lse = \ + # self._flash_attn_varlen_diff_headdims( + # q=q, + # k=k, + # v=v, + # cu_seqlens_q=prefill_metadata.query_start_loc, + # cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i], + # max_seqlen_q=prefill_metadata.max_query_len, + # max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i], + # softmax_scale=self.scale, + # causal=False, # Context is unmasked + # return_softmax_lse=True, + # ) if output is None: output = attn_output @@ -830,8 +1259,10 @@ def _forward_prefill( k_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + layer, ) -> torch.Tensor: assert attn_metadata.prefill is not None + assert attn_metadata.fi_prefill is not None has_context = attn_metadata.prefill.chunked_context is not None kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ @@ -841,10 +1272,19 @@ def _forward_prefill( k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - output = self._flash_attn_varlen_diff_headdims( + # print("has_context = {}".format(has_context)) + + print("_forward_prefill") + print(" q.shape = {}".format(q.shape)) + print(" k.shape = {}".format(k.shape)) + print(" v.shape = {}".format(v.shape)) + + output = self._fi_prefill_run( q=q, k=k, v=v, + kv_cache=kv_c_and_k_pe_cache, + prefill_wrapper=attn_metadata.fi_prefill.prefill_wrapper, cu_seqlens_q=attn_metadata.prefill.query_start_loc, cu_seqlens_k=attn_metadata.prefill.query_start_loc, max_seqlen_q=attn_metadata.prefill.max_query_len, @@ -852,7 +1292,20 @@ def _forward_prefill( softmax_scale=self.scale, causal=True, return_softmax_lse=has_context, + layer=layer, ) + # output = self._flash_attn_varlen_diff_headdims( + # q=q, + # k=k, + # v=v, + # cu_seqlens_q=attn_metadata.prefill.query_start_loc, + # cu_seqlens_k=attn_metadata.prefill.query_start_loc, + # max_seqlen_q=attn_metadata.prefill.max_query_len, + # max_seqlen_k=attn_metadata.prefill.max_query_len, + # softmax_scale=self.scale, + # causal=True, + # return_softmax_lse=has_context, + # ) if has_context: suffix_output, suffix_lse = output @@ -896,6 +1349,9 @@ def forward( output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: + # print("INSIDE forward") + # print(" kv_cache.shape = {}".format(kv_cache.shape)) + assert output is not None, "Output tensor must be provided." if output_scale is not None: @@ -946,7 +1402,7 @@ def forward( if has_prefill: output[num_decode_tokens:] = self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata) + attn_metadata, layer) if has_decode: assert attn_metadata.decode is not None