Chapter 5. Model Comparison

import os
import warnings

import arviz as az
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

import jax.numpy as jnp
from jax import random, vmap, local_device_count, pmap, lax, tree_map
from jax import nn as jnn
from jax.scipy import stats, special

import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim

from numpyro.infer import MCMC, NUTS, HMC, Predictive
from numpyro.diagnostics import hpdi, print_summary
from numpyro.infer import Predictive, SVI, Trace_ELBO, init_to_value
from numpyro.infer.autoguide import AutoLaplaceApproximation

seed=1234

if "SVG" in os.environ:
    %config InlineBackend.figure_formats = ["svg"]
warnings.formatwarning = lambda message, category, *args, **kwargs: "{}: {}\n".format(
    category.__name__, message
)
az.style.use("arviz-darkgrid")
numpyro.set_platform("cpu") # or "gpu", "tpu" depending on system
numpyro.set_host_device_count(local_device_count())
# import pymc3 as pm
# import numpy as np
# import scipy.stats as stats
# import matplotlib.pyplot as plt
# import arviz as az
# az.style.use('arviz-darkgrid')
dummy_data = pd.read_csv('../data/dummy.csv', delimiter=' ', header=None, dtype=float)
x_1 = jnp.asarray(dummy_data)[:, 0]
y_1 = jnp.asarray(dummy_data)[:, 1]

order = 2
x_1p = jnp.vstack([x_1**i for i in range(1, order+1)])
x_1s = (x_1p - x_1p.mean(axis=1, keepdims=True)) / x_1p.std(axis=1, keepdims=True)
y_1s = (y_1 - y_1.mean()) / y_1.std()
plt.scatter(x_1s[0], y_1s)
plt.xlabel('x')
plt.ylabel('y')
Text(0, 0.5, 'y')
../_images/05-model_comparison_4_1.png
def model_l(obs=None):
    α = numpyro.sample('α', dist.Normal(loc=0, scale=1))
    β = numpyro.sample('β', dist.Normal(loc=0, scale=10))
    ϵ = numpyro.sample('ϵ', dist.HalfNormal(scale=5))

    μ = α + β * x_1s[0]

    y_pred = numpyro.sample('y_pred', dist.Normal(loc=μ, scale=ϵ), obs=obs)
    
    
kernel = NUTS(model_l)
mcmc_l = MCMC(kernel, num_warmup=500, num_samples=2000, num_chains=2, chain_method='sequential')
mcmc_l.run(random.PRNGKey(seed), obs=y_1s) 

def model_p(obs=None):
    α = numpyro.sample('α', dist.Normal(loc=0, scale=1))
    β = numpyro.sample('β', dist.Normal(loc=0, scale=10), sample_shape=(order,))
    ϵ = numpyro.sample('ϵ', dist.HalfNormal(scale=5))

    μ = α + jnp.dot(β, x_1s)

    y_pred = numpyro.sample('y_pred', dist.Normal(loc=μ, scale=ϵ), obs=obs)
    
    
kernel = NUTS(model_p)
mcmc_p = MCMC(kernel, num_warmup=500, num_samples=2000, num_chains=2, chain_method='sequential')
mcmc_p.run(random.PRNGKey(seed), obs=y_1s) 
sample: 100%|███████████████████████████| 2500/2500 [00:02<00:00, 915.45it/s, 3 steps of size 6.37e-01. acc. prob=0.95]
sample: 100%|██████████████████████████| 2500/2500 [00:00<00:00, 6931.30it/s, 3 steps of size 7.05e-01. acc. prob=0.93]
sample: 100%|██████████████████████████| 2500/2500 [00:02<00:00, 894.01it/s, 11 steps of size 2.61e-01. acc. prob=0.94]
sample: 100%|█████████████████████████| 2500/2500 [00:00<00:00, 6582.31it/s, 15 steps of size 2.86e-01. acc. prob=0.94]
x_new = jnp.linspace(x_1s[0].min(), x_1s[0].max(), 100)

α_l_post = mcmc_l.get_samples()['α'].mean()
β_l_post = mcmc_l.get_samples()['β'].mean(axis=0)
y_l_post = α_l_post + β_l_post *  x_new

plt.plot(x_new, y_l_post, 'C1', label='linear model')

α_p_post = mcmc_p.get_samples()['α'].mean()
β_p_post = mcmc_p.get_samples()['β'].mean(axis=0)
idx = jnp.argsort(x_1s[0])
y_p_post = α_p_post + jnp.dot(β_p_post, x_1s)

plt.plot(x_1s[0][idx], y_p_post[idx], 'C2', label=f'model order {order}')

#α_p_post = trace_p['α'].mean()
#β_p_post = trace_p['β'].mean(axis=0)
#x_new_p = np.vstack([x_new**i for i in range(1, order+1)])
#y_p_post = α_p_post + np.dot(β_p_post, x_new_p) 

plt.scatter(x_1s[0], y_1s, c='C0', marker='.')
plt.legend()
<matplotlib.legend.Legend at 0x135ca5790>
../_images/05-model_comparison_6_1.png

Posterior predictive checks

prior = Predictive(mcmc_l.sampler.model, num_samples=10)
prior_p = prior(random.PRNGKey(seed), obs=y_1s)

pred = Predictive(model=mcmc_l.sampler.model, posterior_samples=mcmc_l.get_samples(), return_sites=['y_pred'])
post_p = pred(random.PRNGKey(seed))
y_l = post_p['y_pred']

samples = az.from_numpyro(mcmc_l, posterior_predictive=post_p)
az.plot_ppc(samples, mean=True, observed=True, color='C0', alpha=0.01)
<AxesSubplot:xlabel='y_pred'>
../_images/05-model_comparison_8_1.png
prior = Predictive(mcmc_p.sampler.model, num_samples=10)
prior_p = prior(random.PRNGKey(seed), obs=y_1s)

pred = Predictive(model=mcmc_p.sampler.model, posterior_samples=mcmc_p.get_samples(), return_sites=['y_pred'])
post_p = pred(random.PRNGKey(seed))
y_p = post_p['y_pred']

samples = az.from_numpyro(mcmc_p, posterior_predictive=post_p)
az.plot_ppc(samples, mean=True, observed=True, color='C0', alpha=0.01)
<AxesSubplot:xlabel='y_pred'>
../_images/05-model_comparison_9_1.png
plt.figure(figsize=(8, 3))
data = [y_1s, y_l, y_p]
labels = ['data', 'linear model', 'order 2']
for i, d in enumerate(data):
    mean = d.mean()
    err = jnp.percentile(d, [25, 75])
    plt.errorbar(mean, -i, xerr=[[-err[0]], [err[1]]], fmt='o')
    plt.text(mean, -i+0.2, labels[i], ha='center', fontsize=14)
plt.ylim([-i-0.5, 0.5])
plt.yticks([])
([], [])
../_images/05-model_comparison_10_1.png
fig, ax = plt.subplots(1, 2, figsize=(10, 3), constrained_layout=True)


def iqr(x, a=0):
    return jnp.subtract(*jnp.percentile(x, [75, 25], axis=a))


for idx, func in enumerate([jnp.mean, iqr]):
    T_obs = func(y_1s)
    ax[idx].axvline(T_obs, 0, 1, color='k', ls='--')
    for d_sim, c in zip([y_l, y_p], ['C1', 'C2']):
        T_sim = func(d_sim, 1)
        p_value = jnp.mean(T_sim >= T_obs)
        az.plot_kde(T_sim, plot_kwargs={'color': c}, label=f'p-value {p_value:.2f}', ax=ax[idx])
    ax[idx].set_title(func.__name__)
    ax[idx].set_yticks([])
    ax[idx].legend()
../_images/05-model_comparison_11_0.png

Occam’s razor – simplicity and accuracy

x = jnp.array([4., 5., 6., 9., 12, 14.])
y = jnp.array([4.2, 6., 6., 9., 10, 10.])

plt.figure(figsize=(10, 5))
orders = [0, 1, 2, 5]
plt.plot(x, y, 'o')
for i in orders:
    x_n = jnp.linspace(x.min(), x.max(), 100)
    coeffs = jnp.polyfit(x, y, deg=i)
    ffit = jnp.polyval(coeffs, x_n)

#     p = onp.poly1d(coeffs)
    yhat = jnp.polyval(coeffs, x)
#     yhat = p(x)
    ybar = jnp.mean(y)
    ssreg = jnp.sum((yhat-ybar)**2)
    sstot = jnp.sum((y - ybar)**2)
    r2 = ssreg / sstot

    plt.plot(x_n, ffit, label=f'order {i}, $R^2$= {r2:.2f}')

plt.legend(loc=2)
plt.xlabel('x')
plt.ylabel('y', rotation=0)
# plt.savefig('B11197_05_05.png', dpi=300)
plt.plot([10, 7], [9, 7], 'ks')
# plt.savefig('B11197_05_06.png', dpi=300)
[<matplotlib.lines.Line2D at 0x13bf93370>]
../_images/05-model_comparison_13_1.png

Computing information criteria with PyMC3

waic_l = az.waic(mcmc_l)
waic_l
Computed from 4000 by 33 log-likelihood matrix

          Estimate       SE
elpd_waic   -14.38     2.67
p_waic        2.44        -
cmp_df = az.compare({'model_l':mcmc_l, 'model_p':mcmc_p},
                    method='BB-pseudo-BMA')
cmp_df
UserWarning: The default method used to estimate the weights for each model,has changed from BB-pseudo-BMA to stacking
rank loo p_loo d_loo weight se dse warning loo_scale
model_p 0 -4.579088 2.623271 0.000000 0.998345 2.625487 0.000000 False log
model_l 1 -14.409419 2.472847 9.830331 0.001655 2.363713 2.655246 False log
az.plot_compare(cmp_df)
<AxesSubplot:xlabel='Log'>
../_images/05-model_comparison_17_1.png

Model Averaging

# TODO: figure out how to do this in numpyro
# w = 0.5
# y_lp = pm.sample_posterior_predictive_w([trace_l, trace_p],
#                                         samples=1000,
#                                         models=[model_l, model_p],
#                                         weights=[w, 1-w])
# _, ax = plt.subplots(figsize=(10, 6))
# az.plot_kde(y_l, plot_kwargs={'color': 'C1'}, label='linear model', ax=ax)
# az.plot_kde(y_p, plot_kwargs={'color': 'C2'}, label='order 2 model', ax=ax)
# az.plot_kde(y_lp['y_pred'], plot_kwargs={'color': 'C3'},
#            label='weighted model', ax=ax)

# plt.plot(y_1s, jnp.zeros_like(y_1s), '|', label='observed data')
# plt.yticks([])
# plt.legend()
# plt.savefig('B11197_05_09.png', dpi=300)

Bayes factors

coins = 30 # 300
heads = 9 # 90
y_d = jnp.repeat(jnp.array([0, 1]), jnp.array([coins-heads, heads]))
p = jnp.array([0.5, 0.5])
print(p)
model_index = dist.Categorical(probs=p).sample(key=random.PRNGKey(2), sample_shape=(10,))

m_0 = (4, 8)
m_1 = (8, 4)
checkeq = lax.eq(model_index, 0)
checkeq.astype(int)
print(checkeq.astype(int))
# m = lax.switch(checkeq, m_0, m_1)
[0.5 0.5]
[1 0 0 1 1 1 1 0 0 1]
# def model(obs=None):
#     p = jnp.array([0.5, 0.5])
#     model_index = numpyro.sample('model_index', dist.Categorical(probs=p))

#     m_0 = (4, 8)
#     m_1 = (8, 4)
#     checkeq = lax.eq(model_index, 0).astype(int)
#     m = lax.switch(checkeq, m_0, m_1)
# #     m = pm.math.switch(pm.math.eq(model_index, 0), m_0, m_1)

#     # a priori
#     θ = numpyro.sample('θ', dist.Beta(concentration1=m[0], concentration0=m[1]))
#     # likelihood
#     y = numpyro.sample('y', dist.Bernoulli(probs=θ), obs=obs)
    
    
# kernel = NUTS(model)
# mcmc_BF = MCMC(kernel, num_warmup=500, num_samples=2000, num_chains=2, chain_method='sequential')
# mcmc_BF.run(random.PRNGKey(seed), obs=y_d) 

# az.plot_trace(mcmc_BF)
# with pm.Model() as model_BF:
#     p = np.array([0.5, 0.5])
#     model_index = pm.Categorical('model_index', p=p)

#     m_0 = (4, 8)
#     m_1 = (8, 4)
#     m = pm.math.switch(pm.math.eq(model_index, 0), m_0, m_1)

#     # a priori
#     θ = pm.Beta('θ', m[0], m[1])
#     # likelihood
#     y = pm.Bernoulli('y', θ, observed=y_d)

#     trace_BF = pm.sample(5000)
# az.plot_trace(trace_BF)
# plt.savefig('B11197_05_11.png', dpi=300)
# pM1 = mcmc_BF.get_samples()['model_index'].mean()
# pM0 = 1 - pM1
# BF = (pM0 / pM1) * (p[1] / p[0])
# BF
def model_BF_0(obs=None):
    θ = numpyro.sample('θ', dist.Beta(concentration1=4, concentration0=8))
    y = numpyro.sample('y', dist.Bernoulli(probs=θ), obs=obs)

kernel = NUTS(model_BF_0)
mcmc_BF_0 = MCMC(kernel, num_warmup=500, num_samples=2000, num_chains=2, chain_method='sequential')
mcmc_BF_0.run(random.PRNGKey(seed), obs=y_d) 
                       
def model_BF_1(obs=None):
    θ = numpyro.sample('θ', dist.Beta(concentration1=8, concentration0=4))
    y = numpyro.sample('y', dist.Bernoulli(probs=θ), obs=obs)
                       
kernel = NUTS(model_BF_1)
mcmc_BF_1 = MCMC(kernel, num_warmup=500, num_samples=2000, num_chains=2, chain_method='sequential')
mcmc_BF_1.run(random.PRNGKey(seed), obs=y_d) 
sample: 100%|███████████████████████████| 2500/2500 [00:02<00:00, 915.99it/s, 3 steps of size 9.87e-01. acc. prob=0.94]
sample: 100%|██████████████████████████| 2500/2500 [00:00<00:00, 4046.24it/s, 1 steps of size 1.03e+00. acc. prob=0.93]
sample: 100%|███████████████████████████| 2500/2500 [00:02<00:00, 915.59it/s, 3 steps of size 9.37e-01. acc. prob=0.94]
sample: 100%|██████████████████████████| 2500/2500 [00:00<00:00, 4153.99it/s, 1 steps of size 1.13e+00. acc. prob=0.89]
# https://forum.pyro.ai/t/mcmc-and-marginal-likelihood/2025
# model_BF_0.marginal_likelihood / model_BF_1.marginal_likelihood

Bayes factors and information criteria

traces = []
waics = []

for coins, heads in [(30, 9), (300, 90)]:
    y_d = jnp.repeat(jnp.array([0, 1]), jnp.array([coins-heads, heads]))
    for priors in [(4, 8), (8, 4)]:
        def model(obs=None): # with pm.Model() as model:
            θ = numpyro.sample('θ', dist.Beta(*priors))
            y = numpyro.sample('y', dist.Bernoulli(probs=θ), obs=obs)
#             trace = pm.sample(2000)
        kernel = NUTS(model)
        trace = MCMC(kernel, num_warmup=500, num_samples=2000, num_chains=2, chain_method='sequential')
        trace.run(random.PRNGKey(seed), obs=y_d) 
        traces.append(trace)
        waics.append(az.waic(trace))
sample: 100%|███████████████████████████| 2500/2500 [00:02<00:00, 954.85it/s, 3 steps of size 9.87e-01. acc. prob=0.94]
sample: 100%|██████████████████████████| 2500/2500 [00:00<00:00, 7386.16it/s, 1 steps of size 1.03e+00. acc. prob=0.93]
sample: 100%|██████████████████████████| 2500/2500 [00:02<00:00, 1087.66it/s, 3 steps of size 9.37e-01. acc. prob=0.94]
sample: 100%|██████████████████████████| 2500/2500 [00:00<00:00, 7243.88it/s, 1 steps of size 1.13e+00. acc. prob=0.89]
sample: 100%|██████████████████████████| 2500/2500 [00:02<00:00, 1031.30it/s, 3 steps of size 9.27e-01. acc. prob=0.94]
sample: 100%|██████████████████████████| 2500/2500 [00:00<00:00, 6918.44it/s, 1 steps of size 1.32e+00. acc. prob=0.88]
sample: 100%|███████████████████████████| 2500/2500 [00:02<00:00, 964.63it/s, 3 steps of size 1.05e+00. acc. prob=0.94]
sample: 100%|██████████████████████████| 2500/2500 [00:00<00:00, 7002.64it/s, 3 steps of size 9.68e-01. acc. prob=0.92]
traces
[<numpyro.infer.mcmc.MCMC at 0x13bc96fa0>,
 <numpyro.infer.mcmc.MCMC at 0x13c9900a0>,
 <numpyro.infer.mcmc.MCMC at 0x13c3ea0a0>,
 <numpyro.infer.mcmc.MCMC at 0x13cdf9580>]
model_names = ['Model_0 (30-9)', 'Model_1 (30-9)',
               'Model_0 (300-90)', 'Model_1 (300-90)']

az.plot_forest(traces, model_names=model_names)
array([<AxesSubplot:title={'center':'94.0% HDI'}>], dtype=object)
../_images/05-model_comparison_33_1.png
waics
[Computed from 4000 by 30 log-likelihood matrix
 
           Estimate       SE
 elpd_waic   -19.01     2.14
 p_waic        0.68        -,
 Computed from 4000 by 30 log-likelihood matrix
 
           Estimate       SE
 elpd_waic   -19.59     1.08
 p_waic        0.63        -,
 Computed from 4000 by 300 log-likelihood matrix
 
           Estimate       SE
 elpd_waic  -184.16     6.80
 p_waic        0.90        -,
 Computed from 4000 by 300 log-likelihood matrix
 
           Estimate       SE
 elpd_waic  -184.26     6.26
 p_waic        0.88        -]
fig, ax = plt.subplots(1, 2, sharey=True)

labels = model_names
indices = [0, 0, 1, 1]
for i, (ind, d) in enumerate(zip(indices, waics)):
    mean = d.waic
    ax[ind].errorbar(mean, -i, xerr=d.waic_se, fmt='o')
    ax[ind].text(mean, -i+0.2, labels[i], ha='center')

# ax[0].set_xlim(30, 50)
# ax[1].set_xlim(330, 400)
# plt.ylim([-i-0.5, 0.5])
# plt.yticks([])
# plt.subplots_adjust(wspace=0.05)
# fig.text(0.5, 0, 'Deviance', ha='center', fontsize=14)
../_images/05-model_comparison_35_0.png

Regularizing priors

plt.figure(figsize=(8, 6))
x_values = jnp.linspace(-10, 10, 1000)
for df in [1, 2, 5, 15]:
#     distri = stats.laplace(scale=df)
    distri = dist.Laplace(scale=df)
    x_pdf = jnp.exp(distri.log_prob(x_values))
    plt.plot(x_values, x_pdf, label=f'b = {df}')

x_pdf = jnp.exp(dist.Normal().log_prob(x_values))
    #stats.norm.pdf(x_values)
plt.plot(x_values, x_pdf, label='Gaussian')
plt.xlabel('x')
plt.yticks([])
plt.legend()
plt.xlim(-7, 7)
(-7.0, 7.0)
../_images/05-model_comparison_37_1.png
# x = range(0, 10)
x = jnp.arange(0, 10)
q = dist.Binomial(total_count=10, probs=0.75)
# r = random.randint(key=random.PRNGKey(0), shape=x, minval=0, maxval=10,)
r = dist.Uniform(low=0, high=10)
# data = stats.bernoulli.rvs(p=theta_real, size=trials)
# data = dist.Bernoulli(probs=theta_real).sample(random.PRNGKey(1), (trials,))

true_distribution = [list(q.sample(random.PRNGKey(i), (200,))).count(i) / 200 for i in x]

q_pmf = jnp.exp(q.log_prob(x))
r_pmf = jnp.exp(r.log_prob(x))

_, ax = plt.subplots(1, 3, figsize=(12, 4), sharey=True,
                     constrained_layout=True)

for idx, (dist, label) in enumerate(zip([true_distribution, q_pmf, r_pmf], ['true_distribution', 'q', 'r'])):
    h = -jnp.sum(jnp.array(dist) * jnp.log(jnp.array(dist))) # Entropy
    ax[idx].vlines(x, 0, dist, label=f'entropy = {h:.2f}')
    ax[idx].set_title(label)
    ax[idx].set_xticks(x)
    ax[idx].legend(loc=2, handlelength=0)
../_images/05-model_comparison_39_0.png
# import numpy as np
# np.random.seed(912)
# x = range(0, 10)
# q = scipy.stats.binom(10, 0.75)
# r = scipy.stats.randint(0, 10)

# true_distribution = [list(q.rvs(200)).count(i) / 200 for i in x]

# q_pmf = q.pmf(x)
# r_pmf = r.pmf(x)

# _, ax = plt.subplots(1, 3, figsize=(12, 4), sharey=True,
#                      constrained_layout=True)

# for idx, (distri, label) in enumerate(zip([true_distribution, q_pmf, r_pmf], ['true_distribution', 'q', 'r'])):
#     ax[idx].vlines(x, 0, dist, label=f'entropy = {scipy.stats.entropy(dist):.2f}')
#     print("X", dist)
#     print("XX", type(dist))
#     ax[idx].set_title(label)
#     ax[idx].set_xticks(x)
#     ax[idx].legend(loc=2, handlelength=0)

# plt.savefig('B11197_05_15.png', dpi=300)
# If qk is not None, then compute the Kullback-Leibler divergence S = sum(pk * log(pk / qk), axis=axis).
# H = -jnp.sum(jnp.array(dist) * jnp.log(jnp.array(dist))) # Entropy

def KLD(pk, qk, base=None, axis=0):
    
    if base is not None and base <= 0:
        raise ValueError("`base` must be a positive number or `None`.")

    pk = jnp.asarray(pk)
    pk = 1.0*pk / jnp.sum(pk, axis=axis, keepdims=True)
    if qk is None:
        vec = special.entr(pk)
    else:
        qk = jnp.asarray(qk)
        pk, qk = jnp.broadcast_arrays(pk, qk)
        qk = 1.0*qk / jnp.sum(qk, axis=axis, keepdims=True)
        vec = scipy.special.rel_entr(pk, qk)
    S = jnp.sum(vec, axis=axis)
    if base is not None:
        S /= jnp.log(base)
    return S

KLD(q_pmf, r_pmf)
DeviceArray(0.71549076, dtype=float32)
import scipy
scipy.stats.entropy(true_distribution, q_pmf), scipy.stats.entropy(true_distribution, r_pmf)
(0.0074029481416101415, 0.7618786232356769)
scipy.stats.entropy(r_pmf, q_pmf), scipy.stats.entropy(q_pmf, r_pmf)
(2.7569299, 0.71549076)