Update src/nn/attn.py
Browse files- src/nn/attn.py +0 -4
src/nn/attn.py
CHANGED
|
@@ -5,10 +5,6 @@ import einops
|
|
| 5 |
from jaxtyping import Float, Bool
|
| 6 |
from torch import Tensor
|
| 7 |
from typing import Optional
|
| 8 |
-
from torch.nn.attention.flex_attention import flex_attention
|
| 9 |
-
from matplotlib import pyplot as plt
|
| 10 |
-
|
| 11 |
-
from pdb import set_trace
|
| 12 |
|
| 13 |
class KVCache(nn.Module):
|
| 14 |
"""
|
|
|
|
| 5 |
from jaxtyping import Float, Bool
|
| 6 |
from torch import Tensor
|
| 7 |
from typing import Optional
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
class KVCache(nn.Module):
|
| 10 |
"""
|