File tree Expand file tree Collapse file tree 1 file changed +8
-0
lines changed Expand file tree Collapse file tree 1 file changed +8
-0
lines changed Original file line number Diff line number Diff line change @@ -577,6 +577,11 @@ def parse_args(input_args=None):
577577 choices = ["DPMSolverMultistepScheduler" , "DDPMScheduler" ],
578578 help = "Select which scheduler to use for validation. DDPMScheduler is recommended for DeepFloyd IF." ,
579579 )
580+ parser .add_argument (
581+ "--disable_flash_sdp" ,
582+ action = "store_true" ,
583+ required = False , help = "Set to disable flash sdp in torch's cuda backend"
584+ )
580585
581586 if input_args is not None :
582587 args = parser .parse_args (input_args )
@@ -602,6 +607,9 @@ def parse_args(input_args=None):
602607 if args .train_text_encoder and args .pre_compute_text_embeddings :
603608 raise ValueError ("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`" )
604609
610+ if args .disable_flash_sdp :
611+ torch .backends .cuda .enable_flash_sdp (False )
612+
605613 return args
606614
607615
You can’t perform that action at this time.
0 commit comments