@@ -402,41 +402,18 @@ def forward(
402402
403403 # 1. Input
404404 if self .is_input_continuous :
405- batch , _ , height , width = hidden_states .shape
405+ batch_size , _ , height , width = hidden_states .shape
406406 residual = hidden_states
407-
408- hidden_states = self .norm (hidden_states )
409- if not self .use_linear_projection :
410- hidden_states = self .proj_in (hidden_states )
411- inner_dim = hidden_states .shape [1 ]
412- hidden_states = hidden_states .permute (0 , 2 , 3 , 1 ).reshape (batch , height * width , inner_dim )
413- else :
414- inner_dim = hidden_states .shape [1 ]
415- hidden_states = hidden_states .permute (0 , 2 , 3 , 1 ).reshape (batch , height * width , inner_dim )
416- hidden_states = self .proj_in (hidden_states )
417-
407+ hidden_states , inner_dim = self ._operate_on_continuous_inputs (hidden_states )
418408 elif self .is_input_vectorized :
419409 hidden_states = self .latent_image_embedding (hidden_states )
420410 elif self .is_input_patches :
421411 height , width = hidden_states .shape [- 2 ] // self .patch_size , hidden_states .shape [- 1 ] // self .patch_size
422- hidden_states = self .pos_embed (hidden_states )
423-
424- if self .adaln_single is not None :
425- if self .use_additional_conditions and added_cond_kwargs is None :
426- raise ValueError (
427- "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
428- )
429- batch_size = hidden_states .shape [0 ]
430- timestep , embedded_timestep = self .adaln_single (
431- timestep , added_cond_kwargs , batch_size = batch_size , hidden_dtype = hidden_states .dtype
432- )
412+ hidden_states , encoder_hidden_states , timestep , embedded_timestep = self ._operate_on_patched_inputs (
413+ hidden_states , encoder_hidden_states , timestep , added_cond_kwargs
414+ )
433415
434416 # 2. Blocks
435- if self .is_input_patches and self .caption_projection is not None :
436- batch_size = hidden_states .shape [0 ]
437- encoder_hidden_states = self .caption_projection (encoder_hidden_states )
438- encoder_hidden_states = encoder_hidden_states .view (batch_size , - 1 , hidden_states .shape [- 1 ])
439-
440417 for block in self .transformer_blocks :
441418 if self .training and self .gradient_checkpointing :
442419
@@ -474,51 +451,116 @@ def custom_forward(*inputs):
474451
475452 # 3. Output
476453 if self .is_input_continuous :
477- if not self .use_linear_projection :
478- hidden_states = hidden_states . reshape ( batch , height , width , inner_dim ). permute ( 0 , 3 , 1 , 2 ). contiguous ()
479- hidden_states = self . proj_out ( hidden_states )
480- else :
481- hidden_states = self . proj_out ( hidden_states )
482- hidden_states = hidden_states . reshape ( batch , height , width , inner_dim ). permute ( 0 , 3 , 1 , 2 ). contiguous ()
483-
484- output = hidden_states + residual
454+ output = self ._get_output_for_continuous_inputs (
455+ hidden_states = hidden_states ,
456+ residual = residual ,
457+ batch_size = batch_size ,
458+ height = height ,
459+ width = width ,
460+ inner_dim = inner_dim ,
461+ )
485462 elif self .is_input_vectorized :
486- hidden_states = self .norm_out (hidden_states )
487- logits = self .out (hidden_states )
488- # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
489- logits = logits .permute (0 , 2 , 1 )
463+ output = self ._get_output_for_vectorized_inputs (hidden_states )
464+ elif self .is_input_patches :
465+ output = self ._get_output_for_patched_inputs (
466+ hidden_states = hidden_states ,
467+ timestep = timestep ,
468+ class_labels = class_labels ,
469+ embedded_timestep = embedded_timestep ,
470+ height = height ,
471+ width = width ,
472+ )
473+
474+ if not return_dict :
475+ return (output ,)
476+
477+ return Transformer2DModelOutput (sample = output )
478+
479+ def _operate_on_continuous_inputs (self , hidden_states ):
480+ batch , _ , height , width = hidden_states .shape
481+ hidden_states = self .norm (hidden_states )
482+
483+ if not self .use_linear_projection :
484+ hidden_states = self .proj_in (hidden_states )
485+ inner_dim = hidden_states .shape [1 ]
486+ hidden_states = hidden_states .permute (0 , 2 , 3 , 1 ).reshape (batch , height * width , inner_dim )
487+ else :
488+ inner_dim = hidden_states .shape [1 ]
489+ hidden_states = hidden_states .permute (0 , 2 , 3 , 1 ).reshape (batch , height * width , inner_dim )
490+ hidden_states = self .proj_in (hidden_states )
491+
492+ return hidden_states , inner_dim
490493
491- # log(p(x_0))
492- output = F .log_softmax (logits .double (), dim = 1 ).float ()
494+ def _operate_on_patched_inputs (self , hidden_states , encoder_hidden_states , timestep , added_cond_kwargs ):
495+ batch_size = hidden_states .shape [0 ]
496+ hidden_states = self .pos_embed (hidden_states )
497+ embedded_timestep = None
493498
494- if self .is_input_patches :
495- if self .config . norm_type != "ada_norm_single" :
496- conditioning = self . transformer_blocks [ 0 ]. norm1 . emb (
497- timestep , class_labels , hidden_dtype = hidden_states . dtype
499+ if self .adaln_single is not None :
500+ if self .use_additional_conditions and added_cond_kwargs is None :
501+ raise ValueError (
502+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
498503 )
499- shift , scale = self .proj_out_1 (F .silu (conditioning )).chunk (2 , dim = 1 )
500- hidden_states = self .norm_out (hidden_states ) * (1 + scale [:, None ]) + shift [:, None ]
501- hidden_states = self .proj_out_2 (hidden_states )
502- elif self .config .norm_type == "ada_norm_single" :
503- shift , scale = (self .scale_shift_table [None ] + embedded_timestep [:, None ]).chunk (2 , dim = 1 )
504- hidden_states = self .norm_out (hidden_states )
505- # Modulation
506- hidden_states = hidden_states * (1 + scale ) + shift
507- hidden_states = self .proj_out (hidden_states )
508- hidden_states = hidden_states .squeeze (1 )
509-
510- # unpatchify
511- if self .adaln_single is None :
512- height = width = int (hidden_states .shape [1 ] ** 0.5 )
513- hidden_states = hidden_states .reshape (
514- shape = (- 1 , height , width , self .patch_size , self .patch_size , self .out_channels )
504+ timestep , embedded_timestep = self .adaln_single (
505+ timestep , added_cond_kwargs , batch_size = batch_size , hidden_dtype = hidden_states .dtype
515506 )
516- hidden_states = torch .einsum ("nhwpqc->nchpwq" , hidden_states )
517- output = hidden_states .reshape (
518- shape = (- 1 , self .out_channels , height * self .patch_size , width * self .patch_size )
507+
508+ if self .caption_projection is not None :
509+ encoder_hidden_states = self .caption_projection (encoder_hidden_states )
510+ encoder_hidden_states = encoder_hidden_states .view (batch_size , - 1 , hidden_states .shape [- 1 ])
511+
512+ return hidden_states , encoder_hidden_states , timestep , embedded_timestep
513+
514+ def _get_output_for_continuous_inputs (self , hidden_states , residual , batch_size , height , width , inner_dim ):
515+ if not self .use_linear_projection :
516+ hidden_states = (
517+ hidden_states .reshape (batch_size , height , width , inner_dim ).permute (0 , 3 , 1 , 2 ).contiguous ()
518+ )
519+ hidden_states = self .proj_out (hidden_states )
520+ else :
521+ hidden_states = self .proj_out (hidden_states )
522+ hidden_states = (
523+ hidden_states .reshape (batch_size , height , width , inner_dim ).permute (0 , 3 , 1 , 2 ).contiguous ()
519524 )
520525
521- if not return_dict :
522- return ( output ,)
526+ output = hidden_states + residual
527+ return output
523528
524- return Transformer2DModelOutput (sample = output )
529+ def _get_output_for_vectorized_inputs (self , hidden_states ):
530+ hidden_states = self .norm_out (hidden_states )
531+ logits = self .out (hidden_states )
532+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
533+ logits = logits .permute (0 , 2 , 1 )
534+ # log(p(x_0))
535+ output = F .log_softmax (logits .double (), dim = 1 ).float ()
536+ return output
537+
538+ def _get_output_for_patched_inputs (
539+ self , hidden_states , timestep , class_labels , embedded_timestep , height = None , width = None
540+ ):
541+ if self .config .norm_type != "ada_norm_single" :
542+ conditioning = self .transformer_blocks [0 ].norm1 .emb (
543+ timestep , class_labels , hidden_dtype = hidden_states .dtype
544+ )
545+ shift , scale = self .proj_out_1 (F .silu (conditioning )).chunk (2 , dim = 1 )
546+ hidden_states = self .norm_out (hidden_states ) * (1 + scale [:, None ]) + shift [:, None ]
547+ hidden_states = self .proj_out_2 (hidden_states )
548+ elif self .config .norm_type == "ada_norm_single" :
549+ shift , scale = (self .scale_shift_table [None ] + embedded_timestep [:, None ]).chunk (2 , dim = 1 )
550+ hidden_states = self .norm_out (hidden_states )
551+ # Modulation
552+ hidden_states = hidden_states * (1 + scale ) + shift
553+ hidden_states = self .proj_out (hidden_states )
554+ hidden_states = hidden_states .squeeze (1 )
555+
556+ # unpatchify
557+ if self .adaln_single is None :
558+ height = width = int (hidden_states .shape [1 ] ** 0.5 )
559+ hidden_states = hidden_states .reshape (
560+ shape = (- 1 , height , width , self .patch_size , self .patch_size , self .out_channels )
561+ )
562+ hidden_states = torch .einsum ("nhwpqc->nchpwq" , hidden_states )
563+ output = hidden_states .reshape (
564+ shape = (- 1 , self .out_channels , height * self .patch_size , width * self .patch_size )
565+ )
566+ return output
0 commit comments