[Common][PyTorch] Add a new score func sqrtsoftplus to the fused router#2633
[Common][PyTorch] Add a new score func sqrtsoftplus to the fused router#2633yaox12 wants to merge 7 commits intoNVIDIA:mainfrom
sqrtsoftplus to the fused router#2633Conversation
Signed-off-by: Xin Yao <xiny@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
Greptile OverviewGreptile SummaryThis PR adds a new score function Key changes:
Implementation details:
The implementation is mathematically sound, properly tested, and maintains backward compatibility. Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant PyTorch as PyTorch Layer
participant Router as router.py
participant CPP as router.cpp
participant CUDA as CUDA Kernels
User->>PyTorch: Forward pass with logits
PyTorch->>Router: fused_topk_with_score_function(logits, score_function="sqrtsoftplus")
Router->>CPP: fused_topk_with_score_function_fwd(logits, score_function="sqrtsoftplus")
CPP->>CPP: Validate score_function in {softmax, sigmoid, sqrtsoftplus}
CPP->>CPP: Map "sqrtsoftplus" -> score_function_value=2
CPP->>CUDA: nvte_fused_topk_with_score_function_forward(score_function=2)
CUDA->>CUDA: Load logits to shared memory
CUDA->>CUDA: apply_sqrtsoftplus_on_float: y = sqrt(log(1 + exp(x)))
CUDA->>CUDA: Add expert_bias if provided
CUDA->>CUDA: Perform topk selection
CUDA->>CUDA: Revert expert_bias from topk scores
CUDA->>CUDA: Normalize: probs = scores / sum(scores) if topk > 1
CUDA-->>CPP: Return probs, routing_map, intermediate_output
CPP-->>Router: Return tensors
Router-->>PyTorch: Return probs, routing_map
User->>PyTorch: Backward pass with grad_probs
PyTorch->>Router: backward(grad_probs)
Router->>CPP: fused_topk_with_score_function_bwd(grad_probs, score_function="sqrtsoftplus")
CPP->>CUDA: nvte_fused_topk_with_score_function_backward(score_function=2)
CUDA->>CUDA: Load grad_probs and intermediate_output
CUDA->>CUDA: Backward through normalization (if topk > 1)
CUDA->>CUDA: Compute sqrtsoftplus output from saved logits
CUDA->>CUDA: apply_sqrtsoftplus_bwd_on_float: grad *= sigmoid(x) / (2*y)
CUDA-->>CPP: Return grad_logits
CPP-->>Router: Return grad tensor
Router-->>PyTorch: Return grad_logits
|
Additional Comments (1)
The header still documents |
Signed-off-by: Xin Yao <xiny@nvidia.com>
for more information, see https://pre-commit.ci
Description
sqrtsoftplusType of change
Changes
Please list the changes introduced in this PR:
Checklist: