import os
import tensorflow as tf
from edflow.hooks.hook import Hook
from edflow.hooks.checkpoint_hooks.common import get_latest_checkpoint
from edflow.custom_logging import get_logger
[docs]class RestoreModelHook(Hook):
"""Restores a TensorFlow model from a checkpoint at each epoch. Can also
be used as a functor.
Parameters
----------
Returns
-------
"""
[docs] def __init__(
self,
variables,
checkpoint_path,
filter_cond=lambda c: True,
global_step_setter=None,
):
"""
Parameters
----------
variables : list
tf.Variable to be loaded from the checkpoint.
checkpoint_path : str
Directory in which the checkpoints are
stored or explicit checkpoint. Ignored if used as functor.
filter_cond : Callable
A function used to filter files, to only get the checkpoints that
are wanted. Ignored if used as functor.
global_step_setter : Callable
Callback to set global_step.
"""
self.root = checkpoint_path
self.fcond = filter_cond
self.setstep = global_step_setter
self.logger = get_logger(self)
self.saver = tf.train.Saver(variables)
@property
def session(self):
""" """
if not hasattr(self, "_session"):
self._session = tf.get_default_session()
return self._session
[docs] def before_epoch(self, ep):
"""
Parameters
----------
ep :
Returns
-------
"""
# checkpoint = tf.train.latest_checkpoint(self.root)
checkpoint = get_latest_checkpoint(self.root, self.fcond)
self(checkpoint)
def __call__(self, checkpoint):
self.saver.restore(self.session, checkpoint)
self.logger.info("Restored model from {}".format(checkpoint))
global_step = self.parse_global_step(checkpoint)
self.logger.info("Global step: {}".format(global_step))
if self.setstep is not None:
self.setstep(global_step)
[docs] @staticmethod
def parse_global_step(checkpoint):
"""
Parameters
----------
checkpoint :
Returns
-------
"""
global_step = int(checkpoint.rsplit("-", 1)[1])
return global_step
# Simple renaming for consistency
# Todo: Make the Restore op part of the model (issue #2)
# https://bitbucket.org/jhaux/edflow/issues/2/make-a-general-model-restore-hook
RestoreTFModelHook = RestoreModelHook
[docs]class CheckpointHook(Hook):
"""Does that checkpoint thingy where it stores everything in a
checkpoint.
Parameters
----------
Returns
-------
"""
[docs] def __init__(
self,
root_path,
variables,
modelname="model",
session=None,
step=None,
interval=None,
max_to_keep=5,
):
"""
Parameters
----------
root_path : str
Path to where the checkpoints are stored.
variables : list
List of all variables to keep track of.
session : tf.Session
Session instance for saver.
modelname : str
Used to name the checkpoint.
step : tf.Tensor or callable
Step op, that can be evaluated: i,.e. a tf.Tensor or a python
callable returning the step as an integer).
interval : int
Number of iterations after which a checkpoint is
saved. If None, a checkpoint is saved after each epoch.
max_to_keep : int
Maximum number of checkpoints to keep on
disk. Use 0 or None to never delete any checkpoints.
"""
self.root = root_path
self.interval = interval
self.step = step if step is not None else tf.train.get_global_step()
self.saver = tf.train.Saver(variables, max_to_keep=max_to_keep)
self.logger = get_logger(self)
os.makedirs(root_path, exist_ok=True)
self.savename = os.path.join(root_path, "{}.ckpt".format(modelname))
self.session = session
[docs] def before_epoch(self, ep):
"""
Parameters
----------
ep :
Returns
-------
"""
if self.session is None:
self.session = tf.get_default_session()
[docs] def after_epoch(self, epoch):
"""
Parameters
----------
epoch :
Returns
-------
"""
if self.interval is None:
self.save()
[docs] def after_step(self, step, last_results):
"""
Parameters
----------
step :
last_results :
Returns
-------
"""
if self.interval is not None and self.global_step() % self.interval == 0:
self.save()
[docs] def at_exception(self, *args, **kwargs):
"""
Parameters
----------
*args :
**kwargs :
Returns
-------
"""
self.save()
[docs] def save(self):
""" """
global_step = self.global_step()
self.saver.save(self.session, self.savename, global_step=global_step)
self.logger.info("Saved model to {}".format(self.savename))
[docs] def global_step(self):
""" """
if isinstance(self.step, tf.Tensor) or isinstance(self.step, tf.Variable):
global_step = self.step
else:
global_step = self.step()
return global_step
[docs]class RetrainHook(Hook):
"""Restes the global step at the beginning of training."""
[docs] def __init__(self, global_step=None):
"""
Parameters
----------
global_step : tf.Variable
Variable tracking the training step.
"""
self.global_step = global_step
self.logger = get_logger(self)
[docs] def before_epoch(self, epoch):
"""
Parameters
----------
epoch :
Returns
-------
"""
self.epoch = epoch
[docs] def before_step(self, batch_index, fetches, feeds, batch):
"""
Parameters
----------
batch_index :
fetches :
feeds :
batch :
Returns
-------
"""
if self.epoch == 0 and batch_index == 0:
fetches["reset_step"] = tf.assign(self.global_step, 0)
[docs] def after_step(self, step, *args, **kwargs):
"""
Parameters
----------
step :
*args :
**kwargs :
Returns
-------
"""
if step == 0 and self.epoch == 0:
self.logger.info("Reset global_step")
[docs]class RestoreCurrentCheckpointHook(RestoreModelHook):
"""Restores a TensorFlow model from a checkpoint at each epoch. Can also
be used as a functor.
Parameters
----------
Returns
-------
"""
[docs] def before_epoch(self, ep):
"""
Parameters
----------
ep :
Returns
-------
"""
checkpoint = self.root
self(checkpoint)