[Chore] propagate changes
Browse files- configuration_rnd.py +1 -1
- generation_config.py +1 -5
- sampling.py +27 -58
configuration_rnd.py
CHANGED
|
@@ -30,7 +30,7 @@ class RND1Config(Qwen3MoeConfig):
|
|
| 30 |
self,
|
| 31 |
moe_backend: str = "hf",
|
| 32 |
num_diffusion_steps: int = 256,
|
| 33 |
-
mask_token_id: int = 151669,
|
| 34 |
use_cache: bool = False,
|
| 35 |
**kwargs,
|
| 36 |
):
|
|
|
|
| 30 |
self,
|
| 31 |
moe_backend: str = "hf",
|
| 32 |
num_diffusion_steps: int = 256,
|
| 33 |
+
mask_token_id: int = 151669,
|
| 34 |
use_cache: bool = False,
|
| 35 |
**kwargs,
|
| 36 |
):
|
generation_config.py
CHANGED
|
@@ -31,12 +31,11 @@ class RND1GenerationConfig(GenerationConfig):
|
|
| 31 |
self,
|
| 32 |
max_length: int = 256,
|
| 33 |
num_diffusion_steps: int = 256,
|
| 34 |
-
mask_token_id: int = 151669,
|
| 35 |
temperature: float = 1.0,
|
| 36 |
top_k: Optional[int] = None,
|
| 37 |
top_p: Optional[float] = None,
|
| 38 |
greedy: bool = True,
|
| 39 |
-
seed: Optional[int] = None, # For reproducible generation
|
| 40 |
bos_token_id: Optional[int] = None,
|
| 41 |
eos_token_id: Optional[int] = None,
|
| 42 |
pad_token_id: Optional[int] = None,
|
|
@@ -64,7 +63,6 @@ class RND1GenerationConfig(GenerationConfig):
|
|
| 64 |
self.mask_token_id = mask_token_id
|
| 65 |
self.greedy = greedy
|
| 66 |
self.temperature = float(temperature) # Ensure it's a float
|
| 67 |
-
self.seed = seed
|
| 68 |
|
| 69 |
def to_dict(self):
|
| 70 |
"""Convert configuration to dictionary."""
|
|
@@ -72,6 +70,4 @@ class RND1GenerationConfig(GenerationConfig):
|
|
| 72 |
output["num_diffusion_steps"] = self.num_diffusion_steps
|
| 73 |
output["mask_token_id"] = self.mask_token_id
|
| 74 |
output["greedy"] = self.greedy
|
| 75 |
-
if self.seed is not None:
|
| 76 |
-
output["seed"] = self.seed
|
| 77 |
return output
|
|
|
|
| 31 |
self,
|
| 32 |
max_length: int = 256,
|
| 33 |
num_diffusion_steps: int = 256,
|
| 34 |
+
mask_token_id: int = 151669,
|
| 35 |
temperature: float = 1.0,
|
| 36 |
top_k: Optional[int] = None,
|
| 37 |
top_p: Optional[float] = None,
|
| 38 |
greedy: bool = True,
|
|
|
|
| 39 |
bos_token_id: Optional[int] = None,
|
| 40 |
eos_token_id: Optional[int] = None,
|
| 41 |
pad_token_id: Optional[int] = None,
|
|
|
|
| 63 |
self.mask_token_id = mask_token_id
|
| 64 |
self.greedy = greedy
|
| 65 |
self.temperature = float(temperature) # Ensure it's a float
|
|
|
|
| 66 |
|
| 67 |
def to_dict(self):
|
| 68 |
"""Convert configuration to dictionary."""
|
|
|
|
| 70 |
output["num_diffusion_steps"] = self.num_diffusion_steps
|
| 71 |
output["mask_token_id"] = self.mask_token_id
|
| 72 |
output["greedy"] = self.greedy
|
|
|
|
|
|
|
| 73 |
return output
|
sampling.py
CHANGED
|
@@ -16,41 +16,25 @@ def apply_top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor:
|
|
| 16 |
"""
|
| 17 |
Apply top-k filtering to logits: with non-top-k values set to -inf
|
| 18 |
"""
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
filtered = torch.full_like(logits, float("-inf"))
|
| 24 |
-
filtered.scatter_(-1, top_k_indices, top_k_values)
|
| 25 |
-
return filtered
|
| 26 |
|
| 27 |
|
| 28 |
-
def apply_top_p_filtering(logits: torch.Tensor, p: float
|
| 29 |
"""
|
| 30 |
Apply top-p (nucleus) filtering to logits: with tokens beyond threshold set to -inf
|
| 31 |
"""
|
| 32 |
-
if p <= 0:
|
| 33 |
-
p = 1e-8
|
| 34 |
-
if p >= 1:
|
| 35 |
-
return logits
|
| 36 |
-
|
| 37 |
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
|
|
|
| 38 |
|
| 39 |
-
|
| 40 |
-
cumulative_probs = torch.cumsum(probs, dim=-1)
|
| 41 |
-
|
| 42 |
sorted_indices_to_remove = cumulative_probs > p
|
| 43 |
-
|
| 44 |
-
if min_tokens_to_keep > 0:
|
| 45 |
-
sorted_indices_to_remove[..., :min_tokens_to_keep] = False
|
| 46 |
-
|
| 47 |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 48 |
-
sorted_indices_to_remove
|
| 49 |
-
|
| 50 |
-
indices_to_remove = torch.zeros_like(sorted_indices_to_remove)
|
| 51 |
-
indices_to_remove.scatter_(-1, sorted_indices, sorted_indices_to_remove)
|
| 52 |
-
|
| 53 |
-
return logits.masked_fill(indices_to_remove, float("-inf"))
|
| 54 |
|
| 55 |
|
| 56 |
@torch.no_grad()
|
|
@@ -171,22 +155,13 @@ def diffusion_sample(
|
|
| 171 |
if bos_token_id is not None:
|
| 172 |
x[0, 0] = bos_token_id
|
| 173 |
if eos_token_id is not None:
|
| 174 |
-
|
| 175 |
-
if isinstance(eos_token_id, (list, tuple)):
|
| 176 |
-
x[0, -1] = eos_token_id[0]
|
| 177 |
-
else:
|
| 178 |
-
x[0, -1] = eos_token_id
|
| 179 |
init_maskable = x.eq(mask_token_id)
|
| 180 |
|
| 181 |
if bos_token_id is not None:
|
| 182 |
init_maskable[:, 0] = False
|
| 183 |
if eos_token_id is not None:
|
| 184 |
-
|
| 185 |
-
if isinstance(eos_token_id, (list, tuple)):
|
| 186 |
-
for eos_id in eos_token_id:
|
| 187 |
-
init_maskable &= x.ne(eos_id)
|
| 188 |
-
else:
|
| 189 |
-
init_maskable &= x.ne(eos_token_id)
|
| 190 |
init_maskable &= x.ne(pad_token_id)
|
| 191 |
|
| 192 |
maskable = init_maskable.clone()
|
|
@@ -204,34 +179,28 @@ def diffusion_sample(
|
|
| 204 |
# Fall back to positional argument
|
| 205 |
model_output = model(tokens)
|
| 206 |
|
| 207 |
-
# Apply temperature scaling (
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
logits = logits / temperature
|
| 211 |
|
| 212 |
-
# Apply filtering only when not in greedy mode
|
| 213 |
-
# Order matches reference: top_p before top_k
|
| 214 |
-
if not greedy:
|
| 215 |
-
if top_p is not None and 0 < top_p < 1.0:
|
| 216 |
-
logits = apply_top_p_filtering(logits, top_p)
|
| 217 |
|
| 218 |
-
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
logp = torch.log(probs + 1e-10) # Add epsilon for numerical stability
|
| 224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
if greedy:
|
| 226 |
pred_next = logp.argmax(-1)
|
| 227 |
else:
|
| 228 |
-
|
| 229 |
-
if generator is not None:
|
| 230 |
-
# Use multinomial with generator for reproducible sampling
|
| 231 |
-
pred_next = torch.multinomial(probs.view(-1, probs.size(-1)), 1, generator=generator).squeeze(-1).view(probs.shape[:-1])
|
| 232 |
-
else:
|
| 233 |
-
# Sample from categorical using probabilities
|
| 234 |
-
pred_next = torch.distributions.Categorical(probs=probs).sample()
|
| 235 |
|
| 236 |
conf_next = torch.gather(logp, -1, pred_next.unsqueeze(-1)).squeeze(-1)
|
| 237 |
|
|
|
|
| 16 |
"""
|
| 17 |
Apply top-k filtering to logits: with non-top-k values set to -inf
|
| 18 |
"""
|
| 19 |
+
top_k_values, top_k_indices = torch.topk(logits, min(k, logits.size(-1)), dim=-1)
|
| 20 |
+
filtered_logits = torch.full_like(logits, float('-inf'))
|
| 21 |
+
filtered_logits.scatter_(-1, top_k_indices, top_k_values)
|
| 22 |
+
return filtered_logits
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
+
def apply_top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor:
|
| 26 |
"""
|
| 27 |
Apply top-p (nucleus) filtering to logits: with tokens beyond threshold set to -inf
|
| 28 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
| 30 |
+
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
| 31 |
|
| 32 |
+
# Remove tokens with cumulative probability above threshold
|
|
|
|
|
|
|
| 33 |
sorted_indices_to_remove = cumulative_probs > p
|
| 34 |
+
sorted_indices_to_remove[..., 0] = False # Keep at least one token
|
|
|
|
|
|
|
|
|
|
| 35 |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 36 |
+
indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
|
| 37 |
+
return logits.masked_fill(indices_to_remove, float('-inf'))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
@torch.no_grad()
|
|
|
|
| 155 |
if bos_token_id is not None:
|
| 156 |
x[0, 0] = bos_token_id
|
| 157 |
if eos_token_id is not None:
|
| 158 |
+
x[0, -1] = eos_token_id
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
init_maskable = x.eq(mask_token_id)
|
| 160 |
|
| 161 |
if bos_token_id is not None:
|
| 162 |
init_maskable[:, 0] = False
|
| 163 |
if eos_token_id is not None:
|
| 164 |
+
init_maskable &= x.ne(eos_token_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
init_maskable &= x.ne(pad_token_id)
|
| 166 |
|
| 167 |
maskable = init_maskable.clone()
|
|
|
|
| 179 |
# Fall back to positional argument
|
| 180 |
model_output = model(tokens)
|
| 181 |
|
| 182 |
+
# Apply temperature scaling (with safety for near-zero temperature)
|
| 183 |
+
safe_temperature = max(temperature, 1e-8) # Prevent division by zero
|
| 184 |
+
logits = model_output.logits / safe_temperature
|
|
|
|
| 185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
+
# Apply filtering strategies
|
| 188 |
+
# Note: When both top_k and top_p are provided, they are applied sequentially:
|
| 189 |
+
# First top_k filters to k tokens, then top_p filters from those k tokens
|
| 190 |
+
if top_k is not None and top_k > 0:
|
| 191 |
+
logits = apply_top_k_filtering(logits, top_k)
|
| 192 |
|
| 193 |
+
if top_p is not None and 0 < top_p < 1.0:
|
| 194 |
+
logits = apply_top_p_filtering(logits, top_p)
|
|
|
|
| 195 |
|
| 196 |
+
# Convert to log probabilities
|
| 197 |
+
logp = torch.log_softmax(logits, dim=-1)
|
| 198 |
+
|
| 199 |
+
# Greedy or stochastic sampling
|
| 200 |
if greedy:
|
| 201 |
pred_next = logp.argmax(-1)
|
| 202 |
else:
|
| 203 |
+
pred_next = torch.distributions.Categorical(logits=logp).sample(generator=generator)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
conf_next = torch.gather(logp, -1, pred_next.unsqueeze(-1)).squeeze(-1)
|
| 206 |
|