Skip to content

Commit fbcc383

Browse files
authored
Deprecate init_git_repo, refactor train_unconditional.py (huggingface#1022)
Deprecate `init_git_repo` and `push_to_hub`, refactor `train_unconditional.py`
1 parent 90f91ad commit fbcc383

File tree

2 files changed

+186
-55
lines changed

2 files changed

+186
-55
lines changed

examples/unconditional_image_generation/train_unconditional.py

Lines changed: 174 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import argparse
22
import math
33
import os
4+
from pathlib import Path
5+
from typing import Optional
46

57
import torch
68
import torch.nn.functional as F
@@ -9,9 +11,9 @@
911
from accelerate.logging import get_logger
1012
from datasets import load_dataset
1113
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
12-
from diffusers.hub_utils import init_git_repo
1314
from diffusers.optimization import get_scheduler
1415
from diffusers.training_utils import EMAModel
16+
from huggingface_hub import HfFolder, Repository, whoami
1517
from torchvision.transforms import (
1618
CenterCrop,
1719
Compose,
@@ -27,6 +29,160 @@
2729
logger = get_logger(__name__)
2830

2931

32+
def parse_args():
33+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
34+
parser.add_argument(
35+
"--dataset_name",
36+
type=str,
37+
default=None,
38+
help=(
39+
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
40+
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
41+
" or to a folder containing files that HF Datasets can understand."
42+
),
43+
)
44+
parser.add_argument(
45+
"--dataset_config_name",
46+
type=str,
47+
default=None,
48+
help="The config of the Dataset, leave as None if there's only one config.",
49+
)
50+
parser.add_argument(
51+
"--train_data_dir",
52+
type=str,
53+
default=None,
54+
help=(
55+
"A folder containing the training data. Folder contents must follow the structure described in"
56+
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
57+
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
58+
),
59+
)
60+
parser.add_argument(
61+
"--output_dir",
62+
type=str,
63+
default="ddpm-model-64",
64+
help="The output directory where the model predictions and checkpoints will be written.",
65+
)
66+
parser.add_argument("--overwrite_output_dir", action="store_true")
67+
parser.add_argument(
68+
"--cache_dir",
69+
type=str,
70+
default=None,
71+
help="The directory where the downloaded models and datasets will be stored.",
72+
)
73+
parser.add_argument(
74+
"--resolution",
75+
type=int,
76+
default=64,
77+
help=(
78+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
79+
" resolution"
80+
),
81+
)
82+
parser.add_argument(
83+
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
84+
)
85+
parser.add_argument(
86+
"--eval_batch_size", type=int, default=16, help="Batch size (per device) for the eval dataloader."
87+
)
88+
parser.add_argument("--num_epochs", type=int, default=100)
89+
parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.")
90+
parser.add_argument(
91+
"--save_model_epochs", type=int, default=10, help="How often to save the model during training."
92+
)
93+
parser.add_argument(
94+
"--gradient_accumulation_steps",
95+
type=int,
96+
default=1,
97+
help="Number of updates steps to accumulate before performing a backward/update pass.",
98+
)
99+
parser.add_argument(
100+
"--learning_rate",
101+
type=float,
102+
default=1e-4,
103+
help="Initial learning rate (after the potential warmup period) to use.",
104+
)
105+
parser.add_argument(
106+
"--lr_scheduler",
107+
type=str,
108+
default="cosine",
109+
help=(
110+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
111+
' "constant", "constant_with_warmup"]'
112+
),
113+
)
114+
parser.add_argument(
115+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
116+
)
117+
parser.add_argument("--adam_beta1", type=float, default=0.95, help="The beta1 parameter for the Adam optimizer.")
118+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
119+
parser.add_argument(
120+
"--adam_weight_decay", type=float, default=1e-6, help="Weight decay magnitude for the Adam optimizer."
121+
)
122+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer.")
123+
parser.add_argument(
124+
"--use_ema",
125+
action="store_true",
126+
default=True,
127+
help="Whether to use Exponential Moving Average for the final model weights.",
128+
)
129+
parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.")
130+
parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.")
131+
parser.add_argument("--ema_max_decay", type=float, default=0.9999, help="The maximum decay magnitude for EMA.")
132+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
133+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
134+
parser.add_argument(
135+
"--hub_model_id",
136+
type=str,
137+
default=None,
138+
help="The name of the repository to keep in sync with the local `output_dir`.",
139+
)
140+
parser.add_argument(
141+
"--hub_private_repo", action="store_true", help="Whether or not to create a private repository."
142+
)
143+
parser.add_argument(
144+
"--logging_dir",
145+
type=str,
146+
default="logs",
147+
help=(
148+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
149+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
150+
),
151+
)
152+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
153+
parser.add_argument(
154+
"--mixed_precision",
155+
type=str,
156+
default="no",
157+
choices=["no", "fp16", "bf16"],
158+
help=(
159+
"Whether to use mixed precision. Choose"
160+
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
161+
"and an Nvidia Ampere GPU."
162+
),
163+
)
164+
165+
args = parser.parse_args()
166+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
167+
if env_local_rank != -1 and env_local_rank != args.local_rank:
168+
args.local_rank = env_local_rank
169+
170+
if args.dataset_name is None and args.train_data_dir is None:
171+
raise ValueError("You must specify either a dataset name from the hub or a train data directory.")
172+
173+
return args
174+
175+
176+
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
177+
if token is None:
178+
token = HfFolder.get_token()
179+
if organization is None:
180+
username = whoami(token)["name"]
181+
return f"{username}/{model_id}"
182+
else:
183+
return f"{organization}/{model_id}"
184+
185+
30186
def main(args):
31187
logging_dir = os.path.join(args.output_dir, args.logging_dir)
32188
accelerator = Accelerator(
@@ -110,8 +266,22 @@ def transforms(examples):
110266

111267
ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
112268

113-
if args.push_to_hub:
114-
repo = init_git_repo(args, at_init=True)
269+
# Handle the repository creation
270+
if accelerator.is_main_process:
271+
if args.push_to_hub:
272+
if args.hub_model_id is None:
273+
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
274+
else:
275+
repo_name = args.hub_model_id
276+
repo = Repository(args.output_dir, clone_from=repo_name)
277+
278+
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
279+
if "step_*" not in gitignore:
280+
gitignore.write("step_*\n")
281+
if "epoch_*" not in gitignore:
282+
gitignore.write("epoch_*\n")
283+
elif args.output_dir is not None:
284+
os.makedirs(args.output_dir, exist_ok=True)
115285

116286
if accelerator.is_main_process:
117287
run = os.path.split(__file__)[-1].split(".")[0]
@@ -193,55 +363,5 @@ def transforms(examples):
193363

194364

195365
if __name__ == "__main__":
196-
parser = argparse.ArgumentParser(description="Simple example of a training script.")
197-
parser.add_argument("--local_rank", type=int, default=-1)
198-
parser.add_argument("--dataset_name", type=str, default=None)
199-
parser.add_argument("--dataset_config_name", type=str, default=None)
200-
parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.")
201-
parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
202-
parser.add_argument("--overwrite_output_dir", action="store_true")
203-
parser.add_argument("--cache_dir", type=str, default=None)
204-
parser.add_argument("--resolution", type=int, default=64)
205-
parser.add_argument("--train_batch_size", type=int, default=16)
206-
parser.add_argument("--eval_batch_size", type=int, default=16)
207-
parser.add_argument("--num_epochs", type=int, default=100)
208-
parser.add_argument("--save_images_epochs", type=int, default=10)
209-
parser.add_argument("--save_model_epochs", type=int, default=10)
210-
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
211-
parser.add_argument("--learning_rate", type=float, default=1e-4)
212-
parser.add_argument("--lr_scheduler", type=str, default="cosine")
213-
parser.add_argument("--lr_warmup_steps", type=int, default=500)
214-
parser.add_argument("--adam_beta1", type=float, default=0.95)
215-
parser.add_argument("--adam_beta2", type=float, default=0.999)
216-
parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
217-
parser.add_argument("--adam_epsilon", type=float, default=1e-08)
218-
parser.add_argument("--use_ema", action="store_true", default=True)
219-
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
220-
parser.add_argument("--ema_power", type=float, default=3 / 4)
221-
parser.add_argument("--ema_max_decay", type=float, default=0.9999)
222-
parser.add_argument("--push_to_hub", action="store_true")
223-
parser.add_argument("--hub_token", type=str, default=None)
224-
parser.add_argument("--hub_model_id", type=str, default=None)
225-
parser.add_argument("--hub_private_repo", action="store_true")
226-
parser.add_argument("--logging_dir", type=str, default="logs")
227-
parser.add_argument(
228-
"--mixed_precision",
229-
type=str,
230-
default="no",
231-
choices=["no", "fp16", "bf16"],
232-
help=(
233-
"Whether to use mixed precision. Choose"
234-
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
235-
"and an Nvidia Ampere GPU."
236-
),
237-
)
238-
239-
args = parser.parse_args()
240-
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
241-
if env_local_rank != -1 and env_local_rank != args.local_rank:
242-
args.local_rank = env_local_rank
243-
244-
if args.dataset_name is None and args.train_data_dir is None:
245-
raise ValueError("You must specify either a dataset name from the hub or a train data directory.")
246-
366+
args = parse_args()
247367
main(args)

src/diffusers/hub_utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from huggingface_hub import HfFolder, Repository, whoami
2323

2424
from .pipeline_utils import DiffusionPipeline
25-
from .utils import is_modelcards_available, logging
25+
from .utils import deprecate, is_modelcards_available, logging
2626

2727

2828
if is_modelcards_available():
@@ -53,6 +53,12 @@ def init_git_repo(args, at_init: bool = False):
5353
Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True`
5454
and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out.
5555
"""
56+
deprecation_message = (
57+
"Please use `huggingface_hub.Repository`. "
58+
"See `examples/unconditional_image_generation/train_unconditional.py` for an example."
59+
)
60+
deprecate("init_git_repo()", "0.10.0", deprecation_message)
61+
5662
if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
5763
return
5864
hub_token = args.hub_token if hasattr(args, "hub_token") else None
@@ -114,6 +120,11 @@ def push_to_hub(
114120
The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the
115121
commit and an object to track the progress of the commit if `blocking=True`
116122
"""
123+
deprecation_message = (
124+
"Please use `huggingface_hub.Repository` and `Repository.push_to_hub()`. "
125+
"See `examples/unconditional_image_generation/train_unconditional.py` for an example."
126+
)
127+
deprecate("push_to_hub()", "0.10.0", deprecation_message)
117128

118129
if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
119130
model_name = Path(args.output_dir).name

0 commit comments

Comments
 (0)