Skip to content

Commit 4c96475

Browse files
codingl2k1qinxuye
andauthored
ENH: Update fish audio (#2555)
Co-authored-by: qinxuye <[email protected]>
1 parent 7a0bb60 commit 4c96475

40 files changed

+2505
-275
lines changed

Diff for: .github/workflows/python.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,9 @@ jobs:
180180
${{ env.SELF_HOST_PYTHON }} -m pip uninstall -y "faster_whisper"
181181
${{ env.SELF_HOST_PYTHON }} -m pip install -U accelerate
182182
${{ env.SELF_HOST_PYTHON }} -m pip install -U verovio
183+
${{ env.SELF_HOST_PYTHON }} -m pip install -U cachetools
184+
${{ env.SELF_HOST_PYTHON }} -m pip install -U silero-vad
185+
${{ env.SELF_HOST_PYTHON }} -m pip install -U pydantic
183186
${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \
184187
--disable-warnings \
185188
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/core/tests/test_continuous_batching.py && \

Diff for: setup.cfg

+4
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ all =
129129
natsort # For Fish Speech
130130
loralib # For Fish Speech
131131
ormsgpack # For Fish Speech
132+
cachetools # For Fish Speech
133+
silero-vad # For Fish Speech
132134
qwen-vl-utils # For qwen2-vl
133135
datamodel_code_generator # for minicpm-4B
134136
jsonschema # for minicpm-4B
@@ -210,6 +212,8 @@ audio =
210212
natsort # For Fish Speech
211213
loralib # For Fish Speech
212214
ormsgpack # For Fish Speech
215+
cachetools # For Fish Speech
216+
silero-vad # For Fish Speech
213217
doc =
214218
ipython>=6.5.0
215219
sphinx>=3.0.0

Diff for: xinference/deploy/docker/requirements.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ click
77
tqdm>=4.27
88
tabulate
99
requests
10-
pydantic
10+
pydantic>2
1111
fastapi>=0.110.3
1212
uvicorn
1313
huggingface-hub>=0.19.4
@@ -72,6 +72,8 @@ loguru # For Fish Speech
7272
natsort # For Fish Speech
7373
loralib # For Fish Speech
7474
ormsgpack # For Fish Speech
75+
cachetools # For Fish Speech
76+
silero-vad # For Fish Speech
7577
qwen-vl-utils # For qwen2-vl
7678
datamodel_code_generator # for minicpm-4B
7779
jsonschema # for minicpm-4B

Diff for: xinference/deploy/docker/requirements_cpu.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ click
66
tqdm>=4.27
77
tabulate
88
requests
9-
pydantic
9+
pydantic>2
1010
fastapi>=0.110.3
1111
uvicorn
1212
huggingface-hub>=0.19.4
@@ -67,6 +67,8 @@ loguru # For Fish Speech
6767
natsort # For Fish Speech
6868
loralib # For Fish Speech
6969
ormsgpack # For Fish Speech
70+
cachetools # For Fish Speech
71+
silero-vad # For Fish Speech
7072
qwen-vl-utils # For qwen2-vl
7173
datamodel_code_generator # for minicpm-4B
7274
jsonschema # for minicpm-4B

Diff for: xinference/model/audio/model_spec.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@
159159
"model_name": "FishSpeech-1.4",
160160
"model_family": "FishAudio",
161161
"model_id": "fishaudio/fish-speech-1.4",
162-
"model_revision": "3c49651b8e583b6b13f55e375432e0d57e1aa84d",
162+
"model_revision": "069c573759936b35191d3380deb89183c0656f59",
163163
"model_ability": "text-to-audio",
164164
"multilingual": true
165165
}

Diff for: xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py

Whitespace-only changes.

Diff for: xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,256 @@
1+
from dataclasses import dataclass, field
2+
from typing import Literal
3+
4+
import torch
5+
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerFast
6+
7+
IM_START_TOKEN = "<|im_start|>"
8+
IM_END_TOKEN = "<|im_end|>"
19
SEMANTIC_TOKEN = "<|semantic|>"
10+
MEL_TOKEN = "<|mel|>"
11+
PHONEME_START_TOKEN = "<|phoneme_start|>"
12+
PHONEME_END_TOKEN = "<|phoneme_end|>"
13+
ALL_SPECIAL_TOKENS = [
14+
IM_START_TOKEN,
15+
IM_END_TOKEN,
16+
SEMANTIC_TOKEN,
17+
MEL_TOKEN,
18+
PHONEME_START_TOKEN,
19+
PHONEME_END_TOKEN,
20+
]
21+
222
CODEBOOK_PAD_TOKEN_ID = 0
23+
24+
25+
class FishTokenizerConfig(PretrainedConfig):
26+
share_codebook_embeddings: bool = True
27+
codebook_size: int = 1024
28+
num_codebooks: int = 8
29+
30+
31+
class FishTokenizerFast(PreTrainedTokenizerFast):
32+
def __init__(self, *args, **kwargs):
33+
super().__init__(*args, **kwargs)
34+
self.share_codebook_embeddings = kwargs.pop("share_codebook_embeddings", True)
35+
self.codebook_size = kwargs.pop("codebook_size", 1024)
36+
self.num_codebooks = kwargs.pop("num_codebooks", 8)
37+
38+
39+
AutoTokenizer.register(FishTokenizerConfig, fast_tokenizer_class=FishTokenizerFast)
40+
41+
42+
@dataclass(kw_only=True)
43+
class BasePart:
44+
pass
45+
46+
47+
@dataclass(kw_only=True)
48+
class VQPart(BasePart):
49+
codes: torch.Tensor
50+
51+
52+
@dataclass(kw_only=True)
53+
class TextPart(BasePart):
54+
text: str
55+
56+
57+
@dataclass(kw_only=True)
58+
class MelPart(BasePart):
59+
mels: torch.Tensor
60+
61+
62+
@dataclass(kw_only=True)
63+
class EncodedMessage:
64+
tokens: torch.Tensor
65+
labels: torch.Tensor
66+
vq_parts: list[torch.Tensor]
67+
mel_parts: list[torch.Tensor]
68+
vq_require_losses: torch.Tensor | None = None
69+
70+
71+
@dataclass(kw_only=True)
72+
class Message:
73+
role: Literal["system", "user", "assistant"]
74+
parts: list[VQPart | TextPart | MelPart] = field(default_factory=list)
75+
add_im_start: bool = True
76+
add_im_end: bool = True
77+
cal_loss: bool = False
78+
79+
# By default, ignore the loss of the auto-generated im_start token
80+
ignore_im_start_loss: bool = True
81+
82+
def encode(
83+
self: "Message",
84+
tokenizer: AutoTokenizer,
85+
) -> EncodedMessage:
86+
all_tokens = []
87+
all_labels = []
88+
89+
# Multi-modal tokens
90+
vq_parts = []
91+
mel_parts = []
92+
93+
semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
94+
[SEMANTIC_TOKEN, MEL_TOKEN]
95+
)
96+
97+
parts = self.parts.copy()
98+
if self.add_im_start:
99+
parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n"))
100+
101+
if self.add_im_end:
102+
parts.append(TextPart(text="<|im_end|>"))
103+
104+
for part in parts:
105+
if isinstance(part, TextPart):
106+
tokens = tokenizer.encode(
107+
part.text,
108+
add_special_tokens=False,
109+
truncation=False,
110+
return_tensors="pt",
111+
).int()[0]
112+
elif isinstance(part, VQPart):
113+
tokens = torch.zeros(part.codes.shape[1], dtype=torch.int) + semantic_id
114+
codes = part.codes.clone() + 1
115+
116+
if getattr(tokenizer, "share_codebook_embeddings", True) is False:
117+
for i in range(len(codes)):
118+
codes[i] += tokenizer.codebook_size * i
119+
120+
vq_parts.append(codes)
121+
elif isinstance(part, MelPart):
122+
tokens = torch.zeros(part.mels.shape[1], dtype=torch.int) + mel_id
123+
mel_parts.append(part.mels)
124+
else:
125+
raise ValueError(f"Unsupported part type: {type(part)}")
126+
127+
all_tokens.append(tokens)
128+
if self.cal_loss:
129+
all_labels.append(tokens.clone())
130+
else:
131+
all_labels.append(torch.full_like(tokens, -100))
132+
133+
tokens = torch.cat(all_tokens, dim=0)
134+
labels = torch.cat(all_labels, dim=0)
135+
assert tokens.shape == labels.shape
136+
137+
if self.ignore_im_start_loss and self.add_im_start:
138+
labels[: len(all_tokens[0])] = -100
139+
140+
return EncodedMessage(
141+
tokens=tokens,
142+
labels=labels,
143+
vq_parts=vq_parts,
144+
mel_parts=mel_parts,
145+
)
146+
147+
148+
@dataclass
149+
class Conversation:
150+
messages: list[Message]
151+
152+
def encode(
153+
self: "Conversation",
154+
tokenizer: AutoTokenizer,
155+
add_shift: bool = True,
156+
) -> EncodedMessage:
157+
# Build the input_ids and labels
158+
tokens = []
159+
labels = []
160+
vq_parts = []
161+
mel_parts = []
162+
vq_require_losses = []
163+
164+
for message in self.messages:
165+
encoded = message.encode(
166+
tokenizer,
167+
)
168+
tokens.append(encoded.tokens)
169+
labels.append(encoded.labels)
170+
vq_parts.extend(encoded.vq_parts)
171+
mel_parts.extend(encoded.mel_parts)
172+
vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts))
173+
174+
tokens = torch.cat(tokens, dim=0)
175+
labels = torch.cat(labels, dim=0)
176+
vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
177+
178+
if add_shift:
179+
tokens = tokens[:-1]
180+
labels = labels[1:]
181+
182+
assert tokens.dtype in [
183+
torch.int,
184+
torch.long,
185+
], f"Invalid dtype: {tokens.dtype}, conv: {conversation}"
186+
187+
return EncodedMessage(
188+
tokens=tokens,
189+
labels=labels,
190+
vq_parts=vq_parts,
191+
mel_parts=mel_parts,
192+
vq_require_losses=vq_require_losses,
193+
)
194+
195+
def encode_for_inference(
196+
self: "Conversation",
197+
tokenizer: AutoTokenizer,
198+
num_codebooks: int,
199+
) -> EncodedMessage:
200+
encoded = self.encode(tokenizer, add_shift=False)
201+
tokens = encoded.tokens
202+
values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
203+
values[0] = tokens
204+
205+
if encoded.vq_parts is None or len(encoded.vq_parts) == 0:
206+
return values
207+
208+
semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
209+
[SEMANTIC_TOKEN, MEL_TOKEN]
210+
)
211+
vq_parts = encoded.vq_parts
212+
vq_parts = torch.cat(vq_parts, dim=1)
213+
values[1:, tokens == semantic_id] = vq_parts
214+
return values
215+
216+
def visualize(self: "Conversation", tokenizer: AutoTokenizer):
217+
encoded = self.encode(tokenizer, add_shift=False)
218+
219+
print_in_blue = lambda x: print("\033[94m" + x + "\033[0m", end="")
220+
print_in_green = lambda x: print("\033[92m" + x + "\033[0m", end="")
221+
222+
for tok, lab in zip(encoded.tokens, encoded.labels):
223+
val = tokenizer.decode(tok, skip_special_tokens=False)
224+
if val == "\n":
225+
val = "\\n\n"
226+
227+
if lab == -100:
228+
print_in_green(val)
229+
else:
230+
print_in_blue(val)
231+
232+
print()
233+
234+
235+
if __name__ == "__main__":
236+
message0 = Message(
237+
role="user",
238+
parts=[
239+
TextPart(text="Hello, how are you?"),
240+
VQPart(codes=torch.zeros((4, 10))),
241+
],
242+
cal_loss=False,
243+
)
244+
245+
message1 = Message(
246+
role="assistant",
247+
parts=[TextPart(text="I'm fine, thank you.")],
248+
cal_loss=True,
249+
)
250+
conversation = Conversation([message0, message1])
251+
tokenizer = AutoTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct")
252+
conversation.visualize(tokenizer)
253+
254+
encoded = conversation.encode(tokenizer)
255+
print(encoded)
256+
print(tokenizer.batch_decode(encoded.tokens))

Diff for: xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py

Whitespace-only changes.

Diff for: xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py

Whitespace-only changes.

Diff for: xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py

Whitespace-only changes.

Diff for: xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -118,5 +118,6 @@
118118
"new": "new",
119119
"Realtime Transform Text": "Realtime Transform Text",
120120
"Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)",
121-
"Text Normalization": "Text Normalization"
121+
"Text Normalization": "Text Normalization",
122+
"Select Example Audio": "Select Example Audio"
122123
}

Diff for: xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -118,5 +118,6 @@
118118
"new": "nuevo",
119119
"Realtime Transform Text": "Transformación de Texto en Tiempo Real",
120120
"Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)",
121-
"Text Normalization": "Normalización de Texto"
121+
"Text Normalization": "Normalización de Texto",
122+
"Select Example Audio": "Selecionar áudio de exemplo"
122123
}

Diff for: xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,6 @@
118118
"new": "新規",
119119
"Realtime Transform Text": "リアルタイム変換テキスト",
120120
"Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)",
121-
"Text Normalization": "テキスト正規化"
122-
121+
"Text Normalization": "テキスト正規化",
122+
"Select Example Audio": "サンプル音声を選択"
123123
}

0 commit comments

Comments
 (0)