Skip to content

Commit a829ca5

Browse files
forge and gradio 4 compatibility
1 parent e4df29b commit a829ca5

File tree

4 files changed

+99
-39
lines changed

4 files changed

+99
-39
lines changed

scripts/depthmap_api.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ async def process(
5858

5959
if len(depth_input_images) == 0:
6060
raise HTTPException(status_code=422, detail="No images supplied")
61-
print(f"Processing {str(len(depth_input_images))} images trough the API")
61+
print(f"Processing {str(len(depth_input_images))} images through the API")
6262

6363
pil_images = []
6464
for input_image in depth_input_images:
@@ -81,7 +81,7 @@ async def process_video(
8181
):
8282
if len(depth_input_images) == 0:
8383
raise HTTPException(status_code=422, detail="No images supplied")
84-
print(f"Processing {str(len(depth_input_images))} images trough the API")
84+
print(f"Processing {str(len(depth_input_images))} images through the API")
8585

8686
# You can use either these strings, or integers
8787
available_models = {

src/backbone.py

+28-10
Original file line numberDiff line numberDiff line change
@@ -60,21 +60,39 @@ def get_outpath():
6060

6161
def unload_sd_model():
6262
from modules import shared, devices
63-
if shared.sd_model is not None:
64-
if shared.sd_model.cond_stage_model is not None:
65-
shared.sd_model.cond_stage_model.to(devices.cpu)
66-
if shared.sd_model.first_stage_model is not None:
67-
shared.sd_model.first_stage_model.to(devices.cpu)
63+
try:
64+
if shared.sd_model is not None:
65+
if shared.sd_model.cond_stage_model is not None:
66+
shared.sd_model.cond_stage_model.to(devices.cpu)
67+
if shared.sd_model.first_stage_model is not None:
68+
shared.sd_model.first_stage_model.to(devices.cpu)
69+
except Exception as e:
70+
from backend import memory_management
71+
print('trying to catch forge (might be a attribute error)')
72+
if type(e)== AttributeError:
73+
memory_management.unload_all_models()
74+
memory_management.soft_empty_cache()
75+
else:
76+
raise
6877
# Maybe something else???
6978

7079

7180
def reload_sd_model():
7281
from modules import shared, devices
73-
if shared.sd_model is not None:
74-
if shared.sd_model.cond_stage_model is not None:
75-
shared.sd_model.cond_stage_model.to(devices.device)
76-
if shared.sd_model.first_stage_model:
77-
shared.sd_model.first_stage_model.to(devices.device)
82+
try:
83+
if shared.sd_model is not None:
84+
if shared.sd_model.cond_stage_model is not None:
85+
shared.sd_model.cond_stage_model.to(devices.device)
86+
if shared.sd_model.first_stage_model:
87+
shared.sd_model.first_stage_model.to(devices.device)
88+
except Exception as e:
89+
from backend import memory_management
90+
print('trying to catch forge (might be a attribute error)')
91+
if type(e)== AttributeError:
92+
memory_management.unload_all_models()
93+
memory_management.soft_empty_cache()
94+
else:
95+
raise
7896
# Maybe something else???
7997

8098
def get_hide_dirs():

src/common_ui.py

+56-25
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ def ensure_gradio_temp_directory():
2525

2626

2727
def main_ui_panel(is_depth_tab):
28+
29+
is_gradio4 = int(gr.__version__[0])>3
30+
if is_gradio4:
31+
Box = gr.Group
32+
else:
33+
Box = gr.Box
34+
2835
inp = GradioComponentBundle()
2936
# TODO: Greater visual separation
3037
with gr.Blocks():
@@ -41,7 +48,7 @@ def main_ui_panel(is_depth_tab):
4148
'Marigold v1', 'Depth Anything', 'Depth Anything v2 Small',
4249
'Depth Anything v2 Base', 'Depth Anything v2 Large'],
4350
value='Depth Anything v2 Base', type="index")
44-
with gr.Box() as cur_option_root:
51+
with Box() as cur_option_root:
4552
inp -= 'depthmap_gen_row_1', cur_option_root
4653
with gr.Row():
4754
inp += go.BOOST, gr.Checkbox(label="BOOST",
@@ -57,7 +64,7 @@ def main_ui_panel(is_depth_tab):
5764
label='Tiling mode', info='Reduces seams that appear if the depthmap is tiled into a grid'
5865
)
5966

60-
with gr.Box() as cur_option_root:
67+
with Box() as cur_option_root:
6168
inp -= 'depthmap_gen_row_2', cur_option_root
6269
with gr.Row():
6370
with gr.Group(): # 50% of width
@@ -71,7 +78,7 @@ def main_ui_panel(is_depth_tab):
7178
inp += go.OUTPUT_DEPTH_COMBINE_AXIS, gr.Radio(
7279
label="Combine axis", choices=['Vertical', 'Horizontal'], type="value", visible=False)
7380

74-
with gr.Box() as cur_option_root:
81+
with Box() as cur_option_root:
7582
inp -= 'depthmap_gen_row_3', cur_option_root
7683
with gr.Row():
7784
inp += go.CLIPDEPTH, gr.Checkbox(label="Clip and renormalize DepthMap")
@@ -81,7 +88,7 @@ def main_ui_panel(is_depth_tab):
8188
inp += go.CLIPDEPTH_FAR, gr.Slider(minimum=0, maximum=1, step=0.001, label='Far clip')
8289
inp += go.CLIPDEPTH_NEAR, gr.Slider(minimum=0, maximum=1, step=0.001, label='Near clip')
8390

84-
with gr.Box():
91+
with Box():
8592
with gr.Row():
8693
inp += go.GEN_STEREO, gr.Checkbox(label="Generate stereoscopic (3D) image(s)")
8794
with gr.Column(visible=False) as stereo_options:
@@ -104,7 +111,7 @@ def main_ui_panel(is_depth_tab):
104111
inp += go.STEREO_BALANCE, gr.Slider(minimum=-1.0, maximum=1.0, step=0.05,
105112
label='Balance between eyes')
106113

107-
with gr.Box():
114+
with Box():
108115
with gr.Row():
109116
inp += go.GEN_NORMALMAP, gr.Checkbox(label="Generate NormalMap")
110117
with gr.Column(visible=False) as normalmap_options:
@@ -124,11 +131,11 @@ def main_ui_panel(is_depth_tab):
124131
inp += go.NORMALMAP_INVERT, gr.Checkbox(label="Invert")
125132

126133
if backbone.get_opt('depthmap_script_gen_heatmap_from_ui', False):
127-
with gr.Box():
134+
with Box():
128135
with gr.Row():
129136
inp += go.GEN_HEATMAP, gr.Checkbox(label="Generate HeatMap")
130137

131-
with gr.Box():
138+
with Box():
132139
with gr.Column():
133140
inp += go.GEN_SIMPLE_MESH, gr.Checkbox(label="Generate simple 3D mesh")
134141
with gr.Column(visible=False) as mesh_options:
@@ -139,7 +146,7 @@ def main_ui_panel(is_depth_tab):
139146
inp += go.SIMPLE_MESH_SPHERICAL, gr.Checkbox(label="Equirectangular projection")
140147

141148
if is_depth_tab:
142-
with gr.Box():
149+
with Box():
143150
with gr.Column():
144151
inp += go.GEN_INPAINTED_MESH, gr.Checkbox(
145152
label="Generate 3D inpainted mesh")
@@ -149,7 +156,7 @@ def main_ui_panel(is_depth_tab):
149156
label="Generate 4 demo videos with 3D inpainted mesh.")
150157
gr.HTML("More options for generating video can be found in the Generate video tab.")
151158

152-
with gr.Box():
159+
with Box():
153160
# TODO: it should be clear from the UI that there is an option of the background removal
154161
# that does not use the model selected above
155162
with gr.Row():
@@ -163,33 +170,49 @@ def main_ui_panel(is_depth_tab):
163170
label="Rembg Model", type="value",
164171
choices=['u2net', 'u2netp', 'u2net_human_seg', 'silueta', "isnet-general-use", "isnet-anime"])
165172

166-
with gr.Box():
173+
with Box():
167174
gr.HTML(f"{SCRIPT_FULL_NAME}<br/>")
168175
gr.HTML("Information, comment and share @ <a "
169176
"href='https://github.com/thygate/stable-diffusion-webui-depthmap-script'>"
170177
"https://github.com/thygate/stable-diffusion-webui-depthmap-script</a>")
171178

172179
def update_default_net_size(model_type):
173180
w, h = ModelHolder.get_default_net_size(model_type)
174-
return inp[go.NET_WIDTH].update(value=w), inp[go.NET_HEIGHT].update(value=h)
175-
181+
if is_gradio4:
182+
return gr.Slider(step=w), gr.Slider(step=h)
183+
else:
184+
return inp[go.NET_WIDTH].update(value=w), inp[go.NET_HEIGHT].update(value=h)
185+
176186
inp[go.MODEL_TYPE].change(
177187
fn=update_default_net_size,
178188
inputs=inp[go.MODEL_TYPE],
179189
outputs=[inp[go.NET_WIDTH], inp[go.NET_HEIGHT]]
180190
)
191+
def update_boost(a, b):
192+
if is_gradio4:
193+
return (gr.Checkbox(visible= not a), gr.Row(visible = not a and not b ))
194+
else:
195+
return (inp[go.NET_SIZE_MATCH].update(visible=not a),
196+
options_depend_on_match_size.update(visible=not a and not b))
181197

182198
inp[go.BOOST].change( # Go boost! Wroom!..
183-
fn=lambda a, b: (inp[go.NET_SIZE_MATCH].update(visible=not a),
184-
options_depend_on_match_size.update(visible=not a and not b)),
199+
fn=update_boost,
185200
inputs=[inp[go.BOOST], inp[go.NET_SIZE_MATCH]],
186201
outputs=[inp[go.NET_SIZE_MATCH], options_depend_on_match_size]
187202
)
188203
inp.add_rule(options_depend_on_match_size, 'visible-if-not', go.NET_SIZE_MATCH)
204+
205+
def update_tiling(a):
206+
if is_gradio4:
207+
if a:
208+
return (gr.Checkbox(value=False), gr.Checkbox(value= True))
209+
return (inp[go.BOOST], inp[go.NET_SIZE_MATCH])
210+
else:
211+
(inp[go.BOOST].update(value=False), inp[go.NET_SIZE_MATCH].update(value=True)
212+
) if a else (inp[go.BOOST].update(), inp[go.NET_SIZE_MATCH].update())
213+
189214
inp[go.TILING_MODE].change( # Go boost! Wroom!..
190-
fn=lambda a: (
191-
inp[go.BOOST].update(value=False), inp[go.NET_SIZE_MATCH].update(value=True)
192-
) if a else (inp[go.BOOST].update(), inp[go.NET_SIZE_MATCH].update()),
215+
fn= update_tiling,
193216
inputs=[inp[go.TILING_MODE]],
194217
outputs=[inp[go.BOOST], inp[go.NET_SIZE_MATCH]]
195218
)
@@ -248,13 +271,13 @@ def depthmap_mode_video(inp):
248271
"pick settings so that the generation is not too slow. For the best results, "
249272
"use a zoedepth model, since they provide the highest level of coherency between frames.")
250273
inp += gr.File(elem_id='depthmap_vm_input', label="Video or animated file",
251-
file_count="single", interactive=True, type="file")
274+
file_count="single", interactive=True, type="binary")
252275
inp += gr.Checkbox(elem_id="depthmap_vm_custom_checkbox",
253276
label="Use custom/pregenerated DepthMap video", value=False)
254277
inp += gr.Dropdown(elem_id="depthmap_vm_smoothening_mode", label="Smoothening",
255278
type="value", choices=['none', 'experimental'], value='experimental')
256279
inp += gr.File(elem_id='depthmap_vm_custom', file_count="single",
257-
interactive=True, type="file", visible=False)
280+
interactive=True, type="binary", visible=False)
258281
with gr.Row():
259282
inp += gr.Checkbox(elem_id='depthmap_vm_compress_checkbox', label="Compress colorvideos?", value=False)
260283
inp += gr.Slider(elem_id='depthmap_vm_compress_bitrate', label="Bitrate (kbit)", visible=False,
@@ -287,11 +310,11 @@ def on_ui_tabs():
287310
elem_id="depthmap_input_image")
288311
# TODO: depthmap generation settings should disappear when using this
289312
inp += gr.File(label="Custom DepthMap", file_count="single", interactive=True,
290-
type="file", elem_id='custom_depthmap_img', visible=False)
313+
type="binary", elem_id='custom_depthmap_img', visible=False)
291314
inp += gr.Checkbox(elem_id="custom_depthmap", label="Use custom DepthMap", value=False)
292315
with gr.TabItem('Batch Process') as depthmap_mode_1:
293316
inp += gr.File(elem_id='image_batch', label="Batch Process", file_count="multiple",
294-
interactive=True, type="file")
317+
interactive=True, type="binary")
295318
with gr.TabItem('Batch from Directory') as depthmap_mode_2:
296319
inp += gr.Textbox(elem_id="depthmap_batch_input_dir", label="Input directory",
297320
**backbone.get_hide_dirs(),
@@ -370,12 +393,20 @@ def on_ui_tabs():
370393
depthmap_mode_2.select(lambda: '2', None, inp['depthmap_mode'])
371394
depthmap_mode_3.select(lambda: '3', None, inp['depthmap_mode'])
372395

396+
is_gradio4 = int(gr.__version__[0])>3
373397
def custom_depthmap_change_fn(mode, zero_on, three_on):
374398
hide = mode == '0' and zero_on or mode == '3' and three_on
375-
return inp['custom_depthmap_img'].update(visible=hide), \
376-
inp['depthmap_gen_row_0'].update(visible=not hide), \
377-
inp['depthmap_gen_row_1'].update(visible=not hide), \
378-
inp['depthmap_gen_row_3'].update(visible=not hide), not hide
399+
if is_gradio4:
400+
return gr.Row(visible=hide), \
401+
gr.Group(visible = not hide), \
402+
gr.Group(visible = not hide), \
403+
gr.Group(visible = not hide), not hide
404+
else:
405+
return inp['custom_depthmap_img'].update(visible=hide), \
406+
inp['depthmap_gen_row_0'].update(visible=not hide), \
407+
inp['depthmap_gen_row_1'].update(visible=not hide), \
408+
inp['depthmap_gen_row_3'].update(visible=not hide), not hide
409+
379410
custom_depthmap_change_els = ['depthmap_mode', 'custom_depthmap', 'depthmap_vm_custom_checkbox']
380411
for el in custom_depthmap_change_els:
381412
inp[el].change(

src/gradio_args_transport.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,23 @@ def enkey_body(self):
6969
return [self.internal[x] for x in sorted(list(self.internal.keys()))]
7070

7171
def add_rule(self, first, rule, second):
72+
is_gradio4 = int(gr.__version__[0])>3
7273
first = self[first] if first in self else first
7374
second = self[second] if second in self else second
7475
if rule == 'visible-if-not':
75-
second.change(fn=lambda v: first.update(visible=not v), inputs=[second], outputs=[first])
76+
def update_visible_if_not(v):
77+
if is_gradio4:
78+
return gr.Column(visible=not v)
79+
else:
80+
return first.update(visible=not v)
81+
second.change(update_visible_if_not, [second], [first])
7682
elif rule == 'visible-if':
77-
second.change(fn=lambda v: first.update(visible=v), inputs=[second], outputs=[first])
83+
def update_visible_if(v):
84+
if is_gradio4:
85+
return gr.Column(visible=v)
86+
else:
87+
return first.update(visible=v)
88+
second.change(update_visible_if, [second], [first])
7889
else:
7990
raise Exception(f'Unknown rule type {rule}')
8091

0 commit comments

Comments
 (0)