Source code for easy_mpl._box


__all__ = ["boxplot"]

import warnings
from itertools import zip_longest
from typing import Union, List, Tuple

import matplotlib
import numpy as np
import matplotlib.pyplot as plt

from .utils import is_rgb
from .utils import is_series
from .utils import is_dataframe
from .utils import process_axes
from .utils import create_subplots
from .utils import make_cols_from_cmap


[docs]def boxplot( data:Union[np.ndarray, List[np.ndarray], List[list]], line_color:Union[str, List[str]] = None, line_width = None, fill_color:Union[str, List[str]] = None, labels:Union[str, List[str]] = None, share_axes:bool = True, figsize:tuple = None, ax:plt.Axes = None, ax_kws:dict = None, show:bool = True, **box_kws, )->Tuple[Union[plt.Axes, List[plt.Axes]], Union[List[dict], dict]]: """ Draws the box and whiker plot parameters ---------- data : array like (list, numpy array, pandas dataframe/series) or list of array likes. If list of array likes, the length of arrays in the list can be different. line_color : name of color/colors/cmap for lines/boundaries of box, whisker, cap, median and mean line line_width : width of the box lines. fill_color : name of color/colors/cmap to fill the boxes. It can be any valid matplotlib color or cmap. labels : str/list (default=None) used for ticklabels of x-axes share_axes : bool (default=True) whether to draw all the histograms on one axes or not figsize : tuple (default=None) figure size as tuple (width, height) ax : plt.Axes, optional (default=None) matploltib axes on which to draw the plot ax_kws : dict (default=None) keyword arguments of :py:func:`easy_mpl.utils.process_axes` show : bool (default=show) whether to show the plot or not **box_kws : any additional keyword argument for :obj:`matplotlib.axes.Axes.boxplot` Returns ------- tuple a tuple of two - plt.Axes or list of :obj:`matplotlib.axes` - a dictionary or list of dictionaries which consists of boxes, medians, whiskers, fliers Examples --------- >>> from easy_mpl import boxplot >>> boxplot(np.random.random((100, 5))) we can also provide arrays of different lengths >>> boxplot([np.random.random(100), np.random.random(90)]) the color can be given as either color name or colormap >>> boxplot(np.random.random((100, 3)), fill_color=['pink', 'lightblue', 'lightgreen']) >>> boxplot(np.random.random((100, 3)), fill_color="viridis") See :ref:`sphx_glr_auto_examples_boxplot.py` for more examples """ if ax is None: ax = plt.gca() _box_kws = {} data, labels = _unpack_data(data, labels, share_axes) if share_axes: nplots = 1 else: nplots = len(labels) if box_kws is None: box_kws = dict() _box_kws.update(box_kws) f, axes = create_subplots(nplots, ax=ax, figsize=figsize) if isinstance(axes, np.ndarray): axes = axes.flatten().tolist() elif isinstance(axes, plt.Axes): axes = [axes] nboxes = len(labels) if len(labels)==1: nboxes = len(labels[0]) fill_colors = _unpack_colors(fill_color, nboxes, share_axes) line_colors = _unpack_colors(line_color, nboxes, share_axes) line_widths = _unpack_linewidth(line_width, nboxes, share_axes) box_outs = [] for (idx, name), x, ax in zip(enumerate(labels), data, axes): # in version 3.2.. giving DataFrame to ax.boxplot makes boxes for each row # in version 3.3.. giving DataFrame to ax.boxplot tries to make boxp for first row (columns) if is_dataframe(x) and matplotlib.__version__ <= "3.3.0": x = x.values box_out = ax.boxplot(x, **_box_kws) box_outs.append(box_out) _set_box_props(fill_colors[idx], line_colors[idx], line_widths[idx], box_out) _set_ticklabels(ax, share_axes, name, _box_kws) if ax_kws: process_axes(ax, **ax_kws) if show: plt.show() if len(box_outs)==1: box_outs = box_outs[0] if len(axes)==1: axes = axes[0] return axes, box_outs
def _set_ticklabels(ax, share_axes, name, box_kws): if name is not None: kws = dict() if share_axes: if box_kws.get('vert', True): _maybe_set_ticks(ax, name) else: _maybe_set_ticks(ax, name, "y") else: if isinstance(name, (str, int)): if box_kws.get('vert', True): ax.set_xticklabels([name]) else: ax.set_yticklabels([name], rotation=90, va='center') elif isinstance(name, list): if box_kws.get('vert', True): ax.set_xticklabels(name) else: ax.set_yticklabels(name, rotation=90, va='center') if share_axes and len(name) > 7: kws['rotation'] = 90 ax.xaxis.set_tick_params(rotation=90) return def _maybe_set_ticks(axes:plt.Axes, ticklabels, which="x"): ticks = getattr(axes, f"get_{which}ticks")() if len(ticklabels) == len(ticks): getattr(axes, f"set_{which}ticklabels")(ticklabels) else: warnings.warn(f""" {which}ticks ({len(ticks)}) and {which}ticklabels ({len(ticklabels)}) dont match""") return def _unpack_linewidth(line_width, nboxes, share_axes): if isinstance(line_width, (float, int)): line_widths = [[line_width] for _ in range(nboxes)] elif line_width is None: line_widths = [[None] for _ in range(nboxes)] else: raise ValueError if share_axes: line_widths = [[line_width[0] for line_width in line_widths]] return line_widths def _unpack_colors(color, nboxes, share_axes)->list: if isinstance(color, str): if color in plt.colormaps(): colors = make_cols_from_cmap(color, nboxes) colors = [[color] for color in colors] else: colors = [[color] for _ in range(nboxes)] elif is_rgb(color): colors = [[color] for _ in range(nboxes)] elif hasattr(color, '__len__'): assert len(color) == nboxes, f"{len(color)} colors for {nboxes} boxes?" colors = [[clr] for clr in color] elif color is None: colors = [[None] for _ in range(nboxes)] else: raise ValueError(f"{color} is not recognized as valid color") if share_axes: colors = [[color[0] for color in colors]] return colors def _set_box_props(fill_color:list, line_color:list, line_width:list, box_out): whiskers = box_out['whiskers'] if len(whiskers)>0: whiskers = np.array(whiskers).reshape(len(line_color), -1) boxes = box_out['boxes'] if len(boxes)>0: boxes = np.array(boxes).reshape(len(line_color), -1) caps = box_out['caps'] if len(caps)>0: caps = np.array(caps).reshape(len(line_color), -1) medians = box_out['medians'] if len(medians)>0: medians = np.array(medians).reshape(len(line_color), -1) means = box_out['means'] if len(means)>0: means = np.array(means).reshape(len(line_color), -1) fliers = box_out['fliers'] if len(fliers)>0: fliers = np.array(fliers).reshape(len(line_color), -1) for idx, (patch, whisker, cap, median, mean, flier) in enumerate(zip_longest( boxes, whiskers, caps, medians, means, fliers)): if fill_color[idx] is not None: if isinstance(patch[0], matplotlib.lines.Line2D): plt.setp(patch, markerfacecolor=fill_color[idx]) elif isinstance(patch[0], matplotlib.patches.PathPatch): plt.setp(patch, facecolor=fill_color[idx]) if line_color[idx] is not None: #plt.setp(patch, color=line_color[idx]) plt.setp(whisker, color=line_color[idx]) plt.setp(cap, color=line_color[idx]) plt.setp(median, color=line_color[idx]) if line_width[idx] is not None: plt.setp(patch, linewidth=line_width[idx]) plt.setp(whisker, linewidth=line_width[idx]) plt.setp(cap, linewidth=line_width[idx]) plt.setp(median, linewidth=line_width[idx]) return def _unpack_data(x, labels, share_axes:bool)->Tuple[list, list]: if isinstance(x, np.ndarray): if len(x) == x.size: X = [x] names = [[None]] else: if share_axes: X = [x] names = [[f"{i}" for i in range(x.shape[1])]] else: X = [x[:, i] for i in range(x.shape[1])] names = [f"{i}" for i in range(x.shape[1])] elif is_dataframe(x): if share_axes: names = [x.columns.tolist()] X = [x] else: X = [] for col in x.columns: X.append(x[col].values) names = x.columns.tolist() elif is_series(x): X = [x.values] names = [[x.name]] elif isinstance(x, (list, tuple)) and isinstance(x[0], (list, tuple, np.ndarray)): assert all([len(array) == np.array(array).size for array in x]), f""" All arrays must be one dimensional.""" X = [np.array(x_).reshape(-1,) for x_ in x] names = [None] * len(X) if share_axes: X = [X] names = [names] elif isinstance(x, (list, tuple)) and is_series(x[0]): # list of series if share_axes: X = [x] else: X = x names = [x_.name for x_ in x] elif isinstance(x, (list, tuple)) and not is_dataframe(x[0]): X = [x] names = [None] else: raise ValueError(f"unrecognized type of x {type(x)}") if labels is not None: if isinstance(labels, str): labels = [labels] if share_axes: labels = [labels] #assert len(labels) == len(names), f"{len(names)} does not match data" names = labels return X, names