14
14
15
15
run = wandb .init ()
16
16
config = run .config
17
- #fixed size for DenseNet121
17
+ # fixed size for DenseNet121
18
18
config .img_width = 224
19
19
config .img_height = 224
20
20
config .epochs = 10
21
- config .batch_size = 128
21
+ config .batch_size = 32
22
+
22
23
23
24
def setup_to_transfer_learn (model , base_model ):
24
25
"""Freeze all layers and compile the model"""
@@ -30,7 +31,9 @@ def setup_to_transfer_learn(model, base_model):
30
31
epsilon = None ,
31
32
decay = 0.0 )
32
33
sgd = SGD (lr = 0.001 , momentum = 0.9 )
33
- model .compile (optimizer = sgd , loss = 'categorical_crossentropy' , metrics = ['accuracy' ])
34
+ model .compile (optimizer = sgd , loss = 'categorical_crossentropy' ,
35
+ metrics = ['accuracy' ])
36
+
34
37
35
38
def add_new_last_layer (base_model , nb_classes , activation = 'softmax' ):
36
39
"""Add last layer to the convnet
@@ -43,13 +46,15 @@ def add_new_last_layer(base_model, nb_classes, activation='softmax'):
43
46
predictions = Dense (nb_classes , activation = activation )(base_model .output )
44
47
return Model (inputs = base_model .input , outputs = predictions )
45
48
49
+
46
50
train_dir = "dogcat-data/train"
47
51
val_dir = "dogcat-data/validation"
48
52
nb_train_samples = get_nb_files (train_dir )
49
53
nb_classes = len (glob .glob (train_dir + "/*" ))
50
54
nb_val_samples = get_nb_files (val_dir )
51
55
52
- train_generator , validation_generator = generators (preprocess_input , config .img_width , config .img_height , config .batch_size )
56
+ train_generator , validation_generator = generators (
57
+ preprocess_input , config .img_width , config .img_height , config .batch_size )
53
58
54
59
# setup model
55
60
base_model = DenseNet121 (input_shape = (config .img_width , config .img_height , 3 ),
@@ -69,7 +74,8 @@ def add_new_last_layer(base_model, nb_classes, activation='softmax'):
69
74
steps_per_epoch = nb_train_samples * 2 / config .batch_size ,
70
75
validation_data = validation_generator ,
71
76
validation_steps = nb_train_samples / config .batch_size ,
72
- callbacks = [WandbCallback (data_type = "image" , generator = validation_generator , labels = ['cat' , 'dog' ], save_model = False )],
77
+ callbacks = [WandbCallback (data_type = "image" , generator = validation_generator , labels = [
78
+ 'cat' , 'dog' ], save_model = False )],
73
79
class_weight = 'auto' )
74
80
75
81
model .save ("transfered.h5" )
0 commit comments