| import torch |
| from torch.utils.data import Dataset, DataLoader |
| import torchvision |
| from torchvision import transforms |
| from torchvision.transforms.functional import to_pil_image, to_tensor |
| import glob |
| from PIL import Image |
| import tqdm |
| import gc |
|
|
| class TestModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.start = torch.nn.Conv2d(3, 16, 3, 1, 1, bias=False) |
| self.conv1 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False) |
| self.conv2 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False) |
| self.conv3 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False) |
| self.final = torch.nn.Conv2d(16, 3, 3, 1, 1, bias=False) |
| self.bn1 = torch.nn.BatchNorm2d(16) |
| self.bn2 = torch.nn.BatchNorm2d(16) |
|
|
| def forward(self, x): |
| x = self.start(x) |
| x = self.bn1(x) |
| x = self.conv1(x) + x |
| x = self.conv2(x) + x |
| x = self.conv3(x) + x |
| x = self.bn2(x) |
| x = self.final(x) |
| x = torch.clamp(x, -1, 1) |
| return x |
| |
| class DS(Dataset): |
| def __init__(self): |
| super().__init__() |
| self.g = glob.glob("./15k/*") |
| self.trans = transforms.Compose([ |
| transforms.RandomCrop((256, 256)), |
| transforms.ToTensor() |
| ]) |
|
|
| def __len__(self): |
| return len(self.g) |
| |
| def __getitem__(self, idx): |
| x = self.g[idx] |
| x = Image.open(x) |
| x = x.convert("RGB") |
| x = self.trans(x) |
| x = x / 127.5 - 1 |
| return x |
| |
| def gettest(self): |
| x = self.g[0] |
| x = Image.open(x) |
| x = x.convert("RGB") |
| x = to_tensor(x) |
| x = x / 127.5 - 1 |
| return x |
| |
| def main(): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| bacth_size = 64 |
| epoch = 10 |
|
|
| model = TestModel() |
| dataset = DS() |
| datalaoder = DataLoader(dataset, batch_size=bacth_size, shuffle=True) |
| criterion = torch.nn.MSELoss() |
| kl = torch.nn.KLDivLoss(size_average=False) |
| optim = torch.optim.Adam(model.parameters(recurse=True), lr=1e-4) |
| criterion = criterion.to(device) |
| model = model.to(device) |
| model.train() |
|
|
| def log(l): |
| model.eval() |
| x = dataset.gettest().to(device) |
| x = x.unsqueeze(0) |
| out = model(x) |
| to_pil_image((out[0] + 1)/2).save("./test/" + str(l) + ".png") |
| model.train() |
|
|
| log("test") |
|
|
| for i in range(epoch): |
| for j, k in enumerate(tqdm.tqdm(datalaoder)): |
| k = k.to(device) |
| model.zero_grad() |
| out = model(k) |
| loss = criterion(out, k) |
| loss.backward() |
| optim.step() |
| if j % 100 == 0: |
| gc.collect() |
| torch.cuda.empty_cache() |
| print("EPOCH", i) |
| print("LAST LOSS", loss) |
| log(i) |
| |
| |
| if __name__ == "__main__": |
| main() |