File tree Expand file tree Collapse file tree 2 files changed +10
-2
lines changed
pipelines/versatile_diffusion Expand file tree Collapse file tree 2 files changed +10
-2
lines changed Original file line number Diff line number Diff line change @@ -659,7 +659,7 @@ def forward(
659659
660660 t_emb = self .time_proj (timesteps )
661661
662- # timesteps does not contain any weights and will always return f32 tensors
662+ # `Timesteps` does not contain any weights and will always return f32 tensors
663663 # but time_embedding might actually be running in fp16. so we need to cast here.
664664 # there might be better ways to encapsulate this.
665665 t_emb = t_emb .to (dtype = self .dtype )
@@ -673,6 +673,10 @@ def forward(
673673 if self .config .class_embed_type == "timestep" :
674674 class_labels = self .time_proj (class_labels )
675675
676+ # `Timesteps` does not contain any weights and will always return f32 tensors
677+ # there might be better ways to encapsulate this.
678+ class_labels = class_labels .to (dtype = sample .dtype )
679+
676680 class_emb = self .class_embedding (class_labels ).to (dtype = self .dtype )
677681
678682 if self .config .class_embeddings_concat :
Original file line number Diff line number Diff line change @@ -756,7 +756,7 @@ def forward(
756756
757757 t_emb = self .time_proj (timesteps )
758758
759- # timesteps does not contain any weights and will always return f32 tensors
759+ # `Timesteps` does not contain any weights and will always return f32 tensors
760760 # but time_embedding might actually be running in fp16. so we need to cast here.
761761 # there might be better ways to encapsulate this.
762762 t_emb = t_emb .to (dtype = self .dtype )
@@ -770,6 +770,10 @@ def forward(
770770 if self .config .class_embed_type == "timestep" :
771771 class_labels = self .time_proj (class_labels )
772772
773+ # `Timesteps` does not contain any weights and will always return f32 tensors
774+ # there might be better ways to encapsulate this.
775+ class_labels = class_labels .to (dtype = sample .dtype )
776+
773777 class_emb = self .class_embedding (class_labels ).to (dtype = self .dtype )
774778
775779 if self .config .class_embeddings_concat :
You can’t perform that action at this time.
0 commit comments