import time
import os
import re
import pickle
import numpy as np
from collections import OrderedDict, namedtuple
from edflow.hooks.hook import Hook
from edflow.custom_logging import get_logger
from edflow.project_manager import ProjectManager
from edflow.util import retrieve
# Values storable as npz
SAVABLES = (np.ndarray, np.int64, int, float, np.float)
P = ProjectManager()
[docs]def get_latest_checkpoint(checkpoint_root, filter_cond=lambda c: True):
"""Return path to name of latest checkpoint in checkpoint_root dir.
Parameters
----------
checkpoint_root : str
Path to where the checkpoints live.
filter_cond : Callable
A function used to filter files, to only
get the checkpoints that are wanted.
Returns
------
str
path of the latest checkpoint. Note that for tensorflow
checkpoints this is not an existing file, but
path{.index,.meta,data*} should be
"""
ckpt_root = checkpoint_root
all_files = sorted(os.listdir(ckpt_root))
# get actual files belonging to checkpoint as well as normalized name as
# used by tf.Saver.restore
checkpoint_files = list()
checkpoint_names = list()
for f in all_files:
if f.endswith(".ckpt"):
checkpoint_files.append(f)
checkpoint_names.append(f)
elif f.endswith(".index"):
# check if filename matches tensorflow index file of form
# name.ckpt-300.index and continue with name.ckpt-300
checkpoint_files.append(f)
checkpoint_names.append(f[: -len(".index")])
# convert to list of pairs [name, timestamp of file] to retrieve latest
checkpoints = []
for file_, name in zip(checkpoint_files, checkpoint_names):
file_ = os.path.join(ckpt_root, file_)
name = os.path.join(ckpt_root, name)
try:
mt = os.path.getmtime(file_)
except FileNotFoundError:
# checkpoint was deleted, make it infinitely old
mt = -float("inf")
checkpoints += [[name, mt]]
checkpoints = [ckpt for ckpt in checkpoints if filter_cond(ckpt[0])]
if len(checkpoints) > 0:
checkpoints = sorted(checkpoints, key=lambda pt: -pt[1])
latest = checkpoints[0][0]
else:
latest = None
return latest
[docs]class WaitForCheckpointHook(Hook):
"""Waits until a new checkpoint is created, then lets the Iterator
continue."""
[docs] def __init__(
self,
checkpoint_root,
filter_cond=lambda c: True,
interval=5,
add_sec=5,
callback=None,
eval_all=False,
):
"""
Parameters
----------
checkpoint_root : str
Path to look for checkpoints.
filter_cond : Callable
A function used to filter files, to only
get the checkpoints that are wanted.
interval : float
Number of seconds after which to check for a new
checkpoint again.
add_sec : float
Number of seconds to wait, after a checkpoint is
found, to avoid race conditions, if the checkpoint is still
being written at the time it's meant to be read.
callback : Callable
Callback called with path of found
checkpoint.
eval_all : bool
Accept all instead of just latest checkpoint.
"""
self.root = checkpoint_root
self._fcond = filter_cond
self.sleep_interval = interval
self.additional_wait = add_sec
self.callback = callback
self.eval_all = eval_all
self.logger = get_logger(self)
self.known_checkpoints = set()
[docs] def fcond(self, c):
cond = self._fcond(c)
if self.eval_all:
cond = cond and c not in self.known_checkpoints
return cond
[docs] def look(self):
"""Loop until a new checkpoint is found."""
self.logger.info("Waiting for new checkpoint.")
while True:
latest_checkpoint = get_latest_checkpoint(self.root, self.fcond)
if (
latest_checkpoint is not None
and latest_checkpoint not in self.known_checkpoints
):
self.known_checkpoints.add(latest_checkpoint)
time.sleep(self.additional_wait)
self.logger.info("Found new checkpoint: {}".format(latest_checkpoint))
if self.callback is not None:
self.callback(latest_checkpoint)
break
time.sleep(self.sleep_interval)
[docs] def before_epoch(self, ep):
self.look()
[docs]def strenumerate(*args, **kwargs):
"""Same as enumerate, but yields str(index)."""
for i, v in enumerate(*args, **kwargs):
yield str(i), v
[docs]def make_iterator(list_or_dict):
"""Make an iterator that yields key value pairs."""
if isinstance(list_or_dict, (dict, OrderedDict)):
return list_or_dict.items()
elif isinstance(list_or_dict, (list, tuple)):
return strenumerate(list_or_dict)
else:
msg = "results must be list or dict but is "
msg += "{} ".format(type(list_or_dict))
raise ValueError(msg)
[docs]def dict_repr(some_dict, pre="", level=0):
"""Makes a nice representation of a nested dict."""
outstr = ""
n = 1
N = len(some_dict)
for k, v in some_dict.items():
corner = "├╴ " if n < N else "└╴ "
straight = "│ " if n < N else " "
if isinstance(v, dict):
outstr += pre + "{}{}\n".format(corner, k)
outstr += dict_repr(v, pre + straight, level + 1)
else:
outstr += pre + "{}{}: {}\n".format(corner, k, type(v))
n += 1
return outstr
[docs]class CollectorHook(Hook):
"""Collects data. Supposed to be used as base class."""
[docs] def __init__(self):
self.collected_data = {}
self.logger = get_logger(self, "latest_eval")
[docs] def after_step(self, step, results):
self.stack_results(results, self.collected_data)
[docs] def stack_results(self, new_data, all_data):
"""Given the current collected data append the new results along the
batch dimension.
Parameters
----------
new_data : list or dict
data to append.
all_data : list or dict
data to append to.
"""
iterator = make_iterator(new_data)
for key, value in iterator:
if isinstance(value, SAVABLES):
if len(value.shape) == 0:
value = np.reshape(value, [1])
# Leave branch
if key in all_data:
all_data[key] = np.concatenate([all_data[key], value])
else:
all_data[key] = value
else:
if key not in all_data:
all_data[key] = {}
self.stack_results(value, all_data[key])
[docs]class StoreArraysHook(CollectorHook):
"""Collects lots of data, stacks them and then stores them."""
[docs] def __init__(self, save_root):
"""Collect all outputs of step op and store them as npz."""
super().__init__()
self.root = save_root
[docs] def after_epoch(self, epoch):
data = self.collected_data
self.logger.info("Collected Data:\n" + dict_repr(data))
global_step = data["global_step"][0]
# Flatten results dictionary for easy storage
self.flat_dict = {}
self.flatten_results(data, "", self.flat_dict)
self.logger.info("Stored Data:\n" + dict_repr(self.flat_dict))
name = "{:0>6d}_results".format(global_step)
name = os.path.join(self.root, name)
np.savez_compressed(name, **self.flat_dict)
[docs] def flatten_results(self, results, prefix, store_dict):
"""Recursively walk over the results dictionary and stack the data.
Parameters
----------
results : dict or list
Containing results.
prefix : str
Prepended to name when storing.
store_dict : dict
Flat storage dictionary.
"""
iterator = make_iterator(results)
for name, value in iterator:
save_name = "{}_{}".format(prefix, name) if prefix != "" else name
if isinstance(value, SAVABLES):
store_dict[save_name] = value
else:
self.flatten_results(value, save_name, store_dict)
MetricTuple = namedtuple("MetricTuple", "input_names output_names metric name")
[docs]def test_valid_metrictuple(metric_tuple):
"""Checks if all inputs are correct."""
in_names = metric_tuple.input_names
out_names = metric_tuple.output_names
if not isinstance(in_names, dict):
raise ValueError("input_names must be a dict")
if not isinstance(out_names, dict):
raise ValueError("output_names must be a dict")
if not callable(metric_tuple.metric):
raise ValueError("metric must be callable")
if not isinstance(metric_tuple.name, str):
raise ValueError("name must be a string")
if not all([isinstance(i, str) for i in in_names.values()]):
raise ValueError("All entries in input_names must be strings")
if not all([isinstance(o, str) for o in out_names.values()]):
raise ValueError("All entries in output_names must be strings")
identical_names = set(in_names.values()) & set(out_names.values())
if len(identical_names) > 0:
raise ValueError(
"All names must be unique. " "Found {}".format(identical_names)
)
# enough checking already :)
[docs]def torch_parse_global_step(checkpoint):
e_s = os.path.basename(checkpoint).split(".")[0].split("-")
if len(e_s) > 1:
epoch = e_s[0]
step = e_s[1].split("_")[0]
else:
epoch = 0
step = e_s[0].split("_")[0]
epoch, step = int(epoch), int(step)
return step
[docs]def tf_parse_global_step(checkpoint):
global_step = int(checkpoint.rsplit("-", 1)[1])
return global_step
[docs]def get_checkpoint_files(checkpoint_root):
"""Return {global_step: [files,...]}.
Parameters
----------
checkpoint_root : str
Path to where the checkpoints live.
"""
ckpt_root = checkpoint_root
files = []
checkpoints = []
global_steps = []
all_files = os.listdir(ckpt_root)
for p in all_files:
p = os.path.join(ckpt_root, p)
if ".ckpt" in p:
name, ext = os.path.splitext(p)
if not ext == ".ckpt":
normalized = name
global_step = tf_parse_global_step(normalized)
else:
normalized = p
global_step = torch_parse_global_step(normalized)
files.append(p)
checkpoints.append(normalized)
global_steps.append(global_step)
stepmap = dict()
for step in np.unique(global_steps):
stepmap[step] = list()
for step, file_ in zip(global_steps, files):
stepmap[step].append(file_)
return stepmap
[docs]class KeepBestCheckpoints(Hook):
"""Tries to find a metric for all checkpoints and keeps the n_keep best
checkpoints and the latest checkpoint."""
[docs] def __init__(
self,
checkpoint_root,
metric_template,
metric_key,
n_keep=5,
lower_is_better=True,
):
"""
Parameters
----------
checkpoint_root : str
Path to look for checkpoints.
metric_template : str
Format string to find metric file.
metric_key : str
Key to use from metric file.
n_keep : int
Maximum number of checkpoints to keep.
"""
self.root = checkpoint_root
self.metric_template = metric_template
self.metric_key = metric_key
self.n_keep = n_keep
self.lower_is_better = lower_is_better
self.logger = get_logger(self)
[docs] def get_loss(self, step):
path = self.metric_template.format(step)
try:
if path.endswith(".npz"):
loss = np.load(path)[self.metric_key][0]
else:
with open(path, "rb") as f:
loss = pickle.load(f)[self.metric_key][0]
if not self.lower_is_better:
loss = -1.0 * loss
except FileNotFoundError:
self.logger.debug("Could not find {}".format(path))
loss = None
return loss
[docs] def after_epoch(self, ep):
checkpoint_files = get_checkpoint_files(self.root)
steps = sorted(checkpoint_files.keys())
losses = [self.get_loss(step) for step in steps]
valid = [i for i in range(len(steps)) if losses[i] is not None]
steps = [steps[i] for i in valid]
losses = [losses[i] for i in valid]
latest_step = max(steps)
loss_steps = sorted(zip(losses, steps), key=lambda x: x[0])
steps = [s for _, s in loss_steps]
remove_steps = steps[self.n_keep :]
remove_steps = [step for step in remove_steps if not step == latest_step]
remove_files = list()
for step in remove_steps:
remove_files += checkpoint_files[step]
self.logger.info("Removing files:")
self.logger.info(remove_files)
for file_ in remove_files:
os.remove(file_)
best_ls = loss_steps[0]
self.logger.info(
"Current best: {} = {} @ global step {}".format(
self.metric_key, best_ls[0], best_ls[1]
)
)
no_improvement_since = latest_step - best_ls[1]
if no_improvement_since > 0:
self.logger.info(
"No improvement since {} global steps.".format(no_improvement_since)
)