ArviZ.jl Quickstart

Note

This tutorial is adapted from ArviZ's quickstart.

Setup

Here we add the necessary packages for this notebook and load a few we will use throughout.

begin
    using ArviZ, CmdStan, Distributions, LinearAlgebra, PyPlot, Random, Soss, Turing
    using Soss.MeasureTheory: HalfCauchy
    using SampleChainsDynamicHMC: getchains, dynamichmc
end
# ArviZ ships with style sheets!
ArviZ.use_style("arviz-darkgrid")

Get started with plotting

ArviZ.jl is designed to be used with libraries like CmdStan, Turing.jl, and Soss.jl but works fine with raw arrays.

rng1 = Random.MersenneTwister(37772);
begin
    plot_posterior(randn(rng1, 100_000))
    gcf()
end

Plotting a dictionary of arrays, ArviZ.jl will interpret each key as the name of a different random variable. Each row of an array is treated as an independent series of draws from the variable, called a chain. Below, we have 10 chains of 50 draws each for four different distributions.

let
    s = (10, 50)
    plot_forest(
        Dict(
            "normal" => randn(rng1, s),
            "gumbel" => rand(rng1, Gumbel(), s),
            "student t" => rand(rng1, TDist(6), s),
            "exponential" => rand(rng1, Exponential(), s),
        ),
    )
    gcf()
end

Plotting with MCMCChains.jl's Chains objects produced by Turing.jl

ArviZ is designed to work well with high dimensional, labelled data. Consider the eight schools model, which roughly tries to measure the effectiveness of SAT classes at eight different schools. To show off ArviZ's labelling, I give the schools the names of a different eight schools.

This model is small enough to write down, is hierarchical, and uses labelling. Additionally, a centered parameterization causes divergences (which are interesting for illustration).

First we create our data and set some sampling parameters.

begin
    J = 8
    y = [28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]
    σ = [15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]
    schools = [
        "Choate",
        "Deerfield",
        "Phillips Andover",
        "Phillips Exeter",
        "Hotchkiss",
        "Lawrenceville",
        "St. Paul's",
        "Mt. Hermon",
    ]
    ndraws = 1_000
    ndraws_warmup = 1_000
    nchains = 4
end;

Now we write and run the model using Turing:

Turing.@model function model_turing(y, σ, J=length(y))
    μ ~ Normal(0, 5)
    τ ~ truncated(Cauchy(0, 5), 0, Inf)
    θ ~ filldist(Normal(μ, τ), J)
    for i in 1:J
        y[i] ~ Normal(θ[i], σ[i])
    end
end
model_turing (generic function with 3 methods)
rng2 = Random.MersenneTwister(16653);
begin
    param_mod_turing = model_turing(y, σ)
    sampler = NUTS(ndraws_warmup, 0.8)

    turing_chns = Turing.sample(
        rng2, model_turing(y, σ), sampler, MCMCThreads(), ndraws, nchains
    )
end;

Most ArviZ functions work fine with Chains objects from Turing:

begin
    plot_autocorr(turing_chns; var_names=["μ", "τ"])
    gcf()
end

Convert to InferenceData

For much more powerful querying, analysis and plotting, we can use built-in ArviZ utilities to convert Chains objects to xarray datasets. Note we are also giving some information about labelling.

ArviZ is built to work with InferenceData (a netcdf datastore that loads data into xarray datasets), and the more groups it has access to, the more powerful analyses it can perform.

idata_turing_post = from_mcmcchains(
    turing_chns;
    coords=Dict("school" => schools),
    dims=Dict("y" => ["school"], "σ" => ["school"], "θ" => ["school"]),
    library="Turing",
)
InferenceData
    • Dataset (xarray.Dataset)
      Dimensions:  (chain: 4, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 1 2 3 4
        * draw     (draw) int64 1 2 3 4 5 6 7 8 9 ... 993 994 995 996 997 998 999 1000
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          μ        (chain, draw) float64 11.28 5.035 7.561 2.633 ... 11.14 10.39 10.35
          τ        (chain, draw) float64 2.042 9.241 4.15 ... 0.7697 0.6475 0.8526
          θ        (chain, draw, school) float64 8.807 10.73 5.582 ... 10.66 9.537
      Attributes:
          created_at:         2022-05-11T19:01:37.445336
          arviz_version:      0.12.0
          start_time:         [1.65229565e+09 1.65229568e+09 1.65229565e+09 1.65229...
          stop_time:          [1.65229568e+09 1.65229568e+09 1.65229567e+09 1.65229...
          inference_library:  Turing

    • Dataset (xarray.Dataset)
      Dimensions:           (chain: 4, draw: 1000)
      Coordinates:
        * chain             (chain) int64 1 2 3 4
        * draw              (draw) int64 1 2 3 4 5 6 7 ... 995 996 997 998 999 1000
      Data variables:
          energy            (chain, draw) float64 64.18 67.03 66.54 ... 50.29 47.94
          energy_error      (chain, draw) float64 0.3701 -0.3943 ... 0.2559 -0.4336
          tree_depth        (chain, draw) int64 3 5 4 4 4 5 5 4 5 ... 4 5 5 3 3 3 3 2
          diverging         (chain, draw) bool False False False ... False False False
          step_size_nom     (chain, draw) float64 0.1954 0.1553 ... 0.1067 0.1067
          acceptance_rate   (chain, draw) float64 0.9558 1.0 0.9922 ... 0.9011 0.7422
          log_density       (chain, draw) float64 -60.71 -64.98 ... -45.81 -44.75
          max_energy_error  (chain, draw) float64 -0.998 -0.7635 ... -0.5656 11.73
          is_accept         (chain, draw) bool True True True True ... True True True
          lp                (chain, draw) float64 -60.71 -64.98 ... -45.81 -44.75
          step_size         (chain, draw) float64 0.1954 0.1553 ... 0.1067 0.1067
          n_steps           (chain, draw) int64 7 31 31 31 31 31 31 ... 47 15 7 7 15 7
      Attributes:
          created_at:         2022-05-11T19:01:37.483477
          arviz_version:      0.12.0
          start_time:         [1.65229565e+09 1.65229568e+09 1.65229565e+09 1.65229...
          stop_time:          [1.65229568e+09 1.65229568e+09 1.65229567e+09 1.65229...
          inference_library:  Turing

Each group is an ArviZ.Dataset (a thinly wrapped xarray.Dataset). We can view a summary of the dataset.

idata_turing_post.posterior
Dataset (xarray.Dataset)
Dimensions:  (chain: 4, draw: 1000, school: 8)
Coordinates:
  * chain    (chain) int64 1 2 3 4
  * draw     (draw) int64 1 2 3 4 5 6 7 8 9 ... 993 994 995 996 997 998 999 1000
  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
    μ        (chain, draw) float64 11.28 5.035 7.561 2.633 ... 11.14 10.39 10.35
    τ        (chain, draw) float64 2.042 9.241 4.15 ... 0.7697 0.6475 0.8526
    θ        (chain, draw, school) float64 8.807 10.73 5.582 ... 10.66 9.537
Attributes:
    created_at:         2022-05-11T19:01:37.445336
    arviz_version:      0.12.0
    start_time:         [1.65229565e+09 1.65229568e+09 1.65229565e+09 1.65229...
    stop_time:          [1.65229568e+09 1.65229568e+09 1.65229567e+09 1.65229...
    inference_library:  Turing

Here is a plot of the trace. Note the intelligent labels.

begin
    plot_trace(idata_turing_post)
    gcf()
end

We can also generate summary stats...

summarystats(idata_turing_post)
variable mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
"μ" 3.096 3.869 -2.11 10.196 1.064 0.77 14.0 6.0 1.21
"τ" 3.284 3.132 0.3 8.972 0.688 0.493 10.0 11.0 1.31
"θ[Choate]" 4.735 6.211 -3.834 16.574 1.458 1.048 14.0 7.0 1.2
"θ[Deerfield]" 3.486 5.162 -3.685 13.78 1.271 0.915 16.0 7.0 1.17
"θ[Phillips Andover]" 2.713 5.29 -4.945 13.575 0.969 0.692 29.0 1042.0 1.09
"θ[Phillips Exeter]" 3.359 5.081 -4.326 13.469 1.033 0.739 23.0 1209.0 1.11
"θ[Hotchkiss]" 2.349 4.728 -4.477 12.34 0.84 0.6 30.0 981.0 1.09
"θ[Lawrenceville]" 2.752 5.049 -5.125 12.819 0.923 0.659 27.0 937.0 1.1
"θ[St. Paul's]" 4.815 5.935 -2.904 16.305 1.516 1.093 13.0 6.0 1.22
"θ[Mt. Hermon]" 3.406 5.53 -4.473 14.895 1.085 0.776 23.0 1153.0 1.11

...and examine the energy distribution of the Hamiltonian sampler.

begin
    plot_energy(idata_turing_post)
    gcf()
end

Additional information in Turing.jl

With a few more steps, we can use Turing to compute additional useful groups to add to the InferenceData.

To sample from the prior, one simply calls sample but with the Prior sampler:

prior = Turing.sample(rng2, param_mod_turing, Prior(), ndraws);

To draw from the prior and posterior predictive distributions we can instantiate a "predictive model", i.e. a Turing model but with the observations set to missing, and then calling predict on the predictive model and the previously drawn samples:

begin
    # Instantiate the predictive model
    param_mod_predict = model_turing(similar(y, Missing), σ)
    # and then sample!
    prior_predictive = Turing.predict(rng2, param_mod_predict, prior)
    posterior_predictive = Turing.predict(rng2, param_mod_predict, turing_chns)
end;

And to extract the pointwise log-likelihoods, which is useful if you want to compute metrics such as loo,

log_likelihood = let
    log_likelihood = Turing.pointwise_loglikelihoods(
        param_mod_turing, MCMCChains.get_sections(turing_chns, :parameters)
    )
    # Ensure the ordering of the loglikelihoods matches the ordering of `posterior_predictive`
    ynames = string.(keys(posterior_predictive))
    log_likelihood_y = getindex.(Ref(log_likelihood), ynames)
    # Reshape into `(nchains, ndraws, size(y)...)`
    Dict("y" => permutedims(cat(log_likelihood_y...; dims=3), (2, 1, 3)))
end;

This can then be included in the from_mcmcchains call from above:

idata_turing = from_mcmcchains(
    turing_chns;
    posterior_predictive,
    log_likelihood,
    prior,
    prior_predictive,
    observed_data=Dict("y" => y),
    coords=Dict("school" => schools),
    dims=Dict("y" => ["school"], "σ" => ["school"], "θ" => ["school"]),
    library="Turing",
)
InferenceData
    • Dataset (xarray.Dataset)
      Dimensions:  (chain: 4, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 1 2 3 4
        * draw     (draw) int64 1 2 3 4 5 6 7 8 9 ... 993 994 995 996 997 998 999 1000
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          μ        (chain, draw) float64 11.28 5.035 7.561 2.633 ... 11.14 10.39 10.35
          τ        (chain, draw) float64 2.042 9.241 4.15 ... 0.7697 0.6475 0.8526
          θ        (chain, draw, school) float64 8.807 10.73 5.582 ... 10.66 9.537
      Attributes:
          created_at:         2022-05-11T19:02:12.404184
          arviz_version:      0.12.0
          start_time:         [1.65229565e+09 1.65229568e+09 1.65229565e+09 1.65229...
          stop_time:          [1.65229568e+09 1.65229568e+09 1.65229567e+09 1.65229...
          inference_library:  Turing

    • Dataset (xarray.Dataset)
      Dimensions:  (chain: 4, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 1 2 3 4
        * draw     (draw) int64 1 2 3 4 5 6 7 8 9 ... 993 994 995 996 997 998 999 1000
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          y        (chain, draw, school) float64 27.75 7.488 22.88 ... 6.854 -10.97
      Attributes:
          start_time:         [None, None, None, None]
          created_at:         2022-05-11T19:02:11.710614
          stop_time:          [None, None, None, None]
          arviz_version:      0.12.0
          inference_library:  Turing

    • Dataset (xarray.Dataset)
      Dimensions:  (chain: 4, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 1 2 3 4
        * draw     (draw) int64 1 2 3 4 5 6 7 8 9 ... 993 994 995 996 997 998 999 1000
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          y        (chain, draw, school) float64 -4.446 -3.259 ... -3.491 -3.819
      Attributes:
          created_at:         2022-05-11T19:02:12.315381
          arviz_version:      0.12.0
          inference_library:  Turing

    • Dataset (xarray.Dataset)
      Dimensions:           (chain: 4, draw: 1000)
      Coordinates:
        * chain             (chain) int64 1 2 3 4
        * draw              (draw) int64 1 2 3 4 5 6 7 ... 995 996 997 998 999 1000
      Data variables:
          energy            (chain, draw) float64 64.18 67.03 66.54 ... 50.29 47.94
          energy_error      (chain, draw) float64 0.3701 -0.3943 ... 0.2559 -0.4336
          tree_depth        (chain, draw) int64 3 5 4 4 4 5 5 4 5 ... 4 5 5 3 3 3 3 2
          diverging         (chain, draw) bool False False False ... False False False
          step_size_nom     (chain, draw) float64 0.1954 0.1553 ... 0.1067 0.1067
          acceptance_rate   (chain, draw) float64 0.9558 1.0 0.9922 ... 0.9011 0.7422
          log_density       (chain, draw) float64 -60.71 -64.98 ... -45.81 -44.75
          max_energy_error  (chain, draw) float64 -0.998 -0.7635 ... -0.5656 11.73
          is_accept         (chain, draw) bool True True True True ... True True True
          lp                (chain, draw) float64 -60.71 -64.98 ... -45.81 -44.75
          step_size         (chain, draw) float64 0.1954 0.1553 ... 0.1067 0.1067
          n_steps           (chain, draw) int64 7 31 31 31 31 31 31 ... 47 15 7 7 15 7
      Attributes:
          created_at:         2022-05-11T19:02:12.409023
          arviz_version:      0.12.0
          start_time:         [1.65229565e+09 1.65229568e+09 1.65229565e+09 1.65229...
          stop_time:          [1.65229568e+09 1.65229568e+09 1.65229567e+09 1.65229...
          inference_library:  Turing

    • Dataset (xarray.Dataset)
      Dimensions:  (chain: 1, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 1
        * draw     (draw) int64 1 2 3 4 5 6 7 8 9 ... 993 994 995 996 997 998 999 1000
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          μ        (chain, draw) float64 8.628 -1.015 -3.058 ... 0.9541 6.471 8.363
          τ        (chain, draw) float64 4.816 1.663 1.182 19.69 ... 11.82 1.349 10.5
          θ        (chain, draw, school) float64 4.891 5.633 11.72 ... 11.28 11.92
      Attributes:
          created_at:         2022-05-11T19:02:14.119934
          arviz_version:      0.12.0
          start_time:         1652295706.171233
          stop_time:          1652295715.782233
          inference_library:  Turing

    • Dataset (xarray.Dataset)
      Dimensions:  (chain: 1, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 1
        * draw     (draw) int64 1 2 3 4 5 6 7 8 9 ... 993 994 995 996 997 998 999 1000
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          y        (chain, draw, school) float64 -1.721 12.34 29.75 ... -10.13 18.5
      Attributes:
          created_at:         2022-05-11T19:02:13.938919
          arviz_version:      0.12.0
          inference_library:  Turing

    • Dataset (xarray.Dataset)
      Dimensions:  (chain: 1, draw: 1000)
      Coordinates:
        * chain    (chain) int64 1
        * draw     (draw) int64 1 2 3 4 5 6 7 8 9 ... 993 994 995 996 997 998 999 1000
      Data variables:
          lp       (chain, draw) float64 -61.87 -57.45 -48.72 ... -78.14 -50.08 -63.96
      Attributes:
          created_at:         2022-05-11T19:02:14.155328
          arviz_version:      0.12.0
          start_time:         1652295706.171233
          stop_time:          1652295715.782233
          inference_library:  Turing

    • Dataset (xarray.Dataset)
      Dimensions:  (school: 8)
      Coordinates:
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          y        (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0
      Attributes:
          created_at:         2022-05-11T19:02:15.581249
          arviz_version:      0.12.0
          inference_library:  Turing

Then we can for example compute the expected leave-one-out (LOO) predictive density, which is an estimate of the out-of-distribution predictive fit of the model:

loo(idata_turing; pointwise=false) # higher is better
loo loo_se p_loo n_samples n_data_points warning loo_scale
-31.2509 1.59635 1.09271 4000 8 true "log"

If the model is well-calibrated, i.e. it replicates the true generative process well, the CDF of the pointwise LOO values should be similarly distributed to a uniform distribution. This can be inspected visually:

begin
    plot_loo_pit(idata_turing; y="y", ecdf=true)
    gcf()
end

Plotting with CmdStan.jl outputs

CmdStan.jl and StanSample.jl also default to producing Chains outputs, and we can easily plot these chains.

Here is the same centered eight schools model:

begin
    schools_code = """
    data {
      int<lower=0> J;
      real y[J];
      real<lower=0> sigma[J];
    }

    parameters {
      real mu;
      real<lower=0> tau;
      real theta[J];
    }

    model {
      mu ~ normal(0, 5);
      tau ~ cauchy(0, 5);
      theta ~ normal(mu, tau);
      y ~ normal(theta, sigma);
    }

    generated quantities {
        vector[J] log_lik;
        vector[J] y_hat;
        for (j in 1:J) {
            log_lik[j] = normal_lpdf(y[j] | theta[j], sigma[j]);
            y_hat[j] = normal_rng(theta[j], sigma[j]);
        }
    }
    """

    schools_data = Dict("J" => J, "y" => y, "sigma" => σ)
    stan_chns = mktempdir() do path
        stan_model = Stanmodel(;
            model=schools_code,
            name="schools",
            nchains,
            num_warmup=ndraws_warmup,
            num_samples=ndraws,
            output_format=:mcmcchains,
            random=CmdStan.Random(28983),
            tmpdir=path,
        )
        _, chns, _ = stan(stan_model, schools_data; summary=false)
        return chns
    end
end;
begin
    plot_density(stan_chns; var_names=["mu", "tau"])
    gcf()
end

Again, converting to InferenceData, we can get much richer labelling and mixing of data. Note that we're using the same from_cmdstan function used by ArviZ to process cmdstan output files, but through the power of dispatch in Julia, if we pass a Chains object, it instead uses ArviZ.jl's overloads, which forward to from_mcmcchains.

idata_stan = from_cmdstan(
    stan_chns;
    posterior_predictive="y_hat",
    observed_data=Dict("y" => schools_data["y"]),
    log_likelihood="log_lik",
    coords=Dict("school" => schools),
    dims=Dict(
        "y" => ["school"],
        "sigma" => ["school"],
        "theta" => ["school"],
        "log_lik" => ["school"],
        "y_hat" => ["school"],
    ),
)
InferenceData
    • Dataset (xarray.Dataset)
      Dimensions:  (chain: 4, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 1 2 3 4
        * draw     (draw) int64 1 2 3 4 5 6 7 8 9 ... 993 994 995 996 997 998 999 1000
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          theta    (chain, draw, school) float64 6.541 4.165 1.673 ... 7.588 7.146
          tau      (chain, draw) float64 2.026 3.44 2.047 ... 0.5002 0.5771 0.5771
          mu       (chain, draw) float64 5.864 4.849 3.906 5.24 ... 7.369 7.617 7.617
      Attributes:
          created_at:         2022-05-11T19:02:53.685083
          arviz_version:      0.12.0
          inference_library:  CmdStan

    • Dataset (xarray.Dataset)
      Dimensions:  (chain: 4, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 1 2 3 4
        * draw     (draw) int64 1 2 3 4 5 6 7 8 9 ... 993 994 995 996 997 998 999 1000
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          y_hat    (chain, draw, school) float64 -10.71 23.66 -9.425 ... 9.078 20.88
      Attributes:
          created_at:         2022-05-11T19:02:53.680765
          arviz_version:      0.12.0
          inference_library:  CmdStan

    • Dataset (xarray.Dataset)
      Dimensions:  (chain: 4, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 1 2 3 4
        * draw     (draw) int64 1 2 3 4 5 6 7 8 9 ... 993 994 995 996 997 998 999 1000
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          log_lik  (chain, draw, school) float64 -4.65 -3.295 -3.734 ... -3.764 -3.846
      Attributes:
          created_at:         2022-05-11T19:02:53.682821
          arviz_version:      0.12.0
          inference_library:  CmdStan

    • Dataset (xarray.Dataset)
      Dimensions:          (chain: 4, draw: 1000)
      Coordinates:
        * chain            (chain) int64 1 2 3 4
        * draw             (draw) int64 1 2 3 4 5 6 7 ... 994 995 996 997 998 999 1000
      Data variables:
          tree_depth       (chain, draw) int64 4 2 4 3 3 3 4 4 4 ... 1 1 2 1 2 1 5 3 2
          diverging        (chain, draw) bool False False False ... False True False
          energy           (chain, draw) float64 21.85 17.36 18.12 ... 11.01 16.54
          lp               (chain, draw) float64 -11.6 -14.53 -12.58 ... -6.09 -6.09
          step_size        (chain, draw) float64 0.2012 0.2012 ... 0.1493 0.1493
          acceptance_rate  (chain, draw) float64 0.9582 0.8976 ... 0.04874 1.587e-05
          n_steps          (chain, draw) int64 15 7 15 15 7 15 15 ... 5 1 5 3 31 9 3
      Attributes:
          created_at:         2022-05-11T19:02:53.688805
          arviz_version:      0.12.0
          inference_library:  CmdStan

    • Dataset (xarray.Dataset)
      Dimensions:  (school: 8)
      Coordinates:
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          y        (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0
      Attributes:
          created_at:         2022-05-11T19:02:53.692047
          arviz_version:      0.12.0
          inference_library:  CmdStan

Here is a plot showing where the Hamiltonian sampler had divergences:

begin
    plot_pair(
        idata_stan;
        coords=Dict("school" => ["Choate", "Deerfield", "Phillips Andover"]),
        divergences=true,
    )
    gcf()
end