|
|
from __future__ import print_function |
|
|
from six.moves import range |
|
|
from PIL import Image |
|
|
|
|
|
import torch.backends.cudnn as cudnn |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.autograd import Variable |
|
|
import torch.optim as optim |
|
|
import os |
|
|
import time |
|
|
|
|
|
import numpy as np |
|
|
import torchfile |
|
|
import pickle |
|
|
|
|
|
import soundfile as sf |
|
|
import re |
|
|
import math |
|
|
from wavefile import WaveWriter, Format |
|
|
|
|
|
from miscc.config import cfg |
|
|
from miscc.utils import mkdir_p |
|
|
from miscc.utils import weights_init |
|
|
from miscc.utils import save_RIR_results, save_model |
|
|
from miscc.utils import KL_loss |
|
|
from miscc.utils import compute_discriminator_loss, compute_generator_loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GANTrainer(object): |
|
|
def __init__(self, output_dir): |
|
|
if cfg.TRAIN.FLAG: |
|
|
self.model_dir = os.path.join(output_dir, 'Model') |
|
|
self.model_dir_RT = os.path.join(output_dir, 'Model_RT') |
|
|
self.RIR_dir = os.path.join(output_dir, 'RIR') |
|
|
self.log_dir = os.path.join(output_dir, 'Log') |
|
|
mkdir_p(self.model_dir) |
|
|
mkdir_p(self.model_dir_RT) |
|
|
mkdir_p(self.RIR_dir) |
|
|
mkdir_p(self.log_dir) |
|
|
|
|
|
|
|
|
self.max_epoch = cfg.TRAIN.MAX_EPOCH |
|
|
self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL |
|
|
|
|
|
s_gpus = cfg.GPU_ID.split(',') |
|
|
self.gpus = [int(ix) for ix in s_gpus] |
|
|
self.num_gpus = len(self.gpus) |
|
|
self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus |
|
|
torch.cuda.set_device(self.gpus[0]) |
|
|
cudnn.benchmark = True |
|
|
|
|
|
|
|
|
def load_network_stageI(self): |
|
|
from model import STAGE1_G, STAGE1_D |
|
|
netG = STAGE1_G() |
|
|
netG.apply(weights_init) |
|
|
print(netG) |
|
|
netD = STAGE1_D() |
|
|
netD.apply(weights_init) |
|
|
print(netD) |
|
|
|
|
|
if cfg.NET_G != '': |
|
|
state_dict = \ |
|
|
torch.load(cfg.NET_G, |
|
|
map_location=lambda storage, loc: storage) |
|
|
netG.load_state_dict(state_dict) |
|
|
print('Load from: ', cfg.NET_G) |
|
|
if cfg.NET_D != '': |
|
|
state_dict = \ |
|
|
torch.load(cfg.NET_D, |
|
|
map_location=lambda storage, loc: storage) |
|
|
netD.load_state_dict(state_dict) |
|
|
print('Load from: ', cfg.NET_D) |
|
|
if cfg.CUDA: |
|
|
netG.cuda() |
|
|
netD.cuda() |
|
|
return netG, netD |
|
|
|
|
|
|
|
|
def load_network_stageII(self): |
|
|
from model import STAGE1_G, STAGE2_G, STAGE2_D |
|
|
|
|
|
Stage1_G = STAGE1_G() |
|
|
netG = STAGE2_G(Stage1_G) |
|
|
netG.apply(weights_init) |
|
|
print(netG) |
|
|
if cfg.NET_G != '': |
|
|
state_dict = \ |
|
|
torch.load(cfg.NET_G, |
|
|
map_location=lambda storage, loc: storage) |
|
|
netG.load_state_dict(state_dict) |
|
|
print('Load from: ', cfg.NET_G) |
|
|
elif cfg.STAGE1_G != '': |
|
|
state_dict = \ |
|
|
torch.load(cfg.STAGE1_G, |
|
|
map_location=lambda storage, loc: storage) |
|
|
netG.STAGE1_G.load_state_dict(state_dict) |
|
|
print('Load from: ', cfg.STAGE1_G) |
|
|
else: |
|
|
print("Please give the Stage1_G path") |
|
|
return |
|
|
|
|
|
netD = STAGE2_D() |
|
|
netD.apply(weights_init) |
|
|
if cfg.NET_D != '': |
|
|
state_dict = \ |
|
|
torch.load(cfg.NET_D, |
|
|
map_location=lambda storage, loc: storage) |
|
|
netD.load_state_dict(state_dict) |
|
|
print('Load from: ', cfg.NET_D) |
|
|
print(netD) |
|
|
|
|
|
if cfg.CUDA: |
|
|
netG.cuda() |
|
|
netD.cuda() |
|
|
return netG, netD |
|
|
|
|
|
def train(self, data_loader, stage=1): |
|
|
if stage == 1: |
|
|
netG, netD = self.load_network_stageI() |
|
|
else: |
|
|
netG, netD = self.load_network_stageII() |
|
|
|
|
|
|
|
|
batch_size = self.batch_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) |
|
|
fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) |
|
|
if cfg.CUDA: |
|
|
|
|
|
real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() |
|
|
|
|
|
generator_lr = cfg.TRAIN.GENERATOR_LR |
|
|
discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR |
|
|
lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH |
|
|
|
|
|
|
|
|
|
|
|
optimizerD = \ |
|
|
optim.RMSprop(netD.parameters(), |
|
|
lr=cfg.TRAIN.DISCRIMINATOR_LR) |
|
|
netG_para = [] |
|
|
for p in netG.parameters(): |
|
|
if p.requires_grad: |
|
|
netG_para.append(p) |
|
|
|
|
|
|
|
|
|
|
|
optimizerG = optim.RMSprop(netG_para, |
|
|
lr=cfg.TRAIN.GENERATOR_LR) |
|
|
count = 0 |
|
|
least_RT=10 |
|
|
for epoch in range(self.max_epoch): |
|
|
start_t = time.time() |
|
|
if epoch % lr_decay_step == 0 and epoch > 0: |
|
|
generator_lr *= 0.7 |
|
|
for param_group in optimizerG.param_groups: |
|
|
param_group['lr'] = generator_lr |
|
|
discriminator_lr *= 0.7 |
|
|
for param_group in optimizerD.param_groups: |
|
|
param_group['lr'] = discriminator_lr |
|
|
|
|
|
for i, data in enumerate(data_loader, 0): |
|
|
|
|
|
|
|
|
|
|
|
real_RIR_cpu, txt_embedding = data |
|
|
real_RIRs = Variable(real_RIR_cpu) |
|
|
txt_embedding = Variable(txt_embedding) |
|
|
if cfg.CUDA: |
|
|
real_RIRs = real_RIRs.cuda() |
|
|
txt_embedding = txt_embedding.cuda() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = (txt_embedding) |
|
|
|
|
|
|
|
|
_, fake_RIRs,c_code = nn.parallel.data_parallel(netG, inputs, self.gpus) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
netD.zero_grad() |
|
|
errD, errD_real, errD_wrong, errD_fake = \ |
|
|
compute_discriminator_loss(netD, real_RIRs, fake_RIRs, |
|
|
real_labels, fake_labels, |
|
|
c_code, self.gpus) |
|
|
|
|
|
errD_total = errD*5 |
|
|
errD_total.backward() |
|
|
optimizerD.step() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
netG.zero_grad() |
|
|
errG,MSE_error,RT_error= compute_generator_loss(epoch,netD,real_RIRs, fake_RIRs, |
|
|
real_labels, c_code, self.gpus) |
|
|
errG_total = errG *5 |
|
|
errG_total.backward() |
|
|
optimizerG.step() |
|
|
for p in range(2): |
|
|
inputs = (txt_embedding) |
|
|
|
|
|
|
|
|
_, fake_RIRs,c_code = nn.parallel.data_parallel(netG, inputs, self.gpus) |
|
|
netG.zero_grad() |
|
|
errG,MSE_error,RT_error = compute_generator_loss(epoch,netD,real_RIRs, fake_RIRs, |
|
|
real_labels, c_code, self.gpus) |
|
|
|
|
|
errG_total = errG *5 |
|
|
errG_total.backward() |
|
|
optimizerG.step() |
|
|
|
|
|
count = count + 1 |
|
|
if i % 100 == 0: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = (txt_embedding) |
|
|
lr_fake, fake, _ = \ |
|
|
nn.parallel.data_parallel(netG, inputs, self.gpus) |
|
|
if(epoch%self.snapshot_interval==0): |
|
|
save_RIR_results(real_RIR_cpu, fake, epoch, self.RIR_dir) |
|
|
if lr_fake is not None: |
|
|
save_RIR_results(None, lr_fake, epoch, self.RIR_dir) |
|
|
end_t = time.time() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f |
|
|
Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f MSE_ERROR %.4f RT_error %.4f |
|
|
Total Time: %.2fsec |
|
|
''' |
|
|
% (epoch, self.max_epoch, i, len(data_loader), |
|
|
errD.data, errG.data, |
|
|
errD_real, errD_wrong, errD_fake,MSE_error*4096, RT_error,(end_t - start_t))) |
|
|
|
|
|
store_to_file ="[{}/{}][{}/{}] Loss_D: {:.4f} Loss_G: {:.4f} Loss_real: {:.4f} Loss_wrong:{:.4f} Loss_fake {:.4f} MSE Error:{:.4f} RT_error{:.4f} Total Time: {:.2f}sec".format(epoch, self.max_epoch, i, len(data_loader), |
|
|
errD.data, errG.data, errD_real, errD_wrong, errD_fake,MSE_error*4096,RT_error, (end_t - start_t)) |
|
|
store_to_file =store_to_file+"\n" |
|
|
with open("errors.txt", "a") as myfile: |
|
|
myfile.write(store_to_file) |
|
|
|
|
|
if (RT_error<least_RT): |
|
|
least_RT = RT_error |
|
|
save_model(netG, netD, epoch, self.model_dir_RT) |
|
|
if epoch % self.snapshot_interval == 0: |
|
|
save_model(netG, netD, epoch, self.model_dir) |
|
|
|
|
|
save_model(netG, netD, self.max_epoch, self.model_dir) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sample(self,file_path,stage=1): |
|
|
if stage == 1: |
|
|
netG, _ = self.load_network_stageI() |
|
|
else: |
|
|
netG, _ = self.load_network_stageII() |
|
|
netG.eval() |
|
|
|
|
|
time_list =[] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embedding_path = file_path |
|
|
with open(embedding_path, 'rb') as f: |
|
|
embeddings_pickle = pickle.load(f) |
|
|
|
|
|
|
|
|
|
|
|
embeddings_list =[] |
|
|
num_embeddings = len(embeddings_pickle) |
|
|
for b in range (num_embeddings): |
|
|
embeddings_list.append(embeddings_pickle[b]) |
|
|
|
|
|
embeddings = np.array(embeddings_list) |
|
|
|
|
|
save_dir_GAN = "Generated_RIRs" |
|
|
mkdir_p(save_dir_GAN) |
|
|
|
|
|
|
|
|
|
|
|
normalize_embedding = [] |
|
|
|
|
|
|
|
|
batch_size = np.minimum(num_embeddings, self.batch_size) |
|
|
|
|
|
|
|
|
count = 0 |
|
|
count_this = 0 |
|
|
while count < num_embeddings: |
|
|
|
|
|
iend = count + batch_size |
|
|
if iend > num_embeddings: |
|
|
iend = num_embeddings |
|
|
count = num_embeddings - batch_size |
|
|
embeddings_batch = embeddings[count:iend] |
|
|
|
|
|
|
|
|
|
|
|
txt_embedding = Variable(torch.FloatTensor(embeddings_batch)) |
|
|
if cfg.CUDA: |
|
|
txt_embedding = txt_embedding.cuda() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start_t = time.time() |
|
|
inputs = (txt_embedding) |
|
|
_, fake_RIRs,c_code = \ |
|
|
nn.parallel.data_parallel(netG, inputs, self.gpus) |
|
|
end_t = time.time() |
|
|
diff_t = end_t - start_t |
|
|
time_list.append(diff_t) |
|
|
|
|
|
RIR_batch_size = batch_size |
|
|
print("batch_size ", RIR_batch_size) |
|
|
channel_size = 64 |
|
|
|
|
|
for i in range(channel_size): |
|
|
fs =16000 |
|
|
wave_name = "RIR-"+str(count+i)+".wav" |
|
|
save_name_GAN = '%s/%s' % (save_dir_GAN,wave_name) |
|
|
print("wave : ",save_name_GAN) |
|
|
res = {} |
|
|
res_buffer = [] |
|
|
rate = 16000 |
|
|
res['rate'] = rate |
|
|
|
|
|
wave_GAN = fake_RIRs[i].data.cpu().numpy() |
|
|
wave_GAN = np.array(wave_GAN[0]) |
|
|
|
|
|
|
|
|
res_buffer.append(wave_GAN) |
|
|
res['samples'] = np.zeros((len(res_buffer), np.max([len(ps) for ps in res_buffer]))) |
|
|
for i, c in enumerate(res_buffer): |
|
|
res['samples'][i, :len(c)] = c |
|
|
|
|
|
w = WaveWriter(save_name_GAN, channels=np.shape(res['samples'])[0], samplerate=int(res['rate'])) |
|
|
w.write(np.array(res['samples'])) |
|
|
|
|
|
print("counter = ",count) |
|
|
count = count+64 |
|
|
count_this = count_this+1 |
|
|
|
|
|
|
|
|
|