Skip to content

Commit 7c22626

Browse files
Align PT and Flax API - allow loading checkpoint from PyTorch configs (huggingface#827)
* up * finish * add more tests * up * up * finish
1 parent 78db11d commit 7c22626

File tree

6 files changed

+165
-31
lines changed

6 files changed

+165
-31
lines changed

src/diffusers/pipeline_flax_utils.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -111,24 +111,27 @@ def register_modules(self, **kwargs):
111111
from diffusers import pipelines
112112

113113
for name, module in kwargs.items():
114-
# retrieve library
115-
library = module.__module__.split(".")[0]
114+
if module is None:
115+
register_dict = {name: (None, None)}
116+
else:
117+
# retrieve library
118+
library = module.__module__.split(".")[0]
116119

117-
# check if the module is a pipeline module
118-
pipeline_dir = module.__module__.split(".")[-2]
119-
path = module.__module__.split(".")
120-
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
120+
# check if the module is a pipeline module
121+
pipeline_dir = module.__module__.split(".")[-2]
122+
path = module.__module__.split(".")
123+
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
121124

122-
# if library is not in LOADABLE_CLASSES, then it is a custom module.
123-
# Or if it's a pipeline module, then the module is inside the pipeline
124-
# folder so we set the library to module name.
125-
if library not in LOADABLE_CLASSES or is_pipeline_module:
126-
library = pipeline_dir
125+
# if library is not in LOADABLE_CLASSES, then it is a custom module.
126+
# Or if it's a pipeline module, then the module is inside the pipeline
127+
# folder so we set the library to module name.
128+
if library not in LOADABLE_CLASSES or is_pipeline_module:
129+
library = pipeline_dir
127130

128-
# retrieve class_name
129-
class_name = module.__class__.__name__
131+
# retrieve class_name
132+
class_name = module.__class__.__name__
130133

131-
register_dict = {name: (library, class_name)}
134+
register_dict = {name: (library, class_name)}
132135

133136
# save model index config
134137
self.register_to_config(**register_dict)
@@ -320,6 +323,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
320323
pipeline_class = cls
321324
else:
322325
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
326+
class_name = (
327+
config_dict["_class_name"]
328+
if config_dict["_class_name"].startswith("Flax")
329+
else "Flax" + config_dict["_class_name"]
330+
)
323331
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
324332

325333
# some modules can be passed directly to the init
@@ -342,6 +350,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
342350
for name, (library_name, class_name) in init_dict.items():
343351
is_pipeline_module = hasattr(pipelines, library_name)
344352
loaded_sub_model = None
353+
sub_model_should_be_defined = True
345354

346355
# if the model is in a pipeline module, then we load it from the pipeline
347356
if name in passed_class_obj:
@@ -362,6 +371,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
362371
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
363372
f" {expected_class_obj}"
364373
)
374+
elif passed_class_obj[name] is None:
375+
logger.warn(
376+
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
377+
f" that this might lead to problems when using {pipeline_class} and is not recommended."
378+
)
379+
sub_model_should_be_defined = False
365380
else:
366381
logger.warn(
367382
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
@@ -372,25 +387,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
372387
loaded_sub_model = passed_class_obj[name]
373388
elif is_pipeline_module:
374389
pipeline_module = getattr(pipelines, library_name)
375-
if from_pt:
376-
class_obj = import_flax_or_no_model(pipeline_module, class_name)
377-
else:
378-
class_obj = getattr(pipeline_module, class_name)
390+
class_obj = import_flax_or_no_model(pipeline_module, class_name)
379391

380392
importable_classes = ALL_IMPORTABLE_CLASSES
381393
class_candidates = {c: class_obj for c in importable_classes.keys()}
382394
else:
383395
# else we just import it from the library.
384396
library = importlib.import_module(library_name)
385-
if from_pt:
386-
class_obj = import_flax_or_no_model(library, class_name)
387-
else:
388-
class_obj = getattr(library, class_name)
397+
class_obj = import_flax_or_no_model(library, class_name)
389398

390399
importable_classes = LOADABLE_CLASSES[library_name]
391400
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
392401

393-
if loaded_sub_model is None:
402+
if loaded_sub_model is None and sub_model_should_be_defined:
394403
load_method_name = None
395404
for class_name, class_candidate in class_candidates.items():
396405
if issubclass(class_obj, class_candidate):

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
1515
from ...pipeline_flax_utils import FlaxDiffusionPipeline
1616
from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler
17+
from ...utils import logging
1718
from . import FlaxStableDiffusionPipelineOutput
1819
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
1920

2021

22+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
23+
24+
2125
class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
2226
r"""
2327
Pipeline for text-to-image generation using Stable Diffusion.
@@ -60,6 +64,16 @@ def __init__(
6064
super().__init__()
6165
self.dtype = dtype
6266

67+
if safety_checker is None:
68+
logger.warn(
69+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
70+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
71+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
72+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
73+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
74+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
75+
)
76+
6377
self.register_modules(
6478
vae=vae,
6579
text_encoder=text_encoder,
@@ -265,10 +279,23 @@ def __call__(
265279
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
266280
)
267281

268-
safety_params = params["safety_checker"]
269-
images = (images * 255).round().astype("uint8")
270-
images = np.asarray(images).reshape(-1, height, width, 3)
271-
images, has_nsfw_concept = self._run_safety_checker(images, safety_params, jit)
282+
if self.safety_checker is not None:
283+
safety_params = params["safety_checker"]
284+
images_uint8_casted = (images * 255).round().astype("uint8")
285+
num_devices, batch_size = images.shape[:2]
286+
287+
images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
288+
images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit)
289+
images = np.asarray(images)
290+
291+
# block images
292+
if any(has_nsfw_concept):
293+
for i, is_nsfw in enumerate(has_nsfw_concept):
294+
images[i] = np.asarray(images_uint8_casted[i])
295+
296+
images = images.reshape(num_devices, batch_size, height, width, 3)
297+
else:
298+
has_nsfw_concept = False
272299

273300
if not return_dict:
274301
return (images, has_nsfw_concept)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(
7373

7474
if safety_checker is None:
7575
logger.warn(
76-
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
76+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
7777
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
7878
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
7979
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(
8585

8686
if safety_checker is None:
8787
logger.warn(
88-
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
88+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
8989
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
9090
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
9191
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def __init__(
100100

101101
if safety_checker is None:
102102
logger.warn(
103-
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
103+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
104104
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
105105
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
106106
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"

tests/test_pipelines_flax.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
if is_flax_available():
2525
import jax
26+
import jax.numpy as jnp
2627
from diffusers import FlaxStableDiffusionPipeline
2728
from flax.jax_utils import replicate
2829
from flax.training.common_utils import shard
@@ -34,7 +35,7 @@
3435
class FlaxPipelineTests(unittest.TestCase):
3536
def test_dummy_all_tpus(self):
3637
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
37-
"hf-internal-testing/tiny-stable-diffusion-pipe"
38+
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
3839
)
3940

4041
prompt = (
@@ -57,6 +58,103 @@ def test_dummy_all_tpus(self):
5758
prompt_ids = shard(prompt_ids)
5859

5960
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
61+
62+
assert images.shape == (8, 1, 64, 64, 3)
63+
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 4.151474)) < 1e-3
64+
assert np.abs((np.abs(images, dtype=np.float32).sum() - 49947.875)) < 1e-2
65+
6066
images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
6167

6268
assert len(images_pil) == 8
69+
70+
def test_stable_diffusion_v1_4(self):
71+
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
72+
"CompVis/stable-diffusion-v1-4", revision="flax", safety_checker=None
73+
)
74+
75+
prompt = (
76+
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
77+
" field, close up, split lighting, cinematic"
78+
)
79+
80+
prng_seed = jax.random.PRNGKey(0)
81+
num_inference_steps = 50
82+
83+
num_samples = jax.device_count()
84+
prompt = num_samples * [prompt]
85+
prompt_ids = pipeline.prepare_inputs(prompt)
86+
87+
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
88+
89+
# shard inputs and rng
90+
params = replicate(params)
91+
prng_seed = jax.random.split(prng_seed, 8)
92+
prompt_ids = shard(prompt_ids)
93+
94+
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
95+
96+
images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
97+
for i, image in enumerate(images_pil):
98+
image.save(f"/home/patrick/images/flax-test-{i}_fp32.png")
99+
100+
assert images.shape == (8, 1, 512, 512, 3)
101+
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.05652401)) < 1e-3
102+
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2383808.2)) < 1e-2
103+
104+
def test_stable_diffusion_v1_4_bfloat_16(self):
105+
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
106+
"CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16, safety_checker=None
107+
)
108+
109+
prompt = (
110+
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
111+
" field, close up, split lighting, cinematic"
112+
)
113+
114+
prng_seed = jax.random.PRNGKey(0)
115+
num_inference_steps = 50
116+
117+
num_samples = jax.device_count()
118+
prompt = num_samples * [prompt]
119+
prompt_ids = pipeline.prepare_inputs(prompt)
120+
121+
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
122+
123+
# shard inputs and rng
124+
params = replicate(params)
125+
prng_seed = jax.random.split(prng_seed, 8)
126+
prompt_ids = shard(prompt_ids)
127+
128+
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
129+
130+
assert images.shape == (8, 1, 512, 512, 3)
131+
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
132+
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 1e-2
133+
134+
def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
135+
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
136+
"CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16
137+
)
138+
139+
prompt = (
140+
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
141+
" field, close up, split lighting, cinematic"
142+
)
143+
144+
prng_seed = jax.random.PRNGKey(0)
145+
num_inference_steps = 50
146+
147+
num_samples = jax.device_count()
148+
prompt = num_samples * [prompt]
149+
prompt_ids = pipeline.prepare_inputs(prompt)
150+
151+
# shard inputs and rng
152+
params = replicate(params)
153+
prng_seed = jax.random.split(prng_seed, 8)
154+
prompt_ids = shard(prompt_ids)
155+
156+
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
157+
158+
assert images.shape == (8, 1, 512, 512, 3)
159+
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
160+
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 1e-2

0 commit comments

Comments
 (0)