#include #include #include #include #include #include #include #include #include #include using dtype = at::Half; namespace wmma = nvcuda::wmma; namespace { constexpr int LN_THREADS = 256; constexpr int LN_SMALL_THREADS = 1024; constexpr int LN_SMALL512_THREADS = 512; constexpr int LN_SMALL_C = 4096; inline int64_t ceil_div(int64_t n, int64_t d) { return (n + d - 1) / d; } inline void check_cublas(cublasStatus_t status, const char* what) { TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, what, " failed with cublas status ", static_cast(status)); } inline void check_cublaslt(cublasStatus_t status, const char* what) { TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, what, " failed with cublasLt status ", static_cast(status)); } template __device__ __forceinline__ float apply_act(float x) { if constexpr (Act == 1) { return tanhf(x); } else { return 1.0f / (1.0f + expf(-x)); } } __device__ __forceinline__ float warp_sum(float x) { #pragma unroll for (int offset = 16; offset > 0; offset >>= 1) { x += __shfl_down_sync(0xffffffffu, x, offset); } return x; } __device__ __forceinline__ float bf16_bits_to_float_dev(uint16_t bits) { union { uint32_t u; float f; } v; v.u = static_cast(bits) << 16; return v.f; } template __device__ __forceinline__ float block_sum_t(float x) { __shared__ float partial[Threads / 32]; const int lane = threadIdx.x & 31; const int warp = threadIdx.x >> 5; x = warp_sum(x); if (lane == 0) { partial[warp] = x; } __syncthreads(); x = (threadIdx.x < (Threads / 32)) ? partial[lane] : 0.0f; if (warp == 0) { x = warp_sum(x); } if (threadIdx.x == 0) { partial[0] = x; } __syncthreads(); return partial[0]; } __global__ void emb_ln0_bf16_to_f16_kernel( int V, int C, const uint16_t* __restrict__ emb, const uint16_t* __restrict__ weight, const uint16_t* __restrict__ bias, dtype* __restrict__ out, float eps) { // Precision path: bf16 inputs -> fp32 two-pass stats/affine -> fp16 output. const int tok = blockIdx.x; const int tid = threadIdx.x; if (tok >= V) { return; } const uint16_t* er = emb + static_cast(tok) * C; float sum = 0.0f; for (int c = tid; c < C; c += blockDim.x) { sum += bf16_bits_to_float_dev(er[c]); } const float mean = block_sum_t<256>(sum) / static_cast(C); float var = 0.0f; for (int c = tid; c < C; c += blockDim.x) { const float d = bf16_bits_to_float_dev(er[c]) - mean; var += d * d; } const float rstd = rsqrtf(block_sum_t<256>(var) / static_cast(C) + eps); dtype* yr = out + static_cast(tok) * C; for (int c = tid; c < C; c += blockDim.x) { const float x = bf16_bits_to_float_dev(er[c]); const float w = bf16_bits_to_float_dev(weight[c]); const float b = bf16_bits_to_float_dev(bias[c]); yr[c] = static_cast((x - mean) * rstd * w + b); } } __global__ void add_f16_kernel( const dtype* __restrict__ x, const dtype* __restrict__ y, dtype* __restrict__ out, int64_t n_pairs) { const int64_t i = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; if (i < n_pairs) { const float2 xv = __half22float2(reinterpret_cast(x)[i]); const float2 yv = __half22float2(reinterpret_cast(y)[i]); reinterpret_cast<__half2*>(out)[i] = __floats2half2_rn(xv.x + yv.x, xv.y + yv.y); } } __global__ void advance_i32_kernel(int* __restrict__ x, int amount, int64_t n) { const int64_t i = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; if (i < n) { x[i] += amount; } } template __global__ __launch_bounds__(128, 2) void linear_f16_m1_splitk_partial_kernel( int K, int N, const dtype* __restrict__ x, const dtype* __restrict__ weight, float* __restrict__ partial) { const int warp = threadIdx.x >> 5; const int lane = threadIdx.x & 31; const int pair = (blockIdx.x * Warps + warp) * 32 + lane; const int n = pair << 1; if (n >= N) { return; } const int k0 = blockIdx.y * ChunkK; const int k1 = min(k0 + ChunkK, K); float acc0 = 0.0f; float acc1 = 0.0f; for (int k = k0; k < k1; ++k) { const float xv = __half2float(*reinterpret_cast(x + k)); const float2 wv = __half22float2(*reinterpret_cast(weight + static_cast(k) * N + n)); acc0 = fmaf(xv, wv.x, acc0); acc1 = fmaf(xv, wv.y, acc1); } reinterpret_cast(partial + static_cast(blockIdx.y) * N)[pair] = make_float2(acc0, acc1); } __global__ void linear_f16_m1_splitk_reduce_kernel( int chunks, int N, const float* __restrict__ partial, dtype* __restrict__ y) { const int pair = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; const int n = pair << 1; if (n >= N) { return; } float acc0 = 0.0f; float acc1 = 0.0f; for (int c = 0; c < chunks; ++c) { const float2 v = reinterpret_cast(partial + static_cast(c) * N)[pair]; acc0 += v.x; acc1 += v.y; } reinterpret_cast<__half2*>(y)[pair] = __floats2half2_rn(acc0, acc1); } __global__ void linear_f16_m1_splitk_reduce_warp_kernel( int chunks, int N, const float* __restrict__ partial, dtype* __restrict__ y) { const int warp = threadIdx.x >> 5; const int lane = threadIdx.x & 31; const int pair = blockIdx.x * 4 + warp; const int n = pair << 1; if (n >= N) { return; } float acc0 = 0.0f; float acc1 = 0.0f; for (int c = lane; c < chunks; c += 32) { const float2 v = reinterpret_cast(partial + static_cast(c) * N)[pair]; acc0 += v.x; acc1 += v.y; } #pragma unroll for (int offset = 16; offset > 0; offset >>= 1) { acc0 += __shfl_down_sync(0xffffffffu, acc0, offset); acc1 += __shfl_down_sync(0xffffffffu, acc1, offset); } if (lane == 0) { reinterpret_cast<__half2*>(y)[pair] = __floats2half2_rn(acc0, acc1); } } template __global__ __launch_bounds__(128, 2) void linear_f16_rows_splitk_partial_kernel( int K, int N, int chunks, const dtype* __restrict__ x, const dtype* __restrict__ weight, float* __restrict__ partial) { const int warp = threadIdx.x >> 5; const int lane = threadIdx.x & 31; const int pair = (blockIdx.x * Warps + warp) * 32 + lane; const int n = pair << 1; if (n >= N) { return; } const int chunk = blockIdx.y; const int m = blockIdx.z; const int k0 = chunk * ChunkK; const int k1 = min(k0 + ChunkK, K); const dtype* x_row = x + static_cast(m) * K; float acc0 = 0.0f; float acc1 = 0.0f; for (int k = k0; k < k1; ++k) { const float xv = __half2float(*reinterpret_cast(x_row + k)); const float2 wv = __half22float2(*reinterpret_cast(weight + static_cast(k) * N + n)); acc0 = fmaf(xv, wv.x, acc0); acc1 = fmaf(xv, wv.y, acc1); } reinterpret_cast(partial + (static_cast(m) * chunks + chunk) * N)[pair] = make_float2(acc0, acc1); } __global__ void linear_f16_rows_splitk_reduce_kernel( int chunks, int N, const float* __restrict__ partial, dtype* __restrict__ y) { const int pair = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; const int m = blockIdx.y; const int n = pair << 1; if (n >= N) { return; } float acc0 = 0.0f; float acc1 = 0.0f; for (int c = 0; c < chunks; ++c) { const float2 v = reinterpret_cast(partial + (static_cast(m) * chunks + c) * N)[pair]; acc0 += v.x; acc1 += v.y; } reinterpret_cast<__half2*>(y + static_cast(m) * N)[pair] = __floats2half2_rn(acc0, acc1); } template __global__ __launch_bounds__(Threads, 2) void linear_t_f16_kernel( int M, int K, int N, const dtype* __restrict__ x, const dtype* __restrict__ weight_t, dtype* __restrict__ y) { const int n = blockIdx.x; const int m = blockIdx.y; if (m >= M || n >= N) { return; } float acc = 0.0f; const dtype* x_row = x + static_cast(m) * K; const dtype* w_row = weight_t + static_cast(n) * K; const int K2 = K >> 1; for (int k2 = threadIdx.x; k2 < K2; k2 += Threads) { const float2 xv = __half22float2(*reinterpret_cast(x_row + (k2 << 1))); const float2 wv = __half22float2(*reinterpret_cast(w_row + (k2 << 1))); acc = fmaf(xv.x, wv.x, acc); acc = fmaf(xv.y, wv.y, acc); } if ((K & 1) && threadIdx.x == 0) { acc = fmaf(__half2float(*reinterpret_cast(x_row + K - 1)), __half2float(*reinterpret_cast(w_row + K - 1)), acc); } acc = block_sum_t(acc); if (threadIdx.x == 0) { *reinterpret_cast<__half*>(y + static_cast(m) * N + n) = __float2half_rn(acc); } } template __global__ __launch_bounds__(Threads, 2) void linear_t_f16_ntile_kernel( int M, int K, int N, const dtype* __restrict__ x, const dtype* __restrict__ weight_t, dtype* __restrict__ y) { const int n0 = blockIdx.x * OutTile; const int m = blockIdx.y; if (m >= M) { return; } float acc[OutTile]; #pragma unroll for (int j = 0; j < OutTile; ++j) { acc[j] = 0.0f; } const dtype* x_row = x + static_cast(m) * K; const int K2 = K >> 1; for (int k2 = threadIdx.x; k2 < K2; k2 += Threads) { const int k = k2 << 1; const float2 xv = __half22float2(*reinterpret_cast(x_row + k)); #pragma unroll for (int j = 0; j < OutTile; ++j) { const int n = n0 + j; if (n < N) { const float2 wv = __half22float2(*reinterpret_cast(weight_t + static_cast(n) * K + k)); acc[j] = fmaf(xv.x, wv.x, acc[j]); acc[j] = fmaf(xv.y, wv.y, acc[j]); } } } if ((K & 1) && threadIdx.x == 0) { const float xv = __half2float(*reinterpret_cast(x_row + K - 1)); #pragma unroll for (int j = 0; j < OutTile; ++j) { const int n = n0 + j; if (n < N) { acc[j] = fmaf(xv, __half2float(*reinterpret_cast(weight_t + static_cast(n) * K + K - 1)), acc[j]); } } } __shared__ float partial[Threads / 32][OutTile]; const int lane = threadIdx.x & 31; const int warp = threadIdx.x >> 5; #pragma unroll for (int j = 0; j < OutTile; ++j) { acc[j] = warp_sum(acc[j]); if (lane == 0) { partial[warp][j] = acc[j]; } } __syncthreads(); if (threadIdx.x == 0) { #pragma unroll for (int j = 0; j < OutTile; ++j) { float sum = 0.0f; #pragma unroll for (int w = 0; w < Threads / 32; ++w) { sum += partial[w][j]; } const int n = n0 + j; if (n < N) { *reinterpret_cast<__half*>(y + static_cast(m) * N + n) = __float2half_rn(sum); } } } } template __global__ __launch_bounds__(Threads, 2) void linear_t_f16_ntile_scalar_kernel( int M, int K, int N, const dtype* __restrict__ x, const dtype* __restrict__ weight_t, dtype* __restrict__ y) { const int n0 = blockIdx.x * OutTile; const int m = blockIdx.y; if (m >= M) { return; } float acc[OutTile]; #pragma unroll for (int j = 0; j < OutTile; ++j) { acc[j] = 0.0f; } const dtype* x_row = x + static_cast(m) * K; for (int k = threadIdx.x; k < K; k += Threads) { const float xv = __half2float(*reinterpret_cast(x_row + k)); #pragma unroll for (int j = 0; j < OutTile; ++j) { const int n = n0 + j; if (n < N) { acc[j] = fmaf(xv, __half2float(*reinterpret_cast(weight_t + static_cast(n) * K + k)), acc[j]); } } } __shared__ float partial[Threads / 32][OutTile]; const int lane = threadIdx.x & 31; const int warp = threadIdx.x >> 5; #pragma unroll for (int j = 0; j < OutTile; ++j) { acc[j] = warp_sum(acc[j]); if (lane == 0) { partial[warp][j] = acc[j]; } } __syncthreads(); if (threadIdx.x == 0) { #pragma unroll for (int j = 0; j < OutTile; ++j) { float sum = 0.0f; #pragma unroll for (int w = 0; w < Threads / 32; ++w) { sum += partial[w][j]; } const int n = n0 + j; if (n < N) { *reinterpret_cast<__half*>(y + static_cast(m) * N + n) = __float2half_rn(sum); } } } } template __global__ __launch_bounds__(Threads, 1) void linear_orig_rows_f16_kernel( int M, int K, int N, const dtype* __restrict__ x, const dtype* __restrict__ weight_orig, dtype* __restrict__ y) { const int n0 = blockIdx.x * OutTile; const int m0 = blockIdx.y * RowTile; float acc[RowTile][OutTile]; #pragma unroll for (int r = 0; r < RowTile; ++r) { #pragma unroll for (int j = 0; j < OutTile; ++j) { acc[r][j] = 0.0f; } } const int K2 = K >> 1; for (int k2 = threadIdx.x; k2 < K2; k2 += Threads) { const int k = k2 << 1; float2 wv[OutTile]; #pragma unroll for (int j = 0; j < OutTile; ++j) { const int n = n0 + j; wv[j] = (n < N) ? __half22float2(*reinterpret_cast(weight_orig + static_cast(n) * K + k)) : make_float2(0.0f, 0.0f); } #pragma unroll for (int r = 0; r < RowTile; ++r) { const int m = m0 + r; if (m < M) { const float2 xv = __half22float2(*reinterpret_cast(x + static_cast(m) * K + k)); #pragma unroll for (int j = 0; j < OutTile; ++j) { acc[r][j] = fmaf(xv.x, wv[j].x, acc[r][j]); acc[r][j] = fmaf(xv.y, wv[j].y, acc[r][j]); } } } } if ((K & 1) && threadIdx.x == 0) { #pragma unroll for (int j = 0; j < OutTile; ++j) { const int n = n0 + j; if (n < N) { const float wv = __half2float(*reinterpret_cast(weight_orig + static_cast(n) * K + K - 1)); #pragma unroll for (int r = 0; r < RowTile; ++r) { const int m = m0 + r; if (m < M) { const float xv = __half2float(*reinterpret_cast(x + static_cast(m) * K + K - 1)); acc[r][j] = fmaf(xv, wv, acc[r][j]); } } } } } __shared__ float partial[Threads / 32][RowTile][OutTile]; const int lane = threadIdx.x & 31; const int warp = threadIdx.x >> 5; #pragma unroll for (int r = 0; r < RowTile; ++r) { #pragma unroll for (int j = 0; j < OutTile; ++j) { const float v = warp_sum(acc[r][j]); if (lane == 0) { partial[warp][r][j] = v; } } } __syncthreads(); if (threadIdx.x == 0) { #pragma unroll for (int r = 0; r < RowTile; ++r) { const int m = m0 + r; if (m < M) { #pragma unroll for (int j = 0; j < OutTile; ++j) { const int n = n0 + j; if (n < N) { float sum = 0.0f; #pragma unroll for (int w = 0; w < Threads / 32; ++w) { sum += partial[w][r][j]; } *reinterpret_cast<__half*>(y + static_cast(m) * N + n) = __float2half_rn(sum); } } } } } } template __global__ __launch_bounds__(Threads, 1) void linear_orig_row1_exact_f16_kernel( int K, int N, const dtype* __restrict__ x, const dtype* __restrict__ weight_orig, dtype* __restrict__ y) { const int n0 = blockIdx.x * OutTile; float acc[OutTile]; #pragma unroll for (int j = 0; j < OutTile; ++j) { acc[j] = 0.0f; } for (int k2 = threadIdx.x; k2 < (K >> 1); k2 += Threads) { const int k = k2 << 1; const float2 xv = __half22float2(*reinterpret_cast(x + k)); #pragma unroll for (int j = 0; j < OutTile; ++j) { const float2 wv = __half22float2(*reinterpret_cast(weight_orig + static_cast(n0 + j) * K + k)); acc[j] = fmaf(xv.x, wv.x, acc[j]); acc[j] = fmaf(xv.y, wv.y, acc[j]); } } __shared__ float partial[Threads / 32][OutTile]; const int lane = threadIdx.x & 31; const int warp = threadIdx.x >> 5; #pragma unroll for (int j = 0; j < OutTile; ++j) { const float v = warp_sum(acc[j]); if (lane == 0) { partial[warp][j] = v; } } __syncthreads(); if (threadIdx.x == 0) { #pragma unroll for (int j = 0; j < OutTile; ++j) { float sum = 0.0f; #pragma unroll for (int w = 0; w < Threads / 32; ++w) { sum += partial[w][j]; } y[n0 + j] = __float2half_rn(sum); } } } template __global__ __launch_bounds__(Threads, 1) void linear_orig_row1_exact4_f16_kernel( int K, int N, const dtype* __restrict__ x, const dtype* __restrict__ weight_orig, dtype* __restrict__ y) { const int n0 = blockIdx.x * OutTile; float acc[OutTile]; #pragma unroll for (int j = 0; j < OutTile; ++j) { acc[j] = 0.0f; } for (int k = threadIdx.x << 2; k < K; k += Threads << 2) { const float2 x0 = __half22float2(*reinterpret_cast(x + k)); const float2 x1 = __half22float2(*reinterpret_cast(x + k + 2)); #pragma unroll for (int j = 0; j < OutTile; ++j) { const dtype* wj = weight_orig + static_cast(n0 + j) * K + k; const float2 w0 = __half22float2(*reinterpret_cast(wj)); const float2 w1 = __half22float2(*reinterpret_cast(wj + 2)); acc[j] = fmaf(x0.x, w0.x, acc[j]); acc[j] = fmaf(x0.y, w0.y, acc[j]); acc[j] = fmaf(x1.x, w1.x, acc[j]); acc[j] = fmaf(x1.y, w1.y, acc[j]); } } __shared__ float partial[Threads / 32][OutTile]; const int lane = threadIdx.x & 31; const int warp = threadIdx.x >> 5; #pragma unroll for (int j = 0; j < OutTile; ++j) { const float v = warp_sum(acc[j]); if (lane == 0) { partial[warp][j] = v; } } __syncthreads(); if (threadIdx.x == 0) { #pragma unroll for (int j = 0; j < OutTile; ++j) { float sum = 0.0f; #pragma unroll for (int w = 0; w < Threads / 32; ++w) { sum += partial[w][j]; } y[n0 + j] = __float2half_rn(sum); } } } template __global__ __launch_bounds__(Threads, 1) void linear_orig_row2_exact_f16_kernel( int K, int N, const dtype* __restrict__ x, const dtype* __restrict__ weight_orig, dtype* __restrict__ y) { const int n0 = blockIdx.x * OutTile; float acc0[OutTile]; float acc1[OutTile]; #pragma unroll for (int j = 0; j < OutTile; ++j) { acc0[j] = 0.0f; acc1[j] = 0.0f; } for (int k2 = threadIdx.x; k2 < (K >> 1); k2 += Threads) { const int k = k2 << 1; const float2 x0 = __half22float2(*reinterpret_cast(x + k)); const float2 x1 = __half22float2(*reinterpret_cast(x + K + k)); #pragma unroll for (int j = 0; j < OutTile; ++j) { const float2 wv = __half22float2(*reinterpret_cast(weight_orig + static_cast(n0 + j) * K + k)); acc0[j] = fmaf(x0.x, wv.x, acc0[j]); acc0[j] = fmaf(x0.y, wv.y, acc0[j]); acc1[j] = fmaf(x1.x, wv.x, acc1[j]); acc1[j] = fmaf(x1.y, wv.y, acc1[j]); } } __shared__ float partial[Threads / 32][2][OutTile]; const int lane = threadIdx.x & 31; const int warp = threadIdx.x >> 5; #pragma unroll for (int j = 0; j < OutTile; ++j) { const float v0 = warp_sum(acc0[j]); const float v1 = warp_sum(acc1[j]); if (lane == 0) { partial[warp][0][j] = v0; partial[warp][1][j] = v1; } } __syncthreads(); if (threadIdx.x == 0) { #pragma unroll for (int j = 0; j < OutTile; ++j) { float sum0 = 0.0f; float sum1 = 0.0f; #pragma unroll for (int w = 0; w < Threads / 32; ++w) { sum0 += partial[w][0][j]; sum1 += partial[w][1][j]; } const int n = n0 + j; y[n] = __float2half_rn(sum0); y[N + n] = __float2half_rn(sum1); } } } template __global__ __launch_bounds__(Threads, 1) void linear_orig_row2_exact4_f16_kernel( int K, int N, const dtype* __restrict__ x, const dtype* __restrict__ weight_orig, dtype* __restrict__ y) { const int n0 = blockIdx.x * OutTile; float acc0[OutTile]; float acc1[OutTile]; #pragma unroll for (int j = 0; j < OutTile; ++j) { acc0[j] = 0.0f; acc1[j] = 0.0f; } for (int k = threadIdx.x << 2; k < K; k += Threads << 2) { const float2 x00 = __half22float2(*reinterpret_cast(x + k)); const float2 x01 = __half22float2(*reinterpret_cast(x + k + 2)); const float2 x10 = __half22float2(*reinterpret_cast(x + K + k)); const float2 x11 = __half22float2(*reinterpret_cast(x + K + k + 2)); #pragma unroll for (int j = 0; j < OutTile; ++j) { const dtype* wj = weight_orig + static_cast(n0 + j) * K + k; const float2 w0 = __half22float2(*reinterpret_cast(wj)); const float2 w1 = __half22float2(*reinterpret_cast(wj + 2)); acc0[j] = fmaf(x00.x, w0.x, acc0[j]); acc0[j] = fmaf(x00.y, w0.y, acc0[j]); acc0[j] = fmaf(x01.x, w1.x, acc0[j]); acc0[j] = fmaf(x01.y, w1.y, acc0[j]); acc1[j] = fmaf(x10.x, w0.x, acc1[j]); acc1[j] = fmaf(x10.y, w0.y, acc1[j]); acc1[j] = fmaf(x11.x, w1.x, acc1[j]); acc1[j] = fmaf(x11.y, w1.y, acc1[j]); } } __shared__ float partial[Threads / 32][2][OutTile]; const int lane = threadIdx.x & 31; const int warp = threadIdx.x >> 5; #pragma unroll for (int j = 0; j < OutTile; ++j) { const float v0 = warp_sum(acc0[j]); const float v1 = warp_sum(acc1[j]); if (lane == 0) { partial[warp][0][j] = v0; partial[warp][1][j] = v1; } } __syncthreads(); if (threadIdx.x == 0) { #pragma unroll for (int j = 0; j < OutTile; ++j) { float sum0 = 0.0f; float sum1 = 0.0f; #pragma unroll for (int w = 0; w < Threads / 32; ++w) { sum0 += partial[w][0][j]; sum1 += partial[w][1][j]; } const int n = n0 + j; y[n] = __float2half_rn(sum0); y[N + n] = __float2half_rn(sum1); } } } __global__ __launch_bounds__(32, 8) void linear_orig_wmma16_f16_kernel( int M, int K, int N, const dtype* __restrict__ x, const dtype* __restrict__ weight_orig, dtype* __restrict__ y) { const int n0 = blockIdx.x * 16; const int m0 = blockIdx.y * 16; __shared__ __half a_tile[16 * 16]; __shared__ __half b_tile[16 * 16]; __shared__ float c_tile[16 * 16]; wmma::fragment a_frag; wmma::fragment b_frag; wmma::fragment c_frag; wmma::fill_fragment(c_frag, 0.0f); for (int k0 = 0; k0 < K; k0 += 16) { for (int idx = threadIdx.x; idx < 16 * 16; idx += 32) { const int r = idx >> 4; const int kk = idx & 15; const int m = m0 + r; a_tile[idx] = (m < M && k0 + kk < K) ? *reinterpret_cast(x + static_cast(m) * K + k0 + kk) : __float2half(0.0f); const int n = n0 + r; b_tile[r * 16 + kk] = (n < N && k0 + kk < K) ? *reinterpret_cast(weight_orig + static_cast(n) * K + k0 + kk) : __float2half(0.0f); } __syncwarp(); wmma::load_matrix_sync(a_frag, a_tile, 16); wmma::load_matrix_sync(b_frag, b_tile, 16); wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); __syncwarp(); } wmma::store_matrix_sync(c_tile, c_frag, 16, wmma::mem_row_major); __syncwarp(); for (int idx = threadIdx.x; idx < 16 * 16; idx += 32) { const int r = idx >> 4; const int j = idx & 15; const int m = m0 + r; const int n = n0 + j; if (m < M && n < N) { *reinterpret_cast<__half*>(y + static_cast(m) * N + n) = __float2half_rn(c_tile[idx]); } } } template __global__ __launch_bounds__(Threads, 2) void linear_t_act_f16_ntile_scalar_kernel( int M, int K, int N, const dtype* __restrict__ x, const dtype* __restrict__ weight_t, dtype* __restrict__ y) { const int n0 = blockIdx.x * OutTile; const int m = blockIdx.y; if (m >= M) { return; } float acc[OutTile]; #pragma unroll for (int j = 0; j < OutTile; ++j) { acc[j] = 0.0f; } const dtype* x_row = x + static_cast(m) * K; for (int k = threadIdx.x; k < K; k += Threads) { const float xv = apply_act(__half2float(*reinterpret_cast(x_row + k))); #pragma unroll for (int j = 0; j < OutTile; ++j) { const int n = n0 + j; if (n < N) { acc[j] = fmaf(xv, __half2float(*reinterpret_cast(weight_t + static_cast(n) * K + k)), acc[j]); } } } __shared__ float partial[Threads / 32][OutTile]; const int lane = threadIdx.x & 31; const int warp = threadIdx.x >> 5; #pragma unroll for (int j = 0; j < OutTile; ++j) { acc[j] = warp_sum(acc[j]); if (lane == 0) { partial[warp][j] = acc[j]; } } __syncthreads(); if (threadIdx.x == 0) { #pragma unroll for (int j = 0; j < OutTile; ++j) { float sum = 0.0f; #pragma unroll for (int w = 0; w < Threads / 32; ++w) { sum += partial[w][j]; } const int n = n0 + j; if (n < N) { *reinterpret_cast<__half*>(y + static_cast(m) * N + n) = __float2half_rn(sum); } } } } template __global__ __launch_bounds__(Threads, 2) void linear_t_act_f16_ntile_kernel( int M, int K, int N, const dtype* __restrict__ x, const dtype* __restrict__ weight_t, dtype* __restrict__ y) { const int n0 = blockIdx.x * OutTile; const int m = blockIdx.y; if (m >= M) { return; } float acc[OutTile]; #pragma unroll for (int j = 0; j < OutTile; ++j) { acc[j] = 0.0f; } const dtype* x_row = x + static_cast(m) * K; const int K2 = K >> 1; for (int k2 = threadIdx.x; k2 < K2; k2 += Threads) { const int k = k2 << 1; float2 xv = __half22float2(*reinterpret_cast(x_row + k)); xv.x = apply_act(xv.x); xv.y = apply_act(xv.y); #pragma unroll for (int j = 0; j < OutTile; ++j) { const int n = n0 + j; if (n < N) { const float2 wv = __half22float2(*reinterpret_cast(weight_t + static_cast(n) * K + k)); acc[j] = fmaf(xv.x, wv.x, acc[j]); acc[j] = fmaf(xv.y, wv.y, acc[j]); } } } if ((K & 1) && threadIdx.x == 0) { const float xv = apply_act(__half2float(*reinterpret_cast(x_row + K - 1))); #pragma unroll for (int j = 0; j < OutTile; ++j) { const int n = n0 + j; if (n < N) { acc[j] = fmaf(xv, __half2float(*reinterpret_cast(weight_t + static_cast(n) * K + K - 1)), acc[j]); } } } __shared__ float partial[Threads / 32][OutTile]; const int lane = threadIdx.x & 31; const int warp = threadIdx.x >> 5; #pragma unroll for (int j = 0; j < OutTile; ++j) { acc[j] = warp_sum(acc[j]); if (lane == 0) { partial[warp][j] = acc[j]; } } __syncthreads(); if (threadIdx.x == 0) { #pragma unroll for (int j = 0; j < OutTile; ++j) { float sum = 0.0f; #pragma unroll for (int w = 0; w < Threads / 32; ++w) { sum += partial[w][j]; } const int n = n0 + j; if (n < N) { *reinterpret_cast<__half*>(y + static_cast(m) * N + n) = __float2half_rn(sum); } } } } template __global__ __launch_bounds__(Threads, 2) void linear_wag_rank_in_f16_kernel( int M, int K, int Rw, int Ra, int Rg, int Rmax, const dtype* __restrict__ xw, const dtype* __restrict__ xa, const dtype* __restrict__ xg, const dtype* __restrict__ w1_t, const dtype* __restrict__ a1_t, const dtype* __restrict__ g1_t, dtype* __restrict__ w1, dtype* __restrict__ a1, dtype* __restrict__ g1) { const int r = blockIdx.x; const int m = blockIdx.y; const int group = blockIdx.z; int R = Rw; const dtype* x = xw; const dtype* wt = w1_t; dtype* y = w1; if (group == 1) { R = Ra; x = xa; wt = a1_t; y = a1; } else if (group == 2) { R = Rg; x = xg; wt = g1_t; y = g1; } if (m >= M || r >= R || r >= Rmax) { return; } float acc = 0.0f; const dtype* x_row = x + static_cast(m) * K; const dtype* w_row = wt + static_cast(r) * K; const int K2 = K >> 1; for (int k2 = threadIdx.x; k2 < K2; k2 += Threads) { const int k = k2 << 1; const float2 xv = __half22float2(*reinterpret_cast(x_row + k)); const float2 wv = __half22float2(*reinterpret_cast(w_row + k)); acc = fmaf(xv.x, wv.x, acc); acc = fmaf(xv.y, wv.y, acc); } if ((K & 1) && threadIdx.x == 0) { acc = fmaf(__half2float(*reinterpret_cast(x_row + K - 1)), __half2float(*reinterpret_cast(w_row + K - 1)), acc); } acc = block_sum_t(acc); if (threadIdx.x == 0) { *reinterpret_cast<__half*>(y + static_cast(m) * R + r) = __float2half_rn(acc); } } template __global__ __launch_bounds__(Threads, 2) void linear_wagv_rank_in_f16_kernel( int M, int K, int Rw, int Ra, int Rg, int Rv, int Rmax, const dtype* __restrict__ xw, const dtype* __restrict__ xa, const dtype* __restrict__ xg, const dtype* __restrict__ xv, const dtype* __restrict__ w1_t, const dtype* __restrict__ a1_t, const dtype* __restrict__ g1_t, const dtype* __restrict__ v1_t, dtype* __restrict__ w1, dtype* __restrict__ a1, dtype* __restrict__ g1, dtype* __restrict__ v1) { const int r = blockIdx.x; const int m = blockIdx.y; const int group = blockIdx.z; int R = Rw; const dtype* x = xw; const dtype* wt = w1_t; dtype* y = w1; if (group == 1) { R = Ra; x = xa; wt = a1_t; y = a1; } else if (group == 2) { R = Rg; x = xg; wt = g1_t; y = g1; } else if (group == 3) { R = Rv; x = xv; wt = v1_t; y = v1; } if (m >= M || r >= R || r >= Rmax) { return; } float acc = 0.0f; const dtype* x_row = x + static_cast(m) * K; const dtype* w_row = wt + static_cast(r) * K; const int K2 = K >> 1; for (int k2 = threadIdx.x; k2 < K2; k2 += Threads) { const int k = k2 << 1; const float2 xv2 = __half22float2(*reinterpret_cast(x_row + k)); const float2 wv = __half22float2(*reinterpret_cast(w_row + k)); acc = fmaf(xv2.x, wv.x, acc); acc = fmaf(xv2.y, wv.y, acc); } if ((K & 1) && threadIdx.x == 0) { acc = fmaf(__half2float(*reinterpret_cast(x_row + K - 1)), __half2float(*reinterpret_cast(w_row + K - 1)), acc); } acc = block_sum_t(acc); if (threadIdx.x == 0) { *reinterpret_cast<__half*>(y + static_cast(m) * R + r) = __float2half_rn(acc); } } template __global__ __launch_bounds__(Threads, 2) void linear_wag_rank_out_f16_kernel( int M, int C, int Kw, int Ka, int Kg, const dtype* __restrict__ w1, const dtype* __restrict__ a1, const dtype* __restrict__ g1, const dtype* __restrict__ w2_t, const dtype* __restrict__ a2_t, const dtype* __restrict__ g2_t, dtype* __restrict__ w, dtype* __restrict__ a, dtype* __restrict__ g) { const int n0 = blockIdx.x * OutTile; const int m = blockIdx.y; const int group = blockIdx.z; int K = Kw; const dtype* x = w1; const dtype* wt = w2_t; dtype* y = w; if (group == 1) { K = Ka; x = a1; wt = a2_t; y = a; } else if (group == 2) { K = Kg; x = g1; wt = g2_t; y = g; } if (m >= M) { return; } float acc[OutTile]; #pragma unroll for (int j = 0; j < OutTile; ++j) { acc[j] = 0.0f; } const dtype* x_row = x + static_cast(m) * K; for (int k = threadIdx.x; k < K; k += Threads) { float xv = __half2float(*reinterpret_cast(x_row + k)); if (group == 0) { xv = tanhf(xv); } else if (group == 2) { xv = 1.0f / (1.0f + expf(-xv)); } #pragma unroll for (int j = 0; j < OutTile; ++j) { const int n = n0 + j; if (n < C) { acc[j] = fmaf(xv, __half2float(*reinterpret_cast(wt + static_cast(n) * K + k)), acc[j]); } } } __shared__ float partial[Threads / 32][OutTile]; const int lane = threadIdx.x & 31; const int warp = threadIdx.x >> 5; #pragma unroll for (int j = 0; j < OutTile; ++j) { acc[j] = warp_sum(acc[j]); if (lane == 0) { partial[warp][j] = acc[j]; } } __syncthreads(); if (threadIdx.x == 0) { #pragma unroll for (int j = 0; j < OutTile; ++j) { float sum = 0.0f; #pragma unroll for (int u = 0; u < Threads / 32; ++u) { sum += partial[u][j]; } const int n = n0 + j; if (n < C) { *reinterpret_cast<__half*>(y + static_cast(m) * C + n) = __float2half_rn(sum); } } } } template __global__ __launch_bounds__(Threads, 2) void linear_wagv_rank_out_f16_kernel( int M, int C, int Kw, int Ka, int Kg, int Kv, const dtype* __restrict__ w1, const dtype* __restrict__ a1, const dtype* __restrict__ g1, const dtype* __restrict__ v1, const dtype* __restrict__ w2_t, const dtype* __restrict__ a2_t, const dtype* __restrict__ g2_t, const dtype* __restrict__ v2_t, const dtype* __restrict__ v, const dtype* __restrict__ v_first, const dtype* __restrict__ v0, dtype* __restrict__ w, dtype* __restrict__ a, dtype* __restrict__ g, dtype* __restrict__ v_out) { const int n0 = blockIdx.x * OutTile; const int m = blockIdx.y; const int group = blockIdx.z; int K = Kw; const dtype* x = w1; const dtype* wt = w2_t; dtype* y = w; if (group == 1) { K = Ka; x = a1; wt = a2_t; y = a; } else if (group == 2) { K = Kg; x = g1; wt = g2_t; y = g; } else if (group == 3) { K = Kv; x = v1; wt = v2_t; y = v_out; } if (m >= M) { return; } float acc[OutTile]; #pragma unroll for (int j = 0; j < OutTile; ++j) { acc[j] = 0.0f; } const dtype* x_row = x + static_cast(m) * K; for (int k = threadIdx.x; k < K; k += Threads) { float xv = __half2float(*reinterpret_cast(x_row + k)); if (group == 0) { xv = tanhf(xv); } else if (group == 2) { xv = 1.0f / (1.0f + expf(-xv)); } #pragma unroll for (int j = 0; j < OutTile; ++j) { const int n = n0 + j; if (n < C) { acc[j] = fmaf(xv, __half2float(*reinterpret_cast(wt + static_cast(n) * K + k)), acc[j]); } } } __shared__ float partial[Threads / 32][OutTile]; const int lane = threadIdx.x & 31; const int warp = threadIdx.x >> 5; #pragma unroll for (int j = 0; j < OutTile; ++j) { acc[j] = warp_sum(acc[j]); if (lane == 0) { partial[warp][j] = acc[j]; } } __syncthreads(); if (threadIdx.x == 0) { #pragma unroll for (int j = 0; j < OutTile; ++j) { float sum = 0.0f; #pragma unroll for (int u = 0; u < Threads / 32; ++u) { sum += partial[u][j]; } const int n = n0 + j; if (n < C) { if (group == 3) { const int64_t idx = static_cast(m) * C + n; const float vv = __half2float(*reinterpret_cast(v + idx)); const float vf = __half2float(*reinterpret_cast(v_first + idx)); const float gate = 1.0f / (1.0f + expf(-(__half2float(*reinterpret_cast(v0 + n)) + sum))); *reinterpret_cast<__half*>(y + idx) = __float2half_rn(vv + (vf - vv) * gate); } else { *reinterpret_cast<__half*>(y + static_cast(m) * C + n) = __float2half_rn(sum); } } } } } template __global__ __launch_bounds__(Threads, 2) void linear_t_vres_f16_ntile_scalar_kernel( int M, int K, int N, const dtype* __restrict__ x, const dtype* __restrict__ weight_t, const dtype* __restrict__ v, const dtype* __restrict__ v_first, const dtype* __restrict__ v0, dtype* __restrict__ y) { const int n0 = blockIdx.x * OutTile; const int m = blockIdx.y; if (m >= M) { return; } float acc[OutTile]; #pragma unroll for (int j = 0; j < OutTile; ++j) { acc[j] = 0.0f; } const dtype* x_row = x + static_cast(m) * K; for (int k = threadIdx.x; k < K; k += Threads) { const float xv = __half2float(*reinterpret_cast(x_row + k)); #pragma unroll for (int j = 0; j < OutTile; ++j) { const int n = n0 + j; if (n < N) { acc[j] = fmaf(xv, __half2float(*reinterpret_cast(weight_t + static_cast(n) * K + k)), acc[j]); } } } __shared__ float partial[Threads / 32][OutTile]; const int lane = threadIdx.x & 31; const int warp = threadIdx.x >> 5; #pragma unroll for (int j = 0; j < OutTile; ++j) { acc[j] = warp_sum(acc[j]); if (lane == 0) { partial[warp][j] = acc[j]; } } __syncthreads(); if (threadIdx.x == 0) { #pragma unroll for (int j = 0; j < OutTile; ++j) { float sum = 0.0f; #pragma unroll for (int w = 0; w < Threads / 32; ++w) { sum += partial[w][j]; } const int n = n0 + j; if (n < N) { const int64_t idx = static_cast(m) * N + n; const float vv = __half2float(*reinterpret_cast(v + idx)); const float vf = __half2float(*reinterpret_cast(v_first + idx)); const float gate = 1.0f / (1.0f + expf(-(__half2float(*reinterpret_cast(v0 + n)) + sum))); *reinterpret_cast<__half*>(y + idx) = __float2half_rn(vv + (vf - vv) * gate); } } } } template __global__ __launch_bounds__(Threads, 2) void linear_t_vres_f16_ntile_kernel( int M, int K, int N, const dtype* __restrict__ x, const dtype* __restrict__ weight_t, const dtype* __restrict__ v, const dtype* __restrict__ v_first, const dtype* __restrict__ v0, dtype* __restrict__ y) { const int n0 = blockIdx.x * OutTile; const int m = blockIdx.y; if (m >= M) { return; } float acc[OutTile]; #pragma unroll for (int j = 0; j < OutTile; ++j) { acc[j] = 0.0f; } const dtype* x_row = x + static_cast(m) * K; const int K2 = K >> 1; for (int k2 = threadIdx.x; k2 < K2; k2 += Threads) { const int k = k2 << 1; const float2 xv = __half22float2(*reinterpret_cast(x_row + k)); #pragma unroll for (int j = 0; j < OutTile; ++j) { const int n = n0 + j; if (n < N) { const float2 wv = __half22float2(*reinterpret_cast(weight_t + static_cast(n) * K + k)); acc[j] = fmaf(xv.x, wv.x, acc[j]); acc[j] = fmaf(xv.y, wv.y, acc[j]); } } } if ((K & 1) && threadIdx.x == 0) { const float xv = __half2float(*reinterpret_cast(x_row + K - 1)); #pragma unroll for (int j = 0; j < OutTile; ++j) { const int n = n0 + j; if (n < N) { acc[j] = fmaf(xv, __half2float(*reinterpret_cast(weight_t + static_cast(n) * K + K - 1)), acc[j]); } } } __shared__ float partial[Threads / 32][OutTile]; const int lane = threadIdx.x & 31; const int warp = threadIdx.x >> 5; #pragma unroll for (int j = 0; j < OutTile; ++j) { acc[j] = warp_sum(acc[j]); if (lane == 0) { partial[warp][j] = acc[j]; } } __syncthreads(); if (threadIdx.x == 0) { #pragma unroll for (int j = 0; j < OutTile; ++j) { float sum = 0.0f; #pragma unroll for (int w = 0; w < Threads / 32; ++w) { sum += partial[w][j]; } const int n = n0 + j; if (n < N) { const int64_t idx = static_cast(m) * N + n; const float vv = __half2float(*reinterpret_cast(v + idx)); const float vf = __half2float(*reinterpret_cast(v_first + idx)); const float gate = 1.0f / (1.0f + expf(-(__half2float(*reinterpret_cast(v0 + n)) + sum))); *reinterpret_cast<__half*>(y + idx) = __float2half_rn(vv + (vf - vv) * gate); } } } } __global__ void layer_norm_f16_kernel( int C, const dtype* __restrict__ x, const dtype* __restrict__ weight, const dtype* __restrict__ bias, dtype* __restrict__ y, int64_t rows, float eps) { const int64_t row = blockIdx.x; if (row >= rows) { return; } const int64_t base = row * C; float sum = 0.0f; for (int c = threadIdx.x; c < C; c += blockDim.x) { const float v = __half2float(*reinterpret_cast(x + base + c)); sum += v; } sum = block_sum_t(sum); const float inv_c = 1.0f / static_cast(C); const float mean = sum * inv_c; float sum_var = 0.0f; for (int c = threadIdx.x; c < C; c += blockDim.x) { const float v = __half2float(*reinterpret_cast(x + base + c)); const float d = v - mean; sum_var += d * d; } sum_var = block_sum_t(sum_var); const float var = sum_var * inv_c; const float rstd = rsqrtf(var + eps); for (int c = threadIdx.x; c < C; c += blockDim.x) { const float v = __half2float(*reinterpret_cast(x + base + c)); const float w = __half2float(*reinterpret_cast(weight + c)); const float b = __half2float(*reinterpret_cast(bias + c)); *reinterpret_cast<__half*>(y + base + c) = __float2half_rn((v - mean) * rstd * w + b); } } __global__ void add_layer_norm_f16_kernel( int C, const dtype* __restrict__ x, const dtype* __restrict__ residual, const dtype* __restrict__ weight, const dtype* __restrict__ bias, dtype* __restrict__ x_out, dtype* __restrict__ y, int64_t rows, float eps) { const int64_t row = blockIdx.x; if (row >= rows) { return; } const int64_t base = row * C; float sum = 0.0f; for (int c = threadIdx.x; c < C; c += blockDim.x) { const float v = __half2float(*reinterpret_cast(x + base + c)) + __half2float(*reinterpret_cast(residual + base + c)); sum += v; } sum = block_sum_t(sum); const float inv_c = 1.0f / static_cast(C); const float mean = sum * inv_c; float sum_var = 0.0f; for (int c = threadIdx.x; c < C; c += blockDim.x) { const float v = __half2float(*reinterpret_cast(x + base + c)) + __half2float(*reinterpret_cast(residual + base + c)); const float d = v - mean; sum_var += d * d; } sum_var = block_sum_t(sum_var); const float rstd = rsqrtf(sum_var * inv_c + eps); for (int c = threadIdx.x; c < C; c += blockDim.x) { const float v = __half2float(*reinterpret_cast(x + base + c)) + __half2float(*reinterpret_cast(residual + base + c)); const float w = __half2float(*reinterpret_cast(weight + c)); const float b = __half2float(*reinterpret_cast(bias + c)); *reinterpret_cast<__half*>(x_out + base + c) = __float2half_rn(v); *reinterpret_cast<__half*>(y + base + c) = __float2half_rn((v - mean) * rstd * w + b); } } template __global__ __launch_bounds__(Threads, 1) void layer_norm_f16_small_kernel( const dtype* __restrict__ x, const dtype* __restrict__ weight, const dtype* __restrict__ bias, dtype* __restrict__ y, int64_t rows, float eps) { const int64_t row = blockIdx.x; if (row >= rows) { return; } const int64_t base = row * LN_SMALL_C; float sum = 0.0f; if constexpr (VecStats) { #pragma unroll for (int k = 0; k < (LN_SMALL_C / 2) / Threads; ++k) { const int idx = threadIdx.x + k * Threads; const float2 v = __half22float2(reinterpret_cast(x + base)[idx]); sum += v.x + v.y; } } else { #pragma unroll for (int k = 0; k < LN_SMALL_C / Threads; ++k) { const int c = threadIdx.x + k * Threads; const float v = __half2float(*reinterpret_cast(x + base + c)); sum += v; } } sum = block_sum_t(sum); const float mean = sum * (1.0f / static_cast(LN_SMALL_C)); float sum_var = 0.0f; if constexpr (VecStats) { #pragma unroll for (int k = 0; k < (LN_SMALL_C / 2) / Threads; ++k) { const int idx = threadIdx.x + k * Threads; const float2 v = __half22float2(reinterpret_cast(x + base)[idx]); const float dx = v.x - mean; const float dy = v.y - mean; sum_var += dx * dx + dy * dy; } } else { #pragma unroll for (int k = 0; k < LN_SMALL_C / Threads; ++k) { const int c = threadIdx.x + k * Threads; const float v = __half2float(*reinterpret_cast(x + base + c)); const float d = v - mean; sum_var += d * d; } } sum_var = block_sum_t(sum_var); const float rstd = rsqrtf(sum_var * (1.0f / static_cast(LN_SMALL_C)) + eps); if constexpr (VecOut) { #pragma unroll for (int k = 0; k < (LN_SMALL_C / 2) / Threads; ++k) { const int idx = threadIdx.x + k * Threads; const float2 v = __half22float2(reinterpret_cast(x + base)[idx]); const float2 w = __half22float2(reinterpret_cast(weight)[idx]); const float2 b = __half22float2(reinterpret_cast(bias)[idx]); reinterpret_cast<__half2*>(y + base)[idx] = __floats2half2_rn( (v.x - mean) * rstd * w.x + b.x, (v.y - mean) * rstd * w.y + b.y); } } else { #pragma unroll for (int k = 0; k < LN_SMALL_C / Threads; ++k) { const int c = threadIdx.x + k * Threads; const float v = __half2float(*reinterpret_cast(x + base + c)); const float w = __half2float(*reinterpret_cast(weight + c)); const float b = __half2float(*reinterpret_cast(bias + c)); *reinterpret_cast<__half*>(y + base + c) = __float2half_rn((v - mean) * rstd * w + b); } } } template __global__ __launch_bounds__(Threads, 1) void add_layer_norm_f16_small_kernel( const dtype* __restrict__ x, const dtype* __restrict__ residual, const dtype* __restrict__ weight, const dtype* __restrict__ bias, dtype* __restrict__ x_out, dtype* __restrict__ y, int64_t rows, float eps) { const int64_t row = blockIdx.x; if (row >= rows) { return; } const int64_t base = row * LN_SMALL_C; float sum = 0.0f; if constexpr (VecStats) { #pragma unroll for (int k = 0; k < (LN_SMALL_C / 2) / Threads; ++k) { const int idx = threadIdx.x + k * Threads; const float2 xv = __half22float2(reinterpret_cast(x + base)[idx]); const float2 rv = __half22float2(reinterpret_cast(residual + base)[idx]); sum += xv.x + rv.x + xv.y + rv.y; } } else { #pragma unroll for (int k = 0; k < LN_SMALL_C / Threads; ++k) { const int c = threadIdx.x + k * Threads; const float v = __half2float(*reinterpret_cast(x + base + c)) + __half2float(*reinterpret_cast(residual + base + c)); sum += v; } } sum = block_sum_t(sum); const float mean = sum * (1.0f / static_cast(LN_SMALL_C)); float sum_var = 0.0f; if constexpr (VecStats) { #pragma unroll for (int k = 0; k < (LN_SMALL_C / 2) / Threads; ++k) { const int idx = threadIdx.x + k * Threads; const float2 xv = __half22float2(reinterpret_cast(x + base)[idx]); const float2 rv = __half22float2(reinterpret_cast(residual + base)[idx]); const float dx = xv.x + rv.x - mean; const float dy = xv.y + rv.y - mean; sum_var += dx * dx + dy * dy; } } else { #pragma unroll for (int k = 0; k < LN_SMALL_C / Threads; ++k) { const int c = threadIdx.x + k * Threads; const float v = __half2float(*reinterpret_cast(x + base + c)) + __half2float(*reinterpret_cast(residual + base + c)); const float d = v - mean; sum_var += d * d; } } sum_var = block_sum_t(sum_var); const float rstd = rsqrtf(sum_var * (1.0f / static_cast(LN_SMALL_C)) + eps); if constexpr (VecOut) { #pragma unroll for (int k = 0; k < (LN_SMALL_C / 2) / Threads; ++k) { const int idx = threadIdx.x + k * Threads; const float2 xv = __half22float2(reinterpret_cast(x + base)[idx]); const float2 rv = __half22float2(reinterpret_cast(residual + base)[idx]); const float sx = xv.x + rv.x; const float sy = xv.y + rv.y; const float2 w = __half22float2(reinterpret_cast(weight)[idx]); const float2 b = __half22float2(reinterpret_cast(bias)[idx]); reinterpret_cast<__half2*>(x_out + base)[idx] = __floats2half2_rn(sx, sy); reinterpret_cast<__half2*>(y + base)[idx] = __floats2half2_rn( (sx - mean) * rstd * w.x + b.x, (sy - mean) * rstd * w.y + b.y); } } else { #pragma unroll for (int k = 0; k < LN_SMALL_C / Threads; ++k) { const int c = threadIdx.x + k * Threads; const float v = __half2float(*reinterpret_cast(x + base + c)) + __half2float(*reinterpret_cast(residual + base + c)); const float w = __half2float(*reinterpret_cast(weight + c)); const float b = __half2float(*reinterpret_cast(bias + c)); *reinterpret_cast<__half*>(x_out + base + c) = __float2half_rn(v); *reinterpret_cast<__half*>(y + base + c) = __float2half_rn((v - mean) * rstd * w + b); } } } template __global__ __launch_bounds__(Threads, 1) void add_layer_norm_cmix_mix_f16_kernel( const dtype* __restrict__ x, const dtype* __restrict__ residual, dtype* __restrict__ shift_state, const dtype* __restrict__ weight, const dtype* __restrict__ bias, const dtype* __restrict__ x_k, dtype* __restrict__ x_out, dtype* __restrict__ mixed, int64_t rows, float eps) { const int64_t row = blockIdx.x; if (row >= rows) { return; } const int64_t base = row * LN_SMALL_C; float sum = 0.0f; const int64_t base2 = base >> 1; constexpr int pairs = LN_SMALL_C >> 1; #pragma unroll for (int k = 0; k < pairs / Threads; ++k) { const int p = threadIdx.x + k * Threads; const float2 xv = __half22float2(reinterpret_cast(x)[base2 + p]); const float2 rv = __half22float2(reinterpret_cast(residual)[base2 + p]); sum += xv.x + rv.x + xv.y + rv.y; } sum = block_sum_t(sum); const float mean = sum * (1.0f / static_cast(LN_SMALL_C)); float sum_var = 0.0f; #pragma unroll for (int k = 0; k < pairs / Threads; ++k) { const int p = threadIdx.x + k * Threads; const float2 xv = __half22float2(reinterpret_cast(x)[base2 + p]); const float2 rv = __half22float2(reinterpret_cast(residual)[base2 + p]); const float x0 = xv.x + rv.x; const float x1 = xv.y + rv.y; const float d0 = x0 - mean; const float d1 = x1 - mean; sum_var += d0 * d0 + d1 * d1; } sum_var = block_sum_t(sum_var); const float rstd = rsqrtf(sum_var * (1.0f / static_cast(LN_SMALL_C)) + eps); #pragma unroll for (int k = 0; k < pairs / Threads; ++k) { const int p = threadIdx.x + k * Threads; const float2 xv = __half22float2(reinterpret_cast(x)[base2 + p]); const float2 rv = __half22float2(reinterpret_cast(residual)[base2 + p]); const float2 w = __half22float2(reinterpret_cast(weight)[p]); const float2 b = __half22float2(reinterpret_cast(bias)[p]); const float2 prev = __half22float2(reinterpret_cast(shift_state)[base2 + p]); const float2 mix = __half22float2(reinterpret_cast(x_k)[p]); const float x0 = xv.x + rv.x; const float x1 = xv.y + rv.y; const __half2 y2 = __floats2half2_rn((x0 - mean) * rstd * w.x + b.x, (x1 - mean) * rstd * w.y + b.y); const float2 yv = __half22float2(y2); reinterpret_cast<__half2*>(x_out)[base2 + p] = __floats2half2_rn(x0, x1); reinterpret_cast<__half2*>(mixed)[base2 + p] = __floats2half2_rn(yv.x + (prev.x - yv.x) * mix.x, yv.y + (prev.y - yv.y) * mix.y); reinterpret_cast<__half2*>(shift_state)[base2 + p] = y2; } } template __global__ __launch_bounds__(Threads, 1) void add_layer_norm_cmix_mix_f16_scalar_stats_kernel( const dtype* __restrict__ x, const dtype* __restrict__ residual, dtype* __restrict__ shift_state, const dtype* __restrict__ weight, const dtype* __restrict__ bias, const dtype* __restrict__ x_k, dtype* __restrict__ x_out, dtype* __restrict__ mixed, int64_t rows, float eps) { const int64_t row = blockIdx.x; if (row >= rows) { return; } const int64_t base = row * LN_SMALL_C; const int64_t base2 = base >> 1; constexpr int pairs = LN_SMALL_C >> 1; float sum = 0.0f; #pragma unroll for (int k = 0; k < LN_SMALL_C / Threads; ++k) { const int c = threadIdx.x + k * Threads; sum += __half2float(*reinterpret_cast(x + base + c)) + __half2float(*reinterpret_cast(residual + base + c)); } sum = block_sum_t(sum); const float mean = sum * (1.0f / static_cast(LN_SMALL_C)); float sum_var = 0.0f; #pragma unroll for (int k = 0; k < LN_SMALL_C / Threads; ++k) { const int c = threadIdx.x + k * Threads; const float v = __half2float(*reinterpret_cast(x + base + c)) + __half2float(*reinterpret_cast(residual + base + c)); const float d = v - mean; sum_var += d * d; } sum_var = block_sum_t(sum_var); const float rstd = rsqrtf(sum_var * (1.0f / static_cast(LN_SMALL_C)) + eps); #pragma unroll for (int k = 0; k < pairs / Threads; ++k) { const int p = threadIdx.x + k * Threads; const float2 xv = __half22float2(reinterpret_cast(x)[base2 + p]); const float2 rv = __half22float2(reinterpret_cast(residual)[base2 + p]); const float2 w = __half22float2(reinterpret_cast(weight)[p]); const float2 b = __half22float2(reinterpret_cast(bias)[p]); const float2 prev = __half22float2(reinterpret_cast(shift_state)[base2 + p]); const float2 mix = __half22float2(reinterpret_cast(x_k)[p]); const float x0 = xv.x + rv.x; const float x1 = xv.y + rv.y; const __half2 y2 = __floats2half2_rn((x0 - mean) * rstd * w.x + b.x, (x1 - mean) * rstd * w.y + b.y); const float2 yv = __half22float2(y2); reinterpret_cast<__half2*>(x_out)[base2 + p] = __floats2half2_rn(x0, x1); reinterpret_cast<__half2*>(mixed)[base2 + p] = __floats2half2_rn(yv.x + (prev.x - yv.x) * mix.x, yv.y + (prev.y - yv.y) * mix.y); reinterpret_cast<__half2*>(shift_state)[base2 + p] = y2; } } template __global__ __launch_bounds__(Threads, 1) void add_layer_norm_tmix_mix6_f16_kernel( const dtype* __restrict__ x, const dtype* __restrict__ residual, dtype* __restrict__ shift_state, const dtype* __restrict__ weight, const dtype* __restrict__ bias, const dtype* __restrict__ x_r, const dtype* __restrict__ x_w, const dtype* __restrict__ x_k, const dtype* __restrict__ x_v, const dtype* __restrict__ x_a, const dtype* __restrict__ x_g, dtype* __restrict__ x_out, dtype* __restrict__ out_r, dtype* __restrict__ out_w, dtype* __restrict__ out_k, dtype* __restrict__ out_v, dtype* __restrict__ out_a, dtype* __restrict__ out_g, int64_t rows, float eps) { const int64_t row = blockIdx.x; if (row >= rows) { return; } const int64_t base2 = row * (LN_SMALL_C >> 1); constexpr int pairs = LN_SMALL_C >> 1; float sum = 0.0f; #pragma unroll for (int k = 0; k < pairs / Threads; ++k) { const int p = threadIdx.x + k * Threads; const float2 xv = __half22float2(reinterpret_cast(x)[base2 + p]); const float2 rv = __half22float2(reinterpret_cast(residual)[base2 + p]); sum += xv.x + rv.x + xv.y + rv.y; } sum = block_sum_t(sum); const float mean = sum * (1.0f / static_cast(LN_SMALL_C)); float sum_var = 0.0f; #pragma unroll for (int k = 0; k < pairs / Threads; ++k) { const int p = threadIdx.x + k * Threads; const float2 xv = __half22float2(reinterpret_cast(x)[base2 + p]); const float2 rv = __half22float2(reinterpret_cast(residual)[base2 + p]); const float x0 = xv.x + rv.x; const float x1 = xv.y + rv.y; const float d0 = x0 - mean; const float d1 = x1 - mean; sum_var += d0 * d0 + d1 * d1; } sum_var = block_sum_t(sum_var); const float rstd = rsqrtf(sum_var * (1.0f / static_cast(LN_SMALL_C)) + eps); #pragma unroll for (int k = 0; k < pairs / Threads; ++k) { const int p = threadIdx.x + k * Threads; const float2 xv = __half22float2(reinterpret_cast(x)[base2 + p]); const float2 rv = __half22float2(reinterpret_cast(residual)[base2 + p]); const float2 w = __half22float2(reinterpret_cast(weight)[p]); const float2 b = __half22float2(reinterpret_cast(bias)[p]); const float2 prev = __half22float2(reinterpret_cast(shift_state)[base2 + p]); const float x0 = xv.x + rv.x; const float x1 = xv.y + rv.y; const __half2 y2 = __floats2half2_rn((x0 - mean) * rstd * w.x + b.x, (x1 - mean) * rstd * w.y + b.y); const float2 yv = __half22float2(y2); const float dx0 = prev.x - yv.x; const float dx1 = prev.y - yv.y; const float2 mr = __half22float2(reinterpret_cast(x_r)[p]); const float2 mw = __half22float2(reinterpret_cast(x_w)[p]); const float2 mk = __half22float2(reinterpret_cast(x_k)[p]); const float2 mv = __half22float2(reinterpret_cast(x_v)[p]); const float2 ma = __half22float2(reinterpret_cast(x_a)[p]); const float2 mg = __half22float2(reinterpret_cast(x_g)[p]); reinterpret_cast<__half2*>(x_out)[base2 + p] = __floats2half2_rn(x0, x1); reinterpret_cast<__half2*>(out_r)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mr.x, yv.y + dx1 * mr.y); reinterpret_cast<__half2*>(out_w)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mw.x, yv.y + dx1 * mw.y); reinterpret_cast<__half2*>(out_k)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mk.x, yv.y + dx1 * mk.y); reinterpret_cast<__half2*>(out_v)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mv.x, yv.y + dx1 * mv.y); reinterpret_cast<__half2*>(out_a)[base2 + p] = __floats2half2_rn(yv.x + dx0 * ma.x, yv.y + dx1 * ma.y); reinterpret_cast<__half2*>(out_g)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mg.x, yv.y + dx1 * mg.y); reinterpret_cast<__half2*>(shift_state)[base2 + p] = y2; } } template __global__ __launch_bounds__(Threads, 1) void add_layer_norm_tmix_mix6_f16_scalar_stats_kernel( const dtype* __restrict__ x, const dtype* __restrict__ residual, dtype* __restrict__ shift_state, const dtype* __restrict__ weight, const dtype* __restrict__ bias, const dtype* __restrict__ x_r, const dtype* __restrict__ x_w, const dtype* __restrict__ x_k, const dtype* __restrict__ x_v, const dtype* __restrict__ x_a, const dtype* __restrict__ x_g, dtype* __restrict__ x_out, dtype* __restrict__ out_r, dtype* __restrict__ out_w, dtype* __restrict__ out_k, dtype* __restrict__ out_v, dtype* __restrict__ out_a, dtype* __restrict__ out_g, int64_t rows, float eps) { const int64_t row = blockIdx.x; if (row >= rows) { return; } const int64_t base = row * LN_SMALL_C; const int64_t base2 = row * (LN_SMALL_C >> 1); constexpr int pairs = LN_SMALL_C >> 1; float sum = 0.0f; #pragma unroll for (int k = 0; k < LN_SMALL_C / Threads; ++k) { const int c = threadIdx.x + k * Threads; sum += __half2float(*reinterpret_cast(x + base + c)) + __half2float(*reinterpret_cast(residual + base + c)); } sum = block_sum_t(sum); const float mean = sum * (1.0f / static_cast(LN_SMALL_C)); float sum_var = 0.0f; #pragma unroll for (int k = 0; k < LN_SMALL_C / Threads; ++k) { const int c = threadIdx.x + k * Threads; const float v = __half2float(*reinterpret_cast(x + base + c)) + __half2float(*reinterpret_cast(residual + base + c)); const float d = v - mean; sum_var += d * d; } sum_var = block_sum_t(sum_var); const float rstd = rsqrtf(sum_var * (1.0f / static_cast(LN_SMALL_C)) + eps); #pragma unroll for (int k = 0; k < pairs / Threads; ++k) { const int p = threadIdx.x + k * Threads; const float2 xv = __half22float2(reinterpret_cast(x)[base2 + p]); const float2 rv = __half22float2(reinterpret_cast(residual)[base2 + p]); const float2 w = __half22float2(reinterpret_cast(weight)[p]); const float2 b = __half22float2(reinterpret_cast(bias)[p]); const float2 prev = __half22float2(reinterpret_cast(shift_state)[base2 + p]); const float x0 = xv.x + rv.x; const float x1 = xv.y + rv.y; const __half2 y2 = __floats2half2_rn((x0 - mean) * rstd * w.x + b.x, (x1 - mean) * rstd * w.y + b.y); const float2 yv = __half22float2(y2); const float dx0 = prev.x - yv.x; const float dx1 = prev.y - yv.y; const float2 mr = __half22float2(reinterpret_cast(x_r)[p]); const float2 mw = __half22float2(reinterpret_cast(x_w)[p]); const float2 mk = __half22float2(reinterpret_cast(x_k)[p]); const float2 mv = __half22float2(reinterpret_cast(x_v)[p]); const float2 ma = __half22float2(reinterpret_cast(x_a)[p]); const float2 mg = __half22float2(reinterpret_cast(x_g)[p]); reinterpret_cast<__half2*>(x_out)[base2 + p] = __floats2half2_rn(x0, x1); reinterpret_cast<__half2*>(out_r)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mr.x, yv.y + dx1 * mr.y); reinterpret_cast<__half2*>(out_w)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mw.x, yv.y + dx1 * mw.y); reinterpret_cast<__half2*>(out_k)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mk.x, yv.y + dx1 * mk.y); reinterpret_cast<__half2*>(out_v)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mv.x, yv.y + dx1 * mv.y); reinterpret_cast<__half2*>(out_a)[base2 + p] = __floats2half2_rn(yv.x + dx0 * ma.x, yv.y + dx1 * ma.y); reinterpret_cast<__half2*>(out_g)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mg.x, yv.y + dx1 * mg.y); reinterpret_cast<__half2*>(shift_state)[base2 + p] = y2; } } template __global__ __launch_bounds__(Threads, 1) void add_last_layer_norm_f16_small_kernel( const dtype* __restrict__ x, const dtype* __restrict__ residual, const dtype* __restrict__ weight, const dtype* __restrict__ bias, dtype* __restrict__ y, int64_t B, int64_t T, float eps) { const int64_t bidx = blockIdx.x; if (bidx >= B) { return; } const int64_t src = (bidx * T + (T - 1)) * LN_SMALL_C; const int64_t dst = bidx * LN_SMALL_C; float sum = 0.0f; if constexpr (VecStats) { #pragma unroll for (int k = 0; k < (LN_SMALL_C / 2) / Threads; ++k) { const int idx = threadIdx.x + k * Threads; const float2 xv = __half22float2(reinterpret_cast(x + src)[idx]); const float2 rv = __half22float2(reinterpret_cast(residual + src)[idx]); sum += xv.x + rv.x + xv.y + rv.y; } } else { #pragma unroll for (int k = 0; k < LN_SMALL_C / Threads; ++k) { const int c = threadIdx.x + k * Threads; const float v = __half2float(*reinterpret_cast(x + src + c)) + __half2float(*reinterpret_cast(residual + src + c)); sum += v; } } sum = block_sum_t(sum); const float mean = sum * (1.0f / static_cast(LN_SMALL_C)); float sum_var = 0.0f; if constexpr (VecStats) { #pragma unroll for (int k = 0; k < (LN_SMALL_C / 2) / Threads; ++k) { const int idx = threadIdx.x + k * Threads; const float2 xv = __half22float2(reinterpret_cast(x + src)[idx]); const float2 rv = __half22float2(reinterpret_cast(residual + src)[idx]); const float dx = xv.x + rv.x - mean; const float dy = xv.y + rv.y - mean; sum_var += dx * dx + dy * dy; } } else { #pragma unroll for (int k = 0; k < LN_SMALL_C / Threads; ++k) { const int c = threadIdx.x + k * Threads; const float v = __half2float(*reinterpret_cast(x + src + c)) + __half2float(*reinterpret_cast(residual + src + c)); const float d = v - mean; sum_var += d * d; } } sum_var = block_sum_t(sum_var); const float rstd = rsqrtf(sum_var * (1.0f / static_cast(LN_SMALL_C)) + eps); if constexpr (VecOut) { #pragma unroll for (int k = 0; k < (LN_SMALL_C / 2) / Threads; ++k) { const int idx = threadIdx.x + k * Threads; const float2 xv = __half22float2(reinterpret_cast(x + src)[idx]); const float2 rv = __half22float2(reinterpret_cast(residual + src)[idx]); const float sx = xv.x + rv.x; const float sy = xv.y + rv.y; const float2 w = __half22float2(reinterpret_cast(weight)[idx]); const float2 bb = __half22float2(reinterpret_cast(bias)[idx]); reinterpret_cast<__half2*>(y + dst)[idx] = __floats2half2_rn( (sx - mean) * rstd * w.x + bb.x, (sy - mean) * rstd * w.y + bb.y); } } else { #pragma unroll for (int k = 0; k < LN_SMALL_C / Threads; ++k) { const int c = threadIdx.x + k * Threads; const float v = __half2float(*reinterpret_cast(x + src + c)) + __half2float(*reinterpret_cast(residual + src + c)); const float w = __half2float(*reinterpret_cast(weight + c)); const float bb = __half2float(*reinterpret_cast(bias + c)); *reinterpret_cast<__half*>(y + dst + c) = __float2half_rn((v - mean) * rstd * w + bb); } } } template __global__ __launch_bounds__(Threads, 1) void add_last_layer_norm_f16_generic_kernel( const dtype* __restrict__ x, const dtype* __restrict__ residual, const dtype* __restrict__ weight, const dtype* __restrict__ bias, dtype* __restrict__ y, int64_t B, int64_t T, int C, float eps) { const int64_t bidx = blockIdx.x; if (bidx >= B) { return; } const int64_t src = (bidx * T + (T - 1)) * static_cast(C); const int64_t dst = bidx * static_cast(C); float sum = 0.0f; for (int c = threadIdx.x; c < C; c += Threads) { sum += __half2float(*reinterpret_cast(x + src + c)) + __half2float(*reinterpret_cast(residual + src + c)); } sum = block_sum_t(sum); const float mean = sum / static_cast(C); float sum_var = 0.0f; for (int c = threadIdx.x; c < C; c += Threads) { const float v = __half2float(*reinterpret_cast(x + src + c)) + __half2float(*reinterpret_cast(residual + src + c)); const float d = v - mean; sum_var += d * d; } sum_var = block_sum_t(sum_var); const float rstd = rsqrtf(sum_var / static_cast(C) + eps); const int pairs = C >> 1; for (int p = threadIdx.x; p < pairs; p += Threads) { const float2 xv = __half22float2(reinterpret_cast(x + src)[p]); const float2 rv = __half22float2(reinterpret_cast(residual + src)[p]); const float sx = xv.x + rv.x; const float sy = xv.y + rv.y; const float2 w = __half22float2(reinterpret_cast(weight)[p]); const float2 bb = __half22float2(reinterpret_cast(bias)[p]); reinterpret_cast<__half2*>(y + dst)[p] = __floats2half2_rn( (sx - mean) * rstd * w.x + bb.x, (sy - mean) * rstd * w.y + bb.y); } } template __global__ __launch_bounds__(Threads, 1) void add_layer_norm_cmix_mix_f16_generic_kernel( const dtype* __restrict__ x, const dtype* __restrict__ residual, dtype* __restrict__ shift_state, const dtype* __restrict__ weight, const dtype* __restrict__ bias, const dtype* __restrict__ x_k, dtype* __restrict__ x_out, dtype* __restrict__ mixed, int64_t rows, int C, float eps) { const int64_t row = blockIdx.x; if (row >= rows) { return; } const int64_t base = row * static_cast(C); float sum = 0.0f; for (int c = threadIdx.x; c < C; c += Threads) { sum += __half2float(*reinterpret_cast(x + base + c)) + __half2float(*reinterpret_cast(residual + base + c)); } sum = block_sum_t(sum); const float mean = sum / static_cast(C); float sum_var = 0.0f; for (int c = threadIdx.x; c < C; c += Threads) { const float v = __half2float(*reinterpret_cast(x + base + c)) + __half2float(*reinterpret_cast(residual + base + c)); const float d = v - mean; sum_var += d * d; } sum_var = block_sum_t(sum_var); const float rstd = rsqrtf(sum_var / static_cast(C) + eps); const int pairs = C >> 1; const int64_t base2 = base >> 1; for (int p = threadIdx.x; p < pairs; p += Threads) { const float2 xv = __half22float2(reinterpret_cast(x)[base2 + p]); const float2 rv = __half22float2(reinterpret_cast(residual)[base2 + p]); const float2 w = __half22float2(reinterpret_cast(weight)[p]); const float2 b = __half22float2(reinterpret_cast(bias)[p]); const float2 prev = __half22float2(reinterpret_cast(shift_state)[base2 + p]); const float2 mix = __half22float2(reinterpret_cast(x_k)[p]); const float x0 = xv.x + rv.x; const float x1 = xv.y + rv.y; const __half2 y2 = __floats2half2_rn((x0 - mean) * rstd * w.x + b.x, (x1 - mean) * rstd * w.y + b.y); const float2 yv = __half22float2(y2); reinterpret_cast<__half2*>(x_out)[base2 + p] = __floats2half2_rn(x0, x1); reinterpret_cast<__half2*>(mixed)[base2 + p] = __floats2half2_rn(yv.x + (prev.x - yv.x) * mix.x, yv.y + (prev.y - yv.y) * mix.y); reinterpret_cast<__half2*>(shift_state)[base2 + p] = y2; } } template __global__ __launch_bounds__(Threads, 1) void add_layer_norm_tmix_mix6_f16_generic_kernel( const dtype* __restrict__ x, const dtype* __restrict__ residual, dtype* __restrict__ shift_state, const dtype* __restrict__ weight, const dtype* __restrict__ bias, const dtype* __restrict__ x_r, const dtype* __restrict__ x_w, const dtype* __restrict__ x_k, const dtype* __restrict__ x_v, const dtype* __restrict__ x_a, const dtype* __restrict__ x_g, dtype* __restrict__ x_out, dtype* __restrict__ out_r, dtype* __restrict__ out_w, dtype* __restrict__ out_k, dtype* __restrict__ out_v, dtype* __restrict__ out_a, dtype* __restrict__ out_g, int64_t rows, int C, float eps) { const int64_t row = blockIdx.x; if (row >= rows) { return; } const int64_t base = row * static_cast(C); float sum = 0.0f; for (int c = threadIdx.x; c < C; c += Threads) { sum += __half2float(*reinterpret_cast(x + base + c)) + __half2float(*reinterpret_cast(residual + base + c)); } sum = block_sum_t(sum); const float mean = sum / static_cast(C); float sum_var = 0.0f; for (int c = threadIdx.x; c < C; c += Threads) { const float v = __half2float(*reinterpret_cast(x + base + c)) + __half2float(*reinterpret_cast(residual + base + c)); const float d = v - mean; sum_var += d * d; } sum_var = block_sum_t(sum_var); const float rstd = rsqrtf(sum_var / static_cast(C) + eps); const int pairs = C >> 1; const int64_t base2 = base >> 1; for (int p = threadIdx.x; p < pairs; p += Threads) { const float2 xv = __half22float2(reinterpret_cast(x)[base2 + p]); const float2 rv = __half22float2(reinterpret_cast(residual)[base2 + p]); const float2 w = __half22float2(reinterpret_cast(weight)[p]); const float2 b = __half22float2(reinterpret_cast(bias)[p]); const float2 prev = __half22float2(reinterpret_cast(shift_state)[base2 + p]); const float x0 = xv.x + rv.x; const float x1 = xv.y + rv.y; const __half2 y2 = __floats2half2_rn((x0 - mean) * rstd * w.x + b.x, (x1 - mean) * rstd * w.y + b.y); const float2 yv = __half22float2(y2); const float dx0 = prev.x - yv.x; const float dx1 = prev.y - yv.y; const float2 mr = __half22float2(reinterpret_cast(x_r)[p]); const float2 mw = __half22float2(reinterpret_cast(x_w)[p]); const float2 mk = __half22float2(reinterpret_cast(x_k)[p]); const float2 mv = __half22float2(reinterpret_cast(x_v)[p]); const float2 ma = __half22float2(reinterpret_cast(x_a)[p]); const float2 mg = __half22float2(reinterpret_cast(x_g)[p]); reinterpret_cast<__half2*>(x_out)[base2 + p] = __floats2half2_rn(x0, x1); reinterpret_cast<__half2*>(out_r)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mr.x, yv.y + dx1 * mr.y); reinterpret_cast<__half2*>(out_w)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mw.x, yv.y + dx1 * mw.y); reinterpret_cast<__half2*>(out_k)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mk.x, yv.y + dx1 * mk.y); reinterpret_cast<__half2*>(out_v)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mv.x, yv.y + dx1 * mv.y); reinterpret_cast<__half2*>(out_a)[base2 + p] = __floats2half2_rn(yv.x + dx0 * ma.x, yv.y + dx1 * ma.y); reinterpret_cast<__half2*>(out_g)[base2 + p] = __floats2half2_rn(yv.x + dx0 * mg.x, yv.y + dx1 * mg.y); reinterpret_cast<__half2*>(shift_state)[base2 + p] = y2; } } } // namespace at::Tensor add_f16_cuda(at::Tensor x, at::Tensor y) { TORCH_CHECK((x.numel() % 2) == 0, "add_f16 requires even numel"); auto out = at::empty_like(x); constexpr int threads = 256; const int64_t n_pairs = x.numel() / 2; auto stream = at::cuda::getCurrentCUDAStream(); add_f16_kernel<<(ceil_div(n_pairs, threads)), threads, 0, stream>>>( x.data_ptr(), y.data_ptr(), out.data_ptr(), n_pairs); C10_CUDA_KERNEL_LAUNCH_CHECK(); return out; } void advance_i32_cuda(at::Tensor x, int64_t amount) { TORCH_CHECK(amount >= INT_MIN && amount <= INT_MAX, "advance_i32 amount out of int range"); constexpr int threads = 256; const int64_t n = x.numel(); auto stream = at::cuda::getCurrentCUDAStream(); advance_i32_kernel<<(ceil_div(n, threads)), threads, 0, stream>>>( x.data_ptr(), static_cast(amount), n); C10_CUDA_KERNEL_LAUNCH_CHECK(); } at::Tensor layer_norm_f16_cuda(at::Tensor x, at::Tensor weight, at::Tensor bias, double eps) { auto y = at::empty_like(x); const int64_t c64 = x.size(-1); TORCH_CHECK(c64 <= INT_MAX, "C too large"); const int C = static_cast(c64); const int64_t rows = x.numel() / C; auto stream = at::cuda::getCurrentCUDAStream(); if (C == LN_SMALL_C) { if (rows >= 1024) { layer_norm_f16_small_kernel<<(rows), LN_SMALL512_THREADS, 0, stream>>>( x.data_ptr(), weight.data_ptr(), bias.data_ptr(), y.data_ptr(), rows, static_cast(eps)); } else if (rows >= 512) { layer_norm_f16_small_kernel<<(rows), LN_SMALL512_THREADS, 0, stream>>>( x.data_ptr(), weight.data_ptr(), bias.data_ptr(), y.data_ptr(), rows, static_cast(eps)); } else { layer_norm_f16_small_kernel<<(rows), LN_SMALL_THREADS, 0, stream>>>( x.data_ptr(), weight.data_ptr(), bias.data_ptr(), y.data_ptr(), rows, static_cast(eps)); } C10_CUDA_KERNEL_LAUNCH_CHECK(); return y; } layer_norm_f16_kernel<<(rows), LN_THREADS, 0, stream>>>( C, x.data_ptr(), weight.data_ptr(), bias.data_ptr(), y.data_ptr(), rows, static_cast(eps)); C10_CUDA_KERNEL_LAUNCH_CHECK(); return y; } at::Tensor layer_norm_f16_small_cuda(at::Tensor x, at::Tensor weight, at::Tensor bias, double eps) { auto y = at::empty_like(x); const int64_t rows = x.numel() / LN_SMALL_C; auto stream = at::cuda::getCurrentCUDAStream(); layer_norm_f16_small_kernel<<(rows), LN_SMALL_THREADS, 0, stream>>>( x.data_ptr(), weight.data_ptr(), bias.data_ptr(), y.data_ptr(), rows, static_cast(eps)); C10_CUDA_KERNEL_LAUNCH_CHECK(); return y; } at::Tensor layer_norm_f16_small512_cuda(at::Tensor x, at::Tensor weight, at::Tensor bias, double eps) { auto y = at::empty_like(x); const int64_t rows = x.numel() / LN_SMALL_C; auto stream = at::cuda::getCurrentCUDAStream(); layer_norm_f16_small_kernel<<(rows), LN_SMALL512_THREADS, 0, stream>>>( x.data_ptr(), weight.data_ptr(), bias.data_ptr(), y.data_ptr(), rows, static_cast(eps)); C10_CUDA_KERNEL_LAUNCH_CHECK(); return y; } at::Tensor emb_ln0_bf16_to_f16_cuda(at::Tensor emb, at::Tensor weight, at::Tensor bias, double eps) { auto out = at::empty(emb.sizes(), emb.options().dtype(at::kHalf)); const int64_t v64 = emb.size(0); const int64_t c64 = emb.size(1); TORCH_CHECK(v64 <= INT_MAX && c64 <= INT_MAX, "emb shape too large"); const int V = static_cast(v64); const int C = static_cast(c64); auto stream = at::cuda::getCurrentCUDAStream(); emb_ln0_bf16_to_f16_kernel<<>>( V, C, reinterpret_cast(emb.data_ptr()), reinterpret_cast(weight.data_ptr()), reinterpret_cast(bias.data_ptr()), out.data_ptr(), static_cast(eps)); C10_CUDA_KERNEL_LAUNCH_CHECK(); return out; } std::vector add_layer_norm_f16_cuda(at::Tensor x, at::Tensor residual, at::Tensor weight, at::Tensor bias, double eps) { auto x_out = at::empty_like(x); auto y = at::empty_like(x); const int64_t c64 = x.size(-1); TORCH_CHECK(c64 <= INT_MAX, "C too large"); const int C = static_cast(c64); const int64_t rows = x.numel() / C; auto stream = at::cuda::getCurrentCUDAStream(); if (C == LN_SMALL_C) { if (rows >= 1024) { add_layer_norm_f16_small_kernel<<(rows), LN_SMALL512_THREADS, 0, stream>>>( x.data_ptr(), residual.data_ptr(), weight.data_ptr(), bias.data_ptr(), x_out.data_ptr(), y.data_ptr(), rows, static_cast(eps)); } else if (rows >= 512) { add_layer_norm_f16_small_kernel<<(rows), LN_SMALL512_THREADS, 0, stream>>>( x.data_ptr(), residual.data_ptr(), weight.data_ptr(), bias.data_ptr(), x_out.data_ptr(), y.data_ptr(), rows, static_cast(eps)); } else { add_layer_norm_f16_small_kernel<<(rows), LN_SMALL_THREADS, 0, stream>>>( x.data_ptr(), residual.data_ptr(), weight.data_ptr(), bias.data_ptr(), x_out.data_ptr(), y.data_ptr(), rows, static_cast(eps)); } C10_CUDA_KERNEL_LAUNCH_CHECK(); return {x_out, y}; } add_layer_norm_f16_kernel<<(rows), LN_THREADS, 0, stream>>>( C, x.data_ptr(), residual.data_ptr(), weight.data_ptr(), bias.data_ptr(), x_out.data_ptr(), y.data_ptr(), rows, static_cast(eps)); C10_CUDA_KERNEL_LAUNCH_CHECK(); return {x_out, y}; } at::Tensor add_last_layer_norm_f16_cuda(at::Tensor x, at::Tensor residual, at::Tensor weight, at::Tensor bias, double eps) { const int64_t B = x.size(0); const int64_t T = x.size(1); const int64_t C = x.size(2); TORCH_CHECK((C % 2) == 0, "add_last_layer_norm_f16 requires even C"); auto y = at::empty({B, C}, x.options()); auto stream = at::cuda::getCurrentCUDAStream(); if (C != LN_SMALL_C) { add_last_layer_norm_f16_generic_kernel<<(B), LN_THREADS, 0, stream>>>( x.data_ptr(), residual.data_ptr(), weight.data_ptr(), bias.data_ptr(), y.data_ptr(), B, T, static_cast(C), static_cast(eps)); C10_CUDA_KERNEL_LAUNCH_CHECK(); return y; } if (B >= 1024) { add_last_layer_norm_f16_small_kernel<<(B), LN_SMALL512_THREADS, 0, stream>>>( x.data_ptr(), residual.data_ptr(), weight.data_ptr(), bias.data_ptr(), y.data_ptr(), B, T, static_cast(eps)); } else if (B >= 512) { add_last_layer_norm_f16_small_kernel<<(B), LN_SMALL512_THREADS, 0, stream>>>( x.data_ptr(), residual.data_ptr(), weight.data_ptr(), bias.data_ptr(), y.data_ptr(), B, T, static_cast(eps)); } else { add_last_layer_norm_f16_small_kernel<<(B), LN_SMALL_THREADS, 0, stream>>>( x.data_ptr(), residual.data_ptr(), weight.data_ptr(), bias.data_ptr(), y.data_ptr(), B, T, static_cast(eps)); } C10_CUDA_KERNEL_LAUNCH_CHECK(); return y; } std::vector add_layer_norm_cmix_mix_f16_cuda( at::Tensor x, at::Tensor residual, at::Tensor shift_state, at::Tensor weight, at::Tensor bias, at::Tensor x_k, double eps) { auto x_out = at::empty_like(x); auto mixed = at::empty_like(x); const int64_t C = x.size(-1); TORCH_CHECK((C % 2) == 0, "add_layer_norm_cmix_mix_f16 requires even C"); const int64_t rows = x.numel() / C; auto stream = at::cuda::getCurrentCUDAStream(); if (C == LN_SMALL_C) { add_layer_norm_cmix_mix_f16_scalar_stats_kernel<<(rows), LN_SMALL_THREADS, 0, stream>>>( x.data_ptr(), residual.data_ptr(), shift_state.data_ptr(), weight.data_ptr(), bias.data_ptr(), x_k.data_ptr(), x_out.data_ptr(), mixed.data_ptr(), rows, static_cast(eps)); } else { add_layer_norm_cmix_mix_f16_generic_kernel<<(rows), LN_THREADS, 0, stream>>>( x.data_ptr(), residual.data_ptr(), shift_state.data_ptr(), weight.data_ptr(), bias.data_ptr(), x_k.data_ptr(), x_out.data_ptr(), mixed.data_ptr(), rows, static_cast(C), static_cast(eps)); } C10_CUDA_KERNEL_LAUNCH_CHECK(); return {x_out, mixed}; } std::vector add_layer_norm_cmix_mix_f16_scalar_stats_cuda( at::Tensor x, at::Tensor residual, at::Tensor shift_state, at::Tensor weight, at::Tensor bias, at::Tensor x_k, double eps) { auto x_out = at::empty_like(x); auto mixed = at::empty_like(x); const int64_t C = x.size(-1); TORCH_CHECK((C % 2) == 0, "add_layer_norm_cmix_mix_f16 requires even C"); const int64_t rows = x.numel() / C; auto stream = at::cuda::getCurrentCUDAStream(); if (C == LN_SMALL_C) { add_layer_norm_cmix_mix_f16_scalar_stats_kernel<<(rows), LN_SMALL_THREADS, 0, stream>>>( x.data_ptr(), residual.data_ptr(), shift_state.data_ptr(), weight.data_ptr(), bias.data_ptr(), x_k.data_ptr(), x_out.data_ptr(), mixed.data_ptr(), rows, static_cast(eps)); } else { add_layer_norm_cmix_mix_f16_generic_kernel<<(rows), LN_THREADS, 0, stream>>>( x.data_ptr(), residual.data_ptr(), shift_state.data_ptr(), weight.data_ptr(), bias.data_ptr(), x_k.data_ptr(), x_out.data_ptr(), mixed.data_ptr(), rows, static_cast(C), static_cast(eps)); } C10_CUDA_KERNEL_LAUNCH_CHECK(); return {x_out, mixed}; } std::vector add_layer_norm_tmix_mix6_f16_cuda( at::Tensor x, at::Tensor residual, at::Tensor shift_state, at::Tensor weight, at::Tensor bias, at::Tensor x_r, at::Tensor x_w, at::Tensor x_k, at::Tensor x_v, at::Tensor x_a, at::Tensor x_g, double eps) { auto x_out = at::empty_like(x); auto out_r = at::empty_like(x); auto out_w = at::empty_like(x); auto out_k = at::empty_like(x); auto out_v = at::empty_like(x); auto out_a = at::empty_like(x); auto out_g = at::empty_like(x); const int64_t C = x.size(-1); TORCH_CHECK((C % 2) == 0, "add_layer_norm_tmix_mix6_f16 requires even C"); const int64_t rows = x.numel() / C; auto stream = at::cuda::getCurrentCUDAStream(); if (C == LN_SMALL_C) { add_layer_norm_tmix_mix6_f16_scalar_stats_kernel<<(rows), LN_SMALL_THREADS, 0, stream>>>( x.data_ptr(), residual.data_ptr(), shift_state.data_ptr(), weight.data_ptr(), bias.data_ptr(), x_r.data_ptr(), x_w.data_ptr(), x_k.data_ptr(), x_v.data_ptr(), x_a.data_ptr(), x_g.data_ptr(), x_out.data_ptr(), out_r.data_ptr(), out_w.data_ptr(), out_k.data_ptr(), out_v.data_ptr(), out_a.data_ptr(), out_g.data_ptr(), rows, static_cast(eps)); } else { add_layer_norm_tmix_mix6_f16_generic_kernel<<(rows), LN_THREADS, 0, stream>>>( x.data_ptr(), residual.data_ptr(), shift_state.data_ptr(), weight.data_ptr(), bias.data_ptr(), x_r.data_ptr(), x_w.data_ptr(), x_k.data_ptr(), x_v.data_ptr(), x_a.data_ptr(), x_g.data_ptr(), x_out.data_ptr(), out_r.data_ptr(), out_w.data_ptr(), out_k.data_ptr(), out_v.data_ptr(), out_a.data_ptr(), out_g.data_ptr(), rows, static_cast(C), static_cast(eps)); } C10_CUDA_KERNEL_LAUNCH_CHECK(); return {x_out, out_r, out_w, out_k, out_v, out_a, out_g}; } std::vector add_layer_norm_tmix_mix6_f16_cfg_cuda( at::Tensor x, at::Tensor residual, at::Tensor shift_state, at::Tensor weight, at::Tensor bias, at::Tensor x_r, at::Tensor x_w, at::Tensor x_k, at::Tensor x_v, at::Tensor x_a, at::Tensor x_g, double eps, int threads) { auto x_out = at::empty_like(x); auto out_r = at::empty_like(x); auto out_w = at::empty_like(x); auto out_k = at::empty_like(x); auto out_v = at::empty_like(x); auto out_a = at::empty_like(x); auto out_g = at::empty_like(x); const int64_t rows = x.numel() / LN_SMALL_C; auto stream = at::cuda::getCurrentCUDAStream(); if (threads == 256) { add_layer_norm_tmix_mix6_f16_kernel<256><<(rows), 256, 0, stream>>>( x.data_ptr(), residual.data_ptr(), shift_state.data_ptr(), weight.data_ptr(), bias.data_ptr(), x_r.data_ptr(), x_w.data_ptr(), x_k.data_ptr(), x_v.data_ptr(), x_a.data_ptr(), x_g.data_ptr(), x_out.data_ptr(), out_r.data_ptr(), out_w.data_ptr(), out_k.data_ptr(), out_v.data_ptr(), out_a.data_ptr(), out_g.data_ptr(), rows, static_cast(eps)); } else if (threads == 512) { add_layer_norm_tmix_mix6_f16_kernel<512><<(rows), 512, 0, stream>>>( x.data_ptr(), residual.data_ptr(), shift_state.data_ptr(), weight.data_ptr(), bias.data_ptr(), x_r.data_ptr(), x_w.data_ptr(), x_k.data_ptr(), x_v.data_ptr(), x_a.data_ptr(), x_g.data_ptr(), x_out.data_ptr(), out_r.data_ptr(), out_w.data_ptr(), out_k.data_ptr(), out_v.data_ptr(), out_a.data_ptr(), out_g.data_ptr(), rows, static_cast(eps)); } else { add_layer_norm_tmix_mix6_f16_kernel<1024><<(rows), 1024, 0, stream>>>( x.data_ptr(), residual.data_ptr(), shift_state.data_ptr(), weight.data_ptr(), bias.data_ptr(), x_r.data_ptr(), x_w.data_ptr(), x_k.data_ptr(), x_v.data_ptr(), x_a.data_ptr(), x_g.data_ptr(), x_out.data_ptr(), out_r.data_ptr(), out_w.data_ptr(), out_k.data_ptr(), out_v.data_ptr(), out_a.data_ptr(), out_g.data_ptr(), rows, static_cast(eps)); } C10_CUDA_KERNEL_LAUNCH_CHECK(); return {x_out, out_r, out_w, out_k, out_v, out_a, out_g}; } std::vector add_layer_norm_tmix_mix6_f16_scalar_stats_cuda( at::Tensor x, at::Tensor residual, at::Tensor shift_state, at::Tensor weight, at::Tensor bias, at::Tensor x_r, at::Tensor x_w, at::Tensor x_k, at::Tensor x_v, at::Tensor x_a, at::Tensor x_g, double eps) { auto x_out = at::empty_like(x); auto out_r = at::empty_like(x); auto out_w = at::empty_like(x); auto out_k = at::empty_like(x); auto out_v = at::empty_like(x); auto out_a = at::empty_like(x); auto out_g = at::empty_like(x); const int64_t rows = x.numel() / LN_SMALL_C; auto stream = at::cuda::getCurrentCUDAStream(); add_layer_norm_tmix_mix6_f16_scalar_stats_kernel<<(rows), LN_SMALL_THREADS, 0, stream>>>( x.data_ptr(), residual.data_ptr(), shift_state.data_ptr(), weight.data_ptr(), bias.data_ptr(), x_r.data_ptr(), x_w.data_ptr(), x_k.data_ptr(), x_v.data_ptr(), x_a.data_ptr(), x_g.data_ptr(), x_out.data_ptr(), out_r.data_ptr(), out_w.data_ptr(), out_k.data_ptr(), out_v.data_ptr(), out_a.data_ptr(), out_g.data_ptr(), rows, static_cast(eps)); C10_CUDA_KERNEL_LAUNCH_CHECK(); return {x_out, out_r, out_w, out_k, out_v, out_a, out_g}; } std::vector add_layer_norm_cmix_mix_f16_cfg_cuda( at::Tensor x, at::Tensor residual, at::Tensor shift_state, at::Tensor weight, at::Tensor bias, at::Tensor x_k, double eps, int threads) { auto x_out = at::empty_like(x); auto mixed = at::empty_like(x); const int64_t rows = x.numel() / LN_SMALL_C; auto stream = at::cuda::getCurrentCUDAStream(); if (threads == 256) { add_layer_norm_cmix_mix_f16_kernel<256><<(rows), 256, 0, stream>>>( x.data_ptr(), residual.data_ptr(), shift_state.data_ptr(), weight.data_ptr(), bias.data_ptr(), x_k.data_ptr(), x_out.data_ptr(), mixed.data_ptr(), rows, static_cast(eps)); } else if (threads == 512) { add_layer_norm_cmix_mix_f16_kernel<512><<(rows), 512, 0, stream>>>( x.data_ptr(), residual.data_ptr(), shift_state.data_ptr(), weight.data_ptr(), bias.data_ptr(), x_k.data_ptr(), x_out.data_ptr(), mixed.data_ptr(), rows, static_cast(eps)); } else { add_layer_norm_cmix_mix_f16_kernel<1024><<(rows), 1024, 0, stream>>>( x.data_ptr(), residual.data_ptr(), shift_state.data_ptr(), weight.data_ptr(), bias.data_ptr(), x_k.data_ptr(), x_out.data_ptr(), mixed.data_ptr(), rows, static_cast(eps)); } C10_CUDA_KERNEL_LAUNCH_CHECK(); return {x_out, mixed}; } at::Tensor linear_f16_cuda(at::Tensor x, at::Tensor weight) { const int64_t k64 = x.size(-1); const int64_t n64 = weight.size(1); TORCH_CHECK(k64 <= INT_MAX && n64 <= INT_MAX, "linear_f16 K/N too large"); const int k = static_cast(k64); const int n = static_cast(n64); const int64_t m64 = x.numel() / k64; TORCH_CHECK(m64 <= INT_MAX, "linear_f16 M too large"); const int m = static_cast(m64); std::vector out_sizes(x.sizes().begin(), x.sizes().end()); out_sizes.back() = n64; auto y = at::empty(out_sizes, x.options()); if (m == 0 || n == 0 || k == 0) { return y; } // Row-major y[M,N] = x[M,K] @ weight[K,N] is column-major // y^T[N,M] = weight^T[N,K] @ x^T[K,M]. const float alpha = 1.0f; const float beta = 0.0f; cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); check_cublas(cublasGemmEx( handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, weight.data_ptr(), CUDA_R_16F, n, x.data_ptr(), CUDA_R_16F, k, &beta, y.data_ptr(), CUDA_R_16F, n, CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP), "linear_f16 cublasGemmEx"); return y; } at::Tensor linear_f16_orig_cuda(at::Tensor x, at::Tensor weight_orig) { const int64_t k64 = x.size(-1); const int64_t n64 = weight_orig.size(0); TORCH_CHECK(k64 <= INT_MAX && n64 <= INT_MAX, "linear_f16_orig K/N too large"); const int k = static_cast(k64); const int n = static_cast(n64); const int64_t m64 = x.numel() / k64; TORCH_CHECK(m64 <= INT_MAX, "linear_f16_orig M too large"); const int m = static_cast(m64); std::vector out_sizes(x.sizes().begin(), x.sizes().end()); out_sizes.back() = n64; auto y = at::empty(out_sizes, x.options()); if (m == 0 || n == 0 || k == 0) { return y; } // weight_orig is row-major [N,K], i.e. column-major [K,N]. // Row-major y[M,N] = x[M,K] @ weight_orig[N,K]^T becomes // column-major y^T[N,M] = opT(weight_orig_col[K,N]) @ x_col[K,M]. const float alpha = 1.0f; const float beta = 0.0f; cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); check_cublas(cublasGemmEx( handle, CUBLAS_OP_T, CUBLAS_OP_N, n, m, k, &alpha, weight_orig.data_ptr(), CUDA_R_16F, k, x.data_ptr(), CUDA_R_16F, k, &beta, y.data_ptr(), CUDA_R_16F, n, CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP), "linear_f16_orig cublasGemmEx"); return y; } template at::Tensor linear_orig_rows_f16_cuda_impl(at::Tensor x, at::Tensor weight_orig) { const int64_t k64 = x.size(-1); const int64_t n64 = weight_orig.size(0); TORCH_CHECK(k64 <= INT_MAX && n64 <= INT_MAX, "linear_orig_rows_f16 K/N too large"); const int K = static_cast(k64); const int N = static_cast(n64); const int64_t m64 = x.numel() / k64; TORCH_CHECK(m64 <= INT_MAX, "linear_orig_rows_f16 M too large"); const int M = static_cast(m64); std::vector out_sizes(x.sizes().begin(), x.sizes().end()); out_sizes.back() = n64; auto y = at::empty(out_sizes, x.options()); if (M == 0 || N == 0 || K == 0) { return y; } auto stream = at::cuda::getCurrentCUDAStream(); linear_orig_rows_f16_kernel<128, RowTile, OutTile><<>>( M, K, N, x.data_ptr(), weight_orig.data_ptr(), y.data_ptr()); C10_CUDA_KERNEL_LAUNCH_CHECK(); return y; } template at::Tensor linear_orig_rows_cfg_f16_cuda_impl(at::Tensor x, at::Tensor weight_orig) { const int64_t k64 = x.size(-1); const int64_t n64 = weight_orig.size(0); TORCH_CHECK(k64 <= INT_MAX && n64 <= INT_MAX, "linear_orig_rows_cfg_f16 K/N too large"); const int K = static_cast(k64); const int N = static_cast(n64); const int64_t m64 = x.numel() / k64; TORCH_CHECK(m64 <= INT_MAX, "linear_orig_rows_cfg_f16 M too large"); const int M = static_cast(m64); std::vector out_sizes(x.sizes().begin(), x.sizes().end()); out_sizes.back() = n64; auto y = at::empty(out_sizes, x.options()); if (M == 0 || N == 0 || K == 0) { return y; } auto stream = at::cuda::getCurrentCUDAStream(); linear_orig_rows_f16_kernel<<>>( M, K, N, x.data_ptr(), weight_orig.data_ptr(), y.data_ptr()); C10_CUDA_KERNEL_LAUNCH_CHECK(); return y; } at::Tensor linear_orig_rows_f16_cuda(at::Tensor x, at::Tensor weight_orig, int64_t row_tile, int64_t out_tile) { if (row_tile == 1 && out_tile == 2) return linear_orig_rows_f16_cuda_impl<1, 2>(x, weight_orig); if (row_tile == 1 && out_tile == 4) return linear_orig_rows_f16_cuda_impl<1, 4>(x, weight_orig); if (row_tile == 1 && out_tile == 8) return linear_orig_rows_f16_cuda_impl<1, 8>(x, weight_orig); if (row_tile == 1 && out_tile == 16) return linear_orig_rows_f16_cuda_impl<1, 16>(x, weight_orig); if (row_tile == 2 && out_tile == 2) return linear_orig_rows_f16_cuda_impl<2, 2>(x, weight_orig); if (row_tile == 2 && out_tile == 4) return linear_orig_rows_f16_cuda_impl<2, 4>(x, weight_orig); if (row_tile == 2 && out_tile == 8) return linear_orig_rows_f16_cuda_impl<2, 8>(x, weight_orig); if (row_tile == 3 && out_tile == 2) return linear_orig_rows_f16_cuda_impl<3, 2>(x, weight_orig); if (row_tile == 3 && out_tile == 4) return linear_orig_rows_f16_cuda_impl<3, 4>(x, weight_orig); if (row_tile == 3 && out_tile == 8) return linear_orig_rows_f16_cuda_impl<3, 8>(x, weight_orig); if (row_tile == 4 && out_tile == 2) return linear_orig_rows_f16_cuda_impl<4, 2>(x, weight_orig); if (row_tile == 4 && out_tile == 4) return linear_orig_rows_f16_cuda_impl<4, 4>(x, weight_orig); if (row_tile == 4 && out_tile == 8) return linear_orig_rows_f16_cuda_impl<4, 8>(x, weight_orig); if (row_tile == 8 && out_tile == 2) return linear_orig_rows_f16_cuda_impl<8, 2>(x, weight_orig); if (row_tile == 8 && out_tile == 4) return linear_orig_rows_f16_cuda_impl<8, 4>(x, weight_orig); if (row_tile == 16 && out_tile == 1) return linear_orig_rows_f16_cuda_impl<16, 1>(x, weight_orig); if (row_tile == 16 && out_tile == 2) return linear_orig_rows_f16_cuda_impl<16, 2>(x, weight_orig); if (row_tile == 16 && out_tile == 4) return linear_orig_rows_f16_cuda_impl<16, 4>(x, weight_orig); TORCH_CHECK(false, "unsupported linear_orig_rows_f16 row_tile/out_tile"); } at::Tensor linear_orig_rows_cfg_f16_cuda(at::Tensor x, at::Tensor weight_orig, int64_t threads, int64_t row_tile, int64_t out_tile) { if (threads == 64 && row_tile == 1 && out_tile == 4) return linear_orig_rows_cfg_f16_cuda_impl<64, 1, 4>(x, weight_orig); if (threads == 64 && row_tile == 1 && out_tile == 8) return linear_orig_rows_cfg_f16_cuda_impl<64, 1, 8>(x, weight_orig); if (threads == 128 && row_tile == 1 && out_tile == 8) return linear_orig_rows_cfg_f16_cuda_impl<128, 1, 8>(x, weight_orig); if (threads == 256 && row_tile == 1 && out_tile == 1) return linear_orig_rows_cfg_f16_cuda_impl<256, 1, 1>(x, weight_orig); if (threads == 32 && row_tile == 4 && out_tile == 4) return linear_orig_rows_cfg_f16_cuda_impl<32, 4, 4>(x, weight_orig); if (threads == 64 && row_tile == 4 && out_tile == 4) return linear_orig_rows_cfg_f16_cuda_impl<64, 4, 4>(x, weight_orig); if (threads == 96 && row_tile == 4 && out_tile == 4) return linear_orig_rows_cfg_f16_cuda_impl<96, 4, 4>(x, weight_orig); if (threads == 32 && row_tile == 4 && out_tile == 8) return linear_orig_rows_cfg_f16_cuda_impl<32, 4, 8>(x, weight_orig); if (threads == 64 && row_tile == 4 && out_tile == 8) return linear_orig_rows_cfg_f16_cuda_impl<64, 4, 8>(x, weight_orig); if (threads == 32 && row_tile == 8 && out_tile == 4) return linear_orig_rows_cfg_f16_cuda_impl<32, 8, 4>(x, weight_orig); if (threads == 64 && row_tile == 8 && out_tile == 4) return linear_orig_rows_cfg_f16_cuda_impl<64, 8, 4>(x, weight_orig); if (threads == 32 && row_tile == 2 && out_tile == 4) return linear_orig_rows_cfg_f16_cuda_impl<32, 2, 4>(x, weight_orig); if (threads == 64 && row_tile == 2 && out_tile == 2) return linear_orig_rows_cfg_f16_cuda_impl<64, 2, 2>(x, weight_orig); if (threads == 64 && row_tile == 2 && out_tile == 4) return linear_orig_rows_cfg_f16_cuda_impl<64, 2, 4>(x, weight_orig); if (threads == 32 && row_tile == 3 && out_tile == 4) return linear_orig_rows_cfg_f16_cuda_impl<32, 3, 4>(x, weight_orig); if (threads == 64 && row_tile == 3 && out_tile == 4) return linear_orig_rows_cfg_f16_cuda_impl<64, 3, 4>(x, weight_orig); if (threads == 96 && row_tile == 3 && out_tile == 4) return linear_orig_rows_cfg_f16_cuda_impl<96, 3, 4>(x, weight_orig); if (threads == 32 && row_tile == 3 && out_tile == 8) return linear_orig_rows_cfg_f16_cuda_impl<32, 3, 8>(x, weight_orig); if (threads == 64 && row_tile == 3 && out_tile == 8) return linear_orig_rows_cfg_f16_cuda_impl<64, 3, 8>(x, weight_orig); TORCH_CHECK(false, "unsupported linear_orig_rows_cfg_f16 threads/row_tile/out_tile"); } template at::Tensor linear_orig_row1_exact_f16_cuda_impl(at::Tensor x, at::Tensor weight_orig) { const int64_t k64 = x.size(-1); const int64_t n64 = weight_orig.size(0); TORCH_CHECK(k64 <= INT_MAX && n64 <= INT_MAX, "linear_orig_row1_exact_f16 K/N too large"); TORCH_CHECK((n64 % OutTile) == 0, "linear_orig_row1_exact_f16 requires N divisible by out_tile"); TORCH_CHECK((k64 % (Use4 ? 4 : 2)) == 0, "linear_orig_row1_exact_f16 unsupported K alignment"); const int K = static_cast(k64); const int N = static_cast(n64); const int64_t m64 = x.numel() / k64; TORCH_CHECK(m64 == 1, "linear_orig_row1_exact_f16 requires one row"); std::vector out_sizes(x.sizes().begin(), x.sizes().end()); out_sizes.back() = n64; auto y = at::empty(out_sizes, x.options()); if constexpr (Use4) { linear_orig_row1_exact4_f16_kernel<<>>( K, N, x.data_ptr(), weight_orig.data_ptr(), y.data_ptr()); } else { linear_orig_row1_exact_f16_kernel<<>>( K, N, x.data_ptr(), weight_orig.data_ptr(), y.data_ptr()); } C10_CUDA_KERNEL_LAUNCH_CHECK(); return y; } template at::Tensor linear_orig_row2_exact_f16_cuda_impl(at::Tensor x, at::Tensor weight_orig) { const int64_t k64 = x.size(-1); const int64_t n64 = weight_orig.size(0); TORCH_CHECK(k64 <= INT_MAX && n64 <= INT_MAX, "linear_orig_row2_exact_f16 K/N too large"); TORCH_CHECK((n64 % OutTile) == 0, "linear_orig_row2_exact_f16 requires N divisible by out_tile"); TORCH_CHECK((k64 % (Use4 ? 4 : 2)) == 0, "linear_orig_row2_exact_f16 unsupported K alignment"); const int K = static_cast(k64); const int N = static_cast(n64); const int64_t m64 = x.numel() / k64; TORCH_CHECK(m64 == 2, "linear_orig_row2_exact_f16 requires two rows"); std::vector out_sizes(x.sizes().begin(), x.sizes().end()); out_sizes.back() = n64; auto y = at::empty(out_sizes, x.options()); if constexpr (Use4) { linear_orig_row2_exact4_f16_kernel<<>>( K, N, x.data_ptr(), weight_orig.data_ptr(), y.data_ptr()); } else { linear_orig_row2_exact_f16_kernel<<>>( K, N, x.data_ptr(), weight_orig.data_ptr(), y.data_ptr()); } C10_CUDA_KERNEL_LAUNCH_CHECK(); return y; } at::Tensor linear_orig_rows_exact_f16_cuda(at::Tensor x, at::Tensor weight_orig, int64_t threads, int64_t out_tile, bool use4) { const int64_t rows = x.numel() / x.size(-1); if (rows == 1) { if (!use4 && threads == 128 && out_tile == 2) return linear_orig_row1_exact_f16_cuda_impl<128, 2, false>(x, weight_orig); if (use4 && threads == 128 && out_tile == 2) return linear_orig_row1_exact_f16_cuda_impl<128, 2, true>(x, weight_orig); } if (rows == 2) { if (use4 && threads == 64 && out_tile == 2) return linear_orig_row2_exact_f16_cuda_impl<64, 2, true>(x, weight_orig); if (use4 && threads == 256 && out_tile == 1) return linear_orig_row2_exact_f16_cuda_impl<256, 1, true>(x, weight_orig); if (!use4 && threads == 128 && out_tile == 2) return linear_orig_row2_exact_f16_cuda_impl<128, 2, false>(x, weight_orig); } TORCH_CHECK(false, "unsupported linear_orig_rows_exact_f16 rows/threads/out_tile/use4"); } at::Tensor linear_orig_wmma16_f16_cuda(at::Tensor x, at::Tensor weight_orig) { const int64_t k64 = x.size(-1); const int64_t n64 = weight_orig.size(0); TORCH_CHECK(k64 <= INT_MAX && n64 <= INT_MAX, "linear_orig_wmma16_f16 K/N too large"); const int K = static_cast(k64); const int N = static_cast(n64); const int64_t m64 = x.numel() / k64; TORCH_CHECK(m64 <= INT_MAX, "linear_orig_wmma16_f16 M too large"); const int M = static_cast(m64); TORCH_CHECK((K % 16) == 0 && (N % 16) == 0, "linear_orig_wmma16_f16 requires K/N multiple of 16"); std::vector out_sizes(x.sizes().begin(), x.sizes().end()); out_sizes.back() = n64; auto y = at::empty(out_sizes, x.options()); if (M == 0 || N == 0 || K == 0) { return y; } auto stream = at::cuda::getCurrentCUDAStream(); linear_orig_wmma16_f16_kernel<<>>( M, K, N, x.data_ptr(), weight_orig.data_ptr(), y.data_ptr()); C10_CUDA_KERNEL_LAUNCH_CHECK(); return y; } at::Tensor linear_f16_orig_lt_cfg_cuda(at::Tensor x, at::Tensor weight_orig, int64_t workspace_mb, int64_t algo_index); at::Tensor linear_f16_orig_lt_cuda(at::Tensor x, at::Tensor weight_orig) { return linear_f16_orig_lt_cfg_cuda(x, weight_orig, 0, 0); } at::Tensor linear_f16_orig_lt_cfg_cuda(at::Tensor x, at::Tensor weight_orig, int64_t workspace_mb, int64_t algo_index) { const int64_t k64 = x.size(-1); const int64_t n64 = weight_orig.size(0); TORCH_CHECK(k64 <= INT_MAX && n64 <= INT_MAX, "linear_f16_orig_lt_cfg K/N too large"); const int k = static_cast(k64); const int n = static_cast(n64); const int64_t m64 = x.numel() / k64; TORCH_CHECK(m64 <= INT_MAX, "linear_f16_orig_lt_cfg M too large"); const int m = static_cast(m64); std::vector out_sizes(x.sizes().begin(), x.sizes().end()); out_sizes.back() = n64; auto y = at::empty(out_sizes, x.options()); if (m == 0 || n == 0 || k == 0) { return y; } const size_t workspace_size = static_cast(workspace_mb) << 20; at::Tensor workspace; void* workspace_ptr = nullptr; if (workspace_size > 0) { workspace = at::empty({static_cast(workspace_size)}, x.options().dtype(at::kByte)); workspace_ptr = workspace.data_ptr(); } static cublasLtHandle_t lt_handle = nullptr; if (lt_handle == nullptr) { check_cublaslt(cublasLtCreate(<_handle), "cublasLtCreate"); } cublasLtMatmulDesc_t op_desc = nullptr; cublasLtMatrixLayout_t a_desc = nullptr; cublasLtMatrixLayout_t b_desc = nullptr; cublasLtMatrixLayout_t c_desc = nullptr; cublasLtMatmulPreference_t pref = nullptr; check_cublaslt(cublasLtMatmulDescCreate(&op_desc, CUBLAS_COMPUTE_32F, CUDA_R_32F), "linear_f16_orig_lt desc"); const cublasOperation_t transa = CUBLAS_OP_T; const cublasOperation_t transb = CUBLAS_OP_N; check_cublaslt(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)), "linear_f16_orig_lt transa"); check_cublaslt(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb)), "linear_f16_orig_lt transb"); check_cublaslt(cublasLtMatrixLayoutCreate(&a_desc, CUDA_R_16F, k, n, k), "linear_f16_orig_lt a layout"); check_cublaslt(cublasLtMatrixLayoutCreate(&b_desc, CUDA_R_16F, k, m, k), "linear_f16_orig_lt b layout"); check_cublaslt(cublasLtMatrixLayoutCreate(&c_desc, CUDA_R_16F, n, m, n), "linear_f16_orig_lt c layout"); check_cublaslt(cublasLtMatmulPreferenceCreate(&pref), "linear_f16_orig_lt preference"); check_cublaslt(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size)), "linear_f16_orig_lt workspace"); std::vector heuristics(64); int returned = 0; check_cublaslt(cublasLtMatmulAlgoGetHeuristic(lt_handle, op_desc, a_desc, b_desc, c_desc, c_desc, pref, static_cast(heuristics.size()), heuristics.data(), &returned), "linear_f16_orig_lt heuristic"); TORCH_CHECK(returned > 0, "linear_f16_orig_lt found no algorithm"); TORCH_CHECK(algo_index < returned, "linear_f16_orig_lt_cfg algo_index=", algo_index, " returned=", returned); const float alpha = 1.0f; const float beta = 0.0f; check_cublaslt(cublasLtMatmul( lt_handle, op_desc, &alpha, weight_orig.data_ptr(), a_desc, x.data_ptr(), b_desc, &beta, y.data_ptr(), c_desc, y.data_ptr(), c_desc, &heuristics[algo_index].algo, workspace_ptr, workspace_size, at::cuda::getCurrentCUDAStream()), "linear_f16_orig_lt matmul"); cublasLtMatmulPreferenceDestroy(pref); cublasLtMatrixLayoutDestroy(c_desc); cublasLtMatrixLayoutDestroy(b_desc); cublasLtMatrixLayoutDestroy(a_desc); cublasLtMatmulDescDestroy(op_desc); return y; } at::Tensor linear_f16_lt_cuda(at::Tensor x, at::Tensor weight) { const int64_t k64 = x.size(-1); const int64_t n64 = weight.size(1); TORCH_CHECK(k64 <= INT_MAX && n64 <= INT_MAX, "linear_f16_lt K/N too large"); const int k = static_cast(k64); const int n = static_cast(n64); const int64_t m64 = x.numel() / k64; TORCH_CHECK(m64 <= INT_MAX, "linear_f16_lt M too large"); const int m = static_cast(m64); std::vector out_sizes(x.sizes().begin(), x.sizes().end()); out_sizes.back() = n64; auto y = at::empty(out_sizes, x.options()); if (m == 0 || n == 0 || k == 0) { return y; } static cublasLtHandle_t lt_handle = nullptr; if (lt_handle == nullptr) { check_cublaslt(cublasLtCreate(<_handle), "cublasLtCreate"); } cublasLtMatmulDesc_t op_desc = nullptr; cublasLtMatrixLayout_t a_desc = nullptr; cublasLtMatrixLayout_t b_desc = nullptr; cublasLtMatrixLayout_t c_desc = nullptr; cublasLtMatmulPreference_t pref = nullptr; check_cublaslt(cublasLtMatmulDescCreate(&op_desc, CUBLAS_COMPUTE_32F, CUDA_R_32F), "cublasLtMatmulDescCreate"); const cublasOperation_t trans = CUBLAS_OP_N; check_cublaslt(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans, sizeof(trans)), "cublasLt set transa"); check_cublaslt(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans, sizeof(trans)), "cublasLt set transb"); check_cublaslt(cublasLtMatrixLayoutCreate(&a_desc, CUDA_R_16F, n, k, n), "cublasLt a layout"); check_cublaslt(cublasLtMatrixLayoutCreate(&b_desc, CUDA_R_16F, k, m, k), "cublasLt b layout"); check_cublaslt(cublasLtMatrixLayoutCreate(&c_desc, CUDA_R_16F, n, m, n), "cublasLt c layout"); check_cublaslt(cublasLtMatmulPreferenceCreate(&pref), "cublasLt preference"); const size_t workspace_size = 0; check_cublaslt(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size)), "cublasLt set workspace"); cublasLtMatmulHeuristicResult_t heuristic = {}; int returned = 0; check_cublaslt(cublasLtMatmulAlgoGetHeuristic(lt_handle, op_desc, a_desc, b_desc, c_desc, c_desc, pref, 1, &heuristic, &returned), "cublasLt heuristic"); TORCH_CHECK(returned > 0, "cublasLt found no algorithm"); const float alpha = 1.0f; const float beta = 0.0f; check_cublaslt(cublasLtMatmul( lt_handle, op_desc, &alpha, weight.data_ptr(), a_desc, x.data_ptr(), b_desc, &beta, y.data_ptr(), c_desc, y.data_ptr(), c_desc, &heuristic.algo, nullptr, 0, at::cuda::getCurrentCUDAStream()), "cublasLtMatmul"); cublasLtMatmulPreferenceDestroy(pref); cublasLtMatrixLayoutDestroy(c_desc); cublasLtMatrixLayoutDestroy(b_desc); cublasLtMatrixLayoutDestroy(a_desc); cublasLtMatmulDescDestroy(op_desc); return y; } template at::Tensor linear_f16_m1_splitk_cuda_impl(at::Tensor x, at::Tensor weight) { const int64_t k64 = x.size(-1); const int64_t n64 = weight.size(1); TORCH_CHECK(k64 <= INT_MAX && n64 <= INT_MAX, "linear_f16_m1_splitk K/N too large"); const int K = static_cast(k64); const int N = static_cast(n64); TORCH_CHECK(x.numel() == k64, "linear_f16_m1_splitk requires M=1"); TORCH_CHECK((N % 64) == 0, "linear_f16_m1_splitk requires N multiple of 64"); std::vector out_sizes(x.sizes().begin(), x.sizes().end()); out_sizes.back() = n64; auto y = at::empty(out_sizes, x.options()); if (K == 0 || N == 0) { return y; } const int chunks = static_cast(ceil_div(K, ChunkK)); auto partial = at::empty({chunks, n64}, x.options().dtype(at::kFloat)); auto stream = at::cuda::getCurrentCUDAStream(); linear_f16_m1_splitk_partial_kernel<<>>( K, N, x.data_ptr(), weight.data_ptr(), partial.data_ptr()); C10_CUDA_KERNEL_LAUNCH_CHECK(); if constexpr (WarpReduce) { linear_f16_m1_splitk_reduce_warp_kernel<<(ceil_div(N / 2, 4)), 128, 0, stream>>>( chunks, N, partial.data_ptr(), y.data_ptr()); } else { linear_f16_m1_splitk_reduce_kernel<<(ceil_div(N / 2, 128)), 128, 0, stream>>>( chunks, N, partial.data_ptr(), y.data_ptr()); } C10_CUDA_KERNEL_LAUNCH_CHECK(); return y; } at::Tensor linear_f16_m1_splitk_cuda(at::Tensor x, at::Tensor weight) { const int64_t K = x.size(-1); const int64_t N = weight.size(1); if (K == 4096 && N == 4096) { return linear_f16_m1_splitk_cuda_impl<160, 1, true>(x, weight); } if (N >= 65536) { return linear_f16_m1_splitk_cuda_impl<768, 2>(x, weight); } if (K == 4096 && N == 16384) { return linear_f16_m1_splitk_cuda_impl<512, 2>(x, weight); } if (K >= 8192) { return linear_f16_m1_splitk_cuda_impl<512, 2>(x, weight); } return linear_f16_m1_splitk_cuda_impl<256, 4>(x, weight); } at::Tensor linear_f16_m1_splitk_cfg_cuda(at::Tensor x, at::Tensor weight, int64_t chunk_k) { switch (chunk_k) { case 64: return linear_f16_m1_splitk_cuda_impl<64, 4>(x, weight); case 96: return linear_f16_m1_splitk_cuda_impl<96, 4>(x, weight); case 112: return linear_f16_m1_splitk_cuda_impl<112, 4>(x, weight); case 128: return linear_f16_m1_splitk_cuda_impl<128, 4>(x, weight); case 144: return linear_f16_m1_splitk_cuda_impl<144, 4>(x, weight); case 152: return linear_f16_m1_splitk_cuda_impl<152, 4>(x, weight); case 160: return linear_f16_m1_splitk_cuda_impl<160, 4>(x, weight); case 168: return linear_f16_m1_splitk_cuda_impl<168, 4>(x, weight); case 176: return linear_f16_m1_splitk_cuda_impl<176, 4>(x, weight); case 184: return linear_f16_m1_splitk_cuda_impl<184, 4>(x, weight); case 192: return linear_f16_m1_splitk_cuda_impl<192, 4>(x, weight); case 208: return linear_f16_m1_splitk_cuda_impl<208, 4>(x, weight); case 224: return linear_f16_m1_splitk_cuda_impl<224, 4>(x, weight); case 256: return linear_f16_m1_splitk_cuda_impl<256, 4>(x, weight); case 384: return linear_f16_m1_splitk_cuda_impl<384, 4>(x, weight); case 512: return linear_f16_m1_splitk_cuda_impl<512, 4>(x, weight); case 640: return linear_f16_m1_splitk_cuda_impl<640, 4>(x, weight); case 768: return linear_f16_m1_splitk_cuda_impl<768, 4>(x, weight); case 896: return linear_f16_m1_splitk_cuda_impl<896, 4>(x, weight); case 1024: return linear_f16_m1_splitk_cuda_impl<1024, 4>(x, weight); case 2048: return linear_f16_m1_splitk_cuda_impl<2048, 4>(x, weight); case 4096: return linear_f16_m1_splitk_cuda_impl<4096, 4>(x, weight); default: TORCH_CHECK(false, "unsupported chunk_k"); } } at::Tensor linear_f16_m1_splitk_tile_cuda(at::Tensor x, at::Tensor weight, int64_t chunk_k, int64_t tile_cols) { if (tile_cols == 64) { switch (chunk_k) { case 64: return linear_f16_m1_splitk_cuda_impl<64, 1>(x, weight); case 96: return linear_f16_m1_splitk_cuda_impl<96, 1>(x, weight); case 112: return linear_f16_m1_splitk_cuda_impl<112, 1>(x, weight); case 128: return linear_f16_m1_splitk_cuda_impl<128, 1>(x, weight); case 144: return linear_f16_m1_splitk_cuda_impl<144, 1>(x, weight); case 152: return linear_f16_m1_splitk_cuda_impl<152, 1>(x, weight); case 160: return linear_f16_m1_splitk_cuda_impl<160, 1>(x, weight); case 168: return linear_f16_m1_splitk_cuda_impl<168, 1>(x, weight); case 176: return linear_f16_m1_splitk_cuda_impl<176, 1>(x, weight); case 184: return linear_f16_m1_splitk_cuda_impl<184, 1>(x, weight); case 192: return linear_f16_m1_splitk_cuda_impl<192, 1>(x, weight); case 208: return linear_f16_m1_splitk_cuda_impl<208, 1>(x, weight); case 224: return linear_f16_m1_splitk_cuda_impl<224, 1>(x, weight); case 256: return linear_f16_m1_splitk_cuda_impl<256, 1>(x, weight); case 384: return linear_f16_m1_splitk_cuda_impl<384, 1>(x, weight); case 512: return linear_f16_m1_splitk_cuda_impl<512, 1>(x, weight); case 640: return linear_f16_m1_splitk_cuda_impl<640, 1>(x, weight); case 768: return linear_f16_m1_splitk_cuda_impl<768, 1>(x, weight); case 896: return linear_f16_m1_splitk_cuda_impl<896, 1>(x, weight); default: TORCH_CHECK(false, "unsupported chunk_k"); } } if (tile_cols == 128) { switch (chunk_k) { case 64: return linear_f16_m1_splitk_cuda_impl<64, 2>(x, weight); case 96: return linear_f16_m1_splitk_cuda_impl<96, 2>(x, weight); case 112: return linear_f16_m1_splitk_cuda_impl<112, 2>(x, weight); case 128: return linear_f16_m1_splitk_cuda_impl<128, 2>(x, weight); case 144: return linear_f16_m1_splitk_cuda_impl<144, 2>(x, weight); case 152: return linear_f16_m1_splitk_cuda_impl<152, 2>(x, weight); case 160: return linear_f16_m1_splitk_cuda_impl<160, 2>(x, weight); case 168: return linear_f16_m1_splitk_cuda_impl<168, 2>(x, weight); case 176: return linear_f16_m1_splitk_cuda_impl<176, 2>(x, weight); case 184: return linear_f16_m1_splitk_cuda_impl<184, 2>(x, weight); case 192: return linear_f16_m1_splitk_cuda_impl<192, 2>(x, weight); case 208: return linear_f16_m1_splitk_cuda_impl<208, 2>(x, weight); case 224: return linear_f16_m1_splitk_cuda_impl<224, 2>(x, weight); case 256: return linear_f16_m1_splitk_cuda_impl<256, 2>(x, weight); case 384: return linear_f16_m1_splitk_cuda_impl<384, 2>(x, weight); case 512: return linear_f16_m1_splitk_cuda_impl<512, 2>(x, weight); case 640: return linear_f16_m1_splitk_cuda_impl<640, 2>(x, weight); case 768: return linear_f16_m1_splitk_cuda_impl<768, 2>(x, weight); case 896: return linear_f16_m1_splitk_cuda_impl<896, 2>(x, weight); case 1024: return linear_f16_m1_splitk_cuda_impl<1024, 2>(x, weight); default: TORCH_CHECK(false, "unsupported chunk_k"); } } TORCH_CHECK(tile_cols == 256, "unsupported tile_cols"); return linear_f16_m1_splitk_cfg_cuda(x, weight, chunk_k); } at::Tensor linear_f16_m1_splitk_warpred_tile_cuda(at::Tensor x, at::Tensor weight, int64_t chunk_k, int64_t tile_cols) { if (tile_cols == 64) { switch (chunk_k) { case 64: return linear_f16_m1_splitk_cuda_impl<64, 1, true>(x, weight); case 96: return linear_f16_m1_splitk_cuda_impl<96, 1, true>(x, weight); case 112: return linear_f16_m1_splitk_cuda_impl<112, 1, true>(x, weight); case 128: return linear_f16_m1_splitk_cuda_impl<128, 1, true>(x, weight); case 144: return linear_f16_m1_splitk_cuda_impl<144, 1, true>(x, weight); case 152: return linear_f16_m1_splitk_cuda_impl<152, 1, true>(x, weight); case 160: return linear_f16_m1_splitk_cuda_impl<160, 1, true>(x, weight); case 168: return linear_f16_m1_splitk_cuda_impl<168, 1, true>(x, weight); case 176: return linear_f16_m1_splitk_cuda_impl<176, 1, true>(x, weight); case 184: return linear_f16_m1_splitk_cuda_impl<184, 1, true>(x, weight); case 192: return linear_f16_m1_splitk_cuda_impl<192, 1, true>(x, weight); case 208: return linear_f16_m1_splitk_cuda_impl<208, 1, true>(x, weight); case 224: return linear_f16_m1_splitk_cuda_impl<224, 1, true>(x, weight); case 256: return linear_f16_m1_splitk_cuda_impl<256, 1, true>(x, weight); default: TORCH_CHECK(false, "unsupported warpred chunk_k"); } } if (tile_cols == 128) { switch (chunk_k) { case 64: return linear_f16_m1_splitk_cuda_impl<64, 2, true>(x, weight); case 96: return linear_f16_m1_splitk_cuda_impl<96, 2, true>(x, weight); case 112: return linear_f16_m1_splitk_cuda_impl<112, 2, true>(x, weight); case 128: return linear_f16_m1_splitk_cuda_impl<128, 2, true>(x, weight); case 144: return linear_f16_m1_splitk_cuda_impl<144, 2, true>(x, weight); case 152: return linear_f16_m1_splitk_cuda_impl<152, 2, true>(x, weight); case 160: return linear_f16_m1_splitk_cuda_impl<160, 2, true>(x, weight); case 168: return linear_f16_m1_splitk_cuda_impl<168, 2, true>(x, weight); case 176: return linear_f16_m1_splitk_cuda_impl<176, 2, true>(x, weight); case 184: return linear_f16_m1_splitk_cuda_impl<184, 2, true>(x, weight); case 192: return linear_f16_m1_splitk_cuda_impl<192, 2, true>(x, weight); case 208: return linear_f16_m1_splitk_cuda_impl<208, 2, true>(x, weight); case 224: return linear_f16_m1_splitk_cuda_impl<224, 2, true>(x, weight); case 256: return linear_f16_m1_splitk_cuda_impl<256, 2, true>(x, weight); default: TORCH_CHECK(false, "unsupported warpred chunk_k"); } } TORCH_CHECK(tile_cols == 256, "unsupported warpred tile_cols"); switch (chunk_k) { case 64: return linear_f16_m1_splitk_cuda_impl<64, 4, true>(x, weight); case 96: return linear_f16_m1_splitk_cuda_impl<96, 4, true>(x, weight); case 112: return linear_f16_m1_splitk_cuda_impl<112, 4, true>(x, weight); case 128: return linear_f16_m1_splitk_cuda_impl<128, 4, true>(x, weight); case 144: return linear_f16_m1_splitk_cuda_impl<144, 4, true>(x, weight); case 152: return linear_f16_m1_splitk_cuda_impl<152, 4, true>(x, weight); case 160: return linear_f16_m1_splitk_cuda_impl<160, 4, true>(x, weight); case 168: return linear_f16_m1_splitk_cuda_impl<168, 4, true>(x, weight); case 176: return linear_f16_m1_splitk_cuda_impl<176, 4, true>(x, weight); case 184: return linear_f16_m1_splitk_cuda_impl<184, 4, true>(x, weight); case 192: return linear_f16_m1_splitk_cuda_impl<192, 4, true>(x, weight); case 208: return linear_f16_m1_splitk_cuda_impl<208, 4, true>(x, weight); case 224: return linear_f16_m1_splitk_cuda_impl<224, 4, true>(x, weight); case 256: return linear_f16_m1_splitk_cuda_impl<256, 4, true>(x, weight); default: TORCH_CHECK(false, "unsupported warpred chunk_k"); } } template at::Tensor linear_f16_rows_splitk_cuda_impl(at::Tensor x, at::Tensor weight) { const int64_t k64 = x.size(-1); const int64_t n64 = weight.size(1); TORCH_CHECK(k64 <= INT_MAX && n64 <= INT_MAX, "linear_f16_rows_splitk K/N too large"); const int K = static_cast(k64); const int N = static_cast(n64); const int64_t m64 = x.numel() / k64; TORCH_CHECK(m64 <= INT_MAX, "linear_f16_rows_splitk M too large"); const int M = static_cast(m64); TORCH_CHECK((N % 64) == 0, "linear_f16_rows_splitk requires N multiple of 64"); std::vector out_sizes(x.sizes().begin(), x.sizes().end()); out_sizes.back() = n64; auto y = at::empty(out_sizes, x.options()); if (M == 0 || K == 0 || N == 0) { return y; } const int chunks = static_cast(ceil_div(K, ChunkK)); auto partial = at::empty({m64, chunks, n64}, x.options().dtype(at::kFloat)); auto stream = at::cuda::getCurrentCUDAStream(); linear_f16_rows_splitk_partial_kernel<<>>( K, N, chunks, x.data_ptr(), weight.data_ptr(), partial.data_ptr()); C10_CUDA_KERNEL_LAUNCH_CHECK(); linear_f16_rows_splitk_reduce_kernel<<(ceil_div(N / 2, 128)), M, 1), 128, 0, stream>>>( chunks, N, partial.data_ptr(), y.data_ptr()); C10_CUDA_KERNEL_LAUNCH_CHECK(); return y; } at::Tensor linear_f16_rows_splitk_cuda(at::Tensor x, at::Tensor weight, int64_t chunk_k) { switch (chunk_k) { case 128: return linear_f16_rows_splitk_cuda_impl<128, 2>(x, weight); case 256: return linear_f16_rows_splitk_cuda_impl<256, 2>(x, weight); case 512: return linear_f16_rows_splitk_cuda_impl<512, 2>(x, weight); case 1024: return linear_f16_rows_splitk_cuda_impl<1024, 2>(x, weight); default: TORCH_CHECK(false, "unsupported chunk_k"); } } at::Tensor linear_t_f16_cuda(at::Tensor x, at::Tensor weight_t) { const int64_t k64 = x.size(-1); const int64_t n64 = weight_t.size(0); TORCH_CHECK(k64 <= INT_MAX && n64 <= INT_MAX, "linear_t_f16 K/N too large"); const int K = static_cast(k64); const int N = static_cast(n64); const int64_t m64 = x.numel() / k64; TORCH_CHECK(m64 <= INT_MAX, "linear_t_f16 M too large"); const int M = static_cast(m64); std::vector out_sizes(x.sizes().begin(), x.sizes().end()); out_sizes.back() = n64; auto y = at::empty(out_sizes, x.options()); if (M == 0 || N == 0 || K == 0) { return y; } auto stream = at::cuda::getCurrentCUDAStream(); if (K <= 512 && N >= 1024 && M <= 4) { if (M == 1) { linear_t_f16_ntile_scalar_kernel<128, 2><<>>( M, K, N, x.data_ptr(), weight_t.data_ptr(), y.data_ptr()); } else { linear_t_f16_ntile_kernel<128, 4><<>>( M, K, N, x.data_ptr(), weight_t.data_ptr(), y.data_ptr()); } } else if (K >= 1024) { linear_t_f16_kernel<256><<>>( M, K, N, x.data_ptr(), weight_t.data_ptr(), y.data_ptr()); } else { linear_t_f16_kernel<128><<>>( M, K, N, x.data_ptr(), weight_t.data_ptr(), y.data_ptr()); } C10_CUDA_KERNEL_LAUNCH_CHECK(); return y; } template at::Tensor linear_t_act_f16_cuda_impl(at::Tensor x, at::Tensor weight_t) { const int64_t k64 = x.size(-1); const int64_t n64 = weight_t.size(0); TORCH_CHECK(k64 <= INT_MAX && n64 <= INT_MAX, "linear_t_act_f16 K/N too large"); const int K = static_cast(k64); const int N = static_cast(n64); const int64_t m64 = x.numel() / k64; TORCH_CHECK(m64 <= INT_MAX, "linear_t_act_f16 M too large"); const int M = static_cast(m64); std::vector out_sizes(x.sizes().begin(), x.sizes().end()); out_sizes.back() = n64; auto y = at::empty(out_sizes, x.options()); if (M == 0 || N == 0 || K == 0) { return y; } auto stream = at::cuda::getCurrentCUDAStream(); TORCH_CHECK(K <= 512 && N >= 1024 && M <= 4, "linear_t_act_f16 currently supports only small-rank rank-out"); if (M == 1) { linear_t_act_f16_ntile_scalar_kernel<128, 2, Act><<>>( M, K, N, x.data_ptr(), weight_t.data_ptr(), y.data_ptr()); } else { linear_t_act_f16_ntile_kernel<128, 4, Act><<>>( M, K, N, x.data_ptr(), weight_t.data_ptr(), y.data_ptr()); } C10_CUDA_KERNEL_LAUNCH_CHECK(); return y; } at::Tensor linear_t_act_f16_cuda(at::Tensor x, at::Tensor weight_t, int64_t act) { if (act == 1) { return linear_t_act_f16_cuda_impl<1>(x, weight_t); } return linear_t_act_f16_cuda_impl<2>(x, weight_t); } std::vector linear_wag_rank_in_f16_cuda( at::Tensor xw, at::Tensor xa, at::Tensor xg, at::Tensor w1_t, at::Tensor a1_t, at::Tensor g1_t) { const int64_t k64 = xw.size(-1); const int64_t rw64 = w1_t.size(0); const int64_t ra64 = a1_t.size(0); const int64_t rg64 = g1_t.size(0); const int64_t m64 = xw.numel() / k64; TORCH_CHECK(k64 <= INT_MAX && rw64 <= INT_MAX && ra64 <= INT_MAX && rg64 <= INT_MAX && m64 <= INT_MAX, "linear_wag_rank_in_f16 shape too large"); const int K = static_cast(k64); const int Rw = static_cast(rw64); const int Ra = static_cast(ra64); const int Rg = static_cast(rg64); const int Rmax = std::max(Rw, std::max(Ra, Rg)); const int M = static_cast(m64); TORCH_CHECK(K >= 1024 && Rmax <= 512 && M <= 8, "linear_wag_rank_in_f16 supports only K>=1024,R<=512,M<=8"); std::vector w_sizes(xw.sizes().begin(), xw.sizes().end()); std::vector a_sizes = w_sizes; std::vector g_sizes = w_sizes; w_sizes.back() = rw64; a_sizes.back() = ra64; g_sizes.back() = rg64; auto w1 = at::empty(w_sizes, xw.options()); auto a1 = at::empty(a_sizes, xw.options()); auto g1 = at::empty(g_sizes, xw.options()); if (M == 0 || K == 0 || Rmax == 0) { return {w1, a1, g1}; } auto stream = at::cuda::getCurrentCUDAStream(); linear_wag_rank_in_f16_kernel<256><<>>( M, K, Rw, Ra, Rg, Rmax, xw.data_ptr(), xa.data_ptr(), xg.data_ptr(), w1_t.data_ptr(), a1_t.data_ptr(), g1_t.data_ptr(), w1.data_ptr(), a1.data_ptr(), g1.data_ptr()); C10_CUDA_KERNEL_LAUNCH_CHECK(); return {w1, a1, g1}; } std::vector linear_wagv_rank_in_f16_cuda( at::Tensor xw, at::Tensor xa, at::Tensor xg, at::Tensor xv, at::Tensor w1_t, at::Tensor a1_t, at::Tensor g1_t, at::Tensor v1_t) { const int64_t k64 = xw.size(-1); const int64_t rw64 = w1_t.size(0); const int64_t ra64 = a1_t.size(0); const int64_t rg64 = g1_t.size(0); const int64_t rv64 = v1_t.size(0); const int64_t m64 = xw.numel() / k64; TORCH_CHECK(k64 <= INT_MAX && rw64 <= INT_MAX && ra64 <= INT_MAX && rg64 <= INT_MAX && rv64 <= INT_MAX && m64 <= INT_MAX, "linear_wagv_rank_in_f16 shape too large"); const int K = static_cast(k64); const int Rw = static_cast(rw64); const int Ra = static_cast(ra64); const int Rg = static_cast(rg64); const int Rv = static_cast(rv64); const int Rmax = std::max(std::max(Rw, Ra), std::max(Rg, Rv)); const int M = static_cast(m64); TORCH_CHECK(K >= 1024 && Rmax <= 512 && M <= 8, "linear_wagv_rank_in_f16 supports only K>=1024,R<=512,M<=8"); std::vector w_sizes(xw.sizes().begin(), xw.sizes().end()); std::vector a_sizes = w_sizes; std::vector g_sizes = w_sizes; std::vector v_sizes = w_sizes; w_sizes.back() = rw64; a_sizes.back() = ra64; g_sizes.back() = rg64; v_sizes.back() = rv64; auto w1 = at::empty(w_sizes, xw.options()); auto a1 = at::empty(a_sizes, xw.options()); auto g1 = at::empty(g_sizes, xw.options()); auto v1 = at::empty(v_sizes, xw.options()); if (M == 0 || K == 0 || Rmax == 0) { return {w1, a1, g1, v1}; } auto stream = at::cuda::getCurrentCUDAStream(); linear_wagv_rank_in_f16_kernel<256><<>>( M, K, Rw, Ra, Rg, Rv, Rmax, xw.data_ptr(), xa.data_ptr(), xg.data_ptr(), xv.data_ptr(), w1_t.data_ptr(), a1_t.data_ptr(), g1_t.data_ptr(), v1_t.data_ptr(), w1.data_ptr(), a1.data_ptr(), g1.data_ptr(), v1.data_ptr()); C10_CUDA_KERNEL_LAUNCH_CHECK(); return {w1, a1, g1, v1}; } std::vector linear_wag_rank_out_f16_cuda( at::Tensor w1, at::Tensor a1, at::Tensor g1, at::Tensor w2_t, at::Tensor a2_t, at::Tensor g2_t) { const int64_t kw64 = w1.size(-1); const int64_t ka64 = a1.size(-1); const int64_t kg64 = g1.size(-1); const int64_t c64 = w2_t.size(0); const int64_t m64 = w1.numel() / kw64; TORCH_CHECK(kw64 <= INT_MAX && ka64 <= INT_MAX && kg64 <= INT_MAX && c64 <= INT_MAX && m64 <= INT_MAX, "linear_wag_rank_out_f16 shape too large"); const int Kw = static_cast(kw64); const int Ka = static_cast(ka64); const int Kg = static_cast(kg64); const int C = static_cast(c64); const int M = static_cast(m64); TORCH_CHECK(Kw <= 512 && Ka <= 512 && Kg <= 512 && C >= 1024 && M <= 4, "linear_wag_rank_out_f16 supports only small-rank M<=4"); std::vector out_sizes(w1.sizes().begin(), w1.sizes().end()); out_sizes.back() = c64; auto w = at::empty(out_sizes, w1.options()); auto a = at::empty(out_sizes, w1.options()); auto g = at::empty(out_sizes, w1.options()); if (M == 0 || C == 0 || Kw == 0 || Ka == 0 || Kg == 0) { return {w, a, g}; } auto stream = at::cuda::getCurrentCUDAStream(); if (M == 1) { linear_wag_rank_out_f16_kernel<128, 4><<>>( M, C, Kw, Ka, Kg, w1.data_ptr(), a1.data_ptr(), g1.data_ptr(), w2_t.data_ptr(), a2_t.data_ptr(), g2_t.data_ptr(), w.data_ptr(), a.data_ptr(), g.data_ptr()); } else { linear_wag_rank_out_f16_kernel<128, 4><<>>( M, C, Kw, Ka, Kg, w1.data_ptr(), a1.data_ptr(), g1.data_ptr(), w2_t.data_ptr(), a2_t.data_ptr(), g2_t.data_ptr(), w.data_ptr(), a.data_ptr(), g.data_ptr()); } C10_CUDA_KERNEL_LAUNCH_CHECK(); return {w, a, g}; } std::vector linear_wagv_rank_out_f16_cuda( at::Tensor w1, at::Tensor a1, at::Tensor g1, at::Tensor v1, at::Tensor w2_t, at::Tensor a2_t, at::Tensor g2_t, at::Tensor v2_t, at::Tensor v, at::Tensor v_first, at::Tensor v0) { const int64_t kw64 = w1.size(-1); const int64_t ka64 = a1.size(-1); const int64_t kg64 = g1.size(-1); const int64_t kv64 = v1.size(-1); const int64_t c64 = w2_t.size(0); const int64_t m64 = w1.numel() / kw64; TORCH_CHECK(kw64 <= INT_MAX && ka64 <= INT_MAX && kg64 <= INT_MAX && kv64 <= INT_MAX && c64 <= INT_MAX && m64 <= INT_MAX, "linear_wagv_rank_out_f16 shape too large"); const int Kw = static_cast(kw64); const int Ka = static_cast(ka64); const int Kg = static_cast(kg64); const int Kv = static_cast(kv64); const int C = static_cast(c64); const int M = static_cast(m64); TORCH_CHECK(Kw <= 512 && Ka <= 512 && Kg <= 512 && Kv <= 512 && C >= 1024 && M <= 4, "linear_wagv_rank_out_f16 supports only small-rank M<=4"); std::vector out_sizes(w1.sizes().begin(), w1.sizes().end()); out_sizes.back() = c64; auto w = at::empty(out_sizes, w1.options()); auto a = at::empty(out_sizes, w1.options()); auto g = at::empty(out_sizes, w1.options()); auto v_out = at::empty(out_sizes, w1.options()); if (M == 0 || C == 0 || Kw == 0 || Ka == 0 || Kg == 0 || Kv == 0) { return {w, a, g, v_out}; } auto stream = at::cuda::getCurrentCUDAStream(); if (M == 1) { linear_wagv_rank_out_f16_kernel<128, 4><<>>( M, C, Kw, Ka, Kg, Kv, w1.data_ptr(), a1.data_ptr(), g1.data_ptr(), v1.data_ptr(), w2_t.data_ptr(), a2_t.data_ptr(), g2_t.data_ptr(), v2_t.data_ptr(), v.data_ptr(), v_first.data_ptr(), v0.data_ptr(), w.data_ptr(), a.data_ptr(), g.data_ptr(), v_out.data_ptr()); } else { linear_wagv_rank_out_f16_kernel<128, 4><<>>( M, C, Kw, Ka, Kg, Kv, w1.data_ptr(), a1.data_ptr(), g1.data_ptr(), v1.data_ptr(), w2_t.data_ptr(), a2_t.data_ptr(), g2_t.data_ptr(), v2_t.data_ptr(), v.data_ptr(), v_first.data_ptr(), v0.data_ptr(), w.data_ptr(), a.data_ptr(), g.data_ptr(), v_out.data_ptr()); } C10_CUDA_KERNEL_LAUNCH_CHECK(); return {w, a, g, v_out}; } at::Tensor linear_t_vres_f16_cuda(at::Tensor x, at::Tensor weight_t, at::Tensor v, at::Tensor v_first, at::Tensor v0) { const int64_t k64 = x.size(-1); const int64_t n64 = weight_t.size(0); TORCH_CHECK(k64 <= INT_MAX && n64 <= INT_MAX, "linear_t_vres_f16 K/N too large"); const int K = static_cast(k64); const int N = static_cast(n64); const int64_t m64 = x.numel() / k64; TORCH_CHECK(m64 <= INT_MAX, "linear_t_vres_f16 M too large"); const int M = static_cast(m64); auto y = at::empty_like(v); if (M == 0 || N == 0 || K == 0) { return y; } auto stream = at::cuda::getCurrentCUDAStream(); TORCH_CHECK(K <= 512 && N >= 1024 && M <= 4, "linear_t_vres_f16 currently supports only small-rank rank-out"); if (M == 1) { linear_t_vres_f16_ntile_scalar_kernel<128, 2><<>>( M, K, N, x.data_ptr(), weight_t.data_ptr(), v.data_ptr(), v_first.data_ptr(), v0.data_ptr(), y.data_ptr()); } else { linear_t_vres_f16_ntile_kernel<128, 4><<>>( M, K, N, x.data_ptr(), weight_t.data_ptr(), v.data_ptr(), v_first.data_ptr(), v0.data_ptr(), y.data_ptr()); } C10_CUDA_KERNEL_LAUNCH_CHECK(); return y; }