protect + crepe対応

This commit is contained in:
nadare 2023-05-31 23:50:43 +09:00
parent dcfc1d83cf
commit 6b9777f3a2
8 changed files with 37 additions and 5 deletions

View File

@ -12,6 +12,7 @@ class ModelSlot:
indexFile: str = "" indexFile: str = ""
defaultTune: int = 0 defaultTune: int = 0
defaultIndexRatio: int = 1 defaultIndexRatio: int = 1
# defaultProtect: float = .5
isONNX: bool = False isONNX: bool = False
modelType: EnumInferenceTypes = EnumInferenceTypes.pyTorchRVC modelType: EnumInferenceTypes = EnumInferenceTypes.pyTorchRVC
samplingRate: int = -1 samplingRate: int = -1

View File

@ -36,6 +36,9 @@ def generateModelSlot(slotDir: str):
modelSlot.defaultIndexRatio = ( modelSlot.defaultIndexRatio = (
params["defaultIndexRatio"] if "defaultIndexRatio" in params else 0 params["defaultIndexRatio"] if "defaultIndexRatio" in params else 0
) )
# modelSlot.defaultProtect = (
# params["defaultProtect"] if "defaultProtect" in params else 0.5
# )
modelSlot.name = params["name"] if "name" in params else None modelSlot.name = params["name"] if "name" in params else None
modelSlot.description = params["description"] if "description" in params else None modelSlot.description = params["description"] if "description" in params else None
modelSlot.credit = params["credit"] if "credit" in params else None modelSlot.credit = params["credit"] if "credit" in params else None

View File

@ -242,6 +242,7 @@ class RVC:
# その他の設定 # その他の設定
self.next_trans = modelSlot.defaultTune self.next_trans = modelSlot.defaultTune
self.next_index_ratio = modelSlot.defaultIndexRatio self.next_index_ratio = modelSlot.defaultIndexRatio
# self.next_protect = modelSlot.defaultProtect
self.next_samplingRate = modelSlot.samplingRate self.next_samplingRate = modelSlot.samplingRate
self.next_framework = "ONNX" if modelSlot.isONNX else "PyTorch" self.next_framework = "ONNX" if modelSlot.isONNX else "PyTorch"
# self.needSwitch = True # self.needSwitch = True
@ -254,6 +255,7 @@ class RVC:
self.pipeline = self.next_pipeline self.pipeline = self.next_pipeline
self.settings.tran = self.next_trans self.settings.tran = self.next_trans
self.settings.indexRatio = self.next_index_ratio self.settings.indexRatio = self.next_index_ratio
# self.settings.protect = self.next_protect
self.settings.modelSamplingRate = self.next_samplingRate self.settings.modelSamplingRate = self.next_samplingRate
self.settings.framework = self.next_framework self.settings.framework = self.next_framework
@ -336,6 +338,7 @@ class RVC:
sid = 0 sid = 0
f0_up_key = self.settings.tran f0_up_key = self.settings.tran
index_rate = self.settings.indexRatio index_rate = self.settings.indexRatio
protect = .5# self.settings.protect
if_f0 = 1 if self.settings.modelSlots[self.currentSlot].f0 else 0 if_f0 = 1 if self.settings.modelSlots[self.currentSlot].f0 else 0
embOutputLayer = self.settings.modelSlots[self.currentSlot].embOutputLayer embOutputLayer = self.settings.modelSlots[self.currentSlot].embOutputLayer
useFinalProj = self.settings.modelSlots[self.currentSlot].useFinalProj useFinalProj = self.settings.modelSlots[self.currentSlot].useFinalProj
@ -350,6 +353,7 @@ class RVC:
embOutputLayer, embOutputLayer,
useFinalProj, useFinalProj,
repeat, repeat,
protect
) )
result = audio_out.detach().cpu().numpy() * np.sqrt(vol) result = audio_out.detach().cpu().numpy() * np.sqrt(vol)
@ -411,6 +415,7 @@ class RVC:
params = { params = {
"defaultTune": req.defaultTune, "defaultTune": req.defaultTune,
"defaultIndexRatio": req.defaultIndexRatio, "defaultIndexRatio": req.defaultIndexRatio,
# "defaultProtect": req.defaultProtect
"sampleId": "", "sampleId": "",
"files": {"rvcModel": storeFile}, "files": {"rvcModel": storeFile},
} }
@ -432,6 +437,7 @@ class RVC:
) )
params["defaultTune"] = self.settings.tran params["defaultTune"] = self.settings.tran
params["defaultIndexRatio"] = self.settings.indexRatio params["defaultIndexRatio"] = self.settings.indexRatio
# params["defaultProtect"] = self.settings.protect
json.dump(params, open(os.path.join(slotDir, "params.json"), "w")) json.dump(params, open(os.path.join(slotDir, "params.json"), "w"))
self.loadSlots() self.loadSlots()

View File

@ -29,6 +29,7 @@ class RVCSettings:
sampleModels: list[RVCModelSample] = field(default_factory=lambda: []) sampleModels: list[RVCModelSample] = field(default_factory=lambda: [])
indexRatio: float = 0 indexRatio: float = 0
# protect: float = 0.5
rvcQuality: int = 0 rvcQuality: int = 0
silenceFront: int = 1 # 0:off, 1:on silenceFront: int = 1 # 0:off, 1:on
modelSamplingRate: int = 48000 modelSamplingRate: int = 48000
@ -50,5 +51,5 @@ class RVCSettings:
"isHalf", "isHalf",
"enableDirectML", "enableDirectML",
] ]
floatData = ["silentThreshold", "indexRatio"] floatData = ["silentThreshold", "indexRatio"] # , "protect"]
strData = ["framework", "f0Detector"] strData = ["framework", "f0Detector"]

View File

@ -81,6 +81,7 @@ def downloadInitialSampleModels(sampleJsons: list[str], model_dir: str):
sampleParams["sampleId"] = sample.id sampleParams["sampleId"] = sample.id
sampleParams["defaultTune"] = 0 sampleParams["defaultTune"] = 0
sampleParams["defaultIndexRatio"] = 1 sampleParams["defaultIndexRatio"] = 1
# sampleParams["defaultProtect"] = 0.5
sampleParams["credit"] = sample.credit sampleParams["credit"] = sample.credit
sampleParams["description"] = sample.description sampleParams["description"] = sample.description
sampleParams["name"] = sample.name sampleParams["name"] = sample.name

View File

@ -17,4 +17,5 @@ class MergeModelRequest:
slot: int = -1 slot: int = -1
defaultTune: int = 0 defaultTune: int = 0
defaultIndexRatio: int = 1 defaultIndexRatio: int = 1
# defaultProtect: float = .5
files: List[MergeFile] = field(default_factory=lambda: []) files: List[MergeFile] = field(default_factory=lambda: [])

View File

@ -85,7 +85,9 @@ class Pipeline(object):
embOutputLayer, embOutputLayer,
useFinalProj, useFinalProj,
repeat, repeat,
protect=0.5,
): ):
search_index = self.index is not None and self.big_npy is not None and index_rate != 0
self.t_pad = self.sr * repeat self.t_pad = self.sr * repeat
self.t_pad_tgt = self.targetSR * repeat self.t_pad_tgt = self.targetSR * repeat
@ -136,10 +138,12 @@ class Pipeline(object):
raise DeviceChangingException() raise DeviceChangingException()
else: else:
raise e raise e
if protect < 0.5 and search_index:
feats0 = feats.clone()
# Index - feature抽出 # Index - feature抽出
# if self.index is not None and self.feature is not None and index_rate != 0: # if self.index is not None and self.feature is not None and index_rate != 0:
if self.index is not None and self.big_npy is not None and index_rate != 0: if search_index:
npy = feats[0].cpu().numpy() npy = feats[0].cpu().numpy()
if self.isHalf is True: if self.isHalf is True:
npy = npy.astype("float32") npy = npy.astype("float32")
@ -165,7 +169,10 @@ class Pipeline(object):
+ (1 - index_rate) * feats + (1 - index_rate) * feats
) )
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
if protect < 0.5 and search_index:
feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(
0, 2, 1
)
# ピッチサイズ調整 # ピッチサイズ調整
p_len = audio_pad.shape[0] // self.window p_len = audio_pad.shape[0] // self.window
if feats.shape[1] < p_len: if feats.shape[1] < p_len:
@ -173,6 +180,15 @@ class Pipeline(object):
if pitch is not None and pitchf is not None: if pitch is not None and pitchf is not None:
pitch = pitch[:, :p_len] pitch = pitch[:, :p_len]
pitchf = pitchf[:, :p_len] pitchf = pitchf[:, :p_len]
# pitchの推定が上手くいかない(pitchf=0)場合、検索前の特徴を混ぜる
if protect < 0.5 and search_index:
pitchff = pitchf.clone()
pitchff[pitchf > 0] = 1
pitchff[pitchf < 1] = protect
pitchff = pitchff.unsqueeze(-1)
feats = feats * pitchff + feats0 * (1 - pitchff)
feats = feats.to(feats0.dtype)
p_len = torch.tensor([p_len], device=self.device).long() p_len = torch.tensor([p_len], device=self.device).long()
# 推論実行 # 推論実行

View File

@ -26,7 +26,7 @@ class CrepePitchExtractor(PitchExtractor):
f0_mel_min = 1127 * np.log(1 + f0_min / 700) f0_mel_min = 1127 * np.log(1 + f0_min / 700)
f0_mel_max = 1127 * np.log(1 + f0_max / 700) f0_mel_max = 1127 * np.log(1 + f0_max / 700)
f0 = torchcrepe.predict( f0, pd = torchcrepe.predict(
audio.unsqueeze(0), audio.unsqueeze(0),
sr, sr,
hop_length=window, hop_length=window,
@ -37,8 +37,11 @@ class CrepePitchExtractor(PitchExtractor):
batch_size=256, batch_size=256,
decoder=torchcrepe.decode.weighted_argmax, decoder=torchcrepe.decode.weighted_argmax,
device=self.device, device=self.device,
return_periodicity=True,
) )
f0 = torchcrepe.filter.median(f0, 3) f0 = torchcrepe.filter.median(f0, 3) # 本家だとmeanですが、harvestに合わせmedianフィルタ
pd = torchcrepe.filter.median(pd, 3)
f0[pd < 0.1] = 0
f0 = f0.squeeze() f0 = f0.squeeze()
f0 = torch.nn.functional.pad( f0 = torch.nn.functional.pad(