|
| 1 | +from dataclasses import dataclass, field |
| 2 | +from typing import Literal |
| 3 | + |
| 4 | +import torch |
| 5 | +from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerFast |
| 6 | + |
| 7 | +IM_START_TOKEN = "<|im_start|>" |
| 8 | +IM_END_TOKEN = "<|im_end|>" |
1 | 9 | SEMANTIC_TOKEN = "<|semantic|>"
|
| 10 | +MEL_TOKEN = "<|mel|>" |
| 11 | +PHONEME_START_TOKEN = "<|phoneme_start|>" |
| 12 | +PHONEME_END_TOKEN = "<|phoneme_end|>" |
| 13 | +ALL_SPECIAL_TOKENS = [ |
| 14 | + IM_START_TOKEN, |
| 15 | + IM_END_TOKEN, |
| 16 | + SEMANTIC_TOKEN, |
| 17 | + MEL_TOKEN, |
| 18 | + PHONEME_START_TOKEN, |
| 19 | + PHONEME_END_TOKEN, |
| 20 | +] |
| 21 | + |
2 | 22 | CODEBOOK_PAD_TOKEN_ID = 0
|
| 23 | + |
| 24 | + |
| 25 | +class FishTokenizerConfig(PretrainedConfig): |
| 26 | + share_codebook_embeddings: bool = True |
| 27 | + codebook_size: int = 1024 |
| 28 | + num_codebooks: int = 8 |
| 29 | + |
| 30 | + |
| 31 | +class FishTokenizerFast(PreTrainedTokenizerFast): |
| 32 | + def __init__(self, *args, **kwargs): |
| 33 | + super().__init__(*args, **kwargs) |
| 34 | + self.share_codebook_embeddings = kwargs.pop("share_codebook_embeddings", True) |
| 35 | + self.codebook_size = kwargs.pop("codebook_size", 1024) |
| 36 | + self.num_codebooks = kwargs.pop("num_codebooks", 8) |
| 37 | + |
| 38 | + |
| 39 | +AutoTokenizer.register(FishTokenizerConfig, fast_tokenizer_class=FishTokenizerFast) |
| 40 | + |
| 41 | + |
| 42 | +@dataclass(kw_only=True) |
| 43 | +class BasePart: |
| 44 | + pass |
| 45 | + |
| 46 | + |
| 47 | +@dataclass(kw_only=True) |
| 48 | +class VQPart(BasePart): |
| 49 | + codes: torch.Tensor |
| 50 | + |
| 51 | + |
| 52 | +@dataclass(kw_only=True) |
| 53 | +class TextPart(BasePart): |
| 54 | + text: str |
| 55 | + |
| 56 | + |
| 57 | +@dataclass(kw_only=True) |
| 58 | +class MelPart(BasePart): |
| 59 | + mels: torch.Tensor |
| 60 | + |
| 61 | + |
| 62 | +@dataclass(kw_only=True) |
| 63 | +class EncodedMessage: |
| 64 | + tokens: torch.Tensor |
| 65 | + labels: torch.Tensor |
| 66 | + vq_parts: list[torch.Tensor] |
| 67 | + mel_parts: list[torch.Tensor] |
| 68 | + vq_require_losses: torch.Tensor | None = None |
| 69 | + |
| 70 | + |
| 71 | +@dataclass(kw_only=True) |
| 72 | +class Message: |
| 73 | + role: Literal["system", "user", "assistant"] |
| 74 | + parts: list[VQPart | TextPart | MelPart] = field(default_factory=list) |
| 75 | + add_im_start: bool = True |
| 76 | + add_im_end: bool = True |
| 77 | + cal_loss: bool = False |
| 78 | + |
| 79 | + # By default, ignore the loss of the auto-generated im_start token |
| 80 | + ignore_im_start_loss: bool = True |
| 81 | + |
| 82 | + def encode( |
| 83 | + self: "Message", |
| 84 | + tokenizer: AutoTokenizer, |
| 85 | + ) -> EncodedMessage: |
| 86 | + all_tokens = [] |
| 87 | + all_labels = [] |
| 88 | + |
| 89 | + # Multi-modal tokens |
| 90 | + vq_parts = [] |
| 91 | + mel_parts = [] |
| 92 | + |
| 93 | + semantic_id, mel_id = tokenizer.convert_tokens_to_ids( |
| 94 | + [SEMANTIC_TOKEN, MEL_TOKEN] |
| 95 | + ) |
| 96 | + |
| 97 | + parts = self.parts.copy() |
| 98 | + if self.add_im_start: |
| 99 | + parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n")) |
| 100 | + |
| 101 | + if self.add_im_end: |
| 102 | + parts.append(TextPart(text="<|im_end|>")) |
| 103 | + |
| 104 | + for part in parts: |
| 105 | + if isinstance(part, TextPart): |
| 106 | + tokens = tokenizer.encode( |
| 107 | + part.text, |
| 108 | + add_special_tokens=False, |
| 109 | + truncation=False, |
| 110 | + return_tensors="pt", |
| 111 | + ).int()[0] |
| 112 | + elif isinstance(part, VQPart): |
| 113 | + tokens = torch.zeros(part.codes.shape[1], dtype=torch.int) + semantic_id |
| 114 | + codes = part.codes.clone() + 1 |
| 115 | + |
| 116 | + if getattr(tokenizer, "share_codebook_embeddings", True) is False: |
| 117 | + for i in range(len(codes)): |
| 118 | + codes[i] += tokenizer.codebook_size * i |
| 119 | + |
| 120 | + vq_parts.append(codes) |
| 121 | + elif isinstance(part, MelPart): |
| 122 | + tokens = torch.zeros(part.mels.shape[1], dtype=torch.int) + mel_id |
| 123 | + mel_parts.append(part.mels) |
| 124 | + else: |
| 125 | + raise ValueError(f"Unsupported part type: {type(part)}") |
| 126 | + |
| 127 | + all_tokens.append(tokens) |
| 128 | + if self.cal_loss: |
| 129 | + all_labels.append(tokens.clone()) |
| 130 | + else: |
| 131 | + all_labels.append(torch.full_like(tokens, -100)) |
| 132 | + |
| 133 | + tokens = torch.cat(all_tokens, dim=0) |
| 134 | + labels = torch.cat(all_labels, dim=0) |
| 135 | + assert tokens.shape == labels.shape |
| 136 | + |
| 137 | + if self.ignore_im_start_loss and self.add_im_start: |
| 138 | + labels[: len(all_tokens[0])] = -100 |
| 139 | + |
| 140 | + return EncodedMessage( |
| 141 | + tokens=tokens, |
| 142 | + labels=labels, |
| 143 | + vq_parts=vq_parts, |
| 144 | + mel_parts=mel_parts, |
| 145 | + ) |
| 146 | + |
| 147 | + |
| 148 | +@dataclass |
| 149 | +class Conversation: |
| 150 | + messages: list[Message] |
| 151 | + |
| 152 | + def encode( |
| 153 | + self: "Conversation", |
| 154 | + tokenizer: AutoTokenizer, |
| 155 | + add_shift: bool = True, |
| 156 | + ) -> EncodedMessage: |
| 157 | + # Build the input_ids and labels |
| 158 | + tokens = [] |
| 159 | + labels = [] |
| 160 | + vq_parts = [] |
| 161 | + mel_parts = [] |
| 162 | + vq_require_losses = [] |
| 163 | + |
| 164 | + for message in self.messages: |
| 165 | + encoded = message.encode( |
| 166 | + tokenizer, |
| 167 | + ) |
| 168 | + tokens.append(encoded.tokens) |
| 169 | + labels.append(encoded.labels) |
| 170 | + vq_parts.extend(encoded.vq_parts) |
| 171 | + mel_parts.extend(encoded.mel_parts) |
| 172 | + vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts)) |
| 173 | + |
| 174 | + tokens = torch.cat(tokens, dim=0) |
| 175 | + labels = torch.cat(labels, dim=0) |
| 176 | + vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool) |
| 177 | + |
| 178 | + if add_shift: |
| 179 | + tokens = tokens[:-1] |
| 180 | + labels = labels[1:] |
| 181 | + |
| 182 | + assert tokens.dtype in [ |
| 183 | + torch.int, |
| 184 | + torch.long, |
| 185 | + ], f"Invalid dtype: {tokens.dtype}, conv: {conversation}" |
| 186 | + |
| 187 | + return EncodedMessage( |
| 188 | + tokens=tokens, |
| 189 | + labels=labels, |
| 190 | + vq_parts=vq_parts, |
| 191 | + mel_parts=mel_parts, |
| 192 | + vq_require_losses=vq_require_losses, |
| 193 | + ) |
| 194 | + |
| 195 | + def encode_for_inference( |
| 196 | + self: "Conversation", |
| 197 | + tokenizer: AutoTokenizer, |
| 198 | + num_codebooks: int, |
| 199 | + ) -> EncodedMessage: |
| 200 | + encoded = self.encode(tokenizer, add_shift=False) |
| 201 | + tokens = encoded.tokens |
| 202 | + values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int) |
| 203 | + values[0] = tokens |
| 204 | + |
| 205 | + if encoded.vq_parts is None or len(encoded.vq_parts) == 0: |
| 206 | + return values |
| 207 | + |
| 208 | + semantic_id, mel_id = tokenizer.convert_tokens_to_ids( |
| 209 | + [SEMANTIC_TOKEN, MEL_TOKEN] |
| 210 | + ) |
| 211 | + vq_parts = encoded.vq_parts |
| 212 | + vq_parts = torch.cat(vq_parts, dim=1) |
| 213 | + values[1:, tokens == semantic_id] = vq_parts |
| 214 | + return values |
| 215 | + |
| 216 | + def visualize(self: "Conversation", tokenizer: AutoTokenizer): |
| 217 | + encoded = self.encode(tokenizer, add_shift=False) |
| 218 | + |
| 219 | + print_in_blue = lambda x: print("\033[94m" + x + "\033[0m", end="") |
| 220 | + print_in_green = lambda x: print("\033[92m" + x + "\033[0m", end="") |
| 221 | + |
| 222 | + for tok, lab in zip(encoded.tokens, encoded.labels): |
| 223 | + val = tokenizer.decode(tok, skip_special_tokens=False) |
| 224 | + if val == "\n": |
| 225 | + val = "\\n\n" |
| 226 | + |
| 227 | + if lab == -100: |
| 228 | + print_in_green(val) |
| 229 | + else: |
| 230 | + print_in_blue(val) |
| 231 | + |
| 232 | + print() |
| 233 | + |
| 234 | + |
| 235 | +if __name__ == "__main__": |
| 236 | + message0 = Message( |
| 237 | + role="user", |
| 238 | + parts=[ |
| 239 | + TextPart(text="Hello, how are you?"), |
| 240 | + VQPart(codes=torch.zeros((4, 10))), |
| 241 | + ], |
| 242 | + cal_loss=False, |
| 243 | + ) |
| 244 | + |
| 245 | + message1 = Message( |
| 246 | + role="assistant", |
| 247 | + parts=[TextPart(text="I'm fine, thank you.")], |
| 248 | + cal_loss=True, |
| 249 | + ) |
| 250 | + conversation = Conversation([message0, message1]) |
| 251 | + tokenizer = AutoTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct") |
| 252 | + conversation.visualize(tokenizer) |
| 253 | + |
| 254 | + encoded = conversation.encode(tokenizer) |
| 255 | + print(encoded) |
| 256 | + print(tokenizer.batch_decode(encoded.tokens)) |
0 commit comments