Skip to content

Commit 3d6c13a

Browse files
authored
Merge pull request aimacode#471 from jsuyash1514/mcts3e
Monte Carlo Tree Search - AIMA3e
2 parents 0332c44 + ae52668 commit 3d6c13a

File tree

10 files changed

+585
-31
lines changed

10 files changed

+585
-31
lines changed

aima-core/src/main/java/aima/core/environment/tictactoe/TicTacToeState.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,22 @@ public class TicTacToeState implements Cloneable {
1919
public static final String X = "X";
2020
public static final String EMPTY = "-";
2121
//
22-
private String[] board = new String[] { EMPTY, EMPTY, EMPTY, EMPTY, EMPTY,
23-
EMPTY, EMPTY, EMPTY, EMPTY };
22+
private String[] board;
2423

25-
private String playerToMove = X;
24+
private String playerToMove;
2625
private double utility = -1; // 1: win for X, 0: win for O, 0.5: draw
26+
27+
public TicTacToeState(){
28+
this.board = new String[] { EMPTY, EMPTY, EMPTY, EMPTY, EMPTY, EMPTY, EMPTY, EMPTY, EMPTY };
29+
playerToMove = X;
30+
}
31+
32+
public TicTacToeState(String[] board, String playerToMove){
33+
this.board = board;
34+
this.playerToMove = (Objects.equals(playerToMove, X) ? O : X);
35+
analyzeUtility();
36+
this.playerToMove = playerToMove;
37+
}
2738

2839
public String getPlayerToMove() {
2940
return playerToMove;
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package aima.core.environment.twoply;
2+
3+
import aima.core.environment.map.MoveToAction;
4+
import aima.core.search.adversarial.Game;
5+
6+
import java.util.List;
7+
8+
public class TwoPlyGame implements Game<TwoPlyGameState, MoveToAction, String> {
9+
10+
@Override
11+
public TwoPlyGameState getInitialState() {
12+
return new TwoPlyGameState("A");
13+
}
14+
15+
@Override
16+
public String[] getPlayers() {
17+
return new String[]{"MAX", "MIN"};
18+
}
19+
20+
@Override
21+
public String getPlayer(TwoPlyGameState state) {
22+
switch (state.getLocation()) {
23+
case "B":
24+
case "C":
25+
case "D":
26+
return "MIN";
27+
default:
28+
return "MAX";
29+
}
30+
}
31+
32+
33+
@Override
34+
public List<MoveToAction> getActions(TwoPlyGameState state) {
35+
return new TwoPlyGameTree().getActions(state);
36+
}
37+
38+
@Override
39+
public TwoPlyGameState getResult(TwoPlyGameState state, MoveToAction action) {
40+
return new TwoPlyGameState(action.getToLocation());
41+
}
42+
43+
@Override
44+
public boolean isTerminal(TwoPlyGameState state) {
45+
return (state.getLocation().charAt(0) > 'D');
46+
}
47+
48+
@Override
49+
public double getUtility(TwoPlyGameState state, String player) {
50+
switch (state.getLocation()) {
51+
// B
52+
case "E":
53+
return 3;
54+
case "F":
55+
return 12;
56+
case "G":
57+
return 8;
58+
// C
59+
case "H":
60+
return 2;
61+
case "I":
62+
return 4;
63+
case "J":
64+
return 6;
65+
// D
66+
case "K":
67+
return 14;
68+
case "L":
69+
return 5;
70+
case "M":
71+
return 2;
72+
default:
73+
throw new IllegalArgumentException("State " + state.getLocation() + " unexpected.");
74+
}
75+
}
76+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package aima.core.environment.twoply;
2+
3+
public class TwoPlyGameState {
4+
private String location;
5+
6+
public TwoPlyGameState(String location) {
7+
this.location = location;
8+
}
9+
10+
public String getLocation() {
11+
return location;
12+
}
13+
14+
@Override
15+
public boolean equals(Object obj) {
16+
if (obj instanceof TwoPlyGameState) {
17+
return this.getLocation().equals(((TwoPlyGameState) obj).getLocation());
18+
}
19+
return super.equals(obj);
20+
}
21+
22+
@Override
23+
public int hashCode() {
24+
return getLocation().hashCode();
25+
}
26+
27+
@Override
28+
public String toString() {
29+
return location;
30+
}
31+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package aima.core.environment.twoply;
2+
3+
import aima.core.environment.map.ExtendableMap;
4+
import aima.core.environment.map.Map;
5+
import aima.core.environment.map.MoveToAction;
6+
7+
import java.util.ArrayList;
8+
import java.util.List;
9+
10+
public class TwoPlyGameTree {
11+
Map aima3eFig5_2;
12+
13+
public TwoPlyGameTree() {
14+
aima3eFig5_2 = new ExtendableMap() {
15+
{
16+
addUnidirectionalLink("A", "B", 1.0);
17+
addUnidirectionalLink("A", "C", 1.0);
18+
addUnidirectionalLink("A", "D", 1.0);
19+
addUnidirectionalLink("B", "E", 1.0);
20+
addUnidirectionalLink("B", "F", 1.0);
21+
addUnidirectionalLink("B", "G", 1.0);
22+
addUnidirectionalLink("C", "H", 1.0);
23+
addUnidirectionalLink("C", "I", 1.0);
24+
addUnidirectionalLink("C", "J", 1.0);
25+
addUnidirectionalLink("D", "K", 1.0);
26+
addUnidirectionalLink("D", "L", 1.0);
27+
addUnidirectionalLink("D", "M", 1.0);
28+
}
29+
};
30+
}
31+
32+
public List<MoveToAction> getActions(TwoPlyGameState state) {
33+
List<MoveToAction> possibleActions = new ArrayList<>();
34+
List<String> nextPossibleLocations = aima3eFig5_2.getPossibleNextLocations(state.getLocation());
35+
for (String nextLocation : nextPossibleLocations) {
36+
MoveToAction action = new MoveToAction(nextLocation);
37+
possibleActions.add(action);
38+
}
39+
return possibleActions;
40+
}
41+
42+
}
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
package aima.core.search.adversarial;
2+
3+
import aima.core.search.framework.GameTree;
4+
import aima.core.search.framework.Metrics;
5+
import aima.core.search.framework.Node;
6+
import aima.core.search.framework.NodeFactory;
7+
8+
import java.util.ArrayList;
9+
import java.util.List;
10+
import java.util.Random;
11+
12+
/**
13+
* Artificial Intelligence A Modern Approach (4th Edition): page ???.<br>
14+
*
15+
* <pre>
16+
* <code>
17+
* function MONTE-CARLO-TREE-SEARCH(state) returns an action
18+
* tree &larr; NODE(state)
19+
* while TIME-REMAINING() do
20+
* leaf &larr; SELECT(tree)
21+
* child &larr; EXPAND(leaf)
22+
* result &larr; SIMULATE(child)
23+
* BACKPROPAGATE(result, child)
24+
* return the move in ACTIONS(state) whose node has highest number of playouts
25+
* </code>
26+
* </pre>
27+
*
28+
* Figure ?.? The Monte Carlo tree search algorithm. A game tree, tree, is initialized, and
29+
* then we repeat the cycle of SELECT / EXPAND / SIMULATE/ BACKPROPAGATE until we run out
30+
* of time, and return the move that led to the node with the highest number of playouts.
31+
*
32+
*
33+
* @author Suyash Jain
34+
*
35+
* @param <S>
36+
* Type which is used for states in the game.
37+
* @param <A>
38+
* Type which is used for actions in the game.
39+
* @param <P>
40+
* Type which is used for players in the game.
41+
*/
42+
43+
public class MonteCarloTreeSearch<S, A, P> implements AdversarialSearch<S, A> {
44+
private int iterations = 0;
45+
private Game<S, A, P> game;
46+
private GameTree<S, A> tree;
47+
48+
public MonteCarloTreeSearch(Game<S, A, P> game, int iterations) {
49+
this.game = game;
50+
this.iterations = iterations;
51+
tree = new GameTree<>();
52+
}
53+
54+
@Override
55+
public A makeDecision(S state) {
56+
// tree <-- NODE(state)
57+
tree.addRoot(state);
58+
// while TIME-REMAINING() do
59+
while (iterations != 0) {
60+
// leaf <-- SELECT(tree)
61+
Node<S, A> leaf = select(tree);
62+
// child <-- EXPAND(leaf)
63+
Node<S, A> child = expand(leaf);
64+
// result <-- SIMULATE(child)
65+
// result = true if player of root node wins
66+
boolean result = simulate(child);
67+
// BACKPROPAGATE(result, child)
68+
backpropagate(result, child);
69+
// repeat the four steps for set number of iterations
70+
--iterations;
71+
}
72+
// return the move in ACTIONS(state) whose node has highest number of playouts
73+
return bestAction(tree.getRoot());
74+
}
75+
76+
private Node<S, A> select(GameTree gameTree) {
77+
Node<S, A> node = gameTree.getRoot();
78+
while (!game.isTerminal(node.getState()) && isNodeFullyExpanded(node)) {
79+
node = gameTree.getChildWithMaxUCT(node);
80+
}
81+
return node;
82+
}
83+
84+
private Node<S, A> expand(Node<S, A> leaf) {
85+
if (game.isTerminal(leaf.getState())) return leaf;
86+
else {
87+
Node<S, A> child = randomlySelectUnvisitedChild(leaf);
88+
return child;
89+
}
90+
}
91+
92+
private boolean simulate(Node<S, A> node) {
93+
while (!game.isTerminal(node.getState())) {
94+
Random rand = new Random();
95+
A a = game.getActions(node.getState()).get(rand.nextInt(game.getActions(node.getState()).size()));
96+
S result = game.getResult(node.getState(), a);
97+
NodeFactory nodeFactory = new NodeFactory();
98+
node = nodeFactory.createNode(result);
99+
}
100+
if (game.getUtility(node.getState(), game.getPlayer(tree.getRoot().getState())) > 0) return true;
101+
else return false;
102+
}
103+
104+
private void backpropagate(boolean result, Node<S, A> node) {
105+
tree.updateStats(result, node);
106+
if (tree.getParent(node) != null) backpropagate(result, tree.getParent(node));
107+
}
108+
109+
private A bestAction(Node<S, A> root) {
110+
Node<S, A> bestChild = tree.getChildWithMaxPlayouts(root);
111+
for (A a : game.getActions(root.getState())) {
112+
S result = game.getResult(root.getState(), a);
113+
if (result.equals(bestChild.getState())) return a;
114+
}
115+
return null;
116+
}
117+
118+
private boolean isNodeFullyExpanded(Node<S, A> node) {
119+
List<S> visitedChildren = tree.getVisitedChildren(node);
120+
for (A a : game.getActions(node.getState())) {
121+
S result = game.getResult(node.getState(), a);
122+
if (!visitedChildren.contains(result)) {
123+
return false;
124+
}
125+
}
126+
return true;
127+
}
128+
129+
130+
private Node<S, A> randomlySelectUnvisitedChild(Node<S, A> node) {
131+
List<S> unvisitedChildren = new ArrayList<>();
132+
List<S> visitedChildren = tree.getVisitedChildren(node);
133+
for (A a : game.getActions(node.getState())) {
134+
S result = game.getResult(node.getState(), a);
135+
if (!visitedChildren.contains(result)) unvisitedChildren.add(result);
136+
}
137+
Random rand = new Random();
138+
Node<S, A> newChild = tree.addChild(node, unvisitedChildren.get(rand.nextInt(unvisitedChildren.size())));
139+
return newChild;
140+
}
141+
142+
@Override
143+
public Metrics getMetrics() {
144+
return null;
145+
}
146+
}

0 commit comments

Comments
 (0)