Skip to content

Commit b3d60bd

Browse files
psychedelicioushipsterusername
authored andcommitted
feat(nodes): add NormalMapInvocation
Similar to the existing node, but without any resizing and with a revised model loading API that uses the model manager. All code related to the invocation now lives in the Invoke repo. Unfortunately, this includes a whole git repo for EfficientNet. I believe we could use the package `timm` instead of this, but it's beyond me.
1 parent fd42da5 commit b3d60bd

40 files changed

+6234
-0
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
2+
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
3+
from invokeai.app.invocations.primitives import ImageOutput
4+
from invokeai.app.services.shared.invocation_context import InvocationContext
5+
from invokeai.backend.image_util.normal_bae import NormalMapDetector
6+
from invokeai.backend.image_util.normal_bae.nets.NNET import NNET
7+
8+
9+
@invocation(
10+
"normal_map",
11+
title="Normal Map",
12+
tags=["controlnet", "normal"],
13+
category="controlnet",
14+
version="1.0.0",
15+
)
16+
class NormalMapInvocation(BaseInvocation, WithMetadata, WithBoard):
17+
"""Generates a normal map."""
18+
19+
image: ImageField = InputField(description="The image to process")
20+
21+
def invoke(self, context: InvocationContext) -> ImageOutput:
22+
image = context.images.get_pil(self.image.image_name, "RGB")
23+
loaded_model = context.models.load_remote_model(NormalMapDetector.get_model_url(), NormalMapDetector.load_model)
24+
25+
with loaded_model as model:
26+
assert isinstance(model, NNET)
27+
detector = NormalMapDetector(model)
28+
normal_map = detector.run(image=image)
29+
30+
image_dto = context.images.save(image=normal_map)
31+
return ImageOutput.build(image_dto)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2022 Caroline Chan
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Adapted from https://github.com/huggingface/controlnet_aux
2+
3+
import pathlib
4+
import types
5+
6+
import cv2
7+
import huggingface_hub
8+
import numpy as np
9+
import torch
10+
import torchvision.transforms as transforms
11+
from einops import rearrange
12+
from PIL import Image
13+
14+
from invokeai.backend.image_util.normal_bae.nets.NNET import NNET
15+
from invokeai.backend.image_util.util import np_to_pil, pil_to_np, resize_to_multiple
16+
17+
18+
class NormalMapDetector:
19+
"""Simple wrapper around the Normal BAE model for normal map generation."""
20+
21+
hf_repo_id = "lllyasviel/Annotators"
22+
hf_filename = "scannet.pt"
23+
24+
@classmethod
25+
def get_model_url(cls) -> str:
26+
"""Get the URL to download the model from the Hugging Face Hub."""
27+
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename)
28+
29+
@classmethod
30+
def load_model(cls, model_path: pathlib.Path) -> NNET:
31+
"""Load the model from a file."""
32+
33+
args = types.SimpleNamespace()
34+
args.mode = "client"
35+
args.architecture = "BN"
36+
args.pretrained = "scannet"
37+
args.sampling_ratio = 0.4
38+
args.importance_ratio = 0.7
39+
40+
model = NNET(args)
41+
42+
ckpt = torch.load(model_path, map_location="cpu")["model"]
43+
load_dict = {}
44+
for k, v in ckpt.items():
45+
if k.startswith("module."):
46+
k_ = k.replace("module.", "")
47+
load_dict[k_] = v
48+
else:
49+
load_dict[k] = v
50+
51+
model.load_state_dict(load_dict)
52+
model.eval()
53+
54+
return model
55+
56+
def __init__(self, model: NNET) -> None:
57+
self.model = model
58+
self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
59+
60+
def to(self, device: torch.device):
61+
self.model.to(device)
62+
return self
63+
64+
def run(self, image: Image.Image):
65+
"""Processes an image and returns the detected normal map."""
66+
67+
device = next(iter(self.model.parameters())).device
68+
np_image = pil_to_np(image)
69+
70+
height, width, _channels = np_image.shape
71+
72+
# The model requires the image to be a multiple of 8
73+
np_image = resize_to_multiple(np_image, 8)
74+
75+
image_normal = np_image
76+
77+
with torch.no_grad():
78+
image_normal = torch.from_numpy(image_normal).float().to(device)
79+
image_normal = image_normal / 255.0
80+
image_normal = rearrange(image_normal, "h w c -> 1 c h w")
81+
image_normal = self.norm(image_normal)
82+
83+
normal = self.model(image_normal)
84+
normal = normal[0][-1][:, :3]
85+
normal = ((normal + 1) * 0.5).clip(0, 1)
86+
87+
normal = rearrange(normal[0], "c h w -> h w c").cpu().numpy()
88+
normal_image = (normal * 255.0).clip(0, 255).astype(np.uint8)
89+
90+
# Back to the original size
91+
output_image = cv2.resize(normal_image, (width, height), interpolation=cv2.INTER_LINEAR)
92+
93+
return np_to_pil(output_image)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
from .submodules.encoder import Encoder
6+
from .submodules.decoder import Decoder
7+
8+
9+
class NNET(nn.Module):
10+
def __init__(self, args):
11+
super(NNET, self).__init__()
12+
self.encoder = Encoder()
13+
self.decoder = Decoder(args)
14+
15+
def get_1x_lr_params(self): # lr/10 learning rate
16+
return self.encoder.parameters()
17+
18+
def get_10x_lr_params(self): # lr learning rate
19+
return self.decoder.parameters()
20+
21+
def forward(self, img, **kwargs):
22+
return self.decoder(self.encoder(img), **kwargs)

invokeai/backend/image_util/normal_bae/nets/__init__.py

Whitespace-only changes.
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
from .submodules.submodules import UpSampleBN, norm_normalize
6+
7+
8+
# This is the baseline encoder-decoder we used in the ablation study
9+
class NNET(nn.Module):
10+
def __init__(self, args=None):
11+
super(NNET, self).__init__()
12+
self.encoder = Encoder()
13+
self.decoder = Decoder(num_classes=4)
14+
15+
def forward(self, x, **kwargs):
16+
out = self.decoder(self.encoder(x), **kwargs)
17+
18+
# Bilinearly upsample the output to match the input resolution
19+
up_out = F.interpolate(out, size=[x.size(2), x.size(3)], mode='bilinear', align_corners=False)
20+
21+
# L2-normalize the first three channels / ensure positive value for concentration parameters (kappa)
22+
up_out = norm_normalize(up_out)
23+
return up_out
24+
25+
def get_1x_lr_params(self): # lr/10 learning rate
26+
return self.encoder.parameters()
27+
28+
def get_10x_lr_params(self): # lr learning rate
29+
modules = [self.decoder]
30+
for m in modules:
31+
yield from m.parameters()
32+
33+
34+
# Encoder
35+
class Encoder(nn.Module):
36+
def __init__(self):
37+
super(Encoder, self).__init__()
38+
39+
basemodel_name = 'tf_efficientnet_b5_ap'
40+
basemodel = torch.hub.load('rwightman/gen-efficientnet-pytorch', basemodel_name, pretrained=True)
41+
42+
# Remove last layer
43+
basemodel.global_pool = nn.Identity()
44+
basemodel.classifier = nn.Identity()
45+
46+
self.original_model = basemodel
47+
48+
def forward(self, x):
49+
features = [x]
50+
for k, v in self.original_model._modules.items():
51+
if (k == 'blocks'):
52+
for ki, vi in v._modules.items():
53+
features.append(vi(features[-1]))
54+
else:
55+
features.append(v(features[-1]))
56+
return features
57+
58+
59+
# Decoder (no pixel-wise MLP, no uncertainty-guided sampling)
60+
class Decoder(nn.Module):
61+
def __init__(self, num_classes=4):
62+
super(Decoder, self).__init__()
63+
self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0)
64+
self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024)
65+
self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512)
66+
self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256)
67+
self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128)
68+
self.conv3 = nn.Conv2d(128, num_classes, kernel_size=3, stride=1, padding=1)
69+
70+
def forward(self, features):
71+
x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11]
72+
x_d0 = self.conv2(x_block4)
73+
x_d1 = self.up1(x_d0, x_block3)
74+
x_d2 = self.up2(x_d1, x_block2)
75+
x_d3 = self.up3(x_d2, x_block1)
76+
x_d4 = self.up4(x_d3, x_block0)
77+
out = self.conv3(x_d4)
78+
return out
79+
80+
81+
if __name__ == '__main__':
82+
model = Baseline()
83+
x = torch.rand(2, 3, 480, 640)
84+
out = model(x)
85+
print(out.shape)

invokeai/backend/image_util/normal_bae/nets/submodules/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)