13
13
14
14
15
15
class QwenMoeMLP (keras .layers .Layer ):
16
- def __init__ (self , intermediate_dim , hidden_dim , activation_fn = "silu" ):
16
+ def __init__ (
17
+ self ,
18
+ intermediate_dim ,
19
+ hidden_dim ,
20
+ activation_fn = "silu" ,
21
+ layer_norm_epsilon = 1e-5 ,
22
+ kernel_initializer = "glorot_uniform" ,
23
+ ** kwargs ,
24
+ ):
25
+ super ().__init__ (** kwargs )
17
26
self .intermediate_dim = intermediate_dim
18
27
self .hidden_dim = hidden_dim
19
28
self .activation_fn = activation_fn
29
+ self .kernel_initializer = kernel_initializer
30
+ self .layer_norm_epsilon = layer_norm_epsilon
20
31
21
32
def build (self , decoder_sequence_shape ):
22
33
# Feedforward layers.
@@ -91,42 +102,59 @@ def __init__(
91
102
num_experts ,
92
103
top_k ,
93
104
norm_topk_prob ,
94
- kernel_initializer ,
105
+ kernel_initializer = "glorot_uniform" ,
106
+ layer_norm_epsilon = 1e-5 ,
107
+ ** kwargs ,
95
108
):
109
+ super ().__init__ (** kwargs )
96
110
self .hidden_dim = hidden_dim
97
111
self .moe_intermediate_dim = moe_intermediate_dim
98
112
self .shared_expert_intermediate_dim = shared_expert_intermediate_dim
99
113
self .num_experts = num_experts
100
114
self .top_k = top_k
101
115
self .norm_topk_prob = norm_topk_prob
102
116
self .kernel_initializer = kernel_initializer
117
+ self .layer_norm_epsilon = layer_norm_epsilon
103
118
104
- def build (self , input_shape ):
105
- self .gate_proj = keras .layers .Dense (
106
- self .hidden_dim ,
119
+ def build (self , decoder_sequence_shape ):
120
+ self ._sparse_feedforward_gate_dense = keras .layers .Dense (
121
+ self .num_experts ,
107
122
kernel_initializer = clone_initializer (self .kernel_initializer ),
108
123
use_bias = False ,
109
124
dtype = self .dtype_policy ,
110
- name = "sparse_block_gate_proj " ,
125
+ name = "sparse_feedforward_gate_dense " ,
111
126
)
127
+ self ._sparse_feedforward_gate_dense .build (decoder_sequence_shape )
112
128
113
129
self .experts = [
114
130
QwenMoeMLP (
115
131
intermediate_dim = self .moe_intermediate_dim ,
116
132
hidden_dim = self .hidden_dim ,
133
+ kernel_initializer = self .kernel_initializer ,
134
+ layer_norm_epsilon = self .layer_norm_epsilon ,
117
135
)
118
136
for _ in range (self .num_experts )
119
137
]
120
- self .shared_expert = QwenMoeMLP (
121
- intermediate_dim = self .shared_expert_intermediate_dim
138
+ for expert in self .experts :
139
+ expert .build (decoder_sequence_shape )
140
+
141
+ self .shared_expert_dense = QwenMoeMLP (
142
+ intermediate_dim = self .shared_expert_intermediate_dim ,
143
+ hidden_dim = self .hidden_dim ,
144
+ kernel_initializer = self .kernel_initializer ,
145
+ layer_norm_epsilon = self .layer_norm_epsilon ,
122
146
)
123
- self .shared_expert_gate_proj = keras .layers .Dense (1 , use_bias = False )
147
+ self .shared_expert_dense .build (decoder_sequence_shape )
148
+
149
+ self .shared_expert_gate_dense = keras .layers .Dense (1 , use_bias = False )
150
+ self .shared_expert_gate_dense .build (decoder_sequence_shape )
151
+ self .built = True
124
152
125
153
def call (self , hidden_states ):
126
154
batch_size , seq_len , hidden_dim = hidden_states .shape
127
155
hidden_states = hidden_states .reshape (- 1 , hidden_dim )
128
156
129
- router_logits = self .gate_proj (hidden_states )
157
+ router_logits = self ._sparse_feedforward_gate_dense (hidden_states )
130
158
131
159
routing_weights = ops .softmax (router_logits , axis = 1 )
132
160
routing_weights , selected_experts = ops .top_k (
@@ -175,7 +203,7 @@ def call(self, hidden_states):
175
203
176
204
shared_expert_output = self .shared_expert (hidden_states )
177
205
shared_expert_output = (
178
- ops .sigmoid (self .shared_expert_gate_proj (hidden_states ))
206
+ ops .sigmoid (self .shared_expert_gate_dense (hidden_states ))
179
207
* shared_expert_output
180
208
)
181
209
@@ -210,6 +238,7 @@ def __init__(
210
238
sliding_window_size = 4096 ,
211
239
layer_index = 0 ,
212
240
mlp_only_layers = [],
241
+ output_router_logits = False ,
213
242
** kwargs ,
214
243
):
215
244
super ().__init__ (** kwargs )
@@ -238,6 +267,7 @@ def __init__(
238
267
self .top_k = top_k
239
268
self .norm_topk_prob = norm_topk_prob
240
269
self .decoder_sparse_step = decoder_sparse_step
270
+ self .output_router_logits = output_router_logits
241
271
242
272
self .supports_masking = True
243
273
@@ -287,11 +317,20 @@ def build(self, decoder_sequence_shape):
287
317
norm_topk_prob = self .norm_topk_prob ,
288
318
kernel_initializer = self .kernel_initializer ,
289
319
)
320
+ self .mlp .build (decoder_sequence_shape )
290
321
else :
291
322
self .mlp = QwenMoeMLP (
292
323
intermediate_dim = self .intermediate_dim ,
293
324
hidden_dim = self .hidden_dim ,
294
325
)
326
+ self .mlp .build (decoder_sequence_shape )
327
+
328
+ self ._feedforward_layernorm = QwenLayerNorm (
329
+ epsilon = self .layer_norm_epsilon ,
330
+ dtype = self .dtype_policy ,
331
+ name = "feedforward_layernorm" ,
332
+ )
333
+ self ._feedforward_layernorm .build (decoder_sequence_shape )
295
334
296
335
self .built = True
297
336
@@ -301,7 +340,6 @@ def call(
301
340
decoder_padding_mask = None ,
302
341
decoder_attention_mask = None ,
303
342
self_attention_cache = None ,
304
- output_router_logits = False ,
305
343
self_attention_cache_update_index = None ,
306
344
training = None ,
307
345
):
@@ -364,7 +402,7 @@ def call(
364
402
if self_attention_cache is not None :
365
403
output += self_attention_cache
366
404
367
- if output_router_logits :
405
+ if self . output_router_logits :
368
406
output += (router_logits ,)
369
407
370
408
return output
0 commit comments