26 MPI_Datatype datatype;
27 if (std::is_same<T, double>::value) {
28 datatype = MPI_DOUBLE;
30 else if (std::is_same<T, float>::value) {
33 else if (std::is_same<T, int>::value) {
36 else if (std::is_same<T, unsigned int>::value) {
37 datatype = MPI_UNSIGNED;
39 else if (std::is_same<T, long>::value) {
42 else if (std::is_same<T, unsigned long>::value) {
43 datatype = MPI_UNSIGNED_LONG;
46 throw std::runtime_error(
"Type not recognised.");
52 T
MPIMin(
const T& localVal, MPI_Comm comm) {
54 MPI_Allreduce(&localVal, &globalVal, 1, MPIType<T>(), MPI_MIN, comm);
59 T
MPIMax(
const T& localVal, MPI_Comm comm) {
61 MPI_Allreduce(&localVal, &globalVal, 1, MPIType<T>(), MPI_MAX, comm);
66 T
MPISum(
const T& localVal, MPI_Comm comm) {
68 MPI_Allreduce(&localVal, &globalVal, 1, MPIType<T>(), MPI_SUM, comm);
75 int comm_size, comm_rank;
76 MPI_Comm_size(comm, &comm_size);
77 MPI_Comm_rank(comm, &comm_rank);
80 std::vector<T> globalVec;
82 globalVec.resize(comm_size);
84 MPI_Datatype dataType = MPIType<T>();
85 MPI_Gather(&localVal, 1, dataType, globalVec.data(), 1, dataType, 0, comm);
94 int comm_size, comm_rank;
95 MPI_Comm_size(comm, &comm_size);
96 MPI_Comm_rank(comm, &comm_rank);
99 int globalRecvCountSize = 0;
100 if (comm_rank == 0) {
101 globalRecvCountSize = comm_size;
103 std::vector<int> recvcounts(globalRecvCountSize);
104 int nLocal = localVec.size();
105 MPI_Gather(&nLocal, 1, MPI_INT, recvcounts.data(), 1, MPI_INT, 0, comm);
108 int globalVecSize = 0;
109 std::vector<int> displs;
110 if (comm_rank == 0) {
112 for (
int i=0; i<comm_size-1; i++) {
113 globalVecSize += recvcounts[i];
114 displs.push_back(globalVecSize);
116 globalVecSize += recvcounts[comm_size-1];
120 std::vector<T> globalVec(globalVecSize);
121 MPI_Datatype dataType = MPIType<T>();
122 MPI_Gatherv(localVec.data(), nLocal, dataType, globalVec.data(), recvcounts.data(),
123 displs.data(), dataType, 0, comm);
129 template <
typename T>
132 int comm_size, comm_rank;
133 MPI_Comm_size(comm, &comm_size);
134 MPI_Comm_rank(comm, &comm_rank);
137 T localVal = rootVal;
138 MPI_Datatype dataType = MPIType<T>();
139 MPI_Bcast(&localVal, 1, dataType, 0, comm);
145 template <
typename T>
148 int comm_size, comm_rank;
149 MPI_Comm_size(comm, &comm_size);
150 MPI_Comm_rank(comm, &comm_rank);
154 unsigned int size = rootVec.size();
158 std::vector<T> localVec(size);
159 if (comm_rank == 0) {
164 MPI_Datatype dataType = MPIType<T>();
165 MPI_Bcast(localVec.data(), size, dataType, 0, comm);
171 template <
typename T>
174 int comm_size, comm_rank;
175 MPI_Comm_size(comm, &comm_size);
176 MPI_Comm_rank(comm, &comm_rank);
179 int rootVecSize = -1;
180 int localVecSize = -1;
181 int localVecSizeFinalProc = -1;
182 std::vector<int> sendcounts(comm_size);
183 std::vector<int> displs(comm_size);
184 if (comm_rank == 0) {
186 rootVecSize = rootVec.size();
187 localVecSize = rootVecSize/comm_size;
188 localVecSizeFinalProc = rootVecSize-localVecSize*(comm_size-1);
192 for (
int i=0; i<comm_size-1; i++) {
193 sendcounts[i] = localVecSize;
194 displs[i+1] = displs[i] + localVecSize;
196 sendcounts.back() = localVecSizeFinalProc;
207 if (comm_rank == comm_size-1) {
208 localVecSize = localVecSizeFinalProc;
212 std::vector<T> localVec(localVecSize);
213 MPI_Scatterv(rootVec.data(), sendcounts.data(), displs.data(), MPIType<T>(),
214 localVec.data(), localVecSize, MPIType<T>(), 0, comm);
220 template <
typename T>
223 int comm_size, comm_rank;
224 MPI_Comm_size(comm, &comm_size);
225 MPI_Comm_rank(comm, &comm_rank);
228 int rootVecSize = -1;
229 int localVecSize = -1;
230 int localVecSizeFinalProc = -1;
231 std::vector<int> sendcounts(comm_size);
232 std::vector<int> displs(comm_size);
233 if (comm_rank == 0) {
235 rootVecSize = rootVec.size();
236 localVecSize = rootVecSize/comm_size;
237 localVecSizeFinalProc = rootVecSize-localVecSize*(comm_size-1);
241 for (
int i=0; i<comm_size-1; i++) {
242 sendcounts[i] = localVecSize;
243 displs[i+1] = displs[i] + localVecSize;
245 sendcounts.back() = localVecSizeFinalProc;
256 if (comm_rank == comm_size-1) {
257 localVecSize = localVecSizeFinalProc;
261 std::vector<T> localVec(localVecSize);
262 MPI_Scatterv(rootVec.data(), sendcounts.data(), displs.data(), dtype,
263 localVec.data(), localVecSize, dtype, 0, comm);
275 bool MPIAllTrue(
bool condition, MPI_Comm comm);
Definition: casesUtils.cpp:4
T MPISum(const T &localVal, MPI_Comm comm)
Definition: mpiUtils.h:66
bool MPIAllTrue(bool condition, MPI_Comm comm)
Determine if condition is true for all processes.
Definition: mpiUtils.cpp:6
T MPIMin(const T &localVal, MPI_Comm comm)
Definition: mpiUtils.h:52
std::vector< T > MPIConcatOnRoot(T localVal, MPI_Comm comm)
Definition: mpiUtils.h:73
MPI_Datatype MPIType()
Get the MPI datatype from the C++ type.
Definition: mpiUtils.h:25
std::vector< T > MPISplitVectorEvenly(const std::vector< T > &rootVec, MPI_Comm comm)
Definition: mpiUtils.h:172
T MPIBroadcastFromRoot(T rootVal, MPI_Comm comm)
Definition: mpiUtils.h:130
T MPIMax(const T &localVal, MPI_Comm comm)
Definition: mpiUtils.h:59