| | import torch |
| | from torch.utils.data import DataLoader, Dataset |
| | import torchaudio |
| | import torchvision.transforms as tvt |
| | from denoising_diffusion_pytorch.classifier_free_guidance import Unet, GaussianDiffusion |
| | import glob |
| | import torch.nn as nn |
| | import time, math |
| | from PIL import Image |
| | from diffusers import Mel |
| | import sys |
| | import torchaudio |
| | import librosa |
| | import matplotlib.pyplot as plt |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | args = sys.argv[1:] |
| |
|
| | class Audio(Dataset): |
| | def __init__(self, folder): |
| | |
| | self.waveforms = [] |
| | self.labels = [] |
| | print("Loading files...") |
| | for file in glob.iglob(folder + '/**/*.wav', recursive=True): |
| | self.labels.append(int(file.split('/')[-1][0])) |
| | waveform, _ = torchaudio.load(file) |
| | |
| | self.waveforms.append(waveform) |
| | |
| | def __len__(self): |
| | return len(self.waveforms) |
| |
|
| | def __getitem__(self, index): |
| | return self.waveforms[index], self.labels[index] |
| |
|
| |
|
| | image_size = 256 |
| | if len(args) >= 1: |
| | image_size = int(args[0]) |
| |
|
| | MEL = Mel(x_res=image_size, y_res=image_size) |
| | img_to_tensor = tvt.PILToTensor() |
| |
|
| | def collate(batch): |
| | spectros = [] |
| | labels = [] |
| | for waveform, label in batch: |
| | MEL.load_audio(raw_audio=waveform[0]) |
| | for slice in range(MEL.get_number_of_slices()): |
| | spectro = MEL.audio_slice_to_image(slice) |
| | spectro = img_to_tensor(spectro) / 255.0 |
| | |
| | |
| | |
| | |
| | spectros.append(spectro) |
| | labels.append(label) |
| |
|
| | spectros = torch.stack(spectros) |
| | labels = torch.tensor(labels) |
| | |
| | return spectros.to(device), labels.to(device) |
| |
|
| |
|
| | def initialize(scheduler = None, batch_size=32): |
| | model = Unet( |
| | dim = 64, |
| | num_classes=10, |
| | dim_mults=(1, 2, 4, 8), |
| | channels=1 |
| | ) |
| | diffusion = GaussianDiffusion( |
| | model, |
| | image_size=image_size, |
| | timesteps=1000, |
| | loss_type = 'l2', |
| | objective='pred_x0', |
| | |
| | ) |
| | diffusion.to(device) |
| | |
| | optim = torch.optim.AdamW(model.parameters(), lr=1e-4, eps=1e-8) |
| | if scheduler: |
| | scheduler = torch.optim.lr_scheduler.CyclicLR(optim, base_lr=1e-5, max_lr=1e-3, mode="exp_range", cycle_momentum=False) |
| | return diffusion, optim, scheduler |
| |
|
| | def timeSince(since): |
| | now = time.time() |
| | s = now - since |
| | m = math.floor(s / 60) |
| | s -= m * 60 |
| | return '%dm %ds' % (m, s) |
| |
|
| | start = time.time() |
| |
|
| | def train(model, optim, train_dl, batch_size=32, epochs=5, scheduler = None): |
| | size = len(train_dl.dataset) |
| | model.train() |
| | losses = [] |
| | |
| | for e in range(epochs): |
| | batch_loss, batch_counts = 0, 0 |
| | for step, batch in enumerate(train_dl): |
| | model.zero_grad() |
| | batch_counts += 1 |
| | spectros, labels = batch |
| | loss = model(spectros, classes=labels) |
| | |
| | batch_loss += loss.item() |
| | loss.backward() |
| | nn.utils.clip_grad_norm_(model.parameters(), 1) |
| | optim.step() |
| | if scheduler is not None: |
| | scheduler.step() |
| | |
| | if (step % 100 == 0 and step != 0) or (step == len(train_dl) - 1): |
| | to_print = f"{e + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {timeSince(start)} | {step*batch_size:>5d}/{size:>5d}" |
| | print(to_print) |
| | losses.append(batch_loss) |
| | batch_loss, batch_counts = 0, 0 |
| |
|
| | labels = torch.randint(0,9,(8, )).to(device) |
| | print(labels) |
| | samples = model.sample(labels) |
| | for i, sample in enumerate(samples): |
| | im = Image.fromarray(sample[0].cpu().numpy() * 255).convert('L') |
| | audio = torch.tensor([MEL.image_to_audio(im)]) |
| | torchaudio.save(f"audio/sample{e}_{i}_{labels[i]}.wav", audio, 48000) |
| | im.save(f"images/sample{e}_{i}_{labels[i]}.jpg") |
| | return losses |
| |
|
| | if __name__ == "__main__": |
| | num_epochs = 10 |
| | if len(args) >= 2: |
| | num_epochs = int(args[1]) |
| |
|
| | batch_size = 32 |
| | if len(args) >= 3: |
| | batch_size = int(args[2]) |
| |
|
| | print(image_size, num_epochs, batch_size) |
| | model, optim, scheduler = initialize(scheduler=True, batch_size=batch_size) |
| | train_data = Audio("AudioMNIST/data") |
| | print("Done Loading") |
| | train_dl = DataLoader(train_data, batch_size, True, collate_fn=collate) |
| | train(model, optim, train_dl, batch_size, num_epochs, scheduler) |
| | torch.save(model.state_dict(), "diffusion_condition_model.pt") |