You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Apologies if this is the incorrect repo for this issue
I was experimenting with UnitSimplex transformations in a DynamicHMC model and noticed that the transformation seems to be incompatible with AD via Zygote. This is a pity since Zygote is by far the fastest backend for my model.
The error occurs since transform_with on UnitSimplex mutates an array.
I've written a minimal working example based on a Multinomial model here:
using TransformVariables, LogDensityProblems, LogDensityProblemsAD, SimpleUnPack
#Mulitnomial distribution
n_cats = 3;
data = [100, 200, 300];
variable_transform = as((
ρ = UnitSimplex(n_cats),
))
struct log_density
data::Array{Int, 1}
end
function (problem::log_density)(θ, ::Type{T} = Float64) where {T}
@unpack data = problem # extract the data
@unpack ρ = θ # extract the parameters
# log likelihood
return sum(data .* log.(ρ))
end
problem = log_density(data);
transformed_model = TransformedLogDensity(variable_transform, problem);
#LD calculation works
pars = zeros(LogDensityProblems.dimension(transformed_model));
problem(TransformVariables.transform(variable_transform, pars))
#Forward Diff Works
model_gradient_fwd = ADgradient(:ForwardDiff, transformed_model);
LogDensityProblems.logdensity_and_gradient(model_gradient_fwd, pars)
#Zygote doesn't work
import Zygote
model_gradient_zy = ADgradient(:Zygote, transformed_model);
LogDensityProblems.logdensity_and_gradient(model_gradient_zy, pars)
Would be possible to implement so sort of work around for UnitSimplex?
The text was updated successfully, but these errors were encountered:
Apologies if this is the incorrect repo for this issue
I was experimenting with
UnitSimplex
transformations in a DynamicHMC model and noticed that the transformation seems to be incompatible with AD via Zygote. This is a pity since Zygote is by far the fastest backend for my model.The error occurs since
transform_with
onUnitSimplex
mutates an array.I've written a minimal working example based on a Multinomial model here:
Would be possible to implement so sort of work around for
UnitSimplex
?The text was updated successfully, but these errors were encountered: