Source code for easy_mpl._ridge


__all__ = ["ridge"]

import random
from typing import Union, List

import numpy as np
import matplotlib.pyplot as plt

from .utils import kde
from .utils import is_dataframe
from .utils import make_cols_from_cmap
from .utils import despine_axes

# todo, single array with labels argument

# colormaps for ridge plot
RIDGE_CMAPS = [
    "afmhot", "afmhot_r", "Blues", "bone",
    "BrBG", "BuGn", "coolwarm", "cubehelix",
    "gist_earth", "GnBu", "Greens", "magma",
    "ocean", "Pastel1", "pink", "PuBu", "PuBuGn",
    "RdBu", "Spectral",
]


[docs]def ridge( data: Union[np.ndarray, List[np.ndarray]], bw_method="scott", cut: float = 0.5, color: Union[str, List[str], np.ndarray, List[np.ndarray]] = None, fill_kws: dict = None, line_width: Union[int, List[int], float, List[float]] = 1.0, line_color: Union[str, List[str]] = "black", plot_kws: dict = None, xlabel: str = None, title: str = None, figsize: tuple = None, hspace: float = -0.7, labels: Union[str, List[str]] = None, share_axes: bool = False, ax: plt.Axes = None, show=True, ) -> List[plt.Axes,]: """ plots distribution of features/columns/arrays in data as ridge. Parameters ---------- data : array, DataFrame array or list of arrays or pandas DataFrame/Series bw_method : str, optional cut : float (default=0.5) color : str, optional color to fill the ridges. It can be any valid matplotlib color or color name or cmap name or a list of colors for each ridge. fill_kws : dict, (default=None) keyword arguments that will go to :obj:`matplotlib.axes.Axes.fill_between` line_width : int (default=1.0) with of line of ridges. line_color : str (default="black") color or colors of lines of ridges. plot_kws : dict optional any keyword argumenets that will go to :obj:`matplotlib.axes.Axes.plot` during plot of kde line xlabel : str, optional xlabel for the figure title : str, optional title of the figure figsize : tuple, optional size of figure hspace : float, optional (default=-0.7) amount of distance between plots labels : list/str (default=None names for xticks share_axes : bool, optional (default=False) whether to draw all ridges on same axes or separate axes ax : plt.Axes, optional (default=None) matplotlib axes object :obj:`matplotlib.axes` on which to draw the ridges. If given all ridges will be drawn on this axes. show : bool, optional whether to show the plot or not Returns ------- list a list of :obj:`matplotlib.axes` Examples -------- >>> import numpy as np >>> from easy_mpl import ridge >>> data_ = np.random.random((100, 3)) >>> ridge(data_) ... # specifying colormap >>> ridge(data_, cmap="Blues") ... # using pandas DataFrame >>> import pandas as pd >>> ridge(pd.DataFrame(data_)) See :ref:`sphx_glr_auto_examples_ridge.py` for more examples """ if isinstance(data, np.ndarray): if len(data) == data.size: data = data.reshape(-1, 1) assert data.ndim == 2 data = [data[:, i] for i in range(data.shape[1])] elif hasattr(data, 'name') and hasattr(data, 'values'): if labels is None: labels = [data.name] data = [data.values] elif is_dataframe(data): if labels is None: labels = data.columns.tolist() data = [data.values[:, i] for i in range(data.shape[1])] assert isinstance(data, list) if labels is None: labels = [f"F{i}" for i in range(len(data))] n = len(data) if color is None: cmap = random.choice(RIDGE_CMAPS) # +2 because we want to ignore first and last in most cases colors = make_cols_from_cmap(cmap, n + 2) elif isinstance(color, str): if color in plt.colormaps(): colors = make_cols_from_cmap(color, n + 2) else: colors = [color for _ in range(n)] elif isinstance(color, list): colors = color elif isinstance(color, np.ndarray): if len(color) == 3 and len(color) == color.size: colors = [color for _ in range(n)] else: colors = [color[:, i] for i in range(color.shape[1])] else: colors = color if plot_kws is None: plot_kws = dict() dist_maxes = {} xs = {} ys = {} for idx, col in enumerate(labels): ind, y = kde(data[idx], bw_method=bw_method, cut=cut) dist_maxes[col] = np.max(y) xs[col], ys[col] = ind, y ymaxes = dict(sorted(dist_maxes.items(), key=lambda item: item[1])) nrows = n if share_axes: nrows = 1 if ax is None: fig, axes = plt.subplots(nrows, ncols=1, figsize=figsize or (10, 6), gridspec_kw={"hspace": hspace}) else: share_axes = True axes = ax nrows = 1 if not isinstance(axes, np.ndarray): axes = np.array([axes]) ax_objs = axes.tolist() if isinstance(line_color, str): line_color = [line_color for _ in range(len(ymaxes))] if isinstance(line_width, (int, float)): line_width = [line_width for _ in range(len(ymaxes))] for idx, col in enumerate(reversed(list(ymaxes.keys()))): ax_idx = idx if share_axes: ax_idx = 0 if isinstance(colors[idx], str) and colors[idx] == "white": plot_kws['label'] = col ax = ax_objs[ax_idx] ax.plot(xs[col], ys[col], color=line_color[idx], linewidth=line_width[idx], **plot_kws) _fill_kws = { "alpha": 1.0, 'where': np.array([True for _ in range(len(xs[col]))]) } if fill_kws is None: fill_kws = dict() if isinstance(colors[idx], str) and colors[idx] == "white": pass else: fill_kws['label'] = col _fill_kws.update(fill_kws) ax.fill_between(xs[col], ys[col], color=colors[idx], **_fill_kws) # setting uniform y lims ax.set_ylim(0, max(ymaxes.values())) # make background transparent rect = ax.patch rect.set_alpha(0) if idx == n - 1: if xlabel: ax.set_xlabel(xlabel) ax.tick_params(axis="x") else: if nrows > 1: ax.set_xticklabels([]) ax.set_xticks([]) # remove borders, axis ticks, and labels if nrows > 1: ax.set_yticklabels([]) ax.set_yticks([]) else: ax.tick_params(axis="y") despine_axes(ax) if not share_axes and nrows>1: ax.text(xs[col][0], 0.2, col, ha="right") idx += 1 if not share_axes: plt.subplots_adjust(hspace=hspace) if title: plt.suptitle(title) if show: if share_axes: plt.legend() plt.show() return ax_objs