* rename webui.py to run_webui.py * remove unused imports * remove unsued code * move inference code and fix all warnings * move web app code * make code easier to read * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove unused function * remove msgpack_api.py * rename API files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * finish updating the doc with the new file names * finish updating the doc with the new file names * fix CPU use in the API * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor WebUIinference in a class with submodules * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * re-enable streaming in webui inference code * generalize inference code in webui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix * make a unique inference engine class * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix * cleaning code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * implement new structure of the API (not working) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor API * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * reimplement chat endpoint * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
58 lines
2.0 KiB
Python
58 lines
2.0 KiB
Python
from typing import Callable
|
|
|
|
import torch
|
|
from loguru import logger
|
|
|
|
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
|
|
|
|
|
class VQManager:
|
|
|
|
def __init__(self):
|
|
# Make Pylance happy (attribut/method not defined...)
|
|
self.decoder_model: FireflyArchitecture
|
|
self.load_audio: Callable
|
|
|
|
def decode_vq_tokens(self, codes):
|
|
feature_lengths = torch.tensor(
|
|
[codes.shape[1]], device=self.decoder_model.device
|
|
)
|
|
logger.info(f"VQ features: {codes.shape}")
|
|
|
|
if isinstance(self.decoder_model, FireflyArchitecture):
|
|
return self.decoder_model.decode(
|
|
indices=codes[None],
|
|
feature_lengths=feature_lengths,
|
|
)[0].squeeze()
|
|
|
|
raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
|
|
|
|
def encode_reference(self, reference_audio, enable_reference_audio):
|
|
if enable_reference_audio and reference_audio is not None:
|
|
# Load audios, and prepare basic info here
|
|
reference_audio_content = self.load_audio(
|
|
reference_audio, self.decoder_model.spec_transform.sample_rate
|
|
)
|
|
|
|
audios = torch.from_numpy(reference_audio_content).to(
|
|
self.decoder_model.device
|
|
)[None, None, :]
|
|
audio_lengths = torch.tensor(
|
|
[audios.shape[2]], device=self.decoder_model.device, dtype=torch.long
|
|
)
|
|
logger.info(
|
|
f"Loaded audio with {audios.shape[2] / self.decoder_model.spec_transform.sample_rate:.2f} seconds"
|
|
)
|
|
|
|
# VQ Encoder
|
|
if isinstance(self.decoder_model, FireflyArchitecture):
|
|
prompt_tokens = self.decoder_model.encode(audios, audio_lengths)[0][0]
|
|
logger.info(f"Encoded prompt: {prompt_tokens.shape}")
|
|
else:
|
|
raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
|
|
else:
|
|
prompt_tokens = None
|
|
logger.info("No reference audio provided")
|
|
|
|
return prompt_tokens
|