@@ -36,6 +36,22 @@ def compile(self, *args, **kwargs):
3636
3737 self .is_compiled_ = True
3838
39+ def load_runtime_state_dict (self , state_dict ):
40+ if self .is_compiled_ :
41+ return
42+
43+ global_class_name = self .graph_ .__class__ .__name__
44+ logger .info (
45+ f"[oneflow] loading { global_class_name } beforehand to make sure the progress bar is more accurate" ,
46+ )
47+ load_start = timer ()
48+ load_time = 0
49+ self .graph_ .load_runtime_state_dict (state_dict )
50+ load_time = timer () - load_start
51+ logger .info (f"[oneflow] [elapsed(s)] [{ global_class_name } loading] { load_time :.3f} " )
52+
53+ self .is_compiled_ = True
54+
3955 def share_from (self , other_graph ):
4056 self .graph_ .share_from (other_graph .graph_ )
4157 self .is_shared_from_ = True
@@ -52,6 +68,7 @@ def __init__(self, cache_size):
5268 self .cache_size = cache_size
5369 self .queue = deque ()
5470 self .hash_map = dict ()
71+ self .cnt = 0
5572
5673 def front (self ):
5774 if self .is_empty ():
@@ -81,6 +98,8 @@ def set(self, key, value):
8198 pop_key = self .pop ()
8299
83100 self .queue .appendleft (key )
101+ value ._oneflow_graph_cache_order = self .cnt
102+ self .cnt += 1
84103 self .hash_map [key ] = value
85104 return pop_key if pop_key is not None else key
86105
@@ -120,11 +139,46 @@ def enable_share_mem(self, enabled=True):
120139 def enable_save_graph (self , enabled = True ):
121140 self .enable_save_graph_ = enabled
122141
142+ def enable_load_graph (self , enabled = True ):
143+ self .enable_load_graph_ = enabled
144+
123145 def save_graph (self , path ):
124146 if self .enable_save_graph_ :
125147 for (graph_class_name , cache ) in self .cache_bucket_ .items ():
126148 for (key , graph ) in cache .pairs ():
127- flow .save (graph .graph_ .runtime_state_dict (), os .path .join (path , graph_class_name + "_" + str (hash (key ))))
149+ state_dict = graph .graph_ .runtime_state_dict ()
150+ state_dict ["cache_order" ] = graph ._oneflow_graph_cache_order
151+ state_dict ["cache_key" ] = key
152+ state_dict ["graph_class_name" ] = graph_class_name
153+ flow .save (state_dict , os .path .join (path , graph_class_name + "_" + str (hash (key ))))
154+
155+ def load_graph (self , path ):
156+ if self .enable_load_graph_ :
157+ sub_folders = [ f .path for f in os .scandir (path ) if f .is_dir () ]
158+ graph_dict = dict ()
159+ for sub_folder in sub_folders :
160+ state_dict = flow .load (sub_folder , map_location = "cuda" )
161+ cache_order = state_dict ["cache_order" ]
162+ graph_dict [cache_order ] = state_dict
163+
164+ for order , state_dict in sorted (graph_dict .items ()):
165+ graph_class_name = state_dict ["graph_class_name" ]
166+ cache_key = state_dict ["cache_key" ]
167+ if graph_class_name not in self .cache_bucket_ :
168+ self .cache_bucket_ [graph_class_name ] = LRUCache (self .cache_size_ )
169+ # TODO(): release eager vae/unet module
170+ compile_cache = self .cache_bucket_ [graph_class_name ]
171+ graph = OneFlowGraph (flow .nn .Graph )
172+ if self .enable_share_mem_ is True :
173+ if graph_class_name in self .share_origin_ :
174+ graph .share_from (self .share_origin_ [graph_class_name ])
175+ else :
176+ self .share_origin_ [graph_class_name ] = graph
177+ graph .graph_ .enable_shared ()
178+
179+ graph .load_runtime_state_dict (state_dict )
180+ ret = compile_cache .set (cache_key , graph )
181+ assert ret is not None
128182
129183 def get_graph (self , graph_class , cache_key , * args , ** kwargs ):
130184 graph_class_name = graph_class .__name__
0 commit comments