Skip to content

Commit 6b68afd

Browse files
do not automatically enable xformers (huggingface#1640)
* do not automatically enable xformers * uP
1 parent 63c4944 commit 6b68afd

File tree

4 files changed

+30
-11
lines changed

4 files changed

+30
-11
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
1818
from diffusers.optimization import get_scheduler
1919
from diffusers.utils import check_min_version
20+
from diffusers.utils.import_utils import is_xformers_available
2021
from huggingface_hub import HfFolder, Repository, whoami
2122
from PIL import Image
2223
from torchvision import transforms
@@ -488,6 +489,15 @@ def main(args):
488489
revision=args.revision,
489490
)
490491

492+
if is_xformers_available():
493+
try:
494+
unet.enable_xformers_memory_efficient_attention(True)
495+
except Exception as e:
496+
logger.warning(
497+
"Could not enable memory efficient attention. Make sure xformers is installed"
498+
f" correctly and a GPU is available: {e}"
499+
)
500+
491501
vae.requires_grad_(False)
492502
if not args.train_text_encoder:
493503
text_encoder.requires_grad_(False)

examples/text_to_image/train_text_to_image.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
1919
from diffusers.optimization import get_scheduler
2020
from diffusers.utils import check_min_version
21+
from diffusers.utils.import_utils import is_xformers_available
2122
from huggingface_hub import HfFolder, Repository, whoami
2223
from torchvision import transforms
2324
from tqdm.auto import tqdm
@@ -364,6 +365,15 @@ def main():
364365
revision=args.revision,
365366
)
366367

368+
if is_xformers_available():
369+
try:
370+
unet.enable_xformers_memory_efficient_attention(True)
371+
except Exception as e:
372+
logger.warning(
373+
"Could not enable memory efficient attention. Make sure xformers is installed"
374+
f" correctly and a GPU is available: {e}"
375+
)
376+
367377
# Freeze vae and text_encoder
368378
vae.requires_grad_(False)
369379
text_encoder.requires_grad_(False)

examples/textual_inversion/textual_inversion.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from diffusers.optimization import get_scheduler
2121
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
2222
from diffusers.utils import check_min_version
23+
from diffusers.utils.import_utils import is_xformers_available
2324
from huggingface_hub import HfFolder, Repository, whoami
2425

2526
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
@@ -439,6 +440,15 @@ def main():
439440
revision=args.revision,
440441
)
441442

443+
if is_xformers_available():
444+
try:
445+
unet.enable_xformers_memory_efficient_attention(True)
446+
except Exception as e:
447+
logger.warning(
448+
"Could not enable memory efficient attention. Make sure xformers is installed"
449+
f" correctly and a GPU is available: {e}"
450+
)
451+
442452
# Resize the token embeddings as we are adding new special tokens to the tokenizer
443453
text_encoder.resize_token_embeddings(len(tokenizer))
444454

src/diffusers/models/attention.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import math
15-
import warnings
1615
from dataclasses import dataclass
1716
from typing import Optional
1817

@@ -447,16 +446,6 @@ def __init__(
447446
# 3. Feed-forward
448447
self.norm3 = nn.LayerNorm(dim)
449448

450-
# if xformers is installed try to use memory_efficient_attention by default
451-
if is_xformers_available():
452-
try:
453-
self.set_use_memory_efficient_attention_xformers(True)
454-
except Exception as e:
455-
warnings.warn(
456-
"Could not enable memory efficient attention. Make sure xformers is installed"
457-
f" correctly and a GPU is available: {e}"
458-
)
459-
460449
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
461450
if not is_xformers_available():
462451
print("Here is how to install it")

0 commit comments

Comments
 (0)