Neural Networks as Gaussian Processes
Theory and Insights from Guo (2021) and Lee et al. (2018)
Jongeun Choi
School of Mechanical Engineering, Yonsei University
References
- Radford M. Neal. Priors for infinite networks (tech. rep. no. crg-tr-94-1). University of Toronto, 1994a
- Christopher KI Williams. Computing with infinite networks. In Advances in neural information processing systems, pp. 295–301, 1997.
- Guo, Mengwu. A brief note on understanding neural networks as Gaussian processes. arXiv preprint arXiv:2107.11892 (2021).
- Lee, Jaehoon, Yasaman Bahri, Roman Novak, Samuel S. Schoenholz, Jeffrey Pennington, and Jascha Sohl-Dickstein. Deep neural networks as Gaussian processes. ICLR 2018.
Fully Connected Deep Neural Network
Consider an \(L\)-layer fully-connected neural network:
\[
\begin{aligned}
& z_i^{[\ell]}(x) = \sum_{j=1}^{N_{\ell-1}} W_{ij}^{[\ell]} x_j^{[\ell-1]}(x), \\
& x_i^{[\ell]}(x) = \phi(z_i^{[\ell]}(x) + b_i^{[\ell]}), \\
& y_i(x) = z_i^{[L+1]}(x) = \sum_{j=1}^{N_L} W_{ij}^{[L+1]} x_j^{[L]}(x),
\end{aligned}
\]
where \(x^{[0]} = x\) is the input, and \(y(x)\) is the output.
Fully-Connected Neural Network architecture.
Classical NN Training Objective
A multi-layer neural network is commonly trained by minimizing the following loss:
\[
(W, b) = \arg\min_{W,b} \left\{ \frac{1}{M} \sum_{m=1}^M \left\| y^{(m)} - f^{\mathrm{NN}}(x^{(m)}; W, b) \right\|_2^2 + \lambda \| W \|_2^2 \right\}
\]
- First term: Mean Squared Error over \(M\) data points.
- Second term: \(L_2\) regularization (weight decay) with penalty \(\lambda \ge 0\).
In contrast, an infinite-width network--NNGP does not require explicit training, and instead, Bayesian inference is performed directly using its equivalent GP.
Random Initialization
Assume:
- First layer weights and biases: \((w_i^{[1]}, b_i^{[1]}) \sim \pi\) are independent and identically distributed (i.i.d) with respect to \(i\) and \(w_i^{[1]}\)
- For \(\ell \ge 2\), all the other weights and biases are independently drawn;
\[
W_{ij}^{[\ell]} \overset{\text{i.i.d.}}{\sim} \mathcal{N}\left(\frac{\mu_w^{[\ell]}}{N_{\ell-1}}, \frac{\sigma_w^{2[\ell]}}{N_{\ell-1}}\right), \quad b_i^{[\ell]} \overset{\text{i.i.d.}}{\sim} \mathcal{N}(\mu_b^{[\ell]}, \sigma_b^{2[\ell]})
\]
- Infinite-width: \(N_{\ell} \to \infty\)
From Neural Networks to Gaussian Processes
- At infinite width, each pre-activation \(z_i^{[\ell]}(x)\) is Gaussian by the Central Limit Theorem.
- Inductively, each layer's output is governed by a Gaussian process:
\[
z_i^{[\ell]}(x) \sim \mathcal{GP}\left(h^{[\ell]}(x), k^{[\ell]}(x, x')\right)
\]
- The covariance \(k^{[\ell]}\) is defined recursively.
Central Limit Theorem (CLT) Reminder
Central Limit Theorem:
If we sum many independent and identically distributed (i.i.d.) random variables:
\[
S_n = \sum_{i=1}^n X_i
\]
where each \(X_i\) has finite variance, then as \(n \to \infty\):
\[
\frac{S_n - \mathbb{E}[S_n]}{\sqrt{\text{Var}(S_n)}} \overset{d}{\longrightarrow} \mathcal{N}(0,1)
\]
Interpretation:
- The sum (or average) becomes approximately Gaussian, regardless of the original distribution of \(X_i\).
In Neural Networks: Each pre-activation is a sum of many i.i.d. terms:
\[
z_i^{[\ell]}(x) = \sum_{j} W_{ij}^{[\ell]} x_j^{[\ell-1]}(x)
\]
which becomes Gaussian when the layer width is large.
Recursive Kernel Construction
\[
k^{[\ell]}(x,x') = \sigma_b^2 + \sigma_w^2 \cdot \mathbb{E}\left[ \phi(z^{[\ell-1]}(x)) \phi(z^{[\ell-1]}(x')) \right]
\]
- This defines the kernel of the induced GP
- The network's depth influences the kernel via recursion
NN-Induced Gaussian Process Prior
Given the network structure, the outputs of the \(L+1\) layer follow a Gaussian process:
\[
y(x) \sim \mathcal{GP}(h_{\text{NN}}(x), k_{\text{NN}}(x, x'))=\mathcal{GP}\left( h^{[L+1]}(x),\, k^{[L+1]}(x,x') \right)
\]
with:
\[
\begin{aligned}
h^{[\ell]}(x) &= \mu_w^{[\ell]} \mathbb{E}\left[\phi(z^{[\ell-1]}(x) + b^{[\ell-1]})\right] \\
k^{[\ell]}(x, x') &= \sigma_w^{2[\ell]} \mathbb{E}\left[\phi(z^{[\ell-1]}(x) + b^{[\ell-1]}) \phi(z^{[\ell-1]}(x') + b^{[\ell-1]})\right]
\end{aligned}
\]
This recursive formula induces:
\[
y(x) \sim \mathcal{GP}(h_{\text{NN}}(x), k_{\text{NN}}(x, x'))
\]
Interpretation: The NN prior is automatically converted into a GP prior with a recursively defined mean and kernel.
Regression with NN-Induced Gaussian Processes
$y(x) = f(x) + \varepsilon, \quad \varepsilon \sim \mathcal{N}(0, \sigma_\varepsilon^2)$
\[
y^*(x) \mid \mathbf{X}, \mathbf{y} \sim \mathcal{GP}(h^*_{\text{NN}}(x), k^*_{\text{NN}}(x,x'))
\]
\[
\begin{aligned}
h^*_{\text{NN}}(x) &= h_{\text{NN}}(x) + k_{\text{NN}}(X, x)^\top \left[ K + \sigma_\varepsilon^2 I \right]^{-1} \left( \mathbf{y} - h_{\text{NN}}(X) \right) \\
k^*_{\text{NN}}(x, x') &= k_{\text{NN}}(x, x') - k_{\text{NN}}(X, x)^\top \left[ K + \sigma_\varepsilon^2 I \right]^{-1} k_{\text{NN}}(X, x')
\end{aligned}
\]
This is equivalent to standard GP regression but with the NN-induced kernel.
NNGP Hyperparameter Estimation
$\theta = \left\{ \mu_w^{[\ell]}, \sigma_w^{2[\ell]}, \mu_b^{[\ell]}, \sigma_b^{2[\ell]}, \sigma_\varepsilon^2 \right\}$
Learned by maximizing the marginal likelihood:
\[
\begin{split}
\log p(\mathbf{y} | \mathbf{X}, \theta, \sigma_\varepsilon^2) = &-\frac{1}{2} (\mathbf{y} - h_{\text{NN}})^\top (K + \sigma_\varepsilon^2 I)^{-1} (\mathbf{y} - h_{\text{NN}})\\
&- \frac{1}{2} \log |K + \sigma_\varepsilon^2 I| - \frac{M}{2} \log(2\pi)
\end{split}
\]
- Kernel parameters
- Noise variance
Posterior Inference
\[
y^*(x) | X, y \sim \mathcal{GP}(h^*(x), k^*(x,x'))
\]
\[
\begin{aligned}
h^*(x) &= h(x) + k(X,x)^\top (K + \sigma_\epsilon^2 I)^{-1}(y - h(X)) \\
k^*(x,x') &= k(x,x') - k(X,x)^\top (K + \sigma_\epsilon^2 I)^{-1}k(X,x')
\end{aligned}
\]
This is equivalent to Bayesian inference.
Key Insights from Lee et al. (2017)
Main Finding: Deep, infinitely-wide neural networks are equivalent to Gaussian Processes.
$K^{[\ell]}(x, x') = \sigma_b^2 + \sigma_w^2 \, \mathbb{E}\left[ \phi(z^{[\ell-1]}(x)) \phi(z^{[\ell-1]}(x')) \right]$
This enables exact Bayesian inference without training.
Neural Network vs NNGP vs GP Overview
| NN | NNGP | GP |
Type | Finite-width | Infinite-width NN | Kernel Machine |
Prediction | Optimized \(W,b\) | GP Regression | GP Regression |
Kernel | Implicit via \(W,b\) | NN-induced Kernel | Chosen Kernel |
Hyperparams | Learning rate, \(\sigma_w^2\), \(\sigma_b^2\) | \(\sigma_w^2\), \(\sigma_b^2\), \(\sigma_\varepsilon^2\) | Kernel params, \(\sigma_\varepsilon^2\) |
Uncertainty | Extra techniques | GP Posterior Variance | GP Posterior Variance |
Comment: NNGP is a special case of GP where the kernel comes from NN structure.
NNGP Insight: Acts as a GP with an NN-induced kernel. Hyperparameters relate to NN initialization statistics.
What is a Reproducing Kernel Hilbert Space?
Reproducing Kernel Hilbert Space (RKHS):
- A Hilbert space is a vector space equipped with an inner product, where you can measure angles and lengths.
- An RKHS is a Hilbert space of functions \(f: \Omega \to \mathbb{R}\) where point evaluations are continuous.
- It is associated with a positive-definite kernel \(k(x, x')\).
Reproducing Property:
\[
f(x) = \langle f, k(x, \cdot) \rangle_{\mathcal{H}_k}
\]
Interpretation:
- The value of \(f\) at \(x\) is just the inner product of \(f\) with the function \(k(x, \cdot)\).
- Functions in RKHS are "spanned" by the kernel.
RKHS Intuition: Functions as Vectors
Key Idea: Think of functions in \(\mathcal{H}_k\) as vectors in a high-dimensional space.
Inner Product:
\[
\langle f, g \rangle_{\mathcal{H}_k} \text{ defines angles and lengths between functions}
\]
Reproducing Property:
\[
f(x) = \langle f, k(x, \cdot) \rangle_{\mathcal{H}_k}
\]
Interpretation:
- Each \(k(x, \cdot)\) acts like a feature vector centered at \(x\).
- The function value \(f(x)\) is the projection of \(f\) onto \(k(x, \cdot)\).
Analogy: Imagine \(\mathcal{H}_k\) as a space of vectors, and \(k(x, \cdot)\) are the basis functions!
Kernel Ridge Regression (KRR)
Learning in RKHS:
Given data \(\{(x^{(m)}, y^{(m)})\}_{m=1}^M\), the goal is to find \(f \in \mathcal{H}_k\) minimizing:
\[
\min_{f \in \mathcal{H}_k} \sum_{m=1}^M (y^{(m)} - f(x^{(m)}))^2 + \lambda \|f\|_{\mathcal{H}_k}^2
\]
Interpretation:
- First term: Fit the data.
- Second term: Penalize complexity via the RKHS norm.
Solution:
\[
f(x) = \sum_{m=1}^M \beta_m k(x^{(m)}, x)
\]
where \(\beta = (K + \lambda I)^{-1} y\)
Connection to GP Regression: GP posterior mean is exactly the solution of KRR with \(\lambda = \sigma_\varepsilon^2\).
KRR as Projection in RKHS
KRR Objective:
\[
\min_{f \in \mathcal{H}_k} \sum_{m=1}^M (y^{(m)} - f(x^{(m)}))^2 + \lambda \|f\|_{\mathcal{H}_k}^2
\]
Interpretation:
- Finds \(f^*\) that is the closest function (in RKHS) to fit the data points.
- Balances between:
- Fitting the data exactly.
- Not being too "complex" in \(\mathcal{H}_k\) norm.
Geometrically: It is a projection of the target labels \(\{y^{(m)}\}\) onto the RKHS space spanned by \(\{k(x^{(m)}, \cdot)\}_{m=1}^M\).
This is why GP posterior mean and KRR solution coincide!
RKHS Projection: Intuition
Geometric view:
- The RKHS is a subspace spanned by \(\{ k(x^{(m)}, \cdot) \}_{m=1}^M\)
- The projection of \(f\) onto this space gives the posterior mean \(f^*\).
- Kernel Ridge Regression and Gaussian Process regression both compute this projection.
KRR and GP Regression Connection
Gaussian Process Regression:
Given:
\[
y \sim \mathcal{GP}(0, k(x,x')),
\quad y^{(m)} = f(x^{(m)}) + \varepsilon^{(m)}, \, \varepsilon^{(m)} \sim \mathcal{N}(0, \sigma_\varepsilon^2)
\]
The posterior mean is:
\[
f^*(x) = k(X,x)^\top (K + \sigma_\varepsilon^2 I)^{-1} y
\]
Kernel Ridge Regression solution:
\[
f_{\text{KRR}}(x) = k(X,x)^\top (K + \lambda I)^{-1} y
\]
Key Insight:
\[
\boxed{\text{GP Posterior Mean} = \text{KRR Solution} \text{ with } \lambda = \sigma_\varepsilon^2}
\]
Probabilistic vs Optimization:
- GP regression = Bayesian posterior mean
- KRR = solution of penalized least-squares in RKHS
Why does KRR need more data?
Key difference:
- GP regression automatically incorporates uncertainty.
- When data is scarce, GP posterior mean stays close to the prior, preventing overfitting.
- KRR lacks uncertainty modeling; without enough data, it can underfit or overfit depending on \(\lambda\).
With sufficient data:
- KRR and GP posterior mean become similar.
- Smoothness emerges from more constraints (more data points).
Summary: GP regression is naturally better suited for small-data regimes.
KRR vs GP Regression: Overfitting, Underfitting, and Uncertainty
Key Takeaways:
- Small Data:
- KRR may overfit (\(\lambda\) too small) or underfit (\(\lambda\) too large).
- GP stays smooth and controlled due to posterior variance.
- Large Data:
- Both GP and properly regularized KRR perform well.
- GP variance shrinks as data increases.
RKHS induced by the NN Kernel
NN-induced RKHS:
The kernel \(k_{\mathrm{NN}}\) induces a Reproducing Kernel Hilbert Space (RKHS):
\[
\mathcal{H}_{k_{\mathrm{NN}}}(\Omega)
\]
The corresponding function space:
\[
\mathcal{C}_{k_\mathrm{NN}} = \left\{ f(x) = \sum_{m=1}^M \beta_m k_{\mathrm{NN}}(x^{(m)}, x) \,\middle|\, M \in \mathbb{N}^+, x^{(m)}\in \Omega, \beta_m \in \mathbb{R} \right\}
\]
Meaning:
- Any function generated by the NN-GP prior lives in \(\mathcal{H}_{k_{\mathrm{NN}}}\).
- It can be represented as a linear combination of kernel functions centered at training inputs \(\{x^{(m)}\}_{m=1}^M\).
This is called the representer theorem.
RKHS Structure and Correction Term
Inner Product in \(\mathcal{H}_{k_{\mathrm{NN}}}\):
\[
\left\langle f, g \right\rangle_{\mathcal{H}_{k_{\mathrm{NN}}}} = \sum_{m=1}^M \sum_{m'=1}^{M'} \beta_m \beta'_{m'} k_{\mathrm{NN}}(x^{(m)}, x^{(m')})
\]
Posterior Mean as RKHS Element:
\[
h^*_{\mathrm{NN}}(x) = h_{\mathrm{NN}}(x) + \Delta_{\mathrm{NN}}(x)
\]
where:
\[
\Delta_{\mathrm{NN}}(x) = \sum_{m=1}^M \beta_m k_{\mathrm{NN}}(x^{(m)}, x) \in \mathcal{H}_{k_{\mathrm{NN}}}
\]
Insight:
- Posterior mean is prior mean \(h_{\mathrm{NN}}(x)\) plus an RKHS correction.
- The correction \(\Delta_{\mathrm{NN}}\) exactly reconstructs the deviation needed to fit the data.
Kernel Ridge Regression Viewpoint
The coefficient \(\beta\) is computed as:
\[
\beta = \left[ K + \sigma_\varepsilon^2 I \right]^{-1} \left( y - h_{\mathrm{NN}}(X) \right)
\]
It solves:
\[
\beta = \arg\min_{\beta \in \mathbb{R}^M} \left\{ \| y - h_{\mathrm{NN}}(X) - \Delta_{\mathrm{NN}}(X) \|_2^2 + \sigma_\varepsilon^2 \| \Delta_{\mathrm{NN}} \|_{\mathcal{H}_{k_{\mathrm{NN}}}}^2 \right\}
\]
where
\[
\| \Delta_{\mathrm{NN}} \|_{\mathcal{H}_{k_{\mathrm{NN}}}}^2= \beta^T K \beta
\]
Interpretation:
- Equivalent to a classical kernel ridge regression (KRR)
- Regularized by RKHS norm
- Automatically arises from the Bayesian GP posterior
Viewpoint: Bayesian Inference as Kernel Ridge Regression
Probabilistic Interpretation:
\[
p(y|f) \propto \exp\left(-\frac{1}{2\sigma_\varepsilon^2} \| y - f(X) \|_2^2\right), \quad f \sim \mathcal{GP}(h_{\mathrm{NN}}, k_{\mathrm{NN}})
\]
Posterior Mean:
\[
f^*(x) = h_{\mathrm{NN}}(x) + \arg\min_{\Delta_{\mathrm{NN}} \in \mathcal{H}_{k_{\mathrm{NN}}}} \left\{ \| y - h_{\mathrm{NN}}(X) - \Delta_{\mathrm{NN}}(X) \|_2^2 + \sigma_\varepsilon^2 \| \Delta_{\mathrm{NN}} \|^2_{\mathcal{H}_{k_{\mathrm{NN}}}} \right\}
\]
Summary:
- Posterior mean = solution of a penalized least squares problem in RKHS
- The Bayesian posterior mean is the same as the solution of kernel ridge regression
Prior \(\rightarrow\) Posterior as RKHS Projection
Interpretation:
- The correction \(\Delta_{\mathrm{NN}}\) projects the prior mean onto the data.
- The posterior is the sum of prior + data-driven RKHS correction.
Empirical Findings (Lee et al.)
Empirical Findings (Lee et al.):
- NNGPs achieve accuracy comparable or better than wide trained NNs.
- GP's uncertainty strongly correlates with prediction errors.
- Wider NNs behave more like GPs.
Summary
Summary:
- Infinite-width neural networks induce Gaussian processes.
- Recursive kernels define the GP structure layer-by-layer.
- NNGP shows good predictive performance and uncertainty quantification.