Picus303 62eae262c2
Make WebUI and API code cleaner (+ 1.5 fixes) (#703)
* rename webui.py to run_webui.py

* remove unused imports

* remove unsued code

* move inference code and fix all warnings

* move web app code

* make code easier to read

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove unused function

* remove msgpack_api.py

* rename API files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* finish updating the doc with the new file names

* finish updating the doc with the new file names

* fix CPU use in the API

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor WebUIinference in a class with submodules

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* re-enable streaming in webui inference code

* generalize inference code in webui

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fix

* make a unique inference engine class

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fix

* cleaning code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* implement new structure of the API (not working)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor API

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* reimplement chat endpoint

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-12-07 14:13:19 +08:00

58 lines
2.0 KiB
Python

from typing import Callable
import torch
from loguru import logger
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
class VQManager:
def __init__(self):
# Make Pylance happy (attribut/method not defined...)
self.decoder_model: FireflyArchitecture
self.load_audio: Callable
def decode_vq_tokens(self, codes):
feature_lengths = torch.tensor(
[codes.shape[1]], device=self.decoder_model.device
)
logger.info(f"VQ features: {codes.shape}")
if isinstance(self.decoder_model, FireflyArchitecture):
return self.decoder_model.decode(
indices=codes[None],
feature_lengths=feature_lengths,
)[0].squeeze()
raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
def encode_reference(self, reference_audio, enable_reference_audio):
if enable_reference_audio and reference_audio is not None:
# Load audios, and prepare basic info here
reference_audio_content = self.load_audio(
reference_audio, self.decoder_model.spec_transform.sample_rate
)
audios = torch.from_numpy(reference_audio_content).to(
self.decoder_model.device
)[None, None, :]
audio_lengths = torch.tensor(
[audios.shape[2]], device=self.decoder_model.device, dtype=torch.long
)
logger.info(
f"Loaded audio with {audios.shape[2] / self.decoder_model.spec_transform.sample_rate:.2f} seconds"
)
# VQ Encoder
if isinstance(self.decoder_model, FireflyArchitecture):
prompt_tokens = self.decoder_model.encode(audios, audio_lengths)[0][0]
logger.info(f"Encoded prompt: {prompt_tokens.shape}")
else:
raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
else:
prompt_tokens = None
logger.info("No reference audio provided")
return prompt_tokens