Skip to content

Commit

Permalink
make device a property.
Browse files Browse the repository at this point in the history
  • Loading branch information
patel-zeel committed Oct 28, 2023
1 parent d14d4aa commit 921690f
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 3 deletions.
1 change: 1 addition & 0 deletions astra/torch/al/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
acquisitions = [acquisitions]
self.acquisitions = {acq.__class__.__name__: acq for acq in acquisitions}

@property
def device(self):
return self.dummy_param.device

Expand Down
7 changes: 6 additions & 1 deletion astra/torch/al/strategies/diversity.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,16 @@ def query(
# Get the features for the context
context_features_list = []
for x, _ in context_data_loader:
context_features = net(x)
context_features = net(x.to(self.device))
context_features_list.append(context_features)
context_features = torch.cat(context_features_list, dim=0) # (context_dim, feature_dim)

best_indices = {}
# TODO: Fix this for loop to do the following:
# - Get the max score. Get corresponding index.
# - Add that index to selected_indices and also to pool indices.
# - Remove that index from pool indices.
# - Repeat until len(selected_indices) == n_query_samples.
for acq_name, acquisition in self.acquisitions.items():
scores = acquisition.acquire_scores(pool_features, context_features)
selected_indices = torch.topk(scores, n_query_samples).indices
Expand Down
2 changes: 1 addition & 1 deletion astra/torch/al/strategies/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def query(
for x, _ in data_loader:
net_logits_list = []
for model in net:
net_logits = model(x.to(self.device()))[np.newaxis, ...]
net_logits = model(x.to(self.device))[np.newaxis, ...]
net_logits_list.append(net_logits)
logits = torch.cat(net_logits_list, dim=0) # (n_nets, batch_dim, n_classes)
logits_list.append(logits)
Expand Down
2 changes: 1 addition & 1 deletion astra/torch/al/strategies/mc.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def query(
with torch.no_grad():
logits_list = []
for x, _ in data_loader:
vx = x[np.newaxis, ...].repeat(*repeats).to(self.device())
vx = x[np.newaxis, ...].repeat(*repeats).to(self.device)
logits = torch.vmap(net, randomness="different")(vx)
logits_list.append(logits)
logits = torch.cat(logits_list, dim=1) # (n_mc_samples, pool_dim, n_classes)
Expand Down

0 comments on commit 921690f

Please sign in to comment.