WIP:improve model selector (MMVCv13)
This commit is contained in:
parent
3e0772d955
commit
19e70606c8
@ -32,22 +32,8 @@
|
||||
],
|
||||
"modelSetting": [
|
||||
{
|
||||
"name": "modelUploader",
|
||||
"options": {
|
||||
"showConfig": true,
|
||||
"showOnnx": false,
|
||||
"showPyTorch": true,
|
||||
"showCorrespondence": false,
|
||||
"showPyTorchCluster": false,
|
||||
|
||||
"showFeature": false,
|
||||
"showIndex": false,
|
||||
"showHalfPrecision": false,
|
||||
"showPyTorchEnableCheckBox": true,
|
||||
"defaultEnablePyTorch": true,
|
||||
|
||||
"showOnnxExportButton": false
|
||||
}
|
||||
"name": "modelUploaderv2",
|
||||
"options": {}
|
||||
},
|
||||
{
|
||||
"name": "commonFileSelect",
|
||||
@ -84,16 +70,6 @@
|
||||
{
|
||||
"name": "modelUploadButtonRow2",
|
||||
"options": {}
|
||||
},
|
||||
{
|
||||
"name": "framework",
|
||||
"options": {
|
||||
"showFramework": true
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "modelSamplingRate",
|
||||
"options": {}
|
||||
}
|
||||
],
|
||||
"lab": [],
|
||||
|
@ -32,23 +32,28 @@
|
||||
],
|
||||
"modelSetting": [
|
||||
{
|
||||
"name": "modelUploader",
|
||||
"name": "modelUploaderv2",
|
||||
"options": {}
|
||||
},
|
||||
{
|
||||
"name": "commonFileSelect",
|
||||
"options": {
|
||||
"showConfig": true,
|
||||
"showOnnx": true,
|
||||
"showPyTorch": true,
|
||||
"showCorrespondence": false,
|
||||
"showPyTorchCluster": false,
|
||||
|
||||
"showPyTorchEnableCheckBox": true,
|
||||
"defaultEnablePyTorch": false
|
||||
"title": "Config(.json)",
|
||||
"acceptExtentions": ["json"],
|
||||
"fileKind": "mmvcv13Config"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "framework",
|
||||
"name": "commonFileSelect",
|
||||
"options": {
|
||||
"showFramework": true
|
||||
"title": "Model(.pt,.pth,.onxx)",
|
||||
"acceptExtentions": ["pt", "pth", "onnx"],
|
||||
"fileKind": "mmvcv13Model"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "modelUploadButtonRow2",
|
||||
"options": {}
|
||||
}
|
||||
],
|
||||
"lab": [],
|
||||
|
21
client/demo/dist/index.js
vendored
21
client/demo/dist/index.js
vendored
File diff suppressed because one or more lines are too long
@ -32,22 +32,8 @@
|
||||
],
|
||||
"modelSetting": [
|
||||
{
|
||||
"name": "modelUploader",
|
||||
"options": {
|
||||
"showConfig": true,
|
||||
"showOnnx": false,
|
||||
"showPyTorch": true,
|
||||
"showCorrespondence": false,
|
||||
"showPyTorchCluster": false,
|
||||
|
||||
"showFeature": false,
|
||||
"showIndex": false,
|
||||
"showHalfPrecision": false,
|
||||
"showPyTorchEnableCheckBox": true,
|
||||
"defaultEnablePyTorch": true,
|
||||
|
||||
"showOnnxExportButton": false
|
||||
}
|
||||
"name": "modelUploaderv2",
|
||||
"options": {}
|
||||
},
|
||||
{
|
||||
"name": "commonFileSelect",
|
||||
@ -84,16 +70,6 @@
|
||||
{
|
||||
"name": "modelUploadButtonRow2",
|
||||
"options": {}
|
||||
},
|
||||
{
|
||||
"name": "framework",
|
||||
"options": {
|
||||
"showFramework": true
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "modelSamplingRate",
|
||||
"options": {}
|
||||
}
|
||||
],
|
||||
"lab": [],
|
||||
|
@ -32,23 +32,28 @@
|
||||
],
|
||||
"modelSetting": [
|
||||
{
|
||||
"name": "modelUploader",
|
||||
"name": "modelUploaderv2",
|
||||
"options": {}
|
||||
},
|
||||
{
|
||||
"name": "commonFileSelect",
|
||||
"options": {
|
||||
"showConfig": true,
|
||||
"showOnnx": true,
|
||||
"showPyTorch": true,
|
||||
"showCorrespondence": false,
|
||||
"showPyTorchCluster": false,
|
||||
|
||||
"showPyTorchEnableCheckBox": true,
|
||||
"defaultEnablePyTorch": false
|
||||
"title": "Config(.json)",
|
||||
"acceptExtentions": ["json"],
|
||||
"fileKind": "mmvcv13Config"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "framework",
|
||||
"name": "commonFileSelect",
|
||||
"options": {
|
||||
"showFramework": true
|
||||
"title": "Model(.pt,.pth,.onxx)",
|
||||
"acceptExtentions": ["pt", "pth", "onnx"],
|
||||
"fileKind": "mmvcv13Model"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "modelUploadButtonRow2",
|
||||
"options": {}
|
||||
}
|
||||
],
|
||||
"lab": [],
|
||||
|
@ -50,6 +50,7 @@ import { AudioDeviceModeRow, AudioDeviceModeRowProps } from "./components/410_Au
|
||||
import { IOBufferRow, IOBufferRowProps } from "./components/411_IOBufferRow"
|
||||
import { CommonFileSelectRow, CommonFileSelectRowProps } from "./components/301-e_CommonFileSelectRow"
|
||||
import { ModelUploadButtonRow2, ModelUploadButtonRow2Props } from "./components/301-f_ModelUploadButtonRow"
|
||||
import { ModelUploaderRowv2, ModelUploaderRowv2Props } from "./components/301_ModelUploaderRowv2"
|
||||
|
||||
export const catalog: { [key: string]: (props: any) => JSX.Element } = {}
|
||||
|
||||
@ -81,6 +82,7 @@ const initialize = () => {
|
||||
|
||||
|
||||
addToCatalog("modelUploader", (props: ModelUploaderRowProps) => { return <ModelUploaderRow {...props} /> })
|
||||
addToCatalog("modelUploaderv2", (props: ModelUploaderRowv2Props) => { return <ModelUploaderRowv2 {...props} /> })
|
||||
addToCatalog("framework", (props: FrameworkRowProps) => { return <FrameworkRow {...props} /> })
|
||||
addToCatalog("modelSamplingRate", (props: ModelSamplingRateRowProps) => { return <ModelSamplingRateRow {...props} /> })
|
||||
addToCatalog("commonFileSelect", (props: CommonFileSelectRowProps) => { return <CommonFileSelectRow {...props} /> })
|
||||
|
@ -47,7 +47,7 @@ export const PerformanceRow = (_props: PerformanceRowProps) => {
|
||||
setTimeout(updatePerformance, 1000 * 2)
|
||||
}
|
||||
}
|
||||
updatePerformance()
|
||||
// updatePerformance()
|
||||
return () => {
|
||||
execNext = false
|
||||
}
|
||||
|
@ -10,6 +10,8 @@ export type CommonFileSelectRowProps = {
|
||||
}
|
||||
|
||||
export const Filekinds = {
|
||||
"mmvcv13Config": "mmvcv13Config",
|
||||
"mmvcv13Model": "mmvcv13Model",
|
||||
"ddspSvcModel": "ddspSvcModel",
|
||||
"ddspSvcModelConfig": "ddspSvcModelConfig",
|
||||
"ddspSvcDiffusion": "ddspSvcDiffusion",
|
||||
|
@ -0,0 +1,24 @@
|
||||
import React, { useMemo } from "react"
|
||||
import { useGuiState } from "../001_GuiStateProvider"
|
||||
|
||||
export type ModelUploaderRowv2Props = {}
|
||||
|
||||
export const ModelUploaderRowv2 = (_props: ModelUploaderRowv2Props) => {
|
||||
const guiState = useGuiState()
|
||||
|
||||
const modelUploaderRow = useMemo(() => {
|
||||
|
||||
return (
|
||||
<div className="body-row split-3-3-4 left-padding-1 guided">
|
||||
<div className="body-item-title left-padding-1">Model Uploader</div>
|
||||
<div className="body-item-text">
|
||||
<div></div>
|
||||
</div>
|
||||
<div className="body-item-text">
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}, [guiState.showPyTorchModelUpload])
|
||||
|
||||
return modelUploaderRow
|
||||
}
|
@ -31,7 +31,7 @@ export const AudioDeviceModeRow = (_props: AudioDeviceModeRowProps) => {
|
||||
</div>
|
||||
<div className="left-padding-1">
|
||||
<input className="left-padding-1" type="radio" id="server-device" name="device-mode" checked={serverChecked} onChange={() => { onDeviceModeChanged(1) }} />
|
||||
<label htmlFor="server-device">server device</label>
|
||||
<label htmlFor="server-device">server device(exp.)</label>
|
||||
</div>
|
||||
</div>
|
||||
<div></div>
|
||||
|
@ -25,6 +25,9 @@ export type FileUploadSetting = {
|
||||
framework: Framework
|
||||
params: string
|
||||
|
||||
mmvcv13Config: ModelData | null
|
||||
mmvcv13Model: ModelData | null
|
||||
|
||||
ddspSvcModel: ModelData | null
|
||||
ddspSvcModelConfig: ModelData | null
|
||||
ddspSvcDiffusion: ModelData | null
|
||||
@ -41,17 +44,21 @@ const InitialFileUploadSetting: FileUploadSetting = {
|
||||
feature: null,
|
||||
index: null,
|
||||
|
||||
ddspSvcModel: null,
|
||||
ddspSvcModelConfig: null,
|
||||
ddspSvcDiffusion: null,
|
||||
ddspSvcDiffusionConfig: null,
|
||||
|
||||
isHalf: true,
|
||||
uploaded: false,
|
||||
defaultTune: 0,
|
||||
framework: Framework.PyTorch,
|
||||
params: "{}",
|
||||
|
||||
mmvcv13Config: null,
|
||||
mmvcv13Model: null,
|
||||
|
||||
ddspSvcModel: null,
|
||||
ddspSvcModelConfig: null,
|
||||
ddspSvcDiffusion: null,
|
||||
ddspSvcDiffusionConfig: null,
|
||||
|
||||
|
||||
}
|
||||
|
||||
export type UseServerSettingProps = {
|
||||
@ -213,7 +220,16 @@ export const useServerSetting = (props: UseServerSettingProps): ServerSettingSta
|
||||
|
||||
const loadModel = useMemo(() => {
|
||||
return async (slot: number) => {
|
||||
if (props.clientType == "DDSP-SVC") {
|
||||
if (props.clientType == "MMVCv13") {
|
||||
if (!fileUploadSettings[slot].mmvcv13Config) {
|
||||
alert("Configファイルを指定する必要があります。")
|
||||
return
|
||||
}
|
||||
if (!fileUploadSettings[slot].mmvcv13Model) {
|
||||
alert("モデルファイルを指定する必要があります。")
|
||||
return
|
||||
}
|
||||
} else if (props.clientType == "DDSP-SVC") {
|
||||
if (!fileUploadSettings[slot].ddspSvcModel) {
|
||||
alert("DDSPモデルを指定する必要があります。")
|
||||
return
|
||||
@ -304,6 +320,22 @@ export const useServerSetting = (props: UseServerSettingProps): ServerSettingSta
|
||||
})
|
||||
}
|
||||
|
||||
// MMVCv13
|
||||
const mmvcv13Models = [fileUploadSetting.mmvcv13Config, fileUploadSetting.mmvcv13Model].filter(x => { return x != null }) as ModelData[]
|
||||
for (let i = 0; i < mmvcv13Models.length; i++) {
|
||||
if (!mmvcv13Models[i].data) {
|
||||
mmvcv13Models[i].data = await mmvcv13Models[i].file!.arrayBuffer()
|
||||
mmvcv13Models[i].filename = await mmvcv13Models[i].file!.name
|
||||
}
|
||||
}
|
||||
for (let i = 0; i < mmvcv13Models.length; i++) {
|
||||
const progRate = 1 / mmvcv13Models.length
|
||||
const progOffset = 100 * i * progRate
|
||||
await _uploadFile(mmvcv13Models[i], (progress: number, _end: boolean) => {
|
||||
setUploadProgress(progress * progRate + progOffset)
|
||||
})
|
||||
}
|
||||
|
||||
// DDSP-SVC
|
||||
const ddspSvcModels = [fileUploadSetting.ddspSvcModel, fileUploadSetting.ddspSvcModelConfig, fileUploadSetting.ddspSvcDiffusion, fileUploadSetting.ddspSvcDiffusionConfig].filter(x => { return x != null }) as ModelData[]
|
||||
for (let i = 0; i < ddspSvcModels.length; i++) {
|
||||
@ -325,6 +357,8 @@ export const useServerSetting = (props: UseServerSettingProps): ServerSettingSta
|
||||
const params = JSON.stringify({
|
||||
trans: fileUploadSetting.defaultTune || 0,
|
||||
files: {
|
||||
mmvcv13Config: fileUploadSetting.mmvcv13Config?.filename || "",
|
||||
mmvcv13Models: fileUploadSetting.mmvcv13Model?.filename || "",
|
||||
ddspSvcModel: fileUploadSetting.ddspSvcModel?.filename ? "ddsp_mod/" + fileUploadSetting.ddspSvcModel?.filename : "",
|
||||
ddspSvcModelConfig: fileUploadSetting.ddspSvcModelConfig?.filename ? "ddsp_mod/" + fileUploadSetting.ddspSvcModelConfig?.filename : "",
|
||||
ddspSvcDiffusion: fileUploadSetting.ddspSvcDiffusion?.filename ? "ddsp_diff/" + fileUploadSetting.ddspSvcDiffusion?.filename : "",
|
||||
@ -396,6 +430,10 @@ export const useServerSetting = (props: UseServerSettingProps): ServerSettingSta
|
||||
defaultTune: fileUploadSetting.defaultTune,
|
||||
framework: fileUploadSetting.framework,
|
||||
params: fileUploadSetting.params,
|
||||
|
||||
mmvcv13Config: fileUploadSetting.mmvcv13Config ? { data: fileUploadSetting.mmvcv13Config.data, filename: fileUploadSetting.mmvcv13Config.filename } : null,
|
||||
mmvcv13Model: fileUploadSetting.mmvcv13Model ? { data: fileUploadSetting.mmvcv13Model.data, filename: fileUploadSetting.mmvcv13Model.filename } : null,
|
||||
|
||||
ddspSvcModel: fileUploadSetting.ddspSvcModel ? { data: fileUploadSetting.ddspSvcModel.data, filename: fileUploadSetting.ddspSvcModel.filename } : null,
|
||||
ddspSvcModelConfig: fileUploadSetting.ddspSvcModelConfig ? { data: fileUploadSetting.ddspSvcModelConfig.data, filename: fileUploadSetting.ddspSvcModelConfig.filename } : null,
|
||||
ddspSvcDiffusion: fileUploadSetting.ddspSvcDiffusion ? { data: fileUploadSetting.ddspSvcDiffusion.data, filename: fileUploadSetting.ddspSvcDiffusion.filename } : null,
|
||||
|
@ -1,11 +1,11 @@
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchaudio.transforms import Resample
|
||||
from torch.nn import functional as F
|
||||
from voice_changer.DDSP_SVC.ModelSlot import ModelSlot
|
||||
|
||||
from voice_changer.DDSP_SVC.deviceManager.DeviceManager import DeviceManager
|
||||
|
||||
if sys.platform.startswith("darwin"):
|
||||
baseDir = [x for x in sys.path if x.endswith("Contents/MacOS")]
|
||||
@ -17,9 +17,6 @@ if sys.platform.startswith("darwin"):
|
||||
else:
|
||||
sys.path.append("DDSP-SVC")
|
||||
|
||||
import ddsp.vocoder as vo # type:ignore
|
||||
from ddsp.core import upsample # type:ignore
|
||||
from enhancer import Enhancer # type:ignore
|
||||
from diffusion.infer_gt_mel import DiffGtMel # type: ignore
|
||||
|
||||
from voice_changer.utils.VoiceChangerModel import AudioInOut
|
||||
@ -27,18 +24,11 @@ from voice_changer.utils.VoiceChangerParams import VoiceChangerParams
|
||||
from voice_changer.utils.LoadModelParams import LoadModelParams
|
||||
from voice_changer.DDSP_SVC.DDSP_SVCSetting import DDSP_SVCSettings
|
||||
from voice_changer.RVC.embedder.EmbedderManager import EmbedderManager
|
||||
from Exceptions import NoModeLoadedException
|
||||
|
||||
# from Exceptions import NoModeLoadedException
|
||||
from voice_changer.DDSP_SVC.SvcDDSP import SvcDDSP
|
||||
|
||||
|
||||
providers = [
|
||||
"OpenVINOExecutionProvider",
|
||||
"CUDAExecutionProvider",
|
||||
"DmlExecutionProvider",
|
||||
"CPUExecutionProvider",
|
||||
]
|
||||
|
||||
|
||||
def phase_vocoder(a, b, fade_out, fade_in):
|
||||
fa = torch.fft.rfft(a)
|
||||
fb = torch.fft.rfft(b)
|
||||
@ -67,6 +57,8 @@ class DDSP_SVC:
|
||||
settings: DDSP_SVCSettings = DDSP_SVCSettings()
|
||||
diff_model: DiffGtMel = DiffGtMel()
|
||||
svc_model: SvcDDSP = SvcDDSP()
|
||||
|
||||
deviceManager = DeviceManager.get_instance()
|
||||
# diff_model: DiffGtMel = DiffGtMel()
|
||||
|
||||
audio_buffer: AudioInOut | None = None
|
||||
@ -80,147 +72,62 @@ class DDSP_SVC:
|
||||
EmbedderManager.initialize(params)
|
||||
print("DDSP-SVC initialization:", params)
|
||||
|
||||
# def useDevice(self):
|
||||
# if self.settings.gpu >= 0 and torch.cuda.is_available():
|
||||
# return torch.device("cuda", index=self.settings.gpu)
|
||||
# else:
|
||||
# return torch.device("cpu")
|
||||
|
||||
def loadModel(self, props: LoadModelParams):
|
||||
# target_slot_idx = props.slot
|
||||
self.device = torch.device("cuda", index=0)
|
||||
target_slot_idx = props.slot
|
||||
params = props.params
|
||||
|
||||
modelFile = params["files"]["ddspSvcModel"]
|
||||
diffusionFile = params["files"]["ddspSvcDiffusion"]
|
||||
self.svc_model.update_model(modelFile)
|
||||
modelSlot = ModelSlot(
|
||||
modelFile=modelFile,
|
||||
diffusionFile=diffusionFile,
|
||||
defaultTrans=params["trans"] if "trans" in params else 0,
|
||||
)
|
||||
self.settings.modelSlots[target_slot_idx] = modelSlot
|
||||
|
||||
print("diffusion file", diffusionFile)
|
||||
self.diff_model.flush_model(diffusionFile, ddsp_config=self.svc_model.args)
|
||||
# 初回のみロード
|
||||
# if self.initialLoad:
|
||||
# self.prepareModel(target_slot_idx)
|
||||
# self.settings.modelSlotIndex = target_slot_idx
|
||||
# self.switchModel()
|
||||
# self.initialLoad = False
|
||||
# elif target_slot_idx == self.currentSlot:
|
||||
# self.prepareModel(target_slot_idx)
|
||||
self.settings.modelSlotIndex = target_slot_idx
|
||||
self.reloadModel()
|
||||
|
||||
print("params:", params)
|
||||
# print("params_arg:", self.args)
|
||||
|
||||
# self.settings.pyTorchModelFile = props.files.pyTorchModelFilename
|
||||
# # model
|
||||
# model, args = vo.load_model(
|
||||
# self.settings.pyTorchModelFile, device=self.useDevice()
|
||||
# )
|
||||
# self.model = model
|
||||
# self.args = args
|
||||
# self.sampling_rate = args.data.sampling_rate
|
||||
# self.hop_size = int(
|
||||
# self.args.data.block_size
|
||||
# * self.sampling_rate
|
||||
# / self.args.data.sampling_rate
|
||||
# )
|
||||
|
||||
# # hubert
|
||||
# self.vec_path = self.params.hubert_soft
|
||||
# self.encoder = vo.Units_Encoder(
|
||||
# self.args.data.encoder,
|
||||
# self.vec_path,
|
||||
# self.args.data.encoder_sample_rate,
|
||||
# self.args.data.encoder_hop_size,
|
||||
# device=self.useDevice(),
|
||||
# )
|
||||
|
||||
# # f0dec
|
||||
# self.f0_detector = vo.F0_Extractor(
|
||||
# # "crepe",
|
||||
# self.settings.f0Detector,
|
||||
# self.sampling_rate,
|
||||
# self.hop_size,
|
||||
# float(50),
|
||||
# float(1100),
|
||||
# )
|
||||
|
||||
# self.volume_extractor = vo.Volume_Extractor(self.hop_size)
|
||||
# self.enhancer_path = self.params.nsf_hifigan
|
||||
# self.enhancer = Enhancer(
|
||||
# self.args.enhancer.type, self.enhancer_path, device=self.useDevice()
|
||||
# )
|
||||
return self.get_info()
|
||||
|
||||
def reloadModel(self):
|
||||
self.device = self.deviceManager.getDevice(self.settings.gpu)
|
||||
modelFile = self.settings.modelSlots[self.settings.modelSlotIndex].modelFile
|
||||
diffusionFile = self.settings.modelSlots[
|
||||
self.settings.modelSlotIndex
|
||||
].diffusionFile
|
||||
|
||||
self.svc_model = SvcDDSP()
|
||||
self.svc_model.setVCParams(self.params)
|
||||
self.svc_model.update_model(modelFile, self.device)
|
||||
self.diff_model = DiffGtMel(device=self.device)
|
||||
self.diff_model.flush_model(diffusionFile, ddsp_config=self.svc_model.args)
|
||||
|
||||
def update_settings(self, key: str, val: int | float | str):
|
||||
# if key == "onnxExecutionProvider" and self.onnx_session is not None:
|
||||
# if val == "CUDAExecutionProvider":
|
||||
# if self.settings.gpu < 0 or self.settings.gpu >= self.gpu_num:
|
||||
# self.settings.gpu = 0
|
||||
# provider_options = [{"device_id": self.settings.gpu}]
|
||||
# self.onnx_session.set_providers(
|
||||
# providers=[val], provider_options=provider_options
|
||||
# )
|
||||
# else:
|
||||
# self.onnx_session.set_providers(providers=[val])
|
||||
# elif key in self.settings.intData:
|
||||
# val = int(val)
|
||||
# setattr(self.settings, key, val)
|
||||
# if (
|
||||
# key == "gpu"
|
||||
# and val >= 0
|
||||
# and val < self.gpu_num
|
||||
# and self.onnx_session is not None
|
||||
# ):
|
||||
# providers = self.onnx_session.get_providers()
|
||||
# print("Providers:", providers)
|
||||
# if "CUDAExecutionProvider" in providers:
|
||||
# provider_options = [{"device_id": self.settings.gpu}]
|
||||
# self.onnx_session.set_providers(
|
||||
# providers=["CUDAExecutionProvider"],
|
||||
# provider_options=provider_options,
|
||||
# )
|
||||
# if key == "gpu" and len(self.settings.pyTorchModelFile) > 0:
|
||||
# model, _args = vo.load_model(
|
||||
# self.settings.pyTorchModelFile, device=self.useDevice()
|
||||
# )
|
||||
# self.model = model
|
||||
# self.enhancer = Enhancer(
|
||||
# self.args.enhancer.type, self.enhancer_path, device=self.useDevice()
|
||||
# )
|
||||
# self.encoder = vo.Units_Encoder(
|
||||
# self.args.data.encoder,
|
||||
# self.vec_path,
|
||||
# self.args.data.encoder_sample_rate,
|
||||
# self.args.data.encoder_hop_size,
|
||||
# device=self.useDevice(),
|
||||
# )
|
||||
|
||||
# elif key in self.settings.floatData:
|
||||
# setattr(self.settings, key, float(val))
|
||||
# elif key in self.settings.strData:
|
||||
# setattr(self.settings, key, str(val))
|
||||
# if key == "f0Detector":
|
||||
# print("f0Detector update", val)
|
||||
# # if val == "dio":
|
||||
# # val = "parselmouth"
|
||||
|
||||
# if hasattr(self, "sampling_rate") is False:
|
||||
# self.sampling_rate = 44100
|
||||
# self.hop_size = 512
|
||||
|
||||
# self.f0_detector = vo.F0_Extractor(
|
||||
# val, self.sampling_rate, self.hop_size, float(50), float(1100)
|
||||
# )
|
||||
# else:
|
||||
# return False
|
||||
|
||||
if key in self.settings.intData:
|
||||
val = int(val)
|
||||
setattr(self.settings, key, val)
|
||||
if key == "gpu":
|
||||
self.reloadModel()
|
||||
elif key in self.settings.floatData:
|
||||
setattr(self.settings, key, float(val))
|
||||
elif key in self.settings.strData:
|
||||
setattr(self.settings, key, str(val))
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_info(self):
|
||||
# data = asdict(self.settings)
|
||||
|
||||
# data["onnxExecutionProviders"] = (
|
||||
# self.onnx_session.get_providers() if self.onnx_session is not None else []
|
||||
# )
|
||||
# files = ["configFile", "pyTorchModelFile", "onnxModelFile"]
|
||||
# for f in files:
|
||||
# if data[f] is not None and os.path.exists(data[f]):
|
||||
# data[f] = os.path.basename(data[f])
|
||||
# else:
|
||||
# data[f] = ""
|
||||
|
||||
data = {}
|
||||
data = asdict(self.settings)
|
||||
return data
|
||||
|
||||
def get_processing_sampling_rate(self):
|
||||
@ -252,45 +159,7 @@ class DDSP_SVC:
|
||||
|
||||
convertOffset = -1 * convertSize
|
||||
self.audio_buffer = self.audio_buffer[convertOffset:] # 変換対象の部分だけ抽出
|
||||
|
||||
# # f0
|
||||
# f0 = self.f0_detector.extract(
|
||||
# self.audio_buffer * 32768.0,
|
||||
# uv_interp=True,
|
||||
# silence_front=self.settings.extraConvertSize / self.sampling_rate,
|
||||
# )
|
||||
# f0 = torch.from_numpy(f0).float().unsqueeze(-1).unsqueeze(0)
|
||||
# f0 = f0 * 2 ** (float(self.settings.tran) / 12)
|
||||
|
||||
# # volume, mask
|
||||
# volume = self.volume_extractor.extract(self.audio_buffer)
|
||||
# mask = (volume > 10 ** (float(-60) / 20)).astype("float")
|
||||
# mask = np.pad(mask, (4, 4), constant_values=(mask[0], mask[-1]))
|
||||
# mask = np.array(
|
||||
# [np.max(mask[n : n + 9]) for n in range(len(mask) - 8)] # noqa: E203
|
||||
# )
|
||||
# mask = torch.from_numpy(mask).float().unsqueeze(-1).unsqueeze(0)
|
||||
# mask = upsample(mask, self.args.data.block_size).squeeze(-1)
|
||||
# volume = torch.from_numpy(volume).float().unsqueeze(-1).unsqueeze(0)
|
||||
|
||||
# # embed
|
||||
# audio = (
|
||||
# torch.from_numpy(self.audio_buffer)
|
||||
# .float()
|
||||
# .to(self.useDevice())
|
||||
# .unsqueeze(0)
|
||||
# )
|
||||
# seg_units = self.encoder.encode(audio, self.sampling_rate, self.hop_size)
|
||||
|
||||
# cropOffset = -1 * (inputSize + crossfadeSize)
|
||||
# cropEnd = -1 * (crossfadeSize)
|
||||
# crop = self.audio_buffer[cropOffset:cropEnd]
|
||||
|
||||
# rms = np.sqrt(np.square(crop).mean(axis=0))
|
||||
# vol = max(rms, self.prevVol * 0.0)
|
||||
# self.prevVol = vol
|
||||
|
||||
return (self.audio_buffer, inputSize, crossfadeSize, solaSearchFrame)
|
||||
return (self.audio_buffer,)
|
||||
|
||||
# def _onnx_inference(self, data):
|
||||
# if hasattr(self, "onnx_session") is False or self.onnx_session is None:
|
||||
@ -305,32 +174,21 @@ class DDSP_SVC:
|
||||
# raise NoModeLoadedException("pytorch")
|
||||
|
||||
input_wav = data[0]
|
||||
# inputSize = data[1]
|
||||
# crossfadeSize = data[2]
|
||||
# solaSearchFrame = data[3]
|
||||
# last_delay_frame = int(0.02 * self.svc_model.args.data.sampling_rate)
|
||||
|
||||
# fade_in_window = (
|
||||
# torch.sin(
|
||||
# np.pi * torch.arange(0, 1, 1 / crossfadeSize, device=self.device) / 2
|
||||
# )
|
||||
# ** 2
|
||||
# )
|
||||
# fade_out_window = 1 - fade_in_window
|
||||
|
||||
_audio, _model_sr = self.svc_model.infer(
|
||||
input_wav,
|
||||
44100,
|
||||
self.svc_model.args.data.sampling_rate,
|
||||
spk_id=1,
|
||||
threhold=-45,
|
||||
pitch_adjust=10,
|
||||
pitch_adjust=self.settings.tran,
|
||||
use_spk_mix=False,
|
||||
spk_mix_dict=None,
|
||||
use_enhancer=False,
|
||||
pitch_extractor_type="harvest",
|
||||
pitch_extractor_type=self.settings.f0Detector,
|
||||
f0_min=50,
|
||||
f0_max=1100,
|
||||
safe_prefix_pad_length=0, # TBD なにこれ?
|
||||
# safe_prefix_pad_length=0, # TBD なにこれ?
|
||||
safe_prefix_pad_length=self.settings.extraConvertSize
|
||||
/ self.svc_model.args.data.sampling_rate,
|
||||
diff_model=self.diff_model,
|
||||
diff_acc=20, # TBD なにこれ?
|
||||
diff_spk_id=1,
|
||||
@ -340,94 +198,8 @@ class DDSP_SVC:
|
||||
diff_silence=False, # TBD なにこれ?
|
||||
)
|
||||
|
||||
print(" _model_sr", _model_sr)
|
||||
print("_audio", _audio.shape)
|
||||
print("_audio", _audio)
|
||||
return _audio.cpu().numpy() * 32768.0
|
||||
|
||||
# if _model_sr != self.svc_model.args.data.sampling_rate:
|
||||
# key_str = str(_model_sr) + "_" + str(self.svc_model.args.data.sampling_rate)
|
||||
# if key_str not in self.resample_kernel:
|
||||
# self.resample_kernel[key_str] = Resample(
|
||||
# _model_sr,
|
||||
# self.svc_model.args.data.sampling_rate,
|
||||
# lowpass_filter_width=128,
|
||||
# ).to(self.device)
|
||||
# _audio = self.resample_kernel[key_str](_audio)
|
||||
# temp_wav = _audio[
|
||||
# -inputSize
|
||||
# - crossfadeSize
|
||||
# - solaSearchFrame
|
||||
# - last_delay_frame : -last_delay_frame
|
||||
# ]
|
||||
|
||||
# # sola shift
|
||||
# conv_input = temp_wav[None, None, : crossfadeSize + solaSearchFrame]
|
||||
# cor_nom = F.conv1d(conv_input, self.sola_buffer[None, None, :])
|
||||
# cor_den = torch.sqrt(
|
||||
# F.conv1d(
|
||||
# conv_input**2,
|
||||
# torch.ones(1, 1, crossfadeSize, device=self.device),
|
||||
# )
|
||||
# + 1e-8
|
||||
# )
|
||||
# sola_shift = torch.argmax(cor_nom[0, 0] / cor_den[0, 0])
|
||||
# temp_wav = temp_wav[sola_shift : sola_shift + inputSize + crossfadeSize]
|
||||
# print("sola_shift: " + str(int(sola_shift)))
|
||||
|
||||
# # phase vocoder
|
||||
# # if self.config.use_phase_vocoder:
|
||||
# if False:
|
||||
# temp_wav[:crossfadeSize] = phase_vocoder(
|
||||
# self.sola_buffer,
|
||||
# temp_wav[:crossfadeSize],
|
||||
# fade_out_window,
|
||||
# fade_in_window,
|
||||
# )
|
||||
# else:
|
||||
# temp_wav[:crossfadeSize] *= fade_in_window
|
||||
# temp_wav[:crossfadeSize] += self.sola_buffer * fade_out_window
|
||||
|
||||
# self.sola_buffer = temp_wav[-crossfadeSize:]
|
||||
|
||||
# result = temp_wav[:-crossfadeSize, None].repeat(1, 2).cpu().numpy()
|
||||
|
||||
###########################################
|
||||
# c = data[0].to(self.useDevice())
|
||||
# f0 = data[1].to(self.useDevice())
|
||||
# volume = data[2].to(self.useDevice())
|
||||
# mask = data[3].to(self.useDevice())
|
||||
|
||||
# # convertSize = data[4]
|
||||
# # vol = data[5]
|
||||
# # if vol < self.settings.silentThreshold:
|
||||
# # print("threshold")
|
||||
# # return np.zeros(convertSize).astype(np.int16)
|
||||
|
||||
# with torch.no_grad():
|
||||
# spk_id = torch.LongTensor(np.array([[self.settings.dstId]])).to(
|
||||
# self.useDevice()
|
||||
# )
|
||||
# seg_output, _, (s_h, s_n) = self.model(
|
||||
# c, f0, volume, spk_id=spk_id, spk_mix_dict=None
|
||||
# )
|
||||
# seg_output *= mask
|
||||
|
||||
# if self.settings.enableEnhancer:
|
||||
# seg_output, output_sample_rate = self.enhancer.enhance(
|
||||
# seg_output,
|
||||
# self.args.data.sampling_rate,
|
||||
# f0,
|
||||
# self.args.data.block_size,
|
||||
# # adaptive_key=float(self.settings.enhancerTune),
|
||||
# adaptive_key="auto",
|
||||
# silence_front=self.settings.extraConvertSize / self.sampling_rate,
|
||||
# )
|
||||
|
||||
# result = seg_output.squeeze().cpu().numpy() * 32768.0
|
||||
|
||||
# return np.array(result).astype(np.int16)
|
||||
|
||||
def inference(self, data):
|
||||
if self.settings.framework == "ONNX":
|
||||
audio = self._onnx_inference(data)
|
||||
|
@ -1,14 +1,15 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from voice_changer.DDSP_SVC.ModelSlot import ModelSlot
|
||||
|
||||
|
||||
@dataclass
|
||||
class DDSP_SVCSettings:
|
||||
gpu: int = 0
|
||||
dstId: int = 0
|
||||
dstId: int = 1
|
||||
|
||||
f0Detector: str = "dio" # dio or harvest # parselmouth
|
||||
f0Detector: str = "dio" # dio or harvest or crepe # parselmouth
|
||||
tran: int = 20
|
||||
predictF0: int = 0 # 0:False, 1:True
|
||||
silentThreshold: float = 0.00001
|
||||
extraConvertSize: int = 1024 * 32
|
||||
|
||||
@ -21,16 +22,16 @@ class DDSP_SVCSettings:
|
||||
configFile: str = ""
|
||||
|
||||
speakers: dict[str, int] = field(default_factory=lambda: {})
|
||||
|
||||
modelSlotIndex: int = -1
|
||||
modelSlots: list[ModelSlot] = field(default_factory=lambda: [ModelSlot()])
|
||||
# ↓mutableな物だけ列挙
|
||||
intData = [
|
||||
"gpu",
|
||||
"dstId",
|
||||
"tran",
|
||||
"predictF0",
|
||||
"extraConvertSize",
|
||||
"enableEnhancer",
|
||||
"enhancerTune",
|
||||
]
|
||||
floatData = ["silentThreshold", "clusterInferRatio"]
|
||||
floatData = ["silentThreshold"]
|
||||
strData = ["framework", "f0Detector"]
|
||||
|
@ -1,16 +1,8 @@
|
||||
from const import EnumInferenceTypes, EnumEmbedderTypes
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelSlot:
|
||||
pyTorchModelFile: str = ""
|
||||
pyTorchDiffusionModelFile: str = ""
|
||||
modelFile: str = ""
|
||||
diffusionFile: str = ""
|
||||
defaultTrans: int = 0
|
||||
# modelType: EnumDDSPSVCInferenceTypes = EnumDDSPSVCInferenceTypes.pyTorchRVC
|
||||
# samplingRate: int = -1
|
||||
# f0: bool = True
|
||||
# embChannels: int = 256
|
||||
# deprecated: bool = False
|
||||
# embedder: EnumEmbedderTypes = EnumEmbedderTypes.hubert
|
||||
|
@ -1,107 +0,0 @@
|
||||
from const import EnumEmbedderTypes, EnumInferenceTypes
|
||||
from voice_changer.RVC.ModelSlot import ModelSlot
|
||||
|
||||
from voice_changer.utils.LoadModelParams import FilePaths
|
||||
import torch
|
||||
import onnxruntime
|
||||
import json
|
||||
|
||||
|
||||
def generateModelSlot(files: FilePaths, params):
|
||||
modelSlot = ModelSlot()
|
||||
modelSlot.pyTorchModelFile = files.pyTorchModelFilename
|
||||
modelSlot.onnxModelFile = files.onnxModelFilename
|
||||
modelSlot.featureFile = files.featureFilename
|
||||
modelSlot.indexFile = files.indexFilename
|
||||
modelSlot.defaultTrans = params["trans"] if "trans" in params else 0
|
||||
|
||||
modelSlot.isONNX = True if modelSlot.onnxModelFile is not None else False
|
||||
|
||||
if modelSlot.isONNX:
|
||||
_setInfoByONNX(modelSlot, modelSlot.onnxModelFile)
|
||||
else:
|
||||
_setInfoByPytorch(modelSlot, modelSlot.pyTorchModelFile)
|
||||
return modelSlot
|
||||
|
||||
|
||||
def _setInfoByPytorch(slot: ModelSlot, file: str):
|
||||
cpt = torch.load(file, map_location="cpu")
|
||||
config_len = len(cpt["config"])
|
||||
if config_len == 18:
|
||||
slot.f0 = True if cpt["f0"] == 1 else False
|
||||
slot.modelType = (
|
||||
EnumInferenceTypes.pyTorchRVC
|
||||
if slot.f0
|
||||
else EnumInferenceTypes.pyTorchRVCNono
|
||||
)
|
||||
slot.embChannels = 256
|
||||
slot.embedder = EnumEmbedderTypes.hubert
|
||||
else:
|
||||
slot.f0 = True if cpt["f0"] == 1 else False
|
||||
slot.modelType = (
|
||||
EnumInferenceTypes.pyTorchWebUI
|
||||
if slot.f0
|
||||
else EnumInferenceTypes.pyTorchWebUINono
|
||||
)
|
||||
slot.embChannels = cpt["config"][17]
|
||||
slot.embedder = cpt["embedder_name"]
|
||||
if slot.embedder.endswith("768"):
|
||||
slot.embedder = slot.embedder[:-3]
|
||||
|
||||
if slot.embedder == EnumEmbedderTypes.hubert.value:
|
||||
slot.embedder = EnumEmbedderTypes.hubert
|
||||
elif slot.embedder == EnumEmbedderTypes.contentvec.value:
|
||||
slot.embedder = EnumEmbedderTypes.contentvec
|
||||
elif slot.embedder == EnumEmbedderTypes.hubert_jp.value:
|
||||
slot.embedder = EnumEmbedderTypes.hubert_jp
|
||||
else:
|
||||
raise RuntimeError("[Voice Changer][setInfoByONNX] unknown embedder")
|
||||
|
||||
slot.samplingRate = cpt["config"][-1]
|
||||
|
||||
del cpt
|
||||
|
||||
|
||||
def _setInfoByONNX(slot: ModelSlot, file: str):
|
||||
tmp_onnx_session = onnxruntime.InferenceSession(
|
||||
file, providers=["CPUExecutionProvider"]
|
||||
)
|
||||
modelmeta = tmp_onnx_session.get_modelmeta()
|
||||
try:
|
||||
metadata = json.loads(modelmeta.custom_metadata_map["metadata"])
|
||||
|
||||
# slot.modelType = metadata["modelType"]
|
||||
slot.embChannels = metadata["embChannels"]
|
||||
|
||||
if "embedder" not in metadata:
|
||||
slot.embedder = EnumEmbedderTypes.hubert
|
||||
elif metadata["embedder"] == EnumEmbedderTypes.hubert.value:
|
||||
slot.embedder = EnumEmbedderTypes.hubert
|
||||
elif metadata["embedder"] == EnumEmbedderTypes.contentvec.value:
|
||||
slot.embedder = EnumEmbedderTypes.contentvec
|
||||
elif metadata["embedder"] == EnumEmbedderTypes.hubert_jp.value:
|
||||
slot.embedder = EnumEmbedderTypes.hubert_jp
|
||||
else:
|
||||
raise RuntimeError("[Voice Changer][setInfoByONNX] unknown embedder")
|
||||
|
||||
slot.f0 = metadata["f0"]
|
||||
slot.modelType = (
|
||||
EnumInferenceTypes.onnxRVC if slot.f0 else EnumInferenceTypes.onnxRVCNono
|
||||
)
|
||||
slot.samplingRate = metadata["samplingRate"]
|
||||
slot.deprecated = False
|
||||
|
||||
except Exception as e:
|
||||
slot.modelType = EnumInferenceTypes.onnxRVC
|
||||
slot.embChannels = 256
|
||||
slot.embedder = EnumEmbedderTypes.hubert
|
||||
slot.f0 = True
|
||||
slot.samplingRate = 48000
|
||||
slot.deprecated = True
|
||||
|
||||
print("[Voice Changer] setInfoByONNX", e)
|
||||
print("[Voice Changer] ############## !!!! CAUTION !!!! ####################")
|
||||
print("[Voice Changer] This onnxfie is depricated. Please regenerate onnxfile.")
|
||||
print("[Voice Changer] ############## !!!! CAUTION !!!! ####################")
|
||||
|
||||
del tmp_onnx_session
|
@ -21,8 +21,8 @@ class SvcDDSP:
|
||||
def setVCParams(self, params: VoiceChangerParams):
|
||||
self.params = params
|
||||
|
||||
def update_model(self, model_path):
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
def update_model(self, model_path: str, device: torch.device):
|
||||
self.device = device
|
||||
|
||||
# load ddsp model
|
||||
if self.model is None or self.model_path != model_path:
|
||||
@ -42,35 +42,33 @@ class SvcDDSP:
|
||||
else:
|
||||
cnhubertsoft_gate = 10
|
||||
|
||||
# if self.args.data.encoder == "hubertsoft":
|
||||
# encoderPath = self.params.hubert_soft
|
||||
# elif self.args.data.encoder == "hubertbase":
|
||||
# encoderPath = self.params.hubert_base
|
||||
# elif self.args.data.encoder == "hubertbase768":
|
||||
# encoderPath = self.params.hubert_base
|
||||
# elif self.args.data.encoder == "hubertbase768l12":
|
||||
# encoderPath = self.params.hubert_base
|
||||
# elif self.args.data.encoder == "hubertlarge1024l24":
|
||||
# encoderPath = self.params.hubert_base
|
||||
# elif self.args.data.encoder == "contentvec":
|
||||
# encoderPath = self.params.hubert_base
|
||||
# elif self.args.data.encoder == "contentvec768":
|
||||
# encoderPath = self.params.hubert_base
|
||||
# elif self.args.data.encoder == "contentvec768l12":
|
||||
# encoderPath = self.params.hubert_base
|
||||
if self.args.data.encoder == "hubertsoft":
|
||||
encoderPath = self.params.hubert_soft
|
||||
elif self.args.data.encoder == "hubertbase":
|
||||
encoderPath = self.params.hubert_base
|
||||
elif self.args.data.encoder == "hubertbase768":
|
||||
encoderPath = self.params.hubert_base
|
||||
elif self.args.data.encoder == "hubertbase768l12":
|
||||
encoderPath = self.params.hubert_base
|
||||
elif self.args.data.encoder == "hubertlarge1024l24":
|
||||
encoderPath = self.params.hubert_base
|
||||
elif self.args.data.encoder == "contentvec":
|
||||
encoderPath = self.params.hubert_base
|
||||
elif self.args.data.encoder == "contentvec768":
|
||||
encoderPath = self.params.hubert_base
|
||||
elif self.args.data.encoder == "contentvec768l12":
|
||||
encoderPath = self.params.hubert_base
|
||||
|
||||
self.units_encoder = Units_Encoder(
|
||||
self.args.data.encoder,
|
||||
# encoderPath,
|
||||
self.args.data.encoder_ckpt,
|
||||
encoderPath,
|
||||
self.args.data.encoder_sample_rate,
|
||||
self.args.data.encoder_hop_size,
|
||||
cnhubertsoft_gate=cnhubertsoft_gate,
|
||||
device=self.device,
|
||||
)
|
||||
self.encoder_type = self.args.data.encoder
|
||||
# self.encoder_ckpt = encoderPath
|
||||
self.encoder_ckpt = self.args.data.encoder_ckpt
|
||||
self.encoder_ckpt = encoderPath
|
||||
|
||||
# load enhancer
|
||||
if (
|
||||
@ -109,8 +107,8 @@ class SvcDDSP:
|
||||
diff_silence=False,
|
||||
audio_alignment=False,
|
||||
):
|
||||
print("Infering...")
|
||||
print("audio", audio)
|
||||
# print("Infering...")
|
||||
# print("audio", audio)
|
||||
# load input
|
||||
# audio, sample_rate = librosa.load(input_wav, sr=None, mono=True)
|
||||
hop_size = (
|
||||
|
@ -32,13 +32,6 @@ from voice_changer.MMVCv13.TrainerFunctions import (
|
||||
|
||||
from Exceptions import NoModeLoadedException
|
||||
|
||||
providers = [
|
||||
"OpenVINOExecutionProvider",
|
||||
"CUDAExecutionProvider",
|
||||
"DmlExecutionProvider",
|
||||
"CPUExecutionProvider",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MMVCv13Settings:
|
||||
@ -69,11 +62,18 @@ class MMVCv13:
|
||||
self.text_norm = torch.LongTensor([0, 6, 0])
|
||||
|
||||
def loadModel(self, props: LoadModelParams):
|
||||
self.settings.configFile = props.files.configFilename
|
||||
params = props.params
|
||||
|
||||
self.settings.configFile = params["files"]["mmvcv13Config"]
|
||||
self.hps = get_hparams_from_file(self.settings.configFile)
|
||||
|
||||
self.settings.pyTorchModelFile = props.files.pyTorchModelFilename
|
||||
self.settings.onnxModelFile = props.files.onnxModelFilename
|
||||
modelFile = params["files"]["mmvcv13Models"]
|
||||
if modelFile.endswith(".onnx"):
|
||||
self.settings.pyTorchModelFile = None
|
||||
self.settings.onnxModelFile = modelFile
|
||||
else:
|
||||
self.settings.pyTorchModelFile = modelFile
|
||||
self.settings.onnxModelFile = None
|
||||
|
||||
# PyTorchモデル生成
|
||||
if self.settings.pyTorchModelFile is not None:
|
||||
@ -89,41 +89,58 @@ class MMVCv13:
|
||||
|
||||
# ONNXモデル生成
|
||||
if self.settings.onnxModelFile is not None:
|
||||
ort_options = onnxruntime.SessionOptions()
|
||||
ort_options.intra_op_num_threads = 8
|
||||
# ort_options = onnxruntime.SessionOptions()
|
||||
# ort_options.intra_op_num_threads = 8
|
||||
# ort_options.execution_mode = ort_options.ExecutionMode.ORT_PARALLEL
|
||||
# ort_options.inter_op_num_threads = 8
|
||||
providers, options = self.getOnnxExecutionProvider()
|
||||
self.onnx_session = onnxruntime.InferenceSession(
|
||||
self.settings.onnxModelFile, providers=providers
|
||||
self.settings.onnxModelFile,
|
||||
providers=providers,
|
||||
provider_options=options,
|
||||
)
|
||||
return self.get_info()
|
||||
|
||||
def update_settings(self, key: str, val: int | float | str):
|
||||
if key == "onnxExecutionProvider" and self.onnx_session is not None:
|
||||
if val == "CUDAExecutionProvider":
|
||||
if self.settings.gpu < 0 or self.settings.gpu >= self.gpu_num:
|
||||
self.settings.gpu = 0
|
||||
provider_options = [{"device_id": self.settings.gpu}]
|
||||
self.onnx_session.set_providers(
|
||||
providers=[val], provider_options=provider_options
|
||||
)
|
||||
def getOnnxExecutionProvider(self):
|
||||
if self.settings.gpu >= 0:
|
||||
return ["CUDAExecutionProvider"], [{"device_id": self.settings.gpu}]
|
||||
elif "DmlExecutionProvider" in onnxruntime.get_available_providers():
|
||||
return ["DmlExecutionProvider"], []
|
||||
else:
|
||||
self.onnx_session.set_providers(providers=[val])
|
||||
elif key in self.settings.intData:
|
||||
return ["CPUExecutionProvider"], [
|
||||
{
|
||||
"intra_op_num_threads": 8,
|
||||
"execution_mode": onnxruntime.ExecutionMode.ORT_PARALLEL,
|
||||
"inter_op_num_threads": 8,
|
||||
}
|
||||
]
|
||||
|
||||
def isOnnx(self):
|
||||
if self.settings.onnxModelFile is not None:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def update_settings(self, key: str, val: int | float | str):
|
||||
if key in self.settings.intData:
|
||||
val = int(val)
|
||||
setattr(self.settings, key, val)
|
||||
if (
|
||||
key == "gpu"
|
||||
and val >= 0
|
||||
and val < self.gpu_num
|
||||
and self.onnx_session is not None
|
||||
):
|
||||
providers = self.onnx_session.get_providers()
|
||||
print("Providers:", providers)
|
||||
if "CUDAExecutionProvider" in providers:
|
||||
provider_options = [{"device_id": self.settings.gpu}]
|
||||
self.onnx_session.set_providers(
|
||||
providers=["CUDAExecutionProvider"],
|
||||
provider_options=provider_options,
|
||||
|
||||
if key == "gpu" and self.isOnnx():
|
||||
providers, options = self.getOnnxExecutionProvider()
|
||||
self.onnx_session = onnxruntime.InferenceSession(
|
||||
self.settings.onnxModelFile,
|
||||
providers=providers,
|
||||
provider_options=options,
|
||||
)
|
||||
# providers = self.onnx_session.get_providers()
|
||||
# print("Providers:", providers)
|
||||
# if "CUDAExecutionProvider" in providers:
|
||||
# provider_options = [{"device_id": self.settings.gpu}]
|
||||
# self.onnx_session.set_providers(
|
||||
# providers=["CUDAExecutionProvider"],
|
||||
# provider_options=provider_options,
|
||||
# )
|
||||
elif key in self.settings.floatData:
|
||||
setattr(self.settings, key, float(val))
|
||||
elif key in self.settings.strData:
|
||||
@ -254,7 +271,7 @@ class MMVCv13:
|
||||
return result
|
||||
|
||||
def inference(self, data):
|
||||
if self.settings.framework == "ONNX":
|
||||
if self.isOnnx():
|
||||
audio = self._onnx_inference(data)
|
||||
else:
|
||||
audio = self._pyTorch_inference(data)
|
||||
|
Loading…
x
Reference in New Issue
Block a user