@@ -411,7 +411,9 @@ def test_spatial_transformer_cross_attention_dim(self):
411411
412412 assert attention_scores .shape == (1 , 64 , 64 , 64 )
413413 output_slice = attention_scores [0 , - 1 , - 3 :, - 3 :]
414- expected_slice = torch .tensor ([0.0143 , - 0.6909 , - 2.1547 , - 1.8893 , 1.4097 , 0.1359 , - 0.2521 , - 1.3359 , 0.2598 ])
414+ expected_slice = torch .tensor (
415+ [0.0143 , - 0.6909 , - 2.1547 , - 1.8893 , 1.4097 , 0.1359 , - 0.2521 , - 1.3359 , 0.2598 ], device = torch_device
416+ )
415417 assert torch .allclose (output_slice .flatten (), expected_slice , atol = 1e-3 )
416418
417419 def test_spatial_transformer_timestep (self ):
@@ -442,9 +444,11 @@ def test_spatial_transformer_timestep(self):
442444 output_slice_1 = attention_scores_1 [0 , - 1 , - 3 :, - 3 :]
443445 output_slice_2 = attention_scores_2 [0 , - 1 , - 3 :, - 3 :]
444446
445- expected_slice = torch .tensor ([- 0.3923 , - 1.0923 , - 1.7144 , - 1.5570 , 1.4154 , 0.1738 , - 0.1157 , - 1.2998 , - 0.1703 ])
447+ expected_slice = torch .tensor (
448+ [- 0.3923 , - 1.0923 , - 1.7144 , - 1.5570 , 1.4154 , 0.1738 , - 0.1157 , - 1.2998 , - 0.1703 ], device = torch_device
449+ )
446450 expected_slice_2 = torch .tensor (
447- [- 0.4311 , - 1.1376 , - 1.7732 , - 1.5997 , 1.3450 , 0.0964 , - 0.1569 , - 1.3590 , - 0.2348 ]
451+ [- 0.4311 , - 1.1376 , - 1.7732 , - 1.5997 , 1.3450 , 0.0964 , - 0.1569 , - 1.3590 , - 0.2348 ], device = torch_device
448452 )
449453
450454 assert torch .allclose (output_slice_1 .flatten (), expected_slice , atol = 1e-3 )
0 commit comments