Skip to content

Commit dba7d41

Browse files
authored
VRAM Improvement Options (deepfakes#671)
* Implement ping-pong training * Disable tensorboard for pingpong training * Implement Memory Saving Gradients
1 parent 8456097 commit dba7d41

File tree

8 files changed

+592
-36
lines changed

8 files changed

+592
-36
lines changed

lib/cli.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,25 @@ def get_argument_list():
879879
"help": "Disables TensorBoard logging. NB: Disabling logs means "
880880
"that you will not be able to use the graph or analysis "
881881
"for this session in the GUI."})
882+
argument_list.append({"opts": ("-pp", "--ping-pong"),
883+
"action": "store_true",
884+
"dest": "pingpong",
885+
"default": False,
886+
"help": "Enable ping pong training. Trains one side at a time, "
887+
"switching sides at each save iteration. Training will take "
888+
"2 to 4 times longer, with about a 30%%-50%% reduction in "
889+
"VRAM useage. NB: Preview won't show until both sides have "
890+
"been trained once."})
891+
argument_list.append({"opts": ("-msg", "--memory-saving-gradients"),
892+
"action": "store_true",
893+
"dest": "memory_saving_gradients",
894+
"default": False,
895+
"help": "Trades off VRAM useage against computation time. Can fit "
896+
"larger models into memory at a cost of slower training "
897+
"speed. 50%%-150%% batch size increase for 20%%-50%% longer "
898+
"training time. NB: Launch time will be significantly "
899+
"delayed. Switching sides using ping-pong training will "
900+
"take longer."})
882901
argument_list.append({"opts": ("-wl", "--warp-to-landmarks"),
883902
"action": "store_true",
884903
"dest": "warp_to_landmarks",

lib/gui/display_command.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def display_item_set(self):
193193
session = get_config().session
194194
if session.initialized and session.logging_disabled:
195195
logger.trace("Logs disabled. Hiding graph")
196-
self.set_info("Graph is disabled as 'no-logs' has been selected")
196+
self.set_info("Graph is disabled as 'no-logs' or 'pingpong' has been selected")
197197
self.display_item = None
198198
elif session.initialized:
199199
logger.trace("Loading graph")

lib/gui/stats.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def iterations(self):
136136
@property
137137
def logging_disabled(self):
138138
""" Return whether logging is disabled for this session """
139-
return self.session["no_logs"]
139+
return self.session["no_logs"] or self.session["pingpong"]
140140

141141
@property
142142
def loss(self):

lib/model/memory_saving_gradients.py

Lines changed: 439 additions & 0 deletions
Large diffs are not rendered by default.

plugins/train/model/_base.py

Lines changed: 63 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import keras
1515
from keras import losses
1616
from keras import backend as K
17-
from keras.models import load_model
17+
from keras.models import load_model, Model
1818
from keras.optimizers import Adam
1919
from keras.utils import get_custom_objects, multi_gpu_model
2020

@@ -42,12 +42,16 @@ def __init__(self,
4242
input_shape=None,
4343
encoder_dim=None,
4444
trainer="original",
45+
pingpong=False,
46+
memory_saving_gradients=False,
4547
predict=False):
46-
logger.debug("Initializing ModelBase (%s): (model_dir: '%s', gpus: %s, "
48+
logger.debug("Initializing ModelBase (%s): (model_dir: '%s', gpus: %s, no_logs: %s"
4749
"training_image_size, %s, alignments_paths: %s, preview_scale: %s, "
48-
"input_shape: %s, encoder_dim: %s)", self.__class__.__name__, model_dir, gpus,
49-
training_image_size, alignments_paths, preview_scale, input_shape,
50-
encoder_dim)
50+
"input_shape: %s, encoder_dim: %s, trainer: %s, pingpong: %s, "
51+
"memory_saving_gradients: %s, predict: %s)",
52+
self.__class__.__name__, model_dir, gpus, no_logs, training_image_size,
53+
alignments_paths, preview_scale, input_shape, encoder_dim, trainer,
54+
pingpong, memory_saving_gradients, predict)
5155

5256
self.predict = predict
5357
self.model_dir = model_dir
@@ -60,7 +64,7 @@ def __init__(self,
6064
self.encoder_dim = encoder_dim
6165
self.trainer = trainer
6266

63-
self.state = State(self.model_dir, self.name, no_logs, training_image_size)
67+
self.state = State(self.model_dir, self.name, no_logs, pingpong, training_image_size)
6468
self.is_legacy = False
6569
self.rename_legacy()
6670
self.load_state_info()
@@ -74,8 +78,10 @@ def __init__(self,
7478
self.training_opts = {"alignments": alignments_paths,
7579
"preview_scaling": preview_scale / 100,
7680
"warp_to_landmarks": warp_to_landmarks,
77-
"no_flip": no_flip}
81+
"no_flip": no_flip,
82+
"pingpong": pingpong}
7883

84+
self.set_gradient_type(memory_saving_gradients)
7985
self.build()
8086
self.set_training_data()
8187
logger.debug("Initialized ModelBase (%s)", self.__class__.__name__)
@@ -105,6 +111,15 @@ def models_exist(self):
105111
logger.debug("Pre-existing models exist: %s", retval)
106112
return retval
107113

114+
@staticmethod
115+
def set_gradient_type(memory_saving_gradients):
116+
""" Monkeypatch Memory Saving Gradients if requested """
117+
if not memory_saving_gradients:
118+
return
119+
logger.info("Using Memory Saving Gradients")
120+
from lib.model import memory_saving_gradients
121+
K.__dict__["gradients"] = memory_saving_gradients.gradients_memory
122+
108123
def set_training_data(self):
109124
""" Override to set model specific training data.
110125
@@ -132,7 +147,7 @@ def build(self):
132147
self.load_models(swapped=False)
133148
self.build_autoencoders()
134149
self.log_summary()
135-
self.compile_predictors()
150+
self.compile_predictors(initialize=True)
136151

137152
def build_autoencoders(self):
138153
""" Override for Model Specific autoencoder builds
@@ -215,24 +230,42 @@ def set_output_shape(self, model):
215230
self.output_shape = tuple(out[0])
216231
logger.debug("Added output shape: %s", self.output_shape)
217232

218-
def compile_predictors(self):
233+
def reset_pingpong(self):
234+
""" Reset the models for pingpong training """
235+
logger.debug("Resetting models")
236+
237+
# Clear models and graph
238+
self.predictors = dict()
239+
K.clear_session()
240+
241+
# Load Models for current training run
242+
for model in self.networks.values():
243+
model.network = Model.from_config(model.config)
244+
model.network.set_weights(model.weights)
245+
246+
self.build_autoencoders()
247+
self.compile_predictors(initialize=False)
248+
logger.debug("Reset models")
249+
250+
def compile_predictors(self, initialize=True):
219251
""" Compile the predictors """
220252
logger.debug("Compiling Predictors")
221253
optimizer = self.get_optimizer(lr=5e-5, beta_1=0.5, beta_2=0.999)
222254

223255
for side, model in self.predictors.items():
224256
loss_names = ["loss"]
225-
loss_funcs = [self.loss_function(side)]
257+
loss_funcs = [self.loss_function(side, initialize)]
226258
mask = [inp for inp in model.inputs if inp.name.startswith("mask")]
227259
if mask:
228260
loss_names.insert(0, "mask_loss")
229-
loss_funcs.insert(0, self.mask_loss_function(mask[0], side))
261+
loss_funcs.insert(0, self.mask_loss_function(mask[0], side, initialize))
230262
model.compile(optimizer=optimizer, loss=loss_funcs)
231263

232264
if len(loss_names) > 1:
233265
loss_names.insert(0, "total_loss")
234-
self.state.add_session_loss_names(side, loss_names)
235-
self.history[side] = list()
266+
if initialize:
267+
self.state.add_session_loss_names(side, loss_names)
268+
self.history[side] = list()
236269
logger.debug("Compiled Predictors. Losses: %s", loss_names)
237270

238271
def get_optimizer(self, lr=5e-5, beta_1=0.5, beta_2=0.999): # pylint: disable=invalid-name
@@ -250,24 +283,24 @@ def get_optimizer(self, lr=5e-5, beta_1=0.5, beta_2=0.999): # pylint: disable=i
250283
logger.debug("Optimizer kwargs: %s", opt_kwargs)
251284
return Adam(**opt_kwargs)
252285

253-
def loss_function(self, side):
286+
def loss_function(self, side, initialize):
254287
""" Set the loss function """
255288
if self.config.get("dssim_loss", False):
256-
if side == "a" and not self.predict:
289+
if side == "a" and not self.predict and initialize:
257290
logger.verbose("Using DSSIM Loss")
258291
loss_func = DSSIMObjective()
259292
else:
260-
if side == "a" and not self.predict:
293+
if side == "a" and not self.predict and initialize:
261294
logger.verbose("Using Mean Absolute Error Loss")
262295
loss_func = losses.mean_absolute_error
263296
logger.debug(loss_func)
264297
return loss_func
265298

266-
def mask_loss_function(self, mask, side):
299+
def mask_loss_function(self, mask, side, initialize):
267300
""" Set the loss function for masks
268301
Side is input so we only log once """
269302
if self.config.get("dssim_mask_loss", False):
270-
if side == "a" and not self.predict:
303+
if side == "a" and not self.predict and initialize:
271304
logger.verbose("Using DSSIM Loss for mask")
272305
mask_loss_func = DSSIMObjective()
273306
else:
@@ -276,7 +309,7 @@ def mask_loss_function(self, mask, side):
276309
mask_loss_func = losses.mean_absolute_error
277310

278311
if self.config.get("penalized_mask_loss", False):
279-
if side == "a" and not self.predict:
312+
if side == "a" and not self.predict and initialize:
280313
logger.verbose("Using Penalized Loss for mask")
281314
mask_loss_func = PenalizedLoss(mask, mask_loss_func)
282315
logger.debug(mask_loss_func)
@@ -329,7 +362,7 @@ def load_models(self, swapped):
329362

330363
if not self.models_exist and not self.predict:
331364
logger.info("Creating new '%s' model in folder: '%s'", self.name, self.model_dir)
332-
return
365+
return None
333366
if not self.models_exist and self.predict:
334367
logger.error("Model could not be found in folder '%s'. Exiting", self.model_dir)
335368
exit(0)
@@ -495,6 +528,8 @@ def __init__(self, filename, network_type, side, network):
495528
self.name = self.set_name()
496529
self.network = network
497530
self.network.name = self.name
531+
self.config = network.get_config() # For pingpong restore
532+
self.weights = network.get_weights() # For pingpong restore
498533
logger.debug("Initialized %s", self.__class__.__name__)
499534

500535
def set_name(self):
@@ -521,6 +556,7 @@ def load(self, fullpath=None):
521556
logger.warning("Failed loading existing training data. Generating new models")
522557
logger.debug("Exception: %s", str(err))
523558
return False
559+
self.config = network.get_config()
524560
self.network = network # Update network with saved model
525561
self.network.name = self.type
526562
return True
@@ -531,6 +567,7 @@ def save(self, fullpath=None, should_backup=False):
531567
if should_backup:
532568
self.backup(fullpath=fullpath)
533569
logger.debug("Saving model: '%s'", fullpath)
570+
self.weights = self.network.get_weights()
534571
self.network.save(fullpath)
535572

536573
def backup(self, fullpath=None):
@@ -553,10 +590,10 @@ def convert_legacy_weights(self):
553590

554591
class State():
555592
""" Class to hold the model's current state and autoencoder structure """
556-
def __init__(self, model_dir, model_name, no_logs, training_image_size):
593+
def __init__(self, model_dir, model_name, no_logs, pingpong, training_image_size):
557594
logger.debug("Initializing %s: (model_dir: '%s', model_name: '%s', no_logs: %s, "
558-
"training_image_size: '%s'", self.__class__.__name__, model_dir,
559-
model_name, no_logs, training_image_size)
595+
"pingpong: %s, training_image_size: '%s'", self.__class__.__name__, model_dir,
596+
model_name, no_logs, pingpong, training_image_size)
560597
self.serializer = Serializer.get_serializer("json")
561598
filename = "{}_state.{}".format(model_name, self.serializer.ext)
562599
self.filename = str(model_dir / filename)
@@ -570,7 +607,7 @@ def __init__(self, model_dir, model_name, no_logs, training_image_size):
570607
self.config = dict()
571608
self.load()
572609
self.session_id = self.new_session_id()
573-
self.create_new_session(no_logs)
610+
self.create_new_session(no_logs, pingpong)
574611
logger.debug("Initialized %s:", self.__class__.__name__)
575612

576613
@property
@@ -602,11 +639,12 @@ def new_session_id(self):
602639
logger.debug(session_id)
603640
return session_id
604641

605-
def create_new_session(self, no_logs):
642+
def create_new_session(self, no_logs, pingpong):
606643
""" Create a new session """
607644
logger.debug("Creating new session. id: %s", self.session_id)
608645
self.sessions[self.session_id] = {"timestamp": time.time(),
609646
"no_logs": no_logs,
647+
"pingpong": pingpong,
610648
"loss_names": dict(),
611649
"batchsize": 0,
612650
"iterations": 0}

0 commit comments

Comments
 (0)