Spaces:
Running on T4
Running on T4
Commit ·
45d682f
1
Parent(s): c954ce5
faster
Browse files- app.py +47 -38
- cuda/rwkv7_fast_ops_fp16.cu +5 -4
- cuda/rwkv7_v3a_ops.cpp +21 -18
- cuda/rwkv7_v3a_ops.cu +291 -47
- rwkv7_fast_v3a.py +184 -16
app.py
CHANGED
|
@@ -1,45 +1,49 @@
|
|
| 1 |
-
import os,
|
| 2 |
-
os.environ["RWKV_V7_ON"] = '1'
|
| 3 |
-
os.environ["RWKV_JIT_ON"] = '1'
|
| 4 |
-
os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
|
| 5 |
-
|
| 6 |
-
from rwkv.model import RWKV
|
| 7 |
-
|
| 8 |
-
import gc, re
|
| 9 |
import gradio as gr
|
| 10 |
-
import base64
|
| 11 |
-
from io import BytesIO
|
| 12 |
import torch
|
| 13 |
-
import torch.nn.functional as F
|
| 14 |
from datetime import datetime
|
| 15 |
from huggingface_hub import hf_hub_download
|
| 16 |
from pynvml import *
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
nvmlInit()
|
| 18 |
-
gpu_h = nvmlDeviceGetHandleByIndex(0)
|
| 19 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
|
| 21 |
ctx_limit = 7000
|
| 22 |
gen_limit = 1000
|
| 23 |
|
| 24 |
########################## text rwkv ################################################################
|
| 25 |
-
from rwkv.utils import PIPELINE, PIPELINE_ARGS
|
| 26 |
|
| 27 |
title = "rwkv7-g1f-2.9b-20260420-ctx8192"
|
| 28 |
model_path = hf_hub_download(repo_id="BlinkDL/rwkv7-g1", filename=f"{title}.pth")
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
|
| 31 |
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
static_state_in = [torch.empty_like(x, device="cuda") for x in state]
|
| 38 |
-
static_state_out = [torch.empty_like(x, device="cuda") for x in state]
|
| 39 |
-
static_output = torch.empty((model.args.vocab_size), device="cuda", dtype=torch.half)
|
| 40 |
-
graph = torch.cuda.CUDAGraph()
|
| 41 |
-
with torch.cuda.graph(graph):
|
| 42 |
-
static_output, static_state_out = model.forward_one_alt(static_input, static_state_in)
|
| 43 |
|
| 44 |
def generate_prompt(instruction, input=""):
|
| 45 |
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
|
|
@@ -73,25 +77,30 @@ def evaluate(
|
|
| 73 |
out_last = 0
|
| 74 |
out_str = ''
|
| 75 |
occurrence = {}
|
| 76 |
-
state =
|
|
|
|
| 77 |
for i in range(int(token_count)):
|
| 78 |
-
|
| 79 |
if i == 0:
|
| 80 |
input_ids = pipeline.encode(ctx)[-ctx_limit:]
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
else:
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
static_state_in[j].copy_(static_state_out[j])
|
| 90 |
|
| 91 |
for n in occurrence:
|
| 92 |
-
|
| 93 |
|
| 94 |
-
token = pipeline.sample_logits(
|
| 95 |
if token in args.token_stop:
|
| 96 |
break
|
| 97 |
all_tokens += [token]
|
|
@@ -168,4 +177,4 @@ with gr.Blocks(title=title, theme=gr.themes.Base()) as demo:
|
|
| 168 |
data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty, penalty_decay])
|
| 169 |
|
| 170 |
demo.queue(default_concurrency_limit=1, max_size=10)
|
| 171 |
-
demo.launch(share=False)
|
|
|
|
| 1 |
+
import gc, os, re
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import gradio as gr
|
|
|
|
|
|
|
| 3 |
import torch
|
|
|
|
| 4 |
from datetime import datetime
|
| 5 |
from huggingface_hub import hf_hub_download
|
| 6 |
from pynvml import *
|
| 7 |
+
from rwkv.utils import PIPELINE, PIPELINE_ARGS
|
| 8 |
+
|
| 9 |
+
import rwkv7_fast_v3a as v3a
|
| 10 |
+
|
| 11 |
nvmlInit()
|
| 12 |
+
# gpu_h = nvmlDeviceGetHandleByIndex(0)
|
|
|
|
| 13 |
|
| 14 |
ctx_limit = 7000
|
| 15 |
gen_limit = 1000
|
| 16 |
|
| 17 |
########################## text rwkv ################################################################
|
|
|
|
| 18 |
|
| 19 |
title = "rwkv7-g1f-2.9b-20260420-ctx8192"
|
| 20 |
model_path = hf_hub_download(repo_id="BlinkDL/rwkv7-g1", filename=f"{title}.pth")
|
| 21 |
+
# model_path = "/dev/shm/rwkv7-g1f-7.2b-20260414-ctx8192.pth"
|
| 22 |
+
|
| 23 |
+
v3a.MODEL_PATH = model_path
|
| 24 |
+
v3a.WKV_MODE = "fp32io16"
|
| 25 |
+
v3a.EMB_DEVICE = "cpu"
|
| 26 |
+
v3a.RKV_MODE = "off"
|
| 27 |
+
v3a.CMIX_SPARSE = "no-fc"
|
| 28 |
+
v3a.LOWRANK_WEIGHT = "transpose"
|
| 29 |
+
v3a.ORIG_LINEAR_GROUPS = {"att_c2c", "ffn_key", "head"}
|
| 30 |
+
v3a.load_extensions(v3a.WKV_MODE)
|
| 31 |
+
model = v3a.RWKV7()
|
| 32 |
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
|
| 33 |
|
| 34 |
+
decode_state = model.zero_state(1)
|
| 35 |
+
decode_x = torch.empty((1, 1, v3a.C), device="cuda", dtype=torch.half)
|
| 36 |
+
decode_path = v3a.select_path(1, 1)
|
| 37 |
+
for _ in range(2):
|
| 38 |
+
model.forward_from_x(decode_x, decode_state, decode_path)
|
| 39 |
+
torch.cuda.synchronize()
|
| 40 |
+
decode_graph = torch.cuda.CUDAGraph()
|
| 41 |
+
with torch.cuda.graph(decode_graph):
|
| 42 |
+
decode_output = model.forward_from_x(decode_x, decode_state, decode_path)
|
| 43 |
|
| 44 |
+
def token_to_x(token: int):
|
| 45 |
+
token_tensor = torch.tensor([[int(token)]], dtype=torch.long, device="cpu" if model.emb_cpu else "cuda")
|
| 46 |
+
return model.embed(token_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
def generate_prompt(instruction, input=""):
|
| 49 |
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
|
|
|
|
| 77 |
out_last = 0
|
| 78 |
out_str = ''
|
| 79 |
occurrence = {}
|
| 80 |
+
state = model.zero_state(1)
|
| 81 |
+
out = None
|
| 82 |
for i in range(int(token_count)):
|
| 83 |
+
|
| 84 |
if i == 0:
|
| 85 |
input_ids = pipeline.encode(ctx)[-ctx_limit:]
|
| 86 |
+
CHUNK_LEN = 8192 # chunk prefill, save VRAM
|
| 87 |
+
while len(input_ids) > 0:
|
| 88 |
+
token_device = "cpu" if model.emb_cpu else "cuda"
|
| 89 |
+
tokens = torch.tensor(input_ids[:CHUNK_LEN], dtype=torch.long, device=token_device)
|
| 90 |
+
out = model.forward(tokens, state).view(-1)
|
| 91 |
+
input_ids = input_ids[CHUNK_LEN:]
|
| 92 |
+
for dst, src in zip(decode_state, state):
|
| 93 |
+
dst.copy_(src)
|
| 94 |
+
logits = out
|
| 95 |
else:
|
| 96 |
+
decode_x.copy_(token_to_x(token))
|
| 97 |
+
decode_graph.replay()
|
| 98 |
+
logits = decode_output.view(-1)
|
|
|
|
| 99 |
|
| 100 |
for n in occurrence:
|
| 101 |
+
logits[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
|
| 102 |
|
| 103 |
+
token = pipeline.sample_logits(logits, temperature=args.temperature, top_p=args.top_p)
|
| 104 |
if token in args.token_stop:
|
| 105 |
break
|
| 106 |
all_tokens += [token]
|
|
|
|
| 177 |
data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty, penalty_decay])
|
| 178 |
|
| 179 |
demo.queue(default_concurrency_limit=1, max_size=10)
|
| 180 |
+
demo.launch(share=False, server_name="0.0.0.0")
|
cuda/rwkv7_fast_ops_fp16.cu
CHANGED
|
@@ -12,7 +12,8 @@ namespace {
|
|
| 12 |
|
| 13 |
constexpr int HEAD_SIZE = 64;
|
| 14 |
constexpr int WARPS_PER_BLOCK = 4;
|
| 15 |
-
constexpr float
|
|
|
|
| 16 |
constexpr int FFN_SPMV_THREADS = 128;
|
| 17 |
constexpr int FFN_TILE = 128;
|
| 18 |
|
|
@@ -202,7 +203,7 @@ __global__ void tmix_kk_a_gate_kernel(
|
|
| 202 |
float sum_sq = u0 * u0 + u1 * u1;
|
| 203 |
sum_sq = warp_sum(sum_sq);
|
| 204 |
const float total = __shfl_sync(0xffffffffu, sum_sq, 0);
|
| 205 |
-
const float inv_d = 1.0f / fmaxf(sqrtf(total),
|
| 206 |
const float kk0 = u0 * inv_d;
|
| 207 |
const float kk1 = u1 * inv_d;
|
| 208 |
|
|
@@ -264,7 +265,7 @@ __global__ void tmix_lnx_rkvres_xg_kernel(
|
|
| 264 |
}
|
| 265 |
__syncthreads();
|
| 266 |
const float var = (partial[0] + partial[1]) * (1.0f / 64.0f);
|
| 267 |
-
const float rstd = rsqrtf(var +
|
| 268 |
__syncthreads();
|
| 269 |
|
| 270 |
const float rv = load_h1(r + idx);
|
|
@@ -296,7 +297,7 @@ __global__ void tmix_vres_gate_kernel(
|
|
| 296 |
if (idx >= total) {
|
| 297 |
return;
|
| 298 |
}
|
| 299 |
-
const int c = static_cast<int>(idx
|
| 300 |
const float vv = load_h1(v + idx);
|
| 301 |
const float gate = sigmoid_fast(load_h1(v0 + c) + load_h1(v12 + idx));
|
| 302 |
store_h1(out + idx, fmaf(load_h1(v_first + idx) - vv, gate, vv));
|
|
|
|
| 12 |
|
| 13 |
constexpr int HEAD_SIZE = 64;
|
| 14 |
constexpr int WARPS_PER_BLOCK = 4;
|
| 15 |
+
constexpr float KK_NORMALIZE_EPS = 1.0e-12f;
|
| 16 |
+
constexpr float TMIX_LN_X_EPS = 64.0e-5f;
|
| 17 |
constexpr int FFN_SPMV_THREADS = 128;
|
| 18 |
constexpr int FFN_TILE = 128;
|
| 19 |
|
|
|
|
| 203 |
float sum_sq = u0 * u0 + u1 * u1;
|
| 204 |
sum_sq = warp_sum(sum_sq);
|
| 205 |
const float total = __shfl_sync(0xffffffffu, sum_sq, 0);
|
| 206 |
+
const float inv_d = 1.0f / fmaxf(sqrtf(total), KK_NORMALIZE_EPS);
|
| 207 |
const float kk0 = u0 * inv_d;
|
| 208 |
const float kk1 = u1 * inv_d;
|
| 209 |
|
|
|
|
| 265 |
}
|
| 266 |
__syncthreads();
|
| 267 |
const float var = (partial[0] + partial[1]) * (1.0f / 64.0f);
|
| 268 |
+
const float rstd = rsqrtf(var + TMIX_LN_X_EPS);
|
| 269 |
__syncthreads();
|
| 270 |
|
| 271 |
const float rv = load_h1(r + idx);
|
|
|
|
| 297 |
if (idx >= total) {
|
| 298 |
return;
|
| 299 |
}
|
| 300 |
+
const int c = static_cast<int>(idx % static_cast<int64_t>(C));
|
| 301 |
const float vv = load_h1(v + idx);
|
| 302 |
const float gate = sigmoid_fast(load_h1(v0 + c) + load_h1(v12 + idx));
|
| 303 |
store_h1(out + idx, fmaf(load_h1(v_first + idx) - vv, gate, vv));
|
cuda/rwkv7_v3a_ops.cpp
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
#include <torch/extension.h>
|
| 2 |
#include <vector>
|
| 3 |
|
|
|
|
|
|
|
| 4 |
torch::Tensor layer_norm_f16_cuda(torch::Tensor x, torch::Tensor weight, torch::Tensor bias, double eps);
|
| 5 |
torch::Tensor emb_ln0_bf16_to_f16_cuda(torch::Tensor emb, torch::Tensor weight, torch::Tensor bias, double eps);
|
| 6 |
torch::Tensor layer_norm_f16_small_cuda(torch::Tensor x, torch::Tensor weight, torch::Tensor bias, double eps);
|
|
@@ -93,7 +95,6 @@ torch::Tensor emb_ln0_bf16_to_f16(torch::Tensor emb, torch::Tensor weight, torch
|
|
| 93 |
check_bf16_cuda_contig(bias, "bias");
|
| 94 |
TORCH_CHECK(emb.dim() == 2, "emb must have shape [V, C]");
|
| 95 |
const int64_t c = emb.size(1);
|
| 96 |
-
TORCH_CHECK(c == 4096, "emb_ln0_bf16_to_f16 currently requires C=4096");
|
| 97 |
TORCH_CHECK(weight.dim() == 1 && weight.size(0) == c, "weight shape mismatch");
|
| 98 |
TORCH_CHECK(bias.dim() == 1 && bias.size(0) == c, "bias shape mismatch");
|
| 99 |
return emb_ln0_bf16_to_f16_cuda(emb, weight, bias, eps);
|
|
@@ -436,7 +437,7 @@ std::vector<torch::Tensor> add_layer_norm_cmix_mix_f16(torch::Tensor x, torch::T
|
|
| 436 |
TORCH_CHECK(x.sizes() == residual.sizes(), "add_layer_norm_cmix_mix_f16 x/residual shape mismatch");
|
| 437 |
TORCH_CHECK(x.dim() == 3 && x.size(1) == 1, "add_layer_norm_cmix_mix_f16 requires shape [B,1,C]");
|
| 438 |
const int64_t c = x.size(2);
|
| 439 |
-
TORCH_CHECK(c ==
|
| 440 |
TORCH_CHECK(shift_state.dim() == 2 && shift_state.size(0) == x.size(0) && shift_state.size(1) == c,
|
| 441 |
"shift_state shape mismatch");
|
| 442 |
TORCH_CHECK(weight.dim() == 1 && weight.size(0) == c, "weight shape mismatch");
|
|
@@ -470,13 +471,15 @@ std::vector<torch::Tensor> add_layer_norm_tmix_mix6_f16(
|
|
| 470 |
check_half_cuda_contig(x_a, "x_a");
|
| 471 |
check_half_cuda_contig(x_g, "x_g");
|
| 472 |
TORCH_CHECK(x.sizes() == residual.sizes(), "add_layer_norm_tmix_mix6_f16 x/residual shape mismatch");
|
| 473 |
-
TORCH_CHECK(x.dim() == 3 && x.size(1) == 1
|
| 474 |
-
|
|
|
|
|
|
|
| 475 |
"shift_state shape mismatch");
|
| 476 |
-
TORCH_CHECK(weight.dim() == 1 && weight.size(0) ==
|
| 477 |
-
TORCH_CHECK(bias.dim() == 1 && bias.size(0) ==
|
| 478 |
-
TORCH_CHECK(x_r.numel() ==
|
| 479 |
-
x_v.numel() ==
|
| 480 |
"mix vector shape mismatch");
|
| 481 |
return add_layer_norm_tmix_mix6_f16_cuda(
|
| 482 |
x, residual, shift_state, weight, bias, x_r, x_w, x_k, x_v, x_a, x_g, eps);
|
|
@@ -602,10 +605,10 @@ void advance_i32(torch::Tensor x, int64_t amount) {
|
|
| 602 |
} // namespace
|
| 603 |
|
| 604 |
TORCH_LIBRARY(rwkv7_v3a_ops, m) {
|
| 605 |
-
m.def("layer_norm_f16(Tensor x, Tensor weight, Tensor bias, float eps=
|
| 606 |
-
m.def("emb_ln0_bf16_to_f16(Tensor emb, Tensor weight, Tensor bias, float eps=
|
| 607 |
-
m.def("layer_norm_f16_small(Tensor x, Tensor weight, Tensor bias, float eps=
|
| 608 |
-
m.def("layer_norm_f16_small512(Tensor x, Tensor weight, Tensor bias, float eps=
|
| 609 |
m.def("linear_f16(Tensor x, Tensor weight) -> Tensor");
|
| 610 |
m.def("linear_f16_orig(Tensor x, Tensor weight_orig) -> Tensor");
|
| 611 |
m.def("linear_orig_rows_f16(Tensor x, Tensor weight_orig, int row_tile, int out_tile) -> Tensor");
|
|
@@ -628,14 +631,14 @@ TORCH_LIBRARY(rwkv7_v3a_ops, m) {
|
|
| 628 |
m.def("linear_wag_rank_out_f16(Tensor w1, Tensor a1, Tensor g1, Tensor w2_t, Tensor a2_t, Tensor g2_t) -> Tensor[]");
|
| 629 |
m.def("linear_wagv_rank_out_f16(Tensor w1, Tensor a1, Tensor g1, Tensor v1, Tensor w2_t, Tensor a2_t, Tensor g2_t, Tensor v2_t, Tensor v, Tensor v_first, Tensor v0) -> Tensor[]");
|
| 630 |
m.def("add_f16(Tensor x, Tensor y) -> Tensor");
|
| 631 |
-
m.def("add_layer_norm_f16(Tensor x, Tensor residual, Tensor weight, Tensor bias, float eps=
|
| 632 |
-
m.def("add_last_layer_norm_f16(Tensor x, Tensor residual, Tensor weight, Tensor bias, float eps=
|
| 633 |
-
m.def("add_layer_norm_cmix_mix_f16(Tensor x, Tensor residual, Tensor(a!) shift_state, Tensor weight, Tensor bias, Tensor x_k, float eps=
|
| 634 |
-
m.def("add_layer_norm_tmix_mix6_f16(Tensor x, Tensor residual, Tensor(a!) shift_state, Tensor weight, Tensor bias, Tensor x_r, Tensor x_w, Tensor x_k, Tensor x_v, Tensor x_a, Tensor x_g, float eps=
|
| 635 |
m.def("add_layer_norm_tmix_mix6_f16_cfg(Tensor x, Tensor residual, Tensor(a!) shift_state, Tensor weight, Tensor bias, Tensor x_r, Tensor x_w, Tensor x_k, Tensor x_v, Tensor x_a, Tensor x_g, float eps, int threads) -> Tensor[]");
|
| 636 |
-
m.def("add_layer_norm_tmix_mix6_f16_scalar_stats(Tensor x, Tensor residual, Tensor(a!) shift_state, Tensor weight, Tensor bias, Tensor x_r, Tensor x_w, Tensor x_k, Tensor x_v, Tensor x_a, Tensor x_g, float eps=
|
| 637 |
m.def("add_layer_norm_cmix_mix_f16_cfg(Tensor x, Tensor residual, Tensor(a!) shift_state, Tensor weight, Tensor bias, Tensor x_k, float eps, int threads) -> Tensor[]");
|
| 638 |
-
m.def("add_layer_norm_cmix_mix_f16_scalar_stats(Tensor x, Tensor residual, Tensor(a!) shift_state, Tensor weight, Tensor bias, Tensor x_k, float eps=
|
| 639 |
m.def("advance_i32(Tensor(a!) x, int amount) -> ()");
|
| 640 |
}
|
| 641 |
|
|
|
|
| 1 |
#include <torch/extension.h>
|
| 2 |
#include <vector>
|
| 3 |
|
| 4 |
+
#define RWKV7_LAYER_NORM_EPS_SCHEMA "1e-5"
|
| 5 |
+
|
| 6 |
torch::Tensor layer_norm_f16_cuda(torch::Tensor x, torch::Tensor weight, torch::Tensor bias, double eps);
|
| 7 |
torch::Tensor emb_ln0_bf16_to_f16_cuda(torch::Tensor emb, torch::Tensor weight, torch::Tensor bias, double eps);
|
| 8 |
torch::Tensor layer_norm_f16_small_cuda(torch::Tensor x, torch::Tensor weight, torch::Tensor bias, double eps);
|
|
|
|
| 95 |
check_bf16_cuda_contig(bias, "bias");
|
| 96 |
TORCH_CHECK(emb.dim() == 2, "emb must have shape [V, C]");
|
| 97 |
const int64_t c = emb.size(1);
|
|
|
|
| 98 |
TORCH_CHECK(weight.dim() == 1 && weight.size(0) == c, "weight shape mismatch");
|
| 99 |
TORCH_CHECK(bias.dim() == 1 && bias.size(0) == c, "bias shape mismatch");
|
| 100 |
return emb_ln0_bf16_to_f16_cuda(emb, weight, bias, eps);
|
|
|
|
| 437 |
TORCH_CHECK(x.sizes() == residual.sizes(), "add_layer_norm_cmix_mix_f16 x/residual shape mismatch");
|
| 438 |
TORCH_CHECK(x.dim() == 3 && x.size(1) == 1, "add_layer_norm_cmix_mix_f16 requires shape [B,1,C]");
|
| 439 |
const int64_t c = x.size(2);
|
| 440 |
+
TORCH_CHECK((c % 2) == 0 && c > 0 && c <= 8192, "unsupported C");
|
| 441 |
TORCH_CHECK(shift_state.dim() == 2 && shift_state.size(0) == x.size(0) && shift_state.size(1) == c,
|
| 442 |
"shift_state shape mismatch");
|
| 443 |
TORCH_CHECK(weight.dim() == 1 && weight.size(0) == c, "weight shape mismatch");
|
|
|
|
| 471 |
check_half_cuda_contig(x_a, "x_a");
|
| 472 |
check_half_cuda_contig(x_g, "x_g");
|
| 473 |
TORCH_CHECK(x.sizes() == residual.sizes(), "add_layer_norm_tmix_mix6_f16 x/residual shape mismatch");
|
| 474 |
+
TORCH_CHECK(x.dim() == 3 && x.size(1) == 1, "add_layer_norm_tmix_mix6_f16 requires shape [B,1,C]");
|
| 475 |
+
const int64_t c = x.size(2);
|
| 476 |
+
TORCH_CHECK((c % 2) == 0 && c > 0 && c <= 8192, "unsupported C");
|
| 477 |
+
TORCH_CHECK(shift_state.dim() == 2 && shift_state.size(0) == x.size(0) && shift_state.size(1) == c,
|
| 478 |
"shift_state shape mismatch");
|
| 479 |
+
TORCH_CHECK(weight.dim() == 1 && weight.size(0) == c, "weight shape mismatch");
|
| 480 |
+
TORCH_CHECK(bias.dim() == 1 && bias.size(0) == c, "bias shape mismatch");
|
| 481 |
+
TORCH_CHECK(x_r.numel() == c && x_w.numel() == c && x_k.numel() == c &&
|
| 482 |
+
x_v.numel() == c && x_a.numel() == c && x_g.numel() == c,
|
| 483 |
"mix vector shape mismatch");
|
| 484 |
return add_layer_norm_tmix_mix6_f16_cuda(
|
| 485 |
x, residual, shift_state, weight, bias, x_r, x_w, x_k, x_v, x_a, x_g, eps);
|
|
|
|
| 605 |
} // namespace
|
| 606 |
|
| 607 |
TORCH_LIBRARY(rwkv7_v3a_ops, m) {
|
| 608 |
+
m.def("layer_norm_f16(Tensor x, Tensor weight, Tensor bias, float eps=" RWKV7_LAYER_NORM_EPS_SCHEMA ") -> Tensor");
|
| 609 |
+
m.def("emb_ln0_bf16_to_f16(Tensor emb, Tensor weight, Tensor bias, float eps=" RWKV7_LAYER_NORM_EPS_SCHEMA ") -> Tensor");
|
| 610 |
+
m.def("layer_norm_f16_small(Tensor x, Tensor weight, Tensor bias, float eps=" RWKV7_LAYER_NORM_EPS_SCHEMA ") -> Tensor");
|
| 611 |
+
m.def("layer_norm_f16_small512(Tensor x, Tensor weight, Tensor bias, float eps=" RWKV7_LAYER_NORM_EPS_SCHEMA ") -> Tensor");
|
| 612 |
m.def("linear_f16(Tensor x, Tensor weight) -> Tensor");
|
| 613 |
m.def("linear_f16_orig(Tensor x, Tensor weight_orig) -> Tensor");
|
| 614 |
m.def("linear_orig_rows_f16(Tensor x, Tensor weight_orig, int row_tile, int out_tile) -> Tensor");
|
|
|
|
| 631 |
m.def("linear_wag_rank_out_f16(Tensor w1, Tensor a1, Tensor g1, Tensor w2_t, Tensor a2_t, Tensor g2_t) -> Tensor[]");
|
| 632 |
m.def("linear_wagv_rank_out_f16(Tensor w1, Tensor a1, Tensor g1, Tensor v1, Tensor w2_t, Tensor a2_t, Tensor g2_t, Tensor v2_t, Tensor v, Tensor v_first, Tensor v0) -> Tensor[]");
|
| 633 |
m.def("add_f16(Tensor x, Tensor y) -> Tensor");
|
| 634 |
+
m.def("add_layer_norm_f16(Tensor x, Tensor residual, Tensor weight, Tensor bias, float eps=" RWKV7_LAYER_NORM_EPS_SCHEMA ") -> Tensor[]");
|
| 635 |
+
m.def("add_last_layer_norm_f16(Tensor x, Tensor residual, Tensor weight, Tensor bias, float eps=" RWKV7_LAYER_NORM_EPS_SCHEMA ") -> Tensor");
|
| 636 |
+
m.def("add_layer_norm_cmix_mix_f16(Tensor x, Tensor residual, Tensor(a!) shift_state, Tensor weight, Tensor bias, Tensor x_k, float eps=" RWKV7_LAYER_NORM_EPS_SCHEMA ") -> Tensor[]");
|
| 637 |
+
m.def("add_layer_norm_tmix_mix6_f16(Tensor x, Tensor residual, Tensor(a!) shift_state, Tensor weight, Tensor bias, Tensor x_r, Tensor x_w, Tensor x_k, Tensor x_v, Tensor x_a, Tensor x_g, float eps=" RWKV7_LAYER_NORM_EPS_SCHEMA ") -> Tensor[]");
|
| 638 |
m.def("add_layer_norm_tmix_mix6_f16_cfg(Tensor x, Tensor residual, Tensor(a!) shift_state, Tensor weight, Tensor bias, Tensor x_r, Tensor x_w, Tensor x_k, Tensor x_v, Tensor x_a, Tensor x_g, float eps, int threads) -> Tensor[]");
|
| 639 |
+
m.def("add_layer_norm_tmix_mix6_f16_scalar_stats(Tensor x, Tensor residual, Tensor(a!) shift_state, Tensor weight, Tensor bias, Tensor x_r, Tensor x_w, Tensor x_k, Tensor x_v, Tensor x_a, Tensor x_g, float eps=" RWKV7_LAYER_NORM_EPS_SCHEMA ") -> Tensor[]");
|
| 640 |
m.def("add_layer_norm_cmix_mix_f16_cfg(Tensor x, Tensor residual, Tensor(a!) shift_state, Tensor weight, Tensor bias, Tensor x_k, float eps, int threads) -> Tensor[]");
|
| 641 |
+
m.def("add_layer_norm_cmix_mix_f16_scalar_stats(Tensor x, Tensor residual, Tensor(a!) shift_state, Tensor weight, Tensor bias, Tensor x_k, float eps=" RWKV7_LAYER_NORM_EPS_SCHEMA ") -> Tensor[]");
|
| 642 |
m.def("advance_i32(Tensor(a!) x, int amount) -> ()");
|
| 643 |
}
|
| 644 |
|
cuda/rwkv7_v3a_ops.cu
CHANGED
|
@@ -1990,6 +1990,182 @@ __global__ __launch_bounds__(Threads, 1) void add_last_layer_norm_f16_small_kern
|
|
| 1990 |
}
|
| 1991 |
}
|
| 1992 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1993 |
} // namespace
|
| 1994 |
|
| 1995 |
at::Tensor add_f16_cuda(at::Tensor x, at::Tensor y) {
|
|
@@ -2154,9 +2330,16 @@ at::Tensor add_last_layer_norm_f16_cuda(at::Tensor x, at::Tensor residual, at::T
|
|
| 2154 |
const int64_t B = x.size(0);
|
| 2155 |
const int64_t T = x.size(1);
|
| 2156 |
const int64_t C = x.size(2);
|
| 2157 |
-
TORCH_CHECK(C ==
|
| 2158 |
auto y = at::empty({B, C}, x.options());
|
| 2159 |
auto stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2160 |
if (B >= 1024) {
|
| 2161 |
add_last_layer_norm_f16_small_kernel<LN_SMALL512_THREADS, true, true><<<static_cast<int>(B), LN_SMALL512_THREADS, 0, stream>>>(
|
| 2162 |
x.data_ptr<dtype>(), residual.data_ptr<dtype>(), weight.data_ptr<dtype>(), bias.data_ptr<dtype>(),
|
|
@@ -2184,19 +2367,36 @@ std::vector<at::Tensor> add_layer_norm_cmix_mix_f16_cuda(
|
|
| 2184 |
double eps) {
|
| 2185 |
auto x_out = at::empty_like(x);
|
| 2186 |
auto mixed = at::empty_like(x);
|
| 2187 |
-
const int64_t
|
|
|
|
|
|
|
| 2188 |
auto stream = at::cuda::getCurrentCUDAStream();
|
| 2189 |
-
|
| 2190 |
-
|
| 2191 |
-
|
| 2192 |
-
|
| 2193 |
-
|
| 2194 |
-
|
| 2195 |
-
|
| 2196 |
-
|
| 2197 |
-
|
| 2198 |
-
|
| 2199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2200 |
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 2201 |
return {x_out, mixed};
|
| 2202 |
}
|
|
@@ -2211,19 +2411,36 @@ std::vector<at::Tensor> add_layer_norm_cmix_mix_f16_scalar_stats_cuda(
|
|
| 2211 |
double eps) {
|
| 2212 |
auto x_out = at::empty_like(x);
|
| 2213 |
auto mixed = at::empty_like(x);
|
| 2214 |
-
const int64_t
|
|
|
|
|
|
|
| 2215 |
auto stream = at::cuda::getCurrentCUDAStream();
|
| 2216 |
-
|
| 2217 |
-
|
| 2218 |
-
|
| 2219 |
-
|
| 2220 |
-
|
| 2221 |
-
|
| 2222 |
-
|
| 2223 |
-
|
| 2224 |
-
|
| 2225 |
-
|
| 2226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2227 |
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 2228 |
return {x_out, mixed};
|
| 2229 |
}
|
|
@@ -2248,29 +2465,56 @@ std::vector<at::Tensor> add_layer_norm_tmix_mix6_f16_cuda(
|
|
| 2248 |
auto out_v = at::empty_like(x);
|
| 2249 |
auto out_a = at::empty_like(x);
|
| 2250 |
auto out_g = at::empty_like(x);
|
| 2251 |
-
const int64_t
|
|
|
|
|
|
|
| 2252 |
auto stream = at::cuda::getCurrentCUDAStream();
|
| 2253 |
-
|
| 2254 |
-
|
| 2255 |
-
|
| 2256 |
-
|
| 2257 |
-
|
| 2258 |
-
|
| 2259 |
-
|
| 2260 |
-
|
| 2261 |
-
|
| 2262 |
-
|
| 2263 |
-
|
| 2264 |
-
|
| 2265 |
-
|
| 2266 |
-
|
| 2267 |
-
|
| 2268 |
-
|
| 2269 |
-
|
| 2270 |
-
|
| 2271 |
-
|
| 2272 |
-
|
| 2273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2274 |
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 2275 |
return {x_out, out_r, out_w, out_k, out_v, out_a, out_g};
|
| 2276 |
}
|
|
|
|
| 1990 |
}
|
| 1991 |
}
|
| 1992 |
|
| 1993 |
+
template <int Threads>
|
| 1994 |
+
__global__ __launch_bounds__(Threads, 1) void add_last_layer_norm_f16_generic_kernel(
|
| 1995 |
+
const dtype* __restrict__ x,
|
| 1996 |
+
const dtype* __restrict__ residual,
|
| 1997 |
+
const dtype* __restrict__ weight,
|
| 1998 |
+
const dtype* __restrict__ bias,
|
| 1999 |
+
dtype* __restrict__ y,
|
| 2000 |
+
int64_t B,
|
| 2001 |
+
int64_t T,
|
| 2002 |
+
int C,
|
| 2003 |
+
float eps) {
|
| 2004 |
+
const int64_t bidx = blockIdx.x;
|
| 2005 |
+
if (bidx >= B) {
|
| 2006 |
+
return;
|
| 2007 |
+
}
|
| 2008 |
+
const int64_t src = (bidx * T + (T - 1)) * static_cast<int64_t>(C);
|
| 2009 |
+
const int64_t dst = bidx * static_cast<int64_t>(C);
|
| 2010 |
+
float sum = 0.0f;
|
| 2011 |
+
for (int c = threadIdx.x; c < C; c += Threads) {
|
| 2012 |
+
sum += __half2float(*reinterpret_cast<const __half*>(x + src + c)) +
|
| 2013 |
+
__half2float(*reinterpret_cast<const __half*>(residual + src + c));
|
| 2014 |
+
}
|
| 2015 |
+
sum = block_sum_t<Threads>(sum);
|
| 2016 |
+
const float mean = sum / static_cast<float>(C);
|
| 2017 |
+
float sum_var = 0.0f;
|
| 2018 |
+
for (int c = threadIdx.x; c < C; c += Threads) {
|
| 2019 |
+
const float v = __half2float(*reinterpret_cast<const __half*>(x + src + c)) +
|
| 2020 |
+
__half2float(*reinterpret_cast<const __half*>(residual + src + c));
|
| 2021 |
+
const float d = v - mean;
|
| 2022 |
+
sum_var += d * d;
|
| 2023 |
+
}
|
| 2024 |
+
sum_var = block_sum_t<Threads>(sum_var);
|
| 2025 |
+
const float rstd = rsqrtf(sum_var / static_cast<float>(C) + eps);
|
| 2026 |
+
const int pairs = C >> 1;
|
| 2027 |
+
for (int p = threadIdx.x; p < pairs; p += Threads) {
|
| 2028 |
+
const float2 xv = __half22float2(reinterpret_cast<const __half2*>(x + src)[p]);
|
| 2029 |
+
const float2 rv = __half22float2(reinterpret_cast<const __half2*>(residual + src)[p]);
|
| 2030 |
+
const float sx = xv.x + rv.x;
|
| 2031 |
+
const float sy = xv.y + rv.y;
|
| 2032 |
+
const float2 w = __half22float2(reinterpret_cast<const __half2*>(weight)[p]);
|
| 2033 |
+
const float2 bb = __half22float2(reinterpret_cast<const __half2*>(bias)[p]);
|
| 2034 |
+
reinterpret_cast<__half2*>(y + dst)[p] = __floats2half2_rn(
|
| 2035 |
+
(sx - mean) * rstd * w.x + bb.x,
|
| 2036 |
+
(sy - mean) * rstd * w.y + bb.y);
|
| 2037 |
+
}
|
| 2038 |
+
}
|
| 2039 |
+
|
| 2040 |
+
template <int Threads>
|
| 2041 |
+
__global__ __launch_bounds__(Threads, 1) void add_layer_norm_cmix_mix_f16_generic_kernel(
|
| 2042 |
+
const dtype* __restrict__ x,
|
| 2043 |
+
const dtype* __restrict__ residual,
|
| 2044 |
+
dtype* __restrict__ shift_state,
|
| 2045 |
+
const dtype* __restrict__ weight,
|
| 2046 |
+
const dtype* __restrict__ bias,
|
| 2047 |
+
const dtype* __restrict__ x_k,
|
| 2048 |
+
dtype* __restrict__ x_out,
|
| 2049 |
+
dtype* __restrict__ mixed,
|
| 2050 |
+
int64_t rows,
|
| 2051 |
+
int C,
|
| 2052 |
+
float eps) {
|
| 2053 |
+
const int64_t row = blockIdx.x;
|
| 2054 |
+
if (row >= rows) {
|
| 2055 |
+
return;
|
| 2056 |
+
}
|
| 2057 |
+
const int64_t base = row * static_cast<int64_t>(C);
|
| 2058 |
+
float sum = 0.0f;
|
| 2059 |
+
for (int c = threadIdx.x; c < C; c += Threads) {
|
| 2060 |
+
sum += __half2float(*reinterpret_cast<const __half*>(x + base + c)) +
|
| 2061 |
+
__half2float(*reinterpret_cast<const __half*>(residual + base + c));
|
| 2062 |
+
}
|
| 2063 |
+
sum = block_sum_t<Threads>(sum);
|
| 2064 |
+
const float mean = sum / static_cast<float>(C);
|
| 2065 |
+
float sum_var = 0.0f;
|
| 2066 |
+
for (int c = threadIdx.x; c < C; c += Threads) {
|
| 2067 |
+
const float v = __half2float(*reinterpret_cast<const __half*>(x + base + c)) +
|
| 2068 |
+
__half2float(*reinterpret_cast<const __half*>(residual + base + c));
|
| 2069 |
+
const float d = v - mean;
|
| 2070 |
+
sum_var += d * d;
|
| 2071 |
+
}
|
| 2072 |
+
sum_var = block_sum_t<Threads>(sum_var);
|
| 2073 |
+
const float rstd = rsqrtf(sum_var / static_cast<float>(C) + eps);
|
| 2074 |
+
const int pairs = C >> 1;
|
| 2075 |
+
const int64_t base2 = base >> 1;
|
| 2076 |
+
for (int p = threadIdx.x; p < pairs; p += Threads) {
|
| 2077 |
+
const float2 xv = __half22float2(reinterpret_cast<const __half2*>(x)[base2 + p]);
|
| 2078 |
+
const float2 rv = __half22float2(reinterpret_cast<const __half2*>(residual)[base2 + p]);
|
| 2079 |
+
const float2 w = __half22float2(reinterpret_cast<const __half2*>(weight)[p]);
|
| 2080 |
+
const float2 b = __half22float2(reinterpret_cast<const __half2*>(bias)[p]);
|
| 2081 |
+
const float2 prev = __half22float2(reinterpret_cast<const __half2*>(shift_state)[base2 + p]);
|
| 2082 |
+
const float2 mix = __half22float2(reinterpret_cast<const __half2*>(x_k)[p]);
|
| 2083 |
+
const float x0 = xv.x + rv.x;
|
| 2084 |
+
const float x1 = xv.y + rv.y;
|
| 2085 |
+
const __half2 y2 = __floats2half2_rn((x0 - mean) * rstd * w.x + b.x, (x1 - mean) * rstd * w.y + b.y);
|
| 2086 |
+
const float2 yv = __half22float2(y2);
|
| 2087 |
+
reinterpret_cast<__half2*>(x_out)[base2 + p] = __floats2half2_rn(x0, x1);
|
| 2088 |
+
reinterpret_cast<__half2*>(mixed)[base2 + p] =
|
| 2089 |
+
__floats2half2_rn(yv.x + (prev.x - yv.x) * mix.x, yv.y + (prev.y - yv.y) * mix.y);
|
| 2090 |
+
reinterpret_cast<__half2*>(shift_state)[base2 + p] = y2;
|
| 2091 |
+
}
|
| 2092 |
+
}
|
| 2093 |
+
|
| 2094 |
+
template <int Threads>
|
| 2095 |
+
__global__ __launch_bounds__(Threads, 1) void add_layer_norm_tmix_mix6_f16_generic_kernel(
|
| 2096 |
+
const dtype* __restrict__ x,
|
| 2097 |
+
const dtype* __restrict__ residual,
|
| 2098 |
+
dtype* __restrict__ shift_state,
|
| 2099 |
+
const dtype* __restrict__ weight,
|
| 2100 |
+
const dtype* __restrict__ bias,
|
| 2101 |
+
const dtype* __restrict__ x_r,
|
| 2102 |
+
const dtype* __restrict__ x_w,
|
| 2103 |
+
const dtype* __restrict__ x_k,
|
| 2104 |
+
const dtype* __restrict__ x_v,
|
| 2105 |
+
const dtype* __restrict__ x_a,
|
| 2106 |
+
const dtype* __restrict__ x_g,
|
| 2107 |
+
dtype* __restrict__ x_out,
|
| 2108 |
+
dtype* __restrict__ out_r,
|
| 2109 |
+
dtype* __restrict__ out_w,
|
| 2110 |
+
dtype* __restrict__ out_k,
|
| 2111 |
+
dtype* __restrict__ out_v,
|
| 2112 |
+
dtype* __restrict__ out_a,
|
| 2113 |
+
dtype* __restrict__ out_g,
|
| 2114 |
+
int64_t rows,
|
| 2115 |
+
int C,
|
| 2116 |
+
float eps) {
|
| 2117 |
+
const int64_t row = blockIdx.x;
|
| 2118 |
+
if (row >= rows) {
|
| 2119 |
+
return;
|
| 2120 |
+
}
|
| 2121 |
+
const int64_t base = row * static_cast<int64_t>(C);
|
| 2122 |
+
float sum = 0.0f;
|
| 2123 |
+
for (int c = threadIdx.x; c < C; c += Threads) {
|
| 2124 |
+
sum += __half2float(*reinterpret_cast<const __half*>(x + base + c)) +
|
| 2125 |
+
__half2float(*reinterpret_cast<const __half*>(residual + base + c));
|
| 2126 |
+
}
|
| 2127 |
+
sum = block_sum_t<Threads>(sum);
|
| 2128 |
+
const float mean = sum / static_cast<float>(C);
|
| 2129 |
+
float sum_var = 0.0f;
|
| 2130 |
+
for (int c = threadIdx.x; c < C; c += Threads) {
|
| 2131 |
+
const float v = __half2float(*reinterpret_cast<const __half*>(x + base + c)) +
|
| 2132 |
+
__half2float(*reinterpret_cast<const __half*>(residual + base + c));
|
| 2133 |
+
const float d = v - mean;
|
| 2134 |
+
sum_var += d * d;
|
| 2135 |
+
}
|
| 2136 |
+
sum_var = block_sum_t<Threads>(sum_var);
|
| 2137 |
+
const float rstd = rsqrtf(sum_var / static_cast<float>(C) + eps);
|
| 2138 |
+
const int pairs = C >> 1;
|
| 2139 |
+
const int64_t base2 = base >> 1;
|
| 2140 |
+
for (int p = threadIdx.x; p < pairs; p += Threads) {
|
| 2141 |
+
const float2 xv = __half22float2(reinterpret_cast<const __half2*>(x)[base2 + p]);
|
| 2142 |
+
const float2 rv = __half22float2(reinterpret_cast<const __half2*>(residual)[base2 + p]);
|
| 2143 |
+
const float2 w = __half22float2(reinterpret_cast<const __half2*>(weight)[p]);
|
| 2144 |
+
const float2 b = __half22float2(reinterpret_cast<const __half2*>(bias)[p]);
|
| 2145 |
+
const float2 prev = __half22float2(reinterpret_cast<const __half2*>(shift_state)[base2 + p]);
|
| 2146 |
+
const float x0 = xv.x + rv.x;
|
| 2147 |
+
const float x1 = xv.y + rv.y;
|
| 2148 |
+
const __half2 y2 = __floats2half2_rn((x0 - mean) * rstd * w.x + b.x, (x1 - mean) * rstd * w.y + b.y);
|
| 2149 |
+
const float2 yv = __half22float2(y2);
|
| 2150 |
+
const float dx0 = prev.x - yv.x;
|
| 2151 |
+
const float dx1 = prev.y - yv.y;
|
| 2152 |
+
const float2 mr = __half22float2(reinterpret_cast<const __half2*>(x_r)[p]);
|
| 2153 |
+
const float2 mw = __half22float2(reinterpret_cast<const __half2*>(x_w)[p]);
|
| 2154 |
+
const float2 mk = __half22float2(reinterpret_cast<const __half2*>(x_k)[p]);
|
| 2155 |
+
const float2 mv = __half22float2(reinterpret_cast<const __half2*>(x_v)[p]);
|
| 2156 |
+
const float2 ma = __half22float2(reinterpret_cast<const __half2*>(x_a)[p]);
|
| 2157 |
+
const float2 mg = __half22float2(reinterpret_cast<const __half2*>(x_g)[p]);
|
| 2158 |
+
reinterpret_cast<__half2*>(x_out)[base2 + p] = __floats2half2_rn(x0, x1);
|
| 2159 |
+
reinterpret_cast<__half2*>(out_r)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mr.x, yv.y + dx1 * mr.y);
|
| 2160 |
+
reinterpret_cast<__half2*>(out_w)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mw.x, yv.y + dx1 * mw.y);
|
| 2161 |
+
reinterpret_cast<__half2*>(out_k)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mk.x, yv.y + dx1 * mk.y);
|
| 2162 |
+
reinterpret_cast<__half2*>(out_v)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mv.x, yv.y + dx1 * mv.y);
|
| 2163 |
+
reinterpret_cast<__half2*>(out_a)[base2 + p] = __floats2half2_rn(yv.x + dx0 * ma.x, yv.y + dx1 * ma.y);
|
| 2164 |
+
reinterpret_cast<__half2*>(out_g)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mg.x, yv.y + dx1 * mg.y);
|
| 2165 |
+
reinterpret_cast<__half2*>(shift_state)[base2 + p] = y2;
|
| 2166 |
+
}
|
| 2167 |
+
}
|
| 2168 |
+
|
| 2169 |
} // namespace
|
| 2170 |
|
| 2171 |
at::Tensor add_f16_cuda(at::Tensor x, at::Tensor y) {
|
|
|
|
| 2330 |
const int64_t B = x.size(0);
|
| 2331 |
const int64_t T = x.size(1);
|
| 2332 |
const int64_t C = x.size(2);
|
| 2333 |
+
TORCH_CHECK((C % 2) == 0, "add_last_layer_norm_f16 requires even C");
|
| 2334 |
auto y = at::empty({B, C}, x.options());
|
| 2335 |
auto stream = at::cuda::getCurrentCUDAStream();
|
| 2336 |
+
if (C != LN_SMALL_C) {
|
| 2337 |
+
add_last_layer_norm_f16_generic_kernel<LN_THREADS><<<static_cast<int>(B), LN_THREADS, 0, stream>>>(
|
| 2338 |
+
x.data_ptr<dtype>(), residual.data_ptr<dtype>(), weight.data_ptr<dtype>(), bias.data_ptr<dtype>(),
|
| 2339 |
+
y.data_ptr<dtype>(), B, T, static_cast<int>(C), static_cast<float>(eps));
|
| 2340 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 2341 |
+
return y;
|
| 2342 |
+
}
|
| 2343 |
if (B >= 1024) {
|
| 2344 |
add_last_layer_norm_f16_small_kernel<LN_SMALL512_THREADS, true, true><<<static_cast<int>(B), LN_SMALL512_THREADS, 0, stream>>>(
|
| 2345 |
x.data_ptr<dtype>(), residual.data_ptr<dtype>(), weight.data_ptr<dtype>(), bias.data_ptr<dtype>(),
|
|
|
|
| 2367 |
double eps) {
|
| 2368 |
auto x_out = at::empty_like(x);
|
| 2369 |
auto mixed = at::empty_like(x);
|
| 2370 |
+
const int64_t C = x.size(-1);
|
| 2371 |
+
TORCH_CHECK((C % 2) == 0, "add_layer_norm_cmix_mix_f16 requires even C");
|
| 2372 |
+
const int64_t rows = x.numel() / C;
|
| 2373 |
auto stream = at::cuda::getCurrentCUDAStream();
|
| 2374 |
+
if (C == LN_SMALL_C) {
|
| 2375 |
+
add_layer_norm_cmix_mix_f16_scalar_stats_kernel<LN_SMALL_THREADS><<<static_cast<int>(rows), LN_SMALL_THREADS, 0, stream>>>(
|
| 2376 |
+
x.data_ptr<dtype>(),
|
| 2377 |
+
residual.data_ptr<dtype>(),
|
| 2378 |
+
shift_state.data_ptr<dtype>(),
|
| 2379 |
+
weight.data_ptr<dtype>(),
|
| 2380 |
+
bias.data_ptr<dtype>(),
|
| 2381 |
+
x_k.data_ptr<dtype>(),
|
| 2382 |
+
x_out.data_ptr<dtype>(),
|
| 2383 |
+
mixed.data_ptr<dtype>(),
|
| 2384 |
+
rows,
|
| 2385 |
+
static_cast<float>(eps));
|
| 2386 |
+
} else {
|
| 2387 |
+
add_layer_norm_cmix_mix_f16_generic_kernel<LN_THREADS><<<static_cast<int>(rows), LN_THREADS, 0, stream>>>(
|
| 2388 |
+
x.data_ptr<dtype>(),
|
| 2389 |
+
residual.data_ptr<dtype>(),
|
| 2390 |
+
shift_state.data_ptr<dtype>(),
|
| 2391 |
+
weight.data_ptr<dtype>(),
|
| 2392 |
+
bias.data_ptr<dtype>(),
|
| 2393 |
+
x_k.data_ptr<dtype>(),
|
| 2394 |
+
x_out.data_ptr<dtype>(),
|
| 2395 |
+
mixed.data_ptr<dtype>(),
|
| 2396 |
+
rows,
|
| 2397 |
+
static_cast<int>(C),
|
| 2398 |
+
static_cast<float>(eps));
|
| 2399 |
+
}
|
| 2400 |
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 2401 |
return {x_out, mixed};
|
| 2402 |
}
|
|
|
|
| 2411 |
double eps) {
|
| 2412 |
auto x_out = at::empty_like(x);
|
| 2413 |
auto mixed = at::empty_like(x);
|
| 2414 |
+
const int64_t C = x.size(-1);
|
| 2415 |
+
TORCH_CHECK((C % 2) == 0, "add_layer_norm_cmix_mix_f16 requires even C");
|
| 2416 |
+
const int64_t rows = x.numel() / C;
|
| 2417 |
auto stream = at::cuda::getCurrentCUDAStream();
|
| 2418 |
+
if (C == LN_SMALL_C) {
|
| 2419 |
+
add_layer_norm_cmix_mix_f16_scalar_stats_kernel<LN_SMALL_THREADS><<<static_cast<int>(rows), LN_SMALL_THREADS, 0, stream>>>(
|
| 2420 |
+
x.data_ptr<dtype>(),
|
| 2421 |
+
residual.data_ptr<dtype>(),
|
| 2422 |
+
shift_state.data_ptr<dtype>(),
|
| 2423 |
+
weight.data_ptr<dtype>(),
|
| 2424 |
+
bias.data_ptr<dtype>(),
|
| 2425 |
+
x_k.data_ptr<dtype>(),
|
| 2426 |
+
x_out.data_ptr<dtype>(),
|
| 2427 |
+
mixed.data_ptr<dtype>(),
|
| 2428 |
+
rows,
|
| 2429 |
+
static_cast<float>(eps));
|
| 2430 |
+
} else {
|
| 2431 |
+
add_layer_norm_cmix_mix_f16_generic_kernel<LN_THREADS><<<static_cast<int>(rows), LN_THREADS, 0, stream>>>(
|
| 2432 |
+
x.data_ptr<dtype>(),
|
| 2433 |
+
residual.data_ptr<dtype>(),
|
| 2434 |
+
shift_state.data_ptr<dtype>(),
|
| 2435 |
+
weight.data_ptr<dtype>(),
|
| 2436 |
+
bias.data_ptr<dtype>(),
|
| 2437 |
+
x_k.data_ptr<dtype>(),
|
| 2438 |
+
x_out.data_ptr<dtype>(),
|
| 2439 |
+
mixed.data_ptr<dtype>(),
|
| 2440 |
+
rows,
|
| 2441 |
+
static_cast<int>(C),
|
| 2442 |
+
static_cast<float>(eps));
|
| 2443 |
+
}
|
| 2444 |
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 2445 |
return {x_out, mixed};
|
| 2446 |
}
|
|
|
|
| 2465 |
auto out_v = at::empty_like(x);
|
| 2466 |
auto out_a = at::empty_like(x);
|
| 2467 |
auto out_g = at::empty_like(x);
|
| 2468 |
+
const int64_t C = x.size(-1);
|
| 2469 |
+
TORCH_CHECK((C % 2) == 0, "add_layer_norm_tmix_mix6_f16 requires even C");
|
| 2470 |
+
const int64_t rows = x.numel() / C;
|
| 2471 |
auto stream = at::cuda::getCurrentCUDAStream();
|
| 2472 |
+
if (C == LN_SMALL_C) {
|
| 2473 |
+
add_layer_norm_tmix_mix6_f16_scalar_stats_kernel<LN_SMALL_THREADS><<<static_cast<int>(rows), LN_SMALL_THREADS, 0, stream>>>(
|
| 2474 |
+
x.data_ptr<dtype>(),
|
| 2475 |
+
residual.data_ptr<dtype>(),
|
| 2476 |
+
shift_state.data_ptr<dtype>(),
|
| 2477 |
+
weight.data_ptr<dtype>(),
|
| 2478 |
+
bias.data_ptr<dtype>(),
|
| 2479 |
+
x_r.data_ptr<dtype>(),
|
| 2480 |
+
x_w.data_ptr<dtype>(),
|
| 2481 |
+
x_k.data_ptr<dtype>(),
|
| 2482 |
+
x_v.data_ptr<dtype>(),
|
| 2483 |
+
x_a.data_ptr<dtype>(),
|
| 2484 |
+
x_g.data_ptr<dtype>(),
|
| 2485 |
+
x_out.data_ptr<dtype>(),
|
| 2486 |
+
out_r.data_ptr<dtype>(),
|
| 2487 |
+
out_w.data_ptr<dtype>(),
|
| 2488 |
+
out_k.data_ptr<dtype>(),
|
| 2489 |
+
out_v.data_ptr<dtype>(),
|
| 2490 |
+
out_a.data_ptr<dtype>(),
|
| 2491 |
+
out_g.data_ptr<dtype>(),
|
| 2492 |
+
rows,
|
| 2493 |
+
static_cast<float>(eps));
|
| 2494 |
+
} else {
|
| 2495 |
+
add_layer_norm_tmix_mix6_f16_generic_kernel<LN_THREADS><<<static_cast<int>(rows), LN_THREADS, 0, stream>>>(
|
| 2496 |
+
x.data_ptr<dtype>(),
|
| 2497 |
+
residual.data_ptr<dtype>(),
|
| 2498 |
+
shift_state.data_ptr<dtype>(),
|
| 2499 |
+
weight.data_ptr<dtype>(),
|
| 2500 |
+
bias.data_ptr<dtype>(),
|
| 2501 |
+
x_r.data_ptr<dtype>(),
|
| 2502 |
+
x_w.data_ptr<dtype>(),
|
| 2503 |
+
x_k.data_ptr<dtype>(),
|
| 2504 |
+
x_v.data_ptr<dtype>(),
|
| 2505 |
+
x_a.data_ptr<dtype>(),
|
| 2506 |
+
x_g.data_ptr<dtype>(),
|
| 2507 |
+
x_out.data_ptr<dtype>(),
|
| 2508 |
+
out_r.data_ptr<dtype>(),
|
| 2509 |
+
out_w.data_ptr<dtype>(),
|
| 2510 |
+
out_k.data_ptr<dtype>(),
|
| 2511 |
+
out_v.data_ptr<dtype>(),
|
| 2512 |
+
out_a.data_ptr<dtype>(),
|
| 2513 |
+
out_g.data_ptr<dtype>(),
|
| 2514 |
+
rows,
|
| 2515 |
+
static_cast<int>(C),
|
| 2516 |
+
static_cast<float>(eps));
|
| 2517 |
+
}
|
| 2518 |
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 2519 |
return {x_out, out_r, out_w, out_k, out_v, out_a, out_g};
|
| 2520 |
}
|
rwkv7_fast_v3a.py
CHANGED
|
@@ -25,7 +25,7 @@ ORIG_LINEAR_GROUPS = {"att_c2c", "ffn_key", "head"}
|
|
| 25 |
LOWRANK_SUFFIXES = ("att.w1", "att.w2", "att.a1", "att.a2", "att.g1", "att.g2", "att.v1", "att.v2")
|
| 26 |
LOWRANK_IN_ROWS_T = 7
|
| 27 |
LOWRANK_OUT_ROWS_T = 4
|
| 28 |
-
|
| 29 |
CMIX_NOFC_ROW20_MAX_T = 5
|
| 30 |
CMIX_NOFC_T512_MIN_ROWS = 8
|
| 31 |
LN1_TMIX_FUSE = True
|
|
@@ -36,8 +36,9 @@ CMIX_ROWS2_NOFC = "rows2_nofc"
|
|
| 36 |
CMIX_DENSE = "dense"
|
| 37 |
|
| 38 |
def main() -> None:
|
| 39 |
-
global WKV_MODE, EMB_DEVICE, RKV_MODE, CMIX_SPARSE, LOWRANK_WEIGHT, ORIG_LINEAR_GROUPS
|
| 40 |
parser = argparse.ArgumentParser()
|
|
|
|
| 41 |
parser.add_argument("--warmup", type=int, default=1)
|
| 42 |
parser.add_argument("--iters", type=int, default=3)
|
| 43 |
parser.add_argument("--cases", default="1x1,1x2,1x4,1x8,1x16,1x32,1x64,1x128,1x256,2x1,4x1,8x1,16x1,32x1,64x1,128x1,256x1,2x2,4x4,8x8,16x16") # try 1x1024 1024x1 32x32 for extreme tps
|
|
@@ -54,6 +55,7 @@ def main() -> None:
|
|
| 54 |
parser.add_argument("--orig-linear-groups", default="att_c2c,ffn_key,head") # comma list: none, att_c2c, ffn_key, head
|
| 55 |
args = parser.parse_args()
|
| 56 |
|
|
|
|
| 57 |
WKV_MODE = args.wkv
|
| 58 |
EMB_DEVICE = args.emb
|
| 59 |
RKV_MODE = args.batched_rkv
|
|
@@ -62,7 +64,7 @@ def main() -> None:
|
|
| 62 |
ORIG_LINEAR_GROUPS = parse_orig_linear_groups(args.orig_linear_groups)
|
| 63 |
groups = ",".join(sorted(ORIG_LINEAR_GROUPS)) if ORIG_LINEAR_GROUPS else "none"
|
| 64 |
log(f"start model={MODEL_PATH} wkv={WKV_MODE} emb={EMB_DEVICE} batched_rkv={RKV_MODE} cmix_sparse={CMIX_SPARSE} lowrank_weight={LOWRANK_WEIGHT} orig_linear_groups={groups}")
|
| 65 |
-
log(f"fixed fast path: ln=v3a linear=v3a/splitk lowrank={LOWRANK_IN_ROWS_T}/{LOWRANK_OUT_ROWS_T} nofc_rows
|
| 66 |
load_extensions(WKV_MODE)
|
| 67 |
model = RWKV7()
|
| 68 |
if args.eval_json:
|
|
@@ -97,7 +99,7 @@ def select_path(B: int, T: int) -> PathConfig:
|
|
| 97 |
if CMIX_SPARSE == "off":
|
| 98 |
cmix_mode = CMIX_DENSE
|
| 99 |
elif CMIX_SPARSE == "no-fc":
|
| 100 |
-
use_nofc = rows <=
|
| 101 |
cmix_mode = CMIX_B1T1_NOFC if rows == 1 else (CMIX_ROWS2_NOFC if use_nofc else CMIX_DENSE)
|
| 102 |
elif rows == 1:
|
| 103 |
cmix_mode = CMIX_B1T1_SPARSE
|
|
@@ -115,6 +117,12 @@ def select_path(B: int, T: int) -> PathConfig:
|
|
| 115 |
use_batched_rkv = False
|
| 116 |
return PathConfig(rows=rows, use_batched_rkv=use_batched_rkv, cmix_mode=cmix_mode)
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
def parse_orig_linear_groups(text: str) -> set[str]:
|
| 119 |
groups = {x.strip() for x in text.replace(",", " ").split() if x.strip()}
|
| 120 |
if not groups or groups == {"none"}:
|
|
@@ -130,6 +138,12 @@ def use_orig_linear(group: str) -> bool:
|
|
| 130 |
def is_lowrank_weight(key: str) -> bool:
|
| 131 |
return key.endswith(LOWRANK_SUFFIXES)
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
def is_att_c2c_weight(key: str) -> bool:
|
| 134 |
return ".att." in key and key.endswith(("receptance.weight", "key.weight", "value.weight", "output.weight"))
|
| 135 |
|
|
@@ -173,6 +187,7 @@ class RWKV7:
|
|
| 173 |
C, V = H * N, z["emb.weight"].shape[0]
|
| 174 |
assert N == HEAD_SIZE
|
| 175 |
log(f"detected model C={C} H={H} N={N} V={V}")
|
|
|
|
| 176 |
|
| 177 |
emb_src = z["emb.weight"].squeeze()
|
| 178 |
ln0_w_src = z["blocks.0.ln0.weight"].squeeze()
|
|
@@ -271,7 +286,7 @@ class RWKV7:
|
|
| 271 |
dev.copy_(host.view(B,T,C), non_blocking=True)
|
| 272 |
return dev
|
| 273 |
|
| 274 |
-
def forward_from_x(self, x: torch.Tensor, state: list[torch.Tensor], path: PathConfig, all_logits: bool = False) -> torch.Tensor:
|
| 275 |
z = self.z
|
| 276 |
B, T, _ = x.shape
|
| 277 |
v_first = x
|
|
@@ -302,7 +317,11 @@ class RWKV7:
|
|
| 302 |
else:
|
| 303 |
x, xx = self.add_ln(x, xx, z[p_next+"ln1.weight"], z[p_next+"ln1.bias"])
|
| 304 |
elif not all_logits:
|
| 305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
torch.ops.rwkv7_v3a_ops.advance_i32(state[2], T) # !!! IMPORTANT FOR WKV16 DITHERING !!!
|
| 307 |
return self.linear_head(x)
|
| 308 |
else:
|
|
@@ -323,6 +342,14 @@ class RWKV7:
|
|
| 323 |
x = self.embed(tokens)
|
| 324 |
return self.forward_from_x(x, state, path, all_logits=True)
|
| 325 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
def tmix(self, layer: int, x: torch.Tensor, shift_state: torch.Tensor, wkv_state: torch.Tensor, elapsed_t: torch.Tensor, v_first: torch.Tensor, p: str, path: PathConfig, pre_mix=None) -> tuple[torch.Tensor, torch.Tensor]:
|
| 327 |
z = self.z
|
| 328 |
ops = torch.ops.rwkv7_fast_ops_fp16
|
|
@@ -351,11 +378,11 @@ class RWKV7:
|
|
| 351 |
v = self.linear_orig_layout(xv, z[p+"value.weight"], path, "att_c2c")
|
| 352 |
|
| 353 |
v1 = None
|
| 354 |
-
if LOWRANK_WEIGHT != "orig" and path.rows
|
| 355 |
w1, a1, g1, v1 = torch.ops.rwkv7_v3a_ops.linear_wagv_rank_in_f16(
|
| 356 |
xw.contiguous(), xa.contiguous(), xg.contiguous(), xv.contiguous(),
|
| 357 |
z[p+"w1.t"], z[p+"a1.t"], z[p+"g1.t"], z[p+"v1.t"])
|
| 358 |
-
elif LOWRANK_WEIGHT != "orig" and path.rows
|
| 359 |
w1, a1, g1 = torch.ops.rwkv7_v3a_ops.linear_wag_rank_in_f16(
|
| 360 |
xw.contiguous(), xa.contiguous(), xg.contiguous(), z[p+"w1.t"], z[p+"a1.t"], z[p+"g1.t"])
|
| 361 |
else:
|
|
@@ -363,13 +390,13 @@ class RWKV7:
|
|
| 363 |
a1 = self.linear_rank_in(xa, z.get(p+"a1"), z.get(p+"a1.t"), path.rows)
|
| 364 |
g1 = self.linear_rank_in(xg, z.get(p+"g1"), z.get(p+"g1.t"), path.rows)
|
| 365 |
v_done = False
|
| 366 |
-
if LOWRANK_WEIGHT != "orig" and path.rows
|
| 367 |
w, a, g, v = torch.ops.rwkv7_v3a_ops.linear_wagv_rank_out_f16(
|
| 368 |
w1.contiguous(), a1.contiguous(), g1.contiguous(), v1.contiguous(),
|
| 369 |
z[p+"w2.t"], z[p+"a2.t"], z[p+"g2.t"], z[p+"v2.t"],
|
| 370 |
v.contiguous(), v_first.contiguous(), z[p+"v0"])
|
| 371 |
v_done = True
|
| 372 |
-
elif LOWRANK_WEIGHT != "orig" and path.rows
|
| 373 |
w, a, g = torch.ops.rwkv7_v3a_ops.linear_wag_rank_out_f16(
|
| 374 |
w1.contiguous(), a1.contiguous(), g1.contiguous(), z[p+"w2.t"], z[p+"a2.t"], z[p+"g2.t"])
|
| 375 |
else:
|
|
@@ -381,7 +408,7 @@ class RWKV7:
|
|
| 381 |
if layer == 0:
|
| 382 |
v_first = v
|
| 383 |
elif not v_done:
|
| 384 |
-
if LOWRANK_WEIGHT != "orig" and path.rows
|
| 385 |
if v1 is None:
|
| 386 |
v1 = self.linear_rank_in(xv, z.get(p+"v1"), z.get(p+"v1.t"), path.rows)
|
| 387 |
v = torch.ops.rwkv7_v3a_ops.linear_t_vres_f16(v1.contiguous(), z[p+"v2.t"], v.contiguous(), v_first.contiguous(), z[p+"v0"])
|
|
@@ -447,28 +474,102 @@ class RWKV7:
|
|
| 447 |
return self.linear(x, weight)
|
| 448 |
if path.rows == 1:
|
| 449 |
if group == "ffn_key":
|
| 450 |
-
|
| 451 |
-
|
|
|
|
|
|
|
| 452 |
if path.rows == 2:
|
| 453 |
if group == "att_c2c":
|
| 454 |
return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 64, 2, True)
|
| 455 |
if group == "ffn_key":
|
| 456 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 64, 2, True)
|
| 458 |
if path.rows == 3:
|
| 459 |
if group == "head":
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
return torch.ops.rwkv7_v3a_ops.linear_orig_rows_f16(x.contiguous(), weight, 3, 2)
|
| 461 |
if group == "ffn_key":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 462 |
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 0)
|
| 463 |
if group == "att_c2c":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 2)
|
| 465 |
return torch.ops.rwkv7_v3a_ops.linear_orig_rows_cfg_f16(x.contiguous(), weight, 64, 3, 4)
|
| 466 |
if path.rows == 4:
|
| 467 |
if group == "ffn_key":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 0)
|
| 469 |
if group == "att_c2c":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 2)
|
| 471 |
if group == "head":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
if path.rows >= 1024:
|
| 473 |
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 0)
|
| 474 |
if path.rows >= 512:
|
|
@@ -492,6 +593,40 @@ class RWKV7:
|
|
| 492 |
if path.rows >= 72:
|
| 493 |
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 2)
|
| 494 |
if group == "att_c2c":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
if path.rows >= 1024:
|
| 496 |
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 4)
|
| 497 |
if path.rows >= 768:
|
|
@@ -523,6 +658,39 @@ class RWKV7:
|
|
| 523 |
if path.rows >= 5:
|
| 524 |
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 2)
|
| 525 |
if group == "ffn_key":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
if path.rows >= 1024:
|
| 527 |
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 0)
|
| 528 |
if path.rows >= 768:
|
|
@@ -559,12 +727,12 @@ class RWKV7:
|
|
| 559 |
return self.linear_lowrank_orig(x, weight) if weight is not None else self.linear_t_orig(x, weight_t)
|
| 560 |
|
| 561 |
def linear_rank_out(self, x: torch.Tensor, weight: torch.Tensor, weight_t: torch.Tensor, rows: int) -> torch.Tensor:
|
| 562 |
-
if weight_t is not None and rows <= LOWRANK_OUT_ROWS_T:
|
| 563 |
return torch.ops.rwkv7_v3a_ops.linear_t_f16(x.contiguous(), weight_t)
|
| 564 |
return self.linear_lowrank_orig(x, weight) if weight is not None else self.linear_t_orig(x, weight_t)
|
| 565 |
|
| 566 |
def linear_rank_out_act(self, x: torch.Tensor, weight: torch.Tensor, weight_t: torch.Tensor, rows: int, act: int) -> torch.Tensor:
|
| 567 |
-
if weight_t is not None and rows <= LOWRANK_OUT_ROWS_T:
|
| 568 |
return torch.ops.rwkv7_v3a_ops.linear_t_act_f16(x.contiguous(), weight_t, act)
|
| 569 |
ops = torch.ops.rwkv7_fast_ops_fp16
|
| 570 |
x = ops.act_tanh(x.contiguous()) if act == 1 else ops.act_sigmoid(x.contiguous())
|
|
|
|
| 25 |
LOWRANK_SUFFIXES = ("att.w1", "att.w2", "att.a1", "att.a2", "att.g1", "att.g2", "att.v1", "att.v2")
|
| 26 |
LOWRANK_IN_ROWS_T = 7
|
| 27 |
LOWRANK_OUT_ROWS_T = 4
|
| 28 |
+
LOWRANK_FUSED_MIN_C = 1024
|
| 29 |
CMIX_NOFC_ROW20_MAX_T = 5
|
| 30 |
CMIX_NOFC_T512_MIN_ROWS = 8
|
| 31 |
LN1_TMIX_FUSE = True
|
|
|
|
| 36 |
CMIX_DENSE = "dense"
|
| 37 |
|
| 38 |
def main() -> None:
|
| 39 |
+
global MODEL_PATH, WKV_MODE, EMB_DEVICE, RKV_MODE, CMIX_SPARSE, LOWRANK_WEIGHT, ORIG_LINEAR_GROUPS
|
| 40 |
parser = argparse.ArgumentParser()
|
| 41 |
+
parser.add_argument("--model", default=MODEL_PATH)
|
| 42 |
parser.add_argument("--warmup", type=int, default=1)
|
| 43 |
parser.add_argument("--iters", type=int, default=3)
|
| 44 |
parser.add_argument("--cases", default="1x1,1x2,1x4,1x8,1x16,1x32,1x64,1x128,1x256,2x1,4x1,8x1,16x1,32x1,64x1,128x1,256x1,2x2,4x4,8x8,16x16") # try 1x1024 1024x1 32x32 for extreme tps
|
|
|
|
| 55 |
parser.add_argument("--orig-linear-groups", default="att_c2c,ffn_key,head") # comma list: none, att_c2c, ffn_key, head
|
| 56 |
args = parser.parse_args()
|
| 57 |
|
| 58 |
+
MODEL_PATH = args.model
|
| 59 |
WKV_MODE = args.wkv
|
| 60 |
EMB_DEVICE = args.emb
|
| 61 |
RKV_MODE = args.batched_rkv
|
|
|
|
| 64 |
ORIG_LINEAR_GROUPS = parse_orig_linear_groups(args.orig_linear_groups)
|
| 65 |
groups = ",".join(sorted(ORIG_LINEAR_GROUPS)) if ORIG_LINEAR_GROUPS else "none"
|
| 66 |
log(f"start model={MODEL_PATH} wkv={WKV_MODE} emb={EMB_DEVICE} batched_rkv={RKV_MODE} cmix_sparse={CMIX_SPARSE} lowrank_weight={LOWRANK_WEIGHT} orig_linear_groups={groups}")
|
| 67 |
+
log(f"fixed fast path: ln=v3a linear=v3a/splitk lowrank={LOWRANK_IN_ROWS_T}/{LOWRANK_OUT_ROWS_T} nofc_rows=by_C row20_t=by_C nofc_t512_rows>={CMIX_NOFC_T512_MIN_ROWS}")
|
| 68 |
load_extensions(WKV_MODE)
|
| 69 |
model = RWKV7()
|
| 70 |
if args.eval_json:
|
|
|
|
| 99 |
if CMIX_SPARSE == "off":
|
| 100 |
cmix_mode = CMIX_DENSE
|
| 101 |
elif CMIX_SPARSE == "no-fc":
|
| 102 |
+
use_nofc = rows <= cmix_nofc_max_rows() or (rows == 20 and T <= cmix_nofc_row20_max_t())
|
| 103 |
cmix_mode = CMIX_B1T1_NOFC if rows == 1 else (CMIX_ROWS2_NOFC if use_nofc else CMIX_DENSE)
|
| 104 |
elif rows == 1:
|
| 105 |
cmix_mode = CMIX_B1T1_SPARSE
|
|
|
|
| 117 |
use_batched_rkv = False
|
| 118 |
return PathConfig(rows=rows, use_batched_rkv=use_batched_rkv, cmix_mode=cmix_mode)
|
| 119 |
|
| 120 |
+
def cmix_nofc_max_rows() -> int:
|
| 121 |
+
return 19
|
| 122 |
+
|
| 123 |
+
def cmix_nofc_row20_max_t() -> int:
|
| 124 |
+
return CMIX_NOFC_ROW20_MAX_T
|
| 125 |
+
|
| 126 |
def parse_orig_linear_groups(text: str) -> set[str]:
|
| 127 |
groups = {x.strip() for x in text.replace(",", " ").split() if x.strip()}
|
| 128 |
if not groups or groups == {"none"}:
|
|
|
|
| 138 |
def is_lowrank_weight(key: str) -> bool:
|
| 139 |
return key.endswith(LOWRANK_SUFFIXES)
|
| 140 |
|
| 141 |
+
def can_use_lowrank_fused(rows: int) -> bool:
|
| 142 |
+
return C >= LOWRANK_FUSED_MIN_C and rows <= LOWRANK_IN_ROWS_T
|
| 143 |
+
|
| 144 |
+
def can_use_lowrank_out_fused(rows: int) -> bool:
|
| 145 |
+
return C >= LOWRANK_FUSED_MIN_C and rows <= LOWRANK_OUT_ROWS_T
|
| 146 |
+
|
| 147 |
def is_att_c2c_weight(key: str) -> bool:
|
| 148 |
return ".att." in key and key.endswith(("receptance.weight", "key.weight", "value.weight", "output.weight"))
|
| 149 |
|
|
|
|
| 187 |
C, V = H * N, z["emb.weight"].shape[0]
|
| 188 |
assert N == HEAD_SIZE
|
| 189 |
log(f"detected model C={C} H={H} N={N} V={V}")
|
| 190 |
+
log(f"cmix no-fc path: rows<={cmix_nofc_max_rows()} row20_t<={cmix_nofc_row20_max_t()}")
|
| 191 |
|
| 192 |
emb_src = z["emb.weight"].squeeze()
|
| 193 |
ln0_w_src = z["blocks.0.ln0.weight"].squeeze()
|
|
|
|
| 286 |
dev.copy_(host.view(B,T,C), non_blocking=True)
|
| 287 |
return dev
|
| 288 |
|
| 289 |
+
def forward_from_x(self, x: torch.Tensor, state: list[torch.Tensor], path: PathConfig, all_logits: bool = False, last_indices=None) -> torch.Tensor:
|
| 290 |
z = self.z
|
| 291 |
B, T, _ = x.shape
|
| 292 |
v_first = x
|
|
|
|
| 317 |
else:
|
| 318 |
x, xx = self.add_ln(x, xx, z[p_next+"ln1.weight"], z[p_next+"ln1.bias"])
|
| 319 |
elif not all_logits:
|
| 320 |
+
if last_indices is not None:
|
| 321 |
+
x = self.ln(self.add(x, xx), z["ln_out.weight"], z["ln_out.bias"])
|
| 322 |
+
x = x[torch.arange(B, device=x.device), last_indices].contiguous()
|
| 323 |
+
else:
|
| 324 |
+
x = self.add_last_ln(x, xx, z["ln_out.weight"], z["ln_out.bias"])
|
| 325 |
torch.ops.rwkv7_v3a_ops.advance_i32(state[2], T) # !!! IMPORTANT FOR WKV16 DITHERING !!!
|
| 326 |
return self.linear_head(x)
|
| 327 |
else:
|
|
|
|
| 342 |
x = self.embed(tokens)
|
| 343 |
return self.forward_from_x(x, state, path, all_logits=True)
|
| 344 |
|
| 345 |
+
def forward_last_at(self, tokens: torch.Tensor, state: list[torch.Tensor], last_indices: torch.Tensor) -> torch.Tensor:
|
| 346 |
+
if tokens.dim() == 1:
|
| 347 |
+
tokens = tokens.unsqueeze(0)
|
| 348 |
+
B, T = tokens.shape
|
| 349 |
+
path = select_path(B, T)
|
| 350 |
+
x = self.embed(tokens)
|
| 351 |
+
return self.forward_from_x(x, state, path, last_indices=last_indices)
|
| 352 |
+
|
| 353 |
def tmix(self, layer: int, x: torch.Tensor, shift_state: torch.Tensor, wkv_state: torch.Tensor, elapsed_t: torch.Tensor, v_first: torch.Tensor, p: str, path: PathConfig, pre_mix=None) -> tuple[torch.Tensor, torch.Tensor]:
|
| 354 |
z = self.z
|
| 355 |
ops = torch.ops.rwkv7_fast_ops_fp16
|
|
|
|
| 378 |
v = self.linear_orig_layout(xv, z[p+"value.weight"], path, "att_c2c")
|
| 379 |
|
| 380 |
v1 = None
|
| 381 |
+
if LOWRANK_WEIGHT != "orig" and can_use_lowrank_fused(path.rows) and can_use_lowrank_out_fused(path.rows) and layer != 0:
|
| 382 |
w1, a1, g1, v1 = torch.ops.rwkv7_v3a_ops.linear_wagv_rank_in_f16(
|
| 383 |
xw.contiguous(), xa.contiguous(), xg.contiguous(), xv.contiguous(),
|
| 384 |
z[p+"w1.t"], z[p+"a1.t"], z[p+"g1.t"], z[p+"v1.t"])
|
| 385 |
+
elif LOWRANK_WEIGHT != "orig" and can_use_lowrank_fused(path.rows):
|
| 386 |
w1, a1, g1 = torch.ops.rwkv7_v3a_ops.linear_wag_rank_in_f16(
|
| 387 |
xw.contiguous(), xa.contiguous(), xg.contiguous(), z[p+"w1.t"], z[p+"a1.t"], z[p+"g1.t"])
|
| 388 |
else:
|
|
|
|
| 390 |
a1 = self.linear_rank_in(xa, z.get(p+"a1"), z.get(p+"a1.t"), path.rows)
|
| 391 |
g1 = self.linear_rank_in(xg, z.get(p+"g1"), z.get(p+"g1.t"), path.rows)
|
| 392 |
v_done = False
|
| 393 |
+
if LOWRANK_WEIGHT != "orig" and can_use_lowrank_out_fused(path.rows) and layer != 0 and v1 is not None:
|
| 394 |
w, a, g, v = torch.ops.rwkv7_v3a_ops.linear_wagv_rank_out_f16(
|
| 395 |
w1.contiguous(), a1.contiguous(), g1.contiguous(), v1.contiguous(),
|
| 396 |
z[p+"w2.t"], z[p+"a2.t"], z[p+"g2.t"], z[p+"v2.t"],
|
| 397 |
v.contiguous(), v_first.contiguous(), z[p+"v0"])
|
| 398 |
v_done = True
|
| 399 |
+
elif LOWRANK_WEIGHT != "orig" and can_use_lowrank_out_fused(path.rows):
|
| 400 |
w, a, g = torch.ops.rwkv7_v3a_ops.linear_wag_rank_out_f16(
|
| 401 |
w1.contiguous(), a1.contiguous(), g1.contiguous(), z[p+"w2.t"], z[p+"a2.t"], z[p+"g2.t"])
|
| 402 |
else:
|
|
|
|
| 408 |
if layer == 0:
|
| 409 |
v_first = v
|
| 410 |
elif not v_done:
|
| 411 |
+
if LOWRANK_WEIGHT != "orig" and can_use_lowrank_out_fused(path.rows):
|
| 412 |
if v1 is None:
|
| 413 |
v1 = self.linear_rank_in(xv, z.get(p+"v1"), z.get(p+"v1.t"), path.rows)
|
| 414 |
v = torch.ops.rwkv7_v3a_ops.linear_t_vres_f16(v1.contiguous(), z[p+"v2.t"], v.contiguous(), v_first.contiguous(), z[p+"v0"])
|
|
|
|
| 474 |
return self.linear(x, weight)
|
| 475 |
if path.rows == 1:
|
| 476 |
if group == "ffn_key":
|
| 477 |
+
if C == 2560:
|
| 478 |
+
return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 128, 2, True)
|
| 479 |
+
return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 128, 2, C <= 1024)
|
| 480 |
+
return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 128, 2, group != "att_c2c" or C < 2048)
|
| 481 |
if path.rows == 2:
|
| 482 |
if group == "att_c2c":
|
| 483 |
return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 64, 2, True)
|
| 484 |
if group == "ffn_key":
|
| 485 |
+
if C == 2560:
|
| 486 |
+
return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 128, 2, False)
|
| 487 |
+
if C < 4096:
|
| 488 |
+
return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 64, 2, True)
|
| 489 |
+
return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 128, 2, False)
|
| 490 |
+
if group == "head" and C == 2560:
|
| 491 |
+
return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 128, 2, False)
|
| 492 |
return torch.ops.rwkv7_v3a_ops.linear_orig_rows_exact_f16(x.contiguous(), weight, 64, 2, True)
|
| 493 |
if path.rows == 3:
|
| 494 |
if group == "head":
|
| 495 |
+
if C <= 2048:
|
| 496 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig(x.contiguous(), weight)
|
| 497 |
+
if C == 2560:
|
| 498 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig(x.contiguous(), weight)
|
| 499 |
return torch.ops.rwkv7_v3a_ops.linear_orig_rows_f16(x.contiguous(), weight, 3, 2)
|
| 500 |
if group == "ffn_key":
|
| 501 |
+
if C <= 1024:
|
| 502 |
+
return torch.ops.rwkv7_v3a_ops.linear_orig_rows_cfg_f16(x.contiguous(), weight, 64, 3, 4)
|
| 503 |
+
if C == 2048:
|
| 504 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig(x.contiguous(), weight)
|
| 505 |
+
if C == 2560:
|
| 506 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig(x.contiguous(), weight)
|
| 507 |
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 0)
|
| 508 |
if group == "att_c2c":
|
| 509 |
+
if C == 768:
|
| 510 |
+
return torch.ops.rwkv7_v3a_ops.linear_orig_rows_f16(x.contiguous(), weight, 1, 2)
|
| 511 |
+
if C == 1024:
|
| 512 |
+
return torch.ops.rwkv7_v3a_ops.linear_orig_rows_f16(x.contiguous(), weight, 2, 2)
|
| 513 |
+
if C == 2048:
|
| 514 |
+
return torch.ops.rwkv7_v3a_ops.linear_orig_rows_f16(x.contiguous(), weight, 3, 4)
|
| 515 |
+
if C == 2560:
|
| 516 |
+
return torch.ops.rwkv7_v3a_ops.linear_orig_rows_f16(x.contiguous(), weight, 3, 2)
|
| 517 |
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 2)
|
| 518 |
return torch.ops.rwkv7_v3a_ops.linear_orig_rows_cfg_f16(x.contiguous(), weight, 64, 3, 4)
|
| 519 |
if path.rows == 4:
|
| 520 |
if group == "ffn_key":
|
| 521 |
+
if C <= 1024:
|
| 522 |
+
return torch.ops.rwkv7_v3a_ops.linear_orig_rows_cfg_f16(x.contiguous(), weight, 64, 2, 4)
|
| 523 |
+
if C == 2048:
|
| 524 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig(x.contiguous(), weight)
|
| 525 |
+
if C == 2560:
|
| 526 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig(x.contiguous(), weight)
|
| 527 |
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 0)
|
| 528 |
if group == "att_c2c":
|
| 529 |
+
if C <= 1024:
|
| 530 |
+
return torch.ops.rwkv7_v3a_ops.linear_orig_rows_f16(x.contiguous(), weight, 2, 2)
|
| 531 |
+
if C == 2048:
|
| 532 |
+
return torch.ops.rwkv7_v3a_ops.linear_orig_rows_f16(x.contiguous(), weight, 4, 2)
|
| 533 |
+
if C == 2560:
|
| 534 |
+
return torch.ops.rwkv7_v3a_ops.linear_orig_rows_f16(x.contiguous(), weight, 4, 2)
|
| 535 |
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 2)
|
| 536 |
if group == "head":
|
| 537 |
+
if C == 768:
|
| 538 |
+
if 192 <= path.rows < 256:
|
| 539 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 3)
|
| 540 |
+
if 96 <= path.rows < 160:
|
| 541 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 1)
|
| 542 |
+
if C == 1024:
|
| 543 |
+
if 256 <= path.rows < 384:
|
| 544 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig(x.contiguous(), weight)
|
| 545 |
+
if 192 <= path.rows < 256:
|
| 546 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 2)
|
| 547 |
+
if 96 <= path.rows < 160:
|
| 548 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 1)
|
| 549 |
+
if C == 2048:
|
| 550 |
+
if 256 <= path.rows < 384:
|
| 551 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 0)
|
| 552 |
+
if 192 <= path.rows < 256:
|
| 553 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 6)
|
| 554 |
+
if 128 <= path.rows < 160:
|
| 555 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 1)
|
| 556 |
+
if 96 <= path.rows < 112:
|
| 557 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 0)
|
| 558 |
+
if C == 2560:
|
| 559 |
+
if path.rows >= 256:
|
| 560 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 0)
|
| 561 |
+
if path.rows >= 192:
|
| 562 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 5)
|
| 563 |
+
if path.rows >= 160:
|
| 564 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 5)
|
| 565 |
+
if path.rows >= 128:
|
| 566 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 1)
|
| 567 |
+
if path.rows >= 96:
|
| 568 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 0)
|
| 569 |
+
if path.rows >= 80:
|
| 570 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 0)
|
| 571 |
+
if path.rows >= 72:
|
| 572 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 1)
|
| 573 |
if path.rows >= 1024:
|
| 574 |
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 0)
|
| 575 |
if path.rows >= 512:
|
|
|
|
| 593 |
if path.rows >= 72:
|
| 594 |
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 2)
|
| 595 |
if group == "att_c2c":
|
| 596 |
+
if C == 2560 and 17 <= path.rows <= 20:
|
| 597 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 0)
|
| 598 |
+
if C == 768:
|
| 599 |
+
if 256 <= path.rows < 384:
|
| 600 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 1)
|
| 601 |
+
if 96 <= path.rows < 112:
|
| 602 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 3)
|
| 603 |
+
if C == 1024:
|
| 604 |
+
if 256 <= path.rows < 384:
|
| 605 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 0)
|
| 606 |
+
if 96 <= path.rows < 112:
|
| 607 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 6)
|
| 608 |
+
if C == 2048:
|
| 609 |
+
if 256 <= path.rows < 384:
|
| 610 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 3)
|
| 611 |
+
if 192 <= path.rows < 256:
|
| 612 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 0)
|
| 613 |
+
if 96 <= path.rows < 112:
|
| 614 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 4)
|
| 615 |
+
if C == 2560:
|
| 616 |
+
if path.rows >= 256:
|
| 617 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 1)
|
| 618 |
+
if path.rows >= 160:
|
| 619 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 2)
|
| 620 |
+
if path.rows >= 128:
|
| 621 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 2)
|
| 622 |
+
if path.rows >= 112:
|
| 623 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 3)
|
| 624 |
+
if path.rows >= 96:
|
| 625 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 2)
|
| 626 |
+
if path.rows >= 72:
|
| 627 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 2)
|
| 628 |
+
if path.rows >= 5:
|
| 629 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig(x.contiguous(), weight)
|
| 630 |
if path.rows >= 1024:
|
| 631 |
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 4)
|
| 632 |
if path.rows >= 768:
|
|
|
|
| 658 |
if path.rows >= 5:
|
| 659 |
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 2)
|
| 660 |
if group == "ffn_key":
|
| 661 |
+
if C == 2560 and 17 <= path.rows <= 20:
|
| 662 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 0)
|
| 663 |
+
if C == 768:
|
| 664 |
+
if 256 <= path.rows < 384:
|
| 665 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig(x.contiguous(), weight)
|
| 666 |
+
if 96 <= path.rows < 112:
|
| 667 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig(x.contiguous(), weight)
|
| 668 |
+
if C == 1024:
|
| 669 |
+
if 256 <= path.rows < 384:
|
| 670 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 2)
|
| 671 |
+
if 192 <= path.rows < 256:
|
| 672 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 0)
|
| 673 |
+
if 96 <= path.rows < 160:
|
| 674 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 2)
|
| 675 |
+
if C == 2048 and 128 <= path.rows < 160:
|
| 676 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 3)
|
| 677 |
+
if C == 2560:
|
| 678 |
+
if path.rows >= 192:
|
| 679 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 5)
|
| 680 |
+
if path.rows >= 160:
|
| 681 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 4)
|
| 682 |
+
if path.rows >= 128:
|
| 683 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 5)
|
| 684 |
+
if path.rows >= 112:
|
| 685 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 4)
|
| 686 |
+
if path.rows >= 96:
|
| 687 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 128, 4)
|
| 688 |
+
if path.rows >= 80:
|
| 689 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 3)
|
| 690 |
+
if path.rows >= 72:
|
| 691 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 32, 4)
|
| 692 |
+
if path.rows >= 3:
|
| 693 |
+
return torch.ops.rwkv7_v3a_ops.linear_f16_orig(x.contiguous(), weight)
|
| 694 |
if path.rows >= 1024:
|
| 695 |
return torch.ops.rwkv7_v3a_ops.linear_f16_orig_lt_cfg(x.contiguous(), weight, 0, 0)
|
| 696 |
if path.rows >= 768:
|
|
|
|
| 727 |
return self.linear_lowrank_orig(x, weight) if weight is not None else self.linear_t_orig(x, weight_t)
|
| 728 |
|
| 729 |
def linear_rank_out(self, x: torch.Tensor, weight: torch.Tensor, weight_t: torch.Tensor, rows: int) -> torch.Tensor:
|
| 730 |
+
if weight_t is not None and C >= LOWRANK_FUSED_MIN_C and rows <= LOWRANK_OUT_ROWS_T:
|
| 731 |
return torch.ops.rwkv7_v3a_ops.linear_t_f16(x.contiguous(), weight_t)
|
| 732 |
return self.linear_lowrank_orig(x, weight) if weight is not None else self.linear_t_orig(x, weight_t)
|
| 733 |
|
| 734 |
def linear_rank_out_act(self, x: torch.Tensor, weight: torch.Tensor, weight_t: torch.Tensor, rows: int, act: int) -> torch.Tensor:
|
| 735 |
+
if weight_t is not None and C >= LOWRANK_FUSED_MIN_C and rows <= LOWRANK_OUT_ROWS_T:
|
| 736 |
return torch.ops.rwkv7_v3a_ops.linear_t_act_f16(x.contiguous(), weight_t, act)
|
| 737 |
ops = torch.ops.rwkv7_fast_ops_fp16
|
| 738 |
x = ops.act_tanh(x.contiguous()) if act == 1 else ops.act_sigmoid(x.contiguous())
|