Spaces:
Running on Zero
Running on Zero
[Admin maintenance] Support new ZeroGPU hardware
Browse filesThank you so much for having shared this Space with the community on this demo. We have upgraded the ZeroGPU infra-structure to run on modern blackwell architecture.
For that, we need to upgrade your demo to support that. This PR fixes your demo to work with the new architecture. As this is something we broke on our end, we may merge this PR autonomously. If this breaks unexpectedly or brings unintended consequences, feel free to revert, modify or otherwise. Any issues you can email apolinario@huggingface.co
- app.py +0 -2
- flash_attn/__init__.py +24 -0
- flash_attn/flash_attn_interface.py +112 -0
- flash_attn/funcs.py +54 -0
- flash_attn/layers/__init__.py +0 -0
- flash_attn/layers/rotary.py +51 -0
- requirements.txt +6 -5
app.py
CHANGED
|
@@ -1,6 +1,4 @@
|
|
| 1 |
import os
|
| 2 |
-
import subprocess
|
| 3 |
-
subprocess.run('pip install flash-attn==2.6.3 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
| 4 |
import random
|
| 5 |
import spaces
|
| 6 |
import numpy as np
|
|
|
|
| 1 |
import os
|
|
|
|
|
|
|
| 2 |
import random
|
| 3 |
import spaces
|
| 4 |
import numpy as np
|
flash_attn/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Minimal torch-native shim for flash_attn used by AIDC-AI/Ovis-U1-3B.
|
| 2 |
+
|
| 3 |
+
The upstream modeling file imports:
|
| 4 |
+
from flash_attn.layers.rotary import apply_rotary_emb
|
| 5 |
+
from flash_attn import flash_attn_varlen_func
|
| 6 |
+
|
| 7 |
+
Blackwell/CUDA-13 has no flash-attn prebuilt wheel for cp310+torch>=2.10, and the
|
| 8 |
+
package's CUDA build doesn't fit within the @spaces.GPU 1500s budget, so we
|
| 9 |
+
provide a small torch-native equivalent that satisfies the two call sites the
|
| 10 |
+
model actually exercises.
|
| 11 |
+
|
| 12 |
+
We also fake a version string within the range xformers tolerates so that
|
| 13 |
+
``xformers/ops/fmha/flash.py`` (loaded transitively by ``diffusers``) does not
|
| 14 |
+
explode at import time. The xformers FA backend it then registers will never
|
| 15 |
+
be invoked along the user-facing demo path (the model uses transformers SDPA
|
| 16 |
+
attention + this shim's varlen path; diffusers' xformers backend is only
|
| 17 |
+
engaged via an explicit ``set_use_memory_efficient_attention_xformers`` opt-in
|
| 18 |
+
which the demo never makes).
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
__version__ = "2.8.3"
|
| 22 |
+
|
| 23 |
+
from .funcs import flash_attn_varlen_func # noqa: F401
|
| 24 |
+
from . import flash_attn_interface # noqa: F401 -- expose submodule eagerly
|
flash_attn/flash_attn_interface.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Stub of flash_attn.flash_attn_interface.
|
| 2 |
+
|
| 3 |
+
Re-exports the torch-native ``flash_attn_varlen_func`` from the shim plus a
|
| 4 |
+
``flash_attn_func`` fallback that uses ``torch.nn.functional.scaled_dot_product_attention``
|
| 5 |
+
for the padded (batch, seqlen, nheads, headdim) call signature.
|
| 6 |
+
|
| 7 |
+
Also exposes a placeholder ``flash_attn_gpu`` object so xformers'
|
| 8 |
+
``hasattr(flash_attn.flash_attn_interface, "flash_attn_gpu")`` probe in
|
| 9 |
+
``xformers/ops/fmha/flash.py`` succeeds. The backend xformers registers from
|
| 10 |
+
this probe is never invoked along the demo's user-facing path.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
from .funcs import flash_attn_varlen_func # noqa: F401
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class _UnavailableBackend:
|
| 20 |
+
"""Opaque placeholder; calling any attribute raises a clear error."""
|
| 21 |
+
|
| 22 |
+
def __getattr__(self, name):
|
| 23 |
+
raise RuntimeError(
|
| 24 |
+
"flash_attn shim: real CUDA backend is not installed. "
|
| 25 |
+
"The demo's user-facing path should not need it."
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
flash_attn_gpu = _UnavailableBackend()
|
| 30 |
+
flash_attn_cuda = _UnavailableBackend()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def flash_attn_func(
|
| 34 |
+
q,
|
| 35 |
+
k,
|
| 36 |
+
v,
|
| 37 |
+
dropout_p=0.0,
|
| 38 |
+
softmax_scale=None,
|
| 39 |
+
causal=False,
|
| 40 |
+
window_size=(-1, -1),
|
| 41 |
+
softcap=0.0,
|
| 42 |
+
alibi_slopes=None,
|
| 43 |
+
deterministic=False,
|
| 44 |
+
return_attn_probs=False,
|
| 45 |
+
):
|
| 46 |
+
"""Padded attention. q/k/v shape: (B, L, H, D). Returns (B, L, H, D)."""
|
| 47 |
+
if softmax_scale is None:
|
| 48 |
+
softmax_scale = q.shape[-1] ** -0.5
|
| 49 |
+
# SDPA expects (B, H, L, D)
|
| 50 |
+
q_t = q.transpose(1, 2)
|
| 51 |
+
k_t = k.transpose(1, 2)
|
| 52 |
+
v_t = v.transpose(1, 2)
|
| 53 |
+
out = F.scaled_dot_product_attention(
|
| 54 |
+
q_t, k_t, v_t,
|
| 55 |
+
dropout_p=dropout_p,
|
| 56 |
+
is_causal=causal,
|
| 57 |
+
scale=softmax_scale,
|
| 58 |
+
)
|
| 59 |
+
return out.transpose(1, 2)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def flash_attn_qkvpacked_func(
|
| 63 |
+
qkv,
|
| 64 |
+
dropout_p=0.0,
|
| 65 |
+
softmax_scale=None,
|
| 66 |
+
causal=False,
|
| 67 |
+
window_size=(-1, -1),
|
| 68 |
+
softcap=0.0,
|
| 69 |
+
alibi_slopes=None,
|
| 70 |
+
deterministic=False,
|
| 71 |
+
return_attn_probs=False,
|
| 72 |
+
):
|
| 73 |
+
"""qkv shape: (B, L, 3, H, D)."""
|
| 74 |
+
q, k, v = qkv.unbind(dim=2)
|
| 75 |
+
return flash_attn_func(
|
| 76 |
+
q, k, v,
|
| 77 |
+
dropout_p=dropout_p,
|
| 78 |
+
softmax_scale=softmax_scale,
|
| 79 |
+
causal=causal,
|
| 80 |
+
window_size=window_size,
|
| 81 |
+
softcap=softcap,
|
| 82 |
+
alibi_slopes=alibi_slopes,
|
| 83 |
+
deterministic=deterministic,
|
| 84 |
+
return_attn_probs=return_attn_probs,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def flash_attn_kvpacked_func(
|
| 89 |
+
q,
|
| 90 |
+
kv,
|
| 91 |
+
dropout_p=0.0,
|
| 92 |
+
softmax_scale=None,
|
| 93 |
+
causal=False,
|
| 94 |
+
window_size=(-1, -1),
|
| 95 |
+
softcap=0.0,
|
| 96 |
+
alibi_slopes=None,
|
| 97 |
+
deterministic=False,
|
| 98 |
+
return_attn_probs=False,
|
| 99 |
+
):
|
| 100 |
+
"""q shape: (B, Lq, H, D), kv shape: (B, Lk, 2, H, D)."""
|
| 101 |
+
k, v = kv.unbind(dim=2)
|
| 102 |
+
return flash_attn_func(
|
| 103 |
+
q, k, v,
|
| 104 |
+
dropout_p=dropout_p,
|
| 105 |
+
softmax_scale=softmax_scale,
|
| 106 |
+
causal=causal,
|
| 107 |
+
window_size=window_size,
|
| 108 |
+
softcap=softcap,
|
| 109 |
+
alibi_slopes=alibi_slopes,
|
| 110 |
+
deterministic=deterministic,
|
| 111 |
+
return_attn_probs=return_attn_probs,
|
| 112 |
+
)
|
flash_attn/funcs.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Torch-native equivalent of flash_attn.flash_attn_varlen_func.
|
| 2 |
+
|
| 3 |
+
Only the forward path used by AIMv2's vision tower is implemented:
|
| 4 |
+
flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)
|
| 5 |
+
|
| 6 |
+
q, k, v are packed (total_tokens, num_heads, head_dim) tensors and the function
|
| 7 |
+
returns the same packed (total_tokens, num_heads, head_dim) shape after applying
|
| 8 |
+
self-attention per sub-sequence as encoded by `cu_seqlens_q == cu_seqlens_k`.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def flash_attn_varlen_func(
|
| 16 |
+
q,
|
| 17 |
+
k,
|
| 18 |
+
v,
|
| 19 |
+
cu_seqlens_q,
|
| 20 |
+
cu_seqlens_k,
|
| 21 |
+
max_seqlen_q,
|
| 22 |
+
max_seqlen_k,
|
| 23 |
+
dropout_p=0.0,
|
| 24 |
+
softmax_scale=None,
|
| 25 |
+
causal=False,
|
| 26 |
+
window_size=(-1, -1),
|
| 27 |
+
softcap=0.0,
|
| 28 |
+
alibi_slopes=None,
|
| 29 |
+
deterministic=False,
|
| 30 |
+
return_attn_probs=False,
|
| 31 |
+
):
|
| 32 |
+
if softmax_scale is None:
|
| 33 |
+
softmax_scale = q.shape[-1] ** -0.5
|
| 34 |
+
|
| 35 |
+
cu_q = cu_seqlens_q.tolist()
|
| 36 |
+
cu_k = cu_seqlens_k.tolist()
|
| 37 |
+
|
| 38 |
+
out_chunks = []
|
| 39 |
+
for i in range(len(cu_q) - 1):
|
| 40 |
+
sq, eq = cu_q[i], cu_q[i + 1]
|
| 41 |
+
sk, ek = cu_k[i], cu_k[i + 1]
|
| 42 |
+
q_i = q[sq:eq].transpose(0, 1).unsqueeze(0) # (1, H, Lq, D)
|
| 43 |
+
k_i = k[sk:ek].transpose(0, 1).unsqueeze(0)
|
| 44 |
+
v_i = v[sk:ek].transpose(0, 1).unsqueeze(0)
|
| 45 |
+
o_i = F.scaled_dot_product_attention(
|
| 46 |
+
q_i, k_i, v_i,
|
| 47 |
+
dropout_p=dropout_p,
|
| 48 |
+
is_causal=causal,
|
| 49 |
+
scale=softmax_scale,
|
| 50 |
+
)
|
| 51 |
+
out_chunks.append(o_i.squeeze(0).transpose(0, 1)) # (Lq, H, D)
|
| 52 |
+
|
| 53 |
+
out = torch.cat(out_chunks, dim=0)
|
| 54 |
+
return out
|
flash_attn/layers/__init__.py
ADDED
|
File without changes
|
flash_attn/layers/rotary.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Torch-native equivalent of flash_attn.layers.rotary.apply_rotary_emb.
|
| 2 |
+
|
| 3 |
+
Mirrors the flash_attn `apply_rotary_emb_torch` reference implementation:
|
| 4 |
+
x: (batch_size, seqlen, nheads, headdim)
|
| 5 |
+
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _rotate_half(x, interleaved=False):
|
| 12 |
+
if not interleaved:
|
| 13 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 14 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 15 |
+
x1, x2 = x[..., ::2], x[..., 1::2]
|
| 16 |
+
out = torch.stack((-x2, x1), dim=-1)
|
| 17 |
+
return out.flatten(-2)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def apply_rotary_emb(
|
| 21 |
+
x,
|
| 22 |
+
cos,
|
| 23 |
+
sin,
|
| 24 |
+
interleaved=False,
|
| 25 |
+
inplace=False,
|
| 26 |
+
seqlen_offsets=0,
|
| 27 |
+
cu_seqlens=None,
|
| 28 |
+
max_seqlen=None,
|
| 29 |
+
):
|
| 30 |
+
"""Pure-torch rotary embedding application.
|
| 31 |
+
|
| 32 |
+
The Ovis aimv2 call site uses the simple case: no `cu_seqlens`, no
|
| 33 |
+
`seqlen_offsets`, default `interleaved=False`.
|
| 34 |
+
"""
|
| 35 |
+
ro_dim = cos.shape[-1] * 2
|
| 36 |
+
assert ro_dim <= x.shape[-1], f"rotary dim {ro_dim} exceeds head dim {x.shape[-1]}"
|
| 37 |
+
|
| 38 |
+
# Broadcast cos/sin from (..., rotary_dim/2) up to (..., 1, rotary_dim)
|
| 39 |
+
if interleaved:
|
| 40 |
+
cos = cos.unsqueeze(-2).repeat_interleave(2, dim=-1)
|
| 41 |
+
sin = sin.unsqueeze(-2).repeat_interleave(2, dim=-1)
|
| 42 |
+
else:
|
| 43 |
+
cos = cos.unsqueeze(-2)
|
| 44 |
+
sin = sin.unsqueeze(-2)
|
| 45 |
+
cos = torch.cat([cos, cos], dim=-1)
|
| 46 |
+
sin = torch.cat([sin, sin], dim=-1)
|
| 47 |
+
|
| 48 |
+
x_rot = x[..., :ro_dim]
|
| 49 |
+
x_pass = x[..., ro_dim:]
|
| 50 |
+
out_rot = x_rot * cos + _rotate_half(x_rot, interleaved) * sin
|
| 51 |
+
return torch.cat([out_rot, x_pass], dim=-1)
|
requirements.txt
CHANGED
|
@@ -1,17 +1,18 @@
|
|
| 1 |
-
torch==2.
|
|
|
|
| 2 |
transformers==4.51.3
|
| 3 |
tokenizers==0.21.1
|
| 4 |
sentencepiece==0.1.99
|
| 5 |
pyarrow==18.0.0
|
| 6 |
accelerate==1.1.0
|
| 7 |
-
pydantic
|
| 8 |
markdown2[all]
|
| 9 |
-
numpy
|
| 10 |
scikit-learn==1.2.2
|
| 11 |
requests
|
| 12 |
httpx
|
| 13 |
uvicorn
|
| 14 |
-
fastapi
|
| 15 |
einops==0.6.1
|
| 16 |
einops-exts==0.0.4
|
| 17 |
timm==1.0.11
|
|
@@ -19,7 +20,7 @@ tiktoken
|
|
| 19 |
transformers_stream_generator==0.0.4
|
| 20 |
scipy
|
| 21 |
pandas
|
| 22 |
-
torchaudio
|
| 23 |
xformers
|
| 24 |
pillow==10.3.0
|
| 25 |
pysubs2==1.7.2
|
|
|
|
| 1 |
+
torch==2.10.0
|
| 2 |
+
torchvision==0.25.0
|
| 3 |
transformers==4.51.3
|
| 4 |
tokenizers==0.21.1
|
| 5 |
sentencepiece==0.1.99
|
| 6 |
pyarrow==18.0.0
|
| 7 |
accelerate==1.1.0
|
| 8 |
+
pydantic
|
| 9 |
markdown2[all]
|
| 10 |
+
numpy<2
|
| 11 |
scikit-learn==1.2.2
|
| 12 |
requests
|
| 13 |
httpx
|
| 14 |
uvicorn
|
| 15 |
+
fastapi
|
| 16 |
einops==0.6.1
|
| 17 |
einops-exts==0.0.4
|
| 18 |
timm==1.0.11
|
|
|
|
| 20 |
transformers_stream_generator==0.0.4
|
| 21 |
scipy
|
| 22 |
pandas
|
| 23 |
+
torchaudio==2.10.0
|
| 24 |
xformers
|
| 25 |
pillow==10.3.0
|
| 26 |
pysubs2==1.7.2
|