|
|
from __future__ import absolute_import |
|
|
from __future__ import division |
|
|
from __future__ import print_function |
|
|
from __future__ import unicode_literals |
|
|
|
|
|
|
|
|
import torch.utils.data as data |
|
|
|
|
|
import soundfile as sf |
|
|
import PIL |
|
|
import os |
|
|
import os.path |
|
|
import pickle |
|
|
import random |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from scipy import signal |
|
|
|
|
|
from miscc.config import cfg |
|
|
|
|
|
|
|
|
class TextDataset(data.Dataset): |
|
|
def __init__(self, data_dir, split='train',rirsize=4096): |
|
|
|
|
|
|
|
|
|
|
|
self.rirsize = rirsize |
|
|
self.data = [] |
|
|
self.data_dir = data_dir |
|
|
self.bbox = None |
|
|
|
|
|
split_dir = os.path.join(data_dir, split) |
|
|
|
|
|
self.filenames = self.load_filenames(split_dir) |
|
|
self.embeddings = self.load_embedding(split_dir) |
|
|
|
|
|
def get_RIR(self, RIR_path): |
|
|
wav,fs = sf.read(RIR_path) |
|
|
length = wav.size |
|
|
|
|
|
crop_length = 4096 |
|
|
if(length<crop_length): |
|
|
zeros = np.zeros(crop_length-length) |
|
|
RIR_original = np.concatenate([wav,zeros]) |
|
|
else: |
|
|
RIR_original = wav[0:crop_length] |
|
|
|
|
|
|
|
|
resample_length = int(self.rirsize) |
|
|
if(resample_length==16384): |
|
|
RIR = RIR_original |
|
|
else: |
|
|
RIR = RIR_original |
|
|
RIR = np.array([RIR]).astype('float32') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return RIR |
|
|
|
|
|
|
|
|
def load_embedding(self, data_dir): |
|
|
embedding_filename = '/embeddings.pickle' |
|
|
with open(data_dir + embedding_filename, 'rb') as f: |
|
|
embeddings = pickle.load(f) |
|
|
|
|
|
|
|
|
|
|
|
return embeddings |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_filenames(self, data_dir): |
|
|
filepath = os.path.join(data_dir, 'filenames.pickle') |
|
|
with open(filepath, 'rb') as f: |
|
|
filenames = pickle.load(f) |
|
|
print('Load filenames from: %s (%d)' % (filepath, len(filenames))) |
|
|
return filenames |
|
|
|
|
|
def __getitem__(self, index): |
|
|
key = self.filenames[index] |
|
|
|
|
|
data_dir = self.data_dir |
|
|
|
|
|
|
|
|
embeddings = self.embeddings[key] |
|
|
RIR_name = '%s/RIR/%s.wav' % (data_dir, key) |
|
|
RIR = self.get_RIR(RIR_name) |
|
|
embedding = np.array(embeddings).astype('float32') |
|
|
|
|
|
|
|
|
return RIR, embedding |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.filenames) |
|
|
|