Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug whereby partial traces have fewer draws than would be available #4318

Merged
merged 9 commits into from
Dec 12, 2020

Conversation

MarcoGorelli
Copy link
Contributor

Adding a quick test for _choose_chains, which is reached if there's a keyboard interrupt during parallel sampling.

something I'm confused about is this part of pymc3/sampling.py :

    final_length = l_sort[0]
    last_total = 0
    for i, length in enumerate(l_sort):
        total = (i + 1) * length
        if total < last_total:
            use_until = i
            break
        last_total = total
        final_length = length
    else:
        use_until = len(lengths)

So, we iterate through the traces (from the longest one to the shortest one), and stop when the total ((i + 1) * length is smaller than it was for the previous trace.

For example, if we have traces of length 10, 7, 3, then we would get:

  • with i=0, length=10: total=10, last_total=0 : as total >= last_total, we continue
  • with i=1, length=7: total=14, last_total=10: as total >= last_total, we continue
  • with i=2, length=3: total=9, last_total=14: as total < last_total, we break

I just don't see why we'd do that - what's the significance of (i + 1) * length? This function comes from #3011 (cc @aseyboldt in case you could offer any help understanding this)

@codecov
Copy link

codecov bot commented Dec 9, 2020

Codecov Report

Merging #4318 (79c7e87) into master (2a38198) will increase coverage by 0.11%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #4318      +/-   ##
==========================================
+ Coverage   87.54%   87.66%   +0.11%     
==========================================
  Files          88       88              
  Lines       14272    14264       -8     
==========================================
+ Hits        12495    12504       +9     
+ Misses       1777     1760      -17     
Impacted Files Coverage Δ
pymc3/sampling.py 89.66% <100.00%> (+1.99%) ⬆️

@ColCarroll
Copy link
Member

Wow, good catch! I think this is clever code, and maybe has an error in it. It could use a comment, in any case.

The goal here is to "trim the traces", while keeping as many draws as possible. Recall that the traces are drawn using multiprocessing, so they will have various lengths when a keyboard interrupt hits. I think the intended algorithm is:

  1. Sort lengths in a descending manner, so the longest is first.
  2. for j=0...len(lengths)
    we can trim the first j traces, and have a total of (j + 1) * lengths[j] draws
    choose j to maximize this total

I think the correct implementation is

use_until = np.argmax(l_sort * np.arange(1, l_sort.shape[0] + 1))

@ColCarroll
Copy link
Member

Two side notes:

  1. This code path is not used overly often, since no one has spotted this bug.
  2. In our heart of hearts we want to maximize the effective sample size (not the number of draws), and would prefer to keep the number of chains specified by the user.

Given side note 1, I don't think spending the time to get side note 2 is worth it.

@MarcoGorelli MarcoGorelli changed the title Add test for uncovered _choose_chains Fix bug whereby partial traces have fewer draws than would be available Dec 9, 2020
@MarcoGorelli
Copy link
Contributor Author

MarcoGorelli commented Dec 9, 2020

Thanks @ColCarroll for your quick review!

Your explanation of the goal to trim the traces makes sense to me, and I think your np.argmax implementation is clearer. It's also better, as it passes the first tests case I added (which failed on master).

e.g. if we have traces of lengths 5, 2, 2, then the implementation on master would've done:

  • with i=0, length=5: total=5, last_total=0 : as total >= last_total, we continue
  • with i=1, length=2: total=4, last_total=5: as total < last_total, we break

However, if it had continued, it would've found

  • with i=2, length=2: total=6

which would be the actual maximum. np.argmax would find this

Copy link
Member

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great test!
I just requested a change on a docstring that is a bit misleading. Otherwise ✔️

pymc3/sampling.py Outdated Show resolved Hide resolved
pymc3/sampling.py Outdated Show resolved Hide resolved
@twiecki twiecki merged commit 6f15cbb into pymc-devs:master Dec 12, 2020
@MarcoGorelli MarcoGorelli deleted the cover-choose-chains branch December 12, 2020 18:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants