* rename webui.py to run_webui.py * remove unused imports * remove unsued code * move inference code and fix all warnings * move web app code * make code easier to read * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove unused function * remove msgpack_api.py * rename API files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * finish updating the doc with the new file names * finish updating the doc with the new file names * fix CPU use in the API * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor WebUIinference in a class with submodules * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * re-enable streaming in webui inference code * generalize inference code in webui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix * make a unique inference engine class * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix * cleaning code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * implement new structure of the API (not working) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor API * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * reimplement chat endpoint * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
129 lines
4.2 KiB
Python
129 lines
4.2 KiB
Python
import io
|
|
from hashlib import sha256
|
|
from pathlib import Path
|
|
from typing import Callable, Literal, Tuple
|
|
|
|
import torch
|
|
import torchaudio
|
|
from loguru import logger
|
|
|
|
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
|
from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
|
|
from tools.schema import ServeReferenceAudio
|
|
|
|
|
|
class ReferenceLoader:
|
|
|
|
def __init__(self) -> None:
|
|
"""
|
|
Component of the TTSInferenceEngine class.
|
|
Loads and manages the cache for the reference audio and text.
|
|
"""
|
|
self.ref_by_id: dict = {}
|
|
self.ref_by_hash: dict = {}
|
|
|
|
# Make Pylance happy (attribut/method not defined...)
|
|
self.decoder_model: FireflyArchitecture
|
|
self.encode_reference: Callable
|
|
|
|
# Define the torchaudio backend
|
|
backends = torchaudio.list_audio_backends()
|
|
if "ffmpeg" in backends:
|
|
self.backend = "ffmpeg"
|
|
else:
|
|
self.backend = "soundfile"
|
|
|
|
def load_by_id(
|
|
self,
|
|
id: str,
|
|
use_cache: Literal["on", "off"],
|
|
) -> Tuple:
|
|
|
|
# Load the references audio and text by id
|
|
ref_folder = Path("references") / id
|
|
ref_folder.mkdir(parents=True, exist_ok=True)
|
|
ref_audios = list_files(
|
|
ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
|
|
)
|
|
|
|
if use_cache == "off" or id not in self.ref_by_id:
|
|
# If the references are not already loaded, encode them
|
|
prompt_tokens = [
|
|
self.encode_reference(
|
|
decoder_model=self.decoder_model,
|
|
reference_audio=audio_to_bytes(str(ref_audio)),
|
|
enable_reference_audio=True,
|
|
)
|
|
for ref_audio in ref_audios
|
|
]
|
|
prompt_texts = [
|
|
read_ref_text(str(ref_audio.with_suffix(".lab")))
|
|
for ref_audio in ref_audios
|
|
]
|
|
self.ref_by_id[id] = (prompt_tokens, prompt_texts)
|
|
|
|
else:
|
|
# Reuse already encoded references
|
|
logger.info("Use same references")
|
|
prompt_tokens, prompt_texts = self.ref_by_id[id]
|
|
|
|
return prompt_tokens, prompt_texts
|
|
|
|
def load_by_hash(
|
|
self,
|
|
references: list[ServeReferenceAudio],
|
|
use_cache: Literal["on", "off"],
|
|
) -> Tuple:
|
|
|
|
# Load the references audio and text by hash
|
|
audio_hashes = [sha256(ref.audio).hexdigest() for ref in references]
|
|
|
|
cache_used = False
|
|
prompt_tokens, prompt_texts = [], []
|
|
for i, ref in enumerate(references):
|
|
if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash:
|
|
# If the references are not already loaded, encode them
|
|
prompt_tokens.append(
|
|
self.encode_reference(
|
|
decoder_model=self.decoder_model,
|
|
reference_audio=ref.audio,
|
|
enable_reference_audio=True,
|
|
)
|
|
)
|
|
prompt_texts.append(ref.text)
|
|
self.ref_by_hash[audio_hashes[i]] = (prompt_tokens, prompt_texts)
|
|
|
|
else:
|
|
# Reuse already encoded references
|
|
prompt_text, prompt_token = self.ref_by_hash[audio_hashes[i]]
|
|
prompt_texts.append(prompt_text)
|
|
prompt_tokens.append(prompt_token)
|
|
cache_used = True
|
|
|
|
if cache_used:
|
|
logger.info("Use same references")
|
|
|
|
return prompt_tokens, prompt_texts
|
|
|
|
def load_audio(self, reference_audio, sr):
|
|
"""
|
|
Load the audio data from a file or bytes.
|
|
"""
|
|
if len(reference_audio) > 255 or not Path(reference_audio).exists():
|
|
audio_data = reference_audio
|
|
reference_audio = io.BytesIO(audio_data)
|
|
|
|
waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend)
|
|
|
|
if waveform.shape[0] > 1:
|
|
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
|
|
|
if original_sr != sr:
|
|
resampler = torchaudio.transforms.Resample(
|
|
orig_freq=original_sr, new_freq=sr
|
|
)
|
|
waveform = resampler(waveform)
|
|
|
|
audio = waveform.squeeze().numpy()
|
|
return audio
|