Have you ever wondered why we treat data selection differently during training versus before training?
This post explores the fundamental distinction between active learning and data filtering. The two approaches that seem similar on the surface but operate quite differently in practice. By understanding their differences, we can make better choices in designing data-efficient learning pipelines.
Let’s start with a simple question: What makes these approaches distinct? The key lies in two aspects:
- when we make selection decisions (operational context); and
- how we use informativeness scores.
This distinction leads to some fascinating implications for how we design and use these methods in practice. After reading this post, you’ll understand why these approaches evolved differently and how their distinctions arise naturally from the underlying mathematics of information theory.
The Core Distinction: Timing and Direction
The operational context, that is when we make selection decisions, fundamentally shapes how these methods work:
- Active Learning/Sampling: Makes online/on-policy selections during training
- Data Filtering: Makes offline/off-policy selections before training begins
This timing difference naturally leads to different ways of using informativeness scores:
- Active Learning/Sampling: Selects the most informative samples
- Data Filtering: Removes the least informative samples
While selecting or removing might seem equivalent mathematically (accepting 10% is the same as rejecting 90%), the operational context leads to different practical approaches and approximations.
But why did these methods evolve this way?
In short, the answer lies in how informativeness behaves as data accumulate: namely, it is (approximately) submodular.
Informativeness and Submodularity
To see why selection and removal differ, we first need to understand the properties of “informativeness” because data selection approaches fundamentally rely on approximating the answer to the question: how much will training a model on a specific sample benefit its performance?
Crucially, informativeness often exhibits submodularity: the marginal gain of adding a new data point decreases as more data points are included. Each new data point we add to our training set provides less additional information than the previous one. This happens because new samples often contain information that overlaps with what we’ve already learned from previous samples. As the training set grows, new samples offer diminishing returns in terms of informativeness.
Submodularity is a well-understood property (see e.g. Nemhauser et al, 19781), and the expected information gain, which provides a natural measure of informativeness, is submodular or at least approximately submodular in many cases (Das and Kempe, 20182). (In the case of the parameter-based expected information gain, it is always submodular.)
This isn’t just an intuitive concept; it’s mathematically well-grounded. Submodularity is a well-understood property (see e.g. Nemhauser et al, 19783), and the expected information gain, which provides a natural measure of informativeness, is submodular or at least approximately submodular in many cases (Das and Kempe, 20184). (In the case of the parameter-based expected information gain, it is always submodular.)
But why does this matter for active learning versus filtering? Let’s break it down:
For Active Learning: Submodularity tells us we need to constantly re-evaluate which samples are most informative. What was highly informative previously might not be informative anymore, especially after we’ve trained on similar samples.
For Data Filtering: Submodularity gives us confidence that samples deemed uninformative at the start will likely remain so throughout training. This makes early filtering a safe operation.
Active Learning: Iterative Selection
Active learning methods operate iteratively, computing informativeness scores in an online manner, typically at each training iteration. At each step, that currently appear most informative for the model’s learning process are selected for training.
Some active learning strategies, like BatchBALD, take this a step further. They explicitly use submodularity to efficiently select batches of samples. But whether explicit or implicit, the key insight remains: we need to continuously reassess sample importance as our model learns and evolves.
Data Filtering: Early Rejection
What makes data filtering fundamentally different? It’s all about timing.
Data filtering takes a different approach by evaluating informativeness offline, before training begins. The goal is to identify and remove less informative samples to streamline the training process. But how can we be confident in making such early decisions?
Again, submodularity provides the answer. With (approximate) submodularity, we can be reasonably confident that samples deemed uninformative at the start won’t suddenly become highly informative later. This property gives us the theoretical backing to make these early filtering decisions.
The Bayesian Perspective: A Unifying Framework
How do these approaches connect to Bayesian methods? Let’s explore this connection to gain deeper insights.
Both active learning and data filtering can be understood through the lens of information theory. Active learning methods like BALD (Houlsby et al, 20115) select samples that maximize the mutual information between model parameters and predictions. Data filtering, meanwhile, can be viewed as removing samples that would contribute minimally to the posterior update.
But here’s what’s fascinating: many non-Bayesian methods can be interpreted as approximations of these Bayesian approaches (Kirsch et al, 20226). This suggests a deeper unity underlying these seemingly different methods.
A Practical Example: MNIST Active Learning
Let’s make this concrete with a real example. We conducted an active learning experiment on MNIST using a LeNet-5 model with Monte Carlo dropout. The experiment iteratively selects the most informative samples based on BALD scores and trains the model on the acquired samples.
The animation in Figure 3 shows the evolution of BALD scores of the training set using an Exponential Moving Average (EMA) over iterations. While the EMA smooths out noise in the informativeness estimates and is thus only illustrative of the true underlying dynamics, it’s a reasonable practical choice for visualizing trends. In real applications, one would want to use better estimators of the EIG scores or many more samples.
Conclusion
What have we learned about the relationship between active learning and data filtering? Let’s summarize the key insights:
- The timing of selection decisions fundamentally shapes how these methods work
- Submodularity provides theoretical backing for both approaches
- Real-world applications often require balancing theoretical ideals with practical constraints
While submodularity typically holds, recognizing cases where sample informativeness can increase under certain conditions might still be somewhat unexplored for filtering (at least in a principled fashion as far as I know). Identifying and analyzing these scenarios presents an exciting direction for future research, potentially refining both active learning and data filtering methodologies further.
By recognizing these differences and the underlying principle of (approximate) submodularity, we can make more informed choices about what approximations to use and what heuristics to employ.
There’s still much to explore about cases where sample informativeness behaves in unexpected ways. This presents exciting opportunities for future research, particularly in:
- Understanding when and why informativeness patterns deviate from theory
- Developing more robust methods that account for these deviations
- Creating hybrid approaches that combine the strengths of both methods
Appendix
Conceptual Illustrations
The figures above illustrate the key concepts we’ve discussed. Let me explain how they were created:
Figure 1 shows a conceptual visualization of the informativeness curve, where samples are sorted by their informativeness (y-axis) along the x-axis. This curve demonstrates the diminishing returns property characteristic of submodularity. The visualization highlights:
- The data filtering region (red shaded area) shows samples that would be filtered out because their informativeness falls below a threshold.
- The active learning region (blue shaded area) represents samples that would be prioritized by active learning approaches because of their high informativeness.
Figure 2 is an animation that simulates how active learning works over time:
- We start with a distribution of samples with varying informativeness levels.
- At each step, the most informative sample (highest point) is selected for training.
- After selection, that sample’s informativeness drops to zero (it’s been “used”).
- The informativeness of other samples also decreases (with some random noise), reflecting how the value of remaining samples changes as the model learns.
This animation captures the dynamic, iterative nature of active learning, where the most informative sample is constantly changing as training progresses. It visually demonstrates why we need to re-evaluate informativeness at each step rather than making all selection decisions upfront.
Both visualizations were created using matplotlib with an XKCD-style aesthetic to make the concepts more approachable and intuitive.
Illustration Code
"""Generate an XKCD-style illustration of submodularity and how active learning vs.
data filtering interact with the informativeness curve.
The script produces both static PNG/SVG files and an animated GIF showing how
informativeness changes over time as data gets trained on.
Usage:
python scripts/submodularity_curve.py # default path
python scripts/submodularity_curve.py --out /tmp/figure.png
"""
# %%
from __future__ import annotations
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.animation import FuncAnimation, PillowWriter
# %%
# Synthetic submodular-like curve: diminishing returns.
= np.linspace(1, 10000, 10000)
x = 1 / (0.9 / 99.9 * x + (1 - 0.9 / 99.9))
y_base
= 0.05 # filtering threshold: keep samples with y >= threshold.
y_filtering_threshold = 0.2
y_active_threshold
= "red"
filtering_color = "blue"
active_color
# First create the static plots
with plt.xkcd():
= plt.subplots(figsize=(8, 5))
fig, ax
# Plot the diminishing-returns curve.
="black")
ax.plot(x, y_base, color
# Shaded region for samples that would be *filtered out* (y below threshold).
0, y_base, where=y_base < y_filtering_threshold, color=filtering_color, alpha=0.1,
ax.fill_between(x, ="Data filtering region (offline)")
label
# Horizontal dashed line showing the filtering threshold.
=filtering_color, linestyle="--", linewidth=1)
ax.axhline(y_filtering_threshold, color
# Vertical dashed line showing the active learning region.
<= y_active_threshold)], color=active_color, linestyle="--", linewidth=1)
ax.axvline(x[np.argmax(y_base
# Shaded vertical region for samples that would be considered for active learning.
0, x, where=y_base > y_active_threshold, color=active_color, alpha=0.1,
ax.fill_betweenx(y_base, ="Active learning region (online)")
label
# Styling.
"Samples sorted by informativeness")
ax.set_xlabel("Informativeness")
ax.set_ylabel("Active Selection vs. Data Filtering")
ax.set_title(0, 1.05)
ax.set_ylim(0, x.max())
ax.set_xlim(
ax.legend()False)
ax.grid(
= Path("active_vs_filtering_xkcd.png")
out_path =True, exist_ok=True)
out_path.parent.mkdir(parents=150, bbox_inches="tight")
fig.savefig(out_path, dpi
= Path("active_vs_filtering_xkcd.svg")
out_path =True, exist_ok=True)
out_path.parent.mkdir(parents=150, bbox_inches="tight")
fig.savefig(out_path, dpi
plt.show()#%%
# Now create the animated version
= 40
n_frames = 1
top_k
= y_base.copy()
current_y
with plt.xkcd():
= plt.subplots(figsize=(8, 5))
fig_anim, ax_anim = ax_anim.scatter([], [], color="black", s=1)
scatter = ax_anim.scatter([], [], color=active_color, s=100, alpha=0.1)
selected_scatter
# Horizontal dashed line showing the filtering threshold.
=filtering_color, linestyle="--", linewidth=1)
ax_anim.axhline(y_filtering_threshold, color
# Styling
"Samples sorted by informativeness")
ax_anim.set_xlabel("Informativeness")
ax_anim.set_ylabel("Simulation of Top-1 Active Selection")
ax_anim.set_title(0, 1.05)
ax_anim.set_ylim(0, 1000)
ax_anim.set_xlim(False)
ax_anim.grid(
def init():
scatter.set_offsets(np.c_[x, y_base])return (scatter, selected_scatter)
def animate(frame):
print(frame)
global current_y
# Add noise to the decay
if frame > 1:
= np.argsort(current_y)[-top_k:]
top_k_indices = 0.0
current_y[top_k_indices] # Decay the other points
= np.clip(np.random.gumbel(0, 0.1, len(current_y)), 0, 1)
noise = current_y * (1 - noise)
current_y
# Set the top 10 points to 0
= np.argsort(current_y)[-top_k:]
top_k_indices # Append the top 10 points to the selected scatter
selected_scatter.set_offsets(np.c_[x[top_k_indices].copy(), current_y[top_k_indices].copy()])
scatter.set_offsets(np.c_[x, current_y])return (scatter, selected_scatter)
= FuncAnimation(fig_anim, animate, init_func=init,
anim =n_frames, interval=200, blit=False, save_count=40)
frames
fig_anim.tight_layout()= PillowWriter(fps=5)
writer = Path("active_selection_animation.gif")
out_path =True, exist_ok=True)
out_path.parent.mkdir(parents=writer)
anim.save(out_path, writer
# %%
MNIST Active Learning Experiment
In our MNIST active learning experiment, we used a LeNet-5 model with Monte Carlo dropout to iteratively select the most informative samples based on BALD scores. The experiment started with a small initial labeled set and, at each iteration, acquired the most informative unlabeled samples for training. The BALD scores were calculated using multiple Monte Carlo samples to estimate the mutual information between model parameters and predictions.
The animation in Figure 3 visualizes the evolution of BALD scores over iterations. It shows how the informativeness of samples changes as the model learns, with the most informative samples being selected and their scores dropping to zero after acquisition. This dynamic process highlights the importance of re-evaluating sample informativeness during training.
The experiment was implemented in
mnist_al_experiment.py
, which handles the model training,
BALD score calculation, and sample acquisition. The animation was
created using bald_animation.py
, which retrieves data from
Weights & Biases (wandb) and generates the visualization of BALD
score evolution.
Experiment Code
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from tqdm import tqdm
import wandb
# --- Configuration ---
= torch.device(
DEVICE "mps"
if torch.backends.mps.is_available()
else "cuda"
if torch.cuda.is_available()
else "cpu"
)= 42
RANDOM_SEED
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
# MNIST Hyperparameters
= 10
N_CLASSES = 28
IMG_WIDTH = 28
IMG_HEIGHT = 1
N_CHANNELS
# LeNet Model Definition (LeNet-5 architecture with MC Dropout)
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
self.conv1_drop = nn.Dropout2d()
self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(1024, 128)
self.fc1_drop = nn.Dropout()
self.fc2 = nn.Linear(128, N_CLASSES)
def forward(self, input: torch.Tensor):
input = F.relu(F.max_pool2d(self.conv1_drop(self.conv1(input)), 2))
input = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(input)), 2))
input = input.view(-1, 1024)
input = F.relu(self.fc1_drop(self.fc1(input)))
input = self.fc2(input)
return input
def bootstrap_sample(dataset, sample_size=None, replace=True, random_state=None):
"""
Creates a bootstrap sample from the dataset.
Args:
dataset: PyTorch dataset to sample from
sample_size: Size of the bootstrap sample (defaults to len(dataset))
replace: Whether to sample with replacement (True for bootstrap)
random_state: Random seed for reproducibility
Returns:
A PyTorch Subset containing the bootstrapped samples
"""
if sample_size is None:
= len(dataset)
sample_size
= np.random.default_rng(random_state)
rng
# Generate indices with replacement
= rng.choice(len(dataset), size=sample_size, replace=replace)
indices
# Return a subset of the dataset with the selected indices
return torch.utils.data.Subset(dataset, indices)
def create_bootstrap_loader(
dataset,=64,
batch_size=None,
sample_size=True,
replace=None,
random_state=0,
num_workers=True,
shuffle
):"""
Creates a DataLoader with bootstrapped samples from the dataset.
Args:
dataset: PyTorch dataset to sample from
batch_size: Batch size for the DataLoader
sample_size: Size of the bootstrap sample (defaults to len(dataset))
replace: Whether to sample with replacement (True for bootstrap)
random_state: Random seed for reproducibility
num_workers: Number of worker processes for data loading
shuffle: Whether to shuffle the data
Returns:
A DataLoader containing bootstrapped samples
"""
= bootstrap_sample(dataset, sample_size, replace, random_state)
bootstrap_dataset
return torch.utils.data.DataLoader(
bootstrap_dataset,=batch_size,
batch_size=shuffle,
shuffle=num_workers,
num_workers
)
# EIG (BALD) Calculation
def get_scores(model, data_loader, n_mc_samples, criterion=None):
"""
Calculates BALD scores for samples in data_loader.
BALD = H(E[p(y|x,w)]) - E[H(p(y|x,w))]
"""
model.train()= [] # To store [n_samples_dataset, n_mc_samples, n_classes]
all_probs_mc = []
all_labels
with torch.inference_mode():
for data, labels in tqdm(data_loader, desc="Calculating BALD Scores"):
= data.to(DEVICE)
data = [] # Store MC samples for this batch [n_mc_samples, batch_size, n_classes]
batch_probs_mc for _ in range(n_mc_samples):
= model(data)
output = F.softmax(output, dim=1)
probs 0))
batch_probs_mc.append(probs.unsqueeze(
= torch.cat(
batch_probs_mc =0
batch_probs_mc, dim# [n_mc_samples, batch_size, n_classes]
)
all_probs_mc.append(1, 0, 2).cpu()
batch_probs_mc.permute(# [batch_size, n_mc_samples, n_classes]
) # [batch_size]
all_labels.append(labels)
= torch.cat(
all_probs_mc_tensor =0
all_probs_mc, dim# [total_samples, n_mc_samples, n_classes]
) = torch.cat(all_labels, dim=0) # [total_samples]
all_labels
# Entropy of mean predictions: H(E[p(y|x,w)])
= torch.mean(all_probs_mc_tensor, dim=1) # [total_samples, n_classes]
mean_probs = mean_probs * torch.log(mean_probs)
ic_probs = 0.0
ic_probs[torch.isnan(ic_probs)] = -torch.sum(ic_probs, dim=1) # [total_samples]
entropy_of_mean
# Mean of entropy of predictions: E[H(p(y|x,w))]
= all_probs_mc_tensor * torch.log(all_probs_mc_tensor)
ic_probs_mc = 0.0
ic_probs_mc[torch.isnan(ic_probs_mc)] = -torch.sum(
entropy_per_mc_sample =2
ic_probs_mc, dim# [total_samples, n_mc_samples]
) = torch.mean(entropy_per_mc_sample, dim=1) # [total_samples]
mean_of_entropy
= entropy_of_mean - mean_of_entropy
bald_scores
= (mean_probs.argmax(dim=1) == all_labels).float().mean() * 100.0
acc if criterion is not None:
= criterion(mean_probs.log(), all_labels)
loss else:
= -1.0
loss return bald_scores.numpy(), acc, loss
# Training function
def train_model(model, train_loader, optimizer, criterion, epochs):
# Set to train mode (enables dropout, batchnorm updates etc.)
model.train() = tqdm(range(epochs), desc="Training Epochs")
pbar for epoch in pbar:
= 0
epoch_loss = 0
correct = 0
total for batch_idx, (data, target) in enumerate(train_loader):
= data.to(DEVICE), target.to(DEVICE)
data, target
optimizer.zero_grad()= model(data)
output = criterion(output, target)
loss
loss.backward()
optimizer.step()
+= loss.item()
epoch_loss = output.max(1)
_, predicted += target.size(0)
total += predicted.eq(target).sum().item()
correct
=epoch_loss/len(train_loader), acc=100.*correct/total)
pbar.set_postfix(loss
def plot_scores(scores):
# Sort scores descending
= -np.sort(-scores)
sorted_scores # Plot the top 100 scores
plt.plot(sorted_scores)
plt.show()
# --- Data Loading and Preparation ---
= transforms.Compose(
transform 0.1307,), (0.3081,))]
[transforms.ToTensor(), transforms.Normalize((
)
# Load full training and test datasets
= datasets.MNIST(
full_train_dataset "./data", train=True, download=True, transform=transform
)# Subset the full training dataset to 10000 random samples
# Use numpy's default_rng with fixed seed for reproducibility
= np.random.default_rng(1)
rng = rng.permutation(len(full_train_dataset))[:10000]
indices = torch.utils.data.Subset(full_train_dataset, indices)
full_train_dataset = DataLoader(
all_train_loader_for_scores =1024, shuffle=False
full_train_dataset, batch_size
)
= datasets.MNIST("./data", train=False, transform=transform)
test_dataset = DataLoader(
test_loader =1024, shuffle=False
test_dataset, batch_size# For final eval & test scores
)
# --- Active Learning Setup ---
= 60 # Number of initial labeled samples
N_INITIAL_LABELED = 1 # Number of samples to acquire each iteration
N_ACQUIRE_PER_ITER = 100
N_ACTIVE_LEARNING_ITERATIONS = 32 # Number of MC samples for EIG
N_MC_SAMPLES_EIG = 3 # Epochs to train model at each AL iteration
TRAIN_EPOCHS_PER_ITER = 0.0005
LEARNING_RATE = 64 # Batch size for training
BATCH_SIZE_TRAIN = 10_000
N_MIN_SAMPLES_PER_EPOCH
# Initialize wandb
wandb.init(="blog-active-learning-vs-filtering",
project={
config"initial_labeled_samples": N_INITIAL_LABELED,
"acquire_per_iter": N_ACQUIRE_PER_ITER,
"al_iterations": N_ACTIVE_LEARNING_ITERATIONS,
"mc_samples_eig": N_MC_SAMPLES_EIG,
"epochs_per_iter": TRAIN_EPOCHS_PER_ITER,
"learning_rate": LEARNING_RATE,
"batch_size": BATCH_SIZE_TRAIN,
"model": "LeNet-5 with MC Dropout",
"dataset": "MNIST",
"device": str(DEVICE),
}
)
# Create initial labeled and unlabeled pools
= len(full_train_dataset)
num_train_samples = np.arange(num_train_samples)
all_indices
np.random.shuffle(all_indices)
= list(all_indices[:N_INITIAL_LABELED])
initial_labeled_indices = list(all_indices[N_INITIAL_LABELED:])
current_unlabeled_indices
# Store all acquisition scores
# Format: list of dicts, each dict is one AL iteration
# {'iteration': i, 'labeled_indices': [...], 'unlabeled_indices': [...],
# 'scores_for_all_train_samples': np.array([...]), 'scores_for_test_samples': np.array([...])}
= []
all_scores_history
# --- Active Learning Loop ---
print("Starting Active Learning Loop...")
= set(initial_labeled_indices)
current_labeled_indices_set
for al_iteration in tqdm(range(N_ACTIVE_LEARNING_ITERATIONS + 1), desc="Active Learning Iterations"): # +1 for initial state and after last acquisition
print(f"\n--- Active Learning Iteration: {al_iteration} ---")
print(f"Currently labeled samples: {len(current_labeled_indices_set)}")
# 1. Create current labeled dataset and loader
= Subset(full_train_dataset, list(current_labeled_indices_set))
labeled_subset = create_bootstrap_loader(
labeled_loader
labeled_subset,=BATCH_SIZE_TRAIN,
batch_size=True,
shuffle=RANDOM_SEED,
random_state=max(N_MIN_SAMPLES_PER_EPOCH, len(labeled_subset)),
sample_size=N_MIN_SAMPLES_PER_EPOCH > len(labeled_subset),
replace
)
# 2. Initialize or re-initialize model and optimizer
= LeNet().to(DEVICE)
model = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
optimizer = nn.CrossEntropyLoss()
criterion
# 3. Train model on current labeled data (unless it's iteration 0 before any acquisition)
if N_INITIAL_LABELED > 0: # Train if there's data
print("Training model...")
train_model(=TRAIN_EPOCHS_PER_ITER
model, labeled_loader, optimizer, criterion, epochs
)
# 4. Calculate EIG scores for all original training samples and test samples
print("Calculating EIG scores for test samples...")
= get_scores(model, test_loader, N_MC_SAMPLES_EIG, criterion)
scores_test_samples, test_acc, test_loss print(f"Model trained. Test Acc: {test_acc:.2f}%, Test Loss: {test_loss:.4f}")
print("Calculating EIG scores for all original training samples...")
# Create a loader for ALL original training samples to get their scores
= get_scores(
scores_all_train_samples, _, _
model, all_train_loader_for_scores, N_MC_SAMPLES_EIG
)
# Plot the scores
=(10, 5))
plt.figure(figsize"BALD Scores for all original training samples")
plt.title("BALD Score")
plt.ylabel(
plot_scores(scores_all_train_samples)
plt.show()
=(10, 5))
plt.figure(figsize"BALD Scores for all test samples")
plt.title("BALD Score")
plt.ylabel(
plot_scores(scores_test_samples)
plt.show()
# Store scores for this iteration
= {
iteration_scores_data "iteration": al_iteration,
"current_labeled_indices": sorted(list(current_labeled_indices_set)),
"current_unlabeled_indices": sorted(
current_unlabeled_indices# These are indices FROM the original full_train_dataset
), "scores_for_all_train_samples": scores_all_train_samples, # Order matches full_train_dataset
"scores_for_test_samples": scores_test_samples, # Order matches test_dataset
}
all_scores_history.append(iteration_scores_data)print(f"Scores stored for iteration {al_iteration}.")
# Log metrics to wandb
= {
wandb_log_data "iteration": al_iteration,
"num_labeled_samples": len(current_labeled_indices_set),
"test_accuracy": test_acc,
"test_loss": test_loss,
"avg_train_bald_score": np.mean(scores_all_train_samples),
"avg_test_bald_score": np.mean(scores_test_samples),
"max_train_bald_score": np.max(scores_all_train_samples),
"max_test_bald_score": np.max(scores_test_samples),
}
# Log histograms of BALD scores
=False)
wandb.log(wandb_log_data, commit
# Log histograms of BALD scores (optional)
# Create a table with all the scores
= wandb.Table(columns=["index", "bald_score"])
train_scores_table for idx, score in enumerate(scores_all_train_samples):
float(score))
train_scores_table.add_data(idx,
= wandb.Table(columns=["index", "bald_score"])
test_scores_table for idx, score in enumerate(scores_test_samples):
float(score))
test_scores_table.add_data(idx,
# Create plotly figures for BALD scores
# Create sorted scores for better visualization
= sorted(scores_all_train_samples, reverse=True)
sorted_train_scores = sorted(scores_test_samples, reverse=True)
sorted_test_scores
# Create dataframes for plotly
= pd.DataFrame({
train_df "index": range(len(sorted_train_scores)),
"bald_score": sorted_train_scores
})
= pd.DataFrame({
test_df "index": range(len(sorted_test_scores)),
"bald_score": sorted_test_scores
})
# Create plotly figures
= px.line(train_df, x="index", y="bald_score",
train_fig =f"Sorted BALD Scores for Training Samples (Iteration {al_iteration})")
title="Sample Index (sorted)", yaxis_title="BALD Score")
train_fig.update_layout(xaxis_title
= px.line(test_df, x="index", y="bald_score",
test_fig =f"Sorted BALD Scores for Test Samples (Iteration {al_iteration})")
title="Sample Index (sorted)", yaxis_title="BALD Score")
test_fig.update_layout(xaxis_title
# Log plotly figures to wandb
wandb.log({"train_bald_scores_plot": wandb.Plotly(train_fig),
"test_bald_scores_plot": wandb.Plotly(test_fig),
"train_bald_scores": train_scores_table,
"test_bald_scores": test_scores_table,
=False)
}, commit
# 5. If it's not the last iteration, perform acquisition
if al_iteration < N_ACTIVE_LEARNING_ITERATIONS:
if not current_unlabeled_indices:
print("No more unlabeled samples to acquire. Stopping.")
break
print("Acquiring new samples...")
# We need scores only for the *currently unlabeled* samples to decide which to pick
# `scores_all_train_samples` contains scores for *all* original training samples.
# We filter these down to only the unlabeled ones.
= {
unlabeled_scores_map for idx in current_unlabeled_indices
idx: scores_all_train_samples[idx]
}
# Sort unlabeled samples by their EIG score (descending)
= sorted(
sorted_unlabeled_by_score =lambda item: item[1], reverse=True
unlabeled_scores_map.items(), key
)
# Select top N_ACQUIRE_PER_ITER samples
= min(N_ACQUIRE_PER_ITER, len(sorted_unlabeled_by_score))
num_to_acquire = sorted_unlabeled_by_score[:num_to_acquire]
acquired_indices_scores = [idx for idx, score in acquired_indices_scores]
acquired_indices = [score for idx, score in acquired_indices_scores]
acquired_scores
if not acquired_indices:
print(
"Could not acquire any new samples (perhaps scores were uniform or list was empty). Stopping."
)break
print(f"Acquired {len(acquired_indices)} new samples ({np.mean(acquired_scores):.2f} avg score).")
# Log acquisition information
wandb.log({"acquisition_step": al_iteration,
"acquired_indices": acquired_indices,
"acquired_scores": acquired_scores,
"avg_acquired_score": np.mean(acquired_scores) if acquired_scores else 0,
=False)
}, commit
# Add to labeled set and remove from unlabeled pool
for idx in acquired_indices:
current_labeled_indices_set.add(idx)
current_unlabeled_indices.remove(
idx# Keep this list of original indices up to date
) else:
print("Final iteration reached. No more acquisitions.")
=al_iteration, commit=True)
wandb.log({}, step
# Finish wandb run
wandb.finish()
Animation Code
#!/usr/bin/env python3
"""
Create an animation showing how BALD scores evolve during active learning.
This script retrieves data from wandb and creates a similar animation to the one in visualizations.py.
Uses BALD scores without sorting to show the actual distribution of uncertainty.
"""
#%%
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
from pathlib import Path
import json
import pandas as pd
from tqdm.auto import tqdm
from wandb.apis.public import Api as WandbApi
# Initialize wandb API
= WandbApi()
api
# Find the latest run in the project
= api.runs("blog-active-learning-vs-filtering", order="-created_at")
runs = runs[0] # Most recent run
latest_run
print(f"Analyzing run: {latest_run.name} ({latest_run.id})")
#%%
# Get data from the run - we need to extract BALD scores and acquisition info
= {}
scores_by_iter = {}
acquired_indices_by_iter = {} # To store which indices were already labeled
train_subset_indices
# Fetch the history data
print("Fetching history data...")
= "train_bald_scores"
scores_key = latest_run.history(keys=["iteration", scores_key], pandas=False)
history #%%
# Process the history data
for row in tqdm(history):
if "iteration" not in row:
continue
= row["iteration"]
iteration
# Get BALD scores if available in this row
if scores_key in row:
try:
# Try to get scores from the run table
= row[scores_key]
table_data # Load the artifact
= latest_run.file(table_data['path']).download(replace=True).read()
table_file = json.loads(table_file)
table_data = pd.DataFrame(table_data["data"], columns=table_data["columns"])
df = df["bald_score"].values
scores_by_iter[iteration] except Exception as e:
print(f"Error extracting {scores_key} for iteration {iteration}: {e}")
# Get acquired indices if available
if "acquired_indices" in row:
= row["acquired_indices"]
acquired_indices_by_iter[iteration]
# Get currently labeled indices
if "current_labeled_indices" in row:
= row["current_labeled_indices"]
train_subset_indices[iteration]
#%%
# If we still don't have data, exit
if not scores_by_iter:
print(f"Failed to retrieve {scores_key} from the run. Exiting.")
raise ValueError("Failed to retrieve BALD scores from the run.")
#%%
# Sort iterations and prepare animation data
= sorted(scores_by_iter.keys())
sorted_iterations = [scores_by_iter[i].copy() for i in sorted_iterations]
all_scores
= []
acquired_indices for scores in all_scores:
for i in acquired_indices:
= 0.
scores[i] # Get the argmax score
= np.argmax(scores)
max_idx
acquired_indices.append(max_idx.item())
# Pop last acquired index
acquired_indices.pop()
assert np.all(all_scores[-1][acquired_indices] == 0.)
print(f"Acquired indices: {acquired_indices}")
print(f"Found {scores_key} for {len(all_scores)} iterations")
print(f"First iteration has {len(all_scores[0])} samples")
#%%
# Create the animation with XKCD style but using EMA of the scores
= 0.9
ema_decay
= []
all_scores_ema = all_scores[0].copy()
current_scores for i in range(len(all_scores)):
= ema_decay * current_scores + (1.0 - ema_decay) * all_scores[i]
current_scores for j in range(i):
= 0.
current_scores[acquired_indices[j]]
all_scores_ema.append(current_scores)
#%%
with plt.xkcd():
= plt.subplots(figsize=(10, 6))
fig, ax
# Set up initial plot - we'll use scatter plot for non-sorted scores
= min(10000, max([len(scores) for scores in all_scores_ema]))
max_points = ax.scatter([], [], color="black", s=2)
scatter
# For highlighted points we'll use a different color
= ax.scatter([], [], color="blue", s=100, alpha=0.5)
highlight_scatter
# Title and labels
"BALD Score Evolution (Non-sorted w/ EMA)")
ax.set_title("Sample Index")
ax.set_xlabel("Informativeness (BALD score)")
ax.set_ylabel(0, max_points)
ax.set_xlim(
# Find max BALD score for consistent y-axis scaling
= max([max(scores) for scores in all_scores_ema])
max_score 0, max_score * 1.1)
ax.set_ylim(
# Text annotations
= ax.text(0.02, 0.95, '', transform=ax.transAxes, fontsize=12)
iter_text = ax.text(0.02, 0.90, '', transform=ax.transAxes, fontsize=10)
labeled_text
# Get the sorted indices of the first iteration
= np.argsort(all_scores[0])[::-1]
sorted_indices = np.argsort(sorted_indices)
inv_sorted_indices
def init():
# Use the raw scores (no sorting)
= np.array(all_scores_ema[0])
current_scores = np.arange(len(current_scores))
x_data
scatter.set_offsets(np.c_[x_data, current_scores])
# Highlight top uncertain points
highlight_scatter.set_offsets(np.c_[[], []])
# Set iteration text
f"Iteration: {sorted_iterations[0]}")
iter_text.set_text(
return scatter, highlight_scatter, iter_text, labeled_text
def animate(frame_idx):
if frame_idx < len(sorted_iterations):
= sorted_iterations[frame_idx]
iteration # Compute an EMA of the past scores
= np.array(all_scores_ema[frame_idx])
current_scores
# Use raw scores without sorting
= inv_sorted_indices[np.arange(len(current_scores))]
x_data
# Update main scatter plot
scatter.set_offsets(np.c_[x_data, current_scores.copy()])
# Highlight top uncertain points (these will change each iteration)
# top_indices = acquired_indices[:frame_idx]
# highlight_scatter.set_offsets(np.c_[inv_sorted_indices[top_indices], current_scores[top_indices].copy()])
# assert np.all(current_scores[top_indices[:1]] == 0.), current_scores[top_indices]
# current_scores[top_indices] = 0.
# Update text
f"Iteration: {iteration}")
iter_text.set_text(
return scatter, highlight_scatter, iter_text, labeled_text
# Create animation
= len(sorted_iterations)
num_frames = FuncAnimation(fig, animate, init_func=init,
anim =num_frames, interval=50, blit=True)
frames
# Add legend
ax.legend()
# Save animation
= PillowWriter(fps=1000/50)
writer = Path("bald_scores_ema_animation.gif")
output_path print(f"Saving animation to {output_path}")
=writer)
anim.save(output_path, writer
print("Done!")
# %%