File size: 1,574 Bytes
e49bb9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""Torch-native equivalent of flash_attn.flash_attn_varlen_func.

Only the forward path used by AIMv2's vision tower is implemented:
    flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)

q, k, v are packed (total_tokens, num_heads, head_dim) tensors and the function
returns the same packed (total_tokens, num_heads, head_dim) shape after applying
self-attention per sub-sequence as encoded by `cu_seqlens_q == cu_seqlens_k`.
"""

import torch
import torch.nn.functional as F


def flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    return_attn_probs=False,
):
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** -0.5

    cu_q = cu_seqlens_q.tolist()
    cu_k = cu_seqlens_k.tolist()

    out_chunks = []
    for i in range(len(cu_q) - 1):
        sq, eq = cu_q[i], cu_q[i + 1]
        sk, ek = cu_k[i], cu_k[i + 1]
        q_i = q[sq:eq].transpose(0, 1).unsqueeze(0)  # (1, H, Lq, D)
        k_i = k[sk:ek].transpose(0, 1).unsqueeze(0)
        v_i = v[sk:ek].transpose(0, 1).unsqueeze(0)
        o_i = F.scaled_dot_product_attention(
            q_i, k_i, v_i,
            dropout_p=dropout_p,
            is_causal=causal,
            scale=softmax_scale,
        )
        out_chunks.append(o_i.squeeze(0).transpose(0, 1))  # (Lq, H, D)

    out = torch.cat(out_chunks, dim=0)
    return out