Skip to content

Commit 9fd930c

Browse files
committed
Fixed RL examples to work with new gym API
1 parent 32f0c1e commit 9fd930c

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

reinforcement_learning/actor_critic.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626

2727
env = gym.make('CartPole-v1')
28-
env.seed(args.seed)
28+
env.reset(seed=args.seed)
2929
torch.manual_seed(args.seed)
3030

3131

@@ -56,7 +56,7 @@ def forward(self, x):
5656
"""
5757
x = F.relu(self.affine1(x))
5858

59-
# actor: choses action to take from state s_t
59+
# actor: choses action to take from state s_t
6060
# by returning probability of each action
6161
action_prob = F.softmax(self.action_head(x), dim=-1)
6262

@@ -65,7 +65,7 @@ def forward(self, x):
6565

6666
# return values for both actor and critic as a tuple of 2 values:
6767
# 1. a list with the probability of each action over the action space
68-
# 2. the value from state s_t
68+
# 2. the value from state s_t
6969
return action_prob, state_values
7070

7171

@@ -113,7 +113,7 @@ def finish_episode():
113113
for (log_prob, value), R in zip(saved_actions, returns):
114114
advantage = R - value.item()
115115

116-
# calculate actor (policy) loss
116+
# calculate actor (policy) loss
117117
policy_losses.append(-log_prob * advantage)
118118

119119
# calculate critic (value) loss using L1 smooth loss
@@ -141,18 +141,18 @@ def main():
141141
for i_episode in count(1):
142142

143143
# reset environment and episode reward
144-
state = env.reset()
144+
state, _ = env.reset()
145145
ep_reward = 0
146146

147-
# for each episode, only run 9999 steps so that we don't
147+
# for each episode, only run 9999 steps so that we don't
148148
# infinite loop while learning
149149
for t in range(1, 10000):
150150

151151
# select action from policy
152152
action = select_action(state)
153153

154154
# take the action
155-
state, reward, done, _ = env.step(action)
155+
state, reward, done, _, _ = env.step(action)
156156

157157
if args.render:
158158
env.render()

reinforcement_learning/reinforce.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,11 @@ def finish_episode():
8181
def main():
8282
running_reward = 10
8383
for i_episode in count(1):
84-
state, ep_reward = env.reset(), 0
84+
state, _ = env.reset()
85+
ep_reward = 0
8586
for t in range(1, 10000): # Don't infinite loop while learning
8687
action = select_action(state)
87-
state, reward, done, _ = env.step(action)
88+
state, reward, done, _, _ = env.step(action)
8889
if args.render:
8990
env.render()
9091
policy.rewards.append(reward)

0 commit comments

Comments
 (0)