Skip to content

Commit 050d032

Browse files
authored
implement of roformerv2 (#2145)
* implement of roformerv2 * Reduce one input parameter * fix some bug * add test case and fix note * pass testcast at local when backend is tf * remove RMSnorm form keras * remove RMSnorm form keras * remove RMSnorm form keras * locally running pass test at jax backend. * renew test * renew test * skip test when keras version lower than 3.6 * add version raise * fix * modify flash attn check * modify flash attn check * modify flash attn check * modify * fix bug at torch backend * fix bug at torch backend
1 parent 8a55e85 commit 050d032

19 files changed

+1465
-1
lines changed

keras_hub/api/models/__init__.py

+18
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,24 @@
323323
RobertaTextClassifierPreprocessor as RobertaPreprocessor,
324324
)
325325
from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
326+
from keras_hub.src.models.roformer_v2.roformer_v2_backbone import (
327+
RoformerV2Backbone as RorformerV2Backbone,
328+
)
329+
from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm import (
330+
RoformerV2MaskedLM,
331+
)
332+
from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm_preprocessor import (
333+
RoformerV2MaskedLMPreprocessor,
334+
)
335+
from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier import (
336+
RorformerV2TextClassifier,
337+
)
338+
from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier_preprocessor import (
339+
RoformerV2TextClassifierPreprocessor,
340+
)
341+
from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import (
342+
RoformerV2Tokenizer,
343+
)
326344
from keras_hub.src.models.sam.sam_backbone import SAMBackbone
327345
from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter
328346
from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import (

keras_hub/api/tokenizers/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
QwenTokenizer as Qwen2Tokenizer,
3636
)
3737
from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
38+
from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import (
39+
RoformerV2Tokenizer,
40+
)
3841
from keras_hub.src.models.siglip.siglip_tokenizer import SigLIPTokenizer
3942
from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer
4043
from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer

keras_hub/src/models/roformer_v2/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import keras
2+
from keras import initializers
3+
from keras import ops
4+
5+
6+
class RoformerNorm(keras.layers.Layer):
7+
"""A normalization layer for Roformer that implements RMS normalization."""
8+
9+
def __init__(self, epsilon=1e-6, **kwargs):
10+
super().__init__(**kwargs)
11+
self.epsilon = epsilon
12+
13+
def build(self, input_shape):
14+
dim = input_shape[-1]
15+
self.scale = self.add_weight(
16+
name="scale",
17+
trainable=True,
18+
shape=(dim,),
19+
initializer="ones",
20+
dtype=self.variable_dtype,
21+
)
22+
self.built = True
23+
24+
def call(self, x):
25+
x = ops.cast(x, "float32")
26+
var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True)
27+
x = x * ops.rsqrt(var + self.epsilon)
28+
return ops.cast(x * self.scale, self.compute_dtype)
29+
30+
def get_config(self):
31+
config = super().get_config()
32+
config.update({"epsilon": self.epsilon})
33+
return config
34+
35+
36+
class RoformrPositionalEmbedding(keras.layers.Layer):
37+
"""Native rotary implement by jianlin su
38+
from native implement
39+
https://github.com/bojone/bert4keras
40+
41+
"""
42+
43+
def __init__(self, output_dim, max_wavelength=10000, **kwargs):
44+
super().__init__(**kwargs)
45+
self.max_wavelength = max_wavelength
46+
self.output_dim = output_dim
47+
48+
def call(self, tensors):
49+
input_shape = ops.shape(tensors[0])
50+
seq_len = input_shape[1]
51+
position_ids = ops.arange(0, seq_len, dtype=tensors[0].dtype)[None]
52+
embeddings = self.sinusoidal_embeddings(
53+
position_ids, self.output_dim, self.max_wavelength
54+
)
55+
embeddings = ops.cast(embeddings, self.compute_dtype)
56+
57+
ndim = ops.ndim(tensors[0])
58+
sinusoidal = self.align(embeddings, [0, 1, -1], ndim)
59+
cos_pos = ops.repeat(sinusoidal[..., 1::2], 2, -1)
60+
sin_pos = ops.repeat(sinusoidal[..., ::2], 2, -1)
61+
outputs = []
62+
for tensor in tensors:
63+
tensor2 = ops.stack([-tensor[..., 1::2], tensor[..., ::2]], ndim)
64+
tensor2 = ops.reshape(tensor2, ops.shape(tensor))
65+
outputs.append(tensor * cos_pos + tensor2 * sin_pos)
66+
return outputs[0] if len(outputs) == 1 else outputs
67+
68+
def align(self, tensor, axes, ndim=None):
69+
ndim = ndim or max(axes) + 1
70+
indices = [None] * ndim
71+
for i in axes:
72+
indices[i] = slice(None)
73+
if keras.config.backend() == "jax":
74+
return tensor[tuple(indices)]
75+
return tensor[indices]
76+
77+
def sinusoidal_embeddings(self, pos, dim, base=10000):
78+
if dim % 2 != 0:
79+
raise ("Dimension must be even")
80+
81+
indices = ops.arange(0, dim // 2, dtype="float32")
82+
indices = ops.power(ops.cast(base, dtype="float32"), -2 * indices / dim)
83+
embeddings = ops.einsum("...,d->...d", pos, indices)
84+
embeddings = ops.stack(
85+
[ops.sin(embeddings), ops.cos(embeddings)], axis=-1
86+
)
87+
shape = list(ops.shape(embeddings))
88+
embeddings = ops.reshape(embeddings, shape[:-2] + [-1])
89+
return embeddings
90+
91+
def get_config(self):
92+
config = super().get_config()
93+
config.update(
94+
{
95+
"out_dim": self.out_dim,
96+
"max_wavelength": self.max_wavelength,
97+
}
98+
)
99+
return config
100+
101+
102+
@keras.saving.register_keras_serializable(package="keras_hub")
103+
class RoformerAttention(keras.layers.Layer):
104+
"""MultiHeadAttention by roformerV2
105+
106+
modifity from native implement
107+
https://github.com/bojone/bert4keras
108+
"""
109+
110+
def __init__(
111+
self,
112+
heads,
113+
head_size,
114+
out_dim=None,
115+
use_bias=False,
116+
max_wavelength=10000,
117+
kernel_initializer="glorot_uniform",
118+
**kwargs,
119+
):
120+
super().__init__(**kwargs)
121+
self.heads = heads
122+
self.head_size = head_size
123+
self.out_dim = out_dim or heads * head_size
124+
self.use_bias = use_bias
125+
self.kernel_initializer = initializers.get(kernel_initializer)
126+
self.max_wavelength = max_wavelength
127+
128+
def build(self, input_shape):
129+
super().build(input_shape)
130+
self.q_dense = keras.layers.Dense(
131+
units=self.head_size * self.heads,
132+
use_bias=self.use_bias,
133+
kernel_initializer=self.kernel_initializer,
134+
name="q_dense_layer",
135+
dtype=self.dtype_policy,
136+
)
137+
self.q_dense.build(input_shape)
138+
139+
self.k_dense = keras.layers.Dense(
140+
units=self.head_size * self.heads,
141+
use_bias=self.use_bias,
142+
kernel_initializer=self.kernel_initializer,
143+
name="k_dense_layer",
144+
dtype=self.dtype_policy,
145+
)
146+
self.k_dense.build(input_shape)
147+
148+
self.v_dense = keras.layers.Dense(
149+
units=self.head_size * self.heads,
150+
use_bias=self.use_bias,
151+
kernel_initializer=self.kernel_initializer,
152+
name="v_dense_layer",
153+
dtype=self.dtype_policy,
154+
)
155+
self.v_dense.build(input_shape)
156+
157+
self.o_dense = keras.layers.Dense(
158+
units=self.out_dim,
159+
use_bias=self.use_bias,
160+
kernel_initializer=self.kernel_initializer,
161+
name="o_dense_layer",
162+
dtype=self.dtype_policy,
163+
)
164+
self.o_dense.build([None, None, self.head_size * self.heads])
165+
166+
self.rotary_embedding_layer = RoformrPositionalEmbedding(
167+
self.head_size, self.max_wavelength, dtype=self.dtype_policy
168+
)
169+
self.rotary_embedding_layer.build([])
170+
171+
def call(self, x, attention_mask=None):
172+
qw = self.q_dense(x)
173+
kw = self.k_dense(x)
174+
vw = self.v_dense(x)
175+
176+
b, s = ops.shape(qw)[:2]
177+
qw = ops.reshape(qw, (b, s, self.heads, self.head_size))
178+
kw = ops.reshape(kw, (b, s, self.heads, self.head_size))
179+
vw = ops.reshape(vw, (b, s, self.heads, self.head_size))
180+
181+
qw, kw = self.rotary_embedding_layer([qw, kw])
182+
if keras.__version__ < "3.6":
183+
raise ("Please make sure your Keras version is >=3.6.")
184+
flash_attention = keras.config.is_flash_attention_enabled()
185+
attention_mask = ops.reshape(attention_mask, [b, 1, s, 1])
186+
if keras.config.backend() == "torch":
187+
attention_mask = ops.repeat(attention_mask, s, -1)
188+
attention_mask = ops.transpose(attention_mask, [0, 1, 3, 2])
189+
o = ops.dot_product_attention(
190+
qw, kw, vw, mask=attention_mask, flash_attention=flash_attention
191+
)
192+
193+
return self.o_dense(ops.reshape(o, [b, s, -1]))
194+
195+
def compute_output_shape(self, input_shape):
196+
return input_shape
197+
198+
def get_config(self):
199+
config = super().get_config()
200+
config.update(
201+
{
202+
"heads": self.heads,
203+
"head_size": self.head_size,
204+
"out_dim": self.out_dim,
205+
"use_bias": self.use_bias,
206+
"max_wavelength": self.max_wavelength,
207+
"kernel_initializer": initializers.serialize(
208+
self.kernel_initializer
209+
),
210+
}
211+
)
212+
return config

0 commit comments

Comments
 (0)