Make WebUI and API code cleaner (+ 1.5 fixes) (#703)
* rename webui.py to run_webui.py * remove unused imports * remove unsued code * move inference code and fix all warnings * move web app code * make code easier to read * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove unused function * remove msgpack_api.py * rename API files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * finish updating the doc with the new file names * finish updating the doc with the new file names * fix CPU use in the API * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor WebUIinference in a class with submodules * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * re-enable streaming in webui inference code * generalize inference code in webui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix * make a unique inference engine class * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix * cleaning code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * implement new structure of the API (not working) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor API * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * reimplement chat endpoint * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
954cae1b5d
commit
62eae262c2
2
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@ -45,7 +45,7 @@ body:
|
|||||||
description: |
|
description: |
|
||||||
Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks.
|
Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks.
|
||||||
placeholder: |
|
placeholder: |
|
||||||
1. Run the command `python -m tools.post_api -t "xxxxx"`
|
1. Run the command `python -m tools.api_client -t "xxxxx"`
|
||||||
2. Observe the console output error: `ModuleNotFoundError: No module named 'pyaudio'` (with screenshots or logs will be better)
|
2. Observe the console output error: `ModuleNotFoundError: No module named 'pyaudio'` (with screenshots or logs will be better)
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
|
@ -185,7 +185,7 @@ pip install -e .[stable]
|
|||||||
4. Configure environment variables and access WebUI
|
4. Configure environment variables and access WebUI
|
||||||
|
|
||||||
In the terminal inside the docker container, enter `export GRADIO_SERVER_NAME="0.0.0.0"` to allow external access to the gradio service inside docker.
|
In the terminal inside the docker container, enter `export GRADIO_SERVER_NAME="0.0.0.0"` to allow external access to the gradio service inside docker.
|
||||||
Then in the terminal inside the docker container, enter `python tools/webui.py` to start the WebUI service.
|
Then in the terminal inside the docker container, enter `python tools/run_webui.py` to start the WebUI service.
|
||||||
|
|
||||||
If you're using WSL or MacOS, visit [http://localhost:7860](http://localhost:7860) to open the WebUI interface.
|
If you're using WSL or MacOS, visit [http://localhost:7860](http://localhost:7860) to open the WebUI interface.
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ python tools/vqgan/inference.py \
|
|||||||
We provide a HTTP API for inference. You can use the following command to start the server:
|
We provide a HTTP API for inference. You can use the following command to start the server:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m tools.api \
|
python -m tools.api_server \
|
||||||
--listen 0.0.0.0:8080 \
|
--listen 0.0.0.0:8080 \
|
||||||
--llama-checkpoint-path "checkpoints/fish-speech-1.5" \
|
--llama-checkpoint-path "checkpoints/fish-speech-1.5" \
|
||||||
--decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
|
--decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
|
||||||
@ -78,10 +78,10 @@ python -m tools.api \
|
|||||||
|
|
||||||
After that, you can view and test the API at http://127.0.0.1:8080/.
|
After that, you can view and test the API at http://127.0.0.1:8080/.
|
||||||
|
|
||||||
Below is an example of sending a request using `tools/post_api.py`.
|
Below is an example of sending a request using `tools/api_client.py`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m tools.post_api \
|
python -m tools.api_client \
|
||||||
--text "Text to be input" \
|
--text "Text to be input" \
|
||||||
--reference_audio "Path to reference audio" \
|
--reference_audio "Path to reference audio" \
|
||||||
--reference_text "Text content of the reference audio" \
|
--reference_text "Text content of the reference audio" \
|
||||||
@ -93,7 +93,7 @@ The above command indicates synthesizing the desired audio according to the refe
|
|||||||
The following example demonstrates that you can use **multiple** reference audio paths and reference audio texts at once. Separate them with spaces in the command.
|
The following example demonstrates that you can use **multiple** reference audio paths and reference audio texts at once. Separate them with spaces in the command.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m tools.post_api \
|
python -m tools.api_client \
|
||||||
--text "Text to input" \
|
--text "Text to input" \
|
||||||
--reference_audio "reference audio path1" "reference audio path2" \
|
--reference_audio "reference audio path1" "reference audio path2" \
|
||||||
--reference_text "reference audio text1" "reference audio text2"\
|
--reference_text "reference audio text1" "reference audio text2"\
|
||||||
@ -109,7 +109,7 @@ The currently supported reference audio has a maximum total duration of 90 secon
|
|||||||
|
|
||||||
|
|
||||||
!!! info
|
!!! info
|
||||||
To learn more about available parameters, you can use the command `python -m tools.post_api -h`
|
To learn more about available parameters, you can use the command `python -m tools.api_client -h`
|
||||||
|
|
||||||
## GUI Inference
|
## GUI Inference
|
||||||
[Download client](https://github.com/AnyaCoder/fish-speech-gui/releases)
|
[Download client](https://github.com/AnyaCoder/fish-speech-gui/releases)
|
||||||
|
@ -44,7 +44,7 @@ pip install -e .[stable]
|
|||||||
To build fish-agent, please use the command below under the main folder:
|
To build fish-agent, please use the command below under the main folder:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
|
python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
|
||||||
```
|
```
|
||||||
|
|
||||||
The `--compile` args only support Python < 3.12 , which will greatly speed up the token generation.
|
The `--compile` args only support Python < 3.12 , which will greatly speed up the token generation.
|
||||||
|
@ -184,7 +184,7 @@ pip install -e .[stable]
|
|||||||
4. 環境変数の設定と WebUI へのアクセス
|
4. 環境変数の設定と WebUI へのアクセス
|
||||||
|
|
||||||
Docker コンテナ内のターミナルで、`export GRADIO_SERVER_NAME="0.0.0.0"` と入力して、外部から Docker 内の gradio サービスにアクセスできるようにします。
|
Docker コンテナ内のターミナルで、`export GRADIO_SERVER_NAME="0.0.0.0"` と入力して、外部から Docker 内の gradio サービスにアクセスできるようにします。
|
||||||
次に、Docker コンテナ内のターミナルで `python tools/webui.py` と入力して WebUI サービスを起動します。
|
次に、Docker コンテナ内のターミナルで `python tools/run_webui.py` と入力して WebUI サービスを起動します。
|
||||||
|
|
||||||
WSL または MacOS の場合は、[http://localhost:7860](http://localhost:7860) にアクセスして WebUI インターフェースを開くことができます。
|
WSL または MacOS の場合は、[http://localhost:7860](http://localhost:7860) にアクセスして WebUI インターフェースを開くことができます。
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ python tools/vqgan/inference.py \
|
|||||||
推論のための HTTP API を提供しています。次のコマンドを使用してサーバーを起動できます:
|
推論のための HTTP API を提供しています。次のコマンドを使用してサーバーを起動できます:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m tools.api \
|
python -m tools.api_server \
|
||||||
--listen 0.0.0.0:8080 \
|
--listen 0.0.0.0:8080 \
|
||||||
--llama-checkpoint-path "checkpoints/fish-speech-1.5" \
|
--llama-checkpoint-path "checkpoints/fish-speech-1.5" \
|
||||||
--decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
|
--decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
|
||||||
@ -78,10 +78,10 @@ python -m tools.api \
|
|||||||
|
|
||||||
その後、`http://127.0.0.1:8080/`で API を表示およびテストできます。
|
その後、`http://127.0.0.1:8080/`で API を表示およびテストできます。
|
||||||
|
|
||||||
以下は、`tools/post_api.py` を使用してリクエストを送信する例です。
|
以下は、`tools/api_client.py` を使用してリクエストを送信する例です。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m tools.post_api \
|
python -m tools.api_client \
|
||||||
--text "入力するテキスト" \
|
--text "入力するテキスト" \
|
||||||
--reference_audio "参照音声へのパス" \
|
--reference_audio "参照音声へのパス" \
|
||||||
--reference_text "参照音声テキスト" \
|
--reference_text "参照音声テキスト" \
|
||||||
@ -91,7 +91,7 @@ python -m tools.post_api \
|
|||||||
上記のコマンドは、参照音声の情報に基づいて必要な音声を合成し、ストリーミング方式で返すことを示しています。
|
上記のコマンドは、参照音声の情報に基づいて必要な音声を合成し、ストリーミング方式で返すことを示しています。
|
||||||
|
|
||||||
!!! info
|
!!! info
|
||||||
使用可能なパラメータの詳細については、コマンド` python -m tools.post_api -h `を使用してください
|
使用可能なパラメータの詳細については、コマンド` python -m tools.api_client -h `を使用してください
|
||||||
|
|
||||||
## WebUI 推論
|
## WebUI 推論
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ pip install -e .[stable]
|
|||||||
fish-agentを構築するには、メインフォルダで以下のコマンドを使用してください:
|
fish-agentを構築するには、メインフォルダで以下のコマンドを使用してください:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
|
python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
|
||||||
```
|
```
|
||||||
|
|
||||||
`--compile`引数はPython < 3.12でのみサポートされており、トークン生成を大幅に高速化します。
|
`--compile`引数はPython < 3.12でのみサポートされており、トークン生成を大幅に高速化します。
|
||||||
|
@ -185,7 +185,7 @@ pip install -e .[stable]
|
|||||||
4. 환경 변수 설정 및 WebUI 접근
|
4. 환경 변수 설정 및 WebUI 접근
|
||||||
|
|
||||||
Docker 컨테이너 내부의 터미널에서 `export GRADIO_SERVER_NAME="0.0.0.0"`를 입력하여 Docker 내부에서 Gradio 서비스에 외부 접근을 허용합니다.
|
Docker 컨테이너 내부의 터미널에서 `export GRADIO_SERVER_NAME="0.0.0.0"`를 입력하여 Docker 내부에서 Gradio 서비스에 외부 접근을 허용합니다.
|
||||||
이후, 터미널에서 `python tools/webui.py` 명령어를 입력하여 WebUI 서비스를 시작합니다.
|
이후, 터미널에서 `python tools/run_webui.py` 명령어를 입력하여 WebUI 서비스를 시작합니다.
|
||||||
|
|
||||||
WSL 또는 macOS를 사용하는 경우 [http://localhost:7860](http://localhost:7860)에서 WebUI 인터페이스를 열 수 있습니다.
|
WSL 또는 macOS를 사용하는 경우 [http://localhost:7860](http://localhost:7860)에서 WebUI 인터페이스를 열 수 있습니다.
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ python tools/vqgan/inference.py \
|
|||||||
추론을 위한 HTTP API를 제공하고 있습니다. 아래의 명령어로 서버를 시작할 수 있습니다:
|
추론을 위한 HTTP API를 제공하고 있습니다. 아래의 명령어로 서버를 시작할 수 있습니다:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m tools.api \
|
python -m tools.api_server \
|
||||||
--listen 0.0.0.0:8080 \
|
--listen 0.0.0.0:8080 \
|
||||||
--llama-checkpoint-path "checkpoints/fish-speech-1.5" \
|
--llama-checkpoint-path "checkpoints/fish-speech-1.5" \
|
||||||
--decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
|
--decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
|
||||||
@ -78,10 +78,10 @@ python -m tools.api \
|
|||||||
|
|
||||||
이후, http://127.0.0.1:8080/ 에서 API를 확인하고 테스트할 수 있습니다.
|
이후, http://127.0.0.1:8080/ 에서 API를 확인하고 테스트할 수 있습니다.
|
||||||
|
|
||||||
아래는 `tools/post_api.py`를 사용하여 요청을 보내는 예시입니다.
|
아래는 `tools/api_client.py`를 사용하여 요청을 보내는 예시입니다.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m tools.post_api \
|
python -m tools.api_client \
|
||||||
--text "입력할 텍스트" \
|
--text "입력할 텍스트" \
|
||||||
--reference_audio "참고 음성 경로" \
|
--reference_audio "참고 음성 경로" \
|
||||||
--reference_text "참고 음성의 텍스트 내용" \
|
--reference_text "참고 음성의 텍스트 내용" \
|
||||||
@ -93,7 +93,7 @@ python -m tools.post_api \
|
|||||||
다음 예시는 여러 개의 참고 음성 경로와 텍스트를 한꺼번에 사용할 수 있음을 보여줍니다. 명령에서 공백으로 구분하여 입력합니다.
|
다음 예시는 여러 개의 참고 음성 경로와 텍스트를 한꺼번에 사용할 수 있음을 보여줍니다. 명령에서 공백으로 구분하여 입력합니다.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m tools.post_api \
|
python -m tools.api_client \
|
||||||
--text "입력할 텍스트" \
|
--text "입력할 텍스트" \
|
||||||
--reference_audio "참고 음성 경로1" "참고 음성 경로2" \
|
--reference_audio "참고 음성 경로1" "참고 음성 경로2" \
|
||||||
--reference_text "참고 음성 텍스트1" "참고 음성 텍스트2"\
|
--reference_text "참고 음성 텍스트1" "참고 음성 텍스트2"\
|
||||||
@ -107,7 +107,7 @@ python -m tools.post_api \
|
|||||||
`--reference_audio`와 `--reference_text` 대신에 `--reference_id`(하나만 사용 가능)를 사용할 수 있습니다. 프로젝트 루트 디렉토리에 `references/<your reference_id>` 폴더를 만들어 해당 음성과 주석 텍스트를 넣어야 합니다. 참고 음성은 최대 90초까지 지원됩니다.
|
`--reference_audio`와 `--reference_text` 대신에 `--reference_id`(하나만 사용 가능)를 사용할 수 있습니다. 프로젝트 루트 디렉토리에 `references/<your reference_id>` 폴더를 만들어 해당 음성과 주석 텍스트를 넣어야 합니다. 참고 음성은 최대 90초까지 지원됩니다.
|
||||||
|
|
||||||
!!! info
|
!!! info
|
||||||
제공되는 파라미터는 `python -m tools.post_api -h`를 사용하여 확인할 수 있습니다.
|
제공되는 파라미터는 `python -m tools.api_client -h`를 사용하여 확인할 수 있습니다.
|
||||||
|
|
||||||
## GUI 추론
|
## GUI 추론
|
||||||
[클라이언트 다운로드](https://github.com/AnyaCoder/fish-speech-gui/releases)
|
[클라이언트 다운로드](https://github.com/AnyaCoder/fish-speech-gui/releases)
|
||||||
|
@ -47,7 +47,7 @@ pip install -e .[stable]
|
|||||||
fish-agent를 구축하려면 메인 폴더에서 아래 명령어를 사용하세요:
|
fish-agent를 구축하려면 메인 폴더에서 아래 명령어를 사용하세요:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
|
python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
|
||||||
```
|
```
|
||||||
|
|
||||||
`--compile` 인자는 Python < 3.12에서만 지원되며, 토큰 생성 속도를 크게 향상시킵니다.
|
`--compile` 인자는 Python < 3.12에서만 지원되며, 토큰 생성 속도를 크게 향상시킵니다.
|
||||||
|
@ -181,7 +181,7 @@ pip install -e .[stable]
|
|||||||
4. Configure as variáveis de ambiente e acesse a WebUI
|
4. Configure as variáveis de ambiente e acesse a WebUI
|
||||||
|
|
||||||
No terminal do contêiner Docker, digite `export GRADIO_SERVER_NAME="0.0.0.0"` para permitir o acesso externo ao serviço gradio dentro do Docker.
|
No terminal do contêiner Docker, digite `export GRADIO_SERVER_NAME="0.0.0.0"` para permitir o acesso externo ao serviço gradio dentro do Docker.
|
||||||
Em seguida, no terminal do contêiner Docker, digite `python tools/webui.py` para iniciar o serviço WebUI.
|
Em seguida, no terminal do contêiner Docker, digite `python tools/run_webui.py` para iniciar o serviço WebUI.
|
||||||
|
|
||||||
Se estiver usando WSL ou MacOS, acesse [http://localhost:7860](http://localhost:7860) para abrir a interface WebUI.
|
Se estiver usando WSL ou MacOS, acesse [http://localhost:7860](http://localhost:7860) para abrir a interface WebUI.
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ python tools/vqgan/inference.py \
|
|||||||
Fornecemos uma API HTTP para inferência. O seguinte comando pode ser usado para iniciar o servidor:
|
Fornecemos uma API HTTP para inferência. O seguinte comando pode ser usado para iniciar o servidor:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m tools.api \
|
python -m tools.api_server \
|
||||||
--listen 0.0.0.0:8080 \
|
--listen 0.0.0.0:8080 \
|
||||||
--llama-checkpoint-path "checkpoints/fish-speech-1.5" \
|
--llama-checkpoint-path "checkpoints/fish-speech-1.5" \
|
||||||
--decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
|
--decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
|
||||||
@ -78,10 +78,10 @@ python -m tools.api \
|
|||||||
|
|
||||||
Depois disso, é possível visualizar e testar a API em http://127.0.0.1:8080/.
|
Depois disso, é possível visualizar e testar a API em http://127.0.0.1:8080/.
|
||||||
|
|
||||||
Abaixo está um exemplo de envio de uma solicitação usando `tools/post_api.py`.
|
Abaixo está um exemplo de envio de uma solicitação usando `tools/api_client.py`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m tools.post_api \
|
python -m tools.api_client \
|
||||||
--text "Texto a ser inserido" \
|
--text "Texto a ser inserido" \
|
||||||
--reference_audio "Caminho para o áudio de referência" \
|
--reference_audio "Caminho para o áudio de referência" \
|
||||||
--reference_text "Conteúdo de texto do áudio de referência" \
|
--reference_text "Conteúdo de texto do áudio de referência" \
|
||||||
@ -91,7 +91,7 @@ python -m tools.post_api \
|
|||||||
O comando acima indica a síntese do áudio desejada de acordo com as informações do áudio de referência e a retorna em modo de streaming.
|
O comando acima indica a síntese do áudio desejada de acordo com as informações do áudio de referência e a retorna em modo de streaming.
|
||||||
|
|
||||||
!!! info
|
!!! info
|
||||||
Para aprender mais sobre parâmetros disponíveis, você pode usar o comando `python -m tools.post_api -h`
|
Para aprender mais sobre parâmetros disponíveis, você pode usar o comando `python -m tools.api_client -h`
|
||||||
|
|
||||||
## Inferência por WebUI
|
## Inferência por WebUI
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ pip install -e .[stable]
|
|||||||
Para construir o fish-agent, use o comando abaixo na pasta principal:
|
Para construir o fish-agent, use o comando abaixo na pasta principal:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
|
python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
|
||||||
```
|
```
|
||||||
|
|
||||||
O argumento `--compile` só suporta Python < 3.12, o que aumentará muito a velocidade de geração de tokens.
|
O argumento `--compile` só suporta Python < 3.12, o que aumentará muito a velocidade de geração de tokens.
|
||||||
|
@ -188,7 +188,7 @@ pip install -e .[stable]
|
|||||||
4. 配置环境变量,访问 WebUI
|
4. 配置环境变量,访问 WebUI
|
||||||
|
|
||||||
在 docker 容器内的终端,输入 `export GRADIO_SERVER_NAME="0.0.0.0"` ,从而让外部可以访问 docker 内的 gradio 服务。
|
在 docker 容器内的终端,输入 `export GRADIO_SERVER_NAME="0.0.0.0"` ,从而让外部可以访问 docker 内的 gradio 服务。
|
||||||
接着在 docker 容器内的终端,输入 `python tools/webui.py` 即可开启 WebUI 服务。
|
接着在 docker 容器内的终端,输入 `python tools/run_webui.py` 即可开启 WebUI 服务。
|
||||||
|
|
||||||
如果是 WSL 或者是 MacOS ,访问 [http://localhost:7860](http://localhost:7860) 即可打开 WebUI 界面。
|
如果是 WSL 或者是 MacOS ,访问 [http://localhost:7860](http://localhost:7860) 即可打开 WebUI 界面。
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ python tools/vqgan/inference.py \
|
|||||||
运行以下命令来启动 HTTP 服务:
|
运行以下命令来启动 HTTP 服务:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m tools.api \
|
python -m tools.api_server \
|
||||||
--listen 0.0.0.0:8080 \
|
--listen 0.0.0.0:8080 \
|
||||||
--llama-checkpoint-path "checkpoints/fish-speech-1.5" \
|
--llama-checkpoint-path "checkpoints/fish-speech-1.5" \
|
||||||
--decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
|
--decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
|
||||||
@ -88,10 +88,10 @@ HF_ENDPOINT=https://hf-mirror.com python -m ...(同上)
|
|||||||
|
|
||||||
随后, 你可以在 `http://127.0.0.1:8080/` 中查看并测试 API.
|
随后, 你可以在 `http://127.0.0.1:8080/` 中查看并测试 API.
|
||||||
|
|
||||||
下面是使用`tools/post_api.py`发送请求的示例。
|
下面是使用`tools/api_client.py`发送请求的示例。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m tools.post_api \
|
python -m tools.api_client \
|
||||||
--text "要输入的文本" \
|
--text "要输入的文本" \
|
||||||
--reference_audio "参考音频路径" \
|
--reference_audio "参考音频路径" \
|
||||||
--reference_text "参考音频的文本内容" \
|
--reference_text "参考音频的文本内容" \
|
||||||
@ -102,7 +102,7 @@ python -m tools.post_api \
|
|||||||
|
|
||||||
下面的示例展示了, 可以一次使用**多个** `参考音频路径` 和 `参考音频的文本内容`。在命令里用空格隔开即可。
|
下面的示例展示了, 可以一次使用**多个** `参考音频路径` 和 `参考音频的文本内容`。在命令里用空格隔开即可。
|
||||||
```bash
|
```bash
|
||||||
python -m tools.post_api \
|
python -m tools.api_client \
|
||||||
--text "要输入的文本" \
|
--text "要输入的文本" \
|
||||||
--reference_audio "参考音频路径1" "参考音频路径2" \
|
--reference_audio "参考音频路径1" "参考音频路径2" \
|
||||||
--reference_text "参考音频的文本内容1" "参考音频的文本内容2"\
|
--reference_text "参考音频的文本内容1" "参考音频的文本内容2"\
|
||||||
@ -117,7 +117,7 @@ python -m tools.post_api \
|
|||||||
里面放上任意对音频与标注文本。 目前支持的参考音频最多加起来总时长90s。
|
里面放上任意对音频与标注文本。 目前支持的参考音频最多加起来总时长90s。
|
||||||
|
|
||||||
!!! info
|
!!! info
|
||||||
要了解有关可用参数的更多信息,可以使用命令`python -m tools.post_api -h`
|
要了解有关可用参数的更多信息,可以使用命令`python -m tools.api_client -h`
|
||||||
|
|
||||||
## GUI 推理
|
## GUI 推理
|
||||||
[下载客户端](https://github.com/AnyaCoder/fish-speech-gui/releases)
|
[下载客户端](https://github.com/AnyaCoder/fish-speech-gui/releases)
|
||||||
|
@ -49,7 +49,7 @@ pip install -e .[stable]
|
|||||||
你需要使用以下指令来构建 fish-agent
|
你需要使用以下指令来构建 fish-agent
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
|
python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
|
||||||
```
|
```
|
||||||
|
|
||||||
`--compile`只能在小于 3.12 版本的 Python 使用,这个功能可以极大程度上提高生成速度。
|
`--compile`只能在小于 3.12 版本的 Python 使用,这个功能可以极大程度上提高生成速度。
|
||||||
|
@ -7,4 +7,4 @@ if [ "${CUDA_ENABLED}" != "true" ]; then
|
|||||||
DEVICE="--device cpu"
|
DEVICE="--device cpu"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
exec python tools/webui.py ${DEVICE}
|
exec python tools/run_webui.py ${DEVICE}
|
||||||
|
@ -176,7 +176,7 @@ def change_infer(
|
|||||||
p_infer = subprocess.Popen(
|
p_infer = subprocess.Popen(
|
||||||
[
|
[
|
||||||
PYTHON,
|
PYTHON,
|
||||||
"tools/webui.py",
|
"tools/run_webui.py",
|
||||||
"--decoder-checkpoint-path",
|
"--decoder-checkpoint-path",
|
||||||
infer_decoder_model,
|
infer_decoder_model,
|
||||||
"--decoder-config-name",
|
"--decoder-config-name",
|
||||||
|
@ -83,7 +83,7 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"!python tools/webui.py \\\n",
|
"!python tools/run_webui.py \\\n",
|
||||||
" --llama-checkpoint-path checkpoints/fish-speech-1.4 \\\n",
|
" --llama-checkpoint-path checkpoints/fish-speech-1.4 \\\n",
|
||||||
" --decoder-checkpoint-path checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth \\\n",
|
" --decoder-checkpoint-path checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth \\\n",
|
||||||
" # --compile"
|
" # --compile"
|
||||||
|
@ -82,7 +82,7 @@ if not "!flags!"=="" set "flags=!flags:~1!"
|
|||||||
echo Debug: flags = !flags!
|
echo Debug: flags = !flags!
|
||||||
|
|
||||||
if "!mode!"=="api" (
|
if "!mode!"=="api" (
|
||||||
%PYTHON_CMD% -m tools.api !flags!
|
%PYTHON_CMD% -m tools.api_server !flags!
|
||||||
) else if "!mode!"=="infer" (
|
) else if "!mode!"=="infer" (
|
||||||
%PYTHON_CMD% -m tools.webui !flags!
|
%PYTHON_CMD% -m tools.webui !flags!
|
||||||
)
|
)
|
||||||
|
951
tools/api.py
951
tools/api.py
@ -1,951 +0,0 @@
|
|||||||
import io
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import queue
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
import wave
|
|
||||||
from argparse import ArgumentParser
|
|
||||||
from http import HTTPStatus
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Annotated, Any
|
|
||||||
|
|
||||||
import librosa
|
|
||||||
import numpy as np
|
|
||||||
import ormsgpack
|
|
||||||
import pyrootutils
|
|
||||||
import soundfile as sf
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
from baize.datastructures import ContentType
|
|
||||||
from kui.asgi import (
|
|
||||||
Body,
|
|
||||||
FactoryClass,
|
|
||||||
HTTPException,
|
|
||||||
HttpRequest,
|
|
||||||
HttpView,
|
|
||||||
JSONResponse,
|
|
||||||
Kui,
|
|
||||||
OpenAPI,
|
|
||||||
StreamResponse,
|
|
||||||
request,
|
|
||||||
)
|
|
||||||
from kui.asgi.routing import MultimethodRoutes
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
||||||
import struct
|
|
||||||
from threading import Lock
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
from cachetools import LRUCache, cached
|
|
||||||
from funasr import AutoModel
|
|
||||||
from silero_vad import get_speech_timestamps, load_silero_vad
|
|
||||||
|
|
||||||
from fish_speech.models.text2semantic.llama import BaseModelArgs
|
|
||||||
|
|
||||||
# from fish_speech.models.vqgan.lit_module import VQGAN
|
|
||||||
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
|
||||||
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
|
||||||
|
|
||||||
# from fish_speech.conversation import IM_END_TOKEN, SEMANTIC_TOKEN
|
|
||||||
from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
|
|
||||||
from fish_speech.utils import autocast_exclude_mps, set_seed
|
|
||||||
from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
|
|
||||||
from tools.llama.generate import (
|
|
||||||
GenerateRequest,
|
|
||||||
GenerateResponse,
|
|
||||||
WrappedGenerateResponse,
|
|
||||||
launch_thread_safe_queue,
|
|
||||||
launch_thread_safe_queue_agent,
|
|
||||||
)
|
|
||||||
from tools.schema import (
|
|
||||||
GLOBAL_NUM_SAMPLES,
|
|
||||||
ASRPackRequest,
|
|
||||||
ServeASRRequest,
|
|
||||||
ServeASRResponse,
|
|
||||||
ServeASRSegment,
|
|
||||||
ServeAudioPart,
|
|
||||||
ServeForwardMessage,
|
|
||||||
ServeMessage,
|
|
||||||
ServeRequest,
|
|
||||||
ServeResponse,
|
|
||||||
ServeStreamDelta,
|
|
||||||
ServeStreamResponse,
|
|
||||||
ServeTextPart,
|
|
||||||
ServeTimedASRResponse,
|
|
||||||
ServeTTSRequest,
|
|
||||||
ServeVQGANDecodeRequest,
|
|
||||||
ServeVQGANDecodeResponse,
|
|
||||||
ServeVQGANEncodeRequest,
|
|
||||||
ServeVQGANEncodeResponse,
|
|
||||||
ServeVQPart,
|
|
||||||
)
|
|
||||||
from tools.vqgan.inference import load_model as load_decoder_model
|
|
||||||
|
|
||||||
global_lock = Lock()
|
|
||||||
|
|
||||||
# Whether to disable keepalive (which is helpful if the server is in the same cluster)
|
|
||||||
DISABLE_KEEPALIVE = os.getenv("DISABLE_KEEPALIVE", "false").lower() == "true"
|
|
||||||
async_client = httpx.AsyncClient(
|
|
||||||
timeout=120, limits=httpx.Limits(keepalive_expiry=0 if DISABLE_KEEPALIVE else None)
|
|
||||||
)
|
|
||||||
backends = torchaudio.list_audio_backends()
|
|
||||||
|
|
||||||
if "ffmpeg" in backends:
|
|
||||||
backend = "ffmpeg"
|
|
||||||
else:
|
|
||||||
backend = "soundfile"
|
|
||||||
|
|
||||||
|
|
||||||
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
|
|
||||||
with wave.open(buffer, "wb") as wav_file:
|
|
||||||
wav_file.setnchannels(channels)
|
|
||||||
wav_file.setsampwidth(bit_depth // 8)
|
|
||||||
wav_file.setframerate(sample_rate)
|
|
||||||
|
|
||||||
wav_header_bytes = buffer.getvalue()
|
|
||||||
buffer.close()
|
|
||||||
return wav_header_bytes
|
|
||||||
|
|
||||||
|
|
||||||
# Define utils for web server
|
|
||||||
async def http_execption_handler(exc: HTTPException):
|
|
||||||
return JSONResponse(
|
|
||||||
dict(
|
|
||||||
statusCode=exc.status_code,
|
|
||||||
message=exc.content,
|
|
||||||
error=HTTPStatus(exc.status_code).phrase,
|
|
||||||
),
|
|
||||||
exc.status_code,
|
|
||||||
exc.headers,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def other_exception_handler(exc: "Exception"):
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
status = HTTPStatus.INTERNAL_SERVER_ERROR
|
|
||||||
return JSONResponse(
|
|
||||||
dict(statusCode=status, message=str(exc), error=status.phrase),
|
|
||||||
status,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_audio(reference_audio, sr):
|
|
||||||
if len(reference_audio) > 255 or not Path(reference_audio).exists():
|
|
||||||
audio_data = reference_audio
|
|
||||||
reference_audio = io.BytesIO(audio_data)
|
|
||||||
|
|
||||||
waveform, original_sr = torchaudio.load(reference_audio, backend=backend)
|
|
||||||
|
|
||||||
if waveform.shape[0] > 1:
|
|
||||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
|
||||||
|
|
||||||
if original_sr != sr:
|
|
||||||
resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr)
|
|
||||||
waveform = resampler(waveform)
|
|
||||||
|
|
||||||
audio = waveform.squeeze().numpy()
|
|
||||||
return audio
|
|
||||||
|
|
||||||
|
|
||||||
def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
|
|
||||||
if enable_reference_audio and reference_audio is not None:
|
|
||||||
# Load audios, and prepare basic info here
|
|
||||||
reference_audio_content = load_audio(
|
|
||||||
reference_audio, decoder_model.spec_transform.sample_rate
|
|
||||||
)
|
|
||||||
|
|
||||||
audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
|
|
||||||
None, None, :
|
|
||||||
]
|
|
||||||
audio_lengths = torch.tensor(
|
|
||||||
[audios.shape[2]], device=decoder_model.device, dtype=torch.long
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
|
|
||||||
)
|
|
||||||
|
|
||||||
# VQ Encoder
|
|
||||||
if isinstance(decoder_model, FireflyArchitecture):
|
|
||||||
prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
|
|
||||||
|
|
||||||
logger.info(f"Encoded prompt: {prompt_tokens.shape}")
|
|
||||||
else:
|
|
||||||
prompt_tokens = None
|
|
||||||
logger.info("No reference audio provided")
|
|
||||||
|
|
||||||
return prompt_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def decode_vq_tokens(
|
|
||||||
*,
|
|
||||||
decoder_model,
|
|
||||||
codes,
|
|
||||||
):
|
|
||||||
feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
|
|
||||||
logger.info(f"VQ features: {codes.shape}")
|
|
||||||
|
|
||||||
if isinstance(decoder_model, FireflyArchitecture):
|
|
||||||
# VQGAN Inference
|
|
||||||
return decoder_model.decode(
|
|
||||||
indices=codes[None],
|
|
||||||
feature_lengths=feature_lengths,
|
|
||||||
)[0].squeeze()
|
|
||||||
|
|
||||||
raise ValueError(f"Unknown model type: {type(decoder_model)}")
|
|
||||||
|
|
||||||
|
|
||||||
routes = MultimethodRoutes(base_class=HttpView)
|
|
||||||
|
|
||||||
|
|
||||||
def get_content_type(audio_format):
|
|
||||||
if audio_format == "wav":
|
|
||||||
return "audio/wav"
|
|
||||||
elif audio_format == "flac":
|
|
||||||
return "audio/flac"
|
|
||||||
elif audio_format == "mp3":
|
|
||||||
return "audio/mpeg"
|
|
||||||
else:
|
|
||||||
return "application/octet-stream"
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
@torch.autocast(device_type="cuda", dtype=torch.half)
|
|
||||||
def batch_encode(model, audios: list[bytes | torch.Tensor]):
|
|
||||||
audios = [
|
|
||||||
(
|
|
||||||
torch.from_numpy(
|
|
||||||
librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
|
|
||||||
)[None]
|
|
||||||
if isinstance(audio, bytes)
|
|
||||||
else audio
|
|
||||||
)
|
|
||||||
for audio in audios
|
|
||||||
]
|
|
||||||
|
|
||||||
# if any(audio.shape[-1] > model.spec_transform.sample_rate * 120 for audio in audios):
|
|
||||||
# raise ValueError("Single audio length is too long (>120s)")
|
|
||||||
|
|
||||||
max_length = max(audio.shape[-1] for audio in audios)
|
|
||||||
print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
|
|
||||||
|
|
||||||
lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
|
|
||||||
max_length = lengths.max().item()
|
|
||||||
padded = torch.stack(
|
|
||||||
[
|
|
||||||
torch.nn.functional.pad(audio, (0, max_length - audio.shape[-1]))
|
|
||||||
for audio in audios
|
|
||||||
]
|
|
||||||
).to(model.device)
|
|
||||||
|
|
||||||
features, feature_lengths = model.encode(padded, audio_lengths=lengths)
|
|
||||||
features, feature_lengths = features.cpu(), feature_lengths.cpu()
|
|
||||||
|
|
||||||
return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
|
|
||||||
|
|
||||||
|
|
||||||
@cached(
|
|
||||||
cache=LRUCache(maxsize=10000),
|
|
||||||
key=lambda model, audios: (model.device, tuple(audios)),
|
|
||||||
)
|
|
||||||
def cached_vqgan_batch_encode(model, audios: list[bytes]):
|
|
||||||
return batch_encode(model, audios)
|
|
||||||
|
|
||||||
|
|
||||||
@routes.http.post("/v1/vqgan/encode")
|
|
||||||
def api_vqgan_encode(payload: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
tokens = cached_vqgan_batch_encode(decoder_model, payload.audios)
|
|
||||||
logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
|
|
||||||
|
|
||||||
return ormsgpack.packb(
|
|
||||||
ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
|
|
||||||
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
@torch.autocast(device_type="cuda", dtype=torch.half)
|
|
||||||
def vqgan_decode(model, features):
|
|
||||||
lengths = torch.tensor(
|
|
||||||
[feature.shape[-1] for feature in features], device=model.device
|
|
||||||
)
|
|
||||||
max_length = lengths.max().item()
|
|
||||||
padded = torch.stack(
|
|
||||||
[
|
|
||||||
torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
|
|
||||||
for feature in features
|
|
||||||
]
|
|
||||||
).to(model.device)
|
|
||||||
|
|
||||||
# If bs too large, we do micro batch decode
|
|
||||||
audios, audio_lengths = [], []
|
|
||||||
for i in range(0, padded.shape[0], 8):
|
|
||||||
audio, audio_length = model.decode(
|
|
||||||
padded[i : i + 8], feature_lengths=lengths[i : i + 8]
|
|
||||||
)
|
|
||||||
audios.append(audio)
|
|
||||||
audio_lengths.append(audio_length)
|
|
||||||
audios = torch.cat(audios, dim=0)
|
|
||||||
audio_lengths = torch.cat(audio_lengths, dim=0)
|
|
||||||
audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
|
|
||||||
|
|
||||||
return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
|
|
||||||
|
|
||||||
|
|
||||||
@routes.http.post("/v1/vqgan/decode")
|
|
||||||
def api_vqgan_decode(payload: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
|
|
||||||
tokens = [torch.tensor(token, dtype=torch.int) for token in payload.tokens]
|
|
||||||
start_time = time.time()
|
|
||||||
audios = vqgan_decode(decoder_model, tokens)
|
|
||||||
logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
|
|
||||||
audios = [audio.astype(np.float16).tobytes() for audio in audios]
|
|
||||||
return ormsgpack.packb(
|
|
||||||
ServeVQGANDecodeResponse(audios=audios), option=ormsgpack.OPT_SERIALIZE_PYDANTIC
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def batch_asr(model, audios, sr, language="auto"):
|
|
||||||
resampled_audios = []
|
|
||||||
for audio in audios:
|
|
||||||
audio = torchaudio.functional.resample(audio, sr, 16000)
|
|
||||||
assert audio.ndim == 1
|
|
||||||
resampled_audios.append(audio)
|
|
||||||
|
|
||||||
with global_lock:
|
|
||||||
res = model.generate(
|
|
||||||
input=resampled_audios,
|
|
||||||
batch_size=len(resampled_audios),
|
|
||||||
language=language,
|
|
||||||
use_itn=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for r, audio in zip(res, audios):
|
|
||||||
text = r["text"]
|
|
||||||
text = re.sub(r"<\|.*?\|>", "", text)
|
|
||||||
duration = len(audio) / sr * 1000
|
|
||||||
huge_gap = False
|
|
||||||
|
|
||||||
if "timestamp" in r and len(r["timestamp"]) > 2:
|
|
||||||
for timestamp_a, timestamp_b in zip(
|
|
||||||
r["timestamp"][:-1], r["timestamp"][1:]
|
|
||||||
):
|
|
||||||
# If there is a gap of more than 5 seconds, we consider it as a huge gap
|
|
||||||
if timestamp_b[0] - timestamp_a[1] > 5000:
|
|
||||||
huge_gap = True
|
|
||||||
break
|
|
||||||
|
|
||||||
# Doesn't make sense to have a huge gap at the end
|
|
||||||
if duration - r["timestamp"][-1][1] > 3000:
|
|
||||||
huge_gap = True
|
|
||||||
|
|
||||||
results.append(
|
|
||||||
{
|
|
||||||
"text": text,
|
|
||||||
"duration": duration,
|
|
||||||
"huge_gap": huge_gap,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
@routes.http.post("/v1/asr")
|
|
||||||
def api_invoke_asr(payload: Annotated[ServeASRRequest, Body(exclusive=True)]):
|
|
||||||
start_time = time.time()
|
|
||||||
audios = [np.frombuffer(audio, dtype=np.float16) for audio in payload.audios]
|
|
||||||
audios = [torch.from_numpy(audio).float() for audio in audios]
|
|
||||||
|
|
||||||
if any(audios.shape[-1] >= 30 * payload.sample_rate for audios in audios):
|
|
||||||
raise HTTPException(status_code=400, detail="Audio length is too long")
|
|
||||||
|
|
||||||
transcriptions = batch_asr(
|
|
||||||
asr_model, audios=audios, sr=payload.sample_rate, language=payload.language
|
|
||||||
)
|
|
||||||
logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
|
|
||||||
|
|
||||||
return ormsgpack.packb(
|
|
||||||
ServeASRResponse(transcriptions=transcriptions),
|
|
||||||
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
from fish_speech.conversation import Conversation, Message
|
|
||||||
|
|
||||||
|
|
||||||
def execute_request(
|
|
||||||
input_queue: queue.Queue,
|
|
||||||
tokenizer: FishTokenizer,
|
|
||||||
config: BaseModelArgs,
|
|
||||||
request: ServeRequest,
|
|
||||||
device: str = "cuda:0",
|
|
||||||
):
|
|
||||||
|
|
||||||
im_end_id = tokenizer.get_token_id(IM_END_TOKEN)
|
|
||||||
messages = []
|
|
||||||
for message in request.messages:
|
|
||||||
messages.append(message.to_conversation_message())
|
|
||||||
|
|
||||||
assert len(messages) >= 1, "At least one message is required"
|
|
||||||
# assert messages[-1].role == "user", "The last message must be from the user"
|
|
||||||
|
|
||||||
if messages[-1].role == "user":
|
|
||||||
messages.append(
|
|
||||||
Message(role="assistant", parts=[], add_im_end=False, modality="voice")
|
|
||||||
)
|
|
||||||
elif messages[-1].role == "raw":
|
|
||||||
messages[-1].add_im_start = False
|
|
||||||
messages[-1].add_im_end = False
|
|
||||||
messages[-1].modality = "voice"
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
messages[-1].role == "assistant"
|
|
||||||
), "The last message must be from the assistant"
|
|
||||||
messages[-1].add_im_end = False
|
|
||||||
|
|
||||||
conv = Conversation(messages=messages)
|
|
||||||
|
|
||||||
# conv.visualize(tokenizer)
|
|
||||||
prompt = conv.encode_for_inference(
|
|
||||||
tokenizer=tokenizer, num_codebooks=config.num_codebooks
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
if request.streaming:
|
|
||||||
for i in range(request.num_samples):
|
|
||||||
yield ServeStreamResponse(
|
|
||||||
sample_id=i,
|
|
||||||
delta=ServeStreamDelta(
|
|
||||||
role="assistant",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
req = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"max_new_tokens": request.max_new_tokens,
|
|
||||||
"im_end_id": im_end_id,
|
|
||||||
"temperature": request.temperature,
|
|
||||||
"top_p": request.top_p,
|
|
||||||
"repetition_penalty": request.repetition_penalty,
|
|
||||||
"num_samples": request.num_samples,
|
|
||||||
"early_stop_threshold": request.early_stop_threshold,
|
|
||||||
}
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
response_queue = queue.Queue()
|
|
||||||
input_queue.put(GenerateRequest(req, response_queue))
|
|
||||||
|
|
||||||
# Decoding
|
|
||||||
decode_buffer = [[] for _ in range(request.num_samples)]
|
|
||||||
parts = [[] for _ in range(request.num_samples)]
|
|
||||||
|
|
||||||
def send_reset_buffer(sample_id):
|
|
||||||
nonlocal decode_buffer
|
|
||||||
if len(decode_buffer[sample_id]) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
decoded = tokenizer.decode(decode_buffer[sample_id])
|
|
||||||
part = ServeTextPart(text=decoded)
|
|
||||||
|
|
||||||
if request.streaming:
|
|
||||||
yield ServeStreamResponse(delta=ServeStreamDelta(part=part))
|
|
||||||
else:
|
|
||||||
parts[sample_id].append(part)
|
|
||||||
|
|
||||||
decode_buffer[sample_id] = []
|
|
||||||
|
|
||||||
# Decode process
|
|
||||||
finished = [False for _ in range(request.num_samples)]
|
|
||||||
stats = {}
|
|
||||||
idx = 0
|
|
||||||
while True:
|
|
||||||
response = response_queue.get()
|
|
||||||
|
|
||||||
if response in ["stop", "error"]:
|
|
||||||
break
|
|
||||||
|
|
||||||
for sample_id, tokens in enumerate(response):
|
|
||||||
if finished[sample_id]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if tokens[0] == im_end_id:
|
|
||||||
finished[sample_id] = True
|
|
||||||
if request.streaming:
|
|
||||||
yield from send_reset_buffer(sample_id)
|
|
||||||
yield ServeStreamResponse(
|
|
||||||
sample_id=sample_id,
|
|
||||||
finish_reason="stop",
|
|
||||||
stats=stats,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
is_semantic = (
|
|
||||||
tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id
|
|
||||||
)
|
|
||||||
if is_semantic and request.streaming:
|
|
||||||
yield from send_reset_buffer(sample_id)
|
|
||||||
# Streaming vq
|
|
||||||
_tokens = tokens[1:].clone()
|
|
||||||
|
|
||||||
if config.share_codebook_embeddings is False:
|
|
||||||
for i in range(len(_tokens)):
|
|
||||||
_tokens[i] -= config.codebook_size * i
|
|
||||||
|
|
||||||
yield ServeStreamResponse(
|
|
||||||
sample_id=sample_id,
|
|
||||||
delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Not streaming vq
|
|
||||||
if is_semantic:
|
|
||||||
yield from send_reset_buffer(sample_id)
|
|
||||||
# None streaming vq
|
|
||||||
if len(parts[sample_id]) == 0 or not isinstance(
|
|
||||||
parts[sample_id][-1], ServeVQPart
|
|
||||||
):
|
|
||||||
_tokens = tokens[1:].clone()
|
|
||||||
|
|
||||||
if config.share_codebook_embeddings is False:
|
|
||||||
for i in range(len(_tokens)):
|
|
||||||
_tokens[i] -= config.codebook_size * i
|
|
||||||
|
|
||||||
parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
|
|
||||||
else:
|
|
||||||
for codebook_id, value in enumerate(tokens[1:, :]):
|
|
||||||
val = value.item()
|
|
||||||
if config.share_codebook_embeddings is False:
|
|
||||||
val -= config.codebook_size * codebook_id
|
|
||||||
|
|
||||||
parts[sample_id][-1].codes[codebook_id].append(val)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not is_semantic:
|
|
||||||
# Stream text decode is not supported now
|
|
||||||
decode_buffer[sample_id].append(tokens[0, 0])
|
|
||||||
|
|
||||||
if idx == 0:
|
|
||||||
stats["time_to_first_token"] = (time.time() - start) * 1000
|
|
||||||
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
for sample_id in range(request.num_samples):
|
|
||||||
yield from send_reset_buffer(sample_id)
|
|
||||||
|
|
||||||
stats["total_time"] = (time.time() - start) * 1000
|
|
||||||
stats["total_tokens"] = idx
|
|
||||||
|
|
||||||
if request.streaming:
|
|
||||||
for sample_id in range(request.num_samples):
|
|
||||||
if finished[sample_id]:
|
|
||||||
continue
|
|
||||||
yield ServeStreamResponse(
|
|
||||||
finish_reason=response, stats=stats, sample_id=sample_id
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
yield ServeResponse(
|
|
||||||
messages=[
|
|
||||||
ServeMessage(role="assistant", parts=parts[i])
|
|
||||||
for i in range(request.num_samples)
|
|
||||||
],
|
|
||||||
finish_reason=response,
|
|
||||||
stats=stats,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@routes.http.post("/v1/chat")
|
|
||||||
def api_invoke_chat(
|
|
||||||
req: Annotated[ServeRequest, Body(exclusive=True)],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Invoke model and generate audio
|
|
||||||
"""
|
|
||||||
|
|
||||||
# This makes torch compile happy
|
|
||||||
assert (
|
|
||||||
req.num_samples == GLOBAL_NUM_SAMPLES
|
|
||||||
), f"num_samples must be {GLOBAL_NUM_SAMPLES}"
|
|
||||||
|
|
||||||
content_type = request.headers.get("Content-Type", "application/json")
|
|
||||||
json_mode = "application/json" in content_type
|
|
||||||
|
|
||||||
async def wrapped_generator():
|
|
||||||
generator = execute_request(llama_queue, tokenizer, config, req, args.device)
|
|
||||||
|
|
||||||
for i in generator:
|
|
||||||
if json_mode:
|
|
||||||
body = i.model_dump_json().encode("utf-8")
|
|
||||||
yield b"data: " + body + b"\n\n"
|
|
||||||
else:
|
|
||||||
body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
|
|
||||||
yield struct.pack("I", len(body)) + body
|
|
||||||
|
|
||||||
# Naive mode
|
|
||||||
if req.streaming is False:
|
|
||||||
result = next(execute_request(llama_queue, tokenizer, config, req, args.device))
|
|
||||||
|
|
||||||
if json_mode:
|
|
||||||
return JSONResponse(result.model_dump())
|
|
||||||
else:
|
|
||||||
return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
|
|
||||||
|
|
||||||
return StreamResponse(
|
|
||||||
iterable=wrapped_generator(), content_type="text/event-stream"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def inference(req: ServeTTSRequest):
|
|
||||||
|
|
||||||
idstr: str | None = req.reference_id
|
|
||||||
if idstr is not None:
|
|
||||||
ref_folder = Path("references") / idstr
|
|
||||||
ref_folder.mkdir(parents=True, exist_ok=True)
|
|
||||||
ref_audios = list_files(
|
|
||||||
ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
|
|
||||||
)
|
|
||||||
|
|
||||||
if req.use_memory_cache == "never" or (
|
|
||||||
req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
|
|
||||||
):
|
|
||||||
prompt_tokens = [
|
|
||||||
encode_reference(
|
|
||||||
decoder_model=decoder_model,
|
|
||||||
reference_audio=audio_to_bytes(str(ref_audio)),
|
|
||||||
enable_reference_audio=True,
|
|
||||||
)
|
|
||||||
for ref_audio in ref_audios
|
|
||||||
]
|
|
||||||
prompt_texts = [
|
|
||||||
read_ref_text(str(ref_audio.with_suffix(".lab")))
|
|
||||||
for ref_audio in ref_audios
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
logger.info("Use same references")
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Parse reference audio aka prompt
|
|
||||||
refs = req.references
|
|
||||||
|
|
||||||
if req.use_memory_cache == "never" or (
|
|
||||||
req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
|
|
||||||
):
|
|
||||||
prompt_tokens = [
|
|
||||||
encode_reference(
|
|
||||||
decoder_model=decoder_model,
|
|
||||||
reference_audio=ref.audio,
|
|
||||||
enable_reference_audio=True,
|
|
||||||
)
|
|
||||||
for ref in refs
|
|
||||||
]
|
|
||||||
prompt_texts = [ref.text for ref in refs]
|
|
||||||
else:
|
|
||||||
logger.info("Use same references")
|
|
||||||
|
|
||||||
if req.seed is not None:
|
|
||||||
set_seed(req.seed)
|
|
||||||
logger.warning(f"set seed: {req.seed}")
|
|
||||||
|
|
||||||
# LLAMA Inference
|
|
||||||
request = dict(
|
|
||||||
device=decoder_model.device,
|
|
||||||
max_new_tokens=req.max_new_tokens,
|
|
||||||
text=(
|
|
||||||
req.text
|
|
||||||
if not req.normalize
|
|
||||||
else ChnNormedText(raw_text=req.text).normalize()
|
|
||||||
),
|
|
||||||
top_p=req.top_p,
|
|
||||||
repetition_penalty=req.repetition_penalty,
|
|
||||||
temperature=req.temperature,
|
|
||||||
compile=args.compile,
|
|
||||||
iterative_prompt=req.chunk_length > 0,
|
|
||||||
chunk_length=req.chunk_length,
|
|
||||||
max_length=4096,
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
prompt_text=prompt_texts,
|
|
||||||
)
|
|
||||||
|
|
||||||
response_queue = queue.Queue()
|
|
||||||
llama_queue.put(
|
|
||||||
GenerateRequest(
|
|
||||||
request=request,
|
|
||||||
response_queue=response_queue,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if req.streaming:
|
|
||||||
yield wav_chunk_header()
|
|
||||||
|
|
||||||
segments = []
|
|
||||||
while True:
|
|
||||||
result: WrappedGenerateResponse = response_queue.get()
|
|
||||||
if result.status == "error":
|
|
||||||
raise result.response
|
|
||||||
break
|
|
||||||
|
|
||||||
result: GenerateResponse = result.response
|
|
||||||
if result.action == "next":
|
|
||||||
break
|
|
||||||
|
|
||||||
with autocast_exclude_mps(
|
|
||||||
device_type=decoder_model.device.type, dtype=args.precision
|
|
||||||
):
|
|
||||||
fake_audios = decode_vq_tokens(
|
|
||||||
decoder_model=decoder_model,
|
|
||||||
codes=result.codes,
|
|
||||||
)
|
|
||||||
|
|
||||||
fake_audios = fake_audios.float().cpu().numpy()
|
|
||||||
|
|
||||||
if req.streaming:
|
|
||||||
yield (fake_audios * 32768).astype(np.int16).tobytes()
|
|
||||||
else:
|
|
||||||
segments.append(fake_audios)
|
|
||||||
|
|
||||||
if req.streaming:
|
|
||||||
return
|
|
||||||
|
|
||||||
if len(segments) == 0:
|
|
||||||
raise HTTPException(
|
|
||||||
HTTPStatus.INTERNAL_SERVER_ERROR,
|
|
||||||
content="No audio generated, please check the input text.",
|
|
||||||
)
|
|
||||||
|
|
||||||
fake_audios = np.concatenate(segments, axis=0)
|
|
||||||
yield fake_audios
|
|
||||||
|
|
||||||
|
|
||||||
async def inference_async(req: ServeTTSRequest):
|
|
||||||
for chunk in inference(req):
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
|
|
||||||
async def buffer_to_async_generator(buffer):
|
|
||||||
yield buffer
|
|
||||||
|
|
||||||
|
|
||||||
@routes.http.post("/v1/tts")
|
|
||||||
async def api_invoke_model(
|
|
||||||
req: Annotated[ServeTTSRequest, Body(exclusive=True)],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Invoke model and generate audio
|
|
||||||
"""
|
|
||||||
|
|
||||||
if args.max_text_length > 0 and len(req.text) > args.max_text_length:
|
|
||||||
raise HTTPException(
|
|
||||||
HTTPStatus.BAD_REQUEST,
|
|
||||||
content=f"Text is too long, max length is {args.max_text_length}",
|
|
||||||
)
|
|
||||||
|
|
||||||
if req.streaming and req.format != "wav":
|
|
||||||
raise HTTPException(
|
|
||||||
HTTPStatus.BAD_REQUEST,
|
|
||||||
content="Streaming only supports WAV format",
|
|
||||||
)
|
|
||||||
|
|
||||||
if req.streaming:
|
|
||||||
return StreamResponse(
|
|
||||||
iterable=inference_async(req),
|
|
||||||
headers={
|
|
||||||
"Content-Disposition": f"attachment; filename=audio.{req.format}",
|
|
||||||
},
|
|
||||||
content_type=get_content_type(req.format),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
fake_audios = next(inference(req))
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
sf.write(
|
|
||||||
buffer,
|
|
||||||
fake_audios,
|
|
||||||
decoder_model.spec_transform.sample_rate,
|
|
||||||
format=req.format,
|
|
||||||
)
|
|
||||||
|
|
||||||
return StreamResponse(
|
|
||||||
iterable=buffer_to_async_generator(buffer.getvalue()),
|
|
||||||
headers={
|
|
||||||
"Content-Disposition": f"attachment; filename=audio.{req.format}",
|
|
||||||
},
|
|
||||||
content_type=get_content_type(req.format),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@routes.http.post("/v1/health")
|
|
||||||
async def api_health():
|
|
||||||
"""
|
|
||||||
Health check
|
|
||||||
"""
|
|
||||||
return JSONResponse({"status": "ok"})
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = ArgumentParser()
|
|
||||||
parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
|
|
||||||
parser.add_argument("--load-asr-model", action="store_true")
|
|
||||||
parser.add_argument(
|
|
||||||
"--llama-checkpoint-path",
|
|
||||||
type=str,
|
|
||||||
default="checkpoints/fish-speech-1.4",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--decoder-checkpoint-path",
|
|
||||||
type=str,
|
|
||||||
default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
|
||||||
)
|
|
||||||
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
|
||||||
parser.add_argument("--device", type=str, default="cuda")
|
|
||||||
parser.add_argument("--half", action="store_true")
|
|
||||||
parser.add_argument("--compile", action="store_true")
|
|
||||||
parser.add_argument("--max-text-length", type=int, default=0)
|
|
||||||
parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
|
|
||||||
parser.add_argument("--workers", type=int, default=1)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
# Define Kui app
|
|
||||||
openapi = OpenAPI(
|
|
||||||
{
|
|
||||||
"title": "Fish Speech API",
|
|
||||||
"version": "1.4.2",
|
|
||||||
},
|
|
||||||
).routes
|
|
||||||
|
|
||||||
|
|
||||||
class MsgPackRequest(HttpRequest):
|
|
||||||
async def data(
|
|
||||||
self,
|
|
||||||
) -> Annotated[
|
|
||||||
Any, ContentType("application/msgpack"), ContentType("application/json")
|
|
||||||
]:
|
|
||||||
if self.content_type == "application/msgpack":
|
|
||||||
return ormsgpack.unpackb(await self.body)
|
|
||||||
|
|
||||||
elif self.content_type == "application/json":
|
|
||||||
return await self.json
|
|
||||||
|
|
||||||
raise HTTPException(
|
|
||||||
HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
|
|
||||||
headers={"Accept": "application/msgpack, application/json"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
app = Kui(
|
|
||||||
routes=routes + openapi[1:], # Remove the default route
|
|
||||||
exception_handlers={
|
|
||||||
HTTPException: http_execption_handler,
|
|
||||||
Exception: other_exception_handler,
|
|
||||||
},
|
|
||||||
factory_class=FactoryClass(http=MsgPackRequest),
|
|
||||||
cors_config={},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_asr_model(*, device="cuda", hub="ms"):
|
|
||||||
return AutoModel(
|
|
||||||
model="iic/SenseVoiceSmall",
|
|
||||||
device=device,
|
|
||||||
disable_pbar=True,
|
|
||||||
hub=hub,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Each worker process created by Uvicorn has its own memory space,
|
|
||||||
# meaning that models and variables are not shared between processes.
|
|
||||||
# Therefore, any global variables (like `llama_queue` or `decoder_model`)
|
|
||||||
# will not be shared across workers.
|
|
||||||
|
|
||||||
|
|
||||||
# Multi-threading for deep learning can cause issues, such as inconsistent
|
|
||||||
# outputs if multiple threads access the same buffers simultaneously.
|
|
||||||
# Instead, it's better to use multiprocessing or independent models per thread.
|
|
||||||
@app.on_startup
|
|
||||||
def initialize_app(app: Kui):
|
|
||||||
|
|
||||||
global args, llama_queue, tokenizer, config, decoder_model, vad_model, asr_model, prompt_tokens, prompt_texts
|
|
||||||
|
|
||||||
prompt_tokens, prompt_texts = [], []
|
|
||||||
|
|
||||||
args = parse_args() # args same as ones in other processes
|
|
||||||
args.precision = torch.half if args.half else torch.bfloat16
|
|
||||||
|
|
||||||
if args.load_asr_model:
|
|
||||||
logger.info(f"Loading ASR model...")
|
|
||||||
asr_model = load_asr_model(device=args.device)
|
|
||||||
|
|
||||||
logger.info("Loading Llama model...")
|
|
||||||
|
|
||||||
if args.mode == "tts":
|
|
||||||
llama_queue = launch_thread_safe_queue(
|
|
||||||
checkpoint_path=args.llama_checkpoint_path,
|
|
||||||
device=args.device,
|
|
||||||
precision=args.precision,
|
|
||||||
compile=args.compile,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
llama_queue, tokenizer, config = launch_thread_safe_queue_agent(
|
|
||||||
checkpoint_path=args.llama_checkpoint_path,
|
|
||||||
device=args.device,
|
|
||||||
precision=args.precision,
|
|
||||||
compile=args.compile,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Llama model loaded, loading VQ-GAN model...")
|
|
||||||
|
|
||||||
decoder_model = load_decoder_model(
|
|
||||||
config_name=args.decoder_config_name,
|
|
||||||
checkpoint_path=args.decoder_checkpoint_path,
|
|
||||||
device=args.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("VQ-GAN model loaded, warming up...")
|
|
||||||
|
|
||||||
vad_model = load_silero_vad()
|
|
||||||
|
|
||||||
logger.info("VAD model loaded, warming up...")
|
|
||||||
|
|
||||||
if args.mode == "tts":
|
|
||||||
# Dry run to ensure models work and avoid first-time latency
|
|
||||||
list(
|
|
||||||
inference(
|
|
||||||
ServeTTSRequest(
|
|
||||||
text="Hello world.",
|
|
||||||
references=[],
|
|
||||||
reference_id=None,
|
|
||||||
max_new_tokens=0,
|
|
||||||
chunk_length=200,
|
|
||||||
top_p=0.7,
|
|
||||||
repetition_penalty=1.5,
|
|
||||||
temperature=0.7,
|
|
||||||
emotion=None,
|
|
||||||
format="wav",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Warming up done, starting server at http://{args.listen}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
args = parse_args()
|
|
||||||
host, port = args.listen.split(":")
|
|
||||||
uvicorn.run(
|
|
||||||
"tools.api:app",
|
|
||||||
host=host,
|
|
||||||
port=int(port),
|
|
||||||
workers=args.workers,
|
|
||||||
log_level="info",
|
|
||||||
)
|
|
@ -69,10 +69,6 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--format", type=str, choices=["wav", "mp3", "flac"], default="wav"
|
"--format", type=str, choices=["wav", "mp3", "flac"], default="wav"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--mp3_bitrate", type=int, choices=[64, 128, 192], default=64, help="kHz"
|
|
||||||
)
|
|
||||||
parser.add_argument("--opus_bitrate", type=int, default=-1000)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--latency",
|
"--latency",
|
||||||
type=str,
|
type=str,
|
||||||
@ -112,11 +108,9 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use_memory_cache",
|
"--use_memory_cache",
|
||||||
type=str,
|
type=str,
|
||||||
default="never",
|
default="off",
|
||||||
choices=["on-demand", "never"],
|
choices=["on", "off"],
|
||||||
help="Cache encoded references codes in memory.\n"
|
help="Cache encoded references codes in memory.\n",
|
||||||
"If `on-demand`, the server will use cached encodings\n "
|
|
||||||
"instead of encoding reference audio again.",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
@ -154,14 +148,14 @@ if __name__ == "__main__":
|
|||||||
data = {
|
data = {
|
||||||
"text": args.text,
|
"text": args.text,
|
||||||
"references": [
|
"references": [
|
||||||
ServeReferenceAudio(audio=ref_audio, text=ref_text)
|
ServeReferenceAudio(
|
||||||
|
audio=ref_audio if ref_audio is not None else b"", text=ref_text
|
||||||
|
)
|
||||||
for ref_text, ref_audio in zip(ref_texts, byte_audios)
|
for ref_text, ref_audio in zip(ref_texts, byte_audios)
|
||||||
],
|
],
|
||||||
"reference_id": idstr,
|
"reference_id": idstr,
|
||||||
"normalize": args.normalize,
|
"normalize": args.normalize,
|
||||||
"format": args.format,
|
"format": args.format,
|
||||||
"mp3_bitrate": args.mp3_bitrate,
|
|
||||||
"opus_bitrate": args.opus_bitrate,
|
|
||||||
"max_new_tokens": args.max_new_tokens,
|
"max_new_tokens": args.max_new_tokens,
|
||||||
"chunk_length": args.chunk_length,
|
"chunk_length": args.chunk_length,
|
||||||
"top_p": args.top_p,
|
"top_p": args.top_p,
|
98
tools/api_server.py
Normal file
98
tools/api_server.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
from threading import Lock
|
||||||
|
|
||||||
|
import pyrootutils
|
||||||
|
import uvicorn
|
||||||
|
from kui.asgi import FactoryClass, HTTPException, HttpRoute, Kui, OpenAPI, Routes
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||||
|
|
||||||
|
from tools.server.api_utils import MsgPackRequest, parse_args
|
||||||
|
from tools.server.exception_handler import ExceptionHandler
|
||||||
|
from tools.server.model_manager import ModelManager
|
||||||
|
from tools.server.views import (
|
||||||
|
ASRView,
|
||||||
|
ChatView,
|
||||||
|
HealthView,
|
||||||
|
TTSView,
|
||||||
|
VQGANDecodeView,
|
||||||
|
VQGANEncodeView,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class API(ExceptionHandler):
|
||||||
|
def __init__(self):
|
||||||
|
self.args = parse_args()
|
||||||
|
self.routes = [
|
||||||
|
("/v1/health", HealthView),
|
||||||
|
("/v1/vqgan/encode", VQGANEncodeView),
|
||||||
|
("/v1/vqgan/decode", VQGANDecodeView),
|
||||||
|
("/v1/asr", ASRView),
|
||||||
|
("/v1/tts", TTSView),
|
||||||
|
("/v1/chat", ChatView),
|
||||||
|
]
|
||||||
|
self.routes = Routes([HttpRoute(path, view) for path, view in self.routes])
|
||||||
|
|
||||||
|
self.openapi = OpenAPI(
|
||||||
|
{
|
||||||
|
"title": "Fish Speech API",
|
||||||
|
"version": "1.5.0",
|
||||||
|
},
|
||||||
|
).routes
|
||||||
|
|
||||||
|
# Initialize the app
|
||||||
|
self.app = Kui(
|
||||||
|
routes=self.routes + self.openapi[1:], # Remove the default route
|
||||||
|
exception_handlers={
|
||||||
|
HTTPException: self.http_exception_handler,
|
||||||
|
Exception: self.other_exception_handler,
|
||||||
|
},
|
||||||
|
factory_class=FactoryClass(http=MsgPackRequest),
|
||||||
|
cors_config={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the state variables
|
||||||
|
self.app.state.lock = Lock()
|
||||||
|
self.app.state.device = self.args.device
|
||||||
|
self.app.state.max_text_length = self.args.max_text_length
|
||||||
|
|
||||||
|
# Associate the app with the model manager
|
||||||
|
self.app.on_startup(self.initialize_app)
|
||||||
|
|
||||||
|
async def initialize_app(self, app: Kui):
|
||||||
|
# Make the ModelManager available to the views
|
||||||
|
app.state.model_manager = ModelManager(
|
||||||
|
mode=self.args.mode,
|
||||||
|
device=self.args.device,
|
||||||
|
half=self.args.half,
|
||||||
|
compile=self.args.compile,
|
||||||
|
asr_enabled=self.args.load_asr_model,
|
||||||
|
llama_checkpoint_path=self.args.llama_checkpoint_path,
|
||||||
|
decoder_checkpoint_path=self.args.decoder_checkpoint_path,
|
||||||
|
decoder_config_name=self.args.decoder_config_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Startup done, listening server at http://{self.args.listen}")
|
||||||
|
|
||||||
|
|
||||||
|
# Each worker process created by Uvicorn has its own memory space,
|
||||||
|
# meaning that models and variables are not shared between processes.
|
||||||
|
# Therefore, any variables (like `llama_queue` or `decoder_model`)
|
||||||
|
# will not be shared across workers.
|
||||||
|
|
||||||
|
# Multi-threading for deep learning can cause issues, such as inconsistent
|
||||||
|
# outputs if multiple threads access the same buffers simultaneously.
|
||||||
|
# Instead, it's better to use multiprocessing or independent models per thread.
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
api = API()
|
||||||
|
host, port = api.args.listen.split(":")
|
||||||
|
|
||||||
|
uvicorn.run(
|
||||||
|
api.app,
|
||||||
|
host=host,
|
||||||
|
port=int(port),
|
||||||
|
workers=api.args.workers,
|
||||||
|
log_level="info",
|
||||||
|
)
|
@ -14,8 +14,8 @@ import ormsgpack
|
|||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
from .schema import (
|
from .schema import (
|
||||||
|
ServeChatRequest,
|
||||||
ServeMessage,
|
ServeMessage,
|
||||||
ServeRequest,
|
|
||||||
ServeTextPart,
|
ServeTextPart,
|
||||||
ServeVQGANDecodeRequest,
|
ServeVQGANDecodeRequest,
|
||||||
ServeVQGANEncodeRequest,
|
ServeVQGANEncodeRequest,
|
||||||
@ -163,7 +163,7 @@ class FishE2EAgent:
|
|||||||
else:
|
else:
|
||||||
user_codes = None
|
user_codes = None
|
||||||
|
|
||||||
request = ServeRequest(
|
request = ServeChatRequest(
|
||||||
messages=prev_messages
|
messages=prev_messages
|
||||||
+ (
|
+ (
|
||||||
[
|
[
|
||||||
|
193
tools/inference_engine/__init__.py
Normal file
193
tools/inference_engine/__init__.py
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
import gc
|
||||||
|
import queue
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
||||||
|
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
||||||
|
from fish_speech.utils import autocast_exclude_mps, set_seed
|
||||||
|
from tools.inference_engine.reference_loader import ReferenceLoader
|
||||||
|
from tools.inference_engine.utils import InferenceResult, wav_chunk_header
|
||||||
|
from tools.inference_engine.vq_manager import VQManager
|
||||||
|
from tools.llama.generate import (
|
||||||
|
GenerateRequest,
|
||||||
|
GenerateResponse,
|
||||||
|
WrappedGenerateResponse,
|
||||||
|
)
|
||||||
|
from tools.schema import ServeTTSRequest
|
||||||
|
|
||||||
|
|
||||||
|
class TTSInferenceEngine(ReferenceLoader, VQManager):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
llama_queue: queue.Queue,
|
||||||
|
decoder_model: FireflyArchitecture,
|
||||||
|
precision: torch.dtype,
|
||||||
|
compile: bool,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.llama_queue = llama_queue
|
||||||
|
self.decoder_model = decoder_model
|
||||||
|
self.precision = precision
|
||||||
|
self.compile = compile
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, None]:
|
||||||
|
"""
|
||||||
|
Main inference function:
|
||||||
|
- Loads the reference audio and text.
|
||||||
|
- Calls the LLAMA model for inference.
|
||||||
|
- Decodes the VQ tokens to audio.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ref_id: str | None = req.reference_id
|
||||||
|
prompt_tokens, prompt_texts = [], []
|
||||||
|
# Load the reference audio and text based on id or hash
|
||||||
|
if ref_id is not None:
|
||||||
|
prompt_tokens, prompt_texts = self.load_by_id(ref_id, req.use_memory_cache)
|
||||||
|
|
||||||
|
elif req.references:
|
||||||
|
prompt_tokens, prompt_texts = self.load_by_hash(
|
||||||
|
req.references, req.use_memory_cache
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the random seed if provided
|
||||||
|
if req.seed is not None:
|
||||||
|
set_seed(req.seed)
|
||||||
|
logger.warning(f"set seed: {req.seed}")
|
||||||
|
|
||||||
|
# Get the symbolic tokens from the LLAMA model
|
||||||
|
response_queue = self.send_Llama_request(req, prompt_tokens, prompt_texts)
|
||||||
|
|
||||||
|
# Get the sample rate from the decoder model
|
||||||
|
sample_rate = self.decoder_model.spec_transform.sample_rate
|
||||||
|
|
||||||
|
# If streaming, send the header
|
||||||
|
if req.streaming:
|
||||||
|
yield InferenceResult(
|
||||||
|
code="header",
|
||||||
|
audio=(sample_rate, wav_chunk_header(sample_rate=sample_rate)),
|
||||||
|
error=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
segments = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Get the response from the LLAMA model
|
||||||
|
wrapped_result: WrappedGenerateResponse = response_queue.get()
|
||||||
|
if wrapped_result.status == "error":
|
||||||
|
yield InferenceResult(
|
||||||
|
code="error",
|
||||||
|
audio=None,
|
||||||
|
error=(
|
||||||
|
wrapped_result.response
|
||||||
|
if isinstance(wrapped_result.response, Exception)
|
||||||
|
else Exception("Unknown error")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check the response type
|
||||||
|
if not isinstance(wrapped_result.response, GenerateResponse):
|
||||||
|
raise TypeError(
|
||||||
|
"Expected GenerateResponse, got {type(wrapped_result.response).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
result: GenerateResponse = wrapped_result.response
|
||||||
|
if result.action != "next":
|
||||||
|
segment = self.get_audio_segment(result)
|
||||||
|
|
||||||
|
if req.streaming: # Used only by the API server
|
||||||
|
yield InferenceResult(
|
||||||
|
code="segment",
|
||||||
|
audio=(sample_rate, segment),
|
||||||
|
error=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
segments.append(segment)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Clean up the memory
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Edge case: no audio generated
|
||||||
|
if len(segments) == 0:
|
||||||
|
yield InferenceResult(
|
||||||
|
code="error",
|
||||||
|
audio=None,
|
||||||
|
error=RuntimeError("No audio generated, please check the input text."),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Streaming or not, return the final audio
|
||||||
|
audio = np.concatenate(segments, axis=0)
|
||||||
|
yield InferenceResult(
|
||||||
|
code="final",
|
||||||
|
audio=(sample_rate, audio),
|
||||||
|
error=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def send_Llama_request(
|
||||||
|
self, req: ServeTTSRequest, prompt_tokens: list, prompt_texts: list
|
||||||
|
) -> queue.Queue:
|
||||||
|
"""
|
||||||
|
Send a request to the LLAMA model to generate the symbolic tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Prepare the request
|
||||||
|
request = dict(
|
||||||
|
device=self.decoder_model.device,
|
||||||
|
max_new_tokens=req.max_new_tokens,
|
||||||
|
text=(
|
||||||
|
req.text
|
||||||
|
if not req.normalize
|
||||||
|
else ChnNormedText(raw_text=req.text).normalize()
|
||||||
|
),
|
||||||
|
top_p=req.top_p,
|
||||||
|
repetition_penalty=req.repetition_penalty,
|
||||||
|
temperature=req.temperature,
|
||||||
|
compile=self.compile,
|
||||||
|
iterative_prompt=req.chunk_length > 0,
|
||||||
|
chunk_length=req.chunk_length,
|
||||||
|
max_length=4096,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
prompt_text=prompt_texts,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a queue to get the response
|
||||||
|
response_queue = queue.Queue()
|
||||||
|
|
||||||
|
# Send the request to the LLAMA model
|
||||||
|
self.llama_queue.put(
|
||||||
|
GenerateRequest(
|
||||||
|
request=request,
|
||||||
|
response_queue=response_queue,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return response_queue
|
||||||
|
|
||||||
|
def get_audio_segment(self, result: GenerateResponse) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Decode the VQ tokens to audio.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Don't use autocast on MPS devices
|
||||||
|
with autocast_exclude_mps(
|
||||||
|
device_type=self.decoder_model.device.type, dtype=self.precision
|
||||||
|
):
|
||||||
|
# Decode the symbolic tokens to audio
|
||||||
|
segment = self.decode_vq_tokens(codes=result.codes)
|
||||||
|
|
||||||
|
# Convert the audio to numpy
|
||||||
|
return segment.float().cpu().numpy()
|
128
tools/inference_engine/reference_loader.py
Normal file
128
tools/inference_engine/reference_loader.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
import io
|
||||||
|
from hashlib import sha256
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Literal, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
||||||
|
from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
|
||||||
|
from tools.schema import ServeReferenceAudio
|
||||||
|
|
||||||
|
|
||||||
|
class ReferenceLoader:
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""
|
||||||
|
Component of the TTSInferenceEngine class.
|
||||||
|
Loads and manages the cache for the reference audio and text.
|
||||||
|
"""
|
||||||
|
self.ref_by_id: dict = {}
|
||||||
|
self.ref_by_hash: dict = {}
|
||||||
|
|
||||||
|
# Make Pylance happy (attribut/method not defined...)
|
||||||
|
self.decoder_model: FireflyArchitecture
|
||||||
|
self.encode_reference: Callable
|
||||||
|
|
||||||
|
# Define the torchaudio backend
|
||||||
|
backends = torchaudio.list_audio_backends()
|
||||||
|
if "ffmpeg" in backends:
|
||||||
|
self.backend = "ffmpeg"
|
||||||
|
else:
|
||||||
|
self.backend = "soundfile"
|
||||||
|
|
||||||
|
def load_by_id(
|
||||||
|
self,
|
||||||
|
id: str,
|
||||||
|
use_cache: Literal["on", "off"],
|
||||||
|
) -> Tuple:
|
||||||
|
|
||||||
|
# Load the references audio and text by id
|
||||||
|
ref_folder = Path("references") / id
|
||||||
|
ref_folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
ref_audios = list_files(
|
||||||
|
ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_cache == "off" or id not in self.ref_by_id:
|
||||||
|
# If the references are not already loaded, encode them
|
||||||
|
prompt_tokens = [
|
||||||
|
self.encode_reference(
|
||||||
|
decoder_model=self.decoder_model,
|
||||||
|
reference_audio=audio_to_bytes(str(ref_audio)),
|
||||||
|
enable_reference_audio=True,
|
||||||
|
)
|
||||||
|
for ref_audio in ref_audios
|
||||||
|
]
|
||||||
|
prompt_texts = [
|
||||||
|
read_ref_text(str(ref_audio.with_suffix(".lab")))
|
||||||
|
for ref_audio in ref_audios
|
||||||
|
]
|
||||||
|
self.ref_by_id[id] = (prompt_tokens, prompt_texts)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Reuse already encoded references
|
||||||
|
logger.info("Use same references")
|
||||||
|
prompt_tokens, prompt_texts = self.ref_by_id[id]
|
||||||
|
|
||||||
|
return prompt_tokens, prompt_texts
|
||||||
|
|
||||||
|
def load_by_hash(
|
||||||
|
self,
|
||||||
|
references: list[ServeReferenceAudio],
|
||||||
|
use_cache: Literal["on", "off"],
|
||||||
|
) -> Tuple:
|
||||||
|
|
||||||
|
# Load the references audio and text by hash
|
||||||
|
audio_hashes = [sha256(ref.audio).hexdigest() for ref in references]
|
||||||
|
|
||||||
|
cache_used = False
|
||||||
|
prompt_tokens, prompt_texts = [], []
|
||||||
|
for i, ref in enumerate(references):
|
||||||
|
if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash:
|
||||||
|
# If the references are not already loaded, encode them
|
||||||
|
prompt_tokens.append(
|
||||||
|
self.encode_reference(
|
||||||
|
decoder_model=self.decoder_model,
|
||||||
|
reference_audio=ref.audio,
|
||||||
|
enable_reference_audio=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
prompt_texts.append(ref.text)
|
||||||
|
self.ref_by_hash[audio_hashes[i]] = (prompt_tokens, prompt_texts)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Reuse already encoded references
|
||||||
|
prompt_text, prompt_token = self.ref_by_hash[audio_hashes[i]]
|
||||||
|
prompt_texts.append(prompt_text)
|
||||||
|
prompt_tokens.append(prompt_token)
|
||||||
|
cache_used = True
|
||||||
|
|
||||||
|
if cache_used:
|
||||||
|
logger.info("Use same references")
|
||||||
|
|
||||||
|
return prompt_tokens, prompt_texts
|
||||||
|
|
||||||
|
def load_audio(self, reference_audio, sr):
|
||||||
|
"""
|
||||||
|
Load the audio data from a file or bytes.
|
||||||
|
"""
|
||||||
|
if len(reference_audio) > 255 or not Path(reference_audio).exists():
|
||||||
|
audio_data = reference_audio
|
||||||
|
reference_audio = io.BytesIO(audio_data)
|
||||||
|
|
||||||
|
waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend)
|
||||||
|
|
||||||
|
if waveform.shape[0] > 1:
|
||||||
|
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||||
|
|
||||||
|
if original_sr != sr:
|
||||||
|
resampler = torchaudio.transforms.Resample(
|
||||||
|
orig_freq=original_sr, new_freq=sr
|
||||||
|
)
|
||||||
|
waveform = resampler(waveform)
|
||||||
|
|
||||||
|
audio = waveform.squeeze().numpy()
|
||||||
|
return audio
|
42
tools/inference_engine/utils.py
Normal file
42
tools/inference_engine/utils.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import io
|
||||||
|
import wave
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InferenceResult:
|
||||||
|
code: Literal["header", "segment", "error", "final"]
|
||||||
|
audio: Optional[Tuple[int, np.ndarray]]
|
||||||
|
error: Optional[Exception]
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_text(user_input: str, use_normalization: bool) -> str:
|
||||||
|
"""Normalize user input text if needed."""
|
||||||
|
if use_normalization:
|
||||||
|
return ChnNormedText(raw_text=user_input).normalize()
|
||||||
|
else:
|
||||||
|
return user_input
|
||||||
|
|
||||||
|
|
||||||
|
def wav_chunk_header(
|
||||||
|
sample_rate: int = 44100, bit_depth: int = 16, channels: int = 1
|
||||||
|
) -> np.ndarray:
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
|
||||||
|
with wave.open(buffer, "wb") as wav_file:
|
||||||
|
wav_file.setnchannels(channels)
|
||||||
|
wav_file.setsampwidth(bit_depth // 8)
|
||||||
|
wav_file.setframerate(sample_rate)
|
||||||
|
|
||||||
|
wav_header_bytes = buffer.getvalue()
|
||||||
|
buffer.close()
|
||||||
|
|
||||||
|
# Convert to numpy array
|
||||||
|
wav_header = np.frombuffer(wav_header_bytes, dtype=np.uint8)
|
||||||
|
|
||||||
|
return wav_header
|
57
tools/inference_engine/vq_manager.py
Normal file
57
tools/inference_engine/vq_manager.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
||||||
|
|
||||||
|
|
||||||
|
class VQManager:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Make Pylance happy (attribut/method not defined...)
|
||||||
|
self.decoder_model: FireflyArchitecture
|
||||||
|
self.load_audio: Callable
|
||||||
|
|
||||||
|
def decode_vq_tokens(self, codes):
|
||||||
|
feature_lengths = torch.tensor(
|
||||||
|
[codes.shape[1]], device=self.decoder_model.device
|
||||||
|
)
|
||||||
|
logger.info(f"VQ features: {codes.shape}")
|
||||||
|
|
||||||
|
if isinstance(self.decoder_model, FireflyArchitecture):
|
||||||
|
return self.decoder_model.decode(
|
||||||
|
indices=codes[None],
|
||||||
|
feature_lengths=feature_lengths,
|
||||||
|
)[0].squeeze()
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
|
||||||
|
|
||||||
|
def encode_reference(self, reference_audio, enable_reference_audio):
|
||||||
|
if enable_reference_audio and reference_audio is not None:
|
||||||
|
# Load audios, and prepare basic info here
|
||||||
|
reference_audio_content = self.load_audio(
|
||||||
|
reference_audio, self.decoder_model.spec_transform.sample_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
audios = torch.from_numpy(reference_audio_content).to(
|
||||||
|
self.decoder_model.device
|
||||||
|
)[None, None, :]
|
||||||
|
audio_lengths = torch.tensor(
|
||||||
|
[audios.shape[2]], device=self.decoder_model.device, dtype=torch.long
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Loaded audio with {audios.shape[2] / self.decoder_model.spec_transform.sample_rate:.2f} seconds"
|
||||||
|
)
|
||||||
|
|
||||||
|
# VQ Encoder
|
||||||
|
if isinstance(self.decoder_model, FireflyArchitecture):
|
||||||
|
prompt_tokens = self.decoder_model.encode(audios, audio_lengths)[0][0]
|
||||||
|
logger.info(f"Encoded prompt: {prompt_tokens.shape}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
|
||||||
|
else:
|
||||||
|
prompt_tokens = None
|
||||||
|
logger.info("No reference audio provided")
|
||||||
|
|
||||||
|
return prompt_tokens
|
@ -1,95 +0,0 @@
|
|||||||
import os
|
|
||||||
from argparse import ArgumentParser
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import ormsgpack
|
|
||||||
|
|
||||||
from tools.schema import ServeReferenceAudio, ServeTTSRequest
|
|
||||||
|
|
||||||
api_key = os.environ.get("FISH_API_KEY", "YOUR_API_KEY")
|
|
||||||
|
|
||||||
|
|
||||||
def audio_request():
|
|
||||||
# priority: ref_id > references
|
|
||||||
request = ServeTTSRequest(
|
|
||||||
text="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
|
|
||||||
# reference_id="114514",
|
|
||||||
references=[
|
|
||||||
ServeReferenceAudio(
|
|
||||||
audio=open("lengyue.wav", "rb").read(),
|
|
||||||
text=open("lengyue.lab", "r", encoding="utf-8").read(),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
streaming=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
api_key = os.environ.get("FISH_API_KEY", "YOUR_API_KEY")
|
|
||||||
|
|
||||||
with (
|
|
||||||
httpx.Client() as client,
|
|
||||||
open("hello.wav", "wb") as f,
|
|
||||||
):
|
|
||||||
with client.stream(
|
|
||||||
"POST",
|
|
||||||
"http://127.0.0.1:8080/v1/tts",
|
|
||||||
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
|
||||||
headers={
|
|
||||||
"authorization": f"Bearer {api_key}",
|
|
||||||
"content-type": "application/msgpack",
|
|
||||||
},
|
|
||||||
timeout=None,
|
|
||||||
) as response:
|
|
||||||
for chunk in response.iter_bytes():
|
|
||||||
f.write(chunk)
|
|
||||||
|
|
||||||
|
|
||||||
def asr_request(audio_path: Path):
|
|
||||||
|
|
||||||
# Read the audio file
|
|
||||||
with open(
|
|
||||||
str(audio_path),
|
|
||||||
"rb",
|
|
||||||
) as audio_file:
|
|
||||||
audio_data = audio_file.read()
|
|
||||||
|
|
||||||
# Prepare the request data
|
|
||||||
request_data = {
|
|
||||||
"audio": audio_data,
|
|
||||||
"language": "en", # Optional: specify the language
|
|
||||||
"ignore_timestamps": False, # Optional: set to True to ignore precise timestamps
|
|
||||||
}
|
|
||||||
|
|
||||||
# Send the request
|
|
||||||
with httpx.Client() as client:
|
|
||||||
response = client.post(
|
|
||||||
"https://api.fish.audio/v1/asr",
|
|
||||||
headers={
|
|
||||||
"Authorization": f"Bearer {api_key}",
|
|
||||||
"Content-Type": "application/msgpack",
|
|
||||||
},
|
|
||||||
content=ormsgpack.packb(request_data),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Parse the response
|
|
||||||
result = response.json()
|
|
||||||
|
|
||||||
print(f"Transcribed text: {result['text']}")
|
|
||||||
print(f"Audio duration: {result['duration']} seconds")
|
|
||||||
|
|
||||||
for segment in result["segments"]:
|
|
||||||
print(f"Segment: {segment['text']}")
|
|
||||||
print(f"Start time: {segment['start']}, End time: {segment['end']}")
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = ArgumentParser()
|
|
||||||
parser.add_argument("--audio_path", type=Path, default="audio/ref/trump.mp3")
|
|
||||||
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
args = parse_args()
|
|
||||||
|
|
||||||
asr_request(args.audio_path)
|
|
101
tools/run_webui.py
Normal file
101
tools/run_webui.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
import os
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pyrootutils
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||||
|
|
||||||
|
from tools.inference_engine import TTSInferenceEngine
|
||||||
|
from tools.llama.generate import launch_thread_safe_queue
|
||||||
|
from tools.schema import ServeTTSRequest
|
||||||
|
from tools.vqgan.inference import load_model as load_decoder_model
|
||||||
|
from tools.webui import build_app
|
||||||
|
from tools.webui.inference import get_inference_wrapper
|
||||||
|
|
||||||
|
# Make einx happy
|
||||||
|
os.environ["EINX_FILTER_TRACEBACK"] = "false"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--llama-checkpoint-path",
|
||||||
|
type=Path,
|
||||||
|
default="checkpoints/fish-speech-1.5",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-checkpoint-path",
|
||||||
|
type=Path,
|
||||||
|
default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
||||||
|
)
|
||||||
|
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
||||||
|
parser.add_argument("--device", type=str, default="cuda")
|
||||||
|
parser.add_argument("--half", action="store_true")
|
||||||
|
parser.add_argument("--compile", action="store_true")
|
||||||
|
parser.add_argument("--max-gradio-length", type=int, default=0)
|
||||||
|
parser.add_argument("--theme", type=str, default="light")
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
args.precision = torch.half if args.half else torch.bfloat16
|
||||||
|
|
||||||
|
# Check if CUDA is available
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
logger.info("CUDA is not available, running on CPU.")
|
||||||
|
args.device = "cpu"
|
||||||
|
|
||||||
|
logger.info("Loading Llama model...")
|
||||||
|
llama_queue = launch_thread_safe_queue(
|
||||||
|
checkpoint_path=args.llama_checkpoint_path,
|
||||||
|
device=args.device,
|
||||||
|
precision=args.precision,
|
||||||
|
compile=args.compile,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Loading VQ-GAN model...")
|
||||||
|
decoder_model = load_decoder_model(
|
||||||
|
config_name=args.decoder_config_name,
|
||||||
|
checkpoint_path=args.decoder_checkpoint_path,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Decoder model loaded, warming up...")
|
||||||
|
|
||||||
|
# Create the inference engine
|
||||||
|
inference_engine = TTSInferenceEngine(
|
||||||
|
llama_queue=llama_queue,
|
||||||
|
decoder_model=decoder_model,
|
||||||
|
compile=args.compile,
|
||||||
|
precision=args.precision,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
||||||
|
list(
|
||||||
|
inference_engine.inference(
|
||||||
|
ServeTTSRequest(
|
||||||
|
text="Hello world.",
|
||||||
|
references=[],
|
||||||
|
reference_id=None,
|
||||||
|
max_new_tokens=0,
|
||||||
|
chunk_length=200,
|
||||||
|
top_p=0.7,
|
||||||
|
repetition_penalty=1.5,
|
||||||
|
temperature=0.7,
|
||||||
|
format="wav",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Warming up done, launching the web UI...")
|
||||||
|
|
||||||
|
# Get the inference function with the immutable arguments
|
||||||
|
inference_fct = get_inference_wrapper(inference_engine)
|
||||||
|
|
||||||
|
app = build_app(inference_fct, args.theme)
|
||||||
|
app.launch(show_api=True)
|
@ -1,16 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
import queue
|
import queue
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Annotated, Literal, Optional
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import AfterValidator, BaseModel, Field, confloat, conint, conlist
|
from pydantic import BaseModel, Field, conint, conlist
|
||||||
from pydantic.functional_validators import SkipValidation
|
from pydantic.functional_validators import SkipValidation
|
||||||
|
|
||||||
from fish_speech.conversation import Message, TextPart, VQPart
|
from fish_speech.conversation import Message, TextPart, VQPart
|
||||||
|
|
||||||
GLOBAL_NUM_SAMPLES = int(os.getenv("GLOBAL_NUM_SAMPLES", 1))
|
|
||||||
|
|
||||||
|
|
||||||
class ServeVQPart(BaseModel):
|
class ServeVQPart(BaseModel):
|
||||||
type: Literal["vq"] = "vq"
|
type: Literal["vq"] = "vq"
|
||||||
@ -64,7 +62,7 @@ class ServeASRResponse(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ServeMessage(BaseModel):
|
class ServeMessage(BaseModel):
|
||||||
role: Literal["system", "assistant", "user", "raw"]
|
role: Literal["system", "assistant", "user"]
|
||||||
parts: list[ServeVQPart | ServeTextPart]
|
parts: list[ServeVQPart | ServeTextPart]
|
||||||
|
|
||||||
def to_conversation_message(self):
|
def to_conversation_message(self):
|
||||||
@ -85,7 +83,7 @@ class ServeMessage(BaseModel):
|
|||||||
return new_message
|
return new_message
|
||||||
|
|
||||||
|
|
||||||
class ServeRequest(BaseModel):
|
class ServeChatRequest(BaseModel):
|
||||||
messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
|
messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
|
||||||
max_new_tokens: int = 1024
|
max_new_tokens: int = 1024
|
||||||
top_p: float = 0.7
|
top_p: float = 0.7
|
||||||
@ -114,11 +112,6 @@ class ServeVQGANDecodeResponse(BaseModel):
|
|||||||
audios: list[bytes]
|
audios: list[bytes]
|
||||||
|
|
||||||
|
|
||||||
class ServeReferenceAudio(BaseModel):
|
|
||||||
audio: bytes
|
|
||||||
text: str
|
|
||||||
|
|
||||||
|
|
||||||
class ServeForwardMessage(BaseModel):
|
class ServeForwardMessage(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
content: str
|
content: str
|
||||||
@ -150,24 +143,11 @@ class ServeReferenceAudio(BaseModel):
|
|||||||
return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
|
return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
|
||||||
|
|
||||||
|
|
||||||
class ServeChatRequestV1(BaseModel):
|
|
||||||
model: str = "llama3-8b"
|
|
||||||
messages: list[ServeForwardMessage] = []
|
|
||||||
audio: bytes | None = None
|
|
||||||
temperature: float = 1.0
|
|
||||||
top_p: float = 1.0
|
|
||||||
max_tokens: int = 256
|
|
||||||
voice: str = "jessica"
|
|
||||||
tts_audio_format: Literal["mp3", "pcm", "opus"] = "mp3"
|
|
||||||
tts_audio_bitrate: Literal[16, 24, 32, 48, 64, 96, 128, 192] = 128
|
|
||||||
|
|
||||||
|
|
||||||
class ServeTTSRequest(BaseModel):
|
class ServeTTSRequest(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
|
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
|
||||||
# Audio format
|
# Audio format
|
||||||
format: Literal["wav", "pcm", "mp3"] = "wav"
|
format: Literal["wav", "pcm", "mp3"] = "wav"
|
||||||
mp3_bitrate: Literal[64, 128, 192] = 128
|
|
||||||
# References audios for in-context learning
|
# References audios for in-context learning
|
||||||
references: list[ServeReferenceAudio] = []
|
references: list[ServeReferenceAudio] = []
|
||||||
# Reference id
|
# Reference id
|
||||||
@ -175,16 +155,16 @@ class ServeTTSRequest(BaseModel):
|
|||||||
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
|
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
|
||||||
reference_id: str | None = None
|
reference_id: str | None = None
|
||||||
seed: int | None = None
|
seed: int | None = None
|
||||||
use_memory_cache: Literal["on-demand", "never"] = "never"
|
use_memory_cache: Literal["on", "off"] = "off"
|
||||||
# Normalize text for en & zh, this increase stability for numbers
|
# Normalize text for en & zh, this increase stability for numbers
|
||||||
normalize: bool = True
|
normalize: bool = True
|
||||||
mp3_bitrate: Optional[int] = 64
|
|
||||||
opus_bitrate: Optional[int] = -1000
|
|
||||||
# Balance mode will reduce latency to 300ms, but may decrease stability
|
|
||||||
latency: Literal["normal", "balanced"] = "normal"
|
|
||||||
# not usually used below
|
# not usually used below
|
||||||
streaming: bool = False
|
streaming: bool = False
|
||||||
max_new_tokens: int = 1024
|
max_new_tokens: int = 1024
|
||||||
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
||||||
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
|
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
|
||||||
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
# Allow arbitrary types for pytorch related types
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
57
tools/server/agent/__init__.py
Normal file
57
tools/server/agent/__init__.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import struct
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import ormsgpack
|
||||||
|
|
||||||
|
from tools.server.agent.generate import generate_responses
|
||||||
|
from tools.server.agent.pre_generation_utils import prepare_messages
|
||||||
|
|
||||||
|
|
||||||
|
def execute_request(input_queue, tokenizer, config, request, device):
|
||||||
|
"""
|
||||||
|
This function prepares the conversation, encodes the request,
|
||||||
|
sends the generation request, and handles decoding/streaming.
|
||||||
|
It returns a response generator (ServeResponse or ServeStreamResponse).
|
||||||
|
"""
|
||||||
|
prompt, im_end_id = prepare_messages(request, tokenizer, config)
|
||||||
|
yield from generate_responses(
|
||||||
|
input_queue, tokenizer, config, request, prompt, im_end_id, device
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def response_generator(req, llama_queue, tokenizer, config, device):
|
||||||
|
"""
|
||||||
|
Non-streaming response wrapper for the chat endpoint.
|
||||||
|
Only returns the final result.
|
||||||
|
"""
|
||||||
|
generator = execute_request(llama_queue, tokenizer, config, req, device)
|
||||||
|
return next(generator)
|
||||||
|
|
||||||
|
|
||||||
|
async def streaming_generator(req, llama_queue, tokenizer, config, device, json_mode):
|
||||||
|
"""
|
||||||
|
Streaming response wrapper for the chat endpoint.
|
||||||
|
Returns the response in chunks.
|
||||||
|
"""
|
||||||
|
generator = execute_request(llama_queue, tokenizer, config, req, device)
|
||||||
|
for i in generator:
|
||||||
|
if json_mode:
|
||||||
|
body = i.model_dump_json().encode("utf-8")
|
||||||
|
yield b"data: " + body + b"\n\n"
|
||||||
|
else:
|
||||||
|
body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
|
||||||
|
yield struct.pack("I", len(body)) + body
|
||||||
|
|
||||||
|
|
||||||
|
def get_response_generator(
|
||||||
|
llama_queue, tokenizer, config, req, device, json_mode
|
||||||
|
) -> partial:
|
||||||
|
"""
|
||||||
|
Get the correct response generator based on the request.
|
||||||
|
"""
|
||||||
|
if not req.streaming:
|
||||||
|
return partial(response_generator, req, llama_queue, tokenizer, config, device)
|
||||||
|
else:
|
||||||
|
return partial(
|
||||||
|
streaming_generator, req, llama_queue, tokenizer, config, device, json_mode
|
||||||
|
)
|
119
tools/server/agent/generate.py
Normal file
119
tools/server/agent/generate.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
from tools.schema import ServeMessage, ServeResponse, ServeStreamResponse
|
||||||
|
from tools.server.agent.generation_utils import (
|
||||||
|
initialize_decode_buffers,
|
||||||
|
process_response_tokens,
|
||||||
|
send_reset_buffer,
|
||||||
|
)
|
||||||
|
from tools.server.agent.pre_generation_utils import (
|
||||||
|
create_generation_request,
|
||||||
|
send_generation_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_responses(
|
||||||
|
input_queue, tokenizer, config, request, prompt, im_end_id, device
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Main generation function that handles the conversation, encodes the request,
|
||||||
|
sends the generation request, and handles decoding/streaming.
|
||||||
|
It returns a response generator (ServeResponse or ServeStreamResponse).
|
||||||
|
"""
|
||||||
|
stats = {}
|
||||||
|
start = time.time()
|
||||||
|
stats["start_time"] = start
|
||||||
|
stats["tokens_count"] = 0
|
||||||
|
|
||||||
|
# Prepare and send the generation request
|
||||||
|
req = create_generation_request(prompt, request, im_end_id, device)
|
||||||
|
response_queue = send_generation_request(input_queue, req)
|
||||||
|
decode_buffer, parts, finished = initialize_decode_buffers(request.num_samples)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
response = response_queue.get()
|
||||||
|
|
||||||
|
# Handle abnormal finish or error
|
||||||
|
if response in ["stop", "error"]:
|
||||||
|
finish_reason = response
|
||||||
|
break
|
||||||
|
|
||||||
|
# Process the response tokens
|
||||||
|
is_first_token = stats["tokens_count"] == 0
|
||||||
|
responses = process_response_tokens(
|
||||||
|
response,
|
||||||
|
tokenizer,
|
||||||
|
config,
|
||||||
|
request,
|
||||||
|
decode_buffer,
|
||||||
|
parts,
|
||||||
|
finished,
|
||||||
|
im_end_id,
|
||||||
|
stats,
|
||||||
|
start,
|
||||||
|
is_first_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Yield the responses if streaming
|
||||||
|
if request.streaming and responses:
|
||||||
|
for r in responses:
|
||||||
|
yield r
|
||||||
|
|
||||||
|
stats["tokens_count"] += 1
|
||||||
|
|
||||||
|
# Check if all samples are finished
|
||||||
|
if all(finished):
|
||||||
|
finish_reason = "stop"
|
||||||
|
break
|
||||||
|
|
||||||
|
# Finalize the response
|
||||||
|
final_responses = finalize_response(
|
||||||
|
request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
|
||||||
|
)
|
||||||
|
for fr in final_responses:
|
||||||
|
yield fr
|
||||||
|
|
||||||
|
|
||||||
|
def finalize_response(
|
||||||
|
request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Finalize the response by sending the remaining text buffers.
|
||||||
|
"""
|
||||||
|
responses = []
|
||||||
|
|
||||||
|
# Send the remaining text buffers
|
||||||
|
for sample_id in range(request.num_samples):
|
||||||
|
responses.extend(
|
||||||
|
send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate the final stats
|
||||||
|
stats["total_time"] = (time.time() - stats["start_time"]) * 1000
|
||||||
|
stats["total_tokens"] = stats["tokens_count"]
|
||||||
|
|
||||||
|
# If streaming, send the final chunks for each sample
|
||||||
|
if request.streaming:
|
||||||
|
for sample_id in range(request.num_samples):
|
||||||
|
if finished[sample_id]:
|
||||||
|
continue
|
||||||
|
responses.append(
|
||||||
|
ServeStreamResponse(
|
||||||
|
finish_reason=finish_reason, stats=stats, sample_id=sample_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# If not streaming, send the full messages for each sample
|
||||||
|
full_messages = [
|
||||||
|
ServeMessage(role="assistant", parts=parts[i])
|
||||||
|
for i in range(request.num_samples)
|
||||||
|
]
|
||||||
|
responses.append(
|
||||||
|
ServeResponse(
|
||||||
|
messages=full_messages,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
stats=stats,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return responses
|
122
tools/server/agent/generation_utils.py
Normal file
122
tools/server/agent/generation_utils.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
from tools.schema import (
|
||||||
|
ServeStreamDelta,
|
||||||
|
ServeStreamResponse,
|
||||||
|
ServeTextPart,
|
||||||
|
ServeVQPart,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_decode_buffers(num_samples):
|
||||||
|
"""Initialise the decode buffers for each sample."""
|
||||||
|
decode_buffer = [[] for _ in range(num_samples)]
|
||||||
|
parts = [[] for _ in range(num_samples)]
|
||||||
|
finished = [False for _ in range(num_samples)]
|
||||||
|
return decode_buffer, parts, finished
|
||||||
|
|
||||||
|
|
||||||
|
def send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request):
|
||||||
|
"""Send the remaining text buffer for a sample."""
|
||||||
|
if len(decode_buffer[sample_id]) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
decoded = tokenizer.decode(decode_buffer[sample_id])
|
||||||
|
part = ServeTextPart(text=decoded)
|
||||||
|
|
||||||
|
responses = []
|
||||||
|
if request.streaming:
|
||||||
|
responses.append(ServeStreamResponse(delta=ServeStreamDelta(part=part)))
|
||||||
|
else:
|
||||||
|
parts[sample_id].append(part)
|
||||||
|
|
||||||
|
decode_buffer[sample_id] = []
|
||||||
|
return responses
|
||||||
|
|
||||||
|
|
||||||
|
def handle_semantic_tokens(tokens, config, sample_id, parts, request):
|
||||||
|
"""Handle the semantic tokens returned by the model."""
|
||||||
|
responses = []
|
||||||
|
_tokens = tokens[1:].clone()
|
||||||
|
|
||||||
|
if not config.share_codebook_embeddings:
|
||||||
|
for i in range(len(_tokens)):
|
||||||
|
_tokens[i] -= config.codebook_size * i
|
||||||
|
|
||||||
|
# If streaming, send the VQ parts directly
|
||||||
|
if request.streaming:
|
||||||
|
responses.append(
|
||||||
|
ServeStreamResponse(
|
||||||
|
sample_id=sample_id,
|
||||||
|
delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# If not streaming, accumulate the VQ parts
|
||||||
|
if not parts[sample_id] or not isinstance(parts[sample_id][-1], ServeVQPart):
|
||||||
|
parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
|
||||||
|
else:
|
||||||
|
# Accumulate the codes
|
||||||
|
for codebook_id, value in enumerate(_tokens):
|
||||||
|
parts[sample_id][-1].codes[codebook_id].append(value.item())
|
||||||
|
|
||||||
|
return responses
|
||||||
|
|
||||||
|
|
||||||
|
def process_response_tokens(
|
||||||
|
response,
|
||||||
|
tokenizer,
|
||||||
|
config,
|
||||||
|
request,
|
||||||
|
decode_buffer,
|
||||||
|
parts,
|
||||||
|
finished,
|
||||||
|
im_end_id,
|
||||||
|
stats,
|
||||||
|
start,
|
||||||
|
is_first_token,
|
||||||
|
):
|
||||||
|
"""Process the response tokens returned by the model."""
|
||||||
|
responses = []
|
||||||
|
for sample_id, tokens in enumerate(response):
|
||||||
|
if finished[sample_id]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# End of the conversation
|
||||||
|
if tokens[0] == im_end_id:
|
||||||
|
finished[sample_id] = True
|
||||||
|
# Send the remaining text buffer
|
||||||
|
responses.extend(
|
||||||
|
send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
|
||||||
|
)
|
||||||
|
if request.streaming:
|
||||||
|
responses.append(
|
||||||
|
ServeStreamResponse(
|
||||||
|
sample_id=sample_id,
|
||||||
|
finish_reason="stop",
|
||||||
|
stats=stats,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if the token is semantic
|
||||||
|
is_semantic = (
|
||||||
|
tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_semantic:
|
||||||
|
# Before the semantic tokens, send the remaining text buffer
|
||||||
|
responses.extend(
|
||||||
|
send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
|
||||||
|
)
|
||||||
|
responses.extend(
|
||||||
|
handle_semantic_tokens(tokens, config, sample_id, parts, request)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Accumulate the text tokens (not implemented?)
|
||||||
|
decode_buffer[sample_id].append(tokens[0, 0])
|
||||||
|
|
||||||
|
if is_first_token:
|
||||||
|
stats["time_to_first_token"] = (time.time() - start) * 1000
|
||||||
|
|
||||||
|
return responses
|
72
tools/server/agent/pre_generation_utils.py
Normal file
72
tools/server/agent/pre_generation_utils.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
import queue
|
||||||
|
|
||||||
|
from fish_speech.conversation import Conversation, Message
|
||||||
|
from fish_speech.tokenizer import IM_END_TOKEN
|
||||||
|
from tools.llama.generate import GenerateRequest
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_messages(request, tokenizer, config):
|
||||||
|
"""
|
||||||
|
Reorganise the provided list of messages into a conversation.
|
||||||
|
Encode the conversation for inference.
|
||||||
|
"""
|
||||||
|
# Convert the messages to ConversationMessage objects
|
||||||
|
messages = [msg.to_conversation_message() for msg in request.messages]
|
||||||
|
|
||||||
|
if len(messages) < 1:
|
||||||
|
raise ValueError("At least one message is required")
|
||||||
|
|
||||||
|
# Check the last message to determine the next step
|
||||||
|
last_role = messages[-1].role
|
||||||
|
match last_role:
|
||||||
|
case "user":
|
||||||
|
# The last message is from the user, ask the assistant to respond with a new message
|
||||||
|
messages.append(
|
||||||
|
Message(role="assistant", parts=[], add_im_end=False, modality="voice")
|
||||||
|
)
|
||||||
|
case "raw":
|
||||||
|
# The last message is raw text, ask the assistant to complete it
|
||||||
|
messages[-1].add_im_start = False
|
||||||
|
messages[-1].add_im_end = False
|
||||||
|
messages[-1].modality = "voice"
|
||||||
|
case "assistant":
|
||||||
|
# The last message is from the assistant, ask the assistant to continue
|
||||||
|
messages[-1].add_im_end = False
|
||||||
|
case _:
|
||||||
|
# We expect it to be assistant if not user or raw
|
||||||
|
raise ValueError("The last message must be from the assistant, user or raw")
|
||||||
|
|
||||||
|
# Create a conversation object and encode it for inference
|
||||||
|
conv = Conversation(messages=messages)
|
||||||
|
prompt = conv.encode_for_inference(
|
||||||
|
tokenizer=tokenizer, num_codebooks=config.num_codebooks
|
||||||
|
)
|
||||||
|
im_end_id = tokenizer.get_token_id(IM_END_TOKEN)
|
||||||
|
|
||||||
|
return prompt, im_end_id
|
||||||
|
|
||||||
|
|
||||||
|
def create_generation_request(prompt, request, im_end_id, device):
|
||||||
|
"""
|
||||||
|
Convert the request into a dictionary that can be sent to the model for generation.
|
||||||
|
"""
|
||||||
|
req = {
|
||||||
|
"prompt": prompt.to(device),
|
||||||
|
"max_new_tokens": request.max_new_tokens,
|
||||||
|
"im_end_id": im_end_id,
|
||||||
|
"temperature": request.temperature,
|
||||||
|
"top_p": request.top_p,
|
||||||
|
"repetition_penalty": request.repetition_penalty,
|
||||||
|
"num_samples": request.num_samples,
|
||||||
|
"early_stop_threshold": request.early_stop_threshold,
|
||||||
|
}
|
||||||
|
return req
|
||||||
|
|
||||||
|
|
||||||
|
def send_generation_request(input_queue, req):
|
||||||
|
"""
|
||||||
|
Send the generation request to the model and return a queue to get the response.
|
||||||
|
"""
|
||||||
|
response_queue = queue.Queue()
|
||||||
|
input_queue.put(GenerateRequest(req, response_queue))
|
||||||
|
return response_queue
|
75
tools/server/api_utils.py
Normal file
75
tools/server/api_utils.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
from argparse import ArgumentParser
|
||||||
|
from http import HTTPStatus
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
import ormsgpack
|
||||||
|
from baize.datastructures import ContentType
|
||||||
|
from kui.asgi import HTTPException, HttpRequest
|
||||||
|
|
||||||
|
from tools.inference_engine import TTSInferenceEngine
|
||||||
|
from tools.schema import ServeTTSRequest
|
||||||
|
from tools.server.inference import inference_wrapper as inference
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
|
||||||
|
parser.add_argument("--load-asr-model", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--llama-checkpoint-path",
|
||||||
|
type=str,
|
||||||
|
default="checkpoints/fish-speech-1.5",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-checkpoint-path",
|
||||||
|
type=str,
|
||||||
|
default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
||||||
|
)
|
||||||
|
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
||||||
|
parser.add_argument("--device", type=str, default="cuda")
|
||||||
|
parser.add_argument("--half", action="store_true")
|
||||||
|
parser.add_argument("--compile", action="store_true")
|
||||||
|
parser.add_argument("--max-text-length", type=int, default=0)
|
||||||
|
parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
|
||||||
|
parser.add_argument("--workers", type=int, default=1)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
class MsgPackRequest(HttpRequest):
|
||||||
|
async def data(
|
||||||
|
self,
|
||||||
|
) -> Annotated[
|
||||||
|
Any, ContentType("application/msgpack"), ContentType("application/json")
|
||||||
|
]:
|
||||||
|
if self.content_type == "application/msgpack":
|
||||||
|
return ormsgpack.unpackb(await self.body)
|
||||||
|
|
||||||
|
elif self.content_type == "application/json":
|
||||||
|
return await self.json
|
||||||
|
|
||||||
|
raise HTTPException(
|
||||||
|
HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
|
||||||
|
headers={"Accept": "application/msgpack, application/json"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine):
|
||||||
|
for chunk in inference(req, engine):
|
||||||
|
if isinstance(chunk, bytes):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
async def buffer_to_async_generator(buffer):
|
||||||
|
yield buffer
|
||||||
|
|
||||||
|
|
||||||
|
def get_content_type(audio_format):
|
||||||
|
if audio_format == "wav":
|
||||||
|
return "audio/wav"
|
||||||
|
elif audio_format == "flac":
|
||||||
|
return "audio/flac"
|
||||||
|
elif audio_format == "mp3":
|
||||||
|
return "audio/mpeg"
|
||||||
|
else:
|
||||||
|
return "application/octet-stream"
|
27
tools/server/exception_handler.py
Normal file
27
tools/server/exception_handler.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import traceback
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
from kui.asgi import HTTPException, JSONResponse
|
||||||
|
|
||||||
|
|
||||||
|
class ExceptionHandler:
|
||||||
|
|
||||||
|
async def http_exception_handler(self, exc: HTTPException):
|
||||||
|
return JSONResponse(
|
||||||
|
dict(
|
||||||
|
statusCode=exc.status_code,
|
||||||
|
message=exc.content,
|
||||||
|
error=HTTPStatus(exc.status_code).phrase,
|
||||||
|
),
|
||||||
|
exc.status_code,
|
||||||
|
exc.headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def other_exception_handler(self, exc: Exception):
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
status = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
|
return JSONResponse(
|
||||||
|
dict(statusCode=status, message=str(exc), error=status.phrase),
|
||||||
|
status,
|
||||||
|
)
|
41
tools/server/inference.py
Normal file
41
tools/server/inference.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from kui.asgi import HTTPException
|
||||||
|
|
||||||
|
from tools.inference_engine import TTSInferenceEngine
|
||||||
|
from tools.schema import ServeTTSRequest
|
||||||
|
|
||||||
|
AMPLITUDE = 32768 # Needs an explaination
|
||||||
|
|
||||||
|
|
||||||
|
def inference_wrapper(req: ServeTTSRequest, engine: TTSInferenceEngine):
|
||||||
|
"""
|
||||||
|
Wrapper for the inference function.
|
||||||
|
Used in the API server.
|
||||||
|
"""
|
||||||
|
for result in engine.inference(req):
|
||||||
|
match result.code:
|
||||||
|
case "header":
|
||||||
|
if isinstance(result.audio, tuple):
|
||||||
|
yield result.audio[1]
|
||||||
|
|
||||||
|
case "error":
|
||||||
|
raise HTTPException(
|
||||||
|
HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||||
|
content=str(result.error),
|
||||||
|
)
|
||||||
|
|
||||||
|
case "segment":
|
||||||
|
if isinstance(result.audio, tuple):
|
||||||
|
yield (result.audio[1] * AMPLITUDE).astype(np.int16).tobytes()
|
||||||
|
|
||||||
|
case "final":
|
||||||
|
if isinstance(result.audio, tuple):
|
||||||
|
yield result.audio[1]
|
||||||
|
return None # Stop the generator
|
||||||
|
|
||||||
|
raise HTTPException(
|
||||||
|
HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||||
|
content="No audio generated, please check the input text.",
|
||||||
|
)
|
119
tools/server/model_manager.py
Normal file
119
tools/server/model_manager.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
import torch
|
||||||
|
from funasr import AutoModel
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from tools.inference_engine import TTSInferenceEngine
|
||||||
|
from tools.llama.generate import (
|
||||||
|
launch_thread_safe_queue,
|
||||||
|
launch_thread_safe_queue_agent,
|
||||||
|
)
|
||||||
|
from tools.schema import ServeTTSRequest
|
||||||
|
from tools.server.inference import inference_wrapper as inference
|
||||||
|
from tools.vqgan.inference import load_model as load_decoder_model
|
||||||
|
|
||||||
|
ASR_MODEL_NAME = "iic/SenseVoiceSmall"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelManager:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mode: str,
|
||||||
|
device: str,
|
||||||
|
half: bool,
|
||||||
|
compile: bool,
|
||||||
|
asr_enabled: bool,
|
||||||
|
llama_checkpoint_path: str,
|
||||||
|
decoder_checkpoint_path: str,
|
||||||
|
decoder_config_name: str,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
self.mode = mode
|
||||||
|
self.device = device
|
||||||
|
self.half = half
|
||||||
|
self.compile = compile
|
||||||
|
|
||||||
|
self.precision = torch.half if half else torch.bfloat16
|
||||||
|
|
||||||
|
# Check if CUDA is available
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
self.device = "cpu"
|
||||||
|
logger.info("CUDA is not available, running on CPU.")
|
||||||
|
|
||||||
|
# Load the ASR model if enabled
|
||||||
|
if asr_enabled:
|
||||||
|
self.load_asr_model(self.device)
|
||||||
|
|
||||||
|
# Load the TTS models
|
||||||
|
self.load_llama_model(
|
||||||
|
llama_checkpoint_path, self.device, self.precision, self.compile, self.mode
|
||||||
|
)
|
||||||
|
self.load_decoder_model(
|
||||||
|
decoder_config_name, decoder_checkpoint_path, self.device
|
||||||
|
)
|
||||||
|
self.tts_inference_engine = TTSInferenceEngine(
|
||||||
|
llama_queue=self.llama_queue,
|
||||||
|
decoder_model=self.decoder_model,
|
||||||
|
precision=self.precision,
|
||||||
|
compile=self.compile,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warm up the models
|
||||||
|
if self.mode == "tts":
|
||||||
|
self.warm_up(self.tts_inference_engine)
|
||||||
|
|
||||||
|
def load_asr_model(self, device, hub="ms") -> None:
|
||||||
|
self.asr_model = AutoModel(
|
||||||
|
model=ASR_MODEL_NAME,
|
||||||
|
device=device,
|
||||||
|
disable_pbar=True,
|
||||||
|
hub=hub,
|
||||||
|
)
|
||||||
|
logger.info("ASR model loaded.")
|
||||||
|
|
||||||
|
def load_llama_model(
|
||||||
|
self, checkpoint_path, device, precision, compile, mode
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
if mode == "tts":
|
||||||
|
self.llama_queue = launch_thread_safe_queue(
|
||||||
|
checkpoint_path=checkpoint_path,
|
||||||
|
device=device,
|
||||||
|
precision=precision,
|
||||||
|
compile=compile,
|
||||||
|
)
|
||||||
|
elif mode == "agent":
|
||||||
|
self.llama_queue, self.tokenizer, self.config = (
|
||||||
|
launch_thread_safe_queue_agent(
|
||||||
|
checkpoint_path=checkpoint_path,
|
||||||
|
device=device,
|
||||||
|
precision=precision,
|
||||||
|
compile=compile,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid mode: {mode}")
|
||||||
|
|
||||||
|
logger.info("LLAMA model loaded.")
|
||||||
|
|
||||||
|
def load_decoder_model(self, config_name, checkpoint_path, device) -> None:
|
||||||
|
self.decoder_model = load_decoder_model(
|
||||||
|
config_name=config_name,
|
||||||
|
checkpoint_path=checkpoint_path,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
logger.info("Decoder model loaded.")
|
||||||
|
|
||||||
|
def warm_up(self, tts_inference_engine) -> None:
|
||||||
|
request = ServeTTSRequest(
|
||||||
|
text="Hello world.",
|
||||||
|
references=[],
|
||||||
|
reference_id=None,
|
||||||
|
max_new_tokens=0,
|
||||||
|
chunk_length=200,
|
||||||
|
top_p=0.7,
|
||||||
|
repetition_penalty=1.5,
|
||||||
|
temperature=0.7,
|
||||||
|
format="wav",
|
||||||
|
)
|
||||||
|
list(inference(request, tts_inference_engine))
|
||||||
|
logger.info("Models warmed up.")
|
129
tools/server/model_utils.py
Normal file
129
tools/server/model_utils.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
import io
|
||||||
|
import re
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from cachetools import LRUCache, cached
|
||||||
|
|
||||||
|
CACHE_MAXSIZE = 10000
|
||||||
|
MICRO_BATCH_SIZE = 8
|
||||||
|
ASR_SAMPLE_RATE = 16000
|
||||||
|
HUGE_GAP_THRESHOLD = 4000
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@torch.autocast(device_type="cuda", dtype=torch.half)
|
||||||
|
def batch_encode(model, audios_list: list[bytes]):
|
||||||
|
audios: list[torch.Tensor] = [
|
||||||
|
(
|
||||||
|
torch.from_numpy(
|
||||||
|
librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
|
||||||
|
)[None]
|
||||||
|
if isinstance(audio, bytes)
|
||||||
|
else audio
|
||||||
|
)
|
||||||
|
for audio in audios_list
|
||||||
|
]
|
||||||
|
|
||||||
|
lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
|
||||||
|
max_length = lengths.max().item()
|
||||||
|
|
||||||
|
print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
|
||||||
|
|
||||||
|
padded = torch.stack(
|
||||||
|
[
|
||||||
|
torch.nn.functional.pad(audio, (0, int(max_length - audio.shape[-1])))
|
||||||
|
for audio in audios
|
||||||
|
]
|
||||||
|
).to(model.device)
|
||||||
|
|
||||||
|
features, feature_lengths = model.encode(padded, audio_lengths=lengths)
|
||||||
|
features, feature_lengths = features.cpu(), feature_lengths.cpu()
|
||||||
|
|
||||||
|
return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
|
||||||
|
|
||||||
|
|
||||||
|
@cached(
|
||||||
|
cache=LRUCache(maxsize=CACHE_MAXSIZE),
|
||||||
|
key=lambda model, audios: (model.device, tuple(audios)),
|
||||||
|
)
|
||||||
|
def cached_vqgan_batch_encode(model, audios: list[bytes]):
|
||||||
|
return batch_encode(model, audios)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@torch.autocast(device_type="cuda", dtype=torch.half)
|
||||||
|
def vqgan_decode(model, features):
|
||||||
|
lengths = torch.tensor(
|
||||||
|
[feature.shape[-1] for feature in features], device=model.device
|
||||||
|
)
|
||||||
|
max_length = lengths.max().item()
|
||||||
|
padded = torch.stack(
|
||||||
|
[
|
||||||
|
torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
|
||||||
|
for feature in features
|
||||||
|
]
|
||||||
|
).to(model.device)
|
||||||
|
|
||||||
|
# If bs too large, we do micro batch decode
|
||||||
|
audios, audio_lengths = [], []
|
||||||
|
for i in range(0, padded.shape[0], MICRO_BATCH_SIZE):
|
||||||
|
audio, audio_length = model.decode(
|
||||||
|
padded[i : i + MICRO_BATCH_SIZE],
|
||||||
|
feature_lengths=lengths[i : i + MICRO_BATCH_SIZE],
|
||||||
|
)
|
||||||
|
audios.append(audio)
|
||||||
|
audio_lengths.append(audio_length)
|
||||||
|
audios = torch.cat(audios, dim=0)
|
||||||
|
audio_lengths = torch.cat(audio_lengths, dim=0)
|
||||||
|
audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
|
||||||
|
|
||||||
|
return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def batch_asr(model, lock, audios, sr, language="auto"):
|
||||||
|
resampled_audios = []
|
||||||
|
for audio in audios:
|
||||||
|
audio = torchaudio.functional.resample(audio, sr, ASR_SAMPLE_RATE)
|
||||||
|
assert audio.ndim == 1
|
||||||
|
resampled_audios.append(audio)
|
||||||
|
|
||||||
|
with lock:
|
||||||
|
res = model.generate(
|
||||||
|
input=resampled_audios,
|
||||||
|
batch_size=len(resampled_audios),
|
||||||
|
language=language,
|
||||||
|
use_itn=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for r, audio in zip(res, audios):
|
||||||
|
text = r["text"]
|
||||||
|
text = re.sub(r"<\|.*?\|>", "", text)
|
||||||
|
duration = len(audio) / sr * 1000
|
||||||
|
huge_gap = False
|
||||||
|
|
||||||
|
if "timestamp" in r and len(r["timestamp"]) > 2:
|
||||||
|
for timestamp_a, timestamp_b in zip(
|
||||||
|
r["timestamp"][:-1], r["timestamp"][1:]
|
||||||
|
):
|
||||||
|
# If there is a gap of more than 4 seconds, we consider it as a huge gap
|
||||||
|
if timestamp_b[0] - timestamp_a[1] > HUGE_GAP_THRESHOLD:
|
||||||
|
huge_gap = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# Doesn't make sense to have a huge gap at the end
|
||||||
|
if duration - r["timestamp"][-1][1] > HUGE_GAP_THRESHOLD:
|
||||||
|
huge_gap = True
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"text": text,
|
||||||
|
"duration": duration,
|
||||||
|
"huge_gap": huge_gap,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
246
tools/server/views.py
Normal file
246
tools/server/views.py
Normal file
@ -0,0 +1,246 @@
|
|||||||
|
import io
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import ormsgpack
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
from kui.asgi import HTTPException, HttpView, JSONResponse, StreamResponse, request
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from tools.schema import (
|
||||||
|
ServeASRRequest,
|
||||||
|
ServeASRResponse,
|
||||||
|
ServeChatRequest,
|
||||||
|
ServeTTSRequest,
|
||||||
|
ServeVQGANDecodeRequest,
|
||||||
|
ServeVQGANDecodeResponse,
|
||||||
|
ServeVQGANEncodeRequest,
|
||||||
|
ServeVQGANEncodeResponse,
|
||||||
|
)
|
||||||
|
from tools.server.agent import get_response_generator
|
||||||
|
from tools.server.api_utils import (
|
||||||
|
buffer_to_async_generator,
|
||||||
|
get_content_type,
|
||||||
|
inference_async,
|
||||||
|
)
|
||||||
|
from tools.server.inference import inference_wrapper as inference
|
||||||
|
from tools.server.model_manager import ModelManager
|
||||||
|
from tools.server.model_utils import batch_asr, cached_vqgan_batch_encode, vqgan_decode
|
||||||
|
|
||||||
|
MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1))
|
||||||
|
|
||||||
|
|
||||||
|
class HealthView(HttpView):
|
||||||
|
"""
|
||||||
|
Return the health status of the server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def post(cls):
|
||||||
|
return JSONResponse({"status": "ok"})
|
||||||
|
|
||||||
|
|
||||||
|
class VQGANEncodeView(HttpView):
|
||||||
|
"""
|
||||||
|
Encode the audio into symbolic tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def post(cls):
|
||||||
|
# Decode the request
|
||||||
|
payload = await request.data()
|
||||||
|
req = ServeVQGANEncodeRequest(**payload)
|
||||||
|
|
||||||
|
# Get the model from the app
|
||||||
|
model_manager: ModelManager = request.app.state.model_manager
|
||||||
|
decoder_model = model_manager.decoder_model
|
||||||
|
|
||||||
|
# Encode the audio
|
||||||
|
start_time = time.time()
|
||||||
|
tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
|
||||||
|
logger.info(
|
||||||
|
f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the response
|
||||||
|
return ormsgpack.packb(
|
||||||
|
ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
|
||||||
|
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VQGANDecodeView(HttpView):
|
||||||
|
"""
|
||||||
|
Decode the symbolic tokens into audio.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def post(cls):
|
||||||
|
# Decode the request
|
||||||
|
payload = await request.data()
|
||||||
|
req = ServeVQGANDecodeRequest(**payload)
|
||||||
|
|
||||||
|
# Get the model from the app
|
||||||
|
model_manager: ModelManager = request.app.state.model_manager
|
||||||
|
decoder_model = model_manager.decoder_model
|
||||||
|
|
||||||
|
# Decode the audio
|
||||||
|
tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
|
||||||
|
start_time = time.time()
|
||||||
|
audios = vqgan_decode(decoder_model, tokens)
|
||||||
|
logger.info(
|
||||||
|
f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms"
|
||||||
|
)
|
||||||
|
audios = [audio.astype(np.float16).tobytes() for audio in audios]
|
||||||
|
|
||||||
|
# Return the response
|
||||||
|
return ormsgpack.packb(
|
||||||
|
ServeVQGANDecodeResponse(audios=audios),
|
||||||
|
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ASRView(HttpView):
|
||||||
|
"""
|
||||||
|
Perform automatic speech recognition on the audio.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def post(cls):
|
||||||
|
# Decode the request
|
||||||
|
payload = await request.data()
|
||||||
|
req = ServeASRRequest(**payload)
|
||||||
|
|
||||||
|
# Get the model from the app
|
||||||
|
model_manager: ModelManager = request.app.state.model_manager
|
||||||
|
asr_model = model_manager.asr_model
|
||||||
|
lock = request.app.state.lock
|
||||||
|
|
||||||
|
# Perform ASR
|
||||||
|
start_time = time.time()
|
||||||
|
audios = [np.frombuffer(audio, dtype=np.float16) for audio in req.audios]
|
||||||
|
audios = [torch.from_numpy(audio).float() for audio in audios]
|
||||||
|
|
||||||
|
if any(audios.shape[-1] >= 30 * req.sample_rate for audios in audios):
|
||||||
|
raise HTTPException(status_code=400, content="Audio length is too long")
|
||||||
|
|
||||||
|
transcriptions = batch_asr(
|
||||||
|
asr_model, lock, audios=audios, sr=req.sample_rate, language=req.language
|
||||||
|
)
|
||||||
|
logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
|
||||||
|
|
||||||
|
# Return the response
|
||||||
|
return ormsgpack.packb(
|
||||||
|
ServeASRResponse(transcriptions=transcriptions),
|
||||||
|
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TTSView(HttpView):
|
||||||
|
"""
|
||||||
|
Perform text-to-speech on the input text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def post(cls):
|
||||||
|
# Decode the request
|
||||||
|
payload = await request.data()
|
||||||
|
req = ServeTTSRequest(**payload)
|
||||||
|
|
||||||
|
# Get the model from the app
|
||||||
|
app_state = request.app.state
|
||||||
|
model_manager: ModelManager = app_state.model_manager
|
||||||
|
engine = model_manager.tts_inference_engine
|
||||||
|
sample_rate = engine.decoder_model.spec_transform.sample_rate
|
||||||
|
|
||||||
|
# Check if the text is too long
|
||||||
|
if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
|
||||||
|
raise HTTPException(
|
||||||
|
HTTPStatus.BAD_REQUEST,
|
||||||
|
content=f"Text is too long, max length is {app_state.max_text_length}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if streaming is enabled
|
||||||
|
if req.streaming and req.format != "wav":
|
||||||
|
raise HTTPException(
|
||||||
|
HTTPStatus.BAD_REQUEST,
|
||||||
|
content="Streaming only supports WAV format",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform TTS
|
||||||
|
if req.streaming:
|
||||||
|
return StreamResponse(
|
||||||
|
iterable=inference_async(req, engine),
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": f"attachment; filename=audio.{req.format}",
|
||||||
|
},
|
||||||
|
content_type=get_content_type(req.format),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
fake_audios = next(inference(req, engine))
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
sf.write(
|
||||||
|
buffer,
|
||||||
|
fake_audios,
|
||||||
|
sample_rate,
|
||||||
|
format=req.format,
|
||||||
|
)
|
||||||
|
|
||||||
|
return StreamResponse(
|
||||||
|
iterable=buffer_to_async_generator(buffer.getvalue()),
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": f"attachment; filename=audio.{req.format}",
|
||||||
|
},
|
||||||
|
content_type=get_content_type(req.format),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatView(HttpView):
|
||||||
|
"""
|
||||||
|
Perform chatbot inference on the input text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def post(cls):
|
||||||
|
# Decode the request
|
||||||
|
payload = await request.data()
|
||||||
|
req = ServeChatRequest(**payload)
|
||||||
|
|
||||||
|
# Check that the number of samples requested is correct
|
||||||
|
if req.num_samples < 1 or req.num_samples > MAX_NUM_SAMPLES:
|
||||||
|
raise HTTPException(
|
||||||
|
HTTPStatus.BAD_REQUEST,
|
||||||
|
content=f"Number of samples must be between 1 and {MAX_NUM_SAMPLES}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the type of content provided
|
||||||
|
content_type = request.headers.get("Content-Type", "application/json")
|
||||||
|
json_mode = "application/json" in content_type
|
||||||
|
|
||||||
|
# Get the models from the app
|
||||||
|
model_manager: ModelManager = request.app.state.model_manager
|
||||||
|
llama_queue = model_manager.llama_queue
|
||||||
|
tokenizer = model_manager.tokenizer
|
||||||
|
config = model_manager.config
|
||||||
|
|
||||||
|
device = request.app.state.device
|
||||||
|
|
||||||
|
# Get the response generators
|
||||||
|
response_generator = get_response_generator(
|
||||||
|
llama_queue, tokenizer, config, req, device, json_mode
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the response in the correct format
|
||||||
|
if req.streaming is False:
|
||||||
|
result = response_generator()
|
||||||
|
if json_mode:
|
||||||
|
return JSONResponse(result.model_dump())
|
||||||
|
else:
|
||||||
|
return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
|
||||||
|
|
||||||
|
return StreamResponse(
|
||||||
|
iterable=response_generator(), content_type="text/event-stream"
|
||||||
|
)
|
570
tools/webui.py
570
tools/webui.py
@ -1,570 +0,0 @@
|
|||||||
import gc
|
|
||||||
import html
|
|
||||||
import io
|
|
||||||
import os
|
|
||||||
import queue
|
|
||||||
import wave
|
|
||||||
from argparse import ArgumentParser
|
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import gradio as gr
|
|
||||||
import librosa
|
|
||||||
import numpy as np
|
|
||||||
import pyrootutils
|
|
||||||
import torch
|
|
||||||
from loguru import logger
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
||||||
|
|
||||||
|
|
||||||
from fish_speech.i18n import i18n
|
|
||||||
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
|
||||||
from fish_speech.utils import autocast_exclude_mps, set_seed
|
|
||||||
from tools.api import decode_vq_tokens, encode_reference
|
|
||||||
from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
|
|
||||||
from tools.llama.generate import (
|
|
||||||
GenerateRequest,
|
|
||||||
GenerateResponse,
|
|
||||||
WrappedGenerateResponse,
|
|
||||||
launch_thread_safe_queue,
|
|
||||||
)
|
|
||||||
from tools.schema import (
|
|
||||||
GLOBAL_NUM_SAMPLES,
|
|
||||||
ASRPackRequest,
|
|
||||||
ServeASRRequest,
|
|
||||||
ServeASRResponse,
|
|
||||||
ServeASRSegment,
|
|
||||||
ServeAudioPart,
|
|
||||||
ServeForwardMessage,
|
|
||||||
ServeMessage,
|
|
||||||
ServeReferenceAudio,
|
|
||||||
ServeRequest,
|
|
||||||
ServeResponse,
|
|
||||||
ServeStreamDelta,
|
|
||||||
ServeStreamResponse,
|
|
||||||
ServeTextPart,
|
|
||||||
ServeTimedASRResponse,
|
|
||||||
ServeTTSRequest,
|
|
||||||
ServeVQGANDecodeRequest,
|
|
||||||
ServeVQGANDecodeResponse,
|
|
||||||
ServeVQGANEncodeRequest,
|
|
||||||
ServeVQGANEncodeResponse,
|
|
||||||
ServeVQPart,
|
|
||||||
)
|
|
||||||
from tools.vqgan.inference import load_model as load_decoder_model
|
|
||||||
|
|
||||||
# Make einx happy
|
|
||||||
os.environ["EINX_FILTER_TRACEBACK"] = "false"
|
|
||||||
|
|
||||||
|
|
||||||
HEADER_MD = f"""# Fish Speech
|
|
||||||
|
|
||||||
{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
|
|
||||||
|
|
||||||
{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).")}
|
|
||||||
|
|
||||||
{i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}
|
|
||||||
|
|
||||||
{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
|
|
||||||
"""
|
|
||||||
|
|
||||||
TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
|
|
||||||
SPACE_IMPORTED = False
|
|
||||||
|
|
||||||
|
|
||||||
def build_html_error_message(error):
|
|
||||||
return f"""
|
|
||||||
<div style="color: red;
|
|
||||||
font-weight: bold;">
|
|
||||||
{html.escape(str(error))}
|
|
||||||
</div>
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def inference(req: ServeTTSRequest):
|
|
||||||
|
|
||||||
idstr: str | None = req.reference_id
|
|
||||||
prompt_tokens, prompt_texts = [], []
|
|
||||||
if idstr is not None:
|
|
||||||
ref_folder = Path("references") / idstr
|
|
||||||
ref_folder.mkdir(parents=True, exist_ok=True)
|
|
||||||
ref_audios = list_files(
|
|
||||||
ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
|
|
||||||
)
|
|
||||||
|
|
||||||
if req.use_memory_cache == "never" or (
|
|
||||||
req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
|
|
||||||
):
|
|
||||||
prompt_tokens = [
|
|
||||||
encode_reference(
|
|
||||||
decoder_model=decoder_model,
|
|
||||||
reference_audio=audio_to_bytes(str(ref_audio)),
|
|
||||||
enable_reference_audio=True,
|
|
||||||
)
|
|
||||||
for ref_audio in ref_audios
|
|
||||||
]
|
|
||||||
prompt_texts = [
|
|
||||||
read_ref_text(str(ref_audio.with_suffix(".lab")))
|
|
||||||
for ref_audio in ref_audios
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
logger.info("Use same references")
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Parse reference audio aka prompt
|
|
||||||
refs = req.references
|
|
||||||
|
|
||||||
if req.use_memory_cache == "never" or (
|
|
||||||
req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
|
|
||||||
):
|
|
||||||
prompt_tokens = [
|
|
||||||
encode_reference(
|
|
||||||
decoder_model=decoder_model,
|
|
||||||
reference_audio=ref.audio,
|
|
||||||
enable_reference_audio=True,
|
|
||||||
)
|
|
||||||
for ref in refs
|
|
||||||
]
|
|
||||||
prompt_texts = [ref.text for ref in refs]
|
|
||||||
else:
|
|
||||||
logger.info("Use same references")
|
|
||||||
|
|
||||||
if req.seed is not None:
|
|
||||||
set_seed(req.seed)
|
|
||||||
logger.warning(f"set seed: {req.seed}")
|
|
||||||
|
|
||||||
# LLAMA Inference
|
|
||||||
request = dict(
|
|
||||||
device=decoder_model.device,
|
|
||||||
max_new_tokens=req.max_new_tokens,
|
|
||||||
text=(
|
|
||||||
req.text
|
|
||||||
if not req.normalize
|
|
||||||
else ChnNormedText(raw_text=req.text).normalize()
|
|
||||||
),
|
|
||||||
top_p=req.top_p,
|
|
||||||
repetition_penalty=req.repetition_penalty,
|
|
||||||
temperature=req.temperature,
|
|
||||||
compile=args.compile,
|
|
||||||
iterative_prompt=req.chunk_length > 0,
|
|
||||||
chunk_length=req.chunk_length,
|
|
||||||
max_length=4096,
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
prompt_text=prompt_texts,
|
|
||||||
)
|
|
||||||
|
|
||||||
response_queue = queue.Queue()
|
|
||||||
llama_queue.put(
|
|
||||||
GenerateRequest(
|
|
||||||
request=request,
|
|
||||||
response_queue=response_queue,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
segments = []
|
|
||||||
|
|
||||||
while True:
|
|
||||||
result: WrappedGenerateResponse = response_queue.get()
|
|
||||||
if result.status == "error":
|
|
||||||
yield None, None, build_html_error_message(result.response)
|
|
||||||
break
|
|
||||||
|
|
||||||
result: GenerateResponse = result.response
|
|
||||||
if result.action == "next":
|
|
||||||
break
|
|
||||||
|
|
||||||
with autocast_exclude_mps(
|
|
||||||
device_type=decoder_model.device.type, dtype=args.precision
|
|
||||||
):
|
|
||||||
fake_audios = decode_vq_tokens(
|
|
||||||
decoder_model=decoder_model,
|
|
||||||
codes=result.codes,
|
|
||||||
)
|
|
||||||
|
|
||||||
fake_audios = fake_audios.float().cpu().numpy()
|
|
||||||
segments.append(fake_audios)
|
|
||||||
|
|
||||||
if len(segments) == 0:
|
|
||||||
return (
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
build_html_error_message(
|
|
||||||
i18n("No audio generated, please check the input text.")
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# No matter streaming or not, we need to return the final audio
|
|
||||||
audio = np.concatenate(segments, axis=0)
|
|
||||||
yield None, (decoder_model.spec_transform.sample_rate, audio), None
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
|
|
||||||
n_audios = 4
|
|
||||||
|
|
||||||
global_audio_list = []
|
|
||||||
global_error_list = []
|
|
||||||
|
|
||||||
|
|
||||||
def inference_wrapper(
|
|
||||||
text,
|
|
||||||
enable_reference_audio,
|
|
||||||
reference_audio,
|
|
||||||
reference_text,
|
|
||||||
max_new_tokens,
|
|
||||||
chunk_length,
|
|
||||||
top_p,
|
|
||||||
repetition_penalty,
|
|
||||||
temperature,
|
|
||||||
seed,
|
|
||||||
batch_infer_num,
|
|
||||||
):
|
|
||||||
audios = []
|
|
||||||
errors = []
|
|
||||||
|
|
||||||
for _ in range(batch_infer_num):
|
|
||||||
result = inference(
|
|
||||||
text,
|
|
||||||
enable_reference_audio,
|
|
||||||
reference_audio,
|
|
||||||
reference_text,
|
|
||||||
max_new_tokens,
|
|
||||||
chunk_length,
|
|
||||||
top_p,
|
|
||||||
repetition_penalty,
|
|
||||||
temperature,
|
|
||||||
seed,
|
|
||||||
)
|
|
||||||
|
|
||||||
_, audio_data, error_message = next(result)
|
|
||||||
|
|
||||||
audios.append(
|
|
||||||
gr.Audio(value=audio_data if audio_data else None, visible=True),
|
|
||||||
)
|
|
||||||
errors.append(
|
|
||||||
gr.HTML(value=error_message if error_message else None, visible=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
for _ in range(batch_infer_num, n_audios):
|
|
||||||
audios.append(
|
|
||||||
gr.Audio(value=None, visible=False),
|
|
||||||
)
|
|
||||||
errors.append(
|
|
||||||
gr.HTML(value=None, visible=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
return None, *audios, *errors
|
|
||||||
|
|
||||||
|
|
||||||
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
|
|
||||||
with wave.open(buffer, "wb") as wav_file:
|
|
||||||
wav_file.setnchannels(channels)
|
|
||||||
wav_file.setsampwidth(bit_depth // 8)
|
|
||||||
wav_file.setframerate(sample_rate)
|
|
||||||
|
|
||||||
wav_header_bytes = buffer.getvalue()
|
|
||||||
buffer.close()
|
|
||||||
return wav_header_bytes
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_text(user_input, use_normalization):
|
|
||||||
if use_normalization:
|
|
||||||
return ChnNormedText(raw_text=user_input).normalize()
|
|
||||||
else:
|
|
||||||
return user_input
|
|
||||||
|
|
||||||
|
|
||||||
def update_examples():
|
|
||||||
examples_dir = Path("references")
|
|
||||||
examples_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
example_audios = list_files(examples_dir, AUDIO_EXTENSIONS, recursive=True)
|
|
||||||
return gr.Dropdown(choices=example_audios + [""])
|
|
||||||
|
|
||||||
|
|
||||||
def build_app():
|
|
||||||
with gr.Blocks(theme=gr.themes.Base()) as app:
|
|
||||||
gr.Markdown(HEADER_MD)
|
|
||||||
|
|
||||||
# Use light theme by default
|
|
||||||
app.load(
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
|
|
||||||
% args.theme,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Inference
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column(scale=3):
|
|
||||||
text = gr.Textbox(
|
|
||||||
label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
|
|
||||||
)
|
|
||||||
refined_text = gr.Textbox(
|
|
||||||
label=i18n("Realtime Transform Text"),
|
|
||||||
placeholder=i18n(
|
|
||||||
"Normalization Result Preview (Currently Only Chinese)"
|
|
||||||
),
|
|
||||||
lines=5,
|
|
||||||
interactive=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
normalize = gr.Checkbox(
|
|
||||||
label=i18n("Text Normalization"),
|
|
||||||
value=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
with gr.Tab(label=i18n("Advanced Config")):
|
|
||||||
with gr.Row():
|
|
||||||
chunk_length = gr.Slider(
|
|
||||||
label=i18n("Iterative Prompt Length, 0 means off"),
|
|
||||||
minimum=0,
|
|
||||||
maximum=300,
|
|
||||||
value=200,
|
|
||||||
step=8,
|
|
||||||
)
|
|
||||||
|
|
||||||
max_new_tokens = gr.Slider(
|
|
||||||
label=i18n(
|
|
||||||
"Maximum tokens per batch, 0 means no limit"
|
|
||||||
),
|
|
||||||
minimum=0,
|
|
||||||
maximum=2048,
|
|
||||||
value=0,
|
|
||||||
step=8,
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
top_p = gr.Slider(
|
|
||||||
label="Top-P",
|
|
||||||
minimum=0.6,
|
|
||||||
maximum=0.9,
|
|
||||||
value=0.7,
|
|
||||||
step=0.01,
|
|
||||||
)
|
|
||||||
|
|
||||||
repetition_penalty = gr.Slider(
|
|
||||||
label=i18n("Repetition Penalty"),
|
|
||||||
minimum=1,
|
|
||||||
maximum=1.5,
|
|
||||||
value=1.2,
|
|
||||||
step=0.01,
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
temperature = gr.Slider(
|
|
||||||
label="Temperature",
|
|
||||||
minimum=0.6,
|
|
||||||
maximum=0.9,
|
|
||||||
value=0.7,
|
|
||||||
step=0.01,
|
|
||||||
)
|
|
||||||
seed = gr.Number(
|
|
||||||
label="Seed",
|
|
||||||
info="0 means randomized inference, otherwise deterministic",
|
|
||||||
value=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Tab(label=i18n("Reference Audio")):
|
|
||||||
with gr.Row():
|
|
||||||
gr.Markdown(
|
|
||||||
i18n(
|
|
||||||
"5 to 10 seconds of reference audio, useful for specifying speaker."
|
|
||||||
)
|
|
||||||
)
|
|
||||||
with gr.Row():
|
|
||||||
reference_id = gr.Textbox(
|
|
||||||
label=i18n("Reference ID"),
|
|
||||||
placeholder="Leave empty to use uploaded references",
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
use_memory_cache = gr.Radio(
|
|
||||||
label=i18n("Use Memory Cache"),
|
|
||||||
choices=["never", "on-demand", "always"],
|
|
||||||
value="on-demand",
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
reference_audio = gr.Audio(
|
|
||||||
label=i18n("Reference Audio"),
|
|
||||||
type="filepath",
|
|
||||||
)
|
|
||||||
with gr.Row():
|
|
||||||
reference_text = gr.Textbox(
|
|
||||||
label=i18n("Reference Text"),
|
|
||||||
lines=1,
|
|
||||||
placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
|
|
||||||
value="",
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Column(scale=3):
|
|
||||||
with gr.Row():
|
|
||||||
error = gr.HTML(
|
|
||||||
label=i18n("Error Message"),
|
|
||||||
visible=True,
|
|
||||||
)
|
|
||||||
with gr.Row():
|
|
||||||
audio = gr.Audio(
|
|
||||||
label=i18n("Generated Audio"),
|
|
||||||
type="numpy",
|
|
||||||
interactive=False,
|
|
||||||
visible=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column(scale=3):
|
|
||||||
generate = gr.Button(
|
|
||||||
value="\U0001F3A7 " + i18n("Generate"), variant="primary"
|
|
||||||
)
|
|
||||||
|
|
||||||
text.input(fn=normalize_text, inputs=[text, normalize], outputs=[refined_text])
|
|
||||||
|
|
||||||
def inference_wrapper(
|
|
||||||
text,
|
|
||||||
normalize,
|
|
||||||
reference_id,
|
|
||||||
reference_audio,
|
|
||||||
reference_text,
|
|
||||||
max_new_tokens,
|
|
||||||
chunk_length,
|
|
||||||
top_p,
|
|
||||||
repetition_penalty,
|
|
||||||
temperature,
|
|
||||||
seed,
|
|
||||||
use_memory_cache,
|
|
||||||
):
|
|
||||||
references = []
|
|
||||||
if reference_audio:
|
|
||||||
# 将文件路径转换为字节
|
|
||||||
with open(reference_audio, "rb") as audio_file:
|
|
||||||
audio_bytes = audio_file.read()
|
|
||||||
references = [
|
|
||||||
ServeReferenceAudio(audio=audio_bytes, text=reference_text)
|
|
||||||
]
|
|
||||||
|
|
||||||
req = ServeTTSRequest(
|
|
||||||
text=text,
|
|
||||||
normalize=normalize,
|
|
||||||
reference_id=reference_id if reference_id else None,
|
|
||||||
references=references,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
chunk_length=chunk_length,
|
|
||||||
top_p=top_p,
|
|
||||||
repetition_penalty=repetition_penalty,
|
|
||||||
temperature=temperature,
|
|
||||||
seed=int(seed) if seed else None,
|
|
||||||
use_memory_cache=use_memory_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
for result in inference(req):
|
|
||||||
if result[2]: # Error message
|
|
||||||
return None, result[2]
|
|
||||||
elif result[1]: # Audio data
|
|
||||||
return result[1], None
|
|
||||||
|
|
||||||
return None, i18n("No audio generated")
|
|
||||||
|
|
||||||
# Submit
|
|
||||||
generate.click(
|
|
||||||
inference_wrapper,
|
|
||||||
[
|
|
||||||
refined_text,
|
|
||||||
normalize,
|
|
||||||
reference_id,
|
|
||||||
reference_audio,
|
|
||||||
reference_text,
|
|
||||||
max_new_tokens,
|
|
||||||
chunk_length,
|
|
||||||
top_p,
|
|
||||||
repetition_penalty,
|
|
||||||
temperature,
|
|
||||||
seed,
|
|
||||||
use_memory_cache,
|
|
||||||
],
|
|
||||||
[audio, error],
|
|
||||||
concurrency_limit=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--llama-checkpoint-path",
|
|
||||||
type=Path,
|
|
||||||
default="checkpoints/fish-speech-1.5",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--decoder-checkpoint-path",
|
|
||||||
type=Path,
|
|
||||||
default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
|
||||||
)
|
|
||||||
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
|
||||||
parser.add_argument("--device", type=str, default="cuda")
|
|
||||||
parser.add_argument("--half", action="store_true")
|
|
||||||
parser.add_argument("--compile", action="store_true")
|
|
||||||
parser.add_argument("--max-gradio-length", type=int, default=0)
|
|
||||||
parser.add_argument("--theme", type=str, default="light")
|
|
||||||
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
args = parse_args()
|
|
||||||
args.precision = torch.half if args.half else torch.bfloat16
|
|
||||||
|
|
||||||
# Check if CUDA is available
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
logger.info("CUDA is not available, running on CPU.")
|
|
||||||
args.device = "cpu"
|
|
||||||
|
|
||||||
logger.info("Loading Llama model...")
|
|
||||||
llama_queue = launch_thread_safe_queue(
|
|
||||||
checkpoint_path=args.llama_checkpoint_path,
|
|
||||||
device=args.device,
|
|
||||||
precision=args.precision,
|
|
||||||
compile=args.compile,
|
|
||||||
)
|
|
||||||
logger.info("Llama model loaded, loading VQ-GAN model...")
|
|
||||||
|
|
||||||
decoder_model = load_decoder_model(
|
|
||||||
config_name=args.decoder_config_name,
|
|
||||||
checkpoint_path=args.decoder_checkpoint_path,
|
|
||||||
device=args.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Decoder model loaded, warming up...")
|
|
||||||
|
|
||||||
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
|
||||||
list(
|
|
||||||
inference(
|
|
||||||
ServeTTSRequest(
|
|
||||||
text="Hello world.",
|
|
||||||
references=[],
|
|
||||||
reference_id=None,
|
|
||||||
max_new_tokens=0,
|
|
||||||
chunk_length=200,
|
|
||||||
top_p=0.7,
|
|
||||||
repetition_penalty=1.5,
|
|
||||||
temperature=0.7,
|
|
||||||
emotion=None,
|
|
||||||
format="wav",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Warming up done, launching the web UI...")
|
|
||||||
|
|
||||||
app = build_app()
|
|
||||||
app.launch(show_api=True)
|
|
173
tools/webui/__init__.py
Normal file
173
tools/webui/__init__.py
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from fish_speech.i18n import i18n
|
||||||
|
from tools.inference_engine.utils import normalize_text
|
||||||
|
from tools.webui.variables import HEADER_MD, TEXTBOX_PLACEHOLDER
|
||||||
|
|
||||||
|
|
||||||
|
def build_app(inference_fct: Callable, theme: str = "light") -> gr.Blocks:
|
||||||
|
with gr.Blocks(theme=gr.themes.Base()) as app:
|
||||||
|
gr.Markdown(HEADER_MD)
|
||||||
|
|
||||||
|
# Use light theme by default
|
||||||
|
app.load(
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
|
||||||
|
% theme,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
text = gr.Textbox(
|
||||||
|
label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
|
||||||
|
)
|
||||||
|
refined_text = gr.Textbox(
|
||||||
|
label=i18n("Realtime Transform Text"),
|
||||||
|
placeholder=i18n(
|
||||||
|
"Normalization Result Preview (Currently Only Chinese)"
|
||||||
|
),
|
||||||
|
lines=5,
|
||||||
|
interactive=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
normalize = gr.Checkbox(
|
||||||
|
label=i18n("Text Normalization"),
|
||||||
|
value=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Tab(label=i18n("Advanced Config")):
|
||||||
|
with gr.Row():
|
||||||
|
chunk_length = gr.Slider(
|
||||||
|
label=i18n("Iterative Prompt Length, 0 means off"),
|
||||||
|
minimum=0,
|
||||||
|
maximum=300,
|
||||||
|
value=200,
|
||||||
|
step=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
max_new_tokens = gr.Slider(
|
||||||
|
label=i18n(
|
||||||
|
"Maximum tokens per batch, 0 means no limit"
|
||||||
|
),
|
||||||
|
minimum=0,
|
||||||
|
maximum=2048,
|
||||||
|
value=0,
|
||||||
|
step=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
top_p = gr.Slider(
|
||||||
|
label="Top-P",
|
||||||
|
minimum=0.6,
|
||||||
|
maximum=0.9,
|
||||||
|
value=0.7,
|
||||||
|
step=0.01,
|
||||||
|
)
|
||||||
|
|
||||||
|
repetition_penalty = gr.Slider(
|
||||||
|
label=i18n("Repetition Penalty"),
|
||||||
|
minimum=1,
|
||||||
|
maximum=1.5,
|
||||||
|
value=1.2,
|
||||||
|
step=0.01,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
temperature = gr.Slider(
|
||||||
|
label="Temperature",
|
||||||
|
minimum=0.6,
|
||||||
|
maximum=0.9,
|
||||||
|
value=0.7,
|
||||||
|
step=0.01,
|
||||||
|
)
|
||||||
|
seed = gr.Number(
|
||||||
|
label="Seed",
|
||||||
|
info="0 means randomized inference, otherwise deterministic",
|
||||||
|
value=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Tab(label=i18n("Reference Audio")):
|
||||||
|
with gr.Row():
|
||||||
|
gr.Markdown(
|
||||||
|
i18n(
|
||||||
|
"5 to 10 seconds of reference audio, useful for specifying speaker."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
reference_id = gr.Textbox(
|
||||||
|
label=i18n("Reference ID"),
|
||||||
|
placeholder="Leave empty to use uploaded references",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
use_memory_cache = gr.Radio(
|
||||||
|
label=i18n("Use Memory Cache"),
|
||||||
|
choices=["on", "off"],
|
||||||
|
value="on",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
reference_audio = gr.Audio(
|
||||||
|
label=i18n("Reference Audio"),
|
||||||
|
type="filepath",
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
reference_text = gr.Textbox(
|
||||||
|
label=i18n("Reference Text"),
|
||||||
|
lines=1,
|
||||||
|
placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
|
||||||
|
value="",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
with gr.Row():
|
||||||
|
error = gr.HTML(
|
||||||
|
label=i18n("Error Message"),
|
||||||
|
visible=True,
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
audio = gr.Audio(
|
||||||
|
label=i18n("Generated Audio"),
|
||||||
|
type="numpy",
|
||||||
|
interactive=False,
|
||||||
|
visible=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
generate = gr.Button(
|
||||||
|
value="\U0001F3A7 " + i18n("Generate"),
|
||||||
|
variant="primary",
|
||||||
|
)
|
||||||
|
|
||||||
|
text.input(fn=normalize_text, inputs=[text, normalize], outputs=[refined_text])
|
||||||
|
|
||||||
|
# Submit
|
||||||
|
generate.click(
|
||||||
|
inference_fct,
|
||||||
|
[
|
||||||
|
refined_text,
|
||||||
|
normalize,
|
||||||
|
reference_id,
|
||||||
|
reference_audio,
|
||||||
|
reference_text,
|
||||||
|
max_new_tokens,
|
||||||
|
chunk_length,
|
||||||
|
top_p,
|
||||||
|
repetition_penalty,
|
||||||
|
temperature,
|
||||||
|
seed,
|
||||||
|
use_memory_cache,
|
||||||
|
],
|
||||||
|
[audio, error],
|
||||||
|
concurrency_limit=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
return app
|
91
tools/webui/inference.py
Normal file
91
tools/webui/inference.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import html
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
from fish_speech.i18n import i18n
|
||||||
|
from tools.schema import ServeReferenceAudio, ServeTTSRequest
|
||||||
|
|
||||||
|
|
||||||
|
def inference_wrapper(
|
||||||
|
text,
|
||||||
|
normalize,
|
||||||
|
reference_id,
|
||||||
|
reference_audio,
|
||||||
|
reference_text,
|
||||||
|
max_new_tokens,
|
||||||
|
chunk_length,
|
||||||
|
top_p,
|
||||||
|
repetition_penalty,
|
||||||
|
temperature,
|
||||||
|
seed,
|
||||||
|
use_memory_cache,
|
||||||
|
engine,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Wrapper for the inference function.
|
||||||
|
Used in the Gradio interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if reference_audio:
|
||||||
|
references = get_reference_audio(reference_audio, reference_text)
|
||||||
|
else:
|
||||||
|
references = []
|
||||||
|
|
||||||
|
req = ServeTTSRequest(
|
||||||
|
text=text,
|
||||||
|
normalize=normalize,
|
||||||
|
reference_id=reference_id if reference_id else None,
|
||||||
|
references=references,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
top_p=top_p,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
temperature=temperature,
|
||||||
|
seed=int(seed) if seed else None,
|
||||||
|
use_memory_cache=use_memory_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
for result in engine.inference(req):
|
||||||
|
match result.code:
|
||||||
|
case "final":
|
||||||
|
return result.audio, None
|
||||||
|
case "error":
|
||||||
|
return None, build_html_error_message(i18n(result.error))
|
||||||
|
case _:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None, i18n("No audio generated")
|
||||||
|
|
||||||
|
|
||||||
|
def get_reference_audio(reference_audio: str, reference_text: str) -> list:
|
||||||
|
"""
|
||||||
|
Get the reference audio bytes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with open(reference_audio, "rb") as audio_file:
|
||||||
|
audio_bytes = audio_file.read()
|
||||||
|
|
||||||
|
return [ServeReferenceAudio(audio=audio_bytes, text=reference_text)]
|
||||||
|
|
||||||
|
|
||||||
|
def build_html_error_message(error: Any) -> str:
|
||||||
|
|
||||||
|
error = error if isinstance(error, Exception) else Exception("Unknown error")
|
||||||
|
|
||||||
|
return f"""
|
||||||
|
<div style="color: red;
|
||||||
|
font-weight: bold;">
|
||||||
|
{html.escape(str(error))}
|
||||||
|
</div>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_inference_wrapper(engine) -> Callable:
|
||||||
|
"""
|
||||||
|
Get the inference function with the immutable arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return partial(
|
||||||
|
inference_wrapper,
|
||||||
|
engine=engine,
|
||||||
|
)
|
14
tools/webui/variables.py
Normal file
14
tools/webui/variables.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from fish_speech.i18n import i18n
|
||||||
|
|
||||||
|
HEADER_MD = f"""# Fish Speech
|
||||||
|
|
||||||
|
{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
|
||||||
|
|
||||||
|
{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).")}
|
||||||
|
|
||||||
|
{i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}
|
||||||
|
|
||||||
|
{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
|
||||||
|
"""
|
||||||
|
|
||||||
|
TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
|
Loading…
x
Reference in New Issue
Block a user