Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 51 additions & 22 deletions tests/pytorch/test_fused_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def group_limited_topk(


# Pytorch-based topk softmax/sigmoid
def topk_softmax_sigmoid_pytorch(
def topk_score_function_pytorch(
logits: torch.Tensor,
topk: int,
use_pre_softmax: bool = False,
Expand Down Expand Up @@ -79,8 +79,11 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None):
else:
scores, top_indices = compute_topk(logits, topk, num_groups, group_topk)
probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
elif score_function == "sigmoid":
scores = torch.sigmoid(logits.float()).type_as(logits)
elif score_function in ("sigmoid", "sqrtsoftplus"):
if score_function == "sigmoid":
scores = torch.sigmoid(logits.float()).type_as(logits)
else:
scores = torch.nn.functional.softplus(logits.float()).sqrt().type_as(logits)
if expert_bias is not None:
scores_for_routing = scores + expert_bias
_, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk)
Expand Down Expand Up @@ -109,6 +112,9 @@ def compute_scores_for_aux_loss_pytorch(
elif score_function == "sigmoid":
scores = torch.sigmoid(logits)
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
elif score_function == "sqrtsoftplus":
scores = torch.nn.functional.softplus(logits.float()).sqrt().type_as(logits)
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
else:
raise ValueError(f"Invalid score_function: {score_function}")

Expand Down Expand Up @@ -146,8 +152,9 @@ def run_comparison(
enable_bias,
):
# Set some parameters
if score_function == "sigmoid":
# Construct the special logits to avoid inf in the sigmoid function
if score_function in ("sigmoid", "sqrtsoftplus"):
# Construct logits with a narrow range to avoid very small activation values,
# which would cause precision loss when adding/subtracting expert bias in float32.
offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4
logits = (
torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2
Expand All @@ -165,8 +172,8 @@ def run_comparison(
)
logits = logits.view(num_tokens, num_experts)
logits.requires_grad = True
if enable_bias and score_function == "sigmoid":
expert_bias = torch.arange(num_experts, device="cuda") * 0.1
if enable_bias and score_function in ("sigmoid", "sqrtsoftplus"):
expert_bias = torch.arange(num_experts, device="cuda", dtype=dtype) * 0.1
expert_bias = torch.flip(expert_bias, dims=[0])
expert_bias.requires_grad = True
else:
Expand All @@ -183,7 +190,7 @@ def run_comparison(

# Run the original implementation
# We do not support the capacity factor case
probs, routing_map = topk_softmax_sigmoid_pytorch(
probs, routing_map = topk_score_function_pytorch(
logits=logits,
topk=topk,
use_pre_softmax=use_pre_softmax,
Expand Down Expand Up @@ -252,6 +259,37 @@ def test_topk_sigmoid(
)


@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 8992])
@pytest.mark.parametrize("num_experts", [128, 32])
@pytest.mark.parametrize("topk", [4, 8])
@pytest.mark.parametrize("group_topk", [None, 4])
@pytest.mark.parametrize("scaling_factor", [None, 1.2])
@pytest.mark.parametrize("enable_bias", [True, False])
def test_topk_sqrtsoftplus(
dtype,
num_tokens,
num_experts,
topk,
group_topk,
scaling_factor,
enable_bias,
):
num_groups = 8 if group_topk else None
run_comparison(
dtype=dtype,
num_tokens=num_tokens,
num_experts=num_experts,
topk=topk,
use_pre_softmax=False,
num_groups=num_groups,
group_topk=group_topk,
scaling_factor=scaling_factor,
score_function="sqrtsoftplus",
enable_bias=enable_bias,
)


@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234])
@pytest.mark.parametrize("num_experts", [128, 32])
Expand Down Expand Up @@ -287,10 +325,10 @@ def test_topk_softmax(
@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234])
@pytest.mark.parametrize("num_experts", [256, 128, 32])
@pytest.mark.parametrize("topk", [4, 8])
@pytest.mark.parametrize("score_function", ["softmax", "sigmoid"])
@pytest.mark.parametrize("score_function", ["softmax", "sigmoid", "sqrtsoftplus"])
def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function):
if score_function == "sigmoid":
# Construct the special logits to avoid inf in the sigmoid function
if score_function in ("sigmoid", "sqrtsoftplus"):
# Construct logits with a narrow range to avoid very small activation values
offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4
logits = (
torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2
Expand Down Expand Up @@ -396,15 +434,6 @@ def profile_topk_softmax(
test_topk_softmax(
torch.float32, num_tokens, num_experts, topk, use_pre_softmax, group_topk, scaling_factor
)


if __name__ == "__main__":
test_topk_softmax(
dtype=torch.float32,
num_tokens=1024,
num_experts=128,
topk=4,
use_pre_softmax=False,
group_topk=None,
scaling_factor=None,
test_topk_sqrtsoftplus(
torch.float32, num_tokens, num_experts, topk, group_topk, scaling_factor, enable_bias
)
Original file line number Diff line number Diff line change
Expand Up @@ -78,34 +78,40 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi
* Possible preprocess the scores before the topk operation
* - Pre-softmax
* - Sigmoid
* - Sigmoid post-processing when topk > 1
* - Sqrtsoftplus
* - Sigmoid/Sqrtsoftplus post-processing when topk > 1
* This is in-place scores update
*/
// score_function == 1 means softmax
if (score_function == 1) {
if (score_function == 1) { // score_function == 1 means softmax
// Apply softmax to the logits before the topk
apply_softmax_on_float(local_logits, num_experts, lane_id);
__syncwarp();
// Save the softmax output for backward
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
intermediate_output[pos_offset + i] = local_logits[i];
}
}

// score_function == 0 means sigmoid
if (score_function == 0) {
} else if (score_function == 0) { // score_function == 0 means sigmoid
// Apply sigmoid to the logits
apply_sigmoid_on_float(local_logits, num_experts, lane_id);
__syncwarp();
// Save the sigmoid output for backward
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
intermediate_output[pos_offset + i] = local_logits[i];
}
} else if (score_function == 2) { // score_function == 2 means sqrtsoftplus
// First save the original logits for backward (needed for gradient computation)
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
intermediate_output[pos_offset + i] = local_logits[i]; // Save original logits
}
__syncwarp();
// Apply sqrtsoftplus to the logits
apply_sqrtsoftplus_on_float(local_logits, num_experts, lane_id);
}

__syncwarp(); //Confirm the scores is written to the softmax/sigmoid output
__syncwarp(); //Confirm the scores is written to the output

if (score_function == 0) {
// Sigmoid/Sqrtsoftplus post-processing when topk > 1
if (score_function == 0 || score_function == 2) {
if (topk > 1) {
auto sum_logits =
warp_reduce_on_shmem(local_logits, num_experts, ReduceFuncType::SUM, lane_id);
Expand Down Expand Up @@ -227,31 +233,53 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int
/***
* Section: Backward of ops before the topk
* - Pre-softmax bwd
* - Sigmoid Post-processing bwd when topk > 1
* - Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1
* - Sigmoid bwd
* - Sqrtsoftplus bwd
* - Write the grad_logits to the global mem
*/
// Sigmoid Post-processing bwd when topk > 1
if (topk > 1 && score_function == 0) {
auto sum_fwd_input =
warp_reduce_on_shmem(local_act_from_fwd, num_experts, ReduceFuncType::SUM, lane_id);
// Put the result of output * grad to the comp_buf
// Sqrtsoftplus: First compute sqrtsoftplus output from original logits
// (needed for both post-processing bwd and activation bwd, compute once here)
// For sqrtsoftplus, intermediate_output stores original logits
if (score_function == 2) {
// Copy original logits to local_comp_buf and apply sqrtsoftplus in-place
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_comp_buf[i] = local_grad[i] * local_act_from_fwd[i];
local_comp_buf[i] = local_act_from_fwd[i];
}
__syncwarp();
auto sum_Output_x_Grad =
warp_reduce_on_shmem(local_comp_buf, num_experts, ReduceFuncType::SUM, lane_id);
apply_sqrtsoftplus_on_float(local_comp_buf, num_experts, lane_id);
__syncwarp();
}

// Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1 (normalization backward)
if (topk > 1 && (score_function == 0 || score_function == 2)) {
// Select the correct activation output buffer:
// - Sigmoid: local_act_from_fwd already contains sigmoid output
// - Sqrtsoftplus: local_comp_buf contains sqrtsoftplus output computed above
DataType *act_output = (score_function == 0) ? local_act_from_fwd : local_comp_buf;

auto sum_fwd_input =
warp_reduce_on_shmem(act_output, num_experts, ReduceFuncType::SUM, lane_id);
// Compute sum of output * grad using registers
double local_sum_Output_x_Grad = 0.0;
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_sum_Output_x_Grad +=
static_cast<double>(local_grad[i]) * static_cast<double>(act_output[i]);
}
// Warp reduce the sum
for (int s = 16; s > 0; s /= 2) {
local_sum_Output_x_Grad += __shfl_xor_sync(0xffffffff, local_sum_Output_x_Grad, s);
}
double sum_Output_x_Grad = local_sum_Output_x_Grad;
// In-place update
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_grad[i] =
static_cast<double>(local_grad[i]) / (static_cast<double>(sum_fwd_input) + epsilon) -
static_cast<double>(sum_Output_x_Grad) /
((static_cast<double>(sum_fwd_input) + epsilon) *
(static_cast<double>(sum_fwd_input) + epsilon));
sum_Output_x_Grad / ((static_cast<double>(sum_fwd_input) + epsilon) *
(static_cast<double>(sum_fwd_input) + epsilon));
}
__syncwarp();
}
__syncwarp();

// Pre-softmax bwd
if (score_function == 1) {
Expand All @@ -264,6 +292,14 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int
apply_sigmoid_bwd_on_float(local_grad, local_act_from_fwd, num_experts, lane_id);
__syncwarp();
}
// Sqrtsoftplus bwd
// For sqrtsoftplus, local_comp_buf already contains sqrtsoftplus output computed earlier
// Now compute gradient: dy/dx = sigmoid(x) / (2 * y)
if (score_function == 2) {
apply_sqrtsoftplus_bwd_on_float(local_grad, local_comp_buf, local_act_from_fwd, num_experts,
lane_id);
__syncwarp();
}
// Write the grad_logits to the global mem
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
grad_logits[pos_offset + i] = local_grad[i];
Expand Down
Loading