import os
import sys
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from edflow.hooks.hook import Hook
from edflow.custom_logging import get_logger
from edflow.util import retrieve
from edflow.util import walk
from edflow.iterators.batches import plot_batch
"""PyTorch hooks useful during training."""
[docs]class PyCheckpointHook(Hook):
"""Does that checkpoint thingy where it stores everything in a
checkpoint."""
[docs] def __init__(self, root_path, model, modelname="model", interval=None):
"""
Parameters
----------
root_path : str
Path to where the checkpoints are stored.
model : nn.Module
Model to checkpoint.
modelname : str
Prefix for checkpoint files.
interval : int
Number of iterations after which a checkpoint is
saved. In any case a checkpoint is savead after each epoch.
"""
self.root = root_path
self.interval = interval
self.model = model
self.logger = get_logger(self)
os.makedirs(root_path, exist_ok=True)
self.savename = os.path.join(root_path, "{{}}-{{}}_{}.ckpt".format(modelname))
# Init to save even before first step... More of a debug statement
self.step = 0
self.epoch = 0
[docs] def before_epoch(self, epoch):
self.epoch = epoch
[docs] def after_epoch(self, epoch):
self.save()
[docs] def after_step(self, step, last_results):
self.step = retrieve(last_results, "global_step")
if self.interval is not None and step % self.interval == 0:
self.save()
[docs] def at_exception(self, *args, **kwargs):
self.save()
[docs] def save(self):
e = self.epoch
s = self.step
savename = self.savename.format(e, s)
torch.save(self.model.state_dict(), savename)
self.logger.info("Saved model to {}".format(savename))
[docs]class PyLoggingHook(Hook):
"""Supply and evaluate logging ops at an intervall of training steps."""
[docs] def __init__(
self,
log_ops=[],
scalar_keys=[],
histogram_keys=[],
image_keys=[],
log_keys=[],
graph=None,
interval=100,
root_path="logs",
):
"""
Parameters
----------
log_ops : list
Ops to run at logging time.
scalars : dict
Scalar ops.
histograms : dict
Histogram ops.
images : dict
Image ops. Note that for these no
tensorboard logging ist used but a custom image saver.
logs : dict
Logs to std out via logger.
graph : tf.Graph
Current graph.
interval : int
Intervall of training steps before logging.
root_path : str
Path at which the logs are stored.
"""
self.log_ops = log_ops
self.scalar_keys = scalar_keys
self.histogram_keys = histogram_keys
self.image_keys = image_keys
self.log_keys = log_keys
self.interval = interval
self.tb_logger = SummaryWriter(root_path)
self.graph = graph
self.root = root_path
self.logger = get_logger(self)
[docs] def before_step(self, batch_index, fetches, feeds, batch):
if batch_index % self.interval == 0:
fetches["logging"] = self.log_ops
[docs] def after_step(self, batch_index, last_results):
if batch_index % self.interval == 0:
step = last_results["global_step"]
for key in self.scalar_keys:
value = retrieve(last_results, key)
self.tb_logger.add_scalar(key, value, step)
for key in self.histogram_keys:
value = retrieve(last_results, key)
self.tb_logger.add_histogram(key, value, step)
for key in self.image_keys:
value = retrieve(last_results, key)
name = key.split("/")[-1]
full_name = name + "_{:07}.png".format(step)
save_path = os.path.join(self.root, full_name)
plot_batch(value, save_path)
for key in self.log_keys:
value = retrieve(last_results, key)
self.logger.info("{}: {}".format(key, value))
[docs]class ToNumpyHook(Hook):
"""Converts all pytorch Variables and Tensors in the results to numpy
arrays and leaves the rest as is."""
[docs] def after_step(self, step, results):
def convert(var_or_tens):
if hasattr(var_or_tens, "cpu"):
var_or_tens = var_or_tens.cpu()
if isinstance(var_or_tens, torch.autograd.Variable):
return var_or_tens.data.numpy()
elif isinstance(var_or_tens, torch.Tensor):
return var_or_tens.numpy()
else:
return var_or_tens
walk(results, convert, inplace=True)
[docs]class ToTorchHook(Hook):
"""Converts all numpy arrays in the batch to torch.Tensor
arrays and leaves the rest as is."""
[docs] def __init__(self, push_to_gpu=True, dtype=torch.float):
self.use_gpu = push_to_gpu
self.dtype = dtype
self.logger = get_logger(self)
[docs] def before_step(self, step, fetches, feeds, batch):
def convert(obj):
if isinstance(obj, np.ndarray):
try:
obj = torch.tensor(obj)
obj = obj.to(self.dtype)
if self.use_gpu:
obj = obj.cuda()
return obj
except Exception:
return obj
else:
return obj
walk(feeds, convert, inplace=True)
[docs]class ToFromTorchHook(ToNumpyHook, ToTorchHook):
[docs] def __init__(self, *args, **kwargs):
ToTorchHook.__init__(self, *args, **kwargs)
[docs]class DataPrepHook(ToFromTorchHook):
"""
The hook is needed in order to convert the input appropriately. Here, we have
to reshape the input i.e. append 1 to the shape (for the number of channels
of the image). Plus, it converts to data to Pytorch tensors, and back.
"""
[docs] def before_step(self, step, fetches, feeds, batch):
"""
Steps taken before the training step.
Parameters
----------
step
Training step.
fetches
Fetches for the next session.run call.
feeds
Feeds for the next session.run call.
batch
The batch to be iterated over.
"""
def to_image(obj):
if isinstance(obj, np.ndarray) and len(obj.shape) == 3:
batches, height, width = obj.shape
obj = obj.reshape(batches, height, width, 1)
if isinstance(obj, np.ndarray) and len(obj.shape) == 4:
return obj.transpose(0, 3, 1, 2)
else:
return obj
walk(feeds, to_image, inplace=True)
super().before_step(step, fetches, feeds, batch)
[docs] def after_step(self, step, results):
"""
Steps taken after the training step.
Parameters
----------
step
Training step.
results
Result of the session.
"""
super().after_step(step, results)
def to_image(k, obj):
if (
"weights" not in k
and isinstance(obj, np.ndarray)
and len(obj.shape) == 4
):
return obj.transpose(0, 2, 3, 1)
else:
return obj
walk(results, to_image, inplace=True, pass_key=True)