Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit bc1ea84

Browse files
nv-dlasalleDominikaJedynak
authored andcommittedMar 12, 2024
[Performance Improvement] Make GPU sampling and to_block use pinned memory to decrease required synchronization (dmlc#5685)
1 parent a878672 commit bc1ea84

File tree

2 files changed

+52
-17
lines changed

2 files changed

+52
-17
lines changed
 

‎src/array/cuda/rowwise_sampling.cu

+19-8
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <curand_kernel.h>
88
#include <dgl/random.h>
99
#include <dgl/runtime/device_api.h>
10+
#include <dgl/runtime/tensordispatch.h>
1011

1112
#include <numeric>
1213

@@ -15,9 +16,11 @@
1516
#include "./dgl_cub.cuh"
1617
#include "./utils.h"
1718

19+
using namespace dgl::cuda;
20+
using namespace dgl::aten::cuda;
21+
using TensorDispatcher = dgl::runtime::TensorDispatcher;
22+
1823
namespace dgl {
19-
using namespace cuda;
20-
using namespace aten::cuda;
2124
namespace aten {
2225
namespace impl {
2326

@@ -287,13 +290,20 @@ COOMatrix _CSRRowWiseSamplingUniform(
287290
cudaEvent_t copyEvent;
288291
CUDA_CALL(cudaEventCreate(&copyEvent));
289292

290-
// TODO(dlasalle): use pinned memory to overlap with the actual sampling, and
291-
// wait on a cudaevent
292-
IdType new_len;
293+
NDArray new_len_tensor;
294+
if (TensorDispatcher::Global()->IsAvailable()) {
295+
new_len_tensor = NDArray::PinnedEmpty(
296+
{1}, DGLDataTypeTraits<IdType>::dtype, DGLContext{kDGLCPU, 0});
297+
} else {
298+
// use pageable memory, it will unecessarily block but be functional
299+
new_len_tensor = NDArray::Empty(
300+
{1}, DGLDataTypeTraits<IdType>::dtype, DGLContext{kDGLCPU, 0});
301+
}
302+
293303
// copy using the internal current stream
294-
device->CopyDataFromTo(
295-
out_ptr, num_rows * sizeof(new_len), &new_len, 0, sizeof(new_len), ctx,
296-
DGLContext{kDGLCPU, 0}, mat.indptr->dtype);
304+
CUDA_CALL(cudaMemcpyAsync(
305+
new_len_tensor->data, out_ptr + num_rows, sizeof(IdType),
306+
cudaMemcpyDeviceToHost, stream));
297307
CUDA_CALL(cudaEventRecord(copyEvent, stream));
298308

299309
const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);
@@ -322,6 +332,7 @@ COOMatrix _CSRRowWiseSamplingUniform(
322332
CUDA_CALL(cudaEventSynchronize(copyEvent));
323333
CUDA_CALL(cudaEventDestroy(copyEvent));
324334

335+
const IdType new_len = static_cast<const IdType*>(new_len_tensor->data)[0];
325336
picked_row = picked_row.CreateView({new_len}, picked_row->dtype);
326337
picked_col = picked_col.CreateView({new_len}, picked_col->dtype);
327338
picked_idx = picked_idx.CreateView({new_len}, picked_idx->dtype);

‎src/graph/transform/cuda/cuda_to_block.cu

+33-9
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <cuda_runtime.h>
2424
#include <dgl/immutable_graph.h>
2525
#include <dgl/runtime/device_api.h>
26+
#include <dgl/runtime/tensordispatch.h>
2627

2728
#include <algorithm>
2829
#include <memory>
@@ -36,6 +37,7 @@
3637
using namespace dgl::aten;
3738
using namespace dgl::runtime::cuda;
3839
using namespace dgl::transform::cuda;
40+
using TensorDispatcher = dgl::runtime::TensorDispatcher;
3941

4042
namespace dgl {
4143
namespace transform {
@@ -165,6 +167,9 @@ struct CUDAIdsMapper {
165167
NewIdArray(maxNodesPerType[ntype], ctx, sizeof(IdType) * 8));
166168
}
167169
}
170+
171+
cudaEvent_t copyEvent;
172+
NDArray new_len_tensor;
168173
// Populate the mappings.
169174
if (generate_lhs_nodes) {
170175
int64_t* count_lhs_device = static_cast<int64_t*>(
@@ -174,13 +179,23 @@ struct CUDAIdsMapper {
174179
src_nodes, rhs_nodes, &node_maps, count_lhs_device, &lhs_nodes,
175180
stream);
176181

177-
device->CopyDataFromTo(
178-
count_lhs_device, 0, num_nodes_per_type.data(), 0,
179-
sizeof(*num_nodes_per_type.data()) * num_ntypes, ctx,
180-
DGLContext{kDGLCPU, 0}, DGLDataType{kDGLInt, 64, 1});
181-
device->StreamSync(ctx, stream);
182+
CUDA_CALL(cudaEventCreate(&copyEvent));
183+
if (TensorDispatcher::Global()->IsAvailable()) {
184+
new_len_tensor = NDArray::PinnedEmpty(
185+
{num_ntypes}, DGLDataTypeTraits<int64_t>::dtype,
186+
DGLContext{kDGLCPU, 0});
187+
} else {
188+
// use pageable memory, it will unecessarily block but be functional
189+
new_len_tensor = NDArray::Empty(
190+
{num_ntypes}, DGLDataTypeTraits<int64_t>::dtype,
191+
DGLContext{kDGLCPU, 0});
192+
}
193+
CUDA_CALL(cudaMemcpyAsync(
194+
new_len_tensor->data, count_lhs_device,
195+
sizeof(*num_nodes_per_type.data()) * num_ntypes,
196+
cudaMemcpyDeviceToHost, stream));
197+
CUDA_CALL(cudaEventRecord(copyEvent, stream));
182198

183-
// Wait for the node counts to finish transferring.
184199
device->FreeWorkspace(ctx, count_lhs_device);
185200
} else {
186201
maker.Make(lhs_nodes, rhs_nodes, &node_maps, stream);
@@ -189,14 +204,23 @@ struct CUDAIdsMapper {
189204
num_nodes_per_type[ntype] = lhs_nodes[ntype]->shape[0];
190205
}
191206
}
192-
// Resize lhs nodes.
207+
// Map node numberings from global to local, and build pointer for CSR.
208+
auto ret = MapEdges(graph, edge_arrays, node_maps, stream);
209+
193210
if (generate_lhs_nodes) {
211+
// wait for the previous copy
212+
CUDA_CALL(cudaEventSynchronize(copyEvent));
213+
CUDA_CALL(cudaEventDestroy(copyEvent));
214+
215+
// Resize lhs nodes.
194216
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
217+
num_nodes_per_type[ntype] =
218+
static_cast<int64_t*>(new_len_tensor->data)[ntype];
195219
lhs_nodes[ntype]->shape[0] = num_nodes_per_type[ntype];
196220
}
197221
}
198-
// Map node numberings from global to local, and build pointer for CSR.
199-
return MapEdges(graph, edge_arrays, node_maps, stream);
222+
223+
return ret;
200224
}
201225
};
202226

0 commit comments

Comments
 (0)
Please sign in to comment.