chrisxx commited on
Commit
0ade316
·
verified ·
1 Parent(s): 232e1a9

Update src/inference/sampling.py

Browse files
Files changed (1) hide show
  1. src/inference/sampling.py +17 -6
src/inference/sampling.py CHANGED
@@ -1,10 +1,10 @@
1
  import torch as t
2
 
3
  @t.no_grad()
4
- def sample(v, z, actions, num_steps=10, cfg=0, negative_actions=None):
5
- return sample_with_grad(v, z, actions, num_steps, cfg, negative_actions)
6
 
7
- def sample_with_grad(v, z, actions, num_steps=10, cfg=0, negative_actions=None):
8
  device = v.device
9
  ts = 1 - t.linspace(0, 1, num_steps+1, device=device)
10
  ts = 3*ts/(2*ts + 1)
@@ -12,12 +12,23 @@ def sample_with_grad(v, z, actions, num_steps=10, cfg=0, negative_actions=None):
12
  z_prev = z_prev.to(device)
13
  for i in range(len(ts)-1):
14
  t_cond = ts[i].repeat(z_prev.shape[0], 1)
15
- v_pred = v(z_prev.to(device), actions.to(device), t_cond.to(device))
 
 
 
 
 
 
 
 
 
16
  if cfg > 0:
 
 
17
  if negative_actions is not None:
18
- v_neg = v(z_prev.to(device), negative_actions.to(device), t_cond.to(device))
19
  else:
20
- v_neg = v(z_prev.to(device), t.zeros_like(actions, dtype=t.long, device=device), t_cond.to(device))
21
  v_pred = v_neg + cfg * (v_pred - v_neg)
22
  z_prev = z_prev + (ts[i] - ts[i+1])*v_pred
23
  return z_prev
 
1
  import torch as t
2
 
3
  @t.no_grad()
4
+ def sample(v, z, actions, num_steps=10, cfg=0, negative_actions=None, cache=None):
5
+ return sample_with_grad(v, z, actions, num_steps, cfg, negative_actions, cache=cache)
6
 
7
+ def sample_with_grad(v, z, actions, num_steps=10, cfg=0, negative_actions=None, cache=None):
8
  device = v.device
9
  ts = 1 - t.linspace(0, 1, num_steps+1, device=device)
10
  ts = 3*ts/(2*ts + 1)
 
12
  z_prev = z_prev.to(device)
13
  for i in range(len(ts)-1):
14
  t_cond = ts[i].repeat(z_prev.shape[0], 1)
15
+ cached_k = None
16
+ cached_v = None
17
+ if cache is not None:
18
+ cached_k, cached_v = cache.get()
19
+
20
+ v_pred, k_new, v_new = v(z_prev.to(device), actions.to(device), t_cond.to(device), cached_k=cached_k, cached_v=cached_v)
21
+
22
+ if i == len(ts)-2 and cache is not None:
23
+ cache.extend(k_new, v_new)
24
+
25
  if cfg > 0:
26
+ if cache is not None:
27
+ raise NotImplementedError("this is not implemented yet")
28
  if negative_actions is not None:
29
+ v_neg, _, _ = v(z_prev.to(device), negative_actions.to(device), t_cond.to(device))
30
  else:
31
+ v_neg, _, _ = v(z_prev.to(device), t.zeros_like(actions, dtype=t.long, device=device), t_cond.to(device))
32
  v_pred = v_neg + cfg * (v_pred - v_neg)
33
  z_prev = z_prev + (ts[i] - ts[i+1])*v_pred
34
  return z_prev