-
-
Notifications
You must be signed in to change notification settings - Fork 217
Closed
FluxML/Flux.jl
#1704Description
MWE:
using Zygote, CUDA
CUDA.allowscalar(false)
W = CuArray(rand(4,4))
x = Zygote.OneElement(1f0,(1,),axes(rand(4)))
W' * x # Scalar indexing
From:
using DiffEqFlux, Flux, Optim, OrdinaryDiffEq, CUDA, DiffEqSensitivity, Plots
u0 = [1.1; 1.1] |> gpu
tspan = (0.0f0,25.0f0)
ann = FastChain(FastDense(2,16,tanh), FastDense(16,16,tanh), FastDense(16,1))
p1 = initial_params(ann)
p2 = Float32[0.5,-0.5]
p3 = [p1;p2]
θ = Float32[u0;p3]
function dudt_(u,p,t)
x, y = u
pend = cpu(p[end-1:end])
@show typeof(p[1:length(p1)])
@show typeof(gpu(u))
@show cpu(ann(gpu(u),p[1:length(p1)]))[1]
@show pend[1]*y + pend[2]*x
[cpu(ann(gpu(u),p[1:length(p1)]))[1],pend[1]*y + pend[2]*x]
end
prob = ODEProblem{false}(dudt_,u0,tspan,p3)
function predict_adjoint(θ)
gpu(Array(solve(prob,Tsit5(),u0=cpu(θ[1:2]),p=θ[3:end],saveat=0.0:1:25.0,sensealg=QuadratureAdjoint())))
end
loss_adjoint(θ) = sum(abs2,predict_adjoint(θ)[2,:].-1)
l = loss_adjoint(θ)
cb = function (θ,l)
println(l)
#display(plot(solve(remake(prob,p=Flux.data(p3),u0=Flux.data(u0)),Tsit5(),saveat=0.1),ylim=(0,6)))
return false
end
loss1 = loss_adjoint(θ)
Zygote.gradient(loss_adjoint,θ)
Metadata
Metadata
Assignees
Labels
No labels