14
14
//
15
15
// Authors: Philipp Dumitrescu, Olivier Parcollet, Nils Wentzell
16
16
17
- #include " mpi/mpi.hpp"
17
+ #include < mpi/mpi.hpp>
18
+ #include < mpi/vector.hpp>
18
19
#include < gtest/gtest.h>
19
20
#include < numeric>
20
21
@@ -95,10 +96,10 @@ TEST(MPI_Window, SharedArray) {
95
96
auto shm = world.split_shared ();
96
97
int const rank_shm = shm.rank ();
97
98
98
- constexpr int const size = 20 ;
99
+ constexpr int const array_size = 20 ;
99
100
constexpr int const magic = 21 ;
100
101
101
- mpi::shared_window<int > win{shm, rank_shm == 0 ? size : 0 };
102
+ mpi::shared_window<int > win{shm, rank_shm == 0 ? array_size : 0 };
102
103
std::span array_view{win.base (0 ), static_cast <std::size_t >(win.size (0 ))};
103
104
104
105
win.fence ();
@@ -110,7 +111,81 @@ TEST(MPI_Window, SharedArray) {
110
111
win.fence ();
111
112
112
113
int sum = std::accumulate (array_view.begin (), array_view.end (), int {0 });
113
- EXPECT_EQ (sum, size * magic);
114
+ EXPECT_EQ (sum, array_size * magic);
115
+ }
116
+
117
+ TEST (MPI_Window, DistributedSharedArray) {
118
+ mpi::communicator world;
119
+ auto shm = world.split_shared ();
120
+
121
+ // Number of total array elements (prime number to make it a bit more exciting)
122
+ constexpr int const array_size_total = 197 ;
123
+
124
+ // Create a communicator between rank0 of all shared memory islands ("head node")
125
+ auto head = world.split (shm.rank () == 0 ? 0 : MPI_UNDEFINED);
126
+
127
+ // Determine number of shared memory islands and broadcast to everyone
128
+ int head_size = (world.rank () == 0 ? head.size (): -1 );
129
+ mpi::broadcast (head_size, world);
130
+
131
+ // Determine rank in head node communicator and broadcast to all other ranks
132
+ // on the same shared memory island
133
+ int head_rank = (head.get () != MPI_COMM_NULL ? head.rank () : -1 );
134
+ mpi::broadcast (head_rank, shm);
135
+
136
+ // Determine number of ranks on each shared memory island and broadcast to everyone
137
+ std::vector<int > shm_sizes (head_size, 0 );
138
+ if (!head.is_null ()) {
139
+ shm_sizes.at (head_rank) = shm.size ();
140
+ shm_sizes = mpi::all_reduce (shm_sizes, head);
141
+ }
142
+ mpi::broadcast (shm_sizes, world);
143
+
144
+ // Chunk the total array such that each rank has approximately the same number
145
+ // of array elements
146
+ std::vector<int > array_sizes (head_size, 0 );
147
+ for (auto &&[shm_size, array_size]: itertools::zip (shm_sizes, array_sizes)) {
148
+ array_size = array_size_total / world.size () * shm_size;
149
+ }
150
+ // Distribute the remainder evenly over the islands to reduce load imbalance
151
+ for (auto i: itertools::range (array_size_total % world.size ())) {
152
+ array_sizes.at (i % array_sizes.size ()) += 1 ;
153
+ }
154
+
155
+ EXPECT_EQ (array_size_total, std::accumulate (array_sizes.begin (), array_sizes.end (), int {0 }));
156
+
157
+ // Determine the global index offset on the current shared memory island
158
+ auto begin = array_sizes.begin ();
159
+ std::advance (begin, head_rank);
160
+ std::ptrdiff_t offset = std::accumulate (array_sizes.begin (), begin, std::ptrdiff_t {0 });
161
+
162
+ // Allocate memory
163
+ mpi::shared_window<int > win{shm, shm.rank () == 0 ? array_sizes.at (head_rank) : 0 };
164
+ std::span array_view{win.base (0 ), static_cast <std::size_t >(win.size (0 ))};
165
+
166
+ // Fill array with global index (= local index + global offset)
167
+ // We do this in parallel on each shared memory island by chunking the total range
168
+ win.fence ();
169
+ auto slice = itertools::chunk_range (0 , array_view.size (), shm.size (), shm.rank ());
170
+ for (auto i = slice.first ; i < slice.second ; ++i) {
171
+ array_view[i] = i + offset;
172
+ }
173
+ win.fence ();
174
+
175
+ // Calculate partial sum on head node of each shared memory island and
176
+ // all_reduce the partial sums into a total sum over the head node
177
+ // communicator and broadcast result to everyone
178
+ std::vector<int > partial_sum (head_size, 0 );
179
+ int sum = 0 ;
180
+ if (!head.is_null ()) {
181
+ partial_sum[head_rank] = std::accumulate (array_view.begin (), array_view.end (), int {0 });
182
+ partial_sum = mpi::all_reduce (partial_sum, head);
183
+ sum = std::accumulate (partial_sum.begin (), partial_sum.end (), int {0 });
184
+ }
185
+ mpi::broadcast (sum, world);
186
+
187
+ // Total sum is just sum of numbers in interval [0, array_size_total)
188
+ EXPECT_EQ (sum, (array_size_total * (array_size_total - 1 )) / 2 );
114
189
}
115
190
116
191
MPI_TEST_MAIN;
0 commit comments