Skip to content

Commit dfcce3c

Browse files
[Research folder] Add SDXL example (huggingface#5275)
* [SDXL Flax] Add research folder * Add co-author Co-authored-by: Juan Acevedo <[email protected]> --------- Co-authored-by: Juan Acevedo <[email protected]>
1 parent 2457599 commit dfcce3c

File tree

3 files changed

+492
-0
lines changed

3 files changed

+492
-0
lines changed
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
# Stable Diffusion XL for JAX + TPUv5e
2+
3+
[TPU v5e](https://cloud.google.com/blog/products/compute/how-cloud-tpu-v5e-accelerates-large-scale-ai-inference) is a new generation of TPUs from Google Cloud. It is the most cost-effective, versatile, and scalable Cloud TPU to date. This makes them ideal for serving and scaling large diffusion models.
4+
5+
[JAX](https://github.com/google/jax) is a high-performance numerical computation library that is well-suited to develop and deploy diffusion models:
6+
7+
- **High performance**. All JAX operations are implemented in terms of operations in [XLA](https://www.tensorflow.org/xla/) - the Accelerated Linear Algebra compiler
8+
9+
- **Compilation**. JAX uses just-in-time (jit) compilation of JAX Python functions so it can be executed efficiently in XLA. In order to get the best performance, we must use static shapes for jitted functions, this is because JAX transforms work by tracing a function and to determine its effect on inputs of a specific shape and type. When a new shape is introduced to an already compiled function, it retriggers compilation on the new shape, which can greatly reduce performance. **Note**: JIT compilation is particularly well-suited for text-to-image generation because all inputs and outputs (image input / output sizes) are static.
10+
11+
- **Parallelization**. Workloads can be scaled across multiple devices using JAX's [pmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html), which expresses single-program multiple-data (SPMD) programs. Applying pmap to a function will compile a function with XLA, then execute in parallel on XLA devices. For text-to-image generation workloads this means that increasing the number of images rendered simultaneously is straightforward to implement and doesn't compromise performance.
12+
13+
👉 Try it out for yourself:
14+
15+
[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/google/sdxl)
16+
17+
## Stable Diffusion XL pipeline in JAX
18+
19+
Upon having access to a TPU VM (TPUs higher than version 3), you should first install
20+
a TPU-compatible version of JAX:
21+
```
22+
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
23+
```
24+
25+
Next, we can install [flax](https://github.com/google/flax) and the diffusers library:
26+
27+
```
28+
pip install flax diffusers transformers
29+
```
30+
31+
In [sdxl_single.py](./sdxl_single.py) we give a simple example of how to write a text-to-image generation pipeline in JAX using [StabilityAI's Stable Diffusion XL](stabilityai/stable-diffusion-xl-base-1.0).
32+
33+
Let's explain it step-by-step:
34+
35+
**Imports and Setup**
36+
37+
```python
38+
import jax
39+
import jax.numpy as jnp
40+
import numpy as np
41+
from flax.jax_utils import replicate
42+
from diffusers import FlaxStableDiffusionXLPipeline
43+
44+
from jax.experimental.compilation_cache import compilation_cache as cc
45+
cc.initialize_cache("/tmp/sdxl_cache")
46+
import time
47+
48+
NUM_DEVICES = jax.device_count()
49+
```
50+
51+
First, we import the necessary libraries:
52+
- `jax` is provides the primitives for TPU operations
53+
- `flax.jax_utils` contains some useful utility functions for `Flax`, a neural network library built on top of JAX
54+
- `diffusers` has all the code that is relevant for SDXL.
55+
- We also initialize a cache to speed up the JAX model compilation.
56+
- We automatically determine the number of available TPU devices.
57+
58+
**1. Downloading Model and Loading Pipeline**
59+
60+
```python
61+
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
62+
"stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
63+
)
64+
```
65+
Here, a pre-trained model `stable-diffusion-xl-base-1.0` from the namespace `stabilityai` is loaded. It returns a pipeline for inference and its parameters.
66+
67+
**2. Casting Parameter Types**
68+
69+
```python
70+
scheduler_state = params.pop("scheduler")
71+
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
72+
params["scheduler"] = scheduler_state
73+
```
74+
This section adjusts the data types of the model parameters.
75+
We convert all parameters to `bfloat16` to speed-up the computation with model weights.
76+
**Note** that the scheduler parameters are **not** converted to `blfoat16` as the loss
77+
in precision is degrading the pipeline's performance too significantly.
78+
79+
**3. Define Inputs to Pipeline**
80+
81+
```python
82+
default_prompt = ...
83+
default_neg_prompt = ...
84+
default_seed = 33
85+
default_guidance_scale = 5.0
86+
default_num_steps = 25
87+
```
88+
Here, various default inputs for the pipeline are set, including the prompt, negative prompt, random seed, guidance scale, and the number of inference steps.
89+
90+
**4. Tokenizing Inputs**
91+
92+
```python
93+
def tokenize_prompt(prompt, neg_prompt):
94+
prompt_ids = pipeline.prepare_inputs(prompt)
95+
neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
96+
return prompt_ids, neg_prompt_ids
97+
```
98+
This function tokenizes the given prompts. It's essential because the text encoders of SDXL don't understand raw text; they work with numbers. Tokenization converts text to numbers.
99+
100+
**5. Parallelization and Replication**
101+
102+
```python
103+
p_params = replicate(params)
104+
105+
def replicate_all(prompt_ids, neg_prompt_ids, seed):
106+
...
107+
```
108+
To utilize JAX's parallel capabilities, the parameters and input tensors are duplicated across devices. The `replicate_all` function also ensures that every device produces a different image by creating a unique random seed for each device.
109+
110+
**6. Putting Everything Together**
111+
112+
```python
113+
def generate(...):
114+
...
115+
```
116+
This function integrates all the steps to produce the desired outputs from the model. It takes in prompts, tokenizes them, replicates them across devices, runs them through the pipeline, and converts the images to a format that's more interpretable (PIL format).
117+
118+
**7. Compilation Step**
119+
120+
```python
121+
start = time.time()
122+
print(f"Compiling ...")
123+
generate(default_prompt, default_neg_prompt)
124+
print(f"Compiled in {time.time() - start}")
125+
```
126+
The initial run of the `generate` function will be slow because JAX compiles the function during this call. By running it once here, subsequent calls will be much faster. This section measures and prints the compilation time.
127+
128+
**8. Fast Inference**
129+
130+
```python
131+
start = time.time()
132+
prompt = ...
133+
neg_prompt = ...
134+
images = generate(prompt, neg_prompt)
135+
print(f"Inference in {time.time() - start}")
136+
```
137+
Now that the function is compiled, this section shows how to use it for fast inference. It measures and prints the inference time.
138+
139+
In summary, the code demonstrates how to load a pre-trained model using Flax and JAX, prepare it for inference, and run it efficiently using JAX's capabilities.
140+
141+
## Ahead of Time (AOT) Compilation
142+
143+
FlaxStableDiffusionXLPipeline takes care of parallelization across multiple devices using jit. Now let's build parallelization ourselves.
144+
145+
For this we will be using a JAX feature called [Ahead of Time](https://jax.readthedocs.io/en/latest/aot.html) (AOT) lowering and compilation. AOT allows to fully compile prior to execution time and have control over different parts of the compilation process.
146+
147+
In [sdxl_single_aot.py](./sdxl_single_aot.py) we give a simple example of how to write our own parallelization logic for text-to-image generation pipeline in JAX using [StabilityAI's Stable Diffusion XL](stabilityai/stable-diffusion-xl-base-1.0)
148+
149+
We add a `aot_compile` function that compiles the `pipeline._generate` function
150+
telling JAX which input arguments are static, that is, arguments that
151+
are known at compile time and won't change. In our case, it is num_inference_steps,
152+
height, width and return_latents.
153+
154+
Once the function is compiled, these parameters are ommited from future calls and
155+
cannot be changed without modifying the code and recompiling.
156+
157+
```python
158+
def aot_compile(
159+
prompt=default_prompt,
160+
negative_prompt=default_neg_prompt,
161+
seed=default_seed,
162+
guidance_scale=default_guidance_scale,
163+
num_inference_steps=default_num_steps
164+
):
165+
prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
166+
prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
167+
g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)
168+
g = g[:, None]
169+
170+
return pmap(
171+
pipeline._generate,static_broadcasted_argnums=[3, 4, 5, 9]
172+
).lower(
173+
prompt_ids,
174+
p_params,
175+
rng,
176+
num_inference_steps, # num_inference_steps
177+
height, # height
178+
width, # width
179+
g,
180+
None,
181+
neg_prompt_ids,
182+
False # return_latents
183+
).compile()
184+
````
185+
186+
Next we can compile the generate function by executing `aot_compile`.
187+
188+
```python
189+
start = time.time()
190+
print("Compiling ...")
191+
p_generate = aot_compile()
192+
print(f"Compiled in {time.time() - start}")
193+
```
194+
And again we put everything together in a `generate` function.
195+
196+
```python
197+
def generate(
198+
prompt,
199+
negative_prompt,
200+
seed=default_seed,
201+
guidance_scale=default_guidance_scale
202+
):
203+
prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
204+
prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
205+
g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)
206+
g = g[:, None]
207+
images = p_generate(
208+
prompt_ids,
209+
p_params,
210+
rng,
211+
g,
212+
None,
213+
neg_prompt_ids)
214+
215+
# convert the images to PIL
216+
images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
217+
return pipeline.numpy_to_pil(np.array(images))
218+
```
219+
220+
The first forward pass after AOT compilation still takes a while longer than
221+
subsequent passes, this is because on the first pass, JAX uses Python dispatch, which
222+
Fills the C++ dispatch cache.
223+
When using jit, this extra step is done automatically, but when using AOT compilation,
224+
it doesn't happen until the function call is made.
225+
226+
```python
227+
start = time.time()
228+
prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
229+
neg_prompt = "cartoon, illustration, animation. face. male, female"
230+
images = generate(prompt, neg_prompt)
231+
print(f"First inference in {time.time() - start}")
232+
```
233+
234+
From this point forward, any calls to generate should result in a faster inference
235+
time and it won't change.
236+
237+
```python
238+
start = time.time()
239+
prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
240+
neg_prompt = "cartoon, illustration, animation. face. male, female"
241+
images = generate(prompt, neg_prompt)
242+
print(f"Inference in {time.time() - start}")
243+
```
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Show best practices for SDXL JAX
2+
import time
3+
4+
import jax
5+
import jax.numpy as jnp
6+
import numpy as np
7+
from flax.jax_utils import replicate
8+
9+
# Let's cache the model compilation, so that it doesn't take as long the next time around.
10+
from jax.experimental.compilation_cache import compilation_cache as cc
11+
12+
from diffusers import FlaxStableDiffusionXLPipeline
13+
14+
15+
cc.initialize_cache("/tmp/sdxl_cache")
16+
17+
18+
NUM_DEVICES = jax.device_count()
19+
20+
# 1. Let's start by downloading the model and loading it into our pipeline class
21+
# Adhering to JAX's functional approach, the model's parameters are returned seperatetely and
22+
# will have to be passed to the pipeline during inference
23+
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
24+
"stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
25+
)
26+
27+
# 2. We cast all parameters to bfloat16 EXCEPT the scheduler which we leave in
28+
# float32 to keep maximal precision
29+
scheduler_state = params.pop("scheduler")
30+
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
31+
params["scheduler"] = scheduler_state
32+
33+
# 3. Next, we define the different inputs to the pipeline
34+
default_prompt = "a colorful photo of a castle in the middle of a forest with trees and bushes, by Ismail Inceoglu, shadows, high contrast, dynamic shading, hdr, detailed vegetation, digital painting, digital drawing, detailed painting, a detailed digital painting, gothic art, featured on deviantart"
35+
default_neg_prompt = "fog, grainy, purple"
36+
default_seed = 33
37+
default_guidance_scale = 5.0
38+
default_num_steps = 25
39+
40+
41+
# 4. In order to be able to compile the pipeline
42+
# all inputs have to be tensors or strings
43+
# Let's tokenize the prompt and negative prompt
44+
def tokenize_prompt(prompt, neg_prompt):
45+
prompt_ids = pipeline.prepare_inputs(prompt)
46+
neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
47+
return prompt_ids, neg_prompt_ids
48+
49+
50+
# 5. To make full use of JAX's parallelization capabilities
51+
# the parameters and input tensors are duplicated across devices
52+
# To make sure every device generates a different image, we create
53+
# different seeds for each image. The model parameters won't change
54+
# during inference so we do not wrap them into a function
55+
p_params = replicate(params)
56+
57+
58+
def replicate_all(prompt_ids, neg_prompt_ids, seed):
59+
p_prompt_ids = replicate(prompt_ids)
60+
p_neg_prompt_ids = replicate(neg_prompt_ids)
61+
rng = jax.random.PRNGKey(seed)
62+
rng = jax.random.split(rng, NUM_DEVICES)
63+
return p_prompt_ids, p_neg_prompt_ids, rng
64+
65+
66+
# 6. Let's now put it all together in a generate function
67+
def generate(
68+
prompt,
69+
negative_prompt,
70+
seed=default_seed,
71+
guidance_scale=default_guidance_scale,
72+
num_inference_steps=default_num_steps,
73+
):
74+
prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
75+
prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
76+
images = pipeline(
77+
prompt_ids,
78+
p_params,
79+
rng,
80+
num_inference_steps=num_inference_steps,
81+
neg_prompt_ids=neg_prompt_ids,
82+
guidance_scale=guidance_scale,
83+
jit=True,
84+
).images
85+
86+
# convert the images to PIL
87+
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
88+
return pipeline.numpy_to_pil(np.array(images))
89+
90+
91+
# 7. Remember that the first call will compile the function and hence be very slow. Let's run generate once
92+
# so that the pipeline call is compiled
93+
start = time.time()
94+
print("Compiling ...")
95+
generate(default_prompt, default_neg_prompt)
96+
print(f"Compiled in {time.time() - start}")
97+
98+
# 8. Now the model forward pass will run very quickly, let's try it again
99+
start = time.time()
100+
prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
101+
neg_prompt = "cartoon, illustration, animation. face. male, female"
102+
images = generate(prompt, neg_prompt)
103+
print(f"Inference in {time.time() - start}")
104+
105+
for i, image in enumerate(images):
106+
image.save(f"castle_{i}.png")

0 commit comments

Comments
 (0)