@@ -269,3 +269,113 @@ def test_full_loop_with_noise(self):
269269
270270 assert abs (result_sum .item () - 315.5757 ) < 1e-2 , f" expected result sum 315.5757, but get { result_sum } "
271271 assert abs (result_mean .item () - 0.4109 ) < 1e-3 , f" expected result mean 0.4109, but get { result_mean } "
272+
273+
274+ class UniPCMultistepScheduler1DTest (UniPCMultistepSchedulerTest ):
275+ @property
276+ def dummy_sample (self ):
277+ batch_size = 4
278+ num_channels = 3
279+ width = 8
280+
281+ sample = torch .rand ((batch_size , num_channels , width ))
282+
283+ return sample
284+
285+ @property
286+ def dummy_noise_deter (self ):
287+ batch_size = 4
288+ num_channels = 3
289+ width = 8
290+
291+ num_elems = batch_size * num_channels * width
292+ sample = torch .arange (num_elems ).flip (- 1 )
293+ sample = sample .reshape (num_channels , width , batch_size )
294+ sample = sample / num_elems
295+ sample = sample .permute (2 , 0 , 1 )
296+
297+ return sample
298+
299+ @property
300+ def dummy_sample_deter (self ):
301+ batch_size = 4
302+ num_channels = 3
303+ width = 8
304+
305+ num_elems = batch_size * num_channels * width
306+ sample = torch .arange (num_elems )
307+ sample = sample .reshape (num_channels , width , batch_size )
308+ sample = sample / num_elems
309+ sample = sample .permute (2 , 0 , 1 )
310+
311+ return sample
312+
313+ def test_switch (self ):
314+ # make sure that iterating over schedulers with same config names gives same results
315+ # for defaults
316+ scheduler = UniPCMultistepScheduler (** self .get_scheduler_config ())
317+ sample = self .full_loop (scheduler = scheduler )
318+ result_mean = torch .mean (torch .abs (sample ))
319+
320+ assert abs (result_mean .item () - 0.2441 ) < 1e-3
321+
322+ scheduler = DPMSolverSinglestepScheduler .from_config (scheduler .config )
323+ scheduler = DEISMultistepScheduler .from_config (scheduler .config )
324+ scheduler = DPMSolverMultistepScheduler .from_config (scheduler .config )
325+ scheduler = UniPCMultistepScheduler .from_config (scheduler .config )
326+
327+ sample = self .full_loop (scheduler = scheduler )
328+ result_mean = torch .mean (torch .abs (sample ))
329+
330+ assert abs (result_mean .item () - 0.2441 ) < 1e-3
331+
332+ def test_full_loop_no_noise (self ):
333+ sample = self .full_loop ()
334+ result_mean = torch .mean (torch .abs (sample ))
335+
336+ assert abs (result_mean .item () - 0.2441 ) < 1e-3
337+
338+ def test_full_loop_with_karras (self ):
339+ sample = self .full_loop (use_karras_sigmas = True )
340+ result_mean = torch .mean (torch .abs (sample ))
341+
342+ assert abs (result_mean .item () - 0.2898 ) < 1e-3
343+
344+ def test_full_loop_with_v_prediction (self ):
345+ sample = self .full_loop (prediction_type = "v_prediction" )
346+ result_mean = torch .mean (torch .abs (sample ))
347+
348+ assert abs (result_mean .item () - 0.1014 ) < 1e-3
349+
350+ def test_full_loop_with_karras_and_v_prediction (self ):
351+ sample = self .full_loop (prediction_type = "v_prediction" , use_karras_sigmas = True )
352+ result_mean = torch .mean (torch .abs (sample ))
353+
354+ assert abs (result_mean .item () - 0.1944 ) < 1e-3
355+
356+ def test_full_loop_with_noise (self ):
357+ scheduler_class = self .scheduler_classes [0 ]
358+ scheduler_config = self .get_scheduler_config ()
359+ scheduler = scheduler_class (** scheduler_config )
360+
361+ num_inference_steps = 10
362+ t_start = 8
363+
364+ model = self .dummy_model ()
365+ sample = self .dummy_sample_deter
366+ scheduler .set_timesteps (num_inference_steps )
367+
368+ # add noise
369+ noise = self .dummy_noise_deter
370+ timesteps = scheduler .timesteps [t_start * scheduler .order :]
371+ sample = scheduler .add_noise (sample , noise , timesteps [:1 ])
372+
373+ for i , t in enumerate (timesteps ):
374+ residual = model (sample , t )
375+ sample = scheduler .step (residual , t , sample ).prev_sample
376+
377+ result_sum = torch .sum (torch .abs (sample ))
378+ result_mean = torch .mean (torch .abs (sample ))
379+
380+ assert abs (result_sum .item () - 39.0870 ) < 1e-2 , f" expected result sum 39.0870, but get { result_sum } "
381+ assert abs (result_mean .item () - 0.4072 ) < 1e-3 , f" expected result mean 0.4072, but get { result_mean } "
0 commit comments