WIP:common sample

This commit is contained in:
wataru 2023-06-16 17:12:03 +09:00
parent 87de3dc10f
commit d7f5828710
5 changed files with 28 additions and 23 deletions

View File

@ -171,4 +171,4 @@ def getSampleJsonAndModelIds(mode: RVCSampleMode):
RVC_MODEL_DIRNAME = "rvc" RVC_MODEL_DIRNAME = "rvc"
RVC_MAX_SLOT_NUM = 10 MAX_SLOT_NUM = 10

View File

@ -1,5 +1,5 @@
from typing import TypeAlias, Union from typing import TypeAlias, Union
from const import EnumInferenceTypes, EnumEmbedderTypes, VoiceChangerType from const import MAX_SLOT_NUM, EnumInferenceTypes, EnumEmbedderTypes, VoiceChangerType
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
@ -54,6 +54,14 @@ def loadSlotInfo(model_dir: str, slotIndex: int) -> ModelSlots:
return ModelSlot() return ModelSlot()
def loadAllSlotInfo(model_dir: str):
slotInfos: list[ModelSlots] = []
for slotIndex in range(MAX_SLOT_NUM):
slotInfo = loadSlotInfo(model_dir, slotIndex)
slotInfos.append(slotInfo)
return slotInfos
def saveSlotInfo(model_dir: str, slotIndex: int, slotInfo: ModelSlots): def saveSlotInfo(model_dir: str, slotIndex: int, slotInfo: ModelSlots):
slotDir = os.path.join(model_dir, str(slotIndex)) slotDir = os.path.join(model_dir, str(slotIndex))
json.dump(asdict(slotInfo), open(os.path.join(slotDir, "params.json"), "w")) json.dump(asdict(slotInfo), open(os.path.join(slotDir, "params.json"), "w"))

View File

@ -5,6 +5,7 @@ from typing import cast
import numpy as np import numpy as np
import torch import torch
import torchaudio import torchaudio
from data.ModelSlot import loadAllSlotInfo
from utils.downloader.SampleDownloader import getSampleInfos from utils.downloader.SampleDownloader import getSampleInfos
from voice_changer.RVC.ModelSlot import ModelSlot from voice_changer.RVC.ModelSlot import ModelSlot
from voice_changer.RVC.SampleDownloader import downloadModelFiles from voice_changer.RVC.SampleDownloader import downloadModelFiles
@ -67,18 +68,12 @@ class RVC:
self.pitchExtractor = PitchExtractorManager.getPitchExtractor(self.settings.f0Detector) self.pitchExtractor = PitchExtractorManager.getPitchExtractor(self.settings.f0Detector)
self.params = params self.params = params
EmbedderManager.initialize(params) EmbedderManager.initialize(params)
self.loadSlots() # self.loadSlots()
self.settings.modelSlots = loadAllSlotInfo(self.params.model_dir)
print("[Voice Changer] RVC initialization: ", params) print("[Voice Changer] RVC initialization: ", params)
# サンプルカタログ作成 # サンプルカタログ作成
# sampleJsons: list[str] = []
samples = getSampleInfos(params.sample_mode) samples = getSampleInfos(params.sample_mode)
# for url in sampleJsonUrls:
# filename = os.path.basename(url)
# sampleJsons.append(filename)
# sampleModels = getModelSamples(sampleJsons, "RVC")
# if sampleModels is not None:
# self.settings.sampleModels = sampleModels
self.settings.sampleModels = samples self.settings.sampleModels = samples
# 起動時にスロットにモデルがある場合はロードしておく # 起動時にスロットにモデルがある場合はロードしておく
if len(self.settings.modelSlots) > 0: if len(self.settings.modelSlots) > 0:
@ -160,7 +155,8 @@ class RVC:
if slotInfo.iconFile is not None and len(slotInfo.iconFile) > 0: if slotInfo.iconFile is not None and len(slotInfo.iconFile) > 0:
slotInfo.iconFile = self.moveToModelDir(slotInfo.iconFile, slotDir) slotInfo.iconFile = self.moveToModelDir(slotInfo.iconFile, slotDir)
json.dump(asdict(slotInfo), open(os.path.join(slotDir, "params.json"), "w")) json.dump(asdict(slotInfo), open(os.path.join(slotDir, "params.json"), "w"))
self.loadSlots() # self.loadSlots()
self.settings.modelSlots = loadAllSlotInfo(self.params.model_dir)
# 初回のみロード(起動時にスロットにモデルがあった場合はinitialLoadはFalseになっている) # 初回のみロード(起動時にスロットにモデルがあった場合はinitialLoadはFalseになっている)
if self.initialLoad: if self.initialLoad:
@ -444,7 +440,8 @@ class RVC:
params["defaultProtect"] = self.settings.protect params["defaultProtect"] = self.settings.protect
json.dump(params, open(os.path.join(slotDir, "params.json"), "w")) json.dump(params, open(os.path.join(slotDir, "params.json"), "w"))
self.loadSlots() # self.loadSlots()
self.settings.modelSlots = loadAllSlotInfo(self.params.model_dir)
def update_model_info(self, newData: str): def update_model_info(self, newData: str):
print("[Voice Changer] UPDATE MODEL INFO", newData) print("[Voice Changer] UPDATE MODEL INFO", newData)
@ -456,7 +453,8 @@ class RVC:
params = json.load(open(os.path.join(slotDir, "params.json"), "r", encoding="utf-8")) params = json.load(open(os.path.join(slotDir, "params.json"), "r", encoding="utf-8"))
params[newDataDict["key"]] = newDataDict["val"] params[newDataDict["key"]] = newDataDict["val"]
json.dump(params, open(os.path.join(slotDir, "params.json"), "w")) json.dump(params, open(os.path.join(slotDir, "params.json"), "w"))
self.loadSlots() # self.loadSlots()
self.settings.modelSlots = loadAllSlotInfo(self.params.model_dir)
def upload_model_assets(self, params: str): def upload_model_assets(self, params: str):
print("[Voice Changer] UPLOAD ASSETS", params) print("[Voice Changer] UPLOAD ASSETS", params)
@ -479,4 +477,5 @@ class RVC:
except Exception as e: except Exception as e:
print("Exception::::", e) print("Exception::::", e)
self.loadSlots() # self.loadSlots()
self.settings.modelSlots = loadAllSlotInfo(self.params.model_dir)

View File

@ -1,8 +1,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from ModelSample import RVCModelSample from ModelSample import RVCModelSample
from const import RVC_MAX_SLOT_NUM from const import MAX_SLOT_NUM
from data.ModelSlot import ModelSlot, ModelSlots
from voice_changer.RVC.ModelSlot import ModelSlot
@dataclass @dataclass
@ -17,9 +16,7 @@ class RVCSettings:
clusterInferRatio: float = 0.1 clusterInferRatio: float = 0.1
framework: str = "PyTorch" # PyTorch or ONNX framework: str = "PyTorch" # PyTorch or ONNX
modelSlots: list[ModelSlot] = field( modelSlots: list[ModelSlots] = field(default_factory=lambda: [ModelSlot() for _x in range(MAX_SLOT_NUM)])
default_factory=lambda: [ModelSlot() for _x in range(RVC_MAX_SLOT_NUM)]
)
sampleModels: list[RVCModelSample] = field(default_factory=lambda: []) sampleModels: list[RVCModelSample] = field(default_factory=lambda: [])

View File

@ -14,7 +14,7 @@ from voice_changer.IORecorder import IORecorder
from voice_changer.utils.LoadModelParams import LoadModelParams from voice_changer.utils.LoadModelParams import LoadModelParams
from voice_changer.utils.Timer import Timer from voice_changer.utils.Timer import Timer
from voice_changer.utils.VoiceChangerModel import VoiceChangerModel, AudioInOut from voice_changer.utils.VoiceChangerModel import AudioInOut
from Exceptions import ( from Exceptions import (
DeviceCannotSupportHalfPrecisionException, DeviceCannotSupportHalfPrecisionException,
DeviceChangingException, DeviceChangingException,
@ -60,8 +60,6 @@ class VoiceChangerSettings:
class VoiceChanger: class VoiceChanger:
settings: VoiceChangerSettings = VoiceChangerSettings()
voiceChanger: VoiceChangerModel | None = None
ioRecorder: IORecorder ioRecorder: IORecorder
sola_buffer: AudioInOut sola_buffer: AudioInOut
namespace: socketio.AsyncNamespace | None = None namespace: socketio.AsyncNamespace | None = None
@ -148,7 +146,10 @@ class VoiceChanger:
def get_info(self): def get_info(self):
data = asdict(self.settings) data = asdict(self.settings)
if self.voiceChanger is not None: if self.voiceChanger is not None:
print("------------------ self.voiceChanger is not None")
data.update(self.voiceChanger.get_info()) data.update(self.voiceChanger.get_info())
else:
print("------------------ self.voiceChanger is None")
return data return data
def get_performance(self): def get_performance(self):