Skip to content

Commit 3285365

Browse files
authored
feat: add mutlinomial sampling for tensors (#88)
This change adss a `multinomial` method to the `Tensor` class, allowing you to draw samples from a multinomial distribution. The multinomial method accepts a tensor `input` and will produce an output `tensor` that samples the `input` probabilities. ``` Num::Rand.set_seed(0) input = [[0.5, 0.5], [0.5, 0.5]].to_tensor a = Tensor.multinomial(input, 5) puts a # => [[0, 1, 1, 0, 1], [1, 0, 1, 1, 0]] input2 = [0.5, 0.5, 0.5, 0.5].to_tensor b = Tensor.multinomial(input, 6) puts b # => [3, 2, 1, 1, 0, 2] ``` The logic of this method is based on the equivalent `torch.multinomial` method: https://pytorch.org/docs/stable/generated/torch.multinomial.html Signed-off-by: Lucian Buzzo <[email protected]>
1 parent 9c91420 commit 3285365

File tree

2 files changed

+138
-0
lines changed

2 files changed

+138
-0
lines changed

spec/tensor/random_spec.cr

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) 2023 Crystal Data Contributors
2+
#
3+
# MIT License
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining
6+
# a copy of this software and associated documentation files (the
7+
# "Software"), to deal in the Software without restriction, including
8+
# without limitation the rights to use, copy, modify, merge, publish,
9+
# distribute, sublicense, and/or sell copies of the Software, and to
10+
# permit persons to whom the Software is furnished to do so, subject to
11+
# the following conditions:
12+
#
13+
# The above copyright notice and this permission notice shall be
14+
# included in all copies or substantial portions of the Software.
15+
#
16+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17+
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18+
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19+
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
20+
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21+
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
22+
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
23+
require "../spec_helper"
24+
25+
describe Tensor do
26+
it "can sample a multinomial distribution" do
27+
input = [[0.4, 0.6], [0.5, 0.5]].to_tensor
28+
a = Tensor.multinomial(input, 5)
29+
b = [[0, 1, 0, 1, 1], [0, 1, 1, 0, 0]].to_tensor
30+
Num::Testing.tensor_equal(a, b).should be_true
31+
end
32+
33+
it "can sample a multinomial distribution using a 1D input" do
34+
input = [0.2, 0.1, 0.3, 0.4].to_tensor
35+
a = Tensor.multinomial(input, 6)
36+
b = [2, 3, 3, 0, 3, 2].to_tensor
37+
Num::Testing.tensor_equal(a, b).should be_true
38+
end
39+
40+
it "can sample a multinomial distribution using a non-normalized input" do
41+
input = [6, 3, 9, 12].to_tensor
42+
a = Tensor.multinomial(input, 6)
43+
b = [1, 2, 3, 3, 2, 2].to_tensor
44+
Num::Testing.tensor_equal(a, b).should be_true
45+
end
46+
end

src/tensor/random.cr

+92
Original file line numberDiff line numberDiff line change
@@ -378,4 +378,96 @@ class Tensor(T, S)
378378
success
379379
end
380380
end
381+
382+
# Draw samples from a multinomial distribution.
383+
# Returns a Tensor where each row contains `num_samples` samples from the multinomial distribution
384+
# located in the corresponding row of Tensor `input`.
385+
# The rows of `input` do not need to be normalized, but must sum to a positive number.
386+
# If `input` is a vector (1-D Tensor), returns a vector of length `num_samples`
387+
# If `input` is a matrix (2-D Tensor), returns a matrix where each row contains `num_samples` samples, with shape (*m* x `num_samples`).
388+
#
389+
# ## Arguments
390+
#
391+
# * input : `Tensor` - Tensor containing probabilities of different outcomes
392+
# * num_samples : `Int` - Number of samples to draw from the multinomial distribution
393+
#
394+
# ## Examples
395+
#
396+
# ```
397+
# Num::Rand.set_seed(0)
398+
# input = [[0.5, 0.5], [0.5, 0.5]].to_tensor
399+
# a = Tensor.multinomial(input, 5)
400+
# puts a # => [[0, 1, 1, 0, 1], [1, 0, 1, 1, 0]]
401+
402+
# input2 = [0.5, 0.5, 0.5, 0.5].to_tensor
403+
# b = Tensor.multinomial(input, 6)
404+
# puts b # => [3, 2, 1, 1, 0, 2]
405+
# ```
406+
def self.multinomial(input : Tensor(T, S), num_samples : Int32)
407+
sum = input.sum
408+
409+
if sum == 0
410+
raise "Sum of probabilities is 0, can't draw samples"
411+
end
412+
413+
# Normalize 1D tensors into 2D tensors
414+
if input.shape.size == 1
415+
input = input.expand_dims(0)
416+
end
417+
418+
# Normalize the probabilities
419+
probabilities = input / input.sum(axis: 1, dims: true)
420+
421+
samples = [] of Array(Int32)
422+
423+
probabilities.each_axis(0) do |p_row|
424+
sample_set = [] of Int32
425+
num_samples.times do
426+
rand_num = Num::Rand.generator.float32
427+
428+
# Calculate the cumulative probabilities
429+
cumulative_prob = 0.0
430+
431+
# default to return the last probability
432+
s_index = p_row.size - 1
433+
434+
# Loop through the probabilities
435+
p_row.each_with_index do |prob, index|
436+
cumulative_prob += prob
437+
if rand_num <= cumulative_prob
438+
s_index = index
439+
break
440+
end
441+
end
442+
443+
sample_set << s_index
444+
end
445+
446+
samples << sample_set
447+
end
448+
449+
# If the input is a vector, return a vector of size num_samples
450+
if input.shape[0] == 1
451+
samples[0].to_tensor
452+
else
453+
samples.to_tensor
454+
end
455+
end
456+
457+
private def self.draw_sample(probabilities : Array(Float64))
458+
# Generate a random number between 0 and 1
459+
rand_num = Random.new.rand
460+
461+
# Calculate the cumulative probabilities
462+
cumulative_prob = 0.0
463+
464+
# Loop through the probabilities
465+
probabilities.each_with_index do |prob, index|
466+
cumulative_prob += prob
467+
return index if rand_num <= cumulative_prob
468+
end
469+
470+
# If no index has been returned, return the last one
471+
probabilities.size - 1
472+
end
381473
end

0 commit comments

Comments
 (0)