Skip to content

Commit 9faf2c6

Browse files
Kaixhinsoumith
authored andcommitted
Update RL examples to use torch.distributions instead of reinforce
1 parent e0d33a6 commit 9faf2c6

File tree

2 files changed

+24
-21
lines changed

2 files changed

+24
-21
lines changed

reinforcement_learning/actor_critic.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import torch.nn as nn
99
import torch.nn.functional as F
1010
import torch.optim as optim
11-
import torch.autograd as autograd
1211
from torch.autograd import Variable
12+
from torch.distributions import Multinomial
1313

1414

1515
parser = argparse.ArgumentParser(description='PyTorch actor-critic example')
@@ -29,7 +29,9 @@
2929
torch.manual_seed(args.seed)
3030

3131

32-
SavedAction = namedtuple('SavedAction', ['action', 'value'])
32+
SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])
33+
34+
3335
class Policy(nn.Module):
3436
def __init__(self):
3537
super(Policy, self).__init__()
@@ -54,29 +56,28 @@ def forward(self, x):
5456
def select_action(state):
5557
state = torch.from_numpy(state).float().unsqueeze(0)
5658
probs, state_value = model(Variable(state))
57-
action = probs.multinomial()
58-
model.saved_actions.append(SavedAction(action, state_value))
59+
m = Multinomial(probs)
60+
action = m.sample()
61+
model.saved_actions.append(SavedAction(m.log_prob(action), state_value))
5962
return action.data
6063

6164

6265
def finish_episode():
6366
R = 0
6467
saved_actions = model.saved_actions
65-
value_loss = 0
68+
policy_loss, value_loss = 0, 0
6669
rewards = []
6770
for r in model.rewards[::-1]:
6871
R = r + args.gamma * R
6972
rewards.insert(0, R)
7073
rewards = torch.Tensor(rewards)
7174
rewards = (rewards - rewards.mean()) / (rewards.std() + np.finfo(np.float32).eps)
72-
for (action, value), r in zip(saved_actions, rewards):
73-
reward = r - value.data[0,0]
74-
action.reinforce(reward)
75+
for (log_prob, value), r in zip(saved_actions, rewards):
76+
reward = r - value.data[0, 0]
77+
policy_loss -= (log_prob * reward).sum()
7578
value_loss += F.smooth_l1_loss(value, Variable(torch.Tensor([r])))
7679
optimizer.zero_grad()
77-
final_nodes = [value_loss] + list(map(lambda p: p.action, saved_actions))
78-
gradients = [torch.ones(1)] + [None] * len(saved_actions)
79-
autograd.backward(final_nodes, gradients)
80+
(policy_loss + value_loss).backward()
8081
optimizer.step()
8182
del model.rewards[:]
8283
del model.saved_actions[:]
@@ -85,9 +86,9 @@ def finish_episode():
8586
running_reward = 10
8687
for i_episode in count(1):
8788
state = env.reset()
88-
for t in range(10000): # Don't infinite loop while learning
89+
for t in range(10000): # Don't infinite loop while learning
8990
action = select_action(state)
90-
state, reward, done, _ = env.step(action[0,0])
91+
state, reward, done, _ = env.step(action[0, 0])
9192
if args.render:
9293
env.render()
9394
model.rewards.append(reward)

reinforcement_learning/reinforce.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import torch.nn as nn
88
import torch.nn.functional as F
99
import torch.optim as optim
10-
import torch.autograd as autograd
1110
from torch.autograd import Variable
11+
from torch.distributions import Multinomial
1212

1313

1414
parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')
@@ -50,23 +50,25 @@ def forward(self, x):
5050
def select_action(state):
5151
state = torch.from_numpy(state).float().unsqueeze(0)
5252
probs = policy(Variable(state))
53-
action = probs.multinomial()
54-
policy.saved_actions.append(action)
53+
m = Multinomial(probs)
54+
action = m.sample()
55+
policy.saved_actions.append(m.log_prob(action))
5556
return action.data
5657

5758

5859
def finish_episode():
5960
R = 0
61+
policy_loss = 0
6062
rewards = []
6163
for r in policy.rewards[::-1]:
6264
R = r + args.gamma * R
6365
rewards.insert(0, R)
6466
rewards = torch.Tensor(rewards)
6567
rewards = (rewards - rewards.mean()) / (rewards.std() + np.finfo(np.float32).eps)
66-
for action, r in zip(policy.saved_actions, rewards):
67-
action.reinforce(r)
68+
for log_prob, r in zip(policy.saved_actions, rewards):
69+
policy_loss -= (log_prob * reward).sum()
6870
optimizer.zero_grad()
69-
autograd.backward(policy.saved_actions, [None for _ in policy.saved_actions])
71+
policy_loss.backward()
7072
optimizer.step()
7173
del policy.rewards[:]
7274
del policy.saved_actions[:]
@@ -75,9 +77,9 @@ def finish_episode():
7577
running_reward = 10
7678
for i_episode in count(1):
7779
state = env.reset()
78-
for t in range(10000): # Don't infinite loop while learning
80+
for t in range(10000): # Don't infinite loop while learning
7981
action = select_action(state)
80-
state, reward, done, _ = env.step(action[0,0])
82+
state, reward, done, _ = env.step(action[0, 0])
8183
if args.render:
8284
env.render()
8385
policy.rewards.append(reward)

0 commit comments

Comments
 (0)