11import oneflow as torch
22import time
3+ import os
4+ import shutil
5+
36from diffusers import (
47 OneFlowStableDiffusionPipeline as StableDiffusionPipeline ,
58 OneFlowEulerDiscreteScheduler as EulerDiscreteScheduler ,
69)
710from diffusers import utils
811
912model_id = "stabilityai/stable-diffusion-2"
13+ _graph_save_file = "./test_sd_save_graph"
14+ _sch_file_path = "./test_sd_sch"
15+ _pipe_file_path = "./test_sd_pipe"
1016
11- _offline_compile = False
17+ _online_mode = True
18+ _pipe_from_file = True
1219
1320total_start_t = time .time ()
1421start_t = time .time ()
15- # StableDiffusionPipeline 需要支持 unet 和 vae load graph, 此时无需创建 eager module
1622@utils .cost_cnt
1723def get_pipe ():
18- scheduler = EulerDiscreteScheduler .from_pretrained (model_id , subfolder = "scheduler" )
19- sd_pipe = StableDiffusionPipeline .from_pretrained (
20- model_id , scheduler = scheduler , revision = "fp16" , torch_dtype = torch .float16
21- )
24+ if _pipe_from_file :
25+ scheduler = EulerDiscreteScheduler .from_pretrained (_sch_file_path , subfolder = "scheduler" )
26+ sd_pipe = StableDiffusionPipeline .from_pretrained (
27+ _pipe_file_path , scheduler = scheduler , revision = "fp16" , torch_dtype = torch .float16
28+ )
29+ else :
30+ scheduler = EulerDiscreteScheduler .from_pretrained (model_id , subfolder = "scheduler" )
31+ sd_pipe = StableDiffusionPipeline .from_pretrained (
32+ model_id , scheduler = scheduler , revision = "fp16" , torch_dtype = torch .float16
33+ )
2234 torch ._oneflow_internal .eager .Sync ()
23- return sd_pipe
24- pipe = get_pipe ()
35+ return scheduler , sd_pipe
36+ sch , pipe = get_pipe ()
2537
2638@utils .cost_cnt
2739def pipe_to_cuda ():
@@ -37,13 +49,14 @@ def config_graph():
3749 torch ._oneflow_internal .eager .Sync ()
3850config_graph ()
3951
40- if _offline_compile :
52+ if not _online_mode :
4153 pipe .enable_save_graph ()
4254else :
4355 @utils .cost_cnt
4456 def load_graph ():
4557 pipe .enable_load_graph ()
46- pipe .load_graph ("./test_save_load" , compile_unet = True , compile_vae = False )
58+ assert (os .path .exists (_graph_save_file ) and os .path .isdir (_graph_save_file ))
59+ pipe .load_graph (_graph_save_file , compile_unet = True , compile_vae = False )
4760 torch ._oneflow_internal .eager .Sync ()
4861 load_graph ()
4962end_t = time .time ()
@@ -85,9 +98,19 @@ def text_to_image(prompt, image_size, num_images_per_prompt=1, prefix=""):
8598total_end_t = time .time ()
8699print ("st init and run time " , total_end_t - total_start_t , 's.' )
87100
101+ @utils .cost_cnt
102+ def save_pipe_sch ():
103+ pipe .save_pretrained (_pipe_file_path )
104+ sch .save_pretrained (_sch_file_path )
105+
88106@utils .cost_cnt
89107def save_graph ():
90- pipe .save_graph ("./test_save_load" )
108+ if os .path .exists (_graph_save_file ) and os .path .isdir (_graph_save_file ):
109+ shutil .rmtree (_graph_save_file )
110+ os .makedirs (_graph_save_file )
111+
112+ pipe .save_graph (_graph_save_file )
91113
92- if _offline_compile :
114+ if not _online_mode :
115+ save_pipe_sch ()
93116 save_graph ()
0 commit comments