pong / src /datasets /pong1m.py
chrisxx's picture
Add Neural Pong application files
8746765
from torch.utils.data import TensorDataset, DataLoader
from torch import nn
import torch as t
import numpy as np
from einops import rearrange
mean = t.tensor([[[[[0.0352]],
[[0.1046]],
[[0.1046]]]]])
std = t.tensor([[[[[0.1066]],
[[0.0995]],
[[0.0995]]]]])
def fixed2frame(y, lam=1e-6):
y = y.clamp(-1, 1) * 0.5 + 0.5
frames = (y * 255.0).round().byte()
return frames
def z2frame(y, lam=1e-6, mean=mean, std=std):
y = y*std.to(y.dtype).to(y.device) + mean.to(y.dtype).to(y.device)
frames = (y.clamp(0, 1) * 255.0).round().byte()
return frames
def get_loader(batch_size=64, fps=30, duration=5, shuffle=True, debug=False, mode="-1,1", mean=mean, std=std, drop_duration=False):
frames = t.from_numpy(np.load("./datasets/pong1M/frames.npy"))
actions = t.from_numpy(np.load("./datasets/pong1M/actions.npy"))
height, width, channels = frames.shape[-3:]
n = frames.shape[0]//(fps*duration)
frames = frames[:n*fps*duration]
frames = frames.reshape(n, fps*duration, height, width, channels)
frames = frames.permute(0, 1, 4, 2, 3)
actions = actions[:n*fps*duration]
actions = actions.reshape(-1, fps*duration)
b, dur, c, h, w = frames.shape
if mode == "-1,1":
z = rearrange(frames, "b dur c h w -> (b dur h w) c")
mask = (z == t.tensor([6, 24, 24], dtype=z.dtype)).all(dim=1)
z = (z.float()/255.0 - 0.5)*2
z[mask] = 0
z = rearrange(z, "(b dur h w) c -> b dur c h w", b=b, dur=dur, c=c, h=h, w=w)
frames = z
pred2frame = fixed2frame
elif mode == "z":
frames = frames.float()/255.0
frames = (frames - mean) / (std + 1e-6)
pred2frame = z2frame
else:
raise ValueError(f"Invalid mode: {mode}")
firstf = frames[0]
firsta = actions[0]
if debug:
frames = 0*frames + firstf[None]
actions = 0*actions + firsta[None]
frames = 0*frames + frames[:,0].unsqueeze(1)
if drop_duration:
dataset = TensorDataset(frames[:, 0], actions[:,0]*0)
else:
dataset = TensorDataset(frames, actions)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
print(f"{frames.shape[0]//batch_size} batches")
return loader, pred2frame