Spaces:
Running on Zero
Running on Zero
File size: 1,574 Bytes
e49bb9d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 | """Torch-native equivalent of flash_attn.flash_attn_varlen_func.
Only the forward path used by AIMv2's vision tower is implemented:
flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)
q, k, v are packed (total_tokens, num_heads, head_dim) tensors and the function
returns the same packed (total_tokens, num_heads, head_dim) shape after applying
self-attention per sub-sequence as encoded by `cu_seqlens_q == cu_seqlens_k`.
"""
import torch
import torch.nn.functional as F
def flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** -0.5
cu_q = cu_seqlens_q.tolist()
cu_k = cu_seqlens_k.tolist()
out_chunks = []
for i in range(len(cu_q) - 1):
sq, eq = cu_q[i], cu_q[i + 1]
sk, ek = cu_k[i], cu_k[i + 1]
q_i = q[sq:eq].transpose(0, 1).unsqueeze(0) # (1, H, Lq, D)
k_i = k[sk:ek].transpose(0, 1).unsqueeze(0)
v_i = v[sk:ek].transpose(0, 1).unsqueeze(0)
o_i = F.scaled_dot_product_attention(
q_i, k_i, v_i,
dropout_p=dropout_p,
is_causal=causal,
scale=softmax_scale,
)
out_chunks.append(o_i.squeeze(0).transpose(0, 1)) # (Lq, H, D)
out = torch.cat(out_chunks, dim=0)
return out
|