@@ -103,53 +103,52 @@ def forward(
103
103
104
104
105
105
class QEffGrok1MoeBlock (nn .Module ):
106
- def forward (self , hidden : torch .Tensor ):
107
- B , S , H = hidden .shape
108
- T = B * S
109
- x = hidden .view (T , H )
110
-
111
- router_logits = self .gate (x ) # [T, E]
112
- prob = F .softmax (router_logits , - 1 , dtype = torch .float )
113
- top_w , top_i = torch .topk (prob , self .top_k , - 1 )
114
- # if self.norm_topk_prob: # only diff with mixtral sparse moe block!
115
- # top_w /= top_w.sum(-1, keepdim=True)
116
- top_w = top_w .to (x .dtype )
117
-
118
- # Create 2 expert idx based on the topk
119
- expert1_idx , expert2_idx = top_i [:, 0 ], top_i [:, 1 ] # [T]
120
- weight1 , weight2 = top_w [:, 0 ], top_w [:, 1 ] # [T]
121
-
122
- # I = self.config.ffn_dim
123
- I = 32768 # TODO: Find a way to identify from config # Intermediate Size
124
- upgate1 = x .new_zeros ((T , I ))
125
- upgate2 = x .new_zeros ((T , I ))
126
- expert_out1 = x .new_zeros ((T , H ))
127
- expert_out2 = x .new_zeros ((T , H ))
128
-
129
- for e in range (self .num_experts ):
130
- exp = self .experts [e ]
131
- mask1 = (expert1_idx == e ).unsqueeze (1 ) # [T, 1]
132
- mask2 = (expert2_idx == e ).unsqueeze (1 ) # [T, 1]
133
-
134
- hidden_gate = (exp .act_fn (exp .linear_v (x ))) * exp .linear (x )
135
- # Accumulate weighted contributions
136
- upgate1 += torch .where (mask1 , hidden_gate , torch .zeros_like (upgate1 ))
137
- upgate2 += torch .where (mask2 , hidden_gate , torch .zeros_like (upgate2 ))
138
-
139
- for e in range (self .num_experts ):
140
- exp = self .experts [e ]
141
- mask1 = (expert1_idx == e ).unsqueeze (1 )
142
- mask2 = (expert2_idx == e ).unsqueeze (1 )
143
-
144
- expert_out1 += torch .where (
145
- mask1 , exp .linear_1 (upgate1 ) * weight1 .unsqueeze (1 ), torch .zeros_like (expert_out1 )
106
+ def forward (self , hidden_states : torch .Tensor ):
107
+ batch_size , sequence_length , hidden_dim = hidden_states .shape
108
+ hidden_states = hidden_states .view (- 1 , hidden_dim )
109
+ router_logits = self .gate (hidden_states )
110
+
111
+ routing_weights = F .softmax (router_logits , dim = 1 , dtype = torch .float )
112
+ routing_weights , selected_experts = torch .topk (routing_weights , self .top_k , dim = - 1 )
113
+ # Creating experts mask and routing weights masked
114
+ awesome_experts_mask_1 = (
115
+ torch .nn .functional .one_hot (selected_experts [:, 0 ], num_classes = self .num_experts ).bool ().T .unsqueeze (- 1 )
116
+ )
117
+ awesome_experts_mask_2 = (
118
+ torch .nn .functional .one_hot (selected_experts [:, 1 ], num_classes = self .num_experts ).bool ().T .unsqueeze (- 1 )
119
+ )
120
+
121
+ gateupout1 = torch .zeros (hidden_states .shape [0 ], 32768 ) # T, hs
122
+ gateupout2 = torch .zeros (hidden_states .shape [0 ], 32768 ) # T, hs
123
+ for expert_idx in range (self .num_experts ):
124
+ expert_layer = self .experts [expert_idx ]
125
+ current_expert_output = expert_layer .act_fn (expert_layer .linear (hidden_states )) * expert_layer .linear_v (
126
+ hidden_states
127
+ )
128
+ gateupout1 += torch .where (
129
+ awesome_experts_mask_1 [expert_idx ], current_expert_output , torch .zeros_like (gateupout1 )
146
130
)
147
- expert_out2 += torch .where (
148
- mask2 , exp . linear_1 ( upgate2 ) * weight2 . unsqueeze ( 1 ) , torch .zeros_like (expert_out2 )
131
+ gateupout2 += torch .where (
132
+ awesome_experts_mask_2 [ expert_idx ], current_expert_output , torch .zeros_like (gateupout2 )
149
133
)
150
134
151
- expert_out = expert_out1 + expert_out2
152
- return expert_out .view (B , S , H ), router_logits
135
+ downout1 = torch .zeros_like (hidden_states )
136
+ downout2 = torch .zeros_like (hidden_states )
137
+ concat_mask = torch .cat ((awesome_experts_mask_1 .unsqueeze (0 ), awesome_experts_mask_2 .unsqueeze (0 )), dim = 0 )
138
+ concat_down = torch .cat ((downout1 .unsqueeze (0 ), downout2 .unsqueeze (0 )), dim = 0 )
139
+ concat_gateout = torch .cat ((gateupout1 .unsqueeze (0 ), gateupout2 .unsqueeze (0 )), dim = 0 )
140
+ for expert_idx in range (self .num_experts ):
141
+ expert_layer = self .experts [expert_idx ]
142
+ concat_down += torch .where (
143
+ concat_mask [:, expert_idx , :], expert_layer .linear_1 (concat_gateout ), torch .zeros_like (concat_down )
144
+ )
145
+
146
+ downout1 , downout2 = concat_down [0 ], concat_down [1 ]
147
+ hidden_states = (
148
+ downout1 * routing_weights [:, 0 ].unsqueeze (- 1 ) + downout2 * routing_weights [:, 1 ].unsqueeze (- 1 )
149
+ ).reshape (batch_size , sequence_length , hidden_dim )
150
+
151
+ return hidden_states , router_logits
153
152
154
153
155
154
class QEffGrok1DecoderLayer (nn .Module ):
0 commit comments