Visualization is often very helpful in expressing concepts better. Animated visualization, on the other hand, takes this contribution further and increases the explanation of even more complex concepts. Especially in data analytics, you can create animated charts to lift the veil and reveal the world behind the scenes. I will share the following Python code as an example of creating an animated gif showing the learning of weights in a Neural Network.
The parts used to plot and annotate the heatmap are taken from "matplotlib" documentation. "animation_data.obj" is the file, dumped by pickle, containing a list of 2D Numpy arrays, each holding the weights of a Neural Network layer during training.
import pickle
import numpy as np
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import matplotlib
import io
with open("animation_data.obj", "rb") as dataFile:
heatmap_frames_data = pickle.load(dataFile)
def heatmap(data, row_labels, col_labels, ax=None,
cbar_kw={}, cbar_label="",cbar_visible=True, **kwargs):
"""
Create a heatmap from a numpy array and two lists of labels.
Parameters
----------
data
A 2D numpy array of shape (N, M).
row_labels
A list or array of length N with the labels for the rows.
col_labels
A list or array of length M with the labels for the columns.
ax
A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If
not provided, use current axes or create a new one. Optional.
cbar_kw
A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional.
cbarlabel
The label for the colorbar. Optional.
**kwargs
All other arguments are forwarded to `imshow`.
"""
if not ax:
ax = plt.gca()
# Plot the heatmap
im = ax.imshow(data, **kwargs)
# Create colorbar
if cbar_visible:
cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
cbar.ax.set_ylabel(cbar_label, rotation=-90, va="bottom")
# We want to show all ticks...
ax.set_xticks(np.arange(data.shape[1]))
ax.set_yticks(np.arange(data.shape[0]))
# ... and label them with the respective list entries.
ax.set_xticklabels(col_labels)
ax.set_yticklabels(row_labels)
# Let the horizontal axes labeling appear on top.
ax.tick_params(top=True, bottom=False,
labeltop=True, labelbottom=False)
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",
rotation_mode="anchor")
# Turn spines off and create white grid.
for edge, spine in ax.spines.items():
spine.set_visible(False)
ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
ax.tick_params(which="minor", bottom=False, left=False)
return im
def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
textcolors=["black", "white"],
threshold=None, **textkw):
"""
A function to annotate a heatmap.
Parameters
----------
im
The AxesImage to be labeled.
data
Data used to annotate. If None, the image's data is used. Optional.
valfmt
The format of the annotations inside the heatmap. This should either
use the string format method, e.g. "$ {x:.2f}", or be a
`matplotlib.ticker.Formatter`. Optional.
textcolors
A list or array of two color specifications. The first is used for
values below a threshold, the second for those above. Optional.
threshold
Value in data units according to which the colors from textcolors are
applied. If None (the default) uses the middle of the colormap as
separation. Optional.
**kwargs
All other arguments are forwarded to each call to `text` used to create
the text labels.
"""
if not isinstance(data, (list, np.ndarray)):
data = im.get_array()
# Normalize the threshold to the images color range.
if threshold is not None:
threshold = im.norm(threshold)
else:
threshold = im.norm(data.max())/2.
# Set default alignment to center, but allow it to be
# overwritten by textkw.
kw = dict(horizontalalignment="center",
verticalalignment="center")
kw.update(textkw)
# Get the formatter in case a string is supplied
if isinstance(valfmt, str):
valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)
# Loop over the data and create a `Text` for each "pixel".
# Change the text's color depending on the data.
texts = []
for i in range(data.shape[0]):
for j in range(data.shape[1]):
kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
texts.append(text)
return texts
def plot_heatmap_frame(data, figsize=(15,10)):
fig, ax = plt.subplots(figsize=figsize)
im = heatmap(data, ["embed_dim_" + str(dim_idx) for dim_idx in range(data.shape[0])], ["orig_dim_" + str(dim_idx) for dim_idx in range(data.shape[1])],
ax=ax, cmap="YlGn", cbar_visible=False, vmin=0, vmax=0.05)
texts = annotate_heatmap(im, valfmt="{x:.9f}")
fig.tight_layout()
# writing plot to memory for reading with PIL
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
# we should close the image correctly
# otherwise will get an exception
with Image.open(buf) as im:
frame = im.copy()
return frame
# creating the first frame manually to be able to
# append the next frames on top of it
im = plot_heatmap_frame(heatmap_frames_data[0])
# appending the following frames and save as GIF
with open("training_weights.gif", "wb") as myGif:
im.save(myGif, save_all=True, append_images=[plot_heatmap_frame(heatmap_frames_data[i]) for i in range(1, len(heatmap_frames_data))], optimize=False, duration=8, loop=0)
Result!!!
Hope you like it! See you in another post.
Comments