Source code for edflow.iterators.tf_batches

from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
import tensorflow as tf

import math


[docs]def tf_batch_to_canvas(X, cols: int = None): """reshape a batch of images into a grid canvas to form a single image. Parameters ---------- X : Tensor Batch of images to format. [N, H, W, C]-shaped cols : int cols: int : (Default value = None) Returns ------- image_grid : Tensor Tensor representing the image grid. [1, HH, WW, C]-shaped Examples -------- x = np.ones((9, 100, 100, 3)) x = tf.convert_to_tensor(x) canvas = batches.tf_batch_to_canvas(x) assert canvas.shape == (1, 300, 300, 3) canvas = batches.tf_batch_to_canvas(x, cols=5) assert canvas.shape == (1, 200, 500, 3) """ if len(X.shape.as_list()) > 4: raise ValueError("input tensor has more than 4 dimensions.") N, H, W, C = X.shape.as_list() rc = math.sqrt(N) if cols is None: rows = cols = math.ceil(rc) else: cols = max(1, cols) rows = math.ceil(N / cols) n_white_tiles = cols * rows - N if n_white_tiles > 0: white_tiles = tf.ones((n_white_tiles, H, W, C), X.dtype) X = tf.concat([X, white_tiles], 0) image_shape = (H, W) n_channels = C return image_grid(X, (rows, cols), image_shape, n_channels)
# TODO(joelshor): Make this a special case of `image_reshaper`. # shamelessly copied from https://github.com/tensorflow/tensorflow/blob/r1.14/tensorflow/contrib/gan/python/eval/python/eval_utils_impl.py#L34-L80
[docs]def image_grid(input_tensor, grid_shape, image_shape=(32, 32), num_channels=3): """Arrange a minibatch of images into a grid to form a single image. Parameters ---------- input_tensor : Tensor. Minibatch of images to format, either 4D ([batch size, height, width, num_channels]) or flattened ([batch size, height * width * num_channels]). grid_shape : Sequence of int. The shape of the image grid, formatted as [grid_height, grid_width]. image_shape : Sequence of int. The shape of a single image, formatted as [image_height, image_width]. (Default value = (32) 32) : num_channels : (Default value = 3) Returns ------- Tensor representing a single image in which the input images have been Raises ------ ValueError The grid shape and minibatch size don't match, or the image shape and number of channels are incompatible with the input tensor. """ if grid_shape[0] * grid_shape[1] != int(input_tensor.shape[0]): raise ValueError( "Grid shape %s incompatible with minibatch size %i." % (grid_shape, int(input_tensor.shape[0])) ) if len(input_tensor.shape) == 2: num_features = image_shape[0] * image_shape[1] * num_channels if int(input_tensor.shape[1]) != num_features: raise ValueError( "Image shape and number of channels incompatible with " "input tensor." ) elif len(input_tensor.shape) == 4: if ( int(input_tensor.shape[1]) != image_shape[0] or int(input_tensor.shape[2]) != image_shape[1] or int(input_tensor.shape[3]) != num_channels ): raise ValueError( "Image shape and number of channels incompatible with " "input tensor." ) else: raise ValueError("Unrecognized input tensor format.") height, width = grid_shape[0] * image_shape[0], grid_shape[1] * image_shape[1] input_tensor = array_ops.reshape( input_tensor, tuple(grid_shape) + tuple(image_shape) + (num_channels,) ) input_tensor = array_ops.transpose(input_tensor, [0, 1, 3, 2, 4]) input_tensor = array_ops.reshape( input_tensor, [grid_shape[0], width, image_shape[0], num_channels] ) input_tensor = array_ops.transpose(input_tensor, [0, 2, 1, 3]) input_tensor = array_ops.reshape(input_tensor, [1, height, width, num_channels]) return input_tensor