Skip to content

Commit eaa2e51

Browse files
authored
[Bugfix] Re-enable use_cudagraph in vLLM v1 (#19299)
Signed-off-by: Richard Zou <[email protected]>
1 parent d77f7fb commit eaa2e51

File tree

6 files changed

+52
-8
lines changed

6 files changed

+52
-8
lines changed

tests/compile/piecewise/test_simple.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def _test_simple_piecewise_compile(*, use_inductor):
9595
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
9696
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
9797
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
98-
num_cudagraph_caputured=
98+
num_cudagraph_captured=
9999
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
100100
):
101101

tests/compile/piecewise/test_toy_llama.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def _test_toy_llama(*, use_inductor):
327327
num_piecewise_graphs_seen=0,
328328
num_piecewise_capturable_graphs_seen=0,
329329
num_backend_compilations=0,
330-
num_cudagraph_caputured=0,
330+
num_cudagraph_captured=0,
331331
):
332332
outputs.append(
333333
run_model(llama_config, use_inductor=False, use_compile=False))
@@ -343,7 +343,7 @@ def _test_toy_llama(*, use_inductor):
343343
num_piecewise_graphs_seen=1,
344344
num_piecewise_capturable_graphs_seen=1,
345345
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
346-
num_cudagraph_caputured=
346+
num_cudagraph_captured=
347347
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
348348
**kwargs,
349349
):
@@ -361,7 +361,7 @@ def _test_toy_llama(*, use_inductor):
361361
llama_config.num_layers, # 1 + num_layers
362362
num_backend_compilations=1 +
363363
llama_config.num_layers, # num_piecewise_capturable_graphs_seen
364-
num_cudagraph_caputured=2 *
364+
num_cudagraph_captured=2 *
365365
(1 + llama_config.num_layers
366366
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
367367
):

tests/compile/test_config.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
import torch
5+
6+
import vllm
7+
from vllm.compilation.counter import compilation_counter
8+
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
9+
set_current_vllm_config)
10+
11+
from .piecewise.test_simple import SillyModel
12+
13+
14+
@pytest.fixture(scope="function", autouse=True)
15+
def use_v1(monkeypatch):
16+
"""
17+
TODO(rzou): The rest of tests/compile runs VLLM_USE_V1=0 right now,
18+
I'll switch them over later.
19+
"""
20+
monkeypatch.setenv('VLLM_USE_V1', '1')
21+
22+
23+
@pytest.mark.parametrize("enabled", [True, False])
24+
def test_use_cudagraphs(enabled):
25+
assert vllm.envs.VLLM_USE_V1
26+
vllm_config = VllmConfig(compilation_config=CompilationConfig(
27+
level=CompilationLevel.PIECEWISE,
28+
use_cudagraph=enabled,
29+
cudagraph_capture_sizes=[100],
30+
))
31+
with set_current_vllm_config(vllm_config):
32+
model = SillyModel(vllm_config=vllm_config, prefix='')
33+
34+
inputs = torch.randn(100, device="cuda")
35+
36+
with compilation_counter.expect(
37+
num_graphs_seen=1, # one graph for the model
38+
num_cudagraph_captured=1 if enabled else 0,
39+
):
40+
# first run is warmup
41+
model(inputs)
42+
# second run does CUDAGraphs recording (if enabled)
43+
model(inputs)

vllm/compilation/counter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class CompilationCounter:
1515
# not including the splitting ops
1616
num_piecewise_capturable_graphs_seen: int = 0
1717
num_backend_compilations: int = 0
18-
num_cudagraph_caputured: int = 0
18+
num_cudagraph_captured: int = 0
1919
# InductorAdapter.compile calls
2020
num_inductor_compiles: int = 0
2121
# EagerAdapter.compile calls

vllm/compilation/cuda_piecewise_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def __call__(self, *args) -> Any:
193193
entry.output = weak_ref_tensors(output)
194194
entry.cudagraph = cudagraph
195195

196-
compilation_counter.num_cudagraph_caputured += 1
196+
compilation_counter.num_cudagraph_captured += 1
197197

198198
# important: we need to return the output, rather than
199199
# the weak ref of the output, so that pytorch can correctly

vllm/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3918,12 +3918,14 @@ class CompilationConfig:
39183918
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""
39193919

39203920
# CudaGraph compilation
3921-
use_cudagraph: bool = False
3921+
use_cudagraph: bool = envs.VLLM_USE_V1
39223922
"""Whether to use cudagraph inside compilation.
39233923
- False: cudagraph inside compilation is not used.
39243924
- True: cudagraph inside compilation is used. It requires
39253925
that all input buffers have fixed addresses, and all
39263926
splitting ops write their outputs to input buffers.
3927+
In the vLLM V1 Engine, this flag only applies for
3928+
CompilationLevel.PIECEWISE (aka -O3).
39273929
Note that this is orthogonal to the cudagraph capture logic
39283930
outside of compilation.
39293931
TODO: move outside cudagraph logic into compilation.
@@ -4425,7 +4427,6 @@ def __post_init__(self):
44254427
# FIXME(rob): Add function to set all of these.
44264428
if not self.compilation_config.custom_ops:
44274429
self.compilation_config.custom_ops = ["none"]
4428-
self.compilation_config.use_cudagraph = True
44294430
self.compilation_config.cudagraph_num_of_warmups = 1
44304431
self.compilation_config.pass_config.enable_fusion = False
44314432
self.compilation_config.pass_config.enable_noop = False

0 commit comments

Comments
 (0)