@@ -1923,55 +1923,72 @@ def kill_process_tree(pid: int):
1923
1923
@dataclass
1924
1924
class MemorySnapshot :
1925
1925
"""Memory snapshot."""
1926
- torch_peak_in_bytes : int = 0
1927
- torch_memory_in_bytes : int = 0
1926
+ torch_peak : int = 0
1927
+ cuda_memory : int = 0
1928
+ torch_memory : int = 0
1929
+ non_torch_memory : int = 0
1928
1930
timestamp : float = 0.0
1931
+ auto_measure : bool = True
1932
+
1933
+ def __post_init__ (self ):
1934
+ if self .auto_measure :
1935
+ self .measure ()
1929
1936
1930
1937
def measure (self ):
1931
- self .torch_peak_in_bytes = torch .cuda .max_memory_reserved ()
1938
+ # we measure the torch peak memory usage via allocated_bytes,
1939
+ # rather than `torch.cuda.memory_reserved()` .
1940
+ # After `torch.cuda.reset_peak_memory_stats()`,
1941
+ # `torch.cuda.memory_reserved()` will keep growing, and only shrink
1942
+ # when we call `torch.cuda.empty_cache()` or OOM happens.
1943
+ self .torch_peak = torch .cuda .memory_stats ().get (
1944
+ "allocated_bytes.all.peak" , 0 )
1945
+
1946
+ self .cuda_memory = torch .cuda .mem_get_info (
1947
+ )[1 ] - torch .cuda .mem_get_info ()[0 ]
1948
+
1932
1949
# torch.cuda.memory_reserved() is how many bytes
1933
1950
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
1934
- self .torch_memory_in_bytes = torch .cuda .memory_reserved ()
1951
+ # this is used to measure the non-torch memory usage
1952
+ self .torch_memory = torch .cuda .memory_reserved ()
1953
+
1954
+ self .non_torch_memory = self .cuda_memory - self .torch_memory
1935
1955
self .timestamp = time .time ()
1936
1956
1937
1957
def __sub__ (self , other : "MemorySnapshot" ) -> "MemorySnapshot" :
1938
- """support a - b"""
1939
1958
return MemorySnapshot (
1940
- torch_peak_in_bytes = self .torch_peak_in_bytes -
1941
- other .torch_peak_in_bytes ,
1942
- torch_memory_in_bytes = self .torch_memory_in_bytes -
1943
- other .torch_memory_in_bytes ,
1944
- timestamp = self .timestamp - other .timestamp )
1959
+ torch_peak = self .torch_peak - other .torch_peak ,
1960
+ cuda_memory = self .cuda_memory - other .cuda_memory ,
1961
+ torch_memory = self .torch_memory - other .torch_memory ,
1962
+ non_torch_memory = self .non_torch_memory - other .non_torch_memory ,
1963
+ timestamp = self .timestamp - other .timestamp ,
1964
+ auto_measure = False ,
1965
+ )
1945
1966
1946
1967
1947
1968
@dataclass
1948
1969
class MemoryProfilingResult :
1949
- """Memory profiling result.
1950
- """ # noqa
1951
- baseline_memory_in_bytes : int = 0
1952
- non_kv_cache_memory_in_bytes : int = 0
1953
- torch_peak_increase_in_bytes : int = 0
1954
- non_torch_increase_in_bytes : int = 0
1955
- weights_memory_in_bytes : float = 0
1970
+ """Memory profiling result. All numbers are in bytes.
1971
+ """
1972
+ non_kv_cache_memory : int = 0
1973
+ torch_peak_increase : int = 0
1974
+ non_torch_increase : int = 0
1975
+ weights_memory : float = 0
1976
+ before_create : MemorySnapshot = field ( default_factory = MemorySnapshot )
1956
1977
before_profile : MemorySnapshot = field (default_factory = MemorySnapshot )
1957
1978
after_profile : MemorySnapshot = field (default_factory = MemorySnapshot )
1958
1979
profile_time : float = 0.0
1959
1980
1960
1981
1961
1982
@contextlib .contextmanager
1962
1983
def memory_profiling (
1963
- baseline_memory_in_bytes : int , weights_memory_in_bytes : int
1964
- ) -> Generator [MemoryProfilingResult , None , None ]:
1984
+ baseline_snapshot : MemorySnapshot ,
1985
+ weights_memory : int ) -> Generator [MemoryProfilingResult , None , None ]:
1965
1986
"""Memory profiling context manager.
1966
- baseline_memory_in_bytes: memory used by all the components other than
1967
- the current vLLM instance. It contains: memory used by other processes, memory
1968
- used by another vLLM instance in the same process, etc. It is usually measured
1969
- before the current vLLM instance initialize the device. And we assume it is
1970
- constant during the profiling of the current vLLM instance.
1971
- weights_memory_in_bytes: memory used by PyTorch when loading the model weights.
1987
+ baseline_snapshot: the memory snapshot before the current vLLM instance.
1988
+ weights_memory: memory used by PyTorch when loading the model weights.
1972
1989
Note that, before loading the model weights, we also initialize the device
1973
1990
and distributed environment, which may consume some memory. This part is not
1974
- included in the weights_memory_in_bytes because PyTorch does not control it.
1991
+ included in the weights_memory because PyTorch does not control it.
1975
1992
1976
1993
The memory in one GPU can be classified into 3 categories:
1977
1994
1. memory used by anything other than the current vLLM instance.
@@ -2006,20 +2023,21 @@ def memory_profiling(
2006
2023
b. 2 GiB reserved for the peak activation tensors (category 2)
2007
2024
c. 1 GiB used by non-torch components (category 3)
2008
2025
2009
- The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes `.
2026
+ The memory used for loading weights (a.) is directly given from the argument `weights_memory `.
2010
2027
2011
- The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.).
2028
+ The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
2012
2029
2013
- (c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`),
2014
- subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_reserved()`.
2030
+ The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
2015
2031
""" # noqa
2032
+ gc .collect ()
2033
+ torch .cuda .empty_cache ()
2016
2034
torch .cuda .reset_peak_memory_stats ()
2017
2035
2018
2036
result = MemoryProfilingResult ()
2019
2037
2020
- result .baseline_memory_in_bytes = baseline_memory_in_bytes
2038
+ result .before_create = baseline_snapshot
2021
2039
# the part of memory used for holding the model weights
2022
- result .weights_memory_in_bytes = weights_memory_in_bytes
2040
+ result .weights_memory = weights_memory
2023
2041
2024
2042
result .before_profile .measure ()
2025
2043
@@ -2030,13 +2048,12 @@ def memory_profiling(
2030
2048
2031
2049
result .after_profile .measure ()
2032
2050
2033
- diff = result .after_profile - result .before_profile
2034
- result .torch_peak_increase_in_bytes = diff .torch_peak_in_bytes
2035
- current_cuda_memory_bytes = torch .cuda .mem_get_info (
2036
- )[1 ] - torch .cuda .mem_get_info ()[0 ]
2037
- result .non_torch_increase_in_bytes = current_cuda_memory_bytes - baseline_memory_in_bytes - weights_memory_in_bytes - diff .torch_memory_in_bytes # noqa
2038
- result .profile_time = diff .timestamp
2039
- result .non_kv_cache_memory_in_bytes = result .non_torch_increase_in_bytes + result .torch_peak_increase_in_bytes + result .weights_memory_in_bytes # noqa
2051
+ diff_profile = result .after_profile - result .before_profile
2052
+ diff_from_create = result .after_profile - result .before_create
2053
+ result .torch_peak_increase = diff_profile .torch_peak
2054
+ result .non_torch_increase = diff_from_create .non_torch_memory
2055
+ result .profile_time = diff_profile .timestamp
2056
+ result .non_kv_cache_memory = result .non_torch_increase + result .torch_peak_increase + result .weights_memory # noqa
2040
2057
2041
2058
2042
2059
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501
0 commit comments