From 2c80a1803195b88c6af843bccbce6d756cc74789 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 29 Jan 2026 10:28:55 +0000 Subject: [PATCH 1/6] add sqrtsoftplus Signed-off-by: Xin Yao --- .../fused_score_for_moe_aux_loss.cu | 75 +++++++++++-- .../fused_topk_with_score_function.cu | 101 +++++++++++++++--- .../common/fused_router/utils.h | 39 +++++++ .../include/transformer_engine/fused_router.h | 8 +- .../pytorch/csrc/extensions/router.cpp | 15 +-- transformer_engine/pytorch/router.py | 6 +- 6 files changed, 207 insertions(+), 37 deletions(-) diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index 7540b5c41d..fee9822ce9 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -78,11 +78,11 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi * Possible preprocess the scores before the topk operation * - Pre-softmax * - Sigmoid - * - Sigmoid post-processing when topk > 1 + * - Sqrtsoftplus + * - Sigmoid/Sqrtsoftplus post-processing when topk > 1 * This is in-place scores update */ - // score_function == 1 means softmax - if (score_function == 1) { + if (score_function == 1) { // score_function == 1 means softmax // Apply softmax to the logits before the topk apply_softmax_on_float(local_logits, num_experts, lane_id); __syncwarp(); @@ -90,10 +90,7 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { intermediate_output[pos_offset + i] = local_logits[i]; } - } - - // score_function == 0 means sigmoid - if (score_function == 0) { + } else if (score_function == 0) { // score_function == 0 means sigmoid // Apply sigmoid to the logits apply_sigmoid_on_float(local_logits, num_experts, lane_id); __syncwarp(); @@ -101,11 +98,20 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { intermediate_output[pos_offset + i] = local_logits[i]; } + } else if (score_function == 2) { // score_function == 2 means sqrtsoftplus + // First save the original logits for backward (needed for gradient computation) + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + intermediate_output[pos_offset + i] = local_logits[i]; // Save original logits + } + __syncwarp(); + // Apply sqrtsoftplus to the logits + apply_sqrtsoftplus_on_float(local_logits, num_experts, lane_id); } - __syncwarp(); //Confirm the scores is written to the softmax/sigmoid output + __syncwarp(); //Confirm the scores is written to the output - if (score_function == 0) { + // Sigmoid/Sqrtsoftplus post-processing when topk > 1 + if (score_function == 0 || score_function == 2) { if (topk > 1) { auto sum_logits = warp_reduce_on_shmem(local_logits, num_experts, ReduceFuncType::SUM, lane_id); @@ -227,8 +233,9 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int /*** * Section: Backward of ops before the topk * - Pre-softmax bwd - * - Sigmoid Post-processing bwd when topk > 1 + * - Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1 * - Sigmoid bwd + * - Sqrtsoftplus bwd * - Write the grad_logits to the global mem */ // Sigmoid Post-processing bwd when topk > 1 @@ -250,8 +257,46 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int ((static_cast(sum_fwd_input) + epsilon) * (static_cast(sum_fwd_input) + epsilon)); } + __syncwarp(); + } + + // Sqrtsoftplus: First compute sqrtsoftplus output from original logits + // (needed for both post-processing bwd and activation bwd, compute once here) + // For sqrtsoftplus, intermediate_output stores original logits + if (score_function == 2) { + // Copy original logits to local_comp_buf and apply sqrtsoftplus in-place + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + local_comp_buf[i] = local_act_from_fwd[i]; + } + __syncwarp(); + apply_sqrtsoftplus_on_float(local_comp_buf, num_experts, lane_id); + __syncwarp(); + } + + // Sqrtsoftplus Post-processing bwd when topk > 1 (normalization backward) + if (topk > 1 && score_function == 2) { + auto sum_fwd_input = + warp_reduce_on_shmem(local_comp_buf, num_experts, ReduceFuncType::SUM, lane_id); + // Compute sum of output * grad using registers instead of shared memory + double local_sum_Output_x_Grad = 0.0; + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + local_sum_Output_x_Grad += + static_cast(local_grad[i]) * static_cast(local_comp_buf[i]); + } + // Warp reduce the sum + for (int s = 16; s > 0; s /= 2) { + local_sum_Output_x_Grad += __shfl_xor_sync(0xffffffff, local_sum_Output_x_Grad, s); + } + double sum_Output_x_Grad = local_sum_Output_x_Grad; + // In-place update + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + local_grad[i] = + static_cast(local_grad[i]) / (static_cast(sum_fwd_input) + epsilon) - + sum_Output_x_Grad / ((static_cast(sum_fwd_input) + epsilon) * + (static_cast(sum_fwd_input) + epsilon)); + } + __syncwarp(); } - __syncwarp(); // Pre-softmax bwd if (score_function == 1) { @@ -264,6 +309,14 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int apply_sigmoid_bwd_on_float(local_grad, local_act_from_fwd, num_experts, lane_id); __syncwarp(); } + // Sqrtsoftplus bwd + // For sqrtsoftplus, local_comp_buf already contains sqrtsoftplus output computed earlier + // Now compute gradient: dy/dx = sigmoid(x) / (2 * y) + if (score_function == 2) { + apply_sqrtsoftplus_bwd_on_float(local_grad, local_comp_buf, local_act_from_fwd, num_experts, + lane_id); + __syncwarp(); + } // Write the grad_logits to the global mem for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { grad_logits[pos_offset + i] = local_grad[i]; diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 2719c68c97..3237658ccd 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -96,11 +96,11 @@ __global__ void fused_topk_with_score_function_forward_kernel( * Possible preprocess the scores before the topk operation * - Pre-softmax * - Sigmoid + * - Sqrtsoftplus * - Expert bias * This is in-place scores update */ - // score_function == 1 means softmax - if (use_pre_softmax && score_function == 1) { + if (use_pre_softmax && score_function == 1) { // score_function == 1 means softmax // Apply softmax to the logits before the topk apply_softmax_on_float(scores, num_experts, lane_id); __syncwarp(); @@ -108,10 +108,7 @@ __global__ void fused_topk_with_score_function_forward_kernel( for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { intermediate_output[pos_offset + i] = scores[i]; } - } - - // score_function == 0 means sigmoid - if (score_function == 0) { + } else if (score_function == 0) { // score_function == 0 means sigmoid // Apply sigmoid to the logits apply_sigmoid_on_float(scores, num_experts, lane_id); __syncwarp(); @@ -119,12 +116,20 @@ __global__ void fused_topk_with_score_function_forward_kernel( for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { intermediate_output[pos_offset + i] = scores[i]; } + } else if (score_function == 2) { // score_function == 2 means sqrtsoftplus + // First save the original logits for backward (needed for sigmoid computation) + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + intermediate_output[pos_offset + i] = scores[i]; // Save original logits + } + __syncwarp(); + // Apply sqrtsoftplus to the logits + apply_sqrtsoftplus_on_float(scores, num_experts, lane_id); } - __syncwarp(); //Confirm the scores is written to the softmax/sigmoid output + __syncwarp(); //Confirm the scores is written to the output - // Expert bias is only used at the sigmoid case - if (expert_bias && score_function == 0) { + // Expert bias is only used at the sigmoid/sqrtsoftplus case + if (expert_bias && (score_function == 0 || score_function == 2)) { for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { scores[i] = static_cast(static_cast(scores[i]) + static_cast(expert_bias[i])); @@ -203,8 +208,8 @@ __global__ void fused_topk_with_score_function_forward_kernel( topk_scores[i] = static_cast(topk_scores[i]) - static_cast(expert_bias[topk_indices[i]]); } + __syncwarp(); } - __syncwarp(); // score_function == 1 means softmax if (!use_pre_softmax && score_function == 1) { @@ -215,10 +220,22 @@ __global__ void fused_topk_with_score_function_forward_kernel( for (int i = lane_id; i < topk; i += kThreadsPerWarp) { intermediate_output[pos_offset + topk_indices[i]] = topk_scores[i]; } + __syncwarp(); } - // score_function == 0 means sigmoid - if (score_function == 0) { + // Sigmoid/Sqrtsoftplus post-processing when topk > 1 + if (score_function == 0 || score_function == 2) { + if (topk > 1) { + double sum_scores = warp_reduce_on_shmem(topk_scores, topk, ReduceFuncType::SUM, lane_id); + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + topk_scores[i] = static_cast(topk_scores[i]) / (sum_scores + epsilon); + } + } + __syncwarp(); + } + + // score_function == 2 means sqrtsoftplus + if (score_function == 2) { if (topk > 1) { double sum_scores = warp_reduce_on_shmem(topk_scores, topk, ReduceFuncType::SUM, lane_id); for (int i = lane_id; i < topk; i += kThreadsPerWarp) { @@ -357,6 +374,7 @@ __global__ void fused_topk_with_score_function_backward_kernel( } } __syncwarp(); + // Sigmoid Post-processing bwd when topk > 1 if (topk > 1 && score_function == 0) { double sum_fwd_input = masked_warp_reduce_on_shmem( @@ -386,8 +404,56 @@ __global__ void fused_topk_with_score_function_backward_kernel( local_grad[i] = 0.0f; } } + __syncwarp(); } - __syncwarp(); + + // Sqrtsoftplus: First compute sqrtsoftplus output from original logits + // (needed for both post-processing bwd and activation bwd, compute once here) + // For sqrtsoftplus, intermediate_output stores original logits + if (score_function == 2) { + // Copy original logits to local_comp_buf and apply sqrtsoftplus in-place + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + local_comp_buf[i] = local_act_from_fwd[i]; + } + __syncwarp(); + apply_sqrtsoftplus_on_float(local_comp_buf, num_experts, lane_id); + __syncwarp(); + } + + // Sqrtsoftplus Post-processing bwd when topk > 1 (normalization backward) + if (topk > 1 && score_function == 2) { + // Now do the normalization backward (same as sigmoid) + double sum_fwd_input = masked_warp_reduce_on_shmem( + /*data ptr = */ local_comp_buf, + /*mask ptr = */ local_routing_map, + /*data size = */ num_experts, + /*reduce func = */ ReduceFuncType::SUM, lane_id); + // Compute sum of output * grad using registers instead of shared memory + double local_sum_Output_x_Grad = 0.0; + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + if (local_routing_map[i]) { + local_sum_Output_x_Grad += + static_cast(local_grad[i]) * static_cast(local_comp_buf[i]); + } + } + // Warp reduce the sum + for (int s = 16; s > 0; s /= 2) { + local_sum_Output_x_Grad += __shfl_xor_sync(0xffffffff, local_sum_Output_x_Grad, s); + } + double sum_Output_x_Grad = local_sum_Output_x_Grad; + // In-place update + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + if (local_routing_map[i]) { + local_grad[i] = + static_cast(local_grad[i]) / (sum_fwd_input + epsilon) - + sum_Output_x_Grad / ((sum_fwd_input + epsilon) * (sum_fwd_input + epsilon)); + } else { + local_grad[i] = 0.0f; + } + } + __syncwarp(); + } + // Softmax bwd if use_pre_softmax is false if (!use_pre_softmax && score_function == 1) { apply_softmax_bwd_on_float(local_grad, local_act_from_fwd, local_comp_buf, local_routing_map, @@ -410,6 +476,7 @@ __global__ void fused_topk_with_score_function_backward_kernel( * Section: Backward of ops before the topk * - Pre-softmax bwd * - Sigmoid bwd + * - Sqrtsoftplus bwd * - Write the grad_logits to the global mem */ // Pre-softmax bwd @@ -423,6 +490,14 @@ __global__ void fused_topk_with_score_function_backward_kernel( apply_sigmoid_bwd_on_float(local_grad, local_act_from_fwd, num_experts, lane_id); __syncwarp(); } + // Sqrtsoftplus bwd + // For sqrtsoftplus, local_comp_buf already contains sqrtsoftplus output computed earlier + // Now compute gradient: dy/dx = sigmoid(x) / (2 * y) + if (score_function == 2) { + apply_sqrtsoftplus_bwd_on_float(local_grad, local_comp_buf, local_act_from_fwd, num_experts, + lane_id); + __syncwarp(); + } // Write the grad_logits to the global mem for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { grad_logits[pos_offset + i] = local_grad[i]; diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 4ae0b467b5..b752b4974a 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -112,6 +112,45 @@ __device__ inline void apply_sigmoid_bwd_on_float(DataType *grad, DataType *fwd_ } } +// sqrtsoftplus: y = sqrt(softplus(x)) = sqrt(log(1 + exp(x))) +// We store the sqrtsoftplus output (y) in intermediate_output for backward +template +__device__ inline void apply_sqrtsoftplus_on_float(DataType *scores, int data_size, int lane_id) { + for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { + float x = static_cast(scores[i]); + // softplus(x) = log(1 + exp(x)), numerically stable version + float softplus_val; + if (x > 20.0f) { + softplus_val = x; // for large x, softplus(x) ≈ x + } else if (x < -20.0f) { + softplus_val = expf(x); // for small x, softplus(x) ≈ exp(x) + } else { + softplus_val = log1pf(expf(x)); + } + scores[i] = static_cast(sqrtf(softplus_val)); + } +} + +// sqrtsoftplus backward: +// y = sqrt(softplus(x)) +// dy/dx = sigmoid(x) / (2 * y) +// where sigmoid(x) = 1 / (1 + exp(-x)) +// We need the original logits (x) to compute sigmoid, which we store in a separate buffer +template +__device__ inline void apply_sqrtsoftplus_bwd_on_float(DataType *grad, DataType *fwd_output, + DataType *logits_buf, int data_size, + int lane_id) { + for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { + float x = static_cast(logits_buf[i]); // original logit + float y = static_cast(fwd_output[i]); // sqrtsoftplus output + // sigmoid(x) = 1 / (1 + exp(-x)) + float sigmoid_x = 1.0f / (1.0f + expf(-x)); + // dy/dx = sigmoid(x) / (2 * y) + float dy_dx = sigmoid_x / (2.0f * y + epsilon); + grad[i] = static_cast(static_cast(grad[i]) * dy_dx); + } +} + template __device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_output, DataType *comp_buf, bool *mask, int data_size, diff --git a/transformer_engine/common/include/transformer_engine/fused_router.h b/transformer_engine/common/include/transformer_engine/fused_router.h index 1f026a703d..fcf51b934a 100644 --- a/transformer_engine/common/include/transformer_engine/fused_router.h +++ b/transformer_engine/common/include/transformer_engine/fused_router.h @@ -23,7 +23,7 @@ extern "C" { * \param[in] num_groups Number of groups in grouped topk. * \param[in] group_topk Grouped topk value. * \param[in] scaling_factor Scaling factor. - * \param[in] score_function Score function, 0: sigmoid, 1: softmax. + * \param[in] score_function Score function, 0: sigmoid, 1: softmax, 2: sqrtsoftplus. * \param[in] expert_bias Expert bias. (Only used at the sigmoid case) * \param[out] probs Output tensor for probabilities. * \param[out] routing_map Output tensor for routing map. @@ -46,7 +46,7 @@ void nvte_fused_topk_with_score_function_forward( * \param[in] topk Topk value. * \param[in] use_pre_softmax Whether to use softmax before topk. * \param[in] scaling_factor Scaling factor. - * \param[in] score_function Score function, 0: sigmoid, 1: softmax. + * \param[in] score_function Score function, 0: sigmoid, 1: softmax, 2: sqrtsoftplus. * \param[out] grad_logits Gradient of logits. * \param[in] stream CUDA stream used for the operation. */ @@ -63,7 +63,7 @@ void nvte_fused_topk_with_score_function_backward(const NVTETensor routing_map, * \param[in] num_tokens Number of tokens. * \param[in] num_experts Number of experts. * \param[in] topk Topk value. - * \param[in] score_function Score function, 0: sigmoid, 1: softmax. + * \param[in] score_function Score function, 0: sigmoid, 1: softmax, 2: sqrtsoftplus. * \param[out] scores Output tensor for scores. * \param[in] routing_map Routing map. * \param[in] intermediate_output Intermediate output from the forward pass. (Softmax/sigmoid output) @@ -82,7 +82,7 @@ void nvte_fused_score_for_moe_aux_loss_forward(const NVTETensor logits, int num_ * \param[in] num_tokens Number of tokens. * \param[in] num_experts Number of experts. * \param[in] topk Topk value. - * \param[in] score_function Score function, 0: sigmoid, 1: softmax. + * \param[in] score_function Score function, 0: sigmoid, 1: softmax, 2: sqrtsoftplus. * \param[out] grad_logits Gradient of logits. * \param[in] stream CUDA stream used for the operation. */ diff --git a/transformer_engine/pytorch/csrc/extensions/router.cpp b/transformer_engine/pytorch/csrc/extensions/router.cpp index 2ae0d648a1..c4d1503a37 100644 --- a/transformer_engine/pytorch/csrc/extensions/router.cpp +++ b/transformer_engine/pytorch/csrc/extensions/router.cpp @@ -9,7 +9,8 @@ namespace transformer_engine::pytorch { -static std::map score_function_map = {{"sigmoid", 0}, {"softmax", 1}}; +static std::map score_function_map = { + {"sigmoid", 0}, {"softmax", 1}, {"sqrtsoftplus", 2}}; std::tuple fused_topk_with_score_function_fwd( at::Tensor logits, int topk, bool use_pre_softmax, c10::optional num_groups, @@ -26,9 +27,10 @@ std::tuple fused_topk_with_score_function_fw "score_function must be sigmoid when expert_bias is not None"); } // Check if the score function is valid - TORCH_CHECK(score_function == "softmax" || score_function == "sigmoid", - "score_function must be softmax or sigmoid for router fusion"); - if (score_function == "sigmoid") { + TORCH_CHECK(score_function == "softmax" || score_function == "sigmoid" || + score_function == "sqrtsoftplus", + "score_function must be softmax, sigmoid or sqrtsoftplus for router fusion"); + if (score_function == "sigmoid" || score_function == "sqrtsoftplus") { use_pre_softmax = false; // Pre-softmax only happens at the softmax case } @@ -99,8 +101,9 @@ std::tuple fused_score_for_moe_aux_loss_fwd( "num_tokens and num_experts must be greater than 0"); TORCH_CHECK(topk > 0, "topk must be greater than 0"); // Check if the score function is valid - TORCH_CHECK(score_function == "softmax" || score_function == "sigmoid", - "score_function must be softmax or sigmoid for router fusion"); + TORCH_CHECK(score_function == "softmax" || score_function == "sigmoid" || + score_function == "sqrtsoftplus", + "score_function must be softmax, sigmoid or sqrtsoftplus for router fusion"); int score_function_value = score_function_map[score_function]; // Construct the output tensor diff --git a/transformer_engine/pytorch/router.py b/transformer_engine/pytorch/router.py index 52d1d9d6ca..88f665ca14 100644 --- a/transformer_engine/pytorch/router.py +++ b/transformer_engine/pytorch/router.py @@ -11,7 +11,7 @@ class FusedTopkScoreFunction(torch.autograd.Function): """ Fused Topk with Score Function router. - Currently, only support softmax and sigmoid. + Currently, support softmax, sigmoid and sqrtsoftplus. """ @staticmethod @@ -102,7 +102,7 @@ def fused_topk_with_score_function( used in the group topk scaling_factor : float score_function : str - currently only support softmax and sigmoid + currently support softmax, sigmoid and sqrtsoftplus expert_bias : torch.Tensor could be used in the sigmoid @@ -189,7 +189,7 @@ def fused_compute_score_for_moe_aux_loss( logits : torch.Tensor topk : int score_function : str - currently only support softmax and sigmoid + currently support softmax, sigmoid and sqrtsoftplus Returns ------- From 11bb63f017379c786186f8f599fbcde9b1689a4e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 29 Jan 2026 10:31:04 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/fused_router/utils.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index b752b4974a..9c79efbcf7 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -138,8 +138,8 @@ __device__ inline void apply_sqrtsoftplus_on_float(DataType *scores, int data_si // We need the original logits (x) to compute sigmoid, which we store in a separate buffer template __device__ inline void apply_sqrtsoftplus_bwd_on_float(DataType *grad, DataType *fwd_output, - DataType *logits_buf, int data_size, - int lane_id) { + DataType *logits_buf, int data_size, + int lane_id) { for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { float x = static_cast(logits_buf[i]); // original logit float y = static_cast(fwd_output[i]); // sqrtsoftplus output From a6f32fffbc4cfd56a6517ce0acb114a76390746b Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Fri, 30 Jan 2026 09:51:51 +0000 Subject: [PATCH 3/6] update and add tests Signed-off-by: Xin Yao --- tests/pytorch/test_fused_router.py | 64 ++++++++++++----- .../fused_score_for_moe_aux_loss.cu | 37 +++------- .../fused_topk_with_score_function.cu | 69 ++++--------------- .../common/fused_router/utils.h | 22 +++--- .../pytorch/csrc/extensions/router.cpp | 2 +- 5 files changed, 85 insertions(+), 109 deletions(-) diff --git a/tests/pytorch/test_fused_router.py b/tests/pytorch/test_fused_router.py index f559362d82..303f0ee372 100644 --- a/tests/pytorch/test_fused_router.py +++ b/tests/pytorch/test_fused_router.py @@ -47,7 +47,7 @@ def group_limited_topk( # Pytorch-based topk softmax/sigmoid -def topk_softmax_sigmoid_pytorch( +def topk_score_function_pytorch( logits: torch.Tensor, topk: int, use_pre_softmax: bool = False, @@ -79,8 +79,11 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): else: scores, top_indices = compute_topk(logits, topk, num_groups, group_topk) probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits) - elif score_function == "sigmoid": - scores = torch.sigmoid(logits.float()).type_as(logits) + elif score_function in ("sigmoid", "sqrtsoftplus"): + if score_function == "sigmoid": + scores = torch.sigmoid(logits.float()).type_as(logits) + else: + scores = torch.nn.functional.softplus(logits.float()).sqrt().type_as(logits) if expert_bias is not None: scores_for_routing = scores + expert_bias _, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk) @@ -109,6 +112,9 @@ def compute_scores_for_aux_loss_pytorch( elif score_function == "sigmoid": scores = torch.sigmoid(logits) scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores + elif score_function == "sqrtsoftplus": + scores = torch.nn.functional.softplus(logits.float()).sqrt().type_as(logits) + scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores else: raise ValueError(f"Invalid score_function: {score_function}") @@ -165,8 +171,8 @@ def run_comparison( ) logits = logits.view(num_tokens, num_experts) logits.requires_grad = True - if enable_bias and score_function == "sigmoid": - expert_bias = torch.arange(num_experts, device="cuda") * 0.1 + if enable_bias and score_function in ("sigmoid", "sqrtsoftplus"): + expert_bias = torch.arange(num_experts, device="cuda", dtype=dtype) * 0.1 expert_bias = torch.flip(expert_bias, dims=[0]) expert_bias.requires_grad = True else: @@ -183,7 +189,7 @@ def run_comparison( # Run the original implementation # We do not support the capacity factor case - probs, routing_map = topk_softmax_sigmoid_pytorch( + probs, routing_map = topk_score_function_pytorch( logits=logits, topk=topk, use_pre_softmax=use_pre_softmax, @@ -252,6 +258,37 @@ def test_topk_sigmoid( ) +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("num_tokens", [2048, 7168, 8992]) +@pytest.mark.parametrize("num_experts", [128, 32]) +@pytest.mark.parametrize("topk", [4, 8]) +@pytest.mark.parametrize("group_topk", [None, 4]) +@pytest.mark.parametrize("scaling_factor", [None, 1.2]) +@pytest.mark.parametrize("enable_bias", [True, False]) +def test_topk_sqrtsoftplus( + dtype, + num_tokens, + num_experts, + topk, + group_topk, + scaling_factor, + enable_bias, +): + num_groups = 8 if group_topk else None + run_comparison( + dtype=dtype, + num_tokens=num_tokens, + num_experts=num_experts, + topk=topk, + use_pre_softmax=False, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function="sqrtsoftplus", + enable_bias=enable_bias, + ) + + @pytest.mark.parametrize("dtype", [torch.float32]) @pytest.mark.parametrize("num_tokens", [2048, 7168, 14234]) @pytest.mark.parametrize("num_experts", [128, 32]) @@ -287,7 +324,7 @@ def test_topk_softmax( @pytest.mark.parametrize("num_tokens", [2048, 7168, 14234]) @pytest.mark.parametrize("num_experts", [256, 128, 32]) @pytest.mark.parametrize("topk", [4, 8]) -@pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) +@pytest.mark.parametrize("score_function", ["softmax", "sigmoid", "sqrtsoftplus"]) def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function): if score_function == "sigmoid": # Construct the special logits to avoid inf in the sigmoid function @@ -396,15 +433,6 @@ def profile_topk_softmax( test_topk_softmax( torch.float32, num_tokens, num_experts, topk, use_pre_softmax, group_topk, scaling_factor ) - - -if __name__ == "__main__": - test_topk_softmax( - dtype=torch.float32, - num_tokens=1024, - num_experts=128, - topk=4, - use_pre_softmax=False, - group_topk=None, - scaling_factor=None, + test_topk_sqrtsoftplus( + torch.float32, num_tokens, num_experts, topk, group_topk, scaling_factor, enable_bias ) diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index fee9822ce9..c90152759e 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -238,28 +238,6 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int * - Sqrtsoftplus bwd * - Write the grad_logits to the global mem */ - // Sigmoid Post-processing bwd when topk > 1 - if (topk > 1 && score_function == 0) { - auto sum_fwd_input = - warp_reduce_on_shmem(local_act_from_fwd, num_experts, ReduceFuncType::SUM, lane_id); - // Put the result of output * grad to the comp_buf - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_comp_buf[i] = local_grad[i] * local_act_from_fwd[i]; - } - __syncwarp(); - auto sum_Output_x_Grad = - warp_reduce_on_shmem(local_comp_buf, num_experts, ReduceFuncType::SUM, lane_id); - // In-place update - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_grad[i] = - static_cast(local_grad[i]) / (static_cast(sum_fwd_input) + epsilon) - - static_cast(sum_Output_x_Grad) / - ((static_cast(sum_fwd_input) + epsilon) * - (static_cast(sum_fwd_input) + epsilon)); - } - __syncwarp(); - } - // Sqrtsoftplus: First compute sqrtsoftplus output from original logits // (needed for both post-processing bwd and activation bwd, compute once here) // For sqrtsoftplus, intermediate_output stores original logits @@ -273,15 +251,20 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int __syncwarp(); } - // Sqrtsoftplus Post-processing bwd when topk > 1 (normalization backward) - if (topk > 1 && score_function == 2) { + // Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1 (normalization backward) + if (topk > 1 && (score_function == 0 || score_function == 2)) { + // Select the correct activation output buffer: + // - Sigmoid: local_act_from_fwd already contains sigmoid output + // - Sqrtsoftplus: local_comp_buf contains sqrtsoftplus output computed above + DataType *act_output = (score_function == 0) ? local_act_from_fwd : local_comp_buf; + auto sum_fwd_input = - warp_reduce_on_shmem(local_comp_buf, num_experts, ReduceFuncType::SUM, lane_id); - // Compute sum of output * grad using registers instead of shared memory + warp_reduce_on_shmem(act_output, num_experts, ReduceFuncType::SUM, lane_id); + // Compute sum of output * grad using registers double local_sum_Output_x_Grad = 0.0; for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { local_sum_Output_x_Grad += - static_cast(local_grad[i]) * static_cast(local_comp_buf[i]); + static_cast(local_grad[i]) * static_cast(act_output[i]); } // Warp reduce the sum for (int s = 16; s > 0; s /= 2) { diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 3237658ccd..77ce4efd6f 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -134,8 +134,8 @@ __global__ void fused_topk_with_score_function_forward_kernel( scores[i] = static_cast(static_cast(scores[i]) + static_cast(expert_bias[i])); } + __syncwarp(); } - __syncwarp(); /*** * Section: Topk @@ -145,7 +145,7 @@ __global__ void fused_topk_with_score_function_forward_kernel( * - topk with expert bias */ // Topk on the scores - // The bias is not empty only happens at the sigmod case + // The bias being not empty happens at the sigmoid/sqrtsoftplus case if (group_topk > 0) { int group_size = num_experts / num_groups; // Top2 @@ -199,11 +199,11 @@ __global__ void fused_topk_with_score_function_forward_kernel( * Possible postprocess the scores after the topk operation * - Revert Expert bias * - Softmax - * - Sigmoid post-processing when topk > 1 + * - Sigmoid/Sqrtsoftplus post-processing when topk > 1 * - Write the result with scaling_factor */ // Revert Expert bias from the topk scores - if (expert_bias && score_function == 0) { + if (expert_bias && (score_function == 0 || score_function == 2)) { for (int i = lane_id; i < topk; i += kThreadsPerWarp) { topk_scores[i] = static_cast(topk_scores[i]) - static_cast(expert_bias[topk_indices[i]]); @@ -234,17 +234,6 @@ __global__ void fused_topk_with_score_function_forward_kernel( __syncwarp(); } - // score_function == 2 means sqrtsoftplus - if (score_function == 2) { - if (topk > 1) { - double sum_scores = warp_reduce_on_shmem(topk_scores, topk, ReduceFuncType::SUM, lane_id); - for (int i = lane_id; i < topk; i += kThreadsPerWarp) { - topk_scores[i] = static_cast(topk_scores[i]) / (sum_scores + epsilon); - } - } - __syncwarp(); - } - // Write the probs/routing_map to the output tensor for (int i = lane_id; i < topk; i += kThreadsPerWarp) { routing_map[pos_offset + topk_indices[i]] = true; @@ -363,7 +352,7 @@ __global__ void fused_topk_with_score_function_backward_kernel( /*** * Section: Backward of ops after the topk * - Backward of the used scaling_factor - * - Sigmoid Post-processing bwd when topk > 1 + * - Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1 * - Softmax bwd if use_pre_softmax is false */ // Backward of the used scaling_factor @@ -375,38 +364,6 @@ __global__ void fused_topk_with_score_function_backward_kernel( } __syncwarp(); - // Sigmoid Post-processing bwd when topk > 1 - if (topk > 1 && score_function == 0) { - double sum_fwd_input = masked_warp_reduce_on_shmem( - /*data ptr = */ local_act_from_fwd, - /*mask ptr = */ local_routing_map, - /*data size = */ num_experts, - /*reduce func = */ ReduceFuncType::SUM, lane_id); - // Put the result of output * grad to the comp_buf - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_comp_buf[i] = (local_routing_map[i] ? static_cast(local_grad[i]) * - static_cast(local_act_from_fwd[i]) - : 0.0f); - } - __syncwarp(); - double sum_Output_x_Grad = masked_warp_reduce_on_shmem( - /*data ptr = */ local_comp_buf, - /*mask ptr = */ local_routing_map, - /*data size = */ num_experts, - /*reduce func = */ ReduceFuncType::SUM, lane_id); - // In-place update - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - if (local_routing_map[i]) { - local_grad[i] = - static_cast(local_grad[i]) / (sum_fwd_input + epsilon) - - sum_Output_x_Grad / ((sum_fwd_input + epsilon) * (sum_fwd_input + epsilon)); - } else { - local_grad[i] = 0.0f; - } - } - __syncwarp(); - } - // Sqrtsoftplus: First compute sqrtsoftplus output from original logits // (needed for both post-processing bwd and activation bwd, compute once here) // For sqrtsoftplus, intermediate_output stores original logits @@ -420,20 +377,24 @@ __global__ void fused_topk_with_score_function_backward_kernel( __syncwarp(); } - // Sqrtsoftplus Post-processing bwd when topk > 1 (normalization backward) - if (topk > 1 && score_function == 2) { - // Now do the normalization backward (same as sigmoid) + // Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1 (normalization backward) + if (topk > 1 && (score_function == 0 || score_function == 2)) { + // Select the correct activation output buffer: + // - Sigmoid: local_act_from_fwd already contains sigmoid output + // - Sqrtsoftplus: local_comp_buf contains sqrtsoftplus output computed above + DataType *act_output = (score_function == 0) ? local_act_from_fwd : local_comp_buf; + double sum_fwd_input = masked_warp_reduce_on_shmem( - /*data ptr = */ local_comp_buf, + /*data ptr = */ act_output, /*mask ptr = */ local_routing_map, /*data size = */ num_experts, /*reduce func = */ ReduceFuncType::SUM, lane_id); - // Compute sum of output * grad using registers instead of shared memory + // Compute sum of output * grad using registers double local_sum_Output_x_Grad = 0.0; for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { if (local_routing_map[i]) { local_sum_Output_x_Grad += - static_cast(local_grad[i]) * static_cast(local_comp_buf[i]); + static_cast(local_grad[i]) * static_cast(act_output[i]); } } // Warp reduce the sum diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 9c79efbcf7..db38d05899 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -119,11 +119,10 @@ __device__ inline void apply_sqrtsoftplus_on_float(DataType *scores, int data_si for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { float x = static_cast(scores[i]); // softplus(x) = log(1 + exp(x)), numerically stable version + // Matches PyTorch's Softplus(beta=1.0, threshold=20.0) float softplus_val; if (x > 20.0f) { softplus_val = x; // for large x, softplus(x) ≈ x - } else if (x < -20.0f) { - softplus_val = expf(x); // for small x, softplus(x) ≈ exp(x) } else { softplus_val = log1pf(expf(x)); } @@ -133,9 +132,8 @@ __device__ inline void apply_sqrtsoftplus_on_float(DataType *scores, int data_si // sqrtsoftplus backward: // y = sqrt(softplus(x)) -// dy/dx = sigmoid(x) / (2 * y) -// where sigmoid(x) = 1 / (1 + exp(-x)) -// We need the original logits (x) to compute sigmoid, which we store in a separate buffer +// Matches PyTorch's Softplus(beta=1.0, threshold=20.0) +// We need the original logits (x) to compute the gradient template __device__ inline void apply_sqrtsoftplus_bwd_on_float(DataType *grad, DataType *fwd_output, DataType *logits_buf, int data_size, @@ -143,10 +141,16 @@ __device__ inline void apply_sqrtsoftplus_bwd_on_float(DataType *grad, DataType for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { float x = static_cast(logits_buf[i]); // original logit float y = static_cast(fwd_output[i]); // sqrtsoftplus output - // sigmoid(x) = 1 / (1 + exp(-x)) - float sigmoid_x = 1.0f / (1.0f + expf(-x)); - // dy/dx = sigmoid(x) / (2 * y) - float dy_dx = sigmoid_x / (2.0f * y + epsilon); + float dy_dx; + if (x > 20.0f) { + // When softplus(x) = x, y = sqrt(x), dy/dx = 1/(2*y) + dy_dx = 1.0f / (2.0f * y + epsilon); + } else { + // When softplus(x) = log(1+exp(x)), dy/dx = sigmoid(x) / (2*y) + // where sigmoid(x) = 1 / (1 + exp(-x)) + float sigmoid_x = 1.0f / (1.0f + expf(-x)); + dy_dx = sigmoid_x / (2.0f * y + epsilon); + } grad[i] = static_cast(static_cast(grad[i]) * dy_dx); } } diff --git a/transformer_engine/pytorch/csrc/extensions/router.cpp b/transformer_engine/pytorch/csrc/extensions/router.cpp index c4d1503a37..84534a89ba 100644 --- a/transformer_engine/pytorch/csrc/extensions/router.cpp +++ b/transformer_engine/pytorch/csrc/extensions/router.cpp @@ -23,7 +23,7 @@ std::tuple fused_topk_with_score_function_fw "num_tokens and num_experts must be greater than 0"); // Expert bias only happens at the sigmoid case if (expert_bias.has_value()) { - TORCH_CHECK(score_function == "sigmoid", + TORCH_CHECK(score_function == "sigmoid" || score_function == "sqrtsoftplus", "score_function must be sigmoid when expert_bias is not None"); } // Check if the score function is valid From b67ba4b48b9f385489dd38125530b08d894ad265 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Fri, 6 Feb 2026 06:33:50 +0000 Subject: [PATCH 4/6] update and add tests Signed-off-by: Xin Yao --- tests/pytorch/test_fused_router.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_fused_router.py b/tests/pytorch/test_fused_router.py index 303f0ee372..5ce6fcc76f 100644 --- a/tests/pytorch/test_fused_router.py +++ b/tests/pytorch/test_fused_router.py @@ -152,8 +152,9 @@ def run_comparison( enable_bias, ): # Set some parameters - if score_function == "sigmoid": - # Construct the special logits to avoid inf in the sigmoid function + if score_function in ("sigmoid", "sqrtsoftplus"): + # Construct logits with a narrow range to avoid very small activation values, + # which would cause precision loss when adding/subtracting expert bias in float32. offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4 logits = ( torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2 @@ -326,8 +327,8 @@ def test_topk_softmax( @pytest.mark.parametrize("topk", [4, 8]) @pytest.mark.parametrize("score_function", ["softmax", "sigmoid", "sqrtsoftplus"]) def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function): - if score_function == "sigmoid": - # Construct the special logits to avoid inf in the sigmoid function + if score_function in ("sigmoid", "sqrtsoftplus"): + # Construct logits with a narrow range to avoid very small activation values offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4 logits = ( torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2 From ef9e3ce57b0edf8d88d4703affcdb462ae665ac0 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Fri, 6 Feb 2026 06:55:17 +0000 Subject: [PATCH 5/6] update docstring Signed-off-by: Xin Yao --- .../common/include/transformer_engine/fused_router.h | 2 +- transformer_engine/pytorch/csrc/extensions/pybind.cpp | 9 +++++---- transformer_engine/pytorch/csrc/extensions/router.cpp | 2 +- transformer_engine/pytorch/router.py | 4 ++-- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/fused_router.h b/transformer_engine/common/include/transformer_engine/fused_router.h index fcf51b934a..794880d324 100644 --- a/transformer_engine/common/include/transformer_engine/fused_router.h +++ b/transformer_engine/common/include/transformer_engine/fused_router.h @@ -24,7 +24,7 @@ extern "C" { * \param[in] group_topk Grouped topk value. * \param[in] scaling_factor Scaling factor. * \param[in] score_function Score function, 0: sigmoid, 1: softmax, 2: sqrtsoftplus. - * \param[in] expert_bias Expert bias. (Only used at the sigmoid case) + * \param[in] expert_bias Expert bias. (Used at the sigmoid/sqrtsoftplus cases) * \param[out] probs Output tensor for probabilities. * \param[out] routing_map Output tensor for routing map. * \param[out] intermediate_output Output tensor for intermediate output. (Softmax/sigmoid output) diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 79dd9ea5ce..f00aff6ee2 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -325,19 +325,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &transformer_engine::pytorch::fused_topk_with_score_function_fwd, py::arg("logits"), py::arg("topk"), py::arg("use_pre_softmax"), py::arg("num_groups"), py::arg("group_topk"), py::arg("scaling_factor"), py::arg("score_function"), py::arg("expert_bias"), - "Fused topk softmax fwd"); + "Fused topk with score function fwd"); m.def("fused_topk_with_score_function_bwd", &transformer_engine::pytorch::fused_topk_with_score_function_bwd, py::arg("num_tokens"), py::arg("num_experts"), py::arg("routing_map"), py::arg("intermediate_output"), py::arg("grad_probs"), py::arg("topk"), py::arg("use_pre_softmax"), - py::arg("scaling_factor"), py::arg("score_function"), "Fused topk softmax bwd"); + py::arg("scaling_factor"), py::arg("score_function"), + "Fused topk with score function bwd"); m.def("fused_score_for_moe_aux_loss_fwd", &transformer_engine::pytorch::fused_score_for_moe_aux_loss_fwd, py::arg("logits"), - py::arg("topk"), py::arg("score_function"), "Fused topk softmax fwd"); + py::arg("topk"), py::arg("score_function"), "Fused aux loss with score function fwd"); m.def("fused_score_for_moe_aux_loss_bwd", &transformer_engine::pytorch::fused_score_for_moe_aux_loss_bwd, py::arg("num_tokens"), py::arg("num_experts"), py::arg("intermediate_output"), py::arg("grad_scores"), - py::arg("topk"), py::arg("score_function"), "Fused topk softmax bwd"); + py::arg("topk"), py::arg("score_function"), "Fused aux loss with score function bwd"); m.def("fused_moe_aux_loss_fwd", &transformer_engine::pytorch::fused_moe_aux_loss_fwd, py::arg("probs"), py::arg("tokens_per_expert"), py::arg("total_num_tokens"), py::arg("num_experts"), py::arg("num_rows"), py::arg("num_cols"), py::arg("topk"), diff --git a/transformer_engine/pytorch/csrc/extensions/router.cpp b/transformer_engine/pytorch/csrc/extensions/router.cpp index 84534a89ba..93921d2b2b 100644 --- a/transformer_engine/pytorch/csrc/extensions/router.cpp +++ b/transformer_engine/pytorch/csrc/extensions/router.cpp @@ -24,7 +24,7 @@ std::tuple fused_topk_with_score_function_fw // Expert bias only happens at the sigmoid case if (expert_bias.has_value()) { TORCH_CHECK(score_function == "sigmoid" || score_function == "sqrtsoftplus", - "score_function must be sigmoid when expert_bias is not None"); + "score_function must be sigmoid or sqrtsoftplus when expert_bias is not None"); } // Check if the score function is valid TORCH_CHECK(score_function == "softmax" || score_function == "sigmoid" || diff --git a/transformer_engine/pytorch/router.py b/transformer_engine/pytorch/router.py index 88f665ca14..eb53dd1b95 100644 --- a/transformer_engine/pytorch/router.py +++ b/transformer_engine/pytorch/router.py @@ -102,9 +102,9 @@ def fused_topk_with_score_function( used in the group topk scaling_factor : float score_function : str - currently support softmax, sigmoid and sqrtsoftplus + currently support softmax, sigmoid and sqrtsoftplus. expert_bias : torch.Tensor - could be used in the sigmoid + could be used with the sigmoid/sqrtsoftplus score functions. Returns ------- From b9654c1962054c43ace661f098ce292a332ffd54 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Feb 2026 06:57:58 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/extensions/pybind.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index f00aff6ee2..c851e422af 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -330,8 +330,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &transformer_engine::pytorch::fused_topk_with_score_function_bwd, py::arg("num_tokens"), py::arg("num_experts"), py::arg("routing_map"), py::arg("intermediate_output"), py::arg("grad_probs"), py::arg("topk"), py::arg("use_pre_softmax"), - py::arg("scaling_factor"), py::arg("score_function"), - "Fused topk with score function bwd"); + py::arg("scaling_factor"), py::arg("score_function"), "Fused topk with score function bwd"); m.def("fused_score_for_moe_aux_loss_fwd", &transformer_engine::pytorch::fused_score_for_moe_aux_loss_fwd, py::arg("logits"), py::arg("topk"), py::arg("score_function"), "Fused aux loss with score function fwd");