Segmentation models is python library with Neural Networks for Image Segmentation based on PyTorch.
The main features of this library are:
- High level API (just two lines to create neural network)
- 4 models architectures for binary and multi class segmentation (including legendary Unet)
- 30 available encoders for each architecture
- All encoders have pre-trained weights for faster and better convergence
Since the library is built on the PyTorch framework, created segmentation model is just a PyTorch nn.Module, which can be created as easy as:
import segmentation_models_pytorch as smp
model = smp.Unet()Depending on the task, you can change the network architecture by choosing backbones with fewer or more parameters and use pretrainded weights to initialize it:
model = smp.Unet('resnet34', encoder_weights='imagenet')Change number of output classes in the model:
model = smp.Unet('resnet34', classes=3, activation='softmax')All models have pretrained encoders, so you have to prepare your data the same way as during weights pretraining:
from segmentation_models_pytorch.encoders import get_preprocessing_fn
preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')- Training model for cars segmentation on CamVid dataset here.
- Training model with Catalyst (high-level framework for PyTorch) - here.
| Type | Encoder names |
|---|---|
| VGG | vgg11, vgg13, vgg16, vgg19, vgg11bn, vgg13bn, vgg16bn, vgg19bn |
| DenseNet | densenet121, densenet169, densenet201, densenet161 |
| DPN | dpn68, dpn68b, dpn92, dpn98, dpn107, dpn131 |
| Inception | inceptionresnetv2 |
| ResNet | resnet18, resnet34, resnet50, resnet101, resnet152 |
| ResNeXt | resnext50_32x4d, resnext101_32x8d, resnext101_32x16d, resnext101_32x32d, resnext101_32x48d |
| SE-ResNet | se_resnet50, se_resnet101, se_resnet152 |
| SE-ResNeXt | se_resnext50_32x4d, se_resnext101_32x4d |
| SENet | senet154 |
| Weights name | Encoder names |
|---|---|
| imagenet+5k | dpn68b, dpn92, dpn107 |
| imagenet | vgg11, vgg13, vgg16, vgg19, vgg11bn, vgg13bn, vgg16bn, vgg19bn, densenet121, densenet169, densenet201, densenet161, dpn68, dpn98, dpn131, inceptionresnetv2, resnet18, resnet34, resnet50, resnet101, resnet152, resnext50_32x4d, resnext101_32x8d, se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d, se_resnext101_32x4d, senet154 |
| resnext101_32x8d, resnext101_32x16d, resnext101_32x32d, resnext101_32x48d |
model.encoder- pretrained backbone to extract features of different spatial resolutionmodel.decoder- segmentation head, depends on models architecture (Unet/Linknet/PSPNet/FPN)model.activation- output activation function, one ofsigmoid,softmaxmodel.forward(x)- sequentially passxthrough model`s encoder and decoder (return logits!)model.predict(x)- inference method, switch model to.eval()mode, call.forward(x)and apply activation function withtorch.no_grad()
PyPI version:
$ pip install segmentation-models-pytorchLatest version from source:
$ pip install git+https://github.com/qubvel/segmentation_models.pytorchProject is distributed under MIT License
$ docker build -f docker/Dockerfile.dev -t smp:dev .
$ docker run --rm smp:dev pytest -p no:cacheprovider