Spaces:
Running on Zero
Running on Zero
| """Torch-native equivalent of flash_attn.flash_attn_varlen_func. | |
| Only the forward path used by AIMv2's vision tower is implemented: | |
| flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) | |
| q, k, v are packed (total_tokens, num_heads, head_dim) tensors and the function | |
| returns the same packed (total_tokens, num_heads, head_dim) shape after applying | |
| self-attention per sub-sequence as encoded by `cu_seqlens_q == cu_seqlens_k`. | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| def flash_attn_varlen_func( | |
| q, | |
| k, | |
| v, | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| max_seqlen_q, | |
| max_seqlen_k, | |
| dropout_p=0.0, | |
| softmax_scale=None, | |
| causal=False, | |
| window_size=(-1, -1), | |
| softcap=0.0, | |
| alibi_slopes=None, | |
| deterministic=False, | |
| return_attn_probs=False, | |
| ): | |
| if softmax_scale is None: | |
| softmax_scale = q.shape[-1] ** -0.5 | |
| cu_q = cu_seqlens_q.tolist() | |
| cu_k = cu_seqlens_k.tolist() | |
| out_chunks = [] | |
| for i in range(len(cu_q) - 1): | |
| sq, eq = cu_q[i], cu_q[i + 1] | |
| sk, ek = cu_k[i], cu_k[i + 1] | |
| q_i = q[sq:eq].transpose(0, 1).unsqueeze(0) # (1, H, Lq, D) | |
| k_i = k[sk:ek].transpose(0, 1).unsqueeze(0) | |
| v_i = v[sk:ek].transpose(0, 1).unsqueeze(0) | |
| o_i = F.scaled_dot_product_attention( | |
| q_i, k_i, v_i, | |
| dropout_p=dropout_p, | |
| is_causal=causal, | |
| scale=softmax_scale, | |
| ) | |
| out_chunks.append(o_i.squeeze(0).transpose(0, 1)) # (Lq, H, D) | |
| out = torch.cat(out_chunks, dim=0) | |
| return out | |