-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Remove EMA model from Diffusion Policy #134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove EMA model from Diffusion Policy #134
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Genius! so happy we removed EMA!
# ( | ||
# "pusht", | ||
# "diffusion", | ||
# ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], | ||
# ), | ||
# ("aloha", "act", ["policy.n_action_steps=10"]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we remove?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did, and updated instructions.
tests/test_policies.py
Outdated
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], | ||
), | ||
("aloha", "act", ["policy.n_action_steps=10"]), | ||
# ("aloha", "act", ["policy.n_action_steps=10"]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we remove?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be uncommented actually. I reverted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be uncommented actually. I reverted.
Was about to say the same, nice!
Note: this method uses the ema model weights if self.training == False, otherwise the non-ema model | ||
weights. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a note about EMA, saying that we tested with and without, and got as good or better results without EMA, so we decided to remove it for sake of simplicity?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay but I added them in the yaml config as this detail is more relevant to the outer scope. Ptal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
tests/test_policies.py
Outdated
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], | ||
), | ||
("aloha", "act", ["policy.n_action_steps=10"]), | ||
# ("aloha", "act", ["policy.n_action_steps=10"]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be uncommented actually. I reverted.
Was about to say the same, nice!
""" | ||
NOTE: If this test does not pass, and you have intentionally changed something in the policy: | ||
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should | ||
include a report on what changed and how that affected the outputs. | ||
2. Go to the `if __name__ == "__main__"` block of `test/scripts/save_policy_to_safetensors.py` and | ||
comment in the policies you want to update the test artifacts for. | ||
3. Run `python test/scripts/save_policy_to_safetensors.py`. The test artifact should be updated. | ||
4. Check that this test now passes. | ||
5. Remember to restore `test/scripts/save_policy_to_safetensors.py` to its original state. | ||
6. Remember to stage and commit the resulting changes to `tests/data`. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I should have done that, that's really helpful, thanks!
What does this PR do?
As the title suggests. Also updates test artifacts for diffusion policy backwards compatibility check.
Side change:
How was it tested?
I have a pretrained diffusion policy that reaches SOTA eval metrics. I evaluated it for 500 episodes using the EMA vs non-EMA weights.
The mean "avg_max_reward" is higher for non-EMA (without considering error-bars). For success rate we can take a uniform prior and calculate the posterior beta distribution to get the mean, upper confidence bound (mean + 34.1%) and lower confidence bound (mean - 34.1%)
EMA results:
Mean: 0.6434262948207171
Lower: 0.6220813179507085
Upper: 0.6647718949623826
Non-EMA results:
Mean: 0.6374501992031872
Lower: 0.6160270432185413
Upper: 0.6588739530055447
The means are not significantly different.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR. Try to avoid tagging more than 3 people.