Refactor(linear): split LinearBackward kernel into 3 independent kernels#142
Refactor(linear): split LinearBackward kernel into 3 independent kernels#142chen2021673 wants to merge 7 commits intomasterfrom
Conversation
283d083 to
23d301b
Compare
Move grad_flags logic from kernel to autograd layer. The monolithic LinearBackward kernel is replaced by LinearBackwardInput, LinearBackwardWeight, and LinearBackwardBias — each a pure compute operation with no autograd-related parameters.
Move needs_input_grad logic from kernel to autograd layer. The monolithic MatmulBackward kernel is replaced by MatmulBackwardInput1 and MatmulBackwardInput2.
…ls; rename MatmulBackwardInput1/2 - Add gemm.cuh / gemm.cu: GemmParams struct + GemmCuda() dispatch (cublasGemmEx or cublasGemmStridedBatchedEx based on batch_count), GetCublasHandle(), GetCudaStream() shared across all GEMM-using kernels - Split matmul kernels (CPU + CUDA) out of linear.cc / linear.cu into dedicated matmul.cc / matmul.cu; linear.* now only contains the four Linear kernels - Rename MatmulBackwardInput1 → MatmulBackwardInput, MatmulBackwardInput2 → MatmulBackwardOther for semantic clarity matching MatmulForward(input, other) parameter names - Rewrite outer.cu to use GemmCuda() (OuterForward + bf16 backward paths); keep cublasSgemv for the fp32 backward path (more efficient, bf16 unsupported)
ae80cec to
88579ba
Compare
…es in linear kernels
88579ba to
252e6cd
Compare
|
另外我看 cpu 的改动也挺多,但看不出什么问题,最好也辛苦验证下精度没问题 |
…s to designated initializers - Save input1_dims_/input2_dims_ in Matmul::SetupContext to avoid Dims() calls on potentially-null saved tensors in Backward - Get device from grad_output instead of input1 in Matmul::Backward - Add CHECK guards before dereferencing nullable saved tensors - Convert all GemmParams/SgemvParams construction in linear.cu, matmul.cu, outer.cu to C++20 designated initializer form
…evice param GemmParams and SgemvParams are pure problem descriptions and should not carry runtime state. Move handle acquisition into GemmCuda/SgemvCuda via a device parameter, inline the dynamic_cast directly. Remove the public GetCublasHandle/GetCudaStream helpers from gemm.cuh.
There was a problem hiding this comment.
gemm.cuh 和 gemm.cu 都放到 src/kernel/cuda/common/ 目录下吧,include 目录下理论上只放对外提供的接口,最早一些文件没分太仔细,遗留的之后我统一改,新增文件还是遵循这个原则。
| const cudaDataType_t type_c = ToCudaDataType(p.output_dtype); | ||
| // Always use CUBLAS_COMPUTE_32F: required for bf16/fp16 correctness, | ||
| // and fine for fp32 (same compute path). | ||
| const cublasComputeType_t ctype = CUBLAS_COMPUTE_32F; |
| // When bs==1 and fp32, use cublasSgemv (more efficient than GEMM for matrix-vector). | ||
| // cublasSgemv does not support bf16, so bf16 falls through to GemmCuda. | ||
| if (bs == 1 && dtype == DataType::kFLOAT32) { | ||
| SgemvCuda(device, SgemvParams{ |
There was a problem hiding this comment.
我看原来是 fp32 统一走 sgemm,现在改成了 bs=1&fp32 时走 sgemmv,有测试性能收益吗?如果暂时没有明显收益的话,建议这个 pr 先保持原有逻辑,后续单独优化矩阵乘性能(可能需要更复杂的分类讨论)再引入走 gemv 的逻辑。
There was a problem hiding this comment.
目前模型应该没有 case 进 bs=1这个分支,现有测例不覆盖这块
|
|
||
| const std::vector<int64_t> weight_dims | ||
| = transpose ? std::vector<int64_t>{out_features, in_features} : std::vector<int64_t>{in_features, out_features}; | ||
| auto compute_dtype = weight->Dtype(); |
There was a problem hiding this comment.
这里 compute_dtype 现在不走 input/weight 类型提升获取了,会有问题吗?
There was a problem hiding this comment.
LinearBackwardInput 只计算 grad_input = grad_output × weight,理论上单独看这个函数,现在才是 compute dtype 更合理的来源。我做一下精度测试,没有问题的话保留这个写法。
| // Compute dtype determined by saved tensors (forward compute dtype), not grad_output | ||
| DataType compute_dtype = PromoteDataTypes(input_dtype, weight_dtype); | ||
| // For bf16 compute, accumulate in fp32 to preserve precision. | ||
| auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; |
There was a problem hiding this comment.
这里这么写稍微有点 hack,先留个 FIXME 吧,等后面修 autograd/autocast 时看下怎么改合适。
|
麻烦贴一下测试通过的截图。 |
include/ is for public-facing interfaces only; gemm primitives are internal, so relocate them under src/. Update all include paths. Also rename ctype -> compute_type and add FIXME on bf16 output dtype promotion hack in linear backward passes.

Summary
Architecture refactoring of Linear/Matmul/Outer kernels.
The core idea is separation of concerns — moving the decision of whether a gradient should be computed from the kernel layer up to the autograd layer, making kernels pure compute functions. At the same time, unified GEMM/SGEMV primitives are abstracted at the bottom layer to eliminate duplicated cuBLAS boilerplate.
Changes