Skip to content

Commit 1997614

Browse files
prathikrrootpatrickvonplatenPrathik Rao
authored
avoid upcasting by assigning dtype to noise tensor (huggingface#3713)
* avoid upcasting by assigning dtype to noise tensor * make style * Update train_unconditional.py * Update train_unconditional.py * make style * add unit test for pickle * revert change --------- Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Prathik Rao <[email protected]@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
1 parent 4e89856 commit 1997614

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,9 @@ def transform_images(examples):
568568

569569
clean_images = batch["input"]
570570
# Sample noise that we'll add to the images
571-
noise = torch.randn(clean_images.shape).to(clean_images.device)
571+
noise = torch.randn(
572+
clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16)
573+
).to(clean_images.device)
572574
bsz = clean_images.shape[0]
573575
# Sample a random timestep for each image
574576
timesteps = torch.randint(

examples/unconditional_image_generation/train_unconditional.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,9 @@ def transform_images(examples):
557557

558558
clean_images = batch["input"]
559559
# Sample noise that we'll add to the images
560-
noise = torch.randn(clean_images.shape).to(clean_images.device)
560+
noise = torch.randn(
561+
clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16)
562+
).to(clean_images.device)
561563
bsz = clean_images.shape[0]
562564
# Sample a random timestep for each image
563565
timesteps = torch.randint(

0 commit comments

Comments
 (0)