Skip to content

Commit ca914d0

Browse files
authored
Merge branch 'main' into lora
2 parents bc3d43f + 06beeca commit ca914d0

File tree

5 files changed

+91
-25
lines changed

5 files changed

+91
-25
lines changed

examples/controlnet/train_controlnet_sdxl.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,25 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
134134

135135
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
136136
validation_image = Image.open(validation_image).convert("RGB")
137-
validation_image = validation_image.resize((args.resolution, args.resolution))
137+
138+
try:
139+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper())
140+
except (AttributeError, KeyError):
141+
supported_interpolation_modes = [
142+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
143+
]
144+
raise ValueError(
145+
f"Interpolation mode {args.image_interpolation_mode} is not supported. "
146+
f"Please select one of the following: {', '.join(supported_interpolation_modes)}"
147+
)
148+
149+
transform = transforms.Compose(
150+
[
151+
transforms.Resize(args.resolution, interpolation=interpolation),
152+
transforms.CenterCrop(args.resolution),
153+
]
154+
)
155+
validation_image = transform(validation_image)
138156

139157
images = []
140158

@@ -587,6 +605,15 @@ def parse_args(input_args=None):
587605
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
588606
),
589607
)
608+
parser.add_argument(
609+
"--image_interpolation_mode",
610+
type=str,
611+
default="lanczos",
612+
choices=[
613+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
614+
],
615+
help="The image interpolation method to use for resizing images.",
616+
)
590617

591618
if input_args is not None:
592619
args = parser.parse_args(input_args)
@@ -732,9 +759,20 @@ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prom
732759

733760

734761
def prepare_train_dataset(dataset, accelerator):
762+
try:
763+
interpolation_mode = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper())
764+
except (AttributeError, KeyError):
765+
supported_interpolation_modes = [
766+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
767+
]
768+
raise ValueError(
769+
f"Interpolation mode {args.image_interpolation_mode} is not supported. "
770+
f"Please select one of the following: {', '.join(supported_interpolation_modes)}"
771+
)
772+
735773
image_transforms = transforms.Compose(
736774
[
737-
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
775+
transforms.Resize(args.resolution, interpolation=interpolation_mode),
738776
transforms.CenterCrop(args.resolution),
739777
transforms.ToTensor(),
740778
transforms.Normalize([0.5], [0.5]),
@@ -743,7 +781,7 @@ def prepare_train_dataset(dataset, accelerator):
743781

744782
conditioning_image_transforms = transforms.Compose(
745783
[
746-
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
784+
transforms.Resize(args.resolution, interpolation=interpolation_mode),
747785
transforms.CenterCrop(args.resolution),
748786
transforms.ToTensor(),
749787
]

src/diffusers/hooks/group_offloading.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from contextlib import contextmanager, nullcontext
16-
from typing import Dict, List, Optional, Set, Tuple
16+
from typing import Dict, List, Optional, Set, Tuple, Union
1717

1818
import torch
1919

@@ -55,7 +55,7 @@ def __init__(
5555
parameters: Optional[List[torch.nn.Parameter]] = None,
5656
buffers: Optional[List[torch.Tensor]] = None,
5757
non_blocking: bool = False,
58-
stream: Optional[torch.cuda.Stream] = None,
58+
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
5959
record_stream: Optional[bool] = False,
6060
low_cpu_mem_usage: bool = False,
6161
onload_self: bool = True,
@@ -115,8 +115,13 @@ def _pinned_memory_tensors(self):
115115

116116
def onload_(self):
117117
r"""Onloads the group of modules to the onload_device."""
118-
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
119-
current_stream = torch.cuda.current_stream() if self.record_stream else None
118+
torch_accelerator_module = (
119+
getattr(torch, torch.accelerator.current_accelerator().type)
120+
if hasattr(torch, "accelerator")
121+
else torch.cuda
122+
)
123+
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
124+
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
120125

121126
if self.stream is not None:
122127
# Wait for previous Host->Device transfer to complete
@@ -162,9 +167,15 @@ def onload_(self):
162167

163168
def offload_(self):
164169
r"""Offloads the group of modules to the offload_device."""
170+
171+
torch_accelerator_module = (
172+
getattr(torch, torch.accelerator.current_accelerator().type)
173+
if hasattr(torch, "accelerator")
174+
else torch.cuda
175+
)
165176
if self.stream is not None:
166177
if not self.record_stream:
167-
torch.cuda.current_stream().synchronize()
178+
torch_accelerator_module.current_stream().synchronize()
168179
for group_module in self.modules:
169180
for param in group_module.parameters():
170181
param.data = self.cpu_param_dict[param]
@@ -429,8 +440,10 @@ def apply_group_offloading(
429440
if use_stream:
430441
if torch.cuda.is_available():
431442
stream = torch.cuda.Stream()
443+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
444+
stream = torch.Stream()
432445
else:
433-
raise ValueError("Using streams for data transfer requires a CUDA device.")
446+
raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
434447

435448
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
436449

@@ -468,7 +481,7 @@ def _apply_group_offloading_block_level(
468481
offload_device: torch.device,
469482
onload_device: torch.device,
470483
non_blocking: bool,
471-
stream: Optional[torch.cuda.Stream] = None,
484+
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
472485
record_stream: Optional[bool] = False,
473486
low_cpu_mem_usage: bool = False,
474487
) -> None:
@@ -486,7 +499,7 @@ def _apply_group_offloading_block_level(
486499
non_blocking (`bool`):
487500
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
488501
and data transfer.
489-
stream (`torch.cuda.Stream`, *optional*):
502+
stream (`torch.cuda.Stream`or `torch.Stream`, *optional*):
490503
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
491504
for overlapping computation and data transfer.
492505
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
@@ -572,7 +585,7 @@ def _apply_group_offloading_leaf_level(
572585
offload_device: torch.device,
573586
onload_device: torch.device,
574587
non_blocking: bool,
575-
stream: Optional[torch.cuda.Stream] = None,
588+
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
576589
record_stream: Optional[bool] = False,
577590
low_cpu_mem_usage: bool = False,
578591
) -> None:
@@ -592,7 +605,7 @@ def _apply_group_offloading_leaf_level(
592605
non_blocking (`bool`):
593606
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
594607
and data transfer.
595-
stream (`torch.cuda.Stream`, *optional*):
608+
stream (`torch.cuda.Stream` or `torch.Stream`, *optional*):
596609
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
597610
for overlapping computation and data transfer.
598611
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor

tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from diffusers import AsymmetricAutoencoderKL
2323
from diffusers.utils.import_utils import is_xformers_available
2424
from diffusers.utils.testing_utils import (
25+
Expectations,
2526
backend_empty_cache,
2627
enable_full_determinism,
2728
floats_tensor,
@@ -134,18 +135,32 @@ def get_generator(self, seed=0):
134135
# fmt: off
135136
[
136137
33,
137-
[-0.0336, 0.3011, 0.1764, 0.0087, -0.3401, 0.3645, -0.1247, 0.1205],
138-
[-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824],
138+
Expectations(
139+
{
140+
("xpu", 3): torch.tensor([-0.0343, 0.2873, 0.1680, -0.0140, -0.3459, 0.3522, -0.1336, 0.1075]),
141+
("cuda", 7): torch.tensor([-0.0336, 0.3011, 0.1764, 0.0087, -0.3401, 0.3645, -0.1247, 0.1205]),
142+
("mps", None): torch.tensor(
143+
[-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824]
144+
),
145+
}
146+
),
139147
],
140148
[
141149
47,
142-
[0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529],
143-
[-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089],
150+
Expectations(
151+
{
152+
("xpu", 3): torch.tensor([0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529]),
153+
("cuda", 7): torch.tensor([0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529]),
154+
("mps", None): torch.tensor(
155+
[-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089]
156+
),
157+
}
158+
),
144159
],
145160
# fmt: on
146161
]
147162
)
148-
def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps):
163+
def test_stable_diffusion(self, seed, expected_slices):
149164
model = self.get_sd_vae_model()
150165
image = self.get_sd_image(seed)
151166
generator = self.get_generator(seed)
@@ -156,9 +171,9 @@ def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps):
156171
assert sample.shape == image.shape
157172

158173
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
159-
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
160174

161-
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
175+
expected_slice = expected_slices.get_expectation()
176+
assert torch_all_close(output_slice, expected_slice, atol=5e-3)
162177

163178
@parameterized.expand(
164179
[

tests/pipelines/controlnet_flux/test_controlnet_flux.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
enable_full_determinism,
3636
nightly,
3737
numpy_cosine_similarity_distance,
38-
require_big_gpu_with_torch_cuda,
38+
require_big_accelerator,
3939
torch_device,
4040
)
4141
from diffusers.utils.torch_utils import randn_tensor
@@ -210,8 +210,8 @@ def test_flux_image_output_shape(self):
210210

211211

212212
@nightly
213-
@require_big_gpu_with_torch_cuda
214-
@pytest.mark.big_gpu_with_torch_cuda
213+
@require_big_accelerator
214+
@pytest.mark.big_accelerator
215215
class FluxControlNetPipelineSlowTests(unittest.TestCase):
216216
pipeline_class = FluxControlNetPipeline
217217

tests/single_file/test_model_wan_transformer3d_single_file.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from diffusers.utils.testing_utils import (
2525
backend_empty_cache,
2626
enable_full_determinism,
27-
require_big_gpu_with_torch_cuda,
27+
require_big_accelerator,
2828
require_torch_accelerator,
2929
torch_device,
3030
)
@@ -62,7 +62,7 @@ def test_single_file_components(self):
6262
)
6363

6464

65-
@require_big_gpu_with_torch_cuda
65+
@require_big_accelerator
6666
@require_torch_accelerator
6767
class WanTransformer3DModelImage2VideoSingleFileTest(unittest.TestCase):
6868
model_class = WanTransformer3DModel

0 commit comments

Comments
 (0)