@@ -64,11 +64,11 @@ def __call__(self, *args, **kwargs):
6464
6565
6666class LRUCache (object ):
67+ _cnt : int = 0
6768 def __init__ (self , cache_size ):
6869 self .cache_size = cache_size
6970 self .queue = deque ()
7071 self .hash_map = dict ()
71- self .cnt = 0
7272
7373 def front (self ):
7474 if self .is_empty ():
@@ -98,8 +98,8 @@ def set(self, key, value):
9898 pop_key = self .pop ()
9999
100100 self .queue .appendleft (key )
101- value ._oneflow_graph_cache_order = self . cnt
102- self . cnt += 1
101+ value ._oneflow_graph_cache_order = LRUCache . _cnt
102+ LRUCache . _cnt += 1
103103 self .hash_map [key ] = value
104104 return pop_key if pop_key is not None else key
105105
@@ -157,11 +157,13 @@ def load_graph(self, path):
157157 sub_folders = [ f .path for f in os .scandir (path ) if f .is_dir () ]
158158 graph_dict = dict ()
159159 for sub_folder in sub_folders :
160- state_dict = flow .load (sub_folder , map_location = "cuda" )
160+ state_dict = flow .load (sub_folder )
161161 cache_order = state_dict ["cache_order" ]
162+ print ("===> order" , cache_order )
162163 graph_dict [cache_order ] = state_dict
163164
164165 for order , state_dict in sorted (graph_dict .items ()):
166+ print ("===> load order" , order )
165167 graph_class_name = state_dict ["graph_class_name" ]
166168 cache_key = state_dict ["cache_key" ]
167169 if graph_class_name not in self .cache_bucket_ :
0 commit comments