FlexiChains by default#44
Conversation
28047fa to
ebd73cc
Compare
penelopeysm
left a comment
There was a problem hiding this comment.
@rsenne Thought I'd get your opinions on a couple of points before going further.
|
Thanks for this! Here are my thoughts:
Also re your questions to ponder:
I agree
If you agree with my above thoughts, none! :) |
|
You're totally correct, there's no real need to support MCMCChains. That makes life a lot easier! |
|
Okay, I got all the tests to pass locally. Apart from switching over to FlexiChains, I made a couple of design calls (which I am hoping is positive 🙂), which I'll discuss in the comments |
There was a problem hiding this comment.
Pull request overview
This PR migrates ParallelMCMC’s chain output and tests from MCMCChains.Chains to FlexiChains, making FlexiChains chain types (e.g. VNChain / SymChain) the primary output target across samplers and extensions.
Changes:
- Replaced
MCMCChainswithFlexiChainsin core module dependencies and in the AbstractMCMC integration (bundle_samples,mcmcsamplepaths). - Updated Turing/DynamicPPL and interface tests to assert on FlexiChains APIs (
parameters,extras,niters,stack=trueindexing). - Updated package/test dependencies to drop
MCMCChainsand addFlexiChains.
Reviewed changes
Copilot reviewed 12 out of 13 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| test/test-Turing-Integration.jl | Switches Turing integration expectations from MCMCChains names/indexing to FlexiChains parameter/extras APIs. |
| test/test-GPU-Performance.jl | Removes unused MCMCChains test dependency. |
| test/test-GPU-AD-HVP.jl | Removes unused MCMCChains test dependency. |
| test/test-DEER-Turing-Logistic.jl | Updates chain type + indexing to FlexiChains (VNChain/SymChain, stack=true). |
| test/test-DEER-Interface.jl | Migrates interface tests from names(..., :internals)/size to FlexiChains extras/niters/nchains. |
| test/test-Adaptive-MALA.jl | Migrates assertions to FlexiChains chain types and extras. |
| test/test-AbstractMCMC-Interface.jl | Expands test coverage for symbol vs varname keys and vector-valued params under FlexiChains. |
| test/Project.toml | Adds FlexiChains and removes MCMCChains from the test environment deps. |
| src/ParallelMCMC.jl | Replaces using MCMCChains with using FlexiChains and re-exports sample. |
| src/interface.jl | Implements FlexiChains-based chain construction for all samplers; changes default param naming to a single vector-valued :x. |
| Project.toml | Adds FlexiChains, removes MCMCChains, and updates extension trigger list. |
| ext/LogDensityProblemsExt.jl | Updates docs to mention FlexiChains (but still contains outdated default-name text). |
| ext/DynamicPPLExt.jl | Switches DynamicPPL integration to FlexiChains, adds VNChain-only bundling, and overrides construction for VarName chains. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Re-exports for convenience | ||
| import AbstractMCMC: sample | ||
| export sample |
There was a problem hiding this comment.
Oh, yes, I forgot to mention this. Honestly you probably do want to re-export sample...... I don't think there's a single soul out there who wants to use ParallelMCMC without using sample.
MCMCChains used to reexport sample so prior to this PR the basic README example would have worked fine, but FlexiChains doesn't reexport it so if you don't, the example will break.
There was a problem hiding this comment.
(also it's all one single sample function -- AbstractMCMC, MCMCChains, Turing, etc. etc all reexport the same StatsBase.sample so it won't cause conflicts)
Codecov Report❌ Patch coverage is
❌ Your patch check has failed because the patch coverage (73.13%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #44 +/- ##
==========================================
- Coverage 80.06% 79.48% -0.59%
==========================================
Files 8 8
Lines 1164 1165 +1
==========================================
- Hits 932 926 -6
- Misses 232 239 +7 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
| # Wrap parameter names in the format that FlexiChains expects. Ref: | ||
| # https://pysm.dev/FlexiChains.jl/stable/arrays/#api-fromarray | ||
| if param_names === nothing | ||
| param_names = model.param_names | ||
| end | ||
| wrapped_param_names = if param_names === nothing | ||
| # Wasn't defined either as `param_names` or `model.param_names`, so make some up | ||
| (to_parameter(:x) => (size(vals, 2),),) | ||
| else | ||
| map(param_names) do n | ||
| if n isa Pair | ||
| to_parameter(n.first) => n.second | ||
| elseif n isa TKey || n isa Symbol | ||
| to_parameter(n) | ||
| else | ||
| throw(ArgumentError("param_names must be a collection of Pairs or $TKey, got $(typeof(n))")) | ||
| end | ||
| end | ||
| end |
There was a problem hiding this comment.
The main change is to do with how param_names are handled. This is actually not really even a big change but I think it accounts for most of the real code changes in this PR (the rest is boilerplate) so I may as well explain
Previously you could pass a vector of Symbol, and then that gets stored in MCMCChains, nice and simple. Now I've 'upgraded' it to support grouping different parameters together. Let's say you have a model with dimension 3, but conceptually you really want the first parameter to be a scalar and the other two to be grouped together as a vector. FlexiChains has a constructor that lets you do that (https://pysm.dev/FlexiChains.jl/stable/arrays/#api-fromarray).
It's a faff for the user to write all that out, but this code block lets people strip away all the FlexiChains.Parameter stuff and just write something like param_names = (:x, :y => (2,)). On top of that, this also allows us to naturally 'upgrade' a Symbol chain to a VarName chain: if the user requests
chain_type=FlexiChain{VarName}, param_names=(:x, :y => (2,))
then the :x and :y will just be 'upgraded' to @varname(x) and @varname(y). (VarNames have some benefits over Symbols in that it's easier to index into part of them, say to get y[1] out.)
Finally, the default param name has been changed: instead of scalar-valued x[i] for i in 1:D, it's now just one vector-valued parameter. That makes it easier to access the entire thing as a 3D array, while still allowing access to individual elements.
Given all this, I'd suggest that the docs should probably push users towards using VNChain as the go-to chain type, and say something like 'SymChain also works if you prefer it'? I haven't yet updated the docs, but if you're in agreement then I'll do it in the final round of this PR!
There was a problem hiding this comment.
Yes this all LGTM and will be a big improvement imo
| ) | ||
| N = length(samples) | ||
| D = model.dim | ||
| for TKey in (Symbol, VarName) |
There was a problem hiding this comment.
The cost of supporting both Symbol and VarName chains is having to do this loop over both keys (to avoid ambiguities) but since it's self-contained in this file, I think it's fine. The problem I was worried about last time was more about having to have a separate but same implementation in an MCMCChainsExt.
There was a problem hiding this comment.
makes sense--loops are okay w/me but see below!
| @test all(a -> a == 0.0 || a == 1.0, acc) | ||
| end | ||
|
|
||
| @testset "sample() with custom param_names" begin |
There was a problem hiding this comment.
You can see some examples of the param_names change works in this testset.
| s = samples[i] | ||
| vals[i, :] .= s.x | ||
| internals[i, 1] = s.logp | ||
| internals[i, 2] = s.accepted ? 1.0 : 0.0 |
There was a problem hiding this comment.
Last question: Is there a need to store accepted as a F64, is it just so that internals is concrete? If it's the latter, maybe we could think about making internals a NamedTuple like (logp = [logp_matrix...], accepted = [accepted_matrix...])? That would let us store it as the natural Bool. Not one for this PR though, it's a separate thing!
There was a problem hiding this comment.
Good catch. Nope no reason. I like your suggestion. I'll open a new issue about it
|
Okay so I was reading over the code and I had a thought (I'm mildly sleep deprived and this idea is only half-baked so feel free to say I'm not making any sense). But i was looking at the functionality you put in And so I guess IIUC the # helper with same functionality as _construct_flexichain
function parameter_keys(spec, ::Type{TKey}) where {TKey}
to_parameter(vn::VarName) = Parameter(vn)
to_parameter(s::Symbol) = Parameter(TKey <: VarName ? VarName{s}() : s)
map(spec) do n
n isa Pair ? (to_parameter(n.first) => n.second) :
(n isa TKey || n isa Symbol) ? to_parameter(n) :
throw(ArgumentError("param_names must be Symbols, $TKey, or name=>shape Pairs"))
end
endand then add a function bundle_samples(transitions, m, s, state, ::Type{FlexiChain{Symbol}};
param_names=nothing, stats=missing, …)
if param_names === nothing
# unchanged
else
keys = parameter_keys(param_names, Symbol) # Piece 1
arr = stack_parameter_values(transitions) # iters × Σdims, from the hook values
extras = stack_stats(transitions)
return FlexiChain{Symbol}(hcat(arr, extras), (keys..., Extra.(...)...))
end
endConditioned on me making a semblance of sense, the PR as is can still proceed and this can be addressed later as these changes would need to be implemented themselves. Otherwise PR is looking good thus far! |
Yes, that and also the fact that you want to generate all the samples at once. I think the current AbstractMCMC interface kind of assumes that you're generating samples iteratively. The
I think maybe. I have a couple of mild objections:
Basically, let's not overengineer things yet. (In fact I feel like |
ah yes great point
a very reasonable assumption until this package :)
makes sense/sounds good to me |
|
Okay, docs updated too, so this should be reviewable & mergeable! |
| @@ -34,10 +33,9 @@ CUDA_Runtime_jll = "0.21" | |||
| DifferentiationInterface = "0.7.13" | |||
| DynamicPPL = "0.40.6, 0.41" | |||
There was a problem hiding this comment.
| DynamicPPL = "0.40.6, 0.41" | |
| DynamicPPL = "0.41.6, 0.42" |
|
|
||
| [extensions] | ||
| DynamicPPLExt = ["DynamicPPL", "ForwardDiff", "LogDensityProblems"] | ||
| DynamicPPLExt = ["DynamicPPL", "FlexiChains", "LogDensityProblems"] |
There was a problem hiding this comment.
also FlexiChains is a hard dep now so shouldn't need it here anyways
| DifferentiationInterface = "0.7.13" | ||
| DynamicPPL = "0.40.6, 0.41" | ||
| Enzyme = "0.13.146" | ||
| ForwardDiff = "1" |
There was a problem hiding this comment.
We should still declare a compat v for this even if removed from the Ext
| """ | ||
| function ParallelMCMC.DensityModel( |
There was a problem hiding this comment.
So right now I have ForwardDiff as the dfault for this ext because of issue #25 and ForwardDiff worked as a workaround. I failed to document this so thats my b. though I'm not sure it makes total sense to have this as a default just given there is other backends i didn't try. So at the very least we need to either re-add in the ForwardDiff as a gate for this ext or just remove this
| ``` | ||
|
|
| state::$Tstate, | ||
| chain_type::Type{MCMCChains.Chains}; | ||
| chain_type::Type{VNChain}, | ||
| discard_warmup::Bool=false, |
There was a problem hiding this comment.
i think this should still be a kwarg
There was a problem hiding this comment.
| discard_warmup::Bool=false, | |
| chain_type::Type{VNChain}; | |
| discard_warmup::Bool=false, |
| spl::$Tspl, | ||
| state::$Tstate, | ||
| chain_type::Type{MCMCChains.Chains}; | ||
| chain_type::Type{VNChain}, |
There was a problem hiding this comment.
should we add some type of runtime guard for if a turing user accidently passes a SymChain?
| names = _default_param_names(model, D, param_names) | ||
|
|
||
| internal_names = [:logp] | ||
| internal_names = [:logp] | ||
|
|
||
| vals = Matrix{Float64}(undef, N, D) | ||
| internals = Matrix{Float64}(undef, N, 1) | ||
| vals = Matrix{Float64}(undef, N, D) | ||
| internals = Matrix{Float64}(undef, N, 1) | ||
|
|
||
| for i in 1:N | ||
| vals[i, :] .= samples[i].x | ||
| internals[i, 1] = samples[i].logp | ||
| end | ||
| for i in 1:N | ||
| vals[i, :] .= samples[i].x | ||
| internals[i, 1] = samples[i].logp | ||
| end | ||
|
|
||
| return MCMCChains.Chains( | ||
| hcat(vals, internals), | ||
| vcat(names, internal_names), | ||
| Dict(:parameters => names, :internals => internal_names), | ||
| ) | ||
| return _construct_flexichain($TKey, vals, internals, names, internal_names, model) |
There was a problem hiding this comment.
would this work?
| names = _default_param_names(model, D, param_names) | |
| internal_names = [:logp] | |
| internal_names = [:logp] | |
| vals = Matrix{Float64}(undef, N, D) | |
| internals = Matrix{Float64}(undef, N, 1) | |
| vals = Matrix{Float64}(undef, N, D) | |
| internals = Matrix{Float64}(undef, N, 1) | |
| for i in 1:N | |
| vals[i, :] .= samples[i].x | |
| internals[i, 1] = samples[i].logp | |
| end | |
| for i in 1:N | |
| vals[i, :] .= samples[i].x | |
| internals[i, 1] = samples[i].logp | |
| end | |
| return MCMCChains.Chains( | |
| hcat(vals, internals), | |
| vcat(names, internal_names), | |
| Dict(:parameters => names, :internals => internal_names), | |
| ) | |
| return _construct_flexichain($TKey, vals, internals, names, internal_names, model) | |
| internal_names = [:logp] | |
| vals = Matrix{Float64}(undef, N, D) | |
| internals = Matrix{Float64}(undef, N, 1) | |
| for i in 1:N | |
| vals[i, :] .= samples[i].x | |
| internals[i, 1] = samples[i].logp | |
| end | |
| return _construct_flexichain($TKey, vals, internals, param_names, internal_names, model) |
Some questions worth pondering:
FlexiChains.FlexiChain{Symbol}is a massive faff. When I started writing FlexiChains it was very much with the Turing use case in mind, which is whyFlexiChain{VarName}has a convenient aliasVNChain. I think it probably makes sense to also make an alias forFlexiChain{Symbol}.Add SymChain alias penelopeysm/FlexiChains.jl#248