Skip to content

feat: add bf16 group gemm kernel#34

Open
ZelinMa557 wants to merge 1 commit intoTencent:mainfrom
ZelinMa557:bf16_group_gemm
Open

feat: add bf16 group gemm kernel#34
ZelinMa557 wants to merge 1 commit intoTencent:mainfrom
ZelinMa557:bf16_group_gemm

Conversation

@ZelinMa557
Copy link

This pr implement bf16 group gemm, which follows the pattern of fp8 kernels in hpc-ops.

I use the following script to benchmark this kernel against sglang's triton group gemm kernel, and torch naive implementation:

"""
Benchmark script for hpc.group_gemm_bf16 vs sglang fused_moe_kernel (BF16).

Usage:
    python tests/bench_group_gemm_bf16.py
    python tests/bench_group_gemm_bf16.py --E 128 --N 384 --K 4096
    python tests/bench_group_gemm_bf16.py --E 8 --N 4096 --K 7168 --warmup 10 --iters 100
"""

import argparse
import os
import sys
from pathlib import Path

sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0]))

import torch
import triton.language as tl
import hpc
from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_kernels import invoke_fused_moe_kernel
from sglang.srt.layers.moe.fused_moe_triton.moe_align_block_size import moe_align_block_size

# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--E", type=int, default=8, help="Number of groups (experts)")
    p.add_argument("--N", type=int, default=4096, help="Output dimension per group")
    p.add_argument("--K", type=int, default=7168, help="Input dimension (hidden size)")
    p.add_argument("--warmup", type=int, default=5, help="Warmup iterations before graph capture")
    p.add_argument("--iters", type=int, default=100, help="Graph replay iterations for timing")
    return p.parse_args()


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

M_VALUES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]


def make_inputs(E, m_per_group, N, K, device="cuda"):
    dtype = torch.bfloat16
    total_m = E * m_per_group
    x = torch.randn((total_m, K), dtype=dtype, device=device)
    w = torch.randn((E, N, K), dtype=dtype, device=device)
    seqlens = torch.full((E,), m_per_group, dtype=torch.int32, device=device)
    cu_seqlens = torch.zeros(E + 1, dtype=torch.int32, device=device)
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
    output = torch.empty((total_m, N), dtype=dtype, device=device)
    return x, w, seqlens, cu_seqlens, output


def tflops(E, m_per_group, N, K, elapsed_ms):
    flops = 2 * E * m_per_group * N * K
    return flops / (elapsed_ms * 1e-3) / 1e12


# ---------------------------------------------------------------------------
# sglang default BF16 config
# (mirrors get_default_config logic for dtype=None, no server args needed)
# ---------------------------------------------------------------------------


def sglang_bf16_config(total_m, E):
    """
    Replicates sglang's get_default_config logic for plain BF16 (dtype=None).
      M <= E  ->  small-batch config (BLOCK_SIZE_M=16)
      M >  E  ->  regular config    (BLOCK_SIZE_M=64)
    num_warps / num_stages are typical Hopper-friendly defaults.
    """
    if total_m <= E:
        return {
            "BLOCK_SIZE_M": 16,
            "BLOCK_SIZE_N": 32,
            "BLOCK_SIZE_K": 64,
            "GROUP_SIZE_M": 1,
            "num_warps": 4,
            "num_stages": 3,
        }
    else:
        return {
            "BLOCK_SIZE_M": 64,
            "BLOCK_SIZE_N": 64,
            "BLOCK_SIZE_K": 32,
            "GROUP_SIZE_M": 8,
            "num_warps": 4,
            "num_stages": 3,
        }


# ---------------------------------------------------------------------------
# Benchmark with CUDA graph
# ---------------------------------------------------------------------------


def bench_cuda_graph(fn, warmup, iters):
    """Warm up, capture a CUDA graph, replay `iters` times, return avg ms."""
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()

    g = torch.cuda.CUDAGraph()
    capture_stream = torch.cuda.Stream()
    with torch.cuda.stream(capture_stream):
        fn()  # dry run on capture stream
        torch.cuda.synchronize()
        with torch.cuda.graph(g, stream=capture_stream):
            fn()

    torch.cuda.synchronize()

    t0 = torch.cuda.Event(enable_timing=True)
    t1 = torch.cuda.Event(enable_timing=True)
    t0.record()
    for _ in range(iters):
        g.replay()
    t1.record()
    torch.cuda.synchronize()
    return t0.elapsed_time(t1) / iters


# ---------------------------------------------------------------------------
# Per-kernel bench helpers
# ---------------------------------------------------------------------------


def bench_hpc(x, w, seqlens, cu_seqlens, output, mean_seq, warmup, iters):
    def fn():
        hpc.group_gemm_bf16(
            x, w, seqlens, cu_seqlens, num_seq_per_group_avg=mean_seq, output=output
        )

    return bench_cuda_graph(fn, warmup, iters)


def bench_torch_ref(x, w, seqlens, cu_seqlens, output, warmup, iters):
    E = seqlens.shape[0]
    slices = [(int(cu_seqlens[i].item()), int(cu_seqlens[i + 1].item())) for i in range(E)]

    def fn():
        for i, (s, e) in enumerate(slices):
            output[s:e] = torch.matmul(x[s:e], w[i].t())

    return bench_cuda_graph(fn, warmup, iters)


def bench_sglang(x, w, E, m_per_group, N, warmup, iters):
    """
    Benchmark sglang's invoke_fused_moe_kernel for a single up/gate projection
    (top_k=1, each token routes to its group, no routed-weight multiply).

    moe_align_block_size is called once outside the graph; only the GEMM
    kernel itself is captured and replayed.
    """
    total_m = E * m_per_group
    config = sglang_bf16_config(total_m, E)

    # top_k=1: token i belongs to expert floor(i / m_per_group)
    topk_ids = (
        (torch.arange(total_m, device="cuda") // m_per_group).view(-1, 1).to(torch.int32)
    )  # [total_m, 1]
    topk_weights = torch.ones(total_m, 1, dtype=torch.bfloat16, device="cuda")

    # Routing tensors – computed once, reused by every graph replay
    sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
        topk_ids, config["BLOCK_SIZE_M"], E
    )

    # Output shape follows sglang convention: [padded_tokens, N]
    sgl_out = torch.empty((sorted_token_ids.shape[0], N), dtype=torch.bfloat16, device="cuda")

    def fn():
        invoke_fused_moe_kernel(
            A=x,
            B=w,
            bias=None,
            C=sgl_out,
            A_scale=None,
            B_scale=None,
            B_zp=None,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            sorted_token_ids=sorted_token_ids,
            expert_ids=expert_ids,
            num_tokens_post_padded=num_tokens_post_padded,
            mul_routed_weight=False,
            top_k=1,
            config=config,
            compute_type=tl.bfloat16,
            use_fp8_w8a8=False,
            use_int8_w8a8=False,
            use_int8_w8a16=False,
            use_int4_w4a16=False,
            per_channel_quant=False,
            block_shape=None,
        )

    return bench_cuda_graph(fn, warmup, iters)


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------


def main():
    args = parse_args()
    E, N, K = args.E, args.N, args.K

    torch.manual_seed(0)
    torch.cuda.manual_seed(0)

    prop = torch.cuda.get_device_properties(0)
    print(f"\nDevice : {prop.name}")
    print(f"Config : E={E}, N={N}, K={K}")
    print(f"Timing : warmup={args.warmup}, iters={args.iters} (CUDA graph replay)\n")

    hdr = (
        f"{'M/group':>8}  {'total_M':>8}  "
        f"{'hpc(ms)':>10}  {'hpc(TF)':>9}  "
        f"{'sgl(ms)':>10}  {'sgl(TF)':>9}  "
        f"{'ref(ms)':>10}  {'ref(TF)':>9}  "
        f"{'hpc/sgl':>8}"
    )
    sep = "-" * len(hdr)
    print(hdr)
    print(sep)

    for m in M_VALUES:
        x, w, seqlens, cu_seqlens, output = make_inputs(E, m, N, K)

        hpc_ms = bench_hpc(x, w, seqlens, cu_seqlens, output, m, args.warmup, args.iters)
        sgl_ms = bench_sglang(x, w, E, m, N, args.warmup, args.iters)
        ref_ms = bench_torch_ref(x, w, seqlens, cu_seqlens, output, args.warmup, args.iters)

        hpc_tf = tflops(E, m, N, K, hpc_ms)
        sgl_tf = tflops(E, m, N, K, sgl_ms)
        ref_tf = tflops(E, m, N, K, ref_ms)

        print(
            f"{m:>8}  {E*m:>8}  "
            f"{hpc_ms:>10.4f}  {hpc_tf:>9.2f}  "
            f"{sgl_ms:>10.4f}  {sgl_tf:>9.2f}  "
            f"{ref_ms:>10.4f}  {ref_tf:>9.2f}  "
            f"{sgl_ms/hpc_ms:>8.2f}x"
        )

    print(sep)
    print()


if __name__ == "__main__":
    main()

Result for Qwen3-235b-a22b:
TP8 gate+up:

python tests/bench_group_gemm_bf16.py --E 128 --N 384 --K 4096
<frozen importlib._bootstrap_external>:1297: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.
<frozen importlib._bootstrap_external>:1297: FutureWarning: The cuda.nvrtc module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.nvrtc module instead.

Device : NVIDIA H20
Config : E=128, N=384, K=4096
Timing : warmup=5, iters=100 (CUDA graph replay)

 M/group   total_M     hpc(ms)    hpc(TF)     sgl(ms)    sgl(TF)     ref(ms)    ref(TF)   hpc/sgl
-------------------------------------------------------------------------------------------------
       1       128      0.1229       3.28      0.1408       2.86      1.0639       0.38      1.15x
       2       256      0.1246       6.46      0.2095       3.84      1.0692       0.75      1.68x
       4       512      0.1255      12.84      0.2067       7.79      1.0766       1.50      1.65x
       8      1024      0.1275      25.26      0.2093      15.39      1.0747       3.00      1.64x
      16      2048      0.1304      49.40      0.2109      30.54      1.0967       5.87      1.62x
      32      4096      0.1442      89.34      0.2164      59.53      1.2433      10.36      1.50x
      64      8192      0.2058     125.25      0.2397     107.52      1.2878      20.01      1.16x
     128     16384      0.3979     129.52      0.4074     126.50      1.3215      39.00      1.02x
     256     32768      0.7723     133.47      0.7695     133.96      1.7923      57.51      1.00x
     512     65536      1.5234     135.32      1.4930     138.08      2.7824      74.09      0.98x
-------------------------------------------------------------------------------------------------

EP8 gate+up:

python tests/bench_group_gemm_bf16.py --E 16 --N 3072 --K 4096
<frozen importlib._bootstrap_external>:1297: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.
<frozen importlib._bootstrap_external>:1297: FutureWarning: The cuda.nvrtc module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.nvrtc module instead.

Device : NVIDIA H20
Config : E=16, N=3072, K=4096
Timing : warmup=5, iters=100 (CUDA graph replay)

 M/group   total_M     hpc(ms)    hpc(TF)     sgl(ms)    sgl(TF)     ref(ms)    ref(TF)   hpc/sgl
-------------------------------------------------------------------------------------------------
       1        16      0.1199       3.36      0.1407       2.86      0.2201       1.83      1.17x
       2        32      0.1203       6.69      0.2117       3.80      0.2222       3.62      1.76x
       4        64      0.1224      13.16      0.2097       7.68      0.2230       7.22      1.71x
       8       128      0.1233      26.11      0.2124      15.16      0.2253      14.30      1.72x
      16       256      0.1263      50.99      0.2106      30.59      0.2352      27.39      1.67x
      32       512      0.1359      94.82      0.2161      59.63      0.2619      49.20      1.59x
      64      1024      0.1995     129.18      0.2249     114.57      0.3295      78.21      1.13x
     128      2048      0.3858     133.58      0.3914     131.69      0.5546      92.93      1.01x
     256      4096      0.7645     134.83      0.7523     137.02      0.8352     123.42      0.98x
     512      8192      1.5100     136.53      1.4683     140.40      1.6535     124.68      0.97x
-------------------------------------------------------------------------------------------------

TP8 down:

python tests/bench_group_gemm_bf16.py --E 128 --N 4096 --K 192
<frozen importlib._bootstrap_external>:1297: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.
<frozen importlib._bootstrap_external>:1297: FutureWarning: The cuda.nvrtc module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.nvrtc module instead.

Device : NVIDIA H20
Config : E=128, N=4096, K=192
Timing : warmup=5, iters=100 (CUDA graph replay)

 M/group   total_M     hpc(ms)    hpc(TF)     sgl(ms)    sgl(TF)     ref(ms)    ref(TF)   hpc/sgl
-------------------------------------------------------------------------------------------------
       1       128      0.0684       2.94      0.0850       2.37      0.6110       0.33      1.24x
       2       256      0.0701       5.74      0.0978       4.12      0.6158       0.65      1.39x
       4       512      0.0717      11.23      0.0982       8.20      0.6268       1.28      1.37x
       8      1024      0.0731      22.05      0.1003      16.05      0.6322       2.55      1.37x
      16      2048      0.0753      42.76      0.1036      31.08      0.6405       5.03      1.38x
      32      4096      0.0920      70.00      0.1105      58.29      0.6807       9.46      1.20x
      64      8192      0.1314      98.03      0.1258     102.45      0.7466      17.26      0.96x
     128     16384      0.2342     110.05      0.2222     115.97      0.8785      29.33      0.95x
     256     32768      0.4343     118.66      0.4144     124.36      1.1870      43.42      0.95x
     512     65536      0.8436     122.19      0.8093     127.37      1.6995      60.65      0.96x
-------------------------------------------------------------------------------------------------

EP8 down:

python tests/bench_group_gemm_bf16.py --E 16 --N 4096 --K 1536
<frozen importlib._bootstrap_external>:1297: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.
<frozen importlib._bootstrap_external>:1297: FutureWarning: The cuda.nvrtc module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.nvrtc module instead.

Device : NVIDIA H20
Config : E=16, N=4096, K=1536
Timing : warmup=5, iters=100 (CUDA graph replay)

 M/group   total_M     hpc(ms)    hpc(TF)     sgl(ms)    sgl(TF)     ref(ms)    ref(TF)   hpc/sgl
-------------------------------------------------------------------------------------------------
       1        16      0.0648       3.10      0.0701       2.87      0.1311       1.54      1.08x
       2        32      0.0669       6.02      0.1039       3.88      0.1327       3.03      1.55x
       4        64      0.0667      12.08      0.1039       7.75      0.1332       6.04      1.56x
       8       128      0.0677      23.80      0.1039      15.50      0.1349      11.93      1.54x
      16       256      0.0682      47.25      0.1038      31.03      0.1363      23.63      1.52x
      32       512      0.0746      86.41      0.1053      61.17      0.1497      43.04      1.41x
      64      1024      0.1103     116.84      0.1147     112.36      0.1857      69.38      1.04x
     128      2048      0.2098     122.83      0.1986     129.77      0.2839      90.77      0.95x
     256      4096      0.3959     130.18      0.3731     138.13      0.4848     106.30      0.94x
     512      8192      0.7692     134.01      0.7378     139.71      0.8854     116.42      0.96x
-------------------------------------------------------------------------------------------------

This kernel shows great speedup over sglang when M/group is small, and it is 1% ~ 5% slower when M/group >= 128.

I will implement bf16 fused moe in the next pr.

Signed-off-by: ZelinMa557 <3388706467@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant