Skip to content

Commit 6692b45

Browse files
authored
Env interface overhaul (#6)
1. Updates LyceumAI to new env interface (Lyceum/LyceumBase.jl#7) 2. Adds NPG test on PointMass
1 parent c0a3ae1 commit 6692b45

File tree

10 files changed

+265
-229
lines changed

10 files changed

+265
-229
lines changed

Manifest.toml

+122-85
Large diffs are not rendered by default.

Project.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
1414
LearnBase = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6"
1515
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1616
LyceumBase = "db31fed1-ca1e-4084-8a49-12fae1996a55"
17+
LyceumMuJoCo = "48b9757e-04b8-4dbf-b6ed-75c13d9e4026"
1718
MLDataPattern = "9920b226-0b2a-5f5f-9153-9aa70a013f8b"
1819
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1920
MuJoCo = "93189219-7048-461c-94ec-443a161ed927"
@@ -28,8 +29,8 @@ UnsafeArrays = "c4a57d5a-5b31-53a6-b365-19f8c011fbd6"
2829
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2930

3031
[compat]
31-
julia = "1.3"
3232
Flux = "0.10"
33+
julia = "1.3"
3334

3435
[extras]
3536
LyceumMuJoCo = "48b9757e-04b8-4dbf-b6ed-75c13d9e4026"

src/algorithms/MPPI.jl

+24-13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Algorithm 2 from https://www.cc.gatech.edu/~bboots3/files/InformationTheoreticMPC.pdf
2-
struct MPPI{DT,nu,C<:AbstractMatrix{DT},V,E,F,O}
2+
struct MPPI{DT,nu,C<:AbstractMatrix{DT},V,E,F,O,S}
33
# MPPI parameters
44
K::Int
55
H::Int
@@ -15,10 +15,11 @@ struct MPPI{DT,nu,C<:AbstractMatrix{DT},V,E,F,O}
1515
covar_ul::UpperTriangular{DT,C}
1616
meantrajectory::Matrix{DT}
1717
trajectorycosts::Vector{DT}
18-
observationbuffers::Vector{O}
18+
obsbuffers::Vector{O}
19+
statebuffers::Vector{S}
1920

2021
function MPPI{DT}(
21-
sharedmemory_envctor,
22+
env_tconstructor,
2223
K::Integer,
2324
H::Integer,
2425
covar0::AbstractMatrix{<:Real},
@@ -27,10 +28,11 @@ struct MPPI{DT,nu,C<:AbstractMatrix{DT},V,E,F,O}
2728
valuefn,
2829
initfn!,
2930
) where {DT<:AbstractFloat}
30-
envs = [e for e in sharedmemory_envctor(Threads.nthreads())]
31+
envs = [e for e in env_tconstructor(Threads.nthreads())]
3132

33+
ssp = statespace(first(envs))
3234
asp = actionspace(first(envs))
33-
osp = observationspace(first(envs))
35+
osp = obsspace(first(envs))
3436

3537
nd, elt = ndims(asp), eltype(asp)
3638
if nd != 1 || !(elt <: AbstractFloat)
@@ -55,7 +57,8 @@ struct MPPI{DT,nu,C<:AbstractMatrix{DT},V,E,F,O}
5557
meantrajectory = zeros(DT, asp, H)
5658
trajectorycosts = zeros(DT, K)
5759
noise = zeros(DT, asp, H, K)
58-
observationbuffers = [allocate(osp) for _ = 1:Threads.nthreads()]
60+
obsbuffers = [allocate(osp) for _ = 1:Threads.nthreads()]
61+
statebuffers = [allocate(ssp) for _ = 1:Threads.nthreads()]
5962

6063
new{
6164
DT,
@@ -64,7 +67,8 @@ struct MPPI{DT,nu,C<:AbstractMatrix{DT},V,E,F,O}
6467
typeof(valuefn),
6568
eltype(envs),
6669
typeof(initfn!),
67-
eltype(observationbuffers),
70+
eltype(obsbuffers),
71+
eltype(statebuffers),
6872
}(
6973
K,
7074
H,
@@ -78,14 +82,15 @@ struct MPPI{DT,nu,C<:AbstractMatrix{DT},V,E,F,O}
7882
covar_ul,
7983
meantrajectory,
8084
trajectorycosts,
81-
observationbuffers,
85+
obsbuffers,
86+
statebuffers
8287
)
8388
end
8489
end
8590

8691
function MPPI(;
8792
dtype = Float64,
88-
sharedmemory_envctor,
93+
env_tconstructor,
8994
covar0,
9095
lambda,
9196
K,
@@ -94,7 +99,7 @@ function MPPI(;
9499
valuefn = zerofn,
95100
initfn! = default_initfn!,
96101
)
97-
MPPI{dtype}(sharedmemory_envctor, K, H, covar0, lambda, gamma, valuefn, initfn!)
102+
MPPI{dtype}(env_tconstructor, K, H, covar0, lambda, gamma, valuefn, initfn!)
98103
end
99104

100105
LyceumBase.reset!(m::MPPI) = (fill!(m.meantrajectory, 0); m)
@@ -142,20 +147,26 @@ end
142147

143148
function perturbedrollout!(m::MPPI{DT,nu}, state, k, tid) where {DT,nu}
144149
env = m.envs[tid]
145-
obsbuf = m.observationbuffers[tid]
150+
obsbuf = m.obsbuffers[tid]
151+
statebuf = m.statebuffers[tid]
146152
mean = m.meantrajectory
147153
noise = m.noise
148154

149-
reset!(env, state)
155+
setstate!(env, state)
150156
discountedreward = zero(DT)
151157
discountfactor = one(DT)
152158
@uviews mean noise @inbounds for t = 1:m.H
153159
mean_t = SVector{nu,DT}(view(mean, :, t))
154160
noise_tk = SVector{nu,DT}(view(noise, :, t, k))
155161
action_t = mean_t + noise_tk
156162
setaction!(env, action_t)
163+
157164
step!(env)
158-
reward = getreward(env)
165+
166+
getobs!(obsbuf, env)
167+
getstate!(statebuf, env)
168+
reward = getreward(statebuf, action_t, obsbuf, env)
169+
159170
discountedreward += reward * discountfactor
160171
discountfactor *= m.gamma
161172
end # env at t=H+1

src/algorithms/NPG.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ struct NaturalPolicyGradient{DT,S,P,V,VF,CB}
2727
returns_vec::Vector{DT} # N
2828

2929
function NaturalPolicyGradient(
30-
env_ctor,
30+
env_tconstructor,
3131
policy,
3232
value,
3333
valuefit!;
@@ -70,7 +70,7 @@ struct NaturalPolicyGradient{DT,S,P,V,VF,CB}
7070
DT = DTnew
7171
end
7272

73-
envsampler = EnvSampler(env_ctor, dtype=DT)
73+
envsampler = EnvSampler(env_tconstructor, dtype=DT)
7474

7575
z(d...) = zeros(DT, d...)
7676
new{

src/controller.jl

+9-4
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@ struct ControllerIterator{C,E,B}
88

99
function ControllerIterator(
1010
controller,
11-
env::AbstractEnv;
11+
env::AbstractEnvironment;
1212
T = 1000,
1313
plotiter = 1,
1414
)
1515
trajectory = (
1616
states = Array(undef, statespace(env), T),
17-
observations = Array(undef, observationspace(env), T),
17+
observations = Array(undef, obsspace(env), T),
1818
actions = Array(undef, actionspace(env), T),
1919
rewards = Array(undef, rewardspace(env), T),
20-
evaluations = Array(undef, evaluationspace(env), T),
20+
evaluations = Array(undef, evalspace(env), T),
2121
)
2222
new{typeof(controller),typeof(env),typeof(trajectory)}(
2323
controller,
@@ -59,8 +59,13 @@ function rolloutstep!(controller, traj, env, t)
5959
getstate!(st, env)
6060
getobs!(ot, env)
6161
getaction!(at, st, ot, controller)
62+
setaction!(env, at)
63+
64+
step!(env)
65+
r = getreward(st, at, ot, env)
66+
e = geteval(st, at, ot, env)
67+
done = isdone(st, at, ot, env)
6268

63-
r, e, done = step!(env, at)
6469
traj.rewards[t] = r
6570
traj.evaluations[t] = e
6671

test/NPG.jl

-92
This file was deleted.

test/algorithms/MPPI.jl

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
@testset "MPPI (PointMass)" begin
2+
seed_threadrngs!(1)
3+
etype = LyceumMuJoCo.PointMass
4+
env = etype()
5+
T = 300
6+
K = 8
7+
H = 10
8+
9+
mppi = MPPI(
10+
env_tconstructor = n -> tconstruct(etype, n),
11+
covar0 = Diagonal(0.1^2*I, size(actionspace(env), 1)),
12+
lambda = 0.01,
13+
K = K,
14+
H = H,
15+
gamma = 0.99
16+
)
17+
env = testrollout(env, T) do a, s, o
18+
getaction!(a, s, o, mppi)
19+
end
20+
@test abs(geteval(env)) < 0.001
21+
end
22+

test/algorithms/NPG.jl

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
@testset "NPG (PointMass)" begin
2+
seed_threadrngs!(1)
3+
etype = LyceumMuJoCo.PointMass
4+
5+
e = etype()
6+
dobs, dact = length(obsspace(e)), length(actionspace(e))
7+
8+
DT = Float32
9+
Hmax, K = 300, 16
10+
N = Hmax * K
11+
12+
policy = DiagGaussianPolicy(
13+
multilayer_perceptron(dobs, 32, 32, dact, σ=tanh),
14+
zeros(dact)
15+
)
16+
policy = Flux.paramtype(DT, policy)
17+
18+
value = multilayer_perceptron(dobs, 32, 32, 1, σ=Flux.relu)
19+
valueloss(bl, X, Y) = mse(vec(bl(X)), vec(Y))
20+
21+
valuetrainer = FluxTrainer(
22+
optimiser = ADAM(1e-2),
23+
szbatch = 32,
24+
lossfn = valueloss,
25+
stopcb = s->s.nepochs > 4
26+
)
27+
value = Flux.paramtype(DT, value)
28+
29+
npg = NaturalPolicyGradient(
30+
n -> tconstruct(etype, n),
31+
policy,
32+
value,
33+
gamma = 0.95,
34+
gaelambda = 0.99,
35+
valuetrainer,
36+
Hmax=Hmax,
37+
norm_step_size=0.05,
38+
N=N,
39+
)
40+
41+
meanterminal_eval = nothing
42+
for (i, state) in enumerate(npg)
43+
i > 30 && break
44+
meanterminal_eval = state.meanterminal_eval
45+
end
46+
47+
@test meanterminal_eval < 0.1
48+
end

0 commit comments

Comments
 (0)