Skip to content

Commit 71ba8ae

Browse files
pcuencapatrickvonplatenanton-l
authored
Pipeline to device (huggingface#210)
* Implement `pipeline.to(device)` * DiffusionPipeline.to() decides best device on None. * Breaking change: torch_device removed from __call__ `pipeline.to()` now has PyTorch semantics. * Use kwargs and deprecation notice Co-authored-by: Patrick von Platen <[email protected]> * Apply torch_device compatibility to all pipelines. * style Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: anton-l <[email protected]>
1 parent 89e9521 commit 71ba8ae

File tree

9 files changed

+147
-60
lines changed

9 files changed

+147
-60
lines changed

src/diffusers/pipeline_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import os
2020
from typing import Optional, Union
2121

22+
import torch
23+
2224
from huggingface_hub import snapshot_download
2325
from PIL import Image
2426

@@ -113,6 +115,26 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike]):
113115
save_method = getattr(sub_model, save_method_name)
114116
save_method(os.path.join(save_directory, pipeline_component_name))
115117

118+
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
119+
if torch_device is None:
120+
return self
121+
122+
module_names, _ = self.extract_init_dict(dict(self.config))
123+
for name in module_names.keys():
124+
module = getattr(self, name)
125+
if isinstance(module, torch.nn.Module):
126+
module.to(torch_device)
127+
return self
128+
129+
@property
130+
def device(self) -> torch.device:
131+
module_names, _ = self.extract_init_dict(dict(self.config))
132+
for name in module_names.keys():
133+
module = getattr(self, name)
134+
if isinstance(module, torch.nn.Module):
135+
return module.device
136+
return torch.device("cpu")
137+
116138
@classmethod
117139
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
118140
r"""

src/diffusers/pipelines/ddim/pipeline_ddim.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# limitations under the License.
1515

1616

17+
import warnings
18+
1719
import torch
1820

1921
from tqdm.auto import tqdm
@@ -28,21 +30,28 @@ def __init__(self, unet, scheduler):
2830
self.register_modules(unet=unet, scheduler=scheduler)
2931

3032
@torch.no_grad()
31-
def __call__(
32-
self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"
33-
):
34-
# eta corresponds to η in paper and should be between [0, 1]
35-
if torch_device is None:
36-
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
33+
def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50, output_type="pil", **kwargs):
34+
35+
if "torch_device" in kwargs:
36+
device = kwargs.pop("torch_device")
37+
warnings.warn(
38+
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
39+
" Consider using `pipe.to(torch_device)` instead."
40+
)
3741

38-
self.unet.to(torch_device)
42+
# Set device as before (to be removed in 0.3.0)
43+
if device is None:
44+
device = "cuda" if torch.cuda.is_available() else "cpu"
45+
self.to(device)
46+
47+
# eta corresponds to η in paper and should be between [0, 1]
3948

4049
# Sample gaussian noise to begin loop
4150
image = torch.randn(
4251
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
4352
generator=generator,
4453
)
45-
image = image.to(torch_device)
54+
image = image.to(self.device)
4655

4756
# set step values
4857
self.scheduler.set_timesteps(num_inference_steps)

src/diffusers/pipelines/ddpm/pipeline_ddpm.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# limitations under the License.
1515

1616

17+
import warnings
18+
1719
import torch
1820

1921
from tqdm.auto import tqdm
@@ -28,18 +30,25 @@ def __init__(self, unet, scheduler):
2830
self.register_modules(unet=unet, scheduler=scheduler)
2931

3032
@torch.no_grad()
31-
def __call__(self, batch_size=1, generator=None, torch_device=None, output_type="pil"):
32-
if torch_device is None:
33-
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
33+
def __call__(self, batch_size=1, generator=None, output_type="pil", **kwargs):
34+
if "torch_device" in kwargs:
35+
device = kwargs.pop("torch_device")
36+
warnings.warn(
37+
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
38+
" Consider using `pipe.to(torch_device)` instead."
39+
)
3440

35-
self.unet.to(torch_device)
41+
# Set device as before (to be removed in 0.3.0)
42+
if device is None:
43+
device = "cuda" if torch.cuda.is_available() else "cpu"
44+
self.to(device)
3645

3746
# Sample gaussian noise to begin loop
3847
image = torch.randn(
3948
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
4049
generator=generator,
4150
)
42-
image = image.to(torch_device)
51+
image = image.to(self.device)
4352

4453
# set step values
4554
self.scheduler.set_timesteps(1000)

src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
import warnings
23
from typing import List, Optional, Tuple, Union
34

45
import torch
@@ -31,13 +32,22 @@ def __call__(
3132
guidance_scale: Optional[float] = 1.0,
3233
eta: Optional[float] = 0.0,
3334
generator: Optional[torch.Generator] = None,
34-
torch_device: Optional[Union[str, torch.device]] = None,
3535
output_type: Optional[str] = "pil",
36+
**kwargs,
3637
):
3738
# eta corresponds to η in paper and should be between [0, 1]
3839

39-
if torch_device is None:
40-
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
40+
if "torch_device" in kwargs:
41+
device = kwargs.pop("torch_device")
42+
warnings.warn(
43+
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
44+
" Consider using `pipe.to(torch_device)` instead."
45+
)
46+
47+
# Set device as before (to be removed in 0.3.0)
48+
if device is None:
49+
device = "cuda" if torch.cuda.is_available() else "cpu"
50+
self.to(device)
4151

4252
if isinstance(prompt, str):
4353
batch_size = 1
@@ -49,24 +59,20 @@ def __call__(
4959
if height % 8 != 0 or width % 8 != 0:
5060
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
5161

52-
self.unet.to(torch_device)
53-
self.vqvae.to(torch_device)
54-
self.bert.to(torch_device)
55-
5662
# get unconditional embeddings for classifier free guidance
5763
if guidance_scale != 1.0:
5864
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
59-
uncond_embeddings = self.bert(uncond_input.input_ids.to(torch_device))[0]
65+
uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0]
6066

6167
# get prompt text embeddings
6268
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
63-
text_embeddings = self.bert(text_input.input_ids.to(torch_device))[0]
69+
text_embeddings = self.bert(text_input.input_ids.to(self.device))[0]
6470

6571
latents = torch.randn(
6672
(batch_size, self.unet.in_channels, height // 8, width // 8),
6773
generator=generator,
6874
)
69-
latents = latents.to(torch_device)
75+
latents = latents.to(self.device)
7076

7177
self.scheduler.set_timesteps(num_inference_steps)
7278

src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
import warnings
23

34
import torch
45

@@ -14,22 +15,26 @@ def __init__(self, vqvae, unet, scheduler):
1415
self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
1516

1617
@torch.no_grad()
17-
def __call__(
18-
self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"
19-
):
18+
def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50, output_type="pil", **kwargs):
2019
# eta corresponds to η in paper and should be between [0, 1]
2120

22-
if torch_device is None:
23-
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
21+
if "torch_device" in kwargs:
22+
device = kwargs.pop("torch_device")
23+
warnings.warn(
24+
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
25+
" Consider using `pipe.to(torch_device)` instead."
26+
)
2427

25-
self.unet.to(torch_device)
26-
self.vqvae.to(torch_device)
28+
# Set device as before (to be removed in 0.3.0)
29+
if device is None:
30+
device = "cuda" if torch.cuda.is_available() else "cpu"
31+
self.to(device)
2732

2833
latents = torch.randn(
2934
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
3035
generator=generator,
3136
)
32-
latents = latents.to(torch_device)
37+
latents = latents.to(self.device)
3338

3439
self.scheduler.set_timesteps(num_inference_steps)
3540

src/diffusers/pipelines/pndm/pipeline_pndm.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# limitations under the License.
1515

1616

17+
import warnings
18+
1719
import torch
1820

1921
from tqdm.auto import tqdm
@@ -28,20 +30,28 @@ def __init__(self, unet, scheduler):
2830
self.register_modules(unet=unet, scheduler=scheduler)
2931

3032
@torch.no_grad()
31-
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50, output_type="pil"):
33+
def __call__(self, batch_size=1, generator=None, num_inference_steps=50, output_type="pil", **kwargs):
3234
# For more information on the sampling method you can take a look at Algorithm 2 of
3335
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
34-
if torch_device is None:
35-
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
3636

37-
self.unet.to(torch_device)
37+
if "torch_device" in kwargs:
38+
device = kwargs.pop("torch_device")
39+
warnings.warn(
40+
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
41+
" Consider using `pipe.to(torch_device)` instead."
42+
)
43+
44+
# Set device as before (to be removed in 0.3.0)
45+
if device is None:
46+
device = "cuda" if torch.cuda.is_available() else "cpu"
47+
self.to(device)
3848

3949
# Sample gaussian noise to begin loop
4050
image = torch.randn(
4151
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
4252
generator=generator,
4353
)
44-
image = image.to(torch_device)
54+
image = image.to(self.device)
4555

4656
self.scheduler.set_timesteps(num_inference_steps)
4757
for t in tqdm(self.scheduler.timesteps):

src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#!/usr/bin/env python3
2+
import warnings
3+
24
import torch
35

46
from diffusers import DiffusionPipeline
@@ -11,24 +13,32 @@ def __init__(self, unet, scheduler):
1113
self.register_modules(unet=unet, scheduler=scheduler)
1214

1315
@torch.no_grad()
14-
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch_device=None, output_type="pil"):
15-
16-
if torch_device is None:
17-
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
16+
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, output_type="pil", **kwargs):
17+
if "torch_device" in kwargs:
18+
device = kwargs.pop("torch_device")
19+
warnings.warn(
20+
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
21+
" Consider using `pipe.to(torch_device)` instead."
22+
)
23+
24+
# Set device as before (to be removed in 0.3.0)
25+
if device is None:
26+
device = "cuda" if torch.cuda.is_available() else "cpu"
27+
self.to(device)
1828

1929
img_size = self.unet.config.sample_size
2030
shape = (batch_size, 3, img_size, img_size)
2131

22-
model = self.unet.to(torch_device)
32+
model = self.unet
2333

2434
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
25-
sample = sample.to(torch_device)
35+
sample = sample.to(self.device)
2636

2737
self.scheduler.set_timesteps(num_inference_steps)
2838
self.scheduler.set_sigmas(num_inference_steps)
2939

3040
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
31-
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=torch_device)
41+
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)
3242

3343
# correction step
3444
for _ in range(self.scheduler.correct_steps):

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
import warnings
23
from typing import List, Optional, Union
34

45
import torch
@@ -45,11 +46,20 @@ def __call__(
4546
guidance_scale: Optional[float] = 7.5,
4647
eta: Optional[float] = 0.0,
4748
generator: Optional[torch.Generator] = None,
48-
torch_device: Optional[Union[str, torch.device]] = None,
4949
output_type: Optional[str] = "pil",
50+
**kwargs,
5051
):
51-
if torch_device is None:
52-
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
52+
if "torch_device" in kwargs:
53+
device = kwargs.pop("torch_device")
54+
warnings.warn(
55+
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
56+
" Consider using `pipe.to(torch_device)` instead."
57+
)
58+
59+
# Set device as before (to be removed in 0.3.0)
60+
if device is None:
61+
device = "cuda" if torch.cuda.is_available() else "cpu"
62+
self.to(device)
5363

5464
if isinstance(prompt, str):
5565
batch_size = 1
@@ -61,11 +71,6 @@ def __call__(
6171
if height % 8 != 0 or width % 8 != 0:
6272
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
6373

64-
self.unet.to(torch_device)
65-
self.vae.to(torch_device)
66-
self.text_encoder.to(torch_device)
67-
self.safety_checker.to(torch_device)
68-
6974
# get prompt text embeddings
7075
text_input = self.tokenizer(
7176
prompt,
@@ -74,7 +79,7 @@ def __call__(
7479
truncation=True,
7580
return_tensors="pt",
7681
)
77-
text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0]
82+
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
7883

7984
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
8085
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -86,7 +91,7 @@ def __call__(
8691
uncond_input = self.tokenizer(
8792
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
8893
)
89-
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(torch_device))[0]
94+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
9095

9196
# For classifier free guidance, we need to do two forward passes.
9297
# Here we concatenate the unconditional and text embeddings into a single batch
@@ -97,7 +102,7 @@ def __call__(
97102
latents = torch.randn(
98103
(batch_size, self.unet.in_channels, height // 8, width // 8),
99104
generator=generator,
100-
device=torch_device,
105+
device=self.device,
101106
)
102107

103108
# set timesteps
@@ -150,7 +155,7 @@ def __call__(
150155
image = image.cpu().permute(0, 2, 3, 1).numpy()
151156

152157
# run safety checker
153-
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(torch_device)
158+
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
154159
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
155160

156161
if output_type == "pil":

0 commit comments

Comments
 (0)