Update src/inference/sampling.py
Browse files- src/inference/sampling.py +17 -6
src/inference/sampling.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
import torch as t
|
| 2 |
|
| 3 |
@t.no_grad()
|
| 4 |
-
def sample(v, z, actions, num_steps=10, cfg=0, negative_actions=None):
|
| 5 |
-
return sample_with_grad(v, z, actions, num_steps, cfg, negative_actions)
|
| 6 |
|
| 7 |
-
def sample_with_grad(v, z, actions, num_steps=10, cfg=0, negative_actions=None):
|
| 8 |
device = v.device
|
| 9 |
ts = 1 - t.linspace(0, 1, num_steps+1, device=device)
|
| 10 |
ts = 3*ts/(2*ts + 1)
|
|
@@ -12,12 +12,23 @@ def sample_with_grad(v, z, actions, num_steps=10, cfg=0, negative_actions=None):
|
|
| 12 |
z_prev = z_prev.to(device)
|
| 13 |
for i in range(len(ts)-1):
|
| 14 |
t_cond = ts[i].repeat(z_prev.shape[0], 1)
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
if cfg > 0:
|
|
|
|
|
|
|
| 17 |
if negative_actions is not None:
|
| 18 |
-
v_neg = v(z_prev.to(device), negative_actions.to(device), t_cond.to(device))
|
| 19 |
else:
|
| 20 |
-
v_neg = v(z_prev.to(device), t.zeros_like(actions, dtype=t.long, device=device), t_cond.to(device))
|
| 21 |
v_pred = v_neg + cfg * (v_pred - v_neg)
|
| 22 |
z_prev = z_prev + (ts[i] - ts[i+1])*v_pred
|
| 23 |
return z_prev
|
|
|
|
| 1 |
import torch as t
|
| 2 |
|
| 3 |
@t.no_grad()
|
| 4 |
+
def sample(v, z, actions, num_steps=10, cfg=0, negative_actions=None, cache=None):
|
| 5 |
+
return sample_with_grad(v, z, actions, num_steps, cfg, negative_actions, cache=cache)
|
| 6 |
|
| 7 |
+
def sample_with_grad(v, z, actions, num_steps=10, cfg=0, negative_actions=None, cache=None):
|
| 8 |
device = v.device
|
| 9 |
ts = 1 - t.linspace(0, 1, num_steps+1, device=device)
|
| 10 |
ts = 3*ts/(2*ts + 1)
|
|
|
|
| 12 |
z_prev = z_prev.to(device)
|
| 13 |
for i in range(len(ts)-1):
|
| 14 |
t_cond = ts[i].repeat(z_prev.shape[0], 1)
|
| 15 |
+
cached_k = None
|
| 16 |
+
cached_v = None
|
| 17 |
+
if cache is not None:
|
| 18 |
+
cached_k, cached_v = cache.get()
|
| 19 |
+
|
| 20 |
+
v_pred, k_new, v_new = v(z_prev.to(device), actions.to(device), t_cond.to(device), cached_k=cached_k, cached_v=cached_v)
|
| 21 |
+
|
| 22 |
+
if i == len(ts)-2 and cache is not None:
|
| 23 |
+
cache.extend(k_new, v_new)
|
| 24 |
+
|
| 25 |
if cfg > 0:
|
| 26 |
+
if cache is not None:
|
| 27 |
+
raise NotImplementedError("this is not implemented yet")
|
| 28 |
if negative_actions is not None:
|
| 29 |
+
v_neg, _, _ = v(z_prev.to(device), negative_actions.to(device), t_cond.to(device))
|
| 30 |
else:
|
| 31 |
+
v_neg, _, _ = v(z_prev.to(device), t.zeros_like(actions, dtype=t.long, device=device), t_cond.to(device))
|
| 32 |
v_pred = v_neg + cfg * (v_pred - v_neg)
|
| 33 |
z_prev = z_prev + (ts[i] - ts[i+1])*v_pred
|
| 34 |
return z_prev
|