NeverlandPeter commited on
Commit
45d682f
·
1 Parent(s): c954ce5
app.py CHANGED
@@ -1,45 +1,49 @@
1
- import os, copy
2
- os.environ["RWKV_V7_ON"] = '1'
3
- os.environ["RWKV_JIT_ON"] = '1'
4
- os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
5
-
6
- from rwkv.model import RWKV
7
-
8
- import gc, re
9
  import gradio as gr
10
- import base64
11
- from io import BytesIO
12
  import torch
13
- import torch.nn.functional as F
14
  from datetime import datetime
15
  from huggingface_hub import hf_hub_download
16
  from pynvml import *
 
 
 
 
17
  nvmlInit()
18
- gpu_h = nvmlDeviceGetHandleByIndex(0)
19
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
 
21
  ctx_limit = 7000
22
  gen_limit = 1000
23
 
24
  ########################## text rwkv ################################################################
25
- from rwkv.utils import PIPELINE, PIPELINE_ARGS
26
 
27
  title = "rwkv7-g1f-2.9b-20260420-ctx8192"
28
  model_path = hf_hub_download(repo_id="BlinkDL/rwkv7-g1", filename=f"{title}.pth")
29
- model = RWKV(model=model_path.replace('.pth',''), strategy='cuda fp16')
 
 
 
 
 
 
 
 
 
 
30
  pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
31
 
32
- args = model.args
 
 
 
 
 
 
 
 
33
 
34
- _, _ = model.forward([0], None)
35
- state = model.generate_zero_state()
36
- static_input = torch.empty((model.n_embd), device="cuda", dtype=torch.half)
37
- static_state_in = [torch.empty_like(x, device="cuda") for x in state]
38
- static_state_out = [torch.empty_like(x, device="cuda") for x in state]
39
- static_output = torch.empty((model.args.vocab_size), device="cuda", dtype=torch.half)
40
- graph = torch.cuda.CUDAGraph()
41
- with torch.cuda.graph(graph):
42
- static_output, static_state_out = model.forward_one_alt(static_input, static_state_in)
43
 
44
  def generate_prompt(instruction, input=""):
45
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
@@ -73,25 +77,30 @@ def evaluate(
73
  out_last = 0
74
  out_str = ''
75
  occurrence = {}
76
- state = None
 
77
  for i in range(int(token_count)):
78
-
79
  if i == 0:
80
  input_ids = pipeline.encode(ctx)[-ctx_limit:]
81
- out, state = model.forward(input_ids, state)
82
- for j in range(len(state)):
83
- static_state_in[j].copy_(state[j])
84
- static_output.copy_(out)
 
 
 
 
 
85
  else:
86
- static_input.copy_(model.z['emb.weight'][token])
87
- graph.replay()
88
- for j in range(len(state)):
89
- static_state_in[j].copy_(static_state_out[j])
90
 
91
  for n in occurrence:
92
- static_output[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
93
 
94
- token = pipeline.sample_logits(static_output, temperature=args.temperature, top_p=args.top_p)
95
  if token in args.token_stop:
96
  break
97
  all_tokens += [token]
@@ -168,4 +177,4 @@ with gr.Blocks(title=title, theme=gr.themes.Base()) as demo:
168
  data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty, penalty_decay])
169
 
170
  demo.queue(default_concurrency_limit=1, max_size=10)
171
- demo.launch(share=False)
 
1
+ import gc, os, re
 
 
 
 
 
 
 
2
  import gradio as gr
 
 
3
  import torch
 
4
  from datetime import datetime
5
  from huggingface_hub import hf_hub_download
6
  from pynvml import *
7
+ from rwkv.utils import PIPELINE, PIPELINE_ARGS
8
+
9
+ import rwkv7_fast_v3a as v3a
10
+
11
  nvmlInit()
12
+ # gpu_h = nvmlDeviceGetHandleByIndex(0)
 
13
 
14
  ctx_limit = 7000
15
  gen_limit = 1000
16
 
17
  ########################## text rwkv ################################################################
 
18
 
19
  title = "rwkv7-g1f-2.9b-20260420-ctx8192"
20
  model_path = hf_hub_download(repo_id="BlinkDL/rwkv7-g1", filename=f"{title}.pth")
21
+ # model_path = "/dev/shm/rwkv7-g1f-7.2b-20260414-ctx8192.pth"
22
+
23
+ v3a.MODEL_PATH = model_path
24
+ v3a.WKV_MODE = "fp32io16"
25
+ v3a.EMB_DEVICE = "cpu"
26
+ v3a.RKV_MODE = "off"
27
+ v3a.CMIX_SPARSE = "no-fc"
28
+ v3a.LOWRANK_WEIGHT = "transpose"
29
+ v3a.ORIG_LINEAR_GROUPS = {"att_c2c", "ffn_key", "head"}
30
+ v3a.load_extensions(v3a.WKV_MODE)
31
+ model = v3a.RWKV7()
32
  pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
33
 
34
+ decode_state = model.zero_state(1)
35
+ decode_x = torch.empty((1, 1, v3a.C), device="cuda", dtype=torch.half)
36
+ decode_path = v3a.select_path(1, 1)
37
+ for _ in range(2):
38
+ model.forward_from_x(decode_x, decode_state, decode_path)
39
+ torch.cuda.synchronize()
40
+ decode_graph = torch.cuda.CUDAGraph()
41
+ with torch.cuda.graph(decode_graph):
42
+ decode_output = model.forward_from_x(decode_x, decode_state, decode_path)
43
 
44
+ def token_to_x(token: int):
45
+ token_tensor = torch.tensor([[int(token)]], dtype=torch.long, device="cpu" if model.emb_cpu else "cuda")
46
+ return model.embed(token_tensor)
 
 
 
 
 
 
47
 
48
  def generate_prompt(instruction, input=""):
49
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
 
77
  out_last = 0
78
  out_str = ''
79
  occurrence = {}
80
+ state = model.zero_state(1)
81
+ out = None
82
  for i in range(int(token_count)):
83
+
84
  if i == 0:
85
  input_ids = pipeline.encode(ctx)[-ctx_limit:]
86
+ CHUNK_LEN = 8192 # chunk prefill, save VRAM
87
+ while len(input_ids) > 0:
88
+ token_device = "cpu" if model.emb_cpu else "cuda"
89
+ tokens = torch.tensor(input_ids[:CHUNK_LEN], dtype=torch.long, device=token_device)
90
+ out = model.forward(tokens, state).view(-1)
91
+ input_ids = input_ids[CHUNK_LEN:]
92
+ for dst, src in zip(decode_state, state):
93
+ dst.copy_(src)
94
+ logits = out
95
  else:
96
+ decode_x.copy_(token_to_x(token))
97
+ decode_graph.replay()
98
+ logits = decode_output.view(-1)
 
99
 
100
  for n in occurrence:
101
+ logits[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
102
 
103
+ token = pipeline.sample_logits(logits, temperature=args.temperature, top_p=args.top_p)
104
  if token in args.token_stop:
105
  break
106
  all_tokens += [token]
 
177
  data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty, penalty_decay])
178
 
179
  demo.queue(default_concurrency_limit=1, max_size=10)
180
+ demo.launch(share=False, server_name="0.0.0.0")
cuda/rwkv7_fast_ops_fp16.cu CHANGED
@@ -12,7 +12,8 @@ namespace {
12
 
13
  constexpr int HEAD_SIZE = 64;
14
  constexpr int WARPS_PER_BLOCK = 4;
15
- constexpr float NORM_EPS = 1.0e-12f;
 
16
  constexpr int FFN_SPMV_THREADS = 128;
17
  constexpr int FFN_TILE = 128;
18
 
@@ -202,7 +203,7 @@ __global__ void tmix_kk_a_gate_kernel(
202
  float sum_sq = u0 * u0 + u1 * u1;
203
  sum_sq = warp_sum(sum_sq);
204
  const float total = __shfl_sync(0xffffffffu, sum_sq, 0);
205
- const float inv_d = 1.0f / fmaxf(sqrtf(total), NORM_EPS);
206
  const float kk0 = u0 * inv_d;
207
  const float kk1 = u1 * inv_d;
208
 
@@ -264,7 +265,7 @@ __global__ void tmix_lnx_rkvres_xg_kernel(
264
  }
265
  __syncthreads();
266
  const float var = (partial[0] + partial[1]) * (1.0f / 64.0f);
267
- const float rstd = rsqrtf(var + 64.0e-5f);
268
  __syncthreads();
269
 
270
  const float rv = load_h1(r + idx);
@@ -296,7 +297,7 @@ __global__ void tmix_vres_gate_kernel(
296
  if (idx >= total) {
297
  return;
298
  }
299
- const int c = static_cast<int>(idx & (static_cast<int64_t>(C) - 1));
300
  const float vv = load_h1(v + idx);
301
  const float gate = sigmoid_fast(load_h1(v0 + c) + load_h1(v12 + idx));
302
  store_h1(out + idx, fmaf(load_h1(v_first + idx) - vv, gate, vv));
 
12
 
13
  constexpr int HEAD_SIZE = 64;
14
  constexpr int WARPS_PER_BLOCK = 4;
15
+ constexpr float KK_NORMALIZE_EPS = 1.0e-12f;
16
+ constexpr float TMIX_LN_X_EPS = 64.0e-5f;
17
  constexpr int FFN_SPMV_THREADS = 128;
18
  constexpr int FFN_TILE = 128;
19
 
 
203
  float sum_sq = u0 * u0 + u1 * u1;
204
  sum_sq = warp_sum(sum_sq);
205
  const float total = __shfl_sync(0xffffffffu, sum_sq, 0);
206
+ const float inv_d = 1.0f / fmaxf(sqrtf(total), KK_NORMALIZE_EPS);
207
  const float kk0 = u0 * inv_d;
208
  const float kk1 = u1 * inv_d;
209
 
 
265
  }
266
  __syncthreads();
267
  const float var = (partial[0] + partial[1]) * (1.0f / 64.0f);
268
+ const float rstd = rsqrtf(var + TMIX_LN_X_EPS);
269
  __syncthreads();
270
 
271
  const float rv = load_h1(r + idx);
 
297
  if (idx >= total) {
298
  return;
299
  }
300
+ const int c = static_cast<int>(idx % static_cast<int64_t>(C));
301
  const float vv = load_h1(v + idx);
302
  const float gate = sigmoid_fast(load_h1(v0 + c) + load_h1(v12 + idx));
303
  store_h1(out + idx, fmaf(load_h1(v_first + idx) - vv, gate, vv));
cuda/rwkv7_v3a_ops.cpp CHANGED
@@ -1,6 +1,8 @@
1
  #include <torch/extension.h>
2
  #include <vector>
3
 
 
 
4
  torch::Tensor layer_norm_f16_cuda(torch::Tensor x, torch::Tensor weight, torch::Tensor bias, double eps);
5
  torch::Tensor emb_ln0_bf16_to_f16_cuda(torch::Tensor emb, torch::Tensor weight, torch::Tensor bias, double eps);
6
  torch::Tensor layer_norm_f16_small_cuda(torch::Tensor x, torch::Tensor weight, torch::Tensor bias, double eps);
@@ -93,7 +95,6 @@ torch::Tensor emb_ln0_bf16_to_f16(torch::Tensor emb, torch::Tensor weight, torch
93
  check_bf16_cuda_contig(bias, "bias");
94
  TORCH_CHECK(emb.dim() == 2, "emb must have shape [V, C]");
95
  const int64_t c = emb.size(1);
96
- TORCH_CHECK(c == 4096, "emb_ln0_bf16_to_f16 currently requires C=4096");
97
  TORCH_CHECK(weight.dim() == 1 && weight.size(0) == c, "weight shape mismatch");
98
  TORCH_CHECK(bias.dim() == 1 && bias.size(0) == c, "bias shape mismatch");
99
  return emb_ln0_bf16_to_f16_cuda(emb, weight, bias, eps);
@@ -436,7 +437,7 @@ std::vector<torch::Tensor> add_layer_norm_cmix_mix_f16(torch::Tensor x, torch::T
436
  TORCH_CHECK(x.sizes() == residual.sizes(), "add_layer_norm_cmix_mix_f16 x/residual shape mismatch");
437
  TORCH_CHECK(x.dim() == 3 && x.size(1) == 1, "add_layer_norm_cmix_mix_f16 requires shape [B,1,C]");
438
  const int64_t c = x.size(2);
439
- TORCH_CHECK(c == 4096, "add_layer_norm_cmix_mix_f16 currently requires C=4096");
440
  TORCH_CHECK(shift_state.dim() == 2 && shift_state.size(0) == x.size(0) && shift_state.size(1) == c,
441
  "shift_state shape mismatch");
442
  TORCH_CHECK(weight.dim() == 1 && weight.size(0) == c, "weight shape mismatch");
@@ -470,13 +471,15 @@ std::vector<torch::Tensor> add_layer_norm_tmix_mix6_f16(
470
  check_half_cuda_contig(x_a, "x_a");
471
  check_half_cuda_contig(x_g, "x_g");
472
  TORCH_CHECK(x.sizes() == residual.sizes(), "add_layer_norm_tmix_mix6_f16 x/residual shape mismatch");
473
- TORCH_CHECK(x.dim() == 3 && x.size(1) == 1 && x.size(2) == 4096, "add_layer_norm_tmix_mix6_f16 requires shape [B,1,4096]");
474
- TORCH_CHECK(shift_state.dim() == 2 && shift_state.size(0) == x.size(0) && shift_state.size(1) == 4096,
 
 
475
  "shift_state shape mismatch");
476
- TORCH_CHECK(weight.dim() == 1 && weight.size(0) == 4096, "weight shape mismatch");
477
- TORCH_CHECK(bias.dim() == 1 && bias.size(0) == 4096, "bias shape mismatch");
478
- TORCH_CHECK(x_r.numel() == 4096 && x_w.numel() == 4096 && x_k.numel() == 4096 &&
479
- x_v.numel() == 4096 && x_a.numel() == 4096 && x_g.numel() == 4096,
480
  "mix vector shape mismatch");
481
  return add_layer_norm_tmix_mix6_f16_cuda(
482
  x, residual, shift_state, weight, bias, x_r, x_w, x_k, x_v, x_a, x_g, eps);
@@ -602,10 +605,10 @@ void advance_i32(torch::Tensor x, int64_t amount) {
602
  } // namespace
603
 
604
  TORCH_LIBRARY(rwkv7_v3a_ops, m) {
605
- m.def("layer_norm_f16(Tensor x, Tensor weight, Tensor bias, float eps=1e-5) -> Tensor");
606
- m.def("emb_ln0_bf16_to_f16(Tensor emb, Tensor weight, Tensor bias, float eps=1e-5) -> Tensor");
607
- m.def("layer_norm_f16_small(Tensor x, Tensor weight, Tensor bias, float eps=1e-5) -> Tensor");
608
- m.def("layer_norm_f16_small512(Tensor x, Tensor weight, Tensor bias, float eps=1e-5) -> Tensor");
609
  m.def("linear_f16(Tensor x, Tensor weight) -> Tensor");
610
  m.def("linear_f16_orig(Tensor x, Tensor weight_orig) -> Tensor");
611
  m.def("linear_orig_rows_f16(Tensor x, Tensor weight_orig, int row_tile, int out_tile) -> Tensor");
@@ -628,14 +631,14 @@ TORCH_LIBRARY(rwkv7_v3a_ops, m) {
628
  m.def("linear_wag_rank_out_f16(Tensor w1, Tensor a1, Tensor g1, Tensor w2_t, Tensor a2_t, Tensor g2_t) -> Tensor[]");
629
  m.def("linear_wagv_rank_out_f16(Tensor w1, Tensor a1, Tensor g1, Tensor v1, Tensor w2_t, Tensor a2_t, Tensor g2_t, Tensor v2_t, Tensor v, Tensor v_first, Tensor v0) -> Tensor[]");
630
  m.def("add_f16(Tensor x, Tensor y) -> Tensor");
631
- m.def("add_layer_norm_f16(Tensor x, Tensor residual, Tensor weight, Tensor bias, float eps=1e-5) -> Tensor[]");
632
- m.def("add_last_layer_norm_f16(Tensor x, Tensor residual, Tensor weight, Tensor bias, float eps=1e-5) -> Tensor");
633
- m.def("add_layer_norm_cmix_mix_f16(Tensor x, Tensor residual, Tensor(a!) shift_state, Tensor weight, Tensor bias, Tensor x_k, float eps=1e-5) -> Tensor[]");
634
- m.def("add_layer_norm_tmix_mix6_f16(Tensor x, Tensor residual, Tensor(a!) shift_state, Tensor weight, Tensor bias, Tensor x_r, Tensor x_w, Tensor x_k, Tensor x_v, Tensor x_a, Tensor x_g, float eps=1e-5) -> Tensor[]");
635
  m.def("add_layer_norm_tmix_mix6_f16_cfg(Tensor x, Tensor residual, Tensor(a!) shift_state, Tensor weight, Tensor bias, Tensor x_r, Tensor x_w, Tensor x_k, Tensor x_v, Tensor x_a, Tensor x_g, float eps, int threads) -> Tensor[]");
636
- m.def("add_layer_norm_tmix_mix6_f16_scalar_stats(Tensor x, Tensor residual, Tensor(a!) shift_state, Tensor weight, Tensor bias, Tensor x_r, Tensor x_w, Tensor x_k, Tensor x_v, Tensor x_a, Tensor x_g, float eps=1e-5) -> Tensor[]");
637
  m.def("add_layer_norm_cmix_mix_f16_cfg(Tensor x, Tensor residual, Tensor(a!) shift_state, Tensor weight, Tensor bias, Tensor x_k, float eps, int threads) -> Tensor[]");
638
- m.def("add_layer_norm_cmix_mix_f16_scalar_stats(Tensor x, Tensor residual, Tensor(a!) shift_state, Tensor weight, Tensor bias, Tensor x_k, float eps=1e-5) -> Tensor[]");
639
  m.def("advance_i32(Tensor(a!) x, int amount) -> ()");
640
  }
641
 
 
1
  #include <torch/extension.h>
2
  #include <vector>
3
 
4
+ #define RWKV7_LAYER_NORM_EPS_SCHEMA "1e-5"
5
+
6
  torch::Tensor layer_norm_f16_cuda(torch::Tensor x, torch::Tensor weight, torch::Tensor bias, double eps);
7
  torch::Tensor emb_ln0_bf16_to_f16_cuda(torch::Tensor emb, torch::Tensor weight, torch::Tensor bias, double eps);
8
  torch::Tensor layer_norm_f16_small_cuda(torch::Tensor x, torch::Tensor weight, torch::Tensor bias, double eps);
 
95
  check_bf16_cuda_contig(bias, "bias");
96
  TORCH_CHECK(emb.dim() == 2, "emb must have shape [V, C]");
97
  const int64_t c = emb.size(1);
 
98
  TORCH_CHECK(weight.dim() == 1 && weight.size(0) == c, "weight shape mismatch");
99
  TORCH_CHECK(bias.dim() == 1 && bias.size(0) == c, "bias shape mismatch");
100
  return emb_ln0_bf16_to_f16_cuda(emb, weight, bias, eps);
 
437
  TORCH_CHECK(x.sizes() == residual.sizes(), "add_layer_norm_cmix_mix_f16 x/residual shape mismatch");
438
  TORCH_CHECK(x.dim() == 3 && x.size(1) == 1, "add_layer_norm_cmix_mix_f16 requires shape [B,1,C]");
439
  const int64_t c = x.size(2);
440
+ TORCH_CHECK((c % 2) == 0 && c > 0 && c <= 8192, "unsupported C");
441
  TORCH_CHECK(shift_state.dim() == 2 && shift_state.size(0) == x.size(0) && shift_state.size(1) == c,
442
  "shift_state shape mismatch");
443
  TORCH_CHECK(weight.dim() == 1 && weight.size(0) == c, "weight shape mismatch");
 
471
  check_half_cuda_contig(x_a, "x_a");
472
  check_half_cuda_contig(x_g, "x_g");
473
  TORCH_CHECK(x.sizes() == residual.sizes(), "add_layer_norm_tmix_mix6_f16 x/residual shape mismatch");
474
+ TORCH_CHECK(x.dim() == 3 && x.size(1) == 1, "add_layer_norm_tmix_mix6_f16 requires shape [B,1,C]");
475
+ const int64_t c = x.size(2);
476
+ TORCH_CHECK((c % 2) == 0 && c > 0 && c <= 8192, "unsupported C");
477
+ TORCH_CHECK(shift_state.dim() == 2 && shift_state.size(0) == x.size(0) && shift_state.size(1) == c,
478
  "shift_state shape mismatch");
479
+ TORCH_CHECK(weight.dim() == 1 && weight.size(0) == c, "weight shape mismatch");
480
+ TORCH_CHECK(bias.dim() == 1 && bias.size(0) == c, "bias shape mismatch");
481
+ TORCH_CHECK(x_r.numel() == c && x_w.numel() == c && x_k.numel() == c &&
482
+ x_v.numel() == c && x_a.numel() == c && x_g.numel() == c,
483
  "mix vector shape mismatch");
484
  return add_layer_norm_tmix_mix6_f16_cuda(
485
  x, residual, shift_state, weight, bias, x_r, x_w, x_k, x_v, x_a, x_g, eps);
 
605
  } // namespace
606
 
607
  TORCH_LIBRARY(rwkv7_v3a_ops, m) {
608
+ m.def("layer_norm_f16(Tensor x, Tensor weight, Tensor bias, float eps=" RWKV7_LAYER_NORM_EPS_SCHEMA ") -> Tensor");
609
+ m.def("emb_ln0_bf16_to_f16(Tensor emb, Tensor weight, Tensor bias, float eps=" RWKV7_LAYER_NORM_EPS_SCHEMA ") -> Tensor");
610
+ m.def("layer_norm_f16_small(Tensor x, Tensor weight, Tensor bias, float eps=" RWKV7_LAYER_NORM_EPS_SCHEMA ") -> Tensor");
611
+ m.def("layer_norm_f16_small512(Tensor x, Tensor weight, Tensor bias, float eps=" RWKV7_LAYER_NORM_EPS_SCHEMA ") -> Tensor");
612
  m.def("linear_f16(Tensor x, Tensor weight) -> Tensor");
613
  m.def("linear_f16_orig(Tensor x, Tensor weight_orig) -> Tensor");
614
  m.def("linear_orig_rows_f16(Tensor x, Tensor weight_orig, int row_tile, int out_tile) -> Tensor");
 
631
  m.def("linear_wag_rank_out_f16(Tensor w1, Tensor a1, Tensor g1, Tensor w2_t, Tensor a2_t, Tensor g2_t) -> Tensor[]");
632
  m.def("linear_wagv_rank_out_f16(Tensor w1, Tensor a1, Tensor g1, Tensor v1, Tensor w2_t, Tensor a2_t, Tensor g2_t, Tensor v2_t, Tensor v, Tensor v_first, Tensor v0) -> Tensor[]");
633
  m.def("add_f16(Tensor x, Tensor y) -> Tensor");
634
+ m.def("add_layer_norm_f16(Tensor x, Tensor residual, Tensor weight, Tensor bias, float eps=" RWKV7_LAYER_NORM_EPS_SCHEMA ") -> Tensor[]");
635
+ m.def("add_last_layer_norm_f16(Tensor x, Tensor residual, Tensor weight, Tensor bias, float eps=" RWKV7_LAYER_NORM_EPS_SCHEMA ") -> Tensor");
636
+ m.def("add_layer_norm_cmix_mix_f16(Tensor x, Tensor residual, Tensor(a!) shift_state, Tensor weight, Tensor bias, Tensor x_k, float eps=" RWKV7_LAYER_NORM_EPS_SCHEMA ") -> Tensor[]");
637
+ m.def("add_layer_norm_tmix_mix6_f16(Tensor x, Tensor residual, Tensor(a!) shift_state, Tensor weight, Tensor bias, Tensor x_r, Tensor x_w, Tensor x_k, Tensor x_v, Tensor x_a, Tensor x_g, float eps=" RWKV7_LAYER_NORM_EPS_SCHEMA ") -> Tensor[]");
638
  m.def("add_layer_norm_tmix_mix6_f16_cfg(Tensor x, Tensor residual, Tensor(a!) shift_state, Tensor weight, Tensor bias, Tensor x_r, Tensor x_w, Tensor x_k, Tensor x_v, Tensor x_a, Tensor x_g, float eps, int threads) -> Tensor[]");
639
+ m.def("add_layer_norm_tmix_mix6_f16_scalar_stats(Tensor x, Tensor residual, Tensor(a!) shift_state, Tensor weight, Tensor bias, Tensor x_r, Tensor x_w, Tensor x_k, Tensor x_v, Tensor x_a, Tensor x_g, float eps=" RWKV7_LAYER_NORM_EPS_SCHEMA ") -> Tensor[]");
640
  m.def("add_layer_norm_cmix_mix_f16_cfg(Tensor x, Tensor residual, Tensor(a!) shift_state, Tensor weight, Tensor bias, Tensor x_k, float eps, int threads) -> Tensor[]");
641
+ m.def("add_layer_norm_cmix_mix_f16_scalar_stats(Tensor x, Tensor residual, Tensor(a!) shift_state, Tensor weight, Tensor bias, Tensor x_k, float eps=" RWKV7_LAYER_NORM_EPS_SCHEMA ") -> Tensor[]");
642
  m.def("advance_i32(Tensor(a!) x, int amount) -> ()");
643
  }
644
 
cuda/rwkv7_v3a_ops.cu CHANGED
@@ -1990,6 +1990,182 @@ __global__ __launch_bounds__(Threads, 1) void add_last_layer_norm_f16_small_kern
1990
  }
1991
  }
1992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1993
  } // namespace
1994
 
1995
  at::Tensor add_f16_cuda(at::Tensor x, at::Tensor y) {
@@ -2154,9 +2330,16 @@ at::Tensor add_last_layer_norm_f16_cuda(at::Tensor x, at::Tensor residual, at::T
2154
  const int64_t B = x.size(0);
2155
  const int64_t T = x.size(1);
2156
  const int64_t C = x.size(2);
2157
- TORCH_CHECK(C == LN_SMALL_C, "add_last_layer_norm_f16 currently requires C=4096");
2158
  auto y = at::empty({B, C}, x.options());
2159
  auto stream = at::cuda::getCurrentCUDAStream();
 
 
 
 
 
 
 
2160
  if (B >= 1024) {
2161
  add_last_layer_norm_f16_small_kernel<LN_SMALL512_THREADS, true, true><<<static_cast<int>(B), LN_SMALL512_THREADS, 0, stream>>>(
2162
  x.data_ptr<dtype>(), residual.data_ptr<dtype>(), weight.data_ptr<dtype>(), bias.data_ptr<dtype>(),
@@ -2184,19 +2367,36 @@ std::vector<at::Tensor> add_layer_norm_cmix_mix_f16_cuda(
2184
  double eps) {
2185
  auto x_out = at::empty_like(x);
2186
  auto mixed = at::empty_like(x);
2187
- const int64_t rows = x.numel() / LN_SMALL_C;
 
 
2188
  auto stream = at::cuda::getCurrentCUDAStream();
2189
- add_layer_norm_cmix_mix_f16_scalar_stats_kernel<LN_SMALL_THREADS><<<static_cast<int>(rows), LN_SMALL_THREADS, 0, stream>>>(
2190
- x.data_ptr<dtype>(),
2191
- residual.data_ptr<dtype>(),
2192
- shift_state.data_ptr<dtype>(),
2193
- weight.data_ptr<dtype>(),
2194
- bias.data_ptr<dtype>(),
2195
- x_k.data_ptr<dtype>(),
2196
- x_out.data_ptr<dtype>(),
2197
- mixed.data_ptr<dtype>(),
2198
- rows,
2199
- static_cast<float>(eps));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2200
  C10_CUDA_KERNEL_LAUNCH_CHECK();
2201
  return {x_out, mixed};
2202
  }
@@ -2211,19 +2411,36 @@ std::vector<at::Tensor> add_layer_norm_cmix_mix_f16_scalar_stats_cuda(
2211
  double eps) {
2212
  auto x_out = at::empty_like(x);
2213
  auto mixed = at::empty_like(x);
2214
- const int64_t rows = x.numel() / LN_SMALL_C;
 
 
2215
  auto stream = at::cuda::getCurrentCUDAStream();
2216
- add_layer_norm_cmix_mix_f16_scalar_stats_kernel<LN_SMALL_THREADS><<<static_cast<int>(rows), LN_SMALL_THREADS, 0, stream>>>(
2217
- x.data_ptr<dtype>(),
2218
- residual.data_ptr<dtype>(),
2219
- shift_state.data_ptr<dtype>(),
2220
- weight.data_ptr<dtype>(),
2221
- bias.data_ptr<dtype>(),
2222
- x_k.data_ptr<dtype>(),
2223
- x_out.data_ptr<dtype>(),
2224
- mixed.data_ptr<dtype>(),
2225
- rows,
2226
- static_cast<float>(eps));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2227
  C10_CUDA_KERNEL_LAUNCH_CHECK();
2228
  return {x_out, mixed};
2229
  }
@@ -2248,29 +2465,56 @@ std::vector<at::Tensor> add_layer_norm_tmix_mix6_f16_cuda(
2248
  auto out_v = at::empty_like(x);
2249
  auto out_a = at::empty_like(x);
2250
  auto out_g = at::empty_like(x);
2251
- const int64_t rows = x.numel() / LN_SMALL_C;
 
 
2252
  auto stream = at::cuda::getCurrentCUDAStream();
2253
- add_layer_norm_tmix_mix6_f16_scalar_stats_kernel<LN_SMALL_THREADS><<<static_cast<int>(rows), LN_SMALL_THREADS, 0, stream>>>(
2254
- x.data_ptr<dtype>(),
2255
- residual.data_ptr<dtype>(),
2256
- shift_state.data_ptr<dtype>(),
2257
- weight.data_ptr<dtype>(),
2258
- bias.data_ptr<dtype>(),
2259
- x_r.data_ptr<dtype>(),
2260
- x_w.data_ptr<dtype>(),
2261
- x_k.data_ptr<dtype>(),
2262
- x_v.data_ptr<dtype>(),
2263
- x_a.data_ptr<dtype>(),
2264
- x_g.data_ptr<dtype>(),
2265
- x_out.data_ptr<dtype>(),
2266
- out_r.data_ptr<dtype>(),
2267
- out_w.data_ptr<dtype>(),
2268
- out_k.data_ptr<dtype>(),
2269
- out_v.data_ptr<dtype>(),
2270
- out_a.data_ptr<dtype>(),
2271
- out_g.data_ptr<dtype>(),
2272
- rows,
2273
- static_cast<float>(eps));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2274
  C10_CUDA_KERNEL_LAUNCH_CHECK();
2275
  return {x_out, out_r, out_w, out_k, out_v, out_a, out_g};
2276
  }
 
1990
  }
1991
  }
1992
 
1993
+ template <int Threads>
1994
+ __global__ __launch_bounds__(Threads, 1) void add_last_layer_norm_f16_generic_kernel(
1995
+ const dtype* __restrict__ x,
1996
+ const dtype* __restrict__ residual,
1997
+ const dtype* __restrict__ weight,
1998
+ const dtype* __restrict__ bias,
1999
+ dtype* __restrict__ y,
2000
+ int64_t B,
2001
+ int64_t T,
2002
+ int C,
2003
+ float eps) {
2004
+ const int64_t bidx = blockIdx.x;
2005
+ if (bidx >= B) {
2006
+ return;
2007
+ }
2008
+ const int64_t src = (bidx * T + (T - 1)) * static_cast<int64_t>(C);
2009
+ const int64_t dst = bidx * static_cast<int64_t>(C);
2010
+ float sum = 0.0f;
2011
+ for (int c = threadIdx.x; c < C; c += Threads) {
2012
+ sum += __half2float(*reinterpret_cast<const __half*>(x + src + c)) +
2013
+ __half2float(*reinterpret_cast<const __half*>(residual + src + c));
2014
+ }
2015
+ sum = block_sum_t<Threads>(sum);
2016
+ const float mean = sum / static_cast<float>(C);
2017
+ float sum_var = 0.0f;
2018
+ for (int c = threadIdx.x; c < C; c += Threads) {
2019
+ const float v = __half2float(*reinterpret_cast<const __half*>(x + src + c)) +
2020
+ __half2float(*reinterpret_cast<const __half*>(residual + src + c));
2021
+ const float d = v - mean;
2022
+ sum_var += d * d;
2023
+ }
2024
+ sum_var = block_sum_t<Threads>(sum_var);
2025
+ const float rstd = rsqrtf(sum_var / static_cast<float>(C) + eps);
2026
+ const int pairs = C >> 1;
2027
+ for (int p = threadIdx.x; p < pairs; p += Threads) {
2028
+ const float2 xv = __half22float2(reinterpret_cast<const __half2*>(x + src)[p]);
2029
+ const float2 rv = __half22float2(reinterpret_cast<const __half2*>(residual + src)[p]);
2030
+ const float sx = xv.x + rv.x;
2031
+ const float sy = xv.y + rv.y;
2032
+ const float2 w = __half22float2(reinterpret_cast<const __half2*>(weight)[p]);
2033
+ const float2 bb = __half22float2(reinterpret_cast<const __half2*>(bias)[p]);
2034
+ reinterpret_cast<__half2*>(y + dst)[p] = __floats2half2_rn(
2035
+ (sx - mean) * rstd * w.x + bb.x,
2036
+ (sy - mean) * rstd * w.y + bb.y);
2037
+ }
2038
+ }
2039
+
2040
+ template <int Threads>
2041
+ __global__ __launch_bounds__(Threads, 1) void add_layer_norm_cmix_mix_f16_generic_kernel(
2042
+ const dtype* __restrict__ x,
2043
+ const dtype* __restrict__ residual,
2044
+ dtype* __restrict__ shift_state,
2045
+ const dtype* __restrict__ weight,
2046
+ const dtype* __restrict__ bias,
2047
+ const dtype* __restrict__ x_k,
2048
+ dtype* __restrict__ x_out,
2049
+ dtype* __restrict__ mixed,
2050
+ int64_t rows,
2051
+ int C,
2052
+ float eps) {
2053
+ const int64_t row = blockIdx.x;
2054
+ if (row >= rows) {
2055
+ return;
2056
+ }
2057
+ const int64_t base = row * static_cast<int64_t>(C);
2058
+ float sum = 0.0f;
2059
+ for (int c = threadIdx.x; c < C; c += Threads) {
2060
+ sum += __half2float(*reinterpret_cast<const __half*>(x + base + c)) +
2061
+ __half2float(*reinterpret_cast<const __half*>(residual + base + c));
2062
+ }
2063
+ sum = block_sum_t<Threads>(sum);
2064
+ const float mean = sum / static_cast<float>(C);
2065
+ float sum_var = 0.0f;
2066
+ for (int c = threadIdx.x; c < C; c += Threads) {
2067
+ const float v = __half2float(*reinterpret_cast<const __half*>(x + base + c)) +
2068
+ __half2float(*reinterpret_cast<const __half*>(residual + base + c));
2069
+ const float d = v - mean;
2070
+ sum_var += d * d;
2071
+ }
2072
+ sum_var = block_sum_t<Threads>(sum_var);
2073
+ const float rstd = rsqrtf(sum_var / static_cast<float>(C) + eps);
2074
+ const int pairs = C >> 1;
2075
+ const int64_t base2 = base >> 1;
2076
+ for (int p = threadIdx.x; p < pairs; p += Threads) {
2077
+ const float2 xv = __half22float2(reinterpret_cast<const __half2*>(x)[base2 + p]);
2078
+ const float2 rv = __half22float2(reinterpret_cast<const __half2*>(residual)[base2 + p]);
2079
+ const float2 w = __half22float2(reinterpret_cast<const __half2*>(weight)[p]);
2080
+ const float2 b = __half22float2(reinterpret_cast<const __half2*>(bias)[p]);
2081
+ const float2 prev = __half22float2(reinterpret_cast<const __half2*>(shift_state)[base2 + p]);
2082
+ const float2 mix = __half22float2(reinterpret_cast<const __half2*>(x_k)[p]);
2083
+ const float x0 = xv.x + rv.x;
2084
+ const float x1 = xv.y + rv.y;
2085
+ const __half2 y2 = __floats2half2_rn((x0 - mean) * rstd * w.x + b.x, (x1 - mean) * rstd * w.y + b.y);
2086
+ const float2 yv = __half22float2(y2);
2087
+ reinterpret_cast<__half2*>(x_out)[base2 + p] = __floats2half2_rn(x0, x1);
2088
+ reinterpret_cast<__half2*>(mixed)[base2 + p] =
2089
+ __floats2half2_rn(yv.x + (prev.x - yv.x) * mix.x, yv.y + (prev.y - yv.y) * mix.y);
2090
+ reinterpret_cast<__half2*>(shift_state)[base2 + p] = y2;
2091
+ }
2092
+ }
2093
+
2094
+ template <int Threads>
2095
+ __global__ __launch_bounds__(Threads, 1) void add_layer_norm_tmix_mix6_f16_generic_kernel(
2096
+ const dtype* __restrict__ x,
2097
+ const dtype* __restrict__ residual,
2098
+ dtype* __restrict__ shift_state,
2099
+ const dtype* __restrict__ weight,
2100
+ const dtype* __restrict__ bias,
2101
+ const dtype* __restrict__ x_r,
2102
+ const dtype* __restrict__ x_w,
2103
+ const dtype* __restrict__ x_k,
2104
+ const dtype* __restrict__ x_v,
2105
+ const dtype* __restrict__ x_a,
2106
+ const dtype* __restrict__ x_g,
2107
+ dtype* __restrict__ x_out,
2108
+ dtype* __restrict__ out_r,
2109
+ dtype* __restrict__ out_w,
2110
+ dtype* __restrict__ out_k,
2111
+ dtype* __restrict__ out_v,
2112
+ dtype* __restrict__ out_a,
2113
+ dtype* __restrict__ out_g,
2114
+ int64_t rows,
2115
+ int C,
2116
+ float eps) {
2117
+ const int64_t row = blockIdx.x;
2118
+ if (row >= rows) {
2119
+ return;
2120
+ }
2121
+ const int64_t base = row * static_cast<int64_t>(C);
2122
+ float sum = 0.0f;
2123
+ for (int c = threadIdx.x; c < C; c += Threads) {
2124
+ sum += __half2float(*reinterpret_cast<const __half*>(x + base + c)) +
2125
+ __half2float(*reinterpret_cast<const __half*>(residual + base + c));
2126
+ }
2127
+ sum = block_sum_t<Threads>(sum);
2128
+ const float mean = sum / static_cast<float>(C);
2129
+ float sum_var = 0.0f;
2130
+ for (int c = threadIdx.x; c < C; c += Threads) {
2131
+ const float v = __half2float(*reinterpret_cast<const __half*>(x + base + c)) +
2132
+ __half2float(*reinterpret_cast<const __half*>(residual + base + c));
2133
+ const float d = v - mean;
2134
+ sum_var += d * d;
2135
+ }
2136
+ sum_var = block_sum_t<Threads>(sum_var);
2137
+ const float rstd = rsqrtf(sum_var / static_cast<float>(C) + eps);
2138
+ const int pairs = C >> 1;
2139
+ const int64_t base2 = base >> 1;
2140
+ for (int p = threadIdx.x; p < pairs; p += Threads) {
2141
+ const float2 xv = __half22float2(reinterpret_cast<const __half2*>(x)[base2 + p]);
2142
+ const float2 rv = __half22float2(reinterpret_cast<const __half2*>(residual)[base2 + p]);
2143
+ const float2 w = __half22float2(reinterpret_cast<const __half2*>(weight)[p]);
2144
+ const float2 b = __half22float2(reinterpret_cast<const __half2*>(bias)[p]);
2145
+ const float2 prev = __half22float2(reinterpret_cast<const __half2*>(shift_state)[base2 + p]);
2146
+ const float x0 = xv.x + rv.x;
2147
+ const float x1 = xv.y + rv.y;
2148
+ const __half2 y2 = __floats2half2_rn((x0 - mean) * rstd * w.x + b.x, (x1 - mean) * rstd * w.y + b.y);
2149
+ const float2 yv = __half22float2(y2);
2150
+ const float dx0 = prev.x - yv.x;
2151
+ const float dx1 = prev.y - yv.y;
2152
+ const float2 mr = __half22float2(reinterpret_cast<const __half2*>(x_r)[p]);
2153
+ const float2 mw = __half22float2(reinterpret_cast<const __half2*>(x_w)[p]);
2154
+ const float2 mk = __half22float2(reinterpret_cast<const __half2*>(x_k)[p]);
2155
+ const float2 mv = __half22float2(reinterpret_cast<const __half2*>(x_v)[p]);
2156
+ const float2 ma = __half22float2(reinterpret_cast<const __half2*>(x_a)[p]);
2157
+ const float2 mg = __half22float2(reinterpret_cast<const __half2*>(x_g)[p]);
2158
+ reinterpret_cast<__half2*>(x_out)[base2 + p] = __floats2half2_rn(x0, x1);
2159
+ reinterpret_cast<__half2*>(out_r)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mr.x, yv.y + dx1 * mr.y);
2160
+ reinterpret_cast<__half2*>(out_w)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mw.x, yv.y + dx1 * mw.y);
2161
+ reinterpret_cast<__half2*>(out_k)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mk.x, yv.y + dx1 * mk.y);
2162
+ reinterpret_cast<__half2*>(out_v)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mv.x, yv.y + dx1 * mv.y);
2163
+ reinterpret_cast<__half2*>(out_a)[base2 + p] = __floats2half2_rn(yv.x + dx0 * ma.x, yv.y + dx1 * ma.y);
2164
+ reinterpret_cast<__half2*>(out_g)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mg.x, yv.y + dx1 * mg.y);
2165
+ reinterpret_cast<__half2*>(shift_state)[base2 + p] = y2;
2166
+ }
2167
+ }
2168
+
2169
  } // namespace
2170
 
2171
  at::Tensor add_f16_cuda(at::Tensor x, at::Tensor y) {
 
2330
  const int64_t B = x.size(0);
2331
  const int64_t T = x.size(1);
2332
  const int64_t C = x.size(2);
2333
+ TORCH_CHECK((C % 2) == 0, "add_last_layer_norm_f16 requires even C");
2334
  auto y = at::empty({B, C}, x.options());
2335
  auto stream = at::cuda::getCurrentCUDAStream();
2336
+ if (C != LN_SMALL_C) {
2337
+ add_last_layer_norm_f16_generic_kernel<LN_THREADS><<<static_cast<int>(B), LN_THREADS, 0, stream>>>(
2338
+ x.data_ptr<dtype>(), residual.data_ptr<dtype>(), weight.data_ptr<dtype>(), bias.data_ptr<dtype>(),
2339
+ y.data_ptr<dtype>(), B, T, static_cast<int>(C), static_cast<float>(eps));
2340
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
2341
+ return y;
2342
+ }
2343
  if (B >= 1024) {
2344
  add_last_layer_norm_f16_small_kernel<LN_SMALL512_THREADS, true, true><<<static_cast<int>(B), LN_SMALL512_THREADS, 0, stream>>>(
2345
  x.data_ptr<dtype>(), residual.data_ptr<dtype>(), weight.data_ptr<dtype>(), bias.data_ptr<dtype>(),
 
2367
  double eps) {
2368
  auto x_out = at::empty_like(x);
2369
  auto mixed = at::empty_like(x);
2370
+ const int64_t C = x.size(-1);
2371
+ TORCH_CHECK((C % 2) == 0, "add_layer_norm_cmix_mix_f16 requires even C");
2372
+ const int64_t rows = x.numel() / C;
2373
  auto stream = at::cuda::getCurrentCUDAStream();
2374
+ if (C == LN_SMALL_C) {
2375
+ add_layer_norm_cmix_mix_f16_scalar_stats_kernel<LN_SMALL_THREADS><<<static_cast<int>(rows), LN_SMALL_THREADS, 0, stream>>>(
2376
+ x.data_ptr<dtype>(),
2377
+ residual.data_ptr<dtype>(),
2378
+ shift_state.data_ptr<dtype>(),
2379
+ weight.data_ptr<dtype>(),
2380
+ bias.data_ptr<dtype>(),
2381
+ x_k.data_ptr<dtype>(),
2382
+ x_out.data_ptr<dtype>(),
2383
+ mixed.data_ptr<dtype>(),
2384
+ rows,
2385
+ static_cast<float>(eps));
2386
+ } else {
2387
+ add_layer_norm_cmix_mix_f16_generic_kernel<LN_THREADS><<<static_cast<int>(rows), LN_THREADS, 0, stream>>>(
2388
+ x.data_ptr<dtype>(),
2389
+ residual.data_ptr<dtype>(),
2390
+ shift_state.data_ptr<dtype>(),
2391
+ weight.data_ptr<dtype>(),
2392
+ bias.data_ptr<dtype>(),
2393
+ x_k.data_ptr<dtype>(),
2394
+ x_out.data_ptr<dtype>(),
2395
+ mixed.data_ptr<dtype>(),
2396
+ rows,
2397
+ static_cast<int>(C),
2398
+ static_cast<float>(eps));
2399
+ }
2400
  C10_CUDA_KERNEL_LAUNCH_CHECK();
2401
  return {x_out, mixed};
2402
  }
 
2411
  double eps) {
2412
  auto x_out = at::empty_like(x);
2413
  auto mixed = at::empty_like(x);
2414
+ const int64_t C = x.size(-1);
2415
+ TORCH_CHECK((C % 2) == 0, "add_layer_norm_cmix_mix_f16 requires even C");
2416
+ const int64_t rows = x.numel() / C;
2417
  auto stream = at::cuda::getCurrentCUDAStream();
2418
+ if (C == LN_SMALL_C) {
2419
+ add_layer_norm_cmix_mix_f16_scalar_stats_kernel<LN_SMALL_THREADS><<<static_cast<int>(rows), LN_SMALL_THREADS, 0, stream>>>(
2420
+ x.data_ptr<dtype>(),
2421
+ residual.data_ptr<dtype>(),
2422
+ shift_state.data_ptr<dtype>(),
2423
+ weight.data_ptr<dtype>(),
2424
+ bias.data_ptr<dtype>(),
2425
+ x_k.data_ptr<dtype>(),
2426
+ x_out.data_ptr<dtype>(),
2427
+ mixed.data_ptr<dtype>(),
2428
+ rows,
2429
+ static_cast<float>(eps));
2430
+ } else {
2431
+ add_layer_norm_cmix_mix_f16_generic_kernel<LN_THREADS><<<static_cast<int>(rows), LN_THREADS, 0, stream>>>(
2432
+ x.data_ptr<dtype>(),
2433
+ residual.data_ptr<dtype>(),
2434
+ shift_state.data_ptr<dtype>(),
2435
+ weight.data_ptr<dtype>(),
2436
+ bias.data_ptr<dtype>(),
2437
+ x_k.data_ptr<dtype>(),
2438
+ x_out.data_ptr<dtype>(),
2439
+ mixed.data_ptr<dtype>(),
2440
+ rows,
2441
+ static_cast<int>(C),
2442
+ static_cast<float>(eps));
2443
+ }
2444
  C10_CUDA_KERNEL_LAUNCH_CHECK();
2445
  return {x_out, mixed};
2446
  }
 
2465
  auto out_v = at::empty_like(x);
2466
  auto out_a = at::empty_like(x);
2467
  auto out_g = at::empty_like(x);
2468
+ const int64_t C = x.size(-1);
2469
+ TORCH_CHECK((C % 2) == 0, "add_layer_norm_tmix_mix6_f16 requires even C");
2470
+ const int64_t rows = x.numel() / C;
2471
  auto stream = at::cuda::getCurrentCUDAStream();
2472
+ if (C == LN_SMALL_C) {
2473
+ add_layer_norm_tmix_mix6_f16_scalar_stats_kernel<LN_SMALL_THREADS><<<static_cast<int>(rows), LN_SMALL_THREADS, 0, stream>>>(
2474
+ x.data_ptr<dtype>(),
2475
+ residual.data_ptr<dtype>(),
2476
+ shift_state.data_ptr<dtype>(),
2477
+ weight.data_ptr<dtype>(),
2478
+ bias.data_ptr<dtype>(),
2479
+ x_r.data_ptr<dtype>(),
2480
+ x_w.data_ptr<dtype>(),
2481
+ x_k.data_ptr<dtype>(),
2482
+ x_v.data_ptr<dtype>(),
2483
+ x_a.data_ptr<dtype>(),
2484
+ x_g.data_ptr<dtype>(),
2485
+ x_out.data_ptr<dtype>(),
2486
+ out_r.data_ptr<dtype>(),
2487
+ out_w.data_ptr<dtype>(),
2488
+ out_k.data_ptr<dtype>(),
2489
+ out_v.data_ptr<dtype>(),
2490
+ out_a.data_ptr<dtype>(),
2491
+ out_g.data_ptr<dtype>(),
2492
+ rows,
2493
+ static_cast<float>(eps));
2494
+ } else {
2495
+ add_layer_norm_tmix_mix6_f16_generic_kernel<LN_THREADS><<<static_cast<int>(rows), LN_THREADS, 0, stream>>>(
2496
+ x.data_ptr<dtype>(),
2497
+ residual.data_ptr<dtype>(),
2498
+ shift_state.data_ptr<dtype>(),
2499
+ weight.data_ptr<dtype>(),
2500
+ bias.data_ptr<dtype>(),
2501
+ x_r.data_ptr<dtype>(),
2502
+ x_w.data_ptr<dtype>(),
2503
+ x_k.data_ptr<dtype>(),
2504
+ x_v.data_ptr<dtype>(),
2505
+ x_a.data_ptr<dtype>(),
2506
+ x_g.data_ptr<dtype>(),
2507
+ x_out.data_ptr<dtype>(),
2508
+ out_r.data_ptr<dtype>(),
2509
+ out_w.data_ptr<dtype>(),
2510
+ out_k.data_ptr<dtype>(),
2511
+ out_v.data_ptr<dtype>(),
2512
+ out_a.data_ptr<dtype>(),
2513
+ out_g.data_ptr<dtype>(),
2514
+ rows,
2515
+ static_cast<int>(C),
2516
+ static_cast<float>(eps));
2517
+ }
2518
  C10_CUDA_KERNEL_LAUNCH_CHECK();
2519
  return {x_out, out_r, out_w, out_k, out_v, out_a, out_g};
2520
  }
rwkv7_fast_v3a.py CHANGED
@@ -25,7 +25,7 @@ ORIG_LINEAR_GROUPS = {"att_c2c", "ffn_key", "head"}
25
  LOWRANK_SUFFIXES = ("att.w1", "att.w2", "att.a1", "att.a2", "att.g1", "att.g2", "att.v1", "att.v2")
26
  LOWRANK_IN_ROWS_T = 7
27
  LOWRANK_OUT_ROWS_T = 4
28
- CMIX_NOFC_MAX_ROWS = 19
29
  CMIX_NOFC_ROW20_MAX_T = 5
30
  CMIX_NOFC_T512_MIN_ROWS = 8
31
  LN1_TMIX_FUSE = True
@@ -36,8 +36,9 @@ CMIX_ROWS2_NOFC = "rows2_nofc"
36
  CMIX_DENSE = "dense"
37
 
38
  def main() -> None:
39
- global WKV_MODE, EMB_DEVICE, RKV_MODE, CMIX_SPARSE, LOWRANK_WEIGHT, ORIG_LINEAR_GROUPS
40
  parser = argparse.ArgumentParser()
 
41
  parser.add_argument("--warmup", type=int, default=1)
42
  parser.add_argument("--iters", type=int, default=3)
43
  parser.add_argument("--cases", default="1x1,1x2,1x4,1x8,1x16,1x32,1x64,1x128,1x256,2x1,4x1,8x1,16x1,32x1,64x1,128x1,256x1,2x2,4x4,8x8,16x16") # try 1x1024 1024x1 32x32 for extreme tps
@@ -54,6 +55,7 @@ def main() -> None:
54
  parser.add_argument("--orig-linear-groups", default="att_c2c,ffn_key,head") # comma list: none, att_c2c, ffn_key, head
55
  args = parser.parse_args()
56
 
 
57
  WKV_MODE = args.wkv
58
  EMB_DEVICE = args.emb
59
  RKV_MODE = args.batched_rkv
@@ -62,7 +64,7 @@ def main() -> None:
62
  ORIG_LINEAR_GROUPS = parse_orig_linear_groups(args.orig_linear_groups)
63
  groups = ",".join(sorted(ORIG_LINEAR_GROUPS)) if ORIG_LINEAR_GROUPS else "none"
64
  log(f"start model={MODEL_PATH} wkv={WKV_MODE} emb={EMB_DEVICE} batched_rkv={RKV_MODE} cmix_sparse={CMIX_SPARSE} lowrank_weight={LOWRANK_WEIGHT} orig_linear_groups={groups}")
65
- log(f"fixed fast path: ln=v3a linear=v3a/splitk lowrank={LOWRANK_IN_ROWS_T}/{LOWRANK_OUT_ROWS_T} nofc_rows<={CMIX_NOFC_MAX_ROWS} row20_t<={CMIX_NOFC_ROW20_MAX_T} nofc_t512_rows>={CMIX_NOFC_T512_MIN_ROWS}")
66
  load_extensions(WKV_MODE)
67
  model = RWKV7()
68
  if args.eval_json:
@@ -97,7 +99,7 @@ def select_path(B: int, T: int) -> PathConfig:
97
  if CMIX_SPARSE == "off":
98
  cmix_mode = CMIX_DENSE
99
  elif CMIX_SPARSE == "no-fc":
100
- use_nofc = rows <= CMIX_NOFC_MAX_ROWS or (rows == 20 and T <= CMIX_NOFC_ROW20_MAX_T)
101
  cmix_mode = CMIX_B1T1_NOFC if rows == 1 else (CMIX_ROWS2_NOFC if use_nofc else CMIX_DENSE)
102
  elif rows == 1:
103
  cmix_mode = CMIX_B1T1_SPARSE
@@ -115,6 +117,12 @@ def select_path(B: int, T: int) -> PathConfig:
115
  use_batched_rkv = False
116
  return PathConfig(rows=rows, use_batched_rkv=use_batched_rkv, cmix_mode=cmix_mode)
117
 
 
 
 
 
 
 
118
  def parse_orig_linear_groups(text: str) -> set[str]:
119
  groups = {x.strip() for x in text.replace(",", " ").split() if x.strip()}
120
  if not groups or groups == {"none"}:
@@ -130,6 +138,12 @@ def use_orig_linear(group: str) -> bool:
130
  def is_lowrank_weight(key: str) -> bool:
131
  return key.endswith(LOWRANK_SUFFIXES)
132
 
 
 
 
 
 
 
133
  def is_att_c2c_weight(key: str) -> bool:
134
  return ".att." in key and key.endswith(("receptance.weight", "key.weight", "value.weight", "output.weight"))
135
 
@@ -173,6 +187,7 @@ class RWKV7:
173
  C, V = H * N, z["emb.weight"].shape[0]
174
  assert N == HEAD_SIZE
175
  log(f"detected model C={C} H={H} N={N} V={V}")
 
176
 
177
  emb_src = z["emb.weight"].squeeze()
178
  ln0_w_src = z["blocks.0.ln0.weight"].squeeze()
@@ -271,7 +286,7 @@ class RWKV7:
271
  dev.copy_(host.view(B,T,C), non_blocking=True)
272
  return dev
273
 
274
- def forward_from_x(self, x: torch.Tensor, state: list[torch.Tensor], path: PathConfig, all_logits: bool = False) -> torch.Tensor:
275
  z = self.z
276
  B, T, _ = x.shape
277
  v_first = x
@@ -302,7 +317,11 @@ class RWKV7:
302
  else:
303
  x, xx = self.add_ln(x, xx, z[p_next+"ln1.weight"], z[p_next+"ln1.bias"])
304
  elif not all_logits:
305
- x = self.add_last_ln(x, xx, z["ln_out.weight"], z["ln_out.bias"])
 
 
 
 
306
  torch.ops.rwkv7_v3a_ops.advance_i32(state[2], T) # !!! IMPORTANT FOR WKV16 DITHERING !!!
307
  return self.linear_head(x)
308
  else:
@@ -323,6 +342,14 @@ class RWKV7:
323
  x = self.embed(tokens)
324
  return self.forward_from_x(x, state, path, all_logits=True)
325
 
 
 
 
 
 
 
 
 
326
  def tmix(self, layer: int, x: torch.Tensor, shift_state: torch.Tensor, wkv_state: torch.Tensor, elapsed_t: torch.Tensor, v_first: torch.Tensor, p: str, path: PathConfig, pre_mix=None) -> tuple[torch.Tensor, torch.Tensor]:
327
  z = self.z
328
  ops = torch.ops.rwkv7_fast_ops_fp16
@@ -351,11 +378,11 @@ class RWKV7:
351
  v = self.linear_orig_layout(xv, z[p+"value.weight"], path, "att_c2c")
352
 
353
  v1 = None
354
- if LOWRANK_WEIGHT != "orig" and path.rows <= LOWRANK_IN_ROWS_T and path.rows <= LOWRANK_OUT_ROWS_T and layer != 0:
355
  w1, a1, g1, v1 = torch.ops.rwkv7_v3a_ops.linear_wagv_rank_in_f16(
356
  xw.contiguous(), xa.contiguous(), xg.contiguous(), xv.contiguous(),
357
  z[p+"w1.t"], z[p+"a1.t"], z[p+"g1.t"], z[p+"v1.t"])
358
- elif LOWRANK_WEIGHT != "orig" and path.rows <= LOWRANK_IN_ROWS_T:
359
  w1, a1, g1 = torch.ops.rwkv7_v3a_ops.linear_wag_rank_in_f16(
360
  xw.contiguous(), xa.contiguous(), xg.contiguous(), z[p+"w1.t"], z[p+"a1.t"], z[p+"g1.t"])
361
  else:
@@ -363,13 +390,13 @@ class RWKV7:
363
  a1 = self.linear_rank_in(xa, z.get(p+"a1"), z.get(p+"a1.t"), path.rows)
364
  g1 = self.linear_rank_in(xg, z.get(p+"g1"), z.get(p+"g1.t"), path.rows)
365
  v_done = False
366
- if LOWRANK_WEIGHT != "orig" and path.rows <= LOWRANK_OUT_ROWS_T and layer != 0 and v1 is not None:
367
  w, a, g, v = torch.ops.rwkv7_v3a_ops.linear_wagv_rank_out_f16(
368
  w1.contiguous(), a1.contiguous(), g1.contiguous(), v1.contiguous(),
369
  z[p+"w2.t"], z[p+"a2.t"], z[p+"g2.t"], z[p+"v2.t"],
370
  v.contiguous(), v_first.contiguous(), z[p+"v0"])
371
  v_done = True
372
- elif LOWRANK_WEIGHT != "orig" and path.rows <= LOWRANK_OUT_ROWS_T:
373
  w, a, g = torch.ops.rwkv7_v3a_ops.linear_wag_rank_out_f16(
374
  w1.contiguous(), a1.contiguous(), g1.contiguous(), z[p+"w2.t"], z[p+"a2.t"], z[p+"g2.t"])
375
  else:
@@ -381,7 +408,7 @@ class RWKV7:
381
  if layer == 0:
382
  v_first = v
383
  elif not v_done:
384
- if LOWRANK_WEIGHT != "orig" and path.rows <= LOWRANK_OUT_ROWS_T:
385
  if v1 is None:
386
  v1 = self.linear_rank_in(xv, z.get(p+"v1"), z.get(p+"v1.t"), path.rows)
387
  v = torch.ops.rwkv7_v3a_ops.linear_t_vres_f16(v1.contiguous(), z[p+"v2.t"], v.contiguous(), v_first.contiguous(), z[p+"v0"])
@@ -447,28 +474,102 @@ class RWKV7:
447
  return self.linear(x, weight)
448
  if path.rows == 1:
449
  if group == "ffn_key":
450
- return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 128, 2, False)
451
- return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 128, 2, True)
 
 
452
  if path.rows == 2:
453
  if group == "att_c2c":
454
  return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 64, 2, True)
455
  if group == "ffn_key":
456
- return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 256, 1, True)
 
 
 
 
 
 
457
  return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 64, 2, True)
458
  if path.rows == 3:
459
  if group == "head":
 
 
 
 
460
  return torch.ops.rwkv7_v3a_ops.linear_orig_rows_f16(x.contiguous(), weight, 3, 2)
461
  if group == "ffn_key":
 
 
 
 
 
 
462
  return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 0)
463
  if group == "att_c2c":
 
 
 
 
 
 
 
 
464
  return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 2)
465
  return torch.ops.rwkv7_v3a_ops.linear_orig_rows_cfg_f16(x.contiguous(), weight, 64, 3, 4)
466
  if path.rows == 4:
467
  if group == "ffn_key":
 
 
 
 
 
 
468
  return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 0)
469
  if group == "att_c2c":
 
 
 
 
 
 
470
  return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 2)
471
  if group == "head":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  if path.rows >= 1024:
473
  return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 0)
474
  if path.rows >= 512:
@@ -492,6 +593,40 @@ class RWKV7:
492
  if path.rows >= 72:
493
  return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 2)
494
  if group == "att_c2c":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
  if path.rows >= 1024:
496
  return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 4)
497
  if path.rows >= 768:
@@ -523,6 +658,39 @@ class RWKV7:
523
  if path.rows >= 5:
524
  return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 2)
525
  if group == "ffn_key":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  if path.rows >= 1024:
527
  return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 0)
528
  if path.rows >= 768:
@@ -559,12 +727,12 @@ class RWKV7:
559
  return self.linear_lowrank_orig(x, weight) if weight is not None else self.linear_t_orig(x, weight_t)
560
 
561
  def linear_rank_out(self, x: torch.Tensor, weight: torch.Tensor, weight_t: torch.Tensor, rows: int) -> torch.Tensor:
562
- if weight_t is not None and rows <= LOWRANK_OUT_ROWS_T:
563
  return torch.ops.rwkv7_v3a_ops.linear_t_f16(x.contiguous(), weight_t)
564
  return self.linear_lowrank_orig(x, weight) if weight is not None else self.linear_t_orig(x, weight_t)
565
 
566
  def linear_rank_out_act(self, x: torch.Tensor, weight: torch.Tensor, weight_t: torch.Tensor, rows: int, act: int) -> torch.Tensor:
567
- if weight_t is not None and rows <= LOWRANK_OUT_ROWS_T:
568
  return torch.ops.rwkv7_v3a_ops.linear_t_act_f16(x.contiguous(), weight_t, act)
569
  ops = torch.ops.rwkv7_fast_ops_fp16
570
  x = ops.act_tanh(x.contiguous()) if act == 1 else ops.act_sigmoid(x.contiguous())
 
25
  LOWRANK_SUFFIXES = ("att.w1", "att.w2", "att.a1", "att.a2", "att.g1", "att.g2", "att.v1", "att.v2")
26
  LOWRANK_IN_ROWS_T = 7
27
  LOWRANK_OUT_ROWS_T = 4
28
+ LOWRANK_FUSED_MIN_C = 1024
29
  CMIX_NOFC_ROW20_MAX_T = 5
30
  CMIX_NOFC_T512_MIN_ROWS = 8
31
  LN1_TMIX_FUSE = True
 
36
  CMIX_DENSE = "dense"
37
 
38
  def main() -> None:
39
+ global MODEL_PATH, WKV_MODE, EMB_DEVICE, RKV_MODE, CMIX_SPARSE, LOWRANK_WEIGHT, ORIG_LINEAR_GROUPS
40
  parser = argparse.ArgumentParser()
41
+ parser.add_argument("--model", default=MODEL_PATH)
42
  parser.add_argument("--warmup", type=int, default=1)
43
  parser.add_argument("--iters", type=int, default=3)
44
  parser.add_argument("--cases", default="1x1,1x2,1x4,1x8,1x16,1x32,1x64,1x128,1x256,2x1,4x1,8x1,16x1,32x1,64x1,128x1,256x1,2x2,4x4,8x8,16x16") # try 1x1024 1024x1 32x32 for extreme tps
 
55
  parser.add_argument("--orig-linear-groups", default="att_c2c,ffn_key,head") # comma list: none, att_c2c, ffn_key, head
56
  args = parser.parse_args()
57
 
58
+ MODEL_PATH = args.model
59
  WKV_MODE = args.wkv
60
  EMB_DEVICE = args.emb
61
  RKV_MODE = args.batched_rkv
 
64
  ORIG_LINEAR_GROUPS = parse_orig_linear_groups(args.orig_linear_groups)
65
  groups = ",".join(sorted(ORIG_LINEAR_GROUPS)) if ORIG_LINEAR_GROUPS else "none"
66
  log(f"start model={MODEL_PATH} wkv={WKV_MODE} emb={EMB_DEVICE} batched_rkv={RKV_MODE} cmix_sparse={CMIX_SPARSE} lowrank_weight={LOWRANK_WEIGHT} orig_linear_groups={groups}")
67
+ log(f"fixed fast path: ln=v3a linear=v3a/splitk lowrank={LOWRANK_IN_ROWS_T}/{LOWRANK_OUT_ROWS_T} nofc_rows=by_C row20_t=by_C nofc_t512_rows>={CMIX_NOFC_T512_MIN_ROWS}")
68
  load_extensions(WKV_MODE)
69
  model = RWKV7()
70
  if args.eval_json:
 
99
  if CMIX_SPARSE == "off":
100
  cmix_mode = CMIX_DENSE
101
  elif CMIX_SPARSE == "no-fc":
102
+ use_nofc = rows <= cmix_nofc_max_rows() or (rows == 20 and T <= cmix_nofc_row20_max_t())
103
  cmix_mode = CMIX_B1T1_NOFC if rows == 1 else (CMIX_ROWS2_NOFC if use_nofc else CMIX_DENSE)
104
  elif rows == 1:
105
  cmix_mode = CMIX_B1T1_SPARSE
 
117
  use_batched_rkv = False
118
  return PathConfig(rows=rows, use_batched_rkv=use_batched_rkv, cmix_mode=cmix_mode)
119
 
120
+ def cmix_nofc_max_rows() -> int:
121
+ return 19
122
+
123
+ def cmix_nofc_row20_max_t() -> int:
124
+ return CMIX_NOFC_ROW20_MAX_T
125
+
126
  def parse_orig_linear_groups(text: str) -> set[str]:
127
  groups = {x.strip() for x in text.replace(",", " ").split() if x.strip()}
128
  if not groups or groups == {"none"}:
 
138
  def is_lowrank_weight(key: str) -> bool:
139
  return key.endswith(LOWRANK_SUFFIXES)
140
 
141
+ def can_use_lowrank_fused(rows: int) -> bool:
142
+ return C >= LOWRANK_FUSED_MIN_C and rows <= LOWRANK_IN_ROWS_T
143
+
144
+ def can_use_lowrank_out_fused(rows: int) -> bool:
145
+ return C >= LOWRANK_FUSED_MIN_C and rows <= LOWRANK_OUT_ROWS_T
146
+
147
  def is_att_c2c_weight(key: str) -> bool:
148
  return ".att." in key and key.endswith(("receptance.weight", "key.weight", "value.weight", "output.weight"))
149
 
 
187
  C, V = H * N, z["emb.weight"].shape[0]
188
  assert N == HEAD_SIZE
189
  log(f"detected model C={C} H={H} N={N} V={V}")
190
+ log(f"cmix no-fc path: rows<={cmix_nofc_max_rows()} row20_t<={cmix_nofc_row20_max_t()}")
191
 
192
  emb_src = z["emb.weight"].squeeze()
193
  ln0_w_src = z["blocks.0.ln0.weight"].squeeze()
 
286
  dev.copy_(host.view(B,T,C), non_blocking=True)
287
  return dev
288
 
289
+ def forward_from_x(self, x: torch.Tensor, state: list[torch.Tensor], path: PathConfig, all_logits: bool = False, last_indices=None) -> torch.Tensor:
290
  z = self.z
291
  B, T, _ = x.shape
292
  v_first = x
 
317
  else:
318
  x, xx = self.add_ln(x, xx, z[p_next+"ln1.weight"], z[p_next+"ln1.bias"])
319
  elif not all_logits:
320
+ if last_indices is not None:
321
+ x = self.ln(self.add(x, xx), z["ln_out.weight"], z["ln_out.bias"])
322
+ x = x[torch.arange(B, device=x.device), last_indices].contiguous()
323
+ else:
324
+ x = self.add_last_ln(x, xx, z["ln_out.weight"], z["ln_out.bias"])
325
  torch.ops.rwkv7_v3a_ops.advance_i32(state[2], T) # !!! IMPORTANT FOR WKV16 DITHERING !!!
326
  return self.linear_head(x)
327
  else:
 
342
  x = self.embed(tokens)
343
  return self.forward_from_x(x, state, path, all_logits=True)
344
 
345
+ def forward_last_at(self, tokens: torch.Tensor, state: list[torch.Tensor], last_indices: torch.Tensor) -> torch.Tensor:
346
+ if tokens.dim() == 1:
347
+ tokens = tokens.unsqueeze(0)
348
+ B, T = tokens.shape
349
+ path = select_path(B, T)
350
+ x = self.embed(tokens)
351
+ return self.forward_from_x(x, state, path, last_indices=last_indices)
352
+
353
  def tmix(self, layer: int, x: torch.Tensor, shift_state: torch.Tensor, wkv_state: torch.Tensor, elapsed_t: torch.Tensor, v_first: torch.Tensor, p: str, path: PathConfig, pre_mix=None) -> tuple[torch.Tensor, torch.Tensor]:
354
  z = self.z
355
  ops = torch.ops.rwkv7_fast_ops_fp16
 
378
  v = self.linear_orig_layout(xv, z[p+"value.weight"], path, "att_c2c")
379
 
380
  v1 = None
381
+ if LOWRANK_WEIGHT != "orig" and can_use_lowrank_fused(path.rows) and can_use_lowrank_out_fused(path.rows) and layer != 0:
382
  w1, a1, g1, v1 = torch.ops.rwkv7_v3a_ops.linear_wagv_rank_in_f16(
383
  xw.contiguous(), xa.contiguous(), xg.contiguous(), xv.contiguous(),
384
  z[p+"w1.t"], z[p+"a1.t"], z[p+"g1.t"], z[p+"v1.t"])
385
+ elif LOWRANK_WEIGHT != "orig" and can_use_lowrank_fused(path.rows):
386
  w1, a1, g1 = torch.ops.rwkv7_v3a_ops.linear_wag_rank_in_f16(
387
  xw.contiguous(), xa.contiguous(), xg.contiguous(), z[p+"w1.t"], z[p+"a1.t"], z[p+"g1.t"])
388
  else:
 
390
  a1 = self.linear_rank_in(xa, z.get(p+"a1"), z.get(p+"a1.t"), path.rows)
391
  g1 = self.linear_rank_in(xg, z.get(p+"g1"), z.get(p+"g1.t"), path.rows)
392
  v_done = False
393
+ if LOWRANK_WEIGHT != "orig" and can_use_lowrank_out_fused(path.rows) and layer != 0 and v1 is not None:
394
  w, a, g, v = torch.ops.rwkv7_v3a_ops.linear_wagv_rank_out_f16(
395
  w1.contiguous(), a1.contiguous(), g1.contiguous(), v1.contiguous(),
396
  z[p+"w2.t"], z[p+"a2.t"], z[p+"g2.t"], z[p+"v2.t"],
397
  v.contiguous(), v_first.contiguous(), z[p+"v0"])
398
  v_done = True
399
+ elif LOWRANK_WEIGHT != "orig" and can_use_lowrank_out_fused(path.rows):
400
  w, a, g = torch.ops.rwkv7_v3a_ops.linear_wag_rank_out_f16(
401
  w1.contiguous(), a1.contiguous(), g1.contiguous(), z[p+"w2.t"], z[p+"a2.t"], z[p+"g2.t"])
402
  else:
 
408
  if layer == 0:
409
  v_first = v
410
  elif not v_done:
411
+ if LOWRANK_WEIGHT != "orig" and can_use_lowrank_out_fused(path.rows):
412
  if v1 is None:
413
  v1 = self.linear_rank_in(xv, z.get(p+"v1"), z.get(p+"v1.t"), path.rows)
414
  v = torch.ops.rwkv7_v3a_ops.linear_t_vres_f16(v1.contiguous(), z[p+"v2.t"], v.contiguous(), v_first.contiguous(), z[p+"v0"])
 
474
  return self.linear(x, weight)
475
  if path.rows == 1:
476
  if group == "ffn_key":
477
+ if C == 2560:
478
+ return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 128, 2, True)
479
+ return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 128, 2, C <= 1024)
480
+ return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 128, 2, group != "att_c2c" or C < 2048)
481
  if path.rows == 2:
482
  if group == "att_c2c":
483
  return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 64, 2, True)
484
  if group == "ffn_key":
485
+ if C == 2560:
486
+ return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 128, 2, False)
487
+ if C < 4096:
488
+ return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 64, 2, True)
489
+ return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 128, 2, False)
490
+ if group == "head" and C == 2560:
491
+ return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 128, 2, False)
492
  return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 64, 2, True)
493
  if path.rows == 3:
494
  if group == "head":
495
+ if C <= 2048:
496
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig(x.contiguous(), weight)
497
+ if C == 2560:
498
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig(x.contiguous(), weight)
499
  return torch.ops.rwkv7_v3a_ops.linear_orig_rows_f16(x.contiguous(), weight, 3, 2)
500
  if group == "ffn_key":
501
+ if C <= 1024:
502
+ return torch.ops.rwkv7_v3a_ops.linear_orig_rows_cfg_f16(x.contiguous(), weight, 64, 3, 4)
503
+ if C == 2048:
504
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig(x.contiguous(), weight)
505
+ if C == 2560:
506
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig(x.contiguous(), weight)
507
  return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 0)
508
  if group == "att_c2c":
509
+ if C == 768:
510
+ return torch.ops.rwkv7_v3a_ops.linear_orig_rows_f16(x.contiguous(), weight, 1, 2)
511
+ if C == 1024:
512
+ return torch.ops.rwkv7_v3a_ops.linear_orig_rows_f16(x.contiguous(), weight, 2, 2)
513
+ if C == 2048:
514
+ return torch.ops.rwkv7_v3a_ops.linear_orig_rows_f16(x.contiguous(), weight, 3, 4)
515
+ if C == 2560:
516
+ return torch.ops.rwkv7_v3a_ops.linear_orig_rows_f16(x.contiguous(), weight, 3, 2)
517
  return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 2)
518
  return torch.ops.rwkv7_v3a_ops.linear_orig_rows_cfg_f16(x.contiguous(), weight, 64, 3, 4)
519
  if path.rows == 4:
520
  if group == "ffn_key":
521
+ if C <= 1024:
522
+ return torch.ops.rwkv7_v3a_ops.linear_orig_rows_cfg_f16(x.contiguous(), weight, 64, 2, 4)
523
+ if C == 2048:
524
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig(x.contiguous(), weight)
525
+ if C == 2560:
526
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig(x.contiguous(), weight)
527
  return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 0)
528
  if group == "att_c2c":
529
+ if C <= 1024:
530
+ return torch.ops.rwkv7_v3a_ops.linear_orig_rows_f16(x.contiguous(), weight, 2, 2)
531
+ if C == 2048:
532
+ return torch.ops.rwkv7_v3a_ops.linear_orig_rows_f16(x.contiguous(), weight, 4, 2)
533
+ if C == 2560:
534
+ return torch.ops.rwkv7_v3a_ops.linear_orig_rows_f16(x.contiguous(), weight, 4, 2)
535
  return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 2)
536
  if group == "head":
537
+ if C == 768:
538
+ if 192 <= path.rows < 256:
539
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 3)
540
+ if 96 <= path.rows < 160:
541
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 1)
542
+ if C == 1024:
543
+ if 256 <= path.rows < 384:
544
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig(x.contiguous(), weight)
545
+ if 192 <= path.rows < 256:
546
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 2)
547
+ if 96 <= path.rows < 160:
548
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 1)
549
+ if C == 2048:
550
+ if 256 <= path.rows < 384:
551
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 0)
552
+ if 192 <= path.rows < 256:
553
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 6)
554
+ if 128 <= path.rows < 160:
555
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 1)
556
+ if 96 <= path.rows < 112:
557
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 0)
558
+ if C == 2560:
559
+ if path.rows >= 256:
560
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 0)
561
+ if path.rows >= 192:
562
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 5)
563
+ if path.rows >= 160:
564
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 5)
565
+ if path.rows >= 128:
566
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 1)
567
+ if path.rows >= 96:
568
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 0)
569
+ if path.rows >= 80:
570
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 0)
571
+ if path.rows >= 72:
572
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 1)
573
  if path.rows >= 1024:
574
  return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 0)
575
  if path.rows >= 512:
 
593
  if path.rows >= 72:
594
  return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 2)
595
  if group == "att_c2c":
596
+ if C == 2560 and 17 <= path.rows <= 20:
597
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 0)
598
+ if C == 768:
599
+ if 256 <= path.rows < 384:
600
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 1)
601
+ if 96 <= path.rows < 112:
602
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 3)
603
+ if C == 1024:
604
+ if 256 <= path.rows < 384:
605
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 0)
606
+ if 96 <= path.rows < 112:
607
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 6)
608
+ if C == 2048:
609
+ if 256 <= path.rows < 384:
610
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 3)
611
+ if 192 <= path.rows < 256:
612
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 0)
613
+ if 96 <= path.rows < 112:
614
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 4)
615
+ if C == 2560:
616
+ if path.rows >= 256:
617
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 1)
618
+ if path.rows >= 160:
619
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 2)
620
+ if path.rows >= 128:
621
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 2)
622
+ if path.rows >= 112:
623
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 3)
624
+ if path.rows >= 96:
625
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 2)
626
+ if path.rows >= 72:
627
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 2)
628
+ if path.rows >= 5:
629
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig(x.contiguous(), weight)
630
  if path.rows >= 1024:
631
  return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 4)
632
  if path.rows >= 768:
 
658
  if path.rows >= 5:
659
  return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 2)
660
  if group == "ffn_key":
661
+ if C == 2560 and 17 <= path.rows <= 20:
662
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 0)
663
+ if C == 768:
664
+ if 256 <= path.rows < 384:
665
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig(x.contiguous(), weight)
666
+ if 96 <= path.rows < 112:
667
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig(x.contiguous(), weight)
668
+ if C == 1024:
669
+ if 256 <= path.rows < 384:
670
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 2)
671
+ if 192 <= path.rows < 256:
672
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 0)
673
+ if 96 <= path.rows < 160:
674
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 2)
675
+ if C == 2048 and 128 <= path.rows < 160:
676
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 3)
677
+ if C == 2560:
678
+ if path.rows >= 192:
679
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 5)
680
+ if path.rows >= 160:
681
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 4)
682
+ if path.rows >= 128:
683
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 5)
684
+ if path.rows >= 112:
685
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 4)
686
+ if path.rows >= 96:
687
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 4)
688
+ if path.rows >= 80:
689
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 3)
690
+ if path.rows >= 72:
691
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 4)
692
+ if path.rows >= 3:
693
+ return torch.ops.rwkv7_v3a_ops.linear_f16_orig(x.contiguous(), weight)
694
  if path.rows >= 1024:
695
  return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 0)
696
  if path.rows >= 768:
 
727
  return self.linear_lowrank_orig(x, weight) if weight is not None else self.linear_t_orig(x, weight_t)
728
 
729
  def linear_rank_out(self, x: torch.Tensor, weight: torch.Tensor, weight_t: torch.Tensor, rows: int) -> torch.Tensor:
730
+ if weight_t is not None and C >= LOWRANK_FUSED_MIN_C and rows <= LOWRANK_OUT_ROWS_T:
731
  return torch.ops.rwkv7_v3a_ops.linear_t_f16(x.contiguous(), weight_t)
732
  return self.linear_lowrank_orig(x, weight) if weight is not None else self.linear_t_orig(x, weight_t)
733
 
734
  def linear_rank_out_act(self, x: torch.Tensor, weight: torch.Tensor, weight_t: torch.Tensor, rows: int, act: int) -> torch.Tensor:
735
+ if weight_t is not None and C >= LOWRANK_FUSED_MIN_C and rows <= LOWRANK_OUT_ROWS_T:
736
  return torch.ops.rwkv7_v3a_ops.linear_t_act_f16(x.contiguous(), weight_t, act)
737
  ops = torch.ops.rwkv7_fast_ops_fp16
738
  x = ops.act_tanh(x.contiguous()) if act == 1 else ops.act_sigmoid(x.contiguous())