@@ -178,16 +178,16 @@ def __init__(
178178 def enable_xformers_memory_efficient_attention (self , attention_op : Optional [Callable ] = None ):
179179 self .decoder_pipe .enable_xformers_memory_efficient_attention (attention_op )
180180
181- def enable_sequential_cpu_offload (self , gpu_id = 0 ):
181+ def enable_sequential_cpu_offload (self , gpu_id : Optional [ int ] = None , device : Union [ torch . device , str ] = "cuda" ):
182182 r"""
183183 Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
184184 text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
185185 `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
186186 Note that offloading happens on a submodule basis. Memory savings are higher than with
187187 `enable_model_cpu_offload`, but performance is lower.
188188 """
189- self .prior_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id )
190- self .decoder_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id )
189+ self .prior_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id , device = device )
190+ self .decoder_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id , device = device )
191191
192192 def progress_bar (self , iterable = None , total = None ):
193193 self .prior_pipe .progress_bar (iterable = iterable , total = total )
@@ -405,26 +405,26 @@ def __init__(
405405 def enable_xformers_memory_efficient_attention (self , attention_op : Optional [Callable ] = None ):
406406 self .decoder_pipe .enable_xformers_memory_efficient_attention (attention_op )
407407
408- def enable_model_cpu_offload (self , gpu_id = 0 ):
408+ def enable_model_cpu_offload (self , gpu_id : Optional [ int ] = None , device : Union [ torch . device , str ] = "cuda" ):
409409 r"""
410410 Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
411411 to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
412412 method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
413413 `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
414414 """
415- self .prior_pipe .enable_model_cpu_offload ()
416- self .decoder_pipe .enable_model_cpu_offload ()
415+ self .prior_pipe .enable_model_cpu_offload (gpu_id = gpu_id , device = device )
416+ self .decoder_pipe .enable_model_cpu_offload (gpu_id = gpu_id , device = device )
417417
418- def enable_sequential_cpu_offload (self , gpu_id = 0 ):
418+ def enable_sequential_cpu_offload (self , gpu_id : Optional [ int ] = None , device : Union [ torch . device , str ] = "cuda" ):
419419 r"""
420420 Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
421421 text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
422422 `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
423423 Note that offloading happens on a submodule basis. Memory savings are higher than with
424424 `enable_model_cpu_offload`, but performance is lower.
425425 """
426- self .prior_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id )
427- self .decoder_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id )
426+ self .prior_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id , device = device )
427+ self .decoder_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id , device = device )
428428
429429 def progress_bar (self , iterable = None , total = None ):
430430 self .prior_pipe .progress_bar (iterable = iterable , total = total )
@@ -653,16 +653,16 @@ def __init__(
653653 def enable_xformers_memory_efficient_attention (self , attention_op : Optional [Callable ] = None ):
654654 self .decoder_pipe .enable_xformers_memory_efficient_attention (attention_op )
655655
656- def enable_sequential_cpu_offload (self , gpu_id = 0 ):
656+ def enable_sequential_cpu_offload (self , gpu_id : Optional [ int ] = None , device : Union [ torch . device , str ] = "cuda" ):
657657 r"""
658658 Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
659659 text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
660660 `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
661661 Note that offloading happens on a submodule basis. Memory savings are higher than with
662662 `enable_model_cpu_offload`, but performance is lower.
663663 """
664- self .prior_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id )
665- self .decoder_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id )
664+ self .prior_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id , device = device )
665+ self .decoder_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id , device = device )
666666
667667 def progress_bar (self , iterable = None , total = None ):
668668 self .prior_pipe .progress_bar (iterable = iterable , total = total )
0 commit comments