@@ -507,7 +507,7 @@ def register_modules(self, **kwargs):
507507 setattr (self , name , module )
508508
509509 def __setattr__ (self , name : str , value : Any ):
510- if hasattr ( self , name ) and hasattr (self .config , name ):
510+ if name in self . __dict__ and hasattr (self .config , name ):
511511 # We need to overwrite the config if name exists in config
512512 if isinstance (getattr (self .config , name ), (tuple , list )):
513513 if value is not None and self .config [name ][0 ] is not None :
@@ -635,26 +635,25 @@ def module_is_offloaded(module):
635635 )
636636
637637 module_names , _ = self ._get_signature_keys (self )
638- module_names = [m for m in module_names if hasattr (self , m )]
638+ modules = [getattr (self , n , None ) for n in module_names ]
639+ modules = [m for m in modules if isinstance (m , torch .nn .Module )]
639640
640641 is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
641- for name in module_names :
642- module = getattr (self , name )
643- if isinstance (module , torch .nn .Module ):
644- module .to (torch_device , torch_dtype )
645- if (
646- module .dtype == torch .float16
647- and str (torch_device ) in ["cpu" ]
648- and not silence_dtype_warnings
649- and not is_offloaded
650- ):
651- logger .warning (
652- "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
653- " is not recommended to move them to `cpu` as running them will fail. Please make"
654- " sure to use an accelerator to run the pipeline in inference, due to the lack of"
655- " support for`float16` operations on this device in PyTorch. Please, remove the"
656- " `torch_dtype=torch.float16` argument, or use another device for inference."
657- )
642+ for module in modules :
643+ module .to (torch_device , torch_dtype )
644+ if (
645+ module .dtype == torch .float16
646+ and str (torch_device ) in ["cpu" ]
647+ and not silence_dtype_warnings
648+ and not is_offloaded
649+ ):
650+ logger .warning (
651+ "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
652+ " is not recommended to move them to `cpu` as running them will fail. Please make"
653+ " sure to use an accelerator to run the pipeline in inference, due to the lack of"
654+ " support for`float16` operations on this device in PyTorch. Please, remove the"
655+ " `torch_dtype=torch.float16` argument, or use another device for inference."
656+ )
658657 return self
659658
660659 @property
@@ -664,12 +663,12 @@ def device(self) -> torch.device:
664663 `torch.device`: The torch device on which the pipeline is located.
665664 """
666665 module_names , _ = self ._get_signature_keys (self )
667- module_names = [m for m in module_names if hasattr (self , m )]
666+ modules = [getattr (self , n , None ) for n in module_names ]
667+ modules = [m for m in modules if isinstance (m , torch .nn .Module )]
668+
669+ for module in modules :
670+ return module .device
668671
669- for name in module_names :
670- module = getattr (self , name )
671- if isinstance (module , torch .nn .Module ):
672- return module .device
673672 return torch .device ("cpu" )
674673
675674 @classmethod
@@ -1438,13 +1437,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
14381437 for child in module .children ():
14391438 fn_recursive_set_mem_eff (child )
14401439
1441- module_names , _ , _ = self .extract_init_dict (dict (self .config ))
1442- module_names = [m for m in module_names if hasattr (self , m )]
1440+ module_names , _ = self ._get_signature_keys (self )
1441+ modules = [getattr (self , n , None ) for n in module_names ]
1442+ modules = [m for m in modules if isinstance (m , torch .nn .Module )]
14431443
1444- for module_name in module_names :
1445- module = getattr (self , module_name )
1446- if isinstance (module , torch .nn .Module ):
1447- fn_recursive_set_mem_eff (module )
1444+ for module in modules :
1445+ fn_recursive_set_mem_eff (module )
14481446
14491447 def enable_attention_slicing (self , slice_size : Optional [Union [str , int ]] = "auto" ):
14501448 r"""
@@ -1471,10 +1469,9 @@ def disable_attention_slicing(self):
14711469 self .enable_attention_slicing (None )
14721470
14731471 def set_attention_slice (self , slice_size : Optional [int ]):
1474- module_names , _ , _ = self .extract_init_dict (dict (self .config ))
1475- module_names = [m for m in module_names if hasattr (self , m )]
1472+ module_names , _ = self ._get_signature_keys (self )
1473+ modules = [getattr (self , n , None ) for n in module_names ]
1474+ modules = [m for m in modules if isinstance (m , torch .nn .Module ) and hasattr (m , "set_attention_slice" )]
14761475
1477- for module_name in module_names :
1478- module = getattr (self , module_name )
1479- if isinstance (module , torch .nn .Module ) and hasattr (module , "set_attention_slice" ):
1480- module .set_attention_slice (slice_size )
1476+ for module in modules :
1477+ module .set_attention_slice (slice_size )
0 commit comments