|
| 1 | +from functools import partial |
| 2 | + |
1 | 3 | import numpy as np
|
2 | 4 | import pytest
|
3 | 5 |
|
4 |
| -from gym.spaces import Tuple |
| 6 | +from gym.spaces import Discrete, Tuple |
5 | 7 | from gym.vector.async_vector_env import AsyncVectorEnv
|
6 | 8 | from gym.vector.sync_vector_env import SyncVectorEnv
|
7 | 9 | from gym.vector.vector_env import VectorEnv
|
| 10 | +from tests.testing_env import GenericTestEnv |
8 | 11 | from tests.vector.utils import CustomSpace, make_env
|
9 | 12 |
|
10 | 13 |
|
@@ -58,3 +61,65 @@ def test_custom_space_vector_env():
|
58 | 61 |
|
59 | 62 | assert isinstance(env.single_action_space, CustomSpace)
|
60 | 63 | assert isinstance(env.action_space, Tuple)
|
| 64 | + |
| 65 | + |
| 66 | +@pytest.mark.parametrize( |
| 67 | + "vectoriser", |
| 68 | + ( |
| 69 | + SyncVectorEnv, |
| 70 | + partial(AsyncVectorEnv, shared_memory=True), |
| 71 | + partial(AsyncVectorEnv, shared_memory=False), |
| 72 | + ), |
| 73 | + ids=["Sync", "Async with shared memory", "Async without shared memory"], |
| 74 | +) |
| 75 | +def test_final_obs_info(vectoriser): |
| 76 | + """Tests that the vector environments correctly return the final observation and info.""" |
| 77 | + |
| 78 | + def reset_fn(self, seed=None, options=None): |
| 79 | + return 0, {"reset": True} |
| 80 | + |
| 81 | + def thunk(): |
| 82 | + return GenericTestEnv( |
| 83 | + action_space=Discrete(4), |
| 84 | + observation_space=Discrete(4), |
| 85 | + reset_fn=reset_fn, |
| 86 | + step_fn=lambda self, action: ( |
| 87 | + action if action < 3 else 0, |
| 88 | + 0, |
| 89 | + action >= 3, |
| 90 | + False, |
| 91 | + {"action": action}, |
| 92 | + ), |
| 93 | + ) |
| 94 | + |
| 95 | + env = vectoriser([thunk]) |
| 96 | + obs, info = env.reset() |
| 97 | + assert obs == np.array([0]) and info == { |
| 98 | + "reset": np.array([True]), |
| 99 | + "_reset": np.array([True]), |
| 100 | + } |
| 101 | + |
| 102 | + obs, _, termination, _, info = env.step([1]) |
| 103 | + assert ( |
| 104 | + obs == np.array([1]) |
| 105 | + and termination == np.array([False]) |
| 106 | + and info == {"action": np.array([1]), "_action": np.array([True])} |
| 107 | + ) |
| 108 | + |
| 109 | + obs, _, termination, _, info = env.step([2]) |
| 110 | + assert ( |
| 111 | + obs == np.array([2]) |
| 112 | + and termination == np.array([False]) |
| 113 | + and info == {"action": np.array([2]), "_action": np.array([True])} |
| 114 | + ) |
| 115 | + |
| 116 | + obs, _, termination, _, info = env.step([3]) |
| 117 | + assert ( |
| 118 | + obs == np.array([0]) |
| 119 | + and termination == np.array([True]) |
| 120 | + and info["reset"] == np.array([True]) |
| 121 | + ) |
| 122 | + assert "final_observation" in info and "final_info" in info |
| 123 | + assert info["final_observation"] == np.array([0]) and info["final_info"] == { |
| 124 | + "action": 3 |
| 125 | + } |
0 commit comments