diff --git a/configs/default.yaml b/configs/default.yaml index d1399b2a..b713732c 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -1,3 +1,45 @@ +data: + input_dir: "./data/gta" + output_dir: "./data/mels" + valid_input_dir: "./data/valid_gta" + +train: + rep_discriminator: 1 + discriminator_train_start_steps: 10000 + num_workers: 8 + batch_size: 16 + optimizer: 'adam' + adam: + lr: 0.0001 + beta1: 0.5 + beta2: 0.9 +--- +audio: + n_mel_channels: 80 + segment_length: 16000 + pad_short: 2000 + filter_length: 1024 + hop_length: 256 # WARNING: this can't be changed. + win_length: 1024 + sampling_rate: 22050 + mel_fmin: 0.0 + mel_fmax: 8000.0 + +model: + feat_match: 10.0 + lambda_adv: 2.5 + use_subband_stft_loss: False + feat_loss: False + out_channels: 1 + generator_ratio: [8, 8, 4] # for 256 hop size and 22050 sample rate + mult: 256 + n_residual_layers: 4 + num_D : 3 + ndf : 16 + n_layers: 3 + downsampling_factor: 4 + disc_out: 512 + train: "/mnt/Karan/ResUnet/data/training" valid: "/mnt/Karan/ResUnet/data/testing" log: "logs" @@ -7,6 +49,6 @@ checkpoints: "checkpoints" batch_size: 16 lr: 0.001 -RESNET_PLUS_PLUS: True +RESNET_PLUS_PLUS: False IMAGE_SIZE: 1500 CROP_SIZE: 224 \ No newline at end of file diff --git a/core/discriminator.py b/core/discriminator.py new file mode 100644 index 00000000..aa57443d --- /dev/null +++ b/core/discriminator.py @@ -0,0 +1,24 @@ +import torch +import torch.nn as nn + + +class Discriminator(nn.Module): + def __init__(self): + super(Discriminator, self).__init__() + self.discriminator = nn.Sequential( nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(), + nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(), + nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU() + ) + self.out = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + ''' + returns: (list of 6 features, discriminator score) + we directly predict score without last sigmoid function + since we're using Least Squares GAN (https://arxiv.org/abs/1611.04076) + ''' + x = self.discriminator(x) + return self.out(x) \ No newline at end of file diff --git a/core/multiscale.py b/core/multiscale.py new file mode 100644 index 00000000..bc2be1a4 --- /dev/null +++ b/core/multiscale.py @@ -0,0 +1,22 @@ +import torch +import torch.nn as nn +from utils.utils import weights_init +from .discriminator import Discriminator + + + +class MultiScaleDiscriminator(nn.Module): + def __init__(self): + super().__init__() + self.disc1 = Discriminator() + self.disc2 = Discriminator() + self.disc3 = Discriminator() + + self.apply(weights_init) + + def forward(self, x, start): + results = [] + results.append(self.disc1(x[:, : , 0:20, start: start + 40])) + results.append(self.disc2(x[:, :, 20:40, start: start + 40])) + results.append(self.disc3(x[:, :, 40:80, start: start + 40])) + return results \ No newline at end of file diff --git a/core/res_unet.py b/core/res_unet.py index b365ae87..f6bd00ab 100644 --- a/core/res_unet.py +++ b/core/res_unet.py @@ -33,7 +33,7 @@ def __init__(self, channel, filters=[64, 128, 256, 512]): self.output_layer = nn.Sequential( nn.Conv2d(filters[0], 1, 1, 1), - nn.Sigmoid(), + # nn.Sigmoid(), ) def forward(self, x): diff --git a/core/res_unet_plus.py b/core/res_unet_plus.py index 6658ff1a..62534445 100644 --- a/core/res_unet_plus.py +++ b/core/res_unet_plus.py @@ -51,7 +51,7 @@ def __init__(self, channel, filters=[32, 64, 128, 256, 512]): self.aspp_out = ASPP(filters[1], filters[0]) - self.output_layer = nn.Sequential(nn.Conv2d(filters[0], 1, 1), nn.Sigmoid()) + self.output_layer = nn.Sequential(nn.Conv2d(filters[0], 1, 1)) # , nn.Sigmoid()) def forward(self, x): x1 = self.input_layer(x) + self.input_skip(x) diff --git a/dataset/mel_dataset.py b/dataset/mel_dataset.py new file mode 100644 index 00000000..63625e08 --- /dev/null +++ b/dataset/mel_dataset.py @@ -0,0 +1,53 @@ +import os +import glob +import torch +import random +import numpy as np +from torch.utils.data import Dataset, DataLoader + + +def create_dataloader(hp, train): + dataset = MelFromDisk(hp, train) + + if train: + return DataLoader(dataset=dataset, batch_size=hp.train.batch_size, shuffle=True, + num_workers=0, pin_memory=True, drop_last=True) + else: + return DataLoader(dataset=dataset, batch_size=1, shuffle=False, + num_workers=0, pin_memory=False, drop_last=False) + + +class MelFromDisk(Dataset): + def __init__(self, hp, train): + self.hp = hp + self.train = train + self.path = hp.data.input_dir if train else hp.data.valid_input_dir + self.wav_list = glob.glob(os.path.join(self.path, '**', '*.npy'), recursive=True) + self.mel_segment_length = hp.model.idim + self.mapping = [i for i in range(len(self.wav_list))] + + def __len__(self): + return len(self.wav_list) + + def __getitem__(self, idx): + input_mel = self.wav_list[idx] + id = os.path.basename(input_mel).split(".")[0] + + input_mel_path = "{}/{}.npy".format(self.hp.data.input_dir, id) + output_mel_path = "{}/{}.npy".format(self.hp.data.output_dir, id) + + mel_gt = torch.from_numpy(np.load(output_mel_path)) + # mel = torch.load(melpath).squeeze(0) # # [num_mel, T] + + mel_gta = torch.from_numpy(np.load(input_mel_path)) + + max_mel_start = mel_gta.size(1) - self.mel_segment_length + mel_start = random.randint(0, max_mel_start) + mel_end = mel_start + self.mel_segment_length + mel_gta = mel_gta[:, mel_start:mel_end] + mel_gt = mel_gt[:, mel_start:mel_end] + + return mel_gta, mel_gt + + def shuffle_mapping(self): + random.shuffle(self.mapping) \ No newline at end of file diff --git a/train_gan.py b/train_gan.py new file mode 100644 index 00000000..e1216c3c --- /dev/null +++ b/train_gan.py @@ -0,0 +1,264 @@ +import warnings +import numpy as np +warnings.simplefilter("ignore", (UserWarning, FutureWarning)) +from utils.hparams import HParam +from tqdm import tqdm +from dataset.mel_dataset import create_dataloader +from utils.utils import get_commit_hash +from core.res_unet import ResUnet +from core.res_unet_plus import ResUnetPlusPlus +from core.multiscale import MultiScaleDiscriminator +from utils.logger import LogWriter +import torch +import argparse +import os + + +def main(hp, num_epochs, resume, name): + + checkpoint_dir = "{}/{}".format(hp.checkpoints, name) + os.makedirs(checkpoint_dir, exist_ok=True) + + os.makedirs("{}/{}".format(hp.log, name), exist_ok=True) + writer = LogWriter("{}/{}".format(hp.log, name)) + # get model + githash = get_commit_hash() + if hp.RESNET_PLUS_PLUS: + model_g = ResUnetPlusPlus(3).cuda() + else: + model_g = ResUnet(3, 64).cuda() + + model_d = MultiScaleDiscriminator().cuda() + + # set up binary cross entropy and dice loss + # criterion = metrics.BCEDiceLoss() + + + # optimizer + # optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, nesterov=True) + # optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5) + optim_g = torch.optim.Adam(model_g.parameters(), + lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2)) + optim_d = torch.optim.Adam(model_d.parameters(), + lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2)) + + + + # starting params + best_loss = 999 + start_epoch = 0 + step = 0 + # optionally resume from a checkpoint + if resume: + if os.path.isfile(resume): + + checkpoint = torch.load(resume) + model_g.load_state_dict(checkpoint['model_g']) + model_d.load_state_dict(checkpoint['model_d']) + optim_g.load_state_dict(checkpoint['optim_g']) + optim_d.load_state_dict(checkpoint['optim_d']) + step = checkpoint['step'] + init_epoch = checkpoint['epoch'] + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + trainloader = create_dataloader(hp, True) + validloader = create_dataloader(hp, False) + + + model_g.train() + model_d.train() + + criterion_mse = torch.nn.MSELoss() + criterion_l1 = torch.nn.L1Loss() + + for epoch in range(start_epoch, num_epochs): + print("Epoch {}/{}".format(epoch, num_epochs - 1)) + print("-" * 10) + + # iterate over data + avg_g_loss = [] + avg_d_loss = [] + avg_adv_loss = [] + loader = tqdm(trainloader, desc="training") + for idx, data in enumerate(loader): + + # get the inputs and wrap in Variable + inputs = data["sat_img"].cuda() + labels = data["map_img"].cuda() + start = np.random.randint(0, 512-40) + + # generator + optim_g.zero_grad() + fake_mel = model_g(inputs.unsqueeze(1)) + + loss_g = 0.0 + loss_g = criterion_l1(fake_mel.squeeze(1), labels) + adv_loss = 0.0 + if step > hp.train.discriminator_train_start_steps: + disc_real = model_d(labels.unsqueeze(1),start) + disc_fake = model_d(fake_mel, start) + # for multi-scale discriminator + + for score_fake in disc_fake: + adv_loss += criterion_mse(score_fake, torch.ones_like(score_fake)) + adv_loss = adv_loss / len(disc_fake) # len(disc_fake) = 3 + + loss_g += hp.model.lambda_adv * adv_loss + + loss_g.backward() + optim_g.step() + + + # discriminator + + loss_d_avg = 0.0 + if step > hp.train.discriminator_train_start_steps: + start = np.random.randint(0, 512 - 40) + fake_mel = model_g(inputs.unsqueeze(1)) + fake_mel = fake_mel.detach() + loss_d_sum = 0.0 + for _ in range(hp.train.rep_discriminator): + optim_d.zero_grad() + disc_fake = model_d(fake_mel, start) + disc_real = model_d(labels.unsqueeze(1), start) + loss_d = 0.0 + loss_d_real = 0.0 + loss_d_fake = 0.0 + for score_fake, score_real in zip(disc_fake, disc_real): + loss_d_real += criterion_mse(score_real, torch.ones_like(score_real)) + loss_d_fake += criterion_mse(score_fake, torch.zeros_like(score_fake)) + loss_d_real = loss_d_real / len(disc_real) # len(disc_real) = 3 + loss_d_fake = loss_d_fake / len(disc_fake) # len(disc_fake) = 3 + loss_d = loss_d_real + loss_d_fake + loss_d.backward() + optim_d.step() + loss_d_sum += loss_d + loss_d_avg = loss_d_sum / hp.train.rep_discriminator + loss_d_avg = loss_d_avg.item() + + step += 1 + # logging + loss_g = loss_g.item() + avg_g_loss.append(loss_g) + avg_d_loss.append(loss_d_avg) + avg_adv_loss.append(adv_loss) + + + + # tensorboard logging + if step % hp.logging_step == 0: + writer.log_scaler("g_loss", sum(avg_g_loss) / len(avg_g_loss), step) + writer.log_scaler("adv_loss", sum(avg_adv_loss) / len(avg_adv_loss), step) + writer.log_scaler("d_loss", sum(avg_d_loss) / len(avg_d_loss), step) + loader.set_description( + "Avg : g %.04f d %.04f ad %.04f| step %d" % (sum(avg_g_loss) / len(avg_g_loss), + sum(avg_d_loss) / len(avg_d_loss), + sum(avg_adv_loss) / len(avg_adv_loss), + step) + ) + + # Validatiuon + if step % hp.validation_interval == 0: + valid_metrics = validation( + validloader, model_g, model_d, criterion_l1, criterion_mse, hp, writer, step + ) + save_path = os.path.join(checkpoint_dir, '%s_%s_%04d.pt' + % (args.name, githash, epoch)) + torch.save({ + 'model_g': model_g.state_dict(), + 'model_d': model_d.state_dict(), + 'optim_g': optim_g.state_dict(), + 'optim_d': optim_d.state_dict(), + 'step': step, + 'epoch': epoch, + 'hp_str': hp_str, + 'githash': githash, + }, save_path) + print("Saved checkpoint to: %s" % save_path) + + step += 1 + + +def validation(val_dataloader, model_g, model_d, criterion_l1, criterion_mse, hp, writer, step): + + + # switch to evaluate mode + model_g.eval() + model_d.eval() + loss_g_sum = 0.0 + loss_d_sum = 0.0 + # Iterate over data. + + for idx, data in enumerate(tqdm(val_dataloader, desc="validation")): + + # get the inputs and wrap in Variable + inputs = data["sat_img"].cuda() + labels = data["map_img"].cuda() + + # generator + start = np.random.randint(0, 512 - 40) + fake_mel = model_g(inputs.unsqueeze(1)) # B, 1, T' torch.Size([1, 1, 212992]) + if idx < 1: + writer.log_image("actual", labels.squeeze(), "Validation") + writer.log_image("input", inputs.squeeze(), "Validation") + writer.log_image("generated", fake_mel.squeeze(), "Validation") + disc_fake = model_d(fake_mel, start) # B, 1, T torch.Size([1, 1, 212893]) + disc_real = model_d(labels, start) + + adv_loss = 0.0 + loss_d_real = 0.0 + loss_d_fake = 0.0 + loss_g = criterion_l1(fake_mel.squeeze(1), labels) + + + for score_fake, score_real in zip(disc_fake, disc_real): + adv_loss += criterion_mse(score_fake, torch.ones_like(score_fake)) + loss_d_real += criterion_mse(score_real, torch.ones_like(score_real)) + loss_d_fake += criterion_mse(score_fake, torch.zeros_like(score_fake)) + adv_loss = adv_loss / len(disc_fake) + loss_d_real = loss_d_real / len(score_real) + loss_d_fake = loss_d_fake / len(disc_fake) + loss_g += hp.model.lambda_adv * adv_loss + loss_d = loss_d_real + loss_d_fake + loss_g_sum += loss_g.item() + loss_d_sum += loss_d.item() + + loss_g_avg = loss_g_sum / len(val_dataloader.dataset) + loss_d_avg = loss_d_sum / len(val_dataloader.dataset) + writer.log_scaler("g_loss", loss_g_avg, step, "Validation") + writer.log_scaler("d_loss", loss_d_avg, step, "Validation") + print("G Loss: {:.4f} D Loss: {:.4f}".format(loss_g_avg, loss_d_avg)) + model_g.train() + model_d.train() + return {"g_loss": loss_g_avg, "d_acc": loss_d_avg} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Road and Building Extraction") + parser.add_argument( + "-c", "--config", type=str, required=True, help="yaml file for configuration" + ) + parser.add_argument( + "--epochs", + default=75, + type=int, + metavar="N", + help="number of total epochs to run", + ) + parser.add_argument( + "--resume", + default="", + type=str, + metavar="PATH", + help="path to latest checkpoint (default: none)", + ) + parser.add_argument("--name", default="default", type=str, help="Experiment name") + + args = parser.parse_args() + + hp = HParam(args.config) + with open(args.config, "r") as f: + hp_str = "".join(f.readlines()) + + main(hp, num_epochs=args.epochs, resume=args.resume, name=args.name) diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 00000000..3edf9a05 --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,15 @@ +import subprocess + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(0.0, 0.02) + elif classname.find("BatchNorm2d") != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + + +def get_commit_hash(): + message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]) + return message.strip().decode("utf-8")