Skip to content

Commit cf4664e

Browse files
fix tests
1 parent 7222a8e commit cf4664e

File tree

4 files changed

+4
-23
lines changed

4 files changed

+4
-23
lines changed

src/diffusers/models/unet_2d_condition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def forward(
301301
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
302302
# This would be a good case for the `match` statement (Python 3.10+)
303303
is_mps = sample.device.type == "mps"
304-
if torch.is_floating_point(timesteps):
304+
if isinstance(timestep, float):
305305
dtype = torch.float32 if is_mps else torch.float64
306306
else:
307307
dtype = torch.int32 if is_mps else torch.int64

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def forward(
379379
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
380380
# This would be a good case for the `match` statement (Python 3.10+)
381381
is_mps = sample.device.type == "mps"
382-
if torch.is_floating_point(timesteps):
382+
if isinstance(timestep, float):
383383
dtype = torch.float32 if is_mps else torch.float64
384384
else:
385385
dtype = torch.int32 if is_mps else torch.int64

src/diffusers/schedulers/scheduling_lms_discrete_flax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarra
117117
Returns:
118118
`jnp.ndarray`: scaled input sample
119119
"""
120-
(step_index,) = jnp.where(scheduler_state.timesteps == timestep, size=1)
121-
sigma = scheduler_state.sigmas[step_index]
120+
(step_index,) = jnp.where(state.timesteps == timestep, size=1)
121+
sigma = state.sigmas[step_index]
122122
sample = sample / ((sigma**2 + 1) ** 0.5)
123123
return sample
124124

tests/pipelines/stable_diffusion_2/test_stable_diffusion.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import gc
1717
import tempfile
18-
import time
1918
import unittest
2019

2120
import numpy as np
@@ -694,24 +693,6 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
694693
assert test_callback_fn.has_been_called
695694
assert number_of_steps == 20
696695

697-
def test_stable_diffusion_low_cpu_mem_usage(self):
698-
pipeline_id = "stabilityai/stable-diffusion-2-base"
699-
700-
start_time = time.time()
701-
pipeline_low_cpu_mem_usage = StableDiffusionPipeline.from_pretrained(
702-
pipeline_id, revision="fp16", torch_dtype=torch.float16
703-
)
704-
pipeline_low_cpu_mem_usage.to(torch_device)
705-
low_cpu_mem_usage_time = time.time() - start_time
706-
707-
start_time = time.time()
708-
_ = StableDiffusionPipeline.from_pretrained(
709-
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, low_cpu_mem_usage=False
710-
)
711-
normal_load_time = time.time() - start_time
712-
713-
assert 2 * low_cpu_mem_usage_time < normal_load_time
714-
715696
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
716697
torch.cuda.empty_cache()
717698
torch.cuda.reset_max_memory_allocated()

0 commit comments

Comments
 (0)