Skip to content

Commit b93fe08

Browse files
[Loading] Make sure loading edge cases work (huggingface#1192)
* [Loading] Make edge cases work * up * finish * up
1 parent 3f7edc5 commit b93fe08

File tree

3 files changed

+61
-12
lines changed

3 files changed

+61
-12
lines changed

src/diffusers/pipeline_flax_utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
5656
"FlaxPreTrainedModel": ["save_pretrained", "from_pretrained"],
5757
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
58+
"ProcessorMixin": ["save_pretrained", "from_pretrained"],
59+
"ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
5860
},
5961
}
6062

@@ -172,8 +174,8 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union
172174
for library_name, library_classes in LOADABLE_CLASSES.items():
173175
library = importlib.import_module(library_name)
174176
for base_class, save_load_methods in library_classes.items():
175-
class_candidate = getattr(library, base_class)
176-
if issubclass(model_cls, class_candidate):
177+
class_candidate = getattr(library, base_class, None)
178+
if class_candidate is not None and issubclass(model_cls, class_candidate):
177179
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
178180
save_method_name = save_load_methods[0]
179181
break
@@ -387,11 +389,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
387389
library = importlib.import_module(library_name)
388390
class_obj = getattr(library, class_name)
389391
importable_classes = LOADABLE_CLASSES[library_name]
390-
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
392+
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
391393

392394
expected_class_obj = None
393395
for class_name, class_candidate in class_candidates.items():
394-
if issubclass(class_obj, class_candidate):
396+
if class_candidate is not None and issubclass(class_obj, class_candidate):
395397
expected_class_obj = class_candidate
396398

397399
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
@@ -425,12 +427,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
425427
class_obj = import_flax_or_no_model(library, class_name)
426428

427429
importable_classes = LOADABLE_CLASSES[library_name]
428-
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
430+
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
429431

430432
if loaded_sub_model is None and sub_model_should_be_defined:
431433
load_method_name = None
432434
for class_name, class_candidate in class_candidates.items():
433-
if issubclass(class_obj, class_candidate):
435+
if class_candidate is not None and issubclass(class_obj, class_candidate):
434436
load_method_name = importable_classes[class_name][1]
435437

436438
load_method = getattr(class_obj, load_method_name)

src/diffusers/pipeline_utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@
7474
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
7575
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
7676
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
77+
"ProcessorMixin": ["save_pretrained", "from_pretrained"],
78+
"ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
7779
},
7880
}
7981

@@ -190,8 +192,8 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike]):
190192
for library_name, library_classes in LOADABLE_CLASSES.items():
191193
library = importlib.import_module(library_name)
192194
for base_class, save_load_methods in library_classes.items():
193-
class_candidate = getattr(library, base_class)
194-
if issubclass(model_cls, class_candidate):
195+
class_candidate = getattr(library, base_class, None)
196+
if class_candidate is not None and issubclass(model_cls, class_candidate):
195197
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
196198
save_method_name = save_load_methods[0]
197199
break
@@ -543,11 +545,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
543545
library = importlib.import_module(library_name)
544546
class_obj = getattr(library, class_name)
545547
importable_classes = LOADABLE_CLASSES[library_name]
546-
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
548+
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
547549

548550
expected_class_obj = None
549551
for class_name, class_candidate in class_candidates.items():
550-
if issubclass(class_obj, class_candidate):
552+
if class_candidate is not None and issubclass(class_obj, class_candidate):
551553
expected_class_obj = class_candidate
552554

553555
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
@@ -577,14 +579,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
577579
else:
578580
# else we just import it from the library.
579581
library = importlib.import_module(library_name)
582+
580583
class_obj = getattr(library, class_name)
581584
importable_classes = LOADABLE_CLASSES[library_name]
582-
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
585+
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
583586

584587
if loaded_sub_model is None and sub_model_should_be_defined:
585588
load_method_name = None
586589
for class_name, class_candidate in class_candidates.items():
587-
if issubclass(class_obj, class_candidate):
590+
if class_candidate is not None and issubclass(class_obj, class_candidate):
588591
load_method_name = importable_classes[class_name][1]
589592

590593
if load_method_name is None:

tests/test_pipelines.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,50 @@ def test_download_only_pytorch(self):
8888
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack
8989
assert not any(f.endswith(".msgpack") for f in files)
9090

91+
def test_download_no_safety_checker(self):
92+
prompt = "hello"
93+
pipe = StableDiffusionPipeline.from_pretrained(
94+
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
95+
)
96+
generator = torch.Generator(device=torch_device).manual_seed(0)
97+
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
98+
99+
pipe_2 = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
100+
generator_2 = torch.Generator(device=torch_device).manual_seed(0)
101+
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
102+
103+
assert np.max(np.abs(out - out_2)) < 1e-3
104+
105+
def test_load_no_safety_checker_explicit_locally(self):
106+
prompt = "hello"
107+
pipe = StableDiffusionPipeline.from_pretrained(
108+
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
109+
)
110+
generator = torch.Generator(device=torch_device).manual_seed(0)
111+
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
112+
113+
with tempfile.TemporaryDirectory() as tmpdirname:
114+
pipe.save_pretrained(tmpdirname)
115+
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None)
116+
generator_2 = torch.Generator(device=torch_device).manual_seed(0)
117+
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
118+
119+
assert np.max(np.abs(out - out_2)) < 1e-3
120+
121+
def test_load_no_safety_checker_default_locally(self):
122+
prompt = "hello"
123+
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
124+
generator = torch.Generator(device=torch_device).manual_seed(0)
125+
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
126+
127+
with tempfile.TemporaryDirectory() as tmpdirname:
128+
pipe.save_pretrained(tmpdirname)
129+
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname)
130+
generator_2 = torch.Generator(device=torch_device).manual_seed(0)
131+
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
132+
133+
assert np.max(np.abs(out - out_2)) < 1e-3
134+
91135

92136
class CustomPipelineTests(unittest.TestCase):
93137
def test_load_custom_pipeline(self):

0 commit comments

Comments
 (0)