Skip to content

[Bug]: Huge memory overhead with V1 (multiprocessing) when handling several multimodal inputs #16185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
1 task done
p88h opened this issue Apr 7, 2025 · 20 comments · Fixed by #16432
Closed
1 task done
Labels
bug Something isn't working

Comments

@p88h
Copy link
Contributor

p88h commented Apr 7, 2025

Your current environment

The output of `python collect_env.py`
INFO 04-07 10:56:24 [__init__.py:239] Automatically detected platform cuda.
Collecting environment information...
PyTorch version: 2.6.0+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.2 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: 18.1.3 (1ubuntu1)
CMake version: version 3.28.3
Libc version: glibc-2.39

Python version: 3.12.8 | packaged by conda-forge | (main, Dec  5 2024, 14:24:40) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.6.75.1-microsoft-standard-WSL2-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.0.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4080
Nvidia driver version: 572.47
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.2
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        46 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               24
On-line CPU(s) list:                  0-23
Vendor ID:                            GenuineIntel
Model name:                           13th Gen Intel(R) Core(TM) i7-13700K
CPU family:                           6
Model:                                183
Thread(s) per core:                   2
Core(s) per socket:                   12
Socket(s):                            1
Stepping:                             1
BogoMIPS:                             6835.19
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni vnmi umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize flush_l1d arch_capabilities
Virtualization:                       VT-x
Hypervisor vendor:                    Microsoft
Virtualization type:                  full
L1d cache:                            576 KiB (12 instances)
L1i cache:                            384 KiB (12 instances)
L2 cache:                             24 MiB (12 instances)
L3 cache:                             30 MiB (1 instance)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-23
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Mitigation; Clear Register File
Vulnerability Retbleed:               Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-ml-py==12.570.86
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] onnx==1.17.0
[pip3] onnxruntime==1.20.1
[pip3] pytorch-lightning==2.5.0.post0
[pip3] pytorch-metric-learning==2.8.1
[pip3] pyzmq==26.2.1
[pip3] torch==2.6.0
[pip3] torch-audiomentations==0.12.0
[pip3] torch_pitch_shift==1.2.5
[pip3] torchaudio==2.6.0
[pip3] torchmetrics==1.6.1
[pip3] torchvision==0.21.0
[pip3] transformers==4.50.3
[pip3] triton==3.2.0
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-cublas-cu12        12.4.5.8                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.4.127                 pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.4.127                 pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.4.127                 pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
[conda] nvidia-cufft-cu12         11.2.1.3                 pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.5.147               pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.6.1.9                 pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.3.1.170               pypi_0    pypi
[conda] nvidia-cusparselt-cu12    0.6.2                    pypi_0    pypi
[conda] nvidia-ml-py              12.570.86                pypi_0    pypi
[conda] nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.4.127                 pypi_0    pypi
[conda] pytorch-lightning         2.5.0.post0              pypi_0    pypi
[conda] pytorch-metric-learning   2.8.1                    pypi_0    pypi
[conda] pyzmq                     26.2.1                   pypi_0    pypi
[conda] torch                     2.6.0                    pypi_0    pypi
[conda] torch-audiomentations     0.12.0                   pypi_0    pypi
[conda] torch-pitch-shift         1.2.5                    pypi_0    pypi
[conda] torchaudio                2.6.0                    pypi_0    pypi
[conda] torchmetrics              1.6.1                    pypi_0    pypi
[conda] torchvision               0.21.0                   pypi_0    pypi
[conda] transformers              4.50.3                   pypi_0    pypi
[conda] triton                    3.2.0                    pypi_0    pypi
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.8.3.dev212+g58e234a75.d20250405
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X                              N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NCCL_CUMEM_ENABLE=0
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY

🐛 Describe the bug

This should be reproducible with QWEN VL 2.5 and using vision_language_multi_image.py offline inference.
When configured to use several images as input (e.g. just multiply IMAGE_URLS 8-16 times), the CPU memory usage spikes dramatically. With just around 20 or so images, VLLM will try to consume around 20-30GB of RAM, with 40 it will get into 50-60GB range.

It's interesting that the number of passed elements seem to be a problem, while the size causes less problems (ie. when merging multiple images into one, without scaling, it's possible to send ~60 images via 4 collages and still fit within 30GB). This happens regardless of whether qwen-vl-utils are installed to resize the images.

At the same time request preprocessing starts to slow down, and while profiling was not super useful due to multiprocessing it did help as the next step was disabling that. With VLLM_ENABLE_V1_MULTIPROCESSING=0 the issue completely disappears, processing delays are gone and even with ~100 input files the memory usage is in low GBs.

I haven't tried profiling MessageQueue yet, but perhaps someone will also run into this, or maybe already ran into?

BTW - the environment above was collected running within WSL, but the OS doesn't seem to be a significant factor here - while WSL does run into memory issues earlier than raw Linux, running on raw Linux exhibits the same behavior.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@p88h p88h added the bug Something isn't working label Apr 7, 2025
@DarkLight1337
Copy link
Member

Can you see if --disable-mm-preprocessor-cache alleviates the issue?

@sfc-gh-spasko
Copy link

--disable-mm-preprocessor-cache doesn't seem to affect this.

@DarkLight1337
Copy link
Member

DarkLight1337 commented Apr 7, 2025

cc @ywang96 @robertgshaw2-redhat maybe it's related to transferring the multimodal inputs between processes?

@njhill
Copy link
Member

njhill commented Apr 7, 2025

cc @ywang96 @robertgshaw2-redhat maybe it's related to transferring the multimodal inputs between processes?

We have some in-progress improvements for this aspect.

@DarkLight1337
Copy link
Member

Can you try out #16273 and see if it improves the memory overhead?

@p88h
Copy link
Contributor Author

p88h commented Apr 8, 2025

It's much better ... but it's unfortunately still pretty much unusable.

With muliprocessing, handling ~64 1Mpix images used to exceed 52GB and OOM at this point (on a 64GB system...)
After the patch, and with mm_preprocessor_cache enabled, it now fits within 37GB and runs correctly.
At 32 images, before the patch it peaks at 37GB, and with the patch (& preproc cache) it's closer to 18GB.

So while that's likely a 2x improvement, it's still very slow in the preparation phase and the memory use is still very significant.
Without multiprocessing the same workload requires only <4GB.

~16GB memory load to process ~128MB worth of data seems pretty steep.

@DarkLight1337
Copy link
Member

DarkLight1337 commented Apr 8, 2025

How much memory does vLLM use if you disable preprocessor cache?

@p88h
Copy link
Contributor Author

p88h commented Apr 8, 2025

If the preprocessor cache is disabled, the patch has no effect.
Without the patch, on 32 images,. it consumes ~4GB regardless of cache setting in single-process mode, and ~37GB in multi-process mode.
With the patch, on 32 images, it consumes 37/18 GB in multi-process mode depending on the cache setting, and ~4GB in single-process mode.

There may be some difference at that ~4GB level since i was mostly looking at the multi-GB consumption.

@DarkLight1337
Copy link
Member

DarkLight1337 commented Apr 8, 2025

Quite surprising that disabling the cache actually increases the memory overhead by that much. I guess @njhill 's WIP should address that. Thanks for reporting the results!

@p88h
Copy link
Contributor Author

p88h commented Apr 8, 2025

Well, the patch depends on the items to be in the cache, no ? So I would assume if the cache disabled, it will revert to the pre-patch behavior, basically.

@DarkLight1337
Copy link
Member

Yes, you're correct. So the main problem is still about how the data is being transferred in multiprocessing mode.

@p88h
Copy link
Contributor Author

p88h commented Apr 8, 2025

Oh, okay, it seems that serialization itself is not really that big of the problem, but the actual issue was in encoding the multimodal args properly.

This POC seems to fix the problem:
#16279

@njhill
Copy link
Member

njhill commented Apr 8, 2025

Thanks @p88h, aren't encoding and serialization essentially the same thing? :)

This is similar to what I was planning - it just needs to be a bit more general since AFAIK NestedTensors may not always be a single tensor and it has a _items_by_modality field. And combining with #13790 to be even more efficient.

Oh and also to do this in a way that avoids using pickle since we're planning to disable the use of pickle by default (that was actually a secondary motivation). cc @russellb

@p88h
Copy link
Contributor Author

p88h commented Apr 8, 2025

In this case they really aren't the same - just as in the previous similar workaround for bare torch.Tensor, serializing it via pickle uses print format, not bytes. I agree something more generic that can just do this in zero-copy fashion (and without pickles) would be preferable, I will have a look at your PR, updated mine to handle NestedTensors properly in case it helps.

@p88h
Copy link
Contributor Author

p88h commented Apr 10, 2025

I completed a simple benchmark to quantify the problems along the impact of PRs affecting this. Results below;

All times Best of 2 runs each, memory averaged
Time = elapsed realtime via 'time'.
Memory = max RSS reported via 'time'

Note 32 images has 50% duplication, 64 images has 75% duplication (the image files are repeated 2/4 times). Each image is about 1Mpix, and generates ~1000 tokens.

Model used: Qwen/Qwen2.5-VL-7B-Instruct-AWQ. Output limited to 1024 tokens.

Times marked with * mean the model hallucinates and actually hits the limit.
Times marked with ! mean the model truncated the response.
This is less relevant for 64 image case, as that requires about these many tokens.

patches and parameters applied             | 16 images  | 32 images  | 64 images   |
-------------------------------------------+------------+------------+-------------+
baseline, cache on, single                 | 5.2G 33s   | 5.3G 53s*  | 5.8G 69s*   |
baseline, cache off, single                | 5.2G 34s   | 5.3G 54s*  | 5.4G 68s*   |
baseline, cache on, multi                  | 11.3G 44s  | 39.5G 92s* | OOM         |
baseline, cache off, multi                 | 11.3G 43s  | 39.5G 93s* | OOM         |
#16273, cache on, multi                    | 10.2G 45s  | 18.6G 72s* | 35.5G 107s* |
#16273, cache off, multi                   | 11.4G 46s  | 39.5G 94s* | OOM         |
#16273 + #13790, cache on, multi           | 10.2G 46s  | 18.6G 73s* | 35.5G 108s* |
#16273 + #13790, cache off, multi          | 11.4G 44s  | 39.5G 96s* | OOM         |
#16273 + #13790 + #16279, cache on, multi  | 5.4G 34s   | 9.9G 53s*  | 18.6G 71s*  |
#16273 + #13790 + #16279, cache off, multi | 6.0G 33s   | 20.2G 53s* | OOM         |

Note that disabling cache now effectively disables the effect of #16273 when multi-processing. All of these changes have no impact on single-process mode. #16279 is built on top of #13790 now so can't be deployed separately.

with combined communication improvements, this is now usable but there is still something generating a ton of memory use that it does not in the single-threaded version.

@DarkLight1337
Copy link
Member

DarkLight1337 commented Apr 10, 2025

Just to be sure, what are you referring to by "baseline" here? Is it main branch?

@p88h
Copy link
Contributor Author

p88h commented Apr 10, 2025

@DarkLight1337 baseline = main branch @24f6b9a71397539a3d02c801963220b0e9a2aef9 (=yesterday, before #16273 was applied)

@njhill
Copy link
Member

njhill commented Apr 10, 2025

Thanks @p88h, this is great! It's expected that #13790 by itself doesn't change much but I expect that the current combination of #13790 with your PR should give a significant reduction over your original PR without it. I understand that to test that you'd have to revert to an older version of the PR though. An easy way to measure the "zero copy" benefits would be to disable that by changing this line in serial_utils.py:

        if not obj.shape or obj.nbytes < INLINE_BUF_SIZE_THRESHOLD:

to

        if True:

@p88h
Copy link
Contributor Author

p88h commented Apr 11, 2025

Interesting...

With inline mode, I do get warnings about read-only Tensors.
Can either ignore them or copy memory out in this case.

But that's not the interesting bit. It seems that the overall memory usage is significantly lower with inline mode, and performance is back to the single-thread levels, at least for the one test case I've ran so far.

I'll add some benchmark results to the PR and figure out a way to make this work efficiently by default.

@p88h
Copy link
Contributor Author

p88h commented Apr 11, 2025

I added benchmark results in the PR.
I've uploaded the benchmark I wrote for this here: https://github.com/p88h/fake-vqa

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
4 participants