optimize swap model
This commit is contained in:
parent
47d69f19f3
commit
f95b63ea5f
@ -88,6 +88,7 @@ class RVC:
|
||||
self.prevVol = 0
|
||||
self.params = params
|
||||
self.mps_enabled: bool = getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available()
|
||||
self.currentSlot = -1
|
||||
print("RVC initialization: ", params)
|
||||
print("mps: ", self.mps_enabled)
|
||||
|
||||
@ -100,16 +101,16 @@ class RVC:
|
||||
# self.index_file = props["files"]["indexFilename"]
|
||||
|
||||
self.is_half = props["isHalf"]
|
||||
self.slot = props["slot"]
|
||||
self.tmp_slot = props["slot"]
|
||||
|
||||
self.settings.modelSlots[self.slot] = ModelSlot(
|
||||
self.settings.modelSlots[self.tmp_slot] = ModelSlot(
|
||||
pyTorchModelFile=props["files"]["pyTorchModelFilename"],
|
||||
onnxModelFile=props["files"]["onnxModelFilename"],
|
||||
featureFile=props["files"]["featureFilename"],
|
||||
indexFile=props["files"]["indexFilename"]
|
||||
)
|
||||
|
||||
print("[Voice Changer] RVC loading... slot:", self.slot)
|
||||
print("[Voice Changer] RVC loading... slot:", self.tmp_slot)
|
||||
|
||||
try:
|
||||
hubert_path = self.params["hubert_base"]
|
||||
@ -123,13 +124,14 @@ class RVC:
|
||||
except Exception as e:
|
||||
print("EXCEPTION during loading hubert/contentvec model", e)
|
||||
|
||||
self.switchModel(self.slot)
|
||||
# self.switchModel(self.slot)
|
||||
self.prepareModel(self.tmp_slot)
|
||||
self.slot = self.tmp_slot
|
||||
|
||||
return self.get_info()
|
||||
|
||||
def switchModel(self, slot: int):
|
||||
print("[Voice Changer] Switch Model to:", slot)
|
||||
self.slot = slot
|
||||
def prepareModel(self, slot: int):
|
||||
print("[Voice Changer] Prepare Model of slot:", slot)
|
||||
pyTorchModelFile = self.settings.modelSlots[slot].pyTorchModelFile
|
||||
onnxModelFile = self.settings.modelSlots[slot].onnxModelFile
|
||||
# PyTorchモデル生成
|
||||
@ -141,21 +143,31 @@ class RVC:
|
||||
net_g.load_state_dict(cpt["weight"], strict=False)
|
||||
if self.is_half:
|
||||
net_g = net_g.half()
|
||||
self.net_g = net_g
|
||||
self.next_net_g = net_g
|
||||
else:
|
||||
self.net_g = None
|
||||
self.next_net_g = None
|
||||
|
||||
# ONNXモデル生成
|
||||
if onnxModelFile != None and onnxModelFile != "":
|
||||
self.onnx_session = ModelWrapper(onnxModelFile)
|
||||
self.next_onnx_session = ModelWrapper(onnxModelFile)
|
||||
else:
|
||||
self.onnx_session = None
|
||||
self.next_onnx_session = None
|
||||
|
||||
self.feature_file = self.settings.modelSlots[slot].featureFile
|
||||
self.index_file = self.settings.modelSlots[slot].indexFile
|
||||
self.next_feature_file = self.settings.modelSlots[slot].featureFile
|
||||
self.next_index_file = self.settings.modelSlots[slot].indexFile
|
||||
|
||||
return self.get_info()
|
||||
|
||||
def switchModel(self):
|
||||
del self.net_g
|
||||
del self.onnx_session
|
||||
self.net_g = self.next_net_g
|
||||
self.onnx_session = self.next_onnx_session
|
||||
self.feature_file = self.next_feature_file
|
||||
self.index_file = self.next_index_file
|
||||
self.next_net_g = None
|
||||
self.next_onnx_session = None
|
||||
|
||||
def update_settings(self, key: str, val: any):
|
||||
if key == "onnxExecutionProvider" and self.onnx_session != None:
|
||||
if val == "CUDAExecutionProvider":
|
||||
@ -181,7 +193,10 @@ class RVC:
|
||||
provider_options = [{'device_id': self.settings.gpu}]
|
||||
self.onnx_session.set_providers(providers=["CUDAExecutionProvider"], provider_options=provider_options)
|
||||
if key == "modelSlotIndex":
|
||||
self.switchModel(int(val))
|
||||
# self.switchModel(int(val))
|
||||
self.tmp_slot = int(val)
|
||||
self.prepareModel(self.tmp_slot)
|
||||
self.slot = self.tmp_slot
|
||||
elif key in self.settings.floatData:
|
||||
setattr(self.settings, key, float(val))
|
||||
elif key in self.settings.strData:
|
||||
@ -321,6 +336,10 @@ class RVC:
|
||||
return result
|
||||
|
||||
def inference(self, data):
|
||||
if self.currentSlot != self.slot:
|
||||
self.currentSlot = self.slot
|
||||
self.switchModel()
|
||||
|
||||
if self.settings.framework == "ONNX":
|
||||
audio = self._onnx_inference(data)
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user