Creating custom plots
While ArviZ includes many plotting functions for visualizing the data stored in InferenceData
objects, you will often need to construct custom plots, or you may want to tweak some of our plots in your favorite plotting package.
In this tutorial, we will show you a few useful techniques you can use to construct these plots using Julia's plotting packages. For demonstration purposes, we'll use Makie.jl and AlgebraOfGraphics.jl, which can consume Dataset
objects since they implement the Tables interface. However, we could just as easily have used StatsPlots.jl.
begin
using ArviZ, ArviZExampleData, DimensionalData, DataFrames, Statistics
using AlgebraOfGraphics, CairoMakie
using AlgebraOfGraphics: density
set_aog_theme!()
end;
We'll start by loading some draws from an implementation of the non-centered parameterization of the 8 schools model. In this parameterization, the model has some sampling issues.
idata = load_example_data("centered_eight")
posterior
╭─────────────────╮
│ 500×4×8 Dataset │
├─────────────────┴────────────────────────────────────────────────────── dims ┐
↓ draw Sampled{Int64} [0, 1, …, 498, 499] ForwardOrdered Irregular Points,
→ chain Sampled{Int64} [0, 1, 2, 3] ForwardOrdered Irregular Points,
↗ school Categorical{String} [Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
├────────────────────────────────────────────────────────────────────── layers ┤
:mu eltype: Float64 dims: draw, chain size: 500×4
:theta eltype: Float64 dims: school, draw, chain size: 8×500×4
:tau eltype: Float64 dims: draw, chain size: 500×4
├──────────────────────────────────────────────────────────────────── metadata ┤
Dict{String, Any} with 6 entries:
"created_at" => "2022-10-13T14:37:37.315398"
"inference_library_version" => "4.2.2"
"sampling_time" => 7.48011
"tuning_steps" => 1000
"arviz_version" => "0.13.0.dev0"
"inference_library" => "pymc"
posterior_predictive
╭─────────────────╮
│ 8×500×4 Dataset │
├─────────────────┴────────────────────────────────────────────────────── dims ┐
↓ school Categorical{String} [Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered,
→ draw Sampled{Int64} [0, 1, …, 498, 499] ForwardOrdered Irregular Points,
↗ chain Sampled{Int64} [0, 1, 2, 3] ForwardOrdered Irregular Points
├────────────────────────────────────────────────────────────────────── layers ┤
:obs eltype: Float64 dims: school, draw, chain size: 8×500×4
├──────────────────────────────────────────────────────────────────── metadata ┤
Dict{String, Any} with 4 entries:
"created_at" => "2022-10-13T14:37:41.460544"
"inference_library_version" => "4.2.2"
"arviz_version" => "0.13.0.dev0"
"inference_library" => "pymc"
log_likelihood
╭─────────────────╮
│ 8×500×4 Dataset │
├─────────────────┴────────────────────────────────────────────────────── dims ┐
↓ school Categorical{String} [Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered,
→ draw Sampled{Int64} [0, 1, …, 498, 499] ForwardOrdered Irregular Points,
↗ chain Sampled{Int64} [0, 1, 2, 3] ForwardOrdered Irregular Points
├────────────────────────────────────────────────────────────────────── layers ┤
:obs eltype: Float64 dims: school, draw, chain size: 8×500×4
├──────────────────────────────────────────────────────────────────── metadata ┤
Dict{String, Any} with 4 entries:
"created_at" => "2022-10-13T14:37:37.487399"
"inference_library_version" => "4.2.2"
"arviz_version" => "0.13.0.dev0"
"inference_library" => "pymc"
sample_stats
╭───────────────╮
│ 500×4 Dataset │
├───────────────┴─────────────────────────────────────────────────────── dims ┐
↓ draw Sampled{Int64} [0, 1, …, 498, 499] ForwardOrdered Irregular Points,
→ chain Sampled{Int64} [0, 1, 2, 3] ForwardOrdered Irregular Points
├─────────────────────────────────────────────────────────────────────────────┴ layers ┐
:max_energy_error eltype: Float64 dims: draw, chain size: 500×4
:energy_error eltype: Float64 dims: draw, chain size: 500×4
:lp eltype: Float64 dims: draw, chain size: 500×4
:index_in_trajectory eltype: Int64 dims: draw, chain size: 500×4
:acceptance_rate eltype: Float64 dims: draw, chain size: 500×4
:diverging eltype: Bool dims: draw, chain size: 500×4
:process_time_diff eltype: Float64 dims: draw, chain size: 500×4
:n_steps eltype: Float64 dims: draw, chain size: 500×4
:perf_counter_start eltype: Float64 dims: draw, chain size: 500×4
:largest_eigval eltype: Union{Missing, Float64} dims: draw, chain size: 500×4
:smallest_eigval eltype: Union{Missing, Float64} dims: draw, chain size: 500×4
:step_size_bar eltype: Float64 dims: draw, chain size: 500×4
:step_size eltype: Float64 dims: draw, chain size: 500×4
:energy eltype: Float64 dims: draw, chain size: 500×4
:tree_depth eltype: Int64 dims: draw, chain size: 500×4
:perf_counter_diff eltype: Float64 dims: draw, chain size: 500×4
├──────────────────────────────────────────────────────────────────── metadata ┤
Dict{String, Any} with 6 entries:
"created_at" => "2022-10-13T14:37:37.324929"
"inference_library_version" => "4.2.2"
"sampling_time" => 7.48011
"tuning_steps" => 1000
"arviz_version" => "0.13.0.dev0"
"inference_library" => "pymc"
prior
╭─────────────────╮
│ 500×1×8 Dataset │
├─────────────────┴────────────────────────────────────────────────────── dims ┐
↓ draw Sampled{Int64} [0, 1, …, 498, 499] ForwardOrdered Irregular Points,
→ chain Sampled{Int64} [0] ForwardOrdered Irregular Points,
↗ school Categorical{String} [Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
├────────────────────────────────────────────────────────────────────── layers ┤
:tau eltype: Float64 dims: draw, chain size: 500×1
:theta eltype: Float64 dims: school, draw, chain size: 8×500×1
:mu eltype: Float64 dims: draw, chain size: 500×1
├──────────────────────────────────────────────────────────────────── metadata ┤
Dict{String, Any} with 4 entries:
"created_at" => "2022-10-13T14:37:26.602116"
"inference_library_version" => "4.2.2"
"arviz_version" => "0.13.0.dev0"
"inference_library" => "pymc"
prior_predictive
╭─────────────────╮
│ 8×500×1 Dataset │
├─────────────────┴────────────────────────────────────────────────────── dims ┐
↓ school Categorical{String} [Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered,
→ draw Sampled{Int64} [0, 1, …, 498, 499] ForwardOrdered Irregular Points,
↗ chain Sampled{Int64} [0] ForwardOrdered Irregular Points
├────────────────────────────────────────────────────────────────────── layers ┤
:obs eltype: Float64 dims: school, draw, chain size: 8×500×1
├──────────────────────────────────────────────────────────────────── metadata ┤
Dict{String, Any} with 4 entries:
"created_at" => "2022-10-13T14:37:26.604969"
"inference_library_version" => "4.2.2"
"arviz_version" => "0.13.0.dev0"
"inference_library" => "pymc"
observed_data
╭───────────────────╮
│ 8-element Dataset │
├───────────────────┴──────────────────────────────────────────────────── dims ┐
↓ school Categorical{String} [Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
├────────────────────────────────────────────────────────────────────── layers ┤
:obs eltype: Float64 dims: school size: 8
├──────────────────────────────────────────────────────────────────── metadata ┤
Dict{String, Any} with 4 entries:
"created_at" => "2022-10-13T14:37:26.606375"
"inference_library_version" => "4.2.2"
"arviz_version" => "0.13.0.dev0"
"inference_library" => "pymc"
constant_data
╭───────────────────╮
│ 8-element Dataset │
├───────────────────┴──────────────────────────────────────────────────── dims ┐
↓ school Categorical{String} [Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
├────────────────────────────────────────────────────────────────────── layers ┤
:scores eltype: Float64 dims: school size: 8
├──────────────────────────────────────────────────────────────────── metadata ┤
Dict{String, Any} with 4 entries:
"created_at" => "2022-10-13T14:37:26.607471"
"inference_library_version" => "4.2.2"
"arviz_version" => "0.13.0.dev0"
"inference_library" => "pymc"
idata.posterior
╭─────────────────╮ │ 500×4×8 Dataset │ ├─────────────────┴────────────────────────────────────────────────────────────── dims ┐ ↓ draw Sampled{Int64} [0, 1, …, 498, 499] ForwardOrdered Irregular Points, → chain Sampled{Int64} [0, 1, 2, 3] ForwardOrdered Irregular Points, ↗ school Categorical{String} [Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered ├────────────────────────────────────────────────────────────────────────────── layers ┤ :mu eltype: Float64 dims: draw, chain size: 500×4 :theta eltype: Float64 dims: school, draw, chain size: 8×500×4 :tau eltype: Float64 dims: draw, chain size: 500×4 ├──────────────────────────────────────────────────────────────────────────── metadata ┤ Dict{String, Any} with 6 entries: "created_at" => "2022-10-13T14:37:37.315398" "inference_library_version" => "4.2.2" "sampling_time" => 7.48011 "tuning_steps" => 1000 "arviz_version" => "0.13.0.dev0" "inference_library" => "pymc"
The plotting functions we'll be using interact with a tabular view of a Dataset
. Let's see what that view looks like for a Dataset
:
df = DataFrame(idata.posterior)
draw | chain | school | mu | theta | tau | |
---|---|---|---|---|---|---|
1 | 0 | 0 | "Choate" | 7.8718 | 12.3207 | 4.72574 |
2 | 1 | 0 | "Choate" | 3.38455 | 11.2856 | 3.90899 |
3 | 2 | 0 | "Choate" | 9.10048 | 5.70851 | 4.84403 |
4 | 3 | 0 | "Choate" | 7.30429 | 10.0373 | 1.8567 |
5 | 4 | 0 | "Choate" | 9.87968 | 9.14915 | 4.74841 |
6 | 5 | 0 | "Choate" | 7.04203 | 14.7359 | 3.51387 |
7 | 6 | 0 | "Choate" | 10.3785 | 14.304 | 4.20898 |
8 | 7 | 0 | "Choate" | 10.06 | 13.3298 | 2.6834 |
9 | 8 | 0 | "Choate" | 10.4253 | 10.4498 | 1.16889 |
10 | 9 | 0 | "Choate" | 10.8108 | 11.4731 | 1.21052 |
... | ||||||
16000 | 499 | 3 | "Mt. Hermon" | 3.40446 | 1.29505 | 4.46125 |
The tabular view includes dimensions and variables as columns.
When variables with different dimensions are flattened into a tabular form, there's always some duplication of values. As a simple case, note that chain
, draw
, and school
all have repeated values in the above table.
In this case, theta
has the school
dimension, but tau
doesn't, so the values of tau
will be repeated in the table for each value of school
.
df[df.school .== Ref("Choate"), :].tau == df[df.school .== Ref("Deerfield"), :].tau
true
In our first example, this will be important.
Here, let's construct a trace plot. Besides idata
, all functions and types in the following cell are defined in AlgebraOfGraphics or Makie:
data(...)
indicates that the wrapped object implements the Tables interfacemapping
indicates how the data should be used. The symbols are all column names in the table, which for us are our variable names and dimensions.visual
specifies how the data should be converted to a plot.Lines
is a plot type defined in Makie.draw
takes this combination and plots it.
draw(
data(idata.posterior.mu) *
mapping(:draw, :mu; color=:chain => nonnumeric) *
visual(Lines; alpha=0.8),
)
Note the line idata.posterior.mu
. If we had just used idata.posterior
, the plot would have looked more-or-less the same, but there would be artifacts due to mu
being copied many times. By selecting mu
directly, all other dimensions are discarded, so each value of mu
appears in the plot exactly once.
When examining an MCMC trace plot, we want to see a "fuzzy caterpillar". Instead we see a few places where the Markov chains froze. We can do the same for theta
as well, but it's more useful here to separate these draws by school
.
draw(
data(idata.posterior) *
mapping(:draw, :theta; layout=:school, color=:chain => nonnumeric) *
visual(Lines; alpha=0.8),
)
Suppose we want to compare tau
with theta
for two different schools. To do so, we use InferenceData
s indexing syntax to subset the data.
draw(
data(idata[:posterior, school=At(["Choate", "Deerfield"])]) *
mapping(:theta, :tau; color=:school) *
density() *
visual(Contour; levels=10),
)
We can also compare the density plots constructed from each chain for different schools.
draw(
data(idata.posterior) *
mapping(:theta; layout=:school, color=:chain => nonnumeric) *
density(),
)
If we want to compare many schools in a single plot, an ECDF plot is more convenient.
draw(
data(idata.posterior) * mapping(:theta; color=:school => nonnumeric) * visual(ECDFPlot);
axis=(; ylabel="probability"),
)
So far we've just plotted data from one group, but we often want to combine data from multiple groups in one plot. The simplest way to do this is to create the plot out of multiple layers. Here we use this approach to plot the observations over the posterior predictive distribution.
draw(
(data(idata.posterior_predictive) * mapping(:obs; layout=:school) * density()) +
(data(idata.observed_data) * mapping(:obs, :obs => zero => ""; layout=:school)),
)
Another option is to combine the groups into a single dataset.
Here we compare the prior and posterior. Since the prior has 1 chain and the posterior has 4 chains, if we were to combine them into a table, the structure would need to be ragged. This is not currently supported.
We can then either plot the two distributions separately as we did before, or we can compare a single chain from each group. This is what we'll do here. To concatenate the two groups, we introduce a new named dimension using DimensionalData.Dim
.
draw(
data(
cat(
idata.posterior[chain=[1]], idata.prior; dims=Dim{:group}([:posterior, :prior])
)[:mu],
) *
mapping(:mu; color=:group) *
histogram(; bins=20) *
visual(; alpha=0.8);
axis=(; ylabel="probability"),
)
From the trace plots, we suspected the geometry of this posterior was bad. Let's highlight divergent transitions. To do so, we merge posterior
and samplestats
, which can do with merge
since they share no common variable names.
draw(
data(merge(idata.posterior, idata.sample_stats)) * mapping(
:theta,
:tau;
layout=:school,
color=:diverging,
markersize=:diverging => (d -> d ? 5 : 2),
),
)
When we try building more complex plots, we may need to build new Dataset
s from our existing ones.
One example of this is the corner plot. To build this plot, we need to make a copy of theta
with a copy of the school
dimension.
let
theta = idata.posterior.theta[school=1:4]
theta2 = rebuild(set(theta; school=:school2); name=:theta2)
plot_data = Dataset(theta, theta2, idata.sample_stats.diverging)
draw(
data(plot_data) * mapping(
:theta,
:theta2 => "theta";
col=:school,
row=:school2,
color=:diverging,
markersize=:diverging => (d -> d ? 3 : 1),
);
figure=(; figsize=(5, 5)),
axis=(; aspect=1),
)
end
Environment
using Pkg, InteractiveUtils
using PlutoUI
with_terminal(Pkg.status; color=false)
Status `~/work/ArviZ.jl/ArviZ.jl/docs/Project.toml` ⌅ [cbdf2221] AlgebraOfGraphics v0.6.20 [131c737c] ArviZ v0.12.0 `~/work/ArviZ.jl/ArviZ.jl` [2f96bb34] ArviZExampleData v0.1.11 [4a6e88f0] ArviZPythonPlots v0.1.7 ⌅ [13f3f980] CairoMakie v0.11.11 [a93c6f00] DataFrames v1.6.1 [0703355e] DimensionalData v0.27.6 [31c24e10] Distributions v0.25.110 [e30172f5] Documenter v1.5.0 [f6006082] EvoTrees v0.16.7 [b5cf5a8d] InferenceObjects v0.4.2 [be115224] MCMCDiagnosticTools v0.3.10 [a7f614a8] MLJBase v1.7.0 [614be32b] MLJIteration v0.6.2 [ce719bf2] PSIS v0.9.5 [359b1769] PlutoStaticHTML v6.0.27 [7f904dfe] PlutoUI v0.7.59 [7f36be82] PosteriorStats v0.2.5 [c1514b29] StanSample v7.10.1 [a19d573c] StatisticalMeasures v0.1.6 ⌅ [2913bbd2] StatsBase v0.33.21 ⌅ [fce5fe82] Turing v0.30.9 [f43a241f] Downloads v1.6.0 [37e2e46d] LinearAlgebra [10745b16] Statistics v1.10.0 Info Packages marked with ⌅ have new versions available but compatibility constraints restrict them from upgrading. To see why use `status --outdated`
with_terminal(versioninfo)
Julia Version 1.10.4 Commit 48d4fd48430 (2024-06-04 10:41 UTC) Build Info: Official https://julialang.org/ release Platform Info: OS: Linux (x86_64-linux-gnu) CPU: 4 × AMD EPYC 7763 64-Core Processor WORD_SIZE: 64 LIBM: libopenlibm LLVM: libLLVM-15.0.7 (ORCJIT, znver3) Threads: 2 default, 0 interactive, 1 GC (on 4 virtual cores) Environment: JULIA_PKG_SERVER_REGISTRY_PREFERENCE = eager JULIA_NUM_THREADS = 2 JULIA_REVISE_WORKER_ONLY = 1