from edflow.hooks.hook import Hook
from edflow.util import retrieve
[docs]class ExpandHook(Hook):
"""Retrieve paths."""
[docs] def __init__(
self, paths, interval, default=None,
):
"""
Parameters
----------
paths : list of keypaths to expand.
interval : int
The interval in which expansion is performed.
"""
self.paths = paths
self.interval = interval
self.default = default
[docs] def after_step(self, step, last_results):
"""Called after each step."""
if step % self.interval == 0:
for path in self.paths:
retrieve(last_results, path, default=self.default)
[docs]class IntervalHook(Hook):
"""This hook manages a set of hooks, which it will run each time its
interval flag is set to True."""
[docs] def __init__(
self,
hooks,
interval,
start=None,
stop=None,
modify_each=None,
modifier=lambda interval: 2 * interval,
max_interval=None,
get_step=None,
):
"""
Parameters
----------
hook : list of Hook
The set of managed hooks. Each must implement
the methods of a :class:`Hook`.
interval : int
The number of steps after which the managed hooks
are run.
start : int
If `start` is not None, the first time the hooks are
run ist after `start` number of steps have been made.
stop : int
If given, this hook is not evaluated anymore after
`stop` steps.
modify_each : int
If given, `modifier` is called on the interval
after this many executions of thois hook. If `None` it is set
to :attr:`interval`. In case you do not want any mofification
you can either set :attr:`max_interval` to :attr:`interval` or
choose the modifier to be `lambda x: x` or set
:attr:`modify_each` to `float: inf)`.
modifier : Callable
See `modify_each`.
max_interval : int
If given, the modifier can only increase the
interval up to this number of steps.
get_step : Callable
If given, prefer over the use of batch index
to determine run condition, e.g. to run based on global step.
"""
self.hooks = hooks
self.base_interval = interval
inf = float("inf")
self.start = start if start is not None else -1
self.stop = stop if stop is not None else inf
self.modival = modify_each if modify_each is not None else interval
self.modifier = modifier
self.max_interval = max_interval if max_interval is not None else inf
self.get_step = get_step
self.counter = 0
[docs] def run_condition(self, step, is_before=False):
if self.get_step is not None:
step = self.get_step()
if step > self.start and step <= self.stop:
if step % self.base_interval == 0:
self.counter += 1 if is_before else 0
return True
return False
[docs] def maybe_modify(self, step):
if self.counter % self.modival == 0:
new_interval = self.modifier(self.base_interval)
self.base_interval = min(self.max_interval, new_interval)
[docs] def before_epoch(self, *args, **kwargs):
"""Called before each epoch."""
for hook in self.hooks:
hook.before_epoch(*args, **kwargs)
[docs] def before_step(self, step, *args, **kwargs):
"""Called before each step. Can update any feeds and fetches."""
if self.run_condition(step, True):
for hook in self.hooks:
hook.before_step(step, *args, **kwargs)
[docs] def after_step(self, step, *args, **kwargs):
"""Called after each step."""
if self.run_condition(step, False):
for hook in self.hooks:
hook.after_step(step, *args, **kwargs)
self.maybe_modify(step)
[docs] def after_epoch(self, *args, **kwargs):
"""Called after each epoch."""
for hook in self.hooks:
hook.after_epoch(*args, **kwargs)