Skip to content

Commit 5f77a4d

Browse files
committed
Add Xception network
1 parent 7a7e328 commit 5f77a4d

File tree

1 file changed

+226
-0
lines changed

1 file changed

+226
-0
lines changed

xception.py

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
# -*- coding: utf-8 -*-
2+
'''Xception V1 model for Keras.
3+
4+
On ImageNet, this model gets to a top-1 validation accuracy of 0.790.
5+
and a top-5 validation accuracy of 0.945.
6+
7+
Do note that the input image format for this model is different than for
8+
the VGG16 and ResNet models (299x299 instead of 224x224),
9+
and that the input preprocessing function
10+
is also different (same as Inception V3).
11+
12+
Also do note that this model is only available for the TensorFlow backend,
13+
due to its reliance on `SeparableConvolution` layers.
14+
15+
# Reference:
16+
17+
- [Xception: Deep Learning with Depthwise Separable Convolutions](https://arxiv.org/abs/1610.02357)
18+
19+
'''
20+
from __future__ import print_function
21+
from __future__ import absolute_import
22+
23+
import warnings
24+
import numpy as np
25+
26+
from keras.models import Model
27+
from keras.layers import Dense, Input, BatchNormalization, Activation, merge
28+
from keras.layers import Conv2D, SeparableConv2D, MaxPooling2D, GlobalAveragePooling2D
29+
from keras.preprocessing import image
30+
from keras.utils.data_utils import get_file
31+
from keras import backend as K
32+
from imagenet_utils import decode_predictions
33+
34+
35+
TF_WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.4/xception_weights_tf_dim_ordering_tf_kernels.h5'
36+
TF_WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.4/xception_weights_tf_dim_ordering_tf_kernels_notop.h5'
37+
38+
39+
def Xception(include_top=True, weights='imagenet',
40+
input_tensor=None):
41+
'''Instantiate the Xception architecture,
42+
optionally loading weights pre-trained
43+
on ImageNet. This model is available for TensorFlow only,
44+
and can only be used with inputs following the TensorFlow
45+
dimension ordering `(width, height, channels)`.
46+
You should set `image_dim_ordering="tf"` in your Keras config
47+
located at ~/.keras/keras.json.
48+
49+
Note that the default input image size for this model is 299x299.
50+
51+
# Arguments
52+
include_top: whether to include the fully-connected
53+
layer at the top of the network.
54+
weights: one of `None` (random initialization)
55+
or "imagenet" (pre-training on ImageNet).
56+
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
57+
to use as image input for the model.
58+
59+
# Returns
60+
A Keras model instance.
61+
'''
62+
if weights not in {'imagenet', None}:
63+
raise ValueError('The `weights` argument should be either '
64+
'`None` (random initialization) or `imagenet` '
65+
'(pre-training on ImageNet).')
66+
if K.backend() != 'tensorflow':
67+
raise Exception('The Xception model is only available with '
68+
'the TensorFlow backend.')
69+
if K.image_dim_ordering() != 'tf':
70+
warnings.warn('The Xception model is only available for the '
71+
'input dimension ordering "tf" '
72+
'(width, height, channels). '
73+
'However your settings specify the default '
74+
'dimension ordering "th" (channels, width, height). '
75+
'You should set `image_dim_ordering="tf"` in your Keras '
76+
'config located at ~/.keras/keras.json. '
77+
'The model being returned right now will expect inputs '
78+
'to follow the "tf" dimension ordering.')
79+
K.set_image_dim_ordering('tf')
80+
old_dim_ordering = 'th'
81+
else:
82+
old_dim_ordering = None
83+
84+
# Determine proper input shape
85+
if include_top:
86+
input_shape = (299, 299, 3)
87+
else:
88+
input_shape = (None, None, 3)
89+
90+
if input_tensor is None:
91+
img_input = Input(shape=input_shape)
92+
else:
93+
if not K.is_keras_tensor(input_tensor):
94+
img_input = Input(tensor=input_tensor, shape=input_shape)
95+
else:
96+
img_input = input_tensor
97+
98+
x = Conv2D(32, 3, 3, subsample=(2, 2), bias=False, name='block1_conv1')(img_input)
99+
x = BatchNormalization(name='block1_conv1_bn')(x)
100+
x = Activation('relu', name='block1_conv1_act')(x)
101+
x = Conv2D(64, 3, 3, bias=False, name='block1_conv2')(x)
102+
x = BatchNormalization(name='block1_conv2_bn')(x)
103+
x = Activation('relu', name='block1_conv2_act')(x)
104+
105+
residual = Conv2D(128, 1, 1, subsample=(2, 2),
106+
border_mode='same', bias=False)(x)
107+
residual = BatchNormalization()(residual)
108+
109+
x = SeparableConv2D(128, 3, 3, border_mode='same', bias=False, name='block2_sepconv1')(x)
110+
x = BatchNormalization(name='block2_sepconv1_bn')(x)
111+
x = Activation('relu', name='block2_sepconv2_act')(x)
112+
x = SeparableConv2D(128, 3, 3, border_mode='same', bias=False, name='block2_sepconv2')(x)
113+
x = BatchNormalization(name='block2_sepconv2_bn')(x)
114+
115+
x = MaxPooling2D((3, 3), strides=(2, 2), border_mode='same', name='block2_pool')(x)
116+
x = merge([x, residual], mode='sum')
117+
118+
residual = Conv2D(256, 1, 1, subsample=(2, 2),
119+
border_mode='same', bias=False)(x)
120+
residual = BatchNormalization()(residual)
121+
122+
x = Activation('relu', name='block3_sepconv1_act')(x)
123+
x = SeparableConv2D(256, 3, 3, border_mode='same', bias=False, name='block3_sepconv1')(x)
124+
x = BatchNormalization(name='block3_sepconv1_bn')(x)
125+
x = Activation('relu', name='block3_sepconv2_act')(x)
126+
x = SeparableConv2D(256, 3, 3, border_mode='same', bias=False, name='block3_sepconv2')(x)
127+
x = BatchNormalization(name='block3_sepconv2_bn')(x)
128+
129+
x = MaxPooling2D((3, 3), strides=(2, 2), border_mode='same', name='block3_pool')(x)
130+
x = merge([x, residual], mode='sum')
131+
132+
residual = Conv2D(728, 1, 1, subsample=(2, 2),
133+
border_mode='same', bias=False)(x)
134+
residual = BatchNormalization()(residual)
135+
136+
x = Activation('relu', name='block4_sepconv1_act')(x)
137+
x = SeparableConv2D(728, 3, 3, border_mode='same', bias=False, name='block4_sepconv1')(x)
138+
x = BatchNormalization(name='block4_sepconv1_bn')(x)
139+
x = Activation('relu', name='block4_sepconv2_act')(x)
140+
x = SeparableConv2D(728, 3, 3, border_mode='same', bias=False, name='block4_sepconv2')(x)
141+
x = BatchNormalization(name='block4_sepconv2_bn')(x)
142+
143+
x = MaxPooling2D((3, 3), strides=(2, 2), border_mode='same', name='block4_pool')(x)
144+
x = merge([x, residual], mode='sum')
145+
146+
for i in range(8):
147+
residual = x
148+
prefix = 'block' + str(i + 5)
149+
150+
x = Activation('relu', name=prefix + '_sepconv1_act')(x)
151+
x = SeparableConv2D(728, 3, 3, border_mode='same', bias=False, name=prefix + '_sepconv1')(x)
152+
x = BatchNormalization(name=prefix + '_sepconv1_bn')(x)
153+
x = Activation('relu', name=prefix + '_sepconv2_act')(x)
154+
x = SeparableConv2D(728, 3, 3, border_mode='same', bias=False, name=prefix + '_sepconv2')(x)
155+
x = BatchNormalization(name=prefix + '_sepconv2_bn')(x)
156+
x = Activation('relu', name=prefix + '_sepconv3_act')(x)
157+
x = SeparableConv2D(728, 3, 3, border_mode='same', bias=False, name=prefix + '_sepconv3')(x)
158+
x = BatchNormalization(name=prefix + '_sepconv3_bn')(x)
159+
160+
x = merge([x, residual], mode='sum')
161+
162+
residual = Conv2D(1024, 1, 1, subsample=(2, 2),
163+
border_mode='same', bias=False)(x)
164+
residual = BatchNormalization()(residual)
165+
166+
x = Activation('relu', name='block13_sepconv1_act')(x)
167+
x = SeparableConv2D(728, 3, 3, border_mode='same', bias=False, name='block13_sepconv1')(x)
168+
x = BatchNormalization(name='block13_sepconv1_bn')(x)
169+
x = Activation('relu', name='block13_sepconv2_act')(x)
170+
x = SeparableConv2D(1024, 3, 3, border_mode='same', bias=False, name='block13_sepconv2')(x)
171+
x = BatchNormalization(name='block13_sepconv2_bn')(x)
172+
173+
x = MaxPooling2D((3, 3), strides=(2, 2), border_mode='same', name='block13_pool')(x)
174+
x = merge([x, residual], mode='sum')
175+
176+
x = SeparableConv2D(1536, 3, 3, border_mode='same', bias=False, name='block14_sepconv1')(x)
177+
x = BatchNormalization(name='block14_sepconv1_bn')(x)
178+
x = Activation('relu', name='block14_sepconv1_act')(x)
179+
180+
x = SeparableConv2D(2048, 3, 3, border_mode='same', bias=False, name='block14_sepconv2')(x)
181+
x = BatchNormalization(name='block14_sepconv2_bn')(x)
182+
x = Activation('relu', name='block14_sepconv2_act')(x)
183+
184+
if include_top:
185+
x = GlobalAveragePooling2D(name='avg_pool')(x)
186+
x = Dense(1000, activation='softmax', name='predictions')(x)
187+
188+
# Create model
189+
model = Model(img_input, x)
190+
191+
# load weights
192+
if weights == 'imagenet':
193+
if include_top:
194+
weights_path = get_file('xception_weights_tf_dim_ordering_tf_kernels.h5',
195+
TF_WEIGHTS_PATH,
196+
cache_subdir='models')
197+
else:
198+
weights_path = get_file('xception_weights_tf_dim_ordering_tf_kernels_notop.h5',
199+
TF_WEIGHTS_PATH_NO_TOP,
200+
cache_subdir='models')
201+
model.load_weights(weights_path)
202+
203+
if old_dim_ordering:
204+
K.set_image_dim_ordering(old_dim_ordering)
205+
return model
206+
207+
208+
def preprocess_input(x):
209+
x /= 255.
210+
x -= 0.5
211+
x *= 2.
212+
return x
213+
214+
215+
if __name__ == '__main__':
216+
model = Xception(include_top=True, weights='imagenet')
217+
218+
img_path = 'elephant.jpg'
219+
img = image.load_img(img_path, target_size=(299, 299))
220+
x = image.img_to_array(img)
221+
x = np.expand_dims(x, axis=0)
222+
x = preprocess_input(x)
223+
print('Input image shape:', x.shape)
224+
225+
preds = model.predict(x)
226+
print('Predicted:', decode_predictions(preds, 1))

0 commit comments

Comments
 (0)