Source code for edflow.iterators.model_iterator

import signal, sys, math
from tqdm import tqdm, trange

from edflow.custom_logging import get_logger
from edflow.util import walk


[docs]class ShutdownRequest(Exception): """Raised when we receive a SIGTERM signal to shut down. Allows hooks to perform final actions such as writing a last checkpoint.""" pass
[docs]class PyHookedModelIterator(object): """Implements a similar interface as the :class:`HookedModelIterator` to train framework independent models."""
[docs] def __init__( self, config, root, model, datasets, hook_freq=100, num_epochs=100, hooks=[], bar_position=0, nogpu=False, desc="", ): """Constructor. Parameters ---------- model : object Model class. num_epochs : int Number of times to iterate over the data. hooks : list List containing :class:`Hook` instances. hook_freq : int Frequency at which hooks are evaluated. bar_position : int Used by tqdm to place bars at the right position when using multiple Iterators in parallel. """ signal.signal(signal.SIGTERM, self._handle_sigterm) signal.signal(signal.SIGINT, self._handle_sigterm) self.config = config self.root = root self.model = model self.datasets = datasets # backwards compatibility self.dataset = datasets["train"] self.validation_dataset = datasets["validation"] self.num_epochs = num_epochs self.hooks = hooks self.epoch_hooks = list() self.hook_freq = hook_freq self.bar_pos = bar_position * 2 self.desc = desc self.logger = get_logger(type(self).__name__) self._global_step = 0 self._batch_step = 0 self._epoch_step = 0 self._split = None
[docs] def get_split(self, *args, **kwargs): """Get the current split that is processed.""" return self._split
[docs] def get_global_step(self, *args, **kwargs): """Get the global step. The global step corresponds to the number of steps the model was trained for. It is updated in each step during training but not during evaluation.""" return self._global_step
[docs] def set_global_step(self, step): """Set the global step. Should be done when restoring a model from a checkpoint.""" self._global_step = step
[docs] def get_batch_step(self, *args, **kwargs): """Batch index of current run.""" return self._batch_step
[docs] def get_epoch_step(self, *args, **kwargs): """Epoch index of current run.""" return self._epoch_step
[docs] def reset_global_step(self): self.set_global_step(0)
[docs] def increment_global_step(self, *args, **kwargs): if not self.config.get("test_mode", False): self._global_step += 1 return self._global_step
[docs] def make_feeds(self, batch): # copy of batches feeds = walk(batch, lambda val: val) return feeds
def _handle_sigterm(self, signum, frame): e = ShutdownRequest() self._handle_exception(e) sys.exit(0) def _handle_exception(self, e): for hook in self.hooks: hook.at_exception(e)
[docs] def iterate(self, batches): """Iterates over the data supplied and feeds it to the model. Parameters ---------- batch_iterator : Iterable Iterable returning training data. batch_iterator_validation : Iterable Iterable returning validation data or None """ try: self._iterate(batches) except Exception as e: self._handle_exception(e) raise e
def _iterate(self, batches): """Iterates over the data supplied and feeds it to the model. Parameters ---------- batch_iterator : Iterable Iterable returning training data. """ step_ops = self.step_ops() epoch_hooks_only = self.config.get("test_mode", False) pos = self.bar_pos base = self.desc + " - " if self.desc != "" else "" desc_epoch = base + "Epoch" desc_batch = base + "Batch" # TODO use val freq validation_frequency = self.config.get( "val_freq", self.config.get("log_freq", -1) ) batches_per_epoch = 0 if epoch_hooks_only else len(batches["train"]) if "max_batches_per_epoch" in self.config: batches_per_epoch = min( batches_per_epoch, self.config["max_batches_per_epoch"] ) num_epochs = 1 if epoch_hooks_only else self.num_epochs start_epoch = ( 0 if epoch_hooks_only else (self.get_global_step() // batches_per_epoch) ) start_step = ( 0 if epoch_hooks_only else (self.get_global_step() % batches_per_epoch) ) for epoch_step in trange( start_epoch, num_epochs, initial=start_epoch, total=num_epochs, desc=desc_epoch, position=pos, dynamic_ncols=True, leave=False, ): self._epoch_step = epoch_step ############# run one batch on each split until new epoch or max steps batches["train"].reset() self.run_hooks(epoch_step, before=True) for batch_step in trange( start_step, batches_per_epoch, initial=start_step, total=batches_per_epoch, desc=desc_batch, position=pos + 1, dynamic_ncols=True, leave=False, ): self._batch_step = batch_step def lazy_split_op(split): def split_op(): self._split = split batch = next(batches[split]) feeds = self.make_feeds(batch) fetches = step_ops self.run_hooks(batch_step, fetches, feeds, batch, before=True) return self.run(fetches, feed_dict=feeds) return split_op results = {"global_step": self.get_global_step()} for split in batches: results[split] = lazy_split_op(split) self.run_hooks(batch_step, results=results, before=False) del results self.increment_global_step() if self.get_global_step() >= self.config.get("num_steps", float("inf")): break self.run_hooks(epoch_step, before=False) start_step = 0 ############# run one epoch on each split # only continue a split as long as someone is retrieving results for split in batches: batches[split].reset() self.run_hooks(epoch_step, before=True, epoch_hooks=True) tqdm_iterator = trange( len(batches[split]), desc=split, position=pos + 1, dynamic_ncols=True, leave=False, ) for batch_step in tqdm_iterator: self._batch_step = batch_step active = False def lazy_split_op(split): def split_op(): nonlocal active active = True self._split = split batch = next(batches[split]) feeds = self.make_feeds(batch) fetches = step_ops self.run_hooks( batch_step, fetches, feeds, batch, before=True, epoch_hooks=True, ) return self.run(fetches, feed_dict=feeds) return split_op results = { "global_step": self.get_global_step(), split: lazy_split_op(split), } self.run_hooks( batch_step, results=results, before=False, epoch_hooks=True ) del results if batches[split].is_new_epoch or not active: tqdm_iterator.update() tqdm_iterator.close() self.logger.info("Done with {}".format(split)) break self.run_hooks(epoch_step, before=False, epoch_hooks=True) if self.get_global_step() >= self.config.get("num_steps", float("inf")): break
[docs] def run(self, fetches, feed_dict): """Runs all fetch ops and stores the results. Parameters ---------- fetches : dict name: Callable pairs. feed_dict : dict Passed as kwargs to all fetch ops Returns ------- dict name: results pairs. """ def fn(fetch_fn): return fetch_fn(self.model, **feed_dict) results = walk(fetches, fn) return results
[docs] def run_hooks( self, index, fetches=None, feeds=None, batch=None, results=None, before=True, epoch_hooks=False, ): """Run all hooks and manage their stuff. The passed arguments determine which method of the hooks is called. Parameters ---------- index : int Current epoch or batch index. This is not necessarily the global training step. fetches : list or dict Fetches for the next session.run call. feeds : dict Feeds for the next session.run call. results : same as fetches Results from the last session.run call. before : bool If not obvious determines if the before or after methods of the hooks should be called. Returns ------- test : same as fetches Updated fetches. test : dict Updated feeds """ is_step = fetches is not None and feeds is not None is_step = is_step or results is not None condition = self._global_step % self.hook_freq == 0 or not is_step hooks = self.hooks if not epoch_hooks else self.epoch_hooks if condition: for hook in hooks: if before: if is_step: hook.before_step(index, fetches, feeds, batch) else: hook.before_epoch(index) else: if is_step: hook.after_step(index, results) else: hook.after_epoch(index)
[docs] def step_ops(self): """Defines ops that are called at each step. Returns ------- The operation run at each step.""" raise NotImplementedError()
[docs] def initialize(self, checkpoint_path=None): pass