From 3d7c4a424d2608baa565d02db3971b33f42fdbd9 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Mon, 10 Nov 2025 02:27:12 +0100 Subject: [PATCH 01/16] added custom gemm implementation --- CMakeLists.txt | 4 + src/kernels/gemm_mma.cu | 218 ++++++++++++++++++++++++++++++++ src/kernels/kernels.h | 3 + src/kernels/matmul.cpp | 37 +++++- src/kernels/tensor_core_utils.h | 144 +++++++++++++++++++++ src/testing/kernels/gemm.cpp | 143 +++++++++++++++++++++ train.cpp | 7 + 7 files changed, 550 insertions(+), 6 deletions(-) create mode 100644 src/kernels/gemm_mma.cu create mode 100644 src/kernels/tensor_core_utils.h create mode 100644 src/testing/kernels/gemm.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 965ae78..c6a040c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -178,6 +178,7 @@ add_library(llmq-common SHARED src/kernels/random.cu src/kernels/fill.cu src/kernels/convert.cu + src/kernels/gemm_mma.cu ) target_link_libraries(llmq-common PRIVATE ${CUFILE_LIBS} cudnn_frontend ${CUDNN_LIBRARY} nvidia::nccl fmt::fmt-header-only) @@ -194,6 +195,9 @@ target_link_libraries(train PRIVATE llmq-common CLI11::CLI11 fmt::fmt-header-onl add_executable(export-checkpoint export-checkpoint.cpp) target_link_libraries(export-checkpoint PRIVATE llmq-common CLI11::CLI11) +add_executable(gemm-test src/testing/kernels/gemm.cpp) +target_link_libraries(gemm-test PUBLIC llmq-common CUDA::cublas) + if (NOT SKBUILD) install(TARGETS llmq-common train export-checkpoint LIBRARY DESTINATION lib diff --git a/src/kernels/gemm_mma.cu b/src/kernels/gemm_mma.cu new file mode 100644 index 0000000..6f35740 --- /dev/null +++ b/src/kernels/gemm_mma.cu @@ -0,0 +1,218 @@ +#include +#include +#include +#include "tensor_core_utils.h" +#include "utilities/vec.cuh" +#include +#undef NDEBUG +#include +#include + +unsigned div_ceil(unsigned a, unsigned b) { + return (a + b - 1) / b; +} + +template +using int_c = std::integral_constant; + +template +std::type_identity type_v = {}; + +template +__global__ __launch_bounds__(32*2*2, 2) void gemm_mma_tn_kernel(nv_bfloat16* __restrict__ out, + const AType* __restrict__ a, const BType* __restrict__ b, + int m, int n, int k, const float* scale, bool accumulate, + std::type_identity acc_type) { + static_assert(sizeof(AType) == sizeof(BType), "index calculations assume sz(AType) == sz(BType)"); + // Note: you cannot change these numbers without breaking the kernel. + // they are here only for convenience, not to parametrize the algorithm. + // also note that some of the smaller loops have been unrolled by hand + // to have a nicer experience in ncu + constexpr int BI = 2; + constexpr int BJ = 2; + constexpr int WI = 4; + constexpr int WJ = 4; + constexpr int DEPTH = 4; + + constexpr int TI = 16; + constexpr int TJ = 16; + constexpr int TK = 2; // in units of uint4 + + int bi; + int bj; + if( m > n ) { + bi = blockIdx.y * BI * WI; + bj = blockIdx.x * BJ * WJ; + } else { + bi = blockIdx.x * BI * WI; + bj = blockIdx.y * BJ * WJ; + } + int i = bi + threadIdx.y * WI; + int j = bj + threadIdx.z * WJ; + + int wid = threadIdx.y + 2 * threadIdx.z; + constexpr int NW = BI * BJ; + + int stride = k / (sizeof(uint4)/sizeof(AType)); + + m16_n16_k32_c_fragment acc[WI][WJ]; + __shared__ uint4 input_tiles[2 * DEPTH * WI * BI * TI * TK]; + int2 offsets = ldmatrix_offsets(); + constexpr int PIPE_OFFSET = WI * BI * TI * TK; + constexpr int ROW_OFFSET = TI * TK; + + // instead of each thread computing addresses in both A and B, specialize 2 warps + // for A-loading and 2 warps for B-loading. + + // const uint4* a_ptr = reinterpret_cast(a) + (bi + wid) * TI * stride; + // const uint4* b_ptr = reinterpret_cast(b) + (bj + wid) * TJ * stride; + // uint4* as_store_ptr = input_tiles + wid * ROW_OFFSET; + // uint4* bs_store_ptr = input_tiles + wid * ROW_OFFSET + DEPTH * PIPE_OFFSET; + + const uint4* g_ptr; + uint4* s_ptr; + if(wid < 2) { + g_ptr = reinterpret_cast(a) + (bi + wid) * TI * stride; + s_ptr = input_tiles + wid * ROW_OFFSET; + } else { + g_ptr = reinterpret_cast(b) + (bj + wid - 2) * TJ * stride; + s_ptr = input_tiles + (wid - 2) * ROW_OFFSET + DEPTH * PIPE_OFFSET; + } + + global_to_shared_16_32_swizzle(&s_ptr, &g_ptr, stride); + + const uint4* as_load_ptr = input_tiles + offsets.x + threadIdx.y * WI*ROW_OFFSET; + const uint4* bs_load_ptr = input_tiles + offsets.y + threadIdx.z * WJ*ROW_OFFSET + DEPTH * PIPE_OFFSET; + + static_assert(WI * BI % NW == 0, "WI * BI must be divisible by the number of warps per block"); + static_assert(WJ * BJ % NW == 0, "WI * BI must be divisible by the number of warps per block"); + + auto loop_fraction = [&](auto stage_c, auto load_next_c, int ks) { + constexpr int load_stage = decltype(stage_c)::value; + constexpr int store_stage = (load_stage + 3) % DEPTH; + constexpr bool load_next = decltype(load_next_c)::value; + + m16_n16_k32_a_fragment a_frag[WI]; + ptx_ldmatrix(a_frag[0].v, as_load_ptr + 0 * ROW_OFFSET + PIPE_OFFSET * load_stage); + ptx_ldmatrix(a_frag[1].v, as_load_ptr + 1 * ROW_OFFSET + PIPE_OFFSET * load_stage); + ptx_ldmatrix(a_frag[2].v, as_load_ptr + 2 * ROW_OFFSET + PIPE_OFFSET * load_stage); + ptx_ldmatrix(a_frag[3].v, as_load_ptr + 3 * ROW_OFFSET + PIPE_OFFSET * load_stage); + + if constexpr(load_next) { + __pipeline_memcpy_async(s_ptr + store_stage * PIPE_OFFSET + 0 * ROW_OFFSET, g_ptr + 0 * TI * stride + ks, 16); + __pipeline_memcpy_async(s_ptr + store_stage * PIPE_OFFSET + 2 * ROW_OFFSET, g_ptr + 2 * TI * stride + ks, 16); + __pipeline_memcpy_async(s_ptr + store_stage * PIPE_OFFSET + 4 * ROW_OFFSET, g_ptr + 4 * TI * stride + ks, 16); + __pipeline_memcpy_async(s_ptr + store_stage * PIPE_OFFSET + 6 * ROW_OFFSET, g_ptr + 6 * TI * stride + ks, 16); + } + __pipeline_commit(); + for(int jj = 0; jj < WJ; jj++) { + m16_n16_k32_b_fragment b_frag; + ptx_ldmatrix(b_frag.v, bs_load_ptr + jj * ROW_OFFSET + PIPE_OFFSET * load_stage); + for(int ii = 0; ii < WI; ii++) { + mma_m16_n16_k32_sync(acc[ii][jj], a_frag[ii], b_frag, acc[ii][jj]); + } + } + __pipeline_wait_prior(2); + __syncthreads(); + }; + + auto ldg_sts = [&](auto stage_c, int ks) { + constexpr int stage = decltype(stage_c)::value; + __pipeline_memcpy_async(s_ptr + stage * PIPE_OFFSET + 0 * ROW_OFFSET, g_ptr + 0 * TI * stride + ks, 16); + __pipeline_memcpy_async(s_ptr + stage * PIPE_OFFSET + 2 * ROW_OFFSET, g_ptr + 2 * TI * stride + ks, 16); + __pipeline_memcpy_async(s_ptr + stage * PIPE_OFFSET + 4 * ROW_OFFSET, g_ptr + 4 * TI * stride + ks, 16); + __pipeline_memcpy_async(s_ptr + stage * PIPE_OFFSET + 6 * ROW_OFFSET, g_ptr + 6 * TI * stride + ks, 16); + }; + + // start up the pipeline + ldg_sts(int_c<0>{}, 0); + __pipeline_commit(); + ldg_sts(int_c<1>{}, 1 * TK); + __pipeline_commit(); + ldg_sts(int_c<2>{}, 2 * TK); + __pipeline_commit(); + + std::bool_constant true_v; + std::bool_constant false_v; + + __pipeline_wait_prior(2); + __syncthreads(); + + int ks = 0; + while (ks + 6 * TK < stride) { + loop_fraction(int_c<0>{}, true_v, ks + 3 * TK); + loop_fraction(int_c<1>{}, true_v, ks + 4 * TK); + loop_fraction(int_c<2>{}, true_v, ks + 5 * TK); + loop_fraction(int_c<3>{}, true_v, ks + 6 * TK); + + ks += 4 * TK; + } + + // last iteration + loop_fraction(int_c<0>{}, true_v, ks + 3 * TK); + loop_fraction(int_c<1>{}, false_v, ks + 4 * TK); + loop_fraction(int_c<2>{}, false_v, ks + 5 * TK); + loop_fraction(int_c<3>{}, false_v, ks + 6 * TK); + + if(scale != nullptr) { + for (auto ii = 0; ii < WI; ++ii) { + for (int jj = 0; jj < WJ; ++jj) { + acc[ii][jj].v[0] *= *scale; + acc[ii][jj].v[1] *= *scale; + acc[ii][jj].v[2] *= *scale; + acc[ii][jj].v[3] *= *scale; + acc[ii][jj].v[4] *= *scale; + acc[ii][jj].v[5] *= *scale; + acc[ii][jj].v[6] *= *scale; + acc[ii][jj].v[7] *= *scale; + } + } + } + + // note: loop_fraction ends with __syncthreads, so no need to sync here + nv_bfloat16* out_shared = reinterpret_cast(input_tiles) + (threadIdx.y + 2 * threadIdx.z) * TJ * TI; + + for(int ii = 0; ii < WI; ii++) { + for (int jj = 0; jj < WJ; jj++) { + store_fragment_row_major_sync(acc[ii][jj], out_shared, TJ); + __syncwarp(); + int c = threadIdx.x % 2; + int r = threadIdx.x / 2; + + if(accumulate) { + auto old = GenericVector::load(out + ((i + ii) * TI + r) * n + (j + jj) * TJ + 8 * c); + auto upd = GenericVector::load(out_shared + (c + 2 * r) * 8); + for(int l = 0; l < 8; ++l) { + old[l] += upd[l]; + } + old.store(out + ((i + ii) * TI + r) * n + (j + jj) * TJ + 8 * c); + } else { + uint4 load = reinterpret_cast(out_shared)[c + 2 * r]; + *reinterpret_cast(out + ((i + ii) * TI + r) * n + (j + jj) * TJ + 8 * c) = load; + } + } + } +} + +template +void gemm_mma_tn_launcher(nv_bfloat16* out, const AType* a, const BType* b, int m, int n, int k, const float* scale, bool accumulate, std::type_identity, cudaStream_t stream) { + // our kernel is row-major, so to match cublas, we need to transpose everything => swapped a<->b, m<->n + dim3 grid{(unsigned)n / 128, (unsigned)m / 128, 1}; + if( n > m ) { + grid = {(unsigned)m / 128, (unsigned)n / 128, 1}; + } else { + grid = {(unsigned)n / 128, (unsigned)m / 128, 1}; + } + dim3 block{32, 2, 2}; + gemm_mma_tn_kernel<<>>(out, b, a, n, m, k, scale, accumulate, type_v); +} + +void gemm_mma_tn(nv_bfloat16* out, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, int m, int n, int k, const float* scale, bool accumulate, cudaStream_t stream) { + gemm_mma_tn_launcher(out, a, b, m, n, k, scale, accumulate, type_v, stream); + assert(cudaGetLastError() == cudaSuccess); +} + +void gemm_mma_tn(nv_bfloat16* out, const nv_bfloat16* a, const nv_bfloat16* b, int m, int n, int k, const float* scale, bool accumulate, cudaStream_t stream) { + gemm_mma_tn_launcher(out, a, b, m, n, k, scale, accumulate, type_v, stream); + assert(cudaGetLastError() == cudaSuccess); +} diff --git a/src/kernels/kernels.h b/src/kernels/kernels.h index 6266445..c10ba4e 100644 --- a/src/kernels/kernels.h +++ b/src/kernels/kernels.h @@ -20,6 +20,9 @@ struct Tensor; enum class ETensorDType: int; enum class EMMTranspose { TT, TN, NT, NN }; +enum class EMatmulBackend {CuBLAS, Custom}; + +EMatmulBackend& get_matmul_backend(); void encoder_forward(float* out, const int* inp, const float* wte, const float* wpe, int B, int T, int C, int V, cudaStream_t stream); void encoder_forward(nv_bfloat16* out, const int* inp, const nv_bfloat16* wte, const nv_bfloat16* wpe, int B, int T, int C, int V, cudaStream_t stream); diff --git a/src/kernels/matmul.cpp b/src/kernels/matmul.cpp index 7d67462..f6cb352 100644 --- a/src/kernels/matmul.cpp +++ b/src/kernels/matmul.cpp @@ -12,6 +12,11 @@ cublasComputeType_t cublas_compute = CUBLAS_COMPUTE_32F; +EMatmulBackend& get_matmul_backend() { + static EMatmulBackend backend = EMatmulBackend::CuBLAS; + return backend; +} + // ---------------------------------------------------------------------------- // Error checking @@ -151,40 +156,60 @@ void matmul_cublaslt(FloatC* d, const FloatA* a, const FloatB* b, const FloatBia CUDA_CHECK(cudaGetLastError()); } +// custom matmuls +void gemm_mma_tn(nv_bfloat16* out, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, int m, int n, int k, const float* scale, bool accumulate, cudaStream_t stream); +void gemm_mma_tn(nv_bfloat16* out, const nv_bfloat16* a, const nv_bfloat16* b, int m, int n, int k, const float* scale, bool accumulate, cudaStream_t stream); + + +template +void matmul_dispatch(floatO* d, const floatX* a, const floatX* b, const floatB* bias, + std::byte* workspace, std::size_t workspace_size, + int m, int n, int k, cudaStream_t stream, cublasLtHandle_t handle, + const float* scale, EMMTranspose mode, bool accumulate) +{ + if(get_matmul_backend() == EMatmulBackend::CuBLAS || bias != nullptr || mode != EMMTranspose::TN) { + matmul_cublaslt(d, a, b, bias, workspace, workspace_size, m, n, k, stream, handle, scale, mode, accumulate); + } else if constexpr (std::is_same_v){ + gemm_mma_tn(d, a, b, m, n, k, scale, accumulate, stream); + } else { + matmul_cublaslt(d, a, b, bias, workspace, workspace_size, m, n, k, stream, handle, scale, mode, accumulate); + } +} + void matmul(float* c, const float* a, const float* b, const float* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream) { - matmul_cublaslt(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); + matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); } void matmul(float* c, const nv_bfloat16* a, const nv_bfloat16* b, const float* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream) { - matmul_cublaslt(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); + matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); } void matmul(float* c, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, const float* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream) { - matmul_cublaslt(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); + matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); } void matmul(float* c, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream) { - matmul_cublaslt(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); + matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); } void matmul(nv_bfloat16* c, const nv_bfloat16* a, const nv_bfloat16* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream) { - matmul_cublaslt(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); + matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); } void matmul(nv_bfloat16* c, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream) { - matmul_cublaslt(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); + matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); } void matmul(nv_bfloat16* c, const __nv_fp8_e4m3* a, const __nv_fp8_e5m2* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b, diff --git a/src/kernels/tensor_core_utils.h b/src/kernels/tensor_core_utils.h new file mode 100644 index 0000000..32b053e --- /dev/null +++ b/src/kernels/tensor_core_utils.h @@ -0,0 +1,144 @@ +#include "utilities/dtype.h" +#include +#include +#include +#include +#include + +template +struct m16_n16_k32_a_fragment { + uint4 v; +}; + +template +struct m16_n16_k32_b_fragment { + uint4 v; +}; + +template +struct m16_n16_k32_c_fragment { + AccDType v[8] = {0.f, 0.f, 0.f, 0.f}; +}; + +template +constexpr char ptx_type_name[] = "unknown_dtype"; + +template<> +constexpr char ptx_type_name[4] = "f32"; + +template<> +constexpr char ptx_type_name[4] = "f16"; + +template<> +constexpr char ptx_type_name[5] = "bf16"; + +template<> +constexpr char ptx_type_name<__nv_fp8_e4m3>[5] = "e4m3"; + +template<> +constexpr char ptx_type_name<__nv_fp8_e5m2>[5] = "e5m2"; + + +__device__ __forceinline__ void global_to_shared_16_32_swizzle(uint4** shared, const uint4** global, int stride) { + int col = threadIdx.x % 2; + int row = threadIdx.x / 2; + + int g8 = threadIdx.x / 8; + int t8 = threadIdx.x % 8; + + *shared = *shared + (t8 ^ g8) + 8 * g8; + *global = *global + row * stride + col; +} + +__device__ __forceinline__ int load_address(int row, int col) { + int lin = col + 2 * row; + int g8 = lin / 8; + int t8 = lin % 8; + return (t8 ^ g8) + 8 * g8; +} + +__device__ __forceinline__ int2 ldmatrix_offsets() { + int t8 = threadIdx.x % 8; + int g8 = threadIdx.x / 8; + int a = load_address(t8 + 8 * (g8%2), g8 / 2); + int b = load_address(t8 + 8 * (g8/2), g8 % 2); + return make_int2(a, b); +} + +__device__ __forceinline__ void ptx_ldmatrix(uint4& dst, const void* src) { + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) + : "l"(__cvta_generic_to_shared(src)) + ); +} + +template +__device__ __forceinline__ void store_fragment_row_major_sync(m16_n16_k32_c_fragment& c, nv_bfloat16* ptr, int row_stride) { + int g4 = threadIdx.x / 4; + int c4 = threadIdx.x % 4; + nv_bfloat162* vptr = reinterpret_cast(ptr); + row_stride /= 2; + vptr[row_stride * (g4 + 0) + c4 + 0] = make_bfloat162((nv_bfloat16)c.v[0], (nv_bfloat16)c.v[1]); + vptr[row_stride * (g4 + 8) + c4 + 0] = make_bfloat162((nv_bfloat16)c.v[2], (nv_bfloat16)c.v[3]); + + vptr[row_stride * (g4 + 0) + c4 + 4] = make_bfloat162((nv_bfloat16)c.v[4], (nv_bfloat16)c.v[5]); + vptr[row_stride * (g4 + 8) + c4 + 4] = make_bfloat162((nv_bfloat16)c.v[6], (nv_bfloat16)c.v[7]); +} + +template +__device__ __forceinline__ void mma_m16_n16_k32_sync(m16_n16_k32_c_fragment& d, + m16_n16_k32_a_fragment a, + m16_n16_k32_b_fragment b, + m16_n16_k32_c_fragment c) { + static_assert(sizeof(AType) == sizeof(BType), "a and b type must have the same size"); + + constexpr int k = 32 / sizeof(AType); + asm volatile("mma.sync.aligned.m16n8k%26.row.col.f32.%24.%25.f32 " + "{%0, %1, %2, %3}," + "{%8, %9, %10, %11}," + "{%12, %13}," + "{%16, %17, %18, %19};\n" + "mma.sync.aligned.m16n8k%26.row.col.f32.%24.%25.f32 " + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%14, %15}," + "{%20, %21, %22, %23};\n" + : "=f"(d.v[0]), "=f"(d.v[1]), "=f"(d.v[2]), "=f"(d.v[3]), + "=f"(d.v[4]), "=f"(d.v[5]), "=f"(d.v[6]), "=f"(d.v[7]) + : "r"(a.v.x), "r"(a.v.y), "r"(a.v.z), "r"(a.v.w), + "r"(b.v.x), "r"(b.v.y), "r"(b.v.z), "r"(b.v.w), + "f"(c.v[0]), "f"(c.v[1]), "f"(c.v[2]), "f"(c.v[3]), + "f"(c.v[4]), "f"(c.v[5]), "f"(c.v[6]), "f"(c.v[7]), + "C"(ptx_type_name), "C"(ptx_type_name), "n"(k)); +} + +template +__device__ __forceinline__ void mma_m16_n16_k32_sync(m16_n16_k32_c_fragment& d, + m16_n16_k32_a_fragment a, + m16_n16_k32_b_fragment b, + m16_n16_k32_c_fragment c) { + auto to_raw = [](half& h) -> unsigned int& { + return *reinterpret_cast(&h); + }; + asm volatile("mma.sync.aligned.m16n8k32.row.col.f16.%10.%11.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};" + : "=r"(to_raw(d.v[0])), "=r"(to_raw(d.v[2])) + : "r"(a.v.x), "r"(a.v.y), "r"(a.v.z), "r"(a.v.w), + "r"(b.v.x), "r"(b.v.y), + "r"(to_raw(c.v[0])), "r"(to_raw(c.v[2])), + "C"(ptx_type_name), "C"(ptx_type_name)); + + asm volatile("mma.sync.aligned.m16n8k32.row.col.f16.%10.%11.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};" + : "=r"(to_raw(d.v[4])), "=r"(to_raw(d.v[6])) + : "r"(a.v.x), "r"(a.v.y), "r"(a.v.z), "r"(a.v.w), + "r"(b.v.z), "r"(b.v.w), + "r"(to_raw(c.v[4])), "r"(to_raw(c.v[6])), + "C"(ptx_type_name), "C"(ptx_type_name)); +} diff --git a/src/testing/kernels/gemm.cpp b/src/testing/kernels/gemm.cpp new file mode 100644 index 0000000..aa4a448 --- /dev/null +++ b/src/testing/kernels/gemm.cpp @@ -0,0 +1,143 @@ +#include "kernels/kernels.h" +#include "utilities/utils.h" +#include +#include +#include +#include +#include + +void gemm_mma_tn(nv_bfloat16* out, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, int m, int n, int k, const float* scale, bool accumulate, cudaStream_t stream); +void gemm_mma_tn(nv_bfloat16* out, const nv_bfloat16* a, const nv_bfloat16* b, int m, int n, int k, const float* scale, bool accumulate, cudaStream_t stream); + +template +extern void matmul_cublaslt(floatO* d, const floatX* a, const floatX* b, const floatB* bias, + std::byte* workspace, std::size_t workspace_size, + int m, int n, int k, cudaStream_t stream, cublasLtHandle_t handle, + const float* scale, EMMTranspose mode, bool accumulate); + + +template +void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false) { + Atype* a; + Btype* b; + Ctype* c; + cudaMallocManaged(&a, m * k * sizeof(Atype)); + cudaMallocManaged(&b, n * k * sizeof(Btype)); + cudaMallocManaged(&c, m * n * sizeof(Ctype)); + cudaMemset(a, 0, m * k * sizeof(Atype)); + cudaMemset(c, 0, m * n * sizeof(Ctype)); + cudaMemset(b, 0, n * k * sizeof(Btype)); + + float* a_float = nullptr; + float* b_float = nullptr; + float* c_float = nullptr; + cudaMallocManaged(&a_float, m * k * sizeof(float)); + cudaMallocManaged(&b_float, n * k * sizeof(float)); + cudaMallocManaged(&c_float, m * n * sizeof(float)); + cudaMemset(a_float, 0, m * k * sizeof(float)); + cudaMemset(b_float, 0, n * k * sizeof(float)); + cudaMemset(c_float, 0, m * n * sizeof(float)); + + for(int i = 0; i < m; ++i) { + for(int j = 0; j < k; ++j) { + auto val = static_cast(rand() % 31 - 15); + a[i*k+j] = val; + a_float[i*k+j] = static_cast(val); + } + } + + for(int i = 0; i < n; ++i) { + for(int j = 0; j < k; ++j) { + auto val = static_cast(rand() % 31 - 15); + b[i*k+j] = val; + b_float[i*k+j] = static_cast(val); + } + } + + for(int i = 0; i < m; ++i) { + for(int j = 0; j < n; ++j) { + auto val = static_cast(rand() % 31 - 15); + c[i*n+j] = val; + c_float[i*n+j] = static_cast(val); + } + } + + float* scale_ptr; + cudaMallocManaged(&scale_ptr, sizeof(float)); + *scale_ptr = scale; + + cudaMemPrefetchAsync(a, m*k * sizeof(Atype), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); + cudaMemPrefetchAsync(b, n*k * sizeof(Btype), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); + cudaMemPrefetchAsync(c, m*n * sizeof(Ctype), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); + cudaMemPrefetchAsync(scale_ptr, 4, cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); + + CUDA_CHECK(cudaDeviceSynchronize()); + gemm_mma_tn(c, a, b, m, n, k, scale_ptr, accumulate, 0); + CUDA_CHECK(cudaDeviceSynchronize()); + + cublasLtHandle_t handle; + std::byte* workspace; + size_t workspace_size = 128 * 1024 * 1024; + assert(cublasLtCreate(&handle) == CUBLAS_STATUS_SUCCESS); + cudaMalloc(&workspace, workspace_size); + setup_cublas(); + + + cudaMemPrefetchAsync(a_float, m*k * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); + cudaMemPrefetchAsync(b_float, n*k * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); + cudaMemPrefetchAsync(c_float, m*n * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); + CUDA_CHECK(cudaDeviceSynchronize()); + + matmul_cublaslt(c_float, a_float, b_float, (float*)nullptr, workspace, workspace_size, m, n, k, nullptr, handle, scale_ptr, EMMTranspose::TN, accumulate); + CUDA_CHECK(cudaDeviceSynchronize()); + + double r_tol = std::is_same_v ? 1e-2 : 0.125; + + bool equal = true; + for(int i = 0; i < m; ++i) { + for(int j = 0; j < n; ++j) { + if(fabsf(c_float[j * m + i] - (float)c[j * m + i]) > std::max(r_tol * fabsf((float)c[j * m + i]), 1e-4)) { + printf("%d %d: %f != %f\n", i, j, (float) c_float[j * m + i], (float) c[j * m + i]); + equal = false; + } + } + if(!equal) { + break; + } + } + + if(equal) { + printf("PASS\n"); + } else { + printf("FAIL\n"); + } + + cudaFree(a); + cudaFree(b); + cudaFree(c); + cudaFree(a_float); + cudaFree(b_float); + cudaFree(c_float); + cudaFree(scale_ptr); + cudaFree(workspace); + cublasLtDestroy(handle); +} + +int main() { + int m = 1536; + int n = 1024; + int k = 1664; + + // larger shape for benchmarking + if (false) { + m = 2*4864; + n = 1024*8; + k = 896; + } + + run_test(m, n, k, 1.f, false); + run_test<__nv_fp8_e4m3, __nv_fp8_e4m3, nv_bfloat16>(m, n, k, 4.0/k, false); + + run_test(m, n, k, 1.f, true); + run_test<__nv_fp8_e4m3, __nv_fp8_e4m3, nv_bfloat16>(m, n, k, 4.0/k, true); +} diff --git a/train.cpp b/train.cpp index 8409676..f5a1882 100644 --- a/train.cpp +++ b/train.cpp @@ -93,6 +93,7 @@ struct TrainingRunner { int NGPUs = 0; bool MemcpyAllGather = false; bool MemcpySendRecv = false; + bool UseCustomMatmul = false; LLamaOptions Options; @@ -199,6 +200,8 @@ void TrainingRunner::load_training_config(int argc, const char** argv) { app.add_flag("--memcpy-send-recv", MemcpySendRecv, "Use memcpy to perform send/receive (all-to-all). Currently only supported by the threads backend."); app.add_flag("--all-to-all-reduce", Options.UseAllToAllReduce, "Uses an all-to-all-based reduce algorithm. Combine with --memcpy-send-recv."); app.add_flag("--write-combined", Options.UseWriteCombined, "Uses write-combined memory for offloaded tensors."); + app.add_flag("--custom-matmul", UseCustomMatmul, "Use a self-written matmul instead of cublas. This is *not* going to be faster, this" + "option is mostly for the purists who want to minimize the dependencies.\n"); try { app.parse(argc, argv); @@ -206,6 +209,10 @@ void TrainingRunner::load_training_config(int argc, const char** argv) { std::exit(app.exit(e)); } + if( UseCustomMatmul ) { + get_matmul_backend() = EMatmulBackend::Custom; + } + if (!std::filesystem::exists(ModelRootPath)) { if (ModelRootPath.find('/') != std::string::npos) { std::string hf_path = get_hf_model_files(ModelRootPath); From eb932af2a6e51db32c0532247e036f65a03a39ac Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Tue, 11 Nov 2025 13:48:49 +0100 Subject: [PATCH 02/16] added bias --- src/kernels/gemm_mma.cu | 30 +++++++++----- src/kernels/matmul.cpp | 10 ++--- src/testing/kernels/gemm.cpp | 78 ++++++++++++++++++++++++------------ 3 files changed, 77 insertions(+), 41 deletions(-) diff --git a/src/kernels/gemm_mma.cu b/src/kernels/gemm_mma.cu index 6f35740..447e1fe 100644 --- a/src/kernels/gemm_mma.cu +++ b/src/kernels/gemm_mma.cu @@ -18,10 +18,12 @@ using int_c = std::integral_constant; template std::type_identity type_v = {}; -template +template __global__ __launch_bounds__(32*2*2, 2) void gemm_mma_tn_kernel(nv_bfloat16* __restrict__ out, const AType* __restrict__ a, const BType* __restrict__ b, - int m, int n, int k, const float* scale, bool accumulate, + int m, int n, int k, const float* __restrict__ scale, + const BiasType* __restrict__ bias, + bool accumulate, std::type_identity acc_type) { static_assert(sizeof(AType) == sizeof(BType), "index calculations assume sz(AType) == sz(BType)"); // Note: you cannot change these numbers without breaking the kernel. @@ -154,7 +156,7 @@ __global__ __launch_bounds__(32*2*2, 2) void gemm_mma_tn_kernel(nv_bfloat16* __r loop_fraction(int_c<2>{}, false_v, ks + 5 * TK); loop_fraction(int_c<3>{}, false_v, ks + 6 * TK); - if(scale != nullptr) { + if(scale != nullptr && *scale != 1.f) { for (auto ii = 0; ii < WI; ++ii) { for (int jj = 0; jj < WJ; ++jj) { acc[ii][jj].v[0] *= *scale; @@ -186,6 +188,13 @@ __global__ __launch_bounds__(32*2*2, 2) void gemm_mma_tn_kernel(nv_bfloat16* __r old[l] += upd[l]; } old.store(out + ((i + ii) * TI + r) * n + (j + jj) * TJ + 8 * c); + } else if (bias != nullptr) { + auto old = GenericVector::load(bias + (j + jj) * TJ + 8 * c); + auto upd = GenericVector::load(out_shared + (c + 2 * r) * 8); + for(int l = 0; l < 8; ++l) { + old[l] += (nv_bfloat16)upd[l]; + } + old.store(out + ((i + ii) * TI + r) * n + (j + jj) * TJ + 8 * c); } else { uint4 load = reinterpret_cast(out_shared)[c + 2 * r]; *reinterpret_cast(out + ((i + ii) * TI + r) * n + (j + jj) * TJ + 8 * c) = load; @@ -194,8 +203,9 @@ __global__ __launch_bounds__(32*2*2, 2) void gemm_mma_tn_kernel(nv_bfloat16* __r } } -template -void gemm_mma_tn_launcher(nv_bfloat16* out, const AType* a, const BType* b, int m, int n, int k, const float* scale, bool accumulate, std::type_identity, cudaStream_t stream) { +template +void gemm_mma_tn_launcher(nv_bfloat16* out, const AType* a, const BType* b, int m, int n, int k, const float* scale, const BiasType* bias, + bool accumulate, std::type_identity, cudaStream_t stream) { // our kernel is row-major, so to match cublas, we need to transpose everything => swapped a<->b, m<->n dim3 grid{(unsigned)n / 128, (unsigned)m / 128, 1}; if( n > m ) { @@ -204,15 +214,15 @@ void gemm_mma_tn_launcher(nv_bfloat16* out, const AType* a, const BType* b, int grid = {(unsigned)n / 128, (unsigned)m / 128, 1}; } dim3 block{32, 2, 2}; - gemm_mma_tn_kernel<<>>(out, b, a, n, m, k, scale, accumulate, type_v); + gemm_mma_tn_kernel<<>>(out, b, a, n, m, k, scale, bias, accumulate, type_v); } -void gemm_mma_tn(nv_bfloat16* out, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, int m, int n, int k, const float* scale, bool accumulate, cudaStream_t stream) { - gemm_mma_tn_launcher(out, a, b, m, n, k, scale, accumulate, type_v, stream); +void gemm_mma_tn(nv_bfloat16* out, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, int m, int n, int k, const float* scale, const nv_bfloat16* bias, bool accumulate, cudaStream_t stream) { + gemm_mma_tn_launcher(out, a, b, m, n, k, scale, bias, accumulate, type_v, stream); assert(cudaGetLastError() == cudaSuccess); } -void gemm_mma_tn(nv_bfloat16* out, const nv_bfloat16* a, const nv_bfloat16* b, int m, int n, int k, const float* scale, bool accumulate, cudaStream_t stream) { - gemm_mma_tn_launcher(out, a, b, m, n, k, scale, accumulate, type_v, stream); +void gemm_mma_tn(nv_bfloat16* out, const nv_bfloat16* a, const nv_bfloat16* b, int m, int n, int k, const float* scale, const nv_bfloat16* bias, bool accumulate, cudaStream_t stream) { + gemm_mma_tn_launcher(out, a, b, m, n, k, scale, bias, accumulate, type_v, stream); assert(cudaGetLastError() == cudaSuccess); } diff --git a/src/kernels/matmul.cpp b/src/kernels/matmul.cpp index f6cb352..f8ddfe8 100644 --- a/src/kernels/matmul.cpp +++ b/src/kernels/matmul.cpp @@ -157,8 +157,8 @@ void matmul_cublaslt(FloatC* d, const FloatA* a, const FloatB* b, const FloatBia } // custom matmuls -void gemm_mma_tn(nv_bfloat16* out, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, int m, int n, int k, const float* scale, bool accumulate, cudaStream_t stream); -void gemm_mma_tn(nv_bfloat16* out, const nv_bfloat16* a, const nv_bfloat16* b, int m, int n, int k, const float* scale, bool accumulate, cudaStream_t stream); +void gemm_mma_tn(nv_bfloat16* out, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, int m, int n, int k, const float* scale, const nv_bfloat16* bias, bool accumulate, cudaStream_t stream); +void gemm_mma_tn(nv_bfloat16* out, const nv_bfloat16* a, const nv_bfloat16* b, int m, int n, int k, const float* scale, const nv_bfloat16* bias, bool accumulate, cudaStream_t stream); template @@ -167,10 +167,10 @@ void matmul_dispatch(floatO* d, const floatX* a, const floatX* b, const floatB* int m, int n, int k, cudaStream_t stream, cublasLtHandle_t handle, const float* scale, EMMTranspose mode, bool accumulate) { - if(get_matmul_backend() == EMatmulBackend::CuBLAS || bias != nullptr || mode != EMMTranspose::TN) { + if(get_matmul_backend() == EMatmulBackend::CuBLAS || mode != EMMTranspose::TN) { matmul_cublaslt(d, a, b, bias, workspace, workspace_size, m, n, k, stream, handle, scale, mode, accumulate); - } else if constexpr (std::is_same_v){ - gemm_mma_tn(d, a, b, m, n, k, scale, accumulate, stream); + } else if constexpr (std::is_same_v && std::is_same_v){ + gemm_mma_tn(d, a, b, m, n, k, scale, bias, accumulate, stream); } else { matmul_cublaslt(d, a, b, bias, workspace, workspace_size, m, n, k, stream, handle, scale, mode, accumulate); } diff --git a/src/testing/kernels/gemm.cpp b/src/testing/kernels/gemm.cpp index aa4a448..7bb60b5 100644 --- a/src/testing/kernels/gemm.cpp +++ b/src/testing/kernels/gemm.cpp @@ -6,9 +6,6 @@ #include #include -void gemm_mma_tn(nv_bfloat16* out, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, int m, int n, int k, const float* scale, bool accumulate, cudaStream_t stream); -void gemm_mma_tn(nv_bfloat16* out, const nv_bfloat16* a, const nv_bfloat16* b, int m, int n, int k, const float* scale, bool accumulate, cudaStream_t stream); - template extern void matmul_cublaslt(floatO* d, const floatX* a, const floatX* b, const floatB* bias, std::byte* workspace, std::size_t workspace_size, @@ -17,26 +14,32 @@ extern void matmul_cublaslt(floatO* d, const floatX* a, const floatX* b, const f template -void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false) { +void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false, bool use_bias=false, bool check=true) { Atype* a; Btype* b; Ctype* c; + Ctype* bias; cudaMallocManaged(&a, m * k * sizeof(Atype)); cudaMallocManaged(&b, n * k * sizeof(Btype)); cudaMallocManaged(&c, m * n * sizeof(Ctype)); + cudaMallocManaged(&bias, m * sizeof(Ctype)); cudaMemset(a, 0, m * k * sizeof(Atype)); cudaMemset(c, 0, m * n * sizeof(Ctype)); cudaMemset(b, 0, n * k * sizeof(Btype)); + cudaMemset(bias, 0, m * sizeof(Ctype)); float* a_float = nullptr; float* b_float = nullptr; float* c_float = nullptr; + float* bias_float = nullptr; cudaMallocManaged(&a_float, m * k * sizeof(float)); cudaMallocManaged(&b_float, n * k * sizeof(float)); cudaMallocManaged(&c_float, m * n * sizeof(float)); + cudaMallocManaged(&bias_float, m* sizeof(float)); cudaMemset(a_float, 0, m * k * sizeof(float)); cudaMemset(b_float, 0, n * k * sizeof(float)); cudaMemset(c_float, 0, m * n * sizeof(float)); + cudaMemset(bias_float, 0, m * sizeof(float)); for(int i = 0; i < m; ++i) { for(int j = 0; j < k; ++j) { @@ -62,6 +65,12 @@ void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false) { } } + for(int i = 0; i < m; ++i) { + auto val = static_cast(rand() % 31 - 15); + bias[i] = val; + bias_float[i] = static_cast(val); + } + float* scale_ptr; cudaMallocManaged(&scale_ptr, sizeof(float)); *scale_ptr = scale; @@ -71,53 +80,67 @@ void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false) { cudaMemPrefetchAsync(c, m*n * sizeof(Ctype), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); cudaMemPrefetchAsync(scale_ptr, 4, cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); - CUDA_CHECK(cudaDeviceSynchronize()); - gemm_mma_tn(c, a, b, m, n, k, scale_ptr, accumulate, 0); - CUDA_CHECK(cudaDeviceSynchronize()); - cublasLtHandle_t handle; std::byte* workspace; size_t workspace_size = 128 * 1024 * 1024; assert(cublasLtCreate(&handle) == CUBLAS_STATUS_SUCCESS); cudaMalloc(&workspace, workspace_size); setup_cublas(); + get_matmul_backend() = EMatmulBackend::Custom; + + CUDA_CHECK(cudaDeviceSynchronize()); + matmul(c, a, b, use_bias ? bias : nullptr, scale_ptr, handle, workspace, workspace_size, m, n, k , EMMTranspose::TN, accumulate, nullptr); + CUDA_CHECK(cudaDeviceSynchronize()); cudaMemPrefetchAsync(a_float, m*k * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); cudaMemPrefetchAsync(b_float, n*k * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); cudaMemPrefetchAsync(c_float, m*n * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); + cudaMemPrefetchAsync(bias_float, m * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); + get_matmul_backend() = EMatmulBackend::CuBLAS; CUDA_CHECK(cudaDeviceSynchronize()); - - matmul_cublaslt(c_float, a_float, b_float, (float*)nullptr, workspace, workspace_size, m, n, k, nullptr, handle, scale_ptr, EMMTranspose::TN, accumulate); + matmul(c_float, a_float, b_float, use_bias ? bias_float : nullptr, scale_ptr, handle, workspace, workspace_size, m, n, k , EMMTranspose::TN, accumulate, nullptr); CUDA_CHECK(cudaDeviceSynchronize()); - double r_tol = std::is_same_v ? 1e-2 : 0.125; - - bool equal = true; - for(int i = 0; i < m; ++i) { - for(int j = 0; j < n; ++j) { - if(fabsf(c_float[j * m + i] - (float)c[j * m + i]) > std::max(r_tol * fabsf((float)c[j * m + i]), 1e-4)) { - printf("%d %d: %f != %f\n", i, j, (float) c_float[j * m + i], (float) c[j * m + i]); - equal = false; + if(check) { + double r_tol = 1e-2; + bool equal = true; + int approx_count = 0; + int far_count = 0; + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + float r_tol_max = r_tol * std::max(fabsf((float) c[j * m + i]), fabsf((float) c_float[j * m + i])); + float err = fabsf(c_float[j * m + i] - (float) c[j * m + i]); + float tol = std::max(r_tol_max, 1e-4f); + if (err > 10 * tol) { + printf(" %d %d: %f != %f\n", i, j, (float) c_float[j * m + i], (float) c[j * m + i]); + ++far_count; + } else if (err > tol) { + ++approx_count; + } + } + if (far_count > 0) { + break; } } - if(!equal) { - break; - } - } - if(equal) { - printf("PASS\n"); - } else { - printf("FAIL\n"); + if (far_count == 0 && approx_count == 0) { + printf("PASS\n"); + } else if(far_count < m * n / 100 && approx_count < m * n / 10) { + printf("CLOSE %d%%\n", 100 - (approx_count + far_count) * 100 / (m * n)); + } else { + printf("FAIL\n"); + } } cudaFree(a); cudaFree(b); cudaFree(c); + cudaFree(bias); cudaFree(a_float); cudaFree(b_float); cudaFree(c_float); + cudaFree(bias_float); cudaFree(scale_ptr); cudaFree(workspace); cublasLtDestroy(handle); @@ -140,4 +163,7 @@ int main() { run_test(m, n, k, 1.f, true); run_test<__nv_fp8_e4m3, __nv_fp8_e4m3, nv_bfloat16>(m, n, k, 4.0/k, true); + + run_test(m, n, k, 1.f, false, true); + run_test<__nv_fp8_e4m3, __nv_fp8_e4m3, nv_bfloat16>(m, n, k, 4.0/k, false, true); } From dfc21d3b69e10398172c3d4ac2e4d71feeeab786 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Tue, 11 Nov 2025 14:02:15 +0100 Subject: [PATCH 03/16] print warning when falling back --- src/kernels/matmul.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/kernels/matmul.cpp b/src/kernels/matmul.cpp index f8ddfe8..c37a3c0 100644 --- a/src/kernels/matmul.cpp +++ b/src/kernels/matmul.cpp @@ -167,6 +167,12 @@ void matmul_dispatch(floatO* d, const floatX* a, const floatX* b, const floatB* int m, int n, int k, cudaStream_t stream, cublasLtHandle_t handle, const float* scale, EMMTranspose mode, bool accumulate) { + static bool warning = false; + if(get_matmul_backend() == EMatmulBackend::Custom && mode != EMMTranspose::TN && !warning) { + fprintf(stderr, "WARNING: Custom matmuls are not supported for non-TN mode! Falling back to cublas.\n"); + warning = true; + } + if(get_matmul_backend() == EMatmulBackend::CuBLAS || mode != EMMTranspose::TN) { matmul_cublaslt(d, a, b, bias, workspace, workspace_size, m, n, k, stream, handle, scale, mode, accumulate); } else if constexpr (std::is_same_v && std::is_same_v){ From 5fd033f822f343f0c3dfcf2ac7b30bcb3d2d400a Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Tue, 11 Nov 2025 14:02:31 +0100 Subject: [PATCH 04/16] interleave scale and out writing --- src/kernels/gemm_mma.cu | 18 ++++++++---------- src/testing/kernels/gemm.cpp | 2 +- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/kernels/gemm_mma.cu b/src/kernels/gemm_mma.cu index 447e1fe..10856fe 100644 --- a/src/kernels/gemm_mma.cu +++ b/src/kernels/gemm_mma.cu @@ -156,9 +156,14 @@ __global__ __launch_bounds__(32*2*2, 2) void gemm_mma_tn_kernel(nv_bfloat16* __r loop_fraction(int_c<2>{}, false_v, ks + 5 * TK); loop_fraction(int_c<3>{}, false_v, ks + 6 * TK); - if(scale != nullptr && *scale != 1.f) { - for (auto ii = 0; ii < WI; ++ii) { - for (int jj = 0; jj < WJ; ++jj) { + // note: loop_fraction ends with __syncthreads, so no need to sync here + nv_bfloat16* out_shared = reinterpret_cast(input_tiles) + (threadIdx.y + 2 * threadIdx.z) * TJ * TI; + + for(int ii = 0; ii < WI; ii++) { + for (int jj = 0; jj < WJ; jj++) { + + // interleave scaling and output writing + if(scale != nullptr && *scale != 1.f) { acc[ii][jj].v[0] *= *scale; acc[ii][jj].v[1] *= *scale; acc[ii][jj].v[2] *= *scale; @@ -168,14 +173,7 @@ __global__ __launch_bounds__(32*2*2, 2) void gemm_mma_tn_kernel(nv_bfloat16* __r acc[ii][jj].v[6] *= *scale; acc[ii][jj].v[7] *= *scale; } - } - } - - // note: loop_fraction ends with __syncthreads, so no need to sync here - nv_bfloat16* out_shared = reinterpret_cast(input_tiles) + (threadIdx.y + 2 * threadIdx.z) * TJ * TI; - for(int ii = 0; ii < WI; ii++) { - for (int jj = 0; jj < WJ; jj++) { store_fragment_row_major_sync(acc[ii][jj], out_shared, TJ); __syncwarp(); int c = threadIdx.x % 2; diff --git a/src/testing/kernels/gemm.cpp b/src/testing/kernels/gemm.cpp index 7bb60b5..655fa0f 100644 --- a/src/testing/kernels/gemm.cpp +++ b/src/testing/kernels/gemm.cpp @@ -127,7 +127,7 @@ void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false, b if (far_count == 0 && approx_count == 0) { printf("PASS\n"); } else if(far_count < m * n / 100 && approx_count < m * n / 10) { - printf("CLOSE %d%%\n", 100 - (approx_count + far_count) * 100 / (m * n)); + printf("CLOSE %d%% [%d+%d]\n", 100 - (approx_count + far_count) * 100 / (m * n), far_count, approx_count); } else { printf("FAIL\n"); } From 2df94131b04bd7dbfe2ac92a2fc55da4f5e56eff Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Wed, 12 Nov 2025 02:56:44 +0100 Subject: [PATCH 05/16] improve performance to match/slightly surpass cublas (in some configs) --- src/kernels/gemm_mma.cu | 45 +++++++++++++++++++++++++++--------- src/testing/kernels/gemm.cpp | 27 +++++++++++----------- 2 files changed, 48 insertions(+), 24 deletions(-) diff --git a/src/kernels/gemm_mma.cu b/src/kernels/gemm_mma.cu index 10856fe..91fbc04 100644 --- a/src/kernels/gemm_mma.cu +++ b/src/kernels/gemm_mma.cu @@ -89,17 +89,16 @@ __global__ __launch_bounds__(32*2*2, 2) void gemm_mma_tn_kernel(nv_bfloat16* __r static_assert(WI * BI % NW == 0, "WI * BI must be divisible by the number of warps per block"); static_assert(WJ * BJ % NW == 0, "WI * BI must be divisible by the number of warps per block"); + m16_n16_k32_a_fragment a_frag[WI]; + m16_n16_k32_b_fragment b_frag[WI]; + auto loop_fraction = [&](auto stage_c, auto load_next_c, int ks) { - constexpr int load_stage = decltype(stage_c)::value; - constexpr int store_stage = (load_stage + 3) % DEPTH; + constexpr int stage = decltype(stage_c)::value; + constexpr int load_stage = (stage + 1) % DEPTH; + constexpr int store_stage = (stage + 3) % DEPTH; constexpr bool load_next = decltype(load_next_c)::value; - m16_n16_k32_a_fragment a_frag[WI]; - ptx_ldmatrix(a_frag[0].v, as_load_ptr + 0 * ROW_OFFSET + PIPE_OFFSET * load_stage); - ptx_ldmatrix(a_frag[1].v, as_load_ptr + 1 * ROW_OFFSET + PIPE_OFFSET * load_stage); - ptx_ldmatrix(a_frag[2].v, as_load_ptr + 2 * ROW_OFFSET + PIPE_OFFSET * load_stage); - ptx_ldmatrix(a_frag[3].v, as_load_ptr + 3 * ROW_OFFSET + PIPE_OFFSET * load_stage); - + // only load more inputs if we're not winding down the pipeline if constexpr(load_next) { __pipeline_memcpy_async(s_ptr + store_stage * PIPE_OFFSET + 0 * ROW_OFFSET, g_ptr + 0 * TI * stride + ks, 16); __pipeline_memcpy_async(s_ptr + store_stage * PIPE_OFFSET + 2 * ROW_OFFSET, g_ptr + 2 * TI * stride + ks, 16); @@ -108,14 +107,24 @@ __global__ __launch_bounds__(32*2*2, 2) void gemm_mma_tn_kernel(nv_bfloat16* __r } __pipeline_commit(); for(int jj = 0; jj < WJ; jj++) { - m16_n16_k32_b_fragment b_frag; - ptx_ldmatrix(b_frag.v, bs_load_ptr + jj * ROW_OFFSET + PIPE_OFFSET * load_stage); for(int ii = 0; ii < WI; ii++) { - mma_m16_n16_k32_sync(acc[ii][jj], a_frag[ii], b_frag, acc[ii][jj]); + mma_m16_n16_k32_sync(acc[ii][jj], a_frag[ii], b_frag[jj], acc[ii][jj]); } } __pipeline_wait_prior(2); __syncthreads(); + + // only load more registers if this is not the last step + if constexpr(load_next || stage != 3) { + ptx_ldmatrix(a_frag[0].v, as_load_ptr + 0 * ROW_OFFSET + PIPE_OFFSET * load_stage); + ptx_ldmatrix(a_frag[1].v, as_load_ptr + 1 * ROW_OFFSET + PIPE_OFFSET * load_stage); + ptx_ldmatrix(a_frag[2].v, as_load_ptr + 2 * ROW_OFFSET + PIPE_OFFSET * load_stage); + ptx_ldmatrix(a_frag[3].v, as_load_ptr + 3 * ROW_OFFSET + PIPE_OFFSET * load_stage); + ptx_ldmatrix(b_frag[0].v, bs_load_ptr + 0 * ROW_OFFSET + PIPE_OFFSET * load_stage); + ptx_ldmatrix(b_frag[1].v, bs_load_ptr + 1 * ROW_OFFSET + PIPE_OFFSET * load_stage); + ptx_ldmatrix(b_frag[2].v, bs_load_ptr + 2 * ROW_OFFSET + PIPE_OFFSET * load_stage); + ptx_ldmatrix(b_frag[3].v, bs_load_ptr + 3 * ROW_OFFSET + PIPE_OFFSET * load_stage); + } }; auto ldg_sts = [&](auto stage_c, int ks) { @@ -140,6 +149,15 @@ __global__ __launch_bounds__(32*2*2, 2) void gemm_mma_tn_kernel(nv_bfloat16* __r __pipeline_wait_prior(2); __syncthreads(); + ptx_ldmatrix(a_frag[0].v, as_load_ptr + 0 * ROW_OFFSET); + ptx_ldmatrix(a_frag[1].v, as_load_ptr + 1 * ROW_OFFSET); + ptx_ldmatrix(a_frag[2].v, as_load_ptr + 2 * ROW_OFFSET); + ptx_ldmatrix(a_frag[3].v, as_load_ptr + 3 * ROW_OFFSET); + ptx_ldmatrix(b_frag[0].v, bs_load_ptr + 0 * ROW_OFFSET); + ptx_ldmatrix(b_frag[1].v, bs_load_ptr + 1 * ROW_OFFSET); + ptx_ldmatrix(b_frag[2].v, bs_load_ptr + 2 * ROW_OFFSET); + ptx_ldmatrix(b_frag[3].v, bs_load_ptr + 3 * ROW_OFFSET); + int ks = 0; while (ks + 6 * TK < stride) { loop_fraction(int_c<0>{}, true_v, ks + 3 * TK); @@ -159,7 +177,12 @@ __global__ __launch_bounds__(32*2*2, 2) void gemm_mma_tn_kernel(nv_bfloat16* __r // note: loop_fraction ends with __syncthreads, so no need to sync here nv_bfloat16* out_shared = reinterpret_cast(input_tiles) + (threadIdx.y + 2 * threadIdx.z) * TJ * TI; + // on 40xx, for some reason, these loops don't get unrolled, and then + // all the accumulators end up in local memory and we get terrible + // performance. + #pragma unroll for(int ii = 0; ii < WI; ii++) { + #pragma unroll for (int jj = 0; jj < WJ; jj++) { // interleave scaling and output writing diff --git a/src/testing/kernels/gemm.cpp b/src/testing/kernels/gemm.cpp index 655fa0f..9311978 100644 --- a/src/testing/kernels/gemm.cpp +++ b/src/testing/kernels/gemm.cpp @@ -92,17 +92,16 @@ void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false, b matmul(c, a, b, use_bias ? bias : nullptr, scale_ptr, handle, workspace, workspace_size, m, n, k , EMMTranspose::TN, accumulate, nullptr); CUDA_CHECK(cudaDeviceSynchronize()); - - cudaMemPrefetchAsync(a_float, m*k * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); - cudaMemPrefetchAsync(b_float, n*k * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); - cudaMemPrefetchAsync(c_float, m*n * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); - cudaMemPrefetchAsync(bias_float, m * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); - get_matmul_backend() = EMatmulBackend::CuBLAS; - CUDA_CHECK(cudaDeviceSynchronize()); - matmul(c_float, a_float, b_float, use_bias ? bias_float : nullptr, scale_ptr, handle, workspace, workspace_size, m, n, k , EMMTranspose::TN, accumulate, nullptr); - CUDA_CHECK(cudaDeviceSynchronize()); - if(check) { + cudaMemPrefetchAsync(a_float, m*k * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); + cudaMemPrefetchAsync(b_float, n*k * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); + cudaMemPrefetchAsync(c_float, m*n * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); + cudaMemPrefetchAsync(bias_float, m * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); + get_matmul_backend() = EMatmulBackend::CuBLAS; + CUDA_CHECK(cudaDeviceSynchronize()); + matmul(c_float, a_float, b_float, use_bias ? bias_float : nullptr, scale_ptr, handle, workspace, workspace_size, m, n, k , EMMTranspose::TN, accumulate, nullptr); + CUDA_CHECK(cudaDeviceSynchronize()); + double r_tol = 1e-2; bool equal = true; int approx_count = 0; @@ -152,18 +151,20 @@ int main() { int k = 1664; // larger shape for benchmarking - if (false) { + if (true) { m = 2*4864; n = 1024*8; k = 896; } + std::swap(m, n); + run_test(m, n, k, 1.f, false); run_test<__nv_fp8_e4m3, __nv_fp8_e4m3, nv_bfloat16>(m, n, k, 4.0/k, false); - +/* run_test(m, n, k, 1.f, true); run_test<__nv_fp8_e4m3, __nv_fp8_e4m3, nv_bfloat16>(m, n, k, 4.0/k, true); run_test(m, n, k, 1.f, false, true); - run_test<__nv_fp8_e4m3, __nv_fp8_e4m3, nv_bfloat16>(m, n, k, 4.0/k, false, true); + run_test<__nv_fp8_e4m3, __nv_fp8_e4m3, nv_bfloat16>(m, n, k, 4.0/k, false, true);*/ } From f3d595c193c5d2a4ac8aa6d3b287d8463ca123c2 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Sat, 21 Feb 2026 12:00:16 +0100 Subject: [PATCH 06/16] handle a and b scale --- src/kernels/gemm_mma.cu | 45 ++++++++++++++++++++++++++--------------- src/kernels/matmul.cpp | 12 +++++------ 2 files changed, 35 insertions(+), 22 deletions(-) diff --git a/src/kernels/gemm_mma.cu b/src/kernels/gemm_mma.cu index 91fbc04..4a2dd18 100644 --- a/src/kernels/gemm_mma.cu +++ b/src/kernels/gemm_mma.cu @@ -21,7 +21,8 @@ std::type_identity type_v = {}; template __global__ __launch_bounds__(32*2*2, 2) void gemm_mma_tn_kernel(nv_bfloat16* __restrict__ out, const AType* __restrict__ a, const BType* __restrict__ b, - int m, int n, int k, const float* __restrict__ scale, + int m, int n, int k, + const float* __restrict__ scale_a, const float* __restrict__ scale_b, const BiasType* __restrict__ bias, bool accumulate, std::type_identity acc_type) { @@ -177,6 +178,14 @@ __global__ __launch_bounds__(32*2*2, 2) void gemm_mma_tn_kernel(nv_bfloat16* __r // note: loop_fraction ends with __syncthreads, so no need to sync here nv_bfloat16* out_shared = reinterpret_cast(input_tiles) + (threadIdx.y + 2 * threadIdx.z) * TJ * TI; + float scale = 1.f; + if (scale_a != nullptr) { + scale = *scale_a; + } + if (scale_b != nullptr) { + scale *= *scale_b; + } + // on 40xx, for some reason, these loops don't get unrolled, and then // all the accumulators end up in local memory and we get terrible // performance. @@ -186,15 +195,15 @@ __global__ __launch_bounds__(32*2*2, 2) void gemm_mma_tn_kernel(nv_bfloat16* __r for (int jj = 0; jj < WJ; jj++) { // interleave scaling and output writing - if(scale != nullptr && *scale != 1.f) { - acc[ii][jj].v[0] *= *scale; - acc[ii][jj].v[1] *= *scale; - acc[ii][jj].v[2] *= *scale; - acc[ii][jj].v[3] *= *scale; - acc[ii][jj].v[4] *= *scale; - acc[ii][jj].v[5] *= *scale; - acc[ii][jj].v[6] *= *scale; - acc[ii][jj].v[7] *= *scale; + if(scale != 1.f) { + acc[ii][jj].v[0] *= scale; + acc[ii][jj].v[1] *= scale; + acc[ii][jj].v[2] *= scale; + acc[ii][jj].v[3] *= scale; + acc[ii][jj].v[4] *= scale; + acc[ii][jj].v[5] *= scale; + acc[ii][jj].v[6] *= scale; + acc[ii][jj].v[7] *= scale; } store_fragment_row_major_sync(acc[ii][jj], out_shared, TJ); @@ -225,7 +234,7 @@ __global__ __launch_bounds__(32*2*2, 2) void gemm_mma_tn_kernel(nv_bfloat16* __r } template -void gemm_mma_tn_launcher(nv_bfloat16* out, const AType* a, const BType* b, int m, int n, int k, const float* scale, const BiasType* bias, +void gemm_mma_tn_launcher(nv_bfloat16* out, const AType* a, const BType* b, int m, int n, int k, const float* scale_a, const float* scale_b, const BiasType* bias, bool accumulate, std::type_identity, cudaStream_t stream) { // our kernel is row-major, so to match cublas, we need to transpose everything => swapped a<->b, m<->n dim3 grid{(unsigned)n / 128, (unsigned)m / 128, 1}; @@ -235,15 +244,19 @@ void gemm_mma_tn_launcher(nv_bfloat16* out, const AType* a, const BType* b, int grid = {(unsigned)n / 128, (unsigned)m / 128, 1}; } dim3 block{32, 2, 2}; - gemm_mma_tn_kernel<<>>(out, b, a, n, m, k, scale, bias, accumulate, type_v); + gemm_mma_tn_kernel<<>>(out, b, a, n, m, k, scale_a, scale_b, bias, accumulate, type_v); } -void gemm_mma_tn(nv_bfloat16* out, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, int m, int n, int k, const float* scale, const nv_bfloat16* bias, bool accumulate, cudaStream_t stream) { - gemm_mma_tn_launcher(out, a, b, m, n, k, scale, bias, accumulate, type_v, stream); +void gemm_mma_tn(nv_bfloat16* out, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, int m, int n, int k, + const float* scale_a, const float* scale_b, const nv_bfloat16* bias, bool accumulate, cudaStream_t stream) +{ + gemm_mma_tn_launcher(out, a, b, m, n, k, scale_a, scale_b, bias, accumulate, type_v, stream); assert(cudaGetLastError() == cudaSuccess); } -void gemm_mma_tn(nv_bfloat16* out, const nv_bfloat16* a, const nv_bfloat16* b, int m, int n, int k, const float* scale, const nv_bfloat16* bias, bool accumulate, cudaStream_t stream) { - gemm_mma_tn_launcher(out, a, b, m, n, k, scale, bias, accumulate, type_v, stream); +void gemm_mma_tn(nv_bfloat16* out, const nv_bfloat16* a, const nv_bfloat16* b, int m, int n, int k, + const float* scale_a, const float* scale_b, const nv_bfloat16* bias, bool accumulate, cudaStream_t stream) +{ + gemm_mma_tn_launcher(out, a, b, m, n, k, scale_a, scale_b, bias, accumulate, type_v, stream); assert(cudaGetLastError() == cudaSuccess); } diff --git a/src/kernels/matmul.cpp b/src/kernels/matmul.cpp index c37a3c0..ddacd24 100644 --- a/src/kernels/matmul.cpp +++ b/src/kernels/matmul.cpp @@ -157,15 +157,15 @@ void matmul_cublaslt(FloatC* d, const FloatA* a, const FloatB* b, const FloatBia } // custom matmuls -void gemm_mma_tn(nv_bfloat16* out, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, int m, int n, int k, const float* scale, const nv_bfloat16* bias, bool accumulate, cudaStream_t stream); -void gemm_mma_tn(nv_bfloat16* out, const nv_bfloat16* a, const nv_bfloat16* b, int m, int n, int k, const float* scale, const nv_bfloat16* bias, bool accumulate, cudaStream_t stream); +void gemm_mma_tn(nv_bfloat16* out, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, int m, int n, int k, const float* scale_a, const float* scale_b, const nv_bfloat16* bias, bool accumulate, cudaStream_t stream); +void gemm_mma_tn(nv_bfloat16* out, const nv_bfloat16* a, const nv_bfloat16* b, int m, int n, int k, const float* scale_a, const float* scale_b, const nv_bfloat16* bias, bool accumulate, cudaStream_t stream); template void matmul_dispatch(floatO* d, const floatX* a, const floatX* b, const floatB* bias, std::byte* workspace, std::size_t workspace_size, int m, int n, int k, cudaStream_t stream, cublasLtHandle_t handle, - const float* scale, EMMTranspose mode, bool accumulate) + const float* scale_a, const float* scale_b, EMMTranspose mode, bool accumulate) { static bool warning = false; if(get_matmul_backend() == EMatmulBackend::Custom && mode != EMMTranspose::TN && !warning) { @@ -174,11 +174,11 @@ void matmul_dispatch(floatO* d, const floatX* a, const floatX* b, const floatB* } if(get_matmul_backend() == EMatmulBackend::CuBLAS || mode != EMMTranspose::TN) { - matmul_cublaslt(d, a, b, bias, workspace, workspace_size, m, n, k, stream, handle, scale, mode, accumulate); + matmul_cublaslt(d, a, b, bias, workspace, workspace_size, m, n, k, stream, handle, scale_a, scale_b, mode, accumulate); } else if constexpr (std::is_same_v && std::is_same_v){ - gemm_mma_tn(d, a, b, m, n, k, scale, bias, accumulate, stream); + gemm_mma_tn(d, a, b, m, n, k, scale_a, scale_b, bias, accumulate, stream); } else { - matmul_cublaslt(d, a, b, bias, workspace, workspace_size, m, n, k, stream, handle, scale, mode, accumulate); + matmul_cublaslt(d, a, b, bias, workspace, workspace_size, m, n, k, stream, handle, scale_a, scale_b, mode, accumulate); } } From ddb41cb6a9bd4ac18da222f39f478b6303cecacd Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Sat, 21 Feb 2026 13:02:31 +0100 Subject: [PATCH 07/16] cleanups --- src/kernels/gemm_mma.cu | 18 +++++++++++------- src/kernels/matmul.cpp | 3 ++- src/kernels/tensor_core_utils.h | 5 +++++ 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/kernels/gemm_mma.cu b/src/kernels/gemm_mma.cu index 4a2dd18..f2a735c 100644 --- a/src/kernels/gemm_mma.cu +++ b/src/kernels/gemm_mma.cu @@ -4,10 +4,10 @@ #include "tensor_core_utils.h" #include "utilities/vec.cuh" #include -#undef NDEBUG -#include #include +#include "utilities/utils.h" + unsigned div_ceil(unsigned a, unsigned b) { return (a + b - 1) / b; } @@ -236,12 +236,16 @@ __global__ __launch_bounds__(32*2*2, 2) void gemm_mma_tn_kernel(nv_bfloat16* __r template void gemm_mma_tn_launcher(nv_bfloat16* out, const AType* a, const BType* b, int m, int n, int k, const float* scale_a, const float* scale_b, const BiasType* bias, bool accumulate, std::type_identity, cudaStream_t stream) { + if (n % 128 != 0 || m % 128 != 0) { + throw std::invalid_argument("gemm_mma_tn_launcher: n and m must be divisible by 128"); + } + // our kernel is row-major, so to match cublas, we need to transpose everything => swapped a<->b, m<->n - dim3 grid{(unsigned)n / 128, (unsigned)m / 128, 1}; + dim3 grid; if( n > m ) { - grid = {(unsigned)m / 128, (unsigned)n / 128, 1}; + grid = {(unsigned)div_exact(m, 128), (unsigned)div_exact(n, 128), 1}; } else { - grid = {(unsigned)n / 128, (unsigned)m / 128, 1}; + grid = {(unsigned)div_exact(n, 128), (unsigned)div_exact(m, 128), 1}; } dim3 block{32, 2, 2}; gemm_mma_tn_kernel<<>>(out, b, a, n, m, k, scale_a, scale_b, bias, accumulate, type_v); @@ -251,12 +255,12 @@ void gemm_mma_tn(nv_bfloat16* out, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* const float* scale_a, const float* scale_b, const nv_bfloat16* bias, bool accumulate, cudaStream_t stream) { gemm_mma_tn_launcher(out, a, b, m, n, k, scale_a, scale_b, bias, accumulate, type_v, stream); - assert(cudaGetLastError() == cudaSuccess); + CUDA_CHECK(cudaGetLastError()); } void gemm_mma_tn(nv_bfloat16* out, const nv_bfloat16* a, const nv_bfloat16* b, int m, int n, int k, const float* scale_a, const float* scale_b, const nv_bfloat16* bias, bool accumulate, cudaStream_t stream) { gemm_mma_tn_launcher(out, a, b, m, n, k, scale_a, scale_b, bias, accumulate, type_v, stream); - assert(cudaGetLastError() == cudaSuccess); + CUDA_CHECK(cudaGetLastError()); } diff --git a/src/kernels/matmul.cpp b/src/kernels/matmul.cpp index ddacd24..4242dc6 100644 --- a/src/kernels/matmul.cpp +++ b/src/kernels/matmul.cpp @@ -3,6 +3,7 @@ // // Based on llm.c https://github.com/karpathy/llm.c +#include #include #include @@ -167,7 +168,7 @@ void matmul_dispatch(floatO* d, const floatX* a, const floatX* b, const floatB* int m, int n, int k, cudaStream_t stream, cublasLtHandle_t handle, const float* scale_a, const float* scale_b, EMMTranspose mode, bool accumulate) { - static bool warning = false; + static std::atomic warning = false; if(get_matmul_backend() == EMatmulBackend::Custom && mode != EMMTranspose::TN && !warning) { fprintf(stderr, "WARNING: Custom matmuls are not supported for non-TN mode! Falling back to cublas.\n"); warning = true; diff --git a/src/kernels/tensor_core_utils.h b/src/kernels/tensor_core_utils.h index 32b053e..47bb361 100644 --- a/src/kernels/tensor_core_utils.h +++ b/src/kernels/tensor_core_utils.h @@ -1,3 +1,6 @@ +#ifndef LLMQ_TENSOR_CORE_UTILS_CUH +#define LLMQ_TENSOR_CORE_UTILS_CUH + #include "utilities/dtype.h" #include #include @@ -142,3 +145,5 @@ __device__ __forceinline__ void mma_m16_n16_k32_sync(m16_n16_k32_c_fragment), "C"(ptx_type_name)); } + +#endif From 168a2df6f8b221b1d27e44fecf0313e81c61b69e Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Sat, 21 Feb 2026 14:19:20 +0100 Subject: [PATCH 08/16] move test file --- src/testing/{kernels/gemm.cpp => test-gemm.cpp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/testing/{kernels/gemm.cpp => test-gemm.cpp} (100%) diff --git a/src/testing/kernels/gemm.cpp b/src/testing/test-gemm.cpp similarity index 100% rename from src/testing/kernels/gemm.cpp rename to src/testing/test-gemm.cpp From 9a3374f8f0e44748bd36109b8b8afade3acbf729 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Sat, 21 Feb 2026 14:54:53 +0100 Subject: [PATCH 09/16] fix existing rope test --- src/testing/test-rope.cu | 6 +++--- src/testing/test_utils.h | 7 +++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/testing/test-rope.cu b/src/testing/test-rope.cu index d72b2e9..461f1d9 100644 --- a/src/testing/test-rope.cu +++ b/src/testing/test-rope.cu @@ -204,10 +204,10 @@ TEST_CASE("rope forward/backward bfloat16 matches CPU (emulated)", "[rope][bf16] std::vector h_inp_f = uniform_host(size_inp, -1.f, 1.f, 1337ull); std::vector h_inp_bf16 = to_bf16(h_inp_f); - // Prepare freqs and quantize to bf16 (kernel expects bf16 freqs as well) + // Prepare freqs and quantize to fp16 (kernel expects fp16 freqs) std::vector h_freqs_f(size_freqs); precompute_freqs_cis(h_freqs_f.data(), HD, T, 10000.0f); - std::vector h_freqs_bf16 = to_bf16(h_freqs_f); + std::vector h_freqs_fp16 = to_fp16(h_freqs_f); // CPU baseline with bf16 emulation: quantize inputs/freqs to bf16, do math in float, quantize outputs std::vector h_inp_q = round_bf16(h_inp_f); @@ -222,7 +222,7 @@ TEST_CASE("rope forward/backward bfloat16 matches CPU (emulated)", "[rope][bf16] thrust::device_vector d_inp = to_device(h_inp_bf16); thrust::device_vector d_out(size_inp); thrust::device_vector d_dinp(size_inp); - thrust::device_vector d_freqs = to_device(h_freqs_bf16); + thrust::device_vector d_freqs = to_device(h_freqs_fp16); rope_forward(thrust::raw_pointer_cast(d_out.data()), thrust::raw_pointer_cast(d_inp.data()), diff --git a/src/testing/test_utils.h b/src/testing/test_utils.h index dc1e8d5..5bec8b5 100644 --- a/src/testing/test_utils.h +++ b/src/testing/test_utils.h @@ -62,6 +62,13 @@ inline std::vector to_bf16(const std::vector& vec) { } return result; } +inline std::vector to_fp16(const std::vector& vec) { + std::vector result(vec.size()); + for(size_t i = 0; i < vec.size(); ++i) { + result[i] = half(vec[i]); + } + return result; +} inline std::vector round_bf16(const std::vector& vec) { std::vector result(vec.size()); From 0c1e6e14e3e9316c8ae1d6c0445575018b01caf9 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Sat, 21 Feb 2026 19:22:30 +0100 Subject: [PATCH 10/16] fix gemm test --- src/testing/test-gemm.cpp | 190 +++++++++++++++++++++++++++----------- 1 file changed, 138 insertions(+), 52 deletions(-) diff --git a/src/testing/test-gemm.cpp b/src/testing/test-gemm.cpp index 9311978..9fc0012 100644 --- a/src/testing/test-gemm.cpp +++ b/src/testing/test-gemm.cpp @@ -3,14 +3,19 @@ #include #include #include + +#include + #include #include +#include "test_config.h" + template extern void matmul_cublaslt(floatO* d, const floatX* a, const floatX* b, const floatB* bias, std::byte* workspace, std::size_t workspace_size, int m, int n, int k, cudaStream_t stream, cublasLtHandle_t handle, - const float* scale, EMMTranspose mode, bool accumulate); + const float* scale_a, const float* scale_b, EMMTranspose mode, bool accumulate); template @@ -19,27 +24,27 @@ void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false, b Btype* b; Ctype* c; Ctype* bias; - cudaMallocManaged(&a, m * k * sizeof(Atype)); - cudaMallocManaged(&b, n * k * sizeof(Btype)); - cudaMallocManaged(&c, m * n * sizeof(Ctype)); - cudaMallocManaged(&bias, m * sizeof(Ctype)); - cudaMemset(a, 0, m * k * sizeof(Atype)); - cudaMemset(c, 0, m * n * sizeof(Ctype)); - cudaMemset(b, 0, n * k * sizeof(Btype)); - cudaMemset(bias, 0, m * sizeof(Ctype)); + CUDA_CHECK(cudaMallocManaged(&a, m * k * sizeof(Atype))); + CUDA_CHECK(cudaMallocManaged(&b, n * k * sizeof(Btype))); + CUDA_CHECK(cudaMallocManaged(&c, m * n * sizeof(Ctype))); + CUDA_CHECK(cudaMallocManaged(&bias, n * sizeof(Ctype))); + CUDA_CHECK(cudaMemset(a, 0, m * k * sizeof(Atype))); + CUDA_CHECK(cudaMemset(c, 0, m * n * sizeof(Ctype))); + CUDA_CHECK(cudaMemset(b, 0, n * k * sizeof(Btype))); + CUDA_CHECK(cudaMemset(bias, 0, n * sizeof(Ctype))); float* a_float = nullptr; float* b_float = nullptr; float* c_float = nullptr; float* bias_float = nullptr; - cudaMallocManaged(&a_float, m * k * sizeof(float)); - cudaMallocManaged(&b_float, n * k * sizeof(float)); - cudaMallocManaged(&c_float, m * n * sizeof(float)); - cudaMallocManaged(&bias_float, m* sizeof(float)); - cudaMemset(a_float, 0, m * k * sizeof(float)); - cudaMemset(b_float, 0, n * k * sizeof(float)); - cudaMemset(c_float, 0, m * n * sizeof(float)); - cudaMemset(bias_float, 0, m * sizeof(float)); + CUDA_CHECK(cudaMallocManaged(&a_float, m * k * sizeof(float))); + CUDA_CHECK(cudaMallocManaged(&b_float, n * k * sizeof(float))); + CUDA_CHECK(cudaMallocManaged(&c_float, m * n * sizeof(float))); + CUDA_CHECK(cudaMallocManaged(&bias_float, m* sizeof(float))); + CUDA_CHECK(cudaMemset(a_float, 0, m * k * sizeof(float))); + CUDA_CHECK(cudaMemset(b_float, 0, n * k * sizeof(float))); + CUDA_CHECK(cudaMemset(c_float, 0, m * n * sizeof(float))); + CUDA_CHECK(cudaMemset(bias_float, 0, m * sizeof(float))); for(int i = 0; i < m; ++i) { for(int j = 0; j < k; ++j) { @@ -71,35 +76,38 @@ void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false, b bias_float[i] = static_cast(val); } - float* scale_ptr; - cudaMallocManaged(&scale_ptr, sizeof(float)); - *scale_ptr = scale; + float* scale_a_ptr, *scale_b_ptr; + CUDA_CHECK(cudaMallocManaged(&scale_a_ptr, sizeof(float))); + CUDA_CHECK(cudaMallocManaged(&scale_b_ptr, sizeof(float))); + *scale_a_ptr = sqrtf(scale); + *scale_b_ptr = sqrtf(scale); - cudaMemPrefetchAsync(a, m*k * sizeof(Atype), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); - cudaMemPrefetchAsync(b, n*k * sizeof(Btype), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); - cudaMemPrefetchAsync(c, m*n * sizeof(Ctype), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); - cudaMemPrefetchAsync(scale_ptr, 4, cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); + CUDA_CHECK(cudaMemPrefetchAsync(a, m*k * sizeof(Atype), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); + CUDA_CHECK(cudaMemPrefetchAsync(b, n*k * sizeof(Btype), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); + CUDA_CHECK(cudaMemPrefetchAsync(c, m*n * sizeof(Ctype), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); + CUDA_CHECK(cudaMemPrefetchAsync(scale_a_ptr, 4, cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); + CUDA_CHECK(cudaMemPrefetchAsync(scale_b_ptr, 4, cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); cublasLtHandle_t handle; std::byte* workspace; size_t workspace_size = 128 * 1024 * 1024; assert(cublasLtCreate(&handle) == CUBLAS_STATUS_SUCCESS); cudaMalloc(&workspace, workspace_size); - setup_cublas(); get_matmul_backend() = EMatmulBackend::Custom; CUDA_CHECK(cudaDeviceSynchronize()); - matmul(c, a, b, use_bias ? bias : nullptr, scale_ptr, handle, workspace, workspace_size, m, n, k , EMMTranspose::TN, accumulate, nullptr); + matmul(c, a, b, use_bias ? bias : nullptr, scale_a_ptr, scale_b_ptr, handle, workspace, workspace_size, m, n, k, EMMTranspose::TN, accumulate, nullptr); CUDA_CHECK(cudaDeviceSynchronize()); if(check) { - cudaMemPrefetchAsync(a_float, m*k * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); - cudaMemPrefetchAsync(b_float, n*k * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); - cudaMemPrefetchAsync(c_float, m*n * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); - cudaMemPrefetchAsync(bias_float, m * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0); + CUDA_CHECK(cudaMemPrefetchAsync(a_float, m*k * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); + CUDA_CHECK(cudaMemPrefetchAsync(b_float, n*k * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); + CUDA_CHECK(cudaMemPrefetchAsync(c_float, m*n * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); + CUDA_CHECK(cudaMemPrefetchAsync(bias_float, m * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); get_matmul_backend() = EMatmulBackend::CuBLAS; CUDA_CHECK(cudaDeviceSynchronize()); - matmul(c_float, a_float, b_float, use_bias ? bias_float : nullptr, scale_ptr, handle, workspace, workspace_size, m, n, k , EMMTranspose::TN, accumulate, nullptr); + matmul(c_float, a_float, b_float, use_bias ? bias_float : nullptr, nullptr, nullptr, + handle, workspace, workspace_size, m, n, k , EMMTranspose::TN, accumulate, nullptr); CUDA_CHECK(cudaDeviceSynchronize()); double r_tol = 1e-2; @@ -108,11 +116,13 @@ void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false, b int far_count = 0; for (int i = 0; i < m; ++i) { for (int j = 0; j < n; ++j) { - float r_tol_max = r_tol * std::max(fabsf((float) c[j * m + i]), fabsf((float) c_float[j * m + i])); - float err = fabsf(c_float[j * m + i] - (float) c[j * m + i]); + float expected = c_float[j * m + i] * (*scale_a_ptr) * (*scale_b_ptr); + float received = (float) c[j * m + i]; + float r_tol_max = r_tol * std::max(fabsf((float) c[j * m + i]), fabsf(expected)); + float err = fabsf(expected - received); float tol = std::max(r_tol_max, 1e-4f); if (err > 10 * tol) { - printf(" %d %d: %f != %f\n", i, j, (float) c_float[j * m + i], (float) c[j * m + i]); + printf(" %d %d: %f != %f\n", i, j, expected, received); ++far_count; } else if (err > tol) { ++approx_count; @@ -140,31 +150,107 @@ void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false, b cudaFree(b_float); cudaFree(c_float); cudaFree(bias_float); - cudaFree(scale_ptr); + cudaFree(scale_a_ptr); + cudaFree(scale_b_ptr); cudaFree(workspace); cublasLtDestroy(handle); } -int main() { - int m = 1536; - int n = 1024; - int k = 1664; +TEST_CASE("tiny matmul bfloat16 x bfloat16 -> bfloat16", "[gemm][bf16]") { + bool accumulate = false; + bool bias = false; + SECTION("set-nobias") { + accumulate = false; + bias = false; + } + SECTION("set-bias") { + accumulate = false; + bias = true; + } + SECTION("accumulate-nobias") { + accumulate = true; + bias = false; + } + SECTION("accumulate-bias") { + accumulate = true; + bias = true; + } + run_test(128, 128, 128, 1.0f, accumulate, bias); +} - // larger shape for benchmarking - if (true) { - m = 2*4864; - n = 1024*8; - k = 896; +TEST_CASE("tiny matmul fp8 x fp8 -> bfloat16", "[gemm][fp8]") { + bool accumulate = false; + bool bias = false; + SECTION("set-nobias") { + accumulate = false; + bias = false; + } + SECTION("set-bias") { + accumulate = false; + bias = true; + } + SECTION("accumulate-nobias") { + accumulate = true; + bias = false; + } + SECTION("accumulate-bias") { + accumulate = true; + bias = true; } - std::swap(m, n); + run_test<__nv_fp8_e4m3, __nv_fp8_e4m3, nv_bfloat16>(128, 128, 128, 4.0f / 128, accumulate, bias); +} - run_test(m, n, k, 1.f, false); - run_test<__nv_fp8_e4m3, __nv_fp8_e4m3, nv_bfloat16>(m, n, k, 4.0/k, false); -/* - run_test(m, n, k, 1.f, true); - run_test<__nv_fp8_e4m3, __nv_fp8_e4m3, nv_bfloat16>(m, n, k, 4.0/k, true); +TEST_CASE("matmul bfloat16 x bfloat16 -> bfloat16", "[gemm][bf16]") { + const auto& cfg = testing_config::get_test_config(); + int m = cfg.B * cfg.T; + int k = cfg.C; + int n = div_ceil(2 * m / 3, 128) * 128; - run_test(m, n, k, 1.f, false, true); - run_test<__nv_fp8_e4m3, __nv_fp8_e4m3, nv_bfloat16>(m, n, k, 4.0/k, false, true);*/ + bool accumulate = false; + bool bias = false; + SECTION("set-nobias") { + accumulate = false; + bias = false; + } + SECTION("set-bias") { + accumulate = false; + bias = true; + } + SECTION("accumulate-nobias") { + accumulate = true; + bias = false; + } + SECTION("accumulate-bias") { + accumulate = true; + bias = true; + } + run_test(m, n, k, 1.0f, accumulate, bias); +} + +TEST_CASE("matmul fp8 x fp8 -> bfloat16", "[gemm][fp8]") { + const auto& cfg = testing_config::get_test_config(); + int m = cfg.B * cfg.T; + int k = cfg.C; + int n = div_ceil(2 * m / 3, 128) * 128; + + bool accumulate = false; + bool bias = false; + SECTION("set-nobias") { + accumulate = false; + bias = false; + } + SECTION("set-bias") { + accumulate = false; + bias = true; + } + SECTION("accumulate-nobias") { + accumulate = true; + bias = false; + } + SECTION("accumulate-bias") { + accumulate = true; + bias = true; + } + run_test<__nv_fp8_e4m3, __nv_fp8_e4m3, nv_bfloat16>(m, n, k, 4.0f / k, accumulate, bias); } From 01f8c49b6470e9a24efd33701da8f7adee4f3702 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Sat, 21 Feb 2026 19:40:56 +0100 Subject: [PATCH 11/16] more fixes --- CMakeLists.txt | 4 +--- src/kernels/gemm_mma.cu | 8 ++++---- src/kernels/matmul.cpp | 6 ++++-- src/kernels/tensor_core_utils.h | 6 +++++- src/testing/test-gemm.cpp | 34 ++++++++++++++++++--------------- train.cpp | 5 +++-- 6 files changed, 36 insertions(+), 27 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c6a040c..bac1314 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -195,9 +195,6 @@ target_link_libraries(train PRIVATE llmq-common CLI11::CLI11 fmt::fmt-header-onl add_executable(export-checkpoint export-checkpoint.cpp) target_link_libraries(export-checkpoint PRIVATE llmq-common CLI11::CLI11) -add_executable(gemm-test src/testing/kernels/gemm.cpp) -target_link_libraries(gemm-test PUBLIC llmq-common CUDA::cublas) - if (NOT SKBUILD) install(TARGETS llmq-common train export-checkpoint LIBRARY DESTINATION lib @@ -260,6 +257,7 @@ if (BUILD_TESTS) src/testing/test-swiglu.cu src/testing/test-rope.cu src/testing/test-classifier.cu + src/testing/test-gemm.cpp ) target_link_libraries(unit-tests PRIVATE llmq-common Catch2::Catch2 CLI11::CLI11) target_compile_options(unit-tests PUBLIC $<$:--expt-relaxed-constexpr> $<$:-lineinfo>) diff --git a/src/kernels/gemm_mma.cu b/src/kernels/gemm_mma.cu index f2a735c..0cc5096 100644 --- a/src/kernels/gemm_mma.cu +++ b/src/kernels/gemm_mma.cu @@ -1,3 +1,7 @@ +// Copyright (c) 2026, IST Austria, developed by Erik Schultheis +// SPDX-License-Identifier: Apache-2.0 +// + #include #include #include @@ -8,10 +12,6 @@ #include "utilities/utils.h" -unsigned div_ceil(unsigned a, unsigned b) { - return (a + b - 1) / b; -} - template using int_c = std::integral_constant; diff --git a/src/kernels/matmul.cpp b/src/kernels/matmul.cpp index 4242dc6..36f3613 100644 --- a/src/kernels/matmul.cpp +++ b/src/kernels/matmul.cpp @@ -14,6 +14,7 @@ cublasComputeType_t cublas_compute = CUBLAS_COMPUTE_32F; EMatmulBackend& get_matmul_backend() { + // TODO: this is global state right now. Ideally, we could make this local. static EMatmulBackend backend = EMatmulBackend::CuBLAS; return backend; } @@ -168,8 +169,9 @@ void matmul_dispatch(floatO* d, const floatX* a, const floatX* b, const floatB* int m, int n, int k, cudaStream_t stream, cublasLtHandle_t handle, const float* scale_a, const float* scale_b, EMMTranspose mode, bool accumulate) { - static std::atomic warning = false; - if(get_matmul_backend() == EMatmulBackend::Custom && mode != EMMTranspose::TN && !warning) { + static std::atomic warning{false}; + bool expected = false; + if(get_matmul_backend() == EMatmulBackend::Custom && mode != EMMTranspose::TN && warning.compare_exchange_strong(expected, true)) { fprintf(stderr, "WARNING: Custom matmuls are not supported for non-TN mode! Falling back to cublas.\n"); warning = true; } diff --git a/src/kernels/tensor_core_utils.h b/src/kernels/tensor_core_utils.h index 47bb361..9e29f20 100644 --- a/src/kernels/tensor_core_utils.h +++ b/src/kernels/tensor_core_utils.h @@ -1,3 +1,7 @@ +// Copyright (c) 2026, IST Austria, developed by Erik Schultheis +// SPDX-License-Identifier: Apache-2.0 +// + #ifndef LLMQ_TENSOR_CORE_UTILS_CUH #define LLMQ_TENSOR_CORE_UTILS_CUH @@ -20,7 +24,7 @@ struct m16_n16_k32_b_fragment { template struct m16_n16_k32_c_fragment { - AccDType v[8] = {0.f, 0.f, 0.f, 0.f}; + AccDType v[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; }; template diff --git a/src/testing/test-gemm.cpp b/src/testing/test-gemm.cpp index 9fc0012..98510b3 100644 --- a/src/testing/test-gemm.cpp +++ b/src/testing/test-gemm.cpp @@ -1,3 +1,7 @@ +// Copyright (c) 2026, IST Austria, developed by Erik Schultheis +// SPDX-License-Identifier: Apache-2.0 +// + #include "kernels/kernels.h" #include "utilities/utils.h" #include @@ -40,11 +44,11 @@ void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false, b CUDA_CHECK(cudaMallocManaged(&a_float, m * k * sizeof(float))); CUDA_CHECK(cudaMallocManaged(&b_float, n * k * sizeof(float))); CUDA_CHECK(cudaMallocManaged(&c_float, m * n * sizeof(float))); - CUDA_CHECK(cudaMallocManaged(&bias_float, m* sizeof(float))); + CUDA_CHECK(cudaMallocManaged(&bias_float, n * sizeof(float))); CUDA_CHECK(cudaMemset(a_float, 0, m * k * sizeof(float))); CUDA_CHECK(cudaMemset(b_float, 0, n * k * sizeof(float))); CUDA_CHECK(cudaMemset(c_float, 0, m * n * sizeof(float))); - CUDA_CHECK(cudaMemset(bias_float, 0, m * sizeof(float))); + CUDA_CHECK(cudaMemset(bias_float, 0, n * sizeof(float))); for(int i = 0; i < m; ++i) { for(int j = 0; j < k; ++j) { @@ -70,7 +74,7 @@ void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false, b } } - for(int i = 0; i < m; ++i) { + for(int i = 0; i < n; ++i) { auto val = static_cast(rand() % 31 - 15); bias[i] = val; bias_float[i] = static_cast(val); @@ -103,7 +107,7 @@ void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false, b CUDA_CHECK(cudaMemPrefetchAsync(a_float, m*k * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); CUDA_CHECK(cudaMemPrefetchAsync(b_float, n*k * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); CUDA_CHECK(cudaMemPrefetchAsync(c_float, m*n * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); - CUDA_CHECK(cudaMemPrefetchAsync(bias_float, m * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); + CUDA_CHECK(cudaMemPrefetchAsync(bias_float, n * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); get_matmul_backend() = EMatmulBackend::CuBLAS; CUDA_CHECK(cudaDeviceSynchronize()); matmul(c_float, a_float, b_float, use_bias ? bias_float : nullptr, nullptr, nullptr, @@ -142,17 +146,17 @@ void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false, b } } - cudaFree(a); - cudaFree(b); - cudaFree(c); - cudaFree(bias); - cudaFree(a_float); - cudaFree(b_float); - cudaFree(c_float); - cudaFree(bias_float); - cudaFree(scale_a_ptr); - cudaFree(scale_b_ptr); - cudaFree(workspace); + CUDA_CHECK(cudaFree(a)); + CUDA_CHECK(cudaFree(b)); + CUDA_CHECK(cudaFree(c)); + CUDA_CHECK(cudaFree(bias)); + CUDA_CHECK(cudaFree(a_float)); + CUDA_CHECK(cudaFree(b_float)); + CUDA_CHECK(cudaFree(c_float)); + CUDA_CHECK(cudaFree(bias_float)); + CUDA_CHECK(cudaFree(scale_a_ptr)); + CUDA_CHECK(cudaFree(scale_b_ptr)); + CUDA_CHECK(cudaFree(workspace)); cublasLtDestroy(handle); } diff --git a/train.cpp b/train.cpp index f5a1882..5810f10 100644 --- a/train.cpp +++ b/train.cpp @@ -200,7 +200,7 @@ void TrainingRunner::load_training_config(int argc, const char** argv) { app.add_flag("--memcpy-send-recv", MemcpySendRecv, "Use memcpy to perform send/receive (all-to-all). Currently only supported by the threads backend."); app.add_flag("--all-to-all-reduce", Options.UseAllToAllReduce, "Uses an all-to-all-based reduce algorithm. Combine with --memcpy-send-recv."); app.add_flag("--write-combined", Options.UseWriteCombined, "Uses write-combined memory for offloaded tensors."); - app.add_flag("--custom-matmul", UseCustomMatmul, "Use a self-written matmul instead of cublas. This is *not* going to be faster, this" + app.add_flag("--custom-matmul", UseCustomMatmul, "Use a self-written matmul instead of cublas. This is *not* going to be faster, this " "option is mostly for the purists who want to minimize the dependencies.\n"); try { @@ -209,7 +209,8 @@ void TrainingRunner::load_training_config(int argc, const char** argv) { std::exit(app.exit(e)); } - if( UseCustomMatmul ) { + // set-up matmul before any threads are started + if (UseCustomMatmul) { get_matmul_backend() = EMatmulBackend::Custom; } From 10d2013a904878943be13c229c5ff2c2cd9b08b7 Mon Sep 17 00:00:00 2001 From: Erik Schultheis <7938269+ngc92@users.noreply.github.com> Date: Sat, 21 Feb 2026 22:48:28 +0200 Subject: [PATCH 12/16] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/kernels/gemm_mma.cu | 10 ++++---- src/kernels/matmul.cpp | 1 - ...sor_core_utils.h => tensor_core_utils.cuh} | 2 +- src/testing/test-gemm.cpp | 24 ++++++++++++------- train.cpp | 2 +- 5 files changed, 22 insertions(+), 17 deletions(-) rename src/kernels/{tensor_core_utils.h => tensor_core_utils.cuh} (98%) diff --git a/src/kernels/gemm_mma.cu b/src/kernels/gemm_mma.cu index 0cc5096..21e2e73 100644 --- a/src/kernels/gemm_mma.cu +++ b/src/kernels/gemm_mma.cu @@ -5,7 +5,7 @@ #include #include #include -#include "tensor_core_utils.h" +#include "tensor_core_utils.cuh" #include "utilities/vec.cuh" #include #include @@ -88,10 +88,10 @@ __global__ __launch_bounds__(32*2*2, 2) void gemm_mma_tn_kernel(nv_bfloat16* __r const uint4* bs_load_ptr = input_tiles + offsets.y + threadIdx.z * WJ*ROW_OFFSET + DEPTH * PIPE_OFFSET; static_assert(WI * BI % NW == 0, "WI * BI must be divisible by the number of warps per block"); - static_assert(WJ * BJ % NW == 0, "WI * BI must be divisible by the number of warps per block"); + static_assert(WJ * BJ % NW == 0, "WJ * BJ must be divisible by the number of warps per block"); m16_n16_k32_a_fragment a_frag[WI]; - m16_n16_k32_b_fragment b_frag[WI]; + m16_n16_k32_b_fragment b_frag[WJ]; auto loop_fraction = [&](auto stage_c, auto load_next_c, int ks) { constexpr int stage = decltype(stage_c)::value; @@ -236,8 +236,8 @@ __global__ __launch_bounds__(32*2*2, 2) void gemm_mma_tn_kernel(nv_bfloat16* __r template void gemm_mma_tn_launcher(nv_bfloat16* out, const AType* a, const BType* b, int m, int n, int k, const float* scale_a, const float* scale_b, const BiasType* bias, bool accumulate, std::type_identity, cudaStream_t stream) { - if (n % 128 != 0 || m % 128 != 0) { - throw std::invalid_argument("gemm_mma_tn_launcher: n and m must be divisible by 128"); + if (n % 128 != 0 || m % 128 != 0 || k % 128 != 0) { + throw std::invalid_argument("gemm_mma_tn_launcher: n, m, k must be divisible by 128"); } // our kernel is row-major, so to match cublas, we need to transpose everything => swapped a<->b, m<->n diff --git a/src/kernels/matmul.cpp b/src/kernels/matmul.cpp index 36f3613..1396156 100644 --- a/src/kernels/matmul.cpp +++ b/src/kernels/matmul.cpp @@ -173,7 +173,6 @@ void matmul_dispatch(floatO* d, const floatX* a, const floatX* b, const floatB* bool expected = false; if(get_matmul_backend() == EMatmulBackend::Custom && mode != EMMTranspose::TN && warning.compare_exchange_strong(expected, true)) { fprintf(stderr, "WARNING: Custom matmuls are not supported for non-TN mode! Falling back to cublas.\n"); - warning = true; } if(get_matmul_backend() == EMatmulBackend::CuBLAS || mode != EMMTranspose::TN) { diff --git a/src/kernels/tensor_core_utils.h b/src/kernels/tensor_core_utils.cuh similarity index 98% rename from src/kernels/tensor_core_utils.h rename to src/kernels/tensor_core_utils.cuh index 9e29f20..3f3658d 100644 --- a/src/kernels/tensor_core_utils.h +++ b/src/kernels/tensor_core_utils.cuh @@ -24,7 +24,7 @@ struct m16_n16_k32_b_fragment { template struct m16_n16_k32_c_fragment { - AccDType v[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; + AccDType v[8] = {}; }; template diff --git a/src/testing/test-gemm.cpp b/src/testing/test-gemm.cpp index 98510b3..1b0fb6a 100644 --- a/src/testing/test-gemm.cpp +++ b/src/testing/test-gemm.cpp @@ -21,9 +21,11 @@ extern void matmul_cublaslt(floatO* d, const floatX* a, const floatX* b, const f int m, int n, int k, cudaStream_t stream, cublasLtHandle_t handle, const float* scale_a, const float* scale_b, EMMTranspose mode, bool accumulate); +extern cublasLtHandle_t create_cublaslt_handle(); template void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false, bool use_bias=false, bool check=true) { + auto saved_backend = get_matmul_backend(); Atype* a; Btype* b; Ctype* c; @@ -50,9 +52,12 @@ void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false, b CUDA_CHECK(cudaMemset(c_float, 0, m * n * sizeof(float))); CUDA_CHECK(cudaMemset(bias_float, 0, n * sizeof(float))); + std::mt19937 rng(12345); + std::uniform_int_distribution dist(-15, 15); + for(int i = 0; i < m; ++i) { for(int j = 0; j < k; ++j) { - auto val = static_cast(rand() % 31 - 15); + auto val = static_cast(dist(rng)); a[i*k+j] = val; a_float[i*k+j] = static_cast(val); } @@ -60,7 +65,7 @@ void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false, b for(int i = 0; i < n; ++i) { for(int j = 0; j < k; ++j) { - auto val = static_cast(rand() % 31 - 15); + auto val = static_cast(dist(rng)); b[i*k+j] = val; b_float[i*k+j] = static_cast(val); } @@ -68,14 +73,14 @@ void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false, b for(int i = 0; i < m; ++i) { for(int j = 0; j < n; ++j) { - auto val = static_cast(rand() % 31 - 15); + auto val = static_cast(dist(rng)); c[i*n+j] = val; c_float[i*n+j] = static_cast(val); } } for(int i = 0; i < n; ++i) { - auto val = static_cast(rand() % 31 - 15); + auto val = static_cast(dist(rng)); bias[i] = val; bias_float[i] = static_cast(val); } @@ -92,11 +97,10 @@ void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false, b CUDA_CHECK(cudaMemPrefetchAsync(scale_a_ptr, 4, cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); CUDA_CHECK(cudaMemPrefetchAsync(scale_b_ptr, 4, cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); - cublasLtHandle_t handle; + cublasLtHandle_t handle = create_cublaslt_handle(); std::byte* workspace; size_t workspace_size = 128 * 1024 * 1024; - assert(cublasLtCreate(&handle) == CUBLAS_STATUS_SUCCESS); - cudaMalloc(&workspace, workspace_size); + CUDA_CHECK(cudaMalloc(&workspace, workspace_size)); get_matmul_backend() = EMatmulBackend::Custom; CUDA_CHECK(cudaDeviceSynchronize()); @@ -115,7 +119,6 @@ void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false, b CUDA_CHECK(cudaDeviceSynchronize()); double r_tol = 1e-2; - bool equal = true; int approx_count = 0; int far_count = 0; for (int i = 0; i < m; ++i) { @@ -138,11 +141,13 @@ void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false, b } if (far_count == 0 && approx_count == 0) { + SUCCEED(); printf("PASS\n"); } else if(far_count < m * n / 100 && approx_count < m * n / 10) { + SUCCEED(); printf("CLOSE %d%% [%d+%d]\n", 100 - (approx_count + far_count) * 100 / (m * n), far_count, approx_count); } else { - printf("FAIL\n"); + FAIL(); } } @@ -158,6 +163,7 @@ void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false, b CUDA_CHECK(cudaFree(scale_b_ptr)); CUDA_CHECK(cudaFree(workspace)); cublasLtDestroy(handle); + get_matmul_backend() = saved_backend; } TEST_CASE("tiny matmul bfloat16 x bfloat16 -> bfloat16", "[gemm][bf16]") { diff --git a/train.cpp b/train.cpp index 5810f10..3e1124a 100644 --- a/train.cpp +++ b/train.cpp @@ -200,7 +200,7 @@ void TrainingRunner::load_training_config(int argc, const char** argv) { app.add_flag("--memcpy-send-recv", MemcpySendRecv, "Use memcpy to perform send/receive (all-to-all). Currently only supported by the threads backend."); app.add_flag("--all-to-all-reduce", Options.UseAllToAllReduce, "Uses an all-to-all-based reduce algorithm. Combine with --memcpy-send-recv."); app.add_flag("--write-combined", Options.UseWriteCombined, "Uses write-combined memory for offloaded tensors."); - app.add_flag("--custom-matmul", UseCustomMatmul, "Use a self-written matmul instead of cublas. This is *not* going to be faster, this " + app.add_flag("--custom-matmul", UseCustomMatmul, "Use a self-written matmul instead of cuBLAS. This is *not* going to be faster, this " "option is mostly for the purists who want to minimize the dependencies.\n"); try { From 9c7ade489a97b493d7912cf597bf773210fcd803 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Sat, 21 Feb 2026 23:36:39 +0100 Subject: [PATCH 13/16] explicit matmul backend argument --- .github/workflows/wheel.yml | 2 ++ scripts/train.py | 2 ++ src/binding/binding.cpp | 7 +++-- src/binding/python/tests/run.py | 4 +++ src/binding/python/training.py | 1 + src/kernels/kernels.cpp | 26 ++++++++++------ src/kernels/kernels.h | 20 ++++++------- src/kernels/matmul.cpp | 51 +++++++++++++++---------------- src/models/llama_model.cpp | 53 ++++++++++++++++++++------------- src/models/llama_model.h | 1 + src/testing/test-gemm.cpp | 8 ++--- src/training/model.h | 3 ++ src/utilities/sol.cpp | 4 +-- train.cpp | 8 +---- 14 files changed, 105 insertions(+), 85 deletions(-) diff --git a/.github/workflows/wheel.yml b/.github/workflows/wheel.yml index 57dbb7e..e5a1770 100644 --- a/.github/workflows/wheel.yml +++ b/.github/workflows/wheel.yml @@ -181,6 +181,8 @@ jobs: args: "torch-step --grad-accum 2 --model-dtype fp32 --matmul-dtype bf16" - name: "Torch Chunking" args: "torch-step --grad-accum 4 --model-dtype bf16 --matmul-dtype bf16 --lmhead-chunks 4" + - name: "Torch BF16 Custom Matmul" + args: "torch-step --grad-accum 1 --model-dtype bf16 --matmul-dtype bf16 --custom-matmul" steps: - name: Checkout code uses: actions/checkout@v4 diff --git a/scripts/train.py b/scripts/train.py index 46cda1f..0c14c9a 100755 --- a/scripts/train.py +++ b/scripts/train.py @@ -49,6 +49,7 @@ def setup_options(config: pyllmq.TrainingConfig) -> pyllmq.LLamaOptions: # Performance options options.use_cuda_graphs = config.use_cuda_graphs options.use_all_to_all_reduce = config.all_to_all_reduce + options.use_custom_matmul = config.custom_matmul options.use_write_combined = config.write_combined # Other options @@ -99,6 +100,7 @@ def add_toggle(arg: str, default: bool, help: str): add_toggle("memcpy-send-recv", True, "Use cudaMemcpyAsync for send/recv (faster on PCIe). Only meaningful in conjunction with all-to-all-reduce") add_toggle("all-to-all-reduce", True, "Use custom all-to-all reduce which can be used with memcpy-send-recv") add_toggle("write-combined", False, "Use write-combined memory. May give faster PCIe transfers.") + add_toggle("custom-matmul", False, "Use custom matmul implementation.") args = parser.parse_args() return pyllmq.TrainingConfig(**vars(args)) diff --git a/src/binding/binding.cpp b/src/binding/binding.cpp index 7e51888..a4977b6 100644 --- a/src/binding/binding.cpp +++ b/src/binding/binding.cpp @@ -151,7 +151,7 @@ NB_MODULE(_pyllmq, m) { bool recompute_ffn, bool recompute_qkv, bool recompute_att, bool recompute_block, bool offload_residual, bool use_cuda_graphs, bool offload_master, bool offload_quants, bool offload_opt_m, bool offload_opt_v, bool offload_grads, bool use_zero_copy, - bool use_write_combined, bool shard_weights, bool persistent_quants, bool shard_gradients, bool use_all_to_all_reduce, + bool use_write_combined, bool shard_weights, bool persistent_quants, bool shard_gradients, bool use_all_to_all_reduce, bool use_custom_matmul, bool init_projections_to_zero, int lmhead_chunks, int attn_bwd_chunks, const std::string matmul_type, const std::string gradient_type, const std::string master_dtype, const std::string momentum_type, const std::string variance_type) { new (t) LLamaOptions{ @@ -176,6 +176,7 @@ NB_MODULE(_pyllmq, m) { .PersistentQuants = persistent_quants, .ShardGradients = shard_gradients, .UseAllToAllReduce = use_all_to_all_reduce, + .UseCustomMatmul = use_custom_matmul, .InitProjectionsToZero = init_projections_to_zero, .MatmulType = opt_dtype_from_str(matmul_type), .GradientType = opt_dtype_from_str(gradient_type), @@ -193,7 +194,8 @@ NB_MODULE(_pyllmq, m) { nb::arg("offload_opt_v") = false, nb::arg("offload_grads") = false, nb::arg("use_zero_copy") = false, nb::arg("use_write_combined") = false, nb::arg("shard_weights") = false, nb::arg("persistent_quants") = false, - nb::arg("shard_gradients") = false, nb::arg("use_all_to_all_reduce") = false, + nb::arg("shard_gradients") = false, + nb::arg("use_all_to_all_reduce") = false, nb::arg("use_custom_matmul") = false, nb::arg("init_projections_to_zero") = false, nb::arg("lmhead_chunks") = 1, nb::arg("attn_bwd_chunks") = 1, nb::arg("matmul_type") = "", nb::arg("gradient_type") = "", @@ -221,6 +223,7 @@ NB_MODULE(_pyllmq, m) { .def_rw("persistent_quants", &LLamaOptions::PersistentQuants) .def_rw("shard_gradients", &LLamaOptions::ShardGradients) .def_rw("use_all_to_all_reduce", &LLamaOptions::UseAllToAllReduce) + .def_rw("use_custom_matmul", &LLamaOptions::UseCustomMatmul) .def_rw("init_projections_to_zero", &LLamaOptions::InitProjectionsToZero) .def_prop_rw("matmul_type", [](const LLamaOptions* opt){ return opt->matmul_dtype(); }, [](LLamaOptions* opt, const std::string& dtype_str){ opt->MatmulType = opt_dtype_from_str(dtype_str); }) diff --git a/src/binding/python/tests/run.py b/src/binding/python/tests/run.py index 3b5f3c3..8057a6c 100644 --- a/src/binding/python/tests/run.py +++ b/src/binding/python/tests/run.py @@ -78,6 +78,9 @@ def _create_options(config: TrainingConfig) -> pyllmq.LLamaOptions: options.shard_gradients = config.shard_gradients options.shard_weights = config.shard_weights + options.use_all_to_all_reduce = config.all_to_all_reduce + options.use_custom_matmul = config.custom_matmul + if config.matmul_dtype: options.matmul_type = config.matmul_dtype if config.gradient_dtype: @@ -167,6 +170,7 @@ def parse_args(args: list = None) -> TrainingConfig: parser.add_argument("--memcpy-send-recv", action="store_true") parser.add_argument("--all-to-all-reduce", action="store_true") parser.add_argument("--write-combined", action="store_true") + parser.add_argument("--custom-matmul", action="store_true") args = parser.parse_args(args=args) cfg = TrainingConfig(**vars(args)) diff --git a/src/binding/python/training.py b/src/binding/python/training.py index f25e7e1..790311e 100644 --- a/src/binding/python/training.py +++ b/src/binding/python/training.py @@ -90,6 +90,7 @@ class TrainingConfig: memcpy_all_gather: bool = False memcpy_send_recv: bool = False all_to_all_reduce: bool = False + custom_matmul: bool = False write_combined: bool = False use_zero_copy: bool = False diff --git a/src/kernels/kernels.cpp b/src/kernels/kernels.cpp index 62d2a18..6e135d0 100644 --- a/src/kernels/kernels.cpp +++ b/src/kernels/kernels.cpp @@ -296,34 +296,42 @@ void fill_constant(Tensor& dest, float value, std::size_t count, cudaStream_t st void matmul(Tensor& c, const Tensor& a, const Tensor& b, const Tensor& bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, Tensor& workspace, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream) { + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend) { std::byte* ws = workspace.get(); std::size_t ws_size = workspace.bytes(); if(c.DType == ETensorDType::FP32 && a.DType == ETensorDType::FP32) { const float* bias_ptr = bias.get_optional(); - matmul(c.get(), a.get(), b.get(), bias_ptr, scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream); + matmul(c.get(), a.get(), b.get(), bias_ptr, + scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream, backend); } else if(c.DType == ETensorDType::FP32 && a.DType == ETensorDType::BF16) { const float* bias_ptr = bias.get_optional(); - matmul(c.get(), a.get(), b.get(), bias_ptr, scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream); + matmul(c.get(), a.get(), b.get(), bias_ptr, + scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream, backend); } else if(c.DType == ETensorDType::FP32 && a.DType == ETensorDType::FP8_E4M3) { if(!bias.empty()) { if(bias.DType == ETensorDType::BF16) { - matmul(c.get(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e4m3>(), bias.get(), scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream); + matmul(c.get(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e4m3>(), bias.get(), + scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream, backend); } else { - matmul(c.get(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e4m3>(), bias.get(), scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream); + matmul(c.get(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e4m3>(), bias.get(), + scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream, backend); } } else { - matmul(c.get(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e4m3>(), (nv_bfloat16*)nullptr, scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream); + matmul(c.get(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e4m3>(), (nv_bfloat16*)nullptr, + scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream, backend); } } else if(c.DType == ETensorDType::BF16 && a.DType == ETensorDType::FP8_E4M3 && b.DType == ETensorDType::FP8_E4M3) { const nv_bfloat16* bias_ptr = bias.get_optional(); - matmul(c.get(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e4m3>(), bias_ptr, scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream); + matmul(c.get(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e4m3>(), bias_ptr, + scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream, backend); } else if(c.DType == ETensorDType::BF16 && a.DType == ETensorDType::FP8_E4M3 && b.DType == ETensorDType::FP8_E5M2) { const nv_bfloat16* bias_ptr = bias.get_optional(); - matmul(c.get(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e5m2>(), bias_ptr, scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream); + matmul(c.get(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e5m2>(), bias_ptr, + scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream, backend); } else if(c.DType == ETensorDType::BF16) { const nv_bfloat16* bias_ptr = bias.get_optional(); - matmul(c.get(), a.get(), b.get(), bias_ptr, scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream); + matmul(c.get(), a.get(), b.get(), bias_ptr, + scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream, backend); } else { UNSUPPORTED_DTYPE(c, a, b, bias, workspace); } diff --git a/src/kernels/kernels.h b/src/kernels/kernels.h index c10ba4e..ea74863 100644 --- a/src/kernels/kernels.h +++ b/src/kernels/kernels.h @@ -22,8 +22,6 @@ enum class ETensorDType: int; enum class EMMTranspose { TT, TN, NT, NN }; enum class EMatmulBackend {CuBLAS, Custom}; -EMatmulBackend& get_matmul_backend(); - void encoder_forward(float* out, const int* inp, const float* wte, const float* wpe, int B, int T, int C, int V, cudaStream_t stream); void encoder_forward(nv_bfloat16* out, const int* inp, const nv_bfloat16* wte, const nv_bfloat16* wpe, int B, int T, int C, int V, cudaStream_t stream); void encoder_forward(Tensor& out, const Tensor& inp, const Tensor& wte, const Tensor& wpe, int B, int T, int C, int V, cudaStream_t stream); @@ -65,39 +63,39 @@ void fused_residual_rmsnorm_forward(Tensor& residual, Tensor& normed, Tensor& rr void matmul(float* c, const float* a, const float* b, const float* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend); void matmul(float* c, const nv_bfloat16* a, const nv_bfloat16* b, const float* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend); void matmul(float* c, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, const float* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend); void matmul(float* c, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend); void matmul(float* c, const __nv_fp8_e4m3* a, const __nv_fp8_e5m2* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend); void matmul(nv_bfloat16* c, const nv_bfloat16* a, const nv_bfloat16* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend); void matmul(nv_bfloat16* c, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend); void matmul(nv_bfloat16* c, const __nv_fp8_e4m3* a, const __nv_fp8_e5m2* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend); void matmul(Tensor& c, const Tensor& a, const Tensor& b, const Tensor& bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, Tensor& workspace, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend); void add_bias(float* out, const float* bias, int B, int T, int OC, cudaStream_t stream); void add_bias(nv_bfloat16* out, const nv_bfloat16* bias, int B, int T, int OC, cudaStream_t stream); diff --git a/src/kernels/matmul.cpp b/src/kernels/matmul.cpp index 1396156..171d06b 100644 --- a/src/kernels/matmul.cpp +++ b/src/kernels/matmul.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2025, IST Austria, developed by Erik Schultheis +// Copyright (c) 2025-2026, IST Austria, developed by Erik Schultheis // SPDX-License-Identifier: Apache-2.0 // // Based on llm.c https://github.com/karpathy/llm.c @@ -13,12 +13,6 @@ cublasComputeType_t cublas_compute = CUBLAS_COMPUTE_32F; -EMatmulBackend& get_matmul_backend() { - // TODO: this is global state right now. Ideally, we could make this local. - static EMatmulBackend backend = EMatmulBackend::CuBLAS; - return backend; -} - // ---------------------------------------------------------------------------- // Error checking @@ -163,21 +157,24 @@ void gemm_mma_tn(nv_bfloat16* out, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* void gemm_mma_tn(nv_bfloat16* out, const nv_bfloat16* a, const nv_bfloat16* b, int m, int n, int k, const float* scale_a, const float* scale_b, const nv_bfloat16* bias, bool accumulate, cudaStream_t stream); -template -void matmul_dispatch(floatO* d, const floatX* a, const floatX* b, const floatB* bias, +template +void matmul_dispatch(floatO* d, const FloatA* a, const FloatB* b, const FloatBias* bias, std::byte* workspace, std::size_t workspace_size, int m, int n, int k, cudaStream_t stream, cublasLtHandle_t handle, - const float* scale_a, const float* scale_b, EMMTranspose mode, bool accumulate) + const float* scale_a, const float* scale_b, EMMTranspose mode, bool accumulate, EMatmulBackend backend) { static std::atomic warning{false}; bool expected = false; - if(get_matmul_backend() == EMatmulBackend::Custom && mode != EMMTranspose::TN && warning.compare_exchange_strong(expected, true)) { + if(backend == EMatmulBackend::Custom && mode != EMMTranspose::TN && warning.compare_exchange_strong(expected, true)) { fprintf(stderr, "WARNING: Custom matmuls are not supported for non-TN mode! Falling back to cublas.\n"); } - if(get_matmul_backend() == EMatmulBackend::CuBLAS || mode != EMMTranspose::TN) { + if(backend == EMatmulBackend::CuBLAS || mode != EMMTranspose::TN) { matmul_cublaslt(d, a, b, bias, workspace, workspace_size, m, n, k, stream, handle, scale_a, scale_b, mode, accumulate); - } else if constexpr (std::is_same_v && std::is_same_v){ + } else if constexpr (std::is_same_v && std::is_same_v && + ((std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v))) + { gemm_mma_tn(d, a, b, m, n, k, scale_a, scale_b, bias, accumulate, stream); } else { matmul_cublaslt(d, a, b, bias, workspace, workspace_size, m, n, k, stream, handle, scale_a, scale_b, mode, accumulate); @@ -186,44 +183,44 @@ void matmul_dispatch(floatO* d, const floatX* a, const floatX* b, const floatB* void matmul(float* c, const float* a, const float* b, const float* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream) { - matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend) { + matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate, backend); } void matmul(float* c, const nv_bfloat16* a, const nv_bfloat16* b, const float* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream) { - matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend) { + matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate, backend); } void matmul(float* c, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, const float* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream) { - matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend) { + matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate, backend); } void matmul(float* c, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream) { - matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend) { + matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate, backend); } void matmul(nv_bfloat16* c, const nv_bfloat16* a, const nv_bfloat16* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream) { - matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend) { + matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate, backend); } void matmul(nv_bfloat16* c, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream) { - matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend) { + matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate, backend); } void matmul(nv_bfloat16* c, const __nv_fp8_e4m3* a, const __nv_fp8_e5m2* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream) { - matmul_cublaslt(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend) { + matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate, backend); } /* diff --git a/src/models/llama_model.cpp b/src/models/llama_model.cpp index c265127..affdd37 100644 --- a/src/models/llama_model.cpp +++ b/src/models/llama_model.cpp @@ -26,9 +26,10 @@ void forward_qmm(Tensor& out, QuantizableTensor& inp, Tensor& weight, const Tens cublasLtHandle_t handle, Tensor& workspace, int B, int T, int C, int OC, const cudaDeviceProp& dp, bool reuse_inp_quant, - cudaStream_t stream) { + cudaStream_t stream, EMatmulBackend backend) { if (weight.DType == inp.Value.DType) { - matmul(out, weight, inp.Value, bias, nullptr, nullptr, handle, workspace, OC, B*T, C, EMMTranspose::TN, false, stream); + matmul(out, weight, inp.Value, bias, nullptr, nullptr, + handle, workspace, OC, B*T, C, EMMTranspose::TN, false, stream, backend); return; } @@ -37,9 +38,11 @@ void forward_qmm(Tensor& out, QuantizableTensor& inp, Tensor& weight, const Tens } if (weight.DType == ETensorDType::BF16) { - matmul(out, weight, inp.Quant, bias, nullptr, nullptr, handle, workspace, OC, B*T, C, EMMTranspose::TN, false, stream); + matmul(out, weight, inp.Quant, bias, nullptr, nullptr, + handle, workspace, OC, B*T, C, EMMTranspose::TN, false, stream, backend); } else { - matmul(out, weight, inp.Quant, bias, weight.scale(), inp.Quant.scale(), handle, workspace, OC, B*T, C, EMMTranspose::TN, false, stream); + matmul(out, weight, inp.Quant, bias, weight.scale(), inp.Quant.scale(), + handle, workspace, OC, B*T, C, EMMTranspose::TN, false, stream, backend); } } @@ -188,7 +191,7 @@ void LLamaModel::_forward_block(sLLamaBlockWeights& weights, sLLamaLayer forward_qmm(acts.QKV, acts.LN1, weights.Attn_QKV_w, weights.Attn_QKV_b, rs->CublasLtHandle, rs->CuBlasWorkspace, B, T, C, Config.qkv_channels(), - rs->DeviceProp, false, main_stream); + rs->DeviceProp, false, main_stream, rs->MatmulBackend); // 2) apply RoPE to q,k (potentially in place) rope_forward(acts.QKV, acts.QKV, rs->FreqCis, nullptr, B, T, Hq, Hkv, Hs, main_stream); // 3) attention: att <- softmax(qk^T)v @@ -201,7 +204,7 @@ void LLamaModel::_forward_block(sLLamaBlockWeights& weights, sLLamaLayer forward_qmm(acts.AttO, acts.Att, weights.Attn_Out_w, Tensor{}, rs->CublasLtHandle, rs->CuBlasWorkspace, B, T, C, C, - rs->DeviceProp, false, main_stream); + rs->DeviceProp, false, main_stream, rs->MatmulBackend); fused_residual_rmsnorm_forward(acts.ResidualAtt, acts.LN2.Value, acts.LN2_Rstd, residual, acts.AttO, weights.LN2_w, acts.LN2.Quant.abs_max(), Config.RmsNormEps, B * T, C, main_stream); @@ -209,13 +212,13 @@ void LLamaModel::_forward_block(sLLamaBlockWeights& weights, sLLamaLayer forward_qmm(acts.MlpUp, acts.LN2, weights.MLP_Up_w, Tensor{}, rs->CublasLtHandle, rs->CuBlasWorkspace, B, T, C, 2 * D, - rs->DeviceProp, false, main_stream); + rs->DeviceProp, false, main_stream, rs->MatmulBackend); swiglu_forward(acts.SwiGLu.Value, acts.MlpUp, acts.SwiGLu.Quant.abs_max(), B, T, D, main_stream); forward_qmm(acts.MlpDown, acts.SwiGLu, weights.MLP_Down_w, Tensor{}, rs->CublasLtHandle, rs->CuBlasWorkspace, B, T, D, C, - rs->DeviceProp, false, main_stream); + rs->DeviceProp, false, main_stream, rs->MatmulBackend); } std::pair LLamaModel::validate(Tensor inputs, Tensor targets, NCCLCommunicator& comm, int micro_step) { @@ -259,7 +262,7 @@ std::pair LLamaModel::validate(Tensor inputs, Tensor targets, NCCL lse.Data += nano_step * nano_batch_size * get_dtype_size(lse.DType); matmul(rs->Output, Parameters->get_head(main_stream), lnf_slice, - Tensor{}, nullptr, nullptr, rs->CublasLtHandle, rs->CuBlasWorkspace, V, nano_batch_size, C, EMMTranspose::TN, false, main_stream); + Tensor{}, nullptr, nullptr, rs->CublasLtHandle, rs->CuBlasWorkspace, V, nano_batch_size, C, EMMTranspose::TN, false, main_stream, rs->MatmulBackend); // accumulate the losses inside rs->losses, and kick off the backward pass inside the fused classifier fused_classifier(rs->Output, losses, lse, d_loss, tgt, 0.f, nano_batch_size, V, Vp, false, main_stream); @@ -281,8 +284,10 @@ void backward_qmm(Tensor& dinp, Tensor& dweight, Tensor dbias, int B, int T, int C, int OC, bool reuse_inp, cudaStream_t stream) { if (weight.DType == inp.Value.DType) { - matmul(dinp, weight, dout.Value, Tensor{}, nullptr, nullptr, rs.CublasLtHandle, rs.CuBlasWorkspace, C, B*T, OC, EMMTranspose::NN, false, stream); - matmul(dweight, inp.Value, dout.Value, Tensor{}, nullptr, nullptr, rs.CublasLtHandle, rs.CuBlasWorkspace, C, OC, B*T, EMMTranspose::NT, accumulate_gradient, stream); + matmul(dinp, weight, dout.Value, Tensor{}, nullptr, nullptr, + rs.CublasLtHandle, rs.CuBlasWorkspace, C, B*T, OC, EMMTranspose::NN, false, stream, rs.MatmulBackend); + matmul(dweight, inp.Value, dout.Value, Tensor{}, nullptr, nullptr, + rs.CublasLtHandle, rs.CuBlasWorkspace, C, OC, B*T, EMMTranspose::NT, accumulate_gradient, stream, rs.MatmulBackend); if (dbias) { backward_bias(dbias, dout.Value, nullptr, nullptr, bias_buffer, B, T, OC, rs.DeviceProp, stream); @@ -293,8 +298,10 @@ void backward_qmm(Tensor& dinp, Tensor& dweight, Tensor dbias, quantize_with_abs_max(inp.Quant, dout.Quant.scale(), inp.Value, nullptr, B*T*C, rs.DeviceProp, stream); } - matmul(dinp, weight, dout.Quant, Tensor{}, nullptr, nullptr, rs.CublasLtHandle, rs.CuBlasWorkspace, C, B*T, OC, EMMTranspose::NN, false, stream); - matmul(dweight, inp.Quant, dout.Quant, Tensor{}, nullptr, nullptr, rs.CublasLtHandle, rs.CuBlasWorkspace, C, OC, B*T, EMMTranspose::NT, accumulate_gradient, stream); + matmul(dinp, weight, dout.Quant, Tensor{}, nullptr, nullptr, + rs.CublasLtHandle, rs.CuBlasWorkspace, C, B*T, OC, EMMTranspose::NN, false, stream, rs.MatmulBackend); + matmul(dweight, inp.Quant, dout.Quant, Tensor{}, nullptr, nullptr, + rs.CublasLtHandle, rs.CuBlasWorkspace, C, OC, B*T, EMMTranspose::NT, accumulate_gradient, stream, rs.MatmulBackend); if (dbias) { backward_bias(dbias, dout.Value, nullptr, nullptr, bias_buffer, B, T, OC, rs.DeviceProp, stream); @@ -307,7 +314,7 @@ void backward_qmm(Tensor& dinp, Tensor& dweight, Tensor dbias, transpose(weight_tp, weight, OC, C, stream); matmul(dinp, weight_tp, dout.Quant, Tensor{}, weight.scale(), dout.Quant.scale(), - rs.CublasLtHandle, rs.CuBlasWorkspace, C, B*T, OC, EMMTranspose::TN, false, stream); + rs.CublasLtHandle, rs.CuBlasWorkspace, C, B*T, OC, EMMTranspose::TN, false, stream, rs.MatmulBackend); rs.temp_free(weight_tp); auto activation_tp = rs.temp_alloc(inp_q.DType, {C, B*T}); @@ -322,7 +329,8 @@ void backward_qmm(Tensor& dinp, Tensor& dweight, Tensor dbias, } transpose(grad_tp, dout.Quant, B*T, OC, stream); - matmul(dweight, activation_tp, grad_tp, Tensor{}, inp_q.scale(), dout.Quant.scale(), rs.CublasLtHandle, rs.CuBlasWorkspace, C, OC, B*T, EMMTranspose::TN, accumulate_gradient, stream); + matmul(dweight, activation_tp, grad_tp, Tensor{}, inp_q.scale(), dout.Quant.scale(), + rs.CublasLtHandle, rs.CuBlasWorkspace, C, OC, B*T, EMMTranspose::TN, accumulate_gradient, stream, rs.MatmulBackend); if (dbias) { backward_bias(dbias, dout.Quant, inp_q.scale(), dout.Quant.scale(), bias_buffer, B, T, OC, rs.DeviceProp, stream); } @@ -484,7 +492,7 @@ void LLamaModel::_backward_lmhead(long B, long T, float z_loss, int micro_step, matmul(rs->Output, Parameters->get_head(main_stream), lnf_slice, Tensor{}, nullptr, nullptr, rs->CublasLtHandle, rs->CuBlasWorkspace, V, nano_batch_size, C, EMMTranspose::TN, - false, main_stream); + false, main_stream, rs->MatmulBackend); if(nano_step == 0) { // make sure Targets have been copied @@ -504,13 +512,13 @@ void LLamaModel::_backward_lmhead(long B, long T, float z_loss, int micro_step, auto& d_lmhead = Grads->get_lmhead_full(main_stream, comm, accumulate); accumulate |= nano_step != 0; matmul(d_lmhead, lnf_slice, rs->Output, Tensor{}, nullptr, nullptr, - rs->CublasLtHandle, rs->CuBlasWorkspace, C, V, nano_batch_size, EMMTranspose::NT, accumulate, main_stream); + rs->CublasLtHandle, rs->CuBlasWorkspace, C, V, nano_batch_size, EMMTranspose::NT, accumulate, main_stream, rs->MatmulBackend); if (nano_step == nano_batches - 1) { Grads->notify_lmhead(main_stream, comm); } matmul(dlnf_slice, Parameters->get_head(main_stream), rs->Output, Tensor{}, nullptr, nullptr, - rs->CublasLtHandle, rs->CuBlasWorkspace, C, nano_batch_size, V, EMMTranspose::NN, false, main_stream); + rs->CublasLtHandle, rs->CuBlasWorkspace, C, nano_batch_size, V, EMMTranspose::NN, false, main_stream, rs->MatmulBackend); } rs->temp_free(rs->Output); @@ -559,7 +567,7 @@ void LLamaModel::_recompute_block(sLLamaBlockWeights& weights, sLLamaLay forward_qmm(acts.QKV, acts.LN1, weights.Attn_QKV_w, weights.Attn_QKV_b, rs->CublasLtHandle, rs->CuBlasWorkspace, B, T, C, Config.qkv_channels(), - rs->DeviceProp, !recompute_ln1, main_stream); + rs->DeviceProp, !recompute_ln1, main_stream, rs->MatmulBackend); rope_forward(acts.QKV, acts.QKV, rs->FreqCis, nullptr, B, T, Hq, Hkv, Hs, main_stream); } @@ -571,7 +579,7 @@ void LLamaModel::_recompute_block(sLLamaBlockWeights& weights, sLLamaLay forward_qmm(acts.AttO, acts.Att, weights.Attn_Out_w, Tensor{}, rs->CublasLtHandle, rs->CuBlasWorkspace, B, T, C, C, - rs->DeviceProp, false, main_stream); + rs->DeviceProp, false, main_stream, rs->MatmulBackend); } } @@ -590,7 +598,7 @@ void LLamaModel::_recompute_block(sLLamaBlockWeights& weights, sLLamaLay forward_qmm(acts.MlpUp, acts.LN2, weights.MLP_Up_w, Tensor{}, rs->CublasLtHandle, rs->CuBlasWorkspace, B, T, C, 2 * D, - rs->DeviceProp, false, main_stream); + rs->DeviceProp, false, main_stream, rs->MatmulBackend); } if(recompute_swiglu) { @@ -911,6 +919,9 @@ void LLamaModel::allocate_run_state(const LLamaOptions& options, NCCLCommunicato OptimizerRNG = std::minstd_rand{42}; RunState = std::make_unique(std::move(acts)); + if (options.UseCustomMatmul) { + RunState->MatmulBackend = EMatmulBackend::Custom; + } comm.barrier(); // make sure *all* GPUs have allocated the model before returning } diff --git a/src/models/llama_model.h b/src/models/llama_model.h index c184320..32c0bbe 100644 --- a/src/models/llama_model.h +++ b/src/models/llama_model.h @@ -43,6 +43,7 @@ struct LLamaOptions { bool ShardGradients = false; bool UseAllToAllReduce = false; + bool UseCustomMatmul = false; bool InitProjectionsToZero = false; diff --git a/src/testing/test-gemm.cpp b/src/testing/test-gemm.cpp index 1b0fb6a..8a055e0 100644 --- a/src/testing/test-gemm.cpp +++ b/src/testing/test-gemm.cpp @@ -25,7 +25,6 @@ extern cublasLtHandle_t create_cublaslt_handle(); template void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false, bool use_bias=false, bool check=true) { - auto saved_backend = get_matmul_backend(); Atype* a; Btype* b; Ctype* c; @@ -101,10 +100,9 @@ void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false, b std::byte* workspace; size_t workspace_size = 128 * 1024 * 1024; CUDA_CHECK(cudaMalloc(&workspace, workspace_size)); - get_matmul_backend() = EMatmulBackend::Custom; CUDA_CHECK(cudaDeviceSynchronize()); - matmul(c, a, b, use_bias ? bias : nullptr, scale_a_ptr, scale_b_ptr, handle, workspace, workspace_size, m, n, k, EMMTranspose::TN, accumulate, nullptr); + matmul(c, a, b, use_bias ? bias : nullptr, scale_a_ptr, scale_b_ptr, handle, workspace, workspace_size, m, n, k, EMMTranspose::TN, accumulate, nullptr, EMatmulBackend::Custom); CUDA_CHECK(cudaDeviceSynchronize()); if(check) { @@ -112,10 +110,9 @@ void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false, b CUDA_CHECK(cudaMemPrefetchAsync(b_float, n*k * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); CUDA_CHECK(cudaMemPrefetchAsync(c_float, m*n * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); CUDA_CHECK(cudaMemPrefetchAsync(bias_float, n * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); - get_matmul_backend() = EMatmulBackend::CuBLAS; CUDA_CHECK(cudaDeviceSynchronize()); matmul(c_float, a_float, b_float, use_bias ? bias_float : nullptr, nullptr, nullptr, - handle, workspace, workspace_size, m, n, k , EMMTranspose::TN, accumulate, nullptr); + handle, workspace, workspace_size, m, n, k , EMMTranspose::TN, accumulate, nullptr, EMatmulBackend::CuBLAS); CUDA_CHECK(cudaDeviceSynchronize()); double r_tol = 1e-2; @@ -163,7 +160,6 @@ void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false, b CUDA_CHECK(cudaFree(scale_b_ptr)); CUDA_CHECK(cudaFree(workspace)); cublasLtDestroy(handle); - get_matmul_backend() = saved_backend; } TEST_CASE("tiny matmul bfloat16 x bfloat16 -> bfloat16", "[gemm][bf16]") { diff --git a/src/training/model.h b/src/training/model.h index 8e83453..00b6825 100644 --- a/src/training/model.h +++ b/src/training/model.h @@ -10,10 +10,12 @@ #include #include +#include "kernels/kernels.h" #include "utilities/stack.h" #include "utilities/tensor.h" #include "training/transformer_config.h" +enum class EMatmulBackend; class AdamWStateManager; class ITensorContainer; class NCCLCommunicator; @@ -188,6 +190,7 @@ class IRunState { cudnnHandle_t CudnnHandle = nullptr; cublasLtHandle_t CublasLtHandle = nullptr; Tensor CuBlasWorkspace; + EMatmulBackend MatmulBackend = EMatmulBackend{0}; // events for debugging timings void setup_timing_events(int micro_steps); diff --git a/src/utilities/sol.cpp b/src/utilities/sol.cpp index c64ab96..a0d7036 100644 --- a/src/utilities/sol.cpp +++ b/src/utilities/sol.cpp @@ -352,14 +352,14 @@ double measure_real_peak() { } ++trip_count; matmul(c, a, b, nullptr, nullptr, nullptr, handle, workspace, 32 * 1024 * 1024, - 16384, 16384, 16384, EMMTranspose::TN, false, nullptr); + 16384, 16384, 16384, EMMTranspose::TN, false, nullptr, EMatmulBackend::CuBLAS); } // now, actual measurement CUDA_CHECK(cudaEventRecord(start_event)); for(int i = 0; i < trip_count; ++i) { matmul(c, a, b, nullptr, nullptr, nullptr, handle, workspace, 32 * 1024 * 1024, - 16384, 16384, 16384, EMMTranspose::TN, false, nullptr); + 16384, 16384, 16384, EMMTranspose::TN, false, nullptr, EMatmulBackend::CuBLAS); } CUDA_CHECK(cudaEventRecord(stop_event)); CUDA_CHECK(cudaEventSynchronize(stop_event)); diff --git a/train.cpp b/train.cpp index 3e1124a..d4d4c3a 100644 --- a/train.cpp +++ b/train.cpp @@ -93,7 +93,6 @@ struct TrainingRunner { int NGPUs = 0; bool MemcpyAllGather = false; bool MemcpySendRecv = false; - bool UseCustomMatmul = false; LLamaOptions Options; @@ -200,7 +199,7 @@ void TrainingRunner::load_training_config(int argc, const char** argv) { app.add_flag("--memcpy-send-recv", MemcpySendRecv, "Use memcpy to perform send/receive (all-to-all). Currently only supported by the threads backend."); app.add_flag("--all-to-all-reduce", Options.UseAllToAllReduce, "Uses an all-to-all-based reduce algorithm. Combine with --memcpy-send-recv."); app.add_flag("--write-combined", Options.UseWriteCombined, "Uses write-combined memory for offloaded tensors."); - app.add_flag("--custom-matmul", UseCustomMatmul, "Use a self-written matmul instead of cuBLAS. This is *not* going to be faster, this " + app.add_flag("--custom-matmul", Options.UseCustomMatmul, "Use a self-written matmul instead of cuBLAS. This is *not* going to be faster, this " "option is mostly for the purists who want to minimize the dependencies.\n"); try { @@ -209,11 +208,6 @@ void TrainingRunner::load_training_config(int argc, const char** argv) { std::exit(app.exit(e)); } - // set-up matmul before any threads are started - if (UseCustomMatmul) { - get_matmul_backend() = EMatmulBackend::Custom; - } - if (!std::filesystem::exists(ModelRootPath)) { if (ModelRootPath.find('/') != std::string::npos) { std::string hf_path = get_hf_model_files(ModelRootPath); From 413de0207be087c21f7f7b3adf187f6d573cb89e Mon Sep 17 00:00:00 2001 From: Erik Schultheis <7938269+ngc92@users.noreply.github.com> Date: Sun, 22 Feb 2026 01:34:58 +0200 Subject: [PATCH 14/16] Update src/models/llama_model.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/models/llama_model.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/models/llama_model.cpp b/src/models/llama_model.cpp index affdd37..2667333 100644 --- a/src/models/llama_model.cpp +++ b/src/models/llama_model.cpp @@ -921,6 +921,8 @@ void LLamaModel::allocate_run_state(const LLamaOptions& options, NCCLCommunicato RunState = std::make_unique(std::move(acts)); if (options.UseCustomMatmul) { RunState->MatmulBackend = EMatmulBackend::Custom; + } else { + RunState->MatmulBackend = EMatmulBackend::CuBLAS; } comm.barrier(); // make sure *all* GPUs have allocated the model before returning } From 188c72b1673a67fa052213ff1ae2eecd58a5ab95 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Wed, 4 Mar 2026 14:49:31 +0100 Subject: [PATCH 15/16] fix bindings --- src/binding/kernel_binding.cpp | 9 ++++++--- train.cpp | 3 +-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/binding/kernel_binding.cpp b/src/binding/kernel_binding.cpp index 743ba8a..c601988 100644 --- a/src/binding/kernel_binding.cpp +++ b/src/binding/kernel_binding.cpp @@ -330,11 +330,12 @@ void bind_grouped_loss_sum(const CudaArray& out, const CudaArray& per_token_loss void bind_matmul(const CudaArray& c, const CudaArray& a, const CudaArray& b, const std::optional& bias, const std::optional& scale_a, const std::optional& scale_b, std::uintptr_t cublaslt_handle, const CudaArray& workspace, - int mode_, bool accumulate, const std::uintptr_t stream) { + int mode_, bool accumulate, const std::uintptr_t stream, int backend_) { NB_CHECK_NDIMS(a, 2); NB_CHECK_NDIMS(b, 2); NB_CHECK_NDIMS(c, 2); EMMTranspose mode = static_cast(mode_); + EMatmulBackend backend = static_cast(backend_); // torch vs cublas: a @ b <=> b^t @ a^ t const bool a_transposed = (mode == EMMTranspose::TN || mode == EMMTranspose::TT); @@ -363,7 +364,7 @@ void bind_matmul(const CudaArray& c, const CudaArray& a, const CudaArray& b, con Tensor c_t = to_tensor(c); Tensor ws_t = to_tensor(workspace); matmul(c_t, to_tensor(b), to_tensor(a), to_tensor(bias), scale_b_ptr, scale_a_ptr, - reinterpret_cast(cublaslt_handle), ws_t, M, N, K, inv_mode, accumulate, as_stream(stream)); + reinterpret_cast(cublaslt_handle), ws_t, M, N, K, inv_mode, accumulate, as_stream(stream), backend); } cublasLtHandle_t create_cublaslt_handle(); @@ -546,7 +547,9 @@ void register_kernels(nanobind::module_& m) { m.def("grouped_loss_sum", &bind_grouped_loss_sum, nb::arg("out"), nb::arg("per_token_loss"), nb::arg("stream") = 0); // Matmul - m.def("matmul", &bind_matmul, nb::arg("c"), nb::arg("a"), nb::arg("b"), nb::arg("bias") = std::nullopt, nb::arg("scale_a") = std::nullopt, nb::arg("scale_b") = std::nullopt, nb::arg("cublaslt_handle"), nb::arg("workspace"), nb::arg("mode"), nb::arg("accumulate") = false, nb::arg("stream") = 0); + m.def("matmul", &bind_matmul, nb::arg("c"), nb::arg("a"), nb::arg("b"), nb::arg("bias") = std::nullopt, + nb::arg("scale_a") = std::nullopt, nb::arg("scale_b") = std::nullopt, nb::arg("cublaslt_handle"), nb::arg("workspace"), + nb::arg("mode"), nb::arg("accumulate") = false, nb::arg("stream") = 0, nb::arg("backend") = EMatmulBackend::CuBLAS); m.def("create_cublas_handle", &bind_create_cublas_handle); m.def("destroy_cublas_handle", &bind_destroy_cublas_handle); diff --git a/train.cpp b/train.cpp index d4d4c3a..eb0bf7b 100644 --- a/train.cpp +++ b/train.cpp @@ -199,8 +199,7 @@ void TrainingRunner::load_training_config(int argc, const char** argv) { app.add_flag("--memcpy-send-recv", MemcpySendRecv, "Use memcpy to perform send/receive (all-to-all). Currently only supported by the threads backend."); app.add_flag("--all-to-all-reduce", Options.UseAllToAllReduce, "Uses an all-to-all-based reduce algorithm. Combine with --memcpy-send-recv."); app.add_flag("--write-combined", Options.UseWriteCombined, "Uses write-combined memory for offloaded tensors."); - app.add_flag("--custom-matmul", Options.UseCustomMatmul, "Use a self-written matmul instead of cuBLAS. This is *not* going to be faster, this " - "option is mostly for the purists who want to minimize the dependencies.\n"); + app.add_flag("--custom-matmul", Options.UseCustomMatmul, "Use a self-written matmul instead of cuBLAS."); try { app.parse(argc, argv); From 59e1844082b083b6f12a5f3d77eabc5c3e5b24cb Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Wed, 4 Mar 2026 16:22:02 +0100 Subject: [PATCH 16/16] make it build on Ampere again + other fixes --- src/binding/kernel_binding.cpp | 2 +- src/kernels/gemm_mma.cu | 39 ++++++++++++++++++++++++++++------ src/kernels/matmul.cpp | 8 ++++--- 3 files changed, 38 insertions(+), 11 deletions(-) diff --git a/src/binding/kernel_binding.cpp b/src/binding/kernel_binding.cpp index c601988..aba6a90 100644 --- a/src/binding/kernel_binding.cpp +++ b/src/binding/kernel_binding.cpp @@ -549,7 +549,7 @@ void register_kernels(nanobind::module_& m) { // Matmul m.def("matmul", &bind_matmul, nb::arg("c"), nb::arg("a"), nb::arg("b"), nb::arg("bias") = std::nullopt, nb::arg("scale_a") = std::nullopt, nb::arg("scale_b") = std::nullopt, nb::arg("cublaslt_handle"), nb::arg("workspace"), - nb::arg("mode"), nb::arg("accumulate") = false, nb::arg("stream") = 0, nb::arg("backend") = EMatmulBackend::CuBLAS); + nb::arg("mode"), nb::arg("accumulate") = false, nb::arg("stream") = 0, nb::arg("backend") = static_cast(EMatmulBackend::CuBLAS)); m.def("create_cublas_handle", &bind_create_cublas_handle); m.def("destroy_cublas_handle", &bind_destroy_cublas_handle); diff --git a/src/kernels/gemm_mma.cu b/src/kernels/gemm_mma.cu index 21e2e73..6349542 100644 --- a/src/kernels/gemm_mma.cu +++ b/src/kernels/gemm_mma.cu @@ -19,13 +19,14 @@ template std::type_identity type_v = {}; template -__global__ __launch_bounds__(32*2*2, 2) void gemm_mma_tn_kernel(nv_bfloat16* __restrict__ out, - const AType* __restrict__ a, const BType* __restrict__ b, - int m, int n, int k, - const float* __restrict__ scale_a, const float* __restrict__ scale_b, - const BiasType* __restrict__ bias, - bool accumulate, - std::type_identity acc_type) { +__device__ void gemm_mma_tn_impl(nv_bfloat16* __restrict__ out, + const AType* __restrict__ a, const BType* __restrict__ b, + int m, int n, int k, + const float* __restrict__ scale_a, const float* __restrict__ scale_b, + const BiasType* __restrict__ bias, + bool accumulate, + std::type_identity acc_type) +{ static_assert(sizeof(AType) == sizeof(BType), "index calculations assume sz(AType) == sz(BType)"); // Note: you cannot change these numbers without breaking the kernel. // they are here only for convenience, not to parametrize the algorithm. @@ -233,6 +234,27 @@ __global__ __launch_bounds__(32*2*2, 2) void gemm_mma_tn_kernel(nv_bfloat16* __r } } +template +__global__ __launch_bounds__(32*2*2, 2) void gemm_mma_tn_kernel(nv_bfloat16* __restrict__ out, + const AType* __restrict__ a, const BType* __restrict__ b, + int m, int n, int k, + const float* __restrict__ scale_a, const float* __restrict__ scale_b, + const BiasType* __restrict__ bias, + bool accumulate, + std::type_identity acc_type) { + if constexpr(std::is_same_v || std::is_same_v) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) + gemm_mma_tn_impl(out, a, b, m, n, k, scale_a, scale_b, bias, accumulate, acc_type); +#else + __trap(); +#endif + } else { + gemm_mma_tn_impl(out, a, b, m, n, k, scale_a, scale_b, bias, accumulate, acc_type); + } + +} + + template void gemm_mma_tn_launcher(nv_bfloat16* out, const AType* a, const BType* b, int m, int n, int k, const float* scale_a, const float* scale_b, const BiasType* bias, bool accumulate, std::type_identity, cudaStream_t stream) { @@ -240,6 +262,9 @@ void gemm_mma_tn_launcher(nv_bfloat16* out, const AType* a, const BType* b, int throw std::invalid_argument("gemm_mma_tn_launcher: n, m, k must be divisible by 128"); } + if (bias && accumulate) + throw std::invalid_argument("gemm_mma_tn_launcher: cannot specify both bias and accumulate"); + // our kernel is row-major, so to match cublas, we need to transpose everything => swapped a<->b, m<->n dim3 grid; if( n > m ) { diff --git a/src/kernels/matmul.cpp b/src/kernels/matmul.cpp index 171d06b..3ccda59 100644 --- a/src/kernels/matmul.cpp +++ b/src/kernels/matmul.cpp @@ -165,11 +165,13 @@ void matmul_dispatch(floatO* d, const FloatA* a, const FloatB* b, const FloatBia { static std::atomic warning{false}; bool expected = false; - if(backend == EMatmulBackend::Custom && mode != EMMTranspose::TN && warning.compare_exchange_strong(expected, true)) { - fprintf(stderr, "WARNING: Custom matmuls are not supported for non-TN mode! Falling back to cublas.\n"); + if(backend == EMatmulBackend::Custom && (mode != EMMTranspose::TN || m % 128 != 0 || n % 128 != 0 || k % 128 != 0) + && warning.compare_exchange_strong(expected, true)) + { + fprintf(stderr, "WARNING: Custom matmuls are not supported for non-TN mode and multiples of 128! Falling back to cublas.\n"); } - if(backend == EMatmulBackend::CuBLAS || mode != EMMTranspose::TN) { + if(backend == EMatmulBackend::CuBLAS || mode != EMMTranspose::TN || m % 128 != 0 || n % 128 != 0 || k % 128 != 0) { matmul_cublaslt(d, a, b, bias, workspace, workspace_size, m, n, k, stream, handle, scale_a, scale_b, mode, accumulate); } else if constexpr (std::is_same_v && std::is_same_v && ((std::is_same_v && std::is_same_v) ||