NeverlandPeter commited on
Commit
3a168d4
·
1 Parent(s): 9403a3d
Files changed (1) hide show
  1. cuda/rwkv7_v3a_ops.cu +2 -2
cuda/rwkv7_v3a_ops.cu CHANGED
@@ -2980,7 +2980,7 @@ at::Tensor linear_f16_orig_lt_cfg_cuda(at::Tensor x, at::Tensor weight_orig, int
2980
  check_cublaslt(cublasLtMatmulAlgoGetHeuristic(lt_handle, op_desc, a_desc, b_desc, c_desc, c_desc, pref, static_cast<int>(heuristics.size()), heuristics.data(), &returned),
2981
  "linear_f16_orig_lt heuristic");
2982
  TORCH_CHECK(returned > 0, "linear_f16_orig_lt found no algorithm");
2983
- TORCH_CHECK(algo_index < returned, "linear_f16_orig_lt_cfg algo_index=", algo_index, " returned=", returned);
2984
  const float alpha = 1.0f;
2985
  const float beta = 0.0f;
2986
  check_cublaslt(cublasLtMatmul(
@@ -2996,7 +2996,7 @@ at::Tensor linear_f16_orig_lt_cfg_cuda(at::Tensor x, at::Tensor weight_orig, int
2996
  c_desc,
2997
  y.data_ptr<dtype>(),
2998
  c_desc,
2999
- &heuristics[algo_index].algo,
3000
  workspace_ptr,
3001
  workspace_size,
3002
  at::cuda::getCurrentCUDAStream()),
 
2980
  check_cublaslt(cublasLtMatmulAlgoGetHeuristic(lt_handle, op_desc, a_desc, b_desc, c_desc, c_desc, pref, static_cast<int>(heuristics.size()), heuristics.data(), &returned),
2981
  "linear_f16_orig_lt heuristic");
2982
  TORCH_CHECK(returned > 0, "linear_f16_orig_lt found no algorithm");
2983
+ const int selected_algo = algo_index < returned ? static_cast<int>(algo_index) : 0;
2984
  const float alpha = 1.0f;
2985
  const float beta = 0.0f;
2986
  check_cublaslt(cublasLtMatmul(
 
2996
  c_desc,
2997
  y.data_ptr<dtype>(),
2998
  c_desc,
2999
+ &heuristics[selected_algo].algo,
3000
  workspace_ptr,
3001
  workspace_size,
3002
  at::cuda::getCurrentCUDAStream()),