Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: rename n to num_samples #6

Merged
merged 1 commit into from
Jul 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/autora/experimentalist/sampler/uncertainty/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

from autora.utils.deprecation import deprecated_alias

def uncertainty_sample(X, model, n, measure="least_confident"):
def uncertainty_sample(X, model, num_samples, measure="least_confident"):
"""

Args:
X: pool of IV conditions to evaluate uncertainty
model: Scikit-learn model, must have `predict_proba` method.
n: number of samples to select
num_samples: number of samples to select
measure: method to evaluate uncertainty. Options:

- `'least_confident'`: $x* = \\operatorname{argmax} \\left( 1-P(\\hat{y}|x) \\right)$,
Expand All @@ -36,21 +36,21 @@ class labels under the model, respectively.
# Calculate uncertainty of max probability class
a_uncertainty = 1 - a_prob.max(axis=1)
# Get index of largest uncertainties
idx = np.flip(a_uncertainty.argsort()[-n:])
idx = np.flip(a_uncertainty.argsort()[-num_samples:])

elif measure == "margin":
# Sort values by row descending
a_part = np.partition(-a_prob, 1, axis=1)
# Calculate difference between 2 largest probabilities
a_margin = -a_part[:, 0] + a_part[:, 1]
# Determine index of smallest margins
idx = a_margin.argsort()[:n]
idx = a_margin.argsort()[:num_samples]

elif measure == "entropy":
# Calculate entropy
a_entropy = entropy(a_prob.T)
# Get index of largest entropies
idx = np.flip(a_entropy.argsort()[-n:])
idx = np.flip(a_entropy.argsort()[-num_samples:])

else:
raise ValueError(
Expand Down