Ovis-U1-3B / flash_attn /funcs.py
multimodalart's picture
multimodalart HF Staff
[Admin maintenance] Support new ZeroGPU hardware (#4)
e49bb9d
"""Torch-native equivalent of flash_attn.flash_attn_varlen_func.
Only the forward path used by AIMv2's vision tower is implemented:
flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)
q, k, v are packed (total_tokens, num_heads, head_dim) tensors and the function
returns the same packed (total_tokens, num_heads, head_dim) shape after applying
self-attention per sub-sequence as encoded by `cu_seqlens_q == cu_seqlens_k`.
"""
import torch
import torch.nn.functional as F
def flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** -0.5
cu_q = cu_seqlens_q.tolist()
cu_k = cu_seqlens_k.tolist()
out_chunks = []
for i in range(len(cu_q) - 1):
sq, eq = cu_q[i], cu_q[i + 1]
sk, ek = cu_k[i], cu_k[i + 1]
q_i = q[sq:eq].transpose(0, 1).unsqueeze(0) # (1, H, Lq, D)
k_i = k[sk:ek].transpose(0, 1).unsqueeze(0)
v_i = v[sk:ek].transpose(0, 1).unsqueeze(0)
o_i = F.scaled_dot_product_attention(
q_i, k_i, v_i,
dropout_p=dropout_p,
is_causal=causal,
scale=softmax_scale,
)
out_chunks.append(o_i.squeeze(0).transpose(0, 1)) # (Lq, H, D)
out = torch.cat(out_chunks, dim=0)
return out