Skip to content

Commit b4816e3

Browse files
committed
Fix vae_scale_factor error
1 parent cd991d1 commit b4816e3

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2525
from ...image_processor import PixArtImageProcessor
2626
from ...loaders import SanaLoraLoaderMixin
27-
from ...models import AutoencoderDC, SanaTransformer2DModel
27+
from ...models import AutoencoderDC, AutoencoderKL, SanaTransformer2DModel
2828
from ...schedulers import DPMSolverMultistepScheduler
2929
from ...utils import (
3030
BACKENDS_MAPPING,
@@ -150,7 +150,7 @@ def __init__(
150150
self,
151151
tokenizer: AutoTokenizer,
152152
text_encoder: AutoModelForCausalLM,
153-
vae: AutoencoderDC,
153+
vae: Any[AutoencoderDC,AutoencoderKL],
154154
transformer: SanaTransformer2DModel,
155155
scheduler: DPMSolverMultistepScheduler,
156156
):
@@ -162,8 +162,8 @@ def __init__(
162162

163163
self.vae_scale_factor = (
164164
2 ** (len(self.vae.config.encoder_block_out_channels) - 1)
165-
if hasattr(self, "vae") and self.vae is not None
166-
else 32
165+
if hasattr(self, "vae") and type(self.vae) is AutoencoderDC
166+
else 8
167167
)
168168
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
169169

@@ -233,7 +233,7 @@ def encode_prompt(
233233

234234
self.tokenizer.padding_side = "right"
235235

236-
# See Section 3.1. of the paper.
236+
# See Section 3.1. of the paper. (???)
237237
max_length = max_sequence_length
238238
select_index = [0] + list(range(-max_length + 1, 0))
239239

0 commit comments

Comments
 (0)