pong / src /inference /sampling.py
chrisxx's picture
Update src/inference/sampling.py
0ade316 verified
import torch as t
@t.no_grad()
def sample(v, z, actions, num_steps=10, cfg=0, negative_actions=None, cache=None):
return sample_with_grad(v, z, actions, num_steps, cfg, negative_actions, cache=cache)
def sample_with_grad(v, z, actions, num_steps=10, cfg=0, negative_actions=None, cache=None):
device = v.device
ts = 1 - t.linspace(0, 1, num_steps+1, device=device)
ts = 3*ts/(2*ts + 1)
z_prev = z.clone()
z_prev = z_prev.to(device)
for i in range(len(ts)-1):
t_cond = ts[i].repeat(z_prev.shape[0], 1)
cached_k = None
cached_v = None
if cache is not None:
cached_k, cached_v = cache.get()
v_pred, k_new, v_new = v(z_prev.to(device), actions.to(device), t_cond.to(device), cached_k=cached_k, cached_v=cached_v)
if i == len(ts)-2 and cache is not None:
cache.extend(k_new, v_new)
if cfg > 0:
if cache is not None:
raise NotImplementedError("this is not implemented yet")
if negative_actions is not None:
v_neg, _, _ = v(z_prev.to(device), negative_actions.to(device), t_cond.to(device))
else:
v_neg, _, _ = v(z_prev.to(device), t.zeros_like(actions, dtype=t.long, device=device), t_cond.to(device))
v_pred = v_neg + cfg * (v_pred - v_neg)
z_prev = z_prev + (ts[i] - ts[i+1])*v_pred
return z_prev