1
- import numpy as np
2
1
from pytorch_lightning import seed_everything
3
2
4
3
from scripts .demo .streamlit_helpers import *
5
- from scripts .util .detection .nsfw_and_watermark_dectection import DeepFloydDataFiltering
6
- from sgm .inference .helpers import (
7
- do_img2img ,
8
- do_sample ,
9
- get_unique_embedder_keys_from_conditioner ,
10
- perform_save_locally ,
11
- )
12
4
13
5
SAVE_PATH = "outputs/demo/txt2img/"
14
6
42
34
}
43
35
44
36
VERSION2SPECS = {
45
- "SD-XL base" : {
37
+ "SDXL-base-1.0" : {
38
+ "H" : 1024 ,
39
+ "W" : 1024 ,
40
+ "C" : 4 ,
41
+ "f" : 8 ,
42
+ "is_legacy" : False ,
43
+ "config" : "configs/inference/sd_xl_base.yaml" ,
44
+ "ckpt" : "checkpoints/sd_xl_base_1.0.safetensors" ,
45
+ },
46
+ "SDXL-base-0.9" : {
46
47
"H" : 1024 ,
47
48
"W" : 1024 ,
48
49
"C" : 4 ,
49
50
"f" : 8 ,
50
51
"is_legacy" : False ,
51
52
"config" : "configs/inference/sd_xl_base.yaml" ,
52
53
"ckpt" : "checkpoints/sd_xl_base_0.9.safetensors" ,
53
- "is_guided" : True ,
54
54
},
55
- "sd -2.1" : {
55
+ "SD -2.1" : {
56
56
"H" : 512 ,
57
57
"W" : 512 ,
58
58
"C" : 4 ,
59
59
"f" : 8 ,
60
60
"is_legacy" : True ,
61
61
"config" : "configs/inference/sd_2_1.yaml" ,
62
62
"ckpt" : "checkpoints/v2-1_512-ema-pruned.safetensors" ,
63
- "is_guided" : True ,
64
63
},
65
- "sd -2.1-768" : {
64
+ "SD -2.1-768" : {
66
65
"H" : 768 ,
67
66
"W" : 768 ,
68
67
"C" : 4 ,
71
70
"config" : "configs/inference/sd_2_1_768.yaml" ,
72
71
"ckpt" : "checkpoints/v2-1_768-ema-pruned.safetensors" ,
73
72
},
74
- "SDXL-Refiner " : {
73
+ "SDXL-refiner-0.9 " : {
75
74
"H" : 1024 ,
76
75
"W" : 1024 ,
77
76
"C" : 4 ,
78
77
"f" : 8 ,
79
78
"is_legacy" : True ,
80
79
"config" : "configs/inference/sd_xl_refiner.yaml" ,
81
80
"ckpt" : "checkpoints/sd_xl_refiner_0.9.safetensors" ,
82
- "is_guided" : True ,
81
+ },
82
+ "SDXL-refiner-1.0" : {
83
+ "H" : 1024 ,
84
+ "W" : 1024 ,
85
+ "C" : 4 ,
86
+ "f" : 8 ,
87
+ "is_legacy" : True ,
88
+ "config" : "configs/inference/sd_xl_refiner.yaml" ,
89
+ "ckpt" : "checkpoints/sd_xl_refiner_1.0.safetensors" ,
83
90
},
84
91
}
85
92
@@ -103,18 +110,19 @@ def load_img(display=True, key=None, device="cuda"):
103
110
104
111
105
112
def run_txt2img (
106
- state , version , version_dict , is_legacy = False , return_latents = False , filter = None
113
+ state ,
114
+ version ,
115
+ version_dict ,
116
+ is_legacy = False ,
117
+ return_latents = False ,
118
+ filter = None ,
119
+ stage2strength = None ,
107
120
):
108
- if version == "SD-XL base" :
109
- ratio = st .sidebar .selectbox ("Ratio:" , list (SD_XL_BASE_RATIOS .keys ()), 10 )
110
- W , H = SD_XL_BASE_RATIOS [ratio ]
121
+ if version .startswith ("SDXL-base" ):
122
+ W , H = st .selectbox ("Resolution:" , list (SD_XL_BASE_RATIOS .values ()), 10 )
111
123
else :
112
- H = st .sidebar .number_input (
113
- "H" , value = version_dict ["H" ], min_value = 64 , max_value = 2048
114
- )
115
- W = st .sidebar .number_input (
116
- "W" , value = version_dict ["W" ], min_value = 64 , max_value = 2048
117
- )
124
+ H = st .number_input ("H" , value = version_dict ["H" ], min_value = 64 , max_value = 2048 )
125
+ W = st .number_input ("W" , value = version_dict ["W" ], min_value = 64 , max_value = 2048 )
118
126
C = version_dict ["C" ]
119
127
F = version_dict ["f" ]
120
128
@@ -130,16 +138,11 @@ def run_txt2img(
130
138
prompt = prompt ,
131
139
negative_prompt = negative_prompt ,
132
140
)
133
- num_rows , num_cols , sampler = init_sampling (
134
- use_identity_guider = not version_dict ["is_guided" ]
135
- )
136
-
141
+ sampler , num_rows , num_cols = init_sampling (stage2strength = stage2strength )
137
142
num_samples = num_rows * num_cols
138
143
139
144
if st .button ("Sample" ):
140
145
st .write (f"**Model I:** { version } " )
141
- outputs = st .empty ()
142
- st .text ("Sampling" )
143
146
out = do_sample (
144
147
state ["model" ],
145
148
sampler ,
@@ -153,13 +156,16 @@ def run_txt2img(
153
156
return_latents = return_latents ,
154
157
filter = filter ,
155
158
)
156
- show_samples (out , outputs )
157
-
158
159
return out
159
160
160
161
161
162
def run_img2img (
162
- state , version_dict , is_legacy = False , return_latents = False , filter = None
163
+ state ,
164
+ version_dict ,
165
+ is_legacy = False ,
166
+ return_latents = False ,
167
+ filter = None ,
168
+ stage2strength = None ,
163
169
):
164
170
img = load_img ()
165
171
if img is None :
@@ -175,19 +181,19 @@ def run_img2img(
175
181
value_dict = init_embedder_options (
176
182
get_unique_embedder_keys_from_conditioner (state ["model" ].conditioner ),
177
183
init_dict ,
184
+ prompt = prompt ,
185
+ negative_prompt = negative_prompt ,
178
186
)
179
187
strength = st .number_input (
180
- "**Img2Img Strength**" , value = 0.5 , min_value = 0.0 , max_value = 1.0
188
+ "**Img2Img Strength**" , value = 0.75 , min_value = 0.0 , max_value = 1.0
181
189
)
182
- num_rows , num_cols , sampler = init_sampling (
190
+ sampler , num_rows , num_cols = init_sampling (
183
191
img2img_strength = strength ,
184
- use_identity_guider = not version_dict [ "is_guided" ] ,
192
+ stage2strength = stage2strength ,
185
193
)
186
194
num_samples = num_rows * num_cols
187
195
188
196
if st .button ("Sample" ):
189
- outputs = st .empty ()
190
- st .text ("Sampling" )
191
197
out = do_img2img (
192
198
repeat (img , "1 ... -> n ..." , n = num_samples ),
193
199
state ["model" ],
@@ -198,7 +204,6 @@ def run_img2img(
198
204
return_latents = return_latents ,
199
205
filter = filter ,
200
206
)
201
- show_samples (out , outputs )
202
207
return out
203
208
204
209
@@ -210,6 +215,7 @@ def apply_refiner(
210
215
prompt ,
211
216
negative_prompt ,
212
217
filter = None ,
218
+ finish_denoising = False ,
213
219
):
214
220
init_dict = {
215
221
"orig_width" : input .shape [3 ] * 8 ,
@@ -237,6 +243,7 @@ def apply_refiner(
237
243
num_samples ,
238
244
skip_encode = True ,
239
245
filter = filter ,
246
+ add_noise = not finish_denoising ,
240
247
)
241
248
242
249
return samples
@@ -249,20 +256,22 @@ def apply_refiner(
249
256
mode = st .radio ("Mode" , ("txt2img" , "img2img" ), 0 )
250
257
st .write ("__________________________" )
251
258
252
- if version == "SD-XL base" :
253
- add_pipeline = st .checkbox ("Load SDXL-Refiner?" , False )
259
+ set_lowvram_mode (st .checkbox ("Low vram mode" , True ))
260
+
261
+ if version .startswith ("SDXL-base" ):
262
+ add_pipeline = st .checkbox ("Load SDXL-refiner?" , False )
254
263
st .write ("__________________________" )
255
264
else :
256
265
add_pipeline = False
257
266
258
- filter = DeepFloydDataFiltering (verbose = False )
259
-
260
267
seed = st .sidebar .number_input ("seed" , value = 42 , min_value = 0 , max_value = int (1e9 ))
261
268
seed_everything (seed )
262
269
263
270
save_locally , save_path = init_save_locally (os .path .join (SAVE_PATH , version ))
264
271
265
- state = init_st (version_dict )
272
+ state = init_st (version_dict , load_filter = True )
273
+ if state ["msg" ]:
274
+ st .info (state ["msg" ])
266
275
model = state ["model" ]
267
276
268
277
is_legacy = version_dict ["is_legacy" ]
@@ -276,29 +285,34 @@ def apply_refiner(
276
285
else :
277
286
negative_prompt = "" # which is unused
278
287
288
+ stage2strength = None
289
+ finish_denoising = False
290
+
279
291
if add_pipeline :
280
292
st .write ("__________________________" )
281
-
282
- version2 = "SDXL-Refiner"
293
+ version2 = st .selectbox ("Refiner:" , ["SDXL-refiner-1.0" , "SDXL-refiner-0.9" ])
283
294
st .warning (
284
295
f"Running with { version2 } as the second stage model. Make sure to provide (V)RAM :) "
285
296
)
286
297
st .write ("**Refiner Options:**" )
287
298
288
299
version_dict2 = VERSION2SPECS [version2 ]
289
- state2 = init_st (version_dict2 )
300
+ state2 = init_st (version_dict2 , load_filter = False )
301
+ st .info (state2 ["msg" ])
290
302
291
303
stage2strength = st .number_input (
292
- "**Refinement strength**" , value = 0.3 , min_value = 0.0 , max_value = 1.0
304
+ "**Refinement strength**" , value = 0.15 , min_value = 0.0 , max_value = 1.0
293
305
)
294
306
295
- sampler2 = init_sampling (
307
+ sampler2 , * _ = init_sampling (
296
308
key = 2 ,
297
309
img2img_strength = stage2strength ,
298
- use_identity_guider = not version_dict2 ["is_guided" ],
299
- get_num_samples = False ,
310
+ specify_num_samples = False ,
300
311
)
301
312
st .write ("__________________________" )
313
+ finish_denoising = st .checkbox ("Finish denoising with refiner." , True )
314
+ if not finish_denoising :
315
+ stage2strength = None
302
316
303
317
if mode == "txt2img" :
304
318
out = run_txt2img (
@@ -307,15 +321,17 @@ def apply_refiner(
307
321
version_dict ,
308
322
is_legacy = is_legacy ,
309
323
return_latents = add_pipeline ,
310
- filter = filter ,
324
+ filter = state .get ("filter" ),
325
+ stage2strength = stage2strength ,
311
326
)
312
327
elif mode == "img2img" :
313
328
out = run_img2img (
314
329
state ,
315
330
version_dict ,
316
331
is_legacy = is_legacy ,
317
332
return_latents = add_pipeline ,
318
- filter = filter ,
333
+ filter = state .get ("filter" ),
334
+ stage2strength = stage2strength ,
319
335
)
320
336
else :
321
337
raise ValueError (f"unknown mode { mode } " )
@@ -326,7 +342,6 @@ def apply_refiner(
326
342
samples_z = None
327
343
328
344
if add_pipeline and samples_z is not None :
329
- outputs = st .empty ()
330
345
st .write ("**Running Refinement Stage**" )
331
346
samples = apply_refiner (
332
347
samples_z ,
@@ -335,9 +350,9 @@ def apply_refiner(
335
350
samples_z .shape [0 ],
336
351
prompt = prompt ,
337
352
negative_prompt = negative_prompt if is_legacy else "" ,
338
- filter = filter ,
353
+ filter = state .get ("filter" ),
354
+ finish_denoising = finish_denoising ,
339
355
)
340
- show_samples (samples , outputs )
341
356
342
357
if save_locally and samples is not None :
343
358
perform_save_locally (save_path , samples )
0 commit comments