fedebotu commited on
Commit
2011e8a
·
1 Parent(s): edebed2

[Chore] propagate changes

Browse files
Files changed (3) hide show
  1. configuration_rnd.py +1 -1
  2. generation_config.py +1 -5
  3. sampling.py +27 -58
configuration_rnd.py CHANGED
@@ -30,7 +30,7 @@ class RND1Config(Qwen3MoeConfig):
30
  self,
31
  moe_backend: str = "hf",
32
  num_diffusion_steps: int = 256,
33
- mask_token_id: int = 151669, # Default for Qwen-based RND1 models
34
  use_cache: bool = False,
35
  **kwargs,
36
  ):
 
30
  self,
31
  moe_backend: str = "hf",
32
  num_diffusion_steps: int = 256,
33
+ mask_token_id: int = 151669,
34
  use_cache: bool = False,
35
  **kwargs,
36
  ):
generation_config.py CHANGED
@@ -31,12 +31,11 @@ class RND1GenerationConfig(GenerationConfig):
31
  self,
32
  max_length: int = 256,
33
  num_diffusion_steps: int = 256,
34
- mask_token_id: int = 151669, # Default for Qwen-based RND1 models
35
  temperature: float = 1.0,
36
  top_k: Optional[int] = None,
37
  top_p: Optional[float] = None,
38
  greedy: bool = True,
39
- seed: Optional[int] = None, # For reproducible generation
40
  bos_token_id: Optional[int] = None,
41
  eos_token_id: Optional[int] = None,
42
  pad_token_id: Optional[int] = None,
@@ -64,7 +63,6 @@ class RND1GenerationConfig(GenerationConfig):
64
  self.mask_token_id = mask_token_id
65
  self.greedy = greedy
66
  self.temperature = float(temperature) # Ensure it's a float
67
- self.seed = seed
68
 
69
  def to_dict(self):
70
  """Convert configuration to dictionary."""
@@ -72,6 +70,4 @@ class RND1GenerationConfig(GenerationConfig):
72
  output["num_diffusion_steps"] = self.num_diffusion_steps
73
  output["mask_token_id"] = self.mask_token_id
74
  output["greedy"] = self.greedy
75
- if self.seed is not None:
76
- output["seed"] = self.seed
77
  return output
 
31
  self,
32
  max_length: int = 256,
33
  num_diffusion_steps: int = 256,
34
+ mask_token_id: int = 151669,
35
  temperature: float = 1.0,
36
  top_k: Optional[int] = None,
37
  top_p: Optional[float] = None,
38
  greedy: bool = True,
 
39
  bos_token_id: Optional[int] = None,
40
  eos_token_id: Optional[int] = None,
41
  pad_token_id: Optional[int] = None,
 
63
  self.mask_token_id = mask_token_id
64
  self.greedy = greedy
65
  self.temperature = float(temperature) # Ensure it's a float
 
66
 
67
  def to_dict(self):
68
  """Convert configuration to dictionary."""
 
70
  output["num_diffusion_steps"] = self.num_diffusion_steps
71
  output["mask_token_id"] = self.mask_token_id
72
  output["greedy"] = self.greedy
 
 
73
  return output
sampling.py CHANGED
@@ -16,41 +16,25 @@ def apply_top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor:
16
  """
17
  Apply top-k filtering to logits: with non-top-k values set to -inf
18
  """
19
- if k is None or k <= 0:
20
- return torch.full_like(logits, float("-inf"))
21
- k = min(k, logits.size(-1))
22
- top_k_values, top_k_indices = torch.topk(logits, k, dim=-1)
23
- filtered = torch.full_like(logits, float("-inf"))
24
- filtered.scatter_(-1, top_k_indices, top_k_values)
25
- return filtered
26
 
27
 
28
- def apply_top_p_filtering(logits: torch.Tensor, p: float, min_tokens_to_keep: int = 1) -> torch.Tensor:
29
  """
30
  Apply top-p (nucleus) filtering to logits: with tokens beyond threshold set to -inf
31
  """
32
- if p <= 0:
33
- p = 1e-8
34
- if p >= 1:
35
- return logits
36
-
37
  sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
 
38
 
39
- probs = torch.softmax(sorted_logits, dim=-1)
40
- cumulative_probs = torch.cumsum(probs, dim=-1)
41
-
42
  sorted_indices_to_remove = cumulative_probs > p
43
-
44
- if min_tokens_to_keep > 0:
45
- sorted_indices_to_remove[..., :min_tokens_to_keep] = False
46
-
47
  sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
48
- sorted_indices_to_remove[..., 0] = False
49
-
50
- indices_to_remove = torch.zeros_like(sorted_indices_to_remove)
51
- indices_to_remove.scatter_(-1, sorted_indices, sorted_indices_to_remove)
52
-
53
- return logits.masked_fill(indices_to_remove, float("-inf"))
54
 
55
 
56
  @torch.no_grad()
@@ -171,22 +155,13 @@ def diffusion_sample(
171
  if bos_token_id is not None:
172
  x[0, 0] = bos_token_id
173
  if eos_token_id is not None:
174
- # If eos_token_id is a list, use the first one
175
- if isinstance(eos_token_id, (list, tuple)):
176
- x[0, -1] = eos_token_id[0]
177
- else:
178
- x[0, -1] = eos_token_id
179
  init_maskable = x.eq(mask_token_id)
180
 
181
  if bos_token_id is not None:
182
  init_maskable[:, 0] = False
183
  if eos_token_id is not None:
184
- # Handle both single token and list of tokens
185
- if isinstance(eos_token_id, (list, tuple)):
186
- for eos_id in eos_token_id:
187
- init_maskable &= x.ne(eos_id)
188
- else:
189
- init_maskable &= x.ne(eos_token_id)
190
  init_maskable &= x.ne(pad_token_id)
191
 
192
  maskable = init_maskable.clone()
@@ -204,34 +179,28 @@ def diffusion_sample(
204
  # Fall back to positional argument
205
  model_output = model(tokens)
206
 
207
- # Apply temperature scaling (if temperature == 0, treat as 1.0 for greedy)
208
- logits = model_output.logits
209
- if temperature > 0:
210
- logits = logits / temperature
211
 
212
- # Apply filtering only when not in greedy mode
213
- # Order matches reference: top_p before top_k
214
- if not greedy:
215
- if top_p is not None and 0 < top_p < 1.0:
216
- logits = apply_top_p_filtering(logits, top_p)
217
 
218
- if top_k is not None and top_k > 0:
219
- logits = apply_top_k_filtering(logits, top_k)
 
 
 
220
 
221
- # Compute probabilities for sampling and metrics
222
- probs = torch.softmax(logits, dim=-1)
223
- logp = torch.log(probs + 1e-10) # Add epsilon for numerical stability
224
 
 
 
 
 
225
  if greedy:
226
  pred_next = logp.argmax(-1)
227
  else:
228
- # Sample from categorical distribution with proper RNG handling
229
- if generator is not None:
230
- # Use multinomial with generator for reproducible sampling
231
- pred_next = torch.multinomial(probs.view(-1, probs.size(-1)), 1, generator=generator).squeeze(-1).view(probs.shape[:-1])
232
- else:
233
- # Sample from categorical using probabilities
234
- pred_next = torch.distributions.Categorical(probs=probs).sample()
235
 
236
  conf_next = torch.gather(logp, -1, pred_next.unsqueeze(-1)).squeeze(-1)
237
 
 
16
  """
17
  Apply top-k filtering to logits: with non-top-k values set to -inf
18
  """
19
+ top_k_values, top_k_indices = torch.topk(logits, min(k, logits.size(-1)), dim=-1)
20
+ filtered_logits = torch.full_like(logits, float('-inf'))
21
+ filtered_logits.scatter_(-1, top_k_indices, top_k_values)
22
+ return filtered_logits
 
 
 
23
 
24
 
25
+ def apply_top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor:
26
  """
27
  Apply top-p (nucleus) filtering to logits: with tokens beyond threshold set to -inf
28
  """
 
 
 
 
 
29
  sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
30
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
31
 
32
+ # Remove tokens with cumulative probability above threshold
 
 
33
  sorted_indices_to_remove = cumulative_probs > p
34
+ sorted_indices_to_remove[..., 0] = False # Keep at least one token
 
 
 
35
  sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
36
+ indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
37
+ return logits.masked_fill(indices_to_remove, float('-inf'))
 
 
 
 
38
 
39
 
40
  @torch.no_grad()
 
155
  if bos_token_id is not None:
156
  x[0, 0] = bos_token_id
157
  if eos_token_id is not None:
158
+ x[0, -1] = eos_token_id
 
 
 
 
159
  init_maskable = x.eq(mask_token_id)
160
 
161
  if bos_token_id is not None:
162
  init_maskable[:, 0] = False
163
  if eos_token_id is not None:
164
+ init_maskable &= x.ne(eos_token_id)
 
 
 
 
 
165
  init_maskable &= x.ne(pad_token_id)
166
 
167
  maskable = init_maskable.clone()
 
179
  # Fall back to positional argument
180
  model_output = model(tokens)
181
 
182
+ # Apply temperature scaling (with safety for near-zero temperature)
183
+ safe_temperature = max(temperature, 1e-8) # Prevent division by zero
184
+ logits = model_output.logits / safe_temperature
 
185
 
 
 
 
 
 
186
 
187
+ # Apply filtering strategies
188
+ # Note: When both top_k and top_p are provided, they are applied sequentially:
189
+ # First top_k filters to k tokens, then top_p filters from those k tokens
190
+ if top_k is not None and top_k > 0:
191
+ logits = apply_top_k_filtering(logits, top_k)
192
 
193
+ if top_p is not None and 0 < top_p < 1.0:
194
+ logits = apply_top_p_filtering(logits, top_p)
 
195
 
196
+ # Convert to log probabilities
197
+ logp = torch.log_softmax(logits, dim=-1)
198
+
199
+ # Greedy or stochastic sampling
200
  if greedy:
201
  pred_next = logp.argmax(-1)
202
  else:
203
+ pred_next = torch.distributions.Categorical(logits=logp).sample(generator=generator)
 
 
 
 
 
 
204
 
205
  conf_next = torch.gather(logp, -1, pred_next.unsqueeze(-1)).squeeze(-1)
206