DeepBeepMeep commited on
Commit
370ee56
·
1 Parent(s): 26c29b6

This UI color is the good one + slightly reduced VRAM when using Sage2 attention

Browse files
gradio_server.py CHANGED
@@ -1595,7 +1595,7 @@ def create_demo():
1595
  }
1596
  """
1597
  default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v)
1598
- with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="stone", neutral_hue="slate", text_size= "md")) as demo:
1599
  state_dict = {}
1600
 
1601
  if use_image2video:
 
1595
  }
1596
  """
1597
  default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v)
1598
+ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size= "md")) as demo:
1599
  state_dict = {}
1600
 
1601
  if use_image2video:
wan/modules/attention.py CHANGED
@@ -38,27 +38,35 @@ import warnings
38
 
39
  try:
40
  from sageattention import sageattn
41
- @torch.compiler.disable()
42
- def sageattn_wrapper(
43
- qkv_list,
44
- attention_length
45
- ):
46
- q,k, v = qkv_list
47
- padding_length = q.shape[0] -attention_length
48
- q = q[:attention_length, :, : ].unsqueeze(0)
49
- k = k[:attention_length, :, : ].unsqueeze(0)
50
- v = v[:attention_length, :, : ].unsqueeze(0)
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  o = sageattn(q, k, v, tensor_layout="NHD").squeeze(0)
53
  del q, k ,v
54
- qkv_list.clear()
55
 
56
- if padding_length > 0:
57
- o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
58
 
59
- return o
60
- except ImportError:
61
- sageattn = None
 
62
 
63
  # # try:
64
  # if True:
 
38
 
39
  try:
40
  from sageattention import sageattn
41
+ from .sage2_core import sageattn as alt_sageattn
42
+ except ImportError:
43
+ sageattn = None
44
+ alt_sageattn = None
 
 
 
 
 
 
45
 
46
+ # @torch.compiler.disable()
47
+ def sageattn_wrapper(
48
+ qkv_list,
49
+ attention_length
50
+ ):
51
+ q,k, v = qkv_list
52
+ padding_length = q.shape[0] -attention_length
53
+ q = q[:attention_length, :, : ].unsqueeze(0)
54
+ k = k[:attention_length, :, : ].unsqueeze(0)
55
+ v = v[:attention_length, :, : ].unsqueeze(0)
56
+ if True:
57
+ qkv_list = [q,k,v]
58
+ del q, k ,v
59
+ o = alt_sageattn(qkv_list, tensor_layout="NHD").squeeze(0)
60
+ else:
61
  o = sageattn(q, k, v, tensor_layout="NHD").squeeze(0)
62
  del q, k ,v
 
63
 
64
+ qkv_list.clear()
 
65
 
66
+ if padding_length > 0:
67
+ o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
68
+
69
+ return o
70
 
71
  # # try:
72
  # if True:
wan/modules/sage2_core.py ADDED
@@ -0,0 +1,1094 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+
20
+ from sageattention.triton.quant_per_block import per_block_int8 as per_block_int8_triton
21
+ from sageattention.triton.quant_per_block_varlen import per_block_int8 as per_block_int8_varlen_triton
22
+ from sageattention.triton.attn_qk_int8_per_block import forward as attn_false
23
+ from sageattention.triton.attn_qk_int8_per_block_causal import forward as attn_true
24
+ from sageattention.triton.attn_qk_int8_block_varlen import forward as attn_false_varlen
25
+ from sageattention.triton.attn_qk_int8_per_block_causal_varlen import forward as attn_true_varlen
26
+
27
+ from sageattention.triton.quant_per_thread import per_thread_int8 as per_thread_int8_triton
28
+
29
+ try:
30
+ from sageattention import _qattn_sm80
31
+ SM80_ENABLED = True
32
+ except:
33
+ SM80_ENABLED = False
34
+
35
+ try:
36
+ from sageattention import _qattn_sm89
37
+ SM89_ENABLED = True
38
+ except:
39
+ SM89_ENABLED = False
40
+
41
+ try:
42
+ from sageattention import _qattn_sm90
43
+ SM90_ENABLED = True
44
+ except:
45
+ SM90_ENABLED = False
46
+
47
+ from sageattention.quant import per_block_int8 as per_block_int8_cuda
48
+ from sageattention.quant import per_warp_int8 as per_warp_int8_cuda
49
+ from sageattention.quant import sub_mean
50
+ from sageattention.quant import per_channel_fp8
51
+
52
+ from typing import Any, List, Literal, Optional, Tuple, Union
53
+ import warnings
54
+
55
+ def get_cuda_arch_versions():
56
+ cuda_archs = []
57
+ for i in range(torch.cuda.device_count()):
58
+ major, minor = torch.cuda.get_device_capability(i)
59
+ cuda_archs.append(f"sm{major}{minor}")
60
+ return cuda_archs
61
+
62
+ def sageattn(
63
+ qkv_list,
64
+ tensor_layout: str = "HND",
65
+ is_causal: bool = False,
66
+ sm_scale: Optional[float] = None,
67
+ return_lse: bool = False,
68
+ **kwargs: Any,
69
+ ):
70
+ """
71
+ Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability.
72
+
73
+ Parameters
74
+ ----------
75
+ q : torch.Tensor
76
+ The query tensor. Shape:
77
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
78
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
79
+
80
+ k : torch.Tensor
81
+ The key tensor. Shape:
82
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
83
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
84
+
85
+ v : torch.Tensor
86
+ The value tensor. Shape:
87
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
88
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
89
+
90
+ tensor_layout : str
91
+ The tensor layout, either "HND" or "NHD".
92
+ Default: "HND".
93
+
94
+ is_causal : bool
95
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
96
+ Default: False.
97
+
98
+ sm_scale : Optional[float]
99
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
100
+
101
+ return_lse : bool
102
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
103
+ Default: False.
104
+
105
+ Returns
106
+ -------
107
+ torch.Tensor
108
+ The output tensor. Shape:
109
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
110
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
111
+
112
+ torch.Tensor
113
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
114
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
115
+ Only returned if `return_lse` is True.
116
+
117
+ Note
118
+ ----
119
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
120
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
121
+ - All tensors must be on the same cuda device.
122
+ """
123
+
124
+ arch = get_cuda_arch_versions()[qkv_list[0].device.index]
125
+ if arch == "sm80":
126
+ return sageattn_qk_int8_pv_fp16_cuda(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32")
127
+ elif arch == "sm86":
128
+ return sageattn_qk_int8_pv_fp16_triton(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse)
129
+ elif arch == "sm89":
130
+ return sageattn_qk_int8_pv_fp8_cuda(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32")
131
+ elif arch == "sm90":
132
+ return sageattn_qk_int8_pv_fp8_cuda_sm90(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32")
133
+ elif arch == "sm120":
134
+ return sageattn_qk_int8_pv_fp8_cuda(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32") # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
135
+ else:
136
+ raise ValueError(f"Unsupported CUDA architecture: {arch}")
137
+
138
+ @torch.compiler.disable
139
+ def sageattn_qk_int8_pv_fp16_triton(
140
+ q: torch.Tensor,
141
+ k: torch.Tensor,
142
+ v: torch.Tensor,
143
+ tensor_layout: str = "HND",
144
+ quantization_backend: str = "triton",
145
+ is_causal: bool =False,
146
+ sm_scale: Optional[float] = None,
147
+ smooth_k: bool = True,
148
+ return_lse: bool = False,
149
+ **kwargs: Any,
150
+ ) -> torch.Tensor:
151
+ """
152
+ SageAttention with per-block INT8 quantization for Q and K, FP16 PV with FP16 accumulation, implemented using Triton.
153
+ The FP16 accumulator is added to a FP32 buffer immediately after each iteration.
154
+
155
+ Parameters
156
+ ----------
157
+ q : torch.Tensor
158
+ The query tensor. Shape:
159
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
160
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
161
+
162
+ k : torch.Tensor
163
+ The key tensor. Shape:
164
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
165
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
166
+
167
+ v : torch.Tensor
168
+ The value tensor. Shape:
169
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
170
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
171
+
172
+ tensor_layout : str
173
+ The tensor layout, either "HND" or "NHD".
174
+ Default: "HND".
175
+
176
+ quantization_backend : str
177
+ The quantization backend, either "triton" or "cuda".
178
+ "cuda" backend offers better performance due to kernel fusion.
179
+
180
+ is_causal : bool
181
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
182
+ Default: False.
183
+
184
+ sm_scale : Optional[float]
185
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
186
+
187
+ smooth_k : bool
188
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
189
+ Default: True.
190
+
191
+ return_lse : bool
192
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
193
+ Default: False.
194
+
195
+ Returns
196
+ -------
197
+ torch.Tensor
198
+ The output tensor. Shape:
199
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
200
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
201
+
202
+ torch.Tensor
203
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
204
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
205
+ Only returned if `return_lse` is True.
206
+
207
+ Note
208
+ ----
209
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
210
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16``, ``torch.bfloat16`` or ``torch.float32``.
211
+ - All tensors must be on the same cuda device.
212
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
213
+ """
214
+
215
+ dtype = q.dtype
216
+ assert q.is_cuda, "Input tensors must be on cuda."
217
+ assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
218
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
219
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
220
+
221
+ # FIXME(DefTruth): make sage attention work compatible with distributed
222
+ # env, for example, xDiT which launch by torchrun. Without this workaround,
223
+ # sage attention will run into illegal memory access error after first
224
+ # inference step in distributed env for multi gpus inference. This small
225
+ # workaround also make sage attention work compatible with torch.compile
226
+ # through non-fullgraph compile mode.
227
+ torch.cuda.set_device(v.device)
228
+
229
+ head_dim_og = q.size(-1)
230
+
231
+ if head_dim_og < 64:
232
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
233
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
234
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
235
+ elif head_dim_og > 64 and head_dim_og < 128:
236
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
237
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
238
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
239
+ elif head_dim_og > 128:
240
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
241
+
242
+ # assert last dim is contiguous
243
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous."
244
+
245
+ seq_dim = 1 if tensor_layout == "NHD" else 2
246
+
247
+ if smooth_k:
248
+ km = k.mean(dim=seq_dim, keepdim=True)
249
+ if return_lse:
250
+ if tensor_layout == "NHD":
251
+ lse_correction = torch.matmul(q.transpose(1, 2), km.transpose(1, 2).transpose(2, 3)).squeeze(-1).to(torch.float32)
252
+ else:
253
+ lse_correction = torch.matmul(q, km.transpose(2, 3)).squeeze(-1).to(torch.float32)
254
+ else:
255
+ km = None
256
+
257
+ if dtype == torch.bfloat16 or dtype == torch.float32:
258
+ v = v.to(torch.float16)
259
+
260
+ if sm_scale is None:
261
+ sm_scale = 1.0 / (head_dim_og ** 0.5)
262
+
263
+ if quantization_backend == "triton":
264
+ q_int8, q_scale, k_int8, k_scale = per_block_int8_triton(q, k, km=km, sm_scale=sm_scale, tensor_layout=tensor_layout)
265
+ elif quantization_backend == "cuda":
266
+ q_int8, q_scale, k_int8, k_scale = per_block_int8_cuda(q, k, km=km, sm_scale=sm_scale, tensor_layout=tensor_layout)
267
+ else:
268
+ raise ValueError(f"Unsupported quantization backend: {quantization_backend}")
269
+ if is_causal:
270
+ o, lse = attn_true(q_int8, k_int8, v, q_scale, k_scale, tensor_layout=tensor_layout, output_dtype=dtype, return_lse=return_lse)
271
+ else:
272
+ o, lse = attn_false(q_int8, k_int8, v, q_scale, k_scale, tensor_layout=tensor_layout, output_dtype=dtype, return_lse=return_lse)
273
+
274
+ o = o[..., :head_dim_og]
275
+
276
+ if return_lse:
277
+ return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504
278
+ else:
279
+ return o
280
+
281
+ @torch.compiler.disable
282
+ def sageattn_varlen(
283
+ q: torch.Tensor,
284
+ k: torch.Tensor,
285
+ v: torch.Tensor,
286
+ cu_seqlens_q: torch.Tensor,
287
+ cu_seqlens_k: torch.Tensor,
288
+ max_seqlen_q: int,
289
+ max_seqlen_k: int,
290
+ is_causal: bool = False,
291
+ sm_scale: Optional[float] = None,
292
+ smooth_k: bool = True,
293
+ **kwargs: Any,
294
+ ) -> torch.Tensor:
295
+ """
296
+
297
+ Parameters
298
+ ----------
299
+ q : torch.Tensor
300
+ The query tensor, shape: ``[cu_seqlens_q[-1], num_qo_heads, head_dim]``.
301
+
302
+ k : torch.Tensor
303
+ The key tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``.
304
+
305
+ v : torch.Tensor
306
+ The value tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``.
307
+
308
+ cu_seqlens_q : torch.Tensor
309
+ The cumulative sequence lengths for the query sequences in the batch, used to index into `q`.
310
+ Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index.
311
+
312
+ cu_seqlens_k : torch.Tensor
313
+ The cumulative sequence lengths for the key and value sequences in the batch, used to index into `k` and `v`.
314
+ Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index.
315
+
316
+ max_seqlen_q : int
317
+ The maximum sequence length for the query tensor in the batch.
318
+
319
+ max_seqlen_k : int
320
+ The maximum sequence length for the key and value tensors in the batch.
321
+
322
+ is_causal : bool
323
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len for each sequence.
324
+ Default: False.
325
+
326
+ sm_scale : Optional[float]
327
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
328
+
329
+ smooth_k : bool
330
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
331
+ Default: True.
332
+
333
+ Returns
334
+ -------
335
+ torch.Tensor
336
+ The output tensor, shape: ``[cu_seqlens_q[-1], num_qo_heads, head_dim]``.
337
+
338
+ Note
339
+ ----
340
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
341
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16``, ``torch.bfloat16`` or ``torch.float32``.
342
+ - The tensors `cu_seqlens_q` and `cu_seqlens_k` must have the dtype ``torch.int32`` or ``torch.int64``.
343
+ - All tensors must be on the same cuda device.
344
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
345
+ """
346
+
347
+ dtype = q.dtype
348
+ assert q.is_cuda, "Input tensors must be on cuda."
349
+ assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
350
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
351
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
352
+
353
+ # FIXME(DefTruth): make sage attention work compatible with distributed
354
+ # env, for example, xDiT which launch by torchrun. Without this workaround,
355
+ # sage attention will run into illegal memory access error after first
356
+ # inference step in distributed env for multi gpus inference. This small
357
+ # workaround also make sage attention work compatible with torch.compile
358
+ # through non-fullgraph compile mode.
359
+ torch.cuda.set_device(v.device)
360
+
361
+ head_dim_og = q.size(-1)
362
+
363
+ if head_dim_og < 64:
364
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
365
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
366
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
367
+ elif head_dim_og > 64 and head_dim_og < 128:
368
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
369
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
370
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
371
+ elif head_dim_og > 128:
372
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
373
+
374
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous."
375
+ assert cu_seqlens_q.is_contiguous() and cu_seqlens_k.is_contiguous(), "cu_seqlens_q and cu_seqlens_k must be contiguous."
376
+
377
+ if dtype == torch.bfloat16 or dtype == torch.float32:
378
+ v = v.to(torch.float16)
379
+
380
+ if smooth_k:
381
+ km = k.mean(dim=0, keepdim=True) # ! km is calculated on the all the batches. Calculate over each individual sequence requires dedicated kernel.
382
+ k = k - km
383
+
384
+ if sm_scale is None:
385
+ sm_scale = 1.0 / (head_dim_og ** 0.5)
386
+
387
+ q_int8, q_scale, k_int8, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale = per_block_int8_varlen_triton(q, k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale=sm_scale)
388
+
389
+ if is_causal:
390
+ o = attn_true_varlen(q_int8, k_int8, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, output_dtype=dtype)
391
+ else:
392
+ o = attn_false_varlen(q_int8, k_int8, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, output_dtype=dtype)
393
+
394
+ o = o[..., :head_dim_og]
395
+
396
+ return o
397
+
398
+ @torch.compiler.disable
399
+ def sageattn_qk_int8_pv_fp16_cuda(
400
+ qkv_list,
401
+ # q: torch.Tensor,
402
+ # k: torch.Tensor,
403
+ # v: torch.Tensor,
404
+ tensor_layout: str = "HND",
405
+ is_causal: bool = False,
406
+ qk_quant_gran: str = "per_thread",
407
+ sm_scale: Optional[float] = None,
408
+ pv_accum_dtype: str = "fp32",
409
+ smooth_k: bool = True,
410
+ smooth_v: bool = False,
411
+ return_lse: bool = False,
412
+ **kwargs: Any,
413
+ ) -> torch.Tensor:
414
+ """
415
+ SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA.
416
+
417
+ Parameters
418
+ ----------
419
+ q : torch.Tensor
420
+ The query tensor. Shape:
421
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
422
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
423
+
424
+ k : torch.Tensor
425
+ The key tensor. Shape:
426
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
427
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
428
+
429
+ v : torch.Tensor
430
+ The value tensor. Shape:
431
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
432
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
433
+
434
+ tensor_layout : str
435
+ The tensor layout, either "HND" or "NHD".
436
+ Default: "HND".
437
+
438
+ is_causal : bool
439
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
440
+ Default: False.
441
+
442
+ qk_quant_gran : str
443
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
444
+ Default: "per_thread".
445
+
446
+ sm_scale : Optional[float]
447
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
448
+
449
+ pv_accum_dtype : str
450
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32".
451
+ - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b).
452
+ - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead.
453
+ - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
454
+ Default: "fp32".
455
+
456
+ smooth_k : bool
457
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
458
+ Default: True.
459
+
460
+ smooth_v : bool
461
+ Whether to smooth the value tensor by subtracting the mean along the sequence dimension.
462
+ smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32".
463
+ Default: False.
464
+
465
+ return_lse : bool
466
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
467
+ Default: False.
468
+
469
+ Returns
470
+ -------
471
+ torch.Tensor
472
+ The output tensor. Shape:
473
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
474
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
475
+
476
+ torch.Tensor
477
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
478
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
479
+ Only returned if `return_lse` is True.
480
+
481
+ Note
482
+ ----
483
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
484
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
485
+ - All tensors must be on the same cuda device.
486
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
487
+ """
488
+ q,k,v = qkv_list
489
+ qkv_list.clear()
490
+ dtype = q.dtype
491
+ assert SM80_ENABLED, "SM80 kernel is not available. make sure you GPUs with compute capability 8.0 or higher."
492
+ assert q.is_cuda, "Input tensors must be on cuda."
493
+ assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
494
+ assert qk_quant_gran in ["per_warp", "per_thread"], "qk_quant_gran must be either 'per_warp' or 'per_thread'."
495
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
496
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
497
+
498
+ # FIXME(DefTruth): make sage attention work compatible with distributed
499
+ # env, for example, xDiT which launch by torchrun. Without this workaround,
500
+ # sage attention will run into illegal memory access error after first
501
+ # inference step in distributed env for multi gpus inference. This small
502
+ # workaround also make sage attention work compatible with torch.compile
503
+ # through non-fullgraph compile mode.
504
+ torch.cuda.set_device(v.device)
505
+
506
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
507
+ _is_caual = 1 if is_causal else 0
508
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
509
+ _return_lse = 1 if return_lse else 0
510
+
511
+ head_dim_og = q.size(-1)
512
+
513
+ if head_dim_og < 64:
514
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
515
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
516
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
517
+ elif head_dim_og > 64 and head_dim_og < 128:
518
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
519
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
520
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
521
+ elif head_dim_og > 128:
522
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
523
+
524
+ # assert last dim is contiguous
525
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous."
526
+
527
+ if sm_scale is None:
528
+ sm_scale = head_dim_og**-0.5
529
+
530
+ seq_dim = 1 if _tensor_layout == 0 else 2
531
+
532
+ if smooth_k:
533
+ km = k.mean(dim=seq_dim, keepdim=True)
534
+ if return_lse:
535
+ if tensor_layout == "NHD":
536
+ lse_correction = torch.matmul(q.transpose(1, 2), km.transpose(1, 2).transpose(2, 3)).squeeze(-1).to(torch.float32)
537
+ else:
538
+ lse_correction = torch.matmul(q, km.transpose(2, 3)).squeeze(-1).to(torch.float32)
539
+ else:
540
+ km = None
541
+
542
+ if qk_quant_gran == "per_warp":
543
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), BLKK=64)
544
+ elif qk_quant_gran == "per_thread":
545
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), BLKK=64, WARPK=64)
546
+
547
+ q_size = q.size()
548
+ q_device = q.device
549
+ del q,k, km
550
+ o = torch.empty(q_size, dtype=dtype, device=q_device)
551
+
552
+ if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v:
553
+ warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.")
554
+ smooth_v = False
555
+
556
+ if pv_accum_dtype == 'fp32':
557
+ v = v.to(torch.float16)
558
+ lse = _qattn_sm80.qk_int8_sv_f16_accum_f32_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
559
+ elif pv_accum_dtype == "fp16":
560
+ if smooth_v:
561
+ smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout)
562
+ del v
563
+ lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(q_int8, k_int8, smoothed_v, o, q_scale, k_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
564
+ else:
565
+ v = v.to(torch.float16)
566
+ lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
567
+ elif pv_accum_dtype == "fp16+fp32":
568
+ v = v.to(torch.float16)
569
+ lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn_inst_buf(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
570
+ else:
571
+ raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}")
572
+
573
+ o = o[..., :head_dim_og]
574
+
575
+ if return_lse:
576
+ return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504
577
+ else:
578
+ return o
579
+
580
+ @torch.compiler.disable
581
+ def sageattn_qk_int8_pv_fp8_cuda(
582
+ qkv_list,
583
+ tensor_layout: str = "HND",
584
+ is_causal: bool = False,
585
+ qk_quant_gran: str = "per_thread",
586
+ sm_scale: Optional[float] = None,
587
+ pv_accum_dtype: str = "fp32+fp32",
588
+ smooth_k: bool = True,
589
+ smooth_v: bool = False,
590
+ return_lse: bool = False,
591
+ **kwargs: Any,
592
+ ) -> torch.Tensor:
593
+ """
594
+ SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA.
595
+
596
+ Parameters
597
+ ----------
598
+ q : torch.Tensor
599
+ The query tensor. Shape:
600
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
601
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
602
+
603
+ k : torch.Tensor
604
+ The key tensor. Shape:
605
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
606
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
607
+
608
+ v : torch.Tensor
609
+ The value tensor. Shape:
610
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
611
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
612
+
613
+ tensor_layout : str
614
+ The tensor layout, either "HND" or "NHD".
615
+ Default: "HND".
616
+
617
+ is_causal : bool
618
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
619
+ Default: False.
620
+
621
+ qk_quant_gran : str
622
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
623
+ Default: "per_thread".
624
+
625
+ sm_scale : Optional[float]
626
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
627
+
628
+ pv_accum_dtype : str
629
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32".
630
+ - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator.
631
+ - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
632
+ Default: "fp32+fp32".
633
+
634
+ smooth_k : bool
635
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
636
+ Default: True.
637
+
638
+ smooth_v : bool
639
+ Whether to smooth the value tensor by subtracting the mean along the sequence dimension.
640
+ smooth_v will be ignored if pv_accum_dtype is "fp32+fp32".
641
+ Default: False.
642
+
643
+ return_lse : bool
644
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
645
+ Default: False.
646
+
647
+ Returns
648
+ -------
649
+ torch.Tensor
650
+ The output tensor. Shape:
651
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
652
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
653
+
654
+ torch.Tensor
655
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
656
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
657
+ Only returned if `return_lse` is True.
658
+
659
+ Note
660
+ ----
661
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
662
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
663
+ - All tensors must be on the same cuda device.
664
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
665
+ """
666
+ q, k, v = qkv_list
667
+ qkv_list.clear()
668
+
669
+ dtype = q.dtype
670
+ assert SM89_ENABLED, "SM89 kernel is not available. Make sure you GPUs with compute capability 8.9."
671
+ assert q.is_cuda, "Input tensors must be on cuda."
672
+ assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
673
+ assert qk_quant_gran in ["per_warp", "per_thread"], "qk_quant_gran must be either 'per_warp' or 'per_thread'."
674
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
675
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
676
+
677
+ # FIXME(DefTruth): make sage attention work compatible with distributed
678
+ # env, for example, xDiT which launch by torchrun. Without this workaround,
679
+ # sage attention will run into illegal memory access error after first
680
+ # inference step in distributed env for multi gpus inference. This small
681
+ # workaround also make sage attention work compatible with torch.compile
682
+ # through non-fullgraph compile mode.
683
+ torch.cuda.set_device(v.device)
684
+
685
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
686
+ _is_caual = 1 if is_causal else 0
687
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
688
+ _return_lse = 1 if return_lse else 0
689
+
690
+ head_dim_og = q.size(-1)
691
+
692
+ if head_dim_og < 64:
693
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
694
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
695
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
696
+ elif head_dim_og > 64 and head_dim_og < 128:
697
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
698
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
699
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
700
+ elif head_dim_og > 128:
701
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
702
+
703
+ # assert last dim is contiguous
704
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous."
705
+
706
+ if sm_scale is None:
707
+ sm_scale = head_dim_og**-0.5
708
+
709
+ seq_dim = 1 if _tensor_layout == 0 else 2
710
+
711
+ if smooth_k:
712
+ km = k.mean(dim=seq_dim, keepdim=True)
713
+ if return_lse:
714
+ if tensor_layout == "NHD":
715
+ lse_correction = torch.matmul(q.transpose(1, 2), km.transpose(1, 2).transpose(2, 3)).squeeze(-1).to(torch.float32)
716
+ else:
717
+ lse_correction = torch.matmul(q, km.transpose(2, 3)).squeeze(-1).to(torch.float32)
718
+ else:
719
+ km = None
720
+
721
+ if qk_quant_gran == "per_warp":
722
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64)
723
+ elif qk_quant_gran == "per_thread":
724
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64)
725
+ q_size = q.size()
726
+ q_device = q.device
727
+ del q,k,km
728
+
729
+ if pv_accum_dtype == 'fp32+fp32' and smooth_v:
730
+ warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.")
731
+ smooth_v = False
732
+
733
+ v_fp8, v_scale, vm = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=smooth_v)
734
+ del v
735
+ o = torch.empty(q_size, dtype=dtype, device=q_device)
736
+ if pv_accum_dtype == "fp32":
737
+ if smooth_v:
738
+ lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
739
+ else:
740
+ lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
741
+ elif pv_accum_dtype == "fp32+fp32":
742
+ lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
743
+
744
+ o = o[..., :head_dim_og]
745
+
746
+ if return_lse:
747
+ return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504
748
+ else:
749
+ return o
750
+
751
+
752
+ @torch.compiler.disable
753
+ def sageattn_qk_int8_pv_fp8_window_cuda(
754
+ qkv_list,
755
+ # q: torch.Tensor,
756
+ # k: torch.Tensor,
757
+ # v: torch.Tensor,
758
+ tensor_layout: str = "HND",
759
+ is_causal: bool = False,
760
+ qk_quant_gran: str = "per_thread",
761
+ sm_scale: Optional[float] = None,
762
+ pv_accum_dtype: str = "fp32+fp32",
763
+ smooth_k: bool = True,
764
+ smooth_v: bool = False,
765
+ return_lse: bool = False,
766
+ window = -1,
767
+ **kwargs: Any,
768
+ ) -> torch.Tensor:
769
+ """
770
+ SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA.
771
+
772
+ Parameters
773
+ ----------
774
+ q : torch.Tensor
775
+ The query tensor. Shape:
776
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
777
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
778
+
779
+ k : torch.Tensor
780
+ The key tensor. Shape:
781
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
782
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
783
+
784
+ v : torch.Tensor
785
+ The value tensor. Shape:
786
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
787
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
788
+
789
+ tensor_layout : str
790
+ The tensor layout, either "HND" or "NHD".
791
+ Default: "HND".
792
+
793
+ is_causal : bool
794
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
795
+ Default: False.
796
+
797
+ qk_quant_gran : str
798
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
799
+ Default: "per_thread".
800
+
801
+ sm_scale : Optional[float]
802
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
803
+
804
+ pv_accum_dtype : str
805
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32".
806
+ - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator.
807
+ - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
808
+ Default: "fp32+fp32".
809
+
810
+ smooth_k : bool
811
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
812
+ Default: True.
813
+
814
+ smooth_v : bool
815
+ Whether to smooth the value tensor by subtracting the mean along the sequence dimension.
816
+ smooth_v will be ignored if pv_accum_dtype is "fp32+fp32".
817
+ Default: False.
818
+
819
+ return_lse : bool
820
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
821
+ Default: False.
822
+
823
+ Returns
824
+ -------
825
+ torch.Tensor
826
+ The output tensor. Shape:
827
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
828
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
829
+
830
+ torch.Tensor
831
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
832
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
833
+ Only returned if `return_lse` is True.
834
+
835
+ Note
836
+ ----
837
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
838
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
839
+ - All tensors must be on the same cuda device.
840
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
841
+ """
842
+ q,k,v = qkv_list
843
+ qkv_list.clear()
844
+ dtype = q.dtype
845
+ assert SM89_ENABLED, "SM89 kernel is not available. Make sure you GPUs with compute capability 8.9."
846
+ assert q.is_cuda, "Input tensors must be on cuda."
847
+ assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
848
+ assert qk_quant_gran in ["per_warp", "per_thread"], "qk_quant_gran must be either 'per_warp' or 'per_thread'."
849
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
850
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
851
+
852
+ # FIXME(DefTruth): make sage attention work compatible with distributed
853
+ # env, for example, xDiT which launch by torchrun. Without this workaround,
854
+ # sage attention will run into illegal memory access error after first
855
+ # inference step in distributed env for multi gpus inference. This small
856
+ # workaround also make sage attention work compatible with torch.compile
857
+ # through non-fullgraph compile mode.
858
+ torch.cuda.set_device(v.device)
859
+
860
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
861
+ _is_caual = 1 if is_causal else 0
862
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
863
+ _return_lse = 1 if return_lse else 0
864
+
865
+ head_dim_og = q.size(-1)
866
+
867
+ if head_dim_og < 64:
868
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
869
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
870
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
871
+ elif head_dim_og > 64 and head_dim_og < 128:
872
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
873
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
874
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
875
+ elif head_dim_og > 128:
876
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
877
+
878
+ # assert last dim is contiguous
879
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous."
880
+
881
+ if sm_scale is None:
882
+ sm_scale = head_dim_og**-0.5
883
+
884
+ seq_dim = 1 if _tensor_layout == 0 else 2
885
+
886
+ if smooth_k:
887
+ km = k.mean(dim=seq_dim, keepdim=True)
888
+ if return_lse:
889
+ if tensor_layout == "NHD":
890
+ lse_correction = torch.matmul(q.transpose(1, 2), km.transpose(1, 2).transpose(2, 3)).squeeze(-1).to(torch.float32)
891
+ else:
892
+ lse_correction = torch.matmul(q, km.transpose(2, 3)).squeeze(-1).to(torch.float32)
893
+ else:
894
+ km = None
895
+
896
+ if qk_quant_gran == "per_warp":
897
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64)
898
+ elif qk_quant_gran == "per_thread":
899
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64)
900
+
901
+ q_size = q.size()
902
+ q_device = q.device
903
+ del q,k
904
+
905
+ if pv_accum_dtype == 'fp32+fp32' and smooth_v:
906
+ warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.")
907
+ smooth_v = False
908
+
909
+ v_fp8, v_scale, vm = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=smooth_v)
910
+ del v
911
+ o = torch.empty(q_size, dtype=dtype, device=q_device)
912
+
913
+ if pv_accum_dtype == "fp32":
914
+ if smooth_v:
915
+ lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window = window)
916
+ else:
917
+ lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window = window)
918
+ elif pv_accum_dtype == "fp32+fp32":
919
+ lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window = window)
920
+
921
+ o = o[..., :head_dim_og]
922
+
923
+ if return_lse:
924
+ return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504
925
+ else:
926
+ return o
927
+
928
+ @torch.compiler.disable
929
+ def sageattn_qk_int8_pv_fp8_cuda_sm90(
930
+ qkv_list,
931
+ # q: torch.Tensor,
932
+ # k: torch.Tensor,
933
+ # v: torch.Tensor,
934
+ tensor_layout: str = "HND",
935
+ is_causal: bool = False,
936
+ qk_quant_gran: str = "per_thread",
937
+ sm_scale: Optional[float] = None,
938
+ pv_accum_dtype: str = "fp32+fp32",
939
+ smooth_k: bool = True,
940
+ return_lse: bool = False,
941
+ **kwargs: Any,
942
+ ) -> torch.Tensor:
943
+ """
944
+ SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA.
945
+
946
+ Parameters
947
+ ----------
948
+ q : torch.Tensor
949
+ The query tensor. Shape:
950
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
951
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
952
+
953
+ k : torch.Tensor
954
+ The key tensor. Shape:
955
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
956
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
957
+
958
+ v : torch.Tensor
959
+ The value tensor. Shape:
960
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
961
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
962
+
963
+ tensor_layout : str
964
+ The tensor layout, either "HND" or "NHD".
965
+ Default: "HND".
966
+
967
+ is_causal : bool
968
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
969
+ Default: False.
970
+
971
+ qk_quant_gran : str
972
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
973
+ Default: "per_thread".
974
+
975
+ sm_scale : Optional[float]
976
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
977
+
978
+ pv_accum_dtype : str
979
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32".
980
+ - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator.
981
+ - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
982
+ Default: "fp32+fp32".
983
+
984
+ smooth_k : bool
985
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
986
+ Default: True.
987
+
988
+ return_lse : bool
989
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
990
+ Default: False.
991
+
992
+ Returns
993
+ -------
994
+ torch.Tensor
995
+ The output tensor. Shape:
996
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
997
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
998
+
999
+ torch.Tensor
1000
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
1001
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
1002
+ Only returned if `return_lse` is True.
1003
+
1004
+ Note
1005
+ ----
1006
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
1007
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
1008
+ - All tensors must be on the same cuda device.
1009
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
1010
+ """
1011
+ q,k,v = qkv_list
1012
+ qkv_list.clear()
1013
+ dtype = q.dtype
1014
+ assert SM90_ENABLED, "SM90 kernel is not available. Make sure you GPUs with compute capability 9.0."
1015
+ assert q.is_cuda, "Input tensors must be on cuda."
1016
+ assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
1017
+ assert qk_quant_gran in ["per_warp", "per_thread"], "qk_quant_gran must be either 'per_warp' or 'per_thread'."
1018
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
1019
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
1020
+
1021
+ torch.cuda.set_device(v.device)
1022
+
1023
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
1024
+ _is_caual = 1 if is_causal else 0
1025
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
1026
+ _return_lse = 1 if return_lse else 0
1027
+
1028
+ head_dim_og = q.size(-1)
1029
+
1030
+ if head_dim_og < 64:
1031
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
1032
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
1033
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
1034
+ elif head_dim_og > 64 and head_dim_og < 128:
1035
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
1036
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
1037
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
1038
+ elif head_dim_og > 128:
1039
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
1040
+
1041
+ # assert last dim is contiguous
1042
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous."
1043
+
1044
+ if sm_scale is None:
1045
+ sm_scale = head_dim_og**-0.5
1046
+
1047
+ seq_dim = 1 if _tensor_layout == 0 else 2
1048
+
1049
+ if smooth_k:
1050
+ km = k.mean(dim=seq_dim, keepdim=True)
1051
+ if return_lse:
1052
+ if tensor_layout == "NHD":
1053
+ lse_correction = torch.matmul(q.transpose(1, 2), km.transpose(1, 2).transpose(2, 3)).squeeze(-1).to(torch.float32)
1054
+ else:
1055
+ lse_correction = torch.matmul(q, km.transpose(2, 3)).squeeze(-1).to(torch.float32)
1056
+ else:
1057
+ km = None
1058
+
1059
+ if qk_quant_gran == "per_warp":
1060
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128)
1061
+ elif qk_quant_gran == "per_thread":
1062
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128, WARPK=128)
1063
+
1064
+ q_size = q.size()
1065
+ q_device = q.device
1066
+ del q,k
1067
+
1068
+
1069
+ # pad v to multiple of 128
1070
+ # TODO: modify per_channel_fp8 kernel to handle this
1071
+ kv_len = k.size(seq_dim)
1072
+ v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0
1073
+ if v_pad_len > 0:
1074
+ if tensor_layout == "HND":
1075
+ v = torch.cat([v, torch.zeros(v.size(0), v.size(1), v_pad_len, v.size(3), dtype=v.dtype, device=v.device)], dim=2)
1076
+ else:
1077
+ v = torch.cat([v, torch.zeros(v.size(0), v_pad_len, v.size(2), v.size(3), dtype=v.dtype, device=v.device)], dim=1)
1078
+
1079
+ v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False)
1080
+ del v
1081
+ o = torch.empty(q_size, dtype=dtype, device=q_device)
1082
+
1083
+ if pv_accum_dtype == "fp32":
1084
+ raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.")
1085
+ lse = _qattn_sm90.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
1086
+ elif pv_accum_dtype == "fp32+fp32":
1087
+ lse = _qattn_sm90.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
1088
+
1089
+ o = o[..., :head_dim_og]
1090
+
1091
+ if return_lse:
1092
+ return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504
1093
+ else:
1094
+ return o