Skip to content

Commit 4e4b6c6

Browse files
committed
Tidy variable management and dtype handling in FluxTextToImageInvocation.
1 parent 5e8cf9f commit 4e4b6c6

File tree

3 files changed

+24
-26
lines changed

3 files changed

+24
-26
lines changed

invokeai/app/invocations/flux_text_to_image.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -58,26 +58,28 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
5858

5959
@torch.no_grad()
6060
def invoke(self, context: InvocationContext) -> ImageOutput:
61-
# Load the conditioning data.
62-
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
63-
assert len(cond_data.conditionings) == 1
64-
flux_conditioning = cond_data.conditionings[0]
65-
assert isinstance(flux_conditioning, FLUXConditioningInfo)
66-
67-
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
61+
latents = self._run_diffusion(context)
6862
image = self._run_vae_decoding(context, latents)
6963
image_dto = context.images.save(image=image)
7064
return ImageOutput.build(image_dto)
7165

7266
def _run_diffusion(
7367
self,
7468
context: InvocationContext,
75-
clip_embeddings: torch.Tensor,
76-
t5_embeddings: torch.Tensor,
7769
):
78-
transformer_info = context.models.load(self.transformer.transformer)
7970
inference_dtype = torch.bfloat16
8071

72+
# Load the conditioning data.
73+
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
74+
assert len(cond_data.conditionings) == 1
75+
flux_conditioning = cond_data.conditionings[0]
76+
assert isinstance(flux_conditioning, FLUXConditioningInfo)
77+
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
78+
t5_embeddings = flux_conditioning.t5_embeds
79+
clip_embeddings = flux_conditioning.clip_embeds
80+
81+
transformer_info = context.models.load(self.transformer.transformer)
82+
8183
# Prepare input noise.
8284
x = get_noise(
8385
num_samples=1,
@@ -88,13 +90,13 @@ def _run_diffusion(
8890
seed=self.seed,
8991
)
9092

91-
img, img_ids = prepare_latent_img_patches(x)
93+
x, img_ids = prepare_latent_img_patches(x)
9294

9395
is_schnell = "schnell" in transformer_info.config.config_path
9496

9597
timesteps = get_schedule(
9698
num_steps=self.num_steps,
97-
image_seq_len=img.shape[1],
99+
image_seq_len=x.shape[1],
98100
shift=not is_schnell,
99101
)
100102

@@ -135,7 +137,7 @@ def step_callback() -> None:
135137

136138
x = denoise(
137139
model=transformer,
138-
img=img,
140+
img=x,
139141
img_ids=img_ids,
140142
txt=t5_embeddings,
141143
txt_ids=txt_ids,

invokeai/backend/flux/sampling.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -111,16 +111,7 @@ def denoise(
111111
step_callback: Callable[[], None],
112112
guidance: float = 4.0,
113113
):
114-
dtype = model.txt_in.bias.dtype
115-
116-
# TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller.
117-
img = img.to(dtype=dtype)
118-
img_ids = img_ids.to(dtype=dtype)
119-
txt = txt.to(dtype=dtype)
120-
txt_ids = txt_ids.to(dtype=dtype)
121-
vec = vec.to(dtype=dtype)
122-
123-
# this is ignored for schnell
114+
# guidance_vec is ignored for schnell.
124115
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
125116
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
126117
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
@@ -168,9 +159,9 @@ def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor,
168159
img = repeat(img, "1 ... -> bs ...", bs=bs)
169160

170161
# Generate patch position ids.
171-
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
172-
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
173-
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
162+
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)
163+
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None]
164+
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :]
174165
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
175166

176167
return img, img_ids

invokeai/backend/stable_diffusion/diffusion/conditioning_data.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ class FLUXConditioningInfo:
4343
clip_embeds: torch.Tensor
4444
t5_embeds: torch.Tensor
4545

46+
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
47+
self.clip_embeds = self.clip_embeds.to(device=device, dtype=dtype)
48+
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
49+
return self
50+
4651

4752
@dataclass
4853
class ConditioningFieldData:

0 commit comments

Comments
 (0)