Skip to content

Commit 07ef485

Browse files
authored
[Community, Enhancement] Add reference tricks in README (huggingface#3589)
add reference tricks
1 parent 6cbddf5 commit 07ef485

File tree

3 files changed

+21
-16
lines changed

3 files changed

+21
-16
lines changed

examples/community/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,6 +1326,8 @@ image.save('tensorrt_img2img_new_zealand_hills.png')
13261326

13271327
This pipeline uses the Reference Control. Refer to the [sd-webui-controlnet discussion: Reference-only Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1236)[sd-webui-controlnet discussion: Reference-adain Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1280).
13281328

1329+
Based on [this issue](https://github.com/huggingface/diffusers/issues/3566),
1330+
- `EulerAncestralDiscreteScheduler` got poor results.
13291331

13301332
```py
13311333
import torch
@@ -1369,6 +1371,9 @@ Output Image of `reference_attn=True` and `reference_adain=True`
13691371

13701372
This pipeline uses the Reference Control with ControlNet. Refer to the [sd-webui-controlnet discussion: Reference-only Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1236)[sd-webui-controlnet discussion: Reference-adain Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1280).
13711373

1374+
Based on [this issue](https://github.com/huggingface/diffusers/issues/3566),
1375+
- `EulerAncestralDiscreteScheduler` got poor results.
1376+
- `guess_mode=True` works well for ControlNet v1.1
13721377

13731378
```py
13741379
import cv2

examples/community/stable_diffusion_controlnet_reference.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -505,8 +505,8 @@ def hack_CrossAttnDownBlock2D_forward(
505505
if MODE == "write":
506506
if gn_auto_machine_weight >= self.gn_weight:
507507
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
508-
self.mean_bank.append(mean)
509-
self.var_bank.append(var)
508+
self.mean_bank.append([mean])
509+
self.var_bank.append([var])
510510
if MODE == "read":
511511
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
512512
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
@@ -545,8 +545,8 @@ def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
545545
if MODE == "write":
546546
if gn_auto_machine_weight >= self.gn_weight:
547547
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
548-
self.mean_bank.append(mean)
549-
self.var_bank.append(var)
548+
self.mean_bank.append([mean])
549+
self.var_bank.append([var])
550550
if MODE == "read":
551551
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
552552
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
@@ -605,8 +605,8 @@ def hacked_CrossAttnUpBlock2D_forward(
605605
if MODE == "write":
606606
if gn_auto_machine_weight >= self.gn_weight:
607607
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
608-
self.mean_bank.append(mean)
609-
self.var_bank.append(var)
608+
self.mean_bank.append([mean])
609+
self.var_bank.append([var])
610610
if MODE == "read":
611611
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
612612
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
@@ -642,8 +642,8 @@ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=
642642
if MODE == "write":
643643
if gn_auto_machine_weight >= self.gn_weight:
644644
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
645-
self.mean_bank.append(mean)
646-
self.var_bank.append(var)
645+
self.mean_bank.append([mean])
646+
self.var_bank.append([var])
647647
if MODE == "read":
648648
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
649649
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)

examples/community/stable_diffusion_reference.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -499,8 +499,8 @@ def hack_CrossAttnDownBlock2D_forward(
499499
if MODE == "write":
500500
if gn_auto_machine_weight >= self.gn_weight:
501501
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
502-
self.mean_bank.append(mean)
503-
self.var_bank.append(var)
502+
self.mean_bank.append([mean])
503+
self.var_bank.append([var])
504504
if MODE == "read":
505505
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
506506
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
@@ -539,8 +539,8 @@ def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
539539
if MODE == "write":
540540
if gn_auto_machine_weight >= self.gn_weight:
541541
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
542-
self.mean_bank.append(mean)
543-
self.var_bank.append(var)
542+
self.mean_bank.append([mean])
543+
self.var_bank.append([var])
544544
if MODE == "read":
545545
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
546546
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
@@ -599,8 +599,8 @@ def hacked_CrossAttnUpBlock2D_forward(
599599
if MODE == "write":
600600
if gn_auto_machine_weight >= self.gn_weight:
601601
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
602-
self.mean_bank.append(mean)
603-
self.var_bank.append(var)
602+
self.mean_bank.append([mean])
603+
self.var_bank.append([var])
604604
if MODE == "read":
605605
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
606606
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
@@ -636,8 +636,8 @@ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=
636636
if MODE == "write":
637637
if gn_auto_machine_weight >= self.gn_weight:
638638
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
639-
self.mean_bank.append(mean)
640-
self.var_bank.append(var)
639+
self.mean_bank.append([mean])
640+
self.var_bank.append([var])
641641
if MODE == "read":
642642
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
643643
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)

0 commit comments

Comments
 (0)