Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix param randomization, generating repeating values #785

Merged
merged 2 commits into from
Sep 20, 2023

Conversation

khurram-ghani
Copy link
Collaborator

@khurram-ghani khurram-ghani commented Sep 18, 2023

Related issue(s)/PRs: None

Summary

This PR workarounds an issue where randomize_hyperparameters generated same repeating values for model hyperparameters when the global seed was set. The issue only occurred when tf.function compilation was enabled.

The issue seems to be related to the following documented behaviour of tensorflow:

Note that tf.function acts like a re-run of a program in this case. When the global seed is set but operation seeds are not set, the sequence of random numbers are the same for each tf.function.

When the function being compiled has a dynamic conditional (i.e. tf.cond) and the branches contain randomization calls, it seems internally tensorflow acts like "... re-run of a program". This is likely related to the fact that AutoGraph executes both branches during tracing. This could potentially be a tensorflow bug, but requires more investigation.

This PR simply removes the tf.Tensor condition expression (which is converted to tf.cond via AutoGraph) to a static python expression. Also added a unit test to catch the issue, which fails on previous version of the code.

Fully backwards compatible: yes

PR checklist

  • The quality checks are all passing
  • The bug case / new feature is covered by tests
  • Any new features are well-documented (in docstrings or notebooks)

@uri-granta uri-granta self-requested a review September 19, 2023 14:16
Copy link
Collaborator

@uri-granta uri-granta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@@ -68,7 +68,7 @@ def randomize_hyperparameters(object: gpflow.Module) -> None:
param.assign(sample)
elif param.prior is not None:
# handle constant priors for multi-dimensional parameters
if param.prior.batch_shape == param.prior.event_shape == [] and tf.rank(param) == 1:
if param.prior.batch_shape == param.prior.event_shape == [] and len(param.shape) == 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a comment about why this should not be changed to tf.rank? (though now that there's a test at least we'd catch it)

@khurram-ghani khurram-ghani merged commit 3279d31 into develop Sep 20, 2023
12 checks passed
@khurram-ghani khurram-ghani deleted the khurram/rand_params_repeat_fix branch September 20, 2023 08:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants