gaussian processes
Thus far, our MLMs have grouped observations within essentially nominal categories – even when we use numbers, we’re treating these categories as unordered and discrete. A participant id is just a label for a unique “thing”. In these models, partial pooling does double duty – it improves accuracy by borrowing information from other groups but estimates variation across groups.
But consider other types of grouping that are not so distinct – McElreath shows examples using spatial distance, but also consider time or age. Individuals of the same age share some exposure (cultural trends, political and historical events, even climate); they also share exposure with people of similar ages. And even though we assign labels to generations, guessing the similarity between two people is probably better done by the differences in their ages rather than whether they share a generational label.
In this case, it wouldn’t make as much sense to fit a separate intercept for individuals of the same age, because the model would borrow equally from all other groups. Instead, we would rather the model borrow information in proportion to the closeness of other groups (e.g., intercepts for 27-year-olds should be more informed by data from 28- and 26-year-olds than 67-year-olds).
The general approach to this is known as GAUSSIAN PROCESS REGRESSION.
Gaussian processes can be useful for answering questions like:
From the brms manual:
A GP is a stochastic process, which describes the relation between one or more predictors x=(x1,...,xd) and a response f(x), where d is the number of predictors. A GP is the generalization of the multivariate normal distribution to an infinite number of dimensions. Thus, it can be interpreted as a prior over functions. The values of f() at any finite set of locations are jointly multivariate normal, with a covariance matrix defined by the covariance kernal kp(xi,xj), where p is the vector of parameters of the GP:
f(x1),...,f(xn)∼MVN(0,(kp(xi,xj))ni,j=1)
The smoothness and general behavior of the function f depends only on the choice of the covariance kernel.
The smoothness and general behavior of the function f depends only on the choice of the covariance kernel. In plain English, a kernel is like a “relationship measurer” between any two points in your data. It answers the question: “If I know the value at point A, how much does that tell me about the value at point B?”
The kernel determines how the influence spreads across your data. If two points are close together according to the kernel, their values will be similar. If two points are far apart, their values will be uncorrelated (but not necessarily far apart).
Different kernels create different patterns of relationships. For example:
The mathematical formula of the kernel determines exactly how this similarity decays with distance, time, or whatever dimension you’re working with.
Currently there are four available kernels in brms (in order from smoothest to roughest):
exp_quad: exponentiated quadraticmatern52: Matern 5/2matern32: Matern 3/2exponential: exponential(Details about the kernels from the Stan manual.)
With magnitude σ and length scale l, the exponentiated quadratic kernel is:
k(xi,xj)=σ2exp(−|xi−xj|22l2) * Think of it as creating very gentle hills and valleys - no sudden jumps or sharp corners * Influence between points drops off quickly (like a bell curve) * If you know the value at age 25, it gives you strong information about age 24 and 26, moderate information about age 22 and 28, and very little about age 15 or 35 * Best for processes you believe are inherently smooth and continuous * Example: Physical growth patterns in children might follow this kind of smooth curve
This creates somewhat rougher patterns than the exponentiated quadratic.
With magnitude σ and length scale l, the Matern 3/2 kernel is:
k(xi,xj)=σ2(1+√3|xi−xj|l)exp(−√3|xi−xj|l)
This sits between the very smooth exp_quad and the rougher matern32.
With magnitude σ and length scale l, the Matern 3/2 kernel is:
k(xi,xj)=σ2(1+√5|xi−xj|l+√5|xi−xj|23l2)exp(−√5|xi−xj|l)
exp_quad while still maintaining good continuityThis is the roughest of the four kernels.
With magnitude σ and length scale l, the exponential kernel is:
k(xi,xj)=σ2exp(−|xi−xj|l)
Data come from a 26-wave (every 2 weeks) study of political attitudes (Brandt et al., 2021).
Should federal spending on defense be increased, decreased, or kept the same?
d %>%
count(def) %>%
mutate(
col = ifelse(def >7, "1","2"),
def = factor(def,
levels=c(1:9),
labels = c("1\nGreatly decrease", "2", "3", "4\nKeep the same", "5", "6", "7\nGreatly increase", "8\n Don't Know", "9\n Haven't thought"))) %>%
ggplot(aes(x=def, y=n)) +
geom_col(aes(fill=col)) +
scale_x_discrete(labels = label_wrap(10)) +
labs(title = "Should federal spending on defense be changed?",
x=NULL,
y = "count") +
theme(legend.position = "none")Let’s remove those values on the variable that are not part of our Likert scale.
We might suspect that people of similar ages have similar opinions on a question like this. But would we say that age has a linear effect?
Let’s say we’re interested in studying how time and age impact the responses to this question. We’ll build up our models from most simple to most complex.
Ri∼Categorical(p)logit(p)=α−ϕϕ=0α∼Normal(0,1.5)
Family: cumulative
Links: mu = logit; disc = identity
Formula: def ~ 1
Data: d (Number of observations: 529)
Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
total post-warmup draws = 4000
Regression Coefficients:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept[1] -1.59 0.11 -1.82 -1.37 1.00 3109 2689
Intercept[2] -0.86 0.09 -1.04 -0.68 1.00 4144 3091
Intercept[3] -0.03 0.09 -0.20 0.14 1.00 4492 3409
Intercept[4] 1.13 0.10 0.94 1.34 1.00 4642 3428
Intercept[5] 1.99 0.13 1.73 2.25 1.00 5056 3572
Intercept[6] 2.93 0.19 2.57 3.33 1.00 5240 3116
Further Distributional Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
disc 1.00 0.00 1.00 1.00 NA NA NA
Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
Rij∼Categorical(p)logit(p)=α−ϕϕ=β1agei+β2waveiα∼Normal(0,1.5)β1∼Normal(0,1)β2∼Normal(0,1)
Family: cumulative
Links: mu = logit; disc = identity
Formula: def ~ 1 + age
Data: d (Number of observations: 529)
Draws: 4 chains, each with iter = 5000; warmup = 1000; thin = 1;
total post-warmup draws = 16000
Regression Coefficients:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept[1] -0.54 0.24 -1.00 -0.06 1.00 21384 12833
Intercept[2] 0.22 0.23 -0.24 0.67 1.00 23937 12552
Intercept[3] 1.08 0.24 0.62 1.55 1.00 21984 12201
Intercept[4] 2.28 0.25 1.79 2.79 1.00 20214 12964
Intercept[5] 3.17 0.27 2.64 3.71 1.00 19651 12984
Intercept[6] 4.12 0.31 3.51 4.74 1.00 20335 13038
age 0.03 0.01 0.02 0.04 1.00 21477 12271
Further Distributional Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
disc 1.00 0.00 1.00 1.00 NA NA NA
Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
m3 <- brm(
data = d,
family = cumulative,
def ~ 1 + gp(age, cov = "exp_quad", scale = F),
prior = c( prior(normal(0, 1.5), class=Intercept),
prior(inv_gamma(2.5, 3), class=lscale, coef = gpage),
prior(exponential(1), class=sdgp, coef = gpage)),
iter=5000, warmup=1000, seed=3, cores=4,
file = here("files/models/m91.3")
) Family: cumulative
Links: mu = logit; disc = identity
Formula: def ~ 1 + gp(age, cov = "exp_quad", scale = F)
Data: d (Number of observations: 529)
Draws: 4 chains, each with iter = 5000; warmup = 1000; thin = 1;
total post-warmup draws = 16000
Gaussian Process Hyperparameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sdgp(gpage) 0.63 0.31 0.26 1.51 1.00 4794 3120
lscale(gpage) 11.99 8.84 2.38 34.55 1.00 1913 3434
Regression Coefficients:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept[1] -2.01 0.36 -2.84 -1.37 1.00 5063 3825
Intercept[2] -1.26 0.35 -2.07 -0.63 1.00 5438 3817
Intercept[3] -0.40 0.35 -1.20 0.23 1.00 5349 3817
Intercept[4] 0.81 0.35 0.01 1.44 1.00 5782 3874
Intercept[5] 1.69 0.36 0.89 2.35 1.00 5838 3838
Intercept[6] 2.65 0.39 1.81 3.36 1.00 6088 3897
Further Distributional Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
disc 1.00 0.00 1.00 1.00 NA NA NA
Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
As a reminder, the kernel formula for our GP is
k(xi,xj)=σ2exp(−|xi−xj|22l2)
The value |xi−xj|2 is the absolute squared distance between two values. That means, the two values estimated in our model are σ and l. What do each of these mean?
Let’s start with l: The value of 1/(2×l2) is equal to ρ. So we can rewrite the equation as
k(xi,xj)=σ2exp(−ρ|xi−xj|2) In other words, the covariance between two values is declines exponentially with the squared distance between them.
The remaining piece, σ2 is the maximum covariance between any two values.
I used an invgamma distribution for as my prior for l and an expoential distribution as the prior for σ. Here’s what that looks like:
set.seed(11)
nsim = 50
sample_l = invgamma::rinvgamma(nsim, 2.5, 3)
sample_sig = rexp(nsim, 1)
# wrangle into functions
p1 = tibble(
.draw = 1:nsim,
l = sample_l,
sig = sample_sig) %>%
mutate(sigsq = sig^2,
rhosq = 1 / (2 * l^2)) %>%
expand_grid(x = seq(from = 0, to = 10, by = .05)) %>%
mutate(covariance = sigsq * exp(-rhosq * x^2),
correlation = exp(-rhosq * x^2)) %>%
# plot
ggplot(aes(x = x, y = correlation)) +
geom_line(aes(group = .draw),
linewidth = 1/4, alpha = 1/4, color = "#1c5253") +
scale_x_continuous("distance (age)", expand = c(0, 0),
breaks = 0:5 * 2) +
labs(subtitle = "Gaussian process prior")
p1# A tibble: 16,000 × 5
.chain .iteration .draw sdgp_gpage lscale_gpage
<int> <int> <int> <dbl> <dbl>
1 1 1 1 0.351 5.65
2 1 2 2 0.404 6.20
3 1 3 3 0.538 5.90
4 1 4 4 0.586 12.8
5 1 5 5 0.443 2.47
6 1 6 6 0.387 2.28
7 1 7 7 0.481 3.33
8 1 8 8 0.874 4.21
9 1 9 9 0.799 3.90
10 1 10 10 0.504 3.22
# ℹ 15,990 more rows
Here’s our posterior distribution of the Gaussian process
post <- as_draws_df(m3)
# for `slice_sample()`
set.seed(14)
# wrangle
p2 <-
post %>%
mutate(sigsq = sdgp_gpage^2,
rhosq = 1 / (2 * lscale_gpage^2)) %>%
slice_sample(n = 50) %>%
expand_grid(x = seq(from = 0, to = 10, by = .05)) %>%
mutate(covariance = sigsq * exp(-rhosq * x^2),
correlation = exp(-rhosq * x^2)) %>%
# plot
ggplot(aes(x = x, y = correlation)) +
geom_line(aes(group = .draw),
linewidth = 1/4, alpha = 1/4, color = "#1c5253") +
stat_function(fun = function(x)
exp(-(1 / (2 * mean(post$lscale_a_gpage)^2)) * x^2),
color = "#0f393a", linewidth = 1) +
scale_x_continuous("distance (age)", expand = c(0, 0),
breaks = 0:5 * 2) +
labs(subtitle = "Gaussian process posterior")
p1 + p2Let’s plot the model-predicted values as a function of age for models 2 and 3 to see how they compare.
nd = d %>% distinct(age)
pred_m2 = nd %>%
add_epred_draws(m2) %>%
mutate(model = "m2")
pred_m3 = nd %>%
add_epred_draws(m3) %>%
mutate(model = "m3")
full_join(pred_m2, pred_m3) %>%
group_by(model, age, .category) %>%
mean_qi(.epred) %>%
ggplot( aes( x=age, y=.epred, color=model ) ) +
geom_line() +
scale_color_manual(values = c("#1c5253" , "#e07a5f")) +
facet_wrap(~.category) +
theme(legend.position = "bottom")If you zoom in on any of these response options, you’ll see that model 2, which modeled age as a linear predictor, only allows for gradual change in our outcome, whereas model 3 allows for some more “wiggliness”.
ppd_m2 = nd %>%
add_predicted_draws(m2) %>%
mutate(model = "m2")
ppd_m3 = nd %>%
add_predicted_draws(m3) %>%
mutate(model = "m3")
full_join(ppd_m2, ppd_m3) |>
ungroup() |>
with_groups( c(model, age), summarise, avg = mean( as.numeric(.prediction) ) ) |>
ggplot( aes( x=age, y=avg, color=model ) ) +
geom_line() +
scale_color_manual(values = c("#1c5253" , "#e07a5f")) +
labs( x="age", y="average response" )nd = expand.grid(
age=18:72,
iter=1:50)
plotdata = m3 %>%
predicted_draws(newdata = d) |>
filter(.draw <= 50)
plotdata |>
mutate(.prediction = as.numeric(.prediction)) |>
with_groups(
c(age, .draw), summarise, m = mean(.prediction, na.rm=T)
) |>
ggplot( aes( x=age, y=m, group=.draw) ) +
geom_line(alpha=.4, color = "#1c5253") +
geom_vline(aes(xintercept = c(47)), color = "red")m4 <- brm(
data = d,
family = cumulative,
def ~ 1 + gp(age, cov = "exponential", scale = F),
prior = c( prior(normal(0, 1.5), class=Intercept),
prior(inv_gamma(2.5, 3), class=lscale, coef = gpage),
prior(exponential(1), class=sdgp, coef = gpage)),
iter=5000, warmup=1000, seed=3, cores=4,
file = here("files/models/m91.4")
) Family: cumulative
Links: mu = logit; disc = identity
Formula: def ~ 1 + gp(age, cov = "exponential", scale = F)
Data: d (Number of observations: 529)
Draws: 4 chains, each with iter = 5000; warmup = 1000; thin = 1;
total post-warmup draws = 16000
Gaussian Process Hyperparameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sdgp(gpage) 0.47 0.16 0.22 0.84 1.00 7175 5164
lscale(gpage) 11.93 12.34 1.79 43.70 1.00 7382 5700
Regression Coefficients:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept[1] -1.93 0.29 -2.56 -1.43 1.00 7301 6968
Intercept[2] -1.18 0.28 -1.79 -0.70 1.00 7166 6991
Intercept[3] -0.32 0.27 -0.92 0.15 1.00 7355 5429
Intercept[4] 0.89 0.27 0.29 1.36 1.00 7716 5569
Intercept[5] 1.77 0.29 1.16 2.29 1.00 8373 7331
Intercept[6] 2.73 0.32 2.06 3.32 1.00 9640 6593
Further Distributional Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
disc 1.00 0.00 1.00 1.00 NA NA NA
Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
post4 <- as_draws_df(m4)
# for `slice_sample()`
set.seed(14)
# wrangle
p3 <-
post4 %>%
mutate(sigsq = sdgp_gpage^2,
rhosq = 1 / (2 * lscale_gpage^2)) %>%
slice_sample(n = 50) %>%
expand_grid(x = seq(from = 0, to = 10, by = .05)) %>%
mutate(covariance = sigsq * exp(-rhosq * x^2),
correlation = exp(-rhosq * x^2)) %>%
# plot
ggplot(aes(x = x, y = correlation)) +
geom_line(aes(group = .draw),
linewidth = 1/4, alpha = 1/4, color = "#1c5253") +
stat_function(fun = function(x)
exp(-(1 / (2 * mean(post4$lscale_a_gpage)^2)) * x^2),
color = "#0f393a", linewidth = 1) +
scale_x_continuous("distance (age)", expand = c(0, 0),
breaks = 0:5 * 2) +
labs(subtitle = "Gaussian process posterior (exponential kernal)")
p2 + p3ppd_m4 = nd %>%
add_predicted_draws(m4) %>%
mutate(model = "m4")
full_join(ppd_m3, ppd_m4) |>
ungroup() |>
with_groups( c(model, age), summarise,
avg = mean( as.numeric(.prediction) ) ) |>
ggplot( aes( x=age, y=avg, color=model ) ) +
geom_line() +
scale_color_manual(values = c("#1c5253" , "#e07a5f")) +
labs( x="age", y="average response" ) +
scale_x_continuous(breaks=seq(20,75,5))data_path = "https://raw.githubusercontent.com/sjweston/uobayes/refs/heads/main/files/data/external_data/williams.csv"
d <- read.csv(data_path)
rethinking::precis(d) mean sd 5.5% 94.5% histogram
id 98.61029694 63.7493291 10.0000000 207.000000 ▇▇▇▃▅▅▅▃▃▃▂▂
female 0.57016803 0.4950710 0.0000000 1.000000 ▅▁▁▁▁▁▁▁▁▇
PA.std 0.01438236 1.0241384 -1.6812971 1.751466 ▁▁▃▇▇▃▁▁
day 44.17962096 27.6985612 6.0000000 92.000000 ▇▇▇▇▅▅▅▃▃▃
PA_lag 0.01992587 1.0183833 -1.7204351 1.717036 ▁▂▃▅▇▇▅▃▂▁
NA_lag -0.04575229 0.9833161 -0.8750718 1.990468 ▇▃▂▁▁▁▁▁▁▁▁▁▁▁
steps.pm 0.05424387 0.6298941 -1.0258068 1.011356 ▁▂▇▇▃▁▁▁
steps.pmd 0.02839160 0.7575395 -1.1235951 1.229974 ▁▁▇▇▁▁▁▁▁▁▁▁▁▁
NA.std -0.07545093 0.9495660 -1.0484293 1.811061 ▁▁▁▇▂▁▁▁▁▁
m5 <- brm(
data = d,
family = gaussian,
PA.std ~ 1 + gp(day, cov = "exp_quad", scale = F) + (1|id),
prior = c( prior(normal(0, .1), class=Intercept),
prior(inv_gamma(2.5, 3), class = lscale, coef = gpday),
prior(exponential(1), class = sdgp, coef = gpday)),
iter=10000, warmup=1000, seed=3, cores=4,
file = here("files/models/m91.5")
) Family: gaussian
Links: mu = identity; sigma = identity
Formula: PA.std ~ 1 + gp(day, cov = "exp_quad", scale = F) + (1 | id)
Data: d (Number of observations: 13033)
Draws: 4 chains, each with iter = 10000; warmup = 1000; thin = 1;
total post-warmup draws = 36000
Gaussian Process Hyperparameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sdgp(gpday) 0.27 0.21 0.07 0.84 1.09 26 22
lscale(gpday) 31.03 11.85 8.61 52.75 1.06 44 285
Multilevel Hyperparameters:
~id (Number of levels: 193)
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept) 0.80 0.05 0.72 0.89 1.08 35 52
Regression Coefficients:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept -0.02 0.12 -0.32 0.16 1.08 30 24
Further Distributional Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma 0.60 0.00 0.59 0.60 1.01 202 465
Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
post <- as_draws_df(m5)
# for `slice_sample()`
set.seed(14)
# wrangle
post %>%
mutate(sigsq = sdgp_gpday^2,
rhosq = 1 / (2 * lscale_gpday^2)) %>%
slice_sample(n = 50) %>%
expand_grid(x = seq(from = 0, to = 100, by = .05)) %>%
mutate(covariance = sigsq * exp(-rhosq * x^2),
correlation = exp(-rhosq * x^2)) %>%
# plot
ggplot(aes(x = x, y = correlation)) +
geom_line(aes(group = .draw),
linewidth = 1/4, alpha = 1/4, color = "#1c5253") +
stat_function(fun = function(x) exp(-(1 / (2 * mean(post$lscale_gpday)^2)) * x^2),
color = "#0f393a", linewidth = 1) +
scale_x_continuous("distance (day)") +
labs(subtitle = "Gaussian process posterior")