edflow.hooks.pytorch_hooks module

Summary

Classes:

DataPrepHook

The hook is needed in order to convert the input appropriately.

PyCheckpointHook

Does that checkpoint thingy where it stores everything in a checkpoint.

PyLoggingHook

Supply and evaluate logging ops at an intervall of training steps.

ToFromTorchHook

ToNumpyHook

Converts all pytorch Variables and Tensors in the results to numpy arrays and leaves the rest as is.

ToTorchHook

Converts all numpy arrays in the batch to torch.Tensor arrays and leaves the rest as is.

Reference

class edflow.hooks.pytorch_hooks.PyCheckpointHook(root_path, model, modelname='model', interval=None)[source]

Bases: edflow.hooks.hook.Hook

Does that checkpoint thingy where it stores everything in a checkpoint.

__init__(root_path, model, modelname='model', interval=None)[source]
Parameters
  • root_path (str) – Path to where the checkpoints are stored.

  • model (nn.Module) – Model to checkpoint.

  • modelname (str) – Prefix for checkpoint files.

  • interval (int) – Number of iterations after which a checkpoint is saved. In any case a checkpoint is savead after each epoch.

before_epoch(epoch)[source]

Called before each epoch.

Parameters

epoch (int) – Index of epoch that just started.

after_epoch(epoch)[source]

Called after each epoch.

Parameters

epoch (int) – Index of epoch that just ended.

after_step(step, last_results)[source]

Called after each step.

Parameters
  • step (int) – Current training step.

  • last_results (list) – Results from last time this hook was called.

at_exception(*args, **kwargs)[source]

Called when an exception is raised.

Parameters

exception

Raises
  • be – raised again after all

  • been – handled

save()[source]
class edflow.hooks.pytorch_hooks.PyLoggingHook(log_ops=[], scalar_keys=[], histogram_keys=[], image_keys=[], log_keys=[], graph=None, interval=100, root_path='logs')[source]

Bases: edflow.hooks.hook.Hook

Supply and evaluate logging ops at an intervall of training steps.

__init__(log_ops=[], scalar_keys=[], histogram_keys=[], image_keys=[], log_keys=[], graph=None, interval=100, root_path='logs')[source]
Parameters
  • log_ops (list) – Ops to run at logging time.

  • scalars (dict) – Scalar ops.

  • histograms (dict) – Histogram ops.

  • images (dict) – Image ops. Note that for these no tensorboard logging ist used but a custom image saver.

  • logs (dict) – Logs to std out via logger.

  • graph (tf.Graph) – Current graph.

  • interval (int) – Intervall of training steps before logging.

  • root_path (str) – Path at which the logs are stored.

before_step(batch_index, fetches, feeds, batch)[source]

Called before each step. Can update any feeds and fetches.

Parameters
  • step (int) – Current training step.

  • fetches (list or dict) – Fetches for the next session.run call.

  • feeds (dict) – Data used at this step.

  • batch (list or dict) – All data available at this step.

after_step(batch_index, last_results)[source]

Called after each step.

Parameters
  • step (int) – Current training step.

  • last_results (list) – Results from last time this hook was called.

class edflow.hooks.pytorch_hooks.ToNumpyHook[source]

Bases: edflow.hooks.hook.Hook

Converts all pytorch Variables and Tensors in the results to numpy arrays and leaves the rest as is.

after_step(step, results)[source]

Called after each step.

Parameters
  • step (int) – Current training step.

  • last_results (list) – Results from last time this hook was called.

class edflow.hooks.pytorch_hooks.ToTorchHook(push_to_gpu=True, dtype=<Mock name='mock.float' id='140096290262128'>)[source]

Bases: edflow.hooks.hook.Hook

Converts all numpy arrays in the batch to torch.Tensor arrays and leaves the rest as is.

__init__(push_to_gpu=True, dtype=<Mock name='mock.float' id='140096290262128'>)[source]

Initialize self. See help(type(self)) for accurate signature.

before_step(step, fetches, feeds, batch)[source]

Called before each step. Can update any feeds and fetches.

Parameters
  • step (int) – Current training step.

  • fetches (list or dict) – Fetches for the next session.run call.

  • feeds (dict) – Data used at this step.

  • batch (list or dict) – All data available at this step.

class edflow.hooks.pytorch_hooks.ToFromTorchHook(*args, **kwargs)[source]

Bases: edflow.hooks.pytorch_hooks.ToNumpyHook, edflow.hooks.pytorch_hooks.ToTorchHook

__init__(*args, **kwargs)[source]

Initialize self. See help(type(self)) for accurate signature.

class edflow.hooks.pytorch_hooks.DataPrepHook(*args, **kwargs)[source]

Bases: edflow.hooks.pytorch_hooks.ToFromTorchHook

The hook is needed in order to convert the input appropriately. Here, we have to reshape the input i.e. append 1 to the shape (for the number of channels of the image). Plus, it converts to data to Pytorch tensors, and back.

before_step(step, fetches, feeds, batch)[source]

Steps taken before the training step. :param step: Training step. :param fetches: Fetches for the next session.run call. :param feeds: Feeds for the next session.run call. :param batch: The batch to be iterated over.

after_step(step, results)[source]

Steps taken after the training step. :param step: Training step. :param results: Result of the session.