File size: 4,851 Bytes
cdcf094
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1f9c6d
 
 
cdcf094
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#!/usr/bin/env python3
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: Apache-2.0

"""Flexible cross-attention module with xFormers support and automatic fallback."""

import torch
import torch.nn as nn


class FlexibleCrossAttention(nn.MultiheadAttention):
    """Cross-attention with optional xFormers support and automatic fallback to SDPA.

    This module inherits from nn.MultiheadAttention to ensure weight compatibility.
    It overrides forward() to use xFormers when available and requested.

    Uses:
    1. xFormers memory-efficient attention (CUDA only, if installed and use_xformers=True)
    2. PyTorch native SDPA (Scaled Dot Product Attention, PyTorch 2.0+, default)
    3. Standard MultiheadAttention (fallback for older PyTorch versions)

    Args:
        embed_dim: Total dimension of the model
        num_heads: Number of parallel attention heads
        use_xformers: Whether to attempt using xFormers (only works on CUDA)
    """

    def __init__(self, embed_dim, num_heads, use_xformers=False, **kwargs):
        # Initialize parent with batch_first=True to match original usage
        super().__init__(embed_dim, num_heads, batch_first=True, **kwargs)

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Check if xFormers is available and requested
        self.use_xformers = use_xformers and self._check_xformers()

    def _check_xformers(self):
        """Check if xFormers is available for import.

        Returns:
            bool: True if xFormers can be imported, False otherwise
        """
        try:
            import importlib.util

            return importlib.util.find_spec("xformers.ops") is not None
        except (ImportError, ValueError):
            return False

    def forward(self, query, key, value, **kwargs):
        """Forward pass with automatic backend selection.

        Args:
            query: Query tensor of shape [B, N, C]
            key: Key tensor of shape [B, N, C]
            value: Value tensor of shape [B, N, C]

        Returns:
            tuple: (output, attention_weights)
                - output: Attention output of shape [B, N, C]
                - attention_weights: None (not computed for efficiency)
        """
        if not self.use_xformers:
            # Standard path using parent nn.MultiheadAttention (with SDPA in PyTorch 2.0+)
            # This uses the original weights (in_proj_weight, out_proj) from checkpoint
            return super().forward(query, key, value, need_weights=False, **kwargs)
        else:
            # xFormers memory-efficient attention path
            import xformers.ops as xops

            # Use parent's projection weights for Q, K, V
            # in_proj_weight contains concatenated [W_q; W_k; W_v]
            # This ensures we use the exact same weights as standard MultiheadAttention
            if self.in_proj_weight is not None:
                # Split the combined in_proj_weight into Q, K, V weights
                w_q, w_k, w_v = self.in_proj_weight.chunk(3, dim=0)
                b_q, b_k, b_v = None, None, None
                if self.in_proj_bias is not None:
                    b_q, b_k, b_v = self.in_proj_bias.chunk(3, dim=0)

                # Apply projections using the same weights as standard attention
                q = torch.nn.functional.linear(query, w_q, b_q)
                k = torch.nn.functional.linear(key, w_k, b_k)
                v = torch.nn.functional.linear(value, w_v, b_v)
            else:
                # Separate projection weights (shouldn't happen with default config)
                q = torch.nn.functional.linear(query, self.q_proj_weight, self.in_proj_bias)
                k = torch.nn.functional.linear(key, self.k_proj_weight)
                v = torch.nn.functional.linear(value, self.v_proj_weight)

            # Reshape for multi-head attention: [B, N, C] -> [B, N, H, C//H]
            B, N, C = q.shape
            q = q.reshape(B, N, self.num_heads, self.head_dim)
            k = k.reshape(B, N, self.num_heads, self.head_dim)
            v = v.reshape(B, N, self.num_heads, self.head_dim)

            # Apply xFormers memory-efficient attention
            # This is significantly faster and uses less memory than standard attention
            # Scale is 1/sqrt(head_dim) as per standard scaled dot-product attention
            scale = self.head_dim**-0.5
            out = xops.memory_efficient_attention(q, k, v, scale=scale)

            # Reshape back: [B, N, H, C//H] -> [B, N, C]
            out = out.reshape(B, N, C)

            # Use parent's output projection (same weights as standard attention)
            out = torch.nn.functional.linear(out, self.out_proj.weight, self.out_proj.bias)

            return out, None