Skip to content

Commit 397d94a

Browse files
authored
always use oneflow.load (huggingface#84)
2 parents d53868a + 3139667 commit 397d94a

File tree

1 file changed

+1
-14
lines changed

1 file changed

+1
-14
lines changed

src/diffusers/modeling_oneflow_utils.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -92,20 +92,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
9292
Reads a checkpoint file, returning properly formatted errors if they arise.
9393
"""
9494
try:
95-
# this is oneflow saved model, a dir
96-
if os.path.isdir(checkpoint_file):
97-
return torch.load(checkpoint_file, map_location="cpu")
98-
elif os.path.basename(checkpoint_file) == WEIGHTS_NAME:
99-
import torch as og_torch
100-
101-
torch_parameters = og_torch.load(checkpoint_file, map_location="cpu")
102-
oneflow_parameters = dict()
103-
for key, value in torch_parameters.items():
104-
if value.is_cuda:
105-
raise ValueError(f"torch model is not on cpu, it is on {value.device}")
106-
val = value.detach().cpu().numpy()
107-
oneflow_parameters[key] = torch.from_numpy(val)
108-
return oneflow_parameters
95+
return torch.load(checkpoint_file, map_location="cpu")
10996
except Exception as e:
11097
try:
11198
with open(checkpoint_file) as f:

0 commit comments

Comments
 (0)