Skip to content

Commit 111228c

Browse files
wfng92patrickvonplatenpcuenca
authored
Fix torchvision.transforms and transforms function naming clash (huggingface#2274)
* Fix torchvision.transforms and transforms function naming clash * Update unconditional script for onnx * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent bbb46ad commit 111228c

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,13 +386,13 @@ def main(args):
386386
]
387387
)
388388

389-
def transforms(examples):
389+
def transform_images(examples):
390390
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
391391
return {"input": images}
392392

393393
logger.info(f"Dataset size: {len(dataset)}")
394394

395-
dataset.set_transform(transforms)
395+
dataset.set_transform(transform_images)
396396
train_dataloader = torch.utils.data.DataLoader(
397397
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
398398
)

examples/unconditional_image_generation/train_unconditional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,13 +386,13 @@ def main(args):
386386
]
387387
)
388388

389-
def transforms(examples):
389+
def transform_images(examples):
390390
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
391391
return {"input": images}
392392

393393
logger.info(f"Dataset size: {len(dataset)}")
394394

395-
dataset.set_transform(transforms)
395+
dataset.set_transform(transform_images)
396396
train_dataloader = torch.utils.data.DataLoader(
397397
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
398398
)

0 commit comments

Comments
 (0)