Skip to content

Commit 8ed08e4

Browse files
patrickvonplatenanton-lpcuenca
authored
[Deterministic torch randn] Allow tensors to be generated on CPU (huggingface#1902)
* [Deterministic torch randn] Allow tensors to be generated on CPU * fix more * up * fix more * up * Update src/diffusers/utils/torch_utils.py Co-authored-by: Anton Lozhkov <[email protected]> * Apply suggestions from code review * up * up * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Anton Lozhkov <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent 0df83c7 commit 8ed08e4

File tree

13 files changed

+82
-27
lines changed

13 files changed

+82
-27
lines changed

examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def __getitem__(self, i):
336336

337337
if self.center_crop:
338338
crop = min(img.shape[0], img.shape[1])
339-
h, w, = (
339+
(h, w,) = (
340340
img.shape[0],
341341
img.shape[1],
342342
)

examples/textual_inversion/textual_inversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def __getitem__(self, i):
381381

382382
if self.center_crop:
383383
crop = min(img.shape[0], img.shape[1])
384-
h, w, = (
384+
(h, w,) = (
385385
img.shape[0],
386386
img.shape[1],
387387
)

examples/textual_inversion/textual_inversion_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def __getitem__(self, i):
306306

307307
if self.center_crop:
308308
crop = min(img.shape[0], img.shape[1])
309-
h, w, = (
309+
(h, w,) = (
310310
img.shape[0],
311311
img.shape[1],
312312
)

scripts/convert_kakao_brain_unclip_to_diffusers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,7 @@ def super_res_unet_last_step_original_checkpoint_to_diffusers_checkpoint(model,
564564

565565
# unet utils
566566

567+
567568
# <original>.time_embed -> <diffusers>.time_embedding
568569
def unet_time_embeddings(checkpoint, original_unet_prefix):
569570
diffusers_checkpoint = {}

src/diffusers/models/modeling_flax_pytorch_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def rename_key(key):
3737
# PyTorch => Flax #
3838
#####################
3939

40+
4041
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
4142
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
4243
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):

src/diffusers/pipelines/unclip/pipeline_unclip.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
2525
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
2626
from ...schedulers import UnCLIPScheduler
27-
from ...utils import is_accelerate_available, logging
27+
from ...utils import is_accelerate_available, logging, torch_randn
2828
from .text_proj import UnCLIPTextProjModel
2929

3030

@@ -105,11 +105,7 @@ def __init__(
105105

106106
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
107107
if latents is None:
108-
if device.type == "mps":
109-
# randn does not work reproducibly on mps
110-
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
111-
else:
112-
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
108+
latents = torch_randn(shape, generator=generator, device=device, dtype=dtype)
113109
else:
114110
if latents.shape != shape:
115111
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")

src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from ...models import UNet2DConditionModel, UNet2DModel
3030
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
3131
from ...schedulers import UnCLIPScheduler
32-
from ...utils import is_accelerate_available, logging
32+
from ...utils import is_accelerate_available, logging, torch_randn
3333
from .text_proj import UnCLIPTextProjModel
3434

3535

@@ -113,11 +113,7 @@ def __init__(
113113
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
114114
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
115115
if latents is None:
116-
if device.type == "mps":
117-
# randn does not work reproducibly on mps
118-
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
119-
else:
120-
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
116+
latents = torch_randn(shape, generator=generator, device=device, dtype=dtype)
121117
else:
122118
if latents.shape != shape:
123119
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")

src/diffusers/schedulers/scheduling_unclip.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch
2121

2222
from ..configuration_utils import ConfigMixin, register_to_config
23-
from ..utils import BaseOutput
23+
from ..utils import BaseOutput, torch_randn
2424
from .scheduling_utils import SchedulerMixin
2525

2626

@@ -273,15 +273,9 @@ def step(
273273
# 6. Add noise
274274
variance = 0
275275
if t > 0:
276-
device = model_output.device
277-
if device.type == "mps":
278-
# randn does not work reproducibly on mps
279-
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
280-
variance_noise = variance_noise.to(device)
281-
else:
282-
variance_noise = torch.randn(
283-
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
284-
)
276+
variance_noise = torch_randn(
277+
model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device
278+
)
285279

286280
variance = self._get_variance(
287281
t,

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
from .logging import get_logger
6565
from .outputs import BaseOutput
6666
from .pil_utils import PIL_INTERPOLATION
67+
from .torch_utils import torch_randn
6768

6869

6970
if is_torch_available():

src/diffusers/utils/torch_utils.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2022 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
PyTorch utilities: Utilities related to PyTorch
16+
"""
17+
from typing import List, Optional, Tuple, Union
18+
19+
from . import logging
20+
from .import_utils import is_torch_available
21+
22+
23+
if is_torch_available():
24+
import torch
25+
26+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27+
28+
29+
def torch_randn(
30+
shape: Union[Tuple, List],
31+
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
32+
device: Optional["torch.device"] = None,
33+
dtype: Optional["torch.dtype"] = None,
34+
):
35+
"""This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When
36+
passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor
37+
will always be created on CPU.
38+
"""
39+
# device on which tensor is created defaults to device
40+
rand_device = device
41+
batch_size = shape[0]
42+
43+
if generator is not None:
44+
if generator.device != device and generator.device.type == "cpu":
45+
rand_device = "cpu"
46+
if device != "mps":
47+
logger.info(
48+
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
49+
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
50+
f" slighly speed up this function by passing a generator that was created on the {device} device."
51+
)
52+
elif generator.device.type != device.type and generator.device.type == "cuda":
53+
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {generator.device.type}.")
54+
55+
if isinstance(generator, list):
56+
shape = (1,) + shape[1:]
57+
latents = [
58+
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
59+
]
60+
latents = torch.cat(latents, dim=0).to(device)
61+
else:
62+
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
63+
64+
return latents

0 commit comments

Comments
 (0)