"""Compare plot code."""
from collections.abc import Mapping
from importlib import import_module
from typing import Any, Literal
import numpy as np
from arviz_base import rcParams
from xarray import Dataset, DataTree
from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.utils import get_visual_kwargs
[docs]
def plot_compare(
cmp_df,
relative_scale=False,
rotated=False,
hide_top_model=False,
backend=None,
visuals: Mapping[
Literal[
"point_estimate",
"error_bar",
"ref_line",
"ref_band",
"similar_line",
"labels",
"title",
"ticklabels",
],
Mapping[str, Any] | bool,
] = None,
**pc_kwargs,
):
r"""Summary plot for model comparison.
Models are compared based on their expected log pointwise predictive density (ELPD).
Or some transformation of it, such as the mean log predictive density (MLPD)
or the geometric mean predictive density (GMPD).
Higher ELPD values indicate better predictive performance.
The ELPD is estimated by Pareto smoothed importance sampling leave-one-out
cross-validation (LOO). Details are presented in [1]_ and [2]_.
The ELPD can only be interpreted in relative terms. But differences in ELPD less than 4
are considered negligible [3]_.
Parameters
----------
comp_df : pandas.DataFrame
Usually this will be the result of the :func:`arviz_stats.compare` function.
It is assumed that the DataFrame has two columns one named `elpd`, `mlpd`, or `gmpd`,
the other named `se`, and the index is the model names. Additionally,
it is assumed that the first row of the DataFrame is the top model.
relative_scale : bool, optional.
If True, the `stats` values are scaled relative to the best model.
Defaults to True.
rotated : bool, optional
If True, the plot is rotated, with models on the y-axis and ELPD on the x-axis.
Defaults to False.
hide_top_model : bool, optional
If True, the top model (first row of `comp_df`) will not appear as a point with error bars
or in the axis labels. Its performance can still be accessed by the visuals `ref_line`
and/or `ref_band`. Defaults to False.
backend : {"bokeh", "matplotlib", "plotly"}
Select plotting backend. Defaults to rcParams["plot.backend"].
visuals : mapping of {str : mapping or bool}, optional
Valid keys are:
* point_estimate -> passed to :func:`~arviz_plots.backend.none.scatter`
* error_bar -> passed to :func:`~arviz_plots.backend.none.line`
* ref_line -> passed to :func:`~arviz_plots.backend.none.hline` or
:func:`~arviz_plots.backend.none.vline` depending on the
``rotated`` parameter.
* ref_band -> passed to :func:`~arviz_plots.backend.none.hspan` or
:func:`~arviz_plots.backend.none.vspan` depending on the
``rotated`` parameter. Defaults to ``False``.
* similar_line -> passed to :func:`~arviz_plots.backend.none.hline` or
:func:`~arviz_plots.backend.none.vline` depending on the
``rotated`` parameter. Defaults to ``False``.
* labels -> passed to :func:`~arviz_plots.backend.none.xticks` and
:func:`~arviz_plots.backend.none.yticks`
* title -> passed to :func:`~arviz_plots.backend.none.title`
* ticklabels -> passed to :func:`~arviz_plots.backend.none.yticks`
**pc_kwargs
Passed to :class:`arviz_plots.PlotCollection`
Returns
-------
PlotCollection
See Also
--------
:func:`arviz_stats.compare`: Summary plot for model comparison.
:func:`arviz_stats.loo` : Compute the ELPD using Pareto smoothed importance sampling
Leave-one-out cross-validation method.
References
----------
.. [1] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation
and WAIC*. Statistics and Computing. 27(5) (2017).
https://doi.org/10.1007/s11222-016-9696-4. arXiv preprint https://arxiv.org/abs/1507.04544.
.. [2] Vehtari et al. *Pareto Smoothed Importance Sampling*.
Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html
arXiv preprint https://arxiv.org/abs/1507.02646
.. [3] Sivula et al. *Uncertainty in Bayesian Leave-One-Out Cross-Validation Based Model
Comparison*. (2025). https://doi.org/10.48550/arXiv.2008.10296
"""
# Set default backend
if backend is None:
backend = rcParams["plot.backend"]
if visuals is None:
visuals = {}
# Check we have the required columns
valid_stats = [col for col in ("elpd", "mlpd", "gmpd") if col in cmp_df.columns]
if not valid_stats:
raise ValueError(
"The DataFrame must contain one of the following columns: 'elpd', 'mlpd', or 'gmpd'."
)
stats = valid_stats[0]
if "se" not in cmp_df.columns:
raise ValueError("The DataFrame must contain a 'se' column for standard errors.")
# Get plotting backend
p_be = import_module(f"arviz_plots.backend.{backend}")
# Get figure params and create figure and axis
figure_kwargs = pc_kwargs.pop("figure_kwargs", {}).copy()
figsize = figure_kwargs.pop("figsize", None)
figsize_units = figure_kwargs.pop("figsize_units", None)
figsize = p_be.scale_fig_size(
figsize,
rows=int(len(cmp_df) ** 0.5),
cols=2,
figsize_units=figsize_units,
)
figsize_units = "dots"
figure, target = p_be.create_plotting_grid(
1, figsize=figsize, figsize_units=figsize_units, **figure_kwargs
)
# Create plot collection
plot_collection = PlotCollection(
Dataset({}),
viz_dt=DataTree.from_dict(
{"/": Dataset({"figure": np.array(figure, dtype=object), "plot": target})}
),
backend=backend,
**pc_kwargs,
)
if isinstance(target, np.ndarray):
target = target.tolist()
perf_stats = cmp_df[stats].values
ses = cmp_df["se"].values
# Set scale relative to the best model
if relative_scale:
perf_stats = perf_stats - perf_stats[0]
label_score = f"{stats.upper()} (relative)"
else:
label_score = stats.upper()
# Create labels for the models
label_models = cmp_df.index[hide_top_model:]
# Compute positions of yticks
yticks_pos = list(range(len(cmp_df) - hide_top_model, 0, -1))
# Compute positions of the reference line and band
pos_ref_line = perf_stats[0]
pos_ref_band = (perf_stats[0] - ses[0], perf_stats[0] + ses[0])
# Compute values for standard error bars
se_list = list(
zip(
(perf_stats[hide_top_model:] - ses[hide_top_model:]),
(perf_stats[hide_top_model:] + ses[hide_top_model:]),
)
)
# Compute positions for mean elpd estimates
if rotated:
scatter_x = yticks_pos
scatter_y = perf_stats[hide_top_model:]
else:
scatter_x = perf_stats[hide_top_model:]
scatter_y = yticks_pos
# Plot ELPD standard error bars
error_kwargs = get_visual_kwargs(visuals, "error_bar")
if error_kwargs is not False:
error_kwargs.setdefault("color", "B1")
for se_vals, ytick in zip(se_list, yticks_pos):
if rotated:
p_be.line((ytick, ytick), se_vals, target, **error_kwargs)
else:
p_be.line(se_vals, (ytick, ytick), target, **error_kwargs)
# Add reference line for the best model
ref_l_kwargs = get_visual_kwargs(visuals, "ref_line")
if ref_l_kwargs is not False:
ref_l_kwargs.setdefault("color", "B2")
ref_l_kwargs.setdefault("linestyle", p_be.get_default_aes("linestyle", 2, {})[-1])
if rotated:
p_be.hline(pos_ref_line, target, **ref_l_kwargs)
else:
p_be.vline(pos_ref_line, target, **ref_l_kwargs)
# Add reference band for the best model
ref_b_kwargs = get_visual_kwargs(visuals, "ref_band", False)
if ref_b_kwargs is not False:
ref_b_kwargs.setdefault("color", "B2")
ref_b_kwargs.setdefault("alpha", 0.1)
if rotated:
p_be.hspan(*pos_ref_band, target=target, **ref_b_kwargs)
else:
p_be.vspan(*pos_ref_band, target=target, **ref_b_kwargs)
# Plot ELPD point estimates
pe_kwargs = get_visual_kwargs(visuals, "point_estimate")
if pe_kwargs is not False:
pe_kwargs.setdefault("color", "B1")
p_be.scatter(scatter_x, scatter_y, target, **pe_kwargs)
# Add line for statistically undistinguishable models
similar_l_kwargs = get_visual_kwargs(visuals, "similar_line", False)
if similar_l_kwargs is not False:
similar_l_kwargs.setdefault("color", "B2")
similar_l_kwargs.setdefault("linestyle", p_be.get_default_aes("linestyle", 3, {})[-1])
if rotated:
p_be.hline(perf_stats[0] - 4, target, **similar_l_kwargs)
else:
p_be.vline(perf_stats[0] - 4, target, **similar_l_kwargs)
# Add title and labels
title_kwargs = get_visual_kwargs(visuals, "title")
if title_kwargs is not False:
p_be.title(
"Model comparison\nhigher is better",
target,
**title_kwargs,
)
labels_kwargs = get_visual_kwargs(visuals, "labels")
if labels_kwargs is not False:
if rotated:
p_be.ylabel(label_score, target, **labels_kwargs)
else:
p_be.xlabel(label_score, target, **labels_kwargs)
ticklabels_kwargs = get_visual_kwargs(visuals, "ticklabels")
if ticklabels_kwargs is not False:
if rotated:
p_be.xticks(yticks_pos, label_models, target, **ticklabels_kwargs)
else:
p_be.yticks(yticks_pos, label_models, target, **ticklabels_kwargs)
return plot_collection