From 99b90122e86fdd4b480f9dedfe53ad7301717bcb Mon Sep 17 00:00:00 2001 From: xwang0415 Date: Fri, 20 Dec 2024 11:54:22 +0800 Subject: [PATCH 1/2] update webui & fastapi for cosy2.0 --- runtime/python/fastapi/client.py | 14 +++++- runtime/python/fastapi/server.py | 74 ++++++++++++++++++++++++-------- webui.py | 12 ++++-- 3 files changed, 77 insertions(+), 23 deletions(-) diff --git a/runtime/python/fastapi/client.py b/runtime/python/fastapi/client.py index 84e06aed..24328e8c 100644 --- a/runtime/python/fastapi/client.py +++ b/runtime/python/fastapi/client.py @@ -40,13 +40,23 @@ def main(): } files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))] response = requests.request("GET", url, data=payload, files=files, stream=True) - else: + elif args.mode == 'instruct': payload = { 'tts_text': args.tts_text, 'spk_id': args.spk_id, 'instruct_text': args.instruct_text } response = requests.request("GET", url, data=payload, stream=True) + else: + # instruct2 + url = url + "_v2" + payload = { + 'tts_text': args.tts_text, + 'instruct_text': args.instruct_text, + 'format': 'pcm' # option + } + files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))] + response = requests.request("GET", url, data=payload, files=files, stream=True) tts_audio = b'' for r in response.iter_content(chunk_size=16000): tts_audio += r @@ -66,7 +76,7 @@ def main(): default='50000') parser.add_argument('--mode', default='sft', - choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'], + choices=['sft', 'zero_shot', 'cross_lingual', 'instruct', 'instruct2'], help='request mode') parser.add_argument('--tts_text', type=str, diff --git a/runtime/python/fastapi/server.py b/runtime/python/fastapi/server.py index bfe4a56b..6a3f081e 100644 --- a/runtime/python/fastapi/server.py +++ b/runtime/python/fastapi/server.py @@ -13,20 +13,27 @@ # limitations under the License. import os import sys +import io import argparse import logging logging.getLogger('matplotlib').setLevel(logging.WARNING) from fastapi import FastAPI, UploadFile, Form, File -from fastapi.responses import StreamingResponse +from fastapi.responses import StreamingResponse, Response from fastapi.middleware.cors import CORSMiddleware import uvicorn import numpy as np -ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) -sys.path.append('{}/../../..'.format(ROOT_DIR)) -sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR)) -from cosyvoice.cli.cosyvoice import CosyVoice +import torch +import torchaudio +CURR_DIR = os.path.dirname(os.path.abspath(__file__)) +ROOT_DIR = f'{CURR_DIR}/../../..' +sys.path.append(f'{ROOT_DIR}') +sys.path.append(f'{ROOT_DIR}/third_party/Matcha-TTS') +from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 from cosyvoice.utils.file_utils import load_wav +model_dir = f"{ROOT_DIR}/pretrained_models/CosyVoice2-0.5B" +cosyvoice = CosyVoice2(model_dir) if 'CosyVoice2' in model_dir else CosyVoice(model_dir) + app = FastAPI() # set cross region allowance app.add_middleware( @@ -37,6 +44,20 @@ allow_headers=["*"]) +# 非流式wav数据 +def build_data(model_output): + tts_speeches = [] + for i in model_output: + tts_speeches.append(i['tts_speech']) + output = torch.concat(tts_speeches, dim=1) + + buffer = io.BytesIO() + torchaudio.save(buffer, output, 22050, format="wav") + buffer.seek(0) + return buffer.read(-1) + + +# 流式pcm数据 def generate_data(model_output): for i in model_output: tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes() @@ -44,29 +65,51 @@ def generate_data(model_output): @app.get("/inference_sft") -async def inference_sft(tts_text: str = Form(), spk_id: str = Form()): +async def inference_sft(tts_text: str = Form(), spk_id: str = Form(), format: str = Form(default="pcm")): model_output = cosyvoice.inference_sft(tts_text, spk_id) - return StreamingResponse(generate_data(model_output)) + if format == "pcm": + return StreamingResponse(generate_data(model_output)) + else: + return Response(build_data(model_output), media_type="audio/wav") @app.get("/inference_zero_shot") -async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), prompt_wav: UploadFile = File()): +async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), prompt_wav: UploadFile = File(), format: str = Form(default="pcm")): prompt_speech_16k = load_wav(prompt_wav.file, 16000) model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k) - return StreamingResponse(generate_data(model_output)) + if format == "pcm": + return StreamingResponse(generate_data(model_output)) + else: + return Response(build_data(model_output), media_type="audio/wav") @app.get("/inference_cross_lingual") -async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File()): +async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File(), format: str = Form(default="pcm")): prompt_speech_16k = load_wav(prompt_wav.file, 16000) model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k) - return StreamingResponse(generate_data(model_output)) + if format == "pcm": + return StreamingResponse(generate_data(model_output)) + else: + return Response(build_data(model_output), media_type="audio/wav") @app.get("/inference_instruct") -async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form()): +async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form(), format: str = Form(default="pcm")): model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text) - return StreamingResponse(generate_data(model_output)) + if format == "pcm": + return StreamingResponse(generate_data(model_output)) + else: + return Response(build_data(model_output), media_type="audio/wav") + + +@app.get("/inference_instruct_v2") +async def inference_instruct_v2(tts_text: str = Form(), instruct_text: str = Form(), prompt_wav: UploadFile = File(), format: str = Form(default="pcm")): + prompt_speech_16k = load_wav(prompt_wav.file, 16000) + model_output = cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k) + if format == "pcm": + return StreamingResponse(generate_data(model_output)) + else: + return Response(build_data(model_output), media_type="audio/wav") if __name__ == '__main__': @@ -74,10 +117,5 @@ async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instr parser.add_argument('--port', type=int, default=50000) - parser.add_argument('--model_dir', - type=str, - default='iic/CosyVoice-300M', - help='local path or modelscope repo id') args = parser.parse_args() - cosyvoice = CosyVoice(args.model_dir) uvicorn.run(app, host="0.0.0.0", port=args.port) diff --git a/webui.py b/webui.py index 196718dc..3bdf5d66 100644 --- a/webui.py +++ b/webui.py @@ -33,6 +33,7 @@ '自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'} stream_mode_list = [('否', False), ('是', True)] max_val = 0.8 +v2 = True def generate_seed(): @@ -128,8 +129,13 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro else: logging.info('get instruct inference request') set_all_random_seed(seed) - for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream, speed=speed): - yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten()) + if v2: + prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr)) + for i in cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k, stream=stream, speed=speed): + yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten()) + else: + for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream, speed=speed): + yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten()) def main(): @@ -181,7 +187,7 @@ def main(): default='pretrained_models/CosyVoice2-0.5B', help='local path or modelscope repo id') args = parser.parse_args() - cosyvoice = CosyVoice2(args.model_dir) if 'CosyVoice2' in args.model_dir else CosyVoice(args.model_dir) + cosyvoice, v2 = (CosyVoice2(args.model_dir),True) if 'CosyVoice2' in args.model_dir else (CosyVoice(args.model_dir),False) sft_spk = cosyvoice.list_avaliable_spks() prompt_sr = 16000 default_data = np.zeros(cosyvoice.sample_rate) From f26aacf8373674bdfa5e8577212881aa03021504 Mon Sep 17 00:00:00 2001 From: xwang0415 Date: Thu, 26 Dec 2024 13:45:58 +0800 Subject: [PATCH 2/2] update fastapi's server interface and webui --- runtime/python/fastapi/server.py | 20 ++++++++++---------- webui.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/runtime/python/fastapi/server.py b/runtime/python/fastapi/server.py index 6a3f081e..bcc9ebb3 100644 --- a/runtime/python/fastapi/server.py +++ b/runtime/python/fastapi/server.py @@ -65,8 +65,8 @@ def generate_data(model_output): @app.get("/inference_sft") -async def inference_sft(tts_text: str = Form(), spk_id: str = Form(), format: str = Form(default="pcm")): - model_output = cosyvoice.inference_sft(tts_text, spk_id) +async def inference_sft(tts_text: str = Form(), spk_id: str = Form(), stream: bool = Form(default=False), format: str = Form(default="pcm")): + model_output = cosyvoice.inference_sft(tts_text, spk_id, stream=stream) if format == "pcm": return StreamingResponse(generate_data(model_output)) else: @@ -74,9 +74,9 @@ async def inference_sft(tts_text: str = Form(), spk_id: str = Form(), format: st @app.get("/inference_zero_shot") -async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), prompt_wav: UploadFile = File(), format: str = Form(default="pcm")): +async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), prompt_wav: UploadFile = File(), stream: bool = Form(default=False), format: str = Form(default="pcm")): prompt_speech_16k = load_wav(prompt_wav.file, 16000) - model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k) + model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream) if format == "pcm": return StreamingResponse(generate_data(model_output)) else: @@ -84,9 +84,9 @@ async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), @app.get("/inference_cross_lingual") -async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File(), format: str = Form(default="pcm")): +async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File(), stream: bool = Form(default=False), format: str = Form(default="pcm")): prompt_speech_16k = load_wav(prompt_wav.file, 16000) - model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k) + model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream) if format == "pcm": return StreamingResponse(generate_data(model_output)) else: @@ -94,8 +94,8 @@ async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile @app.get("/inference_instruct") -async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form(), format: str = Form(default="pcm")): - model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text) +async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form(), stream: bool = Form(default=False), format: str = Form(default="pcm")): + model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text, stream=stream) if format == "pcm": return StreamingResponse(generate_data(model_output)) else: @@ -103,9 +103,9 @@ async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instr @app.get("/inference_instruct_v2") -async def inference_instruct_v2(tts_text: str = Form(), instruct_text: str = Form(), prompt_wav: UploadFile = File(), format: str = Form(default="pcm")): +async def inference_instruct_v2(tts_text: str = Form(), instruct_text: str = Form(), prompt_wav: UploadFile = File(), stream: bool = Form(default=False), format: str = Form(default="pcm")): prompt_speech_16k = load_wav(prompt_wav.file, 16000) - model_output = cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k) + model_output = cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k, stream=stream) if format == "pcm": return StreamingResponse(generate_data(model_output)) else: diff --git a/webui.py b/webui.py index 3bdf5d66..bce829d6 100644 --- a/webui.py +++ b/webui.py @@ -165,7 +165,7 @@ def main(): generate_button = gr.Button("生成音频") - audio_output = gr.Audio(label="合成音频", autoplay=True, streaming=True) + audio_output = gr.Audio(label="合成音频", autoplay=True, streaming=False) seed_button.click(generate_seed, inputs=[], outputs=seed) generate_button.click(generate_audio,