Skip to content

Commit 03cd625

Browse files
authored
feat: add ip adapter benchmark (huggingface#6936)
* feat: add ip adapter benchmark * sdxl support too. * Empty-Commit
1 parent 001b140 commit 03cd625

File tree

3 files changed

+62
-1
lines changed

3 files changed

+62
-1
lines changed

benchmarks/base_classes.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,35 @@ def run_inference(self, pipe, args):
236236
)
237237

238238

239+
class IPAdapterTextToImageBenchmark(TextToImageBenchmark):
240+
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png"
241+
image = load_image(url)
242+
243+
def __init__(self, args):
244+
pipe = self.pipeline_class.from_pretrained(args.ckpt, torch_dtype=torch.float16).to("cuda")
245+
pipe.load_ip_adapter(
246+
args.ip_adapter_id[0],
247+
subfolder="models" if "sdxl" not in args.ip_adapter_id[1] else "sdxl_models",
248+
weight_name=args.ip_adapter_id[1],
249+
)
250+
251+
if args.run_compile:
252+
pipe.unet.to(memory_format=torch.channels_last)
253+
print("Run torch compile")
254+
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
255+
256+
pipe.set_progress_bar_config(disable=True)
257+
self.pipe = pipe
258+
259+
def run_inference(self, pipe, args):
260+
_ = pipe(
261+
prompt=PROMPT,
262+
ip_adapter_image=self.image,
263+
num_inference_steps=args.num_inference_steps,
264+
num_images_per_prompt=args.batch_size,
265+
)
266+
267+
239268
class ControlNetBenchmark(TextToImageBenchmark):
240269
pipeline_class = StableDiffusionControlNetPipeline
241270
aux_network_class = ControlNetModel
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import argparse
2+
import sys
3+
4+
5+
sys.path.append(".")
6+
from base_classes import IPAdapterTextToImageBenchmark # noqa: E402
7+
8+
9+
IP_ADAPTER_CKPTS = {
10+
"runwayml/stable-diffusion-v1-5": ("h94/IP-Adapter", "ip-adapter_sd15.bin"),
11+
"stabilityai/stable-diffusion-xl-base-1.0": ("h94/IP-Adapter", "ip-adapter_sdxl.bin"),
12+
}
13+
14+
15+
if __name__ == "__main__":
16+
parser = argparse.ArgumentParser()
17+
parser.add_argument(
18+
"--ckpt",
19+
type=str,
20+
default="runwayml/stable-diffusion-v1-5",
21+
choices=list(IP_ADAPTER_CKPTS.keys()),
22+
)
23+
parser.add_argument("--batch_size", type=int, default=1)
24+
parser.add_argument("--num_inference_steps", type=int, default=50)
25+
parser.add_argument("--model_cpu_offload", action="store_true")
26+
parser.add_argument("--run_compile", action="store_true")
27+
args = parser.parse_args()
28+
29+
args.ip_adapter_id = IP_ADAPTER_CKPTS[args.ckpt]
30+
benchmark_pipe = IPAdapterTextToImageBenchmark(args)
31+
args.ckpt = f"{args.ckpt} (IP-Adapter)"
32+
benchmark_pipe.benchmark(args)

benchmarks/run_all.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def main():
7272
command += " --run_compile"
7373
run_command(command.split())
7474

75-
elif file == "benchmark_sd_inpainting.py":
75+
elif file in ["benchmark_sd_inpainting.py", "benchmark_ip_adapters.py"]:
7676
sdxl_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
7777
command = f"python {file} --ckpt {sdxl_ckpt}"
7878
run_command(command.split())

0 commit comments

Comments
 (0)