2424from ...callbacks import MultiPipelineCallbacks , PipelineCallback
2525from ...image_processor import PixArtImageProcessor
2626from ...loaders import SanaLoraLoaderMixin
27- from ...models import AutoencoderDC , SanaTransformer2DModel
27+ from ...models import AutoencoderDC , AutoencoderKL , SanaTransformer2DModel
2828from ...schedulers import DPMSolverMultistepScheduler
2929from ...utils import (
3030 BACKENDS_MAPPING ,
@@ -150,7 +150,7 @@ def __init__(
150150 self ,
151151 tokenizer : AutoTokenizer ,
152152 text_encoder : AutoModelForCausalLM ,
153- vae : AutoencoderDC ,
153+ vae : Any [ AutoencoderDC , AutoencoderKL ] ,
154154 transformer : SanaTransformer2DModel ,
155155 scheduler : DPMSolverMultistepScheduler ,
156156 ):
@@ -162,8 +162,8 @@ def __init__(
162162
163163 self .vae_scale_factor = (
164164 2 ** (len (self .vae .config .encoder_block_out_channels ) - 1 )
165- if hasattr (self , "vae" ) and self .vae is not None
166- else 32
165+ if hasattr (self , "vae" ) and type ( self .vae ) is AutoencoderDC
166+ else 8
167167 )
168168 self .image_processor = PixArtImageProcessor (vae_scale_factor = self .vae_scale_factor )
169169
@@ -233,7 +233,7 @@ def encode_prompt(
233233
234234 self .tokenizer .padding_side = "right"
235235
236- # See Section 3.1. of the paper.
236+ # See Section 3.1. of the paper. (???)
237237 max_length = max_sequence_length
238238 select_index = [0 ] + list (range (- max_length + 1 , 0 ))
239239
0 commit comments