Skip to content

Commit ea0f68a

Browse files
authored
Merge branch 'facebookresearch:main' into main
2 parents 7408767 + efeab72 commit ea0f68a

File tree

7 files changed

+87
-51
lines changed

7 files changed

+87
-51
lines changed

README.md

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
[Alexander Kirillov](https://alexander-kirillov.github.io/), [Eric Mintun](https://ericmintun.github.io/), [Nikhila Ravi](https://nikhilaravi.com/), [Hanzi Mao](https://hanzimao.me/), Chloe Rolland, Laura Gustafson, [Tete Xiao](https://tetexiao.com), [Spencer Whitehead](https://www.spencerwhitehead.com/), Alex Berg, Wan-Yen Lo, [Piotr Dollar](https://pdollar.github.io/), [Ross Girshick](https://www.rossgirshick.info/)
66

7-
[[`Paper`](https://ai.facebook.com/research/publications/segment-anything/)] [[`Project`](https://segment-anything.com/)] [[`Demo`](https://segment-anything.com/demo)] [[`Dataset`](https://segment-anything.com/dataset/index.html)] [[`Blog`](https://ai.facebook.com/blog/segment-anything-foundation-model-image-segmentation/)]
7+
[[`Paper`](https://ai.facebook.com/research/publications/segment-anything/)] [[`Project`](https://segment-anything.com/)] [[`Demo`](https://segment-anything.com/demo)] [[`Dataset`](https://segment-anything.com/dataset/index.html)] [[`Blog`](https://ai.facebook.com/blog/segment-anything-foundation-model-image-segmentation/)] [[`BibTeX`](#citing-segment-anything)]
88

99
![SAM design](assets/model_diagram.png?raw=true)
1010

@@ -43,24 +43,26 @@ pip install opencv-python pycocotools matplotlib onnxruntime onnx
4343
First download a [model checkpoint](#model-checkpoints). Then the model can be used in just a few lines to get masks from a given prompt:
4444

4545
```
46-
from segment_anything import build_sam, SamPredictor
47-
predictor = SamPredictor(build_sam(checkpoint="</path/to/model.pth>"))
46+
from segment_anything import SamPredictor, sam_model_registry
47+
sam = sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")
48+
predictor = SamPredictor(sam)
4849
predictor.set_image(<your_image>)
4950
masks, _, _ = predictor.predict(<input_prompts>)
5051
```
5152

5253
or generate masks for an entire image:
5354

5455
```
55-
from segment_anything import build_sam, SamAutomaticMaskGenerator
56-
mask_generator = SamAutomaticMaskGenerator(build_sam(checkpoint="</path/to/model.pth>"))
56+
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
57+
sam = sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")
58+
mask_generator = SamAutomaticMaskGenerator(sam)
5759
masks = mask_generator.generate(<your_image>)
5860
```
5961

6062
Additionally, masks can be generated for images from the command line:
6163

6264
```
63-
python scripts/amg.py --checkpoint <path/to/sam/checkpoint> --input <image_or_folder> --output <output_directory>
65+
python scripts/amg.py --checkpoint <path/to/checkpoint> --model-type <model_type> --input <image_or_folder> --output <path/to/output>
6466
```
6567

6668
See the examples notebooks on [using SAM with prompts](/notebooks/predictor_example.ipynb) and [automatically generating masks](/notebooks/automatic_mask_generator_example.ipynb) for more details.
@@ -75,7 +77,7 @@ See the examples notebooks on [using SAM with prompts](/notebooks/predictor_exam
7577
SAM's lightweight mask decoder can be exported to ONNX format so that it can be run in any environment that supports ONNX runtime, such as in-browser as showcased in the [demo](https://segment-anything.com/demo). Export the model with
7678

7779
```
78-
python scripts/export_onnx_model.py --checkpoint <path/to/checkpoint> --output <path/to/output>
80+
python scripts/export_onnx_model.py --checkpoint <path/to/checkpoint> --model-type <model_type> --output <path/to/output>
7981
```
8082

8183
See the [example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb) for details on how to combine image preprocessing via SAM's backbone with mask prediction using the ONNX model. It is recommended to use the latest stable version of PyTorch for ONNX export.
@@ -85,14 +87,55 @@ See the [example notebook](https://github.com/facebookresearch/segment-anything/
8587
Three model versions of the model are available with different backbone sizes. These models can be instantiated by running
8688
```
8789
from segment_anything import sam_model_registry
88-
sam = sam_model_registry["<name>"](checkpoint="<path/to/checkpoint>")
90+
sam = sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")
8991
```
90-
Click the links below to download the checkpoint for the corresponding model name. The default model in bold can also be instantiated with `build_sam`, as in the examples in [Getting Started](#getting-started).
92+
Click the links below to download the checkpoint for the corresponding model type.
9193

9294
* **`default` or `vit_h`: [ViT-H SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth)**
9395
* `vit_l`: [ViT-L SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth)
9496
* `vit_b`: [ViT-B SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth)
9597

98+
## Dataset
99+
See [here](https://ai.facebook.com/datasets/segment-anything/) for an overview of the datastet. The dataset can be downloaded [here](https://ai.facebook.com/datasets/segment-anything-downloads/). By downloading the datasets you agree that you have read and accepted the terms of the SA-1B Dataset Research License.
100+
101+
We save masks per image as a json file. It can be loaded as a dictionary in python in the below format.
102+
103+
104+
```python
105+
{
106+
"image" : image_info,
107+
"annotations" : [annotation],
108+
}
109+
110+
image_info {
111+
"image_id" : int, # Image id
112+
"width" : int, # Image width
113+
"height" : int, # Image height
114+
"file_name" : str, # Image filename
115+
}
116+
117+
annotation {
118+
"id" : int, # Annotation id
119+
"segmentation" : dict, # Mask saved in COCO RLE format.
120+
"bbox" : [x, y, w, h], # The box around the mask, in XYWH format
121+
"area" : int, # The area in pixels of the mask
122+
"predicted_iou" : float, # The model's own prediction of the mask's quality
123+
"stability_score" : float, # A measure of the mask's quality
124+
"crop_box" : [x, y, w, h], # The crop of the image used to generate the mask, in XYWH format
125+
"point_coords" : [[x, y]], # The point coordinates input to the model to generate the mask
126+
}
127+
```
128+
129+
Image ids can be found in sa_images_ids.txt which can be downloaded using the above [link](https://ai.facebook.com/datasets/segment-anything-downloads/) as well.
130+
131+
To decode a mask in COCO RLE format into binary:
132+
```
133+
from pycocotools import mask as mask_utils
134+
mask = mask_utils.decode(annotation["segmentation"])
135+
```
136+
See [here](https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/mask.py) for more instructions to manipulate masks stored in RLE format.
137+
138+
96139
## License
97140
The model is licensed under the [Apache 2.0 license](LICENSE).
98141

@@ -105,3 +148,16 @@ See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md
105148
The Segment Anything project was made possible with the help of many contributors (alphabetical):
106149

107150
Aaron Adcock, Vaibhav Aggarwal, Morteza Behrooz, Cheng-Yang Fu, Ashley Gabriel, Ahuva Goldstand, Allen Goodman, Sumanth Gurram, Jiabo Hu, Somya Jain, Devansh Kukreja, Robert Kuo, Joshua Lane, Yanghao Li, Lilian Luong, Jitendra Malik, Mallika Malhotra, William Ngan, Omkar Parkhi, Nikhil Raina, Dirk Rowe, Neil Sejoor, Vanessa Stark, Bala Varadarajan, Bram Wasti, Zachary Winstrom
151+
152+
## Citing Segment Anything
153+
154+
If you use SAM or SA-1B in your research, please use the following BibTeX entry.
155+
156+
```
157+
@article{kirillov2023segany,
158+
title={Segment Anything},
159+
author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
160+
journal={arXiv:2304.02643},
161+
year={2023}
162+
}
163+
```

notebooks/automatic_mask_generator_example.ipynb

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -214,19 +214,6 @@
214214
"To run automatic mask generation, provide a SAM model to the `SamAutomaticMaskGenerator` class. Set the path below to the SAM checkpoint. Running on CUDA and with the default model is recommended."
215215
]
216216
},
217-
{
218-
"cell_type": "code",
219-
"execution_count": 9,
220-
"id": "17ade22d",
221-
"metadata": {},
222-
"outputs": [],
223-
"source": [
224-
"sam_checkpoint = \"sam_vit_h_4b8939.pth\"\n",
225-
"\n",
226-
"device = \"cuda\"\n",
227-
"model_type = \"default\""
228-
]
229-
},
230217
{
231218
"cell_type": "code",
232219
"execution_count": 10,
@@ -238,6 +225,11 @@
238225
"sys.path.append(\"..\")\n",
239226
"from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor\n",
240227
"\n",
228+
"sam_checkpoint = \"sam_vit_h_4b8939.pth\"\n",
229+
"model_type = \"vit_h\"\n",
230+
"\n",
231+
"device = \"cuda\"\n",
232+
"\n",
241233
"sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n",
242234
"sam.to(device=device)\n",
243235
"\n",
@@ -446,7 +438,7 @@
446438
"name": "python",
447439
"nbconvert_exporter": "python",
448440
"pygments_lexer": "ipython3",
449-
"version": "3.10.10"
441+
"version": "3.8.0"
450442
}
451443
},
452444
"nbformat": 4,

notebooks/onnx_model_example.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@
192192
"outputs": [],
193193
"source": [
194194
"checkpoint = \"sam_vit_h_4b8939.pth\"\n",
195-
"model_type = \"default\""
195+
"model_type = \"vit_h\""
196196
]
197197
},
198198
{
@@ -766,7 +766,7 @@
766766
"name": "python",
767767
"nbconvert_exporter": "python",
768768
"pygments_lexer": "ipython3",
769-
"version": "3.10.10"
769+
"version": "3.8.0"
770770
}
771771
},
772772
"nbformat": 4,

notebooks/predictor_example.ipynb

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -229,18 +229,6 @@
229229
"First, load the SAM model and predictor. Change the path below to point to the SAM checkpoint. Running on CUDA and using the default model are recommended for best results."
230230
]
231231
},
232-
{
233-
"cell_type": "code",
234-
"execution_count": 9,
235-
"id": "17ccff22",
236-
"metadata": {},
237-
"outputs": [],
238-
"source": [
239-
"sam_checkpoint = \"sam_vit_h_4b8939.pth\"\n",
240-
"device = \"cuda\"\n",
241-
"model_type = \"default\""
242-
]
243-
},
244232
{
245233
"cell_type": "code",
246234
"execution_count": 10,
@@ -252,6 +240,11 @@
252240
"sys.path.append(\"..\")\n",
253241
"from segment_anything import sam_model_registry, SamPredictor\n",
254242
"\n",
243+
"sam_checkpoint = \"sam_vit_h_4b8939.pth\"\n",
244+
"model_type = \"vit_h\"\n",
245+
"\n",
246+
"device = \"cuda\"\n",
247+
"\n",
255248
"sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n",
256249
"sam.to(device=device)\n",
257250
"\n",
@@ -1015,7 +1008,7 @@
10151008
"name": "python",
10161009
"nbconvert_exporter": "python",
10171010
"pygments_lexer": "ipython3",
1018-
"version": "3.10.10"
1011+
"version": "3.8.0"
10191012
}
10201013
},
10211014
"nbformat": 4,

scripts/amg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141
parser.add_argument(
4242
"--model-type",
4343
type=str,
44-
default="default",
45-
help="The type of model to load, in ['default', 'vit_l', 'vit_b']",
44+
required=True,
45+
help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
4646
)
4747

4848
parser.add_argument(

scripts/export_onnx_model.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import torch
88

9-
from segment_anything import build_sam, build_sam_vit_b, build_sam_vit_l
9+
from segment_anything import sam_model_registry
1010
from segment_anything.utils.onnx import SamOnnxModel
1111

1212
import argparse
@@ -34,8 +34,8 @@
3434
parser.add_argument(
3535
"--model-type",
3636
type=str,
37-
default="default",
38-
help="In ['default', 'vit_b', 'vit_l']. Which type of SAM model to export.",
37+
required=True,
38+
help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.",
3939
)
4040

4141
parser.add_argument(
@@ -105,12 +105,7 @@ def run_export(
105105
return_extra_metrics=False,
106106
):
107107
print("Loading model...")
108-
if model_type == "vit_b":
109-
sam = build_sam_vit_b(checkpoint)
110-
elif model_type == "vit_l":
111-
sam = build_sam_vit_l(checkpoint)
112-
else:
113-
sam = build_sam(checkpoint)
108+
sam = sam_model_registry[model_type](checkpoint=checkpoint)
114109

115110
onnx_model = SamOnnxModel(
116111
model=sam,

segment_anything/build_sam.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def build_sam_vit_b(checkpoint=None):
4545

4646

4747
sam_model_registry = {
48-
"default": build_sam,
49-
"vit_h": build_sam,
48+
"default": build_sam_vit_h,
49+
"vit_h": build_sam_vit_h,
5050
"vit_l": build_sam_vit_l,
5151
"vit_b": build_sam_vit_b,
5252
}

0 commit comments

Comments
 (0)