Source code for easy_mpl._regplot


__all__ = ["regplot"]

import random
from typing import Union, List

import numpy as np
import matplotlib.pyplot as plt
from easy_mpl.utils import process_axes
from easy_mpl.utils import AddMarginalPlots


REGC_COMBS = [
    ['cadetblue', 'slateblue', 'darkslateblue'],
    ['cadetblue', 'mediumblue', 'mediumblue'],
    ['cornflowerblue', 'dodgerblue', 'darkblue'],
    ['cornflowerblue', 'dodgerblue', 'steelblue'],
    ['cornflowerblue', 'mediumblue', 'dodgerblue'],
    ['cornflowerblue', 'steelblue', 'mediumblue'],
    ['darkslateblue', 'aliceblue', 'mediumblue'],
    ['darkslateblue', 'blue', 'royalblue'],
    ['darkslateblue', 'blueviolet', 'royalblue'],
    ['darkslateblue', 'darkblue', 'midnightblue'],
    ['darkslateblue', 'mediumblue', 'darkslateblue'],
    ['darkslateblue', 'midnightblue', 'mediumblue'],
    ['seagreen', 'darkslateblue', 'cadetblue'],
    ['cadetblue', 'darkblue', 'midnightblue'],
    ['cadetblue', 'deepskyblue', 'cadetblue']
]


[docs]def regplot( x: Union[np.ndarray, list], y: Union[np.ndarray, list], label:str = None, marker_size: Union[int, float] = 20, marker_color=None, scatter_kws: dict = None, line_style='-', line_color=None, line_kws: dict = None, ci: Union[int, None] = 95, fill_color=None, figsize: tuple = None, marginals: bool = False, marginal_ax_size: Union[float, List[float]] = 0.7, marginal_ax_pad: Union[float, List[float]] = 0.25, ridge_line_kws: Union[dict, List[dict]] = None, fill_kws: Union[dict, List[dict]] = None, hist: bool = True, hist_kws: Union[dict, List[dict]] = None, ax: plt.Axes = None, ax_kws:dict = None, show: bool = True, ) -> plt.Axes: """ Regpression plot with regression line and confidence interval Parameters ---------- x : array like, optional the 'x' value. It can be numpy array, pandas DataFram/Series or a list y : array like, optional It can be numpy array, pandas DataFram/Series or a list ci : int, optional confidence interval. Set to None if not required. label : str, optional Label to use as legend. The value to annotate with. marker_size : int, optional size of marker marker_color: optional color of marker scatter_kws : dict keyword arguments for :obj:`matplotlib.axes.scatter` line_style : str (default='-') line style will be used as ax.plot(x,y,line_style) Set this to None if you don't want to plot line line_color : optional color of line line_kws : dict keyword arguments for axes.plot for line plot fill_color : optional only relevent if ci is not None. figsize : tuple, optional figure size (width, height) marginals : bool (default=False) whether to draw the marginal plots or not. If yes, marginal plots will be drawn using :py:class:`easy_mpl.utils.AddMarginalPlots` marginal_ax_size : size of marginal axes. If given as list, the first argument is taken for horizontal marginal axes and second argument for vertical merginal axees. It is only valid if ``marginals`` is set to True. marginal_ax_pad : pad value for marginal axes. If given as list, the first argument is taken for horizontal marginal axes and second argument for vertical merginal axees. It is only valid if ``marginals`` is set to True. ridge_line_kws : keyword arguments for a:obj:`matplotlib.axes.Axes.plot to draw the kde. It can be dictionary or list of two dictionaries, where first dictionary is for hoziontal marginal axes and second dictionary is for vertical marginal axes fill_kws : dict/List[dict] (default=None) keyword arguments for :obj:`matplotlib.axes.Axes.fill_between`. It can be dictionary or list of two dictionaries, where first dictionary is for hoziontal marginal axes and second dictionary is for vertical marginal axes hist : bool whether to draw the histogram or not on marginal axes hist_kws : keyword arguments for :obj:`matplotlib.axes.Axes.hist`. It can be dictionary or list of two dictionaries, where first dictionary is for hoziontal marginal axes and second dictionary is for vertical marginal axes ax : plt.Axes, optional matplotlib axes :obj:`matplotlib.axes` to draw plot on. If not given, current avaialable will be used. ax_kws : dict (default=None) keyword arguments for :py:func:`easy_mpl.utils.process_axes` show : bool, optional whether to show the plot or not Returns -------- matplotlib.pyplot.Axes the matplotlib Axes :obj:`matplotlib.axes` on which regression plot is drawn. Examples -------- >>> import numpy as np >>> from easy_mpl import regplot >>> x_, y_ = np.random.random(100), np.random.random(100) >>> regplot(x_, y_) See :ref:`sphx_glr_auto_examples_reg_plot.py` for more examples Note ---- If nans are present in x or y, they will be removed. """ x = np.array(x).reshape(-1,) y = np.array(y).reshape(-1,) # removing nans based upon nans in x x_nan_idx = np.isnan(x) if x_nan_idx.sum() > 0: x = x[~x_nan_idx] y = y[~x_nan_idx] # removing nans based upon nans in y y_nan_idx = np.isnan(y) if y_nan_idx.sum() > 0: x = x[~y_nan_idx] y = y[~y_nan_idx] assert len(x) > 1, f""" length of x is smaller than 1 {x} {len(y_nan_idx)} nans found in y and {len(x_nan_idx)} nans found in x""" assert len(x) == len(y), f""" x and y must be same length. Got {len(x)} and {len(y)}""" mc, lc, fc = random.choice(REGC_COMBS) _metric_names = {'r2': '$R^2$'} if ax is None: _, ax = plt.subplots(figsize=figsize or (6, 5)) _scatter_kws = dict() if label: _scatter_kws['label'] = label if scatter_kws is not None: _scatter_kws.update(scatter_kws) if marker_color is None: marker_color = mc ax.scatter(x, y, c=marker_color, s=marker_size, **_scatter_kws) if label: ax.legend() if line_kws is None: line_kws = dict() if line_color is None: line_color = lc if fill_color is None: fill_color = fc _regplot(x, y, ax=ax, ci=ci, line_style=line_style, line_color=line_color, fill_color=fill_color, **line_kws) if marginals: AddMarginalPlots(ax, pad=marginal_ax_pad, size=marginal_ax_size, hist=hist, hist_kws=hist_kws, ridge_line_kws=ridge_line_kws, fill_kws=fill_kws)(x, y) _ax_kws = {'xlabel': 'Observed', 'xlabel_kws': {'fontsize': 14}, 'ylabel': 'Prediction', 'ylabel_kws': {'fontsize': 14}, } if ax_kws is not None: _ax_kws.update(ax_kws) process_axes(ax, **_ax_kws) if show: plt.show() return ax
def bootdist(f, args, n_boot=1000, **func_kwargs): n = len(args[0]) integers = np.random.randint boot_dist = [] for i in range(int(n_boot)): resampler = integers(0, n, n, dtype=np.intp) # intp is indexing dtype sample = [a.take(resampler, axis=0) for a in args] boot_dist.append(f(*sample, **func_kwargs)) return np.array(boot_dist) def _regplot_paras(x, y, ci: int = None): """prepares parameters for regplot""" grid = np.linspace(np.min(x), np.max(x), 100) x = np.c_[np.ones(len(x)), x] grid = np.c_[np.ones(len(grid)), grid] yhat = grid.dot(reg_func(x, y)) err_bands = None if ci: boots = bootdist(reg_func, args=[x, y], n_boot=1000).T yhat_boots = grid.dot(boots).T err_bands = _ci(yhat_boots, ci, axis=0) return grid, yhat, err_bands def _ci(a, which=95, axis=None): """Return a percentile range from an array of values.""" p = 50 - which / 2, 50 + which / 2 return np.nanpercentile(a, p, axis) def reg_func(_x, _y): return np.linalg.pinv(_x).dot(_y) def _regplot(x, y, ax, ci=None, line_style="-", line_color=None, fill_color=None, **kwargs): grid, yhat, err_bands = _regplot_paras(x, y, ci) if line_style is not None: ax.plot(grid[:, 1], yhat, line_style, color=line_color, **kwargs) if ci: ax.fill_between(grid[:, 1], *err_bands, facecolor=fill_color, alpha=.15, where=np.array([True for _ in range(len(grid))]) ) return ax