Skip to content

Commit ba76f6d

Browse files
committed
trying to add quantization to Flux
1 parent 9e390da commit ba76f6d

File tree

6 files changed

+381
-9
lines changed

6 files changed

+381
-9
lines changed
+168
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# %%
2+
# Import the following libraries
3+
# -----------------------------
4+
import re
5+
6+
import modelopt.torch.opt as mto
7+
import modelopt.torch.quantization as mtq
8+
import torch
9+
import torch_tensorrt
10+
from diffusers import FluxPipeline
11+
from diffusers.models.attention_processor import Attention
12+
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
13+
from modelopt.torch.quantization.utils import export_torch_mode
14+
from torch.export._trace import _export
15+
from transformers import AutoModelForCausalLM
16+
17+
# %%
18+
DEVICE = "cuda:0"
19+
pipe = FluxPipeline.from_pretrained(
20+
"black-forest-labs/FLUX.1-dev",
21+
torch_dtype=torch.float32,
22+
)
23+
pipe.transformer = FluxTransformer2DModel(
24+
num_layers=1, num_single_layers=1, guidance_embeds=True
25+
)
26+
27+
pipe.to(DEVICE).to(torch.float32)
28+
# Store the config and transformer backbone
29+
config = pipe.transformer.config
30+
# global backbone
31+
backbone = pipe.transformer
32+
backbone.eval()
33+
34+
35+
def filter_func(name):
36+
pattern = re.compile(
37+
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*"
38+
)
39+
return pattern.match(name) is not None
40+
41+
42+
def generate_image(pipe, prompt, image_name):
43+
seed = 42
44+
image = pipe(
45+
prompt,
46+
output_type="pil",
47+
num_inference_steps=20,
48+
generator=torch.Generator("cuda").manual_seed(seed),
49+
).images[0]
50+
image.save(f"{image_name}.png")
51+
print(f"Image generated using {image_name} model saved as {image_name}.png")
52+
53+
54+
generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code")
55+
56+
# %%
57+
# Quantization
58+
59+
60+
def do_calibrate(
61+
pipe,
62+
prompt: str,
63+
) -> None:
64+
"""
65+
Run calibration steps on the pipeline using the given prompts.
66+
"""
67+
image = pipe(
68+
prompt,
69+
output_type="pil",
70+
num_inference_steps=20,
71+
generator=torch.Generator("cuda").manual_seed(0),
72+
).images[0]
73+
74+
75+
def forward_loop(mod):
76+
# Switch the pipeline's backbone, run calibration
77+
pipe.transformer = mod
78+
do_calibrate(
79+
pipe=pipe,
80+
prompt="test",
81+
)
82+
83+
84+
ptq_config = mtq.FP8_DEFAULT_CFG
85+
backbone = mtq.quantize(backbone, ptq_config, forward_loop)
86+
mtq.disable_quantizer(backbone, filter_func)
87+
88+
89+
# %%
90+
# Export the backbone using torch.export
91+
# --------------------------------------------------
92+
# Define the dummy inputs and their respective dynamic shapes. We export the transformer backbone with dynamic shapes with a ``batch_size=2``
93+
# due to `0/1 specialization <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk>`_
94+
95+
batch_size = 2
96+
BATCH = torch.export.Dim("batch", min=1, max=2)
97+
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512)
98+
# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model.
99+
# To see this recommendation, you can try exporting using min=1, max=4096
100+
IMG_ID = torch.export.Dim("img_id", min=3586, max=4096)
101+
dynamic_shapes = {
102+
"hidden_states": {0: BATCH},
103+
"encoder_hidden_states": {0: BATCH, 1: SEQ_LEN},
104+
"pooled_projections": {0: BATCH},
105+
"timestep": {0: BATCH},
106+
"txt_ids": {0: SEQ_LEN},
107+
"img_ids": {0: IMG_ID},
108+
"guidance": {0: BATCH},
109+
"joint_attention_kwargs": {},
110+
"return_dict": None,
111+
}
112+
# The guidance factor is of type torch.float32
113+
dummy_inputs = {
114+
"hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float32).to(
115+
DEVICE
116+
),
117+
"encoder_hidden_states": torch.randn(
118+
(batch_size, 512, 4096), dtype=torch.float32
119+
).to(DEVICE),
120+
"pooled_projections": torch.randn((batch_size, 768), dtype=torch.float32).to(
121+
DEVICE
122+
),
123+
"timestep": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE),
124+
"txt_ids": torch.randn((512, 3), dtype=torch.float32).to(DEVICE),
125+
"img_ids": torch.randn((4096, 3), dtype=torch.float32).to(DEVICE),
126+
"guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE),
127+
"joint_attention_kwargs": {},
128+
"return_dict": False,
129+
}
130+
131+
# This will create an exported program which is going to be compiled with Torch-TensorRT
132+
with export_torch_mode():
133+
ep = _export(
134+
backbone,
135+
args=(),
136+
kwargs=dummy_inputs,
137+
dynamic_shapes=dynamic_shapes,
138+
strict=False,
139+
allow_complex_guards_as_runtime_asserts=True,
140+
)
141+
142+
with torch_tensorrt.logging.debug():
143+
trt_gm = torch_tensorrt.dynamo.compile(
144+
ep,
145+
inputs=dummy_inputs,
146+
enabled_precisions={torch.float8_e4m3fn},
147+
truncate_double=True,
148+
min_block_size=1,
149+
debug=False,
150+
use_python_runtime=True,
151+
immutable_weights=True,
152+
offload_module_to_cpu=True,
153+
)
154+
155+
156+
del ep
157+
pipe.transformer = trt_gm
158+
pipe.transformer.config = config
159+
160+
161+
# %%
162+
trt_gm.device = torch.device(DEVICE)
163+
# Function which generates images from the flux pipeline
164+
165+
for _ in range(2):
166+
generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code")
167+
168+
# For this dummy model, the fp16 engine size is around 1GB, fp32 engine size is around 2GB

examples/apps/flux-quantization.py

+163
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# %%
2+
# Import the following libraries
3+
# -----------------------------
4+
import re
5+
6+
import modelopt.torch.opt as mto
7+
import modelopt.torch.quantization as mtq
8+
import torch
9+
import torch_tensorrt
10+
from diffusers import FluxPipeline
11+
from diffusers.models.attention_processor import Attention
12+
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
13+
from modelopt.torch.quantization.utils import export_torch_mode
14+
from torch.export._trace import _export
15+
from transformers import AutoModelForCausalLM
16+
17+
# Load the ModelOpt-modified model architecture and weights using Huggingface APIs
18+
19+
# %%
20+
DEVICE = "cuda:0"
21+
pipe = FluxPipeline.from_pretrained(
22+
"black-forest-labs/FLUX.1-dev",
23+
torch_dtype=torch.float16,
24+
)
25+
pipe.transformer = FluxTransformer2DModel(
26+
num_layers=1, num_single_layers=1, guidance_embeds=True
27+
)
28+
29+
pipe.to(DEVICE).to(torch.float16)
30+
# Store the config and transformer backbone
31+
config = pipe.transformer.config
32+
# global backbone
33+
backbone = pipe.transformer
34+
backbone.eval()
35+
36+
37+
def filter_func(name):
38+
pattern = re.compile(
39+
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*"
40+
)
41+
return pattern.match(name) is not None
42+
43+
44+
def generate_image(pipe, prompt, image_name):
45+
seed = 42
46+
image = pipe(
47+
prompt,
48+
output_type="pil",
49+
num_inference_steps=20,
50+
generator=torch.Generator("cuda").manual_seed(seed),
51+
).images[0]
52+
image.save(f"{image_name}.png")
53+
print(f"Image generated using {image_name} model saved as {image_name}.png")
54+
55+
56+
generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code")
57+
58+
# %%
59+
# Quantization
60+
61+
62+
def do_calibrate(
63+
pipe,
64+
prompt: str,
65+
) -> None:
66+
"""
67+
Run calibration steps on the pipeline using the given prompts.
68+
"""
69+
image = pipe(
70+
prompt,
71+
output_type="pil",
72+
num_inference_steps=20,
73+
generator=torch.Generator("cuda").manual_seed(0),
74+
).images[0]
75+
76+
77+
def forward_loop(mod):
78+
# Switch the pipeline's backbone, run calibration
79+
pipe.transformer = mod
80+
do_calibrate(
81+
pipe=pipe,
82+
prompt="test",
83+
)
84+
85+
86+
ptq_config = mtq.FP8_DEFAULT_CFG
87+
backbone = mtq.quantize(backbone, ptq_config, forward_loop)
88+
mtq.disable_quantizer(backbone, filter_func)
89+
90+
batch_size = 1
91+
BATCH = torch.export.Dim("batch", min=1, max=2)
92+
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512)
93+
# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model.
94+
# To see this recommendation, you can try exporting using min=1, max=4096
95+
IMG_ID = torch.export.Dim("img_id", min=3586, max=4096)
96+
dynamic_shapes = {
97+
"hidden_states": {0: BATCH},
98+
"encoder_hidden_states": {0: BATCH, 1: SEQ_LEN},
99+
"pooled_projections": {0: BATCH},
100+
"timestep": {0: BATCH},
101+
"txt_ids": {0: SEQ_LEN},
102+
"img_ids": {0: IMG_ID},
103+
"guidance": {0: BATCH},
104+
"joint_attention_kwargs": {},
105+
"return_dict": None,
106+
}
107+
# The guidance factor is of type torch.float32
108+
dummy_inputs = {
109+
"hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(
110+
DEVICE
111+
),
112+
"encoder_hidden_states": torch.randn(
113+
(batch_size, 512, 4096), dtype=torch.float16
114+
).to(DEVICE),
115+
"pooled_projections": torch.randn((batch_size, 768), dtype=torch.float16).to(
116+
DEVICE
117+
),
118+
"timestep": torch.tensor([1.0] * batch_size, dtype=torch.float16).to(DEVICE),
119+
"txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE),
120+
"img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE),
121+
"guidance": torch.tensor([1.0] * batch_size, dtype=torch.float32).to(DEVICE),
122+
"joint_attention_kwargs": {},
123+
"return_dict": False,
124+
}
125+
126+
# This will create an exported program which is going to be compiled with Torch-TensorRT
127+
with export_torch_mode():
128+
ep = _export(
129+
backbone,
130+
args=(),
131+
kwargs=dummy_inputs,
132+
# dynamic_shapes=dynamic_shapes,
133+
strict=False,
134+
allow_complex_guards_as_runtime_asserts=True,
135+
)
136+
137+
with torch_tensorrt.logging.debug():
138+
trt_gm = torch_tensorrt.dynamo.compile(
139+
ep,
140+
inputs=dummy_inputs,
141+
enabled_precisions={torch.float8_e4m3fn, torch.float16},
142+
truncate_double=True,
143+
min_block_size=1,
144+
debug=True,
145+
use_python_runtime=True,
146+
immutable_weights=True,
147+
offload_module_to_cpu=True,
148+
)
149+
150+
151+
del ep
152+
pipe.transformer = trt_gm
153+
pipe.transformer.config = config
154+
155+
156+
# %%
157+
trt_gm.device = torch.device(DEVICE)
158+
# Function which generates images from the flux pipeline
159+
160+
for _ in range(2):
161+
generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code")
162+
163+
# For this dummy model, the fp16 engine size is around 1GB, fp32 engine size is around 2GB

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,9 @@ def aten_ops_neg(
597597
)
598598
else:
599599

600-
@dynamo_tensorrt_converter(torch.ops.tensorrt.quantize_op.default)
600+
@dynamo_tensorrt_converter(
601+
torch.ops.tensorrt.quantize_op.default, supports_dynamic_shapes=True
602+
)
601603
def aten_ops_quantize_op(
602604
ctx: ConversionContext,
603605
target: Target,

0 commit comments

Comments
 (0)