Spaces:
Runtime error
Runtime error
Update
Browse files- flux/math.py +1 -1
flux/math.py
CHANGED
|
@@ -11,7 +11,7 @@ def check_tpu():
|
|
| 11 |
return any('TPU' in d.device_kind for d in jax.devices())
|
| 12 |
|
| 13 |
# from torch import Tensor
|
| 14 |
-
if
|
| 15 |
from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention
|
| 16 |
# q, # [batch_size, num_heads, q_seq_len, d_model]
|
| 17 |
# k, # [batch_size, num_heads, kv_seq_len, d_model]
|
|
|
|
| 11 |
return any('TPU' in d.device_kind for d in jax.devices())
|
| 12 |
|
| 13 |
# from torch import Tensor
|
| 14 |
+
if False:
|
| 15 |
from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention
|
| 16 |
# q, # [batch_size, num_heads, q_seq_len, d_model]
|
| 17 |
# k, # [batch_size, num_heads, kv_seq_len, d_model]
|