Skip to content

Commit 779be37

Browse files
authored
Merge pull request #3 from JuliaPOMDP/initstates
implemented initial distributions and terminal states (#2)
2 parents 546396d + 0ee69b4 commit 779be37

File tree

4 files changed

+45
-25
lines changed

4 files changed

+45
-25
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@ POMDPTesting = "92e6a534-49c2-5324-9027-86e3c861ab81"
1010
POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
1111

1212
[compat]
13+
POMDPModelTools = ">=0.1.6"
1314
julia = "1"
1415

1516
[extras]
1617
POMDPPolicies = "182e52fb-cfd0-5e46-8c26-fd0667c990f4"
1718
POMDPSimulators = "e0d0a172-29c6-5d4e-96d0-f262df5d01fd"
1819
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
20+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1921

2022
[targets]
21-
test = ["Test", "POMDPPolicies", "POMDPSimulators"]
23+
test = ["Test", "POMDPPolicies", "POMDPSimulators", "Random"]

src/discrete_explicit.jl

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
struct DiscreteExplicitPOMDP{S,A,O,OF,RF} <: POMDP{S,A,O}
1+
struct DiscreteExplicitPOMDP{S,A,O,OF,RF,D} <: POMDP{S,A,O}
22
s::Vector{S}
33
a::Vector{A}
44
o::Vector{O}
@@ -10,16 +10,20 @@ struct DiscreteExplicitPOMDP{S,A,O,OF,RF} <: POMDP{S,A,O}
1010
amap::Dict{A,Int}
1111
omap::Dict{O,Int}
1212
discount::Float64
13+
initial::D
14+
terminals::Set{S}
1315
end
1416

15-
struct DiscreteExplicitMDP{S,A,RF} <: MDP{S,A}
17+
struct DiscreteExplicitMDP{S,A,RF,D} <: MDP{S,A}
1618
s::Vector{S}
1719
a::Vector{A}
1820
tds::Dict{Tuple{S,A}, SparseCat{Vector{S}, Vector{Float64}}}
1921
r::RF
2022
smap::Dict{S,Int}
2123
amap::Dict{A,Int}
2224
discount::Float64
25+
initial::D
26+
terminals::Set{S}
2327
end
2428

2529
const DEP = DiscreteExplicitPOMDP
@@ -42,38 +46,35 @@ POMDPs.transition(m::DE, s, a) = m.tds[s,a]
4246
POMDPs.observation(m::DEP, a, sp) = m.ods[a,sp]
4347
POMDPs.reward(m::DE, s, a) = m.r(s, a)
4448

45-
POMDPs.initialstate_distribution(m::DEP) = uniform_belief(m)
46-
# XXX hack
47-
POMDPs.initialstate_distribution(m::DiscreteExplicitMDP) = uniform_belief(FullyObservablePOMDP(m))
49+
POMDPs.initialstate_distribution(m::DE) = m.initial
50+
51+
POMDPs.isterminal(m::DE,s) = s in m.terminals
4852

4953
POMDPModelTools.ordered_states(m::DE) = m.s
5054
POMDPModelTools.ordered_actions(m::DE) = m.a
5155
POMDPModelTools.ordered_observations(m::DEP) = m.o
5256

53-
# TODO reward(m, s, a)
54-
# TODO support O(s, a, sp, o)
55-
# TODO initial state distribution
56-
# TODO convert_s, etc, dimensions
57-
# TODO better errors if T or Z return something unexpected
58-
5957
"""
60-
DiscreteExplicitPOMDP(S,A,O,T,Z,R,γ)
58+
DiscreteExplicitPOMDP(S,A,O,T,Z,R,γ,[b₀],[terminal=Set()])
6159
6260
Create a POMDP defined by the tuple (S,A,O,T,Z,R,γ).
6361
6462
# Arguments
6563
64+
## Required
6665
- `S`,`A`,`O`: State, action, and observation spaces (typically `Vector`s)
6766
- `T::Function`: Transition probability distribution function; ``T(s,a,s')`` is the probability of transitioning to state ``s'`` from state ``s`` after taking action ``a``.
6867
- `Z::Function`: Observation probability distribution function; ``O(a, s', o)`` is the probability of receiving observation ``o`` when state ``s'`` is reached after action ``a``.
6968
- `R::Function`: Reward function; ``R(s,a)`` is the reward for taking action ``a`` in state ``s``.
7069
- `γ::Float64`: Discount factor.
7170
72-
# Notes
73-
- The default initial state distribution is uniform across all states. Changing this is not yet supported, but it can be overridden for simulations.
74-
- Terminal states are not yet supported, but absorbing states with zero reward can be used.
71+
## Optional
72+
- `b₀=Uniform(S)`: Initial belief/state distribution (See `POMDPModelTools.Deterministic` and `POMDPModelTools.SparseCat` for other options).
73+
74+
## Keyword
75+
- `terminals=Set()`: Set of terminal states. Once a terminal state is reached, no more actions can be taken or reward received.
7576
"""
76-
function DiscreteExplicitPOMDP(s, a, o, t, z, r, discount)
77+
function DiscreteExplicitPOMDP(s, a, o, t, z, r, discount, b0=Uniform(s); terminals=Set())
7778
ss = vec(collect(s))
7879
as = vec(collect(a))
7980
os = vec(collect(o))
@@ -107,7 +108,7 @@ function DiscreteExplicitPOMDP(s, a, o, t, z, r, discount)
107108
Dict(ss[i]=>i for i in 1:length(ss)),
108109
Dict(as[i]=>i for i in 1:length(as)),
109110
Dict(os[i]=>i for i in 1:length(os)),
110-
discount
111+
discount, b0, convert(Set{eltype(ss)}, terminals)
111112
)
112113

113114
probability_check(m)
@@ -116,22 +117,25 @@ function DiscreteExplicitPOMDP(s, a, o, t, z, r, discount)
116117
end
117118

118119
"""
119-
DiscreteExplicitMDP(S,A,T,R,γ)
120+
DiscreteExplicitMDP(S,A,T,R,γ,[p₀])
120121
121122
Create an MDP defined by the tuple (S,A,T,R,γ).
122123
123124
# Arguments
124125
126+
## Required
125127
- `S`,`A`: State and action spaces (typically `Vector`s)
126128
- `T::Function`: Transition probability distribution function; ``T(s,a,s')`` is the probability of transitioning to state ``s'`` from state ``s`` after taking action ``a``.
127129
- `R::Function`: Reward function; ``R(s,a)`` is the reward for taking action ``a`` in state ``s``.
128130
- `γ::Float64`: Discount factor.
129131
130-
# Notes
131-
- The default initial state distribution is uniform across all states. Changing this is not yet supported, but it can be overridden for simulations.
132-
- Terminal states are not yet supported, but absorbing states with zero reward can be used.
132+
## Optional
133+
- `p₀=Uniform(S)`: Initial state distribution (See `POMDPModelTools.Deterministic` and `POMDPModelTools.SparseCat` for other options).
134+
135+
## Keyword
136+
- `terminals=Set()`: Set of terminal states. Once a terminal state is reached, no more actions can be taken or reward received.
133137
"""
134-
function DiscreteExplicitMDP(s, a, t, r, discount)
138+
function DiscreteExplicitMDP(s, a, t, r, discount, p0=Uniform(s); terminals=Set())
135139
ss = vec(collect(s))
136140
as = vec(collect(a))
137141

@@ -141,7 +145,7 @@ function DiscreteExplicitMDP(s, a, t, r, discount)
141145
ss, as, tds, r,
142146
Dict(ss[i]=>i for i in 1:length(ss)),
143147
Dict(as[i]=>i for i in 1:length(as)),
144-
discount
148+
discount, p0, convert(Set{eltype(ss)}, terminals)
145149
)
146150

147151
trans_prob_consistency_check(m)

test/discrete_explicit.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,18 @@
4747
end
4848
println("Undiscounted reward was $rsum.")
4949
@test rsum == -10.0
50+
51+
dm = DiscreteExplicitPOMDP(S,A,O,T,Z,R,γ,Deterministic(:left))
52+
@test initialstate(dm, Random.GLOBAL_RNG) == :left
53+
tm = DiscreteExplicitPOMDP(S,A,O,T,Z,R,γ,terminals=Set(S))
54+
@test isterminal(tm, initialstate(tm, Random.GLOBAL_RNG))
5055
end
5156

5257
@testset "Discrete Explicit MDP" begin
5358
S = 1:5
5459
A = [-1, 1]
5560
γ = 0.95
61+
p₀ = Deterministic(1)
5662

5763
function T(s, a, sp)
5864
if sp == clamp(s+a,1,5)
@@ -73,6 +79,9 @@ end
7379
end
7480

7581
m = DiscreteExplicitMDP(S,A,T,R,γ)
82+
m = DiscreteExplicitMDP(S,A,T,R,γ,p₀)
83+
m = DiscreteExplicitMDP(S,A,T,R,γ,p₀,terminals=Set(5))
84+
@test isterminal(m, 5)
7685

7786
solver = FunctionSolver(x->1)
7887
policy = solve(solver, m)

test/runtests.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
using QuickPOMDPs
22
using Test
33

4-
using POMDPs, POMDPPolicies, POMDPSimulators, BeliefUpdaters
4+
using POMDPs
5+
using POMDPPolicies
6+
using POMDPSimulators
7+
using BeliefUpdaters
8+
using POMDPModelTools
9+
using Random
510

611
include("discrete_explicit.jl")

0 commit comments

Comments
 (0)