Skip to content

FlexiChains by default#44

Open
penelopeysm wants to merge 16 commits into
rsenne:mainfrom
penelopeysm:main
Open

FlexiChains by default#44
penelopeysm wants to merge 16 commits into
rsenne:mainfrom
penelopeysm:main

Conversation

@penelopeysm

@penelopeysm penelopeysm commented Jun 20, 2026

Copy link
Copy Markdown
Contributor

Some questions worth pondering:

  • FlexiChains.FlexiChain{Symbol} is a massive faff. When I started writing FlexiChains it was very much with the Turing use case in mind, which is why FlexiChain{VarName} has a convenient alias VNChain. I think it probably makes sense to also make an alias for FlexiChain{Symbol}.
    Add SymChain alias penelopeysm/FlexiChains.jl#248
  • Presumably we want to update the entire test suite to use FlexiChains instead. The question would then be what is the minimal subset of MCMCChains tests that are worth retaining separately?

@penelopeysm penelopeysm force-pushed the main branch 2 times, most recently from 28047fa to ebd73cc Compare June 20, 2026 23:24

@penelopeysm penelopeysm left a comment

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rsenne Thought I'd get your opinions on a couple of points before going further.

Comment thread ext/DynamicPPLExt.jl
Comment thread src/interface.jl Outdated
@rsenne

rsenne commented Jun 21, 2026

Copy link
Copy Markdown
Owner

Thanks for this!

Here are my thoughts:

  1. Given MCMCChains is no longer the default chains library of Turing, and soon to be this package, I'm curious if we even need to support it. The cases where someone needs it could be for ArViz, MCMCDiagnosticTools, etc.? And in any event, FlexiChains already has MCMCChains conversions yes? So practically if a user really wanted to they could just convert to a MCMCChains.Chains type via FC?
  2. If we do drop MCMCChains support we could retype the function from Type{Tchn} to ::Type{FlexiChain{Symbol}}. This would still dispatch on both methods, but since ParallelMCMCs is more specific for the first four args and tied on the last so I think it would remove the ambiguity.

Also re your questions to ponder:

I think it probably makes sense to also make an alias

I agree

minimal subset of MCMCChains tests

If you agree with my above thoughts, none! :)

@penelopeysm

Copy link
Copy Markdown
Contributor Author

You're totally correct, there's no real need to support MCMCChains. That makes life a lot easier!

@penelopeysm penelopeysm marked this pull request as ready for review June 21, 2026 19:03
Copilot AI review requested due to automatic review settings June 21, 2026 19:03
@penelopeysm

Copy link
Copy Markdown
Contributor Author

Okay, I got all the tests to pass locally. Apart from switching over to FlexiChains, I made a couple of design calls (which I am hoping is positive 🙂), which I'll discuss in the comments

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR migrates ParallelMCMC’s chain output and tests from MCMCChains.Chains to FlexiChains, making FlexiChains chain types (e.g. VNChain / SymChain) the primary output target across samplers and extensions.

Changes:

  • Replaced MCMCChains with FlexiChains in core module dependencies and in the AbstractMCMC integration (bundle_samples, mcmcsample paths).
  • Updated Turing/DynamicPPL and interface tests to assert on FlexiChains APIs (parameters, extras, niters, stack=true indexing).
  • Updated package/test dependencies to drop MCMCChains and add FlexiChains.

Reviewed changes

Copilot reviewed 12 out of 13 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
test/test-Turing-Integration.jl Switches Turing integration expectations from MCMCChains names/indexing to FlexiChains parameter/extras APIs.
test/test-GPU-Performance.jl Removes unused MCMCChains test dependency.
test/test-GPU-AD-HVP.jl Removes unused MCMCChains test dependency.
test/test-DEER-Turing-Logistic.jl Updates chain type + indexing to FlexiChains (VNChain/SymChain, stack=true).
test/test-DEER-Interface.jl Migrates interface tests from names(..., :internals)/size to FlexiChains extras/niters/nchains.
test/test-Adaptive-MALA.jl Migrates assertions to FlexiChains chain types and extras.
test/test-AbstractMCMC-Interface.jl Expands test coverage for symbol vs varname keys and vector-valued params under FlexiChains.
test/Project.toml Adds FlexiChains and removes MCMCChains from the test environment deps.
src/ParallelMCMC.jl Replaces using MCMCChains with using FlexiChains and re-exports sample.
src/interface.jl Implements FlexiChains-based chain construction for all samplers; changes default param naming to a single vector-valued :x.
Project.toml Adds FlexiChains, removes MCMCChains, and updates extension trigger list.
ext/LogDensityProblemsExt.jl Updates docs to mention FlexiChains (but still contains outdated default-name text).
ext/DynamicPPLExt.jl Switches DynamicPPL integration to FlexiChains, adds VNChain-only bundling, and overrides construction for VarName chains.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread ext/DynamicPPLExt.jl
Comment thread ext/DynamicPPLExt.jl
Comment thread ext/LogDensityProblemsExt.jl Outdated
Comment thread src/interface.jl
Comment thread src/ParallelMCMC.jl
Comment on lines +38 to +40
# Re-exports for convenience
import AbstractMCMC: sample
export sample

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, yes, I forgot to mention this. Honestly you probably do want to re-export sample...... I don't think there's a single soul out there who wants to use ParallelMCMC without using sample.

MCMCChains used to reexport sample so prior to this PR the basic README example would have worked fine, but FlexiChains doesn't reexport it so if you don't, the example will break.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(also it's all one single sample function -- AbstractMCMC, MCMCChains, Turing, etc. etc all reexport the same StatsBase.sample so it won't cause conflicts)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes lets re-export

@codecov

codecov Bot commented Jun 21, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 73.13433% with 18 lines in your changes missing coverage. Please review.
✅ Project coverage is 79.48%. Comparing base (85f4b8f) to head (0b4c98f).
⚠️ Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
src/interface.jl 75.80% 15 Missing ⚠️
ext/DynamicPPLExt.jl 40.00% 3 Missing ⚠️

❌ Your patch check has failed because the patch coverage (73.13%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage.
❌ Your project check has failed because the head coverage (79.48%) is below the target coverage (90.00%). You can increase the head coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             main      #44      +/-   ##
==========================================
- Coverage   80.06%   79.48%   -0.59%     
==========================================
  Files           8        8              
  Lines        1164     1165       +1     
==========================================
- Hits          932      926       -6     
- Misses        232      239       +7     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment thread src/interface.jl
Comment on lines +633 to +651
# Wrap parameter names in the format that FlexiChains expects. Ref:
# https://pysm.dev/FlexiChains.jl/stable/arrays/#api-fromarray
if param_names === nothing
param_names = model.param_names
end
wrapped_param_names = if param_names === nothing
# Wasn't defined either as `param_names` or `model.param_names`, so make some up
(to_parameter(:x) => (size(vals, 2),),)
else
map(param_names) do n
if n isa Pair
to_parameter(n.first) => n.second
elseif n isa TKey || n isa Symbol
to_parameter(n)
else
throw(ArgumentError("param_names must be a collection of Pairs or $TKey, got $(typeof(n))"))
end
end
end

@penelopeysm penelopeysm Jun 21, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main change is to do with how param_names are handled. This is actually not really even a big change but I think it accounts for most of the real code changes in this PR (the rest is boilerplate) so I may as well explain

Previously you could pass a vector of Symbol, and then that gets stored in MCMCChains, nice and simple. Now I've 'upgraded' it to support grouping different parameters together. Let's say you have a model with dimension 3, but conceptually you really want the first parameter to be a scalar and the other two to be grouped together as a vector. FlexiChains has a constructor that lets you do that (https://pysm.dev/FlexiChains.jl/stable/arrays/#api-fromarray).

It's a faff for the user to write all that out, but this code block lets people strip away all the FlexiChains.Parameter stuff and just write something like param_names = (:x, :y => (2,)). On top of that, this also allows us to naturally 'upgrade' a Symbol chain to a VarName chain: if the user requests

chain_type=FlexiChain{VarName}, param_names=(:x, :y => (2,))

then the :x and :y will just be 'upgraded' to @varname(x) and @varname(y). (VarNames have some benefits over Symbols in that it's easier to index into part of them, say to get y[1] out.)

Finally, the default param name has been changed: instead of scalar-valued x[i] for i in 1:D, it's now just one vector-valued parameter. That makes it easier to access the entire thing as a 3D array, while still allowing access to individual elements.

Given all this, I'd suggest that the docs should probably push users towards using VNChain as the go-to chain type, and say something like 'SymChain also works if you prefer it'? I haven't yet updated the docs, but if you're in agreement then I'll do it in the final round of this PR!

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this all LGTM and will be a big improvement imo

Comment thread src/interface.jl
)
N = length(samples)
D = model.dim
for TKey in (Symbol, VarName)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cost of supporting both Symbol and VarName chains is having to do this loop over both keys (to avoid ambiguities) but since it's self-contained in this file, I think it's fine. The problem I was worried about last time was more about having to have a separate but same implementation in an MCMCChainsExt.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense--loops are okay w/me but see below!

@test all(a -> a == 0.0 || a == 1.0, acc)
end

@testset "sample() with custom param_names" begin

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can see some examples of the param_names change works in this testset.

Comment thread src/interface.jl
s = samples[i]
vals[i, :] .= s.x
internals[i, 1] = s.logp
internals[i, 2] = s.accepted ? 1.0 : 0.0

@penelopeysm penelopeysm Jun 21, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last question: Is there a need to store accepted as a F64, is it just so that internals is concrete? If it's the latter, maybe we could think about making internals a NamedTuple like (logp = [logp_matrix...], accepted = [accepted_matrix...])? That would let us store it as the natural Bool. Not one for this PR though, it's a separate thing!

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Nope no reason. I like your suggestion. I'll open a new issue about it

@rsenne

rsenne commented Jun 23, 2026

Copy link
Copy Markdown
Owner

Okay so I was reading over the code and I had a thought (I'm mildly sleep deprived and this idea is only half-baked so feel free to say I'm not making any sense). But i was looking at the functionality you put in _construct_flexichain and this ability to rename params like this, feels generically useful?

And so I guess IIUC the FlexiChains API (see above for why i may not)--the reason we don't hook into to_nt_and_stats/to_vnt_and_stats is because we need the functionality of our current bundle_samples to handle all the parameter names etc. This function then becomes ambiguous unless we use concrete types for arg 5 (hence the loops), but I'm curious if since the functionality in _construct_flexichain feels broadly useful, I'm curious if it should maybe live upstream? I.e., what if a pattern like this existed in FlexiChains:

# helper with same functionality as _construct_flexichain
function parameter_keys(spec, ::Type{TKey}) where {TKey}
    to_parameter(vn::VarName) = Parameter(vn)
    to_parameter(s::Symbol)   = Parameter(TKey <: VarName ? VarName{s}() : s)
    map(spec) do n
        n isa Pair             ? (to_parameter(n.first) => n.second) :
        (n isa TKey || n isa Symbol) ? to_parameter(n) :
        throw(ArgumentError("param_names must be Symbols, $TKey, or name=>shape Pairs"))
    end
end

and then add a param_names kwarg

function bundle_samples(transitions, m, s, state, ::Type{FlexiChain{Symbol}};
                        param_names=nothing, stats=missing, )
    if param_names === nothing
        #  unchanged
    else
        keys = parameter_keys(param_names, Symbol)           # Piece 1
        arr  = stack_parameter_values(transitions)           # iters × Σdims, from the hook values
        extras = stack_stats(transitions)
        return FlexiChain{Symbol}(hcat(arr, extras), (keys..., Extra.(...)...))
    end
end

Conditioned on me making a semblance of sense, the PR as is can still proceed and this can be addressed later as these changes would need to be implemented themselves. Otherwise PR is looking good thus far!

@penelopeysm

Copy link
Copy Markdown
Contributor Author

the reason we don't hook into to_nt_and_stats/to_vnt_and_stats is because we need the functionality of our current bundle_samples to handle all the parameter names etc

Yes, that and also the fact that you want to generate all the samples at once. I think the current AbstractMCMC interface kind of assumes that you're generating samples iteratively. The to_nt_and_stats etc is designed to work with AbstractMCMC.step which you aren't using.

I'm curious if it should maybe live upstream?

I think maybe. I have a couple of mild objections:

  1. I'm a bit wary about introducing more keyword arguments to bundle_samples. None of these kwargs are defined in AbstractMCMC and so if FlexiChains methods introduce their own keyword arguments it creates an undocumented interface (a problem that's described here https://invenia.github.io/blog/2020/11/06/interfacetesting/). This is not a very strong objection because I think unfortunately FlexiChains is the only 'surviving' chains type in the Julia ecosystem, so the interface is de facto defined by FlexiChains. That's a bit of a stupid situation, but well.

  2. It doesn't actually fix the method ambiguities. So the main benefit of upstreaming it would be to allow for reuse of the pattern. But right now there's only one library doing this. I think this might potentially change in the future but I'd rather wait until there are at least two different consumers doing this before introducing this code.

Basically, let's not overengineer things yet. (In fact I feel like to_nt_and_stats is already a bit overengineered)

@rsenne

rsenne commented Jun 24, 2026

Copy link
Copy Markdown
Owner

you want to generate all the samples at once

ah yes great point

assumes that you're generating samples iteratively

a very reasonable assumption until this package :)

let's not overengineer things yet

makes sense/sounds good to me

@penelopeysm

Copy link
Copy Markdown
Contributor Author

Okay, docs updated too, so this should be reviewable & mergeable!

@penelopeysm penelopeysm requested a review from rsenne June 25, 2026 18:08
Comment thread Project.toml
@@ -34,10 +33,9 @@ CUDA_Runtime_jll = "0.21"
DifferentiationInterface = "0.7.13"
DynamicPPL = "0.40.6, 0.41"

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
DynamicPPL = "0.40.6, 0.41"
DynamicPPL = "0.41.6, 0.42"

Comment thread Project.toml

[extensions]
DynamicPPLExt = ["DynamicPPL", "ForwardDiff", "LogDensityProblems"]
DynamicPPLExt = ["DynamicPPL", "FlexiChains", "LogDensityProblems"]

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see comment in DynamicPPLExt

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also FlexiChains is a hard dep now so shouldn't need it here anyways

Comment thread Project.toml
DifferentiationInterface = "0.7.13"
DynamicPPL = "0.40.6, 0.41"
Enzyme = "0.13.146"
ForwardDiff = "1"

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should still declare a compat v for this even if removed from the Ext

Comment thread ext/DynamicPPLExt.jl
Comment on lines 35 to 36
"""
function ParallelMCMC.DensityModel(

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So right now I have ForwardDiff as the dfault for this ext because of issue #25 and ForwardDiff worked as a workaround. I failed to document this so thats my b. though I'm not sure it makes total sense to have this as a default just given there is other backends i didn't try. So at the very least we need to either re-add in the ForwardDiff as a gate for this ext or just remove this

Comment on lines 144 to 145
```

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

love ty

Comment thread ext/DynamicPPLExt.jl
state::$Tstate,
chain_type::Type{MCMCChains.Chains};
chain_type::Type{VNChain},
discard_warmup::Bool=false,

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this should still be a kwarg

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
discard_warmup::Bool=false,
chain_type::Type{VNChain};
discard_warmup::Bool=false,

Comment thread ext/DynamicPPLExt.jl
spl::$Tspl,
state::$Tstate,
chain_type::Type{MCMCChains.Chains};
chain_type::Type{VNChain},

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add some type of runtime guard for if a turing user accidently passes a SymChain?

Comment thread src/interface.jl
Comment on lines +885 to +897
names = _default_param_names(model, D, param_names)

internal_names = [:logp]
internal_names = [:logp]

vals = Matrix{Float64}(undef, N, D)
internals = Matrix{Float64}(undef, N, 1)
vals = Matrix{Float64}(undef, N, D)
internals = Matrix{Float64}(undef, N, 1)

for i in 1:N
vals[i, :] .= samples[i].x
internals[i, 1] = samples[i].logp
end
for i in 1:N
vals[i, :] .= samples[i].x
internals[i, 1] = samples[i].logp
end

return MCMCChains.Chains(
hcat(vals, internals),
vcat(names, internal_names),
Dict(:parameters => names, :internals => internal_names),
)
return _construct_flexichain($TKey, vals, internals, names, internal_names, model)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this work?

Suggested change
names = _default_param_names(model, D, param_names)
internal_names = [:logp]
internal_names = [:logp]
vals = Matrix{Float64}(undef, N, D)
internals = Matrix{Float64}(undef, N, 1)
vals = Matrix{Float64}(undef, N, D)
internals = Matrix{Float64}(undef, N, 1)
for i in 1:N
vals[i, :] .= samples[i].x
internals[i, 1] = samples[i].logp
end
for i in 1:N
vals[i, :] .= samples[i].x
internals[i, 1] = samples[i].logp
end
return MCMCChains.Chains(
hcat(vals, internals),
vcat(names, internal_names),
Dict(:parameters => names, :internals => internal_names),
)
return _construct_flexichain($TKey, vals, internals, names, internal_names, model)
internal_names = [:logp]
vals = Matrix{Float64}(undef, N, D)
internals = Matrix{Float64}(undef, N, 1)
for i in 1:N
vals[i, :] .= samples[i].x
internals[i, 1] = samples[i].logp
end
return _construct_flexichain($TKey, vals, internals, param_names, internal_names, model)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants