edflow.iterators.model_iterator module¶
Summary¶
Exceptions:
Raised when we receive a SIGTERM signal to shut down. |
Classes:
Implements a similar interface as the |
Reference¶
-
exception
edflow.iterators.model_iterator.
ShutdownRequest
[source]¶ Bases:
Exception
Raised when we receive a SIGTERM signal to shut down. Allows hooks to perform final actions such as writing a last checkpoint.
-
class
edflow.iterators.model_iterator.
PyHookedModelIterator
(config, root, model, datasets, hook_freq=100, num_epochs=100, hooks=[], bar_position=0, nogpu=False, desc='')[source]¶ Bases:
object
Implements a similar interface as the
HookedModelIterator
to train framework independent models.-
__init__
(config, root, model, datasets, hook_freq=100, num_epochs=100, hooks=[], bar_position=0, nogpu=False, desc='')[source]¶ Constructor.
- Parameters
model (object) – Model class.
num_epochs (int) – Number of times to iterate over the data.
hooks (list) – List containing
Hook
instances.hook_freq (int) – Frequency at which hooks are evaluated.
bar_position (int) – Used by tqdm to place bars at the right position when using multiple Iterators in parallel.
-
get_global_step
(*args, **kwargs)[source]¶ Get the global step. The global step corresponds to the number of steps the model was trained for. It is updated in each step during training but not during evaluation.
-
set_global_step
(step)[source]¶ Set the global step. Should be done when restoring a model from a checkpoint.
-
iterate
(batches)[source]¶ Iterates over the data supplied and feeds it to the model.
- Parameters
batch_iterator (Iterable) – Iterable returning training data.
batch_iterator_validation (Iterable) – Iterable returning validation data or None
-
run
(fetches, feed_dict)[source]¶ Runs all fetch ops and stores the results.
- Parameters
fetches (dict) – name: Callable pairs.
feed_dict (dict) – Passed as kwargs to all fetch ops
- Returns
name: results pairs.
- Return type
dict
-
run_hooks
(index, fetches=None, feeds=None, batch=None, results=None, before=True, epoch_hooks=False)[source]¶ Run all hooks and manage their stuff. The passed arguments determine which method of the hooks is called.
- Parameters
index (int) – Current epoch or batch index. This is not necessarily the global training step.
fetches (list or dict) – Fetches for the next session.run call.
feeds (dict) – Feeds for the next session.run call.
results (same as fetches) – Results from the last session.run call.
before (bool) – If not obvious determines if the before or after methods of the hooks should be called.
- Returns
test (same as fetches) – Updated fetches.
test (dict) – Updated feeds
-