dream_rcr / generation_utils.py
autoprogrammer's picture
Update generation_utils.py
78f65a4 verified
# coding=utf-8
import warnings
import copy
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.distributions as dists
from torch.nn import functional as F
from transformers import __version__
from transformers.generation.configuration_utils import GenerationConfig
from transformers.utils import ModelOutput, is_torchdynamo_compiling, logging
logger = logging.get_logger(__name__)
def _apply_top_p_k_temp(logits, temperature=0.0, top_p=None, top_k=None):
if temperature and temperature > 0:
logits = logits / temperature
if top_p is not None and top_p < 1:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
if top_k is not None:
top_k = int(min(top_k, logits.size(-1)))
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
return logits
def _confidence_from_probs(
probs: torch.Tensor, # [..., V]
chosen_ids: Optional[torch.Tensor], # [...]
mode: str # 'entropy' | 'maskgit_plus' | 'topk_margin'
) -> torch.Tensor:
"""返回“越大越自信”的标量分数,与解码一致。"""
if mode == "entropy":
eps = 1e-10
logp = torch.log(probs + eps)
return -(probs * logp).sum(dim=-1) # -H(p)
elif mode == "maskgit_plus":
assert chosen_ids is not None, "maskgit_plus 需要 chosen_ids"
return torch.gather(probs, -1, chosen_ids.unsqueeze(-1)).squeeze(-1) # p(x0)
elif mode == "topk_margin":
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
return sorted_probs[..., 0] - sorted_probs[..., 1] # top1 - top2
else:
raise ValueError(f"Unknown conf mode: {mode}")
@dataclass
class DreamModelOutput(ModelOutput):
sequences: torch.LongTensor = None
history: Optional[Tuple[torch.FloatTensor]] = None
class DreamGenerationConfig(GenerationConfig):
def __init__(self, **kwargs):
# sampling
self.temperature: float = kwargs.pop("temperature", 0.0)
self.top_p: Optional[float] = kwargs.pop("top_p", None)
self.top_k: Optional[int] = kwargs.pop("top_k", None)
# length
self.max_length = kwargs.pop("max_length", 20)
self.max_new_tokens = kwargs.pop("max_new_tokens", None)
# diffusion
self.eps: float = kwargs.pop("eps", 1e-3)
self.steps: int = kwargs.pop("steps", 512)
# vanilla 的打分算法(rcr=False 时使用)
self.alg: str = kwargs.pop("alg", 'maskgit_plus') # 'origin' | 'maskgit_plus' | 'topk_margin' | 'entropy'
self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
# === RCR ===
self.rcr: bool = kwargs.pop("rcr", False)
# rcr=True 时用于解码 & 历史分一致的置信度定义
self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus') # 'maskgit_plus' | 'topk_margin' | 'entropy'
# 注意:下两项会被 _sample 内部“写死”为 1/4 到 3/4,总是覆盖
self.rcr_start_step: int = kwargs.pop("rcr_start_step", 0)
self.rcr_end_step: int = kwargs.pop("rcr_end_step", None) or self.steps
# 是否保护“本步刚写”的 token 不被回遮
self.rcr_protect_current_step: bool = kwargs.pop("rcr_protect_current_step", False)
# outputs
self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
self.output_history: bool = kwargs.pop("output_history", False)
# special tokens
self.mask_token_id = kwargs.pop("mask_token_id", None)
self.pad_token_id = kwargs.pop("pad_token_id", None)
self.bos_token_id = kwargs.pop("bos_token_id", None)
self.eos_token_id = kwargs.pop("eos_token_id", None)
# misc
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
# bookkeeping
self._from_model_config = kwargs.pop("_from_model_config", False)
self._commit_hash = kwargs.pop("_commit_hash", None)
self.transformers_version = kwargs.pop("transformers_version", __version__)
if not self._from_model_config:
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error(f"Can't set {key} with value {value} for {self}")
raise err
self.validate(is_init=True)
def validate(self, is_init=False):
# 简单边界
self.rcr_start_step = max(0, int(self.rcr_start_step))
self.rcr_end_step = max(self.rcr_start_step, int(self.rcr_end_step))
class DreamGenerationMixin:
@staticmethod
def _expand_inputs_for_generation(
expand_size: int = 1,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None
):
if expand_size == 1:
return input_ids, attention_mask
if input_ids is not None:
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
if attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
return input_ids, attention_mask
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
if is_torchdynamo_compiling():
return
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
warnings.warn(
f"Using default `max_length` (={generation_config.max_length}). Prefer `max_new_tokens`.",
UserWarning,
)
if input_ids_length >= generation_config.max_length:
raise ValueError(
f"Input length is {input_ids_length}, but `max_length` is {generation_config.max_length}. "
"Increase `max_length` or set `max_new_tokens`."
)
def _prepare_generated_length(self, generation_config, has_default_max_length, input_ids_length):
if generation_config.max_new_tokens is not None:
if not has_default_max_length and generation_config.max_length is not None:
logger.warning("Both `max_new_tokens` and `max_length` are set. `max_new_tokens` takes precedence.")
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
elif has_default_max_length:
if generation_config.max_length == DreamGenerationConfig().max_length:
generation_config.max_length = generation_config.max_length + input_ids_length
mpe = getattr(self.config, "max_position_embeddings", None)
if mpe is not None:
generation_config.max_length = min(generation_config.max_length, mpe)
return generation_config
def _prepare_generation_config(self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict) -> DreamGenerationConfig:
using_model_generation_config = False
if generation_config is None:
generation_config = DreamGenerationConfig.from_model_config(self.config)
using_model_generation_config = True
if not is_torchdynamo_compiling():
generation_config = copy.deepcopy(generation_config)
_ = generation_config.update(**kwargs)
if not using_model_generation_config:
if generation_config.bos_token_id is None:
generation_config.bos_token_id = self.generation_config.bos_token_id
if generation_config.eos_token_id is None:
generation_config.eos_token_id = self.generation_config.eos_token_id
if generation_config.pad_token_id is None:
generation_config.pad_token_id = self.generation_config.pad_token_id
if generation_config.mask_token_id is None:
generation_config.mask_token_id = self.generation_config.mask_token_id
return generation_config
def _prepare_special_tokens(self, generation_config: DreamGenerationConfig, device=None):
def _tensor_or_none(token, device=None):
if token is None:
return token
device = device if device is not None else self.device
if isinstance(token, torch.Tensor):
return token.to(device)
return torch.tensor(token, device=device, dtype=torch.long)
bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device)
eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device)
pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device)
if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
eos_token_tensor = eos_token_tensor.unsqueeze(0)
if pad_token_tensor is None and eos_token_tensor is not None:
pad_token_tensor = eos_token_tensor[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
generation_config._bos_token_tensor = bos_token_tensor
generation_config._eos_token_tensor = eos_token_tensor
generation_config._pad_token_tensor = pad_token_tensor
generation_config._mask_token_tensor = mask_token_tensor
@torch.no_grad()
def diffusion_generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[DreamGenerationConfig] = None,
**kwargs,
):
generation_config = self._prepare_generation_config(generation_config, **kwargs)
generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
assert inputs is not None
input_ids = inputs
device = input_ids.device
attention_mask = kwargs.pop("attention_mask", None)
self._prepare_special_tokens(generation_config, device=device)
input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
input_ids_length=input_ids_length,
)
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
warnings.warn(
"You are calling .generate() with `input_ids` on a device different from the model.",
UserWarning,
)
if (
hasattr(generation_config, "pad_token_id")
and torch.any(input_ids == generation_config.pad_token_id)
and attention_mask is None
):
warnings.warn(
"Padding detected but no attention mask was passed. Set `attention_mask` for correct generation.",
UserWarning,
)
input_ids, attention_mask = self._expand_inputs_for_generation(
expand_size=generation_config.num_return_sequences,
input_ids=input_ids,
attention_mask=attention_mask,
)
return self._sample(
input_ids,
attention_mask=attention_mask,
generation_config=generation_config,
generation_tokens_hook_func=generation_tokens_hook_func,
generation_logits_hook_func=generation_logits_hook_func,
)
def _sample(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor],
generation_config: DreamGenerationConfig,
generation_tokens_hook_func,
generation_logits_hook_func
):
output_history = generation_config.output_history
return_dict_in_generate = generation_config.return_dict_in_generate
max_length = generation_config.max_length
mask_token_id = generation_config.mask_token_id
steps = generation_config.steps
eps = generation_config.eps
alg = generation_config.alg
alg_temp = generation_config.alg_temp
temperature = generation_config.temperature
top_p = generation_config.top_p
top_k = generation_config.top_k
rcr = generation_config.rcr
conf_alg = generation_config.conf_alg if rcr else generation_config.alg
# === 写死 RCR 生效窗口:总步数的 1/4 到 3/4(左闭右开 [start, end))===
rcr_start = max(0, steps // 4)
rcr_end = max(rcr_start, min(steps, (3 * steps) // 4))
protect_cur = bool(generation_config.rcr_protect_current_step)
histories = [] if (return_dict_in_generate and output_history) else None
# pad input_ids to max_length
x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
if attention_mask is not None and torch.any(attention_mask == 0.0):
attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0)
tok_idx = attention_mask.long().cumsum(-1) - 1
tok_idx.masked_fill_(attention_mask == 0, 1)
attention_mask = torch.logical_and(
attention_mask.unsqueeze(1).unsqueeze(-2),
attention_mask.unsqueeze(1).unsqueeze(-1),
)
else:
tok_idx = None
attention_mask = "full"
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
# ==== RCR 状态 ====
if rcr:
init_mask_bool = (x == mask_token_id) # 初始生成区域
init_mask_count = init_mask_bool.sum(dim=1) # [B]
hist_conf = torch.zeros_like(x, dtype=torch.float32, device=x.device) # 历史最大置信度
gen_mask = torch.zeros_like(x, dtype=torch.bool, device=x.device) # 已确认位置
written_step = torch.full_like(x, -1, dtype=torch.int32, device=x.device)
x = generation_tokens_hook_func(None, x, None)
for i in range(steps):
mask_index = (x == mask_token_id)
# 前向 + Dream 的右移对齐
logits = self(x, attention_mask, tok_idx).logits
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
logits = generation_logits_hook_func(i, x, logits)
# 时间步
t = timesteps[i]
s = timesteps[i + 1]
# —— 仅抽出 mask 位置的 logits 并做过滤 ——
mask_logits = logits[mask_index]
if mask_logits.numel() == 0:
x = generation_tokens_hook_func(i, x, logits)
if histories is not None:
histories.append(x.clone())
continue
mask_logits = _apply_top_p_k_temp(mask_logits, temperature, top_p, top_k)
probs = torch.softmax(mask_logits, dim=-1)
# 采样 / 贪心拿到 x0
if temperature and temperature > 0:
try:
x0 = dists.Categorical(probs=probs).sample()
except Exception:
x0 = probs.argmax(dim=-1)
else:
x0 = probs.argmax(dim=-1)
# 统一置信度(与解码一致)
conf_now = _confidence_from_probs(
probs=probs,
chosen_ids=x0 if conf_alg == "maskgit_plus" else None,
mode=conf_alg
).to(torch.float32) # [M]
# ====== 计算当步写入配额 k_t(与 vanilla 一致)======
Mt = mask_index.sum().item()
ratio = (1.0 - (s.item() / t.item())) if i < steps - 1 else 1.0
k_t = int(Mt * ratio)
# —— 写入:top-k_t ——(无论 RCR 窗口与否,先写)
full_conf_now = torch.full((x.size(0), x.size(1)), -1e9, dtype=torch.float32, device=x.device)
full_x0 = torch.full_like(x, mask_token_id, dtype=torch.long)
full_conf_now[mask_index] = conf_now
full_x0[mask_index] = x0
for b in range(x.size(0)):
masked_b = int(mask_index[b].sum().item())
if masked_b == 0 or k_t <= 0:
continue
k_b = min(k_t, masked_b)
_, sel_idx = torch.topk(full_conf_now[b], k=k_b, largest=True)
x[b, sel_idx] = full_x0[b, sel_idx]
if rcr:
gen_mask[b, sel_idx] = True
written_step[b, sel_idx] = i
# 更新历史最大置信度(与解码同定义)
hist_conf[b, sel_idx] = torch.maximum(hist_conf[b, sel_idx], full_conf_now[b, sel_idx])
# —— RCR 窗口外:不回遮,仅跟踪历史;窗口内:执行回遮到目标累计 ——
if rcr and (rcr_start <= i < rcr_end):
for b in range(x.size(0)):
M0 = int(init_mask_count[b].item())
target_cum = M0 if i >= steps - 1 else int(M0 * (1.0 - (s.item() / t.item())))
# 当前累计确认:初始生成区域内的已确认数
C_t = int((gen_mask[b] & init_mask_bool[b]).sum().item())
over = max(0, C_t - target_cum)
if over <= 0:
continue
# 候选:初始区域 ∧ 已确认(可选:排除本步刚写)
cand = torch.where(gen_mask[b] & init_mask_bool[b])[0]
if cand.numel() == 0:
continue
if protect_cur:
mask_old = (written_step[b, cand] < i)
cand = cand[mask_old]
if cand.numel() == 0:
# 全是本步写的,且要求保护,则跳过回遮
continue
over = min(over, int(cand.numel()))
scores = hist_conf[b, cand] # 越大越自信
_, low_local = torch.topk(scores, k=over, largest=False)
low_global = cand[low_local]
# 回遮
x[b, low_global] = mask_token_id
gen_mask[b, low_global] = False
# 历史分数与 written_step 保留
x = generation_tokens_hook_func(i, x, logits)
if histories is not None:
histories.append(x.clone())
if return_dict_in_generate:
return DreamModelOutput(sequences=x, history=histories)
else:
return x