Spaces:
Running on T4
Running on T4
| namespace { | |
| constexpr int N = 64; | |
| constexpr int HALF2_N = N / 2; | |
| constexpr int LDG_ELEMS = sizeof(int4) / sizeof(half); | |
| constexpr float TWO_NEG_41 = 4.547473508864641e-13f; | |
| constexpr float NEXP_HALF_LOG2_E = -0.8750387749145276f; | |
| constexpr float NLOG2_E = -1.4426950408889634f; | |
| constexpr int ROT1 = static_cast<int>(2654435769); | |
| using F = half; | |
| __device__ __forceinline__ float rotator1(int x) { | |
| return TWO_NEG_41 * float(ROT1 * x); | |
| } | |
| __device__ __forceinline__ half w_delta(float w, int phase) { | |
| float d = exp2f(NEXP_HALF_LOG2_E / (1.0f + exp2f(NLOG2_E * w))) - 1.0f + rotator1(phase); | |
| return __float2half_rn(d); | |
| } | |
| template <bool AddW0> | |
| __device__ __forceinline__ half w_delta_maybe_w0(half w_raw, const half* __restrict__ w0_ptr, int c, int phase) { | |
| float w = __half2float(w_raw); | |
| if constexpr (AddW0) { | |
| w += __half2float(w0_ptr[c]); | |
| } | |
| return w_delta(w, phase); | |
| } | |
| template <int Bytes> | |
| __device__ __forceinline__ void clone_cp_async(void const* smem_addr, void const* global_ptr, bool cond) { | |
| static_assert(Bytes == 16 || Bytes == 8 || Bytes == 4); | |
| int bytes = cond ? Bytes : 0; | |
| unsigned int addr = __cvta_generic_to_shared(smem_addr); | |
| if constexpr (Bytes == 16) { | |
| asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;" ::"r"(addr), "l"(global_ptr), "n"(Bytes), "r"(bytes)); | |
| } else { | |
| asm volatile("cp.async.ca.shared.global [%0], [%1], %2, %3;" ::"r"(addr), "l"(global_ptr), "n"(Bytes), "r"(bytes)); | |
| } | |
| } | |
| template <int NWait> | |
| __device__ __forceinline__ void clone_cp_wait() { | |
| if constexpr (NWait == 0) { | |
| asm volatile("cp.async.wait_all;\n" ::); | |
| } else { | |
| asm volatile("cp.async.wait_group %0;\n" ::"n"(NWait)); | |
| } | |
| } | |
| __device__ __forceinline__ void clone_cp_commit() { | |
| asm volatile("cp.async.commit_group;\n" ::); | |
| } | |
| template <bool Tis1 = false, bool AddW0 = false> | |
| __global__ void __launch_bounds__(CLONE_N, 2) wkv_fp16_v1_clone_kernel( | |
| const int B, | |
| const int T, | |
| const int C, | |
| const int H, | |
| F* __restrict__ state_ptr, | |
| const F* __restrict__ r_ptr, | |
| const F* __restrict__ w_ptr, | |
| const F* __restrict__ w0_ptr, | |
| const F* __restrict__ k_ptr, | |
| const F* __restrict__ v_ptr, | |
| const F* __restrict__ a_ptr, | |
| const F* __restrict__ b_ptr, | |
| F* __restrict__ y_ptr, | |
| const int* __restrict__ elapsed_t) { | |
| if constexpr (Tis1) { | |
| __builtin_assume(T == 1); | |
| } | |
| const int b = blockIdx.x / H; | |
| const int h = blockIdx.x % H; | |
| const int i = threadIdx.x; | |
| const int lane = i % 32; | |
| __shared__ __align__(256) half2 state_smem[CLONE_N][CLONE_N / 2]; | |
| state_ptr += b * C * CLONE_N + h * CLONE_N * CLONE_N; | |
| constexpr int ldg_size = sizeof(int4) / sizeof(F); | |
| for (int j0 = 0; j0 < CLONE_N / ldg_size; j0++) { | |
| int4 state_vec = ((int4*)state_ptr)[j0 * CLONE_N + i]; | |
| for (int j1 = 0; j1 < ldg_size / 2; j1++) { | |
| int row = j0 * ldg_size + i * ldg_size / CLONE_N; | |
| int col = i * ldg_size % CLONE_N / 2 + j1; | |
| state_smem[row][(row % 32) ^ col] = ((half2*)&state_vec)[j1]; | |
| } | |
| } | |
| __syncthreads(); | |
| half2 state[CLONE_N / 2]; | |
| for (int j = 0; j < CLONE_N / 2; j++) { | |
| state[j] = state_smem[i][lane ^ j]; | |
| } | |
| __shared__ __align__(128) half2 r[CLONE_N / 2], k[CLONE_N / 2], w[CLONE_N / 2], a[CLONE_N / 2], bvec[CLONE_N / 2]; | |
| for (int tt = 0; tt < T; tt++) { | |
| int t = b * T * C + h * CLONE_N + tt * C; | |
| __syncthreads(); | |
| clone_cp_async<4>((half2*)(i < 32 ? w : a) + lane, (half2*)((i < 32 ? w_ptr : a_ptr) + t) + lane, true); | |
| clone_cp_commit(); | |
| clone_cp_async<4>((half2*)(i < 32 ? r : k) + lane, (half2*)((i < 32 ? r_ptr : k_ptr) + t) + lane, true); | |
| clone_cp_async<4>((half2*)bvec + lane, (half2*)(b_ptr + t) + lane, i < 32); | |
| clone_cp_commit(); | |
| half vv = v_ptr[t + i]; | |
| half2 vv2 = {vv, vv}; | |
| half2 y2 = {0.0, 0.0}; | |
| half2 sa2 = {0.0, 0.0}; | |
| clone_cp_wait<1>(); | |
| __syncthreads(); | |
| for (int j = 0; j < CLONE_N / 2; j++) { | |
| sa2 = __hfma2(a[j], state[j], sa2); | |
| } | |
| half sa = sa2.x + sa2.y; | |
| sa2 = {sa, sa}; | |
| ((F*)w)[i] = w_delta_maybe_w0<AddW0>(((F*)w)[i], w0_ptr, h * CLONE_N + i, elapsed_t[b] + h * CLONE_N + i + tt); | |
| clone_cp_wait<0>(); | |
| __syncthreads(); | |
| for (int j = 0; j < CLONE_N / 2; j++) { | |
| half2& s = state[j]; | |
| s = __hfma2(s, w[j], __hfma2(k[j], vv2, __hfma2(sa2, bvec[j], s))); | |
| y2 = __hfma2(s, r[j], y2); | |
| } | |
| y_ptr[t + i] = y2.x + y2.y; | |
| } | |
| for (int j = 0; j < CLONE_N / 2; j++) { | |
| state_smem[i][lane ^ j] = state[j]; | |
| } | |
| __syncthreads(); | |
| for (int j0 = 0; j0 < CLONE_N / ldg_size; j0++) { | |
| int4 state_vec; | |
| for (int j1 = 0; j1 < ldg_size / 2; j1++) { | |
| int row = j0 * ldg_size + i * ldg_size / CLONE_N; | |
| int col = i * ldg_size % CLONE_N / 2 + j1; | |
| ((half2*)&state_vec)[j1] = state_smem[row][(row % 32) ^ col]; | |
| } | |
| ((int4*)state_ptr)[j0 * CLONE_N + i] = state_vec; | |
| } | |
| } | |
| template <int Bytes> | |
| __device__ __forceinline__ void cp_async(void* smem, const void* global, bool pred) { | |
| static_assert(Bytes == 16 || Bytes == 8 || Bytes == 4); | |
| int bytes = pred ? Bytes : 0; | |
| unsigned addr = __cvta_generic_to_shared(smem); | |
| if constexpr (Bytes == 16) { | |
| asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;" ::"r"(addr), "l"(global), "n"(Bytes), "r"(bytes)); | |
| } else { | |
| asm volatile("cp.async.ca.shared.global [%0], [%1], %2, %3;" ::"r"(addr), "l"(global), "n"(Bytes), "r"(bytes)); | |
| } | |
| } | |
| __device__ __forceinline__ void cp_commit() { | |
| asm volatile("cp.async.commit_group;\n" ::); | |
| } | |
| template <int NWait> | |
| __device__ __forceinline__ void cp_wait() { | |
| if constexpr (NWait == 0) { | |
| asm volatile("cp.async.wait_all;\n" ::); | |
| } else { | |
| asm volatile("cp.async.wait_group %0;\n" ::"n"(NWait)); | |
| } | |
| } | |
| __device__ __forceinline__ void prefetch_token( | |
| int tid, | |
| int lane, | |
| int token, | |
| half2* r, | |
| half2* w, | |
| half2* k, | |
| half2* a, | |
| half2* b, | |
| const half* r_ptr, | |
| const half* w_ptr, | |
| const half* k_ptr, | |
| const half* a_ptr, | |
| const half* b_ptr) { | |
| cp_async<4>((tid < 32 ? w : a) + lane, (const half2*)(tid < 32 ? w_ptr + token : a_ptr + token) + lane, true); | |
| cp_commit(); | |
| cp_async<4>((tid < 32 ? r : k) + lane, (const half2*)(tid < 32 ? r_ptr + token : k_ptr + token) + lane, true); | |
| cp_async<4>(b + lane, (const half2*)(b_ptr + token) + lane, tid < 32); | |
| cp_commit(); | |
| } | |
| template <bool Tis1 = false, bool AddW0 = false> | |
| __global__ __launch_bounds__(N, 2) void wkv_fp16_v1_exact_kernel( | |
| const int B, | |
| const int T, | |
| const int C, | |
| const int H, | |
| half* __restrict__ state_ptr, | |
| const half* __restrict__ r_ptr, | |
| const half* __restrict__ w_ptr, | |
| const half* __restrict__ w0_ptr, | |
| const half* __restrict__ k_ptr, | |
| const half* __restrict__ v_ptr, | |
| const half* __restrict__ a_ptr, | |
| const half* __restrict__ b_ptr, | |
| half* __restrict__ y_ptr, | |
| const int* __restrict__ elapsed_t) { | |
| if constexpr (Tis1) { | |
| __builtin_assume(T == 1); | |
| } | |
| const int b_id = blockIdx.x / H; | |
| const int h = blockIdx.x % H; | |
| const int i = threadIdx.x; | |
| const int lane = i % 32; | |
| __shared__ __align__(256) half2 state_smem[N][HALF2_N]; | |
| state_ptr += b_id * C * N + h * N * N; | |
| for (int j0 = 0; j0 < N / LDG_ELEMS; j0++) { | |
| int4 state_vec = ((int4*)state_ptr)[j0 * N + i]; | |
| for (int j1 = 0; j1 < LDG_ELEMS / 2; j1++) { | |
| int row = j0 * LDG_ELEMS + i * LDG_ELEMS / N; | |
| int col = i * LDG_ELEMS % N / 2 + j1; | |
| state_smem[row][(row % 32) ^ col] = ((half2*)&state_vec)[j1]; | |
| } | |
| } | |
| __syncthreads(); | |
| half2 state[HALF2_N]; | |
| for (int j = 0; j < HALF2_N; j++) { | |
| state[j] = state_smem[i][lane ^ j]; | |
| } | |
| __shared__ __align__(128) half2 r[HALF2_N], k[HALF2_N], w[HALF2_N], a[HALF2_N], bvec[HALF2_N]; | |
| for (int tt = 0; tt < T; tt++) { | |
| int t = b_id * T * C + h * N + tt * C; | |
| __syncthreads(); | |
| cp_async<4>((half2*)(i < 32 ? w : a) + lane, (half2*)((i < 32 ? w_ptr : a_ptr) + t) + lane, true); | |
| cp_commit(); | |
| cp_async<4>((half2*)(i < 32 ? r : k) + lane, (half2*)((i < 32 ? r_ptr : k_ptr) + t) + lane, true); | |
| cp_async<4>((half2*)bvec + lane, (half2*)(b_ptr + t) + lane, i < 32); | |
| cp_commit(); | |
| half vv = v_ptr[t + i]; | |
| half2 vv2 = {vv, vv}; | |
| half2 y2 = {0.0, 0.0}; | |
| half2 sa2 = {0.0, 0.0}; | |
| cp_wait<1>(); | |
| __syncthreads(); | |
| for (int j = 0; j < HALF2_N; j++) { | |
| sa2 = __hfma2(a[j], state[j], sa2); | |
| } | |
| half sa = sa2.x + sa2.y; | |
| sa2 = {sa, sa}; | |
| ((half*)w)[i] = w_delta_maybe_w0<AddW0>(((half*)w)[i], w0_ptr, h * N + i, elapsed_t[b_id] + h * N + i + tt); | |
| cp_wait<0>(); | |
| __syncthreads(); | |
| for (int j = 0; j < HALF2_N; j++) { | |
| half2& s = state[j]; | |
| s = __hfma2(s, w[j], __hfma2(k[j], vv2, __hfma2(sa2, bvec[j], s))); | |
| y2 = __hfma2(s, r[j], y2); | |
| } | |
| y_ptr[t + i] = y2.x + y2.y; | |
| } | |
| for (int j = 0; j < HALF2_N; j++) { | |
| state_smem[i][lane ^ j] = state[j]; | |
| } | |
| __syncthreads(); | |
| for (int j0 = 0; j0 < N / LDG_ELEMS; j0++) { | |
| int4 state_vec; | |
| for (int j1 = 0; j1 < LDG_ELEMS / 2; j1++) { | |
| int row = j0 * LDG_ELEMS + i * LDG_ELEMS / N; | |
| int col = i * LDG_ELEMS % N / 2 + j1; | |
| ((half2*)&state_vec)[j1] = state_smem[row][(row % 32) ^ col]; | |
| } | |
| ((int4*)state_ptr)[j0 * N + i] = state_vec; | |
| } | |
| } | |
| template <bool AddW0 = false> | |
| __global__ __launch_bounds__(N, 2) void wkv_fp16_seq_v2_kernel( | |
| int T, | |
| int C, | |
| int H, | |
| half* __restrict__ state_ptr, | |
| const half* __restrict__ r_ptr, | |
| const half* __restrict__ w_ptr, | |
| const half* __restrict__ w0_ptr, | |
| const half* __restrict__ k_ptr, | |
| const half* __restrict__ v_ptr, | |
| const half* __restrict__ a_ptr, | |
| const half* __restrict__ b_ptr, | |
| half* __restrict__ y_ptr, | |
| const int* __restrict__ elapsed_t) { | |
| const int bh = blockIdx.x; | |
| const int b_id = bh / H; | |
| const int h = bh - b_id * H; | |
| const int i = threadIdx.x; | |
| const int lane = i & 31; | |
| __shared__ __align__(256) half2 state_smem[N][HALF2_N]; | |
| state_ptr += static_cast<int64_t>(b_id) * C * N + h * N * N; | |
| for (int j0 = 0; j0 < N / LDG_ELEMS; ++j0) { | |
| int4 state_vec = ((int4*)state_ptr)[j0 * N + i]; | |
| for (int j1 = 0; j1 < LDG_ELEMS / 2; ++j1) { | |
| int row = j0 * LDG_ELEMS + i * LDG_ELEMS / N; | |
| int col = i * LDG_ELEMS % N / 2 + j1; | |
| state_smem[row][(row & 31) ^ col] = ((half2*)&state_vec)[j1]; | |
| } | |
| } | |
| __syncthreads(); | |
| half2 state[HALF2_N]; | |
| for (int j = 0; j < HALF2_N; ++j) { | |
| state[j] = state_smem[i][lane ^ j]; | |
| } | |
| __shared__ __align__(128) half2 r[2][HALF2_N], w[2][HALF2_N], k[2][HALF2_N], a[2][HALF2_N], bvec[2][HALF2_N]; | |
| int token = (b_id * T) * C + h * N; | |
| prefetch_token(i, lane, token, r[0], w[0], k[0], a[0], bvec[0], r_ptr, w_ptr, k_ptr, a_ptr, b_ptr); | |
| for (int tt = 0; tt < T; ++tt) { | |
| const int cur = tt & 1; | |
| cp_wait<0>(); | |
| __syncthreads(); | |
| half2 sa2 = {0.0f, 0.0f}; | |
| for (int j = 0; j < HALF2_N; ++j) { | |
| sa2 = __hfma2(a[cur][j], state[j], sa2); | |
| } | |
| half sa = sa2.x + sa2.y; | |
| sa2 = {sa, sa}; | |
| ((half*)w[cur])[i] = w_delta_maybe_w0<AddW0>(((half*)w[cur])[i], w0_ptr, h * N + i, elapsed_t[b_id] + h * N + i + tt); | |
| __syncthreads(); | |
| if (tt + 1 < T) { | |
| int next_token = token + C; | |
| prefetch_token(i, lane, next_token, r[cur ^ 1], w[cur ^ 1], k[cur ^ 1], a[cur ^ 1], bvec[cur ^ 1], r_ptr, w_ptr, k_ptr, a_ptr, b_ptr); | |
| } | |
| half vv = v_ptr[token + i]; | |
| half2 vv2 = {vv, vv}; | |
| half2 y2 = {0.0f, 0.0f}; | |
| for (int j = 0; j < HALF2_N; ++j) { | |
| half2 s = state[j]; | |
| s = __hfma2(s, w[cur][j], __hfma2(k[cur][j], vv2, __hfma2(sa2, bvec[cur][j], s))); | |
| state[j] = s; | |
| y2 = __hfma2(s, r[cur][j], y2); | |
| } | |
| y_ptr[token + i] = y2.x + y2.y; | |
| token += C; | |
| } | |
| for (int j = 0; j < HALF2_N; ++j) { | |
| state_smem[i][lane ^ j] = state[j]; | |
| } | |
| __syncthreads(); | |
| for (int j0 = 0; j0 < N / LDG_ELEMS; ++j0) { | |
| int4 state_vec; | |
| for (int j1 = 0; j1 < LDG_ELEMS / 2; ++j1) { | |
| int row = j0 * LDG_ELEMS + i * LDG_ELEMS / N; | |
| int col = i * LDG_ELEMS % N / 2 + j1; | |
| ((half2*)&state_vec)[j1] = state_smem[row][(row & 31) ^ col]; | |
| } | |
| ((int4*)state_ptr)[j0 * N + i] = state_vec; | |
| } | |
| } | |
| template <bool AddW0 = false> | |
| __global__ __launch_bounds__(N, 1) void wkv_fp16_one_direct_kernel( | |
| int C, | |
| int H, | |
| half* __restrict__ state_ptr, | |
| const half* __restrict__ r_ptr, | |
| const half* __restrict__ w_ptr, | |
| const half* __restrict__ w0_ptr, | |
| const half* __restrict__ k_ptr, | |
| const half* __restrict__ v_ptr, | |
| const half* __restrict__ a_ptr, | |
| const half* __restrict__ b_ptr, | |
| half* __restrict__ y_ptr, | |
| const int* __restrict__ elapsed_t) { | |
| const int bh = blockIdx.x; | |
| const int b_id = bh / H; | |
| const int h = bh - b_id * H; | |
| const int i = threadIdx.x; | |
| const int lane = i & 31; | |
| __shared__ __align__(256) half2 state_smem[N][HALF2_N]; | |
| half* state_base = state_ptr + static_cast<int64_t>(b_id) * C * N + h * N * N; | |
| for (int j0 = 0; j0 < N / LDG_ELEMS; ++j0) { | |
| int4 state_vec = ((int4*)state_base)[j0 * N + i]; | |
| for (int j1 = 0; j1 < LDG_ELEMS / 2; ++j1) { | |
| int row = j0 * LDG_ELEMS + i * LDG_ELEMS / N; | |
| int col = i * LDG_ELEMS % N / 2 + j1; | |
| state_smem[row][(row & 31) ^ col] = ((half2*)&state_vec)[j1]; | |
| } | |
| } | |
| __syncthreads(); | |
| half2 state[HALF2_N]; | |
| for (int j = 0; j < HALF2_N; ++j) { | |
| state[j] = state_smem[i][lane ^ j]; | |
| } | |
| __shared__ __align__(128) half2 r[HALF2_N], w[HALF2_N], k[HALF2_N], a[HALF2_N], bvec[HALF2_N]; | |
| const int token = b_id * C + h * N; | |
| if (i < HALF2_N) { | |
| const int idx2 = (token >> 1) + i; | |
| r[i] = __ldg(reinterpret_cast<const half2*>(r_ptr) + idx2); | |
| w[i] = __ldg(reinterpret_cast<const half2*>(w_ptr) + idx2); | |
| k[i] = __ldg(reinterpret_cast<const half2*>(k_ptr) + idx2); | |
| a[i] = __ldg(reinterpret_cast<const half2*>(a_ptr) + idx2); | |
| bvec[i] = __ldg(reinterpret_cast<const half2*>(b_ptr) + idx2); | |
| } | |
| __syncthreads(); | |
| half2 sa2 = {0.0f, 0.0f}; | |
| for (int j = 0; j < HALF2_N; ++j) { | |
| sa2 = __hfma2(a[j], state[j], sa2); | |
| } | |
| half sa = sa2.x + sa2.y; | |
| sa2 = {sa, sa}; | |
| ((half*)w)[i] = w_delta_maybe_w0<AddW0>(((half*)w)[i], w0_ptr, h * N + i, elapsed_t[b_id] + h * N + i); | |
| __syncthreads(); | |
| half vv = __ldg(v_ptr + token + i); | |
| half2 vv2 = {vv, vv}; | |
| half2 y2 = {0.0f, 0.0f}; | |
| for (int j = 0; j < HALF2_N; ++j) { | |
| half2 s = state[j]; | |
| s = __hfma2(s, w[j], __hfma2(k[j], vv2, __hfma2(sa2, bvec[j], s))); | |
| state[j] = s; | |
| y2 = __hfma2(s, r[j], y2); | |
| } | |
| y_ptr[token + i] = y2.x + y2.y; | |
| for (int j = 0; j < HALF2_N; ++j) { | |
| state_smem[i][lane ^ j] = state[j]; | |
| } | |
| __syncthreads(); | |
| for (int j0 = 0; j0 < N / LDG_ELEMS; ++j0) { | |
| int4 state_vec; | |
| for (int j1 = 0; j1 < LDG_ELEMS / 2; ++j1) { | |
| int row = j0 * LDG_ELEMS + i * LDG_ELEMS / N; | |
| int col = i * LDG_ELEMS % N / 2 + j1; | |
| ((half2*)&state_vec)[j1] = state_smem[row][(row & 31) ^ col]; | |
| } | |
| ((int4*)state_base)[j0 * N + i] = state_vec; | |
| } | |
| } | |
| template <bool AddW0 = false> | |
| __global__ __launch_bounds__(N, 1) void wkv_fp16_one_cp_kernel( | |
| int C, | |
| int H, | |
| half* __restrict__ state_ptr, | |
| const half* __restrict__ r_ptr, | |
| const half* __restrict__ w_ptr, | |
| const half* __restrict__ w0_ptr, | |
| const half* __restrict__ k_ptr, | |
| const half* __restrict__ v_ptr, | |
| const half* __restrict__ a_ptr, | |
| const half* __restrict__ b_ptr, | |
| half* __restrict__ y_ptr, | |
| const int* __restrict__ elapsed_t) { | |
| const int bh = blockIdx.x; | |
| const int b_id = bh / H; | |
| const int h = bh - b_id * H; | |
| const int i = threadIdx.x; | |
| const int lane = i & 31; | |
| __shared__ __align__(256) half2 state_smem[N][HALF2_N]; | |
| half* state_base = state_ptr + static_cast<int64_t>(b_id) * C * N + h * N * N; | |
| for (int j0 = 0; j0 < N / LDG_ELEMS; ++j0) { | |
| int4 state_vec = ((int4*)state_base)[j0 * N + i]; | |
| for (int j1 = 0; j1 < LDG_ELEMS / 2; ++j1) { | |
| int row = j0 * LDG_ELEMS + i * LDG_ELEMS / N; | |
| int col = i * LDG_ELEMS % N / 2 + j1; | |
| state_smem[row][(row & 31) ^ col] = ((half2*)&state_vec)[j1]; | |
| } | |
| } | |
| __syncthreads(); | |
| half2 state[HALF2_N]; | |
| for (int j = 0; j < HALF2_N; ++j) { | |
| state[j] = state_smem[i][lane ^ j]; | |
| } | |
| __shared__ __align__(128) half2 r[HALF2_N], w[HALF2_N], k[HALF2_N], a[HALF2_N], bvec[HALF2_N]; | |
| const int token = b_id * C + h * N; | |
| cp_async<4>((half2*)(i < 32 ? w : a) + lane, (half2*)((i < 32 ? w_ptr : a_ptr) + token) + lane, true); | |
| cp_commit(); | |
| cp_async<4>((half2*)(i < 32 ? r : k) + lane, (half2*)((i < 32 ? r_ptr : k_ptr) + token) + lane, true); | |
| cp_async<4>((half2*)bvec + lane, (half2*)(b_ptr + token) + lane, i < 32); | |
| cp_commit(); | |
| half vv = __ldg(v_ptr + token + i); | |
| half2 vv2 = {vv, vv}; | |
| half2 sa2 = {0.0f, 0.0f}; | |
| cp_wait<1>(); | |
| __syncthreads(); | |
| for (int j = 0; j < HALF2_N; ++j) { | |
| sa2 = __hfma2(a[j], state[j], sa2); | |
| } | |
| half sa = sa2.x + sa2.y; | |
| sa2 = {sa, sa}; | |
| ((half*)w)[i] = w_delta_maybe_w0<AddW0>(((half*)w)[i], w0_ptr, h * N + i, elapsed_t[b_id] + h * N + i); | |
| cp_wait<0>(); | |
| __syncthreads(); | |
| half2 y2 = {0.0f, 0.0f}; | |
| for (int j = 0; j < HALF2_N; ++j) { | |
| half2 s = state[j]; | |
| s = __hfma2(s, w[j], __hfma2(k[j], vv2, __hfma2(sa2, bvec[j], s))); | |
| state[j] = s; | |
| y2 = __hfma2(s, r[j], y2); | |
| } | |
| y_ptr[token + i] = y2.x + y2.y; | |
| for (int j = 0; j < HALF2_N; ++j) { | |
| state_smem[i][lane ^ j] = state[j]; | |
| } | |
| __syncthreads(); | |
| for (int j0 = 0; j0 < N / LDG_ELEMS; ++j0) { | |
| int4 state_vec; | |
| for (int j1 = 0; j1 < LDG_ELEMS / 2; ++j1) { | |
| int row = j0 * LDG_ELEMS + i * LDG_ELEMS / N; | |
| int col = i * LDG_ELEMS % N / 2 + j1; | |
| ((half2*)&state_vec)[j1] = state_smem[row][(row & 31) ^ col]; | |
| } | |
| ((int4*)state_base)[j0 * N + i] = state_vec; | |
| } | |
| } | |
| bool use_v2_seq(int B, int T) { | |
| return (B == 1 && T >= 8) || | |
| (B == 4 && T >= 4) || | |
| (B == 8 && T >= 8) || | |
| (B == 64 && T == 1) || | |
| (B == 128 && T == 1); | |
| } | |
| } // namespace | |
| void wkv_one_v2_cuda_impl( | |
| int B, | |
| int C, | |
| int H, | |
| at::Tensor state, | |
| at::Tensor r, | |
| at::Tensor w, | |
| const half* w0_ptr, | |
| bool add_w0, | |
| at::Tensor k, | |
| at::Tensor v, | |
| at::Tensor a, | |
| at::Tensor b, | |
| at::Tensor y, | |
| at::Tensor elapsed_t); | |
| void wkv_seq_v2_cuda_impl( | |
| int B, | |
| int T, | |
| int C, | |
| int H, | |
| at::Tensor state, | |
| at::Tensor r, | |
| at::Tensor w, | |
| const half* w0_ptr, | |
| bool add_w0, | |
| at::Tensor k, | |
| at::Tensor v, | |
| at::Tensor a, | |
| at::Tensor b, | |
| at::Tensor y, | |
| at::Tensor elapsed_t) { | |
| assert(C == H * N); | |
| if (T == 1) { | |
| wkv_one_v2_cuda_impl(B, C, H, state, r, w, w0_ptr, add_w0, k, v, a, b, y, elapsed_t); | |
| return; | |
| } | |
| auto stream = at::cuda::getCurrentCUDAStream(); | |
| if (use_v2_seq(B, T)) { | |
| if (add_w0) { | |
| wkv_fp16_seq_v2_kernel<true><<<dim3(B * H), dim3(N), 0, stream>>>( | |
| T, C, H, | |
| reinterpret_cast<half*>(state.data_ptr()), | |
| reinterpret_cast<const half*>(r.data_ptr()), | |
| reinterpret_cast<const half*>(w.data_ptr()), | |
| w0_ptr, | |
| reinterpret_cast<const half*>(k.data_ptr()), | |
| reinterpret_cast<const half*>(v.data_ptr()), | |
| reinterpret_cast<const half*>(a.data_ptr()), | |
| reinterpret_cast<const half*>(b.data_ptr()), | |
| reinterpret_cast<half*>(y.data_ptr()), | |
| elapsed_t.data_ptr<int>()); | |
| } else { | |
| wkv_fp16_seq_v2_kernel<false><<<dim3(B * H), dim3(N), 0, stream>>>( | |
| T, C, H, | |
| reinterpret_cast<half*>(state.data_ptr()), | |
| reinterpret_cast<const half*>(r.data_ptr()), | |
| reinterpret_cast<const half*>(w.data_ptr()), | |
| nullptr, | |
| reinterpret_cast<const half*>(k.data_ptr()), | |
| reinterpret_cast<const half*>(v.data_ptr()), | |
| reinterpret_cast<const half*>(a.data_ptr()), | |
| reinterpret_cast<const half*>(b.data_ptr()), | |
| reinterpret_cast<half*>(y.data_ptr()), | |
| elapsed_t.data_ptr<int>()); | |
| } | |
| } else { | |
| if (add_w0) { | |
| wkv_fp16_v1_exact_kernel<false, true><<<dim3(B * H), dim3(N), 0, stream>>>( | |
| B, T, C, H, | |
| reinterpret_cast<half*>(state.data_ptr()), | |
| reinterpret_cast<const half*>(r.data_ptr()), | |
| reinterpret_cast<const half*>(w.data_ptr()), | |
| w0_ptr, | |
| reinterpret_cast<const half*>(k.data_ptr()), | |
| reinterpret_cast<const half*>(v.data_ptr()), | |
| reinterpret_cast<const half*>(a.data_ptr()), | |
| reinterpret_cast<const half*>(b.data_ptr()), | |
| reinterpret_cast<half*>(y.data_ptr()), | |
| elapsed_t.data_ptr<int>()); | |
| } else { | |
| wkv_fp16_v1_exact_kernel<false, false><<<dim3(B * H), dim3(N), 0, stream>>>( | |
| B, T, C, H, | |
| reinterpret_cast<half*>(state.data_ptr()), | |
| reinterpret_cast<const half*>(r.data_ptr()), | |
| reinterpret_cast<const half*>(w.data_ptr()), | |
| nullptr, | |
| reinterpret_cast<const half*>(k.data_ptr()), | |
| reinterpret_cast<const half*>(v.data_ptr()), | |
| reinterpret_cast<const half*>(a.data_ptr()), | |
| reinterpret_cast<const half*>(b.data_ptr()), | |
| reinterpret_cast<half*>(y.data_ptr()), | |
| elapsed_t.data_ptr<int>()); | |
| } | |
| } | |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); | |
| } | |
| void wkv_seq_v2_cuda( | |
| int B, | |
| int T, | |
| int C, | |
| int H, | |
| at::Tensor state, | |
| at::Tensor r, | |
| at::Tensor w, | |
| at::Tensor k, | |
| at::Tensor v, | |
| at::Tensor a, | |
| at::Tensor b, | |
| at::Tensor y, | |
| at::Tensor elapsed_t) { | |
| wkv_seq_v2_cuda_impl(B, T, C, H, state, r, w, nullptr, false, k, v, a, b, y, elapsed_t); | |
| } | |
| void wkv_seq_w0_v2_cuda( | |
| int B, | |
| int T, | |
| int C, | |
| int H, | |
| at::Tensor state, | |
| at::Tensor r, | |
| at::Tensor w, | |
| at::Tensor w0, | |
| at::Tensor k, | |
| at::Tensor v, | |
| at::Tensor a, | |
| at::Tensor b, | |
| at::Tensor y, | |
| at::Tensor elapsed_t) { | |
| wkv_seq_v2_cuda_impl(B, T, C, H, state, r, w, reinterpret_cast<const half*>(w0.data_ptr()), true, k, v, a, b, y, elapsed_t); | |
| } | |
| void wkv_one_v2_cuda( | |
| int B, | |
| int C, | |
| int H, | |
| at::Tensor state, | |
| at::Tensor r, | |
| at::Tensor w, | |
| at::Tensor k, | |
| at::Tensor v, | |
| at::Tensor a, | |
| at::Tensor b, | |
| at::Tensor y, | |
| at::Tensor elapsed_t) { | |
| wkv_one_v2_cuda_impl(B, C, H, state, r, w, nullptr, false, k, v, a, b, y, elapsed_t); | |
| } | |
| void wkv_one_w0_v2_cuda( | |
| int B, | |
| int C, | |
| int H, | |
| at::Tensor state, | |
| at::Tensor r, | |
| at::Tensor w, | |
| at::Tensor w0, | |
| at::Tensor k, | |
| at::Tensor v, | |
| at::Tensor a, | |
| at::Tensor b, | |
| at::Tensor y, | |
| at::Tensor elapsed_t) { | |
| wkv_one_v2_cuda_impl(B, C, H, state, r, w, reinterpret_cast<const half*>(w0.data_ptr()), true, k, v, a, b, y, elapsed_t); | |
| } | |
| void wkv_one_v2_cuda_impl( | |
| int B, | |
| int C, | |
| int H, | |
| at::Tensor state, | |
| at::Tensor r, | |
| at::Tensor w, | |
| const half* w0_ptr, | |
| bool add_w0, | |
| at::Tensor k, | |
| at::Tensor v, | |
| at::Tensor a, | |
| at::Tensor b, | |
| at::Tensor y, | |
| at::Tensor elapsed_t) { | |
| assert(C == H * N); | |
| auto stream = at::cuda::getCurrentCUDAStream(); | |
| if (B <= 2) { | |
| if (add_w0) { | |
| wkv_fp16_v1_clone_kernel<true, true><<<dim3(B * H), dim3(N), 0, stream>>>( | |
| B, 1, C, H, | |
| reinterpret_cast<half*>(state.data_ptr()), | |
| reinterpret_cast<const half*>(r.data_ptr()), | |
| reinterpret_cast<const half*>(w.data_ptr()), | |
| w0_ptr, | |
| reinterpret_cast<const half*>(k.data_ptr()), | |
| reinterpret_cast<const half*>(v.data_ptr()), | |
| reinterpret_cast<const half*>(a.data_ptr()), | |
| reinterpret_cast<const half*>(b.data_ptr()), | |
| reinterpret_cast<half*>(y.data_ptr()), | |
| elapsed_t.data_ptr<int>()); | |
| } else { | |
| wkv_fp16_v1_clone_kernel<true, false><<<dim3(B * H), dim3(N), 0, stream>>>( | |
| B, 1, C, H, | |
| reinterpret_cast<half*>(state.data_ptr()), | |
| reinterpret_cast<const half*>(r.data_ptr()), | |
| reinterpret_cast<const half*>(w.data_ptr()), | |
| nullptr, | |
| reinterpret_cast<const half*>(k.data_ptr()), | |
| reinterpret_cast<const half*>(v.data_ptr()), | |
| reinterpret_cast<const half*>(a.data_ptr()), | |
| reinterpret_cast<const half*>(b.data_ptr()), | |
| reinterpret_cast<half*>(y.data_ptr()), | |
| elapsed_t.data_ptr<int>()); | |
| } | |
| } else if (B <= 64) { | |
| if (add_w0) { | |
| wkv_fp16_one_cp_kernel<true><<<dim3(B * H), dim3(N), 0, stream>>>( | |
| C, H, | |
| reinterpret_cast<half*>(state.data_ptr()), | |
| reinterpret_cast<const half*>(r.data_ptr()), | |
| reinterpret_cast<const half*>(w.data_ptr()), | |
| w0_ptr, | |
| reinterpret_cast<const half*>(k.data_ptr()), | |
| reinterpret_cast<const half*>(v.data_ptr()), | |
| reinterpret_cast<const half*>(a.data_ptr()), | |
| reinterpret_cast<const half*>(b.data_ptr()), | |
| reinterpret_cast<half*>(y.data_ptr()), | |
| elapsed_t.data_ptr<int>()); | |
| } else { | |
| wkv_fp16_one_cp_kernel<false><<<dim3(B * H), dim3(N), 0, stream>>>( | |
| C, H, | |
| reinterpret_cast<half*>(state.data_ptr()), | |
| reinterpret_cast<const half*>(r.data_ptr()), | |
| reinterpret_cast<const half*>(w.data_ptr()), | |
| nullptr, | |
| reinterpret_cast<const half*>(k.data_ptr()), | |
| reinterpret_cast<const half*>(v.data_ptr()), | |
| reinterpret_cast<const half*>(a.data_ptr()), | |
| reinterpret_cast<const half*>(b.data_ptr()), | |
| reinterpret_cast<half*>(y.data_ptr()), | |
| elapsed_t.data_ptr<int>()); | |
| } | |
| } else if (B <= 128) { | |
| if (add_w0) { | |
| wkv_fp16_one_direct_kernel<true><<<dim3(B * H), dim3(N), 0, stream>>>( | |
| C, H, | |
| reinterpret_cast<half*>(state.data_ptr()), | |
| reinterpret_cast<const half*>(r.data_ptr()), | |
| reinterpret_cast<const half*>(w.data_ptr()), | |
| w0_ptr, | |
| reinterpret_cast<const half*>(k.data_ptr()), | |
| reinterpret_cast<const half*>(v.data_ptr()), | |
| reinterpret_cast<const half*>(a.data_ptr()), | |
| reinterpret_cast<const half*>(b.data_ptr()), | |
| reinterpret_cast<half*>(y.data_ptr()), | |
| elapsed_t.data_ptr<int>()); | |
| } else { | |
| wkv_fp16_one_direct_kernel<false><<<dim3(B * H), dim3(N), 0, stream>>>( | |
| C, H, | |
| reinterpret_cast<half*>(state.data_ptr()), | |
| reinterpret_cast<const half*>(r.data_ptr()), | |
| reinterpret_cast<const half*>(w.data_ptr()), | |
| nullptr, | |
| reinterpret_cast<const half*>(k.data_ptr()), | |
| reinterpret_cast<const half*>(v.data_ptr()), | |
| reinterpret_cast<const half*>(a.data_ptr()), | |
| reinterpret_cast<const half*>(b.data_ptr()), | |
| reinterpret_cast<half*>(y.data_ptr()), | |
| elapsed_t.data_ptr<int>()); | |
| } | |
| } else { | |
| if (add_w0) { | |
| wkv_fp16_v1_clone_kernel<true, true><<<dim3(B * H), dim3(N), 0, stream>>>( | |
| B, 1, C, H, | |
| reinterpret_cast<half*>(state.data_ptr()), | |
| reinterpret_cast<const half*>(r.data_ptr()), | |
| reinterpret_cast<const half*>(w.data_ptr()), | |
| w0_ptr, | |
| reinterpret_cast<const half*>(k.data_ptr()), | |
| reinterpret_cast<const half*>(v.data_ptr()), | |
| reinterpret_cast<const half*>(a.data_ptr()), | |
| reinterpret_cast<const half*>(b.data_ptr()), | |
| reinterpret_cast<half*>(y.data_ptr()), | |
| elapsed_t.data_ptr<int>()); | |
| } else { | |
| wkv_fp16_v1_clone_kernel<true, false><<<dim3(B * H), dim3(N), 0, stream>>>( | |
| B, 1, C, H, | |
| reinterpret_cast<half*>(state.data_ptr()), | |
| reinterpret_cast<const half*>(r.data_ptr()), | |
| reinterpret_cast<const half*>(w.data_ptr()), | |
| nullptr, | |
| reinterpret_cast<const half*>(k.data_ptr()), | |
| reinterpret_cast<const half*>(v.data_ptr()), | |
| reinterpret_cast<const half*>(a.data_ptr()), | |
| reinterpret_cast<const half*>(b.data_ptr()), | |
| reinterpret_cast<half*>(y.data_ptr()), | |
| elapsed_t.data_ptr<int>()); | |
| } | |
| } | |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); | |
| } | |