From 3a4e534091d6c19849c776fcb5b21fc90ef72037 Mon Sep 17 00:00:00 2001 From: wataru Date: Wed, 3 May 2023 22:36:59 +0900 Subject: [PATCH] update --- server/MMVCServerSIO.py | 65 ++++++++++++++++++++++++++++++++- server/voice_changer/RVC/RVC.py | 3 +- 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/server/MMVCServerSIO.py b/server/MMVCServerSIO.py index 032efaa7..906ec22a 100755 --- a/server/MMVCServerSIO.py +++ b/server/MMVCServerSIO.py @@ -1,3 +1,4 @@ +from concurrent.futures import ThreadPoolExecutor import sys from distutils.util import strtobool @@ -6,6 +7,9 @@ import socket import platform import os import argparse +import requests + +from tqdm import tqdm from voice_changer.utils.VoiceChangerParams import VoiceChangerParams import uvicorn @@ -106,6 +110,36 @@ def localServer(): ) +def download(params): + url = params["url"] + saveTo = params["saveTo"] + position = params["position"] + dirname = os.path.dirname(saveTo) + os.makedirs(dirname, exist_ok=True) + + try: + req = requests.get(url, stream=True, allow_redirects=True) + content_length = req.headers.get("content-length") + progress_bar = tqdm( + total=int(content_length) if content_length is not None else None, + leave=False, + unit="B", + unit_scale=True, + unit_divisor=1024, + position=position, + ) + + # with tqdm + with open(saveTo, "wb") as f: + for chunk in req.iter_content(chunk_size=1024): + if chunk: + progress_bar.update(len(chunk)) + f.write(chunk) + + except Exception as e: + print(e) + + if __name__ == "MMVCServerSIO": voiceChangerParams = VoiceChangerParams( content_vec_500=args.content_vec_500, @@ -116,9 +150,36 @@ if __name__ == "MMVCServerSIO": hubert_soft=args.hubert_soft, nsf_hifigan=args.nsf_hifigan, ) - voiceChangerManager = VoiceChangerManager.get_instance(voiceChangerParams) - print("voiceChangerManager", voiceChangerManager) + # file exists check (currently only for rvc) + downloadParams = [] + if os.path.exists(voiceChangerParams.hubert_base) is False: + downloadParams.append( + { + "url": "https://huggingface.co/ddPn08/rvc-webui-models/resolve/main/embeddings/hubert_base.pt", + "saveTo": voiceChangerParams.hubert_base, + "position": 0, + } + ) + if os.path.exists(voiceChangerParams.hubert_base_jp) is False: + downloadParams.append( + { + "url": "https://huggingface.co/rinna/japanese-hubert-base/resolve/main/fairseq/model.pt", + "saveTo": voiceChangerParams.hubert_base_jp, + "position": 1, + } + ) + with ThreadPoolExecutor() as pool: + pool.map(download, downloadParams) + + if ( + os.path.exists(voiceChangerParams.hubert_base) is False + or os.path.exists(voiceChangerParams.hubert_base_jp) is False + ): + printMessage("RVC用のモデルファイルのダウンロードに失敗しました。", level=2) + printMessage("failed to download weight for rvc", level=2) + + voiceChangerManager = VoiceChangerManager.get_instance(voiceChangerParams) app_fastapi = MMVC_Rest.get_instance(voiceChangerManager) app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager) diff --git a/server/voice_changer/RVC/RVC.py b/server/voice_changer/RVC/RVC.py index 2bdfaf9c..ae749046 100644 --- a/server/voice_changer/RVC/RVC.py +++ b/server/voice_changer/RVC/RVC.py @@ -114,7 +114,8 @@ class RVC: if modelSlot.embedder == EnumEmbedderTypes.hubert: emmbedderFilename = self.params.hubert_base elif modelSlot.embedder == EnumEmbedderTypes.contentvec: - emmbedderFilename = self.params.content_vec_500 + # emmbedderFilename = self.params.content_vec_500 + emmbedderFilename = self.params.hubert_base elif modelSlot.embedder == EnumEmbedderTypes.hubert_jp: emmbedderFilename = self.params.hubert_base_jp else: