| import os |
| import torchaudio |
| import torch |
| import numpy as np |
| import soundfile |
| class AudioLoader: |
| def __init__(self, sample_rate=16000): |
| self.sample_rate = sample_rate |
|
|
| def load_audio(self, file_path): |
| audio, sample_rate = torchaudio.load(file_path,backend='soundfile') |
| if sample_rate != self.sample_rate: |
| audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)(audio) |
| return audio.squeeze(0) |
|
|
| class STFT: |
| def __init__(self, n_fft=1024, hop_length=512, win_length=1024): |
| self.n_fft = n_fft |
| self.hop_length = hop_length |
| self.win_length = win_length |
|
|
| def compute_stft(self, signal): |
| return torch.stft(signal, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=torch.hamming_window(self.win_length), return_complex=True) |
|
|
| class SpectrogramSaver: |
| @staticmethod |
| def save_spectrogram(spectrogram, save_path): |
| torch.save(spectrogram, save_path) |
|
|
| class Preprocessing: |
| def __init__(self, sample_rate=16000, n_fft=1024, hop_length=512, win_length=1024): |
| self.loader = AudioLoader(sample_rate) |
| self.stft = STFT(n_fft, hop_length, win_length) |
| self.saver = SpectrogramSaver() |
| self.fixed_length = None |
|
|
| def preprocess(self, signal): |
| spectrogram = self.stft.compute_stft(signal) |
| real = spectrogram.real |
| imag = spectrogram.imag |
| combined = torch.stack((real, imag), dim=-1) |
| return combined |
|
|
| def determine_fixed_length(self, noisy_dir): |
| lengths = [] |
| noisy_files = [os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir) if f.endswith('.wav')] |
|
|
| for noisy_file in noisy_files: |
| noisy_audio = self.loader.load_audio(noisy_file) |
| noisy_spectrogram = self.preprocess(noisy_audio) |
| lengths.append(noisy_spectrogram.shape[1]) |
|
|
| self.fixed_length = int(np.median(lengths)) |
| print(f"Determined fixed length: {self.fixed_length}") |
|
|
| def create_dataset(self, noisy_dir, save_dir): |
| if self.fixed_length is None: |
| self.determine_fixed_length(noisy_dir) |
|
|
| noisy_save_dir = os.path.join(save_dir, 'noisy') |
| |
| if not os.path.exists(noisy_save_dir): |
| os.makedirs(noisy_save_dir) |
|
|
| noisy_files = [os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir) if f.endswith('.wav')] |
|
|
| for noisy_file in noisy_files: |
| noisy_audio = self.loader.load_audio(noisy_file) |
| noisy_spectrogram = self.preprocess(noisy_audio) |
| noisy_spectrogram = self.pad_spectrogram(noisy_spectrogram) |
| noisy_save_path = os.path.join(noisy_save_dir, f"noisy_{os.path.basename(noisy_file).split('.')[0]}.pt") |
| self.saver.save_spectrogram(noisy_spectrogram, noisy_save_path) |
|
|
| def pad_spectrogram(self, spectrogram): |
| pad_length = self.fixed_length - spectrogram.shape[1] |
| if pad_length > 0: |
| pad = torch.zeros((spectrogram.shape[0], pad_length, spectrogram.shape[2])) |
| spectrogram = torch.cat((spectrogram, pad), dim=1) |
| elif pad_length < 0: |
| spectrogram = spectrogram[:, :self.fixed_length, :] |
| return spectrogram |
|
|
| |
| |
| |
| |
|
|
| |
| |
|
|