multimodalart HF Staff commited on
Commit
37094bb
·
verified ·
1 Parent(s): 9a6780b

[Admin maintenance] Support new ZeroGPU hardware

Browse files

Thank you so much for having shared this Space with the community on this demo. We have upgraded the ZeroGPU infra-structure to run on modern blackwell architecture.
For that, we need to upgrade your demo to support that. This PR fixes your demo to work with the new architecture. As this is something we broke on our end, we may merge this PR autonomously. If this breaks unexpectedly or brings unintended consequences, feel free to revert, modify or otherwise. Any issues you can email apolinario@huggingface.co

app.py CHANGED
@@ -1,6 +1,4 @@
1
  import os
2
- import subprocess
3
- subprocess.run('pip install flash-attn==2.6.3 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
4
  import random
5
  import spaces
6
  import numpy as np
 
1
  import os
 
 
2
  import random
3
  import spaces
4
  import numpy as np
flash_attn/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Minimal torch-native shim for flash_attn used by AIDC-AI/Ovis-U1-3B.
2
+
3
+ The upstream modeling file imports:
4
+ from flash_attn.layers.rotary import apply_rotary_emb
5
+ from flash_attn import flash_attn_varlen_func
6
+
7
+ Blackwell/CUDA-13 has no flash-attn prebuilt wheel for cp310+torch>=2.10, and the
8
+ package's CUDA build doesn't fit within the @spaces.GPU 1500s budget, so we
9
+ provide a small torch-native equivalent that satisfies the two call sites the
10
+ model actually exercises.
11
+
12
+ We also fake a version string within the range xformers tolerates so that
13
+ ``xformers/ops/fmha/flash.py`` (loaded transitively by ``diffusers``) does not
14
+ explode at import time. The xformers FA backend it then registers will never
15
+ be invoked along the user-facing demo path (the model uses transformers SDPA
16
+ attention + this shim's varlen path; diffusers' xformers backend is only
17
+ engaged via an explicit ``set_use_memory_efficient_attention_xformers`` opt-in
18
+ which the demo never makes).
19
+ """
20
+
21
+ __version__ = "2.8.3"
22
+
23
+ from .funcs import flash_attn_varlen_func # noqa: F401
24
+ from . import flash_attn_interface # noqa: F401 -- expose submodule eagerly
flash_attn/flash_attn_interface.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stub of flash_attn.flash_attn_interface.
2
+
3
+ Re-exports the torch-native ``flash_attn_varlen_func`` from the shim plus a
4
+ ``flash_attn_func`` fallback that uses ``torch.nn.functional.scaled_dot_product_attention``
5
+ for the padded (batch, seqlen, nheads, headdim) call signature.
6
+
7
+ Also exposes a placeholder ``flash_attn_gpu`` object so xformers'
8
+ ``hasattr(flash_attn.flash_attn_interface, "flash_attn_gpu")`` probe in
9
+ ``xformers/ops/fmha/flash.py`` succeeds. The backend xformers registers from
10
+ this probe is never invoked along the demo's user-facing path.
11
+ """
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ from .funcs import flash_attn_varlen_func # noqa: F401
17
+
18
+
19
+ class _UnavailableBackend:
20
+ """Opaque placeholder; calling any attribute raises a clear error."""
21
+
22
+ def __getattr__(self, name):
23
+ raise RuntimeError(
24
+ "flash_attn shim: real CUDA backend is not installed. "
25
+ "The demo's user-facing path should not need it."
26
+ )
27
+
28
+
29
+ flash_attn_gpu = _UnavailableBackend()
30
+ flash_attn_cuda = _UnavailableBackend()
31
+
32
+
33
+ def flash_attn_func(
34
+ q,
35
+ k,
36
+ v,
37
+ dropout_p=0.0,
38
+ softmax_scale=None,
39
+ causal=False,
40
+ window_size=(-1, -1),
41
+ softcap=0.0,
42
+ alibi_slopes=None,
43
+ deterministic=False,
44
+ return_attn_probs=False,
45
+ ):
46
+ """Padded attention. q/k/v shape: (B, L, H, D). Returns (B, L, H, D)."""
47
+ if softmax_scale is None:
48
+ softmax_scale = q.shape[-1] ** -0.5
49
+ # SDPA expects (B, H, L, D)
50
+ q_t = q.transpose(1, 2)
51
+ k_t = k.transpose(1, 2)
52
+ v_t = v.transpose(1, 2)
53
+ out = F.scaled_dot_product_attention(
54
+ q_t, k_t, v_t,
55
+ dropout_p=dropout_p,
56
+ is_causal=causal,
57
+ scale=softmax_scale,
58
+ )
59
+ return out.transpose(1, 2)
60
+
61
+
62
+ def flash_attn_qkvpacked_func(
63
+ qkv,
64
+ dropout_p=0.0,
65
+ softmax_scale=None,
66
+ causal=False,
67
+ window_size=(-1, -1),
68
+ softcap=0.0,
69
+ alibi_slopes=None,
70
+ deterministic=False,
71
+ return_attn_probs=False,
72
+ ):
73
+ """qkv shape: (B, L, 3, H, D)."""
74
+ q, k, v = qkv.unbind(dim=2)
75
+ return flash_attn_func(
76
+ q, k, v,
77
+ dropout_p=dropout_p,
78
+ softmax_scale=softmax_scale,
79
+ causal=causal,
80
+ window_size=window_size,
81
+ softcap=softcap,
82
+ alibi_slopes=alibi_slopes,
83
+ deterministic=deterministic,
84
+ return_attn_probs=return_attn_probs,
85
+ )
86
+
87
+
88
+ def flash_attn_kvpacked_func(
89
+ q,
90
+ kv,
91
+ dropout_p=0.0,
92
+ softmax_scale=None,
93
+ causal=False,
94
+ window_size=(-1, -1),
95
+ softcap=0.0,
96
+ alibi_slopes=None,
97
+ deterministic=False,
98
+ return_attn_probs=False,
99
+ ):
100
+ """q shape: (B, Lq, H, D), kv shape: (B, Lk, 2, H, D)."""
101
+ k, v = kv.unbind(dim=2)
102
+ return flash_attn_func(
103
+ q, k, v,
104
+ dropout_p=dropout_p,
105
+ softmax_scale=softmax_scale,
106
+ causal=causal,
107
+ window_size=window_size,
108
+ softcap=softcap,
109
+ alibi_slopes=alibi_slopes,
110
+ deterministic=deterministic,
111
+ return_attn_probs=return_attn_probs,
112
+ )
flash_attn/funcs.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Torch-native equivalent of flash_attn.flash_attn_varlen_func.
2
+
3
+ Only the forward path used by AIMv2's vision tower is implemented:
4
+ flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)
5
+
6
+ q, k, v are packed (total_tokens, num_heads, head_dim) tensors and the function
7
+ returns the same packed (total_tokens, num_heads, head_dim) shape after applying
8
+ self-attention per sub-sequence as encoded by `cu_seqlens_q == cu_seqlens_k`.
9
+ """
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+
15
+ def flash_attn_varlen_func(
16
+ q,
17
+ k,
18
+ v,
19
+ cu_seqlens_q,
20
+ cu_seqlens_k,
21
+ max_seqlen_q,
22
+ max_seqlen_k,
23
+ dropout_p=0.0,
24
+ softmax_scale=None,
25
+ causal=False,
26
+ window_size=(-1, -1),
27
+ softcap=0.0,
28
+ alibi_slopes=None,
29
+ deterministic=False,
30
+ return_attn_probs=False,
31
+ ):
32
+ if softmax_scale is None:
33
+ softmax_scale = q.shape[-1] ** -0.5
34
+
35
+ cu_q = cu_seqlens_q.tolist()
36
+ cu_k = cu_seqlens_k.tolist()
37
+
38
+ out_chunks = []
39
+ for i in range(len(cu_q) - 1):
40
+ sq, eq = cu_q[i], cu_q[i + 1]
41
+ sk, ek = cu_k[i], cu_k[i + 1]
42
+ q_i = q[sq:eq].transpose(0, 1).unsqueeze(0) # (1, H, Lq, D)
43
+ k_i = k[sk:ek].transpose(0, 1).unsqueeze(0)
44
+ v_i = v[sk:ek].transpose(0, 1).unsqueeze(0)
45
+ o_i = F.scaled_dot_product_attention(
46
+ q_i, k_i, v_i,
47
+ dropout_p=dropout_p,
48
+ is_causal=causal,
49
+ scale=softmax_scale,
50
+ )
51
+ out_chunks.append(o_i.squeeze(0).transpose(0, 1)) # (Lq, H, D)
52
+
53
+ out = torch.cat(out_chunks, dim=0)
54
+ return out
flash_attn/layers/__init__.py ADDED
File without changes
flash_attn/layers/rotary.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Torch-native equivalent of flash_attn.layers.rotary.apply_rotary_emb.
2
+
3
+ Mirrors the flash_attn `apply_rotary_emb_torch` reference implementation:
4
+ x: (batch_size, seqlen, nheads, headdim)
5
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
6
+ """
7
+
8
+ import torch
9
+
10
+
11
+ def _rotate_half(x, interleaved=False):
12
+ if not interleaved:
13
+ x1, x2 = x.chunk(2, dim=-1)
14
+ return torch.cat((-x2, x1), dim=-1)
15
+ x1, x2 = x[..., ::2], x[..., 1::2]
16
+ out = torch.stack((-x2, x1), dim=-1)
17
+ return out.flatten(-2)
18
+
19
+
20
+ def apply_rotary_emb(
21
+ x,
22
+ cos,
23
+ sin,
24
+ interleaved=False,
25
+ inplace=False,
26
+ seqlen_offsets=0,
27
+ cu_seqlens=None,
28
+ max_seqlen=None,
29
+ ):
30
+ """Pure-torch rotary embedding application.
31
+
32
+ The Ovis aimv2 call site uses the simple case: no `cu_seqlens`, no
33
+ `seqlen_offsets`, default `interleaved=False`.
34
+ """
35
+ ro_dim = cos.shape[-1] * 2
36
+ assert ro_dim <= x.shape[-1], f"rotary dim {ro_dim} exceeds head dim {x.shape[-1]}"
37
+
38
+ # Broadcast cos/sin from (..., rotary_dim/2) up to (..., 1, rotary_dim)
39
+ if interleaved:
40
+ cos = cos.unsqueeze(-2).repeat_interleave(2, dim=-1)
41
+ sin = sin.unsqueeze(-2).repeat_interleave(2, dim=-1)
42
+ else:
43
+ cos = cos.unsqueeze(-2)
44
+ sin = sin.unsqueeze(-2)
45
+ cos = torch.cat([cos, cos], dim=-1)
46
+ sin = torch.cat([sin, sin], dim=-1)
47
+
48
+ x_rot = x[..., :ro_dim]
49
+ x_pass = x[..., ro_dim:]
50
+ out_rot = x_rot * cos + _rotate_half(x_rot, interleaved) * sin
51
+ return torch.cat([out_rot, x_pass], dim=-1)
requirements.txt CHANGED
@@ -1,17 +1,18 @@
1
- torch==2.4.0
 
2
  transformers==4.51.3
3
  tokenizers==0.21.1
4
  sentencepiece==0.1.99
5
  pyarrow==18.0.0
6
  accelerate==1.1.0
7
- pydantic==2.8.2
8
  markdown2[all]
9
- numpy==1.24.3
10
  scikit-learn==1.2.2
11
  requests
12
  httpx
13
  uvicorn
14
- fastapi==0.112.4
15
  einops==0.6.1
16
  einops-exts==0.0.4
17
  timm==1.0.11
@@ -19,7 +20,7 @@ tiktoken
19
  transformers_stream_generator==0.0.4
20
  scipy
21
  pandas
22
- torchaudio
23
  xformers
24
  pillow==10.3.0
25
  pysubs2==1.7.2
 
1
+ torch==2.10.0
2
+ torchvision==0.25.0
3
  transformers==4.51.3
4
  tokenizers==0.21.1
5
  sentencepiece==0.1.99
6
  pyarrow==18.0.0
7
  accelerate==1.1.0
8
+ pydantic
9
  markdown2[all]
10
+ numpy<2
11
  scikit-learn==1.2.2
12
  requests
13
  httpx
14
  uvicorn
15
+ fastapi
16
  einops==0.6.1
17
  einops-exts==0.0.4
18
  timm==1.0.11
 
20
  transformers_stream_generator==0.0.4
21
  scipy
22
  pandas
23
+ torchaudio==2.10.0
24
  xformers
25
  pillow==10.3.0
26
  pysubs2==1.7.2