Skip to content

Added flux demo #3418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f85820c
Added CPU offloading
cehongwang Mar 26, 2025
4242743
Chagned CPU offload to default
cehongwang Mar 27, 2025
e87d27e
Added support to module with graph break
cehongwang Mar 27, 2025
76cab94
Added back the control flag and fixed the CI
cehongwang Apr 7, 2025
6352110
Chagned CPU offload to default
cehongwang Mar 27, 2025
214e2e6
Added flux demo
cehongwang Feb 27, 2025
c9d8456
changed the file place and deleted unnecessary code
cehongwang Feb 28, 2025
e77737d
Fixed memory overhead and enabled Flux with Mutable Module
cehongwang Mar 3, 2025
42c384d
Supported LoRA
cehongwang Mar 13, 2025
33db1cb
Refined Flux demo, solved a bug of device mismatch, and prototyped Cu…
cehongwang Mar 18, 2025
c69f41c
Enabled Cuda Graph
cehongwang Mar 18, 2025
8f44d7f
Enabled weight streaming and CudaGraph. Supported MTTM saving with dy…
cehongwang Mar 18, 2025
044f4e6
Changed the Refitting test to disable CPU offload
cehongwang Mar 23, 2025
d383be4
Fixed Cuda Error
cehongwang Mar 23, 2025
580fc03
Fixed the bug of SDXL Cuda Error
cehongwang Mar 25, 2025
3e8323f
Changed the way to enable CudaGraph for MTTM
cehongwang Mar 25, 2025
92ae47d
Finalize the refit revision
cehongwang Mar 26, 2025
98cbd76
Fixed the comments
cehongwang Mar 27, 2025
39ac60e
Correct the flux export example
cehongwang Mar 27, 2025
6caf833
Added a textbox to display time the generation process takes
cehongwang Mar 31, 2025
c018151
Added perf script
cehongwang Apr 3, 2025
9e390da
added back control flag
cehongwang Apr 9, 2025
ba76f6d
trying to add quantization to Flux
cehongwang Apr 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added examples/apps/NGRVNG.safetensors
Binary file not shown.
154 changes: 154 additions & 0 deletions examples/apps/flux-demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import time

import gradio as gr
import torch
import torch_tensorrt
from diffusers import FluxPipeline

DEVICE = "cuda:0"
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.float16,
)
pipe.to(DEVICE).to(torch.float16)
backbone = pipe.transformer


batch_size = 2
BATCH = torch.export.Dim("batch", min=1, max=8)

# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model.
# To see this recommendation, you can try exporting using min=1, max=4096
dynamic_shapes = {
"hidden_states": {0: BATCH},
"encoder_hidden_states": {0: BATCH},
"pooled_projections": {0: BATCH},
"timestep": {0: BATCH},
"txt_ids": {},
"img_ids": {},
"guidance": {0: BATCH},
"joint_attention_kwargs": {},
"return_dict": None,
}

settings = {
"strict": False,
"allow_complex_guards_as_runtime_asserts": True,
"enabled_precisions": {torch.float32},
"truncate_double": True,
"min_block_size": 1,
"use_fp32_acc": True,
"use_explicit_typing": True,
"debug": False,
"use_python_runtime": True,
"immutable_weights": False,
"enable_cuda_graph": True,
}

trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings)
trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes)
pipe.transformer = trt_gm


def generate_image(prompt, inference_step, batch_size=2):
start_time = time.time()
image = pipe(
prompt,
output_type="pil",
num_inference_steps=inference_step,
num_images_per_prompt=batch_size,
).images
end_time = time.time()
return image, end_time - start_time


generate_image(["Test"], 2)
torch.cuda.empty_cache()


def model_change(model):
if model == "Torch Model":
pipe.transformer = backbone
backbone.to(DEVICE)
else:
backbone.to("cpu")
pipe.transformer = trt_gm
torch.cuda.empty_cache()


def load_lora(path):

pipe.load_lora_weights(
path,
adapter_name="lora1",
)
pipe.set_adapters(["lora1"], adapter_weights=[1])
pipe.fuse_lora()
pipe.unload_lora_weights()
print("LoRA loaded! Begin refitting")
generate_image(["Test"], 2)
print("Refitting Finished!")


# Create Gradio interface
with gr.Blocks(title="Flux Demo with Torch-TensorRT") as demo:
gr.Markdown("# Flux Image Generation Demo Accelerated by Torch-TensorRT")

with gr.Row():
with gr.Column():
# Input components
prompt_input = gr.Textbox(
label="Prompt", placeholder="Enter your prompt here...", lines=3
)
model_dropdown = gr.Dropdown(
choices=["Torch Model", "Torch-TensorRT Accelerated Model"],
value="Torch-TensorRT Accelerated Model",
label="Model Variant",
)

lora_upload_path = gr.Textbox(
label="LoRA Path",
placeholder="Enter the LoRA checkpoint path here",
value="/home/TensorRT/examples/apps/NGRVNG.safetensors",
lines=2,
)
num_steps = gr.Slider(
minimum=20, maximum=100, value=20, step=1, label="Inference Steps"
)
batch_size = gr.Slider(
minimum=1, maximum=8, value=1, step=1, label="Batch Size"
)

generate_btn = gr.Button("Generate Image")
load_lora_btn = gr.Button("Load LoRA")

with gr.Column():
# Output component
output_image = gr.Gallery(label="Generated Image")
time_taken = gr.Textbox(
label="Generation Time (seconds)", interactive=False
)

# Connect the button to the generation function
model_dropdown.change(model_change, inputs=[model_dropdown])
load_lora_btn.click(
fn=load_lora,
inputs=[
lora_upload_path,
],
)

# Update generate button click to include time output
generate_btn.click(
fn=generate_image,
inputs=[
prompt_input,
num_steps,
batch_size,
],
outputs=[output_image, time_taken],
)

# Launch the interface
if __name__ == "__main__":
demo.launch()
168 changes: 168 additions & 0 deletions examples/apps/flux-quantization-fp32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# %%
# Import the following libraries
# -----------------------------
import re

import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
import torch
import torch_tensorrt
from diffusers import FluxPipeline
from diffusers.models.attention_processor import Attention
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from modelopt.torch.quantization.utils import export_torch_mode
from torch.export._trace import _export
from transformers import AutoModelForCausalLM

# %%
DEVICE = "cuda:0"
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.float32,
)
pipe.transformer = FluxTransformer2DModel(
num_layers=1, num_single_layers=1, guidance_embeds=True
)

pipe.to(DEVICE).to(torch.float32)
# Store the config and transformer backbone
config = pipe.transformer.config
# global backbone
backbone = pipe.transformer
backbone.eval()


def filter_func(name):
pattern = re.compile(
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*"
)
return pattern.match(name) is not None


def generate_image(pipe, prompt, image_name):
seed = 42
image = pipe(
prompt,
output_type="pil",
num_inference_steps=20,
generator=torch.Generator("cuda").manual_seed(seed),
).images[0]
image.save(f"{image_name}.png")
print(f"Image generated using {image_name} model saved as {image_name}.png")


generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code")

# %%
# Quantization


def do_calibrate(
pipe,
prompt: str,
) -> None:
"""
Run calibration steps on the pipeline using the given prompts.
"""
image = pipe(
prompt,
output_type="pil",
num_inference_steps=20,
generator=torch.Generator("cuda").manual_seed(0),
).images[0]


def forward_loop(mod):
# Switch the pipeline's backbone, run calibration
pipe.transformer = mod
do_calibrate(
pipe=pipe,
prompt="test",
)


ptq_config = mtq.FP8_DEFAULT_CFG
backbone = mtq.quantize(backbone, ptq_config, forward_loop)
mtq.disable_quantizer(backbone, filter_func)


# %%
# Export the backbone using torch.export
# --------------------------------------------------
# Define the dummy inputs and their respective dynamic shapes. We export the transformer backbone with dynamic shapes with a ``batch_size=2``
# due to `0/1 specialization <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk>`_

batch_size = 2
BATCH = torch.export.Dim("batch", min=1, max=2)
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512)
# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model.
# To see this recommendation, you can try exporting using min=1, max=4096
IMG_ID = torch.export.Dim("img_id", min=3586, max=4096)
dynamic_shapes = {
"hidden_states": {0: BATCH},
"encoder_hidden_states": {0: BATCH, 1: SEQ_LEN},
"pooled_projections": {0: BATCH},
"timestep": {0: BATCH},
"txt_ids": {0: SEQ_LEN},
"img_ids": {0: IMG_ID},
"guidance": {0: BATCH},
"joint_attention_kwargs": {},
"return_dict": None,
}
# The guidance factor is of type torch.float32
dummy_inputs = {
"hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float32).to(
DEVICE
),
"encoder_hidden_states": torch.randn(
(batch_size, 512, 4096), dtype=torch.float32
).to(DEVICE),
"pooled_projections": torch.randn((batch_size, 768), dtype=torch.float32).to(
DEVICE
),
"timestep": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE),
"txt_ids": torch.randn((512, 3), dtype=torch.float32).to(DEVICE),
"img_ids": torch.randn((4096, 3), dtype=torch.float32).to(DEVICE),
"guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE),
"joint_attention_kwargs": {},
"return_dict": False,
}

# This will create an exported program which is going to be compiled with Torch-TensorRT
with export_torch_mode():
ep = _export(
backbone,
args=(),
kwargs=dummy_inputs,
dynamic_shapes=dynamic_shapes,
strict=False,
allow_complex_guards_as_runtime_asserts=True,
)

with torch_tensorrt.logging.debug():
trt_gm = torch_tensorrt.dynamo.compile(
ep,
inputs=dummy_inputs,
enabled_precisions={torch.float8_e4m3fn},
truncate_double=True,
min_block_size=1,
debug=False,
use_python_runtime=True,
immutable_weights=True,
offload_module_to_cpu=True,
)


del ep
pipe.transformer = trt_gm
pipe.transformer.config = config


# %%
trt_gm.device = torch.device(DEVICE)
# Function which generates images from the flux pipeline

for _ in range(2):
generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code")

# For this dummy model, the fp16 engine size is around 1GB, fp32 engine size is around 2GB
Loading