|
| 1 | +import sys |
| 2 | +import torch |
| 3 | +import numpy as np |
| 4 | +import gradio as gr |
| 5 | +from PIL import Image |
| 6 | +from omegaconf import OmegaConf |
| 7 | +from einops import repeat, rearrange |
| 8 | +from pytorch_lightning import seed_everything |
| 9 | +from imwatermark import WatermarkEncoder |
| 10 | + |
| 11 | +from scripts.txt2img import put_watermark |
| 12 | +from ldm.util import instantiate_from_config |
| 13 | +from ldm.models.diffusion.ddim import DDIMSampler |
| 14 | +from ldm.data.util import AddMiDaS |
| 15 | + |
| 16 | +torch.set_grad_enabled(False) |
| 17 | + |
| 18 | + |
| 19 | +def initialize_model(config, ckpt): |
| 20 | + config = OmegaConf.load(config) |
| 21 | + model = instantiate_from_config(config.model) |
| 22 | + model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False) |
| 23 | + |
| 24 | + device = torch.device( |
| 25 | + "cuda") if torch.cuda.is_available() else torch.device("cpu") |
| 26 | + model = model.to(device) |
| 27 | + sampler = DDIMSampler(model) |
| 28 | + return sampler |
| 29 | + |
| 30 | + |
| 31 | +def make_batch_sd( |
| 32 | + image, |
| 33 | + txt, |
| 34 | + device, |
| 35 | + num_samples=1, |
| 36 | + model_type="dpt_hybrid" |
| 37 | +): |
| 38 | + image = np.array(image.convert("RGB")) |
| 39 | + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 |
| 40 | + # sample['jpg'] is tensor hwc in [-1, 1] at this point |
| 41 | + midas_trafo = AddMiDaS(model_type=model_type) |
| 42 | + batch = { |
| 43 | + "jpg": image, |
| 44 | + "txt": num_samples * [txt], |
| 45 | + } |
| 46 | + batch = midas_trafo(batch) |
| 47 | + batch["jpg"] = rearrange(batch["jpg"], 'h w c -> 1 c h w') |
| 48 | + batch["jpg"] = repeat(batch["jpg"].to(device=device), |
| 49 | + "1 ... -> n ...", n=num_samples) |
| 50 | + batch["midas_in"] = repeat(torch.from_numpy(batch["midas_in"][None, ...]).to( |
| 51 | + device=device), "1 ... -> n ...", n=num_samples) |
| 52 | + return batch |
| 53 | + |
| 54 | + |
| 55 | +def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None, |
| 56 | + do_full_sample=False): |
| 57 | + device = torch.device( |
| 58 | + "cuda") if torch.cuda.is_available() else torch.device("cpu") |
| 59 | + model = sampler.model |
| 60 | + seed_everything(seed) |
| 61 | + |
| 62 | + print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") |
| 63 | + wm = "SDV2" |
| 64 | + wm_encoder = WatermarkEncoder() |
| 65 | + wm_encoder.set_watermark('bytes', wm.encode('utf-8')) |
| 66 | + |
| 67 | + with torch.no_grad(),\ |
| 68 | + torch.autocast("cuda"): |
| 69 | + batch = make_batch_sd( |
| 70 | + image, txt=prompt, device=device, num_samples=num_samples) |
| 71 | + z = model.get_first_stage_encoding(model.encode_first_stage( |
| 72 | + batch[model.first_stage_key])) # move to latent space |
| 73 | + c = model.cond_stage_model.encode(batch["txt"]) |
| 74 | + c_cat = list() |
| 75 | + for ck in model.concat_keys: |
| 76 | + cc = batch[ck] |
| 77 | + cc = model.depth_model(cc) |
| 78 | + depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3], |
| 79 | + keepdim=True) |
| 80 | + display_depth = (cc - depth_min) / (depth_max - depth_min) |
| 81 | + depth_image = Image.fromarray( |
| 82 | + (display_depth[0, 0, ...].cpu().numpy() * 255.).astype(np.uint8)) |
| 83 | + cc = torch.nn.functional.interpolate( |
| 84 | + cc, |
| 85 | + size=z.shape[2:], |
| 86 | + mode="bicubic", |
| 87 | + align_corners=False, |
| 88 | + ) |
| 89 | + depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3], |
| 90 | + keepdim=True) |
| 91 | + cc = 2. * (cc - depth_min) / (depth_max - depth_min) - 1. |
| 92 | + c_cat.append(cc) |
| 93 | + c_cat = torch.cat(c_cat, dim=1) |
| 94 | + # cond |
| 95 | + cond = {"c_concat": [c_cat], "c_crossattn": [c]} |
| 96 | + |
| 97 | + # uncond cond |
| 98 | + uc_cross = model.get_unconditional_conditioning(num_samples, "") |
| 99 | + uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]} |
| 100 | + if not do_full_sample: |
| 101 | + # encode (scaled latent) |
| 102 | + z_enc = sampler.stochastic_encode( |
| 103 | + z, torch.tensor([t_enc] * num_samples).to(model.device)) |
| 104 | + else: |
| 105 | + z_enc = torch.randn_like(z) |
| 106 | + # decode it |
| 107 | + samples = sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale, |
| 108 | + unconditional_conditioning=uc_full, callback=callback) |
| 109 | + x_samples_ddim = model.decode_first_stage(samples) |
| 110 | + result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) |
| 111 | + result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255 |
| 112 | + return [depth_image] + [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result] |
| 113 | + |
| 114 | + |
| 115 | +def pad_image(input_image): |
| 116 | + pad_w, pad_h = np.max(((2, 2), np.ceil( |
| 117 | + np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size |
| 118 | + im_padded = Image.fromarray( |
| 119 | + np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) |
| 120 | + return im_padded |
| 121 | + |
| 122 | + |
| 123 | +def predict(input_image, prompt, steps, num_samples, scale, seed, eta, strength): |
| 124 | + init_image = input_image.convert("RGB") |
| 125 | + image = pad_image(init_image) # resize to integer multiple of 32 |
| 126 | + |
| 127 | + sampler.make_schedule(steps, ddim_eta=eta, verbose=True) |
| 128 | + assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]' |
| 129 | + do_full_sample = strength == 1. |
| 130 | + t_enc = min(int(strength * steps), steps-1) |
| 131 | + result = paint( |
| 132 | + sampler=sampler, |
| 133 | + image=image, |
| 134 | + prompt=prompt, |
| 135 | + t_enc=t_enc, |
| 136 | + seed=seed, |
| 137 | + scale=scale, |
| 138 | + num_samples=num_samples, |
| 139 | + callback=None, |
| 140 | + do_full_sample=do_full_sample |
| 141 | + ) |
| 142 | + return result |
| 143 | + |
| 144 | + |
| 145 | +sampler = initialize_model(sys.argv[1], sys.argv[2]) |
| 146 | + |
| 147 | +block = gr.Blocks().queue() |
| 148 | +with block: |
| 149 | + with gr.Row(): |
| 150 | + gr.Markdown("## Stable Diffusion Depth2Img") |
| 151 | + |
| 152 | + with gr.Row(): |
| 153 | + with gr.Column(): |
| 154 | + input_image = gr.Image(source='upload', type="pil") |
| 155 | + prompt = gr.Textbox(label="Prompt") |
| 156 | + run_button = gr.Button(label="Run") |
| 157 | + with gr.Accordion("Advanced options", open=False): |
| 158 | + num_samples = gr.Slider( |
| 159 | + label="Images", minimum=1, maximum=4, value=1, step=1) |
| 160 | + ddim_steps = gr.Slider(label="Steps", minimum=1, |
| 161 | + maximum=50, value=50, step=1) |
| 162 | + scale = gr.Slider( |
| 163 | + label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1 |
| 164 | + ) |
| 165 | + strength = gr.Slider( |
| 166 | + label="Strength", minimum=0.0, maximum=1.0, value=0.9, step=0.01 |
| 167 | + ) |
| 168 | + seed = gr.Slider( |
| 169 | + label="Seed", |
| 170 | + minimum=0, |
| 171 | + maximum=2147483647, |
| 172 | + step=1, |
| 173 | + randomize=True, |
| 174 | + ) |
| 175 | + eta = gr.Number(label="eta (DDIM)", value=0.0) |
| 176 | + with gr.Column(): |
| 177 | + gallery = gr.Gallery(label="Generated images", show_label=False).style( |
| 178 | + grid=[2], height="auto") |
| 179 | + |
| 180 | + run_button.click(fn=predict, inputs=[ |
| 181 | + input_image, prompt, ddim_steps, num_samples, scale, seed, eta, strength], outputs=[gallery]) |
| 182 | + |
| 183 | + |
| 184 | +block.launch() |
0 commit comments