Source code for edflow.iterators.tf_evaluator

import tensorflow as tf
import os

from edflow.iterators.tf_iterator import TFHookedModelIterator
from edflow.hooks.checkpoint_hooks.common import WaitForCheckpointHook
from edflow.hooks.checkpoint_hooks.tf_checkpoint_hook import (
    RestoreModelHook,
    RestoreTFModelHook,
    RestoreCurrentCheckpointHook,
)
from edflow.project_manager import ProjectManager

P = ProjectManager()


[docs]class TFBaseEvaluator(TFHookedModelIterator):
[docs] def __init__(self, *args, desc="Eval", hook_freq=1, num_epochs=1, **kwargs): """ New Base evaluator restores given checkpoint path if provided, else scans checkpoint directory for latest checkpoint and uses that Parameters ---------- desc : str a description for the evaluator. This description will be used during the logging. hook_freq : int Frequency at which hooks are evaluated. num_epochs : int Number of times to iterate over the data. """ kwargs.update({"desc": desc, "hook_freq": hook_freq, "num_epochs": num_epochs}) super().__init__(*args, **kwargs) self.define_graph()
[docs] def initialize(self, checkpoint_path=None): self.restore_variables = self.model.variables # wait for new checkpoint and restore if checkpoint_path: restorer = RestoreCurrentCheckpointHook( variables=self.restore_variables, checkpoint_path=checkpoint_path, global_step_setter=self.set_global_step, ) self.hooks += [restorer] else: restorer = RestoreTFModelHook( variables=self.restore_variables, checkpoint_path=ProjectManager.checkpoints, global_step_setter=self.set_global_step, ) waiter = WaitForCheckpointHook( checkpoint_root=ProjectManager.checkpoints, callback=restorer, eval_all=self.config.get("eval_all", False), filter_cond=eval(self.config.get("fcond", "lambda c: True")), ) self.hooks += [waiter]
[docs] def define_graph(self): pass
[docs] def step_ops(self): return self.model.outputs