@@ -799,6 +799,15 @@ def parse_args(input_args=None):
799
799
default = False ,
800
800
help = "Cache the VAE latents" ,
801
801
)
802
+ parser .add_argument (
803
+ "--image_interpolation_mode" ,
804
+ type = str ,
805
+ default = "lanczos" ,
806
+ choices = [
807
+ f .lower () for f in dir (transforms .InterpolationMode ) if not f .startswith ("__" ) and not f .endswith ("__" )
808
+ ],
809
+ help = "The image interpolation method to use for resizing images." ,
810
+ )
802
811
803
812
if input_args is not None :
804
813
args = parser .parse_args (input_args )
@@ -1069,7 +1078,10 @@ def __init__(
1069
1078
self .original_sizes = []
1070
1079
self .crop_top_lefts = []
1071
1080
self .pixel_values = []
1072
- train_resize = transforms .Resize (size , interpolation = transforms .InterpolationMode .BILINEAR )
1081
+ interpolation = getattr (transforms .InterpolationMode , args .image_interpolation_mode .upper (), None )
1082
+ if interpolation is None :
1083
+ raise ValueError (f"Unsupported interpolation mode { interpolation = } ." )
1084
+ train_resize = transforms .Resize (size , interpolation = interpolation )
1073
1085
train_crop = transforms .CenterCrop (size ) if center_crop else transforms .RandomCrop (size )
1074
1086
train_flip = transforms .RandomHorizontalFlip (p = 1.0 )
1075
1087
train_transforms = transforms .Compose (
@@ -1146,7 +1158,7 @@ def __init__(
1146
1158
1147
1159
self .image_transforms = transforms .Compose (
1148
1160
[
1149
- transforms .Resize (size , interpolation = transforms . InterpolationMode . BILINEAR ),
1161
+ transforms .Resize (size , interpolation = interpolation ),
1150
1162
transforms .CenterCrop (size ) if center_crop else transforms .RandomCrop (size ),
1151
1163
transforms .ToTensor (),
1152
1164
transforms .Normalize ([0.5 ], [0.5 ]),
0 commit comments