@@ -67,6 +67,7 @@ def forward(
67
67
Shape is [num_tokens, vocab_size]. Here, probabilities from
68
68
different requests are flattened into a single tensor because
69
69
this is the shape of the output logits.
70
+ NOTE: `target_logits` can be updated in place to save memory.
70
71
bonus_token_ids_tensor (torch.Tensor):
71
72
A tensor containing bonus tokens. Shape is [batch_size, 1].
72
73
Bonus tokens are added to the end of the sequence if all
@@ -83,6 +84,8 @@ def forward(
83
84
'''
84
85
assert metadata .max_spec_len <= MAX_SPEC_LEN
85
86
# [num_tokens, vocab_size]
87
+ # NOTE(woosuk): `target_logits` can be updated in place inside the
88
+ # `compute_probs` function.
86
89
target_probs = compute_probs (
87
90
target_logits ,
88
91
metadata .cu_num_draft_tokens ,
@@ -252,8 +255,8 @@ def compute_probs(
252
255
replace_from = GREEDY_TEMPERATURE ,
253
256
replace_to = 1 ,
254
257
)
255
- # TODO (woosuk): Consider using in- place op to reduce memory usage .
256
- logits = logits / temperature .unsqueeze (- 1 )
258
+ # NOTE (woosuk): Update `logits` in place to avoid allocating a new tensor .
259
+ logits . div_ ( temperature .unsqueeze (- 1 ) )
257
260
258
261
# Get expanded top_k and top_p tensors.
259
262
top_k = None
0 commit comments