|
24 | 24 | import torch |
25 | 25 | from requests.exceptions import HTTPError |
26 | 26 |
|
27 | | -from diffusers.models import ModelMixin, UNet2DConditionModel |
| 27 | +from diffusers.models import UNet2DConditionModel |
28 | 28 | from diffusers.models.attention_processor import AttnProcessor |
29 | 29 | from diffusers.training_utils import EMAModel |
30 | 30 | from diffusers.utils import torch_device |
@@ -119,11 +119,6 @@ def test_from_save_pretrained(self): |
119 | 119 | new_model.to(torch_device) |
120 | 120 |
|
121 | 121 | with torch.no_grad(): |
122 | | - # Warmup pass when using mps (see #372) |
123 | | - if torch_device == "mps" and isinstance(model, ModelMixin): |
124 | | - _ = model(**self.dummy_input) |
125 | | - _ = new_model(**self.dummy_input) |
126 | | - |
127 | 122 | image = model(**inputs_dict) |
128 | 123 | if isinstance(image, dict): |
129 | 124 | image = image.sample |
@@ -161,11 +156,6 @@ def test_from_save_pretrained_variant(self): |
161 | 156 | new_model.to(torch_device) |
162 | 157 |
|
163 | 158 | with torch.no_grad(): |
164 | | - # Warmup pass when using mps (see #372) |
165 | | - if torch_device == "mps" and isinstance(model, ModelMixin): |
166 | | - _ = model(**self.dummy_input) |
167 | | - _ = new_model(**self.dummy_input) |
168 | | - |
169 | 159 | image = model(**inputs_dict) |
170 | 160 | if isinstance(image, dict): |
171 | 161 | image = image.sample |
@@ -203,10 +193,6 @@ def test_determinism(self): |
203 | 193 | model.eval() |
204 | 194 |
|
205 | 195 | with torch.no_grad(): |
206 | | - # Warmup pass when using mps (see #372) |
207 | | - if torch_device == "mps" and isinstance(model, ModelMixin): |
208 | | - model(**self.dummy_input) |
209 | | - |
210 | 196 | first = model(**inputs_dict) |
211 | 197 | if isinstance(first, dict): |
212 | 198 | first = first.sample |
@@ -377,10 +363,6 @@ def recursive_check(tuple_object, dict_object): |
377 | 363 | model.eval() |
378 | 364 |
|
379 | 365 | with torch.no_grad(): |
380 | | - # Warmup pass when using mps (see #372) |
381 | | - if torch_device == "mps" and isinstance(model, ModelMixin): |
382 | | - model(**self.dummy_input) |
383 | | - |
384 | 366 | outputs_dict = model(**inputs_dict) |
385 | 367 | outputs_tuple = model(**inputs_dict, return_dict=False) |
386 | 368 |
|
|
0 commit comments