Skip to content

Commit 48caa13

Browse files
committed
Add parallelism flag
1 parent 1aafade commit 48caa13

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

PolicyGradient/a3c/train.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
tf.flags.DEFINE_integer("max_global_steps", None, "Stop after this many steps in the environment")
2828
tf.flags.DEFINE_integer("eval_every", 300, "Evaluate the policy ever [eval_every] seconds")
2929
tf.flags.DEFINE_boolean("reset", False, "If true, delete the existing model directory")
30+
tf.flags.DEFINE_integer("parallelism", None, "Number of threads to run. If not given we run [num_cpu_cores] threads.")
3031

3132
FLAGS = tf.flags.FLAGS
3233

@@ -35,6 +36,10 @@ def make_env():
3536

3637
VALID_ACTIONS = [0, 1, 2, 3]
3738
NUM_WORKERS = multiprocessing.cpu_count()
39+
40+
if FLAGS.parallelism:
41+
NUM_WORKERS = FLAGS.parallelism
42+
3843
MODEL_DIR = FLAGS.model_dir
3944
CHECKPOINT_DIR = os.path.join(MODEL_DIR, "checkpoints")
4045

0 commit comments

Comments
 (0)