This post introduces a family of much less expensive approximations for BatchBALD that might work well where BatchBALD works. You might have noticed that BatchBALD can be very, very slow. We can approximate BatchBALD using pairwise mutual information terms, leading to a new approximation, we call 2-BALD, or generally, following the inclusion-exclusion principle, using up to k-wise mutual information terms1, leading to what call the k-BALD family of approximations for BatchBALD. Importantly, we can dynamically choose the acquisition batch size by estimating the quality of our approximation.
For example, computing BatchBALD with acquisition batch size 5 on 60000 MNIST samples takes 1 min on my machine, 2-BALD also takes 1 min. However, at acquisition batch size 10, BatchBALD already takes >30 min while 2-BALD only takes 2 min. And 2-BALD still performs as well as BatchBALD (at least in the proof of concept experiment on MNIST).
If you like this idea and want to run experiments/write a
paper/make it you own, let me know/add me as co-authors please
Background
Active Learning depends on identifying the most informative points to label. Specifically, active learning consists of determining the most informative points to label from an unlabeled pool set using a model trained on an existing labeled training set; the newly labeled points are then added the training set, and the model is updated; and this is repeated until we are happy with the performance of the model.
From an information-theoretic point of view, this means finding the
unlabeled points in the pool set with the highest expected
information gain, which is also referred to as BALD
score (Houlsby et al,
2011; Gal et al,
2017) when using Bayesian neural networks and is often denoted to
capture the epistemic uncertainty of the model for a given
point. When using Bayesian neural networks, in practice, the BALD scores
measure the disagreement between (Monte-Carlo) parameter samples,
similar to “Query
by Committee” (Seung et al, 1992). To be more specific, the BALD
scores look as follows, where
Research Idea: Applying the Inclusion-Exclusion Principle
Unlike other less principled methods that work amazingly well in practice (Stochastic BALD cough), we can come up with a more principled approximation of BatchBALD by applying the inclusion-exclusion principle known from set theory to our joint mutual information:
Inclusion-Exclusion Principle. For sets
Following R. Yeoung’s “A New Outlook on Shannon’s Information
Measures”, which connects set operations with information quantities
(see also “Better
intuition for information theory” by your truly for a very simple
introduction), we can apply the same to information-theoretic
quantities, which leads to the following statement:
In particular, we can define the following approximations to BatchBALD:
1-BALD:
2-BALD.
k-BALD.
With 1-BALD, we simply recover the well-known top-K BALD, when we maximize over the possible candidates, as also explained in “BatchBALD”.
Initial Results, Challenges, and Questions
Initial Results
For 2-BALD, we have the following result on MNIST with acquisition batch size 5 for BatchBALD (which performs as well as BALD with individual acquisition) and 10 for 2-BALD:
It matches BatchBALD on both while being much, much cheaper to compute!
Questions & challenges
2-BALD does not work better for larger acquisition batch sizes. However, for larger acquisition sizes, there is bad news for 2-BALD:
It performs pretty terrible, and we have examined why:
As we add additional samples and subtract the pairwise interactions, we subtract too much, even pushing the scores to become negative.
Originally higher scoring points become more negative than less informative ones. This spells trouble! At this point, we start to acquire uninformative samples whose scores do not change much but which also do not improve model accuracy. Note that true BatchBALD score is always upper-bounded by the BALD score. Uninformative points with scores close to 0 will always remain uninformative. 2-BALD will thus end up acquiring uninformative points, ensuring that it performs worse than even random acquisition eventually!
Hence, while 2-BALD might be a good replacement for BatchBALD when using comparable acquisition batch sizes, it does not allow us to scale it up just yet. (Though on larger datasets with more classes, we might get away with larger acquisition sizes naturally potentially.)
Dynamic batch acquisition size. We could compute both 2-BALD and 3-BALD and stop the batch acquisition once the scores of 2- and 3-BALD diverge too much.
In Kirsch et al, 2021,
we empirically observed that the total correlation between samples
decreases as the model converges and performance improves. Specifically,
the total correlation between samples goes towards 0 later in training:
It is straightforward to show the following relationship between
BatchBALD and the BALD scores:
From this, we can hope that later in training, the batch size will automatically increase more and more, without loss in label efficiency. Crucially, this depends on the total correlation indeed decreasing further along in training.
Bounding the approximation error. Unlike sets, mutual information terms in more than two variables can be negative, and thus it is not straightforward to bound the approximation error.
Focus on safe acquisitions. We could also play it
safe instead of maximizing 2-BALD through greedy acquisition. This means
we could greedily acquire the highest BALD scorer
The problem with this approach is that only uninformative samples might remain after a few such steps, so we might want to threshold the minimum acceptable BALD score and end the acquisition process dynamically once there are no acceptable samples left.
Could this work? It is definitely a more conservative method.
Is our posterior good enough? To predict many points into the future, we need to both have a good posterior approximation and sample diverse predictions from the posterior. It is not clear whether this is the case and might present yet another challenge to the scalability of this approach. See also our recent paper “Marginal and Joint Cross-Entropies & Predictives for Online Bayesian Inference, Active Learning, and Active Sampling”, which provides pointers in that direction.
Conclusion
k-BALD is a family of BatchBALD approximations which allows us to lower the computational cost of BatchBALD.
The text already raises some interesting additional research questions. Additionally, it might be worth computing BatchBALD scores and 2-BALD scores as the joints increase in size and compare the Spearman rank correlation. This could show how good or bad 2-BALD is when BatchBALD can be computed.
Checking for divergence of 3-BALD (not computed so far) from 2-BALD might be a good way to catch approximation issues and dynamically set the acquisition batch size.