Source code for edflow.edsetup_files.iterator

from edflow.iterators.template_iterator import TemplateIterator


[docs]class Iterator(TemplateIterator): """ Clean iterator skeleton for initialization. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def save(self, checkpoint_path): """ Function for saving the model at a given state Parameters ---------- checkpoint_path: The path where the saved checkpoint should lie. """
[docs] def restore(self, checkpoint_path): """ Function for model restoration from a given checkpoint. Parameters ---------- checkpoint_path: The path where the checkpoint for restoring lies. Returns ------- The restored model from the given checkpoint. """ pass
[docs] def step_op(self, model, **kwargs): """ The main method to be called for training by the iterator. Calculating the loss, optimizer step etc. Parameters ---------- model : The given model class. Returns ------- A dictionary with `train_op`, `log_op` and `eval_op` keys and their returns as their values. """ inputs, labels = kwargs["inputs"], kwargs["labels"] outputs = model(inputs) def train_op(): """Takes care of the training process.""" pass def log_op(): """ Takes care of the logging process. Returns ------- A dictionary whose values are to be logged. """ return {"inputs": inputs, "labels": labels} def eval_op(): """ Takes care of the evaluation. Returns ------- A dictionary with values to be evaluated. """ return {"outputs": outputs} return {"train_op": train_op, "log_op": log_op, "eval_op": eval_op}