Skip to content

Commit 5e84353

Browse files
Refactor progress bar (huggingface#242)
* Refactor progress bar of pipeline __call__ * Make any tqdm configs available * remove init * add some tests * remove file * finish * make style * improve progress bar test Co-authored-by: Patrick von Platen <[email protected]>
1 parent efa773a commit 5e84353

File tree

10 files changed

+45
-21
lines changed

10 files changed

+45
-21
lines changed

src/diffusers/pipeline_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from huggingface_hub import snapshot_download
2525
from PIL import Image
26+
from tqdm.auto import tqdm
2627

2728
from .configuration_utils import ConfigMixin
2829
from .utils import DIFFUSERS_CACHE, logging
@@ -266,3 +267,16 @@ def numpy_to_pil(images):
266267
pil_images = [Image.fromarray(image) for image in images]
267268

268269
return pil_images
270+
271+
def progress_bar(self, iterable):
272+
if not hasattr(self, "_progress_bar_config"):
273+
self._progress_bar_config = {}
274+
elif not isinstance(self._progress_bar_config, dict):
275+
raise ValueError(
276+
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
277+
)
278+
279+
return tqdm(iterable, **self._progress_bar_config)
280+
281+
def set_progress_bar_config(self, **kwargs):
282+
self._progress_bar_config = kwargs

src/diffusers/pipelines/ddim/pipeline_ddim.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818

1919
import torch
2020

21-
from tqdm.auto import tqdm
22-
2321
from ...pipeline_utils import DiffusionPipeline
2422

2523

@@ -56,7 +54,7 @@ def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50
5654
# set step values
5755
self.scheduler.set_timesteps(num_inference_steps)
5856

59-
for t in tqdm(self.scheduler.timesteps):
57+
for t in self.progress_bar(self.scheduler.timesteps):
6058
# 1. predict noise model_output
6159
model_output = self.unet(image, t)["sample"]
6260

src/diffusers/pipelines/ddpm/pipeline_ddpm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818

1919
import torch
2020

21-
from tqdm.auto import tqdm
22-
2321
from ...pipeline_utils import DiffusionPipeline
2422

2523

@@ -53,7 +51,7 @@ def __call__(self, batch_size=1, generator=None, output_type="pil", **kwargs):
5351
# set step values
5452
self.scheduler.set_timesteps(1000)
5553

56-
for t in tqdm(self.scheduler.timesteps):
54+
for t in self.progress_bar(self.scheduler.timesteps):
5755
# 1. predict noise model_output
5856
model_output = self.unet(image, t)["sample"]
5957

src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import torch.nn as nn
77
import torch.utils.checkpoint
88

9-
from tqdm.auto import tqdm
109
from transformers.activations import ACT2FN
1110
from transformers.configuration_utils import PretrainedConfig
1211
from transformers.modeling_outputs import BaseModelOutput
@@ -83,7 +82,7 @@ def __call__(
8382
if accepts_eta:
8483
extra_kwargs["eta"] = eta
8584

86-
for t in tqdm(self.scheduler.timesteps):
85+
for t in self.progress_bar(self.scheduler.timesteps):
8786
if guidance_scale == 1.0:
8887
# guidance_scale of 1 means no guidance
8988
latents_input = latents

src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33

44
import torch
55

6-
from tqdm.auto import tqdm
7-
86
from ...pipeline_utils import DiffusionPipeline
97

108

@@ -45,7 +43,7 @@ def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50
4543
if accepts_eta:
4644
extra_kwargs["eta"] = eta
4745

48-
for t in tqdm(self.scheduler.timesteps):
46+
for t in self.progress_bar(self.scheduler.timesteps):
4947
# predict the noise residual
5048
noise_prediction = self.unet(latents, t)["sample"]
5149
# compute the previous noisy sample x_t -> x_t-1

src/diffusers/pipelines/pndm/pipeline_pndm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818

1919
import torch
2020

21-
from tqdm.auto import tqdm
22-
2321
from ...pipeline_utils import DiffusionPipeline
2422

2523

@@ -54,7 +52,7 @@ def __call__(self, batch_size=1, generator=None, num_inference_steps=50, output_
5452
image = image.to(self.device)
5553

5654
self.scheduler.set_timesteps(num_inference_steps)
57-
for t in tqdm(self.scheduler.timesteps):
55+
for t in self.progress_bar(self.scheduler.timesteps):
5856
model_output = self.unet(image, t)["sample"]
5957

6058
image = self.scheduler.step(model_output, t, image)["prev_sample"]

src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import torch
55

66
from diffusers import DiffusionPipeline
7-
from tqdm.auto import tqdm
87

98

109
class ScoreSdeVePipeline(DiffusionPipeline):
@@ -37,7 +36,7 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, outpu
3736
self.scheduler.set_timesteps(num_inference_steps)
3837
self.scheduler.set_sigmas(num_inference_steps)
3938

40-
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
39+
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
4140
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)
4241

4342
# correction step

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import torch
66

7-
from tqdm.auto import tqdm
87
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
98

109
from ...models import AutoencoderKL, UNet2DConditionModel
@@ -133,7 +132,7 @@ def __call__(
133132
if accepts_eta:
134133
extra_step_kwargs["eta"] = eta
135134

136-
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
135+
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
137136
# expand the latents if we are doing classifier free guidance
138137
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
139138
if isinstance(self.scheduler, LMSDiscreteScheduler):

src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33

44
import torch
55

6-
from tqdm.auto import tqdm
7-
86
from ...models import UNet2DModel
97
from ...pipeline_utils import DiffusionPipeline
108
from ...schedulers import KarrasVeScheduler
@@ -53,7 +51,7 @@ def __call__(self, batch_size=1, num_inference_steps=50, generator=None, output_
5351

5452
self.scheduler.set_timesteps(num_inference_steps)
5553

56-
for t in tqdm(self.scheduler.timesteps):
54+
for t in self.progress_bar(self.scheduler.timesteps):
5755
# here sigma_t == t_i from the paper
5856
sigma = self.scheduler.schedule[t]
5957
sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0

tests/test_pipelines.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,29 @@
4444
torch.backends.cuda.matmul.allow_tf32 = False
4545

4646

47+
def test_progress_bar(capsys):
48+
model = UNet2DModel(
49+
block_out_channels=(32, 64),
50+
layers_per_block=2,
51+
sample_size=32,
52+
in_channels=3,
53+
out_channels=3,
54+
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
55+
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
56+
)
57+
scheduler = DDPMScheduler(num_train_timesteps=10)
58+
59+
ddpm = DDPMPipeline(model, scheduler).to(torch_device)
60+
ddpm(output_type="numpy")["sample"]
61+
captured = capsys.readouterr()
62+
assert "10/10" in captured.err, "Progress bar has to be displayed"
63+
64+
ddpm.set_progress_bar_config(disable=True)
65+
ddpm(output_type="numpy")["sample"]
66+
captured = capsys.readouterr()
67+
assert captured.err == "", "Progress bar should be disabled"
68+
69+
4770
class PipelineTesterMixin(unittest.TestCase):
4871
def test_from_pretrained_save_pretrained(self):
4972
# 1. Load models

0 commit comments

Comments
 (0)