-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix bug whereby partial traces have fewer draws than would be available #4318
Conversation
Codecov Report
@@ Coverage Diff @@
## master #4318 +/- ##
==========================================
+ Coverage 87.54% 87.66% +0.11%
==========================================
Files 88 88
Lines 14272 14264 -8
==========================================
+ Hits 12495 12504 +9
+ Misses 1777 1760 -17
|
Wow, good catch! I think this is clever code, and maybe has an error in it. It could use a comment, in any case. The goal here is to "trim the traces", while keeping as many draws as possible. Recall that the traces are drawn using multiprocessing, so they will have various lengths when a keyboard interrupt hits. I think the intended algorithm is:
I think the correct implementation is use_until = np.argmax(l_sort * np.arange(1, l_sort.shape[0] + 1)) |
Two side notes:
Given side note 1, I don't think spending the time to get side note 2 is worth it. |
Thanks @ColCarroll for your quick review! Your explanation of the goal to trim the traces makes sense to me, and I think your e.g. if we have traces of lengths 5, 2, 2, then the implementation on
However, if it had continued, it would've found
which would be the actual maximum. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great test!
I just requested a change on a docstring that is a bit misleading. Otherwise ✔️
Adding a quick test for
_choose_chains
, which is reached if there's a keyboard interrupt during parallel sampling.something I'm confused about is this part of
pymc3/sampling.py
:So, we iterate through the traces (from the longest one to the shortest one), and stop when the
total
((i + 1) * length
is smaller than it was for the previous trace.For example, if we have traces of length 10, 7, 3, then we would get:
i=0
,length=10
:total=10
,last_total=0
: astotal >= last_total
, we continuei=1
,length=7
:total=14
,last_total=10
: astotal >= last_total
, we continuei=2
,length=3
:total=9
,last_total=14
: astotal < last_total
, we breakI just don't see why we'd do that - what's the significance of
(i + 1) * length
? This function comes from #3011 (cc @aseyboldt in case you could offer any help understanding this)