Types and utility functions for summarizing Markov chain Monte Carlo simulations



Build Status Build status Coverage Status

Implementation of Julia types for summarizing MCMC simulations and utility functions for diagnostics and visualizations.


The following simple example illustrates how to use Chain to visually summarize a MCMC simulation:

using MCMCChains
using StatsPlots


# Define the experiment
n_iter = 500
n_name = 3
n_chain = 2

# experiment results
val = randn(n_iter, n_name, n_chain) .+ [1, 2, 3]'
val = hcat(val, rand(1:2, n_iter, 1, n_chain))

# construct a Chains object
chn = Chains(val)

# visualize the MCMC simulation results
p1 = plot(chn)
p2 = plot(chn, colordim = :parameter)

This code results in the visualizations shown below. Note that the plot function takes the additional arguments described in the Plots.jl package.

Summarize parameters  Summarize chains
plot(chn; colordim = :chain) plot(chn; colordim = :parameter)
p1 p2


Chains type

# construction of a Chains object with no names
    evidence = 0.0,

    name_map = (parameters = parameter_names,);
    evidence = 0.0,

# Indexing a Chains object
chn = Chains(...)
chn_param1 = chn[:,2,:] # returns a new Chains object for parameter 2
chn[:,2,:] = ... # set values for parameter 2

Parameter Names

Chains can be constructed with parameter names, like so:

# 500 samples, 5 parameters, two chains.
val = rand(500,5, 2)

chn = Chains(val, ["a", "b", "c", "d", "e"])

By default, parameters will be given the name :param_i, where i is the parameter number.

Rename Parameters

Parameter names can be changed with the function replacenames, which accepts a Chains object and pairs of old and new parameter names. Note that replacenames creates a new Chains object that shares the same underlying data.

chn = Chains(
    rand(100, 5, 5),
    ["one", "two", "three", "four", "five"],
    Dict(:internals => ["four", "five"])

# Set "one" and "five" to uppercase.
chn2 = replacenames(chn,  "one" => "ONE", "five" => "FIVE")

# Alternatively you can provide a dictionary.
chn3 = replacenames(chn, Dict("two" => "TWO", "four" => "FOUR"))


Chains parameters are sorted into sections that represent groups of parameters. By default, every chain contains a :parameters section, to which all unassigned parameters are assigned to. Chains can be assigned a named map during construction:

chn = Chains(val,
  ["a", "b", "c", "d", "e"],
  Dict(:internals => ["d", "e"]))

The set_section function returns a new Chains object:

chn2 = set_section(chn, Dict(:internals => ["d", "e"]))

Any parameters not assigned will be placed into :parameters.

Calling show(chn) provides the following output:

Log evidence      = 0.0
Iterations        = 1:500
Thinning interval = 1
Chains            = 1, 2, 3
Samples per chain = 500
parameters        = c, b, a

Empirical Posterior Estimates
   Mean    SD   Naive SE  MCSE  ESS
a 0.5169 0.2920   0.0075 0.0066 500
b 0.4891 0.2929   0.0076 0.0070 500
c 0.5102 0.2840   0.0073 0.0068 500

   2.5%   25.0%  50.0%  75.0%  97.5%
a 0.0001 0.2620 0.5314 0.7774 0.9978
b 0.0001 0.2290 0.4972 0.7365 0.9998
c 0.0004 0.2739 0.5137 0.7498 0.9997

Note that only a, b, and c are being shown. You can explicity show the :internals section by calling describe(chn; sections = :internals) or all variables with describe(chn; sections = nothing). Many functions such as plot or gelmandiag support the sections keyword argument.

Groups of parameters

By convention, MCMCChains assumes that parameters with names of the form "name[index]" belong to one group of parameters called :name. You can access the names of all parameters in a chain that belong to the group :name by running

namesingroup(chain, :name)

If the chain contains a parameter of name :name it will be returned as well.

The function group(chain, :name) returns a subset of the chain chain with all parameters in the group :name.

The get Function

MCMCChains provides a get function designed to make it easier to access parameters get(chn, :P) returns a NamedTuple which can be easy to work with.


val = rand(500, 5, 1)
chn = Chains(val, ["P[1]", "P[2]", "P[3]", "D", "E"]);

x = get(chn, :P)

Here's what x looks like:

(P = (Union{Missing, Float64}[0.349592; 0.671365; … ; 0.319421; 0.298899], Union{Missing, Float64}[0.757884; 0.720212; … ; 0.471339; 0.5381], Union{Missing, Float64}[0.240626; 0.987814; … ; 0.980652; 0.149805]),)

You can access each of the P[. . .] variables by indexing, using x.P[1], x.P[2], or x.P[3].

get also accepts vectors of things to retrieve, so you can call x = get(chn, [:P, :D]). This looks like

(P = (Union{Missing, Float64}[0.349592; 0.671365; … ; 0.319421; 0.298899], Union{Missing, Float64}[0.757884; 0.720212; … ; 0.471339; 0.5381], Union{Missing, Float64}[0.240626; 0.987814; … ; 0.980652; 0.149805]),
 D = Union{Missing, Float64}[0.648963; 0.0419232; … ; 0.54666; 0.746028])

Note that x.P is a tuple which has to be indexed by the relevant index, while x.D is just a vector.

Convergence Diagnostics functions

Discrete Diagnostic

Options for method are [:weiss, :hangartner, :DARBOOT, MCBOOT, :billinsgley, :billingsleyBOOT]

discretediag(c::Chains; frac=0.3, method=:weiss, nsim=1000)

Gelman, Rubin, and Brooks Diagnostics

gelmandiag(c::Chains; alpha=0.05, mpsrf=false, transform=false)

Geweke Diagnostic

gewekediag(c::Chains; first=0.1, last=0.5, etype=:imse)

Heidelberger and Welch Diagnostics

heideldiag(c::Chains; alpha=0.05, eps=0.1, etype=:imse)

Raftery and Lewis Diagnostic

rafterydiag(c::Chains; q=0.025, r=0.005, s=0.95, eps=0.001)

Model Selection

Deviance Information Criterion (DIC)

chn ... # sampling results
lpfun = function f(chain::Chains) # function to compute the logpdf values
    niter, nparams, nchains = size(chain)
    lp = zeros(niter + nchains) # resulting logpdf values
    for i = 1:nparams
        lp += map(p -> logpdf( ... , x), Array(chain[:,i,:]))
    return lp
DIC, pD = dic(chn, lpfun)


# construct a plot
plot(c::Chains, seriestype = (:traceplot, :mixeddensity))

# construct trace plots
plot(c::Chains, seriestype = :traceplot)

# or for all seriestypes use the alternative shorthand syntax

# construct running average plots

# construct density plots

# construct histogram plots

# construct mixed density plots

# construct autocorrelation plots

# make a cornerplot (requires StatPlots) of parameters in a Chain:
corner(c::Chains, [:A, :B])

Saving and Loading Chains

Chains objects can be serialized and deserialized using read and write.

# Save a chain.
write("chain-file.jls", chn)

# Read a chain.
chn2 = read("chain-file.jls", Chains)

Exporting Chains

A few utility export functions have been provided to convers Chains objects to either an Array or a DataFrame:

# Several examples of creating an Array object:
Array(chns, [:parameters])
Array(chns, [:parameters, :internals])

# By default chains are appended. This can be disabled
# using the append_chains keyword argument:
Array(chns, append_chains=false)

# This will return an `Array{Array, 1}` object containing
# an Array for each chain.

# A final option is:
Array(chns, remove_missing_union=false)

# This will not convert the Array columns from a
# `Union{Missing, Real}` to a `Vector{Real}`.

Similarly, for DataFrames:

DataFrame(chns, [:parameters])
DataFrame(chns, [:parameters, :internals])
DataFrame(chns, append_chains=false)
DataFrame(chns, remove_missing_union=false)

See also ?DataFrame and ?Array for more help.

Sampling Chains

MCMCChains overloads several sample methods as defined in StatsBase:

# Sampling `n` samples from the chain `a`. Optionally
# weighting the samples using `wv`.
sample([rng], a, [wv::AbstractWeights], n::Integer)

# As above, but supports replacing and ordering.
sample([rng], a, [wv::AbstractWeights], n::Integer; replace=true,

See also ?sample for additional help. Alternatively, you can construct and sample from a kernel density estimator using the KernelDensity package:

using KernelDensity

# Construct a kernel density estimator
c = kde(Array(chn[:s]))

# Generate 10000 weighted samples from the grid points
chn_weighted_sample = sample(c.x, Weights(c.density), 100000)

License Notice

Note that this package heavily uses and adapts code from the Mamba.jl package licensed under MIT License, see License.md.

First Commit


Last Touched

2 days ago


376 commits