Skip to content

Commit c18f55c

Browse files
author
Chris Elion
authored
Improve test_simple.py and check discrete actions (#2345)
* discrete action coverage * undo change * rename test * move test file * Revert "move test file" This reverts commit 2e72b2d. * move files post merge
1 parent bcc2d1c commit c18f55c

File tree

2 files changed

+26
-12
lines changed

2 files changed

+26
-12
lines changed

ml-agents/mlagents/trainers/tests/test_environments/__init__.py

Whitespace-only changes.

ml-agents/mlagents/trainers/tests/test_environments/test_simple.py renamed to ml-agents/mlagents/trainers/tests/test_simple_rl.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
import yaml
21
import math
2+
import random
33
import tempfile
4+
import pytest
5+
import yaml
46
from typing import Any, Dict
57

68

@@ -31,21 +33,25 @@ class Simple1DEnvironment(BaseUnityEnvironment):
3133
it reaches -1. The position is incremented by the action amount (clamped to [-step_size, step_size]).
3234
"""
3335

34-
def __init__(self):
36+
def __init__(self, use_discrete):
37+
super().__init__()
38+
self.discrete = use_discrete
3539
self._brains: Dict[str, BrainParameters] = {}
3640
self._brains[BRAIN_NAME] = BrainParameters(
3741
brain_name=BRAIN_NAME,
3842
vector_observation_space_size=OBS_SIZE,
3943
num_stacked_vector_observations=1,
4044
camera_resolutions=[],
41-
vector_action_space_size=[1],
45+
vector_action_space_size=[2] if use_discrete else [1],
4246
vector_action_descriptions=["moveDirection"],
43-
vector_action_space_type=1, # "continuous"
47+
vector_action_space_type=0 if use_discrete else 1,
4448
)
4549

4650
# state
4751
self.position = 0.0
4852
self.step_count = 0
53+
self.random = random.Random(str(self._brains))
54+
self.goal = random.choice([-1, 1])
4955

5056
def step(
5157
self,
@@ -56,21 +62,23 @@ def step(
5662
) -> AllBrainInfo:
5763
assert vector_action is not None
5864

59-
delta = vector_action[BRAIN_NAME][0][0]
65+
if self.discrete:
66+
act = vector_action[BRAIN_NAME][0][0]
67+
delta = 1 if act else -1
68+
else:
69+
delta = vector_action[BRAIN_NAME][0][0]
6070
delta = clamp(delta, -STEP_SIZE, STEP_SIZE)
6171
self.position += delta
6272
self.position = clamp(self.position, -1, 1)
6373
self.step_count += 1
6474
done = self.position >= 1.0 or self.position <= -1.0
6575
if done:
66-
reward = SUCCESS_REWARD * self.position
76+
reward = SUCCESS_REWARD * self.position * self.goal
6777
else:
6878
reward = -TIME_PENALTY
6979

7080
agent_info = AgentInfoProto(
71-
stacked_vector_observation=[self.position] * OBS_SIZE,
72-
reward=reward,
73-
done=done,
81+
stacked_vector_observation=[self.goal] * OBS_SIZE, reward=reward, done=done
7482
)
7583

7684
if done:
@@ -85,6 +93,7 @@ def step(
8593
def _reset_agent(self):
8694
self.position = 0.0
8795
self.step_count = 0
96+
self.goal = random.choice([-1, 1])
8897

8998
def reset(
9099
self,
@@ -95,7 +104,7 @@ def reset(
95104
self._reset_agent()
96105

97106
agent_info = AgentInfoProto(
98-
stacked_vector_observation=[self.position] * OBS_SIZE,
107+
stacked_vector_observation=[self.goal] * OBS_SIZE,
99108
done=False,
100109
max_step_reached=False,
101110
)
@@ -121,7 +130,7 @@ def close(self):
121130
pass
122131

123132

124-
def test_simple():
133+
def _check_environment_trains(env):
125134
config = """
126135
default:
127136
trainer: ppo
@@ -167,11 +176,16 @@ def test_simple():
167176
)
168177

169178
# Begin training
170-
env = Simple1DEnvironment()
171179
env_manager = SimpleEnvManager(env)
172180
trainer_config = yaml.safe_load(config)
173181
tc.start_learning(env_manager, trainer_config)
174182

175183
for brain_name, mean_reward in tc._get_measure_vals().items():
176184
assert not math.isnan(mean_reward)
177185
assert mean_reward > 0.99
186+
187+
188+
@pytest.mark.parametrize("use_discrete", [True, False])
189+
def test_simple_rl(use_discrete):
190+
env = Simple1DEnvironment(use_discrete=use_discrete)
191+
_check_environment_trains(env)

0 commit comments

Comments
 (0)