2323import numpy as np
2424import PIL
2525from flax .core .frozen_dict import FrozenDict
26- from huggingface_hub import snapshot_download
26+ from huggingface_hub import create_repo , snapshot_download
2727from PIL import Image
2828from tqdm .auto import tqdm
2929
3030from ..configuration_utils import ConfigMixin
3131from ..models .modeling_flax_utils import FLAX_WEIGHTS_NAME , FlaxModelMixin
3232from ..schedulers .scheduling_utils_flax import SCHEDULER_CONFIG_NAME , FlaxSchedulerMixin
33- from ..utils import CONFIG_NAME , DIFFUSERS_CACHE , BaseOutput , http_user_agent , is_transformers_available , logging
33+ from ..utils import (
34+ CONFIG_NAME ,
35+ DIFFUSERS_CACHE ,
36+ BaseOutput ,
37+ PushToHubMixin ,
38+ http_user_agent ,
39+ is_transformers_available ,
40+ logging ,
41+ )
3442
3543
3644if is_transformers_available ():
@@ -90,7 +98,7 @@ class FlaxImagePipelineOutput(BaseOutput):
9098 images : Union [List [PIL .Image .Image ], np .ndarray ]
9199
92100
93- class FlaxDiffusionPipeline (ConfigMixin ):
101+ class FlaxDiffusionPipeline (ConfigMixin , PushToHubMixin ):
94102 r"""
95103 Base class for Flax-based pipelines.
96104
@@ -139,7 +147,13 @@ def register_modules(self, **kwargs):
139147 # set models
140148 setattr (self , name , module )
141149
142- def save_pretrained (self , save_directory : Union [str , os .PathLike ], params : Union [Dict , FrozenDict ]):
150+ def save_pretrained (
151+ self ,
152+ save_directory : Union [str , os .PathLike ],
153+ params : Union [Dict , FrozenDict ],
154+ push_to_hub : bool = False ,
155+ ** kwargs ,
156+ ):
143157 # TODO: handle inference_state
144158 """
145159 Save all saveable variables of the pipeline to a directory. A pipeline variable can be saved and loaded if its
@@ -149,6 +163,12 @@ class implements both a save and loading method. The pipeline is easily reloaded
149163 Arguments:
150164 save_directory (`str` or `os.PathLike`):
151165 Directory to which to save. Will be created if it doesn't exist.
166+ push_to_hub (`bool`, *optional*, defaults to `False`):
167+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
168+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
169+ namespace).
170+ kwargs (`Dict[str, Any]`, *optional*):
171+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
152172 """
153173 self .save_config (save_directory )
154174
@@ -157,6 +177,14 @@ class implements both a save and loading method. The pipeline is easily reloaded
157177 model_index_dict .pop ("_diffusers_version" )
158178 model_index_dict .pop ("_module" , None )
159179
180+ if push_to_hub :
181+ commit_message = kwargs .pop ("commit_message" , None )
182+ private = kwargs .pop ("private" , False )
183+ create_pr = kwargs .pop ("create_pr" , False )
184+ token = kwargs .pop ("token" , None )
185+ repo_id = kwargs .pop ("repo_id" , save_directory .split (os .path .sep )[- 1 ])
186+ repo_id = create_repo (repo_id , exist_ok = True , private = private , token = token ).repo_id
187+
160188 for pipeline_component_name in model_index_dict .keys ():
161189 sub_model = getattr (self , pipeline_component_name )
162190 if sub_model is None :
@@ -188,6 +216,15 @@ class implements both a save and loading method. The pipeline is easily reloaded
188216 else :
189217 save_method (os .path .join (save_directory , pipeline_component_name ))
190218
219+ if push_to_hub :
220+ self ._upload_folder (
221+ save_directory ,
222+ repo_id ,
223+ token = token ,
224+ commit_message = commit_message ,
225+ create_pr = create_pr ,
226+ )
227+
191228 @classmethod
192229 def from_pretrained (cls , pretrained_model_name_or_path : Optional [Union [str , os .PathLike ]], ** kwargs ):
193230 r"""
0 commit comments