File tree Expand file tree Collapse file tree 1 file changed +3
-1
lines changed
src/diffusers/pipelines/stable_cascade Expand file tree Collapse file tree 1 file changed +3
-1
lines changed Original file line number Diff line number Diff line change 1919
2020from ...models import StableCascadeUNet
2121from ...schedulers import DDPMWuerstchenScheduler
22- from ...utils import logging , replace_example_docstring
22+ from ...utils import is_torch_version , logging , replace_example_docstring
2323from ...utils .torch_utils import randn_tensor
2424from ..pipeline_utils import DiffusionPipeline , ImagePipelineOutput
2525from ..wuerstchen .modeling_paella_vq_model import PaellaVQModel
@@ -361,6 +361,8 @@ def __call__(
361361 device = self ._execution_device
362362 dtype = self .decoder .dtype
363363 self ._guidance_scale = guidance_scale
364+ if is_torch_version ("<" , "2.2.0" ) and dtype == torch .bfloat16 :
365+ raise ValueError ("`StableCascadeDecoderPipeline` requires torch>=2.2.0 when using `torch.bfloat16` dtype." )
364366
365367 # 1. Check inputs. Raise error if not correct
366368 self .check_inputs (
You can’t perform that action at this time.
0 commit comments