ArviZ Quickstart

Note

This tutorial is adapted from ArviZ's quickstart.

Setup

Here we add the necessary packages for this notebook and load a few we will use throughout.

using ArviZ, ArviZPythonPlots, Distributions, LinearAlgebra, Random, StanSample, Turing
# ArviZPythonPlots ships with style sheets!
use_style("arviz-darkgrid")

Get started with plotting

To plot with ArviZ, we need to load the ArviZPythonPlots package. ArviZ is designed to be used with libraries like Stan, Turing.jl, and Soss.jl but works fine with raw arrays.

rng1 = Random.MersenneTwister(37772);
begin
    plot_posterior(randn(rng1, 100_000))
    gcf()
end

Plotting a dictionary of arrays, ArviZ will interpret each key as the name of a different random variable. Each row of an array is treated as an independent series of draws from the variable, called a chain. Below, we have 10 chains of 50 draws each for four different distributions.

let
    s = (50, 10)
    plot_forest((
        normal=randn(rng1, s),
        gumbel=rand(rng1, Gumbel(), s),
        student_t=rand(rng1, TDist(6), s),
        exponential=rand(rng1, Exponential(), s),
    ),)
    gcf()
end

Plotting with MCMCChains.jl's Chains objects produced by Turing.jl

ArviZ is designed to work well with high dimensional, labelled data. Consider the eight schools model, which roughly tries to measure the effectiveness of SAT classes at eight different schools. To show off ArviZ's labelling, I give the schools the names of a different eight schools.

This model is small enough to write down, is hierarchical, and uses labelling. Additionally, a centered parameterization causes divergences (which are interesting for illustration).

First we create our data and set some sampling parameters.

begin
    J = 8
    y = [28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]
    σ = [15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]
    schools = [
        "Choate",
        "Deerfield",
        "Phillips Andover",
        "Phillips Exeter",
        "Hotchkiss",
        "Lawrenceville",
        "St. Paul's",
        "Mt. Hermon",
    ]
    ndraws = 1_000
    ndraws_warmup = 1_000
    nchains = 4
end;

Now we write and run the model using Turing:

Turing.@model function model_turing(y, σ, J=length(y))
    μ ~ Normal(0, 5)
    τ ~ truncated(Cauchy(0, 5), 0, Inf)
    θ ~ filldist(Normal(μ, τ), J)
    for i in 1:J
        y[i] ~ Normal(θ[i], σ[i])
    end
end
model_turing (generic function with 4 methods)
rng2 = Random.MersenneTwister(16653);
begin
    param_mod_turing = model_turing(y, σ)
    sampler = NUTS(ndraws_warmup, 0.8)

    turing_chns = Turing.sample(
        rng2, model_turing(y, σ), sampler, MCMCThreads(), ndraws, nchains
    )
end;

Most ArviZ functions work fine with Chains objects from Turing:

begin
    plot_autocorr(turing_chns; var_names=(:μ, :τ))
    gcf()
end

Convert to InferenceData

For much more powerful querying, analysis and plotting, we can use built-in ArviZ utilities to convert Chains objects to multidimensional data structures with named dimensions and indices. Note that for such dimensions, the information is not contained in Chains, so we need to provide it.

ArviZ is built to work with InferenceData, and the more groups it has access to, the more powerful analyses it can perform.

idata_turing_post = from_mcmcchains(
    turing_chns;
    coords=(; school=schools),
    dims=NamedTuple(k => (:school,) for k in (:y, :σ, :θ)),
    library="Turing",
)
InferenceData
posterior
╭──────────────────╮
│ 1000×4×8 Dataset │
├──────────────────┴───────────────────────────────────────────────────── dims ┐
  ↓ draw  ,
  → chain ,
  ↗ school Categorical{String} [Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
├────────────────────────────────────────────────────────────────────── layers ┤
  :μ eltype: Float64 dims: draw, chain size: 1000×4
  :τ eltype: Float64 dims: draw, chain size: 1000×4
  :θ eltype: Float64 dims: draw, chain, school size: 1000×4×8
├──────────────────────────────────────────────────────────────────── metadata ┤
  Dict{String, Any} with 2 entries:
  "created_at" => "2024-08-10T23:18:45.082"
  "inference_library" => "Turing"
sample_stats
╭────────────────╮
│ 1000×4 Dataset │
├────────────────┴ dims ┐
  ↓ draw, → chain
├─────────────────┴───────────────────────────────────────── layers ┐
  :energy           eltype: Float64 dims: draw, chain size: 1000×4
  :n_steps          eltype: Int64 dims: draw, chain size: 1000×4
  :diverging        eltype: Bool dims: draw, chain size: 1000×4
  :max_energy_error eltype: Float64 dims: draw, chain size: 1000×4
  :energy_error     eltype: Float64 dims: draw, chain size: 1000×4
  :is_accept        eltype: Bool dims: draw, chain size: 1000×4
  :log_density      eltype: Float64 dims: draw, chain size: 1000×4
  :tree_depth       eltype: Int64 dims: draw, chain size: 1000×4
  :step_size        eltype: Float64 dims: draw, chain size: 1000×4
  :acceptance_rate  eltype: Float64 dims: draw, chain size: 1000×4
  :lp               eltype: Float64 dims: draw, chain size: 1000×4
  :step_size_nom    eltype: Float64 dims: draw, chain size: 1000×4
├───────────────────────────────────────────────────────── metadata ┤
  Dict{String, Any} with 2 entries:
  "created_at" => "2024-08-10T23:18:45.024"
  "inference_library" => "Turing"

Each group is an ArviZ.Dataset, a DimensionalData.AbstractDimStack that can be used identically to a DimensionalData.Dimstack. We can view a summary of the dataset.

idata_turing_post.posterior
╭──────────────────╮
│ 1000×4×8 Dataset │
├──────────────────┴───────────────────────────────────────────────────────────── dims ┐
  ↓ draw  ,
  → chain ,
  ↗ school Categorical{String} [Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
├────────────────────────────────────────────────────────────────────────────── layers ┤
  :μ eltype: Float64 dims: draw, chain size: 1000×4
  :τ eltype: Float64 dims: draw, chain size: 1000×4
  :θ eltype: Float64 dims: draw, chain, school size: 1000×4×8
├──────────────────────────────────────────────────────────────────────────── metadata ┤
  Dict{String, Any} with 2 entries:
  "created_at"        => "2024-08-10T23:18:45.082"
  "inference_library" => "Turing"

Here is a plot of the trace. Note the intelligent labels.

begin
    plot_trace(idata_turing_post)
    gcf()
end

We can also generate summary stats...

summarystats(idata_turing_post)
SummaryStats
meanstdhdi_3%hdi_97%mcse_meanmcse_stdess_tailess_bulkrhat
μ4.33.3-1.8110.50.110.06211928451.01
τ4.43.30.67310.40.200.121091151.05
θ[Choate]6.66.1-4.0118.00.210.1916277501.01
θ[Deerfield]5.05.0-4.8114.20.140.14195212771.01
θ[Phillips Andover]3.75.7-7.0714.60.140.16197914291.01
θ[Phillips Exeter]4.85.1-4.6014.40.140.14206411781.00
θ[Hotchkiss]3.34.9-6.0812.60.150.11180410981.01
θ[Lawrenceville]3.85.2-6.1213.30.130.14196513311.00
θ[St. Paul's]6.65.4-2.6817.40.180.1418428531.01
θ[Mt. Hermon]4.95.6-5.7114.80.140.19179413931.00

...and examine the energy distribution of the Hamiltonian sampler.

begin
    plot_energy(idata_turing_post)
    gcf()
end

Additional information in Turing.jl

With a few more steps, we can use Turing to compute additional useful groups to add to the InferenceData.

To sample from the prior, one simply calls sample but with the Prior sampler:

prior = Turing.sample(rng2, param_mod_turing, Prior(), ndraws);

To draw from the prior and posterior predictive distributions we can instantiate a "predictive model", i.e. a Turing model but with the observations set to missing, and then calling predict on the predictive model and the previously drawn samples:

begin
    # Instantiate the predictive model
    param_mod_predict = model_turing(similar(y, Missing), σ)
    # and then sample!
    prior_predictive = Turing.predict(rng2, param_mod_predict, prior)
    posterior_predictive = Turing.predict(rng2, param_mod_predict, turing_chns)
end;

And to extract the pointwise log-likelihoods, which is useful if you want to compute metrics such as loo,

log_likelihood = let
    log_likelihood = Turing.pointwise_loglikelihoods(
        param_mod_turing, MCMCChains.get_sections(turing_chns, :parameters)
    )
    # Ensure the ordering of the loglikelihoods matches the ordering of `posterior_predictive`
    ynames = string.(keys(posterior_predictive))
    log_likelihood_y = getindex.(Ref(log_likelihood), ynames)
    (; y=cat(log_likelihood_y...; dims=3))
end;

This can then be included in the from_mcmcchains call from above:

idata_turing = from_mcmcchains(
    turing_chns;
    posterior_predictive,
    log_likelihood,
    prior,
    prior_predictive,
    observed_data=(; y),
    coords=(; school=schools),
    dims=NamedTuple(k => (:school,) for k in (:y, :σ, :θ)),
    library=Turing,
)
InferenceData
posterior
╭──────────────────╮
│ 1000×4×8 Dataset │
├──────────────────┴───────────────────────────────────────────────────── dims ┐
  ↓ draw  ,
  → chain ,
  ↗ school Categorical{String} [Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
├────────────────────────────────────────────────────────────────────── layers ┤
  :μ eltype: Float64 dims: draw, chain size: 1000×4
  :τ eltype: Float64 dims: draw, chain size: 1000×4
  :θ eltype: Float64 dims: draw, chain, school size: 1000×4×8
├──────────────────────────────────────────────────────────────────── metadata ┤
  Dict{String, Any} with 3 entries:
  "created_at" => "2024-08-10T23:19:13.991"
  "inference_library_version" => "0.30.9"
  "inference_library" => "Turing"
posterior_predictive
╭──────────────────╮
│ 1000×4×8 Dataset │
├──────────────────┴───────────────────────────────────────────────────── dims ┐
  ↓ draw  ,
  → chain ,
  ↗ school Categorical{String} [Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
├────────────────────────────────────────────────────────────────────── layers ┤
  :y eltype: Float64 dims: draw, chain, school size: 1000×4×8
├──────────────────────────────────────────────────────────────────── metadata ┤
  Dict{String, Any} with 3 entries:
  "created_at" => "2024-08-10T23:19:13.717"
  "inference_library_version" => "0.30.9"
  "inference_library" => "Turing"
log_likelihood
╭──────────────────╮
│ 1000×4×8 Dataset │
├──────────────────┴───────────────────────────────────────────────────── dims ┐
  ↓ draw  ,
  → chain ,
  ↗ school Categorical{String} [Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
├────────────────────────────────────────────────────────────────────── layers ┤
  :y eltype: Float64 dims: draw, chain, school size: 1000×4×8
├──────────────────────────────────────────────────────────────────── metadata ┤
  Dict{String, Any} with 3 entries:
  "created_at" => "2024-08-10T23:19:13.86"
  "inference_library_version" => "0.30.9"
  "inference_library" => "Turing"
sample_stats
╭────────────────╮
│ 1000×4 Dataset │
├────────────────┴ dims ┐
  ↓ draw, → chain
├─────────────────┴───────────────────────────────────────── layers ┐
  :energy           eltype: Float64 dims: draw, chain size: 1000×4
  :n_steps          eltype: Int64 dims: draw, chain size: 1000×4
  :diverging        eltype: Bool dims: draw, chain size: 1000×4
  :max_energy_error eltype: Float64 dims: draw, chain size: 1000×4
  :energy_error     eltype: Float64 dims: draw, chain size: 1000×4
  :is_accept        eltype: Bool dims: draw, chain size: 1000×4
  :log_density      eltype: Float64 dims: draw, chain size: 1000×4
  :tree_depth       eltype: Int64 dims: draw, chain size: 1000×4
  :step_size        eltype: Float64 dims: draw, chain size: 1000×4
  :acceptance_rate  eltype: Float64 dims: draw, chain size: 1000×4
  :lp               eltype: Float64 dims: draw, chain size: 1000×4
  :step_size_nom    eltype: Float64 dims: draw, chain size: 1000×4
├───────────────────────────────────────────────────────── metadata ┤
  Dict{String, Any} with 3 entries:
  "created_at" => "2024-08-10T23:19:13.991"
  "inference_library_version" => "0.30.9"
  "inference_library" => "Turing"
prior
╭──────────────────╮
│ 1000×1×8 Dataset │
├──────────────────┴───────────────────────────────────────────────────── dims ┐
  ↓ draw  ,
  → chain ,
  ↗ school Categorical{String} [Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
├────────────────────────────────────────────────────────────────────── layers ┤
  :μ eltype: Float64 dims: draw, chain size: 1000×1
  :τ eltype: Float64 dims: draw, chain size: 1000×1
  :θ eltype: Float64 dims: draw, chain, school size: 1000×1×8
├──────────────────────────────────────────────────────────────────── metadata ┤
  Dict{String, Any} with 3 entries:
  "created_at" => "2024-08-10T23:19:14.468"
  "inference_library_version" => "0.30.9"
  "inference_library" => "Turing"
prior_predictive
╭──────────────────╮
│ 1000×1×8 Dataset │
├──────────────────┴───────────────────────────────────────────────────── dims ┐
  ↓ draw  ,
  → chain ,
  ↗ school Categorical{String} [Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
├────────────────────────────────────────────────────────────────────── layers ┤
  :y eltype: Float64 dims: draw, chain, school size: 1000×1×8
├──────────────────────────────────────────────────────────────────── metadata ┤
  Dict{String, Any} with 3 entries:
  "created_at" => "2024-08-10T23:19:14.329"
  "inference_library_version" => "0.30.9"
  "inference_library" => "Turing"
sample_stats_prior
╭────────────────╮
│ 1000×1 Dataset │
├────────────────┴ dims ┐
  ↓ draw, → chain
├─────────────────┴─────────────────────────── layers ┐
  :lp eltype: Float64 dims: draw, chain size: 1000×1
├─────────────────────────────────────────── metadata ┤
  Dict{String, Any} with 3 entries:
  "created_at" => "2024-08-10T23:19:14.419"
  "inference_library_version" => "0.30.9"
  "inference_library" => "Turing"
observed_data
╭───────────────────╮
│ 8-element Dataset │
├───────────────────┴──────────────────────────────────────────────────── dims ┐
  ↓ school Categorical{String} [Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
├────────────────────────────────────────────────────────────────────── layers ┤
  :y eltype: Float64 dims: school size: 8
├──────────────────────────────────────────────────────────────────── metadata ┤
  Dict{String, Any} with 3 entries:
  "created_at" => "2024-08-10T23:19:14.681"
  "inference_library_version" => "0.30.9"
  "inference_library" => "Turing"

Then we can for example compute the expected leave-one-out (LOO) predictive density, which is an estimate of the out-of-distribution predictive fit of the model:

loo(idata_turing) # higher ELPD is better
PSISLOOResult with estimates
 elpd  elpd_mcse    p  p_mcse
  -31        1.4  1.0    0.33

and PSISResult with 1000 draws, 4 chains, and 8 parameters
Pareto shape (k) diagnostic values:
                    Count      Min. ESS
 (-Inf, 0.5]  good  5 (62.5%)  404
  (0.5, 0.7]  okay  3 (37.5%)  788

If the model is well-calibrated, i.e. it replicates the true generative process well, the CDF of the pointwise LOO values should be similarly distributed to a uniform distribution. This can be inspected visually:

begin
    plot_loo_pit(idata_turing; y=:y, ecdf=true)
    gcf()
end

Plotting with Stan.jl outputs

StanSample.jl comes with built-in support for producing InferenceData outputs.

Here is the same centered eight schools model in Stan:

begin
    schools_code = """
    data {
      int<lower=0> J;
      array[J] real y;
      array[J] real<lower=0> sigma;
    }

    parameters {
      real mu;
      real<lower=0> tau;
      array[J] real theta;
    }

    model {
      mu ~ normal(0, 5);
      tau ~ cauchy(0, 5);
      theta ~ normal(mu, tau);
      y ~ normal(theta, sigma);
    }

    generated quantities {
        vector[J] log_lik;
        vector[J] y_hat;
        for (j in 1:J) {
            log_lik[j] = normal_lpdf(y[j] | theta[j], sigma[j]);
            y_hat[j] = normal_rng(theta[j], sigma[j]);
        }
    }
    """

    schools_data = Dict("J" => J, "y" => y, "sigma" => σ)
    idata_stan = mktempdir() do path
        stan_model = SampleModel("schools", schools_code, path)
        _ = stan_sample(
            stan_model;
            data=schools_data,
            num_chains=nchains,
            num_warmups=ndraws_warmup,
            num_samples=ndraws,
            seed=28983,
            summary=false,
        )
        return StanSample.inferencedata(
            stan_model;
            posterior_predictive_var=:y_hat,
            observed_data=(; y),
            log_likelihood_var=:log_lik,
            coords=(; school=schools),
            dims=NamedTuple(
                k => (:school,) for k in (:y, :sigma, :theta, :log_lik, :y_hat)
            ),
        )
    end
end
InferenceData
posterior
╭──────────────────╮
│ 1000×4×8 Dataset │
├──────────────────┴───────────────────────────────────────────────────── dims ┐
  ↓ draw  ,
  → chain ,
  ↗ school Categorical{String} [Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
├────────────────────────────────────────────────────────────────────── layers ┤
  :mu    eltype: Float64 dims: draw, chain size: 1000×4
  :tau   eltype: Float64 dims: draw, chain size: 1000×4
  :theta eltype: Float64 dims: draw, chain, school size: 1000×4×8
├──────────────────────────────────────────────────────────────────── metadata ┤
  Dict{String, Any} with 1 entry:
  "created_at" => "2024-08-10T23:19:56.037"
posterior_predictive
╭──────────────────╮
│ 1000×4×8 Dataset │
├──────────────────┴───────────────────────────────────────────────────── dims ┐
  ↓ draw  ,
  → chain ,
  ↗ school Categorical{String} [Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
├────────────────────────────────────────────────────────────────────── layers ┤
  :y_hat eltype: Float64 dims: draw, chain, school size: 1000×4×8
├──────────────────────────────────────────────────────────────────── metadata ┤
  Dict{String, Any} with 1 entry:
  "created_at" => "2024-08-10T23:19:55.621"
log_likelihood
╭──────────────────╮
│ 1000×4×8 Dataset │
├──────────────────┴───────────────────────────────────────────────────── dims ┐
  ↓ draw  ,
  → chain ,
  ↗ school Categorical{String} [Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
├────────────────────────────────────────────────────────────────────── layers ┤
  :log_lik eltype: Float64 dims: draw, chain, school size: 1000×4×8
├──────────────────────────────────────────────────────────────────── metadata ┤
  Dict{String, Any} with 1 entry:
  "created_at" => "2024-08-10T23:19:55.959"
sample_stats
╭────────────────╮
│ 1000×4 Dataset │
├────────────────┴ dims ┐
  ↓ draw, → chain
├─────────────────┴──────────────────────────────────────── layers ┐
  :tree_depth      eltype: Int64 dims: draw, chain size: 1000×4
  :energy          eltype: Float64 dims: draw, chain size: 1000×4
  :diverging       eltype: Bool dims: draw, chain size: 1000×4
  :acceptance_rate eltype: Float64 dims: draw, chain size: 1000×4
  :n_steps         eltype: Int64 dims: draw, chain size: 1000×4
  :lp              eltype: Float64 dims: draw, chain size: 1000×4
  :step_size       eltype: Float64 dims: draw, chain size: 1000×4
├──────────────────────────────────────────────────────── metadata ┤
  Dict{String, Any} with 1 entry:
  "created_at" => "2024-08-10T23:19:55.72"
observed_data
╭───────────────────╮
│ 8-element Dataset │
├───────────────────┴──────────────────────────────────────────────────── dims ┐
  ↓ school Categorical{String} [Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
├────────────────────────────────────────────────────────────────────── layers ┤
  :y eltype: Float64 dims: school size: 8
├──────────────────────────────────────────────────────────────────── metadata ┤
  Dict{String, Any} with 1 entry:
  "created_at" => "2024-08-10T23:19:56.084"
begin
    plot_density(idata_stan; var_names=(:mu, :tau))
    gcf()
end

Here is a plot showing where the Hamiltonian sampler had divergences:

begin
    plot_pair(
        idata_stan;
        coords=Dict(:school => ["Choate", "Deerfield", "Phillips Andover"]),
        divergences=true,
    )
    gcf()
end
using PlutoUI
using Pkg, InteractiveUtils
with_terminal(Pkg.status; color=false)
Status `~/work/ArviZ.jl/ArviZ.jl/docs/Project.toml`
⌅ [cbdf2221] AlgebraOfGraphics v0.6.20
  [131c737c] ArviZ v0.12.0 `~/work/ArviZ.jl/ArviZ.jl`
  [2f96bb34] ArviZExampleData v0.1.11
  [4a6e88f0] ArviZPythonPlots v0.1.7
⌅ [13f3f980] CairoMakie v0.11.11
  [a93c6f00] DataFrames v1.6.1
  [0703355e] DimensionalData v0.27.6
  [31c24e10] Distributions v0.25.110
  [e30172f5] Documenter v1.5.0
  [f6006082] EvoTrees v0.16.7
  [b5cf5a8d] InferenceObjects v0.4.2
  [be115224] MCMCDiagnosticTools v0.3.10
  [a7f614a8] MLJBase v1.7.0
  [614be32b] MLJIteration v0.6.2
  [ce719bf2] PSIS v0.9.5
  [359b1769] PlutoStaticHTML v6.0.27
  [7f904dfe] PlutoUI v0.7.59
  [7f36be82] PosteriorStats v0.2.5
  [c1514b29] StanSample v7.10.1
  [a19d573c] StatisticalMeasures v0.1.6
⌅ [2913bbd2] StatsBase v0.33.21
⌅ [fce5fe82] Turing v0.30.9
  [f43a241f] Downloads v1.6.0
  [37e2e46d] LinearAlgebra
  [10745b16] Statistics v1.10.0
Info Packages marked with ⌅ have new versions available but compatibility constraints restrict them from upgrading. To see why use `status --outdated`
with_terminal(versioninfo)
Julia Version 1.10.4
Commit 48d4fd48430 (2024-06-04 10:41 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver3)
Threads: 2 default, 0 interactive, 1 GC (on 4 virtual cores)
Environment:
  JULIA_PKG_SERVER_REGISTRY_PREFERENCE = eager
  JULIA_NUM_THREADS = 2
  JULIA_REVISE_WORKER_ONLY = 1
  JULIA_PYTHONCALL_EXE = /home/runner/work/ArviZ.jl/ArviZ.jl/docs/.CondaPkg/env/bin/python