0

I have a sample Pandas data frame as follows:

Action    Comedy    Crime    Thriller    SciFi    
1         0         1         1          0        
0         1         0         0          1        
0         1         0         1          0        
0         0         1         0          1        
1         1         0         0          0        

I would like to plot the data-set using Python(Preferably by using matplotlib) in such a way that each of the columns will be a separate axis. Hence in this case, there will be 5 axis (Action, Comedy, Crime...) and 5 data points (since it has 5 rows). Is it possible to plot this kind of multi-axis data using python matplotlib? If its not possible, what would be the best solution to visualize this data?

2 Answers 2

3

RadarChart

Having several axes could be accomplished using a RadarChart. You may adapt the Radar Chart example to your needs.

enter image description here

u = u"""Action    Comedy    Crime    Thriller    SciFi    
1         0         1         1          0        
0         1         0         0          1        
0         1         0         1          0        
0         0         1         0          1        
1         1         0         0          0"""

import io
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.path import Path
from matplotlib.spines import Spine
from matplotlib.projections.polar import PolarAxes
from matplotlib.projections import register_projection


def radar_factory(num_vars, frame='circle'):
    theta = np.linspace(0, 2*np.pi, num_vars, endpoint=False)
    theta += np.pi/2

    def draw_poly_patch(self):
        verts = unit_poly_verts(theta)
        return plt.Polygon(verts, closed=True, edgecolor='k')

    def draw_circle_patch(self):
        return plt.Circle((0.5, 0.5), 0.5)

    patch_dict = {'polygon': draw_poly_patch, 'circle': draw_circle_patch}

    def unit_poly_verts(theta):
        x0, y0, r = [0.5] * 3
        verts = [(r*np.cos(t) + x0, r*np.sin(t) + y0) for t in theta]
        return verts

    class RadarAxes(PolarAxes):

        name = 'radar'
        RESOLUTION = 1
        draw_patch = patch_dict[frame]

        def fill(self, *args, **kwargs):
            """Override fill so that line is closed by default"""
            closed = kwargs.pop('closed', True)
            return super(RadarAxes, self).fill(closed=closed, *args, **kwargs)

        def plot(self, *args, **kwargs):
            """Override plot so that line is closed by default"""
            lines = super(RadarAxes, self).plot(*args, **kwargs)
            for line in lines:
                self._close_line(line)

        def _close_line(self, line):
            x, y = line.get_data()
            if x[0] != x[-1]:
                x = np.concatenate((x, [x[0]]))
                y = np.concatenate((y, [y[0]]))
                line.set_data(x, y)

        def set_varlabels(self, labels):
            self.set_thetagrids(np.degrees(theta), labels)

        def _gen_axes_patch(self):
            return self.draw_patch()

        def _gen_axes_spines(self):
            if frame == 'circle':
                return PolarAxes._gen_axes_spines(self)
            spine_type = 'circle'
            verts = unit_poly_verts(theta)
            # close off polygon by repeating first vertex
            verts.append(verts[0])
            path = Path(verts)

            spine = Spine(self, spine_type, path)
            spine.set_transform(self.transAxes)
            return {'polar': spine}

    register_projection(RadarAxes)
    return theta


df = pd.read_csv(io.StringIO(u), delim_whitespace=True)

N = 5
theta = radar_factory(N, frame='polygon')

fig, ax = plt.subplots(subplot_kw=dict(projection='radar'))

colors = ['b', 'r', 'g', 'm', 'y']
markers = ["s", "o","P", "*", "^"]
ax.set_rgrids([1])

for i,(col, row) in enumerate(df.iterrows()):
    ax.scatter(theta, row, c=colors[i], marker=markers[i], label=col)
    ax.fill(theta, row, facecolor=colors[i], alpha=0.25)
ax.set_varlabels(df.columns)

labels = ["Book {}".format(i+1) for i in range(len(df))]
ax.legend(labels*2, loc=(0.97, .1), labelspacing=0.1, fontsize='small')

plt.show()

heatmap

An easy and probably more readable way to visualize the data would be a heatmap.

enter image description here

u = u"""Action    Comedy    Crime    Thriller    SciFi    
1         0         1         1          0        
0         1         0         0          1        
0         1         0         1          0        
0         0         1         0          1        
1         1         0         0          0"""

import io
import pandas as pd
import matplotlib.pyplot as plt

df = pd.read_csv(io.StringIO(u), delim_whitespace=True)
print df

plt.matshow(df, cmap="gray")
plt.xticks(range(len(df.columns)), df.columns)
plt.yticks(range(len(df)), range(1,len(df)+1))
plt.ylabel("Book number")     
plt.show()   
1

Here is a nice simple visualization that you can get with a bit of data manipulation and Seaborn.

import seaborn as sns

# df is a Pandas DataFrame with the following content:
#   Action    Comedy    Crime    Thriller    SciFi
#   1         0         1         1          0
#   0         1         0         0          1
#   0         1         0         1          0
#   0         0         1         0          1
#   1         1         0         0          0
df = ...

# Give name to the indices for convenience
df.index.name = "Index"
df.columns.name = "Genre"

# Get a data frame containing the relevant genres and indices
df2 = df.unstack()
df2 = df2[df2 > 0].reset_index()

# Plot it
ax = sns.stripplot(x="Genre", y="Index", data=df2)
ax.set_yticks(df.index)

And you get:

Data plot

For fine tuning you can check the documentation of sns.stripplot.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.