Skip to content

Commit 615edde

Browse files
psychedelicioushipsterusername
authored andcommitted
feat(nodes): add PiDiNetEdgeDetectionInvocation
Similar to the existing node, but without any resizing and with a revised model loading API that uses the model manager. All code related to the invocation now lives in the Invoke repo.
1 parent b3d60bd commit 615edde

File tree

5 files changed

+824
-2
lines changed

5 files changed

+824
-2
lines changed

invokeai/app/invocations/pidi.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
2+
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, WithBoard, WithMetadata
3+
from invokeai.app.invocations.primitives import ImageOutput
4+
from invokeai.app.services.shared.invocation_context import InvocationContext
5+
from invokeai.backend.image_util.pidi import PIDINetDetector
6+
from invokeai.backend.image_util.pidi.model import PiDiNet
7+
8+
9+
@invocation(
10+
"pidi_edge_detection",
11+
title="PiDiNet Edge Detection",
12+
tags=["controlnet", "edge"],
13+
category="controlnet",
14+
version="1.0.0",
15+
)
16+
class PiDiNetEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
17+
"""Generates an edge map using PiDiNet."""
18+
19+
image: ImageField = InputField(description="The image to process")
20+
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
21+
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
22+
23+
def invoke(self, context: InvocationContext) -> ImageOutput:
24+
image = context.images.get_pil(self.image.image_name, "RGB")
25+
loaded_model = context.models.load_remote_model(PIDINetDetector.get_model_url(), PIDINetDetector.load_model)
26+
27+
with loaded_model as model:
28+
assert isinstance(model, PiDiNet)
29+
detector = PIDINetDetector(model)
30+
edge_map = detector.run(image=image, safe=self.safe, scribble=self.scribble)
31+
32+
image_dto = context.images.save(image=edge_map)
33+
return ImageOutput.build(image_dto)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Adapted from https://github.com/huggingface/controlnet_aux
2+
3+
import pathlib
4+
5+
import cv2
6+
import huggingface_hub
7+
import numpy as np
8+
import torch
9+
from einops import rearrange
10+
from PIL import Image
11+
12+
from invokeai.backend.image_util.pidi.model import PiDiNet, pidinet
13+
from invokeai.backend.image_util.util import nms, normalize_image_channel_count, np_to_pil, pil_to_np, safe_step
14+
15+
16+
class PIDINetDetector:
17+
"""Simple wrapper around a PiDiNet model for edge detection."""
18+
19+
hf_repo_id = "lllyasviel/Annotators"
20+
hf_filename = "table5_pidinet.pth"
21+
22+
@classmethod
23+
def get_model_url(cls) -> str:
24+
"""Get the URL to download the model from the Hugging Face Hub."""
25+
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename)
26+
27+
@classmethod
28+
def load_model(cls, model_path: pathlib.Path) -> PiDiNet:
29+
"""Load the model from a file."""
30+
31+
model = pidinet()
32+
model.load_state_dict({k.replace("module.", ""): v for k, v in torch.load(model_path)["state_dict"].items()})
33+
model.eval()
34+
return model
35+
36+
def __init__(self, model: PiDiNet) -> None:
37+
self.model = model
38+
39+
def to(self, device: torch.device):
40+
self.model.to(device)
41+
return self
42+
43+
def run(
44+
self, image: Image.Image, safe: bool = False, scribble: bool = False, apply_filter: bool = False
45+
) -> Image.Image:
46+
"""Processes an image and returns the detected edges."""
47+
48+
device = next(iter(self.model.parameters())).device
49+
50+
np_img = pil_to_np(image)
51+
np_img = normalize_image_channel_count(np_img)
52+
53+
assert np_img.ndim == 3
54+
55+
bgr_img = np_img[:, :, ::-1].copy()
56+
57+
with torch.no_grad():
58+
image_pidi = torch.from_numpy(bgr_img).float().to(device)
59+
image_pidi = image_pidi / 255.0
60+
image_pidi = rearrange(image_pidi, "h w c -> 1 c h w")
61+
edge = self.model(image_pidi)[-1]
62+
edge = edge.cpu().numpy()
63+
if apply_filter:
64+
edge = edge > 0.5
65+
if safe:
66+
edge = safe_step(edge)
67+
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
68+
69+
detected_map = edge[0, 0]
70+
71+
if scribble:
72+
detected_map = nms(detected_map, 127, 3.0)
73+
detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
74+
detected_map[detected_map > 4] = 255
75+
detected_map[detected_map < 255] = 0
76+
77+
output_img = np_to_pil(detected_map)
78+
79+
return output_img

0 commit comments

Comments
 (0)