Skip to content

Commit 351b37e

Browse files
authored
Fix UniPC tests and remove some test warnings (huggingface#2396)
* Change solver_type to match the previous tests. * Prevent warnings about scale_model_inputs * Prevent console log about division by zero.
1 parent 2e0d489 commit 351b37e

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

tests/test_scheduler.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,11 @@ def check_over_configs(self, time_step=0, **config):
287287
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
288288
kwargs["num_inference_steps"] = num_inference_steps
289289

290+
# Make sure `scale_model_input` is invoked to prevent a warning
291+
if scheduler_class != VQDiffusionScheduler:
292+
_ = scheduler.scale_model_input(sample, 0)
293+
_ = new_scheduler.scale_model_input(sample, 0)
294+
290295
# Set the seed before step() as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler
291296
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
292297
kwargs["generator"] = torch.manual_seed(0)
@@ -597,7 +602,7 @@ def test_trained_betas(self):
597602
continue
598603

599604
scheduler_config = self.get_scheduler_config()
600-
scheduler = scheduler_class(**scheduler_config, trained_betas=np.array([0.0, 0.1]))
605+
scheduler = scheduler_class(**scheduler_config, trained_betas=np.array([0.1, 0.3]))
601606

602607
with tempfile.TemporaryDirectory() as tmpdirname:
603608
scheduler.save_pretrained(tmpdirname)
@@ -2648,6 +2653,7 @@ def get_scheduler_config(self, **kwargs):
26482653
"beta_end": 0.02,
26492654
"beta_schedule": "linear",
26502655
"solver_order": 2,
2656+
"solver_type": "bh1",
26512657
}
26522658

26532659
config.update(**kwargs)

0 commit comments

Comments
 (0)