|
6 | 6 |
|
7 | 7 |
|
8 | 8 | This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent
|
9 |
| -on the CartPole-v0 task from the `OpenAI Gym <https://www.gymlibrary.ml/>`__. |
| 9 | +on the CartPole-v0 task from the `OpenAI Gym <https://www.gymlibrary.dev/>`__. |
10 | 10 |
|
11 | 11 | **Task**
|
12 | 12 |
|
13 | 13 | The agent has to decide between two actions - moving the cart left or
|
14 | 14 | right - so that the pole attached to it stays upright. You can find an
|
15 | 15 | official leaderboard with various algorithms and visualizations at the
|
16 |
| -`Gym website <https://www.gymlibrary.ml/environments/classic_control/cart_pole>`__. |
| 16 | +`Gym website <https://www.gymlibrary.dev/environments/classic_control/cart_pole>`__. |
17 | 17 |
|
18 | 18 | .. figure:: /_static/img/cartpole.gif
|
19 | 19 | :alt: cartpole
|
|
74 | 74 | import torchvision.transforms as T
|
75 | 75 |
|
76 | 76 |
|
77 |
| -env = gym.make('CartPole-v0').unwrapped |
| 77 | +env = gym.make('CartPole-v0', new_step_api=True, render_mode='single_rgb_array').unwrapped |
78 | 78 |
|
79 | 79 | # set up matplotlib
|
80 | 80 | is_ipython = 'inline' in matplotlib.get_backend()
|
@@ -254,7 +254,7 @@ def get_cart_location(screen_width):
|
254 | 254 | def get_screen():
|
255 | 255 | # Returned screen requested by gym is 400x600x3, but is sometimes larger
|
256 | 256 | # such as 800x1200x3. Transpose it into torch order (CHW).
|
257 |
| - screen = env.render(mode='rgb_array').transpose((2, 0, 1)) |
| 257 | + screen = env.render().transpose((2, 0, 1)) |
258 | 258 | # Cart is in the lower half, so strip off the top and bottom of the screen
|
259 | 259 | _, screen_height, screen_width = screen.shape
|
260 | 260 | screen = screen[:, int(screen_height*0.4):int(screen_height * 0.8)]
|
@@ -461,7 +461,7 @@ def optimize_model():
|
461 | 461 | for t in count():
|
462 | 462 | # Select and perform an action
|
463 | 463 | action = select_action(state)
|
464 |
| - _, reward, done, _ = env.step(action.item()) |
| 464 | + _, reward, done, _, _ = env.step(action.item()) |
465 | 465 | reward = torch.tensor([reward], device=device)
|
466 | 466 |
|
467 | 467 | # Observe new state
|
|
0 commit comments