Remove unused code. (#1016)
* Remove unused code. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove rest asr code. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
94f9fa6c43
commit
dbec3212ef
@ -21,7 +21,6 @@ from fish_speech.content_sequence import (
|
||||
TextPart,
|
||||
VQPart,
|
||||
)
|
||||
from fish_speech.text import split_text
|
||||
from fish_speech.tokenizer import IM_END_TOKEN
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
@ -1,202 +0,0 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
import lightning as L
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from lightning.pytorch.utilities.types import OptimizerLRScheduler
|
||||
|
||||
import fish_speech.utils as utils
|
||||
from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
|
||||
from fish_speech.models.text2semantic.llama import NaiveTransformer
|
||||
|
||||
log = utils.RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
class TextToSemantic(L.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
model: NaiveTransformer,
|
||||
optimizer: Any,
|
||||
lr_scheduler: Any,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model = model
|
||||
self.optimizer_builder = optimizer
|
||||
self.lr_scheduler_builder = lr_scheduler
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
# Save only LoRA parameters
|
||||
state_dict = checkpoint["state_dict"]
|
||||
use_lora = any("lora" in name for name in state_dict.keys())
|
||||
if not use_lora:
|
||||
return
|
||||
|
||||
for name in list(state_dict.keys()):
|
||||
if "lora" not in name:
|
||||
state_dict.pop(name)
|
||||
|
||||
def configure_optimizers(self) -> OptimizerLRScheduler:
|
||||
# Get weight decay parameters
|
||||
weight_decay_parameters, other_parameters = [], []
|
||||
for name, param in self.named_parameters():
|
||||
if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
|
||||
other_parameters.append(param)
|
||||
else:
|
||||
weight_decay_parameters.append(param)
|
||||
|
||||
optimizer = self.optimizer_builder(
|
||||
[
|
||||
{"params": weight_decay_parameters},
|
||||
{"params": other_parameters, "weight_decay": 0.0},
|
||||
]
|
||||
)
|
||||
|
||||
# Print the parameters and their weight decay
|
||||
for i in optimizer.param_groups:
|
||||
log.info(
|
||||
f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
|
||||
)
|
||||
|
||||
lr_scheduler = self.lr_scheduler_builder(optimizer)
|
||||
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": lr_scheduler,
|
||||
"interval": "step",
|
||||
},
|
||||
}
|
||||
|
||||
# Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
|
||||
def get_batch_logps(
|
||||
self,
|
||||
logits: torch.FloatTensor,
|
||||
labels: torch.LongTensor,
|
||||
average_log_prob: bool = False,
|
||||
) -> torch.FloatTensor:
|
||||
"""Compute the log probabilities of the given labels under the given logits.
|
||||
|
||||
Args:
|
||||
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
|
||||
labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
|
||||
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
||||
|
||||
Returns:
|
||||
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
||||
"""
|
||||
assert logits.shape[:-1] == labels.shape
|
||||
|
||||
labels = labels.clone()
|
||||
loss_mask = labels != -100
|
||||
|
||||
# dummy token; we'll ignore the losses on these tokens later
|
||||
labels[labels == -100] = 0
|
||||
|
||||
per_token_logps = torch.gather(
|
||||
logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
|
||||
if average_log_prob:
|
||||
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
||||
else:
|
||||
return (per_token_logps * loss_mask).sum(-1)
|
||||
|
||||
def _step(self, batch, batch_idx, stage: str):
|
||||
is_train = stage == "train"
|
||||
|
||||
if is_train:
|
||||
# Key part to make lora work
|
||||
# Otherwise the parameters are merged, which lead to incorrect gradients
|
||||
self.model.train()
|
||||
|
||||
# Do positive and negative samples in the same batch to speed up training
|
||||
labels = batch["labels"]
|
||||
outputs = self.model(
|
||||
inp=batch["inputs"],
|
||||
key_padding_mask=batch["attention_masks"],
|
||||
)
|
||||
token_logits = outputs.token_logits
|
||||
codebook_logits = outputs.codebook_logits
|
||||
|
||||
# Generate labels
|
||||
base_loss = F.cross_entropy(
|
||||
token_logits.view(-1, token_logits.size(-1)),
|
||||
labels[:, 0].reshape(-1),
|
||||
ignore_index=-100,
|
||||
)
|
||||
|
||||
codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
|
||||
semantic_loss = F.cross_entropy(
|
||||
codebook_logits.view(-1, codebook_logits.size(-1)),
|
||||
codebook_labels.reshape(-1),
|
||||
ignore_index=-100,
|
||||
)
|
||||
|
||||
loss = base_loss + semantic_loss
|
||||
|
||||
self.log(
|
||||
f"{stage}/loss",
|
||||
loss,
|
||||
on_step=is_train,
|
||||
on_epoch=not is_train,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
sync_dist=not is_train,
|
||||
)
|
||||
|
||||
self.log(
|
||||
f"{stage}/base_loss",
|
||||
base_loss,
|
||||
on_step=is_train,
|
||||
on_epoch=not is_train,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
sync_dist=not is_train,
|
||||
)
|
||||
|
||||
self.log(
|
||||
f"{stage}/semantic_loss",
|
||||
semantic_loss,
|
||||
on_step=is_train,
|
||||
on_epoch=not is_train,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
sync_dist=not is_train,
|
||||
)
|
||||
|
||||
# Top-5 accuracy
|
||||
accuracy = self.get_accuracy(codebook_logits, codebook_labels)
|
||||
self.log(
|
||||
f"{stage}/top_5_accuracy",
|
||||
accuracy,
|
||||
on_step=is_train,
|
||||
on_epoch=not is_train,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
sync_dist=not is_train,
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def get_accuracy(self, logits, labels):
|
||||
mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
|
||||
if mask.sum() == 0:
|
||||
return torch.tensor(0.0, device=logits.device)
|
||||
|
||||
_, indices = logits.topk(5, dim=-1)
|
||||
correct = indices.eq(labels.unsqueeze(-1))
|
||||
correct[~mask] = 0
|
||||
correct = correct.sum()
|
||||
accuracy = correct / mask.sum()
|
||||
|
||||
return accuracy
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
return self._step(batch, batch_idx, "train")
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
return self._step(batch, batch_idx, "val")
|
@ -1,4 +1,3 @@
|
||||
from .clean import clean_text
|
||||
from .spliter import split_text
|
||||
|
||||
__all__ = ["clean_text", "split_text"]
|
||||
__all__ = ["clean_text"]
|
||||
|
@ -1,130 +0,0 @@
|
||||
import re
|
||||
import string
|
||||
|
||||
from fish_speech.text.clean import clean_text
|
||||
|
||||
|
||||
def utf_8_len(text: str):
|
||||
return len(text.encode("utf-8"))
|
||||
|
||||
|
||||
def break_text(texts, length, splits: set):
|
||||
for text in texts:
|
||||
if utf_8_len(text) <= length:
|
||||
yield text
|
||||
continue
|
||||
|
||||
curr = ""
|
||||
for char in text:
|
||||
curr += char
|
||||
|
||||
if char in splits:
|
||||
yield curr
|
||||
curr = ""
|
||||
|
||||
if curr:
|
||||
yield curr
|
||||
|
||||
|
||||
def break_text_by_length(texts, length):
|
||||
for text in texts:
|
||||
if utf_8_len(text) <= length:
|
||||
yield text
|
||||
continue
|
||||
|
||||
curr = ""
|
||||
for char in text:
|
||||
curr += char
|
||||
|
||||
if utf_8_len(curr) >= length:
|
||||
yield curr
|
||||
curr = ""
|
||||
|
||||
if curr:
|
||||
yield curr
|
||||
|
||||
|
||||
def add_cleaned(curr, segments):
|
||||
curr = curr.strip()
|
||||
if curr and not all(c.isspace() or c in string.punctuation for c in curr):
|
||||
segments.append(curr)
|
||||
|
||||
|
||||
def protect_float(text):
|
||||
# Turns 3.14 into <3_f_14> to prevent splitting
|
||||
return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
|
||||
|
||||
|
||||
def unprotect_float(text):
|
||||
# Turns <3_f_14> into 3.14
|
||||
return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
|
||||
|
||||
|
||||
def split_text(text, length):
|
||||
text = clean_text(text)
|
||||
|
||||
# Break the text into pieces with following rules:
|
||||
# 1. Split the text at ".", "!", "?" if text is NOT a float
|
||||
# 2. If the text is longer than length, split at ","
|
||||
# 3. If the text is still longer than length, split at " "
|
||||
# 4. If the text is still longer than length, split at any character to length
|
||||
|
||||
texts = [text]
|
||||
texts = map(protect_float, texts)
|
||||
texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"})
|
||||
texts = map(unprotect_float, texts)
|
||||
texts = break_text(texts, length, {",", ","})
|
||||
texts = break_text(texts, length, {" "})
|
||||
texts = list(break_text_by_length(texts, length))
|
||||
|
||||
# Then, merge the texts into segments with length <= length
|
||||
segments = []
|
||||
curr = ""
|
||||
|
||||
for text in texts:
|
||||
if utf_8_len(curr) + utf_8_len(text) <= length:
|
||||
curr += text
|
||||
else:
|
||||
add_cleaned(curr, segments)
|
||||
curr = text
|
||||
|
||||
if curr:
|
||||
add_cleaned(curr, segments)
|
||||
|
||||
return segments
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the split_text function
|
||||
|
||||
text = "This is a test sentence. This is another test sentence. And a third one."
|
||||
|
||||
assert split_text(text, 50) == [
|
||||
"This is a test sentence.",
|
||||
"This is another test sentence. And a third one.",
|
||||
]
|
||||
assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"]
|
||||
assert split_text(" ", 10) == []
|
||||
assert split_text("a", 10) == ["a"]
|
||||
|
||||
text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines."
|
||||
assert split_text(text, 50) == [
|
||||
"This is a test sentence with only commas,",
|
||||
"and no dots, and no exclamation marks,",
|
||||
"and no question marks, and no newlines.",
|
||||
]
|
||||
|
||||
text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence."
|
||||
# First half split at " ", second half split at ","
|
||||
assert split_text(text, 50) == [
|
||||
"This is a test sentence This is a test sentence",
|
||||
"This is a test sentence. This is a test sentence,",
|
||||
"This is a test sentence, This is a test sentence.",
|
||||
]
|
||||
|
||||
text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。"
|
||||
assert split_text(text, 50) == [
|
||||
"这是一段很长的中文文本,",
|
||||
"而且没有句号,也没有感叹号,",
|
||||
"也没有问号,也没有换行符.",
|
||||
]
|
@ -27,35 +27,6 @@ class ServeAudioPart(BaseModel):
|
||||
audio: bytes
|
||||
|
||||
|
||||
class ServeASRRequest(BaseModel):
|
||||
# The audio should be an uncompressed PCM float16 audio
|
||||
audios: list[bytes]
|
||||
sample_rate: int = 44100
|
||||
language: Literal["zh", "en", "ja", "auto"] = "auto"
|
||||
|
||||
|
||||
class ServeASRTranscription(BaseModel):
|
||||
text: str
|
||||
duration: float
|
||||
huge_gap: bool
|
||||
|
||||
|
||||
class ServeASRSegment(BaseModel):
|
||||
text: str
|
||||
start: float
|
||||
end: float
|
||||
|
||||
|
||||
class ServeTimedASRResponse(BaseModel):
|
||||
text: str
|
||||
segments: list[ServeASRSegment]
|
||||
duration: float
|
||||
|
||||
|
||||
class ServeASRResponse(BaseModel):
|
||||
transcriptions: list[ServeASRTranscription]
|
||||
|
||||
|
||||
class ServeRequest(BaseModel):
|
||||
# Raw content sequence dict that we can use with ContentSequence(**content)
|
||||
content: dict
|
||||
@ -86,18 +57,6 @@ class ServeVQGANDecodeResponse(BaseModel):
|
||||
audios: list[bytes]
|
||||
|
||||
|
||||
class ServeStreamDelta(BaseModel):
|
||||
role: Literal["system", "assistant", "user"] | None = None
|
||||
part: ServeVQPart | ServeTextPart | None = None
|
||||
|
||||
|
||||
class ServeStreamResponse(BaseModel):
|
||||
sample_id: int = 0
|
||||
delta: ServeStreamDelta | None = None
|
||||
finish_reason: Literal["stop", "error"] | None = None
|
||||
stats: dict[str, int | float | str] | None = None
|
||||
|
||||
|
||||
class ServeReferenceAudio(BaseModel):
|
||||
audio: bytes
|
||||
text: str
|
||||
|
@ -36,9 +36,7 @@ dependencies = [
|
||||
"zstandard>=0.22.0",
|
||||
"pydub",
|
||||
"pyaudio",
|
||||
"faster_whisper",
|
||||
"modelscope==1.17.1",
|
||||
"funasr==1.1.5",
|
||||
"opencc-python-reimplemented==0.1.7",
|
||||
"silero-vad",
|
||||
"ormsgpack",
|
||||
|
@ -85,7 +85,6 @@ class API(ExceptionHandler):
|
||||
device=self.args.device,
|
||||
half=self.args.half,
|
||||
compile=self.args.compile,
|
||||
asr_enabled=self.args.load_asr_model,
|
||||
llama_checkpoint_path=self.args.llama_checkpoint_path,
|
||||
decoder_checkpoint_path=self.args.decoder_checkpoint_path,
|
||||
decoder_config_name=self.args.decoder_config_name,
|
||||
|
@ -14,7 +14,6 @@ from tools.server.inference import inference_wrapper as inference
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--mode", type=str, choices=["tts"], default="tts")
|
||||
parser.add_argument("--load-asr-model", action="store_true")
|
||||
parser.add_argument(
|
||||
"--llama-checkpoint-path",
|
||||
type=str,
|
||||
|
@ -1,5 +1,4 @@
|
||||
import torch
|
||||
from funasr import AutoModel
|
||||
from loguru import logger
|
||||
|
||||
from fish_speech.inference_engine import TTSInferenceEngine
|
||||
@ -8,8 +7,6 @@ from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
|
||||
from fish_speech.utils.schema import ServeTTSRequest
|
||||
from tools.server.inference import inference_wrapper as inference
|
||||
|
||||
ASR_MODEL_NAME = "iic/SenseVoiceSmall"
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(
|
||||
@ -61,15 +58,6 @@ class ModelManager:
|
||||
if self.mode == "tts":
|
||||
self.warm_up(self.tts_inference_engine)
|
||||
|
||||
def load_asr_model(self, device, hub="ms") -> None:
|
||||
self.asr_model = AutoModel(
|
||||
model=ASR_MODEL_NAME,
|
||||
device=device,
|
||||
disable_pbar=True,
|
||||
hub=hub,
|
||||
)
|
||||
logger.info("ASR model loaded.")
|
||||
|
||||
def load_llama_model(
|
||||
self, checkpoint_path, device, precision, compile, mode
|
||||
) -> None:
|
||||
|
@ -80,50 +80,3 @@ def batch_vqgan_decode(model, features):
|
||||
audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
|
||||
|
||||
return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def batch_asr(model, lock, audios, sr, language="auto"):
|
||||
resampled_audios = []
|
||||
for audio in audios:
|
||||
audio = torchaudio.functional.resample(audio, sr, ASR_SAMPLE_RATE)
|
||||
assert audio.ndim == 1
|
||||
resampled_audios.append(audio)
|
||||
|
||||
with lock:
|
||||
res = model.generate(
|
||||
input=resampled_audios,
|
||||
batch_size=len(resampled_audios),
|
||||
language=language,
|
||||
use_itn=True,
|
||||
)
|
||||
|
||||
results = []
|
||||
for r, audio in zip(res, audios):
|
||||
text = r["text"]
|
||||
text = re.sub(r"<\|.*?\|>", "", text)
|
||||
duration = len(audio) / sr * 1000
|
||||
huge_gap = False
|
||||
|
||||
if "timestamp" in r and len(r["timestamp"]) > 2:
|
||||
for timestamp_a, timestamp_b in zip(
|
||||
r["timestamp"][:-1], r["timestamp"][1:]
|
||||
):
|
||||
# If there is a gap of more than 4 seconds, we consider it as a huge gap
|
||||
if timestamp_b[0] - timestamp_a[1] > HUGE_GAP_THRESHOLD:
|
||||
huge_gap = True
|
||||
break
|
||||
|
||||
# Doesn't make sense to have a huge gap at the end
|
||||
if duration - r["timestamp"][-1][1] > HUGE_GAP_THRESHOLD:
|
||||
huge_gap = True
|
||||
|
||||
results.append(
|
||||
{
|
||||
"text": text,
|
||||
"duration": duration,
|
||||
"huge_gap": huge_gap,
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
@ -20,8 +20,6 @@ from loguru import logger
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from fish_speech.utils.schema import (
|
||||
ServeASRRequest,
|
||||
ServeASRResponse,
|
||||
ServeTTSRequest,
|
||||
ServeVQGANDecodeRequest,
|
||||
ServeVQGANDecodeResponse,
|
||||
@ -95,33 +93,6 @@ async def vqgan_decode(req: Annotated[ServeVQGANDecodeRequest, Body(exclusive=Tr
|
||||
)
|
||||
|
||||
|
||||
@routes.http.post("/v1/asr")
|
||||
async def asr(req: Annotated[ServeASRRequest, Body(exclusive=True)]):
|
||||
# Get the model from the app
|
||||
model_manager: ModelManager = request.app.state.model_manager
|
||||
asr_model = model_manager.asr_model
|
||||
lock = request.app.state.lock
|
||||
|
||||
# Perform ASR
|
||||
start_time = time.time()
|
||||
audios = [np.frombuffer(audio, dtype=np.float16) for audio in req.audios]
|
||||
audios = [torch.from_numpy(audio).float() for audio in audios]
|
||||
|
||||
if any(audios.shape[-1] >= 30 * req.sample_rate for audios in audios):
|
||||
raise HTTPException(status_code=400, content="Audio length is too long")
|
||||
|
||||
transcriptions = batch_asr(
|
||||
asr_model, lock, audios=audios, sr=req.sample_rate, language=req.language
|
||||
)
|
||||
logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
|
||||
|
||||
# Return the response
|
||||
return ormsgpack.packb(
|
||||
ServeASRResponse(transcriptions=transcriptions),
|
||||
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
|
||||
)
|
||||
|
||||
|
||||
@routes.http.post("/v1/tts")
|
||||
async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]):
|
||||
# Get the model from the app
|
||||
|
@ -1,60 +0,0 @@
|
||||
import random
|
||||
from multiprocessing import Pool
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import librosa
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from tqdm import tqdm
|
||||
|
||||
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
|
||||
|
||||
threshold = 10 ** (-50 / 20.0)
|
||||
|
||||
|
||||
def process(file):
|
||||
waveform, sample_rate = torchaudio.load(str(file), backend="sox")
|
||||
if waveform.size(0) > 1:
|
||||
waveform = waveform.mean(dim=0, keepdim=True)
|
||||
|
||||
loudness = librosa.feature.rms(
|
||||
y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True
|
||||
)[0]
|
||||
|
||||
for i in range(len(loudness) - 1, 0, -1):
|
||||
if loudness[i] > threshold:
|
||||
break
|
||||
|
||||
end_silent_time = (len(loudness) - i) * 512 / sample_rate
|
||||
|
||||
if end_silent_time <= 0.3:
|
||||
random_time = random.uniform(0.3, 0.7) - end_silent_time
|
||||
waveform = F.pad(
|
||||
waveform, (0, int(random_time * sample_rate)), mode="constant", value=0
|
||||
)
|
||||
|
||||
for i in range(len(loudness)):
|
||||
if loudness[i] > threshold:
|
||||
break
|
||||
|
||||
start_silent_time = i * 512 / sample_rate
|
||||
|
||||
if start_silent_time > 0.02:
|
||||
waveform = waveform[:, int((start_silent_time - 0.02) * sample_rate) :]
|
||||
|
||||
torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("source", type=Path)
|
||||
@click.option("--num-workers", type=int, default=12)
|
||||
def main(source, num_workers):
|
||||
files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True))
|
||||
|
||||
with Pool(num_workers) as p:
|
||||
list(tqdm(p.imap_unordered(process, files), total=len(files)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,176 +0,0 @@
|
||||
"""
|
||||
Used to transcribe all audio files in one folder into another folder.
|
||||
e.g.
|
||||
Directory structure:
|
||||
--pre_data_root
|
||||
----SP_1
|
||||
------01.wav
|
||||
------02.wav
|
||||
------......
|
||||
----SP_2
|
||||
------01.wav
|
||||
------02.wav
|
||||
------......
|
||||
Use
|
||||
python tools/whisper_asr.py --audio-dir pre_data_root/SP_1 --save-dir data/SP_1
|
||||
to transcribe the first speaker.
|
||||
|
||||
Use
|
||||
python tools/whisper_asr.py --audio-dir pre_data_root/SP_2 --save-dir data/SP_2
|
||||
to transcribe the second speaker.
|
||||
|
||||
Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import soundfile as sf
|
||||
from faster_whisper import WhisperModel
|
||||
from loguru import logger
|
||||
from pydub import AudioSegment
|
||||
from tqdm import tqdm
|
||||
|
||||
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--model-size", default="large-v3", help="Size of the Whisper model")
|
||||
@click.option(
|
||||
"--compute-type",
|
||||
default="float16",
|
||||
help="Computation Precision of the Whisper model [float16 / int8_float16 / int8]",
|
||||
)
|
||||
@click.option("--audio-dir", required=True, help="Directory containing audio files")
|
||||
@click.option(
|
||||
"--save-dir", required=True, help="Directory to save processed audio files"
|
||||
)
|
||||
@click.option(
|
||||
"--sample-rate",
|
||||
default=44100,
|
||||
type=int,
|
||||
help="Output sample rate, default to input sample rate",
|
||||
)
|
||||
@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
|
||||
@click.option("--language", default="auto", help="Language of the transcription")
|
||||
@click.option("--initial-prompt", default=None, help="Initial prompt for transcribing")
|
||||
def main(
|
||||
model_size,
|
||||
compute_type,
|
||||
audio_dir,
|
||||
save_dir,
|
||||
sample_rate,
|
||||
device,
|
||||
language,
|
||||
initial_prompt,
|
||||
):
|
||||
logger.info("Loading / Downloading Faster Whisper model...")
|
||||
|
||||
model = WhisperModel(
|
||||
model_size,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
download_root="faster_whisper",
|
||||
)
|
||||
|
||||
logger.info("Model loaded.")
|
||||
|
||||
save_path = Path(save_dir)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
audio_files = list_files(
|
||||
path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
|
||||
)
|
||||
|
||||
for file_path in tqdm(audio_files, desc="Processing audio file"):
|
||||
file_stem = file_path.stem
|
||||
file_suffix = file_path.suffix
|
||||
|
||||
rel_path = Path(file_path).relative_to(audio_dir)
|
||||
(save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
audio = AudioSegment.from_file(file_path)
|
||||
|
||||
segments, info = model.transcribe(
|
||||
file_path,
|
||||
beam_size=5,
|
||||
language=None if language == "auto" else language,
|
||||
initial_prompt=initial_prompt,
|
||||
)
|
||||
|
||||
print(
|
||||
"Detected language '%s' with probability %f"
|
||||
% (info.language, info.language_probability)
|
||||
)
|
||||
print("Total len(ms): ", len(audio))
|
||||
|
||||
whole_text = None
|
||||
for segment in segments:
|
||||
id, start, end, text = (
|
||||
segment.id,
|
||||
segment.start,
|
||||
segment.end,
|
||||
segment.text,
|
||||
)
|
||||
print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text))
|
||||
if not whole_text:
|
||||
whole_text = text
|
||||
else:
|
||||
whole_text += ", " + text
|
||||
|
||||
whole_text += "."
|
||||
|
||||
audio_save_path = save_path / rel_path.parent / f"{file_stem}{file_suffix}"
|
||||
audio.export(audio_save_path, format=file_suffix[1:])
|
||||
print(f"Exported {audio_save_path}")
|
||||
|
||||
transcript_save_path = save_path / rel_path.parent / f"{file_stem}.lab"
|
||||
with open(
|
||||
transcript_save_path,
|
||||
"w",
|
||||
encoding="utf-8",
|
||||
) as f:
|
||||
f.write(whole_text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
exit(0)
|
||||
|
||||
audio = AudioSegment.from_wav(
|
||||
r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav"
|
||||
)
|
||||
|
||||
model_size = "large-v3"
|
||||
|
||||
model = WhisperModel(
|
||||
model_size,
|
||||
device="cuda",
|
||||
compute_type="float16",
|
||||
download_root="faster_whisper",
|
||||
)
|
||||
|
||||
segments, info = model.transcribe(
|
||||
r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav",
|
||||
beam_size=5,
|
||||
)
|
||||
|
||||
print(
|
||||
"Detected language '%s' with probability %f"
|
||||
% (info.language, info.language_probability)
|
||||
)
|
||||
print("Total len(ms): ", len(audio))
|
||||
|
||||
for i, segment in enumerate(segments):
|
||||
print(
|
||||
"Segment %03d [%.2fs -> %.2fs] %s"
|
||||
% (i, segment.start, segment.end, segment.text)
|
||||
)
|
||||
start_ms = int(segment.start * 1000)
|
||||
end_ms = int(segment.end * 1000)
|
||||
segment_audio = audio[start_ms:end_ms]
|
||||
segment_audio.export(f"segment_{i:03d}.wav", format="wav")
|
||||
print(f"Exported segment_{i:03d}.wav")
|
||||
|
||||
print("All segments have been exported.")
|
Loading…
x
Reference in New Issue
Block a user