Ovis-U1-3B / flash_attn /flash_attn_interface.py
multimodalart's picture
multimodalart HF Staff
[Admin maintenance] Support new ZeroGPU hardware (#4)
e49bb9d
"""Stub of flash_attn.flash_attn_interface.
Re-exports the torch-native ``flash_attn_varlen_func`` from the shim plus a
``flash_attn_func`` fallback that uses ``torch.nn.functional.scaled_dot_product_attention``
for the padded (batch, seqlen, nheads, headdim) call signature.
Also exposes a placeholder ``flash_attn_gpu`` object so xformers'
``hasattr(flash_attn.flash_attn_interface, "flash_attn_gpu")`` probe in
``xformers/ops/fmha/flash.py`` succeeds. The backend xformers registers from
this probe is never invoked along the demo's user-facing path.
"""
import torch
import torch.nn.functional as F
from .funcs import flash_attn_varlen_func # noqa: F401
class _UnavailableBackend:
"""Opaque placeholder; calling any attribute raises a clear error."""
def __getattr__(self, name):
raise RuntimeError(
"flash_attn shim: real CUDA backend is not installed. "
"The demo's user-facing path should not need it."
)
flash_attn_gpu = _UnavailableBackend()
flash_attn_cuda = _UnavailableBackend()
def flash_attn_func(
q,
k,
v,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
):
"""Padded attention. q/k/v shape: (B, L, H, D). Returns (B, L, H, D)."""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** -0.5
# SDPA expects (B, H, L, D)
q_t = q.transpose(1, 2)
k_t = k.transpose(1, 2)
v_t = v.transpose(1, 2)
out = F.scaled_dot_product_attention(
q_t, k_t, v_t,
dropout_p=dropout_p,
is_causal=causal,
scale=softmax_scale,
)
return out.transpose(1, 2)
def flash_attn_qkvpacked_func(
qkv,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
):
"""qkv shape: (B, L, 3, H, D)."""
q, k, v = qkv.unbind(dim=2)
return flash_attn_func(
q, k, v,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=return_attn_probs,
)
def flash_attn_kvpacked_func(
q,
kv,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
):
"""q shape: (B, Lq, H, D), kv shape: (B, Lk, 2, H, D)."""
k, v = kv.unbind(dim=2)
return flash_attn_func(
q, k, v,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=return_attn_probs,
)