diff --git a/.editorconfig b/.editorconfig index 0860c43..7b91b41 100644 --- a/.editorconfig +++ b/.editorconfig @@ -3,7 +3,7 @@ root = true [*.{h,hpp}] indent_style = tab indent_size = 4 -# end_of_line = crlf +end_of_line = crlf charset = utf-8 trim_trailing_whitespace = false insert_final_newline = false diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..a2237bc --- /dev/null +++ b/.gitattributes @@ -0,0 +1,4 @@ +* text=auto eol=lf + +*.bat text eol=crlf +*.cmd text eol=crlf diff --git a/.github/workflows/cmake.yml b/.github/workflows/cmake.yml new file mode 100644 index 0000000..98be95e --- /dev/null +++ b/.github/workflows/cmake.yml @@ -0,0 +1,68 @@ +name: Test with CMake + +on: + push: + pull_request: + workflow_dispatch: + +permissions: + contents: read + +jobs: + build-and-test: + runs-on: ubuntu-latest + + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install --no-install-recommends -y \ + libcereal-dev \ + libopenmpi-dev \ + ninja-build \ + openmpi-bin + + - name: Configure + run: | + cmake -S . -B build -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX="${{ github.workspace }}/install" \ + -DBUILD_TESTING=ON + + - name: Check public headers + run: cmake --build build --target libcomm_public_header_check -j $(nproc) + + - name: Run tests + run: ctest --test-dir build --output-on-failure + + - name: Install + run: cmake --install build + + - name: Test installed package + run: | + mkdir -p consumer + cat > consumer/CMakeLists.txt <<'EOF' + cmake_minimum_required(VERSION 3.23) + project(LibCommConsumer LANGUAGES CXX) + + find_package(LibComm CONFIG REQUIRED) + + add_executable(libcomm_consumer main.cpp) + target_link_libraries(libcomm_consumer PRIVATE LibComm::LibComm) + EOF + + cat > consumer/main.cpp <<'EOF' + #include + + int main() + { + return 0; + } + EOF + + cmake -S consumer -B consumer/build -G Ninja \ + -DCMAKE_PREFIX_PATH="${{ github.workspace }}/install" + cmake --build consumer/build --parallel 2 diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..cee8b57 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,126 @@ +cmake_minimum_required(VERSION 3.23) +if(POLICY CMP0144) # https://cmake.org/cmake/help/git-stage/policy/CMP0144.html + cmake_policy(SET CMP0144 NEW) +endif() + +set(LIBCOMM_PACKAGE_VERSION "0.2.0" CACHE STRING + "Package version used for generated LibCommConfigVersion.cmake when the source tree has no explicit release version.") + +project(LibComm + VERSION "${LIBCOMM_PACKAGE_VERSION}" + DESCRIPTION "Header-only MPI communication helpers used by LibRI and ABACUS" + LANGUAGES CXX) + +include(CMakePackageConfigHelpers) +include(GNUInstallDirs) +include(CTest) + +list(PREPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") + +option(LIBCOMM_INSTALL_EXAMPLE_HEADERS + "Install example helper headers under include/Comm/example." + ON) +option(LIBCOMM_ENABLE_HEADER_CHECK + "Add an explicit, non-default target that checks public headers for self-contained compilation." + ON) + +find_package(Threads REQUIRED) +find_package(MPI REQUIRED COMPONENTS CXX) +find_package(OpenMP REQUIRED COMPONENTS CXX) +find_package(cereal REQUIRED) + +file(GLOB_RECURSE LIBCOMM_PUBLIC_HEADERS CONFIGURE_DEPENDS + "${CMAKE_CURRENT_SOURCE_DIR}/include/Comm/*.h" + "${CMAKE_CURRENT_SOURCE_DIR}/include/Comm/*.hpp") + +if(NOT LIBCOMM_INSTALL_EXAMPLE_HEADERS) + list(FILTER LIBCOMM_PUBLIC_HEADERS EXCLUDE REGEX "/include/Comm/example/") +endif() + +add_library(LibComm INTERFACE) +add_library(LibComm::LibComm ALIAS LibComm) + +set_target_properties(LibComm PROPERTIES + EXPORT_NAME LibComm) + +target_sources(LibComm + INTERFACE + FILE_SET public_headers + TYPE HEADERS + BASE_DIRS "${CMAKE_CURRENT_SOURCE_DIR}/include" + FILES ${LIBCOMM_PUBLIC_HEADERS}) + +target_compile_features(LibComm INTERFACE cxx_std_17) + +target_include_directories(LibComm + INTERFACE + "$" + "$") + +target_link_libraries(LibComm + INTERFACE + MPI::MPI_CXX + OpenMP::OpenMP_CXX + Threads::Threads + cereal::cereal) + +install(TARGETS LibComm + EXPORT LibCommTargets + FILE_SET public_headers DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}") + +install(EXPORT LibCommTargets + NAMESPACE LibComm:: + DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/LibComm") + +configure_package_config_file( + "${CMAKE_CURRENT_SOURCE_DIR}/cmake/LibCommConfig.cmake.in" + "${CMAKE_CURRENT_BINARY_DIR}/LibCommConfig.cmake" + INSTALL_DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/LibComm") + +write_basic_package_version_file( + "${CMAKE_CURRENT_BINARY_DIR}/LibCommConfigVersion.cmake" + VERSION "${PROJECT_VERSION}" + COMPATIBILITY SameMajorVersion) + +install(FILES + "${CMAKE_CURRENT_BINARY_DIR}/LibCommConfig.cmake" + "${CMAKE_CURRENT_BINARY_DIR}/LibCommConfigVersion.cmake" + DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/LibComm") + +install(FILES "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE" + DESTINATION "${CMAKE_INSTALL_DOCDIR}") + +if(BUILD_TESTING AND LIBCOMM_ENABLE_HEADER_CHECK) + add_custom_target(libcomm_public_header_check) + set(_libcomm_header_check_index 0) + foreach(_libcomm_header IN LISTS LIBCOMM_PUBLIC_HEADERS) + math(EXPR _libcomm_header_check_index "${_libcomm_header_check_index} + 1") + file(RELATIVE_PATH _libcomm_header_rel + "${CMAKE_CURRENT_SOURCE_DIR}/include" + "${_libcomm_header}") + set(_libcomm_header_check_source + "${CMAKE_CURRENT_BINARY_DIR}/header_check_${_libcomm_header_check_index}.cpp") + file(WRITE "${_libcomm_header_check_source}" + "#include <${_libcomm_header_rel}>\nint main() { return 0; }\n") + add_executable(libcomm_header_check_${_libcomm_header_check_index} + EXCLUDE_FROM_ALL + "${_libcomm_header_check_source}") + target_link_libraries(libcomm_header_check_${_libcomm_header_check_index} + PRIVATE LibComm::LibComm) + add_dependencies(libcomm_public_header_check + libcomm_header_check_${_libcomm_header_check_index}) + endforeach() + + add_test(NAME libcomm_public_header_check + COMMAND "${CMAKE_COMMAND}" --build "${CMAKE_BINARY_DIR}" + --target libcomm_public_header_check + --config "$") +endif() + +set(CPACK_PACKAGE_NAME "LibComm") +set(CPACK_PACKAGE_VENDOR "ABACUS") +set(CPACK_PACKAGE_VERSION "${PROJECT_VERSION}") +set(CPACK_PACKAGE_DESCRIPTION_SUMMARY "Header-only MPI communication helpers") +set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") +set(CPACK_GENERATOR "TGZ") +include(CPack) diff --git a/cmake/LibCommConfig.cmake.in b/cmake/LibCommConfig.cmake.in new file mode 100644 index 0000000..3ebb1ab --- /dev/null +++ b/cmake/LibCommConfig.cmake.in @@ -0,0 +1,17 @@ +@PACKAGE_INIT@ + +include(CMakeFindDependencyMacro) + +list(PREPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}") + +find_dependency(Threads) +find_dependency(MPI COMPONENTS CXX) +find_dependency(OpenMP COMPONENTS CXX) +find_package(cereal CONFIG QUIET) +if(NOT cereal_FOUND) + find_dependency(Cereal) +endif() + +include("${CMAKE_CURRENT_LIST_DIR}/LibCommTargets.cmake") + +check_required_components(LibComm) diff --git a/include/Comm/Comm_Assemble/Comm_Assemble.h b/include/Comm/Comm_Assemble/Comm_Assemble.h index 850550b..0656540 100644 --- a/include/Comm/Comm_Assemble/Comm_Assemble.h +++ b/include/Comm/Comm_Assemble/Comm_Assemble.h @@ -1,64 +1,64 @@ -//======================= -// AUTHOR : Peize Lin -// DATE : 2022-07-06 -//======================= - -#pragma once - -#include "../Comm_Keys/Comm_Keys_31-gather.h" -#include "../Comm_Trans/Comm_Trans.h" -#include "../Comm_Tools.h" - -#include -#include - -namespace Comm -{ - -template -class Comm_Assemble -{ -public: - Comm_Assemble(const MPI_Comm &mpi_comm_in); - - std::function< - void( - const Tdatas_provide &keys_provide_mine, - std::function &func)> - &traverse_keys_provide; - std::function< - const Tvalue&( - const Tkey &key, - const Tdatas_provide &datas_provide)> - get_value_provide; - std::function< - void( - Tkey &&key, - Tvalue &&value, - Tdatas_require &datas_require)> - &set_value_require; - - Comm_Tools::Lock_Type &flag_lock_set_value; - std::function< - Tdatas_require( - const int rank_recv)> - &init_datas_local; - std::function< - void( - Tdatas_require &&datas_local, - Tdatas_require &datas_recv)> - &add_datas; - - void communicate( - const Tdatas_provide &datas_provide, - const Tkeys_require &keys_require, - Tdatas_require &datas_require); - -private: - Comm_Keys_31_SenderTraversal comm_keys; - Comm_Trans comm_trans; -}; - -} - -#include "Comm_Assemble.hpp" \ No newline at end of file +//======================= +// AUTHOR : Peize Lin +// DATE : 2022-07-06 +//======================= + +#pragma once + +#include "../Comm_Keys/Comm_Keys_31-gather.h" +#include "../Comm_Trans/Comm_Trans.h" +#include "../Comm_Tools.h" + +#include +#include + +namespace Comm +{ + +template +class Comm_Assemble +{ +public: + Comm_Assemble(const MPI_Comm &mpi_comm_in); + + std::function< + void( + const Tdatas_provide &keys_provide_mine, + std::function &func)> + &traverse_keys_provide; + std::function< + const Tvalue&( + const Tkey &key, + const Tdatas_provide &datas_provide)> + get_value_provide; + std::function< + void( + Tkey &&key, + Tvalue &&value, + Tdatas_require &datas_require)> + &set_value_require; + + Comm_Tools::Lock_Type &flag_lock_set_value; + std::function< + Tdatas_require( + const int rank_recv)> + &init_datas_local; + std::function< + void( + Tdatas_require &&datas_local, + Tdatas_require &datas_recv)> + &add_datas; + + void communicate( + const Tdatas_provide &datas_provide, + const Tkeys_require &keys_require, + Tdatas_require &datas_require); + +private: + Comm_Keys_31_SenderTraversal comm_keys; + Comm_Trans comm_trans; +}; + +} + +#include "Comm_Assemble.hpp" diff --git a/include/Comm/Comm_Assemble/Comm_Assemble.hpp b/include/Comm/Comm_Assemble/Comm_Assemble.hpp index 761b9df..b67a930 100644 --- a/include/Comm/Comm_Assemble/Comm_Assemble.hpp +++ b/include/Comm/Comm_Assemble/Comm_Assemble.hpp @@ -1,41 +1,41 @@ -//======================= -// AUTHOR : Peize Lin -// DATE : 2022-07-06 -//======================= - -#pragma once - -#include "Comm_Assemble.h" - -namespace Comm -{ - -template -Comm_Assemble::Comm_Assemble(const MPI_Comm &mpi_comm_in) - :traverse_keys_provide(comm_keys.traverse_keys_provide), - set_value_require(comm_trans.set_value_recv), - flag_lock_set_value(comm_trans.flag_lock_set_value), - init_datas_local(comm_trans.init_datas_local), - add_datas(comm_trans.add_datas), - comm_keys(mpi_comm_in), - comm_trans(mpi_comm_in){} - -template -void Comm_Assemble::communicate( - const Tdatas_provide &datas_provide, - const Tkeys_require &keys_require, - Tdatas_require &datas_require) -{ - const std::vector> keys_trans = comm_keys.trans( datas_provide, keys_require ); - comm_trans.traverse_isend = [&]( - const Tdatas_provide &datas_provide, - const int rank_isend, - std::function &func) - { - for(const Tkey &key : keys_trans[rank_isend]) - func(key, this->get_value_provide(key, datas_provide)); - }; - comm_trans.communicate(datas_provide, datas_require); -} - -} \ No newline at end of file +//======================= +// AUTHOR : Peize Lin +// DATE : 2022-07-06 +//======================= + +#pragma once + +#include "Comm_Assemble.h" + +namespace Comm +{ + +template +Comm_Assemble::Comm_Assemble(const MPI_Comm &mpi_comm_in) + :traverse_keys_provide(comm_keys.traverse_keys_provide), + set_value_require(comm_trans.set_value_recv), + flag_lock_set_value(comm_trans.flag_lock_set_value), + init_datas_local(comm_trans.init_datas_local), + add_datas(comm_trans.add_datas), + comm_keys(mpi_comm_in), + comm_trans(mpi_comm_in){} + +template +void Comm_Assemble::communicate( + const Tdatas_provide &datas_provide, + const Tkeys_require &keys_require, + Tdatas_require &datas_require) +{ + const std::vector> keys_trans = comm_keys.trans( datas_provide, keys_require ); + comm_trans.traverse_isend = [&]( + const Tdatas_provide &datas_provide, + const int rank_isend, + std::function &func) + { + for(const Tkey &key : keys_trans[rank_isend]) + func(key, this->get_value_provide(key, datas_provide)); + }; + comm_trans.communicate(datas_provide, datas_require); +} + +} diff --git a/include/Comm/Comm_Keys/Comm_Keys_31-gather.h b/include/Comm/Comm_Keys/Comm_Keys_31-gather.h index 8ac7489..242cbd8 100644 --- a/include/Comm/Comm_Keys/Comm_Keys_31-gather.h +++ b/include/Comm/Comm_Keys/Comm_Keys_31-gather.h @@ -1,95 +1,95 @@ -// =================== -// Author: Peize Lin -// date: 2022.06.22 -// =================== - -#pragma once - -#include -#include -#include - -namespace Comm -{ - -template -class Comm_Keys_31 -{ -public: - Comm_Keys_31(const MPI_Comm &mpi_comm_in); - - std::vector> trans( - const Tkeys_provide &keys_provide_mine, - const Tkeys_require &keys_require_mine); - -protected: -#if MPI_VERSION>=4 - void send_keys_require_mine( - const Tkeys_require &keys_require_mine, - std::vector &sss_size, - std::vector &sss_displs, - std::vector &buffer_recv); -#else - void send_keys_require_mine( - const Tkeys_require &keys_require_mine, - std::vector &sss_size, - std::vector &sss_displs, - std::vector &buffer_recv); -#endif - - //void recv_require_intersection( - // const Tkeys_provide &keys_provide_mine, - // std::vector> &keys_trans_list); - - virtual void intersection( - const Tkeys_provide &keys_provide_mine, - const Tkeys_require &keys_require, - std::vector &keys_trans)=0; - - MPI_Comm mpi_comm; - int rank_mine; - int rank_size; - - const int tag_keys = 99; -}; - -template -class Comm_Keys_31_SenderTraversal: public Comm_Keys_31 -{ -public: - Comm_Keys_31_SenderTraversal(const MPI_Comm &mpi_comm); - - std::function< - void( - const Tkeys_provide &keys_provide_mine, - std::function &func )> - traverse_keys_provide; - -private: - void intersection( - const Tkeys_provide &keys_provide_mine, - const Tkeys_require &keys_require, - std::vector &keys_trans); -}; - -template -class Comm_Keys_31_SenderJudge: public Comm_Keys_31 -{ -public: - Comm_Keys_31_SenderJudge(const MPI_Comm &mpi_comm); - - std::function< - void( - std::function &func )> - traverse_keys_all; - -private: - void intersection( - const Tkeys_provide &keys_provide_mine, - const Tkeys_require &keys_require, - std::vector &keys_trans); -}; - -} - -#include "Comm_Keys_31-gather.hpp" \ No newline at end of file +// =================== +// Author: Peize Lin +// date: 2022.06.22 +// =================== + +#pragma once + +#include +#include +#include + +namespace Comm +{ + +template +class Comm_Keys_31 +{ +public: + Comm_Keys_31(const MPI_Comm &mpi_comm_in); + + std::vector> trans( + const Tkeys_provide &keys_provide_mine, + const Tkeys_require &keys_require_mine); + +protected: +#if MPI_VERSION>=4 + void send_keys_require_mine( + const Tkeys_require &keys_require_mine, + std::vector &sss_size, + std::vector &sss_displs, + std::vector &buffer_recv); +#else + void send_keys_require_mine( + const Tkeys_require &keys_require_mine, + std::vector &sss_size, + std::vector &sss_displs, + std::vector &buffer_recv); +#endif + + //void recv_require_intersection( + // const Tkeys_provide &keys_provide_mine, + // std::vector> &keys_trans_list); + + virtual void intersection( + const Tkeys_provide &keys_provide_mine, + const Tkeys_require &keys_require, + std::vector &keys_trans)=0; + + MPI_Comm mpi_comm; + int rank_mine; + int rank_size; + + const int tag_keys = 99; +}; + +template +class Comm_Keys_31_SenderTraversal: public Comm_Keys_31 +{ +public: + Comm_Keys_31_SenderTraversal(const MPI_Comm &mpi_comm); + + std::function< + void( + const Tkeys_provide &keys_provide_mine, + std::function &func )> + traverse_keys_provide; + +private: + void intersection( + const Tkeys_provide &keys_provide_mine, + const Tkeys_require &keys_require, + std::vector &keys_trans); +}; + +template +class Comm_Keys_31_SenderJudge: public Comm_Keys_31 +{ +public: + Comm_Keys_31_SenderJudge(const MPI_Comm &mpi_comm); + + std::function< + void( + std::function &func )> + traverse_keys_all; + +private: + void intersection( + const Tkeys_provide &keys_provide_mine, + const Tkeys_require &keys_require, + std::vector &keys_trans); +}; + +} + +#include "Comm_Keys_31-gather.hpp" diff --git a/include/Comm/Comm_Keys/Comm_Keys_31-gather.hpp b/include/Comm/Comm_Keys/Comm_Keys_31-gather.hpp index 62f33c2..02583ef 100644 --- a/include/Comm/Comm_Keys/Comm_Keys_31-gather.hpp +++ b/include/Comm/Comm_Keys/Comm_Keys_31-gather.hpp @@ -1,163 +1,163 @@ -// =================== -// Author: Peize Lin -// date: 2022.06.22 -// =================== - -#pragma once - -#include "Comm_Keys_31-gather.h" -#include "../global/Cereal_Func.h" - -#include -#include -#include -#include -#include - -#define MPI_CHECK(x) if((x)!=MPI_SUCCESS) throw std::runtime_error(std::string(__FILE__)+" line "+std::to_string(__LINE__)); - -namespace Comm -{ - -template -Comm_Keys_31::Comm_Keys_31( - const MPI_Comm &mpi_comm_in) - :mpi_comm(mpi_comm_in) -{ - MPI_CHECK( MPI_Comm_size( this->mpi_comm, &this->rank_size ) ); - MPI_CHECK( MPI_Comm_rank( this->mpi_comm, &this->rank_mine ) ); -} - -template -Comm_Keys_31_SenderTraversal::Comm_Keys_31_SenderTraversal( - const MPI_Comm &mpi_comm) - :Comm_Keys_31(mpi_comm) -{ - this->traverse_keys_provide = - [](const Tkeys_provide &keys_provide, std::function &func) - { throw std::logic_error("Function traverse not set."); }; -} - -template -Comm_Keys_31_SenderJudge::Comm_Keys_31_SenderJudge( - const MPI_Comm &mpi_comm) - :Comm_Keys_31(mpi_comm) -{ - this->traverse_keys_all = - [](std::function &func) - { throw std::logic_error("Function traverse not set."); }; -} - - -template -std::vector> Comm_Keys_31::trans( - const Tkeys_provide &keys_provide_mine, - const Tkeys_require &keys_require_mine) -{ - std::vector> keys_trans_list(this->rank_size); - -#if MPI_VERSION>=4 - std::vector sss_size; - std::vector sss_displs; -#else - std::vector sss_size; - std::vector sss_displs; -#endif - std::vector buffer_recv; - this->send_keys_require_mine( keys_require_mine, - sss_size, sss_displs, buffer_recv); - - #pragma omp parallel for - for(int rank_require=0; rank_requirerank_size; ++rank_require) - { - Tkeys_require keys_require; - std::stringstream ss_recv; - ss_recv.rdbuf()->pubsetbuf(buffer_recv.data()+sss_displs[rank_require], sss_size[rank_require]); - { - cereal::BinaryInputArchive ar(ss_recv); - ar(keys_require); - } - - this->intersection(keys_provide_mine, keys_require, keys_trans_list[rank_require]); - } - - return keys_trans_list; -} - - -#if MPI_VERSION>=4 -template -void Comm_Keys_31::send_keys_require_mine( - const Tkeys_require &keys_require_mine, - std::vector &sss_size, - std::vector &sss_displs, - std::vector &buffer_recv) -#else -template -void Comm_Keys_31::send_keys_require_mine( - const Tkeys_require &keys_require_mine, - std::vector &sss_size, - std::vector &sss_displs, - std::vector &buffer_recv) -#endif -{ - std::stringstream ss_send; - { - cereal::BinaryOutputArchive ar(ss_send); - ar(keys_require_mine); - } - - sss_size.resize(this->rank_size); -#if MPI_VERSION>=4 - const MPI_Count ss_size = ss_send.str().size(); - MPI_CHECK( MPI_Allgather( &ss_size, 1, MPI_COUNT, sss_size.data(), 1, MPI_COUNT, this->mpi_comm ) ); -#else - const int ss_size = ss_send.str().size(); - MPI_CHECK( MPI_Allgather( &ss_size, 1, MPI_INT, sss_size.data(), 1, MPI_INT, this->mpi_comm ) ); -#endif - - sss_displs.resize(this->rank_size); - sss_displs[0] = 0; - for(int i=1; irank_size; ++i) - sss_displs[i] = sss_displs[i-1] + sss_size[i-1]; - - buffer_recv.resize(sss_displs.back() + sss_size.back()); -#if MPI_VERSION>=4 - MPI_CHECK( MPI_Allgatherv_c( ss_send.str().c_str(), ss_send.str().size(), MPI_CHAR, buffer_recv.data(), sss_size.data(), sss_displs.data(), MPI_CHAR, this->mpi_comm ) ); -#else - MPI_CHECK( MPI_Allgatherv ( ss_send.str().c_str(), ss_send.str().size(), MPI_CHAR, buffer_recv.data(), sss_size.data(), sss_displs.data(), MPI_CHAR, this->mpi_comm ) ); -#endif -} - -template -void Comm_Keys_31_SenderTraversal::intersection( - const Tkeys_provide &keys_provide_mine, - const Tkeys_require &keys_require, - std::vector &keys_trans) -{ - std::function inter = [&](const Tkey &key) - { - if(keys_require.judge(key)) - keys_trans.push_back(key); - }; - this->traverse_keys_provide( keys_provide_mine, inter ); -} - - -template -void Comm_Keys_31_SenderJudge::intersection( - const Tkeys_provide &keys_provide_mine, - const Tkeys_require &keys_require, - std::vector &keys_trans) -{ - std::function inter = [&](const Tkey &key) - { - if(keys_provide_mine.judge(key) && keys_require.judge(key)) - keys_trans.push_back(key); - }; - this->traverse_keys_all(inter); -} - -} - -#undef MPI_CHECK \ No newline at end of file +// =================== +// Author: Peize Lin +// date: 2022.06.22 +// =================== + +#pragma once + +#include "Comm_Keys_31-gather.h" +#include "../global/Cereal_Func.h" + +#include +#include +#include +#include +#include + +#define MPI_CHECK(x) if((x)!=MPI_SUCCESS) throw std::runtime_error(std::string(__FILE__)+" line "+std::to_string(__LINE__)); + +namespace Comm +{ + +template +Comm_Keys_31::Comm_Keys_31( + const MPI_Comm &mpi_comm_in) + :mpi_comm(mpi_comm_in) +{ + MPI_CHECK( MPI_Comm_size( this->mpi_comm, &this->rank_size ) ); + MPI_CHECK( MPI_Comm_rank( this->mpi_comm, &this->rank_mine ) ); +} + +template +Comm_Keys_31_SenderTraversal::Comm_Keys_31_SenderTraversal( + const MPI_Comm &mpi_comm) + :Comm_Keys_31(mpi_comm) +{ + this->traverse_keys_provide = + [](const Tkeys_provide &keys_provide, std::function &func) + { throw std::logic_error("Function traverse not set."); }; +} + +template +Comm_Keys_31_SenderJudge::Comm_Keys_31_SenderJudge( + const MPI_Comm &mpi_comm) + :Comm_Keys_31(mpi_comm) +{ + this->traverse_keys_all = + [](std::function &func) + { throw std::logic_error("Function traverse not set."); }; +} + + +template +std::vector> Comm_Keys_31::trans( + const Tkeys_provide &keys_provide_mine, + const Tkeys_require &keys_require_mine) +{ + std::vector> keys_trans_list(this->rank_size); + +#if MPI_VERSION>=4 + std::vector sss_size; + std::vector sss_displs; +#else + std::vector sss_size; + std::vector sss_displs; +#endif + std::vector buffer_recv; + this->send_keys_require_mine( keys_require_mine, + sss_size, sss_displs, buffer_recv); + + #pragma omp parallel for + for(int rank_require=0; rank_requirerank_size; ++rank_require) + { + Tkeys_require keys_require; + std::stringstream ss_recv; + ss_recv.rdbuf()->pubsetbuf(buffer_recv.data()+sss_displs[rank_require], sss_size[rank_require]); + { + cereal::BinaryInputArchive ar(ss_recv); + ar(keys_require); + } + + this->intersection(keys_provide_mine, keys_require, keys_trans_list[rank_require]); + } + + return keys_trans_list; +} + + +#if MPI_VERSION>=4 +template +void Comm_Keys_31::send_keys_require_mine( + const Tkeys_require &keys_require_mine, + std::vector &sss_size, + std::vector &sss_displs, + std::vector &buffer_recv) +#else +template +void Comm_Keys_31::send_keys_require_mine( + const Tkeys_require &keys_require_mine, + std::vector &sss_size, + std::vector &sss_displs, + std::vector &buffer_recv) +#endif +{ + std::stringstream ss_send; + { + cereal::BinaryOutputArchive ar(ss_send); + ar(keys_require_mine); + } + + sss_size.resize(this->rank_size); +#if MPI_VERSION>=4 + const MPI_Count ss_size = ss_send.str().size(); + MPI_CHECK( MPI_Allgather( &ss_size, 1, MPI_COUNT, sss_size.data(), 1, MPI_COUNT, this->mpi_comm ) ); +#else + const int ss_size = ss_send.str().size(); + MPI_CHECK( MPI_Allgather( &ss_size, 1, MPI_INT, sss_size.data(), 1, MPI_INT, this->mpi_comm ) ); +#endif + + sss_displs.resize(this->rank_size); + sss_displs[0] = 0; + for(int i=1; irank_size; ++i) + sss_displs[i] = sss_displs[i-1] + sss_size[i-1]; + + buffer_recv.resize(sss_displs.back() + sss_size.back()); +#if MPI_VERSION>=4 + MPI_CHECK( MPI_Allgatherv_c( ss_send.str().c_str(), ss_send.str().size(), MPI_CHAR, buffer_recv.data(), sss_size.data(), sss_displs.data(), MPI_CHAR, this->mpi_comm ) ); +#else + MPI_CHECK( MPI_Allgatherv ( ss_send.str().c_str(), ss_send.str().size(), MPI_CHAR, buffer_recv.data(), sss_size.data(), sss_displs.data(), MPI_CHAR, this->mpi_comm ) ); +#endif +} + +template +void Comm_Keys_31_SenderTraversal::intersection( + const Tkeys_provide &keys_provide_mine, + const Tkeys_require &keys_require, + std::vector &keys_trans) +{ + std::function inter = [&](const Tkey &key) + { + if(keys_require.judge(key)) + keys_trans.push_back(key); + }; + this->traverse_keys_provide( keys_provide_mine, inter ); +} + + +template +void Comm_Keys_31_SenderJudge::intersection( + const Tkeys_provide &keys_provide_mine, + const Tkeys_require &keys_require, + std::vector &keys_trans) +{ + std::function inter = [&](const Tkey &key) + { + if(keys_provide_mine.judge(key) && keys_require.judge(key)) + keys_trans.push_back(key); + }; + this->traverse_keys_all(inter); +} + +} + +#undef MPI_CHECK diff --git a/include/Comm/Comm_Keys/Comm_Keys_31-sr.h b/include/Comm/Comm_Keys/Comm_Keys_31-sr.h index a0309e2..1e7d977 100644 --- a/include/Comm/Comm_Keys/Comm_Keys_31-sr.h +++ b/include/Comm/Comm_Keys/Comm_Keys_31-sr.h @@ -1,87 +1,87 @@ -// =================== -// Author: Peize Lin -// date: 2022.06.22 -// =================== - -#pragma once - -#include "../global/Cereal_Func.h" - -#include -#include -#include - -namespace Comm -{ - -template -class Comm_Keys_31 -{ - public: - Comm_Keys_31(const MPI_Comm &mpi_comm_in); - - std::vector> trans( - const Tkeys_provide &keys_provide_mine, - const Tkeys_require &keys_require_mine); - - protected: - void send_keys_require_mine( - const Tkeys_require &keys_require_mine); - - void recv_require_intersection( - const Tkeys_provide &keys_provide_mine, - std::vector> &keys_trans_list); - - virtual void intersection( - const Tkeys_provide &keys_provide_mine, - const Tkeys_require &keys_require, - std::vector &keys_trans)=0; - - MPI_Comm mpi_comm; - int rank_mine; - int rank_size; - - const int tag_keys = 99; - Comm::Cereal_Func cereal_func; -}; - -template -class Comm_Keys_31_SenderTraversal: public Comm_Keys_31 -{ -public: - Comm_Keys_31_SenderTraversal(const MPI_Comm &mpi_comm); - - std::function< - void( - const Tkeys_provide &keys_provide_mine, - std::function &func )> - traverse_keys_provide; - -private: - void intersection( - const Tkeys_provide &keys_provide_mine, - const Tkeys_require &keys_require, - std::vector &keys_trans); -}; - -template -class Comm_Keys_31_SenderJudge: public Comm_Keys_31 -{ -public: - Comm_Keys_31_SenderJudge(const MPI_Comm &mpi_comm); - - std::function< - void( - std::function &func )> - traverse_keys_all; - -private: - void intersection( - const Tkeys_provide &keys_provide_mine, - const Tkeys_require &keys_require, - std::vector &keys_trans); -}; - -} - -#include "Comm_Keys_31-sr.hpp" \ No newline at end of file +// =================== +// Author: Peize Lin +// date: 2022.06.22 +// =================== + +#pragma once + +#include "../global/Cereal_Func.h" + +#include +#include +#include + +namespace Comm +{ + +template +class Comm_Keys_31 +{ + public: + Comm_Keys_31(const MPI_Comm &mpi_comm_in); + + std::vector> trans( + const Tkeys_provide &keys_provide_mine, + const Tkeys_require &keys_require_mine); + + protected: + void send_keys_require_mine( + const Tkeys_require &keys_require_mine); + + void recv_require_intersection( + const Tkeys_provide &keys_provide_mine, + std::vector> &keys_trans_list); + + virtual void intersection( + const Tkeys_provide &keys_provide_mine, + const Tkeys_require &keys_require, + std::vector &keys_trans)=0; + + MPI_Comm mpi_comm; + int rank_mine; + int rank_size; + + const int tag_keys = 99; + Comm::Cereal_Func cereal_func; +}; + +template +class Comm_Keys_31_SenderTraversal: public Comm_Keys_31 +{ +public: + Comm_Keys_31_SenderTraversal(const MPI_Comm &mpi_comm); + + std::function< + void( + const Tkeys_provide &keys_provide_mine, + std::function &func )> + traverse_keys_provide; + +private: + void intersection( + const Tkeys_provide &keys_provide_mine, + const Tkeys_require &keys_require, + std::vector &keys_trans); +}; + +template +class Comm_Keys_31_SenderJudge: public Comm_Keys_31 +{ +public: + Comm_Keys_31_SenderJudge(const MPI_Comm &mpi_comm); + + std::function< + void( + std::function &func )> + traverse_keys_all; + +private: + void intersection( + const Tkeys_provide &keys_provide_mine, + const Tkeys_require &keys_require, + std::vector &keys_trans); +}; + +} + +#include "Comm_Keys_31-sr.hpp" diff --git a/include/Comm/Comm_Keys/Comm_Keys_31-sr.hpp b/include/Comm/Comm_Keys/Comm_Keys_31-sr.hpp index 15dfdea..e9f7330 100644 --- a/include/Comm/Comm_Keys/Comm_Keys_31-sr.hpp +++ b/include/Comm/Comm_Keys/Comm_Keys_31-sr.hpp @@ -179,4 +179,4 @@ void Comm_Keys_31_SenderJudge::intersection( } -#undef MPI_CHECK \ No newline at end of file +#undef MPI_CHECK diff --git a/include/Comm/Comm_Keys/Comm_Keys_32-gather.h b/include/Comm/Comm_Keys/Comm_Keys_32-gather.h index a27b034..880c28d 100644 --- a/include/Comm/Comm_Keys/Comm_Keys_32-gather.h +++ b/include/Comm/Comm_Keys/Comm_Keys_32-gather.h @@ -1,88 +1,88 @@ -// =================== -// Author: Peize Lin -// date: 2022.06.22 -// =================== - -#pragma once - -#include -#include -#include -#include - -namespace Comm -{ - -template -class Comm_Keys_32 -{ -public: - Comm_Keys_32(const MPI_Comm &mpi_comm_in); - - std::vector> trans( - const Tkeys_provide &keys_provide_mine, - const Tkeys_require &keys_require_mine); - -protected: - void send_keys_require_mine( - const Tkeys_require &keys_require_mine, - std::vector &sss_size, - std::vector &sss_displs, - std::vector &buffer_recv); - - //void recv_require_intersection( - // std::vector &keys_provide_mine, - // std::vector> &keys_trans_list); - - void intersection( - std::vector &keys_provide_mine, - const Tkeys_require &keys_require, - std::vector &keys_trans); - - virtual std::vector change_keys_provide_mine( - const Tkeys_provide &keys_provide_mine)=0; - - MPI_Comm mpi_comm; - int rank_mine; - int rank_size; - - const int tag_keys = 88; - std::shared_mutex lock_provide; -}; - -template -class Comm_Keys_32_SenderTraversal: public Comm_Keys_32 -{ -public: - Comm_Keys_32_SenderTraversal(const MPI_Comm &mpi_comm); - - std::function< - void( - const Tkeys_provide &keys_provide_mine, - std::function &func )> - traverse_keys_provide; - -private: - std::vector change_keys_provide_mine( - const Tkeys_provide &keys_provide_mine); -}; - -template -class Comm_Keys_32_SenderJudge: public Comm_Keys_32 -{ -public: - Comm_Keys_32_SenderJudge(const MPI_Comm &mpi_comm); - - std::function< - void( - std::function &func )> - traverse_keys_all; - -private: - std::vector change_keys_provide_mine( - const Tkeys_provide &keys_provide_mine); -}; - -} - -#include "Comm_Keys_32-gather.hpp" \ No newline at end of file +// =================== +// Author: Peize Lin +// date: 2022.06.22 +// =================== + +#pragma once + +#include +#include +#include +#include + +namespace Comm +{ + +template +class Comm_Keys_32 +{ +public: + Comm_Keys_32(const MPI_Comm &mpi_comm_in); + + std::vector> trans( + const Tkeys_provide &keys_provide_mine, + const Tkeys_require &keys_require_mine); + +protected: + void send_keys_require_mine( + const Tkeys_require &keys_require_mine, + std::vector &sss_size, + std::vector &sss_displs, + std::vector &buffer_recv); + + //void recv_require_intersection( + // std::vector &keys_provide_mine, + // std::vector> &keys_trans_list); + + void intersection( + std::vector &keys_provide_mine, + const Tkeys_require &keys_require, + std::vector &keys_trans); + + virtual std::vector change_keys_provide_mine( + const Tkeys_provide &keys_provide_mine)=0; + + MPI_Comm mpi_comm; + int rank_mine; + int rank_size; + + const int tag_keys = 88; + std::shared_mutex lock_provide; +}; + +template +class Comm_Keys_32_SenderTraversal: public Comm_Keys_32 +{ +public: + Comm_Keys_32_SenderTraversal(const MPI_Comm &mpi_comm); + + std::function< + void( + const Tkeys_provide &keys_provide_mine, + std::function &func )> + traverse_keys_provide; + +private: + std::vector change_keys_provide_mine( + const Tkeys_provide &keys_provide_mine); +}; + +template +class Comm_Keys_32_SenderJudge: public Comm_Keys_32 +{ +public: + Comm_Keys_32_SenderJudge(const MPI_Comm &mpi_comm); + + std::function< + void( + std::function &func )> + traverse_keys_all; + +private: + std::vector change_keys_provide_mine( + const Tkeys_provide &keys_provide_mine); +}; + +} + +#include "Comm_Keys_32-gather.hpp" diff --git a/include/Comm/Comm_Keys/Comm_Keys_32-gather.hpp b/include/Comm/Comm_Keys/Comm_Keys_32-gather.hpp index edc2a1d..e890b64 100644 --- a/include/Comm/Comm_Keys/Comm_Keys_32-gather.hpp +++ b/include/Comm/Comm_Keys/Comm_Keys_32-gather.hpp @@ -1,159 +1,159 @@ -// =================== -// Author: Peize Lin -// date: 2022.06.22 -// =================== - -#pragma once - -#include "Comm_Keys_32-gather.h" -#include "../global/Cereal_Func.h" - -#include -#include - -#define MPI_CHECK(x) if((x)!=MPI_SUCCESS) throw std::runtime_error(std::string(__FILE__)+" line "+std::to_string(__LINE__)); - -namespace Comm -{ - -template -Comm_Keys_32::Comm_Keys_32( - const MPI_Comm &mpi_comm_in) - :mpi_comm(mpi_comm_in) -{ - MPI_CHECK( MPI_Comm_size( this->mpi_comm, &this->rank_size ) ); - MPI_CHECK( MPI_Comm_rank( this->mpi_comm, &this->rank_mine ) ); -} - -template -Comm_Keys_32_SenderTraversal::Comm_Keys_32_SenderTraversal( - const MPI_Comm &mpi_comm) - :Comm_Keys_32(mpi_comm) -{ - this->traverse_keys_provide = - [](const Tkeys_provide &keys_provide, std::function &func) - { throw std::logic_error("Function traverse not set."); }; -} - -template -Comm_Keys_32_SenderJudge::Comm_Keys_32_SenderJudge( - const MPI_Comm &mpi_comm) - :Comm_Keys_32(mpi_comm) -{ - this->traverse_keys_all = - [](std::function &func) - { throw std::logic_error("Function traverse not set."); }; -} - - -template -std::vector> Comm_Keys_32::trans( - const Tkeys_provide &keys_provide_mine, - const Tkeys_require &keys_require_mine) -{ - std::vector> keys_trans_list(this->rank_size); - - std::vector sss_size; - std::vector sss_displs; - std::vector buffer_recv; - this->send_keys_require_mine( keys_require_mine, - sss_size, sss_displs, buffer_recv); - - std::vector keys_provide_mine_vec = change_keys_provide_mine(keys_provide_mine); - - #pragma omp parallel for - for(int rank_require=0; rank_requirerank_size; ++rank_require) - { - Tkeys_require keys_require; - std::stringstream ss_recv; - ss_recv.rdbuf()->pubsetbuf(buffer_recv.data()+sss_displs[rank_require], sss_size[rank_require]); - { - cereal::BinaryInputArchive ar(ss_recv); - ar(keys_require); - } - - this->intersection(keys_provide_mine_vec, keys_require, keys_trans_list[rank_require]); - } - - return keys_trans_list; -} - - -template -void Comm_Keys_32::send_keys_require_mine( - const Tkeys_require &keys_require_mine, - std::vector &sss_size, - std::vector &sss_displs, - std::vector &buffer_recv) -{ - std::stringstream ss_send; - { - cereal::BinaryOutputArchive ar(ss_send); - ar(keys_require_mine); - } - - sss_size.resize(this->rank_size); - const int ss_size = ss_send.str().size(); - MPI_Allgather( &ss_size, 1, MPI_INT, sss_size.data(), 1, MPI_INT, this->mpi_comm ); - - sss_displs.resize(this->rank_size); - sss_displs[0] = 0; - for(int i=1; irank_size; ++i) - sss_displs[i] = sss_displs[i-1] + sss_size[i-1]; - - buffer_recv.resize(sss_displs.back() + sss_size.back()); - MPI_Allgatherv( ss_send.str().c_str(), ss_send.str().size(), MPI_CHAR, buffer_recv.data(), sss_size.data(), sss_displs.data(), MPI_CHAR, this->mpi_comm ); -} - - -template -void Comm_Keys_32::intersection( - std::vector &keys_provide_mine, - const Tkeys_require &keys_require, - std::vector &keys_trans) -{ - std::vector keys_provide_mine_new; - this->lock_provide.lock(); - for(const Tkey &key : keys_provide_mine) - if(keys_require.judge(key)) - keys_trans.push_back(key); - else - keys_provide_mine_new.push_back(key); - this->lock_provide.unlock(); - - this->lock_provide.lock_shared(); - keys_provide_mine = keys_provide_mine_new; - this->lock_provide.unlock(); -} - - -template -std::vector Comm_Keys_32_SenderTraversal::change_keys_provide_mine( - const Tkeys_provide &keys_provide_mine) -{ - std::vector keys_provide_mine_vec; - std::function change = [&](const Tkey &key) - { - keys_provide_mine_vec.push_back(key); - }; - this->traverse_keys_provide( keys_provide_mine, change ); - return keys_provide_mine_vec; -} - -template -std::vector Comm_Keys_32_SenderJudge::change_keys_provide_mine( - const Tkeys_provide &keys_provide_mine) -{ - std::vector keys_provide_mine_vec; - std::function change = [&](const Tkey &key) - { - if(keys_provide_mine.judge(key)) - keys_provide_mine_vec.push_back(key); - }; - this->traverse_keys_all( change ); - return keys_provide_mine_vec; -} - -} - -#undef MPI_CHECK \ No newline at end of file +// =================== +// Author: Peize Lin +// date: 2022.06.22 +// =================== + +#pragma once + +#include "Comm_Keys_32-gather.h" +#include "../global/Cereal_Func.h" + +#include +#include + +#define MPI_CHECK(x) if((x)!=MPI_SUCCESS) throw std::runtime_error(std::string(__FILE__)+" line "+std::to_string(__LINE__)); + +namespace Comm +{ + +template +Comm_Keys_32::Comm_Keys_32( + const MPI_Comm &mpi_comm_in) + :mpi_comm(mpi_comm_in) +{ + MPI_CHECK( MPI_Comm_size( this->mpi_comm, &this->rank_size ) ); + MPI_CHECK( MPI_Comm_rank( this->mpi_comm, &this->rank_mine ) ); +} + +template +Comm_Keys_32_SenderTraversal::Comm_Keys_32_SenderTraversal( + const MPI_Comm &mpi_comm) + :Comm_Keys_32(mpi_comm) +{ + this->traverse_keys_provide = + [](const Tkeys_provide &keys_provide, std::function &func) + { throw std::logic_error("Function traverse not set."); }; +} + +template +Comm_Keys_32_SenderJudge::Comm_Keys_32_SenderJudge( + const MPI_Comm &mpi_comm) + :Comm_Keys_32(mpi_comm) +{ + this->traverse_keys_all = + [](std::function &func) + { throw std::logic_error("Function traverse not set."); }; +} + + +template +std::vector> Comm_Keys_32::trans( + const Tkeys_provide &keys_provide_mine, + const Tkeys_require &keys_require_mine) +{ + std::vector> keys_trans_list(this->rank_size); + + std::vector sss_size; + std::vector sss_displs; + std::vector buffer_recv; + this->send_keys_require_mine( keys_require_mine, + sss_size, sss_displs, buffer_recv); + + std::vector keys_provide_mine_vec = change_keys_provide_mine(keys_provide_mine); + + #pragma omp parallel for + for(int rank_require=0; rank_requirerank_size; ++rank_require) + { + Tkeys_require keys_require; + std::stringstream ss_recv; + ss_recv.rdbuf()->pubsetbuf(buffer_recv.data()+sss_displs[rank_require], sss_size[rank_require]); + { + cereal::BinaryInputArchive ar(ss_recv); + ar(keys_require); + } + + this->intersection(keys_provide_mine_vec, keys_require, keys_trans_list[rank_require]); + } + + return keys_trans_list; +} + + +template +void Comm_Keys_32::send_keys_require_mine( + const Tkeys_require &keys_require_mine, + std::vector &sss_size, + std::vector &sss_displs, + std::vector &buffer_recv) +{ + std::stringstream ss_send; + { + cereal::BinaryOutputArchive ar(ss_send); + ar(keys_require_mine); + } + + sss_size.resize(this->rank_size); + const int ss_size = ss_send.str().size(); + MPI_Allgather( &ss_size, 1, MPI_INT, sss_size.data(), 1, MPI_INT, this->mpi_comm ); + + sss_displs.resize(this->rank_size); + sss_displs[0] = 0; + for(int i=1; irank_size; ++i) + sss_displs[i] = sss_displs[i-1] + sss_size[i-1]; + + buffer_recv.resize(sss_displs.back() + sss_size.back()); + MPI_Allgatherv( ss_send.str().c_str(), ss_send.str().size(), MPI_CHAR, buffer_recv.data(), sss_size.data(), sss_displs.data(), MPI_CHAR, this->mpi_comm ); +} + + +template +void Comm_Keys_32::intersection( + std::vector &keys_provide_mine, + const Tkeys_require &keys_require, + std::vector &keys_trans) +{ + std::vector keys_provide_mine_new; + this->lock_provide.lock(); + for(const Tkey &key : keys_provide_mine) + if(keys_require.judge(key)) + keys_trans.push_back(key); + else + keys_provide_mine_new.push_back(key); + this->lock_provide.unlock(); + + this->lock_provide.lock_shared(); + keys_provide_mine = keys_provide_mine_new; + this->lock_provide.unlock(); +} + + +template +std::vector Comm_Keys_32_SenderTraversal::change_keys_provide_mine( + const Tkeys_provide &keys_provide_mine) +{ + std::vector keys_provide_mine_vec; + std::function change = [&](const Tkey &key) + { + keys_provide_mine_vec.push_back(key); + }; + this->traverse_keys_provide( keys_provide_mine, change ); + return keys_provide_mine_vec; +} + +template +std::vector Comm_Keys_32_SenderJudge::change_keys_provide_mine( + const Tkeys_provide &keys_provide_mine) +{ + std::vector keys_provide_mine_vec; + std::function change = [&](const Tkey &key) + { + if(keys_provide_mine.judge(key)) + keys_provide_mine_vec.push_back(key); + }; + this->traverse_keys_all( change ); + return keys_provide_mine_vec; +} + +} + +#undef MPI_CHECK diff --git a/include/Comm/Comm_Keys/Comm_Keys_32-sr.h b/include/Comm/Comm_Keys/Comm_Keys_32-sr.h index 26e40ac..54960d3 100644 --- a/include/Comm/Comm_Keys/Comm_Keys_32-sr.h +++ b/include/Comm/Comm_Keys/Comm_Keys_32-sr.h @@ -1,88 +1,88 @@ -// =================== -// Author: Peize Lin -// date: 2022.06.22 -// =================== - -#pragma once - -#include "../global/Cereal_Func.h" - -#include -#include -#include -#include - -namespace Comm -{ - -template -class Comm_Keys_32 -{ -public: - Comm_Keys_32(const MPI_Comm &mpi_comm_in); - - std::vector> trans( - const Tkeys_provide &keys_provide_mine, - const Tkeys_require &keys_require_mine); - -protected: - void send_keys_require_mine( - const Tkeys_require &keys_require_mine); - - void recv_require_intersection( - std::vector &keys_provide_mine, - std::vector> &keys_trans_list); - - void intersection( - std::vector &keys_provide_mine, - const Tkeys_require &keys_require, - std::vector &keys_trans); - - virtual std::vector change_keys_provide_mine( - const Tkeys_provide &keys_provide_mine)=0; - - MPI_Comm mpi_comm; - int rank_mine; - int rank_size; - - const int tag_keys = 88; - Comm::Cereal_Func cereal_func; - std::shared_mutex lock_provide; -}; - -template -class Comm_Keys_32_SenderTraversal: public Comm_Keys_32 -{ -public: - Comm_Keys_32_SenderTraversal(const MPI_Comm &mpi_comm); - - std::function< - void( - const Tkeys_provide &keys_provide_mine, - std::function &func )> - traverse_keys_provide; - -private: - std::vector change_keys_provide_mine( - const Tkeys_provide &keys_provide_mine); -}; - -template -class Comm_Keys_32_SenderJudge: public Comm_Keys_32 -{ -public: - Comm_Keys_32_SenderJudge(const MPI_Comm &mpi_comm); - - std::function< - void( - std::function &func )> - traverse_keys_all; - -private: - std::vector change_keys_provide_mine( - const Tkeys_provide &keys_provide_mine); -}; - -} - -#include "Comm_Keys_32-sr.hpp" \ No newline at end of file +// =================== +// Author: Peize Lin +// date: 2022.06.22 +// =================== + +#pragma once + +#include "../global/Cereal_Func.h" + +#include +#include +#include +#include + +namespace Comm +{ + +template +class Comm_Keys_32 +{ +public: + Comm_Keys_32(const MPI_Comm &mpi_comm_in); + + std::vector> trans( + const Tkeys_provide &keys_provide_mine, + const Tkeys_require &keys_require_mine); + +protected: + void send_keys_require_mine( + const Tkeys_require &keys_require_mine); + + void recv_require_intersection( + std::vector &keys_provide_mine, + std::vector> &keys_trans_list); + + void intersection( + std::vector &keys_provide_mine, + const Tkeys_require &keys_require, + std::vector &keys_trans); + + virtual std::vector change_keys_provide_mine( + const Tkeys_provide &keys_provide_mine)=0; + + MPI_Comm mpi_comm; + int rank_mine; + int rank_size; + + const int tag_keys = 88; + Comm::Cereal_Func cereal_func; + std::shared_mutex lock_provide; +}; + +template +class Comm_Keys_32_SenderTraversal: public Comm_Keys_32 +{ +public: + Comm_Keys_32_SenderTraversal(const MPI_Comm &mpi_comm); + + std::function< + void( + const Tkeys_provide &keys_provide_mine, + std::function &func )> + traverse_keys_provide; + +private: + std::vector change_keys_provide_mine( + const Tkeys_provide &keys_provide_mine); +}; + +template +class Comm_Keys_32_SenderJudge: public Comm_Keys_32 +{ +public: + Comm_Keys_32_SenderJudge(const MPI_Comm &mpi_comm); + + std::function< + void( + std::function &func )> + traverse_keys_all; + +private: + std::vector change_keys_provide_mine( + const Tkeys_provide &keys_provide_mine); +}; + +} + +#include "Comm_Keys_32-sr.hpp" diff --git a/include/Comm/Comm_Keys/Comm_Keys_32-sr.hpp b/include/Comm/Comm_Keys/Comm_Keys_32-sr.hpp index 006899a..a1bd9ca 100644 --- a/include/Comm/Comm_Keys/Comm_Keys_32-sr.hpp +++ b/include/Comm/Comm_Keys/Comm_Keys_32-sr.hpp @@ -1,200 +1,201 @@ -// =================== -// Author: Peize Lin -// date: 2022.06.22 -// =================== - -#pragma once - -#include "Comm_Keys_32-sr.h" -#include "../global/Cereal_Func.h" - -#include -#include - -#define MPI_CHECK(x) if((x)!=MPI_SUCCESS) throw std::runtime_error(std::string(__FILE__)+" line "+std::to_string(__LINE__)); - -namespace Comm -{ - -template -Comm_Keys_32::Comm_Keys_32( - const MPI_Comm &mpi_comm_in) - :mpi_comm(mpi_comm_in) -{ - MPI_CHECK( MPI_Comm_size( this->mpi_comm, &this->rank_size ) ); - MPI_CHECK( MPI_Comm_rank( this->mpi_comm, &this->rank_mine ) ); -} - -template -Comm_Keys_32_SenderTraversal::Comm_Keys_32_SenderTraversal( - const MPI_Comm &mpi_comm) - :Comm_Keys_32(mpi_comm) -{ - this->traverse_keys_provide = - [](const Tkeys_provide &keys_provide, std::function &func) - { throw std::logic_error("Function traverse not set."); }; -} - -template -Comm_Keys_32_SenderJudge::Comm_Keys_32_SenderJudge( - const MPI_Comm &mpi_comm) - :Comm_Keys_32(mpi_comm) -{ - this->traverse_keys_all = - [](std::function &func) - { throw std::logic_error("Function traverse not set."); }; -} - - -template -std::vector> Comm_Keys_32::trans( - const Tkeys_provide &keys_provide_mine, - const Tkeys_require &keys_require_mine) -{ - std::vector> keys_trans_list(this->rank_size); - int unfinish_rank = this->rank_size; - std::vector threads; - threads.reserve(this->rank_size+2); - MPI_Status status_working; - status_working.MPI_SOURCE=-1; - - threads.emplace_back( - &Comm_Keys_32::send_keys_require_mine, this, - std::cref(keys_require_mine) ); - - std::vector keys_provide_mine_vec = change_keys_provide_mine(keys_provide_mine); - - threads.emplace_back( - &Comm_Keys_32::intersection, this, - std::ref(keys_provide_mine_vec), - std::cref(keys_require_mine), - std::ref(keys_trans_list[rank_mine])); - --unfinish_rank; - - while(unfinish_rank) - { - int flag_iprobe = false; - MPI_Status status; - MPI_CHECK( MPI_Iprobe( MPI_ANY_SOURCE, MPI_ANY_TAG, this->mpi_comm, &flag_iprobe, &status ) ); - if(flag_iprobe - && (status_working.MPI_SOURCE!=status.MPI_SOURCE)) - { - threads.emplace_back( - &Comm_Keys_32::recv_require_intersection, this, - std::ref(keys_provide_mine_vec), std::ref(keys_trans_list) ); - --unfinish_rank; - status_working = status; - } - else - { - std::this_thread::yield(); - } - } - - for(std::thread &t : threads) - t.join(); - - return keys_trans_list; -} - - -template -void Comm_Keys_32::send_keys_require_mine( - const Tkeys_require &keys_require_mine) -{ - std::stringstream ss_isend; - { - cereal::BinaryOutputArchive ar(ss_isend); - ar(keys_require_mine); - } - const std::size_t exponent_align = this->cereal_func.align_stringstream(ss_isend); - const std::string str_isend = ss_isend.str(); - - std::vector requests_isend(this->rank_size); - for(int rank_recv_tmp=1; rank_recv_tmprank_size; ++rank_recv_tmp) - { - const int rank_recv = (this->rank_mine + rank_recv_tmp) % this->rank_size; - this->cereal_func.mpi_isend(str_isend, exponent_align, rank_recv, this->tag_keys, this->mpi_comm, requests_isend[rank_recv]); - std::this_thread::yield(); - } - - for(int rank_recv_tmp=1; rank_recv_tmprank_size; ++rank_recv_tmp) - { - const int rank_recv = (this->rank_mine + rank_recv_tmp) % this->rank_size; - while(true) - { - int flag_finish = false; - MPI_CHECK( MPI_Test(&requests_isend[rank_recv], &flag_finish, MPI_STATUS_IGNORE) ); - if(flag_finish) break; - std::this_thread::yield(); - } - } -} - - -template -void Comm_Keys_32::recv_require_intersection( - std::vector &keys_provide_mine, - std::vector> &keys_trans_list) -{ - Tkeys_require keys_require; - const MPI_Status status_recv = this->cereal_func.mpi_recv( this->mpi_comm, - keys_require); - const int rank_require = status_recv.MPI_SOURCE; - assert(this->tag_keys==status_recv.MPI_TAG); - - this->intersection(keys_provide_mine, keys_require, keys_trans_list[rank_require]); -} - - -template -void Comm_Keys_32::intersection( - std::vector &keys_provide_mine, - const Tkeys_require &keys_require, - std::vector &keys_trans) -{ - std::vector keys_provide_mine_new; - this->lock_provide.lock(); - for(const Tkey &key : keys_provide_mine) - if(keys_require.judge(key)) - keys_trans.push_back(key); - else - keys_provide_mine_new.push_back(key); - this->lock_provide.unlock(); - - this->lock_provide.lock_shared(); - keys_provide_mine = keys_provide_mine_new; - this->lock_provide.unlock(); -} - - -template -std::vector Comm_Keys_32_SenderTraversal::change_keys_provide_mine( - const Tkeys_provide &keys_provide_mine) -{ - std::vector keys_provide_mine_vec; - std::function change = [&](const Tkey &key) - { - keys_provide_mine_vec.push_back(key); - }; - this->traverse_keys_provide( keys_provide_mine, change ); - return keys_provide_mine_vec; -} - -template -std::vector Comm_Keys_32_SenderJudge::change_keys_provide_mine( - const Tkeys_provide &keys_provide_mine) -{ - std::vector keys_provide_mine_vec; - std::function change = [&](const Tkey &key) - { - if(keys_provide_mine.judge(key)) - keys_provide_mine_vec.push_back(key); - }; - this->traverse_keys_all( change ); - return keys_provide_mine_vec; -} - -} - -#undef MPI_CHECK \ No newline at end of file +// =================== +// Author: Peize Lin +// date: 2022.06.22 +// =================== + +#pragma once + +#include "Comm_Keys_32-sr.h" +#include "../global/Cereal_Func.h" + +#include +#include +#include + +#define MPI_CHECK(x) if((x)!=MPI_SUCCESS) throw std::runtime_error(std::string(__FILE__)+" line "+std::to_string(__LINE__)); + +namespace Comm +{ + +template +Comm_Keys_32::Comm_Keys_32( + const MPI_Comm &mpi_comm_in) + :mpi_comm(mpi_comm_in) +{ + MPI_CHECK( MPI_Comm_size( this->mpi_comm, &this->rank_size ) ); + MPI_CHECK( MPI_Comm_rank( this->mpi_comm, &this->rank_mine ) ); +} + +template +Comm_Keys_32_SenderTraversal::Comm_Keys_32_SenderTraversal( + const MPI_Comm &mpi_comm) + :Comm_Keys_32(mpi_comm) +{ + this->traverse_keys_provide = + [](const Tkeys_provide &keys_provide, std::function &func) + { throw std::logic_error("Function traverse not set."); }; +} + +template +Comm_Keys_32_SenderJudge::Comm_Keys_32_SenderJudge( + const MPI_Comm &mpi_comm) + :Comm_Keys_32(mpi_comm) +{ + this->traverse_keys_all = + [](std::function &func) + { throw std::logic_error("Function traverse not set."); }; +} + + +template +std::vector> Comm_Keys_32::trans( + const Tkeys_provide &keys_provide_mine, + const Tkeys_require &keys_require_mine) +{ + std::vector> keys_trans_list(this->rank_size); + int unfinish_rank = this->rank_size; + std::vector threads; + threads.reserve(this->rank_size+2); + MPI_Status status_working; + status_working.MPI_SOURCE=-1; + + threads.emplace_back( + &Comm_Keys_32::send_keys_require_mine, this, + std::cref(keys_require_mine) ); + + std::vector keys_provide_mine_vec = change_keys_provide_mine(keys_provide_mine); + + threads.emplace_back( + &Comm_Keys_32::intersection, this, + std::ref(keys_provide_mine_vec), + std::cref(keys_require_mine), + std::ref(keys_trans_list[rank_mine])); + --unfinish_rank; + + while(unfinish_rank) + { + int flag_iprobe = false; + MPI_Status status; + MPI_CHECK( MPI_Iprobe( MPI_ANY_SOURCE, MPI_ANY_TAG, this->mpi_comm, &flag_iprobe, &status ) ); + if(flag_iprobe + && (status_working.MPI_SOURCE!=status.MPI_SOURCE)) + { + threads.emplace_back( + &Comm_Keys_32::recv_require_intersection, this, + std::ref(keys_provide_mine_vec), std::ref(keys_trans_list) ); + --unfinish_rank; + status_working = status; + } + else + { + std::this_thread::yield(); + } + } + + for(std::thread &t : threads) + t.join(); + + return keys_trans_list; +} + + +template +void Comm_Keys_32::send_keys_require_mine( + const Tkeys_require &keys_require_mine) +{ + std::stringstream ss_isend; + { + cereal::BinaryOutputArchive ar(ss_isend); + ar(keys_require_mine); + } + const std::size_t exponent_align = this->cereal_func.align_stringstream(ss_isend); + const std::string str_isend = ss_isend.str(); + + std::vector requests_isend(this->rank_size); + for(int rank_recv_tmp=1; rank_recv_tmprank_size; ++rank_recv_tmp) + { + const int rank_recv = (this->rank_mine + rank_recv_tmp) % this->rank_size; + this->cereal_func.mpi_isend(str_isend, exponent_align, rank_recv, this->tag_keys, this->mpi_comm, requests_isend[rank_recv]); + std::this_thread::yield(); + } + + for(int rank_recv_tmp=1; rank_recv_tmprank_size; ++rank_recv_tmp) + { + const int rank_recv = (this->rank_mine + rank_recv_tmp) % this->rank_size; + while(true) + { + int flag_finish = false; + MPI_CHECK( MPI_Test(&requests_isend[rank_recv], &flag_finish, MPI_STATUS_IGNORE) ); + if(flag_finish) break; + std::this_thread::yield(); + } + } +} + + +template +void Comm_Keys_32::recv_require_intersection( + std::vector &keys_provide_mine, + std::vector> &keys_trans_list) +{ + Tkeys_require keys_require; + const MPI_Status status_recv = this->cereal_func.mpi_recv( this->mpi_comm, + keys_require); + const int rank_require = status_recv.MPI_SOURCE; + assert(this->tag_keys==status_recv.MPI_TAG); + + this->intersection(keys_provide_mine, keys_require, keys_trans_list[rank_require]); +} + + +template +void Comm_Keys_32::intersection( + std::vector &keys_provide_mine, + const Tkeys_require &keys_require, + std::vector &keys_trans) +{ + std::vector keys_provide_mine_new; + this->lock_provide.lock(); + for(const Tkey &key : keys_provide_mine) + if(keys_require.judge(key)) + keys_trans.push_back(key); + else + keys_provide_mine_new.push_back(key); + this->lock_provide.unlock(); + + this->lock_provide.lock_shared(); + keys_provide_mine = keys_provide_mine_new; + this->lock_provide.unlock(); +} + + +template +std::vector Comm_Keys_32_SenderTraversal::change_keys_provide_mine( + const Tkeys_provide &keys_provide_mine) +{ + std::vector keys_provide_mine_vec; + std::function change = [&](const Tkey &key) + { + keys_provide_mine_vec.push_back(key); + }; + this->traverse_keys_provide( keys_provide_mine, change ); + return keys_provide_mine_vec; +} + +template +std::vector Comm_Keys_32_SenderJudge::change_keys_provide_mine( + const Tkeys_provide &keys_provide_mine) +{ + std::vector keys_provide_mine_vec; + std::function change = [&](const Tkey &key) + { + if(keys_provide_mine.judge(key)) + keys_provide_mine_vec.push_back(key); + }; + this->traverse_keys_all( change ); + return keys_provide_mine_vec; +} + +} + +#undef MPI_CHECK diff --git a/include/Comm/Comm_Tools.h b/include/Comm/Comm_Tools.h index 8b2791e..00200ba 100644 --- a/include/Comm/Comm_Tools.h +++ b/include/Comm/Comm_Tools.h @@ -1,16 +1,16 @@ -// =================== -// Author: Peize Lin -// date: 2022.07.06 -// =================== - -#pragma once - -namespace Comm -{ - -namespace Comm_Tools -{ - enum class Lock_Type {Lock_free, Lock_item, Lock_Process, Copy_merge}; -} - -} \ No newline at end of file +// =================== +// Author: Peize Lin +// date: 2022.07.06 +// =================== + +#pragma once + +namespace Comm +{ + +namespace Comm_Tools +{ + enum class Lock_Type {Lock_free, Lock_item, Lock_Process, Copy_merge}; +} + +} diff --git a/include/Comm/Comm_Trans/Comm_Trans.h b/include/Comm/Comm_Trans/Comm_Trans.h index 29090a6..f10a6a2 100644 --- a/include/Comm/Comm_Trans/Comm_Trans.h +++ b/include/Comm/Comm_Trans/Comm_Trans.h @@ -1,75 +1,76 @@ -//======================= -// AUTHOR : Peize Lin -// DATE : 2022-01-05 -//======================= - -#pragma once - -#include "../Comm_Tools.h" -#include "../global/Global_Func.h" -#include -#include -#include -#include - -namespace Comm -{ - -template -class Comm_Trans -{ -public: - std::function< - void( - const Tdatas_isend &datas_isend, - const int rank_isend, - std::function &func)> - traverse_isend; - std::function< - void( - Tkey &&key, - Tvalue &&value, - Tdatas_recv &datas_recv)> - set_value_recv; - - Comm_Tools::Lock_Type flag_lock_set_value; - std::function< - Tdatas_recv( - const int rank_recv)> - init_datas_local; - std::function< - void( - Tdatas_recv &&datas_local, - Tdatas_recv &datas_recv)> - add_datas; - -public: - Comm_Trans(const MPI_Comm &mpi_comm_in); -// Comm_Trans(const Comm_Trans &com); - void communicate( - const Tdatas_isend &datas_isend, - Tdatas_recv &datas_recv); - -private: - void isend_data (const int rank_isend, const Tdatas_isend &datas_isend, std::string &str_isend, MPI_Request &request_isend, std::atomic &memory_max_isend); - void recv_data (Tdatas_recv &datas_recv, const MPI_Status status_recv, MPI_Message message_recv, std::atomic_flag &lock_set_value, std::atomic &memory_max_isend); - void post_process( - std::vector &requests_isend, - std::vector &strs_isend, - std::vector> &futures_isend, - std::vector> &futures_recv) const; - bool memory_enough(const std::atomic &memory_max) const { return Global_Func::memory_available() > memory_max.load() * 2; } - -public: - const MPI_Comm &mpi_comm; - int rank_mine = 0; - int comm_size = 1; - -private: - const int tag_data = 0; - Comm::Cereal_Func cereal_func; -}; - -} - -#include "Comm_Trans.hpp" \ No newline at end of file +//======================= +// AUTHOR : Peize Lin +// DATE : 2022-01-05 +//======================= + +#pragma once + +#include "../Comm_Tools.h" +#include "../global/Cereal_Func.h" +#include "../global/Global_Func.h" +#include +#include +#include +#include + +namespace Comm +{ + +template +class Comm_Trans +{ +public: + std::function< + void( + const Tdatas_isend &datas_isend, + const int rank_isend, + std::function &func)> + traverse_isend; + std::function< + void( + Tkey &&key, + Tvalue &&value, + Tdatas_recv &datas_recv)> + set_value_recv; + + Comm_Tools::Lock_Type flag_lock_set_value; + std::function< + Tdatas_recv( + const int rank_recv)> + init_datas_local; + std::function< + void( + Tdatas_recv &&datas_local, + Tdatas_recv &datas_recv)> + add_datas; + +public: + Comm_Trans(const MPI_Comm &mpi_comm_in); +// Comm_Trans(const Comm_Trans &com); + void communicate( + const Tdatas_isend &datas_isend, + Tdatas_recv &datas_recv); + +private: + void isend_data (const int rank_isend, const Tdatas_isend &datas_isend, std::string &str_isend, MPI_Request &request_isend, std::atomic &memory_max_isend); + void recv_data (Tdatas_recv &datas_recv, const MPI_Status status_recv, MPI_Message message_recv, std::atomic_flag &lock_set_value, std::atomic &memory_max_isend); + void post_process( + std::vector &requests_isend, + std::vector &strs_isend, + std::vector> &futures_isend, + std::vector> &futures_recv) const; + bool memory_enough(const std::atomic &memory_max) const { return Global_Func::memory_available() > memory_max.load() * 2; } + +public: + const MPI_Comm &mpi_comm; + int rank_mine = 0; + int comm_size = 1; + +private: + const int tag_data = 0; + Comm::Cereal_Func cereal_func; +}; + +} + +#include "Comm_Trans.hpp" diff --git a/include/Comm/Comm_Trans/Comm_Trans.hpp b/include/Comm/Comm_Trans/Comm_Trans.hpp index 0bbd141..af1ceb3 100644 --- a/include/Comm/Comm_Trans/Comm_Trans.hpp +++ b/include/Comm/Comm_Trans/Comm_Trans.hpp @@ -1,287 +1,286 @@ -//======================= -// AUTHOR : Peize Lin -// DATE : 2022-01-05 -//======================= - -#pragma once - -#include "Comm_Trans.h" -#include "../global/Cereal_Func.h" - -#include -#include -#include - -#include -#include -#include - -#define MPI_CHECK(x) if((x)!=MPI_SUCCESS) throw std::runtime_error(std::string(__FILE__)+" line "+std::to_string(__LINE__)); - -namespace Comm -{ - -template -Comm_Trans::Comm_Trans(const MPI_Comm &mpi_comm_in) - :mpi_comm(mpi_comm_in) -{ - MPI_CHECK (MPI_Comm_size (this->mpi_comm, &this->comm_size)); - MPI_CHECK (MPI_Comm_rank (this->mpi_comm, &this->rank_mine)); - - this->set_value_recv - = [](Tkey &&key, Tvalue &&value, Tdatas_recv &datas_recv) - { throw std::logic_error("Function set_value not set."); }; - this->traverse_isend - = [](const Tdatas_isend &datas_isend, const int rank_isend, std::function &func) - { throw std::logic_error("Function traverse not set."); }; - this->init_datas_local - = [](const int rank_recv) -> Tdatas_recv - { throw std::logic_error("Function init_datas_local not set."); }; - this->add_datas - = [](Tdatas_recv &&datas_local, Tdatas_recv &datas_recv) - { throw std::logic_error("Function add_datas not set."); }; -} - -/* -template -Comm_Trans::Comm_Trans(const Comm_Trans &com) - :mpi_comm(com.mpi_comm) -{ - //ofs<<"C"<<" "; - MPI_CHECK (MPI_Comm_size (this->mpi_comm, &this->comm_size)); - MPI_CHECK (MPI_Comm_rank (this->mpi_comm, &this->rank_mine)); - this->set_value_recv = com.set_value_recv; - this->traverse_isend = com.traverse_isend; - this->flag_lock_set_value = com.flag_lock_set_value; - this->init_datas_local = com.init_datas_local; - this->add_datas = com.add_datas; -} -*/ - -template -void Comm_Trans::communicate( - const Tdatas_isend &datas_isend, - Tdatas_recv &datas_recv) -{ - // initialization - int rank_isend_tmp = 0; - int rank_recv_working = -1; - - std::vector requests_isend(comm_size); - std::vector strs_isend(comm_size); - std::vector> futures_isend(comm_size); - std::vector> futures_recv(comm_size); - std::atomic_flag lock_set_value = ATOMIC_FLAG_INIT; - std::atomic memory_max_isend(0); - std::atomic memory_max_recv(0); - - std::future future_post_process = std::async (std::launch::async, - &Comm_Trans::post_process, this, - std::ref(requests_isend), std::ref(strs_isend), std::ref(futures_isend), std::ref(futures_recv)); - - while (future_post_process.wait_for(std::chrono::seconds(0)) != std::future_status::ready) - { - int flag_iprobe=0; - MPI_Status status_recv; - MPI_Message message_recv; - MPI_CHECK (MPI_Improbe(MPI_ANY_SOURCE, this->tag_data, this->mpi_comm, &flag_iprobe, &message_recv, &status_recv)); - if (flag_iprobe && rank_recv_working!=status_recv.MPI_SOURCE && memory_enough(memory_max_recv)) - { - futures_recv[status_recv.MPI_SOURCE] = std::async (std::launch::async, - &Comm_Trans::recv_data, this, - std::ref(datas_recv), status_recv, message_recv, std::ref(lock_set_value), std::ref(memory_max_recv)); - rank_recv_working = status_recv.MPI_SOURCE; - } - - if (rank_isend_tmp < this->comm_size && memory_enough(memory_max_isend)) - { - const int rank_isend = (rank_isend_tmp + this->rank_mine) % this->comm_size; - futures_isend[rank_isend] = std::async (std::launch::async, - &Comm_Trans::isend_data, this, - rank_isend, std::cref(datas_isend), std::ref(strs_isend[rank_isend]), std::ref(requests_isend[rank_isend]), std::ref(memory_max_isend)); - ++rank_isend_tmp; - } - } - future_post_process.get(); -} - - -template -void Comm_Trans::isend_data( - const int rank_isend, - const Tdatas_isend &datas_isend, - std::string &str_isend, - MPI_Request &request_isend, - std::atomic &memory_max_isend) -{ - std::stringstream ss_isend; - { - cereal::BinaryOutputArchive oar(ss_isend); - - size_t size_item = 0; - oar(size_item); // 占位 - - std::function archive_data = [&oar, &size_item]( - const Tkey &key, const Tvalue &value) - { - oar(key, value); - ++size_item; - }; - this->traverse_isend(datas_isend, rank_isend, archive_data); - - ss_isend.rdbuf()->pubseekpos(0); // 返回size_item的占位,序列化真正的size_item值 - oar(size_item); - } // end cereal::BinaryOutputArchive - const std::size_t exponent_align = this->cereal_func.align_stringstream(ss_isend); - str_isend = ss_isend.str(); - memory_max_isend.store( std::max(str_isend.size()*sizeof(char), memory_max_isend.load()) ); - this->cereal_func.mpi_isend(str_isend, exponent_align, rank_isend, this->tag_data, this->mpi_comm, request_isend); -} - - - -template -void Comm_Trans::recv_data ( - Tdatas_recv &datas_recv, - const MPI_Status status_recv, - MPI_Message message_recv, - std::atomic_flag &lock_set_value, - std::atomic &memory_max_recv) -{ - std::vector buffer_recv = this->cereal_func.mpi_mrecv(message_recv, status_recv); - - std::stringstream ss_recv; - ss_recv.rdbuf()->pubsetbuf(buffer_recv.data(), buffer_recv.size()); - memory_max_recv.store( std::max(buffer_recv.size()*sizeof(char), memory_max_recv.load()) ); - - { - cereal::BinaryInputArchive iar(ss_recv); - size_t size_item; iar(size_item); - - if (this->flag_lock_set_value==Comm_Tools::Lock_Type::Lock_free) - { - for (size_t i=0; iset_value_recv(std::move(key), std::move(value), datas_recv); - } - } - else if (this->flag_lock_set_value==Comm_Tools::Lock_Type::Lock_item) - { - for (size_t i=0; iset_value_recv(std::move(key), std::move(value), datas_recv); - lock_set_value.clear(std::memory_order_seq_cst); - } - } - else if (this->flag_lock_set_value==Comm_Tools::Lock_Type::Lock_Process) - { - while (lock_set_value.test_and_set(std::memory_order_seq_cst)) std::this_thread::yield(); - for (size_t i=0; iset_value_recv(std::move(key), std::move(value), datas_recv); - } - lock_set_value.clear(std::memory_order_seq_cst); - } - else if (this->flag_lock_set_value==Comm_Tools::Lock_Type::Copy_merge) - { - Tdatas_recv datas_local = this->init_datas_local (status_recv.MPI_SOURCE); - for (size_t i=0; iset_value_recv (std::move(key), std::move(value), datas_local); - } - while (lock_set_value.test_and_set(std::memory_order_seq_cst)) std::this_thread::yield(); - this->add_datas (std::move(datas_local), datas_recv); - lock_set_value.clear(std::memory_order_seq_cst); - } - else - { - throw std::invalid_argument( - +" file "+std::string(__FILE__) - +" line "+std::to_string(__LINE__) - +" rank_mine "+std::to_string(this->rank_mine) - +" rank_recv "+std::to_string(status_recv.MPI_SOURCE)); - } - } // end cereal::BinaryInputArchive -} - - -template -void Comm_Trans::post_process( - std::vector &requests_isend, - std::vector &strs_isend, - std::vector> &futures_isend, - std::vector> &futures_recv) const -{ - int rank_isend_free_tmp = 0; - int rank_recv_free_tmp = 0; - while (rank_isend_free_tmp < this->comm_size - || rank_recv_free_tmp < this->comm_size) - { - while (rank_isend_free_tmp < this->comm_size) - { - const int rank_isend_free = (this->rank_mine+rank_isend_free_tmp)%this->comm_size; - if (futures_isend[rank_isend_free].valid() - && futures_isend[rank_isend_free].wait_for(std::chrono::seconds(0)) == std::future_status::ready) - { - int flag_finish=0; - MPI_CHECK (MPI_Test (&(requests_isend[rank_isend_free]), &flag_finish, MPI_STATUS_IGNORE)); - if (flag_finish) - { - // MPI_CHECK (MPI_Request_free (&requests_isend[rank_isend_free])); - futures_isend[rank_isend_free].get(); - strs_isend[rank_isend_free].clear(); - ++rank_isend_free_tmp; - } - else{ break; } - } - else{ break; } - } - - while (rank_recv_free_tmp < this->comm_size) - { - const int rank_recv_free = (this->rank_mine+rank_recv_free_tmp)%this->comm_size; - if (futures_recv[rank_recv_free].valid() - && futures_recv[rank_recv_free].wait_for(std::chrono::seconds(0)) == std::future_status::ready) - { - futures_recv[rank_recv_free].get(); - ++rank_recv_free_tmp; - } - else{ break; } - } - - std::this_thread::yield(); - } -} - -} - -#undef MPI_CHECK - -/* -get_send_keys() -{ - - if(unique) - { - for(irank in all) - send(irank_send, atom_pairs_remove); - } -} -*/ \ No newline at end of file +//======================= +// AUTHOR : Peize Lin +// DATE : 2022-01-05 +//======================= + +#pragma once + +#include "Comm_Trans.h" + +#include +#include +#include + +#include +#include +#include + +#define MPI_CHECK(x) if((x)!=MPI_SUCCESS) throw std::runtime_error(std::string(__FILE__)+" line "+std::to_string(__LINE__)); + +namespace Comm +{ + +template +Comm_Trans::Comm_Trans(const MPI_Comm &mpi_comm_in) + :mpi_comm(mpi_comm_in) +{ + MPI_CHECK (MPI_Comm_size (this->mpi_comm, &this->comm_size)); + MPI_CHECK (MPI_Comm_rank (this->mpi_comm, &this->rank_mine)); + + this->set_value_recv + = [](Tkey &&key, Tvalue &&value, Tdatas_recv &datas_recv) + { throw std::logic_error("Function set_value not set."); }; + this->traverse_isend + = [](const Tdatas_isend &datas_isend, const int rank_isend, std::function &func) + { throw std::logic_error("Function traverse not set."); }; + this->init_datas_local + = [](const int rank_recv) -> Tdatas_recv + { throw std::logic_error("Function init_datas_local not set."); }; + this->add_datas + = [](Tdatas_recv &&datas_local, Tdatas_recv &datas_recv) + { throw std::logic_error("Function add_datas not set."); }; +} + +/* +template +Comm_Trans::Comm_Trans(const Comm_Trans &com) + :mpi_comm(com.mpi_comm) +{ + //ofs<<"C"<<" "; + MPI_CHECK (MPI_Comm_size (this->mpi_comm, &this->comm_size)); + MPI_CHECK (MPI_Comm_rank (this->mpi_comm, &this->rank_mine)); + this->set_value_recv = com.set_value_recv; + this->traverse_isend = com.traverse_isend; + this->flag_lock_set_value = com.flag_lock_set_value; + this->init_datas_local = com.init_datas_local; + this->add_datas = com.add_datas; +} +*/ + +template +void Comm_Trans::communicate( + const Tdatas_isend &datas_isend, + Tdatas_recv &datas_recv) +{ + // initialization + int rank_isend_tmp = 0; + int rank_recv_working = -1; + + std::vector requests_isend(comm_size); + std::vector strs_isend(comm_size); + std::vector> futures_isend(comm_size); + std::vector> futures_recv(comm_size); + std::atomic_flag lock_set_value = ATOMIC_FLAG_INIT; + std::atomic memory_max_isend(0); + std::atomic memory_max_recv(0); + + std::future future_post_process = std::async (std::launch::async, + &Comm_Trans::post_process, this, + std::ref(requests_isend), std::ref(strs_isend), std::ref(futures_isend), std::ref(futures_recv)); + + while (future_post_process.wait_for(std::chrono::seconds(0)) != std::future_status::ready) + { + int flag_iprobe=0; + MPI_Status status_recv; + MPI_Message message_recv; + MPI_CHECK (MPI_Improbe(MPI_ANY_SOURCE, this->tag_data, this->mpi_comm, &flag_iprobe, &message_recv, &status_recv)); + if (flag_iprobe && rank_recv_working!=status_recv.MPI_SOURCE && memory_enough(memory_max_recv)) + { + futures_recv[status_recv.MPI_SOURCE] = std::async (std::launch::async, + &Comm_Trans::recv_data, this, + std::ref(datas_recv), status_recv, message_recv, std::ref(lock_set_value), std::ref(memory_max_recv)); + rank_recv_working = status_recv.MPI_SOURCE; + } + + if (rank_isend_tmp < this->comm_size && memory_enough(memory_max_isend)) + { + const int rank_isend = (rank_isend_tmp + this->rank_mine) % this->comm_size; + futures_isend[rank_isend] = std::async (std::launch::async, + &Comm_Trans::isend_data, this, + rank_isend, std::cref(datas_isend), std::ref(strs_isend[rank_isend]), std::ref(requests_isend[rank_isend]), std::ref(memory_max_isend)); + ++rank_isend_tmp; + } + } + future_post_process.get(); +} + + +template +void Comm_Trans::isend_data( + const int rank_isend, + const Tdatas_isend &datas_isend, + std::string &str_isend, + MPI_Request &request_isend, + std::atomic &memory_max_isend) +{ + std::stringstream ss_isend; + { + cereal::BinaryOutputArchive oar(ss_isend); + + size_t size_item = 0; + oar(size_item); // 占位 + + std::function archive_data = [&oar, &size_item]( + const Tkey &key, const Tvalue &value) + { + oar(key, value); + ++size_item; + }; + this->traverse_isend(datas_isend, rank_isend, archive_data); + + ss_isend.rdbuf()->pubseekpos(0); // 返回size_item的占位,序列化真正的size_item值 + oar(size_item); + } // end cereal::BinaryOutputArchive + const std::size_t exponent_align = this->cereal_func.align_stringstream(ss_isend); + str_isend = ss_isend.str(); + memory_max_isend.store( std::max(str_isend.size()*sizeof(char), memory_max_isend.load()) ); + this->cereal_func.mpi_isend(str_isend, exponent_align, rank_isend, this->tag_data, this->mpi_comm, request_isend); +} + + + +template +void Comm_Trans::recv_data ( + Tdatas_recv &datas_recv, + const MPI_Status status_recv, + MPI_Message message_recv, + std::atomic_flag &lock_set_value, + std::atomic &memory_max_recv) +{ + std::vector buffer_recv = this->cereal_func.mpi_mrecv(message_recv, status_recv); + + std::stringstream ss_recv; + ss_recv.rdbuf()->pubsetbuf(buffer_recv.data(), buffer_recv.size()); + memory_max_recv.store( std::max(buffer_recv.size()*sizeof(char), memory_max_recv.load()) ); + + { + cereal::BinaryInputArchive iar(ss_recv); + size_t size_item; iar(size_item); + + if (this->flag_lock_set_value==Comm_Tools::Lock_Type::Lock_free) + { + for (size_t i=0; iset_value_recv(std::move(key), std::move(value), datas_recv); + } + } + else if (this->flag_lock_set_value==Comm_Tools::Lock_Type::Lock_item) + { + for (size_t i=0; iset_value_recv(std::move(key), std::move(value), datas_recv); + lock_set_value.clear(std::memory_order_seq_cst); + } + } + else if (this->flag_lock_set_value==Comm_Tools::Lock_Type::Lock_Process) + { + while (lock_set_value.test_and_set(std::memory_order_seq_cst)) std::this_thread::yield(); + for (size_t i=0; iset_value_recv(std::move(key), std::move(value), datas_recv); + } + lock_set_value.clear(std::memory_order_seq_cst); + } + else if (this->flag_lock_set_value==Comm_Tools::Lock_Type::Copy_merge) + { + Tdatas_recv datas_local = this->init_datas_local (status_recv.MPI_SOURCE); + for (size_t i=0; iset_value_recv (std::move(key), std::move(value), datas_local); + } + while (lock_set_value.test_and_set(std::memory_order_seq_cst)) std::this_thread::yield(); + this->add_datas (std::move(datas_local), datas_recv); + lock_set_value.clear(std::memory_order_seq_cst); + } + else + { + throw std::invalid_argument( + +" file "+std::string(__FILE__) + +" line "+std::to_string(__LINE__) + +" rank_mine "+std::to_string(this->rank_mine) + +" rank_recv "+std::to_string(status_recv.MPI_SOURCE)); + } + } // end cereal::BinaryInputArchive +} + + +template +void Comm_Trans::post_process( + std::vector &requests_isend, + std::vector &strs_isend, + std::vector> &futures_isend, + std::vector> &futures_recv) const +{ + int rank_isend_free_tmp = 0; + int rank_recv_free_tmp = 0; + while (rank_isend_free_tmp < this->comm_size + || rank_recv_free_tmp < this->comm_size) + { + while (rank_isend_free_tmp < this->comm_size) + { + const int rank_isend_free = (this->rank_mine+rank_isend_free_tmp)%this->comm_size; + if (futures_isend[rank_isend_free].valid() + && futures_isend[rank_isend_free].wait_for(std::chrono::seconds(0)) == std::future_status::ready) + { + int flag_finish=0; + MPI_CHECK (MPI_Test (&(requests_isend[rank_isend_free]), &flag_finish, MPI_STATUS_IGNORE)); + if (flag_finish) + { + // MPI_CHECK (MPI_Request_free (&requests_isend[rank_isend_free])); + futures_isend[rank_isend_free].get(); + strs_isend[rank_isend_free].clear(); + ++rank_isend_free_tmp; + } + else{ break; } + } + else{ break; } + } + + while (rank_recv_free_tmp < this->comm_size) + { + const int rank_recv_free = (this->rank_mine+rank_recv_free_tmp)%this->comm_size; + if (futures_recv[rank_recv_free].valid() + && futures_recv[rank_recv_free].wait_for(std::chrono::seconds(0)) == std::future_status::ready) + { + futures_recv[rank_recv_free].get(); + ++rank_recv_free_tmp; + } + else{ break; } + } + + std::this_thread::yield(); + } +} + +} + +#undef MPI_CHECK + +/* +get_send_keys() +{ + + if(unique) + { + for(irank in all) + send(irank_send, atom_pairs_remove); + } +} +*/ diff --git a/include/Comm/example/Communicate_Map-1.h b/include/Comm/example/Communicate_Map-1.h index 29ec62b..e95242d 100644 --- a/include/Comm/example/Communicate_Map-1.h +++ b/include/Comm/example/Communicate_Map-1.h @@ -1,342 +1,342 @@ -//======================= -// AUTHOR : Peize Lin -// DATE : 2022-01-05 -//======================= - -#pragma once - -#include -#include -#include - -namespace Comm -{ - -namespace Communicate_Map -{ - /* - template - inline Tvalue &get_value( - const std::tuple &key) - { - return data[std::get<0>(key)][std::get<1>(key)]; - } - - Tvalue &get_value( - const std::tuple &key) - { - return data[std::get<0>(key)][std::get<1>(key)][std::get<2>(key)]; - } - */ - - // 等于 - /* - template - void set_value_assignment( - const std::tuple &key, - const Tvalue &value, - std::map> &data) - { - data[std::get<0>(key)][std::get<1>(key)] = value; - } - */ - - template - void set_value_assignment( - Tkey &&key, - Tvalue &&value, - std::map &data) - { - data[key] = std::move(value); - } - - template - void set_value_assignment( - std::tuple &&key, - Tvalue &&value, - std::map> &data) - { - data[std::get<0>(key)][std::get<1>(key)] = std::move(value); - } - - template - void set_value_assignment( - std::tuple &&key, - Tvalue &&value, - std::map>> &data) - { - data[std::get<0>(key)][std::get<1>(key)][std::get<2>(key)] = std::move(value); - } - - // 加 - /* - template - void set_value_add( - const std::tuple &key, - const Tvalue &value, - std::map> &data) - { - Tvalue &value_tmp = data[std::get<0>(key)][std::get<1>(key)]; - if (value_tmp.empty()) - value_tmp = value; - else - value_tmp = value_tmp + value; - } - */ - - template - void set_value_add( - Tkey &&key, - Tvalue &&value, - std::map &data) - { - auto ptr = data.find(key); - if(ptr==data.end()) - data[key] = std::move(value); - else - ptr->second = ptr->second + std::move(value); -// Tvalue &value_tmp = data[key]; -// if (!value_tmp) -// value_tmp = std::move(value); -// else -// value_tmp = value_tmp + std::move(value); - } - - template - void set_value_add( - std::tuple &&key, - Tvalue &&value, - std::map> &data) - { - set_value_add(std::move(std::get<1>(key)), std::move(value), data[std::get<0>(key)]); -// Tvalue &value_tmp = data[std::get<0>(key)][std::get<1>(key)]; -// if (!value_tmp) -// value_tmp = std::move(value); -// else -// value_tmp = value_tmp + std::move(value); - } - - template - void set_value_add( - std::tuple &&key, - Tvalue &&value, - std::map>> &data) - { - set_value_add(std::move(std::get<2>(key)), std::move(value), data[std::get<0>(key)][std::get<1>(key)]); -// Tvalue &value_tmp = data[std::get<0>(key)][std::get<1>(key)][std::get<2>(key)]; -// if (!value_tmp) -// value_tmp = std::move(value); -// else -// value_tmp = value_tmp + std::move(value); - } - - /* - // 等于 - template - set_value( - const std::tuple &key, - const Tvalue &value, - std::map>> &data) - { - data[std::get<0>(key)][std::get<1>(key)][std::get<2>(key)] = value; - } - - // 加 - set_value( - const std::tuple &key, - const Tvalue &value, - std::map>> &data) - { - Tensor<> &data_tmp = data[std::get<0>(key)][std::get<1>(key)][std::get<2>(key)]; - if (data_tmp.c) - data_tmp = data_tmp + value; - else - data_tmp = value; - } - */ - - // 无筛选,全部遍历 - template - void traverse_datas_all( - const std::map &data, - const int rank_isend, - std::function &func) - { - for (const auto &item : data) - func (item.first, item.second); - } - - template - void traverse_datas_all( - const std::map> &data, - const int rank_isend, - std::function&, const Tvalue&)> &func) - { - for (const auto &dataA : data) - for (const auto &dataB : dataA.second) - func (std::make_tuple(dataA.first,dataB.first), dataB.second); - } - - - - // 加横纵接收筛选。进程给出的label不做筛选。即,当set_value为=时需所有进程给出的label重复很少,或set_value为+。 - template - void traverse_datas_mask( - const std::map> &data, - const int rank_isend, - std::function&, const Tvalue&)> &func, - const std::function &mask0, - const std::function &mask1) - { - for (const auto &dataA : data) - { - const Tkey0 &key0 = dataA.first; - if (!mask0(rank_isend,key0)) continue; - for (const auto &dataB : dataA.second) - { - const Tkey1 &key1 = dataB.first; - if (!mask1(rank_isend,key1)) continue; - func (std::make_tuple(key0,key1), dataB.second); - } - } - } - - - template - std::map init_datas_local(const int rank_recv) - { - return std::map(); - } - - template - std::map> init_datas_local(const int rank_recv) - { - return std::map>(); - } - - template - std::map>> init_datas_local(const int rank_recv) - { - return std::map>>(); - } - - - template - void add_datas( - std::map &&data_local, - std::map &data_recv) - { - auto ptr_local=data_local.begin(); - auto ptr_recv=data_recv.begin(); - for(; ptr_local!=data_local.end() && ptr_recv!=data_recv.end(); ) - { - const Tkey &key_local = ptr_local->first; - const Tkey &key_recv = ptr_recv->first; - if(key_local == key_recv) - { - ptr_recv->second = ptr_recv->second + std::move(ptr_local->second); - ++ptr_local; - ++ptr_recv; - } - else if(key_local < key_recv) - { - ptr_recv = data_recv.emplace_hint(ptr_recv, key_local, std::move(ptr_local->second)); - ++ptr_local; - } - else - { - ++ptr_recv; - } - } - for(; ptr_local!=data_local.end(); ++ptr_local) - { - ptr_recv = data_recv.emplace_hint(ptr_recv, ptr_local->first, std::move(ptr_local->second)); - } - -// for (auto &&data_local_A : data_local) -// { -// auto ptr = data_recv.find(data_local_A.first); -// if(ptr==data_recv.end()) -// data_recv[data_local_A.first] = std::move(data_local_A.second); -// else -// ptr->second = ptr->second + std::move(data_local_A.second); -// } - -// for (auto &&data_local_A : data_local) -// { -// Tvalue &data = data_recv[std::move(data_local_A.first)]; -// if(!data) -// data = std::move(data_local_A.second); -// else -// data = data + std::move(data_local_A.second); -// } - } - - template - void add_datas( - std::map> &&data_local, - std::map> &data_recv) - { - auto ptr_local=data_local.begin(); - auto ptr_recv=data_recv.begin(); - for(; ptr_local!=data_local.end() && ptr_recv!=data_recv.end(); ) - { - const Tkey0 &key_local = ptr_local->first; - const Tkey0 &key_recv = ptr_recv->first; - if(key_local == key_recv) - { - add_datas(std::move(ptr_local->second), ptr_recv->second); - ++ptr_local; - ++ptr_recv; - } - else if(key_local < key_recv) - { - ptr_recv = data_recv.emplace_hint(ptr_recv, key_local, std::move(ptr_local->second)); - ++ptr_local; - } - else - { - ++ptr_recv; - } - } - for(; ptr_local!=data_local.end(); ++ptr_local) - { - ptr_recv = data_recv.emplace_hint(ptr_recv, ptr_local->first, std::move(ptr_local->second)); - } - -// for (auto &&data_local_A : data_local) -// { -// for (auto &&data_local_B : data_local_A.second) -// { -// Tvalue &data = data_recv[std::move(data_local_A.first)][std::move(data_local_B.first)]; -// if(!data) -// data = std::move(data_local_B.second); -// else -// data = data + std::move(data_local_B.second); -// } -// } - } - -// template -// void add_datas( -// std::map>> &&data_local, -// std::map>> &data_recv) -// { -// for (auto &&data_local_A : data_local) -// { -// for (auto &&data_local_B : data_local_A.second) -// { -// for (auto &&data_local_C : data_local_B.second) -// { -// Tvalue &data = data_recv[std::move(data_local_A.first)][std::move(data_local_B.first)][std::move(data_local_C.first)]; -// if(!data) -// data = std::move(data_local_C.second); -// else -// data = data + std::move(data_local_C.second); -// } -// } -// } -// } -} - -} \ No newline at end of file +//======================= +// AUTHOR : Peize Lin +// DATE : 2022-01-05 +//======================= + +#pragma once + +#include +#include +#include + +namespace Comm +{ + +namespace Communicate_Map +{ + /* + template + inline Tvalue &get_value( + const std::tuple &key) + { + return data[std::get<0>(key)][std::get<1>(key)]; + } + + Tvalue &get_value( + const std::tuple &key) + { + return data[std::get<0>(key)][std::get<1>(key)][std::get<2>(key)]; + } + */ + + // 等于 + /* + template + void set_value_assignment( + const std::tuple &key, + const Tvalue &value, + std::map> &data) + { + data[std::get<0>(key)][std::get<1>(key)] = value; + } + */ + + template + void set_value_assignment( + Tkey &&key, + Tvalue &&value, + std::map &data) + { + data[key] = std::move(value); + } + + template + void set_value_assignment( + std::tuple &&key, + Tvalue &&value, + std::map> &data) + { + data[std::get<0>(key)][std::get<1>(key)] = std::move(value); + } + + template + void set_value_assignment( + std::tuple &&key, + Tvalue &&value, + std::map>> &data) + { + data[std::get<0>(key)][std::get<1>(key)][std::get<2>(key)] = std::move(value); + } + + // 加 + /* + template + void set_value_add( + const std::tuple &key, + const Tvalue &value, + std::map> &data) + { + Tvalue &value_tmp = data[std::get<0>(key)][std::get<1>(key)]; + if (value_tmp.empty()) + value_tmp = value; + else + value_tmp = value_tmp + value; + } + */ + + template + void set_value_add( + Tkey &&key, + Tvalue &&value, + std::map &data) + { + auto ptr = data.find(key); + if(ptr==data.end()) + data[key] = std::move(value); + else + ptr->second = ptr->second + std::move(value); +// Tvalue &value_tmp = data[key]; +// if (!value_tmp) +// value_tmp = std::move(value); +// else +// value_tmp = value_tmp + std::move(value); + } + + template + void set_value_add( + std::tuple &&key, + Tvalue &&value, + std::map> &data) + { + set_value_add(std::move(std::get<1>(key)), std::move(value), data[std::get<0>(key)]); +// Tvalue &value_tmp = data[std::get<0>(key)][std::get<1>(key)]; +// if (!value_tmp) +// value_tmp = std::move(value); +// else +// value_tmp = value_tmp + std::move(value); + } + + template + void set_value_add( + std::tuple &&key, + Tvalue &&value, + std::map>> &data) + { + set_value_add(std::move(std::get<2>(key)), std::move(value), data[std::get<0>(key)][std::get<1>(key)]); +// Tvalue &value_tmp = data[std::get<0>(key)][std::get<1>(key)][std::get<2>(key)]; +// if (!value_tmp) +// value_tmp = std::move(value); +// else +// value_tmp = value_tmp + std::move(value); + } + + /* + // 等于 + template + set_value( + const std::tuple &key, + const Tvalue &value, + std::map>> &data) + { + data[std::get<0>(key)][std::get<1>(key)][std::get<2>(key)] = value; + } + + // 加 + set_value( + const std::tuple &key, + const Tvalue &value, + std::map>> &data) + { + Tensor<> &data_tmp = data[std::get<0>(key)][std::get<1>(key)][std::get<2>(key)]; + if (data_tmp.c) + data_tmp = data_tmp + value; + else + data_tmp = value; + } + */ + + // 无筛选,全部遍历 + template + void traverse_datas_all( + const std::map &data, + const int rank_isend, + std::function &func) + { + for (const auto &item : data) + func (item.first, item.second); + } + + template + void traverse_datas_all( + const std::map> &data, + const int rank_isend, + std::function&, const Tvalue&)> &func) + { + for (const auto &dataA : data) + for (const auto &dataB : dataA.second) + func (std::make_tuple(dataA.first,dataB.first), dataB.second); + } + + + + // 加横纵接收筛选。进程给出的label不做筛选。即,当set_value为=时需所有进程给出的label重复很少,或set_value为+。 + template + void traverse_datas_mask( + const std::map> &data, + const int rank_isend, + std::function&, const Tvalue&)> &func, + const std::function &mask0, + const std::function &mask1) + { + for (const auto &dataA : data) + { + const Tkey0 &key0 = dataA.first; + if (!mask0(rank_isend,key0)) continue; + for (const auto &dataB : dataA.second) + { + const Tkey1 &key1 = dataB.first; + if (!mask1(rank_isend,key1)) continue; + func (std::make_tuple(key0,key1), dataB.second); + } + } + } + + + template + std::map init_datas_local(const int rank_recv) + { + return std::map(); + } + + template + std::map> init_datas_local(const int rank_recv) + { + return std::map>(); + } + + template + std::map>> init_datas_local(const int rank_recv) + { + return std::map>>(); + } + + + template + void add_datas( + std::map &&data_local, + std::map &data_recv) + { + auto ptr_local=data_local.begin(); + auto ptr_recv=data_recv.begin(); + for(; ptr_local!=data_local.end() && ptr_recv!=data_recv.end(); ) + { + const Tkey &key_local = ptr_local->first; + const Tkey &key_recv = ptr_recv->first; + if(key_local == key_recv) + { + ptr_recv->second = ptr_recv->second + std::move(ptr_local->second); + ++ptr_local; + ++ptr_recv; + } + else if(key_local < key_recv) + { + ptr_recv = data_recv.emplace_hint(ptr_recv, key_local, std::move(ptr_local->second)); + ++ptr_local; + } + else + { + ++ptr_recv; + } + } + for(; ptr_local!=data_local.end(); ++ptr_local) + { + ptr_recv = data_recv.emplace_hint(ptr_recv, ptr_local->first, std::move(ptr_local->second)); + } + +// for (auto &&data_local_A : data_local) +// { +// auto ptr = data_recv.find(data_local_A.first); +// if(ptr==data_recv.end()) +// data_recv[data_local_A.first] = std::move(data_local_A.second); +// else +// ptr->second = ptr->second + std::move(data_local_A.second); +// } + +// for (auto &&data_local_A : data_local) +// { +// Tvalue &data = data_recv[std::move(data_local_A.first)]; +// if(!data) +// data = std::move(data_local_A.second); +// else +// data = data + std::move(data_local_A.second); +// } + } + + template + void add_datas( + std::map> &&data_local, + std::map> &data_recv) + { + auto ptr_local=data_local.begin(); + auto ptr_recv=data_recv.begin(); + for(; ptr_local!=data_local.end() && ptr_recv!=data_recv.end(); ) + { + const Tkey0 &key_local = ptr_local->first; + const Tkey0 &key_recv = ptr_recv->first; + if(key_local == key_recv) + { + add_datas(std::move(ptr_local->second), ptr_recv->second); + ++ptr_local; + ++ptr_recv; + } + else if(key_local < key_recv) + { + ptr_recv = data_recv.emplace_hint(ptr_recv, key_local, std::move(ptr_local->second)); + ++ptr_local; + } + else + { + ++ptr_recv; + } + } + for(; ptr_local!=data_local.end(); ++ptr_local) + { + ptr_recv = data_recv.emplace_hint(ptr_recv, ptr_local->first, std::move(ptr_local->second)); + } + +// for (auto &&data_local_A : data_local) +// { +// for (auto &&data_local_B : data_local_A.second) +// { +// Tvalue &data = data_recv[std::move(data_local_A.first)][std::move(data_local_B.first)]; +// if(!data) +// data = std::move(data_local_B.second); +// else +// data = data + std::move(data_local_B.second); +// } +// } + } + +// template +// void add_datas( +// std::map>> &&data_local, +// std::map>> &data_recv) +// { +// for (auto &&data_local_A : data_local) +// { +// for (auto &&data_local_B : data_local_A.second) +// { +// for (auto &&data_local_C : data_local_B.second) +// { +// Tvalue &data = data_recv[std::move(data_local_A.first)][std::move(data_local_B.first)][std::move(data_local_C.first)]; +// if(!data) +// data = std::move(data_local_C.second); +// else +// data = data + std::move(data_local_C.second); +// } +// } +// } +// } +} + +} diff --git a/include/Comm/example/Communicate_Map-2.h b/include/Comm/example/Communicate_Map-2.h index d01c4ce..a608a9c 100644 --- a/include/Comm/example/Communicate_Map-2.h +++ b/include/Comm/example/Communicate_Map-2.h @@ -1,111 +1,112 @@ -//======================= -// AUTHOR : Peize Lin -// DATE : 2022-07-06 -//======================= - -#pragma once - -#include -#include -#include -#include - -namespace Comm -{ - -namespace Communicate_Map -{ - template - void traverse_keys( - const std::map &datas, - std::function &func) - { - for(const auto &data : datas) - func(data.first); - } - template - void traverse_keys( - const std::map> &datas, - std::function &)> &func) - { - for(const auto &data0 : datas) - for(const auto &data1 : data0.second) - func(std::make_tuple(std::cref(data0.first), std::cref(data1.first))); - } - template - void traverse_keys( - const std::map>> &datas, - std::function &)> &func) - { - for(const auto &data0 : datas) - for(const auto &data1 : data0.second) - for(const auto &data2 : data1.second) - func(std::make_tuple(std::cref(data0.first), std::cref(data1.first), std::cref(data2.first))); - } - - template - const Tvalue &get_value( - const Tkey &key, - const std::map &m) - { - return m.at(key); - } - template - const Tvalue &get_value( - const std::tuple &key, - const std::map> &m) - { - return m.at(std::get<0>(key)).at(std::get<1>(key)); - } - template - const Tvalue &get_value( - const std::tuple &key, - const std::map>> &m) - { - return m.at(std::get<0>(key)).at(std::get<1>(key)).at(std::get<2>(key)); - } - - template - class Judge_Map - { - public: - bool judge(const Tkey &key) const - { - return s.find(key)!=s.end(); - } - std::set s; - template void serialize( Archive & ar ){ ar(s); } - }; - - template - class Judge_Map2 - { - public: - bool judge(const std::tuple &key) const - { - return (s0.find(std::get<0>(key))!=s0.end()) - && (s1.find(std::get<1>(key))!=s1.end()); - } - std::set s0; - std::set s1; - template void serialize( Archive & ar ){ ar(s0); ar(s1); } - }; - - template - class Judge_Map3 - { - public: - bool judge(const std::tuple &key) const - { - return (s0.find(std::get<0>(key))!=s0.end()) - && (s1.find(std::get<1>(key))!=s1.end()) - && (s2.find(std::get<2>(key))!=s2.end()); - } - std::set s0; - std::set s1; - std::set s2; - template void serialize( Archive & ar ){ ar(s0); ar(s1); ar(s2); } - }; -} - -} \ No newline at end of file +//======================= +// AUTHOR : Peize Lin +// DATE : 2022-07-06 +//======================= + +#pragma once + +#include +#include +#include +#include +#include + +namespace Comm +{ + +namespace Communicate_Map +{ + template + void traverse_keys( + const std::map &datas, + std::function &func) + { + for(const auto &data : datas) + func(data.first); + } + template + void traverse_keys( + const std::map> &datas, + std::function &)> &func) + { + for(const auto &data0 : datas) + for(const auto &data1 : data0.second) + func(std::make_tuple(std::cref(data0.first), std::cref(data1.first))); + } + template + void traverse_keys( + const std::map>> &datas, + std::function &)> &func) + { + for(const auto &data0 : datas) + for(const auto &data1 : data0.second) + for(const auto &data2 : data1.second) + func(std::make_tuple(std::cref(data0.first), std::cref(data1.first), std::cref(data2.first))); + } + + template + const Tvalue &get_value( + const Tkey &key, + const std::map &m) + { + return m.at(key); + } + template + const Tvalue &get_value( + const std::tuple &key, + const std::map> &m) + { + return m.at(std::get<0>(key)).at(std::get<1>(key)); + } + template + const Tvalue &get_value( + const std::tuple &key, + const std::map>> &m) + { + return m.at(std::get<0>(key)).at(std::get<1>(key)).at(std::get<2>(key)); + } + + template + class Judge_Map + { + public: + bool judge(const Tkey &key) const + { + return s.find(key)!=s.end(); + } + std::set s; + template void serialize( Archive & ar ){ ar(s); } + }; + + template + class Judge_Map2 + { + public: + bool judge(const std::tuple &key) const + { + return (s0.find(std::get<0>(key))!=s0.end()) + && (s1.find(std::get<1>(key))!=s1.end()); + } + std::set s0; + std::set s1; + template void serialize( Archive & ar ){ ar(s0); ar(s1); } + }; + + template + class Judge_Map3 + { + public: + bool judge(const std::tuple &key) const + { + return (s0.find(std::get<0>(key))!=s0.end()) + && (s1.find(std::get<1>(key))!=s1.end()) + && (s2.find(std::get<2>(key))!=s2.end()); + } + std::set s0; + std::set s1; + std::set s2; + template void serialize( Archive & ar ){ ar(s0); ar(s1); ar(s2); } + }; +} + +} diff --git a/include/Comm/example/Communicate_Set.h b/include/Comm/example/Communicate_Set.h index 26396a7..337bad4 100644 --- a/include/Comm/example/Communicate_Set.h +++ b/include/Comm/example/Communicate_Set.h @@ -1,26 +1,27 @@ -// =================== -// Author: Peize Lin -// date: 2022.06.22 -// =================== - -#pragma once - -#include - -namespace Comm -{ - -namespace Communicate_Set -{ - template - void traverse_keys( - const std::set &keys, - std::function &func) - { - for(const Tkey &key : keys) - func(key); - } - -} - -} \ No newline at end of file +// =================== +// Author: Peize Lin +// date: 2022.06.22 +// =================== + +#pragma once + +#include +#include + +namespace Comm +{ + +namespace Communicate_Set +{ + template + void traverse_keys( + const std::set &keys, + std::function &func) + { + for(const Tkey &key : keys) + func(key); + } + +} + +} diff --git a/include/Comm/example/Communicate_Vector.h b/include/Comm/example/Communicate_Vector.h index 8ed6f61..59b0afd 100644 --- a/include/Comm/example/Communicate_Vector.h +++ b/include/Comm/example/Communicate_Vector.h @@ -1,25 +1,26 @@ -// =================== -// Author: Peize Lin -// date: 2022.06.22 -// =================== - -#pragma once - -#include - -namespace Comm -{ - -namespace Communicate_Vector -{ - template - void traverse_keys( - const std::vector &keys, - std::function &func) - { - for(const Tkey &key : keys) - func(key); - } -} - -} \ No newline at end of file +// =================== +// Author: Peize Lin +// date: 2022.06.22 +// =================== + +#pragma once + +#include +#include + +namespace Comm +{ + +namespace Communicate_Vector +{ + template + void traverse_keys( + const std::vector &keys, + std::function &func) + { + for(const Tkey &key : keys) + func(key); + } +} + +} diff --git a/include/Comm/global/Cereal_Func.h b/include/Comm/global/Cereal_Func.h index 192a162..7ec640f 100644 --- a/include/Comm/global/Cereal_Func.h +++ b/include/Comm/global/Cereal_Func.h @@ -1,61 +1,61 @@ -// =================== -// Author: Peize Lin -// date: 2022.05.01 -// =================== - -#pragma once - -#include "MPI_Wrapper.h" - -#include -#include -#include - -namespace Comm -{ - -class Cereal_Func -{ - public: - - inline Cereal_Func(); - - // every 2^exponent_align char concatenate to 1 word - inline std::size_t align_stringstream(std::stringstream &ss); - - // Send str - inline void mpi_send(const std::string &str, const std::size_t exponent_align, const int rank_recv, const int tag, const MPI_Comm &mpi_comm); - - // Send data - template - void mpi_send(const int rank_recv, const int tag, const MPI_Comm &mpi_comm, - const Ts&... data); - - // Isend str - inline void mpi_isend(const std::string &str, const std::size_t exponent_align, const int rank_recv, const int tag, const MPI_Comm &mpi_comm, MPI_Request &request); - - // Isend data using temporary memory str - template - void mpi_isend(const int rank_recv, const int tag, const MPI_Comm &mpi_comm, - std::string &str, MPI_Request &request, - const Ts&... data); - - // Recv to return - inline std::vector mpi_recv(const MPI_Comm &mpi_comm, MPI_Status &status); - - // Recv to data - template - MPI_Status mpi_recv(const MPI_Comm &mpi_comm, - Ts&... data); - - // Mrecv to return - inline std::vector mpi_mrecv(MPI_Message &message_recv, const MPI_Status &status); - - private: - - MPI_Wrapper::MPI_Type_Contiguous_Pool char_contiguous{MPI_CHAR}; -}; - -} - -#include "Cereal_Func.hpp" \ No newline at end of file +// =================== +// Author: Peize Lin +// date: 2022.05.01 +// =================== + +#pragma once + +#include "MPI_Wrapper.h" + +#include +#include +#include + +namespace Comm +{ + +class Cereal_Func +{ + public: + + inline Cereal_Func(); + + // every 2^exponent_align char concatenate to 1 word + inline std::size_t align_stringstream(std::stringstream &ss); + + // Send str + inline void mpi_send(const std::string &str, const std::size_t exponent_align, const int rank_recv, const int tag, const MPI_Comm &mpi_comm); + + // Send data + template + void mpi_send(const int rank_recv, const int tag, const MPI_Comm &mpi_comm, + const Ts&... data); + + // Isend str + inline void mpi_isend(const std::string &str, const std::size_t exponent_align, const int rank_recv, const int tag, const MPI_Comm &mpi_comm, MPI_Request &request); + + // Isend data using temporary memory str + template + void mpi_isend(const int rank_recv, const int tag, const MPI_Comm &mpi_comm, + std::string &str, MPI_Request &request, + const Ts&... data); + + // Recv to return + inline std::vector mpi_recv(const MPI_Comm &mpi_comm, MPI_Status &status); + + // Recv to data + template + MPI_Status mpi_recv(const MPI_Comm &mpi_comm, + Ts&... data); + + // Mrecv to return + inline std::vector mpi_mrecv(MPI_Message &message_recv, const MPI_Status &status); + + private: + + MPI_Wrapper::MPI_Type_Contiguous_Pool char_contiguous{MPI_CHAR}; +}; + +} + +#include "Cereal_Func.hpp" diff --git a/include/Comm/global/Cereal_Func.hpp b/include/Comm/global/Cereal_Func.hpp index 0aa498b..aee7c13 100644 --- a/include/Comm/global/Cereal_Func.hpp +++ b/include/Comm/global/Cereal_Func.hpp @@ -1,192 +1,192 @@ -// =================== -// Author: Peize Lin -// date: 2022.05.01 -// =================== - -#pragma once - -#include "Cereal_Func.h" -#include "Cereal_Types.h" -#include "MPI_Wrapper.h" -#include "Global_Func.h" - -#include -#include -#include -#include -#include - -#define MPI_CHECK(x) if((x)!=MPI_SUCCESS) throw std::runtime_error(std::string(__FILE__)+" line "+std::to_string(__LINE__)); - -namespace Comm -{ - -inline Cereal_Func::Cereal_Func() -{ - #if MPI_VERSION>=4 - using int_type = MPI_Count; - #else - using int_type = int; - #endif - - // assuming MPI communication is less than max(1TB, memory availablle now) for most case, - // so initialize here to avoid thread conflict in the future. - const std::size_t TB = std::size_t(1)<<12; - const std::size_t memory_max = std::max(TB, Global_Func::memory_available()); - const std::size_t times = std::ceil( double(memory_max) / double(std::numeric_limits::max()) ); - const std::size_t exponent_align = std::ceil( std::log(times) / std::log(2) ); - this->char_contiguous.resize(exponent_align); -} - -// every 2^exponent_align char concatenate to 1 word - // <>exponent_align means /2^exponent_align -inline std::size_t Cereal_Func::align_stringstream(std::stringstream &ss) -{ - #if MPI_VERSION>=4 - using int_type = MPI_Count; - #else - using int_type = int; - #endif - - const std::size_t size_old = ss.str().size(); // Inefficient, should be optimized - const std::size_t times = std::ceil( double(size_old) / double(std::numeric_limits::max()) ); - const std::size_t exponent_align = std::ceil( std::log(times) / std::log(2) ); - this->char_contiguous.resize(exponent_align); - constexpr char c0 = 0; - const std::size_t size_align = 1<char_contiguous.resize(exponent_align); - #if MPI_VERSION>=4 - MPI_CHECK( MPI_Send_c( str.c_str(), str.size()>>exponent_align, this->char_contiguous(exponent_align), rank_recv, tag, mpi_comm ) ); - #else - MPI_CHECK( MPI_Send ( str.c_str(), str.size()>>exponent_align, this->char_contiguous(exponent_align), rank_recv, tag, mpi_comm ) ); - #endif -} - -// Send data -template -void Cereal_Func::mpi_send(const int rank_recv, const int tag, const MPI_Comm &mpi_comm, - const Ts&... data) -{ - std::stringstream ss; - { - cereal::BinaryOutputArchive ar(ss); - ar(data...); - } - const std::size_t exponent_align = align_stringstream(ss); - const std::string &str = ss.str(); - mpi_send(str, exponent_align, rank_recv, tag, mpi_comm); -} - - -// Isend str -inline void Cereal_Func::mpi_isend(const std::string &str, const std::size_t exponent_align, const int rank_recv, const int tag, const MPI_Comm &mpi_comm, MPI_Request &request) -{ - this->char_contiguous.resize(exponent_align); - #if MPI_VERSION>=4 - MPI_CHECK( MPI_Isend_c( str.c_str(), str.size()>>exponent_align, this->char_contiguous(exponent_align), rank_recv, tag, mpi_comm, &request ) ); - #else - MPI_CHECK( MPI_Isend ( str.c_str(), str.size()>>exponent_align, this->char_contiguous(exponent_align), rank_recv, tag, mpi_comm, &request ) ); - #endif -} - -// Isend data using temporary memory str -template -void Cereal_Func::mpi_isend(const int rank_recv, const int tag, const MPI_Comm &mpi_comm, - std::string &str, MPI_Request &request, - const Ts&... data) -{ - std::stringstream ss; - { - cereal::BinaryOutputArchive ar(ss); - ar(data...); - } - const std::size_t exponent_align = align_stringstream(ss); - str = ss.str(); - mpi_isend(str, exponent_align, rank_recv, tag, mpi_comm, request); -} - - -// Recv to return -inline std::vector Cereal_Func::mpi_recv(const MPI_Comm &mpi_comm, MPI_Status &status) -{ - for(std::size_t exponent_align=0; ; ++exponent_align) - { - this->char_contiguous.resize(exponent_align); - const MPI_Datatype mpi_type = this->char_contiguous(exponent_align); - #if MPI_VERSION>=4 - MPI_Count size; MPI_CHECK( MPI_Get_count_c( &status, mpi_type, &size ) ); - #else - int size; MPI_CHECK( MPI_Get_count ( &status, mpi_type, &size ) ); - #endif - if(size!=MPI_UNDEFINED) - { - std::vector c(std::size_t(size)<=4 - MPI_CHECK( MPI_Recv_c( c.data(), size, mpi_type, status.MPI_SOURCE, status.MPI_TAG, mpi_comm, MPI_STATUS_IGNORE ) ); - #else - MPI_CHECK( MPI_Recv ( c.data(), size, mpi_type, status.MPI_SOURCE, status.MPI_TAG, mpi_comm, MPI_STATUS_IGNORE ) ); - #endif - return c; - } - } -} - -// Recv to data -template -MPI_Status Cereal_Func::mpi_recv(const MPI_Comm &mpi_comm, - Ts&... data) -{ - MPI_Status status; - MPI_CHECK( MPI_Probe( MPI_ANY_SOURCE, MPI_ANY_TAG, mpi_comm, &status ) ); - - std::vector c = mpi_recv(mpi_comm, status); - - std::stringstream ss; - ss.rdbuf()->pubsetbuf(c.data(), c.size()); - { - cereal::BinaryInputArchive ar(ss); - ar(data...); - } - return status; -} - - -// Mrecv to return -inline std::vector Cereal_Func::mpi_mrecv(MPI_Message &message_recv, const MPI_Status &status) -{ - for(std::size_t exponent_align=0; ; ++exponent_align) - { - this->char_contiguous.resize(exponent_align); - const MPI_Datatype mpi_type = this->char_contiguous(exponent_align); - #if MPI_VERSION>=4 - MPI_Count size; MPI_CHECK( MPI_Get_count_c( &status, mpi_type, &size ) ); - #else - int size; MPI_CHECK( MPI_Get_count ( &status, mpi_type, &size ) ); - #endif - if(size!=MPI_UNDEFINED) - { - std::vector c(std::size_t(size)<=4 - MPI_CHECK( MPI_Mrecv_c( c.data(), size, mpi_type, &message_recv, MPI_STATUS_IGNORE ) ); - #else - MPI_CHECK( MPI_Mrecv ( c.data(), size, mpi_type, &message_recv, MPI_STATUS_IGNORE ) ); - #endif - return c; - } - } -} - -} - -#undef MPI_CHECK \ No newline at end of file +// =================== +// Author: Peize Lin +// date: 2022.05.01 +// =================== + +#pragma once + +#include "Cereal_Func.h" +#include "Cereal_Types.h" +#include "MPI_Wrapper.h" +#include "Global_Func.h" + +#include +#include +#include +#include +#include + +#define MPI_CHECK(x) if((x)!=MPI_SUCCESS) throw std::runtime_error(std::string(__FILE__)+" line "+std::to_string(__LINE__)); + +namespace Comm +{ + +inline Cereal_Func::Cereal_Func() +{ + #if MPI_VERSION>=4 + using int_type = MPI_Count; + #else + using int_type = int; + #endif + + // assuming MPI communication is less than max(1TB, memory availablle now) for most case, + // so initialize here to avoid thread conflict in the future. + const std::size_t TB = std::size_t(1)<<12; + const std::size_t memory_max = std::max(TB, Global_Func::memory_available()); + const std::size_t times = std::ceil( double(memory_max) / double(std::numeric_limits::max()) ); + const std::size_t exponent_align = std::ceil( std::log(times) / std::log(2) ); + this->char_contiguous.resize(exponent_align); +} + +// every 2^exponent_align char concatenate to 1 word + // <>exponent_align means /2^exponent_align +inline std::size_t Cereal_Func::align_stringstream(std::stringstream &ss) +{ + #if MPI_VERSION>=4 + using int_type = MPI_Count; + #else + using int_type = int; + #endif + + const std::size_t size_old = ss.str().size(); // Inefficient, should be optimized + const std::size_t times = std::ceil( double(size_old) / double(std::numeric_limits::max()) ); + const std::size_t exponent_align = std::ceil( std::log(times) / std::log(2) ); + this->char_contiguous.resize(exponent_align); + constexpr char c0 = 0; + const std::size_t size_align = 1<char_contiguous.resize(exponent_align); + #if MPI_VERSION>=4 + MPI_CHECK( MPI_Send_c( str.c_str(), str.size()>>exponent_align, this->char_contiguous(exponent_align), rank_recv, tag, mpi_comm ) ); + #else + MPI_CHECK( MPI_Send ( str.c_str(), str.size()>>exponent_align, this->char_contiguous(exponent_align), rank_recv, tag, mpi_comm ) ); + #endif +} + +// Send data +template +void Cereal_Func::mpi_send(const int rank_recv, const int tag, const MPI_Comm &mpi_comm, + const Ts&... data) +{ + std::stringstream ss; + { + cereal::BinaryOutputArchive ar(ss); + ar(data...); + } + const std::size_t exponent_align = align_stringstream(ss); + const std::string &str = ss.str(); + mpi_send(str, exponent_align, rank_recv, tag, mpi_comm); +} + + +// Isend str +inline void Cereal_Func::mpi_isend(const std::string &str, const std::size_t exponent_align, const int rank_recv, const int tag, const MPI_Comm &mpi_comm, MPI_Request &request) +{ + this->char_contiguous.resize(exponent_align); + #if MPI_VERSION>=4 + MPI_CHECK( MPI_Isend_c( str.c_str(), str.size()>>exponent_align, this->char_contiguous(exponent_align), rank_recv, tag, mpi_comm, &request ) ); + #else + MPI_CHECK( MPI_Isend ( str.c_str(), str.size()>>exponent_align, this->char_contiguous(exponent_align), rank_recv, tag, mpi_comm, &request ) ); + #endif +} + +// Isend data using temporary memory str +template +void Cereal_Func::mpi_isend(const int rank_recv, const int tag, const MPI_Comm &mpi_comm, + std::string &str, MPI_Request &request, + const Ts&... data) +{ + std::stringstream ss; + { + cereal::BinaryOutputArchive ar(ss); + ar(data...); + } + const std::size_t exponent_align = align_stringstream(ss); + str = ss.str(); + mpi_isend(str, exponent_align, rank_recv, tag, mpi_comm, request); +} + + +// Recv to return +inline std::vector Cereal_Func::mpi_recv(const MPI_Comm &mpi_comm, MPI_Status &status) +{ + for(std::size_t exponent_align=0; ; ++exponent_align) + { + this->char_contiguous.resize(exponent_align); + const MPI_Datatype mpi_type = this->char_contiguous(exponent_align); + #if MPI_VERSION>=4 + MPI_Count size; MPI_CHECK( MPI_Get_count_c( &status, mpi_type, &size ) ); + #else + int size; MPI_CHECK( MPI_Get_count ( &status, mpi_type, &size ) ); + #endif + if(size!=MPI_UNDEFINED) + { + std::vector c(std::size_t(size)<=4 + MPI_CHECK( MPI_Recv_c( c.data(), size, mpi_type, status.MPI_SOURCE, status.MPI_TAG, mpi_comm, MPI_STATUS_IGNORE ) ); + #else + MPI_CHECK( MPI_Recv ( c.data(), size, mpi_type, status.MPI_SOURCE, status.MPI_TAG, mpi_comm, MPI_STATUS_IGNORE ) ); + #endif + return c; + } + } +} + +// Recv to data +template +MPI_Status Cereal_Func::mpi_recv(const MPI_Comm &mpi_comm, + Ts&... data) +{ + MPI_Status status; + MPI_CHECK( MPI_Probe( MPI_ANY_SOURCE, MPI_ANY_TAG, mpi_comm, &status ) ); + + std::vector c = mpi_recv(mpi_comm, status); + + std::stringstream ss; + ss.rdbuf()->pubsetbuf(c.data(), c.size()); + { + cereal::BinaryInputArchive ar(ss); + ar(data...); + } + return status; +} + + +// Mrecv to return +inline std::vector Cereal_Func::mpi_mrecv(MPI_Message &message_recv, const MPI_Status &status) +{ + for(std::size_t exponent_align=0; ; ++exponent_align) + { + this->char_contiguous.resize(exponent_align); + const MPI_Datatype mpi_type = this->char_contiguous(exponent_align); + #if MPI_VERSION>=4 + MPI_Count size; MPI_CHECK( MPI_Get_count_c( &status, mpi_type, &size ) ); + #else + int size; MPI_CHECK( MPI_Get_count ( &status, mpi_type, &size ) ); + #endif + if(size!=MPI_UNDEFINED) + { + std::vector c(std::size_t(size)<=4 + MPI_CHECK( MPI_Mrecv_c( c.data(), size, mpi_type, &message_recv, MPI_STATUS_IGNORE ) ); + #else + MPI_CHECK( MPI_Mrecv ( c.data(), size, mpi_type, &message_recv, MPI_STATUS_IGNORE ) ); + #endif + return c; + } + } +} + +} + +#undef MPI_CHECK diff --git a/include/Comm/global/Cereal_Types.h b/include/Comm/global/Cereal_Types.h index 7be32a1..db9dec3 100644 --- a/include/Comm/global/Cereal_Types.h +++ b/include/Comm/global/Cereal_Types.h @@ -1,13 +1,13 @@ -// =================== -// Author: Peize Lin -// date: 2022.05.01 -// =================== - -#pragma once - -#include - -// for example -#include -#include -#include \ No newline at end of file +// =================== +// Author: Peize Lin +// date: 2022.05.01 +// =================== + +#pragma once + +#include + +// for example +#include +#include +#include diff --git a/include/Comm/global/Global_Func.h b/include/Comm/global/Global_Func.h index a1ef138..f8a3314 100644 --- a/include/Comm/global/Global_Func.h +++ b/include/Comm/global/Global_Func.h @@ -1,20 +1,20 @@ -//======================= -// AUTHOR : Peize Lin -// DATE : 2023-02-15 -//======================= - -#pragma once - -#include - -namespace Comm -{ - -namespace Global_Func -{ - extern inline std::size_t memory_available(); -} - -} - -#include "Global_Func.hpp" \ No newline at end of file +//======================= +// AUTHOR : Peize Lin +// DATE : 2023-02-15 +//======================= + +#pragma once + +#include + +namespace Comm +{ + +namespace Global_Func +{ + extern inline std::size_t memory_available(); +} + +} + +#include "Global_Func.hpp" diff --git a/include/Comm/global/Global_Func.hpp b/include/Comm/global/Global_Func.hpp index fd2b769..53e57bd 100644 --- a/include/Comm/global/Global_Func.hpp +++ b/include/Comm/global/Global_Func.hpp @@ -1,64 +1,64 @@ -//======================= -// AUTHOR : Peize Lin -// DATE : 2023-02-15 -//======================= - -#pragma once - -#include "Global_Func.h" - -#include -#include -#include - -// BSD like macOS does not have /proc/meminfo -// include unistd.h to get the number of system memory. -#if defined(__MACH__) -#include -#endif - -namespace Comm -{ - -namespace Global_Func -{ -#if defined(__MACH__) - inline std::size_t memory_available() - { - // HACK: always assume half of the system memory free. Only for build and test on macOS - const std::size_t pages = sysconf(_SC_PHYS_PAGES); - const std::size_t page_size = sysconf(_SC_PAGE_SIZE); - return pages * page_size / 2; - } -#else - inline std::size_t memory_available() - { - constexpr std::size_t kB_to_B = 1024; - std::ifstream ifs("/proc/meminfo"); - int num = 0; - std::size_t mem_sum = 0; - while (ifs.good()) - { - std::string label, size, kB; - ifs >> label >> size >> kB; - if (label == "MemAvailable:") - { - return std::stol(size) * kB_to_B; - } - else if (label == "MemFree:" || label == "Buffers:" || label == "Cached:") - { - mem_sum += std::stol(size); - ++num; - } - - if(num==3) - { - return mem_sum * kB_to_B; - } - } - throw std::runtime_error("read /proc/meminfo error in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -#endif -} - -} \ No newline at end of file +//======================= +// AUTHOR : Peize Lin +// DATE : 2023-02-15 +//======================= + +#pragma once + +#include "Global_Func.h" + +#include +#include +#include + +// BSD like macOS does not have /proc/meminfo +// include unistd.h to get the number of system memory. +#if defined(__MACH__) +#include +#endif + +namespace Comm +{ + +namespace Global_Func +{ +#if defined(__MACH__) + inline std::size_t memory_available() + { + // HACK: always assume half of the system memory free. Only for build and test on macOS + const std::size_t pages = sysconf(_SC_PHYS_PAGES); + const std::size_t page_size = sysconf(_SC_PAGE_SIZE); + return pages * page_size / 2; + } +#else + inline std::size_t memory_available() + { + constexpr std::size_t kB_to_B = 1024; + std::ifstream ifs("/proc/meminfo"); + int num = 0; + std::size_t mem_sum = 0; + while (ifs.good()) + { + std::string label, size, kB; + ifs >> label >> size >> kB; + if (label == "MemAvailable:") + { + return std::stol(size) * kB_to_B; + } + else if (label == "MemFree:" || label == "Buffers:" || label == "Cached:") + { + mem_sum += std::stol(size); + ++num; + } + + if(num==3) + { + return mem_sum * kB_to_B; + } + } + throw std::runtime_error("read /proc/meminfo error in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +#endif +} + +} diff --git a/include/Comm/global/MPI_Wrapper.h b/include/Comm/global/MPI_Wrapper.h index be869bf..09a8907 100644 --- a/include/Comm/global/MPI_Wrapper.h +++ b/include/Comm/global/MPI_Wrapper.h @@ -1,93 +1,93 @@ -// =================== -// Author: Peize Lin -// date: 2022.06.02 -// =================== - -#pragma once - -#include -#include -#include -#include -#include - -#define MPI_CHECK(x) if((x)!=MPI_SUCCESS) throw std::runtime_error(std::string(__FILE__)+" line "+std::to_string(__LINE__)); - -namespace Comm -{ - -namespace MPI_Wrapper -{ - inline int mpi_get_rank(const MPI_Comm &mpi_comm) - { - int rank_mine; - MPI_CHECK( MPI_Comm_rank (mpi_comm, &rank_mine) ); - return rank_mine; - } - - inline int mpi_get_size(const MPI_Comm &mpi_comm) - { - int rank_size; - MPI_CHECK( MPI_Comm_size (mpi_comm, &rank_size) ); - return rank_size; - } - - #if MPI_VERSION>=4 - inline MPI_Count mpi_get_count(const MPI_Status &status, const MPI_Datatype &datatype) - { - MPI_Count count; - MPI_CHECK( MPI_Get_count_c(&status, datatype, &count) ); - return count; - } - #else - inline int mpi_get_count(const MPI_Status &status, const MPI_Datatype &datatype) - { - int count; - MPI_CHECK( MPI_Get_count (&status, datatype, &count) ); - return count; - } - #endif - - - - // MPI_Type_Contiguous_Pool(ie) = MPI_Type_contiguous(2^ie, type_base); - class MPI_Type_Contiguous_Pool - { - public: - MPI_Datatype operator()(const std::size_t exponent) const - { - return this->type_pool.at(exponent); - } - void resize(const std::size_t exponent) - { - #pragma omp critical(MPI_Type_Contiguous_Pool) - if(this->type_pool.size()type_pool.size(); - this->type_pool.resize(exponent+1); - for(std::size_t ie=size_old; ietype_pool.size(); ++ie) - { - MPI_CHECK( MPI_Type_contiguous( 1<type_base, &this->type_pool[ie] ) ); - MPI_CHECK( MPI_Type_commit( &this->type_pool[ie] ) ); - } - } - } - MPI_Type_Contiguous_Pool(const MPI_Datatype &type_base_in) - { - MPI_CHECK( MPI_Type_dup(type_base_in, &this->type_base) ); - MPI_CHECK( MPI_Type_commit( &this->type_base ) ); - } - ~MPI_Type_Contiguous_Pool() - { - for(std::size_t ie=0; ietype_pool.size(); ++ie) - MPI_Type_free( &this->type_pool[ie] ); - MPI_Type_free( &this->type_base ); - } - std::vector type_pool; - MPI_Datatype type_base; - }; -} - -} - -#undef MPI_CHECK +// =================== +// Author: Peize Lin +// date: 2022.06.02 +// =================== + +#pragma once + +#include +#include +#include +#include +#include + +#define MPI_CHECK(x) if((x)!=MPI_SUCCESS) throw std::runtime_error(std::string(__FILE__)+" line "+std::to_string(__LINE__)); + +namespace Comm +{ + +namespace MPI_Wrapper +{ + inline int mpi_get_rank(const MPI_Comm &mpi_comm) + { + int rank_mine; + MPI_CHECK( MPI_Comm_rank (mpi_comm, &rank_mine) ); + return rank_mine; + } + + inline int mpi_get_size(const MPI_Comm &mpi_comm) + { + int rank_size; + MPI_CHECK( MPI_Comm_size (mpi_comm, &rank_size) ); + return rank_size; + } + + #if MPI_VERSION>=4 + inline MPI_Count mpi_get_count(const MPI_Status &status, const MPI_Datatype &datatype) + { + MPI_Count count; + MPI_CHECK( MPI_Get_count_c(&status, datatype, &count) ); + return count; + } + #else + inline int mpi_get_count(const MPI_Status &status, const MPI_Datatype &datatype) + { + int count; + MPI_CHECK( MPI_Get_count (&status, datatype, &count) ); + return count; + } + #endif + + + + // MPI_Type_Contiguous_Pool(ie) = MPI_Type_contiguous(2^ie, type_base); + class MPI_Type_Contiguous_Pool + { + public: + MPI_Datatype operator()(const std::size_t exponent) const + { + return this->type_pool.at(exponent); + } + void resize(const std::size_t exponent) + { + #pragma omp critical(MPI_Type_Contiguous_Pool) + if(this->type_pool.size()type_pool.size(); + this->type_pool.resize(exponent+1); + for(std::size_t ie=size_old; ietype_pool.size(); ++ie) + { + MPI_CHECK( MPI_Type_contiguous( 1<type_base, &this->type_pool[ie] ) ); + MPI_CHECK( MPI_Type_commit( &this->type_pool[ie] ) ); + } + } + } + MPI_Type_Contiguous_Pool(const MPI_Datatype &type_base_in) + { + MPI_CHECK( MPI_Type_dup(type_base_in, &this->type_base) ); + MPI_CHECK( MPI_Type_commit( &this->type_base ) ); + } + ~MPI_Type_Contiguous_Pool() + { + for(std::size_t ie=0; ietype_pool.size(); ++ie) + MPI_Type_free( &this->type_pool[ie] ); + MPI_Type_free( &this->type_base ); + } + std::vector type_pool; + MPI_Datatype type_base; + }; +} + +} + +#undef MPI_CHECK diff --git a/unittests/Comm_Assemble/Communicate_Map-test-2.hpp b/unittests/Comm_Assemble/Communicate_Map-test-2.hpp index 5f2c10b..e6f8ad4 100644 --- a/unittests/Comm_Assemble/Communicate_Map-test-2.hpp +++ b/unittests/Comm_Assemble/Communicate_Map-test-2.hpp @@ -1,103 +1,103 @@ -//======================= -// AUTHOR : Peize Lin -// DATE : 2022-07-06 -//======================= - -#pragma once - -#include "Comm/Comm_Assemble/Comm_Assemble.h" -#include "Comm/example/Communicate_Map-1.h" -#include "Comm/example/Communicate_Map-2.h" -#include "Comm/global/MPI_Wrapper.h" -#include "unittests/print_stl.h" - -#include -#include -#include -#include -#include -#include -#include - -#define MPI_CHECK(x) if((x)!=MPI_SUCCESS) throw std::runtime_error(std::string(__FILE__)+" line "+std::to_string(__LINE__)); - -namespace Communicate_Map_Test -{ - static void test_assemble(int argc, char *argv[]) - { - int provided; - MPI_CHECK( MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided ) ); - assert(Comm::MPI_Wrapper::mpi_get_size(MPI_COMM_WORLD)==6); - const int rank_mine = Comm::MPI_Wrapper::mpi_get_rank(MPI_COMM_WORLD); - - std::map> m_in; - std::map> m_out; - - if(Comm::MPI_Wrapper::mpi_get_rank(MPI_COMM_WORLD)==0){ m_in[0][0]=1.2; m_in[2][1]=3.4;} - else if(Comm::MPI_Wrapper::mpi_get_rank(MPI_COMM_WORLD)==1){ m_in[2][0]=5.6; } - else if(Comm::MPI_Wrapper::mpi_get_rank(MPI_COMM_WORLD)==3){ m_in[2][1]=7.8; m_in[1][10]=9.0; } - - Comm::Communicate_Map::Judge_Map2 judge; - judge.s0 = {rank_mine%3}; - judge.s1 = {rank_mine/3}; - /* - s0\s1 0 1 - 0 0 3 - 1 1 4 - 2 2 5 - */ - //judge.s1 = {0,1,2}; - /* - s0 - 0 03 - 1 14 - 2 25 - */ - - Comm::Comm_Assemble< - std::tuple, - double, - std::map>, - Comm::Communicate_Map::Judge_Map2, - std::map> - > com(MPI_COMM_WORLD); - - com.traverse_keys_provide = Comm::Communicate_Map::traverse_keys; - com.get_value_provide = Comm::Communicate_Map::get_value; - com.set_value_require = Comm::Communicate_Map::set_value_add; - com.flag_lock_set_value = Comm::Comm_Tools::Lock_Type::Copy_merge; - com.init_datas_local = Comm::Communicate_Map::init_datas_local; - com.add_datas = Comm::Communicate_Map::add_datas; - - com.communicate( m_in, judge, m_out ); - - std::ofstream ofs_in("in_"+std::to_string(Comm::MPI_Wrapper::mpi_get_rank(MPI_COMM_WORLD))); - ofs_in< +#include +#include +#include +#include +#include +#include + +#define MPI_CHECK(x) if((x)!=MPI_SUCCESS) throw std::runtime_error(std::string(__FILE__)+" line "+std::to_string(__LINE__)); + +namespace Communicate_Map_Test +{ + static void test_assemble(int argc, char *argv[]) + { + int provided; + MPI_CHECK( MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided ) ); + assert(Comm::MPI_Wrapper::mpi_get_size(MPI_COMM_WORLD)==6); + const int rank_mine = Comm::MPI_Wrapper::mpi_get_rank(MPI_COMM_WORLD); + + std::map> m_in; + std::map> m_out; + + if(Comm::MPI_Wrapper::mpi_get_rank(MPI_COMM_WORLD)==0){ m_in[0][0]=1.2; m_in[2][1]=3.4;} + else if(Comm::MPI_Wrapper::mpi_get_rank(MPI_COMM_WORLD)==1){ m_in[2][0]=5.6; } + else if(Comm::MPI_Wrapper::mpi_get_rank(MPI_COMM_WORLD)==3){ m_in[2][1]=7.8; m_in[1][10]=9.0; } + + Comm::Communicate_Map::Judge_Map2 judge; + judge.s0 = {rank_mine%3}; + judge.s1 = {rank_mine/3}; + /* + s0\s1 0 1 + 0 0 3 + 1 1 4 + 2 2 5 + */ + //judge.s1 = {0,1,2}; + /* + s0 + 0 03 + 1 14 + 2 25 + */ + + Comm::Comm_Assemble< + std::tuple, + double, + std::map>, + Comm::Communicate_Map::Judge_Map2, + std::map> + > com(MPI_COMM_WORLD); + + com.traverse_keys_provide = Comm::Communicate_Map::traverse_keys; + com.get_value_provide = Comm::Communicate_Map::get_value; + com.set_value_require = Comm::Communicate_Map::set_value_add; + com.flag_lock_set_value = Comm::Comm_Tools::Lock_Type::Copy_merge; + com.init_datas_local = Comm::Communicate_Map::init_datas_local; + com.add_datas = Comm::Communicate_Map::add_datas; + + com.communicate( m_in, judge, m_out ); + + std::ofstream ofs_in("in_"+std::to_string(Comm::MPI_Wrapper::mpi_get_rank(MPI_COMM_WORLD))); + ofs_in< -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#define MPI_CHECK(x) if((x)!=MPI_SUCCESS) throw std::runtime_error(std::string(__FILE__)+" line "+std::to_string(__LINE__)); - -namespace Communicate_Map_Test -{ - static void test_speed(int argc, char *argv[]) - { - int provided; - MPI_CHECK( MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided ) ); - const int rank_mine = Comm::MPI_Wrapper::mpi_get_rank(MPI_COMM_WORLD); - const int rank_size = Comm::MPI_Wrapper::mpi_get_size(MPI_COMM_WORLD); - - std::map> m_in; - std::map> m_out; - - const int N = std::stoi(argv[1]); - const int M = std::stoi(argv[2]); - - const int misson_size = std::floor(N/rank_size); - for(int i=0; i(M); - - Comm::Communicate_Map::Judge_Map judge; - for(int i=0; i, - std::map>, - Comm::Communicate_Map::Judge_Map, - std::map> - > com(MPI_COMM_WORLD); - - com.traverse_keys_provide = Comm::Communicate_Map::traverse_keys>; - com.get_value_provide = Comm::Communicate_Map::get_value>; - com.set_value_require = Comm::Communicate_Map::set_value_add>; - com.flag_lock_set_value = Comm::Comm_Tools::Lock_Type::Copy_merge; - com.init_datas_local = Comm::Communicate_Map::init_datas_local>; - com.add_datas = Comm::Communicate_Map::add_datas>; - - MPI_Barrier(MPI_COMM_WORLD); - timeval t_begin; gettimeofday( &t_begin, NULL); - com.communicate( m_in, judge, m_out ); - MPI_Barrier(MPI_COMM_WORLD); - timeval t_end; gettimeofday( &t_end, NULL); - if(rank_mine==0) - std::cout< +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#define MPI_CHECK(x) if((x)!=MPI_SUCCESS) throw std::runtime_error(std::string(__FILE__)+" line "+std::to_string(__LINE__)); + +namespace Communicate_Map_Test +{ + static void test_speed(int argc, char *argv[]) + { + int provided; + MPI_CHECK( MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided ) ); + const int rank_mine = Comm::MPI_Wrapper::mpi_get_rank(MPI_COMM_WORLD); + const int rank_size = Comm::MPI_Wrapper::mpi_get_size(MPI_COMM_WORLD); + + std::map> m_in; + std::map> m_out; + + const int N = std::stoi(argv[1]); + const int M = std::stoi(argv[2]); + + const int misson_size = std::floor(N/rank_size); + for(int i=0; i(M); + + Comm::Communicate_Map::Judge_Map judge; + for(int i=0; i, + std::map>, + Comm::Communicate_Map::Judge_Map, + std::map> + > com(MPI_COMM_WORLD); + + com.traverse_keys_provide = Comm::Communicate_Map::traverse_keys>; + com.get_value_provide = Comm::Communicate_Map::get_value>; + com.set_value_require = Comm::Communicate_Map::set_value_add>; + com.flag_lock_set_value = Comm::Comm_Tools::Lock_Type::Copy_merge; + com.init_datas_local = Comm::Communicate_Map::init_datas_local>; + com.add_datas = Comm::Communicate_Map::add_datas>; + + MPI_Barrier(MPI_COMM_WORLD); + timeval t_begin; gettimeofday( &t_begin, NULL); + com.communicate( m_in, judge, m_out ); + MPI_Barrier(MPI_COMM_WORLD); + timeval t_end; gettimeofday( &t_end, NULL); + if(rank_mine==0) + std::cout< -#include -#include -#include -#include - -#define MPI_CHECK(x) if((x)!=MPI_SUCCESS) throw std::runtime_error(std::string(__FILE__)+" line "+std::to_string(__LINE__)); - -namespace Comm_Keys_3_Test -{ - class Require_Judge - { - public: - Require_Judge()=default; - Require_Judge(const MPI_Comm &mpi_comm) - { - this->rank_mine = Comm::MPI_Wrapper::mpi_get_rank(mpi_comm); - this->rank_size = Comm::MPI_Wrapper::mpi_get_size(mpi_comm); - } - bool judge(const int &key) const - { - return key%(rank_size+1)==rank_mine; - } - template void serialize( Archive & ar ){ ar(rank_size, rank_mine); } - private: - int rank_size=-1; - int rank_mine=-1; - }; - - class Provider_Judge - { - public: - Provider_Judge()=default; - Provider_Judge(const MPI_Comm &mpi_comm) - { - this->rank_mine = Comm::MPI_Wrapper::mpi_get_rank(mpi_comm); - if(rank_mine==0) - v = {2,3,5,7,9}; - else if(rank_mine==1) - v = {3,1,4}; - else if(rank_mine==2) - v = {2}; - else - v = {}; - std::ofstream ofs("out."+std::to_string(rank_mine), std::ofstream::app); - ofs<<"v\t"< void serialize( Archive & ar ){ ar(rank_mine, v); } - private: - std::set v; - int rank_mine=-1; - }; - - template - void main1(int argc, char *argv[]) - { - int mpi_thread_provide; - MPI_CHECK( MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &mpi_thread_provide ) ); - assert(mpi_thread_provide==MPI_THREAD_MULTIPLE); - const int rank_mine = Comm::MPI_Wrapper::mpi_get_rank(MPI_COMM_WORLD); - - std::set v; - if(rank_mine==0) - v = {2,3,5,7,9}; - else if(rank_mine==1) - v = {3,1,4}; - else if(rank_mine==2) - v = {2}; - else - v = {}; - - std::ofstream ofs("out."+std::to_string(rank_mine)); - ofs<<"v\t"<; - std::vector> keys_trans_list = comm.trans(v, require_judge); - - ofs<<"keys_trans_list\t"< - void main2(int argc, char *argv[]) - { - int mpi_thread_provide; - MPI_CHECK( MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &mpi_thread_provide ) ); - assert(mpi_thread_provide==MPI_THREAD_MULTIPLE); - const int rank_mine = Comm::MPI_Wrapper::mpi_get_rank(MPI_COMM_WORLD); - - std::ofstream ofs("out."+std::to_string(rank_mine)); - - Provider_Judge provider_judge(MPI_COMM_WORLD); - Require_Judge require_judge(MPI_COMM_WORLD); - T_Comm_Keys_3_SenderJudge comm(MPI_COMM_WORLD); - comm.traverse_keys_all = [](std::function &func){ for(int i=-3; i<10; ++i) func(i); }; - std::vector> keys_trans_list = comm.trans(provider_judge, require_judge); - - ofs<<"keys_trans_list\t"<, Require_Judge> > (argc,argv); } - inline void main1_32(int argc, char *argv[]){ main1< Comm::Comm_Keys_32_SenderTraversal, Require_Judge> > (argc,argv); } - inline void main2_31(int argc, char *argv[]){ main2< Comm::Comm_Keys_31_SenderJudge > (argc,argv); } - inline void main2_32(int argc, char *argv[]){ main2< Comm::Comm_Keys_32_SenderJudge > (argc,argv); } - - inline void test_all(int argc, char *argv[]) - { - main1_31(argc,argv); - main1_32(argc,argv); - main2_31(argc,argv); - main2_32(argc,argv); - } - - /* - keys_trans_list - - mpirun -n 1 - 0: [2] - - mpirun -n 2 - 0: [3,9], [7] - 1: [3], [1,4] - - mpirun -n 3 - 0: [], [5,9], [2] - 1: [4], [1], [] - 2: [], [], [2] - - mpirun -n 4 - 0: [5], [], [2,7], [3] - 1: [], [1], [], [3] - 2: [], [], [2], [] - 3: [], [], [], [] - */ -} - -#undef MPI_CHECK +// =================== +// Author: Peize Lin +// date: 2022.06.22 +// =================== + +#pragma once + +#include "Comm/Comm_Keys/Comm_Keys_31-gather.h" +#include "Comm/Comm_Keys/Comm_Keys_32-gather.h" +//#include "Comm/Comm_Keys/Comm_Keys_31-sr.h" +//#include "Comm/Comm_Keys/Comm_Keys_32-sr.h" +#include "Comm/example/Communicate_Set.h" +#include "Comm/global/MPI_Wrapper.h" +#include "unittests/print_stl.h" + +#include +#include +#include +#include +#include + +#define MPI_CHECK(x) if((x)!=MPI_SUCCESS) throw std::runtime_error(std::string(__FILE__)+" line "+std::to_string(__LINE__)); + +namespace Comm_Keys_3_Test +{ + class Require_Judge + { + public: + Require_Judge()=default; + Require_Judge(const MPI_Comm &mpi_comm) + { + this->rank_mine = Comm::MPI_Wrapper::mpi_get_rank(mpi_comm); + this->rank_size = Comm::MPI_Wrapper::mpi_get_size(mpi_comm); + } + bool judge(const int &key) const + { + return key%(rank_size+1)==rank_mine; + } + template void serialize( Archive & ar ){ ar(rank_size, rank_mine); } + private: + int rank_size=-1; + int rank_mine=-1; + }; + + class Provider_Judge + { + public: + Provider_Judge()=default; + Provider_Judge(const MPI_Comm &mpi_comm) + { + this->rank_mine = Comm::MPI_Wrapper::mpi_get_rank(mpi_comm); + if(rank_mine==0) + v = {2,3,5,7,9}; + else if(rank_mine==1) + v = {3,1,4}; + else if(rank_mine==2) + v = {2}; + else + v = {}; + std::ofstream ofs("out."+std::to_string(rank_mine), std::ofstream::app); + ofs<<"v\t"< void serialize( Archive & ar ){ ar(rank_mine, v); } + private: + std::set v; + int rank_mine=-1; + }; + + template + void main1(int argc, char *argv[]) + { + int mpi_thread_provide; + MPI_CHECK( MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &mpi_thread_provide ) ); + assert(mpi_thread_provide==MPI_THREAD_MULTIPLE); + const int rank_mine = Comm::MPI_Wrapper::mpi_get_rank(MPI_COMM_WORLD); + + std::set v; + if(rank_mine==0) + v = {2,3,5,7,9}; + else if(rank_mine==1) + v = {3,1,4}; + else if(rank_mine==2) + v = {2}; + else + v = {}; + + std::ofstream ofs("out."+std::to_string(rank_mine)); + ofs<<"v\t"<; + std::vector> keys_trans_list = comm.trans(v, require_judge); + + ofs<<"keys_trans_list\t"< + void main2(int argc, char *argv[]) + { + int mpi_thread_provide; + MPI_CHECK( MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &mpi_thread_provide ) ); + assert(mpi_thread_provide==MPI_THREAD_MULTIPLE); + const int rank_mine = Comm::MPI_Wrapper::mpi_get_rank(MPI_COMM_WORLD); + + std::ofstream ofs("out."+std::to_string(rank_mine)); + + Provider_Judge provider_judge(MPI_COMM_WORLD); + Require_Judge require_judge(MPI_COMM_WORLD); + T_Comm_Keys_3_SenderJudge comm(MPI_COMM_WORLD); + comm.traverse_keys_all = [](std::function &func){ for(int i=-3; i<10; ++i) func(i); }; + std::vector> keys_trans_list = comm.trans(provider_judge, require_judge); + + ofs<<"keys_trans_list\t"<, Require_Judge> > (argc,argv); } + inline void main1_32(int argc, char *argv[]){ main1< Comm::Comm_Keys_32_SenderTraversal, Require_Judge> > (argc,argv); } + inline void main2_31(int argc, char *argv[]){ main2< Comm::Comm_Keys_31_SenderJudge > (argc,argv); } + inline void main2_32(int argc, char *argv[]){ main2< Comm::Comm_Keys_32_SenderJudge > (argc,argv); } + + inline void test_all(int argc, char *argv[]) + { + main1_31(argc,argv); + main1_32(argc,argv); + main2_31(argc,argv); + main2_32(argc,argv); + } + + /* + keys_trans_list + + mpirun -n 1 + 0: [2] + + mpirun -n 2 + 0: [3,9], [7] + 1: [3], [1,4] + + mpirun -n 3 + 0: [], [5,9], [2] + 1: [4], [1], [] + 2: [], [], [2] + + mpirun -n 4 + 0: [5], [], [2,7], [3] + 1: [], [1], [], [3] + 2: [], [], [2], [] + 3: [], [], [], [] + */ +} + +#undef MPI_CHECK diff --git a/unittests/Comm_Trans/Communicate_Map-test-1.hpp b/unittests/Comm_Trans/Communicate_Map-test-1.hpp index 8d39331..85ed1ce 100644 --- a/unittests/Comm_Trans/Communicate_Map-test-1.hpp +++ b/unittests/Comm_Trans/Communicate_Map-test-1.hpp @@ -1,82 +1,82 @@ -//======================= -// AUTHOR : Peize Lin -// DATE : 2022-01-05 -//======================= - -#pragma once - -#include -#include -#include -#include -#include "Comm/Comm_Trans/Comm_Trans.h" -#include "Comm/example/Communicate_Map-1.h" -#include "Comm/global/MPI_Wrapper.h" -#include "unittests/print_stl.h" - -namespace Communicate_Map_Test -{ - static void test_transmission(int argc, char *argv[]) - { - int provided; - MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided ); - - std::map> m_in; - std::map> m_out; - - if(Comm::MPI_Wrapper::mpi_get_rank(MPI_COMM_WORLD)==0){ m_in[0][0]=1.2; m_in[2][1]=3.4;} - else if(Comm::MPI_Wrapper::mpi_get_rank(MPI_COMM_WORLD)==1){ m_in[2][0]=5.6; } - - Comm::Comm_Trans< - std::tuple, - double, - std::map>, - std::map> - > com(MPI_COMM_WORLD); - -// com.set_value_recv = Comm::Communicate_Map::set_value_assignment; - com.set_value_recv = Comm::Communicate_Map::set_value_add; - -// com.traverse_isend = Communicate_Map::traverse_datas_all; - com.traverse_isend = std::bind( - Comm::Communicate_Map::traverse_datas_mask, - std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, - [&com](const int rank_isend, const int &key0) ->bool { return (key0%com.comm_size==rank_isend) ? true : false; }, - [](const int rank_isend, const int &key1) ->bool { return true; }); - - com.flag_lock_set_value = Comm::Comm_Tools::Lock_Type::Copy_merge; - com.init_datas_local = Comm::Communicate_Map::init_datas_local; - com.add_datas = Comm::Communicate_Map::add_datas; - - com.communicate(m_in,m_out); - - std::ofstream ofs_in("in_"+std::to_string(Comm::MPI_Wrapper::mpi_get_rank(MPI_COMM_WORLD))); - ofs_in< +#include +#include +#include +#include "Comm/Comm_Trans/Comm_Trans.h" +#include "Comm/example/Communicate_Map-1.h" +#include "Comm/global/MPI_Wrapper.h" +#include "unittests/print_stl.h" + +namespace Communicate_Map_Test +{ + static void test_transmission(int argc, char *argv[]) + { + int provided; + MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided ); + + std::map> m_in; + std::map> m_out; + + if(Comm::MPI_Wrapper::mpi_get_rank(MPI_COMM_WORLD)==0){ m_in[0][0]=1.2; m_in[2][1]=3.4;} + else if(Comm::MPI_Wrapper::mpi_get_rank(MPI_COMM_WORLD)==1){ m_in[2][0]=5.6; } + + Comm::Comm_Trans< + std::tuple, + double, + std::map>, + std::map> + > com(MPI_COMM_WORLD); + +// com.set_value_recv = Comm::Communicate_Map::set_value_assignment; + com.set_value_recv = Comm::Communicate_Map::set_value_add; + +// com.traverse_isend = Communicate_Map::traverse_datas_all; + com.traverse_isend = std::bind( + Comm::Communicate_Map::traverse_datas_mask, + std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, + [&com](const int rank_isend, const int &key0) ->bool { return (key0%com.comm_size==rank_isend) ? true : false; }, + [](const int rank_isend, const int &key1) ->bool { return true; }); + + com.flag_lock_set_value = Comm::Comm_Tools::Lock_Type::Copy_merge; + com.init_datas_local = Comm::Communicate_Map::init_datas_local; + com.add_datas = Comm::Communicate_Map::add_datas; + + com.communicate(m_in,m_out); + + std::ofstream ofs_in("in_"+std::to_string(Comm::MPI_Wrapper::mpi_get_rank(MPI_COMM_WORLD))); + ofs_in< -#include -#include -#include - -#include - -namespace Cereal_Test -{ - static void main(int argc, char *argv[]) - { - MPI_Init(&argc, &argv); - int rank_size; MPI_Comm_size( MPI_COMM_WORLD, &rank_size ); - int rank_mine; MPI_Comm_rank( MPI_COMM_WORLD, &rank_mine ); - - Comm::Cereal_Func cereal_func; - if(rank_mine==0) - { - std::vector v = {1,2,3,4,5}; - std::map m = {{1,2.3}, {4,5.6}, {-7,-8.9}}; - std::string str; - MPI_Request request; - cereal_func.mpi_isend(1, 0, MPI_COMM_WORLD, str, request, - v, std::string("abc"), -100, m); - - std::cout<<"#\t"< v; - std::string s; - int i; - std::map m; - MPI_Status status = cereal_func.mpi_recv(MPI_COMM_WORLD, - v, s, i, m); - - std::cout<<"@\t"< +#include +#include +#include + +#include + +namespace Cereal_Test +{ + static void main(int argc, char *argv[]) + { + MPI_Init(&argc, &argv); + int rank_size; MPI_Comm_size( MPI_COMM_WORLD, &rank_size ); + int rank_mine; MPI_Comm_rank( MPI_COMM_WORLD, &rank_mine ); + + Comm::Cereal_Func cereal_func; + if(rank_mine==0) + { + std::vector v = {1,2,3,4,5}; + std::map m = {{1,2.3}, {4,5.6}, {-7,-8.9}}; + std::string str; + MPI_Request request; + cereal_func.mpi_isend(1, 0, MPI_COMM_WORLD, str, request, + v, std::string("abc"), -100, m); + + std::cout<<"#\t"< v; + std::string s; + int i; + std::map m; + MPI_Status status = cereal_func.mpi_recv(MPI_COMM_WORLD, + v, s, i, m); + + std::cout<<"@\t"< -#include -#include -#include -#include -#include - -template -std::ostream &operator<<(std::ostream &os, const std::vector &v) -{ - for(size_t i=0; i -std::ostream &operator<<(std::ostream &os, const std::set &s) -{ - for(const T &i : s) - os< -std::ostream &operator<<(std::ostream &os, const std::array &v) -{ - for(size_t i=0; i -std::ostream &operator<<(std::ostream &os, const std::pair &p) -{ - os<<"{ "< -std::ostream &operator<<(std::ostream &os, const std::map &m) -{ - for(const auto &i : m) - os< -std::ostream &operator<<(std::ostream &os, const std::tuple &t) -{ - os<<"[ "<(t)<<", "<(t)<<" ]"; - return os; -} \ No newline at end of file +// =================== +// Author: Peize Lin +// date: 2021.08.21 +// =================== + +#pragma once + +#include +#include +#include +#include +#include +#include + +template +std::ostream &operator<<(std::ostream &os, const std::vector &v) +{ + for(size_t i=0; i +std::ostream &operator<<(std::ostream &os, const std::set &s) +{ + for(const T &i : s) + os< +std::ostream &operator<<(std::ostream &os, const std::array &v) +{ + for(size_t i=0; i +std::ostream &operator<<(std::ostream &os, const std::pair &p) +{ + os<<"{ "< +std::ostream &operator<<(std::ostream &os, const std::map &m) +{ + for(const auto &i : m) + os< +std::ostream &operator<<(std::ostream &os, const std::tuple &t) +{ + os<<"[ "<(t)<<", "<(t)<<" ]"; + return os; +}