Skip to content

[Community Pipelines]Accelerate inference of stable diffusion by IPEX on CPU #3105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 23, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
reformat
  • Loading branch information
yingjie-han committed Apr 18, 2023
commit 2eab9c4daee8cfbfb0ac7f1314d6565eaf67c7f6
87 changes: 51 additions & 36 deletions examples/community/stable_diffusion_ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ def __init__(
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)


self.register_modules(
vae=vae,
text_encoder=text_encoder,
Expand All @@ -173,9 +172,7 @@ def __init__(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)



def get_input_example(self,prompt,height = None,width = None,guidance_scale=7.5,num_images_per_prompt =1):
def get_input_example(self, prompt, height=None, width=None, guidance_scale=7.5, num_images_per_prompt=1):

prompt_embeds = None
negative_prompt_embeds = None
Expand Down Expand Up @@ -205,7 +202,6 @@ def get_input_example(self,prompt,height = None,width = None,guidance_scale=7.5,
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0


# 3. Encode input prompt
prompt_embeds = self._encode_prompt(
prompt,
Expand All @@ -217,7 +213,6 @@ def get_input_example(self,prompt,height = None,width = None,guidance_scale=7.5,
negative_prompt_embeds=negative_prompt_embeds,
)


# 5. Prepare latent variables
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
Expand All @@ -236,52 +231,78 @@ def get_input_example(self,prompt,height = None,width = None,guidance_scale=7.5,
unet_input_example = (latent_model_input, dummy, prompt_embeds)
vae_decoder_input_example = latents

return unet_input_example,vae_decoder_input_example


return unet_input_example, vae_decoder_input_example

def prepare_for_ipex(self,promt,infer_type = 'bf16',height = None,width = None,guidance_scale=7.5):
def prepare_for_ipex(self, promt, infer_type="bf16", height=None, width=None, guidance_scale=7.5):
self.unet = self.unet.to(memory_format=torch.channels_last)
self.vae.decoder = self.vae.decoder.to(memory_format=torch.channels_last)
self.text_encoder = self.text_encoder.to(memory_format=torch.channels_last)
if self.safety_checker != None:
self.safety_checker = self.safety_checker.to(memory_format=torch.channels_last)

unet_input_example,vae_decoder_input_example = self.get_input_example(promt,height,width,guidance_scale)
unet_input_example, vae_decoder_input_example = self.get_input_example(promt, height, width, guidance_scale)

# optimize with ipex
if infer_type == 'bf16':
self.unet = ipex.optimize(self.unet.eval(), dtype=torch.bfloat16, inplace=True, sample_input=unet_input_example)
if infer_type == "bf16":
self.unet = ipex.optimize(
self.unet.eval(), dtype=torch.bfloat16, inplace=True, sample_input=unet_input_example
)
self.vae.decoder = ipex.optimize(self.vae.decoder.eval(), dtype=torch.bfloat16, inplace=True)
self.text_encoder = ipex.optimize(self.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)
if self.safety_checker != None:
self.safety_checker = ipex.optimize(self.safety_checker.eval(), dtype=torch.bfloat16, inplace=True)
elif infer_type == 'fp32':
self.unet = ipex.optimize(self.unet.eval(), dtype=torch.float32, inplace=True, sample_input=unet_input_example, level="O1", weights_prepack=True, auto_kernel_selection=False)
self.vae.decoder = ipex.optimize(self.vae.decoder.eval(), dtype=torch.float32, inplace=True, level="O1", weights_prepack=True, auto_kernel_selection=False)
self.text_encoder = ipex.optimize(self.text_encoder.eval(), dtype=torch.float32, inplace=True, level="O1", weights_prepack=True, auto_kernel_selection=False)
elif infer_type == "fp32":
self.unet = ipex.optimize(
self.unet.eval(),
dtype=torch.float32,
inplace=True,
sample_input=unet_input_example,
level="O1",
weights_prepack=True,
auto_kernel_selection=False,
)
self.vae.decoder = ipex.optimize(
self.vae.decoder.eval(),
dtype=torch.float32,
inplace=True,
level="O1",
weights_prepack=True,
auto_kernel_selection=False,
)
self.text_encoder = ipex.optimize(
self.text_encoder.eval(),
dtype=torch.float32,
inplace=True,
level="O1",
weights_prepack=True,
auto_kernel_selection=False,
)
if self.safety_checker != None:
self.safety_checker = ipex.optimize(self.safety_checker.eval(), dtype=torch.float32, inplace=True, level="O1", weights_prepack=True, auto_kernel_selection=False)
self.safety_checker = ipex.optimize(
self.safety_checker.eval(),
dtype=torch.float32,
inplace=True,
level="O1",
weights_prepack=True,
auto_kernel_selection=False,
)
else:
raise ValueError(
f" The value of infer_type should be 'bf16' or 'fp32' !"
)
raise ValueError(f" The value of infer_type should be 'bf16' or 'fp32' !")

# trace unet model to get better performance on IPEX
with torch.cpu.amp.autocast(enabled=infer_type=='bf16'), torch.no_grad():
with torch.cpu.amp.autocast(enabled=infer_type == "bf16"), torch.no_grad():
unet_trace_model = torch.jit.trace(self.unet, unet_input_example, check_trace=False, strict=False)
unet_trace_model = torch.jit.freeze(unet_trace_model)
self.unet.forward = unet_trace_model.forward

# trace vae.decoder model to get better performance on IPEX
with torch.cpu.amp.autocast(enabled=infer_type=='bf16'), torch.no_grad():
ave_decoder_trace_model = torch.jit.trace(self.vae.decoder, vae_decoder_input_example, check_trace=False, strict=False)
with torch.cpu.amp.autocast(enabled=infer_type == "bf16"), torch.no_grad():
ave_decoder_trace_model = torch.jit.trace(
self.vae.decoder, vae_decoder_input_example, check_trace=False, strict=False
)
ave_decoder_trace_model = torch.jit.freeze(ave_decoder_trace_model)
self.vae.decoder.forward = ave_decoder_trace_model.forward




def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding.
Expand Down Expand Up @@ -459,7 +480,6 @@ def _encode_prompt(
)
prompt_embeds = prompt_embeds[0]


prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)

bs_embed, seq_len, _ = prompt_embeds.shape
Expand Down Expand Up @@ -767,7 +787,7 @@ def __call__(

# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
Expand All @@ -777,11 +797,7 @@ def __call__(
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds
)['sample']
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds)["sample"]

# perform guidance
if do_classifier_free_guidance:
Expand Down Expand Up @@ -811,7 +827,7 @@ def __call__(
# 10. Convert to PIL
image = self.numpy_to_pil(image)
else:

# 8. Post-processing
image = self.decode_latents(latents)

Expand All @@ -826,4 +842,3 @@ def __call__(
return (image, has_nsfw_concept)

return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)