14
14
import keras
15
15
from keras import losses
16
16
from keras import backend as K
17
- from keras .models import load_model
17
+ from keras .models import load_model , Model
18
18
from keras .optimizers import Adam
19
19
from keras .utils import get_custom_objects , multi_gpu_model
20
20
@@ -42,12 +42,16 @@ def __init__(self,
42
42
input_shape = None ,
43
43
encoder_dim = None ,
44
44
trainer = "original" ,
45
+ pingpong = False ,
46
+ memory_saving_gradients = False ,
45
47
predict = False ):
46
- logger .debug ("Initializing ModelBase (%s): (model_dir: '%s', gpus: %s, "
48
+ logger .debug ("Initializing ModelBase (%s): (model_dir: '%s', gpus: %s, no_logs: %s "
47
49
"training_image_size, %s, alignments_paths: %s, preview_scale: %s, "
48
- "input_shape: %s, encoder_dim: %s)" , self .__class__ .__name__ , model_dir , gpus ,
49
- training_image_size , alignments_paths , preview_scale , input_shape ,
50
- encoder_dim )
50
+ "input_shape: %s, encoder_dim: %s, trainer: %s, pingpong: %s, "
51
+ "memory_saving_gradients: %s, predict: %s)" ,
52
+ self .__class__ .__name__ , model_dir , gpus , no_logs , training_image_size ,
53
+ alignments_paths , preview_scale , input_shape , encoder_dim , trainer ,
54
+ pingpong , memory_saving_gradients , predict )
51
55
52
56
self .predict = predict
53
57
self .model_dir = model_dir
@@ -60,7 +64,7 @@ def __init__(self,
60
64
self .encoder_dim = encoder_dim
61
65
self .trainer = trainer
62
66
63
- self .state = State (self .model_dir , self .name , no_logs , training_image_size )
67
+ self .state = State (self .model_dir , self .name , no_logs , pingpong , training_image_size )
64
68
self .is_legacy = False
65
69
self .rename_legacy ()
66
70
self .load_state_info ()
@@ -74,8 +78,10 @@ def __init__(self,
74
78
self .training_opts = {"alignments" : alignments_paths ,
75
79
"preview_scaling" : preview_scale / 100 ,
76
80
"warp_to_landmarks" : warp_to_landmarks ,
77
- "no_flip" : no_flip }
81
+ "no_flip" : no_flip ,
82
+ "pingpong" : pingpong }
78
83
84
+ self .set_gradient_type (memory_saving_gradients )
79
85
self .build ()
80
86
self .set_training_data ()
81
87
logger .debug ("Initialized ModelBase (%s)" , self .__class__ .__name__ )
@@ -105,6 +111,15 @@ def models_exist(self):
105
111
logger .debug ("Pre-existing models exist: %s" , retval )
106
112
return retval
107
113
114
+ @staticmethod
115
+ def set_gradient_type (memory_saving_gradients ):
116
+ """ Monkeypatch Memory Saving Gradients if requested """
117
+ if not memory_saving_gradients :
118
+ return
119
+ logger .info ("Using Memory Saving Gradients" )
120
+ from lib .model import memory_saving_gradients
121
+ K .__dict__ ["gradients" ] = memory_saving_gradients .gradients_memory
122
+
108
123
def set_training_data (self ):
109
124
""" Override to set model specific training data.
110
125
@@ -132,7 +147,7 @@ def build(self):
132
147
self .load_models (swapped = False )
133
148
self .build_autoencoders ()
134
149
self .log_summary ()
135
- self .compile_predictors ()
150
+ self .compile_predictors (initialize = True )
136
151
137
152
def build_autoencoders (self ):
138
153
""" Override for Model Specific autoencoder builds
@@ -215,24 +230,42 @@ def set_output_shape(self, model):
215
230
self .output_shape = tuple (out [0 ])
216
231
logger .debug ("Added output shape: %s" , self .output_shape )
217
232
218
- def compile_predictors (self ):
233
+ def reset_pingpong (self ):
234
+ """ Reset the models for pingpong training """
235
+ logger .debug ("Resetting models" )
236
+
237
+ # Clear models and graph
238
+ self .predictors = dict ()
239
+ K .clear_session ()
240
+
241
+ # Load Models for current training run
242
+ for model in self .networks .values ():
243
+ model .network = Model .from_config (model .config )
244
+ model .network .set_weights (model .weights )
245
+
246
+ self .build_autoencoders ()
247
+ self .compile_predictors (initialize = False )
248
+ logger .debug ("Reset models" )
249
+
250
+ def compile_predictors (self , initialize = True ):
219
251
""" Compile the predictors """
220
252
logger .debug ("Compiling Predictors" )
221
253
optimizer = self .get_optimizer (lr = 5e-5 , beta_1 = 0.5 , beta_2 = 0.999 )
222
254
223
255
for side , model in self .predictors .items ():
224
256
loss_names = ["loss" ]
225
- loss_funcs = [self .loss_function (side )]
257
+ loss_funcs = [self .loss_function (side , initialize )]
226
258
mask = [inp for inp in model .inputs if inp .name .startswith ("mask" )]
227
259
if mask :
228
260
loss_names .insert (0 , "mask_loss" )
229
- loss_funcs .insert (0 , self .mask_loss_function (mask [0 ], side ))
261
+ loss_funcs .insert (0 , self .mask_loss_function (mask [0 ], side , initialize ))
230
262
model .compile (optimizer = optimizer , loss = loss_funcs )
231
263
232
264
if len (loss_names ) > 1 :
233
265
loss_names .insert (0 , "total_loss" )
234
- self .state .add_session_loss_names (side , loss_names )
235
- self .history [side ] = list ()
266
+ if initialize :
267
+ self .state .add_session_loss_names (side , loss_names )
268
+ self .history [side ] = list ()
236
269
logger .debug ("Compiled Predictors. Losses: %s" , loss_names )
237
270
238
271
def get_optimizer (self , lr = 5e-5 , beta_1 = 0.5 , beta_2 = 0.999 ): # pylint: disable=invalid-name
@@ -250,24 +283,24 @@ def get_optimizer(self, lr=5e-5, beta_1=0.5, beta_2=0.999): # pylint: disable=i
250
283
logger .debug ("Optimizer kwargs: %s" , opt_kwargs )
251
284
return Adam (** opt_kwargs )
252
285
253
- def loss_function (self , side ):
286
+ def loss_function (self , side , initialize ):
254
287
""" Set the loss function """
255
288
if self .config .get ("dssim_loss" , False ):
256
- if side == "a" and not self .predict :
289
+ if side == "a" and not self .predict and initialize :
257
290
logger .verbose ("Using DSSIM Loss" )
258
291
loss_func = DSSIMObjective ()
259
292
else :
260
- if side == "a" and not self .predict :
293
+ if side == "a" and not self .predict and initialize :
261
294
logger .verbose ("Using Mean Absolute Error Loss" )
262
295
loss_func = losses .mean_absolute_error
263
296
logger .debug (loss_func )
264
297
return loss_func
265
298
266
- def mask_loss_function (self , mask , side ):
299
+ def mask_loss_function (self , mask , side , initialize ):
267
300
""" Set the loss function for masks
268
301
Side is input so we only log once """
269
302
if self .config .get ("dssim_mask_loss" , False ):
270
- if side == "a" and not self .predict :
303
+ if side == "a" and not self .predict and initialize :
271
304
logger .verbose ("Using DSSIM Loss for mask" )
272
305
mask_loss_func = DSSIMObjective ()
273
306
else :
@@ -276,7 +309,7 @@ def mask_loss_function(self, mask, side):
276
309
mask_loss_func = losses .mean_absolute_error
277
310
278
311
if self .config .get ("penalized_mask_loss" , False ):
279
- if side == "a" and not self .predict :
312
+ if side == "a" and not self .predict and initialize :
280
313
logger .verbose ("Using Penalized Loss for mask" )
281
314
mask_loss_func = PenalizedLoss (mask , mask_loss_func )
282
315
logger .debug (mask_loss_func )
@@ -329,7 +362,7 @@ def load_models(self, swapped):
329
362
330
363
if not self .models_exist and not self .predict :
331
364
logger .info ("Creating new '%s' model in folder: '%s'" , self .name , self .model_dir )
332
- return
365
+ return None
333
366
if not self .models_exist and self .predict :
334
367
logger .error ("Model could not be found in folder '%s'. Exiting" , self .model_dir )
335
368
exit (0 )
@@ -495,6 +528,8 @@ def __init__(self, filename, network_type, side, network):
495
528
self .name = self .set_name ()
496
529
self .network = network
497
530
self .network .name = self .name
531
+ self .config = network .get_config () # For pingpong restore
532
+ self .weights = network .get_weights () # For pingpong restore
498
533
logger .debug ("Initialized %s" , self .__class__ .__name__ )
499
534
500
535
def set_name (self ):
@@ -521,6 +556,7 @@ def load(self, fullpath=None):
521
556
logger .warning ("Failed loading existing training data. Generating new models" )
522
557
logger .debug ("Exception: %s" , str (err ))
523
558
return False
559
+ self .config = network .get_config ()
524
560
self .network = network # Update network with saved model
525
561
self .network .name = self .type
526
562
return True
@@ -531,6 +567,7 @@ def save(self, fullpath=None, should_backup=False):
531
567
if should_backup :
532
568
self .backup (fullpath = fullpath )
533
569
logger .debug ("Saving model: '%s'" , fullpath )
570
+ self .weights = self .network .get_weights ()
534
571
self .network .save (fullpath )
535
572
536
573
def backup (self , fullpath = None ):
@@ -553,10 +590,10 @@ def convert_legacy_weights(self):
553
590
554
591
class State ():
555
592
""" Class to hold the model's current state and autoencoder structure """
556
- def __init__ (self , model_dir , model_name , no_logs , training_image_size ):
593
+ def __init__ (self , model_dir , model_name , no_logs , pingpong , training_image_size ):
557
594
logger .debug ("Initializing %s: (model_dir: '%s', model_name: '%s', no_logs: %s, "
558
- "training_image_size: '%s'" , self .__class__ .__name__ , model_dir ,
559
- model_name , no_logs , training_image_size )
595
+ "pingpong: %s, training_image_size: '%s'" , self .__class__ .__name__ , model_dir ,
596
+ model_name , no_logs , pingpong , training_image_size )
560
597
self .serializer = Serializer .get_serializer ("json" )
561
598
filename = "{}_state.{}" .format (model_name , self .serializer .ext )
562
599
self .filename = str (model_dir / filename )
@@ -570,7 +607,7 @@ def __init__(self, model_dir, model_name, no_logs, training_image_size):
570
607
self .config = dict ()
571
608
self .load ()
572
609
self .session_id = self .new_session_id ()
573
- self .create_new_session (no_logs )
610
+ self .create_new_session (no_logs , pingpong )
574
611
logger .debug ("Initialized %s:" , self .__class__ .__name__ )
575
612
576
613
@property
@@ -602,11 +639,12 @@ def new_session_id(self):
602
639
logger .debug (session_id )
603
640
return session_id
604
641
605
- def create_new_session (self , no_logs ):
642
+ def create_new_session (self , no_logs , pingpong ):
606
643
""" Create a new session """
607
644
logger .debug ("Creating new session. id: %s" , self .session_id )
608
645
self .sessions [self .session_id ] = {"timestamp" : time .time (),
609
646
"no_logs" : no_logs ,
647
+ "pingpong" : pingpong ,
610
648
"loss_names" : dict (),
611
649
"batchsize" : 0 ,
612
650
"iterations" : 0 }
0 commit comments