edflow.iterators.torch_iterator module

Summary

Classes:

TorchHookedModelIterator

Iterator class for framework PyTorch, inherited from PyHookedModelIterator.

Reference

class edflow.iterators.torch_iterator.TorchHookedModelIterator(*args, transform=True, **kwargs)[source]

Bases: edflow.iterators.model_iterator.PyHookedModelIterator

Iterator class for framework PyTorch, inherited from PyHookedModelIterator.

Parameters

transform (bool) – If the batches are to be transformed to pytorch tensors. Should be true even if your input is already pytorch tensors!

__init__(*args, transform=True, **kwargs)[source]

Constructor.

Parameters
  • model (object) – Model class.

  • num_epochs (int) – Number of times to iterate over the data.

  • hooks (list) – List containing Hook instances.

  • hook_freq (int) – Frequency at which hooks are evaluated.

  • bar_position (int) – Used by tqdm to place bars at the right position when using multiple Iterators in parallel.