523 lines
19 KiB
Python
523 lines
19 KiB
Python
import math
|
||
import os
|
||
import sys
|
||
|
||
import numpy as np
|
||
import torch
|
||
from torch import nn
|
||
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
|
||
from torch.nn import functional as F
|
||
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
||
|
||
from . import attentions, commons, modules
|
||
from .commons import get_padding, init_weights
|
||
from .modules import (CausalConvTranspose1d, ConvNext2d, DilatedCausalConv1d,
|
||
LoRALinear1d, ResBlock1, WaveConv1D)
|
||
|
||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
sys.path.append(parent_dir)
|
||
|
||
|
||
class TextEncoder(nn.Module):
|
||
def __init__(
|
||
self,
|
||
out_channels: int,
|
||
hidden_channels: int,
|
||
filter_channels: int,
|
||
emb_channels: int,
|
||
gin_channels: int,
|
||
n_heads: int,
|
||
n_layers: int,
|
||
kernel_size: int,
|
||
p_dropout: int,
|
||
f0: bool = True,
|
||
):
|
||
super().__init__()
|
||
self.out_channels = out_channels
|
||
self.hidden_channels = hidden_channels
|
||
self.filter_channels = filter_channels
|
||
self.emb_channels = emb_channels
|
||
self.n_heads = n_heads
|
||
self.n_layers = n_layers
|
||
self.kernel_size = kernel_size
|
||
self.p_dropout = p_dropout
|
||
self.emb_phone = nn.Linear(emb_channels, hidden_channels)
|
||
self.lrelu = nn.LeakyReLU(0.1, inplace=True)
|
||
if f0 == True:
|
||
self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
|
||
self.emb_g = nn.Conv1d(gin_channels, hidden_channels, 1)
|
||
self.encoder = attentions.Encoder(
|
||
hidden_channels, filter_channels, gin_channels, n_heads, n_layers, kernel_size, p_dropout
|
||
)
|
||
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||
|
||
def forward(self, phone, pitch, lengths, g):
|
||
if pitch == None:
|
||
x = self.emb_phone(phone)
|
||
else:
|
||
x = self.emb_phone(phone) + self.emb_pitch(pitch)
|
||
x = torch.transpose(x, 1, -1) # [b, h, t]
|
||
x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
|
||
x.dtype
|
||
)
|
||
x = self.encoder(x * x_mask, x_mask, g)
|
||
x = self.proj(x)
|
||
|
||
return x, None, x_mask
|
||
|
||
|
||
class SineGen(torch.nn.Module):
|
||
"""Definition of sine generator
|
||
SineGen(samp_rate, harmonic_num = 0,
|
||
sine_amp = 0.1, noise_std = 0.003,
|
||
voiced_threshold = 0,
|
||
flag_for_pulse=False)
|
||
samp_rate: sampling rate in Hz
|
||
harmonic_num: number of harmonic overtones (default 0)
|
||
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
||
noise_std: std of Gaussian noise (default 0.003)
|
||
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
||
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
||
Note: when flag_for_pulse is True, the first time step of a voiced
|
||
segment is always sin(np.pi) or cos(0)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
samp_rate,
|
||
harmonic_num=0,
|
||
sine_amp=0.1,
|
||
noise_std=0.003,
|
||
voiced_threshold=0,
|
||
flag_for_pulse=False,
|
||
):
|
||
super(SineGen, self).__init__()
|
||
self.sine_amp = sine_amp
|
||
self.noise_std = noise_std
|
||
self.harmonic_num = harmonic_num
|
||
self.dim = self.harmonic_num + 1
|
||
self.sampling_rate = samp_rate
|
||
self.voiced_threshold = voiced_threshold
|
||
|
||
def _f02uv(self, f0):
|
||
# generate uv signal
|
||
uv = torch.ones_like(f0)
|
||
uv = uv * (f0 > self.voiced_threshold)
|
||
return uv
|
||
|
||
def forward(self, f0, upp):
|
||
"""sine_tensor, uv = forward(f0)
|
||
input F0: tensor(batchsize=1, length, dim=1)
|
||
f0 for unvoiced steps should be 0
|
||
output sine_tensor: tensor(batchsize=1, length, dim)
|
||
output uv: tensor(batchsize=1, length, 1)
|
||
"""
|
||
with torch.no_grad():
|
||
f0 = f0[:, None].transpose(1, 2)
|
||
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
|
||
# fundamental component
|
||
f0_buf[:, :, 0] = f0[:, :, 0]
|
||
for idx in np.arange(self.harmonic_num):
|
||
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
|
||
idx + 2
|
||
) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
|
||
rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
|
||
rand_ini = torch.rand(
|
||
f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
|
||
)
|
||
rand_ini[:, 0] = 0
|
||
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
||
tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化
|
||
tmp_over_one *= upp
|
||
tmp_over_one = F.interpolate(
|
||
tmp_over_one.transpose(2, 1),
|
||
scale_factor=upp,
|
||
mode="linear",
|
||
align_corners=True,
|
||
).transpose(2, 1)
|
||
rad_values = F.interpolate(
|
||
rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
|
||
).transpose(
|
||
2, 1
|
||
) #######
|
||
tmp_over_one %= 1
|
||
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
|
||
cumsum_shift = torch.zeros_like(rad_values)
|
||
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
||
sine_waves = torch.sin(
|
||
torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
|
||
)
|
||
sine_waves = sine_waves * self.sine_amp
|
||
uv = self._f02uv(f0)
|
||
uv = F.interpolate(
|
||
uv.transpose(2, 1), scale_factor=upp, mode="nearest"
|
||
).transpose(2, 1)
|
||
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
||
noise = noise_amp * torch.randn_like(sine_waves)
|
||
sine_waves = sine_waves * uv + noise
|
||
return sine_waves, uv, noise
|
||
|
||
|
||
class SourceModuleHnNSF(torch.nn.Module):
|
||
"""SourceModule for hn-nsf
|
||
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
||
add_noise_std=0.003, voiced_threshod=0)
|
||
sampling_rate: sampling_rate in Hz
|
||
harmonic_num: number of harmonic above F0 (default: 0)
|
||
sine_amp: amplitude of sine source signal (default: 0.1)
|
||
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
||
note that amplitude of noise in unvoiced is decided
|
||
by sine_amp
|
||
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
||
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||
F0_sampled (batchsize, length, 1)
|
||
Sine_source (batchsize, length, 1)
|
||
noise_source (batchsize, length 1)
|
||
uv (batchsize, length, 1)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
sampling_rate,
|
||
gin_channels,
|
||
harmonic_num=0,
|
||
sine_amp=0.1,
|
||
add_noise_std=0.003,
|
||
voiced_threshod=0,
|
||
is_half=True,
|
||
):
|
||
super(SourceModuleHnNSF, self).__init__()
|
||
|
||
self.sine_amp = sine_amp
|
||
self.noise_std = add_noise_std
|
||
self.is_half = is_half
|
||
# to produce sine waveforms
|
||
self.l_sin_gen = SineGen(
|
||
sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
|
||
)
|
||
|
||
# to merge source harmonics into a single excitation
|
||
self.l_linear = torch.nn.Conv1d(gin_channels, harmonic_num + 1, 1)
|
||
self.l_tanh = torch.nn.Tanh()
|
||
|
||
def forward(self, x, upp=None):
|
||
sine_wavs, uv, _ = self.l_sin_gen(x, upp)
|
||
sine_raw = torch.transpose(sine_wavs, 1, 2).to(device=x.device, dtype=x.dtype)
|
||
return sine_raw, None, None # noise, uv
|
||
|
||
|
||
class GeneratorNSF(torch.nn.Module):
|
||
def __init__(
|
||
self,
|
||
initial_channel,
|
||
resblock,
|
||
resblock_kernel_sizes,
|
||
resblock_dilation_sizes,
|
||
upsample_rates,
|
||
upsample_initial_channel,
|
||
upsample_kernel_sizes,
|
||
gin_channels,
|
||
sr,
|
||
harmonic_num=16,
|
||
is_half=False,
|
||
):
|
||
super(GeneratorNSF, self).__init__()
|
||
self.num_kernels = len(resblock_kernel_sizes)
|
||
self.num_upsamples = len(upsample_rates)
|
||
self.upsample_rates = upsample_rates
|
||
|
||
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
|
||
self.m_source = SourceModuleHnNSF(
|
||
sampling_rate=sr, gin_channels=gin_channels, harmonic_num=harmonic_num, is_half=is_half
|
||
)
|
||
self.gpre = Conv1d(gin_channels, initial_channel, 1)
|
||
self.conv_pre = ResBlock1(initial_channel, upsample_initial_channel, gin_channels, [7] * 5, [1] * 5, [1, 2, 4, 8, 1], 1, 2)
|
||
|
||
self.ups = nn.ModuleList()
|
||
self.resblocks = nn.ModuleList()
|
||
c_cur = upsample_initial_channel
|
||
for i, u in enumerate(upsample_rates):
|
||
c_pre = c_cur
|
||
c_cur = c_pre // 2
|
||
self.ups.append(
|
||
CausalConvTranspose1d(
|
||
c_pre,
|
||
c_pre,
|
||
kernel_rate=3,
|
||
stride=u,
|
||
groups=c_pre,
|
||
)
|
||
)
|
||
self.resblocks.append(ResBlock1(c_pre, c_cur, gin_channels, [11] * 5, [1] * 5, [1, 2, 4, 8, 1], 1, r=2))
|
||
self.conv_post = DilatedCausalConv1d(c_cur, 1, 5, stride=1, groups=1, dilation=1, bias=False)
|
||
self.noise_convs = nn.ModuleList()
|
||
self.noise_pre = LoRALinear1d(1 + harmonic_num, c_pre, gin_channels, r=2+harmonic_num)
|
||
for i, u in enumerate(upsample_rates[::-1]):
|
||
c_pre = c_pre * 2
|
||
c_cur = c_cur * 2
|
||
if i + 1 < len(upsample_rates):
|
||
self.noise_convs.append(DilatedCausalConv1d(c_cur, c_pre, kernel_size=u*3, stride=u, groups=c_cur, dilation=1))
|
||
else:
|
||
self.noise_convs.append(DilatedCausalConv1d(c_cur, initial_channel, kernel_size=u*3, stride=u, groups=math.gcd(c_cur, initial_channel), dilation=1))
|
||
self.upp = np.prod(upsample_rates)
|
||
|
||
def forward(self, x, x_mask, f0f, g):
|
||
har_source, noi_source, uv = self.m_source(f0f, self.upp)
|
||
har_source = self.noise_pre(har_source, g)
|
||
x_sources = [har_source]
|
||
for c in self.noise_convs:
|
||
har_source = c(har_source)
|
||
x_sources.append(har_source)
|
||
|
||
x = x + x_sources[-1]
|
||
x = x + self.gpre(g)
|
||
x = self.conv_pre(x, x_mask, g)
|
||
for i, u in enumerate(self.upsample_rates):
|
||
x_mask = torch.repeat_interleave(x_mask, u, 2)
|
||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||
x = self.ups[i](x)
|
||
x = self.resblocks[i](x + x_sources[-i-2], x_mask, g)
|
||
|
||
x = F.leaky_relu(x)
|
||
x = self.conv_post(x)
|
||
if x_mask is not None:
|
||
x *= x_mask
|
||
x = torch.tanh(x)
|
||
return x
|
||
|
||
def remove_weight_norm(self):
|
||
for l in self.ups:
|
||
remove_weight_norm(l)
|
||
for l in self.resblocks:
|
||
l.remove_weight_norm()
|
||
remove_weight_norm(self.noise_pre)
|
||
remove_weight_norm(self.conv_post)
|
||
|
||
|
||
sr2sr = {
|
||
"32k": 32000,
|
||
"40k": 40000,
|
||
"48k": 48000,
|
||
}
|
||
|
||
|
||
class SynthesizerTrnMs256NSFSid(nn.Module):
|
||
def __init__(
|
||
self,
|
||
spec_channels,
|
||
segment_size,
|
||
inter_channels,
|
||
hidden_channels,
|
||
filter_channels,
|
||
n_heads,
|
||
n_layers,
|
||
kernel_size,
|
||
p_dropout,
|
||
resblock,
|
||
resblock_kernel_sizes,
|
||
resblock_dilation_sizes,
|
||
upsample_rates,
|
||
upsample_initial_channel,
|
||
upsample_kernel_sizes,
|
||
spk_embed_dim,
|
||
gin_channels,
|
||
emb_channels,
|
||
sr,
|
||
**kwargs
|
||
):
|
||
super().__init__()
|
||
if type(sr) == type("strr"):
|
||
sr = sr2sr[sr]
|
||
self.spec_channels = spec_channels
|
||
self.inter_channels = inter_channels
|
||
self.hidden_channels = hidden_channels
|
||
self.filter_channels = filter_channels
|
||
self.n_heads = n_heads
|
||
self.n_layers = n_layers
|
||
self.kernel_size = kernel_size
|
||
self.p_dropout = p_dropout
|
||
self.resblock = resblock
|
||
self.resblock_kernel_sizes = resblock_kernel_sizes
|
||
self.resblock_dilation_sizes = resblock_dilation_sizes
|
||
self.upsample_rates = upsample_rates
|
||
self.upsample_initial_channel = upsample_initial_channel
|
||
self.upsample_kernel_sizes = upsample_kernel_sizes
|
||
self.segment_size = segment_size
|
||
self.gin_channels = gin_channels
|
||
self.emb_channels = emb_channels
|
||
self.sr = sr
|
||
# self.hop_length = hop_length#
|
||
self.spk_embed_dim = spk_embed_dim
|
||
|
||
self.emb_pitch = nn.Embedding(256, emb_channels) # pitch 256
|
||
self.dec = GeneratorNSF(
|
||
emb_channels,
|
||
resblock,
|
||
resblock_kernel_sizes,
|
||
resblock_dilation_sizes,
|
||
upsample_rates,
|
||
upsample_initial_channel,
|
||
upsample_kernel_sizes,
|
||
gin_channels=gin_channels,
|
||
sr=sr,
|
||
)
|
||
|
||
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
|
||
print(
|
||
"gin_channels:",
|
||
gin_channels,
|
||
"self.spk_embed_dim:",
|
||
self.spk_embed_dim,
|
||
"emb_channels:",
|
||
emb_channels,
|
||
)
|
||
|
||
def remove_weight_norm(self):
|
||
self.dec.remove_weight_norm()
|
||
|
||
def forward(
|
||
self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
|
||
): # 这里ds是id,[bs,1]
|
||
# print(1,pitch.shape)#[bs,t]
|
||
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
|
||
# m_p, _, x_mask = self.enc_p(phone, pitch, phone_lengths, g)
|
||
# z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
||
# z_p = self.flow(m_p * x_mask, x_mask, g=g)
|
||
|
||
x = phone + self.emb_pitch(pitch)
|
||
x = torch.transpose(x, 1, -1) # [b, h, t]
|
||
x_mask = torch.unsqueeze(commons.sequence_mask(phone_lengths, x.size(2)), 1).to(
|
||
phone.dtype
|
||
)
|
||
|
||
m_p_slice, ids_slice = commons.rand_slice_segments(
|
||
x, phone_lengths, self.segment_size
|
||
)
|
||
# print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
|
||
pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
|
||
mask_slice = commons.slice_segments(x_mask, ids_slice, self.segment_size)
|
||
# print(-2,pitchf.shape,z_slice.shape)
|
||
o = self.dec(m_p_slice, mask_slice, pitchf, g)
|
||
return o, ids_slice, x_mask, g
|
||
|
||
def infer(self, phone, phone_lengths, pitch, nsff0, sid, max_len=None):
|
||
g = self.emb_g(sid).unsqueeze(-1)
|
||
x = phone + self.emb_pitch(pitch)
|
||
x = torch.transpose(x, 1, -1)
|
||
x_mask = torch.unsqueeze(commons.sequence_mask(phone_lengths, x.size(2)), 1).to(
|
||
phone.dtype
|
||
)
|
||
o = self.dec((x * x_mask)[:, :, :max_len], x_mask, nsff0, g)
|
||
return o, x_mask, (None, None, None, None)
|
||
|
||
|
||
class DiscriminatorS(torch.nn.Module):
|
||
def __init__(
|
||
self,
|
||
hidden_channels: int,
|
||
filter_channels: int,
|
||
gin_channels: int,
|
||
n_heads: int,
|
||
n_layers: int,
|
||
kernel_size: int,
|
||
p_dropout: int,
|
||
):
|
||
super(DiscriminatorS, self).__init__()
|
||
self.convs = WaveConv1D(2, hidden_channels, gin_channels, [10, 7, 7, 7, 5, 3, 3], [5, 4, 4, 4, 3, 2, 2], [1] * 7, hidden_channels // 2, False)
|
||
self.encoder = attentions.Encoder(
|
||
hidden_channels, filter_channels, gin_channels, n_heads, n_layers//2, kernel_size, p_dropout
|
||
)
|
||
self.cross = weight_norm(torch.nn.Conv1d(gin_channels, hidden_channels, 1, 1))
|
||
self.conv_post = weight_norm(torch.nn.Conv1d(hidden_channels, 1, 3, 1, padding=get_padding(5, 1)))
|
||
|
||
def forward(self, x, g):
|
||
x = self.convs(x)
|
||
x_mask = torch.ones([x.shape[0], 1, x.shape[2]], device=x.device, dtype=x.dtype)
|
||
x = self.encoder(x, x_mask, g)
|
||
fmap = [x]
|
||
x = x + x * self.cross(g)
|
||
y = self.conv_post(x)
|
||
return y, fmap
|
||
|
||
|
||
class DiscriminatorP(torch.nn.Module):
|
||
def __init__(self, period, gin_channels, upsample_rates, final_dim=256, use_spectral_norm=False):
|
||
super(DiscriminatorP, self).__init__()
|
||
self.period = period
|
||
self.use_spectral_norm = use_spectral_norm
|
||
self.init_kernel_size = upsample_rates[-1] * 3
|
||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||
N = len(upsample_rates)
|
||
self.init_conv = norm_f(Conv2d(1, final_dim // (2 ** (N - 1)), (self.init_kernel_size, 1), (upsample_rates[-1], 1)))
|
||
self.convs = nn.ModuleList()
|
||
for i, u in enumerate(upsample_rates[::-1][1:], start=1):
|
||
self.convs.append(
|
||
ConvNext2d(
|
||
final_dim // (2 ** (N - i)),
|
||
final_dim // (2 ** (N - i - 1)),
|
||
gin_channels,
|
||
(u*3, 1),
|
||
(u, 1),
|
||
4,
|
||
r=2
|
||
)
|
||
)
|
||
self.conv_post = weight_norm(Conv2d(final_dim, 1, (3, 1), (1, 1)))
|
||
|
||
def forward(self, x, g):
|
||
fmap = []
|
||
|
||
# 1d to 2d
|
||
b, c, t = x.shape
|
||
if t % self.period != 0: # pad first
|
||
n_pad = self.period - (t % self.period)
|
||
x = F.pad(x, (n_pad, 0), "reflect")
|
||
t = t + n_pad
|
||
x = x.view(b, c, t // self.period, self.period)
|
||
|
||
x = torch.flip(x, dims=[2])
|
||
x = F.pad(x, [0, 0, 0, self.init_kernel_size - 1], mode="constant")
|
||
x = self.init_conv(x)
|
||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||
x = torch.flip(x, dims=[2])
|
||
|
||
for i, l in enumerate(self.convs):
|
||
x = l(x, g)
|
||
if i >= 1:
|
||
fmap.append(x)
|
||
|
||
x = F.pad(x, [0, 0, 2, 0], mode="constant")
|
||
x = self.conv_post(x)
|
||
x = torch.flatten(x, 1, -1)
|
||
|
||
return x, fmap
|
||
|
||
|
||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||
def __init__(self, upsample_rates, gin_channels, periods=[2, 3, 5, 7, 11, 17], **kwargs):
|
||
super(MultiPeriodDiscriminator, self).__init__()
|
||
|
||
# discs = [DiscriminatorS(hidden_channels, filter_channels, gin_channels, n_heads, n_layers, kernel_size, p_dropout)]
|
||
discs = [
|
||
DiscriminatorP(i, gin_channels, upsample_rates, use_spectral_norm=False) for i in periods
|
||
]
|
||
self.ups = np.prod(upsample_rates)
|
||
self.discriminators = nn.ModuleList(discs)
|
||
|
||
def forward(self, y, y_hat, g):
|
||
fmap_rs = []
|
||
fmap_gs = []
|
||
y_d_rs = []
|
||
y_d_gs = []
|
||
for i, d in enumerate(self.discriminators):
|
||
y_d_r, fmap_r = d(y, g)
|
||
y_d_g, fmap_g = d(y_hat, g)
|
||
# for j in range(len(fmap_r)):
|
||
# print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
|
||
y_d_rs.append(y_d_r)
|
||
y_d_gs.append(y_d_g)
|
||
fmap_rs.append(fmap_r)
|
||
fmap_gs.append(fmap_g)
|
||
|
||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|