@@ -115,6 +115,45 @@ def test_full_loop_no_noise_multistep(self):
115115 assert abs (result_sum .item () - 347.6357 ) < 1e-2
116116 assert abs (result_mean .item () - 0.4527 ) < 1e-3
117117
118+ def test_full_loop_with_noise (self ):
119+ scheduler_class = self .scheduler_classes [0 ]
120+ scheduler_config = self .get_scheduler_config ()
121+ scheduler = scheduler_class (** scheduler_config )
122+
123+ num_inference_steps = 10
124+ t_start = 8
125+
126+ scheduler .set_timesteps (num_inference_steps )
127+ timesteps = scheduler .timesteps
128+
129+ generator = torch .manual_seed (0 )
130+
131+ model = self .dummy_model ()
132+ sample = self .dummy_sample_deter * scheduler .init_noise_sigma
133+
134+ noise = self .dummy_noise_deter
135+ timesteps = scheduler .timesteps [t_start * scheduler .order :]
136+
137+ sample = scheduler .add_noise (sample , noise , timesteps [:1 ])
138+
139+ for t in timesteps :
140+ # 1. scale model input
141+ scaled_sample = scheduler .scale_model_input (sample , t )
142+
143+ # 2. predict noise residual
144+ residual = model (scaled_sample , t )
145+
146+ # 3. predict previous sample x_t-1
147+ pred_prev_sample = scheduler .step (residual , t , sample , generator = generator ).prev_sample
148+
149+ sample = pred_prev_sample
150+
151+ result_sum = torch .sum (torch .abs (sample ))
152+ result_mean = torch .mean (torch .abs (sample ))
153+
154+ assert abs (result_sum .item () - 763.9186 ) < 1e-2 , f" expected result sum 763.9186, but get { result_sum } "
155+ assert abs (result_mean .item () - 0.9947 ) < 1e-3 , f" expected result mean 0.9947, but get { result_mean } "
156+
118157 def test_custom_timesteps_increasing_order (self ):
119158 scheduler_class = self .scheduler_classes [0 ]
120159 scheduler_config = self .get_scheduler_config ()
0 commit comments