edflow.hooks.checkpoint_hooks.tf_checkpoint_hook module

Summary

Classes:

CheckpointHook

Does that checkpoint thingy where it stores everything in a checkpoint.

RestoreCurrentCheckpointHook

Restores a TensorFlow model from a checkpoint at each epoch.

RestoreModelHook

Restores a TensorFlow model from a checkpoint at each epoch.

RestoreTFModelHook

alias of edflow.hooks.checkpoint_hooks.tf_checkpoint_hook.RestoreModelHook

RetrainHook

Restes the global step at the beginning of training.

WaitForManager

Wait to make sure checkpoints are not overflowing.

Reference

class edflow.hooks.checkpoint_hooks.tf_checkpoint_hook.RestoreModelHook(variables, checkpoint_path, filter_cond=<function RestoreModelHook.<lambda>>, global_step_setter=None)[source]

Bases: edflow.hooks.hook.Hook

Restores a TensorFlow model from a checkpoint at each epoch. Can also be used as a functor.

__init__(variables, checkpoint_path, filter_cond=<function RestoreModelHook.<lambda>>, global_step_setter=None)[source]
Parameters
  • variables (list) – tf.Variable to be loaded from the checkpoint.

  • checkpoint_path (str) – Directory in which the checkpoints are stored or explicit checkpoint. Ignored if used as functor.

  • filter_cond (Callable) – A function used to filter files, to only get the checkpoints that are wanted. Ignored if used as functor.

  • global_step_setter (Callable) – Callback to set global_step.

property session
before_epoch(ep)[source]
Parameters

ep

static parse_global_step(checkpoint)[source]
Parameters

checkpoint

edflow.hooks.checkpoint_hooks.tf_checkpoint_hook.RestoreTFModelHook

alias of edflow.hooks.checkpoint_hooks.tf_checkpoint_hook.RestoreModelHook

class edflow.hooks.checkpoint_hooks.tf_checkpoint_hook.CheckpointHook(root_path, variables, modelname='model', session=None, step=None, interval=None, max_to_keep=5)[source]

Bases: edflow.hooks.hook.Hook

Does that checkpoint thingy where it stores everything in a checkpoint.

__init__(root_path, variables, modelname='model', session=None, step=None, interval=None, max_to_keep=5)[source]
Parameters
  • root_path (str) – Path to where the checkpoints are stored.

  • variables (list) – List of all variables to keep track of.

  • session (tf.Session) – Session instance for saver.

  • modelname (str) – Used to name the checkpoint.

  • step (tf.Tensor or callable) – Step op, that can be evaluated: i,.e. a tf.Tensor or a python callable returning the step as an integer).

  • interval (int) – Number of iterations after which a checkpoint is saved. If None, a checkpoint is saved after each epoch.

  • max_to_keep (int) – Maximum number of checkpoints to keep on disk. Use 0 or None to never delete any checkpoints.

before_epoch(ep)[source]
Parameters

ep

after_epoch(epoch)[source]
Parameters

epoch

after_step(step, last_results)[source]
Parameters
  • step

  • last_results

at_exception(*args, **kwargs)[source]
Parameters
  • *args

  • **kwargs

save()[source]
global_step()[source]
class edflow.hooks.checkpoint_hooks.tf_checkpoint_hook.RetrainHook(global_step=None)[source]

Bases: edflow.hooks.hook.Hook

Restes the global step at the beginning of training.

__init__(global_step=None)[source]
Parameters

global_step (tf.Variable) – Variable tracking the training step.

before_epoch(epoch)[source]
Parameters

epoch

before_step(batch_index, fetches, feeds, batch)[source]
Parameters
  • batch_index

  • fetches

  • feeds

  • batch

after_step(step, *args, **kwargs)[source]
Parameters
  • step

  • *args

  • **kwargs

class edflow.hooks.checkpoint_hooks.tf_checkpoint_hook.WaitForManager(checkpoint_root, max_n, interval=5)[source]

Bases: edflow.hooks.hook.Hook

Wait to make sure checkpoints are not overflowing.

__init__(checkpoint_root, max_n, interval=5)[source]
Parameters
  • checkpoint_root (str) – Path to look for checkpoints.

  • max_n (int) – Wait as long as there are more than max_n ckpts.

  • interval (float) – Number of seconds after which to check for number of checkpoints again.

wait()[source]

Loop until the number of checkpoints got reduced.

before_epoch(ep)[source]
Parameters

ep

class edflow.hooks.checkpoint_hooks.tf_checkpoint_hook.RestoreCurrentCheckpointHook(variables, checkpoint_path, filter_cond=<function RestoreModelHook.<lambda>>, global_step_setter=None)[source]

Bases: edflow.hooks.checkpoint_hooks.tf_checkpoint_hook.RestoreModelHook

Restores a TensorFlow model from a checkpoint at each epoch. Can also be used as a functor.

before_epoch(ep)[source]
Parameters

ep