21
21
# download the data if it doesn't exist
22
22
if not os .path .exists ("simpsons" ):
23
23
print ("Downloading Simpsons dataset..." )
24
- subprocess .check_output ("curl https://storage.googleapis.com/wandb-production.appspot.com/mlclass/simpsons.tar.gz | tar xvz" , shell = True )
24
+ subprocess .check_output (
25
+ "curl https://storage.googleapis.com/wandb-production.appspot.com/mlclass/simpsons.tar.gz | tar xvz" , shell = True )
25
26
26
27
# this is the augmentation configuration we will use for training
27
28
# see: https://keras.io/preprocessing/image/#imagedatagenerator-class
36
37
# batches of augmented image data
37
38
train_generator = train_datagen .flow_from_directory (
38
39
'simpsons/train' , # this is the target directory
39
- target_size = (config .img_size ,config .img_size ),
40
+ target_size = (config .img_size , config .img_size ),
40
41
batch_size = config .batch_size )
41
42
42
43
# this is a similar generator, for validation data
43
44
test_generator = test_datagen .flow_from_directory (
44
- 'simpsons/test' ,
45
- target_size = (config .img_size ,config .img_size ),
46
- batch_size = config .batch_size )
45
+ 'simpsons/test' ,
46
+ target_size = (config .img_size , config .img_size ),
47
+ batch_size = config .batch_size )
47
48
48
49
labels = list (test_generator .class_indices .keys ())
49
50
50
51
model = Sequential ()
51
- model .add (Conv2D (16 , (3 ,3 ), input_shape = (config .img_size , config .img_size , 3 ), activation = "relu" ))
52
+ model .add (Conv2D (16 , (3 , 3 ), input_shape = (
53
+ config .img_size , config .img_size , 3 ), activation = "relu" ))
52
54
model .add (MaxPooling2D ())
53
55
model .add (Flatten ())
54
56
model .add (Dropout (0.4 ))
57
59
loss = 'categorical_crossentropy' , metrics = ['accuracy' ])
58
60
59
61
model .fit_generator (
60
- train_generator ,
61
- steps_per_epoch = len (train_generator ),
62
- epochs = config .epochs ,
63
- workers = 4 ,
64
- validation_data = test_generator ,
65
- callbacks = [WandbCallback (data_type = "image" , labels = labels , generator = test_generator , save_model = False )],
66
- validation_steps = len (test_generator ))
62
+ train_generator ,
63
+ steps_per_epoch = len (train_generator ),
64
+ epochs = config .epochs ,
65
+ workers = 4 ,
66
+ validation_data = test_generator ,
67
+ callbacks = [WandbCallback (
68
+ data_type = "image" , labels = labels , generator = test_generator , save_model = False )],
69
+ validation_steps = len (test_generator ))
0 commit comments