Skip to content

Commit 0630d45

Browse files
afeldman-nmrshaw@neuralmagic.comnjhill
authored
[V1] Logprobs and prompt logprobs support (vllm-project#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1. New behavior: - During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order. - In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized. - During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.) - Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer. Signed-off-by: Andrew Feldman <[email protected]> Signed-off-by: Nick Hill <[email protected]> Signed-off-by: [email protected] <[email protected]> Co-authored-by: [email protected] <[email protected]> Co-authored-by: Nick Hill <[email protected]>
1 parent 538fab9 commit 0630d45

30 files changed

+2869
-287
lines changed

Diff for: tests/v1/core/test_scheduler.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,8 @@ def test_schedule_partial_requests():
195195
req_ids=[request.request_id for request in requests],
196196
req_id_to_index=req_to_index,
197197
sampled_token_ids=[0] * len(requests),
198-
logprob_token_ids_cpu=None,
199-
logprobs_cpu=None,
198+
logprobs=None,
199+
prompt_logprobs_dict={},
200200
)
201201
scheduler.update_from_output(output, model_runner_output)
202202

Diff for: tests/v1/engine/conftest.py

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from typing import List, Tuple
4+
5+
import pytest
6+
import torch
7+
from transformers import AutoTokenizer
8+
9+
from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
10+
NUM_SAMPLE_LOGPROBS_UNDER_TEST, PROMPT_LEN,
11+
TOKENIZER_NAME,
12+
DummyOutputProcessorTestVectors,
13+
generate_dummy_prompt_logprobs_tensors,
14+
generate_dummy_sample_logprobs)
15+
from vllm.engine.arg_utils import EngineArgs
16+
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
17+
18+
from tests.v1.engine.utils import FULL_STRINGS # isort: skip
19+
20+
EngineCoreSampleLogprobsType = List[Tuple[torch.Tensor, torch.Tensor]]
21+
EngineCorePromptLogprobsType = Tuple[torch.Tensor, torch.Tensor]
22+
23+
24+
def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors:
25+
"""Generate output processor dummy test vectors, without logprobs
26+
27+
Returns:
28+
DummyOutputProcessorTestVectors instance with no logprobs
29+
"""
30+
31+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
32+
vllm_config = EngineArgs(model=TOKENIZER_NAME).create_engine_config()
33+
# Tokenize prompts under test & create dummy generated tokens
34+
prompt_tokens = [
35+
tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS
36+
]
37+
generation_tokens = [
38+
tokenizer(text).input_ids[PROMPT_LEN:] for text in FULL_STRINGS
39+
]
40+
# Generate prompt strings
41+
prompt_strings = [
42+
tokenizer.decode(prompt_tokens, skip_special_tokens=True)
43+
for prompt_tokens in prompt_tokens
44+
]
45+
prompt_strings_len = [
46+
len(prompt_string) for prompt_string in prompt_strings
47+
]
48+
return DummyOutputProcessorTestVectors(
49+
tokenizer=tokenizer,
50+
tokenizer_group=init_tokenizer_from_configs(
51+
vllm_config.model_config, vllm_config.scheduler_config,
52+
vllm_config.parallel_config, vllm_config.lora_config),
53+
vllm_config=vllm_config,
54+
full_tokens=[tokenizer(text).input_ids for text in FULL_STRINGS],
55+
prompt_tokens=prompt_tokens,
56+
generation_tokens=generation_tokens,
57+
prompt_strings=prompt_strings,
58+
prompt_strings_len=prompt_strings_len,
59+
generation_strings=[
60+
text[prompt_len:]
61+
for text, prompt_len in zip(FULL_STRINGS, prompt_strings_len)
62+
],
63+
prompt_logprobs=[],
64+
generation_logprobs=[])
65+
66+
67+
@pytest.fixture
68+
def dummy_test_vectors() -> DummyOutputProcessorTestVectors:
69+
"""Generate output processor dummy test vectors, with logprobs
70+
71+
Returns:
72+
DummyOutputProcessorTestVectors instance with logprobs
73+
"""
74+
# Build dummy test vectors without logprobs
75+
dtv = _build_test_vectors_no_logprobs()
76+
# Inject logprobs into dummy test vectors
77+
# data structure
78+
dtv.generation_logprobs = [
79+
generate_dummy_sample_logprobs(
80+
sampled_tokens_list=tokens_list,
81+
num_logprobs=NUM_SAMPLE_LOGPROBS_UNDER_TEST,
82+
tokenizer=dtv.tokenizer) for tokens_list in dtv.generation_tokens
83+
]
84+
dtv.prompt_logprobs = [
85+
generate_dummy_prompt_logprobs_tensors(
86+
prompt_tokens_list=tokens_list,
87+
num_logprobs=NUM_PROMPT_LOGPROBS_UNDER_TEST,
88+
tokenizer=dtv.tokenizer) for tokens_list in dtv.prompt_tokens
89+
]
90+
return dtv

Diff for: tests/v1/engine/test_async_llm.py

+45-4
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
import asyncio
44
from contextlib import ExitStack
5-
from typing import List, Tuple
5+
from typing import List, Optional, Tuple
66

77
import pytest
88

9+
from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG
910
from vllm import SamplingParams
1011
from vllm.engine.arg_utils import AsyncEngineArgs
1112
from vllm.platforms import current_platform
@@ -21,13 +22,19 @@
2122
disable_log_requests=True)
2223

2324

24-
async def generate(engine: AsyncLLM, request_id: str,
25+
async def generate(engine: AsyncLLM,
26+
request_id: str,
2527
output_kind: RequestOutputKind,
26-
max_tokens: int) -> Tuple[int, str]:
28+
max_tokens: int,
29+
prompt_logprobs: Optional[int] = None) -> Tuple[int, str]:
30+
# Ensure generate doesn't complete too fast for cancellation test.
31+
await asyncio.sleep(0.2)
32+
2733
count = 0
2834
sampling_params = SamplingParams(max_tokens=max_tokens,
2935
output_kind=output_kind,
30-
temperature=0)
36+
temperature=0,
37+
prompt_logprobs=prompt_logprobs)
3138
async for out in engine.generate(request_id=request_id,
3239
prompt="Hello my name is Robert and",
3340
sampling_params=sampling_params):
@@ -43,6 +50,40 @@ async def generate(engine: AsyncLLM, request_id: str,
4350
return count, request_id
4451

4552

53+
@pytest.mark.parametrize(
54+
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
55+
@pytest.mark.asyncio
56+
async def test_async_llm_refuses_prompt_logprobs_with_apc(
57+
monkeypatch, output_kind: RequestOutputKind):
58+
"""Test passes if AsyncLLM raises an exception when it is configured
59+
for automatic prefix caching and it receives a request with
60+
prompt_logprobs enabled, which is incompatible."""
61+
# TODO(rickyx): Remove monkeypatch VLLM_USE_V1 setting once we have a
62+
# better way to test V1 so that in the future when we switch, we don't
63+
# have to change all the tests.
64+
monkeypatch.setenv("VLLM_USE_V1", "1")
65+
# Create AsyncLLM engine with APC
66+
apc_engine_args = AsyncEngineArgs(model="facebook/opt-125m",
67+
enable_prefix_caching=True,
68+
gpu_memory_utilization=0.8,
69+
disable_log_requests=True)
70+
engine = AsyncLLM.from_engine_args(apc_engine_args)
71+
try:
72+
with pytest.raises(ValueError) as excinfo:
73+
# Issue a request with prompt logprobs enabled, which should fail
74+
await asyncio.create_task(
75+
generate(engine,
76+
"request-0",
77+
output_kind,
78+
10,
79+
prompt_logprobs=5))
80+
# Validate exception string is correct
81+
assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG
82+
finally:
83+
# Shut down engine
84+
engine.shutdown()
85+
86+
4687
@pytest.mark.parametrize(
4788
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
4889
@pytest.mark.asyncio

Diff for: tests/v1/engine/test_llm_engine.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
5+
from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG
6+
from vllm import LLM, SamplingParams
7+
8+
9+
def test_llm_engine_refuses_prompt_logprobs_with_apc(monkeypatch):
10+
"""Test passes if LLMEngine raises an exception when it is configured
11+
for automatic prefix caching and it receives a request with
12+
prompt_logprobs enabled, which is incompatible."""
13+
14+
monkeypatch.setenv("VLLM_USE_V1", "1")
15+
# TODO(nick): Single-proc to work around a ZMQ shutdown hang for now.
16+
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
17+
with pytest.raises(ValueError) as excinfo:
18+
LLM(model="facebook/opt-125m", enable_prefix_caching=True).generate(
19+
"Hello, my name is",
20+
SamplingParams(temperature=0.8, top_p=0.95, prompt_logprobs=5))
21+
22+
# Validate exception string is correct
23+
assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG

0 commit comments

Comments
 (0)