edflow.hooks.checkpoint_hooks.torch_checkpoint_hook module

Summary

Classes:

RestorePytorchModelHook

Restores a PyTorch model from a checkpoint at each epoch.

Reference

class edflow.hooks.checkpoint_hooks.torch_checkpoint_hook.RestorePytorchModelHook(model, checkpoint_path, filter_cond=<function RestorePytorchModelHook.<lambda>>, global_step_setter=None)[source]

Bases: edflow.hooks.hook.Hook

Restores a PyTorch model from a checkpoint at each epoch. Can also be used as a functor.

__init__(model, checkpoint_path, filter_cond=<function RestorePytorchModelHook.<lambda>>, global_step_setter=None)[source]
Parameters
  • model (torch.nn.Module) – Model to initialize

  • checkpoint_path (str) – Directory in which the checkpoints are stored or explicit checkpoint. Ignored if used as functor.

  • filter_cond (Callable) – A function used to filter files, to only get the checkpoints that are wanted. Ignored if used as functor.

  • global_step_setter (Callable) – Function, that the retrieved global step can be passed to.

before_epoch(ep)[source]

Called before each epoch.

Parameters

epoch (int) – Index of epoch that just started.

static parse_global_step(checkpoint)[source]
static parse_checkpoint(checkpoint)[source]