Skip to content

Commit a27a59e

Browse files
committed
Adds concrete pomdp implementation
1 parent 6b2a62e commit a27a59e

File tree

3 files changed

+113
-5
lines changed

3 files changed

+113
-5
lines changed

aima-core/src/main/java/aima/core/probability/mdp/search/POMDPValueIteration.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010
public class POMDPValueIteration<S,A extends Action,E> {
1111
public POMDP<S,A,E> pomdp;
1212
public double maxError;
13+
public int depth;
1314

14-
public POMDPValueIteration(POMDP<S, A, E> pomdp, double maxError) {
15+
public POMDPValueIteration(POMDP<S, A, E> pomdp, double maxError, int maxDepth) {
1516
this.pomdp = pomdp;
1617
this.maxError = maxError;
18+
this.depth = maxDepth;
1719
}
1820

1921
public HashMap<List<A>, List<Double>> pomdpValueIteration(POMDP pomdp, double maxError){
@@ -26,10 +28,12 @@ public HashMap<List<A>, List<Double>> pomdpValueIteration(POMDP pomdp, double ma
2628
}
2729
uDash = new HashMap<>();
2830
uDash.put(new ArrayList<>(),utilities);
29-
while(maxDifference(u,uDash) < maxError*(1-pomdp.getDiscount())/pomdp.getDiscount()){
31+
int i = 0;
32+
while(maxDifference(u,uDash) < maxError*(1-pomdp.getDiscount())/pomdp.getDiscount() || (i<=this.depth)){
3033
u = new HashMap<>(uDash);
3134
uDash = increasePlanDepths(uDash);
3235
uDash = removeDominatedPlans(uDash);
36+
i++;
3337
}
3438
return u;
3539
}
@@ -54,7 +58,7 @@ private HashMap<List<A>, List<Double>> increasePlanDepths(HashMap<List<A>,
5458
this.pomdp.states()) {
5559
tempUtility+=this.pomdp.sensorModel(observation,
5660
actualState)*uDash.get(plan).
57-
get(((ArrayList)this.pomdp.states()).indexOf(actualState));
61+
get((new ArrayList<>(this.pomdp.states())).indexOf(actualState));
5862
}
5963
planUtility = tempUtility*this.pomdp.transitionProbability(actualState,
6064
currentState,action);
@@ -70,10 +74,10 @@ private HashMap<List<A>, List<Double>> increasePlanDepths(HashMap<List<A>,
7074
}
7175

7276
private HashMap<List<A>, List<Double>> removeDominatedPlans(HashMap<List<A>, List<Double>> uDash) {
73-
return null;
77+
return uDash;
7478
}
7579

7680
private double maxDifference(HashMap<List<A>, List<Double>> u, HashMap<List<A>, List<Double>> uDash) {
77-
return 0.0;
81+
return 2;
7882
}
7983
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package aima.test.core.unit.probability;
2+
3+
import java.util.HashSet;
4+
import java.util.Set;
5+
6+
public class POMDP implements aima.core.probability.mdp.POMDP {
7+
double gamma = 1.0;
8+
State initialState = State.ZERO;
9+
10+
@Override
11+
public double getDiscount() {
12+
return gamma;
13+
}
14+
15+
@Override
16+
public double sensorModel(Object observedState, Object actualState) {
17+
if (observedState.equals(actualState))
18+
return 0.9;
19+
else
20+
return 0.1;
21+
}
22+
23+
@Override
24+
public Set getAllActions() {
25+
HashSet<Action> actions = new HashSet<>();
26+
actions.add(Action.GO);
27+
actions.add(Action.STAY);
28+
return actions;
29+
}
30+
31+
@Override
32+
public Set states() {
33+
HashSet<State> states = new HashSet<>();
34+
states.add(State.ZERO);
35+
states.add(State.ONE);
36+
return states;
37+
}
38+
39+
@Override
40+
public Object getInitialState() {
41+
return this.initialState;
42+
}
43+
44+
@Override
45+
public Set actions(Object o) {
46+
return this.getAllActions();
47+
}
48+
49+
@Override
50+
public double transitionProbability(Object sDelta, Object o, aima.core.agent.Action action) {
51+
if (action.equals(Action.GO)) {
52+
if (sDelta.equals(o))
53+
return 0.1;
54+
else
55+
return 0.9;
56+
} else if (action.equals(Action.STAY)) {
57+
if (sDelta.equals(o))
58+
return 0.9;
59+
else
60+
return 0.1;
61+
}
62+
return 0;
63+
}
64+
65+
@Override
66+
public double reward(Object o) {
67+
if (o.equals(State.ZERO))
68+
return 0.0;
69+
else
70+
return 1.0;
71+
}
72+
73+
public enum State {
74+
ZERO, ONE
75+
}
76+
77+
public enum Action implements aima.core.agent.Action {
78+
GO, STAY;
79+
80+
@Override
81+
public boolean isNoOp() {
82+
return false;
83+
}
84+
}
85+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package aima.test.core.unit.probability;
2+
3+
import aima.core.probability.mdp.search.POMDPValueIteration;
4+
import org.junit.Test;
5+
6+
public class POMDPValueIterationTest {
7+
@Test
8+
public void test(){
9+
POMDPValueIteration algo = new POMDPValueIteration(new POMDP(),0.1,2);
10+
System.out.println(algo.pomdpValueIteration(new POMDP(),0.1).toString());
11+
/**
12+
* Result comes out to be:
13+
* {[STAY, GO]=[1.71, 1.19],
14+
* [GO, STAY]=[0.11000000000000001, 1.9900000000000002],
15+
* [STAY, STAY]=[0.19, 2.71],
16+
* [GO, GO]=[0.9900000000000001, 1.11]}
17+
*/
18+
}
19+
}

0 commit comments

Comments
 (0)