Skip to content

Rename log_prob to unnormalized_log_prob in perturbations (#1214)#1614

Open
jinukuntlaakhilakumargoud-web wants to merge 2 commits intogoogle-deepmind:mainfrom
jinukuntlaakhilakumargoud-web:rename-log-prob-1214
Open

Rename log_prob to unnormalized_log_prob in perturbations (#1214)#1614
jinukuntlaakhilakumargoud-web wants to merge 2 commits intogoogle-deepmind:mainfrom
jinukuntlaakhilakumargoud-web:rename-log-prob-1214

Conversation

@jinukuntlaakhilakumargoud-web

Resolves #1214.

This PR renames the log_prob method to unnormalized_log_prob in the Normal and Gumbel distribution classes under optax.perturbations, as these methods return the logarithm of the unnormalized probability density. This follows the naming convention used in TensorFlow Probability and prevents user confusion with normalized log probabilities.

Changes

  • Renamed Normal.log_prob -> Normal.unnormalized_log_prob
    • Renamed Gumbel.log_prob -> Gumbel.unnormalized_log_prob
    • Updated docstring reference in make_perturbed_fun
    • Updated internal usage in make_perturbed_fun

@google-cla
Copy link

google-cla bot commented Mar 5, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@rdyro
Copy link
Collaborator

rdyro commented Mar 5, 2026

Hey, thanks for the contribution, but this seems like an incomplete solution given the discussion in #1214

Can you argue for one of the more principled solutions discussed in the thread? It seems like a much better approach.

)

- Add function-based noise samplers: normal_noise_sampler, gumbel_noise_sampler
- Add unnormalized log-prob functions: unnormalized_normal_log_prob, unnormalized_gumbel_log_prob
- Deprecate Normal and Gumbel classes with DeprecationWarning
- Refactor make_perturbed_fun to accept noise_sampler and noise_log_prob kwargs
- Deprecate the noise parameter with DeprecationWarning
- Maintain backwards compatibility for existing code using the noise parameter
- Update tests to use the new function-based API
- Export new functions from optax.perturbations
@jinukuntlaakhilakumargoud-web
Copy link
Author

@rdyro Thanks for the feedback! I've updated the PR to implement the function-based API discussed in #1214. make_perturbed_fun now accepts noise_sampler and noise_log_prob as separate callable arguments instead of a noise object, following vroulet's preferred approach. Added standalone functions: normal_noise_sampler, unnormalized_normal_log_prob, gumbel_noise_sampler, unnormalized_gumbel_log_prob. The Normal and Gumbel classes and the noise parameter are deprecated with warnings but remain functional for backwards compatibility. All existing tests pass with the new API.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Rename optax.perturbations distribution method log_prob to unnormalized_log_prob

2 participants