Skip to content

Commit a6314a8

Browse files
authored
Add --dataloader_num_workers to the DDPM training example (huggingface#1027)
1 parent 939ec17 commit a6314a8

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

examples/unconditional_image_generation/train_unconditional.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,16 @@ def parse_args():
8383
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
8484
)
8585
parser.add_argument(
86-
"--eval_batch_size", type=int, default=16, help="Batch size (per device) for the eval dataloader."
86+
"--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation."
87+
)
88+
parser.add_argument(
89+
"--dataloader_num_workers",
90+
type=int,
91+
default=0,
92+
help=(
93+
"The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
94+
" process."
95+
),
8796
)
8897
parser.add_argument("--num_epochs", type=int, default=100)
8998
parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.")
@@ -249,7 +258,9 @@ def transforms(examples):
249258
return {"input": images}
250259

251260
dataset.set_transform(transforms)
252-
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True)
261+
train_dataloader = torch.utils.data.DataLoader(
262+
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
263+
)
253264

254265
lr_scheduler = get_scheduler(
255266
args.lr_scheduler,

0 commit comments

Comments
 (0)