@@ -127,6 +127,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
127127 Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
128128 steps_offset (`int`, defaults to 0):
129129 An offset added to the inference steps, as required by some model families.
130+ final_sigmas_type (`str`, defaults to `"zero"`):
131+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma
132+ is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
130133 """
131134
132135 _compatibles = [e .name for e in KarrasDiffusionSchedulers ]
@@ -153,6 +156,7 @@ def __init__(
153156 use_karras_sigmas : Optional [bool ] = False ,
154157 timestep_spacing : str = "linspace" ,
155158 steps_offset : int = 0 ,
159+ final_sigmas_type : Optional [str ] = "zero" , # "zero", "sigma_min"
156160 ):
157161 if trained_betas is not None :
158162 self .betas = torch .tensor (trained_betas , dtype = torch .float32 )
@@ -265,10 +269,25 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
265269 sigmas = np .flip (sigmas ).copy ()
266270 sigmas = self ._convert_to_karras (in_sigmas = sigmas , num_inference_steps = num_inference_steps )
267271 timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ]).round ()
268- sigmas = np .concatenate ([sigmas , sigmas [- 1 :]]).astype (np .float32 )
272+ if self .config .final_sigmas_type == "sigma_min" :
273+ sigma_last = sigmas [- 1 ]
274+ elif self .config .final_sigmas_type == "zero" :
275+ sigma_last = 0
276+ else :
277+ raise ValueError (
278+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got { self .config .final_sigmas_type } "
279+ )
280+ sigmas = np .concatenate ([sigmas , [sigma_last ]]).astype (np .float32 )
269281 else :
270282 sigmas = np .interp (timesteps , np .arange (0 , len (sigmas )), sigmas )
271- sigma_last = ((1 - self .alphas_cumprod [0 ]) / self .alphas_cumprod [0 ]) ** 0.5
283+ if self .config .final_sigmas_type == "sigma_min" :
284+ sigma_last = ((1 - self .alphas_cumprod [0 ]) / self .alphas_cumprod [0 ]) ** 0.5
285+ elif self .config .final_sigmas_type == "zero" :
286+ sigma_last = 0
287+ else :
288+ raise ValueError (
289+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got { self .config .final_sigmas_type } "
290+ )
272291 sigmas = np .concatenate ([sigmas , [sigma_last ]]).astype (np .float32 )
273292
274293 self .sigmas = torch .from_numpy (sigmas )
0 commit comments