-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Add AudioLDM 2 #4549
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add AudioLDM 2 #4549
Conversation
The documentation is not available anymore as the PR was closed or merged. |
If possible, then maybe consider doing a pruned version of the model, so it'll be able to run on 6 GB of VRAM? |
It's just the slow integration tests and docs to go here! In the interest of time, would you be able to do a first pass of this @sayakpaul @williamberman to confirm that you're happy with the |
) | ||
|
||
|
||
class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is good if we need to make a separate class because this is a one off pipeline. I think it's possible @patrickvonplaten might have a different opinion but I don't think it's worth blocking because we can merge into the existing unet after the fact if we need to.
So let's just move forward here and I'll file an issue asking patrick to double check he's ok with it being a separate class when he's back from vacation :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#4658 follow up issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few small requests but looks good!
Alright this is working nicely for the base, large and music variants of AudioLDM2! The TTS checkpoints require a VITS text encoder, which will be merged as part of huggingface/transformers#24085 Waiting on this before converting the TTS checkpoints. I've written the AudioLDM2 pipeline in such a way that the TTS models should be compatible directly! (possibly with some minor updates) |
Great PR and great reviews here - nice! |
* from audioldm * unet down + mid * vae, clap, flan-t5 * start sequence audio mae * iterate on audioldm encoder * finish encoder * finish weight conversion * text pre-processing * gpt2 pre-processing * fix projection model * working * unet equivalence * finish in base * add unet cond * finish unet * finish custom unet * start clean-up * revert base unet changes * refactor pre-processing * tests: from audioldm * fix some tests * more fixes * iterate on tests * make fix copies * harden fast tests * slow integration tests * finish tests * update checkpoint * update copyright * docs * remove outdated method * add docstring * make style * remove decode latents * enable cpu offload * (text_encoder_1, tokenizer_1) -> (text_encoder, tokenizer) * more clean up * more refactor * build pr docs * Update docs/source/en/api/pipelines/audioldm2.md Co-authored-by: Sayak Paul <[email protected]> * small clean * tidy conversion * update for large checkpoint * generate -> generate_language_model * full clap model * shrink clap-audio in tests * fix large integration test * fix fast tests * use generation config * make style * update docs * finish docs * finish doc * update tests * fix last test * syntax * finalise tests * refactor projection model in prep for TTS * fix fast tests * style --------- Co-authored-by: Sayak Paul <[email protected]>
* from audioldm * unet down + mid * vae, clap, flan-t5 * start sequence audio mae * iterate on audioldm encoder * finish encoder * finish weight conversion * text pre-processing * gpt2 pre-processing * fix projection model * working * unet equivalence * finish in base * add unet cond * finish unet * finish custom unet * start clean-up * revert base unet changes * refactor pre-processing * tests: from audioldm * fix some tests * more fixes * iterate on tests * make fix copies * harden fast tests * slow integration tests * finish tests * update checkpoint * update copyright * docs * remove outdated method * add docstring * make style * remove decode latents * enable cpu offload * (text_encoder_1, tokenizer_1) -> (text_encoder, tokenizer) * more clean up * more refactor * build pr docs * Update docs/source/en/api/pipelines/audioldm2.md Co-authored-by: Sayak Paul <[email protected]> * small clean * tidy conversion * update for large checkpoint * generate -> generate_language_model * full clap model * shrink clap-audio in tests * fix large integration test * fix fast tests * use generation config * make style * update docs * finish docs * finish doc * update tests * fix last test * syntax * finalise tests * refactor projection model in prep for TTS * fix fast tests * style --------- Co-authored-by: Sayak Paul <[email protected]>
What does this PR do?
Adds AudioLDM 2 from the paper AudioLDM 2: A General Framework for Audio, Music, and Speech Generation
Architecture
Steps 5-7 is the same as AudioLDM. The remainder are new.
Diagram:
UNet
The vanilla UNet 2D cross-attention layer looks as follows:
hidden_states
)hidden_states
,encoder_hidden_states
)=> this is the architecture that is used in diffusers when we run the UNet forward with arguments of (
hidden_states
,encoder_hidden_states
)AudioLDM 2 extends the vanilla UNet architecture to use an additional self-attention layer and two cross-attention layers:
hidden_states
)hidden_states
)hidden_states
,encoder_hidden_states_1
)hidden_states
,encoder_hidden_states_2
)=> here we use a different set of encoder hidden-states for cross-attention blocks 1 and 2. The first hidden-states are those obtained from the T5 model. The second hidden-states are those generated from the language model. Also, we don’t want to pass either of these encoder hidden-states to the self-attention layer, since it uses double self-attention.
Checklist
Weight conversion:
Forward pass:
Pipeline: