diff --git a/Project.toml b/Project.toml index d37e9cb..165857f 100644 --- a/Project.toml +++ b/Project.toml @@ -8,21 +8,20 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +FlexiChains = "4a37a8b9-6e57-4b92-8664-298d46e639f7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] -DynamicPPLExt = ["DynamicPPL", "ForwardDiff", "LogDensityProblems"] +DynamicPPLExt = ["DynamicPPL", "FlexiChains", "LogDensityProblems"] EnzymeExt = "Enzyme" LogDensityProblemsExt = "LogDensityProblems" @@ -34,10 +33,9 @@ CUDA_Runtime_jll = "0.21" DifferentiationInterface = "0.7.13" DynamicPPL = "0.40.6, 0.41" Enzyme = "0.13.146" -ForwardDiff = "1" +FlexiChains = "0.6.6" LinearAlgebra = "1" LogDensityProblems = "2" -MCMCChains = "7.7.0" Mooncake = "0.5.26" Random = "1" Statistics = "1" diff --git a/README.md b/README.md index 76ccce2..6398069 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,7 @@ The approach and its scaling tricks (stochastic Hutchinson Jacobian estimators, | [`MALASampler`](src/interface.jl) | Baseline — sequential MALA with a fixed step size | | [`AdaptiveMALASampler`](src/interface.jl) | Baseline — sequential MALA with dual-averaging step-size adaptation | -All samplers implement the [AbstractMCMC](https://github.com/TuringLang/AbstractMCMC.jl) interface and return [`MCMCChains.Chains`](https://github.com/TuringLang/MCMCChains.jl) objects, so they slot into existing Turing.jl / AbstractMCMC workflows. +All samplers implement the [AbstractMCMC](https://github.com/TuringLang/AbstractMCMC.jl) interface and return [`FlexiChains`](https://pysm.dev/FlexiChains.jl) objects, so they slot into existing Turing.jl / AbstractMCMC workflows. ### Quick start @@ -53,7 +53,7 @@ pkg> add ParallelMCMC ``` ```julia -using ParallelMCMC, MCMCChains +using ParallelMCMC, FlexiChains using ADTypes, Enzyme logp(x) = -0.5 * sum(abs2, x) # 2-D standard normal @@ -63,7 +63,7 @@ model = DensityModel(logp, grad_logp, 2; param_names=[:x1, :x2]) sampler = ParallelMALASampler(0.1; T=64, jacobian=:stoch_diag, backend=AutoEnzyme()) -chain = sample(model, sampler, 500; chain_type=MCMCChains.Chains) +chain = sample(model, sampler, 500; chain_type=VNChain) ``` See the [Getting Started guide](docs/src/10-getting-started.md) for worked examples, GPU usage, Turing.jl integration, and step-size tuning. diff --git a/docs/src/10-getting-started.md b/docs/src/10-getting-started.md index 6eafa87..962f5cb 100644 --- a/docs/src/10-getting-started.md +++ b/docs/src/10-getting-started.md @@ -12,7 +12,7 @@ All samplers take a [`DensityModel`](@ref) as their first argument. A `DensityM - `dim::Int` — dimension of the parameter space ```julia -using ParallelMCMC, MCMCChains +using ParallelMCMC, FlexiChains using ADTypes, Enzyme # Banana-shaped target in 2-D @@ -40,7 +40,7 @@ sampler = ParallelMALASampler(0.1; T=64, jacobian=:stoch_diag, damping=0.5, backend=AutoEnzyme()) chain = sample(model, sampler, 500; - chain_type=MCMCChains.Chains, progress=true) + chain_type=VNChain, progress=true) ``` `sample` requests 500 total samples. Internally, DEER solves trajectories of length `T=64` and returns each column of the solved trajectory as a separate sample. When the trajectory is exhausted a new noise tape is drawn and DEER re-solves from the last state. @@ -90,7 +90,7 @@ ParallelMCMC.jl integrates with Turing.jl models through the `DynamicPPL` and `L Load `DynamicPPL` (part of Turing.jl) and a single-argument `DensityModel` constructor becomes available: ```julia -using Turing, ParallelMCMC, MCMCChains +using Turing, ParallelMCMC, FlexiChains @model function normal_model(y) μ ~ Normal(0.0, 1.0) @@ -100,7 +100,7 @@ end model = DensityModel(normal_model(1.5)) # param_names=[:μ] extracted automatically chain = sample(model, ParallelMALASampler(0.1; T=64, backend=AutoEnzyme()), 500; - chain_type=MCMCChains.Chains) + chain_type=VNChain) ``` Much like Turing's own samplers, the resulting chain will always have parameters in the original (possibly constrained) space, even though the MCMC sampling itself is performed in unconstrained space. @@ -116,7 +116,7 @@ directly with DynamicPPL's `adtype` interface: ```julia using Turing, LogDensityProblems, ADTypes -using ParallelMCMC, MCMCChains +using ParallelMCMC, FlexiChains ld = DynamicPPL.LogDensityFunction( normal_model(1.5), @@ -140,10 +140,48 @@ All samplers support `MCMCThreads()`. Start Julia with multiple threads (e.g. ` ```julia chain = sample(model, ParallelMALASampler(0.1; T=64, backend=AutoEnzyme()), MCMCThreads(), 500, 4; - chain_type=MCMCChains.Chains) + chain_type=VNChain) ``` -`MCMCChains` computes R-hat and ESS across chains automatically. +`FlexiChains` computes R-hat and ESS across chains automatically: + +```julia +ess(chain) +``` + +To calculate intra-chain metrics, you can pass `dims=:iter` [as described in the FlexiChains docs](https://pysm.dev/FlexiChains.jl/stable/summarising/#Individual-statistics): + +```julia +ess(chain; dims=:iter) +``` + +--- + +## [Specifying parameter names](@id parameter-names) + +For manually constructed `DensityModel`s, you can optionally specify parameter names with the `param_names` keyword. +The resulting `FlexiChain` object will then have named entries for each parameter. + +If you do not specify `param_names`, the chain will store a single vector-valued parameter called `x` of length `D`. + +You can specify parameter names as a collection of either: + +- `Symbol`s (e.g. `[:x1, :x2]`); +- `VarName`s (e.g. `[@varname(x[1]), @varname(x[2])]`) if `chain_type=VNChain`; or +- A tuple of the above, *plus* a size. In this case, a total of `prod(size)` entries will be allocated to the named parameter, and the results in the chain will be reshaped to that size. + +The above can be mixed and matched as desired, as long as the total number of parameters matches the dimension of the model. +For example: + +```julia +# Three scalar parameters, called `x1` through `x3` +model = DensityModel(...; param_names=[:x1, :x2, :x3]) + +# One scalar parameter called `x`, and a 1x2 matrix parameter called `y` +model = DensityModel(...; param_names=[:x, (:y, (1, 2))]) +``` + +For Turing.jl models, parameter names are automatically derived from the model and do not need to be specified manually. --- @@ -158,14 +196,14 @@ chain = sample(model, ParallelMALASampler(0.1; T=64, backend=AutoEnzyme()), ```julia # Step 1: find a good step size with adaptive MALA baseline = sample(model, AdaptiveMALASampler(0.1; n_warmup=500), 600; - chain_type=MCMCChains.Chains, discard_warmup=true) + chain_type=VNChain, discard_warmup=true) # Read off the frozen step size from the last internal value eps_tuned = baseline[end, :step_size, 1] # Step 2: run DEER with the tuned step size chain = sample(model, ParallelMALASampler(eps_tuned; T=64, backend=AutoEnzyme()), 2_000; - chain_type=MCMCChains.Chains) + chain_type=VNChain) ``` See the [`MALASampler`](@ref) and [`AdaptiveMALASampler`](@ref) reference pages for the full keyword listing. diff --git a/docs/src/15-gpu.md b/docs/src/15-gpu.md index 56fc6dc..2f29a1d 100644 --- a/docs/src/15-gpu.md +++ b/docs/src/15-gpu.md @@ -119,7 +119,7 @@ y_gpu = CUDA.CuVector(y_cpu) ### Mooncake backend (plain operators) ```julia -using ParallelMCMC, MCMCChains +using ParallelMCMC, FlexiChains using ADTypes, Mooncake softplus(z) = log1p(exp(-abs(z))) + max(z, zero(z)) @@ -151,7 +151,7 @@ sampler = ParallelMALASampler(0.005f0; chain = sample(model, sampler, 1_600; initial_params=CUDA.zeros(Float32, D), - chain_type=MCMCChains.Chains) + chain_type=VNChain) ``` Posterior mean recovery error `‖β_post − β_true‖ / ‖β_true‖` should land in the 0.1–0.2 range after a few hundred post-warmup samples. @@ -161,7 +161,7 @@ Posterior mean recovery error `‖β_post − β_true‖ / ‖β_true‖` should Same model, with the GPU-Enzyme restrictions applied: every `*` becomes `pmcmc_matmul`, every `dot` becomes `pmcmc_dot`, and every gradient broadcast is expanded into single-op stages: ```julia -using ParallelMCMC, MCMCChains +using ParallelMCMC, FlexiChains using ADTypes, Enzyme function logp(β) @@ -222,7 +222,7 @@ sampler = ParallelMALASampler(0.005f0; chain = sample(model, sampler, 1_600; initial_params=CUDA.zeros(Float32, D), - chain_type=MCMCChains.Chains) + chain_type=VNChain) ``` The two snippets sample the same posterior; the difference is purely in what the AD backend can chew on. diff --git a/docs/src/95-reference.md b/docs/src/95-reference.md index 9624633..b7f2fb9 100644 --- a/docs/src/95-reference.md +++ b/docs/src/95-reference.md @@ -31,7 +31,7 @@ ParallelMALASampler ## Internal types -These types appear in `MCMCChains` internals and in the `AbstractMCMC` state/transition protocol. You generally do not need to construct them directly. +These types appear in the `AbstractMCMC` state/transition protocol. You generally do not need to construct them directly. ```@docs MALATapeElement diff --git a/docs/src/index.md b/docs/src/index.md index 92bc7ad..68aec97 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -40,7 +40,7 @@ The included [`MALASampler`](@ref) and [`AdaptiveMALASampler`](@ref) are sequent | [`MALASampler`](@ref) | Baseline — sequential MALA with a fixed step size | | [`AdaptiveMALASampler`](@ref) | Baseline — sequential MALA with dual-averaging step-size adaptation | -All samplers implement the [AbstractMCMC](https://github.com/TuringLang/AbstractMCMC.jl) interface and return [`MCMCChains.Chains`](https://github.com/TuringLang/MCMCChains.jl) objects. +All samplers implement the [AbstractMCMC](https://github.com/TuringLang/AbstractMCMC.jl) interface and return [`FlexiChains`](https://pysm.dev/FlexiChains.jl) objects. ## Installation @@ -57,7 +57,7 @@ pkg> add ParallelMCMC The simplest entry point is [`DensityModel`](@ref), which wraps a log-density and its gradient: ```julia -using ParallelMCMC, MCMCChains +using ParallelMCMC, FlexiChains using ADTypes, Enzyme # Example: 2-D standard normal @@ -75,18 +75,22 @@ sampler = ParallelMALASampler(0.1; T=64, jacobian=:stoch_diag, backend=AutoEnzyme()) chain = sample(model, sampler, 500; - chain_type=MCMCChains.Chains) + chain_type=VNChain) ``` Each call to `sample` draws 500 samples by solving DEER trajectories of length `T=64` in parallel, re-solving from the last state when each trajectory is exhausted. +Specifying `chain_type=VNChain` returns a `FlexiChain{VarName}`, which has a parameter type of `VarName`. +This is intended for maximum ease of use; however, if you prefer parameter type of `Symbol` you can use `SymChain` instead. +See [the FlexiChains.jl docs](https://pysm.dev/FlexiChains.jl/) for more information about how to analyze and visualize chains. + ### Sequential MALA baseline ```julia sampler = AdaptiveMALASampler(0.1; n_warmup=500) chain = sample(model, sampler, 2_000; - chain_type=MCMCChains.Chains, + chain_type=VNChain, discard_warmup=true, progress=true) ``` @@ -97,7 +101,7 @@ When `DynamicPPL` (part of Turing.jl) is loaded, a one-argument `DensityModel` c Parameter names are automatically extracted, and values transformed back to the original model space: ```julia -using Turing, ParallelMCMC, MCMCChains +using Turing, ParallelMCMC, FlexiChains @model function normal_model(y) μ ~ Normal(0.0, 1.0) @@ -108,12 +112,14 @@ model = DensityModel(normal_model(1.5)) sampler = AdaptiveMALASampler(0.3; n_warmup=500) chain = sample(model, sampler, 2_000; - chain_type=MCMCChains.Chains, + chain_type=VNChain, discard_warmup=true) ``` See [Getting Started](10-getting-started.md) for worked examples and guidance on choosing samplers, and [Algorithm Details](20-algorithms.md) for the mathematics behind DEER. +For Turing models the chain type used must be `VNChain` (not `SymChain`), as that is the natural parameter type for Turing models. + ## Contributors ```@raw html diff --git a/ext/DynamicPPLExt.jl b/ext/DynamicPPLExt.jl index 94b19e2..4ff4969 100644 --- a/ext/DynamicPPLExt.jl +++ b/ext/DynamicPPLExt.jl @@ -4,7 +4,7 @@ using ParallelMCMC using ADTypes: ADTypes using DynamicPPL: DynamicPPL using AbstractMCMC: AbstractMCMC -using MCMCChains: MCMCChains +using FlexiChains: FlexiChain, VarName, VNChain using LogDensityProblems: LogDensityProblems """ @@ -14,13 +14,12 @@ Convenience constructor: wraps a DynamicPPL/Turing `@model` directly as a `DensityModel`, automatically extracting parameter names and wiring up gradient computation via DynamicPPL's `adtype` interface. -Requires `DynamicPPL`, `ForwardDiff`, and `LogDensityProblems` to be loaded (these are the -weak-dependency triggers for this extension; `ForwardDiff` is what backs the default -`AutoForwardDiff()` AD path). +Requires `DynamicPPL` and `LogDensityProblems` to be loaded (these are the weak-dependency +triggers for this extension), plus any AD backend that is used. # Example ```julia -using Turing, ParallelMCMC, MCMCChains +using Turing, ParallelMCMC, FlexiChains @model function mymodel(y) μ ~ Normal(0, 1) @@ -31,7 +30,7 @@ end # and `using` the corresponding package (Enzyme, Mooncake). model = DensityModel(mymodel(1.5)) chain = sample(model, AdaptiveMALASampler(0.3; n_warmup=500), 2_000; - chain_type=MCMCChains.Chains, discard_warmup=true, progress=true) + chain_type=FlexiChains.VNChain, discard_warmup=true, progress=true) ``` """ function ParallelMCMC.DensityModel( @@ -106,32 +105,26 @@ for (Ttrans, Tspl, Tstate) in ( model::DensityModelLDF, spl::$Tspl, state::$Tstate, - chain_type::Type{MCMCChains.Chains}; + chain_type::Type{VNChain}, discard_warmup::Bool=false, kwargs..., ) ts = discard_warmup ? filter(t -> !is_warmup(t), ts) : ts - return make_processed_dynamicppl_chain(MCMCChains.Chains, ts, model) + pwss = map(ts) do t + # Note: This assumes that there is always a field called t.x. This is currently true + # of all samplers in ParallelMCMC + DynamicPPL.ParamsWithStats(t.x, model.logdensity.ld, getstats(t)) + end + return AbstractMCMC.from_samples(VNChain, hcat(pwss)) end end end -function make_processed_dynamicppl_chain( - ::Type{Tchain}, ts::Vector{<:ParallelMCMCTransitionTypes}, model::DensityModelLDF -) where {Tchain} - pwss = map(ts) do t - # Note: This assumes that there is always a field called t.x. This is currently true - # of all samplers in ParallelMCMC - DynamicPPL.ParamsWithStats(t.x, model.logdensity.ld, getstats(t)) - end - return AbstractMCMC.from_samples(Tchain, hcat(pwss)) -end - -function ParallelMCMC._construct_chain( - ::Type{MCMCChains.Chains}, +function ParallelMCMC._construct_flexichain( + ::Type{VarName}, vals::AbstractMatrix{<:Real}, internals::AbstractMatrix{<:Real}, - ::Vector{Symbol}, + ::Any, internal_names::Vector{Symbol}, model::DensityModelLDF, ) @@ -139,7 +132,7 @@ function ParallelMCMC._construct_chain( stats = NamedTuple{Tuple(internal_names)}(internal) DynamicPPL.ParamsWithStats(val, model.logdensity.ld, stats) end - return AbstractMCMC.from_samples(MCMCChains.Chains, hcat(pwss)) + return AbstractMCMC.from_samples(VNChain, hcat(pwss)) end end # module diff --git a/ext/LogDensityProblemsExt.jl b/ext/LogDensityProblemsExt.jl index 366f737..a1c3885 100644 --- a/ext/LogDensityProblemsExt.jl +++ b/ext/LogDensityProblemsExt.jl @@ -15,16 +15,15 @@ Construct a `DensityModel` from any object implementing the - `LogDensityProblems.dimension(ld)` -> `Int` - `LogDensityProblems.logdensity_and_gradient(ld, x)` -> `(logp, grad)` -The optional `param_names` keyword accepts a `Vector{Symbol}` of parameter names -that will be used for the columns of the returned `MCMCChains.Chains` object. -If omitted, names default to `x[1], x[2], ...` unless you also pass `param_names` -to `sample(...)`. +The optional `param_names` keyword accepts a collection of parameter names that will be used +for the columns of the returned `FlexiChain` object. If omitted, a single vector-valued +parameter named `:x` will be chosen, unless you also pass `param_names` to `sample(...)`. The `hvp` keyword argument is forwarded to the main `DensityModel` constructor. # Turing.jl / DynamicPPL example ```julia -using Turing, LogDensityProblems, ADTypes, Enzyme, ParallelMCMC, MCMCChains +using Turing, LogDensityProblems, ADTypes, Enzyme, ParallelMCMC, FlexiChains @model function mymodel(y) μ ~ Normal(0, 1) @@ -41,7 +40,7 @@ ld = DynamicPPL.LogDensityFunction( model = DensityModel(ld; param_names=[:μ]) chain = sample(model, AdaptiveMALASampler(0.3; n_warmup=500), 2_000; - chain_type=MCMCChains.Chains, discard_warmup=true, progress=true) + chain_type=VNChain, discard_warmup=true, progress=true) ``` If DynamicPPL is loaded, the simpler one-step constructor `DensityModel(mymodel(obs))` diff --git a/src/ParallelMCMC.jl b/src/ParallelMCMC.jl index 806543c..7a9a557 100644 --- a/src/ParallelMCMC.jl +++ b/src/ParallelMCMC.jl @@ -2,7 +2,7 @@ module ParallelMCMC using AbstractMCMC using CUDA -using MCMCChains +using FlexiChains using LinearAlgebra using Random using Statistics @@ -35,4 +35,8 @@ export ParallelMALASampler, ParallelMALATransition, ParallelMALAState export MALA, DEER export pmcmc_matmul, pmcmc_dot, pmcmc_dotsum +# Re-exports for convenience +import AbstractMCMC: sample +export sample + end diff --git a/src/interface.jl b/src/interface.jl index 1bccf12..892bcb6 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -25,13 +25,15 @@ product helpers for use with ParallelMCMC samplers. batched Hessian-vector product over columns. If omitted, the sampler can differentiate `grad_logdensity_batch` to keep the batched DEER path enabled. - `dim::Int` — dimensionality of the parameter space -- `param_names` — optional `Vector{Symbol}` of parameter names used in `MCMCChains` output. - If `nothing` (the default), names fall back to `x[1], x[2], ...`. +- `param_names` — optional collection of parameter names used in `FlexiChains` output. If + `nothing` (the default), uses a single vector-valued parameter `:x` with shape `(dim,)`. + See the [`Parameter names`](@ref parameter-names) section of the docs for more + information. When `logdensity_batch` and `grad_logdensity_batch` are provided, `ParallelMALASampler` enables the batched DEER update path. """ -struct DensityModel{F,G,H,FB,GB,HB} <: AbstractMCMC.AbstractModel +struct DensityModel{F,G,H,FB,GB,HB,PN} <: AbstractMCMC.AbstractModel logdensity::F grad_logdensity::G hvp::H @@ -39,7 +41,7 @@ struct DensityModel{F,G,H,FB,GB,HB} <: AbstractMCMC.AbstractModel grad_logdensity_batch::GB hvp_batch::HB dim::Int - param_names::Union{Nothing,Vector{Symbol}} + param_names::PN end """ @@ -191,43 +193,33 @@ function AbstractMCMC.step( return t, s end -function AbstractMCMC.bundle_samples( - samples::Vector{<:MALATransition}, - model::DensityModel, - sampler::MALASampler, - state::MALAState, - ::Type{MCMCChains.Chains}; - param_names=nothing, - kwargs..., -) - N = length(samples) - D = model.dim +for TKey in (Symbol, VarName) + @eval function AbstractMCMC.bundle_samples( + samples::Vector{<:MALATransition}, + model::DensityModel, + sampler::MALASampler, + state::MALAState, + ::Type{FlexiChains.FlexiChain{$TKey}}; + param_names=nothing, + kwargs..., + ) + N = length(samples) + D = model.dim - names = if param_names !== nothing - param_names - elseif model.param_names !== nothing - model.param_names - else - [Symbol("x[$i]") for i in 1:D] - end + internal_names = [:logp, :accepted] - internal_names = [:logp, :accepted] + vals = Matrix{Float64}(undef, N, D) + internals = Matrix{Float64}(undef, N, 2) - vals = Matrix{Float64}(undef, N, D) - internals = Matrix{Float64}(undef, N, 2) + for i in 1:N + s = samples[i] + vals[i, :] .= s.x + internals[i, 1] = s.logp + internals[i, 2] = s.accepted ? 1.0 : 0.0 + end - for i in 1:N - s = samples[i] - vals[i, :] .= s.x - internals[i, 1] = s.logp - internals[i, 2] = s.accepted ? 1.0 : 0.0 + return _construct_flexichain($TKey, vals, internals, param_names, internal_names, model) end - - return MCMCChains.Chains( - hcat(vals, internals), - vcat(names, internal_names), - Dict(:parameters => names, :internals => internal_names), - ) end """ @@ -528,13 +520,14 @@ function _trajectory_logps(model::DensityModel, S::AbstractMatrix) return [model.logdensity(S[:, t]) for t in 1:T] end -function _parallel_mala_param_names(model::DensityModel, D::Int, param_names) +function _default_param_names(model::DensityModel, D::Int, param_names) if param_names !== nothing return param_names elseif model.param_names !== nothing return model.param_names else - return [Symbol("x[$i]") for i in 1:D] + # One vector-valued parameter `:x` with shape `(D,)` + return (:x => (D,),) end end @@ -574,14 +567,13 @@ function _sample_parallel_mala_chain( model::DensityModel, sampler::ParallelMALASampler, N::Int, - ::Type{Tchn}; + ::Type{FlexiChains.FlexiChain{TKey}}; initial_params=nothing, param_names=nothing, progress=AbstractMCMC.PROGRESS[], progressname="Sampling", -) where {Tchn} +) where {TKey} D = model.dim - names = _parallel_mala_param_names(model, D, param_names) internal_names = [:logp] vals = Matrix{Float64}(undef, N, D) @@ -618,22 +610,53 @@ function _sample_parallel_mala_chain( end end - return _construct_chain(Tchn, vals, internals, names, internal_names, model) + return _construct_flexichain(TKey, vals, internals, param_names, internal_names, model) end -function _construct_chain( - ::Type{MCMCChains.Chains}, +function _construct_flexichain( + ::Type{TKey}, vals::AbstractMatrix{<:Real}, internals::AbstractMatrix{<:Real}, - names::Vector{Symbol}, + param_names::Any, internal_names::Vector{Symbol}, model::DensityModel, -) - return MCMCChains.Chains( - hcat(vals, internals), - vcat(names, internal_names), - Dict(:parameters => names, :internals => internal_names), +) where {TKey} + # Wrap user-supplied names in `Parameter`. This allows people to specify, e.g., + # `param_names=(:x, :y, :z=>(2,))` without faffing with `Parameter` themselves. Also + # 'upgrade' symbol parameter names to VarNames if the user requested a VNChain. + to_parameter(vn::VarName) = FlexiChains.Parameter(vn) + to_parameter(s::Symbol) = FlexiChains.Parameter(TKey <: VarName ? VarName{s}() : s) + + # vals and internals are both `iters x params` + # FlexiChains expects `iters x chains x params` + arr = hcat(vals, internals) + arr = reshape(arr, size(arr, 1), 1, size(arr, 2)) + + # Wrap parameter names in the format that FlexiChains expects. Ref: + # https://pysm.dev/FlexiChains.jl/stable/arrays/#api-fromarray + if param_names === nothing + param_names = model.param_names + end + wrapped_param_names = if param_names === nothing + # Wasn't defined either as `param_names` or `model.param_names`, so make some up + (to_parameter(:x) => (size(vals, 2),),) + else + map(param_names) do n + if n isa Pair + to_parameter(n.first) => n.second + elseif n isa TKey || n isa Symbol + to_parameter(n) + else + throw(ArgumentError("param_names must be a collection of Pairs, Symbols, or $TKey, got $(typeof(n))")) + end + end + end + + all_names = ( + wrapped_param_names..., + FlexiChains.Extra.(internal_names)... ) + return FlexiChains.FlexiChain{TKey}(arr, all_names) end function _sample_parallel_mala_blocks( @@ -755,7 +778,7 @@ function AbstractMCMC.mcmcsample( N > 0 || error("the number of samples must be ≥ 1") N_int = Int(N) - if chain_type === MCMCChains.Chains + if chain_type <: FlexiChains.FlexiChain return _sample_parallel_mala_chain( rng, model, @@ -846,41 +869,33 @@ function AbstractMCMC.step( end end -function AbstractMCMC.bundle_samples( - samples::Vector{<:ParallelMALATransition}, - model::DensityModel, - sampler::ParallelMALASampler, - state::ParallelMALAState, - ::Type{MCMCChains.Chains}; - param_names=nothing, - kwargs..., -) - N = length(samples) - D = model.dim +for TKey in (Symbol, VarName) + @eval function AbstractMCMC.bundle_samples( + samples::Vector{<:ParallelMALATransition}, + model::DensityModel, + sampler::ParallelMALASampler, + state::ParallelMALAState, + ::Type{FlexiChains.FlexiChain{$TKey}}; + param_names=nothing, + kwargs..., + ) + N = length(samples) + D = model.dim - names = if param_names !== nothing - param_names - elseif model.param_names !== nothing - model.param_names - else - [Symbol("x[$i]") for i in 1:D] - end + names = _default_param_names(model, D, param_names) - internal_names = [:logp] + internal_names = [:logp] - vals = Matrix{Float64}(undef, N, D) - internals = Matrix{Float64}(undef, N, 1) + vals = Matrix{Float64}(undef, N, D) + internals = Matrix{Float64}(undef, N, 1) - for i in 1:N - vals[i, :] .= samples[i].x - internals[i, 1] = samples[i].logp - end + for i in 1:N + vals[i, :] .= samples[i].x + internals[i, 1] = samples[i].logp + end - return MCMCChains.Chains( - hcat(vals, internals), - vcat(names, internal_names), - Dict(:parameters => names, :internals => internal_names), - ) + return _construct_flexichain($TKey, vals, internals, names, internal_names, model) + end end """ @@ -1059,45 +1074,35 @@ function AbstractMCMC.step( return trans, new_state end -function AbstractMCMC.bundle_samples( - samples::Vector{<:AdaptiveMALATransition}, - model::DensityModel, - sampler::AdaptiveMALASampler, - state::AdaptiveMALAState, - ::Type{MCMCChains.Chains}; - param_names=nothing, - discard_warmup=false, - kwargs..., -) - filtered = discard_warmup ? filter(s -> !s.is_warmup, samples) : samples - N = length(filtered) - D = model.dim - - names = if param_names !== nothing - param_names - elseif model.param_names !== nothing - model.param_names - else - [Symbol("x[$i]") for i in 1:D] - end - - internal_names = [:logp, :accepted, :step_size, :is_warmup] +for TKey in (Symbol, VarName) + @eval function AbstractMCMC.bundle_samples( + samples::Vector{<:AdaptiveMALATransition}, + model::DensityModel, + sampler::AdaptiveMALASampler, + state::AdaptiveMALAState, + ::Type{FlexiChains.FlexiChain{$TKey}}; + param_names=nothing, + discard_warmup=false, + kwargs..., + ) + filtered = discard_warmup ? filter(s -> !s.is_warmup, samples) : samples + N = length(filtered) + D = model.dim + + internal_names = [:logp, :accepted, :step_size, :is_warmup] + + vals = Matrix{Float64}(undef, N, D) + internals = Matrix{Float64}(undef, N, 4) + + for i in 1:N + s = filtered[i] + vals[i, :] .= s.x + internals[i, 1] = s.logp + internals[i, 2] = s.accepted ? 1.0 : 0.0 + internals[i, 3] = s.step_size + internals[i, 4] = s.is_warmup ? 1.0 : 0.0 + end - vals = Matrix{Float64}(undef, N, D) - internals = Matrix{Float64}(undef, N, 4) - - for i in 1:N - s = filtered[i] - vals[i, :] .= s.x - internals[i, 1] = s.logp - internals[i, 2] = s.accepted ? 1.0 : 0.0 - internals[i, 3] = s.step_size - internals[i, 4] = s.is_warmup ? 1.0 : 0.0 + return _construct_flexichain($TKey, vals, internals, param_names, internal_names, model) end - - return MCMCChains.Chains( - hcat(vals, internals), - vcat(names, internal_names), - Dict(:parameters => names, :internals => internal_names), - ) end diff --git a/test/Project.toml b/test/Project.toml index 09c274c..347b20f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,12 +5,12 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +FlexiChains = "4a37a8b9-6e57-4b92-8664-298d46e639f7" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" ParallelMCMC = "1a970f40-4406-51c9-a967-cb3143c111e8" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/test/test-AbstractMCMC-Interface.jl b/test/test-AbstractMCMC-Interface.jl index b01f19f..fe8db04 100644 --- a/test/test-AbstractMCMC-Interface.jl +++ b/test/test-AbstractMCMC-Interface.jl @@ -2,7 +2,7 @@ using Test using Random using LinearAlgebra using Statistics -using MCMCChains +using FlexiChains using ParallelMCMC const MALA_iface = ParallelMCMC.MALA @@ -138,19 +138,16 @@ gradlogp_iface(x) = -x model = DensityModel(logp_iface, gradlogp_iface, 2) sampler = MALASampler(0.15) - chain = sample(model, sampler, 200; chain_type=MCMCChains.Chains, progress=false) + chain = sample(model, sampler, 200; chain_type=VNChain, progress=false) - @test chain isa MCMCChains.Chains - @test size(chain, 1) == 200 + @test chain isa VNChain + @test FlexiChains.niters(chain) == 200 # Parameter columns present - param_names = names(chain, :parameters) - @test length(param_names) == 2 + @test only(FlexiChains.parameters(chain)) == @varname(x) # Internal columns present - internal_names = names(chain, :internals) - @test :logp in internal_names - @test :accepted in internal_names + @test Set(FlexiChains.extras(chain)) == Set(FlexiChains.Extra.([:logp, :accepted])) # logp values should be finite @test all(isfinite, chain[:logp]) @@ -164,19 +161,63 @@ gradlogp_iface(x) = -x model = DensityModel(logp_iface, gradlogp_iface, 2) sampler = MALASampler(0.15) - chain = sample( - model, - sampler, - 50; - chain_type=MCMCChains.Chains, - progress=false, - param_names=[:mu, :sigma], - ) + @testset "with scalar-valued symbols" begin + chain = sample( + model, + sampler, + 50; + chain_type=SymChain, + progress=false, + param_names=[:mu, :sigma], + ) + + @test chain isa SymChain + @test FlexiChains.parameters(chain) == [:mu, :sigma] + end - @test chain isa MCMCChains.Chains - param_names = names(chain, :parameters) - @test :mu in param_names - @test :sigma in param_names + @testset "symbol param names and VarName chain" begin + chain = sample( + model, + sampler, + 50; + chain_type=VNChain, + progress=false, + param_names=[:mu, :sigma], + ) + + @test chain isa VNChain + @test FlexiChains.parameters(chain) == [@varname(mu), @varname(sigma)] + end + + @testset "with vector-valued symbols" begin + chain = sample( + model, + sampler, + 50; + chain_type=SymChain, + progress=false, + param_names=[:param => (2,)], + ) + + @test chain isa SymChain + @test FlexiChains.parameters(chain) == [:param] + @test size(chain[:param, stack=true]) == (50, 1, 2) + end + + @testset "with vector-valued varnames" begin + chain = sample( + model, + sampler, + 50; + chain_type=VNChain, + progress=false, + param_names=[@varname(param) => (2,)], + ) + + @test chain isa VNChain + @test FlexiChains.parameters(chain) == [@varname(param)] + @test size(chain[@varname(param), stack=true]) == (50, 1, 2) + end end @testset "stationary distribution via sample()" begin @@ -189,12 +230,12 @@ gradlogp_iface(x) = -x model, sampler, 20_000; - chain_type=MCMCChains.Chains, + chain_type=VNChain, progress=false, ) burn = 3_000 - post = Array(chain[burn:end, :, :]) # (N-burn) × D + post = Array(chain)[burn:end, :, :] # (N-burn) × D mu = vec(mean(post; dims=1)) @test maximum(abs.(mu)) < 0.1 diff --git a/test/test-Adaptive-MALA.jl b/test/test-Adaptive-MALA.jl index b10bb3e..d4c5a61 100644 --- a/test/test-Adaptive-MALA.jl +++ b/test/test-Adaptive-MALA.jl @@ -2,7 +2,7 @@ using Test using Random using LinearAlgebra using Statistics -using MCMCChains +using FlexiChains using ParallelMCMC const MALA = ParallelMCMC.MALA @@ -187,18 +187,18 @@ end model, sampler, 150; - chain_type=MCMCChains.Chains, + chain_type=SymChain, progress=false, ) - @test chain isa MCMCChains.Chains - @test size(chain, 1) == 150 + @test chain isa SymChain + @test FlexiChains.niters(chain) == 150 - internals = names(chain, :internals) - @test :logp in internals - @test :accepted in internals - @test :step_size in internals - @test :is_warmup in internals + extras_names = FlexiChains.extras(chain) + @test FlexiChains.Extra(:logp) in extras_names + @test FlexiChains.Extra(:accepted) in extras_names + @test FlexiChains.Extra(:step_size) in extras_names + @test FlexiChains.Extra(:is_warmup) in extras_names @test all(isfinite, chain[:logp]) @test all(x -> x == 0.0 || x == 1.0, chain[:accepted]) @@ -220,7 +220,7 @@ end model, sampler, n_w + 50; - chain_type=MCMCChains.Chains, + chain_type=VNChain, progress=false, ) @@ -245,12 +245,12 @@ end model, sampler, n_w + 5_000; - chain_type=MCMCChains.Chains, + chain_type=VNChain, progress=false, ) # Discard warmup - post = Array(chain[(n_w + 1):end, :, :]) # 5000 × D + post = Array(chain)[(n_w + 1):end, :, :] # 5000 × D mu = vec(mean(post; dims=1)) vars = vec(var(post; dims=1)) @@ -270,11 +270,11 @@ end ParallelMCMC.AbstractMCMC.MCMCThreads(), 60, 2; - chain_type=MCMCChains.Chains, + chain_type=VNChain, progress=false, ) - @test chains isa MCMCChains.Chains - @test size(chains, 1) == 60 - @test size(chains, 3) == 2 + @test chains isa VNChain + @test FlexiChains.niters(chains) == 60 + @test FlexiChains.nchains(chains) == 2 end diff --git a/test/test-DEER-Interface.jl b/test/test-DEER-Interface.jl index 64bfe52..f01f2dc 100644 --- a/test/test-DEER-Interface.jl +++ b/test/test-DEER-Interface.jl @@ -2,7 +2,7 @@ using Test using Random using LinearAlgebra using Statistics -using MCMCChains +using FlexiChains using ParallelMCMC using ADTypes @@ -216,17 +216,20 @@ end model, sampler, 100; - chain_type=MCMCChains.Chains, + chain_type=SymChain, progress=false, ) - @test chain isa MCMCChains.Chains - @test size(chain, 1) == 100 - @test :logp in names(chain, :internals) + @test chain isa SymChain + @test FlexiChains.niters(chain) == 100 + @test FlexiChains.Extra(:logp) in FlexiChains.extras(chain) @test all(isfinite, chain[:logp]) - param_names = names(chain, :parameters) - @test length(param_names) == 2 + # Single vector-valued parameter, each sample having length 2 + param_names = FlexiChains.parameters(chain) + @test length(param_names) == 1 + name = only(param_names) + @test size(chain[name, stack=true], 3) == 2 end @testset "ParallelMALASampler sample() with custom param_names" begin @@ -238,13 +241,13 @@ end model, sampler, 40; - chain_type=MCMCChains.Chains, + chain_type=SymChain, progress=false, param_names=[:mu, :sigma], ) - @test :mu in names(chain, :parameters) - @test :sigma in names(chain, :parameters) + @test :mu in FlexiChains.parameters(chain) + @test :sigma in FlexiChains.parameters(chain) end @testset "ParallelMALASampler stationary distribution" begin @@ -257,12 +260,12 @@ end model, sampler, 5_000; - chain_type=MCMCChains.Chains, + chain_type=SymChain, progress=false, ) burn = 500 - post = Array(chain[burn:end, :, :]) # (N-burn) × D + post = Array(chain)[burn:end, :, :] # (N-burn) × D mu = vec(mean(post; dims=1)) vars = vec(var(post; dims=1)) @@ -282,11 +285,11 @@ end ParallelMCMC.AbstractMCMC.MCMCThreads(), 40, 2; - chain_type=MCMCChains.Chains, + chain_type=SymChain, progress=false, ) - @test chains isa MCMCChains.Chains - @test size(chains, 1) == 40 # samples per chain - @test size(chains, 3) == 2 # number of chains + @test chains isa SymChain + @test FlexiChains.niters(chains) == 40 + @test FlexiChains.nchains(chains) == 2 end diff --git a/test/test-DEER-Turing-Logistic.jl b/test/test-DEER-Turing-Logistic.jl index 06a4d46..fcddcb0 100644 --- a/test/test-DEER-Turing-Logistic.jl +++ b/test/test-DEER-Turing-Logistic.jl @@ -2,7 +2,7 @@ using Test using Random using LinearAlgebra using Statistics -using MCMCChains +using FlexiChains using ParallelMCMC @@ -101,15 +101,14 @@ end sampler, 400; initial_params=zeros(_LR_D), - chain_type=MCMCChains.Chains, + chain_type=VNChain, progress=false, ) - @test chain isa MCMCChains.Chains - @test size(chain, 1) == 400 - @test Symbol("β[1]") in names(chain, :parameters) - @test Symbol("β[2]") in names(chain, :parameters) - @test all(isfinite, Array(chain[:, [Symbol("β[1]"), Symbol("β[2]")], :])) + @test chain isa VNChain + @test FlexiChains.niters(chain) == 400 + @test @varname(β) in FlexiChains.parameters(chain) + @test all(isfinite, chain[@varname(β), stack=true]) end @testset "ParallelMALASampler Turing logistic: posterior sign correct" begin @@ -131,11 +130,11 @@ end sampler, 800; initial_params=zeros(_LR_D), - chain_type=MCMCChains.Chains, + chain_type=VNChain, progress=false, ) - post = Array(chain[201:end, [Symbol("β[1]"), Symbol("β[2]")], 1]) + post = chain[@varname(β), stack=true][201:end, :, :] β_mean = vec(mean(post; dims=1)) @test sign(β_mean[1]) == sign(_LR_β_true[1]) @@ -160,11 +159,11 @@ end model, AdaptiveMALASampler(0.1; n_warmup=1000), 5000; - chain_type=MCMCChains.Chains, + chain_type=SymChain, progress=false, discard_warmup=true, ) - β_mala = vec(mean(Array(mala_chain[:, [Symbol("x[1]"), Symbol("x[2]")], 1]); dims=1)) + β_mala = vec(mean(mala_chain[:x, stack=true]; dims=1)) deer_chain = sample( MersenneTwister(42), @@ -179,12 +178,10 @@ end backend=ADTypes.AutoEnzyme(), ), 800; - chain_type=MCMCChains.Chains, + chain_type=SymChain, progress=false, ) - β_deer = vec( - mean(Array(deer_chain[201:end, [Symbol("x[1]"), Symbol("x[2]")], 1]); dims=1) - ) + β_deer = vec(mean(deer_chain[:x, stack=true][201:end, :, :]; dims=1)) @test abs(β_deer[1] - β_mala[1]) < 0.25 @test abs(β_deer[2] - β_mala[2]) < 0.25 diff --git a/test/test-GPU-AD-HVP.jl b/test/test-GPU-AD-HVP.jl index 74f773c..7bd7df3 100644 --- a/test/test-GPU-AD-HVP.jl +++ b/test/test-GPU-AD-HVP.jl @@ -2,7 +2,6 @@ using Test using Random using LinearAlgebra using Statistics -using MCMCChains using ParallelMCMC using ADTypes: ADTypes diff --git a/test/test-GPU-Performance.jl b/test/test-GPU-Performance.jl index 2ef1217..a57ff38 100644 --- a/test/test-GPU-Performance.jl +++ b/test/test-GPU-Performance.jl @@ -6,7 +6,6 @@ using Statistics using ParallelMCMC using ADTypes: ADTypes using CUDA: CUDA -using MCMCChains #= On a problem large enough to amortize CUDA kernel-launch overhead, diff --git a/test/test-Turing-Integration.jl b/test/test-Turing-Integration.jl index e79af9d..726beb3 100644 --- a/test/test-Turing-Integration.jl +++ b/test/test-Turing-Integration.jl @@ -2,7 +2,7 @@ using Test using Random using LinearAlgebra using Statistics -using MCMCChains +using FlexiChains using ParallelMCMC @@ -63,10 +63,10 @@ end model, AdaptiveMALASampler(0.3; n_warmup=200), 600; - chain_type=MCMCChains.Chains, + chain_type=VNChain, progress=false, ) - @test only(names(chain, :parameters)) == :μ + @test only(FlexiChains.parameters(chain)) == @varname(μ) end @testset "DynamicPPLExt: convenience constructor" begin @@ -81,10 +81,10 @@ end model, AdaptiveMALASampler(0.3; n_warmup=200), 600; - chain_type=MCMCChains.Chains, + chain_type=VNChain, progress=false, ) - @test only(names(chain, :parameters)) == :μ + @test only(FlexiChains.parameters(chain)) == @varname(μ) end @testset "DynamicPPLExt: convenience constructor uses linked space for constrained models" begin @@ -139,13 +139,14 @@ end sampler, 800; initial_params=zeros(2), - chain_type=MCMCChains.Chains, + chain_type=VNChain, progress=false, ) - samples = Array(chain) + samples = chain[@varname(x), stack=true] @test all(isfinite, samples) # Standard normal in 2-D: posterior mean should be near zero. - @test maximum(abs, vec(mean(samples; dims=1))) < 0.25 + posterior_means = mean(samples; dims=1) + @test maximum(abs, posterior_means) < 0.25 end @testset "DynamicPPLExt: Dirichlet(ones(3)) runs with ParallelMALA (linked space)" begin @@ -168,10 +169,10 @@ end sampler, 800; initial_params=zeros(2), - chain_type=MCMCChains.Chains, + chain_type=VNChain, progress=false, ) - @test chain isa MCMCChains.Chains + @test chain isa VNChain @test all(isfinite, Array(chain)) end @@ -183,13 +184,14 @@ end model, AdaptiveMALASampler(0.3; n_warmup=200), 600; - chain_type=MCMCChains.Chains, + chain_type=VNChain, progress=false, ) - @test chain isa MCMCChains.Chains - @test :μ in names(chain, :parameters) - @test !(Symbol("x[1]") in names(chain, :parameters)) + @test chain isa VNChain + params = FlexiChains.parameters(chain) + @test @varname(μ) in params + @test !(@varname(x) in params) end @testset "discard_warmup=true removes warmup samples" begin @@ -203,7 +205,7 @@ end model, sampler, n_total; - chain_type=MCMCChains.Chains, + chain_type=VNChain, progress=false, ) chain_trimmed = sample( @@ -211,14 +213,14 @@ end model, sampler, n_total; - chain_type=MCMCChains.Chains, + chain_type=VNChain, progress=false, discard_warmup=true, ) - @test size(chain_full, 1) == n_total - @test size(chain_trimmed, 1) == n_total - n_warmup - 1 - @test all(==(0.0), vec(chain_trimmed[:is_warmup])) + @test FlexiChains.niters(chain_full) == n_total + @test FlexiChains.niters(chain_trimmed) == n_total - n_warmup - 1 + @test all(==(false), chain_trimmed[:is_warmup]) end @testset "posterior mean and variance match analytic solution" begin @@ -232,12 +234,12 @@ end model, sampler, n_warmup + n_draw; - chain_type=MCMCChains.Chains, + chain_type=VNChain, progress=false, discard_warmup=true, ) - mu_samples = vec(Array(chain[:μ])) + mu_samples = vec(chain[:μ]) @test abs(mean(mu_samples) - TRUE_MU_POST) < 0.05 @test abs(var(mu_samples) - TRUE_VAR_POST) < 0.05 @@ -255,16 +257,12 @@ end model, AdaptiveMALASampler(0.2; n_warmup=100), 300; - chain_type=MCMCChains.Chains, + chain_type=VNChain, progress=false, ) # Check that the chain contains parameters in original space. # The Dirichlet parameter should have length 3. - @test Set(names(chain, :parameters)) == - Set(Symbol.(["c[1]", "c[2]", "c[3]", "μ.a", "μ.b"])) - for i in 1:3 - # Dirichlet samples should be non-negative - @test all(chain[Symbol("c[$i]")] .>= 0.0) - end + @test Set(FlexiChains.parameters(chain)) == Set([@varname(c), @varname(μ)]) + @test all(chain[@varname(c), stack=true] .>= 0.0) # Dirichlet samples should be non-negative end