__all__ = ["dumbbell_plot"]
from typing import Tuple
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
from .utils import is_rgb
from .utils import make_clrs_from_cmap
from .utils import to_1d_array, process_axes
from ._scatter import scatter
[docs]def dumbbell_plot(
start,
end,
labels=None,
line_color = None,
start_marker_color = None,
end_marker_color = None,
start_kws: dict = None,
end_kws: dict = None,
line_kws: dict = None,
sort_start:str = None,
sort_end:str = None,
ax: plt.Axes = None,
ax_kws:dict = None,
show: bool = True
) -> Tuple[plt.Axes, mpl.collections.PathCollection, mpl.collections.PathCollection]:
"""
Dumbell plot which indicates variation of several variables
from start to end.
Parameters
----------
start : list, array, series
an array consisting of starting values
end : list, array, series
an array consisting of end values
labels : list, array, series, optional
names of values in start/end arrays. It is used to label
ticklabcls on y-axis
line_color :
color for lines. This can be a color name, rbg value, array of rbg values
for each marker or a color palette name. This can be used to have separate
color for a each line.
start_marker_color :
color for starting markers. This can be a color name, rbg value, array
of rbg values for each marker or a color palette name. This can be
used to have separate color for a each marker.
end_marker_color :
color for end markers. T This can be a color name, rbg value, array of
rbg values for each marker or a color palette name. his can be used to
have separate color for a each marker.
start_kws : dict, optional
any additional keyword arguments for :py:func:`easy_mpl.utils.scatter` to modify start
markers such as ``color``, ``label`` etc
end_kws : dict, optional
any additional keyword arguments for :py:func:`easy_mpl.utils.scatter` to modify end
markers such as ``color``, ``label`` etc
line_kws : dict, optional
any additional keyword arguments for `lines.Line2D`_ to modify line
style/color which connects dumbbells.
sort_start : str (default=None)
either "ascend" or "descend"
sort_end : str (default=None)
either "ascend" or "descend"
ax : plt.Axes, optional
matplotlib axes object to work with. If not given then currently available
axes will be used.
ax_kws : dict optional
any keyword arguments for :py:func:`easy_mpl.utils.process_axes`.
show : bool, optional
whether to show the plot or not
Returns
-------
axes :
:obj:`matplotlib.axes` matplotlib axes object on which dumbells are drawn
st_pc :
:obj:`matplotlib.collections.PathCollection`
en_pc :
:obj:`matplotlib.collections.PathCollection`
Examples
--------
>>> import numpy as np
>>> from easy_mpl import dumbbell_plot
>>> st = np.random.randint(1, 5, 10)
>>> en = np.random.randint(11, 20, 10)
>>> dumbbell_plot(st, en)
... # modify line color
>>> dumbbell_plot(st, en, line_kws={'color':"black"})
See :ref:`sphx_glr_auto_examples_dumbell.py` for more examples
.. _lines.Line2D:
https://matplotlib.org/stable/api/_as_gen/matplotlib.lines.Line2D.html
"""
if ax_kws is None:
ax_kws = dict()
if ax is None:
ax = plt.gca()
if 'figsize' in ax_kws:
figsize = ax_kws.pop('figsize')
ax.figure.set_size_inches(figsize)
# convert starting and ending values to 1d array
start = to_1d_array(start)
end = to_1d_array(end)
index = np.arange(len(start))
assert len(start) == len(end) == len(index)
if labels is None:
labels = np.arange(len(index))
_line_kws = {'color': 'skyblue'}
if line_kws is not None:
_line_kws.update(line_kws)
line_colors = _get_color(line_color, _line_kws, len(start))
# assigning colors
_start_kws = {'color': '#a3c4dc', "label": "Start"}
if start_kws:
_start_kws.update(start_kws)
_end_kws = {'color': '#0e668b', "label": "End"}
if end_kws:
_end_kws.update(end_kws)
st_mc_colors = _get_color(start_marker_color, _start_kws, len(start))
en_mc_colors = _get_color(end_marker_color, _end_kws, len(start))
if sort_start:
start, end, labels, line_colors, st_mc_colors, en_mc_colors = _handle_sort(
sort_start, start, start, end, labels, line_colors, st_mc_colors, en_mc_colors)
elif sort_end:
start, end, labels, line_colors, st_mc_colors, en_mc_colors = _handle_sort(
sort_end, end, start, end, labels, line_colors, st_mc_colors, en_mc_colors)
# draw line segment
def line_segment(p1, p2, axes, color):
l = mlines.Line2D([p1[0], p2[0]], [p1[1], p2[1]], color=color, **_line_kws)
axes.add_line(l)
return
# joining points together using line segments
for (_idx, idx), _p1, _p2 in zip(enumerate(index), end, start):
line_segment([_p1, idx], [_p2, idx], ax, color=line_colors[_idx])
# circles are plotted after line so that lines don't enter inside the circles
# plotting points for starting and ending values
ax, st_paths = scatter(y=index, x=start, show=False, ax=ax,
color=st_mc_colors, **_start_kws)
ax, en_paths = scatter(y=index, x=end, ax=ax, show=False,
color=en_mc_colors, **_end_kws)
ax.legend()
# set labels
ax.set_yticks(index)
ax.set_yticklabels(labels)
if ax_kws:
process_axes(ax=ax, **ax_kws)
# show plot if show=True
if show:
plt.show()
return ax, st_paths, en_paths
def _get_color(sugg_clr, kws, n)->list:
if sugg_clr is None:
colors = [kws['color'] for _ in range(n)]
elif isinstance(sugg_clr, str):
if sugg_clr in plt.colormaps():
# todo
# this will result in wrong colorbar if these colors are used
# in for plot/scatter
colors = make_clrs_from_cmap(sugg_clr, n, 0.1, 0.9)
else: # 'k'
colors = [sugg_clr for _ in range(n)]
elif is_rgb(sugg_clr):
colors = [sugg_clr for _ in range(n)]
else:
assert hasattr(sugg_clr, '__len__') and len(sugg_clr) == n, f"Invalid color {sugg_clr}"
colors = sugg_clr
kws.pop('color')
return colors
def _handle_sort(sort_type, sort_wrt, start, end, labels,
line_colors, st_mc_clr, en_mc_clr):
assert sort_type in ["ascend", "descend"]
if sort_type == "ascend":
sort_idx = np.argsort(sort_wrt)
else:
sort_idx = np.flip(np.argsort(sort_wrt))
start = np.array(start)[sort_idx]
end = np.array(end)[sort_idx]
labels = np.array(labels)[sort_idx]
if not isinstance(line_colors, str):
line_colors = np.array(line_colors)[sort_idx]
if not isinstance(st_mc_clr, str):
st_mc_clr = np.array(st_mc_clr)[sort_idx]
if not isinstance(en_mc_clr, str):
en_mc_clr = np.array(en_mc_clr)[sort_idx]
return start, end, labels, line_colors, st_mc_clr, en_mc_clr