Skip to content

Commit b60ff87

Browse files
committed
1
1 parent bc7900a commit b60ff87

File tree

1 file changed

+346
-0
lines changed
  • deep-learning/Deep-Reinforcement-Learning-Complete-Collection/DeepRL-Code/chapter01

1 file changed

+346
-0
lines changed
Lines changed: 346 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,346 @@
1+
#######################################################################
2+
# Copyright (C) #
3+
# 2016 - 2018 Shangtong Zhang([email protected]) #
4+
# 2016 Jan Hakenberg([email protected]) #
5+
# 2016 Tian Jun([email protected]) #
6+
# 2016 Kenta Shimada([email protected]) #
7+
# Permission given to modify the code as long as you keep this #
8+
# declaration at the top #
9+
#######################################################################
10+
11+
import numpy as np
12+
import pickle
13+
14+
BOARD_ROWS = 3
15+
BOARD_COLS = 3
16+
BOARD_SIZE = BOARD_ROWS * BOARD_COLS
17+
18+
19+
class State:
20+
def __init__(self):
21+
# the board is represented by an n * n array,
22+
# 1 represents a chessman of the player who moves first,
23+
# -1 represents a chessman of another player
24+
# 0 represents an empty position
25+
self.data = np.zeros((BOARD_ROWS, BOARD_COLS))
26+
self.winner = None
27+
self.hash_val = None
28+
self.end = None
29+
30+
# compute the hash value for one state, it's unique
31+
def hash(self):
32+
if self.hash_val is None:
33+
self.hash_val = 0
34+
for i in np.nditer(self.data):
35+
self.hash_val = self.hash_val * 3 + i + 1
36+
return self.hash_val
37+
38+
# check whether a player has won the game, or it's a tie
39+
def is_end(self):
40+
if self.end is not None:
41+
return self.end
42+
results = []
43+
# check row
44+
for i in range(BOARD_ROWS):
45+
results.append(np.sum(self.data[i, :]))
46+
# check columns
47+
for i in range(BOARD_COLS):
48+
results.append(np.sum(self.data[:, i]))
49+
50+
# check diagonals
51+
trace = 0
52+
reverse_trace = 0
53+
for i in range(BOARD_ROWS):
54+
trace += self.data[i, i]
55+
reverse_trace += self.data[i, BOARD_ROWS - 1 - i]
56+
results.append(trace)
57+
results.append(reverse_trace)
58+
59+
for result in results:
60+
if result == 3:
61+
self.winner = 1
62+
self.end = True
63+
return self.end
64+
if result == -3:
65+
self.winner = -1
66+
self.end = True
67+
return self.end
68+
69+
# whether it's a tie
70+
sum_values = np.sum(np.abs(self.data))
71+
if sum_values == BOARD_SIZE:
72+
self.winner = 0
73+
self.end = True
74+
return self.end
75+
76+
# game is still going on
77+
self.end = False
78+
return self.end
79+
80+
# @symbol: 1 or -1
81+
# put chessman symbol in position (i, j)
82+
def next_state(self, i, j, symbol):
83+
new_state = State()
84+
new_state.data = np.copy(self.data)
85+
new_state.data[i, j] = symbol
86+
return new_state
87+
88+
# print the board
89+
def print_state(self):
90+
for i in range(BOARD_ROWS):
91+
print('-------------')
92+
out = '| '
93+
for j in range(BOARD_COLS):
94+
if self.data[i, j] == 1:
95+
token = '*'
96+
elif self.data[i, j] == -1:
97+
token = 'x'
98+
else:
99+
token = '0'
100+
out += token + ' | '
101+
print(out)
102+
print('-------------')
103+
104+
105+
def get_all_states_impl(current_state, current_symbol, all_states):
106+
for i in range(BOARD_ROWS):
107+
for j in range(BOARD_COLS):
108+
if current_state.data[i][j] == 0:
109+
new_state = current_state.next_state(i, j, current_symbol)
110+
new_hash = new_state.hash()
111+
if new_hash not in all_states:
112+
is_end = new_state.is_end()
113+
all_states[new_hash] = (new_state, is_end)
114+
if not is_end:
115+
get_all_states_impl(new_state, -current_symbol, all_states)
116+
117+
118+
def get_all_states():
119+
current_symbol = 1
120+
current_state = State()
121+
all_states = dict()
122+
all_states[current_state.hash()] = (current_state, current_state.is_end())
123+
get_all_states_impl(current_state, current_symbol, all_states)
124+
return all_states
125+
126+
127+
# all possible board configurations
128+
all_states = get_all_states()
129+
130+
131+
class Judger:
132+
# @player1: the player who will move first, its chessman will be 1
133+
# @player2: another player with a chessman -1
134+
def __init__(self, player1, player2):
135+
self.p1 = player1
136+
self.p2 = player2
137+
self.current_player = None
138+
self.p1_symbol = 1
139+
self.p2_symbol = -1
140+
self.p1.set_symbol(self.p1_symbol)
141+
self.p2.set_symbol(self.p2_symbol)
142+
self.current_state = State()
143+
144+
def reset(self):
145+
self.p1.reset()
146+
self.p2.reset()
147+
148+
def alternate(self):
149+
while True:
150+
yield self.p1
151+
yield self.p2
152+
153+
# @print_state: if True, print each board during the game
154+
def play(self, print_state=False):
155+
alternator = self.alternate()
156+
self.reset()
157+
current_state = State()
158+
self.p1.set_state(current_state)
159+
self.p2.set_state(current_state)
160+
if print_state:
161+
current_state.print_state()
162+
while True:
163+
player = next(alternator)
164+
i, j, symbol = player.act()
165+
next_state_hash = current_state.next_state(i, j, symbol).hash()
166+
current_state, is_end = all_states[next_state_hash]
167+
self.p1.set_state(current_state)
168+
self.p2.set_state(current_state)
169+
if print_state:
170+
current_state.print_state()
171+
if is_end:
172+
return current_state.winner
173+
174+
175+
# AI player
176+
class Player:
177+
# @step_size: the step size to update estimations
178+
# @epsilon: the probability to explore
179+
def __init__(self, step_size=0.1, epsilon=0.1):
180+
self.estimations = dict()
181+
self.step_size = step_size
182+
self.epsilon = epsilon
183+
self.states = []
184+
self.greedy = []
185+
self.symbol = 0
186+
187+
def reset(self):
188+
self.states = []
189+
self.greedy = []
190+
191+
def set_state(self, state):
192+
self.states.append(state)
193+
self.greedy.append(True)
194+
195+
def set_symbol(self, symbol):
196+
self.symbol = symbol
197+
for hash_val in all_states:
198+
state, is_end = all_states[hash_val]
199+
if is_end:
200+
if state.winner == self.symbol:
201+
self.estimations[hash_val] = 1.0
202+
elif state.winner == 0:
203+
# we need to distinguish between a tie and a lose
204+
self.estimations[hash_val] = 0.5
205+
else:
206+
self.estimations[hash_val] = 0
207+
else:
208+
self.estimations[hash_val] = 0.5
209+
210+
# update value estimation
211+
def backup(self):
212+
states = [state.hash() for state in self.states]
213+
214+
for i in reversed(range(len(states) - 1)):
215+
state = states[i]
216+
td_error = self.greedy[i] * (
217+
self.estimations[states[i + 1]] - self.estimations[state]
218+
)
219+
self.estimations[state] += self.step_size * td_error
220+
221+
# choose an action based on the state
222+
def act(self):
223+
state = self.states[-1]
224+
next_states = []
225+
next_positions = []
226+
for i in range(BOARD_ROWS):
227+
for j in range(BOARD_COLS):
228+
if state.data[i, j] == 0:
229+
next_positions.append([i, j])
230+
next_states.append(state.next_state(
231+
i, j, self.symbol).hash())
232+
233+
if np.random.rand() < self.epsilon:
234+
action = next_positions[np.random.randint(len(next_positions))]
235+
action.append(self.symbol)
236+
self.greedy[-1] = False
237+
return action
238+
239+
values = []
240+
for hash_val, pos in zip(next_states, next_positions):
241+
values.append((self.estimations[hash_val], pos))
242+
# to select one of the actions of equal value at random due to Python's sort is stable
243+
np.random.shuffle(values)
244+
values.sort(key=lambda x: x[0], reverse=True)
245+
action = values[0][1]
246+
action.append(self.symbol)
247+
return action
248+
249+
def save_policy(self):
250+
with open('policy_%s.bin' % ('first' if self.symbol == 1 else 'second'), 'wb') as f:
251+
pickle.dump(self.estimations, f)
252+
253+
def load_policy(self):
254+
with open('policy_%s.bin' % ('first' if self.symbol == 1 else 'second'), 'rb') as f:
255+
self.estimations = pickle.load(f)
256+
257+
258+
# human interface
259+
# input a number to put a chessman
260+
# | q | w | e |
261+
# | a | s | d |
262+
# | z | x | c |
263+
class HumanPlayer:
264+
def __init__(self, **kwargs):
265+
self.symbol = None
266+
self.keys = ['q', 'w', 'e', 'a', 's', 'd', 'z', 'x', 'c']
267+
self.state = None
268+
269+
def reset(self):
270+
pass
271+
272+
def set_state(self, state):
273+
self.state = state
274+
275+
def set_symbol(self, symbol):
276+
self.symbol = symbol
277+
278+
def act(self):
279+
self.state.print_state()
280+
key = input("Input your position:")
281+
data = self.keys.index(key)
282+
i = data // BOARD_COLS
283+
j = data % BOARD_COLS
284+
return i, j, self.symbol
285+
286+
287+
def train(epochs, print_every_n=500):
288+
player1 = Player(epsilon=0.01)
289+
player2 = Player(epsilon=0.01)
290+
judger = Judger(player1, player2)
291+
player1_win = 0.0
292+
player2_win = 0.0
293+
for i in range(1, epochs + 1):
294+
winner = judger.play(print_state=False)
295+
if winner == 1:
296+
player1_win += 1
297+
if winner == -1:
298+
player2_win += 1
299+
if i % print_every_n == 0:
300+
print('Epoch %d, player 1 winrate: %.02f, player 2 winrate: %.02f' % (i, player1_win / i, player2_win / i))
301+
player1.backup()
302+
player2.backup()
303+
judger.reset()
304+
player1.save_policy()
305+
player2.save_policy()
306+
307+
308+
def compete(turns):
309+
player1 = Player(epsilon=0)
310+
player2 = Player(epsilon=0)
311+
judger = Judger(player1, player2)
312+
player1.load_policy()
313+
player2.load_policy()
314+
player1_win = 0.0
315+
player2_win = 0.0
316+
for _ in range(turns):
317+
winner = judger.play()
318+
if winner == 1:
319+
player1_win += 1
320+
if winner == -1:
321+
player2_win += 1
322+
judger.reset()
323+
print('%d turns, player 1 win %.02f, player 2 win %.02f' % (turns, player1_win / turns, player2_win / turns))
324+
325+
326+
# The game is a zero sum game. If both players are playing with an optimal strategy, every game will end in a tie.
327+
# So we test whether the AI can guarantee at least a tie if it goes second.
328+
def play():
329+
while True:
330+
player1 = HumanPlayer()
331+
player2 = Player(epsilon=0)
332+
judger = Judger(player1, player2)
333+
player2.load_policy()
334+
winner = judger.play()
335+
if winner == player2.symbol:
336+
print("You lose!")
337+
elif winner == player1.symbol:
338+
print("You win!")
339+
else:
340+
print("It is a tie!")
341+
342+
343+
if __name__ == '__main__':
344+
train(int(1e5))
345+
compete(int(1e3))
346+
play()

0 commit comments

Comments
 (0)