From 6bba47076ee5515c8d64bd940993308ee8b5660b Mon Sep 17 00:00:00 2001 From: ursk Date: Mon, 5 Dec 2022 09:45:47 -0800 Subject: [PATCH] Clean up JAX random seed handling in Time Series example notebook. PiperOrigin-RevId: 493035184 --- ...pheric_CO2_and_Electricity_Demand_JAX.ipynb | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tensorflow_probability/examples/jupyter_notebooks/Structural_Time_Series_Modeling_Case_Studies_Atmospheric_CO2_and_Electricity_Demand_JAX.ipynb b/tensorflow_probability/examples/jupyter_notebooks/Structural_Time_Series_Modeling_Case_Studies_Atmospheric_CO2_and_Electricity_Demand_JAX.ipynb index 9285160ac9..f076e1efd2 100644 --- a/tensorflow_probability/examples/jupyter_notebooks/Structural_Time_Series_Modeling_Case_Studies_Atmospheric_CO2_and_Electricity_Demand_JAX.ipynb +++ b/tensorflow_probability/examples/jupyter_notebooks/Structural_Time_Series_Modeling_Case_Studies_Atmospheric_CO2_and_Electricity_Demand_JAX.ipynb @@ -477,19 +477,19 @@ "num_variational_steps = 200 # @param { isTemplate: true}\n", "num_variational_steps = int(num_variational_steps)\n", "\n", - "seed = tfp.random.sanitize_seed(jax.random.PRNGKey(42), salt='fit_stateless') \n", - "init_seed, fit_seed, sample_seed = tfp.random.split_seed(seed, n=3) \n", - "initial_parameters = init_fn(init_seed) \n", - "jd = co2_model.joint_distribution(co2_by_month_training_data) \n", + "seed = jax.random.PRNGKey(42)\n", + "init_seed, fit_seed, sample_seed = jax.random.split(seed, 3)\n", + "initial_parameters = init_fn(init_seed)\n", + "jd = co2_model.joint_distribution(co2_by_month_training_data)\n", "\n", "# Build and optimize the variational loss function.\n", - "optimized_parameters, elbo_loss_curve = tfp.vi.fit_surrogate_posterior_stateless( \n", - " target_log_prob_fn=jd.log_prob, \n", - " initial_parameters=initial_parameters, \n", - " build_surrogate_posterior_fn=build_surrogate_fn, \n", + "optimized_parameters, elbo_loss_curve = tfp.vi.fit_surrogate_posterior_stateless(\n", + " target_log_prob_fn=jd.log_prob,\n", + " initial_parameters=initial_parameters,\n", + " build_surrogate_posterior_fn=build_surrogate_fn,\n", " optimizer=optax.adam(0.1), \n", " num_steps=num_variational_steps,\n", - " seed=fit_seed) \n", + " seed=fit_seed)\n", "plt.plot(elbo_loss_curve)\n", "plt.show()\n", "\n",