diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index c0f1a3c315..3469d99788 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -1359,4 +1359,8 @@ void QQMatmul::eval_cpu(const std::vector& inputs, array& out) { } } +void GatherQQMM::eval_cpu(const std::vector& inputs, array& out) { + throw std::runtime_error("[GatherQQMM] NYI"); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/device/qmm_naive.cuh b/mlx/backend/cuda/device/qmm_naive.cuh index 01e5f444d5..e2ad855b9f 100644 --- a/mlx/backend/cuda/device/qmm_naive.cuh +++ b/mlx/backend/cuda/device/qmm_naive.cuh @@ -2,6 +2,7 @@ #include "mlx/backend/cuda/device/cute_dequant.cuh" #include "mlx/backend/cuda/device/gemm_sm70.cuh" +#include "mlx/backend/cuda/device/utils.cuh" #include @@ -25,6 +26,7 @@ CUTE_DEVICE void qmm_naive_mainloop( TensorS gS, TensorZ gZ, TensorC gC, + const float* global_scale, int m_max_coord, int n_max_coord, int k_residue, @@ -32,6 +34,7 @@ CUTE_DEVICE void qmm_naive_mainloop( // Get the types of operands. using Element = typename decltype(gA)::value_type; using Quant = typename decltype(gB)::value_type; + using Scale = typename decltype(gS)::value_type; // Shift tensor so we handle residue of K in the 0th tile. gA = domain_offset(make_coord(0, k_residue, 0), gA); @@ -196,6 +199,15 @@ CUTE_DEVICE void qmm_naive_mainloop( CUTE_UNROLL for (int i = 0; i < size(tCrC); ++i) { if ((get<0>(tCcC(i)) < m_max_coord) && (get<1>(tCcC(i)) < n_max_coord)) { + if constexpr ( + cuda::std::is_same_v && + cuda::std::is_same_v) { + // Only nvfp4 supports global scale. + if (global_scale) { + tCgC(i) = Element(tCrC(i) * (*global_scale / (F8E4M3_MAX * F4E2M1_MAX))); + continue; + } + } tCgC(i) = Element(tCrC(i)); } } @@ -224,6 +236,7 @@ void qmm_naive_kernel( const Quant* B, const Scale* S, const Element* Z, + const float* global_scale, const uint32_t* lhs_indices, const uint32_t* rhs_indices, Element* C, @@ -295,6 +308,7 @@ void qmm_naive_kernel( gS, gZ, gC, + global_scale, m_max_coord, n_max_coord, k_residue, thread_idx); } diff --git a/mlx/backend/cuda/quantized/qmm/qmm.h b/mlx/backend/cuda/quantized/qmm/qmm.h index 8d998cda40..64cfecfd5a 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm.h +++ b/mlx/backend/cuda/quantized/qmm/qmm.h @@ -74,6 +74,7 @@ void qmm_naive( const array& w, const array& scales, const std::optional& biases, + const std::optional& global_scale, const std::optional& lhs_indices, const std::optional& rhs_indices, array& out, diff --git a/mlx/backend/cuda/quantized/qmm/qmm_naive.cu b/mlx/backend/cuda/quantized/qmm/qmm_naive.cu index cb47d7f1aa..c7c5c6049a 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_naive.cu +++ b/mlx/backend/cuda/quantized/qmm/qmm_naive.cu @@ -29,6 +29,7 @@ void qmm_naive( const array& w, const array& scales, const std::optional& biases, + const std::optional& global_scale, const std::optional& lhs_indices, const std::optional& rhs_indices, array& out, @@ -75,6 +76,9 @@ void qmm_naive( if (biases) { encoder.set_input_array(*biases); } + if (global_scale) { + encoder.set_input_array(*global_scale); + } if (lhs_indices) { encoder.set_input_array(*lhs_indices); } @@ -103,6 +107,7 @@ void qmm_naive( gpu_ptr(w), gpu_ptr(scales), biases ? gpu_ptr(*biases) : nullptr, + global_scale ? gpu_ptr(*global_scale) : nullptr, lhs_indices ? gpu_ptr(*lhs_indices) : nullptr, rhs_indices ? gpu_ptr(*rhs_indices) : nullptr, gpu_ptr(out), diff --git a/mlx/backend/cuda/quantized/qqmm.cpp b/mlx/backend/cuda/quantized/qqmm.cpp index eaec2ac8f4..196bd9c05e 100644 --- a/mlx/backend/cuda/quantized/qqmm.cpp +++ b/mlx/backend/cuda/quantized/qqmm.cpp @@ -21,7 +21,7 @@ std::tuple quantize_input( QuantizationMode mode, int bits, int group_size, - std::optional global_scale = std::nullopt) { + std::optional global_scale) { const array x = ensure_contiguous(input, encoder, s); // Compute output shapes @@ -52,6 +52,27 @@ std::tuple quantize_input( return {std::move(x_q), std::move(scales_x)}; } +array quantize_dequantize_input( + const array& x_pre, + const std::optional& global_scale, + int bits, + int group_size, + cu::CommandEncoder& encoder, + Stream s) { + bool donate_x = x_pre.is_donatable(); + array x = ensure_row_contiguous(x_pre, encoder, s); + // If x is a copy it should be donatable + donate_x |= x.is_donatable(); + auto xhat = donate_x + ? x + : array(cu::malloc_async(x.nbytes(), encoder), x.shape(), x.dtype()); + if (!donate_x) { + encoder.add_temporary(xhat); + } + fp_quantize_dequantize(x, xhat, group_size, bits, global_scale, encoder, s); + return xhat; +} + GemmScalars create_nvfp4_scalars( const array& global_scale_x, const array& global_scale_w, @@ -75,77 +96,81 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& encoder = cu::get_command_encoder(s); auto& device = encoder.device(); - bool w_quantized = (inputs[1].dtype() == uint32); + + const array& x_pre = inputs[0]; + const array& w_pre = inputs[1]; + + out.set_data(cu::malloc_async(out.nbytes(), encoder)); // - 2 inputs: x, w (non-quantized w) // - 3 inputs: x, w, scales_w (quantized w) + bool w_quantized = (w_pre.dtype() == uint32); int base_size = w_quantized ? 3 : 2; - assert( - inputs.size() == base_size || - (mode_ == QuantizationMode::Nvfp4 && inputs.size() == base_size + 2)); - // For nvfp4, global scales are optional but must be both present or both // absent If present, they add 2 more inputs (global_scale_x, global_scale_w) bool has_global_scales = - mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size; - std::optional global_scale_x = std::nullopt; - std::optional global_scale_w = std::nullopt; + mode_ == QuantizationMode::Nvfp4 && inputs.size() == base_size + 2; + assert(inputs.size() == base_size || has_global_scales); + + std::optional global_scale_x; + std::optional global_scale_w; if (has_global_scales) { global_scale_x = inputs[inputs.size() - 2]; global_scale_w = inputs[inputs.size() - 1]; } - if (w_quantized && inputs[0].shape(-2) == 1) { - out.set_data(cu::malloc_async(out.nbytes(), encoder)); - - bool donate_x = inputs[0].is_donatable(); - array x = ensure_row_contiguous(inputs[0], encoder, s); - // If x is a copy it should be donatable - donate_x |= x.is_donatable(); - auto xhat = donate_x - ? x - : array(cu::malloc_async(x.nbytes(), encoder), x.shape(), x.dtype()); - if (!donate_x) { - encoder.add_temporary(xhat); + // Quantize weights. + auto [w_q, scales_w] = !w_quantized + ? quantize_input( + w_pre, encoder, s, mode_, bits_, group_size_, global_scale_w) + : std::make_tuple( + ensure_contiguous(w_pre, encoder, s), + ensure_contiguous(inputs[base_size - 1], encoder, s)); + + // Reroute to qmm when: no support in cuBLAS, or doing GEMV. + bool can_use_cublas = + (mode_ == QuantizationMode::Nvfp4 || mode_ == QuantizationMode::Mxfp8) && + (device.compute_capability_major() >= 10); + int M = x_pre.shape(-2); + bool use_qmm = (!can_use_cublas) || (M == 1); + + if (use_qmm) { + array x = quantize_dequantize_input( + x_pre, global_scale_x, bits_, group_size_, encoder, s); + if (M < 8) { + qmv(x, + w_q, + scales_w, + std::nullopt, + global_scale_w, + out, + bits_, + group_size_, + mode_, + encoder); + } else { + qmm_naive( + x, + w_q, + scales_w, + std::nullopt, + global_scale_w, + std::nullopt, + std::nullopt, + out, + true, // transpose + bits_, + group_size_, + mode_, + encoder); } - fp_quantize_dequantize( - x, xhat, group_size_, bits_, global_scale_x, encoder, s); - - const array& w = inputs[1]; - const array& scales = inputs[2]; - qmv(xhat, - w, - scales, - std::nullopt, - global_scale_w, - out, - bits_, - group_size_, - mode_, - encoder); return; } - auto cc = device.compute_capability_major() * 100 + - device.compute_capability_minor() * 10; - if (cc < 1000) { - throw std::runtime_error( - "[QQMatmul::eval_gpu] QQMM is only supported on GPUs with compute capability 10.0 or higher."); - } - - // Quantize inputs (or use pre-quantized) - auto [x_q, scale_x_pre] = quantize_input( - inputs[0], encoder, s, mode_, bits_, group_size_, global_scale_x); - auto [w_q, scale_w_pre] = !w_quantized - ? quantize_input( - inputs[1], encoder, s, mode_, bits_, group_size_, global_scale_w) - : std::make_tuple( - ensure_contiguous(inputs[1], encoder, s), - ensure_contiguous(inputs[2], encoder, s)); - - out.set_data(cu::malloc_async(out.nbytes(), encoder)); + // Quantize activation. + auto [x_q, scales_x] = quantize_input( + x_pre, encoder, s, mode_, bits_, group_size_, global_scale_x); - int M = x_q.shape(-2); int N = w_q.shape(-2); // transposed int K = x_q.shape(-1) * (32 / bits_); @@ -155,8 +180,8 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { int64_t ldb = K; // Repack scales to tiled layout for tensor cores - array scale_x = pad_and_swizzle_scales(scale_x_pre, encoder, s); - array scale_w = pad_and_swizzle_scales(scale_w_pre, encoder, s); + scales_x = pad_and_swizzle_scales(scales_x, encoder, s); + scales_w = pad_and_swizzle_scales(scales_w, encoder, s); GemmScalars scalars; if (has_global_scales) { @@ -175,10 +200,69 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { out, x_q, w_q, - scale_x, - scale_w, + scales_x, + scales_w, mode_, scalars); } +void GatherQQMM::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("QQMatmul::eval_gpu"); + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + const array& x_pre = inputs[0]; + const array& w_pre = inputs[1]; + const array& lhs_indices = ensure_row_contiguous(inputs[2], encoder, s); + const array& rhs_indices = ensure_row_contiguous(inputs[3], encoder, s); + + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + + // - 4 inputs: x, w, lhs_indices, rhs_indices (non-quantized w) + // - 5 inputs: x, w, lhs_indices, rhs_indices, scales_w (quantized w) + bool w_quantized = (w_pre.dtype() == uint32); + int base_size = w_quantized ? 5 : 4; + // For nvfp4, global scales are optional but must be both present or both + // absent If present, they add 2 more inputs (global_scale_x, global_scale_w) + bool has_global_scales = + mode_ == QuantizationMode::Nvfp4 && inputs.size() == base_size + 2; + assert(inputs.size() == base_size || has_global_scales); + + std::optional global_scale_x; + std::optional global_scale_w; + if (has_global_scales) { + global_scale_x = inputs[inputs.size() - 2]; + global_scale_w = inputs[inputs.size() - 1]; + } + + // Quantize weights. + auto [w_q, scales_w] = !w_quantized + ? quantize_input( + w_pre, encoder, s, mode_, bits_, group_size_, global_scale_w) + : std::make_tuple( + ensure_contiguous(w_pre, encoder, s), + ensure_contiguous(inputs[base_size - 1], encoder, s)); + + // Quantize activation. + array x = quantize_dequantize_input( + x_pre, global_scale_x, bits_, group_size_, encoder, s); + + // Reroute to qmm. + qmm_naive( + x, + w_q, + scales_w, + std::nullopt, + global_scale_w, + lhs_indices, + rhs_indices, + out, + true, // transpose + bits_, + group_size_, + mode_, + encoder); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 2a1a268c91..4d25f3c3e0 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -72,6 +72,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { biases, std::nullopt, std::nullopt, + std::nullopt, out, transpose_, bits_, @@ -211,6 +212,7 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { w, scales, biases, + std::nullopt, lhs_indices, rhs_indices, out, diff --git a/mlx/backend/metal/kernels/fp4.h b/mlx/backend/metal/kernels/fp4.h index 25642f2016..6922181577 100644 --- a/mlx/backend/metal/kernels/fp4.h +++ b/mlx/backend/metal/kernels/fp4.h @@ -1,5 +1,8 @@ #pragma once +constant constexpr float F8E4M3_MAX = 448.0f; +constant constexpr float F4E2M1_MAX = 6.0f; + struct fp4_e2m1 { fp4_e2m1(float x) { if (metal::isnan(x)) { diff --git a/mlx/backend/metal/kernels/fp_quantized.h b/mlx/backend/metal/kernels/fp_quantized.h index 8d6740db5b..d43ae914e2 100644 --- a/mlx/backend/metal/kernels/fp_quantized.h +++ b/mlx/backend/metal/kernels/fp_quantized.h @@ -321,10 +321,11 @@ METAL_FUNC void fp_qmv_quad_impl( } } -template +template METAL_FUNC void fp_qmv_fast_impl( const device uint32_t* w, const device uint8_t* scales, + const device float* global_scale, const device T* x, device T* y, const constant int& in_vec_size, @@ -374,18 +375,28 @@ METAL_FUNC void fp_qmv_fast_impl( x += block_size; } + float inv_scale_enc = 1.0f; + if constexpr (has_global_scale) { + inv_scale_enc = *global_scale / (F8E4M3_MAX * F4E2M1_MAX); + } + for (int row = 0; row < results_per_simdgroup; row++) { result[row] = simd_sum(result[row]); if (simd_lid == 0) { - y[row] = static_cast(result[row]); + if constexpr (has_global_scale) { + y[row] = static_cast(result[row] * inv_scale_enc); + } else { + y[row] = static_cast(result[row]); + } } } } -template +template METAL_FUNC void fp_qmv_impl( const device uint32_t* w, const device uint8_t* scales, + const device float* global_scale, const device T* x, device T* y, const constant int& in_vec_size, @@ -421,6 +432,11 @@ METAL_FUNC void fp_qmv_impl( return; } + float inv_scale_enc = 1.0f; + if constexpr (has_global_scale) { + inv_scale_enc = *global_scale / (F8E4M3_MAX * F4E2M1_MAX); + } + // In this case we need to properly guard all our reads because there isn't // even 1 tile in the matrix if (out_vec_size < (num_simdgroups * results_per_simdgroup)) { @@ -471,7 +487,11 @@ METAL_FUNC void fp_qmv_impl( row++) { result[row] = simd_sum(result[row]); if (simd_lid == 0) { - y[row] = static_cast(result[row]); + if constexpr (has_global_scale) { + y[row] = static_cast(result[row] * inv_scale_enc); + } else { + y[row] = static_cast(result[row]); + } } } } @@ -519,7 +539,11 @@ METAL_FUNC void fp_qmv_impl( for (int row = 0; row < results_per_simdgroup; row++) { result[row] = simd_sum(result[row]); if (simd_lid == 0) { - y[row] = static_cast(result[row]); + if constexpr (has_global_scale) { + y[row] = static_cast(result[row] * inv_scale_enc); + } else { + y[row] = static_cast(result[row]); + } } } } @@ -644,10 +668,11 @@ METAL_FUNC void fp_qmv_wide_impl( } } -template +template METAL_FUNC void fp_qvm_impl( const device uint32_t* w, const device uint8_t* scales, + const device float* global_scale, const device T* x, device T* y, const int in_vec_size, @@ -690,6 +715,11 @@ METAL_FUNC void fp_qvm_impl( return; } + float inv_scale_enc = 1.0f; + if constexpr (has_global_scale) { + inv_scale_enc = *global_scale / (F8E4M3_MAX * F4E2M1_MAX); + } + // Loop over in_vec in blocks of block_size int remaining = in_vec_size % block_size; if (remaining == 0) { @@ -739,7 +769,11 @@ METAL_FUNC void fp_qvm_impl( if (simd_lid == 0) { #pragma clang loop unroll(full) for (int k = 0; k < tn * pack_factor; k++) { - y[k] = static_cast(result[k]); + if constexpr (has_global_scale) { + y[k] = static_cast(result[k] * inv_scale_enc); + } else { + y[k] = static_cast(result[k]); + } } } } @@ -872,11 +906,11 @@ METAL_FUNC void fp_qmm_t_impl( template < typename T, - const int group_size, - const int bits, - const int BM = 32, - const int BK = 32, - const int BN = 32> + int group_size, + int bits, + int BM = 32, + int BK = 32, + int BN = 32> METAL_FUNC void fp_qmm_n_impl( const device uint32_t* w, const device uint8_t* scales, @@ -1128,10 +1162,16 @@ template w, scales, x, y, in_vec_size, out_vec_size, tid, quad_gid, quad_lid); } -template +template < + typename T, + int group_size, + int bits, + bool batched, + bool has_global_scale = false> [[kernel]] void fp_qmv_fast( const device uint32_t* w, const device uint8_t* scales, + const device float* global_scale, const device T* x, device T* y, const constant int& in_vec_size, @@ -1163,14 +1203,29 @@ template s_strides, tid); } - fp_qmv_fast_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); + fp_qmv_fast_impl( + w, + scales, + global_scale, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); } -template +template < + typename T, + int group_size, + int bits, + bool batched, + bool has_global_scale = false> [[kernel]] void fp_qmv( const device uint32_t* w, const device uint8_t* scales, + const device float* global_scale, const device T* x, device T* y, const constant int& in_vec_size, @@ -1202,8 +1257,17 @@ template s_strides, tid); } - fp_qmv_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); + fp_qmv_impl( + w, + scales, + global_scale, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); } template < @@ -1251,10 +1315,16 @@ template < w, scales, x, y, in_vec_size, out_vec_size, M, tid, simd_gid, simd_lid); } -template +template < + typename T, + int group_size, + int bits, + bool batched, + bool has_global_scale = false> [[kernel]] void fp_qvm( const device uint32_t* w, const device uint8_t* scales, + const device float* global_scale, const device T* x, device T* y, const constant int& in_vec_size, @@ -1286,9 +1356,10 @@ template s_strides, tid); } - fp_qvm_impl( + fp_qvm_impl( w, scales, + global_scale, x, y, in_vec_size, @@ -1341,9 +1412,10 @@ template // The in_vec_stride is the full K dimension, not the partition size int in_vec_stride = (split_k - 1) * in_vec_size + final_block_size; - fp_qvm_impl( + fp_qvm_impl( w, scales, + nullptr, x, y, in_vec_size_adj, @@ -1411,12 +1483,13 @@ template < template < typename T, - const int group_size, - const int bits, - const bool batched, - const int BM = 32, - const int BK = 32, - const int BN = 32> + int group_size, + int bits, + bool batched, + bool has_global_scale = false, + int BM = 32, + int BK = 32, + int BN = 32> [[kernel]] void fp_qmm_n( const device uint32_t* w, const device uint8_t* scales, @@ -1465,10 +1538,11 @@ template < w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } -template +template [[kernel]] void fp_gather_qmv_fast( const device uint32_t* w, const device uint8_t* scales, + const device float* global_scale, const device T* x, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, @@ -1510,14 +1584,24 @@ template w_strides, s_strides, tid); - fp_qmv_fast_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); + fp_qmv_fast_impl( + w, + scales, + global_scale, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); } -template +template [[kernel]] void fp_gather_qmv( const device uint32_t* w, const device uint8_t* scales, + const device float* global_scale, const device T* x, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, @@ -1559,14 +1643,24 @@ template w_strides, s_strides, tid); - fp_qmv_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); + fp_qmv_impl( + w, + scales, + global_scale, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); } -template +template [[kernel]] void fp_gather_qvm( const device uint32_t* w, const device uint8_t* scales, + const device float* global_scale, const device T* x, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, @@ -1608,9 +1702,10 @@ template w_strides, s_strides, tid); - fp_qvm_impl( + fp_qvm_impl( w, scales, + global_scale, x, y, in_vec_size, @@ -1741,11 +1836,12 @@ template < template < typename T, - const int group_size, - const int bits, - const int BM = 32, - const int BK = 32, - const int BN = 32> + int group_size, + int bits, + bool has_global_scale = false, + int BM = 32, + int BK = 32, + int BN = 32> [[kernel]] void fp_gather_qmm_n( const device uint32_t* w, const device uint8_t* scales, @@ -1996,38 +2092,47 @@ template < } } -template +template [[kernel]] void fp_quantize( const device T* w [[buffer(0)]], device uint8_t* out [[buffer(1)]], device uint8_t* scales [[buffer(2)]], + const device float* global_scale [[buffer(3)]], uint2 tidx [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { constexpr bool use_mx_scale = group_size == 32; size_t index = tidx.x + grid_dim.x * size_t(tidx.y); - float scale; + float scale_enc = 1.0f; + if constexpr (has_global_scale) { + scale_enc = (F8E4M3_MAX * F4E2M1_MAX) / *global_scale; + } + + float scale_dec_b; float w_thread = w[index]; if (use_mx_scale) { - scale = simd_max(abs(w_thread)); + scale_dec_b = simd_max(abs(w_thread)); } else { float w_max_l = simd_max(tidx.x < 16 ? abs(w_thread) : 0.0); float w_max_r = simd_max(tidx.x >= 16 ? abs(w_thread) : 0.0); - scale = tidx.x < 16 ? w_max_l : w_max_r; + scale_dec_b = tidx.x < 16 ? w_max_l : w_max_r; + } + scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; + if constexpr (has_global_scale) { + scale_dec_b *= scale_enc; } - scale /= bits == 4 ? 6.0f : 448.0f; using ScaleType = metal::conditional_t; - auto s = ScaleType(scale); + auto s = ScaleType(scale_dec_b); uint8_t q_scale = s.bits; - scale = float(s); - size_t gindex = index / group_size; if (index % group_size == 0) { scales[gindex] = q_scale; } - uint8_t output = Quantize{}(scale == 0 ? 0.0f : w_thread / scale); + float scale_enc_b = float(s); + scale_enc_b = (scale_enc_b == 0) ? 0.0f : (scale_enc / scale_enc_b); + uint8_t output = Quantize{}(w_thread * scale_enc_b); if (bits == 4) { uint8_t sval = simd_shuffle_down(output, 1); output |= sval << bits; @@ -2038,10 +2143,11 @@ template } } -template +template [[kernel]] void fp_dequantize( const device uint8_t* w [[buffer(0)]], const device uint8_t* scales [[buffer(1)]], + const device float* global_scale [[buffer(2)]], device T* out [[buffer(3)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { @@ -2053,9 +2159,17 @@ template out += oindex; + float inv_scale_enc = 1.0f; + if constexpr (has_global_scale) { + inv_scale_enc = *global_scale / (F8E4M3_MAX * F4E2M1_MAX); + } + using ScaleType = metal::conditional_t; auto q_scale = ((device ScaleType*)(scales))[gindex]; auto scale = float(q_scale); + if constexpr (has_global_scale) { + scale *= inv_scale_enc; + } uint val = w[offset]; #pragma clang loop unroll(full) @@ -2070,31 +2184,40 @@ template } } -template +template [[kernel]] void fp_quantize_dequantize( const device T* w [[buffer(0)]], - device T* out [[buffer(1)]], + const device float* global_scale [[buffer(1)]], + device T* out [[buffer(2)]], uint2 tidx [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { constexpr bool use_mx_scale = group_size == 32; size_t index = tidx.x + grid_dim.x * size_t(tidx.y); - float scale; + float scale_enc = 1.0f; + if constexpr (has_global_scale) { + scale_enc = (F8E4M3_MAX * F4E2M1_MAX) / *global_scale; + } + + float scale_dec_b; float w_thread = w[index]; if (use_mx_scale) { - scale = simd_max(abs(w_thread)); + scale_dec_b = simd_max(abs(w_thread)); } else { float w_max_l = simd_max(tidx.x < 16 ? abs(w_thread) : 0.0); float w_max_r = simd_max(tidx.x >= 16 ? abs(w_thread) : 0.0); - scale = tidx.x < 16 ? w_max_l : w_max_r; + scale_dec_b = tidx.x < 16 ? w_max_l : w_max_r; + } + scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; + if constexpr (has_global_scale) { + scale_dec_b *= scale_enc; } - scale /= bits == 4 ? 6.0f : 448.0f; using ScaleType = metal::conditional_t; - auto s = ScaleType(scale); - scale = float(s); + auto scale = float(ScaleType(scale_dec_b)); - uint8_t output = Quantize{}(scale == 0 ? 0.0f : w_thread / scale); + float scale_enc_b = (scale == 0) ? 0.0f : (scale_enc / scale); + uint8_t output = Quantize{}(w_thread * scale_enc_b); - out[index] = static_cast(scale * Dequantize{}(output)); + out[index] = static_cast((scale / scale_enc) * Dequantize{}(output)); } diff --git a/mlx/backend/metal/kernels/fp_quantized.metal b/mlx/backend/metal/kernels/fp_quantized.metal index 76980164f4..7404f2023f 100644 --- a/mlx/backend/metal/kernels/fp_quantized.metal +++ b/mlx/backend/metal/kernels/fp_quantized.metal @@ -9,19 +9,34 @@ #define instantiate_quantized(mode, name, type, group_size, bits) \ instantiate_kernel( \ #mode "_" #name "_" #type "_gs_" #group_size "_b_" #bits, \ - fp_ ## name, \ - type, \ - group_size, \ - bits) + fp_ ## name, \ + type, \ + group_size, \ + bits) \ + instantiate_kernel( \ + #mode "_" #name "_" #type "_gs_" #group_size "_b_" #bits "_hgs", \ + fp_ ## name, \ + type, \ + group_size, \ + bits, \ + true) #define instantiate_quantized_batched(mode, name, type, batched, group_size, bits) \ instantiate_kernel( \ #mode "_" #name "_" #type "_gs_" #group_size "_b_" #bits "_batch_" #batched, \ fp_ ## name, \ - type, \ - group_size, \ - bits, \ - batched) + type, \ + group_size, \ + bits, \ + batched) \ + instantiate_kernel( \ + #mode "_" #name "_" #type "_gs_" #group_size "_b_" #bits "_batch_" #batched "_hgs", \ + fp_ ## name, \ + type, \ + group_size, \ + bits, \ + batched, \ + true) #define instantiate_quantized_aligned(mode, name, type, aligned, group_size, bits) \ instantiate_kernel( \ @@ -137,25 +152,28 @@ instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true, mode, group_size, bits) \ instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false, mode, group_size, bits) -#define instantiate_quantize_dequantize(type, mode, group_size, bits) \ - instantiate_kernel( \ - #mode "_quantize_dequantize_" #type "_gs_" #group_size "_b_" #bits, \ +#define instantiate_quantize_dequantize(type, mode, group_size, bits, has_global_scale) \ + instantiate_kernel( \ + #mode "_quantize_dequantize_" #type "_gs_" #group_size "_b_" #bits "_hgs_" #has_global_scale, \ fp_quantize_dequantize, \ - type, \ - group_size, \ - bits) \ - instantiate_kernel( \ - #mode "_quantize_" #type "_gs_" #group_size "_b_" #bits, \ - fp_quantize, \ - type, \ - group_size, \ - bits) \ - instantiate_kernel( \ - #mode "_dequantize_" #type "_gs_" #group_size "_b_" #bits, \ - fp_dequantize, \ - type, \ - group_size, \ - bits) + type, \ + group_size, \ + bits, \ + has_global_scale) \ + instantiate_kernel( \ + #mode "_quantize_" #type "_gs_" #group_size "_b_" #bits "_hgs_" #has_global_scale, \ + fp_quantize, \ + type, \ + group_size, \ + bits, \ + has_global_scale) \ + instantiate_kernel( \ + #mode "_dequantize_" #type "_gs_" #group_size "_b_" #bits "_hgs_" #has_global_scale, \ + fp_dequantize, \ + type, \ + group_size, \ + bits, \ + has_global_scale) #define instantiate_quantized_modes(type, mode, group_size, bits) \ instantiate_quantized_all_batched(type, mode, group_size, bits) \ @@ -164,13 +182,16 @@ instantiate_quantized_all_wide(type, mode, group_size, bits) \ instantiate_quantized_all_splitk(type, mode, group_size, bits) \ instantiate_quantized_all_aligned(type, mode, group_size, bits) \ - instantiate_quantized_all_rhs(type, mode, group_size, bits) \ - instantiate_quantize_dequantize(type, mode, group_size, bits) + instantiate_quantized_all_rhs(type, mode, group_size, bits) #define instantiate_quantized_types(type) \ instantiate_quantized_modes(type, nvfp4, 16, 4) \ instantiate_quantized_modes(type, mxfp8, 32, 8) \ - instantiate_quantized_modes(type, mxfp4, 32, 4) + instantiate_quantized_modes(type, mxfp4, 32, 4) \ + instantiate_quantize_dequantize(type, nvfp4, 16, 4, false) \ + instantiate_quantize_dequantize(type, nvfp4, 16, 4, true) \ + instantiate_quantize_dequantize(type, mxfp8, 32, 8, false) \ + instantiate_quantize_dequantize(type, mxfp4, 32, 4, false) \ instantiate_quantized_types(float) instantiate_quantized_types(bfloat16_t) diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 345dfc711e..ee00590266 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -1591,7 +1591,12 @@ template quad_lid); } -template +template < + typename T, + int group_size, + int bits, + bool batched, + bool has_global_scale = false> [[kernel]] void affine_qmv_fast( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -1643,7 +1648,12 @@ template simd_lid); } -template +template < + typename T, + int group_size, + const int bits, + bool batched, + bool has_global_scale = false> [[kernel]] void affine_qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -1754,7 +1764,12 @@ template < simd_lid); } -template +template < + typename T, + const int group_size, + const int bits, + bool batched, + bool has_global_scale = false> [[kernel]] void affine_qvm( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -2001,12 +2016,13 @@ template < template < typename T, - const int group_size, - const int bits, - const bool batched, - const int BM = 32, - const int BK = 32, - const int BN = 32> + int group_size, + int bits, + bool batched, + bool has_global_scale = false, + int BM = 32, + int BK = 32, + int BN = 32> [[kernel]] void affine_qmm_n( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -2059,7 +2075,7 @@ template < w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } -template +template [[kernel]] void affine_gather_qmv_fast( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -2121,7 +2137,7 @@ template simd_lid); } -template +template [[kernel]] void affine_gather_qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -2183,7 +2199,7 @@ template simd_lid); } -template +template [[kernel]] void affine_gather_qvm( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -2330,11 +2346,12 @@ template < template < typename T, - const int group_size, - const int bits, - const int BM = 32, - const int BK = 32, - const int BN = 32> + int group_size, + int bits, + bool has_global_scale = false, + int BM = 32, + int BK = 32, + int BN = 32> [[kernel]] void affine_gather_qmm_n( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -2592,7 +2609,7 @@ template < } } -template +template [[kernel]] void affine_quantize( const device T* w [[buffer(0)]], device uint8_t* out [[buffer(1)]], @@ -2697,7 +2714,7 @@ template } } -template +template [[kernel]] void affine_dequantize( const device uint8_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 62d48714e0..5f601fb44e 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -134,7 +134,7 @@ inline int add_strides_and_shapes( const std::optional& biases, int offset) { if (skip) { - return 0; + return offset; } // TODO: Collapse batch dimensions @@ -172,6 +172,209 @@ inline int add_gather_strides_and_shapes( return offset; } +auto get_quantize_kernel_dims( + MTL::ComputePipelineState* kernel, + const array& w, + const array& out, + int group_size, + int bits, + bool dequantize = false) { + // Treat uint32 as uint8 in kernel + constexpr int uint8_per_uint32 = 4; + constexpr int simd_size = 32; + int packs_per_int = (bits == 3 || bits == 5) ? 8 : bits == 6 ? 4 : 8 / bits; + int per_thread = + dequantize ? packs_per_int : std::max(group_size / simd_size, 1); + size_t nthreads = + dequantize ? out.size() / packs_per_int : w.size() / per_thread; + + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size > nthreads) { + thread_group_size = nthreads; + } + auto group_dims = MTL::Size(thread_group_size, 1, 1); + bool use_2d = nthreads > UINT_MAX; + auto grid_shape = w.shape(); + if (dequantize) { + grid_shape.back() *= uint8_per_uint32; + } else { + grid_shape.back() /= per_thread; + } + MTL::Size grid_dims = use_2d ? get_2d_grid_dims(grid_shape, w.strides()) + : MTL::Size(nthreads, 1, 1); + return std::make_tuple(grid_dims, group_dims); +} + +void quantize_impl( + const std::vector& inputs, + std::vector& outputs, + QuantizationMode mode, + int group_size, + int bits, + bool dequantize, + Stream s) { + auto& w_pre = inputs[0]; + auto& out = outputs[0]; + out.set_data(allocator::malloc(out.nbytes())); + + auto& d = metal::device(s.device); + auto& compute_encoder = metal::get_command_encoder(s); + + bool has_biases = (mode == QuantizationMode::Affine); + bool has_global_scale = !has_biases && (inputs.size() > (1 + dequantize)); + + auto w = ensure_row_contiguous(w_pre, d, s); + if (dequantize) { + auto scales = ensure_row_contiguous(inputs[1], d, s); + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + if (has_biases) { + auto biases = ensure_row_contiguous(inputs[2], d, s); + compute_encoder.set_input_array(biases, 2); + } else if (has_global_scale) { + compute_encoder.set_input_array(inputs[2], 2); + } + compute_encoder.set_output_array(out, 3); + } else { + auto& scales = outputs[1]; + scales.set_data(allocator::malloc(scales.nbytes())); + compute_encoder.set_input_array(w, 0); + compute_encoder.set_output_array(out, 1); + compute_encoder.set_output_array(scales, 2); + if (has_biases) { + auto& biases = outputs[2]; + biases.set_data(allocator::malloc(biases.nbytes())); + compute_encoder.set_output_array(biases, 3); + } else if (has_global_scale) { + compute_encoder.set_input_array(inputs[1], 3); + } + } + + auto type_string = dequantize ? get_type_string(out.dtype()) + : get_type_string(w_pre.dtype()); + auto mode_string = quantization_mode_to_string(mode); + std::string kname; + concatenate( + kname, + mode_string + (dequantize ? "_dequantize" : "_quantize"), + "_", + type_string, + "_gs_", + group_size, + "_b_", + bits); + if (!has_biases) { + concatenate(kname, "_hgs_", has_global_scale ? "true" : "false"); + } + auto kernel = get_quantized_kernel_wrapped( + d, + kname, + dequantize ? "dequantize" : "quantize", + mode_string, + type_string, + group_size, + bits, + has_global_scale); + + auto [grid_dims, group_dims] = + get_quantize_kernel_dims(kernel, w, out, group_size, bits, dequantize); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.dispatch_threads(grid_dims, group_dims); +} + +auto quantize_input( + const array& w, + const std::optional& global_scale, + QuantizationMode mode, + int group_size, + int bits, + metal::Device& d, + Stream s) { + auto wq_shape = w.shape(); + wq_shape.back() = w.shape(-1) * bits / 32; + auto scales_shape = w.shape(); + scales_shape.back() = w.shape(-1) / group_size; + + std::vector inputs{w}; + if (global_scale) { + inputs.push_back(*global_scale); + } + std::vector outputs{ + array(wq_shape, uint32, nullptr, {}), + array(scales_shape, uint8, nullptr, {})}; + auto& compute_encoder = metal::get_command_encoder(s); + compute_encoder.add_temporary(outputs[0]); + compute_encoder.add_temporary(outputs[1]); + quantize_impl(inputs, outputs, mode, group_size, bits, false, s); + return std::make_tuple(outputs[0], outputs[1]); +} + +void fp_quantize_dequantize( + const array& in, + const std::optional& global_scale, + array& out, + const std::string& mode, + int group_size, + int bits, + metal::Device& d, + const Stream& s) { + auto& compute_encoder = metal::get_command_encoder(s); + + auto w = ensure_row_contiguous(in, d, s); + compute_encoder.set_input_array(w, 0); + if (global_scale) { + compute_encoder.set_input_array(*global_scale, 1); + } + compute_encoder.set_output_array(out, 2); + auto type_string = get_type_string(in.dtype()); + std::string kname; + concatenate( + kname, + mode + "_quantize_dequantize_", + type_string, + "_gs_", + group_size, + "_b_", + bits, + "_hgs_", + global_scale ? "true" : "false"); + auto kernel = get_quantized_kernel_wrapped( + d, + kname, + "quantize_dequantize", + mode, + type_string, + group_size, + bits, + global_scale.has_value()); + + auto [grid_dims, group_dims] = + get_quantize_kernel_dims(kernel, w, out, group_size, bits); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.dispatch_threads(grid_dims, group_dims); +} + +array quantize_dequantize_input( + const array& x_pre, + const std::optional& global_scale, + const std::string& mode, + int group_size, + int bits, + metal::Device& d, + Stream s) { + bool donate_x = x_pre.is_donatable(); + array x = ensure_row_contiguous(x_pre, d, s); + // If x is a copy it should be donatable + donate_x |= x.is_donatable(); + auto xhat = + donate_x ? x : array(allocator::malloc(x.nbytes()), x.shape(), x.dtype()); + if (!donate_x) { + metal::get_command_encoder(s).add_temporary(xhat); + } + fp_quantize_dequantize(x, global_scale, xhat, mode, group_size, bits, d, s); + return xhat; +} + } // namespace void qmv_quad( @@ -237,6 +440,7 @@ void qmv( const array& w, const array& scales, const std::optional& biases, + const std::optional& global_scale, array& out, int group_size, int bits, @@ -266,7 +470,8 @@ void qmv( group_size, "_b_", bits, - B > 1 ? "_batch_1" : "_batch_0"); + B > 1 ? "_batch_1" : "_batch_0", + global_scale ? "_hgs" : ""); auto kernel = get_quantized_kernel_wrapped( d, kname, @@ -275,17 +480,20 @@ void qmv( type_string, group_size, bits, - B > 1); + B > 1, + global_scale.has_value()); auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); - int c = 0; - compute_encoder.set_input_array(w, c++); - compute_encoder.set_input_array(scales, c++); + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); if (biases) { - compute_encoder.set_input_array(*biases, c++); + compute_encoder.set_input_array(*biases, 2); + } else if (global_scale) { + compute_encoder.set_input_array(*global_scale, 2); } + int c = 3; compute_encoder.set_input_array(x, c++); compute_encoder.set_output_array(out, c++); compute_encoder.set_bytes(K, c++); @@ -511,6 +719,7 @@ void qvm( const array& w, const array& scales, const std::optional& biases, + const std::optional& global_scale, array& out, int group_size, int bits, @@ -539,18 +748,29 @@ void qvm( group_size, "_b_", bits, - B > 1 ? "_batch_1" : "_batch_0"); + B > 1 ? "_batch_1" : "_batch_0", + global_scale ? "_hgs" : ""); auto kernel = get_quantized_kernel_wrapped( - d, kname, "qvm", mode, type_string, group_size, bits, B > 1); + d, + kname, + "qvm", + mode, + type_string, + group_size, + bits, + B > 1, + global_scale.has_value()); auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); - int c = 0; - compute_encoder.set_input_array(w, c++); - compute_encoder.set_input_array(scales, c++); + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); if (biases) { - compute_encoder.set_input_array(*biases, c++); + compute_encoder.set_input_array(*biases, 2); + } else if (global_scale) { + compute_encoder.set_input_array(*global_scale, 2); } + int c = 3; compute_encoder.set_input_array(x, c++); compute_encoder.set_output_array(out, c++); compute_encoder.set_bytes(K, c++); @@ -1054,6 +1274,7 @@ void gather_qmv( const array& w, const array& scales, const std::optional& biases, + const std::optional& global_scale, const array& lhs_indices, const array& rhs_indices, array& out, @@ -1083,7 +1304,8 @@ void gather_qmv( "_gs_", group_size, "_b_", - bits); + bits, + global_scale ? "_hgs" : ""); auto kernel = get_quantized_kernel_wrapped( d, @@ -1092,17 +1314,20 @@ void gather_qmv( mode, type_string, group_size, - bits); + bits, + global_scale.has_value()); auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); - int c = 0; - compute_encoder.set_input_array(w, c++); - compute_encoder.set_input_array(scales, c++); + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); if (biases) { - compute_encoder.set_input_array(*biases, c++); + compute_encoder.set_input_array(*biases, 2); + } else if (global_scale) { + compute_encoder.set_input_array(*global_scale, 2); } + int c = 3; compute_encoder.set_input_array(x, c++); compute_encoder.set_input_array(lhs_indices, c++); compute_encoder.set_input_array(rhs_indices, c++); @@ -1120,6 +1345,7 @@ void gather_qvm( const array& w, const array& scales, const std::optional& biases, + const std::optional& global_scale, const array& lhs_indices, const array& rhs_indices, array& out, @@ -1149,18 +1375,28 @@ void gather_qvm( "_gs_", group_size, "_b_", - bits); + bits, + global_scale ? "_hgs" : ""); auto kernel = get_quantized_kernel_wrapped( - d, kname, "gather_qvm", mode, type_string, group_size, bits); + d, + kname, + "gather_qvm", + mode, + type_string, + group_size, + bits, + global_scale.has_value()); auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); - int c = 0; - compute_encoder.set_input_array(w, c++); - compute_encoder.set_input_array(scales, c++); + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); if (biases) { - compute_encoder.set_input_array(*biases, c++); + compute_encoder.set_input_array(*biases, 2); + } else if (global_scale) { + compute_encoder.set_input_array(*global_scale, 2); } + int c = 3; compute_encoder.set_input_array(x, c++); compute_encoder.set_input_array(lhs_indices, c++); compute_encoder.set_input_array(rhs_indices, c++); @@ -1459,6 +1695,7 @@ void dispatch_qmv( const array& w, const array& scales, const std::optional& biases, + const std::optional& global_scale, array& out, int group_size, int bits, @@ -1469,18 +1706,31 @@ void dispatch_qmv( const Stream& s, const std::string& mode) { // It is a qmv with a small inner dimension so route to qmv_quad kernel - if ((K == 128 || K == 64) && is_power_of_2(bits)) { + if ((K == 128 || K == 64) && is_power_of_2(bits) && !global_scale) { qmv_quad(x, w, scales, biases, out, group_size, bits, M, N, K, d, s, mode); return; } // Small batch so route to qmv_wide, which reuses each weight group across the // M vectors. - if (M >= 2 && use_qmv_wide(mode, d)) { + if (M >= 2 && use_qmv_wide(mode, d) && !global_scale) { qmv_wide(x, w, scales, biases, out, group_size, bits, M, N, K, d, s, mode); return; } - qmv(x, w, scales, biases, out, group_size, bits, M, N, K, d, s, mode); + qmv(x, + w, + scales, + biases, + global_scale, + out, + group_size, + bits, + M, + N, + K, + d, + s, + mode); } void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { @@ -1536,13 +1786,39 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { // Run of the mill qmv if (transpose_) { dispatch_qmv( - x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode); + x, + w, + scales, + biases, + std::nullopt, + out, + group_size_, + bits_, + M, + N, + K, + d, + s, + mode); return; } // Run of the mill qvm if (K < 1024) { - qvm(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode); + qvm(x, + w, + scales, + biases, + std::nullopt, + out, + group_size_, + bits_, + M, + N, + K, + d, + s, + mode); return; } @@ -1628,6 +1904,7 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { w, scales, biases, + std::nullopt, lhs_indices, rhs_indices, out, @@ -1647,6 +1924,7 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { w, scales, biases, + std::nullopt, lhs_indices, rhs_indices, out, @@ -1660,194 +1938,133 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { mode); } -void quantize_dequantize( - const array& in, - array& out, - std::string mode, - int group_size, - int bits, - metal::Device& d, - const Stream& s) { - auto& compute_encoder = metal::get_command_encoder(s); - - auto w = ensure_row_contiguous(in, d, s); - compute_encoder.set_input_array(w, 0); - compute_encoder.set_output_array(out, 1); - auto type_string = get_type_string(in.dtype()); - std::string kname; - concatenate( - kname, - mode + "_quantize_dequantize_", - type_string, - "_gs_", - group_size, - "_b_", - bits); - auto kernel = get_quantized_kernel_wrapped( - d, kname, "quantize_dequantize", mode, type_string, group_size, bits); - - compute_encoder.set_compute_pipeline_state(kernel); - - constexpr int uint8_per_uint32 = 4; - constexpr int simd_size = 32; - int packs_per_int = (bits == 3 || bits == 5) ? 8 : bits == 6 ? 4 : 8 / bits; - int per_thread = std::max(group_size / simd_size, 1); - size_t nthreads = w.size() / per_thread; - - NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - if (thread_group_size > nthreads) { - thread_group_size = nthreads; - } - auto group_dims = MTL::Size(thread_group_size, 1, 1); - bool use_2d = nthreads > UINT_MAX; - auto grid_shape = w.shape(); - grid_shape.back() /= per_thread; - MTL::Size grid_dims = use_2d ? get_2d_grid_dims(grid_shape, w.strides()) - : MTL::Size(nthreads, 1, 1); - compute_encoder.dispatch_threads(grid_dims, group_dims); -} - void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); + const array& x_pre = inputs[0]; + const array& w_pre = inputs[1]; auto mode = quantization_mode_to_string(mode_); + + out.set_data(allocator::malloc(out.nbytes())); + + // - 2 inputs: x, w (non-quantized w) + // - 3 inputs: x, w, scales_w (quantized w) bool w_quantized = (inputs[1].dtype() == uint32); - // Tensor-scale nvfp4 (global_scale_x / global_scale_w) is packed into - // inputs by ops.cpp but no Metal qqmm kernel currently consumes the - // global scales. Reject the request rather than silently dropping them - // in the gemv path below. int base_size = w_quantized ? 3 : 2; - if (mode_ == QuantizationMode::Nvfp4 && - static_cast(inputs.size()) > base_size) { - throw std::runtime_error( - "[QQMatmul] Global scale (tensor-scale nvfp4) is not supported " - "on the Metal backend."); - } - if (w_quantized && inputs[0].shape(-2) == 1) { - out.set_data(allocator::malloc(out.nbytes())); - - bool donate_x = inputs[0].is_donatable(); - array x = ensure_row_contiguous(inputs[0], d, s); - // If x is a copy it should be donatable - donate_x |= x.is_donatable(); - auto xhat = donate_x - ? x - : array(allocator::malloc(x.nbytes()), x.shape(), x.dtype()); - quantize_dequantize(x, xhat, mode, group_size_, bits_, d, s); - - // Make sure the last two dims of w and s are contiguous - array w = ensure_row_contiguous_matrix(inputs[1], d, s); - array scales = ensure_row_contiguous_matrix(inputs[2], d, s); - - bool non_batched = w.ndim() == 2; - int K = x.shape(-1); - int M = non_batched ? x.size() / K : x.shape(-2); - int N = out.shape(-1); - dispatch_qmv( - xhat, - w, - scales, - std::nullopt, - out, - group_size_, - bits_, - M, - N, - K, - d, - s, - mode); - return; - } else { - throw std::runtime_error("[QQMatmul] NYI for the general case"); + // For nvfp4, global scales are optional but must be both present or both + // absent If present, they add 2 more inputs (global_scale_x, global_scale_w) + bool has_global_scales = + mode_ == QuantizationMode::Nvfp4 && inputs.size() == base_size + 2; + assert(inputs.size() == base_size || has_global_scales); + + std::optional global_scale_x; + std::optional global_scale_w; + if (has_global_scales) { + global_scale_x = inputs[inputs.size() - 2]; + global_scale_w = inputs[inputs.size() - 1]; } -} -void fast::Quantize::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - auto& w_pre = inputs[0]; - auto& out = outputs[0]; - out.set_data(allocator::malloc(out.nbytes())); + // Quantize weights. + auto [w_q, scales_w] = !w_quantized + ? quantize_input(w_pre, global_scale_w, mode_, group_size_, bits_, d, s) + : std::make_tuple( + ensure_row_contiguous_matrix(w_pre, d, s), + ensure_row_contiguous_matrix(inputs[base_size - 1], d, s)); + + // Quantize activation. + array x = quantize_dequantize_input( + x_pre, global_scale_x, mode, group_size_, bits_, d, s); + bool non_batched = w_q.ndim() == 2; + int K = x.shape(-1); + int M = non_batched ? x.size() / K : x.shape(-2); + int N = out.shape(-1); + dispatch_qmv( + x, + w_q, + scales_w, + std::nullopt, + global_scale_w, + out, + group_size_, + bits_, + M, + N, + K, + d, + s, + mode); +} + +void GatherQQMM::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); - auto& compute_encoder = metal::get_command_encoder(s); - auto w = ensure_row_contiguous(w_pre, d, s); - if (dequantize_) { - auto scales = ensure_row_contiguous(inputs[1], d, s); - if (mode_ == QuantizationMode::Affine) { - auto biases = ensure_row_contiguous(inputs[2], d, s); - compute_encoder.set_input_array(biases, 2); - } - compute_encoder.set_input_array(w, 0); - compute_encoder.set_input_array(scales, 1); - compute_encoder.set_output_array(out, 3); - } else { - auto& scales = outputs[1]; - scales.set_data(allocator::malloc(scales.nbytes())); - if (mode_ == QuantizationMode::Affine) { - auto& biases = outputs[2]; - biases.set_data(allocator::malloc(biases.nbytes())); - compute_encoder.set_output_array(biases, 3); - } - compute_encoder.set_input_array(w, 0); - compute_encoder.set_output_array(out, 1); - compute_encoder.set_output_array(scales, 2); + const array& x_pre = inputs[0]; + const array& w_pre = inputs[1]; + const array& lhs_indices = ensure_row_contiguous(inputs[2], d, s); + const array& rhs_indices = ensure_row_contiguous(inputs[3], d, s); + auto mode = quantization_mode_to_string(mode_); + + out.set_data(allocator::malloc(out.nbytes())); + + // - 4 inputs: x, w (non-quantized w) + // - 5 inputs: x, w, scales_w (quantized w) + bool w_quantized = (inputs[1].dtype() == uint32); + int base_size = w_quantized ? 5 : 4; + // For nvfp4, global scales are optional but must be both present or both + // absent If present, they add 2 more inputs (global_scale_x, global_scale_w) + bool has_global_scales = + mode_ == QuantizationMode::Nvfp4 && inputs.size() == base_size + 2; + assert(inputs.size() == base_size || has_global_scales); + + std::optional global_scale_x; + std::optional global_scale_w; + if (has_global_scales) { + global_scale_x = inputs[inputs.size() - 2]; + global_scale_w = inputs[inputs.size() - 1]; } - auto type_string = dequantize_ ? get_type_string(out.dtype()) - : get_type_string(w_pre.dtype()); - auto mode = quantization_mode_to_string(mode_); - std::string kname; - concatenate( - kname, - mode + (dequantize_ ? "_dequantize" : "_quantize"), - "_", - type_string, - "_gs_", - group_size_, - "_b_", - bits_); - auto kernel = get_quantized_kernel_wrapped( - d, - kname, - dequantize_ ? "dequantize" : "quantize", - mode, - type_string, - group_size_, - bits_); + // Quantize weights. + auto [w_q, scales_w] = !w_quantized + ? quantize_input(w_pre, global_scale_w, mode_, group_size_, bits_, d, s) + : std::make_tuple( + ensure_row_contiguous_matrix(w_pre, d, s), + ensure_row_contiguous_matrix(inputs[base_size - 1], d, s)); - compute_encoder.set_compute_pipeline_state(kernel); + // Quantize activation. + array x = quantize_dequantize_input( + x_pre, global_scale_x, mode, group_size_, bits_, d, s); - // Treat uint32 as uint8 in kernel - constexpr int uint8_per_uint32 = 4; - constexpr int simd_size = 32; - int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8 - : bits_ == 6 ? 4 - : 8 / bits_; - int per_thread = - dequantize_ ? packs_per_int : std::max(group_size_ / simd_size, 1); - size_t nthreads = - dequantize_ ? out.size() / packs_per_int : w.size() / per_thread; + bool non_batched = w_q.ndim() == 2; + int K = x.shape(-1); + int M = non_batched ? x.size() / K : x.shape(-2); + int N = out.shape(-1); + gather_qmv( + x, + w_q, + scales_w, + std::nullopt, + global_scale_w, + lhs_indices, + rhs_indices, + out, + group_size_, + bits_, + M, + N, + K, + d, + s, + mode); +} - NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - if (thread_group_size > nthreads) { - thread_group_size = nthreads; - } - auto group_dims = MTL::Size(thread_group_size, 1, 1); - bool use_2d = nthreads > UINT_MAX; - auto grid_shape = w.shape(); - if (dequantize_) { - grid_shape.back() *= uint8_per_uint32; - } else { - grid_shape.back() /= per_thread; - } - MTL::Size grid_dims = use_2d ? get_2d_grid_dims(grid_shape, w.strides()) - : MTL::Size(nthreads, 1, 1); - compute_encoder.dispatch_threads(grid_dims, group_dims); +void fast::Quantize::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + quantize_impl( + inputs, outputs, mode_, group_size_, bits_, dequantize_, stream()); } void fast::ConvertFP8::eval_gpu( diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index ae51dd9b2f..faaeb0c7c4 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -71,6 +71,7 @@ NO_CPU(Gather) NO_CPU(GatherAxis) NO_CPU(GatherMM) NO_CPU(GatherQMM) +NO_CPU(GatherQQMM) NO_CPU(Greater) NO_CPU(GreaterEqual) NO_CPU(Hadamard) diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 4819ed2724..0e05e9d19f 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -98,6 +98,7 @@ NO_GPU(Gather) NO_GPU(GatherAxis) NO_GPU(GatherMM) NO_GPU(GatherQMM) +NO_GPU(GatherQQMM) NO_GPU(Greater) NO_GPU(GreaterEqual) NO_GPU(Hadamard) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 849b8081dd..b397fc6d80 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -65,11 +65,26 @@ Dtype at_least_float(const Dtype& d) { } array indices_or_default( - std::optional indices, + std::string_view tag, + const std::optional& indices, const array& x, StreamOrDevice s) { + if (x.ndim() < 2) { + std::ostringstream msg; + msg << tag + << " Input must have at least two dimensions but got input with shape " + << x.shape() << "."; + throw std::invalid_argument(msg.str()); + } + if (indices.has_value()) { - return indices.value(); + if (!issubdtype(indices->dtype(), integer)) { + std::ostringstream msg; + msg << tag + << " Got indices with invalid dtype. Indices must be integral."; + throw std::invalid_argument(msg.str()); + } + return astype(indices.value(), uint32); } Shape shape(x.shape().begin(), x.shape().end() - 2); @@ -4626,10 +4641,10 @@ void validate_global_scale( } array quantized_matmul( - array x, - array w, - array scales, - std::optional biases /* = std::nullopt */, + const array& x, + const array& w, + const array& scales, + const std::optional& biases /* = std::nullopt */, bool transpose /* = true */, std::optional group_size_ /* = std::nullopt */, std::optional bits_ /* = std::nullopt */, @@ -4678,13 +4693,13 @@ array quantized_matmul( } void validate_qqmm_inputs( - array x, - array w, - std::optional scales_w, + const array& x, + const array& w, + const std::optional& scales_w, int group_size, int bits, - std::optional global_scale_x, - std::optional global_scale_w, + const std::optional& global_scale_x, + const std::optional& global_scale_w, QuantizationMode qmode) { // check 2D (for now) if (x.ndim() > 2 || w.ndim() > 2) { @@ -4738,9 +4753,9 @@ void validate_qqmm_inputs( } std::pair extract_qqmm_dims( - array x, - array w, - std::optional scales_w, + const array& x, + const array& w, + const std::optional& scales_w, int group_size, int bits) { if (w.dtype() != uint32) { @@ -4768,24 +4783,17 @@ std::pair extract_qqmm_dims( } array qqmm( - array in_x, - array w, - std::optional scales_w, + const array& in_x, + const array& w, + const std::optional& scales_w, std::optional group_size_ /* = std::nullopt */, std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "nvfp4" */, - const std::optional global_scale_x /* = std::nullopt */, - const std::optional global_scale_w /* = std::nullopt */, + const std::optional& global_scale_x /* = std::nullopt */, + const std::optional& global_scale_w /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto stream = to_stream(s); auto qmode = string_to_quantization_mode(mode, "qqmm"); - // cuBLAS block scaled matmul only supports nvfp4 and mxfp8 - if (qmode != QuantizationMode::Nvfp4 && qmode != QuantizationMode::Mxfp8) { - std::ostringstream msg; - msg << "[qqmm] Only 'nvfp4' and 'mxfp8' quantization modes are supported but '" - << mode << "' was provided."; - throw std::invalid_argument(msg.str()); - } // we need to check 2 cases: // 1. w is quantized, scales is provided // 2. w is not quantized, scales is not provided @@ -5088,12 +5096,6 @@ std::vector quantize( << " matrix has shape " << w.shape(); throw std::invalid_argument(msg.str()); } - if (to_stream(s).device == Device::gpu && metal::is_available() && - global_scale.has_value()) { - std::ostringstream msg; - msg << "[quantize] Global scale is not supported on the Metal backend."; - throw std::invalid_argument(msg.str()); - } validate_global_scale("quantize", qmode, global_scale); if (qmode == QuantizationMode::Affine) { return affine_quantize(w, group_size, bits, s); @@ -5353,13 +5355,6 @@ array dequantize( << "but it has only " << w.ndim() << "."; throw std::invalid_argument(msg.str()); } - if (global_scale.has_value()) { - if (to_stream(s).device == Device::gpu && metal::is_available()) { - std::ostringstream msg; - msg << "[dequantize] Global scale is not supported on the Metal backend."; - throw std::invalid_argument(msg.str()); - } - } validate_global_scale("dequantize", qmode, global_scale); if (qmode == QuantizationMode::Affine) { @@ -5452,30 +5447,11 @@ array gather_qmm( } // Extract indices and broadcast them - array lhs_indices = indices_or_default(lhs_indices_, x, s); - array rhs_indices = indices_or_default(rhs_indices_, w, s); + array lhs_indices = indices_or_default("[gather_qmm]", lhs_indices_, x, s); + array rhs_indices = indices_or_default("[gather_qmm]", rhs_indices_, w, s); std::tie(lhs_indices, rhs_indices) = broadcast_arrays(lhs_indices, rhs_indices, s); - if (!issubdtype(lhs_indices.dtype(), integer)) { - throw std::invalid_argument( - "[gather_qmm] Got lhs_indices with invalid dtype. Indices must be integral."); - } - - if (!issubdtype(rhs_indices.dtype(), integer)) { - throw std::invalid_argument( - "[gather_qmm] Got rhs_indices with invalid dtype. Indices must be integral."); - } - if (x.ndim() < 2) { - std::ostringstream msg; - msg << "[gather_qmm] Non-quantized input must have at least two" - << " dimensions but got input with shape " << x.shape() << "."; - throw std::invalid_argument(msg.str()); - } - - lhs_indices = astype(lhs_indices, uint32, s); - rhs_indices = astype(rhs_indices, uint32, s); - // Compute the full output shape auto out_shape = lhs_indices.shape(); out_shape.push_back(x.shape(-2)); @@ -5511,6 +5487,56 @@ array gather_qmm( std::move(inputs)); } +array gather_qqmm( + const array& x, + const array& w, + const std::optional& scales_w, + const std::optional& lhs_indices_, + const std::optional& rhs_indices_, + std::optional group_size_, + std::optional bits_, + const std::string& mode, + const std::optional& global_scale_x, + const std::optional& global_scale_w, + bool sorted_indices, + StreamOrDevice s) { + auto stream = to_stream(s); + auto qmode = string_to_quantization_mode(mode, "gather_qqmm"); + auto [group_size, bits] = + quantization_params_from_mode(qmode, group_size_, bits_); + + // Extract indices and broadcast them + array lhs_indices = indices_or_default("[gather_qqmm]", lhs_indices_, x, s); + array rhs_indices = indices_or_default("[gather_qqmm]", rhs_indices_, w, s); + std::tie(lhs_indices, rhs_indices) = + broadcast_arrays(lhs_indices, rhs_indices, s); + + std::vector inputs = { + x, + w, + lhs_indices, + rhs_indices, + }; + if (scales_w.has_value()) { + inputs.push_back(*scales_w); + } + if (global_scale_x.has_value() && global_scale_w.has_value()) { + inputs.push_back(*global_scale_x); + inputs.push_back(*global_scale_w); + } + + auto [w_inner_dims, w_outer_dims] = + extract_qqmm_dims(x, w, scales_w, group_size, bits); + auto out_shape = lhs_indices.shape(); + out_shape.push_back(x.shape(-2)); + out_shape.push_back(w_outer_dims); + return array( + std::move(out_shape), + x.dtype(), + std::make_shared(stream, group_size, bits, qmode), + std::move(inputs)); +} + array tensordot( const array& a, const array& b, @@ -6009,28 +6035,14 @@ array gather_mm( b = astype(b, out_type, s); // Handle broadcasting - array lhs_indices = indices_or_default(lhs_indices_, a, s); - array rhs_indices = indices_or_default(rhs_indices_, b, s); - - if (!issubdtype(lhs_indices.dtype(), integer)) { - throw std::invalid_argument( - "[gather_mm] Got lhs_indices with invalid dtype. Indices must be integral."); - } - - if (!issubdtype(rhs_indices.dtype(), integer)) { - throw std::invalid_argument( - "[gather_mm] Got rhs_indices with invalid dtype. Indices must be integral."); - } - - lhs_indices = astype(lhs_indices, uint32, s); - rhs_indices = astype(rhs_indices, uint32, s); + array lhs_indices = indices_or_default("[gather_mm]", lhs_indices_, a, s); + array rhs_indices = indices_or_default("[gather_mm]", rhs_indices_, b, s); + std::tie(lhs_indices, rhs_indices) = + broadcast_arrays(lhs_indices, rhs_indices, s); int M = a.shape(-2); int N = b.shape(-1); - std::tie(lhs_indices, rhs_indices) = - broadcast_arrays(lhs_indices, rhs_indices, s); - auto out_shape = lhs_indices.shape(); out_shape.push_back(M); out_shape.push_back(N); diff --git a/mlx/ops.h b/mlx/ops.h index ed617c3441..20568024f6 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1546,10 +1546,10 @@ MLX_API array conv_transpose3d( /** Quantized matmul multiplies x with a quantized matrix w*/ MLX_API array quantized_matmul( - array x, - array w, - array scales, - std::optional biases = std::nullopt, + const array& x, + const array& w, + const array& scales, + const std::optional& biases = std::nullopt, bool transpose = true, std::optional group_size = std::nullopt, std::optional bits = std::nullopt, @@ -1578,15 +1578,15 @@ MLX_API array dequantize( StreamOrDevice s = {}); MLX_API array qqmm( - array x, // input activations - array w, // maybe quantized weights - const std::optional w_scales = std::nullopt, // optional scales if w - // is quantized + const array& x, // input activations + const array& w, // maybe quantized weights + const std::optional& w_scales = std::nullopt, // optional scales if w + // is quantized std::optional group_size = std::nullopt, std::optional bits = std::nullopt, const std::string& mode = "nvfp4", - const std::optional global_scale_x = std::nullopt, - const std::optional global_scale_w = std::nullopt, + const std::optional& global_scale_x = std::nullopt, + const std::optional& global_scale_w = std::nullopt, StreamOrDevice s = {}); /** Convert an E4M3 float8 to the given floating point dtype. */ @@ -1610,6 +1610,20 @@ MLX_API array gather_qmm( bool sorted_indices = false, StreamOrDevice s = {}); +MLX_API array gather_qqmm( + const array& x, + const array& w, + const std::optional& scales_w = std::nullopt, + const std::optional& lhs_indices = std::nullopt, + const std::optional& rhs_indices = std::nullopt, + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, + const std::string& mode = "nvfp4", + const std::optional& global_scale_x = std::nullopt, + const std::optional& global_scale_w = std::nullopt, + bool sorted_indices = false, + StreamOrDevice s = {}); + /** Returns a contraction of a and b over multiple dimensions. */ MLX_API array tensordot( const array& a, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 6ac88eef6c..c4d1df2ee7 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3825,6 +3825,22 @@ std::vector GatherQMM::output_shapes(const std::vector& inputs) { return {out_shape}; } +bool GatherQQMM::is_equivalent(const Primitive& other) const { + const GatherQQMM& qm_other = static_cast(other); + return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ && + mode_ == qm_other.mode_; +} + +std::vector GatherQQMM::output_shapes(const std::vector& inputs) { + const auto& x = inputs[0]; + const auto& w = inputs[1]; + const auto& lhs_indices = inputs[2]; + auto out_shape = lhs_indices.shape(); + out_shape.push_back(x.shape(-2)); + out_shape.push_back(w.shape(-2)); + return {out_shape}; +} + std::pair, std::vector> RandomBits::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 5b8517c56d..403490dbe8 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1716,6 +1716,41 @@ class GatherQMM : public UnaryPrimitive { bool right_sorted_; }; +class GatherQQMM : public UnaryPrimitive { + public: + explicit GatherQQMM( + Stream stream, + int group_size, + int bits, + QuantizationMode mode, + bool left_sorted = false, + bool right_sorted = false) + : UnaryPrimitive(stream), + group_size_(group_size), + bits_(bits), + mode_(mode), + left_sorted_(left_sorted), + right_sorted_(right_sorted) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_NAME(GatherQQMM) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return std::make_tuple( + group_size_, bits_, mode_, left_sorted_, right_sorted_); + } + + private: + int group_size_; + int bits_; + QuantizationMode mode_; + bool left_sorted_; + bool right_sorted_; +}; + class RandomBits : public UnaryPrimitive { public: explicit RandomBits(Stream stream, const Shape& shape, int width) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 19b19bd78f..62b3a1458d 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4801,6 +4801,58 @@ void init_ops(nb::module_& m) { array: The result of the multiplication of ``x`` with ``w`` after gathering using ``lhs_indices`` and ``rhs_indices``. )pbdoc"); + m.def( + "gather_qqmm", + &mx::gather_qqmm, + nb::arg(), + nb::arg(), + "scales"_a = nb::none(), + "lhs_indices"_a = nb::none(), + "rhs_indices"_a = nb::none(), + "group_size"_a = nb::none(), + "bits"_a = nb::none(), + "mode"_a = "nvfp4", + "global_scale_x"_a = nb::none(), + "global_scale_w"_a = nb::none(), + nb::kw_only(), + "sorted_indices"_a = false, + "stream"_a = nb::none(), + nb::sig( + "def gather_qqmm(x: array, w: array, /, scales: Optional[array] = None, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'nvfp4', global_scale_x: Optional[array] = None, global_scale_w: Optional[array] = None, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Fused :func:`qqmm` with matrix-level gather. + + Similar to :func:`gather_mm`, the indices ``lhs_indices`` and + ``rhs_indices`` contain flat indices along the batch dimensions (i.e. + all but the last two dimensions) of ``x`` and ``w`` respectively. + + Args: + x (array): Input array. + w (array): Weight matrix. If quantized, it is packed in unsigned integers. + scales (array, optional): The scales to use per ``group_size`` elements of + ``w`` if ``w`` is quantized. Default: ``None``. + lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``. + rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``. + group_size (int, optional): Number of elements in ``x`` and ``w`` that + share a scale. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. + bits (int, optional): Number of bits used to represent each element of + ``x`` and ``w``. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. + mode (str, optional): The quantization mode. Default: ``"nvfp4"``. + Supported modes are ``nvfp4`` and ``mxfp8``. See the + :ref:`table of quantization modes ` for details. + global_scale_x (array, optional): The per-input float32 scale used for x + with ``"nvfp4"`` quantization. Default: ``None``. + global_scale_w (array, optional): The per-input float32 scale used for w + with ``"nvfp4"`` quantization. Default: ``None``. + sorted_indices (bool, optional): May allow a faster implementation + if the passed indices are sorted. Default: ``False``. + + Returns: + array: The result of the multiplication of quantized ``x`` with quantized ``w``. + needed). + )pbdoc"); m.def( "segmented_mm", &mx::segmented_mm, diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 2850c7c357..40b24beb03 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -161,11 +161,7 @@ def test_nvfp4_quantize_dequantize(self): self.assertTrue(mx.all(w_hat == 0)) # Test nvfp4 quantize/dequantize with tensor-scale global_scale - # currently supported only on cpu and cuda - if not mx.metal.is_available(): - global_scale = w.abs().max().astype(mx.float32) - else: - global_scale = None + global_scale = w.abs().max().astype(mx.float32) w_q, scales = mx.quantize(w, mode="nvfp4", global_scale=global_scale) w_hat = mx.dequantize( @@ -178,7 +174,7 @@ def test_qqmv(self): k1, k2 = mx.random.split(key) tests = product( [256, 512, 67], # M - [64, 256], # N + [64, 256, 512], # N ["nvfp4", "mxfp8"], # mode ) for M, N, mode in tests: @@ -186,12 +182,8 @@ def test_qqmv(self): x_shape = (1, N) w_shape = (M, N) - # TODO: Fix qmv with global scale in Metal/CPU backends. - has_global_scale = ( - mode == "nvfp4" - and mx.cuda.is_available() - and mx.default_device() == mx.gpu - ) + # TODO: Fix qmv with global scale in CPU backend. + has_global_scale = mode == "nvfp4" and mx.default_device() == mx.gpu x = mx.random.normal(shape=x_shape, key=k1) global_scale_x = mx.max(mx.abs(x)) if has_global_scale else None @@ -224,31 +216,54 @@ def test_qqmv(self): self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 1e-3) - def test_qqmm_metal_global_scale_rejected(self): - # Tensor-scale nvfp4 (global_scale_x / global_scale_w) is not - # implemented in the Metal qqmm kernels. mx.qqmm must reject the - # request on Metal rather than silently dropping the global scales - # in the gemv path and producing incorrect results. - if not mx.metal.is_available(): + def test_qqmm(self): + if mx.default_device() == mx.cpu: + self.skipTest("Not implemented for CPU") return - w = mx.random.normal(shape=(64, 64)) - w_q, scales = mx.quantize(w, mode="nvfp4") - x = mx.random.normal(shape=(1, 64)) - gx = mx.array(1.0, dtype=mx.float32) - gw = mx.array(1.0, dtype=mx.float32) + key = mx.random.key(0) + k1, k2 = mx.random.split(key) + tests = product( + [8, 32, 33, 64], # M + [128, 256], # N + [128, 256], # K + ["nvfp4", "mxfp8"], # mode + ) + for M, N, K, mode in tests: + with self.subTest(shape=(M, N, K), mode=mode): + x_shape = (M, K) + w_shape = (N, K) - with self.assertRaises(RuntimeError): - y = mx.qqmm( - x, - w_q, - scales, - mode="nvfp4", - global_scale_x=gx, - global_scale_w=gw, - stream=mx.gpu, - ) - mx.eval(y) + x = mx.random.normal(shape=x_shape, key=k1) + global_scale_x = mx.max(mx.abs(x)) if mode == "nvfp4" else None + x_hat = mx.dequantize( + *mx.quantize(x, mode=mode, global_scale=global_scale_x), + mode=mode, + dtype=mx.float32, + global_scale=global_scale_x, + ) + + w = mx.random.normal(shape=w_shape, key=k2) + global_scale_w = mx.max(mx.abs(w)) if mode == "nvfp4" else None + w_q, scales = mx.quantize(w, mode=mode, global_scale=global_scale_w) + w_hat = mx.dequantize( + w_q, + scales, + mode=mode, + global_scale=global_scale_w, + dtype=mx.float32, + ) + y_q = mx.qqmm( + x, + w_q, + scales, + mode=mode, + global_scale_x=global_scale_x, + global_scale_w=global_scale_w, + ) + y_hat = x_hat @ mx.swapaxes(w_hat, -1, -2) + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) def test_qmm(self): key = mx.random.key(0) @@ -1072,6 +1087,100 @@ def test_shape( test_shape(32, 512, 32, transpose=False, **kwargs) test_shape(1, 512, 32, transpose=False, **kwargs) + def test_gather_qqmm(self): + if mx.default_device() == mx.cpu: + self.skipTest("Not implemented for CPU") + return + + key = mx.random.key(0) + k1, k2 = mx.random.split(key) + batches = ( + { + "batch_A": (1,), + "lhs_indices": (0,), + "batch_B": (3,), + "rhs_indices": (2, 1), + }, + { + "batch_A": (1,), + "lhs_indices": None, + "batch_B": (3,), + "rhs_indices": (2, 1), + }, + { + "batch_A": (2,), + "lhs_indices": None, + "batch_B": (3,), + "rhs_indices": (2, 1), + }, + { + "batch_A": (3,), + "lhs_indices": (0, 2), + "batch_B": (1,), + "rhs_indices": (0,), + }, + { + "batch_A": (5,), + "lhs_indices": (0, 2), + "batch_B": (3,), + "rhs_indices": (2, 1), + }, + ) + tests = product( + batches, + [1, 32], # M + [32, 256], # N + [32, 256], # K + ["nvfp4", "mxfp8"], # mode + ) + + for batch, M, N, K, mode in tests: + with self.subTest(shape=(M, N, K), mode=mode, **batch): + batch_A, lhs_indices, batch_B, rhs_indices = batch.values() + x_shape = (*batch_A, M, K) + w_shape = (*batch_B, N, K) + + x = mx.random.normal(shape=x_shape, key=k1) + global_scale_x = mx.max(mx.abs(x)) if mode == "nvfp4" else None + x_hat = mx.dequantize( + *mx.quantize(x, mode=mode, global_scale=global_scale_x), + mode=mode, + dtype=mx.float32, + global_scale=global_scale_x, + ) + + w = mx.random.normal(shape=w_shape, key=k2) + global_scale_w = mx.max(mx.abs(w)) if mode == "nvfp4" else None + w_q, scales = mx.quantize(w, mode=mode, global_scale=global_scale_w) + w_hat = mx.dequantize( + w_q, + scales, + mode=mode, + global_scale=global_scale_w, + dtype=mx.float32, + ) + + if lhs_indices is not None: + lhs_indices = mx.array(lhs_indices) + if rhs_indices is not None: + rhs_indices = mx.array(rhs_indices) + + y_q = mx.gather_qqmm( + x, + w_q, + scales, + lhs_indices, + rhs_indices, + mode=mode, + global_scale_x=global_scale_x, + global_scale_w=global_scale_w, + ) + y_hat = mx.gather_mm( + x_hat, mx.swapaxes(w_hat, -1, -2), lhs_indices, rhs_indices + ) + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + def test_qmm_fp_type(self): indices = mx.array([[2], [0], [1]], dtype=mx.uint32)