Skip to content

Commit 0151911

Browse files
authored
Add parallel keyword to GMM (#114)
1 parent 9fd2dd0 commit 0151911

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

src/train.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,19 @@ GMM(x::Vector{T}) where T <: AbstractFloat = GMM(reshape(x, length(x), 1)) # st
2828

2929
## constructors based on data or matrix
3030
function GMM(n::Int, x::DataOrMatrix{T}; method::Symbol=:kmeans, kind=:diag,
31-
nInit::Int=50, nIter::Int=10, nFinal::Int=nIter, sparse=0) where T <: AbstractFloat
31+
nInit::Int=50, nIter::Int=10, nFinal::Int=nIter, sparse=0, parallel::Bool=true) where T <: AbstractFloat
3232
if n < 2
3333
return GMM(x, kind=kind)
3434
elseif method == :split
35-
return GMM2(n, x, kind=kind, nIter=nIter, nFinal=nFinal, sparse=sparse)
35+
return GMM2(n, x, kind=kind, nIter=nIter, nFinal=nFinal, sparse=sparse, parallel=parallel)
3636
elseif method == :kmeans
37-
return GMMk(n, x, kind=kind, nInit=nInit, nIter=nIter, sparse=sparse)
37+
return GMMk(n, x, kind=kind, nInit=nInit, nIter=nIter, sparse=sparse, parallel=parallel)
3838
else
3939
error("Unknown method ", method)
4040
end
4141
end
4242
## a 1-dimensional Gaussian can be initialized with a vector, skip kind=
43-
GMM(n::Int, x::Vector{T}; method::Symbol=:kmeans, nInit::Int=50, nIter::Int=10, nFinal::Int=nIter, sparse=0) where T <: AbstractFloat = GMM(n, reshape(x, length(x), 1); method=method, kind=:diag, nInit=nInit, nIter=nIter, nFinal=nFinal, sparse=sparse)
43+
GMM(n::Int, x::Vector{T}; method::Symbol=:kmeans, nInit::Int=50, nIter::Int=10, nFinal::Int=nIter, sparse=0, parallel::Bool=true) where T <: AbstractFloat = GMM(n, reshape(x, length(x), 1); method=method, kind=:diag, nInit=nInit, nIter=nIter, nFinal=nFinal, sparse=sparse, parallel=parallel)
4444

4545
## we sometimes end up with pathological gmms
4646
function sanitycheck!(gmm::GMM)
@@ -73,7 +73,7 @@ end
7373

7474

7575
## initialize GMM using Clustering.kmeans (which uses a method similar to kmeans++)
76-
function GMMk(n::Int, x::DataOrMatrix{T}; kind=:diag, nInit::Int=50, nIter::Int=10, sparse=0) where T <: AbstractFloat
76+
function GMMk(n::Int, x::DataOrMatrix{T}; kind=:diag, nInit::Int=50, nIter::Int=10, sparse=0, parallel::Bool=true) where T <: AbstractFloat
7777
nₓ, d = size(x)
7878
hist = [History(@sprintf("Initializing GMM, %d Gaussians %s covariance %d dimensions using %d data points", n, diag, d, nₓ))]
7979
@info(last(hist).s)
@@ -141,22 +141,22 @@ function GMMk(n::Int, x::DataOrMatrix{T}; kind=:diag, nInit::Int=50, nIter::Int=
141141
@info(last(hist).s)
142142
gmm = GMM(w, μ, Σ, hist, nxx)
143143
sanitycheck!(gmm)
144-
em!(gmm, x; nIter=nIter, sparse=sparse)
144+
em!(gmm, x; nIter=nIter, sparse=sparse, parallel=parallel)
145145
return gmm
146146
end
147147

148148
## Train a GMM by consecutively splitting all means. n most be a power of 2
149149
## This kind of initialization is deterministic, but doesn't work particularily well, its seems
150150
## We start with one Gaussian, and consecutively split.
151-
function GMM2(n::Int, x::DataOrMatrix; kind=:diag, nIter::Int=10, nFinal::Int=nIter, sparse=0)
151+
function GMM2(n::Int, x::DataOrMatrix; kind=:diag, nIter::Int=10, nFinal::Int=nIter, sparse=0, parallel::Bool=true)
152152
log2n = round(Int,log2(n))
153153
2^log2n == n || error("n must be power of 2")
154154
gmm = GMM(x, kind=kind)
155155
tll = [avll(gmm, x)]
156156
@info("0: avll = ", tll[1])
157157
for i in 1:log2n
158158
gmm = gmmsplit(gmm)
159-
avll = em!(gmm, x; nIter=(i==log2n ? nFinal : nIter), sparse=sparse)
159+
avll = em!(gmm, x; nIter=(i==log2n ? nFinal : nIter), sparse=sparse, parallel=parallel)
160160
@info(i, avll)
161161
append!(tll, avll)
162162
end
@@ -235,7 +235,7 @@ end
235235
# the log-likelihood history, per data frame per dimension
236236
## Note: 0 iterations is allowed, this just computes the average log likelihood
237237
## of the data and stores this in the history.
238-
function em!(gmm::GMM, x::DataOrMatrix; nIter::Int = 10, varfloor::Float64=1e-3, sparse=0, debug=1)
238+
function em!(gmm::GMM, x::DataOrMatrix; nIter::Int = 10, varfloor::Float64=1e-3, sparse=0, parallel::Bool=true, debug=1)
239239
size(x,2)==gmm.d || error("Inconsistent size gmm and x")
240240
d = gmm.d # dim
241241
ng = gmm.n # n gaussians
@@ -247,7 +247,7 @@ function em!(gmm::GMM, x::DataOrMatrix; nIter::Int = 10, varfloor::Float64=1e-3,
247247

248248
for i in 1:nIter
249249
## E-step
250-
nₓ, ll[i], N, F, S = stats(gmm, x, parallel=true)
250+
nₓ, ll[i], N, F, S = stats(gmm, x, parallel=parallel)
251251
## M-step
252252
gmm.w = N / nₓ
253253
gmm.μ = F ./ N

0 commit comments

Comments
 (0)