@@ -31,7 +31,10 @@ class Environment(object):
31
31
valid_headings = [(1 , 0 ), (0 , - 1 ), (- 1 , 0 ), (0 , 1 )] # ENWS
32
32
hard_time_limit = - 100 # even if enforce_deadline is False, end trial when deadline reaches this value (to avoid deadlocks)
33
33
34
- def __init__ (self ):
34
+ def __init__ (self , num_dummies = 3 ):
35
+ self .num_dummies = num_dummies # no. of dummy agents
36
+
37
+ # Initialize simulation variables
35
38
self .done = False
36
39
self .t = 0
37
40
self .agent_states = OrderedDict ()
@@ -55,14 +58,30 @@ def __init__(self):
55
58
self .roads .append ((a , b ))
56
59
57
60
# Dummy agents
58
- self .num_dummies = 3 # no. of dummy agents
59
61
for i in xrange (self .num_dummies ):
60
62
self .create_agent (DummyAgent )
61
63
62
- # Primary agent
64
+ # Primary agent and associated parameters
63
65
self .primary_agent = None # to be set explicitly
64
66
self .enforce_deadline = False
65
67
68
+ # Step data (updated after each environment step)
69
+ self .step_data = {
70
+ 't' : 0 ,
71
+ 'deadline' : 0 ,
72
+ 'waypoint' : None ,
73
+ 'inputs' : None ,
74
+ 'action' : None ,
75
+ 'reward' : 0.0
76
+ }
77
+
78
+ # Trial data (updated at the end of each trial)
79
+ self .trial_data = {
80
+ 'net_reward' : 0.0 , # total reward earned in current trial
81
+ 'final_deadline' : None , # deadline value (time remaining)
82
+ 'success' : 0 # whether the agent reached the destination in time
83
+ }
84
+
66
85
def create_agent (self , agent_class , * args , ** kwargs ):
67
86
agent = agent_class (self , * args , ** kwargs )
68
87
self .agent_states [agent ] = {'location' : random .choice (self .intersections .keys ()), 'heading' : (0 , 1 )}
@@ -101,6 +120,11 @@ def reset(self):
101
120
'destination' : destination if agent is self .primary_agent else None ,
102
121
'deadline' : deadline if agent is self .primary_agent else None }
103
122
agent .reset (destination = (destination if agent is self .primary_agent else None ))
123
+ if agent is self .primary_agent :
124
+ # Reset metrics for this trial (step data will be set during the step)
125
+ self .trial_data ['net_reward' ] = 0.0
126
+ self .trial_data ['final_deadline' ] = deadline
127
+ self .trial_data ['success' ] = 0
104
128
105
129
def step (self ):
106
130
#print "Environment.step(): t = {}".format(self.t) # [debug]
@@ -113,7 +137,9 @@ def step(self):
113
137
for agent in self .agent_states .iterkeys ():
114
138
agent .update (self .t )
115
139
116
- self .t += 1
140
+ if self .done :
141
+ return # primary agent might have reached destination
142
+
117
143
if self .primary_agent is not None :
118
144
agent_deadline = self .agent_states [self .primary_agent ]['deadline' ]
119
145
if agent_deadline <= self .hard_time_limit :
@@ -124,6 +150,8 @@ def step(self):
124
150
print "Environment.step(): Primary agent ran out of time! Trial aborted."
125
151
self .agent_states [self .primary_agent ]['deadline' ] = agent_deadline - 1
126
152
153
+ self .t += 1
154
+
127
155
def sense (self , agent ):
128
156
assert agent in self .agent_states , "Unknown agent!"
129
157
@@ -150,7 +178,7 @@ def sense(self, agent):
150
178
if left != 'forward' : # we don't want to override left == 'forward'
151
179
left = other_heading
152
180
153
- return {'light' : light , 'oncoming' : oncoming , 'left' : left , 'right' : right } # TODO: make this a namedtuple
181
+ return {'light' : light , 'oncoming' : oncoming , 'left' : left , 'right' : right }
154
182
155
183
def get_deadline (self , agent ):
156
184
return self .agent_states [agent ]['deadline' ] if agent is self .primary_agent else None
@@ -163,7 +191,7 @@ def act(self, agent, action):
163
191
location = state ['location' ]
164
192
heading = state ['heading' ]
165
193
light = 'green' if (self .intersections [location ].state and heading [1 ] != 0 ) or ((not self .intersections [location ].state ) and heading [0 ] != 0 ) else 'red'
166
- sense = self .sense (agent )
194
+ inputs = self .sense (agent )
167
195
168
196
# Move agent if within bounds and obeys traffic rules
169
197
reward = 0 # reward/penalty
@@ -172,12 +200,12 @@ def act(self, agent, action):
172
200
if light != 'green' :
173
201
move_okay = False
174
202
elif action == 'left' :
175
- if light == 'green' and (sense ['oncoming' ] == None or sense ['oncoming' ] == 'left' ):
203
+ if light == 'green' and (inputs ['oncoming' ] == None or inputs ['oncoming' ] == 'left' ):
176
204
heading = (heading [1 ], - heading [0 ])
177
205
else :
178
206
move_okay = False
179
207
elif action == 'right' :
180
- if light == 'green' or sense [ ' left' ] != 'straight' :
208
+ if light == 'green' or ( inputs [ 'oncoming' ] != ' left' and inputs [ 'left' ] != 'forward' ) :
181
209
heading = (- heading [1 ], heading [0 ])
182
210
else :
183
211
move_okay = False
@@ -203,11 +231,22 @@ def act(self, agent, action):
203
231
if state ['location' ] == state ['destination' ]:
204
232
if state ['deadline' ] >= 0 :
205
233
reward += 10 # bonus
234
+ self .trial_data ['success' ] = 1
206
235
self .done = True
207
236
print "Environment.act(): Primary agent has reached destination!" # [debug]
208
237
self .status_text = "state: {}\n action: {}\n reward: {}" .format (agent .get_state (), action , reward )
209
238
#print "Environment.act() [POST]: location: {}, heading: {}, action: {}, reward: {}".format(location, heading, action, reward) # [debug]
210
239
240
+ # Update metrics
241
+ self .step_data ['t' ] = self .t
242
+ self .trial_data ['final_deadline' ] = self .step_data ['deadline' ] = state ['deadline' ]
243
+ self .step_data ['waypoint' ] = agent .get_next_waypoint ()
244
+ self .step_data ['inputs' ] = inputs
245
+ self .step_data ['action' ] = action
246
+ self .step_data ['reward' ] = reward
247
+ self .trial_data ['net_reward' ] += reward
248
+ print "Environment.act(): Step data: {}" .format (self .step_data ) # [debug]
249
+
211
250
return reward
212
251
213
252
def compute_dist (self , a , b ):
0 commit comments