@@ -130,17 +130,31 @@ def get_models(
130
130
)
131
131
132
132
133
- def text2img_dataloader (train_dataset , train_batch_size , tokenizer , vae , text_encoder ):
133
+ @torch .no_grad ()
134
+ def text2img_dataloader (
135
+ train_dataset ,
136
+ train_batch_size ,
137
+ tokenizer ,
138
+ vae ,
139
+ text_encoder ,
140
+ cached_latents : bool = False ,
141
+ ):
142
+
143
+ if cached_latents :
144
+ cached_latents_dataset = []
145
+ for idx in tqdm (range (len (train_dataset ))):
146
+ batch = train_dataset [idx ]
147
+ # rint(batch)
148
+ latents = vae .encode (
149
+ batch ["instance_images" ].unsqueeze (0 ).to (dtype = vae .dtype ).to (vae .device )
150
+ ).latent_dist .sample ()
151
+ latents = latents * 0.18215
152
+ batch ["instance_images" ] = latents .squeeze (0 )
153
+ cached_latents_dataset .append (batch )
154
+
134
155
def collate_fn (examples ):
135
156
input_ids = [example ["instance_prompt_ids" ] for example in examples ]
136
157
pixel_values = [example ["instance_images" ] for example in examples ]
137
-
138
- # Concat class and instance examples for prior preservation.
139
- # We do this to avoid doing two forward passes.
140
- if examples [0 ].get ("class_prompt_ids" , None ) is not None :
141
- input_ids += [example ["class_prompt_ids" ] for example in examples ]
142
- pixel_values += [example ["class_images" ] for example in examples ]
143
-
144
158
pixel_values = torch .stack (pixel_values )
145
159
pixel_values = pixel_values .to (memory_format = torch .contiguous_format ).float ()
146
160
@@ -168,27 +182,60 @@ def collate_fn(examples):
168
182
shuffle = True ,
169
183
collate_fn = collate_fn ,
170
184
)
185
+ if cached_latents :
186
+
187
+ train_dataloader = torch .utils .data .DataLoader (
188
+ cached_latents_dataset ,
189
+ batch_size = train_batch_size ,
190
+ shuffle = True ,
191
+ collate_fn = collate_fn ,
192
+ )
193
+
194
+ print ("PTI : Using cached latent." )
195
+
196
+ else :
197
+ train_dataloader = torch .utils .data .DataLoader (
198
+ train_dataset ,
199
+ batch_size = train_batch_size ,
200
+ shuffle = True ,
201
+ collate_fn = collate_fn ,
202
+ )
171
203
172
204
return train_dataloader
173
205
174
- def inpainting_dataloader (train_dataset , train_batch_size , tokenizer , vae , text_encoder ):
206
+
207
+ def inpainting_dataloader (
208
+ train_dataset , train_batch_size , tokenizer , vae , text_encoder
209
+ ):
175
210
def collate_fn (examples ):
176
211
input_ids = [example ["instance_prompt_ids" ] for example in examples ]
177
212
pixel_values = [example ["instance_images" ] for example in examples ]
178
213
mask_values = [example ["instance_masks" ] for example in examples ]
179
- masked_image_values = [example ["instance_masked_images" ] for example in examples ]
214
+ masked_image_values = [
215
+ example ["instance_masked_images" ] for example in examples
216
+ ]
180
217
181
218
# Concat class and instance examples for prior preservation.
182
219
# We do this to avoid doing two forward passes.
183
220
if examples [0 ].get ("class_prompt_ids" , None ) is not None :
184
221
input_ids += [example ["class_prompt_ids" ] for example in examples ]
185
222
pixel_values += [example ["class_images" ] for example in examples ]
186
223
mask_values += [example ["class_masks" ] for example in examples ]
187
- masked_image_values += [example ["class_masked_images" ] for example in examples ]
224
+ masked_image_values += [
225
+ example ["class_masked_images" ] for example in examples
226
+ ]
188
227
189
- pixel_values = torch .stack (pixel_values ).to (memory_format = torch .contiguous_format ).float ()
190
- mask_values = torch .stack (mask_values ).to (memory_format = torch .contiguous_format ).float ()
191
- masked_image_values = torch .stack (masked_image_values ).to (memory_format = torch .contiguous_format ).float ()
228
+ pixel_values = (
229
+ torch .stack (pixel_values ).to (memory_format = torch .contiguous_format ).float ()
230
+ )
231
+ mask_values = (
232
+ torch .stack (mask_values ).to (memory_format = torch .contiguous_format ).float ()
233
+ )
234
+ masked_image_values = (
235
+ torch .stack (masked_image_values )
236
+ .to (memory_format = torch .contiguous_format )
237
+ .float ()
238
+ )
192
239
193
240
input_ids = tokenizer .pad (
194
241
{"input_ids" : input_ids },
@@ -201,7 +248,7 @@ def collate_fn(examples):
201
248
"input_ids" : input_ids ,
202
249
"pixel_values" : pixel_values ,
203
250
"mask_values" : mask_values ,
204
- "masked_image_values" : masked_image_values
251
+ "masked_image_values" : masked_image_values ,
205
252
}
206
253
207
254
if examples [0 ].get ("mask" , None ) is not None :
@@ -218,6 +265,7 @@ def collate_fn(examples):
218
265
219
266
return train_dataloader
220
267
268
+
221
269
def loss_step (
222
270
batch ,
223
271
unet ,
@@ -228,23 +276,30 @@ def loss_step(
228
276
t_mutliplier = 1.0 ,
229
277
mixed_precision = False ,
230
278
mask_temperature = 1.0 ,
279
+ cached_latents : bool = False ,
231
280
):
232
281
weight_dtype = torch .float32
233
-
234
- latents = vae .encode (
235
- batch ["pixel_values" ].to (dtype = weight_dtype ).to (unet .device )
236
- ).latent_dist .sample ()
237
- latents = latents * 0.18215
238
-
239
- if train_inpainting :
240
- masked_image_latents = vae .encode (
241
- batch ["masked_image_values" ].to (dtype = weight_dtype ).to (unet .device )
282
+ if not cached_latents :
283
+ latents = vae .encode (
284
+ batch ["pixel_values" ].to (dtype = weight_dtype ).to (unet .device )
242
285
).latent_dist .sample ()
243
- masked_image_latents = masked_image_latents * 0.18215
244
- mask = F .interpolate (
245
- batch ["mask_values" ].to (dtype = weight_dtype ).to (unet .device ),
246
- scale_factor = 1 / 8
247
- )
286
+ latents = latents * 0.18215
287
+
288
+ if train_inpainting :
289
+ masked_image_latents = vae .encode (
290
+ batch ["masked_image_values" ].to (dtype = weight_dtype ).to (unet .device )
291
+ ).latent_dist .sample ()
292
+ masked_image_latents = masked_image_latents * 0.18215
293
+ mask = F .interpolate (
294
+ batch ["mask_values" ].to (dtype = weight_dtype ).to (unet .device ),
295
+ scale_factor = 1 / 8 ,
296
+ )
297
+ else :
298
+ latents = batch ["pixel_values" ]
299
+
300
+ if train_inpainting :
301
+ masked_image_latents = batch ["masked_image_latents" ]
302
+ mask = batch ["mask_values" ]
248
303
249
304
noise = torch .randn_like (latents )
250
305
bsz = latents .shape [0 ]
@@ -260,7 +315,9 @@ def loss_step(
260
315
noisy_latents = scheduler .add_noise (latents , noise , timesteps )
261
316
262
317
if train_inpainting :
263
- latent_model_input = torch .cat ([noisy_latents , mask , masked_image_latents ], dim = 1 )
318
+ latent_model_input = torch .cat (
319
+ [noisy_latents , mask , masked_image_latents ], dim = 1
320
+ )
264
321
else :
265
322
latent_model_input = noisy_latents
266
323
@@ -271,7 +328,9 @@ def loss_step(
271
328
batch ["input_ids" ].to (text_encoder .device )
272
329
)[0 ]
273
330
274
- model_pred = unet (latent_model_input , timesteps , encoder_hidden_states ).sample
331
+ model_pred = unet (
332
+ latent_model_input , timesteps , encoder_hidden_states
333
+ ).sample
275
334
else :
276
335
277
336
encoder_hidden_states = text_encoder (
@@ -311,7 +370,12 @@ def loss_step(
311
370
312
371
target = target * mask
313
372
314
- loss = F .mse_loss (model_pred .float (), target .float (), reduction = "mean" )
373
+ loss = (
374
+ F .mse_loss (model_pred .float (), target .float (), reduction = "none" )
375
+ .mean ([1 , 2 , 3 ])
376
+ .mean ()
377
+ )
378
+
315
379
return loss
316
380
317
381
@@ -331,6 +395,7 @@ def train_inversion(
331
395
tokenizer ,
332
396
lr_scheduler ,
333
397
test_image_path : str ,
398
+ cached_latents : bool ,
334
399
accum_iter : int = 1 ,
335
400
log_wandb : bool = False ,
336
401
wandb_log_prompt_cnt : int = 10 ,
@@ -372,6 +437,7 @@ def train_inversion(
372
437
scheduler ,
373
438
train_inpainting = train_inpainting ,
374
439
mixed_precision = mixed_precision ,
440
+ cached_latents = cached_latents ,
375
441
)
376
442
/ accum_iter
377
443
)
@@ -381,6 +447,13 @@ def train_inversion(
381
447
loss_sum += loss .detach ().item ()
382
448
383
449
if global_step % accum_iter == 0 :
450
+ # print gradient of text encoder embedding
451
+ print (
452
+ text_encoder .get_input_embeddings ()
453
+ .weight .grad [index_updates , :]
454
+ .norm (dim = - 1 )
455
+ .mean ()
456
+ )
384
457
optimizer .step ()
385
458
optimizer .zero_grad ()
386
459
@@ -455,7 +528,11 @@ def train_inversion(
455
528
# open all images in test_image_path
456
529
images = []
457
530
for file in os .listdir (test_image_path ):
458
- if file .lower ().endswith (".png" ) or file .lower ().endswith (".jpg" ) or file .lower ().endswith (".jpeg" ):
531
+ if (
532
+ file .lower ().endswith (".png" )
533
+ or file .lower ().endswith (".jpg" )
534
+ or file .lower ().endswith (".jpeg" )
535
+ ):
459
536
images .append (
460
537
Image .open (os .path .join (test_image_path , file ))
461
538
)
@@ -507,6 +584,7 @@ def perform_tuning(
507
584
out_name : str ,
508
585
tokenizer ,
509
586
test_image_path : str ,
587
+ cached_latents : bool ,
510
588
log_wandb : bool = False ,
511
589
wandb_log_prompt_cnt : int = 10 ,
512
590
class_token : str = "person" ,
@@ -547,6 +625,7 @@ def perform_tuning(
547
625
t_mutliplier = 0.8 ,
548
626
mixed_precision = True ,
549
627
mask_temperature = mask_temperature ,
628
+ cached_latents = cached_latents ,
550
629
)
551
630
loss_sum += loss .detach ().item ()
552
631
@@ -671,18 +750,12 @@ def train(
671
750
train_text_encoder : bool = True ,
672
751
pretrained_vae_name_or_path : str = None ,
673
752
revision : Optional [str ] = None ,
674
- class_data_dir : Optional [str ] = None ,
675
- stochastic_attribute : Optional [str ] = None ,
676
753
perform_inversion : bool = True ,
677
754
use_template : Literal [None , "object" , "style" ] = None ,
678
755
train_inpainting : bool = False ,
679
756
placeholder_tokens : str = "" ,
680
757
placeholder_token_at_data : Optional [str ] = None ,
681
758
initializer_tokens : Optional [str ] = None ,
682
- class_prompt : Optional [str ] = None ,
683
- with_prior_preservation : bool = False ,
684
- prior_loss_weight : float = 1.0 ,
685
- num_class_images : int = 100 ,
686
759
seed : int = 42 ,
687
760
resolution : int = 512 ,
688
761
color_jitter : bool = True ,
@@ -693,7 +766,6 @@ def train(
693
766
save_steps : int = 100 ,
694
767
gradient_accumulation_steps : int = 4 ,
695
768
gradient_checkpointing : bool = False ,
696
- mixed_precision = "fp16" ,
697
769
lora_rank : int = 4 ,
698
770
lora_unet_target_modules = {"CrossAttention" , "Attention" , "GEGLU" },
699
771
lora_clip_target_modules = {"CLIPAttention" },
@@ -707,6 +779,7 @@ def train(
707
779
continue_inversion : bool = False ,
708
780
continue_inversion_lr : Optional [float ] = None ,
709
781
use_face_segmentation_condition : bool = False ,
782
+ cached_latents : bool = True ,
710
783
use_mask_captioned_data : bool = False ,
711
784
mask_temperature : float = 1.0 ,
712
785
scale_lr : bool = False ,
@@ -820,11 +893,8 @@ def train(
820
893
821
894
train_dataset = PivotalTuningDatasetCapation (
822
895
instance_data_root = instance_data_dir ,
823
- stochastic_attribute = stochastic_attribute ,
824
896
token_map = token_map ,
825
897
use_template = use_template ,
826
- class_data_root = class_data_dir if with_prior_preservation else None ,
827
- class_prompt = class_prompt ,
828
898
tokenizer = tokenizer ,
829
899
size = resolution ,
830
900
color_jitter = color_jitter ,
@@ -836,12 +906,19 @@ def train(
836
906
train_dataset .blur_amount = 200
837
907
838
908
if train_inpainting :
909
+ assert not cached_latents , "Cached latents not supported for inpainting"
910
+
839
911
train_dataloader = inpainting_dataloader (
840
912
train_dataset , train_batch_size , tokenizer , vae , text_encoder
841
913
)
842
914
else :
843
915
train_dataloader = text2img_dataloader (
844
- train_dataset , train_batch_size , tokenizer , vae , text_encoder
916
+ train_dataset ,
917
+ train_batch_size ,
918
+ tokenizer ,
919
+ vae ,
920
+ text_encoder ,
921
+ cached_latents = cached_latents ,
845
922
)
846
923
847
924
index_no_updates = torch .arange (len (tokenizer )) != - 1
@@ -860,6 +937,8 @@ def train(
860
937
for param in params_to_freeze :
861
938
param .requires_grad = False
862
939
940
+ if cached_latents :
941
+ vae = None
863
942
# STEP 1 : Perform Inversion
864
943
if perform_inversion :
865
944
preview_training_batch (train_dataloader , "inversion" )
@@ -889,6 +968,7 @@ def train(
889
968
text_encoder ,
890
969
train_dataloader ,
891
970
max_train_steps_ti ,
971
+ cached_latents = cached_latents ,
892
972
accum_iter = gradient_accumulation_steps ,
893
973
scheduler = noise_scheduler ,
894
974
index_no_updates = index_no_updates ,
@@ -1003,6 +1083,7 @@ def train(
1003
1083
text_encoder ,
1004
1084
train_dataloader ,
1005
1085
max_train_steps_tuning ,
1086
+ cached_latents = cached_latents ,
1006
1087
scheduler = noise_scheduler ,
1007
1088
optimizer = lora_optimizers ,
1008
1089
save_steps = save_steps ,
0 commit comments