@@ -68,7 +68,7 @@ def calculate_shift(
68
68
return mu
69
69
70
70
71
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion .retrieve_timesteps
71
+ # Copied from diffusers.pipelines.cogview4.pipeline_cogview4 .retrieve_timesteps
72
72
def retrieve_timesteps (
73
73
scheduler ,
74
74
num_inference_steps : Optional [int ] = None ,
@@ -100,10 +100,19 @@ def retrieve_timesteps(
100
100
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
101
101
second element is the number of inference steps.
102
102
"""
103
+ accepts_timesteps = "timesteps" in set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
104
+ accepts_sigmas = "sigmas" in set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
105
+
103
106
if timesteps is not None and sigmas is not None :
104
- raise ValueError ("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" )
105
- if timesteps is not None :
106
- accepts_timesteps = "timesteps" in set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
107
+ if not accepts_timesteps and not accepts_sigmas :
108
+ raise ValueError (
109
+ f"The current scheduler class { scheduler .__class__ } 's `set_timesteps` does not support custom"
110
+ f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
111
+ )
112
+ scheduler .set_timesteps (timesteps = timesteps , sigmas = sigmas , device = device , ** kwargs )
113
+ timesteps = scheduler .timesteps
114
+ num_inference_steps = len (timesteps )
115
+ elif timesteps is not None and sigmas is None :
107
116
if not accepts_timesteps :
108
117
raise ValueError (
109
118
f"The current scheduler class { scheduler .__class__ } 's `set_timesteps` does not support custom"
@@ -112,9 +121,8 @@ def retrieve_timesteps(
112
121
scheduler .set_timesteps (timesteps = timesteps , device = device , ** kwargs )
113
122
timesteps = scheduler .timesteps
114
123
num_inference_steps = len (timesteps )
115
- elif sigmas is not None :
116
- accept_sigmas = "sigmas" in set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
117
- if not accept_sigmas :
124
+ elif timesteps is None and sigmas is not None :
125
+ if not accepts_sigmas :
118
126
raise ValueError (
119
127
f"The current scheduler class { scheduler .__class__ } 's `set_timesteps` does not support custom"
120
128
f" sigmas schedules. Please check whether you are using the correct scheduler."
@@ -515,8 +523,8 @@ def __call__(
515
523
The output format of the generate image. Choose between
516
524
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
517
525
return_dict (`bool`, *optional*, defaults to `True`):
518
- Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput `] instead
519
- of a plain tuple.
526
+ Whether or not to return a [`~pipelines.pipeline_CogView4.CogView4PipelineOutput `] instead of a plain
527
+ tuple.
520
528
attention_kwargs (`dict`, *optional*):
521
529
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
522
530
`self.processor` in
@@ -532,7 +540,6 @@ def __call__(
532
540
`._callback_tensor_inputs` attribute of your pipeline class.
533
541
max_sequence_length (`int`, defaults to `224`):
534
542
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
535
-
536
543
Examples:
537
544
538
545
Returns:
0 commit comments