-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Improvements for GAIL #2296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improvements for GAIL #2296
Conversation
# Conflicts: # ml-agents/mlagents/trainers/models.py
…o develop-irl-ervin # Conflicts: # ml-agents/mlagents/trainers/components/bc/model.py # ml-agents/mlagents/trainers/components/bc/module.py # ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py # ml-agents/mlagents/trainers/components/reward_signals/gail/model.py # ml-agents/mlagents/trainers/components/reward_signals/gail/signal.py
@@ -261,5 +292,8 @@ def create_loss(self, learning_rate: float) -> None: | |||
) | |||
else: | |||
self.loss = self.discriminator_loss | |||
|
|||
self.loss = self.loss + self.gradient_penalty * self.compute_gradient_penalty() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need better var names here. From the current names I would expect self.compute_gradient_penalty()
to return the gradient penalty, but it returns the magnitude of the gradient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed this to create_gradient_magnitude
and the weight to be gradient_penalty_weight
grad = tf.gradients(grad_estimate, [grad_input])[0] | ||
|
||
# Norm, like log, can return NaN. Use our own safe_norm | ||
safe_norm = tf.sqrt(tf.reduce_sum(grad ** 2, axis=-1) + EPSILON) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does norm result in NaN? I could see that happening if there was overflow, but in that case adding an epsilon isn't going to help.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not the norm, it's the gradient of the norm. At 0 the gradient of sqrt() is a horizontal line. I'll update the comment to reflect this.
self.intrinsic_reward = -tf.log(1.0 - self.discriminator_score + 1e-7) | ||
self.intrinsic_reward = -tf.log(1.0 - self.discriminator_score + EPSILON) | ||
|
||
def compute_gradient_penalty(self) -> tf.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the performance improvement from this? Faster/more stable convergence?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Faster convergence, esp. for Crawler. I'm seeing about 25% less steps required with PPO. TBH a large motivation for this is for SAC, where without GP the discriminator will overfit very quickly - GAIL + SAC doesn't work at all without it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The squiggly line does not lie.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good. It would be nice to see some rough numbers showing better stability/performance with the gradient penalty as compared to without.
@@ -32,6 +32,9 @@ def __init__(self, policy: TFPolicy, strength: float, gamma: float): | |||
short_name = class_name.replace("RewardSignal", "") | |||
self.stat_name = f"Policy/{short_name} Reward" | |||
self.value_name = f"Policy/{short_name} Value Estimate" | |||
# Don't terminate discounted reward computation at Done. Useful for eliminating positive bias in rewards with | |||
# no natural end, e.g. GAIL or Curiosity | |||
self.ignore_terminal_states = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit (feel free to ignore): could you name this in terms of what it does instead of what it doesn't do? I find double-negatives like this a little harder to understand when trying to reason about the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was confusing for me too while writing it - I reversed it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, not sure if it's on any official naming scheme guides, but I generally try avoid using variables with negative names for this reason.
@@ -261,5 +292,8 @@ def create_loss(self, learning_rate: float) -> None: | |||
) | |||
else: | |||
self.loss = self.discriminator_loss | |||
|
|||
self.loss = self.loss + self.gradient_penalty * self.compute_gradient_penalty() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe don't use self.compute_gradient_penalty()
if self.gradient_penalty == 0.0
(to allow turning it off if you really don't want it for some reason)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made it only happen if the penalty is greater than 0
gail_config.yaml
with GAIL examplestrainer_config.yaml
and unnecessary gammas