This post introduces a family of much less expensive approximations for BatchBALD that might work well where BatchBALD works. You might have noticed that BatchBALD can be very, very slow. We can approximate BatchBALD using pairwise mutual information terms, leading to a new approximation, we call 2-BALD, or generally, following the inclusion-exclusion principle, using up to k-wise mutual information terms1, leading to what call the k-BALD family of approximations for BatchBALD. Importantly, we can dynamically choose the acquisition batch size by estimating the quality of our approximation.

For example, computing BatchBALD with acquisition batch size 5 on 60000 MNIST samples takes 1 min on my machine, 2-BALD also takes 1 min. However, at acquisition batch size 10, BatchBALD already takes >30 min while 2-BALD only takes 2 min. And 2-BALD still performs as well as BatchBALD (at least in the proof of concept experiment on MNIST).

If you like this idea and want to run experiments/write a paper/make it you own, let me know/add me as co-authors please :hugs:

$$ \require{mathtools} $$

Background

Active Learning depends on identifying the most informative points to label. Specifically, active learning consists of determining the most informative points to label from an unlabeled pool set using a model trained on an existing labeled training set; the newly labeled points are then added the training set, and the model is updated; and this is repeated until we are happy with the performance of the model.

From an information-theoretic point of view, this means finding the unlabeled points in the pool set with the highest expected information gain, which is also referred to as BALD score (Houlsby et al, 2011; Gal et al, 2017) when using Bayesian neural networks and is often denoted to capture the epistemic uncertainty of the model for a given point. When using Bayesian neural networks, in practice, the BALD scores measure the disagreement between (Monte-Carlo) parameter samples, similar to “Query by Committee” (Seung et al, 1992). To be more specific, the BALD scores look as follows, where \(q(\omega)\) is an empirical distribution of the parameters of the ensemble members, or an approximate parameter distribution, e.g., using Monte-Carlo dropout: \[ I[\Omega; Y \mid x] = H[Y \mid x] - H[Y \mid x, \Omega] = \mathbb{E}_{p(\omega)} \, D_{\text{KL}}(p(Y \mid x, \omega) \, \Vert \, p(Y \mid x)). \] In the BatchBALD paper, we show that we can perform batch active learning in a principled way using an extension of BALD to batches, that is joints of samples, which we called BatchBALD. Specifically, we acquire the samples which maximize the joint mutual information of batch candidates \(x_1, \ldots, x_B\) with the model parameters: \[ I[\Omega; Y_1, \ldots, Y_B \mid x_1, \ldots, x_B] \] However, BatchBALD is really slow to compute. Estimating a joint entropy leads to a trade-off between a combinatorial explosion and using Monte-Carlo sampling—which impacts active learning performance.

Research Idea: Applying the Inclusion-Exclusion Principle

Unlike other less principled methods that work amazingly well in practice (Stochastic BALD cough), we can come up with a more principled approximation of BatchBALD by applying the inclusion-exclusion principle known from set theory to our joint mutual information:

Inclusion-Exclusion Principle. For sets \(S_i\), we have: \[ \begin{aligned} | \bigcup_i S_i | = \sum_i | S_i | - \sum_{i<j} | S_i \cap S_j | + \sum_{i<j<k} | S_i \cap S_j \cap S_k | + \cdots \end{aligned} \]

Following R. Yeoung’s “A New Outlook on Shannon’s Information Measures”, which connects set operations with information quantities (see also “Better intuition for information theory” by your truly for a very simple introduction), we can apply the same to information-theoretic quantities, which leads to the following statement: \[ \begin{aligned} & H[Y_1, \ldots, Y_B \mid x_1, \ldots x_B] \\ & \quad = \sum_i H[Y_i \mid x_i] - \sum_{i < j} I[Y_i ; Y_j \mid x_i, x_j] + \sum_{i<j<k} I[Y_i ; Y_j ; Y_k \mid x_i, x_j, x_k] - \cdots \end{aligned} \] and then: \[ \begin{aligned} & I[Y_1, \ldots, Y_B ; \Omega \mid x_1, \ldots x_B] \\ & \quad = H[Y_1, \ldots, Y_B \mid x_1, \ldots x_B] - \sum_i H[Y_i \mid x_i, \Omega] \\ & \quad = \sum_i H[Y_i \mid x_i] - \sum_{i < j} I[Y_i ; Y_j \mid x_i, x_j] + \sum_{i<j<k} I[Y_i ; Y_j ; Y_k \mid x_i, x_j, x_k] - \cdots - \sum_i H[Y_i \mid x_i, \Omega] \\ & \quad = \sum_i I[Y_i ; \Omega \mid x_i] - \sum_{i < j} I[Y_i ; Y_j \mid x_i, x_j] + \sum_{i<j<k} I[Y_i ; Y_j ; Y_k \mid x_i, x_j, x_k] - \cdots \end{aligned} \]

In particular, we can define the following approximations to BatchBALD:

1-BALD: \(\sum_i I[Y_i; \Omega \mid x_i ]\)

2-BALD. \(\sum_i I[Y_i; \Omega \mid x_i ] - \sum_{i < j} I[Y_i ; Y_j \mid x_i, x_j]\)

k-BALD. \(\sum_i I[Y_i ; \Omega \mid x_i] - \sum_{i < j} I[Y_i ; Y_j \mid x_i, x_j] + \sum_{i<j<k} I[Y_i ; Y_j ; Y_k \mid x_i, x_j, x_k] - \cdots\)

With 1-BALD, we simply recover the well-known top-K BALD, when we maximize over the possible candidates, as also explained in “BatchBALD”.

Initial Results, Challenges, and Questions

Initial Results

For 2-BALD, we have the following result on MNIST with acquisition batch size 5 for BatchBALD (which performs as well as BALD with individual acquisition) and 10 for 2-BALD:

2-BALD performance on MNIST. 2-BALD performs as well as BatchBALD for a fraction of the computational cost.

It matches BatchBALD on both while being much, much cheaper to compute!

Questions & challenges

2-BALD does not work better for larger acquisition batch sizes. However, for larger acquisition sizes, there is bad news for 2-BALD:

2-BALD deteriorates with larger acquisition batch sizes.

It performs pretty terrible, and we have examined why:

2-BALD scores eventually become negative. Subtracting interactions between pairs of samples eventually makes the originally most informative points the most negative. Hence, 2-BALD prefers uninformative points later in the acquisition batch.

As we add additional samples and subtract the pairwise interactions, we subtract too much, even pushing the scores to become negative.

Originally higher scoring points become more negative than less informative ones. This spells trouble! At this point, we start to acquire uninformative samples whose scores do not change much but which also do not improve model accuracy. Note that true BatchBALD score is always upper-bounded by the BALD score. Uninformative points with scores close to 0 will always remain uninformative. 2-BALD will thus end up acquiring uninformative points, ensuring that it performs worse than even random acquisition eventually!

Hence, while 2-BALD might be a good replacement for BatchBALD when using comparable acquisition batch sizes, it does not allow us to scale it up just yet. (Though on larger datasets with more classes, we might get away with larger acquisition sizes naturally potentially.)

Dynamic batch acquisition size. We could compute both 2-BALD and 3-BALD and stop the batch acquisition once the scores of 2- and 3-BALD diverge too much.

In Kirsch et al, 2021, we empirically observed that the total correlation between samples decreases as the model converges and performance improves. Specifically, the total correlation between samples goes towards 0 later in training: \[TC[Y_1 ; \ldots ; Y_B \mid x_1 ; \ldots ; x_B] \to 0,\] where the total correlation is defined as: \[TC[Y_1 ; \ldots ; Y_B \mid x_1 ; \ldots ; x_B] = \sum_i H[Y_i \mid x_i] - H[Y_1, \ldots, Y_B \mid x_1, \ldots, x_B].\] The total correlation measures the dependence of the predictions on each other. When it is \(0\), the random variables are independent. Indeed, from statistical learning theory, we know that in the infinite training data limit, the model parameters converge, and the predictions become independent indeed.

It is straightforward to show the following relationship between BatchBALD and the BALD scores: \[I[\Omega; Y_1, \ldots, Y_B \mid x_1, \ldots, x_B] = \sum_i I[\Omega; Y_i \mid x_i] - TC[Y_1 ; \ldots ; Y_B \mid x_1 ; \ldots ; x_B].\] BatchBALD takes into account the total correlation. As the total correlation decreases, 1-BALD becomes closer and closer to BatchBALD.

From this, we can hope that later in training, the batch size will automatically increase more and more, without loss in label efficiency. Crucially, this depends on the total correlation indeed decreasing further along in training.

Bounding the approximation error. Unlike sets, mutual information terms in more than two variables can be negative, and thus it is not straightforward to bound the approximation error.

Focus on safe acquisitions. We could also play it safe instead of maximizing 2-BALD through greedy acquisition. This means we could greedily acquire the highest BALD scorer \(x_{b_1}\) and then remove all pool samples from consideration which have total correlation \(I[Y_{b_1}; Y \mid x_{b-1}, x] > \varepsilon\), then pick the next highest BALD scorer \(x_{b_2}\), and repeat the process.

The problem with this approach is that only uninformative samples might remain after a few such steps, so we might want to threshold the minimum acceptable BALD score and end the acquisition process dynamically once there are no acceptable samples left.

Could this work? It is definitely a more conservative method.

Is our posterior good enough? To predict many points into the future, we need to both have a good posterior approximation and sample diverse predictions from the posterior. It is not clear whether this is the case and might present yet another challenge to the scalability of this approach. See also our recent paper “Marginal and Joint Cross-Entropies & Predictives for Online Bayesian Inference, Active Learning, and Active Sampling”, which provides pointers in that direction.

Conclusion

k-BALD is a family of BatchBALD approximations which allows us to lower the computational cost of BatchBALD.

The text already raises some interesting additional research questions. Additionally, it might be worth computing BatchBALD scores and 2-BALD scores as the joints increase in size and compare the Spearman rank correlation. This could show how good or bad 2-BALD is when BatchBALD can be computed.

Checking for divergence of 3-BALD (not computed so far) from 2-BALD might be a good way to catch approximation issues and dynamically set the acquisition batch size.


  1. Note: the mutual information of more than 2 terms is also sometimes referred to as interaction information. However, I see this as the canonical extension of the pairwise mutual information and hence keep the term mutual information. See also “A Practical & Unified Notation for Information-Theoretic Quantities in ML”.↩︎