Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: cloneofsimo/lora
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: master
Choose a base ref
...
head repository: cloneofsimo/lora
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: develop
Choose a head ref
Able to merge. These branches can be automatically merged.

Commits on Feb 13, 2023

  1. Copy the full SHA
    388d666 View commit details
  2. Copy the full SHA
    a87954c View commit details
  3. merge

    aiXander committed Feb 13, 2023
    Copy the full SHA
    23119e5 View commit details
  4. xander updates

    aiXander committed Feb 13, 2023
    Copy the full SHA
    f476828 View commit details
  5. make import more flexible

    aiXander committed Feb 13, 2023
    Copy the full SHA
    720037c View commit details
  6. cleanup

    aiXander committed Feb 13, 2023
    Copy the full SHA
    ef14040 View commit details
  7. cleanup

    aiXander committed Feb 13, 2023
    Copy the full SHA
    cfde3e7 View commit details
  8. fix bug

    aiXander committed Feb 13, 2023
    Copy the full SHA
    9a77b65 View commit details
  9. fix bugs

    aiXander committed Feb 13, 2023
    Copy the full SHA
    61f1a19 View commit details
  10. bugfixes

    aiXander committed Feb 13, 2023
    Copy the full SHA
    cf69309 View commit details
  11. bugfixes

    aiXander committed Feb 13, 2023
    Copy the full SHA
    7399d82 View commit details
  12. more bugfixes

    aiXander committed Feb 13, 2023
    Copy the full SHA
    49610b5 View commit details
  13. Copy the full SHA
    7445169 View commit details
  14. Copy the full SHA
    fd154d1 View commit details
  15. minor updates

    aiXander committed Feb 13, 2023
    Copy the full SHA
    3f10e64 View commit details

Commits on Feb 14, 2023

  1. fix ti continuation

    aiXander committed Feb 14, 2023
    Copy the full SHA
    2b214d0 View commit details
  2. Various Loggings, bugfixes, enhancements by Xander

    Xander updates
    cloneofsimo authored Feb 14, 2023
    Copy the full SHA
    871bf81 View commit details

Commits on Feb 15, 2023

  1. format and bugfix

    cloneofsimo committed Feb 15, 2023
    Copy the full SHA
    3026c58 View commit details
  2. now unet trains

    cloneofsimo committed Feb 15, 2023
    Copy the full SHA
    c52dfae View commit details
  3. format : black

    cloneofsimo committed Feb 15, 2023
    Copy the full SHA
    b1853d1 View commit details
  4. bugfix : to L

    cloneofsimo committed Feb 15, 2023
    Copy the full SHA
    9a6b552 View commit details
  5. version

    cloneofsimo committed Feb 15, 2023
    Copy the full SHA
    71c8c1d View commit details

Commits on Feb 17, 2023

  1. Copy the full SHA
    cdbaf7a View commit details
  2. Change typing import to support python 3.7

    * Remove unused import
    A2va committed Feb 17, 2023
    Copy the full SHA
    c45d59b View commit details
  3. Other typing fixes

    A2va committed Feb 17, 2023
    Copy the full SHA
    246c068 View commit details
  4. Update requirements.txt

    A2va committed Feb 17, 2023
    Copy the full SHA
    46116b8 View commit details

Commits on Feb 21, 2023

  1. Merge pull request #193 from A2va/caption-list

    Support captions list in preprocess files
    cloneofsimo authored Feb 21, 2023
    Copy the full SHA
    a0fb8e6 View commit details
  2. Copy the full SHA
    b85aa80 View commit details
  3. Merge pull request #194 from A2va/typing-fix

    Typing fix
    cloneofsimo authored Feb 21, 2023
    Copy the full SHA
    9714747 View commit details
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -8,4 +8,5 @@ wandb
exps*
.vscode
build
lora_diffusion.egg-info
lora_diffusion.egg-info
training_batch_preview
41 changes: 32 additions & 9 deletions lora_diffusion/cli_lora_add.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,38 @@
from typing import Literal, Union, Dict
import sys
if sys.version_info >= (3,8):
from typing import Literal
else :
from typing_extensions import Literal
from typing import Union, Dict
import os
import shutil
import fire
from diffusers import StableDiffusionPipeline
from safetensors.torch import safe_open, save_file

import torch
from .lora import (
tune_lora_scale,
patch_pipe,
collapse_lora,
monkeypatch_remove_lora,
)
from .lora_manager import lora_join
from .to_ckpt_v2 import convert_to_ckpt

try:
from .lora import (
tune_lora_scale,
patch_pipe,
collapse_lora,
monkeypatch_remove_lora,
)

from .lora_manager import lora_join
from .to_ckpt_v2 import convert_to_ckpt

except: # allows running the repo without installing it (can mess up existing dependencies)
from lora_diffusion import (
tune_lora_scale,
patch_pipe,
collapse_lora,
monkeypatch_remove_lora,
)

from lora_diffusion.lora_manager import lora_join
from lora_diffusion.to_ckpt_v2 import convert_to_ckpt


def _text_lora_path(path: str) -> str:
@@ -185,3 +204,7 @@ def add(

def main():
fire.Fire(add)


if __name__ == "__main__":
main()
172 changes: 155 additions & 17 deletions lora_diffusion/cli_lora_pti.py

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions lora_diffusion/cli_pt_to_safetensors.py
Original file line number Diff line number Diff line change
@@ -62,9 +62,11 @@ def convert(*paths, outpath, overwrite=False, **settings):
}

prefix = f"{name}."

arg_settings = { k[len(prefix) :]: v for k, v in settings.items() if k.startswith(prefix) }
model_settings = { **model_settings, **arg_settings }

arg_settings = {
k[len(prefix) :]: v for k, v in settings.items() if k.startswith(prefix)
}
model_settings = {**model_settings, **arg_settings}

print(f"Loading Lora for {name} from {path} with settings {model_settings}")

126 changes: 114 additions & 12 deletions lora_diffusion/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import random
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
from PIL import Image
from torch import zeros_like
from torch.utils.data import Dataset
@@ -39,7 +39,38 @@
"a photo of a small {}",
]

STYLE_TEMPLATE = [
PERSON_TEMPLATE = [
"{}",
"{}",
"a picture of {}",
"a closeup of {}",
"a closeup photo of {}",
"a close-up picture of {}",
"a photo of {}",
"a photo of {}",
"the photo of {}",
"a cropped photo of {}",
"a funny photo of {}",
"a selfie of {}",
"a photo of the handsome {}",
"a photo of the beautiful {}",
"a selfie taken by the handsome {}",
"a selfie taken by {}",
"{} taking a selfie",
"{} is having fun, 4k photograph",
"{} wearing a plaidered shirt standing next to another person",
"smiling {} in a hoodie and sweater",
"a photo of the cool {}",
"a close-up photo of {}",
"a bright photo of {}",
"a cropped photo of {}",
"a brilliant HD photo of {}",
"a beautiful picture of {}",
"a photo showing {}",
"a great photo of {}",
]

STYLE_TEMPLATE_ORIG = [
"a painting in the style of {}",
"a rendering in the style of {}",
"a cropped painting in the style of {}",
@@ -61,10 +92,28 @@
"a large painting in the style of {}",
]

STYLE_TEMPLATE = [
"a painting in the style of {}",
"a rendering in the style of {}",
"an artwork in the style of {}",
"a magnificent painting in the style of {}",
"a picture in the style of {}",
"a photograph, {} style",
"{} style painting",
"a {}-styled artwork",
"a nice painting in the style of {}",
"a goregous example of {} style",
"image in the style of {}",
"{}, painting",
"{} artwork",
]


NULL_TEMPLATE = ["{}"]

TEMPLATE_MAP = {
"object": OBJECT_TEMPLATE,
"person": PERSON_TEMPLATE,
"style": STYLE_TEMPLATE,
"null": NULL_TEMPLATE,
}
@@ -116,6 +165,35 @@ def _generate_random_mask(image):
return mask, masked_image


def expand_rectangle(mask, f):
rows, cols = np.where(mask == 255)
top_row, bottom_row = np.min(rows), np.max(rows)
left_col, right_col = np.min(cols), np.max(cols)

rect_height, rect_width = bottom_row - top_row + 1, right_col - left_col + 1
new_height, new_width = np.round(rect_height * f), np.round(rect_width * f)

center_row, center_col = top_row + rect_height // 2, left_col + rect_width // 2
top_row, bottom_row = np.round(center_row - new_height / 2), np.round(
center_row + new_height / 2
)
left_col, right_col = np.round(center_col - new_width / 2), np.round(
center_col + new_width / 2
)

top_row, bottom_row = int(np.clip(top_row, 0, mask.shape[0] - 1)), int(
np.clip(bottom_row, 0, mask.shape[0] - 1)
)
left_col, right_col = int(np.clip(left_col, 0, mask.shape[1] - 1)), int(
np.clip(right_col, 0, mask.shape[1] - 1)
)

expanded_mask = np.ones_like(mask)
expanded_mask[top_row : bottom_row + 1, left_col : right_col + 1] = 255

return expanded_mask


class PivotalTuningDatasetCapation(Dataset):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
@@ -141,6 +219,8 @@ def __init__(
self.tokenizer = tokenizer
self.resize = resize
self.train_inpainting = train_inpainting
self.h_flip_prob = 0.5
self.final_flip_prob = 0.33 if use_template == "person" else 0.5

instance_data_root = Path(instance_data_root)
if not instance_data_root.exists():
@@ -156,6 +236,10 @@ def __init__(
# Prepare the instance images
if use_mask_captioned_data:
src_imgs = glob.glob(str(instance_data_root) + "/*src.jpg")
src_imgs = sorted(
src_imgs, key=lambda x: int(str(Path(x).stem).split(".")[0])
)

for f in src_imgs:
idx = int(str(Path(f).stem).split(".")[0])
mask_path = f"{instance_data_root}/{idx}.mask.png"
@@ -218,6 +302,18 @@ def __init__(
]
)
for idx, mask in enumerate(masks):
avg_pixel_value = np.array(mask.getdata()).mean()
if avg_pixel_value == 1.0:
print(f"No mask detected for {idx}..")
else:
if 1:
# convert to numpy array:
mask = np.array(mask)
# Make the rectangular mask region bigger:
mask = expand_rectangle(mask, 1.25)
# convert back to PIL image:
mask = Image.fromarray(mask).convert("L")

mask.save(f"{instance_data_root}/{idx}.mask.png")

break
@@ -237,12 +333,13 @@ def __init__(
self.h_flip = h_flip
self.image_transforms = transforms.Compose(
[
transforms.RandomAffine(degrees=0, translate=(0, 0), scale=(1.0, 1.2)),
transforms.Resize(
size, interpolation=transforms.InterpolationMode.BILINEAR
)
if resize
else transforms.Lambda(lambda x: x),
transforms.ColorJitter(0.1, 0.1)
transforms.ColorJitter(0.1, 0.1, 0.02, 0.02)
if color_jitter
else transforms.Lambda(lambda x: x),
transforms.CenterCrop(size),
@@ -253,6 +350,15 @@ def __init__(

self.blur_amount = blur_amount

print("Captions:")
print(self.captions)

def tune_h_flip_prob(self, training_progress):
if self.h_flip:
# Tune the h_flip probability to be 0.5 training_progress is 0 and end_prob when training_progress is 1
self.h_flip_prob = 0.5 + (self.final_flip_prob - 0.5) * training_progress
print(f"h_flip_prob: {self.h_flip_prob:.3f}")

def __len__(self):
return self._length

@@ -283,18 +389,14 @@ def __getitem__(self, index):
for token, value in self.token_map.items():
text = text.replace(token, value)

print(text)
if random.random() < 0.1:
print(text)

if self.use_mask:
example["mask"] = (
self.image_transforms(
Image.open(self.mask_path[index % self.num_instance_images])
)
* 0.5
+ 1.0
)
img_mask = Image.open(self.mask_path[index % self.num_instance_images])
example["mask"] = self.image_transforms(img_mask) * 0.5 + 1.0

if self.h_flip and random.random() > 0.5:
if self.h_flip and random.random() < self.h_flip_prob:
hflip = transforms.RandomHorizontalFlip(p=1)

example["instance_images"] = hflip(example["instance_images"])
9 changes: 7 additions & 2 deletions lora_diffusion/lora.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import json
import math
from itertools import groupby
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
import sys
if sys.version_info >= (3,9):
from typing import Type
else :
from typing_extensions import Type
from typing import Callable, Dict, List, Optional, Set, Tuple, Union

import numpy as np
import PIL
@@ -914,7 +919,7 @@ def apply_learned_embed_in_clip(
trained_tokens = list(learned_embeds.keys())

for token in trained_tokens:
print(token)
print("Adding new token: ", token)
embeds = learned_embeds[token]

# cast to dtype of text_encoder
25 changes: 19 additions & 6 deletions lora_diffusion/preprocess_files.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,12 @@
# Have BLIP auto caption
# Have CLIPSeg auto mask concept

from typing import List, Literal, Union, Optional, Tuple
import sys
if sys.version_info >= (3,8):
from typing import Literal
else :
from typing_extensions import Literal
from typing import List, Union, Optional, Tuple
import os
from PIL import Image, ImageFilter
import torch
@@ -244,7 +249,7 @@ def _center_of_mass(mask: Image.Image):
def load_and_save_masks_and_captions(
files: Union[str, List[str]],
output_dir: str,
caption_text: Optional[str] = None,
captions_text: Optional[Union[List[str], str]] = None,
target_prompts: Optional[Union[List[str], str]] = None,
target_size: int = 512,
crop_based_on_salience: bool = True,
@@ -263,8 +268,10 @@ def load_and_save_masks_and_captions(
# check if it is a directory
if os.path.isdir(files):
# get all the .png .jpg in the directory
files = glob.glob(os.path.join(files, "*.png")) + glob.glob(
os.path.join(files, "*.jpg")
files = (
glob.glob(os.path.join(files, "*.png"))
+ glob.glob(os.path.join(files, "*.jpg"))
+ glob.glob(os.path.join(files, "*.jpeg"))
)

if len(files) == 0:
@@ -278,8 +285,10 @@ def load_and_save_masks_and_captions(
images = [Image.open(file) for file in files]

# captions
print(f"Generating {len(images)} captions...")
captions = blip_captioning_dataset(images, text=caption_text)
captions = caption_text
if not isinstance(caption_text, List):
print(f"Generating {len(images)} captions...")
captions = blip_captioning_dataset(images, text=caption_text)

if target_prompts is None:
target_prompts = captions
@@ -325,3 +334,7 @@ def load_and_save_masks_and_captions(

def main():
fire.Fire(load_and_save_masks_and_captions)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -8,3 +8,4 @@ safetensors
opencv-python
torchvision
mediapipe
typing_extensions; python_version < '3.9'
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@
setup(
name="lora_diffusion",
py_modules=["lora_diffusion"],
version="0.1.7",
version="0.1.8",
description="Low Rank Adaptation for Diffusion Models. Works with Stable Diffusion out-of-the-box.",
author="Simo Ryu",
packages=find_packages(),
Binary file added textual_inversion.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.