Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/wheel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ jobs:
args: "torch-step --grad-accum 2 --model-dtype fp32 --matmul-dtype bf16"
- name: "Torch Chunking"
args: "torch-step --grad-accum 4 --model-dtype bf16 --matmul-dtype bf16 --lmhead-chunks 4"
- name: "Torch BF16 Custom Matmul"
args: "torch-step --grad-accum 1 --model-dtype bf16 --matmul-dtype bf16 --custom-matmul"
steps:
- name: Checkout code
uses: actions/checkout@v4
Expand Down
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ add_library(llmq-common SHARED
src/kernels/random.cu
src/kernels/fill.cu
src/kernels/convert.cu
src/kernels/gemm_mma.cu
)

target_link_libraries(llmq-common PRIVATE ${CUFILE_LIBS} cudnn_frontend ${CUDNN_LIBRARY} nvidia::nccl fmt::fmt-header-only)
Expand Down Expand Up @@ -256,6 +257,7 @@ if (BUILD_TESTS)
src/testing/test-swiglu.cu
src/testing/test-rope.cu
src/testing/test-classifier.cu
src/testing/test-gemm.cpp
)
target_link_libraries(unit-tests PRIVATE llmq-common Catch2::Catch2 CLI11::CLI11)
target_compile_options(unit-tests PUBLIC $<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr> $<$<COMPILE_LANGUAGE:CUDA>:-lineinfo>)
Expand Down
2 changes: 2 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def setup_options(config: pyllmq.TrainingConfig) -> pyllmq.LLamaOptions:
# Performance options
options.use_cuda_graphs = config.use_cuda_graphs
options.use_all_to_all_reduce = config.all_to_all_reduce
options.use_custom_matmul = config.custom_matmul
options.use_write_combined = config.write_combined

# Other options
Expand Down Expand Up @@ -99,6 +100,7 @@ def add_toggle(arg: str, default: bool, help: str):
add_toggle("memcpy-send-recv", True, "Use cudaMemcpyAsync for send/recv (faster on PCIe). Only meaningful in conjunction with all-to-all-reduce")
add_toggle("all-to-all-reduce", True, "Use custom all-to-all reduce which can be used with memcpy-send-recv")
add_toggle("write-combined", False, "Use write-combined memory. May give faster PCIe transfers.")
add_toggle("custom-matmul", False, "Use custom matmul implementation.")

args = parser.parse_args()
return pyllmq.TrainingConfig(**vars(args))
Expand Down
7 changes: 5 additions & 2 deletions src/binding/binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ NB_MODULE(_pyllmq, m) {
bool recompute_ffn, bool recompute_qkv, bool recompute_att, bool recompute_block, bool offload_residual,
bool use_cuda_graphs,
bool offload_master, bool offload_quants, bool offload_opt_m, bool offload_opt_v, bool offload_grads, bool use_zero_copy,
bool use_write_combined, bool shard_weights, bool persistent_quants, bool shard_gradients, bool use_all_to_all_reduce,
bool use_write_combined, bool shard_weights, bool persistent_quants, bool shard_gradients, bool use_all_to_all_reduce, bool use_custom_matmul,
bool init_projections_to_zero, int lmhead_chunks, int attn_bwd_chunks,
const std::string matmul_type, const std::string gradient_type, const std::string master_dtype, const std::string momentum_type, const std::string variance_type) {
new (t) LLamaOptions{
Expand All @@ -176,6 +176,7 @@ NB_MODULE(_pyllmq, m) {
.PersistentQuants = persistent_quants,
.ShardGradients = shard_gradients,
.UseAllToAllReduce = use_all_to_all_reduce,
.UseCustomMatmul = use_custom_matmul,
.InitProjectionsToZero = init_projections_to_zero,
.MatmulType = opt_dtype_from_str(matmul_type),
.GradientType = opt_dtype_from_str(gradient_type),
Expand All @@ -193,7 +194,8 @@ NB_MODULE(_pyllmq, m) {
nb::arg("offload_opt_v") = false, nb::arg("offload_grads") = false,
nb::arg("use_zero_copy") = false, nb::arg("use_write_combined") = false,
nb::arg("shard_weights") = false, nb::arg("persistent_quants") = false,
nb::arg("shard_gradients") = false, nb::arg("use_all_to_all_reduce") = false,
nb::arg("shard_gradients") = false,
nb::arg("use_all_to_all_reduce") = false, nb::arg("use_custom_matmul") = false,
nb::arg("init_projections_to_zero") = false, nb::arg("lmhead_chunks") = 1,
nb::arg("attn_bwd_chunks") = 1,
nb::arg("matmul_type") = "", nb::arg("gradient_type") = "",
Expand Down Expand Up @@ -221,6 +223,7 @@ NB_MODULE(_pyllmq, m) {
.def_rw("persistent_quants", &LLamaOptions::PersistentQuants)
.def_rw("shard_gradients", &LLamaOptions::ShardGradients)
.def_rw("use_all_to_all_reduce", &LLamaOptions::UseAllToAllReduce)
.def_rw("use_custom_matmul", &LLamaOptions::UseCustomMatmul)
.def_rw("init_projections_to_zero", &LLamaOptions::InitProjectionsToZero)
.def_prop_rw("matmul_type", [](const LLamaOptions* opt){ return opt->matmul_dtype(); },
[](LLamaOptions* opt, const std::string& dtype_str){ opt->MatmulType = opt_dtype_from_str(dtype_str); })
Expand Down
9 changes: 6 additions & 3 deletions src/binding/kernel_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,12 @@ void bind_grouped_loss_sum(const CudaArray& out, const CudaArray& per_token_loss
void bind_matmul(const CudaArray& c, const CudaArray& a, const CudaArray& b, const std::optional<CudaArray>& bias,
const std::optional<CudaArray>& scale_a, const std::optional<CudaArray>& scale_b,
std::uintptr_t cublaslt_handle, const CudaArray& workspace,
int mode_, bool accumulate, const std::uintptr_t stream) {
int mode_, bool accumulate, const std::uintptr_t stream, int backend_) {
NB_CHECK_NDIMS(a, 2);
NB_CHECK_NDIMS(b, 2);
NB_CHECK_NDIMS(c, 2);
EMMTranspose mode = static_cast<EMMTranspose>(mode_);
Comment thread
ngc92 marked this conversation as resolved.
EMatmulBackend backend = static_cast<EMatmulBackend>(backend_);

Comment thread
ngc92 marked this conversation as resolved.
// torch vs cublas: a @ b <=> b^t @ a^ t
const bool a_transposed = (mode == EMMTranspose::TN || mode == EMMTranspose::TT);
Expand Down Expand Up @@ -363,7 +364,7 @@ void bind_matmul(const CudaArray& c, const CudaArray& a, const CudaArray& b, con
Tensor c_t = to_tensor(c);
Tensor ws_t = to_tensor(workspace);
matmul(c_t, to_tensor(b), to_tensor(a), to_tensor(bias), scale_b_ptr, scale_a_ptr,
reinterpret_cast<cublasLtHandle_t>(cublaslt_handle), ws_t, M, N, K, inv_mode, accumulate, as_stream(stream));
reinterpret_cast<cublasLtHandle_t>(cublaslt_handle), ws_t, M, N, K, inv_mode, accumulate, as_stream(stream), backend);
}

cublasLtHandle_t create_cublaslt_handle();
Expand Down Expand Up @@ -546,7 +547,9 @@ void register_kernels(nanobind::module_& m) {
m.def("grouped_loss_sum", &bind_grouped_loss_sum, nb::arg("out"), nb::arg("per_token_loss"), nb::arg("stream") = 0);

// Matmul
m.def("matmul", &bind_matmul, nb::arg("c"), nb::arg("a"), nb::arg("b"), nb::arg("bias") = std::nullopt, nb::arg("scale_a") = std::nullopt, nb::arg("scale_b") = std::nullopt, nb::arg("cublaslt_handle"), nb::arg("workspace"), nb::arg("mode"), nb::arg("accumulate") = false, nb::arg("stream") = 0);
m.def("matmul", &bind_matmul, nb::arg("c"), nb::arg("a"), nb::arg("b"), nb::arg("bias") = std::nullopt,
nb::arg("scale_a") = std::nullopt, nb::arg("scale_b") = std::nullopt, nb::arg("cublaslt_handle"), nb::arg("workspace"),
nb::arg("mode"), nb::arg("accumulate") = false, nb::arg("stream") = 0, nb::arg("backend") = static_cast<int>(EMatmulBackend::CuBLAS));
m.def("create_cublas_handle", &bind_create_cublas_handle);
m.def("destroy_cublas_handle", &bind_destroy_cublas_handle);

Expand Down
4 changes: 4 additions & 0 deletions src/binding/python/tests/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def _create_options(config: TrainingConfig) -> pyllmq.LLamaOptions:
options.shard_gradients = config.shard_gradients
options.shard_weights = config.shard_weights

options.use_all_to_all_reduce = config.all_to_all_reduce
options.use_custom_matmul = config.custom_matmul

if config.matmul_dtype:
options.matmul_type = config.matmul_dtype
if config.gradient_dtype:
Expand Down Expand Up @@ -167,6 +170,7 @@ def parse_args(args: list = None) -> TrainingConfig:
parser.add_argument("--memcpy-send-recv", action="store_true")
parser.add_argument("--all-to-all-reduce", action="store_true")
parser.add_argument("--write-combined", action="store_true")
parser.add_argument("--custom-matmul", action="store_true")
args = parser.parse_args(args=args)

cfg = TrainingConfig(**vars(args))
Expand Down
1 change: 1 addition & 0 deletions src/binding/python/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class TrainingConfig:
memcpy_all_gather: bool = False
memcpy_send_recv: bool = False
all_to_all_reduce: bool = False
custom_matmul: bool = False
write_combined: bool = False
use_zero_copy: bool = False

Expand Down
Loading