Source code for easy_mpl._imshow

__all__ = ["imshow"]

from typing import Union

import numpy as np
import matplotlib.pyplot as plt

from .utils import despine_axes
from .utils import add_cbar
from .utils import process_axes
from .utils import annotate_imshow

[docs]def imshow( values, yticklabels=None, xticklabels=None, annotate:bool = False, annotate_kws:dict = None, colorbar: bool = False, grid_params: dict = None, mask : Union[bool, str, np.ndarray] = None, cbar_params: dict = None, ax:plt.Axes = None, ax_kws: dict = None, show:bool = True, **kwargs ): """ One stop shop for matplotlib's imshow function Parameters ---------- values: 2d array the image/data to show. It must bt 2 dimensional. It can also be dataframe. annotate : bool, optional whether to annotate the heatmap or not annotate_kws : dict, optional a dictionary with following possible keys - ha : horizontal alighnment (default="center") - va : vertical alighnment (default="center") - fmt : format (default='%.2f') - textcolors : colors for axes.text - threshold : threshold to be used for annotation - **kws : any other keyword argument for axes.text colorbar : bool, optional whether to draw colorbar or not xticklabels : list, optional tick labels for x-axis. For DataFrames, column names are used by default. yticklabels : list, optional tick labels for y-axis. For DataFrames, index is used by default grid_params : dict, optional (default=None) parameters to process grid. Allowed keys in the dictionary are following - ``border``, bool - ``linestyle`` - ``linewidth`` - ``color`` mask : This argument can be used to hide part of heatmap from being displayed. - True : will only show the lower half - ``upper`` will only show the lower half - ``lower`` will only show the upper half cbar_params : dict, optional parameters that will go to :py:func`easy_mpl.utils.process_cbar` for colorbar. For example ``pad`` or ``orientation`` ax : plt.Axes, optional if not given, current available axes will be used ax_kws : dict, optional (default=None) any keyword arguments for :py:func:`easy_mpl.utils.process_axes` function as dictionary show : bool, optional whether to show the plot or not **kwargs : optional any further keyword arguments for :obj:`matplotlib.axes.Axes.imshow` Returns ------- matplotlib.image.AxesImage a :obj:`matplotlib.image.AxesImage` Examples -------- >>> import numpy as np >>> from easy_mpl import imshow >>> x = np.random.random((10, 5)) >>> imshow(x, annotate=True) ... # show colorbar >>> imshow(x, colorbar=True) ... # setting white grid lines and annotation >>> data = np.random.random((4, 10)) >>> imshow(data, cmap="YlGn", ... xticklabels=[f"Feature {i}" for i in range(data.shape[1])], ... grid_params={'border': True, 'color': 'w', 'linewidth': 2}, annotate=True, ... colorbar=True) See :ref:`` for more examples """ if ax_kws is None: ax_kws = dict() if ax is None: ax = plt.gca() if 'figsize' in ax_kws: figsize = ax_kws.pop('figsize') ax.figure.set_size_inches(figsize) if hasattr(values, "values") and hasattr(values, "columns"): import pandas as pd # don't make whole project dependent upon pandas if not xticklabels: xticklabels = values.columns.to_list() if not yticklabels: yticklabels = values.index.tolist() # when data in dataframe is object type, it causes error in plotting # the best way to convert series in df to number is to use to_numeric values = np.column_stack([pd.to_numeric(values.iloc[:, i]) for i in range(values.shape[1])]) to_keep = None if mask is not None: if isinstance(mask, (str, bool)): _mask = np.tri(values.shape[0], k=-1) if mask == "lower": values =, mask=_mask) # mask out the lower triangle to_keep = ['right', 'top'] else: values =, mask=_mask).T to_keep = ['left', 'bottom'] tick_params = {} if 'ticks' in kwargs: tick_params['ticks'] = kwargs.pop('ticks') im = ax.imshow(values, **kwargs) if to_keep: despine_axes(ax, keep=to_keep) if annotate_kws is None: annotate_kws = {} assert isinstance(annotate_kws, dict) _annotate_kws = { 'ha':"center", "va": "center", "fmt": '%.2f', "textcolors": ("black", "white"), "threshold": None } _annotate_kws.update(annotate_kws) if annotate: annotate_imshow(im, values, **_annotate_kws) if yticklabels is not None: ax.set_yticks(np.arange(len(yticklabels))) ax.set_yticklabels(yticklabels) if xticklabels is not None: ax.set_xticks(np.arange(len(xticklabels))) if len(xticklabels) > 5: ax.set_xticklabels(xticklabels, rotation=70) ax.set_xticklabels(xticklabels) if ax_kws: process_axes(ax, **ax_kws) if grid_params: process_grid(ax, values, **grid_params) if colorbar: if cbar_params is None: cbar_params = {} add_cbar(ax, im, **cbar_params) # cb_tick_params = cb_tick_params or {'pad': 0.2, 'orientation': 'vertical'} # # # divider = make_axes_locatable(ax) # cax = divider.append_axes("right", size="5%", pad=0.2) # fig: plt.Figure = plt.gcf() # cb = fig.colorbar(im, cax=cax, **cb_tick_params) if show: return im
def process_grid( ax:plt.Axes, data:np.ndarray, border:bool = False, color:str = "w", linewidth:Union[int, float] = 3, linestyle:str = '-' ): if not border: # Turn spines off and create white grid. if isinstance(ax.spines, dict): for sp in ax.spines: ax.spines[sp].set_visible(False) else: ax.spines[:].set_visible(False) ax.set_xticks(np.arange(data.shape[1] + 1) - .5, minor=True) ax.set_yticks(np.arange(data.shape[0] + 1) - .5, minor=True) ax.grid(which="minor", color=color, linestyle=linestyle, linewidth=linewidth) ax.tick_params(which="minor", bottom=False, left=False) return