Skip to content

Commit 5d3de7f

Browse files
authored
Merge pull request udacity#40 from udacity/smartcab
smartcab: Fix penalty logic and add metrics reporting (experimental)
2 parents eb319ec + 60dbb78 commit 5d3de7f

File tree

3 files changed

+192
-10
lines changed

3 files changed

+192
-10
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import time
2+
from collections import OrderedDict
3+
4+
import numpy as np
5+
import pandas as pd
6+
import matplotlib.pyplot as plt
7+
8+
class Metric(object):
9+
"""Named sequence of x and y values, with optional plotting helpers."""
10+
11+
def __init__(self, name):
12+
self.name = name
13+
self.reset()
14+
15+
def collect(self, x, y):
16+
self.xdata.append(x)
17+
self.ydata.append(y)
18+
19+
def plot(self, ax):
20+
self.plot_obj, = ax.plot(self.xdata, self.ydata, 'o-', label=self.name)
21+
22+
def refresh(self):
23+
self.plot_obj.set_data(self.xdata, self.ydata)
24+
25+
def reset(self):
26+
self.xdata = []
27+
self.ydata = []
28+
29+
30+
class Reporter(object):
31+
"""Collect metrics, analyze and report summary statistics."""
32+
33+
def __init__(self, metrics=[], live_plot=False):
34+
self.metrics = OrderedDict()
35+
self.live_plot = live_plot
36+
37+
for name in metrics:
38+
self.metrics[name] = Metric(name)
39+
40+
if self.live_plot:
41+
if not plt.isinteractive():
42+
plt.ion()
43+
self.plot()
44+
45+
print "Reporter.__init__(): Initialized with metrics: {}".format(metrics) # [debug]
46+
47+
def collect(self, name, x, y):
48+
if not name in self.metrics:
49+
self.metrics[name] = Metric(name)
50+
if self.live_plot:
51+
self.metrics[name].plot(self.ax)
52+
self.ax.legend() # add new metric to legend
53+
print "Reporter.collect(): New metric added: {}".format(name) # [debug]
54+
self.metrics[name].collect(x, y)
55+
if self.live_plot:
56+
self.metrics[name].refresh()
57+
58+
def plot(self):
59+
if not hasattr(self, 'fig') or not hasattr(self, 'ax'):
60+
self.fig, self.ax = plt.subplots()
61+
for name in self.metrics:
62+
self.metrics[name].plot(self.ax)
63+
#self.ax.set_autoscalex_on(True)
64+
#self.ax.set_autoscaley_on(True)
65+
self.ax.grid()
66+
self.ax.legend()
67+
else:
68+
for name in self.metrics:
69+
self.metrics[name].refresh()
70+
self.refresh_plot()
71+
72+
def refresh_plot(self):
73+
self.ax.relim()
74+
self.ax.autoscale_view()
75+
self.fig.canvas.draw()
76+
self.fig.canvas.flush_events()
77+
plt.draw()
78+
79+
def show_plot(self):
80+
if plt.isinteractive():
81+
plt.ioff()
82+
self.plot()
83+
plt.show()
84+
85+
def summary(self):
86+
return [pd.Series(metric.ydata, index=metric.xdata, name=name) for name, metric in self.metrics.iteritems()]
87+
88+
def reset(self):
89+
for name in self.metrics:
90+
self.metrics[name].reset()
91+
if self.live_plot:
92+
self.metrics[name].refresh()
93+
94+
95+
def test_reporter():
96+
plt.ion()
97+
rep = Reporter(metrics=['reward', 'flubber'], live_plot=True)
98+
for i in xrange(100):
99+
rep.collect('reward', i, np.random.random())
100+
if i % 10 == 1:
101+
rep.collect('flubber', i, np.random.random() * 2 + 1)
102+
rep.refresh_plot()
103+
time.sleep(0.01)
104+
rep.plot()
105+
summary = rep.summary()
106+
print "Summary ({} metrics):-".format(len(summary))
107+
for metric in summary:
108+
print "Name: {}, samples: {}, type: {}".format(metric.name, len(metric), metric.dtype)
109+
print "Mean: {}, s.d.: {}".format(metric.mean(), metric.std())
110+
#print metric[:5] # [debug]
111+
plt.ioff()
112+
plt.show()
113+
114+
115+
if __name__ == '__main__':
116+
test_reporter()

projects/smartcab/smartcab/environment.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ class Environment(object):
3131
valid_headings = [(1, 0), (0, -1), (-1, 0), (0, 1)] # ENWS
3232
hard_time_limit = -100 # even if enforce_deadline is False, end trial when deadline reaches this value (to avoid deadlocks)
3333

34-
def __init__(self):
34+
def __init__(self, num_dummies=3):
35+
self.num_dummies = num_dummies # no. of dummy agents
36+
37+
# Initialize simulation variables
3538
self.done = False
3639
self.t = 0
3740
self.agent_states = OrderedDict()
@@ -55,14 +58,30 @@ def __init__(self):
5558
self.roads.append((a, b))
5659

5760
# Dummy agents
58-
self.num_dummies = 3 # no. of dummy agents
5961
for i in xrange(self.num_dummies):
6062
self.create_agent(DummyAgent)
6163

62-
# Primary agent
64+
# Primary agent and associated parameters
6365
self.primary_agent = None # to be set explicitly
6466
self.enforce_deadline = False
6567

68+
# Step data (updated after each environment step)
69+
self.step_data = {
70+
't': 0,
71+
'deadline': 0,
72+
'waypoint': None,
73+
'inputs': None,
74+
'action': None,
75+
'reward': 0.0
76+
}
77+
78+
# Trial data (updated at the end of each trial)
79+
self.trial_data = {
80+
'net_reward': 0.0, # total reward earned in current trial
81+
'final_deadline': None, # deadline value (time remaining)
82+
'success': 0 # whether the agent reached the destination in time
83+
}
84+
6685
def create_agent(self, agent_class, *args, **kwargs):
6786
agent = agent_class(self, *args, **kwargs)
6887
self.agent_states[agent] = {'location': random.choice(self.intersections.keys()), 'heading': (0, 1)}
@@ -101,6 +120,11 @@ def reset(self):
101120
'destination': destination if agent is self.primary_agent else None,
102121
'deadline': deadline if agent is self.primary_agent else None}
103122
agent.reset(destination=(destination if agent is self.primary_agent else None))
123+
if agent is self.primary_agent:
124+
# Reset metrics for this trial (step data will be set during the step)
125+
self.trial_data['net_reward'] = 0.0
126+
self.trial_data['final_deadline'] = deadline
127+
self.trial_data['success'] = 0
104128

105129
def step(self):
106130
#print "Environment.step(): t = {}".format(self.t) # [debug]
@@ -113,7 +137,9 @@ def step(self):
113137
for agent in self.agent_states.iterkeys():
114138
agent.update(self.t)
115139

116-
self.t += 1
140+
if self.done:
141+
return # primary agent might have reached destination
142+
117143
if self.primary_agent is not None:
118144
agent_deadline = self.agent_states[self.primary_agent]['deadline']
119145
if agent_deadline <= self.hard_time_limit:
@@ -124,6 +150,8 @@ def step(self):
124150
print "Environment.step(): Primary agent ran out of time! Trial aborted."
125151
self.agent_states[self.primary_agent]['deadline'] = agent_deadline - 1
126152

153+
self.t += 1
154+
127155
def sense(self, agent):
128156
assert agent in self.agent_states, "Unknown agent!"
129157

@@ -150,7 +178,7 @@ def sense(self, agent):
150178
if left != 'forward': # we don't want to override left == 'forward'
151179
left = other_heading
152180

153-
return {'light': light, 'oncoming': oncoming, 'left': left, 'right': right} # TODO: make this a namedtuple
181+
return {'light': light, 'oncoming': oncoming, 'left': left, 'right': right}
154182

155183
def get_deadline(self, agent):
156184
return self.agent_states[agent]['deadline'] if agent is self.primary_agent else None
@@ -163,7 +191,7 @@ def act(self, agent, action):
163191
location = state['location']
164192
heading = state['heading']
165193
light = 'green' if (self.intersections[location].state and heading[1] != 0) or ((not self.intersections[location].state) and heading[0] != 0) else 'red'
166-
sense = self.sense(agent)
194+
inputs = self.sense(agent)
167195

168196
# Move agent if within bounds and obeys traffic rules
169197
reward = 0 # reward/penalty
@@ -172,12 +200,12 @@ def act(self, agent, action):
172200
if light != 'green':
173201
move_okay = False
174202
elif action == 'left':
175-
if light == 'green' and (sense['oncoming'] == None or sense['oncoming'] == 'left'):
203+
if light == 'green' and (inputs['oncoming'] == None or inputs['oncoming'] == 'left'):
176204
heading = (heading[1], -heading[0])
177205
else:
178206
move_okay = False
179207
elif action == 'right':
180-
if light == 'green' or sense['left'] != 'straight':
208+
if light == 'green' or (inputs['oncoming'] != 'left' and inputs['left'] != 'forward'):
181209
heading = (-heading[1], heading[0])
182210
else:
183211
move_okay = False
@@ -203,11 +231,22 @@ def act(self, agent, action):
203231
if state['location'] == state['destination']:
204232
if state['deadline'] >= 0:
205233
reward += 10 # bonus
234+
self.trial_data['success'] = 1
206235
self.done = True
207236
print "Environment.act(): Primary agent has reached destination!" # [debug]
208237
self.status_text = "state: {}\naction: {}\nreward: {}".format(agent.get_state(), action, reward)
209238
#print "Environment.act() [POST]: location: {}, heading: {}, action: {}, reward: {}".format(location, heading, action, reward) # [debug]
210239

240+
# Update metrics
241+
self.step_data['t'] = self.t
242+
self.trial_data['final_deadline'] = self.step_data['deadline'] = state['deadline']
243+
self.step_data['waypoint'] = agent.get_next_waypoint()
244+
self.step_data['inputs'] = inputs
245+
self.step_data['action'] = action
246+
self.step_data['reward'] = reward
247+
self.trial_data['net_reward'] += reward
248+
print "Environment.act(): Step data: {}".format(self.step_data) # [debug]
249+
211250
return reward
212251

213252
def compute_dist(self, a, b):

projects/smartcab/smartcab/simulator.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
import random
44
import importlib
55

6+
import numpy as np
7+
8+
from analysis import Reporter
9+
610
class Simulator(object):
711
"""Simulates agents in a dynamic smartcab environment.
812
@@ -21,7 +25,7 @@ class Simulator(object):
2125
'orange' : (255, 128, 0)
2226
}
2327

24-
def __init__(self, env, size=None, update_delay=1.0, display=True):
28+
def __init__(self, env, size=None, update_delay=1.0, display=True, live_plot=False):
2529
self.env = env
2630
self.size = size if size is not None else ((self.env.grid_size[0] + 1) * self.env.block_size, (self.env.grid_size[1] + 1) * self.env.block_size)
2731
self.width, self.height = self.size
@@ -34,7 +38,7 @@ def __init__(self, env, size=None, update_delay=1.0, display=True):
3438
self.start_time = None
3539
self.current_time = 0.0
3640
self.last_updated = 0.0
37-
self.update_delay = update_delay
41+
self.update_delay = update_delay # duration between each step (in secs)
3842

3943
self.display = display
4044
if self.display:
@@ -59,8 +63,14 @@ def __init__(self, env, size=None, update_delay=1.0, display=True):
5963
self.display = False
6064
print "Simulator.__init__(): Error initializing GUI objects; display disabled.\n{}: {}".format(e.__class__.__name__, e)
6165

66+
# Setup metrics to report
67+
self.live_plot = live_plot
68+
self.rep = Reporter(metrics=['net_reward', 'avg_net_reward', 'final_deadline', 'success'], live_plot=self.live_plot)
69+
self.avg_net_reward_window = 10
70+
6271
def run(self, n_trials=1):
6372
self.quit = False
73+
self.rep.reset()
6474
for trial in xrange(n_trials):
6575
print "Simulator.run(): Trial {}".format(trial) # [debug]
6676
self.env.reset()
@@ -90,6 +100,7 @@ def run(self, n_trials=1):
90100
# Update environment
91101
if self.current_time - self.last_updated >= self.update_delay:
92102
self.env.step()
103+
# TODO: Log step data
93104
self.last_updated = self.current_time
94105

95106
# Render GUI and sleep
@@ -105,6 +116,22 @@ def run(self, n_trials=1):
105116
if self.quit:
106117
break
107118

119+
# Collect/update metrics
120+
self.rep.collect('net_reward', trial, self.env.trial_data['net_reward']) # total reward obtained in this trial
121+
self.rep.collect('avg_net_reward', trial, np.mean(self.rep.metrics['net_reward'].ydata[-self.avg_net_reward_window:])) # rolling mean of reward
122+
self.rep.collect('final_deadline', trial, self.env.trial_data['final_deadline']) # final deadline value (time remaining)
123+
self.rep.collect('success', trial, self.env.trial_data['success'])
124+
if self.live_plot:
125+
self.rep.refresh_plot() # autoscales axes, draws stuff and flushes events
126+
127+
# Report final metrics
128+
if self.display:
129+
self.pygame.display.quit() # need to shutdown pygame before showing metrics plot
130+
# TODO: Figure out why having both game and plot displays makes things crash!
131+
132+
if self.live_plot:
133+
self.rep.show_plot() # holds till user closes plot window
134+
108135
def render(self):
109136
# Clear screen
110137
self.screen.fill(self.bg_color)

0 commit comments

Comments
 (0)