Source code for edflow.hooks.logging_hooks.tf_logging_hook

import tensorflow as tf
import os
import time

from edflow.hooks.hook import Hook
from edflow.hooks.checkpoint_hooks.common import get_checkpoint_files
from edflow.custom_logging import get_logger
from edflow.iterators.batches import plot_batch, batch_to_canvas

import sys
from PIL import Image, ImageDraw, ImageFont
import numpy as np

"""TensorFlow hooks useful during training."""


[docs]class LoggingHook(Hook): """Supply and evaluate logging ops at an intervall of training steps."""
[docs] def __init__( self, scalars={}, histograms={}, images={}, logs={}, graph=None, interval=100, root_path="logs", log_images_to_tensorboard=False, ): """ Parameters ---------- scalars : dict Scalar ops. histograms : dict Histogram ops. images : dict Image ops. Note that for these no tensorboard logging ist used but a custom image saver. logs : dict Logs to std out via logger. graph : tf.Graph Current graph. interval : int Intervall of training steps before logging. root_path : str Path at which the logs are stored. """ scalars = [tf.summary.scalar(n, s) for n, s in scalars.items()] histograms = [tf.summary.histogram(n, h) for n, h in histograms.items()] self.log_images_to_tensorboard = log_images_to_tensorboard if log_images_to_tensorboard: im_summaries = [tf.summary.image(n, i) for n, i in images.items()] else: im_summaries = [] self._has_summary = len(scalars + histograms + im_summaries) > 0 if self._has_summary: summary_op = tf.summary.merge(scalars + histograms + im_summaries) else: summary_op = tf.no_op() self.fetch_dict = {"summaries": summary_op, "logs": logs, "images": images} self.interval = interval self.graph = graph self.root = root_path self.logger = get_logger(self)
[docs] def before_epoch(self, ep): if ep == 0: if self.graph is None: self.graph = tf.get_default_graph() self.writer = tf.summary.FileWriter(self.root, self.graph)
[docs] def before_step(self, batch_index, fetches, feeds, batch): if batch_index % self.interval == 0: fetches["logging"] = self.fetch_dict
[docs] def after_step(self, batch_index, last_results): if batch_index % self.interval == 0: step = last_results["global_step"] last_results = last_results["logging"] if self._has_summary: summary = last_results["summaries"] self.writer.add_summary(summary, step) logs = last_results["logs"] for name in sorted(logs.keys()): self.logger.info("{}: {}".format(name, logs[name])) if not self.log_images_to_tensorboard: for name, image_batch in last_results["images"].items(): full_name = name + "_{:07}.png".format(step) save_path = os.path.join(self.root, full_name) plot_batch(image_batch, save_path) self.logger.info("project root: {}".format(self.root))
[docs]class ImageOverviewHook(Hook):
[docs] def __init__(self, images={}, interval=100, root_path="logs"): """ Logs an overview of all image outputs at an intervall of training steps. Parameters ---------- scalars : dict Scalar ops. histograms : dict Histogram ops. images : dict Image ops. Note that for these no tensorboard logging ist used but a custom image saver. logs : dict Logs to std out via logger. graph : tf.Graph Current graph. interval : int Intervall of training steps before logging. root_path : str Path at which the logs are stored. """ summary_op = tf.no_op() # self.log_images_to_tensorboard = log_images_to_tensorboard # TODO: actually implement this functionality self.fetch_dict = {"summaries": summary_op, "images": images} self.interval = interval self.root = root_path self.logger = get_logger(self)
[docs] def after_step(self, batch_index, last_results): if batch_index % self.interval == 0: step = last_results["global_step"] # TODO: fix hard-coded font type # TODO: add option to log overview to tensorboard batches = [] try: fnt = ImageFont.truetype("LiberationMono-Regular.ttf", 20) except OSError: fnt = ImageFont.load_default() last_results = last_results["logging"] for name, im in sorted(last_results["images"].items()): canvas = batch_to_canvas(im) canvas = (canvas + 1.0) / 2.0 canvas = np.clip(255 * canvas, 0, 255) canvas = np.array(canvas, dtype="uint8") im = Image.fromarray(canvas) im.thumbnail((512, 512), Image.ANTIALIAS) d = ImageDraw.Draw(im) d.text((10, 10), name, fill=(255, 0, 0), font=fnt) batches.append(im) im = Image.new("RGB", batches[0].size, color=(0, 0, 0)) try: fnt = ImageFont.truetype("LiberationMono-Regular.ttf", 50) except OSError: fnt = ImageFont.load_default() d = ImageDraw.Draw(im) d.text((10, 10), "epoch\n{:07d}".format(step), fill=(255, 0, 0), font=fnt) batches.append(im) batch = np.stack(batches, axis=0) / 255.0 * 2 - 1.0 out_path = os.path.join(self.root, "overview_{:07d}.png".format(step)) plot_batch(batch, out_path)