Here we add the necessary packages for this notebook and load a few we will use throughout.
begin
using ArviZ, CmdStan, Distributions, LinearAlgebra, PyPlot, Random, Soss, Turing
using Soss.MeasureTheory: HalfCauchy
using SampleChainsDynamicHMC: getchains, dynamichmc
end
# ArviZ ships with style sheets!
ArviZ.use_style("arviz-darkgrid")
ArviZ.jl is designed to be used with libraries like CmdStan , Turing.jl , and Soss.jl but works fine with raw arrays.
rng1 = Random.MersenneTwister(37772);
begin
plot_posterior(randn(rng1, 100_000))
gcf()
end
Plotting a dictionary of arrays, ArviZ.jl will interpret each key as the name of a different random variable. Each row of an array is treated as an independent series of draws from the variable, called a chain . Below, we have 10 chains of 50 draws each for four different distributions.
let
s = (10, 50)
plot_forest(
Dict(
"normal" => randn(rng1, s),
"gumbel" => rand(rng1, Gumbel(), s),
"student t" => rand(rng1, TDist(6), s),
"exponential" => rand(rng1, Exponential(), s),
),
)
gcf()
end
ArviZ is designed to work well with high dimensional, labelled data. Consider the eight schools model , which roughly tries to measure the effectiveness of SAT classes at eight different schools. To show off ArviZ's labelling, I give the schools the names of a different eight schools .
This model is small enough to write down, is hierarchical, and uses labelling. Additionally, a centered parameterization causes divergences (which are interesting for illustration).
First we create our data and set some sampling parameters.
begin
J = 8
y = [28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]
σ = [15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]
schools = [
"Choate",
"Deerfield",
"Phillips Andover",
"Phillips Exeter",
"Hotchkiss",
"Lawrenceville",
"St. Paul's",
"Mt. Hermon",
]
ndraws = 1_000
ndraws_warmup = 1_000
nchains = 4
end;
Now we write and run the model using Turing:
Turing.@model function model_turing(y, σ, J=length(y))
μ ~ Normal(0, 5)
τ ~ truncated(Cauchy(0, 5), 0, Inf)
θ ~ filldist(Normal(μ, τ), J)
for i in 1:J
y[i] ~ Normal(θ[i], σ[i])
end
end
model_turing (generic function with 3 methods)
rng2 = Random.MersenneTwister(16653);
begin
param_mod_turing = model_turing(y, σ)
sampler = NUTS(ndraws_warmup, 0.8)
turing_chns = Turing.sample(
rng2, model_turing(y, σ), sampler, MCMCThreads(), ndraws, nchains
)
end;
Most ArviZ functions work fine with Chains
objects from Turing:
begin
plot_autocorr(turing_chns; var_names=["μ", "τ"])
gcf()
end
Convert to InferenceData
For much more powerful querying, analysis and plotting, we can use built-in ArviZ utilities to convert Chains
objects to xarray datasets. Note we are also giving some information about labelling.
ArviZ is built to work with InferenceData
(a netcdf datastore that loads data into xarray
datasets), and the more groups it has access to, the more powerful analyses it can perform.
idata_turing_post = from_mcmcchains(
turing_chns;
coords=Dict("school" => schools),
dims=Dict("y" => ["school"], "σ" => ["school"], "θ" => ["school"]),
library="Turing",
)
posterior
Dataset (xarray.Dataset)
Dimensions: (chain: 4, draw: 1000, school: 8)
Coordinates:
* chain (chain) int64 1 2 3 4
* draw (draw) int64 1 2 3 4 5 6 7 8 9 ... 993 994 995 996 997 998 999 1000
* school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
μ (chain, draw) float64 11.28 5.035 7.561 2.633 ... 11.14 10.39 10.35
τ (chain, draw) float64 2.042 9.241 4.15 ... 0.7697 0.6475 0.8526
θ (chain, draw, school) float64 8.807 10.73 5.582 ... 10.66 9.537
Attributes:
created_at: 2022-05-11T19:01:37.445336
arviz_version: 0.12.0
start_time: [1.65229565e+09 1.65229568e+09 1.65229565e+09 1.65229...
stop_time: [1.65229568e+09 1.65229568e+09 1.65229567e+09 1.65229...
inference_library: Turing Dimensions: chain : 4draw : 1000school : 8
Coordinates: (3)
Data variables: (3)
μ
(chain, draw)
float64
11.28 5.035 7.561 ... 10.39 10.35
array([[11.28480872, 5.03513352, 7.56112815, ..., 0.21433085,
0.18216388, -0.49416604],
[ 4.82227895, 5.71074843, 7.33847341, ..., 6.18405029,
6.97419847, 7.7994888 ],
[ 5.95942617, 7.96494237, 8.12160405, ..., -2.1104699 ,
-2.1104699 , -2.1104699 ],
[ 7.45443573, 6.92436129, 6.7782872 , ..., 11.14058309,
10.39167984, 10.34774262]]) τ
(chain, draw)
float64
2.042 9.241 4.15 ... 0.6475 0.8526
array([[2.04193896, 9.24066972, 4.15010976, ..., 0.7932322 , 1.70065849,
1.95800025],
[1.86313742, 2.2520548 , 1.83633082, ..., 4.77132846, 3.37975339,
4.00559181],
[3.86975561, 1.16885748, 1.88701488, ..., 0.46114585, 0.46114585,
0.46114585],
[2.08788141, 0.93824578, 1.09188461, ..., 0.76967835, 0.6475035 ,
0.8525704 ]]) θ
(chain, draw, school)
float64
8.807 10.73 5.582 ... 10.66 9.537
array([[[ 8.80656225, 10.72669013, 5.5823238 , ..., 8.9249554 ,
13.31247549, 7.49832185],
[14.02379804, 13.0010492 , 21.0812221 , ..., 3.03347003,
5.17208219, 23.86774041],
[ 9.79058143, 7.50556587, 2.19937553, ..., 0.25685018,
11.00927166, 6.97654764],
...,
[-0.47212128, 1.22065894, -0.50594601, ..., 0.42343886,
0.57340456, 0.75538783],
[ 1.52525819, -1.24577313, 0.79185717, ..., -1.0409839 ,
0.08451943, -1.00248856],
[ 1.29864236, -3.07043548, 0.56467649, ..., -0.22612544,
1.23330408, 2.13497378]],
[[ 2.7803964 , 5.10504636, 6.0132038 , ..., 7.28723822,
4.43323349, 4.25306377],
[ 3.65445251, 4.52798012, 4.79611926, ..., 6.96286309,
4.21635385, 2.75312356],
[ 7.29623458, 5.78170956, 5.72305174, ..., 3.21789745,
7.48166038, 7.32613465],
...
[-2.04665201, -2.80531164, -2.1651359 , ..., -1.96920808,
-2.46182586, -1.75965529],
[-2.04665201, -2.80531164, -2.1651359 , ..., -1.96920808,
-2.46182586, -1.75965529],
[-2.04665201, -2.80531164, -2.1651359 , ..., -1.96920808,
-2.46182586, -1.75965529]],
[[ 7.71477051, 9.50641671, 7.97658203, ..., 6.94416837,
9.53583513, 6.52082429],
[ 8.01070889, 6.34673094, 6.78748475, ..., 5.5958823 ,
7.83699561, 7.22673458],
[ 5.96437896, 7.7073769 , 6.80241827, ..., 8.08094039,
6.03548433, 6.56118128],
...,
[11.69040381, 11.70233028, 11.50901708, ..., 12.09875837,
11.87489548, 10.49602828],
[10.1271036 , 10.96514252, 10.66210833, ..., 9.48135961,
10.43378183, 11.6164481 ],
[10.17389137, 11.40557501, 10.06773957, ..., 10.45222252,
10.65798776, 9.53671149]]]) Attributes: (5)
created_at : 2022-05-11T19:01:37.445336 arviz_version : 0.12.0 start_time : [1.65229565e+09 1.65229568e+09 1.65229565e+09 1.65229568e+09] stop_time : [1.65229568e+09 1.65229568e+09 1.65229567e+09 1.65229568e+09] inference_library : Turing
sample_stats
Dataset (xarray.Dataset)
Dimensions: (chain: 4, draw: 1000)
Coordinates:
* chain (chain) int64 1 2 3 4
* draw (draw) int64 1 2 3 4 5 6 7 ... 995 996 997 998 999 1000
Data variables:
energy (chain, draw) float64 64.18 67.03 66.54 ... 50.29 47.94
energy_error (chain, draw) float64 0.3701 -0.3943 ... 0.2559 -0.4336
tree_depth (chain, draw) int64 3 5 4 4 4 5 5 4 5 ... 4 5 5 3 3 3 3 2
diverging (chain, draw) bool False False False ... False False False
step_size_nom (chain, draw) float64 0.1954 0.1553 ... 0.1067 0.1067
acceptance_rate (chain, draw) float64 0.9558 1.0 0.9922 ... 0.9011 0.7422
log_density (chain, draw) float64 -60.71 -64.98 ... -45.81 -44.75
max_energy_error (chain, draw) float64 -0.998 -0.7635 ... -0.5656 11.73
is_accept (chain, draw) bool True True True True ... True True True
lp (chain, draw) float64 -60.71 -64.98 ... -45.81 -44.75
step_size (chain, draw) float64 0.1954 0.1553 ... 0.1067 0.1067
n_steps (chain, draw) int64 7 31 31 31 31 31 31 ... 47 15 7 7 15 7
Attributes:
created_at: 2022-05-11T19:01:37.483477
arviz_version: 0.12.0
start_time: [1.65229565e+09 1.65229568e+09 1.65229565e+09 1.65229...
stop_time: [1.65229568e+09 1.65229568e+09 1.65229567e+09 1.65229...
inference_library: Turing Dimensions:
Coordinates: (2)
Data variables: (12)
energy
(chain, draw)
float64
64.18 67.03 66.54 ... 50.29 47.94
array([[64.1772495 , 67.02877995, 66.53924733, ..., 47.08025159,
54.05028203, 54.79472183],
[58.27067994, 53.81867045, 53.52807663, ..., 59.45930469,
57.76306781, 57.70752557],
[60.6109246 , 55.83748755, 61.33510236, ..., 50.42265943,
47.95716151, 50.03958467],
[52.22062439, 52.085025 , 52.2917793 , ..., 50.08864205,
50.28921737, 47.94416489]]) energy_error
(chain, draw)
float64
0.3701 -0.3943 ... 0.2559 -0.4336
array([[ 0.37008922, -0.39434342, -0.05952854, ..., -0.02347321,
2.80132492, 0.08718077],
[ 0. , 0.39163684, 0.12050006, ..., -0.06487078,
-0.03595238, 0.03563671],
[-0.18678923, -1.43191662, 0.2991852 , ..., 0. ,
0. , 0. ],
[ 0.01892662, -0.32265689, 0.08469189, ..., 0.09071618,
0.25591222, -0.43357466]]) tree_depth
(chain, draw)
int64
3 5 4 4 4 5 5 4 ... 4 5 5 3 3 3 3 2
array([[3, 5, 4, ..., 5, 3, 3],
[1, 3, 3, ..., 4, 3, 4],
[4, 4, 3, ..., 1, 2, 1],
[4, 5, 3, ..., 3, 3, 2]], dtype=int64) diverging
(chain, draw)
bool
False False False ... False False
array([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., True, True, False],
[False, False, False, ..., False, False, False]]) step_size_nom
(chain, draw)
float64
0.1954 0.1553 ... 0.1067 0.1067
array([[0.19538872, 0.15528165, 0.15528165, ..., 0.15528165, 0.15528165,
0.15528165],
[0.46790172, 0.19191138, 0.19191138, ..., 0.19191138, 0.19191138,
0.19191138],
[0.2149889 , 0.18458452, 0.18458452, ..., 0.18458452, 0.18458452,
0.18458452],
[0.07706879, 0.10665199, 0.10665199, ..., 0.10665199, 0.10665199,
0.10665199]]) acceptance_rate
(chain, draw)
float64
0.9558 1.0 0.9922 ... 0.9011 0.7422
array([[9.55810386e-001, 1.00000000e+000, 9.92196293e-001, ...,
2.74285620e-001, 4.43578597e-002, 9.22620162e-001],
[2.39885840e-159, 7.08152205e-001, 9.08663072e-001, ...,
1.00000000e+000, 9.98872471e-001, 9.79099092e-001],
[1.00000000e+000, 9.99273217e-001, 1.13070552e-001, ...,
2.01258733e-014, 8.58230993e-012, 3.49411622e-034],
[9.85009065e-001, 4.46836554e-001, 8.78992636e-001, ...,
7.31114138e-001, 9.01135097e-001, 7.42190114e-001]]) log_density
(chain, draw)
float64
-60.71 -64.98 ... -45.81 -44.75
array([[-60.70618549, -64.98128634, -55.31854604, ..., -45.12076031,
-48.96206282, -52.52963731],
[-50.05208049, -51.62395917, -51.55692992, ..., -55.09483963,
-54.06487853, -55.72562505],
[-55.65794997, -51.24842765, -50.0189811 , ..., -41.90854303,
-41.90854303, -41.90854303],
[-49.15031396, -45.02671473, -45.57053181, ..., -45.27189852,
-45.80863596, -44.75280386]]) max_energy_error
(chain, draw)
float64
-0.998 -0.7635 ... -0.5656 11.73
array([[-9.98043355e-01, -7.63463682e-01, -8.35432592e-02, ...,
4.90426481e+02, 1.22430313e+01, 2.00416334e-01],
[ 3.65236037e+02, 1.80056817e+00, -9.78664117e-01, ...,
-9.75877481e-02, -5.64032342e-01, -8.54889308e-02],
[-2.91694069e-01, -1.43191662e+00, 1.09263571e+02, ...,
1.55425965e+03, 2.03031255e+03, 7.70368127e+01],
[ 3.92461708e-02, 3.89660147e+01, 2.94777132e-01, ...,
1.26129500e+00, -5.65598732e-01, 1.17336708e+01]]) is_accept
(chain, draw)
bool
True True True ... True True True
array([[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]]) lp
(chain, draw)
float64
-60.71 -64.98 ... -45.81 -44.75
array([[-60.70618549, -64.98128634, -55.31854604, ..., -45.12076031,
-48.96206282, -52.52963731],
[-50.05208049, -51.62395917, -51.55692992, ..., -55.09483963,
-54.06487853, -55.72562505],
[-55.65794997, -51.24842765, -50.0189811 , ..., -41.90854303,
-41.90854303, -41.90854303],
[-49.15031396, -45.02671473, -45.57053181, ..., -45.27189852,
-45.80863596, -44.75280386]]) step_size
(chain, draw)
float64
0.1954 0.1553 ... 0.1067 0.1067
array([[0.19538872, 0.15528165, 0.15528165, ..., 0.15528165, 0.15528165,
0.15528165],
[0.46790172, 0.19191138, 0.19191138, ..., 0.19191138, 0.19191138,
0.19191138],
[0.2149889 , 0.18458452, 0.18458452, ..., 0.18458452, 0.18458452,
0.18458452],
[0.07706879, 0.10665199, 0.10665199, ..., 0.10665199, 0.10665199,
0.10665199]]) n_steps
(chain, draw)
int64
7 31 31 31 31 31 ... 47 15 7 7 15 7
array([[ 7, 31, 31, ..., 37, 7, 15],
[ 1, 11, 15, ..., 15, 15, 15],
[15, 15, 7, ..., 3, 6, 1],
[15, 47, 7, ..., 7, 15, 7]], dtype=int64) Attributes: (5)
created_at : 2022-05-11T19:01:37.483477 arviz_version : 0.12.0 start_time : [1.65229565e+09 1.65229568e+09 1.65229565e+09 1.65229568e+09] stop_time : [1.65229568e+09 1.65229568e+09 1.65229567e+09 1.65229568e+09] inference_library : Turing
Each group is an ArviZ.Dataset
(a thinly wrapped xarray.Dataset
). We can view a summary of the dataset.
idata_turing_post.posterior
Dataset (xarray.Dataset)
Dimensions: (chain: 4, draw: 1000, school: 8)
Coordinates:
* chain (chain) int64 1 2 3 4
* draw (draw) int64 1 2 3 4 5 6 7 8 9 ... 993 994 995 996 997 998 999 1000
* school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
μ (chain, draw) float64 11.28 5.035 7.561 2.633 ... 11.14 10.39 10.35
τ (chain, draw) float64 2.042 9.241 4.15 ... 0.7697 0.6475 0.8526
θ (chain, draw, school) float64 8.807 10.73 5.582 ... 10.66 9.537
Attributes:
created_at: 2022-05-11T19:01:37.445336
arviz_version: 0.12.0
start_time: [1.65229565e+09 1.65229568e+09 1.65229565e+09 1.65229...
stop_time: [1.65229568e+09 1.65229568e+09 1.65229567e+09 1.65229...
inference_library: Turing Dimensions: chain : 4draw : 1000school : 8
Coordinates: (3)
Data variables: (3)
μ
(chain, draw)
float64
11.28 5.035 7.561 ... 10.39 10.35
array([[11.28480872, 5.03513352, 7.56112815, ..., 0.21433085,
0.18216388, -0.49416604],
[ 4.82227895, 5.71074843, 7.33847341, ..., 6.18405029,
6.97419847, 7.7994888 ],
[ 5.95942617, 7.96494237, 8.12160405, ..., -2.1104699 ,
-2.1104699 , -2.1104699 ],
[ 7.45443573, 6.92436129, 6.7782872 , ..., 11.14058309,
10.39167984, 10.34774262]]) τ
(chain, draw)
float64
2.042 9.241 4.15 ... 0.6475 0.8526
array([[2.04193896, 9.24066972, 4.15010976, ..., 0.7932322 , 1.70065849,
1.95800025],
[1.86313742, 2.2520548 , 1.83633082, ..., 4.77132846, 3.37975339,
4.00559181],
[3.86975561, 1.16885748, 1.88701488, ..., 0.46114585, 0.46114585,
0.46114585],
[2.08788141, 0.93824578, 1.09188461, ..., 0.76967835, 0.6475035 ,
0.8525704 ]]) θ
(chain, draw, school)
float64
8.807 10.73 5.582 ... 10.66 9.537
array([[[ 8.80656225, 10.72669013, 5.5823238 , ..., 8.9249554 ,
13.31247549, 7.49832185],
[14.02379804, 13.0010492 , 21.0812221 , ..., 3.03347003,
5.17208219, 23.86774041],
[ 9.79058143, 7.50556587, 2.19937553, ..., 0.25685018,
11.00927166, 6.97654764],
...,
[-0.47212128, 1.22065894, -0.50594601, ..., 0.42343886,
0.57340456, 0.75538783],
[ 1.52525819, -1.24577313, 0.79185717, ..., -1.0409839 ,
0.08451943, -1.00248856],
[ 1.29864236, -3.07043548, 0.56467649, ..., -0.22612544,
1.23330408, 2.13497378]],
[[ 2.7803964 , 5.10504636, 6.0132038 , ..., 7.28723822,
4.43323349, 4.25306377],
[ 3.65445251, 4.52798012, 4.79611926, ..., 6.96286309,
4.21635385, 2.75312356],
[ 7.29623458, 5.78170956, 5.72305174, ..., 3.21789745,
7.48166038, 7.32613465],
...
[-2.04665201, -2.80531164, -2.1651359 , ..., -1.96920808,
-2.46182586, -1.75965529],
[-2.04665201, -2.80531164, -2.1651359 , ..., -1.96920808,
-2.46182586, -1.75965529],
[-2.04665201, -2.80531164, -2.1651359 , ..., -1.96920808,
-2.46182586, -1.75965529]],
[[ 7.71477051, 9.50641671, 7.97658203, ..., 6.94416837,
9.53583513, 6.52082429],
[ 8.01070889, 6.34673094, 6.78748475, ..., 5.5958823 ,
7.83699561, 7.22673458],
[ 5.96437896, 7.7073769 , 6.80241827, ..., 8.08094039,
6.03548433, 6.56118128],
...,
[11.69040381, 11.70233028, 11.50901708, ..., 12.09875837,
11.87489548, 10.49602828],
[10.1271036 , 10.96514252, 10.66210833, ..., 9.48135961,
10.43378183, 11.6164481 ],
[10.17389137, 11.40557501, 10.06773957, ..., 10.45222252,
10.65798776, 9.53671149]]]) Attributes: (5)
created_at : 2022-05-11T19:01:37.445336 arviz_version : 0.12.0 start_time : [1.65229565e+09 1.65229568e+09 1.65229565e+09 1.65229568e+09] stop_time : [1.65229568e+09 1.65229568e+09 1.65229567e+09 1.65229568e+09] inference_library : Turing
Here is a plot of the trace. Note the intelligent labels.
begin
plot_trace(idata_turing_post)
gcf()
end
We can also generate summary stats...
summarystats(idata_turing_post)
variable
mean
sd
hdi_3%
hdi_97%
mcse_mean
mcse_sd
ess_bulk
ess_tail
r_hat
"μ"
3.096
3.869
-2.11
10.196
1.064
0.77
14.0
6.0
1.21
"τ"
3.284
3.132
0.3
8.972
0.688
0.493
10.0
11.0
1.31
"θ[Choate]"
4.735
6.211
-3.834
16.574
1.458
1.048
14.0
7.0
1.2
"θ[Deerfield]"
3.486
5.162
-3.685
13.78
1.271
0.915
16.0
7.0
1.17
"θ[Phillips Andover]"
2.713
5.29
-4.945
13.575
0.969
0.692
29.0
1042.0
1.09
"θ[Phillips Exeter]"
3.359
5.081
-4.326
13.469
1.033
0.739
23.0
1209.0
1.11
"θ[Hotchkiss]"
2.349
4.728
-4.477
12.34
0.84
0.6
30.0
981.0
1.09
"θ[Lawrenceville]"
2.752
5.049
-5.125
12.819
0.923
0.659
27.0
937.0
1.1
"θ[St. Paul's]"
4.815
5.935
-2.904
16.305
1.516
1.093
13.0
6.0
1.22
"θ[Mt. Hermon]"
3.406
5.53
-4.473
14.895
1.085
0.776
23.0
1153.0
1.11
...and examine the energy distribution of the Hamiltonian sampler.
begin
plot_energy(idata_turing_post)
gcf()
end
Additional information in Turing.jl
With a few more steps, we can use Turing to compute additional useful groups to add to the InferenceData
.
To sample from the prior, one simply calls sample
but with the Prior
sampler:
prior = Turing.sample(rng2, param_mod_turing, Prior(), ndraws);
To draw from the prior and posterior predictive distributions we can instantiate a "predictive model", i.e. a Turing model but with the observations set to missing
, and then calling predict
on the predictive model and the previously drawn samples:
begin
# Instantiate the predictive model
param_mod_predict = model_turing(similar(y, Missing), σ)
# and then sample!
prior_predictive = Turing.predict(rng2, param_mod_predict, prior)
posterior_predictive = Turing.predict(rng2, param_mod_predict, turing_chns)
end;
And to extract the pointwise log-likelihoods, which is useful if you want to compute metrics such as loo
,
log_likelihood = let
log_likelihood = Turing.pointwise_loglikelihoods(
param_mod_turing, MCMCChains.get_sections(turing_chns, :parameters)
)
# Ensure the ordering of the loglikelihoods matches the ordering of `posterior_predictive`
ynames = string.(keys(posterior_predictive))
log_likelihood_y = getindex.(Ref(log_likelihood), ynames)
# Reshape into `(nchains, ndraws, size(y)...)`
Dict("y" => permutedims(cat(log_likelihood_y...; dims=3), (2, 1, 3)))
end;
idata_turing = from_mcmcchains(
turing_chns;
posterior_predictive,
log_likelihood,
prior,
prior_predictive,
observed_data=Dict("y" => y),
coords=Dict("school" => schools),
dims=Dict("y" => ["school"], "σ" => ["school"], "θ" => ["school"]),
library="Turing",
)
posterior
Dataset (xarray.Dataset)
Dimensions: (chain: 4, draw: 1000, school: 8)
Coordinates:
* chain (chain) int64 1 2 3 4
* draw (draw) int64 1 2 3 4 5 6 7 8 9 ... 993 994 995 996 997 998 999 1000
* school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
μ (chain, draw) float64 11.28 5.035 7.561 2.633 ... 11.14 10.39 10.35
τ (chain, draw) float64 2.042 9.241 4.15 ... 0.7697 0.6475 0.8526
θ (chain, draw, school) float64 8.807 10.73 5.582 ... 10.66 9.537
Attributes:
created_at: 2022-05-11T19:02:12.404184
arviz_version: 0.12.0
start_time: [1.65229565e+09 1.65229568e+09 1.65229565e+09 1.65229...
stop_time: [1.65229568e+09 1.65229568e+09 1.65229567e+09 1.65229...
inference_library: Turing Dimensions: chain : 4draw : 1000school : 8
Coordinates: (3)
Data variables: (3)
μ
(chain, draw)
float64
11.28 5.035 7.561 ... 10.39 10.35
array([[11.28480872, 5.03513352, 7.56112815, ..., 0.21433085,
0.18216388, -0.49416604],
[ 4.82227895, 5.71074843, 7.33847341, ..., 6.18405029,
6.97419847, 7.7994888 ],
[ 5.95942617, 7.96494237, 8.12160405, ..., -2.1104699 ,
-2.1104699 , -2.1104699 ],
[ 7.45443573, 6.92436129, 6.7782872 , ..., 11.14058309,
10.39167984, 10.34774262]]) τ
(chain, draw)
float64
2.042 9.241 4.15 ... 0.6475 0.8526
array([[2.04193896, 9.24066972, 4.15010976, ..., 0.7932322 , 1.70065849,
1.95800025],
[1.86313742, 2.2520548 , 1.83633082, ..., 4.77132846, 3.37975339,
4.00559181],
[3.86975561, 1.16885748, 1.88701488, ..., 0.46114585, 0.46114585,
0.46114585],
[2.08788141, 0.93824578, 1.09188461, ..., 0.76967835, 0.6475035 ,
0.8525704 ]]) θ
(chain, draw, school)
float64
8.807 10.73 5.582 ... 10.66 9.537
array([[[ 8.80656225, 10.72669013, 5.5823238 , ..., 8.9249554 ,
13.31247549, 7.49832185],
[14.02379804, 13.0010492 , 21.0812221 , ..., 3.03347003,
5.17208219, 23.86774041],
[ 9.79058143, 7.50556587, 2.19937553, ..., 0.25685018,
11.00927166, 6.97654764],
...,
[-0.47212128, 1.22065894, -0.50594601, ..., 0.42343886,
0.57340456, 0.75538783],
[ 1.52525819, -1.24577313, 0.79185717, ..., -1.0409839 ,
0.08451943, -1.00248856],
[ 1.29864236, -3.07043548, 0.56467649, ..., -0.22612544,
1.23330408, 2.13497378]],
[[ 2.7803964 , 5.10504636, 6.0132038 , ..., 7.28723822,
4.43323349, 4.25306377],
[ 3.65445251, 4.52798012, 4.79611926, ..., 6.96286309,
4.21635385, 2.75312356],
[ 7.29623458, 5.78170956, 5.72305174, ..., 3.21789745,
7.48166038, 7.32613465],
...
[-2.04665201, -2.80531164, -2.1651359 , ..., -1.96920808,
-2.46182586, -1.75965529],
[-2.04665201, -2.80531164, -2.1651359 , ..., -1.96920808,
-2.46182586, -1.75965529],
[-2.04665201, -2.80531164, -2.1651359 , ..., -1.96920808,
-2.46182586, -1.75965529]],
[[ 7.71477051, 9.50641671, 7.97658203, ..., 6.94416837,
9.53583513, 6.52082429],
[ 8.01070889, 6.34673094, 6.78748475, ..., 5.5958823 ,
7.83699561, 7.22673458],
[ 5.96437896, 7.7073769 , 6.80241827, ..., 8.08094039,
6.03548433, 6.56118128],
...,
[11.69040381, 11.70233028, 11.50901708, ..., 12.09875837,
11.87489548, 10.49602828],
[10.1271036 , 10.96514252, 10.66210833, ..., 9.48135961,
10.43378183, 11.6164481 ],
[10.17389137, 11.40557501, 10.06773957, ..., 10.45222252,
10.65798776, 9.53671149]]]) Attributes: (5)
created_at : 2022-05-11T19:02:12.404184 arviz_version : 0.12.0 start_time : [1.65229565e+09 1.65229568e+09 1.65229565e+09 1.65229568e+09] stop_time : [1.65229568e+09 1.65229568e+09 1.65229567e+09 1.65229568e+09] inference_library : Turing
posterior_predictive
Dataset (xarray.Dataset)
Dimensions: (chain: 4, draw: 1000, school: 8)
Coordinates:
* chain (chain) int64 1 2 3 4
* draw (draw) int64 1 2 3 4 5 6 7 8 9 ... 993 994 995 996 997 998 999 1000
* school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
y (chain, draw, school) float64 27.75 7.488 22.88 ... 6.854 -10.97
Attributes:
start_time: [None, None, None, None]
created_at: 2022-05-11T19:02:11.710614
stop_time: [None, None, None, None]
arviz_version: 0.12.0
inference_library: Turing
log_likelihood
Dataset (xarray.Dataset)
Dimensions: (chain: 4, draw: 1000, school: 8)
Coordinates:
* chain (chain) int64 1 2 3 4
* draw (draw) int64 1 2 3 4 5 6 7 8 9 ... 993 994 995 996 997 998 999 1000
* school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
y (chain, draw, school) float64 -4.446 -3.259 ... -3.491 -3.819
Attributes:
created_at: 2022-05-11T19:02:12.315381
arviz_version: 0.12.0
inference_library: Turing
sample_stats
Dataset (xarray.Dataset)
Dimensions: (chain: 4, draw: 1000)
Coordinates:
* chain (chain) int64 1 2 3 4
* draw (draw) int64 1 2 3 4 5 6 7 ... 995 996 997 998 999 1000
Data variables:
energy (chain, draw) float64 64.18 67.03 66.54 ... 50.29 47.94
energy_error (chain, draw) float64 0.3701 -0.3943 ... 0.2559 -0.4336
tree_depth (chain, draw) int64 3 5 4 4 4 5 5 4 5 ... 4 5 5 3 3 3 3 2
diverging (chain, draw) bool False False False ... False False False
step_size_nom (chain, draw) float64 0.1954 0.1553 ... 0.1067 0.1067
acceptance_rate (chain, draw) float64 0.9558 1.0 0.9922 ... 0.9011 0.7422
log_density (chain, draw) float64 -60.71 -64.98 ... -45.81 -44.75
max_energy_error (chain, draw) float64 -0.998 -0.7635 ... -0.5656 11.73
is_accept (chain, draw) bool True True True True ... True True True
lp (chain, draw) float64 -60.71 -64.98 ... -45.81 -44.75
step_size (chain, draw) float64 0.1954 0.1553 ... 0.1067 0.1067
n_steps (chain, draw) int64 7 31 31 31 31 31 31 ... 47 15 7 7 15 7
Attributes:
created_at: 2022-05-11T19:02:12.409023
arviz_version: 0.12.0
start_time: [1.65229565e+09 1.65229568e+09 1.65229565e+09 1.65229...
stop_time: [1.65229568e+09 1.65229568e+09 1.65229567e+09 1.65229...
inference_library: Turing Dimensions:
Coordinates: (2)
Data variables: (12)
energy
(chain, draw)
float64
64.18 67.03 66.54 ... 50.29 47.94
array([[64.1772495 , 67.02877995, 66.53924733, ..., 47.08025159,
54.05028203, 54.79472183],
[58.27067994, 53.81867045, 53.52807663, ..., 59.45930469,
57.76306781, 57.70752557],
[60.6109246 , 55.83748755, 61.33510236, ..., 50.42265943,
47.95716151, 50.03958467],
[52.22062439, 52.085025 , 52.2917793 , ..., 50.08864205,
50.28921737, 47.94416489]]) energy_error
(chain, draw)
float64
0.3701 -0.3943 ... 0.2559 -0.4336
array([[ 0.37008922, -0.39434342, -0.05952854, ..., -0.02347321,
2.80132492, 0.08718077],
[ 0. , 0.39163684, 0.12050006, ..., -0.06487078,
-0.03595238, 0.03563671],
[-0.18678923, -1.43191662, 0.2991852 , ..., 0. ,
0. , 0. ],
[ 0.01892662, -0.32265689, 0.08469189, ..., 0.09071618,
0.25591222, -0.43357466]]) tree_depth
(chain, draw)
int64
3 5 4 4 4 5 5 4 ... 4 5 5 3 3 3 3 2
array([[3, 5, 4, ..., 5, 3, 3],
[1, 3, 3, ..., 4, 3, 4],
[4, 4, 3, ..., 1, 2, 1],
[4, 5, 3, ..., 3, 3, 2]], dtype=int64) diverging
(chain, draw)
bool
False False False ... False False
array([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., True, True, False],
[False, False, False, ..., False, False, False]]) step_size_nom
(chain, draw)
float64
0.1954 0.1553 ... 0.1067 0.1067
array([[0.19538872, 0.15528165, 0.15528165, ..., 0.15528165, 0.15528165,
0.15528165],
[0.46790172, 0.19191138, 0.19191138, ..., 0.19191138, 0.19191138,
0.19191138],
[0.2149889 , 0.18458452, 0.18458452, ..., 0.18458452, 0.18458452,
0.18458452],
[0.07706879, 0.10665199, 0.10665199, ..., 0.10665199, 0.10665199,
0.10665199]]) acceptance_rate
(chain, draw)
float64
0.9558 1.0 0.9922 ... 0.9011 0.7422
array([[9.55810386e-001, 1.00000000e+000, 9.92196293e-001, ...,
2.74285620e-001, 4.43578597e-002, 9.22620162e-001],
[2.39885840e-159, 7.08152205e-001, 9.08663072e-001, ...,
1.00000000e+000, 9.98872471e-001, 9.79099092e-001],
[1.00000000e+000, 9.99273217e-001, 1.13070552e-001, ...,
2.01258733e-014, 8.58230993e-012, 3.49411622e-034],
[9.85009065e-001, 4.46836554e-001, 8.78992636e-001, ...,
7.31114138e-001, 9.01135097e-001, 7.42190114e-001]]) log_density
(chain, draw)
float64
-60.71 -64.98 ... -45.81 -44.75
array([[-60.70618549, -64.98128634, -55.31854604, ..., -45.12076031,
-48.96206282, -52.52963731],
[-50.05208049, -51.62395917, -51.55692992, ..., -55.09483963,
-54.06487853, -55.72562505],
[-55.65794997, -51.24842765, -50.0189811 , ..., -41.90854303,
-41.90854303, -41.90854303],
[-49.15031396, -45.02671473, -45.57053181, ..., -45.27189852,
-45.80863596, -44.75280386]]) max_energy_error
(chain, draw)
float64
-0.998 -0.7635 ... -0.5656 11.73
array([[-9.98043355e-01, -7.63463682e-01, -8.35432592e-02, ...,
4.90426481e+02, 1.22430313e+01, 2.00416334e-01],
[ 3.65236037e+02, 1.80056817e+00, -9.78664117e-01, ...,
-9.75877481e-02, -5.64032342e-01, -8.54889308e-02],
[-2.91694069e-01, -1.43191662e+00, 1.09263571e+02, ...,
1.55425965e+03, 2.03031255e+03, 7.70368127e+01],
[ 3.92461708e-02, 3.89660147e+01, 2.94777132e-01, ...,
1.26129500e+00, -5.65598732e-01, 1.17336708e+01]]) is_accept
(chain, draw)
bool
True True True ... True True True
array([[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]]) lp
(chain, draw)
float64
-60.71 -64.98 ... -45.81 -44.75
array([[-60.70618549, -64.98128634, -55.31854604, ..., -45.12076031,
-48.96206282, -52.52963731],
[-50.05208049, -51.62395917, -51.55692992, ..., -55.09483963,
-54.06487853, -55.72562505],
[-55.65794997, -51.24842765, -50.0189811 , ..., -41.90854303,
-41.90854303, -41.90854303],
[-49.15031396, -45.02671473, -45.57053181, ..., -45.27189852,
-45.80863596, -44.75280386]]) step_size
(chain, draw)
float64
0.1954 0.1553 ... 0.1067 0.1067
array([[0.19538872, 0.15528165, 0.15528165, ..., 0.15528165, 0.15528165,
0.15528165],
[0.46790172, 0.19191138, 0.19191138, ..., 0.19191138, 0.19191138,
0.19191138],
[0.2149889 , 0.18458452, 0.18458452, ..., 0.18458452, 0.18458452,
0.18458452],
[0.07706879, 0.10665199, 0.10665199, ..., 0.10665199, 0.10665199,
0.10665199]]) n_steps
(chain, draw)
int64
7 31 31 31 31 31 ... 47 15 7 7 15 7
array([[ 7, 31, 31, ..., 37, 7, 15],
[ 1, 11, 15, ..., 15, 15, 15],
[15, 15, 7, ..., 3, 6, 1],
[15, 47, 7, ..., 7, 15, 7]], dtype=int64) Attributes: (5)
created_at : 2022-05-11T19:02:12.409023 arviz_version : 0.12.0 start_time : [1.65229565e+09 1.65229568e+09 1.65229565e+09 1.65229568e+09] stop_time : [1.65229568e+09 1.65229568e+09 1.65229567e+09 1.65229568e+09] inference_library : Turing
prior
Dataset (xarray.Dataset)
Dimensions: (chain: 1, draw: 1000, school: 8)
Coordinates:
* chain (chain) int64 1
* draw (draw) int64 1 2 3 4 5 6 7 8 9 ... 993 994 995 996 997 998 999 1000
* school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
μ (chain, draw) float64 8.628 -1.015 -3.058 ... 0.9541 6.471 8.363
τ (chain, draw) float64 4.816 1.663 1.182 19.69 ... 11.82 1.349 10.5
θ (chain, draw, school) float64 4.891 5.633 11.72 ... 11.28 11.92
Attributes:
created_at: 2022-05-11T19:02:14.119934
arviz_version: 0.12.0
start_time: 1652295706.171233
stop_time: 1652295715.782233
inference_library: Turing Dimensions: chain : 1draw : 1000school : 8
Coordinates: (3)
Data variables: (3)
μ
(chain, draw)
float64
8.628 -1.015 -3.058 ... 6.471 8.363
array([[ 8.62796118e+00, -1.01524885e+00, -3.05793096e+00,
-4.00452276e+00, -5.97575088e+00, -3.74314804e+00,
-5.14899046e+00, -7.28308579e+00, 1.30218945e+01,
-6.92212802e+00, 3.12635807e-01, 6.03636660e+00,
8.48674362e+00, -4.57513545e+00, 1.01231962e-02,
6.76690370e-02, -5.54355831e+00, 3.63804876e+00,
4.80179919e+00, 4.72282019e+00, -2.82904457e+00,
8.69987691e-01, 2.68419055e+00, -1.55098959e+00,
6.93791175e+00, 6.24713815e+00, -5.11333453e+00,
-7.09566282e+00, -1.26900523e+00, -2.97701283e+00,
-1.84762812e+00, -1.75100294e+00, -5.67953285e+00,
2.23664015e+00, -1.78276192e+00, -2.71181473e+00,
1.02609056e+01, -4.28425990e+00, 9.32898463e-01,
2.43897345e+00, 2.43835960e+00, -5.09180684e+00,
-6.47467721e+00, 4.53951933e+00, -2.21666886e+00,
-9.35138003e+00, 1.94778605e-01, -4.11270062e+00,
-5.69104049e-01, 5.76310075e+00, 1.02793542e+00,
-1.38306433e+00, 1.58281607e+00, 8.95865456e+00,
6.81966490e-01, -2.70196632e+00, 4.65922692e+00,
-1.23214936e+01, -2.60416478e+00, 4.12654094e+00,
...
-7.44012653e+00, 5.19660584e+00, 4.91630548e+00,
1.78250434e+00, 4.17774048e-01, 1.20723601e+01,
6.55838795e+00, 8.89560142e-01, 3.08963832e+00,
1.87695367e+00, -2.74317478e+00, -1.38630087e+00,
3.01576724e+00, -7.84351425e+00, -2.31470418e+00,
-1.16741405e+00, 2.92854494e+00, -3.88977557e+00,
1.79229509e+00, -2.24751339e+00, -3.65148851e+00,
-1.69380469e+00, -6.28434591e+00, 2.58992414e+00,
6.18341725e+00, 2.90361160e+00, -8.53253703e-01,
5.61781553e+00, -5.86097591e+00, -1.58028370e+00,
4.21511209e+00, 5.31159796e+00, 5.49910167e+00,
2.98063172e+00, -4.73807887e+00, 2.22076057e+00,
-1.69474100e+00, -2.02943893e+00, 1.63269622e-01,
-2.35250081e+00, 5.81820586e+00, -8.33981461e-01,
-1.63962166e+00, -8.79782004e+00, 8.45929667e+00,
4.12283800e+00, 9.66725681e+00, -4.82419492e+00,
4.41027277e-01, -1.72271365e+00, -6.15216028e-01,
7.72506613e+00, -1.72990333e+00, -3.43747023e+00,
-2.20345493e+00, 9.54074899e-01, 6.47114971e+00,
8.36286743e+00]]) τ
(chain, draw)
float64
4.816 1.663 1.182 ... 1.349 10.5
array([[4.81623358e+00, 1.66335727e+00, 1.18185184e+00, 1.96859795e+01,
8.08389610e+00, 1.77662634e+00, 1.69595259e+00, 1.03711224e+01,
2.02821871e+00, 3.23270053e+00, 6.84589038e+00, 8.79262567e+00,
1.05543849e+00, 1.72723043e+00, 2.67020117e+00, 3.22164034e+00,
1.53049668e+00, 9.20257529e+00, 1.13141973e+00, 5.08160673e+00,
1.36007765e+03, 1.47949118e+01, 2.85414420e+00, 5.16828248e+00,
4.66811866e+00, 5.18886190e+00, 2.10318317e+01, 1.43731151e+01,
7.65115999e-01, 1.61156314e+02, 1.08474295e+00, 3.75672425e+00,
9.36725668e+00, 3.07803003e+00, 1.53228063e+00, 3.16787938e+00,
6.29123678e+00, 1.92234815e+00, 3.62875761e+01, 1.28262693e+00,
2.57814057e+01, 6.95597373e+00, 1.95541223e+00, 3.46976902e+00,
5.28807547e+00, 8.75775451e+00, 2.53875265e+00, 7.42345535e+00,
6.22929972e-01, 3.48120766e-01, 1.35839395e+00, 1.21160802e+01,
2.64036401e+01, 1.65580050e+00, 6.14530488e-01, 2.97928140e+01,
1.06400662e+00, 2.59752802e+00, 2.32780296e+00, 2.12562399e+00,
1.53809319e+00, 6.28159676e+01, 8.52284377e-01, 1.91970928e+01,
6.03478241e+00, 4.28198105e+01, 1.14788652e+00, 2.27851006e-01,
8.70348735e+00, 2.92555064e+00, 1.20980740e+01, 2.42846798e+00,
3.50958406e-01, 7.27032788e+00, 3.21605253e+00, 8.12686875e-01,
2.27267654e+01, 1.10143697e+01, 4.89844268e+00, 6.21046084e-01,
...
9.69273885e+00, 2.33206634e+00, 6.23968926e+00, 3.23835236e-02,
1.78992108e+01, 2.65211744e-01, 5.12870673e+00, 4.60299941e+00,
2.17882235e+00, 3.83642289e+01, 4.03458810e+00, 7.00951344e+00,
9.65423943e+00, 1.14886742e+00, 3.55998813e+01, 4.57110244e+00,
1.54365274e+00, 1.25535824e+01, 2.26326632e+00, 2.96141377e+00,
2.71586402e+01, 4.68024138e+01, 5.03937068e+00, 6.37836454e+00,
3.66494373e+02, 3.42563532e+00, 2.71824268e+00, 2.96688415e+00,
4.01989951e+00, 2.48082052e+01, 1.49694013e+02, 1.62807105e+01,
3.82793984e+00, 3.39650188e+02, 5.05468947e+00, 2.72644580e+00,
1.15886794e+01, 8.31450677e+00, 8.21054746e+00, 3.72127353e+00,
8.58527706e-01, 1.82429857e+01, 1.87591930e+02, 4.07746274e+00,
3.32882213e+00, 7.92905818e-01, 4.62292761e+00, 7.32233343e-01,
3.37886887e+00, 1.01014818e+01, 1.56154172e+01, 3.01552408e+00,
2.04393434e+01, 5.63389685e+01, 3.90813483e+00, 3.95551303e+00,
1.15880518e+01, 1.87635290e+01, 1.04850121e+00, 1.27092194e+01,
2.72452359e+01, 6.56866956e+00, 1.45118968e-01, 1.20011475e+01,
2.16082162e+01, 1.58814638e+00, 5.98734638e+00, 2.84735998e+01,
1.28350344e+01, 1.29444488e+01, 1.18479029e+01, 7.93097138e-01,
3.27683384e+00, 2.59537756e+01, 7.36326338e-01, 1.21042832e+00,
2.20693451e+00, 1.18247139e+01, 1.34872429e+00, 1.04991290e+01]]) θ
(chain, draw, school)
float64
4.891 5.633 11.72 ... 11.28 11.92
array([[[ 4.89118403, 5.63279044, 11.72307872, ..., 3.58463328,
1.34299062, 12.34018133],
[ 1.76361596, 0.69992405, -2.6805847 , ..., -3.43031676,
-0.1978057 , 0.33155076],
[ -3.45086955, -2.19304403, -4.35137581, ..., -2.32592113,
-3.63689139, -3.35476538],
...,
[-22.73968927, -7.10589447, 4.01249395, ..., 6.29312483,
-9.72558173, 18.99869921],
[ 7.96437005, 5.8342077 , 4.71117318, ..., 6.03884512,
6.45013941, 4.96234412],
[ 9.51433273, 11.44490164, -2.09516888, ..., 9.57047535,
11.27609811, 11.9192567 ]]]) Attributes: (5)
created_at : 2022-05-11T19:02:14.119934 arviz_version : 0.12.0 start_time : 1652295706.171233 stop_time : 1652295715.782233 inference_library : Turing
prior_predictive
Dataset (xarray.Dataset)
Dimensions: (chain: 1, draw: 1000, school: 8)
Coordinates:
* chain (chain) int64 1
* draw (draw) int64 1 2 3 4 5 6 7 8 9 ... 993 994 995 996 997 998 999 1000
* school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
y (chain, draw, school) float64 -1.721 12.34 29.75 ... -10.13 18.5
Attributes:
created_at: 2022-05-11T19:02:13.938919
arviz_version: 0.12.0
inference_library: Turing
sample_stats_prior
Dataset (xarray.Dataset)
Dimensions: (chain: 1, draw: 1000)
Coordinates:
* chain (chain) int64 1
* draw (draw) int64 1 2 3 4 5 6 7 8 9 ... 993 994 995 996 997 998 999 1000
Data variables:
lp (chain, draw) float64 -61.87 -57.45 -48.72 ... -78.14 -50.08 -63.96
Attributes:
created_at: 2022-05-11T19:02:14.155328
arviz_version: 0.12.0
start_time: 1652295706.171233
stop_time: 1652295715.782233
inference_library: Turing
observed_data
Dataset (xarray.Dataset)
Dimensions: (school: 8)
Coordinates:
* school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
y (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0
Attributes:
created_at: 2022-05-11T19:02:15.581249
arviz_version: 0.12.0
inference_library: Turing
Then we can for example compute the expected leave-one-out (LOO) predictive density, which is an estimate of the out-of-distribution predictive fit of the model:
loo(idata_turing; pointwise=false) # higher is better
loo
loo_se
p_loo
n_samples
n_data_points
warning
loo_scale
-31.2509
1.59635
1.09271
4000
8
true
"log"
If the model is well-calibrated, i.e. it replicates the true generative process well, the CDF of the pointwise LOO values should be similarly distributed to a uniform distribution. This can be inspected visually:
begin
plot_loo_pit(idata_turing; y="y", ecdf=true)
gcf()
end
CmdStan.jl and StanSample.jl also default to producing Chains
outputs, and we can easily plot these chains.
Here is the same centered eight schools model:
begin
schools_code = """
data {
int<lower=0> J;
real y[J];
real<lower=0> sigma[J];
}
parameters {
real mu;
real<lower=0> tau;
real theta[J];
}
model {
mu ~ normal(0, 5);
tau ~ cauchy(0, 5);
theta ~ normal(mu, tau);
y ~ normal(theta, sigma);
}
generated quantities {
vector[J] log_lik;
vector[J] y_hat;
for (j in 1:J) {
log_lik[j] = normal_lpdf(y[j] | theta[j], sigma[j]);
y_hat[j] = normal_rng(theta[j], sigma[j]);
}
}
"""
schools_data = Dict("J" => J, "y" => y, "sigma" => σ)
stan_chns = mktempdir() do path
stan_model = Stanmodel(;
model=schools_code,
name="schools",
nchains,
num_warmup=ndraws_warmup,
num_samples=ndraws,
output_format=:mcmcchains,
random=CmdStan.Random(28983),
tmpdir=path,
)
_, chns, _ = stan(stan_model, schools_data; summary=false)
return chns
end
end;
begin
plot_density(stan_chns; var_names=["mu", "tau"])
gcf()
end
Again, converting to InferenceData
, we can get much richer labelling and mixing of data. Note that we're using the same from_cmdstan
function used by ArviZ to process cmdstan output files, but through the power of dispatch in Julia, if we pass a Chains
object, it instead uses ArviZ.jl's overloads, which forward to from_mcmcchains
.
idata_stan = from_cmdstan(
stan_chns;
posterior_predictive="y_hat",
observed_data=Dict("y" => schools_data["y"]),
log_likelihood="log_lik",
coords=Dict("school" => schools),
dims=Dict(
"y" => ["school"],
"sigma" => ["school"],
"theta" => ["school"],
"log_lik" => ["school"],
"y_hat" => ["school"],
),
)
posterior
Dataset (xarray.Dataset)
Dimensions: (chain: 4, draw: 1000, school: 8)
Coordinates:
* chain (chain) int64 1 2 3 4
* draw (draw) int64 1 2 3 4 5 6 7 8 9 ... 993 994 995 996 997 998 999 1000
* school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
theta (chain, draw, school) float64 6.541 4.165 1.673 ... 7.588 7.146
tau (chain, draw) float64 2.026 3.44 2.047 ... 0.5002 0.5771 0.5771
mu (chain, draw) float64 5.864 4.849 3.906 5.24 ... 7.369 7.617 7.617
Attributes:
created_at: 2022-05-11T19:02:53.685083
arviz_version: 0.12.0
inference_library: CmdStan Dimensions: chain : 4draw : 1000school : 8
Coordinates: (3)
Data variables: (3)
theta
(chain, draw, school)
float64
6.541 4.165 1.673 ... 7.588 7.146
array([[[ 6.5413 , 4.16455 , 1.67319 , ..., 5.92576 ,
7.94342 , 5.17905 ],
[ 7.57664 , 2.39794 , 3.42067 , ..., 3.57763 ,
12.3478 , 3.96639 ],
[ 3.96375 , 5.40868 , 6.68299 , ..., 5.57513 ,
0.491807 , 5.66544 ],
...,
[10.3071 , 0.740317 , -3.47427 , ..., 4.95825 ,
0.219024 , -1.37564 ],
[16.1877 , 13.9135 , 2.11515 , ..., 3.52645 ,
12.14 , 2.66071 ],
[ 4.47998 , 8.5612 , 16.7492 , ..., 6.29477 ,
12.4728 , 15.7279 ]],
[[-1.53146 , 0.457059 , 1.44009 , ..., -1.01029 ,
0.145186 , 0.0664686],
[ 1.76084 , -0.543271 , -1.72437 , ..., 1.38095 ,
0.291964 , -0.305073 ],
[ 1.56649 , -0.430692 , -2.26834 , ..., 0.318468 ,
0.692828 , 1.78802 ],
...
[ 5.12651 , 3.82547 , 4.35792 , ..., 2.9658 ,
5.93826 , 3.39187 ],
[ 2.85975 , 4.47044 , 2.9307 , ..., 4.32423 ,
2.61794 , 3.82375 ],
[ 2.85975 , 4.47044 , 2.9307 , ..., 4.32423 ,
2.61794 , 3.82375 ]],
[[11.7328 , 12.8994 , 4.04337 , ..., -0.38383 ,
3.3298 , 3.99475 ],
[-0.0491869, 14.3168 , 2.88353 , ..., 2.27431 ,
-1.02772 , 0.198564 ],
[ 8.98986 , 9.08316 , 4.97468 , ..., 12.3307 ,
11.3116 , 6.23245 ],
...,
[ 6.95444 , 6.62463 , 7.03101 , ..., 8.27648 ,
8.1468 , 7.2166 ],
[ 7.84862 , 6.50354 , 8.87397 , ..., 7.77533 ,
7.5875 , 7.14576 ],
[ 7.84862 , 6.50354 , 8.87397 , ..., 7.77533 ,
7.5875 , 7.14576 ]]]) tau
(chain, draw)
float64
2.026 3.44 2.047 ... 0.5771 0.5771
array([[2.02605 , 3.44003 , 2.04688 , ..., 5.83343 , 7.2419 , 4.33761 ],
[1.30956 , 1.96716 , 1.36244 , ..., 0.478954, 0.478954, 0.478954],
[1.12191 , 1.12191 , 1.65467 , ..., 1.04315 , 1.01915 , 1.01915 ],
[5.17901 , 6.3555 , 3.33977 , ..., 0.500199, 0.577086, 0.577086]]) mu
(chain, draw)
float64
5.864 4.849 3.906 ... 7.617 7.617
array([[ 5.8636 , 4.84864 , 3.90624 , ..., 0.748423 , 6.95658 ,
7.11618 ],
[ 0.453971 , 0.680242 , -0.0159513, ..., 2.14087 , 2.14087 ,
2.14087 ],
[11.0351 , 11.0351 , 12.0991 , ..., 3.84415 , 3.85201 ,
3.85201 ],
[ 2.14571 , 5.27973 , 5.252 , ..., 7.3694 , 7.61676 ,
7.61676 ]]) Attributes: (3)
created_at : 2022-05-11T19:02:53.685083 arviz_version : 0.12.0 inference_library : CmdStan
posterior_predictive
Dataset (xarray.Dataset)
Dimensions: (chain: 4, draw: 1000, school: 8)
Coordinates:
* chain (chain) int64 1 2 3 4
* draw (draw) int64 1 2 3 4 5 6 7 8 9 ... 993 994 995 996 997 998 999 1000
* school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
y_hat (chain, draw, school) float64 -10.71 23.66 -9.425 ... 9.078 20.88
Attributes:
created_at: 2022-05-11T19:02:53.680765
arviz_version: 0.12.0
inference_library: CmdStan
log_likelihood
Dataset (xarray.Dataset)
Dimensions: (chain: 4, draw: 1000, school: 8)
Coordinates:
* chain (chain) int64 1 2 3 4
* draw (draw) int64 1 2 3 4 5 6 7 8 9 ... 993 994 995 996 997 998 999 1000
* school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
log_lik (chain, draw, school) float64 -4.65 -3.295 -3.734 ... -3.764 -3.846
Attributes:
created_at: 2022-05-11T19:02:53.682821
arviz_version: 0.12.0
inference_library: CmdStan
sample_stats
Dataset (xarray.Dataset)
Dimensions: (chain: 4, draw: 1000)
Coordinates:
* chain (chain) int64 1 2 3 4
* draw (draw) int64 1 2 3 4 5 6 7 ... 994 995 996 997 998 999 1000
Data variables:
tree_depth (chain, draw) int64 4 2 4 3 3 3 4 4 4 ... 1 1 2 1 2 1 5 3 2
diverging (chain, draw) bool False False False ... False True False
energy (chain, draw) float64 21.85 17.36 18.12 ... 11.01 16.54
lp (chain, draw) float64 -11.6 -14.53 -12.58 ... -6.09 -6.09
step_size (chain, draw) float64 0.2012 0.2012 ... 0.1493 0.1493
acceptance_rate (chain, draw) float64 0.9582 0.8976 ... 0.04874 1.587e-05
n_steps (chain, draw) int64 15 7 15 15 7 15 15 ... 5 1 5 3 31 9 3
Attributes:
created_at: 2022-05-11T19:02:53.688805
arviz_version: 0.12.0
inference_library: CmdStan Dimensions:
Coordinates: (2)
Data variables: (7)
tree_depth
(chain, draw)
int64
4 2 4 3 3 3 4 4 ... 1 2 1 2 1 5 3 2
array([[4, 2, 4, ..., 4, 4, 4],
[2, 3, 2, ..., 1, 1, 1],
[3, 3, 5, ..., 2, 4, 2],
[5, 4, 4, ..., 5, 3, 2]], dtype=int64) diverging
(chain, draw)
bool
False False False ... True False
array([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, True, False],
[False, False, False, ..., True, False, False],
[False, False, False, ..., False, True, False]]) energy
(chain, draw)
float64
21.85 17.36 18.12 ... 11.01 16.54
array([[21.8457 , 17.3568 , 18.1216 , ..., 21.8167 , 22.2094 , 24.0544 ],
[14.6107 , 14.2824 , 14.2744 , ..., 8.75661, 5.95451, 14.1661 ],
[10.8452 , 14.1145 , 15.1179 , ..., 17.5578 , 10.4742 , 10.1383 ],
[23.2572 , 25.1984 , 25.1724 , ..., 10.8737 , 11.0074 , 16.5383 ]]) lp
(chain, draw)
float64
-11.6 -14.53 -12.58 ... -6.09 -6.09
array([[-11.5961 , -14.531 , -12.5757 , ..., -18.5766 , -19.326 ,
-20.4905 ],
[-10.4032 , -10.2195 , -11.5774 , ..., -3.1093 , -3.1093 ,
-3.1093 ],
[ -8.20287, -8.20287, -10.9453 , ..., -10.9182 , -6.2819 ,
-6.2819 ],
[-19.4646 , -21.8295 , -18.5388 , ..., -4.66215, -6.08987,
-6.08987]]) step_size
(chain, draw)
float64
0.2012 0.2012 ... 0.1493 0.1493
array([[0.201238, 0.201238, 0.201238, ..., 0.201238, 0.201238, 0.201238],
[0.196767, 0.196767, 0.196767, ..., 0.196767, 0.196767, 0.196767],
[0.209863, 0.209863, 0.209863, ..., 0.209863, 0.209863, 0.209863],
[0.149324, 0.149324, 0.149324, ..., 0.149324, 0.149324, 0.149324]]) acceptance_rate
(chain, draw)
float64
0.9582 0.8976 ... 0.04874 1.587e-05
array([[9.58170e-001, 8.97629e-001, 9.38380e-001, ..., 8.60358e-001,
9.94017e-001, 9.90637e-001],
[5.70767e-001, 7.39040e-001, 7.40280e-001, ..., 7.11221e-048,
9.05449e-006, 8.47847e-105],
[5.18618e-001, 7.93961e-003, 1.82818e-002, ..., 1.43309e-002,
1.06053e-001, 2.72006e-001],
[9.69573e-001, 9.77983e-001, 9.98699e-001, ..., 7.37071e-006,
4.87396e-002, 1.58747e-005]]) n_steps
(chain, draw)
int64
15 7 15 15 7 15 15 ... 1 5 3 31 9 3
array([[15, 7, 15, ..., 15, 15, 15],
[ 7, 7, 3, ..., 1, 3, 1],
[ 7, 7, 31, ..., 5, 31, 7],
[31, 15, 31, ..., 31, 9, 3]], dtype=int64) Attributes: (3)
created_at : 2022-05-11T19:02:53.688805 arviz_version : 0.12.0 inference_library : CmdStan
observed_data
Dataset (xarray.Dataset)
Dimensions: (school: 8)
Coordinates:
* school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
y (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0
Attributes:
created_at: 2022-05-11T19:02:53.692047
arviz_version: 0.12.0
inference_library: CmdStan
Here is a plot showing where the Hamiltonian sampler had divergences:
begin
plot_pair(
idata_stan;
coords=Dict("school" => ["Choate", "Deerfield", "Phillips Andover"]),
divergences=true,
)
gcf()
end