Handle adaptive number of codebooks

This commit is contained in:
Lengyue 2024-03-02 01:45:29 +00:00
parent 3ec16024dd
commit 5707699dfd
7 changed files with 28 additions and 7 deletions

0
.project-root Normal file
View File

View File

@ -26,11 +26,13 @@ train_dataset:
_target_: fish_speech.datasets.text.AutoAugTextDataset
tokenizer: ${tokenizer}
max_length: ${max_length}
num_codebooks: ${model.model.config.num_codebooks}
val_dataset:
_target_: fish_speech.datasets.text.AutoAugTextDataset
tokenizer: ${tokenizer}
max_length: ${max_length}
num_codebooks: ${model.model.config.num_codebooks}
data:
_target_: fish_speech.datasets.text.TextDataModule

View File

@ -27,11 +27,13 @@ train_dataset:
_target_: fish_speech.datasets.text.AutoAugTextDataset
tokenizer: ${tokenizer}
max_length: ${max_length}
num_codebooks: ${model.model.config.num_codebooks}
val_dataset:
_target_: fish_speech.datasets.text.AutoAugTextDataset
tokenizer: ${tokenizer}
max_length: ${max_length}
num_codebooks: ${model.model.config.num_codebooks}
data:
_target_: fish_speech.datasets.text.TextDataModule

View File

@ -2,7 +2,7 @@ defaults:
- base
- _self_
project: text2semantic_pretrain_400m_8_codebooks
project: text2semantic_pretrain_400m_4_codebooks
max_length: 2048
# Lightning Trainer
@ -24,6 +24,7 @@ train_dataset:
_target_: fish_speech.datasets.text.AutoAugTextDataset
tokenizer: ${tokenizer}
max_length: ${max_length}
num_codebooks: ${model.model.config.num_codebooks}
use_speaker: false
phones_prob: 0.5
interactive_prob: 0.5
@ -32,6 +33,7 @@ val_dataset:
_target_: fish_speech.datasets.text.AutoAugTextDataset
tokenizer: ${tokenizer}
max_length: ${max_length}
num_codebooks: ${model.model.config.num_codebooks}
use_speaker: false
phones_prob: 0.5
interactive_prob: 0.5
@ -61,7 +63,7 @@ model:
dim: 1024
rope_base: 10000
norm_eps: 1e-5
num_codebooks: 8 # single codebook
num_codebooks: 4 # single codebook
codebook_size: 264 # codebook size 256 + 2 special tokens
dropout: 0.1
neft_alpha: 10

View File

@ -24,6 +24,7 @@ train_dataset:
_target_: fish_speech.datasets.text.AutoAugTextDataset
tokenizer: ${tokenizer}
max_length: ${max_length}
num_codebooks: ${model.model.config.num_codebooks}
use_speaker: true
phones_prob: 0.5
interactive_prob: 0.5
@ -33,6 +34,7 @@ val_dataset:
_target_: fish_speech.datasets.text.AutoAugTextDataset
tokenizer: ${tokenizer}
max_length: ${max_length}
num_codebooks: ${model.model.config.num_codebooks}
use_speaker: true
phones_prob: 0.5
interactive_prob: 0.5
@ -50,7 +52,6 @@ data:
# Model Configuration
model:
_target_: fish_speech.models.text2semantic.TextToSemantic
use_dpo: true
model:
# ~ 130M parameters, for debug purpose

View File

@ -198,6 +198,7 @@ class AutoAugTextDataset(IterableDataset):
causual: bool = True,
mix_text_phone_prob: float = 0.5,
use_negative_samples: bool = False,
num_codebooks: Optional[int] = None,
):
"""
Args:
@ -214,6 +215,7 @@ class AutoAugTextDataset(IterableDataset):
causual: use causual sampling when using local data, disable will lead to random sampling
mix_text_phone_prob: probability to mix text and phones, if this is 0, then it will be pure text or pure phones
use_negative_samples: generate negative samples
num_codebooks: number of codebooks, if None, it will be automatically detected
"""
super().__init__()
@ -235,6 +237,7 @@ class AutoAugTextDataset(IterableDataset):
self.causual = causual
self.mix_text_phone_prob = mix_text_phone_prob
self.use_negative_samples = use_negative_samples
self.num_codebooks = num_codebooks
if use_data_server is True:
self.channel = grpc.insecure_channel(server)
@ -484,7 +487,9 @@ class AutoAugTextDataset(IterableDataset):
)
semantic_length = sum([len(i[0].values) for i in semantics])
prompt_length = len(encoded)
num_codebooks = len(semantics[0])
num_codebooks = (
len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
)
bos_bias = 1 if add_bos else 0
@ -505,7 +510,7 @@ class AutoAugTextDataset(IterableDataset):
for i in range(num_codebooks)
]
for segment in semantics:
for book_idx, book in enumerate(segment):
for book_idx, book in zip(range(num_codebooks), segment):
for j in book.values:
codes[book_idx].append(int(j) + 2)
@ -520,8 +525,7 @@ class AutoAugTextDataset(IterableDataset):
# Mask out the <s> tokens for semantic, predict semantic tokens only
# Since we don't mask out the input tokens, the language modeling still works
# labels[1:, : (prompt_length + bos_bias)] = -100
labels[:, : (prompt_length + bos_bias)] = -100
labels[1:, : (prompt_length + bos_bias)] = -100
tokens = tokens[:, :-1]
labels = labels[:, 1:]
@ -677,6 +681,7 @@ if __name__ == "__main__":
interactive_prob=1.0,
phones_prob=1.0,
use_negative_samples=False,
num_codebooks=4,
)
# ds = AutoAugTextDataset(

View File

@ -1,7 +1,9 @@
import os
from typing import Optional
import hydra
import lightning as L
import pyrootutils
import torch
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers import Logger
@ -9,6 +11,13 @@ from omegaconf import DictConfig, OmegaConf
import fish_speech.utils as utils
os.environ.pop("SLURM_NTASKS", None)
os.environ.pop("SLURM_JOB_NAME", None)
os.environ.pop("SLURM_NTASKS_PER_NODE", None)
# register eval resolver and root
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
# Allow TF32 on Ampere GPUs
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.allow_tf32 = True