edflow.hooks.checkpoint_hooks.common module¶
Summary¶
Classes:
Collects data. |
|
Tries to find a metric for all checkpoints and keeps the n_keep best checkpoints and the latest checkpoint. |
|
Collects lots of data, stacks them and then stores them. |
|
Waits until a new checkpoint is created, then lets the Iterator continue. |
Functions:
Makes a nice representation of a nested dict. |
|
Return {global_step: [files,…]}. |
|
Return path to name of latest checkpoint in checkpoint_root dir. |
|
Make an iterator that yields key value pairs. |
|
Same as enumerate, but yields str(index). |
|
Checks if all inputs are correct. |
|
Reference¶
-
edflow.hooks.checkpoint_hooks.common.get_latest_checkpoint(checkpoint_root, filter_cond=<function <lambda>>)[source]¶ Return path to name of latest checkpoint in checkpoint_root dir.
- Parameters
checkpoint_root (str) – Path to where the checkpoints live.
filter_cond (Callable) – A function used to filter files, to only get the checkpoints that are wanted.
- Returns
path of the latest checkpoint. Note that for tensorflow checkpoints this is not an existing file, but path{.index,.meta,data*} should be
- Return type
str
-
class
edflow.hooks.checkpoint_hooks.common.WaitForCheckpointHook(checkpoint_root, filter_cond=<function WaitForCheckpointHook.<lambda>>, interval=5, add_sec=5, callback=None, eval_all=False)[source]¶ Bases:
edflow.hooks.hook.HookWaits until a new checkpoint is created, then lets the Iterator continue.
-
__init__(checkpoint_root, filter_cond=<function WaitForCheckpointHook.<lambda>>, interval=5, add_sec=5, callback=None, eval_all=False)[source]¶ - Parameters
checkpoint_root (str) – Path to look for checkpoints.
filter_cond (Callable) – A function used to filter files, to only get the checkpoints that are wanted.
interval (float) – Number of seconds after which to check for a new checkpoint again.
add_sec (float) – Number of seconds to wait, after a checkpoint is found, to avoid race conditions, if the checkpoint is still being written at the time it’s meant to be read.
callback (Callable) – Callback called with path of found checkpoint.
eval_all (bool) – Accept all instead of just latest checkpoint.
-
-
edflow.hooks.checkpoint_hooks.common.strenumerate(*args, **kwargs)[source]¶ Same as enumerate, but yields str(index).
-
edflow.hooks.checkpoint_hooks.common.make_iterator(list_or_dict)[source]¶ Make an iterator that yields key value pairs.
-
edflow.hooks.checkpoint_hooks.common.dict_repr(some_dict, pre='', level=0)[source]¶ Makes a nice representation of a nested dict.
-
class
edflow.hooks.checkpoint_hooks.common.CollectorHook[source]¶ Bases:
edflow.hooks.hook.HookCollects data. Supposed to be used as base class.
-
class
edflow.hooks.checkpoint_hooks.common.StoreArraysHook(save_root)[source]¶ Bases:
edflow.hooks.checkpoint_hooks.common.CollectorHookCollects lots of data, stacks them and then stores them.
-
class
edflow.hooks.checkpoint_hooks.common.MetricTuple(input_names, output_names, metric, name)¶ Bases:
tuple-
input_names¶ Alias for field number 0
-
metric¶ Alias for field number 2
-
name¶ Alias for field number 3
-
output_names¶ Alias for field number 1
-
-
edflow.hooks.checkpoint_hooks.common.test_valid_metrictuple(metric_tuple)[source]¶ Checks if all inputs are correct.
-
edflow.hooks.checkpoint_hooks.common.get_checkpoint_files(checkpoint_root)[source]¶ Return {global_step: [files,…]}.
- Parameters
checkpoint_root (str) – Path to where the checkpoints live.
-
class
edflow.hooks.checkpoint_hooks.common.KeepBestCheckpoints(checkpoint_root, metric_template, metric_key, n_keep=5, lower_is_better=True)[source]¶ Bases:
edflow.hooks.hook.HookTries to find a metric for all checkpoints and keeps the n_keep best checkpoints and the latest checkpoint.
-
__init__(checkpoint_root, metric_template, metric_key, n_keep=5, lower_is_better=True)[source]¶ - Parameters
checkpoint_root (str) – Path to look for checkpoints.
metric_template (str) – Format string to find metric file.
metric_key (str) – Key to use from metric file.
n_keep (int) – Maximum number of checkpoints to keep.
-