Skip to content

revamped envsampler #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 23, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 109 additions & 100 deletions src/Tools/envsampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@ struct TrajectoryBuffer{B1<:ElasticBuffer,B2<:ElasticBuffer}
terminal::B2
end

function TrajectoryBuffer(env::AbstractEnvironment; sizehint::Union{Integer,Nothing} = nothing, dtype::Maybe{DataType} = nothing)
function TrajectoryBuffer(
env::AbstractEnvironment;
sizehint::Maybe{Integer} = nothing,
dtype::Maybe{DataType} = nothing,
)
sp = dtype === nothing ? spaces(env) : adapt(dtype, spaces(env))

trajectory = ElasticBuffer(
states = sp.statespace,
observations = sp.obsspace,
actions = sp.actionspace,
rewards = sp.rewardspace,
evaluations = sp.evalspace
evaluations = sp.evalspace,
)

terminal = ElasticBuffer(
Expand All @@ -21,7 +25,7 @@ function TrajectoryBuffer(env::AbstractEnvironment; sizehint::Union{Integer,Noth
lengths = Int,
)

if !isnothing(sizehint)
if sizehint !== nothing
sizehint!(trajectory, sizehint)
sizehint!(terminal, sizehint)
end
Expand All @@ -36,18 +40,19 @@ struct EnvSampler{E<:AbstractEnvironment,B<:TrajectoryBuffer,BA}
batch::BA
end

function EnvSampler(env_tconstructor::Function; sizehint::Union{Integer,Nothing} = nothing, dtype::Maybe{DataType} = nothing)
function EnvSampler(
env_tconstructor::Function;
sizehint::Maybe{Integer} = nothing,
dtype::Maybe{DataType} = nothing,
)
envs = [e for e in env_tconstructor(Threads.nthreads())]
bufs = [TrajectoryBuffer(first(envs), sizehint=sizehint, dtype=dtype) for _ = 1:Threads.nthreads()]
batch = makebatch(first(envs), sizehint=sizehint, dtype=dtype)
EnvSampler(envs, bufs, batch)
end
bufs = [
TrajectoryBuffer(first(envs), sizehint = sizehint, dtype = dtype)
for _ = 1:Threads.nthreads()
]
batch = _makebatch(first(envs), sizehint = sizehint, dtype = dtype)

function emptybufs!(sampler::EnvSampler)
for buf in sampler.bufs
empty!(buf.trajectory)
empty!(buf.terminal)
end
EnvSampler(envs, bufs, batch)
end


Expand All @@ -60,20 +65,19 @@ function sample!(
nthreads::Integer = Threads.nthreads(),
copy::Bool = false,
)
nsamples > 0 || error("nsamples must be > 0")
(0 < Hmax <= nsamples) || error("Hmax must be 0 < Hmax <= nsamples")
(
0 < nthreads <= Threads.nthreads()
) || error("nthreads must b 0 < nthreads < Threads.nthreads()")
nsamples > 0 || error("`nsamples` must be > 0")
0 < Hmax <= nsamples || error("`Hmax` must be in range (0, `nsamples`]")
if !(0 < nthreads <= Threads.nthreads())
error("`nthreads` must be in range (0, Threads.nthreads()]")
end

atomiccount = Threads.Atomic{Int}(0)
atomicidx = Threads.Atomic{Int}(1)

nthreads = _defaultnthreads(nsamples, Hmax, nthreads)
emptybufs!(sampler)
_emptybufs!(sampler)

if nthreads == 1
_threadsample!(actionfn!, resetfn!, sampler, nsamples, Hmax)
# short circuit
_sample!(actionfn!, resetfn!, sampler, nsamples, Hmax)
else
@sync for _ = 1:nthreads
Threads.@spawn _threadsample!(
Expand All @@ -87,37 +91,33 @@ function sample!(
end
end

@assert sum(b -> length(b.trajectory), sampler.bufs) >= nsamples

collate!(sampler, nsamples, copy)
_collate!(sampler, nsamples, copy)
_checkbatch(sampler.batch, nsamples)
sampler.batch
end

function _threadsample!(actionfn!::F, resetfn!::G, sampler, nsamples, Hmax) where {F,G}
env = sampler.envs[Threads.threadid()]
buf = sampler.bufs[Threads.threadid()]
function _sample!(actionfn!::F, resetfn!::G, sampler, nsamples, Hmax) where {F,G}
env = first(sampler.envs)
buf = first(sampler.bufs)
traj = buf.trajectory
term = buf.terminal

resetfn!(env)
trajlength = n = 0
while n < nsamples
done = rolloutstep!(actionfn!, traj, env)
done = _rolloutstep!(actionfn!, traj, env)
trajlength += 1

if done || trajlength == Hmax
terminate!(term, env, trajlength, done)
resetfn!(env)
n += trajlength
_terminate_trajectory!(term, env, trajlength, done)
resetfn!(env)
trajlength = 0
end
end
sampler
nothing
end




function _threadsample!(
actionfn!::F,
resetfn!::G,
Expand All @@ -134,51 +134,51 @@ function _threadsample!(
resetfn!(env)
trajlength = 0
while true
done = rolloutstep!(actionfn!, traj, env)
done = _rolloutstep!(actionfn!, traj, env)
trajlength += 1

if atomiccount[] >= nsamples
break
elseif atomiccount[] + trajlength >= nsamples
Threads.atomic_add!(atomiccount, trajlength)
terminate!(term, env, trajlength, done)
break
elseif done || trajlength == Hmax
if done || trajlength == Hmax
Threads.atomic_add!(atomiccount, trajlength)
terminate!(term, env, trajlength, done)
resetfn!(env)
trajlength = 0
_terminate_trajectory!(term, env, trajlength, done)
if atomiccount[] >= nsamples
break
else
resetfn!(env)
trajlength = 0
end
end
end
sampler
nothing
end


function rolloutstep!(actionfn!::F, traj::ElasticBuffer, env::AbstractEnvironment) where {F}
function _rolloutstep!(actionfn!::F, traj::ElasticBuffer, env::AbstractEnvironment) where {F}
grow!(traj)
t = lastindex(traj)
@uviews traj begin
st, ot, at =
view(traj.states, :, t), view(traj.observations, :, t), view(traj.actions, :, t)
st = view(traj.states, :, t)
ot = view(traj.observations, :, t)
at = view(traj.actions, :, t)

getstate!(st, env)
getobs!(ot, env)
actionfn!(at, st, ot)

actionfn!(at, st, ot)
setaction!(env, at)
step!(env)

r = getreward(st, at, ot, env)
e = geteval(st, at, ot, env)
done = isdone(st, at, ot, env)
step!(env)

traj.rewards[t] = r
traj.evaluations[t] = e
return done
traj.rewards[t] = getreward(st, at, ot, env)
traj.evaluations[t] = geteval(st, at, ot, env)
return isdone(st, at, ot, env)
end
end

function terminate!(term::ElasticBuffer, env::AbstractEnvironment, trajlength::Integer, done::Bool)
function _terminate_trajectory!(
term::ElasticBuffer,
env::AbstractEnvironment,
trajlength::Integer,
done::Bool,
)
grow!(term)
i = lastindex(term)
@uviews term begin
Expand All @@ -191,9 +191,21 @@ function terminate!(term::ElasticBuffer, env::AbstractEnvironment, trajlength::I
term
end

function _emptybufs!(sampler::EnvSampler)
for buf in sampler.bufs
empty!(buf.trajectory)
empty!(buf.terminal)
end
nothing
end

function makebatch(env::AbstractEnvironment; sizehint::Union{Integer,Nothing} = nothing, dtype::Maybe{DataType} = nothing)
function _makebatch(
env::AbstractEnvironment;
sizehint::Maybe{Integer} = nothing,
dtype::Maybe{DataType} = nothing,
)
sp = dtype === nothing ? spaces(env) : adapt(dtype, spaces(env))

batch = (
states = BatchedArray(sp.statespace),
observations = BatchedArray(sp.obsspace),
Expand All @@ -204,68 +216,65 @@ function makebatch(env::AbstractEnvironment; sizehint::Union{Integer,Nothing} =
terminal_observations = BatchedArray(sp.obsspace),
dones = Vector{Bool}(),
)
!isnothing(sizehint) && foreach(el -> sizehint!(el, sizehint), batch)

sizehint !== nothing && foreach(el -> sizehint!(el, sizehint), batch)

batch
end

function collate!(sampler::EnvSampler, n::Integer, copy::Bool)
batch = copy ? map(deepcopy, sampler.batch) : sampler.batch
function _collate!(sampler::EnvSampler, N::Integer, copy::Bool)
batch = copy ? map(deepcopy, sampler.batch) : sampler.batch # TODO
for b in batch
empty!(b)
sizehint!(b, n)
sizehint!(b, N)
end

count = 0
maxlen = maximum(map(buf -> length(buf.terminal), sampler.bufs))
for i = 1:maxlen, (tid, buf) in enumerate(sampler.bufs)
count >= n && break

togo = N - count
for buf in sampler.bufs
trajbuf = buf.trajectory
termbuf = buf.terminal
length(termbuf) < i && continue
from = firstindex(trajbuf)

@uviews trajbuf termbuf for episode_idx in eachindex(termbuf)
togo = N - count
togo == 0 && return batch

len = termbuf.lengths[episode_idx]
if togo < len # we only want `N` samples
len = togo
# because we cropped this trajectory, it didn't actually terminate early
termbuf.dones[episode_idx] = false
end

@uviews trajbuf termbuf begin
from = i == 1 ? 1 : sum(view(termbuf.lengths, 1:i-1)) + 1
len = termbuf.lengths[i]
len = count + len > n ? n - count : len
to = from + len - 1
until = from + len
to = until - 1
count += len

traj = view(trajbuf, from:to)
term = view(termbuf, i)
term = view(termbuf, episode_idx)

push!(batch.states, traj.states)
push!(batch.observations, traj.observations)
push!(batch.actions, traj.actions)
push!(batch.rewards, traj.rewards)
push!(batch.evaluations, traj.evaluations)
push!(batch.terminal_states, reshape(term.states, (:, 1)))
push!(batch.terminal_observations, reshape(term.observations, (:, 1)))
push!(batch.dones, termbuf.dones[i])
push!(batch.terminal_states, reshape(term.states, (:, 1))) # TODO (:, 1)
push!(batch.terminal_observations, reshape(term.observations, (:, 1))) #TODO (:, 1)
push!(batch.dones, termbuf.dones[episode_idx])

count += length(from:to)
from = to + 1
from = until
end
end

nbatches = length(batch.dones)
@assert length(batch.terminal_states) ==
length(batch.terminal_observations) ==
length(batch.dones)
@assert nsamples(batch.states) == n
@assert nsamples(batch.observations) == n
@assert nsamples(batch.actions) == n
@assert nsamples(batch.rewards) == n
@assert nsamples(batch.evaluations) == n


batch
end


function _defaultnthreads(nsamples, Hmax, nthreads)
d, r = divrem(nsamples, Hmax)
if r > 0
d += 1
end
min(d, nthreads)
function _checkbatch(b, n)
@assert n == nsamples(b.states)
@assert n == nsamples(b.observations)
@assert n == nsamples(b.actions)
@assert n == nsamples(b.rewards)
@assert n == nsamples(b.evaluations)
@assert length(b.dones) == length(b.terminal_states) == length(b.terminal_observations)
nothing
end