Spaces:
Running on T4
Running on T4
Commit ·
3a168d4
1
Parent(s): 9403a3d
fix
Browse files- cuda/rwkv7_v3a_ops.cu +2 -2
cuda/rwkv7_v3a_ops.cu
CHANGED
|
@@ -2980,7 +2980,7 @@ at::Tensor linear_f16_orig_lt_cfg_cuda(at::Tensor x, at::Tensor weight_orig, int
|
|
| 2980 |
check_cublaslt(cublasLtMatmulAlgoGetHeuristic(lt_handle, op_desc, a_desc, b_desc, c_desc, c_desc, pref, static_cast<int>(heuristics.size()), heuristics.data(), &returned),
|
| 2981 |
"linear_f16_orig_lt heuristic");
|
| 2982 |
TORCH_CHECK(returned > 0, "linear_f16_orig_lt found no algorithm");
|
| 2983 |
-
|
| 2984 |
const float alpha = 1.0f;
|
| 2985 |
const float beta = 0.0f;
|
| 2986 |
check_cublaslt(cublasLtMatmul(
|
|
@@ -2996,7 +2996,7 @@ at::Tensor linear_f16_orig_lt_cfg_cuda(at::Tensor x, at::Tensor weight_orig, int
|
|
| 2996 |
c_desc,
|
| 2997 |
y.data_ptr<dtype>(),
|
| 2998 |
c_desc,
|
| 2999 |
-
&heuristics[
|
| 3000 |
workspace_ptr,
|
| 3001 |
workspace_size,
|
| 3002 |
at::cuda::getCurrentCUDAStream()),
|
|
|
|
| 2980 |
check_cublaslt(cublasLtMatmulAlgoGetHeuristic(lt_handle, op_desc, a_desc, b_desc, c_desc, c_desc, pref, static_cast<int>(heuristics.size()), heuristics.data(), &returned),
|
| 2981 |
"linear_f16_orig_lt heuristic");
|
| 2982 |
TORCH_CHECK(returned > 0, "linear_f16_orig_lt found no algorithm");
|
| 2983 |
+
const int selected_algo = algo_index < returned ? static_cast<int>(algo_index) : 0;
|
| 2984 |
const float alpha = 1.0f;
|
| 2985 |
const float beta = 0.0f;
|
| 2986 |
check_cublaslt(cublasLtMatmul(
|
|
|
|
| 2996 |
c_desc,
|
| 2997 |
y.data_ptr<dtype>(),
|
| 2998 |
c_desc,
|
| 2999 |
+
&heuristics[selected_algo].algo,
|
| 3000 |
workspace_ptr,
|
| 3001 |
workspace_size,
|
| 3002 |
at::cuda::getCurrentCUDAStream()),
|