Skip to content

Commit faeb1df

Browse files
committed
fix pip load
1 parent 4e5ca34 commit faeb1df

File tree

4 files changed

+41
-30
lines changed

4 files changed

+41
-30
lines changed

src/diffusers/modeling_oneflow_utils.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -92,20 +92,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
9292
Reads a checkpoint file, returning properly formatted errors if they arise.
9393
"""
9494
try:
95-
# this is oneflow saved model, a dir
96-
if os.path.isdir(checkpoint_file):
97-
return torch.load(checkpoint_file, map_location="cpu")
98-
elif os.path.basename(checkpoint_file) == WEIGHTS_NAME:
99-
import torch as og_torch
100-
101-
torch_parameters = og_torch.load(checkpoint_file, map_location="cpu")
102-
oneflow_parameters = dict()
103-
for key, value in torch_parameters.items():
104-
if value.is_cuda:
105-
raise ValueError(f"torch model is not on cpu, it is on {value.device}")
106-
val = value.detach().cpu().numpy()
107-
oneflow_parameters[key] = torch.from_numpy(val)
108-
return oneflow_parameters
95+
return torch.load(checkpoint_file, map_location="cpu")
10996
except Exception as e:
11097
try:
11198
with open(checkpoint_file) as f:

src/diffusers/oneflow_graph_compile_cache.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,10 @@ def save_graph(self, path):
154154

155155
def load_graph(self, path, graph_class2init_args=None):
156156
if self.enable_load_graph_:
157-
sub_folders = [ f.path for f in os.scandir(path) if f.is_dir() ]
157+
sub_files = [ f.path for f in os.scandir(path) if f.is_file() ]
158158
graph_dict = dict()
159-
for sub_folder in sub_folders:
160-
state_dict = flow.load(sub_folder)
159+
for sub_file in sub_files:
160+
state_dict = flow.load(sub_file)
161161
cache_order = state_dict["cache_order"]
162162
graph_dict[cache_order] = state_dict
163163

src/diffusers/pipeline_oneflow_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,8 @@ def load_module(name, value):
675675
# 3. Load each module in the pipeline
676676
for name, (library_name, class_name) in init_dict.items():
677677
if name in ["scheduler", "unet", "vae", "text_encoder", "safety_checker"]:
678-
class_name = "OneFlow" + class_name
678+
if "OneFlow" not in class_name:
679+
class_name = "OneFlow" + class_name
679680
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
680681
if class_name.startswith("Flax"):
681682
class_name = class_name[4:]

tests/test_pipelines_oneflow_graph_load.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,39 @@
11
import oneflow as torch
22
import time
3+
import os
4+
import shutil
5+
36
from diffusers import (
47
OneFlowStableDiffusionPipeline as StableDiffusionPipeline,
58
OneFlowEulerDiscreteScheduler as EulerDiscreteScheduler,
69
)
710
from diffusers import utils
811

912
model_id = "stabilityai/stable-diffusion-2"
13+
_graph_save_file = "./test_sd_save_graph"
14+
_sch_file_path = "./test_sd_sch"
15+
_pipe_file_path = "./test_sd_pipe"
1016

11-
_offline_compile = False
17+
_online_mode = True
18+
_pipe_from_file = True
1219

1320
total_start_t = time.time()
1421
start_t = time.time()
15-
# StableDiffusionPipeline 需要支持 unet 和 vae load graph, 此时无需创建 eager module
1622
@utils.cost_cnt
1723
def get_pipe():
18-
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
19-
sd_pipe = StableDiffusionPipeline.from_pretrained(
20-
model_id, scheduler=scheduler, revision="fp16", torch_dtype=torch.float16
21-
)
24+
if _pipe_from_file:
25+
scheduler = EulerDiscreteScheduler.from_pretrained(_sch_file_path, subfolder="scheduler")
26+
sd_pipe = StableDiffusionPipeline.from_pretrained(
27+
_pipe_file_path, scheduler=scheduler, revision="fp16", torch_dtype=torch.float16
28+
)
29+
else:
30+
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
31+
sd_pipe = StableDiffusionPipeline.from_pretrained(
32+
model_id, scheduler=scheduler, revision="fp16", torch_dtype=torch.float16
33+
)
2234
torch._oneflow_internal.eager.Sync()
23-
return sd_pipe
24-
pipe = get_pipe()
35+
return scheduler, sd_pipe
36+
sch, pipe = get_pipe()
2537

2638
@utils.cost_cnt
2739
def pipe_to_cuda():
@@ -37,13 +49,14 @@ def config_graph():
3749
torch._oneflow_internal.eager.Sync()
3850
config_graph()
3951

40-
if _offline_compile:
52+
if not _online_mode:
4153
pipe.enable_save_graph()
4254
else:
4355
@utils.cost_cnt
4456
def load_graph():
4557
pipe.enable_load_graph()
46-
pipe.load_graph("./test_save_load", compile_unet=True, compile_vae=False)
58+
assert (os.path.exists(_graph_save_file) and os.path.isdir(_graph_save_file))
59+
pipe.load_graph(_graph_save_file, compile_unet=True, compile_vae=False)
4760
torch._oneflow_internal.eager.Sync()
4861
load_graph()
4962
end_t = time.time()
@@ -85,9 +98,19 @@ def text_to_image(prompt, image_size, num_images_per_prompt=1, prefix=""):
8598
total_end_t = time.time()
8699
print("st init and run time ", total_end_t - total_start_t, 's.')
87100

101+
@utils.cost_cnt
102+
def save_pipe_sch():
103+
pipe.save_pretrained(_pipe_file_path)
104+
sch.save_pretrained(_sch_file_path)
105+
88106
@utils.cost_cnt
89107
def save_graph():
90-
pipe.save_graph("./test_save_load")
108+
if os.path.exists(_graph_save_file) and os.path.isdir(_graph_save_file):
109+
shutil.rmtree(_graph_save_file)
110+
os.makedirs(_graph_save_file)
111+
112+
pipe.save_graph(_graph_save_file)
91113

92-
if _offline_compile:
114+
if not _online_mode:
115+
save_pipe_sch()
93116
save_graph()

0 commit comments

Comments
 (0)