|
|
|
|
|
import warnings |
|
|
import copy |
|
|
from dataclasses import dataclass |
|
|
from typing import Any, Dict, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.distributions as dists |
|
|
from torch.nn import functional as F |
|
|
from transformers import __version__ |
|
|
from transformers.generation.configuration_utils import GenerationConfig |
|
|
from transformers.utils import ModelOutput, is_torchdynamo_compiling, logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
def _apply_top_p_k_temp(logits, temperature=0.0, top_p=None, top_k=None): |
|
|
if temperature and temperature > 0: |
|
|
logits = logits / temperature |
|
|
if top_p is not None and top_p < 1: |
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device) |
|
|
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove) |
|
|
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min) |
|
|
if top_k is not None: |
|
|
top_k = int(min(top_k, logits.size(-1))) |
|
|
if top_k > 0: |
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
|
|
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) |
|
|
return logits |
|
|
|
|
|
|
|
|
def _confidence_from_probs( |
|
|
probs: torch.Tensor, |
|
|
chosen_ids: Optional[torch.Tensor], |
|
|
mode: str |
|
|
) -> torch.Tensor: |
|
|
"""返回“越大越自信”的标量分数,与解码一致。""" |
|
|
if mode == "entropy": |
|
|
eps = 1e-10 |
|
|
logp = torch.log(probs + eps) |
|
|
return -(probs * logp).sum(dim=-1) |
|
|
elif mode == "maskgit_plus": |
|
|
assert chosen_ids is not None, "maskgit_plus 需要 chosen_ids" |
|
|
return torch.gather(probs, -1, chosen_ids.unsqueeze(-1)).squeeze(-1) |
|
|
elif mode == "topk_margin": |
|
|
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True) |
|
|
return sorted_probs[..., 0] - sorted_probs[..., 1] |
|
|
else: |
|
|
raise ValueError(f"Unknown conf mode: {mode}") |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DreamModelOutput(ModelOutput): |
|
|
sequences: torch.LongTensor = None |
|
|
history: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
|
|
|
class DreamGenerationConfig(GenerationConfig): |
|
|
def __init__(self, **kwargs): |
|
|
|
|
|
self.temperature: float = kwargs.pop("temperature", 0.0) |
|
|
self.top_p: Optional[float] = kwargs.pop("top_p", None) |
|
|
self.top_k: Optional[int] = kwargs.pop("top_k", None) |
|
|
|
|
|
|
|
|
self.max_length = kwargs.pop("max_length", 20) |
|
|
self.max_new_tokens = kwargs.pop("max_new_tokens", None) |
|
|
|
|
|
|
|
|
self.eps: float = kwargs.pop("eps", 1e-3) |
|
|
self.steps: int = kwargs.pop("steps", 512) |
|
|
|
|
|
|
|
|
self.alg: str = kwargs.pop("alg", 'maskgit_plus') |
|
|
self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None) |
|
|
|
|
|
|
|
|
self.rcr: bool = kwargs.pop("rcr", False) |
|
|
|
|
|
self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus') |
|
|
|
|
|
self.rcr_start_step: int = kwargs.pop("rcr_start_step", 0) |
|
|
self.rcr_end_step: int = kwargs.pop("rcr_end_step", None) or self.steps |
|
|
|
|
|
self.rcr_protect_current_step: bool = kwargs.pop("rcr_protect_current_step", False) |
|
|
|
|
|
|
|
|
self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1) |
|
|
self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False) |
|
|
self.output_history: bool = kwargs.pop("output_history", False) |
|
|
|
|
|
|
|
|
self.mask_token_id = kwargs.pop("mask_token_id", None) |
|
|
self.pad_token_id = kwargs.pop("pad_token_id", None) |
|
|
self.bos_token_id = kwargs.pop("bos_token_id", None) |
|
|
self.eos_token_id = kwargs.pop("eos_token_id", None) |
|
|
|
|
|
|
|
|
self.generation_kwargs = kwargs.pop("generation_kwargs", {}) |
|
|
|
|
|
|
|
|
self._from_model_config = kwargs.pop("_from_model_config", False) |
|
|
self._commit_hash = kwargs.pop("_commit_hash", None) |
|
|
self.transformers_version = kwargs.pop("transformers_version", __version__) |
|
|
|
|
|
if not self._from_model_config: |
|
|
for key, value in kwargs.items(): |
|
|
try: |
|
|
setattr(self, key, value) |
|
|
except AttributeError as err: |
|
|
logger.error(f"Can't set {key} with value {value} for {self}") |
|
|
raise err |
|
|
|
|
|
self.validate(is_init=True) |
|
|
|
|
|
def validate(self, is_init=False): |
|
|
|
|
|
self.rcr_start_step = max(0, int(self.rcr_start_step)) |
|
|
self.rcr_end_step = max(self.rcr_start_step, int(self.rcr_end_step)) |
|
|
|
|
|
|
|
|
class DreamGenerationMixin: |
|
|
@staticmethod |
|
|
def _expand_inputs_for_generation( |
|
|
expand_size: int = 1, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.LongTensor] = None |
|
|
): |
|
|
if expand_size == 1: |
|
|
return input_ids, attention_mask |
|
|
if input_ids is not None: |
|
|
input_ids = input_ids.repeat_interleave(expand_size, dim=0) |
|
|
if attention_mask is not None: |
|
|
attention_mask = attention_mask.repeat_interleave(expand_size, dim=0) |
|
|
return input_ids, attention_mask |
|
|
|
|
|
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length): |
|
|
if is_torchdynamo_compiling(): |
|
|
return |
|
|
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: |
|
|
warnings.warn( |
|
|
f"Using default `max_length` (={generation_config.max_length}). Prefer `max_new_tokens`.", |
|
|
UserWarning, |
|
|
) |
|
|
if input_ids_length >= generation_config.max_length: |
|
|
raise ValueError( |
|
|
f"Input length is {input_ids_length}, but `max_length` is {generation_config.max_length}. " |
|
|
"Increase `max_length` or set `max_new_tokens`." |
|
|
) |
|
|
|
|
|
def _prepare_generated_length(self, generation_config, has_default_max_length, input_ids_length): |
|
|
if generation_config.max_new_tokens is not None: |
|
|
if not has_default_max_length and generation_config.max_length is not None: |
|
|
logger.warning("Both `max_new_tokens` and `max_length` are set. `max_new_tokens` takes precedence.") |
|
|
generation_config.max_length = generation_config.max_new_tokens + input_ids_length |
|
|
elif has_default_max_length: |
|
|
if generation_config.max_length == DreamGenerationConfig().max_length: |
|
|
generation_config.max_length = generation_config.max_length + input_ids_length |
|
|
mpe = getattr(self.config, "max_position_embeddings", None) |
|
|
if mpe is not None: |
|
|
generation_config.max_length = min(generation_config.max_length, mpe) |
|
|
return generation_config |
|
|
|
|
|
def _prepare_generation_config(self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict) -> DreamGenerationConfig: |
|
|
using_model_generation_config = False |
|
|
if generation_config is None: |
|
|
generation_config = DreamGenerationConfig.from_model_config(self.config) |
|
|
using_model_generation_config = True |
|
|
|
|
|
if not is_torchdynamo_compiling(): |
|
|
generation_config = copy.deepcopy(generation_config) |
|
|
_ = generation_config.update(**kwargs) |
|
|
if not using_model_generation_config: |
|
|
if generation_config.bos_token_id is None: |
|
|
generation_config.bos_token_id = self.generation_config.bos_token_id |
|
|
if generation_config.eos_token_id is None: |
|
|
generation_config.eos_token_id = self.generation_config.eos_token_id |
|
|
if generation_config.pad_token_id is None: |
|
|
generation_config.pad_token_id = self.generation_config.pad_token_id |
|
|
if generation_config.mask_token_id is None: |
|
|
generation_config.mask_token_id = self.generation_config.mask_token_id |
|
|
|
|
|
return generation_config |
|
|
|
|
|
def _prepare_special_tokens(self, generation_config: DreamGenerationConfig, device=None): |
|
|
def _tensor_or_none(token, device=None): |
|
|
if token is None: |
|
|
return token |
|
|
device = device if device is not None else self.device |
|
|
if isinstance(token, torch.Tensor): |
|
|
return token.to(device) |
|
|
return torch.tensor(token, device=device, dtype=torch.long) |
|
|
|
|
|
bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device) |
|
|
eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device) |
|
|
pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device) |
|
|
mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device) |
|
|
|
|
|
if eos_token_tensor is not None and eos_token_tensor.ndim == 0: |
|
|
eos_token_tensor = eos_token_tensor.unsqueeze(0) |
|
|
if pad_token_tensor is None and eos_token_tensor is not None: |
|
|
pad_token_tensor = eos_token_tensor[0] |
|
|
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.") |
|
|
|
|
|
generation_config._bos_token_tensor = bos_token_tensor |
|
|
generation_config._eos_token_tensor = eos_token_tensor |
|
|
generation_config._pad_token_tensor = pad_token_tensor |
|
|
generation_config._mask_token_tensor = mask_token_tensor |
|
|
|
|
|
@torch.no_grad() |
|
|
def diffusion_generate( |
|
|
self, |
|
|
inputs: Optional[torch.Tensor] = None, |
|
|
generation_config: Optional[DreamGenerationConfig] = None, |
|
|
**kwargs, |
|
|
): |
|
|
generation_config = self._prepare_generation_config(generation_config, **kwargs) |
|
|
generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x) |
|
|
generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits) |
|
|
|
|
|
assert inputs is not None |
|
|
input_ids = inputs |
|
|
device = input_ids.device |
|
|
attention_mask = kwargs.pop("attention_mask", None) |
|
|
self._prepare_special_tokens(generation_config, device=device) |
|
|
|
|
|
input_ids_length = input_ids.shape[-1] |
|
|
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None |
|
|
generation_config = self._prepare_generated_length( |
|
|
generation_config=generation_config, |
|
|
has_default_max_length=has_default_max_length, |
|
|
input_ids_length=input_ids_length, |
|
|
) |
|
|
|
|
|
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) |
|
|
|
|
|
if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type: |
|
|
warnings.warn( |
|
|
"You are calling .generate() with `input_ids` on a device different from the model.", |
|
|
UserWarning, |
|
|
) |
|
|
if ( |
|
|
hasattr(generation_config, "pad_token_id") |
|
|
and torch.any(input_ids == generation_config.pad_token_id) |
|
|
and attention_mask is None |
|
|
): |
|
|
warnings.warn( |
|
|
"Padding detected but no attention mask was passed. Set `attention_mask` for correct generation.", |
|
|
UserWarning, |
|
|
) |
|
|
|
|
|
input_ids, attention_mask = self._expand_inputs_for_generation( |
|
|
expand_size=generation_config.num_return_sequences, |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
) |
|
|
|
|
|
return self._sample( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
generation_config=generation_config, |
|
|
generation_tokens_hook_func=generation_tokens_hook_func, |
|
|
generation_logits_hook_func=generation_logits_hook_func, |
|
|
) |
|
|
|
|
|
def _sample( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
attention_mask: Optional[torch.LongTensor], |
|
|
generation_config: DreamGenerationConfig, |
|
|
generation_tokens_hook_func, |
|
|
generation_logits_hook_func |
|
|
): |
|
|
output_history = generation_config.output_history |
|
|
return_dict_in_generate = generation_config.return_dict_in_generate |
|
|
max_length = generation_config.max_length |
|
|
mask_token_id = generation_config.mask_token_id |
|
|
steps = generation_config.steps |
|
|
eps = generation_config.eps |
|
|
alg = generation_config.alg |
|
|
alg_temp = generation_config.alg_temp |
|
|
temperature = generation_config.temperature |
|
|
top_p = generation_config.top_p |
|
|
top_k = generation_config.top_k |
|
|
|
|
|
rcr = generation_config.rcr |
|
|
conf_alg = generation_config.conf_alg if rcr else generation_config.alg |
|
|
|
|
|
|
|
|
rcr_start = max(0, steps // 4) |
|
|
rcr_end = max(rcr_start, min(steps, (3 * steps) // 4)) |
|
|
|
|
|
protect_cur = bool(generation_config.rcr_protect_current_step) |
|
|
|
|
|
histories = [] if (return_dict_in_generate and output_history) else None |
|
|
|
|
|
|
|
|
x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id) |
|
|
|
|
|
if attention_mask is not None and torch.any(attention_mask == 0.0): |
|
|
attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0) |
|
|
tok_idx = attention_mask.long().cumsum(-1) - 1 |
|
|
tok_idx.masked_fill_(attention_mask == 0, 1) |
|
|
attention_mask = torch.logical_and( |
|
|
attention_mask.unsqueeze(1).unsqueeze(-2), |
|
|
attention_mask.unsqueeze(1).unsqueeze(-1), |
|
|
) |
|
|
else: |
|
|
tok_idx = None |
|
|
attention_mask = "full" |
|
|
|
|
|
timesteps = torch.linspace(1, eps, steps + 1, device=x.device) |
|
|
|
|
|
|
|
|
if rcr: |
|
|
init_mask_bool = (x == mask_token_id) |
|
|
init_mask_count = init_mask_bool.sum(dim=1) |
|
|
hist_conf = torch.zeros_like(x, dtype=torch.float32, device=x.device) |
|
|
gen_mask = torch.zeros_like(x, dtype=torch.bool, device=x.device) |
|
|
written_step = torch.full_like(x, -1, dtype=torch.int32, device=x.device) |
|
|
|
|
|
x = generation_tokens_hook_func(None, x, None) |
|
|
|
|
|
for i in range(steps): |
|
|
mask_index = (x == mask_token_id) |
|
|
|
|
|
|
|
|
logits = self(x, attention_mask, tok_idx).logits |
|
|
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1) |
|
|
logits = generation_logits_hook_func(i, x, logits) |
|
|
|
|
|
|
|
|
t = timesteps[i] |
|
|
s = timesteps[i + 1] |
|
|
|
|
|
|
|
|
mask_logits = logits[mask_index] |
|
|
if mask_logits.numel() == 0: |
|
|
x = generation_tokens_hook_func(i, x, logits) |
|
|
if histories is not None: |
|
|
histories.append(x.clone()) |
|
|
continue |
|
|
|
|
|
mask_logits = _apply_top_p_k_temp(mask_logits, temperature, top_p, top_k) |
|
|
probs = torch.softmax(mask_logits, dim=-1) |
|
|
|
|
|
|
|
|
if temperature and temperature > 0: |
|
|
try: |
|
|
x0 = dists.Categorical(probs=probs).sample() |
|
|
except Exception: |
|
|
x0 = probs.argmax(dim=-1) |
|
|
else: |
|
|
x0 = probs.argmax(dim=-1) |
|
|
|
|
|
|
|
|
conf_now = _confidence_from_probs( |
|
|
probs=probs, |
|
|
chosen_ids=x0 if conf_alg == "maskgit_plus" else None, |
|
|
mode=conf_alg |
|
|
).to(torch.float32) |
|
|
|
|
|
|
|
|
Mt = mask_index.sum().item() |
|
|
ratio = (1.0 - (s.item() / t.item())) if i < steps - 1 else 1.0 |
|
|
k_t = int(Mt * ratio) |
|
|
|
|
|
|
|
|
full_conf_now = torch.full((x.size(0), x.size(1)), -1e9, dtype=torch.float32, device=x.device) |
|
|
full_x0 = torch.full_like(x, mask_token_id, dtype=torch.long) |
|
|
full_conf_now[mask_index] = conf_now |
|
|
full_x0[mask_index] = x0 |
|
|
|
|
|
for b in range(x.size(0)): |
|
|
masked_b = int(mask_index[b].sum().item()) |
|
|
if masked_b == 0 or k_t <= 0: |
|
|
continue |
|
|
k_b = min(k_t, masked_b) |
|
|
_, sel_idx = torch.topk(full_conf_now[b], k=k_b, largest=True) |
|
|
x[b, sel_idx] = full_x0[b, sel_idx] |
|
|
if rcr: |
|
|
gen_mask[b, sel_idx] = True |
|
|
written_step[b, sel_idx] = i |
|
|
|
|
|
hist_conf[b, sel_idx] = torch.maximum(hist_conf[b, sel_idx], full_conf_now[b, sel_idx]) |
|
|
|
|
|
|
|
|
if rcr and (rcr_start <= i < rcr_end): |
|
|
for b in range(x.size(0)): |
|
|
M0 = int(init_mask_count[b].item()) |
|
|
target_cum = M0 if i >= steps - 1 else int(M0 * (1.0 - (s.item() / t.item()))) |
|
|
|
|
|
C_t = int((gen_mask[b] & init_mask_bool[b]).sum().item()) |
|
|
over = max(0, C_t - target_cum) |
|
|
if over <= 0: |
|
|
continue |
|
|
|
|
|
|
|
|
cand = torch.where(gen_mask[b] & init_mask_bool[b])[0] |
|
|
if cand.numel() == 0: |
|
|
continue |
|
|
if protect_cur: |
|
|
mask_old = (written_step[b, cand] < i) |
|
|
cand = cand[mask_old] |
|
|
if cand.numel() == 0: |
|
|
|
|
|
continue |
|
|
|
|
|
over = min(over, int(cand.numel())) |
|
|
scores = hist_conf[b, cand] |
|
|
_, low_local = torch.topk(scores, k=over, largest=False) |
|
|
low_global = cand[low_local] |
|
|
|
|
|
|
|
|
x[b, low_global] = mask_token_id |
|
|
gen_mask[b, low_global] = False |
|
|
|
|
|
|
|
|
x = generation_tokens_hook_func(i, x, logits) |
|
|
if histories is not None: |
|
|
histories.append(x.clone()) |
|
|
|
|
|
if return_dict_in_generate: |
|
|
return DreamModelOutput(sequences=x, history=histories) |
|
|
else: |
|
|
return x |
|
|
|