Skip to content

Conversation

@vincentpierre
Copy link
Contributor

Modifying the trainer class to be more flexible.
Added some TODO elements.
Added inline docs to the buffer.
Enables multi brain training.

Note that work needs to be done for discrete and observations
if you want to use recurrent state encoding, pass in --use-recurrent and --sequence-length=<n> in the options of ppo
@vincentpierre vincentpierre changed the title initial commit Recurrent state encoder Jan 4, 2018
@vincentpierre vincentpierre changed the title Recurrent state encoder Multi Brain Training and Recurrent state encoder Jan 9, 2018
improved the curriculum (now reset takes lesson and not progress as input
improved the learn.py : now uses a configuration file
the graph scope is now displayed after training
if there is only one brain, empty graph scope is used
improved on CoreInternalBrain so that recurrent_in and now are used
@romerocesar romerocesar self-assigned this Jan 17, 2018
Copy link

@romerocesar romerocesar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

throughout the code use spaces instead of tabs

import json
from trainers.ppo_models import *
from trainers.ppo_trainer import Trainer
from unityagents import UnityEnvironment, UnityEnvironmentException

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clean up imports according to pep8 for improved redability:

  • sort asciibetically
  • separate stdlib from 3rd party from out code

https://www.python.org/dev/peps/pep-0008/#imports

python/learn.py Outdated




Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove empty lines

python/learn.py Outdated
'''

options = docopt(_USAGE)
print(options)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixme
don't print, log. Besides, logging all options should probably be log.debug and not sent to stdout every time

all it takes is logging.basicConfig() to add a basic console output to the root logger:

import logging

logging.basicConfig(level=logging.DEBUG)

options = docopt(_USAGE)
logging.debug(options)

python/learn.py Outdated

env = UnityEnvironment(file_name=env_name, worker_id=worker_id, curriculum=curriculum_file)
env.curriculum.set_lesson_number(lesson)
print(str(env))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixme
log > print

python/learn.py Outdated
else:
return None

try:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't throw functions in between code - put script logic under if __name__ == '__main__' for readability

python/learn.py Outdated
# use_recurrent = options['--use-recurrent']
# sequence_length = int(options['--sequence-length'])
# summary_freq = int(options['--summary-freq'])
# run_path = str(options['--run-path'])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixme is this dead code that should be removed or relevant documentation that should move into the module's docstring?

self[k].reset_field()
except:
print(k)
def __getitem__(self, key):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if key not in self.keys():
self[key] = self.AgentBufferField()
return super(Buffer.AgentBuffer, self).__getitem__(key)
def check_length(self, key_list):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return False
l = len(self[key])
return True
def shuffle(self, key_list = None):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for key in key_list:
if key not in self.keys():
return False
if ((l != None) and (l!=len(self[key]))):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can omit the left side of the and - it's always safe to compare l to the output of len(self[key]) since that call can never return None

Copy link

@romerocesar romerocesar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pending ppo_trainer.py

# self[key] += l
# self[key] = Buffer.AgentBuffer.AgentBufferField([self[key][i] for i in s])
# self[key].reorder(s)
def __init__(self):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def __init__(self):
self.global_buffer = self.AgentBuffer()
super(Buffer, self).__init__()
def __str__(self):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def __str__(self):
return "global buffer :\n\t{0}\nlocal_buffers :\n{1}".format(str(self.global_buffer),
'\n'.join(['\tagent {0} :{1}'.format(k, str(self[k])) for k in self.keys()]))
def __getitem__(self, key):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if key not in self.keys():
self[key] = self.AgentBuffer()
return super(Buffer, self).__getitem__(key)
def append_BrainInfo(self, info):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stick to snake_case and not UpperCamelCase in function names

https://www.python.org/dev/peps/pep-0008/#function-names

blank lines

https://www.python.org/dev/peps/pep-0008/#blank-lines

raise BufferException("This method is not yet implemented")
# TODO: Find how useful this would be
# TODO: Implementation
def reset_global(self):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agent_ids = list(self.keys())
for k in agent_ids:
self[k].reset_agent()
def append_global(self, agent_id ,key_list = None, batch_size = None, training_length = None):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

key_list = self[agent_id].keys()
if not self[agent_id].check_length(key_list):
raise BufferException("The length of the fields {0} for agent {1} where not of comparable length"
.format(key_list, agent_id))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider using f-strings in python 3.6+

raise BufferException(f'The length of the fields {key_list} for agent {agent_id} where not of comparable length')

https://www.python.org/dev/peps/pep-0498/#abstract

p.s. that message is ambiguous because lengths can always be compared. what does it mean for those to not be comparable as the exception claims?

self.global_buffer[field_key].extend(
self[agent_id][field_key].get_batch(batch_size =batch_size, training_length =training_length)
)
def append_all_agent_batch_to_global(self, key_list = None, batch_size = None, training_length = None):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.append_global(agent_id ,key_list, batch_size, training_length)


#TODO: Put these functions into a utils class

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually remove the TODO. Utils class are rarely justified because they are abused and become a soup of ambiguous functions w/o a home. Find a proper home for these functions rather than creating technical debt that utils modules typically are

self.new_reward = tf.placeholder(shape=[], dtype=tf.float32, name='new_reward')
self.update_reward = tf.assign(self.last_reward, self.new_reward)

def create_recurrent_encoder(self, s_size, input_state):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixme needs docstring

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added docstring

Copy link

@romerocesar romerocesar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like there are many new functions w/o unit tests. We should invest in writing those tests to avoid creating technical debt


from trainers.buffer import *
from trainers.ppo_models import *
import logging

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clean imports by pep8: https://www.python.org/dev/peps/pep-0008/#imports

should read:

import logging
import os

import numpy as np
import tensorflow as tf

from trainers.buffer import *
from trainers.ppo_models import *

from trainers.buffer import *
from trainers.ppo_models import *
import logging
logger = logging.getLogger("unityagents")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blank line between imports and first line of code




class Trainer(object):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class needs docstring

class Trainer(object):
def __init__(self, sess, env, brain_name, trainer_parameters, training):
"""
Responsible for collecting experiences and training PPO model.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the class be called PPOTrainer if this is specific to PPO as the docstring suggests? otherwise fix docstring

self.use_states = (env.brains[brain_name].state_space_size > 0)
self.summary_path = trainer_parameters['summary_path']
if not os.path.exists(self.summary_path):
os.makedirs(self.summary_path)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can fail for permissions or filesystem issues; consider catching the exception and re-raising with more context for better troubleshooting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exception caught

get_gae(
rewards=self.training_buffer[agent_id]['rewards'].get_batch(),
value_estimates=self.training_buffer[agent_id]['value_estimates'].get_batch(),
value_next=value_next, gamma=self.trainer_parameters['gamma'], lambd=self.trainer_parameters['lambd'])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep lines under 100cols for readability

batch_size = self.trainer_parameters['batch_size']
total_v, total_p = 0, 0
advantages = self.training_buffer.global_buffer['advantages'].get_batch()
self.training_buffer.global_buffer['advantages'].set(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's very brittle for this code to have to know so much about the internals (e.g. keys) of the dictionary used by the buffer class and specifically to have to use string literals copy+pasted throughout the code like 'num_epoch'. Instead, the buffer class should expose a clean API that allows the implementation to change w/o its client (i.e. this code) needing to change.

I don't expect that refactor to be part of this PR, but we should at least track it and add a TODO; otherwise we're just adding technical debt here

Saves training statistics to Tensorboard.
:param lesson_number: The lesson the trainer is at.
"""
if self.get_step() % self.trainer_parameters['summary_freq'] == 0 and self.get_step() != 0 and self.is_training:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try to stay within 100cols for readability

steps = self.get_step()
if len(self.stats['cumulative_reward']) > 0:
mean_reward = np.mean(self.stats['cumulative_reward'])
print("{0} : Step: {1}. Mean Reward: {2}. Std of Reward: {3}."

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger.debug instead of print. You already initialized a logger above

self.summary_writer.add_summary(summary, steps)
self.summary_writer.flush()

def write_text(self, key, input_dict):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider renaming to include tf in the name given how specific this function is

clean up imports
removing old comments and adding new ones
replacing print with log
use yaml instead of json for the trainer parameters
Copy link
Contributor

@awjuliani awjuliani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is ready to merge into dev-0.3.

@awjuliani awjuliani merged commit 85f5e63 into development-0.3 Jan 19, 2018
@awjuliani awjuliani deleted the dev-trainer branch January 19, 2018 21:57
@github-actions github-actions bot locked as resolved and limited conversation to collaborators May 20, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants