|
| 1 | +import keras |
| 2 | +from keras import initializers |
| 3 | +from keras import ops |
| 4 | + |
| 5 | + |
| 6 | +class RoformerNorm(keras.layers.Layer): |
| 7 | + """A normalization layer for Roformer that implements RMS normalization.""" |
| 8 | + |
| 9 | + def __init__(self, epsilon=1e-6, **kwargs): |
| 10 | + super().__init__(**kwargs) |
| 11 | + self.epsilon = epsilon |
| 12 | + |
| 13 | + def build(self, input_shape): |
| 14 | + dim = input_shape[-1] |
| 15 | + self.scale = self.add_weight( |
| 16 | + name="scale", |
| 17 | + trainable=True, |
| 18 | + shape=(dim,), |
| 19 | + initializer="ones", |
| 20 | + dtype=self.variable_dtype, |
| 21 | + ) |
| 22 | + self.built = True |
| 23 | + |
| 24 | + def call(self, x): |
| 25 | + x = ops.cast(x, "float32") |
| 26 | + var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True) |
| 27 | + x = x * ops.rsqrt(var + self.epsilon) |
| 28 | + return ops.cast(x * self.scale, self.compute_dtype) |
| 29 | + |
| 30 | + def get_config(self): |
| 31 | + config = super().get_config() |
| 32 | + config.update({"epsilon": self.epsilon}) |
| 33 | + return config |
| 34 | + |
| 35 | + |
| 36 | +class RoformrPositionalEmbedding(keras.layers.Layer): |
| 37 | + """Native rotary implement by jianlin su |
| 38 | + from native implement |
| 39 | + https://github.com/bojone/bert4keras |
| 40 | +
|
| 41 | + """ |
| 42 | + |
| 43 | + def __init__(self, output_dim, max_wavelength=10000, **kwargs): |
| 44 | + super().__init__(**kwargs) |
| 45 | + self.max_wavelength = max_wavelength |
| 46 | + self.output_dim = output_dim |
| 47 | + |
| 48 | + def call(self, tensors): |
| 49 | + input_shape = ops.shape(tensors[0]) |
| 50 | + seq_len = input_shape[1] |
| 51 | + position_ids = ops.arange(0, seq_len, dtype=tensors[0].dtype)[None] |
| 52 | + embeddings = self.sinusoidal_embeddings( |
| 53 | + position_ids, self.output_dim, self.max_wavelength |
| 54 | + ) |
| 55 | + embeddings = ops.cast(embeddings, self.compute_dtype) |
| 56 | + |
| 57 | + ndim = ops.ndim(tensors[0]) |
| 58 | + sinusoidal = self.align(embeddings, [0, 1, -1], ndim) |
| 59 | + cos_pos = ops.repeat(sinusoidal[..., 1::2], 2, -1) |
| 60 | + sin_pos = ops.repeat(sinusoidal[..., ::2], 2, -1) |
| 61 | + outputs = [] |
| 62 | + for tensor in tensors: |
| 63 | + tensor2 = ops.stack([-tensor[..., 1::2], tensor[..., ::2]], ndim) |
| 64 | + tensor2 = ops.reshape(tensor2, ops.shape(tensor)) |
| 65 | + outputs.append(tensor * cos_pos + tensor2 * sin_pos) |
| 66 | + return outputs[0] if len(outputs) == 1 else outputs |
| 67 | + |
| 68 | + def align(self, tensor, axes, ndim=None): |
| 69 | + ndim = ndim or max(axes) + 1 |
| 70 | + indices = [None] * ndim |
| 71 | + for i in axes: |
| 72 | + indices[i] = slice(None) |
| 73 | + if keras.config.backend() == "jax": |
| 74 | + return tensor[tuple(indices)] |
| 75 | + return tensor[indices] |
| 76 | + |
| 77 | + def sinusoidal_embeddings(self, pos, dim, base=10000): |
| 78 | + if dim % 2 != 0: |
| 79 | + raise ("Dimension must be even") |
| 80 | + |
| 81 | + indices = ops.arange(0, dim // 2, dtype="float32") |
| 82 | + indices = ops.power(ops.cast(base, dtype="float32"), -2 * indices / dim) |
| 83 | + embeddings = ops.einsum("...,d->...d", pos, indices) |
| 84 | + embeddings = ops.stack( |
| 85 | + [ops.sin(embeddings), ops.cos(embeddings)], axis=-1 |
| 86 | + ) |
| 87 | + shape = list(ops.shape(embeddings)) |
| 88 | + embeddings = ops.reshape(embeddings, shape[:-2] + [-1]) |
| 89 | + return embeddings |
| 90 | + |
| 91 | + def get_config(self): |
| 92 | + config = super().get_config() |
| 93 | + config.update( |
| 94 | + { |
| 95 | + "out_dim": self.out_dim, |
| 96 | + "max_wavelength": self.max_wavelength, |
| 97 | + } |
| 98 | + ) |
| 99 | + return config |
| 100 | + |
| 101 | + |
| 102 | +@keras.saving.register_keras_serializable(package="keras_hub") |
| 103 | +class RoformerAttention(keras.layers.Layer): |
| 104 | + """MultiHeadAttention by roformerV2 |
| 105 | +
|
| 106 | + modifity from native implement |
| 107 | + https://github.com/bojone/bert4keras |
| 108 | + """ |
| 109 | + |
| 110 | + def __init__( |
| 111 | + self, |
| 112 | + heads, |
| 113 | + head_size, |
| 114 | + out_dim=None, |
| 115 | + use_bias=False, |
| 116 | + max_wavelength=10000, |
| 117 | + kernel_initializer="glorot_uniform", |
| 118 | + **kwargs, |
| 119 | + ): |
| 120 | + super().__init__(**kwargs) |
| 121 | + self.heads = heads |
| 122 | + self.head_size = head_size |
| 123 | + self.out_dim = out_dim or heads * head_size |
| 124 | + self.use_bias = use_bias |
| 125 | + self.kernel_initializer = initializers.get(kernel_initializer) |
| 126 | + self.max_wavelength = max_wavelength |
| 127 | + |
| 128 | + def build(self, input_shape): |
| 129 | + super().build(input_shape) |
| 130 | + self.q_dense = keras.layers.Dense( |
| 131 | + units=self.head_size * self.heads, |
| 132 | + use_bias=self.use_bias, |
| 133 | + kernel_initializer=self.kernel_initializer, |
| 134 | + name="q_dense_layer", |
| 135 | + dtype=self.dtype_policy, |
| 136 | + ) |
| 137 | + self.q_dense.build(input_shape) |
| 138 | + |
| 139 | + self.k_dense = keras.layers.Dense( |
| 140 | + units=self.head_size * self.heads, |
| 141 | + use_bias=self.use_bias, |
| 142 | + kernel_initializer=self.kernel_initializer, |
| 143 | + name="k_dense_layer", |
| 144 | + dtype=self.dtype_policy, |
| 145 | + ) |
| 146 | + self.k_dense.build(input_shape) |
| 147 | + |
| 148 | + self.v_dense = keras.layers.Dense( |
| 149 | + units=self.head_size * self.heads, |
| 150 | + use_bias=self.use_bias, |
| 151 | + kernel_initializer=self.kernel_initializer, |
| 152 | + name="v_dense_layer", |
| 153 | + dtype=self.dtype_policy, |
| 154 | + ) |
| 155 | + self.v_dense.build(input_shape) |
| 156 | + |
| 157 | + self.o_dense = keras.layers.Dense( |
| 158 | + units=self.out_dim, |
| 159 | + use_bias=self.use_bias, |
| 160 | + kernel_initializer=self.kernel_initializer, |
| 161 | + name="o_dense_layer", |
| 162 | + dtype=self.dtype_policy, |
| 163 | + ) |
| 164 | + self.o_dense.build([None, None, self.head_size * self.heads]) |
| 165 | + |
| 166 | + self.rotary_embedding_layer = RoformrPositionalEmbedding( |
| 167 | + self.head_size, self.max_wavelength, dtype=self.dtype_policy |
| 168 | + ) |
| 169 | + self.rotary_embedding_layer.build([]) |
| 170 | + |
| 171 | + def call(self, x, attention_mask=None): |
| 172 | + qw = self.q_dense(x) |
| 173 | + kw = self.k_dense(x) |
| 174 | + vw = self.v_dense(x) |
| 175 | + |
| 176 | + b, s = ops.shape(qw)[:2] |
| 177 | + qw = ops.reshape(qw, (b, s, self.heads, self.head_size)) |
| 178 | + kw = ops.reshape(kw, (b, s, self.heads, self.head_size)) |
| 179 | + vw = ops.reshape(vw, (b, s, self.heads, self.head_size)) |
| 180 | + |
| 181 | + qw, kw = self.rotary_embedding_layer([qw, kw]) |
| 182 | + if keras.__version__ < "3.6": |
| 183 | + raise ("Please make sure your Keras version is >=3.6.") |
| 184 | + flash_attention = keras.config.is_flash_attention_enabled() |
| 185 | + attention_mask = ops.reshape(attention_mask, [b, 1, s, 1]) |
| 186 | + if keras.config.backend() == "torch": |
| 187 | + attention_mask = ops.repeat(attention_mask, s, -1) |
| 188 | + attention_mask = ops.transpose(attention_mask, [0, 1, 3, 2]) |
| 189 | + o = ops.dot_product_attention( |
| 190 | + qw, kw, vw, mask=attention_mask, flash_attention=flash_attention |
| 191 | + ) |
| 192 | + |
| 193 | + return self.o_dense(ops.reshape(o, [b, s, -1])) |
| 194 | + |
| 195 | + def compute_output_shape(self, input_shape): |
| 196 | + return input_shape |
| 197 | + |
| 198 | + def get_config(self): |
| 199 | + config = super().get_config() |
| 200 | + config.update( |
| 201 | + { |
| 202 | + "heads": self.heads, |
| 203 | + "head_size": self.head_size, |
| 204 | + "out_dim": self.out_dim, |
| 205 | + "use_bias": self.use_bias, |
| 206 | + "max_wavelength": self.max_wavelength, |
| 207 | + "kernel_initializer": initializers.serialize( |
| 208 | + self.kernel_initializer |
| 209 | + ), |
| 210 | + } |
| 211 | + ) |
| 212 | + return config |
0 commit comments