top of page
  • Writer's pictureOktay Sahinoglu

Generating Animated GIF with Python

Updated: Oct 5, 2020

Visualization is often very helpful in expressing concepts better. Animated visualization, on the other hand, takes this contribution further and increases the explanation of even more complex concepts. Especially in data analytics, you can create animated charts to lift the veil and reveal the world behind the scenes. I will share the following Python code as an example of creating an animated gif showing the learning of weights in a Neural Network.

The parts used to plot and annotate the heatmap are taken from "matplotlib" documentation. "animation_data.obj" is the file, dumped by pickle, containing a list of 2D Numpy arrays, each holding the weights of a Neural Network layer during training.

import pickle
import numpy as np
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import matplotlib
import io

with open("animation_data.obj", "rb") as dataFile:
    heatmap_frames_data = pickle.load(dataFile)

def heatmap(data, row_labels, col_labels, ax=None,
 cbar_kw={}, cbar_label="",cbar_visible=True, **kwargs):
    Create a heatmap from a numpy array and two lists of labels.

        A 2D numpy array of shape (N, M).
        A list or array of length N with the labels for the rows.
        A list or array of length M with the labels for the columns.
        A `matplotlib.axes.Axes` instance to which the heatmap is plotted.  If
        not provided, use current axes or create a new one.  Optional.
        A dictionary with arguments to `matplotlib.Figure.colorbar`.  Optional.
        The label for the colorbar.  Optional.
        All other arguments are forwarded to `imshow`.

    if not ax:
        ax = plt.gca()

    # Plot the heatmap
    im = ax.imshow(data, **kwargs)

    # Create colorbar
    if cbar_visible:
        cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw), rotation=-90, va="bottom")

    # We want to show all ticks...
    # ... and label them with the respective list entries.

    # Let the horizontal axes labeling appear on top.
    ax.tick_params(top=True, bottom=False,
 labeltop=True, labelbottom=False)

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",

    # Turn spines off and create white grid.
    for edge, spine in ax.spines.items():

    ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
    ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
    ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
    ax.tick_params(which="minor", bottom=False, left=False)

    return im

def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
 textcolors=["black", "white"],
 threshold=None, **textkw):
    A function to annotate a heatmap.

        The AxesImage to be labeled.
        Data used to annotate.  If None, the image's data is used.  Optional.
        The format of the annotations inside the heatmap.  This should either
        use the string format method, e.g. "$ {x:.2f}", or be a
        `matplotlib.ticker.Formatter`.  Optional.
        A list or array of two color specifications.  The first is used for
        values below a threshold, the second for those above.  Optional.
        Value in data units according to which the colors from textcolors are
        applied.  If None (the default) uses the middle of the colormap as
        separation.  Optional.
        All other arguments are forwarded to each call to `text` used to create
        the text labels.

    if not isinstance(data, (list, np.ndarray)):
        data = im.get_array()

    # Normalize the threshold to the images color range.
    if threshold is not None:
        threshold = im.norm(threshold)
        threshold = im.norm(data.max())/2.

    # Set default alignment to center, but allow it to be
    # overwritten by textkw.
    kw = dict(horizontalalignment="center",

    # Get the formatter in case a string is supplied
    if isinstance(valfmt, str):
        valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)

    # Loop over the data and create a `Text` for each "pixel".
    # Change the text's color depending on the data.
    texts = []
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
            text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)

    return texts

def plot_heatmap_frame(data, figsize=(15,10)):
    fig, ax = plt.subplots(figsize=figsize)
    im = heatmap(data, ["embed_dim_" + str(dim_idx) for dim_idx in range(data.shape[0])], ["orig_dim_" + str(dim_idx) for dim_idx in range(data.shape[1])],
 ax=ax, cmap="YlGn", cbar_visible=False, vmin=0, vmax=0.05)
    texts = annotate_heatmap(im, valfmt="{x:.9f}")

    # writing plot to memory for reading with PIL
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    # we should close the image correctly
    # otherwise will get an exception
    with as im:
        frame = im.copy()
    return frame

# creating the first frame manually to be able to
# append the next frames on top of it
im = plot_heatmap_frame(heatmap_frames_data[0])

# appending the following frames and save as GIF
with open("training_weights.gif", "wb") as myGif:, save_all=True, append_images=[plot_heatmap_frame(heatmap_frames_data[i]) for i in range(1, len(heatmap_frames_data))], optimize=False, duration=8, loop=0)    


Hope you like it! See you in another post.

475 views0 comments

Recent Posts

See All
bottom of page