support onnx generater
This commit is contained in:
parent
0d3c97e9fe
commit
f1b700ea5d
@ -55,6 +55,10 @@
|
||||
{
|
||||
"name": "modelSamplingRate",
|
||||
"options": {}
|
||||
},
|
||||
{
|
||||
"name": "onnxExport",
|
||||
"options": {}
|
||||
}
|
||||
],
|
||||
"deviceSetting": [
|
||||
|
11
client/demo/dist/index.html
vendored
11
client/demo/dist/index.html
vendored
@ -1 +1,10 @@
|
||||
<!doctype html><html style="width:100%;height:100%;overflow:hidden"><head><meta charset="utf-8"/><title>Voice Changer Client Demo</title><script defer="defer" src="index.js"></script></head><body style="width:100%;height:100%;margin:0"><div id="app" style="width:100%;height:100%"></div></body></html>
|
||||
<!DOCTYPE html>
|
||||
<html style="width: 100%; height: 100%; overflow: hidden">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<title>Voice Changer Client Demo</title>
|
||||
<script defer src="index.js"></script></head>
|
||||
<body style="width: 100%; height: 100%; margin: 0px">
|
||||
<div id="app" style="width: 100%; height: 100%"></div>
|
||||
</body>
|
||||
</html>
|
||||
|
1585
client/demo/dist/index.js
vendored
1585
client/demo/dist/index.js
vendored
File diff suppressed because one or more lines are too long
31
client/demo/dist/index.js.LICENSE.txt
vendored
31
client/demo/dist/index.js.LICENSE.txt
vendored
@ -1,31 +0,0 @@
|
||||
/*! regenerator-runtime -- Copyright (c) 2014-present, Facebook, Inc. -- license (MIT): https://github.com/facebook/regenerator/blob/main/LICENSE */
|
||||
|
||||
/**
|
||||
* @license React
|
||||
* react-dom.production.min.js
|
||||
*
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @license React
|
||||
* react.production.min.js
|
||||
*
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @license React
|
||||
* scheduler.production.min.js
|
||||
*
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
@ -55,6 +55,10 @@
|
||||
{
|
||||
"name": "modelSamplingRate",
|
||||
"options": {}
|
||||
},
|
||||
{
|
||||
"name": "onnxExport",
|
||||
"options": {}
|
||||
}
|
||||
],
|
||||
"deviceSetting": [
|
||||
|
@ -11,6 +11,7 @@ export const OpenConverterSettingCheckbox = "open-converter-setting-checkbox"
|
||||
export const OpenAdvancedSettingCheckbox = "open-advanced-setting-checkbox"
|
||||
|
||||
export const OpenLicenseDialogCheckbox = "open-license-dialog-checkbox"
|
||||
export const OpenWaitingDialogCheckbox = "open-waiting-dialog-checkbox"
|
||||
|
||||
type Props = {
|
||||
children: ReactNode;
|
||||
@ -26,6 +27,7 @@ export type StateControls = {
|
||||
openAdvancedSettingCheckbox: StateControlCheckbox
|
||||
|
||||
showLicenseCheckbox: StateControlCheckbox
|
||||
showWaitingCheckbox: StateControlCheckbox
|
||||
}
|
||||
|
||||
type GuiStateAndMethod = {
|
||||
@ -135,6 +137,7 @@ export const GuiStateProvider = ({ children }: Props) => {
|
||||
const openAdvancedSettingCheckbox = useStateControlCheckbox(OpenAdvancedSettingCheckbox);
|
||||
|
||||
const showLicenseCheckbox = useStateControlCheckbox(OpenLicenseDialogCheckbox);
|
||||
const showWaitingCheckbox = useStateControlCheckbox(OpenWaitingDialogCheckbox);
|
||||
|
||||
useEffect(() => {
|
||||
openServerControlCheckbox.updateState(true)
|
||||
@ -145,6 +148,7 @@ export const GuiStateProvider = ({ children }: Props) => {
|
||||
openQualityControlCheckbox.updateState(true)
|
||||
|
||||
showLicenseCheckbox.updateState(true)
|
||||
showWaitingCheckbox.updateState(false)
|
||||
|
||||
}, [])
|
||||
|
||||
@ -159,7 +163,8 @@ export const GuiStateProvider = ({ children }: Props) => {
|
||||
openConverterSettingCheckbox,
|
||||
openAdvancedSettingCheckbox,
|
||||
|
||||
showLicenseCheckbox
|
||||
showLicenseCheckbox,
|
||||
showWaitingCheckbox
|
||||
},
|
||||
isConverting,
|
||||
setIsConverting,
|
||||
|
@ -36,6 +36,7 @@ import { TrancateNumTresholdRow, TrancateNumTresholdRowProps } from "./component
|
||||
import { IndexRatioRow, IndexRatioRowProps } from "./components/609_IndexRatioRow"
|
||||
import { RVCQualityRow, RVCQualityRowProps } from "./components/810_RVCQuality"
|
||||
import { ModelSamplingRateRow, ModelSamplingRateRowProps } from "./components/303_ModelSamplingRateRow"
|
||||
import { OnnxExportRow, OnnxExportRowProps } from "./components/304_OnnxExportRow"
|
||||
|
||||
export const catalog: { [key: string]: (props: any) => JSX.Element } = {}
|
||||
|
||||
@ -64,7 +65,7 @@ const initialize = () => {
|
||||
addToCatalog("modelUploader", (props: ModelUploaderRowProps) => { return <ModelUploaderRow {...props} /> })
|
||||
addToCatalog("framework", (props: FrameworkRowProps) => { return <FrameworkRow {...props} /> })
|
||||
addToCatalog("modelSamplingRate", (props: ModelSamplingRateRowProps) => { return <ModelSamplingRateRow {...props} /> })
|
||||
|
||||
addToCatalog("onnxExport", (props: OnnxExportRowProps) => { return <OnnxExportRow {...props} /> })
|
||||
|
||||
addToCatalog("audioInput", (props: AudioInputRowProps) => { return <AudioInputRow {...props} /> })
|
||||
addToCatalog("audioOutput", (props: AudioOutputRowProps) => { return <AudioOutputRow {...props} /> })
|
||||
|
@ -1,6 +1,7 @@
|
||||
import React from "react";
|
||||
import { useGuiState } from "./001_GuiStateProvider";
|
||||
import { LicenseDialog } from "./901_LicenseDialog";
|
||||
import { WaitingDialog } from "./902_WaitingDialog";
|
||||
|
||||
export const Dialogs = () => {
|
||||
const guiState = useGuiState()
|
||||
@ -11,6 +12,12 @@ export const Dialogs = () => {
|
||||
{guiState.stateControls.showLicenseCheckbox.trigger}
|
||||
<LicenseDialog></LicenseDialog>
|
||||
</div>
|
||||
|
||||
{guiState.stateControls.showWaitingCheckbox.trigger}
|
||||
<div className="dialog-container" id="dialog">
|
||||
{guiState.stateControls.showWaitingCheckbox.trigger}
|
||||
<WaitingDialog></WaitingDialog>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
|
@ -45,5 +45,4 @@ export const LicenseDialog = () => {
|
||||
}, [licenses]);
|
||||
return dialog;
|
||||
|
||||
return <></>
|
||||
};
|
||||
|
42
client/demo/src/components/demo/902_WaitingDialog.tsx
Normal file
42
client/demo/src/components/demo/902_WaitingDialog.tsx
Normal file
@ -0,0 +1,42 @@
|
||||
import React, { useMemo } from "react";
|
||||
import { useGuiState } from "./001_GuiStateProvider";
|
||||
|
||||
|
||||
export const WaitingDialog = () => {
|
||||
const guiState = useGuiState()
|
||||
|
||||
const dialog = useMemo(() => {
|
||||
const closeButtonRow = (
|
||||
<div className="body-row split-3-4-3 left-padding-1">
|
||||
<div className="body-item-text">
|
||||
</div>
|
||||
<div className="body-button-container body-button-container-space-around">
|
||||
<div className="body-button" onClick={() => { guiState.stateControls.showWaitingCheckbox.updateState(false) }} >close</div>
|
||||
</div>
|
||||
<div className="body-item-text"></div>
|
||||
</div>
|
||||
)
|
||||
const content = (
|
||||
<div className="body-row split-3-4-3 left-padding-1">
|
||||
<div className="body-item-text">
|
||||
</div>
|
||||
<div className="body-item-text">
|
||||
please wait... (about 1 min)
|
||||
</div>
|
||||
<div className="body-item-text"></div>
|
||||
</div>
|
||||
)
|
||||
|
||||
return (
|
||||
<div className="dialog-frame">
|
||||
<div className="dialog-title">export onnx file</div>
|
||||
<div className="dialog-content">
|
||||
{content}
|
||||
{closeButtonRow}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}, []);
|
||||
return dialog;
|
||||
|
||||
};
|
@ -18,7 +18,7 @@ export const ModelSamplingRateRow = (_props: ModelSamplingRateRowProps) => {
|
||||
|
||||
return (
|
||||
<div className="body-row split-3-3-4 left-padding-1 guided">
|
||||
<div className="body-item-title left-padding-2">Model Sampling Rate</div>
|
||||
<div className="body-item-title left-padding-2">Model Sampling Rate(only for onnx)</div>
|
||||
<div className="body-item-text">
|
||||
<div></div>
|
||||
</div>
|
||||
|
@ -0,0 +1,50 @@
|
||||
import React, { useMemo } from "react"
|
||||
import { useAppState } from "../../../001_provider/001_AppStateProvider"
|
||||
import { OnnxExporterInfo } from "@dannadori/voice-changer-client-js";
|
||||
import { useGuiState } from "../001_GuiStateProvider";
|
||||
export type OnnxExportRowProps = {
|
||||
}
|
||||
|
||||
export const OnnxExportRow = (_props: OnnxExportRowProps) => {
|
||||
const appState = useAppState()
|
||||
|
||||
const onnxExportRow = useMemo(() => {
|
||||
const guiState = useGuiState()
|
||||
|
||||
const onnxExportButtonClassName = "body-button"
|
||||
const onnxExportButtonAction = async () => {
|
||||
|
||||
if (guiState.isConverting) {
|
||||
alert("cannot export onnx when voice conversion is enabled")
|
||||
}
|
||||
document.getElementById("dialog")?.classList.add("dialog-container-show")
|
||||
guiState.stateControls.showWaitingCheckbox.updateState(true)
|
||||
const res = await appState.serverSetting.getOnnx() as OnnxExporterInfo
|
||||
const a = document.createElement("a")
|
||||
a.href = res.path
|
||||
a.download = res.filename;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
guiState.stateControls.showWaitingCheckbox.updateState(false)
|
||||
|
||||
}
|
||||
const onnxExportButtonLabel = "onnx export"
|
||||
|
||||
|
||||
return (
|
||||
<div className="body-row split-3-3-4 left-padding-1 guided">
|
||||
<div className="body-item-title left-padding-2">Onnx Exporter</div>
|
||||
<div className="body-item-text">
|
||||
<div></div>
|
||||
</div>
|
||||
<div className="body-button-container">
|
||||
<div className={onnxExportButtonClassName} onClick={onnxExportButtonAction}>{onnxExportButtonLabel}</div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}, [appState.serverSetting.serverSetting, appState.serverSetting.updateServerSettings])
|
||||
|
||||
return onnxExportRow
|
||||
}
|
@ -10,10 +10,16 @@ export const InputChunkNumRow = (_props: InputChunkNumRowProps) => {
|
||||
<div className="body-row split-3-2-1-4 left-padding-1 guided">
|
||||
<div className="body-item-title left-padding-1">Input Chunk Num(128sample/chunk)</div>
|
||||
<div className="body-input-container">
|
||||
<input type="number" min={1} max={256} step={1} value={appState.workletNodeSetting.workletNodeSetting.inputChunkNum} onChange={(e) => {
|
||||
<select className="body-select" value={appState.workletNodeSetting.workletNodeSetting.inputChunkNum} onChange={(e) => {
|
||||
appState.workletNodeSetting.updateWorkletNodeSetting({ ...appState.workletNodeSetting.workletNodeSetting, inputChunkNum: Number(e.target.value) })
|
||||
appState.workletNodeSetting.trancateBuffer()
|
||||
}} />
|
||||
}}>
|
||||
{
|
||||
[16, 32, 64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 2048].map(x => {
|
||||
return <option key={x} value={x}>{x}</option>
|
||||
})
|
||||
}
|
||||
</select>
|
||||
</div>
|
||||
<div className="body-item-text">
|
||||
<div>buff: {(appState.workletNodeSetting.workletNodeSetting.inputChunkNum * 128 * 1000 / 48000).toFixed(1)}ms</div>
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { ClientType, ServerInfo, ServerSettingKey } from "./const";
|
||||
import { ClientType, OnnxExporterInfo, ServerInfo, ServerSettingKey } from "./const";
|
||||
|
||||
|
||||
type FileChunk = {
|
||||
@ -160,4 +160,16 @@ export class ServerConfigurator {
|
||||
return await info
|
||||
}
|
||||
|
||||
export2onnx = async () => {
|
||||
const url = this.serverUrl + "/onnx"
|
||||
const info = new Promise<OnnxExporterInfo>(async (resolve) => {
|
||||
const request = new Request(url, {
|
||||
method: 'GET',
|
||||
});
|
||||
const res = await (await fetch(request)).json() as OnnxExporterInfo
|
||||
resolve(res)
|
||||
})
|
||||
return await info
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -267,6 +267,9 @@ export class VoiceChangerClient {
|
||||
getModelType = () => {
|
||||
return this.configurator.getModelType()
|
||||
}
|
||||
getOnnx = async () => {
|
||||
return this.configurator.export2onnx()
|
||||
}
|
||||
|
||||
|
||||
updateServerSettings = (key: ServerSettingKey, val: string) => {
|
||||
|
@ -539,3 +539,9 @@ export const INDEXEDDB_KEY_MODEL_DATA = "INDEXEDDB_KEY_VOICE_CHANGER_LIB_MODEL_D
|
||||
export const INDEXEDDB_KEY_WORKLET = "INDEXEDDB_KEY_VOICE_CHANGER_LIB_WORKLET"
|
||||
|
||||
|
||||
// ONNX
|
||||
export type OnnxExporterInfo = {
|
||||
"status": string
|
||||
"path": string
|
||||
"filename": string
|
||||
}
|
@ -1,5 +1,5 @@
|
||||
import { useState, useMemo, useEffect } from "react"
|
||||
import { VoiceChangerServerSetting, ServerInfo, ServerSettingKey, INDEXEDDB_KEY_SERVER, INDEXEDDB_KEY_MODEL_DATA, ClientType, DefaultServerSetting_MMVCv13, DefaultServerSetting_MMVCv15, DefaultServerSetting_so_vits_svc_40v2, DefaultServerSetting_so_vits_svc_40, DefaultServerSetting_so_vits_svc_40_c, DefaultServerSetting_RVC } from "../const"
|
||||
import { VoiceChangerServerSetting, ServerInfo, ServerSettingKey, INDEXEDDB_KEY_SERVER, INDEXEDDB_KEY_MODEL_DATA, ClientType, DefaultServerSetting_MMVCv13, DefaultServerSetting_MMVCv15, DefaultServerSetting_so_vits_svc_40v2, DefaultServerSetting_so_vits_svc_40, DefaultServerSetting_so_vits_svc_40_c, DefaultServerSetting_RVC, OnnxExporterInfo } from "../const"
|
||||
import { VoiceChangerClient } from "../VoiceChangerClient"
|
||||
import { useIndexedDB } from "./useIndexedDB"
|
||||
|
||||
@ -54,6 +54,8 @@ export type ServerSettingState = {
|
||||
uploadProgress: number
|
||||
isUploading: boolean
|
||||
|
||||
getOnnx: () => Promise<OnnxExporterInfo>
|
||||
|
||||
}
|
||||
|
||||
export const useServerSetting = (props: UseServerSettingProps): ServerSettingState => {
|
||||
@ -315,6 +317,11 @@ export const useServerSetting = (props: UseServerSettingProps): ServerSettingSta
|
||||
await removeItem(INDEXEDDB_KEY_MODEL_DATA)
|
||||
}
|
||||
|
||||
|
||||
const getOnnx = async () => {
|
||||
return props.voiceChangerClient!.getOnnx()
|
||||
}
|
||||
|
||||
return {
|
||||
serverSetting,
|
||||
updateServerSettings,
|
||||
@ -326,5 +333,6 @@ export const useServerSetting = (props: UseServerSettingProps): ServerSettingSta
|
||||
loadModel,
|
||||
uploadProgress,
|
||||
isUploading,
|
||||
getOnnx,
|
||||
}
|
||||
}
|
@ -8,8 +8,8 @@ class UvicornSuppressFilter(logging.Filter):
|
||||
return False
|
||||
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
logger.addFilter(UvicornSuppressFilter())
|
||||
# logger = logging.getLogger("uvicorn.error")
|
||||
# logger.addFilter(UvicornSuppressFilter())
|
||||
|
||||
logger = logging.getLogger("fairseq.tasks.hubert_pretraining")
|
||||
logger.addFilter(UvicornSuppressFilter())
|
||||
|
@ -27,6 +27,7 @@ class MMVC_Rest_Fileuploader:
|
||||
self.router.add_api_route("/extract_voices", self.post_extract_voices, methods=["POST"])
|
||||
self.router.add_api_route("/model_type", self.post_model_type, methods=["POST"])
|
||||
self.router.add_api_route("/model_type", self.get_model_type, methods=["GET"])
|
||||
self.router.add_api_route("/onnx", self.get_onnx, methods=["GET"])
|
||||
|
||||
def post_upload_file(self, file: UploadFile = File(...), filename: str = Form(...)):
|
||||
res = upload_file(UPLOAD_DIR, file, filename)
|
||||
@ -110,3 +111,8 @@ class MMVC_Rest_Fileuploader:
|
||||
info = self.voiceChangerManager.getModelType(modelType)
|
||||
json_compatible_item_data = jsonable_encoder(info)
|
||||
return JSONResponse(content=json_compatible_item_data)
|
||||
|
||||
def get_onnx(self):
|
||||
info = self.voiceChangerManager.export2onnx()
|
||||
json_compatible_item_data = jsonable_encoder(info)
|
||||
return JSONResponse(content=json_compatible_item_data)
|
||||
|
@ -3,6 +3,7 @@ import os
|
||||
import resampy
|
||||
from voice_changer.RVC.ModelWrapper import ModelWrapper
|
||||
|
||||
|
||||
# avoiding parse arg error in RVC
|
||||
sys.argv = ["MMVCServerSIO.py"]
|
||||
|
||||
@ -23,7 +24,7 @@ import numpy as np
|
||||
import torch
|
||||
import onnxruntime
|
||||
# onnxruntime.set_default_logger_severity(3)
|
||||
from const import HUBERT_ONNX_MODEL_PATH
|
||||
from const import HUBERT_ONNX_MODEL_PATH, TMP_DIR
|
||||
|
||||
import pyworld as pw
|
||||
|
||||
@ -293,3 +294,20 @@ class RVC:
|
||||
sys.modules.pop(key)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def export2onnx(self):
|
||||
if hasattr(self, "net_g") == False or self.net_g == None:
|
||||
print("[Voice Changer] export2onnx, No pyTorch session.")
|
||||
return {"status": "ng", "path": f""}
|
||||
if self.settings.pyTorchModelFile == None:
|
||||
print("[Voice Changer] export2onnx, No pyTorch filepath.")
|
||||
return {"status": "ng", "path": f""}
|
||||
import voice_changer.RVC.export2onnx as onnxExporter
|
||||
|
||||
output_file = os.path.splitext(os.path.basename(self.settings.pyTorchModelFile))[0] + ".onnx"
|
||||
output_file_simple = os.path.splitext(os.path.basename(self.settings.pyTorchModelFile))[0] + "_simple.onnx"
|
||||
output_path = os.path.join(TMP_DIR, output_file)
|
||||
output_path_simple = os.path.join(TMP_DIR, output_file_simple)
|
||||
|
||||
onnxExporter.export2onnx(self.settings.pyTorchModelFile, output_path, output_path_simple, True)
|
||||
return {"status": "ok", "path": f"/tmp/{output_file_simple}", "filename": output_file_simple}
|
||||
|
149
server/voice_changer/RVC/export2onnx.py
Normal file
149
server/voice_changer/RVC/export2onnx.py
Normal file
@ -0,0 +1,149 @@
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
from distutils.util import strtobool
|
||||
import torch
|
||||
from torch import nn
|
||||
from onnxsim import simplify
|
||||
import onnx
|
||||
|
||||
from infer_pack.models import TextEncoder256, GeneratorNSF, PosteriorEncoder, ResidualCouplingBlock
|
||||
|
||||
|
||||
class SynthesizerTrnMs256NSFsid_ONNX(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
spec_channels,
|
||||
segment_size,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
spk_embed_dim,
|
||||
gin_channels,
|
||||
sr,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
if (type(sr) == type("strr")):
|
||||
sr = sr2sr[sr]
|
||||
self.spec_channels = spec_channels
|
||||
self.inter_channels = inter_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.resblock = resblock
|
||||
self.resblock_kernel_sizes = resblock_kernel_sizes
|
||||
self.resblock_dilation_sizes = resblock_dilation_sizes
|
||||
self.upsample_rates = upsample_rates
|
||||
self.upsample_initial_channel = upsample_initial_channel
|
||||
self.upsample_kernel_sizes = upsample_kernel_sizes
|
||||
self.segment_size = segment_size
|
||||
self.gin_channels = gin_channels
|
||||
# self.hop_length = hop_length#
|
||||
self.spk_embed_dim = spk_embed_dim
|
||||
self.enc_p = TextEncoder256(
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
)
|
||||
self.dec = GeneratorNSF(
|
||||
inter_channels,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
gin_channels=gin_channels, sr=sr, is_half=kwargs["is_half"]
|
||||
)
|
||||
self.enc_q = PosteriorEncoder(
|
||||
spec_channels,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
5,
|
||||
1,
|
||||
16,
|
||||
gin_channels=gin_channels,
|
||||
)
|
||||
self.flow = ResidualCouplingBlock(
|
||||
inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
|
||||
)
|
||||
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
|
||||
print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
|
||||
|
||||
def forward(self, phone, phone_lengths, pitch, nsff0, sid, max_len=None):
|
||||
g = self.emb_g(sid).unsqueeze(-1)
|
||||
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
||||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||
o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g)
|
||||
return o, x_mask, (z, z_p, m_p, logs_p)
|
||||
|
||||
|
||||
def export2onnx(input_model, output_model, output_model_simple, is_half):
|
||||
cpt = torch.load(input_model, map_location="cpu")
|
||||
if is_half:
|
||||
dev = torch.device("cuda", index=0)
|
||||
else:
|
||||
dev = torch.device("cpu")
|
||||
|
||||
net_g_onnx = SynthesizerTrnMs256NSFsid_ONNX(*cpt["config"], is_half=is_half)
|
||||
net_g_onnx.eval().to(dev)
|
||||
net_g_onnx.load_state_dict(cpt["weight"], strict=False)
|
||||
if is_half:
|
||||
net_g_onnx = net_g_onnx.half()
|
||||
|
||||
if is_half:
|
||||
feats = torch.HalfTensor(1, 2192, 256).to(dev)
|
||||
else:
|
||||
feats = torch.FloatTensor(1, 2192, 256).to(dev)
|
||||
p_len = torch.LongTensor([2192]).to(dev)
|
||||
pitch = torch.zeros(1, 2192, dtype=torch.int64).to(dev)
|
||||
|
||||
pitchf = torch.FloatTensor(1, 2192).to(dev)
|
||||
sid = torch.LongTensor([0]).to(dev)
|
||||
|
||||
input_names = ["feats", "p_len", "pitch", "pitchf", "sid"]
|
||||
output_names = ["audio", ]
|
||||
|
||||
torch.onnx.export(net_g_onnx,
|
||||
(
|
||||
feats,
|
||||
p_len,
|
||||
pitch,
|
||||
pitchf,
|
||||
sid,
|
||||
),
|
||||
output_model,
|
||||
dynamic_axes={
|
||||
"feats": [1],
|
||||
"pitch": [1],
|
||||
"pitchf": [1],
|
||||
},
|
||||
do_constant_folding=False,
|
||||
opset_version=17,
|
||||
verbose=False,
|
||||
input_names=input_names,
|
||||
output_names=output_names)
|
||||
|
||||
model_onnx2 = onnx.load(output_model)
|
||||
model_simp, check = simplify(model_onnx2)
|
||||
onnx.save(model_simp, output_model_simple)
|
@ -60,30 +60,6 @@ class VoiceChanger():
|
||||
self.currentCrossFadeOverlapSize = 0 # setting
|
||||
self.crossfadeSize = 0 # calculated
|
||||
|
||||
# self.modelType = getModelType()
|
||||
# print("[VoiceChanger] activate model type:", self.modelType)
|
||||
# if self.modelType == "MMVCv15":
|
||||
# from voice_changer.MMVCv15.MMVCv15 import MMVCv15
|
||||
# self.voiceChanger = MMVCv15() # type: ignore
|
||||
# elif self.modelType == "MMVCv13":
|
||||
# from voice_changer.MMVCv13.MMVCv13 import MMVCv13
|
||||
# self.voiceChanger = MMVCv13()
|
||||
# elif self.modelType == "so-vits-svc-40v2":
|
||||
# from voice_changer.SoVitsSvc40v2.SoVitsSvc40v2 import SoVitsSvc40v2
|
||||
# self.voiceChanger = SoVitsSvc40v2(params)
|
||||
# elif self.modelType == "so-vits-svc-40" or self.modelType == "so-vits-svc-40_c":
|
||||
# from voice_changer.SoVitsSvc40.SoVitsSvc40 import SoVitsSvc40
|
||||
# self.voiceChanger = SoVitsSvc40(params)
|
||||
# elif self.modelType == "DDSP-SVC":
|
||||
# from voice_changer.DDSP_SVC.DDSP_SVC import DDSP_SVC
|
||||
# self.voiceChanger = DDSP_SVC(params)
|
||||
# elif self.modelType == "RVC":
|
||||
# from voice_changer.RVC.RVC import RVC
|
||||
# self.voiceChanger = RVC(params)
|
||||
# else:
|
||||
# from voice_changer.MMVCv13.MMVCv13 import MMVCv13
|
||||
# self.voiceChanger = MMVCv13()
|
||||
|
||||
self.voiceChanger = None
|
||||
self.modelType = None
|
||||
self.params = params
|
||||
@ -324,8 +300,10 @@ class VoiceChanger():
|
||||
perf = [preprocess_time, mainprocess_time, postprocess_time]
|
||||
return outputData, perf
|
||||
|
||||
def export2onnx(self):
|
||||
return self.voiceChanger.export2onnx()
|
||||
|
||||
##############
|
||||
##############
|
||||
PRINT_CONVERT_PROCESSING: bool = False
|
||||
# PRINT_CONVERT_PROCESSING = True
|
||||
|
||||
|
@ -44,3 +44,6 @@ class VoiceChangerManager():
|
||||
|
||||
def getModelType(self):
|
||||
return self.voiceChanger.getModelType()
|
||||
|
||||
def export2onnx(self):
|
||||
return self.voiceChanger.export2onnx()
|
||||
|
Loading…
x
Reference in New Issue
Block a user