Have you ever wondered why we treat data selection differently during training versus before training?

This post explores the fundamental distinction between active learning and data filtering. The two approaches that seem similar on the surface but operate quite differently in practice. By understanding their differences, we can make better choices in designing data-efficient learning pipelines.

Let’s start with a simple question: What makes these approaches distinct? The key lies in two aspects:

  1. when we make selection decisions (operational context); and
  2. how we use informativeness scores.

This distinction leads to some fascinating implications for how we design and use these methods in practice. After reading this post, you’ll understand why these approaches evolved differently and how their distinctions arise naturally from the underlying mathematics of information theory.

The Core Distinction: Timing and Direction

The operational context, that is when we make selection decisions, fundamentally shapes how these methods work:

  • Active Learning/Sampling: Makes online/on-policy selections during training
  • Data Filtering: Makes offline/off-policy selections before training begins

This timing difference naturally leads to different ways of using informativeness scores:

  • Active Learning/Sampling: Selects the most informative samples
  • Data Filtering: Removes the least informative samples

While selecting or removing might seem equivalent mathematically (accepting 10% is the same as rejecting 90%), the operational context leads to different practical approaches and approximations.

But why did these methods evolve this way?

In short, the answer lies in how informativeness behaves as data accumulate: namely, it is (approximately) submodular.

$$\require{mathtools} \DeclareMathOperator{\opExpectation}{\mathbb{E}} \newcommand{\E}[2]{\opExpectation_{#1} \left [ #2 \right ]} \newcommand{\simpleE}[1]{\opExpectation_{#1}} \DeclareMathOperator{\opVariance}{Var} \newcommand{\Var}[2]{\opVariance_{#1} \left [ #2 \right ]} $$

Informativeness and Submodularity

To see why selection and removal differ, we first need to understand the properties of “informativeness” because data selection approaches fundamentally rely on approximating the answer to the question: how much will training a model on a specific sample benefit its performance?

Crucially, informativeness often exhibits submodularity: the marginal gain of adding a new data point decreases as more data points are included. Each new data point we add to our training set provides less additional information than the previous one. This happens because new samples often contain information that overlaps with what we’ve already learned from previous samples. As the training set grows, new samples offer diminishing returns in terms of informativeness.

Submodularity is a well-understood property (see e.g. Nemhauser et al, 19781), and the expected information gain, which provides a natural measure of informativeness, is submodular or at least approximately submodular in many cases (Das and Kempe, 20182). (In the case of the parameter-based expected information gain, it is always submodular.)

This isn’t just an intuitive concept; it’s mathematically well-grounded. Submodularity is a well-understood property (see e.g. Nemhauser et al, 19783), and the expected information gain, which provides a natural measure of informativeness, is submodular or at least approximately submodular in many cases (Das and Kempe, 20184). (In the case of the parameter-based expected information gain, it is always submodular.)

But why does this matter for active learning versus filtering? Let’s break it down:

  1. For Active Learning: Submodularity tells us we need to constantly re-evaluate which samples are most informative. What was highly informative previously might not be informative anymore, especially after we’ve trained on similar samples.

  2. For Data Filtering: Submodularity gives us confidence that samples deemed uninformative at the start will likely remain so throughout training. This makes early filtering a safe operation.

Active Learning: Iterative Selection

Active learning methods operate iteratively, computing informativeness scores in an online manner, typically at each training iteration. At each step, that currently appear most informative for the model’s learning process are selected for training.

Some active learning strategies, like BatchBALD, take this a step further. They explicitly use submodularity to efficiently select batches of samples. But whether explicit or implicit, the key insight remains: we need to continuously reassess sample importance as our model learns and evolves.

Figure 1: Active learning vs. data filtering: A conceptual comparison of how these approaches evaluate and use sample informativeness.

Data Filtering: Early Rejection

What makes data filtering fundamentally different? It’s all about timing.

Data filtering takes a different approach by evaluating informativeness offline, before training begins. The goal is to identify and remove less informative samples to streamline the training process. But how can we be confident in making such early decisions?

Again, submodularity provides the answer. With (approximate) submodularity, we can be reasonably confident that samples deemed uninformative at the start won’t suddenly become highly informative later. This property gives us the theoretical backing to make these early filtering decisions.

Figure 2: Visualization of how sample informativeness evolves during active learning.

The Bayesian Perspective: A Unifying Framework

How do these approaches connect to Bayesian methods? Let’s explore this connection to gain deeper insights.

Both active learning and data filtering can be understood through the lens of information theory. Active learning methods like BALD (Houlsby et al, 20115) select samples that maximize the mutual information between model parameters and predictions. Data filtering, meanwhile, can be viewed as removing samples that would contribute minimally to the posterior update.

But here’s what’s fascinating: many non-Bayesian methods can be interpreted as approximations of these Bayesian approaches (Kirsch et al, 20226). This suggests a deeper unity underlying these seemingly different methods.

A Practical Example: MNIST Active Learning

Let’s make this concrete with a real example. We conducted an active learning experiment on MNIST using a LeNet-5 model with Monte Carlo dropout. The experiment iteratively selects the most informative samples based on BALD scores and trains the model on the acquired samples.

The animation in Figure 3 shows the evolution of BALD scores of the training set using an Exponential Moving Average (EMA) over iterations. While the EMA smooths out noise in the informativeness estimates and is thus only illustrative of the true underlying dynamics, it’s a reasonable practical choice for visualizing trends. In real applications, one would want to use better estimators of the EIG scores or many more samples.

Figure 3: Evolution of BALD scores over time, smoothed using Exponential Moving Average (EMA).

Conclusion

What have we learned about the relationship between active learning and data filtering? Let’s summarize the key insights:

  1. The timing of selection decisions fundamentally shapes how these methods work
  2. Submodularity provides theoretical backing for both approaches
  3. Real-world applications often require balancing theoretical ideals with practical constraints

While submodularity typically holds, recognizing cases where sample informativeness can increase under certain conditions might still be somewhat unexplored for filtering (at least in a principled fashion as far as I know). Identifying and analyzing these scenarios presents an exciting direction for future research, potentially refining both active learning and data filtering methodologies further.

By recognizing these differences and the underlying principle of (approximate) submodularity, we can make more informed choices about what approximations to use and what heuristics to employ.

There’s still much to explore about cases where sample informativeness behaves in unexpected ways. This presents exciting opportunities for future research, particularly in:

  • Understanding when and why informativeness patterns deviate from theory
  • Developing more robust methods that account for these deviations
  • Creating hybrid approaches that combine the strengths of both methods

Appendix

Conceptual Illustrations

The figures above illustrate the key concepts we’ve discussed. Let me explain how they were created:

Figure 1 shows a conceptual visualization of the informativeness curve, where samples are sorted by their informativeness (y-axis) along the x-axis. This curve demonstrates the diminishing returns property characteristic of submodularity. The visualization highlights:

  • The data filtering region (red shaded area) shows samples that would be filtered out because their informativeness falls below a threshold.
  • The active learning region (blue shaded area) represents samples that would be prioritized by active learning approaches because of their high informativeness.

Figure 2 is an animation that simulates how active learning works over time:

  1. We start with a distribution of samples with varying informativeness levels.
  2. At each step, the most informative sample (highest point) is selected for training.
  3. After selection, that sample’s informativeness drops to zero (it’s been “used”).
  4. The informativeness of other samples also decreases (with some random noise), reflecting how the value of remaining samples changes as the model learns.

This animation captures the dynamic, iterative nature of active learning, where the most informative sample is constantly changing as training progresses. It visually demonstrates why we need to re-evaluate informativeness at each step rather than making all selection decisions upfront.

Both visualizations were created using matplotlib with an XKCD-style aesthetic to make the concepts more approachable and intuitive.

Illustration Code
"""Generate an XKCD-style illustration of submodularity and how active learning vs.
   data filtering interact with the informativeness curve.

   The script produces both static PNG/SVG files and an animated GIF showing how
   informativeness changes over time as data gets trained on.

   Usage:
       python scripts/submodularity_curve.py               # default path
       python scripts/submodularity_curve.py --out /tmp/figure.png
"""

# %%
from __future__ import annotations

import argparse
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.animation import FuncAnimation, PillowWriter

# %%
# Synthetic submodular-like curve: diminishing returns.
x = np.linspace(1, 10000, 10000)
y_base = 1 / (0.9 / 99.9 * x + (1 - 0.9 / 99.9))

y_filtering_threshold = 0.05  # filtering threshold: keep samples with y >= threshold.
y_active_threshold = 0.2

filtering_color = "red"
active_color = "blue"

# First create the static plots
with plt.xkcd():
    fig, ax = plt.subplots(figsize=(8, 5))

    # Plot the diminishing-returns curve.
    ax.plot(x, y_base, color="black")

    # Shaded region for samples that would be *filtered out* (y below threshold).
    ax.fill_between(x, 0, y_base, where=y_base < y_filtering_threshold, color=filtering_color, alpha=0.1,
                    label="Data filtering region (offline)")

    # Horizontal dashed line showing the filtering threshold.
    ax.axhline(y_filtering_threshold, color=filtering_color, linestyle="--", linewidth=1)
    
    # Vertical dashed line showing the active learning region.
    ax.axvline(x[np.argmax(y_base <= y_active_threshold)], color=active_color, linestyle="--", linewidth=1)
    
    # Shaded vertical region for samples that would be considered for active learning.
    ax.fill_betweenx(y_base, 0, x, where=y_base > y_active_threshold, color=active_color, alpha=0.1,
                    label="Active learning region (online)")
    
    # Styling.
    ax.set_xlabel("Samples sorted by informativeness")
    ax.set_ylabel("Informativeness")
    ax.set_title("Active Selection vs. Data Filtering")
    ax.set_ylim(0, 1.05)
    ax.set_xlim(0, x.max())
    ax.legend()
    ax.grid(False)
    
    out_path = Path("active_vs_filtering_xkcd.png")
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, dpi=150, bbox_inches="tight")

    out_path = Path("active_vs_filtering_xkcd.svg")
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, dpi=150, bbox_inches="tight")

plt.show()
#%%
# Now create the animated version
n_frames = 40
top_k = 1

current_y = y_base.copy()

with plt.xkcd():
    fig_anim, ax_anim = plt.subplots(figsize=(8, 5))
    scatter = ax_anim.scatter([], [], color="black", s=1)
    selected_scatter = ax_anim.scatter([], [], color=active_color, s=100, alpha=0.1)
    
    # Horizontal dashed line showing the filtering threshold.
    ax_anim.axhline(y_filtering_threshold, color=filtering_color, linestyle="--", linewidth=1)
    
    # Styling
    ax_anim.set_xlabel("Samples sorted by informativeness")
    ax_anim.set_ylabel("Informativeness")
    ax_anim.set_title("Simulation of Top-1 Active Selection")
    ax_anim.set_ylim(0, 1.05)
    ax_anim.set_xlim(0, 1000)
    ax_anim.grid(False)

    def init():
        scatter.set_offsets(np.c_[x, y_base])
        return (scatter, selected_scatter)

    def animate(frame):
        print(frame)
        global current_y
        # Add noise to the decay
        if frame > 1:
            top_k_indices = np.argsort(current_y)[-top_k:]
            current_y[top_k_indices] = 0.0
            # Decay the other points
            noise = np.clip(np.random.gumbel(0, 0.1, len(current_y)), 0, 1)
            current_y = current_y * (1 - noise)
            
        # Set the top 10 points to 0
        top_k_indices = np.argsort(current_y)[-top_k:]
        # Append the top 10 points to the selected scatter
        selected_scatter.set_offsets(np.c_[x[top_k_indices].copy(), current_y[top_k_indices].copy()])

        scatter.set_offsets(np.c_[x, current_y])
        return (scatter, selected_scatter)

    anim = FuncAnimation(fig_anim, animate, init_func=init,
                        frames=n_frames, interval=200, blit=False, save_count=40)

fig_anim.tight_layout()
writer = PillowWriter(fps=5)
out_path = Path("active_selection_animation.gif")
out_path.parent.mkdir(parents=True, exist_ok=True)
anim.save(out_path, writer=writer)

# %%

MNIST Active Learning Experiment

In our MNIST active learning experiment, we used a LeNet-5 model with Monte Carlo dropout to iteratively select the most informative samples based on BALD scores. The experiment started with a small initial labeled set and, at each iteration, acquired the most informative unlabeled samples for training. The BALD scores were calculated using multiple Monte Carlo samples to estimate the mutual information between model parameters and predictions.

The animation in Figure 3 visualizes the evolution of BALD scores over iterations. It shows how the informativeness of samples changes as the model learns, with the most informative samples being selected and their scores dropping to zero after acquisition. This dynamic process highlights the importance of re-evaluating sample informativeness during training.

The experiment was implemented in mnist_al_experiment.py, which handles the model training, BALD score calculation, and sample acquisition. The animation was created using bald_animation.py, which retrieves data from Weights & Biases (wandb) and generates the visualization of BALD score evolution.

Experiment Code
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from tqdm import tqdm

import wandb

# --- Configuration ---
DEVICE = torch.device(
    "mps"
    if torch.backends.mps.is_available()
    else "cuda"
    if torch.cuda.is_available()
    else "cpu"
)
RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

# MNIST Hyperparameters
N_CLASSES = 10
IMG_WIDTH = 28
IMG_HEIGHT = 28
N_CHANNELS = 1

# LeNet Model Definition (LeNet-5 architecture with MC Dropout)
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv1_drop = nn.Dropout2d()
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(1024, 128)
        self.fc1_drop = nn.Dropout()
        self.fc2 = nn.Linear(128, N_CLASSES)

    def forward(self, input: torch.Tensor):
        input = F.relu(F.max_pool2d(self.conv1_drop(self.conv1(input)), 2))
        input = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(input)), 2))
        input = input.view(-1, 1024)
        input = F.relu(self.fc1_drop(self.fc1(input)))
        input = self.fc2(input)
        return input


def bootstrap_sample(dataset, sample_size=None, replace=True, random_state=None):
    """
    Creates a bootstrap sample from the dataset.

    Args:
        dataset: PyTorch dataset to sample from
        sample_size: Size of the bootstrap sample (defaults to len(dataset))
        replace: Whether to sample with replacement (True for bootstrap)
        random_state: Random seed for reproducibility

    Returns:
        A PyTorch Subset containing the bootstrapped samples
    """
    if sample_size is None:
        sample_size = len(dataset)

    rng = np.random.default_rng(random_state)

    # Generate indices with replacement
    indices = rng.choice(len(dataset), size=sample_size, replace=replace)

    # Return a subset of the dataset with the selected indices
    return torch.utils.data.Subset(dataset, indices)


def create_bootstrap_loader(
    dataset,
    batch_size=64,
    sample_size=None,
    replace=True,
    random_state=None,
    num_workers=0,
    shuffle=True,
):
    """
    Creates a DataLoader with bootstrapped samples from the dataset.

    Args:
        dataset: PyTorch dataset to sample from
        batch_size: Batch size for the DataLoader
        sample_size: Size of the bootstrap sample (defaults to len(dataset))
        replace: Whether to sample with replacement (True for bootstrap)
        random_state: Random seed for reproducibility
        num_workers: Number of worker processes for data loading
        shuffle: Whether to shuffle the data

    Returns:
        A DataLoader containing bootstrapped samples
    """
    bootstrap_dataset = bootstrap_sample(dataset, sample_size, replace, random_state)

    return torch.utils.data.DataLoader(
        bootstrap_dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
    )


# EIG (BALD) Calculation
def get_scores(model, data_loader, n_mc_samples, criterion=None):
    """
    Calculates BALD scores for samples in data_loader.
    BALD = H(E[p(y|x,w)]) - E[H(p(y|x,w))]
    """
    model.train()
    all_probs_mc = []  # To store [n_samples_dataset, n_mc_samples, n_classes]
    all_labels = []
        
    with torch.inference_mode():
        for data, labels in tqdm(data_loader, desc="Calculating BALD Scores"):
            data = data.to(DEVICE)
            batch_probs_mc = []  # Store MC samples for this batch [n_mc_samples, batch_size, n_classes]
            for _ in range(n_mc_samples):
                output = model(data)
                probs = F.softmax(output, dim=1)
                batch_probs_mc.append(probs.unsqueeze(0))

            batch_probs_mc = torch.cat(
                batch_probs_mc, dim=0
            )  # [n_mc_samples, batch_size, n_classes]
            all_probs_mc.append(
                batch_probs_mc.permute(1, 0, 2).cpu()
            )  # [batch_size, n_mc_samples, n_classes]
            all_labels.append(labels) # [batch_size]

    all_probs_mc_tensor = torch.cat(
        all_probs_mc, dim=0
    )  # [total_samples, n_mc_samples, n_classes]
    all_labels = torch.cat(all_labels, dim=0) # [total_samples]

    # Entropy of mean predictions: H(E[p(y|x,w)])
    mean_probs = torch.mean(all_probs_mc_tensor, dim=1)  # [total_samples, n_classes]
    ic_probs = mean_probs * torch.log(mean_probs)
    ic_probs[torch.isnan(ic_probs)] = 0.0
    entropy_of_mean = -torch.sum(ic_probs, dim=1)  # [total_samples]

    # Mean of entropy of predictions: E[H(p(y|x,w))]
    ic_probs_mc = all_probs_mc_tensor * torch.log(all_probs_mc_tensor)
    ic_probs_mc[torch.isnan(ic_probs_mc)] = 0.0
    entropy_per_mc_sample = -torch.sum(
        ic_probs_mc, dim=2
    )  # [total_samples, n_mc_samples]
    mean_of_entropy = torch.mean(entropy_per_mc_sample, dim=1)  # [total_samples]

    bald_scores = entropy_of_mean - mean_of_entropy
    
    acc = (mean_probs.argmax(dim=1) == all_labels).float().mean() * 100.0
    if criterion is not None:
        loss = criterion(mean_probs.log(), all_labels)
    else:
        loss = -1.0
    return bald_scores.numpy(), acc, loss


# Training function
def train_model(model, train_loader, optimizer, criterion, epochs):
    model.train()  # Set to train mode (enables dropout, batchnorm updates etc.)
    pbar = tqdm(range(epochs), desc="Training Epochs")
    for epoch in pbar:
        epoch_loss = 0
        correct = 0
        total = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(DEVICE), target.to(DEVICE)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
            
        pbar.set_postfix(loss=epoch_loss/len(train_loader), acc=100.*correct/total)


def plot_scores(scores):
    # Sort scores descending
    sorted_scores = -np.sort(-scores)
    # Plot the top 100 scores
    plt.plot(sorted_scores)
    plt.show()


# --- Data Loading and Preparation ---
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

# Load full training and test datasets
full_train_dataset = datasets.MNIST(
    "./data", train=True, download=True, transform=transform
)
# Subset the full training dataset to 10000 random samples
# Use numpy's default_rng with fixed seed for reproducibility
rng = np.random.default_rng(1)
indices = rng.permutation(len(full_train_dataset))[:10000]
full_train_dataset = torch.utils.data.Subset(full_train_dataset, indices)
all_train_loader_for_scores = DataLoader(
    full_train_dataset, batch_size=1024, shuffle=False
)

test_dataset = datasets.MNIST("./data", train=False, transform=transform)
test_loader = DataLoader(
    test_dataset, batch_size=1024, shuffle=False
)  # For final eval & test scores

# --- Active Learning Setup ---
N_INITIAL_LABELED = 60  # Number of initial labeled samples
N_ACQUIRE_PER_ITER = 1  # Number of samples to acquire each iteration
N_ACTIVE_LEARNING_ITERATIONS = 100
N_MC_SAMPLES_EIG = 32  # Number of MC samples for EIG
TRAIN_EPOCHS_PER_ITER = 3  # Epochs to train model at each AL iteration
LEARNING_RATE = 0.0005
BATCH_SIZE_TRAIN = 64  # Batch size for training
N_MIN_SAMPLES_PER_EPOCH = 10_000

# Initialize wandb
wandb.init(
    project="blog-active-learning-vs-filtering",
    config={
        "initial_labeled_samples": N_INITIAL_LABELED,
        "acquire_per_iter": N_ACQUIRE_PER_ITER,
        "al_iterations": N_ACTIVE_LEARNING_ITERATIONS,
        "mc_samples_eig": N_MC_SAMPLES_EIG,
        "epochs_per_iter": TRAIN_EPOCHS_PER_ITER,
        "learning_rate": LEARNING_RATE,
        "batch_size": BATCH_SIZE_TRAIN,
        "model": "LeNet-5 with MC Dropout",
        "dataset": "MNIST",
        "device": str(DEVICE),
    }
)

# Create initial labeled and unlabeled pools
num_train_samples = len(full_train_dataset)
all_indices = np.arange(num_train_samples)
np.random.shuffle(all_indices)

initial_labeled_indices = list(all_indices[:N_INITIAL_LABELED])
current_unlabeled_indices = list(all_indices[N_INITIAL_LABELED:])

# Store all acquisition scores
# Format: list of dicts, each dict is one AL iteration
# {'iteration': i, 'labeled_indices': [...], 'unlabeled_indices': [...],
#  'scores_for_all_train_samples': np.array([...]), 'scores_for_test_samples': np.array([...])}
all_scores_history = []


# --- Active Learning Loop ---
print("Starting Active Learning Loop...")
current_labeled_indices_set = set(initial_labeled_indices)

for al_iteration in tqdm(range(N_ACTIVE_LEARNING_ITERATIONS + 1), desc="Active Learning Iterations"):  # +1 for initial state and after last acquisition
    print(f"\n--- Active Learning Iteration: {al_iteration} ---")
    print(f"Currently labeled samples: {len(current_labeled_indices_set)}")

    # 1. Create current labeled dataset and loader
    labeled_subset = Subset(full_train_dataset, list(current_labeled_indices_set))
    labeled_loader = create_bootstrap_loader(
        labeled_subset,
        batch_size=BATCH_SIZE_TRAIN,
        shuffle=True,
        random_state=RANDOM_SEED,
        sample_size=max(N_MIN_SAMPLES_PER_EPOCH, len(labeled_subset)),
        replace=N_MIN_SAMPLES_PER_EPOCH > len(labeled_subset),
    )

    # 2. Initialize or re-initialize model and optimizer
    model = LeNet().to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()

    # 3. Train model on current labeled data (unless it's iteration 0 before any acquisition)
    if N_INITIAL_LABELED > 0:  # Train if there's data
        print("Training model...")
        train_model(
            model, labeled_loader, optimizer, criterion, epochs=TRAIN_EPOCHS_PER_ITER
        )
    
    # 4. Calculate EIG scores for all original training samples and test samples
    print("Calculating EIG scores for test samples...")
    scores_test_samples, test_acc, test_loss = get_scores(model, test_loader, N_MC_SAMPLES_EIG, criterion)
    print(f"Model trained. Test Acc: {test_acc:.2f}%, Test Loss: {test_loss:.4f}")
    
    print("Calculating EIG scores for all original training samples...")
    # Create a loader for ALL original training samples to get their scores
    scores_all_train_samples, _, _ = get_scores(
        model, all_train_loader_for_scores, N_MC_SAMPLES_EIG
    )
    
    
    # Plot the scores
    plt.figure(figsize=(10, 5))
    plt.title("BALD Scores for all original training samples")
    plt.ylabel("BALD Score")
    plot_scores(scores_all_train_samples)
    plt.show()
    
    plt.figure(figsize=(10, 5))
    plt.title("BALD Scores for all test samples")
    plt.ylabel("BALD Score")
    plot_scores(scores_test_samples)
    plt.show()
    
    # Store scores for this iteration
    iteration_scores_data = {
        "iteration": al_iteration,
        "current_labeled_indices": sorted(list(current_labeled_indices_set)),
        "current_unlabeled_indices": sorted(
            current_unlabeled_indices
        ),  # These are indices FROM the original full_train_dataset
        "scores_for_all_train_samples": scores_all_train_samples,  # Order matches full_train_dataset
        "scores_for_test_samples": scores_test_samples,  # Order matches test_dataset
    }
    all_scores_history.append(iteration_scores_data)
    print(f"Scores stored for iteration {al_iteration}.")
    
    # Log metrics to wandb
    wandb_log_data = {
        "iteration": al_iteration,
        "num_labeled_samples": len(current_labeled_indices_set),
        "test_accuracy": test_acc,
        "test_loss": test_loss,
        "avg_train_bald_score": np.mean(scores_all_train_samples),
        "avg_test_bald_score": np.mean(scores_test_samples),
        "max_train_bald_score": np.max(scores_all_train_samples),
        "max_test_bald_score": np.max(scores_test_samples),
    }
    
    # Log histograms of BALD scores
    wandb.log(wandb_log_data, commit=False)
    
    # Log histograms of BALD scores (optional)
    # Create a table with all the scores
    train_scores_table = wandb.Table(columns=["index", "bald_score"])
    for idx, score in enumerate(scores_all_train_samples):
        train_scores_table.add_data(idx, float(score))
    
    test_scores_table = wandb.Table(columns=["index", "bald_score"])
    for idx, score in enumerate(scores_test_samples):
        test_scores_table.add_data(idx, float(score))
        
    # Create plotly figures for BALD scores
    # Create sorted scores for better visualization
    sorted_train_scores = sorted(scores_all_train_samples, reverse=True)
    sorted_test_scores = sorted(scores_test_samples, reverse=True)
    
    # Create dataframes for plotly
    train_df = pd.DataFrame({
        "index": range(len(sorted_train_scores)),
        "bald_score": sorted_train_scores
    })
    
    test_df = pd.DataFrame({
        "index": range(len(sorted_test_scores)),
        "bald_score": sorted_test_scores
    })
    
    # Create plotly figures
    train_fig = px.line(train_df, x="index", y="bald_score", 
                        title=f"Sorted BALD Scores for Training Samples (Iteration {al_iteration})")
    train_fig.update_layout(xaxis_title="Sample Index (sorted)", yaxis_title="BALD Score")
    
    test_fig = px.line(test_df, x="index", y="bald_score", 
                       title=f"Sorted BALD Scores for Test Samples (Iteration {al_iteration})")
    test_fig.update_layout(xaxis_title="Sample Index (sorted)", yaxis_title="BALD Score")
    
    # Log plotly figures to wandb
    wandb.log({
        "train_bald_scores_plot": wandb.Plotly(train_fig),
        "test_bald_scores_plot": wandb.Plotly(test_fig),
        "train_bald_scores": train_scores_table,
        "test_bald_scores": test_scores_table,
    }, commit=False)

    # 5. If it's not the last iteration, perform acquisition
    if al_iteration < N_ACTIVE_LEARNING_ITERATIONS:
        if not current_unlabeled_indices:
            print("No more unlabeled samples to acquire. Stopping.")
            break

        print("Acquiring new samples...")
        # We need scores only for the *currently unlabeled* samples to decide which to pick
        # `scores_all_train_samples` contains scores for *all* original training samples.
        # We filter these down to only the unlabeled ones.

        unlabeled_scores_map = {
            idx: scores_all_train_samples[idx] for idx in current_unlabeled_indices
        }

        # Sort unlabeled samples by their EIG score (descending)
        sorted_unlabeled_by_score = sorted(
            unlabeled_scores_map.items(), key=lambda item: item[1], reverse=True
        )

        # Select top N_ACQUIRE_PER_ITER samples
        num_to_acquire = min(N_ACQUIRE_PER_ITER, len(sorted_unlabeled_by_score))
        acquired_indices_scores = sorted_unlabeled_by_score[:num_to_acquire]
        acquired_indices = [idx for idx, score in acquired_indices_scores]
        acquired_scores = [score for idx, score in acquired_indices_scores]

        if not acquired_indices:
            print(
                "Could not acquire any new samples (perhaps scores were uniform or list was empty). Stopping."
            )
            break

        print(f"Acquired {len(acquired_indices)} new samples ({np.mean(acquired_scores):.2f} avg score).")
        
        # Log acquisition information
        wandb.log({
            "acquisition_step": al_iteration,
            "acquired_indices": acquired_indices,
            "acquired_scores": acquired_scores,
            "avg_acquired_score": np.mean(acquired_scores) if acquired_scores else 0,
        }, commit=False)

        # Add to labeled set and remove from unlabeled pool
        for idx in acquired_indices:
            current_labeled_indices_set.add(idx)
            current_unlabeled_indices.remove(
                idx
            )  # Keep this list of original indices up to date
    else:
        print("Final iteration reached. No more acquisitions.")
        
    wandb.log({}, step=al_iteration, commit=True)

# Finish wandb run
wandb.finish()
Animation Code
#!/usr/bin/env python3
"""
Create an animation showing how BALD scores evolve during active learning.
This script retrieves data from wandb and creates a similar animation to the one in visualizations.py.
Uses BALD scores without sorting to show the actual distribution of uncertainty.
"""
#%%
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
from pathlib import Path
import json
import pandas as pd
from tqdm.auto import tqdm
from wandb.apis.public import Api as WandbApi

# Initialize wandb API
api = WandbApi()

# Find the latest run in the project
runs = api.runs("blog-active-learning-vs-filtering", order="-created_at")
latest_run = runs[0]  # Most recent run

print(f"Analyzing run: {latest_run.name} ({latest_run.id})")

#%%
# Get data from the run - we need to extract BALD scores and acquisition info
scores_by_iter = {}
acquired_indices_by_iter = {}
train_subset_indices = {}  # To store which indices were already labeled

# Fetch the history data
print("Fetching history data...")
scores_key = "train_bald_scores"
history = latest_run.history(keys=["iteration", scores_key], pandas=False)
#%%
# Process the history data
for row in tqdm(history):
    if "iteration" not in row:
        continue
    
    iteration = row["iteration"]
    
    # Get BALD scores if available in this row
    if scores_key in row:
        try:
            # Try to get scores from the run table
            table_data = row[scores_key]
            # Load the artifact
            table_file = latest_run.file(table_data['path']).download(replace=True).read()
            table_data = json.loads(table_file)
            df = pd.DataFrame(table_data["data"], columns=table_data["columns"])
            scores_by_iter[iteration] = df["bald_score"].values
        except Exception as e:
            print(f"Error extracting {scores_key} for iteration {iteration}: {e}")
    
    # Get acquired indices if available
    if "acquired_indices" in row:
        acquired_indices_by_iter[iteration] = row["acquired_indices"]
    
    # Get currently labeled indices
    if "current_labeled_indices" in row:
        train_subset_indices[iteration] = row["current_labeled_indices"]

#%%
# If we still don't have data, exit
if not scores_by_iter:
    print(f"Failed to retrieve {scores_key} from the run. Exiting.")
    raise ValueError("Failed to retrieve BALD scores from the run.")

#%%
# Sort iterations and prepare animation data
sorted_iterations = sorted(scores_by_iter.keys())
all_scores = [scores_by_iter[i].copy() for i in sorted_iterations]

acquired_indices = []
for scores in all_scores:
    for i in acquired_indices:
        scores[i] = 0.
    # Get the argmax score
    max_idx = np.argmax(scores)
    acquired_indices.append(max_idx.item())
    
# Pop last acquired index
acquired_indices.pop()

assert np.all(all_scores[-1][acquired_indices] == 0.)

print(f"Acquired indices: {acquired_indices}")

print(f"Found {scores_key} for {len(all_scores)} iterations")
print(f"First iteration has {len(all_scores[0])} samples")

#%%
# Create the animation with XKCD style but using EMA of the scores
ema_decay = 0.9

all_scores_ema = []
current_scores = all_scores[0].copy()
for i in range(len(all_scores)):
    current_scores = ema_decay * current_scores + (1.0 - ema_decay) * all_scores[i]
    for j in range(i):
        current_scores[acquired_indices[j]] = 0.
    all_scores_ema.append(current_scores)

#%%
with plt.xkcd():
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Set up initial plot - we'll use scatter plot for non-sorted scores
    max_points = min(10000, max([len(scores) for scores in all_scores_ema]))
    scatter = ax.scatter([], [], color="black", s=2)
    
    # For highlighted points we'll use a different color
    highlight_scatter = ax.scatter([], [], color="blue", s=100, alpha=0.5)
    
    # Title and labels
    ax.set_title("BALD Score Evolution (Non-sorted w/ EMA)")
    ax.set_xlabel("Sample Index")
    ax.set_ylabel("Informativeness (BALD score)")
    ax.set_xlim(0, max_points)
    
    # Find max BALD score for consistent y-axis scaling
    max_score = max([max(scores) for scores in all_scores_ema])
    ax.set_ylim(0, max_score * 1.1)
    
    # Text annotations
    iter_text = ax.text(0.02, 0.95, '', transform=ax.transAxes, fontsize=12)
    labeled_text = ax.text(0.02, 0.90, '', transform=ax.transAxes, fontsize=10)
    
    # Get the sorted indices of the first iteration
    sorted_indices = np.argsort(all_scores[0])[::-1]
    inv_sorted_indices = np.argsort(sorted_indices)
        
    def init():
        # Use the raw scores (no sorting)
        current_scores = np.array(all_scores_ema[0])
        x_data = np.arange(len(current_scores))
                
        scatter.set_offsets(np.c_[x_data, current_scores])
        
        # Highlight top uncertain points
        highlight_scatter.set_offsets(np.c_[[], []])
        
        # Set iteration text
        iter_text.set_text(f"Iteration: {sorted_iterations[0]}")
        
        return scatter, highlight_scatter, iter_text, labeled_text
    
    def animate(frame_idx):
        if frame_idx < len(sorted_iterations):
            iteration = sorted_iterations[frame_idx]
            # Compute an EMA of the past scores
            current_scores = np.array(all_scores_ema[frame_idx])
            
            # Use raw scores without sorting
            x_data = inv_sorted_indices[np.arange(len(current_scores))]
            
            # Update main scatter plot
            scatter.set_offsets(np.c_[x_data, current_scores.copy()])
            
            # Highlight top uncertain points (these will change each iteration)
            # top_indices = acquired_indices[:frame_idx]
            # highlight_scatter.set_offsets(np.c_[inv_sorted_indices[top_indices], current_scores[top_indices].copy()])
            # assert np.all(current_scores[top_indices[:1]] == 0.), current_scores[top_indices]
            # current_scores[top_indices] = 0.
            
            # Update text
            iter_text.set_text(f"Iteration: {iteration}")
            
        return scatter, highlight_scatter, iter_text, labeled_text

    # Create animation
    num_frames = len(sorted_iterations)
    anim = FuncAnimation(fig, animate, init_func=init,
                         frames=num_frames, interval=50, blit=True)

    # Add legend
    ax.legend()
    
    # Save animation
    writer = PillowWriter(fps=1000/50)
    output_path = Path("bald_scores_ema_animation.gif")
    print(f"Saving animation to {output_path}")
    anim.save(output_path, writer=writer)

print("Done!") 

# %%