[Admin maintenance] Support new ZeroGPU hardware

#4
by multimodalart HF Staff - opened
app.py CHANGED
@@ -1,6 +1,4 @@
1
  import os
2
- import subprocess
3
- subprocess.run('pip install flash-attn==2.6.3 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
4
  import random
5
  import spaces
6
  import numpy as np
 
1
  import os
 
 
2
  import random
3
  import spaces
4
  import numpy as np
flash_attn/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Minimal torch-native shim for flash_attn used by AIDC-AI/Ovis-U1-3B.
2
+
3
+ The upstream modeling file imports:
4
+ from flash_attn.layers.rotary import apply_rotary_emb
5
+ from flash_attn import flash_attn_varlen_func
6
+
7
+ Blackwell/CUDA-13 has no flash-attn prebuilt wheel for cp310+torch>=2.10, and the
8
+ package's CUDA build doesn't fit within the @spaces.GPU 1500s budget, so we
9
+ provide a small torch-native equivalent that satisfies the two call sites the
10
+ model actually exercises.
11
+
12
+ We also fake a version string within the range xformers tolerates so that
13
+ ``xformers/ops/fmha/flash.py`` (loaded transitively by ``diffusers``) does not
14
+ explode at import time. The xformers FA backend it then registers will never
15
+ be invoked along the user-facing demo path (the model uses transformers SDPA
16
+ attention + this shim's varlen path; diffusers' xformers backend is only
17
+ engaged via an explicit ``set_use_memory_efficient_attention_xformers`` opt-in
18
+ which the demo never makes).
19
+ """
20
+
21
+ __version__ = "2.8.3"
22
+
23
+ from .funcs import flash_attn_varlen_func # noqa: F401
24
+ from . import flash_attn_interface # noqa: F401 -- expose submodule eagerly
flash_attn/flash_attn_interface.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stub of flash_attn.flash_attn_interface.
2
+
3
+ Re-exports the torch-native ``flash_attn_varlen_func`` from the shim plus a
4
+ ``flash_attn_func`` fallback that uses ``torch.nn.functional.scaled_dot_product_attention``
5
+ for the padded (batch, seqlen, nheads, headdim) call signature.
6
+
7
+ Also exposes a placeholder ``flash_attn_gpu`` object so xformers'
8
+ ``hasattr(flash_attn.flash_attn_interface, "flash_attn_gpu")`` probe in
9
+ ``xformers/ops/fmha/flash.py`` succeeds. The backend xformers registers from
10
+ this probe is never invoked along the demo's user-facing path.
11
+ """
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ from .funcs import flash_attn_varlen_func # noqa: F401
17
+
18
+
19
+ class _UnavailableBackend:
20
+ """Opaque placeholder; calling any attribute raises a clear error."""
21
+
22
+ def __getattr__(self, name):
23
+ raise RuntimeError(
24
+ "flash_attn shim: real CUDA backend is not installed. "
25
+ "The demo's user-facing path should not need it."
26
+ )
27
+
28
+
29
+ flash_attn_gpu = _UnavailableBackend()
30
+ flash_attn_cuda = _UnavailableBackend()
31
+
32
+
33
+ def flash_attn_func(
34
+ q,
35
+ k,
36
+ v,
37
+ dropout_p=0.0,
38
+ softmax_scale=None,
39
+ causal=False,
40
+ window_size=(-1, -1),
41
+ softcap=0.0,
42
+ alibi_slopes=None,
43
+ deterministic=False,
44
+ return_attn_probs=False,
45
+ ):
46
+ """Padded attention. q/k/v shape: (B, L, H, D). Returns (B, L, H, D)."""
47
+ if softmax_scale is None:
48
+ softmax_scale = q.shape[-1] ** -0.5
49
+ # SDPA expects (B, H, L, D)
50
+ q_t = q.transpose(1, 2)
51
+ k_t = k.transpose(1, 2)
52
+ v_t = v.transpose(1, 2)
53
+ out = F.scaled_dot_product_attention(
54
+ q_t, k_t, v_t,
55
+ dropout_p=dropout_p,
56
+ is_causal=causal,
57
+ scale=softmax_scale,
58
+ )
59
+ return out.transpose(1, 2)
60
+
61
+
62
+ def flash_attn_qkvpacked_func(
63
+ qkv,
64
+ dropout_p=0.0,
65
+ softmax_scale=None,
66
+ causal=False,
67
+ window_size=(-1, -1),
68
+ softcap=0.0,
69
+ alibi_slopes=None,
70
+ deterministic=False,
71
+ return_attn_probs=False,
72
+ ):
73
+ """qkv shape: (B, L, 3, H, D)."""
74
+ q, k, v = qkv.unbind(dim=2)
75
+ return flash_attn_func(
76
+ q, k, v,
77
+ dropout_p=dropout_p,
78
+ softmax_scale=softmax_scale,
79
+ causal=causal,
80
+ window_size=window_size,
81
+ softcap=softcap,
82
+ alibi_slopes=alibi_slopes,
83
+ deterministic=deterministic,
84
+ return_attn_probs=return_attn_probs,
85
+ )
86
+
87
+
88
+ def flash_attn_kvpacked_func(
89
+ q,
90
+ kv,
91
+ dropout_p=0.0,
92
+ softmax_scale=None,
93
+ causal=False,
94
+ window_size=(-1, -1),
95
+ softcap=0.0,
96
+ alibi_slopes=None,
97
+ deterministic=False,
98
+ return_attn_probs=False,
99
+ ):
100
+ """q shape: (B, Lq, H, D), kv shape: (B, Lk, 2, H, D)."""
101
+ k, v = kv.unbind(dim=2)
102
+ return flash_attn_func(
103
+ q, k, v,
104
+ dropout_p=dropout_p,
105
+ softmax_scale=softmax_scale,
106
+ causal=causal,
107
+ window_size=window_size,
108
+ softcap=softcap,
109
+ alibi_slopes=alibi_slopes,
110
+ deterministic=deterministic,
111
+ return_attn_probs=return_attn_probs,
112
+ )
flash_attn/funcs.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Torch-native equivalent of flash_attn.flash_attn_varlen_func.
2
+
3
+ Only the forward path used by AIMv2's vision tower is implemented:
4
+ flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)
5
+
6
+ q, k, v are packed (total_tokens, num_heads, head_dim) tensors and the function
7
+ returns the same packed (total_tokens, num_heads, head_dim) shape after applying
8
+ self-attention per sub-sequence as encoded by `cu_seqlens_q == cu_seqlens_k`.
9
+ """
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+
15
+ def flash_attn_varlen_func(
16
+ q,
17
+ k,
18
+ v,
19
+ cu_seqlens_q,
20
+ cu_seqlens_k,
21
+ max_seqlen_q,
22
+ max_seqlen_k,
23
+ dropout_p=0.0,
24
+ softmax_scale=None,
25
+ causal=False,
26
+ window_size=(-1, -1),
27
+ softcap=0.0,
28
+ alibi_slopes=None,
29
+ deterministic=False,
30
+ return_attn_probs=False,
31
+ ):
32
+ if softmax_scale is None:
33
+ softmax_scale = q.shape[-1] ** -0.5
34
+
35
+ cu_q = cu_seqlens_q.tolist()
36
+ cu_k = cu_seqlens_k.tolist()
37
+
38
+ out_chunks = []
39
+ for i in range(len(cu_q) - 1):
40
+ sq, eq = cu_q[i], cu_q[i + 1]
41
+ sk, ek = cu_k[i], cu_k[i + 1]
42
+ q_i = q[sq:eq].transpose(0, 1).unsqueeze(0) # (1, H, Lq, D)
43
+ k_i = k[sk:ek].transpose(0, 1).unsqueeze(0)
44
+ v_i = v[sk:ek].transpose(0, 1).unsqueeze(0)
45
+ o_i = F.scaled_dot_product_attention(
46
+ q_i, k_i, v_i,
47
+ dropout_p=dropout_p,
48
+ is_causal=causal,
49
+ scale=softmax_scale,
50
+ )
51
+ out_chunks.append(o_i.squeeze(0).transpose(0, 1)) # (Lq, H, D)
52
+
53
+ out = torch.cat(out_chunks, dim=0)
54
+ return out
flash_attn/layers/__init__.py ADDED
File without changes
flash_attn/layers/rotary.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Torch-native equivalent of flash_attn.layers.rotary.apply_rotary_emb.
2
+
3
+ Mirrors the flash_attn `apply_rotary_emb_torch` reference implementation:
4
+ x: (batch_size, seqlen, nheads, headdim)
5
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
6
+ """
7
+
8
+ import torch
9
+
10
+
11
+ def _rotate_half(x, interleaved=False):
12
+ if not interleaved:
13
+ x1, x2 = x.chunk(2, dim=-1)
14
+ return torch.cat((-x2, x1), dim=-1)
15
+ x1, x2 = x[..., ::2], x[..., 1::2]
16
+ out = torch.stack((-x2, x1), dim=-1)
17
+ return out.flatten(-2)
18
+
19
+
20
+ def apply_rotary_emb(
21
+ x,
22
+ cos,
23
+ sin,
24
+ interleaved=False,
25
+ inplace=False,
26
+ seqlen_offsets=0,
27
+ cu_seqlens=None,
28
+ max_seqlen=None,
29
+ ):
30
+ """Pure-torch rotary embedding application.
31
+
32
+ The Ovis aimv2 call site uses the simple case: no `cu_seqlens`, no
33
+ `seqlen_offsets`, default `interleaved=False`.
34
+ """
35
+ ro_dim = cos.shape[-1] * 2
36
+ assert ro_dim <= x.shape[-1], f"rotary dim {ro_dim} exceeds head dim {x.shape[-1]}"
37
+
38
+ # Broadcast cos/sin from (..., rotary_dim/2) up to (..., 1, rotary_dim)
39
+ if interleaved:
40
+ cos = cos.unsqueeze(-2).repeat_interleave(2, dim=-1)
41
+ sin = sin.unsqueeze(-2).repeat_interleave(2, dim=-1)
42
+ else:
43
+ cos = cos.unsqueeze(-2)
44
+ sin = sin.unsqueeze(-2)
45
+ cos = torch.cat([cos, cos], dim=-1)
46
+ sin = torch.cat([sin, sin], dim=-1)
47
+
48
+ x_rot = x[..., :ro_dim]
49
+ x_pass = x[..., ro_dim:]
50
+ out_rot = x_rot * cos + _rotate_half(x_rot, interleaved) * sin
51
+ return torch.cat([out_rot, x_pass], dim=-1)
requirements.txt CHANGED
@@ -1,17 +1,18 @@
1
- torch==2.4.0
 
2
  transformers==4.51.3
3
  tokenizers==0.21.1
4
  sentencepiece==0.1.99
5
  pyarrow==18.0.0
6
  accelerate==1.1.0
7
- pydantic==2.8.2
8
  markdown2[all]
9
- numpy==1.24.3
10
  scikit-learn==1.2.2
11
  requests
12
  httpx
13
  uvicorn
14
- fastapi==0.112.4
15
  einops==0.6.1
16
  einops-exts==0.0.4
17
  timm==1.0.11
@@ -19,7 +20,7 @@ tiktoken
19
  transformers_stream_generator==0.0.4
20
  scipy
21
  pandas
22
- torchaudio
23
  xformers
24
  pillow==10.3.0
25
  pysubs2==1.7.2
 
1
+ torch==2.10.0
2
+ torchvision==0.25.0
3
  transformers==4.51.3
4
  tokenizers==0.21.1
5
  sentencepiece==0.1.99
6
  pyarrow==18.0.0
7
  accelerate==1.1.0
8
+ pydantic
9
  markdown2[all]
10
+ numpy<2
11
  scikit-learn==1.2.2
12
  requests
13
  httpx
14
  uvicorn
15
+ fastapi
16
  einops==0.6.1
17
  einops-exts==0.0.4
18
  timm==1.0.11
 
20
  transformers_stream_generator==0.0.4
21
  scipy
22
  pandas
23
+ torchaudio==2.10.0
24
  xformers
25
  pillow==10.3.0
26
  pysubs2==1.7.2