Skip to content

Commit 47e7807

Browse files
committed
add benchmark and edit propagation. update colab
1 parent 182d4d7 commit 47e7807

16 files changed

+1476
-88
lines changed

.gitignore

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,13 @@ cython_debug/
161161

162162
*/.DS_Store
163163
.DS_Store
164+
165+
guided-diffusion/
166+
davis_results_sd/
167+
davis_results_adm/
168+
superpoint-1k/
169+
hpatches_results/
170+
superpoint-1k.zip
171+
SPair-71k.tar.gz
172+
SPair-71k/
173+
./guided-diffusion/models/256x256_diffusion_uncond.pt

README.md

Lines changed: 171 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Diffusion Features (DIFT)
2-
This repository contains code for paper "Emergent Correspondence from Image Diffusion".
2+
This repository contains code for our NeurIPS 2023 paper "Emergent Correspondence from Image Diffusion".
33

4-
### [Project Page](https://diffusionfeatures.github.io/) | [Paper](https://arxiv.org/abs/2306.03881) | [Colab Demo](https://colab.research.google.com/drive/1tUTJ3UJxbqnfvUMvYH5lxcqt0UdUjdq6?usp=sharing)
4+
### [Project Page](https://diffusionfeatures.github.io/) | [Paper](https://arxiv.org/abs/2306.03881) | [Colab Demo](https://colab.research.google.com/drive/1km6MGafhAvbPOouD3oo64aUXgLlWM6L1?usp=sharing)
55

66
![video](./assets/teaser.gif)
77

@@ -11,7 +11,7 @@ If you have a Linux machine, you could either set up the python environment usin
1111
conda env create -f environment.yml
1212
conda activate dift
1313
```
14-
or create a new conda environment and install the packages manually using the
14+
or create a new conda environment and install the packages manually using the
1515
shell commands in [setup_env.sh](setup_env.sh).
1616

1717
## Interactive Demo: Give it a Try!
@@ -24,7 +24,7 @@ We provide an interative jupyter notebook [demo.ipynb](demo.ipynb) to demonstrat
2424
</tr>
2525
</table>
2626

27-
If you don't have a local GPU, you can also use the provided [Colab Demo](https://colab.research.google.com/drive/1tUTJ3UJxbqnfvUMvYH5lxcqt0UdUjdq6?usp=sharing).
27+
If you don't have a local GPU, you can also use the provided [Colab Demo](https://colab.research.google.com/drive/1km6MGafhAvbPOouD3oo64aUXgLlWM6L1?usp=sharing).
2828

2929
## Extract DIFT for a given image
3030
You could use the following [command](extract_dift.sh) to extract DIFT from a given image, and save it as a torch tensor. These arguments are set to the same as in the semantic correspondence tasks by default.
@@ -47,10 +47,175 @@ Here're the explanation for each argument:
4747
- `prompt`: the prompt used in the diffusion model.
4848
- `ensemble_size`: the number of repeated images in each batch used to get features. `ensemble_size=8` by default. You can reduce this value if encountering memory issue.
4949

50-
The output DIFT tensor spatial size is determined by both `img_size` and `up_ft_index`. If `up_ft_index=0`, the output size would be 1/32 of `img_size`; if `up_ft_index=1`, it would be 1/16; if `up_ft_index=2 or 3`, it would be 1/8.
50+
The output DIFT tensor spatial size is determined by both `img_size` and `up_ft_index`. If `up_ft_index=0`, the output size would be 1/32 of `img_size`; if `up_ft_index=1`, it would be 1/16; if `up_ft_index=2 or 3`, it would be 1/8.
5151

5252
## Application: Edit Propagation
5353
Using DIFT, we can propagate edits in one image to others that share semantic correspondences, even cross categories and domains:
5454
<img src="./assets/edit_cat.gif" alt="edit cat" style="width:90%;">
55+
More implementation details are in this notebook [edit_propagation.ipynb](edit_propagation.ipynb).
5556

56-
Check out more videos and visualizations in the [project page](https://diffusionfeatures.github.io/).
57+
## Get Benchmark Evaluation Results
58+
First, run the following scripts to enable the usage of DIFT_adm:
59+
```
60+
git clone [email protected]:openai/guided-diffusion.git
61+
cd guided-diffusion && mkdir models && cd models
62+
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt
63+
```
64+
65+
### SPair-71k
66+
67+
First, download SPair-71k data:
68+
```
69+
wget https://cvlab.postech.ac.kr/research/SPair-71k/data/SPair-71k.tar.gz
70+
tar -xzvf SPair-71k.tar.gz
71+
```
72+
Run the following script to get PCK (both per point and per img) of DIFT_sd on SPair-71k:
73+
```
74+
python eval_spair.py \
75+
--dataset_path ./SPair-71k \
76+
--save_path ./spair_ft \ # a path to save features
77+
--dift_model sd \
78+
--img_size 768 768 \
79+
--t 261 \
80+
--up_ft_index 2 \
81+
--ensemble_size 8
82+
```
83+
Run the following script to get PCK (both per point and per img) of DIFT_adm on SPair-71k:
84+
```
85+
python eval_spair.py \
86+
--dataset_path ./SPair-71k \
87+
--save_path ./spair_ft \ # a path to save features
88+
--dift_model adm \
89+
--img_size 512 512 \
90+
--t 101 \
91+
--up_ft_index 4 \
92+
--ensemble_size 8
93+
```
94+
95+
### HPatches
96+
97+
First, prepare HPatches data:
98+
```
99+
cd $HOME
100+
git clone [email protected]:mihaidusmanu/d2-net.git && cd d2-net/hpatches_sequences/
101+
chmod u+x download.sh
102+
./download.sh
103+
```
104+
105+
Then, download the 1k superpoint keypoints:
106+
```
107+
wget "https://www.dropbox.com/scl/fi/1mxy3oycnz7m2acd92u2x/superpoint-1k.zip?rlkey=fic30gr2tlth3cmsyyywcg385&dl=1" -O superpoint-1k.zip
108+
unzip superpoint-1k.zip
109+
rm superpoint-1k.zip
110+
```
111+
112+
Run the following script to get hompography estimation accuracy of DIFT_sd on HPatches:
113+
```
114+
python eval_hpatches.py \
115+
--hpatches_path ../d2-net/hpatches_sequences/hpatches-sequences-release \
116+
--kpts_path ./superpoint-1k \
117+
--save_path ./hpatches_results \
118+
--dift_model sd \
119+
--img_size 768 768 \
120+
--t 0 \
121+
--up_ft_index 2 \
122+
--ensemble_size 8
123+
124+
python eval_homography.py \
125+
--hpatches_path ../d2-net/hpatches_sequences/hpatches-sequences-release \
126+
--save_path ./hpatches_results \
127+
--hpatches_path
128+
--feat dift_sd \
129+
--metric cosine \
130+
--mode lmeds
131+
```
132+
133+
Run the following script to get hompography estimation accuracy of DIFT_adm on HPatches:
134+
```
135+
python eval_hpatches.py \
136+
--hpatches_path ../d2-net/hpatches_sequences/hpatches-sequences-release \
137+
--kpts_path ./superpoint-1k \
138+
--save_path ./hpatches_results \
139+
--dift_model adm \
140+
--img_size 768 768 \
141+
--t 41 \
142+
--up_ft_index 11 \
143+
--ensemble_size 4
144+
145+
python eval_homography.py \
146+
--hpatches_path ../d2-net/hpatches_sequences/hpatches-sequences-release \
147+
--save_path ./hpatches_results \
148+
--hpatches_path
149+
--feat dift_adm \
150+
--metric l2 \
151+
--mode ransac
152+
```
153+
154+
### DAVIS
155+
156+
We follow the evaluation protocal as in DINO's [implementation](https://github.com/facebookresearch/dino#evaluation-davis-2017-video-object-segmentation).
157+
158+
First, prepare DAVIS 2017 data and evaluation tools:
159+
```
160+
cd $HOME
161+
git clone https://github.com/davisvideochallenge/davis-2017 && cd davis-2017
162+
./data/get_davis.sh
163+
cd $HOME
164+
git clone https://github.com/davisvideochallenge/davis2017-evaluation
165+
```
166+
167+
Then, get segmentation results using DIFT_sd:
168+
```
169+
python eval_davis.py \
170+
--dift_model sd \
171+
--t 51 \
172+
--up_ft_index 2 \
173+
--temperature 0.2 \
174+
--topk 15 \
175+
--n_last_frames 28 \
176+
--ensemble_size 8 \
177+
--size_mask_neighborhood 15 \
178+
--data_path $HOME/davis-2017/DAVIS/ \
179+
--output_dir ./davis_results_sd/
180+
```
181+
182+
and results using DIFT_adm:
183+
```
184+
python eval_davis.py \
185+
--dift_model adm \
186+
--t 51 \
187+
--up_ft_index 7 \
188+
--temperature 0.1 \
189+
--topk 10 \
190+
--n_last_frames 28 \
191+
--ensemble_size 4 \
192+
--size_mask_neighborhood 15 \
193+
--data_path $HOME/davis-2017/DAVIS/ \
194+
--output_dir ./davis_results_adm/
195+
```
196+
197+
Finally, evaluate the results:
198+
```
199+
python $HOME/davis2017-evaluation/evaluation_method.py \
200+
--task semi-supervised \
201+
--results_path ./davis_results_sd/ \
202+
--davis_path $HOME/davis-2017/DAVIS/
203+
204+
python $HOME/davis2017-evaluation/evaluation_method.py \
205+
--task semi-supervised \
206+
--results_path ./davis_results_adm/ \
207+
--davis_path $HOME/davis-2017/DAVIS/
208+
```
209+
210+
# Misc.
211+
If you find our code or paper useful to your research work, please consider citing our work using the following bibtex:
212+
```
213+
@inproceedings{
214+
tang2023emergent,
215+
title={Emergent Correspondence from Image Diffusion},
216+
author={Luming Tang and Menglin Jia and Qianqian Wang and Cheng Perng Phoo and Bharath Hariharan},
217+
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
218+
year={2023},
219+
url={https://openreview.net/forum?id=ypOiXjdfnU}
220+
}
221+
```

assets/cartoon.png

19.3 KB
Loading

assets/guitar_cat.jpg

121 KB
Loading

assets/painting_cat.jpg

103 KB
Loading

demo.ipynb

Lines changed: 12 additions & 54 deletions
Large diffs are not rendered by default.

edit_propagation.ipynb

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "3306ccce-4b17-41a9-831d-add6cccddc0e",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"import torch\n",
11+
"import torch.nn as nn\n",
12+
"import matplotlib.pyplot as plt\n",
13+
"import numpy as np\n",
14+
"import gc\n",
15+
"import imageio\n",
16+
"from PIL import Image\n",
17+
"from torchvision.transforms import PILToTensor\n",
18+
"import os\n",
19+
"import json\n",
20+
"from PIL import Image, ImageDraw\n",
21+
"import torch.nn.functional as F\n",
22+
"import cv2\n",
23+
"import glob\n",
24+
"from torchvision.transforms import PILToTensor\n",
25+
"from src.models.dift_sd import SDFeaturizer4Eval"
26+
]
27+
},
28+
{
29+
"cell_type": "code",
30+
"execution_count": null,
31+
"id": "081cd585-9d9d-4ffe-8c9b-6c6360d2e4ad",
32+
"metadata": {},
33+
"outputs": [],
34+
"source": [
35+
"def gen_grid(h, w, device, normalize=False, homogeneous=False):\n",
36+
" if normalize:\n",
37+
" lin_y = torch.linspace(-1., 1., steps=h, device=device)\n",
38+
" lin_x = torch.linspace(-1., 1., steps=w, device=device)\n",
39+
" else:\n",
40+
" lin_y = torch.arange(0, h, device=device)\n",
41+
" lin_x = torch.arange(0, w, device=device)\n",
42+
" grid_y, grid_x = torch.meshgrid((lin_y, lin_x))\n",
43+
" grid = torch.stack((grid_x, grid_y), -1)\n",
44+
" if homogeneous:\n",
45+
" grid = torch.cat([grid, torch.ones_like(grid[..., :1])], dim=-1)\n",
46+
" return grid # [h, w, 2 or 3]\n",
47+
"\n",
48+
"\n",
49+
"def normalize_coords(coords, h, w, no_shift=False):\n",
50+
" assert coords.shape[-1] == 2\n",
51+
" if no_shift:\n",
52+
" return coords / torch.tensor([w-1., h-1.], device=coords.device) * 2\n",
53+
" else:\n",
54+
" return coords / torch.tensor([w-1., h-1.], device=coords.device) * 2 - 1."
55+
]
56+
},
57+
{
58+
"cell_type": "code",
59+
"execution_count": null,
60+
"id": "2a13b459-4698-4a9c-803f-d7ba8adb6962",
61+
"metadata": {},
62+
"outputs": [],
63+
"source": [
64+
"cat = 'cat'\n",
65+
"dift = SDFeaturizer4Eval(cat_list=['cat'])"
66+
]
67+
},
68+
{
69+
"cell_type": "code",
70+
"execution_count": null,
71+
"id": "0606e9dd-9e51-49ec-bf37-1f2bc9f78a84",
72+
"metadata": {},
73+
"outputs": [],
74+
"source": [
75+
"src_img = Image.open('./assets/guitar_cat.jpg').convert('RGB')\n",
76+
"trg_img = Image.open('./assets/painting_cat.jpg').convert('RGB')\n",
77+
"sticker = imageio.imread('./assets/cartoon.png')\n",
78+
"sticker_color, sticker_mask = sticker[..., :3], sticker[..., 3]\n",
79+
"\n",
80+
"assert np.array(src_img).shape[:2] == sticker.shape[:2]\n",
81+
"h_src, w_src = sticker.shape[:2]\n",
82+
"h_trg, w_trg = np.array(trg_img).shape[:2]\n",
83+
"\n",
84+
"sd_feat_src = dift.forward(src_img, cat)\n",
85+
"sd_feat_trg = dift.forward(trg_img, cat)\n",
86+
"\n",
87+
"sd_feat_src = F.normalize(sd_feat_src.squeeze(), p=2, dim=0)\n",
88+
"sd_feat_trg = F.normalize(sd_feat_trg.squeeze(), p=2, dim=0)\n",
89+
"feat_dim = sd_feat_src.shape[0]\n",
90+
"\n",
91+
"grid_src = gen_grid(h_src, w_src, device='cuda')\n",
92+
"grid_trg = gen_grid(h_trg, w_trg, device='cuda')\n",
93+
"\n",
94+
"coord_src = grid_src[sticker_mask > 0]\n",
95+
"coord_src = coord_src[torch.randperm(len(coord_src))][:1000]\n",
96+
"coord_src_normed = normalize_coords(coord_src, h_src, w_src)\n",
97+
"grid_trg_normed = normalize_coords(grid_trg, h_trg, w_trg)\n",
98+
"\n",
99+
"feat_src = F.grid_sample(sd_feat_src[None], coord_src_normed[None, None], align_corners=True).squeeze().T\n",
100+
"feat_trg = F.grid_sample(sd_feat_trg[None], grid_trg_normed[None], align_corners=True).squeeze()\n",
101+
"feat_trg_flattened = feat_trg.permute(1, 2, 0).reshape(-1, feat_dim)\n",
102+
"\n",
103+
"distances = torch.cdist(feat_src, feat_trg_flattened)\n",
104+
"_, indices = torch.min(distances, dim=1)\n",
105+
"\n",
106+
"src_pts = coord_src.reshape(-1, 2).cpu().numpy()\n",
107+
"trg_pts = grid_trg.reshape(-1, 2)[indices].cpu().numpy()\n",
108+
"\n",
109+
"M, mask = cv2.findHomography(src_pts, trg_pts, cv2.RANSAC, 5.0)\n",
110+
"sticker_out = cv2.warpPerspective(sticker, M, (w_trg, h_trg))\n",
111+
"\n",
112+
"sticker_out_alpha = sticker_out[..., 3:] / 255\n",
113+
"sticker_alpha = sticker[..., 3:] / 255\n",
114+
"\n",
115+
"trg_img_with_sticker = sticker_out_alpha * sticker_out[..., :3] + (1 - sticker_out_alpha) * trg_img\n",
116+
"src_img_with_sticker = sticker_alpha * sticker[..., :3] + (1 - sticker_alpha) * src_img"
117+
]
118+
},
119+
{
120+
"cell_type": "code",
121+
"execution_count": null,
122+
"id": "88723600-c18f-4eb1-aec7-feb4112e2610",
123+
"metadata": {},
124+
"outputs": [],
125+
"source": [
126+
"fig, axs = plt.subplots(2, 2, figsize=(10, 10))\n",
127+
"\n",
128+
"axs[0, 0].imshow(src_img)\n",
129+
"axs[0, 0].set_title(\"Source Image\")\n",
130+
"axs[0, 0].axis('off')\n",
131+
"\n",
132+
"axs[0, 1].imshow(src_img_with_sticker.astype(np.uint8))\n",
133+
"axs[0, 1].set_title(\"Source Image with Edits\")\n",
134+
"axs[0, 1].axis('off')\n",
135+
"\n",
136+
"axs[1, 0].imshow(trg_img)\n",
137+
"axs[1, 0].set_title(\"Target Image\")\n",
138+
"axs[1, 0].axis('off')\n",
139+
"\n",
140+
"axs[1, 1].imshow(trg_img_with_sticker.astype(np.uint8))\n",
141+
"axs[1, 1].set_title(\"Target Image with Propagated Edits\")\n",
142+
"axs[1, 1].axis('off')\n",
143+
"\n",
144+
"plt.tight_layout()\n",
145+
"plt.show()"
146+
]
147+
}
148+
],
149+
"metadata": {
150+
"kernelspec": {
151+
"display_name": "Python 3 (ipykernel)",
152+
"language": "python",
153+
"name": "python3"
154+
},
155+
"language_info": {
156+
"codemirror_mode": {
157+
"name": "ipython",
158+
"version": 3
159+
},
160+
"file_extension": ".py",
161+
"mimetype": "text/x-python",
162+
"name": "python",
163+
"nbconvert_exporter": "python",
164+
"pygments_lexer": "ipython3",
165+
"version": "3.10.9"
166+
}
167+
},
168+
"nbformat": 4,
169+
"nbformat_minor": 5
170+
}

0 commit comments

Comments
 (0)