Skip to content

Commit 47d3406

Browse files
committed
T2I-Adapter diffusers implementation
orginal: https://github.com/TencentARC/T2I-Adapter
1 parent 45572c2 commit 47d3406

File tree

7 files changed

+477
-3
lines changed

7 files changed

+477
-3
lines changed

scripts/t2i_adapter_tester.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import cv2
2+
import argparse
3+
#import importlib
4+
import random
5+
import copy
6+
import torch
7+
import sys
8+
import os
9+
root_path = os.getcwd()
10+
print(root_path )
11+
sys.path.append(f"{root_path}/src")
12+
import diffusers
13+
from PIL import Image
14+
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
15+
from extra.t2iadapter.adapter import Adapter
16+
from basicsr.utils import img2tensor, tensor2img, scandir, get_time_str, get_root_logger, get_env_info
17+
18+
diffusers.utils.logging.disable_progress_bar()
19+
20+
class DummySafetyChecker():
21+
def safety_checker(self, images, *args, **kwargs):
22+
return images, [False] * len(images)
23+
24+
def loadmodel(pipeline_name, model_path, **kwargs):
25+
print("load pipeline")
26+
print("load model from:", pipeline_name, model_path)
27+
28+
kwargs["torch_dtype"] = torch.float16
29+
if pipeline_name == "StableDiffusionPipeline":
30+
kwargs["revision"] = "fp16"
31+
32+
safechecker = DummySafetyChecker().safety_checker
33+
kwargs["safety_checker"] = safechecker
34+
35+
pipe = StableDiffusionPipeline.from_pretrained(model_path, **kwargs)
36+
return pipe.to("cuda")
37+
38+
39+
def generation(pipe, prompt, seed, features_adapter=None):
40+
settings = {
41+
"height": 512,
42+
"width": 512,
43+
"num_inference_steps": 50,
44+
}
45+
settings["prompt"] = prompt
46+
g = torch.Generator(device="cuda")
47+
settings["generator"] = g.manual_seed(seed)
48+
settings["features_adapter"] = features_adapter
49+
settings["features_adapter_strength"] = 0.4
50+
images = pipe(**settings).images
51+
return images
52+
53+
54+
def main() -> int:
55+
parser = argparse.ArgumentParser(description="auto aiart generator")
56+
parser.add_argument(
57+
"-p", "--pipeline", help="Diffusers pipeline name", required=True
58+
)
59+
parser.add_argument("-m", "--model_path", help="model path", required=True)
60+
parser.add_argument("-ad", "--ckpt_ad", help="path to checkpoint of adapter", required=True)
61+
parser.add_argument("-cond", "--path_cond", help="path to adapter condition", required=True)
62+
args = parser.parse_args()
63+
kwargs = {}
64+
65+
device = "cuda"
66+
model_ad = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device).half()
67+
model_ad.load_state_dict(torch.load(args.ckpt_ad))
68+
edge = cv2.imread(args.path_cond)
69+
edge = cv2.resize(edge,(512,512))
70+
edge = img2tensor(edge)[0].unsqueeze(0).unsqueeze(0)/255.
71+
edge = edge>0.5
72+
edge = edge.float().half()
73+
features_adapter = model_ad(edge.to(device))
74+
75+
pipe = loadmodel(args.pipeline, args.model_path, **kwargs)
76+
77+
if args.pipeline == "StableDiffusionPipeline":
78+
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
79+
80+
prompt = f"A car with flying wings"
81+
outputimg = generation(pipe, prompt, 52, features_adapter)
82+
filename = f"output.png"
83+
outputimg[0].save(f"{filename}")
84+
85+
if __name__ == "__main__":
86+
sys.exit(main()) # next section explains the use of sys.exit

src/diffusers/models/unet_2d_blocks.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,13 @@ def __init__(
808808
self.gradient_checkpointing = False
809809

810810
def forward(
811-
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
811+
self,
812+
hidden_states,
813+
temb=None,
814+
encoder_hidden_states=None,
815+
attention_mask=None,
816+
cross_attention_kwargs=None,
817+
features_adapter=None,
812818
):
813819
# TODO(Patrick, William) - attention mask is not used
814820
output_states = ()
@@ -842,6 +848,9 @@ def custom_forward(*inputs):
842848

843849
output_states += (hidden_states,)
844850

851+
if features_adapter is not None:
852+
hidden_states = hidden_states + features_adapter
853+
845854
if self.downsamplers is not None:
846855
for downsampler in self.downsamplers:
847856
hidden_states = downsampler(hidden_states)

src/diffusers/models/unet_2d_condition.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,7 @@ def forward(
492492
timestep_cond: Optional[torch.Tensor] = None,
493493
attention_mask: Optional[torch.Tensor] = None,
494494
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
495+
features_adapter: list = None,
495496
return_dict: bool = True,
496497
) -> Union[UNet2DConditionOutput, Tuple]:
497498
r"""
@@ -574,21 +575,30 @@ def forward(
574575
sample = self.conv_in(sample)
575576

576577
# 3. down
578+
577579
down_block_res_samples = (sample,)
580+
feature_idx = 0
578581
for downsample_block in self.down_blocks:
579582
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
583+
fa = None
584+
if features_adapter is not None:
585+
fa = features_adapter[feature_idx]
580586
sample, res_samples = downsample_block(
581587
hidden_states=sample,
582588
temb=emb,
583589
encoder_hidden_states=encoder_hidden_states,
584590
attention_mask=attention_mask,
585591
cross_attention_kwargs=cross_attention_kwargs,
592+
features_adapter=fa,
586593
)
594+
feature_idx = feature_idx + 1
587595
else:
588596
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
589-
590597
down_block_res_samples += res_samples
591598

599+
if features_adapter is not None:
600+
sample = sample + features_adapter[feature_idx]
601+
592602
# 4. mid
593603
if self.mid_block is not None:
594604
sample = self.mid_block(

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,8 @@ def __call__(
500500
eta: float = 0.0,
501501
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
502502
latents: Optional[torch.FloatTensor] = None,
503+
features_adapter: list = None,
504+
features_adapter_strength: float = 0.4,
503505
prompt_embeds: Optional[torch.FloatTensor] = None,
504506
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
505507
output_type: Optional[str] = "pil",
@@ -638,13 +640,23 @@ def __call__(
638640
# expand the latents if we are doing classifier free guidance
639641
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
640642
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
641-
642643
# predict the noise residual
644+
645+
input_features_adapter = None
646+
#num_inference_steps
647+
if features_adapter_strength > 1:
648+
features_adapter_strength = 1
649+
elif features_adapter_strength < 0:
650+
features_adapter_strength = 0
651+
if i < int(num_inference_steps * features_adapter_strength):
652+
input_features_adapter = features_adapter
653+
643654
noise_pred = self.unet(
644655
latent_model_input,
645656
t,
646657
encoder_hidden_states=prompt_embeds,
647658
cross_attention_kwargs=cross_attention_kwargs,
659+
features_adapter=input_features_adapter
648660
).sample
649661

650662
# perform guidance

src/extra/t2iadapter/__init__.py

Whitespace-only changes.

src/extra/t2iadapter/adapter.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
#from ldm.modules.attention import SpatialTransformer, BasicTransformerBlock
5+
6+
def conv_nd(dims, *args, **kwargs):
7+
"""
8+
Create a 1D, 2D, or 3D convolution module.
9+
"""
10+
if dims == 1:
11+
return nn.Conv1d(*args, **kwargs)
12+
elif dims == 2:
13+
return nn.Conv2d(*args, **kwargs)
14+
elif dims == 3:
15+
return nn.Conv3d(*args, **kwargs)
16+
raise ValueError(f"unsupported dimensions: {dims}")
17+
18+
def avg_pool_nd(dims, *args, **kwargs):
19+
"""
20+
Create a 1D, 2D, or 3D average pooling module.
21+
"""
22+
if dims == 1:
23+
return nn.AvgPool1d(*args, **kwargs)
24+
elif dims == 2:
25+
return nn.AvgPool2d(*args, **kwargs)
26+
elif dims == 3:
27+
return nn.AvgPool3d(*args, **kwargs)
28+
raise ValueError(f"unsupported dimensions: {dims}")
29+
30+
class Downsample(nn.Module):
31+
"""
32+
A downsampling layer with an optional convolution.
33+
:param channels: channels in the inputs and outputs.
34+
:param use_conv: a bool determining if a convolution is applied.
35+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
36+
downsampling occurs in the inner-two dimensions.
37+
"""
38+
39+
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
40+
super().__init__()
41+
self.channels = channels
42+
self.out_channels = out_channels or channels
43+
self.use_conv = use_conv
44+
self.dims = dims
45+
stride = 2 if dims != 3 else (1, 2, 2)
46+
if use_conv:
47+
self.op = conv_nd(
48+
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
49+
)
50+
else:
51+
assert self.channels == self.out_channels
52+
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
53+
54+
def forward(self, x):
55+
assert x.shape[1] == self.channels
56+
return self.op(x)
57+
58+
59+
class ResnetBlock(nn.Module):
60+
def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
61+
super().__init__()
62+
ps = ksize//2
63+
if in_c != out_c or sk==False:
64+
self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
65+
else:
66+
# print('n_in')
67+
self.in_conv = None
68+
self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
69+
self.act = nn.ReLU()
70+
self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
71+
if sk==False:
72+
self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
73+
else:
74+
self.skep = None
75+
76+
self.down = down
77+
if self.down == True:
78+
self.down_opt = Downsample(in_c, use_conv=use_conv)
79+
80+
def forward(self, x):
81+
if self.down == True:
82+
x = self.down_opt(x)
83+
if self.in_conv is not None: # edit
84+
x = self.in_conv(x)
85+
86+
h = self.block1(x)
87+
h = self.act(h)
88+
h = self.block2(h)
89+
if self.skep is not None:
90+
return h + self.skep(x)
91+
else:
92+
return h + x
93+
94+
95+
class Adapter(nn.Module):
96+
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True):
97+
super(Adapter, self).__init__()
98+
self.unshuffle = nn.PixelUnshuffle(8)
99+
self.channels = channels
100+
self.nums_rb = nums_rb
101+
self.body = []
102+
for i in range(len(channels)):
103+
for j in range(nums_rb):
104+
if (i!=0) and (j==0):
105+
self.body.append(ResnetBlock(channels[i-1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv))
106+
else:
107+
self.body.append(ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
108+
self.body = nn.ModuleList(self.body)
109+
self.conv_in = nn.Conv2d(cin,channels[0], 3, 1, 1)
110+
111+
def forward(self, x):
112+
# unshuffle
113+
x = self.unshuffle(x)
114+
# extract features
115+
features = []
116+
x = self.conv_in(x)
117+
for i in range(len(self.channels)):
118+
for j in range(self.nums_rb):
119+
idx = i*self.nums_rb +j
120+
x = self.body[idx](x)
121+
features.append(x)
122+
123+
return features

0 commit comments

Comments
 (0)