-
Notifications
You must be signed in to change notification settings - Fork 1
FlexiChains by default #44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b0cd534
801aa3a
bab1b19
e9c1371
0fefe34
966bb63
02634b6
3471cdd
9dfadb8
414b497
8531ba2
84ae244
07132ae
e9e64fb
588966a
0b4c98f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -8,21 +8,20 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | |||||
| AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" | ||||||
| CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" | ||||||
| DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" | ||||||
| FlexiChains = "4a37a8b9-6e57-4b92-8664-298d46e639f7" | ||||||
| LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||||||
| MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" | ||||||
| Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||||||
| Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||||||
|
|
||||||
| [weakdeps] | ||||||
| DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" | ||||||
| Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" | ||||||
| ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | ||||||
| LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" | ||||||
| Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" | ||||||
| Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||||||
|
|
||||||
| [extensions] | ||||||
| DynamicPPLExt = ["DynamicPPL", "ForwardDiff", "LogDensityProblems"] | ||||||
| DynamicPPLExt = ["DynamicPPL", "FlexiChains", "LogDensityProblems"] | ||||||
| EnzymeExt = "Enzyme" | ||||||
| LogDensityProblemsExt = "LogDensityProblems" | ||||||
|
|
||||||
|
|
@@ -34,10 +33,9 @@ CUDA_Runtime_jll = "0.21" | |||||
| DifferentiationInterface = "0.7.13" | ||||||
| DynamicPPL = "0.40.6, 0.41" | ||||||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| Enzyme = "0.13.146" | ||||||
| ForwardDiff = "1" | ||||||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should still declare a compat v for this even if removed from the Ext |
||||||
| FlexiChains = "0.6.6" | ||||||
| LinearAlgebra = "1" | ||||||
| LogDensityProblems = "2" | ||||||
| MCMCChains = "7.7.0" | ||||||
| Mooncake = "0.5.26" | ||||||
| Random = "1" | ||||||
| Statistics = "1" | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,7 +12,7 @@ All samplers take a [`DensityModel`](@ref) as their first argument. A `DensityM | |
| - `dim::Int` — dimension of the parameter space | ||
|
|
||
| ```julia | ||
| using ParallelMCMC, MCMCChains | ||
| using ParallelMCMC, FlexiChains | ||
| using ADTypes, Enzyme | ||
|
|
||
| # Banana-shaped target in 2-D | ||
|
|
@@ -40,7 +40,7 @@ sampler = ParallelMALASampler(0.1; T=64, jacobian=:stoch_diag, damping=0.5, | |
| backend=AutoEnzyme()) | ||
|
|
||
| chain = sample(model, sampler, 500; | ||
| chain_type=MCMCChains.Chains, progress=true) | ||
| chain_type=VNChain, progress=true) | ||
| ``` | ||
|
|
||
| `sample` requests 500 total samples. Internally, DEER solves trajectories of length `T=64` and returns each column of the solved trajectory as a separate sample. When the trajectory is exhausted a new noise tape is drawn and DEER re-solves from the last state. | ||
|
|
@@ -90,7 +90,7 @@ ParallelMCMC.jl integrates with Turing.jl models through the `DynamicPPL` and `L | |
| Load `DynamicPPL` (part of Turing.jl) and a single-argument `DensityModel` constructor becomes available: | ||
|
|
||
| ```julia | ||
| using Turing, ParallelMCMC, MCMCChains | ||
| using Turing, ParallelMCMC, FlexiChains | ||
|
|
||
| @model function normal_model(y) | ||
| μ ~ Normal(0.0, 1.0) | ||
|
|
@@ -100,7 +100,7 @@ end | |
| model = DensityModel(normal_model(1.5)) # param_names=[:μ] extracted automatically | ||
|
|
||
| chain = sample(model, ParallelMALASampler(0.1; T=64, backend=AutoEnzyme()), 500; | ||
| chain_type=MCMCChains.Chains) | ||
| chain_type=VNChain) | ||
| ``` | ||
|
|
||
| Much like Turing's own samplers, the resulting chain will always have parameters in the original (possibly constrained) space, even though the MCMC sampling itself is performed in unconstrained space. | ||
|
|
@@ -116,7 +116,7 @@ directly with DynamicPPL's `adtype` interface: | |
|
|
||
| ```julia | ||
| using Turing, LogDensityProblems, ADTypes | ||
| using ParallelMCMC, MCMCChains | ||
| using ParallelMCMC, FlexiChains | ||
|
|
||
| ld = DynamicPPL.LogDensityFunction( | ||
| normal_model(1.5), | ||
|
|
@@ -140,10 +140,48 @@ All samplers support `MCMCThreads()`. Start Julia with multiple threads (e.g. ` | |
| ```julia | ||
| chain = sample(model, ParallelMALASampler(0.1; T=64, backend=AutoEnzyme()), | ||
| MCMCThreads(), 500, 4; | ||
| chain_type=MCMCChains.Chains) | ||
| chain_type=VNChain) | ||
| ``` | ||
|
|
||
|
Comment on lines
144
to
145
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. love ty |
||
| `MCMCChains` computes R-hat and ESS across chains automatically. | ||
| `FlexiChains` computes R-hat and ESS across chains automatically: | ||
|
|
||
| ```julia | ||
| ess(chain) | ||
| ``` | ||
|
|
||
| To calculate intra-chain metrics, you can pass `dims=:iter` [as described in the FlexiChains docs](https://pysm.dev/FlexiChains.jl/stable/summarising/#Individual-statistics): | ||
|
|
||
| ```julia | ||
| ess(chain; dims=:iter) | ||
| ``` | ||
|
|
||
| --- | ||
|
|
||
| ## [Specifying parameter names](@id parameter-names) | ||
|
|
||
| For manually constructed `DensityModel`s, you can optionally specify parameter names with the `param_names` keyword. | ||
| The resulting `FlexiChain` object will then have named entries for each parameter. | ||
|
|
||
| If you do not specify `param_names`, the chain will store a single vector-valued parameter called `x` of length `D`. | ||
|
|
||
| You can specify parameter names as a collection of either: | ||
|
|
||
| - `Symbol`s (e.g. `[:x1, :x2]`); | ||
| - `VarName`s (e.g. `[@varname(x[1]), @varname(x[2])]`) if `chain_type=VNChain`; or | ||
| - A tuple of the above, *plus* a size. In this case, a total of `prod(size)` entries will be allocated to the named parameter, and the results in the chain will be reshaped to that size. | ||
|
|
||
| The above can be mixed and matched as desired, as long as the total number of parameters matches the dimension of the model. | ||
| For example: | ||
|
|
||
| ```julia | ||
| # Three scalar parameters, called `x1` through `x3` | ||
| model = DensityModel(...; param_names=[:x1, :x2, :x3]) | ||
|
|
||
| # One scalar parameter called `x`, and a 1x2 matrix parameter called `y` | ||
| model = DensityModel(...; param_names=[:x, (:y, (1, 2))]) | ||
| ``` | ||
|
|
||
| For Turing.jl models, parameter names are automatically derived from the model and do not need to be specified manually. | ||
|
|
||
| --- | ||
|
|
||
|
|
@@ -158,14 +196,14 @@ chain = sample(model, ParallelMALASampler(0.1; T=64, backend=AutoEnzyme()), | |
| ```julia | ||
| # Step 1: find a good step size with adaptive MALA | ||
| baseline = sample(model, AdaptiveMALASampler(0.1; n_warmup=500), 600; | ||
| chain_type=MCMCChains.Chains, discard_warmup=true) | ||
| chain_type=VNChain, discard_warmup=true) | ||
|
|
||
| # Read off the frozen step size from the last internal value | ||
| eps_tuned = baseline[end, :step_size, 1] | ||
|
|
||
| # Step 2: run DEER with the tuned step size | ||
| chain = sample(model, ParallelMALASampler(eps_tuned; T=64, backend=AutoEnzyme()), 2_000; | ||
| chain_type=MCMCChains.Chains) | ||
| chain_type=VNChain) | ||
| ``` | ||
|
|
||
| See the [`MALASampler`](@ref) and [`AdaptiveMALASampler`](@ref) reference pages for the full keyword listing. | ||
|
penelopeysm marked this conversation as resolved.
|
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -4,7 +4,7 @@ using ParallelMCMC | |||||||
| using ADTypes: ADTypes | ||||||||
| using DynamicPPL: DynamicPPL | ||||||||
| using AbstractMCMC: AbstractMCMC | ||||||||
| using MCMCChains: MCMCChains | ||||||||
| using FlexiChains: FlexiChain, VarName, VNChain | ||||||||
| using LogDensityProblems: LogDensityProblems | ||||||||
|
|
||||||||
| """ | ||||||||
|
|
@@ -14,13 +14,12 @@ Convenience constructor: wraps a DynamicPPL/Turing `@model` directly as a | |||||||
| `DensityModel`, automatically extracting parameter names and wiring up gradient | ||||||||
| computation via DynamicPPL's `adtype` interface. | ||||||||
|
|
||||||||
| Requires `DynamicPPL`, `ForwardDiff`, and `LogDensityProblems` to be loaded (these are the | ||||||||
| weak-dependency triggers for this extension; `ForwardDiff` is what backs the default | ||||||||
| `AutoForwardDiff()` AD path). | ||||||||
| Requires `DynamicPPL` and `LogDensityProblems` to be loaded (these are the weak-dependency | ||||||||
| triggers for this extension), plus any AD backend that is used. | ||||||||
|
|
||||||||
| # Example | ||||||||
| ```julia | ||||||||
| using Turing, ParallelMCMC, MCMCChains | ||||||||
| using Turing, ParallelMCMC, FlexiChains | ||||||||
|
penelopeysm marked this conversation as resolved.
|
||||||||
|
|
||||||||
| @model function mymodel(y) | ||||||||
| μ ~ Normal(0, 1) | ||||||||
|
|
@@ -31,7 +30,7 @@ end | |||||||
| # and `using` the corresponding package (Enzyme, Mooncake). | ||||||||
| model = DensityModel(mymodel(1.5)) | ||||||||
| chain = sample(model, AdaptiveMALASampler(0.3; n_warmup=500), 2_000; | ||||||||
| chain_type=MCMCChains.Chains, discard_warmup=true, progress=true) | ||||||||
| chain_type=FlexiChains.VNChain, discard_warmup=true, progress=true) | ||||||||
| ``` | ||||||||
| """ | ||||||||
| function ParallelMCMC.DensityModel( | ||||||||
|
Comment on lines
35
to
36
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So right now I have ForwardDiff as the dfault for this ext because of issue #25 and ForwardDiff worked as a workaround. I failed to document this so thats my b. though I'm not sure it makes total sense to have this as a default just given there is other backends i didn't try. So at the very least we need to either re-add in the ForwardDiff as a gate for this ext or just remove this |
||||||||
|
|
@@ -106,40 +105,34 @@ for (Ttrans, Tspl, Tstate) in ( | |||||||
| model::DensityModelLDF, | ||||||||
| spl::$Tspl, | ||||||||
| state::$Tstate, | ||||||||
| chain_type::Type{MCMCChains.Chains}; | ||||||||
| chain_type::Type{VNChain}, | ||||||||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we add some type of runtime guard for if a turing user accidently passes a SymChain? |
||||||||
| discard_warmup::Bool=false, | ||||||||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think this should still be a kwarg
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| kwargs..., | ||||||||
| ) | ||||||||
| ts = discard_warmup ? filter(t -> !is_warmup(t), ts) : ts | ||||||||
| return make_processed_dynamicppl_chain(MCMCChains.Chains, ts, model) | ||||||||
| pwss = map(ts) do t | ||||||||
| # Note: This assumes that there is always a field called t.x. This is currently true | ||||||||
| # of all samplers in ParallelMCMC | ||||||||
| DynamicPPL.ParamsWithStats(t.x, model.logdensity.ld, getstats(t)) | ||||||||
| end | ||||||||
| return AbstractMCMC.from_samples(VNChain, hcat(pwss)) | ||||||||
| end | ||||||||
| end | ||||||||
| end | ||||||||
|
|
||||||||
| function make_processed_dynamicppl_chain( | ||||||||
| ::Type{Tchain}, ts::Vector{<:ParallelMCMCTransitionTypes}, model::DensityModelLDF | ||||||||
| ) where {Tchain} | ||||||||
| pwss = map(ts) do t | ||||||||
| # Note: This assumes that there is always a field called t.x. This is currently true | ||||||||
| # of all samplers in ParallelMCMC | ||||||||
| DynamicPPL.ParamsWithStats(t.x, model.logdensity.ld, getstats(t)) | ||||||||
| end | ||||||||
| return AbstractMCMC.from_samples(Tchain, hcat(pwss)) | ||||||||
| end | ||||||||
|
|
||||||||
| function ParallelMCMC._construct_chain( | ||||||||
| ::Type{MCMCChains.Chains}, | ||||||||
| function ParallelMCMC._construct_flexichain( | ||||||||
| ::Type{VarName}, | ||||||||
| vals::AbstractMatrix{<:Real}, | ||||||||
| internals::AbstractMatrix{<:Real}, | ||||||||
| ::Vector{Symbol}, | ||||||||
| ::Any, | ||||||||
| internal_names::Vector{Symbol}, | ||||||||
| model::DensityModelLDF, | ||||||||
| ) | ||||||||
| pwss = map(zip(eachrow(vals), eachrow(internals))) do (val, internal) | ||||||||
| stats = NamedTuple{Tuple(internal_names)}(internal) | ||||||||
| DynamicPPL.ParamsWithStats(val, model.logdensity.ld, stats) | ||||||||
|
penelopeysm marked this conversation as resolved.
|
||||||||
| end | ||||||||
| return AbstractMCMC.from_samples(MCMCChains.Chains, hcat(pwss)) | ||||||||
| return AbstractMCMC.from_samples(VNChain, hcat(pwss)) | ||||||||
| end | ||||||||
|
|
||||||||
| end # module | ||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see comment in DynamicPPLExt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also FlexiChains is a hard dep now so shouldn't need it here anyways