Skip to content

Commit da02cb4

Browse files
authored
[core] further polish memory profiling (vllm-project#12126)
Signed-off-by: youkaichao <[email protected]>
1 parent c09503d commit da02cb4

File tree

3 files changed

+85
-67
lines changed

3 files changed

+85
-67
lines changed

tests/test_utils.py

+12-14
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
from vllm_test_utils import monitor
1010

1111
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
12-
from vllm.utils import (FlexibleArgumentParser, PlaceholderModule,
13-
StoreBoolean, bind_kv_cache, deprecate_kwargs,
14-
get_open_port, memory_profiling, merge_async_iterators,
15-
supports_kw)
12+
from vllm.utils import (FlexibleArgumentParser, MemorySnapshot,
13+
PlaceholderModule, StoreBoolean, bind_kv_cache,
14+
deprecate_kwargs, get_open_port, memory_profiling,
15+
merge_async_iterators, supports_kw)
1616

1717
from .utils import error_on_warning, fork_new_process_for_each_test
1818

@@ -284,14 +284,13 @@ def test_memory_profiling():
284284
# 512 MiB allocation outside of this instance
285285
handle1 = lib.cudaMalloc(512 * 1024 * 1024)
286286

287-
baseline_memory_in_bytes = \
288-
torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]
287+
baseline_snapshot = MemorySnapshot()
289288

290289
# load weights
291290

292291
weights = torch.randn(128, 1024, 1024, device='cuda', dtype=torch.float32)
293292

294-
weights_memory_in_bytes = 128 * 1024 * 1024 * 4 # 512 MiB
293+
weights_memory = 128 * 1024 * 1024 * 4 # 512 MiB
295294

296295
def measure_current_non_torch():
297296
free, total = torch.cuda.mem_get_info()
@@ -300,8 +299,8 @@ def measure_current_non_torch():
300299
current_non_torch = current_used - current_torch
301300
return current_non_torch
302301

303-
with memory_profiling(baseline_memory_in_bytes=baseline_memory_in_bytes,
304-
weights_memory_in_bytes=weights_memory_in_bytes) as result, \
302+
with memory_profiling(baseline_snapshot=baseline_snapshot,
303+
weights_memory=weights_memory) as result, \
305304
monitor(measure_current_non_torch) as monitored_values:
306305
# make a memory spike, 1 GiB
307306
spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32)
@@ -316,13 +315,12 @@ def measure_current_non_torch():
316315
assert measured_diff == 256 * 1024 * 1024
317316

318317
# Check that the memory usage is within 5% of the expected values
319-
# 5% tolerance is caused by PyTorch caching allocator,
320-
# we cannot control PyTorch's behavior of its internal buffers,
318+
# 5% tolerance is caused by cuda runtime.
319+
# we cannot control cuda runtime in the granularity of bytes,
321320
# which causes a small error (<10 MiB in practice)
322-
non_torch_ratio = result.non_torch_increase_in_bytes / (256 * 1024 * 1024) # noqa
323-
torch_peak_ratio = result.torch_peak_increase_in_bytes / (1024 * 1024 * 1024) # noqa
321+
non_torch_ratio = result.non_torch_increase / (256 * 1024 * 1024) # noqa
324322
assert abs(non_torch_ratio - 1) <= 0.05
325-
assert abs(torch_peak_ratio - 1) <= 0.05
323+
assert result.torch_peak_increase == 1024 * 1024 * 1024
326324
del weights
327325
lib.cudaFree(handle1)
328326
lib.cudaFree(handle2)

vllm/utils.py

+56-39
Original file line numberDiff line numberDiff line change
@@ -1923,55 +1923,72 @@ def kill_process_tree(pid: int):
19231923
@dataclass
19241924
class MemorySnapshot:
19251925
"""Memory snapshot."""
1926-
torch_peak_in_bytes: int = 0
1927-
torch_memory_in_bytes: int = 0
1926+
torch_peak: int = 0
1927+
cuda_memory: int = 0
1928+
torch_memory: int = 0
1929+
non_torch_memory: int = 0
19281930
timestamp: float = 0.0
1931+
auto_measure: bool = True
1932+
1933+
def __post_init__(self):
1934+
if self.auto_measure:
1935+
self.measure()
19291936

19301937
def measure(self):
1931-
self.torch_peak_in_bytes = torch.cuda.max_memory_reserved()
1938+
# we measure the torch peak memory usage via allocated_bytes,
1939+
# rather than `torch.cuda.memory_reserved()` .
1940+
# After `torch.cuda.reset_peak_memory_stats()`,
1941+
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
1942+
# when we call `torch.cuda.empty_cache()` or OOM happens.
1943+
self.torch_peak = torch.cuda.memory_stats().get(
1944+
"allocated_bytes.all.peak", 0)
1945+
1946+
self.cuda_memory = torch.cuda.mem_get_info(
1947+
)[1] - torch.cuda.mem_get_info()[0]
1948+
19321949
# torch.cuda.memory_reserved() is how many bytes
19331950
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
1934-
self.torch_memory_in_bytes = torch.cuda.memory_reserved()
1951+
# this is used to measure the non-torch memory usage
1952+
self.torch_memory = torch.cuda.memory_reserved()
1953+
1954+
self.non_torch_memory = self.cuda_memory - self.torch_memory
19351955
self.timestamp = time.time()
19361956

19371957
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
1938-
"""support a - b"""
19391958
return MemorySnapshot(
1940-
torch_peak_in_bytes=self.torch_peak_in_bytes -
1941-
other.torch_peak_in_bytes,
1942-
torch_memory_in_bytes=self.torch_memory_in_bytes -
1943-
other.torch_memory_in_bytes,
1944-
timestamp=self.timestamp - other.timestamp)
1959+
torch_peak=self.torch_peak - other.torch_peak,
1960+
cuda_memory=self.cuda_memory - other.cuda_memory,
1961+
torch_memory=self.torch_memory - other.torch_memory,
1962+
non_torch_memory=self.non_torch_memory - other.non_torch_memory,
1963+
timestamp=self.timestamp - other.timestamp,
1964+
auto_measure=False,
1965+
)
19451966

19461967

19471968
@dataclass
19481969
class MemoryProfilingResult:
1949-
"""Memory profiling result.
1950-
""" # noqa
1951-
baseline_memory_in_bytes: int = 0
1952-
non_kv_cache_memory_in_bytes: int = 0
1953-
torch_peak_increase_in_bytes: int = 0
1954-
non_torch_increase_in_bytes: int = 0
1955-
weights_memory_in_bytes: float = 0
1970+
"""Memory profiling result. All numbers are in bytes.
1971+
"""
1972+
non_kv_cache_memory: int = 0
1973+
torch_peak_increase: int = 0
1974+
non_torch_increase: int = 0
1975+
weights_memory: float = 0
1976+
before_create: MemorySnapshot = field(default_factory=MemorySnapshot)
19561977
before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
19571978
after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
19581979
profile_time: float = 0.0
19591980

19601981

19611982
@contextlib.contextmanager
19621983
def memory_profiling(
1963-
baseline_memory_in_bytes: int, weights_memory_in_bytes: int
1964-
) -> Generator[MemoryProfilingResult, None, None]:
1984+
baseline_snapshot: MemorySnapshot,
1985+
weights_memory: int) -> Generator[MemoryProfilingResult, None, None]:
19651986
"""Memory profiling context manager.
1966-
baseline_memory_in_bytes: memory used by all the components other than
1967-
the current vLLM instance. It contains: memory used by other processes, memory
1968-
used by another vLLM instance in the same process, etc. It is usually measured
1969-
before the current vLLM instance initialize the device. And we assume it is
1970-
constant during the profiling of the current vLLM instance.
1971-
weights_memory_in_bytes: memory used by PyTorch when loading the model weights.
1987+
baseline_snapshot: the memory snapshot before the current vLLM instance.
1988+
weights_memory: memory used by PyTorch when loading the model weights.
19721989
Note that, before loading the model weights, we also initialize the device
19731990
and distributed environment, which may consume some memory. This part is not
1974-
included in the weights_memory_in_bytes because PyTorch does not control it.
1991+
included in the weights_memory because PyTorch does not control it.
19751992
19761993
The memory in one GPU can be classified into 3 categories:
19771994
1. memory used by anything other than the current vLLM instance.
@@ -2006,20 +2023,21 @@ def memory_profiling(
20062023
b. 2 GiB reserved for the peak activation tensors (category 2)
20072024
c. 1 GiB used by non-torch components (category 3)
20082025
2009-
The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes`.
2026+
The memory used for loading weights (a.) is directly given from the argument `weights_memory`.
20102027
2011-
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.).
2028+
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
20122029
2013-
(c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`),
2014-
subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_reserved()`.
2030+
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
20152031
""" # noqa
2032+
gc.collect()
2033+
torch.cuda.empty_cache()
20162034
torch.cuda.reset_peak_memory_stats()
20172035

20182036
result = MemoryProfilingResult()
20192037

2020-
result.baseline_memory_in_bytes = baseline_memory_in_bytes
2038+
result.before_create = baseline_snapshot
20212039
# the part of memory used for holding the model weights
2022-
result.weights_memory_in_bytes = weights_memory_in_bytes
2040+
result.weights_memory = weights_memory
20232041

20242042
result.before_profile.measure()
20252043

@@ -2030,13 +2048,12 @@ def memory_profiling(
20302048

20312049
result.after_profile.measure()
20322050

2033-
diff = result.after_profile - result.before_profile
2034-
result.torch_peak_increase_in_bytes = diff.torch_peak_in_bytes
2035-
current_cuda_memory_bytes = torch.cuda.mem_get_info(
2036-
)[1] - torch.cuda.mem_get_info()[0]
2037-
result.non_torch_increase_in_bytes = current_cuda_memory_bytes - baseline_memory_in_bytes - weights_memory_in_bytes - diff.torch_memory_in_bytes # noqa
2038-
result.profile_time = diff.timestamp
2039-
result.non_kv_cache_memory_in_bytes = result.non_torch_increase_in_bytes + result.torch_peak_increase_in_bytes + result.weights_memory_in_bytes # noqa
2051+
diff_profile = result.after_profile - result.before_profile
2052+
diff_from_create = result.after_profile - result.before_create
2053+
result.torch_peak_increase = diff_profile.torch_peak
2054+
result.non_torch_increase = diff_from_create.non_torch_memory
2055+
result.profile_time = diff_profile.timestamp
2056+
result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa
20402057

20412058

20422059
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501

vllm/worker/worker.py

+17-14
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
from vllm.prompt_adapter.request import PromptAdapterRequest
2222
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
2323
SequenceGroupMetadata, SequenceGroupMetadataDelta)
24-
from vllm.utils import GiB_bytes, bind_kv_cache, memory_profiling
24+
from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache,
25+
memory_profiling)
2526
from vllm.worker.cache_engine import CacheEngine
2627
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
2728
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
@@ -137,7 +138,8 @@ def init_device(self) -> None:
137138
_check_if_gpu_supports_dtype(self.model_config.dtype)
138139
gc.collect()
139140
torch.cuda.empty_cache()
140-
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
141+
torch.cuda.reset_peak_memory_stats()
142+
self.baseline_snapshot = MemorySnapshot()
141143
else:
142144
raise RuntimeError(
143145
f"Not support device type: {self.device_config.device}")
@@ -192,18 +194,17 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
192194

193195
# Execute a forward pass with dummy inputs to profile the memory usage
194196
# of the model.
195-
with memory_profiling(baseline_memory_in_bytes=total_gpu_memory -
196-
self.init_gpu_memory,
197-
weights_memory_in_bytes=self.model_runner.
198-
model_memory_usage) as result:
197+
with memory_profiling(
198+
self.baseline_snapshot,
199+
weights_memory=self.model_runner.model_memory_usage) as result:
199200
self.model_runner.profile_run()
200201

201202
self._assert_memory_footprint_increased_during_profiling()
202203

203204
memory_for_current_instance = total_gpu_memory * \
204205
self.cache_config.gpu_memory_utilization
205206
available_kv_cache_memory = (memory_for_current_instance -
206-
result.non_kv_cache_memory_in_bytes)
207+
result.non_kv_cache_memory)
207208

208209
# Calculate the number of blocks that can be allocated with the
209210
# profiled peak memory.
@@ -226,11 +227,11 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
226227
f"({self.cache_config.gpu_memory_utilization:.2f})"
227228
f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n"
228229
"model weights take "
229-
f"{(result.weights_memory_in_bytes / GiB_bytes):.2f}GiB;"
230+
f"{(result.weights_memory / GiB_bytes):.2f}GiB;"
230231
" non_torch_memory takes "
231-
f"{(result.non_torch_increase_in_bytes / GiB_bytes):.2f}GiB;"
232+
f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;"
232233
" PyTorch activation peak memory takes "
233-
f"{(result.torch_peak_increase_in_bytes / GiB_bytes):.2f}GiB;"
234+
f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;"
234235
" the rest of the memory reserved for KV Cache is "
235236
f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.")
236237

@@ -246,11 +247,13 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
246247
def _assert_memory_footprint_increased_during_profiling(self):
247248
# NOTE(woosuk): Here we assume that the other processes using the same
248249
# GPU did not change their memory usage during the profiling.
249-
free_gpu_memory, _ = torch.cuda.mem_get_info()
250-
assert self.init_gpu_memory - free_gpu_memory > 0, (
250+
free_gpu_memory, total = torch.cuda.mem_get_info()
251+
cuda_memory = total - free_gpu_memory
252+
assert self.baseline_snapshot.cuda_memory < cuda_memory, (
251253
"Error in memory profiling. "
252-
f"Initial free memory {self.init_gpu_memory}, current free memory"
253-
f" {free_gpu_memory}. This happens when the GPU memory was "
254+
f"Initial used memory {self.baseline_snapshot.cuda_memory}, "
255+
f"currently used memory {cuda_memory}. "
256+
f"This happens when the GPU memory was "
254257
"not properly cleaned up before initializing the vLLM instance.")
255258

256259
def initialize_cache(self, num_gpu_blocks: int,

0 commit comments

Comments
 (0)