Skip to content

Commit 05aea71

Browse files
author
multimodalart
committed
Add depth2img Gradio demo
1 parent cccfb98 commit 05aea71

File tree

2 files changed

+192
-2
lines changed

2 files changed

+192
-2
lines changed

README.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,18 @@ To augment the well-established [img2img](https://github.com/CompVis/stable-diff
136136
Note that the original method for image modification introduces significant semantic changes w.r.t. the initial image.
137137
If that is not desired, download our [depth-conditional stable diffusion](https://huggingface.co/stabilityai/stable-diffusion-2-depth) model and the `dpt_hybrid` MiDaS [model weights](https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt), place the latter in a folder `midas_models` and sample via
138138
```
139-
python scripts/streamlit/depth2img.py streamlit run scripts/demo/depth2img.py configs/stable-diffusion/v2-midas-inference.yaml <path-to-ckpt>
139+
python scripts/gradio/depth2img.py configs/stable-diffusion/v2-midas-inference.yaml <path-to-ckpt>
140+
```
141+
142+
or
143+
144+
```
145+
streamlit run scripts/streamlit/depth2img.py configs/stable-diffusion/v2-midas-inference.yaml <path-to-ckpt>
140146
```
141147

142148
This method can be used on the samples of the base model itself.
143149
For example, take [this sample](assets/stable-samples/depth2img/old_man.png) generated by an anonymous discord user.
144-
Using the [streamlit](https://streamlit.io/) script `depth2img.py`, the MiDaS model first infers a monocular depth estimate given this input,
150+
Using the [gradio](https://gradio.app) or [streamlit](https://streamlit.io/) script `depth2img.py`, the MiDaS model first infers a monocular depth estimate given this input,
145151
and the diffusion model is then conditioned on the (relative) depth output.
146152

147153
<p align="center">

scripts/gradio/depth2img.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import sys
2+
import torch
3+
import numpy as np
4+
import gradio as gr
5+
from PIL import Image
6+
from omegaconf import OmegaConf
7+
from einops import repeat, rearrange
8+
from pytorch_lightning import seed_everything
9+
from imwatermark import WatermarkEncoder
10+
11+
from scripts.txt2img import put_watermark
12+
from ldm.util import instantiate_from_config
13+
from ldm.models.diffusion.ddim import DDIMSampler
14+
from ldm.data.util import AddMiDaS
15+
16+
torch.set_grad_enabled(False)
17+
18+
19+
def initialize_model(config, ckpt):
20+
config = OmegaConf.load(config)
21+
model = instantiate_from_config(config.model)
22+
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
23+
24+
device = torch.device(
25+
"cuda") if torch.cuda.is_available() else torch.device("cpu")
26+
model = model.to(device)
27+
sampler = DDIMSampler(model)
28+
return sampler
29+
30+
31+
def make_batch_sd(
32+
image,
33+
txt,
34+
device,
35+
num_samples=1,
36+
model_type="dpt_hybrid"
37+
):
38+
image = np.array(image.convert("RGB"))
39+
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
40+
# sample['jpg'] is tensor hwc in [-1, 1] at this point
41+
midas_trafo = AddMiDaS(model_type=model_type)
42+
batch = {
43+
"jpg": image,
44+
"txt": num_samples * [txt],
45+
}
46+
batch = midas_trafo(batch)
47+
batch["jpg"] = rearrange(batch["jpg"], 'h w c -> 1 c h w')
48+
batch["jpg"] = repeat(batch["jpg"].to(device=device),
49+
"1 ... -> n ...", n=num_samples)
50+
batch["midas_in"] = repeat(torch.from_numpy(batch["midas_in"][None, ...]).to(
51+
device=device), "1 ... -> n ...", n=num_samples)
52+
return batch
53+
54+
55+
def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None,
56+
do_full_sample=False):
57+
device = torch.device(
58+
"cuda") if torch.cuda.is_available() else torch.device("cpu")
59+
model = sampler.model
60+
seed_everything(seed)
61+
62+
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
63+
wm = "SDV2"
64+
wm_encoder = WatermarkEncoder()
65+
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
66+
67+
with torch.no_grad(),\
68+
torch.autocast("cuda"):
69+
batch = make_batch_sd(
70+
image, txt=prompt, device=device, num_samples=num_samples)
71+
z = model.get_first_stage_encoding(model.encode_first_stage(
72+
batch[model.first_stage_key])) # move to latent space
73+
c = model.cond_stage_model.encode(batch["txt"])
74+
c_cat = list()
75+
for ck in model.concat_keys:
76+
cc = batch[ck]
77+
cc = model.depth_model(cc)
78+
depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
79+
keepdim=True)
80+
display_depth = (cc - depth_min) / (depth_max - depth_min)
81+
depth_image = Image.fromarray(
82+
(display_depth[0, 0, ...].cpu().numpy() * 255.).astype(np.uint8))
83+
cc = torch.nn.functional.interpolate(
84+
cc,
85+
size=z.shape[2:],
86+
mode="bicubic",
87+
align_corners=False,
88+
)
89+
depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
90+
keepdim=True)
91+
cc = 2. * (cc - depth_min) / (depth_max - depth_min) - 1.
92+
c_cat.append(cc)
93+
c_cat = torch.cat(c_cat, dim=1)
94+
# cond
95+
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
96+
97+
# uncond cond
98+
uc_cross = model.get_unconditional_conditioning(num_samples, "")
99+
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
100+
if not do_full_sample:
101+
# encode (scaled latent)
102+
z_enc = sampler.stochastic_encode(
103+
z, torch.tensor([t_enc] * num_samples).to(model.device))
104+
else:
105+
z_enc = torch.randn_like(z)
106+
# decode it
107+
samples = sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale,
108+
unconditional_conditioning=uc_full, callback=callback)
109+
x_samples_ddim = model.decode_first_stage(samples)
110+
result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
111+
result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
112+
return [depth_image] + [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
113+
114+
115+
def pad_image(input_image):
116+
pad_w, pad_h = np.max(((2, 2), np.ceil(
117+
np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size
118+
im_padded = Image.fromarray(
119+
np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
120+
return im_padded
121+
122+
123+
def predict(input_image, prompt, steps, num_samples, scale, seed, eta, strength):
124+
init_image = input_image.convert("RGB")
125+
image = pad_image(init_image) # resize to integer multiple of 32
126+
127+
sampler.make_schedule(steps, ddim_eta=eta, verbose=True)
128+
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
129+
do_full_sample = strength == 1.
130+
t_enc = min(int(strength * steps), steps-1)
131+
result = paint(
132+
sampler=sampler,
133+
image=image,
134+
prompt=prompt,
135+
t_enc=t_enc,
136+
seed=seed,
137+
scale=scale,
138+
num_samples=num_samples,
139+
callback=None,
140+
do_full_sample=do_full_sample
141+
)
142+
return result
143+
144+
145+
sampler = initialize_model(sys.argv[1], sys.argv[2])
146+
147+
block = gr.Blocks().queue()
148+
with block:
149+
with gr.Row():
150+
gr.Markdown("## Stable Diffusion Depth2Img")
151+
152+
with gr.Row():
153+
with gr.Column():
154+
input_image = gr.Image(source='upload', type="pil")
155+
prompt = gr.Textbox(label="Prompt")
156+
run_button = gr.Button(label="Run")
157+
with gr.Accordion("Advanced options", open=False):
158+
num_samples = gr.Slider(
159+
label="Images", minimum=1, maximum=4, value=1, step=1)
160+
ddim_steps = gr.Slider(label="Steps", minimum=1,
161+
maximum=50, value=50, step=1)
162+
scale = gr.Slider(
163+
label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1
164+
)
165+
strength = gr.Slider(
166+
label="Strength", minimum=0.0, maximum=1.0, value=0.9, step=0.01
167+
)
168+
seed = gr.Slider(
169+
label="Seed",
170+
minimum=0,
171+
maximum=2147483647,
172+
step=1,
173+
randomize=True,
174+
)
175+
eta = gr.Number(label="eta (DDIM)", value=0.0)
176+
with gr.Column():
177+
gallery = gr.Gallery(label="Generated images", show_label=False).style(
178+
grid=[2], height="auto")
179+
180+
run_button.click(fn=predict, inputs=[
181+
input_image, prompt, ddim_steps, num_samples, scale, seed, eta, strength], outputs=[gallery])
182+
183+
184+
block.launch()

0 commit comments

Comments
 (0)