fish-speech/tools/server/api_utils.py
Picus303 62eae262c2
Make WebUI and API code cleaner (+ 1.5 fixes) (#703)
* rename webui.py to run_webui.py

* remove unused imports

* remove unsued code

* move inference code and fix all warnings

* move web app code

* make code easier to read

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove unused function

* remove msgpack_api.py

* rename API files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* finish updating the doc with the new file names

* finish updating the doc with the new file names

* fix CPU use in the API

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor WebUIinference in a class with submodules

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* re-enable streaming in webui inference code

* generalize inference code in webui

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fix

* make a unique inference engine class

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fix

* cleaning code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* implement new structure of the API (not working)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor API

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* reimplement chat endpoint

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-12-07 14:13:19 +08:00

76 lines
2.4 KiB
Python

from argparse import ArgumentParser
from http import HTTPStatus
from typing import Annotated, Any
import ormsgpack
from baize.datastructures import ContentType
from kui.asgi import HTTPException, HttpRequest
from tools.inference_engine import TTSInferenceEngine
from tools.schema import ServeTTSRequest
from tools.server.inference import inference_wrapper as inference
def parse_args():
parser = ArgumentParser()
parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
parser.add_argument("--load-asr-model", action="store_true")
parser.add_argument(
"--llama-checkpoint-path",
type=str,
default="checkpoints/fish-speech-1.5",
)
parser.add_argument(
"--decoder-checkpoint-path",
type=str,
default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
)
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--half", action="store_true")
parser.add_argument("--compile", action="store_true")
parser.add_argument("--max-text-length", type=int, default=0)
parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
parser.add_argument("--workers", type=int, default=1)
return parser.parse_args()
class MsgPackRequest(HttpRequest):
async def data(
self,
) -> Annotated[
Any, ContentType("application/msgpack"), ContentType("application/json")
]:
if self.content_type == "application/msgpack":
return ormsgpack.unpackb(await self.body)
elif self.content_type == "application/json":
return await self.json
raise HTTPException(
HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
headers={"Accept": "application/msgpack, application/json"},
)
async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine):
for chunk in inference(req, engine):
if isinstance(chunk, bytes):
yield chunk
async def buffer_to_async_generator(buffer):
yield buffer
def get_content_type(audio_format):
if audio_format == "wav":
return "audio/wav"
elif audio_format == "flac":
return "audio/flac"
elif audio_format == "mp3":
return "audio/mpeg"
else:
return "application/octet-stream"