Skip to content

Commit 23119e5

Browse files
committed
merge
2 parents a87954c + bdd51b0 commit 23119e5

File tree

4 files changed

+135
-84
lines changed

4 files changed

+135
-84
lines changed

lora_diffusion/cli_lora_pti.py

+125-44
Original file line numberDiff line numberDiff line change
@@ -130,17 +130,31 @@ def get_models(
130130
)
131131

132132

133-
def text2img_dataloader(train_dataset, train_batch_size, tokenizer, vae, text_encoder):
133+
@torch.no_grad()
134+
def text2img_dataloader(
135+
train_dataset,
136+
train_batch_size,
137+
tokenizer,
138+
vae,
139+
text_encoder,
140+
cached_latents: bool = False,
141+
):
142+
143+
if cached_latents:
144+
cached_latents_dataset = []
145+
for idx in tqdm(range(len(train_dataset))):
146+
batch = train_dataset[idx]
147+
# rint(batch)
148+
latents = vae.encode(
149+
batch["instance_images"].unsqueeze(0).to(dtype=vae.dtype).to(vae.device)
150+
).latent_dist.sample()
151+
latents = latents * 0.18215
152+
batch["instance_images"] = latents.squeeze(0)
153+
cached_latents_dataset.append(batch)
154+
134155
def collate_fn(examples):
135156
input_ids = [example["instance_prompt_ids"] for example in examples]
136157
pixel_values = [example["instance_images"] for example in examples]
137-
138-
# Concat class and instance examples for prior preservation.
139-
# We do this to avoid doing two forward passes.
140-
if examples[0].get("class_prompt_ids", None) is not None:
141-
input_ids += [example["class_prompt_ids"] for example in examples]
142-
pixel_values += [example["class_images"] for example in examples]
143-
144158
pixel_values = torch.stack(pixel_values)
145159
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
146160

@@ -168,27 +182,60 @@ def collate_fn(examples):
168182
shuffle=True,
169183
collate_fn=collate_fn,
170184
)
185+
if cached_latents:
186+
187+
train_dataloader = torch.utils.data.DataLoader(
188+
cached_latents_dataset,
189+
batch_size=train_batch_size,
190+
shuffle=True,
191+
collate_fn=collate_fn,
192+
)
193+
194+
print("PTI : Using cached latent.")
195+
196+
else:
197+
train_dataloader = torch.utils.data.DataLoader(
198+
train_dataset,
199+
batch_size=train_batch_size,
200+
shuffle=True,
201+
collate_fn=collate_fn,
202+
)
171203

172204
return train_dataloader
173205

174-
def inpainting_dataloader(train_dataset, train_batch_size, tokenizer, vae, text_encoder):
206+
207+
def inpainting_dataloader(
208+
train_dataset, train_batch_size, tokenizer, vae, text_encoder
209+
):
175210
def collate_fn(examples):
176211
input_ids = [example["instance_prompt_ids"] for example in examples]
177212
pixel_values = [example["instance_images"] for example in examples]
178213
mask_values = [example["instance_masks"] for example in examples]
179-
masked_image_values = [example["instance_masked_images"] for example in examples]
214+
masked_image_values = [
215+
example["instance_masked_images"] for example in examples
216+
]
180217

181218
# Concat class and instance examples for prior preservation.
182219
# We do this to avoid doing two forward passes.
183220
if examples[0].get("class_prompt_ids", None) is not None:
184221
input_ids += [example["class_prompt_ids"] for example in examples]
185222
pixel_values += [example["class_images"] for example in examples]
186223
mask_values += [example["class_masks"] for example in examples]
187-
masked_image_values += [example["class_masked_images"] for example in examples]
224+
masked_image_values += [
225+
example["class_masked_images"] for example in examples
226+
]
188227

189-
pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
190-
mask_values = torch.stack(mask_values).to(memory_format=torch.contiguous_format).float()
191-
masked_image_values = torch.stack(masked_image_values).to(memory_format=torch.contiguous_format).float()
228+
pixel_values = (
229+
torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
230+
)
231+
mask_values = (
232+
torch.stack(mask_values).to(memory_format=torch.contiguous_format).float()
233+
)
234+
masked_image_values = (
235+
torch.stack(masked_image_values)
236+
.to(memory_format=torch.contiguous_format)
237+
.float()
238+
)
192239

193240
input_ids = tokenizer.pad(
194241
{"input_ids": input_ids},
@@ -201,7 +248,7 @@ def collate_fn(examples):
201248
"input_ids": input_ids,
202249
"pixel_values": pixel_values,
203250
"mask_values": mask_values,
204-
"masked_image_values": masked_image_values
251+
"masked_image_values": masked_image_values,
205252
}
206253

207254
if examples[0].get("mask", None) is not None:
@@ -218,6 +265,7 @@ def collate_fn(examples):
218265

219266
return train_dataloader
220267

268+
221269
def loss_step(
222270
batch,
223271
unet,
@@ -228,23 +276,30 @@ def loss_step(
228276
t_mutliplier=1.0,
229277
mixed_precision=False,
230278
mask_temperature=1.0,
279+
cached_latents: bool = False,
231280
):
232281
weight_dtype = torch.float32
233-
234-
latents = vae.encode(
235-
batch["pixel_values"].to(dtype=weight_dtype).to(unet.device)
236-
).latent_dist.sample()
237-
latents = latents * 0.18215
238-
239-
if train_inpainting:
240-
masked_image_latents = vae.encode(
241-
batch["masked_image_values"].to(dtype=weight_dtype).to(unet.device)
282+
if not cached_latents:
283+
latents = vae.encode(
284+
batch["pixel_values"].to(dtype=weight_dtype).to(unet.device)
242285
).latent_dist.sample()
243-
masked_image_latents = masked_image_latents * 0.18215
244-
mask = F.interpolate(
245-
batch["mask_values"].to(dtype=weight_dtype).to(unet.device),
246-
scale_factor=1/8
247-
)
286+
latents = latents * 0.18215
287+
288+
if train_inpainting:
289+
masked_image_latents = vae.encode(
290+
batch["masked_image_values"].to(dtype=weight_dtype).to(unet.device)
291+
).latent_dist.sample()
292+
masked_image_latents = masked_image_latents * 0.18215
293+
mask = F.interpolate(
294+
batch["mask_values"].to(dtype=weight_dtype).to(unet.device),
295+
scale_factor=1 / 8,
296+
)
297+
else:
298+
latents = batch["pixel_values"]
299+
300+
if train_inpainting:
301+
masked_image_latents = batch["masked_image_latents"]
302+
mask = batch["mask_values"]
248303

249304
noise = torch.randn_like(latents)
250305
bsz = latents.shape[0]
@@ -260,7 +315,9 @@ def loss_step(
260315
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
261316

262317
if train_inpainting:
263-
latent_model_input = torch.cat([noisy_latents, mask, masked_image_latents], dim=1)
318+
latent_model_input = torch.cat(
319+
[noisy_latents, mask, masked_image_latents], dim=1
320+
)
264321
else:
265322
latent_model_input = noisy_latents
266323

@@ -271,7 +328,9 @@ def loss_step(
271328
batch["input_ids"].to(text_encoder.device)
272329
)[0]
273330

274-
model_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
331+
model_pred = unet(
332+
latent_model_input, timesteps, encoder_hidden_states
333+
).sample
275334
else:
276335

277336
encoder_hidden_states = text_encoder(
@@ -311,7 +370,12 @@ def loss_step(
311370

312371
target = target * mask
313372

314-
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
373+
loss = (
374+
F.mse_loss(model_pred.float(), target.float(), reduction="none")
375+
.mean([1, 2, 3])
376+
.mean()
377+
)
378+
315379
return loss
316380

317381

@@ -331,6 +395,7 @@ def train_inversion(
331395
tokenizer,
332396
lr_scheduler,
333397
test_image_path: str,
398+
cached_latents: bool,
334399
accum_iter: int = 1,
335400
log_wandb: bool = False,
336401
wandb_log_prompt_cnt: int = 10,
@@ -372,6 +437,7 @@ def train_inversion(
372437
scheduler,
373438
train_inpainting=train_inpainting,
374439
mixed_precision=mixed_precision,
440+
cached_latents=cached_latents,
375441
)
376442
/ accum_iter
377443
)
@@ -381,6 +447,13 @@ def train_inversion(
381447
loss_sum += loss.detach().item()
382448

383449
if global_step % accum_iter == 0:
450+
# print gradient of text encoder embedding
451+
print(
452+
text_encoder.get_input_embeddings()
453+
.weight.grad[index_updates, :]
454+
.norm(dim=-1)
455+
.mean()
456+
)
384457
optimizer.step()
385458
optimizer.zero_grad()
386459

@@ -455,7 +528,11 @@ def train_inversion(
455528
# open all images in test_image_path
456529
images = []
457530
for file in os.listdir(test_image_path):
458-
if file.lower().endswith(".png") or file.lower().endswith(".jpg") or file.lower().endswith(".jpeg"):
531+
if (
532+
file.lower().endswith(".png")
533+
or file.lower().endswith(".jpg")
534+
or file.lower().endswith(".jpeg")
535+
):
459536
images.append(
460537
Image.open(os.path.join(test_image_path, file))
461538
)
@@ -507,6 +584,7 @@ def perform_tuning(
507584
out_name: str,
508585
tokenizer,
509586
test_image_path: str,
587+
cached_latents: bool,
510588
log_wandb: bool = False,
511589
wandb_log_prompt_cnt: int = 10,
512590
class_token: str = "person",
@@ -547,6 +625,7 @@ def perform_tuning(
547625
t_mutliplier=0.8,
548626
mixed_precision=True,
549627
mask_temperature=mask_temperature,
628+
cached_latents=cached_latents,
550629
)
551630
loss_sum += loss.detach().item()
552631

@@ -671,18 +750,12 @@ def train(
671750
train_text_encoder: bool = True,
672751
pretrained_vae_name_or_path: str = None,
673752
revision: Optional[str] = None,
674-
class_data_dir: Optional[str] = None,
675-
stochastic_attribute: Optional[str] = None,
676753
perform_inversion: bool = True,
677754
use_template: Literal[None, "object", "style"] = None,
678755
train_inpainting: bool = False,
679756
placeholder_tokens: str = "",
680757
placeholder_token_at_data: Optional[str] = None,
681758
initializer_tokens: Optional[str] = None,
682-
class_prompt: Optional[str] = None,
683-
with_prior_preservation: bool = False,
684-
prior_loss_weight: float = 1.0,
685-
num_class_images: int = 100,
686759
seed: int = 42,
687760
resolution: int = 512,
688761
color_jitter: bool = True,
@@ -693,7 +766,6 @@ def train(
693766
save_steps: int = 100,
694767
gradient_accumulation_steps: int = 4,
695768
gradient_checkpointing: bool = False,
696-
mixed_precision="fp16",
697769
lora_rank: int = 4,
698770
lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"},
699771
lora_clip_target_modules={"CLIPAttention"},
@@ -707,6 +779,7 @@ def train(
707779
continue_inversion: bool = False,
708780
continue_inversion_lr: Optional[float] = None,
709781
use_face_segmentation_condition: bool = False,
782+
cached_latents: bool = True,
710783
use_mask_captioned_data: bool = False,
711784
mask_temperature: float = 1.0,
712785
scale_lr: bool = False,
@@ -820,11 +893,8 @@ def train(
820893

821894
train_dataset = PivotalTuningDatasetCapation(
822895
instance_data_root=instance_data_dir,
823-
stochastic_attribute=stochastic_attribute,
824896
token_map=token_map,
825897
use_template=use_template,
826-
class_data_root=class_data_dir if with_prior_preservation else None,
827-
class_prompt=class_prompt,
828898
tokenizer=tokenizer,
829899
size=resolution,
830900
color_jitter=color_jitter,
@@ -836,12 +906,19 @@ def train(
836906
train_dataset.blur_amount = 200
837907

838908
if train_inpainting:
909+
assert not cached_latents, "Cached latents not supported for inpainting"
910+
839911
train_dataloader = inpainting_dataloader(
840912
train_dataset, train_batch_size, tokenizer, vae, text_encoder
841913
)
842914
else:
843915
train_dataloader = text2img_dataloader(
844-
train_dataset, train_batch_size, tokenizer, vae, text_encoder
916+
train_dataset,
917+
train_batch_size,
918+
tokenizer,
919+
vae,
920+
text_encoder,
921+
cached_latents=cached_latents,
845922
)
846923

847924
index_no_updates = torch.arange(len(tokenizer)) != -1
@@ -860,6 +937,8 @@ def train(
860937
for param in params_to_freeze:
861938
param.requires_grad = False
862939

940+
if cached_latents:
941+
vae = None
863942
# STEP 1 : Perform Inversion
864943
if perform_inversion:
865944
preview_training_batch(train_dataloader, "inversion")
@@ -889,6 +968,7 @@ def train(
889968
text_encoder,
890969
train_dataloader,
891970
max_train_steps_ti,
971+
cached_latents=cached_latents,
892972
accum_iter=gradient_accumulation_steps,
893973
scheduler=noise_scheduler,
894974
index_no_updates=index_no_updates,
@@ -1003,6 +1083,7 @@ def train(
10031083
text_encoder,
10041084
train_dataloader,
10051085
max_train_steps_tuning,
1086+
cached_latents=cached_latents,
10061087
scheduler=noise_scheduler,
10071088
optimizer=lora_optimizers,
10081089
save_steps=save_steps,

0 commit comments

Comments
 (0)