-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compile the functions needed by SMC before the worker processes are started #7472
base: main
Are you sure you want to change the base?
Conversation
self.varlogp = self.model.varlogp | ||
self.datalogp = self.model.datalogp | ||
|
||
def initialize_rng(self, random_seed=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was it unnecessary to add this method? Didn't want to directly access SMC_KERNEL.rng
since it's not initialized by just direct assignment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't follow, can you explain again?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a new method SMC_KERNEL.initialize_rng
which creates a new SMC_KERNEL.rng
with given seed. It's just a convenience method for seeding the rng
with a different seed in each worker process. Previously it wasn't necessary because the kernels were created in each process separately and rng
is seeded during that. This PR creates the kernel before creating the worker processes and seeding has to be done separately.
I was wondering if adding a new method was unnecessary. The method doesn't do much after all. I could just do smc.rng = np.random.default_rng(seed=random_seed)
in _sample_smc_int
instead but I didn't want to interact with SMC_KERNEL.rng
since I didn't see it used anywhere else outside of SMC_KERNEL
.
Nitpicky? I agree.
The test from #7241 is missing but I intend to add it later. |
One thing we'll have to be careful that didn't matter before is to update the shared variables of the logp functions that define RNGs. Model with minibatch/Simulator have a stochastic logp. For those we will probably have to copy the pytensor function using the |
Which shared variables are you referring to? I ran Could you kick off the tests? I don't know if I'm doing something wrong on my computer which leads to failures in the tests. |
Description
Currently
sample_smc
can fail due to aNotImplementedError
if it's used with a model defined usingCustomDist
. If aCustomDist
is used, the overloads for e.g._logprob
are registered only in the main process. The issue exists only on Windows because the worker processes are spawned. In other systems where the default option is forking, everything works.#7241 fixed the issue by registering the overloads manually. Although that would fix #7224, the approach might not be the best in the long run.
This PR moves some of the SMC kernel initialization (calculating
initial_point
and compiling*logp
functions) from worker processes to the main process. This way the overloads are not needed in the worker processes.Related Issue
pymc.sample_smc
fails withpymc.CustomDist
#7224, Register the overloads added by CustomDist in worker processes #7241, Add a test for SMC sampling from CustomDist in multiple processes #7337Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7472.org.readthedocs.build/en/7472/