Skip to content

Commit f007978

Browse files
authored
【开源实习】MindSpore自定义RWKV算子开发(Python接口实现) (#1862)
1 parent 54161e2 commit f007978

File tree

4 files changed

+548
-0
lines changed

4 files changed

+548
-0
lines changed

llm/inference/rwkv6/main.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# coding=utf-8
2+
# Copyright 2024 BlinkDL, et al.
3+
# Copyright 2024 yuunnn-w, et al.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Modifications copyright 2024 [Huawei Technologies Co., Ltd]
12+
# Changes: Migrated to MindSpore interface
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
# ============================================================================
20+
import time
21+
import mindspore
22+
from mindspore import ops
23+
from mindnlp.experimental.rwkv6.modeling_rwkv6 import RWKV_RNN
24+
from mindnlp.experimental.rwkv6.tokenizer_rwkv6 import RWKV_TOKENIZER
25+
from mindnlp.experimental.rwkv6.sampler_rwkv6 import sample_logits
26+
27+
if __name__ == '__main__':
28+
args = {
29+
'MODEL_NAME': 'RWKV-x060-World-1B6-v2.1-20240328-ctx4096', #模型文件的名字,ckpt结尾的权重文件。
30+
'vocab_size': 65536, #词表大小
31+
'batch_size': 3
32+
}
33+
34+
# 加载模型和分词器
35+
print("Loading model and tokenizer...")
36+
model = RWKV_RNN(args)
37+
tokenizer = RWKV_TOKENIZER()
38+
print("Done.")
39+
40+
# 设置续写的初始字符串和参数
41+
BATCH_SIZE = args['batch_size']
42+
initial_string = "User: 帮我用python写一个打印字符三角形的代码.\n\nAssistant: "
43+
TEMPERATURE = 2.5 # 温度参数
44+
TOP_P = 0.1 # Top-p采样参数
45+
LENGTH_PER_TRIAL = 50 # 生成的长度
46+
47+
# 编码初始字符串
48+
token = mindspore.Tensor(tokenizer.encode([initial_string] * BATCH_SIZE), dtype=mindspore.int64)
49+
for t in ops.unbind(token, dim=-1):
50+
out = model(t)
51+
else:
52+
token_sampled = sample_logits(out, TEMPERATURE, TOP_P).type_as(token)
53+
token = ops.cat((token, token_sampled.unsqueeze(1)), 1)
54+
55+
start_time = time.time() # 开始计时
56+
for step in range(LENGTH_PER_TRIAL): # 生成指定数量的token
57+
out = model(token_sampled)
58+
token_sampled = sample_logits(out, TEMPERATURE, TOP_P).type_as(token)
59+
token = ops.cat((token, token_sampled.unsqueeze(1)), 1)
60+
end_time = time.time() # 结束计时
61+
62+
# 打印结果
63+
decoded_sequences = tokenizer.decode(token.tolist())
64+
for i, seq in enumerate(decoded_sequences):
65+
print(f"Batch {i+1}: {seq}")
66+
67+
total_time = end_time - start_time
68+
tokens_generated = LENGTH_PER_TRIAL * BATCH_SIZE
69+
speed = tokens_generated / total_time
70+
speed_per_batch = speed / BATCH_SIZE
71+
print(f"\nTotal time: {total_time:.2f} seconds")
72+
print(f"Tokens generated: {tokens_generated}")
73+
print(f"Token generation speed: {speed:.2f} tokens/second")
74+
print(f"Token generation speed per batch: {speed_per_batch:.2f} tokens/second")
75+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# coding=utf-8
2+
# Copyright 2024 BlinkDL, et al.
3+
# Copyright 2024 yuunnn-w, et al.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Modifications copyright 2024 [Huawei Technologies Co., Ltd]
12+
# Changes: Migrated to MindSpore interface
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
# ============================================================================
20+
import mindspore
21+
import mindnlp
22+
import mindnlp.core.nn as nn
23+
import mindnlp.core.ops as ops
24+
from typing import Tuple
25+
26+
27+
class RWKV_BLOCK(nn.Module):
28+
"""
29+
RWKV模型的块结构。
30+
31+
Args:
32+
block_w (dict): 权重字典。
33+
n_embd (int): 嵌入维度。
34+
n_head (int): 头数。
35+
state (mindspore.Tensor): 隐藏状态张量。[Batch_size, State_size, N_embd]
36+
i (int): 时间索引。
37+
"""
38+
def __init__(self, block_w: dict, n_embd: int, n_head: int, state: mindspore.Tensor, i: int):
39+
super().__init__()
40+
self.n_embd = n_embd
41+
self.n_head = n_head
42+
self.head_size = n_embd // n_head
43+
44+
# 时间状态索引
45+
i0 = (2 + self.head_size) * i + 0
46+
i1 = (2 + self.head_size) * i + 1
47+
i2 = (2 + self.head_size) * i + 2
48+
i3 = (2 + self.head_size) * (i + 1)
49+
50+
# 初始化时间状态视图
51+
self.state_view_channel = state[:, i0]
52+
self.state_view_time_1 = state[:, i1]
53+
self.state_view_time_2 = state[:, i2: i3, :]
54+
55+
# 初始化层归一化
56+
self.ln1 = nn.LayerNorm(n_embd)
57+
self.ln1.weight = nn.Parameter(block_w['ln1.weight'])
58+
self.ln1.bias = nn.Parameter(block_w['ln1.bias'])
59+
self.ln2 = nn.LayerNorm(n_embd)
60+
self.ln2.weight = nn.Parameter(block_w['ln2.weight'])
61+
self.ln2.bias = nn.Parameter(block_w['ln2.bias'])
62+
63+
# 初始化激活函数
64+
self.silu = nn.SiLU()
65+
66+
# 初始化注意力参数
67+
self.att_time_maa_x = nn.Parameter(block_w['att.time_maa_x'])
68+
self.att_time_maa = nn.Parameter(ops.stack([block_w['att.time_maa_w'],
69+
block_w['att.time_maa_k'],
70+
block_w['att.time_maa_v'],
71+
block_w['att.time_maa_r'],
72+
block_w['att.time_maa_g']]))
73+
self.att_time_maa_w1 = nn.Parameter(block_w['att.time_maa_w1'])
74+
self.att_time_maa_w2 = nn.Parameter(block_w['att.time_maa_w2'])
75+
self.att_time_decay = nn.Parameter(block_w['att.time_decay'])
76+
self.att_time_decay_w1 = nn.Parameter(block_w['att.time_decay_w1'])
77+
self.att_time_decay_w2 = nn.Parameter(block_w['att.time_decay_w2'])
78+
self.att_time_faaaa = nn.Parameter(block_w['att.time_faaaa'])
79+
self.att_receptance = nn.Linear(self.n_embd, self.n_embd, bias=False)
80+
self.att_receptance.weight = nn.Parameter(block_w['att.receptance.weight'])
81+
self.att_key = nn.Linear(self.n_embd, self.n_embd, bias=False)
82+
self.att_key.weight = nn.Parameter(block_w['att.key.weight'])
83+
self.att_value = nn.Linear(self.n_embd, self.n_embd, bias=False)
84+
self.att_value.weight = nn.Parameter(block_w['att.value.weight'])
85+
self.att_output = nn.Linear(self.n_embd, self.n_embd, bias=False)
86+
self.att_output.weight = nn.Parameter(block_w['att.output.weight'])
87+
self.att_gate = nn.Linear(self.n_embd, self.n_embd, bias=False)
88+
self.att_gate.weight = nn.Parameter(block_w['att.gate.weight'])
89+
self.att_group_norm = nn.GroupNorm(num_groups=n_head, num_channels=n_embd, eps=1e-5, affine=True)
90+
self.att_group_norm.weight = nn.Parameter(block_w['att.ln_x.weight'])
91+
self.att_group_norm.bias = nn.Parameter(block_w['att.ln_x.bias'])
92+
93+
# 初始化前馈参数
94+
self.ffn_time_maa_k = nn.Parameter(block_w['ffn.time_maa_k'])
95+
self.ffn_time_maa_r = nn.Parameter(block_w['ffn.time_maa_r'])
96+
self.ffn_key = nn.Linear(self.n_embd, self.n_embd, bias=False)
97+
self.ffn_key.weight = nn.Parameter(block_w['ffn.key.weight'])
98+
self.ffn_receptance = nn.Linear(self.n_embd, self.n_embd, bias=False)
99+
self.ffn_receptance.weight = nn.Parameter(block_w['ffn.receptance.weight'])
100+
self.ffn_value = nn.Linear(self.n_embd, self.n_embd, bias=False)
101+
self.ffn_value.weight = nn.Parameter(block_w['ffn.value.weight'])
102+
103+
def channel_mixing(self, x: mindspore.Tensor) -> mindspore.Tensor:
104+
"""
105+
通道混合函数。
106+
107+
Args:
108+
x (mindspore.Tensor): 输入张量,形状为[Batch, 2048]。
109+
Returns:
110+
mindspore.Tensor: 混合后的张量,形状与输入的x相同。
111+
"""
112+
sx = self.state_view_channel - x
113+
self.state_view_channel = x
114+
xk = x + sx * self.ffn_time_maa_k
115+
xr = x + sx * self.ffn_time_maa_r
116+
r = nn.functional.sigmoid(self.ffn_receptance(xr))
117+
k = nn.functional.relu(self.ffn_key(xk)).pow(2)
118+
output = r * self.ffn_value(k)
119+
return output
120+
121+
def time_mixing(self, x: mindspore.Tensor) -> mindspore.Tensor:
122+
"""
123+
时间混合函数。
124+
125+
Args:
126+
x (mindspore.Tensor): 输入张量,形状为[Batch, 2048]。
127+
Returns:
128+
mindspore.Tensor: 混合后的时间状态张量,形状与输入的state相同。
129+
"""
130+
batch_size, H, S = x.shape[0], self.n_head, self.head_size
131+
132+
sx = (self.state_view_time_1 - x)
133+
self.state_view_time_1 = x
134+
135+
xxx = x + sx * self.att_time_maa_x
136+
xxx = ops.tanh(xxx @ self.att_time_maa_w1).view(batch_size, 5, 1, -1)
137+
xxx = ops.matmul(xxx, self.att_time_maa_w2).view(batch_size, 5, -1)
138+
139+
xw, xk, xv, xr, xg = ops.unbind(x.unsqueeze(1) + sx.unsqueeze(1) * (self.att_time_maa + xxx), dim=1)
140+
141+
w = (self.att_time_decay + (ops.tanh(xw @ self.att_time_decay_w1) @ self.att_time_decay_w2))
142+
143+
# 计算注意力机制的权重
144+
w = ops.exp(-ops.exp(w.view(batch_size, H, S, 1)))
145+
146+
# 计算注意力机制的组件
147+
r = self.att_receptance(xr).view(batch_size, H, 1, S)
148+
k = self.att_key(xk).view(batch_size, H, S, 1)
149+
v = self.att_value(xv).view(batch_size, H, 1, S)
150+
g = self.silu(self.att_gate(xg))
151+
152+
# 使用注意力机制更新状态
153+
s = self.state_view_time_2.view(batch_size, H, S, S)
154+
a = k @ v
155+
x = r @ (self.att_time_faaaa * a + s)
156+
s = a + w * s
157+
self.state_view_time_2 = s.view(batch_size, S, -1)
158+
159+
# 展平x并应用组归一化和门控
160+
x = self.att_group_norm(x.flatten(start_dim=1)) * g
161+
162+
# 应用输出层并返回结果
163+
return self.att_output(x)
164+
165+
def forward(self, x: mindspore.Tensor) -> mindspore.Tensor:
166+
"""
167+
模型的前向传播。
168+
Args:
169+
x (mindspore.Tensor): 输入张量,形状为[Batch, N_embd]。
170+
Returns:
171+
mindspore.Tensor: 前向传播结果张量,形状与输入的x相同。
172+
"""
173+
x = x + self.time_mixing(self.ln1(x))
174+
x = x + self.channel_mixing(self.ln2(x))
175+
return x
176+
177+
178+
class RWKV_RNN(nn.Module):
179+
"""
180+
RWKV模型的RNN结构。
181+
182+
Args:
183+
args (dict): 参数字典。
184+
"""
185+
def __init__(self, args: dict):
186+
super().__init__()
187+
self.args = args
188+
self.set_train(False)
189+
190+
# 加载权重
191+
w = mindnlp.core.serialization.load(args['MODEL_NAME'] + '.pth')
192+
193+
# 将所有权重转换为float32
194+
self.num_layer = 0
195+
for k in w.keys():
196+
w[k] = w[k].float()
197+
if '.time_' in k: w[k] = w[k].squeeze()
198+
if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)
199+
if "blocks" in k: self.num_layer = max(self.num_layer, int(k.split(".")[1]))
200+
201+
self.num_layer += 1
202+
203+
self.n_head = w['blocks.0.att.time_faaaa'].shape[0]
204+
self.n_embd = w['blocks.0.ln1.weight'].shape[0]
205+
self.head_size = self.n_embd // self.n_head
206+
self.state_size = [self.num_layer * (2 + self.head_size), self.n_embd]
207+
self.batch_size = args['batch_size']
208+
209+
print(f"state_size: {self.state_size}") # 这里打印状态的形状
210+
211+
# 初始化模型参数
212+
self.emb = nn.Embedding.from_pretrained(w['emb.weight'], freeze=True)
213+
self.ln0 = nn.LayerNorm(self.n_embd)
214+
self.ln0.weight = nn.Parameter(w['blocks.0.ln0.weight'])
215+
self.ln0.bias = nn.Parameter(w['blocks.0.ln0.bias'])
216+
self.blocks = nn.ModuleList()
217+
218+
# 初始化状态
219+
self.state = ops.zeros([self.batch_size, *self.state_size])
220+
221+
for i in range(self.num_layer):
222+
# 提取当前块的权重
223+
block_w = {k[len(f'blocks.{i}.'):]: v for k, v in w.items() if f'blocks.{i}.' in k}
224+
self.blocks.append(RWKV_BLOCK(block_w, self.n_embd, self.n_head, self.state, i))
225+
print(f"Loading blocks...[{i + 1}/{self.num_layer}]", end='\r')
226+
print()
227+
228+
self.ln_out = nn.LayerNorm(self.n_embd)
229+
self.ln_out.weight = nn.Parameter(w['ln_out.weight'])
230+
self.ln_out.bias = nn.Parameter(w['ln_out.bias'])
231+
self.head = nn.Linear(self.n_embd, args['vocab_size'], bias=False)
232+
self.head.weight = nn.Parameter(w['head.weight'])
233+
234+
def forward(self, token: mindspore.Tensor) -> Tuple[mindspore.Tensor, mindspore.Tensor]:
235+
"""
236+
模型的前向传播。
237+
Args:
238+
token (mindspore.Tensor): 输入的令牌张量。[Batch_size]
239+
Returns:
240+
mindspore.Tensor: 模型输出。
241+
"""
242+
x = self.emb(token)
243+
x = self.ln0(x)
244+
for block in self.blocks:
245+
x = block(x)
246+
x = self.ln_out(x)
247+
x = self.head(x)
248+
return x
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# coding=utf-8
2+
# Copyright 2024 BlinkDL, et al.
3+
# Copyright 2024 yuunnn-w, et al.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Modifications copyright 2024 [Huawei Technologies Co., Ltd]
12+
# Changes: Migrated to MindSpore interface
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
# ============================================================================
20+
import mindspore
21+
from mindnlp.core import ops
22+
23+
24+
def sample_logits(out: mindspore.Tensor, temperature: float = 1.0, top_p: float = 0.8) -> mindspore.Tensor:
25+
"""
26+
Sample from the logits tensor produced by the model.
27+
28+
Args:
29+
out (mindspore.Tensor): Logits tensor from the model, shape [* , vocab_size].
30+
temperature (float): Temperature parameter for controlling the diversity of sampling. Default is 1.0.
31+
top_p (float): Top-p truncation parameter for stabilizing and controlling the sampling probability distribution. Default is 0.8.
32+
33+
Returns:
34+
mindspore.Tensor: Sampled indices, shape [*].
35+
"""
36+
# Apply temperature scaling
37+
scaled_logits = out / temperature
38+
39+
# Convert logits to probabilities
40+
probabilities = ops.softmax(scaled_logits, dim=-1)
41+
42+
# Sort the probabilities to identify the top-p candidates
43+
sorted_probs, sorted_indices = ops.sort(probabilities, descending=True)
44+
45+
# Compute the cumulative distribution of probabilities
46+
cumulative_probs = ops.cumsum(sorted_probs, dim=-1)
47+
48+
# Remove tokens with a cumulative probability above the threshold (top_p)
49+
sorted_indices_to_remove = cumulative_probs > top_p
50+
# Shift the indices to the right to keep the first token above the threshold
51+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].copy()
52+
sorted_indices_to_remove[..., 0] = 0
53+
54+
# Create a mask for the indices to remove
55+
indices_to_remove = sorted_indices_to_remove.scatter(axis=-1, index=sorted_indices, src=sorted_indices_to_remove)
56+
57+
# Use the mask to zero out probabilities that should be removed
58+
probabilities[indices_to_remove] = 0.0
59+
60+
# Resample if probabilities are all zero (unlikely but just in case)
61+
if ops.all(probabilities == 0):
62+
probabilities = ops.ones_like(probabilities)
63+
probabilities /= probabilities.sum()
64+
65+
# Sample from the modified distribution
66+
sampled_indices = ops.multinomial(probabilities, 1)
67+
68+
return sampled_indices.squeeze(-1)

0 commit comments

Comments
 (0)