diff --git a/.githooks/pre-commit b/.githooks/pre-commit
new file mode 100755
index 00000000..d85be5a7
--- /dev/null
+++ b/.githooks/pre-commit
@@ -0,0 +1,13 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+mapfile -d '' notebooks < <(git diff --cached --name-only -z --diff-filter=ACMR -- '*.ipynb')
+if [[ ${#notebooks[@]} -eq 0 ]]; then
+ exit 0
+fi
+
+python3 scripts/sync_notebook_badges.py "${notebooks[@]}"
+python3 scripts/strip_notebook_outputs.py "${notebooks[@]}"
+
+# Re-stage files in case the stripper changed them.
+git add -- "${notebooks[@]}"
diff --git a/.githooks/pre-push b/.githooks/pre-push
new file mode 100755
index 00000000..9111bde9
--- /dev/null
+++ b/.githooks/pre-push
@@ -0,0 +1,16 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+mapfile -d '' notebooks < <(git ls-files -z '*.ipynb')
+if [[ ${#notebooks[@]} -eq 0 ]]; then
+ exit 0
+fi
+
+python3 scripts/sync_notebook_badges.py "${notebooks[@]}"
+python3 scripts/strip_notebook_outputs.py "${notebooks[@]}"
+
+if ! git diff --quiet -- "${notebooks[@]}"; then
+ git add -- "${notebooks[@]}"
+ echo "Notebook badges/outputs were normalized during pre-push. Commit the updated notebooks and push again." >&2
+ exit 1
+fi
diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml
index f84ba9b4..5d0a8006 100644
--- a/.github/workflows/CI.yml
+++ b/.github/workflows/CI.yml
@@ -12,6 +12,21 @@ concurrency:
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
jobs:
+ notebook-clean:
+ name: Notebook Output Check
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - name: Verify notebooks are stripped
+ run: |
+ notebooks=$(git ls-files '*.ipynb')
+ if [ -z "$notebooks" ]; then
+ exit 0
+ fi
+ python3 scripts/sync_notebook_badges.py $notebooks
+ python3 scripts/strip_notebook_outputs.py $notebooks
+ git diff --exit-code
+
test:
name: ${{ matrix.pkg.name }} - Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }}
runs-on: ${{ matrix.os }}
@@ -108,4 +123,4 @@ jobs:
using CUDA;
using GeneralisedFilters;
println("GeneralisedFilters with CUDA loaded successfully");
- println("CUDA functional: ", CUDA.functional())'
\ No newline at end of file
+ println("CUDA functional: ", CUDA.functional())'
diff --git a/.github/workflows/Documentation.yml b/.github/workflows/Documentation.yml
index f335df53..601388bb 100644
--- a/.github/workflows/Documentation.yml
+++ b/.github/workflows/Documentation.yml
@@ -36,7 +36,24 @@ jobs:
- uses: julia-actions/setup-julia@v2
with:
version: '1'
- - uses: julia-actions/cache@v1
+ - uses: julia-actions/cache@v2
+ - name: Cache notebook outputs
+ if: matrix.pkg.name == 'GeneralisedFilters'
+ uses: actions/cache@v4
+ with:
+ path: GeneralisedFilters/docs/src/examples
+ key: notebooks-${{ runner.os }}-${{ hashFiles('GeneralisedFilters/examples/**/*.ipynb') }}
+ - name: Cache pip packages
+ if: matrix.pkg.name == 'GeneralisedFilters'
+ uses: actions/cache@v4
+ with:
+ path: ~/.cache/pip
+ key: ${{ runner.os }}-pip-jupyter-nbconvert
+ - name: Install notebook tooling
+ if: matrix.pkg.name == 'GeneralisedFilters'
+ run: |
+ python3 -m pip install --upgrade pip
+ python3 -m pip install jupyter nbconvert
- name: Install dependencies
run: |
julia --project=${{ matrix.pkg.dir }}/docs/ --color=yes -e '
diff --git a/GeneralisedFilters/docs/Project.toml b/GeneralisedFilters/docs/Project.toml
index cf645b52..3878ce2f 100644
--- a/GeneralisedFilters/docs/Project.toml
+++ b/GeneralisedFilters/docs/Project.toml
@@ -1,3 +1,3 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
-Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
+IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
diff --git a/GeneralisedFilters/docs/literate.jl b/GeneralisedFilters/docs/literate.jl
deleted file mode 100644
index feb1e47b..00000000
--- a/GeneralisedFilters/docs/literate.jl
+++ /dev/null
@@ -1,19 +0,0 @@
-# Retrieve name of example and output directory
-if length(ARGS) != 2
- error("please specify the name of the example and the output directory")
-end
-const EXAMPLE = ARGS[1]
-const OUTDIR = ARGS[2]
-
-# Activate environment
-# Note that each example's Project.toml must include Literate as a dependency
-using Pkg: Pkg
-const EXAMPLEPATH = joinpath(@__DIR__, "..", "examples", EXAMPLE)
-Pkg.activate(EXAMPLEPATH)
-# Pkg.develop(joinpath(@__DIR__, "..", "..", "SSMProblems"))
-Pkg.instantiate()
-using Literate: Literate
-
-# Convert to markdown and notebook
-const SCRIPTJL = joinpath(EXAMPLEPATH, "script.jl")
-Literate.markdown(SCRIPTJL, OUTDIR; name=EXAMPLE, execute=true)
diff --git a/GeneralisedFilters/docs/make.jl b/GeneralisedFilters/docs/make.jl
index 117d6cff..c5a05d07 100644
--- a/GeneralisedFilters/docs/make.jl
+++ b/GeneralisedFilters/docs/make.jl
@@ -1,17 +1,20 @@
push!(LOAD_PATH, "../src/")
+const REPO = "TuringLang/SSMProblems.jl"
+const PKG_SUBDIR = "GeneralisedFilters"
#
# With minor changes from https://github.com/JuliaGaussianProcesses/AbstractGPs.jl/docs
#
### Process examples
-# Always rerun examples
+const EXAMPLES_ROOT = joinpath(@__DIR__, "..", "examples")
const EXAMPLES_OUT = joinpath(@__DIR__, "src", "examples")
-ispath(EXAMPLES_OUT) && rm(EXAMPLES_OUT; recursive=true)
mkpath(EXAMPLES_OUT)
+const EXAMPLE_ASSETS_OUT = joinpath(@__DIR__, "src", "assets", "examples")
+mkpath(EXAMPLE_ASSETS_OUT)
# Install and precompile all packages
# Workaround for https://github.com/JuliaLang/Pkg.jl/issues/2219
-examples = filter!(isdir, readdir(joinpath(@__DIR__, "..", "examples"); join=true))
+examples = sort(filter!(isdir, readdir(EXAMPLES_ROOT; join=true)))
above = joinpath(@__DIR__, "..")
ssmproblems_path = joinpath(above, "..", "SSMProblems")
let script = "using Pkg; Pkg.activate(ARGS[1]); Pkg.develop(path=\"$(above)\"); Pkg.develop(path=\"$(ssmproblems_path)\"); Pkg.instantiate()"
@@ -26,11 +29,13 @@ let script = "using Pkg; Pkg.activate(ARGS[1]); Pkg.develop(path=\"$(above)\");
end
end
# Run examples asynchronously
-processes = let literatejl = joinpath(@__DIR__, "literate.jl")
+processes = let
+ notebookjl = joinpath(@__DIR__, "notebook.jl")
+ docs_project = abspath(@__DIR__)
map(examples) do example
return run(
pipeline(
- `$(Base.julia_cmd()) $literatejl $(basename(example)) $EXAMPLES_OUT`;
+ `$(Base.julia_cmd()) --project=$(docs_project) $notebookjl $(basename(example)) $EXAMPLES_OUT`;
stdin=devnull,
stdout=devnull,
stderr=stderr,
@@ -43,6 +48,108 @@ end
# Check that all examples were run successfully
isempty(processes) || success(processes) || error("some examples were not run successfully")
+example_slug(markdown_filename::AbstractString) = splitext(markdown_filename)[1]
+
+function example_title(markdown_path::AbstractString, slug::AbstractString)
+ for line in eachline(markdown_path)
+ stripped = strip(line)
+ startswith(stripped, "# ") && return strip(stripped[3:end])
+ end
+ return replace(slug, "-" => " ")
+end
+
+function example_summary(markdown_path::AbstractString)
+ in_fence = false
+ for line in eachline(markdown_path)
+ stripped = strip(line)
+
+ if startswith(stripped, "```")
+ in_fence = !in_fence
+ continue
+ end
+
+ if in_fence || isempty(stripped)
+ continue
+ end
+
+ if startswith(stripped, "# ")
+ continue
+ end
+
+ if occursin("Open in Colab", stripped) ||
+ occursin("Source notebook", stripped) ||
+ startswith(stripped, "*This page was generated")
+ continue
+ end
+
+ return replace(stripped, "|" => "\\|")
+ end
+
+ return "Runnable example with executable source."
+end
+
+function example_thumbnail(slug::AbstractString)
+ for ext in ("svg", "png", "jpg", "jpeg", "webp")
+ thumb = joinpath(EXAMPLE_ASSETS_OUT, string(slug, ".", ext))
+ if isfile(thumb)
+ return "../assets/examples/$(slug).$(ext)"
+ end
+ end
+ return "../assets/examples/default.svg"
+end
+
+function links_for_example(slug::AbstractString)
+ example_dir = joinpath(EXAMPLES_ROOT, slug)
+ notebook_name = string(slug, ".ipynb")
+ isfile(joinpath(example_dir, notebook_name)) ||
+ error("example $(slug) must include $(notebook_name)")
+
+ return join(
+ [
+ "[Colab](https://colab.research.google.com/github/$(REPO)/blob/main/$(PKG_SUBDIR)/examples/$(slug)/$(notebook_name))",
+ "[Notebook](https://github.com/$(REPO)/blob/main/$(PKG_SUBDIR)/examples/$(slug)/$(notebook_name))",
+ ],
+ " · ",
+ )
+end
+
+function write_examples_index(example_markdowns::Vector{String})
+ index_path = joinpath(EXAMPLES_OUT, "index.md")
+ open(index_path, "w") do io
+ println(io, "# Examples")
+ println(io)
+ println(
+ io,
+ "Executable examples for `GeneralisedFilters` with links to notebooks and source files.",
+ )
+ println(io)
+ println(io, "| Example | Preview |")
+ println(io, "| :-- | :-- |")
+ for markdown in example_markdowns
+ slug = example_slug(markdown)
+ markdown_path = joinpath(EXAMPLES_OUT, markdown)
+ title = example_title(markdown_path, slug)
+ summary = example_summary(markdown_path)
+ links = links_for_example(slug)
+ thumbnail = example_thumbnail(slug)
+
+ page_link = "[$(title)]($(markdown))"
+ left = string(page_link, "
", summary, "
", links)
+ right = "[)]($(markdown))"
+ println(io, "| $(left) | $(right) |")
+ end
+ end
+ return nothing
+end
+
+const EXAMPLE_MARKDOWNS = sort(
+ filter(
+ filename -> endswith(filename, ".md") && filename != "index.md",
+ readdir(EXAMPLES_OUT),
+ ),
+)
+write_examples_index(EXAMPLE_MARKDOWNS)
+
# Building Documenter
using Documenter
using GeneralisedFilters
@@ -56,11 +163,9 @@ makedocs(;
format=Documenter.HTML(; size_threshold=1000 * 2^11), # 1Mb per page
pages=[
"Home" => "index.md",
- "Examples" => [
- map(
- (x) -> joinpath("examples", x),
- filter!(filename -> endswith(filename, ".md"), readdir(EXAMPLES_OUT)),
- )...,
+ "Examples" => Any[
+ "examples/index.md",
+ map((x) -> joinpath("examples", x), EXAMPLE_MARKDOWNS)...,
],
],
#strict=true,
diff --git a/GeneralisedFilters/docs/notebook.jl b/GeneralisedFilters/docs/notebook.jl
new file mode 100644
index 00000000..7f9cea1f
--- /dev/null
+++ b/GeneralisedFilters/docs/notebook.jl
@@ -0,0 +1,129 @@
+# Build one example from notebook source into markdown for Documenter.
+if length(ARGS) != 2
+ error("please specify the name of the example and the output directory")
+end
+
+const EXAMPLE = ARGS[1]
+const OUTDIR = ARGS[2]
+const REPO = "TuringLang/SSMProblems.jl"
+const PKG_SUBDIR = "GeneralisedFilters"
+const DOCS_ENV = abspath(@__DIR__)
+
+using Pkg: Pkg
+
+const EXAMPLEPATH = joinpath(@__DIR__, "..", "examples", EXAMPLE)
+const NOTEBOOK_FILENAME = string(EXAMPLE, ".ipynb")
+const NOTEBOOK = joinpath(EXAMPLEPATH, NOTEBOOK_FILENAME)
+const MARKDOWN = joinpath(OUTDIR, string(EXAMPLE, ".md"))
+
+isfile(NOTEBOOK) || error("example $(EXAMPLE) must include $(NOTEBOOK_FILENAME)")
+
+if isfile(MARKDOWN)
+ @info "Skipping $(EXAMPLE): cached output found at $(MARKDOWN)"
+ exit(0)
+end
+
+# Keep notebook execution in the example's own environment.
+Pkg.activate(EXAMPLEPATH)
+Pkg.instantiate()
+
+Pkg.activate(DOCS_ENV)
+Pkg.instantiate()
+using IJulia
+
+const KERNEL_NAME = "gf-docs-julia-$(VERSION.major).$(VERSION.minor)"
+IJulia.installkernel(KERNEL_NAME, "--project=$(DOCS_ENV)"; specname=KERNEL_NAME)
+
+function run_nbconvert(
+ examplepath::AbstractString,
+ outdir::AbstractString,
+ name::AbstractString,
+ kernel::AbstractString,
+)
+ jupyter = Sys.which("jupyter")
+ isnothing(jupyter) && error(
+ "jupyter executable not found. Install it (e.g. `pip install jupyter nbconvert`) before building docs.",
+ )
+
+ cmd = `$(jupyter) nbconvert --to markdown --execute --ExecutePreprocessor.timeout=3600 --ExecutePreprocessor.kernel_name=$(kernel) --output=$(name) --output-dir=$(outdir) $(NOTEBOOK_FILENAME)`
+ run(pipeline(Cmd(cmd; dir=examplepath); stdin=devnull, stdout=devnull, stderr=stderr))
+ return nothing
+end
+
+function inject_edit_url(markdown_path::AbstractString, example::AbstractString)
+ content = read(markdown_path, String)
+ meta_block = string(
+ "```@meta\n",
+ "EditURL = \"../../../examples/",
+ example,
+ "/",
+ NOTEBOOK_FILENAME,
+ "\"\n",
+ "```\n\n",
+ )
+
+ if occursin(r"(?m)^```@meta$", content)
+ content = replace(content, r"(?s)^```@meta.*?```\s*" => meta_block; count=1)
+ else
+ content = string(meta_block, content)
+ end
+
+ write(markdown_path, content)
+ return nothing
+end
+
+function inject_docs_badges(markdown_path::AbstractString, example::AbstractString)
+ content = read(markdown_path, String)
+ notebook_name = string(example, ".ipynb")
+
+ colab_url = string(
+ "https://colab.research.google.com/github/",
+ REPO,
+ "/blob/main/",
+ PKG_SUBDIR,
+ "/examples/",
+ example,
+ "/",
+ notebook_name,
+ )
+ source_url = string(
+ "https://github.com/",
+ REPO,
+ "/blob/main/",
+ PKG_SUBDIR,
+ "/examples/",
+ example,
+ "/",
+ notebook_name,
+ )
+ badge_line = string(
+ "[](",
+ colab_url,
+ ") [](",
+ source_url,
+ ")",
+ )
+
+ # Remove existing notebook badge lines, then inject docs-specific badges.
+ content = replace(content, r"(?m)^\[\!\[Open in Colab\].*\n?" => "";)
+ content = replace(content, r"\n{3,}" => "\n\n";)
+
+ heading_match = match(r"(?m)^# .+$", content)
+ if isnothing(heading_match)
+ content = string(badge_line, "\n\n", content)
+ else
+ head_start = heading_match.offset
+ head_end = head_start + ncodeunits(heading_match.match) - 1
+ head = content[1:head_end]
+ tail = content[(head_end + 1):end]
+ tail = replace(tail, r"^\n+" => "")
+ content = string(head, "\n\n", badge_line, "\n\n", tail)
+ end
+
+ write(markdown_path, content)
+ return nothing
+end
+
+run_nbconvert(EXAMPLEPATH, OUTDIR, EXAMPLE, KERNEL_NAME)
+inject_docs_badges(MARKDOWN, EXAMPLE)
+inject_edit_url(MARKDOWN, EXAMPLE)
diff --git a/GeneralisedFilters/docs/src/assets/examples/default.svg b/GeneralisedFilters/docs/src/assets/examples/default.svg
new file mode 100644
index 00000000..a75b498e
--- /dev/null
+++ b/GeneralisedFilters/docs/src/assets/examples/default.svg
@@ -0,0 +1,12 @@
+
diff --git a/GeneralisedFilters/docs/src/assets/examples/trend-inflation.svg b/GeneralisedFilters/docs/src/assets/examples/trend-inflation.svg
new file mode 100644
index 00000000..c2506b9e
--- /dev/null
+++ b/GeneralisedFilters/docs/src/assets/examples/trend-inflation.svg
@@ -0,0 +1,610 @@
+
+
diff --git a/GeneralisedFilters/examples/trend-inflation/Project.toml b/GeneralisedFilters/examples/trend-inflation/Project.toml
index dd1155bf..9eced912 100644
--- a/GeneralisedFilters/examples/trend-inflation/Project.toml
+++ b/GeneralisedFilters/examples/trend-inflation/Project.toml
@@ -5,7 +5,6 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GeneralisedFilters = "3ef92589-7ab8-43f9-b5b9-a3a0c86ecbb7"
-Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
diff --git a/GeneralisedFilters/examples/trend-inflation/script-alt.jl b/GeneralisedFilters/examples/trend-inflation/script-alt.jl
deleted file mode 100644
index 223e5577..00000000
--- a/GeneralisedFilters/examples/trend-inflation/script-alt.jl
+++ /dev/null
@@ -1,154 +0,0 @@
-using GeneralisedFilters
-using SSMProblems
-using Distributions
-
-using Random
-using StatsBase
-
-include("utilities.jl")
-
-## STATE PRIORS ############################################################################
-
-struct LocalLevelTrendPrior{T<:Real} <: StatePrior end
-
-function SSMProblems.distribution(prior::LocalLevelTrendPrior{T}; kwargs...) where {T}
- return product_distribution(
- Normal(zero(T), T(5)), Normal(zero(T), T(1)), Normal(zero(T), T(1))
- )
-end
-
-struct OutlierAdjustedTrendPrior{T<:Real} <: StatePrior end
-
-function SSMProblems.distribution(prior::OutlierAdjustedTrendPrior{T}; kwargs...) where {T}
- return product_distribution(
- Normal(zero(T), T(5)), Normal(zero(T), T(1)), Normal(zero(T), T(1)), Dirac(one(T))
- )
-end
-
-## LATENT DYNAMICS #########################################################################
-
-struct LocalLevelTrend{ΓT<:AbstractVector} <: LatentDynamics
- γ::ΓT
-end
-
-function SSMProblems.logdensity(
- proc::LocalLevelTrend, ::Integer, prev_state, state, kwargs...
-)
- vol_prob = logpdf(MvNormal(prev_state[2:end], proc.γ), state[2:end])
- trend_prob = logpdf(Normal(prev_state[1], exp(prev_state[2] / 2)), state[1])
- return vol_prob + trend_prob
-end
-
-function SSMProblems.simulate(
- rng::AbstractRNG,
- proc::LocalLevelTrend,
- step::Integer,
- state::AbstractVector{T};
- kwargs...,
-) where {T<:Real}
- new_state = deepcopy(state)
- new_state[2:3] += proc.γ .* randn(rng, T, 2)
- new_state[1] += exp(new_state[2] / 2) * randn(T)
- return new_state
-end
-
-struct OutlierAdjustedTrend{ΓT<:AbstractVector} <: LatentDynamics
- trend::LocalLevelTrend{ΓT}
- switch_dist::Bernoulli
- outlier_dist::Uniform
-end
-
-function SSMProblems.logdensity(
- proc::OutlierAdjustedTrend, step::Integer, prev_state, state, kwargs...
-)
- return SSMProblems.logdensity(proc.trend, step, prev_state, state; kwargs...)
-end
-
-function SSMProblems.simulate(
- rng::AbstractRNG,
- proc::OutlierAdjustedTrend,
- step::Integer,
- state::AbstractVector{T};
- kwargs...,
-) where {T<:Real}
- new_state = SSMProblems.simulate(rng, proc.trend, step, state; kwargs...)
- new_state[4] = rand(rng, proc.switch_dist) ? rand(rng, proc.outlier_dist) : one(T)
- return new_state
-end
-
-## OBSERVATION PROCESS #####################################################################
-
-struct OutlierAdjustedObservation <: ObservationProcess end
-
-function SSMProblems.distribution(
- proc::OutlierAdjustedObservation, step::Integer, state::AbstractVector; kwargs...
-)
- return Normal(state[1], sqrt(state[4]) * exp(state[3] / 2))
-end
-
-struct SimpleObservation <: ObservationProcess end
-
-function SSMProblems.distribution(
- proc::SimpleObservation, step::Integer, state::AbstractVector; kwargs...
-)
- return Normal(state[1], exp(state[3] / 2))
-end
-
-## MAIN ####################################################################################
-
-# include UCSV as a baseline
-function UCSV(γ::T) where {T<:Real}
- return StateSpaceModel(
- LocalLevelTrendPrior{T}(), LocalLevelTrend(fill(γ, 2)), SimpleObservation()
- )
-end
-
-# quick demo of the outlier-adjusted univariate UCSV model
-function UCSVO(γ::T, prob::T) where {T<:Real}
- trend = LocalLevelTrend(fill(γ, 2))
- return StateSpaceModel(
- OutlierAdjustedTrendPrior{T}(),
- OutlierAdjustedTrend(trend, Bernoulli(prob), Uniform{T}(2, 10)),
- OutlierAdjustedObservation(),
- )
-end
-
-# wrapper to plot and demo the model
-function plot_ucsv(rng::AbstractRNG, model, data)
- alg = BF(2^14; threshold=1.0, resampler=Systematic())
- sparse_ancestry = GeneralisedFilters.AncestorCallback(nothing)
- states, ll = GeneralisedFilters.filter(rng, model, alg, data; callback=sparse_ancestry)
-
- fig = Figure(; size=(1200, 500), fontsize=16)
- dateticks = date_format(fred.data.date)
-
- all_paths = map(x -> hcat(x...), GeneralisedFilters.get_ancestry(sparse_ancestry.tree))
- mean_paths = mean(all_paths, Weights(StatsBase.weights(states.log_weights)))
-
- ax = Axis(
- fig[1:2, 1];
- limits=(nothing, (-14, 18)),
- title="Trend Inflation",
- xtickformat=dateticks,
- )
-
- lines!(fig[1:2, 1], vcat(0, data...); color=:red, linestyle=:dash)
- lines!(ax, mean_paths[1, :]; color=:black)
-
- ax1 = Axis(fig[1, 2]; title="Volatility", xtickformat=dateticks)
- lines!(ax1, exp.(0.5 * mean_paths[2, :]); color=:black, label="permanent")
- axislegend(ax1; position=:rt)
-
- ax2 = Axis(fig[2, 2]; xtickformat=dateticks)
- lines!(ax2, exp.(0.5 * mean_paths[3, :]); color=:black, label="transitory")
- axislegend(ax2; position=:lt)
-
- display(fig)
- return ll
-end
-
-rng = MersenneTwister(1234);
-
-# plot both models side by side, notice the difference in volatility
-plot_ucsv(rng, UCSV(0.2), fred.data.value);
-plot_ucsv(rng, UCSVO(0.2, 0.05), fred.data.value);
diff --git a/GeneralisedFilters/examples/trend-inflation/script.jl b/GeneralisedFilters/examples/trend-inflation/script.jl
deleted file mode 100644
index 126db1db..00000000
--- a/GeneralisedFilters/examples/trend-inflation/script.jl
+++ /dev/null
@@ -1,253 +0,0 @@
-# # Trend Inflation
-#
-# This example is a replication of the univariate state space model suggested by (Stock &
-# Watson, 2016) using GeneralisedFilters to define a heirarchical model for use in Rao-
-# Blackwellised particle filtering.
-
-using GeneralisedFilters
-using SSMProblems
-using Distributions
-using Random
-using StatsBase
-using LinearAlgebra
-using PDMats
-
-const GF = GeneralisedFilters
-
-INFL_PATH = joinpath(@__DIR__, "..", "..", "..", "examples", "trend-inflation"); #hide
-# INFL_PATH = joinpath(@__DIR__)
-include(joinpath(INFL_PATH, "utilities.jl")); #hide
-
-# ## Model Definition
-
-# We begin by defining the local level trend model, a linear Gaussian model with a weakly
-# stationary random walk component. The dynamics of which are as follows:
-
-# ```math
-# \begin{aligned}
-# y_{t} &= x_{t} + \eta_{t} \\
-# x_{t+1} &= x_{t} + \varepsilon_{t}
-# \end{aligned}
-# ```
-
-# However, this model is not enough to capture trend dynamics when faced with structural
-# breaks. (Stock & Watson, 2007) suggest adding a stochastic volatiltiy component, defined
-# like so:
-
-# ```math
-# \begin{aligned}
-# \log \sigma_{\eta, t+1} = \log \sigma_{\eta, t} + \nu_{\eta, t} \\
-# \log \sigma_{\varepsilon, t+1} = \log \sigma_{\varepsilon, t} + \nu_{\varepsilon, t}
-# \end{aligned}
-# ```
-
-# where $\nu_{z,t} \sim N(0, \gamma)$ for $z \in \{ \varepsilon, \eta \}$.
-
-# Using `GeneralisedFilters`, we can construct a heirarchical version of this model such
-# that the local level trend component is conditionally linear Gaussian on the volatility
-# draws.
-
-# #### Stochastic Volatility Process
-
-# We begin by defining the non-linear dynamics, which aren't conditioned contemporaneous
-# states. Since these processes are traditionally non-linear/non-Gaussian we use the
-# SSMProblems interface to define the stochastic volatility components.
-
-struct StochasticVolatilityPrior{T<:Real} <: StatePrior end
-
-#
-
-function SSMProblems.distribution(prior::StochasticVolatilityPrior{T}; kwargs...) where {T}
- return product_distribution(Normal(zero(T), T(1)), Normal(zero(T), T(1)))
-end
-
-# For the dynamics, instead of using the `SSMProblems.distribution` utility, we only define
-# the `simulate` method, which is sufficient for the RBPF.
-
-struct StochasticVolatility{ΓT<:AbstractVector} <: LatentDynamics
- γ::ΓT
-end
-
-#
-
-function SSMProblems.simulate(
- rng::AbstractRNG,
- proc::StochasticVolatility,
- step::Integer,
- state::AbstractVector{T};
- kwargs...,
-) where {T<:Real}
- new_state = deepcopy(state)
- new_state[1:2] += proc.γ .* randn(rng, T, 2)
- return new_state
-end
-
-# #### Local Level Trend Process
-#
-# For the conditionally linear and Gaussian components, we subtype the model and provide a
-# keyword argument as the conditional element. In this case $A$ and $b$ remain constant, but
-# $Q$ is conditional on the log variance, stored in `new_outer` (the nomenclature chosen for
-# heirarchical modeling).
-
-struct LocalLevelTrend <: LinearGaussianLatentDynamics end
-
-#
-
-GF.calc_A(::LocalLevelTrend, ::Integer; kwargs...) = [1;;]
-GF.calc_b(::LocalLevelTrend, ::Integer; kwargs...) = [0;]
-function GF.calc_Q(::LocalLevelTrend, ::Integer; new_outer, kwargs...)
- return PDMat([exp(new_outer[1]);;])
-end
-
-# Similarly, we define the observation process conditional on a separate log variance.
-
-struct SimpleObservation <: LinearGaussianObservationProcess end
-
-#
-
-GF.calc_H(::SimpleObservation, ::Integer; kwargs...) = [1;;]
-GF.calc_c(::SimpleObservation, ::Integer; kwargs...) = [0;]
-function GF.calc_R(::SimpleObservation, ::Integer; new_outer, kwargs...)
- return PDMat([exp(new_outer[2]);;])
-end
-
-# ### Unobserved Components with Stochastic Volatility
-
-# The state space model suggested by (Stock & Watson, 2007) can be constructed with the
-# following method:
-
-function UCSV(γ::T) where {T<:Real}
- stoch_vol_prior = StochasticVolatilityPrior{T}()
- stoch_vol_process = StochasticVolatility(fill(γ, 2))
-
- local_level_model = StateSpaceModel(
- GF.HomogeneousGaussianPrior(zeros(T, 1), PDMat([100.0;;])),
- LocalLevelTrend(),
- SimpleObservation(),
- )
-
- return HierarchicalSSM(stoch_vol_prior, stoch_vol_process, local_level_model)
-end;
-
-# For plotting, we can extract the ancestry of the Rao Blackwellised particles using the
-# callback system. For our inflation data, this reduces to the following:
-
-rng = MersenneTwister(1234);
-sparse_ancestry = GF.AncestorCallback(nothing);
-states, ll = GF.filter(
- rng,
- UCSV(0.2),
- RBPF(BF(2^12), KalmanFilter()),
- [[pce] for pce in fred_data.value];
- callback=sparse_ancestry,
-);
-
-# The `sparse_ancestry` object stores a sparse ancestry tree which we can use to approximate
-# the smoothed series without an additional backwards pass. We can convert this data
-# structure to a human readable array by using `GeneralisedFilters.get_ancestry` and then
-# take the mean path by passing a custom function.
-
-trends, volatilities = mean_path(GF.get_ancestry(sparse_ancestry.tree), states);
-plot_ucsv(trends[1, :], eachrow(volatilities), fred_data)
-
-# #### Outlier Adjustments
-
-# For additional robustness, (Stock & Watson, 2016) account for one-time measurement shocks
-# and suggest an alteration in the observation equation, where
-
-# ```math
-# \eta_{t} \sim N(0, s_{t} \cdot \sigma_{\eta, t}^2) \quad \quad s_{t} \sim \begin{cases}
-# U(0,2) & \text{ with probability } p \\
-# \delta(1) & \text{ with probability } 1 - p
-# \end{cases}
-# ```
-
-# The prior is the same as before, but with additional state which we can assume will always
-# be 1; using the `Distributions` interface this is just `Dirac(1)`
-
-struct OutlierAdjustedVolatilityPrior{T<:Real} <: StatePrior end
-
-#
-
-function SSMProblems.distribution(
- prior::OutlierAdjustedVolatilityPrior{T}; kwargs...
-) where {T}
- return product_distribution(Normal(zero(T), T(1)), Normal(zero(T), T(1)), Dirac(one(T)))
-end
-
-# In terms of the model definition, we can construct a separate `LatentDynamics` which
-# contains the same volatility process as before, but with the respective draw in the third
-# component.
-
-struct OutlierAdjustedVolatility{ΓT} <: LatentDynamics
- volatility::StochasticVolatility{ΓT}
- switch_dist::Bernoulli
- outlier_dist::Uniform
-end
-
-# The simulation then calls the volatility process, and computes the outlier term in the
-# third state
-
-function SSMProblems.simulate(
- rng::AbstractRNG,
- proc::OutlierAdjustedVolatility,
- step::Integer,
- state::AbstractVector{T};
- kwargs...,
-) where {T<:Real}
- new_state = SSMProblems.simulate(rng, proc.volatility, step, state; kwargs...)
- new_state[3] = rand(rng, proc.switch_dist) ? rand(rng, proc.outlier_dist) : one(T)
- return new_state
-end
-
-# For the observation process, we define a new object where $R$ is dependent on both the
-# measurement volatility as well as this outlier adjustment coefficient.
-
-struct OutlierAdjustedObservation <: LinearGaussianObservationProcess end
-
-#
-
-GF.calc_H(::OutlierAdjustedObservation, ::Integer; kwargs...) = [1;;]
-GF.calc_c(::OutlierAdjustedObservation, ::Integer; kwargs...) = [0;]
-function GF.calc_R(::OutlierAdjustedObservation, ::Integer; new_outer, kwargs...)
- return PDMat([new_outer[3] * exp(new_outer[2]);;])
-end
-
-# ### Outlier Adjusted UCSV
-
-# The state space model suggested by (Stock & Watson, 2007) can be constructed with the
-# following method:
-
-function UCSVO(γ::T, prob::T) where {T<:Real}
- stoch_vol_prior = OutlierAdjustedVolatilityPrior{T}()
- stoch_vol_process = OutlierAdjustedVolatility(
- StochasticVolatility(fill(γ, 2)), Bernoulli(prob), Uniform{T}(2, 10)
- )
-
- local_level_model = StateSpaceModel(
- GF.HomogeneousGaussianPrior(zeros(T, 1), PDMat([100.0;;])),
- LocalLevelTrend(),
- OutlierAdjustedObservation(),
- )
-
- return HierarchicalSSM(stoch_vol_prior, stoch_vol_process, local_level_model)
-end;
-
-# We then repeat the same experiment, this time with an outlier probability of $p = 0.05$
-
-rng = MersenneTwister(1234);
-sparse_ancestry = GF.AncestorCallback(nothing)
-states, ll = GF.filter(
- rng,
- UCSVO(0.2, 0.05),
- RBPF(BF(2^12), KalmanFilter()),
- [[pce] for pce in fred_data.value];
- callback=sparse_ancestry,
-);
-
-# this process is identical to the last, except with an additional `volatilities` state
-# which captures the outlier distance. We omit this feature in the plots, but the impact is
-# clear when comparing the maximum transitory noise around the GFC.
-
-trends, volatilities = mean_path(GF.get_ancestry(sparse_ancestry.tree), states);
-plot_ucsv(trends[1, :], eachrow(volatilities), fred_data)
diff --git a/GeneralisedFilters/examples/trend-inflation/trend-inflation.ipynb b/GeneralisedFilters/examples/trend-inflation/trend-inflation.ipynb
new file mode 100644
index 00000000..8ae65267
--- /dev/null
+++ b/GeneralisedFilters/examples/trend-inflation/trend-inflation.ipynb
@@ -0,0 +1,592 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Trend Inflation\n",
+ "\n",
+ "[](https://colab.research.google.com/github/TuringLang/SSMProblems.jl/blob/main/GeneralisedFilters/examples/trend-inflation/trend-inflation.ipynb) [](https://github.com/TuringLang/SSMProblems.jl/blob/main/GeneralisedFilters/examples/trend-inflation/trend-inflation.ipynb) [](https://turinglang.org/SSMProblems.jl/GeneralisedFilters/dev/examples/trend-inflation/)\n",
+ "\n",
+ "This example is a replication of the univariate state space model suggested by (Stock &\n",
+ "Watson, 2016) using GeneralisedFilters to define a heirarchical model for use in Rao-Blackwellised particle filtering."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": [
+ "remove-output"
+ ]
+ },
+ "outputs": [],
+ "source": [
+ "using Pkg\n",
+ "\n",
+ "Pkg.activate(\".\")\n",
+ "if isfile(\"Project.toml\")\n",
+ " Pkg.instantiate()\n",
+ "else\n",
+ " Pkg.add([\n",
+ " \"SSMProblems\",\n",
+ " \"GeneralisedFilters\",\n",
+ " \"CSV\",\n",
+ " \"CairoMakie\",\n",
+ " \"DataFrames\",\n",
+ " \"Distributions\",\n",
+ " \"LogExpFunctions\",\n",
+ " \"PDMats\",\n",
+ " \"StatsBase\",\n",
+ " ])\n",
+ "end;"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": [
+ "remove-output"
+ ]
+ },
+ "outputs": [],
+ "source": [
+ "using GeneralisedFilters\n",
+ "using SSMProblems\n",
+ "using Distributions\n",
+ "using Random\n",
+ "using StatsBase\n",
+ "using LinearAlgebra\n",
+ "using PDMats\n",
+ "\n",
+ "const GF = GeneralisedFilters"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "using Downloads\n",
+ "\n",
+ "const UTILITIES_URL = \"https://raw.githubusercontent.com/TuringLang/SSMProblems.jl/main/GeneralisedFilters/examples/trend-inflation/utilities.jl\"\n",
+ "const UTILITIES_PATH = isfile(\"utilities.jl\") ? \"utilities.jl\" : Downloads.download(UTILITIES_URL)\n",
+ "include(UTILITIES_PATH);"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Model Definition"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We begin by defining the local level trend model, a linear Gaussian model with a weakly\n",
+ "stationary random walk component. The dynamics of which are as follows:"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "$$\n",
+ "\\begin{aligned}\n",
+ " y_{t} &= x_{t} + \\eta_{t} \\\\\n",
+ " x_{t+1} &= x_{t} + \\varepsilon_{t}\n",
+ "\\end{aligned}\n",
+ "$$"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "However, this model is not enough to capture trend dynamics when faced with structural\n",
+ "breaks. (Stock & Watson, 2007) suggest adding a stochastic volatiltiy component, defined\n",
+ "like so:"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "$$\n",
+ "\\begin{aligned}\n",
+ " \\log \\sigma_{\\eta, t+1} = \\log \\sigma_{\\eta, t} + \\nu_{\\eta, t} \\\\\n",
+ " \\log \\sigma_{\\varepsilon, t+1} = \\log \\sigma_{\\varepsilon, t} + \\nu_{\\varepsilon, t}\n",
+ "\\end{aligned}\n",
+ "$$"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "where $\\nu_{z,t} \\sim N(0, \\gamma)$ for $z \\in \\{ \\varepsilon, \\eta \\}$."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Using `GeneralisedFilters`, we can construct a heirarchical version of this model such\n",
+ "that the local level trend component is conditionally linear Gaussian on the volatility\n",
+ "draws."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Stochastic Volatility Process"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We begin by defining the non-linear dynamics, which aren't conditioned contemporaneous\n",
+ "states. Since these processes are traditionally non-linear/non-Gaussian we use the\n",
+ "SSMProblems interface to define the stochastic volatility components."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "struct StochasticVolatilityPrior{T<:Real} <: StatePrior end"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "function SSMProblems.distribution(prior::StochasticVolatilityPrior{T}; kwargs...) where {T}\n",
+ " return product_distribution(Normal(zero(T), T(1)), Normal(zero(T), T(1)))\n",
+ "end"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "For the dynamics, instead of using the `SSMProblems.distribution` utility, we only define\n",
+ "the `simulate` method, which is sufficient for the RBPF."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "struct StochasticVolatility{ΓT<:AbstractVector} <: LatentDynamics\n",
+ " γ::ΓT\n",
+ "end"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "function SSMProblems.simulate(\n",
+ " rng::AbstractRNG,\n",
+ " proc::StochasticVolatility,\n",
+ " step::Integer,\n",
+ " state::AbstractVector{T};\n",
+ " kwargs...,\n",
+ ") where {T<:Real}\n",
+ " new_state = deepcopy(state)\n",
+ " new_state[1:2] += proc.γ .* randn(rng, T, 2)\n",
+ " return new_state\n",
+ "end"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Local Level Trend Process\n",
+ "\n",
+ "For the conditionally linear and Gaussian components, we subtype the model and provide a\n",
+ "keyword argument as the conditional element. In this case $A$ and $b$ remain constant, but\n",
+ "$Q$ is conditional on the log variance, stored in `new_outer` (the nomenclature chosen for\n",
+ "heirarchical modeling)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "struct LocalLevelTrend <: LinearGaussianLatentDynamics end"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "GF.calc_A(::LocalLevelTrend, ::Integer; kwargs...) = [1;;]\n",
+ "GF.calc_b(::LocalLevelTrend, ::Integer; kwargs...) = [0;]\n",
+ "function GF.calc_Q(::LocalLevelTrend, ::Integer; new_outer, kwargs...)\n",
+ " return PDMat([exp(new_outer[1]);;])\n",
+ "end"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Similarly, we define the observation process conditional on a separate log variance."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "struct SimpleObservation <: LinearGaussianObservationProcess end"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "GF.calc_H(::SimpleObservation, ::Integer; kwargs...) = [1;;]\n",
+ "GF.calc_c(::SimpleObservation, ::Integer; kwargs...) = [0;]\n",
+ "function GF.calc_R(::SimpleObservation, ::Integer; new_outer, kwargs...)\n",
+ " return PDMat([exp(new_outer[2]);;])\n",
+ "end"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Unobserved Components with Stochastic Volatility"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The state space model suggested by (Stock & Watson, 2007) can be constructed with the\n",
+ "following method:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "function UCSV(γ::T) where {T<:Real}\n",
+ " stoch_vol_prior = StochasticVolatilityPrior{T}()\n",
+ " stoch_vol_process = StochasticVolatility(fill(γ, 2))\n",
+ "\n",
+ " local_level_model = StateSpaceModel(\n",
+ " GF.HomogeneousGaussianPrior(zeros(T, 1), PDMat([100.0;;])),\n",
+ " LocalLevelTrend(),\n",
+ " SimpleObservation(),\n",
+ " )\n",
+ "\n",
+ " return HierarchicalSSM(stoch_vol_prior, stoch_vol_process, local_level_model)\n",
+ "end;"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "For plotting, we can extract the ancestry of the Rao Blackwellised particles using the\n",
+ "callback system. For our inflation data, this reduces to the following:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "rng = MersenneTwister(1234);\n",
+ "sparse_ancestry = GF.AncestorCallback(nothing);\n",
+ "states, ll = GF.filter(\n",
+ " rng,\n",
+ " UCSV(0.2),\n",
+ " RBPF(BF(2^12), KalmanFilter()),\n",
+ " [[pce] for pce in fred_data.value];\n",
+ " callback=sparse_ancestry,\n",
+ ");"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The `sparse_ancestry` object stores a sparse ancestry tree which we can use to approximate\n",
+ "the smoothed series without an additional backwards pass. We can convert this data\n",
+ "structure to a human readable array by using `GeneralisedFilters.get_ancestry` and then\n",
+ "take the mean path by passing a custom function."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "trends, volatilities = mean_path(GF.get_ancestry(sparse_ancestry.tree), states);\n",
+ "plot_ucsv(trends[1, :], eachrow(volatilities), fred_data)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Outlier Adjustments"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "For additional robustness, (Stock & Watson, 2016) account for one-time measurement shocks\n",
+ "and suggest an alteration in the observation equation, where"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "$$\n",
+ "\\eta_{t} \\sim N(0, s_{t} \\cdot \\sigma_{\\eta, t}^2) \\quad \\quad s_{t} \\sim \\begin{cases}\n",
+ "U(0,2) & \\text{ with probability } p \\\\\n",
+ "\\delta(1) & \\text{ with probability } 1 - p\n",
+ "\\end{cases}\n",
+ "$$"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The prior is the same as before, but with additional state which we can assume will always\n",
+ "be 1; using the `Distributions` interface this is just `Dirac(1)`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "struct OutlierAdjustedVolatilityPrior{T<:Real} <: StatePrior end"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "function SSMProblems.distribution(\n",
+ " prior::OutlierAdjustedVolatilityPrior{T}; kwargs...\n",
+ ") where {T}\n",
+ " return product_distribution(Normal(zero(T), T(1)), Normal(zero(T), T(1)), Dirac(one(T)))\n",
+ "end"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In terms of the model definition, we can construct a separate `LatentDynamics` which\n",
+ "contains the same volatility process as before, but with the respective draw in the third\n",
+ "component."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "struct OutlierAdjustedVolatility{ΓT} <: LatentDynamics\n",
+ " volatility::StochasticVolatility{ΓT}\n",
+ " switch_dist::Bernoulli\n",
+ " outlier_dist::Uniform\n",
+ "end"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The simulation then calls the volatility process, and computes the outlier term in the\n",
+ "third state"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "function SSMProblems.simulate(\n",
+ " rng::AbstractRNG,\n",
+ " proc::OutlierAdjustedVolatility,\n",
+ " step::Integer,\n",
+ " state::AbstractVector{T};\n",
+ " kwargs...,\n",
+ ") where {T<:Real}\n",
+ " new_state = SSMProblems.simulate(rng, proc.volatility, step, state; kwargs...)\n",
+ " new_state[3] = rand(rng, proc.switch_dist) ? rand(rng, proc.outlier_dist) : one(T)\n",
+ " return new_state\n",
+ "end"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "For the observation process, we define a new object where $R$ is dependent on both the\n",
+ "measurement volatility as well as this outlier adjustment coefficient."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "struct OutlierAdjustedObservation <: LinearGaussianObservationProcess end"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "GF.calc_H(::OutlierAdjustedObservation, ::Integer; kwargs...) = [1;;]\n",
+ "GF.calc_c(::OutlierAdjustedObservation, ::Integer; kwargs...) = [0;]\n",
+ "function GF.calc_R(::OutlierAdjustedObservation, ::Integer; new_outer, kwargs...)\n",
+ " return PDMat([new_outer[3] * exp(new_outer[2]);;])\n",
+ "end"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Outlier Adjusted UCSV"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The state space model suggested by (Stock & Watson, 2007) can be constructed with the\n",
+ "following method:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "function UCSVO(γ::T, prob::T) where {T<:Real}\n",
+ " stoch_vol_prior = OutlierAdjustedVolatilityPrior{T}()\n",
+ " stoch_vol_process = OutlierAdjustedVolatility(\n",
+ " StochasticVolatility(fill(γ, 2)), Bernoulli(prob), Uniform{T}(2, 10)\n",
+ " )\n",
+ "\n",
+ " local_level_model = StateSpaceModel(\n",
+ " GF.HomogeneousGaussianPrior(zeros(T, 1), PDMat([100.0;;])),\n",
+ " LocalLevelTrend(),\n",
+ " OutlierAdjustedObservation(),\n",
+ " )\n",
+ "\n",
+ " return HierarchicalSSM(stoch_vol_prior, stoch_vol_process, local_level_model)\n",
+ "end;"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We then repeat the same experiment, this time with an outlier probability of $p = 0.05$"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "rng = MersenneTwister(1234);\n",
+ "sparse_ancestry = GF.AncestorCallback(nothing)\n",
+ "states, ll = GF.filter(\n",
+ " rng,\n",
+ " UCSVO(0.2, 0.05),\n",
+ " RBPF(BF(2^12), KalmanFilter()),\n",
+ " [[pce] for pce in fred_data.value];\n",
+ " callback=sparse_ancestry,\n",
+ ");"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "this process is identical to the last, except with an additional `volatilities` state\n",
+ "which captures the outlier distance. We omit this feature in the plots, but the impact is\n",
+ "clear when comparing the maximum transitory noise around the GFC."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "trends, volatilities = mean_path(GF.get_ancestry(sparse_ancestry.tree), states);\n",
+ "p = plot_ucsv(trends[1, :], eachrow(volatilities), fred_data)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Julia 1.12.3",
+ "language": "julia",
+ "name": "julia-1.12"
+ },
+ "language_info": {
+ "file_extension": ".jl",
+ "mimetype": "application/julia",
+ "name": "julia",
+ "version": "1.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 3
+}
diff --git a/GeneralisedFilters/examples/trend-inflation/utilities.jl b/GeneralisedFilters/examples/trend-inflation/utilities.jl
index e3f9e0da..592da07c 100644
--- a/GeneralisedFilters/examples/trend-inflation/utilities.jl
+++ b/GeneralisedFilters/examples/trend-inflation/utilities.jl
@@ -1,9 +1,20 @@
using CSV, DataFrames
using CairoMakie
using Dates
+using Downloads
using LogExpFunctions
-fred_data = CSV.read(joinpath(INFL_PATH, "data.csv"), DataFrame)
+const FRED_DATA_URL = "https://raw.githubusercontent.com/TuringLang/SSMProblems.jl/main/GeneralisedFilters/examples/trend-inflation/data.csv"
+
+function load_fred_data()
+ for path in (joinpath(pwd(), "data.csv"), joinpath(@__DIR__, "data.csv"))
+ isfile(path) && return CSV.read(path, DataFrame)
+ end
+
+ return CSV.read(Downloads.download(FRED_DATA_URL), DataFrame)
+end
+
+fred_data = load_fred_data()
## PLOTTING UTILITIES ######################################################################
@@ -25,7 +36,7 @@ function mean_path(paths::Vector{Vector{T}}, states) where {T<:GeneralisedFilter
end
function plot_ucsv(trend, volatilities, fred_data)
- fig = Figure(; size=(1200, 500), fontsize=16)
+ fig = Figure(; size=(1200, 675), fontsize=16)
dateticks = date_format(fred_data.date)
trend_ax = Axis(
diff --git a/scripts/setup-git-hooks.sh b/scripts/setup-git-hooks.sh
new file mode 100755
index 00000000..ebf41647
--- /dev/null
+++ b/scripts/setup-git-hooks.sh
@@ -0,0 +1,5 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+git config core.hooksPath .githooks
+echo "Configured git hooks path: $(git config --get core.hooksPath)"
diff --git a/scripts/strip_notebook_outputs.py b/scripts/strip_notebook_outputs.py
new file mode 100755
index 00000000..14f7ef18
--- /dev/null
+++ b/scripts/strip_notebook_outputs.py
@@ -0,0 +1,76 @@
+#!/usr/bin/env python3
+"""Strip execution state from Jupyter notebooks.
+
+Usage:
+ python scripts/strip_notebook_outputs.py [more.ipynb ...]
+"""
+
+from __future__ import annotations
+
+import json
+import sys
+from pathlib import Path
+
+
+def strip_notebook(path: Path) -> bool:
+ original = path.read_text(encoding="utf-8")
+ data = json.loads(original)
+ changed = False
+
+ for cell in data.get("cells", []):
+ if cell.get("cell_type") != "code":
+ continue
+
+ if cell.get("execution_count") is not None:
+ cell["execution_count"] = None
+ changed = True
+
+ if cell.get("outputs"):
+ cell["outputs"] = []
+ changed = True
+
+ metadata = data.get("metadata")
+ if isinstance(metadata, dict) and "widgets" in metadata:
+ del metadata["widgets"]
+ changed = True
+
+ if not changed:
+ return False
+
+ path.write_text(
+ json.dumps(data, indent=1, ensure_ascii=False) + "\n",
+ encoding="utf-8",
+ )
+ return True
+
+
+def main(argv: list[str]) -> int:
+ if len(argv) < 2:
+ print("usage: strip_notebook_outputs.py [more.ipynb ...]", file=sys.stderr)
+ return 2
+
+ failed = False
+ changed_files: list[str] = []
+
+ for raw in argv[1:]:
+ path = Path(raw)
+ if not path.exists():
+ continue
+
+ try:
+ if strip_notebook(path):
+ changed_files.append(str(path))
+ except Exception as exc: # pragma: no cover - defensive for hook usage
+ print(f"failed to strip {path}: {exc}", file=sys.stderr)
+ failed = True
+
+ if changed_files:
+ print("stripped notebook outputs:")
+ for filename in changed_files:
+ print(f" - {filename}")
+
+ return 1 if failed else 0
+
+
+if __name__ == "__main__":
+ raise SystemExit(main(sys.argv))
diff --git a/scripts/sync_notebook_badges.py b/scripts/sync_notebook_badges.py
new file mode 100755
index 00000000..10567dfb
--- /dev/null
+++ b/scripts/sync_notebook_badges.py
@@ -0,0 +1,161 @@
+#!/usr/bin/env python3
+"""Normalize example notebook top-cell badges.
+
+Usage:
+ python scripts/sync_notebook_badges.py [more.ipynb ...]
+"""
+
+from __future__ import annotations
+
+import json
+import re
+import sys
+from pathlib import Path
+
+REPO = "TuringLang/SSMProblems.jl"
+
+
+def source_to_text(source: object) -> str:
+ if isinstance(source, list):
+ return "".join(str(part) for part in source)
+ if isinstance(source, str):
+ return source
+ return ""
+
+
+def text_to_source(text: str) -> list[str]:
+ lines = text.splitlines()
+ if not lines:
+ return []
+ return [f"{line}\n" for line in lines[:-1]] + [lines[-1]]
+
+
+def title_from_slug(slug: str) -> str:
+ return " ".join(word.capitalize() for word in slug.replace("_", "-").split("-"))
+
+
+def is_badge_line(line: str) -> bool:
+ return (
+ "colab.research.google.com" in line
+ or "View%20Source-GitHub" in line
+ or "Example%20Page-Docs" in line
+ )
+
+
+def notebook_urls(path: Path) -> tuple[str, str, str]:
+ parts = path.parts
+ try:
+ pkg_idx = parts.index("GeneralisedFilters")
+ except ValueError:
+ try:
+ pkg_idx = parts.index("SSMProblems")
+ except ValueError as exc:
+ raise ValueError(f"notebook path does not include known package dir: {path}") from exc
+
+ pkg = parts[pkg_idx]
+ rel_path = "/".join(parts[pkg_idx:])
+
+ if len(parts) < pkg_idx + 4 or parts[pkg_idx + 1] != "examples":
+ raise ValueError(f"notebook path is not under {pkg}/examples/: {path}")
+
+ slug = parts[pkg_idx + 2]
+ colab = f"https://colab.research.google.com/github/{REPO}/blob/main/{rel_path}"
+ source = f"https://github.com/{REPO}/blob/main/{rel_path}"
+ docs = f"https://turinglang.org/SSMProblems.jl/{pkg}/dev/examples/{slug}/"
+ return colab, source, docs
+
+
+def build_badge_line(colab: str, source: str, docs: str) -> str:
+ return (
+ f"[]({colab}) "
+ f"[]({source}) "
+ f"[]({docs})"
+ )
+
+
+def normalize_notebook(path: Path) -> bool:
+ original = path.read_text(encoding="utf-8")
+ data = json.loads(original)
+ cells = data.get("cells", [])
+ if not isinstance(cells, list):
+ raise ValueError(f"notebook cells are invalid in {path}")
+
+ colab, source, docs = notebook_urls(path)
+ badges = build_badge_line(colab, source, docs)
+
+ if not cells or cells[0].get("cell_type") != "markdown":
+ slug = path.parent.name
+ title = f"# {title_from_slug(slug)}"
+ top_cell = {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": text_to_source(f"{title}\n\n{badges}"),
+ }
+ cells.insert(0, top_cell)
+ data["cells"] = cells
+ path.write_text(json.dumps(data, indent=1, ensure_ascii=False) + "\n", encoding="utf-8")
+ return True
+
+ first = cells[0]
+ source_text = source_to_text(first.get("source", []))
+ lines = source_text.splitlines()
+
+ title = None
+ for idx, line in enumerate(lines):
+ if re.match(r"^#\s+", line):
+ title = line.strip()
+ lines = lines[idx + 1 :]
+ break
+
+ if title is None:
+ title = f"# {title_from_slug(path.parent.name)}"
+
+ body = [line for line in lines if not is_badge_line(line)]
+ while body and not body[0].strip():
+ body.pop(0)
+ while body and not body[-1].strip():
+ body.pop()
+
+ new_lines = [title, "", badges]
+ if body:
+ new_lines += [""] + body
+
+ new_source = text_to_source("\n".join(new_lines))
+ if first.get("source") == new_source:
+ return False
+
+ first["source"] = new_source
+ path.write_text(json.dumps(data, indent=1, ensure_ascii=False) + "\n", encoding="utf-8")
+ return True
+
+
+def main(argv: list[str]) -> int:
+ if len(argv) < 2:
+ print("usage: sync_notebook_badges.py [more.ipynb ...]", file=sys.stderr)
+ return 2
+
+ changed: list[str] = []
+ failed = False
+
+ for raw in argv[1:]:
+ path = Path(raw)
+ if not path.exists():
+ continue
+
+ try:
+ if normalize_notebook(path):
+ changed.append(str(path))
+ except Exception as exc:
+ print(f"failed to sync badges for {path}: {exc}", file=sys.stderr)
+ failed = True
+
+ if changed:
+ print("synchronized notebook badges:")
+ for filename in changed:
+ print(f" - {filename}")
+
+ return 1 if failed else 0
+
+
+if __name__ == "__main__":
+ raise SystemExit(main(sys.argv))