Rename log_prob to unnormalized_log_prob in perturbations (#1214)#1614
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
6043124 to
e023130
Compare
|
Hey, thanks for the contribution, but this seems like an incomplete solution given the discussion in #1214 Can you argue for one of the more principled solutions discussed in the thread? It seems like a much better approach. |
) - Add function-based noise samplers: normal_noise_sampler, gumbel_noise_sampler - Add unnormalized log-prob functions: unnormalized_normal_log_prob, unnormalized_gumbel_log_prob - Deprecate Normal and Gumbel classes with DeprecationWarning - Refactor make_perturbed_fun to accept noise_sampler and noise_log_prob kwargs - Deprecate the noise parameter with DeprecationWarning - Maintain backwards compatibility for existing code using the noise parameter - Update tests to use the new function-based API - Export new functions from optax.perturbations
|
@rdyro Thanks for the feedback! I've updated the PR to implement the function-based API discussed in #1214. make_perturbed_fun now accepts noise_sampler and noise_log_prob as separate callable arguments instead of a noise object, following vroulet's preferred approach. Added standalone functions: normal_noise_sampler, unnormalized_normal_log_prob, gumbel_noise_sampler, unnormalized_gumbel_log_prob. The Normal and Gumbel classes and the noise parameter are deprecated with warnings but remain functional for backwards compatibility. All existing tests pass with the new API. |
Resolves #1214.
This PR renames the log_prob method to unnormalized_log_prob in the Normal and Gumbel distribution classes under optax.perturbations, as these methods return the logarithm of the unnormalized probability density. This follows the naming convention used in TensorFlow Probability and prevents user confusion with normalized log probabilities.
Changes