diff --git a/.github/workflows/wheel.yml b/.github/workflows/wheel.yml index 57dbb7e..e5a1770 100644 --- a/.github/workflows/wheel.yml +++ b/.github/workflows/wheel.yml @@ -181,6 +181,8 @@ jobs: args: "torch-step --grad-accum 2 --model-dtype fp32 --matmul-dtype bf16" - name: "Torch Chunking" args: "torch-step --grad-accum 4 --model-dtype bf16 --matmul-dtype bf16 --lmhead-chunks 4" + - name: "Torch BF16 Custom Matmul" + args: "torch-step --grad-accum 1 --model-dtype bf16 --matmul-dtype bf16 --custom-matmul" steps: - name: Checkout code uses: actions/checkout@v4 diff --git a/CMakeLists.txt b/CMakeLists.txt index 965ae78..bac1314 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -178,6 +178,7 @@ add_library(llmq-common SHARED src/kernels/random.cu src/kernels/fill.cu src/kernels/convert.cu + src/kernels/gemm_mma.cu ) target_link_libraries(llmq-common PRIVATE ${CUFILE_LIBS} cudnn_frontend ${CUDNN_LIBRARY} nvidia::nccl fmt::fmt-header-only) @@ -256,6 +257,7 @@ if (BUILD_TESTS) src/testing/test-swiglu.cu src/testing/test-rope.cu src/testing/test-classifier.cu + src/testing/test-gemm.cpp ) target_link_libraries(unit-tests PRIVATE llmq-common Catch2::Catch2 CLI11::CLI11) target_compile_options(unit-tests PUBLIC $<$:--expt-relaxed-constexpr> $<$:-lineinfo>) diff --git a/scripts/train.py b/scripts/train.py index 46cda1f..0c14c9a 100755 --- a/scripts/train.py +++ b/scripts/train.py @@ -49,6 +49,7 @@ def setup_options(config: pyllmq.TrainingConfig) -> pyllmq.LLamaOptions: # Performance options options.use_cuda_graphs = config.use_cuda_graphs options.use_all_to_all_reduce = config.all_to_all_reduce + options.use_custom_matmul = config.custom_matmul options.use_write_combined = config.write_combined # Other options @@ -99,6 +100,7 @@ def add_toggle(arg: str, default: bool, help: str): add_toggle("memcpy-send-recv", True, "Use cudaMemcpyAsync for send/recv (faster on PCIe). Only meaningful in conjunction with all-to-all-reduce") add_toggle("all-to-all-reduce", True, "Use custom all-to-all reduce which can be used with memcpy-send-recv") add_toggle("write-combined", False, "Use write-combined memory. May give faster PCIe transfers.") + add_toggle("custom-matmul", False, "Use custom matmul implementation.") args = parser.parse_args() return pyllmq.TrainingConfig(**vars(args)) diff --git a/src/binding/binding.cpp b/src/binding/binding.cpp index 7e51888..a4977b6 100644 --- a/src/binding/binding.cpp +++ b/src/binding/binding.cpp @@ -151,7 +151,7 @@ NB_MODULE(_pyllmq, m) { bool recompute_ffn, bool recompute_qkv, bool recompute_att, bool recompute_block, bool offload_residual, bool use_cuda_graphs, bool offload_master, bool offload_quants, bool offload_opt_m, bool offload_opt_v, bool offload_grads, bool use_zero_copy, - bool use_write_combined, bool shard_weights, bool persistent_quants, bool shard_gradients, bool use_all_to_all_reduce, + bool use_write_combined, bool shard_weights, bool persistent_quants, bool shard_gradients, bool use_all_to_all_reduce, bool use_custom_matmul, bool init_projections_to_zero, int lmhead_chunks, int attn_bwd_chunks, const std::string matmul_type, const std::string gradient_type, const std::string master_dtype, const std::string momentum_type, const std::string variance_type) { new (t) LLamaOptions{ @@ -176,6 +176,7 @@ NB_MODULE(_pyllmq, m) { .PersistentQuants = persistent_quants, .ShardGradients = shard_gradients, .UseAllToAllReduce = use_all_to_all_reduce, + .UseCustomMatmul = use_custom_matmul, .InitProjectionsToZero = init_projections_to_zero, .MatmulType = opt_dtype_from_str(matmul_type), .GradientType = opt_dtype_from_str(gradient_type), @@ -193,7 +194,8 @@ NB_MODULE(_pyllmq, m) { nb::arg("offload_opt_v") = false, nb::arg("offload_grads") = false, nb::arg("use_zero_copy") = false, nb::arg("use_write_combined") = false, nb::arg("shard_weights") = false, nb::arg("persistent_quants") = false, - nb::arg("shard_gradients") = false, nb::arg("use_all_to_all_reduce") = false, + nb::arg("shard_gradients") = false, + nb::arg("use_all_to_all_reduce") = false, nb::arg("use_custom_matmul") = false, nb::arg("init_projections_to_zero") = false, nb::arg("lmhead_chunks") = 1, nb::arg("attn_bwd_chunks") = 1, nb::arg("matmul_type") = "", nb::arg("gradient_type") = "", @@ -221,6 +223,7 @@ NB_MODULE(_pyllmq, m) { .def_rw("persistent_quants", &LLamaOptions::PersistentQuants) .def_rw("shard_gradients", &LLamaOptions::ShardGradients) .def_rw("use_all_to_all_reduce", &LLamaOptions::UseAllToAllReduce) + .def_rw("use_custom_matmul", &LLamaOptions::UseCustomMatmul) .def_rw("init_projections_to_zero", &LLamaOptions::InitProjectionsToZero) .def_prop_rw("matmul_type", [](const LLamaOptions* opt){ return opt->matmul_dtype(); }, [](LLamaOptions* opt, const std::string& dtype_str){ opt->MatmulType = opt_dtype_from_str(dtype_str); }) diff --git a/src/binding/kernel_binding.cpp b/src/binding/kernel_binding.cpp index 743ba8a..aba6a90 100644 --- a/src/binding/kernel_binding.cpp +++ b/src/binding/kernel_binding.cpp @@ -330,11 +330,12 @@ void bind_grouped_loss_sum(const CudaArray& out, const CudaArray& per_token_loss void bind_matmul(const CudaArray& c, const CudaArray& a, const CudaArray& b, const std::optional& bias, const std::optional& scale_a, const std::optional& scale_b, std::uintptr_t cublaslt_handle, const CudaArray& workspace, - int mode_, bool accumulate, const std::uintptr_t stream) { + int mode_, bool accumulate, const std::uintptr_t stream, int backend_) { NB_CHECK_NDIMS(a, 2); NB_CHECK_NDIMS(b, 2); NB_CHECK_NDIMS(c, 2); EMMTranspose mode = static_cast(mode_); + EMatmulBackend backend = static_cast(backend_); // torch vs cublas: a @ b <=> b^t @ a^ t const bool a_transposed = (mode == EMMTranspose::TN || mode == EMMTranspose::TT); @@ -363,7 +364,7 @@ void bind_matmul(const CudaArray& c, const CudaArray& a, const CudaArray& b, con Tensor c_t = to_tensor(c); Tensor ws_t = to_tensor(workspace); matmul(c_t, to_tensor(b), to_tensor(a), to_tensor(bias), scale_b_ptr, scale_a_ptr, - reinterpret_cast(cublaslt_handle), ws_t, M, N, K, inv_mode, accumulate, as_stream(stream)); + reinterpret_cast(cublaslt_handle), ws_t, M, N, K, inv_mode, accumulate, as_stream(stream), backend); } cublasLtHandle_t create_cublaslt_handle(); @@ -546,7 +547,9 @@ void register_kernels(nanobind::module_& m) { m.def("grouped_loss_sum", &bind_grouped_loss_sum, nb::arg("out"), nb::arg("per_token_loss"), nb::arg("stream") = 0); // Matmul - m.def("matmul", &bind_matmul, nb::arg("c"), nb::arg("a"), nb::arg("b"), nb::arg("bias") = std::nullopt, nb::arg("scale_a") = std::nullopt, nb::arg("scale_b") = std::nullopt, nb::arg("cublaslt_handle"), nb::arg("workspace"), nb::arg("mode"), nb::arg("accumulate") = false, nb::arg("stream") = 0); + m.def("matmul", &bind_matmul, nb::arg("c"), nb::arg("a"), nb::arg("b"), nb::arg("bias") = std::nullopt, + nb::arg("scale_a") = std::nullopt, nb::arg("scale_b") = std::nullopt, nb::arg("cublaslt_handle"), nb::arg("workspace"), + nb::arg("mode"), nb::arg("accumulate") = false, nb::arg("stream") = 0, nb::arg("backend") = static_cast(EMatmulBackend::CuBLAS)); m.def("create_cublas_handle", &bind_create_cublas_handle); m.def("destroy_cublas_handle", &bind_destroy_cublas_handle); diff --git a/src/binding/python/tests/run.py b/src/binding/python/tests/run.py index 3b5f3c3..8057a6c 100644 --- a/src/binding/python/tests/run.py +++ b/src/binding/python/tests/run.py @@ -78,6 +78,9 @@ def _create_options(config: TrainingConfig) -> pyllmq.LLamaOptions: options.shard_gradients = config.shard_gradients options.shard_weights = config.shard_weights + options.use_all_to_all_reduce = config.all_to_all_reduce + options.use_custom_matmul = config.custom_matmul + if config.matmul_dtype: options.matmul_type = config.matmul_dtype if config.gradient_dtype: @@ -167,6 +170,7 @@ def parse_args(args: list = None) -> TrainingConfig: parser.add_argument("--memcpy-send-recv", action="store_true") parser.add_argument("--all-to-all-reduce", action="store_true") parser.add_argument("--write-combined", action="store_true") + parser.add_argument("--custom-matmul", action="store_true") args = parser.parse_args(args=args) cfg = TrainingConfig(**vars(args)) diff --git a/src/binding/python/training.py b/src/binding/python/training.py index f25e7e1..790311e 100644 --- a/src/binding/python/training.py +++ b/src/binding/python/training.py @@ -90,6 +90,7 @@ class TrainingConfig: memcpy_all_gather: bool = False memcpy_send_recv: bool = False all_to_all_reduce: bool = False + custom_matmul: bool = False write_combined: bool = False use_zero_copy: bool = False diff --git a/src/kernels/gemm_mma.cu b/src/kernels/gemm_mma.cu new file mode 100644 index 0000000..6349542 --- /dev/null +++ b/src/kernels/gemm_mma.cu @@ -0,0 +1,291 @@ +// Copyright (c) 2026, IST Austria, developed by Erik Schultheis +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include "tensor_core_utils.cuh" +#include "utilities/vec.cuh" +#include +#include + +#include "utilities/utils.h" + +template +using int_c = std::integral_constant; + +template +std::type_identity type_v = {}; + +template +__device__ void gemm_mma_tn_impl(nv_bfloat16* __restrict__ out, + const AType* __restrict__ a, const BType* __restrict__ b, + int m, int n, int k, + const float* __restrict__ scale_a, const float* __restrict__ scale_b, + const BiasType* __restrict__ bias, + bool accumulate, + std::type_identity acc_type) +{ + static_assert(sizeof(AType) == sizeof(BType), "index calculations assume sz(AType) == sz(BType)"); + // Note: you cannot change these numbers without breaking the kernel. + // they are here only for convenience, not to parametrize the algorithm. + // also note that some of the smaller loops have been unrolled by hand + // to have a nicer experience in ncu + constexpr int BI = 2; + constexpr int BJ = 2; + constexpr int WI = 4; + constexpr int WJ = 4; + constexpr int DEPTH = 4; + + constexpr int TI = 16; + constexpr int TJ = 16; + constexpr int TK = 2; // in units of uint4 + + int bi; + int bj; + if( m > n ) { + bi = blockIdx.y * BI * WI; + bj = blockIdx.x * BJ * WJ; + } else { + bi = blockIdx.x * BI * WI; + bj = blockIdx.y * BJ * WJ; + } + int i = bi + threadIdx.y * WI; + int j = bj + threadIdx.z * WJ; + + int wid = threadIdx.y + 2 * threadIdx.z; + constexpr int NW = BI * BJ; + + int stride = k / (sizeof(uint4)/sizeof(AType)); + + m16_n16_k32_c_fragment acc[WI][WJ]; + __shared__ uint4 input_tiles[2 * DEPTH * WI * BI * TI * TK]; + int2 offsets = ldmatrix_offsets(); + constexpr int PIPE_OFFSET = WI * BI * TI * TK; + constexpr int ROW_OFFSET = TI * TK; + + // instead of each thread computing addresses in both A and B, specialize 2 warps + // for A-loading and 2 warps for B-loading. + + // const uint4* a_ptr = reinterpret_cast(a) + (bi + wid) * TI * stride; + // const uint4* b_ptr = reinterpret_cast(b) + (bj + wid) * TJ * stride; + // uint4* as_store_ptr = input_tiles + wid * ROW_OFFSET; + // uint4* bs_store_ptr = input_tiles + wid * ROW_OFFSET + DEPTH * PIPE_OFFSET; + + const uint4* g_ptr; + uint4* s_ptr; + if(wid < 2) { + g_ptr = reinterpret_cast(a) + (bi + wid) * TI * stride; + s_ptr = input_tiles + wid * ROW_OFFSET; + } else { + g_ptr = reinterpret_cast(b) + (bj + wid - 2) * TJ * stride; + s_ptr = input_tiles + (wid - 2) * ROW_OFFSET + DEPTH * PIPE_OFFSET; + } + + global_to_shared_16_32_swizzle(&s_ptr, &g_ptr, stride); + + const uint4* as_load_ptr = input_tiles + offsets.x + threadIdx.y * WI*ROW_OFFSET; + const uint4* bs_load_ptr = input_tiles + offsets.y + threadIdx.z * WJ*ROW_OFFSET + DEPTH * PIPE_OFFSET; + + static_assert(WI * BI % NW == 0, "WI * BI must be divisible by the number of warps per block"); + static_assert(WJ * BJ % NW == 0, "WJ * BJ must be divisible by the number of warps per block"); + + m16_n16_k32_a_fragment a_frag[WI]; + m16_n16_k32_b_fragment b_frag[WJ]; + + auto loop_fraction = [&](auto stage_c, auto load_next_c, int ks) { + constexpr int stage = decltype(stage_c)::value; + constexpr int load_stage = (stage + 1) % DEPTH; + constexpr int store_stage = (stage + 3) % DEPTH; + constexpr bool load_next = decltype(load_next_c)::value; + + // only load more inputs if we're not winding down the pipeline + if constexpr(load_next) { + __pipeline_memcpy_async(s_ptr + store_stage * PIPE_OFFSET + 0 * ROW_OFFSET, g_ptr + 0 * TI * stride + ks, 16); + __pipeline_memcpy_async(s_ptr + store_stage * PIPE_OFFSET + 2 * ROW_OFFSET, g_ptr + 2 * TI * stride + ks, 16); + __pipeline_memcpy_async(s_ptr + store_stage * PIPE_OFFSET + 4 * ROW_OFFSET, g_ptr + 4 * TI * stride + ks, 16); + __pipeline_memcpy_async(s_ptr + store_stage * PIPE_OFFSET + 6 * ROW_OFFSET, g_ptr + 6 * TI * stride + ks, 16); + } + __pipeline_commit(); + for(int jj = 0; jj < WJ; jj++) { + for(int ii = 0; ii < WI; ii++) { + mma_m16_n16_k32_sync(acc[ii][jj], a_frag[ii], b_frag[jj], acc[ii][jj]); + } + } + __pipeline_wait_prior(2); + __syncthreads(); + + // only load more registers if this is not the last step + if constexpr(load_next || stage != 3) { + ptx_ldmatrix(a_frag[0].v, as_load_ptr + 0 * ROW_OFFSET + PIPE_OFFSET * load_stage); + ptx_ldmatrix(a_frag[1].v, as_load_ptr + 1 * ROW_OFFSET + PIPE_OFFSET * load_stage); + ptx_ldmatrix(a_frag[2].v, as_load_ptr + 2 * ROW_OFFSET + PIPE_OFFSET * load_stage); + ptx_ldmatrix(a_frag[3].v, as_load_ptr + 3 * ROW_OFFSET + PIPE_OFFSET * load_stage); + ptx_ldmatrix(b_frag[0].v, bs_load_ptr + 0 * ROW_OFFSET + PIPE_OFFSET * load_stage); + ptx_ldmatrix(b_frag[1].v, bs_load_ptr + 1 * ROW_OFFSET + PIPE_OFFSET * load_stage); + ptx_ldmatrix(b_frag[2].v, bs_load_ptr + 2 * ROW_OFFSET + PIPE_OFFSET * load_stage); + ptx_ldmatrix(b_frag[3].v, bs_load_ptr + 3 * ROW_OFFSET + PIPE_OFFSET * load_stage); + } + }; + + auto ldg_sts = [&](auto stage_c, int ks) { + constexpr int stage = decltype(stage_c)::value; + __pipeline_memcpy_async(s_ptr + stage * PIPE_OFFSET + 0 * ROW_OFFSET, g_ptr + 0 * TI * stride + ks, 16); + __pipeline_memcpy_async(s_ptr + stage * PIPE_OFFSET + 2 * ROW_OFFSET, g_ptr + 2 * TI * stride + ks, 16); + __pipeline_memcpy_async(s_ptr + stage * PIPE_OFFSET + 4 * ROW_OFFSET, g_ptr + 4 * TI * stride + ks, 16); + __pipeline_memcpy_async(s_ptr + stage * PIPE_OFFSET + 6 * ROW_OFFSET, g_ptr + 6 * TI * stride + ks, 16); + }; + + // start up the pipeline + ldg_sts(int_c<0>{}, 0); + __pipeline_commit(); + ldg_sts(int_c<1>{}, 1 * TK); + __pipeline_commit(); + ldg_sts(int_c<2>{}, 2 * TK); + __pipeline_commit(); + + std::bool_constant true_v; + std::bool_constant false_v; + + __pipeline_wait_prior(2); + __syncthreads(); + + ptx_ldmatrix(a_frag[0].v, as_load_ptr + 0 * ROW_OFFSET); + ptx_ldmatrix(a_frag[1].v, as_load_ptr + 1 * ROW_OFFSET); + ptx_ldmatrix(a_frag[2].v, as_load_ptr + 2 * ROW_OFFSET); + ptx_ldmatrix(a_frag[3].v, as_load_ptr + 3 * ROW_OFFSET); + ptx_ldmatrix(b_frag[0].v, bs_load_ptr + 0 * ROW_OFFSET); + ptx_ldmatrix(b_frag[1].v, bs_load_ptr + 1 * ROW_OFFSET); + ptx_ldmatrix(b_frag[2].v, bs_load_ptr + 2 * ROW_OFFSET); + ptx_ldmatrix(b_frag[3].v, bs_load_ptr + 3 * ROW_OFFSET); + + int ks = 0; + while (ks + 6 * TK < stride) { + loop_fraction(int_c<0>{}, true_v, ks + 3 * TK); + loop_fraction(int_c<1>{}, true_v, ks + 4 * TK); + loop_fraction(int_c<2>{}, true_v, ks + 5 * TK); + loop_fraction(int_c<3>{}, true_v, ks + 6 * TK); + + ks += 4 * TK; + } + + // last iteration + loop_fraction(int_c<0>{}, true_v, ks + 3 * TK); + loop_fraction(int_c<1>{}, false_v, ks + 4 * TK); + loop_fraction(int_c<2>{}, false_v, ks + 5 * TK); + loop_fraction(int_c<3>{}, false_v, ks + 6 * TK); + + // note: loop_fraction ends with __syncthreads, so no need to sync here + nv_bfloat16* out_shared = reinterpret_cast(input_tiles) + (threadIdx.y + 2 * threadIdx.z) * TJ * TI; + + float scale = 1.f; + if (scale_a != nullptr) { + scale = *scale_a; + } + if (scale_b != nullptr) { + scale *= *scale_b; + } + + // on 40xx, for some reason, these loops don't get unrolled, and then + // all the accumulators end up in local memory and we get terrible + // performance. + #pragma unroll + for(int ii = 0; ii < WI; ii++) { + #pragma unroll + for (int jj = 0; jj < WJ; jj++) { + + // interleave scaling and output writing + if(scale != 1.f) { + acc[ii][jj].v[0] *= scale; + acc[ii][jj].v[1] *= scale; + acc[ii][jj].v[2] *= scale; + acc[ii][jj].v[3] *= scale; + acc[ii][jj].v[4] *= scale; + acc[ii][jj].v[5] *= scale; + acc[ii][jj].v[6] *= scale; + acc[ii][jj].v[7] *= scale; + } + + store_fragment_row_major_sync(acc[ii][jj], out_shared, TJ); + __syncwarp(); + int c = threadIdx.x % 2; + int r = threadIdx.x / 2; + + if(accumulate) { + auto old = GenericVector::load(out + ((i + ii) * TI + r) * n + (j + jj) * TJ + 8 * c); + auto upd = GenericVector::load(out_shared + (c + 2 * r) * 8); + for(int l = 0; l < 8; ++l) { + old[l] += upd[l]; + } + old.store(out + ((i + ii) * TI + r) * n + (j + jj) * TJ + 8 * c); + } else if (bias != nullptr) { + auto old = GenericVector::load(bias + (j + jj) * TJ + 8 * c); + auto upd = GenericVector::load(out_shared + (c + 2 * r) * 8); + for(int l = 0; l < 8; ++l) { + old[l] += (nv_bfloat16)upd[l]; + } + old.store(out + ((i + ii) * TI + r) * n + (j + jj) * TJ + 8 * c); + } else { + uint4 load = reinterpret_cast(out_shared)[c + 2 * r]; + *reinterpret_cast(out + ((i + ii) * TI + r) * n + (j + jj) * TJ + 8 * c) = load; + } + } + } +} + +template +__global__ __launch_bounds__(32*2*2, 2) void gemm_mma_tn_kernel(nv_bfloat16* __restrict__ out, + const AType* __restrict__ a, const BType* __restrict__ b, + int m, int n, int k, + const float* __restrict__ scale_a, const float* __restrict__ scale_b, + const BiasType* __restrict__ bias, + bool accumulate, + std::type_identity acc_type) { + if constexpr(std::is_same_v || std::is_same_v) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) + gemm_mma_tn_impl(out, a, b, m, n, k, scale_a, scale_b, bias, accumulate, acc_type); +#else + __trap(); +#endif + } else { + gemm_mma_tn_impl(out, a, b, m, n, k, scale_a, scale_b, bias, accumulate, acc_type); + } + +} + + +template +void gemm_mma_tn_launcher(nv_bfloat16* out, const AType* a, const BType* b, int m, int n, int k, const float* scale_a, const float* scale_b, const BiasType* bias, + bool accumulate, std::type_identity, cudaStream_t stream) { + if (n % 128 != 0 || m % 128 != 0 || k % 128 != 0) { + throw std::invalid_argument("gemm_mma_tn_launcher: n, m, k must be divisible by 128"); + } + + if (bias && accumulate) + throw std::invalid_argument("gemm_mma_tn_launcher: cannot specify both bias and accumulate"); + + // our kernel is row-major, so to match cublas, we need to transpose everything => swapped a<->b, m<->n + dim3 grid; + if( n > m ) { + grid = {(unsigned)div_exact(m, 128), (unsigned)div_exact(n, 128), 1}; + } else { + grid = {(unsigned)div_exact(n, 128), (unsigned)div_exact(m, 128), 1}; + } + dim3 block{32, 2, 2}; + gemm_mma_tn_kernel<<>>(out, b, a, n, m, k, scale_a, scale_b, bias, accumulate, type_v); +} + +void gemm_mma_tn(nv_bfloat16* out, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, int m, int n, int k, + const float* scale_a, const float* scale_b, const nv_bfloat16* bias, bool accumulate, cudaStream_t stream) +{ + gemm_mma_tn_launcher(out, a, b, m, n, k, scale_a, scale_b, bias, accumulate, type_v, stream); + CUDA_CHECK(cudaGetLastError()); +} + +void gemm_mma_tn(nv_bfloat16* out, const nv_bfloat16* a, const nv_bfloat16* b, int m, int n, int k, + const float* scale_a, const float* scale_b, const nv_bfloat16* bias, bool accumulate, cudaStream_t stream) +{ + gemm_mma_tn_launcher(out, a, b, m, n, k, scale_a, scale_b, bias, accumulate, type_v, stream); + CUDA_CHECK(cudaGetLastError()); +} diff --git a/src/kernels/kernels.cpp b/src/kernels/kernels.cpp index 62d2a18..6e135d0 100644 --- a/src/kernels/kernels.cpp +++ b/src/kernels/kernels.cpp @@ -296,34 +296,42 @@ void fill_constant(Tensor& dest, float value, std::size_t count, cudaStream_t st void matmul(Tensor& c, const Tensor& a, const Tensor& b, const Tensor& bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, Tensor& workspace, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream) { + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend) { std::byte* ws = workspace.get(); std::size_t ws_size = workspace.bytes(); if(c.DType == ETensorDType::FP32 && a.DType == ETensorDType::FP32) { const float* bias_ptr = bias.get_optional(); - matmul(c.get(), a.get(), b.get(), bias_ptr, scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream); + matmul(c.get(), a.get(), b.get(), bias_ptr, + scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream, backend); } else if(c.DType == ETensorDType::FP32 && a.DType == ETensorDType::BF16) { const float* bias_ptr = bias.get_optional(); - matmul(c.get(), a.get(), b.get(), bias_ptr, scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream); + matmul(c.get(), a.get(), b.get(), bias_ptr, + scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream, backend); } else if(c.DType == ETensorDType::FP32 && a.DType == ETensorDType::FP8_E4M3) { if(!bias.empty()) { if(bias.DType == ETensorDType::BF16) { - matmul(c.get(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e4m3>(), bias.get(), scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream); + matmul(c.get(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e4m3>(), bias.get(), + scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream, backend); } else { - matmul(c.get(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e4m3>(), bias.get(), scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream); + matmul(c.get(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e4m3>(), bias.get(), + scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream, backend); } } else { - matmul(c.get(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e4m3>(), (nv_bfloat16*)nullptr, scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream); + matmul(c.get(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e4m3>(), (nv_bfloat16*)nullptr, + scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream, backend); } } else if(c.DType == ETensorDType::BF16 && a.DType == ETensorDType::FP8_E4M3 && b.DType == ETensorDType::FP8_E4M3) { const nv_bfloat16* bias_ptr = bias.get_optional(); - matmul(c.get(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e4m3>(), bias_ptr, scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream); + matmul(c.get(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e4m3>(), bias_ptr, + scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream, backend); } else if(c.DType == ETensorDType::BF16 && a.DType == ETensorDType::FP8_E4M3 && b.DType == ETensorDType::FP8_E5M2) { const nv_bfloat16* bias_ptr = bias.get_optional(); - matmul(c.get(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e5m2>(), bias_ptr, scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream); + matmul(c.get(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e5m2>(), bias_ptr, + scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream, backend); } else if(c.DType == ETensorDType::BF16) { const nv_bfloat16* bias_ptr = bias.get_optional(); - matmul(c.get(), a.get(), b.get(), bias_ptr, scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream); + matmul(c.get(), a.get(), b.get(), bias_ptr, + scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream, backend); } else { UNSUPPORTED_DTYPE(c, a, b, bias, workspace); } diff --git a/src/kernels/kernels.h b/src/kernels/kernels.h index 6266445..ea74863 100644 --- a/src/kernels/kernels.h +++ b/src/kernels/kernels.h @@ -20,6 +20,7 @@ struct Tensor; enum class ETensorDType: int; enum class EMMTranspose { TT, TN, NT, NN }; +enum class EMatmulBackend {CuBLAS, Custom}; void encoder_forward(float* out, const int* inp, const float* wte, const float* wpe, int B, int T, int C, int V, cudaStream_t stream); void encoder_forward(nv_bfloat16* out, const int* inp, const nv_bfloat16* wte, const nv_bfloat16* wpe, int B, int T, int C, int V, cudaStream_t stream); @@ -62,39 +63,39 @@ void fused_residual_rmsnorm_forward(Tensor& residual, Tensor& normed, Tensor& rr void matmul(float* c, const float* a, const float* b, const float* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend); void matmul(float* c, const nv_bfloat16* a, const nv_bfloat16* b, const float* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend); void matmul(float* c, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, const float* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend); void matmul(float* c, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend); void matmul(float* c, const __nv_fp8_e4m3* a, const __nv_fp8_e5m2* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend); void matmul(nv_bfloat16* c, const nv_bfloat16* a, const nv_bfloat16* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend); void matmul(nv_bfloat16* c, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend); void matmul(nv_bfloat16* c, const __nv_fp8_e4m3* a, const __nv_fp8_e5m2* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend); void matmul(Tensor& c, const Tensor& a, const Tensor& b, const Tensor& bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, Tensor& workspace, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend); void add_bias(float* out, const float* bias, int B, int T, int OC, cudaStream_t stream); void add_bias(nv_bfloat16* out, const nv_bfloat16* bias, int B, int T, int OC, cudaStream_t stream); diff --git a/src/kernels/matmul.cpp b/src/kernels/matmul.cpp index 7d67462..3ccda59 100644 --- a/src/kernels/matmul.cpp +++ b/src/kernels/matmul.cpp @@ -1,8 +1,9 @@ -// Copyright (c) 2025, IST Austria, developed by Erik Schultheis +// Copyright (c) 2025-2026, IST Austria, developed by Erik Schultheis // SPDX-License-Identifier: Apache-2.0 // // Based on llm.c https://github.com/karpathy/llm.c +#include #include #include @@ -151,46 +152,77 @@ void matmul_cublaslt(FloatC* d, const FloatA* a, const FloatB* b, const FloatBia CUDA_CHECK(cudaGetLastError()); } +// custom matmuls +void gemm_mma_tn(nv_bfloat16* out, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, int m, int n, int k, const float* scale_a, const float* scale_b, const nv_bfloat16* bias, bool accumulate, cudaStream_t stream); +void gemm_mma_tn(nv_bfloat16* out, const nv_bfloat16* a, const nv_bfloat16* b, int m, int n, int k, const float* scale_a, const float* scale_b, const nv_bfloat16* bias, bool accumulate, cudaStream_t stream); + + +template +void matmul_dispatch(floatO* d, const FloatA* a, const FloatB* b, const FloatBias* bias, + std::byte* workspace, std::size_t workspace_size, + int m, int n, int k, cudaStream_t stream, cublasLtHandle_t handle, + const float* scale_a, const float* scale_b, EMMTranspose mode, bool accumulate, EMatmulBackend backend) +{ + static std::atomic warning{false}; + bool expected = false; + if(backend == EMatmulBackend::Custom && (mode != EMMTranspose::TN || m % 128 != 0 || n % 128 != 0 || k % 128 != 0) + && warning.compare_exchange_strong(expected, true)) + { + fprintf(stderr, "WARNING: Custom matmuls are not supported for non-TN mode and multiples of 128! Falling back to cublas.\n"); + } + + if(backend == EMatmulBackend::CuBLAS || mode != EMMTranspose::TN || m % 128 != 0 || n % 128 != 0 || k % 128 != 0) { + matmul_cublaslt(d, a, b, bias, workspace, workspace_size, m, n, k, stream, handle, scale_a, scale_b, mode, accumulate); + } else if constexpr (std::is_same_v && std::is_same_v && + ((std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v))) + { + gemm_mma_tn(d, a, b, m, n, k, scale_a, scale_b, bias, accumulate, stream); + } else { + matmul_cublaslt(d, a, b, bias, workspace, workspace_size, m, n, k, stream, handle, scale_a, scale_b, mode, accumulate); + } +} + void matmul(float* c, const float* a, const float* b, const float* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream) { - matmul_cublaslt(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend) { + matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate, backend); } void matmul(float* c, const nv_bfloat16* a, const nv_bfloat16* b, const float* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream) { - matmul_cublaslt(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend) { + matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate, backend); } void matmul(float* c, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, const float* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream) { - matmul_cublaslt(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend) { + matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate, backend); } void matmul(float* c, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream) { - matmul_cublaslt(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend) { + matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate, backend); } void matmul(nv_bfloat16* c, const nv_bfloat16* a, const nv_bfloat16* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream) { - matmul_cublaslt(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend) { + matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate, backend); } void matmul(nv_bfloat16* c, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream) { - matmul_cublaslt(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend) { + matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate, backend); } void matmul(nv_bfloat16* c, const __nv_fp8_e4m3* a, const __nv_fp8_e5m2* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b, cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size, - int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream) { - matmul_cublaslt(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate); + int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend) { + matmul_dispatch(c, a, b, bias, workspace, workspace_size, M, N, K, stream, handle, scale_a, scale_b, mode, accumulate, backend); } /* diff --git a/src/kernels/tensor_core_utils.cuh b/src/kernels/tensor_core_utils.cuh new file mode 100644 index 0000000..3f3658d --- /dev/null +++ b/src/kernels/tensor_core_utils.cuh @@ -0,0 +1,153 @@ +// Copyright (c) 2026, IST Austria, developed by Erik Schultheis +// SPDX-License-Identifier: Apache-2.0 +// + +#ifndef LLMQ_TENSOR_CORE_UTILS_CUH +#define LLMQ_TENSOR_CORE_UTILS_CUH + +#include "utilities/dtype.h" +#include +#include +#include +#include +#include + +template +struct m16_n16_k32_a_fragment { + uint4 v; +}; + +template +struct m16_n16_k32_b_fragment { + uint4 v; +}; + +template +struct m16_n16_k32_c_fragment { + AccDType v[8] = {}; +}; + +template +constexpr char ptx_type_name[] = "unknown_dtype"; + +template<> +constexpr char ptx_type_name[4] = "f32"; + +template<> +constexpr char ptx_type_name[4] = "f16"; + +template<> +constexpr char ptx_type_name[5] = "bf16"; + +template<> +constexpr char ptx_type_name<__nv_fp8_e4m3>[5] = "e4m3"; + +template<> +constexpr char ptx_type_name<__nv_fp8_e5m2>[5] = "e5m2"; + + +__device__ __forceinline__ void global_to_shared_16_32_swizzle(uint4** shared, const uint4** global, int stride) { + int col = threadIdx.x % 2; + int row = threadIdx.x / 2; + + int g8 = threadIdx.x / 8; + int t8 = threadIdx.x % 8; + + *shared = *shared + (t8 ^ g8) + 8 * g8; + *global = *global + row * stride + col; +} + +__device__ __forceinline__ int load_address(int row, int col) { + int lin = col + 2 * row; + int g8 = lin / 8; + int t8 = lin % 8; + return (t8 ^ g8) + 8 * g8; +} + +__device__ __forceinline__ int2 ldmatrix_offsets() { + int t8 = threadIdx.x % 8; + int g8 = threadIdx.x / 8; + int a = load_address(t8 + 8 * (g8%2), g8 / 2); + int b = load_address(t8 + 8 * (g8/2), g8 % 2); + return make_int2(a, b); +} + +__device__ __forceinline__ void ptx_ldmatrix(uint4& dst, const void* src) { + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) + : "l"(__cvta_generic_to_shared(src)) + ); +} + +template +__device__ __forceinline__ void store_fragment_row_major_sync(m16_n16_k32_c_fragment& c, nv_bfloat16* ptr, int row_stride) { + int g4 = threadIdx.x / 4; + int c4 = threadIdx.x % 4; + nv_bfloat162* vptr = reinterpret_cast(ptr); + row_stride /= 2; + vptr[row_stride * (g4 + 0) + c4 + 0] = make_bfloat162((nv_bfloat16)c.v[0], (nv_bfloat16)c.v[1]); + vptr[row_stride * (g4 + 8) + c4 + 0] = make_bfloat162((nv_bfloat16)c.v[2], (nv_bfloat16)c.v[3]); + + vptr[row_stride * (g4 + 0) + c4 + 4] = make_bfloat162((nv_bfloat16)c.v[4], (nv_bfloat16)c.v[5]); + vptr[row_stride * (g4 + 8) + c4 + 4] = make_bfloat162((nv_bfloat16)c.v[6], (nv_bfloat16)c.v[7]); +} + +template +__device__ __forceinline__ void mma_m16_n16_k32_sync(m16_n16_k32_c_fragment& d, + m16_n16_k32_a_fragment a, + m16_n16_k32_b_fragment b, + m16_n16_k32_c_fragment c) { + static_assert(sizeof(AType) == sizeof(BType), "a and b type must have the same size"); + + constexpr int k = 32 / sizeof(AType); + asm volatile("mma.sync.aligned.m16n8k%26.row.col.f32.%24.%25.f32 " + "{%0, %1, %2, %3}," + "{%8, %9, %10, %11}," + "{%12, %13}," + "{%16, %17, %18, %19};\n" + "mma.sync.aligned.m16n8k%26.row.col.f32.%24.%25.f32 " + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%14, %15}," + "{%20, %21, %22, %23};\n" + : "=f"(d.v[0]), "=f"(d.v[1]), "=f"(d.v[2]), "=f"(d.v[3]), + "=f"(d.v[4]), "=f"(d.v[5]), "=f"(d.v[6]), "=f"(d.v[7]) + : "r"(a.v.x), "r"(a.v.y), "r"(a.v.z), "r"(a.v.w), + "r"(b.v.x), "r"(b.v.y), "r"(b.v.z), "r"(b.v.w), + "f"(c.v[0]), "f"(c.v[1]), "f"(c.v[2]), "f"(c.v[3]), + "f"(c.v[4]), "f"(c.v[5]), "f"(c.v[6]), "f"(c.v[7]), + "C"(ptx_type_name), "C"(ptx_type_name), "n"(k)); +} + +template +__device__ __forceinline__ void mma_m16_n16_k32_sync(m16_n16_k32_c_fragment& d, + m16_n16_k32_a_fragment a, + m16_n16_k32_b_fragment b, + m16_n16_k32_c_fragment c) { + auto to_raw = [](half& h) -> unsigned int& { + return *reinterpret_cast(&h); + }; + asm volatile("mma.sync.aligned.m16n8k32.row.col.f16.%10.%11.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};" + : "=r"(to_raw(d.v[0])), "=r"(to_raw(d.v[2])) + : "r"(a.v.x), "r"(a.v.y), "r"(a.v.z), "r"(a.v.w), + "r"(b.v.x), "r"(b.v.y), + "r"(to_raw(c.v[0])), "r"(to_raw(c.v[2])), + "C"(ptx_type_name), "C"(ptx_type_name)); + + asm volatile("mma.sync.aligned.m16n8k32.row.col.f16.%10.%11.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};" + : "=r"(to_raw(d.v[4])), "=r"(to_raw(d.v[6])) + : "r"(a.v.x), "r"(a.v.y), "r"(a.v.z), "r"(a.v.w), + "r"(b.v.z), "r"(b.v.w), + "r"(to_raw(c.v[4])), "r"(to_raw(c.v[6])), + "C"(ptx_type_name), "C"(ptx_type_name)); +} + +#endif diff --git a/src/models/llama_model.cpp b/src/models/llama_model.cpp index c265127..2667333 100644 --- a/src/models/llama_model.cpp +++ b/src/models/llama_model.cpp @@ -26,9 +26,10 @@ void forward_qmm(Tensor& out, QuantizableTensor& inp, Tensor& weight, const Tens cublasLtHandle_t handle, Tensor& workspace, int B, int T, int C, int OC, const cudaDeviceProp& dp, bool reuse_inp_quant, - cudaStream_t stream) { + cudaStream_t stream, EMatmulBackend backend) { if (weight.DType == inp.Value.DType) { - matmul(out, weight, inp.Value, bias, nullptr, nullptr, handle, workspace, OC, B*T, C, EMMTranspose::TN, false, stream); + matmul(out, weight, inp.Value, bias, nullptr, nullptr, + handle, workspace, OC, B*T, C, EMMTranspose::TN, false, stream, backend); return; } @@ -37,9 +38,11 @@ void forward_qmm(Tensor& out, QuantizableTensor& inp, Tensor& weight, const Tens } if (weight.DType == ETensorDType::BF16) { - matmul(out, weight, inp.Quant, bias, nullptr, nullptr, handle, workspace, OC, B*T, C, EMMTranspose::TN, false, stream); + matmul(out, weight, inp.Quant, bias, nullptr, nullptr, + handle, workspace, OC, B*T, C, EMMTranspose::TN, false, stream, backend); } else { - matmul(out, weight, inp.Quant, bias, weight.scale(), inp.Quant.scale(), handle, workspace, OC, B*T, C, EMMTranspose::TN, false, stream); + matmul(out, weight, inp.Quant, bias, weight.scale(), inp.Quant.scale(), + handle, workspace, OC, B*T, C, EMMTranspose::TN, false, stream, backend); } } @@ -188,7 +191,7 @@ void LLamaModel::_forward_block(sLLamaBlockWeights& weights, sLLamaLayer forward_qmm(acts.QKV, acts.LN1, weights.Attn_QKV_w, weights.Attn_QKV_b, rs->CublasLtHandle, rs->CuBlasWorkspace, B, T, C, Config.qkv_channels(), - rs->DeviceProp, false, main_stream); + rs->DeviceProp, false, main_stream, rs->MatmulBackend); // 2) apply RoPE to q,k (potentially in place) rope_forward(acts.QKV, acts.QKV, rs->FreqCis, nullptr, B, T, Hq, Hkv, Hs, main_stream); // 3) attention: att <- softmax(qk^T)v @@ -201,7 +204,7 @@ void LLamaModel::_forward_block(sLLamaBlockWeights& weights, sLLamaLayer forward_qmm(acts.AttO, acts.Att, weights.Attn_Out_w, Tensor{}, rs->CublasLtHandle, rs->CuBlasWorkspace, B, T, C, C, - rs->DeviceProp, false, main_stream); + rs->DeviceProp, false, main_stream, rs->MatmulBackend); fused_residual_rmsnorm_forward(acts.ResidualAtt, acts.LN2.Value, acts.LN2_Rstd, residual, acts.AttO, weights.LN2_w, acts.LN2.Quant.abs_max(), Config.RmsNormEps, B * T, C, main_stream); @@ -209,13 +212,13 @@ void LLamaModel::_forward_block(sLLamaBlockWeights& weights, sLLamaLayer forward_qmm(acts.MlpUp, acts.LN2, weights.MLP_Up_w, Tensor{}, rs->CublasLtHandle, rs->CuBlasWorkspace, B, T, C, 2 * D, - rs->DeviceProp, false, main_stream); + rs->DeviceProp, false, main_stream, rs->MatmulBackend); swiglu_forward(acts.SwiGLu.Value, acts.MlpUp, acts.SwiGLu.Quant.abs_max(), B, T, D, main_stream); forward_qmm(acts.MlpDown, acts.SwiGLu, weights.MLP_Down_w, Tensor{}, rs->CublasLtHandle, rs->CuBlasWorkspace, B, T, D, C, - rs->DeviceProp, false, main_stream); + rs->DeviceProp, false, main_stream, rs->MatmulBackend); } std::pair LLamaModel::validate(Tensor inputs, Tensor targets, NCCLCommunicator& comm, int micro_step) { @@ -259,7 +262,7 @@ std::pair LLamaModel::validate(Tensor inputs, Tensor targets, NCCL lse.Data += nano_step * nano_batch_size * get_dtype_size(lse.DType); matmul(rs->Output, Parameters->get_head(main_stream), lnf_slice, - Tensor{}, nullptr, nullptr, rs->CublasLtHandle, rs->CuBlasWorkspace, V, nano_batch_size, C, EMMTranspose::TN, false, main_stream); + Tensor{}, nullptr, nullptr, rs->CublasLtHandle, rs->CuBlasWorkspace, V, nano_batch_size, C, EMMTranspose::TN, false, main_stream, rs->MatmulBackend); // accumulate the losses inside rs->losses, and kick off the backward pass inside the fused classifier fused_classifier(rs->Output, losses, lse, d_loss, tgt, 0.f, nano_batch_size, V, Vp, false, main_stream); @@ -281,8 +284,10 @@ void backward_qmm(Tensor& dinp, Tensor& dweight, Tensor dbias, int B, int T, int C, int OC, bool reuse_inp, cudaStream_t stream) { if (weight.DType == inp.Value.DType) { - matmul(dinp, weight, dout.Value, Tensor{}, nullptr, nullptr, rs.CublasLtHandle, rs.CuBlasWorkspace, C, B*T, OC, EMMTranspose::NN, false, stream); - matmul(dweight, inp.Value, dout.Value, Tensor{}, nullptr, nullptr, rs.CublasLtHandle, rs.CuBlasWorkspace, C, OC, B*T, EMMTranspose::NT, accumulate_gradient, stream); + matmul(dinp, weight, dout.Value, Tensor{}, nullptr, nullptr, + rs.CublasLtHandle, rs.CuBlasWorkspace, C, B*T, OC, EMMTranspose::NN, false, stream, rs.MatmulBackend); + matmul(dweight, inp.Value, dout.Value, Tensor{}, nullptr, nullptr, + rs.CublasLtHandle, rs.CuBlasWorkspace, C, OC, B*T, EMMTranspose::NT, accumulate_gradient, stream, rs.MatmulBackend); if (dbias) { backward_bias(dbias, dout.Value, nullptr, nullptr, bias_buffer, B, T, OC, rs.DeviceProp, stream); @@ -293,8 +298,10 @@ void backward_qmm(Tensor& dinp, Tensor& dweight, Tensor dbias, quantize_with_abs_max(inp.Quant, dout.Quant.scale(), inp.Value, nullptr, B*T*C, rs.DeviceProp, stream); } - matmul(dinp, weight, dout.Quant, Tensor{}, nullptr, nullptr, rs.CublasLtHandle, rs.CuBlasWorkspace, C, B*T, OC, EMMTranspose::NN, false, stream); - matmul(dweight, inp.Quant, dout.Quant, Tensor{}, nullptr, nullptr, rs.CublasLtHandle, rs.CuBlasWorkspace, C, OC, B*T, EMMTranspose::NT, accumulate_gradient, stream); + matmul(dinp, weight, dout.Quant, Tensor{}, nullptr, nullptr, + rs.CublasLtHandle, rs.CuBlasWorkspace, C, B*T, OC, EMMTranspose::NN, false, stream, rs.MatmulBackend); + matmul(dweight, inp.Quant, dout.Quant, Tensor{}, nullptr, nullptr, + rs.CublasLtHandle, rs.CuBlasWorkspace, C, OC, B*T, EMMTranspose::NT, accumulate_gradient, stream, rs.MatmulBackend); if (dbias) { backward_bias(dbias, dout.Value, nullptr, nullptr, bias_buffer, B, T, OC, rs.DeviceProp, stream); @@ -307,7 +314,7 @@ void backward_qmm(Tensor& dinp, Tensor& dweight, Tensor dbias, transpose(weight_tp, weight, OC, C, stream); matmul(dinp, weight_tp, dout.Quant, Tensor{}, weight.scale(), dout.Quant.scale(), - rs.CublasLtHandle, rs.CuBlasWorkspace, C, B*T, OC, EMMTranspose::TN, false, stream); + rs.CublasLtHandle, rs.CuBlasWorkspace, C, B*T, OC, EMMTranspose::TN, false, stream, rs.MatmulBackend); rs.temp_free(weight_tp); auto activation_tp = rs.temp_alloc(inp_q.DType, {C, B*T}); @@ -322,7 +329,8 @@ void backward_qmm(Tensor& dinp, Tensor& dweight, Tensor dbias, } transpose(grad_tp, dout.Quant, B*T, OC, stream); - matmul(dweight, activation_tp, grad_tp, Tensor{}, inp_q.scale(), dout.Quant.scale(), rs.CublasLtHandle, rs.CuBlasWorkspace, C, OC, B*T, EMMTranspose::TN, accumulate_gradient, stream); + matmul(dweight, activation_tp, grad_tp, Tensor{}, inp_q.scale(), dout.Quant.scale(), + rs.CublasLtHandle, rs.CuBlasWorkspace, C, OC, B*T, EMMTranspose::TN, accumulate_gradient, stream, rs.MatmulBackend); if (dbias) { backward_bias(dbias, dout.Quant, inp_q.scale(), dout.Quant.scale(), bias_buffer, B, T, OC, rs.DeviceProp, stream); } @@ -484,7 +492,7 @@ void LLamaModel::_backward_lmhead(long B, long T, float z_loss, int micro_step, matmul(rs->Output, Parameters->get_head(main_stream), lnf_slice, Tensor{}, nullptr, nullptr, rs->CublasLtHandle, rs->CuBlasWorkspace, V, nano_batch_size, C, EMMTranspose::TN, - false, main_stream); + false, main_stream, rs->MatmulBackend); if(nano_step == 0) { // make sure Targets have been copied @@ -504,13 +512,13 @@ void LLamaModel::_backward_lmhead(long B, long T, float z_loss, int micro_step, auto& d_lmhead = Grads->get_lmhead_full(main_stream, comm, accumulate); accumulate |= nano_step != 0; matmul(d_lmhead, lnf_slice, rs->Output, Tensor{}, nullptr, nullptr, - rs->CublasLtHandle, rs->CuBlasWorkspace, C, V, nano_batch_size, EMMTranspose::NT, accumulate, main_stream); + rs->CublasLtHandle, rs->CuBlasWorkspace, C, V, nano_batch_size, EMMTranspose::NT, accumulate, main_stream, rs->MatmulBackend); if (nano_step == nano_batches - 1) { Grads->notify_lmhead(main_stream, comm); } matmul(dlnf_slice, Parameters->get_head(main_stream), rs->Output, Tensor{}, nullptr, nullptr, - rs->CublasLtHandle, rs->CuBlasWorkspace, C, nano_batch_size, V, EMMTranspose::NN, false, main_stream); + rs->CublasLtHandle, rs->CuBlasWorkspace, C, nano_batch_size, V, EMMTranspose::NN, false, main_stream, rs->MatmulBackend); } rs->temp_free(rs->Output); @@ -559,7 +567,7 @@ void LLamaModel::_recompute_block(sLLamaBlockWeights& weights, sLLamaLay forward_qmm(acts.QKV, acts.LN1, weights.Attn_QKV_w, weights.Attn_QKV_b, rs->CublasLtHandle, rs->CuBlasWorkspace, B, T, C, Config.qkv_channels(), - rs->DeviceProp, !recompute_ln1, main_stream); + rs->DeviceProp, !recompute_ln1, main_stream, rs->MatmulBackend); rope_forward(acts.QKV, acts.QKV, rs->FreqCis, nullptr, B, T, Hq, Hkv, Hs, main_stream); } @@ -571,7 +579,7 @@ void LLamaModel::_recompute_block(sLLamaBlockWeights& weights, sLLamaLay forward_qmm(acts.AttO, acts.Att, weights.Attn_Out_w, Tensor{}, rs->CublasLtHandle, rs->CuBlasWorkspace, B, T, C, C, - rs->DeviceProp, false, main_stream); + rs->DeviceProp, false, main_stream, rs->MatmulBackend); } } @@ -590,7 +598,7 @@ void LLamaModel::_recompute_block(sLLamaBlockWeights& weights, sLLamaLay forward_qmm(acts.MlpUp, acts.LN2, weights.MLP_Up_w, Tensor{}, rs->CublasLtHandle, rs->CuBlasWorkspace, B, T, C, 2 * D, - rs->DeviceProp, false, main_stream); + rs->DeviceProp, false, main_stream, rs->MatmulBackend); } if(recompute_swiglu) { @@ -911,6 +919,11 @@ void LLamaModel::allocate_run_state(const LLamaOptions& options, NCCLCommunicato OptimizerRNG = std::minstd_rand{42}; RunState = std::make_unique(std::move(acts)); + if (options.UseCustomMatmul) { + RunState->MatmulBackend = EMatmulBackend::Custom; + } else { + RunState->MatmulBackend = EMatmulBackend::CuBLAS; + } comm.barrier(); // make sure *all* GPUs have allocated the model before returning } diff --git a/src/models/llama_model.h b/src/models/llama_model.h index c184320..32c0bbe 100644 --- a/src/models/llama_model.h +++ b/src/models/llama_model.h @@ -43,6 +43,7 @@ struct LLamaOptions { bool ShardGradients = false; bool UseAllToAllReduce = false; + bool UseCustomMatmul = false; bool InitProjectionsToZero = false; diff --git a/src/testing/test-gemm.cpp b/src/testing/test-gemm.cpp new file mode 100644 index 0000000..8a055e0 --- /dev/null +++ b/src/testing/test-gemm.cpp @@ -0,0 +1,262 @@ +// Copyright (c) 2026, IST Austria, developed by Erik Schultheis +// SPDX-License-Identifier: Apache-2.0 +// + +#include "kernels/kernels.h" +#include "utilities/utils.h" +#include +#include +#include + +#include + +#include +#include + +#include "test_config.h" + +template +extern void matmul_cublaslt(floatO* d, const floatX* a, const floatX* b, const floatB* bias, + std::byte* workspace, std::size_t workspace_size, + int m, int n, int k, cudaStream_t stream, cublasLtHandle_t handle, + const float* scale_a, const float* scale_b, EMMTranspose mode, bool accumulate); + +extern cublasLtHandle_t create_cublaslt_handle(); + +template +void run_test(int m, int n, int k, float scale = 1.f, bool accumulate = false, bool use_bias=false, bool check=true) { + Atype* a; + Btype* b; + Ctype* c; + Ctype* bias; + CUDA_CHECK(cudaMallocManaged(&a, m * k * sizeof(Atype))); + CUDA_CHECK(cudaMallocManaged(&b, n * k * sizeof(Btype))); + CUDA_CHECK(cudaMallocManaged(&c, m * n * sizeof(Ctype))); + CUDA_CHECK(cudaMallocManaged(&bias, n * sizeof(Ctype))); + CUDA_CHECK(cudaMemset(a, 0, m * k * sizeof(Atype))); + CUDA_CHECK(cudaMemset(c, 0, m * n * sizeof(Ctype))); + CUDA_CHECK(cudaMemset(b, 0, n * k * sizeof(Btype))); + CUDA_CHECK(cudaMemset(bias, 0, n * sizeof(Ctype))); + + float* a_float = nullptr; + float* b_float = nullptr; + float* c_float = nullptr; + float* bias_float = nullptr; + CUDA_CHECK(cudaMallocManaged(&a_float, m * k * sizeof(float))); + CUDA_CHECK(cudaMallocManaged(&b_float, n * k * sizeof(float))); + CUDA_CHECK(cudaMallocManaged(&c_float, m * n * sizeof(float))); + CUDA_CHECK(cudaMallocManaged(&bias_float, n * sizeof(float))); + CUDA_CHECK(cudaMemset(a_float, 0, m * k * sizeof(float))); + CUDA_CHECK(cudaMemset(b_float, 0, n * k * sizeof(float))); + CUDA_CHECK(cudaMemset(c_float, 0, m * n * sizeof(float))); + CUDA_CHECK(cudaMemset(bias_float, 0, n * sizeof(float))); + + std::mt19937 rng(12345); + std::uniform_int_distribution dist(-15, 15); + + for(int i = 0; i < m; ++i) { + for(int j = 0; j < k; ++j) { + auto val = static_cast(dist(rng)); + a[i*k+j] = val; + a_float[i*k+j] = static_cast(val); + } + } + + for(int i = 0; i < n; ++i) { + for(int j = 0; j < k; ++j) { + auto val = static_cast(dist(rng)); + b[i*k+j] = val; + b_float[i*k+j] = static_cast(val); + } + } + + for(int i = 0; i < m; ++i) { + for(int j = 0; j < n; ++j) { + auto val = static_cast(dist(rng)); + c[i*n+j] = val; + c_float[i*n+j] = static_cast(val); + } + } + + for(int i = 0; i < n; ++i) { + auto val = static_cast(dist(rng)); + bias[i] = val; + bias_float[i] = static_cast(val); + } + + float* scale_a_ptr, *scale_b_ptr; + CUDA_CHECK(cudaMallocManaged(&scale_a_ptr, sizeof(float))); + CUDA_CHECK(cudaMallocManaged(&scale_b_ptr, sizeof(float))); + *scale_a_ptr = sqrtf(scale); + *scale_b_ptr = sqrtf(scale); + + CUDA_CHECK(cudaMemPrefetchAsync(a, m*k * sizeof(Atype), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); + CUDA_CHECK(cudaMemPrefetchAsync(b, n*k * sizeof(Btype), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); + CUDA_CHECK(cudaMemPrefetchAsync(c, m*n * sizeof(Ctype), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); + CUDA_CHECK(cudaMemPrefetchAsync(scale_a_ptr, 4, cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); + CUDA_CHECK(cudaMemPrefetchAsync(scale_b_ptr, 4, cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); + + cublasLtHandle_t handle = create_cublaslt_handle(); + std::byte* workspace; + size_t workspace_size = 128 * 1024 * 1024; + CUDA_CHECK(cudaMalloc(&workspace, workspace_size)); + + CUDA_CHECK(cudaDeviceSynchronize()); + matmul(c, a, b, use_bias ? bias : nullptr, scale_a_ptr, scale_b_ptr, handle, workspace, workspace_size, m, n, k, EMMTranspose::TN, accumulate, nullptr, EMatmulBackend::Custom); + CUDA_CHECK(cudaDeviceSynchronize()); + + if(check) { + CUDA_CHECK(cudaMemPrefetchAsync(a_float, m*k * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); + CUDA_CHECK(cudaMemPrefetchAsync(b_float, n*k * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); + CUDA_CHECK(cudaMemPrefetchAsync(c_float, m*n * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); + CUDA_CHECK(cudaMemPrefetchAsync(bias_float, n * sizeof(float), cudaMemLocation{cudaMemLocationTypeDevice, 0}, 0)); + CUDA_CHECK(cudaDeviceSynchronize()); + matmul(c_float, a_float, b_float, use_bias ? bias_float : nullptr, nullptr, nullptr, + handle, workspace, workspace_size, m, n, k , EMMTranspose::TN, accumulate, nullptr, EMatmulBackend::CuBLAS); + CUDA_CHECK(cudaDeviceSynchronize()); + + double r_tol = 1e-2; + int approx_count = 0; + int far_count = 0; + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + float expected = c_float[j * m + i] * (*scale_a_ptr) * (*scale_b_ptr); + float received = (float) c[j * m + i]; + float r_tol_max = r_tol * std::max(fabsf((float) c[j * m + i]), fabsf(expected)); + float err = fabsf(expected - received); + float tol = std::max(r_tol_max, 1e-4f); + if (err > 10 * tol) { + printf(" %d %d: %f != %f\n", i, j, expected, received); + ++far_count; + } else if (err > tol) { + ++approx_count; + } + } + if (far_count > 0) { + break; + } + } + + if (far_count == 0 && approx_count == 0) { + SUCCEED(); + printf("PASS\n"); + } else if(far_count < m * n / 100 && approx_count < m * n / 10) { + SUCCEED(); + printf("CLOSE %d%% [%d+%d]\n", 100 - (approx_count + far_count) * 100 / (m * n), far_count, approx_count); + } else { + FAIL(); + } + } + + CUDA_CHECK(cudaFree(a)); + CUDA_CHECK(cudaFree(b)); + CUDA_CHECK(cudaFree(c)); + CUDA_CHECK(cudaFree(bias)); + CUDA_CHECK(cudaFree(a_float)); + CUDA_CHECK(cudaFree(b_float)); + CUDA_CHECK(cudaFree(c_float)); + CUDA_CHECK(cudaFree(bias_float)); + CUDA_CHECK(cudaFree(scale_a_ptr)); + CUDA_CHECK(cudaFree(scale_b_ptr)); + CUDA_CHECK(cudaFree(workspace)); + cublasLtDestroy(handle); +} + +TEST_CASE("tiny matmul bfloat16 x bfloat16 -> bfloat16", "[gemm][bf16]") { + bool accumulate = false; + bool bias = false; + SECTION("set-nobias") { + accumulate = false; + bias = false; + } + SECTION("set-bias") { + accumulate = false; + bias = true; + } + SECTION("accumulate-nobias") { + accumulate = true; + bias = false; + } + SECTION("accumulate-bias") { + accumulate = true; + bias = true; + } + run_test(128, 128, 128, 1.0f, accumulate, bias); +} + +TEST_CASE("tiny matmul fp8 x fp8 -> bfloat16", "[gemm][fp8]") { + bool accumulate = false; + bool bias = false; + SECTION("set-nobias") { + accumulate = false; + bias = false; + } + SECTION("set-bias") { + accumulate = false; + bias = true; + } + SECTION("accumulate-nobias") { + accumulate = true; + bias = false; + } + SECTION("accumulate-bias") { + accumulate = true; + bias = true; + } + + run_test<__nv_fp8_e4m3, __nv_fp8_e4m3, nv_bfloat16>(128, 128, 128, 4.0f / 128, accumulate, bias); +} + +TEST_CASE("matmul bfloat16 x bfloat16 -> bfloat16", "[gemm][bf16]") { + const auto& cfg = testing_config::get_test_config(); + int m = cfg.B * cfg.T; + int k = cfg.C; + int n = div_ceil(2 * m / 3, 128) * 128; + + bool accumulate = false; + bool bias = false; + SECTION("set-nobias") { + accumulate = false; + bias = false; + } + SECTION("set-bias") { + accumulate = false; + bias = true; + } + SECTION("accumulate-nobias") { + accumulate = true; + bias = false; + } + SECTION("accumulate-bias") { + accumulate = true; + bias = true; + } + run_test(m, n, k, 1.0f, accumulate, bias); +} + +TEST_CASE("matmul fp8 x fp8 -> bfloat16", "[gemm][fp8]") { + const auto& cfg = testing_config::get_test_config(); + int m = cfg.B * cfg.T; + int k = cfg.C; + int n = div_ceil(2 * m / 3, 128) * 128; + + bool accumulate = false; + bool bias = false; + SECTION("set-nobias") { + accumulate = false; + bias = false; + } + SECTION("set-bias") { + accumulate = false; + bias = true; + } + SECTION("accumulate-nobias") { + accumulate = true; + bias = false; + } + SECTION("accumulate-bias") { + accumulate = true; + bias = true; + } + run_test<__nv_fp8_e4m3, __nv_fp8_e4m3, nv_bfloat16>(m, n, k, 4.0f / k, accumulate, bias); +} diff --git a/src/testing/test-rope.cu b/src/testing/test-rope.cu index d72b2e9..461f1d9 100644 --- a/src/testing/test-rope.cu +++ b/src/testing/test-rope.cu @@ -204,10 +204,10 @@ TEST_CASE("rope forward/backward bfloat16 matches CPU (emulated)", "[rope][bf16] std::vector h_inp_f = uniform_host(size_inp, -1.f, 1.f, 1337ull); std::vector h_inp_bf16 = to_bf16(h_inp_f); - // Prepare freqs and quantize to bf16 (kernel expects bf16 freqs as well) + // Prepare freqs and quantize to fp16 (kernel expects fp16 freqs) std::vector h_freqs_f(size_freqs); precompute_freqs_cis(h_freqs_f.data(), HD, T, 10000.0f); - std::vector h_freqs_bf16 = to_bf16(h_freqs_f); + std::vector h_freqs_fp16 = to_fp16(h_freqs_f); // CPU baseline with bf16 emulation: quantize inputs/freqs to bf16, do math in float, quantize outputs std::vector h_inp_q = round_bf16(h_inp_f); @@ -222,7 +222,7 @@ TEST_CASE("rope forward/backward bfloat16 matches CPU (emulated)", "[rope][bf16] thrust::device_vector d_inp = to_device(h_inp_bf16); thrust::device_vector d_out(size_inp); thrust::device_vector d_dinp(size_inp); - thrust::device_vector d_freqs = to_device(h_freqs_bf16); + thrust::device_vector d_freqs = to_device(h_freqs_fp16); rope_forward(thrust::raw_pointer_cast(d_out.data()), thrust::raw_pointer_cast(d_inp.data()), diff --git a/src/testing/test_utils.h b/src/testing/test_utils.h index dc1e8d5..5bec8b5 100644 --- a/src/testing/test_utils.h +++ b/src/testing/test_utils.h @@ -62,6 +62,13 @@ inline std::vector to_bf16(const std::vector& vec) { } return result; } +inline std::vector to_fp16(const std::vector& vec) { + std::vector result(vec.size()); + for(size_t i = 0; i < vec.size(); ++i) { + result[i] = half(vec[i]); + } + return result; +} inline std::vector round_bf16(const std::vector& vec) { std::vector result(vec.size()); diff --git a/src/training/model.h b/src/training/model.h index 8e83453..00b6825 100644 --- a/src/training/model.h +++ b/src/training/model.h @@ -10,10 +10,12 @@ #include #include +#include "kernels/kernels.h" #include "utilities/stack.h" #include "utilities/tensor.h" #include "training/transformer_config.h" +enum class EMatmulBackend; class AdamWStateManager; class ITensorContainer; class NCCLCommunicator; @@ -188,6 +190,7 @@ class IRunState { cudnnHandle_t CudnnHandle = nullptr; cublasLtHandle_t CublasLtHandle = nullptr; Tensor CuBlasWorkspace; + EMatmulBackend MatmulBackend = EMatmulBackend{0}; // events for debugging timings void setup_timing_events(int micro_steps); diff --git a/src/utilities/sol.cpp b/src/utilities/sol.cpp index c64ab96..a0d7036 100644 --- a/src/utilities/sol.cpp +++ b/src/utilities/sol.cpp @@ -352,14 +352,14 @@ double measure_real_peak() { } ++trip_count; matmul(c, a, b, nullptr, nullptr, nullptr, handle, workspace, 32 * 1024 * 1024, - 16384, 16384, 16384, EMMTranspose::TN, false, nullptr); + 16384, 16384, 16384, EMMTranspose::TN, false, nullptr, EMatmulBackend::CuBLAS); } // now, actual measurement CUDA_CHECK(cudaEventRecord(start_event)); for(int i = 0; i < trip_count; ++i) { matmul(c, a, b, nullptr, nullptr, nullptr, handle, workspace, 32 * 1024 * 1024, - 16384, 16384, 16384, EMMTranspose::TN, false, nullptr); + 16384, 16384, 16384, EMMTranspose::TN, false, nullptr, EMatmulBackend::CuBLAS); } CUDA_CHECK(cudaEventRecord(stop_event)); CUDA_CHECK(cudaEventSynchronize(stop_event)); diff --git a/train.cpp b/train.cpp index 8409676..eb0bf7b 100644 --- a/train.cpp +++ b/train.cpp @@ -199,6 +199,7 @@ void TrainingRunner::load_training_config(int argc, const char** argv) { app.add_flag("--memcpy-send-recv", MemcpySendRecv, "Use memcpy to perform send/receive (all-to-all). Currently only supported by the threads backend."); app.add_flag("--all-to-all-reduce", Options.UseAllToAllReduce, "Uses an all-to-all-based reduce algorithm. Combine with --memcpy-send-recv."); app.add_flag("--write-combined", Options.UseWriteCombined, "Uses write-combined memory for offloaded tensors."); + app.add_flag("--custom-matmul", Options.UseCustomMatmul, "Use a self-written matmul instead of cuBLAS."); try { app.parse(argc, argv);