Make WebUI and API code cleaner (+ 1.5 fixes) (#703)
* rename webui.py to run_webui.py * remove unused imports * remove unsued code * move inference code and fix all warnings * move web app code * make code easier to read * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove unused function * remove msgpack_api.py * rename API files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * finish updating the doc with the new file names * finish updating the doc with the new file names * fix CPU use in the API * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor WebUIinference in a class with submodules * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * re-enable streaming in webui inference code * generalize inference code in webui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix * make a unique inference engine class * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix * cleaning code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * implement new structure of the API (not working) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor API * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * reimplement chat endpoint * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
954cae1b5d
commit
62eae262c2
2
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@ -45,7 +45,7 @@ body:
|
||||
description: |
|
||||
Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks.
|
||||
placeholder: |
|
||||
1. Run the command `python -m tools.post_api -t "xxxxx"`
|
||||
1. Run the command `python -m tools.api_client -t "xxxxx"`
|
||||
2. Observe the console output error: `ModuleNotFoundError: No module named 'pyaudio'` (with screenshots or logs will be better)
|
||||
validations:
|
||||
required: true
|
||||
|
@ -185,7 +185,7 @@ pip install -e .[stable]
|
||||
4. Configure environment variables and access WebUI
|
||||
|
||||
In the terminal inside the docker container, enter `export GRADIO_SERVER_NAME="0.0.0.0"` to allow external access to the gradio service inside docker.
|
||||
Then in the terminal inside the docker container, enter `python tools/webui.py` to start the WebUI service.
|
||||
Then in the terminal inside the docker container, enter `python tools/run_webui.py` to start the WebUI service.
|
||||
|
||||
If you're using WSL or MacOS, visit [http://localhost:7860](http://localhost:7860) to open the WebUI interface.
|
||||
|
||||
|
@ -67,7 +67,7 @@ python tools/vqgan/inference.py \
|
||||
We provide a HTTP API for inference. You can use the following command to start the server:
|
||||
|
||||
```bash
|
||||
python -m tools.api \
|
||||
python -m tools.api_server \
|
||||
--listen 0.0.0.0:8080 \
|
||||
--llama-checkpoint-path "checkpoints/fish-speech-1.5" \
|
||||
--decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
|
||||
@ -78,10 +78,10 @@ python -m tools.api \
|
||||
|
||||
After that, you can view and test the API at http://127.0.0.1:8080/.
|
||||
|
||||
Below is an example of sending a request using `tools/post_api.py`.
|
||||
Below is an example of sending a request using `tools/api_client.py`.
|
||||
|
||||
```bash
|
||||
python -m tools.post_api \
|
||||
python -m tools.api_client \
|
||||
--text "Text to be input" \
|
||||
--reference_audio "Path to reference audio" \
|
||||
--reference_text "Text content of the reference audio" \
|
||||
@ -93,7 +93,7 @@ The above command indicates synthesizing the desired audio according to the refe
|
||||
The following example demonstrates that you can use **multiple** reference audio paths and reference audio texts at once. Separate them with spaces in the command.
|
||||
|
||||
```bash
|
||||
python -m tools.post_api \
|
||||
python -m tools.api_client \
|
||||
--text "Text to input" \
|
||||
--reference_audio "reference audio path1" "reference audio path2" \
|
||||
--reference_text "reference audio text1" "reference audio text2"\
|
||||
@ -109,7 +109,7 @@ The currently supported reference audio has a maximum total duration of 90 secon
|
||||
|
||||
|
||||
!!! info
|
||||
To learn more about available parameters, you can use the command `python -m tools.post_api -h`
|
||||
To learn more about available parameters, you can use the command `python -m tools.api_client -h`
|
||||
|
||||
## GUI Inference
|
||||
[Download client](https://github.com/AnyaCoder/fish-speech-gui/releases)
|
||||
|
@ -44,7 +44,7 @@ pip install -e .[stable]
|
||||
To build fish-agent, please use the command below under the main folder:
|
||||
|
||||
```bash
|
||||
python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
|
||||
python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
|
||||
```
|
||||
|
||||
The `--compile` args only support Python < 3.12 , which will greatly speed up the token generation.
|
||||
|
@ -184,7 +184,7 @@ pip install -e .[stable]
|
||||
4. 環境変数の設定と WebUI へのアクセス
|
||||
|
||||
Docker コンテナ内のターミナルで、`export GRADIO_SERVER_NAME="0.0.0.0"` と入力して、外部から Docker 内の gradio サービスにアクセスできるようにします。
|
||||
次に、Docker コンテナ内のターミナルで `python tools/webui.py` と入力して WebUI サービスを起動します。
|
||||
次に、Docker コンテナ内のターミナルで `python tools/run_webui.py` と入力して WebUI サービスを起動します。
|
||||
|
||||
WSL または MacOS の場合は、[http://localhost:7860](http://localhost:7860) にアクセスして WebUI インターフェースを開くことができます。
|
||||
|
||||
|
@ -67,7 +67,7 @@ python tools/vqgan/inference.py \
|
||||
推論のための HTTP API を提供しています。次のコマンドを使用してサーバーを起動できます:
|
||||
|
||||
```bash
|
||||
python -m tools.api \
|
||||
python -m tools.api_server \
|
||||
--listen 0.0.0.0:8080 \
|
||||
--llama-checkpoint-path "checkpoints/fish-speech-1.5" \
|
||||
--decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
|
||||
@ -78,10 +78,10 @@ python -m tools.api \
|
||||
|
||||
その後、`http://127.0.0.1:8080/`で API を表示およびテストできます。
|
||||
|
||||
以下は、`tools/post_api.py` を使用してリクエストを送信する例です。
|
||||
以下は、`tools/api_client.py` を使用してリクエストを送信する例です。
|
||||
|
||||
```bash
|
||||
python -m tools.post_api \
|
||||
python -m tools.api_client \
|
||||
--text "入力するテキスト" \
|
||||
--reference_audio "参照音声へのパス" \
|
||||
--reference_text "参照音声テキスト" \
|
||||
@ -91,7 +91,7 @@ python -m tools.post_api \
|
||||
上記のコマンドは、参照音声の情報に基づいて必要な音声を合成し、ストリーミング方式で返すことを示しています。
|
||||
|
||||
!!! info
|
||||
使用可能なパラメータの詳細については、コマンド` python -m tools.post_api -h `を使用してください
|
||||
使用可能なパラメータの詳細については、コマンド` python -m tools.api_client -h `を使用してください
|
||||
|
||||
## WebUI 推論
|
||||
|
||||
|
@ -47,7 +47,7 @@ pip install -e .[stable]
|
||||
fish-agentを構築するには、メインフォルダで以下のコマンドを使用してください:
|
||||
|
||||
```bash
|
||||
python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
|
||||
python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
|
||||
```
|
||||
|
||||
`--compile`引数はPython < 3.12でのみサポートされており、トークン生成を大幅に高速化します。
|
||||
|
@ -185,7 +185,7 @@ pip install -e .[stable]
|
||||
4. 환경 변수 설정 및 WebUI 접근
|
||||
|
||||
Docker 컨테이너 내부의 터미널에서 `export GRADIO_SERVER_NAME="0.0.0.0"`를 입력하여 Docker 내부에서 Gradio 서비스에 외부 접근을 허용합니다.
|
||||
이후, 터미널에서 `python tools/webui.py` 명령어를 입력하여 WebUI 서비스를 시작합니다.
|
||||
이후, 터미널에서 `python tools/run_webui.py` 명령어를 입력하여 WebUI 서비스를 시작합니다.
|
||||
|
||||
WSL 또는 macOS를 사용하는 경우 [http://localhost:7860](http://localhost:7860)에서 WebUI 인터페이스를 열 수 있습니다.
|
||||
|
||||
|
@ -67,7 +67,7 @@ python tools/vqgan/inference.py \
|
||||
추론을 위한 HTTP API를 제공하고 있습니다. 아래의 명령어로 서버를 시작할 수 있습니다:
|
||||
|
||||
```bash
|
||||
python -m tools.api \
|
||||
python -m tools.api_server \
|
||||
--listen 0.0.0.0:8080 \
|
||||
--llama-checkpoint-path "checkpoints/fish-speech-1.5" \
|
||||
--decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
|
||||
@ -78,10 +78,10 @@ python -m tools.api \
|
||||
|
||||
이후, http://127.0.0.1:8080/ 에서 API를 확인하고 테스트할 수 있습니다.
|
||||
|
||||
아래는 `tools/post_api.py`를 사용하여 요청을 보내는 예시입니다.
|
||||
아래는 `tools/api_client.py`를 사용하여 요청을 보내는 예시입니다.
|
||||
|
||||
```bash
|
||||
python -m tools.post_api \
|
||||
python -m tools.api_client \
|
||||
--text "입력할 텍스트" \
|
||||
--reference_audio "참고 음성 경로" \
|
||||
--reference_text "참고 음성의 텍스트 내용" \
|
||||
@ -93,7 +93,7 @@ python -m tools.post_api \
|
||||
다음 예시는 여러 개의 참고 음성 경로와 텍스트를 한꺼번에 사용할 수 있음을 보여줍니다. 명령에서 공백으로 구분하여 입력합니다.
|
||||
|
||||
```bash
|
||||
python -m tools.post_api \
|
||||
python -m tools.api_client \
|
||||
--text "입력할 텍스트" \
|
||||
--reference_audio "참고 음성 경로1" "참고 음성 경로2" \
|
||||
--reference_text "참고 음성 텍스트1" "참고 음성 텍스트2"\
|
||||
@ -107,7 +107,7 @@ python -m tools.post_api \
|
||||
`--reference_audio`와 `--reference_text` 대신에 `--reference_id`(하나만 사용 가능)를 사용할 수 있습니다. 프로젝트 루트 디렉토리에 `references/<your reference_id>` 폴더를 만들어 해당 음성과 주석 텍스트를 넣어야 합니다. 참고 음성은 최대 90초까지 지원됩니다.
|
||||
|
||||
!!! info
|
||||
제공되는 파라미터는 `python -m tools.post_api -h`를 사용하여 확인할 수 있습니다.
|
||||
제공되는 파라미터는 `python -m tools.api_client -h`를 사용하여 확인할 수 있습니다.
|
||||
|
||||
## GUI 추론
|
||||
[클라이언트 다운로드](https://github.com/AnyaCoder/fish-speech-gui/releases)
|
||||
|
@ -47,7 +47,7 @@ pip install -e .[stable]
|
||||
fish-agent를 구축하려면 메인 폴더에서 아래 명령어를 사용하세요:
|
||||
|
||||
```bash
|
||||
python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
|
||||
python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
|
||||
```
|
||||
|
||||
`--compile` 인자는 Python < 3.12에서만 지원되며, 토큰 생성 속도를 크게 향상시킵니다.
|
||||
|
@ -181,7 +181,7 @@ pip install -e .[stable]
|
||||
4. Configure as variáveis de ambiente e acesse a WebUI
|
||||
|
||||
No terminal do contêiner Docker, digite `export GRADIO_SERVER_NAME="0.0.0.0"` para permitir o acesso externo ao serviço gradio dentro do Docker.
|
||||
Em seguida, no terminal do contêiner Docker, digite `python tools/webui.py` para iniciar o serviço WebUI.
|
||||
Em seguida, no terminal do contêiner Docker, digite `python tools/run_webui.py` para iniciar o serviço WebUI.
|
||||
|
||||
Se estiver usando WSL ou MacOS, acesse [http://localhost:7860](http://localhost:7860) para abrir a interface WebUI.
|
||||
|
||||
|
@ -67,7 +67,7 @@ python tools/vqgan/inference.py \
|
||||
Fornecemos uma API HTTP para inferência. O seguinte comando pode ser usado para iniciar o servidor:
|
||||
|
||||
```bash
|
||||
python -m tools.api \
|
||||
python -m tools.api_server \
|
||||
--listen 0.0.0.0:8080 \
|
||||
--llama-checkpoint-path "checkpoints/fish-speech-1.5" \
|
||||
--decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
|
||||
@ -78,10 +78,10 @@ python -m tools.api \
|
||||
|
||||
Depois disso, é possível visualizar e testar a API em http://127.0.0.1:8080/.
|
||||
|
||||
Abaixo está um exemplo de envio de uma solicitação usando `tools/post_api.py`.
|
||||
Abaixo está um exemplo de envio de uma solicitação usando `tools/api_client.py`.
|
||||
|
||||
```bash
|
||||
python -m tools.post_api \
|
||||
python -m tools.api_client \
|
||||
--text "Texto a ser inserido" \
|
||||
--reference_audio "Caminho para o áudio de referência" \
|
||||
--reference_text "Conteúdo de texto do áudio de referência" \
|
||||
@ -91,7 +91,7 @@ python -m tools.post_api \
|
||||
O comando acima indica a síntese do áudio desejada de acordo com as informações do áudio de referência e a retorna em modo de streaming.
|
||||
|
||||
!!! info
|
||||
Para aprender mais sobre parâmetros disponíveis, você pode usar o comando `python -m tools.post_api -h`
|
||||
Para aprender mais sobre parâmetros disponíveis, você pode usar o comando `python -m tools.api_client -h`
|
||||
|
||||
## Inferência por WebUI
|
||||
|
||||
|
@ -47,7 +47,7 @@ pip install -e .[stable]
|
||||
Para construir o fish-agent, use o comando abaixo na pasta principal:
|
||||
|
||||
```bash
|
||||
python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
|
||||
python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
|
||||
```
|
||||
|
||||
O argumento `--compile` só suporta Python < 3.12, o que aumentará muito a velocidade de geração de tokens.
|
||||
|
@ -188,7 +188,7 @@ pip install -e .[stable]
|
||||
4. 配置环境变量,访问 WebUI
|
||||
|
||||
在 docker 容器内的终端,输入 `export GRADIO_SERVER_NAME="0.0.0.0"` ,从而让外部可以访问 docker 内的 gradio 服务。
|
||||
接着在 docker 容器内的终端,输入 `python tools/webui.py` 即可开启 WebUI 服务。
|
||||
接着在 docker 容器内的终端,输入 `python tools/run_webui.py` 即可开启 WebUI 服务。
|
||||
|
||||
如果是 WSL 或者是 MacOS ,访问 [http://localhost:7860](http://localhost:7860) 即可打开 WebUI 界面。
|
||||
|
||||
|
@ -73,7 +73,7 @@ python tools/vqgan/inference.py \
|
||||
运行以下命令来启动 HTTP 服务:
|
||||
|
||||
```bash
|
||||
python -m tools.api \
|
||||
python -m tools.api_server \
|
||||
--listen 0.0.0.0:8080 \
|
||||
--llama-checkpoint-path "checkpoints/fish-speech-1.5" \
|
||||
--decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
|
||||
@ -88,10 +88,10 @@ HF_ENDPOINT=https://hf-mirror.com python -m ...(同上)
|
||||
|
||||
随后, 你可以在 `http://127.0.0.1:8080/` 中查看并测试 API.
|
||||
|
||||
下面是使用`tools/post_api.py`发送请求的示例。
|
||||
下面是使用`tools/api_client.py`发送请求的示例。
|
||||
|
||||
```bash
|
||||
python -m tools.post_api \
|
||||
python -m tools.api_client \
|
||||
--text "要输入的文本" \
|
||||
--reference_audio "参考音频路径" \
|
||||
--reference_text "参考音频的文本内容" \
|
||||
@ -102,7 +102,7 @@ python -m tools.post_api \
|
||||
|
||||
下面的示例展示了, 可以一次使用**多个** `参考音频路径` 和 `参考音频的文本内容`。在命令里用空格隔开即可。
|
||||
```bash
|
||||
python -m tools.post_api \
|
||||
python -m tools.api_client \
|
||||
--text "要输入的文本" \
|
||||
--reference_audio "参考音频路径1" "参考音频路径2" \
|
||||
--reference_text "参考音频的文本内容1" "参考音频的文本内容2"\
|
||||
@ -117,7 +117,7 @@ python -m tools.post_api \
|
||||
里面放上任意对音频与标注文本。 目前支持的参考音频最多加起来总时长90s。
|
||||
|
||||
!!! info
|
||||
要了解有关可用参数的更多信息,可以使用命令`python -m tools.post_api -h`
|
||||
要了解有关可用参数的更多信息,可以使用命令`python -m tools.api_client -h`
|
||||
|
||||
## GUI 推理
|
||||
[下载客户端](https://github.com/AnyaCoder/fish-speech-gui/releases)
|
||||
|
@ -49,7 +49,7 @@ pip install -e .[stable]
|
||||
你需要使用以下指令来构建 fish-agent
|
||||
|
||||
```bash
|
||||
python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
|
||||
python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
|
||||
```
|
||||
|
||||
`--compile`只能在小于 3.12 版本的 Python 使用,这个功能可以极大程度上提高生成速度。
|
||||
|
@ -7,4 +7,4 @@ if [ "${CUDA_ENABLED}" != "true" ]; then
|
||||
DEVICE="--device cpu"
|
||||
fi
|
||||
|
||||
exec python tools/webui.py ${DEVICE}
|
||||
exec python tools/run_webui.py ${DEVICE}
|
||||
|
@ -176,7 +176,7 @@ def change_infer(
|
||||
p_infer = subprocess.Popen(
|
||||
[
|
||||
PYTHON,
|
||||
"tools/webui.py",
|
||||
"tools/run_webui.py",
|
||||
"--decoder-checkpoint-path",
|
||||
infer_decoder_model,
|
||||
"--decoder-config-name",
|
||||
|
@ -83,7 +83,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!python tools/webui.py \\\n",
|
||||
"!python tools/run_webui.py \\\n",
|
||||
" --llama-checkpoint-path checkpoints/fish-speech-1.4 \\\n",
|
||||
" --decoder-checkpoint-path checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth \\\n",
|
||||
" # --compile"
|
||||
|
@ -82,7 +82,7 @@ if not "!flags!"=="" set "flags=!flags:~1!"
|
||||
echo Debug: flags = !flags!
|
||||
|
||||
if "!mode!"=="api" (
|
||||
%PYTHON_CMD% -m tools.api !flags!
|
||||
%PYTHON_CMD% -m tools.api_server !flags!
|
||||
) else if "!mode!"=="infer" (
|
||||
%PYTHON_CMD% -m tools.webui !flags!
|
||||
)
|
||||
|
951
tools/api.py
951
tools/api.py
@ -1,951 +0,0 @@
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
import wave
|
||||
from argparse import ArgumentParser
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import ormsgpack
|
||||
import pyrootutils
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import torchaudio
|
||||
from baize.datastructures import ContentType
|
||||
from kui.asgi import (
|
||||
Body,
|
||||
FactoryClass,
|
||||
HTTPException,
|
||||
HttpRequest,
|
||||
HttpView,
|
||||
JSONResponse,
|
||||
Kui,
|
||||
OpenAPI,
|
||||
StreamResponse,
|
||||
request,
|
||||
)
|
||||
from kui.asgi.routing import MultimethodRoutes
|
||||
from loguru import logger
|
||||
|
||||
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
import struct
|
||||
from threading import Lock
|
||||
|
||||
import httpx
|
||||
from cachetools import LRUCache, cached
|
||||
from funasr import AutoModel
|
||||
from silero_vad import get_speech_timestamps, load_silero_vad
|
||||
|
||||
from fish_speech.models.text2semantic.llama import BaseModelArgs
|
||||
|
||||
# from fish_speech.models.vqgan.lit_module import VQGAN
|
||||
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
||||
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
||||
|
||||
# from fish_speech.conversation import IM_END_TOKEN, SEMANTIC_TOKEN
|
||||
from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
|
||||
from fish_speech.utils import autocast_exclude_mps, set_seed
|
||||
from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
|
||||
from tools.llama.generate import (
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
WrappedGenerateResponse,
|
||||
launch_thread_safe_queue,
|
||||
launch_thread_safe_queue_agent,
|
||||
)
|
||||
from tools.schema import (
|
||||
GLOBAL_NUM_SAMPLES,
|
||||
ASRPackRequest,
|
||||
ServeASRRequest,
|
||||
ServeASRResponse,
|
||||
ServeASRSegment,
|
||||
ServeAudioPart,
|
||||
ServeForwardMessage,
|
||||
ServeMessage,
|
||||
ServeRequest,
|
||||
ServeResponse,
|
||||
ServeStreamDelta,
|
||||
ServeStreamResponse,
|
||||
ServeTextPart,
|
||||
ServeTimedASRResponse,
|
||||
ServeTTSRequest,
|
||||
ServeVQGANDecodeRequest,
|
||||
ServeVQGANDecodeResponse,
|
||||
ServeVQGANEncodeRequest,
|
||||
ServeVQGANEncodeResponse,
|
||||
ServeVQPart,
|
||||
)
|
||||
from tools.vqgan.inference import load_model as load_decoder_model
|
||||
|
||||
global_lock = Lock()
|
||||
|
||||
# Whether to disable keepalive (which is helpful if the server is in the same cluster)
|
||||
DISABLE_KEEPALIVE = os.getenv("DISABLE_KEEPALIVE", "false").lower() == "true"
|
||||
async_client = httpx.AsyncClient(
|
||||
timeout=120, limits=httpx.Limits(keepalive_expiry=0 if DISABLE_KEEPALIVE else None)
|
||||
)
|
||||
backends = torchaudio.list_audio_backends()
|
||||
|
||||
if "ffmpeg" in backends:
|
||||
backend = "ffmpeg"
|
||||
else:
|
||||
backend = "soundfile"
|
||||
|
||||
|
||||
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
||||
buffer = io.BytesIO()
|
||||
|
||||
with wave.open(buffer, "wb") as wav_file:
|
||||
wav_file.setnchannels(channels)
|
||||
wav_file.setsampwidth(bit_depth // 8)
|
||||
wav_file.setframerate(sample_rate)
|
||||
|
||||
wav_header_bytes = buffer.getvalue()
|
||||
buffer.close()
|
||||
return wav_header_bytes
|
||||
|
||||
|
||||
# Define utils for web server
|
||||
async def http_execption_handler(exc: HTTPException):
|
||||
return JSONResponse(
|
||||
dict(
|
||||
statusCode=exc.status_code,
|
||||
message=exc.content,
|
||||
error=HTTPStatus(exc.status_code).phrase,
|
||||
),
|
||||
exc.status_code,
|
||||
exc.headers,
|
||||
)
|
||||
|
||||
|
||||
async def other_exception_handler(exc: "Exception"):
|
||||
traceback.print_exc()
|
||||
|
||||
status = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
return JSONResponse(
|
||||
dict(statusCode=status, message=str(exc), error=status.phrase),
|
||||
status,
|
||||
)
|
||||
|
||||
|
||||
def load_audio(reference_audio, sr):
|
||||
if len(reference_audio) > 255 or not Path(reference_audio).exists():
|
||||
audio_data = reference_audio
|
||||
reference_audio = io.BytesIO(audio_data)
|
||||
|
||||
waveform, original_sr = torchaudio.load(reference_audio, backend=backend)
|
||||
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
|
||||
if original_sr != sr:
|
||||
resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr)
|
||||
waveform = resampler(waveform)
|
||||
|
||||
audio = waveform.squeeze().numpy()
|
||||
return audio
|
||||
|
||||
|
||||
def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
|
||||
if enable_reference_audio and reference_audio is not None:
|
||||
# Load audios, and prepare basic info here
|
||||
reference_audio_content = load_audio(
|
||||
reference_audio, decoder_model.spec_transform.sample_rate
|
||||
)
|
||||
|
||||
audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
|
||||
None, None, :
|
||||
]
|
||||
audio_lengths = torch.tensor(
|
||||
[audios.shape[2]], device=decoder_model.device, dtype=torch.long
|
||||
)
|
||||
logger.info(
|
||||
f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
|
||||
)
|
||||
|
||||
# VQ Encoder
|
||||
if isinstance(decoder_model, FireflyArchitecture):
|
||||
prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
|
||||
|
||||
logger.info(f"Encoded prompt: {prompt_tokens.shape}")
|
||||
else:
|
||||
prompt_tokens = None
|
||||
logger.info("No reference audio provided")
|
||||
|
||||
return prompt_tokens
|
||||
|
||||
|
||||
def decode_vq_tokens(
|
||||
*,
|
||||
decoder_model,
|
||||
codes,
|
||||
):
|
||||
feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
|
||||
logger.info(f"VQ features: {codes.shape}")
|
||||
|
||||
if isinstance(decoder_model, FireflyArchitecture):
|
||||
# VQGAN Inference
|
||||
return decoder_model.decode(
|
||||
indices=codes[None],
|
||||
feature_lengths=feature_lengths,
|
||||
)[0].squeeze()
|
||||
|
||||
raise ValueError(f"Unknown model type: {type(decoder_model)}")
|
||||
|
||||
|
||||
routes = MultimethodRoutes(base_class=HttpView)
|
||||
|
||||
|
||||
def get_content_type(audio_format):
|
||||
if audio_format == "wav":
|
||||
return "audio/wav"
|
||||
elif audio_format == "flac":
|
||||
return "audio/flac"
|
||||
elif audio_format == "mp3":
|
||||
return "audio/mpeg"
|
||||
else:
|
||||
return "application/octet-stream"
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.autocast(device_type="cuda", dtype=torch.half)
|
||||
def batch_encode(model, audios: list[bytes | torch.Tensor]):
|
||||
audios = [
|
||||
(
|
||||
torch.from_numpy(
|
||||
librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
|
||||
)[None]
|
||||
if isinstance(audio, bytes)
|
||||
else audio
|
||||
)
|
||||
for audio in audios
|
||||
]
|
||||
|
||||
# if any(audio.shape[-1] > model.spec_transform.sample_rate * 120 for audio in audios):
|
||||
# raise ValueError("Single audio length is too long (>120s)")
|
||||
|
||||
max_length = max(audio.shape[-1] for audio in audios)
|
||||
print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
|
||||
|
||||
lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
|
||||
max_length = lengths.max().item()
|
||||
padded = torch.stack(
|
||||
[
|
||||
torch.nn.functional.pad(audio, (0, max_length - audio.shape[-1]))
|
||||
for audio in audios
|
||||
]
|
||||
).to(model.device)
|
||||
|
||||
features, feature_lengths = model.encode(padded, audio_lengths=lengths)
|
||||
features, feature_lengths = features.cpu(), feature_lengths.cpu()
|
||||
|
||||
return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
|
||||
|
||||
|
||||
@cached(
|
||||
cache=LRUCache(maxsize=10000),
|
||||
key=lambda model, audios: (model.device, tuple(audios)),
|
||||
)
|
||||
def cached_vqgan_batch_encode(model, audios: list[bytes]):
|
||||
return batch_encode(model, audios)
|
||||
|
||||
|
||||
@routes.http.post("/v1/vqgan/encode")
|
||||
def api_vqgan_encode(payload: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
|
||||
|
||||
start_time = time.time()
|
||||
tokens = cached_vqgan_batch_encode(decoder_model, payload.audios)
|
||||
logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
|
||||
|
||||
return ormsgpack.packb(
|
||||
ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
|
||||
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.autocast(device_type="cuda", dtype=torch.half)
|
||||
def vqgan_decode(model, features):
|
||||
lengths = torch.tensor(
|
||||
[feature.shape[-1] for feature in features], device=model.device
|
||||
)
|
||||
max_length = lengths.max().item()
|
||||
padded = torch.stack(
|
||||
[
|
||||
torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
|
||||
for feature in features
|
||||
]
|
||||
).to(model.device)
|
||||
|
||||
# If bs too large, we do micro batch decode
|
||||
audios, audio_lengths = [], []
|
||||
for i in range(0, padded.shape[0], 8):
|
||||
audio, audio_length = model.decode(
|
||||
padded[i : i + 8], feature_lengths=lengths[i : i + 8]
|
||||
)
|
||||
audios.append(audio)
|
||||
audio_lengths.append(audio_length)
|
||||
audios = torch.cat(audios, dim=0)
|
||||
audio_lengths = torch.cat(audio_lengths, dim=0)
|
||||
audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
|
||||
|
||||
return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
|
||||
|
||||
|
||||
@routes.http.post("/v1/vqgan/decode")
|
||||
def api_vqgan_decode(payload: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
|
||||
tokens = [torch.tensor(token, dtype=torch.int) for token in payload.tokens]
|
||||
start_time = time.time()
|
||||
audios = vqgan_decode(decoder_model, tokens)
|
||||
logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
|
||||
audios = [audio.astype(np.float16).tobytes() for audio in audios]
|
||||
return ormsgpack.packb(
|
||||
ServeVQGANDecodeResponse(audios=audios), option=ormsgpack.OPT_SERIALIZE_PYDANTIC
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def batch_asr(model, audios, sr, language="auto"):
|
||||
resampled_audios = []
|
||||
for audio in audios:
|
||||
audio = torchaudio.functional.resample(audio, sr, 16000)
|
||||
assert audio.ndim == 1
|
||||
resampled_audios.append(audio)
|
||||
|
||||
with global_lock:
|
||||
res = model.generate(
|
||||
input=resampled_audios,
|
||||
batch_size=len(resampled_audios),
|
||||
language=language,
|
||||
use_itn=True,
|
||||
)
|
||||
|
||||
results = []
|
||||
for r, audio in zip(res, audios):
|
||||
text = r["text"]
|
||||
text = re.sub(r"<\|.*?\|>", "", text)
|
||||
duration = len(audio) / sr * 1000
|
||||
huge_gap = False
|
||||
|
||||
if "timestamp" in r and len(r["timestamp"]) > 2:
|
||||
for timestamp_a, timestamp_b in zip(
|
||||
r["timestamp"][:-1], r["timestamp"][1:]
|
||||
):
|
||||
# If there is a gap of more than 5 seconds, we consider it as a huge gap
|
||||
if timestamp_b[0] - timestamp_a[1] > 5000:
|
||||
huge_gap = True
|
||||
break
|
||||
|
||||
# Doesn't make sense to have a huge gap at the end
|
||||
if duration - r["timestamp"][-1][1] > 3000:
|
||||
huge_gap = True
|
||||
|
||||
results.append(
|
||||
{
|
||||
"text": text,
|
||||
"duration": duration,
|
||||
"huge_gap": huge_gap,
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@routes.http.post("/v1/asr")
|
||||
def api_invoke_asr(payload: Annotated[ServeASRRequest, Body(exclusive=True)]):
|
||||
start_time = time.time()
|
||||
audios = [np.frombuffer(audio, dtype=np.float16) for audio in payload.audios]
|
||||
audios = [torch.from_numpy(audio).float() for audio in audios]
|
||||
|
||||
if any(audios.shape[-1] >= 30 * payload.sample_rate for audios in audios):
|
||||
raise HTTPException(status_code=400, detail="Audio length is too long")
|
||||
|
||||
transcriptions = batch_asr(
|
||||
asr_model, audios=audios, sr=payload.sample_rate, language=payload.language
|
||||
)
|
||||
logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
|
||||
|
||||
return ormsgpack.packb(
|
||||
ServeASRResponse(transcriptions=transcriptions),
|
||||
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
|
||||
)
|
||||
|
||||
|
||||
from fish_speech.conversation import Conversation, Message
|
||||
|
||||
|
||||
def execute_request(
|
||||
input_queue: queue.Queue,
|
||||
tokenizer: FishTokenizer,
|
||||
config: BaseModelArgs,
|
||||
request: ServeRequest,
|
||||
device: str = "cuda:0",
|
||||
):
|
||||
|
||||
im_end_id = tokenizer.get_token_id(IM_END_TOKEN)
|
||||
messages = []
|
||||
for message in request.messages:
|
||||
messages.append(message.to_conversation_message())
|
||||
|
||||
assert len(messages) >= 1, "At least one message is required"
|
||||
# assert messages[-1].role == "user", "The last message must be from the user"
|
||||
|
||||
if messages[-1].role == "user":
|
||||
messages.append(
|
||||
Message(role="assistant", parts=[], add_im_end=False, modality="voice")
|
||||
)
|
||||
elif messages[-1].role == "raw":
|
||||
messages[-1].add_im_start = False
|
||||
messages[-1].add_im_end = False
|
||||
messages[-1].modality = "voice"
|
||||
else:
|
||||
assert (
|
||||
messages[-1].role == "assistant"
|
||||
), "The last message must be from the assistant"
|
||||
messages[-1].add_im_end = False
|
||||
|
||||
conv = Conversation(messages=messages)
|
||||
|
||||
# conv.visualize(tokenizer)
|
||||
prompt = conv.encode_for_inference(
|
||||
tokenizer=tokenizer, num_codebooks=config.num_codebooks
|
||||
).to(device)
|
||||
|
||||
if request.streaming:
|
||||
for i in range(request.num_samples):
|
||||
yield ServeStreamResponse(
|
||||
sample_id=i,
|
||||
delta=ServeStreamDelta(
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
|
||||
req = {
|
||||
"prompt": prompt,
|
||||
"max_new_tokens": request.max_new_tokens,
|
||||
"im_end_id": im_end_id,
|
||||
"temperature": request.temperature,
|
||||
"top_p": request.top_p,
|
||||
"repetition_penalty": request.repetition_penalty,
|
||||
"num_samples": request.num_samples,
|
||||
"early_stop_threshold": request.early_stop_threshold,
|
||||
}
|
||||
|
||||
start = time.time()
|
||||
response_queue = queue.Queue()
|
||||
input_queue.put(GenerateRequest(req, response_queue))
|
||||
|
||||
# Decoding
|
||||
decode_buffer = [[] for _ in range(request.num_samples)]
|
||||
parts = [[] for _ in range(request.num_samples)]
|
||||
|
||||
def send_reset_buffer(sample_id):
|
||||
nonlocal decode_buffer
|
||||
if len(decode_buffer[sample_id]) == 0:
|
||||
return
|
||||
|
||||
decoded = tokenizer.decode(decode_buffer[sample_id])
|
||||
part = ServeTextPart(text=decoded)
|
||||
|
||||
if request.streaming:
|
||||
yield ServeStreamResponse(delta=ServeStreamDelta(part=part))
|
||||
else:
|
||||
parts[sample_id].append(part)
|
||||
|
||||
decode_buffer[sample_id] = []
|
||||
|
||||
# Decode process
|
||||
finished = [False for _ in range(request.num_samples)]
|
||||
stats = {}
|
||||
idx = 0
|
||||
while True:
|
||||
response = response_queue.get()
|
||||
|
||||
if response in ["stop", "error"]:
|
||||
break
|
||||
|
||||
for sample_id, tokens in enumerate(response):
|
||||
if finished[sample_id]:
|
||||
continue
|
||||
|
||||
if tokens[0] == im_end_id:
|
||||
finished[sample_id] = True
|
||||
if request.streaming:
|
||||
yield from send_reset_buffer(sample_id)
|
||||
yield ServeStreamResponse(
|
||||
sample_id=sample_id,
|
||||
finish_reason="stop",
|
||||
stats=stats,
|
||||
)
|
||||
continue
|
||||
|
||||
is_semantic = (
|
||||
tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id
|
||||
)
|
||||
if is_semantic and request.streaming:
|
||||
yield from send_reset_buffer(sample_id)
|
||||
# Streaming vq
|
||||
_tokens = tokens[1:].clone()
|
||||
|
||||
if config.share_codebook_embeddings is False:
|
||||
for i in range(len(_tokens)):
|
||||
_tokens[i] -= config.codebook_size * i
|
||||
|
||||
yield ServeStreamResponse(
|
||||
sample_id=sample_id,
|
||||
delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
|
||||
)
|
||||
continue
|
||||
|
||||
# Not streaming vq
|
||||
if is_semantic:
|
||||
yield from send_reset_buffer(sample_id)
|
||||
# None streaming vq
|
||||
if len(parts[sample_id]) == 0 or not isinstance(
|
||||
parts[sample_id][-1], ServeVQPart
|
||||
):
|
||||
_tokens = tokens[1:].clone()
|
||||
|
||||
if config.share_codebook_embeddings is False:
|
||||
for i in range(len(_tokens)):
|
||||
_tokens[i] -= config.codebook_size * i
|
||||
|
||||
parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
|
||||
else:
|
||||
for codebook_id, value in enumerate(tokens[1:, :]):
|
||||
val = value.item()
|
||||
if config.share_codebook_embeddings is False:
|
||||
val -= config.codebook_size * codebook_id
|
||||
|
||||
parts[sample_id][-1].codes[codebook_id].append(val)
|
||||
continue
|
||||
|
||||
if not is_semantic:
|
||||
# Stream text decode is not supported now
|
||||
decode_buffer[sample_id].append(tokens[0, 0])
|
||||
|
||||
if idx == 0:
|
||||
stats["time_to_first_token"] = (time.time() - start) * 1000
|
||||
|
||||
idx += 1
|
||||
|
||||
for sample_id in range(request.num_samples):
|
||||
yield from send_reset_buffer(sample_id)
|
||||
|
||||
stats["total_time"] = (time.time() - start) * 1000
|
||||
stats["total_tokens"] = idx
|
||||
|
||||
if request.streaming:
|
||||
for sample_id in range(request.num_samples):
|
||||
if finished[sample_id]:
|
||||
continue
|
||||
yield ServeStreamResponse(
|
||||
finish_reason=response, stats=stats, sample_id=sample_id
|
||||
)
|
||||
return
|
||||
|
||||
yield ServeResponse(
|
||||
messages=[
|
||||
ServeMessage(role="assistant", parts=parts[i])
|
||||
for i in range(request.num_samples)
|
||||
],
|
||||
finish_reason=response,
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
|
||||
@routes.http.post("/v1/chat")
|
||||
def api_invoke_chat(
|
||||
req: Annotated[ServeRequest, Body(exclusive=True)],
|
||||
):
|
||||
"""
|
||||
Invoke model and generate audio
|
||||
"""
|
||||
|
||||
# This makes torch compile happy
|
||||
assert (
|
||||
req.num_samples == GLOBAL_NUM_SAMPLES
|
||||
), f"num_samples must be {GLOBAL_NUM_SAMPLES}"
|
||||
|
||||
content_type = request.headers.get("Content-Type", "application/json")
|
||||
json_mode = "application/json" in content_type
|
||||
|
||||
async def wrapped_generator():
|
||||
generator = execute_request(llama_queue, tokenizer, config, req, args.device)
|
||||
|
||||
for i in generator:
|
||||
if json_mode:
|
||||
body = i.model_dump_json().encode("utf-8")
|
||||
yield b"data: " + body + b"\n\n"
|
||||
else:
|
||||
body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
|
||||
yield struct.pack("I", len(body)) + body
|
||||
|
||||
# Naive mode
|
||||
if req.streaming is False:
|
||||
result = next(execute_request(llama_queue, tokenizer, config, req, args.device))
|
||||
|
||||
if json_mode:
|
||||
return JSONResponse(result.model_dump())
|
||||
else:
|
||||
return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
|
||||
|
||||
return StreamResponse(
|
||||
iterable=wrapped_generator(), content_type="text/event-stream"
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(req: ServeTTSRequest):
|
||||
|
||||
idstr: str | None = req.reference_id
|
||||
if idstr is not None:
|
||||
ref_folder = Path("references") / idstr
|
||||
ref_folder.mkdir(parents=True, exist_ok=True)
|
||||
ref_audios = list_files(
|
||||
ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
|
||||
)
|
||||
|
||||
if req.use_memory_cache == "never" or (
|
||||
req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
|
||||
):
|
||||
prompt_tokens = [
|
||||
encode_reference(
|
||||
decoder_model=decoder_model,
|
||||
reference_audio=audio_to_bytes(str(ref_audio)),
|
||||
enable_reference_audio=True,
|
||||
)
|
||||
for ref_audio in ref_audios
|
||||
]
|
||||
prompt_texts = [
|
||||
read_ref_text(str(ref_audio.with_suffix(".lab")))
|
||||
for ref_audio in ref_audios
|
||||
]
|
||||
else:
|
||||
logger.info("Use same references")
|
||||
|
||||
else:
|
||||
# Parse reference audio aka prompt
|
||||
refs = req.references
|
||||
|
||||
if req.use_memory_cache == "never" or (
|
||||
req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
|
||||
):
|
||||
prompt_tokens = [
|
||||
encode_reference(
|
||||
decoder_model=decoder_model,
|
||||
reference_audio=ref.audio,
|
||||
enable_reference_audio=True,
|
||||
)
|
||||
for ref in refs
|
||||
]
|
||||
prompt_texts = [ref.text for ref in refs]
|
||||
else:
|
||||
logger.info("Use same references")
|
||||
|
||||
if req.seed is not None:
|
||||
set_seed(req.seed)
|
||||
logger.warning(f"set seed: {req.seed}")
|
||||
|
||||
# LLAMA Inference
|
||||
request = dict(
|
||||
device=decoder_model.device,
|
||||
max_new_tokens=req.max_new_tokens,
|
||||
text=(
|
||||
req.text
|
||||
if not req.normalize
|
||||
else ChnNormedText(raw_text=req.text).normalize()
|
||||
),
|
||||
top_p=req.top_p,
|
||||
repetition_penalty=req.repetition_penalty,
|
||||
temperature=req.temperature,
|
||||
compile=args.compile,
|
||||
iterative_prompt=req.chunk_length > 0,
|
||||
chunk_length=req.chunk_length,
|
||||
max_length=4096,
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_text=prompt_texts,
|
||||
)
|
||||
|
||||
response_queue = queue.Queue()
|
||||
llama_queue.put(
|
||||
GenerateRequest(
|
||||
request=request,
|
||||
response_queue=response_queue,
|
||||
)
|
||||
)
|
||||
|
||||
if req.streaming:
|
||||
yield wav_chunk_header()
|
||||
|
||||
segments = []
|
||||
while True:
|
||||
result: WrappedGenerateResponse = response_queue.get()
|
||||
if result.status == "error":
|
||||
raise result.response
|
||||
break
|
||||
|
||||
result: GenerateResponse = result.response
|
||||
if result.action == "next":
|
||||
break
|
||||
|
||||
with autocast_exclude_mps(
|
||||
device_type=decoder_model.device.type, dtype=args.precision
|
||||
):
|
||||
fake_audios = decode_vq_tokens(
|
||||
decoder_model=decoder_model,
|
||||
codes=result.codes,
|
||||
)
|
||||
|
||||
fake_audios = fake_audios.float().cpu().numpy()
|
||||
|
||||
if req.streaming:
|
||||
yield (fake_audios * 32768).astype(np.int16).tobytes()
|
||||
else:
|
||||
segments.append(fake_audios)
|
||||
|
||||
if req.streaming:
|
||||
return
|
||||
|
||||
if len(segments) == 0:
|
||||
raise HTTPException(
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
content="No audio generated, please check the input text.",
|
||||
)
|
||||
|
||||
fake_audios = np.concatenate(segments, axis=0)
|
||||
yield fake_audios
|
||||
|
||||
|
||||
async def inference_async(req: ServeTTSRequest):
|
||||
for chunk in inference(req):
|
||||
yield chunk
|
||||
|
||||
|
||||
async def buffer_to_async_generator(buffer):
|
||||
yield buffer
|
||||
|
||||
|
||||
@routes.http.post("/v1/tts")
|
||||
async def api_invoke_model(
|
||||
req: Annotated[ServeTTSRequest, Body(exclusive=True)],
|
||||
):
|
||||
"""
|
||||
Invoke model and generate audio
|
||||
"""
|
||||
|
||||
if args.max_text_length > 0 and len(req.text) > args.max_text_length:
|
||||
raise HTTPException(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
content=f"Text is too long, max length is {args.max_text_length}",
|
||||
)
|
||||
|
||||
if req.streaming and req.format != "wav":
|
||||
raise HTTPException(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
content="Streaming only supports WAV format",
|
||||
)
|
||||
|
||||
if req.streaming:
|
||||
return StreamResponse(
|
||||
iterable=inference_async(req),
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=audio.{req.format}",
|
||||
},
|
||||
content_type=get_content_type(req.format),
|
||||
)
|
||||
else:
|
||||
fake_audios = next(inference(req))
|
||||
buffer = io.BytesIO()
|
||||
sf.write(
|
||||
buffer,
|
||||
fake_audios,
|
||||
decoder_model.spec_transform.sample_rate,
|
||||
format=req.format,
|
||||
)
|
||||
|
||||
return StreamResponse(
|
||||
iterable=buffer_to_async_generator(buffer.getvalue()),
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=audio.{req.format}",
|
||||
},
|
||||
content_type=get_content_type(req.format),
|
||||
)
|
||||
|
||||
|
||||
@routes.http.post("/v1/health")
|
||||
async def api_health():
|
||||
"""
|
||||
Health check
|
||||
"""
|
||||
return JSONResponse({"status": "ok"})
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
|
||||
parser.add_argument("--load-asr-model", action="store_true")
|
||||
parser.add_argument(
|
||||
"--llama-checkpoint-path",
|
||||
type=str,
|
||||
default="checkpoints/fish-speech-1.4",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder-checkpoint-path",
|
||||
type=str,
|
||||
default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
||||
)
|
||||
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
||||
parser.add_argument("--device", type=str, default="cuda")
|
||||
parser.add_argument("--half", action="store_true")
|
||||
parser.add_argument("--compile", action="store_true")
|
||||
parser.add_argument("--max-text-length", type=int, default=0)
|
||||
parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
|
||||
parser.add_argument("--workers", type=int, default=1)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# Define Kui app
|
||||
openapi = OpenAPI(
|
||||
{
|
||||
"title": "Fish Speech API",
|
||||
"version": "1.4.2",
|
||||
},
|
||||
).routes
|
||||
|
||||
|
||||
class MsgPackRequest(HttpRequest):
|
||||
async def data(
|
||||
self,
|
||||
) -> Annotated[
|
||||
Any, ContentType("application/msgpack"), ContentType("application/json")
|
||||
]:
|
||||
if self.content_type == "application/msgpack":
|
||||
return ormsgpack.unpackb(await self.body)
|
||||
|
||||
elif self.content_type == "application/json":
|
||||
return await self.json
|
||||
|
||||
raise HTTPException(
|
||||
HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
|
||||
headers={"Accept": "application/msgpack, application/json"},
|
||||
)
|
||||
|
||||
|
||||
app = Kui(
|
||||
routes=routes + openapi[1:], # Remove the default route
|
||||
exception_handlers={
|
||||
HTTPException: http_execption_handler,
|
||||
Exception: other_exception_handler,
|
||||
},
|
||||
factory_class=FactoryClass(http=MsgPackRequest),
|
||||
cors_config={},
|
||||
)
|
||||
|
||||
|
||||
def load_asr_model(*, device="cuda", hub="ms"):
|
||||
return AutoModel(
|
||||
model="iic/SenseVoiceSmall",
|
||||
device=device,
|
||||
disable_pbar=True,
|
||||
hub=hub,
|
||||
)
|
||||
|
||||
|
||||
# Each worker process created by Uvicorn has its own memory space,
|
||||
# meaning that models and variables are not shared between processes.
|
||||
# Therefore, any global variables (like `llama_queue` or `decoder_model`)
|
||||
# will not be shared across workers.
|
||||
|
||||
|
||||
# Multi-threading for deep learning can cause issues, such as inconsistent
|
||||
# outputs if multiple threads access the same buffers simultaneously.
|
||||
# Instead, it's better to use multiprocessing or independent models per thread.
|
||||
@app.on_startup
|
||||
def initialize_app(app: Kui):
|
||||
|
||||
global args, llama_queue, tokenizer, config, decoder_model, vad_model, asr_model, prompt_tokens, prompt_texts
|
||||
|
||||
prompt_tokens, prompt_texts = [], []
|
||||
|
||||
args = parse_args() # args same as ones in other processes
|
||||
args.precision = torch.half if args.half else torch.bfloat16
|
||||
|
||||
if args.load_asr_model:
|
||||
logger.info(f"Loading ASR model...")
|
||||
asr_model = load_asr_model(device=args.device)
|
||||
|
||||
logger.info("Loading Llama model...")
|
||||
|
||||
if args.mode == "tts":
|
||||
llama_queue = launch_thread_safe_queue(
|
||||
checkpoint_path=args.llama_checkpoint_path,
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
compile=args.compile,
|
||||
)
|
||||
else:
|
||||
llama_queue, tokenizer, config = launch_thread_safe_queue_agent(
|
||||
checkpoint_path=args.llama_checkpoint_path,
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
compile=args.compile,
|
||||
)
|
||||
|
||||
logger.info("Llama model loaded, loading VQ-GAN model...")
|
||||
|
||||
decoder_model = load_decoder_model(
|
||||
config_name=args.decoder_config_name,
|
||||
checkpoint_path=args.decoder_checkpoint_path,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
logger.info("VQ-GAN model loaded, warming up...")
|
||||
|
||||
vad_model = load_silero_vad()
|
||||
|
||||
logger.info("VAD model loaded, warming up...")
|
||||
|
||||
if args.mode == "tts":
|
||||
# Dry run to ensure models work and avoid first-time latency
|
||||
list(
|
||||
inference(
|
||||
ServeTTSRequest(
|
||||
text="Hello world.",
|
||||
references=[],
|
||||
reference_id=None,
|
||||
max_new_tokens=0,
|
||||
chunk_length=200,
|
||||
top_p=0.7,
|
||||
repetition_penalty=1.5,
|
||||
temperature=0.7,
|
||||
emotion=None,
|
||||
format="wav",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"Warming up done, starting server at http://{args.listen}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
import uvicorn
|
||||
|
||||
args = parse_args()
|
||||
host, port = args.listen.split(":")
|
||||
uvicorn.run(
|
||||
"tools.api:app",
|
||||
host=host,
|
||||
port=int(port),
|
||||
workers=args.workers,
|
||||
log_level="info",
|
||||
)
|
@ -69,10 +69,6 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--format", type=str, choices=["wav", "mp3", "flac"], default="wav"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mp3_bitrate", type=int, choices=[64, 128, 192], default=64, help="kHz"
|
||||
)
|
||||
parser.add_argument("--opus_bitrate", type=int, default=-1000)
|
||||
parser.add_argument(
|
||||
"--latency",
|
||||
type=str,
|
||||
@ -112,11 +108,9 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--use_memory_cache",
|
||||
type=str,
|
||||
default="never",
|
||||
choices=["on-demand", "never"],
|
||||
help="Cache encoded references codes in memory.\n"
|
||||
"If `on-demand`, the server will use cached encodings\n "
|
||||
"instead of encoding reference audio again.",
|
||||
default="off",
|
||||
choices=["on", "off"],
|
||||
help="Cache encoded references codes in memory.\n",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
@ -154,14 +148,14 @@ if __name__ == "__main__":
|
||||
data = {
|
||||
"text": args.text,
|
||||
"references": [
|
||||
ServeReferenceAudio(audio=ref_audio, text=ref_text)
|
||||
ServeReferenceAudio(
|
||||
audio=ref_audio if ref_audio is not None else b"", text=ref_text
|
||||
)
|
||||
for ref_text, ref_audio in zip(ref_texts, byte_audios)
|
||||
],
|
||||
"reference_id": idstr,
|
||||
"normalize": args.normalize,
|
||||
"format": args.format,
|
||||
"mp3_bitrate": args.mp3_bitrate,
|
||||
"opus_bitrate": args.opus_bitrate,
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"chunk_length": args.chunk_length,
|
||||
"top_p": args.top_p,
|
98
tools/api_server.py
Normal file
98
tools/api_server.py
Normal file
@ -0,0 +1,98 @@
|
||||
from threading import Lock
|
||||
|
||||
import pyrootutils
|
||||
import uvicorn
|
||||
from kui.asgi import FactoryClass, HTTPException, HttpRoute, Kui, OpenAPI, Routes
|
||||
from loguru import logger
|
||||
|
||||
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
|
||||
from tools.server.api_utils import MsgPackRequest, parse_args
|
||||
from tools.server.exception_handler import ExceptionHandler
|
||||
from tools.server.model_manager import ModelManager
|
||||
from tools.server.views import (
|
||||
ASRView,
|
||||
ChatView,
|
||||
HealthView,
|
||||
TTSView,
|
||||
VQGANDecodeView,
|
||||
VQGANEncodeView,
|
||||
)
|
||||
|
||||
|
||||
class API(ExceptionHandler):
|
||||
def __init__(self):
|
||||
self.args = parse_args()
|
||||
self.routes = [
|
||||
("/v1/health", HealthView),
|
||||
("/v1/vqgan/encode", VQGANEncodeView),
|
||||
("/v1/vqgan/decode", VQGANDecodeView),
|
||||
("/v1/asr", ASRView),
|
||||
("/v1/tts", TTSView),
|
||||
("/v1/chat", ChatView),
|
||||
]
|
||||
self.routes = Routes([HttpRoute(path, view) for path, view in self.routes])
|
||||
|
||||
self.openapi = OpenAPI(
|
||||
{
|
||||
"title": "Fish Speech API",
|
||||
"version": "1.5.0",
|
||||
},
|
||||
).routes
|
||||
|
||||
# Initialize the app
|
||||
self.app = Kui(
|
||||
routes=self.routes + self.openapi[1:], # Remove the default route
|
||||
exception_handlers={
|
||||
HTTPException: self.http_exception_handler,
|
||||
Exception: self.other_exception_handler,
|
||||
},
|
||||
factory_class=FactoryClass(http=MsgPackRequest),
|
||||
cors_config={},
|
||||
)
|
||||
|
||||
# Add the state variables
|
||||
self.app.state.lock = Lock()
|
||||
self.app.state.device = self.args.device
|
||||
self.app.state.max_text_length = self.args.max_text_length
|
||||
|
||||
# Associate the app with the model manager
|
||||
self.app.on_startup(self.initialize_app)
|
||||
|
||||
async def initialize_app(self, app: Kui):
|
||||
# Make the ModelManager available to the views
|
||||
app.state.model_manager = ModelManager(
|
||||
mode=self.args.mode,
|
||||
device=self.args.device,
|
||||
half=self.args.half,
|
||||
compile=self.args.compile,
|
||||
asr_enabled=self.args.load_asr_model,
|
||||
llama_checkpoint_path=self.args.llama_checkpoint_path,
|
||||
decoder_checkpoint_path=self.args.decoder_checkpoint_path,
|
||||
decoder_config_name=self.args.decoder_config_name,
|
||||
)
|
||||
|
||||
logger.info(f"Startup done, listening server at http://{self.args.listen}")
|
||||
|
||||
|
||||
# Each worker process created by Uvicorn has its own memory space,
|
||||
# meaning that models and variables are not shared between processes.
|
||||
# Therefore, any variables (like `llama_queue` or `decoder_model`)
|
||||
# will not be shared across workers.
|
||||
|
||||
# Multi-threading for deep learning can cause issues, such as inconsistent
|
||||
# outputs if multiple threads access the same buffers simultaneously.
|
||||
# Instead, it's better to use multiprocessing or independent models per thread.
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
api = API()
|
||||
host, port = api.args.listen.split(":")
|
||||
|
||||
uvicorn.run(
|
||||
api.app,
|
||||
host=host,
|
||||
port=int(port),
|
||||
workers=api.args.workers,
|
||||
log_level="info",
|
||||
)
|
@ -14,8 +14,8 @@ import ormsgpack
|
||||
import soundfile as sf
|
||||
|
||||
from .schema import (
|
||||
ServeChatRequest,
|
||||
ServeMessage,
|
||||
ServeRequest,
|
||||
ServeTextPart,
|
||||
ServeVQGANDecodeRequest,
|
||||
ServeVQGANEncodeRequest,
|
||||
@ -163,7 +163,7 @@ class FishE2EAgent:
|
||||
else:
|
||||
user_codes = None
|
||||
|
||||
request = ServeRequest(
|
||||
request = ServeChatRequest(
|
||||
messages=prev_messages
|
||||
+ (
|
||||
[
|
||||
|
193
tools/inference_engine/__init__.py
Normal file
193
tools/inference_engine/__init__.py
Normal file
@ -0,0 +1,193 @@
|
||||
import gc
|
||||
import queue
|
||||
from typing import Generator
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
||||
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
||||
from fish_speech.utils import autocast_exclude_mps, set_seed
|
||||
from tools.inference_engine.reference_loader import ReferenceLoader
|
||||
from tools.inference_engine.utils import InferenceResult, wav_chunk_header
|
||||
from tools.inference_engine.vq_manager import VQManager
|
||||
from tools.llama.generate import (
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
WrappedGenerateResponse,
|
||||
)
|
||||
from tools.schema import ServeTTSRequest
|
||||
|
||||
|
||||
class TTSInferenceEngine(ReferenceLoader, VQManager):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llama_queue: queue.Queue,
|
||||
decoder_model: FireflyArchitecture,
|
||||
precision: torch.dtype,
|
||||
compile: bool,
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.llama_queue = llama_queue
|
||||
self.decoder_model = decoder_model
|
||||
self.precision = precision
|
||||
self.compile = compile
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, None]:
|
||||
"""
|
||||
Main inference function:
|
||||
- Loads the reference audio and text.
|
||||
- Calls the LLAMA model for inference.
|
||||
- Decodes the VQ tokens to audio.
|
||||
"""
|
||||
|
||||
ref_id: str | None = req.reference_id
|
||||
prompt_tokens, prompt_texts = [], []
|
||||
# Load the reference audio and text based on id or hash
|
||||
if ref_id is not None:
|
||||
prompt_tokens, prompt_texts = self.load_by_id(ref_id, req.use_memory_cache)
|
||||
|
||||
elif req.references:
|
||||
prompt_tokens, prompt_texts = self.load_by_hash(
|
||||
req.references, req.use_memory_cache
|
||||
)
|
||||
|
||||
# Set the random seed if provided
|
||||
if req.seed is not None:
|
||||
set_seed(req.seed)
|
||||
logger.warning(f"set seed: {req.seed}")
|
||||
|
||||
# Get the symbolic tokens from the LLAMA model
|
||||
response_queue = self.send_Llama_request(req, prompt_tokens, prompt_texts)
|
||||
|
||||
# Get the sample rate from the decoder model
|
||||
sample_rate = self.decoder_model.spec_transform.sample_rate
|
||||
|
||||
# If streaming, send the header
|
||||
if req.streaming:
|
||||
yield InferenceResult(
|
||||
code="header",
|
||||
audio=(sample_rate, wav_chunk_header(sample_rate=sample_rate)),
|
||||
error=None,
|
||||
)
|
||||
|
||||
segments = []
|
||||
|
||||
while True:
|
||||
# Get the response from the LLAMA model
|
||||
wrapped_result: WrappedGenerateResponse = response_queue.get()
|
||||
if wrapped_result.status == "error":
|
||||
yield InferenceResult(
|
||||
code="error",
|
||||
audio=None,
|
||||
error=(
|
||||
wrapped_result.response
|
||||
if isinstance(wrapped_result.response, Exception)
|
||||
else Exception("Unknown error")
|
||||
),
|
||||
)
|
||||
break
|
||||
|
||||
# Check the response type
|
||||
if not isinstance(wrapped_result.response, GenerateResponse):
|
||||
raise TypeError(
|
||||
"Expected GenerateResponse, got {type(wrapped_result.response).__name__}"
|
||||
)
|
||||
|
||||
result: GenerateResponse = wrapped_result.response
|
||||
if result.action != "next":
|
||||
segment = self.get_audio_segment(result)
|
||||
|
||||
if req.streaming: # Used only by the API server
|
||||
yield InferenceResult(
|
||||
code="segment",
|
||||
audio=(sample_rate, segment),
|
||||
error=None,
|
||||
)
|
||||
else:
|
||||
segments.append(segment)
|
||||
else:
|
||||
break
|
||||
|
||||
# Clean up the memory
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# Edge case: no audio generated
|
||||
if len(segments) == 0:
|
||||
yield InferenceResult(
|
||||
code="error",
|
||||
audio=None,
|
||||
error=RuntimeError("No audio generated, please check the input text."),
|
||||
)
|
||||
else:
|
||||
# Streaming or not, return the final audio
|
||||
audio = np.concatenate(segments, axis=0)
|
||||
yield InferenceResult(
|
||||
code="final",
|
||||
audio=(sample_rate, audio),
|
||||
error=None,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def send_Llama_request(
|
||||
self, req: ServeTTSRequest, prompt_tokens: list, prompt_texts: list
|
||||
) -> queue.Queue:
|
||||
"""
|
||||
Send a request to the LLAMA model to generate the symbolic tokens.
|
||||
"""
|
||||
|
||||
# Prepare the request
|
||||
request = dict(
|
||||
device=self.decoder_model.device,
|
||||
max_new_tokens=req.max_new_tokens,
|
||||
text=(
|
||||
req.text
|
||||
if not req.normalize
|
||||
else ChnNormedText(raw_text=req.text).normalize()
|
||||
),
|
||||
top_p=req.top_p,
|
||||
repetition_penalty=req.repetition_penalty,
|
||||
temperature=req.temperature,
|
||||
compile=self.compile,
|
||||
iterative_prompt=req.chunk_length > 0,
|
||||
chunk_length=req.chunk_length,
|
||||
max_length=4096,
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_text=prompt_texts,
|
||||
)
|
||||
|
||||
# Create a queue to get the response
|
||||
response_queue = queue.Queue()
|
||||
|
||||
# Send the request to the LLAMA model
|
||||
self.llama_queue.put(
|
||||
GenerateRequest(
|
||||
request=request,
|
||||
response_queue=response_queue,
|
||||
)
|
||||
)
|
||||
|
||||
return response_queue
|
||||
|
||||
def get_audio_segment(self, result: GenerateResponse) -> np.ndarray:
|
||||
"""
|
||||
Decode the VQ tokens to audio.
|
||||
"""
|
||||
|
||||
# Don't use autocast on MPS devices
|
||||
with autocast_exclude_mps(
|
||||
device_type=self.decoder_model.device.type, dtype=self.precision
|
||||
):
|
||||
# Decode the symbolic tokens to audio
|
||||
segment = self.decode_vq_tokens(codes=result.codes)
|
||||
|
||||
# Convert the audio to numpy
|
||||
return segment.float().cpu().numpy()
|
128
tools/inference_engine/reference_loader.py
Normal file
128
tools/inference_engine/reference_loader.py
Normal file
@ -0,0 +1,128 @@
|
||||
import io
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal, Tuple
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from loguru import logger
|
||||
|
||||
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
||||
from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
|
||||
from tools.schema import ServeReferenceAudio
|
||||
|
||||
|
||||
class ReferenceLoader:
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Component of the TTSInferenceEngine class.
|
||||
Loads and manages the cache for the reference audio and text.
|
||||
"""
|
||||
self.ref_by_id: dict = {}
|
||||
self.ref_by_hash: dict = {}
|
||||
|
||||
# Make Pylance happy (attribut/method not defined...)
|
||||
self.decoder_model: FireflyArchitecture
|
||||
self.encode_reference: Callable
|
||||
|
||||
# Define the torchaudio backend
|
||||
backends = torchaudio.list_audio_backends()
|
||||
if "ffmpeg" in backends:
|
||||
self.backend = "ffmpeg"
|
||||
else:
|
||||
self.backend = "soundfile"
|
||||
|
||||
def load_by_id(
|
||||
self,
|
||||
id: str,
|
||||
use_cache: Literal["on", "off"],
|
||||
) -> Tuple:
|
||||
|
||||
# Load the references audio and text by id
|
||||
ref_folder = Path("references") / id
|
||||
ref_folder.mkdir(parents=True, exist_ok=True)
|
||||
ref_audios = list_files(
|
||||
ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
|
||||
)
|
||||
|
||||
if use_cache == "off" or id not in self.ref_by_id:
|
||||
# If the references are not already loaded, encode them
|
||||
prompt_tokens = [
|
||||
self.encode_reference(
|
||||
decoder_model=self.decoder_model,
|
||||
reference_audio=audio_to_bytes(str(ref_audio)),
|
||||
enable_reference_audio=True,
|
||||
)
|
||||
for ref_audio in ref_audios
|
||||
]
|
||||
prompt_texts = [
|
||||
read_ref_text(str(ref_audio.with_suffix(".lab")))
|
||||
for ref_audio in ref_audios
|
||||
]
|
||||
self.ref_by_id[id] = (prompt_tokens, prompt_texts)
|
||||
|
||||
else:
|
||||
# Reuse already encoded references
|
||||
logger.info("Use same references")
|
||||
prompt_tokens, prompt_texts = self.ref_by_id[id]
|
||||
|
||||
return prompt_tokens, prompt_texts
|
||||
|
||||
def load_by_hash(
|
||||
self,
|
||||
references: list[ServeReferenceAudio],
|
||||
use_cache: Literal["on", "off"],
|
||||
) -> Tuple:
|
||||
|
||||
# Load the references audio and text by hash
|
||||
audio_hashes = [sha256(ref.audio).hexdigest() for ref in references]
|
||||
|
||||
cache_used = False
|
||||
prompt_tokens, prompt_texts = [], []
|
||||
for i, ref in enumerate(references):
|
||||
if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash:
|
||||
# If the references are not already loaded, encode them
|
||||
prompt_tokens.append(
|
||||
self.encode_reference(
|
||||
decoder_model=self.decoder_model,
|
||||
reference_audio=ref.audio,
|
||||
enable_reference_audio=True,
|
||||
)
|
||||
)
|
||||
prompt_texts.append(ref.text)
|
||||
self.ref_by_hash[audio_hashes[i]] = (prompt_tokens, prompt_texts)
|
||||
|
||||
else:
|
||||
# Reuse already encoded references
|
||||
prompt_text, prompt_token = self.ref_by_hash[audio_hashes[i]]
|
||||
prompt_texts.append(prompt_text)
|
||||
prompt_tokens.append(prompt_token)
|
||||
cache_used = True
|
||||
|
||||
if cache_used:
|
||||
logger.info("Use same references")
|
||||
|
||||
return prompt_tokens, prompt_texts
|
||||
|
||||
def load_audio(self, reference_audio, sr):
|
||||
"""
|
||||
Load the audio data from a file or bytes.
|
||||
"""
|
||||
if len(reference_audio) > 255 or not Path(reference_audio).exists():
|
||||
audio_data = reference_audio
|
||||
reference_audio = io.BytesIO(audio_data)
|
||||
|
||||
waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend)
|
||||
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
|
||||
if original_sr != sr:
|
||||
resampler = torchaudio.transforms.Resample(
|
||||
orig_freq=original_sr, new_freq=sr
|
||||
)
|
||||
waveform = resampler(waveform)
|
||||
|
||||
audio = waveform.squeeze().numpy()
|
||||
return audio
|
42
tools/inference_engine/utils.py
Normal file
42
tools/inference_engine/utils.py
Normal file
@ -0,0 +1,42 @@
|
||||
import io
|
||||
import wave
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceResult:
|
||||
code: Literal["header", "segment", "error", "final"]
|
||||
audio: Optional[Tuple[int, np.ndarray]]
|
||||
error: Optional[Exception]
|
||||
|
||||
|
||||
def normalize_text(user_input: str, use_normalization: bool) -> str:
|
||||
"""Normalize user input text if needed."""
|
||||
if use_normalization:
|
||||
return ChnNormedText(raw_text=user_input).normalize()
|
||||
else:
|
||||
return user_input
|
||||
|
||||
|
||||
def wav_chunk_header(
|
||||
sample_rate: int = 44100, bit_depth: int = 16, channels: int = 1
|
||||
) -> np.ndarray:
|
||||
buffer = io.BytesIO()
|
||||
|
||||
with wave.open(buffer, "wb") as wav_file:
|
||||
wav_file.setnchannels(channels)
|
||||
wav_file.setsampwidth(bit_depth // 8)
|
||||
wav_file.setframerate(sample_rate)
|
||||
|
||||
wav_header_bytes = buffer.getvalue()
|
||||
buffer.close()
|
||||
|
||||
# Convert to numpy array
|
||||
wav_header = np.frombuffer(wav_header_bytes, dtype=np.uint8)
|
||||
|
||||
return wav_header
|
57
tools/inference_engine/vq_manager.py
Normal file
57
tools/inference_engine/vq_manager.py
Normal file
@ -0,0 +1,57 @@
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
||||
|
||||
|
||||
class VQManager:
|
||||
|
||||
def __init__(self):
|
||||
# Make Pylance happy (attribut/method not defined...)
|
||||
self.decoder_model: FireflyArchitecture
|
||||
self.load_audio: Callable
|
||||
|
||||
def decode_vq_tokens(self, codes):
|
||||
feature_lengths = torch.tensor(
|
||||
[codes.shape[1]], device=self.decoder_model.device
|
||||
)
|
||||
logger.info(f"VQ features: {codes.shape}")
|
||||
|
||||
if isinstance(self.decoder_model, FireflyArchitecture):
|
||||
return self.decoder_model.decode(
|
||||
indices=codes[None],
|
||||
feature_lengths=feature_lengths,
|
||||
)[0].squeeze()
|
||||
|
||||
raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
|
||||
|
||||
def encode_reference(self, reference_audio, enable_reference_audio):
|
||||
if enable_reference_audio and reference_audio is not None:
|
||||
# Load audios, and prepare basic info here
|
||||
reference_audio_content = self.load_audio(
|
||||
reference_audio, self.decoder_model.spec_transform.sample_rate
|
||||
)
|
||||
|
||||
audios = torch.from_numpy(reference_audio_content).to(
|
||||
self.decoder_model.device
|
||||
)[None, None, :]
|
||||
audio_lengths = torch.tensor(
|
||||
[audios.shape[2]], device=self.decoder_model.device, dtype=torch.long
|
||||
)
|
||||
logger.info(
|
||||
f"Loaded audio with {audios.shape[2] / self.decoder_model.spec_transform.sample_rate:.2f} seconds"
|
||||
)
|
||||
|
||||
# VQ Encoder
|
||||
if isinstance(self.decoder_model, FireflyArchitecture):
|
||||
prompt_tokens = self.decoder_model.encode(audios, audio_lengths)[0][0]
|
||||
logger.info(f"Encoded prompt: {prompt_tokens.shape}")
|
||||
else:
|
||||
raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
|
||||
else:
|
||||
prompt_tokens = None
|
||||
logger.info("No reference audio provided")
|
||||
|
||||
return prompt_tokens
|
@ -1,95 +0,0 @@
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import ormsgpack
|
||||
|
||||
from tools.schema import ServeReferenceAudio, ServeTTSRequest
|
||||
|
||||
api_key = os.environ.get("FISH_API_KEY", "YOUR_API_KEY")
|
||||
|
||||
|
||||
def audio_request():
|
||||
# priority: ref_id > references
|
||||
request = ServeTTSRequest(
|
||||
text="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
|
||||
# reference_id="114514",
|
||||
references=[
|
||||
ServeReferenceAudio(
|
||||
audio=open("lengyue.wav", "rb").read(),
|
||||
text=open("lengyue.lab", "r", encoding="utf-8").read(),
|
||||
)
|
||||
],
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
api_key = os.environ.get("FISH_API_KEY", "YOUR_API_KEY")
|
||||
|
||||
with (
|
||||
httpx.Client() as client,
|
||||
open("hello.wav", "wb") as f,
|
||||
):
|
||||
with client.stream(
|
||||
"POST",
|
||||
"http://127.0.0.1:8080/v1/tts",
|
||||
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
||||
headers={
|
||||
"authorization": f"Bearer {api_key}",
|
||||
"content-type": "application/msgpack",
|
||||
},
|
||||
timeout=None,
|
||||
) as response:
|
||||
for chunk in response.iter_bytes():
|
||||
f.write(chunk)
|
||||
|
||||
|
||||
def asr_request(audio_path: Path):
|
||||
|
||||
# Read the audio file
|
||||
with open(
|
||||
str(audio_path),
|
||||
"rb",
|
||||
) as audio_file:
|
||||
audio_data = audio_file.read()
|
||||
|
||||
# Prepare the request data
|
||||
request_data = {
|
||||
"audio": audio_data,
|
||||
"language": "en", # Optional: specify the language
|
||||
"ignore_timestamps": False, # Optional: set to True to ignore precise timestamps
|
||||
}
|
||||
|
||||
# Send the request
|
||||
with httpx.Client() as client:
|
||||
response = client.post(
|
||||
"https://api.fish.audio/v1/asr",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/msgpack",
|
||||
},
|
||||
content=ormsgpack.packb(request_data),
|
||||
)
|
||||
|
||||
# Parse the response
|
||||
result = response.json()
|
||||
|
||||
print(f"Transcribed text: {result['text']}")
|
||||
print(f"Audio duration: {result['duration']} seconds")
|
||||
|
||||
for segment in result["segments"]:
|
||||
print(f"Segment: {segment['text']}")
|
||||
print(f"Start time: {segment['start']}, End time: {segment['end']}")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--audio_path", type=Path, default="audio/ref/trump.mp3")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
asr_request(args.audio_path)
|
101
tools/run_webui.py
Normal file
101
tools/run_webui.py
Normal file
@ -0,0 +1,101 @@
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
|
||||
import pyrootutils
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
|
||||
from tools.inference_engine import TTSInferenceEngine
|
||||
from tools.llama.generate import launch_thread_safe_queue
|
||||
from tools.schema import ServeTTSRequest
|
||||
from tools.vqgan.inference import load_model as load_decoder_model
|
||||
from tools.webui import build_app
|
||||
from tools.webui.inference import get_inference_wrapper
|
||||
|
||||
# Make einx happy
|
||||
os.environ["EINX_FILTER_TRACEBACK"] = "false"
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--llama-checkpoint-path",
|
||||
type=Path,
|
||||
default="checkpoints/fish-speech-1.5",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder-checkpoint-path",
|
||||
type=Path,
|
||||
default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
||||
)
|
||||
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
||||
parser.add_argument("--device", type=str, default="cuda")
|
||||
parser.add_argument("--half", action="store_true")
|
||||
parser.add_argument("--compile", action="store_true")
|
||||
parser.add_argument("--max-gradio-length", type=int, default=0)
|
||||
parser.add_argument("--theme", type=str, default="light")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
args.precision = torch.half if args.half else torch.bfloat16
|
||||
|
||||
# Check if CUDA is available
|
||||
if not torch.cuda.is_available():
|
||||
logger.info("CUDA is not available, running on CPU.")
|
||||
args.device = "cpu"
|
||||
|
||||
logger.info("Loading Llama model...")
|
||||
llama_queue = launch_thread_safe_queue(
|
||||
checkpoint_path=args.llama_checkpoint_path,
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
compile=args.compile,
|
||||
)
|
||||
|
||||
logger.info("Loading VQ-GAN model...")
|
||||
decoder_model = load_decoder_model(
|
||||
config_name=args.decoder_config_name,
|
||||
checkpoint_path=args.decoder_checkpoint_path,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
logger.info("Decoder model loaded, warming up...")
|
||||
|
||||
# Create the inference engine
|
||||
inference_engine = TTSInferenceEngine(
|
||||
llama_queue=llama_queue,
|
||||
decoder_model=decoder_model,
|
||||
compile=args.compile,
|
||||
precision=args.precision,
|
||||
)
|
||||
|
||||
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
||||
list(
|
||||
inference_engine.inference(
|
||||
ServeTTSRequest(
|
||||
text="Hello world.",
|
||||
references=[],
|
||||
reference_id=None,
|
||||
max_new_tokens=0,
|
||||
chunk_length=200,
|
||||
top_p=0.7,
|
||||
repetition_penalty=1.5,
|
||||
temperature=0.7,
|
||||
format="wav",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Warming up done, launching the web UI...")
|
||||
|
||||
# Get the inference function with the immutable arguments
|
||||
inference_fct = get_inference_wrapper(inference_engine)
|
||||
|
||||
app = build_app(inference_fct, args.theme)
|
||||
app.launch(show_api=True)
|
@ -1,16 +1,14 @@
|
||||
import os
|
||||
import queue
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Literal, Optional
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import torch
|
||||
from pydantic import AfterValidator, BaseModel, Field, confloat, conint, conlist
|
||||
from pydantic import BaseModel, Field, conint, conlist
|
||||
from pydantic.functional_validators import SkipValidation
|
||||
|
||||
from fish_speech.conversation import Message, TextPart, VQPart
|
||||
|
||||
GLOBAL_NUM_SAMPLES = int(os.getenv("GLOBAL_NUM_SAMPLES", 1))
|
||||
|
||||
|
||||
class ServeVQPart(BaseModel):
|
||||
type: Literal["vq"] = "vq"
|
||||
@ -64,7 +62,7 @@ class ServeASRResponse(BaseModel):
|
||||
|
||||
|
||||
class ServeMessage(BaseModel):
|
||||
role: Literal["system", "assistant", "user", "raw"]
|
||||
role: Literal["system", "assistant", "user"]
|
||||
parts: list[ServeVQPart | ServeTextPart]
|
||||
|
||||
def to_conversation_message(self):
|
||||
@ -85,7 +83,7 @@ class ServeMessage(BaseModel):
|
||||
return new_message
|
||||
|
||||
|
||||
class ServeRequest(BaseModel):
|
||||
class ServeChatRequest(BaseModel):
|
||||
messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
|
||||
max_new_tokens: int = 1024
|
||||
top_p: float = 0.7
|
||||
@ -114,11 +112,6 @@ class ServeVQGANDecodeResponse(BaseModel):
|
||||
audios: list[bytes]
|
||||
|
||||
|
||||
class ServeReferenceAudio(BaseModel):
|
||||
audio: bytes
|
||||
text: str
|
||||
|
||||
|
||||
class ServeForwardMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
@ -150,24 +143,11 @@ class ServeReferenceAudio(BaseModel):
|
||||
return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
|
||||
|
||||
|
||||
class ServeChatRequestV1(BaseModel):
|
||||
model: str = "llama3-8b"
|
||||
messages: list[ServeForwardMessage] = []
|
||||
audio: bytes | None = None
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
max_tokens: int = 256
|
||||
voice: str = "jessica"
|
||||
tts_audio_format: Literal["mp3", "pcm", "opus"] = "mp3"
|
||||
tts_audio_bitrate: Literal[16, 24, 32, 48, 64, 96, 128, 192] = 128
|
||||
|
||||
|
||||
class ServeTTSRequest(BaseModel):
|
||||
text: str
|
||||
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
|
||||
# Audio format
|
||||
format: Literal["wav", "pcm", "mp3"] = "wav"
|
||||
mp3_bitrate: Literal[64, 128, 192] = 128
|
||||
# References audios for in-context learning
|
||||
references: list[ServeReferenceAudio] = []
|
||||
# Reference id
|
||||
@ -175,16 +155,16 @@ class ServeTTSRequest(BaseModel):
|
||||
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
|
||||
reference_id: str | None = None
|
||||
seed: int | None = None
|
||||
use_memory_cache: Literal["on-demand", "never"] = "never"
|
||||
use_memory_cache: Literal["on", "off"] = "off"
|
||||
# Normalize text for en & zh, this increase stability for numbers
|
||||
normalize: bool = True
|
||||
mp3_bitrate: Optional[int] = 64
|
||||
opus_bitrate: Optional[int] = -1000
|
||||
# Balance mode will reduce latency to 300ms, but may decrease stability
|
||||
latency: Literal["normal", "balanced"] = "normal"
|
||||
# not usually used below
|
||||
streaming: bool = False
|
||||
max_new_tokens: int = 1024
|
||||
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
||||
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
|
||||
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
||||
|
||||
class Config:
|
||||
# Allow arbitrary types for pytorch related types
|
||||
arbitrary_types_allowed = True
|
||||
|
57
tools/server/agent/__init__.py
Normal file
57
tools/server/agent/__init__.py
Normal file
@ -0,0 +1,57 @@
|
||||
import struct
|
||||
from functools import partial
|
||||
|
||||
import ormsgpack
|
||||
|
||||
from tools.server.agent.generate import generate_responses
|
||||
from tools.server.agent.pre_generation_utils import prepare_messages
|
||||
|
||||
|
||||
def execute_request(input_queue, tokenizer, config, request, device):
|
||||
"""
|
||||
This function prepares the conversation, encodes the request,
|
||||
sends the generation request, and handles decoding/streaming.
|
||||
It returns a response generator (ServeResponse or ServeStreamResponse).
|
||||
"""
|
||||
prompt, im_end_id = prepare_messages(request, tokenizer, config)
|
||||
yield from generate_responses(
|
||||
input_queue, tokenizer, config, request, prompt, im_end_id, device
|
||||
)
|
||||
|
||||
|
||||
def response_generator(req, llama_queue, tokenizer, config, device):
|
||||
"""
|
||||
Non-streaming response wrapper for the chat endpoint.
|
||||
Only returns the final result.
|
||||
"""
|
||||
generator = execute_request(llama_queue, tokenizer, config, req, device)
|
||||
return next(generator)
|
||||
|
||||
|
||||
async def streaming_generator(req, llama_queue, tokenizer, config, device, json_mode):
|
||||
"""
|
||||
Streaming response wrapper for the chat endpoint.
|
||||
Returns the response in chunks.
|
||||
"""
|
||||
generator = execute_request(llama_queue, tokenizer, config, req, device)
|
||||
for i in generator:
|
||||
if json_mode:
|
||||
body = i.model_dump_json().encode("utf-8")
|
||||
yield b"data: " + body + b"\n\n"
|
||||
else:
|
||||
body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
|
||||
yield struct.pack("I", len(body)) + body
|
||||
|
||||
|
||||
def get_response_generator(
|
||||
llama_queue, tokenizer, config, req, device, json_mode
|
||||
) -> partial:
|
||||
"""
|
||||
Get the correct response generator based on the request.
|
||||
"""
|
||||
if not req.streaming:
|
||||
return partial(response_generator, req, llama_queue, tokenizer, config, device)
|
||||
else:
|
||||
return partial(
|
||||
streaming_generator, req, llama_queue, tokenizer, config, device, json_mode
|
||||
)
|
119
tools/server/agent/generate.py
Normal file
119
tools/server/agent/generate.py
Normal file
@ -0,0 +1,119 @@
|
||||
import time
|
||||
|
||||
from tools.schema import ServeMessage, ServeResponse, ServeStreamResponse
|
||||
from tools.server.agent.generation_utils import (
|
||||
initialize_decode_buffers,
|
||||
process_response_tokens,
|
||||
send_reset_buffer,
|
||||
)
|
||||
from tools.server.agent.pre_generation_utils import (
|
||||
create_generation_request,
|
||||
send_generation_request,
|
||||
)
|
||||
|
||||
|
||||
def generate_responses(
|
||||
input_queue, tokenizer, config, request, prompt, im_end_id, device
|
||||
):
|
||||
"""
|
||||
Main generation function that handles the conversation, encodes the request,
|
||||
sends the generation request, and handles decoding/streaming.
|
||||
It returns a response generator (ServeResponse or ServeStreamResponse).
|
||||
"""
|
||||
stats = {}
|
||||
start = time.time()
|
||||
stats["start_time"] = start
|
||||
stats["tokens_count"] = 0
|
||||
|
||||
# Prepare and send the generation request
|
||||
req = create_generation_request(prompt, request, im_end_id, device)
|
||||
response_queue = send_generation_request(input_queue, req)
|
||||
decode_buffer, parts, finished = initialize_decode_buffers(request.num_samples)
|
||||
|
||||
while True:
|
||||
response = response_queue.get()
|
||||
|
||||
# Handle abnormal finish or error
|
||||
if response in ["stop", "error"]:
|
||||
finish_reason = response
|
||||
break
|
||||
|
||||
# Process the response tokens
|
||||
is_first_token = stats["tokens_count"] == 0
|
||||
responses = process_response_tokens(
|
||||
response,
|
||||
tokenizer,
|
||||
config,
|
||||
request,
|
||||
decode_buffer,
|
||||
parts,
|
||||
finished,
|
||||
im_end_id,
|
||||
stats,
|
||||
start,
|
||||
is_first_token,
|
||||
)
|
||||
|
||||
# Yield the responses if streaming
|
||||
if request.streaming and responses:
|
||||
for r in responses:
|
||||
yield r
|
||||
|
||||
stats["tokens_count"] += 1
|
||||
|
||||
# Check if all samples are finished
|
||||
if all(finished):
|
||||
finish_reason = "stop"
|
||||
break
|
||||
|
||||
# Finalize the response
|
||||
final_responses = finalize_response(
|
||||
request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
|
||||
)
|
||||
for fr in final_responses:
|
||||
yield fr
|
||||
|
||||
|
||||
def finalize_response(
|
||||
request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
|
||||
):
|
||||
"""
|
||||
Finalize the response by sending the remaining text buffers.
|
||||
"""
|
||||
responses = []
|
||||
|
||||
# Send the remaining text buffers
|
||||
for sample_id in range(request.num_samples):
|
||||
responses.extend(
|
||||
send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
|
||||
)
|
||||
|
||||
# Calculate the final stats
|
||||
stats["total_time"] = (time.time() - stats["start_time"]) * 1000
|
||||
stats["total_tokens"] = stats["tokens_count"]
|
||||
|
||||
# If streaming, send the final chunks for each sample
|
||||
if request.streaming:
|
||||
for sample_id in range(request.num_samples):
|
||||
if finished[sample_id]:
|
||||
continue
|
||||
responses.append(
|
||||
ServeStreamResponse(
|
||||
finish_reason=finish_reason, stats=stats, sample_id=sample_id
|
||||
)
|
||||
)
|
||||
else:
|
||||
# If not streaming, send the full messages for each sample
|
||||
full_messages = [
|
||||
ServeMessage(role="assistant", parts=parts[i])
|
||||
for i in range(request.num_samples)
|
||||
]
|
||||
responses.append(
|
||||
ServeResponse(
|
||||
messages=full_messages,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
)
|
||||
)
|
||||
|
||||
return responses
|
122
tools/server/agent/generation_utils.py
Normal file
122
tools/server/agent/generation_utils.py
Normal file
@ -0,0 +1,122 @@
|
||||
import time
|
||||
|
||||
from tools.schema import (
|
||||
ServeStreamDelta,
|
||||
ServeStreamResponse,
|
||||
ServeTextPart,
|
||||
ServeVQPart,
|
||||
)
|
||||
|
||||
|
||||
def initialize_decode_buffers(num_samples):
|
||||
"""Initialise the decode buffers for each sample."""
|
||||
decode_buffer = [[] for _ in range(num_samples)]
|
||||
parts = [[] for _ in range(num_samples)]
|
||||
finished = [False for _ in range(num_samples)]
|
||||
return decode_buffer, parts, finished
|
||||
|
||||
|
||||
def send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request):
|
||||
"""Send the remaining text buffer for a sample."""
|
||||
if len(decode_buffer[sample_id]) == 0:
|
||||
return []
|
||||
|
||||
decoded = tokenizer.decode(decode_buffer[sample_id])
|
||||
part = ServeTextPart(text=decoded)
|
||||
|
||||
responses = []
|
||||
if request.streaming:
|
||||
responses.append(ServeStreamResponse(delta=ServeStreamDelta(part=part)))
|
||||
else:
|
||||
parts[sample_id].append(part)
|
||||
|
||||
decode_buffer[sample_id] = []
|
||||
return responses
|
||||
|
||||
|
||||
def handle_semantic_tokens(tokens, config, sample_id, parts, request):
|
||||
"""Handle the semantic tokens returned by the model."""
|
||||
responses = []
|
||||
_tokens = tokens[1:].clone()
|
||||
|
||||
if not config.share_codebook_embeddings:
|
||||
for i in range(len(_tokens)):
|
||||
_tokens[i] -= config.codebook_size * i
|
||||
|
||||
# If streaming, send the VQ parts directly
|
||||
if request.streaming:
|
||||
responses.append(
|
||||
ServeStreamResponse(
|
||||
sample_id=sample_id,
|
||||
delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# If not streaming, accumulate the VQ parts
|
||||
if not parts[sample_id] or not isinstance(parts[sample_id][-1], ServeVQPart):
|
||||
parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
|
||||
else:
|
||||
# Accumulate the codes
|
||||
for codebook_id, value in enumerate(_tokens):
|
||||
parts[sample_id][-1].codes[codebook_id].append(value.item())
|
||||
|
||||
return responses
|
||||
|
||||
|
||||
def process_response_tokens(
|
||||
response,
|
||||
tokenizer,
|
||||
config,
|
||||
request,
|
||||
decode_buffer,
|
||||
parts,
|
||||
finished,
|
||||
im_end_id,
|
||||
stats,
|
||||
start,
|
||||
is_first_token,
|
||||
):
|
||||
"""Process the response tokens returned by the model."""
|
||||
responses = []
|
||||
for sample_id, tokens in enumerate(response):
|
||||
if finished[sample_id]:
|
||||
continue
|
||||
|
||||
# End of the conversation
|
||||
if tokens[0] == im_end_id:
|
||||
finished[sample_id] = True
|
||||
# Send the remaining text buffer
|
||||
responses.extend(
|
||||
send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
|
||||
)
|
||||
if request.streaming:
|
||||
responses.append(
|
||||
ServeStreamResponse(
|
||||
sample_id=sample_id,
|
||||
finish_reason="stop",
|
||||
stats=stats,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
# Check if the token is semantic
|
||||
is_semantic = (
|
||||
tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id
|
||||
)
|
||||
|
||||
if is_semantic:
|
||||
# Before the semantic tokens, send the remaining text buffer
|
||||
responses.extend(
|
||||
send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
|
||||
)
|
||||
responses.extend(
|
||||
handle_semantic_tokens(tokens, config, sample_id, parts, request)
|
||||
)
|
||||
else:
|
||||
# Accumulate the text tokens (not implemented?)
|
||||
decode_buffer[sample_id].append(tokens[0, 0])
|
||||
|
||||
if is_first_token:
|
||||
stats["time_to_first_token"] = (time.time() - start) * 1000
|
||||
|
||||
return responses
|
72
tools/server/agent/pre_generation_utils.py
Normal file
72
tools/server/agent/pre_generation_utils.py
Normal file
@ -0,0 +1,72 @@
|
||||
import queue
|
||||
|
||||
from fish_speech.conversation import Conversation, Message
|
||||
from fish_speech.tokenizer import IM_END_TOKEN
|
||||
from tools.llama.generate import GenerateRequest
|
||||
|
||||
|
||||
def prepare_messages(request, tokenizer, config):
|
||||
"""
|
||||
Reorganise the provided list of messages into a conversation.
|
||||
Encode the conversation for inference.
|
||||
"""
|
||||
# Convert the messages to ConversationMessage objects
|
||||
messages = [msg.to_conversation_message() for msg in request.messages]
|
||||
|
||||
if len(messages) < 1:
|
||||
raise ValueError("At least one message is required")
|
||||
|
||||
# Check the last message to determine the next step
|
||||
last_role = messages[-1].role
|
||||
match last_role:
|
||||
case "user":
|
||||
# The last message is from the user, ask the assistant to respond with a new message
|
||||
messages.append(
|
||||
Message(role="assistant", parts=[], add_im_end=False, modality="voice")
|
||||
)
|
||||
case "raw":
|
||||
# The last message is raw text, ask the assistant to complete it
|
||||
messages[-1].add_im_start = False
|
||||
messages[-1].add_im_end = False
|
||||
messages[-1].modality = "voice"
|
||||
case "assistant":
|
||||
# The last message is from the assistant, ask the assistant to continue
|
||||
messages[-1].add_im_end = False
|
||||
case _:
|
||||
# We expect it to be assistant if not user or raw
|
||||
raise ValueError("The last message must be from the assistant, user or raw")
|
||||
|
||||
# Create a conversation object and encode it for inference
|
||||
conv = Conversation(messages=messages)
|
||||
prompt = conv.encode_for_inference(
|
||||
tokenizer=tokenizer, num_codebooks=config.num_codebooks
|
||||
)
|
||||
im_end_id = tokenizer.get_token_id(IM_END_TOKEN)
|
||||
|
||||
return prompt, im_end_id
|
||||
|
||||
|
||||
def create_generation_request(prompt, request, im_end_id, device):
|
||||
"""
|
||||
Convert the request into a dictionary that can be sent to the model for generation.
|
||||
"""
|
||||
req = {
|
||||
"prompt": prompt.to(device),
|
||||
"max_new_tokens": request.max_new_tokens,
|
||||
"im_end_id": im_end_id,
|
||||
"temperature": request.temperature,
|
||||
"top_p": request.top_p,
|
||||
"repetition_penalty": request.repetition_penalty,
|
||||
"num_samples": request.num_samples,
|
||||
"early_stop_threshold": request.early_stop_threshold,
|
||||
}
|
||||
return req
|
||||
|
||||
|
||||
def send_generation_request(input_queue, req):
|
||||
"""
|
||||
Send the generation request to the model and return a queue to get the response.
|
||||
"""
|
||||
response_queue = queue.Queue()
|
||||
input_queue.put(GenerateRequest(req, response_queue))
|
||||
return response_queue
|
75
tools/server/api_utils.py
Normal file
75
tools/server/api_utils.py
Normal file
@ -0,0 +1,75 @@
|
||||
from argparse import ArgumentParser
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated, Any
|
||||
|
||||
import ormsgpack
|
||||
from baize.datastructures import ContentType
|
||||
from kui.asgi import HTTPException, HttpRequest
|
||||
|
||||
from tools.inference_engine import TTSInferenceEngine
|
||||
from tools.schema import ServeTTSRequest
|
||||
from tools.server.inference import inference_wrapper as inference
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
|
||||
parser.add_argument("--load-asr-model", action="store_true")
|
||||
parser.add_argument(
|
||||
"--llama-checkpoint-path",
|
||||
type=str,
|
||||
default="checkpoints/fish-speech-1.5",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder-checkpoint-path",
|
||||
type=str,
|
||||
default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
||||
)
|
||||
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
||||
parser.add_argument("--device", type=str, default="cuda")
|
||||
parser.add_argument("--half", action="store_true")
|
||||
parser.add_argument("--compile", action="store_true")
|
||||
parser.add_argument("--max-text-length", type=int, default=0)
|
||||
parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
|
||||
parser.add_argument("--workers", type=int, default=1)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
class MsgPackRequest(HttpRequest):
|
||||
async def data(
|
||||
self,
|
||||
) -> Annotated[
|
||||
Any, ContentType("application/msgpack"), ContentType("application/json")
|
||||
]:
|
||||
if self.content_type == "application/msgpack":
|
||||
return ormsgpack.unpackb(await self.body)
|
||||
|
||||
elif self.content_type == "application/json":
|
||||
return await self.json
|
||||
|
||||
raise HTTPException(
|
||||
HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
|
||||
headers={"Accept": "application/msgpack, application/json"},
|
||||
)
|
||||
|
||||
|
||||
async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine):
|
||||
for chunk in inference(req, engine):
|
||||
if isinstance(chunk, bytes):
|
||||
yield chunk
|
||||
|
||||
|
||||
async def buffer_to_async_generator(buffer):
|
||||
yield buffer
|
||||
|
||||
|
||||
def get_content_type(audio_format):
|
||||
if audio_format == "wav":
|
||||
return "audio/wav"
|
||||
elif audio_format == "flac":
|
||||
return "audio/flac"
|
||||
elif audio_format == "mp3":
|
||||
return "audio/mpeg"
|
||||
else:
|
||||
return "application/octet-stream"
|
27
tools/server/exception_handler.py
Normal file
27
tools/server/exception_handler.py
Normal file
@ -0,0 +1,27 @@
|
||||
import traceback
|
||||
from http import HTTPStatus
|
||||
|
||||
from kui.asgi import HTTPException, JSONResponse
|
||||
|
||||
|
||||
class ExceptionHandler:
|
||||
|
||||
async def http_exception_handler(self, exc: HTTPException):
|
||||
return JSONResponse(
|
||||
dict(
|
||||
statusCode=exc.status_code,
|
||||
message=exc.content,
|
||||
error=HTTPStatus(exc.status_code).phrase,
|
||||
),
|
||||
exc.status_code,
|
||||
exc.headers,
|
||||
)
|
||||
|
||||
async def other_exception_handler(self, exc: Exception):
|
||||
traceback.print_exc()
|
||||
|
||||
status = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
return JSONResponse(
|
||||
dict(statusCode=status, message=str(exc), error=status.phrase),
|
||||
status,
|
||||
)
|
41
tools/server/inference.py
Normal file
41
tools/server/inference.py
Normal file
@ -0,0 +1,41 @@
|
||||
from http import HTTPStatus
|
||||
|
||||
import numpy as np
|
||||
from kui.asgi import HTTPException
|
||||
|
||||
from tools.inference_engine import TTSInferenceEngine
|
||||
from tools.schema import ServeTTSRequest
|
||||
|
||||
AMPLITUDE = 32768 # Needs an explaination
|
||||
|
||||
|
||||
def inference_wrapper(req: ServeTTSRequest, engine: TTSInferenceEngine):
|
||||
"""
|
||||
Wrapper for the inference function.
|
||||
Used in the API server.
|
||||
"""
|
||||
for result in engine.inference(req):
|
||||
match result.code:
|
||||
case "header":
|
||||
if isinstance(result.audio, tuple):
|
||||
yield result.audio[1]
|
||||
|
||||
case "error":
|
||||
raise HTTPException(
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
content=str(result.error),
|
||||
)
|
||||
|
||||
case "segment":
|
||||
if isinstance(result.audio, tuple):
|
||||
yield (result.audio[1] * AMPLITUDE).astype(np.int16).tobytes()
|
||||
|
||||
case "final":
|
||||
if isinstance(result.audio, tuple):
|
||||
yield result.audio[1]
|
||||
return None # Stop the generator
|
||||
|
||||
raise HTTPException(
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
content="No audio generated, please check the input text.",
|
||||
)
|
119
tools/server/model_manager.py
Normal file
119
tools/server/model_manager.py
Normal file
@ -0,0 +1,119 @@
|
||||
import torch
|
||||
from funasr import AutoModel
|
||||
from loguru import logger
|
||||
|
||||
from tools.inference_engine import TTSInferenceEngine
|
||||
from tools.llama.generate import (
|
||||
launch_thread_safe_queue,
|
||||
launch_thread_safe_queue_agent,
|
||||
)
|
||||
from tools.schema import ServeTTSRequest
|
||||
from tools.server.inference import inference_wrapper as inference
|
||||
from tools.vqgan.inference import load_model as load_decoder_model
|
||||
|
||||
ASR_MODEL_NAME = "iic/SenseVoiceSmall"
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(
|
||||
self,
|
||||
mode: str,
|
||||
device: str,
|
||||
half: bool,
|
||||
compile: bool,
|
||||
asr_enabled: bool,
|
||||
llama_checkpoint_path: str,
|
||||
decoder_checkpoint_path: str,
|
||||
decoder_config_name: str,
|
||||
) -> None:
|
||||
|
||||
self.mode = mode
|
||||
self.device = device
|
||||
self.half = half
|
||||
self.compile = compile
|
||||
|
||||
self.precision = torch.half if half else torch.bfloat16
|
||||
|
||||
# Check if CUDA is available
|
||||
if not torch.cuda.is_available():
|
||||
self.device = "cpu"
|
||||
logger.info("CUDA is not available, running on CPU.")
|
||||
|
||||
# Load the ASR model if enabled
|
||||
if asr_enabled:
|
||||
self.load_asr_model(self.device)
|
||||
|
||||
# Load the TTS models
|
||||
self.load_llama_model(
|
||||
llama_checkpoint_path, self.device, self.precision, self.compile, self.mode
|
||||
)
|
||||
self.load_decoder_model(
|
||||
decoder_config_name, decoder_checkpoint_path, self.device
|
||||
)
|
||||
self.tts_inference_engine = TTSInferenceEngine(
|
||||
llama_queue=self.llama_queue,
|
||||
decoder_model=self.decoder_model,
|
||||
precision=self.precision,
|
||||
compile=self.compile,
|
||||
)
|
||||
|
||||
# Warm up the models
|
||||
if self.mode == "tts":
|
||||
self.warm_up(self.tts_inference_engine)
|
||||
|
||||
def load_asr_model(self, device, hub="ms") -> None:
|
||||
self.asr_model = AutoModel(
|
||||
model=ASR_MODEL_NAME,
|
||||
device=device,
|
||||
disable_pbar=True,
|
||||
hub=hub,
|
||||
)
|
||||
logger.info("ASR model loaded.")
|
||||
|
||||
def load_llama_model(
|
||||
self, checkpoint_path, device, precision, compile, mode
|
||||
) -> None:
|
||||
|
||||
if mode == "tts":
|
||||
self.llama_queue = launch_thread_safe_queue(
|
||||
checkpoint_path=checkpoint_path,
|
||||
device=device,
|
||||
precision=precision,
|
||||
compile=compile,
|
||||
)
|
||||
elif mode == "agent":
|
||||
self.llama_queue, self.tokenizer, self.config = (
|
||||
launch_thread_safe_queue_agent(
|
||||
checkpoint_path=checkpoint_path,
|
||||
device=device,
|
||||
precision=precision,
|
||||
compile=compile,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
|
||||
logger.info("LLAMA model loaded.")
|
||||
|
||||
def load_decoder_model(self, config_name, checkpoint_path, device) -> None:
|
||||
self.decoder_model = load_decoder_model(
|
||||
config_name=config_name,
|
||||
checkpoint_path=checkpoint_path,
|
||||
device=device,
|
||||
)
|
||||
logger.info("Decoder model loaded.")
|
||||
|
||||
def warm_up(self, tts_inference_engine) -> None:
|
||||
request = ServeTTSRequest(
|
||||
text="Hello world.",
|
||||
references=[],
|
||||
reference_id=None,
|
||||
max_new_tokens=0,
|
||||
chunk_length=200,
|
||||
top_p=0.7,
|
||||
repetition_penalty=1.5,
|
||||
temperature=0.7,
|
||||
format="wav",
|
||||
)
|
||||
list(inference(request, tts_inference_engine))
|
||||
logger.info("Models warmed up.")
|
129
tools/server/model_utils.py
Normal file
129
tools/server/model_utils.py
Normal file
@ -0,0 +1,129 @@
|
||||
import io
|
||||
import re
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
import torchaudio
|
||||
from cachetools import LRUCache, cached
|
||||
|
||||
CACHE_MAXSIZE = 10000
|
||||
MICRO_BATCH_SIZE = 8
|
||||
ASR_SAMPLE_RATE = 16000
|
||||
HUGE_GAP_THRESHOLD = 4000
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.autocast(device_type="cuda", dtype=torch.half)
|
||||
def batch_encode(model, audios_list: list[bytes]):
|
||||
audios: list[torch.Tensor] = [
|
||||
(
|
||||
torch.from_numpy(
|
||||
librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
|
||||
)[None]
|
||||
if isinstance(audio, bytes)
|
||||
else audio
|
||||
)
|
||||
for audio in audios_list
|
||||
]
|
||||
|
||||
lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
|
||||
max_length = lengths.max().item()
|
||||
|
||||
print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
|
||||
|
||||
padded = torch.stack(
|
||||
[
|
||||
torch.nn.functional.pad(audio, (0, int(max_length - audio.shape[-1])))
|
||||
for audio in audios
|
||||
]
|
||||
).to(model.device)
|
||||
|
||||
features, feature_lengths = model.encode(padded, audio_lengths=lengths)
|
||||
features, feature_lengths = features.cpu(), feature_lengths.cpu()
|
||||
|
||||
return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
|
||||
|
||||
|
||||
@cached(
|
||||
cache=LRUCache(maxsize=CACHE_MAXSIZE),
|
||||
key=lambda model, audios: (model.device, tuple(audios)),
|
||||
)
|
||||
def cached_vqgan_batch_encode(model, audios: list[bytes]):
|
||||
return batch_encode(model, audios)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.autocast(device_type="cuda", dtype=torch.half)
|
||||
def vqgan_decode(model, features):
|
||||
lengths = torch.tensor(
|
||||
[feature.shape[-1] for feature in features], device=model.device
|
||||
)
|
||||
max_length = lengths.max().item()
|
||||
padded = torch.stack(
|
||||
[
|
||||
torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
|
||||
for feature in features
|
||||
]
|
||||
).to(model.device)
|
||||
|
||||
# If bs too large, we do micro batch decode
|
||||
audios, audio_lengths = [], []
|
||||
for i in range(0, padded.shape[0], MICRO_BATCH_SIZE):
|
||||
audio, audio_length = model.decode(
|
||||
padded[i : i + MICRO_BATCH_SIZE],
|
||||
feature_lengths=lengths[i : i + MICRO_BATCH_SIZE],
|
||||
)
|
||||
audios.append(audio)
|
||||
audio_lengths.append(audio_length)
|
||||
audios = torch.cat(audios, dim=0)
|
||||
audio_lengths = torch.cat(audio_lengths, dim=0)
|
||||
audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
|
||||
|
||||
return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def batch_asr(model, lock, audios, sr, language="auto"):
|
||||
resampled_audios = []
|
||||
for audio in audios:
|
||||
audio = torchaudio.functional.resample(audio, sr, ASR_SAMPLE_RATE)
|
||||
assert audio.ndim == 1
|
||||
resampled_audios.append(audio)
|
||||
|
||||
with lock:
|
||||
res = model.generate(
|
||||
input=resampled_audios,
|
||||
batch_size=len(resampled_audios),
|
||||
language=language,
|
||||
use_itn=True,
|
||||
)
|
||||
|
||||
results = []
|
||||
for r, audio in zip(res, audios):
|
||||
text = r["text"]
|
||||
text = re.sub(r"<\|.*?\|>", "", text)
|
||||
duration = len(audio) / sr * 1000
|
||||
huge_gap = False
|
||||
|
||||
if "timestamp" in r and len(r["timestamp"]) > 2:
|
||||
for timestamp_a, timestamp_b in zip(
|
||||
r["timestamp"][:-1], r["timestamp"][1:]
|
||||
):
|
||||
# If there is a gap of more than 4 seconds, we consider it as a huge gap
|
||||
if timestamp_b[0] - timestamp_a[1] > HUGE_GAP_THRESHOLD:
|
||||
huge_gap = True
|
||||
break
|
||||
|
||||
# Doesn't make sense to have a huge gap at the end
|
||||
if duration - r["timestamp"][-1][1] > HUGE_GAP_THRESHOLD:
|
||||
huge_gap = True
|
||||
|
||||
results.append(
|
||||
{
|
||||
"text": text,
|
||||
"duration": duration,
|
||||
"huge_gap": huge_gap,
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
246
tools/server/views.py
Normal file
246
tools/server/views.py
Normal file
@ -0,0 +1,246 @@
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
|
||||
import numpy as np
|
||||
import ormsgpack
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from kui.asgi import HTTPException, HttpView, JSONResponse, StreamResponse, request
|
||||
from loguru import logger
|
||||
|
||||
from tools.schema import (
|
||||
ServeASRRequest,
|
||||
ServeASRResponse,
|
||||
ServeChatRequest,
|
||||
ServeTTSRequest,
|
||||
ServeVQGANDecodeRequest,
|
||||
ServeVQGANDecodeResponse,
|
||||
ServeVQGANEncodeRequest,
|
||||
ServeVQGANEncodeResponse,
|
||||
)
|
||||
from tools.server.agent import get_response_generator
|
||||
from tools.server.api_utils import (
|
||||
buffer_to_async_generator,
|
||||
get_content_type,
|
||||
inference_async,
|
||||
)
|
||||
from tools.server.inference import inference_wrapper as inference
|
||||
from tools.server.model_manager import ModelManager
|
||||
from tools.server.model_utils import batch_asr, cached_vqgan_batch_encode, vqgan_decode
|
||||
|
||||
MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1))
|
||||
|
||||
|
||||
class HealthView(HttpView):
|
||||
"""
|
||||
Return the health status of the server.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def post(cls):
|
||||
return JSONResponse({"status": "ok"})
|
||||
|
||||
|
||||
class VQGANEncodeView(HttpView):
|
||||
"""
|
||||
Encode the audio into symbolic tokens.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def post(cls):
|
||||
# Decode the request
|
||||
payload = await request.data()
|
||||
req = ServeVQGANEncodeRequest(**payload)
|
||||
|
||||
# Get the model from the app
|
||||
model_manager: ModelManager = request.app.state.model_manager
|
||||
decoder_model = model_manager.decoder_model
|
||||
|
||||
# Encode the audio
|
||||
start_time = time.time()
|
||||
tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
|
||||
logger.info(
|
||||
f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms"
|
||||
)
|
||||
|
||||
# Return the response
|
||||
return ormsgpack.packb(
|
||||
ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
|
||||
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
|
||||
)
|
||||
|
||||
|
||||
class VQGANDecodeView(HttpView):
|
||||
"""
|
||||
Decode the symbolic tokens into audio.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def post(cls):
|
||||
# Decode the request
|
||||
payload = await request.data()
|
||||
req = ServeVQGANDecodeRequest(**payload)
|
||||
|
||||
# Get the model from the app
|
||||
model_manager: ModelManager = request.app.state.model_manager
|
||||
decoder_model = model_manager.decoder_model
|
||||
|
||||
# Decode the audio
|
||||
tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
|
||||
start_time = time.time()
|
||||
audios = vqgan_decode(decoder_model, tokens)
|
||||
logger.info(
|
||||
f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms"
|
||||
)
|
||||
audios = [audio.astype(np.float16).tobytes() for audio in audios]
|
||||
|
||||
# Return the response
|
||||
return ormsgpack.packb(
|
||||
ServeVQGANDecodeResponse(audios=audios),
|
||||
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
|
||||
)
|
||||
|
||||
|
||||
class ASRView(HttpView):
|
||||
"""
|
||||
Perform automatic speech recognition on the audio.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def post(cls):
|
||||
# Decode the request
|
||||
payload = await request.data()
|
||||
req = ServeASRRequest(**payload)
|
||||
|
||||
# Get the model from the app
|
||||
model_manager: ModelManager = request.app.state.model_manager
|
||||
asr_model = model_manager.asr_model
|
||||
lock = request.app.state.lock
|
||||
|
||||
# Perform ASR
|
||||
start_time = time.time()
|
||||
audios = [np.frombuffer(audio, dtype=np.float16) for audio in req.audios]
|
||||
audios = [torch.from_numpy(audio).float() for audio in audios]
|
||||
|
||||
if any(audios.shape[-1] >= 30 * req.sample_rate for audios in audios):
|
||||
raise HTTPException(status_code=400, content="Audio length is too long")
|
||||
|
||||
transcriptions = batch_asr(
|
||||
asr_model, lock, audios=audios, sr=req.sample_rate, language=req.language
|
||||
)
|
||||
logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
|
||||
|
||||
# Return the response
|
||||
return ormsgpack.packb(
|
||||
ServeASRResponse(transcriptions=transcriptions),
|
||||
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
|
||||
)
|
||||
|
||||
|
||||
class TTSView(HttpView):
|
||||
"""
|
||||
Perform text-to-speech on the input text.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def post(cls):
|
||||
# Decode the request
|
||||
payload = await request.data()
|
||||
req = ServeTTSRequest(**payload)
|
||||
|
||||
# Get the model from the app
|
||||
app_state = request.app.state
|
||||
model_manager: ModelManager = app_state.model_manager
|
||||
engine = model_manager.tts_inference_engine
|
||||
sample_rate = engine.decoder_model.spec_transform.sample_rate
|
||||
|
||||
# Check if the text is too long
|
||||
if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
|
||||
raise HTTPException(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
content=f"Text is too long, max length is {app_state.max_text_length}",
|
||||
)
|
||||
|
||||
# Check if streaming is enabled
|
||||
if req.streaming and req.format != "wav":
|
||||
raise HTTPException(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
content="Streaming only supports WAV format",
|
||||
)
|
||||
|
||||
# Perform TTS
|
||||
if req.streaming:
|
||||
return StreamResponse(
|
||||
iterable=inference_async(req, engine),
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=audio.{req.format}",
|
||||
},
|
||||
content_type=get_content_type(req.format),
|
||||
)
|
||||
else:
|
||||
fake_audios = next(inference(req, engine))
|
||||
buffer = io.BytesIO()
|
||||
sf.write(
|
||||
buffer,
|
||||
fake_audios,
|
||||
sample_rate,
|
||||
format=req.format,
|
||||
)
|
||||
|
||||
return StreamResponse(
|
||||
iterable=buffer_to_async_generator(buffer.getvalue()),
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=audio.{req.format}",
|
||||
},
|
||||
content_type=get_content_type(req.format),
|
||||
)
|
||||
|
||||
|
||||
class ChatView(HttpView):
|
||||
"""
|
||||
Perform chatbot inference on the input text.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def post(cls):
|
||||
# Decode the request
|
||||
payload = await request.data()
|
||||
req = ServeChatRequest(**payload)
|
||||
|
||||
# Check that the number of samples requested is correct
|
||||
if req.num_samples < 1 or req.num_samples > MAX_NUM_SAMPLES:
|
||||
raise HTTPException(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
content=f"Number of samples must be between 1 and {MAX_NUM_SAMPLES}",
|
||||
)
|
||||
|
||||
# Get the type of content provided
|
||||
content_type = request.headers.get("Content-Type", "application/json")
|
||||
json_mode = "application/json" in content_type
|
||||
|
||||
# Get the models from the app
|
||||
model_manager: ModelManager = request.app.state.model_manager
|
||||
llama_queue = model_manager.llama_queue
|
||||
tokenizer = model_manager.tokenizer
|
||||
config = model_manager.config
|
||||
|
||||
device = request.app.state.device
|
||||
|
||||
# Get the response generators
|
||||
response_generator = get_response_generator(
|
||||
llama_queue, tokenizer, config, req, device, json_mode
|
||||
)
|
||||
|
||||
# Return the response in the correct format
|
||||
if req.streaming is False:
|
||||
result = response_generator()
|
||||
if json_mode:
|
||||
return JSONResponse(result.model_dump())
|
||||
else:
|
||||
return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
|
||||
|
||||
return StreamResponse(
|
||||
iterable=response_generator(), content_type="text/event-stream"
|
||||
)
|
570
tools/webui.py
570
tools/webui.py
@ -1,570 +0,0 @@
|
||||
import gc
|
||||
import html
|
||||
import io
|
||||
import os
|
||||
import queue
|
||||
import wave
|
||||
from argparse import ArgumentParser
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import librosa
|
||||
import numpy as np
|
||||
import pyrootutils
|
||||
import torch
|
||||
from loguru import logger
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
|
||||
|
||||
from fish_speech.i18n import i18n
|
||||
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
||||
from fish_speech.utils import autocast_exclude_mps, set_seed
|
||||
from tools.api import decode_vq_tokens, encode_reference
|
||||
from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
|
||||
from tools.llama.generate import (
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
WrappedGenerateResponse,
|
||||
launch_thread_safe_queue,
|
||||
)
|
||||
from tools.schema import (
|
||||
GLOBAL_NUM_SAMPLES,
|
||||
ASRPackRequest,
|
||||
ServeASRRequest,
|
||||
ServeASRResponse,
|
||||
ServeASRSegment,
|
||||
ServeAudioPart,
|
||||
ServeForwardMessage,
|
||||
ServeMessage,
|
||||
ServeReferenceAudio,
|
||||
ServeRequest,
|
||||
ServeResponse,
|
||||
ServeStreamDelta,
|
||||
ServeStreamResponse,
|
||||
ServeTextPart,
|
||||
ServeTimedASRResponse,
|
||||
ServeTTSRequest,
|
||||
ServeVQGANDecodeRequest,
|
||||
ServeVQGANDecodeResponse,
|
||||
ServeVQGANEncodeRequest,
|
||||
ServeVQGANEncodeResponse,
|
||||
ServeVQPart,
|
||||
)
|
||||
from tools.vqgan.inference import load_model as load_decoder_model
|
||||
|
||||
# Make einx happy
|
||||
os.environ["EINX_FILTER_TRACEBACK"] = "false"
|
||||
|
||||
|
||||
HEADER_MD = f"""# Fish Speech
|
||||
|
||||
{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
|
||||
|
||||
{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).")}
|
||||
|
||||
{i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}
|
||||
|
||||
{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
|
||||
"""
|
||||
|
||||
TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
|
||||
SPACE_IMPORTED = False
|
||||
|
||||
|
||||
def build_html_error_message(error):
|
||||
return f"""
|
||||
<div style="color: red;
|
||||
font-weight: bold;">
|
||||
{html.escape(str(error))}
|
||||
</div>
|
||||
"""
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(req: ServeTTSRequest):
|
||||
|
||||
idstr: str | None = req.reference_id
|
||||
prompt_tokens, prompt_texts = [], []
|
||||
if idstr is not None:
|
||||
ref_folder = Path("references") / idstr
|
||||
ref_folder.mkdir(parents=True, exist_ok=True)
|
||||
ref_audios = list_files(
|
||||
ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
|
||||
)
|
||||
|
||||
if req.use_memory_cache == "never" or (
|
||||
req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
|
||||
):
|
||||
prompt_tokens = [
|
||||
encode_reference(
|
||||
decoder_model=decoder_model,
|
||||
reference_audio=audio_to_bytes(str(ref_audio)),
|
||||
enable_reference_audio=True,
|
||||
)
|
||||
for ref_audio in ref_audios
|
||||
]
|
||||
prompt_texts = [
|
||||
read_ref_text(str(ref_audio.with_suffix(".lab")))
|
||||
for ref_audio in ref_audios
|
||||
]
|
||||
else:
|
||||
logger.info("Use same references")
|
||||
|
||||
else:
|
||||
# Parse reference audio aka prompt
|
||||
refs = req.references
|
||||
|
||||
if req.use_memory_cache == "never" or (
|
||||
req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
|
||||
):
|
||||
prompt_tokens = [
|
||||
encode_reference(
|
||||
decoder_model=decoder_model,
|
||||
reference_audio=ref.audio,
|
||||
enable_reference_audio=True,
|
||||
)
|
||||
for ref in refs
|
||||
]
|
||||
prompt_texts = [ref.text for ref in refs]
|
||||
else:
|
||||
logger.info("Use same references")
|
||||
|
||||
if req.seed is not None:
|
||||
set_seed(req.seed)
|
||||
logger.warning(f"set seed: {req.seed}")
|
||||
|
||||
# LLAMA Inference
|
||||
request = dict(
|
||||
device=decoder_model.device,
|
||||
max_new_tokens=req.max_new_tokens,
|
||||
text=(
|
||||
req.text
|
||||
if not req.normalize
|
||||
else ChnNormedText(raw_text=req.text).normalize()
|
||||
),
|
||||
top_p=req.top_p,
|
||||
repetition_penalty=req.repetition_penalty,
|
||||
temperature=req.temperature,
|
||||
compile=args.compile,
|
||||
iterative_prompt=req.chunk_length > 0,
|
||||
chunk_length=req.chunk_length,
|
||||
max_length=4096,
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_text=prompt_texts,
|
||||
)
|
||||
|
||||
response_queue = queue.Queue()
|
||||
llama_queue.put(
|
||||
GenerateRequest(
|
||||
request=request,
|
||||
response_queue=response_queue,
|
||||
)
|
||||
)
|
||||
|
||||
segments = []
|
||||
|
||||
while True:
|
||||
result: WrappedGenerateResponse = response_queue.get()
|
||||
if result.status == "error":
|
||||
yield None, None, build_html_error_message(result.response)
|
||||
break
|
||||
|
||||
result: GenerateResponse = result.response
|
||||
if result.action == "next":
|
||||
break
|
||||
|
||||
with autocast_exclude_mps(
|
||||
device_type=decoder_model.device.type, dtype=args.precision
|
||||
):
|
||||
fake_audios = decode_vq_tokens(
|
||||
decoder_model=decoder_model,
|
||||
codes=result.codes,
|
||||
)
|
||||
|
||||
fake_audios = fake_audios.float().cpu().numpy()
|
||||
segments.append(fake_audios)
|
||||
|
||||
if len(segments) == 0:
|
||||
return (
|
||||
None,
|
||||
None,
|
||||
build_html_error_message(
|
||||
i18n("No audio generated, please check the input text.")
|
||||
),
|
||||
)
|
||||
|
||||
# No matter streaming or not, we need to return the final audio
|
||||
audio = np.concatenate(segments, axis=0)
|
||||
yield None, (decoder_model.spec_transform.sample_rate, audio), None
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
n_audios = 4
|
||||
|
||||
global_audio_list = []
|
||||
global_error_list = []
|
||||
|
||||
|
||||
def inference_wrapper(
|
||||
text,
|
||||
enable_reference_audio,
|
||||
reference_audio,
|
||||
reference_text,
|
||||
max_new_tokens,
|
||||
chunk_length,
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
temperature,
|
||||
seed,
|
||||
batch_infer_num,
|
||||
):
|
||||
audios = []
|
||||
errors = []
|
||||
|
||||
for _ in range(batch_infer_num):
|
||||
result = inference(
|
||||
text,
|
||||
enable_reference_audio,
|
||||
reference_audio,
|
||||
reference_text,
|
||||
max_new_tokens,
|
||||
chunk_length,
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
temperature,
|
||||
seed,
|
||||
)
|
||||
|
||||
_, audio_data, error_message = next(result)
|
||||
|
||||
audios.append(
|
||||
gr.Audio(value=audio_data if audio_data else None, visible=True),
|
||||
)
|
||||
errors.append(
|
||||
gr.HTML(value=error_message if error_message else None, visible=True),
|
||||
)
|
||||
|
||||
for _ in range(batch_infer_num, n_audios):
|
||||
audios.append(
|
||||
gr.Audio(value=None, visible=False),
|
||||
)
|
||||
errors.append(
|
||||
gr.HTML(value=None, visible=False),
|
||||
)
|
||||
|
||||
return None, *audios, *errors
|
||||
|
||||
|
||||
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
||||
buffer = io.BytesIO()
|
||||
|
||||
with wave.open(buffer, "wb") as wav_file:
|
||||
wav_file.setnchannels(channels)
|
||||
wav_file.setsampwidth(bit_depth // 8)
|
||||
wav_file.setframerate(sample_rate)
|
||||
|
||||
wav_header_bytes = buffer.getvalue()
|
||||
buffer.close()
|
||||
return wav_header_bytes
|
||||
|
||||
|
||||
def normalize_text(user_input, use_normalization):
|
||||
if use_normalization:
|
||||
return ChnNormedText(raw_text=user_input).normalize()
|
||||
else:
|
||||
return user_input
|
||||
|
||||
|
||||
def update_examples():
|
||||
examples_dir = Path("references")
|
||||
examples_dir.mkdir(parents=True, exist_ok=True)
|
||||
example_audios = list_files(examples_dir, AUDIO_EXTENSIONS, recursive=True)
|
||||
return gr.Dropdown(choices=example_audios + [""])
|
||||
|
||||
|
||||
def build_app():
|
||||
with gr.Blocks(theme=gr.themes.Base()) as app:
|
||||
gr.Markdown(HEADER_MD)
|
||||
|
||||
# Use light theme by default
|
||||
app.load(
|
||||
None,
|
||||
None,
|
||||
js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
|
||||
% args.theme,
|
||||
)
|
||||
|
||||
# Inference
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
text = gr.Textbox(
|
||||
label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
|
||||
)
|
||||
refined_text = gr.Textbox(
|
||||
label=i18n("Realtime Transform Text"),
|
||||
placeholder=i18n(
|
||||
"Normalization Result Preview (Currently Only Chinese)"
|
||||
),
|
||||
lines=5,
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
normalize = gr.Checkbox(
|
||||
label=i18n("Text Normalization"),
|
||||
value=False,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Tab(label=i18n("Advanced Config")):
|
||||
with gr.Row():
|
||||
chunk_length = gr.Slider(
|
||||
label=i18n("Iterative Prompt Length, 0 means off"),
|
||||
minimum=0,
|
||||
maximum=300,
|
||||
value=200,
|
||||
step=8,
|
||||
)
|
||||
|
||||
max_new_tokens = gr.Slider(
|
||||
label=i18n(
|
||||
"Maximum tokens per batch, 0 means no limit"
|
||||
),
|
||||
minimum=0,
|
||||
maximum=2048,
|
||||
value=0,
|
||||
step=8,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
top_p = gr.Slider(
|
||||
label="Top-P",
|
||||
minimum=0.6,
|
||||
maximum=0.9,
|
||||
value=0.7,
|
||||
step=0.01,
|
||||
)
|
||||
|
||||
repetition_penalty = gr.Slider(
|
||||
label=i18n("Repetition Penalty"),
|
||||
minimum=1,
|
||||
maximum=1.5,
|
||||
value=1.2,
|
||||
step=0.01,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
temperature = gr.Slider(
|
||||
label="Temperature",
|
||||
minimum=0.6,
|
||||
maximum=0.9,
|
||||
value=0.7,
|
||||
step=0.01,
|
||||
)
|
||||
seed = gr.Number(
|
||||
label="Seed",
|
||||
info="0 means randomized inference, otherwise deterministic",
|
||||
value=0,
|
||||
)
|
||||
|
||||
with gr.Tab(label=i18n("Reference Audio")):
|
||||
with gr.Row():
|
||||
gr.Markdown(
|
||||
i18n(
|
||||
"5 to 10 seconds of reference audio, useful for specifying speaker."
|
||||
)
|
||||
)
|
||||
with gr.Row():
|
||||
reference_id = gr.Textbox(
|
||||
label=i18n("Reference ID"),
|
||||
placeholder="Leave empty to use uploaded references",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
use_memory_cache = gr.Radio(
|
||||
label=i18n("Use Memory Cache"),
|
||||
choices=["never", "on-demand", "always"],
|
||||
value="on-demand",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
reference_audio = gr.Audio(
|
||||
label=i18n("Reference Audio"),
|
||||
type="filepath",
|
||||
)
|
||||
with gr.Row():
|
||||
reference_text = gr.Textbox(
|
||||
label=i18n("Reference Text"),
|
||||
lines=1,
|
||||
placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
|
||||
value="",
|
||||
)
|
||||
|
||||
with gr.Column(scale=3):
|
||||
with gr.Row():
|
||||
error = gr.HTML(
|
||||
label=i18n("Error Message"),
|
||||
visible=True,
|
||||
)
|
||||
with gr.Row():
|
||||
audio = gr.Audio(
|
||||
label=i18n("Generated Audio"),
|
||||
type="numpy",
|
||||
interactive=False,
|
||||
visible=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
generate = gr.Button(
|
||||
value="\U0001F3A7 " + i18n("Generate"), variant="primary"
|
||||
)
|
||||
|
||||
text.input(fn=normalize_text, inputs=[text, normalize], outputs=[refined_text])
|
||||
|
||||
def inference_wrapper(
|
||||
text,
|
||||
normalize,
|
||||
reference_id,
|
||||
reference_audio,
|
||||
reference_text,
|
||||
max_new_tokens,
|
||||
chunk_length,
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
temperature,
|
||||
seed,
|
||||
use_memory_cache,
|
||||
):
|
||||
references = []
|
||||
if reference_audio:
|
||||
# 将文件路径转换为字节
|
||||
with open(reference_audio, "rb") as audio_file:
|
||||
audio_bytes = audio_file.read()
|
||||
references = [
|
||||
ServeReferenceAudio(audio=audio_bytes, text=reference_text)
|
||||
]
|
||||
|
||||
req = ServeTTSRequest(
|
||||
text=text,
|
||||
normalize=normalize,
|
||||
reference_id=reference_id if reference_id else None,
|
||||
references=references,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chunk_length=chunk_length,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
temperature=temperature,
|
||||
seed=int(seed) if seed else None,
|
||||
use_memory_cache=use_memory_cache,
|
||||
)
|
||||
|
||||
for result in inference(req):
|
||||
if result[2]: # Error message
|
||||
return None, result[2]
|
||||
elif result[1]: # Audio data
|
||||
return result[1], None
|
||||
|
||||
return None, i18n("No audio generated")
|
||||
|
||||
# Submit
|
||||
generate.click(
|
||||
inference_wrapper,
|
||||
[
|
||||
refined_text,
|
||||
normalize,
|
||||
reference_id,
|
||||
reference_audio,
|
||||
reference_text,
|
||||
max_new_tokens,
|
||||
chunk_length,
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
temperature,
|
||||
seed,
|
||||
use_memory_cache,
|
||||
],
|
||||
[audio, error],
|
||||
concurrency_limit=1,
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--llama-checkpoint-path",
|
||||
type=Path,
|
||||
default="checkpoints/fish-speech-1.5",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder-checkpoint-path",
|
||||
type=Path,
|
||||
default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
||||
)
|
||||
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
||||
parser.add_argument("--device", type=str, default="cuda")
|
||||
parser.add_argument("--half", action="store_true")
|
||||
parser.add_argument("--compile", action="store_true")
|
||||
parser.add_argument("--max-gradio-length", type=int, default=0)
|
||||
parser.add_argument("--theme", type=str, default="light")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
args.precision = torch.half if args.half else torch.bfloat16
|
||||
|
||||
# Check if CUDA is available
|
||||
if not torch.cuda.is_available():
|
||||
logger.info("CUDA is not available, running on CPU.")
|
||||
args.device = "cpu"
|
||||
|
||||
logger.info("Loading Llama model...")
|
||||
llama_queue = launch_thread_safe_queue(
|
||||
checkpoint_path=args.llama_checkpoint_path,
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
compile=args.compile,
|
||||
)
|
||||
logger.info("Llama model loaded, loading VQ-GAN model...")
|
||||
|
||||
decoder_model = load_decoder_model(
|
||||
config_name=args.decoder_config_name,
|
||||
checkpoint_path=args.decoder_checkpoint_path,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
logger.info("Decoder model loaded, warming up...")
|
||||
|
||||
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
||||
list(
|
||||
inference(
|
||||
ServeTTSRequest(
|
||||
text="Hello world.",
|
||||
references=[],
|
||||
reference_id=None,
|
||||
max_new_tokens=0,
|
||||
chunk_length=200,
|
||||
top_p=0.7,
|
||||
repetition_penalty=1.5,
|
||||
temperature=0.7,
|
||||
emotion=None,
|
||||
format="wav",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Warming up done, launching the web UI...")
|
||||
|
||||
app = build_app()
|
||||
app.launch(show_api=True)
|
173
tools/webui/__init__.py
Normal file
173
tools/webui/__init__.py
Normal file
@ -0,0 +1,173 @@
|
||||
from typing import Callable
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from fish_speech.i18n import i18n
|
||||
from tools.inference_engine.utils import normalize_text
|
||||
from tools.webui.variables import HEADER_MD, TEXTBOX_PLACEHOLDER
|
||||
|
||||
|
||||
def build_app(inference_fct: Callable, theme: str = "light") -> gr.Blocks:
|
||||
with gr.Blocks(theme=gr.themes.Base()) as app:
|
||||
gr.Markdown(HEADER_MD)
|
||||
|
||||
# Use light theme by default
|
||||
app.load(
|
||||
None,
|
||||
None,
|
||||
js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
|
||||
% theme,
|
||||
)
|
||||
|
||||
# Inference
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
text = gr.Textbox(
|
||||
label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
|
||||
)
|
||||
refined_text = gr.Textbox(
|
||||
label=i18n("Realtime Transform Text"),
|
||||
placeholder=i18n(
|
||||
"Normalization Result Preview (Currently Only Chinese)"
|
||||
),
|
||||
lines=5,
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
normalize = gr.Checkbox(
|
||||
label=i18n("Text Normalization"),
|
||||
value=False,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Tab(label=i18n("Advanced Config")):
|
||||
with gr.Row():
|
||||
chunk_length = gr.Slider(
|
||||
label=i18n("Iterative Prompt Length, 0 means off"),
|
||||
minimum=0,
|
||||
maximum=300,
|
||||
value=200,
|
||||
step=8,
|
||||
)
|
||||
|
||||
max_new_tokens = gr.Slider(
|
||||
label=i18n(
|
||||
"Maximum tokens per batch, 0 means no limit"
|
||||
),
|
||||
minimum=0,
|
||||
maximum=2048,
|
||||
value=0,
|
||||
step=8,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
top_p = gr.Slider(
|
||||
label="Top-P",
|
||||
minimum=0.6,
|
||||
maximum=0.9,
|
||||
value=0.7,
|
||||
step=0.01,
|
||||
)
|
||||
|
||||
repetition_penalty = gr.Slider(
|
||||
label=i18n("Repetition Penalty"),
|
||||
minimum=1,
|
||||
maximum=1.5,
|
||||
value=1.2,
|
||||
step=0.01,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
temperature = gr.Slider(
|
||||
label="Temperature",
|
||||
minimum=0.6,
|
||||
maximum=0.9,
|
||||
value=0.7,
|
||||
step=0.01,
|
||||
)
|
||||
seed = gr.Number(
|
||||
label="Seed",
|
||||
info="0 means randomized inference, otherwise deterministic",
|
||||
value=0,
|
||||
)
|
||||
|
||||
with gr.Tab(label=i18n("Reference Audio")):
|
||||
with gr.Row():
|
||||
gr.Markdown(
|
||||
i18n(
|
||||
"5 to 10 seconds of reference audio, useful for specifying speaker."
|
||||
)
|
||||
)
|
||||
with gr.Row():
|
||||
reference_id = gr.Textbox(
|
||||
label=i18n("Reference ID"),
|
||||
placeholder="Leave empty to use uploaded references",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
use_memory_cache = gr.Radio(
|
||||
label=i18n("Use Memory Cache"),
|
||||
choices=["on", "off"],
|
||||
value="on",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
reference_audio = gr.Audio(
|
||||
label=i18n("Reference Audio"),
|
||||
type="filepath",
|
||||
)
|
||||
with gr.Row():
|
||||
reference_text = gr.Textbox(
|
||||
label=i18n("Reference Text"),
|
||||
lines=1,
|
||||
placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
|
||||
value="",
|
||||
)
|
||||
|
||||
with gr.Column(scale=3):
|
||||
with gr.Row():
|
||||
error = gr.HTML(
|
||||
label=i18n("Error Message"),
|
||||
visible=True,
|
||||
)
|
||||
with gr.Row():
|
||||
audio = gr.Audio(
|
||||
label=i18n("Generated Audio"),
|
||||
type="numpy",
|
||||
interactive=False,
|
||||
visible=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
generate = gr.Button(
|
||||
value="\U0001F3A7 " + i18n("Generate"),
|
||||
variant="primary",
|
||||
)
|
||||
|
||||
text.input(fn=normalize_text, inputs=[text, normalize], outputs=[refined_text])
|
||||
|
||||
# Submit
|
||||
generate.click(
|
||||
inference_fct,
|
||||
[
|
||||
refined_text,
|
||||
normalize,
|
||||
reference_id,
|
||||
reference_audio,
|
||||
reference_text,
|
||||
max_new_tokens,
|
||||
chunk_length,
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
temperature,
|
||||
seed,
|
||||
use_memory_cache,
|
||||
],
|
||||
[audio, error],
|
||||
concurrency_limit=1,
|
||||
)
|
||||
|
||||
return app
|
91
tools/webui/inference.py
Normal file
91
tools/webui/inference.py
Normal file
@ -0,0 +1,91 @@
|
||||
import html
|
||||
from functools import partial
|
||||
from typing import Any, Callable
|
||||
|
||||
from fish_speech.i18n import i18n
|
||||
from tools.schema import ServeReferenceAudio, ServeTTSRequest
|
||||
|
||||
|
||||
def inference_wrapper(
|
||||
text,
|
||||
normalize,
|
||||
reference_id,
|
||||
reference_audio,
|
||||
reference_text,
|
||||
max_new_tokens,
|
||||
chunk_length,
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
temperature,
|
||||
seed,
|
||||
use_memory_cache,
|
||||
engine,
|
||||
):
|
||||
"""
|
||||
Wrapper for the inference function.
|
||||
Used in the Gradio interface.
|
||||
"""
|
||||
|
||||
if reference_audio:
|
||||
references = get_reference_audio(reference_audio, reference_text)
|
||||
else:
|
||||
references = []
|
||||
|
||||
req = ServeTTSRequest(
|
||||
text=text,
|
||||
normalize=normalize,
|
||||
reference_id=reference_id if reference_id else None,
|
||||
references=references,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chunk_length=chunk_length,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
temperature=temperature,
|
||||
seed=int(seed) if seed else None,
|
||||
use_memory_cache=use_memory_cache,
|
||||
)
|
||||
|
||||
for result in engine.inference(req):
|
||||
match result.code:
|
||||
case "final":
|
||||
return result.audio, None
|
||||
case "error":
|
||||
return None, build_html_error_message(i18n(result.error))
|
||||
case _:
|
||||
pass
|
||||
|
||||
return None, i18n("No audio generated")
|
||||
|
||||
|
||||
def get_reference_audio(reference_audio: str, reference_text: str) -> list:
|
||||
"""
|
||||
Get the reference audio bytes.
|
||||
"""
|
||||
|
||||
with open(reference_audio, "rb") as audio_file:
|
||||
audio_bytes = audio_file.read()
|
||||
|
||||
return [ServeReferenceAudio(audio=audio_bytes, text=reference_text)]
|
||||
|
||||
|
||||
def build_html_error_message(error: Any) -> str:
|
||||
|
||||
error = error if isinstance(error, Exception) else Exception("Unknown error")
|
||||
|
||||
return f"""
|
||||
<div style="color: red;
|
||||
font-weight: bold;">
|
||||
{html.escape(str(error))}
|
||||
</div>
|
||||
"""
|
||||
|
||||
|
||||
def get_inference_wrapper(engine) -> Callable:
|
||||
"""
|
||||
Get the inference function with the immutable arguments.
|
||||
"""
|
||||
|
||||
return partial(
|
||||
inference_wrapper,
|
||||
engine=engine,
|
||||
)
|
14
tools/webui/variables.py
Normal file
14
tools/webui/variables.py
Normal file
@ -0,0 +1,14 @@
|
||||
from fish_speech.i18n import i18n
|
||||
|
||||
HEADER_MD = f"""# Fish Speech
|
||||
|
||||
{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
|
||||
|
||||
{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).")}
|
||||
|
||||
{i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}
|
||||
|
||||
{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
|
||||
"""
|
||||
|
||||
TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
|
Loading…
x
Reference in New Issue
Block a user