Chapter 7. Gaussian Processes

import os
import warnings

import arviz as az
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

import jax.numpy as jnp
from jax import random, vmap, local_device_count, pmap, lax, tree_map
from jax import nn as jnn
from jax.scipy import stats, special

import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim

from numpyro.infer import MCMC, NUTS, HMC, Predictive
from numpyro.diagnostics import hpdi, print_summary
from numpyro.infer import Predictive, SVI, Trace_ELBO, init_to_value
from numpyro.infer.autoguide import AutoLaplaceApproximation

seed=1234

if "SVG" in os.environ:
    %config InlineBackend.figure_formats = ["svg"]
warnings.formatwarning = lambda message, category, *args, **kwargs: "{}: {}\n".format(
    category.__name__, message
)
az.style.use("arviz-darkgrid")
numpyro.set_platform("cpu") # or "gpu", "tpu" depending on system
numpyro.set_host_device_count(local_device_count())
# import pymc3 as pm
# import numpy as np
# import pandas as pd
# from scipy import stats
# from scipy.special import expit as logistic

# import matplotlib.pyplot as plt
# import arviz as az
# az.style.use('arviz-darkgrid')

Modeling functions

x = jnp.linspace(0, 1, 10)

y = dist.Normal(0, 1).sample(random.PRNGKey(0), (len(x),))
# y = np.random.normal(0, 1, len(x))
plt.plot(x, y, 'o-', label='the first one')

y = jnp.zeros_like(x)
for i in range(len(x)):
    # x[idx] = y``, use ``x = x.at[idx].set(y)
    y = y.at[i].set(dist.Normal(y[i-1], 1).sample(random.PRNGKey(i*i)))
#     y[i] = dist.Normal(y[i-1], 1)
plt.plot(x, y, 'o-', label='the second one')

plt.legend()
<matplotlib.legend.Legend at 0x7fac6159f460>
../_images/07-gaussian-process_5_1.png

Covariance functions and kernels

def exp_quad_kernel(x, knots, =1):
    """exponentiated quadratic kernel"""
    return jnp.array([jnp.exp(-(x-k)**2 / (2***2)) for k in knots])
#  def linear_kernel(x, knots):
#      """ linear kernel """
#     return np.array([(x - 2) * (k - 2) for k in knots])
data = jnp.array([-1, 0, 1, 2])  # np.random.normal(size=4)
cov = exp_quad_kernel(data, data, 1)

_, ax = plt.subplots(1, 2, figsize=(12, 5))
ax = list(ax.flat)

ax[0].plot(data, jnp.zeros_like(data), 'ko')
ax[0].set_yticks([])
for idx, i in enumerate(data):
    ax[0].text(i, 0+0.005, idx)
ax[0].set_xticks(data)
ax[0].set_xticklabels(jnp.round(data, 2))
#ax[0].set_xticklabels(np.round(data, 2), rotation=70)

ax[1].grid(False)
im = ax[1].imshow(cov)
colors = ['w', 'k']
for i in range(len(cov)):
    for j in range(len(cov)):
        ax[1].text(j, i, round(cov[i, j], 2),
                   color=colors[int(im.norm(cov[i, j]) > 0.5)],
                   ha='center', va='center', fontdict={'size': 16})
ax[1].set_xticks(range(len(data)))
ax[1].set_yticks(range(len(data)))
ax[1].xaxis.tick_top()
../_images/07-gaussian-process_9_0.png
cov
DeviceArray([[1.        , 0.99998736, 0.9999495 , ..., 0.6126262 ,
              0.6095785 , 0.60653067],
             [0.99998736, 1.        , 0.99998736, ..., 0.6156736 ,
              0.6126262 , 0.6095785 ],
             [0.9999495 , 0.99998736, 1.        , ..., 0.6187206 ,
              0.61567366, 0.61262625],
             ...,
             [0.6126262 , 0.6156736 , 0.6187206 , ..., 1.        ,
              0.99998736, 0.9999495 ],
             [0.6095785 , 0.6126262 , 0.61567366, ..., 0.99998736,
              1.        , 0.99998736],
             [0.60653067, 0.6095785 , 0.61262625, ..., 0.9999495 ,
              0.99998736, 1.        ]], dtype=float32)
import numpy as onp
onp.linalg.cholesky(cov)
---------------------------------------------------------------------------
LinAlgError                               Traceback (most recent call last)
/var/folders/9y/6kx7fns90pn84gtycx7dyl680000gn/T/ipykernel_69498/231636669.py in <module>
      1 import numpy as onp
----> 2 onp.linalg.cholesky(cov)

<__array_function__ internals> in cholesky(*args, **kwargs)

/usr/local/anaconda3/envs/bap-numpyro/lib/python3.8/site-packages/numpy/linalg/linalg.py in cholesky(a)
    762     t, result_t = _commonType(a)
    763     signature = 'D->D' if isComplexType(t) else 'd->d'
--> 764     r = gufunc(a, signature=signature, extobj=extobj)
    765     return wrap(r.astype(result_t, copy=False))
    766 

/usr/local/anaconda3/envs/bap-numpyro/lib/python3.8/site-packages/numpy/linalg/linalg.py in _raise_linalgerror_nonposdef(err, flag)
     89 
     90 def _raise_linalgerror_nonposdef(err, flag):
---> 91     raise LinAlgError("Matrix is not positive definite")
     92 
     93 def _raise_linalgerror_eigenvalues_nonconvergence(err, flag):

LinAlgError: Matrix is not positive definite
scipy.stats.multivariate_normal.rvs(cov=cov, size=2).T
RuntimeWarning: covariance is not positive-semidefinite.
array([[1.22834997, 1.26860861],
       [1.22305191, 1.2751146 ],
       [1.21725241, 1.28201152],
       [1.21087143, 1.28862307],
       [1.20532334, 1.29586049],
       [1.19978672, 1.30182685],
       [1.19356774, 1.30673356],
       [1.18807364, 1.31364538],
       [1.18212254, 1.32064075],
       [1.17558817, 1.32581598],
       [1.16981591, 1.3327004 ],
       [1.16421756, 1.33831306],
       [1.15860885, 1.34451221],
       [1.15313842, 1.34923182],
       [1.14646254, 1.355225  ],
       [1.14124387, 1.36029765],
       [1.13549547, 1.36516422],
       [1.12895057, 1.3713957 ],
       [1.12422938, 1.37655859],
       [1.11764509, 1.38207624],
       [1.11204476, 1.38639213],
       [1.10619711, 1.39099539],
       [1.10072925, 1.39708865],
       [1.0950571 , 1.40125147],
       [1.09013638, 1.40544347],
       [1.08381375, 1.4103632 ],
       [1.07895087, 1.41464009],
       [1.07328475, 1.4184839 ],
       [1.06750644, 1.42329084],
       [1.06183023, 1.42742417],
       [1.05645092, 1.43089956],
       [1.05126399, 1.43428785],
       [1.04542456, 1.43898383],
       [1.03978415, 1.44227499],
       [1.03456501, 1.44626591],
       [1.02953554, 1.44804436],
       [1.02360813, 1.45142245],
       [1.01920607, 1.45513867],
       [1.01357327, 1.45950914],
       [1.00798219, 1.46200108],
       [1.00284421, 1.46440118],
       [0.99651242, 1.46800336],
       [0.99099775, 1.47005644],
       [0.98627839, 1.47276092],
       [0.98221156, 1.47506495],
       [0.97691585, 1.47802344],
       [0.97247569, 1.48019942],
       [0.9661773 , 1.48111496],
       [0.96126503, 1.48390362],
       [0.95606151, 1.48574872],
       [0.95172718, 1.4883004 ],
       [0.94610328, 1.48874033],
       [0.94158849, 1.49187074],
       [0.93730434, 1.49234079],
       [0.93092552, 1.49377998],
       [0.92741322, 1.49487654],
       [0.92216119, 1.49596871],
       [0.91849701, 1.49708688],
       [0.91310846, 1.49883881],
       [0.90863707, 1.49967159],
       [0.90398875, 1.49973405],
       [0.89940298, 1.50073857],
       [0.89484153, 1.50131931],
       [0.89034191, 1.50122569],
       [0.88606359, 1.50265256],
       [0.88157595, 1.50225347],
       [0.87653267, 1.50285578],
       [0.87258591, 1.50291047],
       [0.86782887, 1.50250208],
       [0.86452653, 1.50279476],
       [0.86065832, 1.50239505],
       [0.85564012, 1.50151861],
       [0.85210079, 1.50201755],
       [0.8481202 , 1.50108929],
       [0.84366669, 1.5002522 ],
       [0.83962756, 1.49808148],
       [0.8356202 , 1.49849782],
       [0.83197736, 1.49880421],
       [0.82797261, 1.49686268],
       [0.82347324, 1.49549258],
       [0.8204014 , 1.4942432 ],
       [0.81720691, 1.49270749],
       [0.8130361 , 1.49200511],
       [0.80912126, 1.49004347],
       [0.80549373, 1.48851091],
       [0.80225647, 1.48658018],
       [0.79809758, 1.48493699],
       [0.79511329, 1.4817913 ],
       [0.79192946, 1.48038088],
       [0.7891514 , 1.47810695],
       [0.7854522 , 1.47645541],
       [0.78212484, 1.47345521],
       [0.77865185, 1.47120853],
       [0.77571075, 1.46861671],
       [0.77150938, 1.46693716],
       [0.7695443 , 1.46372656],
       [0.76581067, 1.46127417],
       [0.7637027 , 1.45819984],
       [0.76024889, 1.45498951],
       [0.75696391, 1.45247269],
       [0.7546983 , 1.45085776],
       [0.7517832 , 1.44627637],
       [0.74901317, 1.44389102],
       [0.74644898, 1.43970353],
       [0.74398602, 1.43642019],
       [0.74105686, 1.43273793],
       [0.73920703, 1.42944461],
       [0.73657213, 1.42556891],
       [0.73358376, 1.42228537],
       [0.73108861, 1.41828222],
       [0.72876998, 1.41394623],
       [0.72596376, 1.40999813],
       [0.72300718, 1.40632597],
       [0.72251101, 1.40201102],
       [0.71857912, 1.39822693],
       [0.71784158, 1.39355986],
       [0.71530644, 1.38924087],
       [0.71339197, 1.38507457],
       [0.71041813, 1.38049428],
       [0.70887513, 1.37690037],
       [0.70689073, 1.37126061],
       [0.70548095, 1.36748257],
       [0.70313452, 1.36287306],
       [0.70137117, 1.35748969],
       [0.69989588, 1.35292971],
       [0.69938697, 1.34837913],
       [0.69676902, 1.34261635],
       [0.69468667, 1.33793432],
       [0.69259705, 1.33352073],
       [0.69194959, 1.3278962 ],
       [0.6903998 , 1.32271078],
       [0.68877001, 1.31702569],
       [0.6871487 , 1.31253962],
       [0.68630734, 1.30723765],
       [0.68402401, 1.30190062],
       [0.68343193, 1.29688219],
       [0.68257924, 1.29108299],
       [0.68067166, 1.28559347],
       [0.67953412, 1.27911588],
       [0.67983449, 1.27432699],
       [0.67758668, 1.26895524],
       [0.67628052, 1.26263415],
       [0.67525752, 1.25754894],
       [0.67407117, 1.25119887],
       [0.67317311, 1.24573331],
       [0.67200599, 1.23941705],
       [0.67195932, 1.23436487],
       [0.67071368, 1.22802535],
       [0.66989234, 1.22134134],
       [0.67050067, 1.21576575],
       [0.66888637, 1.20948851],
       [0.66870897, 1.20360157],
       [0.66679842, 1.19817939],
       [0.66666204, 1.19185133],
       [0.66639067, 1.18604894],
       [0.6652358 , 1.17979469],
       [0.66495332, 1.17354564],
       [0.6642215 , 1.16695682],
       [0.66393073, 1.16135227],
       [0.66378381, 1.15480639],
       [0.6629655 , 1.14799279],
       [0.66309825, 1.14197151],
       [0.66175358, 1.13561118],
       [0.66227839, 1.12984325],
       [0.66170638, 1.12354541],
       [0.66109193, 1.11635925],
       [0.66162106, 1.11025977],
       [0.6608557 , 1.10447617],
       [0.66106632, 1.09680616],
       [0.66101192, 1.09108545],
       [0.66084843, 1.08395282],
       [0.66052339, 1.0783072 ],
       [0.66045684, 1.07232264],
       [0.66042845, 1.06546908],
       [0.66104575, 1.0590404 ],
       [0.66038399, 1.05307127],
       [0.66134973, 1.04614853],
       [0.66002179, 1.04034875],
       [0.66079018, 1.03380879],
       [0.66114864, 1.0270071 ],
       [0.66114755, 1.02073483],
       [0.6603978 , 1.01477346],
       [0.66063908, 1.00847714],
       [0.66144295, 1.00211527],
       [0.6616524 , 0.99586983],
       [0.66124895, 0.9887918 ],
       [0.66215697, 0.9833618 ],
       [0.66251425, 0.97621597],
       [0.66142867, 0.96940604],
       [0.66257405, 0.9638015 ],
       [0.66264867, 0.95793698],
       [0.66424105, 0.95048408],
       [0.6633457 , 0.9445846 ],
       [0.6636494 , 0.93852713],
       [0.66376808, 0.93158544],
       [0.66425539, 0.92544706],
       [0.66517844, 0.91938353],
       [0.66529752, 0.91325503],
       [0.66454473, 0.90701574],
       [0.66593941, 0.90155903]])
dist.MultivariateNormal(covariance_matrix=cov).sample(random.PRNGKey(0), (2,))
DeviceArray([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan],
             [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan]], dtype=float32)
jnp.zeros(test_points.shape[0]).shape
(200,)
test_points = jnp.linspace(0, 10, 200)
fig, ax = plt.subplots(2, 2, figsize=(12, 6), sharex=True,
                       sharey=True, constrained_layout=True)
ax = list(ax.flat)
import scipy
for idx,  in enumerate((0.2, 1, 2, 10)):
    cov = exp_quad_kernel(test_points, test_points, )
    ax[idx].plot(test_points, dist.MultivariateNormal(loc=jnp.zeros(test_points.shape[0]), covariance_matrix=cov).sample(random.PRNGKey(0), (2,)).T)
    
#     ax[idx].plot(test_points, scipy.stats.multivariate_normal.rvs(cov=cov, size=2).T)
    ax[idx].set_title(f'ℓ ={}')
fig.text(0.51, -0.03, 'x', fontsize=16)
fig.text(-0.03, 0.5, 'f(x)', fontsize=16)
Text(-0.03, 0.5, 'f(x)')
../_images/07-gaussian-process_15_1.png

Gaussian Process regression

# np.random.seed(42)
x = dist.Uniform(low=0, high=10).sample(random.PRNGKey(1), (15,))
y = dist.Normal(loc=jnp.sin(x), scale=0.1).sample(random.PRNGKey(2))
plt.plot(x, y, 'o')
true_x = jnp.linspace(0, 10, 200)
plt.plot(true_x, jnp.sin(true_x), 'k--')
plt.xlabel('x')
plt.ylabel('f(x)', rotation=0)
Text(0, 0.5, 'f(x)')
../_images/07-gaussian-process_17_1.png
# A one dimensional column vector of inputs.
X = x[:, None]
import pyro.contrib.gp as gp
cov = gp.likelihoods.
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/var/folders/9y/6kx7fns90pn84gtycx7dyl680000gn/T/ipykernel_69498/672023081.py in <module>
      1 import pyro.contrib.gp as gp
----> 2 cov = gp.cov.ExpQuad(1, ls=)

AttributeError: module 'pyro.contrib.gp' has no attribute 'cov'
def model(obs=None):
    # hyperprior for lengthscale kernel parameter
     = numpyro.sample('ℓ', dist.Gamma(concentration=2, rate=0.5))
    # instanciate a covariance function
    cov = gp.cov.ExpQuad(1, ls=)
#     cov = pm.gp.cov.ExpQuad(1, ls=ℓ)
    # instanciate a GP prior
    gp = pm.gp.Marginal(cov_func=cov)
    # prior
    ϵ = pm.HalfNormal('ϵ', 25)
    # likelihood
    y_pred = gp.marginal_likelihood('y_pred', X=X, y=y, noise=ϵ) 
    
kernel = NUTS(model, target_accept_prob=0.85)
mcmc3 = MCMC(kernel, num_warmup=50, num_samples=50, num_chains=2, chain_method='sequential')
mcmc3.run(random.PRNGKey(seed), obs=jnp.expand_dims(jnp.asarray(cs_exp.values), axis=1))
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
/var/folders/9y/6kx7fns90pn84gtycx7dyl680000gn/T/ipykernel_69498/1453198497.py in <module>
     14 kernel = NUTS(model, target_accept_prob=0.85)
     15 mcmc3 = MCMC(kernel, num_warmup=50, num_samples=50, num_chains=2, chain_method='sequential')
---> 16 mcmc3.run(random.PRNGKey(seed), obs=jnp.expand_dims(jnp.asarray(cs_exp.values), axis=1))

NameError: name 'cs_exp' is not defined
with pm.Model() as model_reg:
    # hyperprior for lengthscale kernel parameter
     = pm.Gamma('ℓ', 2, 0.5)
    # instanciate a covariance function
    cov = pm.gp.cov.ExpQuad(1, ls=)
    # instanciate a GP prior
    gp = pm.gp.Marginal(cov_func=cov)
    # prior
    ϵ = pm.HalfNormal('ϵ', 25)
    # likelihood
    y_pred = gp.marginal_likelihood('y_pred', X=X, y=y, noise=ϵ)
    trace_reg = pm.sample(2000)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [ϵ, ℓ]
Sampling 2 chains: 100%|██████████| 5000/5000 [00:15<00:00, 317.05draws/s]
az.plot_trace(trace_reg)
plt.savefig('B11197_07_05.png')
../_images/07-gaussian-process_22_0.png
X_new = np.linspace(np.floor(x.min()), np.ceil(x.max()), 100)[:,None]

with model_reg:
    #del marginal_gp_model.named_vars['f_pred']
    #marginal_gp_model.vars.remove(f_pred)
    f_pred = gp.conditional('f_pred', X_new)
with model_reg:
    pred_samples = pm.sample_posterior_predictive(trace_reg, vars=[f_pred], samples=82)
100%|██████████| 82/82 [00:03<00:00,  2.75it/s]
_, ax = plt.subplots(figsize=(12,5))
ax.plot(X_new, pred_samples['f_pred'].T, 'C1-', alpha=0.3)
ax.plot(X, y, 'ko')
ax.set_xlabel('X')
plt.savefig('B11197_07_06.png')
../_images/07-gaussian-process_25_0.png
_, ax = plt.subplots(figsize=(12,5))

pm.gp.util.plot_gp_dist(ax, pred_samples['f_pred'], X_new, palette='viridis', plot_samples=False);

ax.plot(X, y, 'ko')
ax.set_xlabel('x')
ax.set_ylabel('f(x)', rotation=0, labelpad=15)
plt.savefig('B11197_07_07.png')
../_images/07-gaussian-process_26_0.png
# plot the results
_, ax = plt.subplots(figsize=(12,5))

# predict
point = {'ℓ': trace_reg['ℓ'].mean(), 'ϵ': trace_reg['ϵ'].mean()}
mu, var = gp.predict(X_new, point=point, diag=True)
sd = var**0.5

# plot mean and 1σ and 2σ intervals
ax.plot(X_new, mu, 'C1')
ax.fill_between(X_new.flatten(),
                 mu - sd, mu + sd,
                 color="C1",
                 alpha=0.3)

ax.fill_between(X_new.flatten(),
                 mu - 2*sd, mu + 2*sd,
                 color="C1",
                 alpha=0.3)

ax.plot(X, y, 'ko')
ax.set_xlabel('X')
plt.savefig('B11197_07_08.png')
../_images/07-gaussian-process_27_0.png

Regression with spatial autocorrelation

islands_dist = pd.read_csv('../data/islands_dist.csv',
                           sep=',', index_col=0)
islands_dist.round(1)
Ml Ti SC Ya Fi Tr Ch Mn To Ha
Malekula 0.0 0.5 0.6 4.4 1.2 2.0 3.2 2.8 1.9 5.7
Tikopia 0.5 0.0 0.3 4.2 1.2 2.0 2.9 2.7 2.0 5.3
Santa Cruz 0.6 0.3 0.0 3.9 1.6 1.7 2.6 2.4 2.3 5.4
Yap 4.4 4.2 3.9 0.0 5.4 2.5 1.6 1.6 6.1 7.2
Lau Fiji 1.2 1.2 1.6 5.4 0.0 3.2 4.0 3.9 0.8 4.9
Trobriand 2.0 2.0 1.7 2.5 3.2 0.0 1.8 0.8 3.9 6.7
Chuuk 3.2 2.9 2.6 1.6 4.0 1.8 0.0 1.2 4.8 5.8
Manus 2.8 2.7 2.4 1.6 3.9 0.8 1.2 0.0 4.6 6.7
Tonga 1.9 2.0 2.3 6.1 0.8 3.9 4.8 4.6 0.0 5.0
Hawaii 5.7 5.3 5.4 7.2 4.9 6.7 5.8 6.7 5.0 0.0
islands = pd.read_csv('../data/islands.csv', sep=',')
islands.head().round(1)
culture population contact total_tools mean_TU lat lon lon2 logpop
0 Malekula 1100 low 13 3.2 -16.3 167.5 -12.5 7.0
1 Tikopia 1500 low 22 4.7 -12.3 168.8 -11.2 7.3
2 Santa Cruz 3600 low 24 4.0 -10.7 166.0 -14.0 8.2
3 Yap 4791 high 43 5.0 9.5 138.1 -41.9 8.5
4 Lau Fiji 7400 high 33 5.0 -17.7 178.1 -1.9 8.9
islands_dist_sqr = islands_dist.values**2
culture_labels = islands.culture.values
index = islands.index.values
log_pop = islands.logpop
total_tools = islands.total_tools
x_data = [islands.lat.values[:, None], islands.lon.values[:, None]]
with pm.Model() as model_islands:
    η = pm.HalfCauchy('η', 1)
     = pm.HalfCauchy('ℓ', 1)
    
    cov = η * pm.gp.cov.ExpQuad(1, ls=)
    gp = pm.gp.Latent(cov_func=cov)
    f = gp.prior('f', X=islands_dist_sqr)

    α = pm.Normal('α', 0, 10)
    β = pm.Normal('β', 0, 1)
    μ = pm.math.exp(α + f[index] + β * log_pop)
    tt_pred = pm.Poisson('tt_pred', μ, observed=total_tools)
    trace_islands = pm.sample(1000, tune=1000)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [β, α, f_rotated_, ℓ, η]
Sampling 2 chains: 100%|██████████| 4000/4000 [02:35<00:00, 24.91draws/s]
There was 1 divergence after tuning. Increase `target_accept` or reparameterize.
There were 6 divergences after tuning. Increase `target_accept` or reparameterize.
The number of effective samples is smaller than 25% for some parameters.
#az.summary(trace_islands, var_names=['α', 'β', 'η', 'ℓ', 'f'])
trace_η = trace_islands['η']
trace_ℓ = trace_islands['ℓ']

_, ax = plt.subplots(1, 1, figsize=(8, 5))
xrange = np.linspace(0, islands_dist.values.max(), 100)

ax.plot(xrange, np.median(trace_η) *
        np.exp(-np.median(trace_ℓ) * xrange**2), lw=3)

ax.plot(xrange, (trace_η[::20][:, None] * np.exp(- trace_ℓ[::20][:, None] * xrange**2)).T,
        'C0', alpha=.1)

ax.set_ylim(0, 1)
ax.set_xlabel('distance (thousand kilometers)')
ax.set_ylabel('covariance')
plt.savefig('B11197_07_09.png')
../_images/07-gaussian-process_34_0.png
# compute posterior median covariance among societies
Σ = np.median(trace_η) * (np.exp(-np.median(trace_ℓ) * islands_dist_sqr))
# convert to correlation matrix
Σ_post = np.diag(np.diag(Σ)**-0.5)
ρ = Σ_post @  Σ @ Σ_post
ρ = pd.DataFrame(ρ, index=islands_dist.columns, columns=islands_dist.columns)
ρ.round(2)
Ml Ti SC Ya Fi Tr Ch Mn To Ha
Ml 1.00 0.90 0.84 0.00 0.50 0.15 0.01 0.03 0.21 0.0
Ti 0.90 1.00 0.96 0.00 0.50 0.16 0.02 0.04 0.17 0.0
SC 0.84 0.96 1.00 0.00 0.34 0.27 0.05 0.08 0.10 0.0
Ya 0.00 0.00 0.00 1.00 0.00 0.06 0.33 0.31 0.00 0.0
Fi 0.50 0.50 0.34 0.00 1.00 0.01 0.00 0.00 0.77 0.0
Tr 0.15 0.16 0.27 0.06 0.01 1.00 0.23 0.72 0.00 0.0
Ch 0.01 0.02 0.05 0.33 0.00 0.23 1.00 0.51 0.00 0.0
Mn 0.03 0.04 0.08 0.31 0.00 0.72 0.51 1.00 0.00 0.0
To 0.21 0.17 0.10 0.00 0.77 0.00 0.00 0.00 1.00 0.0
Ha 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 1.0
# scale point size to logpop
logpop = np.copy(log_pop)
logpop /= logpop.max()
psize = np.exp(logpop*5.5)
log_pop_seq = np.linspace(6, 14, 100)
lambda_post = np.exp(trace_islands['α'][:, None] +
                     trace_islands['β'][:, None] * log_pop_seq)

_, ax = plt.subplots(1, 2, figsize=(12, 6))

ax[0].scatter(islands.lon2, islands.lat, psize, zorder=3)
ax[1].scatter(islands.logpop, islands.total_tools, psize, zorder=3)

for i, itext in enumerate(culture_labels):
    ax[0].text(islands.lon2[i]+1, islands.lat[i]+1, itext)
    ax[1].text(islands.logpop[i]+.1, islands.total_tools[i]-2.5, itext)


ax[1].plot(log_pop_seq, np.median(lambda_post, axis=0), 'k--')

az.plot_hpd(log_pop_seq, lambda_post, fill_kwargs={'alpha':0},
            plot_kwargs={'color':'k', 'ls':'--', 'alpha':1})


for i in range(10):
    for j in np.arange(i+1, 10):
        ax[0].plot((islands.lon2[i], islands.lon2[j]),
                   (islands.lat[i], islands.lat[j]), 'C1-',
                   alpha=ρ.iloc[i, j]**2, lw=4)
        ax[1].plot((islands.logpop[i], islands.logpop[j]),
                   (islands.total_tools[i], islands.total_tools[j]), 'C1-',
                   alpha=ρ.iloc[i, j]**2, lw=4)
ax[0].set_xlabel('longitude')
ax[0].set_ylabel('latitude')


ax[1].set_xlabel('log-population')
ax[1].set_ylabel('total tools')
ax[1].set_xlim(6.8, 12.8)
ax[1].set_ylim(10, 73)
plt.savefig('B11197_07_10.png')
../_images/07-gaussian-process_36_0.png

Gaussian process classification

iris = pd.read_csv('../data/iris.csv')
iris.head()
sepal_length sepal_width petal_length petal_width species
0 5.1 3.5 1.4 0.2 setosa
1 4.9 3.0 1.4 0.2 setosa
2 4.7 3.2 1.3 0.2 setosa
3 4.6 3.1 1.5 0.2 setosa
4 5.0 3.6 1.4 0.2 setosa
df = iris.query("species == ('setosa', 'versicolor')")
y = pd.Categorical(df['species']).codes
x_1 = df['sepal_length'].values
X_1 = x_1[:, None]
with pm.Model() as model_iris:
    #ℓ = pm.HalfCauchy("ℓ", 1)
     = pm.Gamma('ℓ', 2, 0.5)
    cov = pm.gp.cov.ExpQuad(1, )
    gp = pm.gp.Latent(cov_func=cov)
    f = gp.prior("f", X=X_1)
    # logistic inverse link function and Bernoulli likelihood
    y_ = pm.Bernoulli("y", p=pm.math.sigmoid(f), observed=y)
    trace_iris = pm.sample(1000, chains=1, compute_convergence_checks=False)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (1 chains in 1 job)
NUTS: [f_rotated_, ℓ]
100%|██████████| 1500/1500 [00:42<00:00, 35.10it/s]
X_new = np.linspace(np.floor(x_1.min()), np.ceil(x_1.max()), 200)[:, None]

with model_iris:
    f_pred = gp.conditional('f_pred', X_new)
    pred_samples = pm.sample_posterior_predictive(
        trace_iris, vars=[f_pred], samples=1000)
100%|██████████| 1000/1000 [00:22<00:00, 44.10it/s]
def find_midpoint(array1, array2, value):
    """
    This should be a proper docstring :-)
    """
    array1 = np.asarray(array1)
    idx0 = np.argsort(np.abs(array1 - value))[0]
    idx1 = idx0 - 1 if array1[idx0] > value else idx0 + 1
    if idx1 == len(array1):
        idx1 -= 1
    return (array2[idx0] + array2[idx1]) / 2
_, ax = plt.subplots(figsize=(10, 6))

fp = logistic(pred_samples['f_pred'])
fp_mean = np.mean(fp, 0)

ax.plot(X_new[:, 0], fp_mean)
# plot the data (with some jitter) and the true latent function
ax.scatter(x_1, np.random.normal(y, 0.02),
           marker='.', color=[f'C{x}' for x in y])

az.plot_hpd(X_new[:, 0], fp, color='C2')

db = np.array([find_midpoint(f, X_new[:, 0], 0.5) for f in fp])
db_mean = db.mean()
db_hpd = az.hpd(db)
ax.vlines(db_mean, 0, 1, color='k')
ax.fill_betweenx([0, 1], db_hpd[0], db_hpd[1], color='k', alpha=0.5)
ax.set_xlabel('sepal_length')
ax.set_ylabel('θ', rotation=0)
plt.savefig('B11197_07_11.png')
../_images/07-gaussian-process_43_0.png
with pm.Model() as model_iris2:
    #ℓ = pm.HalfCauchy("ℓ", 1)
     = pm.Gamma('ℓ', 2, 0.5)
    c = pm.Normal('c', x_1.min())
    τ = pm.HalfNormal('τ', 5)
    cov = (pm.gp.cov.ExpQuad(1, ) +
           τ * pm.gp.cov.Linear(1, c) +
           pm.gp.cov.WhiteNoise(1E-5))
    gp = pm.gp.Latent(cov_func=cov)
    f = gp.prior("f", X=X_1)
    # logistic inverse link function and Bernoulli likelihood
    y_ = pm.Bernoulli("y", p=pm.math.sigmoid(f), observed=y)
    trace_iris2 = pm.sample(1000, chains=1, compute_convergence_checks=False)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (1 chains in 1 job)
NUTS: [f_rotated_, τ, c, ℓ]
100%|██████████| 1500/1500 [00:44<00:00, 33.30it/s]
There was 1 divergence after tuning. Increase `target_accept` or reparameterize.
with model_iris2:
    f_pred = gp.conditional('f_pred', X_new)
    pred_samples = pm.sample_posterior_predictive(trace_iris2,
                                                  vars=[f_pred],
                                                  samples=1000)
100%|██████████| 1000/1000 [00:23<00:00, 41.86it/s]
_, ax = plt.subplots(figsize=(10,6))

fp = logistic(pred_samples['f_pred'])
fp_mean = np.mean(fp, 0)

ax.scatter(x_1, np.random.normal(y, 0.02), marker='.', color=[f'C{ci}' for ci in y])

db = np.array([find_midpoint(f, X_new[:,0], 0.5) for f in fp])
db_mean = db.mean()
db_hpd = az.hpd(db)
ax.vlines(db_mean, 0, 1, color='k')
ax.fill_betweenx([0, 1], db_hpd[0], db_hpd[1], color='k', alpha=0.5)

ax.plot(X_new[:,0], fp_mean, 'C2', lw=3)
az.plot_hpd(X_new[:,0], fp, color='C2')

ax.set_xlabel('sepal_length')
ax.set_ylabel('θ', rotation=0)
plt.savefig('B11197_07_12.png')
../_images/07-gaussian-process_46_0.png
df_sf = pd.read_csv('../data/space_flu.csv')
age = df_sf.age.values[:, None]
space_flu = df_sf.space_flu

ax = df_sf.plot.scatter('age', 'space_flu', figsize=(8, 5))
ax.set_yticks([0, 1])
ax.set_yticklabels(['healthy', 'sick'])
plt.savefig('B11197_07_13.png', bbox_inches='tight')
../_images/07-gaussian-process_47_0.png
with pm.Model() as model_space_flu:
     = pm.HalfCauchy('ℓ', 1)
    cov = pm.gp.cov.ExpQuad(1, ) + pm.gp.cov.WhiteNoise(1E-5)
    gp = pm.gp.Latent(cov_func=cov)
    f = gp.prior('f', X=age)
    y_ = pm.Bernoulli('y', p=pm.math.sigmoid(f), observed=space_flu)
    trace_space_flu = pm.sample(
        1000, chains=1, compute_convergence_checks=False)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (1 chains in 1 job)
NUTS: [f_rotated_, ℓ]
100%|██████████| 1500/1500 [00:11<00:00, 127.88it/s]
X_new = np.linspace(0, 80, 200)[:, None]

with model_space_flu:
    f_pred = gp.conditional('f_pred', X_new)
    pred_samples = pm.sample_posterior_predictive(trace_space_flu,
                                                  vars=[f_pred],
                                                  samples=1000)
100%|██████████| 1000/1000 [00:17<00:00, 56.40it/s]
_, ax = plt.subplots(figsize=(10, 6))

fp = logistic(pred_samples['f_pred'])
fp_mean = np.nanmean(fp, 0)

ax.scatter(age, np.random.normal(space_flu, 0.02),
           marker='.', color=[f'C{ci}' for ci in space_flu])

ax.plot(X_new[:, 0], fp_mean, 'C2', lw=3)

az.plot_hpd(X_new[:, 0], fp, color='C2')
ax.set_yticks([0, 1])
ax.set_yticklabels(['healthy', 'sick'])
ax.set_xlabel('age')
plt.savefig('B11197_07_14.png')
../_images/07-gaussian-process_50_0.png

the coal-mining disaster

coal_df = pd.read_csv('../data/coal.csv', header=None)
coal_df.head()
0
0 1851.2026
1 1851.6324
2 1851.9692
3 1851.9747
4 1852.3142
# discretize data
years = int(coal_df.max().values - coal_df.min().values)
bins = years // 4
hist, x_edges = np.histogram(coal_df, bins=bins)
# compute the location of the centers of the discretized data
x_centers = x_edges[:-1] + (x_edges[1] - x_edges[0]) / 2
# arrange xdata into proper shape for GP
x_data = x_centers[:, None]
# express data as the rate number of disaster per year
y_data = hist / 4
with pm.Model() as model_coal:
     = pm.HalfNormal('ℓ', x_data.std())
    cov = pm.gp.cov.ExpQuad(1, ls=) + pm.gp.cov.WhiteNoise(1E-5)
    gp = pm.gp.Latent(cov_func=cov)
    f = gp.prior('f', X=x_data)

    y_pred = pm.Poisson('y_pred', mu=pm.math.exp(f), observed=y_data)
    trace_coal = pm.sample(1000, chains=1)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (1 chains in 1 job)
NUTS: [f_rotated_, ℓ]
100%|██████████| 1500/1500 [00:13<00:00, 127.51it/s]
Only one chain was sampled, this makes it impossible to run some convergence checks
_, ax = plt.subplots(figsize=(10, 6))

f_trace = np.exp(trace_coal['f'])
rate_median = np.median(f_trace, axis=0)

ax.plot(x_centers, rate_median, 'w', lw=3)
az.plot_hpd(x_centers, f_trace)

az.plot_hpd(x_centers, f_trace, credible_interval=0.5,
            plot_kwargs={'alpha': 0})

ax.plot(coal_df, np.zeros_like(coal_df)-0.5, 'k|')
ax.set_xlabel('years')
ax.set_ylabel('rate')
plt.savefig('B11197_07_15.png')
../_images/07-gaussian-process_55_0.png

the redwood data

rw_df = pd.read_csv('../data/redwood.csv', header=None)
_, ax = plt.subplots(figsize=(8, 8))
ax.plot(rw_df[0], rw_df[1], 'C0.')
ax.set_xlabel('x1 coordinate')
ax.set_ylabel('x2 coordinate')
plt.savefig('B11197_07_16.png')
../_images/07-gaussian-process_57_0.png
# discretize spatial data
bins = 20
hist, x1_edges, x2_edges = np.histogram2d(
    rw_df[1].values, rw_df[0].values, bins=bins)
# compute the location of the centers of the discretized data
x1_centers = x1_edges[:-1] + (x1_edges[1] - x1_edges[0]) / 2
x2_centers = x2_edges[:-1] + (x2_edges[1] - x2_edges[0]) / 2
# arrange xdata into proper shape for GP
x_data = [x1_centers[:, None], x2_centers[:, None]]
# arrange ydata into proper shape for GP
y_data = hist.flatten()
with pm.Model() as model_rw:
     = pm.HalfNormal('ℓ',  rw_df.std().values, shape=2)
    cov_func1 = pm.gp.cov.ExpQuad(1, ls=[0])
    cov_func2 = pm.gp.cov.ExpQuad(1, ls=[1])

    gp = pm.gp.LatentKron(cov_funcs=[cov_func1, cov_func2])
    f = gp.prior('f', Xs=x_data)

    y = pm.Poisson('y', mu=pm.math.exp(f), observed=y_data)
    trace_rw = pm.sample(1000)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [f_rotated_, ℓ]
Sampling 2 chains: 100%|██████████| 3000/3000 [00:58<00:00, 50.95draws/s]
The estimated number of effective samples is smaller than 200 for some parameters.
az.summary(trace_rw, var_names=['ℓ'])
mean sd mc error hpd 3% hpd 97% eff_n r_hat
ℓ[0] 0.13 0.04 0.0 0.08 0.19 153.0 1.0
ℓ[1] 0.09 0.03 0.0 0.05 0.14 217.0 1.0
rate = np.exp(np.mean(trace_rw['f'], axis=0).reshape((bins, -1)))
fig, ax = plt.subplots(figsize=(6, 6))
ims = ax.imshow(rate, origin='lower')
ax.grid(False)
ticks_loc = np.linspace(0, bins-1, 6)
ticks_lab = np.linspace(0, 1, 6).round(1)
ax.set_xticks(ticks_loc)
ax.set_yticks(ticks_loc)
ax.set_xticklabels(ticks_lab)
ax.set_yticklabels(ticks_lab)
cbar = fig.colorbar(ims, fraction=0.046, pad=0.04)
plt.savefig('B11197_07_17.png')
../_images/07-gaussian-process_61_0.png