Skip to content

Commit e5d714d

Browse files
Jonas Müllerqp-qp
Jonas Müller
andauthored
Improved sampling (Stability-AI#69)
* New research features * Add new model specs --------- Co-authored-by: Dominik Lorenz <[email protected]> * remove sd1.5 and change default refiner to 1.0 * remove asking second time for output * adapt model names * adjusted strength * Correctly pass prompt --------- Co-authored-by: Dominik Lorenz <[email protected]>
1 parent f2fa96b commit e5d714d

File tree

2 files changed

+514
-93
lines changed

2 files changed

+514
-93
lines changed

scripts/demo/sampling.py

Lines changed: 72 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
1-
import numpy as np
21
from pytorch_lightning import seed_everything
32

43
from scripts.demo.streamlit_helpers import *
5-
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
6-
from sgm.inference.helpers import (
7-
do_img2img,
8-
do_sample,
9-
get_unique_embedder_keys_from_conditioner,
10-
perform_save_locally,
11-
)
124

135
SAVE_PATH = "outputs/demo/txt2img/"
146

@@ -42,27 +34,34 @@
4234
}
4335

4436
VERSION2SPECS = {
45-
"SD-XL base": {
37+
"SDXL-base-1.0": {
38+
"H": 1024,
39+
"W": 1024,
40+
"C": 4,
41+
"f": 8,
42+
"is_legacy": False,
43+
"config": "configs/inference/sd_xl_base.yaml",
44+
"ckpt": "checkpoints/sd_xl_base_1.0.safetensors",
45+
},
46+
"SDXL-base-0.9": {
4647
"H": 1024,
4748
"W": 1024,
4849
"C": 4,
4950
"f": 8,
5051
"is_legacy": False,
5152
"config": "configs/inference/sd_xl_base.yaml",
5253
"ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
53-
"is_guided": True,
5454
},
55-
"sd-2.1": {
55+
"SD-2.1": {
5656
"H": 512,
5757
"W": 512,
5858
"C": 4,
5959
"f": 8,
6060
"is_legacy": True,
6161
"config": "configs/inference/sd_2_1.yaml",
6262
"ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors",
63-
"is_guided": True,
6463
},
65-
"sd-2.1-768": {
64+
"SD-2.1-768": {
6665
"H": 768,
6766
"W": 768,
6867
"C": 4,
@@ -71,15 +70,23 @@
7170
"config": "configs/inference/sd_2_1_768.yaml",
7271
"ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors",
7372
},
74-
"SDXL-Refiner": {
73+
"SDXL-refiner-0.9": {
7574
"H": 1024,
7675
"W": 1024,
7776
"C": 4,
7877
"f": 8,
7978
"is_legacy": True,
8079
"config": "configs/inference/sd_xl_refiner.yaml",
8180
"ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
82-
"is_guided": True,
81+
},
82+
"SDXL-refiner-1.0": {
83+
"H": 1024,
84+
"W": 1024,
85+
"C": 4,
86+
"f": 8,
87+
"is_legacy": True,
88+
"config": "configs/inference/sd_xl_refiner.yaml",
89+
"ckpt": "checkpoints/sd_xl_refiner_1.0.safetensors",
8390
},
8491
}
8592

@@ -103,18 +110,19 @@ def load_img(display=True, key=None, device="cuda"):
103110

104111

105112
def run_txt2img(
106-
state, version, version_dict, is_legacy=False, return_latents=False, filter=None
113+
state,
114+
version,
115+
version_dict,
116+
is_legacy=False,
117+
return_latents=False,
118+
filter=None,
119+
stage2strength=None,
107120
):
108-
if version == "SD-XL base":
109-
ratio = st.sidebar.selectbox("Ratio:", list(SD_XL_BASE_RATIOS.keys()), 10)
110-
W, H = SD_XL_BASE_RATIOS[ratio]
121+
if version.startswith("SDXL-base"):
122+
W, H = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10)
111123
else:
112-
H = st.sidebar.number_input(
113-
"H", value=version_dict["H"], min_value=64, max_value=2048
114-
)
115-
W = st.sidebar.number_input(
116-
"W", value=version_dict["W"], min_value=64, max_value=2048
117-
)
124+
H = st.number_input("H", value=version_dict["H"], min_value=64, max_value=2048)
125+
W = st.number_input("W", value=version_dict["W"], min_value=64, max_value=2048)
118126
C = version_dict["C"]
119127
F = version_dict["f"]
120128

@@ -130,16 +138,11 @@ def run_txt2img(
130138
prompt=prompt,
131139
negative_prompt=negative_prompt,
132140
)
133-
num_rows, num_cols, sampler = init_sampling(
134-
use_identity_guider=not version_dict["is_guided"]
135-
)
136-
141+
sampler, num_rows, num_cols = init_sampling(stage2strength=stage2strength)
137142
num_samples = num_rows * num_cols
138143

139144
if st.button("Sample"):
140145
st.write(f"**Model I:** {version}")
141-
outputs = st.empty()
142-
st.text("Sampling")
143146
out = do_sample(
144147
state["model"],
145148
sampler,
@@ -153,13 +156,16 @@ def run_txt2img(
153156
return_latents=return_latents,
154157
filter=filter,
155158
)
156-
show_samples(out, outputs)
157-
158159
return out
159160

160161

161162
def run_img2img(
162-
state, version_dict, is_legacy=False, return_latents=False, filter=None
163+
state,
164+
version_dict,
165+
is_legacy=False,
166+
return_latents=False,
167+
filter=None,
168+
stage2strength=None,
163169
):
164170
img = load_img()
165171
if img is None:
@@ -175,19 +181,19 @@ def run_img2img(
175181
value_dict = init_embedder_options(
176182
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
177183
init_dict,
184+
prompt=prompt,
185+
negative_prompt=negative_prompt,
178186
)
179187
strength = st.number_input(
180-
"**Img2Img Strength**", value=0.5, min_value=0.0, max_value=1.0
188+
"**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0
181189
)
182-
num_rows, num_cols, sampler = init_sampling(
190+
sampler, num_rows, num_cols = init_sampling(
183191
img2img_strength=strength,
184-
use_identity_guider=not version_dict["is_guided"],
192+
stage2strength=stage2strength,
185193
)
186194
num_samples = num_rows * num_cols
187195

188196
if st.button("Sample"):
189-
outputs = st.empty()
190-
st.text("Sampling")
191197
out = do_img2img(
192198
repeat(img, "1 ... -> n ...", n=num_samples),
193199
state["model"],
@@ -198,7 +204,6 @@ def run_img2img(
198204
return_latents=return_latents,
199205
filter=filter,
200206
)
201-
show_samples(out, outputs)
202207
return out
203208

204209

@@ -210,6 +215,7 @@ def apply_refiner(
210215
prompt,
211216
negative_prompt,
212217
filter=None,
218+
finish_denoising=False,
213219
):
214220
init_dict = {
215221
"orig_width": input.shape[3] * 8,
@@ -237,6 +243,7 @@ def apply_refiner(
237243
num_samples,
238244
skip_encode=True,
239245
filter=filter,
246+
add_noise=not finish_denoising,
240247
)
241248

242249
return samples
@@ -249,20 +256,22 @@ def apply_refiner(
249256
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
250257
st.write("__________________________")
251258

252-
if version == "SD-XL base":
253-
add_pipeline = st.checkbox("Load SDXL-Refiner?", False)
259+
set_lowvram_mode(st.checkbox("Low vram mode", True))
260+
261+
if version.startswith("SDXL-base"):
262+
add_pipeline = st.checkbox("Load SDXL-refiner?", False)
254263
st.write("__________________________")
255264
else:
256265
add_pipeline = False
257266

258-
filter = DeepFloydDataFiltering(verbose=False)
259-
260267
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
261268
seed_everything(seed)
262269

263270
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
264271

265-
state = init_st(version_dict)
272+
state = init_st(version_dict, load_filter=True)
273+
if state["msg"]:
274+
st.info(state["msg"])
266275
model = state["model"]
267276

268277
is_legacy = version_dict["is_legacy"]
@@ -276,29 +285,34 @@ def apply_refiner(
276285
else:
277286
negative_prompt = "" # which is unused
278287

288+
stage2strength = None
289+
finish_denoising = False
290+
279291
if add_pipeline:
280292
st.write("__________________________")
281-
282-
version2 = "SDXL-Refiner"
293+
version2 = st.selectbox("Refiner:", ["SDXL-refiner-1.0", "SDXL-refiner-0.9"])
283294
st.warning(
284295
f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
285296
)
286297
st.write("**Refiner Options:**")
287298

288299
version_dict2 = VERSION2SPECS[version2]
289-
state2 = init_st(version_dict2)
300+
state2 = init_st(version_dict2, load_filter=False)
301+
st.info(state2["msg"])
290302

291303
stage2strength = st.number_input(
292-
"**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0
304+
"**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0
293305
)
294306

295-
sampler2 = init_sampling(
307+
sampler2, *_ = init_sampling(
296308
key=2,
297309
img2img_strength=stage2strength,
298-
use_identity_guider=not version_dict2["is_guided"],
299-
get_num_samples=False,
310+
specify_num_samples=False,
300311
)
301312
st.write("__________________________")
313+
finish_denoising = st.checkbox("Finish denoising with refiner.", True)
314+
if not finish_denoising:
315+
stage2strength = None
302316

303317
if mode == "txt2img":
304318
out = run_txt2img(
@@ -307,15 +321,17 @@ def apply_refiner(
307321
version_dict,
308322
is_legacy=is_legacy,
309323
return_latents=add_pipeline,
310-
filter=filter,
324+
filter=state.get("filter"),
325+
stage2strength=stage2strength,
311326
)
312327
elif mode == "img2img":
313328
out = run_img2img(
314329
state,
315330
version_dict,
316331
is_legacy=is_legacy,
317332
return_latents=add_pipeline,
318-
filter=filter,
333+
filter=state.get("filter"),
334+
stage2strength=stage2strength,
319335
)
320336
else:
321337
raise ValueError(f"unknown mode {mode}")
@@ -326,7 +342,6 @@ def apply_refiner(
326342
samples_z = None
327343

328344
if add_pipeline and samples_z is not None:
329-
outputs = st.empty()
330345
st.write("**Running Refinement Stage**")
331346
samples = apply_refiner(
332347
samples_z,
@@ -335,9 +350,9 @@ def apply_refiner(
335350
samples_z.shape[0],
336351
prompt=prompt,
337352
negative_prompt=negative_prompt if is_legacy else "",
338-
filter=filter,
353+
filter=state.get("filter"),
354+
finish_denoising=finish_denoising,
339355
)
340-
show_samples(samples, outputs)
341356

342357
if save_locally and samples is not None:
343358
perform_save_locally(save_path, samples)

0 commit comments

Comments
 (0)