-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Multi Brain Training and Recurrent state encoder #166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Note that work needs to be done for discrete and observations if you want to use recurrent state encoding, pass in --use-recurrent and --sequence-length=<n> in the options of ppo
…d of **python ppo.py**
improved the curriculum (now reset takes lesson and not progress as input improved the learn.py : now uses a configuration file the graph scope is now displayed after training if there is only one brain, empty graph scope is used improved on CoreInternalBrain so that recurrent_in and now are used
romerocesar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
throughout the code use spaces instead of tabs
| import json | ||
| from trainers.ppo_models import * | ||
| from trainers.ppo_trainer import Trainer | ||
| from unityagents import UnityEnvironment, UnityEnvironmentException |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clean up imports according to pep8 for improved redability:
- sort asciibetically
- separate stdlib from 3rd party from out code
python/learn.py
Outdated
|
|
||
|
|
||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove empty lines
python/learn.py
Outdated
| ''' | ||
|
|
||
| options = docopt(_USAGE) | ||
| print(options) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixme
don't print, log. Besides, logging all options should probably be log.debug and not sent to stdout every time
all it takes is logging.basicConfig() to add a basic console output to the root logger:
import logging
logging.basicConfig(level=logging.DEBUG)
options = docopt(_USAGE)
logging.debug(options)
python/learn.py
Outdated
|
|
||
| env = UnityEnvironment(file_name=env_name, worker_id=worker_id, curriculum=curriculum_file) | ||
| env.curriculum.set_lesson_number(lesson) | ||
| print(str(env)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixme
log > print
python/learn.py
Outdated
| else: | ||
| return None | ||
|
|
||
| try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't throw functions in between code - put script logic under if __name__ == '__main__' for readability
python/learn.py
Outdated
| # use_recurrent = options['--use-recurrent'] | ||
| # sequence_length = int(options['--sequence-length']) | ||
| # summary_freq = int(options['--summary-freq']) | ||
| # run_path = str(options['--run-path']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixme is this dead code that should be removed or relevant documentation that should move into the module's docstring?
| self[k].reset_field() | ||
| except: | ||
| print(k) | ||
| def __getitem__(self, key): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add empty line between functions
| if key not in self.keys(): | ||
| self[key] = self.AgentBufferField() | ||
| return super(Buffer.AgentBuffer, self).__getitem__(key) | ||
| def check_length(self, key_list): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| return False | ||
| l = len(self[key]) | ||
| return True | ||
| def shuffle(self, key_list = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| for key in key_list: | ||
| if key not in self.keys(): | ||
| return False | ||
| if ((l != None) and (l!=len(self[key]))): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can omit the left side of the and - it's always safe to compare l to the output of len(self[key]) since that call can never return None
romerocesar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pending ppo_trainer.py
| # self[key] += l | ||
| # self[key] = Buffer.AgentBuffer.AgentBufferField([self[key][i] for i in s]) | ||
| # self[key].reorder(s) | ||
| def __init__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def __init__(self): | ||
| self.global_buffer = self.AgentBuffer() | ||
| super(Buffer, self).__init__() | ||
| def __str__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def __str__(self): | ||
| return "global buffer :\n\t{0}\nlocal_buffers :\n{1}".format(str(self.global_buffer), | ||
| '\n'.join(['\tagent {0} :{1}'.format(k, str(self[k])) for k in self.keys()])) | ||
| def __getitem__(self, key): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
python/trainers/buffer.py
Outdated
| if key not in self.keys(): | ||
| self[key] = self.AgentBuffer() | ||
| return super(Buffer, self).__getitem__(key) | ||
| def append_BrainInfo(self, info): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stick to snake_case and not UpperCamelCase in function names
https://www.python.org/dev/peps/pep-0008/#function-names
blank lines
python/trainers/buffer.py
Outdated
| raise BufferException("This method is not yet implemented") | ||
| # TODO: Find how useful this would be | ||
| # TODO: Implementation | ||
| def reset_global(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
python/trainers/buffer.py
Outdated
| agent_ids = list(self.keys()) | ||
| for k in agent_ids: | ||
| self[k].reset_agent() | ||
| def append_global(self, agent_id ,key_list = None, batch_size = None, training_length = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| key_list = self[agent_id].keys() | ||
| if not self[agent_id].check_length(key_list): | ||
| raise BufferException("The length of the fields {0} for agent {1} where not of comparable length" | ||
| .format(key_list, agent_id)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider using f-strings in python 3.6+
raise BufferException(f'The length of the fields {key_list} for agent {agent_id} where not of comparable length')https://www.python.org/dev/peps/pep-0498/#abstract
p.s. that message is ambiguous because lengths can always be compared. what does it mean for those to not be comparable as the exception claims?
python/trainers/buffer.py
Outdated
| self.global_buffer[field_key].extend( | ||
| self[agent_id][field_key].get_batch(batch_size =batch_size, training_length =training_length) | ||
| ) | ||
| def append_all_agent_batch_to_global(self, key_list = None, batch_size = None, training_length = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
python/trainers/buffer.py
Outdated
| self.append_global(agent_id ,key_list, batch_size, training_length) | ||
|
|
||
|
|
||
| #TODO: Put these functions into a utils class |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually remove the TODO. Utils class are rarely justified because they are abused and become a soup of ambiguous functions w/o a home. Find a proper home for these functions rather than creating technical debt that utils modules typically are
| self.new_reward = tf.placeholder(shape=[], dtype=tf.float32, name='new_reward') | ||
| self.update_reward = tf.assign(self.last_reward, self.new_reward) | ||
|
|
||
| def create_recurrent_encoder(self, s_size, input_state): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixme needs docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added docstring
romerocesar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like there are many new functions w/o unit tests. We should invest in writing those tests to avoid creating technical debt
python/trainers/ppo_trainer.py
Outdated
|
|
||
| from trainers.buffer import * | ||
| from trainers.ppo_models import * | ||
| import logging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clean imports by pep8: https://www.python.org/dev/peps/pep-0008/#imports
should read:
import logging
import os
import numpy as np
import tensorflow as tf
from trainers.buffer import *
from trainers.ppo_models import *| from trainers.buffer import * | ||
| from trainers.ppo_models import * | ||
| import logging | ||
| logger = logging.getLogger("unityagents") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blank line between imports and first line of code
python/trainers/ppo_trainer.py
Outdated
|
|
||
|
|
||
|
|
||
| class Trainer(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class needs docstring
| class Trainer(object): | ||
| def __init__(self, sess, env, brain_name, trainer_parameters, training): | ||
| """ | ||
| Responsible for collecting experiences and training PPO model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should the class be called PPOTrainer if this is specific to PPO as the docstring suggests? otherwise fix docstring
| self.use_states = (env.brains[brain_name].state_space_size > 0) | ||
| self.summary_path = trainer_parameters['summary_path'] | ||
| if not os.path.exists(self.summary_path): | ||
| os.makedirs(self.summary_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can fail for permissions or filesystem issues; consider catching the exception and re-raising with more context for better troubleshooting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exception caught
python/trainers/ppo_trainer.py
Outdated
| get_gae( | ||
| rewards=self.training_buffer[agent_id]['rewards'].get_batch(), | ||
| value_estimates=self.training_buffer[agent_id]['value_estimates'].get_batch(), | ||
| value_next=value_next, gamma=self.trainer_parameters['gamma'], lambd=self.trainer_parameters['lambd']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keep lines under 100cols for readability
python/trainers/ppo_trainer.py
Outdated
| batch_size = self.trainer_parameters['batch_size'] | ||
| total_v, total_p = 0, 0 | ||
| advantages = self.training_buffer.global_buffer['advantages'].get_batch() | ||
| self.training_buffer.global_buffer['advantages'].set( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's very brittle for this code to have to know so much about the internals (e.g. keys) of the dictionary used by the buffer class and specifically to have to use string literals copy+pasted throughout the code like 'num_epoch'. Instead, the buffer class should expose a clean API that allows the implementation to change w/o its client (i.e. this code) needing to change.
I don't expect that refactor to be part of this PR, but we should at least track it and add a TODO; otherwise we're just adding technical debt here
python/trainers/ppo_trainer.py
Outdated
| Saves training statistics to Tensorboard. | ||
| :param lesson_number: The lesson the trainer is at. | ||
| """ | ||
| if self.get_step() % self.trainer_parameters['summary_freq'] == 0 and self.get_step() != 0 and self.is_training: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try to stay within 100cols for readability
python/trainers/ppo_trainer.py
Outdated
| steps = self.get_step() | ||
| if len(self.stats['cumulative_reward']) > 0: | ||
| mean_reward = np.mean(self.stats['cumulative_reward']) | ||
| print("{0} : Step: {1}. Mean Reward: {2}. Std of Reward: {3}." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logger.debug instead of print. You already initialized a logger above
python/trainers/ppo_trainer.py
Outdated
| self.summary_writer.add_summary(summary, steps) | ||
| self.summary_writer.flush() | ||
|
|
||
| def write_text(self, key, input_dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider renaming to include tf in the name given how specific this function is
…iner branch. Thanks asolano.
…rn.py script to put them in the config file.
awjuliani
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is ready to merge into dev-0.3.
Modifying the trainer class to be more flexible.
Added some TODO elements.
Added inline docs to the buffer.
Enables multi brain training.