|
22 | 22 | from typing import Callable, Dict, List, Optional, Union |
23 | 23 |
|
24 | 24 | import requests |
| 25 | +import safetensors |
25 | 26 | import torch |
26 | 27 | import torch.nn.functional as F |
27 | 28 | from huggingface_hub import hf_hub_download |
|
34 | 35 | deprecate, |
35 | 36 | is_accelerate_available, |
36 | 37 | is_omegaconf_available, |
37 | | - is_safetensors_available, |
38 | 38 | is_transformers_available, |
39 | 39 | logging, |
40 | 40 | ) |
41 | 41 | from .utils.import_utils import BACKENDS_MAPPING |
42 | 42 |
|
43 | 43 |
|
44 | | -if is_safetensors_available(): |
45 | | - import safetensors |
46 | | - |
47 | 44 | if is_transformers_available(): |
48 | 45 | from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel, PreTrainedTokenizer |
49 | 46 |
|
@@ -261,14 +258,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict |
261 | 258 | network_alphas = kwargs.pop("network_alphas", None) |
262 | 259 | is_network_alphas_none = network_alphas is None |
263 | 260 |
|
264 | | - if use_safetensors and not is_safetensors_available(): |
265 | | - raise ValueError( |
266 | | - "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" |
267 | | - ) |
268 | | - |
269 | 261 | allow_pickle = False |
| 262 | + |
270 | 263 | if use_safetensors is None: |
271 | | - use_safetensors = is_safetensors_available() |
| 264 | + use_safetensors = True |
272 | 265 | allow_pickle = True |
273 | 266 |
|
274 | 267 | user_agent = { |
@@ -757,14 +750,9 @@ def load_textual_inversion( |
757 | 750 | weight_name = kwargs.pop("weight_name", None) |
758 | 751 | use_safetensors = kwargs.pop("use_safetensors", None) |
759 | 752 |
|
760 | | - if use_safetensors and not is_safetensors_available(): |
761 | | - raise ValueError( |
762 | | - "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" |
763 | | - ) |
764 | | - |
765 | 753 | allow_pickle = False |
766 | 754 | if use_safetensors is None: |
767 | | - use_safetensors = is_safetensors_available() |
| 755 | + use_safetensors = True |
768 | 756 | allow_pickle = True |
769 | 757 |
|
770 | 758 | user_agent = { |
@@ -1014,14 +1002,9 @@ def lora_state_dict( |
1014 | 1002 | unet_config = kwargs.pop("unet_config", None) |
1015 | 1003 | use_safetensors = kwargs.pop("use_safetensors", None) |
1016 | 1004 |
|
1017 | | - if use_safetensors and not is_safetensors_available(): |
1018 | | - raise ValueError( |
1019 | | - "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" |
1020 | | - ) |
1021 | | - |
1022 | 1005 | allow_pickle = False |
1023 | 1006 | if use_safetensors is None: |
1024 | | - use_safetensors = is_safetensors_available() |
| 1007 | + use_safetensors = True |
1025 | 1008 | allow_pickle = True |
1026 | 1009 |
|
1027 | 1010 | user_agent = { |
@@ -1853,7 +1836,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): |
1853 | 1836 |
|
1854 | 1837 | torch_dtype = kwargs.pop("torch_dtype", None) |
1855 | 1838 |
|
1856 | | - use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) |
| 1839 | + use_safetensors = kwargs.pop("use_safetensors", None) |
1857 | 1840 |
|
1858 | 1841 | pipeline_name = cls.__name__ |
1859 | 1842 | file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] |
@@ -2050,7 +2033,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): |
2050 | 2033 |
|
2051 | 2034 | torch_dtype = kwargs.pop("torch_dtype", None) |
2052 | 2035 |
|
2053 | | - use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) |
| 2036 | + use_safetensors = kwargs.pop("use_safetensors", None) |
2054 | 2037 |
|
2055 | 2038 | file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] |
2056 | 2039 | from_safetensors = file_extension == "safetensors" |
@@ -2223,7 +2206,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): |
2223 | 2206 |
|
2224 | 2207 | torch_dtype = kwargs.pop("torch_dtype", None) |
2225 | 2208 |
|
2226 | | - use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) |
| 2209 | + use_safetensors = kwargs.pop("use_safetensors", None) |
2227 | 2210 |
|
2228 | 2211 | file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] |
2229 | 2212 | from_safetensors = file_extension == "safetensors" |
|
0 commit comments