Source code for edflow.iterators.torch_iterator

from edflow.iterators.model_iterator import PyHookedModelIterator
from edflow.hooks.pytorch_hooks import DataPrepHook


[docs]class TorchHookedModelIterator(PyHookedModelIterator): """ Iterator class for framework PyTorch, inherited from PyHookedModelIterator. Parameters ---------- transform : bool If the batches are to be transformed to pytorch tensors. Should be true even if your input is already pytorch tensors! """
[docs] def __init__(self, *args, transform=True, **kwargs): super().__init__(*args, **kwargs) # check if the data preparation hook is already supplied. check = transform and not any( [isinstance(hook, DataPrepHook) for hook in self.hooks] ) if check: self.hooks += [DataPrepHook()]