Skip to content

Commit 15782fd

Browse files
sayakpaulpatrickvonplatenWauplinydshieh
authored
[Pipeline utils] feat: implement push_to_hub for standalone models, schedulers as well as pipelines (huggingface#4128)
* feat: implement push_to_hub for standalone models. * address PR feedback. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * remove max_shard_size. * add: support for scheduler push_to_hub * enable push_to_hub support for flax schedulers. * enable push_to_hub for pipelines. * Apply suggestions from code review Co-authored-by: Lucain <[email protected]> * reflect pr feedback. * address another round of deedback. * better handling of kwargs. * add: tests * Apply suggestions from code review Co-authored-by: Lucain <[email protected]> * setting hub staging to False for now. * incorporate staging test as a separate job. Co-authored-by: ydshieh <[email protected]> * fix: tokenizer loading. * fix: json dumping. * move is_staging_test to a better location. * better treatment to tokens. * define repo_id to better handle concurrency * style * explicitly set token * Empty-Commit * move SUER, TOKEN to test * collate org_repo_id * delete repo --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Lucain <[email protected]> Co-authored-by: ydshieh <[email protected]>
1 parent d93ca26 commit 15782fd

File tree

15 files changed

+647
-20
lines changed

15 files changed

+647
-20
lines changed

.github/workflows/pr_tests.yml

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,60 @@ jobs:
113113
with:
114114
name: pr_${{ matrix.config.report }}_test_reports
115115
path: reports
116+
117+
run_staging_tests:
118+
strategy:
119+
fail-fast: false
120+
matrix:
121+
config:
122+
- name: Hub tests for models, schedulers, and pipelines
123+
framework: hub_tests_pytorch
124+
runner: docker-cpu
125+
image: diffusers/diffusers-pytorch-cpu
126+
report: torch_hub
127+
128+
name: ${{ matrix.config.name }}
129+
130+
runs-on: ${{ matrix.config.runner }}
131+
132+
container:
133+
image: ${{ matrix.config.image }}
134+
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
135+
136+
defaults:
137+
run:
138+
shell: bash
139+
140+
steps:
141+
- name: Checkout diffusers
142+
uses: actions/checkout@v3
143+
with:
144+
fetch-depth: 2
145+
146+
- name: Install dependencies
147+
run: |
148+
apt-get update && apt-get install libsndfile1-dev libgl1 -y
149+
python -m pip install -e .[quality,test]
150+
151+
- name: Environment
152+
run: |
153+
python utils/print_env.py
154+
155+
- name: Run Hub tests for models, schedulers, and pipelines on a staging env
156+
if: ${{ matrix.config.framework == 'hub_tests_pytorch' }}
157+
run: |
158+
HUGGINGFACE_CO_STAGING=true python -m pytest \
159+
-m "is_staging_test" \
160+
--make-reports=tests_${{ matrix.config.report }} \
161+
tests
162+
163+
- name: Failure short reports
164+
if: ${{ failure() }}
165+
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
166+
167+
- name: Test suite reports artifacts
168+
if: ${{ always() }}
169+
uses: actions/upload-artifact@v2
170+
with:
171+
name: pr_${{ matrix.config.report }}_test_reports
172+
path: reports

docs/source/en/api/models/overview.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,8 @@ All models are built from the base [`ModelMixin`] class which is a [`torch.nn.mo
99

1010
## FlaxModelMixin
1111

12-
[[autodoc]] FlaxModelMixin
12+
[[autodoc]] FlaxModelMixin
13+
14+
## Pushing to the Hub
15+
16+
[[autodoc]] utils.PushToHubMixin

src/diffusers/configuration_utils.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from typing import Any, Dict, Tuple, Union
2727

2828
import numpy as np
29-
from huggingface_hub import hf_hub_download
29+
from huggingface_hub import create_repo, hf_hub_download
3030
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
3131
from requests import HTTPError
3232

@@ -144,6 +144,12 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool
144144
Args:
145145
save_directory (`str` or `os.PathLike`):
146146
Directory where the configuration JSON file is saved (will be created if it does not exist).
147+
push_to_hub (`bool`, *optional*, defaults to `False`):
148+
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
149+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
150+
namespace).
151+
kwargs (`Dict[str, Any]`, *optional*):
152+
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
147153
"""
148154
if os.path.isfile(save_directory):
149155
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
@@ -156,6 +162,22 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool
156162
self.to_json_file(output_config_file)
157163
logger.info(f"Configuration saved in {output_config_file}")
158164

165+
if push_to_hub:
166+
commit_message = kwargs.pop("commit_message", None)
167+
private = kwargs.pop("private", False)
168+
create_pr = kwargs.pop("create_pr", False)
169+
token = kwargs.pop("token", None)
170+
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
171+
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
172+
173+
self._upload_folder(
174+
save_directory,
175+
repo_id,
176+
token=token,
177+
commit_message=commit_message,
178+
create_pr=create_pr,
179+
)
180+
159181
@classmethod
160182
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
161183
r"""

src/diffusers/models/modeling_flax_utils.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from flax.core.frozen_dict import FrozenDict, unfreeze
2424
from flax.serialization import from_bytes, to_bytes
2525
from flax.traverse_util import flatten_dict, unflatten_dict
26-
from huggingface_hub import hf_hub_download
26+
from huggingface_hub import create_repo, hf_hub_download
2727
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
2828
from requests import HTTPError
2929

@@ -34,6 +34,7 @@
3434
FLAX_WEIGHTS_NAME,
3535
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
3636
WEIGHTS_NAME,
37+
PushToHubMixin,
3738
logging,
3839
)
3940
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
@@ -42,7 +43,7 @@
4243
logger = logging.get_logger(__name__)
4344

4445

45-
class FlaxModelMixin:
46+
class FlaxModelMixin(PushToHubMixin):
4647
r"""
4748
Base class for all Flax models.
4849
@@ -497,6 +498,8 @@ def save_pretrained(
497498
save_directory: Union[str, os.PathLike],
498499
params: Union[Dict, FrozenDict],
499500
is_main_process: bool = True,
501+
push_to_hub: bool = False,
502+
**kwargs,
500503
):
501504
"""
502505
Save a model and its configuration file to a directory so that it can be reloaded using the
@@ -511,13 +514,27 @@ def save_pretrained(
511514
Whether the process calling this is the main process or not. Useful during distributed training and you
512515
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
513516
process to avoid race conditions.
517+
push_to_hub (`bool`, *optional*, defaults to `False`):
518+
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
519+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
520+
namespace).
521+
kwargs (`Dict[str, Any]`, *optional*):
522+
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
514523
"""
515524
if os.path.isfile(save_directory):
516525
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
517526
return
518527

519528
os.makedirs(save_directory, exist_ok=True)
520529

530+
if push_to_hub:
531+
commit_message = kwargs.pop("commit_message", None)
532+
private = kwargs.pop("private", False)
533+
create_pr = kwargs.pop("create_pr", False)
534+
token = kwargs.pop("token", None)
535+
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
536+
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
537+
521538
model_to_save = self
522539

523540
# Attach architecture to the config
@@ -532,3 +549,12 @@ def save_pretrained(
532549
f.write(model_bytes)
533550

534551
logger.info(f"Model weights saved in {output_model_file}")
552+
553+
if push_to_hub:
554+
self._upload_folder(
555+
save_directory,
556+
repo_id,
557+
token=token,
558+
commit_message=commit_message,
559+
create_pr=create_pr,
560+
)

src/diffusers/models/modeling_utils.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import safetensors
2525
import torch
26+
from huggingface_hub import create_repo
2627
from torch import Tensor, device, nn
2728

2829
from .. import __version__
@@ -40,6 +41,7 @@
4041
is_torch_version,
4142
logging,
4243
)
44+
from ..utils.hub_utils import PushToHubMixin
4345

4446

4547
logger = logging.get_logger(__name__)
@@ -147,7 +149,7 @@ def load(module: torch.nn.Module, prefix=""):
147149
return error_msgs
148150

149151

150-
class ModelMixin(torch.nn.Module):
152+
class ModelMixin(torch.nn.Module, PushToHubMixin):
151153
r"""
152154
Base class for all models.
153155
@@ -272,6 +274,8 @@ def save_pretrained(
272274
save_function: Callable = None,
273275
safe_serialization: bool = False,
274276
variant: Optional[str] = None,
277+
push_to_hub: bool = False,
278+
**kwargs,
275279
):
276280
"""
277281
Save a model and its configuration file to a directory so that it can be reloaded using the
@@ -292,13 +296,28 @@ def save_pretrained(
292296
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
293297
variant (`str`, *optional*):
294298
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
299+
push_to_hub (`bool`, *optional*, defaults to `False`):
300+
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
301+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
302+
namespace).
303+
kwargs (`Dict[str, Any]`, *optional*):
304+
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
295305
"""
296306
if os.path.isfile(save_directory):
297307
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
298308
return
299309

300310
os.makedirs(save_directory, exist_ok=True)
301311

312+
if push_to_hub:
313+
commit_message = kwargs.pop("commit_message", None)
314+
private = kwargs.pop("private", False)
315+
create_pr = kwargs.pop("create_pr", False)
316+
token = kwargs.pop("token", None)
317+
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
318+
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
319+
320+
# Only save the model itself if we are using distributed training
302321
model_to_save = self
303322

304323
# Attach architecture to the config
@@ -322,6 +341,15 @@ def save_pretrained(
322341

323342
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
324343

344+
if push_to_hub:
345+
self._upload_folder(
346+
save_directory,
347+
repo_id,
348+
token=token,
349+
commit_message=commit_message,
350+
create_pr=create_pr,
351+
)
352+
325353
@classmethod
326354
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
327355
r"""

src/diffusers/pipelines/pipeline_flax_utils.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,22 @@
2323
import numpy as np
2424
import PIL
2525
from flax.core.frozen_dict import FrozenDict
26-
from huggingface_hub import snapshot_download
26+
from huggingface_hub import create_repo, snapshot_download
2727
from PIL import Image
2828
from tqdm.auto import tqdm
2929

3030
from ..configuration_utils import ConfigMixin
3131
from ..models.modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
3232
from ..schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
33-
from ..utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, http_user_agent, is_transformers_available, logging
33+
from ..utils import (
34+
CONFIG_NAME,
35+
DIFFUSERS_CACHE,
36+
BaseOutput,
37+
PushToHubMixin,
38+
http_user_agent,
39+
is_transformers_available,
40+
logging,
41+
)
3442

3543

3644
if is_transformers_available():
@@ -90,7 +98,7 @@ class FlaxImagePipelineOutput(BaseOutput):
9098
images: Union[List[PIL.Image.Image], np.ndarray]
9199

92100

93-
class FlaxDiffusionPipeline(ConfigMixin):
101+
class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
94102
r"""
95103
Base class for Flax-based pipelines.
96104
@@ -139,7 +147,13 @@ def register_modules(self, **kwargs):
139147
# set models
140148
setattr(self, name, module)
141149

142-
def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict]):
150+
def save_pretrained(
151+
self,
152+
save_directory: Union[str, os.PathLike],
153+
params: Union[Dict, FrozenDict],
154+
push_to_hub: bool = False,
155+
**kwargs,
156+
):
143157
# TODO: handle inference_state
144158
"""
145159
Save all saveable variables of the pipeline to a directory. A pipeline variable can be saved and loaded if its
@@ -149,6 +163,12 @@ class implements both a save and loading method. The pipeline is easily reloaded
149163
Arguments:
150164
save_directory (`str` or `os.PathLike`):
151165
Directory to which to save. Will be created if it doesn't exist.
166+
push_to_hub (`bool`, *optional*, defaults to `False`):
167+
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
168+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
169+
namespace).
170+
kwargs (`Dict[str, Any]`, *optional*):
171+
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
152172
"""
153173
self.save_config(save_directory)
154174

@@ -157,6 +177,14 @@ class implements both a save and loading method. The pipeline is easily reloaded
157177
model_index_dict.pop("_diffusers_version")
158178
model_index_dict.pop("_module", None)
159179

180+
if push_to_hub:
181+
commit_message = kwargs.pop("commit_message", None)
182+
private = kwargs.pop("private", False)
183+
create_pr = kwargs.pop("create_pr", False)
184+
token = kwargs.pop("token", None)
185+
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
186+
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
187+
160188
for pipeline_component_name in model_index_dict.keys():
161189
sub_model = getattr(self, pipeline_component_name)
162190
if sub_model is None:
@@ -188,6 +216,15 @@ class implements both a save and loading method. The pipeline is easily reloaded
188216
else:
189217
save_method(os.path.join(save_directory, pipeline_component_name))
190218

219+
if push_to_hub:
220+
self._upload_folder(
221+
save_directory,
222+
repo_id,
223+
token=token,
224+
commit_message=commit_message,
225+
create_pr=create_pr,
226+
)
227+
191228
@classmethod
192229
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
193230
r"""

0 commit comments

Comments
 (0)