Skip to content

Commit 45f6d52

Browse files
yiyixuxuyiyixuxupatrickvonplatensayakpaulpcuenca
authored
Add Shap-E (huggingface#3742)
* refactor prior_transformer adding conversion script add pipeline add step_index from pipeline, + remove permute add zero pad token remove copy from statement for betas_for_alpha_bar function * add * add * update conversion script for renderer model * refactor camera a little bit * clean up * style * fix copies * Update src/diffusers/schedulers/scheduling_heun_discrete.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py Co-authored-by: Patrick von Platen <[email protected]> * alpha_transform_type * remove step_index argument * remove get_sigmas_karras * remove _yiyi_sigma_to_t * move the rescale prompt_embeds from prior_transformer to pipeline * replace baddbmm with einsum to match origial repo * Revert "replace baddbmm with einsum to match origial repo" This reverts commit 3f6b435. * add step_index to scale_model_input * Revert "move the rescale prompt_embeds from prior_transformer to pipeline" This reverts commit 5b5a8e6. * move rescale from prior_transformer to pipeline * correct step_index in scale_model_input * remove print lines * refactor prior - reduce arguments * make style * add prior_image * arg embedding_proj_norm -> norm_embedding_proj * add pre-norm for proj_embedding * move rescale prompt from pipeline to _encode_prompt * add img2img pipeline * style * copies * Update src/diffusers/models/prior_transformer.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py add arg: encoder_hid_proj Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py add new config: norm_in_type Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py add new config: added_emb_type Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py rename out_dim -> clip_embed_dim Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py rename config: out_dim -> clip_embed_dim Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/prior_transformer.py Co-authored-by: Patrick von Platen <[email protected]> * finish refactor prior_tranformer * make style * refactor renderer * fix * make style * refactor img2img * remove params_proj * add test * add upcast_softmax to prior_transformer * enable num_images_per_prompt, add save_gif utility * add * add fast test * make style * add slow test * style * add test for img2img * refactor * enable batching * style * refactor scheduler * update test * style * attempt to solve batch related tests timeout * add doc * Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py Co-authored-by: Patrick von Platen <[email protected]> * hardcode rendering related config * update betas_for_alpha_bar on ddpm_scheduler * fix copies * fix * export_to_gif * style * second attempt to speed up batching tests * add doc page to index * Remove intermediate clipping * 3rd attempt to speed up batching tests * Remvoe time index * simplify scheduler * Fix more * Fix more * fix more * make style * fix schedulers * fix some more tests * finish * add one more test * Apply suggestions from code review Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> * style * apply feedbacks * style * fix copies * add one example * style * add example for img2img * fix doc * fix more doc strings * size -> frame_size * style * update doc * style * fix on doc * update repo name * improve the usage example in shap-e img2img * add usage examples in the shap-e docs. * consolidate examples. * minor fix. * update doc * Apply suggestions from code review * Apply suggestions from code review * remove upcast * Make sure background is white * Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py * Apply suggestions from code review * Finish * Apply suggestions from code review * Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py * Make style --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent 7462156 commit 45f6d52

37 files changed

+3534
-116
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@
226226
title: Self-Attention Guidance
227227
- local: api/pipelines/semantic_stable_diffusion
228228
title: Semantic Guidance
229+
- local: api/pipelines/shap_e
230+
title: Shap-E
229231
- local: api/pipelines/spectrogram_diffusion
230232
title: Spectrogram Diffusion
231233
- sections:
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
3+
the License. You may obtain a copy of the License at
4+
http://www.apache.org/licenses/LICENSE-2.0
5+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
6+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
7+
specific language governing permissions and limitations under the License.
8+
-->
9+
10+
# Shap-E
11+
12+
## Overview
13+
14+
15+
The Shap-E model was proposed in [Shap-E: Generating Conditional 3D Implicit Functions](https://arxiv.org/abs/2305.02463) by Alex Nichol and Heewon Jun from [OpenAI](https://github.com/openai).
16+
17+
The abstract of the paper is the following:
18+
19+
*We present Shap-E, a conditional generative model for 3D assets. Unlike recent work on 3D generative models which produce a single output representation, Shap-E directly generates the parameters of implicit functions that can be rendered as both textured meshes and neural radiance fields. We train Shap-E in two stages: first, we train an encoder that deterministically maps 3D assets into the parameters of an implicit function; second, we train a conditional diffusion model on outputs of the encoder. When trained on a large dataset of paired 3D and text data, our resulting models are capable of generating complex and diverse 3D assets in a matter of seconds. When compared to Point-E, an explicit generative model over point clouds, Shap-E converges faster and reaches comparable or better sample quality despite modeling a higher-dimensional, multi-representation output space.*
20+
21+
The original codebase can be found [here](https://github.com/openai/shap-e).
22+
23+
## Available Pipelines:
24+
25+
| Pipeline | Tasks |
26+
|---|---|
27+
| [pipeline_shap_e.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/shap_e/pipeline_shap_e.py) | *Text-to-Image Generation* |
28+
| [pipeline_shap_e_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py) | *Image-to-Image Generation* |
29+
30+
## Available checkpoints
31+
32+
* [`openai/shap-e`](https://huggingface.co/openai/shap-e)
33+
* [`openai/shap-e-img2img`](https://huggingface.co/openai/shap-e-img2img)
34+
35+
## Usage Examples
36+
37+
In the following, we will walk you through some examples of how to use Shap-E pipelines to create 3D objects in gif format.
38+
39+
### Text-to-3D image generation
40+
41+
We can use [`ShapEPipeline`] to create 3D object based on a text prompt. In this example, we will make a birthday cupcake for :firecracker: diffusers library's 1 year birthday. The workflow to use the Shap-E text-to-image pipeline is same as how you would use other text-to-image pipelines in diffusers.
42+
43+
```python
44+
import torch
45+
46+
from diffusers import DiffusionPipeline
47+
48+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49+
50+
repo = "openai/shap-e"
51+
pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16)
52+
pipe = pipe.to(device)
53+
54+
guidance_scale = 15.0
55+
prompt = ["A firecracker", "A birthday cupcake"]
56+
57+
images = pipe(
58+
prompt,
59+
guidance_scale=guidance_scale,
60+
num_inference_steps=64,
61+
frame_size=256,
62+
).images
63+
```
64+
65+
The output of [`ShapEPipeline`] is a list of lists of images frames. Each list of frames can be used to create a 3D object. Let's use the `export_to_gif` utility function in diffusers to make a 3D cupcake!
66+
67+
```python
68+
from diffusers.utils import export_to_gif
69+
70+
export_to_gif(images[0], "firecracker_3d.gif")
71+
export_to_gif(images[1], "cake_3d.gif")
72+
```
73+
![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/firecracker_out.gif)
74+
![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/cake_out.gif)
75+
76+
77+
### Image-to-Image generation
78+
79+
You can use [`ShapEImg2ImgPipeline`] along with other text-to-image pipelines in diffusers and turn your 2D generation into 3D.
80+
81+
In this example, We will first genrate a cheeseburger with a simple prompt "A cheeseburger, white background"
82+
83+
```python
84+
from diffusers import DiffusionPipeline
85+
import torch
86+
87+
pipe_prior = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16)
88+
pipe_prior.to("cuda")
89+
90+
t2i_pipe = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
91+
t2i_pipe.to("cuda")
92+
93+
prompt = "A cheeseburger, white background"
94+
95+
image_embeds, negative_image_embeds = pipe_prior(prompt, guidance_scale=1.0).to_tuple()
96+
image = t2i_pipe(
97+
prompt,
98+
image_embeds=image_embeds,
99+
negative_image_embeds=negative_image_embeds,
100+
).images[0]
101+
102+
image.save("burger.png")
103+
```
104+
105+
![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/burger_in.png)
106+
107+
we will then use the Shap-E image-to-image pipeline to turn it into a 3D cheeseburger :)
108+
109+
```python
110+
from PIL import Image
111+
from diffusers.utils import export_to_gif
112+
113+
repo = "openai/shap-e-img2img"
114+
pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16)
115+
pipe = pipe.to("cuda")
116+
117+
guidance_scale = 3.0
118+
image = Image.open("burger.png").resize((256, 256))
119+
120+
images = pipe(
121+
image,
122+
guidance_scale=guidance_scale,
123+
num_inference_steps=64,
124+
frame_size=256,
125+
).images
126+
127+
gif_path = export_to_gif(images[0], "burger_3d.gif")
128+
```
129+
![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/burger_out.gif)
130+
131+
## ShapEPipeline
132+
[[autodoc]] ShapEPipeline
133+
- all
134+
- __call__
135+
136+
## ShapEImg2ImgPipeline
137+
[[autodoc]] ShapEImg2ImgPipeline
138+
- all
139+
- __call__

0 commit comments

Comments
 (0)