import numpy as npimport scipy.stats as stimport polars as plimport cmdstanpyimport arviz as azimport iqplotimport bebi103import bokeh.ioimport bokeh.plottingbokeh.io.output_notebook()
Loading BokehJS ...
After posterior predictive checks, we are unsatisfied with modeling spiking as a Poisson process and need to either update or rethink our model. One option that is prevalent in the literature is to model the ISI as an Inverse Gaussian distribution. Let us try that. Our likelihood is
The parameter \(\mu\) is the mean ISI. We will choose a weakly informative prior, which is Log Normal with most of the probability mass between 1 and 1000 milliseconds. The variance for the Inverse Gaussian distribution is \(\mu^3/\lambda\). I am unsure about reasonable values of \(\lambda\), so, I will also choose a Log Normal distribution and allow it to vary over a few orders of magnitude. So, our model is
Stan does not have the Inverse Gaussian among the distributions available. Therefore, we will need to write our own functions, such as inv_gaussian_lpdf() to give the log PDF of the distribution and inv_gaussian_rng() to generate random numbers from it. These functions are available from the Distribution Explorer, so we can copy and paste them into the functions block of our Stan code. (We do that here for pedagogical purposes. It is preferred instead to have a file called, e.g., inv_gaussian.stanfunctions and to then have #include inv_gaussian.stanfunctions in the functions block.)
Our new Stan code is shown below.
functions {/* Functions for inverse Gaussian distribution. *//** * Log PDF of Inverse Gaussian distribution * * @param real y * @param real mu * @param real lambda * @return log PDF of Inverse Gaussian distribution for scalar input */real inv_gaussian_lpdf(real y, real mu, real lambda) {if (y <= 0 || is_nan(y)) {reject("inverse_gaussian_lpdf: y must be greater than 0; found y = ", y); }if (mu <= 0 || is_nan(mu)) {reject("inverse_gaussian_lpdf: mu must be greater than 0; found mu = ", mu); }if (lambda <= 0 || is_nan(lambda)) {reject("inverse_gaussian_lpdf: lambda must be greater than 0; found lambda = ", lambda); }real logpdf = -log(2 * pi()) / 2.0 + (log(lambda) - 3.0 * log(y)) / 2.0 - lambda * (y - mu)^2 / (2.0 * mu^2 * y);return logpdf; }/** * Log PDF of Inverse Gaussian distribution * * @param array[] real y * @param real mu * @param real lambda * @return real log PDF of Inverse Gaussian distribution for array input */real inv_gaussian_lpdf(array[] real y, real mu, real lambda) {if (mu <= 0 || is_nan(mu)) {reject("inverse_gaussian_lpdf: mu must be greater than 0; found mu = ", mu); }if (lambda <= 0 || is_nan(lambda)) {reject("inverse_gaussian_lpdf: lambda must be greater than 0; found lambda = ", lambda); }real logpdf = -num_elements(y) * log(2 * pi()) / 2.0;for (yi in y) {if (yi <= 0 || is_nan(yi)) {reject("inverse_gaussian_lpdf: all elements of y must be greater than 0 and not NaN, got ", yi); }else { logpdf += (log(lambda) - 3.0 * log(yi)) / 2.0 - lambda * (yi - mu)^2 / (2.0 * mu^2 * yi); } }return logpdf; }/** * Log PDF of Inverse Gaussian distribution * * @param vector y * @param real mu * @param real lambda * @return real log PDF of Inverse Gaussian distribution for vector input */real inv_gaussian_lpdf(vector y, real mu, real lambda) {if (mu <= 0 || is_nan(mu)) {reject("inverse_gaussian_lpdf: mu must be greater than 0; found mu = ", mu); }if (lambda <= 0 || is_nan(lambda)) {reject("inverse_gaussian_lpdf: lambda must be greater than 0; found lambda = ", lambda); }for (yi in y) {if (yi <= 0 || is_nan(yi)) {reject("inverse_gaussian_lpdf: all elements of y must be greater than 0 and not NaN, got ", yi); } }real logpdf = num_elements(y) * (log(lambda) - log(2 * pi())) / 2.0 - 1.5 * sum(log(y)) - lambda * sum((y - mu).^2 ./ y) / (2.0 * mu^2);return logpdf; }/** * Log CDF of Inverse Gaussian distribution * * @param real y * @param real mu * @param real lambda * @return log CDF of Inverse Gaussian distribution for scalar input */real inv_gaussian_lcdf(real y, real mu, real lambda) {if (y <= 0 || is_nan(y)) {reject("inverse_gaussian_lcdf: y must be greater than 0; found y = ", y); }if (mu <= 0 || is_nan(mu)) {reject("inverse_gaussian_lcdf: mu must be greater than 0; found mu = ", mu); }if (lambda <= 0 || is_nan(lambda)) {reject("inverse_gaussian_lcdf: lambda must be greater than 0; found lambda = ", lambda); }real term1 = std_normal_lcdf(sqrt(lambda / y) * (y / mu - 1.0));real term2 = 2.0 * lambda / mu + std_normal_lcdf(-sqrt(lambda / y) * (y / mu + 1.0));return log_sum_exp(term1, term2); }/** * Log CCDF of Inverse Gaussian distribution * * @param real y * @param real mu * @param real lambda * @return log CCDF of Inverse Gaussian distribution */real inv_gaussian_lccdf(real y, real mu, real lambda) {if (y <= 0 || is_nan(y)) {reject("inverse_gaussian_lccdf: y must be greater than 0; found y = ", y); }if (mu <= 0 || is_nan(mu)) {reject("inverse_gaussian_lccdf: mu must be greater than 0; found mu = ", mu); }if (lambda <= 0 || is_nan(lambda)) {reject("inverse_gaussian_lccdf: lambda must be greater than 0; found lambda = ", lambda); }real term1 = std_normal_lccdf(sqrt(lambda / y) * (y / mu - 1.0));real term2 = 2.0 * lambda / mu + std_normal_lcdf(-sqrt(lambda / y) * (y / mu + 1.0));return log_diff_exp(term1, term2); }/** * Draw a random number from Inverse Gaussian distribution * * @param real mu * @param real lambda * @return A random number drawn from Inverse Gaussian distribution */real inv_gaussian_rng(real mu, real lambda) {if (mu <= 0 || is_nan(mu)) {reject("inverse_gaussian_rng: mu must be greater than 0; found mu = ", mu); }if (lambda <= 0 || is_nan(lambda)) {reject("inverse_gaussian_rng: lambda must be greater than 0; found lambda = ", lambda); }real y = std_normal_rng()^2;real mu2 = mu^2;real x = mu + (mu2 * y - mu * sqrt(4.0 * mu * lambda * y + mu2 * y^2)) / (2.0 * lambda);real z = uniform_rng(0, 1);real return_value;if (z <= mu / (mu + x)) { return_value = x; } else { return_value = mu2 / x; }return return_value; } }data {int<lower=2> n;array[n] real spike_times;}transformed data {// Parameters for the prior distributionsreal mu_mu = 1.5;real sigma_mu = 1.5;real mu_lambda = 1.5;real sigma_lambda = 1.5;// Sorted spike timesarray[n] real t = sort_asc(spike_times);// Interspike intervalsarray[n] real y; y[1] = t[1];for (i in2:n) { y[i] = t[i] - t[i-1]; }}parameters {real<lower=0> log_mu;real<lower=0> log_lambda;}transformed parameters {real mu = 10 ^ log_mu;real lambda = 10 ^ log_lambda;}model { log_mu ~ normal(mu_mu, sigma_mu); log_lambda ~ normal(mu_lambda, sigma_lambda); y ~ inv_gaussian(mu, lambda);}generated quantities {array[n] real y_ppc;for (i in1:n) { y_ppc[i] = inv_gaussian_rng(mu, lambda); }}
Let’s grab our samples, including the posterior predictive checks!
# Load in as dataframedf = pl.read_csv(os.path.join(data_path, 'rgc_spike_times_1.csv'))# Spike times in milliseconds for conveniencespike_times = df['spike time (ms)'].to_numpy()# Interspike intervalsy = np.concatenate(((spike_times[0],), np.diff(spike_times)))# Prep for Standata =dict(n=len(spike_times), spike_times=spike_times)with bebi103.stan.disable_logging(): sm = cmdstanpy.CmdStanModel(stan_file='rgc_spike_times_inv_gaussian.stan') samples = az.from_cmdstanpy(sm.sample(data=data), posterior_predictive='y_ppc')
31.2 Corner plots
We now have two parameters, so visualizing the posterior distribution is not as simple as plotting a single histogram or ECDF. We instead make a corner plot, which shows all marginal posteriors marginalized down to one or two parameters.