Skip to content

Commit 25f560a

Browse files
authoredMar 25, 2025
[V1][Spec Decode] Update target_logits in place for rejection sampling (#15427)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent a09ad90 commit 25f560a

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed
 

‎vllm/v1/sample/rejection_sampler.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def forward(
6767
Shape is [num_tokens, vocab_size]. Here, probabilities from
6868
different requests are flattened into a single tensor because
6969
this is the shape of the output logits.
70+
NOTE: `target_logits` can be updated in place to save memory.
7071
bonus_token_ids_tensor (torch.Tensor):
7172
A tensor containing bonus tokens. Shape is [batch_size, 1].
7273
Bonus tokens are added to the end of the sequence if all
@@ -83,6 +84,8 @@ def forward(
8384
'''
8485
assert metadata.max_spec_len <= MAX_SPEC_LEN
8586
# [num_tokens, vocab_size]
87+
# NOTE(woosuk): `target_logits` can be updated in place inside the
88+
# `compute_probs` function.
8689
target_probs = compute_probs(
8790
target_logits,
8891
metadata.cu_num_draft_tokens,
@@ -252,8 +255,8 @@ def compute_probs(
252255
replace_from=GREEDY_TEMPERATURE,
253256
replace_to=1,
254257
)
255-
# TODO(woosuk): Consider using in-place op to reduce memory usage.
256-
logits = logits / temperature.unsqueeze(-1)
258+
# NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor.
259+
logits.div_(temperature.unsqueeze(-1))
257260

258261
# Get expanded top_k and top_p tensors.
259262
top_k = None

‎vllm/v1/worker/gpu_model_runner.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1059,15 +1059,20 @@ def execute_model(
10591059
sampling_metadata=sampling_metadata,
10601060
)
10611061
else:
1062-
# TODO(woosuk): Optimize the memory usage.
1062+
# When indexing with a tensor (bonus_logits_indices), PyTorch
1063+
# creates a new tensor with separate storage from the original
1064+
# logits tensor. This means any in-place operations on bonus_logits
1065+
# won't affect the original logits tensor.
10631066
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
10641067
sampler_output = self.model.sample(
10651068
logits=bonus_logits,
10661069
sampling_metadata=sampling_metadata,
10671070
)
10681071
bonus_token_ids = sampler_output.sampled_token_ids
10691072

1070-
# TODO(woosuk): Optimize the memory usage.
1073+
# Just like `bonus_logits`, `target_logits` is a new tensor with
1074+
# separate storage from the original `logits` tensor. Therefore,
1075+
# it is safe to update `target_logits` in place.
10711076
target_logits = logits[spec_decode_metadata.target_logits_indices]
10721077
output_token_ids = self.rejection_sampler(
10731078
spec_decode_metadata,

0 commit comments

Comments
 (0)