Source code for easy_mpl._main


__all__ = ["plot"]

import numpy as np
import matplotlib.pyplot as plt

from .utils import process_axes, is_dataframe, is_series, create_subplots

# todo add share_axes argument

[docs]def plot( *args, share_axes:bool = True, show: bool = True, ax: plt.Axes = None, ax_kws:dict = None, **kwargs ) -> plt.Axes: """ One liner plot function. It's use is not more complex than :obj:`matplotlib.axes.Axes.plot` or :obj:`matplotlib.pyplot.plot` . However it accomplishes all in one line what requires multiple lines in matplotlib. args and kwargs can be anything which goes into :obj:`matplotlib.pyplot.plot` or :obj:`matplotlib.axes.Axes.plot`. Parameters ---------- *args : either a single array or x and y arrays or anything which can go to :obj:`matplotlib.axes.Axes.plot` or anything which can got to :obj:`matplotlib.pyplot.plot` . share_axes : bool (default=True) whether to draw all the plots on same axes or not. Only relevant for 2 dimensional data. ax : :obj:`matplotlib.axes` matplotlib axes object on which plot is to be drawn. If not given, then current active axes will be used. ax_kws : dict keyword arguments for :func:`easy_mpl.utils.process_axes` show : bool, optional (default=True) If set to True, plt.show() is called. **kwargs : optional Any keyword argument for :obj:`matplotlib.pyplot.plot` or :obj:`matplotlib.axes.Axes.plot` Returns ------- :obj:`matplotlib.axes` :obj:`matplotlib.axes` on which the plot is drawn. If ``show`` is False, this axes can be used for further processing Example -------- >>> from easy_mpl import plot >>> import numpy as np >>> plot(np.random.random(100)) use x and y >>> plot(np.arange(100), np.random.random(100)) use x and y >>> plot(np.arange(100), np.random.random(100), '.') string after arrays represent marker style >>> plot(np.random.random(100), '.') use cutom marker >>> plot(np.random.random(100), '--*') using label keyword >>> plot(np.random.random(100), '--*', label='label') log transform y-axis >>> plot(np.random.random(100), '--*', ax_kws={'logy':True}, label='label') See :ref:`sphx_glr_auto_examples_plot.py` for more examples """ if ax_kws is None: ax_kws = {} plot_args = [] marker = None if len(args) == 1: data, = args data = [data] elif len(args) == 2 and not isinstance(args[1], str): data = args elif len(args) == 2 and isinstance(args[1], str): data, marker = args[0], args[1] data = [data] elif len(args) == 3: *data, marker = args if isinstance(marker, np.ndarray): data.append(marker) marker = None else: data = args if marker: plot_args.append(marker) assert 'marker' not in kwargs # we have alreay got marker s = data[0] nplots = 1 if not share_axes and hasattr(s, 'shape') and s.shape[1] >1: nplots = s.shape[1] figsize = None if 'figsize' in ax_kws: figsize = ax_kws.pop('figsize') f, ax = create_subplots(nplots, ax=ax, figsize=figsize, ncols=1, sharex="all") if isinstance(ax, np.ndarray): ax = ax.flatten().tolist() elif isinstance(ax, plt.Axes): ax = [ax] if is_series(s): ax_kws['min_xticks'] = ax_kws.get('min_xticks', 3) ax_kws['max_xticks'] = ax_kws.get('max_xticks', 5) ax_kws['xlabel'] = ax_kws.get('xlabel', s.index.name) ax_kws['ylabel'] = ax_kws.get('ylabel', s.name) elif is_dataframe(s): ax_kws['min_xticks'] = ax_kws.get('min_xticks', 3) ax_kws['max_xticks'] = ax_kws.get('max_xticks', 5) if s.shape[1] == 1: ax_kws['xlabel'] = ax_kws.get('xlabel', s.index.name) ax_kws['ylabel'] = ax_kws.get('ylabel', s.columns.tolist()[0]) else: if share_axes: ax = _plot_df_cols(ax[0], data, s, ax_kws, plot_args, kwargs) else: for idx, col in enumerate(s.columns): #if 'label' not in kwargs: kwargs['label'] = col ax_kws['label'] = col ax[idx].plot(s[col], *plot_args, **kwargs) _process_axis(ax[idx], False, ax_kws) if show: plt.show() return ax elif hasattr(s, 'shape') and len(s.shape)>1 and s.shape[1]>1 and not share_axes: for idx, col in enumerate(range(s.shape[1])): ax[idx].plot(s[:, idx]) _process_axis(ax[idx], False, ax_kws) if show: plt.show() return ax if 'label' in kwargs: ax_kws['label'] = kwargs['label'] ax[0].plot(*data, *plot_args, **kwargs) return _process_axis(ax[0], show, ax_kws)
def _plot_df_cols(ax, data, s, ax_kws, plot_args, kwargs): ax_kws['xlabel'] = ax_kws.get('xlabel', s.index.name) ax_kws['label'] = ax_kws.get('label', s.columns.tolist()) for col in s.columns: kwargs['label'] = col data[0] = s[col] ax.plot(*data, *plot_args, **kwargs) return process_axes(ax, **ax_kws) def _process_axis(ax, show, kwargs): if kwargs: ax = process_axes(ax=ax, **kwargs) if kwargs.get('save', False): plt.savefig(f"{kwargs.get('name', 'fig.png')}") if show: plt.show() return ax