Skip to content

Hotfix 0.6.0a to develop #1589

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 11, 2019
Prev Previous commit
Next Next commit
fixed the windows ctrl-c bug (#1558)
* Documentation tweaks and updates (#1479)

* Add blurb about using the --load flag in the intro guide, and typo fix.

* Add section in tutorial to create multiple area learning environment.

* Add mention of Done() method in agent design

* fixed the windows ctrl-c bug

* fixed typo

* removed some uncessary printing

* nothing

* make the import of the win api conditional

* removved the duplicate code

* added the ability to use python debugger on ml-agents

* added newline at the end, changed the import to be complete path

* changed the info.log into policy.export_model, changed the sys.platform to use startswith

* fixed a bug

* remove the printing of the path

* tweaked the info message to notify the user about the expected error message

* removed some logging according to comments

* removed the sys import

* Revert "Documentation tweaks and updates (#1479)"

This reverts commit 84ef07a.

* resolved the model path comment
  • Loading branch information
xiaomaogy authored Jan 8, 2019
commit 5510062b93b1249cace6028381ffb2fc8ff8a22a
8 changes: 6 additions & 2 deletions ml-agents/mlagents/trainers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import numpy as np
from docopt import docopt

from .trainer_controller import TrainerController
from .exception import TrainerError
from mlagents.trainers.trainer_controller import TrainerController
from mlagents.trainers.exception import TrainerError


def run_training(sub_id, run_seed, run_options, process_queue):
Expand Down Expand Up @@ -117,3 +117,7 @@ def main():
# Wait for signal that environment has successfully launched
while process_queue.get() is not True:
continue

# For python debugger to directly run this script
if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions ml-agents/mlagents/trainers/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def export_model(self):
clear_devices=True, initializer_nodes='', input_saver='',
restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0')
logger.info('Exported ' + self.model_path + '.bytes file')

def _process_graph(self):
"""
Expand Down
46 changes: 33 additions & 13 deletions ml-agents/mlagents/trainers/trainer_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import glob
import logging
import shutil
import sys
if sys.platform.startswith('win'):
import win32api
import win32con

import yaml
import re
Expand Down Expand Up @@ -103,6 +107,7 @@ def __init__(self, env_path, run_id, save_freq, curriculum_folder,
self.keep_checkpoints = keep_checkpoints
self.trainers = {}
self.seed = seed
self.global_step = 0
np.random.seed(self.seed)
tf.set_random_seed(self.seed)
self.env = UnityEnvironment(file_name=env_path,
Expand Down Expand Up @@ -181,6 +186,23 @@ def _save_model(self,steps=0):
self.trainers[brain_name].save_model()
self.logger.info('Saved Model')

def _save_model_when_interrupted(self, steps=0):
self.logger.info('Learning was interrupted. Please wait '
'while the graph is generated.')
self._save_model(steps)

def _win_handler(self, event):
"""
This function gets triggered after ctrl-c or ctrl-break is pressed
under Windows platform.
"""
if event in (win32con.CTRL_C_EVENT, win32con.CTRL_BREAK_EVENT):
self._save_model_when_interrupted(self.global_step)
self._export_graph()
sys.exit()
return True
return False

def _export_graph(self):
"""
Exports latest saved models to .bytes format for Unity embedding.
Expand Down Expand Up @@ -288,12 +310,14 @@ def start_learning(self):
self._initialize_trainers(trainer_config)
for _, t in self.trainers.items():
self.logger.info(t)
global_step = 0 # This is only for saving the model
curr_info = self._reset_env()
if self.train_model:
for brain_name, trainer in self.trainers.items():
trainer.write_tensorboard_text('Hyperparameters',
trainer.parameters)
if sys.platform.startswith('win'):
# Add the _win_handler function to the windows console's handler function list
win32api.SetConsoleCtrlHandler(self._win_handler, True)
try:
while any([t.get_step <= t.get_max_steps \
for k, t in self.trainers.items()]) \
Expand Down Expand Up @@ -353,31 +377,27 @@ def start_learning(self):
# Write training statistics to Tensorboard.
if self.meta_curriculum is not None:
trainer.write_summary(
global_step,
self.global_step,
lesson_num=self.meta_curriculum
.brains_to_curriculums[brain_name]
.lesson_num)
else:
trainer.write_summary(global_step)
trainer.write_summary(self.global_step)
if self.train_model \
and trainer.get_step <= trainer.get_max_steps:
trainer.increment_step_and_update_last_reward()
global_step += 1
if global_step % self.save_freq == 0 and global_step != 0 \
self.global_step += 1
if self.global_step % self.save_freq == 0 and self.global_step != 0 \
and self.train_model:
# Save Tensorflow model
self._save_model(steps=global_step)
self._save_model(steps=self.global_step)
curr_info = new_info
# Final save Tensorflow model
if global_step != 0 and self.train_model:
self._save_model(steps=global_step)
if self.global_step != 0 and self.train_model:
self._save_model(steps=self.global_step)
except KeyboardInterrupt:
print('--------------------------Now saving model--------------'
'-----------')
if self.train_model:
self.logger.info('Learning was interrupted. Please wait '
'while the graph is generated.')
self._save_model(steps=global_step)
self._save_model_when_interrupted(steps=self.global_step)
pass
self.env.close()
if self.train_model:
Expand Down