Make WebUI and API code cleaner (+ 1.5 fixes) (#703)

* rename webui.py to run_webui.py

* remove unused imports

* remove unsued code

* move inference code and fix all warnings

* move web app code

* make code easier to read

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove unused function

* remove msgpack_api.py

* rename API files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* finish updating the doc with the new file names

* finish updating the doc with the new file names

* fix CPU use in the API

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor WebUIinference in a class with submodules

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* re-enable streaming in webui inference code

* generalize inference code in webui

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fix

* make a unique inference engine class

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fix

* cleaning code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* implement new structure of the API (not working)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor API

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* reimplement chat endpoint

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Picus303 2024-12-07 07:13:19 +01:00 committed by GitHub
parent 954cae1b5d
commit 62eae262c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
45 changed files with 1959 additions and 1697 deletions

View File

@ -45,7 +45,7 @@ body:
description: |
Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks.
placeholder: |
1. Run the command `python -m tools.post_api -t "xxxxx"`
1. Run the command `python -m tools.api_client -t "xxxxx"`
2. Observe the console output error: `ModuleNotFoundError: No module named 'pyaudio'` (with screenshots or logs will be better)
validations:
required: true

View File

@ -185,7 +185,7 @@ pip install -e .[stable]
4. Configure environment variables and access WebUI
In the terminal inside the docker container, enter `export GRADIO_SERVER_NAME="0.0.0.0"` to allow external access to the gradio service inside docker.
Then in the terminal inside the docker container, enter `python tools/webui.py` to start the WebUI service.
Then in the terminal inside the docker container, enter `python tools/run_webui.py` to start the WebUI service.
If you're using WSL or MacOS, visit [http://localhost:7860](http://localhost:7860) to open the WebUI interface.

View File

@ -67,7 +67,7 @@ python tools/vqgan/inference.py \
We provide a HTTP API for inference. You can use the following command to start the server:
```bash
python -m tools.api \
python -m tools.api_server \
--listen 0.0.0.0:8080 \
--llama-checkpoint-path "checkpoints/fish-speech-1.5" \
--decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
@ -78,10 +78,10 @@ python -m tools.api \
After that, you can view and test the API at http://127.0.0.1:8080/.
Below is an example of sending a request using `tools/post_api.py`.
Below is an example of sending a request using `tools/api_client.py`.
```bash
python -m tools.post_api \
python -m tools.api_client \
--text "Text to be input" \
--reference_audio "Path to reference audio" \
--reference_text "Text content of the reference audio" \
@ -93,7 +93,7 @@ The above command indicates synthesizing the desired audio according to the refe
The following example demonstrates that you can use **multiple** reference audio paths and reference audio texts at once. Separate them with spaces in the command.
```bash
python -m tools.post_api \
python -m tools.api_client \
--text "Text to input" \
--reference_audio "reference audio path1" "reference audio path2" \
--reference_text "reference audio text1" "reference audio text2"\
@ -109,7 +109,7 @@ The currently supported reference audio has a maximum total duration of 90 secon
!!! info
To learn more about available parameters, you can use the command `python -m tools.post_api -h`
To learn more about available parameters, you can use the command `python -m tools.api_client -h`
## GUI Inference
[Download client](https://github.com/AnyaCoder/fish-speech-gui/releases)

View File

@ -44,7 +44,7 @@ pip install -e .[stable]
To build fish-agent, please use the command below under the main folder:
```bash
python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
```
The `--compile` args only support Python < 3.12 , which will greatly speed up the token generation.

View File

@ -184,7 +184,7 @@ pip install -e .[stable]
4. 環境変数の設定と WebUI へのアクセス
Docker コンテナ内のターミナルで、`export GRADIO_SERVER_NAME="0.0.0.0"` と入力して、外部から Docker 内の gradio サービスにアクセスできるようにします。
次に、Docker コンテナ内のターミナルで `python tools/webui.py` と入力して WebUI サービスを起動します。
次に、Docker コンテナ内のターミナルで `python tools/run_webui.py` と入力して WebUI サービスを起動します。
WSL または MacOS の場合は、[http://localhost:7860](http://localhost:7860) にアクセスして WebUI インターフェースを開くことができます。

View File

@ -67,7 +67,7 @@ python tools/vqgan/inference.py \
推論のための HTTP API を提供しています。次のコマンドを使用してサーバーを起動できます:
```bash
python -m tools.api \
python -m tools.api_server \
--listen 0.0.0.0:8080 \
--llama-checkpoint-path "checkpoints/fish-speech-1.5" \
--decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
@ -78,10 +78,10 @@ python -m tools.api \
その後、`http://127.0.0.1:8080/`で API を表示およびテストできます。
以下は、`tools/post_api.py` を使用してリクエストを送信する例です。
以下は、`tools/api_client.py` を使用してリクエストを送信する例です。
```bash
python -m tools.post_api \
python -m tools.api_client \
--text "入力するテキスト" \
--reference_audio "参照音声へのパス" \
--reference_text "参照音声テキスト" \
@ -91,7 +91,7 @@ python -m tools.post_api \
上記のコマンドは、参照音声の情報に基づいて必要な音声を合成し、ストリーミング方式で返すことを示しています。
!!! info
使用可能なパラメータの詳細については、コマンド` python -m tools.post_api -h `を使用してください
使用可能なパラメータの詳細については、コマンド` python -m tools.api_client -h `を使用してください
## WebUI 推論

View File

@ -47,7 +47,7 @@ pip install -e .[stable]
fish-agentを構築するには、メインフォルダで以下のコマンドを使用してください:
```bash
python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
```
`--compile`引数はPython < 3.12でのみサポートされておりトークン生成を大幅に高速化します

View File

@ -185,7 +185,7 @@ pip install -e .[stable]
4. 환경 변수 설정 및 WebUI 접근
Docker 컨테이너 내부의 터미널에서 `export GRADIO_SERVER_NAME="0.0.0.0"`를 입력하여 Docker 내부에서 Gradio 서비스에 외부 접근을 허용합니다.
이후, 터미널에서 `python tools/webui.py` 명령어를 입력하여 WebUI 서비스를 시작합니다.
이후, 터미널에서 `python tools/run_webui.py` 명령어를 입력하여 WebUI 서비스를 시작합니다.
WSL 또는 macOS를 사용하는 경우 [http://localhost:7860](http://localhost:7860)에서 WebUI 인터페이스를 열 수 있습니다.

View File

@ -67,7 +67,7 @@ python tools/vqgan/inference.py \
추론을 위한 HTTP API를 제공하고 있습니다. 아래의 명령어로 서버를 시작할 수 있습니다:
```bash
python -m tools.api \
python -m tools.api_server \
--listen 0.0.0.0:8080 \
--llama-checkpoint-path "checkpoints/fish-speech-1.5" \
--decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
@ -78,10 +78,10 @@ python -m tools.api \
이후, http://127.0.0.1:8080/ 에서 API를 확인하고 테스트할 수 있습니다.
아래는 `tools/post_api.py`를 사용하여 요청을 보내는 예시입니다.
아래는 `tools/api_client.py`를 사용하여 요청을 보내는 예시입니다.
```bash
python -m tools.post_api \
python -m tools.api_client \
--text "입력할 텍스트" \
--reference_audio "참고 음성 경로" \
--reference_text "참고 음성의 텍스트 내용" \
@ -93,7 +93,7 @@ python -m tools.post_api \
다음 예시는 여러 개의 참고 음성 경로와 텍스트를 한꺼번에 사용할 수 있음을 보여줍니다. 명령에서 공백으로 구분하여 입력합니다.
```bash
python -m tools.post_api \
python -m tools.api_client \
--text "입력할 텍스트" \
--reference_audio "참고 음성 경로1" "참고 음성 경로2" \
--reference_text "참고 음성 텍스트1" "참고 음성 텍스트2"\
@ -107,7 +107,7 @@ python -m tools.post_api \
`--reference_audio``--reference_text` 대신에 `--reference_id`(하나만 사용 가능)를 사용할 수 있습니다. 프로젝트 루트 디렉토리에 `references/<your reference_id>` 폴더를 만들어 해당 음성과 주석 텍스트를 넣어야 합니다. 참고 음성은 최대 90초까지 지원됩니다.
!!! info
제공되는 파라미터는 `python -m tools.post_api -h`를 사용하여 확인할 수 있습니다.
제공되는 파라미터는 `python -m tools.api_client -h`를 사용하여 확인할 수 있습니다.
## GUI 추론
[클라이언트 다운로드](https://github.com/AnyaCoder/fish-speech-gui/releases)

View File

@ -47,7 +47,7 @@ pip install -e .[stable]
fish-agent를 구축하려면 메인 폴더에서 아래 명령어를 사용하세요:
```bash
python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
```
`--compile` 인자는 Python < 3.12에서만 지원되며, 토큰 생성 속도를 크게 향상시킵니다.

View File

@ -181,7 +181,7 @@ pip install -e .[stable]
4. Configure as variáveis de ambiente e acesse a WebUI
No terminal do contêiner Docker, digite `export GRADIO_SERVER_NAME="0.0.0.0"` para permitir o acesso externo ao serviço gradio dentro do Docker.
Em seguida, no terminal do contêiner Docker, digite `python tools/webui.py` para iniciar o serviço WebUI.
Em seguida, no terminal do contêiner Docker, digite `python tools/run_webui.py` para iniciar o serviço WebUI.
Se estiver usando WSL ou MacOS, acesse [http://localhost:7860](http://localhost:7860) para abrir a interface WebUI.

View File

@ -67,7 +67,7 @@ python tools/vqgan/inference.py \
Fornecemos uma API HTTP para inferência. O seguinte comando pode ser usado para iniciar o servidor:
```bash
python -m tools.api \
python -m tools.api_server \
--listen 0.0.0.0:8080 \
--llama-checkpoint-path "checkpoints/fish-speech-1.5" \
--decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
@ -78,10 +78,10 @@ python -m tools.api \
Depois disso, é possível visualizar e testar a API em http://127.0.0.1:8080/.
Abaixo está um exemplo de envio de uma solicitação usando `tools/post_api.py`.
Abaixo está um exemplo de envio de uma solicitação usando `tools/api_client.py`.
```bash
python -m tools.post_api \
python -m tools.api_client \
--text "Texto a ser inserido" \
--reference_audio "Caminho para o áudio de referência" \
--reference_text "Conteúdo de texto do áudio de referência" \
@ -91,7 +91,7 @@ python -m tools.post_api \
O comando acima indica a síntese do áudio desejada de acordo com as informações do áudio de referência e a retorna em modo de streaming.
!!! info
Para aprender mais sobre parâmetros disponíveis, você pode usar o comando `python -m tools.post_api -h`
Para aprender mais sobre parâmetros disponíveis, você pode usar o comando `python -m tools.api_client -h`
## Inferência por WebUI

View File

@ -47,7 +47,7 @@ pip install -e .[stable]
Para construir o fish-agent, use o comando abaixo na pasta principal:
```bash
python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
```
O argumento `--compile` só suporta Python < 3.12, o que aumentará muito a velocidade de geração de tokens.

View File

@ -188,7 +188,7 @@ pip install -e .[stable]
4. 配置环境变量,访问 WebUI
在 docker 容器内的终端,输入 `export GRADIO_SERVER_NAME="0.0.0.0"` ,从而让外部可以访问 docker 内的 gradio 服务。
接着在 docker 容器内的终端,输入 `python tools/webui.py` 即可开启 WebUI 服务。
接着在 docker 容器内的终端,输入 `python tools/run_webui.py` 即可开启 WebUI 服务。
如果是 WSL 或者是 MacOS ,访问 [http://localhost:7860](http://localhost:7860) 即可打开 WebUI 界面。

View File

@ -73,7 +73,7 @@ python tools/vqgan/inference.py \
运行以下命令来启动 HTTP 服务:
```bash
python -m tools.api \
python -m tools.api_server \
--listen 0.0.0.0:8080 \
--llama-checkpoint-path "checkpoints/fish-speech-1.5" \
--decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
@ -88,10 +88,10 @@ HF_ENDPOINT=https://hf-mirror.com python -m ...(同上)
随后, 你可以在 `http://127.0.0.1:8080/` 中查看并测试 API.
下面是使用`tools/post_api.py`发送请求的示例。
下面是使用`tools/api_client.py`发送请求的示例。
```bash
python -m tools.post_api \
python -m tools.api_client \
--text "要输入的文本" \
--reference_audio "参考音频路径" \
--reference_text "参考音频的文本内容" \
@ -102,7 +102,7 @@ python -m tools.post_api \
下面的示例展示了, 可以一次使用**多个** `参考音频路径``参考音频的文本内容`。在命令里用空格隔开即可。
```bash
python -m tools.post_api \
python -m tools.api_client \
--text "要输入的文本" \
--reference_audio "参考音频路径1" "参考音频路径2" \
--reference_text "参考音频的文本内容1" "参考音频的文本内容2"\
@ -117,7 +117,7 @@ python -m tools.post_api \
里面放上任意对音频与标注文本。 目前支持的参考音频最多加起来总时长90s。
!!! info
要了解有关可用参数的更多信息,可以使用命令`python -m tools.post_api -h`
要了解有关可用参数的更多信息,可以使用命令`python -m tools.api_client -h`
## GUI 推理
[下载客户端](https://github.com/AnyaCoder/fish-speech-gui/releases)

View File

@ -49,7 +49,7 @@ pip install -e .[stable]
你需要使用以下指令来构建 fish-agent
```bash
python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
```
`--compile`只能在小于 3.12 版本的 Python 使用,这个功能可以极大程度上提高生成速度。

View File

@ -7,4 +7,4 @@ if [ "${CUDA_ENABLED}" != "true" ]; then
DEVICE="--device cpu"
fi
exec python tools/webui.py ${DEVICE}
exec python tools/run_webui.py ${DEVICE}

View File

@ -176,7 +176,7 @@ def change_infer(
p_infer = subprocess.Popen(
[
PYTHON,
"tools/webui.py",
"tools/run_webui.py",
"--decoder-checkpoint-path",
infer_decoder_model,
"--decoder-config-name",

View File

@ -83,7 +83,7 @@
},
"outputs": [],
"source": [
"!python tools/webui.py \\\n",
"!python tools/run_webui.py \\\n",
" --llama-checkpoint-path checkpoints/fish-speech-1.4 \\\n",
" --decoder-checkpoint-path checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth \\\n",
" # --compile"

View File

@ -82,7 +82,7 @@ if not "!flags!"=="" set "flags=!flags:~1!"
echo Debug: flags = !flags!
if "!mode!"=="api" (
%PYTHON_CMD% -m tools.api !flags!
%PYTHON_CMD% -m tools.api_server !flags!
) else if "!mode!"=="infer" (
%PYTHON_CMD% -m tools.webui !flags!
)

View File

@ -1,951 +0,0 @@
import io
import json
import os
import queue
import re
import time
import traceback
import wave
from argparse import ArgumentParser
from http import HTTPStatus
from pathlib import Path
from typing import Annotated, Any
import librosa
import numpy as np
import ormsgpack
import pyrootutils
import soundfile as sf
import torch
import torchaudio
from baize.datastructures import ContentType
from kui.asgi import (
Body,
FactoryClass,
HTTPException,
HttpRequest,
HttpView,
JSONResponse,
Kui,
OpenAPI,
StreamResponse,
request,
)
from kui.asgi.routing import MultimethodRoutes
from loguru import logger
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
import struct
from threading import Lock
import httpx
from cachetools import LRUCache, cached
from funasr import AutoModel
from silero_vad import get_speech_timestamps, load_silero_vad
from fish_speech.models.text2semantic.llama import BaseModelArgs
# from fish_speech.models.vqgan.lit_module import VQGAN
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
# from fish_speech.conversation import IM_END_TOKEN, SEMANTIC_TOKEN
from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
from fish_speech.utils import autocast_exclude_mps, set_seed
from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
from tools.llama.generate import (
GenerateRequest,
GenerateResponse,
WrappedGenerateResponse,
launch_thread_safe_queue,
launch_thread_safe_queue_agent,
)
from tools.schema import (
GLOBAL_NUM_SAMPLES,
ASRPackRequest,
ServeASRRequest,
ServeASRResponse,
ServeASRSegment,
ServeAudioPart,
ServeForwardMessage,
ServeMessage,
ServeRequest,
ServeResponse,
ServeStreamDelta,
ServeStreamResponse,
ServeTextPart,
ServeTimedASRResponse,
ServeTTSRequest,
ServeVQGANDecodeRequest,
ServeVQGANDecodeResponse,
ServeVQGANEncodeRequest,
ServeVQGANEncodeResponse,
ServeVQPart,
)
from tools.vqgan.inference import load_model as load_decoder_model
global_lock = Lock()
# Whether to disable keepalive (which is helpful if the server is in the same cluster)
DISABLE_KEEPALIVE = os.getenv("DISABLE_KEEPALIVE", "false").lower() == "true"
async_client = httpx.AsyncClient(
timeout=120, limits=httpx.Limits(keepalive_expiry=0 if DISABLE_KEEPALIVE else None)
)
backends = torchaudio.list_audio_backends()
if "ffmpeg" in backends:
backend = "ffmpeg"
else:
backend = "soundfile"
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
buffer = io.BytesIO()
with wave.open(buffer, "wb") as wav_file:
wav_file.setnchannels(channels)
wav_file.setsampwidth(bit_depth // 8)
wav_file.setframerate(sample_rate)
wav_header_bytes = buffer.getvalue()
buffer.close()
return wav_header_bytes
# Define utils for web server
async def http_execption_handler(exc: HTTPException):
return JSONResponse(
dict(
statusCode=exc.status_code,
message=exc.content,
error=HTTPStatus(exc.status_code).phrase,
),
exc.status_code,
exc.headers,
)
async def other_exception_handler(exc: "Exception"):
traceback.print_exc()
status = HTTPStatus.INTERNAL_SERVER_ERROR
return JSONResponse(
dict(statusCode=status, message=str(exc), error=status.phrase),
status,
)
def load_audio(reference_audio, sr):
if len(reference_audio) > 255 or not Path(reference_audio).exists():
audio_data = reference_audio
reference_audio = io.BytesIO(audio_data)
waveform, original_sr = torchaudio.load(reference_audio, backend=backend)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
if original_sr != sr:
resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr)
waveform = resampler(waveform)
audio = waveform.squeeze().numpy()
return audio
def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
if enable_reference_audio and reference_audio is not None:
# Load audios, and prepare basic info here
reference_audio_content = load_audio(
reference_audio, decoder_model.spec_transform.sample_rate
)
audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
None, None, :
]
audio_lengths = torch.tensor(
[audios.shape[2]], device=decoder_model.device, dtype=torch.long
)
logger.info(
f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
)
# VQ Encoder
if isinstance(decoder_model, FireflyArchitecture):
prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
logger.info(f"Encoded prompt: {prompt_tokens.shape}")
else:
prompt_tokens = None
logger.info("No reference audio provided")
return prompt_tokens
def decode_vq_tokens(
*,
decoder_model,
codes,
):
feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
logger.info(f"VQ features: {codes.shape}")
if isinstance(decoder_model, FireflyArchitecture):
# VQGAN Inference
return decoder_model.decode(
indices=codes[None],
feature_lengths=feature_lengths,
)[0].squeeze()
raise ValueError(f"Unknown model type: {type(decoder_model)}")
routes = MultimethodRoutes(base_class=HttpView)
def get_content_type(audio_format):
if audio_format == "wav":
return "audio/wav"
elif audio_format == "flac":
return "audio/flac"
elif audio_format == "mp3":
return "audio/mpeg"
else:
return "application/octet-stream"
@torch.no_grad()
@torch.autocast(device_type="cuda", dtype=torch.half)
def batch_encode(model, audios: list[bytes | torch.Tensor]):
audios = [
(
torch.from_numpy(
librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
)[None]
if isinstance(audio, bytes)
else audio
)
for audio in audios
]
# if any(audio.shape[-1] > model.spec_transform.sample_rate * 120 for audio in audios):
# raise ValueError("Single audio length is too long (>120s)")
max_length = max(audio.shape[-1] for audio in audios)
print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
max_length = lengths.max().item()
padded = torch.stack(
[
torch.nn.functional.pad(audio, (0, max_length - audio.shape[-1]))
for audio in audios
]
).to(model.device)
features, feature_lengths = model.encode(padded, audio_lengths=lengths)
features, feature_lengths = features.cpu(), feature_lengths.cpu()
return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
@cached(
cache=LRUCache(maxsize=10000),
key=lambda model, audios: (model.device, tuple(audios)),
)
def cached_vqgan_batch_encode(model, audios: list[bytes]):
return batch_encode(model, audios)
@routes.http.post("/v1/vqgan/encode")
def api_vqgan_encode(payload: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
start_time = time.time()
tokens = cached_vqgan_batch_encode(decoder_model, payload.audios)
logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
return ormsgpack.packb(
ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
)
@torch.no_grad()
@torch.autocast(device_type="cuda", dtype=torch.half)
def vqgan_decode(model, features):
lengths = torch.tensor(
[feature.shape[-1] for feature in features], device=model.device
)
max_length = lengths.max().item()
padded = torch.stack(
[
torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
for feature in features
]
).to(model.device)
# If bs too large, we do micro batch decode
audios, audio_lengths = [], []
for i in range(0, padded.shape[0], 8):
audio, audio_length = model.decode(
padded[i : i + 8], feature_lengths=lengths[i : i + 8]
)
audios.append(audio)
audio_lengths.append(audio_length)
audios = torch.cat(audios, dim=0)
audio_lengths = torch.cat(audio_lengths, dim=0)
audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
@routes.http.post("/v1/vqgan/decode")
def api_vqgan_decode(payload: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
tokens = [torch.tensor(token, dtype=torch.int) for token in payload.tokens]
start_time = time.time()
audios = vqgan_decode(decoder_model, tokens)
logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
audios = [audio.astype(np.float16).tobytes() for audio in audios]
return ormsgpack.packb(
ServeVQGANDecodeResponse(audios=audios), option=ormsgpack.OPT_SERIALIZE_PYDANTIC
)
@torch.no_grad()
def batch_asr(model, audios, sr, language="auto"):
resampled_audios = []
for audio in audios:
audio = torchaudio.functional.resample(audio, sr, 16000)
assert audio.ndim == 1
resampled_audios.append(audio)
with global_lock:
res = model.generate(
input=resampled_audios,
batch_size=len(resampled_audios),
language=language,
use_itn=True,
)
results = []
for r, audio in zip(res, audios):
text = r["text"]
text = re.sub(r"<\|.*?\|>", "", text)
duration = len(audio) / sr * 1000
huge_gap = False
if "timestamp" in r and len(r["timestamp"]) > 2:
for timestamp_a, timestamp_b in zip(
r["timestamp"][:-1], r["timestamp"][1:]
):
# If there is a gap of more than 5 seconds, we consider it as a huge gap
if timestamp_b[0] - timestamp_a[1] > 5000:
huge_gap = True
break
# Doesn't make sense to have a huge gap at the end
if duration - r["timestamp"][-1][1] > 3000:
huge_gap = True
results.append(
{
"text": text,
"duration": duration,
"huge_gap": huge_gap,
}
)
return results
@routes.http.post("/v1/asr")
def api_invoke_asr(payload: Annotated[ServeASRRequest, Body(exclusive=True)]):
start_time = time.time()
audios = [np.frombuffer(audio, dtype=np.float16) for audio in payload.audios]
audios = [torch.from_numpy(audio).float() for audio in audios]
if any(audios.shape[-1] >= 30 * payload.sample_rate for audios in audios):
raise HTTPException(status_code=400, detail="Audio length is too long")
transcriptions = batch_asr(
asr_model, audios=audios, sr=payload.sample_rate, language=payload.language
)
logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
return ormsgpack.packb(
ServeASRResponse(transcriptions=transcriptions),
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
)
from fish_speech.conversation import Conversation, Message
def execute_request(
input_queue: queue.Queue,
tokenizer: FishTokenizer,
config: BaseModelArgs,
request: ServeRequest,
device: str = "cuda:0",
):
im_end_id = tokenizer.get_token_id(IM_END_TOKEN)
messages = []
for message in request.messages:
messages.append(message.to_conversation_message())
assert len(messages) >= 1, "At least one message is required"
# assert messages[-1].role == "user", "The last message must be from the user"
if messages[-1].role == "user":
messages.append(
Message(role="assistant", parts=[], add_im_end=False, modality="voice")
)
elif messages[-1].role == "raw":
messages[-1].add_im_start = False
messages[-1].add_im_end = False
messages[-1].modality = "voice"
else:
assert (
messages[-1].role == "assistant"
), "The last message must be from the assistant"
messages[-1].add_im_end = False
conv = Conversation(messages=messages)
# conv.visualize(tokenizer)
prompt = conv.encode_for_inference(
tokenizer=tokenizer, num_codebooks=config.num_codebooks
).to(device)
if request.streaming:
for i in range(request.num_samples):
yield ServeStreamResponse(
sample_id=i,
delta=ServeStreamDelta(
role="assistant",
),
)
req = {
"prompt": prompt,
"max_new_tokens": request.max_new_tokens,
"im_end_id": im_end_id,
"temperature": request.temperature,
"top_p": request.top_p,
"repetition_penalty": request.repetition_penalty,
"num_samples": request.num_samples,
"early_stop_threshold": request.early_stop_threshold,
}
start = time.time()
response_queue = queue.Queue()
input_queue.put(GenerateRequest(req, response_queue))
# Decoding
decode_buffer = [[] for _ in range(request.num_samples)]
parts = [[] for _ in range(request.num_samples)]
def send_reset_buffer(sample_id):
nonlocal decode_buffer
if len(decode_buffer[sample_id]) == 0:
return
decoded = tokenizer.decode(decode_buffer[sample_id])
part = ServeTextPart(text=decoded)
if request.streaming:
yield ServeStreamResponse(delta=ServeStreamDelta(part=part))
else:
parts[sample_id].append(part)
decode_buffer[sample_id] = []
# Decode process
finished = [False for _ in range(request.num_samples)]
stats = {}
idx = 0
while True:
response = response_queue.get()
if response in ["stop", "error"]:
break
for sample_id, tokens in enumerate(response):
if finished[sample_id]:
continue
if tokens[0] == im_end_id:
finished[sample_id] = True
if request.streaming:
yield from send_reset_buffer(sample_id)
yield ServeStreamResponse(
sample_id=sample_id,
finish_reason="stop",
stats=stats,
)
continue
is_semantic = (
tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id
)
if is_semantic and request.streaming:
yield from send_reset_buffer(sample_id)
# Streaming vq
_tokens = tokens[1:].clone()
if config.share_codebook_embeddings is False:
for i in range(len(_tokens)):
_tokens[i] -= config.codebook_size * i
yield ServeStreamResponse(
sample_id=sample_id,
delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
)
continue
# Not streaming vq
if is_semantic:
yield from send_reset_buffer(sample_id)
# None streaming vq
if len(parts[sample_id]) == 0 or not isinstance(
parts[sample_id][-1], ServeVQPart
):
_tokens = tokens[1:].clone()
if config.share_codebook_embeddings is False:
for i in range(len(_tokens)):
_tokens[i] -= config.codebook_size * i
parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
else:
for codebook_id, value in enumerate(tokens[1:, :]):
val = value.item()
if config.share_codebook_embeddings is False:
val -= config.codebook_size * codebook_id
parts[sample_id][-1].codes[codebook_id].append(val)
continue
if not is_semantic:
# Stream text decode is not supported now
decode_buffer[sample_id].append(tokens[0, 0])
if idx == 0:
stats["time_to_first_token"] = (time.time() - start) * 1000
idx += 1
for sample_id in range(request.num_samples):
yield from send_reset_buffer(sample_id)
stats["total_time"] = (time.time() - start) * 1000
stats["total_tokens"] = idx
if request.streaming:
for sample_id in range(request.num_samples):
if finished[sample_id]:
continue
yield ServeStreamResponse(
finish_reason=response, stats=stats, sample_id=sample_id
)
return
yield ServeResponse(
messages=[
ServeMessage(role="assistant", parts=parts[i])
for i in range(request.num_samples)
],
finish_reason=response,
stats=stats,
)
@routes.http.post("/v1/chat")
def api_invoke_chat(
req: Annotated[ServeRequest, Body(exclusive=True)],
):
"""
Invoke model and generate audio
"""
# This makes torch compile happy
assert (
req.num_samples == GLOBAL_NUM_SAMPLES
), f"num_samples must be {GLOBAL_NUM_SAMPLES}"
content_type = request.headers.get("Content-Type", "application/json")
json_mode = "application/json" in content_type
async def wrapped_generator():
generator = execute_request(llama_queue, tokenizer, config, req, args.device)
for i in generator:
if json_mode:
body = i.model_dump_json().encode("utf-8")
yield b"data: " + body + b"\n\n"
else:
body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
yield struct.pack("I", len(body)) + body
# Naive mode
if req.streaming is False:
result = next(execute_request(llama_queue, tokenizer, config, req, args.device))
if json_mode:
return JSONResponse(result.model_dump())
else:
return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
return StreamResponse(
iterable=wrapped_generator(), content_type="text/event-stream"
)
@torch.inference_mode()
def inference(req: ServeTTSRequest):
idstr: str | None = req.reference_id
if idstr is not None:
ref_folder = Path("references") / idstr
ref_folder.mkdir(parents=True, exist_ok=True)
ref_audios = list_files(
ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
)
if req.use_memory_cache == "never" or (
req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
):
prompt_tokens = [
encode_reference(
decoder_model=decoder_model,
reference_audio=audio_to_bytes(str(ref_audio)),
enable_reference_audio=True,
)
for ref_audio in ref_audios
]
prompt_texts = [
read_ref_text(str(ref_audio.with_suffix(".lab")))
for ref_audio in ref_audios
]
else:
logger.info("Use same references")
else:
# Parse reference audio aka prompt
refs = req.references
if req.use_memory_cache == "never" or (
req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
):
prompt_tokens = [
encode_reference(
decoder_model=decoder_model,
reference_audio=ref.audio,
enable_reference_audio=True,
)
for ref in refs
]
prompt_texts = [ref.text for ref in refs]
else:
logger.info("Use same references")
if req.seed is not None:
set_seed(req.seed)
logger.warning(f"set seed: {req.seed}")
# LLAMA Inference
request = dict(
device=decoder_model.device,
max_new_tokens=req.max_new_tokens,
text=(
req.text
if not req.normalize
else ChnNormedText(raw_text=req.text).normalize()
),
top_p=req.top_p,
repetition_penalty=req.repetition_penalty,
temperature=req.temperature,
compile=args.compile,
iterative_prompt=req.chunk_length > 0,
chunk_length=req.chunk_length,
max_length=4096,
prompt_tokens=prompt_tokens,
prompt_text=prompt_texts,
)
response_queue = queue.Queue()
llama_queue.put(
GenerateRequest(
request=request,
response_queue=response_queue,
)
)
if req.streaming:
yield wav_chunk_header()
segments = []
while True:
result: WrappedGenerateResponse = response_queue.get()
if result.status == "error":
raise result.response
break
result: GenerateResponse = result.response
if result.action == "next":
break
with autocast_exclude_mps(
device_type=decoder_model.device.type, dtype=args.precision
):
fake_audios = decode_vq_tokens(
decoder_model=decoder_model,
codes=result.codes,
)
fake_audios = fake_audios.float().cpu().numpy()
if req.streaming:
yield (fake_audios * 32768).astype(np.int16).tobytes()
else:
segments.append(fake_audios)
if req.streaming:
return
if len(segments) == 0:
raise HTTPException(
HTTPStatus.INTERNAL_SERVER_ERROR,
content="No audio generated, please check the input text.",
)
fake_audios = np.concatenate(segments, axis=0)
yield fake_audios
async def inference_async(req: ServeTTSRequest):
for chunk in inference(req):
yield chunk
async def buffer_to_async_generator(buffer):
yield buffer
@routes.http.post("/v1/tts")
async def api_invoke_model(
req: Annotated[ServeTTSRequest, Body(exclusive=True)],
):
"""
Invoke model and generate audio
"""
if args.max_text_length > 0 and len(req.text) > args.max_text_length:
raise HTTPException(
HTTPStatus.BAD_REQUEST,
content=f"Text is too long, max length is {args.max_text_length}",
)
if req.streaming and req.format != "wav":
raise HTTPException(
HTTPStatus.BAD_REQUEST,
content="Streaming only supports WAV format",
)
if req.streaming:
return StreamResponse(
iterable=inference_async(req),
headers={
"Content-Disposition": f"attachment; filename=audio.{req.format}",
},
content_type=get_content_type(req.format),
)
else:
fake_audios = next(inference(req))
buffer = io.BytesIO()
sf.write(
buffer,
fake_audios,
decoder_model.spec_transform.sample_rate,
format=req.format,
)
return StreamResponse(
iterable=buffer_to_async_generator(buffer.getvalue()),
headers={
"Content-Disposition": f"attachment; filename=audio.{req.format}",
},
content_type=get_content_type(req.format),
)
@routes.http.post("/v1/health")
async def api_health():
"""
Health check
"""
return JSONResponse({"status": "ok"})
def parse_args():
parser = ArgumentParser()
parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
parser.add_argument("--load-asr-model", action="store_true")
parser.add_argument(
"--llama-checkpoint-path",
type=str,
default="checkpoints/fish-speech-1.4",
)
parser.add_argument(
"--decoder-checkpoint-path",
type=str,
default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
)
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--half", action="store_true")
parser.add_argument("--compile", action="store_true")
parser.add_argument("--max-text-length", type=int, default=0)
parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
parser.add_argument("--workers", type=int, default=1)
return parser.parse_args()
# Define Kui app
openapi = OpenAPI(
{
"title": "Fish Speech API",
"version": "1.4.2",
},
).routes
class MsgPackRequest(HttpRequest):
async def data(
self,
) -> Annotated[
Any, ContentType("application/msgpack"), ContentType("application/json")
]:
if self.content_type == "application/msgpack":
return ormsgpack.unpackb(await self.body)
elif self.content_type == "application/json":
return await self.json
raise HTTPException(
HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
headers={"Accept": "application/msgpack, application/json"},
)
app = Kui(
routes=routes + openapi[1:], # Remove the default route
exception_handlers={
HTTPException: http_execption_handler,
Exception: other_exception_handler,
},
factory_class=FactoryClass(http=MsgPackRequest),
cors_config={},
)
def load_asr_model(*, device="cuda", hub="ms"):
return AutoModel(
model="iic/SenseVoiceSmall",
device=device,
disable_pbar=True,
hub=hub,
)
# Each worker process created by Uvicorn has its own memory space,
# meaning that models and variables are not shared between processes.
# Therefore, any global variables (like `llama_queue` or `decoder_model`)
# will not be shared across workers.
# Multi-threading for deep learning can cause issues, such as inconsistent
# outputs if multiple threads access the same buffers simultaneously.
# Instead, it's better to use multiprocessing or independent models per thread.
@app.on_startup
def initialize_app(app: Kui):
global args, llama_queue, tokenizer, config, decoder_model, vad_model, asr_model, prompt_tokens, prompt_texts
prompt_tokens, prompt_texts = [], []
args = parse_args() # args same as ones in other processes
args.precision = torch.half if args.half else torch.bfloat16
if args.load_asr_model:
logger.info(f"Loading ASR model...")
asr_model = load_asr_model(device=args.device)
logger.info("Loading Llama model...")
if args.mode == "tts":
llama_queue = launch_thread_safe_queue(
checkpoint_path=args.llama_checkpoint_path,
device=args.device,
precision=args.precision,
compile=args.compile,
)
else:
llama_queue, tokenizer, config = launch_thread_safe_queue_agent(
checkpoint_path=args.llama_checkpoint_path,
device=args.device,
precision=args.precision,
compile=args.compile,
)
logger.info("Llama model loaded, loading VQ-GAN model...")
decoder_model = load_decoder_model(
config_name=args.decoder_config_name,
checkpoint_path=args.decoder_checkpoint_path,
device=args.device,
)
logger.info("VQ-GAN model loaded, warming up...")
vad_model = load_silero_vad()
logger.info("VAD model loaded, warming up...")
if args.mode == "tts":
# Dry run to ensure models work and avoid first-time latency
list(
inference(
ServeTTSRequest(
text="Hello world.",
references=[],
reference_id=None,
max_new_tokens=0,
chunk_length=200,
top_p=0.7,
repetition_penalty=1.5,
temperature=0.7,
emotion=None,
format="wav",
)
)
)
logger.info(f"Warming up done, starting server at http://{args.listen}")
if __name__ == "__main__":
import uvicorn
args = parse_args()
host, port = args.listen.split(":")
uvicorn.run(
"tools.api:app",
host=host,
port=int(port),
workers=args.workers,
log_level="info",
)

View File

@ -69,10 +69,6 @@ def parse_args():
parser.add_argument(
"--format", type=str, choices=["wav", "mp3", "flac"], default="wav"
)
parser.add_argument(
"--mp3_bitrate", type=int, choices=[64, 128, 192], default=64, help="kHz"
)
parser.add_argument("--opus_bitrate", type=int, default=-1000)
parser.add_argument(
"--latency",
type=str,
@ -112,11 +108,9 @@ def parse_args():
parser.add_argument(
"--use_memory_cache",
type=str,
default="never",
choices=["on-demand", "never"],
help="Cache encoded references codes in memory.\n"
"If `on-demand`, the server will use cached encodings\n "
"instead of encoding reference audio again.",
default="off",
choices=["on", "off"],
help="Cache encoded references codes in memory.\n",
)
parser.add_argument(
"--seed",
@ -154,14 +148,14 @@ if __name__ == "__main__":
data = {
"text": args.text,
"references": [
ServeReferenceAudio(audio=ref_audio, text=ref_text)
ServeReferenceAudio(
audio=ref_audio if ref_audio is not None else b"", text=ref_text
)
for ref_text, ref_audio in zip(ref_texts, byte_audios)
],
"reference_id": idstr,
"normalize": args.normalize,
"format": args.format,
"mp3_bitrate": args.mp3_bitrate,
"opus_bitrate": args.opus_bitrate,
"max_new_tokens": args.max_new_tokens,
"chunk_length": args.chunk_length,
"top_p": args.top_p,

98
tools/api_server.py Normal file
View File

@ -0,0 +1,98 @@
from threading import Lock
import pyrootutils
import uvicorn
from kui.asgi import FactoryClass, HTTPException, HttpRoute, Kui, OpenAPI, Routes
from loguru import logger
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from tools.server.api_utils import MsgPackRequest, parse_args
from tools.server.exception_handler import ExceptionHandler
from tools.server.model_manager import ModelManager
from tools.server.views import (
ASRView,
ChatView,
HealthView,
TTSView,
VQGANDecodeView,
VQGANEncodeView,
)
class API(ExceptionHandler):
def __init__(self):
self.args = parse_args()
self.routes = [
("/v1/health", HealthView),
("/v1/vqgan/encode", VQGANEncodeView),
("/v1/vqgan/decode", VQGANDecodeView),
("/v1/asr", ASRView),
("/v1/tts", TTSView),
("/v1/chat", ChatView),
]
self.routes = Routes([HttpRoute(path, view) for path, view in self.routes])
self.openapi = OpenAPI(
{
"title": "Fish Speech API",
"version": "1.5.0",
},
).routes
# Initialize the app
self.app = Kui(
routes=self.routes + self.openapi[1:], # Remove the default route
exception_handlers={
HTTPException: self.http_exception_handler,
Exception: self.other_exception_handler,
},
factory_class=FactoryClass(http=MsgPackRequest),
cors_config={},
)
# Add the state variables
self.app.state.lock = Lock()
self.app.state.device = self.args.device
self.app.state.max_text_length = self.args.max_text_length
# Associate the app with the model manager
self.app.on_startup(self.initialize_app)
async def initialize_app(self, app: Kui):
# Make the ModelManager available to the views
app.state.model_manager = ModelManager(
mode=self.args.mode,
device=self.args.device,
half=self.args.half,
compile=self.args.compile,
asr_enabled=self.args.load_asr_model,
llama_checkpoint_path=self.args.llama_checkpoint_path,
decoder_checkpoint_path=self.args.decoder_checkpoint_path,
decoder_config_name=self.args.decoder_config_name,
)
logger.info(f"Startup done, listening server at http://{self.args.listen}")
# Each worker process created by Uvicorn has its own memory space,
# meaning that models and variables are not shared between processes.
# Therefore, any variables (like `llama_queue` or `decoder_model`)
# will not be shared across workers.
# Multi-threading for deep learning can cause issues, such as inconsistent
# outputs if multiple threads access the same buffers simultaneously.
# Instead, it's better to use multiprocessing or independent models per thread.
if __name__ == "__main__":
api = API()
host, port = api.args.listen.split(":")
uvicorn.run(
api.app,
host=host,
port=int(port),
workers=api.args.workers,
log_level="info",
)

View File

@ -14,8 +14,8 @@ import ormsgpack
import soundfile as sf
from .schema import (
ServeChatRequest,
ServeMessage,
ServeRequest,
ServeTextPart,
ServeVQGANDecodeRequest,
ServeVQGANEncodeRequest,
@ -163,7 +163,7 @@ class FishE2EAgent:
else:
user_codes = None
request = ServeRequest(
request = ServeChatRequest(
messages=prev_messages
+ (
[

View File

@ -0,0 +1,193 @@
import gc
import queue
from typing import Generator
import numpy as np
import torch
from loguru import logger
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
from fish_speech.utils import autocast_exclude_mps, set_seed
from tools.inference_engine.reference_loader import ReferenceLoader
from tools.inference_engine.utils import InferenceResult, wav_chunk_header
from tools.inference_engine.vq_manager import VQManager
from tools.llama.generate import (
GenerateRequest,
GenerateResponse,
WrappedGenerateResponse,
)
from tools.schema import ServeTTSRequest
class TTSInferenceEngine(ReferenceLoader, VQManager):
def __init__(
self,
llama_queue: queue.Queue,
decoder_model: FireflyArchitecture,
precision: torch.dtype,
compile: bool,
) -> None:
super().__init__()
self.llama_queue = llama_queue
self.decoder_model = decoder_model
self.precision = precision
self.compile = compile
@torch.inference_mode()
def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, None]:
"""
Main inference function:
- Loads the reference audio and text.
- Calls the LLAMA model for inference.
- Decodes the VQ tokens to audio.
"""
ref_id: str | None = req.reference_id
prompt_tokens, prompt_texts = [], []
# Load the reference audio and text based on id or hash
if ref_id is not None:
prompt_tokens, prompt_texts = self.load_by_id(ref_id, req.use_memory_cache)
elif req.references:
prompt_tokens, prompt_texts = self.load_by_hash(
req.references, req.use_memory_cache
)
# Set the random seed if provided
if req.seed is not None:
set_seed(req.seed)
logger.warning(f"set seed: {req.seed}")
# Get the symbolic tokens from the LLAMA model
response_queue = self.send_Llama_request(req, prompt_tokens, prompt_texts)
# Get the sample rate from the decoder model
sample_rate = self.decoder_model.spec_transform.sample_rate
# If streaming, send the header
if req.streaming:
yield InferenceResult(
code="header",
audio=(sample_rate, wav_chunk_header(sample_rate=sample_rate)),
error=None,
)
segments = []
while True:
# Get the response from the LLAMA model
wrapped_result: WrappedGenerateResponse = response_queue.get()
if wrapped_result.status == "error":
yield InferenceResult(
code="error",
audio=None,
error=(
wrapped_result.response
if isinstance(wrapped_result.response, Exception)
else Exception("Unknown error")
),
)
break
# Check the response type
if not isinstance(wrapped_result.response, GenerateResponse):
raise TypeError(
"Expected GenerateResponse, got {type(wrapped_result.response).__name__}"
)
result: GenerateResponse = wrapped_result.response
if result.action != "next":
segment = self.get_audio_segment(result)
if req.streaming: # Used only by the API server
yield InferenceResult(
code="segment",
audio=(sample_rate, segment),
error=None,
)
else:
segments.append(segment)
else:
break
# Clean up the memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Edge case: no audio generated
if len(segments) == 0:
yield InferenceResult(
code="error",
audio=None,
error=RuntimeError("No audio generated, please check the input text."),
)
else:
# Streaming or not, return the final audio
audio = np.concatenate(segments, axis=0)
yield InferenceResult(
code="final",
audio=(sample_rate, audio),
error=None,
)
return None
def send_Llama_request(
self, req: ServeTTSRequest, prompt_tokens: list, prompt_texts: list
) -> queue.Queue:
"""
Send a request to the LLAMA model to generate the symbolic tokens.
"""
# Prepare the request
request = dict(
device=self.decoder_model.device,
max_new_tokens=req.max_new_tokens,
text=(
req.text
if not req.normalize
else ChnNormedText(raw_text=req.text).normalize()
),
top_p=req.top_p,
repetition_penalty=req.repetition_penalty,
temperature=req.temperature,
compile=self.compile,
iterative_prompt=req.chunk_length > 0,
chunk_length=req.chunk_length,
max_length=4096,
prompt_tokens=prompt_tokens,
prompt_text=prompt_texts,
)
# Create a queue to get the response
response_queue = queue.Queue()
# Send the request to the LLAMA model
self.llama_queue.put(
GenerateRequest(
request=request,
response_queue=response_queue,
)
)
return response_queue
def get_audio_segment(self, result: GenerateResponse) -> np.ndarray:
"""
Decode the VQ tokens to audio.
"""
# Don't use autocast on MPS devices
with autocast_exclude_mps(
device_type=self.decoder_model.device.type, dtype=self.precision
):
# Decode the symbolic tokens to audio
segment = self.decode_vq_tokens(codes=result.codes)
# Convert the audio to numpy
return segment.float().cpu().numpy()

View File

@ -0,0 +1,128 @@
import io
from hashlib import sha256
from pathlib import Path
from typing import Callable, Literal, Tuple
import torch
import torchaudio
from loguru import logger
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
from tools.schema import ServeReferenceAudio
class ReferenceLoader:
def __init__(self) -> None:
"""
Component of the TTSInferenceEngine class.
Loads and manages the cache for the reference audio and text.
"""
self.ref_by_id: dict = {}
self.ref_by_hash: dict = {}
# Make Pylance happy (attribut/method not defined...)
self.decoder_model: FireflyArchitecture
self.encode_reference: Callable
# Define the torchaudio backend
backends = torchaudio.list_audio_backends()
if "ffmpeg" in backends:
self.backend = "ffmpeg"
else:
self.backend = "soundfile"
def load_by_id(
self,
id: str,
use_cache: Literal["on", "off"],
) -> Tuple:
# Load the references audio and text by id
ref_folder = Path("references") / id
ref_folder.mkdir(parents=True, exist_ok=True)
ref_audios = list_files(
ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
)
if use_cache == "off" or id not in self.ref_by_id:
# If the references are not already loaded, encode them
prompt_tokens = [
self.encode_reference(
decoder_model=self.decoder_model,
reference_audio=audio_to_bytes(str(ref_audio)),
enable_reference_audio=True,
)
for ref_audio in ref_audios
]
prompt_texts = [
read_ref_text(str(ref_audio.with_suffix(".lab")))
for ref_audio in ref_audios
]
self.ref_by_id[id] = (prompt_tokens, prompt_texts)
else:
# Reuse already encoded references
logger.info("Use same references")
prompt_tokens, prompt_texts = self.ref_by_id[id]
return prompt_tokens, prompt_texts
def load_by_hash(
self,
references: list[ServeReferenceAudio],
use_cache: Literal["on", "off"],
) -> Tuple:
# Load the references audio and text by hash
audio_hashes = [sha256(ref.audio).hexdigest() for ref in references]
cache_used = False
prompt_tokens, prompt_texts = [], []
for i, ref in enumerate(references):
if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash:
# If the references are not already loaded, encode them
prompt_tokens.append(
self.encode_reference(
decoder_model=self.decoder_model,
reference_audio=ref.audio,
enable_reference_audio=True,
)
)
prompt_texts.append(ref.text)
self.ref_by_hash[audio_hashes[i]] = (prompt_tokens, prompt_texts)
else:
# Reuse already encoded references
prompt_text, prompt_token = self.ref_by_hash[audio_hashes[i]]
prompt_texts.append(prompt_text)
prompt_tokens.append(prompt_token)
cache_used = True
if cache_used:
logger.info("Use same references")
return prompt_tokens, prompt_texts
def load_audio(self, reference_audio, sr):
"""
Load the audio data from a file or bytes.
"""
if len(reference_audio) > 255 or not Path(reference_audio).exists():
audio_data = reference_audio
reference_audio = io.BytesIO(audio_data)
waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
if original_sr != sr:
resampler = torchaudio.transforms.Resample(
orig_freq=original_sr, new_freq=sr
)
waveform = resampler(waveform)
audio = waveform.squeeze().numpy()
return audio

View File

@ -0,0 +1,42 @@
import io
import wave
from dataclasses import dataclass
from typing import Literal, Optional, Tuple
import numpy as np
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
@dataclass
class InferenceResult:
code: Literal["header", "segment", "error", "final"]
audio: Optional[Tuple[int, np.ndarray]]
error: Optional[Exception]
def normalize_text(user_input: str, use_normalization: bool) -> str:
"""Normalize user input text if needed."""
if use_normalization:
return ChnNormedText(raw_text=user_input).normalize()
else:
return user_input
def wav_chunk_header(
sample_rate: int = 44100, bit_depth: int = 16, channels: int = 1
) -> np.ndarray:
buffer = io.BytesIO()
with wave.open(buffer, "wb") as wav_file:
wav_file.setnchannels(channels)
wav_file.setsampwidth(bit_depth // 8)
wav_file.setframerate(sample_rate)
wav_header_bytes = buffer.getvalue()
buffer.close()
# Convert to numpy array
wav_header = np.frombuffer(wav_header_bytes, dtype=np.uint8)
return wav_header

View File

@ -0,0 +1,57 @@
from typing import Callable
import torch
from loguru import logger
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
class VQManager:
def __init__(self):
# Make Pylance happy (attribut/method not defined...)
self.decoder_model: FireflyArchitecture
self.load_audio: Callable
def decode_vq_tokens(self, codes):
feature_lengths = torch.tensor(
[codes.shape[1]], device=self.decoder_model.device
)
logger.info(f"VQ features: {codes.shape}")
if isinstance(self.decoder_model, FireflyArchitecture):
return self.decoder_model.decode(
indices=codes[None],
feature_lengths=feature_lengths,
)[0].squeeze()
raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
def encode_reference(self, reference_audio, enable_reference_audio):
if enable_reference_audio and reference_audio is not None:
# Load audios, and prepare basic info here
reference_audio_content = self.load_audio(
reference_audio, self.decoder_model.spec_transform.sample_rate
)
audios = torch.from_numpy(reference_audio_content).to(
self.decoder_model.device
)[None, None, :]
audio_lengths = torch.tensor(
[audios.shape[2]], device=self.decoder_model.device, dtype=torch.long
)
logger.info(
f"Loaded audio with {audios.shape[2] / self.decoder_model.spec_transform.sample_rate:.2f} seconds"
)
# VQ Encoder
if isinstance(self.decoder_model, FireflyArchitecture):
prompt_tokens = self.decoder_model.encode(audios, audio_lengths)[0][0]
logger.info(f"Encoded prompt: {prompt_tokens.shape}")
else:
raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
else:
prompt_tokens = None
logger.info("No reference audio provided")
return prompt_tokens

View File

@ -1,95 +0,0 @@
import os
from argparse import ArgumentParser
from pathlib import Path
import httpx
import ormsgpack
from tools.schema import ServeReferenceAudio, ServeTTSRequest
api_key = os.environ.get("FISH_API_KEY", "YOUR_API_KEY")
def audio_request():
# priority: ref_id > references
request = ServeTTSRequest(
text="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
# reference_id="114514",
references=[
ServeReferenceAudio(
audio=open("lengyue.wav", "rb").read(),
text=open("lengyue.lab", "r", encoding="utf-8").read(),
)
],
streaming=True,
)
api_key = os.environ.get("FISH_API_KEY", "YOUR_API_KEY")
with (
httpx.Client() as client,
open("hello.wav", "wb") as f,
):
with client.stream(
"POST",
"http://127.0.0.1:8080/v1/tts",
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
headers={
"authorization": f"Bearer {api_key}",
"content-type": "application/msgpack",
},
timeout=None,
) as response:
for chunk in response.iter_bytes():
f.write(chunk)
def asr_request(audio_path: Path):
# Read the audio file
with open(
str(audio_path),
"rb",
) as audio_file:
audio_data = audio_file.read()
# Prepare the request data
request_data = {
"audio": audio_data,
"language": "en", # Optional: specify the language
"ignore_timestamps": False, # Optional: set to True to ignore precise timestamps
}
# Send the request
with httpx.Client() as client:
response = client.post(
"https://api.fish.audio/v1/asr",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/msgpack",
},
content=ormsgpack.packb(request_data),
)
# Parse the response
result = response.json()
print(f"Transcribed text: {result['text']}")
print(f"Audio duration: {result['duration']} seconds")
for segment in result["segments"]:
print(f"Segment: {segment['text']}")
print(f"Start time: {segment['start']}, End time: {segment['end']}")
def parse_args():
parser = ArgumentParser()
parser.add_argument("--audio_path", type=Path, default="audio/ref/trump.mp3")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
asr_request(args.audio_path)

101
tools/run_webui.py Normal file
View File

@ -0,0 +1,101 @@
import os
from argparse import ArgumentParser
from pathlib import Path
import pyrootutils
import torch
from loguru import logger
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from tools.inference_engine import TTSInferenceEngine
from tools.llama.generate import launch_thread_safe_queue
from tools.schema import ServeTTSRequest
from tools.vqgan.inference import load_model as load_decoder_model
from tools.webui import build_app
from tools.webui.inference import get_inference_wrapper
# Make einx happy
os.environ["EINX_FILTER_TRACEBACK"] = "false"
def parse_args():
parser = ArgumentParser()
parser.add_argument(
"--llama-checkpoint-path",
type=Path,
default="checkpoints/fish-speech-1.5",
)
parser.add_argument(
"--decoder-checkpoint-path",
type=Path,
default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
)
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--half", action="store_true")
parser.add_argument("--compile", action="store_true")
parser.add_argument("--max-gradio-length", type=int, default=0)
parser.add_argument("--theme", type=str, default="light")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
args.precision = torch.half if args.half else torch.bfloat16
# Check if CUDA is available
if not torch.cuda.is_available():
logger.info("CUDA is not available, running on CPU.")
args.device = "cpu"
logger.info("Loading Llama model...")
llama_queue = launch_thread_safe_queue(
checkpoint_path=args.llama_checkpoint_path,
device=args.device,
precision=args.precision,
compile=args.compile,
)
logger.info("Loading VQ-GAN model...")
decoder_model = load_decoder_model(
config_name=args.decoder_config_name,
checkpoint_path=args.decoder_checkpoint_path,
device=args.device,
)
logger.info("Decoder model loaded, warming up...")
# Create the inference engine
inference_engine = TTSInferenceEngine(
llama_queue=llama_queue,
decoder_model=decoder_model,
compile=args.compile,
precision=args.precision,
)
# Dry run to check if the model is loaded correctly and avoid the first-time latency
list(
inference_engine.inference(
ServeTTSRequest(
text="Hello world.",
references=[],
reference_id=None,
max_new_tokens=0,
chunk_length=200,
top_p=0.7,
repetition_penalty=1.5,
temperature=0.7,
format="wav",
)
)
)
logger.info("Warming up done, launching the web UI...")
# Get the inference function with the immutable arguments
inference_fct = get_inference_wrapper(inference_engine)
app = build_app(inference_fct, args.theme)
app.launch(show_api=True)

View File

@ -1,16 +1,14 @@
import os
import queue
from dataclasses import dataclass
from typing import Annotated, Literal, Optional
from typing import Annotated, Literal
import torch
from pydantic import AfterValidator, BaseModel, Field, confloat, conint, conlist
from pydantic import BaseModel, Field, conint, conlist
from pydantic.functional_validators import SkipValidation
from fish_speech.conversation import Message, TextPart, VQPart
GLOBAL_NUM_SAMPLES = int(os.getenv("GLOBAL_NUM_SAMPLES", 1))
class ServeVQPart(BaseModel):
type: Literal["vq"] = "vq"
@ -64,7 +62,7 @@ class ServeASRResponse(BaseModel):
class ServeMessage(BaseModel):
role: Literal["system", "assistant", "user", "raw"]
role: Literal["system", "assistant", "user"]
parts: list[ServeVQPart | ServeTextPart]
def to_conversation_message(self):
@ -85,7 +83,7 @@ class ServeMessage(BaseModel):
return new_message
class ServeRequest(BaseModel):
class ServeChatRequest(BaseModel):
messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
max_new_tokens: int = 1024
top_p: float = 0.7
@ -114,11 +112,6 @@ class ServeVQGANDecodeResponse(BaseModel):
audios: list[bytes]
class ServeReferenceAudio(BaseModel):
audio: bytes
text: str
class ServeForwardMessage(BaseModel):
role: str
content: str
@ -150,24 +143,11 @@ class ServeReferenceAudio(BaseModel):
return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
class ServeChatRequestV1(BaseModel):
model: str = "llama3-8b"
messages: list[ServeForwardMessage] = []
audio: bytes | None = None
temperature: float = 1.0
top_p: float = 1.0
max_tokens: int = 256
voice: str = "jessica"
tts_audio_format: Literal["mp3", "pcm", "opus"] = "mp3"
tts_audio_bitrate: Literal[16, 24, 32, 48, 64, 96, 128, 192] = 128
class ServeTTSRequest(BaseModel):
text: str
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
# Audio format
format: Literal["wav", "pcm", "mp3"] = "wav"
mp3_bitrate: Literal[64, 128, 192] = 128
# References audios for in-context learning
references: list[ServeReferenceAudio] = []
# Reference id
@ -175,16 +155,16 @@ class ServeTTSRequest(BaseModel):
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
reference_id: str | None = None
seed: int | None = None
use_memory_cache: Literal["on-demand", "never"] = "never"
use_memory_cache: Literal["on", "off"] = "off"
# Normalize text for en & zh, this increase stability for numbers
normalize: bool = True
mp3_bitrate: Optional[int] = 64
opus_bitrate: Optional[int] = -1000
# Balance mode will reduce latency to 300ms, but may decrease stability
latency: Literal["normal", "balanced"] = "normal"
# not usually used below
streaming: bool = False
max_new_tokens: int = 1024
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
class Config:
# Allow arbitrary types for pytorch related types
arbitrary_types_allowed = True

View File

@ -0,0 +1,57 @@
import struct
from functools import partial
import ormsgpack
from tools.server.agent.generate import generate_responses
from tools.server.agent.pre_generation_utils import prepare_messages
def execute_request(input_queue, tokenizer, config, request, device):
"""
This function prepares the conversation, encodes the request,
sends the generation request, and handles decoding/streaming.
It returns a response generator (ServeResponse or ServeStreamResponse).
"""
prompt, im_end_id = prepare_messages(request, tokenizer, config)
yield from generate_responses(
input_queue, tokenizer, config, request, prompt, im_end_id, device
)
def response_generator(req, llama_queue, tokenizer, config, device):
"""
Non-streaming response wrapper for the chat endpoint.
Only returns the final result.
"""
generator = execute_request(llama_queue, tokenizer, config, req, device)
return next(generator)
async def streaming_generator(req, llama_queue, tokenizer, config, device, json_mode):
"""
Streaming response wrapper for the chat endpoint.
Returns the response in chunks.
"""
generator = execute_request(llama_queue, tokenizer, config, req, device)
for i in generator:
if json_mode:
body = i.model_dump_json().encode("utf-8")
yield b"data: " + body + b"\n\n"
else:
body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
yield struct.pack("I", len(body)) + body
def get_response_generator(
llama_queue, tokenizer, config, req, device, json_mode
) -> partial:
"""
Get the correct response generator based on the request.
"""
if not req.streaming:
return partial(response_generator, req, llama_queue, tokenizer, config, device)
else:
return partial(
streaming_generator, req, llama_queue, tokenizer, config, device, json_mode
)

View File

@ -0,0 +1,119 @@
import time
from tools.schema import ServeMessage, ServeResponse, ServeStreamResponse
from tools.server.agent.generation_utils import (
initialize_decode_buffers,
process_response_tokens,
send_reset_buffer,
)
from tools.server.agent.pre_generation_utils import (
create_generation_request,
send_generation_request,
)
def generate_responses(
input_queue, tokenizer, config, request, prompt, im_end_id, device
):
"""
Main generation function that handles the conversation, encodes the request,
sends the generation request, and handles decoding/streaming.
It returns a response generator (ServeResponse or ServeStreamResponse).
"""
stats = {}
start = time.time()
stats["start_time"] = start
stats["tokens_count"] = 0
# Prepare and send the generation request
req = create_generation_request(prompt, request, im_end_id, device)
response_queue = send_generation_request(input_queue, req)
decode_buffer, parts, finished = initialize_decode_buffers(request.num_samples)
while True:
response = response_queue.get()
# Handle abnormal finish or error
if response in ["stop", "error"]:
finish_reason = response
break
# Process the response tokens
is_first_token = stats["tokens_count"] == 0
responses = process_response_tokens(
response,
tokenizer,
config,
request,
decode_buffer,
parts,
finished,
im_end_id,
stats,
start,
is_first_token,
)
# Yield the responses if streaming
if request.streaming and responses:
for r in responses:
yield r
stats["tokens_count"] += 1
# Check if all samples are finished
if all(finished):
finish_reason = "stop"
break
# Finalize the response
final_responses = finalize_response(
request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
)
for fr in final_responses:
yield fr
def finalize_response(
request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
):
"""
Finalize the response by sending the remaining text buffers.
"""
responses = []
# Send the remaining text buffers
for sample_id in range(request.num_samples):
responses.extend(
send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
)
# Calculate the final stats
stats["total_time"] = (time.time() - stats["start_time"]) * 1000
stats["total_tokens"] = stats["tokens_count"]
# If streaming, send the final chunks for each sample
if request.streaming:
for sample_id in range(request.num_samples):
if finished[sample_id]:
continue
responses.append(
ServeStreamResponse(
finish_reason=finish_reason, stats=stats, sample_id=sample_id
)
)
else:
# If not streaming, send the full messages for each sample
full_messages = [
ServeMessage(role="assistant", parts=parts[i])
for i in range(request.num_samples)
]
responses.append(
ServeResponse(
messages=full_messages,
finish_reason=finish_reason,
stats=stats,
)
)
return responses

View File

@ -0,0 +1,122 @@
import time
from tools.schema import (
ServeStreamDelta,
ServeStreamResponse,
ServeTextPart,
ServeVQPart,
)
def initialize_decode_buffers(num_samples):
"""Initialise the decode buffers for each sample."""
decode_buffer = [[] for _ in range(num_samples)]
parts = [[] for _ in range(num_samples)]
finished = [False for _ in range(num_samples)]
return decode_buffer, parts, finished
def send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request):
"""Send the remaining text buffer for a sample."""
if len(decode_buffer[sample_id]) == 0:
return []
decoded = tokenizer.decode(decode_buffer[sample_id])
part = ServeTextPart(text=decoded)
responses = []
if request.streaming:
responses.append(ServeStreamResponse(delta=ServeStreamDelta(part=part)))
else:
parts[sample_id].append(part)
decode_buffer[sample_id] = []
return responses
def handle_semantic_tokens(tokens, config, sample_id, parts, request):
"""Handle the semantic tokens returned by the model."""
responses = []
_tokens = tokens[1:].clone()
if not config.share_codebook_embeddings:
for i in range(len(_tokens)):
_tokens[i] -= config.codebook_size * i
# If streaming, send the VQ parts directly
if request.streaming:
responses.append(
ServeStreamResponse(
sample_id=sample_id,
delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
)
)
else:
# If not streaming, accumulate the VQ parts
if not parts[sample_id] or not isinstance(parts[sample_id][-1], ServeVQPart):
parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
else:
# Accumulate the codes
for codebook_id, value in enumerate(_tokens):
parts[sample_id][-1].codes[codebook_id].append(value.item())
return responses
def process_response_tokens(
response,
tokenizer,
config,
request,
decode_buffer,
parts,
finished,
im_end_id,
stats,
start,
is_first_token,
):
"""Process the response tokens returned by the model."""
responses = []
for sample_id, tokens in enumerate(response):
if finished[sample_id]:
continue
# End of the conversation
if tokens[0] == im_end_id:
finished[sample_id] = True
# Send the remaining text buffer
responses.extend(
send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
)
if request.streaming:
responses.append(
ServeStreamResponse(
sample_id=sample_id,
finish_reason="stop",
stats=stats,
)
)
continue
# Check if the token is semantic
is_semantic = (
tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id
)
if is_semantic:
# Before the semantic tokens, send the remaining text buffer
responses.extend(
send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
)
responses.extend(
handle_semantic_tokens(tokens, config, sample_id, parts, request)
)
else:
# Accumulate the text tokens (not implemented?)
decode_buffer[sample_id].append(tokens[0, 0])
if is_first_token:
stats["time_to_first_token"] = (time.time() - start) * 1000
return responses

View File

@ -0,0 +1,72 @@
import queue
from fish_speech.conversation import Conversation, Message
from fish_speech.tokenizer import IM_END_TOKEN
from tools.llama.generate import GenerateRequest
def prepare_messages(request, tokenizer, config):
"""
Reorganise the provided list of messages into a conversation.
Encode the conversation for inference.
"""
# Convert the messages to ConversationMessage objects
messages = [msg.to_conversation_message() for msg in request.messages]
if len(messages) < 1:
raise ValueError("At least one message is required")
# Check the last message to determine the next step
last_role = messages[-1].role
match last_role:
case "user":
# The last message is from the user, ask the assistant to respond with a new message
messages.append(
Message(role="assistant", parts=[], add_im_end=False, modality="voice")
)
case "raw":
# The last message is raw text, ask the assistant to complete it
messages[-1].add_im_start = False
messages[-1].add_im_end = False
messages[-1].modality = "voice"
case "assistant":
# The last message is from the assistant, ask the assistant to continue
messages[-1].add_im_end = False
case _:
# We expect it to be assistant if not user or raw
raise ValueError("The last message must be from the assistant, user or raw")
# Create a conversation object and encode it for inference
conv = Conversation(messages=messages)
prompt = conv.encode_for_inference(
tokenizer=tokenizer, num_codebooks=config.num_codebooks
)
im_end_id = tokenizer.get_token_id(IM_END_TOKEN)
return prompt, im_end_id
def create_generation_request(prompt, request, im_end_id, device):
"""
Convert the request into a dictionary that can be sent to the model for generation.
"""
req = {
"prompt": prompt.to(device),
"max_new_tokens": request.max_new_tokens,
"im_end_id": im_end_id,
"temperature": request.temperature,
"top_p": request.top_p,
"repetition_penalty": request.repetition_penalty,
"num_samples": request.num_samples,
"early_stop_threshold": request.early_stop_threshold,
}
return req
def send_generation_request(input_queue, req):
"""
Send the generation request to the model and return a queue to get the response.
"""
response_queue = queue.Queue()
input_queue.put(GenerateRequest(req, response_queue))
return response_queue

75
tools/server/api_utils.py Normal file
View File

@ -0,0 +1,75 @@
from argparse import ArgumentParser
from http import HTTPStatus
from typing import Annotated, Any
import ormsgpack
from baize.datastructures import ContentType
from kui.asgi import HTTPException, HttpRequest
from tools.inference_engine import TTSInferenceEngine
from tools.schema import ServeTTSRequest
from tools.server.inference import inference_wrapper as inference
def parse_args():
parser = ArgumentParser()
parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
parser.add_argument("--load-asr-model", action="store_true")
parser.add_argument(
"--llama-checkpoint-path",
type=str,
default="checkpoints/fish-speech-1.5",
)
parser.add_argument(
"--decoder-checkpoint-path",
type=str,
default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
)
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--half", action="store_true")
parser.add_argument("--compile", action="store_true")
parser.add_argument("--max-text-length", type=int, default=0)
parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
parser.add_argument("--workers", type=int, default=1)
return parser.parse_args()
class MsgPackRequest(HttpRequest):
async def data(
self,
) -> Annotated[
Any, ContentType("application/msgpack"), ContentType("application/json")
]:
if self.content_type == "application/msgpack":
return ormsgpack.unpackb(await self.body)
elif self.content_type == "application/json":
return await self.json
raise HTTPException(
HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
headers={"Accept": "application/msgpack, application/json"},
)
async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine):
for chunk in inference(req, engine):
if isinstance(chunk, bytes):
yield chunk
async def buffer_to_async_generator(buffer):
yield buffer
def get_content_type(audio_format):
if audio_format == "wav":
return "audio/wav"
elif audio_format == "flac":
return "audio/flac"
elif audio_format == "mp3":
return "audio/mpeg"
else:
return "application/octet-stream"

View File

@ -0,0 +1,27 @@
import traceback
from http import HTTPStatus
from kui.asgi import HTTPException, JSONResponse
class ExceptionHandler:
async def http_exception_handler(self, exc: HTTPException):
return JSONResponse(
dict(
statusCode=exc.status_code,
message=exc.content,
error=HTTPStatus(exc.status_code).phrase,
),
exc.status_code,
exc.headers,
)
async def other_exception_handler(self, exc: Exception):
traceback.print_exc()
status = HTTPStatus.INTERNAL_SERVER_ERROR
return JSONResponse(
dict(statusCode=status, message=str(exc), error=status.phrase),
status,
)

41
tools/server/inference.py Normal file
View File

@ -0,0 +1,41 @@
from http import HTTPStatus
import numpy as np
from kui.asgi import HTTPException
from tools.inference_engine import TTSInferenceEngine
from tools.schema import ServeTTSRequest
AMPLITUDE = 32768 # Needs an explaination
def inference_wrapper(req: ServeTTSRequest, engine: TTSInferenceEngine):
"""
Wrapper for the inference function.
Used in the API server.
"""
for result in engine.inference(req):
match result.code:
case "header":
if isinstance(result.audio, tuple):
yield result.audio[1]
case "error":
raise HTTPException(
HTTPStatus.INTERNAL_SERVER_ERROR,
content=str(result.error),
)
case "segment":
if isinstance(result.audio, tuple):
yield (result.audio[1] * AMPLITUDE).astype(np.int16).tobytes()
case "final":
if isinstance(result.audio, tuple):
yield result.audio[1]
return None # Stop the generator
raise HTTPException(
HTTPStatus.INTERNAL_SERVER_ERROR,
content="No audio generated, please check the input text.",
)

View File

@ -0,0 +1,119 @@
import torch
from funasr import AutoModel
from loguru import logger
from tools.inference_engine import TTSInferenceEngine
from tools.llama.generate import (
launch_thread_safe_queue,
launch_thread_safe_queue_agent,
)
from tools.schema import ServeTTSRequest
from tools.server.inference import inference_wrapper as inference
from tools.vqgan.inference import load_model as load_decoder_model
ASR_MODEL_NAME = "iic/SenseVoiceSmall"
class ModelManager:
def __init__(
self,
mode: str,
device: str,
half: bool,
compile: bool,
asr_enabled: bool,
llama_checkpoint_path: str,
decoder_checkpoint_path: str,
decoder_config_name: str,
) -> None:
self.mode = mode
self.device = device
self.half = half
self.compile = compile
self.precision = torch.half if half else torch.bfloat16
# Check if CUDA is available
if not torch.cuda.is_available():
self.device = "cpu"
logger.info("CUDA is not available, running on CPU.")
# Load the ASR model if enabled
if asr_enabled:
self.load_asr_model(self.device)
# Load the TTS models
self.load_llama_model(
llama_checkpoint_path, self.device, self.precision, self.compile, self.mode
)
self.load_decoder_model(
decoder_config_name, decoder_checkpoint_path, self.device
)
self.tts_inference_engine = TTSInferenceEngine(
llama_queue=self.llama_queue,
decoder_model=self.decoder_model,
precision=self.precision,
compile=self.compile,
)
# Warm up the models
if self.mode == "tts":
self.warm_up(self.tts_inference_engine)
def load_asr_model(self, device, hub="ms") -> None:
self.asr_model = AutoModel(
model=ASR_MODEL_NAME,
device=device,
disable_pbar=True,
hub=hub,
)
logger.info("ASR model loaded.")
def load_llama_model(
self, checkpoint_path, device, precision, compile, mode
) -> None:
if mode == "tts":
self.llama_queue = launch_thread_safe_queue(
checkpoint_path=checkpoint_path,
device=device,
precision=precision,
compile=compile,
)
elif mode == "agent":
self.llama_queue, self.tokenizer, self.config = (
launch_thread_safe_queue_agent(
checkpoint_path=checkpoint_path,
device=device,
precision=precision,
compile=compile,
)
)
else:
raise ValueError(f"Invalid mode: {mode}")
logger.info("LLAMA model loaded.")
def load_decoder_model(self, config_name, checkpoint_path, device) -> None:
self.decoder_model = load_decoder_model(
config_name=config_name,
checkpoint_path=checkpoint_path,
device=device,
)
logger.info("Decoder model loaded.")
def warm_up(self, tts_inference_engine) -> None:
request = ServeTTSRequest(
text="Hello world.",
references=[],
reference_id=None,
max_new_tokens=0,
chunk_length=200,
top_p=0.7,
repetition_penalty=1.5,
temperature=0.7,
format="wav",
)
list(inference(request, tts_inference_engine))
logger.info("Models warmed up.")

129
tools/server/model_utils.py Normal file
View File

@ -0,0 +1,129 @@
import io
import re
import librosa
import torch
import torchaudio
from cachetools import LRUCache, cached
CACHE_MAXSIZE = 10000
MICRO_BATCH_SIZE = 8
ASR_SAMPLE_RATE = 16000
HUGE_GAP_THRESHOLD = 4000
@torch.no_grad()
@torch.autocast(device_type="cuda", dtype=torch.half)
def batch_encode(model, audios_list: list[bytes]):
audios: list[torch.Tensor] = [
(
torch.from_numpy(
librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
)[None]
if isinstance(audio, bytes)
else audio
)
for audio in audios_list
]
lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
max_length = lengths.max().item()
print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
padded = torch.stack(
[
torch.nn.functional.pad(audio, (0, int(max_length - audio.shape[-1])))
for audio in audios
]
).to(model.device)
features, feature_lengths = model.encode(padded, audio_lengths=lengths)
features, feature_lengths = features.cpu(), feature_lengths.cpu()
return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
@cached(
cache=LRUCache(maxsize=CACHE_MAXSIZE),
key=lambda model, audios: (model.device, tuple(audios)),
)
def cached_vqgan_batch_encode(model, audios: list[bytes]):
return batch_encode(model, audios)
@torch.no_grad()
@torch.autocast(device_type="cuda", dtype=torch.half)
def vqgan_decode(model, features):
lengths = torch.tensor(
[feature.shape[-1] for feature in features], device=model.device
)
max_length = lengths.max().item()
padded = torch.stack(
[
torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
for feature in features
]
).to(model.device)
# If bs too large, we do micro batch decode
audios, audio_lengths = [], []
for i in range(0, padded.shape[0], MICRO_BATCH_SIZE):
audio, audio_length = model.decode(
padded[i : i + MICRO_BATCH_SIZE],
feature_lengths=lengths[i : i + MICRO_BATCH_SIZE],
)
audios.append(audio)
audio_lengths.append(audio_length)
audios = torch.cat(audios, dim=0)
audio_lengths = torch.cat(audio_lengths, dim=0)
audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
@torch.no_grad()
def batch_asr(model, lock, audios, sr, language="auto"):
resampled_audios = []
for audio in audios:
audio = torchaudio.functional.resample(audio, sr, ASR_SAMPLE_RATE)
assert audio.ndim == 1
resampled_audios.append(audio)
with lock:
res = model.generate(
input=resampled_audios,
batch_size=len(resampled_audios),
language=language,
use_itn=True,
)
results = []
for r, audio in zip(res, audios):
text = r["text"]
text = re.sub(r"<\|.*?\|>", "", text)
duration = len(audio) / sr * 1000
huge_gap = False
if "timestamp" in r and len(r["timestamp"]) > 2:
for timestamp_a, timestamp_b in zip(
r["timestamp"][:-1], r["timestamp"][1:]
):
# If there is a gap of more than 4 seconds, we consider it as a huge gap
if timestamp_b[0] - timestamp_a[1] > HUGE_GAP_THRESHOLD:
huge_gap = True
break
# Doesn't make sense to have a huge gap at the end
if duration - r["timestamp"][-1][1] > HUGE_GAP_THRESHOLD:
huge_gap = True
results.append(
{
"text": text,
"duration": duration,
"huge_gap": huge_gap,
}
)
return results

246
tools/server/views.py Normal file
View File

@ -0,0 +1,246 @@
import io
import os
import time
from http import HTTPStatus
import numpy as np
import ormsgpack
import soundfile as sf
import torch
from kui.asgi import HTTPException, HttpView, JSONResponse, StreamResponse, request
from loguru import logger
from tools.schema import (
ServeASRRequest,
ServeASRResponse,
ServeChatRequest,
ServeTTSRequest,
ServeVQGANDecodeRequest,
ServeVQGANDecodeResponse,
ServeVQGANEncodeRequest,
ServeVQGANEncodeResponse,
)
from tools.server.agent import get_response_generator
from tools.server.api_utils import (
buffer_to_async_generator,
get_content_type,
inference_async,
)
from tools.server.inference import inference_wrapper as inference
from tools.server.model_manager import ModelManager
from tools.server.model_utils import batch_asr, cached_vqgan_batch_encode, vqgan_decode
MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1))
class HealthView(HttpView):
"""
Return the health status of the server.
"""
@classmethod
async def post(cls):
return JSONResponse({"status": "ok"})
class VQGANEncodeView(HttpView):
"""
Encode the audio into symbolic tokens.
"""
@classmethod
async def post(cls):
# Decode the request
payload = await request.data()
req = ServeVQGANEncodeRequest(**payload)
# Get the model from the app
model_manager: ModelManager = request.app.state.model_manager
decoder_model = model_manager.decoder_model
# Encode the audio
start_time = time.time()
tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
logger.info(
f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms"
)
# Return the response
return ormsgpack.packb(
ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
)
class VQGANDecodeView(HttpView):
"""
Decode the symbolic tokens into audio.
"""
@classmethod
async def post(cls):
# Decode the request
payload = await request.data()
req = ServeVQGANDecodeRequest(**payload)
# Get the model from the app
model_manager: ModelManager = request.app.state.model_manager
decoder_model = model_manager.decoder_model
# Decode the audio
tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
start_time = time.time()
audios = vqgan_decode(decoder_model, tokens)
logger.info(
f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms"
)
audios = [audio.astype(np.float16).tobytes() for audio in audios]
# Return the response
return ormsgpack.packb(
ServeVQGANDecodeResponse(audios=audios),
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
)
class ASRView(HttpView):
"""
Perform automatic speech recognition on the audio.
"""
@classmethod
async def post(cls):
# Decode the request
payload = await request.data()
req = ServeASRRequest(**payload)
# Get the model from the app
model_manager: ModelManager = request.app.state.model_manager
asr_model = model_manager.asr_model
lock = request.app.state.lock
# Perform ASR
start_time = time.time()
audios = [np.frombuffer(audio, dtype=np.float16) for audio in req.audios]
audios = [torch.from_numpy(audio).float() for audio in audios]
if any(audios.shape[-1] >= 30 * req.sample_rate for audios in audios):
raise HTTPException(status_code=400, content="Audio length is too long")
transcriptions = batch_asr(
asr_model, lock, audios=audios, sr=req.sample_rate, language=req.language
)
logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
# Return the response
return ormsgpack.packb(
ServeASRResponse(transcriptions=transcriptions),
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
)
class TTSView(HttpView):
"""
Perform text-to-speech on the input text.
"""
@classmethod
async def post(cls):
# Decode the request
payload = await request.data()
req = ServeTTSRequest(**payload)
# Get the model from the app
app_state = request.app.state
model_manager: ModelManager = app_state.model_manager
engine = model_manager.tts_inference_engine
sample_rate = engine.decoder_model.spec_transform.sample_rate
# Check if the text is too long
if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
raise HTTPException(
HTTPStatus.BAD_REQUEST,
content=f"Text is too long, max length is {app_state.max_text_length}",
)
# Check if streaming is enabled
if req.streaming and req.format != "wav":
raise HTTPException(
HTTPStatus.BAD_REQUEST,
content="Streaming only supports WAV format",
)
# Perform TTS
if req.streaming:
return StreamResponse(
iterable=inference_async(req, engine),
headers={
"Content-Disposition": f"attachment; filename=audio.{req.format}",
},
content_type=get_content_type(req.format),
)
else:
fake_audios = next(inference(req, engine))
buffer = io.BytesIO()
sf.write(
buffer,
fake_audios,
sample_rate,
format=req.format,
)
return StreamResponse(
iterable=buffer_to_async_generator(buffer.getvalue()),
headers={
"Content-Disposition": f"attachment; filename=audio.{req.format}",
},
content_type=get_content_type(req.format),
)
class ChatView(HttpView):
"""
Perform chatbot inference on the input text.
"""
@classmethod
async def post(cls):
# Decode the request
payload = await request.data()
req = ServeChatRequest(**payload)
# Check that the number of samples requested is correct
if req.num_samples < 1 or req.num_samples > MAX_NUM_SAMPLES:
raise HTTPException(
HTTPStatus.BAD_REQUEST,
content=f"Number of samples must be between 1 and {MAX_NUM_SAMPLES}",
)
# Get the type of content provided
content_type = request.headers.get("Content-Type", "application/json")
json_mode = "application/json" in content_type
# Get the models from the app
model_manager: ModelManager = request.app.state.model_manager
llama_queue = model_manager.llama_queue
tokenizer = model_manager.tokenizer
config = model_manager.config
device = request.app.state.device
# Get the response generators
response_generator = get_response_generator(
llama_queue, tokenizer, config, req, device, json_mode
)
# Return the response in the correct format
if req.streaming is False:
result = response_generator()
if json_mode:
return JSONResponse(result.model_dump())
else:
return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
return StreamResponse(
iterable=response_generator(), content_type="text/event-stream"
)

View File

@ -1,570 +0,0 @@
import gc
import html
import io
import os
import queue
import wave
from argparse import ArgumentParser
from functools import partial
from pathlib import Path
import gradio as gr
import librosa
import numpy as np
import pyrootutils
import torch
from loguru import logger
from transformers import AutoTokenizer
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from fish_speech.i18n import i18n
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
from fish_speech.utils import autocast_exclude_mps, set_seed
from tools.api import decode_vq_tokens, encode_reference
from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
from tools.llama.generate import (
GenerateRequest,
GenerateResponse,
WrappedGenerateResponse,
launch_thread_safe_queue,
)
from tools.schema import (
GLOBAL_NUM_SAMPLES,
ASRPackRequest,
ServeASRRequest,
ServeASRResponse,
ServeASRSegment,
ServeAudioPart,
ServeForwardMessage,
ServeMessage,
ServeReferenceAudio,
ServeRequest,
ServeResponse,
ServeStreamDelta,
ServeStreamResponse,
ServeTextPart,
ServeTimedASRResponse,
ServeTTSRequest,
ServeVQGANDecodeRequest,
ServeVQGANDecodeResponse,
ServeVQGANEncodeRequest,
ServeVQGANEncodeResponse,
ServeVQPart,
)
from tools.vqgan.inference import load_model as load_decoder_model
# Make einx happy
os.environ["EINX_FILTER_TRACEBACK"] = "false"
HEADER_MD = f"""# Fish Speech
{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).")}
{i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}
{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
"""
TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
SPACE_IMPORTED = False
def build_html_error_message(error):
return f"""
<div style="color: red;
font-weight: bold;">
{html.escape(str(error))}
</div>
"""
@torch.inference_mode()
def inference(req: ServeTTSRequest):
idstr: str | None = req.reference_id
prompt_tokens, prompt_texts = [], []
if idstr is not None:
ref_folder = Path("references") / idstr
ref_folder.mkdir(parents=True, exist_ok=True)
ref_audios = list_files(
ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
)
if req.use_memory_cache == "never" or (
req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
):
prompt_tokens = [
encode_reference(
decoder_model=decoder_model,
reference_audio=audio_to_bytes(str(ref_audio)),
enable_reference_audio=True,
)
for ref_audio in ref_audios
]
prompt_texts = [
read_ref_text(str(ref_audio.with_suffix(".lab")))
for ref_audio in ref_audios
]
else:
logger.info("Use same references")
else:
# Parse reference audio aka prompt
refs = req.references
if req.use_memory_cache == "never" or (
req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
):
prompt_tokens = [
encode_reference(
decoder_model=decoder_model,
reference_audio=ref.audio,
enable_reference_audio=True,
)
for ref in refs
]
prompt_texts = [ref.text for ref in refs]
else:
logger.info("Use same references")
if req.seed is not None:
set_seed(req.seed)
logger.warning(f"set seed: {req.seed}")
# LLAMA Inference
request = dict(
device=decoder_model.device,
max_new_tokens=req.max_new_tokens,
text=(
req.text
if not req.normalize
else ChnNormedText(raw_text=req.text).normalize()
),
top_p=req.top_p,
repetition_penalty=req.repetition_penalty,
temperature=req.temperature,
compile=args.compile,
iterative_prompt=req.chunk_length > 0,
chunk_length=req.chunk_length,
max_length=4096,
prompt_tokens=prompt_tokens,
prompt_text=prompt_texts,
)
response_queue = queue.Queue()
llama_queue.put(
GenerateRequest(
request=request,
response_queue=response_queue,
)
)
segments = []
while True:
result: WrappedGenerateResponse = response_queue.get()
if result.status == "error":
yield None, None, build_html_error_message(result.response)
break
result: GenerateResponse = result.response
if result.action == "next":
break
with autocast_exclude_mps(
device_type=decoder_model.device.type, dtype=args.precision
):
fake_audios = decode_vq_tokens(
decoder_model=decoder_model,
codes=result.codes,
)
fake_audios = fake_audios.float().cpu().numpy()
segments.append(fake_audios)
if len(segments) == 0:
return (
None,
None,
build_html_error_message(
i18n("No audio generated, please check the input text.")
),
)
# No matter streaming or not, we need to return the final audio
audio = np.concatenate(segments, axis=0)
yield None, (decoder_model.spec_transform.sample_rate, audio), None
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
n_audios = 4
global_audio_list = []
global_error_list = []
def inference_wrapper(
text,
enable_reference_audio,
reference_audio,
reference_text,
max_new_tokens,
chunk_length,
top_p,
repetition_penalty,
temperature,
seed,
batch_infer_num,
):
audios = []
errors = []
for _ in range(batch_infer_num):
result = inference(
text,
enable_reference_audio,
reference_audio,
reference_text,
max_new_tokens,
chunk_length,
top_p,
repetition_penalty,
temperature,
seed,
)
_, audio_data, error_message = next(result)
audios.append(
gr.Audio(value=audio_data if audio_data else None, visible=True),
)
errors.append(
gr.HTML(value=error_message if error_message else None, visible=True),
)
for _ in range(batch_infer_num, n_audios):
audios.append(
gr.Audio(value=None, visible=False),
)
errors.append(
gr.HTML(value=None, visible=False),
)
return None, *audios, *errors
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
buffer = io.BytesIO()
with wave.open(buffer, "wb") as wav_file:
wav_file.setnchannels(channels)
wav_file.setsampwidth(bit_depth // 8)
wav_file.setframerate(sample_rate)
wav_header_bytes = buffer.getvalue()
buffer.close()
return wav_header_bytes
def normalize_text(user_input, use_normalization):
if use_normalization:
return ChnNormedText(raw_text=user_input).normalize()
else:
return user_input
def update_examples():
examples_dir = Path("references")
examples_dir.mkdir(parents=True, exist_ok=True)
example_audios = list_files(examples_dir, AUDIO_EXTENSIONS, recursive=True)
return gr.Dropdown(choices=example_audios + [""])
def build_app():
with gr.Blocks(theme=gr.themes.Base()) as app:
gr.Markdown(HEADER_MD)
# Use light theme by default
app.load(
None,
None,
js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
% args.theme,
)
# Inference
with gr.Row():
with gr.Column(scale=3):
text = gr.Textbox(
label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
)
refined_text = gr.Textbox(
label=i18n("Realtime Transform Text"),
placeholder=i18n(
"Normalization Result Preview (Currently Only Chinese)"
),
lines=5,
interactive=False,
)
with gr.Row():
normalize = gr.Checkbox(
label=i18n("Text Normalization"),
value=False,
)
with gr.Row():
with gr.Column():
with gr.Tab(label=i18n("Advanced Config")):
with gr.Row():
chunk_length = gr.Slider(
label=i18n("Iterative Prompt Length, 0 means off"),
minimum=0,
maximum=300,
value=200,
step=8,
)
max_new_tokens = gr.Slider(
label=i18n(
"Maximum tokens per batch, 0 means no limit"
),
minimum=0,
maximum=2048,
value=0,
step=8,
)
with gr.Row():
top_p = gr.Slider(
label="Top-P",
minimum=0.6,
maximum=0.9,
value=0.7,
step=0.01,
)
repetition_penalty = gr.Slider(
label=i18n("Repetition Penalty"),
minimum=1,
maximum=1.5,
value=1.2,
step=0.01,
)
with gr.Row():
temperature = gr.Slider(
label="Temperature",
minimum=0.6,
maximum=0.9,
value=0.7,
step=0.01,
)
seed = gr.Number(
label="Seed",
info="0 means randomized inference, otherwise deterministic",
value=0,
)
with gr.Tab(label=i18n("Reference Audio")):
with gr.Row():
gr.Markdown(
i18n(
"5 to 10 seconds of reference audio, useful for specifying speaker."
)
)
with gr.Row():
reference_id = gr.Textbox(
label=i18n("Reference ID"),
placeholder="Leave empty to use uploaded references",
)
with gr.Row():
use_memory_cache = gr.Radio(
label=i18n("Use Memory Cache"),
choices=["never", "on-demand", "always"],
value="on-demand",
)
with gr.Row():
reference_audio = gr.Audio(
label=i18n("Reference Audio"),
type="filepath",
)
with gr.Row():
reference_text = gr.Textbox(
label=i18n("Reference Text"),
lines=1,
placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
value="",
)
with gr.Column(scale=3):
with gr.Row():
error = gr.HTML(
label=i18n("Error Message"),
visible=True,
)
with gr.Row():
audio = gr.Audio(
label=i18n("Generated Audio"),
type="numpy",
interactive=False,
visible=True,
)
with gr.Row():
with gr.Column(scale=3):
generate = gr.Button(
value="\U0001F3A7 " + i18n("Generate"), variant="primary"
)
text.input(fn=normalize_text, inputs=[text, normalize], outputs=[refined_text])
def inference_wrapper(
text,
normalize,
reference_id,
reference_audio,
reference_text,
max_new_tokens,
chunk_length,
top_p,
repetition_penalty,
temperature,
seed,
use_memory_cache,
):
references = []
if reference_audio:
# 将文件路径转换为字节
with open(reference_audio, "rb") as audio_file:
audio_bytes = audio_file.read()
references = [
ServeReferenceAudio(audio=audio_bytes, text=reference_text)
]
req = ServeTTSRequest(
text=text,
normalize=normalize,
reference_id=reference_id if reference_id else None,
references=references,
max_new_tokens=max_new_tokens,
chunk_length=chunk_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
temperature=temperature,
seed=int(seed) if seed else None,
use_memory_cache=use_memory_cache,
)
for result in inference(req):
if result[2]: # Error message
return None, result[2]
elif result[1]: # Audio data
return result[1], None
return None, i18n("No audio generated")
# Submit
generate.click(
inference_wrapper,
[
refined_text,
normalize,
reference_id,
reference_audio,
reference_text,
max_new_tokens,
chunk_length,
top_p,
repetition_penalty,
temperature,
seed,
use_memory_cache,
],
[audio, error],
concurrency_limit=1,
)
return app
def parse_args():
parser = ArgumentParser()
parser.add_argument(
"--llama-checkpoint-path",
type=Path,
default="checkpoints/fish-speech-1.5",
)
parser.add_argument(
"--decoder-checkpoint-path",
type=Path,
default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
)
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--half", action="store_true")
parser.add_argument("--compile", action="store_true")
parser.add_argument("--max-gradio-length", type=int, default=0)
parser.add_argument("--theme", type=str, default="light")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
args.precision = torch.half if args.half else torch.bfloat16
# Check if CUDA is available
if not torch.cuda.is_available():
logger.info("CUDA is not available, running on CPU.")
args.device = "cpu"
logger.info("Loading Llama model...")
llama_queue = launch_thread_safe_queue(
checkpoint_path=args.llama_checkpoint_path,
device=args.device,
precision=args.precision,
compile=args.compile,
)
logger.info("Llama model loaded, loading VQ-GAN model...")
decoder_model = load_decoder_model(
config_name=args.decoder_config_name,
checkpoint_path=args.decoder_checkpoint_path,
device=args.device,
)
logger.info("Decoder model loaded, warming up...")
# Dry run to check if the model is loaded correctly and avoid the first-time latency
list(
inference(
ServeTTSRequest(
text="Hello world.",
references=[],
reference_id=None,
max_new_tokens=0,
chunk_length=200,
top_p=0.7,
repetition_penalty=1.5,
temperature=0.7,
emotion=None,
format="wav",
)
)
)
logger.info("Warming up done, launching the web UI...")
app = build_app()
app.launch(show_api=True)

173
tools/webui/__init__.py Normal file
View File

@ -0,0 +1,173 @@
from typing import Callable
import gradio as gr
from fish_speech.i18n import i18n
from tools.inference_engine.utils import normalize_text
from tools.webui.variables import HEADER_MD, TEXTBOX_PLACEHOLDER
def build_app(inference_fct: Callable, theme: str = "light") -> gr.Blocks:
with gr.Blocks(theme=gr.themes.Base()) as app:
gr.Markdown(HEADER_MD)
# Use light theme by default
app.load(
None,
None,
js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
% theme,
)
# Inference
with gr.Row():
with gr.Column(scale=3):
text = gr.Textbox(
label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
)
refined_text = gr.Textbox(
label=i18n("Realtime Transform Text"),
placeholder=i18n(
"Normalization Result Preview (Currently Only Chinese)"
),
lines=5,
interactive=False,
)
with gr.Row():
normalize = gr.Checkbox(
label=i18n("Text Normalization"),
value=False,
)
with gr.Row():
with gr.Column():
with gr.Tab(label=i18n("Advanced Config")):
with gr.Row():
chunk_length = gr.Slider(
label=i18n("Iterative Prompt Length, 0 means off"),
minimum=0,
maximum=300,
value=200,
step=8,
)
max_new_tokens = gr.Slider(
label=i18n(
"Maximum tokens per batch, 0 means no limit"
),
minimum=0,
maximum=2048,
value=0,
step=8,
)
with gr.Row():
top_p = gr.Slider(
label="Top-P",
minimum=0.6,
maximum=0.9,
value=0.7,
step=0.01,
)
repetition_penalty = gr.Slider(
label=i18n("Repetition Penalty"),
minimum=1,
maximum=1.5,
value=1.2,
step=0.01,
)
with gr.Row():
temperature = gr.Slider(
label="Temperature",
minimum=0.6,
maximum=0.9,
value=0.7,
step=0.01,
)
seed = gr.Number(
label="Seed",
info="0 means randomized inference, otherwise deterministic",
value=0,
)
with gr.Tab(label=i18n("Reference Audio")):
with gr.Row():
gr.Markdown(
i18n(
"5 to 10 seconds of reference audio, useful for specifying speaker."
)
)
with gr.Row():
reference_id = gr.Textbox(
label=i18n("Reference ID"),
placeholder="Leave empty to use uploaded references",
)
with gr.Row():
use_memory_cache = gr.Radio(
label=i18n("Use Memory Cache"),
choices=["on", "off"],
value="on",
)
with gr.Row():
reference_audio = gr.Audio(
label=i18n("Reference Audio"),
type="filepath",
)
with gr.Row():
reference_text = gr.Textbox(
label=i18n("Reference Text"),
lines=1,
placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
value="",
)
with gr.Column(scale=3):
with gr.Row():
error = gr.HTML(
label=i18n("Error Message"),
visible=True,
)
with gr.Row():
audio = gr.Audio(
label=i18n("Generated Audio"),
type="numpy",
interactive=False,
visible=True,
)
with gr.Row():
with gr.Column(scale=3):
generate = gr.Button(
value="\U0001F3A7 " + i18n("Generate"),
variant="primary",
)
text.input(fn=normalize_text, inputs=[text, normalize], outputs=[refined_text])
# Submit
generate.click(
inference_fct,
[
refined_text,
normalize,
reference_id,
reference_audio,
reference_text,
max_new_tokens,
chunk_length,
top_p,
repetition_penalty,
temperature,
seed,
use_memory_cache,
],
[audio, error],
concurrency_limit=1,
)
return app

91
tools/webui/inference.py Normal file
View File

@ -0,0 +1,91 @@
import html
from functools import partial
from typing import Any, Callable
from fish_speech.i18n import i18n
from tools.schema import ServeReferenceAudio, ServeTTSRequest
def inference_wrapper(
text,
normalize,
reference_id,
reference_audio,
reference_text,
max_new_tokens,
chunk_length,
top_p,
repetition_penalty,
temperature,
seed,
use_memory_cache,
engine,
):
"""
Wrapper for the inference function.
Used in the Gradio interface.
"""
if reference_audio:
references = get_reference_audio(reference_audio, reference_text)
else:
references = []
req = ServeTTSRequest(
text=text,
normalize=normalize,
reference_id=reference_id if reference_id else None,
references=references,
max_new_tokens=max_new_tokens,
chunk_length=chunk_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
temperature=temperature,
seed=int(seed) if seed else None,
use_memory_cache=use_memory_cache,
)
for result in engine.inference(req):
match result.code:
case "final":
return result.audio, None
case "error":
return None, build_html_error_message(i18n(result.error))
case _:
pass
return None, i18n("No audio generated")
def get_reference_audio(reference_audio: str, reference_text: str) -> list:
"""
Get the reference audio bytes.
"""
with open(reference_audio, "rb") as audio_file:
audio_bytes = audio_file.read()
return [ServeReferenceAudio(audio=audio_bytes, text=reference_text)]
def build_html_error_message(error: Any) -> str:
error = error if isinstance(error, Exception) else Exception("Unknown error")
return f"""
<div style="color: red;
font-weight: bold;">
{html.escape(str(error))}
</div>
"""
def get_inference_wrapper(engine) -> Callable:
"""
Get the inference function with the immutable arguments.
"""
return partial(
inference_wrapper,
engine=engine,
)

14
tools/webui/variables.py Normal file
View File

@ -0,0 +1,14 @@
from fish_speech.i18n import i18n
HEADER_MD = f"""# Fish Speech
{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).")}
{i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}
{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
"""
TEXTBOX_PLACEHOLDER = i18n("Put your text here.")