import torch as t @t.no_grad() def sample(v, z, actions, num_steps=10, cfg=0, negative_actions=None, cache=None): return sample_with_grad(v, z, actions, num_steps, cfg, negative_actions, cache=cache) def sample_with_grad(v, z, actions, num_steps=10, cfg=0, negative_actions=None, cache=None): device = v.device ts = 1 - t.linspace(0, 1, num_steps+1, device=device) ts = 3*ts/(2*ts + 1) z_prev = z.clone() z_prev = z_prev.to(device) for i in range(len(ts)-1): t_cond = ts[i].repeat(z_prev.shape[0], 1) cached_k = None cached_v = None if cache is not None: cached_k, cached_v = cache.get() v_pred, k_new, v_new = v(z_prev.to(device), actions.to(device), t_cond.to(device), cached_k=cached_k, cached_v=cached_v) if i == len(ts)-2 and cache is not None: cache.extend(k_new, v_new) if cfg > 0: if cache is not None: raise NotImplementedError("this is not implemented yet") if negative_actions is not None: v_neg, _, _ = v(z_prev.to(device), negative_actions.to(device), t_cond.to(device)) else: v_neg, _, _ = v(z_prev.to(device), t.zeros_like(actions, dtype=t.long, device=device), t_cond.to(device)) v_pred = v_neg + cfg * (v_pred - v_neg) z_prev = z_prev + (ts[i] - ts[i+1])*v_pred return z_prev