Skip to content

[keras_hub/src/models/control_net] Add ControlNet #2209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 187 additions & 0 deletions keras_hub/src/models/control_net/EncodeDecode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import keras

from .kerasCVDiffusionModels import GroupNormalization


class Decoder(keras.Sequential):
def __init__(self, img_height, img_width, name=None, download_weights=False):
super().__init__(
[
keras.layers.Input((img_height // 8, img_width // 8, 4)),
keras.layers.Rescaling(1.0 / 0.18215),
PaddedConv2D(4, 1, name="PostQuantConvolutionalIn"),
PaddedConv2D(512, 3, padding="same", name="ConvolutionalIn"),
ResnetBlock(512),
AttentionBlock(512),
ResnetBlock(512),
ResnetBlock(512),
ResnetBlock(512),
ResnetBlock(512),
keras.layers.UpSampling2D(size=(2, 2)),
PaddedConv2D(512, 3, padding="same"),
ResnetBlock(512),
ResnetBlock(512),
ResnetBlock(512),
keras.layers.UpSampling2D(size=(2, 2)),
PaddedConv2D(512, 3, padding="same"),
ResnetBlock(256),
ResnetBlock(256),
ResnetBlock(256),
keras.layers.UpSampling2D(size=(2, 2)),
PaddedConv2D(256, 3, padding="same"),
ResnetBlock(128),
ResnetBlock(128),
ResnetBlock(128),
GroupNormalization(epsilon=1e-5),
keras.layers.Activation("swish"),
PaddedConv2D(3, 3, padding="same", name="ConvolutionalOut"),
],
name=name,
)

if download_weights:
decoder_weights_fpath = keras.utils.get_file(
origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_decoder.h5",
file_hash="ad350a65cc8bc4a80c8103367e039a3329b4231c2469a1093869a345f55b1962",
)
self.load_weights(decoder_weights_fpath)


class ImageEncoder(keras.Sequential):
"""ImageEncoder is the VAE Encoder for StableDiffusion."""

def __init__(self, img_height=512, img_width=512, download_weights=False):
super().__init__(
[
keras.layers.Input((img_height, img_width, 3)),
PaddedConv2D(128, 3, padding="same"),
ResnetBlock(128),
ResnetBlock(128),
PaddedConv2D(128, 3, padding="same", strides=2),
ResnetBlock(256),
ResnetBlock(256),
PaddedConv2D(256, 3, padding="same", strides=2),
ResnetBlock(512),
ResnetBlock(512),
PaddedConv2D(512, 3, padding="same", strides=2),
ResnetBlock(512),
ResnetBlock(512),
ResnetBlock(512),
AttentionBlock(512),
ResnetBlock(512),
GroupNormalization(epsilon=1e-5),
keras.layers.Activation("swish"),
PaddedConv2D(8, 3, padding="same"),
PaddedConv2D(8, 1),
# TODO(lukewood): can this be refactored to be a Rescaling layer?
# Perhaps some sort of rescale and gather?
# Either way, we may need a lambda to gather the first 4 dimensions.
keras.layers.Lambda(lambda x: x[..., :4] * 0.18215),
]
)


"""
Blocks
"""


class ResnetBlock(keras.layers.Layer):
def __init__(self, output_dim, **kwargs):
super().__init__(**kwargs)
self.output_dim = output_dim
self.norm1 = GroupNormalization(epsilon=1e-5)
self.conv1 = PaddedConv2D(output_dim, 3, padding="same")
self.norm2 = GroupNormalization(epsilon=1e-5)
self.conv2 = PaddedConv2D(output_dim, 3, padding="same")

def build(self, input_shape):
if input_shape[-1] != self.output_dim:
self.residual_projection = PaddedConv2D(self.output_dim, 1)
else:
self.residual_projection = lambda x: x

def call(self, inputs):
x = self.conv1(keras.activations.swish(self.norm1(inputs)))
x = self.conv2(keras.activations.swish(self.norm2(x)))
return x + self.residual_projection(inputs)

def get_config(self):
config = super().get_config()
config.update(
{
"output_dim": self.output_dim,
}
)
return config


class AttentionBlock(keras.layers.Layer):
def __init__(self, output_dim, **kwargs):
super().__init__(**kwargs)
self.output_dim = output_dim
self.norm = GroupNormalization(epsilon=1e-5)
self.q = PaddedConv2D(output_dim, 1)
self.k = PaddedConv2D(output_dim, 1)
self.v = PaddedConv2D(output_dim, 1)
self.proj_out = PaddedConv2D(output_dim, 1)

def get_config(self):
config = super().get_config()
config.update(
{
"output_dim": self.output_dim,
}
)
return config

def call(self, inputs):
x = self.norm(inputs)
q, k, v = self.q(x), self.k(x), self.v(x)

# Compute attention
_, h, w, c = q.shape
q = keras.ops.reshape(q, (-1, h * w, c)) # b, hw, c
k = keras.ops.transpose(k, (0, 3, 1, 2))
k = keras.ops.reshape(k, (-1, c, h * w)) # b, c, hw
y = q @ k
y = y * (c**-0.5)
y = keras.activations.softmax(y)

# Attend to values
v = keras.ops.transpose(v, (0, 3, 1, 2))
v = keras.ops.reshape(v, (-1, c, h * w))
y = keras.ops.transpose(y, (0, 2, 1))
x = v @ y
x = keras.ops.transpose(x, (0, 2, 1))
x = keras.ops.reshape(x, (-1, h, w, c))
return self.proj_out(x) + inputs


class PaddedConv2D(keras.layers.Layer):
def __init__(
self, filters, kernel_size, padding="valid", strides=1, name=None, **kwargs
):
super().__init__(**kwargs)
self.conv2d = keras.layers.Conv2D(
filters, kernel_size, strides=strides, padding=padding, name=name
)
self.filters = filters
self.kernel_size = kernel_size
self.padding = padding
self.strides = strides

def call(self, inputs):
return self.conv2d(inputs)

def get_config(self):
config = super().get_config()
config.update(
{
"filters": self.filters,
"kernel_size": self.kernel_size,
"padding": self.padding,
"strides": self.strides,
}
)
return config
5 changes: 5 additions & 0 deletions keras_hub/src/models/control_net/ReadMe.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
### Stable Diffusion TensorFlow ###

Originally implemented by Divum Gupta, then heavily modified by https://github.com/soten355 into [MetalDiffusion](https://github.com/soten355/MetalDiffusion).

Finally, the TensorFlow variant was extracted out and converted to Keras 3.
Empty file.
155 changes: 155 additions & 0 deletions keras_hub/src/models/control_net/clipEncoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import keras
import numpy as np

from .layers import quick_gelu


# Step 1
# Create and return the CLIP Embeddings
class CLIPTextTransformer(keras.models.Model):
def __init__(self, maxLength=77, vocabularySize=49408):
super().__init__()

# Create embeddings -> Step 2
self.embeddings = CLIPTextEmbeddings(
maxLength=maxLength, vocabularySize=vocabularySize
)

# Create encoder -> Step 3
self.encoder = CLIPEncoder()

self.final_layer_norm = keras.layers.LayerNormalization(
epsilon=1e-5, name="FinalLayerNormalization"
)
self.causal_attention_mask = keras.initializers.Constant(
np.triu(np.ones((1, 1, 77, 77), dtype="float32") * -np.inf, k=1),
name="CausalAttentionMask",
)

def call(self, inputs):
input_ids, position_ids = inputs
x = self.embeddings([input_ids, position_ids])
x = self.encoder([x, self.causal_attention_mask])
return self.final_layer_norm(x)


# Step 2
# Create and return word and position embeddings


class CLIPTextEmbeddings(keras.layers.Layer):
def __init__(self, maxLength=77, vocabularySize=49408, embeddingSize=768):
super().__init__()
self.token_embedding_layer = keras.layers.Embedding(
vocabularySize, embeddingSize, name="token_embedding"
)
self.position_embedding_layer = keras.layers.Embedding(
maxLength, embeddingSize, name="position_embedding"
)

def call(self, inputs):
input_ids, position_ids = inputs
word_embeddings = self.token_embedding_layer(input_ids)
position_embeddings = self.position_embedding_layer(position_ids)
return word_embeddings + position_embeddings


# Step 3
# Create and return the hidden states (aka hidden size)
class CLIPEncoder(keras.layers.Layer):
def __init__(self):
super().__init__()
self.layers = [CLIPEncoderLayer() for i in range(12)]

def call(self, inputs):
[hidden_states, causal_attention_mask] = inputs
for l in self.layers:
hidden_states = l([hidden_states, causal_attention_mask])
return hidden_states


# Step 4 (also creatd in step 3)
# Create the layers
class CLIPEncoderLayer(keras.layers.Layer):
def __init__(self, intermediateSize=3072, embeddingSize=768):
super().__init__()
self.layer_norm1 = keras.layers.LayerNormalization(
epsilon=1e-5, name="LayerNormalization001"
)
self.self_attn = CLIPAttention()
self.layer_norm2 = keras.layers.LayerNormalization(
epsilon=1e-5, name="LayerNormalization002"
)
self.fc1 = keras.layers.Dense(intermediateSize, name="FC1")
self.fc2 = keras.layers.Dense(embeddingSize, name="FC2")

def call(self, inputs):
hidden_states, causal_attention_mask = inputs
residual = hidden_states

hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn([hidden_states, causal_attention_mask])
hidden_states = residual + hidden_states

residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)

hidden_states = self.fc1(hidden_states)
hidden_states = quick_gelu(hidden_states)
hidden_states = self.fc2(hidden_states)

return residual + hidden_states


class CLIPAttention(keras.layers.Layer):
def __init__(self):
super().__init__()
self.embed_dim = 768
self.num_heads = 12
self.head_dim = self.embed_dim // self.num_heads
self.scale = self.head_dim**-0.5
self.q_proj = keras.layers.Dense(self.embed_dim, name="QueryState")
self.k_proj = keras.layers.Dense(self.embed_dim, name="KeyState")
self.v_proj = keras.layers.Dense(self.embed_dim, name="ValueState")
self.out_proj = keras.layers.Dense(self.embed_dim, name="OutProjection")

def _shape(self, tensor, seq_len: int, bsz: int):
a = keras.ops.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim))
return keras.layers.Permute((2, 1, 3))(a) # bs , n_head , seq_len , head_dim

def call(self, inputs):
hidden_states, causal_attention_mask = inputs
bsz, tgt_len, embed_dim = hidden_states.shape
query_states = self.q_proj(hidden_states) * self.scale
key_states = self._shape(self.k_proj(hidden_states), tgt_len, -1)
value_states = self._shape(self.v_proj(hidden_states), tgt_len, -1)

proj_shape = (-1, tgt_len, self.head_dim)
query_states = self._shape(query_states, tgt_len, -1)
query_states = keras.ops.reshape(query_states, proj_shape)
key_states = keras.ops.reshape(key_states, proj_shape)

src_len = tgt_len
value_states = keras.ops.reshape(value_states, proj_shape)
attn_weights = query_states @ keras.layers.Permute((2, 1))(key_states)

attn_weights = keras.ops.reshape(
attn_weights, (-1, self.num_heads, tgt_len, src_len)
)
# print("attn_weights dtype:",attn_weights.dtype)
# print('casual dtype:',causal_attention_mask.dtype)
# Convert the causal_attention_mask tensor to the same data type as attn_weights
# causal_attention_mask = keras.ops.cast(causal_attention_mask, dtype=attn_weights.dtype)
attn_weights = attn_weights + causal_attention_mask
attn_weights = keras.ops.reshape(attn_weights, (-1, tgt_len, src_len))

attn_weights = keras.ops.softmax(attn_weights)
attn_output = attn_weights @ value_states

attn_output = keras.ops.reshape(
attn_output, (-1, self.num_heads, tgt_len, self.head_dim)
)
attn_output = keras.layers.Permute((2, 1, 3))(attn_output)
attn_output = keras.ops.reshape(attn_output, (-1, tgt_len, embed_dim))

return self.out_proj(attn_output)
3 changes: 3 additions & 0 deletions keras_hub/src/models/control_net/clipTokenizer/ReadMe.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
## CLIP Tokenizer

This folder contains the files necessary for the CLIP Tokenizer.
Loading