Source code for edflow.hooks.checkpoint_hooks.lambda_checkpoint_hook

import os

from edflow.hooks.hook import Hook
from edflow.custom_logging import get_logger
from edflow.hooks.checkpoint_hooks.common import get_latest_checkpoint


[docs]class LambdaCheckpointHook(Hook): """ """
[docs] def __init__( self, root_path, global_step_getter, global_step_setter, save, restore, interval=None, ckpt_zero=False, modelname="model", ): """ Parameters ---------- """ self.root = root_path self.logger = get_logger(self) self.global_step_getter = global_step_getter self.global_step_setter = global_step_setter self._save = save self._restore = restore self.interval = interval self.ckpt_zero = ckpt_zero os.makedirs(root_path, exist_ok=True) self.savename = os.path.join(root_path, "{}-{{}}.ckpt".format(modelname)) self._active = False
[docs] def before_epoch(self, epoch): """ Parameters ---------- epoch : Returns ------- """ if self.ckpt_zero and self.global_step_getter() == 0: self.save(force_active=True)
[docs] def after_epoch(self, epoch): """ Parameters ---------- epoch : Returns ------- """ self._active = True if self.interval is None: self.save()
[docs] def after_step(self, step, last_results): """ Parameters ---------- step : last_results : Returns ------- """ if step > 0: self._active = True step = self.global_step_getter() if self.interval is not None and step % self.interval == 0: self.save()
[docs] def at_exception(self, *args, **kwargs): """ Parameters ---------- *args : **kwargs : Returns ------- """ self.save()
[docs] def save(self, force_active=False): """ """ if self._active or force_active: savename = self.savename.format(self.global_step_getter()) self._save(savename) self.logger.info("Saved model to {}".format(savename))
def __call__(self, checkpoint): """Load checkpoint and set global step.""" self._restore(checkpoint) self.logger.info("Restored model from {}".format(checkpoint)) step = self.parse_global_step(checkpoint) if self.global_step_setter is not None: self.global_step_setter(step) self.logger.info("Global step: {}".format(step))
[docs] @staticmethod def parse_global_step(checkpoint): """ Parameters ---------- checkpoint : Returns ------- """ checkpoint = checkpoint.rsplit(".ckpt", 1)[0] global_step = int(checkpoint.rsplit("-", 1)[1]) return global_step