-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Complete set_attn_processor for prior and vae #3796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i love the tests!
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from ..configuration_utils import ConfigMixin, register_to_config | ||
from ..utils import BaseOutput, apply_forward_hook | ||
from .attention_processor import AttentionProcessor, AttnProcessor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we maybe also implement this automatic discoverability of AttnProcessor2_0
?
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't fully follow here. Note that by default the ATTN 2.0 is always chosen whenever someone uses the Attention
class which is done for both the AutoEncoder and the Prior, see:
if processor is None:
and:self.attn1 = Attention(
@@ -182,8 +184,9 @@ def test_output_pretrained(self): | |||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3)) | |||
|
|||
|
|||
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): | |||
class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curiosity:
What does it mean by NCSNpp
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -153,8 +154,9 @@ def test_unet_1d_maestro(self): | |||
assert (output_max - 0.0607).abs() < 4e-4 | |||
|
|||
|
|||
class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): | |||
class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curiosity?
What is it called UNetRLModelTests
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests a RL model we've integrated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exceptionally clean!
Co-authored-by: Sayak Paul <[email protected]>
* relax tolerance slightly * Add more tests * upload readme * upload readme * Apply suggestions from code review * Improve API Autoencoder KL * finalize * finalize tests * finalize tests * Apply suggestions from code review Co-authored-by: Sayak Paul <[email protected]> * up --------- Co-authored-by: Sayak Paul <[email protected]>
* relax tolerance slightly * Add more tests * upload readme * upload readme * Apply suggestions from code review * Improve API Autoencoder KL * finalize * finalize tests * finalize tests * Apply suggestions from code review Co-authored-by: Sayak Paul <[email protected]> * up --------- Co-authored-by: Sayak Paul <[email protected]>
This adds cleans up the tests a bit and makes sure that:
UNet2DConditionModel
PriorTransformer
AutoencoderKL
all have the same features. This PR also cleans up the tests a bit and adds a whole test suite for the prior transformer.