Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
3541164
initial commit
vincentpierre Dec 9, 2017
7cc6d2c
multi-training fixes
vincentpierre Dec 11, 2017
2d38f4b
fixes for python3
vincentpierre Dec 11, 2017
eea344f
tracking possible memory leak
vincentpierre Dec 12, 2017
04fac45
improvement on the buffer class
vincentpierre Dec 13, 2017
46f54b3
added missing doc
vincentpierre Dec 13, 2017
4b71ddb
bug fixes on ppo
vincentpierre Dec 14, 2017
63bb9d5
implemented recurrent state encoding for continuous state
vincentpierre Jan 4, 2018
fec4dfc
merge development-0.3 into dev-trainer
vincentpierre Jan 4, 2018
51c069d
The bytes file will be generated if the user interrupts the learning …
vincentpierre Jan 8, 2018
8bde510
changed the names of the files. Please use **python learn.py** instea…
vincentpierre Jan 8, 2018
a35583a
As Hunter suggested, the available graphscopes are now displayed afte…
vincentpierre Jan 8, 2018
097ea19
Merge remote-tracking branch 'origin/development-0.3' into dev-trainer
vincentpierre Jan 9, 2018
68a6ff3
modifications to support observations and discrete input
vincentpierre Jan 9, 2018
436b7fc
reverted commit on curriculum
vincentpierre Jan 10, 2018
6b6af71
Internal Brain will not complain if the graph scope does not end with /
vincentpierre Jan 10, 2018
b746d1b
fix for tensorflow r1.4
vincentpierre Jan 12, 2018
c346f54
fix from issue 249
vincentpierre Jan 16, 2018
384651a
removed old comments
vincentpierre Jan 17, 2018
30e7c0a
fixes:
vincentpierre Jan 17, 2018
508ad7c
put buffer back into trainers
vincentpierre Jan 17, 2018
9bd01f8
add pyyaml to the requirements.txt
vincentpierre Jan 18, 2018
5d5d8dc
Implemented the PR of asolano for multi observations with the dev-tra…
vincentpierre Jan 18, 2018
4530370
removed PPO notebook, removed default trainer parameters from the lea…
vincentpierre Jan 18, 2018
407cb95
bug fix on the multi camera
vincentpierre Jan 18, 2018
f113573
added tests to make sure the trainers receive the right parameters
vincentpierre Jan 18, 2018
3b2fb9c
removed max_step from the learn.py parameters
vincentpierre Jan 18, 2018
7c87485
imitation trainer initial commit
vincentpierre Jan 19, 2018
9bc6d19
made a trainer abstract class
vincentpierre Jan 19, 2018
73ce025
indentation fix
vincentpierre Jan 19, 2018
9888af2
added some more tests
vincentpierre Jan 19, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
imitation trainer initial commit
  • Loading branch information
vincentpierre committed Jan 19, 2018
commit 7c87485ff554674aa24e25dcaae769c8c7e7814a
61 changes: 35 additions & 26 deletions python/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from trainers.ghost_trainer import GhostTrainer
from trainers.ppo_models import *
from trainers.ppo_trainer import PPOTrainer
from trainers.imitation_trainer import ImitationTrainer
from unityagents import UnityEnvironment, UnityEnvironmentException

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clean up imports according to pep8 for improved redability:

  • sort asciibetically
  • separate stdlib from 3rd party from out code

https://www.python.org/dev/peps/pep-0008/#imports


def get_progress():
Expand Down Expand Up @@ -112,6 +113,8 @@ def get_progress():
for brain_name in env.external_brain_names:
if 'is_ghost' not in trainer_parameters_dict[brain_name]:
trainer_parameters_dict[brain_name]['is_ghost'] = False
if 'is_imitation' not in trainer_parameters_dict[brain_name]:
trainer_parameters_dict[brain_name]['is_imitation'] = False
if trainer_parameters_dict[brain_name]['is_ghost']:
if trainer_parameters_dict[brain_name]['brain_to_copy'] not in env.external_brain_names:
raise UnityEnvironmentException("The external brain {0} could not be found in the environment "
Expand All @@ -120,6 +123,8 @@ def get_progress():
trainer_parameters_dict[brain_name]['original_brain_parameters'] = trainer_parameters_dict[
trainer_parameters_dict[brain_name]['brain_to_copy']]
trainers[brain_name] = GhostTrainer(sess, env, brain_name, trainer_parameters_dict[brain_name], train_model)
elif trainer_parameters_dict[brain_name]['is_imitation']:
trainers[brain_name] = ImitationTrainer(sess, env, brain_name, trainer_parameters_dict[brain_name], train_model)
else:
trainers[brain_name] = PPOTrainer(sess, env, brain_name, trainer_parameters_dict[brain_name], train_model)

Expand Down Expand Up @@ -171,7 +176,7 @@ def get_progress():
trainer.update_model()
# Write training statistics to tensorboard.
trainer.write_summary(env.curriculum.lesson_number)
if train_model:
if train_model and trainer.get_step <= trainer.get_max_steps:
trainer.increment_step()
trainer.update_last_reward()
if train_model and trainer.get_step <= trainer.get_max_steps:
Expand All @@ -184,31 +189,35 @@ def get_progress():
if global_step != 0 and train_model:
save_model(sess, model_path=model_path, steps=global_step, saver=saver)
except KeyboardInterrupt:
logger.info("Learning was interupted. Please wait while the graph is generated.")
save_model(sess, model_path=model_path, steps=global_step, saver=saver)
if train_model:
logger.info("Learning was interupted. Please wait while the graph is generated.")
save_model(sess, model_path=model_path, steps=global_step, saver=saver)
pass
env.close()
graph_name = (env_name.strip()
.replace('.app', '').replace('.exe', '').replace('.x86_64', '').replace('.x86', ''))
graph_name = os.path.basename(os.path.normpath(graph_name))
nodes = []
scopes = []
for brain_name in trainers.keys():
if trainers[brain_name].graph_scope is not None:
scope = trainers[brain_name].graph_scope + '/'
if scope == '/':
scope = ''
scopes += [scope]
if not trainers[brain_name].parameters["use_recurrent"]:
nodes +=[scope + x for x in ["action","value_estimate","action_probs"]]
else:
nodes +=[scope + x for x in ["action","value_estimate","action_probs","recurrent_out"]]
export_graph(model_path, graph_name, target_nodes=','.join(nodes))
if len(scopes) > 1:
logger.info("List of available scopes :")
for scope in scopes:
logger.info("\t" + scope )
logger.info("List of nodes exported :")
for n in nodes:
logger.info("\t" + n)
if train_model:
graph_name = (env_name.strip()
.replace('.app', '').replace('.exe', '').replace('.x86_64', '').replace('.x86', ''))
graph_name = os.path.basename(os.path.normpath(graph_name))
nodes = []
scopes = []
for brain_name in trainers.keys():
if trainers[brain_name].graph_scope is not None:
scope = trainers[brain_name].graph_scope + '/'
if scope == '/':
scope = ''
scopes += [scope]
if trainers[brain_name].parameters["is_imitation"]:
nodes +=[scope + x for x in ["action"]]
elif not trainers[brain_name].parameters["use_recurrent"]:
nodes +=[scope + x for x in ["action","value_estimate","action_probs"]]
else:
nodes +=[scope + x for x in ["action","value_estimate","action_probs","recurrent_out"]]
export_graph(model_path, graph_name, target_nodes=','.join(nodes))
if len(scopes) > 1:
logger.info("List of available scopes :")
for scope in scopes:
logger.info("\t" + scope )
logger.info("List of nodes exported :")
for n in nodes:
logger.info("\t" + n)

2 changes: 1 addition & 1 deletion python/trainers/ghost_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(self, sess, env, brain_name, trainer_parameters, training):
self.trainer_parameters = trainer_parameters

def __str__(self):
return '''Hypermarameters for {0}: \n{1}'''.format(
return '''Hypermarameters for the Ghost Trainer of brain {0}: \n{1}'''.format(
self.brain_name, '\n'.join(['\t{0}:\t{1}'.format(x, self.trainer_parameters[x]) for x in self.param_keys]))


Expand Down
Loading