import glob import logging import os import shutil import socket import sys import ffmpeg import matplotlib import matplotlib.pylab as plt import numpy as np import torch from scipy.io.wavfile import read from torch.nn import functional as F from modules.shared import ROOT_DIR from .config import TrainConfig matplotlib.use("Agg") logging.getLogger("matplotlib").setLevel(logging.WARNING) # logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) logger = logging class AWP: """ Fast AWP https://www.kaggle.com/code/junkoda/fast-awp """ def __init__(self, model, optimizer, *, adv_param='weight', adv_lr=0.01, adv_eps=0.01): self.model = model self.optimizer = optimizer self.adv_param = adv_param self.adv_lr = adv_lr self.adv_eps = adv_eps self.backup = {} def perturb(self): """ Perturb model parameters for AWP gradient Call before loss and loss.backward() """ self._save() # save model parameters self._attack_step() # perturb weights def _attack_step(self): e = 1e-6 for name, param in self.model.named_parameters(): if param.requires_grad and param.grad is not None and self.adv_param in name: grad = self.optimizer.state[param]['exp_avg'] norm_grad = torch.norm(grad) norm_data = torch.norm(param.detach()) if norm_grad != 0 and not torch.isnan(norm_grad): # Set lower and upper limit in change limit_eps = self.adv_eps * param.detach().abs() param_min = param.data - limit_eps param_max = param.data + limit_eps # Perturb along gradient # w += (adv_lr * |w| / |grad|) * grad param.data.add_(grad, alpha=(self.adv_lr * (norm_data + e) / (norm_grad + e))) # Apply the limit to the change param.data.clamp_(param_min, param_max) def _save(self): for name, param in self.model.named_parameters(): if param.requires_grad and param.grad is not None and self.adv_param in name: if name not in self.backup: self.backup[name] = param.clone().detach() else: self.backup[name].copy_(param.data) def restore(self): """ Restore model parameter to correct position; AWP do not perturbe weights, it perturb gradients Call after loss.backward(), before optimizer.step() """ for name, param in self.model.named_parameters(): if name in self.backup: param.data.copy_(self.backup[name]) def load_audio(file: str, sr): try: # https://github.com/openai/whisper/blob/main/whisper/audio.py#L26 # This launches a subprocess to decode audio while down-mixing and resampling as necessary. # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. file = ( file.strip(" ").strip('"').strip("\n").strip('"').strip(" ") ) # Prevent small white copy path head and tail with spaces and " and return out, _ = ( ffmpeg.input(file, threads=0) .output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr) .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) ) except Exception as e: raise RuntimeError(f"Failed to load audio: {e}") return np.frombuffer(out, np.float32).flatten() def find_empty_port(): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("", 0)) s.listen(1) port = s.getsockname()[1] s.close() return port def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1): assert os.path.isfile(checkpoint_path) checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") saved_state_dict = checkpoint_dict["model"] if hasattr(model, "module"): state_dict = model.module.state_dict() else: state_dict = model.state_dict() new_state_dict = {} for k, v in state_dict.items(): # 模型需要的shape try: new_state_dict[k] = saved_state_dict[k] if saved_state_dict[k].shape != state_dict[k].shape: print( f"shape-{k}-mismatch|need-{state_dict[k].shape}|get-{saved_state_dict[k].shape}" ) if saved_state_dict[k].dim() == 2: # NOTE: check is this ok? # for embedded input 256 <==> 768 # this achieves we can continue training from original's pretrained checkpoints when using embedder that 768-th dim output etc. if saved_state_dict[k].dtype == torch.half: new_state_dict[k] = ( F.interpolate( saved_state_dict[k].float().unsqueeze(0).unsqueeze(0), size=state_dict[k].shape, mode="bilinear", ) .half() .squeeze(0) .squeeze(0) ) else: new_state_dict[k] = ( F.interpolate( saved_state_dict[k].unsqueeze(0).unsqueeze(0), size=state_dict[k].shape, mode="bilinear", ) .squeeze(0) .squeeze(0) ) print( "interpolated new_state_dict", k, "from", saved_state_dict[k].shape, "to", new_state_dict[k].shape, ) else: raise KeyError except Exception as e: # print(traceback.format_exc()) print(f"{k} is not in the checkpoint") print("error: %s" % e) new_state_dict[k] = v # 模型自带的随机值 if hasattr(model, "module"): model.module.load_state_dict(new_state_dict, strict=False) else: model.load_state_dict(new_state_dict, strict=False) print("Loaded model weights") epoch = checkpoint_dict["epoch"] learning_rate = checkpoint_dict["learning_rate"] if optimizer is not None and load_opt == 1: optimizer.load_state_dict(checkpoint_dict["optimizer"]) print("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, epoch)) return model, optimizer, learning_rate, epoch def save_state(model, optimizer, learning_rate, epoch, checkpoint_path): print( "Saving model and optimizer state at epoch {} to {}".format( epoch, checkpoint_path ) ) if hasattr(model, "module"): state_dict = model.module.state_dict() else: state_dict = model.state_dict() torch.save( { "model": state_dict, "epoch": epoch, "optimizer": optimizer.state_dict(), "learning_rate": learning_rate, }, checkpoint_path, ) def summarize( writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050, ): for k, v in scalars.items(): writer.add_scalar(k, v, global_step) for k, v in histograms.items(): writer.add_histogram(k, v, global_step) for k, v in images.items(): writer.add_image(k, v, global_step, dataformats="HWC") for k, v in audios.items(): writer.add_audio(k, v, global_step, audio_sampling_rate) def latest_checkpoint_path(dir_path, regex="G_*.pth"): filelist = glob.glob(os.path.join(dir_path, regex)) if len(filelist) == 0: return None filelist.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) filepath = filelist[-1] return filepath def plot_spectrogram_to_numpy(spectrogram): fig, ax = plt.subplots(figsize=(10, 2)) im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") plt.colorbar(im, ax=ax) plt.xlabel("Frames") plt.ylabel("Channels") plt.tight_layout() fig.canvas.draw() data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plt.close() return data def plot_alignment_to_numpy(alignment, info=None): fig, ax = plt.subplots(figsize=(6, 4)) im = ax.imshow( alignment.transpose(), aspect="auto", origin="lower", interpolation="none" ) fig.colorbar(im, ax=ax) xlabel = "Decoder timestep" if info is not None: xlabel += "\n\n" + info plt.xlabel(xlabel) plt.ylabel("Encoder timestep") plt.tight_layout() fig.canvas.draw() data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plt.close() return data def load_wav_to_torch(full_path): sampling_rate, data = read(full_path) return torch.FloatTensor(data.astype(np.float32)), sampling_rate def load_config(training_dir: str, sample_rate: int, emb_channels: int): if emb_channels == 256: config_path = os.path.join(ROOT_DIR, "configs", f"{sample_rate}.json") else: config_path = os.path.join( ROOT_DIR, "configs", f"{sample_rate}-{emb_channels}.json" ) config_save_path = os.path.join(training_dir, "config.json") shutil.copyfile(config_path, config_save_path) return TrainConfig.parse_file(config_save_path)