diff --git a/keras_hub/src/models/control_net/EncodeDecode.py b/keras_hub/src/models/control_net/EncodeDecode.py new file mode 100644 index 0000000000..9bac94cb79 --- /dev/null +++ b/keras_hub/src/models/control_net/EncodeDecode.py @@ -0,0 +1,187 @@ +import keras + +from .kerasCVDiffusionModels import GroupNormalization + + +class Decoder(keras.Sequential): + def __init__(self, img_height, img_width, name=None, download_weights=False): + super().__init__( + [ + keras.layers.Input((img_height // 8, img_width // 8, 4)), + keras.layers.Rescaling(1.0 / 0.18215), + PaddedConv2D(4, 1, name="PostQuantConvolutionalIn"), + PaddedConv2D(512, 3, padding="same", name="ConvolutionalIn"), + ResnetBlock(512), + AttentionBlock(512), + ResnetBlock(512), + ResnetBlock(512), + ResnetBlock(512), + ResnetBlock(512), + keras.layers.UpSampling2D(size=(2, 2)), + PaddedConv2D(512, 3, padding="same"), + ResnetBlock(512), + ResnetBlock(512), + ResnetBlock(512), + keras.layers.UpSampling2D(size=(2, 2)), + PaddedConv2D(512, 3, padding="same"), + ResnetBlock(256), + ResnetBlock(256), + ResnetBlock(256), + keras.layers.UpSampling2D(size=(2, 2)), + PaddedConv2D(256, 3, padding="same"), + ResnetBlock(128), + ResnetBlock(128), + ResnetBlock(128), + GroupNormalization(epsilon=1e-5), + keras.layers.Activation("swish"), + PaddedConv2D(3, 3, padding="same", name="ConvolutionalOut"), + ], + name=name, + ) + + if download_weights: + decoder_weights_fpath = keras.utils.get_file( + origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_decoder.h5", + file_hash="ad350a65cc8bc4a80c8103367e039a3329b4231c2469a1093869a345f55b1962", + ) + self.load_weights(decoder_weights_fpath) + + +class ImageEncoder(keras.Sequential): + """ImageEncoder is the VAE Encoder for StableDiffusion.""" + + def __init__(self, img_height=512, img_width=512, download_weights=False): + super().__init__( + [ + keras.layers.Input((img_height, img_width, 3)), + PaddedConv2D(128, 3, padding="same"), + ResnetBlock(128), + ResnetBlock(128), + PaddedConv2D(128, 3, padding="same", strides=2), + ResnetBlock(256), + ResnetBlock(256), + PaddedConv2D(256, 3, padding="same", strides=2), + ResnetBlock(512), + ResnetBlock(512), + PaddedConv2D(512, 3, padding="same", strides=2), + ResnetBlock(512), + ResnetBlock(512), + ResnetBlock(512), + AttentionBlock(512), + ResnetBlock(512), + GroupNormalization(epsilon=1e-5), + keras.layers.Activation("swish"), + PaddedConv2D(8, 3, padding="same"), + PaddedConv2D(8, 1), + # TODO(lukewood): can this be refactored to be a Rescaling layer? + # Perhaps some sort of rescale and gather? + # Either way, we may need a lambda to gather the first 4 dimensions. + keras.layers.Lambda(lambda x: x[..., :4] * 0.18215), + ] + ) + + +""" +Blocks +""" + + +class ResnetBlock(keras.layers.Layer): + def __init__(self, output_dim, **kwargs): + super().__init__(**kwargs) + self.output_dim = output_dim + self.norm1 = GroupNormalization(epsilon=1e-5) + self.conv1 = PaddedConv2D(output_dim, 3, padding="same") + self.norm2 = GroupNormalization(epsilon=1e-5) + self.conv2 = PaddedConv2D(output_dim, 3, padding="same") + + def build(self, input_shape): + if input_shape[-1] != self.output_dim: + self.residual_projection = PaddedConv2D(self.output_dim, 1) + else: + self.residual_projection = lambda x: x + + def call(self, inputs): + x = self.conv1(keras.activations.swish(self.norm1(inputs))) + x = self.conv2(keras.activations.swish(self.norm2(x))) + return x + self.residual_projection(inputs) + + def get_config(self): + config = super().get_config() + config.update( + { + "output_dim": self.output_dim, + } + ) + return config + + +class AttentionBlock(keras.layers.Layer): + def __init__(self, output_dim, **kwargs): + super().__init__(**kwargs) + self.output_dim = output_dim + self.norm = GroupNormalization(epsilon=1e-5) + self.q = PaddedConv2D(output_dim, 1) + self.k = PaddedConv2D(output_dim, 1) + self.v = PaddedConv2D(output_dim, 1) + self.proj_out = PaddedConv2D(output_dim, 1) + + def get_config(self): + config = super().get_config() + config.update( + { + "output_dim": self.output_dim, + } + ) + return config + + def call(self, inputs): + x = self.norm(inputs) + q, k, v = self.q(x), self.k(x), self.v(x) + + # Compute attention + _, h, w, c = q.shape + q = keras.ops.reshape(q, (-1, h * w, c)) # b, hw, c + k = keras.ops.transpose(k, (0, 3, 1, 2)) + k = keras.ops.reshape(k, (-1, c, h * w)) # b, c, hw + y = q @ k + y = y * (c**-0.5) + y = keras.activations.softmax(y) + + # Attend to values + v = keras.ops.transpose(v, (0, 3, 1, 2)) + v = keras.ops.reshape(v, (-1, c, h * w)) + y = keras.ops.transpose(y, (0, 2, 1)) + x = v @ y + x = keras.ops.transpose(x, (0, 2, 1)) + x = keras.ops.reshape(x, (-1, h, w, c)) + return self.proj_out(x) + inputs + + +class PaddedConv2D(keras.layers.Layer): + def __init__( + self, filters, kernel_size, padding="valid", strides=1, name=None, **kwargs + ): + super().__init__(**kwargs) + self.conv2d = keras.layers.Conv2D( + filters, kernel_size, strides=strides, padding=padding, name=name + ) + self.filters = filters + self.kernel_size = kernel_size + self.padding = padding + self.strides = strides + + def call(self, inputs): + return self.conv2d(inputs) + + def get_config(self): + config = super().get_config() + config.update( + { + "filters": self.filters, + "kernel_size": self.kernel_size, + "padding": self.padding, + "strides": self.strides, + } + ) + return config diff --git a/keras_hub/src/models/control_net/ReadMe.md b/keras_hub/src/models/control_net/ReadMe.md new file mode 100644 index 0000000000..9f82630aa2 --- /dev/null +++ b/keras_hub/src/models/control_net/ReadMe.md @@ -0,0 +1,5 @@ +### Stable Diffusion TensorFlow ### + +Originally implemented by Divum Gupta, then heavily modified by https://github.com/soten355 into [MetalDiffusion](https://github.com/soten355/MetalDiffusion). + +Finally, the TensorFlow variant was extracted out and converted to Keras 3. diff --git a/keras_hub/src/models/control_net/__init__.py b/keras_hub/src/models/control_net/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_hub/src/models/control_net/clipEncoder.py b/keras_hub/src/models/control_net/clipEncoder.py new file mode 100644 index 0000000000..118a8135f1 --- /dev/null +++ b/keras_hub/src/models/control_net/clipEncoder.py @@ -0,0 +1,155 @@ +import keras +import numpy as np + +from .layers import quick_gelu + + +# Step 1 +# Create and return the CLIP Embeddings +class CLIPTextTransformer(keras.models.Model): + def __init__(self, maxLength=77, vocabularySize=49408): + super().__init__() + + # Create embeddings -> Step 2 + self.embeddings = CLIPTextEmbeddings( + maxLength=maxLength, vocabularySize=vocabularySize + ) + + # Create encoder -> Step 3 + self.encoder = CLIPEncoder() + + self.final_layer_norm = keras.layers.LayerNormalization( + epsilon=1e-5, name="FinalLayerNormalization" + ) + self.causal_attention_mask = keras.initializers.Constant( + np.triu(np.ones((1, 1, 77, 77), dtype="float32") * -np.inf, k=1), + name="CausalAttentionMask", + ) + + def call(self, inputs): + input_ids, position_ids = inputs + x = self.embeddings([input_ids, position_ids]) + x = self.encoder([x, self.causal_attention_mask]) + return self.final_layer_norm(x) + + +# Step 2 +# Create and return word and position embeddings + + +class CLIPTextEmbeddings(keras.layers.Layer): + def __init__(self, maxLength=77, vocabularySize=49408, embeddingSize=768): + super().__init__() + self.token_embedding_layer = keras.layers.Embedding( + vocabularySize, embeddingSize, name="token_embedding" + ) + self.position_embedding_layer = keras.layers.Embedding( + maxLength, embeddingSize, name="position_embedding" + ) + + def call(self, inputs): + input_ids, position_ids = inputs + word_embeddings = self.token_embedding_layer(input_ids) + position_embeddings = self.position_embedding_layer(position_ids) + return word_embeddings + position_embeddings + + +# Step 3 +# Create and return the hidden states (aka hidden size) +class CLIPEncoder(keras.layers.Layer): + def __init__(self): + super().__init__() + self.layers = [CLIPEncoderLayer() for i in range(12)] + + def call(self, inputs): + [hidden_states, causal_attention_mask] = inputs + for l in self.layers: + hidden_states = l([hidden_states, causal_attention_mask]) + return hidden_states + + +# Step 4 (also creatd in step 3) +# Create the layers +class CLIPEncoderLayer(keras.layers.Layer): + def __init__(self, intermediateSize=3072, embeddingSize=768): + super().__init__() + self.layer_norm1 = keras.layers.LayerNormalization( + epsilon=1e-5, name="LayerNormalization001" + ) + self.self_attn = CLIPAttention() + self.layer_norm2 = keras.layers.LayerNormalization( + epsilon=1e-5, name="LayerNormalization002" + ) + self.fc1 = keras.layers.Dense(intermediateSize, name="FC1") + self.fc2 = keras.layers.Dense(embeddingSize, name="FC2") + + def call(self, inputs): + hidden_states, causal_attention_mask = inputs + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn([hidden_states, causal_attention_mask]) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = quick_gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + + return residual + hidden_states + + +class CLIPAttention(keras.layers.Layer): + def __init__(self): + super().__init__() + self.embed_dim = 768 + self.num_heads = 12 + self.head_dim = self.embed_dim // self.num_heads + self.scale = self.head_dim**-0.5 + self.q_proj = keras.layers.Dense(self.embed_dim, name="QueryState") + self.k_proj = keras.layers.Dense(self.embed_dim, name="KeyState") + self.v_proj = keras.layers.Dense(self.embed_dim, name="ValueState") + self.out_proj = keras.layers.Dense(self.embed_dim, name="OutProjection") + + def _shape(self, tensor, seq_len: int, bsz: int): + a = keras.ops.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)) + return keras.layers.Permute((2, 1, 3))(a) # bs , n_head , seq_len , head_dim + + def call(self, inputs): + hidden_states, causal_attention_mask = inputs + bsz, tgt_len, embed_dim = hidden_states.shape + query_states = self.q_proj(hidden_states) * self.scale + key_states = self._shape(self.k_proj(hidden_states), tgt_len, -1) + value_states = self._shape(self.v_proj(hidden_states), tgt_len, -1) + + proj_shape = (-1, tgt_len, self.head_dim) + query_states = self._shape(query_states, tgt_len, -1) + query_states = keras.ops.reshape(query_states, proj_shape) + key_states = keras.ops.reshape(key_states, proj_shape) + + src_len = tgt_len + value_states = keras.ops.reshape(value_states, proj_shape) + attn_weights = query_states @ keras.layers.Permute((2, 1))(key_states) + + attn_weights = keras.ops.reshape( + attn_weights, (-1, self.num_heads, tgt_len, src_len) + ) + # print("attn_weights dtype:",attn_weights.dtype) + # print('casual dtype:',causal_attention_mask.dtype) + # Convert the causal_attention_mask tensor to the same data type as attn_weights + # causal_attention_mask = keras.ops.cast(causal_attention_mask, dtype=attn_weights.dtype) + attn_weights = attn_weights + causal_attention_mask + attn_weights = keras.ops.reshape(attn_weights, (-1, tgt_len, src_len)) + + attn_weights = keras.ops.softmax(attn_weights) + attn_output = attn_weights @ value_states + + attn_output = keras.ops.reshape( + attn_output, (-1, self.num_heads, tgt_len, self.head_dim) + ) + attn_output = keras.layers.Permute((2, 1, 3))(attn_output) + attn_output = keras.ops.reshape(attn_output, (-1, tgt_len, embed_dim)) + + return self.out_proj(attn_output) diff --git a/keras_hub/src/models/control_net/clipTokenizer/ReadMe.md b/keras_hub/src/models/control_net/clipTokenizer/ReadMe.md new file mode 100644 index 0000000000..78118d531c --- /dev/null +++ b/keras_hub/src/models/control_net/clipTokenizer/ReadMe.md @@ -0,0 +1,3 @@ +## CLIP Tokenizer + +This folder contains the files necessary for the CLIP Tokenizer. diff --git a/keras_hub/src/models/control_net/clipTokenizer/__init__.py b/keras_hub/src/models/control_net/clipTokenizer/__init__.py new file mode 100644 index 0000000000..73714c4bb9 --- /dev/null +++ b/keras_hub/src/models/control_net/clipTokenizer/__init__.py @@ -0,0 +1,284 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import keras +import regex as re + + +@lru_cache() +def default_bpe(): + p = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" + ) + if os.path.exists(p): + return p + else: + return keras.utils.get_file( + "bpe_simple_vocab_16e6.txt.gz", + "https://github.com/openai/CLIP/blob/main/clip/bpe_simple_vocab_16e6.txt.gz?raw=true", + ) + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe(), specialTokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + + """ + Special Tokens are words we want to add to the vocabularly that don't exist in real life. + We can use these special tokens to activate pre-trained vectors for the text-encoder + + The only special tokens that are always added indicate when the text starts and ends + """ + if not specialTokens: + addedTokens = None + specialTokens = ["", ""] + else: + addedTokens = specialTokens + specialTokens = ["", ""] + specialTokens + vocab.extend(specialTokens) + + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t: t for t in specialTokens} + + # Create special words to recognize + special = "|".join(specialTokens) + self.pat = re.compile( + special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in specialTokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = ( + bytearray([self.byte_decoder[c] for c in text]) + .decode("utf-8", errors="replace") + .replace("", " ") + ) + return text + + +class LegacySimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe(), specialTokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + + """ + Special Tokens are words we want to add to the vocabularly that don't exist in real life. + We can use these special tokens to activate pre-trained vectors for the text-encoder + + The only special tokens that are always added indicate when the text starts and ends + """ + # Create the words to add to the vocabulary + if specialTokens is None: + addedTokens = None + specialTokens = ["<|startoftext|>", "<|endoftext|>"] + else: + addedTokens = specialTokens + specialTokens = ["<|startoftext|>", "<|endoftext|>"] + specialTokens + vocab.extend(specialTokens) + + # Create the list for the program to recognize the words + if addedTokens is not None: + special = "|".join(addedTokens) + special = special + "|" + else: + special = "" + + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t: t for t in specialTokens} + + self.pat = re.compile( + special + + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + + def bpe(self, token): + # print("Tokenzing this: ",token) + if token in self.cache: + # print("Found in cache! Returning cache value") + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + # return [49406] + bpe_tokens + [49407] + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = ( + bytearray([self.byte_decoder[c] for c in text]) + .decode("utf-8", errors="replace") + .replace("", " ") + ) + return text diff --git a/keras_hub/src/models/control_net/clipTokenizer/bpe_simple_vocab_16e6.txt.gz b/keras_hub/src/models/control_net/clipTokenizer/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000..7b5088a527 Binary files /dev/null and b/keras_hub/src/models/control_net/clipTokenizer/bpe_simple_vocab_16e6.txt.gz differ diff --git a/keras_hub/src/models/control_net/constants.py b/keras_hub/src/models/control_net/constants.py new file mode 100644 index 0000000000..ecca03ca88 --- /dev/null +++ b/keras_hub/src/models/control_net/constants.py @@ -0,0 +1,4636 @@ +PYTORCH_CKPT_MAPPING = { + "text_encoder_legacy": [ + ( + "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight", + (1, 0), + ), + ("cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias", None), + ( + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight", + (1, 0), + ), + ("cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias", None), + ( + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight", + (1, 0), + ), + ("cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias", None), + ( + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight", + (1, 0), + ), + ("cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias", None), + ( + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight", + (1, 0), + ), + ("cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias", None), + ( + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight", + (1, 0), + ), + ("cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias", None), + ( + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight", + (1, 0), + ), + ("cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias", None), + ( + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight", + (1, 0), + ), + ("cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias", None), + ( + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight", + (1, 0), + ), + ("cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias", None), + ( + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight", + (1, 0), + ), + ("cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias", None), + ( + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight", + (1, 0), + ), + ("cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias", None), + ( + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight", + (1, 0), + ), + ("cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias", None), + ( + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight", + (1, 0), + ), + ("cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias", None), + ( + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight", + (1, 0), + ), + ("cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias", None), + ( + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight", + (1, 0), + ), + ("cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias", None), + ( + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight", + (1, 0), + ), + ("cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias", None), + ( + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight", + (1, 0), + ), + ("cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias", None), + ( + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight", + (1, 0), + ), + ("cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias", None), + ( + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight", + (1, 0), + ), + ("cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias", None), + ( + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight", + (1, 0), + ), + ("cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias", None), + ( + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias", + None, + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight", + (1, 0), + ), + ( + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias", + None, + ), + ("cond_stage_model.transformer.text_model.final_layer_norm.weight", None), + ("cond_stage_model.transformer.text_model.final_layer_norm.bias", None), + ], + # new Stable Diffusion + "text_encoder": [ + ("cond_stage_model.model.token_embedding.weight", None), + ("cond_stage_model.model.positional_embedding", None), + ("cond_stage_model.model.transformer.resblocks.0.ln_1.weight", None), + ("cond_stage_model.model.transformer.resblocks.0.ln_1.bias", None), + ("cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight", None), + ("cond_stage_model.model.transformer.resblocks.0.attn.in_proj_bias", None), + ("cond_stage_model.model.transformer.resblocks.0.attn.out_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.0.attn.out_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.0.ln_2.weight", None), + ("cond_stage_model.model.transformer.resblocks.0.ln_2.bias", None), + ("cond_stage_model.model.transformer.resblocks.0.mlp.c_fc.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.0.mlp.c_fc.bias", None), + ("cond_stage_model.model.transformer.resblocks.0.mlp.c_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.0.mlp.c_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.1.ln_1.weight", None), + ("cond_stage_model.model.transformer.resblocks.1.ln_1.bias", None), + ("cond_stage_model.model.transformer.resblocks.1.attn.in_proj_weight", None), + ("cond_stage_model.model.transformer.resblocks.1.attn.in_proj_bias", None), + ("cond_stage_model.model.transformer.resblocks.1.attn.out_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.1.attn.out_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.1.ln_2.weight", None), + ("cond_stage_model.model.transformer.resblocks.1.ln_2.bias", None), + ("cond_stage_model.model.transformer.resblocks.1.mlp.c_fc.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.1.mlp.c_fc.bias", None), + ("cond_stage_model.model.transformer.resblocks.1.mlp.c_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.1.mlp.c_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.2.ln_1.weight", None), + ("cond_stage_model.model.transformer.resblocks.2.ln_1.bias", None), + ("cond_stage_model.model.transformer.resblocks.2.attn.in_proj_weight", None), + ("cond_stage_model.model.transformer.resblocks.2.attn.in_proj_bias", None), + ("cond_stage_model.model.transformer.resblocks.2.attn.out_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.2.attn.out_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.2.ln_2.weight", None), + ("cond_stage_model.model.transformer.resblocks.2.ln_2.bias", None), + ("cond_stage_model.model.transformer.resblocks.2.mlp.c_fc.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.2.mlp.c_fc.bias", None), + ("cond_stage_model.model.transformer.resblocks.2.mlp.c_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.2.mlp.c_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.3.ln_1.weight", None), + ("cond_stage_model.model.transformer.resblocks.3.ln_1.bias", None), + ("cond_stage_model.model.transformer.resblocks.3.attn.in_proj_weight", None), + ("cond_stage_model.model.transformer.resblocks.3.attn.in_proj_bias", None), + ("cond_stage_model.model.transformer.resblocks.3.attn.out_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.3.attn.out_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.3.ln_2.weight", None), + ("cond_stage_model.model.transformer.resblocks.3.ln_2.bias", None), + ("cond_stage_model.model.transformer.resblocks.3.mlp.c_fc.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.3.mlp.c_fc.bias", None), + ("cond_stage_model.model.transformer.resblocks.3.mlp.c_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.3.mlp.c_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.4.ln_1.weight", None), + ("cond_stage_model.model.transformer.resblocks.4.ln_1.bias", None), + ("cond_stage_model.model.transformer.resblocks.4.attn.in_proj_weight", None), + ("cond_stage_model.model.transformer.resblocks.4.attn.in_proj_bias", None), + ("cond_stage_model.model.transformer.resblocks.4.attn.out_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.4.attn.out_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.4.ln_2.weight", None), + ("cond_stage_model.model.transformer.resblocks.4.ln_2.bias", None), + ("cond_stage_model.model.transformer.resblocks.4.mlp.c_fc.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.4.mlp.c_fc.bias", None), + ("cond_stage_model.model.transformer.resblocks.4.mlp.c_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.4.mlp.c_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.5.ln_1.weight", None), + ("cond_stage_model.model.transformer.resblocks.5.ln_1.bias", None), + ("cond_stage_model.model.transformer.resblocks.5.attn.in_proj_weight", None), + ("cond_stage_model.model.transformer.resblocks.5.attn.in_proj_bias", None), + ("cond_stage_model.model.transformer.resblocks.5.attn.out_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.5.attn.out_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.5.ln_2.weight", None), + ("cond_stage_model.model.transformer.resblocks.5.ln_2.bias", None), + ("cond_stage_model.model.transformer.resblocks.5.mlp.c_fc.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.5.mlp.c_fc.bias", None), + ("cond_stage_model.model.transformer.resblocks.5.mlp.c_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.5.mlp.c_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.6.ln_1.weight", None), + ("cond_stage_model.model.transformer.resblocks.6.ln_1.bias", None), + ("cond_stage_model.model.transformer.resblocks.6.attn.in_proj_weight", None), + ("cond_stage_model.model.transformer.resblocks.6.attn.in_proj_bias", None), + ("cond_stage_model.model.transformer.resblocks.6.attn.out_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.6.attn.out_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.6.ln_2.weight", None), + ("cond_stage_model.model.transformer.resblocks.6.ln_2.bias", None), + ("cond_stage_model.model.transformer.resblocks.6.mlp.c_fc.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.6.mlp.c_fc.bias", None), + ("cond_stage_model.model.transformer.resblocks.6.mlp.c_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.6.mlp.c_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.7.ln_1.weight", None), + ("cond_stage_model.model.transformer.resblocks.7.ln_1.bias", None), + ("cond_stage_model.model.transformer.resblocks.7.attn.in_proj_weight", None), + ("cond_stage_model.model.transformer.resblocks.7.attn.in_proj_bias", None), + ("cond_stage_model.model.transformer.resblocks.7.attn.out_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.7.attn.out_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.7.ln_2.weight", None), + ("cond_stage_model.model.transformer.resblocks.7.ln_2.bias", None), + ("cond_stage_model.model.transformer.resblocks.7.mlp.c_fc.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.7.mlp.c_fc.bias", None), + ("cond_stage_model.model.transformer.resblocks.7.mlp.c_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.7.mlp.c_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.8.ln_1.weight", None), + ("cond_stage_model.model.transformer.resblocks.8.ln_1.bias", None), + ("cond_stage_model.model.transformer.resblocks.8.attn.in_proj_weight", None), + ("cond_stage_model.model.transformer.resblocks.8.attn.in_proj_bias", None), + ("cond_stage_model.model.transformer.resblocks.8.attn.out_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.8.attn.out_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.8.ln_2.weight", None), + ("cond_stage_model.model.transformer.resblocks.8.ln_2.bias", None), + ("cond_stage_model.model.transformer.resblocks.8.mlp.c_fc.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.8.mlp.c_fc.bias", None), + ("cond_stage_model.model.transformer.resblocks.8.mlp.c_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.8.mlp.c_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.9.ln_1.weight", None), + ("cond_stage_model.model.transformer.resblocks.9.ln_1.bias", None), + ("cond_stage_model.model.transformer.resblocks.9.attn.in_proj_weight", None), + ("cond_stage_model.model.transformer.resblocks.9.attn.in_proj_bias", None), + ("cond_stage_model.model.transformer.resblocks.9.attn.out_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.9.attn.out_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.9.ln_2.weight", None), + ("cond_stage_model.model.transformer.resblocks.9.ln_2.bias", None), + ("cond_stage_model.model.transformer.resblocks.9.mlp.c_fc.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.9.mlp.c_fc.bias", None), + ("cond_stage_model.model.transformer.resblocks.9.mlp.c_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.9.mlp.c_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.10.ln_1.weight", None), + ("cond_stage_model.model.transformer.resblocks.10.ln_1.bias", None), + ("cond_stage_model.model.transformer.resblocks.10.attn.in_proj_weight", None), + ("cond_stage_model.model.transformer.resblocks.10.attn.in_proj_bias", None), + ( + "cond_stage_model.model.transformer.resblocks.10.attn.out_proj.weight", + (1, 0), + ), + ("cond_stage_model.model.transformer.resblocks.10.attn.out_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.10.ln_2.weight", None), + ("cond_stage_model.model.transformer.resblocks.10.ln_2.bias", None), + ("cond_stage_model.model.transformer.resblocks.10.mlp.c_fc.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.10.mlp.c_fc.bias", None), + ("cond_stage_model.model.transformer.resblocks.10.mlp.c_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.10.mlp.c_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.11.ln_1.weight", None), + ("cond_stage_model.model.transformer.resblocks.11.ln_1.bias", None), + ("cond_stage_model.model.transformer.resblocks.11.attn.in_proj_weight", None), + ("cond_stage_model.model.transformer.resblocks.11.attn.in_proj_bias", None), + ( + "cond_stage_model.model.transformer.resblocks.11.attn.out_proj.weight", + (1, 0), + ), + ("cond_stage_model.model.transformer.resblocks.11.attn.out_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.11.ln_2.weight", None), + ("cond_stage_model.model.transformer.resblocks.11.ln_2.bias", None), + ("cond_stage_model.model.transformer.resblocks.11.mlp.c_fc.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.11.mlp.c_fc.bias", None), + ("cond_stage_model.model.transformer.resblocks.11.mlp.c_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.11.mlp.c_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.12.ln_1.weight", None), + ("cond_stage_model.model.transformer.resblocks.12.ln_1.bias", None), + ("cond_stage_model.model.transformer.resblocks.12.attn.in_proj_weight", None), + ("cond_stage_model.model.transformer.resblocks.12.attn.in_proj_bias", None), + ( + "cond_stage_model.model.transformer.resblocks.12.attn.out_proj.weight", + (1, 0), + ), + ("cond_stage_model.model.transformer.resblocks.12.attn.out_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.12.ln_2.weight", None), + ("cond_stage_model.model.transformer.resblocks.12.ln_2.bias", None), + ("cond_stage_model.model.transformer.resblocks.12.mlp.c_fc.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.12.mlp.c_fc.bias", None), + ("cond_stage_model.model.transformer.resblocks.12.mlp.c_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.12.mlp.c_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.13.ln_1.weight", None), + ("cond_stage_model.model.transformer.resblocks.13.ln_1.bias", None), + ("cond_stage_model.model.transformer.resblocks.13.attn.in_proj_weight", None), + ("cond_stage_model.model.transformer.resblocks.13.attn.in_proj_bias", None), + ( + "cond_stage_model.model.transformer.resblocks.13.attn.out_proj.weight", + (1, 0), + ), + ("cond_stage_model.model.transformer.resblocks.13.attn.out_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.13.ln_2.weight", None), + ("cond_stage_model.model.transformer.resblocks.13.ln_2.bias", None), + ("cond_stage_model.model.transformer.resblocks.13.mlp.c_fc.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.13.mlp.c_fc.bias", None), + ("cond_stage_model.model.transformer.resblocks.13.mlp.c_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.13.mlp.c_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.14.ln_1.weight", None), + ("cond_stage_model.model.transformer.resblocks.14.ln_1.bias", None), + ("cond_stage_model.model.transformer.resblocks.14.attn.in_proj_weight", None), + ("cond_stage_model.model.transformer.resblocks.14.attn.in_proj_bias", None), + ( + "cond_stage_model.model.transformer.resblocks.14.attn.out_proj.weight", + (1, 0), + ), + ("cond_stage_model.model.transformer.resblocks.14.attn.out_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.14.ln_2.weight", None), + ("cond_stage_model.model.transformer.resblocks.14.ln_2.bias", None), + ("cond_stage_model.model.transformer.resblocks.14.mlp.c_fc.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.14.mlp.c_fc.bias", None), + ("cond_stage_model.model.transformer.resblocks.14.mlp.c_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.14.mlp.c_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.15.ln_1.weight", None), + ("cond_stage_model.model.transformer.resblocks.15.ln_1.bias", None), + ("cond_stage_model.model.transformer.resblocks.15.attn.in_proj_weight", None), + ("cond_stage_model.model.transformer.resblocks.15.attn.in_proj_bias", None), + ( + "cond_stage_model.model.transformer.resblocks.15.attn.out_proj.weight", + (1, 0), + ), + ("cond_stage_model.model.transformer.resblocks.15.attn.out_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.15.ln_2.weight", None), + ("cond_stage_model.model.transformer.resblocks.15.ln_2.bias", None), + ("cond_stage_model.model.transformer.resblocks.15.mlp.c_fc.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.15.mlp.c_fc.bias", None), + ("cond_stage_model.model.transformer.resblocks.15.mlp.c_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.15.mlp.c_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.16.ln_1.weight", None), + ("cond_stage_model.model.transformer.resblocks.16.ln_1.bias", None), + ("cond_stage_model.model.transformer.resblocks.16.attn.in_proj_weight", None), + ("cond_stage_model.model.transformer.resblocks.16.attn.in_proj_bias", None), + ( + "cond_stage_model.model.transformer.resblocks.16.attn.out_proj.weight", + (1, 0), + ), + ("cond_stage_model.model.transformer.resblocks.16.attn.out_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.16.ln_2.weight", None), + ("cond_stage_model.model.transformer.resblocks.16.ln_2.bias", None), + ("cond_stage_model.model.transformer.resblocks.16.mlp.c_fc.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.16.mlp.c_fc.bias", None), + ("cond_stage_model.model.transformer.resblocks.16.mlp.c_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.16.mlp.c_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.17.ln_1.weight", None), + ("cond_stage_model.model.transformer.resblocks.17.ln_1.bias", None), + ("cond_stage_model.model.transformer.resblocks.17.attn.in_proj_weight", None), + ("cond_stage_model.model.transformer.resblocks.17.attn.in_proj_bias", None), + ( + "cond_stage_model.model.transformer.resblocks.17.attn.out_proj.weight", + (1, 0), + ), + ("cond_stage_model.model.transformer.resblocks.17.attn.out_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.17.ln_2.weight", None), + ("cond_stage_model.model.transformer.resblocks.17.ln_2.bias", None), + ("cond_stage_model.model.transformer.resblocks.17.mlp.c_fc.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.17.mlp.c_fc.bias", None), + ("cond_stage_model.model.transformer.resblocks.17.mlp.c_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.17.mlp.c_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.18.ln_1.weight", None), + ("cond_stage_model.model.transformer.resblocks.18.ln_1.bias", None), + ("cond_stage_model.model.transformer.resblocks.18.attn.in_proj_weight", None), + ("cond_stage_model.model.transformer.resblocks.18.attn.in_proj_bias", None), + ( + "cond_stage_model.model.transformer.resblocks.18.attn.out_proj.weight", + (1, 0), + ), + ("cond_stage_model.model.transformer.resblocks.18.attn.out_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.18.ln_2.weight", None), + ("cond_stage_model.model.transformer.resblocks.18.ln_2.bias", None), + ("cond_stage_model.model.transformer.resblocks.18.mlp.c_fc.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.18.mlp.c_fc.bias", None), + ("cond_stage_model.model.transformer.resblocks.18.mlp.c_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.18.mlp.c_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.19.ln_1.weight", None), + ("cond_stage_model.model.transformer.resblocks.19.ln_1.bias", None), + ("cond_stage_model.model.transformer.resblocks.19.attn.in_proj_weight", None), + ("cond_stage_model.model.transformer.resblocks.19.attn.in_proj_bias", None), + ( + "cond_stage_model.model.transformer.resblocks.19.attn.out_proj.weight", + (1, 0), + ), + ("cond_stage_model.model.transformer.resblocks.19.attn.out_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.19.ln_2.weight", None), + ("cond_stage_model.model.transformer.resblocks.19.ln_2.bias", None), + ("cond_stage_model.model.transformer.resblocks.19.mlp.c_fc.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.19.mlp.c_fc.bias", None), + ("cond_stage_model.model.transformer.resblocks.19.mlp.c_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.19.mlp.c_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.20.ln_1.weight", None), + ("cond_stage_model.model.transformer.resblocks.20.ln_1.bias", None), + ("cond_stage_model.model.transformer.resblocks.20.attn.in_proj_weight", None), + ("cond_stage_model.model.transformer.resblocks.20.attn.in_proj_bias", None), + ( + "cond_stage_model.model.transformer.resblocks.20.attn.out_proj.weight", + (1, 0), + ), + ("cond_stage_model.model.transformer.resblocks.20.attn.out_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.20.ln_2.weight", None), + ("cond_stage_model.model.transformer.resblocks.20.ln_2.bias", None), + ("cond_stage_model.model.transformer.resblocks.20.mlp.c_fc.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.20.mlp.c_fc.bias", None), + ("cond_stage_model.model.transformer.resblocks.20.mlp.c_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.20.mlp.c_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.21.ln_1.weight", None), + ("cond_stage_model.model.transformer.resblocks.21.ln_1.bias", None), + ("cond_stage_model.model.transformer.resblocks.21.attn.in_proj_weight", None), + ("cond_stage_model.model.transformer.resblocks.21.attn.in_proj_bias", None), + ( + "cond_stage_model.model.transformer.resblocks.21.attn.out_proj.weight", + (1, 0), + ), + ("cond_stage_model.model.transformer.resblocks.21.attn.out_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.21.ln_2.weight", None), + ("cond_stage_model.model.transformer.resblocks.21.ln_2.bias", None), + ("cond_stage_model.model.transformer.resblocks.21.mlp.c_fc.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.21.mlp.c_fc.bias", None), + ("cond_stage_model.model.transformer.resblocks.21.mlp.c_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.21.mlp.c_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.22.ln_1.weight", None), + ("cond_stage_model.model.transformer.resblocks.22.ln_1.bias", None), + ("cond_stage_model.model.transformer.resblocks.22.attn.in_proj_weight", None), + ("cond_stage_model.model.transformer.resblocks.22.attn.in_proj_bias", None), + ( + "cond_stage_model.model.transformer.resblocks.22.attn.out_proj.weight", + (1, 0), + ), + ("cond_stage_model.model.transformer.resblocks.22.attn.out_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.22.ln_2.weight", None), + ("cond_stage_model.model.transformer.resblocks.22.ln_2.bias", None), + ("cond_stage_model.model.transformer.resblocks.22.mlp.c_fc.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.22.mlp.c_fc.bias", None), + ("cond_stage_model.model.transformer.resblocks.22.mlp.c_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.22.mlp.c_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.23.ln_1.weight", None), + ("cond_stage_model.model.transformer.resblocks.23.ln_1.bias", None), + ("cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight", None), + ("cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias", None), + ( + "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.weight", + (1, 0), + ), + ("cond_stage_model.model.transformer.resblocks.23.attn.out_proj.bias", None), + ("cond_stage_model.model.transformer.resblocks.23.ln_2.weight", None), + ("cond_stage_model.model.transformer.resblocks.23.ln_2.bias", None), + ("cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.bias", None), + ("cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.weight", (1, 0)), + ("cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.bias", None), + ("cond_stage_model.model.ln_final.weight", None), + ("cond_stage_model.model.ln_final.bias", None), + ], + "diffusion_model": [ + ("model.diffusion_model.time_embed.0.weight", (1, 0)), + ("model.diffusion_model.time_embed.0.bias", None), + ("model.diffusion_model.input_blocks.0.0.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.0.0.bias", None), + ("model.diffusion_model.time_embed.2.weight", (1, 0)), + ("model.diffusion_model.time_embed.2.bias", None), + ("model.diffusion_model.input_blocks.1.0.in_layers.0.weight", None), + ("model.diffusion_model.input_blocks.1.0.in_layers.0.bias", None), + ("model.diffusion_model.input_blocks.1.0.in_layers.2.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.1.0.in_layers.2.bias", None), + ("model.diffusion_model.input_blocks.1.0.emb_layers.1.weight", (1, 0)), + ("model.diffusion_model.input_blocks.1.0.emb_layers.1.bias", None), + ("model.diffusion_model.input_blocks.1.0.out_layers.0.weight", None), + ("model.diffusion_model.input_blocks.1.0.out_layers.0.bias", None), + ("model.diffusion_model.input_blocks.1.0.out_layers.3.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.1.0.out_layers.3.bias", None), + ("model.diffusion_model.input_blocks.1.1.norm.weight", None), + ("model.diffusion_model.input_blocks.1.1.norm.bias", None), + ("model.diffusion_model.input_blocks.1.1.proj_in.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.1.1.proj_in.bias", None), + ( + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.weight", + None, + ), + ( + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.weight", + None, + ), + ( + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.weight", + None, + ), + ( + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias", + None, + ), + ("model.diffusion_model.input_blocks.1.1.proj_out.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.1.1.proj_out.bias", None), + ("model.diffusion_model.input_blocks.2.0.in_layers.0.weight", None), + ("model.diffusion_model.input_blocks.2.0.in_layers.0.bias", None), + ("model.diffusion_model.input_blocks.2.0.in_layers.2.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.2.0.in_layers.2.bias", None), + ("model.diffusion_model.input_blocks.2.0.emb_layers.1.weight", (1, 0)), + ("model.diffusion_model.input_blocks.2.0.emb_layers.1.bias", None), + ("model.diffusion_model.input_blocks.2.0.out_layers.0.weight", None), + ("model.diffusion_model.input_blocks.2.0.out_layers.0.bias", None), + ("model.diffusion_model.input_blocks.2.0.out_layers.3.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.2.0.out_layers.3.bias", None), + ("model.diffusion_model.input_blocks.2.1.norm.weight", None), + ("model.diffusion_model.input_blocks.2.1.norm.bias", None), + ("model.diffusion_model.input_blocks.2.1.proj_in.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.2.1.proj_in.bias", None), + ( + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.weight", + None, + ), + ( + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.weight", + None, + ), + ( + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.weight", + None, + ), + ( + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias", + None, + ), + ("model.diffusion_model.input_blocks.2.1.proj_out.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.2.1.proj_out.bias", None), + ("model.diffusion_model.input_blocks.3.0.op.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.3.0.op.bias", None), + ("model.diffusion_model.input_blocks.4.0.in_layers.0.weight", None), + ("model.diffusion_model.input_blocks.4.0.in_layers.0.bias", None), + ("model.diffusion_model.input_blocks.4.0.in_layers.2.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.4.0.in_layers.2.bias", None), + ("model.diffusion_model.input_blocks.4.0.emb_layers.1.weight", (1, 0)), + ("model.diffusion_model.input_blocks.4.0.emb_layers.1.bias", None), + ("model.diffusion_model.input_blocks.4.0.out_layers.0.weight", None), + ("model.diffusion_model.input_blocks.4.0.out_layers.0.bias", None), + ("model.diffusion_model.input_blocks.4.0.out_layers.3.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.4.0.out_layers.3.bias", None), + ("model.diffusion_model.input_blocks.4.0.skip_connection.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.4.0.skip_connection.bias", None), + ("model.diffusion_model.input_blocks.4.1.norm.weight", None), + ("model.diffusion_model.input_blocks.4.1.norm.bias", None), + ("model.diffusion_model.input_blocks.4.1.proj_in.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.4.1.proj_in.bias", None), + ( + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight", + None, + ), + ( + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight", + None, + ), + ( + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight", + None, + ), + ( + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias", + None, + ), + ("model.diffusion_model.input_blocks.4.1.proj_out.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.4.1.proj_out.bias", None), + ("model.diffusion_model.input_blocks.5.0.in_layers.0.weight", None), + ("model.diffusion_model.input_blocks.5.0.in_layers.0.bias", None), + ("model.diffusion_model.input_blocks.5.0.in_layers.2.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.5.0.in_layers.2.bias", None), + ("model.diffusion_model.input_blocks.5.0.emb_layers.1.weight", (1, 0)), + ("model.diffusion_model.input_blocks.5.0.emb_layers.1.bias", None), + ("model.diffusion_model.input_blocks.5.0.out_layers.0.weight", None), + ("model.diffusion_model.input_blocks.5.0.out_layers.0.bias", None), + ("model.diffusion_model.input_blocks.5.0.out_layers.3.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.5.0.out_layers.3.bias", None), + ("model.diffusion_model.input_blocks.5.1.norm.weight", None), + ("model.diffusion_model.input_blocks.5.1.norm.bias", None), + ("model.diffusion_model.input_blocks.5.1.proj_in.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.5.1.proj_in.bias", None), + ( + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight", + None, + ), + ( + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight", + None, + ), + ( + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight", + None, + ), + ( + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias", + None, + ), + ("model.diffusion_model.input_blocks.5.1.proj_out.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.5.1.proj_out.bias", None), + ("model.diffusion_model.input_blocks.6.0.op.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.6.0.op.bias", None), + ("model.diffusion_model.input_blocks.7.0.in_layers.0.weight", None), + ("model.diffusion_model.input_blocks.7.0.in_layers.0.bias", None), + ("model.diffusion_model.input_blocks.7.0.in_layers.2.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.7.0.in_layers.2.bias", None), + ("model.diffusion_model.input_blocks.7.0.emb_layers.1.weight", (1, 0)), + ("model.diffusion_model.input_blocks.7.0.emb_layers.1.bias", None), + ("model.diffusion_model.input_blocks.7.0.out_layers.0.weight", None), + ("model.diffusion_model.input_blocks.7.0.out_layers.0.bias", None), + ("model.diffusion_model.input_blocks.7.0.out_layers.3.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.7.0.out_layers.3.bias", None), + ("model.diffusion_model.input_blocks.7.0.skip_connection.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.7.0.skip_connection.bias", None), + ("model.diffusion_model.input_blocks.7.1.norm.weight", None), + ("model.diffusion_model.input_blocks.7.1.norm.bias", None), + ("model.diffusion_model.input_blocks.7.1.proj_in.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.7.1.proj_in.bias", None), + ( + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight", + None, + ), + ( + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight", + None, + ), + ( + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight", + None, + ), + ( + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias", + None, + ), + ("model.diffusion_model.input_blocks.7.1.proj_out.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.7.1.proj_out.bias", None), + ("model.diffusion_model.input_blocks.8.0.in_layers.0.weight", None), + ("model.diffusion_model.input_blocks.8.0.in_layers.0.bias", None), + ("model.diffusion_model.input_blocks.8.0.in_layers.2.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.8.0.in_layers.2.bias", None), + ("model.diffusion_model.input_blocks.8.0.emb_layers.1.weight", (1, 0)), + ("model.diffusion_model.input_blocks.8.0.emb_layers.1.bias", None), + ("model.diffusion_model.input_blocks.8.0.out_layers.0.weight", None), + ("model.diffusion_model.input_blocks.8.0.out_layers.0.bias", None), + ("model.diffusion_model.input_blocks.8.0.out_layers.3.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.8.0.out_layers.3.bias", None), + ("model.diffusion_model.input_blocks.8.1.norm.weight", None), + ("model.diffusion_model.input_blocks.8.1.norm.bias", None), + ("model.diffusion_model.input_blocks.8.1.proj_in.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.8.1.proj_in.bias", None), + ( + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight", + None, + ), + ( + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight", + None, + ), + ( + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight", + None, + ), + ( + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias", + None, + ), + ( + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight", + (1, 0), + ), + ( + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias", + None, + ), + ("model.diffusion_model.input_blocks.8.1.proj_out.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.8.1.proj_out.bias", None), + ("model.diffusion_model.input_blocks.9.0.op.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.9.0.op.bias", None), + ("model.diffusion_model.input_blocks.10.0.in_layers.0.weight", None), + ("model.diffusion_model.input_blocks.10.0.in_layers.0.bias", None), + ("model.diffusion_model.input_blocks.10.0.in_layers.2.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.10.0.in_layers.2.bias", None), + ("model.diffusion_model.input_blocks.10.0.emb_layers.1.weight", (1, 0)), + ("model.diffusion_model.input_blocks.10.0.emb_layers.1.bias", None), + ("model.diffusion_model.input_blocks.10.0.out_layers.0.weight", None), + ("model.diffusion_model.input_blocks.10.0.out_layers.0.bias", None), + ("model.diffusion_model.input_blocks.10.0.out_layers.3.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.10.0.out_layers.3.bias", None), + ("model.diffusion_model.input_blocks.11.0.in_layers.0.weight", None), + ("model.diffusion_model.input_blocks.11.0.in_layers.0.bias", None), + ("model.diffusion_model.input_blocks.11.0.in_layers.2.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.11.0.in_layers.2.bias", None), + ("model.diffusion_model.input_blocks.11.0.emb_layers.1.weight", (1, 0)), + ("model.diffusion_model.input_blocks.11.0.emb_layers.1.bias", None), + ("model.diffusion_model.input_blocks.11.0.out_layers.0.weight", None), + ("model.diffusion_model.input_blocks.11.0.out_layers.0.bias", None), + ("model.diffusion_model.input_blocks.11.0.out_layers.3.weight", (2, 3, 1, 0)), + ("model.diffusion_model.input_blocks.11.0.out_layers.3.bias", None), + ("model.diffusion_model.middle_block.0.in_layers.0.weight", None), + ("model.diffusion_model.middle_block.0.in_layers.0.bias", None), + ("model.diffusion_model.middle_block.0.in_layers.2.weight", (2, 3, 1, 0)), + ("model.diffusion_model.middle_block.0.in_layers.2.bias", None), + ("model.diffusion_model.middle_block.0.emb_layers.1.weight", (1, 0)), + ("model.diffusion_model.middle_block.0.emb_layers.1.bias", None), + ("model.diffusion_model.middle_block.0.out_layers.0.weight", None), + ("model.diffusion_model.middle_block.0.out_layers.0.bias", None), + ("model.diffusion_model.middle_block.0.out_layers.3.weight", (2, 3, 1, 0)), + ("model.diffusion_model.middle_block.0.out_layers.3.bias", None), + ("model.diffusion_model.middle_block.1.norm.weight", None), + ("model.diffusion_model.middle_block.1.norm.bias", None), + ("model.diffusion_model.middle_block.1.proj_in.weight", (2, 3, 1, 0)), + ("model.diffusion_model.middle_block.1.proj_in.bias", None), + ( + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.weight", + None, + ), + ("model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.bias", None), + ( + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight", + None, + ), + ("model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.bias", None), + ( + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.weight", + None, + ), + ("model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.bias", None), + ( + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight", + (1, 0), + ), + ( + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias", + None, + ), + ( + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.weight", + (1, 0), + ), + ( + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.bias", + None, + ), + ("model.diffusion_model.middle_block.1.proj_out.weight", (2, 3, 1, 0)), + ("model.diffusion_model.middle_block.1.proj_out.bias", None), + ("model.diffusion_model.middle_block.2.in_layers.0.weight", None), + ("model.diffusion_model.middle_block.2.in_layers.0.bias", None), + ("model.diffusion_model.middle_block.2.in_layers.2.weight", (2, 3, 1, 0)), + ("model.diffusion_model.middle_block.2.in_layers.2.bias", None), + ("model.diffusion_model.middle_block.2.emb_layers.1.weight", (1, 0)), + ("model.diffusion_model.middle_block.2.emb_layers.1.bias", None), + ("model.diffusion_model.middle_block.2.out_layers.0.weight", None), + ("model.diffusion_model.middle_block.2.out_layers.0.bias", None), + ("model.diffusion_model.middle_block.2.out_layers.3.weight", (2, 3, 1, 0)), + ("model.diffusion_model.middle_block.2.out_layers.3.bias", None), + ("model.diffusion_model.output_blocks.0.0.in_layers.0.weight", None), + ("model.diffusion_model.output_blocks.0.0.in_layers.0.bias", None), + ("model.diffusion_model.output_blocks.0.0.in_layers.2.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.0.0.in_layers.2.bias", None), + ("model.diffusion_model.output_blocks.0.0.emb_layers.1.weight", (1, 0)), + ("model.diffusion_model.output_blocks.0.0.emb_layers.1.bias", None), + ("model.diffusion_model.output_blocks.0.0.out_layers.0.weight", None), + ("model.diffusion_model.output_blocks.0.0.out_layers.0.bias", None), + ("model.diffusion_model.output_blocks.0.0.out_layers.3.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.0.0.out_layers.3.bias", None), + ( + "model.diffusion_model.output_blocks.0.0.skip_connection.weight", + (2, 3, 1, 0), + ), + ("model.diffusion_model.output_blocks.0.0.skip_connection.bias", None), + ("model.diffusion_model.output_blocks.1.0.in_layers.0.weight", None), + ("model.diffusion_model.output_blocks.1.0.in_layers.0.bias", None), + ("model.diffusion_model.output_blocks.1.0.in_layers.2.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.1.0.in_layers.2.bias", None), + ("model.diffusion_model.output_blocks.1.0.emb_layers.1.weight", (1, 0)), + ("model.diffusion_model.output_blocks.1.0.emb_layers.1.bias", None), + ("model.diffusion_model.output_blocks.1.0.out_layers.0.weight", None), + ("model.diffusion_model.output_blocks.1.0.out_layers.0.bias", None), + ("model.diffusion_model.output_blocks.1.0.out_layers.3.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.1.0.out_layers.3.bias", None), + ( + "model.diffusion_model.output_blocks.1.0.skip_connection.weight", + (2, 3, 1, 0), + ), + ("model.diffusion_model.output_blocks.1.0.skip_connection.bias", None), + ("model.diffusion_model.output_blocks.2.0.in_layers.0.weight", None), + ("model.diffusion_model.output_blocks.2.0.in_layers.0.bias", None), + ("model.diffusion_model.output_blocks.2.0.in_layers.2.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.2.0.in_layers.2.bias", None), + ("model.diffusion_model.output_blocks.2.0.emb_layers.1.weight", (1, 0)), + ("model.diffusion_model.output_blocks.2.0.emb_layers.1.bias", None), + ("model.diffusion_model.output_blocks.2.0.out_layers.0.weight", None), + ("model.diffusion_model.output_blocks.2.0.out_layers.0.bias", None), + ("model.diffusion_model.output_blocks.2.0.out_layers.3.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.2.0.out_layers.3.bias", None), + ( + "model.diffusion_model.output_blocks.2.0.skip_connection.weight", + (2, 3, 1, 0), + ), + ("model.diffusion_model.output_blocks.2.0.skip_connection.bias", None), + ("model.diffusion_model.output_blocks.2.1.conv.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.2.1.conv.bias", None), + ("model.diffusion_model.output_blocks.3.0.in_layers.0.weight", None), + ("model.diffusion_model.output_blocks.3.0.in_layers.0.bias", None), + ("model.diffusion_model.output_blocks.3.0.in_layers.2.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.3.0.in_layers.2.bias", None), + ("model.diffusion_model.output_blocks.3.0.emb_layers.1.weight", (1, 0)), + ("model.diffusion_model.output_blocks.3.0.emb_layers.1.bias", None), + ("model.diffusion_model.output_blocks.3.0.out_layers.0.weight", None), + ("model.diffusion_model.output_blocks.3.0.out_layers.0.bias", None), + ("model.diffusion_model.output_blocks.3.0.out_layers.3.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.3.0.out_layers.3.bias", None), + ( + "model.diffusion_model.output_blocks.3.0.skip_connection.weight", + (2, 3, 1, 0), + ), + ("model.diffusion_model.output_blocks.3.0.skip_connection.bias", None), + ("model.diffusion_model.output_blocks.3.1.norm.weight", None), + ("model.diffusion_model.output_blocks.3.1.norm.bias", None), + ("model.diffusion_model.output_blocks.3.1.proj_in.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.3.1.proj_in.bias", None), + ( + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias", + None, + ), + ("model.diffusion_model.output_blocks.3.1.proj_out.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.3.1.proj_out.bias", None), + ("model.diffusion_model.output_blocks.4.0.in_layers.0.weight", None), + ("model.diffusion_model.output_blocks.4.0.in_layers.0.bias", None), + ("model.diffusion_model.output_blocks.4.0.in_layers.2.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.4.0.in_layers.2.bias", None), + ("model.diffusion_model.output_blocks.4.0.emb_layers.1.weight", (1, 0)), + ("model.diffusion_model.output_blocks.4.0.emb_layers.1.bias", None), + ("model.diffusion_model.output_blocks.4.0.out_layers.0.weight", None), + ("model.diffusion_model.output_blocks.4.0.out_layers.0.bias", None), + ("model.diffusion_model.output_blocks.4.0.out_layers.3.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.4.0.out_layers.3.bias", None), + ( + "model.diffusion_model.output_blocks.4.0.skip_connection.weight", + (2, 3, 1, 0), + ), + ("model.diffusion_model.output_blocks.4.0.skip_connection.bias", None), + ("model.diffusion_model.output_blocks.4.1.norm.weight", None), + ("model.diffusion_model.output_blocks.4.1.norm.bias", None), + ("model.diffusion_model.output_blocks.4.1.proj_in.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.4.1.proj_in.bias", None), + ( + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias", + None, + ), + ("model.diffusion_model.output_blocks.4.1.proj_out.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.4.1.proj_out.bias", None), + ("model.diffusion_model.output_blocks.5.0.in_layers.0.weight", None), + ("model.diffusion_model.output_blocks.5.0.in_layers.0.bias", None), + ("model.diffusion_model.output_blocks.5.0.in_layers.2.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.5.0.in_layers.2.bias", None), + ("model.diffusion_model.output_blocks.5.0.emb_layers.1.weight", (1, 0)), + ("model.diffusion_model.output_blocks.5.0.emb_layers.1.bias", None), + ("model.diffusion_model.output_blocks.5.0.out_layers.0.weight", None), + ("model.diffusion_model.output_blocks.5.0.out_layers.0.bias", None), + ("model.diffusion_model.output_blocks.5.0.out_layers.3.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.5.0.out_layers.3.bias", None), + ( + "model.diffusion_model.output_blocks.5.0.skip_connection.weight", + (2, 3, 1, 0), + ), + ("model.diffusion_model.output_blocks.5.0.skip_connection.bias", None), + ("model.diffusion_model.output_blocks.5.1.norm.weight", None), + ("model.diffusion_model.output_blocks.5.1.norm.bias", None), + ("model.diffusion_model.output_blocks.5.1.proj_in.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.5.1.proj_in.bias", None), + ( + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias", + None, + ), + ("model.diffusion_model.output_blocks.5.1.proj_out.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.5.1.proj_out.bias", None), + ("model.diffusion_model.output_blocks.5.2.conv.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.5.2.conv.bias", None), + ("model.diffusion_model.output_blocks.6.0.in_layers.0.weight", None), + ("model.diffusion_model.output_blocks.6.0.in_layers.0.bias", None), + ("model.diffusion_model.output_blocks.6.0.in_layers.2.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.6.0.in_layers.2.bias", None), + ("model.diffusion_model.output_blocks.6.0.emb_layers.1.weight", (1, 0)), + ("model.diffusion_model.output_blocks.6.0.emb_layers.1.bias", None), + ("model.diffusion_model.output_blocks.6.0.out_layers.0.weight", None), + ("model.diffusion_model.output_blocks.6.0.out_layers.0.bias", None), + ("model.diffusion_model.output_blocks.6.0.out_layers.3.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.6.0.out_layers.3.bias", None), + ( + "model.diffusion_model.output_blocks.6.0.skip_connection.weight", + (2, 3, 1, 0), + ), + ("model.diffusion_model.output_blocks.6.0.skip_connection.bias", None), + ("model.diffusion_model.output_blocks.6.1.norm.weight", None), + ("model.diffusion_model.output_blocks.6.1.norm.bias", None), + ("model.diffusion_model.output_blocks.6.1.proj_in.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.6.1.proj_in.bias", None), + ( + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.bias", + None, + ), + ("model.diffusion_model.output_blocks.6.1.proj_out.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.6.1.proj_out.bias", None), + ("model.diffusion_model.output_blocks.7.0.in_layers.0.weight", None), + ("model.diffusion_model.output_blocks.7.0.in_layers.0.bias", None), + ("model.diffusion_model.output_blocks.7.0.in_layers.2.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.7.0.in_layers.2.bias", None), + ("model.diffusion_model.output_blocks.7.0.emb_layers.1.weight", (1, 0)), + ("model.diffusion_model.output_blocks.7.0.emb_layers.1.bias", None), + ("model.diffusion_model.output_blocks.7.0.out_layers.0.weight", None), + ("model.diffusion_model.output_blocks.7.0.out_layers.0.bias", None), + ("model.diffusion_model.output_blocks.7.0.out_layers.3.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.7.0.out_layers.3.bias", None), + ( + "model.diffusion_model.output_blocks.7.0.skip_connection.weight", + (2, 3, 1, 0), + ), + ("model.diffusion_model.output_blocks.7.0.skip_connection.bias", None), + ("model.diffusion_model.output_blocks.7.1.norm.weight", None), + ("model.diffusion_model.output_blocks.7.1.norm.bias", None), + ("model.diffusion_model.output_blocks.7.1.proj_in.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.7.1.proj_in.bias", None), + ( + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.bias", + None, + ), + ("model.diffusion_model.output_blocks.7.1.proj_out.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.7.1.proj_out.bias", None), + ("model.diffusion_model.output_blocks.8.0.in_layers.0.weight", None), + ("model.diffusion_model.output_blocks.8.0.in_layers.0.bias", None), + ("model.diffusion_model.output_blocks.8.0.in_layers.2.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.8.0.in_layers.2.bias", None), + ("model.diffusion_model.output_blocks.8.0.emb_layers.1.weight", (1, 0)), + ("model.diffusion_model.output_blocks.8.0.emb_layers.1.bias", None), + ("model.diffusion_model.output_blocks.8.0.out_layers.0.weight", None), + ("model.diffusion_model.output_blocks.8.0.out_layers.0.bias", None), + ("model.diffusion_model.output_blocks.8.0.out_layers.3.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.8.0.out_layers.3.bias", None), + ( + "model.diffusion_model.output_blocks.8.0.skip_connection.weight", + (2, 3, 1, 0), + ), + ("model.diffusion_model.output_blocks.8.0.skip_connection.bias", None), + ("model.diffusion_model.output_blocks.8.1.norm.weight", None), + ("model.diffusion_model.output_blocks.8.1.norm.bias", None), + ("model.diffusion_model.output_blocks.8.1.proj_in.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.8.1.proj_in.bias", None), + ( + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.bias", + None, + ), + ("model.diffusion_model.output_blocks.8.1.proj_out.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.8.1.proj_out.bias", None), + ("model.diffusion_model.output_blocks.8.2.conv.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.8.2.conv.bias", None), + ("model.diffusion_model.output_blocks.9.0.in_layers.0.weight", None), + ("model.diffusion_model.output_blocks.9.0.in_layers.0.bias", None), + ("model.diffusion_model.output_blocks.9.0.in_layers.2.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.9.0.in_layers.2.bias", None), + ("model.diffusion_model.output_blocks.9.0.emb_layers.1.weight", (1, 0)), + ("model.diffusion_model.output_blocks.9.0.emb_layers.1.bias", None), + ("model.diffusion_model.output_blocks.9.0.out_layers.0.weight", None), + ("model.diffusion_model.output_blocks.9.0.out_layers.0.bias", None), + ("model.diffusion_model.output_blocks.9.0.out_layers.3.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.9.0.out_layers.3.bias", None), + ( + "model.diffusion_model.output_blocks.9.0.skip_connection.weight", + (2, 3, 1, 0), + ), + ("model.diffusion_model.output_blocks.9.0.skip_connection.bias", None), + ("model.diffusion_model.output_blocks.9.1.norm.weight", None), + ("model.diffusion_model.output_blocks.9.1.norm.bias", None), + ("model.diffusion_model.output_blocks.9.1.proj_in.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.9.1.proj_in.bias", None), + ( + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.bias", + None, + ), + ("model.diffusion_model.output_blocks.9.1.proj_out.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.9.1.proj_out.bias", None), + ("model.diffusion_model.output_blocks.10.0.in_layers.0.weight", None), + ("model.diffusion_model.output_blocks.10.0.in_layers.0.bias", None), + ("model.diffusion_model.output_blocks.10.0.in_layers.2.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.10.0.in_layers.2.bias", None), + ("model.diffusion_model.output_blocks.10.0.emb_layers.1.weight", (1, 0)), + ("model.diffusion_model.output_blocks.10.0.emb_layers.1.bias", None), + ("model.diffusion_model.output_blocks.10.0.out_layers.0.weight", None), + ("model.diffusion_model.output_blocks.10.0.out_layers.0.bias", None), + ("model.diffusion_model.output_blocks.10.0.out_layers.3.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.10.0.out_layers.3.bias", None), + ( + "model.diffusion_model.output_blocks.10.0.skip_connection.weight", + (2, 3, 1, 0), + ), + ("model.diffusion_model.output_blocks.10.0.skip_connection.bias", None), + ("model.diffusion_model.output_blocks.10.1.norm.weight", None), + ("model.diffusion_model.output_blocks.10.1.norm.bias", None), + ("model.diffusion_model.output_blocks.10.1.proj_in.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.10.1.proj_in.bias", None), + ( + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.bias", + None, + ), + ("model.diffusion_model.output_blocks.10.1.proj_out.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.10.1.proj_out.bias", None), + ("model.diffusion_model.output_blocks.11.0.in_layers.0.weight", None), + ("model.diffusion_model.output_blocks.11.0.in_layers.0.bias", None), + ("model.diffusion_model.output_blocks.11.0.in_layers.2.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.11.0.in_layers.2.bias", None), + ("model.diffusion_model.output_blocks.11.0.emb_layers.1.weight", (1, 0)), + ("model.diffusion_model.output_blocks.11.0.emb_layers.1.bias", None), + ("model.diffusion_model.output_blocks.11.0.out_layers.0.weight", None), + ("model.diffusion_model.output_blocks.11.0.out_layers.0.bias", None), + ("model.diffusion_model.output_blocks.11.0.out_layers.3.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.11.0.out_layers.3.bias", None), + ( + "model.diffusion_model.output_blocks.11.0.skip_connection.weight", + (2, 3, 1, 0), + ), + ("model.diffusion_model.output_blocks.11.0.skip_connection.bias", None), + ("model.diffusion_model.output_blocks.11.1.norm.weight", None), + ("model.diffusion_model.output_blocks.11.1.norm.bias", None), + ("model.diffusion_model.output_blocks.11.1.proj_in.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.11.1.proj_in.bias", None), + ( + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_q.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_k.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_v.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.weight", + None, + ), + ( + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.bias", + None, + ), + ( + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.weight", + (1, 0), + ), + ( + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.bias", + None, + ), + ("model.diffusion_model.output_blocks.11.1.proj_out.weight", (2, 3, 1, 0)), + ("model.diffusion_model.output_blocks.11.1.proj_out.bias", None), + ("model.diffusion_model.out.0.weight", None), + ("model.diffusion_model.out.0.bias", None), + ("model.diffusion_model.out.2.weight", (2, 3, 1, 0)), + ("model.diffusion_model.out.2.bias", None), + ], + "decoder": [ + ("first_stage_model.post_quant_conv.weight", (2, 3, 1, 0)), + ("first_stage_model.post_quant_conv.bias", None), + ("first_stage_model.decoder.conv_in.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.conv_in.bias", None), + ("first_stage_model.decoder.mid.block_1.norm1.weight", None), + ("first_stage_model.decoder.mid.block_1.norm1.bias", None), + ("first_stage_model.decoder.mid.block_1.conv1.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.mid.block_1.conv1.bias", None), + ("first_stage_model.decoder.mid.block_1.norm2.weight", None), + ("first_stage_model.decoder.mid.block_1.norm2.bias", None), + ("first_stage_model.decoder.mid.block_1.conv2.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.mid.block_1.conv2.bias", None), + ("first_stage_model.decoder.mid.attn_1.norm.weight", None), + ("first_stage_model.decoder.mid.attn_1.norm.bias", None), + ("first_stage_model.decoder.mid.attn_1.q.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.mid.attn_1.q.bias", None), + ("first_stage_model.decoder.mid.attn_1.k.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.mid.attn_1.k.bias", None), + ("first_stage_model.decoder.mid.attn_1.v.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.mid.attn_1.v.bias", None), + ("first_stage_model.decoder.mid.attn_1.proj_out.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.mid.attn_1.proj_out.bias", None), + ("first_stage_model.decoder.mid.block_2.norm1.weight", None), + ("first_stage_model.decoder.mid.block_2.norm1.bias", None), + ("first_stage_model.decoder.mid.block_2.conv1.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.mid.block_2.conv1.bias", None), + ("first_stage_model.decoder.mid.block_2.norm2.weight", None), + ("first_stage_model.decoder.mid.block_2.norm2.bias", None), + ("first_stage_model.decoder.mid.block_2.conv2.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.mid.block_2.conv2.bias", None), + ("first_stage_model.decoder.up.3.block.0.norm1.weight", None), + ("first_stage_model.decoder.up.3.block.0.norm1.bias", None), + ("first_stage_model.decoder.up.3.block.0.conv1.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.3.block.0.conv1.bias", None), + ("first_stage_model.decoder.up.3.block.0.norm2.weight", None), + ("first_stage_model.decoder.up.3.block.0.norm2.bias", None), + ("first_stage_model.decoder.up.3.block.0.conv2.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.3.block.0.conv2.bias", None), + ("first_stage_model.decoder.up.3.block.1.norm1.weight", None), + ("first_stage_model.decoder.up.3.block.1.norm1.bias", None), + ("first_stage_model.decoder.up.3.block.1.conv1.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.3.block.1.conv1.bias", None), + ("first_stage_model.decoder.up.3.block.1.norm2.weight", None), + ("first_stage_model.decoder.up.3.block.1.norm2.bias", None), + ("first_stage_model.decoder.up.3.block.1.conv2.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.3.block.1.conv2.bias", None), + ("first_stage_model.decoder.up.3.block.2.norm1.weight", None), + ("first_stage_model.decoder.up.3.block.2.norm1.bias", None), + ("first_stage_model.decoder.up.3.block.2.conv1.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.3.block.2.conv1.bias", None), + ("first_stage_model.decoder.up.3.block.2.norm2.weight", None), + ("first_stage_model.decoder.up.3.block.2.norm2.bias", None), + ("first_stage_model.decoder.up.3.block.2.conv2.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.3.block.2.conv2.bias", None), + ("first_stage_model.decoder.up.3.upsample.conv.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.3.upsample.conv.bias", None), + ("first_stage_model.decoder.up.2.block.0.norm1.weight", None), + ("first_stage_model.decoder.up.2.block.0.norm1.bias", None), + ("first_stage_model.decoder.up.2.block.0.conv1.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.2.block.0.conv1.bias", None), + ("first_stage_model.decoder.up.2.block.0.norm2.weight", None), + ("first_stage_model.decoder.up.2.block.0.norm2.bias", None), + ("first_stage_model.decoder.up.2.block.0.conv2.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.2.block.0.conv2.bias", None), + ("first_stage_model.decoder.up.2.block.1.norm1.weight", None), + ("first_stage_model.decoder.up.2.block.1.norm1.bias", None), + ("first_stage_model.decoder.up.2.block.1.conv1.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.2.block.1.conv1.bias", None), + ("first_stage_model.decoder.up.2.block.1.norm2.weight", None), + ("first_stage_model.decoder.up.2.block.1.norm2.bias", None), + ("first_stage_model.decoder.up.2.block.1.conv2.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.2.block.1.conv2.bias", None), + ("first_stage_model.decoder.up.2.block.2.norm1.weight", None), + ("first_stage_model.decoder.up.2.block.2.norm1.bias", None), + ("first_stage_model.decoder.up.2.block.2.conv1.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.2.block.2.conv1.bias", None), + ("first_stage_model.decoder.up.2.block.2.norm2.weight", None), + ("first_stage_model.decoder.up.2.block.2.norm2.bias", None), + ("first_stage_model.decoder.up.2.block.2.conv2.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.2.block.2.conv2.bias", None), + ("first_stage_model.decoder.up.2.upsample.conv.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.2.upsample.conv.bias", None), + ("first_stage_model.decoder.up.1.block.0.norm1.weight", None), + ("first_stage_model.decoder.up.1.block.0.norm1.bias", None), + ("first_stage_model.decoder.up.1.block.0.conv1.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.1.block.0.conv1.bias", None), + ("first_stage_model.decoder.up.1.block.0.norm2.weight", None), + ("first_stage_model.decoder.up.1.block.0.norm2.bias", None), + ("first_stage_model.decoder.up.1.block.0.conv2.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.1.block.0.conv2.bias", None), + ("first_stage_model.decoder.up.1.block.0.nin_shortcut.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.1.block.0.nin_shortcut.bias", None), + ("first_stage_model.decoder.up.1.block.1.norm1.weight", None), + ("first_stage_model.decoder.up.1.block.1.norm1.bias", None), + ("first_stage_model.decoder.up.1.block.1.conv1.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.1.block.1.conv1.bias", None), + ("first_stage_model.decoder.up.1.block.1.norm2.weight", None), + ("first_stage_model.decoder.up.1.block.1.norm2.bias", None), + ("first_stage_model.decoder.up.1.block.1.conv2.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.1.block.1.conv2.bias", None), + ("first_stage_model.decoder.up.1.block.2.norm1.weight", None), + ("first_stage_model.decoder.up.1.block.2.norm1.bias", None), + ("first_stage_model.decoder.up.1.block.2.conv1.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.1.block.2.conv1.bias", None), + ("first_stage_model.decoder.up.1.block.2.norm2.weight", None), + ("first_stage_model.decoder.up.1.block.2.norm2.bias", None), + ("first_stage_model.decoder.up.1.block.2.conv2.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.1.block.2.conv2.bias", None), + ("first_stage_model.decoder.up.1.upsample.conv.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.1.upsample.conv.bias", None), + ("first_stage_model.decoder.up.0.block.0.norm1.weight", None), + ("first_stage_model.decoder.up.0.block.0.norm1.bias", None), + ("first_stage_model.decoder.up.0.block.0.conv1.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.0.block.0.conv1.bias", None), + ("first_stage_model.decoder.up.0.block.0.norm2.weight", None), + ("first_stage_model.decoder.up.0.block.0.norm2.bias", None), + ("first_stage_model.decoder.up.0.block.0.conv2.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.0.block.0.conv2.bias", None), + ("first_stage_model.decoder.up.0.block.0.nin_shortcut.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.0.block.0.nin_shortcut.bias", None), + ("first_stage_model.decoder.up.0.block.1.norm1.weight", None), + ("first_stage_model.decoder.up.0.block.1.norm1.bias", None), + ("first_stage_model.decoder.up.0.block.1.conv1.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.0.block.1.conv1.bias", None), + ("first_stage_model.decoder.up.0.block.1.norm2.weight", None), + ("first_stage_model.decoder.up.0.block.1.norm2.bias", None), + ("first_stage_model.decoder.up.0.block.1.conv2.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.0.block.1.conv2.bias", None), + ("first_stage_model.decoder.up.0.block.2.norm1.weight", None), + ("first_stage_model.decoder.up.0.block.2.norm1.bias", None), + ("first_stage_model.decoder.up.0.block.2.conv1.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.0.block.2.conv1.bias", None), + ("first_stage_model.decoder.up.0.block.2.norm2.weight", None), + ("first_stage_model.decoder.up.0.block.2.norm2.bias", None), + ("first_stage_model.decoder.up.0.block.2.conv2.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.up.0.block.2.conv2.bias", None), + ("first_stage_model.decoder.norm_out.weight", None), + ("first_stage_model.decoder.norm_out.bias", None), + ("first_stage_model.decoder.conv_out.weight", (2, 3, 1, 0)), + ("first_stage_model.decoder.conv_out.bias", None), + ], + "encoder": [ + ("first_stage_model.encoder.conv_in.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.conv_in.bias", None), + ("first_stage_model.encoder.down.0.block.0.norm1.weight", None), + ("first_stage_model.encoder.down.0.block.0.norm1.bias", None), + ("first_stage_model.encoder.down.0.block.0.conv1.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.down.0.block.0.conv1.bias", None), + ("first_stage_model.encoder.down.0.block.0.norm2.weight", None), + ("first_stage_model.encoder.down.0.block.0.norm2.bias", None), + ("first_stage_model.encoder.down.0.block.0.conv2.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.down.0.block.0.conv2.bias", None), + ("first_stage_model.encoder.down.0.block.1.norm1.weight", None), + ("first_stage_model.encoder.down.0.block.1.norm1.bias", None), + ("first_stage_model.encoder.down.0.block.1.conv1.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.down.0.block.1.conv1.bias", None), + ("first_stage_model.encoder.down.0.block.1.norm2.weight", None), + ("first_stage_model.encoder.down.0.block.1.norm2.bias", None), + ("first_stage_model.encoder.down.0.block.1.conv2.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.down.0.block.1.conv2.bias", None), + ("first_stage_model.encoder.down.0.downsample.conv.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.down.0.downsample.conv.bias", None), + ("first_stage_model.encoder.down.1.block.0.norm1.weight", None), + ("first_stage_model.encoder.down.1.block.0.norm1.bias", None), + ("first_stage_model.encoder.down.1.block.0.conv1.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.down.1.block.0.conv1.bias", None), + ("first_stage_model.encoder.down.1.block.0.norm2.weight", None), + ("first_stage_model.encoder.down.1.block.0.norm2.bias", None), + ("first_stage_model.encoder.down.1.block.0.conv2.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.down.1.block.0.conv2.bias", None), + ("first_stage_model.encoder.down.1.block.0.nin_shortcut.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.down.1.block.0.nin_shortcut.bias", None), + ("first_stage_model.encoder.down.1.block.1.norm1.weight", None), + ("first_stage_model.encoder.down.1.block.1.norm1.bias", None), + ("first_stage_model.encoder.down.1.block.1.conv1.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.down.1.block.1.conv1.bias", None), + ("first_stage_model.encoder.down.1.block.1.norm2.weight", None), + ("first_stage_model.encoder.down.1.block.1.norm2.bias", None), + ("first_stage_model.encoder.down.1.block.1.conv2.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.down.1.block.1.conv2.bias", None), + ("first_stage_model.encoder.down.1.downsample.conv.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.down.1.downsample.conv.bias", None), + ("first_stage_model.encoder.down.2.block.0.norm1.weight", None), + ("first_stage_model.encoder.down.2.block.0.norm1.bias", None), + ("first_stage_model.encoder.down.2.block.0.conv1.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.down.2.block.0.conv1.bias", None), + ("first_stage_model.encoder.down.2.block.0.norm2.weight", None), + ("first_stage_model.encoder.down.2.block.0.norm2.bias", None), + ("first_stage_model.encoder.down.2.block.0.conv2.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.down.2.block.0.conv2.bias", None), + ("first_stage_model.encoder.down.2.block.0.nin_shortcut.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.down.2.block.0.nin_shortcut.bias", None), + ("first_stage_model.encoder.down.2.block.1.norm1.weight", None), + ("first_stage_model.encoder.down.2.block.1.norm1.bias", None), + ("first_stage_model.encoder.down.2.block.1.conv1.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.down.2.block.1.conv1.bias", None), + ("first_stage_model.encoder.down.2.block.1.norm2.weight", None), + ("first_stage_model.encoder.down.2.block.1.norm2.bias", None), + ("first_stage_model.encoder.down.2.block.1.conv2.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.down.2.block.1.conv2.bias", None), + ("first_stage_model.encoder.down.2.downsample.conv.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.down.2.downsample.conv.bias", None), + ("first_stage_model.encoder.down.3.block.0.norm1.weight", None), + ("first_stage_model.encoder.down.3.block.0.norm1.bias", None), + ("first_stage_model.encoder.down.3.block.0.conv1.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.down.3.block.0.conv1.bias", None), + ("first_stage_model.encoder.down.3.block.0.norm2.weight", None), + ("first_stage_model.encoder.down.3.block.0.norm2.bias", None), + ("first_stage_model.encoder.down.3.block.0.conv2.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.down.3.block.0.conv2.bias", None), + ("first_stage_model.encoder.down.3.block.1.norm1.weight", None), + ("first_stage_model.encoder.down.3.block.1.norm1.bias", None), + ("first_stage_model.encoder.down.3.block.1.conv1.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.down.3.block.1.conv1.bias", None), + ("first_stage_model.encoder.down.3.block.1.norm2.weight", None), + ("first_stage_model.encoder.down.3.block.1.norm2.bias", None), + ("first_stage_model.encoder.down.3.block.1.conv2.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.down.3.block.1.conv2.bias", None), + ("first_stage_model.encoder.mid.block_1.norm1.weight", None), + ("first_stage_model.encoder.mid.block_1.norm1.bias", None), + ("first_stage_model.encoder.mid.block_1.conv1.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.mid.block_1.conv1.bias", None), + ("first_stage_model.encoder.mid.block_1.norm2.weight", None), + ("first_stage_model.encoder.mid.block_1.norm2.bias", None), + ("first_stage_model.encoder.mid.block_1.conv2.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.mid.block_1.conv2.bias", None), + ("first_stage_model.encoder.mid.attn_1.norm.weight", None), + ("first_stage_model.encoder.mid.attn_1.norm.bias", None), + ("first_stage_model.encoder.mid.attn_1.q.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.mid.attn_1.q.bias", None), + ("first_stage_model.encoder.mid.attn_1.k.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.mid.attn_1.k.bias", None), + ("first_stage_model.encoder.mid.attn_1.v.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.mid.attn_1.v.bias", None), + ("first_stage_model.encoder.mid.attn_1.proj_out.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.mid.attn_1.proj_out.bias", None), + ("first_stage_model.encoder.mid.block_2.norm1.weight", None), + ("first_stage_model.encoder.mid.block_2.norm1.bias", None), + ("first_stage_model.encoder.mid.block_2.conv1.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.mid.block_2.conv1.bias", None), + ("first_stage_model.encoder.mid.block_2.norm2.weight", None), + ("first_stage_model.encoder.mid.block_2.norm2.bias", None), + ("first_stage_model.encoder.mid.block_2.conv2.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.mid.block_2.conv2.bias", None), + ("first_stage_model.encoder.norm_out.weight", None), + ("first_stage_model.encoder.norm_out.bias", None), + ("first_stage_model.encoder.conv_out.weight", (2, 3, 1, 0)), + ("first_stage_model.encoder.conv_out.bias", None), + ("first_stage_model.quant_conv.weight", (2, 3, 1, 0)), + ("first_stage_model.quant_conv.bias", None), + ], + "controlNet": [ + ("control_model.time_embed.0.weight", (1, 0)), + ("control_model.time_embed.0.bias", None), + ("control_model.input_blocks.0.0.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.0.0.bias", None), + ("control_model.input_hint_block.0.weight", (2, 3, 1, 0)), + ("control_model.input_hint_block.0.bias", None), + ("control_model.input_hint_block.2.weight", (2, 3, 1, 0)), + ("control_model.input_hint_block.2.bias", None), + ("control_model.input_hint_block.4.weight", (2, 3, 1, 0)), + ("control_model.input_hint_block.4.bias", None), + ("control_model.input_hint_block.6.weight", (2, 3, 1, 0)), + ("control_model.input_hint_block.6.bias", None), + ("control_model.input_hint_block.8.weight", (2, 3, 1, 0)), + ("control_model.input_hint_block.8.bias", None), + ("control_model.input_hint_block.10.weight", (2, 3, 1, 0)), + ("control_model.input_hint_block.10.bias", None), + ("control_model.input_hint_block.12.weight", (2, 3, 1, 0)), + ("control_model.input_hint_block.12.bias", None), + ("control_model.input_hint_block.14.weight", (2, 3, 1, 0)), + ("control_model.input_hint_block.14.bias", None), + ("control_model.time_embed.2.weight", (1, 0)), + ("control_model.time_embed.2.bias", None), + ("control_model.input_blocks.1.0.in_layers.0.weight", None), + ("control_model.input_blocks.1.0.in_layers.0.bias", None), + ("control_model.input_blocks.1.0.in_layers.2.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.1.0.in_layers.2.bias", None), + ("control_model.input_blocks.1.0.emb_layers.1.weight", (1, 0)), + ("control_model.input_blocks.1.0.emb_layers.1.bias", None), + ("control_model.input_blocks.1.0.out_layers.0.weight", None), + ("control_model.input_blocks.1.0.out_layers.0.bias", None), + ("control_model.input_blocks.1.0.out_layers.3.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.1.0.out_layers.3.bias", None), + ("control_model.input_blocks.1.1.norm.weight", None), + ("control_model.input_blocks.1.1.norm.bias", None), + ("control_model.input_blocks.1.1.proj_in.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.1.1.proj_in.bias", None), + ("control_model.input_blocks.1.1.transformer_blocks.0.norm1.weight", None), + ("control_model.input_blocks.1.1.transformer_blocks.0.norm1.bias", None), + ( + "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight", + (1, 0), + ), + ( + "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight", + (1, 0), + ), + ( + "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight", + (1, 0), + ), + ( + "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight", + (1, 0), + ), + ( + "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias", + None, + ), + ("control_model.input_blocks.1.1.transformer_blocks.0.norm2.weight", None), + ("control_model.input_blocks.1.1.transformer_blocks.0.norm2.bias", None), + ( + "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight", + (1, 0), + ), + ( + "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight", + (1, 0), + ), + ( + "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight", + (1, 0), + ), + ( + "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight", + (1, 0), + ), + ( + "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias", + None, + ), + ("control_model.input_blocks.1.1.transformer_blocks.0.norm3.weight", None), + ("control_model.input_blocks.1.1.transformer_blocks.0.norm3.bias", None), + ( + "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight", + (1, 0), + ), + ( + "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias", + None, + ), + ("control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight", (1, 0)), + ("control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias", None), + ("control_model.input_blocks.1.1.proj_out.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.1.1.proj_out.bias", None), + ("control_model.input_blocks.2.0.in_layers.0.weight", None), + ("control_model.input_blocks.2.0.in_layers.0.bias", None), + ("control_model.input_blocks.2.0.in_layers.2.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.2.0.in_layers.2.bias", None), + ("control_model.input_blocks.2.0.emb_layers.1.weight", (1, 0)), + ("control_model.input_blocks.2.0.emb_layers.1.bias", None), + ("control_model.input_blocks.2.0.out_layers.0.weight", None), + ("control_model.input_blocks.2.0.out_layers.0.bias", None), + ("control_model.input_blocks.2.0.out_layers.3.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.2.0.out_layers.3.bias", None), + ("control_model.input_blocks.2.1.norm.weight", None), + ("control_model.input_blocks.2.1.norm.bias", None), + ("control_model.input_blocks.2.1.proj_in.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.2.1.proj_in.bias", None), + ("control_model.input_blocks.2.1.transformer_blocks.0.norm1.weight", None), + ("control_model.input_blocks.2.1.transformer_blocks.0.norm1.bias", None), + ( + "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight", + (1, 0), + ), + ( + "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight", + (1, 0), + ), + ( + "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight", + (1, 0), + ), + ( + "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight", + (1, 0), + ), + ( + "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias", + None, + ), + ("control_model.input_blocks.2.1.transformer_blocks.0.norm2.weight", None), + ("control_model.input_blocks.2.1.transformer_blocks.0.norm2.bias", None), + ( + "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight", + (1, 0), + ), + ( + "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", + (1, 0), + ), + ( + "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight", + (1, 0), + ), + ( + "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight", + (1, 0), + ), + ( + "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias", + None, + ), + ("control_model.input_blocks.2.1.transformer_blocks.0.norm3.weight", None), + ("control_model.input_blocks.2.1.transformer_blocks.0.norm3.bias", None), + ( + "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight", + (1, 0), + ), + ( + "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias", + None, + ), + ("control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight", (1, 0)), + ("control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias", None), + ("control_model.input_blocks.2.1.proj_out.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.2.1.proj_out.bias", None), + ("control_model.input_blocks.3.0.op.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.3.0.op.bias", None), + ("control_model.input_blocks.4.0.in_layers.0.weight", None), + ("control_model.input_blocks.4.0.in_layers.0.bias", None), + ("control_model.input_blocks.4.0.in_layers.2.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.4.0.in_layers.2.bias", None), + ("control_model.input_blocks.4.0.emb_layers.1.weight", (1, 0)), + ("control_model.input_blocks.4.0.emb_layers.1.bias", None), + ("control_model.input_blocks.4.0.out_layers.0.weight", None), + ("control_model.input_blocks.4.0.out_layers.0.bias", None), + ("control_model.input_blocks.4.0.out_layers.3.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.4.0.out_layers.3.bias", None), + ("control_model.input_blocks.4.0.skip_connection.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.4.0.skip_connection.bias", None), + ("control_model.input_blocks.4.1.norm.weight", None), + ("control_model.input_blocks.4.1.norm.bias", None), + ("control_model.input_blocks.4.1.proj_in.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.4.1.proj_in.bias", None), + ("control_model.input_blocks.4.1.transformer_blocks.0.norm1.weight", None), + ("control_model.input_blocks.4.1.transformer_blocks.0.norm1.bias", None), + ( + "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight", + (1, 0), + ), + ( + "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight", + (1, 0), + ), + ( + "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight", + (1, 0), + ), + ( + "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight", + (1, 0), + ), + ( + "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias", + None, + ), + ("control_model.input_blocks.4.1.transformer_blocks.0.norm2.weight", None), + ("control_model.input_blocks.4.1.transformer_blocks.0.norm2.bias", None), + ( + "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight", + (1, 0), + ), + ( + "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight", + (1, 0), + ), + ( + "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight", + (1, 0), + ), + ( + "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight", + (1, 0), + ), + ( + "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias", + None, + ), + ("control_model.input_blocks.4.1.transformer_blocks.0.norm3.weight", None), + ("control_model.input_blocks.4.1.transformer_blocks.0.norm3.bias", None), + ( + "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight", + (1, 0), + ), + ( + "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias", + None, + ), + ("control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight", (1, 0)), + ("control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias", None), + ("control_model.input_blocks.4.1.proj_out.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.4.1.proj_out.bias", None), + ("control_model.input_blocks.5.0.in_layers.0.weight", None), + ("control_model.input_blocks.5.0.in_layers.0.bias", None), + ("control_model.input_blocks.5.0.in_layers.2.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.5.0.in_layers.2.bias", None), + ("control_model.input_blocks.5.0.emb_layers.1.weight", (1, 0)), + ("control_model.input_blocks.5.0.emb_layers.1.bias", None), + ("control_model.input_blocks.5.0.out_layers.0.weight", None), + ("control_model.input_blocks.5.0.out_layers.0.bias", None), + ("control_model.input_blocks.5.0.out_layers.3.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.5.0.out_layers.3.bias", None), + ("control_model.input_blocks.5.1.norm.weight", None), + ("control_model.input_blocks.5.1.norm.bias", None), + ("control_model.input_blocks.5.1.proj_in.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.5.1.proj_in.bias", None), + ("control_model.input_blocks.5.1.transformer_blocks.0.norm1.weight", None), + ("control_model.input_blocks.5.1.transformer_blocks.0.norm1.bias", None), + ( + "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight", + (1, 0), + ), + ( + "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight", + (1, 0), + ), + ( + "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight", + (1, 0), + ), + ( + "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight", + (1, 0), + ), + ( + "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias", + None, + ), + ("control_model.input_blocks.5.1.transformer_blocks.0.norm2.weight", None), + ("control_model.input_blocks.5.1.transformer_blocks.0.norm2.bias", None), + ( + "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight", + (1, 0), + ), + ( + "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight", + (1, 0), + ), + ( + "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight", + (1, 0), + ), + ( + "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight", + (1, 0), + ), + ( + "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias", + None, + ), + ("control_model.input_blocks.5.1.transformer_blocks.0.norm3.weight", None), + ("control_model.input_blocks.5.1.transformer_blocks.0.norm3.bias", None), + ( + "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight", + (1, 0), + ), + ( + "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias", + None, + ), + ("control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight", (1, 0)), + ("control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias", None), + ("control_model.input_blocks.5.1.proj_out.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.5.1.proj_out.bias", None), + ("control_model.input_blocks.6.0.op.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.6.0.op.bias", None), + ("control_model.input_blocks.7.0.in_layers.0.weight", None), + ("control_model.input_blocks.7.0.in_layers.0.bias", None), + ("control_model.input_blocks.7.0.in_layers.2.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.7.0.in_layers.2.bias", None), + ("control_model.input_blocks.7.0.emb_layers.1.weight", (1, 0)), + ("control_model.input_blocks.7.0.emb_layers.1.bias", None), + ("control_model.input_blocks.7.0.out_layers.0.weight", None), + ("control_model.input_blocks.7.0.out_layers.0.bias", None), + ("control_model.input_blocks.7.0.out_layers.3.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.7.0.out_layers.3.bias", None), + ("control_model.input_blocks.7.0.skip_connection.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.7.0.skip_connection.bias", None), + ("control_model.input_blocks.7.1.norm.weight", None), + ("control_model.input_blocks.7.1.norm.bias", None), + ("control_model.input_blocks.7.1.proj_in.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.7.1.proj_in.bias", None), + ("control_model.input_blocks.7.1.transformer_blocks.0.norm1.weight", None), + ("control_model.input_blocks.7.1.transformer_blocks.0.norm1.bias", None), + ( + "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight", + (1, 0), + ), + ( + "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight", + (1, 0), + ), + ( + "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight", + (1, 0), + ), + ( + "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight", + (1, 0), + ), + ( + "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias", + None, + ), + ("control_model.input_blocks.7.1.transformer_blocks.0.norm2.weight", None), + ("control_model.input_blocks.7.1.transformer_blocks.0.norm2.bias", None), + ( + "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight", + (1, 0), + ), + ( + "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight", + (1, 0), + ), + ( + "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight", + (1, 0), + ), + ( + "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight", + (1, 0), + ), + ( + "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias", + None, + ), + ("control_model.input_blocks.7.1.transformer_blocks.0.norm3.weight", None), + ("control_model.input_blocks.7.1.transformer_blocks.0.norm3.bias", None), + ( + "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight", + (1, 0), + ), + ( + "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias", + None, + ), + ("control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight", (1, 0)), + ("control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias", None), + ("control_model.input_blocks.7.1.proj_out.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.7.1.proj_out.bias", None), + ("control_model.input_blocks.8.0.in_layers.0.weight", None), + ("control_model.input_blocks.8.0.in_layers.0.bias", None), + ("control_model.input_blocks.8.0.in_layers.2.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.8.0.in_layers.2.bias", None), + ("control_model.input_blocks.8.0.emb_layers.1.weight", (1, 0)), + ("control_model.input_blocks.8.0.emb_layers.1.bias", None), + ("control_model.input_blocks.8.0.out_layers.0.weight", None), + ("control_model.input_blocks.8.0.out_layers.0.bias", None), + ("control_model.input_blocks.8.0.out_layers.3.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.8.0.out_layers.3.bias", None), + ("control_model.input_blocks.8.1.norm.weight", None), + ("control_model.input_blocks.8.1.norm.bias", None), + ("control_model.input_blocks.8.1.proj_in.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.8.1.proj_in.bias", None), + ("control_model.input_blocks.8.1.transformer_blocks.0.norm1.weight", None), + ("control_model.input_blocks.8.1.transformer_blocks.0.norm1.bias", None), + ( + "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight", + (1, 0), + ), + ( + "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight", + (1, 0), + ), + ( + "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight", + (1, 0), + ), + ( + "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight", + (1, 0), + ), + ( + "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias", + None, + ), + ("control_model.input_blocks.8.1.transformer_blocks.0.norm2.weight", None), + ("control_model.input_blocks.8.1.transformer_blocks.0.norm2.bias", None), + ( + "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight", + (1, 0), + ), + ( + "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight", + (1, 0), + ), + ( + "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight", + (1, 0), + ), + ( + "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight", + (1, 0), + ), + ( + "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias", + None, + ), + ("control_model.input_blocks.8.1.transformer_blocks.0.norm3.weight", None), + ("control_model.input_blocks.8.1.transformer_blocks.0.norm3.bias", None), + ( + "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight", + (1, 0), + ), + ( + "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias", + None, + ), + ("control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight", (1, 0)), + ("control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias", None), + ("control_model.input_blocks.8.1.proj_out.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.8.1.proj_out.bias", None), + ("control_model.input_blocks.9.0.op.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.9.0.op.bias", None), + ("control_model.input_blocks.10.0.in_layers.0.weight", None), + ("control_model.input_blocks.10.0.in_layers.0.bias", None), + ("control_model.input_blocks.10.0.in_layers.2.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.10.0.in_layers.2.bias", None), + ("control_model.input_blocks.10.0.emb_layers.1.weight", (1, 0)), + ("control_model.input_blocks.10.0.emb_layers.1.bias", None), + ("control_model.input_blocks.10.0.out_layers.0.weight", None), + ("control_model.input_blocks.10.0.out_layers.0.bias", None), + ("control_model.input_blocks.10.0.out_layers.3.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.10.0.out_layers.3.bias", None), + ("control_model.input_blocks.11.0.in_layers.0.weight", None), + ("control_model.input_blocks.11.0.in_layers.0.bias", None), + ("control_model.input_blocks.11.0.in_layers.2.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.11.0.in_layers.2.bias", None), + ("control_model.input_blocks.11.0.emb_layers.1.weight", (1, 0)), + ("control_model.input_blocks.11.0.emb_layers.1.bias", None), + ("control_model.input_blocks.11.0.out_layers.0.weight", None), + ("control_model.input_blocks.11.0.out_layers.0.bias", None), + ("control_model.input_blocks.11.0.out_layers.3.weight", (2, 3, 1, 0)), + ("control_model.input_blocks.11.0.out_layers.3.bias", None), + ("control_model.middle_block.0.in_layers.0.weight", None), + ("control_model.middle_block.0.in_layers.0.bias", None), + ("control_model.middle_block.0.in_layers.2.weight", (2, 3, 1, 0)), + ("control_model.middle_block.0.in_layers.2.bias", None), + ("control_model.middle_block.0.emb_layers.1.weight", (1, 0)), + ("control_model.middle_block.0.emb_layers.1.bias", None), + ("control_model.middle_block.0.out_layers.0.weight", None), + ("control_model.middle_block.0.out_layers.0.bias", None), + ("control_model.middle_block.0.out_layers.3.weight", (2, 3, 1, 0)), + ("control_model.middle_block.0.out_layers.3.bias", None), + ("control_model.middle_block.1.norm.weight", None), + ("control_model.middle_block.1.norm.bias", None), + ("control_model.middle_block.1.proj_in.weight", (2, 3, 1, 0)), + ("control_model.middle_block.1.proj_in.bias", None), + ("control_model.middle_block.1.transformer_blocks.0.norm1.weight", None), + ("control_model.middle_block.1.transformer_blocks.0.norm1.bias", None), + ("control_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight", (1, 0)), + ("control_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight", (1, 0)), + ("control_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight", (1, 0)), + ( + "control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight", + (1, 0), + ), + ("control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias", None), + ("control_model.middle_block.1.transformer_blocks.0.norm2.weight", None), + ("control_model.middle_block.1.transformer_blocks.0.norm2.bias", None), + ("control_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight", (1, 0)), + ("control_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight", (1, 0)), + ("control_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight", (1, 0)), + ( + "control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight", + (1, 0), + ), + ("control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias", None), + ("control_model.middle_block.1.transformer_blocks.0.norm3.weight", None), + ("control_model.middle_block.1.transformer_blocks.0.norm3.bias", None), + ( + "control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight", + (1, 0), + ), + ("control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias", None), + ("control_model.middle_block.1.transformer_blocks.0.ff.net.2.weight", (1, 0)), + ("control_model.middle_block.1.transformer_blocks.0.ff.net.2.bias", None), + ("control_model.middle_block.1.proj_out.weight", (2, 3, 1, 0)), + ("control_model.middle_block.1.proj_out.bias", None), + ("control_model.middle_block.2.in_layers.0.weight", None), + ("control_model.middle_block.2.in_layers.0.bias", None), + ("control_model.middle_block.2.in_layers.2.weight", (2, 3, 1, 0)), + ("control_model.middle_block.2.in_layers.2.bias", None), + ("control_model.middle_block.2.emb_layers.1.weight", (1, 0)), + ("control_model.middle_block.2.emb_layers.1.bias", None), + ("control_model.middle_block.2.out_layers.0.weight", None), + ("control_model.middle_block.2.out_layers.0.bias", None), + ("control_model.middle_block.2.out_layers.3.weight", (2, 3, 1, 0)), + ("control_model.middle_block.2.out_layers.3.bias", None), + ("control_model.zero_convs.0.0.weight", (2, 3, 1, 0)), + ("control_model.zero_convs.0.0.bias", None), + ("control_model.zero_convs.1.0.weight", (2, 3, 1, 0)), + ("control_model.zero_convs.1.0.bias", None), + ("control_model.zero_convs.2.0.weight", (2, 3, 1, 0)), + ("control_model.zero_convs.2.0.bias", None), + ("control_model.zero_convs.3.0.weight", (2, 3, 1, 0)), + ("control_model.zero_convs.3.0.bias", None), + ("control_model.zero_convs.4.0.weight", (2, 3, 1, 0)), + ("control_model.zero_convs.4.0.bias", None), + ("control_model.zero_convs.5.0.weight", (2, 3, 1, 0)), + ("control_model.zero_convs.5.0.bias", None), + ("control_model.zero_convs.6.0.weight", (2, 3, 1, 0)), + ("control_model.zero_convs.6.0.bias", None), + ("control_model.zero_convs.7.0.weight", (2, 3, 1, 0)), + ("control_model.zero_convs.7.0.bias", None), + ("control_model.zero_convs.8.0.weight", (2, 3, 1, 0)), + ("control_model.zero_convs.8.0.bias", None), + ("control_model.zero_convs.9.0.weight", (2, 3, 1, 0)), + ("control_model.zero_convs.9.0.bias", None), + ("control_model.zero_convs.10.0.weight", (2, 3, 1, 0)), + ("control_model.zero_convs.10.0.bias", None), + ("control_model.zero_convs.11.0.weight", (2, 3, 1, 0)), + ("control_model.zero_convs.11.0.bias", None), + ("control_model.middle_block_out.0.weight", (2, 3, 1, 0)), + ("control_model.middle_block_out.0.bias", None), + ], +} + + +_UNCONDITIONAL_TOKENS = [ + 49406, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, +] +_ALPHAS_CUMPROD = [ + 0.99915, + 0.998296, + 0.9974381, + 0.9965762, + 0.99571025, + 0.9948404, + 0.9939665, + 0.9930887, + 0.9922069, + 0.9913211, + 0.9904313, + 0.98953754, + 0.9886398, + 0.9877381, + 0.9868324, + 0.98592263, + 0.98500896, + 0.9840913, + 0.9831696, + 0.982244, + 0.98131436, + 0.9803808, + 0.97944313, + 0.97850156, + 0.977556, + 0.9766064, + 0.97565293, + 0.9746954, + 0.9737339, + 0.9727684, + 0.97179896, + 0.97082555, + 0.96984816, + 0.96886677, + 0.9678814, + 0.96689206, + 0.96589875, + 0.9649015, + 0.96390027, + 0.9628951, + 0.9618859, + 0.96087277, + 0.95985574, + 0.95883465, + 0.9578097, + 0.95678073, + 0.95574784, + 0.954711, + 0.95367026, + 0.9526256, + 0.9515769, + 0.95052433, + 0.94946784, + 0.94840735, + 0.947343, + 0.94627476, + 0.9452025, + 0.9441264, + 0.9430464, + 0.9419625, + 0.9408747, + 0.939783, + 0.9386874, + 0.93758786, + 0.9364845, + 0.93537724, + 0.9342661, + 0.9331511, + 0.9320323, + 0.9309096, + 0.929783, + 0.9286526, + 0.9275183, + 0.9263802, + 0.92523825, + 0.92409253, + 0.92294294, + 0.9217895, + 0.92063236, + 0.9194713, + 0.9183065, + 0.9171379, + 0.91596556, + 0.9147894, + 0.9136095, + 0.91242576, + 0.9112383, + 0.9100471, + 0.9088522, + 0.9076535, + 0.9064511, + 0.90524495, + 0.9040351, + 0.90282154, + 0.9016043, + 0.90038335, + 0.8991587, + 0.8979304, + 0.8966984, + 0.89546275, + 0.89422345, + 0.8929805, + 0.89173394, + 0.89048374, + 0.88922995, + 0.8879725, + 0.8867115, + 0.88544685, + 0.88417864, + 0.88290685, + 0.8816315, + 0.88035256, + 0.8790701, + 0.87778413, + 0.8764946, + 0.8752016, + 0.873905, + 0.87260497, + 0.8713014, + 0.8699944, + 0.86868393, + 0.86737, + 0.8660526, + 0.8647318, + 0.86340755, + 0.8620799, + 0.8607488, + 0.85941434, + 0.8580765, + 0.8567353, + 0.8553907, + 0.8540428, + 0.85269153, + 0.85133696, + 0.84997904, + 0.84861785, + 0.8472533, + 0.8458856, + 0.8445145, + 0.84314024, + 0.84176266, + 0.8403819, + 0.8389979, + 0.8376107, + 0.8362203, + 0.83482677, + 0.83343, + 0.8320301, + 0.8306271, + 0.8292209, + 0.82781166, + 0.82639927, + 0.8249838, + 0.82356524, + 0.8221436, + 0.82071894, + 0.81929123, + 0.81786054, + 0.8164268, + 0.8149901, + 0.8135504, + 0.81210774, + 0.81066215, + 0.8092136, + 0.8077621, + 0.80630773, + 0.80485046, + 0.8033903, + 0.80192727, + 0.8004614, + 0.79899275, + 0.79752123, + 0.7960469, + 0.7945698, + 0.7930899, + 0.79160726, + 0.7901219, + 0.7886338, + 0.787143, + 0.7856495, + 0.7841533, + 0.78265446, + 0.78115296, + 0.7796488, + 0.77814204, + 0.7766327, + 0.7751208, + 0.7736063, + 0.77208924, + 0.7705697, + 0.7690476, + 0.767523, + 0.7659959, + 0.7644664, + 0.76293445, + 0.7614, + 0.7598632, + 0.75832397, + 0.75678235, + 0.75523835, + 0.75369203, + 0.7521434, + 0.75059247, + 0.7490392, + 0.7474837, + 0.7459259, + 0.7443659, + 0.74280363, + 0.7412392, + 0.7396726, + 0.7381038, + 0.73653287, + 0.7349598, + 0.7333846, + 0.73180735, + 0.730228, + 0.7286466, + 0.7270631, + 0.7254777, + 0.72389024, + 0.72230077, + 0.7207094, + 0.71911603, + 0.7175208, + 0.7159236, + 0.71432453, + 0.7127236, + 0.71112084, + 0.7095162, + 0.7079098, + 0.7063016, + 0.70469165, + 0.70307994, + 0.7014665, + 0.69985133, + 0.6982345, + 0.696616, + 0.6949958, + 0.69337404, + 0.69175065, + 0.69012564, + 0.6884991, + 0.68687093, + 0.6852413, + 0.68361014, + 0.6819775, + 0.6803434, + 0.67870784, + 0.6770708, + 0.6754324, + 0.6737926, + 0.67215145, + 0.670509, + 0.66886514, + 0.66722, + 0.6655736, + 0.66392595, + 0.662277, + 0.6606269, + 0.65897554, + 0.657323, + 0.65566933, + 0.6540145, + 0.6523586, + 0.6507016, + 0.6490435, + 0.64738435, + 0.6457241, + 0.64406294, + 0.6424008, + 0.64073765, + 0.63907355, + 0.63740855, + 0.6357426, + 0.6340758, + 0.6324082, + 0.6307397, + 0.6290704, + 0.6274003, + 0.6257294, + 0.62405777, + 0.6223854, + 0.62071234, + 0.6190386, + 0.61736417, + 0.6156891, + 0.61401343, + 0.6123372, + 0.6106603, + 0.6089829, + 0.607305, + 0.6056265, + 0.6039476, + 0.60226816, + 0.6005883, + 0.598908, + 0.59722733, + 0.5955463, + 0.59386486, + 0.5921831, + 0.59050107, + 0.5888187, + 0.5871361, + 0.5854532, + 0.5837701, + 0.5820868, + 0.5804033, + 0.5787197, + 0.5770359, + 0.575352, + 0.57366806, + 0.571984, + 0.5702999, + 0.5686158, + 0.56693166, + 0.56524754, + 0.5635635, + 0.5618795, + 0.56019557, + 0.5585118, + 0.5568281, + 0.55514455, + 0.5534612, + 0.551778, + 0.5500951, + 0.5484124, + 0.54673, + 0.5450478, + 0.54336596, + 0.54168445, + 0.54000324, + 0.53832245, + 0.5366421, + 0.53496206, + 0.5332825, + 0.53160346, + 0.5299248, + 0.52824676, + 0.5265692, + 0.52489215, + 0.5232157, + 0.5215398, + 0.51986456, + 0.51818997, + 0.51651603, + 0.51484275, + 0.5131702, + 0.5114983, + 0.5098272, + 0.50815684, + 0.5064873, + 0.50481856, + 0.50315064, + 0.50148356, + 0.4998174, + 0.4981521, + 0.49648774, + 0.49482432, + 0.49316183, + 0.49150035, + 0.48983985, + 0.4881804, + 0.486522, + 0.48486462, + 0.4832084, + 0.48155323, + 0.4798992, + 0.47824633, + 0.47659463, + 0.4749441, + 0.47329482, + 0.4716468, + 0.47, + 0.46835446, + 0.46671024, + 0.46506736, + 0.4634258, + 0.46178558, + 0.46014675, + 0.45850933, + 0.45687333, + 0.45523876, + 0.45360568, + 0.45197406, + 0.45034397, + 0.44871536, + 0.44708833, + 0.44546285, + 0.44383895, + 0.44221666, + 0.440596, + 0.43897697, + 0.43735963, + 0.43574396, + 0.43412998, + 0.43251774, + 0.43090722, + 0.4292985, + 0.42769152, + 0.42608637, + 0.42448303, + 0.4228815, + 0.42128187, + 0.4196841, + 0.41808826, + 0.4164943, + 0.4149023, + 0.41331223, + 0.41172415, + 0.41013804, + 0.40855396, + 0.4069719, + 0.4053919, + 0.40381396, + 0.4022381, + 0.40066436, + 0.39909273, + 0.39752322, + 0.3959559, + 0.39439073, + 0.39282778, + 0.39126703, + 0.3897085, + 0.3881522, + 0.3865982, + 0.38504648, + 0.38349706, + 0.38194993, + 0.38040516, + 0.37886274, + 0.37732267, + 0.375785, + 0.37424973, + 0.37271687, + 0.37118647, + 0.36965853, + 0.36813304, + 0.36661002, + 0.36508954, + 0.36357155, + 0.3620561, + 0.36054322, + 0.3590329, + 0.35752517, + 0.35602003, + 0.35451752, + 0.35301763, + 0.3515204, + 0.3500258, + 0.3485339, + 0.3470447, + 0.34555823, + 0.34407446, + 0.34259343, + 0.34111515, + 0.33963963, + 0.33816692, + 0.336697, + 0.3352299, + 0.33376563, + 0.3323042, + 0.33084565, + 0.32938993, + 0.32793713, + 0.3264872, + 0.32504022, + 0.32359615, + 0.32215503, + 0.32071686, + 0.31928164, + 0.31784943, + 0.3164202, + 0.314994, + 0.3135708, + 0.31215066, + 0.31073356, + 0.3093195, + 0.30790854, + 0.30650064, + 0.30509588, + 0.30369422, + 0.30229566, + 0.30090025, + 0.299508, + 0.2981189, + 0.29673296, + 0.29535022, + 0.2939707, + 0.29259437, + 0.29122123, + 0.28985137, + 0.28848472, + 0.28712133, + 0.2857612, + 0.28440437, + 0.2830508, + 0.28170055, + 0.2803536, + 0.27900997, + 0.27766964, + 0.27633268, + 0.27499905, + 0.2736688, + 0.27234194, + 0.27101842, + 0.2696983, + 0.26838157, + 0.26706827, + 0.26575837, + 0.26445192, + 0.26314887, + 0.2618493, + 0.26055318, + 0.2592605, + 0.25797132, + 0.2566856, + 0.2554034, + 0.25412467, + 0.25284946, + 0.25157773, + 0.2503096, + 0.24904492, + 0.24778382, + 0.24652626, + 0.24527225, + 0.2440218, + 0.24277493, + 0.24153163, + 0.24029191, + 0.23905578, + 0.23782326, + 0.23659433, + 0.23536903, + 0.23414734, + 0.23292927, + 0.23171483, + 0.23050404, + 0.22929688, + 0.22809339, + 0.22689353, + 0.22569734, + 0.22450483, + 0.22331597, + 0.2221308, + 0.22094932, + 0.21977153, + 0.21859743, + 0.21742703, + 0.21626033, + 0.21509734, + 0.21393807, + 0.21278252, + 0.21163069, + 0.21048258, + 0.20933822, + 0.20819758, + 0.2070607, + 0.20592754, + 0.20479813, + 0.20367248, + 0.20255059, + 0.20143245, + 0.20031808, + 0.19920748, + 0.19810064, + 0.19699757, + 0.19589828, + 0.19480278, + 0.19371104, + 0.1926231, + 0.19153893, + 0.19045855, + 0.18938197, + 0.18830918, + 0.18724018, + 0.18617497, + 0.18511358, + 0.18405597, + 0.18300217, + 0.18195218, + 0.18090598, + 0.1798636, + 0.17882504, + 0.17779027, + 0.1767593, + 0.17573217, + 0.17470883, + 0.1736893, + 0.1726736, + 0.1716617, + 0.17065361, + 0.16964935, + 0.1686489, + 0.16765225, + 0.16665943, + 0.16567042, + 0.16468522, + 0.16370384, + 0.16272627, + 0.16175252, + 0.16078258, + 0.15981644, + 0.15885411, + 0.1578956, + 0.15694089, + 0.15599, + 0.15504292, + 0.15409963, + 0.15316014, + 0.15222447, + 0.15129258, + 0.1503645, + 0.14944021, + 0.14851972, + 0.14760303, + 0.14669013, + 0.14578101, + 0.14487568, + 0.14397413, + 0.14307636, + 0.14218238, + 0.14129217, + 0.14040573, + 0.13952307, + 0.13864417, + 0.13776903, + 0.13689767, + 0.13603005, + 0.13516618, + 0.13430607, + 0.13344972, + 0.1325971, + 0.13174823, + 0.1309031, + 0.13006169, + 0.12922402, + 0.12839006, + 0.12755983, + 0.12673332, + 0.12591052, + 0.12509143, + 0.12427604, + 0.12346435, + 0.12265636, + 0.121852055, + 0.12105144, + 0.1202545, + 0.11946124, + 0.11867165, + 0.11788572, + 0.11710346, + 0.11632485, + 0.115549885, + 0.11477857, + 0.11401089, + 0.11324684, + 0.11248643, + 0.11172963, + 0.11097645, + 0.110226884, + 0.10948092, + 0.10873855, + 0.10799977, + 0.107264586, + 0.106532976, + 0.105804935, + 0.10508047, + 0.10435956, + 0.1036422, + 0.10292839, + 0.10221813, + 0.1015114, + 0.10080819, + 0.100108504, + 0.09941233, + 0.098719664, + 0.0980305, + 0.09734483, + 0.09666264, + 0.09598393, + 0.095308684, + 0.09463691, + 0.093968585, + 0.09330372, + 0.092642285, + 0.09198428, + 0.09132971, + 0.09067855, + 0.090030804, + 0.089386456, + 0.088745505, + 0.088107936, + 0.08747375, + 0.08684293, + 0.08621547, + 0.085591376, + 0.084970616, + 0.08435319, + 0.0837391, + 0.08312833, + 0.08252087, + 0.08191671, + 0.08131585, + 0.08071827, + 0.080123976, + 0.07953294, + 0.078945175, + 0.078360654, + 0.077779375, + 0.07720133, + 0.07662651, + 0.07605491, + 0.07548651, + 0.07492131, + 0.0743593, + 0.07380046, + 0.073244795, + 0.07269229, + 0.07214294, + 0.07159673, + 0.07105365, + 0.070513695, + 0.06997685, + 0.069443114, + 0.06891247, + 0.06838491, + 0.067860425, + 0.06733901, + 0.066820644, + 0.06630533, + 0.06579305, + 0.0652838, + 0.06477757, + 0.06427433, + 0.0637741, + 0.063276865, + 0.06278259, + 0.062291294, + 0.061802953, + 0.06131756, + 0.0608351, + 0.060355574, + 0.05987896, + 0.059405252, + 0.058934443, + 0.05846652, + 0.058001474, + 0.057539295, + 0.05707997, + 0.056623492, + 0.05616985, + 0.05571903, + 0.055271026, + 0.054825824, + 0.05438342, + 0.053943794, + 0.053506944, + 0.05307286, + 0.052641522, + 0.052212927, + 0.051787063, + 0.051363923, + 0.05094349, + 0.050525755, + 0.05011071, + 0.04969834, + 0.049288645, + 0.0488816, + 0.048477206, + 0.048075445, + 0.04767631, + 0.047279786, + 0.04688587, + 0.046494544, + 0.046105802, + 0.04571963, + 0.04533602, + 0.04495496, + 0.04457644, + 0.044200446, + 0.04382697, + 0.043456003, + 0.043087535, + 0.042721547, + 0.042358037, + 0.04199699, + 0.041638397, + 0.041282244, + 0.040928524, + 0.040577225, + 0.040228333, + 0.039881844, + 0.039537743, + 0.039196018, + 0.038856663, + 0.038519662, + 0.038185004, + 0.037852682, + 0.037522685, + 0.037195, + 0.036869615, + 0.036546525, + 0.036225714, + 0.03590717, + 0.035590887, + 0.035276853, + 0.034965057, + 0.034655485, + 0.03434813, + 0.03404298, + 0.033740025, + 0.033439253, + 0.033140652, + 0.032844216, + 0.03254993, + 0.032257784, + 0.03196777, + 0.031679876, + 0.031394087, + 0.031110398, + 0.030828796, + 0.030549273, + 0.030271813, + 0.02999641, + 0.029723052, + 0.029451728, + 0.029182427, + 0.02891514, + 0.028649855, + 0.028386563, + 0.028125253, + 0.02786591, + 0.027608532, + 0.027353102, + 0.027099613, + 0.026848052, + 0.026598409, + 0.026350675, + 0.02610484, + 0.02586089, + 0.02561882, + 0.025378617, + 0.025140269, + 0.024903767, + 0.0246691, + 0.02443626, + 0.024205236, + 0.023976017, + 0.023748592, + 0.023522953, + 0.023299087, + 0.023076987, + 0.022856642, + 0.02263804, + 0.022421172, + 0.022206029, + 0.0219926, + 0.021780876, + 0.021570845, + 0.021362498, + 0.021155827, + 0.020950818, + 0.020747466, + 0.020545758, + 0.020345684, + 0.020147236, + 0.019950403, + 0.019755175, + 0.019561544, + 0.019369498, + 0.019179028, + 0.018990126, + 0.01880278, + 0.018616982, + 0.018432721, + 0.01824999, + 0.018068777, + 0.017889075, + 0.017710872, + 0.01753416, + 0.017358929, + 0.017185168, + 0.017012872, + 0.016842028, + 0.016672628, + 0.016504662, + 0.016338123, + 0.016173, + 0.016009282, + 0.015846964, + 0.015686033, + 0.015526483, + 0.015368304, + 0.015211486, + 0.0150560215, + 0.014901901, + 0.014749114, + 0.014597654, + 0.014447511, + 0.0142986765, + 0.014151142, + 0.014004898, + 0.013859936, + 0.013716248, + 0.0135738235, + 0.013432656, + 0.013292736, + 0.013154055, + 0.013016605, + 0.012880377, + 0.012745362, + 0.012611552, + 0.012478939, + 0.012347515, + 0.01221727, + 0.012088198, + 0.0119602885, + 0.0118335355, + 0.011707929, + 0.011583461, + 0.011460125, + 0.011337912, + 0.011216813, + 0.011096821, + 0.010977928, + 0.0108601255, + 0.010743406, + 0.010627762, + 0.0105131855, + 0.010399668, + 0.010287202, + 0.01017578, + 0.010065395, + 0.009956039, + 0.009847702, + 0.009740381, + 0.0096340645, + 0.009528747, + 0.009424419, + 0.009321076, + 0.009218709, + 0.00911731, + 0.009016872, + 0.008917389, + 0.008818853, + 0.008721256, + 0.008624591, + 0.008528852, + 0.00843403, + 0.00834012, + 0.008247114, + 0.008155004, + 0.008063785, + 0.007973449, + 0.007883989, + 0.007795398, + 0.0077076694, + 0.0076207966, + 0.0075347726, + 0.007449591, + 0.0073652444, + 0.007281727, + 0.0071990318, + 0.007117152, + 0.0070360815, + 0.0069558136, + 0.0068763415, + 0.006797659, + 0.00671976, + 0.0066426382, + 0.0065662866, + 0.006490699, + 0.0064158696, + 0.006341792, + 0.00626846, + 0.0061958674, + 0.0061240084, + 0.0060528764, + 0.0059824656, + 0.0059127696, + 0.0058437833, + 0.0057755, + 0.0057079145, + 0.00564102, + 0.0055748112, + 0.0055092825, + 0.005444428, + 0.005380241, + 0.0053167176, + 0.005253851, + 0.005191636, + 0.005130066, + 0.0050691366, + 0.0050088423, + 0.0049491767, + 0.004890135, + 0.0048317118, + 0.004773902, + 0.004716699, + 0.0046600983, +] diff --git a/keras_hub/src/models/control_net/controlNetDiffusionModels.py b/keras_hub/src/models/control_net/controlNetDiffusionModels.py new file mode 100644 index 0000000000..860c862e19 --- /dev/null +++ b/keras_hub/src/models/control_net/controlNetDiffusionModels.py @@ -0,0 +1,683 @@ +""" +Copyright 2022 The KerasCV Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +you may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +ControlNet Version + +by AJ Young + +""" + +import keras + +""" +Models +""" + + +class DiffusionModel(keras.Model): + def __init__( + self, img_height, img_width, max_text_length, name="LockedDiffusionModel" + ): + context = keras.layers.Input((max_text_length, 768), name="Context_Input") + t_embed_input = keras.layers.Input((320,), name="TimeStepEmbed_Input") + latent = keras.layers.Input( + (img_height // 8, img_width // 8, 4), name="LatentImage_Input" + ) + + ### ControlNet Input ### + + controlNet1 = keras.layers.Input(shape=(1), name="ControlNet_Input001") + controlNet2 = keras.layers.Input(shape=(1), name="ControlNet_Input002") + controlNet3 = keras.layers.Input(shape=(1), name="ControlNet_Input003") + controlNet4 = keras.layers.Input(shape=(1), name="ControlNet_Input004") + controlNet5 = keras.layers.Input(shape=(1), name="ControlNet_Input005") + controlNet6 = keras.layers.Input(shape=(1), name="ControlNet_Input006") + controlNet7 = keras.layers.Input(shape=(1), name="ControlNet_Input007") + controlNet8 = keras.layers.Input(shape=(1), name="ControlNet_Input008") + controlNet9 = keras.layers.Input(shape=(1), name="ControlNet_Input009") + controlNet10 = keras.layers.Input(shape=(1), name="ControlNet_Input010") + controlNet11 = keras.layers.Input(shape=(1), name="ControlNet_Input011") + controlNet12 = keras.layers.Input(shape=(1), name="ControlNet_Input012") + controlNet13 = keras.layers.Input(shape=(1), name="ControlNet_Input013") + + controlNetResults = [ + controlNet1, + controlNet2, + controlNet3, + controlNet4, + controlNet5, + controlNet6, + controlNet7, + controlNet8, + controlNet9, + controlNet10, + controlNet11, + controlNet12, + controlNet13, + ] + + t_emb = keras.layers.Dense(1280, name="TimeEmbed1")(t_embed_input) + t_emb = keras.layers.Activation("swish", name="swishActivation")(t_emb) + t_emb = keras.layers.Dense(1280, name="TimeEmbed2")(t_emb) + + # Downsampling flow, aka input_blocks + + outputs = [] + x = PaddedConv2D(320, kernel_size=3, padding=1, name="inputBlocks")(latent) + outputs.append(x) + + for _ in range(2): + x = ResBlock(320)([x, t_emb]) + x = SpatialTransformer(8, 40, fully_connected=False)([x, context]) + outputs.append(x) + x = PaddedConv2D(320, 3, strides=2, padding=1)(x) # Downsample 2x + outputs.append(x) + + for _ in range(2): + x = ResBlock(640)([x, t_emb]) + x = SpatialTransformer(8, 80, fully_connected=False)([x, context]) + outputs.append(x) + x = PaddedConv2D(640, 3, strides=2, padding=1)(x) # Downsample 2x + outputs.append(x) + + for _ in range(2): + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(8, 160, fully_connected=False)([x, context]) + outputs.append(x) + x = PaddedConv2D(1280, 3, strides=2, padding=1)(x) # Downsample 2x + outputs.append(x) + + for _ in range(2): + x = ResBlock(1280)([x, t_emb]) + outputs.append(x) + + # Middle flow + + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(8, 160, fully_connected=False)([x, context]) + x = ResBlock(1280)([x, t_emb]) + # controlNetResults, controlNetResult = tfPOP(controlNetResults) + x = x + controlNetResults.pop() + + # Upsampling flow + + for _ in range(3): + # controlNetResults, controlNetResult = tfPOP(controlNetResults) + x = keras.layers.Concatenate()([x, outputs.pop() + controlNetResults.pop()]) + x = ResBlock(1280)([x, t_emb]) + x = Upsample(1280)(x) + + for _ in range(3): + # controlNetResults, controlNetResult = tfPOP(controlNetResults) + x = keras.layers.Concatenate()([x, outputs.pop() + controlNetResults.pop()]) + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(8, 160, fully_connected=False)([x, context]) + x = Upsample(1280)(x) + + for _ in range(3): + # controlNetResults, controlNetResult = tfPOP(controlNetResults) + x = keras.layers.Concatenate()([x, outputs.pop() + controlNetResults.pop()]) + x = ResBlock(640)([x, t_emb]) + x = SpatialTransformer(8, 80, fully_connected=False)([x, context]) + x = Upsample(640)(x) + + for _ in range(3): + # controlNetResults, controlNetResult = tfPOP(controlNetResults) + x = keras.layers.Concatenate()([x, outputs.pop() + controlNetResults.pop()]) + x = ResBlock(320)([x, t_emb]) + x = SpatialTransformer(8, 40, fully_connected=False)([x, context]) + + # Exit flow + + x = GroupNormalization(epsilon=1e-5)(x) + x = keras.layers.Activation("swish")(x) + output = PaddedConv2D(4, kernel_size=3, padding=1)(x) + + super().__init__( + inputs=[ + latent, + t_embed_input, + context, + controlNet1, + controlNet2, + controlNet3, + controlNet4, + controlNet5, + controlNet6, + controlNet7, + controlNet8, + controlNet9, + controlNet10, + controlNet11, + controlNet12, + controlNet13, + ], + outputs=output, + name=name, + ) + + +class ControlNetDiffusionModel(keras.Model): + def __init__( + self, + img_height, + img_width, + max_text_length, + name="ControlNetModel", + ): + context = keras.layers.Input((max_text_length, 768), name="Context_Input") + inputHint = keras.layers.Input((img_height, img_width, 3), name="Hint_Input") + t_embed_input = keras.layers.Input((320,), name="TimeStepEmbed_Input") + latent = keras.layers.Input( + (img_height // 8, img_width // 8, 4), name="LatentImage_Input" + ) + + t_emb = keras.layers.Dense(1280, name="ControlTimeEmbed1")(t_embed_input) + t_emb = keras.layers.Activation("swish", name="swishActivation")(t_emb) + t_emb = keras.layers.Dense(1280, name="ControlTimeEmbed2")(t_emb) + + # Input Hint Blocks + + guidedHint = HintBlocks()(inputHint) + + # Downsampling flow, aka input_blocks + + outputs = [] + x = PaddedConv2D(320, kernel_size=3, padding=1, name="inputBlocks")(latent) + x = x + guidedHint + outputs.append(zeroConv(x, 320, "zeroConv1")) + + for _ in range(2): + x = ResBlock(320)([x, t_emb]) + x = SpatialTransformer(8, 40, fully_connected=False)([x, context]) + outputs.append(zeroConv(x, 320)) + x = PaddedConv2D(320, 3, strides=2, padding=1)(x) # Downsample 2x + outputs.append(zeroConv(x, 320, "zeroConv4")) + + for _ in range(2): + x = ResBlock(640)([x, t_emb]) + x = SpatialTransformer(8, 80, fully_connected=False)([x, context]) + outputs.append(zeroConv(x, 640)) + x = PaddedConv2D(640, 3, strides=2, padding=1)(x) # Downsample 2x + outputs.append(zeroConv(x, 640, "zeroConv7")) + + for _ in range(2): + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(8, 160, fully_connected=False)([x, context]) + outputs.append(zeroConv(x, 1280)) + x = PaddedConv2D(1280, 3, strides=2, padding=1)(x) # Downsample 2x + outputs.append(zeroConv(x, 1280, "zeroConv10")) + + for _ in range(2): + x = ResBlock(1280)([x, t_emb]) + outputs.append(zeroConv(x, 1280)) + + # Middle flow + + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(8, 160, fully_connected=False)([x, context]) + x = ResBlock(1280)([x, t_emb]) + outputs.append(zeroConv(x, 1280, "zeroConv13")) + + super().__init__( + [latent, t_embed_input, context, inputHint], outputs, name=name + ) + # Input: Latent, TimestepEmbed, Context, Input Hint + # Output: Python List of each zeroConv (Zero Convolution Layer) + + +class DiffusionModelV2(keras.Model): + def __init__(self, img_height, img_width, max_text_length, name=None): + context = keras.layers.Input((max_text_length, 1024)) + t_embed_input = keras.layers.Input((320,)) + latent = keras.layers.Input((img_height // 8, img_width // 8, 4)) + + t_emb = keras.layers.Dense(1280)(t_embed_input) + t_emb = keras.layers.Activation("swish")(t_emb) + t_emb = keras.layers.Dense(1280)(t_emb) + + # Downsampling flow + + outputs = [] + x = PaddedConv2D(320, kernel_size=3, padding=1)(latent) + outputs.append(x) + + for _ in range(2): + x = ResBlock(320)([x, t_emb]) + x = SpatialTransformer(5, 64, fully_connected=True)([x, context]) + outputs.append(x) + x = PaddedConv2D(320, 3, strides=2, padding=1)(x) # Downsample 2x + outputs.append(x) + + for _ in range(2): + x = ResBlock(640)([x, t_emb]) + x = SpatialTransformer(10, 64, fully_connected=True)([x, context]) + outputs.append(x) + x = PaddedConv2D(640, 3, strides=2, padding=1)(x) # Downsample 2x + outputs.append(x) + + for _ in range(2): + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(20, 64, fully_connected=True)([x, context]) + outputs.append(x) + x = PaddedConv2D(1280, 3, strides=2, padding=1)(x) # Downsample 2x + outputs.append(x) + + for _ in range(2): + x = ResBlock(1280)([x, t_emb]) + outputs.append(x) + + # Middle flow + + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(20, 64, fully_connected=True)([x, context]) + x = ResBlock(1280)([x, t_emb]) + + # Upsampling flow + + for _ in range(3): + x = keras.layers.Concatenate()([x, outputs.pop()]) + x = ResBlock(1280)([x, t_emb]) + x = Upsample(1280)(x) + + for _ in range(3): + x = keras.layers.Concatenate()([x, outputs.pop()]) + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(20, 64, fully_connected=True)([x, context]) + x = Upsample(1280)(x) + + for _ in range(3): + x = keras.layers.Concatenate()([x, outputs.pop()]) + x = ResBlock(640)([x, t_emb]) + x = SpatialTransformer(10, 64, fully_connected=True)([x, context]) + x = Upsample(640)(x) + + for _ in range(3): + x = keras.layers.Concatenate()([x, outputs.pop()]) + x = ResBlock(320)([x, t_emb]) + x = SpatialTransformer(5, 64, fully_connected=True)([x, context]) + + # Exit flow + + x = GroupNormalization(epsilon=1e-5)(x) + x = keras.layers.Activation("swish")(x) + output = PaddedConv2D(4, kernel_size=3, padding=1)(x) + + super().__init__([latent, t_embed_input, context], output, name=name) + + +""" +Blocks +""" + + +class GroupNormalization(keras.layers.Layer): + """ + GroupNormalization layer. + + This layer is only here temporarily and will be removed + as we introduce GroupNormalization in core Keras. + """ + + def __init__( + self, + groups=32, + axis=-1, + epsilon=1e-5, + name="GroupNormalization", + **kwargs, + ): + super().__init__(**kwargs) + self.groups = groups + self.axis = axis + self.epsilon = epsilon + + def get_config(self): + config = super().get_config() + config.update( + {"groups": self.groups, "axis": self.axis, "epsilon": self.epsilon} + ) + return config + + def build(self, input_shape): + dim = input_shape[self.axis] + self.gamma = self.add_weight( + shape=(dim,), + name="gamma", + initializer="ones", + ) + self.beta = self.add_weight( + shape=(dim,), + name="beta", + initializer="zeros", + ) + + ## @tf.function + def call(self, inputs): + input_shape = keras.ops.shape(inputs) + reshaped_inputs = self._reshape_into_groups(inputs, input_shape) + normalized_inputs = self._apply_normalization(reshaped_inputs, input_shape) + return keras.ops.reshape(normalized_inputs, input_shape) + + def _reshape_into_groups(self, inputs, input_shape): + group_shape = [input_shape[i] for i in range(inputs.shape.rank)] + group_shape[self.axis] = input_shape[self.axis] // self.groups + group_shape.insert(self.axis, self.groups) + group_shape = keras.ops.stack(group_shape) + return keras.ops.reshape(inputs, group_shape) + + def _apply_normalization(self, reshaped_inputs, input_shape): + group_reduction_axes = list(range(1, reshaped_inputs.shape.rank)) + axis = -2 if self.axis == -1 else self.axis - 1 + group_reduction_axes.pop(axis) + mean, variance = keras.ops.moments( + reshaped_inputs, group_reduction_axes, keepdims=True + ) + gamma, beta = self._get_reshaped_weights(input_shape) + return keras.ops.batch_normalization( + reshaped_inputs, + mean=mean, + variance=variance, + scale=gamma, + offset=beta, + variance_epsilon=self.epsilon, + ) + + def _get_reshaped_weights(self, input_shape): + broadcast_shape = self._create_broadcast_shape(input_shape) + gamma = keras.ops.reshape(self.gamma, broadcast_shape) + beta = keras.ops.reshape(self.beta, broadcast_shape) + return gamma, beta + + def _create_broadcast_shape(self, input_shape): + broadcast_shape = [1] * input_shape.shape.rank + broadcast_shape[self.axis] = input_shape[self.axis] // self.groups + broadcast_shape.insert(self.axis, self.groups) + return broadcast_shape + + +class HintBlocks(keras.layers.Layer): + def __init__(self, hint_channels=16, model_channels=320, **kwargs): + super().__init__(**kwargs) + self.layers = [ + PaddedConv2D(filters=16, kernel_size=3, padding=1), + keras.layers.Activation("swish"), + PaddedConv2D(filters=16, kernel_size=3, padding=1), + keras.layers.Activation("swish"), + PaddedConv2D(filters=32, kernel_size=3, padding=1, strides=2), + keras.layers.Activation("swish"), + PaddedConv2D(filters=32, kernel_size=3, padding=1), + keras.layers.Activation("swish"), + PaddedConv2D(filters=96, kernel_size=3, padding=1, strides=2), + keras.layers.Activation("swish"), + PaddedConv2D(filters=96, kernel_size=3, padding=1), + keras.layers.Activation("swish"), + PaddedConv2D(filters=256, kernel_size=3, padding=1, strides=2), + keras.layers.Activation("swish"), + PaddedConv2D(filters=model_channels, kernel_size=3, padding=1), + ] + + ## @tf.function + def call(self, inputs): + x = inputs + layerNumber = 0 + layerLength = len(self.layers) + for layer in self.layers: + if layerNumber == layerLength: + for weight in layer.weights: + weight.assign(keras.ops.zeros_like(weight)) + x = layer(x) + layerNumber += 1 + return x + + +class PaddedConv2D(keras.layers.Layer): + def __init__(self, filters, kernel_size, padding=0, strides=1, name=None, **kwargs): + super().__init__(**kwargs) + self.padding2d = keras.layers.ZeroPadding2D(padding, name=name) + self.conv2d = keras.layers.Conv2D( + filters, kernel_size, strides=strides, name=name + ) + self.filters = filters + self.kernel_size = kernel_size + self.padding = padding + self.strides = strides + + ## @tf.function + def call(self, inputs): + x = self.padding2d(inputs) + return self.conv2d(x) + + def get_config(self): + config = super().get_config() + config.update( + { + "filters": self.filters, + "kernel_size": self.kernel_size, + "padding": self.padding, + "strides": self.strides, + } + ) + return config + + +class ResBlock(keras.layers.Layer): + def __init__(self, output_dim, **kwargs): + super().__init__(**kwargs) + self.output_dim = output_dim + self.entry_flow = [ + GroupNormalization(epsilon=1e-5), + keras.layers.Activation("swish", name="ResBlock_swish1"), + PaddedConv2D(output_dim, 3, padding=1, name="inLayers2"), + ] + self.embedding_flow = [ + keras.layers.Activation("swish"), + keras.layers.Dense(output_dim, name="embeddingLayer"), + ] + self.exit_flow = [ + GroupNormalization(epsilon=1e-5), + keras.layers.Activation("swish", name="ResBlock_swish2"), + PaddedConv2D(output_dim, 3, padding=1, name="outLayers3"), + ] + + def build(self, input_shape): + if input_shape[0][-1] != self.output_dim: + self.residual_projection = PaddedConv2D(self.output_dim, 1) + else: + self.residual_projection = lambda x: x + + ## @tf.function + def call(self, inputs): + inputs, embeddings = inputs + x = inputs + for layer in self.entry_flow: + x = layer(x) + for layer in self.embedding_flow: + embeddings = layer(embeddings) + x = x + embeddings[:, None, None] + for layer in self.exit_flow: + x = layer(x) + return x + self.residual_projection(inputs) + + def get_config(self): + config = super().get_config() + config.update( + { + "output_dim": self.output_dim, + } + ) + return config + + +class SpatialTransformer(keras.layers.Layer): + def __init__(self, num_heads, head_size, fully_connected=False, **kwargs): + super().__init__(**kwargs) + self.norm = GroupNormalization(epsilon=1e-5) + self.num_heads = num_heads + self.head_size = head_size + self.fully_connected = fully_connected + channels = num_heads * head_size + if fully_connected: + self.proj1 = keras.layers.Dense( + num_heads * head_size, name="proj_in1_fullyConnected" + ) + else: + self.proj1 = PaddedConv2D(num_heads * head_size, 1, name="proj_in") + self.transformer_block = BasicTransformerBlock(channels, num_heads, head_size) + if fully_connected: + self.proj2 = keras.layers.Dense(channels) + else: + self.proj2 = PaddedConv2D(channels, 1, name="proj_in2_fullyConnected") + + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "head_size": self.head_size, + "fully_connected": self.fully_connected, + } + ) + return config + + ## @tf.function + def call(self, inputs): + inputs, context = inputs + _, h, w, c = inputs.shape + x = self.norm(inputs) + x = self.proj1(x) + x = keras.ops.reshape(x, (-1, h * w, c)) + x = self.transformer_block([x, context]) + x = keras.ops.reshape(x, (-1, h, w, c)) + return self.proj2(x) + inputs + + +class BasicTransformerBlock(keras.layers.Layer): + def __init__(self, dim, num_heads, head_size, **kwargs): + super().__init__(**kwargs) + self.norm1 = keras.layers.LayerNormalization(epsilon=1e-5, name="norm1") + self.attn1 = CrossAttention(num_heads, head_size) + + self.norm2 = keras.layers.LayerNormalization(epsilon=1e-5, name="norm2") + self.attn2 = CrossAttention(num_heads, head_size) + + self.norm3 = keras.layers.LayerNormalization(epsilon=1e-5, name="norm3") + self.geglu = GEGLU(dim * 4) + self.dense = keras.layers.Dense(dim) + + ## @tf.function + def call(self, inputs): + inputs, context = inputs + x = self.attn1([self.norm1(inputs), None]) + inputs + x = self.attn2([self.norm2(x), context]) + x + return self.dense(self.geglu(self.norm3(x))) + x + + +class CrossAttention(keras.layers.Layer): + def __init__(self, num_heads, head_size, **kwargs): + super().__init__(**kwargs) + self.to_q = keras.layers.Dense( + num_heads * head_size, use_bias=False, name="to_q" + ) + self.to_k = keras.layers.Dense( + num_heads * head_size, use_bias=False, name="to_k" + ) + self.to_v = keras.layers.Dense( + num_heads * head_size, use_bias=False, name="to_v" + ) + self.scale = head_size**-0.5 + self.num_heads = num_heads + self.head_size = head_size + self.out_proj = keras.layers.Dense(num_heads * head_size, name="out_projection") + + ## @tf.function + def call(self, inputs): + inputs, context = inputs + context = inputs if context is None else context + q, k, v = self.to_q(inputs), self.to_k(context), self.to_v(context) + q = keras.ops.reshape(q, (-1, inputs.shape[1], self.num_heads, self.head_size)) + k = keras.ops.reshape(k, (-1, context.shape[1], self.num_heads, self.head_size)) + v = keras.ops.reshape(v, (-1, context.shape[1], self.num_heads, self.head_size)) + + q = keras.ops.transpose(q, (0, 2, 1, 3)) # (bs, num_heads, time, head_size) + k = keras.ops.transpose(k, (0, 2, 3, 1)) # (bs, num_heads, head_size, time) + v = keras.ops.transpose(v, (0, 2, 1, 3)) # (bs, num_heads, time, head_size) + + score = td_dot(q, k) * self.scale + weights = keras.activations.softmax(score) # (bs, num_heads, time, time) + attn = td_dot(weights, v) + attn = keras.ops.transpose( + attn, (0, 2, 1, 3) + ) # (bs, time, num_heads, head_size) + out = keras.ops.reshape( + attn, (-1, inputs.shape[1], self.num_heads * self.head_size) + ) + return self.out_proj(out) + + +class Upsample(keras.layers.Layer): + def __init__(self, channels, **kwargs): + super().__init__(**kwargs) + self.channels = channels + self.ups = keras.layers.UpSampling2D(2) + self.conv = PaddedConv2D(channels, 3, padding=1, name="Upsample") + + ## @tf.function + def call(self, inputs): + return self.conv(self.ups(inputs)) + + def get_config(self): + config = super().get_config() + config.update( + { + "channels": self.channels, + } + ) + return config + + +class GEGLU(keras.layers.Layer): + def __init__(self, output_dim, **kwargs): + super().__init__(**kwargs) + self.output_dim = output_dim + self.dense = keras.layers.Dense(output_dim * 2) + + ## @tf.function + def call(self, inputs): + x = self.dense(inputs) + x, gate = x[..., : self.output_dim], x[..., self.output_dim :] + tanh_res = keras.activations.tanh( + gate * 0.7978845608 * (1 + 0.044715 * (gate**2)) + ) + return x * 0.5 * gate * (1 + tanh_res) + + +def td_dot(a, b): + aa = keras.ops.reshape(a, (-1, a.shape[2], a.shape[3])) + bb = keras.ops.reshape(b, (-1, b.shape[2], b.shape[3])) + cc = keras.backend.batch_dot(aa, bb) + return keras.ops.reshape(cc, (-1, a.shape[1], cc.shape[1], cc.shape[2])) + + +def zeroConv(tensor, channels, name=None): + layer = keras.layers.Conv2D( + filters=channels, kernel_size=1, padding="same", name=name + ) + for weight in layer.weights: + weight.assign(keras.ops.zeros_like(weight)) + + tensor = layer(tensor) + + return tensor diff --git a/keras_hub/src/models/control_net/kerasCVDiffusionModels.py b/keras_hub/src/models/control_net/kerasCVDiffusionModels.py new file mode 100644 index 0000000000..23efbe20a4 --- /dev/null +++ b/keras_hub/src/models/control_net/kerasCVDiffusionModels.py @@ -0,0 +1,521 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import keras + +""" +Models +""" + + +class DiffusionModel(keras.Model): + def __init__( + self, + img_height, + img_width, + max_text_length, + name="DiffusionModel", + download_weights=False, + ): + context = keras.layers.Input((max_text_length, 768), name="Context_Input") + t_embed_input = keras.layers.Input((320,), name="TimeStepEmbed_Input") + latent = keras.layers.Input( + (img_height // 8, img_width // 8, 4), name="LatentImage_Input" + ) + + t_emb = keras.layers.Dense(1280, name="TimeEmbed1")(t_embed_input) + t_emb = keras.layers.Activation("swish", name="swishActivation")(t_emb) + t_emb = keras.layers.Dense(1280, name="TimeEmbed2")(t_emb) + + # Downsampling flow, aka input_blocks + + outputs = [] + x = PaddedConv2D(320, kernel_size=3, padding=1, name="inputBlocks")(latent) + outputs.append(x) + + for _ in range(2): + x = ResBlock(320)([x, t_emb]) + x = SpatialTransformer(8, 40, fully_connected=False)([x, context]) + outputs.append(x) + x = PaddedConv2D(320, 3, strides=2, padding=1)(x) # Downsample 2x + outputs.append(x) + + for _ in range(2): + x = ResBlock(640)([x, t_emb]) + x = SpatialTransformer(8, 80, fully_connected=False)([x, context]) + outputs.append(x) + x = PaddedConv2D(640, 3, strides=2, padding=1)(x) # Downsample 2x + outputs.append(x) + + for _ in range(2): + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(8, 160, fully_connected=False)([x, context]) + outputs.append(x) + x = PaddedConv2D(1280, 3, strides=2, padding=1)(x) # Downsample 2x + outputs.append(x) + + for _ in range(2): + x = ResBlock(1280)([x, t_emb]) + outputs.append(x) + + # Middle flow + + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(8, 160, fully_connected=False)([x, context]) + x = ResBlock(1280)([x, t_emb]) + + # Upsampling flow + + for _ in range(3): + x = keras.layers.Concatenate()([x, outputs.pop()]) + x = ResBlock(1280)([x, t_emb]) + x = Upsample(1280)(x) + + for _ in range(3): + x = keras.layers.Concatenate()([x, outputs.pop()]) + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(8, 160, fully_connected=False)([x, context]) + x = Upsample(1280)(x) + + for _ in range(3): + x = keras.layers.Concatenate()([x, outputs.pop()]) + x = ResBlock(640)([x, t_emb]) + x = SpatialTransformer(8, 80, fully_connected=False)([x, context]) + x = Upsample(640)(x) + + for _ in range(3): + x = keras.layers.Concatenate()([x, outputs.pop()]) + x = ResBlock(320)([x, t_emb]) + x = SpatialTransformer(8, 40, fully_connected=False)([x, context]) + + # Exit flow + + x = GroupNormalization(epsilon=1e-5)(x) + x = keras.layers.Activation("swish")(x) + output = PaddedConv2D(4, kernel_size=3, padding=1)(x) + + super().__init__([latent, t_embed_input, context], output, name=name) + + if download_weights: + diffusion_model_weights_fpath = keras.utils.get_file( + origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_diffusion_model.h5", + file_hash="8799ff9763de13d7f30a683d653018e114ed24a6a819667da4f5ee10f9e805fe", + ) + self.load_weights(diffusion_model_weights_fpath) + + +class DiffusionModelV2(keras.Model): + def __init__( + self, img_height, img_width, max_text_length, name=None, download_weights=False + ): + context = keras.layers.Input((max_text_length, 1024)) + t_embed_input = keras.layers.Input((320,)) + latent = keras.layers.Input((img_height // 8, img_width // 8, 4)) + + t_emb = keras.layers.Dense(1280)(t_embed_input) + t_emb = keras.layers.Activation("swish")(t_emb) + t_emb = keras.layers.Dense(1280)(t_emb) + + # Downsampling flow + + outputs = [] + x = PaddedConv2D(320, kernel_size=3, padding=1)(latent) + outputs.append(x) + + for _ in range(2): + x = ResBlock(320)([x, t_emb]) + x = SpatialTransformer(5, 64, fully_connected=True)([x, context]) + outputs.append(x) + x = PaddedConv2D(320, 3, strides=2, padding=1)(x) # Downsample 2x + outputs.append(x) + + for _ in range(2): + x = ResBlock(640)([x, t_emb]) + x = SpatialTransformer(10, 64, fully_connected=True)([x, context]) + outputs.append(x) + x = PaddedConv2D(640, 3, strides=2, padding=1)(x) # Downsample 2x + outputs.append(x) + + for _ in range(2): + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(20, 64, fully_connected=True)([x, context]) + outputs.append(x) + x = PaddedConv2D(1280, 3, strides=2, padding=1)(x) # Downsample 2x + outputs.append(x) + + for _ in range(2): + x = ResBlock(1280)([x, t_emb]) + outputs.append(x) + + # Middle flow + + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(20, 64, fully_connected=True)([x, context]) + x = ResBlock(1280)([x, t_emb]) + + # Upsampling flow + + for _ in range(3): + x = keras.layers.Concatenate()([x, outputs.pop()]) + x = ResBlock(1280)([x, t_emb]) + x = Upsample(1280)(x) + + for _ in range(3): + x = keras.layers.Concatenate()([x, outputs.pop()]) + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(20, 64, fully_connected=True)([x, context]) + x = Upsample(1280)(x) + + for _ in range(3): + x = keras.layers.Concatenate()([x, outputs.pop()]) + x = ResBlock(640)([x, t_emb]) + x = SpatialTransformer(10, 64, fully_connected=True)([x, context]) + x = Upsample(640)(x) + + for _ in range(3): + x = keras.layers.Concatenate()([x, outputs.pop()]) + x = ResBlock(320)([x, t_emb]) + x = SpatialTransformer(5, 64, fully_connected=True)([x, context]) + + # Exit flow + + x = GroupNormalization(epsilon=1e-5)(x) + x = keras.layers.Activation("swish")(x) + output = PaddedConv2D(4, kernel_size=3, padding=1)(x) + + super().__init__([latent, t_embed_input, context], output, name=name) + + if download_weights: + diffusion_model_weights_fpath = keras.utils.get_file( + origin="https://huggingface.co/ianstenbit/keras-sd2.1/resolve/main/diffusion_model_v2_1.h5", + file_hash="c31730e91111f98fe0e2dbde4475d381b5287ebb9672b1821796146a25c5132d", + ) + self.load_weights(diffusion_model_weights_fpath) + + +""" +Blocks +""" + + +class GroupNormalization(keras.layers.Layer): + """GroupNormalization layer. + This layer is only here temporarily and will be removed + as we introduce GroupNormalization in core Keras. + """ + + def __init__( + self, + groups=32, + axis=-1, + epsilon=1e-5, + name="GroupNormalization", + **kwargs, + ): + super().__init__(**kwargs) + self.groups = groups + self.axis = axis + self.epsilon = epsilon + + def get_config(self): + config = super().get_config() + config.update( + {"groups": self.groups, "axis": self.axis, "epsilon": self.epsilon} + ) + return config + + def build(self, input_shape): + dim = input_shape[self.axis] + self.gamma = self.add_weight( + shape=(dim,), + name="gamma", + initializer="ones", + ) + self.beta = self.add_weight( + shape=(dim,), + name="beta", + initializer="zeros", + ) + + # @tf.function + def call(self, inputs): + input_shape = keras.ops.shape(inputs) + reshaped_inputs = self._reshape_into_groups(inputs, input_shape) + normalized_inputs = self._apply_normalization(reshaped_inputs, input_shape) + return keras.ops.reshape(normalized_inputs, input_shape) + + def _reshape_into_groups(self, inputs, input_shape): + group_shape = [input_shape[i] for i in range(inputs.shape.rank)] + group_shape[self.axis] = input_shape[self.axis] // self.groups + group_shape.insert(self.axis, self.groups) + group_shape = keras.ops.stack(group_shape) + return keras.ops.reshape(inputs, group_shape) + + def _apply_normalization(self, reshaped_inputs, input_shape): + group_reduction_axes = list(range(1, reshaped_inputs.shape.rank)) + axis = -2 if self.axis == -1 else self.axis - 1 + group_reduction_axes.pop(axis) + mean, variance = keras.ops.moments( + reshaped_inputs, group_reduction_axes, keepdims=True + ) + gamma, beta = self._get_reshaped_weights(input_shape) + return keras.ops.batch_normalization( + reshaped_inputs, + mean=mean, + variance=variance, + scale=gamma, + offset=beta, + variance_epsilon=self.epsilon, + ) + + def _get_reshaped_weights(self, input_shape): + broadcast_shape = self._create_broadcast_shape(input_shape) + gamma = keras.ops.reshape(self.gamma, broadcast_shape) + beta = keras.ops.reshape(self.beta, broadcast_shape) + return gamma, beta + + def _create_broadcast_shape(self, input_shape): + broadcast_shape = [1] * input_shape.shape.rank + broadcast_shape[self.axis] = input_shape[self.axis] // self.groups + broadcast_shape.insert(self.axis, self.groups) + return broadcast_shape + + +class PaddedConv2D(keras.layers.Layer): + def __init__(self, filters, kernel_size, padding=0, strides=1, name=None, **kwargs): + super().__init__(**kwargs) + self.padding2d = keras.layers.ZeroPadding2D(padding, name=name) + self.conv2d = keras.layers.Conv2D( + filters, kernel_size, strides=strides, name=name + ) + self.filters = filters + self.kernel_size = kernel_size + self.padding = padding + self.strides = strides + + # @tf.function + def call(self, inputs): + x = self.padding2d(inputs) + return self.conv2d(x) + + def get_config(self): + config = super().get_config() + config.update( + { + "filters": self.filters, + "kernel_size": self.kernel_size, + "padding": self.padding, + "strides": self.strides, + } + ) + return config + + +class ResBlock(keras.layers.Layer): + def __init__(self, output_dim, **kwargs): + super().__init__(**kwargs) + self.output_dim = output_dim + self.entry_flow = [ + GroupNormalization(epsilon=1e-5), + keras.layers.Activation("swish", name="ResBlock_swish1"), + PaddedConv2D(output_dim, 3, padding=1, name="inLayers2"), + ] + self.embedding_flow = [ + keras.layers.Activation("swish"), + keras.layers.Dense(output_dim, name="embeddingLayer"), + ] + self.exit_flow = [ + GroupNormalization(epsilon=1e-5), + keras.layers.Activation("swish", name="ResBlock_swish2"), + PaddedConv2D(output_dim, 3, padding=1, name="outLayers3"), + ] + + def build(self, input_shape): + if input_shape[0][-1] != self.output_dim: + self.residual_projection = PaddedConv2D(self.output_dim, 1) + else: + self.residual_projection = lambda x: x + + # @tf.function + def call(self, inputs): + inputs, embeddings = inputs + x = inputs + for layer in self.entry_flow: + x = layer(x) + for layer in self.embedding_flow: + embeddings = layer(embeddings) + x = x + embeddings[:, None, None] + for layer in self.exit_flow: + x = layer(x) + return x + self.residual_projection(inputs) + + def get_config(self): + config = super().get_config() + config.update( + { + "output_dim": self.output_dim, + } + ) + return config + + +class SpatialTransformer(keras.layers.Layer): + def __init__(self, num_heads, head_size, fully_connected=False, **kwargs): + super().__init__(**kwargs) + self.norm = GroupNormalization(epsilon=1e-5) + self.num_heads = num_heads + self.head_size = head_size + self.fully_connected = fully_connected + channels = num_heads * head_size + if fully_connected: + self.proj1 = keras.layers.Dense( + num_heads * head_size, name="proj_in1_fullyConnected" + ) + else: + self.proj1 = PaddedConv2D(num_heads * head_size, 1, name="proj_in") + self.transformer_block = BasicTransformerBlock(channels, num_heads, head_size) + if fully_connected: + self.proj2 = keras.layers.Dense(channels) + else: + self.proj2 = PaddedConv2D(channels, 1, name="proj_in2_fullyConnected") + + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "head_size": self.head_size, + "fully_connected": self.fully_connected, + } + ) + return config + + # @tf.function + def call(self, inputs): + inputs, context = inputs + _, h, w, c = inputs.shape + x = self.norm(inputs) + x = self.proj1(x) + x = keras.ops.reshape(x, (-1, h * w, c)) + x = self.transformer_block([x, context]) + x = keras.ops.reshape(x, (-1, h, w, c)) + return self.proj2(x) + inputs + + +class BasicTransformerBlock(keras.layers.Layer): + def __init__(self, dim, num_heads, head_size, **kwargs): + super().__init__(**kwargs) + self.norm1 = keras.layers.LayerNormalization(epsilon=1e-5, name="norm1") + self.attn1 = CrossAttention(num_heads, head_size) + + self.norm2 = keras.layers.LayerNormalization(epsilon=1e-5, name="norm2") + self.attn2 = CrossAttention(num_heads, head_size) + + self.norm3 = keras.layers.LayerNormalization(epsilon=1e-5, name="norm3") + self.geglu = GEGLU(dim * 4) + self.dense = keras.layers.Dense(dim) + + # @tf.function + def call(self, inputs): + inputs, context = inputs + x = self.attn1([self.norm1(inputs), None]) + inputs + x = self.attn2([self.norm2(x), context]) + x + return self.dense(self.geglu(self.norm3(x))) + x + + +class CrossAttention(keras.layers.Layer): + def __init__(self, num_heads, head_size, **kwargs): + super().__init__(**kwargs) + self.to_q = keras.layers.Dense( + num_heads * head_size, use_bias=False, name="to_q" + ) + self.to_k = keras.layers.Dense( + num_heads * head_size, use_bias=False, name="to_k" + ) + self.to_v = keras.layers.Dense( + num_heads * head_size, use_bias=False, name="to_v" + ) + self.scale = head_size**-0.5 + self.num_heads = num_heads + self.head_size = head_size + self.out_proj = keras.layers.Dense(num_heads * head_size, name="out_projection") + + # @tf.function + def call(self, inputs): + inputs, context = inputs + context = inputs if context is None else context + q, k, v = self.to_q(inputs), self.to_k(context), self.to_v(context) + q = keras.ops.reshape(q, (-1, inputs.shape[1], self.num_heads, self.head_size)) + k = keras.ops.reshape(k, (-1, context.shape[1], self.num_heads, self.head_size)) + v = keras.ops.reshape(v, (-1, context.shape[1], self.num_heads, self.head_size)) + + q = keras.ops.transpose(q, (0, 2, 1, 3)) # (bs, num_heads, time, head_size) + k = keras.ops.transpose(k, (0, 2, 3, 1)) # (bs, num_heads, head_size, time) + v = keras.ops.transpose(v, (0, 2, 1, 3)) # (bs, num_heads, time, head_size) + + score = td_dot(q, k) * self.scale + weights = keras.activations.softmax(score) # (bs, num_heads, time, time) + attn = td_dot(weights, v) + attn = keras.ops.transpose( + attn, (0, 2, 1, 3) + ) # (bs, time, num_heads, head_size) + out = keras.ops.reshape( + attn, (-1, inputs.shape[1], self.num_heads * self.head_size) + ) + return self.out_proj(out) + + +class Upsample(keras.layers.Layer): + def __init__(self, channels, **kwargs): + super().__init__(**kwargs) + self.channels = channels + self.ups = keras.layers.UpSampling2D(2) + self.conv = PaddedConv2D(channels, 3, padding=1, name="Upsample") + + # @tf.function + def call(self, inputs): + return self.conv(self.ups(inputs)) + + def get_config(self): + config = super().get_config() + config.update( + { + "channels": self.channels, + } + ) + return config + + +class GEGLU(keras.layers.Layer): + def __init__(self, output_dim, **kwargs): + super().__init__(**kwargs) + self.output_dim = output_dim + self.dense = keras.layers.Dense(output_dim * 2) + + # @tf.function + def call(self, inputs): + x = self.dense(inputs) + x, gate = x[..., : self.output_dim], x[..., self.output_dim :] + tanh_res = keras.activations.tanh( + gate * 0.7978845608 * (1 + 0.044715 * (gate**2)) + ) + return x * 0.5 * gate * (1 + tanh_res) + + +def td_dot(a, b): + aa = keras.ops.reshape(a, (-1, a.shape[2], a.shape[3])) + bb = keras.ops.reshape(b, (-1, b.shape[2], b.shape[3])) + cc = keras.backend.batch_dot(aa, bb) + return keras.ops.reshape(cc, (-1, a.shape[1], cc.shape[1], cc.shape[2])) diff --git a/keras_hub/src/models/control_net/layers.py b/keras_hub/src/models/control_net/layers.py new file mode 100644 index 0000000000..7595b4b867 --- /dev/null +++ b/keras_hub/src/models/control_net/layers.py @@ -0,0 +1,57 @@ +import keras + + +class PaddedConv2D(keras.layers.Layer): + def __init__(self, channels, kernel_size, padding=0, stride=1, name=None): + super().__init__() + self.padding2d = keras.layers.ZeroPadding2D((padding, padding), name=name) + self.conv2d = keras.layers.Conv2D( + channels, kernel_size, strides=(stride, stride), name=name + ) + + def call(self, x): + x = self.padding2d(x) + return self.conv2d(x) + + +class GEGLU(keras.layers.Layer): + def __init__(self, dim_out, name=None): + super().__init__() + self.proj = keras.layers.Dense(dim_out * 2, name=name) + self.dim_out = dim_out + + def call(self, x): + xp = self.proj(x) + x, gate = xp[..., : self.dim_out], xp[..., self.dim_out :] + return x * gelu(gate) + + +def gelu(x): + tanh_res = keras.activations.tanh(x * 0.7978845608 * (1 + 0.044715 * (x**2))) + return 0.5 * x * (1 + tanh_res) + + +def quick_gelu(x): + return x * keras.ops.sigmoid(x * 1.702) + + +"""def apply_seq(x, layers): + for l in layers: + x = l(x) + return x""" + + +def apply_seq(x, seq_layer): + if isinstance(seq_layer, keras.Sequential): + x = seq_layer(x) + else: + for l in seq_layer: + x = l(x) + return x + + +def td_dot(a, b): + aa = keras.ops.reshape(a, (-1, a.shape[2], a.shape[3])) + bb = keras.ops.reshape(b, (-1, b.shape[2], b.shape[3])) + cc = keras.backend.batch_dot(aa, bb) + return keras.ops.reshape(cc, (-1, a.shape[1], cc.shape[1], cc.shape[2])) diff --git a/keras_hub/src/models/control_net/openClipEncoder.py b/keras_hub/src/models/control_net/openClipEncoder.py new file mode 100644 index 0000000000..c41fcfbca5 --- /dev/null +++ b/keras_hub/src/models/control_net/openClipEncoder.py @@ -0,0 +1,160 @@ +import keras +import numpy as np + +from .layers import gelu + + +# Step 1 +# Create and return the CLIP Embeddings +class OpenCLIPTextTransformer(keras.models.Model): + def __init__(self, maxLength=77, vocabularySize=49408): + super().__init__() + + # Create embeddings -> Step 2 + self.embeddings = OpenCLIPTextEmbeddings( + maxLength=maxLength, vocabularySize=vocabularySize + ) + + # Create encoder -> Step 3 + self.encoder = OpenCLIPEncoder() + + self.final_layer_norm = keras.layers.LayerNormalization( + epsilon=1e-5, name="FinalLayerNormalization" + ) + self.causal_attention_mask = keras.initializers.Constant( + np.triu(np.ones((1, 1, 77, 77), dtype="float32") * -np.inf, k=1) + ) + + def call(self, inputs): + input_ids, position_ids = inputs + x = self.embeddings([input_ids, position_ids]) + x = self.encoder([x, self.causal_attention_mask]) + return self.final_layer_norm(x) + + +# Step 2 +# Create and return word and position embeddings +class OpenCLIPTextEmbeddings(keras.layers.Layer): + def __init__(self, maxLength=77, vocabularySize=49408, embeddingSize=1024): + super().__init__() + # Token Embedding Layer - Representing a sequence of tokens (words) + self.token_embedding_layer = keras.layers.Embedding( + vocabularySize, embeddingSize, name="token_embedding" + ) + # Position Embedding layer - Where is the word in the sentence? What does it mean in the context of the sentence? + self.position_embedding_layer = keras.layers.Embedding( + maxLength, embeddingSize, name="position_embedding" + ) + + def call(self, inputs): + input_ids, position_ids = inputs + word_embeddings = self.token_embedding_layer(input_ids) + position_embeddings = self.position_embedding_layer(position_ids) + return word_embeddings + position_embeddings + + +# Step 3 +# Create and return the hidden states (aka hidden size) +class OpenCLIPEncoder(keras.layers.Layer): + def __init__(self): + super().__init__() + self.layers = [OpenCLIPEncoderLayer() for i in range(24)] + + def call(self, inputs): + [hidden_states, causal_attention_mask] = inputs + for l in self.layers: + hidden_states = l([hidden_states, causal_attention_mask]) + return hidden_states + + +# Step 4 (also creatd in step 3) +# Create the layers +class OpenCLIPEncoderLayer(keras.layers.Layer): + def __init__(self, intermediateSize=4096, embeddingSize=1024): + super().__init__() + self.layer_norm1 = keras.layers.LayerNormalization( + epsilon=1e-5, name="LayerNormalization01" + ) # Layer Normalization 1 + self.self_attn = OpenCLIPAttention() # Attention Layers + self.layer_norm2 = keras.layers.LayerNormalization( + epsilon=1e-5, name="LayerNormalization02" + ) # Layer Normalization 2 + self.fc1 = keras.layers.Dense(intermediateSize, name="FC1") # MLP layer? + self.fc2 = keras.layers.Dense(embeddingSize, name="FC2") # ??? + + def call(self, inputs): + hidden_states, causal_attention_mask = inputs + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn([hidden_states, causal_attention_mask]) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + + # MLP Steps + hidden_states = self.fc1(hidden_states) + hidden_states = gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + + return residual + hidden_states + + +class OpenCLIPAttention(keras.layers.Layer): + def __init__(self): + super().__init__() + self.embed_dim = 1024 + self.num_heads = 16 + self.head_dim = self.embed_dim // self.num_heads + self.scale = self.head_dim**-0.5 + self.q_proj = keras.layers.Dense( + self.embed_dim, name="QueryState" + ) # Query states, the given word + self.k_proj = keras.layers.Dense( + self.embed_dim, name="KeyState" + ) # Key states, all other words + self.v_proj = keras.layers.Dense( + self.embed_dim, name="ValueState" + ) # Value states, the sentence + self.out_proj = keras.layers.Dense( + self.embed_dim, name="OutProjection" + ) # Out Projection? + + def _shape(self, tensor, seq_len: int, bsz: int): + # Keys + a = keras.ops.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)) + return keras.layers.Permute((2, 1, 3))(a) # bs , n_head , seq_len , head_dim + + def call(self, inputs): + hidden_states, causal_attention_mask = inputs + bsz, tgt_len, embed_dim = hidden_states.shape + query_states = self.q_proj(hidden_states) * self.scale + key_states = self._shape(self.k_proj(hidden_states), tgt_len, -1) + value_states = self._shape(self.v_proj(hidden_states), tgt_len, -1) + + proj_shape = (-1, tgt_len, self.head_dim) + query_states = self._shape(query_states, tgt_len, -1) + query_states = keras.ops.reshape(query_states, proj_shape) + key_states = keras.ops.reshape(key_states, proj_shape) + + src_len = tgt_len + value_states = keras.ops.reshape(value_states, proj_shape) + attn_weights = query_states @ keras.layers.Permute((2, 1))(key_states) + + attn_weights = keras.ops.reshape( + attn_weights, (-1, self.num_heads, tgt_len, src_len) + ) + attn_weights = attn_weights + causal_attention_mask + attn_weights = keras.ops.reshape(attn_weights, (-1, tgt_len, src_len)) + + attn_weights = keras.ops.softmax(attn_weights) + attn_output = attn_weights @ value_states + + attn_output = keras.ops.reshape( + attn_output, (-1, self.num_heads, tgt_len, self.head_dim) + ) + attn_output = keras.layers.Permute((2, 1, 3))(attn_output) + attn_output = keras.ops.reshape(attn_output, (-1, tgt_len, embed_dim)) + + return self.out_proj(attn_output) diff --git a/keras_hub/src/models/control_net/samplers/DPMSolverKerasCV.py b/keras_hub/src/models/control_net/samplers/DPMSolverKerasCV.py new file mode 100644 index 0000000000..da4181ab4f --- /dev/null +++ b/keras_hub/src/models/control_net/samplers/DPMSolverKerasCV.py @@ -0,0 +1,196 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""StableDiffusion Noise scheduler + +Adapted from https://github.com/huggingface/diffusers/blob/v0.3.0/src/diffusers/schedulers/scheduling_ddpm.py#L56 + +From https://github.com/keras-team/keras-cv/blob/master/keras_cv/models/stable_diffusion/noise_scheduler.py +""" +import keras + + +class NoiseScheduler: + """ + Args: + train_timesteps: number of diffusion steps used to train the model. + beta_start: the starting `beta` value of inference. + beta_end: the final `beta` value. + beta_schedule: + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `quadratic`. + betas: a complete set of betas, in lieu of using one of the existing schedules. + variance_type: + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + clip_sample: + option to clip predicted sample between -1 and 1 for numerical stability. + """ + + def __init__( + self, + train_timesteps=1000, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="linear", + betas=None, + variance_type="fixed_small", + clip_sample=True, + ): + self.train_timesteps = train_timesteps + + if beta_schedule == "linear": + self.betas = keras.ops.linspace(beta_start, beta_end, train_timesteps) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + keras.ops.linspace(beta_start**0.5, beta_end**0.5, train_timesteps) ** 2 + ) + else: + raise ValueError(f"Invalid beta schedule: {beta_schedule}.") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = keras.ops.cumprod(self.alphas) + + self.variance_type = variance_type + self.clip_sample = clip_sample + + def _get_variance(self, timestep, predicted_variance=None): + alpha_prod = self.alphas_cumprod[timestep] + alpha_prod_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else 1.0 + + variance = (1 - alpha_prod_prev) / (1 - alpha_prod) * self.betas[timestep] + + if self.variance_type == "fixed_small": + variance = keras.ops.clip(variance, clip_value_min=1e-20, clip_value_max=1) + elif self.variance_type == "fixed_small_log": + variance = keras.ops.log( + (keras.ops.clip(variance, clip_value_min=1e-20, clip_value_max=1)) + ) + elif self.variance_type == "fixed_large": + variance = self.betas[timestep] + elif self.variance_type == "fixed_large_log": + variance = keras.ops.log(self.betas[timestep]) + elif self.variance_type == "learned": + return predicted_variance + elif self.variance_type == "learned_range": + min_log = variance + max_log = self.betas[timestep] + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + else: + raise ValueError(f"Invalid variance type: {self.variance_type}") + + return variance + + def step( + self, + model_output, + timestep, + sample, + predict_epsilon=True, + ): + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (usually the predicted noise). + Args: + model_output: a Tensor containing direct output from learned diffusion model + timestep: current discrete timestep in the diffusion chain. + sample: a Tensor containing the current instance of sample being created by diffusion process. + predict_epsilon: whether the model is predicting noise (epsilon) or samples + Returns: + The predicted sample at the previous timestep + """ + + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [ + "learned", + "learned_range", + ]: + model_output, predicted_variance = keras.ops.split( + model_output, sample.shape[1], axis=1 + ) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod = self.alphas_cumprod[timestep] + alpha_prod_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else 1.0 + beta_prod = 1 - alpha_prod + beta_prod_prev = 1 - alpha_prod_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if predict_epsilon: + pred_original_sample = ( + sample - beta_prod ** (0.5) * model_output + ) / alpha_prod ** (0.5) + else: + pred_original_sample = model_output + + # 3. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = keras.ops.clip(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = ( + alpha_prod_prev ** (0.5) * self.betas[timestep] + ) / beta_prod + current_sample_coeff = ( + self.alphas[timestep] ** (0.5) * beta_prod_prev / beta_prod + ) + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = ( + pred_original_sample_coeff * pred_original_sample + + current_sample_coeff * sample + ) + + # 6. Add noise + variance = 0 + if timestep > 0: + noise = keras.random.normal(model_output.shape) + variance = ( + self._get_variance(timestep, predicted_variance=predicted_variance) + ** 0.5 + ) * noise + + pred_prev_sample = pred_prev_sample + variance + + return pred_prev_sample + + def add_noise( + self, + original_samples, + noise, + timesteps, + ): + sqrt_alpha_prod = keras.ops.take(self.alphas_cumprod, timesteps) ** 0.5 + sqrt_one_minus_alpha_prod = ( + 1 - keras.ops.take(self.alphas_cumprod, timesteps) + ) ** 0.5 + + for _ in range(3): + sqrt_alpha_prod = keras.ops.expand_dims(sqrt_alpha_prod, axis=-1) + sqrt_one_minus_alpha_prod = keras.ops.expand_dims( + sqrt_one_minus_alpha_prod, axis=-1 + ) + + noisy_samples = ( + sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + ) + return noisy_samples + + def __len__(self): + return self.train_timesteps diff --git a/keras_hub/src/models/control_net/samplers/ReadMe.md b/keras_hub/src/models/control_net/samplers/ReadMe.md new file mode 100644 index 0000000000..7cbaba5bcc --- /dev/null +++ b/keras_hub/src/models/control_net/samplers/ReadMe.md @@ -0,0 +1,5 @@ +### Samplers ### + +This folder contains the samplers used for calculating and generating the images. + +Because this is a TensorFlow implementation of Stable Diffusion, there are only a few options for sampling. diff --git a/keras_hub/src/models/control_net/samplers/__init__.py b/keras_hub/src/models/control_net/samplers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_hub/src/models/control_net/samplers/basicSampler.py b/keras_hub/src/models/control_net/samplers/basicSampler.py new file mode 100644 index 0000000000..3917d55ed9 --- /dev/null +++ b/keras_hub/src/models/control_net/samplers/basicSampler.py @@ -0,0 +1,402 @@ +### Basic Modules +import math +import random + +import keras +from keras_hub.src.models.control_net.utils import keras_print +### Modules for image building +from PIL import Image + +### TensorFlow Modules + + +#import cv2 # OpenCV + + + +class BasicSampler: + def __init__( + self, + model=None, + timesteps=keras.ops.arange(1, 1000, 1000 // 50), + batchSize=1, + seed=1990, + inputImage=None, # Expecting a tensor + inputMask=None, # Expecting a tensor + inputImageStrength=0.5, + temperature=1, + AlphasCumprod=None, + controlNetInput=None, + ): + print("...starting Basic Sampler...") + self.model = model + self.timesteps = timesteps + self.batchSize = batchSize + self.seed = seed + self.inputImage = inputImage + self.inputMask = inputMask + self.inputImageStrength = inputImageStrength + self.inputImageNoise_T = self.timesteps[ + int(len(self.timesteps) * self.inputImageStrength) + ] + self.temperature = temperature + self.AlphasCumprod = AlphasCumprod # Length = 1000 + + self.latent, self.alphas, self.alphas_prev, self.controlNetInput = ( + self.getStartingParameters( + self.timesteps, + self.batchSize, + seed, + inputImage=self.inputImage, + inputImageNoise_T=self.inputImageNoise_T, + controlNetInput=controlNetInput, + ) + ) + + if self.inputImage is not None: + self.timesteps = self.timesteps[ + : int(len(self.timesteps) * self.inputImageStrength) + ] + + print("...sampler ready...") + + def addNoise(self, x, t, noise=None, DType=keras.config.floatx()): + batch_size, w, h = x.shape[0], x.shape[1], x.shape[2] + if noise is None: + # Post-Encode version: + noise = keras.random.normal((batch_size, w, h, 4), dtype=DType) + # Pre-Encode version: + # noise = keras.random.normal((batch_size,w,h,3), dtype = DType) + sqrt_alpha_prod = self.AlphasCumprod[t] ** 0.5 + sqrt_one_minus_alpha_prod = (1 - self.AlphasCumprod[t]) ** 0.5 + + return sqrt_alpha_prod * x + sqrt_one_minus_alpha_prod * noise + + def getStartingParameters( + self, + timesteps, + batchSize, + seed, + inputImage=None, + inputImageNoise_T=None, + controlNetInput=None, + ): + # Use floor division to get minimum height/width of image size + # for the Diffusion and Decoder models + floorDividedImageHeight = self.model.imageHeight // 8 + floorDividedImageWidth = self.model.imageWidth // 8 + + alphas = [self.AlphasCumprod[t] for t in timesteps] # sample steps length + alphas_prev = [1.0] + alphas[:-1] + + if inputImage is None: + # Create a random input image from noise + latent = keras.random.stateless_normal( + (batchSize, floorDividedImageHeight, floorDividedImageWidth, 4), + seed=[seed, seed], + ) + else: + ## Debug Variables + randomNumber = str(random.randint(0, 2**31)) + + # Noise the input image before encoding + # latent = self.addNoise(inputImage, inputImageNoise_T) + + # Encode the given image + print(inputImage.shape) + latent = self.model.encoder(inputImage, training=False) + print(latent.shape) + # self.displayImage(latent,("encoded" + randomNumber)) + # Repeat it within the tensor for the given batch size + latent = keras.ops.repeat(latent, batchSize, axis=0) + # Noise the image after encode + latent = self.addNoise(latent, inputImageNoise_T) + + if controlNetInput is None: + # Create a random input image from noise + controlNetLatent = keras.random.normal( + (batchSize, floorDividedImageHeight, floorDividedImageWidth, 3), + seed=seed, + ) + else: + controlNetLatent = keras.ops.repeat(controlNetInput, batchSize, axis=0) + + return latent, alphas, alphas_prev, controlNetLatent + + def get_x_prev_and_pred_x0(self, x, e_t, index, a_t, a_prev, temperature, seed): + sigma_t = keras.initializers.Constant(0.0) + sqrt_one_minus_at = keras.ops.sqrt(keras.initializers.Constant(1.0) - a_t) + pred_x0 = (x - sqrt_one_minus_at * e_t) / keras.ops.sqrt(a_t) + + # Direction pointing to x_t + dir_xt = ( + keras.ops.sqrt( + keras.initializers.Constant(1.0) - a_prev - keras.ops.square(sigma_t) + ) + * e_t + ) + noise = sigma_t * keras.random.normal(x.shape, seed=seed) * temperature + x_prev = keras.ops.sqrt(a_prev) * pred_x0 + dir_xt + return x_prev, pred_x0 + + # Keras Version + def sample( + self, + context, + unconditionalContext, + unconditionalGuidanceScale, + controlNet=[None, 1, None], # [0]Use ControlNet, [1]Strength, [2] Cache Input + vPrediction=False, + device=None, + ): + with keras.device(device): + # Progress Bar set-up + progbar = keras.utils.Progbar(len(self.timesteps)) + iteration = 0 + + # ControlNet Cache + if controlNet[2] is not None: + keras_print("...using controlNet cache...") + controlNetCache = controlNet[2] + else: + if controlNet[0] is True: + keras_print("...creating controlNet cache...") + controlNetCache = [] + + if controlNet[2] is not None and len(controlNet[2]) != len( + list(enumerate(self.timesteps))[::-1] + ): + keras_print("...updating controlNet cache...") + controlNetCache = [] + controlNet[2] = None + + keras_print("...sampling:") + + # Iteration loop + for index, timestep in list(enumerate(self.timesteps))[::-1]: + + latentPrevious = self.latent + + # Establish timestep embedding + # t_emb = self.timestepEmbedding(float(timestep)) + t_emb = self.timestepEmbedding(int(timestep)) + t_emb = keras.ops.repeat( + t_emb, self.batchSize, axis=0 + ) # shape is (1, 320) + + inputsConditional = [self.latent, t_emb, context] + inputsUnconditional = [self.latent, t_emb, unconditionalContext] + + if controlNet[0] is True: + + if controlNet[2] is None: + # No cache was given, so we're starting from scratch + + # Get unconditional and conditional tensors(arrays) + controlNetUnconditionalArray = self.model.controlNet( + [ + self.latent, + t_emb, + unconditionalContext, + keras.ops.concatenate(self.controlNetInput, axis=3), + ], + training=False, + ) + controlNetConditionalArray = self.model.controlNet( + [ + self.latent, + t_emb, + context, + keras.ops.concatenate(self.controlNetInput, axis=3), + ], + training=False, + ) + + # Apply strength + controlNetUnconditionalArray = [ + result * scale + for result, scale in zip( + controlNetUnconditionalArray, controlNet[1] + ) + ] + controlNetConditionalArray = [ + result * scale + for result, scale in zip( + controlNetConditionalArray, controlNet[1] + ) + ] + + # Update Cache + controlNetCacheData = { + "unconditional": controlNetUnconditionalArray, + "conditional": controlNetConditionalArray, + } + controlNetCache.insert(0, controlNetCacheData) + + # Add the resulting tensors from the contorlNet models to the list of inputs for the diffusion models + inputsUnconditional.append(controlNetUnconditionalArray) + inputsConditional.append(controlNetConditionalArray) + else: + # Use ControlNet Cache + inputsUnconditional.extend( + controlNetCache[index]["unconditional"] + ) + inputsConditional.extend(controlNetCache[index]["conditional"]) + + # Get unconditional (negative prompt) latent image + unconditionalLatent = self.model.diffusion_model( + inputsUnconditional, training=False + ) + + # Get conditional (positive prompt) latent image + self.latent = self.model.diffusion_model( + inputsConditional, training=False + ) + + # Combine the two latent images + self.latent = unconditionalLatent + unconditionalGuidanceScale * ( + self.latent - unconditionalLatent + ) + + # Alphas + a_t, a_prev = self.alphas[index], self.alphas_prev[index] + + # Predictions + if vPrediction is False: + # Debug Info + if iteration == 0: + print("Latent Previous dtype:", latentPrevious.dtype) + print("Latent dtype:", self.latent.dtype) + + # Make the data types (dtypes) match + if latentPrevious.dtype != self.latent.dtype: + latentPrevious = keras.ops.cast( + latentPrevious, dtype=self.latent.dtype + ) + + pred_x0 = ( + latentPrevious - math.sqrt(1.0 - a_t) * self.latent + ) / math.sqrt(a_t) + + self.latent = ( + self.latent * math.sqrt(1.0 - a_prev) + + math.sqrt(a_prev) * pred_x0 + ) + else: + # v-Prediction for SD 2.1-V models + self.latent = self.predictEpsFromZandV( + latentPrevious, index, self.latent + ) + + # Keras Progress Bar Update + iteration += 1 + progbar.update(iteration) + + keras_print("...finished! Returning latent image...") + + return self.latent, controlNetCache + + def predictEpsFromZandV(self, latent, timestep, velocity): + + # sqrt_alphas_cumprod = keras.ops.sqrt(keras.ops.cumprod([1 - alpha for alpha in self.alphas], axis = 0, exclusive = True)) + sqrt_alphas_cumprod = keras.ops.sqrt(self.alphas) + # keras_print("\nSquare Root Alphas Cumprod:\n",len(sqrt_alphas_cumprod)) + tensorShape = sqrt_alphas_cumprod.shape[0] + # sqrt_alphas_cumprod = sqrt_alphas_cumprod[timestep] + # sqrt_alphas_cumprod = keras.ops.reshape(sqrt_alphas_cumprod, (tensorShape,) + (1,) * (len(latent.shape) - 1)) + + sqrt_one_minus_alphas_cumprod = keras.ops.sqrt( + [1 - alpha for alpha in self.alphas] + ) + # keras_print("\nSquare Root Alphas Cumprod Minus One:\n",len(sqrt_one_minus_alphas_cumprod)) + tensorShape = sqrt_one_minus_alphas_cumprod.shape[0] + # sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[timestep] + # sqrt_one_minus_alphas_cumprod = keras.ops.reshape(sqrt_one_minus_alphas_cumprod, (tensorShape,) + (1,) * (len(latent.shape) - 1)) + + return ( + sqrt_alphas_cumprod[timestep] * latent + - sqrt_one_minus_alphas_cumprod[timestep] * velocity + ) + + def predictStartFromZandV(self, latent, timestep, velocity): + # sqrt_alphas_cumprod = keras.ops.sqrt(keras.ops.cumprod([1 - alpha for alpha in self.alphas], axis = 0, exclusive = True)) + sqrt_alphas_cumprod = keras.ops.sqrt(self.alphas) + tensorShape = sqrt_alphas_cumprod.shape[0] + # sqrt_alphas_cumprod = sqrt_alphas_cumprod[timestep] + # sqrt_alphas_cumprod = keras.ops.reshape(sqrt_alphas_cumprod, (tensorShape,) + (1,) * (len(latent.shape) - 1)) + + # sqrt_one_minus_alphas_cumprod = keras.ops.sqrt(1 - keras.ops.cumprod(self.alphas, axis = 0, exclusive = True)) + sqrt_one_minus_alphas_cumprod = keras.ops.sqrt( + [1 - alpha for alpha in self.alphas] + ) + tensorShape = sqrt_one_minus_alphas_cumprod.shape[0] + # sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[timestep] + # sqrt_one_minus_alphas_cumprod = keras.ops.reshape(sqrt_one_minus_alphas_cumprod, (tensorShape,) + (1,) * (len(latent.shape) - 1)) + + """sqrt_alphas_cumprod_t = extractIntoTensor(sqrt_alphas_cumprod, timestep, latent.shape) + sqrt_one_minus_alphas_cumprod_t = extractIntoTensor(sqrt_one_minus_alphas_cumprod, timestep, latent.shape) + + return sqrt_alphas_cumprod_t * latent - sqrt_one_minus_alphas_cumprod_t * velocity""" + + """print(sqrt_alphas_cumprod.shape) + print(timestep) + print(sqrt_alphas_cumprod[timestep])""" + + return ( + sqrt_alphas_cumprod[timestep] * velocity + + sqrt_one_minus_alphas_cumprod[timestep] * latent + ) + + def timestepEmbedding(self, timesteps, dimensions=320, max_period=10000.0): + half = dimensions // 2 + freqs = keras.ops.exp( + -keras.ops.log(max_period) + * keras.ops.arange(0, half, dtype=keras.config.floatx()) + / half + ) + args = ( + keras.ops.convert_to_tensor([timesteps], dtype=keras.config.floatx()) + * freqs + ) + embedding = keras.ops.concatenate([keras.ops.cos(args), keras.ops.sin(args)], 0) + embedding = keras.ops.reshape(embedding, [1, -1]) + return embedding + + def displayImage(self, image, name="sampler"): + # Assuming input_image_tensor is a TensorFlow tensor representing the image + + try: + input_image_tensor = self.model.decoder(image, training=False) + except Exception as e: + print(e) + input_image_tensor = image + + # Assuming input_image_tensor is a TensorFlow tensor representing the image + # Remove the batch dimension + input_image_tensor = keras.ops.squeeze(input_image_tensor, axis=0) + + # keras.ops.image.resize(input_image_tensor, [self.model.imageWidth, self.model.imageHeight]) + + # Convert the tensor to a NumPy array + input_image_array = input_image_tensor.numpy() + + # Rescale the array to the range [0, 255] + input_image_array = ((input_image_array + 1) / 2.0) * 255.0 + + # Convert the array to uint8 data type + input_image_array = input_image_array.astype("uint8") + + # Display the image using Matplotlib + imageFromBatch = Image.fromarray(input_image_array) + imageFromBatch.save("debug/" + name + ".png") + + +""" +Utilities +""" + + +def extractIntoTensor(a, t, x_shape): + b, *_ = keras.ops.shape(t) + out = keras.ops.take(a, t, axis=-1) + return keras.ops.reshape(out, (b,) + (1,) * (len(x_shape) - 1)) diff --git a/keras_hub/src/models/control_net/samplers/basicVSampler.py b/keras_hub/src/models/control_net/samplers/basicVSampler.py new file mode 100644 index 0000000000..fb588cc3a3 --- /dev/null +++ b/keras_hub/src/models/control_net/samplers/basicVSampler.py @@ -0,0 +1,202 @@ +# TensorFlow Modules + +import keras +from keras_hub.src.models.control_net.utils import keras_print + + +class BasicSampler: + def __init__( + self, + model=None, + timesteps=keras.ops.numpy.arange(1, 1000, 1000 // 5), + batchSize=1, + seed=1990, + inputImage=None, # Expecting a tensor + inputMask=None, # Expecting a tensor + inputImageStrength=0.5, + temperature=1, + AlphasCumprod=None, + ): + print("...starting Basic Sampler...") + self.model = model + self.timesteps = timesteps + self.batchSize = batchSize + self.seed = seed + self.inputImage = inputImage + self.inputMask = inputMask + self.inputImageStrength = inputImageStrength + self.inputImageNoise_T = self.timesteps[ + int(len(self.timesteps) * self.inputImageStrength) + ] + self.temperature = temperature + self.AlphasCumprod = AlphasCumprod + + self.latent, self.alphas, self.alphas_prev = self.getStartingParameters( + self.timesteps, + self.batchSize, + seed, + inputImage=self.inputImage, + inputImageNoise_T=self.inputImageNoise_T, + ) + + if self.inputImage is not None: + self.timesteps = self.timesteps[ + : int(len(self.timesteps) * self.inputImageStrength) + ] + + print("...sampler ready...") + + def addNoise(self, x, t, noise=None, DType=keras.config.floatx()): + batch_size, w, h = x.shape[0], x.shape[1], x.shape[2] + if noise is None: + noise = keras.random.normal((batch_size, w, h, 4), dtype=DType) + sqrt_alpha_prod = self.AlphasCumprod[t] ** 0.5 + sqrt_one_minus_alpha_prod = (1 - self.AlphasCumprod[t]) ** 0.5 + + return sqrt_alpha_prod * x + sqrt_one_minus_alpha_prod * noise + + def getStartingParameters( + self, timesteps, batchSize, seed, inputImage=None, inputImageNoise_T=None + ): + # Use floor division to get minimum height/width of image size + # for the Diffusion and Decoder models + floorDividedImageHeight = self.model.imageHeight // 8 + floorDividedImageWidth = self.model.imageWidth // 8 + + alphas = [self.AlphasCumprod[t] for t in timesteps] + alphas_prev = [1.0] + alphas[:-1] + + if inputImage is None: + # Create a random input image from noise + latent = keras.random.normal( + (batchSize, floorDividedImageHeight, floorDividedImageWidth, 4), + seed=seed, + ) + else: + # Encode the given image + latent = self.model.encoder(inputImage, training=False) + # Repeat it within the tensor for the given batch size + latent = keras.ops.repeat(latent, batchSize, axis=0) + # Noise the image + latent = self.addNoise(latent, inputImageNoise_T) + + return latent, alphas, alphas_prev + + def get_x_prev_and_pred_x0(self, x, e_t, index, a_t, a_prev, temperature, seed): + sigma_t = keras.initializers.Constant(0.0) + sqrt_one_minus_at = keras.ops.sqrt(keras.initializers.Constant(1.0).value - a_t) + pred_x0 = (x - sqrt_one_minus_at * e_t) / keras.ops.sqrt(a_t) + + # Direction pointing to x_t + dir_xt = ( + keras.ops.sqrt( + keras.initializers.Constant(1.0).value + - a_prev + - keras.ops.square(sigma_t) + ) + * e_t + ) + noise = sigma_t.value * keras.random.normal(x.shape, seed=seed) * temperature + x_prev = keras.ops.sqrt(a_prev) * pred_x0 + dir_xt + return x_prev, pred_x0 + + # Keras Version + def sample(self, context, unconditionalContext, unconditionalGuidanceScale): + keras_print("...sampling:") + + # Progress Bar set-up + progbar = keras.utils.Progbar(len(self.timesteps)) + iteration = 0 + + # Iteration loop + for index, timestep in list(enumerate(self.timesteps))[::-1]: + + latentPrevious = self.latent + + # Establish timestep embedding + t_emb = self.timestepEmbedding(float(timestep)) + t_emb = keras.ops.repeat(t_emb, self.batchSize, axis=0) + + # Get unconditional (negative prompt) latent image + unconditionalLatent = self.model.diffusion_model( + [self.latent, t_emb, unconditionalContext], training=False + ) + # Get conditional (positive prompt) latent image + self.latent = self.model.diffusion_model( + [self.latent, t_emb, context], training=False + ) + + # Combine the two latent images, the et + self.latent = unconditionalLatent + unconditionalGuidanceScale * ( + self.latent - unconditionalLatent + ) + + # Alphas, the sigma + a_t, a_prev = self.alphas[index], self.alphas_prev[index] + + """# Predictions + predictV = (latentPrevious - keras.ops.sqrt(keras.initializers.Constant(1.0) - a_t) * self.latent) / keras.ops.sqrt( + a_t + ) + self.latent = ( + self.latent * keras.ops.sqrt(1.0 - a_prev) + keras.ops.sqrt(a_prev) * predictV + )""" + + # Predictions + predictV = ( + latentPrevious + - keras.ops.sqrt(keras.initializers.Constant(1.0) - a_t) * self.latent + ) / keras.ops.sqrt(a_t) + self.latent = ( + self.latent * keras.ops.sqrt(1.0 - a_prev) + + keras.ops.sqrt(a_prev) * predictV + ) + + # Keras Progress Bar Update + iteration += 1 + progbar.update(iteration) + + keras_print("...finished! Returning latent image...") + + return self.latent + + def getModelOutput( + self, + latent, + inputTimesteps, + context, + unconditionalContext, + unconditionalGuidanceScale, + batch_size, + ): + + # Establish timestep embedding + t_emb = self.timestepEmbedding(float(inputTimesteps)) + t_emb = keras.ops.repeat(t_emb, batch_size, axis=0) + + # Get unconditional (negative prompt) latent image + unconditionalLatent = self.model.diffusion_model( + [latent, t_emb, unconditionalContext], training=False + ) + # Get conditional (positive prompt) latent image + latent = self.model.diffusion_model([latent, t_emb, context], training=False) + + # Combine the images and return the result + return unconditionalLatent + unconditionalGuidanceScale * ( + latent - unconditionalLatent + ) + + def timestepEmbedding(self, timesteps, dimensions=320, max_period=10000.0): + half = dimensions // 2 + freqs = keras.ops.exp( + -keras.ops.log(max_period) + * keras.ops.arange(0, half, dtype=keras.config.floatx()) + / half + ) + args = ( + keras.ops.convert_to_tensor([timesteps], dtype=keras.config.floatx()) + * freqs + ) + embedding = keras.ops.concatenate([keras.ops.cos(args), keras.ops.sin(args)], 0) + embedding = keras.ops.reshape(embedding, [1, -1]) + return embedding diff --git a/keras_hub/src/models/control_net/stableDiffusion.py b/keras_hub/src/models/control_net/stableDiffusion.py new file mode 100644 index 0000000000..261430ce24 --- /dev/null +++ b/keras_hub/src/models/control_net/stableDiffusion.py @@ -0,0 +1,1003 @@ +### System modules +### Time modules +import datetime +### Memmory Management +import gc # Garbage Collector +import logging +import os +import random +import sys +import warnings + +### Math modules +import numpy as np +from jax import Array +### Console GUI +from rich import box, print +from rich.panel import Panel +from rich.text import Text + +from .utils import keras_print + +### Import TensorFlow module +### but with supressed warnings to clear up the terminal outputs +# Filter tensorflow version warnings +# https://stackoverflow.com/questions/40426502/is-there-a-way-to-suppress-the-messages-tensorflow-prints/40426709 +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # or any {'0', '1', '2'} +# https://stackoverflow.com/questions/15777951/how-to-suppress-pandas-future-warning +warnings.simplefilter(action="ignore", category=FutureWarning) +warnings.simplefilter(action="ignore", category=Warning) + +# TensorFlow module + + +# More suppressed warnings from TensorFlow +# tf.get_logger().setLevel('INFO') +# tf.autograph.set_verbosity(0) +# tf.get_logger().setLevel(logging.ERROR) + +### Keras module +import keras +### Pytorch (for converting pytorch weights) +import torch as torch +from keras import backend as K +### Modules for image building +from PIL import Image +### Safetensors (for converting safetensor weights) +from safetensors.torch import load_file + +## Text encoder +from .clipEncoder import CLIPTextTransformer # SD 1.4/1.5 +## Tokenizer +from .clipTokenizer import LegacySimpleTokenizer, SimpleTokenizer +## ControlNet +from .controlNetDiffusionModels import \ + ControlNetDiffusionModel as ControlNetModel +from .controlNetDiffusionModels import DiffusionModel as ControlDiffusionModel +### Models from Modules +## VAE, encode and decode +from .EncodeDecode import Decoder, ImageEncoder +# from .autoencoderKl import Decoder, Encoder +## Diffusion +from .kerasCVDiffusionModels import DiffusionModel, DiffusionModelV2 +from .openClipEncoder import OpenCLIPTextTransformer # SD 2.x +### Sampler modules +from .samplers import DPMSolverKerasCV as DPMSolver +from .samplers.basicSampler import BasicSampler +### Tools +from .tools import textEmbeddings as textEmbeddingTools + +#import cv2 # OpenCV + + +### Global Variables +MAX_TEXT_LEN = 77 +from .constants import _ALPHAS_CUMPROD, PYTORCH_CKPT_MAPPING + +### Main Class + + +class StableDiffusion: + ### Base class/object for Stable Diffusion + def __init__( + self, + imageHeight=512, + imageWidth=512, + jit_compile=False, + weights=None, + legacy=True, + VAE="Original", + textEmbeddings=None, + mixedPrecision=False, + optimizer="nadam", + device=None, + controlNet=[ + False, + None, + ], # [0] = Use ControlNet? [1] = ControlNet Weights [2] = Input [3] = Strength + ): + self.device = device + + with keras.device(self.device): + ### Step 1: Establish image dimensions for UNet ### + ## requires multiples of 2**7, 2 to the power of 7 + self.imageHeight = round(imageHeight / 128) * 128 + self.imageWidth = round(imageWidth / 128) * 128 + + # Global policy + self.dtype = keras.config.floatx() # Default + + # Maaaybe float16 will result in faster images? + if mixedPrecision is True: + self.changePolicy("mixed_float16") + + ### Step 2: Load Text Embeddings ### + textEmbeddingTokens = [] + if textEmbeddings == None: + keras_print("\nIgnoring Text Embeddings") + self.textEmbeddings = None + self.textEmbeddingsTokens = None + else: + keras_print("\nUsing Text Embeddings") + self.textEmbeddings, self.textEmbeddingsTokens = ( + textEmbeddingTools.loadTextEmbedding(textEmbeddings) + ) + + ### Step 3: Which version of Stable Diffusion ### + + self.legacy = legacy + + ### Step 4: Create Tokenizer ### + if self.legacy is True: + if self.textEmbeddings is None: + # If no textEmbeddings were given, we're not adding to the special tokens list in the tokenizer + self.tokenizer = LegacySimpleTokenizer() + else: + self.tokenizer = LegacySimpleTokenizer( + specialTokens=self.textEmbeddingsTokens + ) + else: + if self.textEmbeddings is None: + self.tokenizer = SimpleTokenizer() + else: + self.tokenizer = SimpleTokenizer( + specialTokens=self.textEmbeddingsTokens + ) + + ### Step 5: Create Models ### + """ + We need to create empty models before we can compile them with + the weights of the trained models. + First, let's check for pytorch weights. If given, we will load them later. + If not, then we're loading in a pre-compiled model OR weights made for TensorFlow + """ + + ## Step 5.1: Create weightless models ## + if controlNet[0] == True: + keras_print("\nUsing ControlNet", controlNet[1]) + + text_encoder, diffusion_model, decoder, encoder, control_net = CreateModels( + self.imageHeight, + self.imageWidth, + preCompiled=None, # If not None, then we're passing on Keras weights ".h5" + legacy=legacy, + addedTokens=self.textEmbeddings, + useControlNet=[controlNet[0]], + device=self.device, + ) + + ## Step 5.2 Create object/class variables that point to the compiled models + self.text_encoder = text_encoder + self.diffusion_model = diffusion_model + self.decoder = decoder + self.encoder = encoder + self.controlNet = control_net + + ## Step 5.4: Load Weights + # NOTE: must be done after creating models + self.weights = weights + + self.setWeights(weights, VAE) + + ### Step 6: Load Text Embedding Weights ### + if self.textEmbeddings is not None: + if legacy is True: + CLIP = CLIPTextTransformer + else: + CLIP = OpenCLIPTextTransformer + self.text_encoder = textEmbeddingTools.loadTextEmbeddingWeight( + textEncoder=text_encoder, + CLIP=CLIP, + maxTextLength=MAX_TEXT_LEN, + embeddings=self.textEmbeddings, + legacy=legacy, + ) + + ### Step 7: Load ControlNet Weights ### + if controlNet[0] == True: + if ".safetensors" in controlNet[1]: + loadWeightsFromSafeTensor( + self, + controlNet[ + 1 + ], # Which weights to load, in this case maybe all four models + legacy, # Which version of Stable Diffusion + ["controlNet"], # Which specific Models to load + ) + elif ".pth" in controlNet[1]: + loadWeightsFromPytorchCKPT( + self, + controlNet[ + 1 + ], # Which weights to load, in this case maybe all four models + legacy, # Which version of Stable Diffusion + ["controlNet"], # Which specific Models to load + ) + + ### Step 8: Compile Models ### + self.jitCompile = jit_compile + self.compileModels(optimizer, self.jitCompile) + + ## Cache + self.prompt = None + self.negativePrompt = None + self.encodedPrompt = None + self.encodedNegativePrompt = None + self.batch_size = None + self.controlNetCache = None + + def compileModels(self, optimizer="nadam", jitCompile=False): + modules = ["text_encoder", "diffusion_model", "decoder", "encoder"] + + if jitCompile is True: + keras_print("\nCompiling models with XLA (Accelerated Linear Algebra):") + else: + keras_print("\nCompiling models") + + with keras.device(self.device): + for module in modules: + getattr(self, module).compile( + optimizer=keras.optimizers.Adam(), jit_compile=jitCompile + ) + print(module, "compiled.") + + """ + Generate and image, the key function + """ + + def generate( + self, + prompt, + negativePrompt=None, + batch_size=1, + num_steps=25, + unconditional_guidance_scale=7.5, + temperature=1, + seed=None, + input_image=None, # expecting file path as a string or np.ndarray + input_image_strength=0.5, + input_mask=None, # expecting a file path as a string + sampler=None, + controlNetStrength=1, + controlNetImage=None, + controlNetCache=False, + vPrediction=False, + ): + with keras.device(self.device): + ## Memory Efficiency + # Clear up tensorflow memory + keras_print("\n...cleaning memory...") + keras.backend.clear_session() + gc.collect() + + keras_print("...getting to work...") + + ### Step 1: Cache Prompts + if self.prompt != prompt: # New prompt? + # Create prompt cache + self.prompt = prompt + self.encodedPrompt = None + + if self.negativePrompt != negativePrompt: # New negative prompt? + # Create negative prompt cache + self.negativePrompt = negativePrompt + self.encodedNegativePrompt = None + + if self.batch_size != batch_size: # New batch size? + # clear prompt caches if batch_size has changed + self.encodedPrompt = None + self.encodedNegativePrompt = None + self.batch_size = batch_size + + ### Step 2: Tokenize prompts + # the tokenized prompts are AKA "starting context" + # we'll also tokenize the negative prompt, the "unconditional context" + + if self.encodedPrompt is None: + # No cached encoded prompt exists + keras_print("\n...tokenizing prompt...") + + if self.textEmbeddings is not None: + keras_print("...checking for text embeddings...") + prompt = textEmbeddingTools.injectTokens( + prompt=prompt, embeddings=self.textEmbeddings + ) + + phrase, pos_ids = self.encodeText(prompt, batch_size, self.legacy) + + keras_print("...encoding the tokenized prompt...") + context = self.text_encoder([phrase, pos_ids], training=False) + + # Cache encoded prompt + self.encodedPrompt = context + else: + # Load cached encoded prompt + keras_print("...using cached encoded prompt...") + context = self.encodedPrompt + + if self.encodedNegativePrompt is None: + keras_print("...tokenizing negative prompt...") + if negativePrompt is None: + # Encoding text requires a string variable + negativePrompt = "" + + if self.textEmbeddings is not None: + keras_print("...checking for text embeddings...") + negativePrompt = textEmbeddingTools.injectTokens( + prompt=negativePrompt, embeddings=self.textEmbeddings + ) + + unconditional_tokens, pos_ids = self.encodeText( + negativePrompt, batch_size, self.legacy + ) + + keras_print("...encoding the tokenized negative prompt...") + unconditionalContext = self.text_encoder( + [unconditional_tokens, pos_ids], training=False + ) + + # Cache encoded negative prompt + self.encodedNegativePrompt = unconditionalContext + else: + keras_print("...using cached encoded negative prompt...") + unconditionalContext = self.encodedNegativePrompt + + ### Step 3: Prepare the input image, if it was given + ## If given, we're expecting an np.ndarry + input_image_tensor = None + if input_image is not None: + + if isinstance(input_image, np.ndarray): + print("...received NumPy Array...") + print(input_image.shape) + + input_image = keras.ops.convert_to_tensor( + input_image, dtype=keras.config.floatx() + ) + + # Resize the image to self.imageHeight x self.imageWidth + input_image = keras.ops.image.resize( + input_image, [self.imageHeight, self.imageWidth] + ) + + inputImageArray = keras.initializers.Constant( + input_image, dtype=keras.config.floatx() + ) + inputImageArray = keras.ops.expand_dims( + input_image[..., :3], axis=0 + ) + input_image_tensor = keras.ops.cast( + (inputImageArray / 255.0) * 2 - 1, self.dtype + ) + + print(input_image_tensor.shape) + # displayImage(input_image_tensor, name = "1preppedImage") + elif isinstance(input_image, Array): + print("...received jax.Array (JAX Array)...") + input_image_tensor = input_image + # displayImage(input_image_tensor, name = "1preppedImage") + + ### Step 4: Prepare the image mask, if it was given + if type(input_mask) is str: + print("...preparing input mask...") + input_mask = Image.open(input_mask) + input_mask = input_mask.resize((self.imageWidth, self.imageHeight)) + input_mask_array = np.array(input_mask, dtype=np.float32)[ + None, ..., None + ] + input_mask_array = input_mask_array / 255.0 + + latent_mask = input_mask.resize( + (self.imageWidth // 8, self.imageHeight // 8) + ) + latent_mask = np.array(latent_mask, dtype=np.float32)[None, ..., None] + latent_mask = 1 - (latent_mask.astype("float") / 255.0) + latent_mask_tensor = keras.ops.cast( + keras.ops.repeat(latent_mask, batch_size, axis=0), self.dtype + ) + else: + latent_mask_tensor = None + + ### Step 5: Create a random seed if one is not provided + if seed is None: + keras_print("...generating random seed...") + seed = random.randint(1000, sys.maxsize) + seed = int(seed) + else: + seed = int(seed) + + ### Step 6: Create time steps + keras_print("...creating time steps...") + timesteps = keras.ops.arange(1, 1000, 1000 // num_steps) + + ### Step 7: Load Sampler and: + ### Step 8: Start Diffusion + if sampler == "DPMSolver": + keras_print("...using DPM Solver...\n...starting sampler...") + + alphasCumprod = keras.initializers.Constant(_ALPHAS_CUMPROD) + + noiseScheduler = DPMSolver.NoiseScheduler(beta_schedule="scaled_linear") + + print( + "...starting diffusion...\n...this solver not supported yet!\nDividing by zero now:\n" + ) + + x = 5 / 0 + else: + if sampler is None: + keras_print("...no sampler given...") + + # ControlNet + # Parameters: [0]Use ControlNet, [1] Input Image, [2]Strength, [3] Cache Input + if self.controlNet is not None: + controlNetImage = [ + keras.initializers.Constant( + controlNetImage[0].copy(), dtype=keras.config.floatx() + ) + / 255.0 + ] + if controlNetCache is False: + self.controlNetCache = None + if type(self.controlNetCache) is dict: + if len(self.controlNetCache["unconditional"]) != timesteps: + keras_print("Incompatible cache!") + self.controlNetCache = None + controlNetParamters = [ + True, + controlNetImage, + controlNetStrength, + self.controlNetCache, + ] + else: + controlNetParamters = [False, None, 1, None] + + # Create Sampler + sampler = BasicSampler( + model=self, + timesteps=timesteps, + batchSize=batch_size, + seed=seed, + inputImage=input_image_tensor, + inputMask=latent_mask_tensor, + inputImageStrength=input_image_strength, + temperature=temperature, + AlphasCumprod=_ALPHAS_CUMPROD, + controlNetInput=controlNetParamters[ + 1 + ], # Input Image, assuming pre-processed + ) + + if vPrediction is True: + keras_print("...using v-prediction...") + + # Sample, create image essentially + latentImage, self.controlNetCache = sampler.sample( + context, + unconditionalContext, + unconditional_guidance_scale, + controlNet=[ + controlNetParamters[0], + controlNetParamters[2], + controlNetParamters[3], + ], # [0]Use Control Net, [2]Strength, [3]Cache + vPrediction=vPrediction, + device=self.device, + ) + + ### Step 9: Decoding stage + keras_print("\n...decoding latent image...") + decoded = self.decoder(latentImage, training=False) + decoded = ((decoded + 1) / 2) * 255 + + ### Step 10: Merge inpainting result of input mask with original image + if input_mask is not None: + decoded = ( + inputImageArray * (1 - input_mask_array) + + np.array(decoded) * input_mask_array + ) + + ### Memory cleanup + gc.collect() + + ### Step 11: return final image as an array + return np.clip(decoded, 0, 255).astype("uint8") + + def changePolicy(self, policy): + + if policy == "mixed_float16": + # self.dtype = tf.float16 + if keras.mixed_precision.global_policy().name != "mixed_float16": + print("\n...using mixed precision...") + keras.mixed_precision.set_global_policy("mixed_float16") + # self.dtype = tf.float16 + + if policy == "float32": + # self.dtype = keras.config.floatx() + if keras.mixed_precision.global_policy().name != "float32": + print("\n...using regular precision...") + keras.mixed_precision.set_global_policy("float32") + # self.dtype = keras.config.floatx() + + def encodeText(self, prompt, batch_size, legacy): + TextLimit = MAX_TEXT_LEN - 1 + with keras.device(self.device): + if legacy is True: + # First, encode the prompt + inputs = self.tokenizer.encode(prompt) + # Then check the inputs length and truncate if too long + if len(inputs) > TextLimit: + keras_print( + "Prompt is too long (should be less than 77 words). Truncating down to 77 words..." + ) + inputs = inputs[:TextLimit] + + """## Create numpy array with the inputs + # Phrase - aka the prompt + phrase = [49406] + inputs + [49407] * (TextLimit - len(inputs)) + phrase = np.array(phrase)[None].astype("int32") + phrase = np.repeat(phrase, batch_size, axis = 0) + + # Position ID + pos_ids = np.array(list(range(77)))[None].astype("int32") + pos_ids = np.repeat(pos_ids, batch_size, axis = 0)""" + + # Phrase - aka the prompt + phrase = keras.ops.concatenate( + [[49406], inputs, [49407] * (TextLimit - len(inputs))], axis=0 + ) + phrase = keras.ops.expand_dims(phrase, axis=0) + phrase = keras.ops.repeat(phrase, batch_size, axis=0) + phrase = keras.ops.cast(phrase, dtype="int32") + + # Position ID + pos_ids = keras.ops.expand_dims(keras.ops.arange(77), axis=0) + pos_ids = keras.ops.repeat(pos_ids, batch_size, axis=0) + pos_ids = keras.ops.cast(pos_ids, dtype="int32") + else: + # First, encode the prompt + TextLimit += 1 + if isinstance(prompt, str): + inputs = [prompt] + # Then tokenize the prompt + startOfToken = self.tokenizer.encoder[""] + endOfToken = self.tokenizer.encoder[""] + allTokens = [ + [startOfToken] + self.tokenizer.encode(input) + [endOfToken] + for input in inputs + ] + # Create the empty tensor/numpy array to load the tokens into + phrase = np.zeros((len(allTokens), TextLimit), dtype=np.int32) + + for i, tokens in enumerate(allTokens): + if len(tokens) > TextLimit: + tokens = tokens[:TextLimit] # Truncate + tokens[-1] = endOfToken + phrase[i, : len(tokens)] = np.array(tokens) + + phrase = np.repeat(phrase, batch_size, axis=0) + + pos_ids = np.array(list(range(TextLimit)))[None].astype("int32") + pos_ids = np.repeat(pos_ids, batch_size, axis=0) + + return phrase, pos_ids + + def setWeights(self, weights, VAE="Original"): + self.weights = weights + # Load weights for VAE models, if given + if VAE != "Original": + if ".ckpt" in VAE: + loadWeightsFromPytorchCKPT( + self, + VAE, # Which weights to load, in this case weights for VAE + self.legacy, # Which version of Stable Diffusion + ["decoder", "encoder"], # Models to load + True, + ) + elif ".safetensors" in VAE: + loadWeightsFromSafeTensor( + self, + VAE, # Which weights to load, in this case weights for VAE + self.legacy, # Which version of Stable Diffusion + ["decoder", "encoder"], # Models to load + True, + ) + else: + loadWeightsFromKeras(self, VAE, VAEOnly=True) + + # Load all weights + if ".ckpt" in self.weights: + if VAE == "Original": # Load all weights from PyTorch .ckpt + modules = ["text_encoder", "diffusion_model", "decoder", "encoder"] + else: # only load weights for the text encoder and diffusion model if VAE was given + modules = ["text_encoder", "diffusion_model"] + loadWeightsFromPytorchCKPT( + self, + self.weights, # Which weights to load, in this case maybe all four models + self.legacy, # Which version of Stable Diffusion + modules, # Which specific Models to load + ) + elif ".safetensors" in self.weights: + if VAE == "Original": # Load all weights from PyTorch .ckpt + modules = ["text_encoder", "diffusion_model", "decoder", "encoder"] + else: # only load weights for the text encoder and diffusion model if VAE was given + modules = ["text_encoder", "diffusion_model"] + loadWeightsFromSafeTensor( + self, + self.weights, # Which weights to load, in this case maybe all four models + self.legacy, # Which version of Stable Diffusion + modules, # Which specific Models to load + ) + else: + if VAE == "Original": + loadWeightsFromKeras(self, self.weights, VAEOnly=False) + else: + loadWeightsFromKeras(self, self.weights, VAEOnly=VAE) + + +### Functions ### + + +def CreateModels( + imageHeight=512, + imageWidth=512, + preCompiled=None, + legacy=True, + addedTokens=0, + useControlNet=[False], + device=None, +): + with keras.device(device): + # Memory Clean up + keras.backend.clear_session() + gc.collect() + + controlNet = None + + if legacy is True: + # Are we using Pre-Stable Diffusion 2.0? + + keras_print("\nCreating models in legacy mode...") + + # Create Text Encoder model + input_word_ids = keras.layers.Input(shape=(MAX_TEXT_LEN,), dtype="int32") + input_pos_ids = keras.layers.Input(shape=(MAX_TEXT_LEN,), dtype="int32") + embeds = CLIPTextTransformer()([input_word_ids, input_pos_ids]) + text_encoder = keras.models.Model([input_word_ids, input_pos_ids], embeds) + keras_print("Created text encoder model") + + if useControlNet[0] is False: + # Create Diffusion model + diffusion_model = DiffusionModel(imageHeight, imageWidth, MAX_TEXT_LEN) + keras_print("Created diffusion model") + else: + # Create seperate control net model + controlNet = ControlNetModel(imageHeight, imageWidth, MAX_TEXT_LEN) + + keras_print("Created ControlNet Model") + + # Create Diffusion model + diffusion_model = ControlDiffusionModel( + imageHeight, imageWidth, MAX_TEXT_LEN + ) + keras_print("Created diffusion model") + + # Create Decoder model + decoder = Decoder( + img_height=imageHeight, + img_width=imageWidth, + ) + keras_print("Created decoder model") + + # Create Image Encoder model + encoder = ImageEncoder(img_height=imageHeight, img_width=imageWidth) + keras_print("Created encoder model") + + else: + # We're using SD 2.0 and newer + + print("\nCreating models in contemporary mode...") + + # Create Text Encoder model + input_word_ids = keras.layers.Input(shape=(MAX_TEXT_LEN,), dtype="int32") + input_pos_ids = keras.layers.Input(shape=(MAX_TEXT_LEN,), dtype="int32") + embeds = OpenCLIPTextTransformer()([input_word_ids, input_pos_ids]) + text_encoder = keras.models.Model([input_word_ids, input_pos_ids], embeds) + print("Created text encoder model") + + # Create Diffusion model + diffusion_model = DiffusionModelV2(imageHeight, imageWidth, MAX_TEXT_LEN) + print("Created diffusion model") + + # Create Decoder model + decoder = Decoder( + img_height=imageHeight, + img_width=imageWidth, + ) + print("Created decoder model") + + # Create Image Encoder model + encoder = ImageEncoder(img_height=imageHeight, img_width=imageWidth) + print("Created encoder model") + + # return created models + return text_encoder, diffusion_model, decoder, encoder, controlNet + + +def loadWeightsFromKeras(models, weightsPath, VAEOnly=False): + keras_print("\nLoading Keras weights for:", weightsPath) + textEncoderWeights = weightsPath + "/text_encoder.h5" + diffusionModelWeights = weightsPath + "/diffusion_model.h5" + imageEncoderWeights = weightsPath + "/encoder.h5" + decoderWeights = weightsPath + "/decoder.h5" + + if VAEOnly is False: + models.text_encoder.load_weights(textEncoderWeights) + keras_print("...Text Encoder weights loaded!") + models.diffusion_model.load_weights(diffusionModelWeights) + keras_print("...diffusion model weights loaded") + models.encoder.load_weights(imageEncoderWeights) + keras_print("...Image Encoder weights loaded!") + models.decoder.load_weights(decoderWeights) + keras_print("...Decoder weights loaded!") + keras_print("All weights loaded!") + + +def loadWeightsFromPytorchCKPT( + model, + pytorch_ckpt_path, + legacy=True, + moduleName=["text_encoder", "diffusion_model", "decoder", "encoder"], + VAEoverride=False, +): + print("\nLoading pytorch checkpoint " + pytorch_ckpt_path) + pytorchWeights = torch.load(pytorch_ckpt_path, map_location="mps") + if legacy is True: + ## Legacy Mode + print("...loading pytroch weights in legacy mode...") + for module in moduleName: + module_weights = [] + if module == "text_encoder": + module = "text_encoder_legacy" + for i, (key, perm) in enumerate(PYTORCH_CKPT_MAPPING[module]): + if VAEoverride is True: + key = key.replace("first_stage_model.", "") + if "state_dict" in pytorchWeights: + weight = pytorchWeights["state_dict"][key].detach().numpy() + else: + weight = pytorchWeights[key].detach().numpy() + if perm is not None: + weight = np.transpose(weight, perm) + module_weights.append(weight) + if module == "text_encoder_legacy": + module = "text_encoder" + + getattr(model, module).set_weights(module_weights) + + print("Loaded %d pytorch weights for %s" % (len(module_weights), module)) + else: + ## Contemporary Mode + print("...loading pytorch weights in contemporary mode...") + for module in moduleName: + module_weights = [] + in_projWeightConversion = [] + in_projBiasConversion = [] + for i, (key, perm) in enumerate(PYTORCH_CKPT_MAPPING[module]): + if "in_proj" not in key: + if VAEoverride is True: + key = key.replace("first_stage_model.", "") + weight = pytorchWeights["state_dict"][key].detach().numpy() + + if module == "diffusion_model": + if "proj_in.weight" in key or "proj_out.weight" in key: + # print(i+1," Overriding premuation from constants:\n",key) + # This is so the constants.py "diffusion_model" dictionary keeps its legacy state + perm = (1, 0) + + if perm is not None: + weight = np.transpose(weight, perm) + module_weights.append(weight) + else: + if module == "text_encoder": + # "in_proj" layer of SD2.x is a matrix multiplcation of the query, key, and value layers of SD1.4/5 + # We will slice this layer into the the three vectors + if "weight" in key: + # Get the in_proj.weight + originalWeight = ( + pytorchWeights["state_dict"][key].float().numpy() + ) + + queryWeight = originalWeight[:1024, ...] + queryWeight = np.transpose(queryWeight, (1, 0)) + + keyWeight = originalWeight[1024:2048, ...] + keyWeight = np.transpose(keyWeight, (1, 0)) + + valueWeight = originalWeight[2048:, ...] + valueWeight = np.transpose(valueWeight, (1, 0)) + + # Clear local variable to carry forward for bias + in_projWeightConversion = [] + + in_projWeightConversion.append(queryWeight) # Query states + in_projWeightConversion.append(keyWeight) # Key states + in_projWeightConversion.append(valueWeight) # Value states + elif "bias" in key: + originalBias = ( + pytorchWeights["state_dict"][key].float().numpy() + ) + + queryBias = originalBias[:1024] + + keyBias = originalBias[1024:2048] + + valueBias = originalBias[2048:] + + # Clear local variable to carry forward for bias + in_projBiasConversion = [] + + in_projBiasConversion.append(queryBias) # Query states + in_projBiasConversion.append(keyBias) # Key states + in_projBiasConversion.append(valueBias) # Value states + + # add the converted weights/biases in the correct order + # Query + module_weights.append(in_projWeightConversion[0]) + module_weights.append(in_projBiasConversion[0]) + # Key + module_weights.append(in_projWeightConversion[1]) + module_weights.append(in_projBiasConversion[1]) + # Value + module_weights.append(in_projWeightConversion[2]) + module_weights.append(in_projBiasConversion[2]) + + print("Loading weights for ", module) + + getattr(model, module).set_weights(module_weights) + print("Loaded %d pytorch weights for %s" % (len(module_weights), module)) + + ## Memory Clean up + del pytorchWeights + + +def loadWeightsFromSafeTensor( + model, + safetensor_path, + legacy=True, + moduleName=["text_encoder", "diffusion_model", "decoder", "encoder"], + VAEoverride=False, +): + print("\nLoading safetensor " + safetensor_path) + safeTensorWeights = load_file(safetensor_path) + if legacy is True: + ## Legacy Mode + print("...loading safetensors weights in legacy mode...") + for module in moduleName: + module_weights = [] + if module == "text_encoder": + module = "text_encoder_legacy" + for i, (key, perm) in enumerate(PYTORCH_CKPT_MAPPING[module]): + if VAEoverride is True: + key = key.replace("first_stage_model.", "") + if "state_dict" in safeTensorWeights: + weight = safeTensorWeights["state_dict"][key].detach().numpy() + else: + if module == "controlNet": + # Repalce "control_model." in case the safetensor doesn't have that key + key = key.replace("control_model.", "") + weight = safeTensorWeights[key].detach().numpy() + if perm is not None: + weight = np.transpose(weight, perm) + module_weights.append(weight) + if module == "text_encoder_legacy": + module = "text_encoder" + + getattr(model, module).set_weights(module_weights) + + print( + "Loaded %d safetensors weights for %s" % (len(module_weights), module) + ) + else: + ## Contemporary Mode + print("...loading safetensors weights in contemporary mode...") + for module in moduleName: + module_weights = [] + in_projWeightConversion = [] + in_projBiasConversion = [] + for i, (key, perm) in enumerate(PYTORCH_CKPT_MAPPING[module]): + if "in_proj" not in key: + if VAEoverride is True: + key = key.replace("first_stage_model.", "") + weight = safeTensorWeights["state_dict"][key].detach().numpy() + + if module == "diffusion_model": + if "proj_in.weight" in key or "proj_out.weight" in key: + # print(i+1," Overriding premuation from constants:\n",key) + # This is so the constants.py "diffusion_model" dictionary keeps its legacy state + perm = (1, 0) + + if perm is not None: + weight = np.transpose(weight, perm) + module_weights.append(weight) + else: + if module == "text_encoder": + # "in_proj" layer of SD2.x is a matrix multiplcation of the query, key, and value layers of SD1.4/5 + # We will slice this layer into the the three vectors + if "weight" in key: + # Get the in_proj.weight + originalWeight = ( + safeTensorWeights["state_dict"][key].float().numpy() + ) + + queryWeight = originalWeight[:1024, ...] + queryWeight = np.transpose(queryWeight, (1, 0)) + + keyWeight = originalWeight[1024:2048, ...] + keyWeight = np.transpose(keyWeight, (1, 0)) + + valueWeight = originalWeight[2048:, ...] + valueWeight = np.transpose(valueWeight, (1, 0)) + + # Clear local variable to carry forward for bias + in_projWeightConversion = [] + + in_projWeightConversion.append(queryWeight) # Query states + in_projWeightConversion.append(keyWeight) # Key states + in_projWeightConversion.append(valueWeight) # Value states + elif "bias" in key: + originalBias = ( + safeTensorWeights["state_dict"][key].float().numpy() + ) + + queryBias = originalBias[:1024] + + keyBias = originalBias[1024:2048] + + valueBias = originalBias[2048:] + + # Clear local variable to carry forward for bias + in_projBiasConversion = [] + + in_projBiasConversion.append(queryBias) # Query states + in_projBiasConversion.append(keyBias) # Key states + in_projBiasConversion.append(valueBias) # Value states + + # add the converted weights/biases in the correct order + # Query + module_weights.append(in_projWeightConversion[0]) + module_weights.append(in_projBiasConversion[0]) + # Key + module_weights.append(in_projWeightConversion[1]) + module_weights.append(in_projBiasConversion[1]) + # Value + module_weights.append(in_projWeightConversion[2]) + module_weights.append(in_projBiasConversion[2]) + + print("Loading weights for ", module) + + getattr(model, module).set_weights(module_weights) + print( + "Loaded %d safetensors weights for %s" % (len(module_weights), module) + ) + + ## Memory Clean up + del safeTensorWeights + + +def displayImage(input_image_tensor, name="image"): + # Assuming input_image_tensor is a TensorFlow tensor representing the image + # Remove the batch dimension + input_image_tensor = keras.ops.squeeze(input_image_tensor, axis=0) + + # Convert the tensor to a NumPy array + input_image_array = input_image_tensor.numpy() + + # Rescale the array to the range [0, 255] + input_image_array = ((input_image_array + 1) / 2.0) * 255.0 + + # Convert the array to uint8 data type + input_image_array = input_image_array.astype("uint8") + + # Display the image using Matplotlib + imageFromBatch = Image.fromarray(input_image_array) + imageFromBatch.save("debug/" + name + ".png") diff --git a/keras_hub/src/models/control_net/tools/ReadMe.md b/keras_hub/src/models/control_net/tools/ReadMe.md new file mode 100644 index 0000000000..ea9cba85e8 --- /dev/null +++ b/keras_hub/src/models/control_net/tools/ReadMe.md @@ -0,0 +1,3 @@ +## Tools ## + +These files are helpful tools the TensorFlow pipeline uses. diff --git a/keras_hub/src/models/control_net/tools/__init__.py b/keras_hub/src/models/control_net/tools/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_hub/src/models/control_net/tools/textEmbeddings.py b/keras_hub/src/models/control_net/tools/textEmbeddings.py new file mode 100644 index 0000000000..efeb62d398 --- /dev/null +++ b/keras_hub/src/models/control_net/tools/textEmbeddings.py @@ -0,0 +1,252 @@ +import keras +import numpy as np +import torch as torch +from keras_hub.src.models.control_net.utils import keras_print + + +class Embedding: + """ + This is an object class that stores the loaded Text Embedding + It only needs two key variables to exist: + Name: unique name of embedding + Vector: vector(s) of the embedding + """ + + def __init__(self, vector, name, step=None): + self.vector = vector + self.name = name + + # Adjust the vector shape to (x, 768) + # This is for single vector text embeddings, which may come as a (768,) instead of (1,768) + if self.vector.ndim < 2: + if self.vector.shape[0] == 768: # Stable Diffusion 1.4/1.5 + self.vector = self.vector.reshape((1, 768)) + elif self.vector.shape[0] == 1024: # Stable Diffusion 2.x + self.vector = self.vector.reshape((1, 1024)) + + # Create the unique tokens + if self.vector.shape[0] > 1: + # If we have a multidimensional vector, then we'll split up the token per dimension + self.token = [] + for dimension in range(self.vector.shape[0]): + self.token.append("<" + self.name + "_" + str(dimension) + ">") + self.name = "<" + self.name + ">" + else: + # Single dimension vector, so the token is the name + self.name = "<" + self.name + ">" + self.token = self.name + + # Extra info + self.step = step + self.shape = self.vector.shape + self.vectors = 0 + self.cached_checksum = None + self.sd_checkpoint = None + self.sd_checkpoint_name = None + self.optimizer_state_dict = None + self.filename = self.name + ".pt" + + def save(self, filename): + embedding_data = { + "string_to_token": {"*": 265}, + "string_to_param": {"*": self.vec}, + "name": self.name, + "step": self.step, + "sd_checkpoint": self.sd_checkpoint, + "sd_checkpoint_name": self.sd_checkpoint_name, + } + + print("If I could save, I'd save this:\n", embedding_data) + + +def injectTokens(prompt, embeddings): + """ + This code searches the given prompt for any of the given embeddings and replaces it with the proper embedding for the tokenizer + + Only necessary for multi-vector embeddings because we've split the token up per vector. + + For example, if we have a multi-vector embedding like this: + Token: + Vectors: (3,768) + + Then, in the creation of the embedding class, we've automatically created the actual token for the tokenizer: + Token: + Vecors: (3,768) + + So, if we find a multi-vector token, then we replace it's user-friendly token name with the actual token name. For example: + + Prompt: A picture of , painted by Caravaggio + + becomes + + Prompt: A picture of , painted by Caravaggio + """ + prompt = prompt.lower() + foundTokens = 0 + + for embedding in embeddings: + if embedding.name in prompt: + # First, let's prepare the replacement tokens + replacementToken = "" + if type(embedding.token) is str: + replacementToken = embedding.name + elif type(embedding.token) is list: + for token in embedding.token: + replacementToken = replacementToken + " " + token + foundTokens += 1 + prompt = prompt.replace(embedding.name, replacementToken) + + keras_print("...found", foundTokens, "text embedding token(s)...") + + return prompt + + +def loadTextEmbedding(textEmbeddings): + """ + Using pytorch, we load in the text embedding weights as numpy arrays and store them in the Embeddings object class + + textEmbeddings REQUIRES a list expecting the first index to have the file path. For example: + + ['models/embeddings/','myEmbedding.pt','myOtherEmbedding.bin','etc.pt'] + + The code then seperates the file path as a variable and uses it to find the embeddings + """ + finalTextEmbeddings = [] + tokensToAdd = [] + # save file path into seperate location + embeddingsPath = textEmbeddings[0] + # delete file path from list + del textEmbeddings[0] + + for textEmbedding in textEmbeddings: + print("\nLoading text embedding " + textEmbedding) + # Load the text embedding file + textEmbeddingFile = torch.load( + embeddingsPath + textEmbedding, map_location="cpu" + ) + + # Debug Info + # print("Data for",textEmbedding,"\n",textEmbeddingFile) + # print(textEmbeddingFile.keys()) # Shows the entire file data, which should be a dictionary + + if "pt" in textEmbedding: + # load the necessary values + stringToToken = textEmbeddingFile[ + "string_to_token" + ] # Token assigned to vector + stringToParam = textEmbeddingFile["string_to_param"] # The vector(s) + textEmbeddingName = textEmbedding.replace(".pt", "") + elif "bin" in textEmbedding: + # load the necessary values + for key, value in textEmbeddingFile.items(): + stringToToken = key # Token assigned to vector + stringToParam = value # The vector + textEmbeddingName = textEmbedding.replace(".bin", "") + + # Save the token for finding the vector + if type(stringToToken) is dict: + token = list(stringToToken.keys())[ + 0 + ] # Convert dictionary to a list and then pull the first value + else: + token = stringToToken + + # Save the vector by finding it with the token + if type(stringToToken) is dict: + textEmbeddingVector = stringToParam[token] + else: + textEmbeddingVector = stringToParam + + # Debug info + # print("Weight type:\n",type(textEmbeddingVector)) + # print("Vector shape:\n", textEmbeddingVector.shape) + + # Make the token lowercase + token = textEmbeddingName.lower() + print("Unique Token: ", "<" + token + ">") + + embedding = Embedding(name=token, vector=textEmbeddingVector.detach().numpy()) + try: + embedding.step = textEmbeddingFile["step"] + embedding.sd_checkpoint_name = textEmbeddingFile["sd_checkpoint_name"] + except Exception as e: + embedding.step = 0 + embedding.sd_checkpoint_name = "N/A" + + finalTextEmbeddings.append(embedding) + + if type(embedding.token) is str: + tokensToAdd.append(embedding.token) + elif type(embedding.token) is list: + tokensToAdd.extend(embedding.token) + + # Memory Clean up + del textEmbeddingFile + + # add file path back to list for re-compiling later, if needed + textEmbeddings.insert(0, embeddingsPath) + + return finalTextEmbeddings, tokensToAdd + + +def loadTextEmbeddingWeight(textEncoder, CLIP, maxTextLength, embeddings, legacy): + """ + This code is where the magic happens with Text Embeddings. + We're going to add our text embeddings to the Text Encoder Model + """ + keras_print("\nLoading Text Embedding weights...") + + if legacy == True: + columnLength = 768 + else: + columnLength = 1024 + + # First get the current weights of the text encoder + originalWeights = textEncoder.get_weights() + + # Find the "token_embedding" weights + updatedWeights = originalWeights[0] + successfulTokenCount = 0 + + # Add our token vectors to the "token_embedding" weights + for embedding in embeddings: + if np.size(embedding.vector[0]) != columnLength: + # if our vector column length doesn't match our version of stable diffusion, then skip this embedding + print( + embedding.name, + "not compatible with current version of Stable Diffusion", + ) + continue + + # Add our vectors to the weights for the "token_embeddings" + updatedWeights = np.vstack((updatedWeights, embedding.vector)) + + # Update our token count, taking multidimensional vectors into account + if type(embedding.token) is list: + successfulTokenCount += len(embedding.token) + else: + successfulTokenCount += 1 + + keras_print( + "...found all compatible embeddings, total:", successfulTokenCount, "..." + ) + + # Create new Text Encoder model, increasing the size of tokens for the CLIP model + keras_print("...creating new text encoder model with embeddings") + input_word_ids = keras.layers.Input(shape=(maxTextLength,), dtype="int32") + input_pos_ids = keras.layers.Input(shape=(maxTextLength,), dtype="int32") + embeds = CLIP(vocabularySize=49408 + successfulTokenCount)( + [input_word_ids, input_pos_ids] + ) + textEncoder = keras.models.Model([input_word_ids, input_pos_ids], embeds) + keras_print( + "...created text encoder model with", successfulTokenCount, "token(s) added" + ) + + # Update the weights for "token_embedding" and then set the weights of the model + keras_print("...setting updated weights for token_embedding...") + originalWeights[0] = updatedWeights + textEncoder.set_weights(originalWeights) + keras_print("...weights loaded!") + + return textEncoder diff --git a/keras_hub/src/models/control_net/tools/tools.py b/keras_hub/src/models/control_net/tools/tools.py new file mode 100644 index 0000000000..9ae797f4e4 --- /dev/null +++ b/keras_hub/src/models/control_net/tools/tools.py @@ -0,0 +1,10 @@ +from keras_hub.src.models.control_net.utils import keras_print + + +def getWeightsAndNames(model): + # For finding the order of weights + names = [weight.name for layer in model.layers for weight in layer.weights] + weights = model.get_weights() + + for name, weight in zip(names, weights): + keras_print(name, "\n", weight.shape) diff --git a/keras_hub/src/models/control_net/utils.py b/keras_hub/src/models/control_net/utils.py new file mode 100644 index 0000000000..c8d1956aaf --- /dev/null +++ b/keras_hub/src/models/control_net/utils.py @@ -0,0 +1,25 @@ +from keras.api import backend + + +# redundant once https://github.com/keras-team/keras/issues/21137 is addressed +def keras_print(*args, **kwargs): + back_end = backend.backend() + if back_end == "tensorflow": + import tensorflow as tf + + return tf.print(*args, **kwargs) + elif back_end == "jax": + import jax.debug + + return jax.debug.print(*args, **kwargs) + else: + return print(*args, **kwargs) + # print_fn = {"jax": jax.debug.print, + # "tensorflow": keras_print}.get(backend, print) + # "torch" https://pytorch.org/docs/stable/generated/torch.set_printoptions.html ? + # "openvino" + # "numpy" + # return print_fn(*args, **kwargs) + + +__all__ = ["keras_print"] diff --git a/requirements-common.txt b/requirements-common.txt index da331b567a..5d7dae073d 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -18,3 +18,5 @@ sentencepiece tensorflow-datasets safetensors pillow +# Optional deps for control_net models +ftfy