@@ -133,11 +133,15 @@ def test_freeu_enabled(self):
133133
134134 inputs = self .get_dummy_inputs (torch_device )
135135 inputs ["return_dict" ] = False
136+ inputs ["output_type" ] = "np"
137+
136138 output = pipe (** inputs )[0 ]
137139
138140 pipe .enable_freeu (s1 = 0.9 , s2 = 0.2 , b1 = 1.2 , b2 = 1.4 )
139141 inputs = self .get_dummy_inputs (torch_device )
140142 inputs ["return_dict" ] = False
143+ inputs ["output_type" ] = "np"
144+
141145 output_freeu = pipe (** inputs )[0 ]
142146
143147 assert not np .allclose (
@@ -152,6 +156,8 @@ def test_freeu_disabled(self):
152156
153157 inputs = self .get_dummy_inputs (torch_device )
154158 inputs ["return_dict" ] = False
159+ inputs ["output_type" ] = "np"
160+
155161 output = pipe (** inputs )[0 ]
156162
157163 pipe .enable_freeu (s1 = 0.9 , s2 = 0.2 , b1 = 1.2 , b2 = 1.4 )
@@ -164,6 +170,8 @@ def test_freeu_disabled(self):
164170
165171 inputs = self .get_dummy_inputs (torch_device )
166172 inputs ["return_dict" ] = False
173+ inputs ["output_type" ] = "np"
174+
167175 output_no_freeu = pipe (** inputs )[0 ]
168176 assert np .allclose (
169177 output , output_no_freeu , atol = 1e-2
0 commit comments