Skip to content

Commit 1a60b69

Browse files
chore(format): run black on dev
1 parent 6feb586 commit 1a60b69

File tree

3 files changed

+48
-26
lines changed

3 files changed

+48
-26
lines changed

ChatTTS/core.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@ async def _infer(
402402
else:
403403
# Hacker:Check if there are any silent segments; if so, take the last segment. Otherwise, try waiting for another loop.
404404
import librosa
405+
405406
silence_intervals = librosa.effects.split(wavs[0][length:], top_db=10)
406407
silence_left = 0
407408
if len(silence_intervals) == 0:
@@ -532,7 +533,9 @@ async def _infer_code(
532533
async for i in results_generator:
533534
token_ids = []
534535
hidden_states = []
535-
if (stream and len(i.outputs[0].token_ids) % stream_batch_size == 0) or i.finished:
536+
if (
537+
stream and len(i.outputs[0].token_ids) % stream_batch_size == 0
538+
) or i.finished:
536539
token_ids.append(torch.tensor(i.outputs[0].token_ids))
537540
hidden_states.append(
538541
i.outputs[0].hidden_states.to(torch.float32).to(self.device)
@@ -568,9 +571,7 @@ async def _infer_code(
568571
hidden_states = []
569572
if (stream and len(i.ids[0]) % stream_batch_size == 0) or i.finished:
570573
token_ids.append(i.ids[0])
571-
hidden_states.append(
572-
i.hiddens[0].to(torch.float32).to(self.device)
573-
)
574+
hidden_states.append(i.hiddens[0].to(torch.float32).to(self.device))
574575
yield GPT.GenerationOutputs(
575576
ids=token_ids,
576577
finished=i.finished,

ChatTTS/model/gpt.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def from_pretrained(
6868
num_audio_tokens=self.num_audio_tokens,
6969
num_text_tokens=self.num_text_tokens,
7070
post_model_path=embed_file_path,
71-
dtype="float32"
71+
dtype="float32",
7272
)
7373
self.logger.info("vLLM model loaded")
7474
return
@@ -585,7 +585,7 @@ async def generate(
585585
attentions,
586586
hiddens,
587587
infer_text,
588-
False
588+
False,
589589
)
590590
del not_finished
591591

@@ -609,11 +609,5 @@ async def generate(
609609
del finish, inputs_ids_buf
610610

611611
yield self._prepare_generation_outputs(
612-
inputs_ids,
613-
start_idx,
614-
end_idx,
615-
attentions,
616-
hiddens,
617-
infer_text,
618-
True
612+
inputs_ids, start_idx, end_idx, attentions, hiddens, infer_text, True
619613
)

ChatTTS/model/velocity/model_runner.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def set_block_size(self, block_size: int) -> None:
107107
def _prepare_prompt(
108108
self,
109109
seq_group_metadata_list: List[SequenceGroupMetadata],
110-
) -> tuple[list[list[int]], list[list[int]], InputMetadata, list[int], list[Tensor]]:
110+
) -> tuple[
111+
list[list[int]], list[list[int]], InputMetadata, list[int], list[Tensor]
112+
]:
111113
assert len(seq_group_metadata_list) > 0
112114
input_tokens: List[List[int]] = []
113115
input_positions: List[List[int]] = []
@@ -360,17 +362,23 @@ def _prepare_sample(
360362
def prepare_input_tensors(
361363
self,
362364
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
363-
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata, list[torch.Tensor]]:
365+
) -> Tuple[
366+
torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata, list[torch.Tensor]
367+
]:
364368
speaker_embedding = None
365369
if self.is_driver_worker:
366370
# NOTE: We assume that all sequences in the group are all prompts or
367371
# all decodes.
368372
is_prompt = seq_group_metadata_list[0].is_prompt
369373
# Prepare input tensors.
370374
if is_prompt:
371-
(input_tokens, input_positions, input_metadata, prompt_lens, speaker_embedding) = (
372-
self._prepare_prompt(seq_group_metadata_list)
373-
)
375+
(
376+
input_tokens,
377+
input_positions,
378+
input_metadata,
379+
prompt_lens,
380+
speaker_embedding,
381+
) = self._prepare_prompt(seq_group_metadata_list)
374382
else:
375383
(input_tokens, input_positions, input_metadata) = self._prepare_decode(
376384
seq_group_metadata_list
@@ -462,7 +470,13 @@ def get_size_or_none(x: Optional[torch.Tensor]):
462470
perform_sampling=False,
463471
)
464472

465-
return input_tokens, input_positions, input_metadata, sampling_metadata, speaker_embedding
473+
return (
474+
input_tokens,
475+
input_positions,
476+
input_metadata,
477+
sampling_metadata,
478+
speaker_embedding,
479+
)
466480

467481
@torch.inference_mode()
468482
def execute_model(
@@ -471,9 +485,13 @@ def execute_model(
471485
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
472486
) -> Optional[SamplerOutput]:
473487

474-
input_tokens, input_positions, input_metadata, sampling_metadata, speaker_embedding = (
475-
self.prepare_input_tensors(seq_group_metadata_list)
476-
)
488+
(
489+
input_tokens,
490+
input_positions,
491+
input_metadata,
492+
sampling_metadata,
493+
speaker_embedding,
494+
) = self.prepare_input_tensors(seq_group_metadata_list)
477495
# print(sampling_metadata.seq_data)
478496
seq_groups = []
479497
for i, rtn in enumerate(sampling_metadata.seq_groups):
@@ -522,7 +540,9 @@ def execute_model(
522540
if speaker_embedding_params is None:
523541
speaker_embedding_params = speaker_embedding[i]
524542
else:
525-
speaker_embedding_params = torch.cat((speaker_embedding_params, speaker_embedding[i]))
543+
speaker_embedding_params = torch.cat(
544+
(speaker_embedding_params, speaker_embedding[i])
545+
)
526546

527547
else:
528548
speaker_embedding_params = self.post_model(input_tokens, text_mask)
@@ -560,7 +580,7 @@ def execute_model(
560580
# sampling_metadata=sampling_metadata,
561581
# )
562582
results = []
563-
for i,val in enumerate(seq_groups):
583+
for i, val in enumerate(seq_groups):
564584
idx_next_i = idx_next[i, 0, :].tolist()
565585
logprob_i = logprob[i].tolist()
566586
tmp_hidden_states = hidden_states[i]
@@ -781,7 +801,9 @@ def _make_tensor_with_pad(
781801
for x_i in x:
782802
pad_i = pad
783803
if isinstance(x[0][0], list):
784-
pad_i = [0,] * len(x[0][0])
804+
pad_i = [
805+
0,
806+
] * len(x[0][0])
785807
elif isinstance(x[0][0], tuple):
786808
pad_i = (0,) * len(x[0][0])
787809
padded_x.append(_pad_to_max(x_i, max_len, pad_i))
@@ -791,6 +813,7 @@ def _make_tensor_with_pad(
791813
device=device,
792814
)
793815

816+
794817
def _make_with_pad(
795818
x: List[torch.Tensor],
796819
max_len: int,
@@ -805,11 +828,15 @@ def _make_with_pad(
805828
padded_x.append(x_i)
806829
else:
807830
padded_x.append(
808-
torch.cat((torch.zeros(1, max_len-x_i.shape[-2], 768).to(device), x_i), dim=1)
831+
torch.cat(
832+
(torch.zeros(1, max_len - x_i.shape[-2], 768).to(device), x_i),
833+
dim=1,
834+
)
809835
)
810836

811837
return padded_x
812838

839+
813840
def _get_graph_batch_size(batch_size: int) -> int:
814841
if batch_size <= 2:
815842
return batch_size

0 commit comments

Comments
 (0)