Spaces:
Running on Zero
Running on Zero
| """Stub of flash_attn.flash_attn_interface. | |
| Re-exports the torch-native ``flash_attn_varlen_func`` from the shim plus a | |
| ``flash_attn_func`` fallback that uses ``torch.nn.functional.scaled_dot_product_attention`` | |
| for the padded (batch, seqlen, nheads, headdim) call signature. | |
| Also exposes a placeholder ``flash_attn_gpu`` object so xformers' | |
| ``hasattr(flash_attn.flash_attn_interface, "flash_attn_gpu")`` probe in | |
| ``xformers/ops/fmha/flash.py`` succeeds. The backend xformers registers from | |
| this probe is never invoked along the demo's user-facing path. | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| from .funcs import flash_attn_varlen_func # noqa: F401 | |
| class _UnavailableBackend: | |
| """Opaque placeholder; calling any attribute raises a clear error.""" | |
| def __getattr__(self, name): | |
| raise RuntimeError( | |
| "flash_attn shim: real CUDA backend is not installed. " | |
| "The demo's user-facing path should not need it." | |
| ) | |
| flash_attn_gpu = _UnavailableBackend() | |
| flash_attn_cuda = _UnavailableBackend() | |
| def flash_attn_func( | |
| q, | |
| k, | |
| v, | |
| dropout_p=0.0, | |
| softmax_scale=None, | |
| causal=False, | |
| window_size=(-1, -1), | |
| softcap=0.0, | |
| alibi_slopes=None, | |
| deterministic=False, | |
| return_attn_probs=False, | |
| ): | |
| """Padded attention. q/k/v shape: (B, L, H, D). Returns (B, L, H, D).""" | |
| if softmax_scale is None: | |
| softmax_scale = q.shape[-1] ** -0.5 | |
| # SDPA expects (B, H, L, D) | |
| q_t = q.transpose(1, 2) | |
| k_t = k.transpose(1, 2) | |
| v_t = v.transpose(1, 2) | |
| out = F.scaled_dot_product_attention( | |
| q_t, k_t, v_t, | |
| dropout_p=dropout_p, | |
| is_causal=causal, | |
| scale=softmax_scale, | |
| ) | |
| return out.transpose(1, 2) | |
| def flash_attn_qkvpacked_func( | |
| qkv, | |
| dropout_p=0.0, | |
| softmax_scale=None, | |
| causal=False, | |
| window_size=(-1, -1), | |
| softcap=0.0, | |
| alibi_slopes=None, | |
| deterministic=False, | |
| return_attn_probs=False, | |
| ): | |
| """qkv shape: (B, L, 3, H, D).""" | |
| q, k, v = qkv.unbind(dim=2) | |
| return flash_attn_func( | |
| q, k, v, | |
| dropout_p=dropout_p, | |
| softmax_scale=softmax_scale, | |
| causal=causal, | |
| window_size=window_size, | |
| softcap=softcap, | |
| alibi_slopes=alibi_slopes, | |
| deterministic=deterministic, | |
| return_attn_probs=return_attn_probs, | |
| ) | |
| def flash_attn_kvpacked_func( | |
| q, | |
| kv, | |
| dropout_p=0.0, | |
| softmax_scale=None, | |
| causal=False, | |
| window_size=(-1, -1), | |
| softcap=0.0, | |
| alibi_slopes=None, | |
| deterministic=False, | |
| return_attn_probs=False, | |
| ): | |
| """q shape: (B, Lq, H, D), kv shape: (B, Lk, 2, H, D).""" | |
| k, v = kv.unbind(dim=2) | |
| return flash_attn_func( | |
| q, k, v, | |
| dropout_p=dropout_p, | |
| softmax_scale=softmax_scale, | |
| causal=causal, | |
| window_size=window_size, | |
| softcap=softcap, | |
| alibi_slopes=alibi_slopes, | |
| deterministic=deterministic, | |
| return_attn_probs=return_attn_probs, | |
| ) | |