Skip to content

Commit

Permalink
feat(sdk): add exponential decay sampling utility for line_plot (#6228)
Browse files Browse the repository at this point in the history
Co-authored-by: sweep-ai[bot] <128439645+sweep-ai[bot]@users.noreply.github.com>
Co-authored-by: William Zeng <[email protected]>
  • Loading branch information
3 people authored Sep 5, 2023
1 parent 11e9e26 commit b7da7e4
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 1 deletion.
11 changes: 11 additions & 0 deletions tests/pytest_tests/unit_tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,3 +760,14 @@ def test_make_docker_image_name_safe():
== "abc.123__def-456"
)
assert util.make_docker_image_name_safe("......") == "image"


def test_sampling_weights():
xs = np.arange(0, 100)
ys = np.arange(100, 200)
sample_size = 1000
sampled_xs, _, _ = util.sample_with_exponential_decay_weights(
xs, ys, sample_size=sample_size
)
# Expect more samples from the start of the list
assert np.mean(sampled_xs) < np.mean(xs)
1 change: 0 additions & 1 deletion wandb/plot/line_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def line_series(

if keys is not None:
assert len(keys) == len(ys), "Number of keys and y-lines must match"

data = [
[x, f"key_{i}" if keys is None else keys[i], y]
for i, (xx, yy) in enumerate(zip(xs, ys))
Expand Down
23 changes: 23 additions & 0 deletions wandb/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1813,3 +1813,26 @@ def random_string(length: int = 12) -> str:
return "".join(
secrets.choice(string.ascii_lowercase + string.digits) for _ in range(length)
)


def sample_with_exponential_decay_weights(
xs: Union[Iterable, Iterable[Iterable]],
ys: Iterable[Iterable],
keys: Optional[Iterable] = None,
sample_size: int = 1500,
) -> Tuple[List, List, Optional[List]]:
"""Sample from a list of lists with weights that decay exponentially.
May be used with the wandb.plot.line_series function.
"""
xs_array = np.array(xs)
ys_array = np.array(ys)
keys_array = np.array(keys) if keys else None
weights = np.exp(-np.arange(len(xs_array)) / len(xs_array))
weights /= np.sum(weights)
sampled_indices = np.random.choice(len(xs_array), size=sample_size, p=weights)
sampled_xs = xs_array[sampled_indices].tolist()
sampled_ys = ys_array[sampled_indices].tolist()
sampled_keys = keys_array[sampled_indices].tolist() if keys else None

return sampled_xs, sampled_ys, sampled_keys

0 comments on commit b7da7e4

Please sign in to comment.