|
| 1 | +import matplotlib.pyplot as plt |
| 2 | +import random |
| 3 | +import heapq |
| 4 | +import math |
| 5 | +import sys |
| 6 | +from collections import defaultdict, deque, Counter |
| 7 | +from itertools import combinations |
| 8 | + |
| 9 | +class Problem(object): |
| 10 | + """The abstract class for a formal problem. A new domain subclasses this, |
| 11 | + overriding `actions` and `results`, and perhaps other methods. |
| 12 | + The default heuristic is 0 and the default action cost is 1 for all states. |
| 13 | + When yiou create an instance of a subclass, specify `initial`, and `goal` states |
| 14 | + (or give an `is_goal` method) and perhaps other keyword args for the subclass.""" |
| 15 | + |
| 16 | + def __init__(self, initial=None, goal=None, **kwds): |
| 17 | + self.__dict__.update(initial=initial, goal=goal, **kwds) |
| 18 | + |
| 19 | + def actions(self, state): raise NotImplementedError |
| 20 | + def result(self, state, action): raise NotImplementedError |
| 21 | + def is_goal(self, state): return state == self.goal |
| 22 | + def action_cost(self, s, a, s1): return 1 |
| 23 | + def h(self, node): return 0 |
| 24 | + |
| 25 | + def __str__(self): |
| 26 | + return '{}({!r}, {!r})'.format( |
| 27 | + type(self).__name__, self.initial, self.goal) |
| 28 | + |
| 29 | + |
| 30 | +class Node: |
| 31 | + "A Node in a search tree." |
| 32 | + def __init__(self, state, parent=None, action=None, path_cost=0): |
| 33 | + self.__dict__.update(state=state, parent=parent, action=action, path_cost=path_cost) |
| 34 | + |
| 35 | + def __repr__(self): return '<{}>'.format(self.state) |
| 36 | + def __len__(self): return 0 if self.parent is None else (1 + len(self.parent)) |
| 37 | + def __lt__(self, other): return self.path_cost < other.path_cost |
| 38 | + |
| 39 | + |
| 40 | +failure = Node('failure', path_cost=math.inf) # Indicates an algorithm couldn't find a solution. |
| 41 | +cutoff = Node('cutoff', path_cost=math.inf) # Indicates iterative deepening search was cut off. |
| 42 | + |
| 43 | + |
| 44 | +def expand(problem, node): |
| 45 | + "Expand a node, generating the children nodes." |
| 46 | + s = node.state |
| 47 | + for action in problem.actions(s): |
| 48 | + s1 = problem.result(s, action) |
| 49 | + cost = node.path_cost + problem.action_cost(s, action, s1) |
| 50 | + yield Node(s1, node, action, cost) |
| 51 | + |
| 52 | + |
| 53 | +def path_actions(node): |
| 54 | + "The sequence of actions to get to this node." |
| 55 | + if node.parent is None: |
| 56 | + return [] |
| 57 | + return path_actions(node.parent) + [node.action] |
| 58 | + |
| 59 | + |
| 60 | +def path_states(node): |
| 61 | + "The sequence of states to get to this node." |
| 62 | + if node in (cutoff, failure, None): |
| 63 | + return [] |
| 64 | + return path_states(node.parent) + [node.state] |
| 65 | + |
| 66 | +FIFOQueue = deque |
| 67 | + |
| 68 | +LIFOQueue = list |
| 69 | + |
| 70 | +class PriorityQueue: |
| 71 | + """A queue in which the item with minimum f(item) is always popped first.""" |
| 72 | + |
| 73 | + def __init__(self, items=(), key=lambda x: x): |
| 74 | + self.key = key |
| 75 | + self.items = [] # a heap of (score, item) pairs |
| 76 | + for item in items: |
| 77 | + self.add(item) |
| 78 | + |
| 79 | + def add(self, item): |
| 80 | + """Add item to the queuez.""" |
| 81 | + pair = (self.key(item), item) |
| 82 | + heapq.heappush(self.items, pair) |
| 83 | + |
| 84 | + def pop(self): |
| 85 | + """Pop and return the item with min f(item) value.""" |
| 86 | + return heapq.heappop(self.items)[1] |
| 87 | + |
| 88 | + def top(self): return self.items[0][1] |
| 89 | + |
| 90 | + def __len__(self): return len(self.items) |
| 91 | + |
| 92 | +def best_first_search(problem, f): |
| 93 | + "Search nodes with minimum f(node) value first." |
| 94 | + node = Node(problem.initial) |
| 95 | + frontier = PriorityQueue([node], key=f) |
| 96 | + reached = {problem.initial: node} |
| 97 | + while frontier: |
| 98 | + node = frontier.pop() |
| 99 | + if problem.is_goal(node.state): |
| 100 | + return node |
| 101 | + for child in expand(problem, node): |
| 102 | + s = child.state |
| 103 | + if s not in reached or child.path_cost < reached[s].path_cost: |
| 104 | + reached[s] = child |
| 105 | + frontier.add(child) |
| 106 | + return failure |
| 107 | + |
| 108 | + |
| 109 | +def best_first_tree_search(problem, f): |
| 110 | + "A version of best_first_search without the `reached` table." |
| 111 | + frontier = PriorityQueue([Node(problem.initial)], key=f) |
| 112 | + while frontier: |
| 113 | + node = frontier.pop() |
| 114 | + if problem.is_goal(node.state): |
| 115 | + return node |
| 116 | + for child in expand(problem, node): |
| 117 | + if not is_cycle(child): |
| 118 | + frontier.add(child) |
| 119 | + return failure |
| 120 | + |
| 121 | + |
| 122 | +def g(n): return n.path_cost |
| 123 | + |
| 124 | + |
| 125 | +def astar_search(problem, h=None): |
| 126 | + """Search nodes with minimum f(n) = g(n) + h(n).""" |
| 127 | + h = h or problem.h |
| 128 | + return best_first_search(problem, f=lambda n: g(n) + h(n)) |
| 129 | + |
| 130 | + |
| 131 | +def astar_tree_search(problem, h=None): |
| 132 | + """Search nodes with minimum f(n) = g(n) + h(n), with no `reached` table.""" |
| 133 | + h = h or problem.h |
| 134 | + return best_first_tree_search(problem, f=lambda n: g(n) + h(n)) |
| 135 | + |
| 136 | + |
| 137 | +def weighted_astar_search(problem, h=None, weight=1.4): |
| 138 | + """Search nodes with minimum f(n) = g(n) + weight * h(n).""" |
| 139 | + h = h or problem.h |
| 140 | + return best_first_search(problem, f=lambda n: g(n) + weight * h(n)) |
| 141 | + |
| 142 | + |
| 143 | +def greedy_bfs(problem, h=None): |
| 144 | + """Search nodes with minimum h(n).""" |
| 145 | + h = h or problem.h |
| 146 | + return best_first_search(problem, f=h) |
| 147 | + |
| 148 | + |
| 149 | +def uniform_cost_search(problem): |
| 150 | + "Search nodes with minimum path cost first." |
| 151 | + return best_first_search(problem, f=g) |
| 152 | + |
| 153 | + |
| 154 | +def breadth_first_bfs(problem): |
| 155 | + "Search shallowest nodes in the search tree first; using best-first." |
| 156 | + return best_first_search(problem, f=len) |
| 157 | + |
| 158 | + |
| 159 | +def depth_first_bfs(problem): |
| 160 | + "Search deepest nodes in the search tree first; using best-first." |
| 161 | + return best_first_search(problem, f=lambda n: -len(n)) |
| 162 | + |
| 163 | + |
| 164 | +def is_cycle(node, k=30): |
| 165 | + "Does this node form a cycle of length k or less?" |
| 166 | + def find_cycle(ancestor, k): |
| 167 | + return (ancestor is not None and k > 0 and |
| 168 | + (ancestor.state == node.state or find_cycle(ancestor.parent, k - 1))) |
| 169 | + return find_cycle(node.parent, k) |
| 170 | + |
| 171 | +def breadth_first_search(problem): |
| 172 | + "Search shallowest nodes in the search tree first." |
| 173 | + node = Node(problem.initial) |
| 174 | + if problem.is_goal(problem.initial): |
| 175 | + return node |
| 176 | + frontier = FIFOQueue([node]) |
| 177 | + reached = {problem.initial} |
| 178 | + while frontier: |
| 179 | + node = frontier.pop() |
| 180 | + for child in expand(problem, node): |
| 181 | + s = child.state |
| 182 | + if problem.is_goal(s): |
| 183 | + return child |
| 184 | + if s not in reached: |
| 185 | + reached.add(s) |
| 186 | + frontier.appendleft(child) |
| 187 | + return failure |
| 188 | + |
| 189 | + |
| 190 | +def iterative_deepening_search(problem): |
| 191 | + "Do depth-limited search with increasing depth limits." |
| 192 | + for limit in range(1, sys.maxsize): |
| 193 | + result = depth_limited_search(problem, limit) |
| 194 | + if result != cutoff: |
| 195 | + return result |
| 196 | + |
| 197 | + |
| 198 | +def depth_limited_search(problem, limit=10): |
| 199 | + "Search deepest nodes in the search tree first." |
| 200 | + frontier = LIFOQueue([Node(problem.initial)]) |
| 201 | + result = failure |
| 202 | + while frontier: |
| 203 | + node = frontier.pop() |
| 204 | + if problem.is_goal(node.state): |
| 205 | + return node |
| 206 | + elif len(node) >= limit: |
| 207 | + result = cutoff |
| 208 | + elif not is_cycle(node): |
| 209 | + for child in expand(problem, node): |
| 210 | + frontier.append(child) |
| 211 | + return result |
| 212 | + |
| 213 | + |
| 214 | +def depth_first_recursive_search(problem, node=None): |
| 215 | + if node is None: |
| 216 | + node = Node(problem.initial) |
| 217 | + if problem.is_goal(node.state): |
| 218 | + return node |
| 219 | + elif is_cycle(node): |
| 220 | + return failure |
| 221 | + else: |
| 222 | + for child in expand(problem, node): |
| 223 | + result = depth_first_recursive_search(problem, child) |
| 224 | + if result: |
| 225 | + return result |
| 226 | + return failure |
| 227 | + |
| 228 | +class Map: |
| 229 | + """A map of places in a 2D world: a graph with vertexes and links between them. |
| 230 | + In `Map(links, locations)`, `links` can be either [(v1, v2)...] pairs, |
| 231 | + or a {(v1, v2): distance...} dict. Optional `locations` can be {v1: (x, y)} |
| 232 | + If `directed=False` then for every (v1, v2) link, we add a (v2, v1) link.""" |
| 233 | + |
| 234 | + def __init__(self, links, locations=None, directed=False): |
| 235 | + if not hasattr(links, 'items'): # Distances are 1 by default |
| 236 | + links = {link: 1 for link in links} |
| 237 | + if not directed: |
| 238 | + for (v1, v2) in list(links): |
| 239 | + links[v2, v1] = links[v1, v2] |
| 240 | + self.distances = links |
| 241 | + self.neighbors = multimap(links) |
| 242 | + self.locations = locations or defaultdict(lambda: (0, 0)) |
| 243 | + |
| 244 | + |
| 245 | +def multimap(pairs) -> dict: |
| 246 | + "Given (key, val) pairs, make a dict of {key: [val,...]}." |
| 247 | + result = defaultdict(list) |
| 248 | + for key, val in pairs: |
| 249 | + result[key].append(val) |
| 250 | + return result |
| 251 | + |
| 252 | +class RouteProblem(Problem): |
| 253 | + """A problem to find a route between locations on a `Map`. |
| 254 | + Create a problem with RouteProblem(start, goal, map=Map(...)}). |
| 255 | + States are the vertexes in the Map graph; actions are destination states.""" |
| 256 | + |
| 257 | + def actions(self, state): |
| 258 | + """The places neighboring `state`.""" |
| 259 | + return self.map.neighbors[state] |
| 260 | + |
| 261 | + def result(self, state, action): |
| 262 | + """Go to the `action` place, if the map says that is possible.""" |
| 263 | + return action if action in self.map.neighbors[state] else state |
| 264 | + |
| 265 | + def action_cost(self, s, action, s1): |
| 266 | + """The distance (cost) to go from s to s1.""" |
| 267 | + return self.map.distances[s, s1] |
| 268 | + |
| 269 | + def h(self, node): |
| 270 | + "Straight-line distance between state and the goal." |
| 271 | + locs = self.map.locations |
| 272 | + return straight_line_distance(locs[node.state], locs[self.goal]) |
| 273 | + |
| 274 | + |
| 275 | +def straight_line_distance(A, B): |
| 276 | + "Straight-line distance between two points." |
| 277 | + return sum(abs(a - b)**2 for (a, b) in zip(A, B)) ** 0.5 |
0 commit comments