sohv commited on
Commit
7ec2f3b
·
verified ·
1 Parent(s): 2fe1ac9

Upload src/attention.py

Browse files
Files changed (1) hide show
  1. src/attention.py +159 -0
src/attention.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Latent Attention Implementation for nanoKimi
3
+
4
+ This module implements the Latent Attention mechanism used in Kimi-K2,
5
+ which compresses attention representations to reduce memory footprint
6
+ while maintaining performance on long sequences.
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import math
13
+
14
+
15
+ class LatentAttention(nn.Module):
16
+ """
17
+ Latent Attention mechanism that compresses attention representations
18
+
19
+ The key idea is to project keys and values into a lower-dimensional
20
+ latent space, reducing memory usage while preserving attention quality.
21
+
22
+ Args:
23
+ n_embd: embedding dimension
24
+ n_head: number of attention heads
25
+ latent_dim: dimension of the latent space
26
+ dropout: dropout probability
27
+ bias: whether to use bias in linear layers
28
+ """
29
+
30
+ def __init__(self, n_embd, n_head, latent_dim=64, dropout=0.0, bias=True):
31
+ super().__init__()
32
+ assert n_embd % n_head == 0
33
+
34
+ self.n_embd = n_embd
35
+ self.n_head = n_head
36
+ self.latent_dim = latent_dim
37
+ self.head_dim = n_embd // n_head
38
+
39
+ # Query projection (full dimension)
40
+ self.q_proj = nn.Linear(n_embd, n_embd, bias=bias)
41
+
42
+ # Key and Value projections to latent space
43
+ self.k_proj = nn.Linear(n_embd, n_head * latent_dim, bias=bias)
44
+ self.v_proj = nn.Linear(n_embd, n_head * latent_dim, bias=bias)
45
+
46
+ # Output projection
47
+ self.o_proj = nn.Linear(n_head * latent_dim, n_embd, bias=bias)
48
+
49
+ # Dropout
50
+ self.dropout = nn.Dropout(dropout)
51
+ self.resid_dropout = nn.Dropout(dropout)
52
+
53
+ # Scale factor for attention
54
+ self.scale = 1.0 / math.sqrt(latent_dim)
55
+
56
+ def forward(self, x, mask=None):
57
+ B, T, C = x.size() # batch, sequence length, embedding dim
58
+
59
+ # Project to query, key, value
60
+ q = self.q_proj(x) # (B, T, n_embd)
61
+ k = self.k_proj(x) # (B, T, n_head * latent_dim)
62
+ v = self.v_proj(x) # (B, T, n_head * latent_dim)
63
+
64
+ # Reshape for multi-head attention
65
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, n_head, T, head_dim)
66
+ k = k.view(B, T, self.n_head, self.latent_dim).transpose(1, 2) # (B, n_head, T, latent_dim)
67
+ v = v.view(B, T, self.n_head, self.latent_dim).transpose(1, 2) # (B, n_head, T, latent_dim)
68
+
69
+ # Compress queries to latent dimension for attention computation
70
+ # We use a learnable compression matrix
71
+ if not hasattr(self, 'q_compress'):
72
+ self.q_compress = nn.Linear(self.head_dim, self.latent_dim, bias=False).to(x.device)
73
+
74
+ q_compressed = self.q_compress(q) # (B, n_head, T, latent_dim)
75
+
76
+ # Compute attention scores in latent space
77
+ att = torch.matmul(q_compressed, k.transpose(-2, -1)) * self.scale # (B, n_head, T, T)
78
+
79
+ # Apply causal mask
80
+ if mask is not None:
81
+ att = att.masked_fill(mask == 0, float('-inf'))
82
+ else:
83
+ # Create causal mask
84
+ causal_mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
85
+ att = att.masked_fill(causal_mask == 0, float('-inf'))
86
+
87
+ # Apply softmax
88
+ att = F.softmax(att, dim=-1)
89
+ att = self.dropout(att)
90
+
91
+ # Apply attention to values
92
+ y = torch.matmul(att, v) # (B, n_head, T, latent_dim)
93
+
94
+ # Reshape and project back
95
+ y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.latent_dim)
96
+ y = self.o_proj(y)
97
+ y = self.resid_dropout(y)
98
+
99
+ return y
100
+
101
+
102
+ class MultiHeadAttention(nn.Module):
103
+ """
104
+ Standard multi-head attention for comparison
105
+ """
106
+
107
+ def __init__(self, n_embd, n_head, dropout=0.0, bias=True):
108
+ super().__init__()
109
+ assert n_embd % n_head == 0
110
+
111
+ self.n_embd = n_embd
112
+ self.n_head = n_head
113
+ self.head_dim = n_embd // n_head
114
+
115
+ # QKV projection
116
+ self.qkv_proj = nn.Linear(n_embd, 3 * n_embd, bias=bias)
117
+
118
+ # Output projection
119
+ self.o_proj = nn.Linear(n_embd, n_embd, bias=bias)
120
+
121
+ # Dropout
122
+ self.dropout = nn.Dropout(dropout)
123
+ self.resid_dropout = nn.Dropout(dropout)
124
+
125
+ # Scale factor
126
+ self.scale = 1.0 / math.sqrt(self.head_dim)
127
+
128
+ def forward(self, x, mask=None):
129
+ B, T, C = x.size()
130
+
131
+ # Compute QKV
132
+ qkv = self.qkv_proj(x)
133
+ q, k, v = qkv.chunk(3, dim=-1)
134
+
135
+ # Reshape for multi-head attention
136
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
137
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
138
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
139
+
140
+ # Compute attention
141
+ att = torch.matmul(q, k.transpose(-2, -1)) * self.scale
142
+
143
+ # Apply causal mask
144
+ if mask is not None:
145
+ att = att.masked_fill(mask == 0, float('-inf'))
146
+ else:
147
+ causal_mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
148
+ att = att.masked_fill(causal_mask == 0, float('-inf'))
149
+
150
+ att = F.softmax(att, dim=-1)
151
+ att = self.dropout(att)
152
+
153
+ # Apply attention to values
154
+ y = torch.matmul(att, v)
155
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
156
+ y = self.o_proj(y)
157
+ y = self.resid_dropout(y)
158
+
159
+ return y