diff --git a/examples/broadcast.cc b/examples/broadcast.cc new file mode 100644 index 0000000..43cd2cb --- /dev/null +++ b/examples/broadcast.cc @@ -0,0 +1,328 @@ +/** + * InfiniCCL Example/Test: Broadcast + * + * Runs a small suite of boundary cases: + * 1. count = 0 → no-op success + * 2. out-of-place, root = size - 1 + * 3. out-of-place, non-root sendbuff = nullptr (documented contract) + * 4. in-place (sendbuff == recvbuff), root = 0 + * 5. in-place (sendbuff == recvbuff), root = size - 1 + * 6. count > INT_MAX bytes (chunking path), gated by INFINI_BROADCAST_LARGE=1 + * 7. invalid root (-1 and size) → infiniInvalidArgument + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "infiniccl.h" +#include "utils.h" + +#include "backend_manifest.h" +#include "device.h" +#include "runtime.h" +#include "traits.h" + +using namespace infini::ccl; + +namespace { + +struct CaseResult { + bool ok = true; + bool skipped = false; + std::string note; +}; + +bool AllRanksOk(bool local_ok) { + int local = local_ok ? 1 : 0; + int global = 0; + MPI_Allreduce(&local, &global, 1, MPI_INT, MPI_LAND, MPI_COMM_WORLD); + return global != 0; +} + +void PrintCase(int rank, const std::string &name, const CaseResult &local, + bool global_ok) { + if (rank != 0) { + return; + } + const char *GREEN = "\033[32m"; + const char *YELLOW = "\033[33m"; + const char *RED = "\033[31m"; + const char *RESET = "\033[0m"; + + std::string status; + if (local.skipped) { + status = std::string(YELLOW) + "SKIP" + RESET; + } else if (global_ok) { + status = std::string(GREEN) + "PASS" + RESET; + } else { + status = std::string(RED) + "FAIL" + RESET; + } + + std::cout << "[" << name << "] " << status; + if (!local.note.empty()) { + std::cout << " (rank0: " << local.note << ")"; + } + std::cout << std::endl; +} + +// Broadcasts `count` floats from `root` and verifies every rank receives +// `expected`. All ranks must call with the same `root`, `count`, and `inplace`. +// +// Out-of-place: root passes a separate sendbuff; non-root passes `nullptr`. +// In-place: every rank passes `sendbuff == recvbuff`; root pre-fills the +// recv buffer with the source data, non-root with garbage that +// must be overwritten by the broadcast. +template +CaseResult RunBasicFloat32(infiniComm_t comm, int rank, int root, size_t count, + float expected, bool inplace) { + using Rt = Runtime; + const size_t total_bytes = count * sizeof(float); + + std::vector h_init(count, expected); + std::vector h_garbage(count, -1.0f); + + float *d_recv = nullptr; + float *d_send_owned = + nullptr; // separate allocation, only for out-of-place root + + CHECK_RT(Rt, Rt::Malloc(&d_recv, total_bytes)); + + // Pre-fill the receive buffer. + // - In-place root: recv holds the source data (and stays unchanged). + // - Everyone else: garbage that must be overwritten by the broadcast. + const bool recv_holds_source = inplace && (rank == root); + CHECK_RT(Rt, Rt::Memcpy(d_recv, + recv_holds_source ? h_init.data() : h_garbage.data(), + total_bytes, Rt::MemcpyHostToDevice)); + + // Out-of-place root needs a separate sendbuff carrying the source data. + if (!inplace && rank == root) { + CHECK_RT(Rt, Rt::Malloc(&d_send_owned, total_bytes)); + CHECK_RT(Rt, Rt::Memcpy(d_send_owned, h_init.data(), total_bytes, + Rt::MemcpyHostToDevice)); + } + + // Resolve the send pointer per mode. + const void *send_ptr = nullptr; + if (inplace) { + send_ptr = d_recv; // every rank passes the same pointer as recv + } else if (rank == root) { + send_ptr = d_send_owned; + } + // Out-of-place non-root: send_ptr stays nullptr (documented contract). + + infiniResult_t status = infiniBroadcast(send_ptr, d_recv, count, + infiniFloat32, root, comm, nullptr); + + CaseResult result; + if (status != infiniSuccess) { + result.ok = false; + result.note = + "infiniBroadcast returned " + std::to_string(static_cast(status)); + } else { + CHECK_RT(Rt, Rt::StreamSynchronize(nullptr)); + std::vector h_recv(count); + CHECK_RT(Rt, Rt::Memcpy(h_recv.data(), d_recv, total_bytes, + Rt::MemcpyDeviceToHost)); + + for (size_t i = 0; i < count; ++i) { + if (std::fabs(h_recv[i] - expected) > 1e-3) { + result.ok = false; + result.note = "value mismatch at index " + std::to_string(i); + break; + } + } + } + + if (d_send_owned) { + CHECK_RT(Rt, Rt::Free(d_send_owned)); + } + CHECK_RT(Rt, Rt::Free(d_recv)); + return result; +} + +CaseResult Case_Count0(infiniComm_t comm) { + // All ranks pass nullptrs; the impl must short-circuit before any buffer + // access. + infiniResult_t status = + infiniBroadcast(nullptr, nullptr, 0, infiniFloat32, 0, comm, nullptr); + if (status != infiniSuccess) { + return {false, false, + "expected infiniSuccess, got " + + std::to_string(static_cast(status))}; + } + return {}; +} + +template +CaseResult Case_OutOfPlaceRootLast(infiniComm_t comm, int rank, int size) { + return RunBasicFloat32(comm, rank, /*root=*/size - 1, /*count=*/1024, + /*expected=*/7.5f, /*inplace=*/false); +} + +template +CaseResult Case_OutOfPlaceNonRootNullSend(infiniComm_t comm, int rank) { + // Out-of-place mode passes nullptr as sendbuff on non-root ranks. This case + // locks that contract in as an explicit, named check. + return RunBasicFloat32(comm, rank, /*root=*/0, /*count=*/2048, + /*expected=*/-3.25f, /*inplace=*/false); +} + +template +CaseResult Case_InplaceRootZero(infiniComm_t comm, int rank) { + // sendbuff == recvbuff on every rank; root's value must survive, non-root + // must be overwritten. + return RunBasicFloat32(comm, rank, /*root=*/0, /*count=*/1024, + /*expected=*/11.25f, /*inplace=*/true); +} + +template +CaseResult Case_InplaceRootLast(infiniComm_t comm, int rank, int size) { + return RunBasicFloat32(comm, rank, /*root=*/size - 1, /*count=*/1024, + /*expected=*/-42.5f, /*inplace=*/true); +} + +template +CaseResult Case_LargeCount(infiniComm_t comm, int rank) { + using Rt = Runtime; + if (std::getenv("INFINI_BROADCAST_LARGE") == nullptr) { + return {true, true, "set INFINI_BROADCAST_LARGE=1 to enable (~2GB/rank)"}; + } + // Force the chunked MPI_Bcast path: byte count > INT_MAX. + const size_t count = static_cast(std::numeric_limits::max()) + + static_cast(1024); + const std::int8_t expected = 0x5A; + const size_t total_bytes = count * sizeof(std::int8_t); + + std::vector h_send; + if (rank == 0) { + h_send.assign(count, expected); + } + + std::int8_t *d_send = nullptr; + std::int8_t *d_recv = nullptr; + if (rank == 0) { + CHECK_RT(Rt, Rt::Malloc(&d_send, total_bytes)); + CHECK_RT(Rt, Rt::Memcpy(d_send, h_send.data(), total_bytes, + Rt::MemcpyHostToDevice)); + } + CHECK_RT(Rt, Rt::Malloc(&d_recv, total_bytes)); + + infiniResult_t status = + infiniBroadcast(rank == 0 ? d_send : nullptr, d_recv, count, infiniChar, + /*root=*/0, comm, nullptr); + + CaseResult result; + if (status != infiniSuccess) { + result.ok = false; + result.note = + "infiniBroadcast returned " + std::to_string(static_cast(status)); + } else { + CHECK_RT(Rt, Rt::StreamSynchronize(nullptr)); + // Sample head, middle, and tail to avoid scanning ~2GB. + std::int8_t probes[3] = {-1, -1, -1}; + CHECK_RT(Rt, Rt::Memcpy(&probes[0], d_recv, sizeof(std::int8_t), + Rt::MemcpyDeviceToHost)); + CHECK_RT(Rt, Rt::Memcpy(&probes[1], d_recv + (count / 2), + sizeof(std::int8_t), Rt::MemcpyDeviceToHost)); + CHECK_RT(Rt, Rt::Memcpy(&probes[2], d_recv + (count - 1), + sizeof(std::int8_t), Rt::MemcpyDeviceToHost)); + if (probes[0] != expected || probes[1] != expected || + probes[2] != expected) { + result.ok = false; + result.note = "head/mid/tail mismatch"; + } + } + + if (d_send) { + CHECK_RT(Rt, Rt::Free(d_send)); + } + CHECK_RT(Rt, Rt::Free(d_recv)); + return result; +} + +CaseResult Case_InvalidRoot(infiniComm_t comm, int size) { + // Tiny dummy buffers — the validator must reject `root` before touching + // them. Passing `count=1` with valid datatype ensures no other early exit + // (count=0, dtype) preempts the root check. + float dummy_send = 0.f; + float dummy_recv = 0.f; + + for (int bad_root : {-1, size}) { + infiniResult_t status = infiniBroadcast( + &dummy_send, &dummy_recv, 1, infiniFloat32, bad_root, comm, nullptr); + if (status != infiniInvalidArgument) { + return {false, false, + "root=" + std::to_string(bad_root) + " expected " + + std::to_string(static_cast(infiniInvalidArgument)) + + ", got " + std::to_string(static_cast(status))}; + } + } + return {}; +} + +} // namespace + +int main(int argc, char **argv) { + constexpr Device::Type kDevType = + ListGetBest(EnabledDevices{}); + + CHECK_INFINI(infiniInit(&argc, &argv)); + + int rank = 0; + int size = 0; + CHECK_INFINI(infiniGetRank(&rank)); + CHECK_INFINI(infiniGetSize(&size)); + + if (rank == 0) { + std::cout << "=== Broadcast Test Suite ===" << std::endl; + std::cout << "Device: " << Device::StringFromType(kDevType) << std::endl; + std::cout << "Ranks: " << size << std::endl; + } + + infiniComm_t comm = nullptr; + CHECK_INFINI(infiniCommInitAll(&comm, size, nullptr)); + + bool overall_ok = true; + + auto run = [&](const std::string &name, CaseResult local) { + bool global_ok = AllRanksOk(local.ok); + PrintCase(rank, name, local, global_ok); + if (!local.skipped) { + overall_ok = overall_ok && global_ok; + } + }; + + run("count=0", Case_Count0(comm)); + run("out-of-place, root=size-1", + Case_OutOfPlaceRootLast(comm, rank, size)); + run("out-of-place, non-root sendbuff=nullptr", + Case_OutOfPlaceNonRootNullSend(comm, rank)); + run("in-place, root=0", Case_InplaceRootZero(comm, rank)); + run("in-place, root=size-1", + Case_InplaceRootLast(comm, rank, size)); + run("large count (>INT_MAX bytes)", Case_LargeCount(comm, rank)); + run("invalid root", Case_InvalidRoot(comm, size)); + + if (rank == 0) { + const char *GREEN = "\033[32m"; + const char *RED = "\033[31m"; + const char *RESET = "\033[0m"; + std::cout << "\n=== Summary ===" << std::endl; + std::cout << (overall_ok ? (std::string(GREEN) + "ALL PASS" + RESET) + : (std::string(RED) + "FAILED" + RESET)) + << std::endl; + } + + CHECK_INFINI(infiniCommDestroy(comm)); + CHECK_INFINI(infiniFinalize()); + + return overall_ok ? EXIT_SUCCESS : EXIT_FAILURE; +} \ No newline at end of file diff --git a/include/comm.h b/include/comm.h index a778a0f..34d2b07 100644 --- a/include/comm.h +++ b/include/comm.h @@ -40,6 +40,9 @@ infiniResult_t infiniAllReduce(const void *sendbuff, void *recvbuff, size_t count, infiniDataType_t datatype, infiniRedOp_t op, infiniComm_t comm, void *stream); +infiniResult_t infiniBroadcast(const void *sendbuff, void *recvbuff, + size_t count, infiniDataType_t datatype, + int root, infiniComm_t comm, void *stream); #ifdef __cplusplus } diff --git a/scripts/icclrun_logic.py b/scripts/icclrun_logic.py index 011f128..0fa3704 100644 --- a/scripts/icclrun_logic.py +++ b/scripts/icclrun_logic.py @@ -96,6 +96,7 @@ def ensure_launcher_exists(self): bin_sub = "examples/$1" if is_internal else "$1" case_blocks = "" + first_case = True for node in self.config["nodes"]: n_type = node["type"] n_env = node.get("backend_env", {}) @@ -106,9 +107,13 @@ def ensure_launcher_exists(self): exports += f' export {k}="{v if k != "LD_LIBRARY_PATH" else v + ":${LD_LIBRARY_PATH}"}"\n' if n_type == "nvidia": - case_blocks += f'if [ -c "/dev/nvidia0" ] || [ -x "$(command -v nvidia-smi)" ]; then\n{exports} ARCH="nvidia"\n' + prefix = "if" if first_case else "elif" + case_blocks += f'{prefix} [ -c "/dev/nvidia0" ] || [ -x "$(command -v nvidia-smi)" ]; then\n{exports} ARCH="nvidia"\n' + first_case = False elif n_type == "metax": - case_blocks += f'elif [ -d "/opt/maca" ] || grep -l "9999" /sys/bus/pci/devices/*/vendor >/dev/null 2>&1; then\n{exports} ARCH="metax"\n' + prefix = "if" if first_case else "elif" + case_blocks += f'{prefix} [ -d "/opt/maca" ] || grep -l "9999" /sys/bus/pci/devices/*/vendor >/dev/null 2>&1; then\n{exports} ARCH="metax"\n' + first_case = False content = f"""#!/bin/bash {case_blocks}else diff --git a/src/base/broadcast.h b/src/base/broadcast.h new file mode 100644 index 0000000..b19addc --- /dev/null +++ b/src/base/broadcast.h @@ -0,0 +1,67 @@ +#ifndef INFINI_CCL_BASE_BROADCAST_H_ +#define INFINI_CCL_BASE_BROADCAST_H_ + +#include "communicator.h" +#include "data_type_impl.h" +#include "logging.h" +#include "operation.h" +#include "return_status_impl.h" + +namespace infini::ccl { + +template +struct BroadcastImpl; + +class Broadcast : public Operation { +public: + template + static ReturnStatus Execute(const void *send_buff, void *recv_buff, + size_t count, DataType datatype, int root, + void *comm_handle, void *stream) { + if (!comm_handle) { + LOG("Invalid communicator handle for Broadcast."); + return ReturnStatus::kInvalidArgument; + } + + auto *comm = static_cast(comm_handle); + if (HasInvalidArgs(send_buff, recv_buff, count, datatype, root, comm)) { + return ReturnStatus::kInvalidArgument; + } + if (count == 0) { + return ReturnStatus::kSuccess; + } + + return BroadcastImpl::Apply( + send_buff, recv_buff, count, datatype, root, comm, stream); + } + +private: + static bool HasInvalidArgs(const void *send_buff, void *recv_buff, + size_t count, DataType datatype, int root, + Communicator *comm) { + if (datatype < DataType::kChar || datatype >= DataType::kNumTypes) { + LOG("Invalid data type for Broadcast."); + return true; + } + if (root < 0 || root >= comm->size()) { + LOG("Invalid root rank for Broadcast."); + return true; + } + if (count == 0) { + return false; + } + if (!recv_buff) { + LOG("Invalid receive buffer pointer for Broadcast."); + return true; + } + if (comm->rank() == root && !send_buff) { + LOG("Invalid root send buffer pointer for Broadcast."); + return true; + } + return false; + } +}; + +} // namespace infini::ccl + +#endif // INFINI_CCL_BASE_BROADCAST_H_ \ No newline at end of file diff --git a/src/ompi/impl/all_reduce.h b/src/ompi/impl/all_reduce.h index eb54030..1d8facf 100644 --- a/src/ompi/impl/all_reduce.h +++ b/src/ompi/impl/all_reduce.h @@ -8,6 +8,9 @@ #include "ompi/checks.h" #include "ompi/comm_instance.h" #include "ompi/type_map.h" +#include +#include +#include namespace infini::ccl { @@ -66,7 +69,15 @@ class AllReduceImpl { for (size_t i = 0; i < count; ++i) { // TODO(lzm): should later use the unified `Cast` function instead of // static_cast to support CPU custom types. - typed_buf[i] *= static_cast(scale); + if constexpr (std::is_same_v) { + typed_buf[i] = + __float2half(__half2float(typed_buf[i]) * static_cast(scale)); + } else if constexpr (std::is_same_v) { + typed_buf[i] = + __float2bfloat16(__bfloat162float(typed_buf[i]) * static_cast(scale)); + } else { + typed_buf[i] *= static_cast(scale); + } } }); } diff --git a/src/ompi/impl/broadcast.h b/src/ompi/impl/broadcast.h new file mode 100644 index 0000000..1d3f89f --- /dev/null +++ b/src/ompi/impl/broadcast.h @@ -0,0 +1,79 @@ +#ifndef INFINI_CCL_OMPI_IMPL_BROADCAST_H_ +#define INFINI_CCL_OMPI_IMPL_BROADCAST_H_ + +#include +#include + +#include "base/broadcast.h" +#include "communicator.h" +#include "data_type_impl.h" +#include "logging.h" +#include "ompi/checks.h" +#include "ompi/comm_instance.h" +#include "runtime.h" + +namespace infini::ccl { + +template +class BroadcastImpl { +public: + static ReturnStatus Apply(const void *send_buff, void *recv_buff, + size_t count, DataType data_type, int root, + Communicator *comm, void *stream) { + constexpr Device::Type kDev = + ListGetBest(ActiveDevices{}); + using Rt = Runtime; + + auto *inst = static_cast(comm->inter_comm()); + if (!inst || inst->handle == MPI_COMM_NULL) { + LOG("Invalid OpenMPI communicator instance for Broadcast."); + return ReturnStatus::kInternalError; + } + + size_t type_size = kDataTypeToSize.at(data_type); + if (count > std::numeric_limits::max() / type_size) { + LOG("Broadcast byte size overflow."); + return ReturnStatus::kInvalidArgument; + } + + size_t total_bytes = count * type_size; + void *host_buf = std::malloc(total_bytes); + if (!host_buf) { + LOG("Failed to allocate host buffer for Broadcast staging."); + return ReturnStatus::kSystemError; + } + + if (comm->rank() == root) { + CHECK_STATUS(Rt, Rt::Memcpy(host_buf, send_buff, total_bytes, + Rt::MemcpyDeviceToHost)); + CHECK_STATUS(Rt, Rt::StreamSynchronize(static_cast(stream))); + } + + auto *bytes = static_cast(host_buf); + size_t offset = 0; + constexpr size_t kMaxMpiCount = + static_cast(std::numeric_limits::max()); + while (offset < total_bytes) { + size_t chunk = total_bytes - offset; + if (chunk > kMaxMpiCount) { + chunk = kMaxMpiCount; + } + INFINI_CHECK_MPI(MPI_Bcast(bytes + offset, static_cast(chunk), + MPI_BYTE, root, inst->handle)); + offset += chunk; + } + + CHECK_STATUS(Rt, Rt::Memcpy(recv_buff, host_buf, total_bytes, + Rt::MemcpyHostToDevice)); + + std::free(host_buf); + return ReturnStatus::kSuccess; + } +}; + +template <> +struct BackendEnabled : std::true_type {}; + +} // namespace infini::ccl + +#endif // INFINI_CCL_OMPI_IMPL_BROADCAST_H_ \ No newline at end of file diff --git a/src/ompi/impl/comm_init_all.h b/src/ompi/impl/comm_init_all.h index 0764b9b..98d51fc 100644 --- a/src/ompi/impl/comm_init_all.h +++ b/src/ompi/impl/comm_init_all.h @@ -1,8 +1,8 @@ #ifndef INFINI_CCL_OMPI_IMPL_COMM_INIT_ALL_H_ #define INFINI_CCL_OMPI_IMPL_COMM_INIT_ALL_H_ -#include "base/comm_init_all.h" #include "communicator.h" +#include "base/comm_init_all.h" #include "logging.h" #include "ompi/checks.h" #include "ompi/comm_instance.h" diff --git a/src/ompi/impl/finalize.h b/src/ompi/impl/finalize.h index 2787c52..cf47161 100644 --- a/src/ompi/impl/finalize.h +++ b/src/ompi/impl/finalize.h @@ -10,8 +10,13 @@ template class FinalizeImpl { public: static ReturnStatus Apply() { - int finalized; + int finalized = 0; INFINI_CHECK_MPI(MPI_Finalized(&finalized)); + + if (!finalized) { + INFINI_CHECK_MPI(MPI_Finalize()); + } + return ReturnStatus::kSuccess; } }; @@ -21,4 +26,4 @@ struct BackendEnabled : std::true_type {}; } // namespace infini::ccl -#endif // INFINI_CCL_OMPI_IMPL_FINALIZE_H_ +#endif // INFINI_CCL_OMPI_IMPL_FINALIZE_H_ \ No newline at end of file