from torch.utils.data import TensorDataset, DataLoader from torch import nn import torch as t import numpy as np from einops import rearrange mean = t.tensor([[[[[0.0352]], [[0.1046]], [[0.1046]]]]]) std = t.tensor([[[[[0.1066]], [[0.0995]], [[0.0995]]]]]) def fixed2frame(y, lam=1e-6): y = y.clamp(-1, 1) * 0.5 + 0.5 frames = (y * 255.0).round().byte() return frames def z2frame(y, lam=1e-6, mean=mean, std=std): y = y*std.to(y.dtype).to(y.device) + mean.to(y.dtype).to(y.device) frames = (y.clamp(0, 1) * 255.0).round().byte() return frames def get_loader(batch_size=64, fps=30, duration=5, shuffle=True, debug=False, mode="-1,1", mean=mean, std=std, drop_duration=False): frames = t.from_numpy(np.load("./datasets/pong1M/frames.npy")) actions = t.from_numpy(np.load("./datasets/pong1M/actions.npy")) height, width, channels = frames.shape[-3:] n = frames.shape[0]//(fps*duration) frames = frames[:n*fps*duration] frames = frames.reshape(n, fps*duration, height, width, channels) frames = frames.permute(0, 1, 4, 2, 3) actions = actions[:n*fps*duration] actions = actions.reshape(-1, fps*duration) b, dur, c, h, w = frames.shape if mode == "-1,1": z = rearrange(frames, "b dur c h w -> (b dur h w) c") mask = (z == t.tensor([6, 24, 24], dtype=z.dtype)).all(dim=1) z = (z.float()/255.0 - 0.5)*2 z[mask] = 0 z = rearrange(z, "(b dur h w) c -> b dur c h w", b=b, dur=dur, c=c, h=h, w=w) frames = z pred2frame = fixed2frame elif mode == "z": frames = frames.float()/255.0 frames = (frames - mean) / (std + 1e-6) pred2frame = z2frame else: raise ValueError(f"Invalid mode: {mode}") firstf = frames[0] firsta = actions[0] if debug: frames = 0*frames + firstf[None] actions = 0*actions + firsta[None] frames = 0*frames + frames[:,0].unsqueeze(1) if drop_duration: dataset = TensorDataset(frames[:, 0], actions[:,0]*0) else: dataset = TensorDataset(frames, actions) loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) print(f"{frames.shape[0]//batch_size} batches") return loader, pred2frame