Ginkgo Generated from branch based on master. Ginkgo version 1.8.0
A numerical linear algebra library targeting many-core architectures
Loading...
Searching...
No Matches
mpi.hpp
1// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_BASE_MPI_HPP_
6#define GKO_PUBLIC_CORE_BASE_MPI_HPP_
7
8
9#include <memory>
10#include <type_traits>
11#include <utility>
12
13
14#include <ginkgo/config.hpp>
15#include <ginkgo/core/base/exception.hpp>
16#include <ginkgo/core/base/exception_helpers.hpp>
17#include <ginkgo/core/base/executor.hpp>
18#include <ginkgo/core/base/types.hpp>
19#include <ginkgo/core/base/utils_helper.hpp>
20
21
22#if GINKGO_BUILD_MPI
23
24
25#include <mpi.h>
26
27
28namespace gko {
29namespace experimental {
36namespace mpi {
37
38
42inline constexpr bool is_gpu_aware()
43{
44#if GINKGO_HAVE_GPU_AWARE_MPI
45 return true;
46#else
47 return false;
48#endif
49}
50
51
59int map_rank_to_device_id(MPI_Comm comm, int num_devices);
60
61
62#define GKO_REGISTER_MPI_TYPE(input_type, mpi_type) \
63 template <> \
64 struct type_impl<input_type> { \
65 static MPI_Datatype get_type() { return mpi_type; } \
66 }
67
76template <typename T>
77struct type_impl {};
78
79
80GKO_REGISTER_MPI_TYPE(char, MPI_CHAR);
81GKO_REGISTER_MPI_TYPE(unsigned char, MPI_UNSIGNED_CHAR);
82GKO_REGISTER_MPI_TYPE(unsigned, MPI_UNSIGNED);
83GKO_REGISTER_MPI_TYPE(int, MPI_INT);
84GKO_REGISTER_MPI_TYPE(unsigned short, MPI_UNSIGNED_SHORT);
85GKO_REGISTER_MPI_TYPE(unsigned long, MPI_UNSIGNED_LONG);
86GKO_REGISTER_MPI_TYPE(long, MPI_LONG);
87GKO_REGISTER_MPI_TYPE(long long, MPI_LONG_LONG_INT);
88GKO_REGISTER_MPI_TYPE(unsigned long long, MPI_UNSIGNED_LONG_LONG);
89GKO_REGISTER_MPI_TYPE(float, MPI_FLOAT);
90GKO_REGISTER_MPI_TYPE(double, MPI_DOUBLE);
91GKO_REGISTER_MPI_TYPE(long double, MPI_LONG_DOUBLE);
92GKO_REGISTER_MPI_TYPE(std::complex<float>, MPI_C_FLOAT_COMPLEX);
93GKO_REGISTER_MPI_TYPE(std::complex<double>, MPI_C_DOUBLE_COMPLEX);
94
95
103public:
110 contiguous_type(int count, MPI_Datatype old_type) : type_(MPI_DATATYPE_NULL)
111 {
112 GKO_ASSERT_NO_MPI_ERRORS(MPI_Type_contiguous(count, old_type, &type_));
113 GKO_ASSERT_NO_MPI_ERRORS(MPI_Type_commit(&type_));
114 }
115
119 contiguous_type() : type_(MPI_DATATYPE_NULL) {}
120
125
130
136 contiguous_type(contiguous_type&& other) noexcept : type_(MPI_DATATYPE_NULL)
137 {
138 *this = std::move(other);
139 }
140
149 {
150 if (this != &other) {
151 this->type_ = std::exchange(other.type_, MPI_DATATYPE_NULL);
152 }
153 return *this;
154 }
155
160 {
161 if (type_ != MPI_DATATYPE_NULL) {
162 MPI_Type_free(&type_);
163 }
164 }
165
171 MPI_Datatype get() const { return type_; }
172
173private:
174 MPI_Datatype type_;
175};
176
177
182enum class thread_type {
183 serialized = MPI_THREAD_SERIALIZED,
184 funneled = MPI_THREAD_FUNNELED,
185 single = MPI_THREAD_SINGLE,
186 multiple = MPI_THREAD_MULTIPLE
187};
188
189
200public:
201 static bool is_finalized()
202 {
203 int flag = 0;
204 GKO_ASSERT_NO_MPI_ERRORS(MPI_Finalized(&flag));
205 return flag;
206 }
207
208 static bool is_initialized()
209 {
210 int flag = 0;
211 GKO_ASSERT_NO_MPI_ERRORS(MPI_Initialized(&flag));
212 return flag;
213 }
214
220 int get_provided_thread_support() const { return provided_thread_support_; }
221
230 environment(int& argc, char**& argv,
231 const thread_type thread_t = thread_type::serialized)
232 {
233 this->required_thread_support_ = static_cast<int>(thread_t);
234 GKO_ASSERT_NO_MPI_ERRORS(
235 MPI_Init_thread(&argc, &argv, this->required_thread_support_,
236 &(this->provided_thread_support_)));
237 }
238
242 ~environment() { MPI_Finalize(); }
243
244 environment(const environment&) = delete;
245 environment(environment&&) = delete;
246 environment& operator=(const environment&) = delete;
247 environment& operator=(environment&&) = delete;
248
249private:
250 int required_thread_support_;
251 int provided_thread_support_;
252};
253
254
255namespace {
256
257
262class comm_deleter {
263public:
264 using pointer = MPI_Comm*;
265 void operator()(pointer comm) const
266 {
267 GKO_ASSERT(*comm != MPI_COMM_NULL);
268 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_free(comm));
269 delete comm;
270 }
271};
272
273
274} // namespace
275
276
280struct status {
284 status() : status_(MPI_Status{}) {}
285
291 MPI_Status* get() { return &this->status_; }
292
303 template <typename T>
304 int get_count(const T* data) const
305 {
306 int count;
307 MPI_Get_count(&status_, type_impl<T>::get_type(), &count);
308 return count;
309 }
310
311private:
312 MPI_Status status_;
313};
314
315
320class request {
321public:
326 request() : req_(MPI_REQUEST_NULL) {}
327
328 request(const request&) = delete;
329
330 request& operator=(const request&) = delete;
331
332 request(request&& o) noexcept { *this = std::move(o); }
333
334 request& operator=(request&& o) noexcept
335 {
336 if (this != &o) {
337 this->req_ = std::exchange(o.req_, MPI_REQUEST_NULL);
338 }
339 return *this;
340 }
341
342 ~request()
343 {
344 if (req_ != MPI_REQUEST_NULL) {
345 if (MPI_Request_free(&req_) != MPI_SUCCESS) {
346 std::terminate(); // since we can't throw in destructors, we
347 // have to terminate the program
348 }
349 }
350 }
351
357 MPI_Request* get() { return &this->req_; }
358
366 {
368 GKO_ASSERT_NO_MPI_ERRORS(MPI_Wait(&req_, status.get()));
369 return status;
370 }
371
372
373private:
374 MPI_Request req_;
375};
376
377
385inline std::vector<status> wait_all(std::vector<request>& req)
386{
387 std::vector<status> stat;
388 for (std::size_t i = 0; i < req.size(); ++i) {
389 stat.emplace_back(req[i].wait());
390 }
391 return stat;
392}
393
394
410public:
421 communicator(const MPI_Comm& comm, bool force_host_buffer = false)
422 : comm_(), force_host_buffer_(force_host_buffer)
423 {
424 this->comm_.reset(new MPI_Comm(comm));
425 }
426
435 communicator(const MPI_Comm& comm, int color, int key)
436 {
437 MPI_Comm comm_out;
438 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_split(comm, color, key, &comm_out));
439 this->comm_.reset(new MPI_Comm(comm_out), comm_deleter{});
440 }
441
450 communicator(const communicator& comm, int color, int key)
451 {
452 MPI_Comm comm_out;
453 GKO_ASSERT_NO_MPI_ERRORS(
454 MPI_Comm_split(comm.get(), color, key, &comm_out));
455 this->comm_.reset(new MPI_Comm(comm_out), comm_deleter{});
456 }
457
463 const MPI_Comm& get() const { return *(this->comm_.get()); }
464
465 bool force_host_buffer() const { return force_host_buffer_; }
466
472 int size() const { return get_num_ranks(); }
473
479 int rank() const { return get_my_rank(); };
480
486 int node_local_rank() const { return get_node_local_rank(); };
487
493 bool operator==(const communicator& rhs) const
494 {
495 return compare(rhs.get());
496 }
497
503 bool operator!=(const communicator& rhs) const { return !(*this == rhs); }
504
509 void synchronize() const
510 {
511 GKO_ASSERT_NO_MPI_ERRORS(MPI_Barrier(this->get()));
512 }
513
527 template <typename SendType>
528 void send(std::shared_ptr<const Executor> exec, const SendType* send_buffer,
529 const int send_count, const int destination_rank,
530 const int send_tag) const
531 {
532 auto guard = exec->get_scoped_device_id_guard();
533 GKO_ASSERT_NO_MPI_ERRORS(
534 MPI_Send(send_buffer, send_count, type_impl<SendType>::get_type(),
535 destination_rank, send_tag, this->get()));
536 }
537
554 template <typename SendType>
555 request i_send(std::shared_ptr<const Executor> exec,
556 const SendType* send_buffer, const int send_count,
557 const int destination_rank, const int send_tag) const
558 {
559 auto guard = exec->get_scoped_device_id_guard();
560 request req;
561 GKO_ASSERT_NO_MPI_ERRORS(
562 MPI_Isend(send_buffer, send_count, type_impl<SendType>::get_type(),
563 destination_rank, send_tag, this->get(), req.get()));
564 return req;
565 }
566
582 template <typename RecvType>
583 status recv(std::shared_ptr<const Executor> exec, RecvType* recv_buffer,
584 const int recv_count, const int source_rank,
585 const int recv_tag) const
586 {
587 auto guard = exec->get_scoped_device_id_guard();
588 status st;
589 GKO_ASSERT_NO_MPI_ERRORS(
590 MPI_Recv(recv_buffer, recv_count, type_impl<RecvType>::get_type(),
591 source_rank, recv_tag, this->get(), st.get()));
592 return st;
593 }
594
610 template <typename RecvType>
611 request i_recv(std::shared_ptr<const Executor> exec, RecvType* recv_buffer,
612 const int recv_count, const int source_rank,
613 const int recv_tag) const
614 {
615 auto guard = exec->get_scoped_device_id_guard();
616 request req;
617 GKO_ASSERT_NO_MPI_ERRORS(
618 MPI_Irecv(recv_buffer, recv_count, type_impl<RecvType>::get_type(),
619 source_rank, recv_tag, this->get(), req.get()));
620 return req;
621 }
622
635 template <typename BroadcastType>
636 void broadcast(std::shared_ptr<const Executor> exec, BroadcastType* buffer,
637 int count, int root_rank) const
638 {
639 auto guard = exec->get_scoped_device_id_guard();
640 GKO_ASSERT_NO_MPI_ERRORS(MPI_Bcast(buffer, count,
642 root_rank, this->get()));
643 }
644
660 template <typename BroadcastType>
661 request i_broadcast(std::shared_ptr<const Executor> exec,
662 BroadcastType* buffer, int count, int root_rank) const
663 {
664 auto guard = exec->get_scoped_device_id_guard();
665 request req;
666 GKO_ASSERT_NO_MPI_ERRORS(
667 MPI_Ibcast(buffer, count, type_impl<BroadcastType>::get_type(),
668 root_rank, this->get(), req.get()));
669 return req;
670 }
671
686 template <typename ReduceType>
687 void reduce(std::shared_ptr<const Executor> exec,
688 const ReduceType* send_buffer, ReduceType* recv_buffer,
689 int count, MPI_Op operation, int root_rank) const
690 {
691 auto guard = exec->get_scoped_device_id_guard();
692 GKO_ASSERT_NO_MPI_ERRORS(MPI_Reduce(send_buffer, recv_buffer, count,
694 operation, root_rank, this->get()));
695 }
696
713 template <typename ReduceType>
714 request i_reduce(std::shared_ptr<const Executor> exec,
715 const ReduceType* send_buffer, ReduceType* recv_buffer,
716 int count, MPI_Op operation, int root_rank) const
717 {
718 auto guard = exec->get_scoped_device_id_guard();
719 request req;
720 GKO_ASSERT_NO_MPI_ERRORS(MPI_Ireduce(
721 send_buffer, recv_buffer, count, type_impl<ReduceType>::get_type(),
722 operation, root_rank, this->get(), req.get()));
723 return req;
724 }
725
739 template <typename ReduceType>
740 void all_reduce(std::shared_ptr<const Executor> exec,
741 ReduceType* recv_buffer, int count, MPI_Op operation) const
742 {
743 auto guard = exec->get_scoped_device_id_guard();
744 GKO_ASSERT_NO_MPI_ERRORS(MPI_Allreduce(
745 MPI_IN_PLACE, recv_buffer, count, type_impl<ReduceType>::get_type(),
746 operation, this->get()));
747 }
748
764 template <typename ReduceType>
765 request i_all_reduce(std::shared_ptr<const Executor> exec,
766 ReduceType* recv_buffer, int count,
767 MPI_Op operation) const
768 {
769 auto guard = exec->get_scoped_device_id_guard();
770 request req;
771 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iallreduce(
772 MPI_IN_PLACE, recv_buffer, count, type_impl<ReduceType>::get_type(),
773 operation, this->get(), req.get()));
774 return req;
775 }
776
791 template <typename ReduceType>
792 void all_reduce(std::shared_ptr<const Executor> exec,
793 const ReduceType* send_buffer, ReduceType* recv_buffer,
794 int count, MPI_Op operation) const
795 {
796 auto guard = exec->get_scoped_device_id_guard();
797 GKO_ASSERT_NO_MPI_ERRORS(MPI_Allreduce(
798 send_buffer, recv_buffer, count, type_impl<ReduceType>::get_type(),
799 operation, this->get()));
800 }
801
818 template <typename ReduceType>
819 request i_all_reduce(std::shared_ptr<const Executor> exec,
820 const ReduceType* send_buffer, ReduceType* recv_buffer,
821 int count, MPI_Op operation) const
822 {
823 auto guard = exec->get_scoped_device_id_guard();
824 request req;
825 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iallreduce(
826 send_buffer, recv_buffer, count, type_impl<ReduceType>::get_type(),
827 operation, this->get(), req.get()));
828 return req;
829 }
830
847 template <typename SendType, typename RecvType>
848 void gather(std::shared_ptr<const Executor> exec,
849 const SendType* send_buffer, const int send_count,
850 RecvType* recv_buffer, const int recv_count,
851 int root_rank) const
852 {
853 auto guard = exec->get_scoped_device_id_guard();
854 GKO_ASSERT_NO_MPI_ERRORS(
855 MPI_Gather(send_buffer, send_count, type_impl<SendType>::get_type(),
856 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
857 root_rank, this->get()));
858 }
859
879 template <typename SendType, typename RecvType>
880 request i_gather(std::shared_ptr<const Executor> exec,
881 const SendType* send_buffer, const int send_count,
882 RecvType* recv_buffer, const int recv_count,
883 int root_rank) const
884 {
885 auto guard = exec->get_scoped_device_id_guard();
886 request req;
887 GKO_ASSERT_NO_MPI_ERRORS(MPI_Igather(
888 send_buffer, send_count, type_impl<SendType>::get_type(),
889 recv_buffer, recv_count, type_impl<RecvType>::get_type(), root_rank,
890 this->get(), req.get()));
891 return req;
892 }
893
912 template <typename SendType, typename RecvType>
913 void gather_v(std::shared_ptr<const Executor> exec,
914 const SendType* send_buffer, const int send_count,
915 RecvType* recv_buffer, const int* recv_counts,
916 const int* displacements, int root_rank) const
917 {
918 auto guard = exec->get_scoped_device_id_guard();
919 GKO_ASSERT_NO_MPI_ERRORS(MPI_Gatherv(
920 send_buffer, send_count, type_impl<SendType>::get_type(),
921 recv_buffer, recv_counts, displacements,
922 type_impl<RecvType>::get_type(), root_rank, this->get()));
923 }
924
945 template <typename SendType, typename RecvType>
946 request i_gather_v(std::shared_ptr<const Executor> exec,
947 const SendType* send_buffer, const int send_count,
948 RecvType* recv_buffer, const int* recv_counts,
949 const int* displacements, int root_rank) const
950 {
951 auto guard = exec->get_scoped_device_id_guard();
952 request req;
953 GKO_ASSERT_NO_MPI_ERRORS(MPI_Igatherv(
954 send_buffer, send_count, type_impl<SendType>::get_type(),
955 recv_buffer, recv_counts, displacements,
956 type_impl<RecvType>::get_type(), root_rank, this->get(),
957 req.get()));
958 return req;
959 }
960
976 template <typename SendType, typename RecvType>
977 void all_gather(std::shared_ptr<const Executor> exec,
978 const SendType* send_buffer, const int send_count,
979 RecvType* recv_buffer, const int recv_count) const
980 {
981 auto guard = exec->get_scoped_device_id_guard();
982 GKO_ASSERT_NO_MPI_ERRORS(MPI_Allgather(
983 send_buffer, send_count, type_impl<SendType>::get_type(),
984 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
985 this->get()));
986 }
987
1006 template <typename SendType, typename RecvType>
1007 request i_all_gather(std::shared_ptr<const Executor> exec,
1008 const SendType* send_buffer, const int send_count,
1009 RecvType* recv_buffer, const int recv_count) const
1010 {
1011 auto guard = exec->get_scoped_device_id_guard();
1012 request req;
1013 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iallgather(
1014 send_buffer, send_count, type_impl<SendType>::get_type(),
1015 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1016 this->get(), req.get()));
1017 return req;
1018 }
1019
1035 template <typename SendType, typename RecvType>
1036 void scatter(std::shared_ptr<const Executor> exec,
1037 const SendType* send_buffer, const int send_count,
1038 RecvType* recv_buffer, const int recv_count,
1039 int root_rank) const
1040 {
1041 auto guard = exec->get_scoped_device_id_guard();
1042 GKO_ASSERT_NO_MPI_ERRORS(MPI_Scatter(
1043 send_buffer, send_count, type_impl<SendType>::get_type(),
1044 recv_buffer, recv_count, type_impl<RecvType>::get_type(), root_rank,
1045 this->get()));
1046 }
1047
1066 template <typename SendType, typename RecvType>
1067 request i_scatter(std::shared_ptr<const Executor> exec,
1068 const SendType* send_buffer, const int send_count,
1069 RecvType* recv_buffer, const int recv_count,
1070 int root_rank) const
1071 {
1072 auto guard = exec->get_scoped_device_id_guard();
1073 request req;
1074 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iscatter(
1075 send_buffer, send_count, type_impl<SendType>::get_type(),
1076 recv_buffer, recv_count, type_impl<RecvType>::get_type(), root_rank,
1077 this->get(), req.get()));
1078 return req;
1079 }
1080
1099 template <typename SendType, typename RecvType>
1100 void scatter_v(std::shared_ptr<const Executor> exec,
1101 const SendType* send_buffer, const int* send_counts,
1102 const int* displacements, RecvType* recv_buffer,
1103 const int recv_count, int root_rank) const
1104 {
1105 auto guard = exec->get_scoped_device_id_guard();
1106 GKO_ASSERT_NO_MPI_ERRORS(MPI_Scatterv(
1107 send_buffer, send_counts, displacements,
1108 type_impl<SendType>::get_type(), recv_buffer, recv_count,
1109 type_impl<RecvType>::get_type(), root_rank, this->get()));
1110 }
1111
1132 template <typename SendType, typename RecvType>
1133 request i_scatter_v(std::shared_ptr<const Executor> exec,
1134 const SendType* send_buffer, const int* send_counts,
1135 const int* displacements, RecvType* recv_buffer,
1136 const int recv_count, int root_rank) const
1137 {
1138 auto guard = exec->get_scoped_device_id_guard();
1139 request req;
1140 GKO_ASSERT_NO_MPI_ERRORS(
1141 MPI_Iscatterv(send_buffer, send_counts, displacements,
1142 type_impl<SendType>::get_type(), recv_buffer,
1143 recv_count, type_impl<RecvType>::get_type(),
1144 root_rank, this->get(), req.get()));
1145 return req;
1146 }
1147
1164 template <typename RecvType>
1165 void all_to_all(std::shared_ptr<const Executor> exec, RecvType* recv_buffer,
1166 const int recv_count) const
1167 {
1168 auto guard = exec->get_scoped_device_id_guard();
1169 GKO_ASSERT_NO_MPI_ERRORS(MPI_Alltoall(
1170 MPI_IN_PLACE, recv_count, type_impl<RecvType>::get_type(),
1171 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1172 this->get()));
1173 }
1174
1193 template <typename RecvType>
1194 request i_all_to_all(std::shared_ptr<const Executor> exec,
1195 RecvType* recv_buffer, const int recv_count) const
1196 {
1197 auto guard = exec->get_scoped_device_id_guard();
1198 request req;
1199 GKO_ASSERT_NO_MPI_ERRORS(MPI_Ialltoall(
1200 MPI_IN_PLACE, recv_count, type_impl<RecvType>::get_type(),
1201 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1202 this->get(), req.get()));
1203 return req;
1204 }
1205
1222 template <typename SendType, typename RecvType>
1223 void all_to_all(std::shared_ptr<const Executor> exec,
1224 const SendType* send_buffer, const int send_count,
1225 RecvType* recv_buffer, const int recv_count) const
1226 {
1227 auto guard = exec->get_scoped_device_id_guard();
1228 GKO_ASSERT_NO_MPI_ERRORS(MPI_Alltoall(
1229 send_buffer, send_count, type_impl<SendType>::get_type(),
1230 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1231 this->get()));
1232 }
1233
1252 template <typename SendType, typename RecvType>
1253 request i_all_to_all(std::shared_ptr<const Executor> exec,
1254 const SendType* send_buffer, const int send_count,
1255 RecvType* recv_buffer, const int recv_count) const
1256 {
1257 auto guard = exec->get_scoped_device_id_guard();
1258 request req;
1259 GKO_ASSERT_NO_MPI_ERRORS(MPI_Ialltoall(
1260 send_buffer, send_count, type_impl<SendType>::get_type(),
1261 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1262 this->get(), req.get()));
1263 return req;
1264 }
1265
1285 template <typename SendType, typename RecvType>
1286 void all_to_all_v(std::shared_ptr<const Executor> exec,
1287 const SendType* send_buffer, const int* send_counts,
1288 const int* send_offsets, RecvType* recv_buffer,
1289 const int* recv_counts, const int* recv_offsets) const
1290 {
1291 this->all_to_all_v(std::move(exec), send_buffer, send_counts,
1292 send_offsets, type_impl<SendType>::get_type(),
1293 recv_buffer, recv_counts, recv_offsets,
1295 }
1296
1312 void all_to_all_v(std::shared_ptr<const Executor> exec,
1313 const void* send_buffer, const int* send_counts,
1314 const int* send_offsets, MPI_Datatype send_type,
1315 void* recv_buffer, const int* recv_counts,
1316 const int* recv_offsets, MPI_Datatype recv_type) const
1317 {
1318 auto guard = exec->get_scoped_device_id_guard();
1319 GKO_ASSERT_NO_MPI_ERRORS(MPI_Alltoallv(
1320 send_buffer, send_counts, send_offsets, send_type, recv_buffer,
1321 recv_counts, recv_offsets, recv_type, this->get()));
1322 }
1323
1343 request i_all_to_all_v(std::shared_ptr<const Executor> exec,
1344 const void* send_buffer, const int* send_counts,
1345 const int* send_offsets, MPI_Datatype send_type,
1346 void* recv_buffer, const int* recv_counts,
1347 const int* recv_offsets,
1348 MPI_Datatype recv_type) const
1349 {
1350 auto guard = exec->get_scoped_device_id_guard();
1351 request req;
1352 GKO_ASSERT_NO_MPI_ERRORS(MPI_Ialltoallv(
1353 send_buffer, send_counts, send_offsets, send_type, recv_buffer,
1354 recv_counts, recv_offsets, recv_type, this->get(), req.get()));
1355 return req;
1356 }
1357
1378 template <typename SendType, typename RecvType>
1379 request i_all_to_all_v(std::shared_ptr<const Executor> exec,
1380 const SendType* send_buffer, const int* send_counts,
1381 const int* send_offsets, RecvType* recv_buffer,
1382 const int* recv_counts,
1383 const int* recv_offsets) const
1384 {
1385 return this->i_all_to_all_v(
1386 std::move(exec), send_buffer, send_counts, send_offsets,
1387 type_impl<SendType>::get_type(), recv_buffer, recv_counts,
1388 recv_offsets, type_impl<RecvType>::get_type());
1389 }
1390
1405 template <typename ScanType>
1406 void scan(std::shared_ptr<const Executor> exec, const ScanType* send_buffer,
1407 ScanType* recv_buffer, int count, MPI_Op operation) const
1408 {
1409 auto guard = exec->get_scoped_device_id_guard();
1410 GKO_ASSERT_NO_MPI_ERRORS(MPI_Scan(send_buffer, recv_buffer, count,
1412 operation, this->get()));
1413 }
1414
1431 template <typename ScanType>
1432 request i_scan(std::shared_ptr<const Executor> exec,
1433 const ScanType* send_buffer, ScanType* recv_buffer,
1434 int count, MPI_Op operation) const
1435 {
1436 auto guard = exec->get_scoped_device_id_guard();
1437 request req;
1438 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iscan(send_buffer, recv_buffer, count,
1440 operation, this->get(), req.get()));
1441 return req;
1442 }
1443
1444private:
1445 std::shared_ptr<MPI_Comm> comm_;
1446 bool force_host_buffer_;
1447
1448 int get_my_rank() const
1449 {
1450 int my_rank = 0;
1451 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_rank(get(), &my_rank));
1452 return my_rank;
1453 }
1454
1455 int get_node_local_rank() const
1456 {
1457 MPI_Comm local_comm;
1458 int rank;
1459 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_split_type(
1460 this->get(), MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &local_comm));
1461 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_rank(local_comm, &rank));
1462 MPI_Comm_free(&local_comm);
1463 return rank;
1464 }
1465
1466 int get_num_ranks() const
1467 {
1468 int size = 1;
1469 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_size(this->get(), &size));
1470 return size;
1471 }
1472
1473 bool compare(const MPI_Comm& other) const
1474 {
1475 int flag;
1476 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_compare(get(), other, &flag));
1477 return flag == MPI_IDENT;
1478 }
1479};
1480
1481
1486bool requires_host_buffer(const std::shared_ptr<const Executor>& exec,
1487 const communicator& comm);
1488
1489
1495inline double get_walltime() { return MPI_Wtime(); }
1496
1497
1506template <typename ValueType>
1507class window {
1508public:
1512 enum class create_type { allocate = 1, create = 2, dynamic_create = 3 };
1513
1517 enum class lock_type { shared = 1, exclusive = 2 };
1518
1522 window() : window_(MPI_WIN_NULL) {}
1523
1524 window(const window& other) = delete;
1525
1526 window& operator=(const window& other) = delete;
1527
1534 window(window&& other) : window_{std::exchange(other.window_, MPI_WIN_NULL)}
1535 {}
1536
1544 {
1545 window_ = std::exchange(other.window_, MPI_WIN_NULL);
1546 }
1547
1560 window(std::shared_ptr<const Executor> exec, ValueType* base, int num_elems,
1561 const communicator& comm, const int disp_unit = sizeof(ValueType),
1562 MPI_Info input_info = MPI_INFO_NULL,
1563 create_type c_type = create_type::create)
1564 {
1565 auto guard = exec->get_scoped_device_id_guard();
1566 unsigned size = num_elems * sizeof(ValueType);
1567 if (c_type == create_type::create) {
1568 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_create(
1569 base, size, disp_unit, input_info, comm.get(), &this->window_));
1570 } else if (c_type == create_type::dynamic_create) {
1571 GKO_ASSERT_NO_MPI_ERRORS(
1572 MPI_Win_create_dynamic(input_info, comm.get(), &this->window_));
1573 } else if (c_type == create_type::allocate) {
1574 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_allocate(
1575 size, disp_unit, input_info, comm.get(), base, &this->window_));
1576 } else {
1577 GKO_NOT_IMPLEMENTED;
1578 }
1579 }
1580
1586 MPI_Win get_window() const { return this->window_; }
1587
1594 void fence(int assert = 0) const
1595 {
1596 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_fence(assert, this->window_));
1597 }
1598
1607 void lock(int rank, lock_type lock_t = lock_type::shared,
1608 int assert = 0) const
1609 {
1610 if (lock_t == lock_type::shared) {
1611 GKO_ASSERT_NO_MPI_ERRORS(
1612 MPI_Win_lock(MPI_LOCK_SHARED, rank, assert, this->window_));
1613 } else if (lock_t == lock_type::exclusive) {
1614 GKO_ASSERT_NO_MPI_ERRORS(
1615 MPI_Win_lock(MPI_LOCK_EXCLUSIVE, rank, assert, this->window_));
1616 } else {
1617 GKO_NOT_IMPLEMENTED;
1618 }
1619 }
1620
1627 void unlock(int rank) const
1628 {
1629 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_unlock(rank, this->window_));
1630 }
1631
1638 void lock_all(int assert = 0) const
1639 {
1640 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_lock_all(assert, this->window_));
1641 }
1642
1647 void unlock_all() const
1648 {
1649 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_unlock_all(this->window_));
1650 }
1651
1658 void flush(int rank) const
1659 {
1660 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush(rank, this->window_));
1661 }
1662
1669 void flush_local(int rank) const
1670 {
1671 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush_local(rank, this->window_));
1672 }
1673
1678 void flush_all() const
1679 {
1680 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush_all(this->window_));
1681 }
1682
1687 void flush_all_local() const
1688 {
1689 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush_local_all(this->window_));
1690 }
1691
1695 void sync() const { GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_sync(this->window_)); }
1696
1701 {
1702 if (this->window_ && this->window_ != MPI_WIN_NULL) {
1703 MPI_Win_free(&this->window_);
1704 }
1705 }
1706
1717 template <typename PutType>
1718 void put(std::shared_ptr<const Executor> exec, const PutType* origin_buffer,
1719 const int origin_count, const int target_rank,
1720 const unsigned int target_disp, const int target_count) const
1721 {
1722 auto guard = exec->get_scoped_device_id_guard();
1723 GKO_ASSERT_NO_MPI_ERRORS(
1724 MPI_Put(origin_buffer, origin_count, type_impl<PutType>::get_type(),
1725 target_rank, target_disp, target_count,
1727 }
1728
1741 template <typename PutType>
1742 request r_put(std::shared_ptr<const Executor> exec,
1743 const PutType* origin_buffer, const int origin_count,
1744 const int target_rank, const unsigned int target_disp,
1745 const int target_count) const
1746 {
1747 auto guard = exec->get_scoped_device_id_guard();
1748 request req;
1749 GKO_ASSERT_NO_MPI_ERRORS(MPI_Rput(
1750 origin_buffer, origin_count, type_impl<PutType>::get_type(),
1751 target_rank, target_disp, target_count,
1752 type_impl<PutType>::get_type(), this->get_window(), req.get()));
1753 return req;
1754 }
1755
1767 template <typename PutType>
1768 void accumulate(std::shared_ptr<const Executor> exec,
1769 const PutType* origin_buffer, const int origin_count,
1770 const int target_rank, const unsigned int target_disp,
1771 const int target_count, MPI_Op operation) const
1772 {
1773 auto guard = exec->get_scoped_device_id_guard();
1774 GKO_ASSERT_NO_MPI_ERRORS(MPI_Accumulate(
1775 origin_buffer, origin_count, type_impl<PutType>::get_type(),
1776 target_rank, target_disp, target_count,
1777 type_impl<PutType>::get_type(), operation, this->get_window()));
1778 }
1779
1793 template <typename PutType>
1794 request r_accumulate(std::shared_ptr<const Executor> exec,
1795 const PutType* origin_buffer, const int origin_count,
1796 const int target_rank, const unsigned int target_disp,
1797 const int target_count, MPI_Op operation) const
1798 {
1799 auto guard = exec->get_scoped_device_id_guard();
1800 request req;
1801 GKO_ASSERT_NO_MPI_ERRORS(MPI_Raccumulate(
1802 origin_buffer, origin_count, type_impl<PutType>::get_type(),
1803 target_rank, target_disp, target_count,
1804 type_impl<PutType>::get_type(), operation, this->get_window(),
1805 req.get()));
1806 return req;
1807 }
1808
1819 template <typename GetType>
1820 void get(std::shared_ptr<const Executor> exec, GetType* origin_buffer,
1821 const int origin_count, const int target_rank,
1822 const unsigned int target_disp, const int target_count) const
1823 {
1824 auto guard = exec->get_scoped_device_id_guard();
1825 GKO_ASSERT_NO_MPI_ERRORS(
1826 MPI_Get(origin_buffer, origin_count, type_impl<GetType>::get_type(),
1827 target_rank, target_disp, target_count,
1829 }
1830
1843 template <typename GetType>
1844 request r_get(std::shared_ptr<const Executor> exec, GetType* origin_buffer,
1845 const int origin_count, const int target_rank,
1846 const unsigned int target_disp, const int target_count) const
1847 {
1848 auto guard = exec->get_scoped_device_id_guard();
1849 request req;
1850 GKO_ASSERT_NO_MPI_ERRORS(MPI_Rget(
1851 origin_buffer, origin_count, type_impl<GetType>::get_type(),
1852 target_rank, target_disp, target_count,
1853 type_impl<GetType>::get_type(), this->get_window(), req.get()));
1854 return req;
1855 }
1856
1870 template <typename GetType>
1871 void get_accumulate(std::shared_ptr<const Executor> exec,
1872 GetType* origin_buffer, const int origin_count,
1873 GetType* result_buffer, const int result_count,
1874 const int target_rank, const unsigned int target_disp,
1875 const int target_count, MPI_Op operation) const
1876 {
1877 auto guard = exec->get_scoped_device_id_guard();
1878 GKO_ASSERT_NO_MPI_ERRORS(MPI_Get_accumulate(
1879 origin_buffer, origin_count, type_impl<GetType>::get_type(),
1880 result_buffer, result_count, type_impl<GetType>::get_type(),
1881 target_rank, target_disp, target_count,
1882 type_impl<GetType>::get_type(), operation, this->get_window()));
1883 }
1884
1900 template <typename GetType>
1901 request r_get_accumulate(std::shared_ptr<const Executor> exec,
1902 GetType* origin_buffer, const int origin_count,
1903 GetType* result_buffer, const int result_count,
1904 const int target_rank,
1905 const unsigned int target_disp,
1906 const int target_count, MPI_Op operation) const
1907 {
1908 auto guard = exec->get_scoped_device_id_guard();
1909 request req;
1910 GKO_ASSERT_NO_MPI_ERRORS(MPI_Rget_accumulate(
1911 origin_buffer, origin_count, type_impl<GetType>::get_type(),
1912 result_buffer, result_count, type_impl<GetType>::get_type(),
1913 target_rank, target_disp, target_count,
1914 type_impl<GetType>::get_type(), operation, this->get_window(),
1915 req.get()));
1916 return req;
1917 }
1918
1929 template <typename GetType>
1930 void fetch_and_op(std::shared_ptr<const Executor> exec,
1931 GetType* origin_buffer, GetType* result_buffer,
1932 const int target_rank, const unsigned int target_disp,
1933 MPI_Op operation) const
1934 {
1935 auto guard = exec->get_scoped_device_id_guard();
1936 GKO_ASSERT_NO_MPI_ERRORS(MPI_Fetch_and_op(
1937 origin_buffer, result_buffer, type_impl<GetType>::get_type(),
1938 target_rank, target_disp, operation, this->get_window()));
1939 }
1940
1941private:
1942 MPI_Win window_;
1943};
1944
1945
1946} // namespace mpi
1947} // namespace experimental
1948} // namespace gko
1949
1950
1951#endif // GKO_HAVE_MPI
1952
1953
1954#endif // GKO_PUBLIC_CORE_BASE_MPI_HPP_
A thin wrapper of MPI_Comm that supports most MPI calls.
Definition mpi.hpp:409
status recv(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count, const int source_rank, const int recv_tag) const
Receive data from source rank.
Definition mpi.hpp:583
void scatter_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *displacements, RecvType *recv_buffer, const int recv_count, int root_rank) const
Scatter data from root rank to all ranks in the communicator with offsets.
Definition mpi.hpp:1100
request i_broadcast(std::shared_ptr< const Executor > exec, BroadcastType *buffer, int count, int root_rank) const
(Non-blocking) Broadcast data from calling process to all ranks in the communicator
Definition mpi.hpp:661
void gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
Gather data onto the root rank from all ranks in the communicator.
Definition mpi.hpp:848
request i_recv(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count, const int source_rank, const int recv_tag) const
Receive (Non-blocking, Immediate return) data from source rank.
Definition mpi.hpp:611
request i_scatter_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *displacements, RecvType *recv_buffer, const int recv_count, int root_rank) const
(Non-blocking) Scatter data from root rank to all ranks in the communicator with offsets.
Definition mpi.hpp:1133
void all_to_all(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
Communicate data from all ranks to all other ranks (MPI_Alltoall).
Definition mpi.hpp:1223
request i_all_to_all(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
(Non-blocking) Communicate data from all ranks to all other ranks (MPI_Ialltoall).
Definition mpi.hpp:1253
request i_all_to_all_v(std::shared_ptr< const Executor > exec, const void *send_buffer, const int *send_counts, const int *send_offsets, MPI_Datatype send_type, void *recv_buffer, const int *recv_counts, const int *recv_offsets, MPI_Datatype recv_type) const
Communicate data from all ranks to all other ranks with offsets (MPI_Ialltoallv).
Definition mpi.hpp:1343
bool operator!=(const communicator &rhs) const
Compare two communicator objects for non-equality.
Definition mpi.hpp:503
void scatter(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
Scatter data from root rank to all ranks in the communicator.
Definition mpi.hpp:1036
void synchronize() const
This function is used to synchronize the ranks in the communicator.
Definition mpi.hpp:509
int rank() const
Return the rank of the calling process in the communicator.
Definition mpi.hpp:479
request i_reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation, int root_rank) const
(Non-blocking) Reduce data into root from all calling processes on the same communicator.
Definition mpi.hpp:714
int size() const
Return the size of the communicator (number of ranks).
Definition mpi.hpp:472
void send(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, const int destination_rank, const int send_tag) const
Send (Blocking) data from calling process to destination rank.
Definition mpi.hpp:528
request i_all_to_all_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *send_offsets, RecvType *recv_buffer, const int *recv_counts, const int *recv_offsets) const
Communicate data from all ranks to all other ranks with offsets (MPI_Ialltoallv).
Definition mpi.hpp:1379
request i_gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
(Non-blocking) Gather data onto the root rank from all ranks in the communicator.
Definition mpi.hpp:880
void all_to_all(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count) const
(In-place) Communicate data from all ranks to all other ranks in place (MPI_Alltoall).
Definition mpi.hpp:1165
void all_to_all_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *send_offsets, RecvType *recv_buffer, const int *recv_counts, const int *recv_offsets) const
Communicate data from all ranks to all other ranks with offsets (MPI_Alltoallv).
Definition mpi.hpp:1286
request i_all_reduce(std::shared_ptr< const Executor > exec, ReduceType *recv_buffer, int count, MPI_Op operation) const
(In-place, non-blocking) Reduce data from all calling processes from all calling processes on same co...
Definition mpi.hpp:765
request i_all_to_all(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count) const
(In-place, Non-blocking) Communicate data from all ranks to all other ranks in place (MPI_Ialltoall).
Definition mpi.hpp:1194
void all_to_all_v(std::shared_ptr< const Executor > exec, const void *send_buffer, const int *send_counts, const int *send_offsets, MPI_Datatype send_type, void *recv_buffer, const int *recv_counts, const int *recv_offsets, MPI_Datatype recv_type) const
Communicate data from all ranks to all other ranks with offsets (MPI_Alltoallv).
Definition mpi.hpp:1312
int node_local_rank() const
Return the node local rank of the calling process in the communicator.
Definition mpi.hpp:486
void broadcast(std::shared_ptr< const Executor > exec, BroadcastType *buffer, int count, int root_rank) const
Broadcast data from calling process to all ranks in the communicator.
Definition mpi.hpp:636
const MPI_Comm & get() const
Return the underlying MPI_Comm object.
Definition mpi.hpp:463
communicator(const MPI_Comm &comm, int color, int key)
Create a communicator object from an existing MPI_Comm object using color and key.
Definition mpi.hpp:435
void all_reduce(std::shared_ptr< const Executor > exec, ReduceType *recv_buffer, int count, MPI_Op operation) const
(In-place) Reduce data from all calling processes from all calling processes on same communicator.
Definition mpi.hpp:740
void all_gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
Gather data onto all ranks from all ranks in the communicator.
Definition mpi.hpp:977
request i_all_gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
(Non-blocking) Gather data onto all ranks from all ranks in the communicator.
Definition mpi.hpp:1007
bool operator==(const communicator &rhs) const
Compare two communicator objects for equality.
Definition mpi.hpp:493
void all_reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation) const
Reduce data from all calling processes from all calling processes on same communicator.
Definition mpi.hpp:792
request i_gather_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int *recv_counts, const int *displacements, int root_rank) const
(Non-blocking) Gather data onto the root rank from all ranks in the communicator with offsets.
Definition mpi.hpp:946
request i_all_reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation) const
Reduce data from all calling processes from all calling processes on same communicator.
Definition mpi.hpp:819
communicator(const MPI_Comm &comm, bool force_host_buffer=false)
Non-owning constructor for an existing communicator of type MPI_Comm.
Definition mpi.hpp:421
request i_scan(std::shared_ptr< const Executor > exec, const ScanType *send_buffer, ScanType *recv_buffer, int count, MPI_Op operation) const
Does a scan operation with the given operator.
Definition mpi.hpp:1432
void reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation, int root_rank) const
Reduce data into root from all calling processes on the same communicator.
Definition mpi.hpp:687
request i_scatter(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
(Non-blocking) Scatter data from root rank to all ranks in the communicator.
Definition mpi.hpp:1067
void scan(std::shared_ptr< const Executor > exec, const ScanType *send_buffer, ScanType *recv_buffer, int count, MPI_Op operation) const
Does a scan operation with the given operator.
Definition mpi.hpp:1406
void gather_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int *recv_counts, const int *displacements, int root_rank) const
Gather data onto the root rank from all ranks in the communicator with offsets.
Definition mpi.hpp:913
request i_send(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, const int destination_rank, const int send_tag) const
Send (Non-blocking, Immediate return) data from calling process to destination rank.
Definition mpi.hpp:555
communicator(const communicator &comm, int color, int key)
Create a communicator object from an existing MPI_Comm object using color and key.
Definition mpi.hpp:450
A move-only wrapper for a contiguous MPI_Datatype.
Definition mpi.hpp:102
MPI_Datatype get() const
Access the underlying MPI_Datatype.
Definition mpi.hpp:171
contiguous_type(int count, MPI_Datatype old_type)
Constructs a wrapper for a contiguous MPI_Datatype.
Definition mpi.hpp:110
contiguous_type()
Constructs empty wrapper with MPI_DATATYPE_NULL.
Definition mpi.hpp:119
contiguous_type(const contiguous_type &)=delete
Disallow copying of wrapper type.
contiguous_type(contiguous_type &&other) noexcept
Move constructor, leaves other with MPI_DATATYPE_NULL.
Definition mpi.hpp:136
contiguous_type & operator=(contiguous_type &&other) noexcept
Move assignment, leaves other with MPI_DATATYPE_NULL.
Definition mpi.hpp:148
contiguous_type & operator=(const contiguous_type &)=delete
Disallow copying of wrapper type.
~contiguous_type()
Destructs object by freeing wrapped MPI_Datatype.
Definition mpi.hpp:159
Class that sets up and finalizes the MPI environment.
Definition mpi.hpp:199
~environment()
Call MPI_Finalize at the end of the scope of this class.
Definition mpi.hpp:242
int get_provided_thread_support() const
Return the provided thread support.
Definition mpi.hpp:220
environment(int &argc, char **&argv, const thread_type thread_t=thread_type::serialized)
Call MPI_Init_thread and initialize the MPI environment.
Definition mpi.hpp:230
The request class is a light, move-only wrapper around the MPI_Request handle.
Definition mpi.hpp:320
request()
The default constructor.
Definition mpi.hpp:326
MPI_Request * get()
Get a pointer to the underlying MPI_Request handle.
Definition mpi.hpp:357
status wait()
Allows a rank to wait on a particular request handle.
Definition mpi.hpp:365
This class wraps the MPI_Window class with RAII functionality.
Definition mpi.hpp:1507
void get(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Get data from the target window.
Definition mpi.hpp:1820
request r_put(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Put data into the target window.
Definition mpi.hpp:1742
window()
The default constructor.
Definition mpi.hpp:1522
void get_accumulate(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, GetType *result_buffer, const int result_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
Get Accumulate data from the target window.
Definition mpi.hpp:1871
void put(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Put data into the target window.
Definition mpi.hpp:1718
~window()
The deleter which calls MPI_Win_free when the window leaves its scope.
Definition mpi.hpp:1700
lock_type
The lock type for passive target synchronization of the windows.
Definition mpi.hpp:1517
window & operator=(window &&other)
The move assignment operator.
Definition mpi.hpp:1543
request r_accumulate(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
(Non-blocking) Accumulate data into the target window.
Definition mpi.hpp:1794
request r_get_accumulate(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, GetType *result_buffer, const int result_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
(Non-blocking) Get Accumulate data (with handle) from the target window.
Definition mpi.hpp:1901
void fetch_and_op(std::shared_ptr< const Executor > exec, GetType *origin_buffer, GetType *result_buffer, const int target_rank, const unsigned int target_disp, MPI_Op operation) const
Fetch and operate on data from the target window (An optimized version of Get_accumulate).
Definition mpi.hpp:1930
void sync() const
Synchronize the public and private buffers for the window object.
Definition mpi.hpp:1695
void unlock(int rank) const
Close the epoch using MPI_Win_unlock for the window object.
Definition mpi.hpp:1627
void fence(int assert=0) const
The active target synchronization using MPI_Win_fence for the window object.
Definition mpi.hpp:1594
void flush(int rank) const
Flush the existing RDMA operations on the target rank for the calling process for the window object.
Definition mpi.hpp:1658
void unlock_all() const
Close the epoch on all ranks using MPI_Win_unlock_all for the window object.
Definition mpi.hpp:1647
create_type
The create type for the window object.
Definition mpi.hpp:1512
window(std::shared_ptr< const Executor > exec, ValueType *base, int num_elems, const communicator &comm, const int disp_unit=sizeof(ValueType), MPI_Info input_info=MPI_INFO_NULL, create_type c_type=create_type::create)
Create a window object with a given data pointer and type.
Definition mpi.hpp:1560
void accumulate(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
Accumulate data into the target window.
Definition mpi.hpp:1768
void lock_all(int assert=0) const
Create the epoch on all ranks using MPI_Win_lock_all for the window object.
Definition mpi.hpp:1638
void lock(int rank, lock_type lock_t=lock_type::shared, int assert=0) const
Create an epoch using MPI_Win_lock for the window object.
Definition mpi.hpp:1607
void flush_all_local() const
Flush all the local existing RDMA operations on the calling rank for the window object.
Definition mpi.hpp:1687
window(window &&other)
The move constructor.
Definition mpi.hpp:1534
void flush_local(int rank) const
Flush the existing RDMA operations on the calling rank from the target rank for the window object.
Definition mpi.hpp:1669
MPI_Win get_window() const
Get the underlying window object of MPI_Win type.
Definition mpi.hpp:1586
request r_get(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Get data (with handle) from the target window.
Definition mpi.hpp:1844
void flush_all() const
Flush all the existing RDMA operations for the calling process for the window object.
Definition mpi.hpp:1678
int map_rank_to_device_id(MPI_Comm comm, int num_devices)
Maps each MPI rank to a single device id in a round robin manner.
bool requires_host_buffer(const std::shared_ptr< const Executor > &exec, const communicator &comm)
Checks if the combination of Executor and communicator requires passing MPI buffers from the host mem...
double get_walltime()
Get the rank in the communicator of the calling process.
Definition mpi.hpp:1495
constexpr bool is_gpu_aware()
Return if GPU aware functionality is available.
Definition mpi.hpp:42
thread_type
This enum specifies the threading type to be used when creating an MPI environment.
Definition mpi.hpp:182
std::vector< status > wait_all(std::vector< request > &req)
Allows a rank to wait on multiple request handles.
Definition mpi.hpp:385
The Ginkgo namespace.
Definition abstract_factory.hpp:20
The status struct is a light wrapper around the MPI_Status struct.
Definition mpi.hpp:280
int get_count(const T *data) const
Get the count of the number of elements received by the communication call.
Definition mpi.hpp:304
status()
The default constructor.
Definition mpi.hpp:284
MPI_Status * get()
Get a pointer to the underlying MPI_Status object.
Definition mpi.hpp:291
A struct that is used to determine the MPI_Datatype of a specified type.
Definition mpi.hpp:77