@@ -44,7 +44,7 @@ def get_dummy_components(self):
4444 torch .manual_seed (0 )
4545 unet = UNet2DConditionModel (
4646 block_out_channels = (32 , 64 ),
47- layers_per_block = 2 ,
47+ layers_per_block = 1 ,
4848 sample_size = 32 ,
4949 in_channels = 4 ,
5050 out_channels = 4 ,
@@ -111,7 +111,7 @@ def get_dummy_inputs(self, device, seed=0):
111111 "prompt" : "a cat and a frog" ,
112112 "token_indices" : [2 , 5 ],
113113 "generator" : generator ,
114- "num_inference_steps" : 2 ,
114+ "num_inference_steps" : 1 ,
115115 "guidance_scale" : 6.0 ,
116116 "output_type" : "numpy" ,
117117 "max_iter_to_alter" : 2 ,
@@ -132,13 +132,18 @@ def test_inference(self):
132132 image_slice = image [0 , - 3 :, - 3 :, - 1 ]
133133
134134 self .assertEqual (image .shape , (1 , 64 , 64 , 3 ))
135- expected_slice = np .array ([0.5743 , 0.6081 , 0.4975 , 0.5021 , 0.5441 , 0.4699 , 0.4988 , 0.4841 , 0.4851 ])
135+ expected_slice = np .array (
136+ [0.63905364 , 0.62897307 , 0.48599017 , 0.5133624 , 0.5550048 , 0.45769516 , 0.50326973 , 0.5023139 , 0.45384496 ]
137+ )
136138 max_diff = np .abs (image_slice .flatten () - expected_slice ).max ()
137139 self .assertLessEqual (max_diff , 1e-3 )
138140
139141 def test_inference_batch_consistent (self ):
140142 # NOTE: Larger batch sizes cause this test to timeout, only test on smaller batches
141- self ._test_inference_batch_consistent (batch_sizes = [2 , 4 ])
143+ self ._test_inference_batch_consistent (batch_sizes = [1 , 2 ])
144+
145+ def test_inference_batch_single_identical (self ):
146+ self ._test_inference_batch_single_identical (batch_size = 2 )
142147
143148
144149@require_torch_gpu
0 commit comments