edflow.hooks.hook module

Summary

Classes:

Hook

Base Hook to be inherited from.

Reference

class edflow.hooks.hook.Hook[source]

Bases: object

Base Hook to be inherited from. Hooks can be passed to HookedModelIterator and will be called at fixed intervals.

The inheriting class only needs to overwrite those methods below, which are of interest.

In principle a hook can be used to do anything during its execution. It is intended to be used as an update mechanism for the standard fetches and feeds, passed to the session managed e.g. by a HookedModelIterator and then working with the results of the run call to the session.

Assuming there is one hook that is passed to a HookedModelIterator its methods will be called in the following fashion:

for epoch in epochs:
    hook.before_epoch(epoch)
    for i, batch in enumerate(batches):
        fetches, feeds = some_function(batch)
        hook.before_step(i, fetches, feeds)  # change fetches & feeds

        results = session.run(fetches, feed_dict=feeds)

        hook.after_step(i, results)
    hook.after_epoch(epoch)
before_epoch(epoch)[source]

Called before each epoch.

Parameters

epoch (int) – Index of epoch that just started.

before_step(step, fetches, feeds, batch)[source]

Called before each step. Can update any feeds and fetches.

Parameters
  • step (int) – Current training step.

  • fetches (list or dict) – Fetches for the next session.run call.

  • feeds (dict) – Data used at this step.

  • batch (list or dict) – All data available at this step.

after_step(step, last_results)[source]

Called after each step.

Parameters
  • step (int) – Current training step.

  • last_results (list) – Results from last time this hook was called.

after_epoch(epoch)[source]

Called after each epoch.

Parameters

epoch (int) – Index of epoch that just ended.

at_exception(exception)[source]

Called when an exception is raised.

Parameters

exception

Raises
  • be – raised again after all

  • been – handled