Skip to content

Commit a71f631

Browse files
committed
added efficient MOE avoiding redudant reads in prefill for down weights
Signed-off-by: Onkar Chougule <[email protected]>
1 parent 2997d24 commit a71f631

File tree

1 file changed

+43
-44
lines changed

1 file changed

+43
-44
lines changed

QEfficient/transformers/models/grok_1/modeling_grok1.py

Lines changed: 43 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -103,53 +103,52 @@ def forward(
103103

104104

105105
class QEffGrok1MoeBlock(nn.Module):
106-
def forward(self, hidden: torch.Tensor):
107-
B, S, H = hidden.shape
108-
T = B * S
109-
x = hidden.view(T, H)
110-
111-
router_logits = self.gate(x) # [T, E]
112-
prob = F.softmax(router_logits, -1, dtype=torch.float)
113-
top_w, top_i = torch.topk(prob, self.top_k, -1)
114-
# if self.norm_topk_prob: # only diff with mixtral sparse moe block!
115-
# top_w /= top_w.sum(-1, keepdim=True)
116-
top_w = top_w.to(x.dtype)
117-
118-
# Create 2 expert idx based on the topk
119-
expert1_idx, expert2_idx = top_i[:, 0], top_i[:, 1] # [T]
120-
weight1, weight2 = top_w[:, 0], top_w[:, 1] # [T]
121-
122-
# I = self.config.ffn_dim
123-
I = 32768 # TODO: Find a way to identify from config # Intermediate Size
124-
upgate1 = x.new_zeros((T, I))
125-
upgate2 = x.new_zeros((T, I))
126-
expert_out1 = x.new_zeros((T, H))
127-
expert_out2 = x.new_zeros((T, H))
128-
129-
for e in range(self.num_experts):
130-
exp = self.experts[e]
131-
mask1 = (expert1_idx == e).unsqueeze(1) # [T, 1]
132-
mask2 = (expert2_idx == e).unsqueeze(1) # [T, 1]
133-
134-
hidden_gate = (exp.act_fn(exp.linear_v(x))) * exp.linear(x)
135-
# Accumulate weighted contributions
136-
upgate1 += torch.where(mask1, hidden_gate, torch.zeros_like(upgate1))
137-
upgate2 += torch.where(mask2, hidden_gate, torch.zeros_like(upgate2))
138-
139-
for e in range(self.num_experts):
140-
exp = self.experts[e]
141-
mask1 = (expert1_idx == e).unsqueeze(1)
142-
mask2 = (expert2_idx == e).unsqueeze(1)
143-
144-
expert_out1 += torch.where(
145-
mask1, exp.linear_1(upgate1) * weight1.unsqueeze(1), torch.zeros_like(expert_out1)
106+
def forward(self, hidden_states: torch.Tensor):
107+
batch_size, sequence_length, hidden_dim = hidden_states.shape
108+
hidden_states = hidden_states.view(-1, hidden_dim)
109+
router_logits = self.gate(hidden_states)
110+
111+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
112+
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
113+
# Creating experts mask and routing weights masked
114+
awesome_experts_mask_1 = (
115+
torch.nn.functional.one_hot(selected_experts[:, 0], num_classes=self.num_experts).bool().T.unsqueeze(-1)
116+
)
117+
awesome_experts_mask_2 = (
118+
torch.nn.functional.one_hot(selected_experts[:, 1], num_classes=self.num_experts).bool().T.unsqueeze(-1)
119+
)
120+
121+
gateupout1 = torch.zeros(hidden_states.shape[0], 32768) # T, hs
122+
gateupout2 = torch.zeros(hidden_states.shape[0], 32768) # T, hs
123+
for expert_idx in range(self.num_experts):
124+
expert_layer = self.experts[expert_idx]
125+
current_expert_output = expert_layer.act_fn(expert_layer.linear(hidden_states)) * expert_layer.linear_v(
126+
hidden_states
127+
)
128+
gateupout1 += torch.where(
129+
awesome_experts_mask_1[expert_idx], current_expert_output, torch.zeros_like(gateupout1)
146130
)
147-
expert_out2 += torch.where(
148-
mask2, exp.linear_1(upgate2) * weight2.unsqueeze(1), torch.zeros_like(expert_out2)
131+
gateupout2 += torch.where(
132+
awesome_experts_mask_2[expert_idx], current_expert_output, torch.zeros_like(gateupout2)
149133
)
150134

151-
expert_out = expert_out1 + expert_out2
152-
return expert_out.view(B, S, H), router_logits
135+
downout1 = torch.zeros_like(hidden_states)
136+
downout2 = torch.zeros_like(hidden_states)
137+
concat_mask = torch.cat((awesome_experts_mask_1.unsqueeze(0), awesome_experts_mask_2.unsqueeze(0)), dim=0)
138+
concat_down = torch.cat((downout1.unsqueeze(0), downout2.unsqueeze(0)), dim=0)
139+
concat_gateout = torch.cat((gateupout1.unsqueeze(0), gateupout2.unsqueeze(0)), dim=0)
140+
for expert_idx in range(self.num_experts):
141+
expert_layer = self.experts[expert_idx]
142+
concat_down += torch.where(
143+
concat_mask[:, expert_idx, :], expert_layer.linear_1(concat_gateout), torch.zeros_like(concat_down)
144+
)
145+
146+
downout1, downout2 = concat_down[0], concat_down[1]
147+
hidden_states = (
148+
downout1 * routing_weights[:, 0].unsqueeze(-1) + downout2 * routing_weights[:, 1].unsqueeze(-1)
149+
).reshape(batch_size, sequence_length, hidden_dim)
150+
151+
return hidden_states, router_logits
153152

154153

155154
class QEffGrok1DecoderLayer(nn.Module):

0 commit comments

Comments
 (0)