Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TraceEnum_ELBO: Subsample local variables that depend on a global model-enumerated variable #1572

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

ordabayevy
Copy link
Member

One of the features not supported by TraceEnum_ELBO is that you cannot subsample a local variable when it depends on a global variable that is enumerated in the model because it requires a common scale:

@config_enumerate
def model(data):
    # Global variables.
    locs = jnp.tensor([1., 10.])
    a = pyro.sample('a', dist.Categorical(jnp.ones(2)))
    with pyro.plate('data', len(data), subsample_size=2) as ind:  # cannot subsample here
        # Local variables.
        pyro.sample('b', dist.Normal(locs[a], 1.), obs=data[ind])

def guide(data):
    pass

This has been asked on the forum as well: https://forum.pyro.ai/t/enumeration-and-subsampling-expected-all-enumerated-sample-sites-to-share-common-poutine-scale/4938

Proposed solution here is to scale log factors as follows ($N$ - total size, $M$ - subsample size):
$\log \sum_a p(a) {\prod_i}^{N} p(b_i | a) \approx \frac{N}{M}\log \sum_a p(a) {\prod_i}^{M} p(b_i | a)$

Expectation of the left hand side:
$\mathbb{E} [ \log \sum_a p(a) {\prod_i}^{N} p(b_i | a) ] = \mathbb{E} [ \log {\prod_i}^{N} \sum_a p(a) p(b_i | a) ]= \mathbb{E} [ \log {\prod_i}^{N} p(b_i) ]$
$= \mathbb{E} [{\sum_i}^N \log p(b_i) ] = {\sum_i}^N \mathbb{E} [ \log p(b_i) ]$
$= N \mathbb{E} [ \log p(b_i) ]$

Expectation of the right hand side:
$\mathbb{E} [ \frac{N}{M} \log \sum_a p(a) {\prod_i}^{M} p(b_i | a) ] = \frac{N}{M} \mathbb{E} [ \log {\prod_i}^{M} \sum_a p(a) p(b_i | a) ] = \frac{N}{M} \mathbb{E} [ \log {\prod_i}^{M} p(b_i) ]$
$= \frac{N}{M} \mathbb{E} [{\sum_i}^M \log p(b_i) ] = \frac{N}{M} {\sum_i}^M \mathbb{E} [ \log p(b_i) ]$
$= N \mathbb{E} [ \log p(b_i) ]$

@fehiepsi
Copy link
Member

fehiepsi commented Sep 1, 2023

Hi @ordabayevy, I don't understand how you can move prod and sum around. In particular, I'm not sure if your first equation makes sense: $\mathbb{E} [ \log \sum_a p(a) {\prod_i}^{N} p(b_i | a) ] = \mathbb{E} [ \log {\prod_i}^{N} \sum_a p(a) p(b_i | a) ]$ - could you clarify?

@ordabayevy
Copy link
Member Author

I think you are right @fehiepsi . Let me think more about this.

@ordabayevy
Copy link
Member Author

So the actual equation should be (same in the code):
$\log \sum_a p(a) {\prod_i}^{N} p(b_i | a) \approx \log \sum_a p(a) \left ( {\prod_{i \in I_M}} p(b_i | a) \right ) ^ {\frac{N}{M}}$

This seems intuitive to me - subsample within a plate and then scale the product before summing it up. I did some tests and it seems to be unbiased. However, I can't figure out how to prove unbiasedness mathematically.

@ordabayevy
Copy link
Member Author

Code I used to check unbiasedness:

import torch
import torch.distributions as dist
import matplotlib.pyplot as plt

a = torch.tensor([0, 1])
logits_a = torch.log(torch.tensor([0.3, 0.7]))

# values are sampled from N(0, 1) and N(1, 1)
b = torch.rand(1000)
b[500:] += 1

d = dist.Normal(a, 1)
log_b = d.log_prob(b.reshape(-1,1))

expected = torch.logsumexp(log_b.sum(0) + logits_a, 0)

results = []
for _ in range(50000):
    idx = torch.randperm(1000)[:100] # subsample 100 samples
    scale = 10  # 1000 / 100
    results.append(torch.logsumexp(d.log_prob(b[idx].reshape(-1,1)).sum(0) * scale + logits_a, 0))

print(expected)
print(torch.mean(torch.tensor(results)))

plt.plot(results)
plt.hlines(expected, 0, 50000, "C1")
plt.show()

>>> tensor(-1087.9426)
>>> tensor(-1087.8069)

image

@fehiepsi
Copy link
Member

fehiepsi commented Sep 7, 2023

I think it's easier to see the issue if we use a smaller number of data (e.g. just 2). Assume we are using subsample to estimate $\log \sum_a p(a) p(x|a)p(y|a)=\log \sqrt{\sum_{a,b} p(a)p(b)p(x|a)p(x|b)p(y|a)p(y|b)}$ - we have
$0.5 * (\log \sum_a p(a) p(x|a)^2 + \log \sum_a p(a) p(y|a)^2) = \log \sqrt{\sum_{a,b} p(a)p(b) p(x|a)^2 p(y|b)^2}$
It seems clear to me that two terms $\sum_{a,b} p(a)p(b)p(x|a)p(x|b)p(y|a)p(y|b)$ and $\sum_{a,b} p(a)p(b) p(x|a)^2 p(y|b)^2$ are not equal.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants