"""lm plot code."""
import warnings
from collections.abc import Mapping, Sequence
from copy import copy
from importlib import import_module
from typing import Any, Literal
import arviz_stats as azs
import numpy as np
import xarray as xr
from arviz_base import rcParams
from arviz_base.labels import BaseLabeller
from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.utils import (
filter_aes,
get_group,
process_group_variables_coords,
set_wrap_layout,
)
from arviz_plots.visuals import fill_between_y, labelled_x, labelled_y, line_xy, scatter_xy
[docs]
def plot_lm(
dt,
y=None,
x=None,
y_pred=None,
x_pred=None,
filter_vars=None,
group="posterior_predictive",
coords=None,
sample_dims=None,
ci_kind=None,
ci_prob=None,
line_kind=None,
plot_collection=None,
backend=None,
labeller=None,
aes_by_visuals: Mapping[
Literal[
"ci_line",
"central_line",
"ci_fill",
"scatter",
"xlabel",
"ylabel",
],
Sequence[str],
] = None,
visuals: Mapping[
Literal[
"ci_line",
"central_line",
"ci_fill",
"scatter",
"xlabel",
"ylabel",
],
Mapping[str, Any] | Literal[False],
] = None,
stats: Mapping[
Literal["credible_interval", "point_estimate"],
Mapping[str, Any] | xr.Dataset,
] = None,
**pc_kwargs,
):
"""Posterior predictive and mean plots for regression-like data..
Parameters
----------
dt : DataTree
Input data
y : str or DataArray, optional
Target variable. If None (default), the first variable in "observed_data" is used.
x : str or list of str or DataArray or Dataset, optional
Independent variable(s). If None (default), all variables in "constant_data" are used.
y_pred : str or DataArray, optional
Predicted values.
If None (default), the variable in the specified group with the same name as y is used.
x_pred : str or list of str or DataArray or Dataset, optional
Independent variable(s) for predictions.
If None (default), and if group is "predictions", all variables corresponding to x data
in "predictions_constant_data" group are used. If group is "posterior_predictive",
x is used.
filter_vars: {None, “like”, “regex”}, default None
If None (default), interpret var_names as the real variables names.
If “like”, interpret var_names as substrings of the real variables names.
If “regex”, interpret var_names as regular expressions on the real variables names.
It is used for any of y, x, y_pred, and x_pred if they are strings or lists of strings.
group : str, default "posterior_predictive"
Group to use for plotting.
coords : mapping, optional
Coordinates to use for plotting.
sample_dims : iterable, optional
Dimensions to reduce unless mapped to an aesthetic.
Defaults to ``rcParams["data.sample_dims"]``
ci_kind : {"hdi", "eti"}, optional
Which credible interval to use. Defaults to ``rcParams["stats.ci_kind"]``
ci_prob : float or list of float, optional
Indicates the probabilities that should be contained within the plotted credible intervals.
Defaults to ``rcParams["stats.ci_prob"]``
line_kind : {"mean", "median","mode"}, optional
Which point estimate to use for the line. Defaults to ``rcParams["stats.point_estimate"]``
plot_collection : PlotCollection, optional
backend : {"matplotlib", "bokeh"}, optional
labeller : labeller, optional
aes_by_visuals : mapping, optional
Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
when plotted. Valid keys are the same as for `visuals`.
By default, there are no aesthetic mappings at all
visuals : mapping of {str : mapping or bool}, optional
Valid keys are:
* ci_line -> passed to :func:`~.visuals.line_xy`. Defaults to False
* central_line -> passed to :func:`~.visuals.line_xy`.
* ci_fill -> passed to :func:`~.visuals.fill_between_y`.
* scatter -> passed to :func:`~.visuals.scatter_xy`.
* xlabel -> passed to :func:`~.visuals.labelled_x`.
* ylabel -> passed to :func:`~.visuals.labelled_y`.
stats : mapping, optional
Valid keys are:
* credible_interval -> passed to eti or hdi
* point_estimate -> passed to mean, median or mode
**pc_kwargs
Passed to :class:`arviz_plots.PlotCollection.wrap`
Returns
-------
PlotMatrix
"""
if sample_dims is None:
sample_dims = rcParams["data.sample_dims"]
if isinstance(sample_dims, str):
sample_dims = [sample_dims]
if visuals is None:
visuals = {}
if pc_kwargs is None:
pc_kwargs = {}
else:
pc_kwargs = pc_kwargs.copy()
if ci_prob is None:
ci_prob = rcParams["stats.ci_prob"]
if ci_kind is None:
ci_kind = rcParams["stats.ci_kind"]
if aes_by_visuals is None:
aes_by_visuals = {}
else:
aes_by_visuals = aes_by_visuals.copy()
if stats is None:
stats = {}
else:
stats = stats.copy()
if labeller is None:
labeller = BaseLabeller()
if line_kind is None:
line_kind = rcParams["stats.point_estimate"]
if backend is None:
if plot_collection is None:
backend = rcParams["plot.backend"]
else:
backend = plot_collection.backend
obs_data = get_group(dt, "observed_data")
if y is None:
y = list(obs_data.data_vars)[0]
if isinstance(y, xr.Dataset):
raise TypeError(
"y can't be a dataset because multiple target variables are not supported yet."
)
if not isinstance(y, xr.DataArray):
y = process_group_variables_coords(
dt, group="observed_data", var_names=y, filter_vars=filter_vars, coords=coords
)
const_data = get_group(dt, "constant_data")
if x is None:
x = list(const_data.data_vars)
if not isinstance(x, xr.DataArray | xr.Dataset):
x = process_group_variables_coords(
dt, group="constant_data", var_names=x, filter_vars=filter_vars, coords=coords
)
(target_var,) = y.data_vars
independent_var = list(x.data_vars)
if y_pred is None:
y_pred = target_var
if isinstance(y_pred, xr.Dataset):
raise TypeError(
"y_pred can't be a dataset because multiple target variables are not supported yet."
)
if not isinstance(y_pred, xr.DataArray):
y_pred = process_group_variables_coords(
dt,
group=group,
var_names=y_pred,
filter_vars=filter_vars,
coords=coords,
)
if x_pred is None:
x_pred = independent_var
if not isinstance(x_pred, xr.DataArray | xr.Dataset):
if group == "predictions":
x_pred = process_group_variables_coords(
dt,
group="predictions_constant_data",
var_names=x_pred,
filter_vars=filter_vars,
coords=coords,
)
else:
x_pred = x
if isinstance(ci_prob, Sequence):
x_with_prob = x.expand_dims(dim={"prob": ci_prob})
else:
x_with_prob = x
plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")
if plot_collection is None:
pc_kwargs.setdefault("cols", "__variable__")
pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy()
pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy()
if isinstance(ci_prob, Sequence):
if "alpha" not in pc_kwargs["aes"]:
pc_kwargs["aes"].setdefault("alpha", ["prob"])
pc_kwargs["alpha"] = np.linspace(0.1, 0.5, len(ci_prob))
else:
warnings.warn(
"When multiple credible intervals are plotted, "
"it is recommended to map 'alpha' aesthetic to 'prob' "
"dimension to differentiate between intervals.",
)
pc_kwargs["aes"].setdefault("color", ["__variable__"])
pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, x)
plot_collection = PlotCollection.wrap(
x_with_prob,
backend=backend,
**pc_kwargs,
)
if aes_by_visuals is None:
aes_by_visuals = {}
else:
aes_by_visuals = aes_by_visuals.copy()
aes_by_visuals.setdefault(
"central_line", plot_collection.aes_set.difference({"alpha", "color"})
)
if isinstance(ci_prob, Sequence):
aes_by_visuals.setdefault("ci_line", {"alpha"})
aes_by_visuals.setdefault(
"ci_fill", set(aes_by_visuals.get("ci_fill", {})).union({"color", "alpha"})
)
else:
aes_by_visuals.setdefault(
"ci_fill", set(aes_by_visuals.get("ci_fill", {})).union({"color"})
)
# calculations for credible interval
ci_fun = azs.hdi if ci_kind == "hdi" else azs.eti
ci_dims, _, fill_ignore = filter_aes(plot_collection, aes_by_visuals, "ci_fill", sample_dims)
if isinstance(ci_prob, Sequence):
ci_data = xr.concat(
[
ci_fun(
y_pred, dim=ci_dims, prob=p, **stats.get("credible_interval", {})
).expand_dims(prob=[p])
for p in ci_prob
],
dim="prob",
)
else:
ci_data = ci_fun(y_pred, dim=ci_dims, prob=ci_prob, **stats.get("credible_interval", {}))
central_line_dims, _, _ = filter_aes(
plot_collection, aes_by_visuals, "central_line", sample_dims
)
if line_kind == "mean":
line_data = y_pred.mean(dim=central_line_dims, **stats.get("point_estimate", {}))
elif line_kind == "median":
line_data = y_pred.median(dim=central_line_dims, **stats.get("point_estimate", {}))
elif line_kind == "mode":
line_data = azs.mode(y_pred, dim=central_line_dims, **stats.get("point_estimate", {}))
lines = plot_bknd.get_default_aes("linestyle", 2, {})
ci_lower = ci_data.sel(ci_bound="lower")
ci_upper = ci_data.sel(ci_bound="upper")
# upper and lower lines of credible interval
ci_line_kwargs = copy(visuals.get("ci_line", False))
if ci_line_kwargs is not False:
_, ci_line_aes, ci_line_ignore = filter_aes(
plot_collection, aes_by_visuals, "ci_line", sample_dims
)
if "color" not in ci_line_aes:
ci_line_kwargs.setdefault("color", "B2")
if "linestyle" not in ci_line_aes:
ci_line_kwargs.setdefault("linestyle", lines[1])
plot_collection.map(
line_xy,
"ci_line",
x=x_pred,
y=ci_lower[target_var],
ignore_aes=ci_line_ignore,
**ci_line_kwargs,
)
plot_collection.map(
line_xy,
"ci_line",
x=x_pred,
y=ci_upper[target_var],
ignore_aes=ci_line_ignore,
**ci_line_kwargs,
)
# fill between lines of credible interval
fill_kwargs = copy(visuals.get("ci_fill", {}))
if fill_kwargs is not False:
plot_collection.map(
fill_between_y,
"ci_fill",
x=x_pred,
y_bottom=ci_lower[target_var],
y_top=ci_upper[target_var],
ignore_aes=fill_ignore,
**fill_kwargs,
)
# mean/median/mode line
central_line_kwargs = copy(visuals.get("central_line", {}))
if central_line_kwargs is not False:
_, central_line_aes, central_line_ignore = filter_aes(
plot_collection, aes_by_visuals, "central_line", sample_dims
)
if "color" not in central_line_aes:
central_line_kwargs.setdefault("color", "B1")
if "alpha" not in central_line_aes:
central_line_kwargs.setdefault("alpha", 0.6)
plot_collection.map(
line_xy,
"central_line",
x=x_pred,
y=line_data[target_var],
ignore_aes=central_line_ignore,
**central_line_kwargs,
)
# scatter plot
original_scatter_kwargs = copy(visuals.get("scatter", {}))
if original_scatter_kwargs is not False:
_, scatter_aes, scatter_ignore = filter_aes(
plot_collection, aes_by_visuals, "scatter", sample_dims
)
if "alpha" not in scatter_aes:
original_scatter_kwargs.setdefault("alpha", 0.3)
if "color" not in scatter_aes:
original_scatter_kwargs.setdefault("color", "B2")
if "width" not in scatter_aes:
original_scatter_kwargs.setdefault("width", 0)
plot_collection.map(
scatter_xy,
"scatter",
x=x,
y=y[target_var],
ignore_aes=scatter_ignore,
**original_scatter_kwargs,
)
# x-axis label
xlabel_kwargs = copy(visuals.get("xlabel", {}))
if xlabel_kwargs is not False:
_, _, xlabel_ignore = filter_aes(plot_collection, aes_by_visuals, "xlabel", sample_dims)
plot_collection.map(
labelled_x,
"xlabel",
data=x,
labeller=labeller,
subset_info=True,
ignore_aes=xlabel_ignore,
**xlabel_kwargs,
)
# y-axis label
ylabel_kwargs = copy(visuals.get("ylabel", {}))
if ylabel_kwargs is not False:
_, _, ylabel_ignore = filter_aes(plot_collection, aes_by_visuals, "ylabel", sample_dims)
plot_collection.map(
labelled_y,
"ylabel",
text=target_var,
ignore_aes=ylabel_ignore,
**ylabel_kwargs,
)
return plot_collection