Skip to content

Commit 2997d24

Browse files
committed
Updated latest MOE changes
Signed-off-by: Amit Raj <[email protected]>
1 parent 88e8bd9 commit 2997d24

File tree

1 file changed

+47
-21
lines changed

1 file changed

+47
-21
lines changed

QEfficient/transformers/models/grok_1/modeling_grok1.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -103,27 +103,53 @@ def forward(
103103

104104

105105
class QEffGrok1MoeBlock(nn.Module):
106-
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
107-
B, S, D = hidden_states.shape # [1, 8, 2304]
108-
hidden_states = hidden_states.reshape(-1, D) # [8, 2304]
109-
T = hidden_states.size(0) # 8 tokens
110-
router_logits = self.gate(hidden_states) # [8, 8]
111-
probs = F.softmax(router_logits, dim=-1) # [8, 8]
112-
113-
topk_scores, topk_indices = torch.topk(probs, self.top_k, dim=-1) # [8, top_k] → topk_k is 2 for Grok1
114-
topk_scores = topk_scores / topk_scores.sum(dim=-1, keepdim=True) # normalize per-token
115-
topk_scores = topk_scores.to(hidden_states.dtype) # [8, top_k]
116-
route = torch.zeros((T, self.num_experts), dtype=hidden_states.dtype)
117-
route.scatter_(1, topk_indices, topk_scores) # [8, num_experts]
118-
final_output = torch.zeros_like(hidden_states) # [8, 2304]
119-
120-
for e, expert in enumerate(self.experts):
121-
scores = route[:, e].unsqueeze(1) # [8, 1]
122-
masked_out = torch.where(
123-
scores > 0, expert(hidden_states) * scores, 0.0
124-
) # # [8, 2304] × [8, 1] → [8, 2304]
125-
final_output += masked_out # accumulate expert outputs
126-
return final_output.reshape(B, S, D), router_logits # ([1, 8, 2304], [8, num_experts])
106+
def forward(self, hidden: torch.Tensor):
107+
B, S, H = hidden.shape
108+
T = B * S
109+
x = hidden.view(T, H)
110+
111+
router_logits = self.gate(x) # [T, E]
112+
prob = F.softmax(router_logits, -1, dtype=torch.float)
113+
top_w, top_i = torch.topk(prob, self.top_k, -1)
114+
# if self.norm_topk_prob: # only diff with mixtral sparse moe block!
115+
# top_w /= top_w.sum(-1, keepdim=True)
116+
top_w = top_w.to(x.dtype)
117+
118+
# Create 2 expert idx based on the topk
119+
expert1_idx, expert2_idx = top_i[:, 0], top_i[:, 1] # [T]
120+
weight1, weight2 = top_w[:, 0], top_w[:, 1] # [T]
121+
122+
# I = self.config.ffn_dim
123+
I = 32768 # TODO: Find a way to identify from config # Intermediate Size
124+
upgate1 = x.new_zeros((T, I))
125+
upgate2 = x.new_zeros((T, I))
126+
expert_out1 = x.new_zeros((T, H))
127+
expert_out2 = x.new_zeros((T, H))
128+
129+
for e in range(self.num_experts):
130+
exp = self.experts[e]
131+
mask1 = (expert1_idx == e).unsqueeze(1) # [T, 1]
132+
mask2 = (expert2_idx == e).unsqueeze(1) # [T, 1]
133+
134+
hidden_gate = (exp.act_fn(exp.linear_v(x))) * exp.linear(x)
135+
# Accumulate weighted contributions
136+
upgate1 += torch.where(mask1, hidden_gate, torch.zeros_like(upgate1))
137+
upgate2 += torch.where(mask2, hidden_gate, torch.zeros_like(upgate2))
138+
139+
for e in range(self.num_experts):
140+
exp = self.experts[e]
141+
mask1 = (expert1_idx == e).unsqueeze(1)
142+
mask2 = (expert2_idx == e).unsqueeze(1)
143+
144+
expert_out1 += torch.where(
145+
mask1, exp.linear_1(upgate1) * weight1.unsqueeze(1), torch.zeros_like(expert_out1)
146+
)
147+
expert_out2 += torch.where(
148+
mask2, exp.linear_1(upgate2) * weight2.unsqueeze(1), torch.zeros_like(expert_out2)
149+
)
150+
151+
expert_out = expert_out1 + expert_out2
152+
return expert_out.view(B, S, H), router_logits
127153

128154

129155
class QEffGrok1DecoderLayer(nn.Module):

0 commit comments

Comments
 (0)