Skip to content

Commit 87261bd

Browse files
authored
FLUX memory management improvements (invoke-ai#6791)
## Summary This PR contains several improvements to memory management for FLUX workflows. It is now possible to achieve better FLUX model caching performance, but this still requires users to manually configure their `ram`/`vram` settings. E.g. a `vram` setting of 16.0 should allow for all quantized FLUX models to be kept in memory on the GPU. Changes: - Check the size of a model on disk and free the requisite space in the model cache before loading it. (This behaviour existed previously, but was removed in https://github.com/invoke-ai/InvokeAI/pull/6072/files. The removal did not seem to be intentional). - Removed the hack to free 24GB of space in the cache before loading the FLUX model. - Split the T5 embedding and CLIP embedding steps into separate functions so that the two models don't both have to be held in RAM at the same time. - Fix a bug in `InvokeLinear8bitLt` that was causing some tensors to be left on the GPU when the model was offloaded to the CPU. (This class is getting very messy due to the non-standard state_dict handling in `bnb.nn.Linear8bitLt`. ) - Tidy up some dtype handling in FluxTextToImageInvocation to avoid situations where we hold references to two copies of the same tensor unnecessarily. - (minor) Misc cleanup of ModelCache: improve docs and remove unused vars. Future: We should revisit our default ram/vram configs. The current defaults are very conservative, and users could see major performance improvements from tuning these values. ## QA Instructions I tested the FLUX workflow with the following configurations and verified that the cache hit rates and memory usage matched the expected behaviour: - `ram = 16` and `vram = 16` - `ram = 16` and `vram = 1` - `ram = 1` and `vram = 1` Note that the changes in this PR are not isolated to FLUX. Since we now check the size of models on disk, we may see slight changes in model cache offload patterns for other models as well. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_
2 parents 3e569c8 + 4e4b6c6 commit 87261bd

File tree

10 files changed

+114
-118
lines changed

10 files changed

+114
-118
lines changed

invokeai/app/invocations/flux_text_encoder.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,18 @@ class FluxTextEncoderInvocation(BaseInvocation):
4040

4141
@torch.no_grad()
4242
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
43-
t5_embeddings, clip_embeddings = self._encode_prompt(context)
43+
# Note: The T5 and CLIP encoding are done in separate functions to ensure that all model references are locally
44+
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
45+
t5_embeddings = self._t5_encode(context)
46+
clip_embeddings = self._clip_encode(context)
4447
conditioning_data = ConditioningFieldData(
4548
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
4649
)
4750

4851
conditioning_name = context.conditioning.save(conditioning_data)
4952
return FluxConditioningOutput.build(conditioning_name)
5053

51-
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
52-
# Load CLIP.
53-
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
54-
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
55-
56-
# Load T5.
54+
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
5755
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
5856
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
5957

@@ -70,6 +68,15 @@ def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torc
7068

7169
prompt_embeds = t5_encoder(prompt)
7270

71+
assert isinstance(prompt_embeds, torch.Tensor)
72+
return prompt_embeds
73+
74+
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
75+
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
76+
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
77+
78+
prompt = [self.prompt]
79+
7380
with (
7481
clip_text_encoder_info as clip_text_encoder,
7582
clip_tokenizer_info as clip_tokenizer,
@@ -81,6 +88,5 @@ def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torc
8188

8289
pooled_prompt_embeds = clip_encoder(prompt)
8390

84-
assert isinstance(prompt_embeds, torch.Tensor)
8591
assert isinstance(pooled_prompt_embeds, torch.Tensor)
86-
return prompt_embeds, pooled_prompt_embeds
92+
return pooled_prompt_embeds

invokeai/app/invocations/flux_text_to_image.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -58,26 +58,28 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
5858

5959
@torch.no_grad()
6060
def invoke(self, context: InvocationContext) -> ImageOutput:
61-
# Load the conditioning data.
62-
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
63-
assert len(cond_data.conditionings) == 1
64-
flux_conditioning = cond_data.conditionings[0]
65-
assert isinstance(flux_conditioning, FLUXConditioningInfo)
66-
67-
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
61+
latents = self._run_diffusion(context)
6862
image = self._run_vae_decoding(context, latents)
6963
image_dto = context.images.save(image=image)
7064
return ImageOutput.build(image_dto)
7165

7266
def _run_diffusion(
7367
self,
7468
context: InvocationContext,
75-
clip_embeddings: torch.Tensor,
76-
t5_embeddings: torch.Tensor,
7769
):
78-
transformer_info = context.models.load(self.transformer.transformer)
7970
inference_dtype = torch.bfloat16
8071

72+
# Load the conditioning data.
73+
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
74+
assert len(cond_data.conditionings) == 1
75+
flux_conditioning = cond_data.conditionings[0]
76+
assert isinstance(flux_conditioning, FLUXConditioningInfo)
77+
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
78+
t5_embeddings = flux_conditioning.t5_embeds
79+
clip_embeddings = flux_conditioning.clip_embeds
80+
81+
transformer_info = context.models.load(self.transformer.transformer)
82+
8183
# Prepare input noise.
8284
x = get_noise(
8385
num_samples=1,
@@ -88,24 +90,19 @@ def _run_diffusion(
8890
seed=self.seed,
8991
)
9092

91-
img, img_ids = prepare_latent_img_patches(x)
93+
x, img_ids = prepare_latent_img_patches(x)
9294

9395
is_schnell = "schnell" in transformer_info.config.config_path
9496

9597
timesteps = get_schedule(
9698
num_steps=self.num_steps,
97-
image_seq_len=img.shape[1],
99+
image_seq_len=x.shape[1],
98100
shift=not is_schnell,
99101
)
100102

101103
bs, t5_seq_len, _ = t5_embeddings.shape
102104
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
103105

104-
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
105-
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
106-
# if the cache is not empty.
107-
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
108-
109106
with transformer_info as transformer:
110107
assert isinstance(transformer, Flux)
111108

@@ -140,7 +137,7 @@ def step_callback() -> None:
140137

141138
x = denoise(
142139
model=transformer,
143-
img=img,
140+
img=x,
144141
img_ids=img_ids,
145142
txt=t5_embeddings,
146143
txt_ids=txt_ids,

invokeai/backend/flux/sampling.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -111,16 +111,7 @@ def denoise(
111111
step_callback: Callable[[], None],
112112
guidance: float = 4.0,
113113
):
114-
dtype = model.txt_in.bias.dtype
115-
116-
# TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller.
117-
img = img.to(dtype=dtype)
118-
img_ids = img_ids.to(dtype=dtype)
119-
txt = txt.to(dtype=dtype)
120-
txt_ids = txt_ids.to(dtype=dtype)
121-
vec = vec.to(dtype=dtype)
122-
123-
# this is ignored for schnell
114+
# guidance_vec is ignored for schnell.
124115
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
125116
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
126117
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
@@ -168,9 +159,9 @@ def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor,
168159
img = repeat(img, "1 ... -> bs ...", bs=bs)
169160

170161
# Generate patch position ids.
171-
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
172-
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
173-
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
162+
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)
163+
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None]
164+
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :]
174165
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
175166

176167
return img, img_ids

invokeai/backend/model_manager/load/load_default.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubMod
7272
pass
7373

7474
config.path = str(self._get_model_path(config))
75+
self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type))
7576
loaded_model = self._load_model(config, submodel_type)
7677

7778
self._ram_cache.put(

invokeai/backend/model_manager/load/model_cache/model_cache_base.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -193,15 +193,6 @@ def get(
193193
"""
194194
pass
195195

196-
@abstractmethod
197-
def exists(
198-
self,
199-
key: str,
200-
submodel_type: Optional[SubModelType] = None,
201-
) -> bool:
202-
"""Return true if the model identified by key and submodel_type is in the cache."""
203-
pass
204-
205196
@abstractmethod
206197
def cache_size(self) -> int:
207198
"""Get the total size of the models currently cached."""

0 commit comments

Comments
 (0)