Skip to content

Commit

Permalink
Add basic AL example. Fix indexing in strategies. Allow only tensor i…
Browse files Browse the repository at this point in the history
…ndices.
  • Loading branch information
patel-zeel committed Oct 28, 2023
1 parent 7d0d3e8 commit 42a8cbc
Show file tree
Hide file tree
Showing 6 changed files with 789 additions and 29 deletions.
21 changes: 13 additions & 8 deletions astra/torch/al/strategies/diversity.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from astra.torch.al import Strategy

from typing import Sequence, Dict, Union, List
from typing import Sequence, Dict


class DiversityStrategy(Strategy):
def query(
self,
net: nn.Module,
pool_indices: Union[List[int], np.ndarray, torch.Tensor],
context_indices: Union[List[int], np.ndarray, torch.Tensor] = None,
pool_indices: torch.Tensor,
context_indices: torch.Tensor = None,
n_query_samples: int = 1,
n_mc_samples: int = None,
batch_size: int = None,
Expand All @@ -31,6 +30,11 @@ def query(
Returns:
best_indices: A dictionary of acquisition names and the corresponding best indices.
"""
assert isinstance(pool_indices, torch.Tensor), f"pool_indices must be a torch.Tensor, got {type(pool_indices)}"
assert isinstance(
context_indices, torch.Tensor
), f"context_indices must be a torch.Tensor, got {type(context_indices)}"

if batch_size is None:
batch_size = len(pool_indices)

Expand All @@ -57,10 +61,11 @@ def query(
# TODO: We can make this loop faster by computing scores only for updated indices. There can be a method in acquisition to update the scores.
for _ in range(n_query_samples):
scores = acquisition.acquire_scores(features, pool_indices, context_indices)
best_index = torch.argmax(scores)
selected_indices.append(best_index)
pool_indices.pop(best_index)
context_indices.append(best_index)
index = torch.argmax(scores)
selected_index = pool_indices[index]
selected_indices.append(selected_index)
pool_indices = torch.cat([pool_indices[:index], pool_indices[index + 1 :]])
context_indices = torch.cat([context_indices, selected_index])
selected_indices = torch.tensor(selected_indices, device=self.device)
best_indices[acq_name] = selected_indices

Expand Down
11 changes: 7 additions & 4 deletions astra/torch/al/strategies/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

from astra.torch.al import Strategy

from typing import Sequence, Dict, List, Union
from typing import Sequence, Dict


class EnsembleStrategy(Strategy):
def query(
self,
net: Union[List[int], np.ndarray, torch.Tensor],
pool_indices: Union[List[int], np.ndarray, torch.Tensor],
net: torch.Tensor,
pool_indices: torch.Tensor,
context_indices: Sequence[int] = None,
n_query_samples: int = 1,
n_mc_samples: int = None,
Expand All @@ -32,6 +32,8 @@ def query(
Returns:
best_indices: A dictionary of acquisition names and the corresponding best indices.
"""
assert isinstance(pool_indices, torch.Tensor), f"pool_indices must be a torch.Tensor, got {type(pool_indices)}"

if not isinstance(net, Sequence):
raise ValueError(f"net must be a sequence of nets, got {type(net)}")

Expand All @@ -54,7 +56,8 @@ def query(
best_indices = {}
for acq_name, acquisition in self.acquisitions.items():
scores = acquisition.acquire_scores(logits)
selected_indices = torch.topk(scores, n_query_samples).indices
index = torch.topk(scores, n_query_samples).indices
selected_indices = pool_indices[index]
best_indices[acq_name] = selected_indices

return best_indices
11 changes: 7 additions & 4 deletions astra/torch/al/strategies/mc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@

from astra.torch.al import Strategy

from typing import Sequence, Dict, Union, List
from typing import Sequence, Dict


class MCStrategy(Strategy):
def query(
self,
net: nn.Module,
pool_indices: Union[List[int], np.ndarray, torch.Tensor],
context_indices: Union[List[int], np.ndarray, torch.Tensor] = None,
pool_indices: torch.Tensor,
context_indices: torch.Tensor = None,
n_query_samples: int = 1,
n_mc_samples: int = 10,
batch_size: int = None,
Expand All @@ -32,6 +32,8 @@ def query(
Returns:
best_indices: A dictionary of acquisition names and the corresponding best indices.
"""
assert isinstance(pool_indices, torch.Tensor), f"pool_indices must be a torch.Tensor, got {type(pool_indices)}"

if batch_size is None:
batch_size = len(pool_indices)

Expand All @@ -55,7 +57,8 @@ def query(
best_indices = {}
for acq_name, acquisition in self.acquisitions.items():
scores = acquisition.acquire_scores(logits)
selected_indices = torch.topk(scores, n_query_samples).indices
index = torch.topk(scores, n_query_samples).indices
selected_indices = pool_indices[index]
best_indices[acq_name] = selected_indices

return best_indices
11 changes: 7 additions & 4 deletions astra/torch/al/strategies/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import torch.nn as nn
from astra.torch.al import Strategy

from typing import Dict, Union, List
from typing import Dict


class RandomStrategy(Strategy):
def query(
self,
net: nn.Module,
pool_indices: Union[List[int], np.ndarray, torch.Tensor],
context_indices: Union[List[int], np.ndarray, torch.Tensor] = None,
pool_indices: torch.Tensor,
context_indices: torch.Tensor = None,
n_query_samples: int = 1,
n_mc_samples: int = 10,
batch_size: int = None,
Expand All @@ -29,12 +29,15 @@ def query(
Returns:
best_indices: A dictionary of acquisition names and the corresponding best indices.
"""
assert isinstance(pool_indices, torch.Tensor), f"pool_indices must be a torch.Tensor, got {type(pool_indices)}"

# logits shape (n_mc_samples, pool_dim, n_classes)
logits = torch.rand(n_mc_samples, len(pool_indices), self.n_classes)
best_indices = {}
for acq_name, acquisition in self.acquisitions.items():
scores = acquisition.acquire_scores(logits)
selected_indices = torch.topk(scores, n_query_samples).indices
indices = torch.topk(scores, n_query_samples).indices
selected_indices = pool_indices[indices]
best_indices[acq_name] = selected_indices

return best_indices
41 changes: 32 additions & 9 deletions astra/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ def train_fn(model, inputs, output, loss_fn, lr, n_epochs, batch_size=None, enab
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

iter_losses = []
epochs_losses = []
epoch_losses = []
outer_loop = range(n_epochs)
if enable_tqdm:
loop = tqdm(range(n_epochs))
else:
loop = range(n_epochs)
for _ in loop:
pbar = tqdm(total=len(data_loader) * n_epochs)
n_processed = 0

for _ in outer_loop:
loss_value = 0.0
for x, y in data_loader:
optimizer.zero_grad()
Expand All @@ -33,11 +34,33 @@ def train_fn(model, inputs, output, loss_fn, lr, n_epochs, batch_size=None, enab
optimizer.step()
iter_losses.append(loss.item())
loss_value += loss.item()
epochs_losses.append(loss_value / len(data_loader))
if enable_tqdm:
loop.set_description(f"Loss: {loss.item():.6f}")
if enable_tqdm:
n_processed += len(x)
pbar.update(1)
pbar.set_description(f"Loss: {loss.item():.6f}")

epoch_losses.append(loss_value / len(data_loader))

return {"iter_losses": iter_losses, "epoch_losses": epoch_losses}


def predict_class(model, inputs, batch_size=None):
"""Generic predict function for classification models.
Note that we assume that the model predicts the logits of size `n_classes` even for the binary classification case.
"""
if batch_size is None:
batch_size = len(inputs)

data_loader = DataLoader(TensorDataset(inputs), batch_size=batch_size, shuffle=False)

return {"iter_losses": iter_losses, "epochs_losses": epochs_losses}
model.eval()
with torch.no_grad():
preds = []
for x in tqdm(data_loader):
pred = model(x[0])
preds.append(pred)
pred = torch.cat(preds)
return pred.argmax(dim=1)


def ravel_pytree(pytree):
Expand Down
723 changes: 723 additions & 0 deletions notebooks/al/basic_example.ipynb

Large diffs are not rendered by default.

0 comments on commit 42a8cbc

Please sign in to comment.