@@ -65,6 +65,9 @@ def test_full_loop_no_noise(self):
6565 if torch_device in ["mps" ]:
6666 assert abs (result_sum .item () - 167.47821044921875 ) < 1e-2
6767 assert abs (result_mean .item () - 0.2178705964565277 ) < 1e-3
68+ elif torch_device in ["cuda" ]:
69+ assert abs (result_sum .item () - 171.59352111816406 ) < 1e-2
70+ assert abs (result_mean .item () - 0.22342906892299652 ) < 1e-3
6871 else :
6972 assert abs (result_sum .item () - 162.52383422851562 ) < 1e-2
7073 assert abs (result_mean .item () - 0.211619570851326 ) < 1e-3
@@ -94,6 +97,9 @@ def test_full_loop_with_v_prediction(self):
9497 if torch_device in ["mps" ]:
9598 assert abs (result_sum .item () - 124.77149200439453 ) < 1e-2
9699 assert abs (result_mean .item () - 0.16226289014816284 ) < 1e-3
100+ elif torch_device in ["cuda" ]:
101+ assert abs (result_sum .item () - 128.1663360595703 ) < 1e-2
102+ assert abs (result_mean .item () - 0.16688326001167297 ) < 1e-3
97103 else :
98104 assert abs (result_sum .item () - 119.8487548828125 ) < 1e-2
99105 assert abs (result_mean .item () - 0.1560530662536621 ) < 1e-3
@@ -122,6 +128,9 @@ def test_full_loop_device(self):
122128 if torch_device in ["mps" ]:
123129 assert abs (result_sum .item () - 167.46957397460938 ) < 1e-2
124130 assert abs (result_mean .item () - 0.21805934607982635 ) < 1e-3
131+ elif torch_device in ["cuda" ]:
132+ assert abs (result_sum .item () - 171.59353637695312 ) < 1e-2
133+ assert abs (result_mean .item () - 0.22342908382415771 ) < 1e-3
125134 else :
126135 assert abs (result_sum .item () - 162.52383422851562 ) < 1e-2
127136 assert abs (result_mean .item () - 0.211619570851326 ) < 1e-3
@@ -151,6 +160,9 @@ def test_full_loop_device_karras_sigmas(self):
151160 if torch_device in ["mps" ]:
152161 assert abs (result_sum .item () - 176.66974135742188 ) < 1e-2
153162 assert abs (result_mean .item () - 0.23003872730981811 ) < 1e-2
163+ elif torch_device in ["cuda" ]:
164+ assert abs (result_sum .item () - 177.63653564453125 ) < 1e-2
165+ assert abs (result_mean .item () - 0.23003872730981811 ) < 1e-2
154166 else :
155167 assert abs (result_sum .item () - 170.3135223388672 ) < 1e-2
156168 assert abs (result_mean .item () - 0.23003872730981811 ) < 1e-2
0 commit comments