From d7f5828710730e562d9ebc92f6f0b87ce98cea0b Mon Sep 17 00:00:00 2001 From: wataru Date: Fri, 16 Jun 2023 17:12:03 +0900 Subject: [PATCH] WIP:common sample --- server/const.py | 2 +- server/data/ModelSlot.py | 10 +++++++++- server/voice_changer/RVC/RVC.py | 23 +++++++++++------------ server/voice_changer/RVC/RVCSettings.py | 9 +++------ server/voice_changer/VoiceChanger.py | 7 ++++--- 5 files changed, 28 insertions(+), 23 deletions(-) diff --git a/server/const.py b/server/const.py index c5c18cc7..df45f70b 100644 --- a/server/const.py +++ b/server/const.py @@ -171,4 +171,4 @@ def getSampleJsonAndModelIds(mode: RVCSampleMode): RVC_MODEL_DIRNAME = "rvc" -RVC_MAX_SLOT_NUM = 10 +MAX_SLOT_NUM = 10 diff --git a/server/data/ModelSlot.py b/server/data/ModelSlot.py index 1087d90c..5d5d72bc 100644 --- a/server/data/ModelSlot.py +++ b/server/data/ModelSlot.py @@ -1,5 +1,5 @@ from typing import TypeAlias, Union -from const import EnumInferenceTypes, EnumEmbedderTypes, VoiceChangerType +from const import MAX_SLOT_NUM, EnumInferenceTypes, EnumEmbedderTypes, VoiceChangerType from dataclasses import dataclass, asdict @@ -54,6 +54,14 @@ def loadSlotInfo(model_dir: str, slotIndex: int) -> ModelSlots: return ModelSlot() +def loadAllSlotInfo(model_dir: str): + slotInfos: list[ModelSlots] = [] + for slotIndex in range(MAX_SLOT_NUM): + slotInfo = loadSlotInfo(model_dir, slotIndex) + slotInfos.append(slotInfo) + return slotInfos + + def saveSlotInfo(model_dir: str, slotIndex: int, slotInfo: ModelSlots): slotDir = os.path.join(model_dir, str(slotIndex)) json.dump(asdict(slotInfo), open(os.path.join(slotDir, "params.json"), "w")) diff --git a/server/voice_changer/RVC/RVC.py b/server/voice_changer/RVC/RVC.py index 9999e6f9..3a10f16c 100644 --- a/server/voice_changer/RVC/RVC.py +++ b/server/voice_changer/RVC/RVC.py @@ -5,6 +5,7 @@ from typing import cast import numpy as np import torch import torchaudio +from data.ModelSlot import loadAllSlotInfo from utils.downloader.SampleDownloader import getSampleInfos from voice_changer.RVC.ModelSlot import ModelSlot from voice_changer.RVC.SampleDownloader import downloadModelFiles @@ -67,18 +68,12 @@ class RVC: self.pitchExtractor = PitchExtractorManager.getPitchExtractor(self.settings.f0Detector) self.params = params EmbedderManager.initialize(params) - self.loadSlots() + # self.loadSlots() + self.settings.modelSlots = loadAllSlotInfo(self.params.model_dir) print("[Voice Changer] RVC initialization: ", params) # サンプルカタログ作成 - # sampleJsons: list[str] = [] samples = getSampleInfos(params.sample_mode) - # for url in sampleJsonUrls: - # filename = os.path.basename(url) - # sampleJsons.append(filename) - # sampleModels = getModelSamples(sampleJsons, "RVC") - # if sampleModels is not None: - # self.settings.sampleModels = sampleModels self.settings.sampleModels = samples # 起動時にスロットにモデルがある場合はロードしておく if len(self.settings.modelSlots) > 0: @@ -160,7 +155,8 @@ class RVC: if slotInfo.iconFile is not None and len(slotInfo.iconFile) > 0: slotInfo.iconFile = self.moveToModelDir(slotInfo.iconFile, slotDir) json.dump(asdict(slotInfo), open(os.path.join(slotDir, "params.json"), "w")) - self.loadSlots() + # self.loadSlots() + self.settings.modelSlots = loadAllSlotInfo(self.params.model_dir) # 初回のみロード(起動時にスロットにモデルがあった場合はinitialLoadはFalseになっている) if self.initialLoad: @@ -444,7 +440,8 @@ class RVC: params["defaultProtect"] = self.settings.protect json.dump(params, open(os.path.join(slotDir, "params.json"), "w")) - self.loadSlots() + # self.loadSlots() + self.settings.modelSlots = loadAllSlotInfo(self.params.model_dir) def update_model_info(self, newData: str): print("[Voice Changer] UPDATE MODEL INFO", newData) @@ -456,7 +453,8 @@ class RVC: params = json.load(open(os.path.join(slotDir, "params.json"), "r", encoding="utf-8")) params[newDataDict["key"]] = newDataDict["val"] json.dump(params, open(os.path.join(slotDir, "params.json"), "w")) - self.loadSlots() + # self.loadSlots() + self.settings.modelSlots = loadAllSlotInfo(self.params.model_dir) def upload_model_assets(self, params: str): print("[Voice Changer] UPLOAD ASSETS", params) @@ -479,4 +477,5 @@ class RVC: except Exception as e: print("Exception::::", e) - self.loadSlots() + # self.loadSlots() + self.settings.modelSlots = loadAllSlotInfo(self.params.model_dir) diff --git a/server/voice_changer/RVC/RVCSettings.py b/server/voice_changer/RVC/RVCSettings.py index 1a25d624..d25b054d 100644 --- a/server/voice_changer/RVC/RVCSettings.py +++ b/server/voice_changer/RVC/RVCSettings.py @@ -1,8 +1,7 @@ from dataclasses import dataclass, field from ModelSample import RVCModelSample -from const import RVC_MAX_SLOT_NUM - -from voice_changer.RVC.ModelSlot import ModelSlot +from const import MAX_SLOT_NUM +from data.ModelSlot import ModelSlot, ModelSlots @dataclass @@ -17,9 +16,7 @@ class RVCSettings: clusterInferRatio: float = 0.1 framework: str = "PyTorch" # PyTorch or ONNX - modelSlots: list[ModelSlot] = field( - default_factory=lambda: [ModelSlot() for _x in range(RVC_MAX_SLOT_NUM)] - ) + modelSlots: list[ModelSlots] = field(default_factory=lambda: [ModelSlot() for _x in range(MAX_SLOT_NUM)]) sampleModels: list[RVCModelSample] = field(default_factory=lambda: []) diff --git a/server/voice_changer/VoiceChanger.py b/server/voice_changer/VoiceChanger.py index 4877b6c9..79183289 100755 --- a/server/voice_changer/VoiceChanger.py +++ b/server/voice_changer/VoiceChanger.py @@ -14,7 +14,7 @@ from voice_changer.IORecorder import IORecorder from voice_changer.utils.LoadModelParams import LoadModelParams from voice_changer.utils.Timer import Timer -from voice_changer.utils.VoiceChangerModel import VoiceChangerModel, AudioInOut +from voice_changer.utils.VoiceChangerModel import AudioInOut from Exceptions import ( DeviceCannotSupportHalfPrecisionException, DeviceChangingException, @@ -60,8 +60,6 @@ class VoiceChangerSettings: class VoiceChanger: - settings: VoiceChangerSettings = VoiceChangerSettings() - voiceChanger: VoiceChangerModel | None = None ioRecorder: IORecorder sola_buffer: AudioInOut namespace: socketio.AsyncNamespace | None = None @@ -148,7 +146,10 @@ class VoiceChanger: def get_info(self): data = asdict(self.settings) if self.voiceChanger is not None: + print("------------------ self.voiceChanger is not None") data.update(self.voiceChanger.get_info()) + else: + print("------------------ self.voiceChanger is None") return data def get_performance(self):