Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update sphere.py -- add pooling and max_pooling #1596

Merged
merged 17 commits into from
Feb 20, 2024
31 changes: 31 additions & 0 deletions nevergrad/common/sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,35 @@ def antithetic_order(n, shape, axis=-1, also_sym=False, conv=None):
x = x + [scx]
return x

def max_pooling(n, shape, budget, pooling=(1,8,8)):
old_latents = []
m = torch.nn.AvgPool3d(pooling)
x = []
for i in range(n):
latents = torch.randn((1, *shape),)
latents_pooling = m(latents)
if len(old_latents) != 0:
dist = torch.min(torch.stack([(latents_pooling - old_latents[r_n]).pow(2).sum().sqrt() for r_n in range(len(old_latents))]))
max_dist = dist
t0 = time.time()
while (time.time() - t0) < 0.01 * budget / n:
latents_new = torch.randn((n, *shape),)
latents_pooling_new = m(latents_new)
dist_new = torch.min(torch.stack(
[(latents_pooling_new - old_latents[r_n]).pow(2).sum().sqrt() for r_n in range(len(old_latents))]))
if dist_new > max_dist:
latents = latents_new
max_dist = dist_new
latents_pooling = latents_pooling_new
x.append(latents)
old_latents.append(latents_pooling)
x = torch.cat(x, 0).numpy()
x = normalize(x)
return x


def pooling(n, shape, budget, pooling=(1,1,1)):
return max_pooling(n, shape, budget, pooling)

def antithetic_order_and_sign(n, shape, axis=-1, conv=None):
return antithetic_order(n, shape, axis, also_sym=True)
Expand Down Expand Up @@ -826,6 +855,8 @@ def metric_pack_big_conv(x, budget=default_budget):
"Riesz_blursum_lowconv_loworder",
"Riesz_blursum_lowconv_midorder",
"Riesz_blursum_lowconv_highorder",
"max_pooling",
"pooling"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All methods in the list have the same signature, so that they can be used by quasi-randomize.
Please use the same signature.

]
list_metrics = [
"metric_half",
Expand Down