edflow.iterators.model_iterator module

Summary

Exceptions:

ShutdownRequest

Raised when we receive a SIGTERM signal to shut down.

Classes:

PyHookedModelIterator

Implements a similar interface as the HookedModelIterator to train framework independent models.

Reference

exception edflow.iterators.model_iterator.ShutdownRequest[source]

Bases: Exception

Raised when we receive a SIGTERM signal to shut down. Allows hooks to perform final actions such as writing a last checkpoint.

class edflow.iterators.model_iterator.PyHookedModelIterator(config, root, model, datasets, hook_freq=100, num_epochs=100, hooks=[], bar_position=0, nogpu=False, desc='')[source]

Bases: object

Implements a similar interface as the HookedModelIterator to train framework independent models.

__init__(config, root, model, datasets, hook_freq=100, num_epochs=100, hooks=[], bar_position=0, nogpu=False, desc='')[source]

Constructor.

Parameters
  • model (object) – Model class.

  • num_epochs (int) – Number of times to iterate over the data.

  • hooks (list) – List containing Hook instances.

  • hook_freq (int) – Frequency at which hooks are evaluated.

  • bar_position (int) – Used by tqdm to place bars at the right position when using multiple Iterators in parallel.

get_split(*args, **kwargs)[source]

Get the current split that is processed.

get_global_step(*args, **kwargs)[source]

Get the global step. The global step corresponds to the number of steps the model was trained for. It is updated in each step during training but not during evaluation.

set_global_step(step)[source]

Set the global step. Should be done when restoring a model from a checkpoint.

get_batch_step(*args, **kwargs)[source]

Batch index of current run.

get_epoch_step(*args, **kwargs)[source]

Epoch index of current run.

reset_global_step()[source]
increment_global_step(*args, **kwargs)[source]
make_feeds(batch)[source]
iterate(batches)[source]

Iterates over the data supplied and feeds it to the model.

Parameters
  • batch_iterator (Iterable) – Iterable returning training data.

  • batch_iterator_validation (Iterable) – Iterable returning validation data or None

run(fetches, feed_dict)[source]

Runs all fetch ops and stores the results.

Parameters
  • fetches (dict) – name: Callable pairs.

  • feed_dict (dict) – Passed as kwargs to all fetch ops

Returns

name: results pairs.

Return type

dict

run_hooks(index, fetches=None, feeds=None, batch=None, results=None, before=True, epoch_hooks=False)[source]

Run all hooks and manage their stuff. The passed arguments determine which method of the hooks is called.

Parameters
  • index (int) – Current epoch or batch index. This is not necessarily the global training step.

  • fetches (list or dict) – Fetches for the next session.run call.

  • feeds (dict) – Feeds for the next session.run call.

  • results (same as fetches) – Results from the last session.run call.

  • before (bool) – If not obvious determines if the before or after methods of the hooks should be called.

Returns

  • test (same as fetches) – Updated fetches.

  • test (dict) – Updated feeds

step_ops()[source]

Defines ops that are called at each step.

Returns

Return type

The operation run at each step.

initialize(checkpoint_path=None)[source]