Skip to content

Commit 970e306

Browse files
committed
Revert "[v0.4.0] Temporarily remove Flax modules from the public API (huggingface#755)"
This reverts commit 2e209c3.
1 parent c15cda0 commit 970e306

File tree

9 files changed

+100
-6
lines changed

9 files changed

+100
-6
lines changed

docs/source/api/models.mdx

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,21 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
4545

4646
## AutoencoderKL
4747
[[autodoc]] AutoencoderKL
48+
49+
## FlaxModelMixin
50+
[[autodoc]] FlaxModelMixin
51+
52+
## FlaxUNet2DConditionOutput
53+
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput
54+
55+
## FlaxUNet2DConditionModel
56+
[[autodoc]] FlaxUNet2DConditionModel
57+
58+
## FlaxDecoderOutput
59+
[[autodoc]] models.vae_flax.FlaxDecoderOutput
60+
61+
## FlaxAutoencoderKLOutput
62+
[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput
63+
64+
## FlaxAutoencoderKL
65+
[[autodoc]] FlaxAutoencoderKL

docs/source/api/schedulers.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ This allows for rapid experimentation and cleaner abstractions in the code, wher
3636
To this end, the design of schedulers is such that:
3737

3838
- Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality.
39-
- Schedulers are currently by default in PyTorch.
39+
- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Jax support currently exists).
4040

4141

4242
## API

setup.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,13 @@
8484
"datasets",
8585
"filelock",
8686
"flake8>=3.8.3",
87+
"flax>=0.4.1",
8788
"hf-doc-builder>=0.3.0",
8889
"huggingface-hub>=0.10.0",
8990
"importlib_metadata",
9091
"isort>=5.5.4",
92+
"jax>=0.2.8,!=0.3.2,<=0.3.6",
93+
"jaxlib>=0.1.65,<=0.3.6",
9194
"modelcards>=0.1.4",
9295
"numpy",
9396
"onnxruntime",
@@ -185,9 +188,15 @@ def run(self):
185188
"torchvision",
186189
"transformers"
187190
)
191+
extras["torch"] = deps_list("torch")
192+
193+
if os.name == "nt": # windows
194+
extras["flax"] = [] # jax is not supported on windows
195+
else:
196+
extras["flax"] = deps_list("jax", "jaxlib", "flax")
188197

189198
extras["dev"] = (
190-
extras["quality"] + extras["test"] + extras["training"] + extras["docs"]
199+
extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
191200
)
192201

193202
install_requires = [
@@ -198,7 +207,6 @@ def run(self):
198207
deps["regex"],
199208
deps["requests"],
200209
deps["Pillow"],
201-
deps["torch"]
202210
]
203211

204212
setup(

src/diffusers/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .utils import (
2+
is_flax_available,
23
is_inflect_available,
34
is_onnx_available,
45
is_scipy_available,
@@ -60,3 +61,25 @@
6061
from .pipelines import StableDiffusionOnnxPipeline
6162
else:
6263
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
64+
65+
if is_flax_available():
66+
from .modeling_flax_utils import FlaxModelMixin
67+
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
68+
from .models.vae_flax import FlaxAutoencoderKL
69+
from .pipeline_flax_utils import FlaxDiffusionPipeline
70+
from .schedulers import (
71+
FlaxDDIMScheduler,
72+
FlaxDDPMScheduler,
73+
FlaxKarrasVeScheduler,
74+
FlaxLMSDiscreteScheduler,
75+
FlaxPNDMScheduler,
76+
FlaxSchedulerMixin,
77+
FlaxScoreSdeVeScheduler,
78+
)
79+
else:
80+
from .utils.dummy_flax_objects import * # noqa F403
81+
82+
if is_flax_available() and is_transformers_available():
83+
from .pipelines import FlaxStableDiffusionPipeline
84+
else:
85+
from .utils.dummy_flax_and_transformers_objects import * # noqa F403

src/diffusers/dependency_versions_table.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@
88
"datasets": "datasets",
99
"filelock": "filelock",
1010
"flake8": "flake8>=3.8.3",
11+
"flax": "flax>=0.4.1",
1112
"hf-doc-builder": "hf-doc-builder>=0.3.0",
1213
"huggingface-hub": "huggingface-hub>=0.10.0",
1314
"importlib_metadata": "importlib_metadata",
1415
"isort": "isort>=5.5.4",
16+
"jax": "jax>=0.2.8,!=0.3.2,<=0.3.6",
17+
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
1518
"modelcards": "modelcards>=0.1.4",
1619
"numpy": "numpy",
1720
"onnxruntime": "onnxruntime",

src/diffusers/models/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from ..utils import is_torch_available
15+
from ..utils import is_flax_available, is_torch_available
1616

1717

1818
if is_torch_available():
1919
from .unet_2d import UNet2DModel
2020
from .unet_2d_condition import UNet2DConditionModel
2121
from .vae import AutoencoderKL, VQModel
22+
23+
if is_flax_available():
24+
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
25+
from .vae_flax import FlaxAutoencoderKL

src/diffusers/pipelines/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,6 @@
2121

2222
if is_transformers_available() and is_onnx_available():
2323
from .stable_diffusion import StableDiffusionOnnxPipeline
24+
25+
if is_transformers_available() and is_flax_available():
26+
from .stable_diffusion import FlaxStableDiffusionPipeline

src/diffusers/pipelines/stable_diffusion/__init__.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import PIL
77
from PIL import Image
88

9-
from ...utils import BaseOutput, is_onnx_available, is_torch_available, is_transformers_available
9+
from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
1010

1111

1212
@dataclass
@@ -35,3 +35,27 @@ class StableDiffusionPipelineOutput(BaseOutput):
3535

3636
if is_transformers_available() and is_onnx_available():
3737
from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline
38+
39+
if is_transformers_available() and is_flax_available():
40+
import flax
41+
42+
@flax.struct.dataclass
43+
class FlaxStableDiffusionPipelineOutput(BaseOutput):
44+
"""
45+
Output class for Stable Diffusion pipelines.
46+
47+
Args:
48+
images (`List[PIL.Image.Image]` or `np.ndarray`)
49+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
50+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
51+
nsfw_content_detected (`List[bool]`)
52+
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
53+
(nsfw) content.
54+
"""
55+
56+
images: Union[List[PIL.Image.Image], np.ndarray]
57+
nsfw_content_detected: List[bool]
58+
59+
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
60+
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
61+
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker

src/diffusers/schedulers/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515

16-
from ..utils import is_scipy_available, is_torch_available
16+
from ..utils import is_flax_available, is_scipy_available, is_torch_available
1717

1818

1919
if is_torch_available():
@@ -27,6 +27,17 @@
2727
else:
2828
from ..utils.dummy_pt_objects import * # noqa F403
2929

30+
if is_flax_available():
31+
from .scheduling_ddim_flax import FlaxDDIMScheduler
32+
from .scheduling_ddpm_flax import FlaxDDPMScheduler
33+
from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler
34+
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
35+
from .scheduling_pndm_flax import FlaxPNDMScheduler
36+
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
37+
from .scheduling_utils_flax import FlaxSchedulerMixin
38+
else:
39+
from ..utils.dummy_flax_objects import * # noqa F403
40+
3041

3142
if is_scipy_available() and is_torch_available():
3243
from .scheduling_lms_discrete import LMSDiscreteScheduler

0 commit comments

Comments
 (0)