Remove unused code. (#1016)

* Remove unused code.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove rest asr code.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
PoTaTo 2025-06-07 16:59:24 +08:00 committed by GitHub
parent 94f9fa6c43
commit dbec3212ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 1 additions and 706 deletions

View File

@ -21,7 +21,6 @@ from fish_speech.content_sequence import (
TextPart,
VQPart,
)
from fish_speech.text import split_text
from fish_speech.tokenizer import IM_END_TOKEN
os.environ["TOKENIZERS_PARALLELISM"] = "false"

View File

@ -1,202 +0,0 @@
from typing import Any, Optional
import lightning as L
import torch
import torch.nn.functional as F
from lightning.pytorch.utilities.types import OptimizerLRScheduler
import fish_speech.utils as utils
from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
from fish_speech.models.text2semantic.llama import NaiveTransformer
log = utils.RankedLogger(__name__, rank_zero_only=True)
class TextToSemantic(L.LightningModule):
def __init__(
self,
model: NaiveTransformer,
optimizer: Any,
lr_scheduler: Any,
):
super().__init__()
self.model = model
self.optimizer_builder = optimizer
self.lr_scheduler_builder = lr_scheduler
def forward(self, x):
return self.model(x)
def on_save_checkpoint(self, checkpoint):
# Save only LoRA parameters
state_dict = checkpoint["state_dict"]
use_lora = any("lora" in name for name in state_dict.keys())
if not use_lora:
return
for name in list(state_dict.keys()):
if "lora" not in name:
state_dict.pop(name)
def configure_optimizers(self) -> OptimizerLRScheduler:
# Get weight decay parameters
weight_decay_parameters, other_parameters = [], []
for name, param in self.named_parameters():
if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
other_parameters.append(param)
else:
weight_decay_parameters.append(param)
optimizer = self.optimizer_builder(
[
{"params": weight_decay_parameters},
{"params": other_parameters, "weight_decay": 0.0},
]
)
# Print the parameters and their weight decay
for i in optimizer.param_groups:
log.info(
f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
)
lr_scheduler = self.lr_scheduler_builder(optimizer)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"interval": "step",
},
}
# Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
def get_batch_logps(
self,
logits: torch.FloatTensor,
labels: torch.LongTensor,
average_log_prob: bool = False,
) -> torch.FloatTensor:
"""Compute the log probabilities of the given labels under the given logits.
Args:
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
Returns:
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
"""
assert logits.shape[:-1] == labels.shape
labels = labels.clone()
loss_mask = labels != -100
# dummy token; we'll ignore the losses on these tokens later
labels[labels == -100] = 0
per_token_logps = torch.gather(
logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
).squeeze(-1)
if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
else:
return (per_token_logps * loss_mask).sum(-1)
def _step(self, batch, batch_idx, stage: str):
is_train = stage == "train"
if is_train:
# Key part to make lora work
# Otherwise the parameters are merged, which lead to incorrect gradients
self.model.train()
# Do positive and negative samples in the same batch to speed up training
labels = batch["labels"]
outputs = self.model(
inp=batch["inputs"],
key_padding_mask=batch["attention_masks"],
)
token_logits = outputs.token_logits
codebook_logits = outputs.codebook_logits
# Generate labels
base_loss = F.cross_entropy(
token_logits.view(-1, token_logits.size(-1)),
labels[:, 0].reshape(-1),
ignore_index=-100,
)
codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
semantic_loss = F.cross_entropy(
codebook_logits.view(-1, codebook_logits.size(-1)),
codebook_labels.reshape(-1),
ignore_index=-100,
)
loss = base_loss + semantic_loss
self.log(
f"{stage}/loss",
loss,
on_step=is_train,
on_epoch=not is_train,
prog_bar=True,
logger=True,
sync_dist=not is_train,
)
self.log(
f"{stage}/base_loss",
base_loss,
on_step=is_train,
on_epoch=not is_train,
prog_bar=False,
logger=True,
sync_dist=not is_train,
)
self.log(
f"{stage}/semantic_loss",
semantic_loss,
on_step=is_train,
on_epoch=not is_train,
prog_bar=False,
logger=True,
sync_dist=not is_train,
)
# Top-5 accuracy
accuracy = self.get_accuracy(codebook_logits, codebook_labels)
self.log(
f"{stage}/top_5_accuracy",
accuracy,
on_step=is_train,
on_epoch=not is_train,
prog_bar=True,
logger=True,
sync_dist=not is_train,
)
return loss
def get_accuracy(self, logits, labels):
mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
if mask.sum() == 0:
return torch.tensor(0.0, device=logits.device)
_, indices = logits.topk(5, dim=-1)
correct = indices.eq(labels.unsqueeze(-1))
correct[~mask] = 0
correct = correct.sum()
accuracy = correct / mask.sum()
return accuracy
def training_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "train")
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val")

View File

@ -1,4 +1,3 @@
from .clean import clean_text
from .spliter import split_text
__all__ = ["clean_text", "split_text"]
__all__ = ["clean_text"]

View File

@ -1,130 +0,0 @@
import re
import string
from fish_speech.text.clean import clean_text
def utf_8_len(text: str):
return len(text.encode("utf-8"))
def break_text(texts, length, splits: set):
for text in texts:
if utf_8_len(text) <= length:
yield text
continue
curr = ""
for char in text:
curr += char
if char in splits:
yield curr
curr = ""
if curr:
yield curr
def break_text_by_length(texts, length):
for text in texts:
if utf_8_len(text) <= length:
yield text
continue
curr = ""
for char in text:
curr += char
if utf_8_len(curr) >= length:
yield curr
curr = ""
if curr:
yield curr
def add_cleaned(curr, segments):
curr = curr.strip()
if curr and not all(c.isspace() or c in string.punctuation for c in curr):
segments.append(curr)
def protect_float(text):
# Turns 3.14 into <3_f_14> to prevent splitting
return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
def unprotect_float(text):
# Turns <3_f_14> into 3.14
return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
def split_text(text, length):
text = clean_text(text)
# Break the text into pieces with following rules:
# 1. Split the text at ".", "!", "?" if text is NOT a float
# 2. If the text is longer than length, split at ","
# 3. If the text is still longer than length, split at " "
# 4. If the text is still longer than length, split at any character to length
texts = [text]
texts = map(protect_float, texts)
texts = break_text(texts, length, {".", "!", "?", "", "", ""})
texts = map(unprotect_float, texts)
texts = break_text(texts, length, {",", ""})
texts = break_text(texts, length, {" "})
texts = list(break_text_by_length(texts, length))
# Then, merge the texts into segments with length <= length
segments = []
curr = ""
for text in texts:
if utf_8_len(curr) + utf_8_len(text) <= length:
curr += text
else:
add_cleaned(curr, segments)
curr = text
if curr:
add_cleaned(curr, segments)
return segments
if __name__ == "__main__":
# Test the split_text function
text = "This is a test sentence. This is another test sentence. And a third one."
assert split_text(text, 50) == [
"This is a test sentence.",
"This is another test sentence. And a third one.",
]
assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"]
assert split_text(" ", 10) == []
assert split_text("a", 10) == ["a"]
text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines."
assert split_text(text, 50) == [
"This is a test sentence with only commas,",
"and no dots, and no exclamation marks,",
"and no question marks, and no newlines.",
]
text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence."
# First half split at " ", second half split at ","
assert split_text(text, 50) == [
"This is a test sentence This is a test sentence",
"This is a test sentence. This is a test sentence,",
"This is a test sentence, This is a test sentence.",
]
text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。"
assert split_text(text, 50) == [
"这是一段很长的中文文本,",
"而且没有句号,也没有感叹号,",
"也没有问号,也没有换行符.",
]

View File

@ -27,35 +27,6 @@ class ServeAudioPart(BaseModel):
audio: bytes
class ServeASRRequest(BaseModel):
# The audio should be an uncompressed PCM float16 audio
audios: list[bytes]
sample_rate: int = 44100
language: Literal["zh", "en", "ja", "auto"] = "auto"
class ServeASRTranscription(BaseModel):
text: str
duration: float
huge_gap: bool
class ServeASRSegment(BaseModel):
text: str
start: float
end: float
class ServeTimedASRResponse(BaseModel):
text: str
segments: list[ServeASRSegment]
duration: float
class ServeASRResponse(BaseModel):
transcriptions: list[ServeASRTranscription]
class ServeRequest(BaseModel):
# Raw content sequence dict that we can use with ContentSequence(**content)
content: dict
@ -86,18 +57,6 @@ class ServeVQGANDecodeResponse(BaseModel):
audios: list[bytes]
class ServeStreamDelta(BaseModel):
role: Literal["system", "assistant", "user"] | None = None
part: ServeVQPart | ServeTextPart | None = None
class ServeStreamResponse(BaseModel):
sample_id: int = 0
delta: ServeStreamDelta | None = None
finish_reason: Literal["stop", "error"] | None = None
stats: dict[str, int | float | str] | None = None
class ServeReferenceAudio(BaseModel):
audio: bytes
text: str

View File

@ -36,9 +36,7 @@ dependencies = [
"zstandard>=0.22.0",
"pydub",
"pyaudio",
"faster_whisper",
"modelscope==1.17.1",
"funasr==1.1.5",
"opencc-python-reimplemented==0.1.7",
"silero-vad",
"ormsgpack",

View File

@ -85,7 +85,6 @@ class API(ExceptionHandler):
device=self.args.device,
half=self.args.half,
compile=self.args.compile,
asr_enabled=self.args.load_asr_model,
llama_checkpoint_path=self.args.llama_checkpoint_path,
decoder_checkpoint_path=self.args.decoder_checkpoint_path,
decoder_config_name=self.args.decoder_config_name,

View File

@ -14,7 +14,6 @@ from tools.server.inference import inference_wrapper as inference
def parse_args():
parser = ArgumentParser()
parser.add_argument("--mode", type=str, choices=["tts"], default="tts")
parser.add_argument("--load-asr-model", action="store_true")
parser.add_argument(
"--llama-checkpoint-path",
type=str,

View File

@ -1,5 +1,4 @@
import torch
from funasr import AutoModel
from loguru import logger
from fish_speech.inference_engine import TTSInferenceEngine
@ -8,8 +7,6 @@ from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
from fish_speech.utils.schema import ServeTTSRequest
from tools.server.inference import inference_wrapper as inference
ASR_MODEL_NAME = "iic/SenseVoiceSmall"
class ModelManager:
def __init__(
@ -61,15 +58,6 @@ class ModelManager:
if self.mode == "tts":
self.warm_up(self.tts_inference_engine)
def load_asr_model(self, device, hub="ms") -> None:
self.asr_model = AutoModel(
model=ASR_MODEL_NAME,
device=device,
disable_pbar=True,
hub=hub,
)
logger.info("ASR model loaded.")
def load_llama_model(
self, checkpoint_path, device, precision, compile, mode
) -> None:

View File

@ -80,50 +80,3 @@ def batch_vqgan_decode(model, features):
audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
@torch.no_grad()
def batch_asr(model, lock, audios, sr, language="auto"):
resampled_audios = []
for audio in audios:
audio = torchaudio.functional.resample(audio, sr, ASR_SAMPLE_RATE)
assert audio.ndim == 1
resampled_audios.append(audio)
with lock:
res = model.generate(
input=resampled_audios,
batch_size=len(resampled_audios),
language=language,
use_itn=True,
)
results = []
for r, audio in zip(res, audios):
text = r["text"]
text = re.sub(r"<\|.*?\|>", "", text)
duration = len(audio) / sr * 1000
huge_gap = False
if "timestamp" in r and len(r["timestamp"]) > 2:
for timestamp_a, timestamp_b in zip(
r["timestamp"][:-1], r["timestamp"][1:]
):
# If there is a gap of more than 4 seconds, we consider it as a huge gap
if timestamp_b[0] - timestamp_a[1] > HUGE_GAP_THRESHOLD:
huge_gap = True
break
# Doesn't make sense to have a huge gap at the end
if duration - r["timestamp"][-1][1] > HUGE_GAP_THRESHOLD:
huge_gap = True
results.append(
{
"text": text,
"duration": duration,
"huge_gap": huge_gap,
}
)
return results

View File

@ -20,8 +20,6 @@ from loguru import logger
from typing_extensions import Annotated
from fish_speech.utils.schema import (
ServeASRRequest,
ServeASRResponse,
ServeTTSRequest,
ServeVQGANDecodeRequest,
ServeVQGANDecodeResponse,
@ -95,33 +93,6 @@ async def vqgan_decode(req: Annotated[ServeVQGANDecodeRequest, Body(exclusive=Tr
)
@routes.http.post("/v1/asr")
async def asr(req: Annotated[ServeASRRequest, Body(exclusive=True)]):
# Get the model from the app
model_manager: ModelManager = request.app.state.model_manager
asr_model = model_manager.asr_model
lock = request.app.state.lock
# Perform ASR
start_time = time.time()
audios = [np.frombuffer(audio, dtype=np.float16) for audio in req.audios]
audios = [torch.from_numpy(audio).float() for audio in audios]
if any(audios.shape[-1] >= 30 * req.sample_rate for audios in audios):
raise HTTPException(status_code=400, content="Audio length is too long")
transcriptions = batch_asr(
asr_model, lock, audios=audios, sr=req.sample_rate, language=req.language
)
logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
# Return the response
return ormsgpack.packb(
ServeASRResponse(transcriptions=transcriptions),
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
)
@routes.http.post("/v1/tts")
async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]):
# Get the model from the app

View File

@ -1,60 +0,0 @@
import random
from multiprocessing import Pool
from pathlib import Path
import click
import librosa
import torch.nn.functional as F
import torchaudio
from tqdm import tqdm
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
threshold = 10 ** (-50 / 20.0)
def process(file):
waveform, sample_rate = torchaudio.load(str(file), backend="sox")
if waveform.size(0) > 1:
waveform = waveform.mean(dim=0, keepdim=True)
loudness = librosa.feature.rms(
y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True
)[0]
for i in range(len(loudness) - 1, 0, -1):
if loudness[i] > threshold:
break
end_silent_time = (len(loudness) - i) * 512 / sample_rate
if end_silent_time <= 0.3:
random_time = random.uniform(0.3, 0.7) - end_silent_time
waveform = F.pad(
waveform, (0, int(random_time * sample_rate)), mode="constant", value=0
)
for i in range(len(loudness)):
if loudness[i] > threshold:
break
start_silent_time = i * 512 / sample_rate
if start_silent_time > 0.02:
waveform = waveform[:, int((start_silent_time - 0.02) * sample_rate) :]
torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate)
@click.command()
@click.argument("source", type=Path)
@click.option("--num-workers", type=int, default=12)
def main(source, num_workers):
files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True))
with Pool(num_workers) as p:
list(tqdm(p.imap_unordered(process, files), total=len(files)))
if __name__ == "__main__":
main()

View File

@ -1,176 +0,0 @@
"""
Used to transcribe all audio files in one folder into another folder.
e.g.
Directory structure:
--pre_data_root
----SP_1
------01.wav
------02.wav
------......
----SP_2
------01.wav
------02.wav
------......
Use
python tools/whisper_asr.py --audio-dir pre_data_root/SP_1 --save-dir data/SP_1
to transcribe the first speaker.
Use
python tools/whisper_asr.py --audio-dir pre_data_root/SP_2 --save-dir data/SP_2
to transcribe the second speaker.
Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
"""
import re
from pathlib import Path
import click
import soundfile as sf
from faster_whisper import WhisperModel
from loguru import logger
from pydub import AudioSegment
from tqdm import tqdm
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
@click.command()
@click.option("--model-size", default="large-v3", help="Size of the Whisper model")
@click.option(
"--compute-type",
default="float16",
help="Computation Precision of the Whisper model [float16 / int8_float16 / int8]",
)
@click.option("--audio-dir", required=True, help="Directory containing audio files")
@click.option(
"--save-dir", required=True, help="Directory to save processed audio files"
)
@click.option(
"--sample-rate",
default=44100,
type=int,
help="Output sample rate, default to input sample rate",
)
@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
@click.option("--language", default="auto", help="Language of the transcription")
@click.option("--initial-prompt", default=None, help="Initial prompt for transcribing")
def main(
model_size,
compute_type,
audio_dir,
save_dir,
sample_rate,
device,
language,
initial_prompt,
):
logger.info("Loading / Downloading Faster Whisper model...")
model = WhisperModel(
model_size,
device=device,
compute_type=compute_type,
download_root="faster_whisper",
)
logger.info("Model loaded.")
save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True)
audio_files = list_files(
path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
)
for file_path in tqdm(audio_files, desc="Processing audio file"):
file_stem = file_path.stem
file_suffix = file_path.suffix
rel_path = Path(file_path).relative_to(audio_dir)
(save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
audio = AudioSegment.from_file(file_path)
segments, info = model.transcribe(
file_path,
beam_size=5,
language=None if language == "auto" else language,
initial_prompt=initial_prompt,
)
print(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
)
print("Total len(ms): ", len(audio))
whole_text = None
for segment in segments:
id, start, end, text = (
segment.id,
segment.start,
segment.end,
segment.text,
)
print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text))
if not whole_text:
whole_text = text
else:
whole_text += ", " + text
whole_text += "."
audio_save_path = save_path / rel_path.parent / f"{file_stem}{file_suffix}"
audio.export(audio_save_path, format=file_suffix[1:])
print(f"Exported {audio_save_path}")
transcript_save_path = save_path / rel_path.parent / f"{file_stem}.lab"
with open(
transcript_save_path,
"w",
encoding="utf-8",
) as f:
f.write(whole_text)
if __name__ == "__main__":
main()
exit(0)
audio = AudioSegment.from_wav(
r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav"
)
model_size = "large-v3"
model = WhisperModel(
model_size,
device="cuda",
compute_type="float16",
download_root="faster_whisper",
)
segments, info = model.transcribe(
r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav",
beam_size=5,
)
print(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
)
print("Total len(ms): ", len(audio))
for i, segment in enumerate(segments):
print(
"Segment %03d [%.2fs -> %.2fs] %s"
% (i, segment.start, segment.end, segment.text)
)
start_ms = int(segment.start * 1000)
end_ms = int(segment.end * 1000)
segment_audio = audio[start_ms:end_ms]
segment_audio.export(f"segment_{i:03d}.wav", format="wav")
print(f"Exported segment_{i:03d}.wav")
print("All segments have been exported.")

2
uv.lock generated
View File

@ -945,8 +945,6 @@ dependencies = [
{ name = "descript-audiotools" },
{ name = "einops" },
{ name = "einx", extra = ["torch"] },
{ name = "faster-whisper" },
{ name = "funasr" },
{ name = "gradio" },
{ name = "grpcio" },
{ name = "hydra-core" },