|
| 1 | +####################################################################### |
| 2 | +# Copyright (C) # |
| 3 | +# 2016 - 2018 Shangtong Zhang([email protected]) # |
| 4 | +# 2016 Jan Hakenberg([email protected]) # |
| 5 | +# 2016 Tian Jun([email protected]) # |
| 6 | +# 2016 Kenta Shimada([email protected]) # |
| 7 | +# Permission given to modify the code as long as you keep this # |
| 8 | +# declaration at the top # |
| 9 | +####################################################################### |
| 10 | + |
| 11 | +import numpy as np |
| 12 | +import pickle |
| 13 | + |
| 14 | +BOARD_ROWS = 3 |
| 15 | +BOARD_COLS = 3 |
| 16 | +BOARD_SIZE = BOARD_ROWS * BOARD_COLS |
| 17 | + |
| 18 | + |
| 19 | +class State: |
| 20 | + def __init__(self): |
| 21 | + # the board is represented by an n * n array, |
| 22 | + # 1 represents a chessman of the player who moves first, |
| 23 | + # -1 represents a chessman of another player |
| 24 | + # 0 represents an empty position |
| 25 | + self.data = np.zeros((BOARD_ROWS, BOARD_COLS)) |
| 26 | + self.winner = None |
| 27 | + self.hash_val = None |
| 28 | + self.end = None |
| 29 | + |
| 30 | + # compute the hash value for one state, it's unique |
| 31 | + def hash(self): |
| 32 | + if self.hash_val is None: |
| 33 | + self.hash_val = 0 |
| 34 | + for i in np.nditer(self.data): |
| 35 | + self.hash_val = self.hash_val * 3 + i + 1 |
| 36 | + return self.hash_val |
| 37 | + |
| 38 | + # check whether a player has won the game, or it's a tie |
| 39 | + def is_end(self): |
| 40 | + if self.end is not None: |
| 41 | + return self.end |
| 42 | + results = [] |
| 43 | + # check row |
| 44 | + for i in range(BOARD_ROWS): |
| 45 | + results.append(np.sum(self.data[i, :])) |
| 46 | + # check columns |
| 47 | + for i in range(BOARD_COLS): |
| 48 | + results.append(np.sum(self.data[:, i])) |
| 49 | + |
| 50 | + # check diagonals |
| 51 | + trace = 0 |
| 52 | + reverse_trace = 0 |
| 53 | + for i in range(BOARD_ROWS): |
| 54 | + trace += self.data[i, i] |
| 55 | + reverse_trace += self.data[i, BOARD_ROWS - 1 - i] |
| 56 | + results.append(trace) |
| 57 | + results.append(reverse_trace) |
| 58 | + |
| 59 | + for result in results: |
| 60 | + if result == 3: |
| 61 | + self.winner = 1 |
| 62 | + self.end = True |
| 63 | + return self.end |
| 64 | + if result == -3: |
| 65 | + self.winner = -1 |
| 66 | + self.end = True |
| 67 | + return self.end |
| 68 | + |
| 69 | + # whether it's a tie |
| 70 | + sum_values = np.sum(np.abs(self.data)) |
| 71 | + if sum_values == BOARD_SIZE: |
| 72 | + self.winner = 0 |
| 73 | + self.end = True |
| 74 | + return self.end |
| 75 | + |
| 76 | + # game is still going on |
| 77 | + self.end = False |
| 78 | + return self.end |
| 79 | + |
| 80 | + # @symbol: 1 or -1 |
| 81 | + # put chessman symbol in position (i, j) |
| 82 | + def next_state(self, i, j, symbol): |
| 83 | + new_state = State() |
| 84 | + new_state.data = np.copy(self.data) |
| 85 | + new_state.data[i, j] = symbol |
| 86 | + return new_state |
| 87 | + |
| 88 | + # print the board |
| 89 | + def print_state(self): |
| 90 | + for i in range(BOARD_ROWS): |
| 91 | + print('-------------') |
| 92 | + out = '| ' |
| 93 | + for j in range(BOARD_COLS): |
| 94 | + if self.data[i, j] == 1: |
| 95 | + token = '*' |
| 96 | + elif self.data[i, j] == -1: |
| 97 | + token = 'x' |
| 98 | + else: |
| 99 | + token = '0' |
| 100 | + out += token + ' | ' |
| 101 | + print(out) |
| 102 | + print('-------------') |
| 103 | + |
| 104 | + |
| 105 | +def get_all_states_impl(current_state, current_symbol, all_states): |
| 106 | + for i in range(BOARD_ROWS): |
| 107 | + for j in range(BOARD_COLS): |
| 108 | + if current_state.data[i][j] == 0: |
| 109 | + new_state = current_state.next_state(i, j, current_symbol) |
| 110 | + new_hash = new_state.hash() |
| 111 | + if new_hash not in all_states: |
| 112 | + is_end = new_state.is_end() |
| 113 | + all_states[new_hash] = (new_state, is_end) |
| 114 | + if not is_end: |
| 115 | + get_all_states_impl(new_state, -current_symbol, all_states) |
| 116 | + |
| 117 | + |
| 118 | +def get_all_states(): |
| 119 | + current_symbol = 1 |
| 120 | + current_state = State() |
| 121 | + all_states = dict() |
| 122 | + all_states[current_state.hash()] = (current_state, current_state.is_end()) |
| 123 | + get_all_states_impl(current_state, current_symbol, all_states) |
| 124 | + return all_states |
| 125 | + |
| 126 | + |
| 127 | +# all possible board configurations |
| 128 | +all_states = get_all_states() |
| 129 | + |
| 130 | + |
| 131 | +class Judger: |
| 132 | + # @player1: the player who will move first, its chessman will be 1 |
| 133 | + # @player2: another player with a chessman -1 |
| 134 | + def __init__(self, player1, player2): |
| 135 | + self.p1 = player1 |
| 136 | + self.p2 = player2 |
| 137 | + self.current_player = None |
| 138 | + self.p1_symbol = 1 |
| 139 | + self.p2_symbol = -1 |
| 140 | + self.p1.set_symbol(self.p1_symbol) |
| 141 | + self.p2.set_symbol(self.p2_symbol) |
| 142 | + self.current_state = State() |
| 143 | + |
| 144 | + def reset(self): |
| 145 | + self.p1.reset() |
| 146 | + self.p2.reset() |
| 147 | + |
| 148 | + def alternate(self): |
| 149 | + while True: |
| 150 | + yield self.p1 |
| 151 | + yield self.p2 |
| 152 | + |
| 153 | + # @print_state: if True, print each board during the game |
| 154 | + def play(self, print_state=False): |
| 155 | + alternator = self.alternate() |
| 156 | + self.reset() |
| 157 | + current_state = State() |
| 158 | + self.p1.set_state(current_state) |
| 159 | + self.p2.set_state(current_state) |
| 160 | + if print_state: |
| 161 | + current_state.print_state() |
| 162 | + while True: |
| 163 | + player = next(alternator) |
| 164 | + i, j, symbol = player.act() |
| 165 | + next_state_hash = current_state.next_state(i, j, symbol).hash() |
| 166 | + current_state, is_end = all_states[next_state_hash] |
| 167 | + self.p1.set_state(current_state) |
| 168 | + self.p2.set_state(current_state) |
| 169 | + if print_state: |
| 170 | + current_state.print_state() |
| 171 | + if is_end: |
| 172 | + return current_state.winner |
| 173 | + |
| 174 | + |
| 175 | +# AI player |
| 176 | +class Player: |
| 177 | + # @step_size: the step size to update estimations |
| 178 | + # @epsilon: the probability to explore |
| 179 | + def __init__(self, step_size=0.1, epsilon=0.1): |
| 180 | + self.estimations = dict() |
| 181 | + self.step_size = step_size |
| 182 | + self.epsilon = epsilon |
| 183 | + self.states = [] |
| 184 | + self.greedy = [] |
| 185 | + self.symbol = 0 |
| 186 | + |
| 187 | + def reset(self): |
| 188 | + self.states = [] |
| 189 | + self.greedy = [] |
| 190 | + |
| 191 | + def set_state(self, state): |
| 192 | + self.states.append(state) |
| 193 | + self.greedy.append(True) |
| 194 | + |
| 195 | + def set_symbol(self, symbol): |
| 196 | + self.symbol = symbol |
| 197 | + for hash_val in all_states: |
| 198 | + state, is_end = all_states[hash_val] |
| 199 | + if is_end: |
| 200 | + if state.winner == self.symbol: |
| 201 | + self.estimations[hash_val] = 1.0 |
| 202 | + elif state.winner == 0: |
| 203 | + # we need to distinguish between a tie and a lose |
| 204 | + self.estimations[hash_val] = 0.5 |
| 205 | + else: |
| 206 | + self.estimations[hash_val] = 0 |
| 207 | + else: |
| 208 | + self.estimations[hash_val] = 0.5 |
| 209 | + |
| 210 | + # update value estimation |
| 211 | + def backup(self): |
| 212 | + states = [state.hash() for state in self.states] |
| 213 | + |
| 214 | + for i in reversed(range(len(states) - 1)): |
| 215 | + state = states[i] |
| 216 | + td_error = self.greedy[i] * ( |
| 217 | + self.estimations[states[i + 1]] - self.estimations[state] |
| 218 | + ) |
| 219 | + self.estimations[state] += self.step_size * td_error |
| 220 | + |
| 221 | + # choose an action based on the state |
| 222 | + def act(self): |
| 223 | + state = self.states[-1] |
| 224 | + next_states = [] |
| 225 | + next_positions = [] |
| 226 | + for i in range(BOARD_ROWS): |
| 227 | + for j in range(BOARD_COLS): |
| 228 | + if state.data[i, j] == 0: |
| 229 | + next_positions.append([i, j]) |
| 230 | + next_states.append(state.next_state( |
| 231 | + i, j, self.symbol).hash()) |
| 232 | + |
| 233 | + if np.random.rand() < self.epsilon: |
| 234 | + action = next_positions[np.random.randint(len(next_positions))] |
| 235 | + action.append(self.symbol) |
| 236 | + self.greedy[-1] = False |
| 237 | + return action |
| 238 | + |
| 239 | + values = [] |
| 240 | + for hash_val, pos in zip(next_states, next_positions): |
| 241 | + values.append((self.estimations[hash_val], pos)) |
| 242 | + # to select one of the actions of equal value at random due to Python's sort is stable |
| 243 | + np.random.shuffle(values) |
| 244 | + values.sort(key=lambda x: x[0], reverse=True) |
| 245 | + action = values[0][1] |
| 246 | + action.append(self.symbol) |
| 247 | + return action |
| 248 | + |
| 249 | + def save_policy(self): |
| 250 | + with open('policy_%s.bin' % ('first' if self.symbol == 1 else 'second'), 'wb') as f: |
| 251 | + pickle.dump(self.estimations, f) |
| 252 | + |
| 253 | + def load_policy(self): |
| 254 | + with open('policy_%s.bin' % ('first' if self.symbol == 1 else 'second'), 'rb') as f: |
| 255 | + self.estimations = pickle.load(f) |
| 256 | + |
| 257 | + |
| 258 | +# human interface |
| 259 | +# input a number to put a chessman |
| 260 | +# | q | w | e | |
| 261 | +# | a | s | d | |
| 262 | +# | z | x | c | |
| 263 | +class HumanPlayer: |
| 264 | + def __init__(self, **kwargs): |
| 265 | + self.symbol = None |
| 266 | + self.keys = ['q', 'w', 'e', 'a', 's', 'd', 'z', 'x', 'c'] |
| 267 | + self.state = None |
| 268 | + |
| 269 | + def reset(self): |
| 270 | + pass |
| 271 | + |
| 272 | + def set_state(self, state): |
| 273 | + self.state = state |
| 274 | + |
| 275 | + def set_symbol(self, symbol): |
| 276 | + self.symbol = symbol |
| 277 | + |
| 278 | + def act(self): |
| 279 | + self.state.print_state() |
| 280 | + key = input("Input your position:") |
| 281 | + data = self.keys.index(key) |
| 282 | + i = data // BOARD_COLS |
| 283 | + j = data % BOARD_COLS |
| 284 | + return i, j, self.symbol |
| 285 | + |
| 286 | + |
| 287 | +def train(epochs, print_every_n=500): |
| 288 | + player1 = Player(epsilon=0.01) |
| 289 | + player2 = Player(epsilon=0.01) |
| 290 | + judger = Judger(player1, player2) |
| 291 | + player1_win = 0.0 |
| 292 | + player2_win = 0.0 |
| 293 | + for i in range(1, epochs + 1): |
| 294 | + winner = judger.play(print_state=False) |
| 295 | + if winner == 1: |
| 296 | + player1_win += 1 |
| 297 | + if winner == -1: |
| 298 | + player2_win += 1 |
| 299 | + if i % print_every_n == 0: |
| 300 | + print('Epoch %d, player 1 winrate: %.02f, player 2 winrate: %.02f' % (i, player1_win / i, player2_win / i)) |
| 301 | + player1.backup() |
| 302 | + player2.backup() |
| 303 | + judger.reset() |
| 304 | + player1.save_policy() |
| 305 | + player2.save_policy() |
| 306 | + |
| 307 | + |
| 308 | +def compete(turns): |
| 309 | + player1 = Player(epsilon=0) |
| 310 | + player2 = Player(epsilon=0) |
| 311 | + judger = Judger(player1, player2) |
| 312 | + player1.load_policy() |
| 313 | + player2.load_policy() |
| 314 | + player1_win = 0.0 |
| 315 | + player2_win = 0.0 |
| 316 | + for _ in range(turns): |
| 317 | + winner = judger.play() |
| 318 | + if winner == 1: |
| 319 | + player1_win += 1 |
| 320 | + if winner == -1: |
| 321 | + player2_win += 1 |
| 322 | + judger.reset() |
| 323 | + print('%d turns, player 1 win %.02f, player 2 win %.02f' % (turns, player1_win / turns, player2_win / turns)) |
| 324 | + |
| 325 | + |
| 326 | +# The game is a zero sum game. If both players are playing with an optimal strategy, every game will end in a tie. |
| 327 | +# So we test whether the AI can guarantee at least a tie if it goes second. |
| 328 | +def play(): |
| 329 | + while True: |
| 330 | + player1 = HumanPlayer() |
| 331 | + player2 = Player(epsilon=0) |
| 332 | + judger = Judger(player1, player2) |
| 333 | + player2.load_policy() |
| 334 | + winner = judger.play() |
| 335 | + if winner == player2.symbol: |
| 336 | + print("You lose!") |
| 337 | + elif winner == player1.symbol: |
| 338 | + print("You win!") |
| 339 | + else: |
| 340 | + print("It is a tie!") |
| 341 | + |
| 342 | + |
| 343 | +if __name__ == '__main__': |
| 344 | + train(int(1e5)) |
| 345 | + compete(int(1e3)) |
| 346 | + play() |
0 commit comments