Keras: reshaping tensors

Posted on Sat 06 March 2021 in Recipes

In order to create a Keras model, you need to specify the shape of the input tensor. That input tensor can be an array of values, an grayscale image, an RGB image, or anything else. Suppose we have 8x8 grayscale images (i.e., a single channel). Your task is to reshape the image from (8,8) to (8,8,1), so we can include the number of channels in the tuple. There are many ways to do this (and a lot more to do it incorrectly!). Let's discuss a few of them.

The following code provides a simple framework for testing methods for reshaping a tensor from (8,8) to (8,8,1).

In [3]:
# reshape_arrays.py
# This script compares methods for reshaping a tensor
# Our goal is to reshape a tensor from (8,8) to (8,8,1)
# This is useful when we want to specify the shape of a grayscale image
# in Keras, for example:
#
#       input_tensor = layers.Input(shape=(8, 8, 1)) 
# or:
# 
#       model.add(layers.Dense(32, activation="relu", input_shape=(8, 8, 1)))

import numpy as np

if __name__ == "__main__":

    # create an input tensor of shape (4, 4)
    # think of it as a grayscale image of 4x4
    a = np.random.randint(0, 256, (4, 4))

    # reshape the tensor from (4, 4) to (4, 4, 1)
    # using several methods
    a1 = a.reshape(a.shape + (1, ))     # method 1
    a2 = np.expand_dims(a, axis=0)      # method 2 fails: incorrect axis
    a3 = np.expand_dims(a, axis=1)      # method 3 fails: incorrect axis
    a4 = np.expand_dims(a, axis=2)      # method 4
    a5 = a[None, :, :]                  # method 5 fails: None is in the incorrect dimension
    a6 = a[:, :, None]                  # method 6
    a7 = a.reshape(-1, 1)               # method 7 fails: -1 is in the incorrect dimension
    a8 = a.reshape(4, 4, -1)            # method 8

    tensors = (
        (a1, a2, "method 2"),
        (a1, a3, "method 3"),
        (a1, a4, "method 4"),
        (a1, a5, "method 5"),
        (a1, a6, "method 6"),
        (a1, a7, "method 7"),
        (a1, a8, "method 8"),
    )

    for ta, tb, msg in tensors:
        try:
            np.testing.assert_array_equal(ta, tb, err_msg=msg)
        except AssertionError as err:
            print("Error on: %s, expected shape: %s, result: %s" % (msg, ta.shape, tb.shape))
Error on: method 2, expected shape: (4, 4, 1), result: (1, 4, 4)
Error on: method 3, expected shape: (4, 4, 1), result: (4, 1, 4)
Error on: method 5, expected shape: (4, 4, 1), result: (1, 4, 4)
Error on: method 7, expected shape: (4, 4, 1), result: (16, 1)

As we can see from the output, four of the eight method failed the test: 2, 3, 5, and 7.

Method Description Does it work?
a.reshape(a.shape + (1, )) We take the current shape of the tensor and add it a new dimension. This works because both shapes are tuples: (4, 4) + (1,) is (4, 4, 1). We then reshape the tensor a to (4, 4, 1). Yes
np.expand_dims(a, axis=...) This is a bit tricky because we need to take care of the value of axis. If that par ameter is incorrectly specified, then the reshape may not return the expected result. In our example, this method passed the test only when axis=2. When axis=0 or axis=1, the shape of tensor changes but not as expected. For instance, if axis=0, the resulting shape is (1, 4, 4) rather than (4, 4, 1). Yes, but be careful with axis.
a[None, :, :] Similar to np.expand_dims, it is a bit tricky. It results that we can use None as a placeholder for a new dimension. In this case, [None, :, :] changes the shape of the tensor from (4, 4) to (1, 4, 4). No
a[:, : None] In order to make it work, we need to specify None in the correct dimension. Yes
a.reshape(-1, 1) This is an interesting one. At first, I believed that -1 can be interpreted as a.shape, so a.reshape(-1, 1) could be something as take the current shape of a and add one dimension. This is wrong! The meaning of -1 is described below. No
a.reshape(4, 4, -1) This works because we used -1 in the correct dimension. Yes

From the documentation of reshape:

The new shape should be compatible with the original shape. If an integer, then the result will be a 1-D array of that length. One shape dimension can be -1. In this case, the value is inferred from the length of the array and remaining dimensions.