Skip to content

Commit 4887a9c

Browse files
davidmartinriussemjon00
authored andcommitted
Add video depth generation endpoint
This commit adds a new API endpoint for generating depth maps from input images in a video format. The endpoint supports various depth model options, including different pre-trained models. It also validates and processes video parameters such as number of frames, frames per second, trajectory, shift, border, dolly, format, and super-sampling anti-aliasing. The commit includes error handling for missing input images, invalid model types, and required video parameters. Additionally, it checks if a mesh file already exists, and if not, it generates a new one. The generated mesh is then used to create a depth video based on the specified parameters. See more information in the pull request description.
1 parent b8120b4 commit 4887a9c

File tree

2 files changed

+123
-22
lines changed

2 files changed

+123
-22
lines changed

scripts/depthmap_api.py

+103-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
# Non-public API. Don't host publicly - SECURITY RISKS!
2-
# (will only be on with --api starting option)
3-
# Currently no API stability guarantees are provided - API may break on any new commit.
1+
# DO NOT HOST PUBLICLY - SECURITY RISKS!
2+
# (the API will only be on with --api starting option)
3+
# Currently no API stability guarantees are provided - API may break on any new commit (but hopefully won't).
44

5+
import os
56
import numpy as np
67
from fastapi import FastAPI, Body
78
from fastapi.exceptions import HTTPException
@@ -12,7 +13,7 @@
1213
from typing import Dict, List
1314
from modules.api import api
1415

15-
from src.core import core_generation_funnel
16+
from src.core import core_generation_funnel, run_makevideo
1617
from src.misc import SCRIPT_VERSION
1718
from src import backbone
1819
from src.common_constants import GenerationOptions as go
@@ -70,12 +71,110 @@ async def process(
7071
if not isinstance(result, Image.Image):
7172
continue
7273
results_based += [encode_to_base64(result)]
74+
7375
return {"images": results_based, "info": "Success"}
7476

77+
@app.post("/depth/generate/video")
78+
async def process_video(
79+
depth_input_images: List[str] = Body([], title='Input Images'),
80+
options: Dict[str, object] = Body("options", title='Generation options'),
81+
):
82+
if len(depth_input_images) == 0:
83+
raise HTTPException(status_code=422, detail="No images supplied")
84+
print(f"Processing {str(len(depth_input_images))} images trough the API")
85+
86+
available_models = {
87+
'res101': 0,
88+
'dpt_beit_large_512': 1, #midas 3.1
89+
'dpt_beit_large_384': 2, #midas 3.1
90+
'dpt_large_384': 3, #midas 3.0
91+
'dpt_hybrid_384': 4, #midas 3.0
92+
'midas_v21': 5,
93+
'midas_v21_small': 6,
94+
'zoedepth_n': 7, #indoor
95+
'zoedepth_k': 8, #outdoor
96+
'zoedepth_nk': 9,
97+
}
98+
99+
model_type = options["model_type"]
100+
101+
model_id = None
102+
if isinstance(model_type, str):
103+
# Check if the string is in the available_models dictionary
104+
if model_type in available_models:
105+
model_id = available_models[model_type]
106+
else:
107+
available_strings = list(available_models.keys())
108+
raise HTTPException(status_code=400, detail={'error': 'Invalid model string', 'available_models': available_strings})
109+
elif isinstance(model_type, int):
110+
model_id = model_type
111+
else:
112+
raise HTTPException(status_code=400, detail={'error': 'Invalid model parameter type'})
113+
114+
options["model_type"] = model_id
115+
116+
video_parameters = options["video_parameters"]
117+
118+
required_params = ["vid_numframes", "vid_fps", "vid_traj", "vid_shift", "vid_border", "dolly", "vid_format", "vid_ssaa", "output_filename"]
119+
120+
missing_params = [param for param in required_params if param not in video_parameters]
121+
122+
if missing_params:
123+
raise HTTPException(status_code=400, detail={'error': f"Missing required parameter(s): {', '.join(missing_params)}"})
124+
125+
vid_numframes = video_parameters["vid_numframes"]
126+
vid_fps = video_parameters["vid_fps"]
127+
vid_traj = video_parameters["vid_traj"]
128+
vid_shift = video_parameters["vid_shift"]
129+
vid_border = video_parameters["vid_border"]
130+
dolly = video_parameters["dolly"]
131+
vid_format = video_parameters["vid_format"]
132+
vid_ssaa = int(video_parameters["vid_ssaa"])
133+
134+
output_filename = video_parameters["output_filename"]
135+
output_path = os.path.dirname(output_filename)
136+
basename, extension = os.path.splitext(os.path.basename(output_filename))
137+
138+
# Comparing video_format with the extension
139+
if vid_format != extension[1:]:
140+
raise HTTPException(status_code=400, detail={'error': f"Video format '{vid_format}' does not match with the extension '{extension}'."})
141+
142+
pil_images = []
143+
for input_image in depth_input_images:
144+
pil_images.append(to_base64_PIL(input_image))
145+
outpath = backbone.get_outpath()
146+
147+
mesh_fi_filename = video_parameters.get('mesh_fi_filename', None)
148+
149+
if mesh_fi_filename and os.path.exists(mesh_fi_filename):
150+
mesh_fi = mesh_fi_filename
151+
print("Loaded existing mesh from: ", mesh_fi)
152+
else:
153+
# If there is no mesh file generate it.
154+
options["GEN_INPAINTED_MESH"] = True
155+
156+
gen_obj = core_generation_funnel(outpath, pil_images, None, None, options)
157+
158+
mesh_fi = None
159+
for count, type, result in gen_obj:
160+
if type == 'inpainted_mesh':
161+
mesh_fi = result
162+
break
163+
164+
if mesh_fi:
165+
print("Created mesh in: ", mesh_fi)
166+
else:
167+
raise HTTPException(status_code=400, detail={'error': "The mesh has not been created"})
168+
169+
run_makevideo(mesh_fi, vid_numframes, vid_fps, vid_traj, vid_shift, vid_border, dolly, vid_format, vid_ssaa, output_path, basename)
170+
171+
return {"info": "Success"}
172+
75173

76174
try:
77175
import modules.script_callbacks as script_callbacks
78176
if backbone.get_cmd_opt('api', False):
79177
script_callbacks.on_app_started(depth_api)
178+
print("Started the depthmap API. DO NOT HOST PUBLICLY - SECURITY RISKS!")
80179
except:
81180
print('DepthMap API could not start')

src/core.py

+20-18
Original file line numberDiff line numberDiff line change
@@ -578,9 +578,8 @@ def run_3dphoto_videos(mesh_fi, basename, outpath, num_frames, fps, crop_border,
578578
fnExt=vid_format)
579579
return fn_saved
580580

581-
582-
# called from gen vid tab button
583-
def run_makevideo(fn_mesh, vid_numframes, vid_fps, vid_traj, vid_shift, vid_border, dolly, vid_format, vid_ssaa):
581+
def run_makevideo(fn_mesh, vid_numframes, vid_fps, vid_traj, vid_shift, vid_border, dolly, vid_format, vid_ssaa,
582+
outpath=None, basename=None):
584583
if len(fn_mesh) == 0 or not os.path.exists(fn_mesh):
585584
raise Exception("Could not open mesh.")
586585

@@ -608,20 +607,24 @@ def run_makevideo(fn_mesh, vid_numframes, vid_fps, vid_traj, vid_shift, vid_bord
608607
raise Exception("Crop Border requires 4 elements.")
609608
crop_border = [float(borders[0]), float(borders[1]), float(borders[2]), float(borders[3])]
610609

611-
# output path and filename mess ..
612-
basename = Path(fn_mesh).stem
613-
outpath = backbone.get_outpath()
614-
# unique filename
615-
basecount = backbone.get_next_sequence_number(outpath, basename)
616-
if basecount > 0: basecount = basecount - 1
617-
fullfn = None
618-
for i in range(500):
619-
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
620-
fullfn = os.path.join(outpath, f"{fn}_." + vid_format)
621-
if not os.path.exists(fullfn):
622-
break
623-
basename = Path(fullfn).stem
624-
basename = basename[:-1]
610+
if not outpath:
611+
outpath = backbone.get_outpath()
612+
613+
if not basename:
614+
# output path and filename mess ..
615+
basename = Path(fn_mesh).stem
616+
617+
# unique filename
618+
basecount = backbone.get_next_sequence_number(outpath, basename)
619+
if basecount > 0: basecount = basecount - 1
620+
fullfn = None
621+
for i in range(500):
622+
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
623+
fullfn = os.path.join(outpath, f"{fn}_." + vid_format)
624+
if not os.path.exists(fullfn):
625+
break
626+
basename = Path(fullfn).stem
627+
basename = basename[:-1]
625628

626629
print("Loading mesh ..")
627630

@@ -630,7 +633,6 @@ def run_makevideo(fn_mesh, vid_numframes, vid_fps, vid_traj, vid_shift, vid_bord
630633

631634
return fn_saved[-1], fn_saved[-1], ''
632635

633-
634636
def unload_models():
635637
model_holder.unload_models()
636638

0 commit comments

Comments
 (0)