@@ -58,26 +58,28 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
58
58
59
59
@torch .no_grad ()
60
60
def invoke (self , context : InvocationContext ) -> ImageOutput :
61
- # Load the conditioning data.
62
- cond_data = context .conditioning .load (self .positive_text_conditioning .conditioning_name )
63
- assert len (cond_data .conditionings ) == 1
64
- flux_conditioning = cond_data .conditionings [0 ]
65
- assert isinstance (flux_conditioning , FLUXConditioningInfo )
66
-
67
- latents = self ._run_diffusion (context , flux_conditioning .clip_embeds , flux_conditioning .t5_embeds )
61
+ latents = self ._run_diffusion (context )
68
62
image = self ._run_vae_decoding (context , latents )
69
63
image_dto = context .images .save (image = image )
70
64
return ImageOutput .build (image_dto )
71
65
72
66
def _run_diffusion (
73
67
self ,
74
68
context : InvocationContext ,
75
- clip_embeddings : torch .Tensor ,
76
- t5_embeddings : torch .Tensor ,
77
69
):
78
- transformer_info = context .models .load (self .transformer .transformer )
79
70
inference_dtype = torch .bfloat16
80
71
72
+ # Load the conditioning data.
73
+ cond_data = context .conditioning .load (self .positive_text_conditioning .conditioning_name )
74
+ assert len (cond_data .conditionings ) == 1
75
+ flux_conditioning = cond_data .conditionings [0 ]
76
+ assert isinstance (flux_conditioning , FLUXConditioningInfo )
77
+ flux_conditioning = flux_conditioning .to (dtype = inference_dtype )
78
+ t5_embeddings = flux_conditioning .t5_embeds
79
+ clip_embeddings = flux_conditioning .clip_embeds
80
+
81
+ transformer_info = context .models .load (self .transformer .transformer )
82
+
81
83
# Prepare input noise.
82
84
x = get_noise (
83
85
num_samples = 1 ,
@@ -88,13 +90,13 @@ def _run_diffusion(
88
90
seed = self .seed ,
89
91
)
90
92
91
- img , img_ids = prepare_latent_img_patches (x )
93
+ x , img_ids = prepare_latent_img_patches (x )
92
94
93
95
is_schnell = "schnell" in transformer_info .config .config_path
94
96
95
97
timesteps = get_schedule (
96
98
num_steps = self .num_steps ,
97
- image_seq_len = img .shape [1 ],
99
+ image_seq_len = x .shape [1 ],
98
100
shift = not is_schnell ,
99
101
)
100
102
@@ -135,7 +137,7 @@ def step_callback() -> None:
135
137
136
138
x = denoise (
137
139
model = transformer ,
138
- img = img ,
140
+ img = x ,
139
141
img_ids = img_ids ,
140
142
txt = t5_embeddings ,
141
143
txt_ids = txt_ids ,
0 commit comments