__all__ = ["parallel_coordinates"]
from typing import Union, Any
import numpy as np
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from matplotlib.path import Path
import matplotlib.patches as patches
from .utils import _rescale
[docs]def parallel_coordinates(
data: Union[np.ndarray, Any],
categories: Union[np.ndarray, list] = None,
names: list = None,
cmap: str = None,
linestyle: str = "bezier",
coord_title_kws: dict = None,
title: str = 'Parallel Coordinates Plot',
figsize: tuple = None,
ticklabel_kws: dict = None,
show: bool = True
) -> plt.Axes:
"""
parallel coordinates plot
modifying after https://stackoverflow.com/a/60401570/5982232
Parameters
----------
data : array, DataFrame
a two dimensional array with the shape (rows, columns). It can also
be pandas DataFrame
categories : list, array
1 dimensional array which contain class labels of the of each row in
data. It can be either categorical or continuous numerical values.
If not given, colorbar will not be drawn. The length of categroes
array must be equal to length of/rows in data.
names : list, optional
Labels for columns in data. It's length should be equal to number of
oclumns in data.
cmap : str, optional
colormap to be used
coord_title_kws : dict, optional
keyword arguments for coodinate titles. All of these arguments will go to
:obj:`matplotlib.axes.Axes.set_xticklabels`
linestyle : str, optional
either "straight" or "bezier". Default is "bezier".
title : str, optional
title for the Figure
figsize : tuple, optional
figure size
ticklabel_kws : dict, optional
keyword arguments for ticklabels on y-axis
show : bool, optional
whether to show the plot or not
Returns
-------
matplotlib Axes
Examples
--------
>>> import random
>>> import numpy as np
>>> import pandas as pd
>>> from easy_mpl import parallel_coordinates
...
>>> ynames = ['P1', 'P2', 'P3', 'P4', 'P5'] # feature/column names
>>> N1, N2, N3 = 10, 5, 8
>>> N = N1 + N2 + N3
>>> categories_ = ['a', 'b', 'c', 'd', 'e', 'f']
>>> y1 = np.random.uniform(0, 10, N) + 7
>>> y2 = np.sin(np.random.uniform(0, np.pi, N))
>>> y3 = np.random.binomial(300, 1 / 10, N)
>>> y4 = np.random.binomial(200, 1 / 3, N)
>>> y5 = np.random.uniform(0, 800, N)
... # combine all arrays into a pandas DataFrame
>>> data_np = np.column_stack((y1, y2, y3, y4, y5))
>>> data_df = pd.DataFrame(data_np, columns=ynames)
... # using a DataFrame to draw parallel coordinates
>>> parallel_coordinates(data_df, names=ynames)
... # using continuous values for categories
>>> parallel_coordinates(data_df, names=ynames, categories=np.random.randint(0, 5, N))
... # using categorical classes
>>> parallel_coordinates(data_df, names=ynames, categories=random.choices(categories_, k=N))
... # using numpy array instead of DataFrame
>>> parallel_coordinates(data_df.values, names=ynames)
... # with customized tick labels
>>> parallel_coordinates(data_df.values, ticklabel_kws={"fontsize": 8, "color": "red"})
... # using straight lines instead of bezier
>>> parallel_coordinates(data_df, linestyle="straight")
... # with categorical class labels
>>> data_df['P5'] = random.choices(categories_, k=N)
>>> parallel_coordinates(data_df, names=ynames)
... # with categorical class labels and customized ticklabels
>>> data_df['P5'] = random.choices(categories_, k=N)
>>> parallel_coordinates(data_df, ticklabel_kws={"fontsize": 8, "color": "red"})
See :ref:`sphx_glr_auto_examples_parallel_coordinates.py` for more examples
Note
----
If nans are present in data or categories, all the corresponding enteries/rows
will be removed.
"""
try:
import pandas as pd
except (ModuleNotFoundError, ImportError):
raise NotImplemented(f"You must install pandas to draw parallel plot")
if cmap is None:
cmap = "Blues"
if isinstance(data, np.ndarray):
assert data.ndim == 2, f"{data.ndim} dimensional data not allowed. It must be 2d"
if names is None:
names = [f"Feat_{i}" for i in range(data.shape[1])]
data = pd.DataFrame(data, columns=names)
if hasattr(data, "columns"):
names = names or data.columns.tolist()
if len(names) != data.shape[1]:
raise ValueError(f"""
provided names have length {len(names)} but data has {data.shape[1]} columns""")
show_colorbar = True
if categories is None:
show_colorbar = False
categories = np.linspace(0, 1, len(data))
categories = np.array(categories)
assert len(categories) == len(data)
# remove NaN values based upon nan values in data
if data.isna().sum().sum() > 0:
df_nan_idx = data.isna().any(axis=1)
categories = categories[~df_nan_idx]
data = data[~df_nan_idx]
_is_categorical = False
cat_encoded = categories
if not np.issubdtype(categories.dtype, np.number):
# category contains categorical/non-numeri values
cat_encoded = label_encoder(categories)
_is_categorical = True
if not _is_categorical: # because we can't do np.isnan for categorical values
# if there are still any nans in categories, remove them
cat_nan_idx = np.isnan(categories)
if cat_nan_idx.any():
categories = categories[~cat_nan_idx]
data = data[~cat_nan_idx]
num_cols = data.shape[1]
num_lines = len(data)
# find out which columns are categorical and which are numerical
enc_data = data.copy()
cols = {}
for idx, col in enumerate(data.columns):
_col = data[col]
if is_categorical(data[col].values):
col_encoded = label_encoder(data[col].values)
cols[idx] = {'cat': True, 'original': _col}
enc_data[col] = col_encoded
else:
cols[idx] = {'cat': False}
# organize the data
enc_data = enc_data.astype(float)
ymins = np.min(enc_data.values, axis=0) # ys.min(axis=0)
ymaxs = np.max(enc_data.values, axis=0) # ys.max(axis=0)
dys = ymaxs - ymins
ymins -= dys * 0.05 # add 5% padding below and above
ymaxs += dys * 0.05
dys = ymaxs - ymins
# transform all data to be compatible with the main axis
zs = np.zeros_like(enc_data.values)
zs[:, 0] = enc_data.iloc[:, 0]
zs[:, 1:] = (enc_data.iloc[:, 1:] - ymins[1:]) / dys[1:] * dys[0] + ymins[0]
fig, host = plt.subplots(figsize=figsize)
axes = [host] + [host.twinx() for _ in range(num_cols - 1)]
for i, ax in enumerate(axes):
ax.set_ylim(ymins[i], ymaxs[i])
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_visible(False)
if ax != host:
ax.spines['left'].set_visible(False)
ax.yaxis.set_ticks_position('right')
ax.spines["right"].set_position(("axes", i / (num_cols - 1)))
if cols[i]['cat']:
categories = np.unique(cols[i]['original'])
new_ticks = np.unique(enc_data.iloc[:, i]).astype("float32")
ax.set_yticks(new_ticks)
ax.set_yticklabels(categories)
if ticklabel_kws:
if cols[i]['cat']:
ticks_loc = [l._text for l in ax.get_yticklabels()]
else:
ticks_loc = ax.get_yticks().tolist()
ax.set_yticks(ax.get_yticks().tolist())
ax.set_yticklabels([label_format(x) for x in ticks_loc], **ticklabel_kws)
if coord_title_kws is None:
coord_title_kws = dict()
_coord_title_kws = {'fontsize': 14}
_coord_title_kws.update(coord_title_kws)
host.set_xlim(0, num_cols - 1)
host.set_xticks(range(num_cols))
host.set_xticklabels(names, **_coord_title_kws)
host.tick_params(axis='x', which='major', pad=7)
host.spines['right'].set_visible(False)
if title:
host.set_title(title, fontsize=18)
# category between 0.2,1 to map colors to their values
cat_norm = _rescale(cat_encoded, 0.2)
for j in range(num_lines):
# color of each line is based upon corresponding value in category
colors = getattr(cm, cmap)(cat_norm[j])
if linestyle == "straight":
# to just draw straight lines between the axes:
host.plot(range(num_cols), zs[j, :], c=colors)
else:
# create bezier curves
# for each axis, there will a control vertex at the point itself, one at 1/3rd towards the previous and one
# at one third towards the next axis; the first and last axis have one less control vertex
# x-coordinate of the control vertices: at each integer (for the axes) and two inbetween
# y-coordinate: repeat every point three times, except the first and last only twice
x_coords = [x for x in np.linspace(0, len(data) - 1, len(data) * 3 - 2, endpoint=True)]
y_coords = np.repeat(zs[j, :], 3)[1:-1]
verts = list(zip(x_coords, y_coords))
# for x,y in verts: host.plot(x, y, 'go') # to show the control points of the beziers
codes = [Path.MOVETO] + [Path.CURVE4 for _ in range(len(verts) - 1)]
path = Path(verts, codes)
patch = patches.PathPatch(path, facecolor='none', lw=1, edgecolor=colors)
host.add_patch(patch)
if show_colorbar:
norm = cm.colors.Normalize(np.min(cat_encoded), np.max(cat_encoded))
cb = cm.ScalarMappable(norm, cmap=cmap)
if _is_categorical:
cbar = fig.colorbar(cb, orientation="vertical", pad=0.1, ax=ax)
ticks = cbar.get_ticks()
new_ticks = np.linspace(ticks[0], ticks[-1], len(np.unique(categories)))
cbar.set_ticks(new_ticks)
cbar.set_ticklabels(np.unique(categories))
else:
cbar = fig.colorbar(cb, orientation="vertical", pad=0.1, ax=ax)
cax = cbar.ax # todo
# Turn spines off and create white grid.
if isinstance(cax.spines, dict):
for sp in cax.spines:
cax.spines[sp].set_visible(False)
else:
cax.spines[:].set_visible(False)
plt.tight_layout()
if show:
plt.show()
return host
def label_format(x):
if isinstance(x, float):
return round(x, 3)
else:
return x
def is_categorical(array) -> bool:
return not np.issubdtype(array.dtype, np.number)
def label_encoder(arr):
# label encoder of numpy array with categorical values
return np.unique(arr, return_inverse=True)[1]