Skip to content

Commit 795f53c

Browse files
committed
fix device and order
1 parent 2aa5a35 commit 795f53c

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

src/diffusers/oneflow_graph_compile_cache.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,11 @@ def __call__(self, *args, **kwargs):
6464

6565

6666
class LRUCache(object):
67+
_cnt: int = 0
6768
def __init__(self, cache_size):
6869
self.cache_size = cache_size
6970
self.queue = deque()
7071
self.hash_map = dict()
71-
self.cnt = 0
7272

7373
def front(self):
7474
if self.is_empty():
@@ -98,8 +98,8 @@ def set(self, key, value):
9898
pop_key = self.pop()
9999

100100
self.queue.appendleft(key)
101-
value._oneflow_graph_cache_order = self.cnt
102-
self.cnt += 1
101+
value._oneflow_graph_cache_order = LRUCache._cnt
102+
LRUCache._cnt += 1
103103
self.hash_map[key] = value
104104
return pop_key if pop_key is not None else key
105105

@@ -157,11 +157,13 @@ def load_graph(self, path):
157157
sub_folders = [ f.path for f in os.scandir(path) if f.is_dir() ]
158158
graph_dict = dict()
159159
for sub_folder in sub_folders:
160-
state_dict = flow.load(sub_folder, map_location="cuda")
160+
state_dict = flow.load(sub_folder)
161161
cache_order = state_dict["cache_order"]
162+
print("===> order", cache_order)
162163
graph_dict[cache_order] = state_dict
163164

164165
for order, state_dict in sorted(graph_dict.items()):
166+
print("===> load order", order)
165167
graph_class_name = state_dict["graph_class_name"]
166168
cache_key = state_dict["cache_key"]
167169
if graph_class_name not in self.cache_bucket_:

src/diffusers/utils/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
import os
17+
import time
1718

1819
from .deprecation_utils import deprecate
1920
from .import_utils import (
@@ -87,3 +88,19 @@
8788
"EulerAncestralDiscreteScheduler",
8889
"DPMSolverMultistepScheduler",
8990
]
91+
92+
def cost_cnt(fn):
93+
import oneflow as flow
94+
def new_fn(*args, **kwargs):
95+
print("==>", fn.__name__, " try to run")
96+
before_used = flow._oneflow_internal.GetCUDAMemoryUsed()
97+
start_time = time.time()
98+
out = fn(*args, **kwargs)
99+
end_time = time.time()
100+
after_used = flow._oneflow_internal.GetCUDAMemoryUsed()
101+
print(fn.__name__, " run time ", end_time - start_time)
102+
print(fn.__name__, " cuda mem", after_used - before_used)
103+
print("<==", fn.__name__, " finish run")
104+
return out
105+
106+
return new_fn

0 commit comments

Comments
 (0)