Skip to content

Commit 2aa5a35

Browse files
committed
add load
1 parent 0cb6ded commit 2aa5a35

File tree

3 files changed

+63
-2
lines changed

3 files changed

+63
-2
lines changed

src/diffusers/oneflow_graph_compile_cache.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,22 @@ def compile(self, *args, **kwargs):
3636

3737
self.is_compiled_ = True
3838

39+
def load_runtime_state_dict(self, state_dict):
40+
if self.is_compiled_:
41+
return
42+
43+
global_class_name = self.graph_.__class__.__name__
44+
logger.info(
45+
f"[oneflow] loading {global_class_name} beforehand to make sure the progress bar is more accurate",
46+
)
47+
load_start = timer()
48+
load_time = 0
49+
self.graph_.load_runtime_state_dict(state_dict)
50+
load_time = timer() - load_start
51+
logger.info(f"[oneflow] [elapsed(s)] [{global_class_name} loading] {load_time:.3f}")
52+
53+
self.is_compiled_ = True
54+
3955
def share_from(self, other_graph):
4056
self.graph_.share_from(other_graph.graph_)
4157
self.is_shared_from_ = True
@@ -52,6 +68,7 @@ def __init__(self, cache_size):
5268
self.cache_size = cache_size
5369
self.queue = deque()
5470
self.hash_map = dict()
71+
self.cnt = 0
5572

5673
def front(self):
5774
if self.is_empty():
@@ -81,6 +98,8 @@ def set(self, key, value):
8198
pop_key = self.pop()
8299

83100
self.queue.appendleft(key)
101+
value._oneflow_graph_cache_order = self.cnt
102+
self.cnt += 1
84103
self.hash_map[key] = value
85104
return pop_key if pop_key is not None else key
86105

@@ -120,11 +139,46 @@ def enable_share_mem(self, enabled=True):
120139
def enable_save_graph(self, enabled=True):
121140
self.enable_save_graph_ = enabled
122141

142+
def enable_load_graph(self, enabled=True):
143+
self.enable_load_graph_ = enabled
144+
123145
def save_graph(self, path):
124146
if self.enable_save_graph_:
125147
for (graph_class_name, cache) in self.cache_bucket_.items():
126148
for (key, graph) in cache.pairs():
127-
flow.save(graph.graph_.runtime_state_dict(), os.path.join(path, graph_class_name + "_" + str(hash(key))))
149+
state_dict = graph.graph_.runtime_state_dict()
150+
state_dict["cache_order"] = graph._oneflow_graph_cache_order
151+
state_dict["cache_key"] = key
152+
state_dict["graph_class_name"] = graph_class_name
153+
flow.save(state_dict, os.path.join(path, graph_class_name + "_" + str(hash(key))))
154+
155+
def load_graph(self, path):
156+
if self.enable_load_graph_:
157+
sub_folders = [ f.path for f in os.scandir(path) if f.is_dir() ]
158+
graph_dict = dict()
159+
for sub_folder in sub_folders:
160+
state_dict = flow.load(sub_folder, map_location="cuda")
161+
cache_order = state_dict["cache_order"]
162+
graph_dict[cache_order] = state_dict
163+
164+
for order, state_dict in sorted(graph_dict.items()):
165+
graph_class_name = state_dict["graph_class_name"]
166+
cache_key = state_dict["cache_key"]
167+
if graph_class_name not in self.cache_bucket_:
168+
self.cache_bucket_[graph_class_name] = LRUCache(self.cache_size_)
169+
# TODO(): release eager vae/unet module
170+
compile_cache = self.cache_bucket_[graph_class_name]
171+
graph = OneFlowGraph(flow.nn.Graph)
172+
if self.enable_share_mem_ is True:
173+
if graph_class_name in self.share_origin_:
174+
graph.share_from(self.share_origin_[graph_class_name])
175+
else:
176+
self.share_origin_[graph_class_name] = graph
177+
graph.graph_.enable_shared()
178+
179+
graph.load_runtime_state_dict(state_dict)
180+
ret = compile_cache.set(cache_key, graph)
181+
assert ret is not None
128182

129183
def get_graph(self, graph_class, cache_key, *args, **kwargs):
130184
graph_class_name = graph_class.__name__

src/diffusers/pipeline_oneflow_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,12 @@ def enable_save_graph(self, enabled=True):
178178
def save_graph(self, path):
179179
self.graph_compile_cache.save_graph(path)
180180

181+
def enable_load_graph(self, enabled=True):
182+
self.graph_compile_cache.enable_load_graph(enabled)
183+
184+
def load_graph(self, path):
185+
self.graph_compile_cache.load_graph(path)
186+
181187
def register_modules(self, **kwargs):
182188
# import it here to avoid circular import
183189
from diffusers import pipelines

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_oneflow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,8 @@ def __call__(
650650
vae_post_process = VaePostProcess(self.vae)
651651
vae_post_process.eval()
652652
vae_post_process_graph = self.graph_compile_cache.get_graph(VaeGraph, cache_key, vae_post_process)
653-
vae_post_process_graph.compile(latents)
653+
if vae_post_process_graph.is_compiled is False:
654+
vae_post_process_graph.compile(latents)
654655

655656
# compile unet graph
656657
if compile_unet:

0 commit comments

Comments
 (0)