28  Parameter estimation with Markov chain Monte Carlo I

Open in Google Colab | Download notebook

Data set download


Code
# Colab setup ------------------
import os, shutil, sys, subprocess, urllib.request
if "google.colab" in sys.modules:
    cmd = "pip install --upgrade polars iqplot colorcet bebi103 arviz cmdstanpy watermark"
    process = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    stdout, stderr = process.communicate()
    from cmdstanpy.install_cmdstan import latest_version
    cmdstan_version = latest_version()
    cmdstan_url = f"https://github.com/stan-dev/cmdstan/releases/download/v{cmdstan_version}/"
    fname = f"colab-cmdstan-{cmdstan_version}.tgz"
    urllib.request.urlretrieve(cmdstan_url + fname, fname)
    shutil.unpack_archive(fname)
    os.environ["CMDSTAN"] = f"./cmdstan-{cmdstan_version}"
    data_path = "https://s3.amazonaws.com/bebi103.caltech.edu/data/"
else:
    data_path = "../data/"
# ------------------------------
import numpy as np
import scipy.stats as st
import polars as pl

import cmdstanpy
import arviz as az

import iqplot
import bebi103

import bokeh.io
import bokeh.plotting
bokeh.io.output_notebook()
Loading BokehJS ...

In this lesson, we will learn how to use Markov chain Monte Carlo to do parameter estimation. To get the basic idea behind MCMC, imagine for a moment that we can draw samples out of the posterior distribution. This means that the probability of choosing given values of a set of parameters is proportional to the posterior probability of that set of values. If we drew many many such samples, we could reconstruct the posterior from the samples, e.g., by making histograms. That’s a big thing to imagine: that we can draw properly weighted samples. But, it turns out that we can! That is what MCMC allows us to do.

We already discussed some theory behind this seemingly miraculous capability. For this lesson, we will just use the fact that we can do the sampling to learn about posterior distributions in the context of parameter estimation.

28.1 The data set

The data set we will use is synthetic, but will serve to instruct on parameter estimation techniques. The synthetic experiment is as described in Section 27.2. In the experiment, we take a sample of retinal tissue and expose it to a constant light source. We measure the spiking activity of a single retinal ganglion cell (RGC) for one minute, recording \(n\) spikes.

# Load in as dataframe
df = pl.read_csv(os.path.join(data_path, 'rgc_spike_times_1.csv'))

# Spike times in milliseconds for convenience
spike_times = df['spike time (ms)'].to_numpy()

# Interspike intervals
y = np.concatenate(((spike_times[0],), np.diff(spike_times)))

# Make ECDF
bokeh.io.show(
    iqplot.ecdf(y, 'interspike interval (ms)')
)

Just glancing at this, we see that we have plenty of interspike intervals and they appear to be Exponentially distributed.

28.2 Statistical model, take 1

We will now formulate a generative model. We will choose an Exponential likelihood, assuming that the spike arrival in these constant conditions are best modeled as a Poisson process. Let \(t_i\) be the time at which spike \(i\) arrives (with \(t_0 = 0\) by definition), and the interspike interval \(i\) as \(y_i = t_i - t_{i-1}\). Then, our likelihood is

\[\begin{align} y_i \sim \text{Expon}(\beta)\;\forall i. \end{align} \]

We need to specify a prior for the rate parameter \(\beta\). The rate parameter must be positive, so we will choose a distribution defined on positive numbers. We will choose to use a Gamma distribution. I do not know much about spiking, so I will choose a weakly informative prior with 95% of the probability mass lying between a spiking rate of 1 and 1000 Hz. Using the Distribution Explorer’s quantile setting tool, I find that I should have Gamma parameters \(a = 0.65\) and \(b = 2.9\times 10^{-3}\) Hz. That is fine if I am going to work in units of seconds, but my data are in units of milliseconds, so I should work in units of kHz to be consistent. By the change of variables formula, I should use \(b = 2.9\) kHz. So, my generative model, with the understanding that all units are consistent with the interspike intervals being in units of milliseconds, is

\[\begin{align} &\beta \sim \text{Gamma}(0.65, 2.9),\\[1em] &y_i \sim \text{Expon}(\beta)\;\forall i. \end{align} \]

We can code this model up in Stan as follows.

data {
    int<lower=2> n;
    array[n] real spike_times;
}


transformed data {
    // Parameters for the prior distribution for beta
    real a = 0.65;
    real b = 2.9;

    // Sorted spike times
    array[n] real t = sort_asc(spike_times);

    // Interspike intervals
    array[n] real y;
    y[1] = t[1];
    for (i in 2:n) {
        y[i] = t[i] - t[i-1];
    }
}


parameters {
    real<lower=0> beta_;
}


model {
    beta_ ~ gamma(a, b);
    y ~ exponential(beta_);
}

Let’s run the model and see what we get for a posterior!

data = dict(n=len(spike_times), spike_times=spike_times)

with bebi103.stan.disable_logging():
    sm = cmdstanpy.CmdStanModel(stan_file='rgc_spike_times_expon.stan')
    samples = az.from_cmdstanpy(sm.sample(data=data))
                                                                                                                                                                                                                                                                                                                                

We can plot our MCMC samples as an ECDF or histgram. The binning bias of a histogram is less of an issues with MCMC samples because we can take arbitrarily many of them. So, let’s just plot a histogram.

p = iqplot.histogram(
    samples.posterior.beta_.values.ravel(),
    rug=False,
    density=True,
    frame_width=400,
    frame_height=200,
    x_axis_label="β (1/ms)",
    y_axis_label="g(β | y)",
)

bokeh.io.show(p)

We have estimated our rate as about 0.024 kHz, or about 24 Hz. Coincidentally, for our choice of prior, we may write the posterior PDF down analytically. You will prove this in an exercise. Given that our prior is Gamma\((a, b)\), the posterior is

\[\begin{align} g(\beta\mid y) = \frac{(b+n\bar{y})^\alpha}{\Gamma(a+n)}\,\beta^{a + n-1}\,\mathrm{e}^{-(b+n\bar{y})\beta}, \end{align} \]

where

\[\begin{align} \bar{y} = \frac{1}{n}\sum_{i=1}^n y_i \end{align} \]

is the mean interspike interval. Just to verify that we are sampling properly, I can overlay this PDF.

# Analytical posterior
a = 0.65
b = 2.9
beta_theor = np.linspace(0.02, 0.03, 400)
posterior = st.gamma.pdf(beta_theor, len(y) + a, loc=0, scale=1/(np.sum(y) + b))

# Add to plot
p.line(beta_theor, posterior, line_width=2, color='tomato')

bokeh.io.show(p)

Indeed, the PDFs match!

28.3 Summarizing the posterior

In the next section (Chapter 29), we discuss how to concisely report summaries of the posterior. In this case, the posterior is simple, and the above plot works very well. We could alternatively report the median and central 95% credible interval.

med = np.median(samples.posterior.beta_)
cred_int = np.percentile(samples.posterior.beta_, [2.5, 97.5])

print('[{1:.4f}, {0:.4f}, {2:.4f}] kHz'.format(med, *cred_int))
[0.0227, 0.0238, 0.0251] kHz
bebi103.stan.clean_cmdstan()

28.4 Computing environment

%load_ext watermark
%watermark -v -p numpy,scipy,polars,cmdstanpy,arviz,bokeh,iqplot,bebi103,jupyterlab
print("cmdstan   :", bebi103.stan.cmdstan_version())
Python implementation: CPython
Python version       : 3.12.11
IPython version      : 9.1.0

numpy     : 2.1.3
scipy     : 1.15.3
polars    : 1.30.0
cmdstanpy : 1.2.5
arviz     : 0.21.0
bokeh     : 3.6.2
iqplot    : 0.3.7
bebi103   : 0.1.27
jupyterlab: 4.3.7

cmdstan   : 2.36.0