Skip to content

Commit ed6cf52

Browse files
authored
[train_dreambooth_lora_sdxl_advanced] Add LANCZOS as the default interpolation mode for image resizing (#11471)
1 parent e23705e commit ed6cf52

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,15 @@ def parse_args(input_args=None):
799799
default=False,
800800
help="Cache the VAE latents",
801801
)
802+
parser.add_argument(
803+
"--image_interpolation_mode",
804+
type=str,
805+
default="lanczos",
806+
choices=[
807+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
808+
],
809+
help="The image interpolation method to use for resizing images.",
810+
)
802811

803812
if input_args is not None:
804813
args = parser.parse_args(input_args)
@@ -1069,7 +1078,10 @@ def __init__(
10691078
self.original_sizes = []
10701079
self.crop_top_lefts = []
10711080
self.pixel_values = []
1072-
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
1081+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
1082+
if interpolation is None:
1083+
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
1084+
train_resize = transforms.Resize(size, interpolation=interpolation)
10731085
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
10741086
train_flip = transforms.RandomHorizontalFlip(p=1.0)
10751087
train_transforms = transforms.Compose(
@@ -1146,7 +1158,7 @@ def __init__(
11461158

11471159
self.image_transforms = transforms.Compose(
11481160
[
1149-
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
1161+
transforms.Resize(size, interpolation=interpolation),
11501162
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
11511163
transforms.ToTensor(),
11521164
transforms.Normalize([0.5], [0.5]),

0 commit comments

Comments
 (0)