Skip to content

Commit

Permalink
Clean up JAX random seed handling in Time Series example notebook.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 493035184
  • Loading branch information
ursk authored and tensorflower-gardener committed Dec 5, 2022
1 parent e816859 commit 6bba470
Showing 1 changed file with 9 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -477,19 +477,19 @@
"num_variational_steps = 200 # @param { isTemplate: true}\n",
"num_variational_steps = int(num_variational_steps)\n",
"\n",
"seed = tfp.random.sanitize_seed(jax.random.PRNGKey(42), salt='fit_stateless') \n",
"init_seed, fit_seed, sample_seed = tfp.random.split_seed(seed, n=3) \n",
"initial_parameters = init_fn(init_seed) \n",
"jd = co2_model.joint_distribution(co2_by_month_training_data) \n",
"seed = jax.random.PRNGKey(42)\n",
"init_seed, fit_seed, sample_seed = jax.random.split(seed, 3)\n",
"initial_parameters = init_fn(init_seed)\n",
"jd = co2_model.joint_distribution(co2_by_month_training_data)\n",
"\n",
"# Build and optimize the variational loss function.\n",
"optimized_parameters, elbo_loss_curve = tfp.vi.fit_surrogate_posterior_stateless( \n",
" target_log_prob_fn=jd.log_prob, \n",
" initial_parameters=initial_parameters, \n",
" build_surrogate_posterior_fn=build_surrogate_fn, \n",
"optimized_parameters, elbo_loss_curve = tfp.vi.fit_surrogate_posterior_stateless(\n",
" target_log_prob_fn=jd.log_prob,\n",
" initial_parameters=initial_parameters,\n",
" build_surrogate_posterior_fn=build_surrogate_fn,\n",
" optimizer=optax.adam(0.1), \n",
" num_steps=num_variational_steps,\n",
" seed=fit_seed) \n",
" seed=fit_seed)\n",
"plt.plot(elbo_loss_curve)\n",
"plt.show()\n",
"\n",
Expand Down

0 comments on commit 6bba470

Please sign in to comment.