hubert
This commit is contained in:
parent
f0fbf58258
commit
e10c832b46
@ -83,7 +83,7 @@ class SoVitsSvc40:
|
|||||||
ort_options.intra_op_num_threads = 8
|
ort_options.intra_op_num_threads = 8
|
||||||
self.hubert_onnx = onnxruntime.InferenceSession(
|
self.hubert_onnx = onnxruntime.InferenceSession(
|
||||||
"model_hubert/hubert_simple.onnx",
|
"model_hubert/hubert_simple.onnx",
|
||||||
providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']
|
providers=providers
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
||||||
@ -143,8 +143,12 @@ class SoVitsSvc40:
|
|||||||
self.settings.gpu = 0
|
self.settings.gpu = 0
|
||||||
provider_options = [{'device_id': self.settings.gpu}]
|
provider_options = [{'device_id': self.settings.gpu}]
|
||||||
self.onnx_session.set_providers(providers=[val], provider_options=provider_options)
|
self.onnx_session.set_providers(providers=[val], provider_options=provider_options)
|
||||||
|
if hasattr(self, "hubert_onnx"):
|
||||||
|
self.hubert_onnx.set_providers(providers=[val], provider_options=provider_options)
|
||||||
else:
|
else:
|
||||||
self.onnx_session.set_providers(providers=[val])
|
self.onnx_session.set_providers(providers=[val])
|
||||||
|
if hasattr(self, "hubert_onnx"):
|
||||||
|
self.hubert_onnx.set_providers(providers=[val])
|
||||||
elif key == "onnxExecutionProvider" and self.onnx_session == None:
|
elif key == "onnxExecutionProvider" and self.onnx_session == None:
|
||||||
print("Onnx is not enabled. Please load model.")
|
print("Onnx is not enabled. Please load model.")
|
||||||
return False
|
return False
|
||||||
@ -217,6 +221,7 @@ class SoVitsSvc40:
|
|||||||
"audio": wav16k_numpy.reshape(1, -1),
|
"audio": wav16k_numpy.reshape(1, -1),
|
||||||
})
|
})
|
||||||
c = torch.from_numpy(np.array(c)).squeeze(0).transpose(1, 2)
|
c = torch.from_numpy(np.array(c)).squeeze(0).transpose(1, 2)
|
||||||
|
# print("onnx hubert:", self.hubert_onnx.get_providers())
|
||||||
else:
|
else:
|
||||||
self.hubert_model = self.hubert_model.to(dev)
|
self.hubert_model = self.hubert_model.to(dev)
|
||||||
wav16k_tensor = wav16k_tensor.to(dev)
|
wav16k_tensor = wav16k_tensor.to(dev)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user