@@ -103,27 +103,53 @@ def forward(
103
103
104
104
105
105
class QEffGrok1MoeBlock (nn .Module ):
106
- def forward (self , hidden_states : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
107
- B , S , D = hidden_states .shape # [1, 8, 2304]
108
- hidden_states = hidden_states .reshape (- 1 , D ) # [8, 2304]
109
- T = hidden_states .size (0 ) # 8 tokens
110
- router_logits = self .gate (hidden_states ) # [8, 8]
111
- probs = F .softmax (router_logits , dim = - 1 ) # [8, 8]
112
-
113
- topk_scores , topk_indices = torch .topk (probs , self .top_k , dim = - 1 ) # [8, top_k] → topk_k is 2 for Grok1
114
- topk_scores = topk_scores / topk_scores .sum (dim = - 1 , keepdim = True ) # normalize per-token
115
- topk_scores = topk_scores .to (hidden_states .dtype ) # [8, top_k]
116
- route = torch .zeros ((T , self .num_experts ), dtype = hidden_states .dtype )
117
- route .scatter_ (1 , topk_indices , topk_scores ) # [8, num_experts]
118
- final_output = torch .zeros_like (hidden_states ) # [8, 2304]
119
-
120
- for e , expert in enumerate (self .experts ):
121
- scores = route [:, e ].unsqueeze (1 ) # [8, 1]
122
- masked_out = torch .where (
123
- scores > 0 , expert (hidden_states ) * scores , 0.0
124
- ) # # [8, 2304] × [8, 1] → [8, 2304]
125
- final_output += masked_out # accumulate expert outputs
126
- return final_output .reshape (B , S , D ), router_logits # ([1, 8, 2304], [8, num_experts])
106
+ def forward (self , hidden : torch .Tensor ):
107
+ B , S , H = hidden .shape
108
+ T = B * S
109
+ x = hidden .view (T , H )
110
+
111
+ router_logits = self .gate (x ) # [T, E]
112
+ prob = F .softmax (router_logits , - 1 , dtype = torch .float )
113
+ top_w , top_i = torch .topk (prob , self .top_k , - 1 )
114
+ # if self.norm_topk_prob: # only diff with mixtral sparse moe block!
115
+ # top_w /= top_w.sum(-1, keepdim=True)
116
+ top_w = top_w .to (x .dtype )
117
+
118
+ # Create 2 expert idx based on the topk
119
+ expert1_idx , expert2_idx = top_i [:, 0 ], top_i [:, 1 ] # [T]
120
+ weight1 , weight2 = top_w [:, 0 ], top_w [:, 1 ] # [T]
121
+
122
+ # I = self.config.ffn_dim
123
+ I = 32768 # TODO: Find a way to identify from config # Intermediate Size
124
+ upgate1 = x .new_zeros ((T , I ))
125
+ upgate2 = x .new_zeros ((T , I ))
126
+ expert_out1 = x .new_zeros ((T , H ))
127
+ expert_out2 = x .new_zeros ((T , H ))
128
+
129
+ for e in range (self .num_experts ):
130
+ exp = self .experts [e ]
131
+ mask1 = (expert1_idx == e ).unsqueeze (1 ) # [T, 1]
132
+ mask2 = (expert2_idx == e ).unsqueeze (1 ) # [T, 1]
133
+
134
+ hidden_gate = (exp .act_fn (exp .linear_v (x ))) * exp .linear (x )
135
+ # Accumulate weighted contributions
136
+ upgate1 += torch .where (mask1 , hidden_gate , torch .zeros_like (upgate1 ))
137
+ upgate2 += torch .where (mask2 , hidden_gate , torch .zeros_like (upgate2 ))
138
+
139
+ for e in range (self .num_experts ):
140
+ exp = self .experts [e ]
141
+ mask1 = (expert1_idx == e ).unsqueeze (1 )
142
+ mask2 = (expert2_idx == e ).unsqueeze (1 )
143
+
144
+ expert_out1 += torch .where (
145
+ mask1 , exp .linear_1 (upgate1 ) * weight1 .unsqueeze (1 ), torch .zeros_like (expert_out1 )
146
+ )
147
+ expert_out2 += torch .where (
148
+ mask2 , exp .linear_1 (upgate2 ) * weight2 .unsqueeze (1 ), torch .zeros_like (expert_out2 )
149
+ )
150
+
151
+ expert_out = expert_out1 + expert_out2
152
+ return expert_out .view (B , S , H ), router_logits
127
153
128
154
129
155
class QEffGrok1DecoderLayer (nn .Module ):
0 commit comments