15
15
#include < limits>
16
16
#include < numeric>
17
17
#include < tuple>
18
+ #include < type_traits>
18
19
#include < vector>
19
20
20
21
#include " ./macro.h"
@@ -660,26 +661,37 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
660
661
}
661
662
662
663
if (layer) {
663
- SamplerArgs<SamplerType::LABOR> args = [&] {
664
- if (random_seed.has_value ()) {
665
- return SamplerArgs<SamplerType::LABOR>{
666
- indices_,
667
- {random_seed.value (), static_cast <float >(seed2_contribution)},
668
- NumNodes ()};
669
- } else {
670
- return SamplerArgs<SamplerType::LABOR>{
671
- indices_,
672
- RandomEngine::ThreadLocal ()->RandInt (
673
- static_cast <int64_t >(0 ), std::numeric_limits<int64_t >::max ()),
674
- NumNodes ()};
675
- }
676
- }();
677
- return SampleNeighborsImpl (
678
- nodes.value (), return_eids,
679
- GetNumPickFn (fanouts, replace, type_per_edge_, probs_or_mask),
680
- GetPickFn (
681
- fanouts, replace, indptr_.options (), type_per_edge_, probs_or_mask,
682
- args));
664
+ if (random_seed.has_value () && random_seed->numel () >= 2 ) {
665
+ SamplerArgs<SamplerType::LABOR_DEPENDENT> args{
666
+ indices_,
667
+ {random_seed.value (), static_cast <float >(seed2_contribution)},
668
+ NumNodes ()};
669
+ return SampleNeighborsImpl (
670
+ nodes.value (), return_eids,
671
+ GetNumPickFn (fanouts, replace, type_per_edge_, probs_or_mask),
672
+ GetPickFn (
673
+ fanouts, replace, indptr_.options (), type_per_edge_,
674
+ probs_or_mask, args));
675
+ } else {
676
+ auto args = [&] {
677
+ if (random_seed.has_value () && random_seed->numel () == 1 ) {
678
+ return SamplerArgs<SamplerType::LABOR>{
679
+ indices_, random_seed.value (), NumNodes ()};
680
+ } else {
681
+ return SamplerArgs<SamplerType::LABOR>{
682
+ indices_,
683
+ RandomEngine::ThreadLocal ()->RandInt (
684
+ static_cast <int64_t >(0 ), std::numeric_limits<int64_t >::max ()),
685
+ NumNodes ()};
686
+ }
687
+ }();
688
+ return SampleNeighborsImpl (
689
+ nodes.value (), return_eids,
690
+ GetNumPickFn (fanouts, replace, type_per_edge_, probs_or_mask),
691
+ GetPickFn (
692
+ fanouts, replace, indptr_.options (), type_per_edge_,
693
+ probs_or_mask, args));
694
+ }
683
695
} else {
684
696
SamplerArgs<SamplerType::NEIGHBOR> args;
685
697
return SampleNeighborsImpl (
@@ -1297,7 +1309,7 @@ int64_t TemporalPick(
1297
1309
}
1298
1310
return picked_indices.numel ();
1299
1311
}
1300
- if constexpr (S == SamplerType::LABOR ) {
1312
+ if constexpr (is_labor (S) ) {
1301
1313
return Pick (
1302
1314
offset, num_neighbors, fanout, replace, options, masked_prob, args,
1303
1315
picked_data_ptr);
@@ -1383,12 +1395,12 @@ int64_t TemporalPickByEtype(
1383
1395
return pick_offset;
1384
1396
}
1385
1397
1386
- template <typename PickedType>
1387
- int64_t Pick (
1398
+ template <SamplerType S, typename PickedType>
1399
+ std:: enable_if_t <is_labor(S), int64_t > Pick (
1388
1400
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
1389
1401
const torch::TensorOptions& options,
1390
- const torch::optional<torch::Tensor>& probs_or_mask,
1391
- SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr) {
1402
+ const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
1403
+ PickedType* picked_data_ptr) {
1392
1404
if (fanout == 0 ) return 0 ;
1393
1405
if (probs_or_mask.has_value ()) {
1394
1406
if (fanout < 0 ) {
@@ -1438,9 +1450,9 @@ inline T invcdf(T u, int64_t n, T rem) {
1438
1450
return rem * (one - std::pow (one - u, one / n));
1439
1451
}
1440
1452
1441
- template <typename T>
1453
+ template <typename T, typename seed_t >
1442
1454
inline T jth_sorted_uniform_random (
1443
- continuous_seed seed, int64_t t, int64_t c, int64_t j, T& rem, int64_t n) {
1455
+ seed_t seed, int64_t t, int64_t c, int64_t j, T& rem, int64_t n) {
1444
1456
const T u = seed.uniform (t + j * c);
1445
1457
// https://mathematica.stackexchange.com/a/256707
1446
1458
rem -= invcdf (u, n, rem);
@@ -1474,13 +1486,13 @@ inline T jth_sorted_uniform_random(
1474
1486
* should be put. Enough memory space should be allocated in advance.
1475
1487
*/
1476
1488
template <
1477
- bool NonUniform, bool Replace, typename ProbsType, typename PickedType ,
1478
- int StackSize>
1479
- inline int64_t LaborPick (
1489
+ bool NonUniform, bool Replace, typename ProbsType, SamplerType S ,
1490
+ typename PickedType, int StackSize>
1491
+ inline std:: enable_if_t <is_labor(S), int64_t > LaborPick (
1480
1492
int64_t offset, int64_t num_neighbors, int64_t fanout,
1481
1493
const torch::TensorOptions& options,
1482
- const torch::optional<torch::Tensor>& probs_or_mask,
1483
- SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr) {
1494
+ const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
1495
+ PickedType* picked_data_ptr) {
1484
1496
fanout = Replace ? fanout : std::min (fanout, num_neighbors);
1485
1497
if (!NonUniform && !Replace && fanout >= num_neighbors) {
1486
1498
std::iota (picked_data_ptr, picked_data_ptr + num_neighbors, offset);
@@ -1504,8 +1516,8 @@ inline int64_t LaborPick(
1504
1516
}
1505
1517
AT_DISPATCH_INDEX_TYPES (
1506
1518
args.indices .scalar_type (), " LaborPickMain" , ([&] {
1507
- const index_t * local_indices_data =
1508
- args.indices .data_ptr < index_t >( ) + offset;
1519
+ const auto local_indices_data =
1520
+ reinterpret_cast < index_t *>( args.indices .data_ptr () ) + offset;
1509
1521
if constexpr (Replace) {
1510
1522
// [Algorithm] @mfbalin
1511
1523
// Use a max-heap to get rid of the big random numbers and filter the
0 commit comments