|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2024 BlinkDL, et al. |
| 3 | +# Copyright 2024 yuunnn-w, et al. |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Modifications copyright 2024 [Huawei Technologies Co., Ltd] |
| 12 | +# Changes: Migrated to MindSpore interface |
| 13 | +# |
| 14 | +# Unless required by applicable law or agreed to in writing, software |
| 15 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 16 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 17 | +# See the License for the specific language governing permissions and |
| 18 | +# limitations under the License. |
| 19 | +# ============================================================================ |
| 20 | +import mindspore |
| 21 | +import mindnlp |
| 22 | +import mindnlp.core.nn as nn |
| 23 | +import mindnlp.core.ops as ops |
| 24 | +from typing import Tuple |
| 25 | + |
| 26 | + |
| 27 | +class RWKV_BLOCK(nn.Module): |
| 28 | + """ |
| 29 | + RWKV模型的块结构。 |
| 30 | +
|
| 31 | + Args: |
| 32 | + block_w (dict): 权重字典。 |
| 33 | + n_embd (int): 嵌入维度。 |
| 34 | + n_head (int): 头数。 |
| 35 | + state (mindspore.Tensor): 隐藏状态张量。[Batch_size, State_size, N_embd] |
| 36 | + i (int): 时间索引。 |
| 37 | + """ |
| 38 | + def __init__(self, block_w: dict, n_embd: int, n_head: int, state: mindspore.Tensor, i: int): |
| 39 | + super().__init__() |
| 40 | + self.n_embd = n_embd |
| 41 | + self.n_head = n_head |
| 42 | + self.head_size = n_embd // n_head |
| 43 | + |
| 44 | + # 时间状态索引 |
| 45 | + i0 = (2 + self.head_size) * i + 0 |
| 46 | + i1 = (2 + self.head_size) * i + 1 |
| 47 | + i2 = (2 + self.head_size) * i + 2 |
| 48 | + i3 = (2 + self.head_size) * (i + 1) |
| 49 | + |
| 50 | + # 初始化时间状态视图 |
| 51 | + self.state_view_channel = state[:, i0] |
| 52 | + self.state_view_time_1 = state[:, i1] |
| 53 | + self.state_view_time_2 = state[:, i2: i3, :] |
| 54 | + |
| 55 | + # 初始化层归一化 |
| 56 | + self.ln1 = nn.LayerNorm(n_embd) |
| 57 | + self.ln1.weight = nn.Parameter(block_w['ln1.weight']) |
| 58 | + self.ln1.bias = nn.Parameter(block_w['ln1.bias']) |
| 59 | + self.ln2 = nn.LayerNorm(n_embd) |
| 60 | + self.ln2.weight = nn.Parameter(block_w['ln2.weight']) |
| 61 | + self.ln2.bias = nn.Parameter(block_w['ln2.bias']) |
| 62 | + |
| 63 | + # 初始化激活函数 |
| 64 | + self.silu = nn.SiLU() |
| 65 | + |
| 66 | + # 初始化注意力参数 |
| 67 | + self.att_time_maa_x = nn.Parameter(block_w['att.time_maa_x']) |
| 68 | + self.att_time_maa = nn.Parameter(ops.stack([block_w['att.time_maa_w'], |
| 69 | + block_w['att.time_maa_k'], |
| 70 | + block_w['att.time_maa_v'], |
| 71 | + block_w['att.time_maa_r'], |
| 72 | + block_w['att.time_maa_g']])) |
| 73 | + self.att_time_maa_w1 = nn.Parameter(block_w['att.time_maa_w1']) |
| 74 | + self.att_time_maa_w2 = nn.Parameter(block_w['att.time_maa_w2']) |
| 75 | + self.att_time_decay = nn.Parameter(block_w['att.time_decay']) |
| 76 | + self.att_time_decay_w1 = nn.Parameter(block_w['att.time_decay_w1']) |
| 77 | + self.att_time_decay_w2 = nn.Parameter(block_w['att.time_decay_w2']) |
| 78 | + self.att_time_faaaa = nn.Parameter(block_w['att.time_faaaa']) |
| 79 | + self.att_receptance = nn.Linear(self.n_embd, self.n_embd, bias=False) |
| 80 | + self.att_receptance.weight = nn.Parameter(block_w['att.receptance.weight']) |
| 81 | + self.att_key = nn.Linear(self.n_embd, self.n_embd, bias=False) |
| 82 | + self.att_key.weight = nn.Parameter(block_w['att.key.weight']) |
| 83 | + self.att_value = nn.Linear(self.n_embd, self.n_embd, bias=False) |
| 84 | + self.att_value.weight = nn.Parameter(block_w['att.value.weight']) |
| 85 | + self.att_output = nn.Linear(self.n_embd, self.n_embd, bias=False) |
| 86 | + self.att_output.weight = nn.Parameter(block_w['att.output.weight']) |
| 87 | + self.att_gate = nn.Linear(self.n_embd, self.n_embd, bias=False) |
| 88 | + self.att_gate.weight = nn.Parameter(block_w['att.gate.weight']) |
| 89 | + self.att_group_norm = nn.GroupNorm(num_groups=n_head, num_channels=n_embd, eps=1e-5, affine=True) |
| 90 | + self.att_group_norm.weight = nn.Parameter(block_w['att.ln_x.weight']) |
| 91 | + self.att_group_norm.bias = nn.Parameter(block_w['att.ln_x.bias']) |
| 92 | + |
| 93 | + # 初始化前馈参数 |
| 94 | + self.ffn_time_maa_k = nn.Parameter(block_w['ffn.time_maa_k']) |
| 95 | + self.ffn_time_maa_r = nn.Parameter(block_w['ffn.time_maa_r']) |
| 96 | + self.ffn_key = nn.Linear(self.n_embd, self.n_embd, bias=False) |
| 97 | + self.ffn_key.weight = nn.Parameter(block_w['ffn.key.weight']) |
| 98 | + self.ffn_receptance = nn.Linear(self.n_embd, self.n_embd, bias=False) |
| 99 | + self.ffn_receptance.weight = nn.Parameter(block_w['ffn.receptance.weight']) |
| 100 | + self.ffn_value = nn.Linear(self.n_embd, self.n_embd, bias=False) |
| 101 | + self.ffn_value.weight = nn.Parameter(block_w['ffn.value.weight']) |
| 102 | + |
| 103 | + def channel_mixing(self, x: mindspore.Tensor) -> mindspore.Tensor: |
| 104 | + """ |
| 105 | + 通道混合函数。 |
| 106 | +
|
| 107 | + Args: |
| 108 | + x (mindspore.Tensor): 输入张量,形状为[Batch, 2048]。 |
| 109 | + Returns: |
| 110 | + mindspore.Tensor: 混合后的张量,形状与输入的x相同。 |
| 111 | + """ |
| 112 | + sx = self.state_view_channel - x |
| 113 | + self.state_view_channel = x |
| 114 | + xk = x + sx * self.ffn_time_maa_k |
| 115 | + xr = x + sx * self.ffn_time_maa_r |
| 116 | + r = nn.functional.sigmoid(self.ffn_receptance(xr)) |
| 117 | + k = nn.functional.relu(self.ffn_key(xk)).pow(2) |
| 118 | + output = r * self.ffn_value(k) |
| 119 | + return output |
| 120 | + |
| 121 | + def time_mixing(self, x: mindspore.Tensor) -> mindspore.Tensor: |
| 122 | + """ |
| 123 | + 时间混合函数。 |
| 124 | +
|
| 125 | + Args: |
| 126 | + x (mindspore.Tensor): 输入张量,形状为[Batch, 2048]。 |
| 127 | + Returns: |
| 128 | + mindspore.Tensor: 混合后的时间状态张量,形状与输入的state相同。 |
| 129 | + """ |
| 130 | + batch_size, H, S = x.shape[0], self.n_head, self.head_size |
| 131 | + |
| 132 | + sx = (self.state_view_time_1 - x) |
| 133 | + self.state_view_time_1 = x |
| 134 | + |
| 135 | + xxx = x + sx * self.att_time_maa_x |
| 136 | + xxx = ops.tanh(xxx @ self.att_time_maa_w1).view(batch_size, 5, 1, -1) |
| 137 | + xxx = ops.matmul(xxx, self.att_time_maa_w2).view(batch_size, 5, -1) |
| 138 | + |
| 139 | + xw, xk, xv, xr, xg = ops.unbind(x.unsqueeze(1) + sx.unsqueeze(1) * (self.att_time_maa + xxx), dim=1) |
| 140 | + |
| 141 | + w = (self.att_time_decay + (ops.tanh(xw @ self.att_time_decay_w1) @ self.att_time_decay_w2)) |
| 142 | + |
| 143 | + # 计算注意力机制的权重 |
| 144 | + w = ops.exp(-ops.exp(w.view(batch_size, H, S, 1))) |
| 145 | + |
| 146 | + # 计算注意力机制的组件 |
| 147 | + r = self.att_receptance(xr).view(batch_size, H, 1, S) |
| 148 | + k = self.att_key(xk).view(batch_size, H, S, 1) |
| 149 | + v = self.att_value(xv).view(batch_size, H, 1, S) |
| 150 | + g = self.silu(self.att_gate(xg)) |
| 151 | + |
| 152 | + # 使用注意力机制更新状态 |
| 153 | + s = self.state_view_time_2.view(batch_size, H, S, S) |
| 154 | + a = k @ v |
| 155 | + x = r @ (self.att_time_faaaa * a + s) |
| 156 | + s = a + w * s |
| 157 | + self.state_view_time_2 = s.view(batch_size, S, -1) |
| 158 | + |
| 159 | + # 展平x并应用组归一化和门控 |
| 160 | + x = self.att_group_norm(x.flatten(start_dim=1)) * g |
| 161 | + |
| 162 | + # 应用输出层并返回结果 |
| 163 | + return self.att_output(x) |
| 164 | + |
| 165 | + def forward(self, x: mindspore.Tensor) -> mindspore.Tensor: |
| 166 | + """ |
| 167 | + 模型的前向传播。 |
| 168 | + Args: |
| 169 | + x (mindspore.Tensor): 输入张量,形状为[Batch, N_embd]。 |
| 170 | + Returns: |
| 171 | + mindspore.Tensor: 前向传播结果张量,形状与输入的x相同。 |
| 172 | + """ |
| 173 | + x = x + self.time_mixing(self.ln1(x)) |
| 174 | + x = x + self.channel_mixing(self.ln2(x)) |
| 175 | + return x |
| 176 | + |
| 177 | + |
| 178 | +class RWKV_RNN(nn.Module): |
| 179 | + """ |
| 180 | + RWKV模型的RNN结构。 |
| 181 | +
|
| 182 | + Args: |
| 183 | + args (dict): 参数字典。 |
| 184 | + """ |
| 185 | + def __init__(self, args: dict): |
| 186 | + super().__init__() |
| 187 | + self.args = args |
| 188 | + self.set_train(False) |
| 189 | + |
| 190 | + # 加载权重 |
| 191 | + w = mindnlp.core.serialization.load(args['MODEL_NAME'] + '.pth') |
| 192 | + |
| 193 | + # 将所有权重转换为float32 |
| 194 | + self.num_layer = 0 |
| 195 | + for k in w.keys(): |
| 196 | + w[k] = w[k].float() |
| 197 | + if '.time_' in k: w[k] = w[k].squeeze() |
| 198 | + if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1) |
| 199 | + if "blocks" in k: self.num_layer = max(self.num_layer, int(k.split(".")[1])) |
| 200 | + |
| 201 | + self.num_layer += 1 |
| 202 | + |
| 203 | + self.n_head = w['blocks.0.att.time_faaaa'].shape[0] |
| 204 | + self.n_embd = w['blocks.0.ln1.weight'].shape[0] |
| 205 | + self.head_size = self.n_embd // self.n_head |
| 206 | + self.state_size = [self.num_layer * (2 + self.head_size), self.n_embd] |
| 207 | + self.batch_size = args['batch_size'] |
| 208 | + |
| 209 | + print(f"state_size: {self.state_size}") # 这里打印状态的形状 |
| 210 | + |
| 211 | + # 初始化模型参数 |
| 212 | + self.emb = nn.Embedding.from_pretrained(w['emb.weight'], freeze=True) |
| 213 | + self.ln0 = nn.LayerNorm(self.n_embd) |
| 214 | + self.ln0.weight = nn.Parameter(w['blocks.0.ln0.weight']) |
| 215 | + self.ln0.bias = nn.Parameter(w['blocks.0.ln0.bias']) |
| 216 | + self.blocks = nn.ModuleList() |
| 217 | + |
| 218 | + # 初始化状态 |
| 219 | + self.state = ops.zeros([self.batch_size, *self.state_size]) |
| 220 | + |
| 221 | + for i in range(self.num_layer): |
| 222 | + # 提取当前块的权重 |
| 223 | + block_w = {k[len(f'blocks.{i}.'):]: v for k, v in w.items() if f'blocks.{i}.' in k} |
| 224 | + self.blocks.append(RWKV_BLOCK(block_w, self.n_embd, self.n_head, self.state, i)) |
| 225 | + print(f"Loading blocks...[{i + 1}/{self.num_layer}]", end='\r') |
| 226 | + print() |
| 227 | + |
| 228 | + self.ln_out = nn.LayerNorm(self.n_embd) |
| 229 | + self.ln_out.weight = nn.Parameter(w['ln_out.weight']) |
| 230 | + self.ln_out.bias = nn.Parameter(w['ln_out.bias']) |
| 231 | + self.head = nn.Linear(self.n_embd, args['vocab_size'], bias=False) |
| 232 | + self.head.weight = nn.Parameter(w['head.weight']) |
| 233 | + |
| 234 | + def forward(self, token: mindspore.Tensor) -> Tuple[mindspore.Tensor, mindspore.Tensor]: |
| 235 | + """ |
| 236 | + 模型的前向传播。 |
| 237 | + Args: |
| 238 | + token (mindspore.Tensor): 输入的令牌张量。[Batch_size] |
| 239 | + Returns: |
| 240 | + mindspore.Tensor: 模型输出。 |
| 241 | + """ |
| 242 | + x = self.emb(token) |
| 243 | + x = self.ln0(x) |
| 244 | + for block in self.blocks: |
| 245 | + x = block(x) |
| 246 | + x = self.ln_out(x) |
| 247 | + x = self.head(x) |
| 248 | + return x |
0 commit comments