Skip to content

Commit 05dcde9

Browse files
authored
Merge branch 'master' into quant-fix
2 parents 54d3c0d + 9fb9f47 commit 05dcde9

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

intermediate_source/reinforcement_q_learning.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
77
88
This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent
9-
on the CartPole-v0 task from the `OpenAI Gym <https://www.gymlibrary.ml/>`__.
9+
on the CartPole-v0 task from the `OpenAI Gym <https://www.gymlibrary.dev/>`__.
1010
1111
**Task**
1212
1313
The agent has to decide between two actions - moving the cart left or
1414
right - so that the pole attached to it stays upright. You can find an
1515
official leaderboard with various algorithms and visualizations at the
16-
`Gym website <https://www.gymlibrary.ml/environments/classic_control/cart_pole>`__.
16+
`Gym website <https://www.gymlibrary.dev/environments/classic_control/cart_pole>`__.
1717
1818
.. figure:: /_static/img/cartpole.gif
1919
:alt: cartpole
@@ -74,7 +74,7 @@
7474
import torchvision.transforms as T
7575

7676

77-
env = gym.make('CartPole-v0').unwrapped
77+
env = gym.make('CartPole-v0', new_step_api=True, render_mode='single_rgb_array').unwrapped
7878

7979
# set up matplotlib
8080
is_ipython = 'inline' in matplotlib.get_backend()
@@ -254,7 +254,7 @@ def get_cart_location(screen_width):
254254
def get_screen():
255255
# Returned screen requested by gym is 400x600x3, but is sometimes larger
256256
# such as 800x1200x3. Transpose it into torch order (CHW).
257-
screen = env.render(mode='rgb_array').transpose((2, 0, 1))
257+
screen = env.render().transpose((2, 0, 1))
258258
# Cart is in the lower half, so strip off the top and bottom of the screen
259259
_, screen_height, screen_width = screen.shape
260260
screen = screen[:, int(screen_height*0.4):int(screen_height * 0.8)]
@@ -461,7 +461,7 @@ def optimize_model():
461461
for t in count():
462462
# Select and perform an action
463463
action = select_action(state)
464-
_, reward, done, _ = env.step(action.item())
464+
_, reward, done, _, _ = env.step(action.item())
465465
reward = torch.tensor([reward], device=device)
466466

467467
# Observe new state

0 commit comments

Comments
 (0)