In the first part of this two-part series on Function-Space Variational Inference (FSVI), we looked at the Data Processing Inequality (DPI). In this second part, we finally look at the relationship between FSVI, a method focusing on the Bayesian predictive posterior rather than the parameter space, and the DPI. We will recover several interesting results from the literature in a simple form, and hopefully build intuitions for the relationship between parameter and functional priors.
Most importantly, we consider how FSVI can measure a predictive divergence between the approximate and true posterior which is independent of parameter symmetries. With parameter symmetries, I refer to different parameters that yield the same predictions, which is very common in over-parameterized neural networks: think of parameter symmetries like different paths leading to the same destination; they might look different but end up at the same predictions1.
As a nice example and application, we relate FSVI to training with label entropy regularization: a potentially more meaningful prior than the ones usually used in Bayesian neural networks2.
Variational inference is a powerful technique for approximating complex Bayesian posteriors with simpler distributions. In its usual form, it optimizes an approximate, variational distribution to match the Bayesian parameter posterior as closely as possible. This way, it transforms the problem of Bayesian inference into an optimization problem.
Especially for deep neural networks, obtaining a good approximation of the parameter space can be difficult. One reason is the sheer size of the parameter space. Additionally, the parameterization of a neural network often contains many symmetries (different parameter configurations can lead to the same predictions of the model) that are not taken into account either.
Function-space variational inference (FSVI) side-steps some of these restrictions by only requiring that the variational distribution matches the Bayesian predictive posterior.
Data Processing Inequality
The DPI states that processing data stochastically can only reduce information. More formally:
That is, the KL divergence (🥬 divergence) between \(\qof{Y}\) and \(\pof{Y}\) cannot be larger than between the original \(\qof{\W}\) and \(\pof{\W}\). Intuitively, the stochastic mapping \(\opf\) induces a bottleneck that reduces how well we can distinguish between \(\opp\) and \(\opq\). Finally we have equality when \(\Kale{\qof{\W \given Y}}{\pof{\W \given Y}} = 0\). See the first part of this two-part series for more details.
The paper “Understanding Variational Inference in Function-Space” by Burt et al. (2021)3 succinctly summarizes the DPI as follows:
The data processing inequality states that if two random variables are transformed in this way, they cannot become easier to tell apart.
Problem Setting & Notation
In the following, I assume a classification task with cross-entropy loss. This post uses the following notation:
- \(\y\) is the label,
- \(\x\) is the input,
- \(\qof{\y \given \x}\) is the predictive distribution we want to learn,
- \(\pdata{\y \given \x}\) is the data distribution, and
- \(C\) is the number of classes.
The probabilistic model is as usual: \[\pof{\y, \w \given \x} = \pof{\y \given \x, \w} \, \pof{\w}.\]
Note that I drop conditioning on \(\x\) for simplicity, and also note, that I use upper-case letters for random variables, I’m taking an expectation over in some form, e.g. the 🥬 divergence, and lower-case letters when I’m referring to specific observations or values that could be substituted (with the exception of \(\Dany\)).
Chain Rule of the 🥬 Divergence & DPI
An important property of the 🥬 divergence is the chain rule:
\[ \Kale{\qof{\Y_n,...,\Y_1}}{\pof{\Y_n,...,\Y_1}} = \sum_{i=1}^n \Kale{\qof{\Y_i \given \Y_{i-1}, ..., \Y_1}}{\pof{\Y_i \given \Y_{i-1}, ..., \Y_1}}. \]
Above application of the chain rule yields a chain inequality for the DPI as well: \[ \begin{align} \Kale{\qof{\W}}{\pof{\W}} &\ge \Kale{\qof{\Y_n,...,\Y_1}}{\pof{\Y_n,...,\Y_1}}\\ &\ge \Kale{\qof{\Y_{n-1},...,\Y_1}}{\pof{\Y_{n-1},...,\Y_1}}\\ &\ge \Kale{\qof{\Y_1}}{\pof{\Y_1}}. \end{align} \]
Function-Space Variational Inference
The DPI has an intriguing connection to FSVI. Let’s say we want to approximate a Bayesian posterior \(\pof{\w \given \Dany}\) with a variational distribution \(\qof{\w}\). In standard VI, we would minimize \(\Kale{\qof{\W}}{\pof{\W \given \Dany}}\) to match the variational distribution to the Bayesian posterior. Specifically:
\[ \begin{align} \Kale{\qof{\W}}{\pof{\W \given \Dany}} &= \underbrace{\E{\qof{\w}}{-\log \pof{\Dany \given \w}} + \Kale{\qof{\W}}{\pof{\W}}}_{\text{Evidence}\ \text{Bound}} + \log \pof{\Dany} \ge 0 \\ \iff \underbrace{-\log \pof{\Dany}}_{=\xHof{\pof{\Dany}}} &\le \E{\qof{\w}}{-\log \pof{\Dany \given \w}} + \Kale{\qof{\W}}{\pof{\W}}. \end{align} \]
This derives an information-theoretic evidence upper bound on the information content \(-\log \pof{\Dany}\) of the data \(\Dany\) under the variational distribution \(\qof{\w}\). In other more probability-theory inspired literature, the negative of this bound is called the evidence lower bound (ELBO), which we maximize.
In FSVI (with a caveat I detail below), we apply the DPI to the prior 🥬 divergence term and obtain a functional version of the evidence bound:
\[ \begin{align} \Kale{\qof{\W}}{\pof{\W}} \ge \Kale{\qof{\Y... \given \x...}}{\pof{\Y... \given \x...}}, \end{align} \]
where \(\Y... \given \x...\) are (finite or infinite) sets of samples. That is, we do not only optimize marginal distributions but also joint distributions4.
The resulting objective
\[ \begin{align} \E{\qof{\w}}{-\log \pof{\Dany \given \w}} + \Kale{\qof{\Y... \given \x...}}{\pof{\Y... \given \x...}} \end{align} \]
is equal to the (negative) functional ELBO (fELBO) in “Functional variational Bayesian neural networks” by Sun et al. (2019)5—with caveats that we discuss below.
Choosing \(\x...\)
One important detail is the question of how to choose the \(\x...\):
Ideally, we want to choose them such that the DPI inequality is as tight as possible.
Given the chain inequality, it is obvious that the larger the set \(\x...\), the tighter the inequality will be. Hence, if we could choose an infinite set of points well, we might be able to get the tightest possible inequality. However, this might not be tractable and in practice, it is often not.
Some works take a supremum over finite subsets of a certain size, essentially building a core-set as an approximation (Rudner et al., 2022a6/b7); others take an expectation over the finite sets of input samples (Sun et al., 2019)8, which is not necessarily yielding the tightest inequality but provides an unbiased estimate; while again other works focus on finite datasets for which the all points can be taken into account (Klarner et al., 20239).
We will discuss the tightness of the inequality and the implications in the data limit below.
Focusing on the most important aspect of FSVI, we observe:
Continual Learning
When we directly optimize the 🥬 divergence on the finite input dataset, for example, we align \(\opq\) with the prior of \(\opp\) where it matters most: on the predictions of the observed data.
This is of particular interest in continual learning, where the prior for the next task is chosen to be the posterior from the previous task. In this case, the functional ELBO can be used to approximate the posterior of the previous model while incorporating new data.
For two great papers that are very readable and provide further insights, see “Continual learning via sequential function-space variational inference”10 and “Tractable function-space variational inference in Bayesian neural networks”11, both by Rudner et al. (2022).
Comparison to FSVI in the literature
In practice, both works by Rudner et al. (2022), for example, linearize the logits12 (similar to a Laplace approximation) and use the DPI to shown (in their notation): \[ \mathbb{D}_{\mathrm{KL}}\left(q_{f(\cdot ; \boldsymbol{\Theta})} \| p_{f(\cdot ; \boldsymbol{\Theta})}\right) \leq \mathbb{D}_{\mathrm{KL}}\left(q_{\Theta} \| p_{\Theta}\right) \] which in my notation is equivalent to the above: \[ \Kale{\qof{\L...\given \x...}}{\pof{\L...\given \x...}} \le \Kale{\qof{\W}}{\pof{\W}}. \] They maximize the fELBO objective: \[ \mathcal{F}\left(q_{\boldsymbol{\Theta}}\right)=\mathbb{E}_{q_{f\left(\mathbf{x}_{\mathcal{D}} ; \boldsymbol{\Theta}\right)}}\left[\log p_{\mathbf{y} \mid f(\mathbf{X} ; \boldsymbol{\Theta})}\left(\mathbf{y}_{\mathcal{D}} \mid f\left(\mathbf{X}_{\mathcal{D}} ; \boldsymbol{\theta}\right)\right)\right]-\sup _{\mathbf{X} \in \mathcal{X}_{\mathbb{N}}} \mathbb{D}_{\mathrm{KL}}\left(q_{f(\mathbf{X} ; \boldsymbol{\Theta})} \| p_{f(\mathbf{X} ; \boldsymbol{\Theta})}\right),\] which is equivalent to minimizing the information-theoretic objective: \[ \E{\qof{\w}}{-\log \pof{\Dany \given \w}} + \Kale{\qof{\L... \given \x...}}{\pof{\L... \given \x...}}, \] if we choose the \(\x...\) to tighten the DPI inequality as much as possible (i.e. choosing the supremum).
Using the inequality chain from above, we can sandwich their objective between a regular (negative) ELBO and the (negative) functional ELBO, we have derived above: \[ \begin{aligned} &\E{\qof{\w}}{-\log \pof{\Dany \given \w}} + \Kale{\qof{\W}}{\pof{\W}} \\ &\quad \E{\qof{\w}}{-\log \pof{\Dany \given \w}} + \Kale{\qof{\L... \given \x...}}{\pof{\L... \given \x...}} \\ &\quad \ge \E{\qof{\w}}{-\log \pof{\Dany \given \w}} + \Kale{\qof{\Y... \given \x...}}{\pof{\Y... \given \x...}}. \end{aligned} \]
In practice, using the probabilities instead of logits when performing linearization is often cumbersome due to the non-linearity of the softmax functions, which requires Monte-Carlo sampling of the logits to obtain an approximation of the final probabilities. I speculate that sampling the logits can be more benign given that we often use ReLUs in the underlying neural networks. (Don’t quote me too strongly on this, though.)
Conceptually, this explains the derivation of their ELBO objective and also relates them to the ‘purer’ and simpler functional evidence bound derived above, but this raises the question of how these different inequalities are different and what the gap between them tells us. We will address this question in the next section.
The Equality Case and Equivalence Classes
When do we have \(\Kale{\qof{\W}}{\pof{\W}} = \Kale{\qof{\Y... \given \x...}}{\pof{\Y... \given \x...}}\)? And what does it tell us? As we have seen in the first part, we have equality in the DPI if and only \(\Kale{\qof{\W \given \Y..., \x...}}{\pof{\W \given \Y..., \x...}}=0\).
Given that we are trying to approximate the Bayesian posterior \(\pof{\w \given \Y..., \x...}\) using \(\qof{\w}\), this equality condition tells us that we would have to find the exact posterior for equality. Hence, it is unlikely that we will have equality in practice. Thus, this raises the question of what this predictive prior term \(\Kale{\qof{\Y... \given \x...}}{\pof{\Y... \given \x...}}\) provides us with.
Another way to think about the gap between the two 🥬 divergences is that one is parameter-based and the other one is not: Deep neural networks have many parameter symmetries, and it is often possible to permute the weights of a neural network without changing the outputs—for example, in a convolutional neural network, we could swap channels.
The functional 🥬 divergences won’t be affected by this as they are parameter-free and do not take into account the parameters of the model but only the predictions. The regular parameter-based 🥬 divergence, however, would be affected by this—depending on the prior \(\pof{\w}\).
If the prior assigns different probability to otherwise equivalent parameters, this obviously changes the parameter posterior, while the outputs are invariant to these changes if the overall assigned probability remains the same.
Equivalence Classes
Unless there are other considerations, it makes sense to use priors that assign the same density to parameters that are equivalent. Hence, for a given function \(\fof{\x ; \w}\), which determines the likelihood \(\pof{\y \given \x, \w} \triangleq \pof{y \given \fof{\x ; \w}}\), we can define an equivalence relation such that \(\w \sim \w'\) if and only if \(\fof{\x; \w} = \fof{\x; \w'}\) for all \(\x\). This equivalence relation partitions the parameter space into equivalence classes: \[[\w] \triangleq \{\w' : \fof{x ; \w} = \fof{x ; \w} \quad \forall x \}.\]
Any prior \(\pof{\w}\) induces a prior \(\hpof{[\w]}\) over the equivalence classes: \[\hpof{[\w]} \triangleq \sum_{\w' \in [\w]} \pof{\w'}.\] —or \(\int_{[\w]} \pof{\w'} \, d \w'\) for continuous \(\w\)—with the corresponding model: \[ \begin{aligned} \hpof{\y, [\w] \given \x} &\triangleq \hpof{\y \given [\w], \x} \, \hpof{[\w]} \\ &= \pof{\y \given \x, \w} \, \hpof{[\w]}. \end{aligned} \]
Consistency
Importantly, this definition is consistent with Bayesian inference:
This is easy to show with the definition and application of Bayes’ rule: \[ \begin{aligned} \hpof{[\w] \given \Dany} &= \hpof{\Dany \given [\w]} \, \hpof{[\w]} / \hpof{\Dany} \\ &= \pof{\Dany \given \w} \sum_{\w' \in [\w]} \pof{\w'} / \hpof{\Dany} \\ &= \sum_{\w' \in [\w]} \pof{\Dany \given \w'} \, \pof{\w'} / \hpof{\Dany} \\ &= \sum_{\w' \in [\w]} \pof{\w' \given \Dany} \, \pof{\Dany} / \hpof{\Dany} \\ &= \sum_{\w' \in [\w]} \pof{\w' \given \Dany}. \end{aligned} \] The last step follows from \(\hpof{\Dany}=\pof{\Dany}\): \[ \begin{aligned} \hpof{\Dany} &= \sum_{[\w]} \hpof{\Dany, [\w]} \\ &= \sum_{[\w]} \sum_{\w' \in [\w]} \pof{\Dany, \w'} \\ &= \sum_{\w'} \pof{\Dany, \w} \\ &= \pof{\Dany}. \end{aligned} \] This also tells us that, for any \(\x\) and \(\y\), \(\pof{\y... \given \x...} = \hpof{\y... \given \x...}\).
Given this consistency, we don’t have to differentiate between \(\hat\opp\) and \(\opp\) and can use \(\opp\) interchangeably. The same holds for \(\opq\).
Equality & Symmetries
We can view \([\w]\) as a projection from \(\w\) to its equivalence class \([\w]\). The DPI then gives us: \[ \Kale{\qof{\W}}{\pof{\W}} \ge \Kale{\qof{[\W]}}{\pof{[\W]}}. \]
What does the gap between the two terms tell us?
Let’s look at a few examples to get a better understanding of this.
- Trivial Constant Case
Let \(\fof{\x ; \w} = 0\) independent of any \(f\). Then \([\w] = [\w']\) for any \(\w\), \(\w'\).
For any approximate distribution \(\qof{\w}\), the induced \(\Kale{\qof{[\W]}}{\pof{[\W]}}=0\), while \(\Kale{\qof{\W}}{\pof{\W}}\) also includes superfluous divergence.
- Unused Parameter
Let \(\y \given (\w_1, \w_2) = \w_1\) deterministic but independent of \(\w_2\). Then \([(\w_1, \w_2)] = [(\w_1, {\w'}_2)]\) for any \({\w'}_2\) and \([(\w_1,*)]\not=[({\w'}_1, *)]\) for any \(\w_1 \not= \w'_1\).
Then \(\Kale{\qof{[\W]}}{\pof{[\W]}}=\Kale{\qof{\W_1}}{\pof{\W_1}}\) and captures purely the meaningful divergence between approximate and true posterior, while \(\Kale{\qof{\W}}{\pof{\W}}\) includes any divergence due to parameter symmetries across all \(\w_2\) that has no effect on the predictions.
- Periodic Parameter Space
Finally, let’s assume that the predictions are periodic in some way. That is, e.g. \(\y = \sin \w\). Obviously, \([\w] = [\w + 2\pi]\) then.
Further, let \(\pof{\w} = \operatorname{U}(\w; [0,2\pi \, N))\) for some \(N\) that determines the number of periods. Then, if we introduce another random variable \(K\), that captures which period we are in, we can (again) use the chain rule to write: \[ \begin{aligned} \Kale{\qof{\W}}{\pof{\W}} &= \Kale{\qof{\W \given \W \in [K\,2\pi, (K+1)\,2\pi]}}{\pof{\W \given \W \in [K\,2\pi, (K+1)\,2\pi]}} \\ &\quad + \Kale{\qof{\W \in [K\,2\pi, (K+1)\,2\pi]}}{\pof{\W \in [K\,2\pi, (K+1)\,2\pi]}} \\ &= \Kale{\qof{[\W]}}{\pof{[\W]}} \\ &\quad + \Kale{\qof{\W \in [K\,2\pi, (K+1)\,2\pi]}}{\pof{\W \in [K\,2\pi, (K+1)\,2\pi]}}. \end{aligned} \] This follows from the setup of this specific example.
Finally, \(\Kale{\qof{\W \in [K\,2\pi, (K+1)\,2\pi]}}{\pof{\W \in [K\,2\pi, (K+1)\,2\pi]}} \le \log N\).
So, if \(\opq\) only had support in a single period for example, the difference between \(\Kale{\qof{\W}}{\pof{\W}}\) and \(\Kale{\qof{[\W]}}{\pof{[\W]}}\) would be \(\log N\), capturing the redundancy.
Predictive Prior
How does the predictive prior term fit into this? The DPI again yields the answer:
This tells us that the predictive prior term can at best measure the 🥬 divergence between the equivalence classes of the parameters—and not between the parameters itself. Luckily, this is the more meaningful divergence anyway.
For the equality cases, we note that:
- we need a 1:1 mapping between parameters and equivalence classes for the first bound to be tight, and
- we need \(\Kale{\qof{\Y_n,...,\Y_1\given\x_n,...,\x_1}}{\pof{\Y_n,...,\Y_1\given\x_n,...,\x_1}} \to 0\) for \(n \to \infty\) for the second bound to be tight.
For 2., we know from the chain rule that \(\Kale{\qof{\Y_n,...\Y_1\given\x_n,...,\x_1}}{\pof{\Y_n,...\Y_1\given\x_n,...,\x_1}}\) is monotonically increasing in \(n\), and as it is bounded by \(\Kale{\qof{[\W]}}{\pof{[\W]}}\) from above, it must converge.
To give intuition and without attempting to prove this formally, we can appeal to Bernstein von Mises theorem, which states that the posterior distribution of the parameters converges to a Gaussian distribution with mean and variance given by the maximum likelihood estimate (MLE) as the number of samples goes to infinity as long as the model parameters are identifiable, that is the true parameters we want to learn are unique:
In the space of equivalence classes given our definitions, the [MLE] will be unique by definition and thus the model identifiable, and as the MLE is prior-independent, both \(\opq\) and \(\opp\) will converge to the MLE. In other words, both \(\opq\) and \(\opp\) will converge to the same equivalence class, and \(\Kale{\qof{[\W]\given \Y..., \x...}}{\pof{[\W] \given \Y..., \x...}} \to 0\) for \(n \to \infty\). Thus, we have: \[ \begin{align} \Kale{\qof{[\W]}}{\pof{[\W]}} = \sup_{n\in \mathbb{N}} \Kale{\qof{\Y_n,...,\Y_1\given\x_n,...,\x_1}}{\pof{\Y_n,...,\Y_1\given\x_n,...,\x_1}}. \end{align} \]
Parameter Priors vs Predictive Priors
What is the advantage of this all?
In Bayesian deep learning, we often use parameter priors that are not meaningful in the sense that they do not take parameter symmetries into account. For example, a unit Gaussian prior over the parameters of a neural network does not induce different predictions for different parameters necessarily. While this prior can be sensible from a parameter compression perspective (e.g. see Hinton and van Camp (1993)13), this does not have to be the only consideration.
With function priors, that is predictive priors, we can induce more meaningful priors because we can focus on the predictions and ignore the parameters. This can connect Bayesian approaches to data augmentation and other regularization techniques.
Given that these priors are difficult to express explicitly though, using the DPI to obtain a functional ELBO can be an easier way to express and approximate them.
Label Entropy Regularization
Applied to the functional ELBO, we can gain a new perspective on label entropy regularization.
The functional evidence bound can be lower-bounded using the chain rule by: \[ \begin{align} \E{\qof{\w}}{-\log \pof{\Dany \given \w}} + \Kale{\qof{\Y... \given \x...}}{\pof{\Y... \given \x...}} \\ \ge \E{\qof{\w}}{-\log \pof{\Dany \given \w}} + \E{\pdata{\x}}{\Kale{\qof{\Y \given \x}}{\pof{\Y \given \x}}}, \end{align} \] where we can expand the term under the second expectation to: \[ \Kale{\qof{\Y \given \x}}{\pof{\Y \given \x}}=\CrossEntropy{\qof{\Y \given \x}}{\pof{\Y \given \x}} - \xHof{\qof{\Y \given \x}}. \] Assuming that our prior yields a uniform distribution over the labels, we can drop the cross entropy term because it is constant and obtain: \[ \E{\qof{\w}}{-\log \pof{\Dany \given \w}} - \E{\pdata{\x}}{\xHof{\qof{\Y \given \x}}}. \] This is the same as an MLE minimization objective with an additional entropy regularization term \(-\xHof{\qof{\Y \given \x}}\) for different \(\x\) that prevents the model from overfitting to the labels and collapsing to the one-hot encoding of the labels.
Thus, in the simplest approximation, the DPI and functional variational inference give us a new perspective on label entropy regularization (and also label noise as I will shortly discuss in the next post 🤞).
Knowledge Distillation
Obviously, assuming non-uniform prior predictions, \(\E{\pdata{\x}}{\Kale{\qof{\Y \given \x}}{\pof{\Y \given \x}}}\) can be related to knowledge distillation in deep neural networks as introduced by Hinton et al. (2015)14.
The main difference is that knowledge distillation is using the reverse KL divergence instead of the forward KL divergence.
Only that we are not distilling the knowledge from a teacher model but from the prior that we downweigh while also training our model on the data itself.
Conclusion
In summary, the data processing inequality provides an elegant perspective on why optimizing the variational posterior in function-space can be an effective inference strategy. This basic tool from information theory has helped us with a very simple deduction of a modern Bayesian deep learning method.
While recovering the results in their full generality is beyond the scope of this post, we were able to derive several interesting results from the literature in simplified form (or at least provide strong intuitions and motivation for them).
Further, we gained intuition for the relationship between parameter and predictive priors and how they relate to the DPI. Predictive priors can be seen as more meaningful priors for Bayesian neural networks because of their ability to ignore parameter symmetries.
Most importantly, we have looked at equivalence classes based on the predictions and pointed out how the predictive prior divergences approximate this functional 🥬 divergence between equivalence classes and not the usual parameter-based 🥬 divergences.
Lastly, we have seen that the DPI can be used to derive a new perspective on label entropy regularization, which is a commonly used regularization technique in deep learning.
Acknowledgements. Many thanks to Freddie Bickford Smith for very helpful comments and feedback on this post and to Tim Rudner for additional pointers to relevant literature and feedback on the FSVI section in particular 🤗