Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/corner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
# -*- coding: utf-8 -*-

__all__ = ["corner", "hist2d", "quantile", "overplot_lines", "overplot_points"]
__all__ = [
"corner",
"hist2d",
"quantile",
"overplot_lines",
"overplot_spans",
"overplot_points",
]

from corner.core import hist2d, overplot_lines, overplot_points, quantile
from corner.core import (
hist2d,
overplot_lines,
overplot_points,
overplot_spans,
quantile,
)
from corner.corner import corner
from corner.version import version as __version__
8 changes: 8 additions & 0 deletions src/corner/arviz_corner.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def arviz_corner(
title_fmt=".2f",
title_kwargs=None,
truths=None,
truth_uncertainties=None,
truth_color="#4682b4",
scale_hist=False,
quantiles=None,
Expand All @@ -68,6 +69,7 @@ def arviz_corner(
use_math_text=False,
reverse=False,
labelpad=0.0,
truth_uncertainties_kwargs=None,
hist_kwargs=None,
# Arviz parameters
group="posterior",
Expand Down Expand Up @@ -126,6 +128,10 @@ def arviz_corner(
truths = np.concatenate(
[np.asarray(truths[k]).flatten() for k in var_names]
)
if isinstance(truth_uncertainties, Mapping):
truth_uncertainties = np.concatenate(
[np.asarray(truth_uncertainties[k]).flatten() for k in var_names]
)
if isinstance(titles, Mapping):
titles = np.concatenate(
[np.asarray(titles[k]).flatten() for k in var_names]
Expand All @@ -150,6 +156,7 @@ def arviz_corner(
title_fmt=title_fmt,
title_kwargs=title_kwargs,
truths=truths,
truth_uncertainties=truth_uncertainties,
truth_color=truth_color,
scale_hist=scale_hist,
quantiles=quantiles,
Expand All @@ -160,6 +167,7 @@ def arviz_corner(
use_math_text=use_math_text,
reverse=reverse,
labelpad=labelpad,
truth_uncertainties_kwargs=truth_uncertainties_kwargs,
hist_kwargs=hist_kwargs,
**hist2d_kwargs,
)
Expand Down
116 changes: 116 additions & 0 deletions src/corner/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def corner_impl(
title_fmt=".2f",
title_kwargs=None,
truths=None,
truth_uncertainties=None,
truth_color="#4682b4",
scale_hist=False,
quantiles=None,
Expand All @@ -57,6 +58,7 @@ def corner_impl(
use_math_text=False,
reverse=False,
labelpad=0.0,
truth_uncertainties_kwargs=None,
hist_kwargs=None,
**hist2d_kwargs,
):
Expand Down Expand Up @@ -463,6 +465,24 @@ def corner_impl(
color=truth_color,
)

if truth_uncertainties is not None:
lower_list, upper_list = _parse_truth_uncertainties(
truths, truth_uncertainties
)
if upper_list is not None and lower_list is not None:
if truth_uncertainties_kwargs is None:
# Use default settings.
truth_uncertainties_kwargs = dict(
alpha=0.15, fc=truth_color, ec=truth_color, zorder=0
)
overplot_spans(
fig,
lower_list,
upper_list,
reverse=reverse,
**truth_uncertainties_kwargs,
)

return fig


Expand Down Expand Up @@ -853,6 +873,67 @@ def overplot_lines(fig, xs, reverse=False, **kwargs):
axes[k2, k1].axhline(xs[k2], **kwargs)


def overplot_spans(fig, x1s, x2s, reverse=False, **kwargs):
"""
Overplot spans on a figure generated by ``corner.corner``

Parameters
----------
fig : Figure
The figure generated by a call to :func:`corner.corner`.

x1s : array_like[ndim]
The start value of each span that will be plotted. This must have ``ndim``
entries, where ``ndim`` is compatible with the :func:`corner.corner`
call that originally generated the figure. The entries can optionally
be ``None`` to omit the line in that axis.

x2s : array_like[ndim]
The end value of each span that will be plotted. This must have ``ndim``
entries, where ``ndim`` is compatible with the :func:`corner.corner`
call that originally generated the figure. The entries can optionally
be ``None`` to omit the line in that axis.

reverse: bool
A boolean flag that should be set to 'True' if the corner plot itself
was plotted with 'reverse=True'.

**kwargs
Any remaining keyword arguments are passed to the ``ax.axhspan``
method.

"""
K = len(x1s)
if K != len(x2s):
raise ValueError("`x1s` and `x2s` arrays must be the same length.")

axes, _ = _get_fig_axes(fig, K)
if reverse:
for k1 in range(K):
if x1s[k1] is not None:
axes[K - k1 - 1, K - k1 - 1].axvspan(
x1s[k1], x2s[k1], **kwargs
)
for k2 in range(k1 + 1, K):
if x1s[k1] is not None:
axes[K - k2 - 1, K - k1 - 1].axvspan(
x1s[k1], x2s[k1], **kwargs
)
if x1s[k2] is not None:
axes[K - k2 - 1, K - k1 - 1].axhspan(
x1s[k2], x2s[k2], **kwargs
)
else:
for k1 in range(K):
if x1s[k1] is not None:
axes[k1, k1].axvspan(x1s[k1], x2s[k1], **kwargs)
for k2 in range(k1 + 1, K):
if x1s[k1] is not None:
axes[k2, k1].axvspan(x1s[k1], x2s[k1], **kwargs)
if x1s[k2] is not None:
axes[k2, k1].axhspan(x1s[k2], x2s[k2], **kwargs)


def overplot_points(fig, xs, reverse=False, **kwargs):
"""
Overplot points on a figure generated by ``corner.corner``
Expand Down Expand Up @@ -892,6 +973,41 @@ def overplot_points(fig, xs, reverse=False, **kwargs):
axes[k2, k1].plot(xs[k1], xs[k2], **kwargs)


def _parse_truth_uncertainties(truths, truth_uncertainties):

if truth_uncertainties is None or truths is None:
return None, None

lowers = list()
uppers = list()
for i, current_uncert in enumerate(truth_uncertainties):
lower_uncert = None
upper_uncert = None
if current_uncert is None or truths[i] is None:
# Skip
lower_uncert = None
upper_uncert = None
elif isinstance(current_uncert, (float, np.floating)):
# Single uncertainty provided.
lower_uncert = truths[i] - current_uncert
upper_uncert = truths[i] + current_uncert
elif len(current_uncert) == 1:
# Still a single uncertainty provided but its a in a iterable.
lower_uncert = truths[i] - current_uncert[0]
upper_uncert = truths[i] + current_uncert[0]
elif len(current_uncert) == 2:
lower_uncert = truths[i] - current_uncert[0]
upper_uncert = truths[i] + current_uncert[1]
else:
raise ValueError(
f"Unexpected number of truth uncertainties provided at index {i}."
)
lowers.append(lower_uncert)
uppers.append(upper_uncert)

return lowers, uppers


def _parse_input(xs):
xs = np.atleast_1d(xs)
if len(xs.shape) == 1:
Expand Down
17 changes: 17 additions & 0 deletions src/corner/corner.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def corner(
title_fmt=".2f",
title_kwargs=None,
truths=None,
truth_uncertainties=None,
truth_color="#4682b4",
scale_hist=False,
quantiles=None,
Expand All @@ -44,6 +45,7 @@ def corner(
use_math_text=False,
reverse=False,
labelpad=0.0,
truth_uncertainties_kwargs=None,
hist_kwargs=None,
# Arviz parameters
group="posterior",
Expand Down Expand Up @@ -171,6 +173,13 @@ def corner(
A list of reference values to indicate on the plots. Individual
values can be omitted by using ``None``.

truth_uncertainties : iterable (ndim, udim = 1 or 2)
A list of uncertainties corresponding to `truths`.
If udim is 1 then that uncertainty will be used for both the
lower and upper bounds. If udim is 2 then the first value will be used
as the lower bound and the second as the upper. Individual
values can be omitted by using ``None``.

truth_color : str
A ``matplotlib`` style color for the ``truths`` makers.

Expand Down Expand Up @@ -211,6 +220,10 @@ def corner(
axes yet, or ``ndim * ndim`` axes already present. If not set, the
plot will be drawn on a newly created figure.

truth_uncertainties_kwargs : dict
Any extra keyword arguments to send to the axvspan used to create truth
uncertainty bands.

hist_kwargs : dict
Any extra keyword arguments to send to the 1-D histogram plots.

Expand Down Expand Up @@ -263,6 +276,7 @@ def corner(
title_fmt=title_fmt,
title_kwargs=title_kwargs,
truths=truths,
truth_uncertainties=truth_uncertainties,
truth_color=truth_color,
scale_hist=scale_hist,
quantiles=quantiles,
Expand All @@ -273,6 +287,7 @@ def corner(
use_math_text=use_math_text,
reverse=reverse,
labelpad=labelpad,
truth_uncertainties_kwargs=truth_uncertainties_kwargs,
hist_kwargs=hist_kwargs,
**hist2d_kwargs,
)
Expand All @@ -295,6 +310,7 @@ def corner(
title_fmt=title_fmt,
title_kwargs=title_kwargs,
truths=truths,
truth_uncertainties=truth_uncertainties,
truth_color=truth_color,
scale_hist=scale_hist,
quantiles=quantiles,
Expand All @@ -305,6 +321,7 @@ def corner(
use_math_text=use_math_text,
reverse=reverse,
labelpad=labelpad,
truth_uncertainties_kwargs=truth_uncertainties_kwargs,
hist_kwargs=hist_kwargs,
group=group,
var_names=var_names,
Expand Down