Skip to content

Commit bea1456

Browse files
committed
Add complicated distributed shared array test
1 parent 56cfe71 commit bea1456

File tree

1 file changed

+79
-4
lines changed

1 file changed

+79
-4
lines changed

test/c++/mpi_window.cpp

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
//
1515
// Authors: Philipp Dumitrescu, Olivier Parcollet, Nils Wentzell
1616

17-
#include "mpi/mpi.hpp"
17+
#include <mpi/mpi.hpp>
18+
#include <mpi/vector.hpp>
1819
#include <gtest/gtest.h>
1920
#include <numeric>
2021

@@ -95,10 +96,10 @@ TEST(MPI_Window, SharedArray) {
9596
auto shm = world.split_shared();
9697
int const rank_shm = shm.rank();
9798

98-
constexpr int const size = 20;
99+
constexpr int const array_size = 20;
99100
constexpr int const magic = 21;
100101

101-
mpi::shared_window<int> win{shm, rank_shm == 0 ? size : 0};
102+
mpi::shared_window<int> win{shm, rank_shm == 0 ? array_size : 0};
102103
std::span array_view{win.base(0), static_cast<std::size_t>(win.size(0))};
103104

104105
win.fence();
@@ -110,7 +111,81 @@ TEST(MPI_Window, SharedArray) {
110111
win.fence();
111112

112113
int sum = std::accumulate(array_view.begin(), array_view.end(), int{0});
113-
EXPECT_EQ(sum, size * magic);
114+
EXPECT_EQ(sum, array_size * magic);
115+
}
116+
117+
TEST(MPI_Window, DistributedSharedArray) {
118+
mpi::communicator world;
119+
auto shm = world.split_shared();
120+
121+
// Number of total array elements (prime number to make it a bit more exciting)
122+
constexpr int const array_size_total = 197;
123+
124+
// Create a communicator between rank0 of all shared memory islands ("head node")
125+
auto head = world.split(shm.rank() == 0 ? 0 : MPI_UNDEFINED);
126+
127+
// Determine number of shared memory islands and broadcast to everyone
128+
int head_size = (world.rank() == 0 ? head.size(): -1);
129+
mpi::broadcast(head_size, world);
130+
131+
// Determine rank in head node communicator and broadcast to all other ranks
132+
// on the same shared memory island
133+
int head_rank = (head.get() != MPI_COMM_NULL ? head.rank() : -1);
134+
mpi::broadcast(head_rank, shm);
135+
136+
// Determine number of ranks on each shared memory island and broadcast to everyone
137+
std::vector<int> shm_sizes(head_size, 0);
138+
if (!head.is_null()) {
139+
shm_sizes.at(head_rank) = shm.size();
140+
shm_sizes = mpi::all_reduce(shm_sizes, head);
141+
}
142+
mpi::broadcast(shm_sizes, world);
143+
144+
// Chunk the total array such that each rank has approximately the same number
145+
// of array elements
146+
std::vector<int> array_sizes(head_size, 0);
147+
for (auto &&[shm_size, array_size]: itertools::zip(shm_sizes, array_sizes)) {
148+
array_size = array_size_total / world.size() * shm_size;
149+
}
150+
// Distribute the remainder evenly over the islands to reduce load imbalance
151+
for (auto i: itertools::range(array_size_total % world.size())) {
152+
array_sizes.at(i % array_sizes.size()) += 1;
153+
}
154+
155+
EXPECT_EQ(array_size_total, std::accumulate(array_sizes.begin(), array_sizes.end(), int{0}));
156+
157+
// Determine the global index offset on the current shared memory island
158+
auto begin = array_sizes.begin();
159+
std::advance(begin, head_rank);
160+
std::ptrdiff_t offset = std::accumulate(array_sizes.begin(), begin, std::ptrdiff_t{0});
161+
162+
// Allocate memory
163+
mpi::shared_window<int> win{shm, shm.rank() == 0 ? array_sizes.at(head_rank) : 0};
164+
std::span array_view{win.base(0), static_cast<std::size_t>(win.size(0))};
165+
166+
// Fill array with global index (= local index + global offset)
167+
// We do this in parallel on each shared memory island by chunking the total range
168+
win.fence();
169+
auto slice = itertools::chunk_range(0, array_view.size(), shm.size(), shm.rank());
170+
for (auto i = slice.first; i < slice.second; ++i) {
171+
array_view[i] = i + offset;
172+
}
173+
win.fence();
174+
175+
// Calculate partial sum on head node of each shared memory island and
176+
// all_reduce the partial sums into a total sum over the head node
177+
// communicator and broadcast result to everyone
178+
std::vector<int> partial_sum(head_size, 0);
179+
int sum = 0;
180+
if (!head.is_null()) {
181+
partial_sum[head_rank] = std::accumulate(array_view.begin(), array_view.end(), int{0});
182+
partial_sum = mpi::all_reduce(partial_sum, head);
183+
sum = std::accumulate(partial_sum.begin(), partial_sum.end(), int{0});
184+
}
185+
mpi::broadcast(sum, world);
186+
187+
// Total sum is just sum of numbers in interval [0, array_size_total)
188+
EXPECT_EQ(sum, (array_size_total * (array_size_total - 1)) / 2);
114189
}
115190

116191
MPI_TEST_MAIN;

0 commit comments

Comments
 (0)