edflow.hooks.checkpoint_hooks.common module

Summary

Classes:

CollectorHook

Collects data.

KeepBestCheckpoints

Tries to find a metric for all checkpoints and keeps the n_keep best checkpoints and the latest checkpoint.

MetricTuple

StoreArraysHook

Collects lots of data, stacks them and then stores them.

WaitForCheckpointHook

Waits until a new checkpoint is created, then lets the Iterator continue.

Functions:

dict_repr

Makes a nice representation of a nested dict.

get_checkpoint_files

Return {global_step: [files,…]}.

get_latest_checkpoint

Return path to name of latest checkpoint in checkpoint_root dir.

make_iterator

Make an iterator that yields key value pairs.

strenumerate

Same as enumerate, but yields str(index).

test_valid_metrictuple

Checks if all inputs are correct.

tf_parse_global_step

torch_parse_global_step

Reference

edflow.hooks.checkpoint_hooks.common.get_latest_checkpoint(checkpoint_root, filter_cond=<function <lambda>>)[source]

Return path to name of latest checkpoint in checkpoint_root dir.

Parameters
  • checkpoint_root (str) – Path to where the checkpoints live.

  • filter_cond (Callable) – A function used to filter files, to only get the checkpoints that are wanted.

Returns

path of the latest checkpoint. Note that for tensorflow checkpoints this is not an existing file, but path{.index,.meta,data*} should be

Return type

str

class edflow.hooks.checkpoint_hooks.common.WaitForCheckpointHook(checkpoint_root, filter_cond=<function WaitForCheckpointHook.<lambda>>, interval=5, add_sec=5, callback=None, eval_all=False)[source]

Bases: edflow.hooks.hook.Hook

Waits until a new checkpoint is created, then lets the Iterator continue.

__init__(checkpoint_root, filter_cond=<function WaitForCheckpointHook.<lambda>>, interval=5, add_sec=5, callback=None, eval_all=False)[source]
Parameters
  • checkpoint_root (str) – Path to look for checkpoints.

  • filter_cond (Callable) – A function used to filter files, to only get the checkpoints that are wanted.

  • interval (float) – Number of seconds after which to check for a new checkpoint again.

  • add_sec (float) – Number of seconds to wait, after a checkpoint is found, to avoid race conditions, if the checkpoint is still being written at the time it’s meant to be read.

  • callback (Callable) – Callback called with path of found checkpoint.

  • eval_all (bool) – Accept all instead of just latest checkpoint.

fcond(c)[source]
look()[source]

Loop until a new checkpoint is found.

before_epoch(ep)[source]

Called before each epoch.

Parameters

epoch (int) – Index of epoch that just started.

edflow.hooks.checkpoint_hooks.common.strenumerate(*args, **kwargs)[source]

Same as enumerate, but yields str(index).

edflow.hooks.checkpoint_hooks.common.make_iterator(list_or_dict)[source]

Make an iterator that yields key value pairs.

edflow.hooks.checkpoint_hooks.common.dict_repr(some_dict, pre='', level=0)[source]

Makes a nice representation of a nested dict.

class edflow.hooks.checkpoint_hooks.common.CollectorHook[source]

Bases: edflow.hooks.hook.Hook

Collects data. Supposed to be used as base class.

__init__()[source]

Initialize self. See help(type(self)) for accurate signature.

after_step(step, results)[source]

Called after each step.

Parameters
  • step (int) – Current training step.

  • last_results (list) – Results from last time this hook was called.

stack_results(new_data, all_data)[source]

Given the current collected data append the new results along the batch dimension.

Parameters
  • new_data (list or dict) – data to append.

  • all_data (list or dict) – data to append to.

class edflow.hooks.checkpoint_hooks.common.StoreArraysHook(save_root)[source]

Bases: edflow.hooks.checkpoint_hooks.common.CollectorHook

Collects lots of data, stacks them and then stores them.

__init__(save_root)[source]

Collect all outputs of step op and store them as npz.

after_epoch(epoch)[source]

Called after each epoch.

Parameters

epoch (int) – Index of epoch that just ended.

flatten_results(results, prefix, store_dict)[source]

Recursively walk over the results dictionary and stack the data.

Parameters
  • results (dict or list) – Containing results.

  • prefix (str) – Prepended to name when storing.

  • store_dict (dict) – Flat storage dictionary.

class edflow.hooks.checkpoint_hooks.common.MetricTuple(input_names, output_names, metric, name)

Bases: tuple

input_names

Alias for field number 0

metric

Alias for field number 2

name

Alias for field number 3

output_names

Alias for field number 1

edflow.hooks.checkpoint_hooks.common.test_valid_metrictuple(metric_tuple)[source]

Checks if all inputs are correct.

edflow.hooks.checkpoint_hooks.common.torch_parse_global_step(checkpoint)[source]
edflow.hooks.checkpoint_hooks.common.tf_parse_global_step(checkpoint)[source]
edflow.hooks.checkpoint_hooks.common.get_checkpoint_files(checkpoint_root)[source]

Return {global_step: [files,…]}.

Parameters

checkpoint_root (str) – Path to where the checkpoints live.

class edflow.hooks.checkpoint_hooks.common.KeepBestCheckpoints(checkpoint_root, metric_template, metric_key, n_keep=5, lower_is_better=True)[source]

Bases: edflow.hooks.hook.Hook

Tries to find a metric for all checkpoints and keeps the n_keep best checkpoints and the latest checkpoint.

__init__(checkpoint_root, metric_template, metric_key, n_keep=5, lower_is_better=True)[source]
Parameters
  • checkpoint_root (str) – Path to look for checkpoints.

  • metric_template (str) – Format string to find metric file.

  • metric_key (str) – Key to use from metric file.

  • n_keep (int) – Maximum number of checkpoints to keep.

get_loss(step)[source]
after_epoch(ep)[source]

Called after each epoch.

Parameters

epoch (int) – Index of epoch that just ended.