| | from typing import Optional, Tuple, Union
|
| |
|
| | import torch
|
| | from einops import rearrange
|
| | import torch.nn.functional as F
|
| |
|
| | import triton
|
| | import triton.language as tl
|
| |
|
| |
|
| | @triton.jit
|
| | def rotary_kernel(
|
| | OUT,
|
| | X,
|
| | COS,
|
| | SIN,
|
| | CU_SEQLENS,
|
| | SEQLEN_OFFSETS,
|
| | seqlen,
|
| | nheads,
|
| | rotary_dim,
|
| | seqlen_ro,
|
| | CACHE_KEY_SEQLEN,
|
| |
|
| | stride_out_batch,
|
| | stride_out_nheads,
|
| | stride_out_seqlen,
|
| | stride_out_headdim,
|
| | stride_x_batch,
|
| | stride_x_nheads,
|
| | stride_x_seqlen,
|
| | stride_x_headdim,
|
| | BLOCK_K: tl.constexpr,
|
| | IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
|
| | IS_VARLEN: tl.constexpr,
|
| | INTERLEAVED: tl.constexpr,
|
| | CONJUGATE: tl.constexpr,
|
| | BLOCK_M: tl.constexpr,
|
| | ):
|
| | pid_m = tl.program_id(axis=0)
|
| | pid_batch = tl.program_id(axis=1)
|
| | pid_head = tl.program_id(axis=2)
|
| | rotary_dim_half = rotary_dim // 2
|
| |
|
| | if not IS_VARLEN:
|
| | X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
|
| | OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
|
| | COS = COS + pid_batch * seqlen_ro * rotary_dim_half
|
| | SIN = SIN + pid_batch * seqlen_ro * rotary_dim_half
|
| | else:
|
| | start_idx = tl.load(CU_SEQLENS + pid_batch)
|
| | seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
|
| | X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
|
| | OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads
|
| |
|
| | if pid_m * BLOCK_M >= seqlen:
|
| | return
|
| | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| | if not IS_SEQLEN_OFFSETS_TENSOR:
|
| | rm_cs = rm + SEQLEN_OFFSETS
|
| | else:
|
| | rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
|
| | rk = tl.arange(0, BLOCK_K)
|
| | rk_half = tl.arange(0, BLOCK_K // 2)
|
| |
|
| | if not INTERLEAVED:
|
| |
|
| | X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)
|
| | COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
|
| | SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
|
| | cos = tl.load(
|
| | COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0
|
| | )
|
| | sin = tl.load(
|
| | SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0
|
| | )
|
| | x0 = tl.load(
|
| | X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0
|
| | )
|
| | x1 = tl.load(
|
| | X + rotary_dim_half * stride_x_headdim,
|
| | mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
|
| | other=0.0,
|
| | )
|
| | if CONJUGATE:
|
| | sin = -sin
|
| | o0 = x0 * cos - x1 * sin
|
| | o1 = x0 * sin + x1 * cos
|
| |
|
| | OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
|
| | tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
|
| | tl.store(
|
| | OUT + rotary_dim_half * stride_out_headdim,
|
| | o1,
|
| | mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
|
| | )
|
| | else:
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | rk_swap = rk + ((rk + 1) % 2) * 2 - 1
|
| | rk_repeat = tl.arange(0, BLOCK_K) // 2
|
| | X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
|
| | X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)
|
| | COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
|
| | SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
|
| | cos = tl.load(
|
| | COS,
|
| | mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
|
| | other=1.0,
|
| | ).to(tl.float32)
|
| | sin = tl.load(
|
| | SIN,
|
| | mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
|
| | other=0.0,
|
| | ).to(tl.float32)
|
| | x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(
|
| | tl.float32
|
| | )
|
| | x1 = tl.load(
|
| | X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
|
| | ).to(tl.float32)
|
| | if CONJUGATE:
|
| | sin = -sin
|
| | x0_cos = x0 * cos
|
| | x1_sin = x1 * sin
|
| | out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
|
| | OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
|
| | tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
|
| |
|
| |
|
| | def apply_rotary(
|
| | x: torch.Tensor,
|
| | cos: torch.Tensor,
|
| | sin: torch.Tensor,
|
| | seqlen_offsets: Union[int, torch.Tensor] = 0,
|
| | cu_seqlens: Optional[torch.Tensor] = None,
|
| | max_seqlen: Optional[int] = None,
|
| | interleaved=False,
|
| | inplace=False,
|
| | conjugate=False,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Arguments:
|
| | x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
|
| | else (total_seqlen, nheads, headdim).
|
| | cos: (seqlen_ro, rotary_dim / 2)
|
| | sin: (seqlen_ro, rotary_dim / 2)
|
| | seqlen_offsets: integer or integer tensor of size (batch,)
|
| | cu_seqlens: (batch + 1,) or None
|
| | max_seqlen: int
|
| | Returns:
|
| | y: (batch, seqlen, nheads, headdim)
|
| | """
|
| |
|
| | batch, nheads, seqlen, headdim = x.shape
|
| |
|
| | batch_ro, seqlen_ro, rotary_dim = cos.shape
|
| |
|
| | assert batch == batch_ro
|
| | assert sin.shape == cos.shape
|
| | rotary_dim *= 2
|
| | assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
|
| | assert headdim <= 256, "Only support headdim <= 256"
|
| |
|
| | assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
|
| |
|
| | assert (
|
| | cos.dtype == sin.dtype
|
| | ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
|
| | assert (
|
| | x.dtype == cos.dtype
|
| | ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
|
| |
|
| | cos, sin = cos.contiguous(), sin.contiguous()
|
| | if isinstance(seqlen_offsets, torch.Tensor):
|
| | assert seqlen_offsets.shape == (batch,)
|
| | assert seqlen_offsets.dtype in [torch.int32, torch.int64]
|
| | seqlen_offsets = seqlen_offsets.contiguous()
|
| | else:
|
| | assert seqlen_offsets + seqlen <= seqlen_ro
|
| |
|
| | output = torch.empty_like(x) if not inplace else x
|
| | if rotary_dim < headdim and not inplace:
|
| | output[..., rotary_dim:].copy_(x[..., rotary_dim:])
|
| |
|
| | BLOCK_K = (
|
| | 32
|
| | if rotary_dim <= 32
|
| | else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
|
| | )
|
| | grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads)
|
| | BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
|
| |
|
| |
|
| |
|
| | with torch.cuda.device(x.device.index):
|
| | rotary_kernel[grid](
|
| | output,
|
| | x,
|
| | cos,
|
| | sin,
|
| | cu_seqlens,
|
| | seqlen_offsets,
|
| | seqlen,
|
| | nheads,
|
| | rotary_dim,
|
| | seqlen_ro,
|
| | seqlen // 128,
|
| | output.stride(0),
|
| | output.stride(-3),
|
| | output.stride(-2),
|
| | output.stride(-1),
|
| | x.stride(0),
|
| | x.stride(-3),
|
| | x.stride(-2),
|
| | x.stride(-1),
|
| | BLOCK_K,
|
| | isinstance(seqlen_offsets, torch.Tensor),
|
| | False,
|
| | interleaved,
|
| | conjugate,
|
| | BLOCK_M,
|
| | )
|
| | return output
|
| |
|
| |
|
| | class ApplyRotaryEmb(torch.autograd.Function):
|
| | @staticmethod
|
| | def forward(
|
| | ctx,
|
| | x,
|
| | cos,
|
| | sin,
|
| | interleaved=False,
|
| | inplace=False,
|
| | seqlen_offsets: Union[int, torch.Tensor] = 0,
|
| | cu_seqlens: Optional[torch.Tensor] = None,
|
| | max_seqlen: Optional[int] = None,
|
| | ):
|
| | out = apply_rotary(
|
| | x,
|
| | cos,
|
| | sin,
|
| | seqlen_offsets=seqlen_offsets,
|
| | cu_seqlens=cu_seqlens,
|
| | max_seqlen=max_seqlen,
|
| | interleaved=interleaved,
|
| | inplace=inplace,
|
| | )
|
| | if isinstance(seqlen_offsets, int):
|
| | ctx.save_for_backward(cos, sin, cu_seqlens)
|
| | ctx.seqlen_offsets = seqlen_offsets
|
| | else:
|
| | ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
|
| | ctx.seqlen_offsets = None
|
| | ctx.interleaved = interleaved
|
| | ctx.inplace = inplace
|
| | ctx.max_seqlen = max_seqlen
|
| | return out if not inplace else x
|
| |
|
| | @staticmethod
|
| | def backward(ctx, do):
|
| | seqlen_offsets = ctx.seqlen_offsets
|
| | if seqlen_offsets is None:
|
| | cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
|
| | else:
|
| | cos, sin, cu_seqlens = ctx.saved_tensors
|
| |
|
| |
|
| | if not ctx.interleaved and not ctx.inplace:
|
| | do = do.clone()
|
| | dx = apply_rotary(
|
| | do,
|
| | cos,
|
| | sin,
|
| | seqlen_offsets=seqlen_offsets,
|
| | cu_seqlens=cu_seqlens,
|
| | max_seqlen=ctx.max_seqlen,
|
| | interleaved=ctx.interleaved,
|
| | inplace=ctx.inplace,
|
| | conjugate=True,
|
| | )
|
| | return dx, None, None, None, None, None, None, None
|
| |
|
| |
|
| | def apply_rotary_emb(
|
| | x,
|
| | cos,
|
| | sin,
|
| | interleaved=False,
|
| | inplace=False,
|
| | seqlen_offsets: Union[int, torch.Tensor] = 0,
|
| | cu_seqlens: Optional[torch.Tensor] = None,
|
| | max_seqlen: Optional[int] = None,
|
| | ):
|
| | """
|
| | Arguments:
|
| | x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
|
| | else (total_seqlen, nheads, headdim)
|
| | cos, sin: (seqlen_rotary, rotary_dim / 2)
|
| | interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
| | of 1st half and 2nd half (GPT-NeoX style).
|
| | inplace: if True, apply rotary embedding in-place.
|
| | seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
|
| | Most commonly used in inference when we have KV cache.
|
| | cu_seqlens: (batch + 1,) or None
|
| | max_seqlen: int
|
| | Return:
|
| | out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
|
| | else (total_seqlen, nheads, headdim)
|
| | rotary_dim must be <= headdim
|
| | Apply rotary embedding to the first rotary_dim of x.
|
| | """
|
| | return ApplyRotaryEmb.apply(
|
| | x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
|
| | )
|
| |
|
| |
|
| |
|
| | apply_rotary_emb_func = apply_rotary_emb
|
| |
|
| |
|
| | class FastRotaryEmbedding(torch.nn.Module):
|
| | """
|
| | The rotary position embeddings from RoFormer_ (Su et. al).
|
| | A crucial insight from the method is that the query and keys are
|
| | transformed by rotation matrices which depend on the relative positions.
|
| |
|
| | Other implementations are available in the Rotary Transformer repo_ and in
|
| | GPT-NeoX_, GPT-NeoX was an inspiration
|
| |
|
| | .. _RoFormer: https://arxiv.org/abs/2104.09864
|
| | .. _repo: https://github.com/ZhuiyiTechnology/roformer
|
| | .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
|
| |
|
| | If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
|
| | A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
|
| | Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | dim: int,
|
| | base=10000,
|
| | interleaved=False,
|
| | scale_base=None,
|
| | pos_idx_in_fp32=True,
|
| | device=None,
|
| | ):
|
| | """
|
| | interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
| | of 1st half and 2nd half (GPT-NeoX style).
|
| | pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
|
| | otherwise they might be in lower precision.
|
| | This option was added because previously (before 2023-07-02), when we construct
|
| | the position indices, we use the dtype of self.inv_freq. In most cases this would
|
| | be fp32, but if the model is trained in pure bf16 (not mixed precision), then
|
| | self.inv_freq would be bf16, and the position indices are also in bf16.
|
| | Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
|
| | embeddings for some positions will coincide.
|
| | To maintain compatibility with models previously trained in pure bf16,
|
| | we add this option.
|
| | """
|
| | super().__init__()
|
| | self.dim = dim
|
| | self.base = base
|
| | self.pos_idx_in_fp32 = pos_idx_in_fp32
|
| |
|
| | inv_freq = self._compute_inv_freq(device)
|
| | self.register_buffer("inv_freq", inv_freq)
|
| | self.interleaved = interleaved
|
| | self.scale_base = scale_base
|
| | scale = (
|
| | (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
| | if scale_base is not None
|
| | else None
|
| | )
|
| | self.register_buffer("scale", scale, persistent=False)
|
| |
|
| | self._seq_len_cached = 0
|
| | self._cos_cached = None
|
| | self._sin_cached = None
|
| | self._cos_k_cached = None
|
| | self._sin_k_cached = None
|
| | self.cos = None
|
| | self.sin = None
|
| |
|
| | def _compute_inv_freq(self, device=None):
|
| | return 1.0 / (
|
| | self.base
|
| | ** (torch.arange(0, self.dim, 2, device=device) / self.dim)
|
| |
|
| | )
|
| |
|
| | def _update_cos_sin_cache(self, seqlen, position_id, device=None, dtype=None):
|
| |
|
| | if (
|
| | seqlen > self._seq_len_cached
|
| | ):
|
| | self._seq_len_cached = seqlen
|
| |
|
| |
|
| |
|
| | if self.pos_idx_in_fp32:
|
| | t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
| |
|
| |
|
| |
|
| |
|
| | if self.inv_freq.dtype != torch.float32:
|
| | inv_freq = self._compute_inv_freq(device=device)
|
| | else:
|
| | inv_freq = self.inv_freq
|
| | else:
|
| | t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
| | inv_freq = self.inv_freq
|
| | freqs = torch.einsum("i,j->ij", t, inv_freq)
|
| | if self.scale is None:
|
| | self._cos_cached = torch.cos(freqs).to(dtype)
|
| | self._sin_cached = torch.sin(freqs).to(dtype)
|
| |
|
| | else:
|
| | power = (
|
| | torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
|
| | - seqlen // 2
|
| | ) / self.scale_base
|
| | scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
| |
|
| | self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
| | self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
| | self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
| | self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
| |
|
| | def forward(
|
| | self,
|
| | q: torch.Tensor,
|
| | k: torch.Tensor,
|
| | position_ids: torch.Tensor,
|
| | max_seqlen,
|
| | ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| | """
|
| | q: (batch, nheads, seqlen, headdim)
|
| | k: (batch, nheads, seqlen, headdim)
|
| | position_id: (batch, seqlen)
|
| | max_seqlen: int
|
| | layer_id: int
|
| | only if layer_id == 0, then update cons and sin
|
| | Apply rotary embedding *inplace* to q k.
|
| | """
|
| |
|
| | self._update_cos_sin_cache(max_seqlen, position_ids, device=q.device, dtype=q.dtype)
|
| | cos, sin = F.embedding(position_ids, self._cos_cached), F.embedding(position_ids, self._sin_cached)
|
| |
|
| | q = apply_rotary_emb_func(
|
| | q,
|
| | cos,
|
| | sin,
|
| | interleaved=self.interleaved,
|
| | inplace=True
|
| | )
|
| | k = apply_rotary_emb_func(
|
| | k,
|
| | cos,
|
| | sin,
|
| | interleaved=self.interleaved,
|
| | inplace=True
|
| | )
|
| | return q, k
|
| |
|