Skip to content

Commit 6de80f9

Browse files
committed
offload weights to cpu before fp8 online quant
Signed-off-by: yan <[email protected]>
1 parent f17aec0 commit 6de80f9

File tree

6 files changed

+56
-25
lines changed

6 files changed

+56
-25
lines changed

docs/features/quantization/fp8.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,4 +135,4 @@ print(result[0].outputs[0].text)
135135
```
136136

137137
!!! warning
138-
Currently, we load the model at original precision before quantizing down to 8-bits, so you need enough memory to load the whole model.
138+
Currently, by default we load the model at original precision before quantizing down to 8-bits, so you need enough memory to load the whole model. To avoid this, adding `VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT=1` can allow offloading weights to cpu before quantization and quantized weights will be kept in device.

tests/quantization/test_cpu_offload.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-License-Identifier: Apache-2.0
1+
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
# Expanded quantized model tests for CPU offloading
@@ -11,6 +11,16 @@
1111
from ..utils import compare_two_settings
1212

1313

14+
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
15+
reason="fp8 is not supported on this GPU type.")
16+
def test_offload_weights_before_quant_fp8():
17+
# Test quantization of an unquantized checkpoint
18+
compare_two_settings("meta-llama/Llama-3.2-1B-Instruct",
19+
["--quantization", "fp8"], ["--quantization", "fp8"],
20+
{"VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT": "1"},
21+
max_wait_seconds=480)
22+
23+
1424
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
1525
reason="fp8 is not supported on this GPU type.")
1626
def test_cpu_offload_fp8():

vllm/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,6 +1551,10 @@ def _verify_args(self) -> Self:
15511551
raise ValueError("CPU offload space must be non-negative"
15521552
f", but got {self.cpu_offload_gb}")
15531553

1554+
if self.cpu_offload_gb > 0 and envs.VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT:
1555+
raise ValueError("CPU offload can't work together with"
1556+
"OFFLOAD_WEIGHTS_BEFORE_QUANT")
1557+
15541558
if self.gpu_memory_utilization > 1.0:
15551559
raise ValueError(
15561560
"GPU memory utilization must be less than 1.0. Got "

vllm/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@
133133
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300
134134
VLLM_KV_CACHE_LAYOUT: Optional[str] = None
135135
VLLM_COMPUTE_NANS_IN_LOGITS: bool = False
136+
VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT: bool = False
136137

137138

138139
def get_default_cache_root():
@@ -918,6 +919,10 @@ def get_vllm_port() -> Optional[int]:
918919
# or bad hardware but it may add compute overhead.
919920
"VLLM_COMPUTE_NANS_IN_LOGITS":
920921
lambda: bool(int(os.getenv("VLLM_COMPUTE_NANS_IN_LOGITS", "0"))),
922+
923+
# Offload model weights to cpu before online fp8 quantization
924+
"VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT":
925+
lambda: os.environ.get("VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT", "0") == "1",
921926
}
922927

923928
# --8<-- [end:env-vars-definition]

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,10 +246,14 @@ def create_weights(
246246
if self.quant_config.is_checkpoint_fp8_serialized else
247247
params_dtype)
248248

249+
# Force offloading weights to cpu if VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT
250+
# enabled, otherwise use original device config which can be gpu or cpu
251+
# (may happen when cpu_offload_gb > 0)
249252
weight = ModelWeightParameter(data=torch.empty(
250253
output_size_per_partition,
251254
input_size_per_partition,
252-
dtype=weight_dtype),
255+
dtype=weight_dtype,
256+
device="cpu" if envs.VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT else None),
253257
input_dim=1,
254258
output_dim=0,
255259
weight_loader=weight_loader)
@@ -513,16 +517,19 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
513517
num_experts,
514518
2 * intermediate_size_per_partition,
515519
hidden_size,
516-
dtype=params_dtype),
520+
dtype=params_dtype,
521+
device="cpu" if envs.VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT else None),
517522
requires_grad=False)
523+
518524
layer.register_parameter("w13_weight", w13_weight)
519525
set_weight_attrs(w13_weight, extra_weight_attrs)
520526

521527
w2_weight = torch.nn.Parameter(torch.empty(
522528
num_experts,
523529
hidden_size,
524530
intermediate_size_per_partition,
525-
dtype=params_dtype),
531+
dtype=params_dtype,
532+
device="cpu" if envs.VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT else None),
526533
requires_grad=False)
527534
layer.register_parameter("w2_weight", w2_weight)
528535
set_weight_attrs(w2_weight, extra_weight_attrs)

vllm/model_executor/model_loader/utils.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from vllm.attention import Attention
1717
from vllm.config import (ModelConfig, ModelImpl, VllmConfig,
1818
set_current_vllm_config)
19+
from vllm.envs import VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT
1920
from vllm.logger import init_logger
2021
from vllm.model_executor.layers.linear import QKVCrossParallelLinear
2122
from vllm.model_executor.layers.quantization.base_config import (
@@ -144,26 +145,30 @@ def device_loading_context(module: torch.nn.Module,
144145
yield module
145146

146147
finally:
147-
# Restore parameters to their original devices, ignoring new parameters
148-
pin_memory = is_pin_memory_available()
149-
for name, p in module.named_parameters():
150-
if name in original_device_states:
151-
original_device: torch.device = original_device_states[name]
152-
if original_device.type == "cpu":
153-
# `torch.empty_like` does not support `pin_memory` argument
154-
cpu_data = torch.empty_strided(
155-
size=p.data.size(),
156-
stride=p.data.stride(),
157-
dtype=p.data.dtype,
158-
layout=p.data.layout,
159-
device="cpu",
160-
pin_memory=pin_memory,
161-
)
162-
cpu_data.copy_(p.data)
163-
p.data = cpu_data
164-
else:
165-
p.data = p.data.to(original_device)
166-
# New parameters or parameters already on target device are untouched
148+
# If weights were loaded onto the CPU for FP8 online quantization, there
149+
# is no need to move them back to the original device.
150+
if not VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT:
151+
# Restore parameters to their original devices, ignoring new parameters # noqa: E501
152+
pin_memory = is_pin_memory_available()
153+
for name, p in module.named_parameters():
154+
if name in original_device_states:
155+
original_device: torch.device = original_device_states[
156+
name]
157+
if original_device.type == "cpu":
158+
# `torch.empty_like` does not support `pin_memory` argument # noqa: E501
159+
cpu_data = torch.empty_strided(
160+
size=p.data.size(),
161+
stride=p.data.stride(),
162+
dtype=p.data.dtype,
163+
layout=p.data.layout,
164+
device="cpu",
165+
pin_memory=pin_memory,
166+
)
167+
cpu_data.copy_(p.data)
168+
p.data = cpu_data
169+
else:
170+
p.data = p.data.to(original_device)
171+
# New parameters or parameters already on target device are untouched # noqa: E501
167172

168173

169174
def resolve_transformers_arch(model_config: ModelConfig,

0 commit comments

Comments
 (0)