@@ -166,6 +166,7 @@ def __init__(
166166 self ._chunk_size = None
167167 self ._chunk_dim = 0
168168
169+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
169170 def set_chunk_feed_forward (self , chunk_size : Optional [int ], dim : int = 0 ):
170171 # Sets chunk feed-forward
171172 self ._chunk_size = chunk_size
@@ -529,3 +530,45 @@ def forward(
529530 if not return_dict :
530531 return (output ,)
531532 return Transformer2DModelOutput (sample = output )
533+
534+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
535+ def enable_forward_chunking (self , chunk_size : Optional [int ] = None , dim : int = 0 ) -> None :
536+ """
537+ Sets the attention processor to use [feed forward
538+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
539+
540+ Parameters:
541+ chunk_size (`int`, *optional*):
542+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
543+ over each tensor of dim=`dim`.
544+ dim (`int`, *optional*, defaults to `0`):
545+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
546+ or dim=1 (sequence length).
547+ """
548+ if dim not in [0 , 1 ]:
549+ raise ValueError (f"Make sure to set `dim` to either 0 or 1, not { dim } " )
550+
551+ # By default chunk size is 1
552+ chunk_size = chunk_size or 1
553+
554+ def fn_recursive_feed_forward (module : torch .nn .Module , chunk_size : int , dim : int ):
555+ if hasattr (module , "set_chunk_feed_forward" ):
556+ module .set_chunk_feed_forward (chunk_size = chunk_size , dim = dim )
557+
558+ for child in module .children ():
559+ fn_recursive_feed_forward (child , chunk_size , dim )
560+
561+ for module in self .children ():
562+ fn_recursive_feed_forward (module , chunk_size , dim )
563+
564+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
565+ def disable_forward_chunking (self ):
566+ def fn_recursive_feed_forward (module : torch .nn .Module , chunk_size : int , dim : int ):
567+ if hasattr (module , "set_chunk_feed_forward" ):
568+ module .set_chunk_feed_forward (chunk_size = chunk_size , dim = dim )
569+
570+ for child in module .children ():
571+ fn_recursive_feed_forward (child , chunk_size , dim )
572+
573+ for module in self .children ():
574+ fn_recursive_feed_forward (module , None , 0 )
0 commit comments