| | import os |
| | import re |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torchaudio |
| | import numpy as np |
| | import pytorch_lightning as pl |
| | import random |
| | import librosa |
| | from os.path import basename, exists, join |
| | from torch.utils.data import Dataset, DataLoader |
| | import hydra |
| | import utils |
| | import torchaudio |
| | from transformers import AutoFeatureExtractor |
| | from torchaudio.transforms import Resample |
| | from tqdm import tqdm |
| |
|
| | class DataModule(pl.LightningDataModule): |
| | def __init__(self, cfg): |
| | super().__init__() |
| | self.cfg = cfg |
| | |
| | ocwd = hydra.utils.get_original_cwd() |
| | self.ocwd = ocwd |
| |
|
| | def get_loader(self, phase): |
| | phase_cfg = self.cfg.dataset.get(phase) |
| | batch_size = phase_cfg.batch_size |
| | ds = FSDataset(phase, self.cfg) |
| | |
| |
|
| | dl = DataLoader(ds, |
| | batch_size=batch_size, |
| | shuffle=phase_cfg.shuffle, |
| | num_workers=8, |
| | collate_fn=ds.collate_fn, |
| | pin_memory=True, |
| | persistent_workers=False) |
| |
|
| | return dl |
| |
|
| | def train_dataloader(self): |
| | return self.get_loader('train') |
| |
|
| | def val_dataloader(self): |
| | return self.get_loader('val') |
| |
|
| | def test_dataloader(self): |
| | pass |
| |
|
| | class FSDataset(Dataset): |
| | """Dataset batching wav, mel |
| | and other acoustic features |
| | |
| | Args: |
| | phase: train, val, test |
| | cfg: hydra config |
| | """ |
| | def __init__(self, phase, cfg): |
| | self.phase = phase |
| | self.cfg = cfg |
| | self.phase_cfg = cfg.dataset.get(phase) |
| | self.ocwd = hydra.utils.get_original_cwd() |
| | |
| | self.sr = cfg.preprocess.audio.sr |
| | |
| | |
| | self.filelist = self.get_filelist(self.phase_cfg.filelist) |
| | self.min_audio_length = cfg.dataset.min_audio_length |
| | self.feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0") |
| | self.resample_to_16k = Resample(24000, 16000) |
| | |
| | def __len__(self): |
| | return len(self.filelist) |
| |
|
| | def load_wav(self, path): |
| | wav, sr = librosa.load(path, sr=self.sr) |
| | return wav |
| |
|
| | def get_filelist(self, fpath): |
| | with open(fpath, 'r') as f: |
| | |
| | flist = [l.strip().split('\t')[0] for l in f if l.strip()] |
| | return flist |
| |
|
| |
|
| | def __getitem__(self, idx): |
| | wavpath = self.filelist[idx] |
| | |
| | try: |
| | wav, sr = torchaudio.load(wavpath) |
| | except Exception as e: |
| | print(f"Error loading {wavpath}: {e}") |
| | wav = torch.zeros((1, self.min_audio_length)) |
| | sr = self.sr |
| | |
| | if sr != 24000: |
| | wav = Resample(sr, 24000)(wav) |
| | |
| | wav = wav[0,:] |
| | length = wav.shape[0] |
| | |
| | if length < self.min_audio_length: |
| | wav = F.pad(wav, (0, self.min_audio_length - length)) |
| | length = wav.shape[0] |
| | |
| | i = random.randint(0, length - self.min_audio_length) |
| | wav = wav[i:i + self.min_audio_length] |
| | |
| | |
| | wav_16k = self.resample_to_16k(wav) |
| | wav_16k_pad = F.pad(wav_16k, (160, 160)) |
| | |
| | feat = self.feature_extractor(wav_16k_pad, sampling_rate=16000, return_tensors="pt").data['input_features'].squeeze(0) |
| | |
| | out = { |
| | 'wav': wav, |
| | 'feat': feat, |
| | } |
| | |
| | return out |
| | |
| | def collate_fn(self, bs): |
| | wavs = [b['wav'] for b in bs] |
| | wavs = torch.stack(wavs) |
| | feats = [b['feat'] for b in bs] |
| | feats = torch.stack(feats) |
| | out = { |
| | 'wav': wavs, |
| | 'feats': feats, |
| | |
| | } |
| | return out |
| |
|
| | @hydra.main(config_path='config', config_name='default', version_base=None) |
| | def main(cfg): |
| | data_module = DataModule(cfg) |
| | train_loader = data_module.val_dataloader() |
| | |
| | valid_filelist = [] |
| | |
| | for batch_idx, batch in enumerate(tqdm(train_loader, desc="Processing batches", unit="batch")): |
| | wavs = batch['wav'] |
| |
|
| | if __name__ == "__main__": |
| | main() |