24
24
#include < mxnet/operator_util.h>
25
25
#include < dmlc/logging.h>
26
26
#include < dmlc/optional.h>
27
+ #include < algorithm>
28
+ #include < random>
29
+
27
30
#include " ../elemwise_op_common.h"
28
31
#include " ../../imperative/imperative_utils.h"
29
32
#include " ../subgraph_op_common.h"
@@ -41,7 +44,9 @@ typedef int64_t dgl_id_t;
41
44
*/
42
45
class ArrayHeap {
43
46
public:
44
- explicit ArrayHeap (const std::vector<float >& prob) {
47
+ explicit ArrayHeap (const std::vector<float >& prob, unsigned int seed) {
48
+ generator_ = std::mt19937 (seed);
49
+ distribution_ = std::uniform_real_distribution<float >(0.0 , 1.0 );
45
50
vec_size_ = prob.size ();
46
51
bit_len_ = ceil (log2 (vec_size_));
47
52
limit_ = 1 << bit_len_;
@@ -86,8 +91,8 @@ class ArrayHeap {
86
91
/*
87
92
* Sample from arrayHeap
88
93
*/
89
- size_t Sample (unsigned int * seed ) {
90
- float xi = heap_[1 ] * ( rand_r (seed)% 100 / 101.0 );
94
+ size_t Sample () {
95
+ float xi = heap_[1 ] * distribution_ (generator_ );
91
96
int i = 1 ;
92
97
while (i < limit_) {
93
98
i = i << 1 ;
@@ -102,10 +107,10 @@ class ArrayHeap {
102
107
/*
103
108
* Sample a vector by given the size n
104
109
*/
105
- void SampleWithoutReplacement (size_t n, std::vector<size_t >* samples, unsigned int * seed ) {
110
+ void SampleWithoutReplacement (size_t n, std::vector<size_t >* samples) {
106
111
// sample n elements
107
112
for (size_t i = 0 ; i < n; ++i) {
108
- samples->at (i) = this ->Sample (seed );
113
+ samples->at (i) = this ->Sample ();
109
114
this ->Delete (samples->at (i));
110
115
}
111
116
}
@@ -115,6 +120,8 @@ class ArrayHeap {
115
120
int bit_len_; // bit size
116
121
int limit_;
117
122
std::vector<float > heap_;
123
+ std::mt19937 generator_;
124
+ std::uniform_real_distribution<float > distribution_;
118
125
};
119
126
120
127
struct NeighborSampleParam : public dmlc ::Parameter<NeighborSampleParam> {
@@ -402,10 +409,12 @@ static bool CSRNeighborNonUniformSampleType(const nnvm::NodeAttrs& attrs,
402
409
static void RandomSample (size_t set_size,
403
410
size_t num,
404
411
std::vector<size_t >* out,
405
- unsigned int * seed) {
412
+ unsigned int seed) {
413
+ std::mt19937 generator (seed);
406
414
std::unordered_set<size_t > sampled_idxs;
415
+ std::uniform_int_distribution<size_t > distribution (0 , set_size - 1 );
407
416
while (sampled_idxs.size () < num) {
408
- sampled_idxs.insert (rand_r (seed) % set_size );
417
+ sampled_idxs.insert (distribution (generator) );
409
418
}
410
419
out->clear ();
411
420
for (auto it = sampled_idxs.begin (); it != sampled_idxs.end (); it++) {
@@ -441,7 +450,7 @@ static void GetUniformSample(const dgl_id_t* val_list,
441
450
const size_t max_num_neighbor,
442
451
std::vector<dgl_id_t >* out_ver,
443
452
std::vector<dgl_id_t >* out_edge,
444
- unsigned int * seed) {
453
+ unsigned int seed) {
445
454
// Copy ver_list to output
446
455
if (ver_len <= max_num_neighbor) {
447
456
for (size_t i = 0 ; i < ver_len; ++i) {
@@ -485,7 +494,7 @@ static void GetNonUniformSample(const float* probability,
485
494
const size_t max_num_neighbor,
486
495
std::vector<dgl_id_t >* out_ver,
487
496
std::vector<dgl_id_t >* out_edge,
488
- unsigned int * seed) {
497
+ unsigned int seed) {
489
498
// Copy ver_list to output
490
499
if (ver_len <= max_num_neighbor) {
491
500
for (size_t i = 0 ; i < ver_len; ++i) {
@@ -500,8 +509,8 @@ static void GetNonUniformSample(const float* probability,
500
509
for (size_t i = 0 ; i < ver_len; ++i) {
501
510
sp_prob[i] = probability[col_list[i]];
502
511
}
503
- ArrayHeap arrayHeap (sp_prob);
504
- arrayHeap.SampleWithoutReplacement (max_num_neighbor, &sp_index, seed );
512
+ ArrayHeap arrayHeap (sp_prob, seed );
513
+ arrayHeap.SampleWithoutReplacement (max_num_neighbor, &sp_index);
505
514
out_ver->resize (max_num_neighbor);
506
515
out_edge->resize (max_num_neighbor);
507
516
for (size_t i = 0 ; i < max_num_neighbor; ++i) {
@@ -536,8 +545,8 @@ static void SampleSubgraph(const NDArray &csr,
536
545
const float * probability,
537
546
int num_hops,
538
547
size_t num_neighbor,
539
- size_t max_num_vertices) {
540
- unsigned int time_seed = time ( nullptr );
548
+ size_t max_num_vertices,
549
+ unsigned int random_seed) {
541
550
size_t num_seeds = seed_arr.shape ().Size ();
542
551
CHECK_GE (max_num_vertices, num_seeds);
543
552
@@ -594,7 +603,7 @@ static void SampleSubgraph(const NDArray &csr,
594
603
num_neighbor,
595
604
&tmp_sampled_src_list,
596
605
&tmp_sampled_edge_list,
597
- &time_seed );
606
+ random_seed );
598
607
} else { // non-uniform-sample
599
608
GetNonUniformSample (probability,
600
609
val_list + *(indptr + dst_id),
@@ -603,7 +612,7 @@ static void SampleSubgraph(const NDArray &csr,
603
612
num_neighbor,
604
613
&tmp_sampled_src_list,
605
614
&tmp_sampled_edge_list,
606
- &time_seed );
615
+ random_seed );
607
616
}
608
617
CHECK_EQ (tmp_sampled_src_list.size (), tmp_sampled_edge_list.size ());
609
618
size_t pos = neighbor_list.size ();
@@ -720,12 +729,15 @@ static void CSRNeighborUniformSampleComputeExCPU(const nnvm::NodeAttrs& attrs,
720
729
const std::vector<NDArray>& inputs,
721
730
const std::vector<OpReqType>& req,
722
731
const std::vector<NDArray>& outputs) {
723
- const NeighborSampleParam& params =
724
- nnvm::get<NeighborSampleParam>(attrs.parsed );
732
+ const NeighborSampleParam& params = nnvm::get<NeighborSampleParam>(attrs.parsed );
725
733
726
734
int num_subgraphs = inputs.size () - 1 ;
727
735
CHECK_EQ (outputs.size (), 3 * num_subgraphs);
728
736
737
+ mshadow::Stream<cpu> *s = ctx.get_stream <cpu>();
738
+ mshadow::Random<cpu, unsigned int > *prnd = ctx.requested [0 ].get_random <cpu, unsigned int >(s);
739
+ unsigned int seed = prnd->GetRandInt ();
740
+
729
741
#pragma omp parallel for
730
742
for (int i = 0 ; i < num_subgraphs; i++) {
731
743
SampleSubgraph (inputs[0 ], // graph_csr
@@ -737,7 +749,12 @@ static void CSRNeighborUniformSampleComputeExCPU(const nnvm::NodeAttrs& attrs,
737
749
nullptr , // probability
738
750
params.num_hops ,
739
751
params.num_neighbor ,
740
- params.max_num_vertices );
752
+ params.max_num_vertices ,
753
+ #if defined(_OPENMP)
754
+ seed + omp_get_thread_num ());
755
+ #else
756
+ seed);
757
+ #endif
741
758
}
742
759
}
743
760
@@ -798,6 +815,9 @@ of max_num_vertices, and the valid number of vertices is the same as the ones in
798
815
.set_attr<mxnet::FInferShape>(" FInferShape" , CSRNeighborUniformSampleShape)
799
816
.set_attr<nnvm::FInferType>(" FInferType" , CSRNeighborUniformSampleType)
800
817
.set_attr<FComputeEx>(" FComputeEx<cpu>" , CSRNeighborUniformSampleComputeExCPU)
818
+ .set_attr<FResourceRequest>(" FResourceRequest" , [](const NodeAttrs& attrs) {
819
+ return std::vector<ResourceRequest>{ResourceRequest::kRandom };
820
+ })
801
821
.add_argument(" csr_matrix" , " NDArray-or-Symbol" , " csr matrix" )
802
822
.add_argument(" seed_arrays" , " NDArray-or-Symbol[]" , " seed vertices" )
803
823
.set_attr<std::string>(" key_var_num_args" , " num_args" )
@@ -811,14 +831,17 @@ static void CSRNeighborNonUniformSampleComputeExCPU(const nnvm::NodeAttrs& attrs
811
831
const std::vector<NDArray>& inputs,
812
832
const std::vector<OpReqType>& req,
813
833
const std::vector<NDArray>& outputs) {
814
- const NeighborSampleParam& params =
815
- nnvm::get<NeighborSampleParam>(attrs.parsed );
834
+ const NeighborSampleParam& params = nnvm::get<NeighborSampleParam>(attrs.parsed );
816
835
817
836
int num_subgraphs = inputs.size () - 2 ;
818
837
CHECK_EQ (outputs.size (), 4 * num_subgraphs);
819
838
820
839
const float * probability = inputs[1 ].data ().dptr <float >();
821
840
841
+ mshadow::Stream<cpu> *s = ctx.get_stream <cpu>();
842
+ mshadow::Random<cpu, unsigned int > *prnd = ctx.requested [0 ].get_random <cpu, unsigned int >(s);
843
+ unsigned int seed = prnd->GetRandInt ();
844
+
822
845
#pragma omp parallel for
823
846
for (int i = 0 ; i < num_subgraphs; i++) {
824
847
float * sub_prob = outputs[i+2 *num_subgraphs].data ().dptr <float >();
@@ -831,7 +854,12 @@ static void CSRNeighborNonUniformSampleComputeExCPU(const nnvm::NodeAttrs& attrs
831
854
probability,
832
855
params.num_hops ,
833
856
params.num_neighbor ,
834
- params.max_num_vertices );
857
+ params.max_num_vertices ,
858
+ #if defined(_OPENMP)
859
+ seed + omp_get_thread_num ());
860
+ #else
861
+ seed);
862
+ #endif
835
863
}
836
864
}
837
865
@@ -897,6 +925,9 @@ of max_num_vertices, and the valid number of vertices is the same as the ones in
897
925
.set_attr<mxnet::FInferShape>(" FInferShape" , CSRNeighborNonUniformSampleShape)
898
926
.set_attr<nnvm::FInferType>(" FInferType" , CSRNeighborNonUniformSampleType)
899
927
.set_attr<FComputeEx>(" FComputeEx<cpu>" , CSRNeighborNonUniformSampleComputeExCPU)
928
+ .set_attr<FResourceRequest>(" FResourceRequest" , [](const NodeAttrs& attrs) {
929
+ return std::vector<ResourceRequest>{ResourceRequest::kRandom };
930
+ })
900
931
.add_argument(" csr_matrix" , " NDArray-or-Symbol" , " csr matrix" )
901
932
.add_argument(" probability" , " NDArray-or-Symbol" , " probability vector" )
902
933
.add_argument(" seed_arrays" , " NDArray-or-Symbol[]" , " seed vertices" )
0 commit comments