"""Torch-native equivalent of flash_attn.flash_attn_varlen_func. Only the forward path used by AIMv2's vision tower is implemented: flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) q, k, v are packed (total_tokens, num_heads, head_dim) tensors and the function returns the same packed (total_tokens, num_heads, head_dim) shape after applying self-attention per sub-sequence as encoded by `cu_seqlens_q == cu_seqlens_k`. """ import torch import torch.nn.functional as F def flash_attn_varlen_func( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, deterministic=False, return_attn_probs=False, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** -0.5 cu_q = cu_seqlens_q.tolist() cu_k = cu_seqlens_k.tolist() out_chunks = [] for i in range(len(cu_q) - 1): sq, eq = cu_q[i], cu_q[i + 1] sk, ek = cu_k[i], cu_k[i + 1] q_i = q[sq:eq].transpose(0, 1).unsqueeze(0) # (1, H, Lq, D) k_i = k[sk:ek].transpose(0, 1).unsqueeze(0) v_i = v[sk:ek].transpose(0, 1).unsqueeze(0) o_i = F.scaled_dot_product_attention( q_i, k_i, v_i, dropout_p=dropout_p, is_causal=causal, scale=softmax_scale, ) out_chunks.append(o_i.squeeze(0).transpose(0, 1)) # (Lq, H, D) out = torch.cat(out_chunks, dim=0) return out