diff --git a/.githooks/pre-commit b/.githooks/pre-commit new file mode 100755 index 00000000..d85be5a7 --- /dev/null +++ b/.githooks/pre-commit @@ -0,0 +1,13 @@ +#!/usr/bin/env bash +set -euo pipefail + +mapfile -d '' notebooks < <(git diff --cached --name-only -z --diff-filter=ACMR -- '*.ipynb') +if [[ ${#notebooks[@]} -eq 0 ]]; then + exit 0 +fi + +python3 scripts/sync_notebook_badges.py "${notebooks[@]}" +python3 scripts/strip_notebook_outputs.py "${notebooks[@]}" + +# Re-stage files in case the stripper changed them. +git add -- "${notebooks[@]}" diff --git a/.githooks/pre-push b/.githooks/pre-push new file mode 100755 index 00000000..9111bde9 --- /dev/null +++ b/.githooks/pre-push @@ -0,0 +1,16 @@ +#!/usr/bin/env bash +set -euo pipefail + +mapfile -d '' notebooks < <(git ls-files -z '*.ipynb') +if [[ ${#notebooks[@]} -eq 0 ]]; then + exit 0 +fi + +python3 scripts/sync_notebook_badges.py "${notebooks[@]}" +python3 scripts/strip_notebook_outputs.py "${notebooks[@]}" + +if ! git diff --quiet -- "${notebooks[@]}"; then + git add -- "${notebooks[@]}" + echo "Notebook badges/outputs were normalized during pre-push. Commit the updated notebooks and push again." >&2 + exit 1 +fi diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f84ba9b4..5d0a8006 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -12,6 +12,21 @@ concurrency: cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} jobs: + notebook-clean: + name: Notebook Output Check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Verify notebooks are stripped + run: | + notebooks=$(git ls-files '*.ipynb') + if [ -z "$notebooks" ]; then + exit 0 + fi + python3 scripts/sync_notebook_badges.py $notebooks + python3 scripts/strip_notebook_outputs.py $notebooks + git diff --exit-code + test: name: ${{ matrix.pkg.name }} - Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} runs-on: ${{ matrix.os }} @@ -108,4 +123,4 @@ jobs: using CUDA; using GeneralisedFilters; println("GeneralisedFilters with CUDA loaded successfully"); - println("CUDA functional: ", CUDA.functional())' \ No newline at end of file + println("CUDA functional: ", CUDA.functional())' diff --git a/.github/workflows/Documentation.yml b/.github/workflows/Documentation.yml index f335df53..601388bb 100644 --- a/.github/workflows/Documentation.yml +++ b/.github/workflows/Documentation.yml @@ -36,7 +36,24 @@ jobs: - uses: julia-actions/setup-julia@v2 with: version: '1' - - uses: julia-actions/cache@v1 + - uses: julia-actions/cache@v2 + - name: Cache notebook outputs + if: matrix.pkg.name == 'GeneralisedFilters' + uses: actions/cache@v4 + with: + path: GeneralisedFilters/docs/src/examples + key: notebooks-${{ runner.os }}-${{ hashFiles('GeneralisedFilters/examples/**/*.ipynb') }} + - name: Cache pip packages + if: matrix.pkg.name == 'GeneralisedFilters' + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-jupyter-nbconvert + - name: Install notebook tooling + if: matrix.pkg.name == 'GeneralisedFilters' + run: | + python3 -m pip install --upgrade pip + python3 -m pip install jupyter nbconvert - name: Install dependencies run: | julia --project=${{ matrix.pkg.dir }}/docs/ --color=yes -e ' diff --git a/GeneralisedFilters/docs/Project.toml b/GeneralisedFilters/docs/Project.toml index cf645b52..3878ce2f 100644 --- a/GeneralisedFilters/docs/Project.toml +++ b/GeneralisedFilters/docs/Project.toml @@ -1,3 +1,3 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" +IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a" diff --git a/GeneralisedFilters/docs/literate.jl b/GeneralisedFilters/docs/literate.jl deleted file mode 100644 index feb1e47b..00000000 --- a/GeneralisedFilters/docs/literate.jl +++ /dev/null @@ -1,19 +0,0 @@ -# Retrieve name of example and output directory -if length(ARGS) != 2 - error("please specify the name of the example and the output directory") -end -const EXAMPLE = ARGS[1] -const OUTDIR = ARGS[2] - -# Activate environment -# Note that each example's Project.toml must include Literate as a dependency -using Pkg: Pkg -const EXAMPLEPATH = joinpath(@__DIR__, "..", "examples", EXAMPLE) -Pkg.activate(EXAMPLEPATH) -# Pkg.develop(joinpath(@__DIR__, "..", "..", "SSMProblems")) -Pkg.instantiate() -using Literate: Literate - -# Convert to markdown and notebook -const SCRIPTJL = joinpath(EXAMPLEPATH, "script.jl") -Literate.markdown(SCRIPTJL, OUTDIR; name=EXAMPLE, execute=true) diff --git a/GeneralisedFilters/docs/make.jl b/GeneralisedFilters/docs/make.jl index 117d6cff..c5a05d07 100644 --- a/GeneralisedFilters/docs/make.jl +++ b/GeneralisedFilters/docs/make.jl @@ -1,17 +1,20 @@ push!(LOAD_PATH, "../src/") +const REPO = "TuringLang/SSMProblems.jl" +const PKG_SUBDIR = "GeneralisedFilters" # # With minor changes from https://github.com/JuliaGaussianProcesses/AbstractGPs.jl/docs # ### Process examples -# Always rerun examples +const EXAMPLES_ROOT = joinpath(@__DIR__, "..", "examples") const EXAMPLES_OUT = joinpath(@__DIR__, "src", "examples") -ispath(EXAMPLES_OUT) && rm(EXAMPLES_OUT; recursive=true) mkpath(EXAMPLES_OUT) +const EXAMPLE_ASSETS_OUT = joinpath(@__DIR__, "src", "assets", "examples") +mkpath(EXAMPLE_ASSETS_OUT) # Install and precompile all packages # Workaround for https://github.com/JuliaLang/Pkg.jl/issues/2219 -examples = filter!(isdir, readdir(joinpath(@__DIR__, "..", "examples"); join=true)) +examples = sort(filter!(isdir, readdir(EXAMPLES_ROOT; join=true))) above = joinpath(@__DIR__, "..") ssmproblems_path = joinpath(above, "..", "SSMProblems") let script = "using Pkg; Pkg.activate(ARGS[1]); Pkg.develop(path=\"$(above)\"); Pkg.develop(path=\"$(ssmproblems_path)\"); Pkg.instantiate()" @@ -26,11 +29,13 @@ let script = "using Pkg; Pkg.activate(ARGS[1]); Pkg.develop(path=\"$(above)\"); end end # Run examples asynchronously -processes = let literatejl = joinpath(@__DIR__, "literate.jl") +processes = let + notebookjl = joinpath(@__DIR__, "notebook.jl") + docs_project = abspath(@__DIR__) map(examples) do example return run( pipeline( - `$(Base.julia_cmd()) $literatejl $(basename(example)) $EXAMPLES_OUT`; + `$(Base.julia_cmd()) --project=$(docs_project) $notebookjl $(basename(example)) $EXAMPLES_OUT`; stdin=devnull, stdout=devnull, stderr=stderr, @@ -43,6 +48,108 @@ end # Check that all examples were run successfully isempty(processes) || success(processes) || error("some examples were not run successfully") +example_slug(markdown_filename::AbstractString) = splitext(markdown_filename)[1] + +function example_title(markdown_path::AbstractString, slug::AbstractString) + for line in eachline(markdown_path) + stripped = strip(line) + startswith(stripped, "# ") && return strip(stripped[3:end]) + end + return replace(slug, "-" => " ") +end + +function example_summary(markdown_path::AbstractString) + in_fence = false + for line in eachline(markdown_path) + stripped = strip(line) + + if startswith(stripped, "```") + in_fence = !in_fence + continue + end + + if in_fence || isempty(stripped) + continue + end + + if startswith(stripped, "# ") + continue + end + + if occursin("Open in Colab", stripped) || + occursin("Source notebook", stripped) || + startswith(stripped, "*This page was generated") + continue + end + + return replace(stripped, "|" => "\\|") + end + + return "Runnable example with executable source." +end + +function example_thumbnail(slug::AbstractString) + for ext in ("svg", "png", "jpg", "jpeg", "webp") + thumb = joinpath(EXAMPLE_ASSETS_OUT, string(slug, ".", ext)) + if isfile(thumb) + return "../assets/examples/$(slug).$(ext)" + end + end + return "../assets/examples/default.svg" +end + +function links_for_example(slug::AbstractString) + example_dir = joinpath(EXAMPLES_ROOT, slug) + notebook_name = string(slug, ".ipynb") + isfile(joinpath(example_dir, notebook_name)) || + error("example $(slug) must include $(notebook_name)") + + return join( + [ + "[Colab](https://colab.research.google.com/github/$(REPO)/blob/main/$(PKG_SUBDIR)/examples/$(slug)/$(notebook_name))", + "[Notebook](https://github.com/$(REPO)/blob/main/$(PKG_SUBDIR)/examples/$(slug)/$(notebook_name))", + ], + " · ", + ) +end + +function write_examples_index(example_markdowns::Vector{String}) + index_path = joinpath(EXAMPLES_OUT, "index.md") + open(index_path, "w") do io + println(io, "# Examples") + println(io) + println( + io, + "Executable examples for `GeneralisedFilters` with links to notebooks and source files.", + ) + println(io) + println(io, "| Example | Preview |") + println(io, "| :-- | :-- |") + for markdown in example_markdowns + slug = example_slug(markdown) + markdown_path = joinpath(EXAMPLES_OUT, markdown) + title = example_title(markdown_path, slug) + summary = example_summary(markdown_path) + links = links_for_example(slug) + thumbnail = example_thumbnail(slug) + + page_link = "[$(title)]($(markdown))" + left = string(page_link, "
", summary, "
", links) + right = "[![$(title)]($(thumbnail))]($(markdown))" + println(io, "| $(left) | $(right) |") + end + end + return nothing +end + +const EXAMPLE_MARKDOWNS = sort( + filter( + filename -> endswith(filename, ".md") && filename != "index.md", + readdir(EXAMPLES_OUT), + ), +) +write_examples_index(EXAMPLE_MARKDOWNS) + # Building Documenter using Documenter using GeneralisedFilters @@ -56,11 +163,9 @@ makedocs(; format=Documenter.HTML(; size_threshold=1000 * 2^11), # 1Mb per page pages=[ "Home" => "index.md", - "Examples" => [ - map( - (x) -> joinpath("examples", x), - filter!(filename -> endswith(filename, ".md"), readdir(EXAMPLES_OUT)), - )..., + "Examples" => Any[ + "examples/index.md", + map((x) -> joinpath("examples", x), EXAMPLE_MARKDOWNS)..., ], ], #strict=true, diff --git a/GeneralisedFilters/docs/notebook.jl b/GeneralisedFilters/docs/notebook.jl new file mode 100644 index 00000000..7f9cea1f --- /dev/null +++ b/GeneralisedFilters/docs/notebook.jl @@ -0,0 +1,129 @@ +# Build one example from notebook source into markdown for Documenter. +if length(ARGS) != 2 + error("please specify the name of the example and the output directory") +end + +const EXAMPLE = ARGS[1] +const OUTDIR = ARGS[2] +const REPO = "TuringLang/SSMProblems.jl" +const PKG_SUBDIR = "GeneralisedFilters" +const DOCS_ENV = abspath(@__DIR__) + +using Pkg: Pkg + +const EXAMPLEPATH = joinpath(@__DIR__, "..", "examples", EXAMPLE) +const NOTEBOOK_FILENAME = string(EXAMPLE, ".ipynb") +const NOTEBOOK = joinpath(EXAMPLEPATH, NOTEBOOK_FILENAME) +const MARKDOWN = joinpath(OUTDIR, string(EXAMPLE, ".md")) + +isfile(NOTEBOOK) || error("example $(EXAMPLE) must include $(NOTEBOOK_FILENAME)") + +if isfile(MARKDOWN) + @info "Skipping $(EXAMPLE): cached output found at $(MARKDOWN)" + exit(0) +end + +# Keep notebook execution in the example's own environment. +Pkg.activate(EXAMPLEPATH) +Pkg.instantiate() + +Pkg.activate(DOCS_ENV) +Pkg.instantiate() +using IJulia + +const KERNEL_NAME = "gf-docs-julia-$(VERSION.major).$(VERSION.minor)" +IJulia.installkernel(KERNEL_NAME, "--project=$(DOCS_ENV)"; specname=KERNEL_NAME) + +function run_nbconvert( + examplepath::AbstractString, + outdir::AbstractString, + name::AbstractString, + kernel::AbstractString, +) + jupyter = Sys.which("jupyter") + isnothing(jupyter) && error( + "jupyter executable not found. Install it (e.g. `pip install jupyter nbconvert`) before building docs.", + ) + + cmd = `$(jupyter) nbconvert --to markdown --execute --ExecutePreprocessor.timeout=3600 --ExecutePreprocessor.kernel_name=$(kernel) --output=$(name) --output-dir=$(outdir) $(NOTEBOOK_FILENAME)` + run(pipeline(Cmd(cmd; dir=examplepath); stdin=devnull, stdout=devnull, stderr=stderr)) + return nothing +end + +function inject_edit_url(markdown_path::AbstractString, example::AbstractString) + content = read(markdown_path, String) + meta_block = string( + "```@meta\n", + "EditURL = \"../../../examples/", + example, + "/", + NOTEBOOK_FILENAME, + "\"\n", + "```\n\n", + ) + + if occursin(r"(?m)^```@meta$", content) + content = replace(content, r"(?s)^```@meta.*?```\s*" => meta_block; count=1) + else + content = string(meta_block, content) + end + + write(markdown_path, content) + return nothing +end + +function inject_docs_badges(markdown_path::AbstractString, example::AbstractString) + content = read(markdown_path, String) + notebook_name = string(example, ".ipynb") + + colab_url = string( + "https://colab.research.google.com/github/", + REPO, + "/blob/main/", + PKG_SUBDIR, + "/examples/", + example, + "/", + notebook_name, + ) + source_url = string( + "https://github.com/", + REPO, + "/blob/main/", + PKG_SUBDIR, + "/examples/", + example, + "/", + notebook_name, + ) + badge_line = string( + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](", + colab_url, + ") [![View Source](https://img.shields.io/badge/View%20Source-GitHub-181717?logo=github)](", + source_url, + ")", + ) + + # Remove existing notebook badge lines, then inject docs-specific badges. + content = replace(content, r"(?m)^\[\!\[Open in Colab\].*\n?" => "";) + content = replace(content, r"\n{3,}" => "\n\n";) + + heading_match = match(r"(?m)^# .+$", content) + if isnothing(heading_match) + content = string(badge_line, "\n\n", content) + else + head_start = heading_match.offset + head_end = head_start + ncodeunits(heading_match.match) - 1 + head = content[1:head_end] + tail = content[(head_end + 1):end] + tail = replace(tail, r"^\n+" => "") + content = string(head, "\n\n", badge_line, "\n\n", tail) + end + + write(markdown_path, content) + return nothing +end + +run_nbconvert(EXAMPLEPATH, OUTDIR, EXAMPLE, KERNEL_NAME) +inject_docs_badges(MARKDOWN, EXAMPLE) +inject_edit_url(MARKDOWN, EXAMPLE) diff --git a/GeneralisedFilters/docs/src/assets/examples/default.svg b/GeneralisedFilters/docs/src/assets/examples/default.svg new file mode 100644 index 00000000..a75b498e --- /dev/null +++ b/GeneralisedFilters/docs/src/assets/examples/default.svg @@ -0,0 +1,12 @@ + + + + + + + + + + GeneralisedFilters + Notebook Example + diff --git a/GeneralisedFilters/docs/src/assets/examples/trend-inflation.svg b/GeneralisedFilters/docs/src/assets/examples/trend-inflation.svg new file mode 100644 index 00000000..c2506b9e --- /dev/null +++ b/GeneralisedFilters/docs/src/assets/examples/trend-inflation.svg @@ -0,0 +1,610 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/GeneralisedFilters/examples/trend-inflation/Project.toml b/GeneralisedFilters/examples/trend-inflation/Project.toml index dd1155bf..9eced912 100644 --- a/GeneralisedFilters/examples/trend-inflation/Project.toml +++ b/GeneralisedFilters/examples/trend-inflation/Project.toml @@ -5,7 +5,6 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" GeneralisedFilters = "3ef92589-7ab8-43f9-b5b9-a3a0c86ecbb7" -Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/GeneralisedFilters/examples/trend-inflation/script-alt.jl b/GeneralisedFilters/examples/trend-inflation/script-alt.jl deleted file mode 100644 index 223e5577..00000000 --- a/GeneralisedFilters/examples/trend-inflation/script-alt.jl +++ /dev/null @@ -1,154 +0,0 @@ -using GeneralisedFilters -using SSMProblems -using Distributions - -using Random -using StatsBase - -include("utilities.jl") - -## STATE PRIORS ############################################################################ - -struct LocalLevelTrendPrior{T<:Real} <: StatePrior end - -function SSMProblems.distribution(prior::LocalLevelTrendPrior{T}; kwargs...) where {T} - return product_distribution( - Normal(zero(T), T(5)), Normal(zero(T), T(1)), Normal(zero(T), T(1)) - ) -end - -struct OutlierAdjustedTrendPrior{T<:Real} <: StatePrior end - -function SSMProblems.distribution(prior::OutlierAdjustedTrendPrior{T}; kwargs...) where {T} - return product_distribution( - Normal(zero(T), T(5)), Normal(zero(T), T(1)), Normal(zero(T), T(1)), Dirac(one(T)) - ) -end - -## LATENT DYNAMICS ######################################################################### - -struct LocalLevelTrend{ΓT<:AbstractVector} <: LatentDynamics - γ::ΓT -end - -function SSMProblems.logdensity( - proc::LocalLevelTrend, ::Integer, prev_state, state, kwargs... -) - vol_prob = logpdf(MvNormal(prev_state[2:end], proc.γ), state[2:end]) - trend_prob = logpdf(Normal(prev_state[1], exp(prev_state[2] / 2)), state[1]) - return vol_prob + trend_prob -end - -function SSMProblems.simulate( - rng::AbstractRNG, - proc::LocalLevelTrend, - step::Integer, - state::AbstractVector{T}; - kwargs..., -) where {T<:Real} - new_state = deepcopy(state) - new_state[2:3] += proc.γ .* randn(rng, T, 2) - new_state[1] += exp(new_state[2] / 2) * randn(T) - return new_state -end - -struct OutlierAdjustedTrend{ΓT<:AbstractVector} <: LatentDynamics - trend::LocalLevelTrend{ΓT} - switch_dist::Bernoulli - outlier_dist::Uniform -end - -function SSMProblems.logdensity( - proc::OutlierAdjustedTrend, step::Integer, prev_state, state, kwargs... -) - return SSMProblems.logdensity(proc.trend, step, prev_state, state; kwargs...) -end - -function SSMProblems.simulate( - rng::AbstractRNG, - proc::OutlierAdjustedTrend, - step::Integer, - state::AbstractVector{T}; - kwargs..., -) where {T<:Real} - new_state = SSMProblems.simulate(rng, proc.trend, step, state; kwargs...) - new_state[4] = rand(rng, proc.switch_dist) ? rand(rng, proc.outlier_dist) : one(T) - return new_state -end - -## OBSERVATION PROCESS ##################################################################### - -struct OutlierAdjustedObservation <: ObservationProcess end - -function SSMProblems.distribution( - proc::OutlierAdjustedObservation, step::Integer, state::AbstractVector; kwargs... -) - return Normal(state[1], sqrt(state[4]) * exp(state[3] / 2)) -end - -struct SimpleObservation <: ObservationProcess end - -function SSMProblems.distribution( - proc::SimpleObservation, step::Integer, state::AbstractVector; kwargs... -) - return Normal(state[1], exp(state[3] / 2)) -end - -## MAIN #################################################################################### - -# include UCSV as a baseline -function UCSV(γ::T) where {T<:Real} - return StateSpaceModel( - LocalLevelTrendPrior{T}(), LocalLevelTrend(fill(γ, 2)), SimpleObservation() - ) -end - -# quick demo of the outlier-adjusted univariate UCSV model -function UCSVO(γ::T, prob::T) where {T<:Real} - trend = LocalLevelTrend(fill(γ, 2)) - return StateSpaceModel( - OutlierAdjustedTrendPrior{T}(), - OutlierAdjustedTrend(trend, Bernoulli(prob), Uniform{T}(2, 10)), - OutlierAdjustedObservation(), - ) -end - -# wrapper to plot and demo the model -function plot_ucsv(rng::AbstractRNG, model, data) - alg = BF(2^14; threshold=1.0, resampler=Systematic()) - sparse_ancestry = GeneralisedFilters.AncestorCallback(nothing) - states, ll = GeneralisedFilters.filter(rng, model, alg, data; callback=sparse_ancestry) - - fig = Figure(; size=(1200, 500), fontsize=16) - dateticks = date_format(fred.data.date) - - all_paths = map(x -> hcat(x...), GeneralisedFilters.get_ancestry(sparse_ancestry.tree)) - mean_paths = mean(all_paths, Weights(StatsBase.weights(states.log_weights))) - - ax = Axis( - fig[1:2, 1]; - limits=(nothing, (-14, 18)), - title="Trend Inflation", - xtickformat=dateticks, - ) - - lines!(fig[1:2, 1], vcat(0, data...); color=:red, linestyle=:dash) - lines!(ax, mean_paths[1, :]; color=:black) - - ax1 = Axis(fig[1, 2]; title="Volatility", xtickformat=dateticks) - lines!(ax1, exp.(0.5 * mean_paths[2, :]); color=:black, label="permanent") - axislegend(ax1; position=:rt) - - ax2 = Axis(fig[2, 2]; xtickformat=dateticks) - lines!(ax2, exp.(0.5 * mean_paths[3, :]); color=:black, label="transitory") - axislegend(ax2; position=:lt) - - display(fig) - return ll -end - -rng = MersenneTwister(1234); - -# plot both models side by side, notice the difference in volatility -plot_ucsv(rng, UCSV(0.2), fred.data.value); -plot_ucsv(rng, UCSVO(0.2, 0.05), fred.data.value); diff --git a/GeneralisedFilters/examples/trend-inflation/script.jl b/GeneralisedFilters/examples/trend-inflation/script.jl deleted file mode 100644 index 126db1db..00000000 --- a/GeneralisedFilters/examples/trend-inflation/script.jl +++ /dev/null @@ -1,253 +0,0 @@ -# # Trend Inflation -# -# This example is a replication of the univariate state space model suggested by (Stock & -# Watson, 2016) using GeneralisedFilters to define a heirarchical model for use in Rao- -# Blackwellised particle filtering. - -using GeneralisedFilters -using SSMProblems -using Distributions -using Random -using StatsBase -using LinearAlgebra -using PDMats - -const GF = GeneralisedFilters - -INFL_PATH = joinpath(@__DIR__, "..", "..", "..", "examples", "trend-inflation"); #hide -# INFL_PATH = joinpath(@__DIR__) -include(joinpath(INFL_PATH, "utilities.jl")); #hide - -# ## Model Definition - -# We begin by defining the local level trend model, a linear Gaussian model with a weakly -# stationary random walk component. The dynamics of which are as follows: - -# ```math -# \begin{aligned} -# y_{t} &= x_{t} + \eta_{t} \\ -# x_{t+1} &= x_{t} + \varepsilon_{t} -# \end{aligned} -# ``` - -# However, this model is not enough to capture trend dynamics when faced with structural -# breaks. (Stock & Watson, 2007) suggest adding a stochastic volatiltiy component, defined -# like so: - -# ```math -# \begin{aligned} -# \log \sigma_{\eta, t+1} = \log \sigma_{\eta, t} + \nu_{\eta, t} \\ -# \log \sigma_{\varepsilon, t+1} = \log \sigma_{\varepsilon, t} + \nu_{\varepsilon, t} -# \end{aligned} -# ``` - -# where $\nu_{z,t} \sim N(0, \gamma)$ for $z \in \{ \varepsilon, \eta \}$. - -# Using `GeneralisedFilters`, we can construct a heirarchical version of this model such -# that the local level trend component is conditionally linear Gaussian on the volatility -# draws. - -# #### Stochastic Volatility Process - -# We begin by defining the non-linear dynamics, which aren't conditioned contemporaneous -# states. Since these processes are traditionally non-linear/non-Gaussian we use the -# SSMProblems interface to define the stochastic volatility components. - -struct StochasticVolatilityPrior{T<:Real} <: StatePrior end - -# - -function SSMProblems.distribution(prior::StochasticVolatilityPrior{T}; kwargs...) where {T} - return product_distribution(Normal(zero(T), T(1)), Normal(zero(T), T(1))) -end - -# For the dynamics, instead of using the `SSMProblems.distribution` utility, we only define -# the `simulate` method, which is sufficient for the RBPF. - -struct StochasticVolatility{ΓT<:AbstractVector} <: LatentDynamics - γ::ΓT -end - -# - -function SSMProblems.simulate( - rng::AbstractRNG, - proc::StochasticVolatility, - step::Integer, - state::AbstractVector{T}; - kwargs..., -) where {T<:Real} - new_state = deepcopy(state) - new_state[1:2] += proc.γ .* randn(rng, T, 2) - return new_state -end - -# #### Local Level Trend Process -# -# For the conditionally linear and Gaussian components, we subtype the model and provide a -# keyword argument as the conditional element. In this case $A$ and $b$ remain constant, but -# $Q$ is conditional on the log variance, stored in `new_outer` (the nomenclature chosen for -# heirarchical modeling). - -struct LocalLevelTrend <: LinearGaussianLatentDynamics end - -# - -GF.calc_A(::LocalLevelTrend, ::Integer; kwargs...) = [1;;] -GF.calc_b(::LocalLevelTrend, ::Integer; kwargs...) = [0;] -function GF.calc_Q(::LocalLevelTrend, ::Integer; new_outer, kwargs...) - return PDMat([exp(new_outer[1]);;]) -end - -# Similarly, we define the observation process conditional on a separate log variance. - -struct SimpleObservation <: LinearGaussianObservationProcess end - -# - -GF.calc_H(::SimpleObservation, ::Integer; kwargs...) = [1;;] -GF.calc_c(::SimpleObservation, ::Integer; kwargs...) = [0;] -function GF.calc_R(::SimpleObservation, ::Integer; new_outer, kwargs...) - return PDMat([exp(new_outer[2]);;]) -end - -# ### Unobserved Components with Stochastic Volatility - -# The state space model suggested by (Stock & Watson, 2007) can be constructed with the -# following method: - -function UCSV(γ::T) where {T<:Real} - stoch_vol_prior = StochasticVolatilityPrior{T}() - stoch_vol_process = StochasticVolatility(fill(γ, 2)) - - local_level_model = StateSpaceModel( - GF.HomogeneousGaussianPrior(zeros(T, 1), PDMat([100.0;;])), - LocalLevelTrend(), - SimpleObservation(), - ) - - return HierarchicalSSM(stoch_vol_prior, stoch_vol_process, local_level_model) -end; - -# For plotting, we can extract the ancestry of the Rao Blackwellised particles using the -# callback system. For our inflation data, this reduces to the following: - -rng = MersenneTwister(1234); -sparse_ancestry = GF.AncestorCallback(nothing); -states, ll = GF.filter( - rng, - UCSV(0.2), - RBPF(BF(2^12), KalmanFilter()), - [[pce] for pce in fred_data.value]; - callback=sparse_ancestry, -); - -# The `sparse_ancestry` object stores a sparse ancestry tree which we can use to approximate -# the smoothed series without an additional backwards pass. We can convert this data -# structure to a human readable array by using `GeneralisedFilters.get_ancestry` and then -# take the mean path by passing a custom function. - -trends, volatilities = mean_path(GF.get_ancestry(sparse_ancestry.tree), states); -plot_ucsv(trends[1, :], eachrow(volatilities), fred_data) - -# #### Outlier Adjustments - -# For additional robustness, (Stock & Watson, 2016) account for one-time measurement shocks -# and suggest an alteration in the observation equation, where - -# ```math -# \eta_{t} \sim N(0, s_{t} \cdot \sigma_{\eta, t}^2) \quad \quad s_{t} \sim \begin{cases} -# U(0,2) & \text{ with probability } p \\ -# \delta(1) & \text{ with probability } 1 - p -# \end{cases} -# ``` - -# The prior is the same as before, but with additional state which we can assume will always -# be 1; using the `Distributions` interface this is just `Dirac(1)` - -struct OutlierAdjustedVolatilityPrior{T<:Real} <: StatePrior end - -# - -function SSMProblems.distribution( - prior::OutlierAdjustedVolatilityPrior{T}; kwargs... -) where {T} - return product_distribution(Normal(zero(T), T(1)), Normal(zero(T), T(1)), Dirac(one(T))) -end - -# In terms of the model definition, we can construct a separate `LatentDynamics` which -# contains the same volatility process as before, but with the respective draw in the third -# component. - -struct OutlierAdjustedVolatility{ΓT} <: LatentDynamics - volatility::StochasticVolatility{ΓT} - switch_dist::Bernoulli - outlier_dist::Uniform -end - -# The simulation then calls the volatility process, and computes the outlier term in the -# third state - -function SSMProblems.simulate( - rng::AbstractRNG, - proc::OutlierAdjustedVolatility, - step::Integer, - state::AbstractVector{T}; - kwargs..., -) where {T<:Real} - new_state = SSMProblems.simulate(rng, proc.volatility, step, state; kwargs...) - new_state[3] = rand(rng, proc.switch_dist) ? rand(rng, proc.outlier_dist) : one(T) - return new_state -end - -# For the observation process, we define a new object where $R$ is dependent on both the -# measurement volatility as well as this outlier adjustment coefficient. - -struct OutlierAdjustedObservation <: LinearGaussianObservationProcess end - -# - -GF.calc_H(::OutlierAdjustedObservation, ::Integer; kwargs...) = [1;;] -GF.calc_c(::OutlierAdjustedObservation, ::Integer; kwargs...) = [0;] -function GF.calc_R(::OutlierAdjustedObservation, ::Integer; new_outer, kwargs...) - return PDMat([new_outer[3] * exp(new_outer[2]);;]) -end - -# ### Outlier Adjusted UCSV - -# The state space model suggested by (Stock & Watson, 2007) can be constructed with the -# following method: - -function UCSVO(γ::T, prob::T) where {T<:Real} - stoch_vol_prior = OutlierAdjustedVolatilityPrior{T}() - stoch_vol_process = OutlierAdjustedVolatility( - StochasticVolatility(fill(γ, 2)), Bernoulli(prob), Uniform{T}(2, 10) - ) - - local_level_model = StateSpaceModel( - GF.HomogeneousGaussianPrior(zeros(T, 1), PDMat([100.0;;])), - LocalLevelTrend(), - OutlierAdjustedObservation(), - ) - - return HierarchicalSSM(stoch_vol_prior, stoch_vol_process, local_level_model) -end; - -# We then repeat the same experiment, this time with an outlier probability of $p = 0.05$ - -rng = MersenneTwister(1234); -sparse_ancestry = GF.AncestorCallback(nothing) -states, ll = GF.filter( - rng, - UCSVO(0.2, 0.05), - RBPF(BF(2^12), KalmanFilter()), - [[pce] for pce in fred_data.value]; - callback=sparse_ancestry, -); - -# this process is identical to the last, except with an additional `volatilities` state -# which captures the outlier distance. We omit this feature in the plots, but the impact is -# clear when comparing the maximum transitory noise around the GFC. - -trends, volatilities = mean_path(GF.get_ancestry(sparse_ancestry.tree), states); -plot_ucsv(trends[1, :], eachrow(volatilities), fred_data) diff --git a/GeneralisedFilters/examples/trend-inflation/trend-inflation.ipynb b/GeneralisedFilters/examples/trend-inflation/trend-inflation.ipynb new file mode 100644 index 00000000..8ae65267 --- /dev/null +++ b/GeneralisedFilters/examples/trend-inflation/trend-inflation.ipynb @@ -0,0 +1,592 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Trend Inflation\n", + "\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/TuringLang/SSMProblems.jl/blob/main/GeneralisedFilters/examples/trend-inflation/trend-inflation.ipynb) [![View Source](https://img.shields.io/badge/View%20Source-GitHub-181717?logo=github)](https://github.com/TuringLang/SSMProblems.jl/blob/main/GeneralisedFilters/examples/trend-inflation/trend-inflation.ipynb) [![Example Page](https://img.shields.io/badge/Example%20Page-Docs-0A7F2E)](https://turinglang.org/SSMProblems.jl/GeneralisedFilters/dev/examples/trend-inflation/)\n", + "\n", + "This example is a replication of the univariate state space model suggested by (Stock &\n", + "Watson, 2016) using GeneralisedFilters to define a heirarchical model for use in Rao-Blackwellised particle filtering." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "remove-output" + ] + }, + "outputs": [], + "source": [ + "using Pkg\n", + "\n", + "Pkg.activate(\".\")\n", + "if isfile(\"Project.toml\")\n", + " Pkg.instantiate()\n", + "else\n", + " Pkg.add([\n", + " \"SSMProblems\",\n", + " \"GeneralisedFilters\",\n", + " \"CSV\",\n", + " \"CairoMakie\",\n", + " \"DataFrames\",\n", + " \"Distributions\",\n", + " \"LogExpFunctions\",\n", + " \"PDMats\",\n", + " \"StatsBase\",\n", + " ])\n", + "end;" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "remove-output" + ] + }, + "outputs": [], + "source": [ + "using GeneralisedFilters\n", + "using SSMProblems\n", + "using Distributions\n", + "using Random\n", + "using StatsBase\n", + "using LinearAlgebra\n", + "using PDMats\n", + "\n", + "const GF = GeneralisedFilters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "using Downloads\n", + "\n", + "const UTILITIES_URL = \"https://raw.githubusercontent.com/TuringLang/SSMProblems.jl/main/GeneralisedFilters/examples/trend-inflation/utilities.jl\"\n", + "const UTILITIES_PATH = isfile(\"utilities.jl\") ? \"utilities.jl\" : Downloads.download(UTILITIES_URL)\n", + "include(UTILITIES_PATH);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model Definition" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We begin by defining the local level trend model, a linear Gaussian model with a weakly\n", + "stationary random walk component. The dynamics of which are as follows:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "$$\n", + "\\begin{aligned}\n", + " y_{t} &= x_{t} + \\eta_{t} \\\\\n", + " x_{t+1} &= x_{t} + \\varepsilon_{t}\n", + "\\end{aligned}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "However, this model is not enough to capture trend dynamics when faced with structural\n", + "breaks. (Stock & Watson, 2007) suggest adding a stochastic volatiltiy component, defined\n", + "like so:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "$$\n", + "\\begin{aligned}\n", + " \\log \\sigma_{\\eta, t+1} = \\log \\sigma_{\\eta, t} + \\nu_{\\eta, t} \\\\\n", + " \\log \\sigma_{\\varepsilon, t+1} = \\log \\sigma_{\\varepsilon, t} + \\nu_{\\varepsilon, t}\n", + "\\end{aligned}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "where $\\nu_{z,t} \\sim N(0, \\gamma)$ for $z \\in \\{ \\varepsilon, \\eta \\}$." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Using `GeneralisedFilters`, we can construct a heirarchical version of this model such\n", + "that the local level trend component is conditionally linear Gaussian on the volatility\n", + "draws." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Stochastic Volatility Process" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We begin by defining the non-linear dynamics, which aren't conditioned contemporaneous\n", + "states. Since these processes are traditionally non-linear/non-Gaussian we use the\n", + "SSMProblems interface to define the stochastic volatility components." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "struct StochasticVolatilityPrior{T<:Real} <: StatePrior end" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "function SSMProblems.distribution(prior::StochasticVolatilityPrior{T}; kwargs...) where {T}\n", + " return product_distribution(Normal(zero(T), T(1)), Normal(zero(T), T(1)))\n", + "end" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For the dynamics, instead of using the `SSMProblems.distribution` utility, we only define\n", + "the `simulate` method, which is sufficient for the RBPF." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "struct StochasticVolatility{ΓT<:AbstractVector} <: LatentDynamics\n", + " γ::ΓT\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "function SSMProblems.simulate(\n", + " rng::AbstractRNG,\n", + " proc::StochasticVolatility,\n", + " step::Integer,\n", + " state::AbstractVector{T};\n", + " kwargs...,\n", + ") where {T<:Real}\n", + " new_state = deepcopy(state)\n", + " new_state[1:2] += proc.γ .* randn(rng, T, 2)\n", + " return new_state\n", + "end" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Local Level Trend Process\n", + "\n", + "For the conditionally linear and Gaussian components, we subtype the model and provide a\n", + "keyword argument as the conditional element. In this case $A$ and $b$ remain constant, but\n", + "$Q$ is conditional on the log variance, stored in `new_outer` (the nomenclature chosen for\n", + "heirarchical modeling)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "struct LocalLevelTrend <: LinearGaussianLatentDynamics end" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "GF.calc_A(::LocalLevelTrend, ::Integer; kwargs...) = [1;;]\n", + "GF.calc_b(::LocalLevelTrend, ::Integer; kwargs...) = [0;]\n", + "function GF.calc_Q(::LocalLevelTrend, ::Integer; new_outer, kwargs...)\n", + " return PDMat([exp(new_outer[1]);;])\n", + "end" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Similarly, we define the observation process conditional on a separate log variance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "struct SimpleObservation <: LinearGaussianObservationProcess end" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "GF.calc_H(::SimpleObservation, ::Integer; kwargs...) = [1;;]\n", + "GF.calc_c(::SimpleObservation, ::Integer; kwargs...) = [0;]\n", + "function GF.calc_R(::SimpleObservation, ::Integer; new_outer, kwargs...)\n", + " return PDMat([exp(new_outer[2]);;])\n", + "end" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Unobserved Components with Stochastic Volatility" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The state space model suggested by (Stock & Watson, 2007) can be constructed with the\n", + "following method:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "function UCSV(γ::T) where {T<:Real}\n", + " stoch_vol_prior = StochasticVolatilityPrior{T}()\n", + " stoch_vol_process = StochasticVolatility(fill(γ, 2))\n", + "\n", + " local_level_model = StateSpaceModel(\n", + " GF.HomogeneousGaussianPrior(zeros(T, 1), PDMat([100.0;;])),\n", + " LocalLevelTrend(),\n", + " SimpleObservation(),\n", + " )\n", + "\n", + " return HierarchicalSSM(stoch_vol_prior, stoch_vol_process, local_level_model)\n", + "end;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For plotting, we can extract the ancestry of the Rao Blackwellised particles using the\n", + "callback system. For our inflation data, this reduces to the following:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rng = MersenneTwister(1234);\n", + "sparse_ancestry = GF.AncestorCallback(nothing);\n", + "states, ll = GF.filter(\n", + " rng,\n", + " UCSV(0.2),\n", + " RBPF(BF(2^12), KalmanFilter()),\n", + " [[pce] for pce in fred_data.value];\n", + " callback=sparse_ancestry,\n", + ");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `sparse_ancestry` object stores a sparse ancestry tree which we can use to approximate\n", + "the smoothed series without an additional backwards pass. We can convert this data\n", + "structure to a human readable array by using `GeneralisedFilters.get_ancestry` and then\n", + "take the mean path by passing a custom function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trends, volatilities = mean_path(GF.get_ancestry(sparse_ancestry.tree), states);\n", + "plot_ucsv(trends[1, :], eachrow(volatilities), fred_data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Outlier Adjustments" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For additional robustness, (Stock & Watson, 2016) account for one-time measurement shocks\n", + "and suggest an alteration in the observation equation, where" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "$$\n", + "\\eta_{t} \\sim N(0, s_{t} \\cdot \\sigma_{\\eta, t}^2) \\quad \\quad s_{t} \\sim \\begin{cases}\n", + "U(0,2) & \\text{ with probability } p \\\\\n", + "\\delta(1) & \\text{ with probability } 1 - p\n", + "\\end{cases}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The prior is the same as before, but with additional state which we can assume will always\n", + "be 1; using the `Distributions` interface this is just `Dirac(1)`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "struct OutlierAdjustedVolatilityPrior{T<:Real} <: StatePrior end" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "function SSMProblems.distribution(\n", + " prior::OutlierAdjustedVolatilityPrior{T}; kwargs...\n", + ") where {T}\n", + " return product_distribution(Normal(zero(T), T(1)), Normal(zero(T), T(1)), Dirac(one(T)))\n", + "end" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In terms of the model definition, we can construct a separate `LatentDynamics` which\n", + "contains the same volatility process as before, but with the respective draw in the third\n", + "component." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "struct OutlierAdjustedVolatility{ΓT} <: LatentDynamics\n", + " volatility::StochasticVolatility{ΓT}\n", + " switch_dist::Bernoulli\n", + " outlier_dist::Uniform\n", + "end" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The simulation then calls the volatility process, and computes the outlier term in the\n", + "third state" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "function SSMProblems.simulate(\n", + " rng::AbstractRNG,\n", + " proc::OutlierAdjustedVolatility,\n", + " step::Integer,\n", + " state::AbstractVector{T};\n", + " kwargs...,\n", + ") where {T<:Real}\n", + " new_state = SSMProblems.simulate(rng, proc.volatility, step, state; kwargs...)\n", + " new_state[3] = rand(rng, proc.switch_dist) ? rand(rng, proc.outlier_dist) : one(T)\n", + " return new_state\n", + "end" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For the observation process, we define a new object where $R$ is dependent on both the\n", + "measurement volatility as well as this outlier adjustment coefficient." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "struct OutlierAdjustedObservation <: LinearGaussianObservationProcess end" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "GF.calc_H(::OutlierAdjustedObservation, ::Integer; kwargs...) = [1;;]\n", + "GF.calc_c(::OutlierAdjustedObservation, ::Integer; kwargs...) = [0;]\n", + "function GF.calc_R(::OutlierAdjustedObservation, ::Integer; new_outer, kwargs...)\n", + " return PDMat([new_outer[3] * exp(new_outer[2]);;])\n", + "end" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Outlier Adjusted UCSV" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The state space model suggested by (Stock & Watson, 2007) can be constructed with the\n", + "following method:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "function UCSVO(γ::T, prob::T) where {T<:Real}\n", + " stoch_vol_prior = OutlierAdjustedVolatilityPrior{T}()\n", + " stoch_vol_process = OutlierAdjustedVolatility(\n", + " StochasticVolatility(fill(γ, 2)), Bernoulli(prob), Uniform{T}(2, 10)\n", + " )\n", + "\n", + " local_level_model = StateSpaceModel(\n", + " GF.HomogeneousGaussianPrior(zeros(T, 1), PDMat([100.0;;])),\n", + " LocalLevelTrend(),\n", + " OutlierAdjustedObservation(),\n", + " )\n", + "\n", + " return HierarchicalSSM(stoch_vol_prior, stoch_vol_process, local_level_model)\n", + "end;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We then repeat the same experiment, this time with an outlier probability of $p = 0.05$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rng = MersenneTwister(1234);\n", + "sparse_ancestry = GF.AncestorCallback(nothing)\n", + "states, ll = GF.filter(\n", + " rng,\n", + " UCSVO(0.2, 0.05),\n", + " RBPF(BF(2^12), KalmanFilter()),\n", + " [[pce] for pce in fred_data.value];\n", + " callback=sparse_ancestry,\n", + ");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "this process is identical to the last, except with an additional `volatilities` state\n", + "which captures the outlier distance. We omit this feature in the plots, but the impact is\n", + "clear when comparing the maximum transitory noise around the GFC." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trends, volatilities = mean_path(GF.get_ancestry(sparse_ancestry.tree), states);\n", + "p = plot_ucsv(trends[1, :], eachrow(volatilities), fred_data)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Julia 1.12.3", + "language": "julia", + "name": "julia-1.12" + }, + "language_info": { + "file_extension": ".jl", + "mimetype": "application/julia", + "name": "julia", + "version": "1.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 3 +} diff --git a/GeneralisedFilters/examples/trend-inflation/utilities.jl b/GeneralisedFilters/examples/trend-inflation/utilities.jl index e3f9e0da..592da07c 100644 --- a/GeneralisedFilters/examples/trend-inflation/utilities.jl +++ b/GeneralisedFilters/examples/trend-inflation/utilities.jl @@ -1,9 +1,20 @@ using CSV, DataFrames using CairoMakie using Dates +using Downloads using LogExpFunctions -fred_data = CSV.read(joinpath(INFL_PATH, "data.csv"), DataFrame) +const FRED_DATA_URL = "https://raw.githubusercontent.com/TuringLang/SSMProblems.jl/main/GeneralisedFilters/examples/trend-inflation/data.csv" + +function load_fred_data() + for path in (joinpath(pwd(), "data.csv"), joinpath(@__DIR__, "data.csv")) + isfile(path) && return CSV.read(path, DataFrame) + end + + return CSV.read(Downloads.download(FRED_DATA_URL), DataFrame) +end + +fred_data = load_fred_data() ## PLOTTING UTILITIES ###################################################################### @@ -25,7 +36,7 @@ function mean_path(paths::Vector{Vector{T}}, states) where {T<:GeneralisedFilter end function plot_ucsv(trend, volatilities, fred_data) - fig = Figure(; size=(1200, 500), fontsize=16) + fig = Figure(; size=(1200, 675), fontsize=16) dateticks = date_format(fred_data.date) trend_ax = Axis( diff --git a/scripts/setup-git-hooks.sh b/scripts/setup-git-hooks.sh new file mode 100755 index 00000000..ebf41647 --- /dev/null +++ b/scripts/setup-git-hooks.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +set -euo pipefail + +git config core.hooksPath .githooks +echo "Configured git hooks path: $(git config --get core.hooksPath)" diff --git a/scripts/strip_notebook_outputs.py b/scripts/strip_notebook_outputs.py new file mode 100755 index 00000000..14f7ef18 --- /dev/null +++ b/scripts/strip_notebook_outputs.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +"""Strip execution state from Jupyter notebooks. + +Usage: + python scripts/strip_notebook_outputs.py [more.ipynb ...] +""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path + + +def strip_notebook(path: Path) -> bool: + original = path.read_text(encoding="utf-8") + data = json.loads(original) + changed = False + + for cell in data.get("cells", []): + if cell.get("cell_type") != "code": + continue + + if cell.get("execution_count") is not None: + cell["execution_count"] = None + changed = True + + if cell.get("outputs"): + cell["outputs"] = [] + changed = True + + metadata = data.get("metadata") + if isinstance(metadata, dict) and "widgets" in metadata: + del metadata["widgets"] + changed = True + + if not changed: + return False + + path.write_text( + json.dumps(data, indent=1, ensure_ascii=False) + "\n", + encoding="utf-8", + ) + return True + + +def main(argv: list[str]) -> int: + if len(argv) < 2: + print("usage: strip_notebook_outputs.py [more.ipynb ...]", file=sys.stderr) + return 2 + + failed = False + changed_files: list[str] = [] + + for raw in argv[1:]: + path = Path(raw) + if not path.exists(): + continue + + try: + if strip_notebook(path): + changed_files.append(str(path)) + except Exception as exc: # pragma: no cover - defensive for hook usage + print(f"failed to strip {path}: {exc}", file=sys.stderr) + failed = True + + if changed_files: + print("stripped notebook outputs:") + for filename in changed_files: + print(f" - {filename}") + + return 1 if failed else 0 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv)) diff --git a/scripts/sync_notebook_badges.py b/scripts/sync_notebook_badges.py new file mode 100755 index 00000000..10567dfb --- /dev/null +++ b/scripts/sync_notebook_badges.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +"""Normalize example notebook top-cell badges. + +Usage: + python scripts/sync_notebook_badges.py [more.ipynb ...] +""" + +from __future__ import annotations + +import json +import re +import sys +from pathlib import Path + +REPO = "TuringLang/SSMProblems.jl" + + +def source_to_text(source: object) -> str: + if isinstance(source, list): + return "".join(str(part) for part in source) + if isinstance(source, str): + return source + return "" + + +def text_to_source(text: str) -> list[str]: + lines = text.splitlines() + if not lines: + return [] + return [f"{line}\n" for line in lines[:-1]] + [lines[-1]] + + +def title_from_slug(slug: str) -> str: + return " ".join(word.capitalize() for word in slug.replace("_", "-").split("-")) + + +def is_badge_line(line: str) -> bool: + return ( + "colab.research.google.com" in line + or "View%20Source-GitHub" in line + or "Example%20Page-Docs" in line + ) + + +def notebook_urls(path: Path) -> tuple[str, str, str]: + parts = path.parts + try: + pkg_idx = parts.index("GeneralisedFilters") + except ValueError: + try: + pkg_idx = parts.index("SSMProblems") + except ValueError as exc: + raise ValueError(f"notebook path does not include known package dir: {path}") from exc + + pkg = parts[pkg_idx] + rel_path = "/".join(parts[pkg_idx:]) + + if len(parts) < pkg_idx + 4 or parts[pkg_idx + 1] != "examples": + raise ValueError(f"notebook path is not under {pkg}/examples/: {path}") + + slug = parts[pkg_idx + 2] + colab = f"https://colab.research.google.com/github/{REPO}/blob/main/{rel_path}" + source = f"https://github.com/{REPO}/blob/main/{rel_path}" + docs = f"https://turinglang.org/SSMProblems.jl/{pkg}/dev/examples/{slug}/" + return colab, source, docs + + +def build_badge_line(colab: str, source: str, docs: str) -> str: + return ( + f"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)]({colab}) " + f"[![View Source](https://img.shields.io/badge/View%20Source-GitHub-181717?logo=github)]({source}) " + f"[![Example Page](https://img.shields.io/badge/Example%20Page-Docs-0A7F2E)]({docs})" + ) + + +def normalize_notebook(path: Path) -> bool: + original = path.read_text(encoding="utf-8") + data = json.loads(original) + cells = data.get("cells", []) + if not isinstance(cells, list): + raise ValueError(f"notebook cells are invalid in {path}") + + colab, source, docs = notebook_urls(path) + badges = build_badge_line(colab, source, docs) + + if not cells or cells[0].get("cell_type") != "markdown": + slug = path.parent.name + title = f"# {title_from_slug(slug)}" + top_cell = { + "cell_type": "markdown", + "metadata": {}, + "source": text_to_source(f"{title}\n\n{badges}"), + } + cells.insert(0, top_cell) + data["cells"] = cells + path.write_text(json.dumps(data, indent=1, ensure_ascii=False) + "\n", encoding="utf-8") + return True + + first = cells[0] + source_text = source_to_text(first.get("source", [])) + lines = source_text.splitlines() + + title = None + for idx, line in enumerate(lines): + if re.match(r"^#\s+", line): + title = line.strip() + lines = lines[idx + 1 :] + break + + if title is None: + title = f"# {title_from_slug(path.parent.name)}" + + body = [line for line in lines if not is_badge_line(line)] + while body and not body[0].strip(): + body.pop(0) + while body and not body[-1].strip(): + body.pop() + + new_lines = [title, "", badges] + if body: + new_lines += [""] + body + + new_source = text_to_source("\n".join(new_lines)) + if first.get("source") == new_source: + return False + + first["source"] = new_source + path.write_text(json.dumps(data, indent=1, ensure_ascii=False) + "\n", encoding="utf-8") + return True + + +def main(argv: list[str]) -> int: + if len(argv) < 2: + print("usage: sync_notebook_badges.py [more.ipynb ...]", file=sys.stderr) + return 2 + + changed: list[str] = [] + failed = False + + for raw in argv[1:]: + path = Path(raw) + if not path.exists(): + continue + + try: + if normalize_notebook(path): + changed.append(str(path)) + except Exception as exc: + print(f"failed to sync badges for {path}: {exc}", file=sys.stderr) + failed = True + + if changed: + print("synchronized notebook badges:") + for filename in changed: + print(f" - {filename}") + + return 1 if failed else 0 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv))