Handle adaptive number of codebooks
This commit is contained in:
parent
3ec16024dd
commit
5707699dfd
0
.project-root
Normal file
0
.project-root
Normal file
@ -26,11 +26,13 @@ train_dataset:
|
||||
_target_: fish_speech.datasets.text.AutoAugTextDataset
|
||||
tokenizer: ${tokenizer}
|
||||
max_length: ${max_length}
|
||||
num_codebooks: ${model.model.config.num_codebooks}
|
||||
|
||||
val_dataset:
|
||||
_target_: fish_speech.datasets.text.AutoAugTextDataset
|
||||
tokenizer: ${tokenizer}
|
||||
max_length: ${max_length}
|
||||
num_codebooks: ${model.model.config.num_codebooks}
|
||||
|
||||
data:
|
||||
_target_: fish_speech.datasets.text.TextDataModule
|
||||
|
@ -27,11 +27,13 @@ train_dataset:
|
||||
_target_: fish_speech.datasets.text.AutoAugTextDataset
|
||||
tokenizer: ${tokenizer}
|
||||
max_length: ${max_length}
|
||||
num_codebooks: ${model.model.config.num_codebooks}
|
||||
|
||||
val_dataset:
|
||||
_target_: fish_speech.datasets.text.AutoAugTextDataset
|
||||
tokenizer: ${tokenizer}
|
||||
max_length: ${max_length}
|
||||
num_codebooks: ${model.model.config.num_codebooks}
|
||||
|
||||
data:
|
||||
_target_: fish_speech.datasets.text.TextDataModule
|
||||
|
@ -2,7 +2,7 @@ defaults:
|
||||
- base
|
||||
- _self_
|
||||
|
||||
project: text2semantic_pretrain_400m_8_codebooks
|
||||
project: text2semantic_pretrain_400m_4_codebooks
|
||||
max_length: 2048
|
||||
|
||||
# Lightning Trainer
|
||||
@ -24,6 +24,7 @@ train_dataset:
|
||||
_target_: fish_speech.datasets.text.AutoAugTextDataset
|
||||
tokenizer: ${tokenizer}
|
||||
max_length: ${max_length}
|
||||
num_codebooks: ${model.model.config.num_codebooks}
|
||||
use_speaker: false
|
||||
phones_prob: 0.5
|
||||
interactive_prob: 0.5
|
||||
@ -32,6 +33,7 @@ val_dataset:
|
||||
_target_: fish_speech.datasets.text.AutoAugTextDataset
|
||||
tokenizer: ${tokenizer}
|
||||
max_length: ${max_length}
|
||||
num_codebooks: ${model.model.config.num_codebooks}
|
||||
use_speaker: false
|
||||
phones_prob: 0.5
|
||||
interactive_prob: 0.5
|
||||
@ -61,7 +63,7 @@ model:
|
||||
dim: 1024
|
||||
rope_base: 10000
|
||||
norm_eps: 1e-5
|
||||
num_codebooks: 8 # single codebook
|
||||
num_codebooks: 4 # single codebook
|
||||
codebook_size: 264 # codebook size 256 + 2 special tokens
|
||||
dropout: 0.1
|
||||
neft_alpha: 10
|
||||
|
@ -24,6 +24,7 @@ train_dataset:
|
||||
_target_: fish_speech.datasets.text.AutoAugTextDataset
|
||||
tokenizer: ${tokenizer}
|
||||
max_length: ${max_length}
|
||||
num_codebooks: ${model.model.config.num_codebooks}
|
||||
use_speaker: true
|
||||
phones_prob: 0.5
|
||||
interactive_prob: 0.5
|
||||
@ -33,6 +34,7 @@ val_dataset:
|
||||
_target_: fish_speech.datasets.text.AutoAugTextDataset
|
||||
tokenizer: ${tokenizer}
|
||||
max_length: ${max_length}
|
||||
num_codebooks: ${model.model.config.num_codebooks}
|
||||
use_speaker: true
|
||||
phones_prob: 0.5
|
||||
interactive_prob: 0.5
|
||||
@ -50,7 +52,6 @@ data:
|
||||
# Model Configuration
|
||||
model:
|
||||
_target_: fish_speech.models.text2semantic.TextToSemantic
|
||||
use_dpo: true
|
||||
|
||||
model:
|
||||
# ~ 130M parameters, for debug purpose
|
||||
|
@ -198,6 +198,7 @@ class AutoAugTextDataset(IterableDataset):
|
||||
causual: bool = True,
|
||||
mix_text_phone_prob: float = 0.5,
|
||||
use_negative_samples: bool = False,
|
||||
num_codebooks: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -214,6 +215,7 @@ class AutoAugTextDataset(IterableDataset):
|
||||
causual: use causual sampling when using local data, disable will lead to random sampling
|
||||
mix_text_phone_prob: probability to mix text and phones, if this is 0, then it will be pure text or pure phones
|
||||
use_negative_samples: generate negative samples
|
||||
num_codebooks: number of codebooks, if None, it will be automatically detected
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
@ -235,6 +237,7 @@ class AutoAugTextDataset(IterableDataset):
|
||||
self.causual = causual
|
||||
self.mix_text_phone_prob = mix_text_phone_prob
|
||||
self.use_negative_samples = use_negative_samples
|
||||
self.num_codebooks = num_codebooks
|
||||
|
||||
if use_data_server is True:
|
||||
self.channel = grpc.insecure_channel(server)
|
||||
@ -484,7 +487,9 @@ class AutoAugTextDataset(IterableDataset):
|
||||
)
|
||||
semantic_length = sum([len(i[0].values) for i in semantics])
|
||||
prompt_length = len(encoded)
|
||||
num_codebooks = len(semantics[0])
|
||||
num_codebooks = (
|
||||
len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
|
||||
)
|
||||
|
||||
bos_bias = 1 if add_bos else 0
|
||||
|
||||
@ -505,7 +510,7 @@ class AutoAugTextDataset(IterableDataset):
|
||||
for i in range(num_codebooks)
|
||||
]
|
||||
for segment in semantics:
|
||||
for book_idx, book in enumerate(segment):
|
||||
for book_idx, book in zip(range(num_codebooks), segment):
|
||||
for j in book.values:
|
||||
codes[book_idx].append(int(j) + 2)
|
||||
|
||||
@ -520,8 +525,7 @@ class AutoAugTextDataset(IterableDataset):
|
||||
|
||||
# Mask out the <s> tokens for semantic, predict semantic tokens only
|
||||
# Since we don't mask out the input tokens, the language modeling still works
|
||||
# labels[1:, : (prompt_length + bos_bias)] = -100
|
||||
labels[:, : (prompt_length + bos_bias)] = -100
|
||||
labels[1:, : (prompt_length + bos_bias)] = -100
|
||||
|
||||
tokens = tokens[:, :-1]
|
||||
labels = labels[:, 1:]
|
||||
@ -677,6 +681,7 @@ if __name__ == "__main__":
|
||||
interactive_prob=1.0,
|
||||
phones_prob=1.0,
|
||||
use_negative_samples=False,
|
||||
num_codebooks=4,
|
||||
)
|
||||
|
||||
# ds = AutoAugTextDataset(
|
||||
|
@ -1,7 +1,9 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import hydra
|
||||
import lightning as L
|
||||
import pyrootutils
|
||||
import torch
|
||||
from lightning import Callback, LightningDataModule, LightningModule, Trainer
|
||||
from lightning.pytorch.loggers import Logger
|
||||
@ -9,6 +11,13 @@ from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
import fish_speech.utils as utils
|
||||
|
||||
os.environ.pop("SLURM_NTASKS", None)
|
||||
os.environ.pop("SLURM_JOB_NAME", None)
|
||||
os.environ.pop("SLURM_NTASKS_PER_NODE", None)
|
||||
|
||||
# register eval resolver and root
|
||||
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
|
||||
# Allow TF32 on Ampere GPUs
|
||||
torch.set_float32_matmul_precision("high")
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
Loading…
x
Reference in New Issue
Block a user