@@ -122,25 +122,29 @@ inline bool SampleMultinomialOpType(const nnvm::NodeAttrs& attrs,
122
122
struct SampleMultinomialKernel {
123
123
template <typename DType, typename IType>
124
124
MSHADOW_XINLINE static void Map (int i, index_t K, index_t M,
125
- DType* dist, float * uniform, IType* out,
126
- DType* prob) {
125
+ DType* dist, float * uniform, float * cum_table,
126
+ IType* out, DType* prob) {
127
+ double acc = 0.0 ;
128
+ // CDF table
129
+ for (index_t c = 0 ; c < K; ++c) {
130
+ acc += dist[i*K + c];
131
+ cum_table[i*K + c] = static_cast <float >(acc);
132
+ }
127
133
for (index_t j = 0 ; j < M; ++j) {
134
+ index_t left = 0 , right = K;
135
+ index_t middle = left + (right - left) / 2 ;
128
136
DType loc = static_cast <DType>(uniform[i*M + j]);
129
- DType acc = 0 ;
130
- bool found = false ;
131
- for (index_t k = 0 ; k < K; ++k) {
132
- acc += dist[i*K + k];
133
- if (acc > loc) {
134
- found = true ;
135
- out[i*M + j] = static_cast <IType>(k);
136
- if (prob != nullptr ) prob[i*M + j] = logf (dist[i*K + k]);
137
- break ;
137
+ while (right - left > 0 ) {
138
+ middle = left + (right - left) / 2 ;
139
+ DType cum_prob = cum_table[i*K + middle];
140
+ if (cum_prob < loc) {
141
+ left = middle + 1 ;
142
+ } else {
143
+ right = middle;
138
144
}
139
145
}
140
- if (!found) {
141
- out[i*M + j] = static_cast <IType>(K-1 );
142
- if (prob != nullptr ) prob[i*M + j] = logf (dist[i*K + K - 1 ]);
143
- }
146
+ out[i*M + j] = static_cast <IType>(left);
147
+ if (prob != nullptr ) prob[i*M + j] = logf (dist[i*K + left]);
144
148
}
145
149
}
146
150
};
@@ -163,12 +167,14 @@ void SampleMultinomialForward(const nnvm::NodeAttrs& attrs,
163
167
Stream<xpu> *s = ctx.get_stream <xpu>();
164
168
MSHADOW_REAL_TYPE_SWITCH (inputs[0 ].type_flag_ , DType, {
165
169
Random<xpu, float > *prnd = ctx.requested [0 ].get_random <xpu, float >(s);
166
- Tensor<xpu, 1 , float > uniform =
167
- ctx.requested [1 ].get_space_typed <xpu, 1 , float >(Shape1 (N*M), s);
170
+ Tensor<xpu, 1 , float > workspace =
171
+ ctx.requested [1 ].get_space_typed <xpu, 1 , float >(Shape1 (N*M + N*K), s);
172
+ Tensor<xpu, 1 , float > uniform (workspace.dptr_ , Shape1 (N*M));
168
173
prnd->SampleUniform (&uniform, 0 , 1 );
169
174
MSHADOW_TYPE_SWITCH (outputs[0 ].type_flag_ , IType, {
170
175
Kernel<SampleMultinomialKernel, xpu>::Launch (
171
- s, N, K, M, inputs[0 ].dptr <DType>(), uniform.dptr_ , outputs[0 ].dptr <IType>(),
176
+ s, N, K, M, inputs[0 ].dptr <DType>(), uniform.dptr_ , workspace.dptr_ + N*M,
177
+ outputs[0 ].dptr <IType>(),
172
178
param.get_prob ? outputs[1 ].dptr <DType>() : nullptr );
173
179
});
174
180
});
0 commit comments