Skip to content

Commit 59d34ae

Browse files
committed
implemented initial distributions and terminal states (#2)
1 parent e0d2596 commit 59d34ae

File tree

2 files changed

+52
-23
lines changed

2 files changed

+52
-23
lines changed

src/discrete_explicit.jl

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
struct DiscreteExplicitPOMDP{S,A,O,OF,RF} <: POMDP{S,A,O}
1+
struct DiscreteExplicitPOMDP{S,A,O,OF,RF,D} <: POMDP{S,A,O}
22
s::Vector{S}
33
a::Vector{A}
44
o::Vector{O}
@@ -10,16 +10,20 @@ struct DiscreteExplicitPOMDP{S,A,O,OF,RF} <: POMDP{S,A,O}
1010
amap::Dict{A,Int}
1111
omap::Dict{O,Int}
1212
discount::Float64
13+
initial::D
14+
terminals::Set{S}
1315
end
1416

15-
struct DiscreteExplicitMDP{S,A,RF} <: MDP{S,A}
17+
struct DiscreteExplicitMDP{S,A,RF,D} <: MDP{S,A}
1618
s::Vector{S}
1719
a::Vector{A}
1820
tds::Dict{Tuple{S,A}, SparseCat{Vector{S}, Vector{Float64}}}
1921
r::RF
2022
smap::Dict{S,Int}
2123
amap::Dict{A,Int}
2224
discount::Float64
25+
initial::D
26+
terminals::Set{S}
2327
end
2428

2529
const DEP = DiscreteExplicitPOMDP
@@ -42,38 +46,51 @@ POMDPs.transition(m::DE, s, a) = m.tds[s,a]
4246
POMDPs.observation(m::DEP, a, sp) = m.ods[a,sp]
4347
POMDPs.reward(m::DE, s, a) = m.r(s, a)
4448

45-
POMDPs.initialstate_distribution(m::DEP) = uniform_belief(m)
46-
# XXX hack
47-
POMDPs.initialstate_distribution(m::DiscreteExplicitMDP) = uniform_belief(FullyObservablePOMDP(m))
49+
POMDPs.initialstate_distribution(m::DE) = m.initial
50+
51+
POMDPs.isterminal(m::DE,s) = s in m.terminals
52+
53+
#=
54+
POMDPs.convert_s(::Type{V}, s::W, m::DE) where {V<:AbstractArray,W<:AbstractArray} =
55+
POMDPs.convert_s(::Type{V}, s::W, m::DE) where {V<:AbstractVector} = convert_to_vec(V, s, m.smap)
56+
POMDPs.convert_s(::Type{V}, s::W, m::DE) where {V<:AbstractVector} = convert_to_vec(V, s, m.smap)
57+
58+
POMDPs.convert_s(::Type{V}, s, m::DE) where {V<:AbstractArray} = convert_to_vec(V, s, m.smap)
59+
POMDPs.convert_a(::Type{V}, a, m::DE) where {V<:AbstractArray} = convert_to_vec(V, a, m.amap)
60+
POMDPs.convert_o(::Type{V}, o, m::DEP) where {V<:AbstractArray} = convert_to_vec(V, o, m.omap)
61+
POMDPs.convert_s(::Type{}
62+
=#
63+
64+
#=
65+
convert_to_vec(V, x, map) = convert(V, [map[x]])
66+
convert_from_vec(T, v, space) = convert(T, space[convert(Integer, first(v))])
67+
=#
4868

4969
POMDPModelTools.ordered_states(m::DE) = m.s
5070
POMDPModelTools.ordered_actions(m::DE) = m.a
5171
POMDPModelTools.ordered_observations(m::DEP) = m.o
5272

53-
# TODO reward(m, s, a)
54-
# TODO support O(s, a, sp, o)
55-
# TODO initial state distribution
56-
# TODO convert_s, etc, dimensions
57-
# TODO better errors if T or Z return something unexpected
58-
5973
"""
60-
DiscreteExplicitPOMDP(S,A,O,T,Z,R,γ)
74+
DiscreteExplicitPOMDP(S,A,O,T,Z,R,γ,[b₀],[terminal=Set()])
6175
6276
Create a POMDP defined by the tuple (S,A,O,T,Z,R,γ).
6377
6478
# Arguments
6579
80+
## Required
6681
- `S`,`A`,`O`: State, action, and observation spaces (typically `Vector`s)
6782
- `T::Function`: Transition probability distribution function; ``T(s,a,s')`` is the probability of transitioning to state ``s'`` from state ``s`` after taking action ``a``.
6883
- `Z::Function`: Observation probability distribution function; ``O(a, s', o)`` is the probability of receiving observation ``o`` when state ``s'`` is reached after action ``a``.
6984
- `R::Function`: Reward function; ``R(s,a)`` is the reward for taking action ``a`` in state ``s``.
7085
- `γ::Float64`: Discount factor.
7186
72-
# Notes
73-
- The default initial state distribution is uniform across all states. Changing this is not yet supported, but it can be overridden for simulations.
74-
- Terminal states are not yet supported, but absorbing states with zero reward can be used.
87+
## Optional
88+
- `b₀=Uniform(S)`: Initial belief/state distribution (See `POMDPModelTools.Deterministic` and `POMDPModelTools.SparseCat` for other options).
89+
90+
## Keyword
91+
- `terminal=Set()`: Set of terminal states. Once a terminal state is reached, no more actions can be taken or reward received.
7592
"""
76-
function DiscreteExplicitPOMDP(s, a, o, t, z, r, discount)
93+
function DiscreteExplicitPOMDP(s, a, o, t, z, r, discount, b0=Uniform(s))
7794
ss = vec(collect(s))
7895
as = vec(collect(a))
7996
os = vec(collect(o))
@@ -107,7 +124,7 @@ function DiscreteExplicitPOMDP(s, a, o, t, z, r, discount)
107124
Dict(ss[i]=>i for i in 1:length(ss)),
108125
Dict(as[i]=>i for i in 1:length(as)),
109126
Dict(os[i]=>i for i in 1:length(os)),
110-
discount
127+
discount, b0, terminal
111128
)
112129

113130
probability_check(m)
@@ -116,22 +133,25 @@ function DiscreteExplicitPOMDP(s, a, o, t, z, r, discount)
116133
end
117134

118135
"""
119-
DiscreteExplicitMDP(S,A,T,R,γ)
136+
DiscreteExplicitMDP(S,A,T,R,γ,[p₀])
120137
121138
Create an MDP defined by the tuple (S,A,T,R,γ).
122139
123140
# Arguments
124141
142+
## Required
125143
- `S`,`A`: State and action spaces (typically `Vector`s)
126144
- `T::Function`: Transition probability distribution function; ``T(s,a,s')`` is the probability of transitioning to state ``s'`` from state ``s`` after taking action ``a``.
127145
- `R::Function`: Reward function; ``R(s,a)`` is the reward for taking action ``a`` in state ``s``.
128146
- `γ::Float64`: Discount factor.
129147
130-
# Notes
131-
- The default initial state distribution is uniform across all states. Changing this is not yet supported, but it can be overridden for simulations.
132-
- Terminal states are not yet supported, but absorbing states with zero reward can be used.
148+
## Optional
149+
- `p₀=Uniform(S)`: Initial state distribution (See `POMDPModelTools.Deterministic` and `POMDPModelTools.SparseCat` for other options).
150+
151+
## Keyword
152+
- `terminal=Set()`: Set of terminal states. Once a terminal state is reached, no more actions can be taken or reward received.
133153
"""
134-
function DiscreteExplicitMDP(s, a, t, r, discount)
154+
function DiscreteExplicitMDP(s, a, t, r, discount, p0=Uniform(s); terminal=Set())
135155
ss = vec(collect(s))
136156
as = vec(collect(a))
137157

@@ -141,7 +161,7 @@ function DiscreteExplicitMDP(s, a, t, r, discount)
141161
ss, as, tds, r,
142162
Dict(ss[i]=>i for i in 1:length(ss)),
143163
Dict(as[i]=>i for i in 1:length(as)),
144-
discount
164+
discount, p0, terminal
145165
)
146166

147167
trans_prob_consistency_check(m)

test/discrete_explicit.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,18 @@
4747
end
4848
println("Undiscounted reward was $rsum.")
4949
@test rsum == -10.0
50+
51+
dm = DiscreteExplicitPOMDP(S,A,O,T,Z,R,γ,Deterministic(:left))
52+
@test initialstate(dm, Random.GLOBAL_RNG) == :left
53+
tm = DiscreteExplicitPOMDP(S,A,O,T,Z,R,γ,terminal=Set(S))
54+
@test isterminal(tm, initialstate(tm, Random.GLOBAL_RNG))
5055
end
5156

5257
@testset "Discrete Explicit MDP" begin
5358
S = 1:5
5459
A = [-1, 1]
5560
γ = 0.95
61+
p₀ = Deterministic(1)
5662

5763
function T(s, a, sp)
5864
if sp == clamp(s+a,1,5)
@@ -73,6 +79,9 @@ end
7379
end
7480

7581
m = DiscreteExplicitMDP(S,A,T,R,γ)
82+
m = DiscreteExplicitMDP(S,A,T,R,γ,p₀)
83+
m = DiscreteExplicitMDP(S,A,T,R,γ,p₀,terminal=Set(5))
84+
@test isterminal(m, 5)
7685

7786
solver = FunctionSolver(x->1)
7887
policy = solve(solver, m)

0 commit comments

Comments
 (0)