Skip to content
Open
8 changes: 3 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,20 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
FlexiChains = "4a37a8b9-6e57-4b92-8664-298d46e639f7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
DynamicPPLExt = ["DynamicPPL", "ForwardDiff", "LogDensityProblems"]
DynamicPPLExt = ["DynamicPPL", "FlexiChains", "LogDensityProblems"]

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see comment in DynamicPPLExt

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also FlexiChains is a hard dep now so shouldn't need it here anyways

EnzymeExt = "Enzyme"
LogDensityProblemsExt = "LogDensityProblems"

Expand All @@ -34,10 +33,9 @@ CUDA_Runtime_jll = "0.21"
DifferentiationInterface = "0.7.13"
DynamicPPL = "0.40.6, 0.41"

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
DynamicPPL = "0.40.6, 0.41"
DynamicPPL = "0.41.6, 0.42"

Enzyme = "0.13.146"
ForwardDiff = "1"

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should still declare a compat v for this even if removed from the Ext

FlexiChains = "0.6.6"
LinearAlgebra = "1"
LogDensityProblems = "2"
MCMCChains = "7.7.0"
Mooncake = "0.5.26"
Random = "1"
Statistics = "1"
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ The approach and its scaling tricks (stochastic Hutchinson Jacobian estimators,
| [`MALASampler`](src/interface.jl) | Baseline — sequential MALA with a fixed step size |
| [`AdaptiveMALASampler`](src/interface.jl) | Baseline — sequential MALA with dual-averaging step-size adaptation |

All samplers implement the [AbstractMCMC](https://github.com/TuringLang/AbstractMCMC.jl) interface and return [`MCMCChains.Chains`](https://github.com/TuringLang/MCMCChains.jl) objects, so they slot into existing Turing.jl / AbstractMCMC workflows.
All samplers implement the [AbstractMCMC](https://github.com/TuringLang/AbstractMCMC.jl) interface and return [`FlexiChains`](https://pysm.dev/FlexiChains.jl) objects, so they slot into existing Turing.jl / AbstractMCMC workflows.


### Quick start
Expand All @@ -53,7 +53,7 @@ pkg> add ParallelMCMC
```

```julia
using ParallelMCMC, MCMCChains
using ParallelMCMC, FlexiChains
using ADTypes, Enzyme

logp(x) = -0.5 * sum(abs2, x) # 2-D standard normal
Expand All @@ -63,7 +63,7 @@ model = DensityModel(logp, grad_logp, 2; param_names=[:x1, :x2])
sampler = ParallelMALASampler(0.1; T=64, jacobian=:stoch_diag,
backend=AutoEnzyme())

chain = sample(model, sampler, 500; chain_type=MCMCChains.Chains)
chain = sample(model, sampler, 500; chain_type=VNChain)
```

See the [Getting Started guide](docs/src/10-getting-started.md) for worked examples, GPU usage, Turing.jl integration, and step-size tuning.
Expand Down
56 changes: 47 additions & 9 deletions docs/src/10-getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ All samplers take a [`DensityModel`](@ref) as their first argument. A `DensityM
- `dim::Int` — dimension of the parameter space

```julia
using ParallelMCMC, MCMCChains
using ParallelMCMC, FlexiChains
using ADTypes, Enzyme

# Banana-shaped target in 2-D
Expand Down Expand Up @@ -40,7 +40,7 @@ sampler = ParallelMALASampler(0.1; T=64, jacobian=:stoch_diag, damping=0.5,
backend=AutoEnzyme())

chain = sample(model, sampler, 500;
chain_type=MCMCChains.Chains, progress=true)
chain_type=VNChain, progress=true)
```

`sample` requests 500 total samples. Internally, DEER solves trajectories of length `T=64` and returns each column of the solved trajectory as a separate sample. When the trajectory is exhausted a new noise tape is drawn and DEER re-solves from the last state.
Expand Down Expand Up @@ -90,7 +90,7 @@ ParallelMCMC.jl integrates with Turing.jl models through the `DynamicPPL` and `L
Load `DynamicPPL` (part of Turing.jl) and a single-argument `DensityModel` constructor becomes available:

```julia
using Turing, ParallelMCMC, MCMCChains
using Turing, ParallelMCMC, FlexiChains

@model function normal_model(y)
μ ~ Normal(0.0, 1.0)
Expand All @@ -100,7 +100,7 @@ end
model = DensityModel(normal_model(1.5)) # param_names=[:μ] extracted automatically

chain = sample(model, ParallelMALASampler(0.1; T=64, backend=AutoEnzyme()), 500;
chain_type=MCMCChains.Chains)
chain_type=VNChain)
```

Much like Turing's own samplers, the resulting chain will always have parameters in the original (possibly constrained) space, even though the MCMC sampling itself is performed in unconstrained space.
Expand All @@ -116,7 +116,7 @@ directly with DynamicPPL's `adtype` interface:

```julia
using Turing, LogDensityProblems, ADTypes
using ParallelMCMC, MCMCChains
using ParallelMCMC, FlexiChains

ld = DynamicPPL.LogDensityFunction(
normal_model(1.5),
Expand All @@ -140,10 +140,48 @@ All samplers support `MCMCThreads()`. Start Julia with multiple threads (e.g. `
```julia
chain = sample(model, ParallelMALASampler(0.1; T=64, backend=AutoEnzyme()),
MCMCThreads(), 500, 4;
chain_type=MCMCChains.Chains)
chain_type=VNChain)
```

Comment on lines 144 to 145

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

love ty

`MCMCChains` computes R-hat and ESS across chains automatically.
`FlexiChains` computes R-hat and ESS across chains automatically:

```julia
ess(chain)
```

To calculate intra-chain metrics, you can pass `dims=:iter` [as described in the FlexiChains docs](https://pysm.dev/FlexiChains.jl/stable/summarising/#Individual-statistics):

```julia
ess(chain; dims=:iter)
```

---

## [Specifying parameter names](@id parameter-names)

For manually constructed `DensityModel`s, you can optionally specify parameter names with the `param_names` keyword.
The resulting `FlexiChain` object will then have named entries for each parameter.

If you do not specify `param_names`, the chain will store a single vector-valued parameter called `x` of length `D`.

You can specify parameter names as a collection of either:

- `Symbol`s (e.g. `[:x1, :x2]`);
- `VarName`s (e.g. `[@varname(x[1]), @varname(x[2])]`) if `chain_type=VNChain`; or
- A tuple of the above, *plus* a size. In this case, a total of `prod(size)` entries will be allocated to the named parameter, and the results in the chain will be reshaped to that size.

The above can be mixed and matched as desired, as long as the total number of parameters matches the dimension of the model.
For example:

```julia
# Three scalar parameters, called `x1` through `x3`
model = DensityModel(...; param_names=[:x1, :x2, :x3])

# One scalar parameter called `x`, and a 1x2 matrix parameter called `y`
model = DensityModel(...; param_names=[:x, (:y, (1, 2))])
```

For Turing.jl models, parameter names are automatically derived from the model and do not need to be specified manually.

---

Expand All @@ -158,14 +196,14 @@ chain = sample(model, ParallelMALASampler(0.1; T=64, backend=AutoEnzyme()),
```julia
# Step 1: find a good step size with adaptive MALA
baseline = sample(model, AdaptiveMALASampler(0.1; n_warmup=500), 600;
chain_type=MCMCChains.Chains, discard_warmup=true)
chain_type=VNChain, discard_warmup=true)

# Read off the frozen step size from the last internal value
eps_tuned = baseline[end, :step_size, 1]

# Step 2: run DEER with the tuned step size
chain = sample(model, ParallelMALASampler(eps_tuned; T=64, backend=AutoEnzyme()), 2_000;
chain_type=MCMCChains.Chains)
chain_type=VNChain)
```

See the [`MALASampler`](@ref) and [`AdaptiveMALASampler`](@ref) reference pages for the full keyword listing.
8 changes: 4 additions & 4 deletions docs/src/15-gpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ y_gpu = CUDA.CuVector(y_cpu)
### Mooncake backend (plain operators)

```julia
using ParallelMCMC, MCMCChains
using ParallelMCMC, FlexiChains
using ADTypes, Mooncake

softplus(z) = log1p(exp(-abs(z))) + max(z, zero(z))
Expand Down Expand Up @@ -151,7 +151,7 @@ sampler = ParallelMALASampler(0.005f0;

chain = sample(model, sampler, 1_600;
initial_params=CUDA.zeros(Float32, D),
chain_type=MCMCChains.Chains)
chain_type=VNChain)
```

Posterior mean recovery error `‖β_post − β_true‖ / ‖β_true‖` should land in the 0.1–0.2 range after a few hundred post-warmup samples.
Expand All @@ -161,7 +161,7 @@ Posterior mean recovery error `‖β_post − β_true‖ / ‖β_true‖` should
Same model, with the GPU-Enzyme restrictions applied: every `*` becomes `pmcmc_matmul`, every `dot` becomes `pmcmc_dot`, and every gradient broadcast is expanded into single-op stages:

```julia
using ParallelMCMC, MCMCChains
using ParallelMCMC, FlexiChains
using ADTypes, Enzyme

function logp(β)
Expand Down Expand Up @@ -222,7 +222,7 @@ sampler = ParallelMALASampler(0.005f0;

chain = sample(model, sampler, 1_600;
initial_params=CUDA.zeros(Float32, D),
chain_type=MCMCChains.Chains)
chain_type=VNChain)
```

The two snippets sample the same posterior; the difference is purely in what the AD backend can chew on.
Expand Down
2 changes: 1 addition & 1 deletion docs/src/95-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ ParallelMALASampler

## Internal types

These types appear in `MCMCChains` internals and in the `AbstractMCMC` state/transition protocol. You generally do not need to construct them directly.
These types appear in the `AbstractMCMC` state/transition protocol. You generally do not need to construct them directly.

```@docs
MALATapeElement
Expand Down
18 changes: 12 additions & 6 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ The included [`MALASampler`](@ref) and [`AdaptiveMALASampler`](@ref) are sequent
| [`MALASampler`](@ref) | Baseline — sequential MALA with a fixed step size |
| [`AdaptiveMALASampler`](@ref) | Baseline — sequential MALA with dual-averaging step-size adaptation |

All samplers implement the [AbstractMCMC](https://github.com/TuringLang/AbstractMCMC.jl) interface and return [`MCMCChains.Chains`](https://github.com/TuringLang/MCMCChains.jl) objects.
All samplers implement the [AbstractMCMC](https://github.com/TuringLang/AbstractMCMC.jl) interface and return [`FlexiChains`](https://pysm.dev/FlexiChains.jl) objects.

## Installation

Expand All @@ -57,7 +57,7 @@ pkg> add ParallelMCMC
The simplest entry point is [`DensityModel`](@ref), which wraps a log-density and its gradient:

```julia
using ParallelMCMC, MCMCChains
using ParallelMCMC, FlexiChains
using ADTypes, Enzyme

# Example: 2-D standard normal
Expand All @@ -75,18 +75,22 @@ sampler = ParallelMALASampler(0.1; T=64, jacobian=:stoch_diag,
backend=AutoEnzyme())

chain = sample(model, sampler, 500;
chain_type=MCMCChains.Chains)
chain_type=VNChain)
```

Each call to `sample` draws 500 samples by solving DEER trajectories of length `T=64` in parallel, re-solving from the last state when each trajectory is exhausted.

Specifying `chain_type=VNChain` returns a `FlexiChain{VarName}`, which has a parameter type of `VarName`.
This is intended for maximum ease of use; however, if you prefer parameter type of `Symbol` you can use `SymChain` instead.
See [the FlexiChains.jl docs](https://pysm.dev/FlexiChains.jl/) for more information about how to analyze and visualize chains.

### Sequential MALA baseline

```julia
sampler = AdaptiveMALASampler(0.1; n_warmup=500)

chain = sample(model, sampler, 2_000;
chain_type=MCMCChains.Chains,
chain_type=VNChain,
discard_warmup=true,
progress=true)
```
Expand All @@ -97,7 +101,7 @@ When `DynamicPPL` (part of Turing.jl) is loaded, a one-argument `DensityModel` c
Parameter names are automatically extracted, and values transformed back to the original model space:

```julia
using Turing, ParallelMCMC, MCMCChains
using Turing, ParallelMCMC, FlexiChains

@model function normal_model(y)
μ ~ Normal(0.0, 1.0)
Expand All @@ -108,12 +112,14 @@ model = DensityModel(normal_model(1.5))
sampler = AdaptiveMALASampler(0.3; n_warmup=500)

chain = sample(model, sampler, 2_000;
chain_type=MCMCChains.Chains,
chain_type=VNChain,
discard_warmup=true)
```

See [Getting Started](10-getting-started.md) for worked examples and guidance on choosing samplers, and [Algorithm Details](20-algorithms.md) for the mathematics behind DEER.

For Turing models the chain type used must be `VNChain` (not `SymChain`), as that is the natural parameter type for Turing models.

## Contributors

```@raw html
Expand Down
39 changes: 16 additions & 23 deletions ext/DynamicPPLExt.jl
Comment thread
penelopeysm marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ParallelMCMC
using ADTypes: ADTypes
using DynamicPPL: DynamicPPL
using AbstractMCMC: AbstractMCMC
using MCMCChains: MCMCChains
using FlexiChains: FlexiChain, VarName, VNChain
using LogDensityProblems: LogDensityProblems

"""
Expand All @@ -14,13 +14,12 @@ Convenience constructor: wraps a DynamicPPL/Turing `@model` directly as a
`DensityModel`, automatically extracting parameter names and wiring up gradient
computation via DynamicPPL's `adtype` interface.

Requires `DynamicPPL`, `ForwardDiff`, and `LogDensityProblems` to be loaded (these are the
weak-dependency triggers for this extension; `ForwardDiff` is what backs the default
`AutoForwardDiff()` AD path).
Requires `DynamicPPL` and `LogDensityProblems` to be loaded (these are the weak-dependency
triggers for this extension), plus any AD backend that is used.

# Example
```julia
using Turing, ParallelMCMC, MCMCChains
using Turing, ParallelMCMC, FlexiChains
Comment thread
penelopeysm marked this conversation as resolved.

@model function mymodel(y)
μ ~ Normal(0, 1)
Expand All @@ -31,7 +30,7 @@ end
# and `using` the corresponding package (Enzyme, Mooncake).
model = DensityModel(mymodel(1.5))
chain = sample(model, AdaptiveMALASampler(0.3; n_warmup=500), 2_000;
chain_type=MCMCChains.Chains, discard_warmup=true, progress=true)
chain_type=FlexiChains.VNChain, discard_warmup=true, progress=true)
```
"""
function ParallelMCMC.DensityModel(
Comment on lines 35 to 36

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So right now I have ForwardDiff as the dfault for this ext because of issue #25 and ForwardDiff worked as a workaround. I failed to document this so thats my b. though I'm not sure it makes total sense to have this as a default just given there is other backends i didn't try. So at the very least we need to either re-add in the ForwardDiff as a gate for this ext or just remove this

Expand Down Expand Up @@ -106,40 +105,34 @@ for (Ttrans, Tspl, Tstate) in (
model::DensityModelLDF,
spl::$Tspl,
state::$Tstate,
chain_type::Type{MCMCChains.Chains};
chain_type::Type{VNChain},

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add some type of runtime guard for if a turing user accidently passes a SymChain?

discard_warmup::Bool=false,

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this should still be a kwarg

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
discard_warmup::Bool=false,
chain_type::Type{VNChain};
discard_warmup::Bool=false,

kwargs...,
)
ts = discard_warmup ? filter(t -> !is_warmup(t), ts) : ts
return make_processed_dynamicppl_chain(MCMCChains.Chains, ts, model)
pwss = map(ts) do t
# Note: This assumes that there is always a field called t.x. This is currently true
# of all samplers in ParallelMCMC
DynamicPPL.ParamsWithStats(t.x, model.logdensity.ld, getstats(t))
end
return AbstractMCMC.from_samples(VNChain, hcat(pwss))
end
end
end

function make_processed_dynamicppl_chain(
::Type{Tchain}, ts::Vector{<:ParallelMCMCTransitionTypes}, model::DensityModelLDF
) where {Tchain}
pwss = map(ts) do t
# Note: This assumes that there is always a field called t.x. This is currently true
# of all samplers in ParallelMCMC
DynamicPPL.ParamsWithStats(t.x, model.logdensity.ld, getstats(t))
end
return AbstractMCMC.from_samples(Tchain, hcat(pwss))
end

function ParallelMCMC._construct_chain(
::Type{MCMCChains.Chains},
function ParallelMCMC._construct_flexichain(
::Type{VarName},
vals::AbstractMatrix{<:Real},
internals::AbstractMatrix{<:Real},
::Vector{Symbol},
::Any,
internal_names::Vector{Symbol},
model::DensityModelLDF,
)
pwss = map(zip(eachrow(vals), eachrow(internals))) do (val, internal)
stats = NamedTuple{Tuple(internal_names)}(internal)
DynamicPPL.ParamsWithStats(val, model.logdensity.ld, stats)
Comment thread
penelopeysm marked this conversation as resolved.
end
return AbstractMCMC.from_samples(MCMCChains.Chains, hcat(pwss))
return AbstractMCMC.from_samples(VNChain, hcat(pwss))
end

end # module
Loading
Loading