"""Stub of flash_attn.flash_attn_interface. Re-exports the torch-native ``flash_attn_varlen_func`` from the shim plus a ``flash_attn_func`` fallback that uses ``torch.nn.functional.scaled_dot_product_attention`` for the padded (batch, seqlen, nheads, headdim) call signature. Also exposes a placeholder ``flash_attn_gpu`` object so xformers' ``hasattr(flash_attn.flash_attn_interface, "flash_attn_gpu")`` probe in ``xformers/ops/fmha/flash.py`` succeeds. The backend xformers registers from this probe is never invoked along the demo's user-facing path. """ import torch import torch.nn.functional as F from .funcs import flash_attn_varlen_func # noqa: F401 class _UnavailableBackend: """Opaque placeholder; calling any attribute raises a clear error.""" def __getattr__(self, name): raise RuntimeError( "flash_attn shim: real CUDA backend is not installed. " "The demo's user-facing path should not need it." ) flash_attn_gpu = _UnavailableBackend() flash_attn_cuda = _UnavailableBackend() def flash_attn_func( q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, deterministic=False, return_attn_probs=False, ): """Padded attention. q/k/v shape: (B, L, H, D). Returns (B, L, H, D).""" if softmax_scale is None: softmax_scale = q.shape[-1] ** -0.5 # SDPA expects (B, H, L, D) q_t = q.transpose(1, 2) k_t = k.transpose(1, 2) v_t = v.transpose(1, 2) out = F.scaled_dot_product_attention( q_t, k_t, v_t, dropout_p=dropout_p, is_causal=causal, scale=softmax_scale, ) return out.transpose(1, 2) def flash_attn_qkvpacked_func( qkv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, deterministic=False, return_attn_probs=False, ): """qkv shape: (B, L, 3, H, D).""" q, k, v = qkv.unbind(dim=2) return flash_attn_func( q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=causal, window_size=window_size, softcap=softcap, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=return_attn_probs, ) def flash_attn_kvpacked_func( q, kv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, deterministic=False, return_attn_probs=False, ): """q shape: (B, Lq, H, D), kv shape: (B, Lk, 2, H, D).""" k, v = kv.unbind(dim=2) return flash_attn_func( q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=causal, window_size=window_size, softcap=softcap, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=return_attn_probs, )