Skip to content

Commit ebce6eb

Browse files
committed
added flag and logic to disable flash_spd in the torch cuda backend
1 parent 3e8b632 commit ebce6eb

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,11 @@ def parse_args(input_args=None):
577577
choices=["DPMSolverMultistepScheduler", "DDPMScheduler"],
578578
help="Select which scheduler to use for validation. DDPMScheduler is recommended for DeepFloyd IF.",
579579
)
580+
parser.add_argument(
581+
"--disable_flash_sdp",
582+
action="store_true",
583+
required=False, help="Set to disable flash sdp in torch's cuda backend"
584+
)
580585

581586
if input_args is not None:
582587
args = parser.parse_args(input_args)
@@ -602,6 +607,9 @@ def parse_args(input_args=None):
602607
if args.train_text_encoder and args.pre_compute_text_embeddings:
603608
raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`")
604609

610+
if args.disable_flash_sdp:
611+
torch.backends.cuda.enable_flash_sdp(False)
612+
605613
return args
606614

607615

0 commit comments

Comments
 (0)