diff --git a/python/csrc/env_py.cpp b/python/csrc/env_py.cpp index d4b2f5da..eab7b076 100644 --- a/python/csrc/env_py.cpp +++ b/python/csrc/env_py.cpp @@ -24,6 +24,11 @@ void register_env(nb::module_& m) { .def_ro("cache_dir", &Env::cacheDir) .def_ro("npkit_dump_dir", &Env::npkitDumpDir) .def_ro("cuda_ipc_use_default_stream", &Env::cudaIpcUseDefaultStream) + .def_ro("nccl_shared_lib_path", &Env::ncclSharedLibPath) + .def_ro("force_nccl_fallback_operation", &Env::forceNcclFallbackOperation) + .def_ro("nccl_symmetric_memory", &Env::ncclSymmetricMemory) + .def_ro("force_disable_nvls", &Env::forceDisableNvls) + .def_ro("force_disable_gdr", &Env::forceDisableGdr) .def_ro("ib_gid_index", &Env::ibGidIndex); m.def("env", &env); diff --git a/src/core/executor/executor.cc b/src/core/executor/executor.cc index bf2caf97..9229f9ac 100644 --- a/src/core/executor/executor.cc +++ b/src/core/executor/executor.cc @@ -109,7 +109,7 @@ namespace mscclpp { struct ExecutionContext { std::shared_ptr proxyService; - std::unordered_map connections; + std::vector connections; std::vector> nvlsConnections; MemoryId localMemoryIdBegin = MemoryId(0); @@ -121,8 +121,6 @@ struct ExecutionContext { // local registered memories to keep resources alive std::vector localRegisteredMemories; - std::vector> memorySemaphores; - std::vector proxySemaphores; std::vector memoryChannels; std::vector portChannels; std::vector nvlsChannels; @@ -266,15 +264,28 @@ struct Executor::Impl { } }; - std::vector connectedPeers = plan.impl_->getConnectedPeers(); - std::vector> connectionFutures; - for (int peer : connectedPeers) { - Transport transport = - !useIB(rank, peer, this->nranksPerNode) ? Transport::CudaIpc : IBs[rank % this->nranksPerNode]; - connectionFutures.push_back(this->comm->connect(transport, peer)); + std::unordered_map peerTags; + Transport ibTransport = IBs[rank % this->nranksPerNode]; + std::vector> connFutures; + for (ChannelType channelType : {ChannelType::MEMORY, ChannelType::PORT}) { + std::vector channelInfos = plan.impl_->getChannelInfos(channelType); + for (const auto& info : channelInfos) { + for (int peer : info.connectedPeers) { + Transport transport = useIB(rank, peer, this->nranksPerNode) ? ibTransport : Transport::CudaIpc; + connFutures.push_back(this->comm->connect(transport, peer, peerTags[peer]++)); + } + } + channelInfos = plan.impl_->getUnpairedChannelInfos(nranks, channelType); + for (const auto& info : channelInfos) { + for (int peer : info.connectedPeers) { + Transport transport = useIB(rank, peer, this->nranksPerNode) ? ibTransport : Transport::CudaIpc; + connFutures.push_back(this->comm->connect(transport, peer, peerTags[peer]++)); + } + } } - for (size_t i = 0; i < connectionFutures.size(); i++) { - context.connections[connectedPeers[i]] = connectionFutures[i].get(); + + for (auto& future : connFutures) { + context.connections.push_back(future.get()); } std::vector nvlsInfos = plan.impl_->nvlsInfos.at(rank); @@ -328,10 +339,11 @@ struct Executor::Impl { std::vector> futureProxySemaphores; std::vector> memorySemaphores; std::vector proxySemaphores; + int connIdx = 0; auto processChannelInfos = [&](std::vector& channelInfos) { for (ChannelInfo& info : channelInfos) { - for (int peer : info.connectedPeers) { - auto connection = context.connections.at(peer); + for (size_t i = 0; i < info.connectedPeers.size(); i++) { + auto& connection = context.connections[connIdx++]; if (info.channelType == ChannelType::MEMORY) { futureMemorySemaphores.push_back(this->comm->buildSemaphore( connection, this->comm->remoteRankOf(connection), this->comm->tagOf(connection))); @@ -360,18 +372,15 @@ struct Executor::Impl { proxySemaphores.push_back(context.proxyService->addSemaphore(sem.get())); } - context.memorySemaphores = std::move(memorySemaphores); - context.proxySemaphores = std::move(proxySemaphores); - for (ChannelType channelType : channelTypes) { std::vector channelInfos = plan.impl_->getChannelInfos(channelType); int index = 0; for (ChannelInfo& info : channelInfos) { for (size_t i = 0; i < info.connectedPeers.size(); i++) { if (channelType == ChannelType::MEMORY) { - context.memoryChannels.emplace_back(context.memorySemaphores[index++]); + context.memoryChannels.emplace_back(memorySemaphores[index++]); } else if (channelType == ChannelType::PORT) { - context.portChannels.emplace_back(context.proxyService->basePortChannel(context.proxySemaphores[index++])); + context.portChannels.emplace_back(context.proxyService->basePortChannel(proxySemaphores[index++])); } } }