| import torch |
| import torch.optim as optim |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
| from torchvision.utils import make_grid, save_image |
| from tqdm import tqdm |
| from ddt_model import LocalSongModel |
| from transformers import get_cosine_schedule_with_warmup |
| from datasets import load_from_disk |
| from accelerate import Accelerator |
| import os |
| import argparse |
| from torch.utils.tensorboard import SummaryWriter |
| from datetime import datetime |
| from collections import deque |
| import torchaudio |
| import re |
| import sys |
| import math |
| from tag_embedder import TagEmbedder |
|
|
| |
| from acestep.music_dcae.music_dcae_pipeline import MusicDCAE |
|
|
| |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| import timm.optim |
|
|
| import os |
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| def save(model, optimizer, scheduler, global_step, accelerator): |
| if accelerator.is_main_process: |
| checkpoint_dir = "checkpoints" |
| os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
| unwrapped_model = accelerator.unwrap_model(model) |
| |
| checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_{global_step}.pth") |
| save_dict = { |
| 'model_state_dict': unwrapped_model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'global_step': global_step |
| } |
| |
| accelerator.save(save_dict, checkpoint_path) |
| print(f"Checkpoint saved at step {global_step}: {checkpoint_path}") |
|
|
| checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pth")], |
| key=lambda x: int(x.split("_")[1].split(".")[0]), reverse=True) |
| |
| for old_checkpoint in checkpoints[5:]: |
| os.remove(os.path.join(checkpoint_dir, old_checkpoint)) |
| print(f"Removed old checkpoint: {old_checkpoint}") |
|
|
|
|
| def load_checkpoint(model, optimizer, scheduler, checkpoint_path, accelerator): |
| checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) |
| |
| unwrapped_model = accelerator.unwrap_model(model) |
| state_dict = {k.replace("_orig_mod.", ""): v for k, v in checkpoint['model_state_dict'].items()} |
| missing, unexpected = unwrapped_model.load_state_dict(state_dict, strict=True) |
| print("MISSING:", missing) |
| print("UNEXPECTED:", unexpected) |
|
|
| if 'optimizer_state_dict' in checkpoint: |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| print("Optimizer loaded") |
|
|
| global_step = checkpoint['global_step'] |
| print(f"Resumed from step {global_step}") |
| return global_step |
|
|
| def resume(model, optimizer, scheduler, accelerator): |
| checkpoint_dir = "checkpoints" |
| if os.path.exists(checkpoint_dir): |
| checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pth")] |
| if checkpoints: |
| latest_checkpoint = max(checkpoints, key=lambda x: int(x.split("_")[1].split(".")[0])) |
| checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint) |
| if accelerator.is_main_process: |
| print(f"Resuming from checkpoint: {checkpoint_path}") |
|
|
| return load_checkpoint(model, optimizer, scheduler, checkpoint_path, accelerator) |
| else: |
| if accelerator.is_main_process: |
| print("No checkpoints found. Starting from scratch.") |
| else: |
| if accelerator.is_main_process: |
| print("Checkpoint directory not found. Starting from scratch.") |
| |
| return 0 |
|
|
| class AudioVAE: |
| def __init__(self, device): |
| self.model = MusicDCAE().to(device) |
| self.model.eval() |
| self.device = device |
|
|
| self.latent_mean = torch.tensor([0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526], device=device).view(1, -1, 1, 1) |
| self.latent_std = torch.tensor([0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707], device=device).view(1, -1, 1, 1) |
|
|
| def encode(self, audio): |
| """Encode audio to latents""" |
| |
| with torch.no_grad(): |
| audio_lengths = torch.tensor([audio.shape[2]] * audio.shape[0]).to(self.device) |
| latents, _ = self.model.encode(audio, audio_lengths, sr=48000) |
| |
| latents = (latents - self.latent_mean) / self.latent_std |
| return latents |
|
|
| def decode(self, latents): |
| """Decode latents to audio""" |
| with torch.no_grad(): |
| |
| latents = latents * self.latent_std + self.latent_mean |
| sr, audio_list = self.model.decode(latents, sr=48000) |
| |
| audio_batch = torch.stack(audio_list).to(self.device) |
| return audio_batch |
|
|
| class RF: |
| def __init__(self, model, time_sampling="sigmoid"): |
| self.model = model |
| self.time_sampling = time_sampling |
|
|
| def sample_timesteps(self, batch, device): |
| """Sample timesteps based on the configured strategy.""" |
| if self.time_sampling == "sigmoid": |
| return torch.sigmoid(torch.randn((batch,), device=device)) |
| elif self.time_sampling == "warped": |
| pm = 128 * 16 * 16 |
| alpha = max(1.0, math.sqrt(pm / 4096.0)) |
| u = torch.rand(batch, device=device) |
| return alpha * u / (1.0 + (alpha - 1.0) * u) |
| elif self.time_sampling == "uniform": |
| return torch.rand(batch, device=device) |
| else: |
| raise ValueError(f"Unknown time_sampling strategy: {self.time_sampling}") |
|
|
| def forward(self, x, cond): |
| b = x.size(0) |
|
|
| t = self.sample_timesteps(b, x.device) |
|
|
| texp = t.view([b, *([1] * len(x.shape[1:]))]) |
| z1 = torch.randn_like(x) |
| zt = (1 - texp) * x + texp * z1 |
|
|
| x_pred = self.model(zt, t, cond) |
|
|
| target = (zt - x) / (texp + 0.05) |
| v_pred = (zt - x_pred) / (texp + 0.05) |
| loss = F.mse_loss(target, v_pred) |
|
|
| return loss |
|
|
| def get_sampling_timesteps(self, steps, device): |
| """Generate timesteps for sampling.""" |
| if self.time_sampling == "uniform" or self.time_sampling == "sigmoid": |
| return torch.linspace(1.0, 0.0, steps + 1, device=device)[:-1] |
| elif self.time_sampling == "warped": |
| pm = 128 * 16 * 16 |
| alpha = max(1.0, math.sqrt(pm / 4096.0)) |
| u = torch.linspace(1.0, 0.0, steps + 1, device=device)[:-1] |
| return alpha * u / (1.0 + (alpha - 1.0) * u) |
| else: |
| raise ValueError(f"Unknown time_sampling strategy: {self.time_sampling}") |
|
|
| def sample(self, z, cond, null_cond=None, sample_steps=100, cfg=3.0): |
| b = z.size(0) |
| device = z.device |
| latent_shape = [b, *([1] * len(z.shape[1:]))] |
|
|
| timesteps = self.get_sampling_timesteps(sample_steps, device) |
| images = [z] |
|
|
| for idx in range(sample_steps): |
| t_curr = timesteps[idx] |
| t_next = timesteps[idx + 1] if idx + 1 < sample_steps else torch.tensor(0.0, device=device) |
| dt = t_curr - t_next |
| t = t_curr.expand(b) |
|
|
| vc = self.model(z, t, cond) |
| vc = (z - vc) / t_curr |
| if null_cond is not None: |
| vu = self.model(z, t, null_cond) |
| vu = (z - vu) / t_curr |
| vc = vu + cfg * (vc - vu) |
|
|
| z = z - dt * vc |
| images.append(z) |
| return images |
|
|
| def save_audio_samples(audio_batch, sample_rate, filename): |
| """Save audio samples to file""" |
| os.makedirs("audio_samples", exist_ok=True) |
| |
| |
| audio = audio_batch[0].cpu() |
| |
| |
| filepath = os.path.join("audio_samples", filename) |
| torchaudio.save(filepath, audio, sample_rate) |
| print(f"Saved audio sample: {filepath}") |
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description='Audio training script with TensorBoard logging') |
|
|
| parser.add_argument('--channels', type=int, default=8, help='Number of input channels in the audio latents') |
| parser.add_argument('--audio_height', type=int, default=16, help='Height of audio latents') |
| parser.add_argument('--max_audio_width', type=int, default=4096, help='Max width of audio latents') |
| parser.add_argument('--subsection_length', type=int, default=256, help='Length of random subsection to sample from each audio latent') |
| parser.add_argument('--n_layers', type=int, default=36, help='Number of layers in the model') |
| parser.add_argument('--n_encoder_layers', type=int, default=36, help='Number of encoder layers in the model') |
| parser.add_argument('--n_heads', type=int, default=16, help='Number of heads in the model') |
| parser.add_argument('--dim', type=int, default=768, help='Dimension of the encoder') |
| parser.add_argument('--decoder_dim', type=int, default=1536, help='Dimension of the decoder (if None, uses --dim)') |
| parser.add_argument('--dataset_name', type=str, default="cache", help='Audio dataset name') |
| parser.add_argument('--num_workers', type=int, default=16, help='Number of workers for dataloader') |
|
|
| parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training') |
| parser.add_argument('--epochs', type=int, default=1000, help='Number of epochs to train') |
| parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate') |
| parser.add_argument('--warmup_steps', type=int, default=0, help='Number of warmup steps') |
|
|
| parser.add_argument('--sample_every', type=int, default=500, help='Audio sampling interval (batches)') |
| parser.add_argument('--save_every', type=int, default=1000, help='Model saving interval (batches)') |
| parser.add_argument('--num_samples', type=int, default=16, help='Number of samples to generate') |
| parser.add_argument('--resume', type=bool, default=True, help='Resume training from checkpoint') |
| parser.add_argument('--pad_to_length', action='store_true', help='Pad short samples to subsection_length instead of filtering them out') |
| parser.add_argument('--time_sampling', type=str, default='warped', choices=['sigmoid', 'warped', 'uniform'], help='Timestep sampling strategy') |
|
|
| return parser.parse_args() |
|
|
| def main(): |
| args = parse_args() |
|
|
| accelerator = Accelerator(mixed_precision="bf16" if torch.cuda.is_available() else "no") |
|
|
| is_main_process = accelerator.is_main_process |
| |
| writer = None |
| if is_main_process: |
| run_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
| writer = SummaryWriter(log_dir=f"runs/{run_datetime}") |
| |
| dataset = load_from_disk(args.dataset_name).with_format(type="torch") |
|
|
| |
| if not args.pad_to_length: |
| def filter_by_length(example): |
| latent_width = example['latents'].shape[-1] |
| return latent_width >= args.subsection_length * 2 |
|
|
| dataset = dataset.filter(filter_by_length) |
|
|
| if is_main_process: |
| print(f"Dataset filtered to {len(dataset)} samples with width >= {args.subsection_length * 2}") |
| else: |
| if is_main_process: |
| print(f"Padding enabled: short samples will be zero-padded to {args.subsection_length}") |
|
|
| |
| latent_mean = torch.tensor([0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526]).view(1, -1, 1, 1) |
| latent_std = torch.tensor([0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707]).view(1, -1, 1, 1) |
|
|
| |
| num_classes = 2304 |
| tag_embedder = TagEmbedder(num_classes=num_classes) |
|
|
| |
| def collate_fn(batch): |
| subsection_length = args.subsection_length |
| pad_to_length = False |
|
|
| sampled_latents = [] |
| album_names = [] |
| song_names = [] |
| ids = [] |
| tags = [] |
|
|
| for item in batch: |
| latent = item['latents'] |
| if len(latent.shape) == 3: |
| latent = latent.unsqueeze(0) |
|
|
| |
| _, _, _, width = latent.shape |
|
|
| if width < subsection_length: |
| if pad_to_length: |
| |
| pad_amount = subsection_length - width |
| sampled_latent = torch.nn.functional.pad(latent, (0, pad_amount), mode='constant', value=0) |
|
|
| else: |
| |
| max_start = width - subsection_length |
| start_idx = torch.randint(0, max_start + 1, (1,)).item() |
|
|
| |
| sampled_latent = latent[:, :, :, start_idx:start_idx + subsection_length] |
|
|
| sampled_latents.append(sampled_latent.squeeze(0)) |
| album_name = item['album_name'] |
| song_name = item['song_name'] |
| album_names.append(album_name) |
| song_names.append(song_name) |
|
|
| sample_tags = tag_embedder.get_tags(album_name, song_name) |
| tags.append(sample_tags) |
|
|
| |
| stacked_latents = torch.stack(sampled_latents) |
| normalized_latents = (stacked_latents - latent_mean) / latent_std |
|
|
| return { |
| 'latents': normalized_latents, |
| 'tags': tags |
| } |
|
|
| dataloader = DataLoader( |
| dataset, |
| batch_size=args.batch_size, |
| shuffle=True, |
| drop_last=True, |
| persistent_workers=True, |
| num_workers=args.num_workers if torch.cuda.is_available() else 0, |
| pin_memory=True, |
| collate_fn=collate_fn |
| ) |
| |
| channels = args.channels |
|
|
| model = LocalSongModel( |
| in_channels=channels, |
| num_groups=args.n_heads, |
| hidden_size=args.dim, |
| decoder_hidden_size=args.decoder_dim, |
| num_blocks=args.n_layers, |
| patch_size=(16, 1), |
| num_classes=num_classes, |
| max_tags=8, |
| ) |
|
|
| vae = AudioVAE(accelerator.device) |
|
|
| rf = RF(model, time_sampling=args.time_sampling) |
|
|
| optimizer = timm.optim.Muon(model.parameters(),lr=args.lr) |
| scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.epochs * len(dataloader)) |
|
|
| global_step = 0 |
| if args.resume: |
| global_step = resume(model, optimizer, scheduler, accelerator) |
|
|
| if torch.cuda.is_available(): |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| model.forward_emb = torch.compile(model.forward_emb) |
|
|
| model, optimizer, scheduler, dataloader = accelerator.prepare( |
| model, optimizer, scheduler, dataloader |
| ) |
| |
| rf.model = model |
|
|
| if is_main_process: |
| model_size = sum(p.numel() for p in accelerator.unwrap_model(model).parameters() if p.requires_grad) |
| print(f"Number of parameters: {model_size}, {model_size / 1e6}M") |
|
|
| os.makedirs("audio_samples", exist_ok=True) |
| num_samples = args.num_samples |
| |
| fixed_batch = None |
| fixed_latents = None |
| fixed_labels = None |
| fixed_noise = None |
| |
| if is_main_process: |
| data_iter = iter(dataloader) |
| fixed_batch = next(data_iter) |
| fixed_latents = fixed_batch["latents"][:num_samples] |
|
|
| print("Fixed ids:", fixed_batch["album_names"]) |
|
|
| |
| fixed_tags = [] |
|
|
| |
| idx_to_tag = {v: k for k, v in tag_embedder.tag_mapping.items()} |
|
|
| |
| print("Fixed tag labels:") |
| for i, tag_list in enumerate(fixed_tags): |
| labels = [idx_to_tag.get(idx, f"<unknown:{idx}>") for idx in tag_list] |
| print(f" Sample {i}: {labels}") |
|
|
| |
| B, C, H, W = fixed_latents.shape |
| fixed_noise = torch.randn(num_samples, C, H, args.subsection_length, device=accelerator.device) |
|
|
| fixed_latents = fixed_latents.to(accelerator.device) |
|
|
| if is_main_process: |
| print("Begin training") |
|
|
| mse_loss_window = deque(maxlen=100) |
| start_epoch = 0 |
| for epoch in range(start_epoch, args.epochs): |
| |
| pbar = tqdm(dataloader) if is_main_process else dataloader |
| for batch in pbar: |
| x = batch["latents"] |
|
|
| |
| tags = batch["tags"] |
|
|
| |
| dropout_tags = [] |
| for tag_list in tags: |
| if torch.rand(1).item() < 0.1: |
| |
| dropout_tags.append([]) |
| else: |
| dropout_tags.append(tag_list) |
|
|
| |
| c = dropout_tags |
|
|
| with accelerator.accumulate(model): |
| optimizer.zero_grad() |
| mse_loss = rf.forward(x, c) |
| |
| loss = mse_loss |
| |
| accelerator.backward(loss) |
| accelerator.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| scheduler.step() |
|
|
| if is_main_process: |
|
|
| mse_loss_window.append(mse_loss.item()) |
| |
| avg_mse_loss = sum(mse_loss_window) / len(mse_loss_window) |
|
|
| if isinstance(pbar, tqdm): |
| pbar.set_postfix({"mse_loss": avg_mse_loss, "lr": optimizer.param_groups[0]['lr']}) |
| |
| if writer is not None: |
| writer.add_scalar('Learning_Rate', optimizer.param_groups[0]['lr'], global_step) |
| writer.add_scalar('MSE_Loss', avg_mse_loss, global_step) |
|
|
| global_step += 1 |
| |
| if is_main_process and global_step % args.save_every == 0: |
| save(model, optimizer, scheduler, global_step, accelerator) |
| |
| if is_main_process and global_step % args.sample_every == 0: |
| model.eval() |
|
|
| with torch.no_grad(): |
| |
| cond = fixed_tags |
| |
| null_cond = [[] for _ in range(len(cond))] |
|
|
| sampled_latents = rf.sample(fixed_noise, cond, null_cond)[-1] |
| |
| |
| try: |
| sampled_audio = vae.decode(sampled_latents) |
| |
| |
| for i in range(min(8, sampled_audio.shape[0])): |
| save_audio_samples( |
| sampled_audio[i:i+1], |
| 48000, |
| f"sample_{global_step}_generated_{i}.wav" |
| ) |
| |
| |
| if global_step == args.sample_every: |
| original_audio = vae.decode(fixed_latents) |
| for i in range(min(8, original_audio.shape[0])): |
| save_audio_samples( |
| original_audio[i:i+1], |
| 48000, |
| f"sample_{global_step}_original_{i}.wav" |
| ) |
| |
| except Exception as e: |
| print(f"Error during audio generation: {e}") |
|
|
| model.train() |
| |
| print("Saving final model") |
| save(model, optimizer, scheduler, global_step, accelerator) |
| |
| if writer is not None: |
| writer.close() |
|
|
| if __name__ == '__main__': |
| main() |
|
|