Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit e6fad30

Browse files
xziyawkcn
authored andcommitted
Efficient MXNet sampling in the multinomial distribution (#15311)
* Effective multinomial * Meaningful uniform data pointer as input * Remove beginning Zeros from CDFs * Double precision for accumulated var
1 parent b4ce4e7 commit e6fad30

File tree

1 file changed

+24
-18
lines changed

1 file changed

+24
-18
lines changed

src/operator/random/sample_multinomial_op.h

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -122,25 +122,29 @@ inline bool SampleMultinomialOpType(const nnvm::NodeAttrs& attrs,
122122
struct SampleMultinomialKernel {
123123
template<typename DType, typename IType>
124124
MSHADOW_XINLINE static void Map(int i, index_t K, index_t M,
125-
DType* dist, float* uniform, IType* out,
126-
DType* prob) {
125+
DType* dist, float* uniform, float* cum_table,
126+
IType* out, DType* prob) {
127+
double acc = 0.0;
128+
// CDF table
129+
for (index_t c = 0; c < K; ++c) {
130+
acc += dist[i*K + c];
131+
cum_table[i*K + c] = static_cast<float>(acc);
132+
}
127133
for (index_t j = 0; j < M; ++j) {
134+
index_t left = 0, right = K;
135+
index_t middle = left + (right - left) / 2;
128136
DType loc = static_cast<DType>(uniform[i*M + j]);
129-
DType acc = 0;
130-
bool found = false;
131-
for (index_t k = 0; k < K; ++k) {
132-
acc += dist[i*K + k];
133-
if (acc > loc) {
134-
found = true;
135-
out[i*M + j] = static_cast<IType>(k);
136-
if (prob != nullptr) prob[i*M + j] = logf(dist[i*K + k]);
137-
break;
137+
while (right - left > 0) {
138+
middle = left + (right - left) / 2;
139+
DType cum_prob = cum_table[i*K + middle];
140+
if (cum_prob < loc) {
141+
left = middle + 1;
142+
} else {
143+
right = middle;
138144
}
139145
}
140-
if (!found) {
141-
out[i*M + j] = static_cast<IType>(K-1);
142-
if (prob != nullptr) prob[i*M + j] = logf(dist[i*K + K - 1]);
143-
}
146+
out[i*M + j] = static_cast<IType>(left);
147+
if (prob != nullptr) prob[i*M + j] = logf(dist[i*K + left]);
144148
}
145149
}
146150
};
@@ -163,12 +167,14 @@ void SampleMultinomialForward(const nnvm::NodeAttrs& attrs,
163167
Stream<xpu> *s = ctx.get_stream<xpu>();
164168
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
165169
Random<xpu, float> *prnd = ctx.requested[0].get_random<xpu, float>(s);
166-
Tensor<xpu, 1, float> uniform =
167-
ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(N*M), s);
170+
Tensor<xpu, 1, float> workspace =
171+
ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(N*M + N*K), s);
172+
Tensor<xpu, 1, float> uniform(workspace.dptr_, Shape1(N*M));
168173
prnd->SampleUniform(&uniform, 0, 1);
169174
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, IType, {
170175
Kernel<SampleMultinomialKernel, xpu>::Launch(
171-
s, N, K, M, inputs[0].dptr<DType>(), uniform.dptr_, outputs[0].dptr<IType>(),
176+
s, N, K, M, inputs[0].dptr<DType>(), uniform.dptr_, workspace.dptr_ + N*M,
177+
outputs[0].dptr<IType>(),
172178
param.get_prob ? outputs[1].dptr<DType>() : nullptr);
173179
});
174180
});

0 commit comments

Comments
 (0)