Skip to content

Commit 73a2d80

Browse files
weight conversion wip
1 parent 8b0ef49 commit 73a2d80

File tree

11 files changed

+193
-74
lines changed

11 files changed

+193
-74
lines changed

keras_hub/api/models/__init__.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,7 @@
292292
QwenTokenizer as Qwen2Tokenizer,
293293
)
294294
from keras_hub.src.models.qwen_moe.qwen_moe_backbone import QwenMoeBackbone
295-
from keras_hub.src.models.qwen_moe.qwen_moe_causal_lm import (
296-
QwenmOECausalLM as QwenMoeCausalLM,
297-
)
295+
from keras_hub.src.models.qwen_moe.qwen_moe_causal_lm import QwenMoeCausalLM
298296
from keras_hub.src.models.qwen_moe.qwen_moe_causal_lm_preprocessor import (
299297
QwenMoeCausalLMPreprocessor,
300298
)

keras_hub/api/tokenizers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from keras_hub.src.models.qwen.qwen_tokenizer import (
3535
QwenTokenizer as Qwen2Tokenizer,
3636
)
37+
from keras_hub.src.models.qwen_moe.qwen_moe_tokenizer import QwenMoeTokenizer
3738
from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
3839
from keras_hub.src.models.siglip.siglip_tokenizer import SigLIPTokenizer
3940
from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer

keras_hub/src/models/qwen/qwen_decoder.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from keras_hub.src.layers.modeling.transformer_layer_utils import (
88
merge_padding_and_attention_mask,
99
)
10-
from keras_hub.src.models.qwen.qwen_attention import QwenAttention
1110
from keras_hub.src.models.qwen.qwen_layernorm import QwenLayerNorm
11+
from keras_hub.src.models.qwen_moe.qwen_moe_attention import QwenMoeAttention
1212
from keras_hub.src.utils.keras_utils import clone_initializer
1313

1414

@@ -79,7 +79,7 @@ def build(self, decoder_sequence_shape):
7979
self.hidden_dim = decoder_sequence_shape[-1]
8080

8181
# Self attention layer.
82-
self._self_attention_layer = QwenAttention(
82+
self._self_attention_layer = QwenMoeAttention(
8383
num_query_heads=self.num_query_heads,
8484
num_key_value_heads=self.num_key_value_heads,
8585
rope_max_wavelength=self.rope_max_wavelength,

keras_hub/src/models/qwen_moe/README.md

+45-1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,48 @@ Immediate TODOs
1010
1. What about new caching mechanism in HF?
1111

1212

13-
Reference - https://huggingface.co/docs/transformers/en/model_doc/qwen2_moe
13+
Reference - https://huggingface.co/docs/transformers/en/model_doc/qwen2_moe
14+
15+
Model Architecture:
16+
17+
```
18+
Qwen2MoeForCausalLM(
19+
(model): Qwen2MoeModel(
20+
(embed_tokens): Embedding(151936, 2048)
21+
(layers): ModuleList(
22+
(0-23): 24 x Qwen2MoeDecoderLayer(
23+
(self_attn): Qwen2MoeSdpaAttention(
24+
(q_proj): Linear(in_features=2048, out_features=2048, bias=True)
25+
(k_proj): Linear(in_features=2048, out_features=2048, bias=True)
26+
(v_proj): Linear(in_features=2048, out_features=2048, bias=True)
27+
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
28+
(rotary_emb): Qwen2MoeRotaryEmbedding()
29+
)
30+
(mlp): Qwen2MoeSparseMoeBlock(
31+
(gate): Linear(in_features=2048, out_features=60, bias=False)
32+
(experts): ModuleList(
33+
(0-59): 60 x Qwen2MoeMLP(
34+
(gate_proj): Linear(in_features=2048, out_features=1408, bias=False)
35+
(up_proj): Linear(in_features=2048, out_features=1408, bias=False)
36+
(down_proj): Linear(in_features=1408, out_features=2048, bias=False)
37+
(act_fn): SiLU()
38+
)
39+
)
40+
(shared_expert): Qwen2MoeMLP(
41+
(gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
42+
(up_proj): Linear(in_features=2048, out_features=5632, bias=False)
43+
(down_proj): Linear(in_features=5632, out_features=2048, bias=False)
44+
(act_fn): SiLU()
45+
)
46+
(shared_expert_gate): Linear(in_features=2048, out_features=1, bias=False)
47+
)
48+
(input_layernorm): Qwen2MoeRMSNorm((2048,), eps=1e-06)
49+
(post_attention_layernorm): Qwen2MoeRMSNorm((2048,), eps=1e-06)
50+
)
51+
)
52+
(norm): Qwen2MoeRMSNorm((2048,), eps=1e-06)
53+
(rotary_emb): Qwen2MoeRotaryEmbedding()
54+
)
55+
(lm_head): Linear(in_features=2048, out_features=151936, bias=False)
56+
)
57+
```
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from keras_hub.src.api_export import keras_hub_export
22
from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor
3-
from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer
43
from keras_hub.src.models.qwen_moe.qwen_moe_backbone import QwenMoeBackbone
4+
from keras_hub.src.models.qwen_moe.qwen_moe_tokenizer import QwenMoeTokenizer
55

66

77
@keras_hub_export(
@@ -11,7 +11,7 @@
1111
)
1212
class QwenMoeCausalLMPreprocessor(CausalLMPreprocessor):
1313
backbone_cls = QwenMoeBackbone
14-
tokenizer_cls = QwenTokenizer
14+
tokenizer_cls = QwenMoeTokenizer
1515

1616
def __init__(self, *args, **kwargs):
1717
super().__init__(*args, **kwargs)

keras_hub/src/models/qwen_moe/qwen_moe_decoder.py

+51-13
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,21 @@
1313

1414

1515
class QwenMoeMLP(keras.layers.Layer):
16-
def __init__(self, intermediate_dim, hidden_dim, activation_fn="silu"):
16+
def __init__(
17+
self,
18+
intermediate_dim,
19+
hidden_dim,
20+
activation_fn="silu",
21+
layer_norm_epsilon=1e-5,
22+
kernel_initializer="glorot_uniform",
23+
**kwargs,
24+
):
25+
super().__init__(**kwargs)
1726
self.intermediate_dim = intermediate_dim
1827
self.hidden_dim = hidden_dim
1928
self.activation_fn = activation_fn
29+
self.kernel_initializer = kernel_initializer
30+
self.layer_norm_epsilon = layer_norm_epsilon
2031

2132
def build(self, decoder_sequence_shape):
2233
# Feedforward layers.
@@ -91,42 +102,59 @@ def __init__(
91102
num_experts,
92103
top_k,
93104
norm_topk_prob,
94-
kernel_initializer,
105+
kernel_initializer="glorot_uniform",
106+
layer_norm_epsilon=1e-5,
107+
**kwargs,
95108
):
109+
super().__init__(**kwargs)
96110
self.hidden_dim = hidden_dim
97111
self.moe_intermediate_dim = moe_intermediate_dim
98112
self.shared_expert_intermediate_dim = shared_expert_intermediate_dim
99113
self.num_experts = num_experts
100114
self.top_k = top_k
101115
self.norm_topk_prob = norm_topk_prob
102116
self.kernel_initializer = kernel_initializer
117+
self.layer_norm_epsilon = layer_norm_epsilon
103118

104-
def build(self, input_shape):
105-
self.gate_proj = keras.layers.Dense(
106-
self.hidden_dim,
119+
def build(self, decoder_sequence_shape):
120+
self._sparse_feedforward_gate_dense = keras.layers.Dense(
121+
self.num_experts,
107122
kernel_initializer=clone_initializer(self.kernel_initializer),
108123
use_bias=False,
109124
dtype=self.dtype_policy,
110-
name="sparse_block_gate_proj",
125+
name="sparse_feedforward_gate_dense",
111126
)
127+
self._sparse_feedforward_gate_dense.build(decoder_sequence_shape)
112128

113129
self.experts = [
114130
QwenMoeMLP(
115131
intermediate_dim=self.moe_intermediate_dim,
116132
hidden_dim=self.hidden_dim,
133+
kernel_initializer=self.kernel_initializer,
134+
layer_norm_epsilon=self.layer_norm_epsilon,
117135
)
118136
for _ in range(self.num_experts)
119137
]
120-
self.shared_expert = QwenMoeMLP(
121-
intermediate_dim=self.shared_expert_intermediate_dim
138+
for expert in self.experts:
139+
expert.build(decoder_sequence_shape)
140+
141+
self.shared_expert_dense = QwenMoeMLP(
142+
intermediate_dim=self.shared_expert_intermediate_dim,
143+
hidden_dim=self.hidden_dim,
144+
kernel_initializer=self.kernel_initializer,
145+
layer_norm_epsilon=self.layer_norm_epsilon,
122146
)
123-
self.shared_expert_gate_proj = keras.layers.Dense(1, use_bias=False)
147+
self.shared_expert_dense.build(decoder_sequence_shape)
148+
149+
self.shared_expert_gate_dense = keras.layers.Dense(1, use_bias=False)
150+
self.shared_expert_gate_dense.build(decoder_sequence_shape)
151+
self.built = True
124152

125153
def call(self, hidden_states):
126154
batch_size, seq_len, hidden_dim = hidden_states.shape
127155
hidden_states = hidden_states.reshape(-1, hidden_dim)
128156

129-
router_logits = self.gate_proj(hidden_states)
157+
router_logits = self._sparse_feedforward_gate_dense(hidden_states)
130158

131159
routing_weights = ops.softmax(router_logits, axis=1)
132160
routing_weights, selected_experts = ops.top_k(
@@ -175,7 +203,7 @@ def call(self, hidden_states):
175203

176204
shared_expert_output = self.shared_expert(hidden_states)
177205
shared_expert_output = (
178-
ops.sigmoid(self.shared_expert_gate_proj(hidden_states))
206+
ops.sigmoid(self.shared_expert_gate_dense(hidden_states))
179207
* shared_expert_output
180208
)
181209

@@ -210,6 +238,7 @@ def __init__(
210238
sliding_window_size=4096,
211239
layer_index=0,
212240
mlp_only_layers=[],
241+
output_router_logits=False,
213242
**kwargs,
214243
):
215244
super().__init__(**kwargs)
@@ -238,6 +267,7 @@ def __init__(
238267
self.top_k = top_k
239268
self.norm_topk_prob = norm_topk_prob
240269
self.decoder_sparse_step = decoder_sparse_step
270+
self.output_router_logits = output_router_logits
241271

242272
self.supports_masking = True
243273

@@ -287,11 +317,20 @@ def build(self, decoder_sequence_shape):
287317
norm_topk_prob=self.norm_topk_prob,
288318
kernel_initializer=self.kernel_initializer,
289319
)
320+
self.mlp.build(decoder_sequence_shape)
290321
else:
291322
self.mlp = QwenMoeMLP(
292323
intermediate_dim=self.intermediate_dim,
293324
hidden_dim=self.hidden_dim,
294325
)
326+
self.mlp.build(decoder_sequence_shape)
327+
328+
self._feedforward_layernorm = QwenLayerNorm(
329+
epsilon=self.layer_norm_epsilon,
330+
dtype=self.dtype_policy,
331+
name="feedforward_layernorm",
332+
)
333+
self._feedforward_layernorm.build(decoder_sequence_shape)
295334

296335
self.built = True
297336

@@ -301,7 +340,6 @@ def call(
301340
decoder_padding_mask=None,
302341
decoder_attention_mask=None,
303342
self_attention_cache=None,
304-
output_router_logits=False,
305343
self_attention_cache_update_index=None,
306344
training=None,
307345
):
@@ -364,7 +402,7 @@ def call(
364402
if self_attention_cache is not None:
365403
output += self_attention_cache
366404

367-
if output_router_logits:
405+
if self.output_router_logits:
368406
output += (router_logits,)
369407

370408
return output
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from keras_hub.src.api_export import keras_hub_export
2+
from keras_hub.src.models.qwen_moe.qwen_moe_backbone import QwenMoeBackbone
3+
from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
4+
5+
6+
@keras_hub_export(
7+
"keras_hub.tokenizers.QwenMoeTokenizer",
8+
)
9+
class QwenMoeTokenizer(BytePairTokenizer):
10+
"""Tokenizer for Qwen models.
11+
12+
This tokenizer implements byte-pair encoding (BPE) for Qwen models,
13+
handling special tokens like BOS (beginning of sequence) and EOS (end of
14+
sequence).
15+
16+
Args:
17+
vocabulary: Dictionary mapping tokens to token IDs, or path to
18+
vocabulary file.
19+
merges: List of BPE merges, or path to merges file.
20+
bos_token: Beginning of sequence token. Defaults to None.
21+
eos_token: End of sequence token. Defaults to "<|endoftext|>".
22+
misc_special_tokens: Set of additional special tokens. Defaults to
23+
empty set.
24+
"""
25+
26+
backbone_cls = QwenMoeBackbone
27+
28+
def __init__(
29+
self,
30+
vocabulary=None,
31+
merges=None,
32+
**kwargs,
33+
):
34+
# Add EOS token
35+
eos_token = "<|endoftext|>"
36+
self._add_special_token(eos_token, "end_token")
37+
38+
self.start_token_id = None
39+
self.start_token = None
40+
self.pad_token_id = 0
41+
42+
super().__init__(
43+
vocabulary=vocabulary,
44+
merges=merges,
45+
**kwargs,
46+
)

0 commit comments

Comments
 (0)