Skip to content

Commit 013955b

Browse files
[Dit] Fix dit tests (huggingface#2034)
* [Dit] Fix dit tests * up
1 parent ed616bd commit 013955b

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

tests/pipelines/dit/test_dit.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ class DiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
3636
def get_dummy_components(self):
3737
torch.manual_seed(0)
3838
transformer = Transformer2DModel(
39-
sample_size=4,
39+
sample_size=16,
4040
num_layers=2,
41-
patch_size=2,
42-
attention_head_dim=2,
41+
patch_size=4,
42+
attention_head_dim=8,
4343
num_attention_heads=2,
4444
in_channels=4,
4545
out_channels=8,
@@ -79,10 +79,8 @@ def test_inference(self):
7979
image = pipe(**inputs).images
8080
image_slice = image[0, -3:, -3:, -1]
8181

82-
self.assertEqual(image.shape, (1, 4, 4, 3))
83-
expected_slice = np.array(
84-
[0.44405967, 0.33592293, 0.6093237, 0.48981372, 0.79098296, 0.7504172, 0.59413105, 0.49462673, 0.35190058]
85-
)
82+
self.assertEqual(image.shape, (1, 16, 16, 3))
83+
expected_slice = np.array([0.4380, 0.4141, 0.5159, 0.0000, 0.4282, 0.6680, 0.5485, 0.2545, 0.6719])
8684
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
8785
self.assertLessEqual(max_diff, 1e-3)
8886

0 commit comments

Comments
 (0)