2023-07-27 04:06:25 +09:00

287 lines
9.7 KiB
Python

import glob
import logging
import os
import shutil
import socket
import sys
import ffmpeg
import matplotlib
import matplotlib.pylab as plt
import numpy as np
import torch
from scipy.io.wavfile import read
from torch.nn import functional as F
from modules.shared import ROOT_DIR
from .config import TrainConfig
matplotlib.use("Agg")
logging.getLogger("matplotlib").setLevel(logging.WARNING)
# logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging
class AWP:
"""
Fast AWP
https://www.kaggle.com/code/junkoda/fast-awp
"""
def __init__(self, model, optimizer, *, adv_param='weight',
adv_lr=0.01, adv_eps=0.01):
self.model = model
self.optimizer = optimizer
self.adv_param = adv_param
self.adv_lr = adv_lr
self.adv_eps = adv_eps
self.backup = {}
def perturb(self):
"""
Perturb model parameters for AWP gradient
Call before loss and loss.backward()
"""
self._save() # save model parameters
self._attack_step() # perturb weights
def _attack_step(self):
e = 1e-6
for name, param in self.model.named_parameters():
if param.requires_grad and param.grad is not None and self.adv_param in name:
grad = self.optimizer.state[param]['exp_avg']
norm_grad = torch.norm(grad)
norm_data = torch.norm(param.detach())
if norm_grad != 0 and not torch.isnan(norm_grad):
# Set lower and upper limit in change
limit_eps = self.adv_eps * param.detach().abs()
param_min = param.data - limit_eps
param_max = param.data + limit_eps
# Perturb along gradient
# w += (adv_lr * |w| / |grad|) * grad
param.data.add_(grad, alpha=(self.adv_lr * (norm_data + e) / (norm_grad + e)))
# Apply the limit to the change
param.data.clamp_(param_min, param_max)
def _save(self):
for name, param in self.model.named_parameters():
if param.requires_grad and param.grad is not None and self.adv_param in name:
if name not in self.backup:
self.backup[name] = param.clone().detach()
else:
self.backup[name].copy_(param.data)
def restore(self):
"""
Restore model parameter to correct position; AWP do not perturbe weights, it perturb gradients
Call after loss.backward(), before optimizer.step()
"""
for name, param in self.model.named_parameters():
if name in self.backup:
param.data.copy_(self.backup[name])
def load_audio(file: str, sr):
try:
# https://github.com/openai/whisper/blob/main/whisper/audio.py#L26
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
file = (
file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
) # Prevent small white copy path head and tail with spaces and " and return
out, _ = (
ffmpeg.input(file, threads=0)
.output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr)
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
)
except Exception as e:
raise RuntimeError(f"Failed to load audio: {e}")
return np.frombuffer(out, np.float32).flatten()
def find_empty_port():
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
s.close()
return port
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
saved_state_dict = checkpoint_dict["model"]
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items(): # 模型需要的shape
try:
new_state_dict[k] = saved_state_dict[k]
if saved_state_dict[k].shape != state_dict[k].shape:
print(
f"shape-{k}-mismatch|need-{state_dict[k].shape}|get-{saved_state_dict[k].shape}"
)
if saved_state_dict[k].dim() == 2: # NOTE: check is this ok?
# for embedded input 256 <==> 768
# this achieves we can continue training from original's pretrained checkpoints when using embedder that 768-th dim output etc.
if saved_state_dict[k].dtype == torch.half:
new_state_dict[k] = (
F.interpolate(
saved_state_dict[k].float().unsqueeze(0).unsqueeze(0),
size=state_dict[k].shape,
mode="bilinear",
)
.half()
.squeeze(0)
.squeeze(0)
)
else:
new_state_dict[k] = (
F.interpolate(
saved_state_dict[k].unsqueeze(0).unsqueeze(0),
size=state_dict[k].shape,
mode="bilinear",
)
.squeeze(0)
.squeeze(0)
)
print(
"interpolated new_state_dict",
k,
"from",
saved_state_dict[k].shape,
"to",
new_state_dict[k].shape,
)
else:
raise KeyError
except Exception as e:
# print(traceback.format_exc())
print(f"{k} is not in the checkpoint")
print("error: %s" % e)
new_state_dict[k] = v # 模型自带的随机值
if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict, strict=False)
else:
model.load_state_dict(new_state_dict, strict=False)
print("Loaded model weights")
epoch = checkpoint_dict["epoch"]
learning_rate = checkpoint_dict["learning_rate"]
if optimizer is not None and load_opt == 1:
optimizer.load_state_dict(checkpoint_dict["optimizer"])
print("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, epoch))
return model, optimizer, learning_rate, epoch
def save_state(model, optimizer, learning_rate, epoch, checkpoint_path):
print(
"Saving model and optimizer state at epoch {} to {}".format(
epoch, checkpoint_path
)
)
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(
{
"model": state_dict,
"epoch": epoch,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
checkpoint_path,
)
def summarize(
writer,
global_step,
scalars={},
histograms={},
images={},
audios={},
audio_sampling_rate=22050,
):
for k, v in scalars.items():
writer.add_scalar(k, v, global_step)
for k, v in histograms.items():
writer.add_histogram(k, v, global_step)
for k, v in images.items():
writer.add_image(k, v, global_step, dataformats="HWC")
for k, v in audios.items():
writer.add_audio(k, v, global_step, audio_sampling_rate)
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
filelist = glob.glob(os.path.join(dir_path, regex))
if len(filelist) == 0:
return None
filelist.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
filepath = filelist[-1]
return filepath
def plot_spectrogram_to_numpy(spectrogram):
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def plot_alignment_to_numpy(alignment, info=None):
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(
alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
)
fig.colorbar(im, ax=ax)
xlabel = "Decoder timestep"
if info is not None:
xlabel += "\n\n" + info
plt.xlabel(xlabel)
plt.ylabel("Encoder timestep")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def load_wav_to_torch(full_path):
sampling_rate, data = read(full_path)
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
def load_config(training_dir: str, sample_rate: int, emb_channels: int):
if emb_channels == 256:
config_path = os.path.join(ROOT_DIR, "configs", f"{sample_rate}.json")
else:
config_path = os.path.join(
ROOT_DIR, "configs", f"{sample_rate}-{emb_channels}.json"
)
config_save_path = os.path.join(training_dir, "config.json")
shutil.copyfile(config_path, config_save_path)
return TrainConfig.parse_file(config_save_path)