From 9c516fdccbd7e36055612c178f1c0d31d9e27d9e Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sat, 21 Mar 2026 11:03:51 -0700 Subject: [PATCH 01/74] Project.toml: support Symbolics v7 & Utils v4 --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index b445f3db..eabc5b36 100644 --- a/Project.toml +++ b/Project.toml @@ -40,8 +40,8 @@ Optim = "1" PrettyTables = "3" ProximalAlgorithms = "0.7" StatsBase = "0.33, 0.34" -Symbolics = "4, 5, 6" -SymbolicUtils = "1.4 - 1.5, 1.7, 2, 3" +Symbolics = "4, 5, 6, 7" +SymbolicUtils = "1.4 - 1.5, 1.7, 2, 3, 4" StatsAPI = "1" [extras] From 6e1ffaaa76f21bfd44fe8ee506fa370f9d4e22f4 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sat, 21 Mar 2026 11:03:51 -0700 Subject: [PATCH 02/74] prepare_start_params(): tighten type check --- src/optimizer/abstract.jl | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/optimizer/abstract.jl b/src/optimizer/abstract.jl index e9a8c47b..0c7913c4 100644 --- a/src/optimizer/abstract.jl +++ b/src/optimizer/abstract.jl @@ -154,11 +154,24 @@ function prepare_start_params(start_val::AbstractVector, model::AbstractSem; kwa "The length of `start_val` vector ($(length(start_val))) does not match the number of model parameters ($(nparams(model))).", ), ) + (eltype(start_val) <: Number) || throw( + TypeError( + :prepare_start_params, + "start_val elements must be numeric", + Number, + eltype(start_val), + ), + ) return start_val end function prepare_start_params(start_val::AbstractDict, model::AbstractSem; kwargs...) - return [start_val[param] for param in params(model)] # convert to a vector + # convert to a vector + return prepare_start_params( + [start_val[param] for param in params(model)], + model; + kwargs..., + ) end # get from the ParameterTable (potentially from a different model with match param names) From 32068dee990fd5474387aeaef008186f08e2e627 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sat, 21 Mar 2026 11:03:51 -0700 Subject: [PATCH 03/74] SemImplied/SemLossFun: drop meanstructure kwarg - for SemImplied require spec::SemSpec as positional - for SemLossFunction require implied argument --- src/frontend/specification/Sem.jl | 10 +++++- src/implied/RAM/generic.jl | 24 +++++-------- src/implied/RAM/symbolic.jl | 32 ++++++----------- src/implied/abstract.jl | 14 -------- src/loss/ML/FIML.jl | 11 +++--- src/loss/ML/ML.jl | 3 +- src/loss/WLS/WLS.jl | 27 ++++++++------ .../recover_parameters_twofact.jl | 2 +- test/unit_tests/model.jl | 35 ++++--------------- 9 files changed, 59 insertions(+), 99 deletions(-) diff --git a/src/frontend/specification/Sem.jl b/src/frontend/specification/Sem.jl index 53858abd..a47bad4b 100644 --- a/src/frontend/specification/Sem.jl +++ b/src/frontend/specification/Sem.jl @@ -109,7 +109,15 @@ function get_fields!(kwargs, specification, observed, implied, loss) # implied if !isa(implied, SemImplied) - implied = implied(; specification, kwargs...) + # FIXME remove this implicit logic + # SemWLS only accepts vech-ed implied covariance + if isa(loss, Type) && (loss <: SemWLS) && !haskey(kwargs, :vech) + implied_kwargs = copy(kwargs) + implied_kwargs[:vech] = true + else + implied_kwargs = kwargs + end + implied = implied(specification; implied_kwargs...) end kwargs[:implied] = implied diff --git a/src/implied/RAM/generic.jl b/src/implied/RAM/generic.jl index 4c1fa323..d57500a3 100644 --- a/src/implied/RAM/generic.jl +++ b/src/implied/RAM/generic.jl @@ -6,14 +6,10 @@ Model implied covariance and means via RAM notation. # Constructor - RAM(;specification, - meanstructure = false, - gradient = true, - kwargs...) + RAM(; specification, gradient = true, kwargs...) # Arguments - `specification`: either a `RAMMatrices` or `ParameterTable` object -- `meanstructure::Bool`: does the model have a meanstructure? - `gradient::Bool`: is gradient-based optimization used # Extended help @@ -53,9 +49,9 @@ Vector of indices of each parameter in the respective RAM matrix: - `ram.M_indices` Additional interfaces -- `ram.F⨉I_A⁻¹` -> ``F(I-A)^{-1}`` -- `ram.F⨉I_A⁻¹S` -> ``F(I-A)^{-1}S`` -- `ram.I_A` -> ``I-A`` +- `F⨉I_A⁻¹(::RAM)` -> ``F(I-A)^{-1}`` +- `F⨉I_A⁻¹S(::RAM)` -> ``F(I-A)^{-1}S`` +- `I_A(::RAM)` -> ``I-A`` Only available in gradient! calls: - `ram.I_A⁻¹` -> ``(I-A)^{-1}`` @@ -90,15 +86,13 @@ end ### Constructors ############################################################################################ -function RAM(; - specification::SemSpecification, +function RAM( + spec::SemSpecification; + #vech = false, gradient_required = true, - meanstructure = false, kwargs..., ) - ram_matrices = convert(RAMMatrices, specification) - - check_meanstructure_specification(meanstructure, ram_matrices) + ram_matrices = convert(RAMMatrices, spec) # get dimensions of the model n_par = nparams(ram_matrices) @@ -126,7 +120,7 @@ function RAM(; end # μ - if meanstructure + if !isnothing(ram_matrices.M) MS = HasMeanStruct M_pre = materialize(ram_matrices.M, rand_params) ∇M = gradient_required ? sparse_gradient(ram_matrices.M) : nothing diff --git a/src/implied/RAM/symbolic.jl b/src/implied/RAM/symbolic.jl index df7c497a..0f7868da 100644 --- a/src/implied/RAM/symbolic.jl +++ b/src/implied/RAM/symbolic.jl @@ -12,12 +12,10 @@ Subtype of `SemImplied` that implements the RAM notation with symbolic precomput gradient = true, hessian = false, approximate_hessian = false, - meanstructure = false, kwargs...) # Arguments - `specification`: either a `RAMMatrices` or `ParameterTable` object -- `meanstructure::Bool`: does the model have a meanstructure? - `gradient::Bool`: is gradient-based optimization used - `hessian::Bool`: is hessian-based optimization used - `approximate_hessian::Bool`: for hessian based optimization: should the hessian be approximated @@ -79,20 +77,16 @@ end ### Constructors ############################################################################################ -function RAMSymbolic(; - specification::SemSpecification, - loss_types = nothing, - vech = false, - simplify_symbolics = false, - gradient = true, - hessian = false, - meanstructure = false, - approximate_hessian = false, +function RAMSymbolic( + spec::SemSpecification; + vech::Bool = false, + simplify_symbolics::Bool = false, + gradient::Bool = true, + hessian::Bool = false, + approximate_hessian::Bool = false, kwargs..., ) - ram_matrices = convert(RAMMatrices, specification) - - check_meanstructure_specification(meanstructure, ram_matrices) + ram_matrices = convert(RAMMatrices, spec) n_par = nparams(ram_matrices) par = (Symbolics.@variables θ[1:n_par])[1] @@ -102,10 +96,6 @@ function RAMSymbolic(; M = !isnothing(ram_matrices.M) ? materialize(Num, ram_matrices.M, par) : nothing F = ram_matrices.F - if !isnothing(loss_types) && any(T -> T <: SemWLS, loss_types) - vech = true - end - I_A⁻¹ = neumann_series(A) # Σ @@ -146,7 +136,7 @@ function RAMSymbolic(; end # μ - if meanstructure + if !isnothing(ram_matrices.M) MS = HasMeanStruct μ_sym = eval_μ_symbolic(M, I_A⁻¹, F; simplify = simplify_symbolics) μ_eval! = Symbolics.build_function(μ_sym, par, expression = Val{false})[2] @@ -230,10 +220,10 @@ end ############################################################################################ # expected covariations of observed vars -function eval_Σ_symbolic(S, I_A⁻¹, F; vech = false, simplify = false) +function eval_Σ_symbolic(S, I_A⁻¹, F; vech::Bool = false, simplify::Bool = false) Σ = F * I_A⁻¹ * S * permutedims(I_A⁻¹) * permutedims(F) Σ = Array(Σ) - vech && (Σ = Σ[tril(trues(size(F, 1), size(F, 1)))]) + vech && (Σ = SEM.vech(Σ)) if simplify Threads.@threads for i in eachindex(Σ) Σ[i] = Symbolics.simplify(Σ[i]) diff --git a/src/implied/abstract.jl b/src/implied/abstract.jl index 6d298f65..d4868d74 100644 --- a/src/implied/abstract.jl +++ b/src/implied/abstract.jl @@ -31,17 +31,3 @@ function check_acyclic(A::AbstractMatrix; verbose::Bool = false) return A end end - -# Verify that the `meanstructure` argument aligns with the model specification. -function check_meanstructure_specification(meanstructure, ram_matrices) - if meanstructure & isnothing(ram_matrices.M) - throw(ArgumentError( - "You set `meanstructure = true`, but your model specification contains no mean parameters." - )) - end - if !meanstructure & !isnothing(ram_matrices.M) - throw(ArgumentError( - "If your model specification contains mean parameters, you have to set `Sem(..., meanstructure = true)`." - )) - end -end \ No newline at end of file diff --git a/src/loss/ML/FIML.jl b/src/loss/ML/FIML.jl index da5ccb7c..8572b15a 100644 --- a/src/loss/ML/FIML.jl +++ b/src/loss/ML/FIML.jl @@ -75,15 +75,16 @@ Can handle observed data with missing values. # Constructor - SemFIML(; observed::SemObservedMissing, specification, kwargs...) + SemFIML(; observed::SemObservedMissing, implied::SemImplied, kwargs...) # Arguments -- `observed`: the observed data with missing values (see [`SemObservedMissing`](@ref)) -- `specification`: [`SemSpecification`](@ref) object +- `observed::SemObservedMissing`: the observed part of the model + (see [`SemObservedMissing`](@ref)) +- `implied::SemImplied`: the implied part of the model # Examples ```julia -my_fiml = SemFIML(observed = my_observed, specification = my_parameter_table) +my_fiml = SemFIML(observed = my_observed, implied = my_implied) ``` # Interfaces @@ -118,7 +119,7 @@ function SemFIML(; observed::SemObservedMissing, implied, specification, kwargs. ExactHessian(), [SemFIMLPattern(pat) for pat in observed.patterns], zeros(nobserved_vars(observed), nobserved_vars(observed)), - CommutationMatrix(nvars(specification)), + CommutationMatrix(nvars(implied)), nothing, ) end diff --git a/src/loss/ML/ML.jl b/src/loss/ML/ML.jl index ce77ea9c..aae1dada 100644 --- a/src/loss/ML/ML.jl +++ b/src/loss/ML/ML.jl @@ -8,11 +8,10 @@ Maximum likelihood estimation. # Constructor - SemML(;observed, meanstructure = false, approximate_hessian = false, kwargs...) + SemML(; observed, approximate_hessian = false, kwargs...) # Arguments - `observed::SemObserved`: the observed part of the model -- `meanstructure::Bool`: does the model have a meanstructure? - `approximate_hessian::Bool`: if hessian-based optimization is used, should the hessian be swapped for an approximation # Examples diff --git a/src/loss/WLS/WLS.jl b/src/loss/WLS/WLS.jl index b7f66d55..9de011f6 100644 --- a/src/loss/WLS/WLS.jl +++ b/src/loss/WLS/WLS.jl @@ -10,8 +10,7 @@ At the moment only available with the `RAMSymbolic` implied type. # Constructor SemWLS(; - observed, - meanstructure = false, + observed, implied, wls_weight_matrix = nothing, wls_weight_matrix_mean = nothing, approximate_hessian = false, @@ -19,7 +18,7 @@ At the moment only available with the `RAMSymbolic` implied type. # Arguments - `observed`: the `SemObserved` part of the model -- `meanstructure::Bool`: does the model have a meanstructure? +- `implied::SemImplied`: the implied part of the model - `approximate_hessian::Bool`: should the hessian be swapped for an approximation - `wls_weight_matrix`: the weight matrix for weighted least squares. Defaults to GLS estimation (``0.5*(D^T*kron(S,S)*D)`` where D is the duplication matrix @@ -29,7 +28,7 @@ At the moment only available with the `RAMSymbolic` implied type. # Examples ```julia -my_wls = SemWLS(observed = my_observed) +my_wls = SemWLS(observed = my_observed, implied = my_implied) ``` # Interfaces @@ -50,12 +49,11 @@ SemWLS{HE}(args...) where {HE <: HessianEval} = SemWLS{HE, map(typeof, args)...}(HE(), args...) function SemWLS(; - observed, - implied, - wls_weight_matrix = nothing, - wls_weight_matrix_mean = nothing, - approximate_hessian = false, - meanstructure = false, + observed::SemObserved, + implied::SemImplied, + wls_weight_matrix::Union{AbstractMatrix, Nothing} = nothing, + wls_weight_matrix_mean::Union{AbstractMatrix, Nothing} = nothing, + approximate_hessian::Bool = false, kwargs..., ) if observed isa SemObservedMissing @@ -81,6 +79,10 @@ function SemWLS(; nobs_vars = nobserved_vars(observed) tril_ind = filter(x -> (x[1] >= x[2]), CartesianIndices(obs_cov(observed))) s = obs_cov(observed)[tril_ind] + size(s) == size(implied.Σ) || + throw(DimensionMismatch("SemWLS requires implied covariance to be in vech-ed form " * + "(vectorized lower triangular part of Σ matrix): $(size(s)) expected, $(size(implied.Σ)) found.\n" * + "$(nameof(typeof(implied))) must be constructed with vech=true.")) # compute V here if isnothing(wls_weight_matrix) @@ -94,9 +96,12 @@ function SemWLS(; "wls_weight_matrix has to be of size $(length(tril_ind))×$(length(tril_ind))", ) end + size(wls_weight_matrix) == (length(s), length(s)) || + DimensionMismatch("wls_weight_matrix has to be of size $(length(s))×$(length(s))") - if meanstructure + if MeanStruct(implied) == HasMeanStruct if isnothing(wls_weight_matrix_mean) + @warn "Computing WLS weight matrix for the meanstructure using obs_cov()" wls_weight_matrix_mean = inv(obs_cov(observed)) else size(wls_weight_matrix_mean) == (nobs_vars, nobs_vars) || DimensionMismatch( diff --git a/test/examples/recover_parameters/recover_parameters_twofact.jl b/test/examples/recover_parameters/recover_parameters_twofact.jl index 9f9503af..a4bd7d5f 100644 --- a/test/examples/recover_parameters/recover_parameters_twofact.jl +++ b/test/examples/recover_parameters/recover_parameters_twofact.jl @@ -53,7 +53,7 @@ start = [ repeat([0.5], 4) ] -implied_ml = RAMSymbolic(; specification = ram_matrices, start_val = start) +implied_ml = RAMSymbolic(ram_matrices; start_val = start) implied_ml.Σ_eval!(implied_ml.Σ, true_val) diff --git a/test/unit_tests/model.jl b/test/unit_tests/model.jl index d9f9254b..fbe2a937 100644 --- a/test/unit_tests/model.jl +++ b/test/unit_tests/model.jl @@ -46,13 +46,16 @@ function test_params_api(semobj, spec::SemSpecification) @test @inferred(param_labels(semobj)) == param_labels(spec) end -@testset "Sem(implied=$impliedtype, loss=SemML)" for impliedtype in (RAM, RAMSymbolic) - +@testset "Sem(implied=$impliedtype, loss=$losstype)" for (impliedtype, losstype) in [ + (RAM, SemML), + (RAMSymbolic, SemML), + (RAMSymbolic, SemWLS), +] model = Sem( specification = ram_matrices, observed = obs, implied = impliedtype, - loss = SemML, + loss = losstype, ) @test model isa Sem @@ -71,29 +74,3 @@ end @test @inferred(nsamples(model)) == nsamples(obs) end - -@testset "Sem(implied=RAMSymbolic, loss=SemWLS)" begin - - model = Sem( - specification = ram_matrices, - observed = obs, - implied = RAMSymbolic, - loss = SemWLS, - ) - - @test model isa Sem - @test @inferred(implied(model)) isa RAMSymbolic - @test @inferred(observed(model)) isa SemObserved - - test_vars_api(model, ram_matrices) - test_params_api(model, ram_matrices) - - test_vars_api(implied(model), ram_matrices) - test_params_api(implied(model), ram_matrices) - - @test @inferred(loss(model)) isa SemLoss - semloss = loss(model).functions[1] - @test semloss isa SemWLS - - @test @inferred(nsamples(model)) == nsamples(obs) -end \ No newline at end of file From e81cec008cb99ef492841bd1c4e0503ac45cca38 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sat, 21 Mar 2026 11:17:30 -0700 Subject: [PATCH 04/74] refactor Sem, SemEnsemble, SemLoss --- src/StructuralEquationModels.jl | 17 +- src/additional_functions/simulation.jl | 87 ++-- src/additional_functions/start_val/common.jl | 17 + .../start_val/start_fabin3.jl | 16 +- .../start_val/start_simple.jl | 34 +- src/frontend/finite_diff.jl | 35 ++ src/frontend/fit/fitmeasures/RMSEA.jl | 32 +- src/frontend/fit/fitmeasures/chi2.jl | 73 ++- src/frontend/fit/fitmeasures/dof.jl | 11 +- src/frontend/fit/fitmeasures/minus2ll.jl | 34 +- src/frontend/fit/standard_errors/hessian.jl | 25 +- src/frontend/pretty_printing.jl | 8 +- src/frontend/specification/Sem.jl | 439 +++++++++++++----- src/implied/RAM/generic.jl | 13 +- src/implied/RAM/symbolic.jl | 12 +- src/implied/empty.jl | 6 +- src/loss/ML/FIML.jl | 86 ++-- src/loss/ML/ML.jl | 90 ++-- src/loss/WLS/WLS.jl | 94 ++-- src/loss/abstract.jl | 42 ++ src/loss/constant/constant.jl | 28 +- src/loss/regularization/ridge.jl | 9 +- src/objective_gradient_hessian.jl | 251 +++++----- src/optimizer/abstract.jl | 14 +- src/types.jl | 243 +++------- test/examples/multigroup/build_models.jl | 222 +++------ test/examples/political_democracy/by_parts.jl | 284 +++++------ .../political_democracy/constraints.jl | 4 +- .../political_democracy/constructor.jl | 169 +++---- .../recover_parameters_twofact.jl | 24 +- test/unit_tests/model.jl | 5 +- 31 files changed, 1203 insertions(+), 1221 deletions(-) create mode 100644 src/additional_functions/start_val/common.jl create mode 100644 src/frontend/finite_diff.jl create mode 100644 src/loss/abstract.jl diff --git a/src/StructuralEquationModels.jl b/src/StructuralEquationModels.jl index 19dd6f43..d98e7925 100644 --- a/src/StructuralEquationModels.jl +++ b/src/StructuralEquationModels.jl @@ -44,6 +44,7 @@ include("frontend/specification/EnsembleParameterTable.jl") include("frontend/specification/StenoGraphs.jl") include("frontend/fit/summary.jl") include("frontend/StatsAPI.jl") +include("frontend/finite_diff.jl") # pretty printing include("frontend/pretty_printing.jl") # observed @@ -53,26 +54,28 @@ include("observed/covariance.jl") include("observed/missing_pattern.jl") include("observed/missing.jl") include("observed/EM.jl") -# constructor -include("frontend/specification/Sem.jl") -include("frontend/specification/documentation.jl") # implied include("implied/abstract.jl") include("implied/RAM/symbolic.jl") include("implied/RAM/generic.jl") include("implied/empty.jl") # loss +include("loss/abstract.jl") include("loss/ML/ML.jl") include("loss/ML/FIML.jl") include("loss/regularization/ridge.jl") include("loss/WLS/WLS.jl") include("loss/constant/constant.jl") +# constructor +include("frontend/specification/Sem.jl") +include("frontend/specification/documentation.jl") # optimizer include("optimizer/abstract.jl") include("optimizer/Empty.jl") include("optimizer/optim.jl") # helper functions include("additional_functions/helper.jl") +include("additional_functions/start_val/common.jl") include("additional_functions/start_val/start_fabin3.jl") include("additional_functions/start_val/start_simple.jl") include("additional_functions/artifacts.jl") @@ -94,14 +97,11 @@ include("frontend/fit/standard_errors/z_test.jl") include("frontend/fit/standard_errors/confidence_intervals.jl") export AbstractSem, - AbstractSemSingle, - AbstractSemCollection, coef, coefnames, coeftable, Sem, SemFiniteDiff, - SemEnsemble, MeanStruct, NoMeanStruct, HasMeanStruct, @@ -116,8 +116,8 @@ export AbstractSem, start_val, start_fabin3, start_simple, + AbstractLoss, SemLoss, - SemLossFunction, SemML, SemFIML, em_mvn, @@ -125,6 +125,9 @@ export AbstractSem, SemConstant, SemWLS, loss, + nsem_terms, + sem_terms, + sem_term, SemOptimizer, optimizer, optimizer_engine, diff --git a/src/additional_functions/simulation.jl b/src/additional_functions/simulation.jl index 4839bc27..6d694c97 100644 --- a/src/additional_functions/simulation.jl +++ b/src/additional_functions/simulation.jl @@ -43,36 +43,42 @@ function update_observed end # change observed (data) without reconstructing the whole model ############################################################################################ +# don't change non-SEM terms +replace_observed(loss::AbstractLoss; kwargs...) = loss + # use the same observed type as before -replace_observed(model::AbstractSemSingle; kwargs...) = - replace_observed(model, typeof(observed(model)).name.wrapper; kwargs...) +replace_observed(loss::SemLoss; kwargs...) = + replace_observed(loss, typeof(SEM.observed(loss)).name.wrapper; kwargs...) + +# construct a new observed type +replace_observed(loss::SemLoss, observed_type; kwargs...) = + replace_observed(loss, observed_type(; kwargs...); kwargs...) -function replace_observed(model::AbstractSemSingle, observed_type; kwargs...) - new_observed = observed_type(; kwargs...) +function replace_observed(loss::SemLoss, new_observed::SemObserved; kwargs...) kwargs = Dict{Symbol, Any}(kwargs...) + old_observed = SEM.observed(loss) + implied = SEM.implied(loss) # get field types kwargs[:observed_type] = typeof(new_observed) - kwargs[:old_observed_type] = typeof(model.observed) - kwargs[:implied_type] = typeof(model.implied) - kwargs[:loss_types] = [typeof(lossfun) for lossfun in model.loss.functions] + kwargs[:old_observed_type] = typeof(old_observed) # update implied - new_implied = update_observed(model.implied, new_observed; kwargs...) + new_implied = update_observed(implied, new_observed; kwargs...) kwargs[:implied] = new_implied + kwargs[:implied_type] = typeof(new_implied) kwargs[:nparams] = nparams(new_implied) # update loss - new_loss = update_observed(model.loss, new_observed; kwargs...) - - return Sem(new_observed, new_implied, new_loss) + return update_observed(loss, new_observed; kwargs...) end -function update_observed(loss::SemLoss, new_observed; kwargs...) - new_functions = Tuple( - update_observed(lossfun, new_observed; kwargs...) for lossfun in loss.functions - ) - return SemLoss(new_functions, loss.weights) +replace_observed(loss::LossTerm; kwargs...) = + LossTerm(replace_observed(loss.loss; kwargs...), loss.id, loss.weight) + +function replace_observed(sem::Sem; kwargs...) + updated_terms = Tuple(replace_observed(term; kwargs...) for term in loss_terms(sem)) + return Sem(updated_terms...) end function replace_observed( @@ -111,39 +117,38 @@ end # simulate data ############################################################################################ """ - (1) rand(model::AbstractSemSingle, params, n) - - (2) rand(model::AbstractSemSingle, n) + rand(sem::Union{Sem, SemLoss, SemImplied}, [params], n) -Sample normally distributed data from the model-implied covariance matrix and mean vector. +Sample from the multivariate normal distribution implied by the SEM model. # Arguments -- `model::AbstractSemSingle`: model to simulate from. -- `params`: parameter values to simulate from. -- `n::Integer`: Number of samples. +- `sem`: SEM model to use. Ensemble models with multiple SEM terms are not supported. +- `params`: optional SEM model parameters to simulate from, otherwise uses the + current state of implied covariances and means. +- `n::Integer`: Number of samples to draw. # Examples ```julia rand(model, start_simple(model), 100) ``` """ -function Distributions.rand( - model::AbstractSemSingle{O, I, L}, - params, - n::Integer, -) where {O, I <: Union{RAM, RAMSymbolic}, L} - update!(EvaluationTargets{true, false, false}(), model.implied, model, params) - return rand(model, n) -end - -function Distributions.rand( - model::AbstractSemSingle{O, I, L}, - n::Integer, -) where {O, I <: Union{RAM, RAMSymbolic}, L} - if MeanStruct(model.implied) === NoMeanStruct - data = permutedims(rand(MvNormal(Symmetric(model.implied.Σ)), n)) - elseif MeanStruct(model.implied) === HasMeanStruct - data = permutedims(rand(MvNormal(model.implied.μ, Symmetric(model.implied.Σ)), n)) +function Distributions.rand(implied::SemImplied, params, n::Integer) + if !isnothing(params) + # update the implied covariances with the new model params + update!(EvaluationTargets{true, false, false}(), implied, params) + end + Σ = Symmetric(implied.Σ) + if MeanStruct(implied) === NoMeanStruct + return permutedims(rand(MvNormal(Σ), n)) + elseif MeanStruct(implied) === HasMeanStruct + return permutedims(rand(MvNormal(implied.μ, Σ), n)) end - return data end + +Distributions.rand(loss::SemLoss, params, n::Integer) = rand(SEM.implied(loss), params, n) + +Distributions.rand(model::Sem, params, n::Integer) = rand(sem_term(model), params, n) + +# rand() overloads without SEM params +Distributions.rand(implied::Union{SemImplied, SemLoss, Sem}, n::Integer) = + Distributions.rand(implied, nothing, n) diff --git a/src/additional_functions/start_val/common.jl b/src/additional_functions/start_val/common.jl new file mode 100644 index 00000000..92c85d6f --- /dev/null +++ b/src/additional_functions/start_val/common.jl @@ -0,0 +1,17 @@ + +# start values for SEM Models (including ensembles) +function start_values(f, model::AbstractSem; kwargs...) + start_vals = fill(0.0, nparams(model)) + + # initialize parameters using the SEM loss terms + # (first SEM loss term that sets given parameter to nonzero value) + for term in loss_terms(model) + issemloss(term) || continue + term_start_vals = f(loss(term); kwargs...) + for (i, val) in enumerate(term_start_vals) + iszero(val) || (start_vals[i] = val) + end + end + + return start_vals +end diff --git a/src/additional_functions/start_val/start_fabin3.jl b/src/additional_functions/start_val/start_fabin3.jl index 53d3442a..54337028 100644 --- a/src/additional_functions/start_val/start_fabin3.jl +++ b/src/additional_functions/start_val/start_fabin3.jl @@ -7,12 +7,17 @@ Not available for ensemble models. function start_fabin3 end # splice model and loss functions -function start_fabin3(model::AbstractSemSingle; kwargs...) - return start_fabin3(model.observed, model.implied, model.loss.functions..., kwargs...) +function start_fabin3(model::SemLoss; kwargs...) + return start_fabin3(model.observed, model.implied; kwargs...) end -function start_fabin3(observed::SemObserved, implied::SemImplied, args...; kwargs...) - return start_fabin3(implied.ram_matrices, obs_cov(observed), obs_mean(observed)) +function start_fabin3(observed::SemObserved, implied::SemImplied; kwargs...) + return start_fabin3( + implied.ram_matrices, + obs_cov(observed), + # ignore observed means if no meansturcture + !isnothing(implied.ram_matrices.M) ? obs_mean(observed) : nothing, + ) end function start_fabin3( @@ -161,3 +166,6 @@ end function is_in_Λ(ind_vec, F_ind) return any(ind -> !(ind[2] ∈ F_ind) && (ind[1] ∈ F_ind), ind_vec) end + +# ensembles +start_fabin3(model::AbstractSem; kwargs...) = start_values(start_fabin3, model; kwargs...) diff --git a/src/additional_functions/start_val/start_simple.jl b/src/additional_functions/start_val/start_simple.jl index 4fbc8719..afdbf92e 100644 --- a/src/additional_functions/start_val/start_simple.jl +++ b/src/additional_functions/start_val/start_simple.jl @@ -15,34 +15,11 @@ Return a vector of simple starting values. """ function start_simple end -# Single Models ---------------------------------------------------------------------------- -function start_simple(model::AbstractSemSingle; kwargs...) - return start_simple(model.observed, model.implied, model.loss.functions...; kwargs...) -end - -function start_simple(observed, implied, args...; kwargs...) - return start_simple(implied.ram_matrices; kwargs...) -end - -# Ensemble Models -------------------------------------------------------------------------- -function start_simple(model::SemEnsemble; kwargs...) - start_vals = [] - - for sem in model.sems - push!(start_vals, start_simple(sem; kwargs...)) - end - - has_start_val = [.!iszero.(start_val) for start_val in start_vals] +start_simple(model::SemLoss; kwargs...) = + start_simple(observed(model), implied(model); kwargs...) - start_val = similar(start_vals[1]) - start_val .= 0.0 - - for (j, indices) in enumerate(has_start_val) - start_val[indices] .= start_vals[j][indices] - end - - return start_val -end +start_simple(observed::SemObserved, implied::SemImplied; kwargs...) = + start_simple(implied.ram_matrices; kwargs...) function start_simple( ram_matrices::RAMMatrices; @@ -103,3 +80,6 @@ function start_simple( end return start_val end + +# multigroup models +start_simple(model::AbstractSem; kwargs...) = start_values(start_simple, model; kwargs...) diff --git a/src/frontend/finite_diff.jl b/src/frontend/finite_diff.jl new file mode 100644 index 00000000..ee0a9bf9 --- /dev/null +++ b/src/frontend/finite_diff.jl @@ -0,0 +1,35 @@ +_unwrap(wrapper::SemFiniteDiff) = wrapper.model +params(wrapper::SemFiniteDiff) = params(wrapper.model) +loss_terms(wrapper::SemFiniteDiff) = loss_terms(wrapper.model) + +FiniteDiffLossWrappers = Union{LossFiniteDiff, SemLossFiniteDiff} + +_unwrap(term::AbstractLoss) = term +_unwrap(wrapper::FiniteDiffLossWrappers) = wrapper.loss +implied(wrapper::FiniteDiffLossWrappers) = implied(_unwrap(wrapper)) +observed(wrapper::FiniteDiffLossWrappers) = observed(_unwrap(wrapper)) + +FiniteDiffWrapper(model::AbstractSem) = SemFiniteDiff(model) +FiniteDiffWrapper(loss::AbstractLoss) = LossFiniteDiff(loss) +FiniteDiffWrapper(loss::SemLoss) = SemLossFiniteDiff(loss) + +function evaluate!( + objective, + gradient, + hessian, + sem::Union{SemFiniteDiff, FiniteDiffLossWrappers}, + params, +) + wrapped = _unwrap(sem) + obj(p) = _evaluate!( + objective_zero(objective, gradient, hessian), + nothing, + nothing, + wrapped, + p, + ) + isnothing(gradient) || FiniteDiff.finite_difference_gradient!(gradient, obj, params) + isnothing(hessian) || FiniteDiff.finite_difference_hessian!(hessian, obj, params) + # FIXME if objective is not calculated, SemLoss implied states may not correspond to params + return !isnothing(objective) ? obj(params) : nothing +end diff --git a/src/frontend/fit/fitmeasures/RMSEA.jl b/src/frontend/fit/fitmeasures/RMSEA.jl index 8539896f..7406b74c 100644 --- a/src/frontend/fit/fitmeasures/RMSEA.jl +++ b/src/frontend/fit/fitmeasures/RMSEA.jl @@ -19,28 +19,18 @@ for the SEM model. For multigroup models, the correction proposed by J.H. Steiger is applied (see [Steiger, J. H. (1998). *A note on multiple sample extensions of the RMSEA fit index*](https://doi.org/10.1080/10705519809540115)). """ -function RMSEA end - RMSEA(fit::SemFit) = RMSEA(fit, fit.model) -function RMSEA(fit::SemFit, model::AbstractSemSingle) - check_single_lossfun(model; throw_error = true) - return RMSEA(dof(fit), χ²(fit), nsamples(fit)+rmsea_correction(model.loss.functions[1])) -end - -function RMSEA(fit::SemFit, model::SemEnsemble) - check_single_lossfun(model; throw_error = true) - n = nsamples(fit)+model.n*rmsea_correction(model.sems[1].loss.functions[1]) - return sqrt(length(model.sems)) * RMSEA(dof(fit), χ²(fit), n) -end - -function RMSEA(dof, chi2, N⁻) - rmsea = (chi2 - dof) / (N⁻ * dof) - rmsea = rmsea > 0 ? rmsea : 0 - return sqrt(rmsea) +# scaling corrections +RMSEA_corr_scale(::Type{<:SemFIML}) = 0 +RMSEA_corr_scale(::Type{<:SemML}) = -1 +RMSEA_corr_scale(::Type{<:SemWLS}) = -1 + +function RMSEA(fit::SemFit, model::AbstractSem) + term_type = check_single_lossfun(model; throw_error = true) + n = nsamples(fit) + nsem_terms(model) * RMSEA_corr_scale(term_type) + sqrt(nsem_terms(model)) * RMSEA(dof(fit), χ²(fit), n) end -# scaling corrections -rmsea_correction(::SemFIML) = 0 -rmsea_correction(::SemML) = -1 -rmsea_correction(::SemWLS) = -1 +RMSEA(dof::Number, chi2::Number, nsamples::Number) = + sqrt(max((chi2 - dof) / (nsamples * dof), 0.0)) diff --git a/src/frontend/fit/fitmeasures/chi2.jl b/src/frontend/fit/fitmeasures/chi2.jl index 8ce5f079..22d6c2e2 100644 --- a/src/frontend/fit/fitmeasures/chi2.jl +++ b/src/frontend/fit/fitmeasures/chi2.jl @@ -12,57 +12,42 @@ with the *observed* covariance matrix. """ χ²(fit::SemFit) = χ²(fit, fit.model) -############################################################################################ -# Single Models -############################################################################################ +function χ²(fit::SemFit, model::AbstractSem) + terms = sem_terms(model) + isempty(terms) && return 0.0 + + term1 = _unwrap(loss(terms[1])) + L = typeof(term1).name + + # check that all SemLoss terms are of the same class (ML, FIML, WLS etc), ignore typeparams + for (i, term) in enumerate(terms) + lossterm = _unwrap(loss(term)) + @assert lossterm isa SemLoss + if typeof(_unwrap(lossterm)).name != L + @error "SemLoss term #$i is $(typeof(_unwrap(lossterm)).name), expected $L. Heterogeneous loss functions are not supported" + end + end -function χ²(fit::SemFit, model::AbstractSemSingle) - check_single_lossfun(model; throw_error = true) - return χ²(model.loss.functions[1], fit::SemFit, model::AbstractSemSingle) + return χ²(typeof(term1), fit, model) end -χ²(::SemML, fit::SemFit, model::AbstractSemSingle) = - (nsamples(fit) - 1) * - (fit.minimum - logdet(obs_cov(observed(model))) - nobserved_vars(model)) - # bollen, p. 115, only correct for GLS weight matrix -χ²(::SemWLS, fit::SemFit, model::AbstractSemSingle) = - (nsamples(fit) - 1) * fit.minimum - -# FIML -function χ²(::SemFIML, fit::SemFit, model::AbstractSemSingle) - ll_H0 = minus2ll(fit) - ll_H1 = minus2ll(observed(model)) - return ll_H0 - ll_H1 -end - -############################################################################################ -# Collections -############################################################################################ - -function χ²(fit::SemFit, model::SemEnsemble) - check_single_lossfun(model; throw_error = true) - lossfun = model.sems[1].loss.functions[1] - return χ²(lossfun, fit, model) -end - -function χ²(::SemWLS, fit::SemFit, models::SemEnsemble) - return (nsamples(models) - models.n) * fit.minimum -end - -function χ²(::SemML, fit::SemFit, models::SemEnsemble) - F = 0 - for model in models.sems - Fᵢ = objective(model, fit.solution) - Fᵢ -= logdet(obs_cov(observed(model))) + nobserved_vars(model) - Fᵢ *= nsamples(model) - 1 - F += Fᵢ +χ²(::Type{<:SemWLS}, fit::SemFit, model::AbstractSem) = (nsamples(model) - 1) * fit.minimum + +function χ²(::Type{<:SemML}, fit::SemFit, model::AbstractSem) + G = sum(loss_terms(model)) do term + if issemloss(term) + data = observed(term) + something(weight(term), 1.0) * (logdet(obs_cov(data)) + nobserved_vars(data)) + else + return 0.0 + end end - return F + return (nsamples(model) - 1) * (fit.minimum - G) end -function χ²(::SemFIML, fit::SemFit, models::SemEnsemble) +function χ²(::Type{<:SemFIML}, fit::SemFit, model::AbstractSem) ll_H0 = minus2ll(fit) - ll_H1 = sum(minus2ll ∘ observed, models.sems) + ll_H1 = sum(minus2ll ∘ observed, sem_terms(model)) return ll_H0 - ll_H1 end diff --git a/src/frontend/fit/fitmeasures/dof.jl b/src/frontend/fit/fitmeasures/dof.jl index 0e051d02..49b7febf 100644 --- a/src/frontend/fit/fitmeasures/dof.jl +++ b/src/frontend/fit/fitmeasures/dof.jl @@ -18,13 +18,16 @@ dof(fit::SemFit) = dof(fit.model) dof(model::AbstractSem) = n_dp(model) - nparams(model) -function n_dp(model::AbstractSemSingle) - nvars = nobserved_vars(model) +# length of Σ and μ (if present) +function n_dp(implied::SemImplied) + nvars = nobserved_vars(implied) ndp = 0.5(nvars^2 + nvars) - if !isnothing(model.implied.μ) + if !isnothing(implied.μ) ndp += nvars end return ndp end -n_dp(model::SemEnsemble) = sum(n_dp.(model.sems)) +n_dp(term::SemLoss) = n_dp(implied(term)) + +n_dp(model::AbstractSem) = sum(n_dp ∘ loss, sem_terms(model)) diff --git a/src/frontend/fit/fitmeasures/minus2ll.jl b/src/frontend/fit/fitmeasures/minus2ll.jl index c6a954ef..3b353f5c 100644 --- a/src/frontend/fit/fitmeasures/minus2ll.jl +++ b/src/frontend/fit/fitmeasures/minus2ll.jl @@ -6,31 +6,27 @@ Calculate the *-2⋅log(likelihood(fit))*. # See also [`fit_measures`](@ref) """ -minus2ll(fit::SemFit) = minus2ll(fit, fit.model) +minus2ll(fit::SemFit) = minus2ll(fit.model, fit) ############################################################################################ -# Single Models +# Single SEM Terms Models ############################################################################################ -function minus2ll(fit::SemFit, model::AbstractSemSingle) - check_single_lossfun(model; throw_error = true) - F = objective(model, fit.solution) - return minus2ll(model.loss.functions[1], F, model) +function minus2ll(term::SemLoss, fit::SemFit) + minimum = objective(term, fit.solution) + return minus2ll(term, minimum) end -# SemML ------------------------------------------------------------------------------------ -function minus2ll(::SemML, F, model::AbstractSemSingle) - return nsamples(model) * (F + log(2π) * nobserved_vars(model)) -end +minus2ll(term::SemML, minimum::Number) = + nsamples(term) * (minimum + log(2π) * nobserved_vars(term)) -# WLS -------------------------------------------------------------------------------------- -minus2ll(::SemWLS, F, ::AbstractSemSingle) = missing +minus2ll(term::SemWLS, minimum::Number) = missing # compute likelihood for missing data - H0 ------------------------------------------------- -# -2ll = (∑ log(2π)*(nᵢ*mᵢ)) + F*n -function minus2ll(::SemFIML, F, model::AbstractSemSingle) - obs = observed(model)::SemObservedMissing - F *= nsamples(obs) +# -2ll = (∑ log(2π)*(nᵢ + mᵢ)) + F*n +function minus2ll(term::SemFIML, minimum::Number) + obs = observed(term)::SemObservedMissing + F = minimum * nsamples(obs) F += log(2π) * sum(pat -> nsamples(pat) * nmeasured_vars(pat), obs.patterns) return F end @@ -62,10 +58,10 @@ function minus2ll(observed::SemObservedMissing) end ############################################################################################ -# Collection +# Multi-group ############################################################################################ -function minus2ll(fit::SemFit, model::SemEnsemble) +function minus2ll(model::AbstractSem, fit::SemFit) check_single_lossfun(model; throw_error = true) - return sum(Base.Fix1(minus2ll, fit), model.sems) + sum(Base.Fix2(minus2ll, fit) ∘ _unwrap ∘ loss, sem_terms(model)) end diff --git a/src/frontend/fit/standard_errors/hessian.jl b/src/frontend/fit/standard_errors/hessian.jl index 6ae53407..80b96d33 100644 --- a/src/frontend/fit/standard_errors/hessian.jl +++ b/src/frontend/fit/standard_errors/hessian.jl @@ -35,20 +35,21 @@ function se_hessian(fit::SemFit; method = :finitediff) end # Addition functions ------------------------------------------------------------- -function H_scaling(model::AbstractSemSingle) - if length(model.loss.functions) > 1 - @warn "Hessian scaling for multiple loss functions is not implemented yet" - end - return H_scaling(model.loss.functions[1], model) -end - -H_scaling(lossfun::SemML, model::AbstractSemSingle) = 2 / (nsamples(model) - 1) +H_scaling(loss::SemML) = 2 / (nsamples(loss) - 1) -function H_scaling(lossfun::SemWLS, model::AbstractSemSingle) +function H_scaling(loss::SemWLS) @warn "Standard errors for WLS are only correct if a GLS weight matrix (the default) is used." - return 2 / (nsamples(model) - 1) + return 2 / (nsamples(loss) - 1) end -H_scaling(lossfun::SemFIML, model::AbstractSemSingle) = 2 / nsamples(model) +H_scaling(loss::SemFIML) = 2 / nsamples(loss) -H_scaling(model::SemEnsemble) = 2 / nsamples(model) +function H_scaling(model::AbstractSem) + semterms = SEM.sem_terms(model) + if length(semterms) > 1 + #@warn "Hessian scaling for multiple loss functions is not implemented yet" + return 2 / nsamples(model) + else + return length(semterms) >= 1 ? H_scaling(loss(semterms[1])) : 1.0 + end +end diff --git a/src/frontend/pretty_printing.jl b/src/frontend/pretty_printing.jl index 2fa970f2..7b6975f6 100644 --- a/src/frontend/pretty_printing.jl +++ b/src/frontend/pretty_printing.jl @@ -32,9 +32,11 @@ end # Loss Function, Implied, Observed, Optimizer ############################################################## -function Base.show(io::IO, struct_inst::SemLossFunction) - print_type_name(io, struct_inst) - print_field_types(io, struct_inst) +function Base.show(io::IO, sem::SemLoss) + println(io, "Structural Equation Model Loss ($(nameof(typeof(sem))))") + println(io, "- Observed: $(nameof(typeof(observed(sem)))) ($(nsamples(sem)) samples)") + println(io, "- Implied: $(nameof(typeof(implied(sem)))) ($(nparams(sem)) parameters)") + println(io, "- Variables: $(nobserved_vars(sem)) observed, $(nlatent_vars(sem)) latent") end function Base.show(io::IO, struct_inst::SemImplied) diff --git a/src/frontend/specification/Sem.jl b/src/frontend/specification/Sem.jl index a47bad4b..684cfa62 100644 --- a/src/frontend/specification/Sem.jl +++ b/src/frontend/specification/Sem.jl @@ -1,7 +1,167 @@ +losstype(::Type{<:LossTerm{L, W, I}}) where {L, W, I} = L +losstype(term::LossTerm) = losstype(typeof(term)) +loss(term::LossTerm) = term.loss +weight(term::LossTerm) = term.weight +id(term::LossTerm) = term.id + +""" + issemloss(term::LossTerm) -> Bool + +Check if a SEM model term is a SEM loss function ([`SemLoss`](@ref)). +""" +issemloss(term::LossTerm) = isa(loss(term), SemLoss) + +for f in ( + :implied, + :observed, + :nsamples, + :observed_vars, + :nobserved_vars, + :vars, + :nvars, + :latent_vars, + :nlatent_vars, + :params, + :nparams, +) + @eval $f(term::LossTerm) = $f(loss(term)) +end + +function Base.show(io::IO, term::LossTerm) + if !isnothing(id(term)) + print(io, ":$(id(term)): ") + end + print(io, nameof(losstype(term))) + if issemloss(term) + print( + io, + " ($(nsamples(term)) samples, $(nobserved_vars(term)) observed, $(nlatent_vars(term)) latent variables)", + ) + end + if !isnothing(weight(term)) + print(io, " w=$(round(weight(term), digits=3))") + else + print(io, " w=1") + end +end + ############################################################################################ # constructor for Sem types ############################################################################################ +function multigroup_weights(models, n) + nsamples_total = sum(nsamples, models) + uniform_lossfun = check_single_lossfun(models...; throw_error = false) + if !uniform_lossfun + @info """ + Your ensemble model contains heterogeneous loss functions. + Default weights of (#samples per group/#total samples) will be used + """ + return [(nsamples(model)) / (nsamples_total) for model in models] + end + lossfun = models[1].loss.functions[1] + if !applicable(mg_correction, lossfun) + @info """ + We don't know how to choose group weights for the specified loss function. + Default weights of (#samples per group/#total samples) will be used + """ + return [(nsamples(model)) / (nsamples_total) for model in models] + end + c = mg_correction(lossfun) + return [(nsamples(model)+c) / (nsamples_total+n*c) for model in models] +end + +function Sem( + loss_terms...; + params::Union{Vector{Symbol}, Nothing} = nothing, + default_sem_weights = :nsamples, +) + default_sem_weights ∈ [:nsamples, :uniform, :one] || + throw(ArgumentError("Unsupported default_sem_weights=:$default_sem_weights")) + # assemble a list of weighted losses and check params equality + terms = Vector{LossTerm}() + params = !isnothing(params) ? copy(params) : params + has_sem_weights = false + nsems = 0 + for inp_term in loss_terms + if inp_term isa AbstractLoss # term + term = inp_term + term_w = nothing + term_id = nothing + elseif inp_term isa Pair + if inp_term[1] isa AbstractLoss # term => weight + term, term_w = inp_term + term_id = nothing + elseif inp_term[2] isa AbstractLoss # id => term + term_id, term = inp_term + term_w = nothing + elseif inp_term[2] isa Pair # id => term => weight + term_id, (term, term_w) = inp_term + isa(term, AbstractLoss) || throw( + ArgumentError( + "AbstractLoss expected as a second argument of a loss term double pair (id => loss => weight), $(nameof(typeof(term))) found", + ), + ) + end + elseif inp_term isa LossTerm + term_id = id(inp_term) + term = loss(inp_term) + term_w = weight(inp_term) + else + "[id =>] AbstractLoss [=> weight] expected as a loss term, $(nameof(typeof(inp_term))) found" |> + ArgumentError |> + throw + end + + if term isa SemLoss + nsems += 1 + has_sem_weights |= !isnothing(term_w) + # check integrity + if isnothing(params) + params = SEM.params(term) + elseif params != SEM.params(term) + # FIXME the suggestion might no longer be relevant, since ParTable also stores params order + """ + The parameters of your SEM models do not match. + Maybe you tried to specify models of an ensemble via ParameterTables. + In that case, you may use RAMMatrices instead. + """ |> error + end + check_observed_vars(term) + elseif !(term isa AbstractLoss) + "AbstractLoss term expected at $(length(terms)+1) position, $(nameof(typeof(term))) found" |> + ArgumentError |> + throw + end + push!(terms, LossTerm(term, term_id, term_w)) + end + isnothing(params) && throw(ErrorException("No SEM models provided.")) + + if !has_sem_weights && nsems > 1 + # set the weights of SEMs in the ensemble + if default_sem_weights == :nsamples + # weight SEM by the number of samples + nsamples_total = sum(nsamples(term) for term in terms if issemloss(term)) + for (i, term) in enumerate(terms) + if issemloss(term) + terms[i] = + LossTerm(loss(term), id(term), nsamples(term) / nsamples_total) + end + end + elseif default_sem_weights == :uniform # uniform weights + for (i, term) in enumerate(terms) + if issemloss(term) + terms[i] = LossTerm(loss(term), id(term), 1 / nsems) + end + end + elseif default_sem_weights == :one # do nothing + end + end + + terms_tuple = Tuple(terms) + return Sem{typeof(terms_tuple)}(terms_tuple, params) +end + function Sem(; specification = ParameterTable, observed::O = SemObservedData, @@ -13,99 +173,127 @@ function Sem(; set_field_type_kwargs!(kwdict, observed, implied, loss, O, I) - observed, implied, loss = get_fields!(kwdict, specification, observed, implied, loss) - - sem = Sem(observed, implied, loss) + loss = get_fields!(kwdict, specification, observed, implied, loss) - return sem + return Sem(loss...) end +############################################################################################ +# functions +############################################################################################ + +params(model::AbstractSem) = model.params + """ - implied(model::AbstractSemSingle) -> SemImplied + loss_terms(model::AbstractSem) -Returns the [*implied*](@ref SemImplied) part of a model. +Returns a tuple of all [`LossTerm`](@ref) weighted terms in the SEM model. + +See also [`sem_terms`](@ref), [`loss_term`](@ref). """ -implied(model::AbstractSemSingle) = model.implied +loss_terms(model::AbstractSem) = model.loss_terms +nloss_terms(model::AbstractSem) = length(loss_terms(model)) -nvars(model::AbstractSemSingle) = nvars(implied(model)) -nobserved_vars(model::AbstractSemSingle) = nobserved_vars(implied(model)) -nlatent_vars(model::AbstractSemSingle) = nlatent_vars(implied(model)) +""" + sem_terms(model::AbstractSem) -vars(model::AbstractSemSingle) = vars(implied(model)) -observed_vars(model::AbstractSemSingle) = observed_vars(implied(model)) -latent_vars(model::AbstractSemSingle) = latent_vars(implied(model)) +Returns a tuple of all weighted SEM terms in the SEM model. -param_labels(model::AbstractSemSingle) = param_labels(implied(model)) -nparams(model::AbstractSemSingle) = nparams(implied(model)) +In comparison to [`loss_terms`](@ref) that returns all model terms, including e.g. +regularization terms, this function returns only the [`SemLoss`] terms. +See also [`loss_terms`](@ref), [`sem_term`](@ref). """ - observed(model::AbstractSemSingle) -> SemObserved +sem_terms(model::AbstractSem) = Tuple(term for term in loss_terms(model) if issemloss(term)) +nsem_terms(model::AbstractSem) = sum(issemloss, loss_terms(model)) + +nsamples(model::AbstractSem) = + sum(term -> issemloss(term) ? nsamples(term) : 0, loss_terms(model)) -Returns the [*observed*](@ref SemObserved) part of a model. """ -observed(model::AbstractSemSingle) = model.observed + loss_term(model::AbstractSem, id::Any) -> AbstractLoss -nsamples(model::AbstractSemSingle) = nsamples(observed(model)) +Returns the loss term with the specified `id` from the `model`. +Throws an error if the model has no term with the specified `id`. +See also [`loss_terms`](@ref). """ - loss(model::AbstractSemSingle) -> SemLoss +function loss_term(model::AbstractSem, id::Any) + for term in loss_terms(model) + if SEM.id(term) == id + return loss(term) + end + end + error("No loss term with id=$id found") +end -Returns the [*loss*](@ref SemLoss) function of a model. """ -loss(model::AbstractSemSingle) = model.loss - -# sum of samples in all sub-models -nsamples(ensemble::SemEnsemble) = sum(nsamples, ensemble.sems) - -function SemFiniteDiff(; - specification = ParameterTable, - observed::O = SemObservedData, - implied::I = RAM, - loss::L = SemML, - kwargs..., -) where {O, I, L} - kwdict = Dict{Symbol, Any}(kwargs...) + sem_term(model::AbstractSem, [id]) -> SemLoss - set_field_type_kwargs!(kwdict, observed, implied, loss, O, I) +Returns the SEM loss term with the specified `id` from the `model`. +Throws an error if the model has no term with the specified `id` or +if it is not of a [`SemLoss`](@ref) type. - observed, implied, loss = get_fields!(kwdict, specification, observed, implied, loss) +If no `id` is specified and the model contains only one SEM term, the term is returned. +Throws an error if the model contains multiple SEM terms. - sem = SemFiniteDiff(observed, implied, loss) +See also [`loss_term`](@ref), [`sem_terms`](@ref). +""" +function sem_term(model::AbstractSem, id::Any) + term = loss_term(model, id) + issemloss(term) || error("Loss term with id=$id ($(typeof(term))) is not a SEM term") + return term +end - return sem +function sem_term(model::AbstractSem, _::Nothing = nothing) + if nsem_terms(model) != 1 + error( + "Model contains $(nsem_terms(model)) SEM terms, you have to specify a specific term", + ) + end + for term in loss_terms(model) + issemloss(term) && return loss(term) + end + error("Unreachable reached") end -############################################################################################ -# functions -############################################################################################ +# wrappers arounds a single SemLoss term +observed(model::AbstractSem, id::Nothing = nothing) = observed(sem_term(model, id)) +implied(model::AbstractSem, id::Nothing = nothing) = implied(sem_term(model, id)) +vars(model::AbstractSem, id::Nothing = nothing) = vars(implied(model, id)) +observed_vars(model::AbstractSem, id::Nothing = nothing) = observed_vars(implied(model, id)) +latent_vars(model::AbstractSem, id::Nothing = nothing) = latent_vars(implied(model, id)) function set_field_type_kwargs!(kwargs, observed, implied, loss, O, I) kwargs[:observed_type] = O <: Type ? observed : typeof(observed) kwargs[:implied_type] = I <: Type ? implied : typeof(implied) if loss isa SemLoss - kwargs[:loss_types] = [ - lossfun isa SemLossFunction ? typeof(lossfun) : lossfun for - lossfun in loss.functions - ] - elseif applicable(iterate, loss) kwargs[:loss_types] = - [lossfun isa SemLossFunction ? typeof(lossfun) : lossfun for lossfun in loss] + [aloss isa SemLoss ? typeof(aloss) : aloss for aloss in loss.functions] + elseif applicable(iterate, loss) + kwargs[:loss_types] = [aloss isa SemLoss ? typeof(aloss) : aloss for aloss in loss] else - kwargs[:loss_types] = [loss isa SemLossFunction ? typeof(loss) : loss] + kwargs[:loss_types] = [loss isa SemLoss ? typeof(loss) : loss] end end # construct Sem fields -function get_fields!(kwargs, specification, observed, implied, loss) - if !isa(specification, SemSpecification) - specification = specification(; kwargs...) +function get_fields!(kwargs, spec, observed, implied, loss) + if !isa(spec, SemSpecification) + spec = spec(; kwargs...) end # observed if !isa(observed, SemObserved) - observed = observed(; specification, kwargs...) + observed = if spec isa EnsembleParameterTable + Dict( + term_id => observed(; specification = term_spec, kwargs...) for + (term_id, term_spec) in pairs(spec.tables) + ) + else + observed(; specification = spec, kwargs...) + end end - kwargs[:observed] = observed # implied if !isa(implied, SemImplied) @@ -117,95 +305,98 @@ function get_fields!(kwargs, specification, observed, implied, loss) else implied_kwargs = kwargs end - implied = implied(specification; implied_kwargs...) + implied = if spec isa EnsembleParameterTable + Dict( + term_id => implied(term_spec; implied_kwargs...) for + (term_id, term_spec) in pairs(spec.tables) + ) + else + implied(spec; implied_kwargs...) + end end - kwargs[:implied] = implied - kwargs[:nparams] = nparams(implied) - # loss - loss = get_SemLoss(loss; specification, kwargs...) - kwargs[:loss] = loss + loss_kwargs = copy(kwargs) + loss_kwargs[:nparams] = nparams(spec) + loss = build_SemTerms(loss, observed, implied; loss_kwargs...) - return observed, implied, loss + return loss end # construct loss field -function get_SemLoss(loss; kwargs...) +function build_SemTerms(loss, observed, implied; kwargs...) + function build_SemLoss(aloss, observed, implied) + if loss isa AbstractLoss + return loss + elseif aloss <: SemLoss{O, I} where {O, I} + return aloss(observed, implied; kwargs...) + else + return aloss(; kwargs...) + end + end + if loss isa SemLoss - nothing + return loss elseif applicable(iterate, loss) - loss_out = [] - for lossfun in loss - if isa(lossfun, SemLossFunction) - push!(loss_out, lossfun) - else - lossfun = lossfun(; kwargs...) - push!(loss_out, lossfun) - end - end - loss = SemLoss(loss_out...; kwargs...) + return [build_SemLoss(aloss, observed, implied) for aloss in loss] else - if !isa(loss, SemLossFunction) - loss = SemLoss(loss(; kwargs...); kwargs...) + if isa(observed, AbstractDict) && isa(implied, AbstractDict) + observed_ids = Set(keys(observed)) + implied_ids = Set(keys(implied)) + if observed_ids != implied_ids + """" + The term ids of the observed and the implied data do not match. + Observed term ids: $(observed_ids), implied term ids: $(implied_ids) + """ |> + ArgumentError |> + throw + end + loss_out = [ + begin + term_implied = implied[term_id] + if observed_vars(term_observed) != observed_vars(term_implied) + "observed_vars differ between the observed and the implied for the term $term_id" |> + ArgumentError |> + throw + end + LossTerm( + build_SemLoss(loss, term_observed, term_implied), + term_id, + nothing, + ) + end for (term_id, term_observed) in pairs(observed) + ] + return loss_out else - loss = SemLoss(loss; kwargs...) + if observed_vars(observed) != observed_vars(implied) + "observed_vars differ between the observed and the implied" |> + ArgumentError |> + throw + end + return (build_SemLoss(loss, observed, implied),) end end - return loss +end + +function update_observed(sem::Sem, new_observed; kwargs...) + new_terms = Tuple( + update_observed(lossterm.loss, new_observed; kwargs...) for + lossterm in loss_terms(sem) + ) + return Sem(new_terms...) end ############################################################## # pretty printing ############################################################## -function Base.show(io::IO, sem::Sem{O, I, L}) where {O, I, L} - lossfuntypes = @. string(nameof(typeof(sem.loss.functions))) - lossfuntypes = " " .* lossfuntypes .* ("\n") - print(io, "Structural Equation Model \n") - print(io, "- Loss Functions \n") - print(io, lossfuntypes...) - print(io, "- Fields \n") - print(io, " observed: $(nameof(O)) \n") - print(io, " implied: $(nameof(I)) \n") -end - -function Base.show(io::IO, sem::SemFiniteDiff{O, I, L}) where {O, I, L} - lossfuntypes = @. string(nameof(typeof(sem.loss.functions))) - lossfuntypes = " " .* lossfuntypes .* ("\n") - print(io, "Structural Equation Model : Finite Diff Approximation\n") - print(io, "- Loss Functions \n") - print(io, lossfuntypes...) - print(io, "- Fields \n") - print(io, " observed: $(nameof(O)) \n") - print(io, " implied: $(nameof(I)) \n") -end - -function Base.show(io::IO, loss::SemLoss) - lossfuntypes = @. string(nameof(typeof(loss.functions))) - lossfuntypes = " " .* lossfuntypes .* ("\n") - print(io, "SemLoss \n") - print(io, "- Loss Functions \n") - print(io, lossfuntypes...) - print(io, "- Weights \n") - for weight in loss.weights - if isnothing(weight.w) - print(io, " one \n") - else - print(io, "$(round.(weight.w, digits = 2)) \n") - end - end -end - -function Base.show(io::IO, models::SemEnsemble) - print(io, "SemEnsemble \n") - print(io, "- Number of Models: $(models.n) \n") - print(io, "- Weights: $(round.(models.weights, digits = 2)) \n") - - print(io, "\n", "Models: \n") - print(io, "===============================================", "\n") - for (model, i) in zip(models.sems, 1:models.n) - print(io, "---------------------- ", i, " ----------------------", "\n") - print(io, model) +function Base.show(io::IO, sem::AbstractSem) + println(io, "Structural Equation Model ($(nameof(typeof(sem))))") + println(io, "- $(nparams(sem)) parameters") + println(io, "- Loss terms:") + for term in loss_terms(sem) + print(io, " - ") + print(io, term) + println(io) end end diff --git a/src/implied/RAM/generic.jl b/src/implied/RAM/generic.jl index d57500a3..f1c1e08d 100644 --- a/src/implied/RAM/generic.jl +++ b/src/implied/RAM/generic.jl @@ -154,16 +154,11 @@ end ### methods ############################################################################################ -function update!( - targets::EvaluationTargets, - implied::RAM, - model::AbstractSemSingle, - param_labels, -) - materialize!(implied.A, implied.ram_matrices.A, param_labels) - materialize!(implied.S, implied.ram_matrices.S, param_labels) +function update!(targets::EvaluationTargets, implied::RAM, params) + materialize!(implied.A, implied.ram_matrices.A, params) + materialize!(implied.S, implied.ram_matrices.S, params) if !isnothing(implied.M) - materialize!(implied.M, implied.ram_matrices.M, param_labels) + materialize!(implied.M, implied.ram_matrices.M, params) end parent(implied.I_A) .= .-implied.A diff --git a/src/implied/RAM/symbolic.jl b/src/implied/RAM/symbolic.jl index 0f7868da..4c9bda91 100644 --- a/src/implied/RAM/symbolic.jl +++ b/src/implied/RAM/symbolic.jl @@ -176,12 +176,7 @@ end ### objective, gradient, hessian ############################################################################################ -function update!( - targets::EvaluationTargets, - implied::RAMSymbolic, - model::AbstractSemSingle, - par, -) +function update!(targets::EvaluationTargets, implied::RAMSymbolic, par) implied.Σ_eval!(implied.Σ, par) if MeanStruct(implied) === HasMeanStruct implied.μ_eval!(implied.μ, par) @@ -223,7 +218,10 @@ end function eval_Σ_symbolic(S, I_A⁻¹, F; vech::Bool = false, simplify::Bool = false) Σ = F * I_A⁻¹ * S * permutedims(I_A⁻¹) * permutedims(F) Σ = Array(Σ) - vech && (Σ = SEM.vech(Σ)) + if vech + n = size(Σ, 1) + Σ = [Σ[i, j] for j in 1:n for i in j:n] + end if simplify Threads.@threads for i in eachindex(Σ) Σ[i] = Symbolics.simplify(Σ[i]) diff --git a/src/implied/empty.jl b/src/implied/empty.jl index 82a6c946..a327ee13 100644 --- a/src/implied/empty.jl +++ b/src/implied/empty.jl @@ -13,8 +13,8 @@ Empty placeholder for models that don't need an implied part. - `specification`: either a `RAMMatrices` or `ParameterTable` object # Examples -A multigroup model with ridge regularization could be specified as a `SemEnsemble` with one -model per group and an additional model with `ImpliedEmpty` and `SemRidge` for the regularization part. +A multigroup model with ridge regularization could be specified as a `Sem` with one +SEM term (`SemLoss`) per group and an additional `SemRidge` regularization term. # Extended help @@ -45,7 +45,7 @@ end ### methods ############################################################################################ -update!(targets::EvaluationTargets, implied::ImpliedEmpty, par, model) = nothing +update!(targets::EvaluationTargets, implied::ImpliedEmpty, par) = nothing ############################################################################################ ### Recommended methods diff --git a/src/loss/ML/FIML.jl b/src/loss/ML/FIML.jl index 8572b15a..fdedf398 100644 --- a/src/loss/ML/FIML.jl +++ b/src/loss/ML/FIML.jl @@ -75,23 +75,27 @@ Can handle observed data with missing values. # Constructor - SemFIML(; observed::SemObservedMissing, implied::SemImplied, kwargs...) + SemFIML(observed::SemObservedMissing, implied::SemImplied) # Arguments - `observed::SemObservedMissing`: the observed part of the model (see [`SemObservedMissing`](@ref)) - `implied::SemImplied`: the implied part of the model + (see [`SemImplied`](@ref)) # Examples ```julia -my_fiml = SemFIML(observed = my_observed, implied = my_implied) +my_fiml = SemFIML(my_observed, my_implied) ``` # Interfaces Analytic gradients are available. """ -struct SemFIML{T, W} <: SemLossFunction +struct SemFIML{O, I, T, W} <: SemLoss{O, I} hessianeval::ExactHessian + + observed::O + implied::I patterns::Vector{SemFIMLPattern{T}} imp_inv::Matrix{T} # implied inverse @@ -105,7 +109,7 @@ end ### Constructors ############################################################################################ -function SemFIML(; observed::SemObservedMissing, implied, specification, kwargs...) +function SemFIML(observed::SemObservedMissing, implied::SemImplied; kwargs...) if MeanStruct(implied) === NoMeanStruct """ Full information maximum likelihood (FIML) can only be used with a meanstructure. @@ -117,6 +121,8 @@ function SemFIML(; observed::SemObservedMissing, implied, specification, kwargs. return SemFIML( ExactHessian(), + observed, + implied, [SemFIMLPattern(pat) for pat in observed.patterns], zeros(nobserved_vars(observed), nobserved_vars(observed)), CommutationMatrix(nvars(implied)), @@ -128,30 +134,28 @@ end ### methods ############################################################################################ -function evaluate!( - objective, - gradient, - hessian, - loss::SemFIML, - implied::SemImplied, - model::AbstractSemSingle, - params, -) +function evaluate!(objective, gradient, hessian, loss::SemFIML, params) isnothing(hessian) || error("Hessian not implemented for FIML") - if !check(loss, model) + implied = SEM.implied(loss) + observed = SEM.observed(loss) + + copyto!(loss.imp_inv, implied.Σ) + Σ_chol = cholesky!(Symmetric(loss.imp_inv); check = false) + + if !isposdef(Σ_chol) isnothing(objective) || (objective = non_posdef_return(params)) isnothing(gradient) || fill!(gradient, 1) return objective end - prepare!(loss, model) + @inbounds for (patloss, pat) in zip(loss.patterns, observed.patterns) + prepare!(patloss, pat, implied) + end - scale = inv(nsamples(observed(model))) - isnothing(objective) || - (objective = scale * F_FIML(eltype(params), loss, observed(model), model)) - isnothing(gradient) || - (∇F_FIML!(gradient, loss, observed(model), model); gradient .*= scale) + scale = inv(nsamples(observed)) + isnothing(objective) || (objective = scale * F_FIML(eltype(params), loss)) + isnothing(gradient) || (∇F_FIML!(gradient, loss); gradient .*= scale) return objective end @@ -167,27 +171,14 @@ update_observed(loss::SemFIML, observed::SemObserved; kwargs...) = ### additional functions ############################################################################################ -function prepare!(loss::SemFIML, observed::SemObservedMissing, implied::SemImplied) - @inbounds for (patloss, pat) in zip(loss.patterns, observed.patterns) - prepare!(patloss, pat, implied.Σ, implied.μ) - end -end - -prepare!(loss::SemFIML, model::AbstractSemSingle) = - prepare!(loss, observed(model), implied(model)) - -function check(loss::SemFIML, model::AbstractSemSingle) - copyto!(loss.imp_inv, implied(model).Σ) - a = cholesky!(Symmetric(loss.imp_inv); check = false) - return isposdef(a) +function ∇F_fiml_outer!(G, JΣ, Jμ, loss::SemFIML{O, I}) where {O, I <: SemImpliedSymbolic} + mul!(G, loss.implied.∇Σ', JΣ) # should be transposed + mul!(G, loss.implied.∇μ', Jμ, -1, 1) end -function ∇F_fiml_outer!(G, JΣ, Jμ, loss::SemFIML, implied::SemImpliedSymbolic, model) - mul!(G, implied.∇Σ', JΣ) # should be transposed - mul!(G, implied.∇μ', Jμ, -1, 1) -end +function ∇F_fiml_outer!(G, JΣ, Jμ, loss::SemFIML) + implied = loss.implied -function ∇F_fiml_outer!(G, JΣ, Jμ, loss::SemFIML, implied, model) Iₙ = sparse(1.0I, size(implied.A)...) P = kron(implied.F⨉I_A⁻¹, implied.F⨉I_A⁻¹) Q = kron(implied.S * implied.I_A⁻¹', Iₙ) @@ -203,25 +194,20 @@ function ∇F_fiml_outer!(G, JΣ, Jμ, loss::SemFIML, implied, model) mul!(G, ∇μ', Jμ, -1, 1) end -function F_FIML( - ::Type{T}, - loss::SemFIML, - observed::SemObservedMissing, - model::AbstractSemSingle, -) where {T} +function F_FIML(::Type{T}, loss::SemFIML) where {T} F = zero(T) - for (patloss, pat) in zip(loss.patterns, observed.patterns) + for (patloss, pat) in zip(loss.patterns, loss.observed.patterns) F += objective(patloss, pat) end return F end -function ∇F_FIML!(G, loss::SemFIML, observed::SemObservedMissing, model::AbstractSemSingle) - Jμ = zeros(nobserved_vars(model)) - JΣ = zeros(nobserved_vars(model)^2) +function ∇F_FIML!(G, loss::SemFIML) + Jμ = zeros(nobserved_vars(loss)) + JΣ = zeros(nobserved_vars(loss)^2) - for (patloss, pat) in zip(loss.patterns, observed.patterns) + for (patloss, pat) in zip(loss.patterns, loss.observed.patterns) gradient!(JΣ, Jμ, patloss, pat) end - ∇F_fiml_outer!(G, JΣ, Jμ, loss, implied(model), model) + ∇F_fiml_outer!(G, JΣ, Jμ, loss) end diff --git a/src/loss/ML/ML.jl b/src/loss/ML/ML.jl index aae1dada..2d449d73 100644 --- a/src/loss/ML/ML.jl +++ b/src/loss/ML/ML.jl @@ -8,36 +8,41 @@ Maximum likelihood estimation. # Constructor - SemML(; observed, approximate_hessian = false, kwargs...) + SemML(observed, implied; approximate_hessian = false) # Arguments - `observed::SemObserved`: the observed part of the model +- `implied::SemImplied`: [`SemImplied`](@ref) instance - `approximate_hessian::Bool`: if hessian-based optimization is used, should the hessian be swapped for an approximation # Examples ```julia -my_ml = SemML(observed = my_observed) +my_ml = SemML(my_observed, my_implied) ``` # Interfaces Analytic gradients are available, and for models without a meanstructure -and RAMSymbolic implied type, also analytic hessians. +and `RAMSymbolic` implied type, also analytic hessians. """ -struct SemML{HE <: HessianEval, INV, M, M2} <: SemLossFunction +struct SemML{O, I, HE <: HessianEval, INV, M, M2} <: SemLoss{O, I} + observed::O + implied::I hessianeval::HE Σ⁻¹::INV Σ⁻¹Σₒ::M meandiff::M2 - - SemML{HE}(args...) where {HE <: HessianEval} = - new{HE, map(typeof, args)...}(HE(), args...) end ############################################################################################ ### Constructors ############################################################################################ -function SemML(; observed::SemObserved, approximate_hessian::Bool = false, kwargs...) +function SemML( + observed::SemObserved, + implied::SemImplied; + approximate_hessian::Bool = false, + kwargs..., +) if observed isa SemObservedMissing @warn """ ML estimation with `SemObservedMissing` will use an approximate covariance and mean estimated with EM algorithm. @@ -51,12 +56,25 @@ function SemML(; observed::SemObserved, approximate_hessian::Bool = false, kwarg ) """ end + # check integrity + check_observed_vars(observed, implied) + he = approximate_hessian ? ApproxHessian() : ExactHessian() obsmean = obs_mean(observed) obscov = obs_cov(observed) meandiff = isnothing(obsmean) ? nothing : copy(obsmean) - return SemML{approximate_hessian ? ApproxHessian : ExactHessian}( + return SemML{ + typeof(observed), + typeof(implied), + typeof(he), + typeof(obscov), + typeof(obscov), + typeof(meandiff), + }( + observed, + implied, + he, similar(obscov), similar(obscov), meandiff, @@ -74,20 +92,20 @@ function evaluate!( objective, gradient, hessian, - semml::SemML, - implied::SemImpliedSymbolic, - model::AbstractSemSingle, + loss::SemML{<:Any, <:SemImpliedSymbolic}, par, ) + implied = SEM.implied(loss) + if !isnothing(hessian) (MeanStruct(implied) === HasMeanStruct) && throw(DomainError(H, "hessian of ML + meanstructure is not available")) end Σ = implied.Σ - Σₒ = obs_cov(observed(model)) - Σ⁻¹Σₒ = semml.Σ⁻¹Σₒ - Σ⁻¹ = semml.Σ⁻¹ + Σₒ = obs_cov(observed(loss)) + Σ⁻¹Σₒ = loss.Σ⁻¹Σₒ + Σ⁻¹ = loss.Σ⁻¹ copyto!(Σ⁻¹, Σ) Σ_chol = cholesky!(Symmetric(Σ⁻¹); check = false) @@ -105,7 +123,7 @@ function evaluate!( if MeanStruct(implied) === HasMeanStruct μ = implied.μ - μₒ = obs_mean(observed(model)) + μₒ = obs_mean(observed(loss)) μ₋ = μₒ - μ isnothing(objective) || (objective += dot(μ₋, Σ⁻¹, μ₋)) @@ -124,7 +142,7 @@ function evaluate!( mul!(gradient, ∇Σ', J') end if !isnothing(hessian) - if HessianEval(semml) === ApproxHessian + if HessianEval(loss) === ApproxHessian mul!(hessian, ∇Σ' * kron(Σ⁻¹, Σ⁻¹), ∇Σ, 2, 0) else ∇²Σ = implied.∇²Σ @@ -143,24 +161,17 @@ end ############################################################################################ ### Non-Symbolic Implied Types -function evaluate!( - objective, - gradient, - hessian, - semml::SemML, - implied::RAM, - model::AbstractSemSingle, - par, -) +function evaluate!(objective, gradient, hessian, loss::SemML, par) if !isnothing(hessian) error("hessian of ML + non-symbolic implied type is not available") end - Σ = implied.Σ - Σₒ = obs_cov(observed(model)) - Σ⁻¹Σₒ = semml.Σ⁻¹Σₒ - Σ⁻¹ = semml.Σ⁻¹ + implied = SEM.implied(loss) + Σ = implied.Σ + Σₒ = obs_cov(observed(loss)) + Σ⁻¹Σₒ = loss.Σ⁻¹Σₒ + Σ⁻¹ = loss.Σ⁻¹ copyto!(Σ⁻¹, Σ) Σ_chol = cholesky!(Symmetric(Σ⁻¹); check = false) if !isposdef(Σ_chol) @@ -179,7 +190,7 @@ function evaluate!( if MeanStruct(implied) === HasMeanStruct μ = implied.μ - μₒ = obs_mean(observed(model)) + μₒ = obs_mean(observed(loss)) μ₋ = μₒ - μ objective += dot(μ₋, Σ⁻¹, μ₋) end @@ -198,7 +209,7 @@ function evaluate!( if MeanStruct(implied) === HasMeanStruct μ = implied.μ - μₒ = obs_mean(observed(model)) + μₒ = obs_mean(observed(loss)) ∇M = implied.∇M M = implied.M μ₋ = μₒ - μ @@ -229,16 +240,17 @@ end ### recommended methods ############################################################################################ -update_observed(lossfun::SemML, observed::SemObservedMissing; kwargs...) = +update_observed(loss::SemML, observed::SemObservedMissing; kwargs...) = error("ML estimation does not work with missing data - use FIML instead") -function update_observed(lossfun::SemML, observed::SemObserved; kwargs...) - if size(lossfun.Σ⁻¹) == size(obs_cov(observed)) - return lossfun +function update_observed(loss::SemML, observed::SemObserved; kwargs...) + if (obs_cov(loss) == obs_cov(observed)) && (obs_mean(loss) == obs_mean(observed)) + return loss # no change else - return SemML(; - observed = observed, - approximate_hessian = HessianEval(lossfun) == ApproxHessian, + return SemML( + observed, + loss.implied; + approximate_hessian = HessianEval(loss) == ApproxHessian, kwargs..., ) end diff --git a/src/loss/WLS/WLS.jl b/src/loss/WLS/WLS.jl index 9de011f6..5c4cb252 100644 --- a/src/loss/WLS/WLS.jl +++ b/src/loss/WLS/WLS.jl @@ -9,8 +9,8 @@ At the moment only available with the `RAMSymbolic` implied type. # Constructor - SemWLS(; - observed, implied, + SemWLS( + observed::SemObserved, implied::SemImplied; wls_weight_matrix = nothing, wls_weight_matrix_mean = nothing, approximate_hessian = false, @@ -18,7 +18,7 @@ At the moment only available with the `RAMSymbolic` implied type. # Arguments - `observed`: the `SemObserved` part of the model -- `implied::SemImplied`: the implied part of the model +- `implied`: the `SemImplied` part of the model - `approximate_hessian::Bool`: should the hessian be swapped for an approximation - `wls_weight_matrix`: the weight matrix for weighted least squares. Defaults to GLS estimation (``0.5*(D^T*kron(S,S)*D)`` where D is the duplication matrix @@ -28,29 +28,37 @@ At the moment only available with the `RAMSymbolic` implied type. # Examples ```julia -my_wls = SemWLS(observed = my_observed, implied = my_implied) +my_wls = SemWLS(my_observed, my_implied) ``` # Interfaces Analytic gradients are available, and for models without a meanstructure also analytic hessians. """ -struct SemWLS{HE <: HessianEval, Vt, St, C} <: SemLossFunction +struct SemWLS{O, I, HE <: HessianEval, Vt, St, C} <: SemLoss{O, I} + observed::O + implied::I + hessianeval::HE V::Vt σₒ::St V_μ::C + + SemWLS(observed, implied, ::Type{HE}, args...) where {HE <: HessianEval} = + new{typeof(observed), typeof(implied), HE, map(typeof, args)...}( + observed, + implied, + HE(), + args..., + ) end ############################################################################################ ### Constructors ############################################################################################ -SemWLS{HE}(args...) where {HE <: HessianEval} = - SemWLS{HE, map(typeof, args)...}(HE(), args...) - -function SemWLS(; +function SemWLS( observed::SemObserved, - implied::SemImplied, + implied::SemImplied; wls_weight_matrix::Union{AbstractMatrix, Nothing} = nothing, wls_weight_matrix_mean::Union{AbstractMatrix, Nothing} = nothing, approximate_hessian::Bool = false, @@ -75,14 +83,19 @@ function SemWLS(; ArgumentError |> throw end + # check integrity + check_observed_vars(observed, implied) nobs_vars = nobserved_vars(observed) tril_ind = filter(x -> (x[1] >= x[2]), CartesianIndices(obs_cov(observed))) s = obs_cov(observed)[tril_ind] - size(s) == size(implied.Σ) || - throw(DimensionMismatch("SemWLS requires implied covariance to be in vech-ed form " * - "(vectorized lower triangular part of Σ matrix): $(size(s)) expected, $(size(implied.Σ)) found.\n" * - "$(nameof(typeof(implied))) must be constructed with vech=true.")) + size(s) == size(implied.Σ) || throw( + DimensionMismatch( + "SemWLS requires implied covariance to be in vech-ed form " * + "(vectorized lower triangular part of Σ matrix): $(size(s)) expected, $(size(implied.Σ)) found.\n" * + "$(nameof(typeof(implied))) must be constructed with vech=true.", + ), + ) # compute V here if isnothing(wls_weight_matrix) @@ -101,13 +114,12 @@ function SemWLS(; if MeanStruct(implied) == HasMeanStruct if isnothing(wls_weight_matrix_mean) - @warn "Computing WLS weight matrix for the meanstructure using obs_cov()" + @info "Computing WLS weight matrix for the meanstructure using obs_cov()" wls_weight_matrix_mean = inv(obs_cov(observed)) - else - size(wls_weight_matrix_mean) == (nobs_vars, nobs_vars) || DimensionMismatch( - "wls_weight_matrix_mean has to be of size $(nobs_vars)×$(nobs_vars)", - ) end + size(wls_weight_matrix_mean) == (nobs_vars, nobs_vars) || DimensionMismatch( + "wls_weight_matrix_mean has to be of size $(nobs_vars)×$(nobs_vars)", + ) else isnothing(wls_weight_matrix_mean) || @warn "Ignoring wls_weight_matrix_mean since meanstructure is disabled" @@ -115,31 +127,25 @@ function SemWLS(; end HE = approximate_hessian ? ApproxHessian : ExactHessian - return SemWLS{HE}(wls_weight_matrix, s, wls_weight_matrix_mean) + return SemWLS(observed, implied, HE, wls_weight_matrix, s, wls_weight_matrix_mean) end ############################################################################ ### methods ############################################################################ -function evaluate!( - objective, - gradient, - hessian, - semwls::SemWLS, - implied::SemImpliedSymbolic, - model::AbstractSemSingle, - par, -) +function evaluate!(objective, gradient, hessian, loss::SemWLS, par) + implied = SEM.implied(loss) + if !isnothing(hessian) && (MeanStruct(implied) === HasMeanStruct) error("hessian of WLS with meanstructure is not available") end - V = semwls.V + V = loss.V ∇σ = implied.∇Σ σ = implied.Σ - σₒ = semwls.σₒ + σₒ = loss.σₒ σ₋ = σₒ - σ isnothing(objective) || (objective = dot(σ₋, V, σ₋)) @@ -152,17 +158,17 @@ function evaluate!( gradient .*= -2 end isnothing(hessian) || (mul!(hessian, ∇σ' * V, ∇σ, 2, 0)) - if !isnothing(hessian) && (HessianEval(semwls) === ExactHessian) + if !isnothing(hessian) && (HessianEval(loss) === ExactHessian) ∇²Σ = implied.∇²Σ - J = -2 * (σ₋' * semwls.V)' + J = -2 * (σ₋' * loss.V)' implied.∇²Σ_eval!(∇²Σ, J, par) hessian .+= ∇²Σ end if MeanStruct(implied) === HasMeanStruct μ = implied.μ - μₒ = obs_mean(observed(model)) + μₒ = obs_mean(observed(loss)) μ₋ = μₒ - μ - V_μ = semwls.V_μ + V_μ = loss.V_μ if !isnothing(objective) objective += dot(μ₋, V_μ, μ₋) end @@ -179,23 +185,19 @@ end ############################################################################################ function update_observed( - lossfun::SemWLS, + loss::SemWLS, observed::SemObserved; recompute_V = true, kwargs..., ) if recompute_V - return SemWLS(; - observed = observed, - meanstructure = MeanStruct(kwargs[:implied]) == HasMeanStruct, - kwargs..., - ) + return SemWLS(observed, loss.implied; kwargs...) else - return SemWLS(; - observed = observed, - wls_weight_matrix = lossfun.V, - wls_weight_matrix_mean = lossfun.V_μ, - meanstructure = MeanStruct(kwargs[:implied]) == HasMeanStruct, + return SemWLS( + observed, + loss.implied; + wls_weight_matrix = loss.V, + wls_weight_matrix_mean = loss.V_μ, kwargs..., ) end diff --git a/src/loss/abstract.jl b/src/loss/abstract.jl new file mode 100644 index 00000000..bf8585d6 --- /dev/null +++ b/src/loss/abstract.jl @@ -0,0 +1,42 @@ +""" + observed(loss::SemLoss) -> SemObserved + +Returns the [*observed*](@ref SemObserved) part of a model. +""" +observed(loss::SemLoss) = loss.observed + +""" + implied(loss::SemLoss) -> SemImplied + +Returns the [*implied*](@ref SemImplied) part of a model. +""" +implied(loss::SemLoss) = loss.implied + +for f in (:nsamples, :obs_cov, :obs_mean) + @eval $f(loss::SemLoss) = $f(observed(loss)) +end + +for f in ( + :vars, + :nvars, + :latent_vars, + :nlatent_vars, + :observed_vars, + :nobserved_vars, + :params, + :nparams, +) + @eval $f(loss::SemLoss) = $f(implied(loss)) +end + +function check_observed_vars(observed::SemObserved, implied::SemImplied) + isnothing(observed_vars(implied)) || + observed_vars(observed) == observed_vars(implied) || + throw( + ArgumentError( + "Observed variables defined for \"observed\" and \"implied\" do not match.", + ), + ) +end + +check_observed_vars(sem::SemLoss) = check_observed_vars(observed(sem), implied(sem)) diff --git a/src/loss/constant/constant.jl b/src/loss/constant/constant.jl index 3aed5e27..023076cc 100644 --- a/src/loss/constant/constant.jl +++ b/src/loss/constant/constant.jl @@ -4,6 +4,8 @@ ### Types ############################################################################################ """ + SemConstant{C <: Number} <: AbstractLoss + Constant loss term. Can be used for comparability to other packages. # Constructor @@ -15,37 +17,27 @@ Constant loss term. Can be used for comparability to other packages. # Examples ```julia - my_constant = SemConstant(constant_loss = 42.0) + my_constant = SemConstant(42.0) ``` # Interfaces Analytic gradients and hessians are available. """ -struct SemConstant{C} <: SemLossFunction +struct SemConstant{C <: Number} <: AbstractLoss hessianeval::ExactHessian c::C -end -############################################################################################ -### Constructors -############################################################################################ - -function SemConstant(; constant_loss, kwargs...) - return SemConstant(ExactHessian(), constant_loss) + SemConstant(c::Number) = new{typeof(c)}(ExactHessian(), c) end -############################################################################################ -### methods -############################################################################################ +SemConstant(; constant_loss::Number, kwargs...) = SemConstant(constant_loss) -objective(constant::SemConstant, model::AbstractSem, par) = constant.c -gradient(constant::SemConstant, model::AbstractSem, par) = zero(par) -hessian(constant::SemConstant, model::AbstractSem, par) = - zeros(eltype(par), length(par), length(par)) +objective(loss::SemConstant, par) = convert(eltype(par), loss.c) +gradient(loss::SemConstant, par) = zero(par) +hessian(loss::SemConstant, par) = zeros(eltype(par), length(par), length(par)) ############################################################################################ ### Recommended methods ############################################################################################ -update_observed(loss_function::SemConstant, observed::SemObserved; kwargs...) = - loss_function +update_observed(loss::SemConstant, observed::SemObserved; kwargs...) = loss diff --git a/src/loss/regularization/ridge.jl b/src/loss/regularization/ridge.jl index 90cbcc23..3e2cfbff 100644 --- a/src/loss/regularization/ridge.jl +++ b/src/loss/regularization/ridge.jl @@ -25,7 +25,7 @@ my_ridge = SemRidge(;α_ridge = 0.02, which_ridge = [:λ₁, :λ₂, :ω₂₃], # Interfaces Analytic gradients and hessians are available. """ -struct SemRidge{P, W1, W2, GT, HT} <: SemLossFunction +struct SemRidge{P, W1, W2, GT, HT} <: AbstractLoss hessianeval::ExactHessian α::P which::W1 @@ -74,15 +74,14 @@ end ### methods ############################################################################################ -objective(ridge::SemRidge, model::AbstractSem, par) = - @views ridge.α * sum(abs2, par[ridge.which]) +objective(ridge::SemRidge, par) = @views ridge.α * sum(abs2, par[ridge.which]) -function gradient(ridge::SemRidge, model::AbstractSem, par) +function gradient(ridge::SemRidge, par) @views ridge.gradient[ridge.which] .= (2 * ridge.α) * par[ridge.which] return ridge.gradient end -function hessian(ridge::SemRidge, model::AbstractSem, par) +function hessian(ridge::SemRidge, par) @views @. ridge.hessian[ridge.which_H] .= 2 * ridge.α return ridge.hessian end diff --git a/src/objective_gradient_hessian.jl b/src/objective_gradient_hessian.jl index 69915ffa..23cef4e6 100644 --- a/src/objective_gradient_hessian.jl +++ b/src/objective_gradient_hessian.jl @@ -24,68 +24,61 @@ is_hessian_required(::EvaluationTargets{<:Any, <:Any, H}) where {H} = H (targets::EvaluationTargets)(arg_tuple::Tuple) = targets(arg_tuple...) """ - evaluate!(objective, gradient, hessian [, lossfun], model, params) + evaluate!(objective, gradient, hessian, loss::AbstractLoss, params) + evaluate!(objective, gradient, hessian, model::AbstractSem, params) Evaluates the objective, gradient, and/or Hessian at the given parameter vector. -If a loss function is passed, only this specific loss function is evaluated, otherwise, -the sum of all loss functions in the model is evaluated. + +If a single loss term (`loss`) is passed, only this specific term is evaluated, +otherwise, if the entire SEM `model` is passed, the weighted sum of all loss terms +in the model is evaluated. If objective, gradient or hessian are `nothing`, they are not evaluated. For example, since many numerical optimization algorithms don't require a Hessian, -the computation will be turned off by setting `hessian` to `nothing`. +its computation will be turned off by setting `hessian` to `nothing`. + +During the evaluation, the internal state of the loss term or of the model +could be modified. # Arguments - `objective`: a Number if the objective should be evaluated, otherwise `nothing` - `gradient`: a pre-allocated vector the gradient should be written to, otherwise `nothing` - `hessian`: a pre-allocated matrix the Hessian should be written to, otherwise `nothing` -- `lossfun::SemLossFunction`: loss function to evaluate +- `loss::AbstractLoss`: loss function to evaluate - `model::AbstractSem`: model to evaluate - `params`: vector of parameters # Implementing a new loss function -To implement a new loss function, a new method for `evaluate!` has to be defined. +To implement a new loss (subtype of `SemLoss` for SEM terms, or of `AbstractLoss` for +regularization terms), a new method for `evaluate!` has to be defined. This is explained in the online documentation on [Custom loss functions](@ref). """ function evaluate! end -# dispatch on SemImplied -evaluate!(objective, gradient, hessian, loss::SemLossFunction, model::AbstractSem, params) = - evaluate!(objective, gradient, hessian, loss, implied(model), model, params) - # fallback method -function evaluate!( - obj, - grad, - hess, - loss::SemLossFunction, - implied::SemImplied, - model, - params, -) - isnothing(obj) || (obj = objective(loss, implied, model, params)) - isnothing(grad) || copyto!(grad, gradient(loss, implied, model, params)) - isnothing(hess) || copyto!(hess, hessian(loss, implied, model, params)) +function evaluate!(obj, grad, hess, loss::AbstractLoss, params) + isnothing(obj) || (obj = objective(loss, params)) + isnothing(grad) || copyto!(grad, gradient(loss, params)) + isnothing(hess) || copyto!(hess, hessian(loss, params)) return obj end -# fallback methods -objective(f::SemLossFunction, implied::SemImplied, model, params) = - objective(f, model, params) -gradient(f::SemLossFunction, implied::SemImplied, model, params) = - gradient(f, model, params) -hessian(f::SemLossFunction, implied::SemImplied, model, params) = hessian(f, model, params) +evaluate!(obj, grad, hess, term::LossTerm, params) = + evaluate!(obj, grad, hess, loss(term), params) # fallback method for SemImplied that calls update_xxx!() methods -function update!(targets::EvaluationTargets, implied::SemImplied, model, params) - is_objective_required(targets) && update_objective!(implied, model, params) - is_gradient_required(targets) && update_gradient!(implied, model, params) - is_hessian_required(targets) && update_hessian!(implied, model, params) +function update!(targets::EvaluationTargets, implied::SemImplied, params) + is_objective_required(targets) && update_objective!(implied, params) + is_gradient_required(targets) && update_gradient!(implied, params) + is_hessian_required(targets) && update_hessian!(implied, params) end +const AbstractSemOrLoss = Union{AbstractSem, AbstractLoss} + # guess objective type -objective_type(model::AbstractSem, params::Any) = Float64 -objective_type(model::AbstractSem, params::AbstractVector{T}) where {T <: Number} = T -objective_zero(model::AbstractSem, params::Any) = zero(objective_type(model, params)) +objective_type(model::AbstractSemOrLoss, params::Any) = Float64 +objective_type(model::AbstractSemOrLoss, params::AbstractVector{T}) where {T <: Number} = T +objective_zero(model::AbstractSemOrLoss, params::Any) = zero(objective_type(model, params)) objective_type(objective::T, gradient, hessian) where {T <: Number} = T objective_type( @@ -101,145 +94,151 @@ objective_type( objective_zero(objective, gradient, hessian) = zero(objective_type(objective, gradient, hessian)) +evaluate!(objective, gradient, hessian, model::AbstractSem, params) = + error("evaluate!() for $(typeof(model)) is not implemented") + ############################################################################################ -# methods for AbstractSem +# methods for Sem ############################################################################################ -function evaluate!(objective, gradient, hessian, model::AbstractSemSingle, params) - targets = EvaluationTargets(objective, gradient, hessian) - # update implied state, its gradient and hessian (if required) - update!(targets, implied(model), model, params) - return evaluate!( - !isnothing(objective) ? zero(objective) : nothing, - gradient, - hessian, - loss(model), - model, - params, - ) -end +function evaluate!(objective, gradient, hessian, model::Sem, params) + # reset output + isnothing(objective) || (objective = objective_zero(objective, gradient, hessian)) + isnothing(gradient) || fill!(gradient, zero(eltype(gradient))) + isnothing(hessian) || fill!(hessian, zero(eltype(hessian))) -############################################################################################ -# methods for SemFiniteDiff -# (approximate gradient and hessian with finite differences of objective) -############################################################################################ + # gradient and hessian for individual terms + t_grad = isnothing(gradient) ? nothing : similar(gradient) + t_hess = isnothing(hessian) ? nothing : similar(hessian) + + # update implied states of all SemLoss terms before term calculation loop + # to make sure all terms use updated implied states + targets = EvaluationTargets(objective, gradient, hessian) + for term in loss_terms(model) + issemloss(term) && update!(targets, implied(term), params) + end -function evaluate!(objective, gradient, hessian, model::SemFiniteDiff, params) - function obj(p) - # recalculate implied state for p - update!(EvaluationTargets{true, false, false}(), implied(model), model, p) - evaluate!( - objective_zero(objective, gradient, hessian), - nothing, - nothing, - loss(model), - model, - p, + for term in loss_terms(model) + t_obj = evaluate!(objective, t_grad, t_hess, term, params) + #@show nameof(typeof(term)) t_obj + objective = accumulate_loss!( + objective, + gradient, + hessian, + weight(term), + t_obj, + t_grad, + t_hess, ) end - isnothing(gradient) || FiniteDiff.finite_difference_gradient!(gradient, obj, params) - isnothing(hessian) || FiniteDiff.finite_difference_hessian!(hessian, obj, params) - return !isnothing(objective) ? obj(params) : nothing + return objective end -objective(model::AbstractSem, params) = - evaluate!(objective_zero(model, params), nothing, nothing, model, params) - -############################################################################################ -# methods for SemLoss (weighted sum of individual SemLossFunctions) -############################################################################################ +# internal function to accumulate loss objective, gradient and hessian +function accumulate_loss!( + total_objective, + total_gradient, + total_hessian, + weight::Nothing, + objective, + gradient, + hessian, +) + isnothing(total_gradient) || (total_gradient .+= gradient) + isnothing(total_hessian) || (total_hessian .+= hessian) + return isnothing(total_objective) ? total_objective : (total_objective + objective) +end -function evaluate!(objective, gradient, hessian, loss::SemLoss, model::AbstractSem, params) - isnothing(objective) || (objective = zero(objective)) - isnothing(gradient) || fill!(gradient, zero(eltype(gradient))) - isnothing(hessian) || fill!(hessian, zero(eltype(hessian))) - f_grad = isnothing(gradient) ? nothing : similar(gradient) - f_hess = isnothing(hessian) ? nothing : similar(hessian) - for (f, weight) in zip(loss.functions, loss.weights) - f_obj = evaluate!(objective, f_grad, f_hess, f, model, params) - isnothing(objective) || (objective += weight * f_obj) - isnothing(gradient) || (gradient .+= weight * f_grad) - isnothing(hessian) || (hessian .+= weight * f_hess) - end - return objective +function accumulate_loss!( + total_objective, + total_gradient, + total_hessian, + weight::Number, + objective, + gradient, + hessian, +) + isnothing(total_gradient) || axpy!(weight, gradient, total_gradient) + isnothing(total_hessian) || axpy!(weight, hessian, total_hessian) + return isnothing(total_objective) ? total_objective : + (total_objective + weight * objective) end ############################################################################################ -# methods for SemEnsemble (weighted sum of individual AbstractSemSingle models) +# methods for SemFiniteDiff +# (approximate gradient and hessian with finite differences of objective) ############################################################################################ -function evaluate!(objective, gradient, hessian, ensemble::SemEnsemble, params) - isnothing(objective) || (objective = zero(objective)) - isnothing(gradient) || fill!(gradient, zero(eltype(gradient))) - isnothing(hessian) || fill!(hessian, zero(eltype(hessian))) - sem_grad = isnothing(gradient) ? nothing : similar(gradient) - sem_hess = isnothing(hessian) ? nothing : similar(hessian) - for (sem, weight) in zip(ensemble.sems, ensemble.weights) - sem_obj = evaluate!(objective, sem_grad, sem_hess, sem, params) - isnothing(objective) || (objective += weight * sem_obj) - isnothing(gradient) || (gradient .+= weight * sem_grad) - isnothing(hessian) || (hessian .+= weight * sem_hess) - end - return objective +# evaluate!() wrapper that does some housekeeping, if necessary +_evaluate!(args...) = evaluate!(args...) + +# update implied state, its gradient and hessian +function _evaluate!(objective, gradient, hessian, loss::SemLoss, params) + # note that any other Sem loss terms that are dependent on implied + # should be enumerated after the SemLoss term + # otherwise they would be using outdated implied state + update!(EvaluationTargets(objective, gradient, hessian), implied(loss), params) + return evaluate!(objective, gradient, hessian, loss, params) end +objective(model::AbstractSemOrLoss, params) = + _evaluate!(objective_zero(model, params), nothing, nothing, model, params) + +# throw an error by default if gradient! and hessian! are not implemented + +#= gradient!(model::AbstractSemOrLoss, par, model) = + throw(ArgumentError("gradient for $(nameof(typeof(model))) is not available")) + +hessian!(model::AbstractSemOrLoss, par, model) = + throw(ArgumentError("hessian for $(nameof(typeof(model))) is not available")) =# + ############################################################################################ # Documentation ############################################################################################ """ objective!(model::AbstractSem, params) -Returns the objective value at `params`. -The model object can be modified. +Calculates the objective value at `params`. -# Implementation -To implement a new `SemImplied` or `SemLossFunction` subtype, you need to add a method for - objective!(newtype::MyNewType, params, model::AbstractSemSingle) +The model object can be modified during calculation. -To implement a new `AbstractSem` subtype, you need to add a method for - objective!(model::MyNewType, params) +See also [`evaluate!`](@ref). """ function objective! end """ gradient!(gradient, model::AbstractSem, params) -Writes the gradient value at `params` to `gradient`. +Calculates the model's gradient at `params` and writes it to `gradient`. -# Implementation -To implement a new `SemImplied` or `SemLossFunction` type, you can add a method for - gradient!(newtype::MyNewType, params, model::AbstractSemSingle) +The model object can be modified during calculation. -To implement a new `AbstractSem` subtype, you can add a method for - gradient!(gradient, model::MyNewType, params) +See also [`evaluate!`](@ref). """ function gradient! end """ hessian!(hessian, model::AbstractSem, params) -Writes the hessian value at `params` to `hessian`. +Calculates the model's hessian at `params` and writes it to `hessian`. -# Implementation -To implement a new `SemImplied` or `SemLossFunction` type, you can add a method for - hessian!(newtype::MyNewType, params, model::AbstractSemSingle) +The model object can be modified during calculation. -To implement a new `AbstractSem` subtype, you can add a method for - hessian!(hessian, model::MyNewType, params) +See also [`evaluate!`](@ref). """ function hessian! end objective!(model::AbstractSem, params) = - evaluate!(objective_zero(model, params), nothing, nothing, model, params) + _evaluate!(objective_zero(model, params), nothing, nothing, model, params) gradient!(gradient, model::AbstractSem, params) = - evaluate!(nothing, gradient, nothing, model, params) + _evaluate!(nothing, gradient, nothing, model, params) hessian!(hessian, model::AbstractSem, params) = - evaluate!(nothing, nothing, hessian, model, params) + _evaluate!(nothing, nothing, hessian, model, params) objective_gradient!(gradient, model::AbstractSem, params) = - evaluate!(objective_zero(model, params), gradient, nothing, model, params) + _evaluate!(objective_zero(model, params), gradient, nothing, model, params) objective_hessian!(hessian, model::AbstractSem, params) = - evaluate!(objective_zero(model, params), nothing, hessian, model, params) + _evaluate!(objective_zero(model, params), nothing, hessian, model, params) gradient_hessian!(gradient, hessian, model::AbstractSem, params) = - evaluate!(nothing, gradient, hessian, model, params) + _evaluate!(nothing, gradient, hessian, model, params) objective_gradient_hessian!(gradient, hessian, model::AbstractSem, params) = - evaluate!(objective_zero(model, params), gradient, hessian, model, params) + _evaluate!(objective_zero(model, params), gradient, hessian, model, params) diff --git a/src/optimizer/abstract.jl b/src/optimizer/abstract.jl index 0c7913c4..6774e549 100644 --- a/src/optimizer/abstract.jl +++ b/src/optimizer/abstract.jl @@ -137,13 +137,13 @@ fit(model::AbstractSem; engine::Symbol = :Optim, start_val = nothing, kwargs...) fit(optim::SemOptimizer, model::AbstractSem, start_params; kwargs...) = error("Optimizer $(optim) support not implemented.") -# FABIN3 is the default method for single models -prepare_start_params(start_val::Nothing, model::AbstractSemSingle; kwargs...) = - start_fabin3(model; kwargs...) - -# simple algorithm is the default method for ensembles -prepare_start_params(start_val::Nothing, model::AbstractSem; kwargs...) = - start_simple(model; kwargs...) +# defaults when no starting parameters are specified +function prepare_start_params(start_val::Nothing, model::AbstractSem; kwargs...) + sems = sem_terms(model) + # FABIN3 for single models, simple algorithm for ensembles + return length(sems) == 1 ? start_fabin3(loss(sems[1]); kwargs...) : + start_simple(model; kwargs...) +end # first argument is a function prepare_start_params(start_val, model::AbstractSem; kwargs...) = start_val(model; kwargs...) diff --git a/src/types.jl b/src/types.jl index 3a6b5fdf..87b733cf 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1,14 +1,6 @@ ############################################################################################ # Define the basic type system ############################################################################################ -"Most abstract supertype for all SEMs" -abstract type AbstractSem end - -"Supertype for all single SEMs, e.g. SEMs that have at least the fields `observed`, `implied`, `loss`" -abstract type AbstractSemSingle{O, I, L} <: AbstractSem end - -"Supertype for all collections of multiple SEMs" -abstract type AbstractSemCollection <: AbstractSem end "Meanstructure trait for `SemImplied` subtypes" abstract type MeanStruct end @@ -36,48 +28,8 @@ HessianEval(::Type{T}) where {T} = HessianEval(semobj) = HessianEval(typeof(semobj)) -"Supertype for all loss functions of SEMs. If you want to implement a custom loss function, it should be a subtype of `SemLossFunction`." -abstract type SemLossFunction end - -""" - SemLoss(args...; loss_weights = nothing, ...) - -Constructs the loss field of a SEM. Can contain multiple `SemLossFunction`s, the model is optimized over their sum. -See also [`SemLossFunction`](@ref). - -# Arguments -- `args...`: Multiple `SemLossFunction`s. -- `loss_weights::Vector`: Weights for each loss function. Defaults to unweighted optimization. - -# Examples -```julia -my_ml_loss = SemML(...) -my_ridge_loss = SemRidge(...) -my_loss = SemLoss(SemML, SemRidge; loss_weights = [1.0, 2.0]) -``` -""" -mutable struct SemLoss{F <: Tuple, T} - functions::F - weights::T -end - -function SemLoss(functions...; loss_weights = nothing, kwargs...) - if !isnothing(loss_weights) - loss_weights = SemWeight.(loss_weights) - else - loss_weights = Tuple(SemWeight(nothing) for _ in 1:length(functions)) - end - - return SemLoss(functions, loss_weights) -end - -# weights for loss functions or models. If the weight is nothing, multiplication returns the second argument -struct SemWeight{T} - w::T -end - -Base.:*(x::SemWeight{Nothing}, y) = y -Base.:*(x::SemWeight, y) = x.w * y +"Supertype for all loss functions of SEMs. If you want to implement a custom loss function, it should be a subtype of `AbstractLoss`." +abstract type AbstractLoss end abstract type SemOptimizer{E} end @@ -85,6 +37,8 @@ abstract type SemOptimizer{E} end abstract type SemOptimizerResult{O <: SemOptimizer} end """ + abstract type SemObserved + Supertype of all objects that can serve as the observed field of a SEM. Pre-processes data and computes sufficient statistics for example. If you have a special kind of data, e.g. ordinal data, you should implement a subtype of SemObserved. @@ -103,169 +57,90 @@ abstract type SemImplied end abstract type SemImpliedSymbolic <: SemImplied end """ - Sem(;observed = SemObservedData, implied = RAM, loss = SemML, kwargs...) + abstract type SemLoss{O <: SemObserved, I <: SemImplied} <: AbstractLoss -Constructor for the basic `Sem` type. -All additional kwargs are passed down to the constructors for the observed, implied, and loss fields. +The base type for calculating the loss of the implied SEM model when explaining the observed data. -# Arguments -- `observed`: object of subtype `SemObserved` or a constructor. -- `implied`: object of subtype `SemImplied` or a constructor. -- `loss`: object of subtype `SemLossFunction`s or constructor; or a tuple of such. - -Returns a Sem with fields -- `observed::SemObserved`: Stores observed data, sample statistics, etc. See also [`SemObserved`](@ref). -- `implied::SemImplied`: Computes model implied statistics, like Σ, μ, etc. See also [`SemImplied`](@ref). -- `loss::SemLoss`: Computes the objective and gradient of a sum of loss functions. See also [`SemLoss`](@ref). +All subtypes of `SemLoss` should have the following fields: +- `observed::O`: object of subtype [`SemObserved`](@ref). +- `implied::I`: object of subtype [`SemImplied`](@ref). """ -mutable struct Sem{O <: SemObserved, I <: SemImplied, L <: SemLoss} <: - AbstractSemSingle{O, I, L} - observed::O - implied::I - loss::L -end +abstract type SemLoss{O <: SemObserved, I <: SemImplied} <: AbstractLoss end -############################################################################################ -# automatic differentiation -############################################################################################ """ - SemFiniteDiff(;observed = SemObservedData, implied = RAM, loss = SemML, kwargs...) + abstract type AbstractSem -A wrapper around [`Sem`](@ref) that substitutes dedicated evaluation of gradient and hessian with -finite difference approximation. +The base type for all SEMs. +""" +abstract type AbstractSem end -# Arguments -- `observed`: object of subtype `SemObserved` or a constructor. -- `implied`: object of subtype `SemImplied` or a constructor. -- `loss`: object of subtype `SemLossFunction`s or constructor; or a tuple of such. - -Returns a Sem with fields -- `observed::SemObserved`: Stores observed data, sample statistics, etc. See also [`SemObserved`](@ref). -- `implied::SemImplied`: Computes model implied statistics, like Σ, μ, etc. See also [`SemImplied`](@ref). -- `loss::SemLoss`: Computes the objective and gradient of a sum of loss functions. See also [`SemLoss`](@ref). """ -struct SemFiniteDiff{O <: SemObserved, I <: SemImplied, L <: SemLoss} <: - AbstractSemSingle{O, I, L} - observed::O - implied::I + struct LossTerm{L, I, W} + +A term of a [`Sem`](@ref) model that wraps [`AbstractLoss`](@ref) loss function of type `L`. +Loss term can have an optional *id* of type `I` and *weight* of numeric type `W`. +""" +struct LossTerm{L <: AbstractLoss, I <: Union{Symbol, Nothing}, W <: Union{Number, Nothing}} loss::L + id::I + weight::W end -############################################################################################ -# ensemble models -############################################################################################ """ - (1) SemEnsemble(models...; weights = nothing, groups = nothing, kwargs...) + Sem(loss_terms...; [params], kwargs...) - (2) SemEnsemble(;specification, data, groups, column = :group, kwargs...) +SEM model (including multi-group SEMs) that combines all the data, implied SEM structure +and regularization terms. -Constructor for ensemble models. (2) can be used to conveniently specify multigroup models. +All terms of the `Sem` object share the same set of parameters. +`Sem` implements the calculation of the weighted sum of its terms (the *objective* +function), as well as the gradient and Hessian of this sum. # Arguments -- `models...`: `AbstractSem`s. -- `weights::Vector`: Weights for each model. Defaults to the number of observed data points. -- `specification::EnsembleParameterTable`: Model specification. -- `data::DataFrame`: Observed data. Must contain a `column` of type `Vector{Symbol}` that contains the group. -- `groups::Vector{Symbol}`: Group names. -- `column::Symbol`: Name of the column in `data` that contains the group. - -All additional kwargs are passed down to the model parts. - -Returns a SemEnsemble with fields -- `n::Int`: Number of models. -- `sems::Tuple`: `AbstractSem`s. -- `weights::Vector`: Weights for each model. -- `param_labels::Vector`: Stores parameter labels and their position. - -For instructions on multigroup models, see the online documentation. +- `loss_terms...`: [`AbstractLoss`](@ref) objects, including SEM losses ([`SemLoss`](@ref)), + optionally can be a pair of a loss object and its numeric weight + +# Fields +- `loss_terms::Tuple`: a tuple of all loss functions and their weights +- `params::Vector{Symbol}`: the vector of parameter ids shared by all loss functions. """ -struct SemEnsemble{N, T <: Tuple, V <: AbstractVector, I, G <: Vector{Symbol}} <: - AbstractSemCollection - n::N - sems::T - weights::V - param_labels::I - groups::G +struct Sem{L <: Tuple} <: AbstractSem + loss_terms::L + params::Vector{Symbol} end -# constructor from multiple models -function SemEnsemble(models...; weights = nothing, groups = nothing, kwargs...) - n = length(models) - # default weights - weights = isnothing(weights) ? multigroup_weights(models, n) : weights - # default group labels - groups = isnothing(groups) ? Symbol.(:g, 1:n) : groups - # check parameters equality - param_labels = SEM.param_labels(models[1]) - for model in models - if param_labels != SEM.param_labels(model) - throw(ErrorException("The parameters of your models do not match. \n - Maybe you tried to specify models of an ensemble via ParameterTables. \n - In that case, you may use RAMMatrices instead.")) - end - end - - return SemEnsemble(n, models, weights, param_labels, groups) -end +############################################################################################ +# automatic differentiation +############################################################################################ -# constructor from EnsembleParameterTable and data set -function SemEnsemble(; specification, data, groups, column = :group, kwargs...) - if specification isa EnsembleParameterTable - specification = convert(Dict{Symbol, RAMMatrices}, specification) - end - models = [] - for group in groups - ram_matrices = specification[group] - data_group = select(filter(r -> r[column] == group, data), Not(column)) - if iszero(nrow(data_group)) - error("Your data does not contain any observations from group `$(group)`.") - end - model = Sem(; specification = ram_matrices, data = data_group, kwargs...) - push!(models, model) - end - return SemEnsemble(models...; groups = groups, kwargs...) -end +""" + SemFiniteDiff(model::AbstractSem) -function multigroup_weights(models, n) - nsamples_total = sum(nsamples, models) - uniform_lossfun = check_single_lossfun(models...; throw_error = false) - if !uniform_lossfun - @info "Your ensemble model contains heterogeneous loss functions. - Default weights of (#samples per group/#total samples) will be used." - return [(nsamples(model)) / (nsamples_total) for model in models] - end - lossfun = models[1].loss.functions[1] - if !applicable(mg_correction, lossfun) - @info "We don't know how to choose group weights for the specified loss function. - Default weights of (#samples per group/#total samples) will be used." - return [(nsamples(model)) / (nsamples_total) for model in models] - end - c = mg_correction(lossfun) - return [(nsamples(model)+c) / (nsamples_total+n*c) for model in models] -end +A wrapper around [`AbstractSem`](@ref) that substitutes dedicated evaluation of gradient and +hessian with finite difference approximation. -param_labels(ensemble::SemEnsemble) = ensemble.param_labels +`SemFiniteDiff` could be used to enable gradient-based optimization of the SEM models +when the dedicated calculation of gradient and hessian are not available. +For approximation, it uses the *FiniteDiff.jl* package. +# Arguments +- `model::Sem`: the SEM model to wrap """ - n_models(ensemble::SemEnsemble) -> Integer +struct SemFiniteDiff{S <: AbstractSem} <: AbstractSem + model::S +end -Returns the number of models in an ensemble model. -""" -n_models(ensemble::SemEnsemble) = ensemble.n -""" - models(ensemble::SemEnsemble) -> Tuple{AbstractSem} +struct LossFiniteDiff{L <: AbstractLoss} <: AbstractLoss + loss::L +end -Returns the models in an ensemble model. -""" -models(ensemble::SemEnsemble) = ensemble.sems -""" - weights(ensemble::SemEnsemble) -> Vector +struct SemLossFiniteDiff{O, I, L <: SemLoss{O, I}} <: SemLoss{O, I} + loss::L +end -Returns the weights of an ensemble model. """ -weights(ensemble::SemEnsemble) = ensemble.weights + abstract type SemSpecification end -""" Base type for all SEM specifications. """ abstract type SemSpecification end diff --git a/test/examples/multigroup/build_models.jl b/test/examples/multigroup/build_models.jl index bb7db3b5..48723fbe 100644 --- a/test/examples/multigroup/build_models.jl +++ b/test/examples/multigroup/build_models.jl @@ -4,34 +4,36 @@ const SEM = StructuralEquationModels # ML estimation ############################################################################################ -model_g1 = Sem(specification = specification_g1, data = dat_g1, implied = RAMSymbolic) +obs_g1 = SemObservedData(data = dat_g1, observed_vars = SEM.observed_vars(specification_g1)) +obs_g2 = SemObservedData(data = dat_g2, observed_vars = SEM.observed_vars(specification_g2)) -model_g2 = Sem(specification = specification_g2, data = dat_g2, implied = RAM) +model_ml_multigroup = Sem( + :Pasteur => SemML(obs_g1, RAMSymbolic(specification_g1)), + :Grant_White => SemML(obs_g2, RAM(specification_g2)), +) -@test SEM.param_labels(model_g1.implied.ram_matrices) == - SEM.param_labels(model_g2.implied.ram_matrices) +@testset "Sem API" begin + @test SEM.nsamples(model_ml_multigroup) == nsamples(obs_g1) + nsamples(obs_g2) + @test SEM.nsem_terms(model_ml_multigroup) == 2 + @test length(SEM.sem_terms(model_ml_multigroup)) == 2 +end -# test the different constructors -model_ml_multigroup = SemEnsemble(model_g1, model_g2; groups = [:Pasteur, :Grant_White]) -model_ml_multigroup2 = SemEnsemble( - specification = partable, - data = dat, - column = :school, - groups = [:Pasteur, :Grant_White], - loss = SemML, +# replace observed using Dict of data matrices +model_ml_multigroup3 = replace_observed( + model_ml_multigroup, + Dict(:Pasteur => dat_g1, :Grant_White => dat_g2), ) -model_ml_multigroup3 = replace_observed( - model_ml_multigroup2, - column = :school, - specification = partable, - data = dat, +# replace observed using DataFrame with group column +model_ml_multigroup4 = replace_observed( + model_ml_multigroup, + dat; + semterm_column = :school, ) # gradients @testset "ml_gradients_multigroup" begin test_gradient(model_ml_multigroup, start_test; atol = 1e-9) - test_gradient(model_ml_multigroup2, start_test; atol = 1e-9) end # fit @@ -44,50 +46,18 @@ end atol = 1e-4, lav_groups = Dict(:Pasteur => 1, :Grant_White => 2), ) - solution = fit(semoptimizer, model_ml_multigroup2) - update_estimate!(partable, solution) - test_estimates( - partable, - solution_lav[:parameter_estimates_ml]; - atol = 1e-4, - lav_groups = Dict(:Pasteur => 1, :Grant_White => 2), - ) end @testset "replace_observed_multigroup" begin sem_fit_1 = fit(semoptimizer, model_ml_multigroup) - sem_fit_2 = fit(semoptimizer, model_ml_multigroup3) - @test sem_fit_1.solution ≈ sem_fit_2.solution + sem_fit_3 = fit(semoptimizer, model_ml_multigroup3) + @test sem_fit_1.solution ≈ sem_fit_3.solution + sem_fit_4 = fit(semoptimizer, model_ml_multigroup4) + @test sem_fit_1.solution ≈ sem_fit_4.solution end @testset "fitmeasures/se_ml" begin - solution_ml = fit(model_ml_multigroup) - test_fitmeasures( - fit_measures(solution_ml), - solution_lav[:fitmeasures_ml]; - rtol = 1e-2, - atol = 1e-7, - ) - test_fitmeasures( - Dict(:CFI => CFI(solution_ml)), - solution_lav[:fitmeasures_ml]; - fitmeasure_names = Dict(:CFI => "cfi"), - ) - - update_se_hessian!(partable, solution_ml) - test_estimates( - partable, - solution_lav[:parameter_estimates_ml]; - atol = 1e-3, - col = :se, - lav_col = :se, - lav_groups = Dict(:Pasteur => 1, :Grant_White => 2), - ) - - test_bootstrap(solution_ml, partable; rtol_hessian = 0.3) - smoketest_CI_z(solution_ml, partable) - - solution_ml = fit(model_ml_multigroup2) + solution_ml = fit(semoptimizer, model_ml_multigroup) test_fitmeasures( fit_measures(solution_ml), solution_lav[:fitmeasures_ml]; @@ -118,15 +88,19 @@ end partable_s = sort_vars(partable) specification_s = convert(Dict{Symbol, RAMMatrices}, partable_s) +obs_g1_s = SemObservedData( + data = dat_g1, + observed_vars = SEM.observed_vars(specification_s[:Pasteur]), +) +obs_g2_s = SemObservedData( + data = dat_g2, + observed_vars = SEM.observed_vars(specification_s[:Grant_White]), +) -specification_g1_s = specification_s[:Pasteur] -specification_g2_s = specification_s[:Grant_White] - -model_g1 = Sem(specification = specification_g1_s, data = dat_g1, implied = RAMSymbolic) - -model_g2 = Sem(specification = specification_g2_s, data = dat_g2, implied = RAM) - -model_ml_multigroup = SemEnsemble(model_g1, model_g2; optimizer = semoptimizer) +model_ml_multigroup = Sem( + SemML(obs_g1_s, RAMSymbolic(specification_s[:Pasteur])), + SemML(obs_g2_s, RAM(specification_s[:Grant_White])), +) # gradients @testset "ml_gradients_multigroup | sorted" begin @@ -142,7 +116,7 @@ grad_fd = FiniteDiff.finite_difference_gradient( # fit @testset "ml_solution_multigroup | sorted" begin - solution = fit(model_ml_multigroup) + solution = fit(semoptimizer, model_ml_multigroup) update_estimate!(partable_s, solution) test_estimates( partable_s, @@ -153,7 +127,7 @@ grad_fd = FiniteDiff.finite_difference_gradient( end @testset "fitmeasures/se_ml | sorted" begin - solution_ml = fit(model_ml_multigroup) + solution_ml = fit(semoptimizer, model_ml_multigroup) test_fitmeasures( fit_measures(solution_ml), solution_lav[:fitmeasures_ml]; @@ -178,28 +152,26 @@ end end @testset "sorted | LowerTriangular A" begin - @test implied(model_ml_multigroup.sems[2]).A isa LowerTriangular + @test implied(SEM.sem_terms(model_ml_multigroup)[2]).A isa LowerTriangular end ############################################################################################ # ML estimation - user defined loss function ############################################################################################ -struct UserSemML <: SemLossFunction +struct UserSemML{O, I} <: SemLoss{O, I} hessianeval::ExactHessian - UserSemML() = new(ExactHessian()) -end - -############################################################################################ -### functors -############################################################################################ + observed::O + implied::I -using LinearAlgebra: isposdef, logdet, tr, inv + UserSemML(observed::SemObserved, implied::SemImplied) = + new{typeof(observed), typeof(implied)}(ExactHessian(), observed, implied) +end -function SEM.objective(ml::UserSemML, model::AbstractSem, params) - Σ = implied(model).Σ - Σₒ = SEM.obs_cov(observed(model)) +function SEM.objective(ml::UserSemML, params) + Σ = implied(ml).Σ + Σₒ = SEM.obs_cov(observed(ml)) if !isposdef(Σ) return Inf else @@ -208,24 +180,18 @@ function SEM.objective(ml::UserSemML, model::AbstractSem, params) end # models -model_g1 = Sem(specification = specification_g1, data = dat_g1, implied = RAMSymbolic) - -model_g2 = SemFiniteDiff( - specification = specification_g2, - data = dat_g2, - implied = RAMSymbolic, - loss = UserSemML(), +model_ml_multigroup = Sem( + SemML(obs_g1, RAMSymbolic(specification_g1)), + SEM.FiniteDiffWrapper(UserSemML(obs_g2, RAMSymbolic(specification_g2))), ) -model_ml_multigroup = SemEnsemble(model_g1, model_g2; optimizer = semoptimizer) - @testset "gradients_user_defined_loss" begin test_gradient(model_ml_multigroup, start_test; atol = 1e-9) end # fit @testset "solution_user_defined_loss" begin - solution = fit(model_ml_multigroup) + solution = fit(semoptimizer, model_ml_multigroup) update_estimate!(partable, solution) test_estimates( partable, @@ -239,25 +205,9 @@ end # GLS estimation ############################################################################################ -model_ls_g1 = Sem( - specification = specification_g1, - data = dat_g1, - implied = RAMSymbolic, - loss = SemWLS, -) - -model_ls_g2 = Sem( - specification = specification_g2, - data = dat_g2, - implied = RAMSymbolic, - loss = SemWLS, -) - -model_ls_multigroup = SemEnsemble( - model_ls_g1, - model_ls_g2; - groups = [:Pasteur, :Grant_White], - optimizer = semoptimizer, +model_ls_multigroup = Sem( + SemWLS(obs_g1, RAMSymbolic(specification_g1, vech = true)), + SemWLS(obs_g2, RAMSymbolic(specification_g2, vech = true)), ) @testset "ls_gradients_multigroup" begin @@ -265,7 +215,7 @@ model_ls_multigroup = SemEnsemble( end @testset "ls_solution_multigroup" begin - solution = fit(model_ls_multigroup) + solution = fit(semoptimizer, model_ls_multigroup) update_estimate!(partable, solution) test_estimates( partable, @@ -276,7 +226,7 @@ end end @testset "fitmeasures/se_ls" begin - solution_ls = fit(model_ls_multigroup) + solution_ls = fit(semoptimizer, model_ls_multigroup) test_fitmeasures( fit_measures(solution_ls), solution_lav[:fitmeasures_ls]; @@ -308,40 +258,27 @@ end ############################################################################################ if !isnothing(specification_miss_g1) - model_g1 = Sem( - specification = specification_miss_g1, - observed = SemObservedMissing, - loss = SemFIML, - data = dat_miss_g1, - implied = RAM, - meanstructure = true, - ) - - model_g2 = Sem( - specification = specification_miss_g2, - observed = SemObservedMissing, - loss = SemFIML, - data = dat_miss_g2, - implied = RAM, - meanstructure = true, - ) - - model_ml_multigroup = SemEnsemble(model_g1, model_g2) - model_ml_multigroup2 = SemEnsemble( - specification = partable_miss, - data = dat_missing, - column = :school, - groups = [:Pasteur, :Grant_White], - loss = SemFIML, - observed = SemObservedMissing, - meanstructure = true, + model_ml_multigroup = Sem( + SemFIML( + SemObservedMissing( + data = dat_miss_g1, + observed_vars = SEM.observed_vars(specification_miss_g1), + ), + RAM(specification_miss_g1), + ), + SemFIML( + SemObservedMissing( + data = dat_miss_g2, + observed_vars = SEM.observed_vars(specification_miss_g2), + ), + RAM(specification_miss_g2), + ), ) - model_ml_varonly = SemEnsemble( + model_ml_varonly = Sem( specification = partable_varonly, data = dat_missing, - column = :school, - groups = [:Pasteur, :Grant_White], + semterm_column = :school, loss = SemFIML, observed = SemObservedMissing, meanstructure = true, @@ -373,7 +310,6 @@ if !isnothing(specification_miss_g1) @testset "fiml_gradients_multigroup" begin test_gradient(model_ml_multigroup, start_test; atol = 1e-7) - test_gradient(model_ml_multigroup2, start_test; atol = 1e-7) end @testset "fiml_solution_multigroup" begin @@ -385,14 +321,6 @@ if !isnothing(specification_miss_g1) atol = 1e-4, lav_groups = Dict(:Pasteur => 1, :Grant_White => 2), ) - solution = fit(semoptimizer, model_ml_multigroup2) - update_estimate!(partable_miss, solution) - test_estimates( - partable_miss, - solution_lav[:parameter_estimates_fiml]; - atol = 1e-4, - lav_groups = Dict(:Pasteur => 1, :Grant_White => 2), - ) end @testset "fitmeasures/se_fiml" begin diff --git a/test/examples/political_democracy/by_parts.jl b/test/examples/political_democracy/by_parts.jl index d2d468a9..6866eead 100644 --- a/test/examples/political_democracy/by_parts.jl +++ b/test/examples/political_democracy/by_parts.jl @@ -2,103 +2,83 @@ ### models w.o. meanstructure ############################################################################################ -# observed --------------------------------------------------------------------------------- -observed = SemObservedData(specification = spec, data = dat) +semoptimizer = SemOptimizer(engine = opt_engine) -# implied -implied_ram = RAM(specification = spec) +model_ml = Sem(specification = spec, data = dat) +@test SEM.params(model_ml) == SEM.params(spec) -implied_ram_sym = RAMSymbolic(specification = spec) +model_ls_sym = + Sem(specification = spec, data = dat, implied = RAMSymbolic, vech = true, loss = SemWLS) -# loss functions --------------------------------------------------------------------------- -ml = SemML(observed = observed) +model_ml_sym = Sem(specification = spec, data = dat, implied = RAMSymbolic) -wls = SemWLS(observed = observed) - -ridge = SemRidge(α_ridge = 0.001, which_ridge = 16:20, nparams = 31) - -constant = SemConstant(constant_loss = 3.465) - -# loss ------------------------------------------------------------------------------------- -loss_ml = SemLoss(ml) - -loss_wls = SemLoss(wls) - -# optimizer ------------------------------------------------------------------------------------- -optimizer_obj = SemOptimizer(engine = opt_engine) - -# models ----------------------------------------------------------------------------------- - -model_ml = Sem(observed, implied_ram, loss_ml) - -model_ls_sym = Sem(observed, RAMSymbolic(specification = spec, vech = true), loss_wls) - -model_ml_sym = Sem(observed, implied_ram_sym, loss_ml) - -model_ridge = Sem(observed, implied_ram, SemLoss(ml, ridge)) +model_ml_ridge = Sem( + specification = spec, + data = dat, + loss = (SemML, SemRidge), + α_ridge = 0.001, + which_ridge = 16:20, +) -model_constant = Sem(observed, implied_ram, SemLoss(ml, constant)) +model_ml_const = Sem( + specification = spec, + data = dat, + loss = (SemML, SemConstant), + constant_loss = 3.465, +) -model_ml_weighted = - Sem(observed, implied_ram, SemLoss(ml; loss_weights = [nsamples(model_ml)])) +model_ml_weighted = Sem(SemML(SemObservedData(data = dat), RAM(spec)) => nsamples(model_ml)) ############################################################################################ ### test gradients ############################################################################################ -models = - [model_ml, model_ls_sym, model_ridge, model_constant, model_ml_sym, model_ml_weighted] -model_names = ["ml", "ls_sym", "ridge", "constant", "ml_sym", "ml_weighted"] +models = Dict( + "ml" => model_ml, + "ls_sym" => model_ls_sym, + "ml_ridge" => model_ml_ridge, + "ml_const" => model_ml_const, + "ml_sym" => model_ml_sym, + "ml_weighted" => model_ml_weighted, +) -for (model, name) in zip(models, model_names) - try - @testset "$(name)_gradient" begin - test_gradient(model, start_test; rtol = 1e-9) - end - catch - end +@testset "$(id)_gradient" for (id, model) in pairs(models) + test_gradient(model, start_test; rtol = 1e-9) end ############################################################################################ ### test solution ############################################################################################ -models = [model_ml, model_ls_sym, model_ml_sym, model_constant] -model_names = ["ml", "ls_sym", "ml_sym", "constant"] -solution_names = Symbol.("parameter_estimates_" .* ["ml", "ls", "ml", "ml"]) - -for (model, name, solution_name) in zip(models, model_names, solution_names) - try - @testset "$(name)_solution" begin - solution = fit(optimizer_obj, model) - update_estimate!(partable, solution) - test_estimates(partable, solution_lav[solution_name]; atol = 1e-2) - end - catch - end +@testset "$(id)_solution" for id in ["ml", "ls_sym", "ml_sym", "ml_const"] + model = models[id] + solution = fit(semoptimizer, model) + sol_name = Symbol("parameter_estimates_", replace(id, r"_.+$" => "")) + update_estimate!(partable, solution) + test_estimates(partable, solution_lav[sol_name]; atol = 1e-2) end @testset "ridge_solution" begin - solution_ridge = fit(optimizer_obj, model_ridge) - solution_ml = fit(optimizer_obj, model_ml) - # solution_ridge_id = fit(optimizer_obj, model_ridge_id) - @test solution_ridge.minimum < solution_ml.minimum + 1 + solution_ridge = fit(semoptimizer, model_ml_ridge) + solution_ml = fit(semoptimizer, model_ml) + # solution_ridge_id = fit(model_ridge_id) + @test abs(solution_ridge.minimum - solution_ml.minimum) < 1 end # test constant objective value @testset "constant_objective_and_gradient" begin - @test (objective!(model_constant, start_test) - 3.465) ≈ + @test (objective!(model_ml_const, start_test) - 3.465) ≈ objective!(model_ml, start_test) grad = similar(start_test) grad2 = similar(start_test) - gradient!(grad, model_constant, start_test) + gradient!(grad, model_ml_const, start_test) gradient!(grad2, model_ml, start_test) @test grad ≈ grad2 end @testset "ml_solution_weighted" begin - solution_ml = fit(optimizer_obj, model_ml) - solution_ml_weighted = fit(optimizer_obj, model_ml_weighted) + solution_ml = fit(semoptimizer, model_ml) + solution_ml_weighted = fit(semoptimizer, model_ml_weighted) @test solution(solution_ml) ≈ solution(solution_ml_weighted) rtol = 1e-3 @test nsamples(model_ml) * StructuralEquationModels.minimum(solution_ml) ≈ StructuralEquationModels.minimum(solution_ml_weighted) rtol = 1e-6 @@ -109,7 +89,7 @@ end ############################################################################################ @testset "fitmeasures/se_ml" begin - solution_ml = fit(optimizer_obj, model_ml) + solution_ml = fit(semoptimizer, model_ml) test_fitmeasures(fit_measures(solution_ml), solution_lav[:fitmeasures_ml]; atol = 1e-3) test_fitmeasures( Dict(:CFI => CFI(solution_ml)), @@ -128,20 +108,15 @@ end end @testset "fitmeasures/se_ls" begin - solution_ls = fit(optimizer_obj, model_ls_sym) + solution_ls = fit(semoptimizer, model_ls_sym) fm = fit_measures(solution_ls) test_fitmeasures( - fm, + merge(fm, Dict(:CFI => CFI(solution_ls))), solution_lav[:fitmeasures_ls]; atol = 1e-3, - fitmeasure_names = fitmeasure_names_ls, - ) - test_fitmeasures( - Dict(:CFI => CFI(solution_ls)), - solution_lav[:fitmeasures_ls]; - fitmeasure_names = Dict(:CFI => "cfi"), + fitmeasure_names = merge(fitmeasure_names_ls, Dict(:CFI => "cfi")) ) - @test (fm[:AIC] === missing) & (fm[:BIC] === missing) & (fm[:minus2ll] === missing) + @test ismissing(fm[:AIC]) && ismissing(fm[:BIC]) && ismissing(fm[:minus2ll]) @suppress update_se_hessian!(partable, solution_ls) test_estimates( @@ -160,22 +135,22 @@ end if opt_engine == :Optim using Optim, LineSearches - optimizer_obj = SemOptimizer( - engine = opt_engine, - algorithm = Newton(; - linesearch = BackTracking(order = 3), - alphaguess = InitialHagerZhang(), - ), + model_ls = Sem( + data = dat, + specification = spec, + implied = RAMSymbolic, + loss = SemWLS, + vech = true, + hessian = true, ) - implied_sym_hessian_vech = - RAMSymbolic(specification = spec, vech = true, hessian = true) - - implied_sym_hessian = RAMSymbolic(specification = spec, hessian = true) - - model_ls = Sem(observed, implied_sym_hessian_vech, loss_wls) - - model_ml = Sem(observed, implied_sym_hessian, loss_ml) + model_ml = Sem( + data = dat, + specification = spec, + implied = RAMSymbolic, + loss = SemML, + hessian = true, + ) @testset "ml_hessians" begin test_hessian(model_ml, start_test; atol = 1e-4) @@ -186,13 +161,23 @@ if opt_engine == :Optim end @testset "ml_solution_hessian" begin - solution = fit(optimizer_obj, model_ml) + solution = fit(SemOptimizer(engine = :Optim, algorithm = Newton()), model_ml) + update_estimate!(partable, solution) test_estimates(partable, solution_lav[:parameter_estimates_ml]; atol = 1e-2) end @testset "ls_solution_hessian" begin - solution = fit(optimizer_obj, model_ls) + solution = fit( + SemOptimizer( + engine = :Optim, + algorithm = Newton( + linesearch = BackTracking(order = 3), + alphaguess = InitialHagerZhang(), + ), + ), + model_ls, + ) update_estimate!(partable, solution) test_estimates( partable, @@ -207,69 +192,47 @@ end ### meanstructure ############################################################################################ -# observed --------------------------------------------------------------------------------- -observed = SemObservedData(specification = spec_mean, data = dat, meanstructure = true) - -# implied -implied_ram = RAM(specification = spec_mean, meanstructure = true) - -implied_ram_sym = RAMSymbolic(specification = spec_mean, meanstructure = true) - -# loss functions --------------------------------------------------------------------------- -ml = SemML(observed = observed, meanstructure = true) - -wls = SemWLS(observed = observed, meanstructure = true) - -# loss ------------------------------------------------------------------------------------- -loss_ml = SemLoss(ml) - -loss_wls = SemLoss(wls) - -# optimizer ------------------------------------------------------------------------------------- -optimizer_obj = SemOptimizer(engine = opt_engine) +# models +model_ls = Sem( + data = dat, + specification = spec_mean, + implied = RAMSymbolic, + loss = SemWLS, + vech = true, +) -# models ----------------------------------------------------------------------------------- -model_ml = Sem(observed, implied_ram, loss_ml) +model_ml = Sem(data = dat, specification = spec_mean, implied = RAM, loss = SemML) -model_ls = Sem( - observed, - RAMSymbolic(specification = spec_mean, meanstructure = true, vech = true), - loss_wls, +model_ml_cov = Sem( + specification = spec, + observed = SemObservedCovariance, + obs_cov = cov(Matrix(dat)), + observed_vars = Symbol.(names(dat)), + nsamples = 75, ) -model_ml_sym = Sem(observed, implied_ram_sym, loss_ml) +model_ml_sym = + Sem(data = dat, specification = spec_mean, implied = RAMSymbolic, loss = SemML) ############################################################################################ ### test gradients ############################################################################################ -models = [model_ml, model_ls, model_ml_sym] -model_names = ["ml", "ls_sym", "ml_sym"] +models = Dict("ml" => model_ml, "ls_sym" => model_ls, "ml_sym" => model_ml_sym) -for (model, name) in zip(models, model_names) - try - @testset "$(name)_gradient_mean" begin - test_gradient(model, start_test_mean; rtol = 1e-9) - end - catch - end +@testset "$(id)_gradient_mean" for (id, model) in pairs(models) + test_gradient(model, start_test_mean; rtol = 1e-9) end ############################################################################################ ### test solution ############################################################################################ -solution_names = Symbol.("parameter_estimates_" .* ["ml", "ls", "ml"] .* "_mean") - -for (model, name, solution_name) in zip(models, model_names, solution_names) - try - @testset "$(name)_solution_mean" begin - solution = fit(optimizer_obj, model) - update_estimate!(partable_mean, solution) - test_estimates(partable_mean, solution_lav[solution_name]; atol = 1e-2) - end - catch - end +@testset "$(id)_solution_mean" for (id, model) in pairs(models) + solution = fit(semoptimizer, model, start_val = start_test_mean) + update_estimate!(partable_mean, solution) + sol_name = Symbol("parameter_estimates_", replace(id, r"_.+$" => ""), "_mean") + test_estimates(partable_mean, solution_lav[sol_name]; atol = 1e-2) end ############################################################################################ @@ -277,7 +240,7 @@ end ############################################################################################ @testset "fitmeasures/se_ml_mean" begin - solution_ml = fit(optimizer_obj, model_ml) + solution_ml = fit(semoptimizer, model_ml) test_fitmeasures( fit_measures(solution_ml), solution_lav[:fitmeasures_ml_mean]; @@ -300,20 +263,15 @@ end end @testset "fitmeasures/se_ls_mean" begin - solution_ls = fit(optimizer_obj, model_ls) + solution_ls = fit(semoptimizer, model_ls) fm = fit_measures(solution_ls) test_fitmeasures( - fm, + merge(fm, Dict(:CFI => CFI(solution_ls))), solution_lav[:fitmeasures_ls_mean]; atol = 1e-3, - fitmeasure_names = fitmeasure_names_ls, - ) - test_fitmeasures( - Dict(:CFI => CFI(solution_ls)), - solution_lav[:fitmeasures_ls_mean]; - fitmeasure_names = Dict(:CFI => "cfi"), + fitmeasure_names = merge(fitmeasure_names_ls, Dict(:CFI => "cfi")), ) - @test (fm[:AIC] === missing) & (fm[:BIC] === missing) & (fm[:minus2ll] === missing) + @test ismissing(fm[:AIC]) && ismissing(fm[:BIC]) && ismissing(fm[:minus2ll]) @suppress update_se_hessian!(partable_mean, solution_ls) test_estimates( @@ -329,16 +287,22 @@ end ### fiml ############################################################################################ -observed = - SemObservedMissing(specification = spec_mean, data = dat_missing, rtol_em = 1e-10) - -fiml = SemFIML(observed = observed, specification = spec_mean) - -loss_fiml = SemLoss(fiml) - -model_ml = Sem(observed, implied_ram, loss_fiml) +# models +model_ml = Sem( + data = dat_missing, + observed = SemObservedMissing, + specification = spec_mean, + implied = RAM, + loss = SemFIML, +) -model_ml_sym = Sem(observed, implied_ram_sym, loss_fiml) +model_ml_sym = Sem( + data = dat_missing, + observed = SemObservedMissing, + specification = spec_mean, + implied = RAMSymbolic, + loss = SemFIML, +) ############################################################################################ ### test gradients @@ -357,13 +321,13 @@ end ############################################################################################ @testset "fiml_solution" begin - solution = fit(optimizer_obj, model_ml) + solution = fit(semoptimizer, model_ml) update_estimate!(partable_mean, solution) test_estimates(partable_mean, solution_lav[:parameter_estimates_fiml]; atol = 1e-2) end @testset "fiml_solution_symbolic" begin - solution = fit(optimizer_obj, model_ml_sym) + solution = fit(semoptimizer, model_ml_sym, start_val = start_test_mean) update_estimate!(partable_mean, solution) test_estimates(partable_mean, solution_lav[:parameter_estimates_fiml]; atol = 1e-2) end @@ -373,7 +337,7 @@ end ############################################################################################ @testset "fitmeasures/se_fiml" begin - solution_ml = fit(optimizer_obj, model_ml) + solution_ml = fit(semoptimizer, model_ml) test_fitmeasures( fit_measures(solution_ml), solution_lav[:fitmeasures_fiml]; @@ -384,7 +348,7 @@ end test_estimates( partable_mean, solution_lav[:parameter_estimates_fiml]; - atol = 1e-3, + atol = 0.002, col = :se, lav_col = :se, ) diff --git a/test/examples/political_democracy/constraints.jl b/test/examples/political_democracy/constraints.jl index 7a6670fa..0291e7ea 100644 --- a/test/examples/political_democracy/constraints.jl +++ b/test/examples/political_democracy/constraints.jl @@ -50,7 +50,7 @@ end @test solution_constrained.solution[31] * solution_constrained.solution[30] >= (0.6 - 1e-8) - @test all(abs.(solution_constrained.solution) .< 10) - @test solution_constrained.optimization_result.result[3] == :FTOL_REACHED + @test all(p -> abs(p) < 10, solution_constrained.solution) + @test solution_constrained.optimization_result.result[3] == :FTOL_REACHED skip = true @test solution_constrained.minimum <= 21.21 + 0.01 end diff --git a/test/examples/political_democracy/constructor.jl b/test/examples/political_democracy/constructor.jl index 1c1c42e5..25a6da91 100644 --- a/test/examples/political_democracy/constructor.jl +++ b/test/examples/political_democracy/constructor.jl @@ -5,21 +5,22 @@ semoptimizer = SemOptimizer(engine = opt_engine) model_ml = Sem(specification = spec, data = dat) -@test SEM.param_labels(model_ml.implied.ram_matrices) == SEM.param_labels(spec) +@test SEM.param_labels(model_ml) == SEM.param_labels(spec) model_ml_cov = Sem( specification = spec, observed = SemObservedCovariance, obs_cov = cov(Matrix(dat)), - obs_colnames = Symbol.(names(dat)), + observed_vars = Symbol.(names(dat)), nsamples = 75, ) -model_ls_sym = Sem(specification = spec, data = dat, implied = RAMSymbolic, loss = SemWLS) +model_ls_sym = + Sem(specification = spec, data = dat, implied = RAMSymbolic, vech = true, loss = SemWLS) model_ml_sym = Sem(specification = spec, data = dat, implied = RAMSymbolic) -model_ridge = Sem( +model_ml_ridge = Sem( specification = spec, data = dat, loss = (SemML, SemRidge), @@ -27,7 +28,7 @@ model_ridge = Sem( which_ridge = 16:20, ) -model_constant = Sem( +model_ml_const = Sem( specification = spec, data = dat, loss = (SemML, SemConstant), @@ -35,65 +36,52 @@ model_constant = Sem( ) model_ml_weighted = - Sem(specification = partable, data = dat, loss_weights = (nsamples(model_ml),)) + Sem(SemML(SemObservedData(data = dat), RAMSymbolic(spec)) => nsamples(model_ml)) ############################################################################################ ### test gradients ############################################################################################ -models = [ - model_ml, - model_ml_cov, - model_ls_sym, - model_ridge, - model_constant, - model_ml_sym, - model_ml_weighted, -] -model_names = ["ml", "ml_cov", "ls_sym", "ridge", "constant", "ml_sym", "ml_weighted"] - -for (model, name) in zip(models, model_names) - try - @testset "$(name)_gradient" begin - test_gradient(model, start_test; rtol = 1e-9) - end - catch - end +models = Dict( + "ml" => model_ml, + "ml_cov" => model_ml_cov, + "ls_sym" => model_ls_sym, + "ridge" => model_ml_ridge, + "ml_const" => model_ml_const, + "ml_sym" => model_ml_sym, + "ml_weighted" => model_ml_weighted, +) + +@testset "$(id)_gradient" for (id, model) in pairs(models) + test_gradient(model, start_test; rtol = 1e-9) end ############################################################################################ ### test solution ############################################################################################ -models = [model_ml, model_ml_cov, model_ls_sym, model_ml_sym, model_constant] -model_names = ["ml", "ml_cov", "ls_sym", "ml_sym", "constant"] -solution_names = Symbol.("parameter_estimates_" .* ["ml", "ml", "ls", "ml", "ml"]) - -for (model, name, solution_name) in zip(models, model_names, solution_names) - try - @testset "$(name)_solution" begin - solution = fit(semoptimizer, model) - update_estimate!(partable, solution) - test_estimates(partable, solution_lav[solution_name]; atol = 1e-2) - end - catch - end +@testset "$(id)_solution" for id in ["ml", "ml_cov", "ls_sym", "ml_sym", "ml_const"] + model = models[id] + solution = fit(semoptimizer, model) + sol_name = Symbol("parameter_estimates_", replace(id, r"_.+$" => "")) + update_estimate!(partable, solution) + test_estimates(partable, solution_lav[sol_name]; atol = 1e-2) end @testset "ridge_solution" begin - solution_ridge = fit(semoptimizer, model_ridge) + solution_ridge = fit(semoptimizer, model_ml_ridge) solution_ml = fit(semoptimizer, model_ml) - # solution_ridge_id = fit(semoptimizer, model_ridge_id) + # solution_ridge_id = fit(model_ridge_id) @test abs(solution_ridge.minimum - solution_ml.minimum) < 1 end # test constant objective value @testset "constant_objective_and_gradient" begin - @test (objective!(model_constant, start_test) - 3.465) ≈ + @test (objective!(model_ml_const, start_test) - 3.465) ≈ objective!(model_ml, start_test) grad = similar(start_test) grad2 = similar(start_test) - gradient!(grad, model_constant, start_test) + gradient!(grad, model_ml_const, start_test) gradient!(grad2, model_ml, start_test) @test grad ≈ grad2 end @@ -101,12 +89,9 @@ end @testset "ml_solution_weighted" begin solution_ml = fit(semoptimizer, model_ml) solution_ml_weighted = fit(semoptimizer, model_ml_weighted) - @test isapprox(solution(solution_ml), solution(solution_ml_weighted), rtol = 1e-3) - @test isapprox( - nsamples(model_ml) * StructuralEquationModels.minimum(solution_ml), - StructuralEquationModels.minimum(solution_ml_weighted), - rtol = 1e-6, - ) + @test solution(solution_ml) ≈ solution(solution_ml_weighted) rtol = 1e-3 + @test nsamples(model_ml) * StructuralEquationModels.minimum(solution_ml) ≈ + StructuralEquationModels.minimum(solution_ml_weighted) rtol = 1e-6 end ############################################################################################ @@ -181,19 +166,14 @@ end ) # set seed for simulation Random.seed!(83472834) - colnames = Symbol.(names(example_data("political_democracy"))) # simulate data model_ml_new = replace_observed( model_ml, - data = rand(model_ml, params, 1_000_000), - specification = spec, - obs_colnames = colnames, + rand(model_ml, params, 1_000_000), ) model_ml_sym_new = replace_observed( model_ml_sym, - data = rand(model_ml_sym, params, 1_000_000), - specification = spec, - obs_colnames = colnames, + rand(model_ml_sym, params, 1_000_000), ) # fit models sol_ml = solution(fit(semoptimizer, model_ml_new)) @@ -211,23 +191,19 @@ if opt_engine == :Optim using Optim, LineSearches model_ls = Sem( - specification = spec, data = dat, - implied = RAMSymbolic, + specification = spec, + observed = SemObservedData, + implied = RAMSymbolic(spec, vech = true, hessian = true), loss = SemWLS, - hessian = true, - algorithm = Newton(; - linesearch = BackTracking(order = 3), - alphaguess = InitialHagerZhang(), - ), ) model_ml = Sem( - specification = spec, data = dat, - implied = RAMSymbolic, - hessian = true, - algorithm = Newton(), + specification = spec, + observed = SemObservedData, + implied = RAMSymbolic(spec, hessian = true), + loss = SemML, ) @testset "ml_hessians" begin @@ -239,13 +215,23 @@ if opt_engine == :Optim end @testset "ml_solution_hessian" begin - solution = fit(semoptimizer, model_ml) + solution = fit(SemOptimizer(engine = :Optim, algorithm = Newton()), model_ml) + update_estimate!(partable, solution) test_estimates(partable, solution_lav[:parameter_estimates_ml]; atol = 1e-2) end @testset "ls_solution_hessian" begin - solution = fit(semoptimizer, model_ls) + solution = fit( + SemOptimizer( + engine = :Optim, + algorithm = Newton( + linesearch = BackTracking(order = 3), + alphaguess = InitialHagerZhang(), + ), + ), + model_ls, + ) update_estimate!(partable, solution) test_estimates( partable, @@ -266,6 +252,7 @@ model_ls = Sem( specification = spec_mean, data = dat, implied = RAMSymbolic, + vech = true, loss = SemWLS, meanstructure = true, ) @@ -277,7 +264,7 @@ model_ml_cov = Sem( observed = SemObservedCovariance, obs_cov = cov(Matrix(dat)), obs_mean = vcat(mean(Matrix(dat), dims = 1)...), - obs_colnames = Symbol.(names(dat)), + observed_vars = Symbol.(names(dat)), meanstructure = true, nsamples = 75, ) @@ -289,33 +276,26 @@ model_ml_sym = ### test gradients ############################################################################################ -models = [model_ml, model_ml_cov, model_ls, model_ml_sym] -model_names = ["ml", "ml_cov", "ls_sym", "ml_sym"] +models = Dict( + "ml" => model_ml, + "ml_cov" => model_ml_cov, + "ls_sym" => model_ls, + "ml_sym" => model_ml_sym, +) -for (model, name) in zip(models, model_names) - try - @testset "$(name)_gradient_mean" begin - test_gradient(model, start_test_mean; rtol = 1e-9) - end - catch - end +@testset "$(id)_gradient_mean" for (id, model) in pairs(models) + test_gradient(model, start_test_mean; rtol = 1e-9) end ############################################################################################ ### test solution ############################################################################################ -solution_names = Symbol.("parameter_estimates_" .* ["ml", "ml", "ls", "ml"] .* "_mean") - -for (model, name, solution_name) in zip(models, model_names, solution_names) - try - @testset "$(name)_solution_mean" begin - solution = fit(semoptimizer, model) - update_estimate!(partable_mean, solution) - test_estimates(partable_mean, solution_lav[solution_name]; atol = 1e-2) - end - catch - end +@testset "$(id)_solution_mean" for (id, model) in pairs(models) + solution = fit(semoptimizer, model, start_val = start_test_mean) + update_estimate!(partable_mean, solution) + sol_name = Symbol("parameter_estimates_", replace(id, r"_.+$" => ""), "_mean") + test_estimates(partable_mean, solution_lav[sol_name]; atol = 1e-2) end ############################################################################################ @@ -395,21 +375,14 @@ end ) # set seed for simulation Random.seed!(83472834) - colnames = Symbol.(names(example_data("political_democracy"))) # simulate data model_ml_new = replace_observed( model_ml, - data = rand(model_ml, params, 1_000_000), - specification = spec, - obs_colnames = colnames, - meanstructure = true, + rand(model_ml, params, 1_000_000), ) model_ml_sym_new = replace_observed( model_ml_sym, - data = rand(model_ml_sym, params, 1_000_000), - specification = spec, - obs_colnames = colnames, - meanstructure = true, + rand(model_ml_sym, params, 1_000_000), ) # fit models sol_ml = solution(fit(semoptimizer, model_ml_new)) @@ -474,7 +447,7 @@ end end @testset "fiml_solution_symbolic" begin - solution = fit(semoptimizer, model_ml_sym) + solution = fit(semoptimizer, model_ml_sym, start_val = start_test_mean) update_estimate!(partable_mean, solution) test_estimates(partable_mean, solution_lav[:parameter_estimates_fiml]; atol = 1e-2) end diff --git a/test/examples/recover_parameters/recover_parameters_twofact.jl b/test/examples/recover_parameters/recover_parameters_twofact.jl index a4bd7d5f..ebaaae83 100644 --- a/test/examples/recover_parameters/recover_parameters_twofact.jl +++ b/test/examples/recover_parameters/recover_parameters_twofact.jl @@ -1,5 +1,7 @@ using StructuralEquationModels, Distributions, Random, Optim, LineSearches +SEM = StructuralEquationModels + include( joinpath( chop(dirname(pathof(StructuralEquationModels)), tail = 3), @@ -7,7 +9,7 @@ include( ), ) -x = Symbol.("x", 1:13) +pars = Symbol.("x", 1:13) S = [ :x1 0 0 0 0 0 0 0 @@ -40,7 +42,7 @@ A = [ 0 0 0 0 0 0 0 0 ] -ram_matrices = RAMMatrices(; A = A, S = S, F = F, param_labels = x, vars = nothing) +ram_matrices = RAMMatrices(; A = A, S = S, F = F, param_labels = pars, vars = nothing) true_val = [ repeat([1], 8) @@ -53,19 +55,19 @@ start = [ repeat([0.5], 4) ] -implied_ml = RAMSymbolic(ram_matrices; start_val = start) +implied_sym = RAMSymbolic(ram_matrices) -implied_ml.Σ_eval!(implied_ml.Σ, true_val) +implied_sym.Σ_eval!(implied_sym.Σ, true_val) -true_dist = MultivariateNormal(implied_ml.Σ) +true_dist = MultivariateNormal(implied_sym.Σ) Random.seed!(1234) -x = transpose(rand(true_dist, 100_000)) -semobserved = SemObservedData(data = x, specification = nothing) +x = permutedims(rand(true_dist, 10^5), (2, 1)) + +observed = SemObservedData(data = x, specification = ram_matrices) -loss_ml = SemLoss(SemML(; observed = semobserved, nparams = length(start))) +model_ml = Sem(SemML(observed, implied_sym)) -model_ml = Sem(semobserved, implied_ml, loss_ml) objective!(model_ml, true_val) optimizer = SemOptimizer( @@ -73,6 +75,6 @@ optimizer = SemOptimizer( Optim.Options(; f_reltol = 1e-10, x_abstol = 1.5e-8), ) -solution_ml = fit(optimizer, model_ml) +solution_ml = fit(optimizer, model_ml, start_val = start) -@test true_val ≈ solution(solution_ml) atol = 0.05 +@test solution(solution_ml) ≈ true_val atol = 0.05 diff --git a/test/unit_tests/model.jl b/test/unit_tests/model.jl index fbe2a937..87812fba 100644 --- a/test/unit_tests/model.jl +++ b/test/unit_tests/model.jl @@ -68,9 +68,8 @@ end test_vars_api(implied(model), ram_matrices) test_params_api(implied(model), ram_matrices) - @test @inferred(loss(model)) isa SemLoss - semloss = loss(model).functions[1] - @test semloss isa SemML + @test @inferred(sem_term(model)) isa SemLoss + @test sem_term(model) isa losstype @test @inferred(nsamples(model)) == nsamples(obs) end From bab1317c939a3d33bfc44b5e74f7f3fb54a9c697 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sat, 21 Mar 2026 11:17:31 -0700 Subject: [PATCH 05/74] params/param_labels(): use both as synonyms for now --- src/frontend/specification/Sem.jl | 1 + src/implied/abstract.jl | 1 + 2 files changed, 2 insertions(+) diff --git a/src/frontend/specification/Sem.jl b/src/frontend/specification/Sem.jl index 684cfa62..d89606b6 100644 --- a/src/frontend/specification/Sem.jl +++ b/src/frontend/specification/Sem.jl @@ -183,6 +183,7 @@ end ############################################################################################ params(model::AbstractSem) = model.params +param_labels(model::AbstractSem) = params(model) # alias """ loss_terms(model::AbstractSem) diff --git a/src/implied/abstract.jl b/src/implied/abstract.jl index d4868d74..e41e79f6 100644 --- a/src/implied/abstract.jl +++ b/src/implied/abstract.jl @@ -8,6 +8,7 @@ nobserved_vars(implied::SemImplied) = nobserved_vars(implied.ram_matrices) nlatent_vars(implied::SemImplied) = nlatent_vars(implied.ram_matrices) param_labels(implied::SemImplied) = param_labels(implied.ram_matrices) +params(implied::SemImplied) = param_labels(implied) nparams(implied::SemImplied) = nparams(implied.ram_matrices) # checks if the A matrix is acyclic From f7f74520ea765ebea72d5b54098e1ab31f7bda7d Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sat, 21 Mar 2026 11:17:31 -0700 Subject: [PATCH 06/74] check_same_semterm_type(): refactor check_single_lossfun() --- src/additional_functions/helper.jl | 29 --------------- src/frontend/fit/fitmeasures/RMSEA.jl | 2 +- src/frontend/fit/fitmeasures/chi2.jl | 17 ++------- src/frontend/fit/fitmeasures/minus2ll.jl | 2 +- src/frontend/specification/Sem.jl | 47 +++++++++++++++++++++++- 5 files changed, 50 insertions(+), 47 deletions(-) diff --git a/src/additional_functions/helper.jl b/src/additional_functions/helper.jl index b3f5212b..8f2342c3 100644 --- a/src/additional_functions/helper.jl +++ b/src/additional_functions/helper.jl @@ -116,35 +116,6 @@ function nonunique(values::AbstractVector) return res end -# check that a model only has a single lossfun -function check_single_lossfun(model::AbstractSemSingle; throw_error) - if (length(model.loss.functions) > 1) & throw_error - @error "The model has $(length(sem.loss.functions)) loss functions. - Only a single loss function is supported." - end - return isone(length(model.loss.functions)) -end - -# check that all models use the same single loss function -function check_single_lossfun(models::AbstractSemSingle...; throw_error) - uniform = true - lossfun = models[1].loss.functions[1] - L = typeof(lossfun) - for (i, model) in enumerate(models) - uniform &= check_single_lossfun(model; throw_error = throw_error) - cur_lossfun = model.loss.functions[1] - if !isa(cur_lossfun, L) & throw_error - @error "Loss function for group #$i model is $(typeof(cur_lossfun)), expected $L. - Heterogeneous loss functions are not supported." - end - uniform &= isa(cur_lossfun, L) - end - return uniform -end - -check_single_lossfun(model::SemEnsemble; throw_error) = - check_single_lossfun(model.sems...; throw_error) - # scaling corrections for multigroup models mg_correction(::SemFIML) = 0 mg_correction(::SemML) = 0 diff --git a/src/frontend/fit/fitmeasures/RMSEA.jl b/src/frontend/fit/fitmeasures/RMSEA.jl index 7406b74c..ac2d890d 100644 --- a/src/frontend/fit/fitmeasures/RMSEA.jl +++ b/src/frontend/fit/fitmeasures/RMSEA.jl @@ -27,7 +27,7 @@ RMSEA_corr_scale(::Type{<:SemML}) = -1 RMSEA_corr_scale(::Type{<:SemWLS}) = -1 function RMSEA(fit::SemFit, model::AbstractSem) - term_type = check_single_lossfun(model; throw_error = true) + term_type = check_same_semterm_type(model; throw_error = true) n = nsamples(fit) + nsem_terms(model) * RMSEA_corr_scale(term_type) sqrt(nsem_terms(model)) * RMSEA(dof(fit), χ²(fit), n) end diff --git a/src/frontend/fit/fitmeasures/chi2.jl b/src/frontend/fit/fitmeasures/chi2.jl index 22d6c2e2..c56b9a2a 100644 --- a/src/frontend/fit/fitmeasures/chi2.jl +++ b/src/frontend/fit/fitmeasures/chi2.jl @@ -14,21 +14,10 @@ with the *observed* covariance matrix. function χ²(fit::SemFit, model::AbstractSem) terms = sem_terms(model) - isempty(terms) && return 0.0 + @assert !isempty(terms) - term1 = _unwrap(loss(terms[1])) - L = typeof(term1).name - - # check that all SemLoss terms are of the same class (ML, FIML, WLS etc), ignore typeparams - for (i, term) in enumerate(terms) - lossterm = _unwrap(loss(term)) - @assert lossterm isa SemLoss - if typeof(_unwrap(lossterm)).name != L - @error "SemLoss term #$i is $(typeof(_unwrap(lossterm)).name), expected $L. Heterogeneous loss functions are not supported" - end - end - - return χ²(typeof(term1), fit, model) + L = check_same_semterm_type(model; throw_error = true) + return χ²(L, fit, model) end # bollen, p. 115, only correct for GLS weight matrix diff --git a/src/frontend/fit/fitmeasures/minus2ll.jl b/src/frontend/fit/fitmeasures/minus2ll.jl index 3b353f5c..1cdf5c07 100644 --- a/src/frontend/fit/fitmeasures/minus2ll.jl +++ b/src/frontend/fit/fitmeasures/minus2ll.jl @@ -62,6 +62,6 @@ end ############################################################################################ function minus2ll(model::AbstractSem, fit::SemFit) - check_single_lossfun(model; throw_error = true) + check_same_semterm_type(model; throw_error = true) sum(Base.Fix2(minus2ll, fit) ∘ _unwrap ∘ loss, sem_terms(model)) end diff --git a/src/frontend/specification/Sem.jl b/src/frontend/specification/Sem.jl index d89606b6..c3e4bd2b 100644 --- a/src/frontend/specification/Sem.jl +++ b/src/frontend/specification/Sem.jl @@ -51,8 +51,8 @@ end function multigroup_weights(models, n) nsamples_total = sum(nsamples, models) - uniform_lossfun = check_single_lossfun(models...; throw_error = false) - if !uniform_lossfun + semloss_type = check_same_semterm_type(semterms; throw_error = false) + if isnothing(semloss_type) @info """ Your ensemble model contains heterogeneous loss functions. Default weights of (#samples per group/#total samples) will be used @@ -258,6 +258,49 @@ function sem_term(model::AbstractSem, _::Nothing = nothing) error("Unreachable reached") end +# check that all models use the same single loss function +# returns the type of the single SEM loss function, SemLoss if there are multiple different SEM losses, +# nothing if there are no SEM terms. +# If throw_error=true, throws an error if there are multiple different SEM loss functions +check_same_semterm_type(model::AbstractSem; throw_error::Bool = true) = + check_same_semterm_type(sem_terms(model); throw_error = throw_error) + +# check that all models use the same single loss function +# returns the type of the single SEM loss function, +# nothing if there are multiple different SEM losses or no SEM terms. +# If throw_error=true, throws an error if there are multiple different SEM loss functions +function check_same_semterm_type(terms::Tuple; throw_error::Bool = true) + isempty(terms) && return nothing + + _semloss(term::SemLoss) = _unwrap(term) + _semloss(term::LossTerm) = _semloss(loss(term)) + _semloss(term) = throw(ArgumentError("SemLoss term expected, $(typeof(term)) found")) + _semloss_label(i::Integer, _::Union{SemLoss, LossTerm{<:SemLoss, Nothing}}) = "#$i" + _semloss_label(i::Integer, term::LossTerm{<:SemLoss, Symbol}) = "#$i ($(SEM.id(term)))" + + term1 = _semloss(terms[1]) + L = typeof(term1).name + + # check that all SemLoss terms are of the same class (ML, FIML, WLS etc), ignore typeparams + for (i, term) in enumerate(terms) + lossterm = _semloss(term) + @assert lossterm isa SemLoss + if typeof(lossterm).name != L + if throw_error + error( + "SemLoss term $(_semloss_label(i, term)) is $(typeof(lossterm).name), expected $L. Heterogeneous loss functions are not supported", + ) + else + return nothing + end + end + end + + # return the type of the first SEM term + # note that type params of the SEM terms might be different + return typeof(term1) +end + # wrappers arounds a single SemLoss term observed(model::AbstractSem, id::Nothing = nothing) = observed(sem_term(model, id)) implied(model::AbstractSem, id::Nothing = nothing) = implied(sem_term(model, id)) From 961a3c8931941c7c5e0c150b95a74d05119da537 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sat, 21 Mar 2026 11:17:31 -0700 Subject: [PATCH 07/74] update multi-group correction deduplicate the correction scale methods and move to Sem.jl --- src/additional_functions/helper.jl | 5 --- src/frontend/fit/fitmeasures/RMSEA.jl | 7 +--- src/frontend/specification/Sem.jl | 46 +++++++++++++++++---------- 3 files changed, 31 insertions(+), 27 deletions(-) diff --git a/src/additional_functions/helper.jl b/src/additional_functions/helper.jl index 8f2342c3..5442357f 100644 --- a/src/additional_functions/helper.jl +++ b/src/additional_functions/helper.jl @@ -115,8 +115,3 @@ function nonunique(values::AbstractVector) end return res end - -# scaling corrections for multigroup models -mg_correction(::SemFIML) = 0 -mg_correction(::SemML) = 0 -mg_correction(::SemWLS) = -1 diff --git a/src/frontend/fit/fitmeasures/RMSEA.jl b/src/frontend/fit/fitmeasures/RMSEA.jl index ac2d890d..9d33e47e 100644 --- a/src/frontend/fit/fitmeasures/RMSEA.jl +++ b/src/frontend/fit/fitmeasures/RMSEA.jl @@ -21,14 +21,9 @@ For multigroup models, the correction proposed by J.H. Steiger is applied """ RMSEA(fit::SemFit) = RMSEA(fit, fit.model) -# scaling corrections -RMSEA_corr_scale(::Type{<:SemFIML}) = 0 -RMSEA_corr_scale(::Type{<:SemML}) = -1 -RMSEA_corr_scale(::Type{<:SemWLS}) = -1 - function RMSEA(fit::SemFit, model::AbstractSem) term_type = check_same_semterm_type(model; throw_error = true) - n = nsamples(fit) + nsem_terms(model) * RMSEA_corr_scale(term_type) + n = nsamples(fit) + nsem_terms(model) * multigroup_correction_scale(term_type) sqrt(nsem_terms(model)) * RMSEA(dof(fit), χ²(fit), n) end diff --git a/src/frontend/specification/Sem.jl b/src/frontend/specification/Sem.jl index c3e4bd2b..d8696e82 100644 --- a/src/frontend/specification/Sem.jl +++ b/src/frontend/specification/Sem.jl @@ -45,32 +45,46 @@ function Base.show(io::IO, term::LossTerm) end end -############################################################################################ -# constructor for Sem types -############################################################################################ +# scaling corrections for multigroup models + +# fallback method for non-standard SemLoss type +multigroup_correction_scale(::Type{<:SemLoss}) = nothing -function multigroup_weights(models, n) - nsamples_total = sum(nsamples, models) +multigroup_correction_scale(::Type{<:SemFIML}) = 0 +multigroup_correction_scale(::Type{<:SemML}) = 0 +multigroup_correction_scale(::Type{<:SemWLS}) = -1 + +multigroup_correction_scale(loss::SemLoss) = multigroup_correction_scale(typeof(loss)) + +# calculate sem term weights for multigroup models +# correcting for the number of samples and the loss type +function multigroup_weights(semterms...) + n = length(semterms) + nsamples_total = sum(nsamples, semterms) semloss_type = check_same_semterm_type(semterms; throw_error = false) if isnothing(semloss_type) @info """ Your ensemble model contains heterogeneous loss functions. Default weights of (#samples per group/#total samples) will be used """ - return [(nsamples(model)) / (nsamples_total) for model in models] - end - lossfun = models[1].loss.functions[1] - if !applicable(mg_correction, lossfun) - @info """ - We don't know how to choose group weights for the specified loss function. - Default weights of (#samples per group/#total samples) will be used - """ - return [(nsamples(model)) / (nsamples_total) for model in models] + c = 0 + else + c = multigroup_correction_scale(semloss_type) + if isnothing(c) + @info """ + We don't know how to choose group weights for the specified loss function. + Default weights of (#samples per group/#total samples) will be used + """ + c = 0 + end end - c = mg_correction(lossfun) - return [(nsamples(model)+c) / (nsamples_total+n*c) for model in models] + return [(nsamples(term)+c) / (nsamples_total+n*c) for term in semterms] end +############################################################################################ +# constructor for Sem types +############################################################################################ + function Sem( loss_terms...; params::Union{Vector{Symbol}, Nothing} = nothing, From a9ee00bdafcbd63df536eeaa929ebb4a218fe6f3 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sat, 21 Mar 2026 11:18:19 -0700 Subject: [PATCH 08/74] replace_observed(): simplify & refactor remove update_observed!() --- ext/SEMNLOptExt/NLopt.jl | 7 -- ext/SEMProximalOptExt/ProximalAlgorithms.jl | 7 -- src/StructuralEquationModels.jl | 1 - src/additional_functions/simulation.jl | 115 ------------------- src/frontend/specification/Sem.jl | 116 +++++++++++++++++++- src/implied/RAM/generic.jl | 17 --- src/implied/RAM/symbolic.jl | 20 ---- src/implied/empty.jl | 6 - src/loss/ML/FIML.jl | 8 -- src/loss/ML/ML.jl | 20 ---- src/loss/WLS/WLS.jl | 23 ---- src/loss/abstract.jl | 29 +++++ src/loss/constant/constant.jl | 6 - src/loss/regularization/ridge.jl | 6 - src/optimizer/Empty.jl | 6 - src/optimizer/optim.jl | 6 - 16 files changed, 140 insertions(+), 253 deletions(-) diff --git a/ext/SEMNLOptExt/NLopt.jl b/ext/SEMNLOptExt/NLopt.jl index 90004b90..87601030 100644 --- a/ext/SEMNLOptExt/NLopt.jl +++ b/ext/SEMNLOptExt/NLopt.jl @@ -107,13 +107,6 @@ function SemOptimizerNLopt(; ) end -############################################################################################ -### Recommended methods -############################################################################################ - -SEM.update_observed(optimizer::SemOptimizerNLopt, observed::SemObserved; kwargs...) = - optimizer - ############################################################################################ ### additional methods ############################################################################################ diff --git a/ext/SEMProximalOptExt/ProximalAlgorithms.jl b/ext/SEMProximalOptExt/ProximalAlgorithms.jl index 3ec32453..0937ee04 100644 --- a/ext/SEMProximalOptExt/ProximalAlgorithms.jl +++ b/ext/SEMProximalOptExt/ProximalAlgorithms.jl @@ -34,13 +34,6 @@ SemOptimizerProximal(; SEM.sem_optimizer_subtype(::Val{:Proximal}) = SemOptimizerProximal -############################################################################################ -### Recommended methods -############################################################################################ - -SEM.update_observed(optimizer::SemOptimizerProximal, observed::SemObserved; kwargs...) = - optimizer - ############################################################################ ### Model fitting ############################################################################ diff --git a/src/StructuralEquationModels.jl b/src/StructuralEquationModels.jl index d98e7925..0dbcd16a 100644 --- a/src/StructuralEquationModels.jl +++ b/src/StructuralEquationModels.jl @@ -205,7 +205,6 @@ export AbstractSem, z_test!, example_data, replace_observed, - update_observed, @StenoGraph, →, ←, diff --git a/src/additional_functions/simulation.jl b/src/additional_functions/simulation.jl index 6d694c97..e85e9d5c 100644 --- a/src/additional_functions/simulation.jl +++ b/src/additional_functions/simulation.jl @@ -1,118 +1,3 @@ -""" - (1) replace_observed(model::AbstractSemSingle; kwargs...) - - (2) replace_observed(model::AbstractSemSingle, observed; kwargs...) - - (3) replace_observed(model::SemEnsemble; column = :group, weights = nothing, kwargs...) - -Return a new model with swaped observed part. - -# Arguments -- `model::AbstractSemSingle`: model to swap the observed part of. -- `kwargs`: additional keyword arguments; typically includes `data` and `specification` -- `observed`: Either an object of subtype of `SemObserved` or a subtype of `SemObserved` - -# For SemEnsemble models: -- `column`: if a DataFrame is passed as `data = ...`, which column signifies the group? -- `weights`: how to weight the different sub-models, - defaults to number of samples per group in the new data -- `kwargs`: has to be a dict with keys equal to the group names. - For `data` can also be a DataFrame with `column` containing the group information, - and for `specification` can also be an `EnsembleParameterTable`. - -# Examples -See the online documentation on [Replace observed data](@ref). -""" -function replace_observed end - -""" - update_observed(to_update, observed::SemObserved; kwargs...) - -Update a `SemImplied`, `SemLossFunction` or `SemOptimizer` object to use a `SemObserved` object. - -# Examples -See the online documentation on [Replace observed data](@ref). - -# Implementation -You can provide a method for this function when defining a new type, for more information -on this see the online developer documentation on [Update observed data](@ref). -""" -function update_observed end - -############################################################################################ -# change observed (data) without reconstructing the whole model -############################################################################################ - -# don't change non-SEM terms -replace_observed(loss::AbstractLoss; kwargs...) = loss - -# use the same observed type as before -replace_observed(loss::SemLoss; kwargs...) = - replace_observed(loss, typeof(SEM.observed(loss)).name.wrapper; kwargs...) - -# construct a new observed type -replace_observed(loss::SemLoss, observed_type; kwargs...) = - replace_observed(loss, observed_type(; kwargs...); kwargs...) - -function replace_observed(loss::SemLoss, new_observed::SemObserved; kwargs...) - kwargs = Dict{Symbol, Any}(kwargs...) - old_observed = SEM.observed(loss) - implied = SEM.implied(loss) - - # get field types - kwargs[:observed_type] = typeof(new_observed) - kwargs[:old_observed_type] = typeof(old_observed) - - # update implied - new_implied = update_observed(implied, new_observed; kwargs...) - kwargs[:implied] = new_implied - kwargs[:implied_type] = typeof(new_implied) - kwargs[:nparams] = nparams(new_implied) - - # update loss - return update_observed(loss, new_observed; kwargs...) -end - -replace_observed(loss::LossTerm; kwargs...) = - LossTerm(replace_observed(loss.loss; kwargs...), loss.id, loss.weight) - -function replace_observed(sem::Sem; kwargs...) - updated_terms = Tuple(replace_observed(term; kwargs...) for term in loss_terms(sem)) - return Sem(updated_terms...) -end - -function replace_observed( - emodel::SemEnsemble; - column = :group, - weights = nothing, - kwargs..., -) - kwargs = Dict{Symbol, Any}(kwargs...) - # allow for EnsembleParameterTable to be passed as specification - if haskey(kwargs, :specification) && isa(kwargs[:specification], EnsembleParameterTable) - kwargs[:specification] = convert(Dict{Symbol, RAMMatrices}, kwargs[:specification]) - end - # allow for DataFrame with group variable "column" to be passed as new data - if haskey(kwargs, :data) && isa(kwargs[:data], DataFrame) - kwargs[:data] = Dict( - group => - select(filter(r -> r[column] == group, kwargs[:data]), Not(column)) for - group in emodel.groups - ) - end - # update each model for new data - models = emodel.sems - new_models = Tuple( - replace_observed(m; group_kwargs(g, kwargs)...) for - (m, g) in zip(models, emodel.groups) - ) - return SemEnsemble(new_models...; weights = weights, groups = emodel.groups) -end - -function group_kwargs(g, kwargs) - return Dict(k => kwargs[k][g] for k in keys(kwargs)) -end - ############################################################################################ # simulate data ############################################################################################ diff --git a/src/frontend/specification/Sem.jl b/src/frontend/specification/Sem.jl index d8696e82..42ff2d3e 100644 --- a/src/frontend/specification/Sem.jl +++ b/src/frontend/specification/Sem.jl @@ -436,12 +436,118 @@ function build_SemTerms(loss, observed, implied; kwargs...) end end -function update_observed(sem::Sem, new_observed; kwargs...) - new_terms = Tuple( - update_observed(lossterm.loss, new_observed; kwargs...) for - lossterm in loss_terms(sem) +############################################################## +# replace_observed: Sem level +############################################################## + +""" + replace_observed(model::Sem, observed::SemObserved) + replace_observed(model::Sem, data::AbstractDict{Symbol}) + replace_observed(model::Sem, data::AbstractDataFrame; [semterm_column]) + replace_observed(loss::SemLoss, observed::SemObserved) + replace_observed(loss::SemLoss, data::Union{AbstractMatrix, DataFrame}) + +Construct a new SEM model or SEM loss with replaced observed data. + +The SEM structure (implied covariance, loss type) is preserved; +only the observed data is swapped. + +# Single-term models + +Pass a `SemObserved` object, a data matrix, or a `DataFrame`: +```julia +replace_observed(model, new_data_matrix) +replace_observed(model, new_sem_observed) +replace_observed(model, new_df) +``` + +# Multi-term models + +Pass a `Dict{Symbol}` mapping term ids to data or `SemObserved` objects: +```julia +replace_observed(model, Dict(:g1 => data1, :g2 => data2)) +``` + +Or pass a `DataFrame` with a `semterm_column` identifying the group: +```julia +replace_observed(model, new_df; semterm_column = :group) +``` +""" +function replace_observed end + +function replace_observed(sem::Sem, data::Union{SemObserved, AbstractMatrix}) + nsem_terms(sem) > 1 && throw( + ArgumentError( + "Model contains $(nsem_terms(sem)) SEM terms. " * + "Use a Dict{Symbol} or a DataFrame with `semterm_column` to provide per-term data.", + ), + ) + updated_terms = Tuple(replace_observed(term, data) for term in loss_terms(sem)) + return Sem(updated_terms...) +end + +function replace_observed(sem::Sem, data::AbstractDict{Symbol}) + term_ids = Set( + if !isnothing(id(term)) + id(term) + else + "Multigroup replace_observed(sem, data::Dict) requires all SEM terms to have ids." |> + ArgumentError |> + throw + end for term in loss_terms(sem) if issemloss(term) + ) + # check for extra ids + extra_term_ids = setdiff(keys(data), term_ids) + isempty(extra_term_ids) || + @warn "Ignoring data with ids=$(collect(extra_term_ids)): no such SEM terms exist in the model" + + updated_terms = map(loss_terms(sem)) do term + issemloss(term) || return term + tid = id(term) + term_data = get(data, tid, nothing) + isnothing(term_data) && + throw(ArgumentError("No data provided for SEM term :$tid")) + return replace_observed(term, term_data) + end + return Sem(Tuple(updated_terms)...) +end + +function replace_observed(sem::Sem, data::AbstractVector) + nsem = nsem_terms(sem) + nsem == length(data) || throw( + ArgumentError( + "Length of data ($(length(data))) does not match number of SEM terms ($nsem)", + ), + ) + updated_terms = map(enumerate(loss_terms(sem))) do (i, term) + issemloss(term) ? replace_observed(term, data[i]) : term + end + return Sem(Tuple(updated_terms)...) +end + +function replace_observed( + sem::Sem, + data::AbstractDataFrame; + semterm_column::Union{Symbol, Nothing} = nothing, +) + if isnothing(semterm_column) + # single-term shortcut + nsem_terms(sem) > 1 && throw( + ArgumentError( + "Model contains $(nsem_terms(sem)) SEM terms. " * + "Provide `semterm_column` to specify which DataFrame column identifies the groups.", + ), + ) + updated_terms = Tuple(replace_observed(term, data) for term in loss_terms(sem)) + return Sem(updated_terms...) + end + + # multi-term: split DataFrame by semterm_column + terms_data = Dict( + g[semterm_column] => group_data for + (g, group_data) in pairs(groupby(data, semterm_column)) ) - return Sem(new_terms...) + return replace_observed(sem, terms_data) end ############################################################## diff --git a/src/implied/RAM/generic.jl b/src/implied/RAM/generic.jl index f1c1e08d..1569b341 100644 --- a/src/implied/RAM/generic.jl +++ b/src/implied/RAM/generic.jl @@ -179,20 +179,3 @@ function update!(targets::EvaluationTargets, implied::RAM, params) mul!(implied.μ, implied.F⨉I_A⁻¹, implied.M) end end - -############################################################################################ -### Recommended methods -############################################################################################ - -function update_observed(implied::RAM, observed::SemObserved; kwargs...) - if nobserved_vars(observed) == nobserved_vars(implied) - return implied - else - return RAM(; - observed = observed, - gradient_required = !isnothing(implied.∇A), - meanstructure = MeanStruct(implied) == HasMeanStruct, - kwargs..., - ) - end -end diff --git a/src/implied/RAM/symbolic.jl b/src/implied/RAM/symbolic.jl index 4c9bda91..52a192e6 100644 --- a/src/implied/RAM/symbolic.jl +++ b/src/implied/RAM/symbolic.jl @@ -190,26 +190,6 @@ function update!(targets::EvaluationTargets, implied::RAMSymbolic, par) end end -############################################################################################ -### Recommended methods -############################################################################################ - -function update_observed(implied::RAMSymbolic, observed::SemObserved; kwargs...) - if nobserved_vars(observed) == nobserved_vars(implied) - return implied - else - return RAMSymbolic(; - observed = observed, - vech = implied.Σ isa Vector, - gradient = !isnothing(implied.∇Σ), - hessian = !isnothing(implied.∇²Σ), - meanstructure = MeanStruct(implied) == HasMeanStruct, - approximate_hessian = isnothing(implied.∇²Σ), - kwargs..., - ) - end -end - ############################################################################################ ### additional functions ############################################################################################ diff --git a/src/implied/empty.jl b/src/implied/empty.jl index a327ee13..a650a07a 100644 --- a/src/implied/empty.jl +++ b/src/implied/empty.jl @@ -46,9 +46,3 @@ end ############################################################################################ update!(targets::EvaluationTargets, implied::ImpliedEmpty, par) = nothing - -############################################################################################ -### Recommended methods -############################################################################################ - -update_observed(implied::ImpliedEmpty, observed::SemObserved; kwargs...) = implied diff --git a/src/loss/ML/FIML.jl b/src/loss/ML/FIML.jl index fdedf398..15081e20 100644 --- a/src/loss/ML/FIML.jl +++ b/src/loss/ML/FIML.jl @@ -159,14 +159,6 @@ function evaluate!(objective, gradient, hessian, loss::SemFIML, params) return objective end - -############################################################################################ -### Recommended methods -############################################################################################ - -update_observed(loss::SemFIML, observed::SemObserved; kwargs...) = - SemFIML(; observed = observed, kwargs...) - ############################################################################################ ### additional functions ############################################################################################ diff --git a/src/loss/ML/ML.jl b/src/loss/ML/ML.jl index 2d449d73..9f327544 100644 --- a/src/loss/ML/ML.jl +++ b/src/loss/ML/ML.jl @@ -235,23 +235,3 @@ function non_posdef_return(par) return typemax(eltype(par)) end end - -############################################################################################ -### recommended methods -############################################################################################ - -update_observed(loss::SemML, observed::SemObservedMissing; kwargs...) = - error("ML estimation does not work with missing data - use FIML instead") - -function update_observed(loss::SemML, observed::SemObserved; kwargs...) - if (obs_cov(loss) == obs_cov(observed)) && (obs_mean(loss) == obs_mean(observed)) - return loss # no change - else - return SemML( - observed, - loss.implied; - approximate_hessian = HessianEval(loss) == ApproxHessian, - kwargs..., - ) - end -end diff --git a/src/loss/WLS/WLS.jl b/src/loss/WLS/WLS.jl index 5c4cb252..8f4a109c 100644 --- a/src/loss/WLS/WLS.jl +++ b/src/loss/WLS/WLS.jl @@ -179,26 +179,3 @@ function evaluate!(objective, gradient, hessian, loss::SemWLS, par) return objective end - -############################################################################################ -### Recommended methods -############################################################################################ - -function update_observed( - loss::SemWLS, - observed::SemObserved; - recompute_V = true, - kwargs..., -) - if recompute_V - return SemWLS(observed, loss.implied; kwargs...) - else - return SemWLS( - observed, - loss.implied; - wls_weight_matrix = loss.V, - wls_weight_matrix_mean = loss.V_μ, - kwargs..., - ) - end -end diff --git a/src/loss/abstract.jl b/src/loss/abstract.jl index bf8585d6..bcd6d62b 100644 --- a/src/loss/abstract.jl +++ b/src/loss/abstract.jl @@ -40,3 +40,32 @@ function check_observed_vars(observed::SemObserved, implied::SemImplied) end check_observed_vars(sem::SemLoss) = check_observed_vars(observed(sem), implied(sem)) + +############################################################################################ +# replace_observed: SemLoss, AbstractLoss, LossTerm +############################################################################################ + +function replace_observed(loss::SemLoss, new_observed::SemObserved) + old_obs = SEM.observed(loss) + observed_vars(old_obs) == observed_vars(new_observed) || throw( + ArgumentError( + "observed_vars of the new data do not match the model: " * + "expected $(observed_vars(old_obs)), got $(observed_vars(new_observed))", + ), + ) + return typeof(loss).name.wrapper(new_observed, SEM.implied(loss)) +end + +function replace_observed(loss::SemLoss, data::Union{AbstractMatrix, DataFrame}) + old_obs = SEM.observed(loss) + new_observed = + typeof(old_obs).name.wrapper(data = data, observed_vars = observed_vars(old_obs)) + return replace_observed(loss, new_observed) +end + +# non-SEM loss terms are unchanged +replace_observed(loss::AbstractLoss, ::Any) = loss + +# LossTerm: delegate to inner loss +replace_observed(term::LossTerm, data) = + LossTerm(replace_observed(loss(term), data), id(term), weight(term)) diff --git a/src/loss/constant/constant.jl b/src/loss/constant/constant.jl index 023076cc..2aff0156 100644 --- a/src/loss/constant/constant.jl +++ b/src/loss/constant/constant.jl @@ -35,9 +35,3 @@ SemConstant(; constant_loss::Number, kwargs...) = SemConstant(constant_loss) objective(loss::SemConstant, par) = convert(eltype(par), loss.c) gradient(loss::SemConstant, par) = zero(par) hessian(loss::SemConstant, par) = zeros(eltype(par), length(par), length(par)) - -############################################################################################ -### Recommended methods -############################################################################################ - -update_observed(loss::SemConstant, observed::SemObserved; kwargs...) = loss diff --git a/src/loss/regularization/ridge.jl b/src/loss/regularization/ridge.jl index 3e2cfbff..813aff11 100644 --- a/src/loss/regularization/ridge.jl +++ b/src/loss/regularization/ridge.jl @@ -85,9 +85,3 @@ function hessian(ridge::SemRidge, par) @views @. ridge.hessian[ridge.which_H] .= 2 * ridge.α return ridge.hessian end - -############################################################################################ -### Recommended methods -############################################################################################ - -update_observed(loss::SemRidge, observed::SemObserved; kwargs...) = loss diff --git a/src/optimizer/Empty.jl b/src/optimizer/Empty.jl index f95c067c..fd36acb5 100644 --- a/src/optimizer/Empty.jl +++ b/src/optimizer/Empty.jl @@ -11,12 +11,6 @@ struct SemOptimizerEmpty <: SemOptimizer{:Empty} end sem_optimizer_subtype(::Val{:Empty}) = SemOptimizerEmpty -############################################################################################ -### Recommended methods -############################################################################################ - -update_observed(optimizer::SemOptimizerEmpty, observed::SemObserved; kwargs...) = optimizer - ############################################################################################ ### Pretty Printing ############################################################################################ diff --git a/src/optimizer/optim.jl b/src/optimizer/optim.jl index 70413193..a0aae22a 100644 --- a/src/optimizer/optim.jl +++ b/src/optimizer/optim.jl @@ -57,12 +57,6 @@ SemOptimizerOptim(; sem_optimizer_subtype(::Val{:Optim}) = SemOptimizerOptim -############################################################################################ -### Recommended methods -############################################################################################ - -update_observed(optimizer::SemOptimizerOptim, observed::SemObserved; kwargs...) = optimizer - ############################################################################################ ### additional methods ############################################################################################ From 84c66530d59da0c645b9fed8aa1578f1747f0795 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sat, 21 Mar 2026 17:56:25 -0700 Subject: [PATCH 09/74] bootstrap: sync with Sem updates --- src/frontend/fit/standard_errors/bootstrap.jl | 348 +++++++++--------- test/examples/helper.jl | 27 +- test/examples/multigroup/build_models.jl | 4 +- .../political_democracy/constructor.jl | 14 +- 4 files changed, 199 insertions(+), 194 deletions(-) diff --git a/src/frontend/fit/standard_errors/bootstrap.jl b/src/frontend/fit/standard_errors/bootstrap.jl index 845a209e..ce84e923 100644 --- a/src/frontend/fit/standard_errors/bootstrap.jl +++ b/src/frontend/fit/standard_errors/bootstrap.jl @@ -1,35 +1,118 @@ +# base type for accumulators of intermediate bootstrap results +abstract type BootstrapAccumulator end + +# internal function to run bootstrap +function bootstrap!( + acc::BootstrapAccumulator, + fitted::SemFit; + data = nothing, + engine = :Optim, + parallel = false, + fit_kwargs = Dict(), +) + sem = model(fitted) + data = isnothing(data) ? _bootstrap_data(sem) : data + start = solution(fitted) + + n_boot = n_bootstrap(acc) + + # fit to bootstrap samples + if !parallel + for i in 1:n_boot + new_fit = _fit_bootstrap_sample(sem, data, start; engine, fit_kwargs) + update!(acc, i, new_fit, nothing) + end + else + n_threads = Threads.nthreads() + # Pre-create one independent model copy per thread via deepcopy. + model_pool = Channel(n_threads) + for _ in 1:n_threads + put!(model_pool, deepcopy(sem)) + end + lk = ReentrantLock() + Threads.@threads for i in 1:n_boot + thread_model = take!(model_pool) + new_fit = _fit_bootstrap_sample(thread_model, data, start; engine, fit_kwargs) + update!(acc, i, new_fit, lk) + put!(model_pool, thread_model) + end + end + + return acc +end + +# a simple accumulator that just stores the statistic for each sample and whether it converged +struct SimpleBootstrapAccumulator{F} <: BootstrapAccumulator + statistic::F + samples::Vector{Any} + converged_mask::Vector{Bool} +end + +SimpleBootstrapAccumulator(statistic, n_boot::Integer) = + SimpleBootstrapAccumulator(statistic, Vector{Any}(undef, n_boot), fill(false, n_boot)) + +n_bootstrap(acc::SimpleBootstrapAccumulator) = length(acc.samples) + +function update!(acc::SimpleBootstrapAccumulator, i::Integer, fit::SemFit, _) + acc.samples[i] = acc.statistic(fit) + acc.converged_mask[i] = converged(fit) +end + +""" + struct BootstrapResult{T} + +Stores the output of a [`bootstrap`](@ref) call. +""" +struct BootstrapResult{T} + samples::Vector{T} + converged_mask::BitVector + n_boot::Int + n_converged::Int +end + +function Base.show(io::IO, result::BootstrapResult{T}) where {T} + println( + io, + "BootstrapResult{$(T)} with $(result.n_converged)/$(result.n_boot) converged samples", + ) +end + """ bootstrap( - fitted::SemFit, - specification::SemSpecification; + fitted::SemFit; statistic = solution, n_boot = 3000, data = nothing, engine = :Optim, parallel = false, - fit_kwargs = Dict(), - replace_kwargs = Dict()) + fit_kwargs = Dict() + ) -> BootstrapResult + +Bootstrap the samples and apply `statistic` function to each. -Return bootstrap samples for `statistic`. +Returns a [`BootstrapResult`](@ref) object containing the results of `statistic` +applied to each bootstrapped sample. + +Supports both single-group and multi-group models. +For multi-group models, each group is resampled independently. # Arguments - `fitted`: a fitted SEM. -- `specification`: a `ParameterTable` or `RAMMatrices` object passed to `replace_observed`. - `statistic`: any function that can be called on a `SemFit` object. The output will be returned as the bootstrap sample. - `n_boot`: number of boostrap samples -- `data`: data to sample from. Only needed if different than the data from `sem_fit` +- `data`: data to sample from. Only needed if different than the fitted model. + For multi-group models, pass a `Dict{Symbol}` mapping term ids to data matrices. - `engine`: optimizer engine, passed to `fit`. - `parallel`: if `true`, run bootstrap samples in parallel on all available threads. The number of threads is controlled by the `JULIA_NUM_THREADS` environment variable or the `--threads` flag when starting Julia. - `fit_kwargs` : a `Dict` controlling model fitting for each bootstrap sample, - passed to `fit` -- `replace_kwargs`: a `Dict` passed to `replace_observed` + passed to [`fit`](@ref) # Example ```julia -# 1000 boostrap samples of the minimum, fitted with :Optim +# 1000 bootstrap samples of the minimum, fitted with :Optim bootstrap( fitted; statistic = StructuralEquationModels.minimum, @@ -40,95 +123,74 @@ bootstrap( """ function bootstrap( fitted::SemFit, - specification::SemSpecification; - statistic = solution, + statistic = solution; n_boot = 3000, data = nothing, engine = :Optim, parallel = false, fit_kwargs = Dict(), - replace_kwargs = Dict(), ) - # access data and convert to matrix - data = prepare_data_bootstrap(data, fitted.model) - start = solution(fitted) - # pre-allocations - out = Vector{Any}(nothing, n_boot) - conv = fill(false, n_boot) - # fit to bootstrap samples - if !parallel - for i in 1:n_boot - sample_data = bootstrap_sample(data) - new_model = replace_observed( - fitted.model; - data = sample_data, - specification = specification, - replace_kwargs..., - ) - new_fit = fit(new_model; start_val = start, engine = engine, fit_kwargs...) - sample = statistic(new_fit) - c = converged(new_fit) - out[i] = sample - conv[i] = c - end - else - n_threads = Threads.nthreads() - # Pre-create one independent model copy per thread via deepcopy. - model_pool = Channel(n_threads) - for _ in 1:n_threads - put!(model_pool, deepcopy(fitted.model)) - end - # fit models in parallel - lk = ReentrantLock() - Threads.@threads for i in 1:n_boot - thread_model = take!(model_pool) - sample_data = bootstrap_sample(data) - new_model = replace_observed( - thread_model; - data = sample_data, - specification = specification, - replace_kwargs..., - ) - new_fit = fit(new_model; start_val = start, engine = engine, fit_kwargs...) - sample = statistic(new_fit) - c = converged(new_fit) - out[i] = sample - conv[i] = c - put!(model_pool, thread_model) - end - end - return Dict( - :samples => collect(a for a in out), - :n_boot => n_boot, - :n_converged => sum(conv), - :converged => conv, + acc = SimpleBootstrapAccumulator(statistic, n_boot) + bootstrap!(acc, fitted; data, engine, parallel, fit_kwargs) + return BootstrapResult( + [s for s in acc.samples], + convert(BitVector, acc.converged_mask), + n_bootstrap(acc), + sum(acc.converged_mask), ) end +# bootstrap accumulator for se_bootstrap() +# accumulates per-parameter sum and sum of squares across bootstrap samples +struct StdErrBootstrapAccumulator <: BootstrapAccumulator + n_boot::Int + sum::Vector{Float64} + squared_sum::Vector{Float64} + n_converged::Ref{Int} +end + +n_bootstrap(acc::StdErrBootstrapAccumulator) = acc.n_boot + +StdErrBootstrapAccumulator(n_params::Integer, n_boot::Integer) = + StdErrBootstrapAccumulator(n_boot, zeros(n_params), zeros(n_params), Ref(0)) + +function update!( + acc::StdErrBootstrapAccumulator, + i::Integer, + fit::SemFit, + lk::Union{Base.AbstractLock, Nothing}, +) + conv = converged(fit) + if conv + sol = solution(fit) + isnothing(lk) || lock(lk) + acc.n_converged[] += 1 + @. acc.sum += sol + @. acc.squared_sum += abs2(sol) + isnothing(lk) || unlock(lk) + end +end + """ - se_bootstrap( - fitted::SemFit, - specification::SemSpecification; - n_boot = 3000, - data = nothing, - parallel = false, - fit_kwargs = Dict(), - replace_kwargs = Dict()) + se_bootstrap(fitted::SemFit; n_boot = 3000, kwargs...) -Return bootstrap standard errors. +Calculate standard errors using bootstrap approach. + +Supports both single-group and multi-group models. +For multi-group models, each group is resampled independently. # Arguments - `fitted`: a fitted SEM. -- `specification`: a `ParameterTable` or `RAMMatrices` object passed to `replace_observed`. - `n_boot`: number of boostrap samples -- `data`: data to sample from. Only needed if different than the data from `sem_fit` +- `data`: data to sample from. Only needed if different than the fitted model. + For multi-group models, pass a `Dict{Symbol}` mapping term ids to data matrices. - `engine`: optimizer engine, passed to `fit`. - `parallel`: if `true`, run bootstrap samples in parallel on all available threads. The number of threads is controlled by the `JULIA_NUM_THREADS` environment variable or the `--threads` flag when starting Julia. -- `fit_kwargs` : a `Dict` controlling model fitting for each bootstrap sample, - passed to `sem_fit` -- `replace_kwargs`: a `Dict` passed to `replace_observed` +- `fit_kwargs` : a `Dict` controlling model fitting for each bootstrap sample, + passed to [`fit`](@ref) + # Example ```julia @@ -142,109 +204,53 @@ se_bootstrap( ) ``` """ -function se_bootstrap( - fitted::SemFit, - specification::SemSpecification; - n_boot = 3000, - data = nothing, - engine = :Optim, - parallel = false, - fit_kwargs = Dict(), - replace_kwargs = Dict(), -) - # access data and convert to matrix - data = prepare_data_bootstrap(data, fitted.model) - start = solution(fitted) - # pre-allocations - total_sum = zero(start) - total_squared_sum = zero(start) - n_conv = Ref(0) - # fit to bootstrap samples - if !parallel - for _ in 1:n_boot - sample_data = bootstrap_sample(data) - new_model = replace_observed( - fitted.model; - data = sample_data, - specification = specification, - replace_kwargs..., - ) - new_fit = fit(new_model; start_val = start, engine = engine, fit_kwargs...) - sol = solution(new_fit) - conv = converged(new_fit) - if conv - n_conv[] += 1 - @. total_sum += sol - @. total_squared_sum += sol^2 - end - end +function se_bootstrap(fitted::SemFit; n_boot = 3000, kwargs...) + acc = StdErrBootstrapAccumulator(nparams(fitted), n_boot) + bootstrap!(acc, fitted; kwargs...) + n_conv = acc.n_converged[] + @info "$n_conv models converged" + + if n_conv == 0 + @warn "No bootstrap samples converged. Returning NaN." + return fill(NaN, length(acc.sum)) else - n_threads = Threads.nthreads() - # Pre-create one independent model copy per thread via deepcopy. - model_pool = Channel(n_threads) - for _ in 1:n_threads - put!(model_pool, deepcopy(fitted.model)) - end - # fit models in parallel - lk = ReentrantLock() - Threads.@threads for _ in 1:n_boot - thread_model = take!(model_pool) - sample_data = bootstrap_sample(data) - new_model = replace_observed( - thread_model; - data = sample_data, - specification = specification, - replace_kwargs..., - ) - new_fit = fit(new_model; start_val = start, engine = engine, fit_kwargs...) - sol = solution(new_fit) - conv = converged(new_fit) - if conv - lock(lk) do - n_conv[] += 1 - @. total_sum += sol - @. total_squared_sum += sol^2 - end - end - put!(model_pool, thread_model) - end + return sqrt.(acc.squared_sum ./ n_conv - abs2.(acc.sum / n_conv)) end - # compute parameters - n_conv = n_conv[] - sd = sqrt.(total_squared_sum / n_conv - (total_sum / n_conv) .^ 2) - @info string(n_conv)*" models converged" - return sd end ############################################################################################ ### Helper Functions ############################################################################################ -function bootstrap_sample(data::Matrix) - nobs = size(data, 1) - index_new = rand(1:nobs, nobs) - data_new = data[index_new, :] - return data_new -end +""" + resample_with_replacement(data::AbstractMatrix) + resample_with_replacement(data::AbstractVector{<:AbstractMatrix}) -bootstrap_sample(data::Dict) = Dict(k => bootstrap_sample(data[k]) for k in keys(data)) +Resample rows of a data matrix with replacement (bootstrap sample). +For a vector of matrices (multi-group models), independently resamples each matrix. +""" +function resample_with_replacement(data::AbstractMatrix) + n = size(data, 1) + return data[rand(1:n, n), :] +end -function prepare_data_bootstrap(data, model::AbstractSemSingle) - if isnothing(data) - data = samples(observed(model)) - end - data = Matrix(data) - return data +function resample_with_replacement(data::AbstractVector{<:AbstractMatrix}) + return [resample_with_replacement(term_data) for term_data in data] end -function prepare_data_bootstrap(data, model::SemEnsemble) - sems = model.sems - groups = model.groups - if isnothing(data) - data = Dict(g => samples(observed(m)) for (g, m) in zip(groups, sems)) +# Extract data from a model for bootstrap resampling. +function _bootstrap_data(sem::AbstractSem) + terms = sem_terms(sem) + if length(terms) == 1 + return samples(observed(loss(terms[1]))) + else + return [samples(observed(loss(term))) for term in terms] end - data = Dict(k => Matrix(data[k]) for k in keys(data)) - return data end - +# Fit one bootstrap replicate: resample, replace observed data, fit. +function _fit_bootstrap_sample(sem_model, data, start; engine, fit_kwargs) + boot_data = resample_with_replacement(data) + boot_model = replace_observed(sem_model, boot_data) + return fit(boot_model; start_val = start, engine = engine, fit_kwargs...) +end diff --git a/test/examples/helper.jl b/test/examples/helper.jl index c4191fdb..f14fec62 100644 --- a/test/examples/helper.jl +++ b/test/examples/helper.jl @@ -138,17 +138,17 @@ function test_estimates( end function test_bootstrap( - model_fit, - spec; + model_fit::SemFit; compare_hessian = true, rtol_hessian = 0.2, compare_bs = true, rtol_bs = 0.1, n_boot = 500, + seed = 32432, ) - @testset rng = Random.seed!(32432) "bootstrap" begin - se_bs = @suppress se_bootstrap(model_fit, spec; n_boot = n_boot) - # hessian and bootstrap se are close + @testset rng = Random.seed!(seed) "bootstrap" begin + se_bs = @suppress se_bootstrap(model_fit; n_boot = n_boot) + # hessian-based and bootstrap-based std.errors are close if compare_hessian se_he = @suppress se_hessian(model_fit) #println(maximum(abs.(se_he - se_bs))) @@ -156,10 +156,9 @@ function test_bootstrap( end # se_bootstrap and bootstrap |> se are close if compare_bs - bs_samples = bootstrap(model_fit, spec; n_boot = n_boot) - @test bs_samples[:n_converged] >= 0.95*n_boot - bs_samples = - cat(bs_samples[:samples][BitVector(bs_samples[:converged])]..., dims = 2) + bs_samples = bootstrap(model_fit; n_boot = n_boot) + @test bs_samples.n_converged >= 0.95*n_boot + bs_samples = reduce(hcat, bs_samples.samples[bs_samples.converged_mask]) se_bs_2 = sqrt.(var(bs_samples, corrected = false, dims = 2)) #println(maximum(abs.(se_bs_2 - se_bs))) @test isapprox(se_bs_2, se_bs, rtol = rtol_bs) @@ -167,14 +166,14 @@ function test_bootstrap( end end -function smoketest_bootstrap(model_fit, spec; n_boot = 5) - # hessian and bootstrap se are close - se_bs = se_bootstrap(model_fit, spec; n_boot = n_boot) - bs_samples = bootstrap(model_fit, spec; n_boot = n_boot) +function smoketest_bootstrap(model_fit::SemFit; n_boot = 5) + # just test that both methods succeed + se_bs = se_bootstrap(model_fit; n_boot = n_boot) + bs_samples = bootstrap(model_fit; n_boot = n_boot) return se_bs, bs_samples end -function smoketest_CI_z(model_fit, partable) +function smoketest_CI_z(model_fit::SemFit, partable) se_he = @suppress se_hessian(model_fit) normal_CI!(partable, model_fit, se_he) z_test!(partable, model_fit, se_he) diff --git a/test/examples/multigroup/build_models.jl b/test/examples/multigroup/build_models.jl index 48723fbe..462deab6 100644 --- a/test/examples/multigroup/build_models.jl +++ b/test/examples/multigroup/build_models.jl @@ -249,7 +249,7 @@ end lav_col = :se, lav_groups = Dict(:Pasteur => 1, :Grant_White => 2), ) - # test_bootstrap(solution_ls, partable; compare_bs = false, rtol_hessian = 0.3) + # test_bootstrap(solution_ls; compare_bs = false, rtol_hessian = 0.3) smoketest_CI_z(solution_ls, partable) end @@ -360,7 +360,7 @@ if !isnothing(specification_miss_g1) fitmeasure_names = Dict(:CFI => "cfi"), ) - test_bootstrap(solution, partable_miss; compare_bs = false, rtol_hessian = 0.5) + test_bootstrap(solution; compare_bs = false, rtol_hessian = 0.5) smoketest_CI_z(solution, partable_miss) update_se_hessian!(partable_miss, solution) diff --git a/test/examples/political_democracy/constructor.jl b/test/examples/political_democracy/constructor.jl index 25a6da91..759875b2 100644 --- a/test/examples/political_democracy/constructor.jl +++ b/test/examples/political_democracy/constructor.jl @@ -116,7 +116,7 @@ end lav_col = :se, ) - test_bootstrap(solution_ml, partable) + test_bootstrap(solution_ml) smoketest_CI_z(solution_ml, partable) end @@ -146,7 +146,7 @@ end lav_col = :se, ) - test_bootstrap(solution_ls, partable; compare_bs = false) + test_bootstrap(solution_ls; compare_bs = false) smoketest_CI_z(solution_ls, partable) end @@ -324,7 +324,7 @@ end lav_col = :se, ) - test_bootstrap(solution_ml, partable_mean) + test_bootstrap(solution_ml) smoketest_CI_z(solution_ml, partable_mean) end @@ -353,8 +353,8 @@ end lav_col = :se, ) - test_bootstrap(solution_ls, partable_mean, compare_bs = false) - # smoketest_bootstrap(solution_ls, partable_mean) + test_bootstrap(solution_ls, compare_bs = false) + # smoketest_bootstrap(solution_ls) smoketest_CI_z(solution_ls, partable_mean) end @@ -481,7 +481,7 @@ end lav_col = :se, ) - # test_bootstrap(solution_ml, partable_mean) # too much compute - smoketest_bootstrap(solution_ml, partable_mean) + # test_bootstrap(solution_ml) # too much compute + smoketest_bootstrap(solution_ml) smoketest_CI_z(solution_ml, partable_mean) end From 24261d52055b8400400b8271774a7db483bd6906 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sat, 21 Mar 2026 17:56:25 -0700 Subject: [PATCH 10/74] CFI: sync with Sem refactor --- src/frontend/fit/fitmeasures/CFI.jl | 47 ++++++++++++++--------------- 1 file changed, 22 insertions(+), 25 deletions(-) diff --git a/src/frontend/fit/fitmeasures/CFI.jl b/src/frontend/fit/fitmeasures/CFI.jl index 9f3c5a2d..e2bcb8c7 100644 --- a/src/frontend/fit/fitmeasures/CFI.jl +++ b/src/frontend/fit/fitmeasures/CFI.jl @@ -5,11 +5,11 @@ Calculate the Comparative Fit Index (CFI). -The CFI ranges from 0-1 and measures how much better the model +The CFI ranges from 0-1 and measures how much better the model fits the data compared to a baseline model. If no baseline model is provided, a model with unconstrained variances (and means) is compaired against. -For multigroup models, variances (and means) per group are free +For multigroup models, variances (and means) per group are free without any equality constraints between groups. """ function CFI end @@ -35,34 +35,31 @@ function CFI(χ², dof, χ²₀, dof₀) end ### -function χ²_varonly(model::AbstractSemSingle) - check_single_lossfun(model; throw_error = true) - return χ²_varonly(model.loss.functions[1], model) -end - -function χ²_varonly(model::SemEnsemble) - check_single_lossfun(model; throw_error = true) - return sum(χ²_varonly, model.sems) +function χ²_varonly(model::AbstractSem) + check_same_semterm_type(model; throw_error = true) + return sum(sem_terms(model)) do semterm + χ²_varonly(_unwrap(loss(semterm))) + end end -function χ²_varonly(::SemML, model::AbstractSemSingle) - N⁻ = (nsamples(model) - 1) - S = obs_cov(observed(model)) +function χ²_varonly(loss::SemML) + N⁻ = (nsamples(loss) - 1) + S = obs_cov(observed(loss)) Σ₀ = Diagonal(S) - p = nobserved_vars(model) + p = nobserved_vars(loss) return N⁻*(logdet(Σ₀) + tr(inv(Σ₀)*S) - logdet(S) - p) end # for the optimal variance only model, we have to solve 1/2 tr((I-XS⁻¹)^2) with X diagonal -function χ²_varonly(::SemWLS, model) - N⁻ = (nsamples(model) - 1) - S⁻¹ = inv((obs_cov(observed(model)))) +function χ²_varonly(loss::SemWLS) + N⁻ = (nsamples(loss) - 1) + S⁻¹ = inv((obs_cov(observed(loss)))) Σ₀ = Diagonal(inv(S⁻¹ .* S⁻¹)*diag(S⁻¹)) return N⁻*0.5*tr((I - Σ₀*S⁻¹)^2) end # For FIML, an explicit bl model has to be passed -function χ²_varonly(::SemFIML, model) +function χ²_varonly(loss::SemFIML) """ Computing the CFI with FIML requires explicitely passing a fitted baseline model as CFI(fit::SemFit, fit_baseline::SemFit) @@ -71,12 +68,12 @@ function χ²_varonly(::SemFIML, model) throw end -function dof_varonly(model::AbstractSemSingle) - nparams_varonly = nobserved_vars(model) - if MeanStruct(model.implied) === HasMeanStruct - nparams_varonly *= 2 +function dof_varonly(model::AbstractSem) + return sum(sem_terms(model)) do semterm + nparams_varonly = nobserved_vars(semterm) + if MeanStruct(implied(semterm)) === HasMeanStruct + nparams_varonly *= 2 + end + return n_dp(loss(semterm)) - nparams_varonly end - return n_dp(model) - nparams_varonly end - -dof_varonly(model::SemEnsemble) = sum(dof_varonly, model.sems) From e4d38e582e19d0ee0b6a8d13e407928a99305f24 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sat, 21 Mar 2026 17:57:46 -0700 Subject: [PATCH 11/74] test/build_models: remove redundant model --- test/examples/multigroup/build_models.jl | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/test/examples/multigroup/build_models.jl b/test/examples/multigroup/build_models.jl index 462deab6..329c5502 100644 --- a/test/examples/multigroup/build_models.jl +++ b/test/examples/multigroup/build_models.jl @@ -347,19 +347,6 @@ if !isnothing(specification_miss_g1) lav_groups = Dict(:Pasteur => 1, :Grant_White => 2), ) - solution = fit(semoptimizer, model_ml_multigroup2) - test_fitmeasures( - fit_measures(solution), - solution_lav[:fitmeasures_fiml]; - rtol = 1e-3, - atol = 0, - ) - test_fitmeasures( - Dict(:CFI => CFI(solution, solution_varonly)), - solution_lav[:fitmeasures_fiml]; - fitmeasure_names = Dict(:CFI => "cfi"), - ) - test_bootstrap(solution; compare_bs = false, rtol_hessian = 0.5) smoketest_CI_z(solution, partable_miss) From cb9b1e772392ed6697f4c641c8bff1830fae9ce3 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sat, 21 Mar 2026 17:57:46 -0700 Subject: [PATCH 12/74] revert using --- test/examples/multigroup/multigroup.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/examples/multigroup/multigroup.jl b/test/examples/multigroup/multigroup.jl index dd654731..c8ae8c1f 100644 --- a/test/examples/multigroup/multigroup.jl +++ b/test/examples/multigroup/multigroup.jl @@ -1,5 +1,5 @@ using StructuralEquationModels, Test, FiniteDiff, Suppressor -using LinearAlgebra: diagind, LowerTriangular +using LinearAlgebra: diagind, isposdef, logdet, tr, LowerTriangular using Statistics: var using Random From afac0b4b56dcbfa3599b6d6ba43f1490729eb6bb Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sat, 21 Mar 2026 17:57:46 -0700 Subject: [PATCH 13/74] WLS: verbose option to suppress info about inv(obs_cov) --- src/loss/WLS/WLS.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/loss/WLS/WLS.jl b/src/loss/WLS/WLS.jl index 8f4a109c..d04bc346 100644 --- a/src/loss/WLS/WLS.jl +++ b/src/loss/WLS/WLS.jl @@ -62,6 +62,7 @@ function SemWLS( wls_weight_matrix::Union{AbstractMatrix, Nothing} = nothing, wls_weight_matrix_mean::Union{AbstractMatrix, Nothing} = nothing, approximate_hessian::Bool = false, + verbose::Bool = false, kwargs..., ) if observed isa SemObservedMissing @@ -114,7 +115,8 @@ function SemWLS( if MeanStruct(implied) == HasMeanStruct if isnothing(wls_weight_matrix_mean) - @info "Computing WLS weight matrix for the meanstructure using obs_cov()" + verbose && + @info "Computing WLS weight matrix for the meanstructure using obs_cov()" wls_weight_matrix_mean = inv(obs_cov(observed)) end size(wls_weight_matrix_mean) == (nobs_vars, nobs_vars) || DimensionMismatch( From 53a615a3b602e4485a300987c24f9f3035249ff2 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sat, 21 Mar 2026 17:57:46 -0700 Subject: [PATCH 14/74] docs: sync with Sem refactor --- docs/src/developer/implied.md | 30 ++++------- docs/src/developer/loss.md | 21 ++------ docs/src/developer/optimizer.md | 8 --- docs/src/developer/sem.md | 17 +++---- docs/src/internals/types.md | 20 +++++--- docs/src/performance/mixed_differentiation.md | 18 +++---- docs/src/performance/simulation.md | 17 +------ docs/src/tutorials/collection/collection.md | 51 +++++++++++++------ docs/src/tutorials/collection/multigroup.md | 17 ++++--- 9 files changed, 91 insertions(+), 108 deletions(-) diff --git a/docs/src/developer/implied.md b/docs/src/developer/implied.md index 056cd663..6321decb 100644 --- a/docs/src/developer/implied.md +++ b/docs/src/developer/implied.md @@ -13,9 +13,9 @@ end and a method to update!: ```julia -import StructuralEquationModels: objective! +import StructuralEquationModels: update! -function update!(targets::EvaluationTargets, implied::MyImplied, model::AbstractSemSingle, params) +function update!(targets::EvaluationTargets, implied::MyImplied, params) if is_objective_required(targets) ... @@ -31,11 +31,9 @@ function update!(targets::EvaluationTargets, implied::MyImplied, model::Abstract end ``` -As you can see, `update` gets passed as a first argument `targets`, which is telling us whether the objective value, gradient, and/or hessian are needed. +As you can see, `update!` gets passed as a first argument `targets`, which is telling us whether the objective value, gradient, and/or hessian are needed. We can then use the functions `is_..._required` and conditional on what the optimizer needs, we can compute and store things we want to make available to the loss functions. For example, as we have seen in [Second example - maximum likelihood](@ref), the `RAM` implied type computes the model-implied covariance matrix and makes it available via `implied.Σ`. - - Just as described in [Custom loss functions](@ref), you may define a constructor. Typically, this will depend on the `specification = ...` argument that can be a `ParameterTable` or a `RAMMatrices` object. We implement an `ImpliedEmpty` type in our package that does nothing but serving as an `implied` field in case you are using a loss function that does not need any implied type at all. You may use it as a template for defining your own implied type, as it also shows how to handle the specification objects: @@ -56,7 +54,7 @@ Empty placeholder for models that don't need an implied part. - `specification`: either a `RAMMatrices` or `ParameterTable` object # Examples -A multigroup model with ridge regularization could be specified as a `SemEnsemble` with one +A multigroup model with ridge regularization could be specified as a `Sem` with one model per group and an additional model with `ImpliedEmpty` and `SemRidge` for the regularization part. # Extended help @@ -75,26 +73,20 @@ end ### Constructors ############################################################################################ -function ImpliedEmpty(; - specification, - meanstruct = NoMeanStruct(), - hessianeval = ExactHessian(), +function ImpliedEmpty( + spec::SemSpecification; + hessianeval::HessianApprox = ExactHessian(), kwargs..., ) - return ImpliedEmpty(hessianeval, meanstruct, convert(RAMMatrices, specification)) + ram_matrices = convert(RAMMatrices, spec) + return ImpliedEmpty(hessianeval, MeanStruct(ram_matrices), ram_matrices) end ############################################################################################ ### methods ############################################################################################ -update!(targets::EvaluationTargets, implied::ImpliedEmpty, par, model) = nothing - -############################################################################################ -### Recommended methods -############################################################################################ - -update_observed(implied::ImpliedEmpty, observed::SemObserved; kwargs...) = implied +update!(targets::EvaluationTargets, implied::ImpliedEmpty, par) = nothing ``` -As you see, similar to [Custom loss functions](@ref) we implement a method for `update_observed`. \ No newline at end of file +As you see, similar to [Custom loss functions](@ref) we implement a constructor. \ No newline at end of file diff --git a/docs/src/developer/loss.md b/docs/src/developer/loss.md index d6949842..aa6a1e17 100644 --- a/docs/src/developer/loss.md +++ b/docs/src/developer/loss.md @@ -11,9 +11,9 @@ Since we allow for the optimization of sums of loss functions, and the maximum l using StructuralEquationModels ``` -To define a new loss function, you have to define a new type that is a subtype of `SemLossFunction`: +To define a new loss function, you have to define a new type that is a subtype of `AbstractLoss`: ```@example loss -struct Ridge <: SemLossFunction +struct MyRidge <: AbstractLoss α I end @@ -25,8 +25,8 @@ Additionaly, we need to define a *method* of the function `evaluate!` to compute ```@example loss import StructuralEquationModels: evaluate! -evaluate!(objective::Number, gradient::Nothing, hessian::Nothing, ridge::Ridge, model::AbstractSem, par) = - ridge.α * sum(i -> par[i]^2, ridge.I) +evaluate!(objective::Number, gradient::Nothing, hessian::Nothing, ridge::MyRidge, par) = + ridge.α * sum(i -> abs2(par[i]), ridge.I) ``` The function `evaluate!` recognizes by the types of the arguments `objective`, `gradient` and `hessian` whether it should compute the objective value, gradient or hessian of the model w.r.t. the parameters. @@ -98,7 +98,7 @@ function evaluate!(objective, gradient, hessian::Nothing, ridge::Ridge, model::A gradient[ridge.I] .= 2 * ridge.α * par[ridge.I] end # compute objective - if !isnothing(objective) + if !isnothing(objective) return ridge.α * sum(i -> par[i]^2, ridge.I) end end @@ -166,17 +166,6 @@ end ## Additional functionality -### Update observed data - -If you are planing a simulation study where you have to fit the **same model** to many **different datasets**, it is computationally beneficial to not build the whole model completely new everytime you change your data. -Therefore, we provide a function to update the data of your model, `replace_observed(model(semfit); data = new_data)`. However, we can not know beforehand in what way your loss function depends on the specific datasets. The solution is to provide a method for `update_observed`. Since `Ridge` does not depend on the data at all, this is quite easy: - -```julia -import StructuralEquationModels: update_observed - -update_observed(ridge::Ridge, observed::SemObserved; kwargs...) = ridge -``` - ### Access additional information If you want to provide a way to query information about loss functions of your type, you can provide functions for that: diff --git a/docs/src/developer/optimizer.md b/docs/src/developer/optimizer.md index b5c9a6e0..4659ba5d 100644 --- a/docs/src/developer/optimizer.md +++ b/docs/src/developer/optimizer.md @@ -25,12 +25,6 @@ struct MyoptResult{O <: SemOptimizerMyopt} <: SEM.SemOptimizerResult{O} ... end -############################################################################################ -### Recommended methods -############################################################################################ - -update_observed(optimizer::SemOptimizerMyopt, observed::SemObserved; kwargs...) = optimizer - ############################################################################################ ### additional methods ############################################################################################ @@ -43,8 +37,6 @@ and `SEM.sem_optimizer_subtype(::Val{:Myopt})` returns `SemOptimizerMyopt`. This instructs *SEM.jl* to use `SemOptimizerMyopt` when `:Myopt` is specified as the engine for model fitting: `fit(..., engine = :Myopt)`. -A method for `update_observed` and additional methods might be usefull, but are not necessary. - Now comes the essential part: we need to provide the [`fit`](@ref) method with `SemOptimizerMyopt` as the first positional argument. diff --git a/docs/src/developer/sem.md b/docs/src/developer/sem.md index c54ff26a..bb077f43 100644 --- a/docs/src/developer/sem.md +++ b/docs/src/developer/sem.md @@ -1,13 +1,14 @@ # Custom model types -The abstract supertype for all models is `AbstractSem`, which has two subtypes, `AbstractSemSingle{O, I, L}` and `AbstractSemCollection`. Currently, there are 2 subtypes of `AbstractSemSingle`: `Sem`, `SemFiniteDiff`. All subtypes of `AbstractSemSingle` should have at least observed, implied, loss and optimizer fields, and share their types (`{O, I, L}`) with the parametric abstract supertype. For example, the `SemFiniteDiff` type is implemented as +The abstract supertype for all models is [`AbstractSem`](@ref). Currently, there are 2 concrete subtypes: +`Sem{L <: Tuple}` and `SemFiniteDiff{S <: AbstractSem}`. +A `Sem` model holds a tuple of `LossTerm`s (each wrapping an `AbstractLoss`) and a vector of parameter labels. Both single-group and multigroup models are represented as `Sem`. + +`SemFiniteDiff` wraps any `AbstractSem` and substitutes dedicated gradient/hessian evaluation with finite difference approximation: ```julia -struct SemFiniteDiff{O <: SemObserved, I <: SemImplied, L <: SemLoss} <: - AbstractSemSingle{O, I, L} - observed::O - implied::I - loss::L +struct SemFiniteDiff{S <: AbstractSem} <: AbstractSem + model::S end ``` @@ -17,6 +18,4 @@ Additionally, you can change how objective/gradient/hessian values are computed evaluate!(objective, gradient, hessian, model::SemFiniteDiff, params) = ... ``` -Additionally, we can define constructors like the one in `"src/frontend/specification/Sem.jl"`. - -It is also possible to add new subtypes for `AbstractSemCollection`. \ No newline at end of file +Additionally, we can define constructors like the one in `"src/frontend/specification/Sem.jl"`. \ No newline at end of file diff --git a/docs/src/internals/types.md b/docs/src/internals/types.md index e70a52ca..4b4cd4fa 100644 --- a/docs/src/internals/types.md +++ b/docs/src/internals/types.md @@ -2,12 +2,16 @@ The type hierarchy is implemented in `"src/types.jl"`. -`AbstractSem`: the most abstract type in our package -- `AbstractSemSingle{O, I, L} <: AbstractSem` is an abstract parametric type that is a supertype of all single models - - `Sem`: models that do not need automatic differentiation or finite difference approximation - - `SemFiniteDiff`: models whose gradients and/or hessians should be computed via finite difference approximation -- `AbstractSemCollection <: AbstractSem` is an abstract supertype of all models that contain multiple `AbstractSem` submodels +[`AbstractLoss`](@ref): is the base abstract type for all loss functions: +- `SemLoss{O <: SemObserved, I <: SemImplied}`: is the subtype of `AbstractLoss`, which is the + base for all SEM-specific loss functions ([`SemML`](@ref), [`SemWLS`](@ref) etc) that + evaluate how closely the implied covariation structure (represented by the object of type `I`) + matches the observed one (contained in the object of type `O`); +- regularizing terms (e.g. [`SemRidge`](@ref)) are implemented as subtypes of `AbstractLoss`. -Every `AbstractSemSingle` has to have `SemObserved`, `SemImplied`, and `SemLoss` fields (and can have additional fields). - -`SemLoss` is a container for multiple `SemLossFunctions`. \ No newline at end of file +[`AbstractSem`](@ref) is the base abstract type for all SEM models. It has two concrete subtypes: +- `Sem{L <: Tuple} <: AbstractSem`: the main SEM model type that implements a list of weighted +loss terms (using [`LossTerm`](@ref) wrapper around `AbstractLoss`) and allows modeling both single +and multi-group SEMs and combining them with regularization terms. +- `SemFiniteDiff{S <: AbstractSem} <: AbstractSem`: a wrapper around any `AbstractSem` that + substitutes dedicated gradient/hessian evaluation with finite difference approximation. diff --git a/docs/src/performance/mixed_differentiation.md b/docs/src/performance/mixed_differentiation.md index b7ae333b..f33fa6ab 100644 --- a/docs/src/performance/mixed_differentiation.md +++ b/docs/src/performance/mixed_differentiation.md @@ -2,22 +2,20 @@ This way of specifying our model is not ideal, however, because now also the maximum likelihood loss function lives inside a `SemFiniteDiff` model, and this means even though we have defined analytical gradients for it, we do not make use of them. -A more efficient way is therefore to specify our model as an ensemble model: +A more efficient way is therefore to specify our model as a combined model with multiple loss terms: ```julia -model_ml = Sem( - specification = partable, - data = data, - loss = SemML +ml_term = SemML( + SemObservedData(data = data, specification = partable), + RAMSymbolic(partable) ) -model_ridge = SemFiniteDiff( - specification = partable, - data = data, - loss = myridge +ridge_term = SemRidge( + α_ridge = 0.01, + which_ridge = params(ml_term) ) -model_ml_ridge = SemEnsemble(model_ml, model_ridge) +model_ml_ridge = Sem(ml_term, ridge_term) model_ml_ridge_fit = fit(model_ml_ridge) ``` diff --git a/docs/src/performance/simulation.md b/docs/src/performance/simulation.md index 85a0c0a0..61a9d5ad 100644 --- a/docs/src/performance/simulation.md +++ b/docs/src/performance/simulation.md @@ -57,19 +57,7 @@ model = Sem( data = data_1 ) -model_updated = replace_observed(model; data = data_2, specification = partable) -``` - -If you are building your models by parts, you can also update each part seperately with the function `update_observed`. -For example, - -```@example replace_observed - -new_observed = SemObservedData(;data = data_2, specification = partable) - -my_optimizer = SemOptimizer() - -new_optimizer = update_observed(my_optimizer, new_observed) +model_updated = replace_observed(model, data_2) ``` ## Multithreading @@ -90,7 +78,7 @@ model1 = Sem( data = data_1 ) -model2 = deepcopy(replace_observed(model; data = data_2, specification = partable)) +model2 = deepcopy(replace_observed(model, data_2)) models = [model1, model2] fits = Vector{SemFit}(undef, 2) @@ -104,5 +92,4 @@ end ```@docs replace_observed -update_observed ``` \ No newline at end of file diff --git a/docs/src/tutorials/collection/collection.md b/docs/src/tutorials/collection/collection.md index f60b7312..2a8ea92c 100644 --- a/docs/src/tutorials/collection/collection.md +++ b/docs/src/tutorials/collection/collection.md @@ -1,31 +1,52 @@ # Collections -With StructuralEquationModels.jl, you can fit weighted sums of structural equation models. -The most common use case for this are [Multigroup models](@ref). -Another use case may be optimizing the sum of loss functions for some of which you do know the analytic gradient, but not for others. -In this case, you can optimize the sum of a `Sem` and a `SemFiniteDiff` (or any other differentiation method). +With *StructuralEquationModels.jl*, you can fit weighted sums of structural equation models. +The most common use case for this are [Multigroup models](@ref). +Another use case may be optimizing the sum of loss functions for some of which you do know the analytic gradient, but not for others. +In this case, [`FiniteDiffWrapper`](@ref) can generate a wrapper around the specific `SemLoss` term. The wrapper loss term will +only use the objective of the original term to calculate its gradient using finite difference approximation. -To use this feature, you have to construct a `SemEnsemble` model, which is actually quite easy: +```julia +loss_1 = SemML(observed_1, implied_1) +loss_2 = SemML(observed_2, implied_2) +loss_2_findiff = FiniteDiffWrapper(loss_2) +``` + +To construct `Sem` from the the individual `SemLoss` (or other `AbstractLoss`) terms, they are +just passed to the `Sem` constructor: ```julia -# models -model_1 = Sem(...) +model = Sem(loss_1, loss_2) +model_findiff = Sem(loss_1, loss_2_findiff) +``` + +It is also possible to use finite difference for the entire `Sem` model: -model_2 = SemFiniteDiff(...) +```julia +model_findiff2 = FiniteDiffWrapper(model) +``` -model_3 = Sem(...) +The weighting scheme of the SEM loss terms is specified using `default_set_weights` argument of the `Sem` constructor. +The `:nsamples` scheme (the default) weights SEM terms by ``N_{term}/N_{total}``, i.e. each term is weighted by the number +of observations in its data (which matches the formula for multigroup models). +The weights for the loss terms (both SEM and regularization) can be explicitly specified the pair syntax `loss => weight`: -model_ensemble = SemEnsemble(model_1, model_2, model_3) +```julia +model_weighted = Sem(loss_1 => 0.5, loss_2 => 1.0) ``` -So you just construct the individual models (however you like) and pass them to `SemEnsemble`. -You may also pass a vector of weigths to `SemEnsemble`. By default, those are set to ``N_{model}/N_{total}``, i.e. each model is weighted by the number of observations in it's data (which matches the formula for multigroup models). +`Sem` support assigning unique identifier to each loss term, which is essential for complex multi-term model. +The syntax is `id => loss`, or `id => loss => weight`: -Multigroup models can also be specified via the graph interface; for an example, see [Multigroup models](@ref). +```julia +model2 = Sem(:main => loss_1, :alt => loss_2) +model2_weighted = Sem(:main => loss_1 => 0.5, :alt => loss_2 => 1.0) +``` # API - collections ```@docs -SemEnsemble -AbstractSemCollection +Sem +LossTerm +FiniteDiffWrapper ``` \ No newline at end of file diff --git a/docs/src/tutorials/collection/multigroup.md b/docs/src/tutorials/collection/multigroup.md index 16d3dcd7..04f1893d 100644 --- a/docs/src/tutorials/collection/multigroup.md +++ b/docs/src/tutorials/collection/multigroup.md @@ -4,19 +4,20 @@ using StructuralEquationModels ``` -As an example, we will fit the model from [the `lavaan` tutorial](https://lavaan.ugent.be/tutorial/groups.html) with loadings constrained to equality across groups. +As an example, we will fit the model from [the `lavaan` tutorial](https://lavaan.ugent.be/tutorial/groups.html) +with loadings constrained to equality across groups. -We first load the example data. +We first load the example data. We have to make sure that the column indicating the group (here called `school`) is a vector of `Symbol`s, not strings - so we convert it. ```@setup mg dat = example_data("holzinger_swineford") -dat.school = ifelse.(dat.school .== "Pasteur", :Pasteur, :Grant_White) +dat.school = Symbol.(replace.(dat.school, "-" => "_")) ``` ```julia dat = example_data("holzinger_swineford") -dat.school = ifelse.(dat.school .== "Pasteur", :Pasteur, :Grant_White) +dat.school = Symbol.(replace.(dat.school, "-" => "_")) ``` We then specify our model via the graph interface: @@ -59,19 +60,19 @@ You can then use the resulting graph to specify an `EnsembleParameterTable` groups = [:Pasteur, :Grant_White] partable = EnsembleParameterTable( - graph, + graph, observed_vars = observed_vars, latent_vars = latent_vars, groups = groups) ``` -The parameter table can be used to create a `SemEnsemble` model: +The parameter table can be used to create a multigroup `Sem` model: ```@example mg; ansicolor = true -model_ml_multigroup = SemEnsemble( +model_ml_multigroup = Sem( specification = partable, data = dat, - column = :school, + semterm_column = :school, groups = groups) ``` From 240e3cd42058aa3ccecb2c44dcc521de0bfe116c Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sun, 22 Mar 2026 13:19:46 -0700 Subject: [PATCH 15/74] test: fix formatting --- test/examples/multigroup/build_models.jl | 12 +++-------- .../political_democracy/constructor.jl | 20 ++++--------------- 2 files changed, 7 insertions(+), 25 deletions(-) diff --git a/test/examples/multigroup/build_models.jl b/test/examples/multigroup/build_models.jl index 329c5502..6811cb40 100644 --- a/test/examples/multigroup/build_models.jl +++ b/test/examples/multigroup/build_models.jl @@ -19,17 +19,11 @@ model_ml_multigroup = Sem( end # replace observed using Dict of data matrices -model_ml_multigroup3 = replace_observed( - model_ml_multigroup, - Dict(:Pasteur => dat_g1, :Grant_White => dat_g2), -) +model_ml_multigroup3 = + replace_observed(model_ml_multigroup, Dict(:Pasteur => dat_g1, :Grant_White => dat_g2)) # replace observed using DataFrame with group column -model_ml_multigroup4 = replace_observed( - model_ml_multigroup, - dat; - semterm_column = :school, -) +model_ml_multigroup4 = replace_observed(model_ml_multigroup, dat; semterm_column = :school) # gradients @testset "ml_gradients_multigroup" begin diff --git a/test/examples/political_democracy/constructor.jl b/test/examples/political_democracy/constructor.jl index 759875b2..48ba1b96 100644 --- a/test/examples/political_democracy/constructor.jl +++ b/test/examples/political_democracy/constructor.jl @@ -167,14 +167,8 @@ end # set seed for simulation Random.seed!(83472834) # simulate data - model_ml_new = replace_observed( - model_ml, - rand(model_ml, params, 1_000_000), - ) - model_ml_sym_new = replace_observed( - model_ml_sym, - rand(model_ml_sym, params, 1_000_000), - ) + model_ml_new = replace_observed(model_ml, rand(model_ml, params, 1_000_000)) + model_ml_sym_new = replace_observed(model_ml_sym, rand(model_ml_sym, params, 1_000_000)) # fit models sol_ml = solution(fit(semoptimizer, model_ml_new)) sol_ml_sym = solution(fit(semoptimizer, model_ml_sym_new)) @@ -376,14 +370,8 @@ end # set seed for simulation Random.seed!(83472834) # simulate data - model_ml_new = replace_observed( - model_ml, - rand(model_ml, params, 1_000_000), - ) - model_ml_sym_new = replace_observed( - model_ml_sym, - rand(model_ml_sym, params, 1_000_000), - ) + model_ml_new = replace_observed(model_ml, rand(model_ml, params, 1_000_000)) + model_ml_sym_new = replace_observed(model_ml_sym, rand(model_ml_sym, params, 1_000_000)) # fit models sol_ml = solution(fit(semoptimizer, model_ml_new)) sol_ml_sym = solution(fit(semoptimizer, model_ml_sym_new)) From a277cb0e87a77e9b6772996aaac10dc0b588e78a Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sun, 22 Mar 2026 20:54:23 -0700 Subject: [PATCH 16/74] fit_measures(): support vectors of funcs also add CFI to the list --- src/frontend/fit/fitmeasures/fit_measures.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/frontend/fit/fitmeasures/fit_measures.jl b/src/frontend/fit/fitmeasures/fit_measures.jl index 185b348c..7fabc950 100644 --- a/src/frontend/fit/fitmeasures/fit_measures.jl +++ b/src/frontend/fit/fitmeasures/fit_measures.jl @@ -1,6 +1,8 @@ -fit_measures(fit) = fit_measures(fit, nparams, dof, AIC, BIC, RMSEA, χ², p_value, minus2ll) +const DEFAULT_FIT_MEASURES = [AIC, BIC, dof, χ², p_value, nparams, RMSEA, CFI] -fit_measures(fit, measures...) = Dict(Symbol(fn) => fn(fit) for fn in measures) +fit_measures(fit, measures::AbstractVector) = Dict(Symbol(fn) => fn(fit) for fn in measures) +fit_measures(fit, measures...) = fit_measures(fit, measures) +fit_measures(fit) = fit_measures(fit, DEFAULT_FIT_MEASURES) """ fit_measures(fit::SemFit, measures...) -> Dict{Symbol} @@ -20,6 +22,7 @@ fit_measures(semfit, nparams, dof, p_value) ``` # See also -[`AIC`](@ref), [`BIC`](@ref), [`RMSEA`](@ref), [`χ²`](@ref), [`p_value`](@ref), [`minus2ll`](@ref) +[`AIC`](@ref), [`BIC`](@ref), [`RMSEA`](@ref), [`χ²`](@ref), [`p_value`](@ref), +[`minus2ll`](@ref), [`CFI`](@ref) """ fit_measures From 60dbdc7a250a9e289d928e2847823f8eede82834 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sun, 22 Mar 2026 20:58:00 -0700 Subject: [PATCH 17/74] test_fitmeasures(): refactor/simplify --- test/examples/helper.jl | 34 ++++++---- test/examples/multigroup/build_models.jl | 47 ++------------ test/examples/political_democracy/by_parts.jl | 36 ++--------- .../political_democracy/constructor.jl | 64 ++++--------------- 4 files changed, 42 insertions(+), 139 deletions(-) diff --git a/test/examples/helper.jl b/test/examples/helper.jl index f14fec62..fed95f3c 100644 --- a/test/examples/helper.jl +++ b/test/examples/helper.jl @@ -49,7 +49,8 @@ function test_hessian(model, params; rtol = 1e-4, atol = 0) @test hessian ≈ true_hessian rtol = rtol atol = atol end -fitmeasure_names_ml = Dict( +# map from the SEM.jl name of the fit measure to the lavaan's one +fitmeasure_semjl_to_lavaan = Dict( :AIC => "aic", :BIC => "bic", :dof => "df", @@ -57,26 +58,31 @@ fitmeasure_names_ml = Dict( :p_value => "pvalue", :nparams => "npar", :RMSEA => "rmsea", -) - -fitmeasure_names_ls = Dict( - :dof => "df", - :χ² => "chisq", - :p_value => "pvalue", - :nparams => "npar", - :RMSEA => "rmsea", + :CFI => "cfi", ) function test_fitmeasures( - measures, + fitted::SemFit, measures_lav; + fitmeasures::AbstractVector = SEM.DEFAULT_FIT_MEASURES, + fitted_baseline::Union{SemFit, Nothing} = nothing, rtol = 1e-4, atol = 0, - fitmeasure_names = fitmeasure_names_ml, ) - @testset "$name" for (key, name) in pairs(fitmeasure_names) - measure_lav = measures_lav.x[findfirst(==(name), measures_lav[!, 1])] - @test measures[key] ≈ measure_lav rtol = rtol atol = atol + @testset "$fn" for fn in fitmeasures + name = Symbol(fn) + # FIML CFI requires the baseline model + measure = + fn != CFI || isnothing(fitted_baseline) ? fn(fitted) : + fn(fitted, fitted_baseline) + lav_name = fitmeasure_semjl_to_lavaan[name] + lav_ix = findfirst(==(lav_name), measures_lav[!, 1]) + if isnothing(lav_ix) + @test ismissing(measure) + else + measure_lav = measures_lav.x[lav_ix] + @test measure ≈ measure_lav rtol = rtol atol = atol + end end end diff --git a/test/examples/multigroup/build_models.jl b/test/examples/multigroup/build_models.jl index 6811cb40..71128760 100644 --- a/test/examples/multigroup/build_models.jl +++ b/test/examples/multigroup/build_models.jl @@ -52,17 +52,7 @@ end @testset "fitmeasures/se_ml" begin solution_ml = fit(semoptimizer, model_ml_multigroup) - test_fitmeasures( - fit_measures(solution_ml), - solution_lav[:fitmeasures_ml]; - rtol = 1e-2, - atol = 1e-7, - ) - test_fitmeasures( - Dict(:CFI => CFI(solution_ml)), - solution_lav[:fitmeasures_ml]; - fitmeasure_names = Dict(:CFI => "cfi"), - ) + test_fitmeasures(solution_ml, solution_lav[:fitmeasures_ml]; rtol = 1e-2, atol = 1e-7) update_se_hessian!(partable, solution_ml) test_estimates( @@ -122,17 +112,7 @@ end @testset "fitmeasures/se_ml | sorted" begin solution_ml = fit(semoptimizer, model_ml_multigroup) - test_fitmeasures( - fit_measures(solution_ml), - solution_lav[:fitmeasures_ml]; - rtol = 1e-2, - atol = 1e-7, - ) - test_fitmeasures( - Dict(:CFI => CFI(solution_ml)), - solution_lav[:fitmeasures_ml]; - fitmeasure_names = Dict(:CFI => "cfi"), - ) + test_fitmeasures(solution_ml, solution_lav[:fitmeasures_ml]; rtol = 1e-2, atol = 1e-7) update_se_hessian!(partable_s, solution_ml) test_estimates( @@ -221,18 +201,7 @@ end @testset "fitmeasures/se_ls" begin solution_ls = fit(semoptimizer, model_ls_multigroup) - test_fitmeasures( - fit_measures(solution_ls), - solution_lav[:fitmeasures_ls]; - fitmeasure_names = fitmeasure_names_ls, - rtol = 1e-2, - atol = 1e-5, - ) - test_fitmeasures( - Dict(:CFI => CFI(solution_ls)), - solution_lav[:fitmeasures_ls]; - fitmeasure_names = Dict(:CFI => "cfi"), - ) + test_fitmeasures(solution_ls, solution_lav[:fitmeasures_ls]; rtol = 1e-2, atol = 1e-5) @suppress update_se_hessian!(partable, solution_ls) test_estimates( @@ -319,18 +288,14 @@ if !isnothing(specification_miss_g1) @testset "fitmeasures/se_fiml" begin solution = fit(semoptimizer, model_ml_multigroup) + solution_varonly = fit(semoptimizer, model_ml_varonly) test_fitmeasures( - fit_measures(solution), + solution, solution_lav[:fitmeasures_fiml]; + fitted_baseline = solution_varonly, rtol = 1e-3, atol = 0, ) - solution_varonly = fit(semoptimizer, model_ml_varonly) - test_fitmeasures( - Dict(:CFI => CFI(solution, solution_varonly)), - solution_lav[:fitmeasures_fiml]; - fitmeasure_names = Dict(:CFI => "cfi"), - ) update_se_hessian!(partable_miss, solution) test_estimates( partable_miss, diff --git a/test/examples/political_democracy/by_parts.jl b/test/examples/political_democracy/by_parts.jl index 6866eead..ef634a59 100644 --- a/test/examples/political_democracy/by_parts.jl +++ b/test/examples/political_democracy/by_parts.jl @@ -90,12 +90,7 @@ end @testset "fitmeasures/se_ml" begin solution_ml = fit(semoptimizer, model_ml) - test_fitmeasures(fit_measures(solution_ml), solution_lav[:fitmeasures_ml]; atol = 1e-3) - test_fitmeasures( - Dict(:CFI => CFI(solution_ml)), - solution_lav[:fitmeasures_ml]; - fitmeasure_names = Dict(:CFI => "cfi"), - ) + test_fitmeasures(solution_ml, solution_lav[:fitmeasures_ml]; atol = 1e-3) update_se_hessian!(partable, solution_ml) test_estimates( @@ -109,14 +104,7 @@ end @testset "fitmeasures/se_ls" begin solution_ls = fit(semoptimizer, model_ls_sym) - fm = fit_measures(solution_ls) - test_fitmeasures( - merge(fm, Dict(:CFI => CFI(solution_ls))), - solution_lav[:fitmeasures_ls]; - atol = 1e-3, - fitmeasure_names = merge(fitmeasure_names_ls, Dict(:CFI => "cfi")) - ) - @test ismissing(fm[:AIC]) && ismissing(fm[:BIC]) && ismissing(fm[:minus2ll]) + test_fitmeasures(solution_ls, solution_lav[:fitmeasures_ls]; atol = 1e-3) @suppress update_se_hessian!(partable, solution_ls) test_estimates( @@ -241,16 +229,7 @@ end @testset "fitmeasures/se_ml_mean" begin solution_ml = fit(semoptimizer, model_ml) - test_fitmeasures( - fit_measures(solution_ml), - solution_lav[:fitmeasures_ml_mean]; - atol = 1e-3, - ) - test_fitmeasures( - Dict(:CFI => CFI(solution_ml)), - solution_lav[:fitmeasures_ml_mean]; - fitmeasure_names = Dict(:CFI => "cfi"), - ) + test_fitmeasures(solution_ml, solution_lav[:fitmeasures_ml_mean]; atol = 1e-3) update_se_hessian!(partable_mean, solution_ml) test_estimates( @@ -264,14 +243,7 @@ end @testset "fitmeasures/se_ls_mean" begin solution_ls = fit(semoptimizer, model_ls) - fm = fit_measures(solution_ls) - test_fitmeasures( - merge(fm, Dict(:CFI => CFI(solution_ls))), - solution_lav[:fitmeasures_ls_mean]; - atol = 1e-3, - fitmeasure_names = merge(fitmeasure_names_ls, Dict(:CFI => "cfi")), - ) - @test ismissing(fm[:AIC]) && ismissing(fm[:BIC]) && ismissing(fm[:minus2ll]) + test_fitmeasures(solution_ls, solution_lav[:fitmeasures_ls_mean]; atol = 1e-3) @suppress update_se_hessian!(partable_mean, solution_ls) test_estimates( diff --git a/test/examples/political_democracy/constructor.jl b/test/examples/political_democracy/constructor.jl index 48ba1b96..2efa5abe 100644 --- a/test/examples/political_democracy/constructor.jl +++ b/test/examples/political_democracy/constructor.jl @@ -100,12 +100,7 @@ end @testset "fitmeasures/se_ml" begin solution_ml = fit(semoptimizer, model_ml) - test_fitmeasures(fit_measures(solution_ml), solution_lav[:fitmeasures_ml]; atol = 1e-3) - test_fitmeasures( - Dict(:CFI => CFI(solution_ml)), - solution_lav[:fitmeasures_ml]; - fitmeasure_names = Dict(:CFI => "cfi"), - ) + test_fitmeasures(solution_ml, solution_lav[:fitmeasures_ml]; atol = 1e-3) update_se_hessian!(partable, solution_ml) test_estimates( @@ -122,20 +117,7 @@ end @testset "fitmeasures/se_ls" begin solution_ls = fit(semoptimizer, model_ls_sym) - fm = fit_measures(solution_ls) - test_fitmeasures( - fm, - solution_lav[:fitmeasures_ls]; - atol = 1e-3, - fitmeasure_names = fitmeasure_names_ls, - ) - test_fitmeasures( - Dict(:CFI => CFI(solution_ls)), - solution_lav[:fitmeasures_ls]; - fitmeasure_names = Dict(:CFI => "cfi"), - ) - - @test ismissing(fm[:AIC]) && ismissing(fm[:BIC]) && ismissing(fm[:minus2ll]) + test_fitmeasures(solution_ls, solution_lav[:fitmeasures_ls]; atol = 1e-3) @suppress update_se_hessian!(partable, solution_ls) test_estimates( @@ -298,16 +280,7 @@ end @testset "fitmeasures/se_ml_mean" begin solution_ml = fit(semoptimizer, model_ml) - test_fitmeasures( - fit_measures(solution_ml), - solution_lav[:fitmeasures_ml_mean]; - atol = 0.002, - ) - test_fitmeasures( - Dict(:CFI => CFI(solution_ml)), - solution_lav[:fitmeasures_ml_mean]; - fitmeasure_names = Dict(:CFI => "cfi"), - ) + test_fitmeasures(solution_ml, solution_lav[:fitmeasures_ml_mean]; atol = 0.002) update_se_hessian!(partable_mean, solution_ml) test_estimates( @@ -324,19 +297,7 @@ end @testset "fitmeasures/se_ls_mean" begin solution_ls = fit(semoptimizer, model_ls) - fm = fit_measures(solution_ls) - test_fitmeasures( - fm, - solution_lav[:fitmeasures_ls_mean]; - atol = 1e-3, - fitmeasure_names = fitmeasure_names_ls, - ) - @test ismissing(fm[:AIC]) && ismissing(fm[:BIC]) && ismissing(fm[:minus2ll]) - test_fitmeasures( - Dict(:CFI => CFI(solution_ls)), - solution_lav[:fitmeasures_ls_mean]; - fitmeasure_names = Dict(:CFI => "cfi"), - ) + test_fitmeasures(solution_ls, solution_lav[:fitmeasures_ls_mean]; atol = 1e-3) @suppress update_se_hessian!(partable_mean, solution_ls) test_estimates( @@ -410,6 +371,8 @@ if !ismissing(spec_varonly) loss = SemFIML, meanstructure = true, ) +else + model_varonly = nothing end ############################################################################################ @@ -446,19 +409,16 @@ end @testset "fitmeasures/se_fiml" begin solution_ml = fit(semoptimizer, model_ml) + solution_varonly = + !isnothing(model_varonly) ? fit(semoptimizer, model_varonly) : nothing test_fitmeasures( - fit_measures(solution_ml), + solution_ml, solution_lav[:fitmeasures_fiml]; + fitted_baseline = solution_varonly, + fitmeasures = !isnothing(solution_varonly) ? SEM.DEFAULT_FIT_MEASURES : + filter(!=(CFI), SEM.DEFAULT_FIT_MEASURES), atol = 1e-3, ) - if !ismissing(spec_varonly) - solution_varonly = fit(semoptimizer, model_varonly) - test_fitmeasures( - Dict(:CFI => CFI(solution_ml, solution_varonly)), - solution_lav[:fitmeasures_fiml]; - fitmeasure_names = Dict(:CFI => "cfi"), - ) - end update_se_hessian!(partable_mean, solution_ml) test_estimates( From 05abcd9c7f2170d8fe89cb3ffd1edc3d2f5035da Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sun, 22 Mar 2026 23:51:27 -0700 Subject: [PATCH 18/74] test/multigroup: small tweaks --- test/examples/multigroup/build_models.jl | 2 +- test/examples/multigroup/multigroup.jl | 17 +++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/test/examples/multigroup/build_models.jl b/test/examples/multigroup/build_models.jl index 71128760..6c22a453 100644 --- a/test/examples/multigroup/build_models.jl +++ b/test/examples/multigroup/build_models.jl @@ -240,7 +240,7 @@ if !isnothing(specification_miss_g1) model_ml_varonly = Sem( specification = partable_varonly, - data = dat_missing, + data = dat_miss, semterm_column = :school, loss = SemFIML, observed = SemObservedMissing, diff --git a/test/examples/multigroup/multigroup.jl b/test/examples/multigroup/multigroup.jl index c8ae8c1f..35fe20e6 100644 --- a/test/examples/multigroup/multigroup.jl +++ b/test/examples/multigroup/multigroup.jl @@ -10,17 +10,18 @@ const SEM = StructuralEquationModels include(joinpath(chop(dirname(pathof(SEM)), tail = 3), "test/examples/helper.jl")) dat = example_data("holzinger_swineford") -dat_missing = example_data("holzinger_swineford_missing") -solution_lav = example_data("holzinger_swineford_solution") +dat.school = Symbol.(replace.(dat.school, "-" => "_")) + +dat_miss = example_data("holzinger_swineford_missing") +dat_miss.school = Symbol.(replace.(dat_miss.school, "-" => "_")) -dat_g1 = dat[dat.school .== "Pasteur", :] -dat_g2 = dat[dat.school .== "Grant-White", :] +solution_lav = example_data("holzinger_swineford_solution") -dat_miss_g1 = dat_missing[dat_missing.school .== "Pasteur", :] -dat_miss_g2 = dat_missing[dat_missing.school .== "Grant-White", :] +dat_g1 = dat[dat.school .== :Pasteur, :] +dat_g2 = dat[dat.school .== :Grant_White, :] -dat.school = ifelse.(dat.school .== "Pasteur", :Pasteur, :Grant_White) -dat_missing.school = ifelse.(dat_missing.school .== "Pasteur", :Pasteur, :Grant_White) +dat_miss_g1 = dat_miss[dat_miss.school .== :Pasteur, :] +dat_miss_g2 = dat_miss[dat_miss.school .== :Grant_White, :] ############################################################################################ ### specification - RAMMatrices From 91d6f4774acbae0f7ed4cb5b40a39fdf26a259bf Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 30 Mar 2026 15:51:19 -0700 Subject: [PATCH 19/74] finite_diff: replace_observed() calls replace_observed() for the underlying term --- src/frontend/finite_diff.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/frontend/finite_diff.jl b/src/frontend/finite_diff.jl index ee0a9bf9..bbf2cc68 100644 --- a/src/frontend/finite_diff.jl +++ b/src/frontend/finite_diff.jl @@ -2,6 +2,9 @@ _unwrap(wrapper::SemFiniteDiff) = wrapper.model params(wrapper::SemFiniteDiff) = params(wrapper.model) loss_terms(wrapper::SemFiniteDiff) = loss_terms(wrapper.model) +replace_observed(wrapper::SemFiniteDiff, data) = + SemFiniteDiff(replace_observed(wrapper.model, data)) + FiniteDiffLossWrappers = Union{LossFiniteDiff, SemLossFiniteDiff} _unwrap(term::AbstractLoss) = term @@ -9,6 +12,17 @@ _unwrap(wrapper::FiniteDiffLossWrappers) = wrapper.loss implied(wrapper::FiniteDiffLossWrappers) = implied(_unwrap(wrapper)) observed(wrapper::FiniteDiffLossWrappers) = observed(_unwrap(wrapper)) +replace_observed(wrapper::LossFiniteDiff, data) = + LossFiniteDiff(replace_observed(_unwrap(wrapper), data)) + +replace_observed(wrapper::SemLossFiniteDiff, new_observed::SemObserved) = + SemLossFiniteDiff(replace_observed(_unwrap(wrapper), new_observed)) + +replace_observed( + wrapper::SemLossFiniteDiff, + data::Union{AbstractMatrix, DataFrame}, +) = SemLossFiniteDiff(replace_observed(_unwrap(wrapper), data)) + FiniteDiffWrapper(model::AbstractSem) = SemFiniteDiff(model) FiniteDiffWrapper(loss::AbstractLoss) = LossFiniteDiff(loss) FiniteDiffWrapper(loss::SemLoss) = SemLossFiniteDiff(loss) From bfd32b4de8fd74c07c878da8433b1397409b7d93 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 30 Mar 2026 15:52:25 -0700 Subject: [PATCH 20/74] replace_observed(): support kwargs --- src/frontend/finite_diff.jl | 17 +++++++++-------- src/frontend/specification/Sem.jl | 19 +++++++++++-------- src/loss/abstract.jl | 13 +++++++------ 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/src/frontend/finite_diff.jl b/src/frontend/finite_diff.jl index bbf2cc68..0ecd4865 100644 --- a/src/frontend/finite_diff.jl +++ b/src/frontend/finite_diff.jl @@ -2,8 +2,8 @@ _unwrap(wrapper::SemFiniteDiff) = wrapper.model params(wrapper::SemFiniteDiff) = params(wrapper.model) loss_terms(wrapper::SemFiniteDiff) = loss_terms(wrapper.model) -replace_observed(wrapper::SemFiniteDiff, data) = - SemFiniteDiff(replace_observed(wrapper.model, data)) +replace_observed(wrapper::SemFiniteDiff, data; kwargs...) = + SemFiniteDiff(replace_observed(wrapper.model, data; kwargs...)) FiniteDiffLossWrappers = Union{LossFiniteDiff, SemLossFiniteDiff} @@ -12,16 +12,17 @@ _unwrap(wrapper::FiniteDiffLossWrappers) = wrapper.loss implied(wrapper::FiniteDiffLossWrappers) = implied(_unwrap(wrapper)) observed(wrapper::FiniteDiffLossWrappers) = observed(_unwrap(wrapper)) -replace_observed(wrapper::LossFiniteDiff, data) = - LossFiniteDiff(replace_observed(_unwrap(wrapper), data)) +replace_observed(wrapper::LossFiniteDiff, data; kwargs...) = + LossFiniteDiff(replace_observed(_unwrap(wrapper), data; kwargs...)) -replace_observed(wrapper::SemLossFiniteDiff, new_observed::SemObserved) = - SemLossFiniteDiff(replace_observed(_unwrap(wrapper), new_observed)) +replace_observed(wrapper::SemLossFiniteDiff, new_observed::SemObserved; kwargs...) = + SemLossFiniteDiff(replace_observed(_unwrap(wrapper), new_observed; kwargs...)) replace_observed( wrapper::SemLossFiniteDiff, - data::Union{AbstractMatrix, DataFrame}, -) = SemLossFiniteDiff(replace_observed(_unwrap(wrapper), data)) + data::Union{AbstractMatrix, DataFrame}; + kwargs..., +) = SemLossFiniteDiff(replace_observed(_unwrap(wrapper), data; kwargs...)) FiniteDiffWrapper(model::AbstractSem) = SemFiniteDiff(model) FiniteDiffWrapper(loss::AbstractLoss) = LossFiniteDiff(loss) diff --git a/src/frontend/specification/Sem.jl b/src/frontend/specification/Sem.jl index 42ff2d3e..01f5013c 100644 --- a/src/frontend/specification/Sem.jl +++ b/src/frontend/specification/Sem.jl @@ -475,18 +475,19 @@ replace_observed(model, new_df; semterm_column = :group) """ function replace_observed end -function replace_observed(sem::Sem, data::Union{SemObserved, AbstractMatrix}) +function replace_observed(sem::Sem, data::Union{SemObserved, AbstractMatrix}; kwargs...) nsem_terms(sem) > 1 && throw( ArgumentError( "Model contains $(nsem_terms(sem)) SEM terms. " * "Use a Dict{Symbol} or a DataFrame with `semterm_column` to provide per-term data.", ), ) - updated_terms = Tuple(replace_observed(term, data) for term in loss_terms(sem)) + updated_terms = + Tuple(replace_observed(term, data; kwargs...) for term in loss_terms(sem)) return Sem(updated_terms...) end -function replace_observed(sem::Sem, data::AbstractDict{Symbol}) +function replace_observed(sem::Sem, data::AbstractDict{Symbol}; kwargs...) term_ids = Set( if !isnothing(id(term)) id(term) @@ -507,12 +508,12 @@ function replace_observed(sem::Sem, data::AbstractDict{Symbol}) term_data = get(data, tid, nothing) isnothing(term_data) && throw(ArgumentError("No data provided for SEM term :$tid")) - return replace_observed(term, term_data) + return replace_observed(term, term_data; kwargs...) end return Sem(Tuple(updated_terms)...) end -function replace_observed(sem::Sem, data::AbstractVector) +function replace_observed(sem::Sem, data::AbstractVector; kwargs...) nsem = nsem_terms(sem) nsem == length(data) || throw( ArgumentError( @@ -520,7 +521,7 @@ function replace_observed(sem::Sem, data::AbstractVector) ), ) updated_terms = map(enumerate(loss_terms(sem))) do (i, term) - issemloss(term) ? replace_observed(term, data[i]) : term + issemloss(term) ? replace_observed(term, data[i]; kwargs...) : term end return Sem(Tuple(updated_terms)...) end @@ -529,6 +530,7 @@ function replace_observed( sem::Sem, data::AbstractDataFrame; semterm_column::Union{Symbol, Nothing} = nothing, + kwargs..., ) if isnothing(semterm_column) # single-term shortcut @@ -538,7 +540,8 @@ function replace_observed( "Provide `semterm_column` to specify which DataFrame column identifies the groups.", ), ) - updated_terms = Tuple(replace_observed(term, data) for term in loss_terms(sem)) + updated_terms = + Tuple(replace_observed(term, data; kwargs...) for term in loss_terms(sem)) return Sem(updated_terms...) end @@ -547,7 +550,7 @@ function replace_observed( g[semterm_column] => group_data for (g, group_data) in pairs(groupby(data, semterm_column)) ) - return replace_observed(sem, terms_data) + return replace_observed(sem, terms_data; kwargs...) end ############################################################## diff --git a/src/loss/abstract.jl b/src/loss/abstract.jl index bcd6d62b..2cd9f35d 100644 --- a/src/loss/abstract.jl +++ b/src/loss/abstract.jl @@ -45,7 +45,7 @@ check_observed_vars(sem::SemLoss) = check_observed_vars(observed(sem), implied(s # replace_observed: SemLoss, AbstractLoss, LossTerm ############################################################################################ -function replace_observed(loss::SemLoss, new_observed::SemObserved) +function replace_observed(loss::SemLoss, new_observed::SemObserved; kwargs...) old_obs = SEM.observed(loss) observed_vars(old_obs) == observed_vars(new_observed) || throw( ArgumentError( @@ -53,19 +53,20 @@ function replace_observed(loss::SemLoss, new_observed::SemObserved) "expected $(observed_vars(old_obs)), got $(observed_vars(new_observed))", ), ) + # the default replace_observed() does not pass through kwargs to the ctor return typeof(loss).name.wrapper(new_observed, SEM.implied(loss)) end -function replace_observed(loss::SemLoss, data::Union{AbstractMatrix, DataFrame}) +function replace_observed(loss::SemLoss, data::Union{AbstractMatrix, DataFrame}; kwargs...) old_obs = SEM.observed(loss) new_observed = typeof(old_obs).name.wrapper(data = data, observed_vars = observed_vars(old_obs)) - return replace_observed(loss, new_observed) + return replace_observed(loss, new_observed; kwargs...) end # non-SEM loss terms are unchanged -replace_observed(loss::AbstractLoss, ::Any) = loss +replace_observed(loss::AbstractLoss, ::Any; kwargs...) = loss # LossTerm: delegate to inner loss -replace_observed(term::LossTerm, data) = - LossTerm(replace_observed(loss(term), data), id(term), weight(term)) +replace_observed(term::LossTerm, data; kwargs...) = + LossTerm(replace_observed(loss(term), data; kwargs...), id(term), weight(term)) From 690d24821a99fb84389ddf7ac5c7b830c880d4fd Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 30 Mar 2026 15:53:29 -0700 Subject: [PATCH 21/74] replace_observed(SemWLS, ...; update_internal_state) the kwarg specifies whether to recalculate weights --- src/loss/WLS/WLS.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/loss/WLS/WLS.jl b/src/loss/WLS/WLS.jl index d04bc346..6c36aadd 100644 --- a/src/loss/WLS/WLS.jl +++ b/src/loss/WLS/WLS.jl @@ -181,3 +181,17 @@ function evaluate!(objective, gradient, hessian, loss::SemWLS, par) return objective end + +function replace_observed( + loss::SemWLS, + new_observed::SemObserved; + update_internal_state::Bool = true, +) + # recompute weight matrices only if update_internal_state=true + return SemWLS( + new_observed, + SEM.implied(loss); + wls_weight_matrix = update_internal_state ? nothing : loss.V, + wls_weight_matrix_mean = update_internal_state ? nothing : loss.V_μ, + ) +end From b41e75b758291ad3ee84cdf16222485480c4c739 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 30 Mar 2026 15:54:09 -0700 Subject: [PATCH 22/74] tests/model: replace_observed() kwargs passing --- test/unit_tests/model.jl | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/test/unit_tests/model.jl b/test/unit_tests/model.jl index 87812fba..e7f229aa 100644 --- a/test/unit_tests/model.jl +++ b/test/unit_tests/model.jl @@ -1,5 +1,7 @@ using StructuralEquationModels, Test, Statistics +const SEM = StructuralEquationModels + dat = example_data("political_democracy") dat_missing = example_data("political_democracy_missing")[:, names(dat)] @@ -73,3 +75,34 @@ end @test @inferred(nsamples(model)) == nsamples(obs) end + +@testset "replace_observed() preserves WLS state through finite-diff wrappers" begin + model = Sem( + specification = ram_matrices, + observed = obs, + implied = RAMSymbolic, + loss = SemWLS, + ) + wls_loss = sem_term(model) + findiff_model = Sem(SEM.FiniteDiffWrapper(wls_loss)) + + new_data = randn(nsamples(obs), nobserved_vars(obs)) + + findiff_model_oldstate = + replace_observed(findiff_model, new_data; update_internal_state = false) + findiff_model_newstate = + replace_observed(findiff_model, new_data; update_internal_state = true) + + loss_orig = SEM._unwrap(sem_term(findiff_model)) + loss_oldstate = SEM._unwrap(sem_term(findiff_model_oldstate)) + loss_newstate = SEM._unwrap(sem_term(findiff_model_newstate)) + + @test loss_orig isa SemWLS + @test loss_oldstate isa SemWLS + @test loss_newstate isa SemWLS + @test loss_orig !== loss_oldstate + @test loss_orig !== loss_newstate + @test loss_oldstate.V === loss_orig.V + @test loss_newstate.V !== loss_orig.V + @test observed_vars(loss_oldstate) == observed_vars(loss_orig) +end From b5e920ac445cb5b602620fefac896d8c9dc58506 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 30 Mar 2026 15:54:29 -0700 Subject: [PATCH 23/74] replace_observed(...; recompute_obs_state=true) --- src/frontend/fit/standard_errors/bootstrap.jl | 4 +++- src/loss/WLS/WLS.jl | 8 ++++---- test/unit_tests/model.jl | 4 ++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/frontend/fit/standard_errors/bootstrap.jl b/src/frontend/fit/standard_errors/bootstrap.jl index ce84e923..0a1f39b4 100644 --- a/src/frontend/fit/standard_errors/bootstrap.jl +++ b/src/frontend/fit/standard_errors/bootstrap.jl @@ -251,6 +251,8 @@ end # Fit one bootstrap replicate: resample, replace observed data, fit. function _fit_bootstrap_sample(sem_model, data, start; engine, fit_kwargs) boot_data = resample_with_replacement(data) - boot_model = replace_observed(sem_model, boot_data) + # we replace the observed data with the bootstrapped one, + # but preserve any internal state that is associated with the original data + boot_model = replace_observed(sem_model, boot_data; recompute_observed_state = true) return fit(boot_model; start_val = start, engine = engine, fit_kwargs...) end diff --git a/src/loss/WLS/WLS.jl b/src/loss/WLS/WLS.jl index 6c36aadd..9acb7de0 100644 --- a/src/loss/WLS/WLS.jl +++ b/src/loss/WLS/WLS.jl @@ -185,13 +185,13 @@ end function replace_observed( loss::SemWLS, new_observed::SemObserved; - update_internal_state::Bool = true, + recompute_observed_state::Bool = true, ) - # recompute weight matrices only if update_internal_state=true + # recompute weight matrices only if recompute_observed_state=true return SemWLS( new_observed, SEM.implied(loss); - wls_weight_matrix = update_internal_state ? nothing : loss.V, - wls_weight_matrix_mean = update_internal_state ? nothing : loss.V_μ, + wls_weight_matrix = recompute_observed_state ? nothing : loss.V, + wls_weight_matrix_mean = recompute_observed_state ? nothing : loss.V_μ, ) end diff --git a/test/unit_tests/model.jl b/test/unit_tests/model.jl index e7f229aa..dd1136f4 100644 --- a/test/unit_tests/model.jl +++ b/test/unit_tests/model.jl @@ -89,9 +89,9 @@ end new_data = randn(nsamples(obs), nobserved_vars(obs)) findiff_model_oldstate = - replace_observed(findiff_model, new_data; update_internal_state = false) + replace_observed(findiff_model, new_data; recompute_observed_state = false) findiff_model_newstate = - replace_observed(findiff_model, new_data; update_internal_state = true) + replace_observed(findiff_model, new_data; recompute_observed_state = true) loss_orig = SEM._unwrap(sem_term(findiff_model)) loss_oldstate = SEM._unwrap(sem_term(findiff_model_oldstate)) From 293c88bd71a89f94af29e0817b9282acabd8cab4 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 30 Mar 2026 17:57:10 -0700 Subject: [PATCH 24/74] tests/model: test multi-group data ctor --- src/frontend/specification/Sem.jl | 55 ++++++++++++++++++++++++++++--- test/unit_tests/model.jl | 35 ++++++++++++++++++++ 2 files changed, 86 insertions(+), 4 deletions(-) diff --git a/src/frontend/specification/Sem.jl b/src/frontend/specification/Sem.jl index 01f5013c..2d246166 100644 --- a/src/frontend/specification/Sem.jl +++ b/src/frontend/specification/Sem.jl @@ -335,6 +335,56 @@ function set_field_type_kwargs!(kwargs, observed, implied, loss, O, I) end end +# build ensemble/multi-group observed from the specification and Sem(...) kwargs +# used by Sem(...) and replace_observed() +function build_ensemble_observed(observed_type, spec::EnsembleParameterTable, kwargs) + if !haskey(kwargs, :data) + @warn """ + No data provided for ensemble SEM model. Each SEM term will be constructed with empty data. + To provide data for each term, pass a DataFrame with a column identifying the term groups or a Dict mapping term ids to data + """ + semterms_data = nothing + else + kwdata = kwargs[:data] + if isa(kwdata, AbstractDataFrame) + semterm_col = get(kwargs, :semterm_column, nothing) + isnothing(semterm_col) && + throw(ArgumentError("No semterm_column specified for ensemble data.")) + semterms_data = Dict( + g[semterm_col] => group_data for + (g, group_data) in pairs(groupby(kwdata, semterm_col)) + ) + elseif isa(kwdata, AbstractDict) + semterms_data = kwdata + else + """ + Unsupported data type for ensemble SEM model: $(typeof(kwdata)). + Provide a DataFrame with a column identifying the term groups or a Dict mapping term ids to data. + """ |> + ArgumentError |> + throw + end + unused_term_ids = setdiff(keys(semterms_data), keys(spec.tables)) + isempty(unused_term_ids) || + @warn "Ignoring data with ids=$(collect(unused_term_ids)): no such SEM terms exist" + end + + # construct SemObserved for each term + return Dict( + term_id => begin + term_kwargs = copy(kwargs) + if !isnothing(semterms_data) + term_data = get(semterms_data, term_id, nothing) + isnothing(term_data) && + throw(ArgumentError("No data provided for SEM term :$term_id")) + term_kwargs[:data] = term_data + delete!(term_kwargs, :semterm_column) + end + observed_type(; specification = term_spec, term_kwargs...) + end for (term_id, term_spec) in pairs(spec.tables) + ) +end + # construct Sem fields function get_fields!(kwargs, spec, observed, implied, loss) if !isa(spec, SemSpecification) @@ -344,10 +394,7 @@ function get_fields!(kwargs, spec, observed, implied, loss) # observed if !isa(observed, SemObserved) observed = if spec isa EnsembleParameterTable - Dict( - term_id => observed(; specification = term_spec, kwargs...) for - (term_id, term_spec) in pairs(spec.tables) - ) + build_ensemble_observed(observed, spec, kwargs) else observed(; specification = spec, kwargs...) end diff --git a/test/unit_tests/model.jl b/test/unit_tests/model.jl index dd1136f4..93ba5e80 100644 --- a/test/unit_tests/model.jl +++ b/test/unit_tests/model.jl @@ -106,3 +106,38 @@ end @test loss_newstate.V !== loss_orig.V @test observed_vars(loss_oldstate) == observed_vars(loss_orig) end + +@testset "Sem(...; semterm_column=...) splits ensemble data by group" begin + dat_grouped = copy(dat[:, [:x1, :x2]]) + n_g1 = size(dat_grouped, 1) ÷ 2 + dat_grouped.group = [fill(:g1, n_g1); fill(:g2, size(dat_grouped, 1) - n_g1)] + + group_graph = @StenoGraph begin + f1 → fixed(1.0, 1.0) * x1 + label(:λ₂, :λ₂) * x2 + _(Symbol[:x1, :x2]) ↔ _(Symbol[:x1, :x2]) + _(Symbol[:f1]) ↔ _(Symbol[:f1]) + end + + grouped_partable = EnsembleParameterTable( + group_graph; + observed_vars = [:x1, :x2], + latent_vars = [:f1], + groups = [:g1, :g2], + ) + + grouped_model = Sem( + specification = grouped_partable, + data = dat_grouped, + semterm_column = :group, + observed = SemObservedData, + implied = RAM, + loss = SemML, + ) + + term_g1 = only(filter(term -> SEM.id(term) == :g1, SEM.loss_terms(grouped_model))) + term_g2 = only(filter(term -> SEM.id(term) == :g2, SEM.loss_terms(grouped_model))) + + @test nsamples(observed(term_g1)) == n_g1 + @test nsamples(observed(term_g2)) == size(dat_grouped, 1) - n_g1 + @test nsamples(grouped_model) == size(dat_grouped, 1) +end From 7466a239ce40de7e4659f9152ea08403677227a5 Mon Sep 17 00:00:00 2001 From: Maximilian Ernst Date: Wed, 25 Mar 2026 14:41:30 +0100 Subject: [PATCH 25/74] SemFiniteDiff constructor to keep same Syntax --- src/types.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/types.jl b/src/types.jl index 87b733cf..eb251a3b 100644 --- a/src/types.jl +++ b/src/types.jl @@ -138,6 +138,9 @@ struct SemLossFiniteDiff{O, I, L <: SemLoss{O, I}} <: SemLoss{O, I} loss::L end +SemFiniteDiff(args...; kwargs...) = + SemFiniteDiff(Sem(args...; gradient = false, hessian = false, kwargs...)) + """ abstract type SemSpecification end From 5cdcc63539f32f5d2daf3bf783d70ace9cd2cdfe Mon Sep 17 00:00:00 2001 From: Maximilian Ernst Date: Wed, 25 Mar 2026 22:40:37 +0100 Subject: [PATCH 26/74] Sem print methods --- src/frontend/specification/Sem.jl | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/frontend/specification/Sem.jl b/src/frontend/specification/Sem.jl index 2d246166..db76acfa 100644 --- a/src/frontend/specification/Sem.jl +++ b/src/frontend/specification/Sem.jl @@ -28,20 +28,17 @@ for f in ( end function Base.show(io::IO, term::LossTerm) + print(io, nameof(losstype(term))) + print(io, "\n") if !isnothing(id(term)) - print(io, ":$(id(term)): ") + print(io, " - id: $(id(term)) \n") end - print(io, nameof(losstype(term))) if issemloss(term) - print( - io, - " ($(nsamples(term)) samples, $(nobserved_vars(term)) observed, $(nlatent_vars(term)) latent variables)", - ) + print(io, " - observed: $(nameof(typeof(observed(loss(term))))) \n") + print(io, " - implied: $(nameof(typeof(implied(loss(term))))) \n") end if !isnothing(weight(term)) - print(io, " w=$(round(weight(term), digits=3))") - else - print(io, " w=1") + print(io, " - weight: $(round(weight(term), digits=3))") end end @@ -604,12 +601,17 @@ end # pretty printing ############################################################## +_subtype_info(::Sem) = nothing +_subtype_info(::SemFiniteDiff) = "Finite Difference Approximation" + function Base.show(io::IO, sem::AbstractSem) - println(io, "Structural Equation Model ($(nameof(typeof(sem))))") - println(io, "- $(nparams(sem)) parameters") - println(io, "- Loss terms:") + print(io, "Structural Equation Model") + si = _subtype_info(sem) + isnothing(si) || print(io, " : "*si) + print("\n") + print(io, "- Loss Functions \n") for term in loss_terms(sem) - print(io, " - ") + print(io, " > ") print(io, term) println(io) end From 6d803cbb44d0d251a895304d0529349cb66171a4 Mon Sep 17 00:00:00 2001 From: Maximilian Ernst Date: Fri, 10 Apr 2026 18:36:06 +0200 Subject: [PATCH 27/74] add details method for AbstractSem --- src/frontend/fit/summary.jl | 43 ++++++++++++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 3 deletions(-) diff --git a/src/frontend/fit/summary.jl b/src/frontend/fit/summary.jl index c8495b79..fe7ea930 100644 --- a/src/frontend/fit/summary.jl +++ b/src/frontend/fit/summary.jl @@ -1,3 +1,38 @@ +function details(sem::AbstractSem) + print("Structural Equation Model") + print(_subtype_info(sem)) + print("\n") + print("- Loss Functions \n") + for term in loss_terms(sem) + print(" > ") + details(term) + println() + end +end + +function details(term::LossTerm) + if !issemloss(term) + print(term.loss) + else + println("Structural Equation Model Loss ($(nameof(typeof(term.loss))))") + if !isnothing(id(term)) + print(" - id: $(id(term)) \n") + end + println( + " - Observed: $(nameof(typeof(observed(term)))) ($(nsamples(term)) samples)", + ) + println( + " - Implied: $(nameof(typeof(implied(term)))) ($(nparams(term)) parameters)", + ) + println( + " - Variables: $(nobserved_vars(term)) observed, $(nlatent_vars(term)) latent", + ) + if !isnothing(weight(term)) + print(" - weight: $(round(weight(term), digits=3))") + end + end +end + function details(sem_fit::SemFit; show_fitmeasures = false, color = :light_cyan, digits = 2) print("\n") println("Fitted Structural Equation Model") @@ -325,11 +360,13 @@ function Base.findall(fun::Function, partable::ParameterTable) end """ - (1) details(sem_fit::SemFit; show_fitmeasures = false) + (1) details(model::AbstractSem) + + (2) details(sem_fit::SemFit; show_fitmeasures = false) - (2) details(partable::AbstractParameterTable; ...) + (3) details(partable::AbstractParameterTable; ...) -Print information about (1) a fitted SEM or (2) a parameter table to stdout. +Print information about (1) a SEM, (2) a fitted SEM or (3) a parameter table to stdout. # Extended help ## Addition keyword arguments From 1874351be1392d49f9d0ddf68910a3b3bb0d4bf8 Mon Sep 17 00:00:00 2001 From: Maximilian Ernst Date: Sat, 11 Apr 2026 23:12:27 +0200 Subject: [PATCH 28/74] shorten model --- src/frontend/specification/Sem.jl | 41 ++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/src/frontend/specification/Sem.jl b/src/frontend/specification/Sem.jl index db76acfa..12b8da15 100644 --- a/src/frontend/specification/Sem.jl +++ b/src/frontend/specification/Sem.jl @@ -28,17 +28,35 @@ for f in ( end function Base.show(io::IO, term::LossTerm) - print(io, nameof(losstype(term))) - print(io, "\n") - if !isnothing(id(term)) - print(io, " - id: $(id(term)) \n") - end - if issemloss(term) - print(io, " - observed: $(nameof(typeof(observed(loss(term))))) \n") - print(io, " - implied: $(nameof(typeof(implied(loss(term))))) \n") - end - if !isnothing(weight(term)) - print(io, " - weight: $(round(weight(term), digits=3))") + if (:compact => true) in io + if !isnothing(id(term)) + print(io, ":$(id(term)): ") + end + print(io, nameof(losstype(term))) + if issemloss(term) + print( + io, + " ($(nsamples(term)) samples, $(nobserved_vars(term)) observed, $(nlatent_vars(term)) latent variables)", + ) + end + if !isnothing(weight(term)) + print(io, " w=$(round(weight(term), digits=3))") + else + print(io, " w=1") + end + else + print(io, nameof(losstype(term))) + print(io, "\n") + if !isnothing(id(term)) + print(io, " - id: $(id(term)) \n") + end + if issemloss(term) + print(io, " - observed: $(nameof(typeof(observed(loss(term))))) \n") + print(io, " - implied: $(nameof(typeof(implied(loss(term))))) \n") + end + if !isnothing(weight(term)) + print(io, " - weight: $(round(weight(term), digits=3))") + end end end @@ -610,6 +628,7 @@ function Base.show(io::IO, sem::AbstractSem) isnothing(si) || print(io, " : "*si) print("\n") print(io, "- Loss Functions \n") + io = length(loss_terms(sem)) >= 5 ? IOContext(io, :compact => true) : io for term in loss_terms(sem) print(io, " > ") print(io, term) From 7696e8f94d4ce5281c1278836c940abad3d03333 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sun, 12 Apr 2026 21:23:11 -0700 Subject: [PATCH 29/74] Sem(): remove SemWLS kw check logic --- src/frontend/specification/Sem.jl | 12 ++---------- test/unit_tests/model.jl | 2 ++ 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/frontend/specification/Sem.jl b/src/frontend/specification/Sem.jl index 12b8da15..a3ccec58 100644 --- a/src/frontend/specification/Sem.jl +++ b/src/frontend/specification/Sem.jl @@ -417,21 +417,13 @@ function get_fields!(kwargs, spec, observed, implied, loss) # implied if !isa(implied, SemImplied) - # FIXME remove this implicit logic - # SemWLS only accepts vech-ed implied covariance - if isa(loss, Type) && (loss <: SemWLS) && !haskey(kwargs, :vech) - implied_kwargs = copy(kwargs) - implied_kwargs[:vech] = true - else - implied_kwargs = kwargs - end implied = if spec isa EnsembleParameterTable Dict( - term_id => implied(term_spec; implied_kwargs...) for + term_id => implied(term_spec; kwargs...) for (term_id, term_spec) in pairs(spec.tables) ) else - implied(spec; implied_kwargs...) + implied(spec; kwargs...) end end diff --git a/test/unit_tests/model.jl b/test/unit_tests/model.jl index 93ba5e80..c80a0c1b 100644 --- a/test/unit_tests/model.jl +++ b/test/unit_tests/model.jl @@ -58,6 +58,7 @@ end observed = obs, implied = impliedtype, loss = losstype, + vech = losstype <: SemWLS && impliedtype <: RAMSymbolic ) @test model isa Sem @@ -82,6 +83,7 @@ end observed = obs, implied = RAMSymbolic, loss = SemWLS, + vech = true ) wls_loss = sem_term(model) findiff_model = Sem(SEM.FiniteDiffWrapper(wls_loss)) From 9c5e4461843a384bc01f6d6a572ddca68b06f4fc Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sun, 12 Apr 2026 21:32:48 -0700 Subject: [PATCH 30/74] Sem(): cleanup constructor * rename get_fields!() into build_sem_terms() for clarity * move set_field_type!() code into Sem() ctor since its not used outside --- src/frontend/specification/Sem.jl | 37 ++++++++++++++----------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/src/frontend/specification/Sem.jl b/src/frontend/specification/Sem.jl index a3ccec58..1c0c5f3f 100644 --- a/src/frontend/specification/Sem.jl +++ b/src/frontend/specification/Sem.jl @@ -200,9 +200,19 @@ function Sem(; ) where {O, I, L} kwdict = Dict{Symbol, Any}(kwargs...) - set_field_type_kwargs!(kwdict, observed, implied, loss, O, I) + # add kwargs with type information + kwdict[:observed_type] = O <: Type ? observed : typeof(observed) + kwdict[:implied_type] = I <: Type ? implied : typeof(implied) + if loss isa SemLoss + kwdict[:loss_types] = + [aloss isa SemLoss ? typeof(aloss) : aloss for aloss in loss.functions] + elseif applicable(iterate, loss) + kwdict[:loss_types] = [aloss isa SemLoss ? typeof(aloss) : aloss for aloss in loss] + else + kwdict[:loss_types] = [loss isa SemLoss ? typeof(loss) : loss] + end - loss = get_fields!(kwdict, specification, observed, implied, loss) + loss = build_sem_terms(kwdict, specification, observed, implied, loss) return Sem(loss...) end @@ -337,19 +347,6 @@ vars(model::AbstractSem, id::Nothing = nothing) = vars(implied(model, id)) observed_vars(model::AbstractSem, id::Nothing = nothing) = observed_vars(implied(model, id)) latent_vars(model::AbstractSem, id::Nothing = nothing) = latent_vars(implied(model, id)) -function set_field_type_kwargs!(kwargs, observed, implied, loss, O, I) - kwargs[:observed_type] = O <: Type ? observed : typeof(observed) - kwargs[:implied_type] = I <: Type ? implied : typeof(implied) - if loss isa SemLoss - kwargs[:loss_types] = - [aloss isa SemLoss ? typeof(aloss) : aloss for aloss in loss.functions] - elseif applicable(iterate, loss) - kwargs[:loss_types] = [aloss isa SemLoss ? typeof(aloss) : aloss for aloss in loss] - else - kwargs[:loss_types] = [loss isa SemLoss ? typeof(loss) : loss] - end -end - # build ensemble/multi-group observed from the specification and Sem(...) kwargs # used by Sem(...) and replace_observed() function build_ensemble_observed(observed_type, spec::EnsembleParameterTable, kwargs) @@ -400,8 +397,8 @@ function build_ensemble_observed(observed_type, spec::EnsembleParameterTable, kw ) end -# construct Sem fields -function get_fields!(kwargs, spec, observed, implied, loss) +# called by Sem() ctor to construct its loss terms +function build_sem_terms(kwargs::AbstractDict, spec, observed, implied, loss) if !isa(spec, SemSpecification) spec = spec(; kwargs...) end @@ -430,13 +427,13 @@ function get_fields!(kwargs, spec, observed, implied, loss) # loss loss_kwargs = copy(kwargs) loss_kwargs[:nparams] = nparams(spec) - loss = build_SemTerms(loss, observed, implied; loss_kwargs...) + loss = build_sem_terms(loss, observed, implied; loss_kwargs...) return loss end -# construct loss field -function build_SemTerms(loss, observed, implied; kwargs...) +# construct loss terms for the given observed and implied +function build_sem_terms(loss, observed, implied; kwargs...) function build_SemLoss(aloss, observed, implied) if loss isa AbstractLoss return loss From 3aee9f42a61acc3695b371457e783d433070ba9e Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sun, 12 Apr 2026 22:59:47 -0700 Subject: [PATCH 31/74] show(::Sem): respect existing :compact key --- src/frontend/specification/Sem.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/frontend/specification/Sem.jl b/src/frontend/specification/Sem.jl index 1c0c5f3f..be9c4fc3 100644 --- a/src/frontend/specification/Sem.jl +++ b/src/frontend/specification/Sem.jl @@ -612,12 +612,15 @@ _subtype_info(::Sem) = nothing _subtype_info(::SemFiniteDiff) = "Finite Difference Approximation" function Base.show(io::IO, sem::AbstractSem) + # if not specified, use compact printing for larger models + if !haskey(io, :compact) && length(loss_terms(sem)) >= 5 + io = IOContext(io, :compact => true) + end print(io, "Structural Equation Model") si = _subtype_info(sem) isnothing(si) || print(io, " : "*si) print("\n") print(io, "- Loss Functions \n") - io = length(loss_terms(sem)) >= 5 ? IOContext(io, :compact => true) : io for term in loss_terms(sem) print(io, " > ") print(io, term) From 0e948d7888c538aabc16e7c4c1ab6450ab4eaf69 Mon Sep 17 00:00:00 2001 From: Maximilian Ernst Date: Sat, 25 Apr 2026 11:04:55 +0200 Subject: [PATCH 32/74] add show method for AbstractSem --- src/frontend/specification/Sem.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/frontend/specification/Sem.jl b/src/frontend/specification/Sem.jl index be9c4fc3..2ecac9e7 100644 --- a/src/frontend/specification/Sem.jl +++ b/src/frontend/specification/Sem.jl @@ -612,6 +612,19 @@ _subtype_info(::Sem) = nothing _subtype_info(::SemFiniteDiff) = "Finite Difference Approximation" function Base.show(io::IO, sem::AbstractSem) + io = IOContext(io, :compact => true) + println(io, "Structural Equation Model ($(nameof(typeof(sem))))") + println(io, "- $(nparams(sem)) parameters") + println(io, "- Loss Functions:") + for term in loss_terms(sem) + print(io, " - ") + print(io, term) + println(io) + end +end + +# pretty prenting for console +function Base.show(io::IO, ::MIME"text/plain", sem::AbstractSem) # if not specified, use compact printing for larger models if !haskey(io, :compact) && length(loss_terms(sem)) >= 5 io = IOContext(io, :compact => true) From d3ee1c932cff142b942160470e278e7404e307cb Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sun, 3 May 2026 11:36:48 -0700 Subject: [PATCH 33/74] simulation.md: whitespace fixes --- docs/src/performance/simulation.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/src/performance/simulation.md b/docs/src/performance/simulation.md index 61a9d5ad..f5199f62 100644 --- a/docs/src/performance/simulation.md +++ b/docs/src/performance/simulation.md @@ -40,7 +40,7 @@ end partable = ParameterTable( graph, - latent_vars = latent_vars, + latent_vars = latent_vars, observed_vars = observed_vars ) ``` @@ -63,10 +63,10 @@ model_updated = replace_observed(model, data_2) ## Multithreading !!! danger "Thread safety" *This is only relevant when you are planning to fit updated models in parallel* - - Models generated by `replace_observed` may share the same objects in memory (e.g. some parts of + + Models generated by `replace_observed` may share the same objects in memory (e.g. some parts of `model` and `model_updated` are the same objects in memory.) - Therefore, fitting both of these models in parallel will lead to **race conditions**, + Therefore, fitting both of these models in parallel will lead to **race conditions**, possibly crashing your computer. To avoid these problems, you should copy `model` before updating it. From 57ea987eff13fa72d254a5abeea585501c3a28cf Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sun, 3 May 2026 12:32:32 -0700 Subject: [PATCH 34/74] replace_obs(sem): make sure Sem type is preserved Co-authored-by: Copilot --- src/frontend/specification/Sem.jl | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/frontend/specification/Sem.jl b/src/frontend/specification/Sem.jl index 2ecac9e7..bbf7dab8 100644 --- a/src/frontend/specification/Sem.jl +++ b/src/frontend/specification/Sem.jl @@ -491,6 +491,14 @@ end # replace_observed: Sem level ############################################################## +# internal function to create a copy of Sem with the loss term replaced +# used by the replace_observed() +_replace_loss_terms(sem::Sem, new_terms::Tuple) = + Sem{typeof(new_terms)}(new_terms, copy(params(sem))) + +_replace_loss_terms(sem::Sem, new_terms::AbstractVector) = + _replace_loss_terms(sem, Tuple(new_terms)) + """ replace_observed(model::Sem, observed::SemObserved) replace_observed(model::Sem, data::AbstractDict{Symbol}) @@ -533,9 +541,10 @@ function replace_observed(sem::Sem, data::Union{SemObserved, AbstractMatrix}; kw "Use a Dict{Symbol} or a DataFrame with `semterm_column` to provide per-term data.", ), ) - updated_terms = - Tuple(replace_observed(term, data; kwargs...) for term in loss_terms(sem)) - return Sem(updated_terms...) + updated_terms = map(loss_terms(sem)) do term + replace_observed(term, data; kwargs...) + end + return _replace_loss_terms(sem, updated_terms) end function replace_observed(sem::Sem, data::AbstractDict{Symbol}; kwargs...) @@ -561,7 +570,7 @@ function replace_observed(sem::Sem, data::AbstractDict{Symbol}; kwargs...) throw(ArgumentError("No data provided for SEM term :$tid")) return replace_observed(term, term_data; kwargs...) end - return Sem(Tuple(updated_terms)...) + return _replace_loss_terms(sem, updated_terms) end function replace_observed(sem::Sem, data::AbstractVector; kwargs...) @@ -574,7 +583,7 @@ function replace_observed(sem::Sem, data::AbstractVector; kwargs...) updated_terms = map(enumerate(loss_terms(sem))) do (i, term) issemloss(term) ? replace_observed(term, data[i]; kwargs...) : term end - return Sem(Tuple(updated_terms)...) + return _replace_loss_terms(sem, updated_terms) end function replace_observed( @@ -591,9 +600,10 @@ function replace_observed( "Provide `semterm_column` to specify which DataFrame column identifies the groups.", ), ) - updated_terms = - Tuple(replace_observed(term, data; kwargs...) for term in loss_terms(sem)) - return Sem(updated_terms...) + updated_terms = map(loss_terms(sem)) do term + replace_observed(term, data; kwargs...) + end + return _replace_loss_terms(sem, updated_terms) end # multi-term: split DataFrame by semterm_column From 5cd11cfb477eb1e764a1a5b513388928ac557fd4 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sun, 3 May 2026 12:34:19 -0700 Subject: [PATCH 35/74] replace_obs(sem): update docstring --- src/frontend/specification/Sem.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/frontend/specification/Sem.jl b/src/frontend/specification/Sem.jl index bbf7dab8..52981990 100644 --- a/src/frontend/specification/Sem.jl +++ b/src/frontend/specification/Sem.jl @@ -509,7 +509,12 @@ _replace_loss_terms(sem::Sem, new_terms::AbstractVector) = Construct a new SEM model or SEM loss with replaced observed data. The SEM structure (implied covariance, loss type) is preserved; -only the observed data is swapped. +only the observed data is swapped. The new loss terms preserve the configuration +and share the implied state with the loss terms of the original SEM model. + +Keyword arguments: +- `recompute_observed_state::Bool = true`: loss terms should recompute observed-dependent + caches. Losses without such caches ignore this argument. # Single-term models From 42d4e648f639f57151ea81426df2de5d5b13a550 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sun, 3 May 2026 12:35:42 -0700 Subject: [PATCH 36/74] replace_obs(loss): extract check_obs_vars() method Co-authored-by: Copilot --- src/loss/abstract.jl | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/loss/abstract.jl b/src/loss/abstract.jl index 2cd9f35d..2b4940ad 100644 --- a/src/loss/abstract.jl +++ b/src/loss/abstract.jl @@ -41,18 +41,20 @@ end check_observed_vars(sem::SemLoss) = check_observed_vars(observed(sem), implied(sem)) +function check_observed_vars(loss::SemLoss, new_observed::SemObserved) + observed_vars(new_observed) == observed_vars(SEM.observed(loss)) || throw( + ArgumentError( + "Observed variables of the loss term do not match the ones of the new observed data", + ), + ) +end + ############################################################################################ # replace_observed: SemLoss, AbstractLoss, LossTerm ############################################################################################ function replace_observed(loss::SemLoss, new_observed::SemObserved; kwargs...) - old_obs = SEM.observed(loss) - observed_vars(old_obs) == observed_vars(new_observed) || throw( - ArgumentError( - "observed_vars of the new data do not match the model: " * - "expected $(observed_vars(old_obs)), got $(observed_vars(new_observed))", - ), - ) + check_observed_vars(loss, new_observed) # the default replace_observed() does not pass through kwargs to the ctor return typeof(loss).name.wrapper(new_observed, SEM.implied(loss)) end From 77597216ff4f04d9c566e2a365ef9e91427dd66e Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sun, 3 May 2026 20:17:19 -0700 Subject: [PATCH 37/74] test/multigroup: avoid clash with observed_vars() method --- test/examples/multigroup/multigroup.jl | 34 +++++++++++++------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/test/examples/multigroup/multigroup.jl b/test/examples/multigroup/multigroup.jl index 35fe20e6..d095cf11 100644 --- a/test/examples/multigroup/multigroup.jl +++ b/test/examples/multigroup/multigroup.jl @@ -103,8 +103,8 @@ end # w.o. meanstructure ----------------------------------------------------------------------- -latent_vars = [:visual, :textual, :speed] -observed_vars = Symbol.(:x, 1:9) +lat_vars = [:visual, :textual, :speed] +obs_vars = Symbol.(:x, 1:9) graph = @StenoGraph begin # measurement model @@ -112,14 +112,14 @@ graph = @StenoGraph begin textual → fixed(1.0, 1.0) * x4 + label(:λ₅, :λ₅) * x5 + label(:λ₆, :λ₆) * x6 speed → fixed(1.0, 1.0) * x7 + label(:λ₈, :λ₈) * x8 + label(:λ₉, :λ₉) * x9 # variances and covariances - _(observed_vars) ↔ _(observed_vars) - _(latent_vars) ⇔ _(latent_vars) + _(obs_vars) ↔ _(obs_vars) + _(lat_vars) ⇔ _(lat_vars) end partable = EnsembleParameterTable( graph; - observed_vars = observed_vars, - latent_vars = latent_vars, + observed_vars = obs_vars, + latent_vars = lat_vars, groups = [:Pasteur, :Grant_White], ) @@ -130,8 +130,8 @@ specification_g2 = specification[:Grant_White] # w. meanstructure (fiml) ------------------------------------------------------------------ -latent_vars = [:visual, :textual, :speed] -observed_vars = Symbol.(:x, 1:9) +lat_vars = [:visual, :textual, :speed] +obs_vars = Symbol.(:x, 1:9) graph = @StenoGraph begin # measurement model @@ -139,16 +139,16 @@ graph = @StenoGraph begin textual → fixed(1.0, 1.0) * x4 + label(:λ₅, :λ₅) * x5 + label(:λ₆, :λ₆) * x6 speed → fixed(1.0, 1.0) * x7 + label(:λ₈, :λ₈) * x8 + label(:λ₉, :λ₉) * x9 # variances and covariances - _(observed_vars) ↔ _(observed_vars) - _(latent_vars) ⇔ _(latent_vars) + _(obs_vars) ↔ _(obs_vars) + _(lat_vars) ⇔ _(lat_vars) - Symbol(1) → _(observed_vars) + Symbol(1) → _(obs_vars) end partable_miss = EnsembleParameterTable( graph; - observed_vars = observed_vars, - latent_vars = latent_vars, + observed_vars = obs_vars, + latent_vars = lat_vars, groups = [:Pasteur, :Grant_White], ) @@ -159,14 +159,14 @@ specification_miss_g2 = specification_miss[:Grant_White] # CFI baseline model graph_varonly = @StenoGraph begin - _(observed_vars) ↔ _(observed_vars) - Symbol(1) → _(observed_vars) + _(obs_vars) ↔ _(obs_vars) + Symbol(1) → _(obs_vars) end partable_varonly = EnsembleParameterTable( graph_varonly; - observed_vars = observed_vars, - latent_vars = latent_vars, + observed_vars = obs_vars, + latent_vars = lat_vars, groups = [:Pasteur, :Grant_White], ) From 13e118da6d7c3490f98f61008f57072ee63f605c Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sun, 3 May 2026 20:26:30 -0700 Subject: [PATCH 38/74] SemLoss(observed, implied, refloss; kwarg...) ctor to allow replicating the whole refloss state, e.g. for replace_observed() Co-authored-by: Copilot --- docs/src/developer/loss.md | 38 ++++++++++-------------- src/loss/ML/FIML.jl | 16 ++++++++-- src/loss/ML/ML.jl | 11 +++++-- src/loss/WLS/WLS.jl | 20 +++++++++---- test/examples/multigroup/build_models.jl | 7 +++-- 5 files changed, 56 insertions(+), 36 deletions(-) diff --git a/docs/src/developer/loss.md b/docs/src/developer/loss.md index aa6a1e17..8cdf2150 100644 --- a/docs/src/developer/loss.md +++ b/docs/src/developer/loss.md @@ -136,31 +136,25 @@ Additionally, you may provide analytic hessians by writing a respective method f ## Convenient -To be able to build the model with the [Outer Constructor](@ref), you need to add a constructor for your loss function that only takes keyword arguments and allows for passing optional additional kewyword arguments. A constructor is just a function that creates a new instance of your type: +To be able to build the loss term, it needs a constructor. +Every `SemLoss` subtype should provide a constructor with 3 positional arguments: + * `observed::SemObserved`: the observed part of the model + * `implied::SemImplied`: the implied part of the model + * `refloss::Union{MyLoss, Nothing} = nothing`: optional loss term of the same type + to use as a reference for any loss-specific configuration. + +Any additional loss configuration details should be passed as optional keyword arguments. +If both `refloss` and the keyword arguments are provided, the keyword arguments take +precedence. This constructor is used internally by the functions like [`replace_observed`](@ref) +to rebuild the loss term with new observed data while preserving the implied state. ```julia -function MyLoss(;arg1 = ..., arg2, kwargs...) +function MyLoss( + observed::SemObserved, implied::SemImplied, refloss::Union{MyLoss, Nothing} = nothing; + kwarg1 = ..., kwarg2 = ..., kwargs... +) ... - return MyLoss(...) -end -``` - -All keyword arguments that a user passes to the Sem constructor are passed to your loss function. In addition, all previously constructed parts of the model (implied and observed part) are passed as keyword arguments as well as the number of parameters `n_par = ...`, so your constructor may depend on those. For example, the constructor for `SemML` in our package depends on the additional argument `meanstructure` as well as the observed part of the model to pre-allocate arrays of the same size as the observed covariance matrix and the observed mean vector: - -```julia -function SemML(;observed, meanstructure = false, approx_H = false, kwargs...) - - isnothing(obs_mean(observed)) ? - meandiff = nothing : - meandiff = copy(obs_mean(observed)) - - return SemML( - similar(obs_cov(observed)), - similar(obs_cov(observed)), - meandiff, - approx_H, - Val(meanstructure) - ) + return MyLoss(...) # internal MyLoss constructor end ``` diff --git a/src/loss/ML/FIML.jl b/src/loss/ML/FIML.jl index 15081e20..74d5edfb 100644 --- a/src/loss/ML/FIML.jl +++ b/src/loss/ML/FIML.jl @@ -75,13 +75,16 @@ Can handle observed data with missing values. # Constructor - SemFIML(observed::SemObservedMissing, implied::SemImplied) + SemFIML(observed::SemObservedMissing, implied::SemImplied, refloss = nothing) # Arguments - `observed::SemObservedMissing`: the observed part of the model (see [`SemObservedMissing`](@ref)) - `implied::SemImplied`: the implied part of the model (see [`SemImplied`](@ref)) +- `refloss::Union{SemFIML, Nothing}`: optional reference loss used to preserve + loss-specific configuration and share the internal state when rebuilding + a loss term, e.g. in [`replace_observed`](@ref) # Examples ```julia @@ -109,7 +112,12 @@ end ### Constructors ############################################################################################ -function SemFIML(observed::SemObservedMissing, implied::SemImplied; kwargs...) +function SemFIML( + observed::SemObservedMissing, + implied::SemImplied, + refloss::Union{SemFIML, Nothing} = nothing; + kwargs..., +) if MeanStruct(implied) === NoMeanStruct """ Full information maximum likelihood (FIML) can only be used with a meanstructure. @@ -124,8 +132,10 @@ function SemFIML(observed::SemObservedMissing, implied::SemImplied; kwargs...) observed, implied, [SemFIMLPattern(pat) for pat in observed.patterns], + # share the internal state with the refloss + !isnothing(refloss) ? refloss.imp_inv : zeros(nobserved_vars(observed), nobserved_vars(observed)), - CommutationMatrix(nvars(implied)), + !isnothing(refloss) ? refloss.commutator : CommutationMatrix(nvars(implied)), nothing, ) end diff --git a/src/loss/ML/ML.jl b/src/loss/ML/ML.jl index 9f327544..cf119832 100644 --- a/src/loss/ML/ML.jl +++ b/src/loss/ML/ML.jl @@ -8,11 +8,14 @@ Maximum likelihood estimation. # Constructor - SemML(observed, implied; approximate_hessian = false) + SemML(observed, implied, refloss = nothing; approximate_hessian = false) # Arguments - `observed::SemObserved`: the observed part of the model - `implied::SemImplied`: [`SemImplied`](@ref) instance +- `refloss::Union{SemML, Nothing}`: optional reference loss used to preserve + loss-specific configuration and share the internal state when rebuilding + a loss term, e.g. in [`replace_observed`](@ref) - `approximate_hessian::Bool`: if hessian-based optimization is used, should the hessian be swapped for an approximation # Examples @@ -39,8 +42,10 @@ end function SemML( observed::SemObserved, - implied::SemImplied; - approximate_hessian::Bool = false, + implied::SemImplied, + refloss::Union{SemML, Nothing} = nothing; + approximate_hessian::Bool = !isnothing(refloss) ? + HessianEval(refloss) === ApproxHessian : false, kwargs..., ) if observed isa SemObservedMissing diff --git a/src/loss/WLS/WLS.jl b/src/loss/WLS/WLS.jl index 9acb7de0..d067e346 100644 --- a/src/loss/WLS/WLS.jl +++ b/src/loss/WLS/WLS.jl @@ -10,7 +10,7 @@ At the moment only available with the `RAMSymbolic` implied type. # Constructor SemWLS( - observed::SemObserved, implied::SemImplied; + observed::SemObserved, implied::SemImplied, refloss = nothing; wls_weight_matrix = nothing, wls_weight_matrix_mean = nothing, approximate_hessian = false, @@ -19,6 +19,9 @@ At the moment only available with the `RAMSymbolic` implied type. # Arguments - `observed`: the `SemObserved` part of the model - `implied`: the `SemImplied` part of the model +- `refloss::Union{SemWLS, Nothing}`: optional reference loss used to preserve + loss-specific configuration and share the internal state when rebuilding + a loss term, e.g. in [`replace_observed`](@ref) - `approximate_hessian::Bool`: should the hessian be swapped for an approximation - `wls_weight_matrix`: the weight matrix for weighted least squares. Defaults to GLS estimation (``0.5*(D^T*kron(S,S)*D)`` where D is the duplication matrix @@ -58,10 +61,14 @@ end function SemWLS( observed::SemObserved, - implied::SemImplied; - wls_weight_matrix::Union{AbstractMatrix, Nothing} = nothing, - wls_weight_matrix_mean::Union{AbstractMatrix, Nothing} = nothing, - approximate_hessian::Bool = false, + implied::SemImplied, + refloss::Union{SemWLS, Nothing} = nothing; + wls_weight_matrix::Union{AbstractMatrix, Nothing} = !isnothing(refloss) ? refloss.V : + nothing, + wls_weight_matrix_mean::Union{AbstractMatrix, Nothing} = !isnothing(refloss) ? + refloss.V_μ : nothing, + approximate_hessian::Bool = !isnothing(refloss) ? + HessianEval(refloss) === ApproxHessian : false, verbose::Bool = false, kwargs..., ) @@ -190,7 +197,8 @@ function replace_observed( # recompute weight matrices only if recompute_observed_state=true return SemWLS( new_observed, - SEM.implied(loss); + SEM.implied(loss), + loss; wls_weight_matrix = recompute_observed_state ? nothing : loss.V, wls_weight_matrix_mean = recompute_observed_state ? nothing : loss.V_μ, ) diff --git a/test/examples/multigroup/build_models.jl b/test/examples/multigroup/build_models.jl index 6c22a453..b538a4f4 100644 --- a/test/examples/multigroup/build_models.jl +++ b/test/examples/multigroup/build_models.jl @@ -139,8 +139,11 @@ struct UserSemML{O, I} <: SemLoss{O, I} observed::O implied::I - UserSemML(observed::SemObserved, implied::SemImplied) = - new{typeof(observed), typeof(implied)}(ExactHessian(), observed, implied) + UserSemML( + observed::SemObserved, + implied::SemImplied, + refloss::Union{UserSemML, Nothing} = nothing, + ) = new{typeof(observed), typeof(implied)}(ExactHessian(), observed, implied) end function SEM.objective(ml::UserSemML, params) From d5ff15b945e6adfa13706ff4c134bd0bc77fc8da Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sun, 3 May 2026 20:36:28 -0700 Subject: [PATCH 39/74] replace_observed(): use 3-arg SemLoss ctor Co-authored-by: Copilot --- docs/src/performance/simulation.md | 9 +++++++-- src/loss/abstract.jl | 12 ++++++++---- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/docs/src/performance/simulation.md b/docs/src/performance/simulation.md index f5199f62..3061c656 100644 --- a/docs/src/performance/simulation.md +++ b/docs/src/performance/simulation.md @@ -3,8 +3,13 @@ ## Replace observed data In simulation studies, a common task is fitting the same model to many different datasets. It would be a waste of resources to reconstruct the complete model for each dataset. -We therefore provide the function `replace_observed` to change the `observed` part of a model, -without necessarily reconstructing the other parts. +We therefore provide the function [`replace_observed`](@ref) to change the `observed` part +of a model, without necessarily reconstructing the other parts. + +For `SemLoss` terms, `replace_observed()` constructs the new loss by passing the new observed +data, the current implied state, and the current loss (as `refloss`) to the appropriate loss +constructor. The new loss term therefore shares the implied state with the original one, as well +as loss-specific settings and, potentially, the internal state. For the [A first model](@ref), you would use it as diff --git a/src/loss/abstract.jl b/src/loss/abstract.jl index 2b4940ad..56a3af58 100644 --- a/src/loss/abstract.jl +++ b/src/loss/abstract.jl @@ -55,14 +55,18 @@ end function replace_observed(loss::SemLoss, new_observed::SemObserved; kwargs...) check_observed_vars(loss, new_observed) - # the default replace_observed() does not pass through kwargs to the ctor - return typeof(loss).name.wrapper(new_observed, SEM.implied(loss)) + # construct the new loss: + # 1) replace the observed + # 2) share the implied and its internal state with the original loss + # 3) replicate the current loss configuration/share its internal state + loss_ctor = typeof(loss).name.wrapper # get the loss constructor + return loss_ctor(new_observed, SEM.implied(loss), loss) end function replace_observed(loss::SemLoss, data::Union{AbstractMatrix, DataFrame}; kwargs...) old_obs = SEM.observed(loss) - new_observed = - typeof(old_obs).name.wrapper(data = data, observed_vars = observed_vars(old_obs)) + obs_ctor = typeof(old_obs).name.wrapper + new_observed = obs_ctor(data = data, observed_vars = observed_vars(old_obs)) return replace_observed(loss, new_observed; kwargs...) end From bc6ef3bd561f1069834b4b075df8c68e360be0cf Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sun, 3 May 2026 20:36:46 -0700 Subject: [PATCH 40/74] boostrap!(): deepcopy the sem Co-authored-by: Copilot --- src/frontend/fit/standard_errors/bootstrap.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/frontend/fit/standard_errors/bootstrap.jl b/src/frontend/fit/standard_errors/bootstrap.jl index 0a1f39b4..a8c1d393 100644 --- a/src/frontend/fit/standard_errors/bootstrap.jl +++ b/src/frontend/fit/standard_errors/bootstrap.jl @@ -18,8 +18,9 @@ function bootstrap!( # fit to bootstrap samples if !parallel + bs_sem = deepcopy(sem) # avoid mutating the original model for i in 1:n_boot - new_fit = _fit_bootstrap_sample(sem, data, start; engine, fit_kwargs) + new_fit = _fit_bootstrap_sample(bs_sem, data, start; engine, fit_kwargs) update!(acc, i, new_fit, nothing) end else From 6f09822a22c5b260a2f2cb80c191fba732243639 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sun, 3 May 2026 20:37:40 -0700 Subject: [PATCH 41/74] unit_tests/model: more config-preserving tests Co-authored-by: Copilot --- test/unit_tests/model.jl | 74 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 71 insertions(+), 3 deletions(-) diff --git a/test/unit_tests/model.jl b/test/unit_tests/model.jl index c80a0c1b..b62f056d 100644 --- a/test/unit_tests/model.jl +++ b/test/unit_tests/model.jl @@ -58,7 +58,7 @@ end observed = obs, implied = impliedtype, loss = losstype, - vech = losstype <: SemWLS && impliedtype <: RAMSymbolic + vech = losstype <: SemWLS && impliedtype <: RAMSymbolic, ) @test model isa Sem @@ -77,13 +77,20 @@ end @test @inferred(nsamples(model)) == nsamples(obs) end -@testset "replace_observed() preserves WLS state through finite-diff wrappers" begin +@testset "replace_observed() preserves WLS and approx_hessian=$(approx_hessian) state through finite-diff wrappers" for approx_hessian in + ( + false, + true, +) + expected_hessianeval = approx_hessian ? SEM.ApproxHessian : SEM.ExactHessian + model = Sem( specification = ram_matrices, observed = obs, implied = RAMSymbolic, loss = SemWLS, - vech = true + vech = true, + approximate_hessian = approx_hessian, ) wls_loss = sem_term(model) findiff_model = Sem(SEM.FiniteDiffWrapper(wls_loss)) @@ -104,11 +111,72 @@ end @test loss_newstate isa SemWLS @test loss_orig !== loss_oldstate @test loss_orig !== loss_newstate + @test SEM.HessianEval(loss_orig) === expected_hessianeval + @test SEM.HessianEval(loss_oldstate) === expected_hessianeval + @test SEM.HessianEval(loss_newstate) === expected_hessianeval @test loss_oldstate.V === loss_orig.V @test loss_newstate.V !== loss_orig.V @test observed_vars(loss_oldstate) == observed_vars(loss_orig) end +@testset "replace_observed() shares implied unless model is deepcopied and approx_hessian=$(approx_hessian)" for approx_hessian in + ( + false, + true, +) + expected_hessianeval = approx_hessian ? SEM.ApproxHessian : SEM.ExactHessian + + model = Sem( + specification = ram_matrices, + observed = obs, + implied = RAMSymbolic, + loss = SemML, + approximate_hessian = approx_hessian, + ) + + data_new = randn(nsamples(obs), nobserved_vars(obs)) + + model_new = replace_observed(model, data_new) + model_deepcopy = replace_observed(deepcopy(model), data_new) + + loss_orig = sem_term(model) + loss_new = sem_term(model_new) + loss_deepcopy = sem_term(model_deepcopy) + + @test SEM.HessianEval(loss_orig) === expected_hessianeval + @test SEM.HessianEval(loss_new) === expected_hessianeval + @test SEM.HessianEval(loss_deepcopy) === expected_hessianeval + @test implied(loss_new) === implied(loss_orig) + @test implied(loss_deepcopy) !== implied(loss_orig) +end + +@testset "replace_observed() preserves Sem container defaults" begin + data_g1 = dat[1:40, :] + data_g2 = dat[41:end, :] + + sem_multigroup = Sem( + :g1 => SemML( + SemObservedData(specification = ram_matrices, data = data_g1), + RAM(ram_matrices), + ), + :g2 => SemML( + SemObservedData(specification = ram_matrices, data = data_g2), + RAM(ram_matrices), + ); + default_sem_weights = :one, + ) + + sem_newobs = replace_observed( + sem_multigroup, + Dict(:g1 => randn(10, nobserved_vars(obs)), :g2 => randn(25, nobserved_vars(obs))), + ) + + @test all(isnothing, map(SEM.weight, SEM.loss_terms(sem_multigroup))) + @test all(isnothing, map(SEM.weight, SEM.loss_terms(sem_newobs))) + @test params(sem_newobs) == params(sem_multigroup) + @test params(sem_newobs) !== params(sem_multigroup) +end + @testset "Sem(...; semterm_column=...) splits ensemble data by group" begin dat_grouped = copy(dat[:, [:x1, :x2]]) n_g1 = size(dat_grouped, 1) ÷ 2 From 32ad928e3e8fed1d9a8ec77cc7998810cc9f0732 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sun, 3 May 2026 20:39:35 -0700 Subject: [PATCH 42/74] tests: replace_observed(UserSemML) --- test/examples/multigroup/build_models.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/examples/multigroup/build_models.jl b/test/examples/multigroup/build_models.jl index b538a4f4..47bdea22 100644 --- a/test/examples/multigroup/build_models.jl +++ b/test/examples/multigroup/build_models.jl @@ -178,6 +178,16 @@ end ) end +@testset "replace_observed_user_defined_loss" begin + wrapped_loss = SEM.FiniteDiffWrapper(UserSemML(obs_g2, RAMSymbolic(specification_g2))) + new_data = randn(nsamples(obs_g2), nobserved_vars(obs_g2)) + replaced_loss = SEM._unwrap(replace_observed(wrapped_loss, new_data)) + + @test replaced_loss isa UserSemML + @test observed_vars(replaced_loss) == observed_vars(obs_g2) + @test implied(replaced_loss) === implied(SEM._unwrap(wrapped_loss)) +end + ############################################################################################ # GLS estimation ############################################################################################ From 25dd38c68c0c23c1b90759981a3036d4acb49139 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sun, 3 May 2026 20:56:00 -0700 Subject: [PATCH 43/74] show(ParTable): fix formatting --- src/frontend/specification/ParameterTable.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/frontend/specification/ParameterTable.jl b/src/frontend/specification/ParameterTable.jl index ecff7c4e..c9b9dc24 100644 --- a/src/frontend/specification/ParameterTable.jl +++ b/src/frontend/specification/ParameterTable.jl @@ -114,7 +114,10 @@ function Base.show(io::IO, partable::ParameterTable) pretty_table( io, as_matrix, - column_labels = [shown_columns, [eltype(partable.columns[col]) for col in shown_columns]], + column_labels = [ + shown_columns, + [eltype(partable.columns[col]) for col in shown_columns], + ], table_format = TextTableFormat(borders = text_table_borders__compact), # TODO switch to `missing` as non-specified values and suppress printing of `missing` instead formatters = [(v, i, j) -> isa(v, Number) && isnan(v) ? "" : v], From ee91629e414ae20033400e6c65785919c9b3773a Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 14:53:17 -0700 Subject: [PATCH 44/74] WIP SemImpliedState --- src/types.jl | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/types.jl b/src/types.jl index eb251a3b..f1ed3d12 100644 --- a/src/types.jl +++ b/src/types.jl @@ -46,6 +46,8 @@ If you have a special kind of data, e.g. ordinal data, you should implement a su abstract type SemObserved end """ + abstract type SemImplied + Supertype of all objects that can serve as the implied field of a SEM. Computes model-implied values that should be compared with the observed data to find parameter estimates, e. g. the model implied covariance or mean. @@ -56,6 +58,22 @@ abstract type SemImplied end "Subtype of SemImplied for all objects that can serve as the implied field of a SEM and use some form of symbolic precomputation." abstract type SemImpliedSymbolic <: SemImplied end +""" + abstract type SemImpliedState + +State of [`SemImplied`](@ref) that corresponds to the specific SEM parameter values. + +Contains the necessary vectors and matrices for calculating the SEM +objective, gradient and hessian (whichever is requested). +""" +abstract type SemImpliedState{I <: SemImplied} end + +impliedtype(::Type{SemImpliedState{I}}) where {I} = I +impliedtype(state::SemImpliedState) = impliedtype(typeof(state)) +implied(state::SemImpliedState) = state.implied +MeanStruct(::Type{S}) where {S <: SemImpliedState} = MeanStruct(impliedtype(state)) +HessianEval(::Type{S}) where {S <: SemImpliedState} = HessianEval(impliedtype(state)) + """ abstract type SemLoss{O <: SemObserved, I <: SemImplied} <: AbstractLoss From 0df5515c1ed579b1ef63a62e1a1195e1c1ccb6d2 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 14:57:25 -0700 Subject: [PATCH 45/74] declare cov matrices symmetric --- src/frontend/fit/fitmeasures/minus2ll.jl | 2 +- src/observed/covariance.jl | 4 ++-- src/observed/data.jl | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/frontend/fit/fitmeasures/minus2ll.jl b/src/frontend/fit/fitmeasures/minus2ll.jl index 1cdf5c07..661c1ee0 100644 --- a/src/frontend/fit/fitmeasures/minus2ll.jl +++ b/src/frontend/fit/fitmeasures/minus2ll.jl @@ -39,7 +39,7 @@ function minus2ll(observed::SemObservedMissing) # FIXME: this code is duplicate to objective(fiml, ...) F = sum(observed.patterns) do pat # implied covariance/mean - Σᵢ = Σ[pat.measured_mask, pat.measured_mask] + Σᵢ = Symmetric(Σ[pat.measured_mask, pat.measured_mask]) Σᵢ_chol = cholesky!(Σᵢ) ld = logdet(Σᵢ_chol) Σᵢ⁻¹ = LinearAlgebra.inv!(Σᵢ_chol) diff --git a/src/observed/covariance.jl b/src/observed/covariance.jl index f81fe8e5..43960f62 100644 --- a/src/observed/covariance.jl +++ b/src/observed/covariance.jl @@ -3,7 +3,7 @@ Type alias for [`SemObservedData`](@ref) that has mean and covariance, but no ac For instances of `SemObservedCovariance` [`samples`](@ref) returns `nothing`. """ -const SemObservedCovariance{S} = SemObservedData{Nothing, S} +const SemObservedCovariance{C, S} = SemObservedData{Nothing, C, S} """ SemObservedCovariance(; @@ -76,5 +76,5 @@ function SemObservedCovariance(; obs_mean = obs_mean[obs_vars_perm] end - return SemObservedData(nothing, obs_vars, obs_cov, obs_mean, nsamples) + return SemObservedData(nothing, obs_vars, Symmetric(obs_cov), obs_mean, nsamples) end diff --git a/src/observed/data.jl b/src/observed/data.jl index 30d433e0..39eebe30 100644 --- a/src/observed/data.jl +++ b/src/observed/data.jl @@ -23,10 +23,10 @@ For observed data without missings. - `obs_cov(::SemObservedData)` -> observed covariance matrix - `obs_mean(::SemObservedData)` -> observed mean vector """ -struct SemObservedData{D <: Union{Nothing, AbstractMatrix}, S <: Number} <: SemObserved +struct SemObservedData{D <: Union{Nothing, AbstractMatrix}, C, S <: Number} <: SemObserved data::D observed_vars::Vector{Symbol} - obs_cov::Matrix{S} + obs_cov::C obs_mean::Vector{S} nsamples::Int end @@ -58,7 +58,7 @@ function SemObservedData(; throw end - return SemObservedData(data, obs_vars, obs_cov, vec(obs_mean), size(data, 1)) + return SemObservedData(data, obs_vars, Symmetric(obs_cov), vec(obs_mean), size(data, 1)) end ############################################################################################ From 85363fc01c2758f02825a208e3cb6fa4d7c357c7 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 14:57:50 -0700 Subject: [PATCH 46/74] RAM: reuse sigma array --- src/loss/ML/ML.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/loss/ML/ML.jl b/src/loss/ML/ML.jl index cf119832..af6425dd 100644 --- a/src/loss/ML/ML.jl +++ b/src/loss/ML/ML.jl @@ -208,7 +208,12 @@ function evaluate!(objective, gradient, hessian, loss::SemML, par) ∇A = implied.∇A ∇S = implied.∇S - C = F⨉I_A⁻¹' * (I - Σ⁻¹Σₒ) * Σ⁻¹ * F⨉I_A⁻¹ + # reuse Σ⁻¹Σₒ to calculate I-Σ⁻¹Σₒ + one_Σ⁻¹Σₒ = Σ⁻¹Σₒ + one_Σ⁻¹Σₒ .*= -1 + one_Σ⁻¹Σₒ[diagind(one_Σ⁻¹Σₒ)] .+= 1 + + C = F⨉I_A⁻¹' * one_Σ⁻¹Σₒ * Σ⁻¹ * F⨉I_A⁻¹ mul!(gradient, ∇A', vec(C * S * I_A⁻¹'), 2, 0) mul!(gradient, ∇S', vec(C), 1, 1) From 4c3d14b85a2f554c6f18a250a2020c6206d9fdc9 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 14:58:39 -0700 Subject: [PATCH 47/74] RAM: optional sparse Sigma matrix --- src/implied/RAM/generic.jl | 5 ++++- src/loss/ML/ML.jl | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/implied/RAM/generic.jl b/src/implied/RAM/generic.jl index 1569b341..58b689f8 100644 --- a/src/implied/RAM/generic.jl +++ b/src/implied/RAM/generic.jl @@ -90,6 +90,7 @@ function RAM( spec::SemSpecification; #vech = false, gradient_required = true, + sparse_S::Bool = true, kwargs..., ) ram_matrices = convert(RAMMatrices, spec) @@ -102,7 +103,9 @@ function RAM( #preallocate arrays rand_params = randn(Float64, n_par) A_pre = check_acyclic(materialize(ram_matrices.A, rand_params)) - S_pre = materialize(ram_matrices.S, rand_params) + S_pre = Symmetric( + (sparse_S ? sparse_materialize : materialize)(ram_matrices.S, rand_params), + ) F = copy(ram_matrices.F) # pre-allocate some matrices diff --git a/src/loss/ML/ML.jl b/src/loss/ML/ML.jl index af6425dd..534801ff 100644 --- a/src/loss/ML/ML.jl +++ b/src/loss/ML/ML.jl @@ -214,7 +214,7 @@ function evaluate!(objective, gradient, hessian, loss::SemML, par) one_Σ⁻¹Σₒ[diagind(one_Σ⁻¹Σₒ)] .+= 1 C = F⨉I_A⁻¹' * one_Σ⁻¹Σₒ * Σ⁻¹ * F⨉I_A⁻¹ - mul!(gradient, ∇A', vec(C * S * I_A⁻¹'), 2, 0) + mul!(gradient, ∇A', vec(C * mul!(similar(C), S, I_A⁻¹')), 2, 0) mul!(gradient, ∇S', vec(C), 1, 1) if MeanStruct(implied) === HasMeanStruct From 62a1a799da7b2fc24e98e4d64af318e7db78dc2f Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:12:45 -0700 Subject: [PATCH 48/74] ML: refactor to minimize allocs * preallocate matrices for intermediate gradient calculation * call mul!() with these preallocating matrices * annotate mul!() arguments as triangular/symmetric to use faster routines --- src/loss/ML/ML.jl | 95 ++++++++++++++++++++++++++++++----------------- 1 file changed, 61 insertions(+), 34 deletions(-) diff --git a/src/loss/ML/ML.jl b/src/loss/ML/ML.jl index 534801ff..5052aa8d 100644 --- a/src/loss/ML/ML.jl +++ b/src/loss/ML/ML.jl @@ -27,13 +27,19 @@ my_ml = SemML(my_observed, my_implied) Analytic gradients are available, and for models without a meanstructure and `RAMSymbolic` implied type, also analytic hessians. """ -struct SemML{O, I, HE <: HessianEval, INV, M, M2} <: SemLoss{O, I} +struct SemML{O, I, HE <: HessianEval, M} <: SemLoss{O, I} observed::O implied::I hessianeval::HE - Σ⁻¹::INV - Σ⁻¹Σₒ::M - meandiff::M2 + + # pre-allocated arrays to store intermediate results in evaluate!() + obsXobs_1::M + obsXobs_2::M + obsXobs_3::M + obsXvar_1::M + varXvar_1::M + varXvar_2::M + varXvar_3::M end ############################################################################################ @@ -63,26 +69,27 @@ function SemML( end # check integrity check_observed_vars(observed, implied) + @assert isnothing(refloss) || ( + nobserved_vars(refloss.observed) == nobserved_vars(observed) && + nvars(refloss.implied) == nvars(implied) + ) he = approximate_hessian ? ApproxHessian() : ExactHessian() - obsmean = obs_mean(observed) - obscov = obs_cov(observed) - meandiff = isnothing(obsmean) ? nothing : copy(obsmean) - - return SemML{ - typeof(observed), - typeof(implied), - typeof(he), - typeof(obscov), - typeof(obscov), - typeof(meandiff), - }( + obsXobs = parent(obs_cov(observed)) + nobs = nobserved_vars(observed) + nvar = nvars(implied) + + return SemML{typeof(observed), typeof(implied), typeof(he), typeof(obsXobs)}( observed, implied, he, - similar(obscov), - similar(obscov), - meandiff, + isnothing(refloss) ? similar(obsXobs) : refloss.obsXobs_1, + isnothing(refloss) ? similar(obsXobs) : refloss.obsXobs_2, + isnothing(refloss) ? similar(obsXobs) : refloss.obsXobs_3, + isnothing(refloss) ? similar(obsXobs, (nobs, nvar)) : refloss.obsXvar_1, + isnothing(refloss) ? similar(obsXobs, (nvar, nvar)) : refloss.varXvar_1, + isnothing(refloss) ? similar(obsXobs, (nvar, nvar)) : refloss.varXvar_2, + isnothing(refloss) ? similar(obsXobs, (nvar, nvar)) : refloss.varXvar_3, ) end @@ -109,10 +116,8 @@ function evaluate!( Σ = implied.Σ Σₒ = obs_cov(observed(loss)) - Σ⁻¹Σₒ = loss.Σ⁻¹Σₒ - Σ⁻¹ = loss.Σ⁻¹ - copyto!(Σ⁻¹, Σ) + Σ⁻¹ = copy!(loss.obsXobs_1, Σ) Σ_chol = cholesky!(Symmetric(Σ⁻¹); check = false) if !isposdef(Σ_chol) #@warn "∑⁻¹ is not positive definite" @@ -123,7 +128,7 @@ function evaluate!( end ld = logdet(Σ_chol) Σ⁻¹ = LinearAlgebra.inv!(Σ_chol) - mul!(Σ⁻¹Σₒ, Σ⁻¹, Σₒ) + Σ⁻¹Σₒ = mul!(loss.obsXobs_2, Σ⁻¹, Σₒ) isnothing(objective) || (objective = ld + tr(Σ⁻¹Σₒ)) if MeanStruct(implied) === HasMeanStruct @@ -136,12 +141,15 @@ function evaluate!( ∇Σ = implied.∇Σ ∇μ = implied.∇μ μ₋ᵀΣ⁻¹ = μ₋' * Σ⁻¹ - mul!(gradient, ∇Σ', vec(Σ⁻¹ - Σ⁻¹Σₒ * Σ⁻¹ - μ₋ᵀΣ⁻¹'μ₋ᵀΣ⁻¹)) + J = copyto!(loss.obsXobs_3, Σ⁻¹) + mul!(J, Σ⁻¹Σₒ, Σ⁻¹, -1, 1) + mul!(J, μ₋ᵀΣ⁻¹', μ₋ᵀΣ⁻¹, -1, 1) + mul!(gradient, ∇Σ', vec(J)) mul!(gradient, ∇μ', μ₋ᵀΣ⁻¹', -2, 1) end elseif !isnothing(gradient) || !isnothing(hessian) ∇Σ = implied.∇Σ - Σ⁻¹ΣₒΣ⁻¹ = Σ⁻¹Σₒ * Σ⁻¹ + Σ⁻¹ΣₒΣ⁻¹ = mul!(loss.obsXobs_3, Σ⁻¹Σₒ, Σ⁻¹) J = vec(Σ⁻¹ - Σ⁻¹ΣₒΣ⁻¹)' if !isnothing(gradient) mul!(gradient, ∇Σ', J') @@ -175,9 +183,8 @@ function evaluate!(objective, gradient, hessian, loss::SemML, par) Σ = implied.Σ Σₒ = obs_cov(observed(loss)) - Σ⁻¹Σₒ = loss.Σ⁻¹Σₒ - Σ⁻¹ = loss.Σ⁻¹ - copyto!(Σ⁻¹, Σ) + + Σ⁻¹ = copy!(loss.obsXobs_1, Σ) Σ_chol = cholesky!(Symmetric(Σ⁻¹); check = false) if !isposdef(Σ_chol) #@warn "Σ⁻¹ is not positive definite" @@ -188,7 +195,7 @@ function evaluate!(objective, gradient, hessian, loss::SemML, par) end ld = logdet(Σ_chol) Σ⁻¹ = LinearAlgebra.inv!(Σ_chol) - mul!(Σ⁻¹Σₒ, Σ⁻¹, Σₒ) + Σ⁻¹Σₒ = mul!(loss.obsXobs_2, Σ⁻¹, Σₒ) if !isnothing(objective) objective = ld + tr(Σ⁻¹Σₒ) @@ -210,11 +217,25 @@ function evaluate!(objective, gradient, hessian, loss::SemML, par) # reuse Σ⁻¹Σₒ to calculate I-Σ⁻¹Σₒ one_Σ⁻¹Σₒ = Σ⁻¹Σₒ - one_Σ⁻¹Σₒ .*= -1 + lmul!(-1, one_Σ⁻¹Σₒ) one_Σ⁻¹Σₒ[diagind(one_Σ⁻¹Σₒ)] .+= 1 - C = F⨉I_A⁻¹' * one_Σ⁻¹Σₒ * Σ⁻¹ * F⨉I_A⁻¹ - mul!(gradient, ∇A', vec(C * mul!(similar(C), S, I_A⁻¹')), 2, 0) + C = mul!( + loss.varXvar_1, + F⨉I_A⁻¹', + mul!( + loss.obsXvar_1, + Symmetric(mul!(loss.obsXobs_3, one_Σ⁻¹Σₒ, Σ⁻¹)), + F⨉I_A⁻¹, + ), + ) + mul!( + gradient, + ∇A', + vec(mul!(loss.varXvar_3, Symmetric(C), mul!(loss.varXvar_2, S, I_A⁻¹'))), + 2, + 0, + ) mul!(gradient, ∇S', vec(C), 1, 1) if MeanStruct(implied) === HasMeanStruct @@ -226,8 +247,14 @@ function evaluate!(objective, gradient, hessian, loss::SemML, par) μ₋ᵀΣ⁻¹ = μ₋' * Σ⁻¹ k = μ₋ᵀΣ⁻¹ * F⨉I_A⁻¹ mul!(gradient, ∇M', k', -2, 1) - mul!(gradient, ∇A', vec(k' * (I_A⁻¹ * (M + S * k'))'), -2, 1) - mul!(gradient, ∇S', vec(k'k), -1, 1) + mul!( + gradient, + ∇A', + vec(mul!(loss.varXvar_1, k', (I_A⁻¹ * (M + S * k'))')), + -2, + 1, + ) + mul!(gradient, ∇S', vec(mul!(loss.varXvar_2, k', k)), -1, 1) end end From 1f225ccc30affce3b2045de4b5b644d28d9bd567 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:12:45 -0700 Subject: [PATCH 49/74] add PackageExtensionCompat --- Project.toml | 1 + src/StructuralEquationModels.jl | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index eabc5b36..f56b4ed7 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NLSolversBase = "d41bc354-129a-5804-8e4c-c37616107c6c" Optim = "429524aa-4258-5aef-a3af-852621145aeb" +PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" diff --git a/src/StructuralEquationModels.jl b/src/StructuralEquationModels.jl index 0dbcd16a..5c3b6030 100644 --- a/src/StructuralEquationModels.jl +++ b/src/StructuralEquationModels.jl @@ -15,7 +15,8 @@ using LinearAlgebra, LazyArtifacts, DelimitedFiles, DataFrames, - ProgressMeter + ProgressMeter, + PackageExtensionCompat import StatsAPI: params, coef, coefnames, dof, fit, nobs, coeftable @@ -210,4 +211,9 @@ export AbstractSem, ←, ↔, ⇔ + +function __init__() + @require_extensions +end + end From 708345fa943c192fe48e4236b1c28d5fcfa7e9c2 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:12:45 -0700 Subject: [PATCH 50/74] variance_params(SEMSpec) --- src/frontend/specification/ParameterTable.jl | 12 ++++++++++++ src/frontend/specification/RAMMatrices.jl | 11 +++++++++++ 2 files changed, 23 insertions(+) diff --git a/src/frontend/specification/ParameterTable.jl b/src/frontend/specification/ParameterTable.jl index c9b9dc24..102f81b2 100644 --- a/src/frontend/specification/ParameterTable.jl +++ b/src/frontend/specification/ParameterTable.jl @@ -405,6 +405,18 @@ function update_se_hessian!( return update_partable!(partable, :se, param_labels(fit), se) end +function variance_params(partable::ParameterTable) + res = [ + param for (param, rel, from, to) in zip( + partable.columns.param, + partable.columns.relation, + partable.columns.from, + partable.columns.to, + ) if (rel == :↔) && (from == to) + ] + unique!(res) +end + """ lavaan_params!(out::AbstractVector, partable_lav, partable::ParameterTable, diff --git a/src/frontend/specification/RAMMatrices.jl b/src/frontend/specification/RAMMatrices.jl index d430e9c0..ca77041d 100644 --- a/src/frontend/specification/RAMMatrices.jl +++ b/src/frontend/specification/RAMMatrices.jl @@ -62,6 +62,17 @@ function latent_vars(ram::RAMMatrices) end end +function variance_params(ram::RAMMatrices) + S_diaginds = Set(diagind(ram.S)) + varparams = Vector{Symbol}() + for (i, param) in enumerate(ram.params) + if any(∈(S_diaginds), param_occurences(ram.S, i)) + push!(varparams, param) + end + end + return unique!(varparams) +end + ############################################################################################ ### Constructor ############################################################################################ From 70ac792e36e2ca3019953069482f1dad0f7f9ee4 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:13:14 -0700 Subject: [PATCH 51/74] predict_latent_vars() --- src/StructuralEquationModels.jl | 1 + src/frontend/predict.jl | 119 ++++++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+) create mode 100644 src/frontend/predict.jl diff --git a/src/StructuralEquationModels.jl b/src/StructuralEquationModels.jl index 5c3b6030..45a12dae 100644 --- a/src/StructuralEquationModels.jl +++ b/src/StructuralEquationModels.jl @@ -46,6 +46,7 @@ include("frontend/specification/StenoGraphs.jl") include("frontend/fit/summary.jl") include("frontend/StatsAPI.jl") include("frontend/finite_diff.jl") +include("frontend/predict.jl") # pretty printing include("frontend/pretty_printing.jl") # observed diff --git a/src/frontend/predict.jl b/src/frontend/predict.jl new file mode 100644 index 00000000..61eecaf8 --- /dev/null +++ b/src/frontend/predict.jl @@ -0,0 +1,119 @@ +abstract type SemScoresPredictMethod end + +struct SemRegressionScores <: SemScoresPredictMethod end +struct SemBartlettScores <: SemScoresPredictMethod end +struct SemAndersonRubinScores <: SemScoresPredictMethod end + +function SemScoresPredictMethod(method::Symbol) + if method == :regression + return SemRegressionScores() + elseif method == :Bartlett + return SemBartlettScores() + elseif method == :AndersonRubin + return SemAndersonRubinScores() + else + throw(ArgumentError("Unsupported prediction method: $method")) + end +end + +predict_latent_scores( + fit::SemFit, + data::SemObserved = observed(sem_term(fit.model)); + method::Symbol = :regression, +) = predict_latent_scores(SemScoresPredictMethod(method), fit, data) + +predict_latent_scores( + method::SemScoresPredictMethod, + fit::SemFit, + data::SemObserved = observed(sem_term(fit.model)), +) = predict_latent_scores(method, loss(sem_term(fit.model)), fit.solution, data) + +function inv_cov!(A::AbstractMatrix) + if istril(A) + A = LowerTriangular(A) + elseif istriu(A) + A = UpperTriangular(A) + else + end + A_chol = Cholesky(A) + return inv!(A_chol) +end + +function latent_scores_operator( + ::SemRegressionScores, + model::SemLoss, + params::AbstractVector, +) + implied = SEM.implied(model) + ram = implied.ram_matrices + lv_inds = latent_var_indices(ram) + + A = materialize(ram.A, params) + lv_FA = ram.F * A[:, lv_inds] + lv_I_A⁻¹ = inv(I - A)[lv_inds, :] + + S = materialize(ram.S, params) + + cov_lv = lv_I_A⁻¹ * S * lv_I_A⁻¹' + Σ = implied.Σ + Σ⁻¹ = inv(Σ) + return cov_lv * lv_FA' * Σ⁻¹ +end + +function latent_scores_operator(::SemBartlettScores, model::SemLoss, params::AbstractVector) + implied = SEM.implied(model) + ram = implied.ram_matrices + lv_inds = latent_var_indices(ram) + A = materialize(ram.A, params) + lv_FA = ram.F * A[:, lv_inds] + + S = materialize(ram.S, params) + obs_inds = observed_var_indices(ram) + ov_S⁻¹ = inv(S[obs_inds, obs_inds]) + + return inv(lv_FA' * ov_S⁻¹ * lv_FA) * lv_FA' * ov_S⁻¹ +end + +function predict_latent_scores( + method::SemScoresPredictMethod, + model::SemLoss, + params::AbstractVector, + data::SemObserved = observed(model), +) + nobserved_vars(data) == nobserved_vars(model) || throw( + DimensionMismatch( + "Number of variables in data ($(nsamples(data))) does not match the number of observed variables in the model ($(nobserved_vars(model)))", + ), + ) + length(params) == nparams(model) || throw( + DimensionMismatch( + "The length of parameters vector ($(length(params))) does not match the number of parameters in the model ($(nparams(model)))", + ), + ) + + implied = SEM.implied(model) + update!(EvaluationTargets(0.0, nothing, nothing), implied, params) + ram = implied.ram_matrices + lv_inds = latent_var_indices(ram) + A = materialize(ram.A, params) + lv_I_A⁻¹ = inv(I - A)[lv_inds, :] + + lv_scores_op = latent_scores_operator(method, model, params) + + data = + data.data .- (isnothing(data.obs_mean) ? mean(data.data, dims = 1) : data.obs_mean') + lv_scores = data * lv_scores_op' + if MeanStruct(implied) === HasMeanStruct + M = materialize(ram.M, params) + lv_scores .+= (lv_I_A⁻¹ * M)' + end + + return lv_scores +end + +predict_latent_scores( + model::SemLoss, + params::AbstractVector, + data::SemObserved = observed(model); + method::Symbol = :regression, +) = predict_latent_scores(SemScoresPredictMethod(method), model, params, data) From 429723e8bedbbf1c82eeefe1acba8ec08db07e06 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:13:14 -0700 Subject: [PATCH 52/74] fixup docstring --- src/frontend/specification/ParameterTable.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/frontend/specification/ParameterTable.jl b/src/frontend/specification/ParameterTable.jl index 102f81b2..daca3dee 100644 --- a/src/frontend/specification/ParameterTable.jl +++ b/src/frontend/specification/ParameterTable.jl @@ -293,11 +293,11 @@ function update_partable!( end """ - (1) update_partable!(partable::AbstractParameterTable, column, fitted:SemFit, params, default = nothing) - - (2) update_partable!(partable::AbstractParameterTable, column, param_labels::Vector{Symbol}, params, default = nothing) + update_partable!(partable::AbstractParameterTable, column, fitted:SemFit, params, default = nothing) -Add a new column to a parameter table. + update_partable!(partable::AbstractParameterTable, column, param_labels::Vector{Symbol}, params, default = nothing) + +Add a new column to a parameter table. `column` is the name of the column, `params` contains the values of the new column, and `fitted` or `param_labels` is used to match the values to the correct parameter labels. The `default` value is used if a parameter in `partable` does not occur in `param_labels`. From 255017ef56ce630827495b169a332a3fbdf0bc27 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:13:14 -0700 Subject: [PATCH 53/74] lavaan_model() --- src/frontend/specification/ParameterTable.jl | 75 ++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/src/frontend/specification/ParameterTable.jl b/src/frontend/specification/ParameterTable.jl index daca3dee..739ad571 100644 --- a/src/frontend/specification/ParameterTable.jl +++ b/src/frontend/specification/ParameterTable.jl @@ -552,3 +552,78 @@ lavaan_params( lav_col::Symbol = :est, lav_group = nothing, ) = lavaan_params!(fill(NaN, nparams(partable)), partable_lav, partable, lav_col, lav_group) + +""" + lavaan_model(partable::ParameterTable) + +Generate lavaan model definition from a `partable`. +""" +function lavaan_model(partable::ParameterTable) + latent_vars = Set(partable.variables.latent) + observed_vars = Set(partable.variables.observed) + + variance_defs = Dict{Symbol, IOBuffer}() + latent_dep_defs = Dict{Symbol, IOBuffer}() + latent_regr_defs = Dict{Symbol, IOBuffer}() + observed_regr_defs = Dict{Symbol, IOBuffer}() + + model = IOBuffer() + for (from, to, rel, param, value, free) in zip( + partable.columns.from, + partable.columns.to, + partable.columns.relation, + partable.columns.param, + partable.columns.value_fixed, + partable.columns.free, + ) + function append_param(io) + if free + @assert param != :const + param == Symbol("") || write(io, "$param * ") + else + write(io, "$value * ") + end + end + function append_rhs(io) + if position(io) > 0 + write(io, " + ") + end + append_param(io) + write(io, "$to") + end + + if from == Symbol("1") + write(model, "$to ~ ") + append_param(model) + write(model, "1\n") + else + if rel == :↔ + variance_def = get!(() -> IOBuffer(), variance_defs, from) + append_rhs(variance_def) + elseif rel == :→ + if (from ∈ latent_vars) && (to ∈ observed_vars) + latent_dep_def = get!(() -> IOBuffer(), latent_dep_defs, from) + append_rhs(latent_dep_def) + elseif (from ∈ latent_vars) && (to ∈ latent_vars) + latent_regr_def = get!(() -> IOBuffer(), latent_regr_defs, from) + append_rhs(latent_regr_def) + else + observed_regr_def = get!(() -> IOBuffer(), observed_regr_defs, from) + append_rhs(observed_regr_def) + end + end + end + end + function write_rules(io, defs, relation) + vars = sort!(collect(keys(defs))) + for var in vars + write(io, String(var), " ", relation, " ") + write(io, String(take!(defs[var])), "\n") + end + end + write_rules(model, latent_dep_defs, "=~") + write_rules(model, latent_regr_defs, "~") + write_rules(model, observed_regr_defs, "~") + write_rules(model, variance_defs, "~~") + return String(take!(model)) +end From 29edb0805a6e316045f702d77141634f89b71a90 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:13:14 -0700 Subject: [PATCH 54/74] test_grad/hess(): check that alt calls give same results --- test/examples/helper.jl | 56 ++++++++++++++++++----------------------- 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/test/examples/helper.jl b/test/examples/helper.jl index fed95f3c..daf34bb4 100644 --- a/test/examples/helper.jl +++ b/test/examples/helper.jl @@ -9,44 +9,38 @@ function test_gradient(model, params; rtol = 1e-10, atol = 0) @test nparams(model) == length(params) true_grad = FiniteDiff.finite_difference_gradient(Base.Fix1(objective!, model), params) - gradient = similar(params) - # F and G - fill!(gradient, NaN) - gradient!(gradient, model, params) - @test gradient ≈ true_grad rtol = rtol atol = atol + gradient_G = fill!(similar(params), NaN) + gradient!(gradient_G, model, params) + gradient_FG = fill!(similar(params), NaN) + objective_gradient!(gradient_FG, model, params) - # only G - fill!(gradient, NaN) - objective_gradient!(gradient, model, params) - @test gradient ≈ true_grad rtol = rtol atol = atol + @test gradient_G == gradient_FG + + #@info "G norm = $(norm(gradient_G - true_grad, Inf))" + @test gradient_G ≈ true_grad rtol = rtol atol = atol end function test_hessian(model, params; rtol = 1e-4, atol = 0) true_hessian = FiniteDiff.finite_difference_hessian(Base.Fix1(objective!, model), params) - hessian = similar(params, size(true_hessian)) - gradient = similar(params) - - # H - fill!(hessian, NaN) - hessian!(hessian, model, params) - @test hessian ≈ true_hessian rtol = rtol atol = atol - - # F and H - fill!(hessian, NaN) - objective_hessian!(hessian, model, params) - @test hessian ≈ true_hessian rtol = rtol atol = atol - - # G and H - fill!(hessian, NaN) - gradient_hessian!(gradient, hessian, model, params) - @test hessian ≈ true_hessian rtol = rtol atol = atol - - # F, G and H - fill!(hessian, NaN) - objective_gradient_hessian!(gradient, hessian, model, params) - @test hessian ≈ true_hessian rtol = rtol atol = atol + gradient = fill!(similar(params), NaN) + + hessian_H = fill!(similar(parent(true_hessian)), NaN) + hessian!(hessian_H, model, params) + + hessian_FH = fill!(similar(hessian_H), NaN) + objective_hessian!(hessian_FH, model, params) + + hessian_GH = fill!(similar(hessian_H), NaN) + gradient_hessian!(gradient, hessian_GH, model, params) + + hessian_FGH = fill!(similar(hessian_H), NaN) + objective_gradient_hessian!(gradient, hessian_FGH, model, params) + + @test hessian_H == hessian_FH == hessian_GH == hessian_FGH + + @test hessian_H ≈ true_hessian rtol = rtol atol = atol end # map from the SEM.jl name of the fit measure to the lavaan's one From 3945afa4ea0dec63b2787d56380d90911f4f2114 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:13:14 -0700 Subject: [PATCH 55/74] start_simple(): code cleanup also warn if params are overwritten --- .../start_val/start_simple.jl | 73 ++++++++++--------- 1 file changed, 40 insertions(+), 33 deletions(-) diff --git a/src/additional_functions/start_val/start_simple.jl b/src/additional_functions/start_val/start_simple.jl index afdbf92e..4f9d7063 100644 --- a/src/additional_functions/start_val/start_simple.jl +++ b/src/additional_functions/start_val/start_simple.jl @@ -33,52 +33,59 @@ function start_simple( start_means = 0.0, kwargs..., ) - A, S, F_ind, M, n_par = ram_matrices.A, - ram_matrices.S, - observed_var_indices(ram_matrices), - ram_matrices.M, - nparams(ram_matrices) + A, S, M = ram_matrices.A, ram_matrices.S, ram_matrices.M + obs_inds = Set(observed_var_indices(ram_matrices)) + C_indices = CartesianIndices(size(A)) - start_val = zeros(n_par) - n_var = nvars(ram_matrices) + start_vals = Vector{Float64}(undef, nparams(ram_matrices)) + for i in eachindex(start_vals) + par = 0.0 - C_indices = CartesianIndices((n_var, n_var)) - - for i in 1:n_par Si_ind = param_occurences(S, i) - Ai_ind = param_occurences(A, i) if length(Si_ind) != 0 # use the first occurence of the parameter to determine starting value c_ind = C_indices[Si_ind[1]] if c_ind[1] == c_ind[2] - if c_ind[1] ∈ F_ind - start_val[i] = start_variances_observed - else - start_val[i] = start_variances_latent - end + par = ifelse( + c_ind[1] ∈ obs_inds, + start_variances_observed, + start_variances_latent, + ) else - o1 = c_ind[1] ∈ F_ind - o2 = c_ind[2] ∈ F_ind - if o1 & o2 - start_val[i] = start_covariances_observed - elseif !o1 & !o2 - start_val[i] = start_covariances_latent - else - start_val[i] = start_covariances_obs_lat - end + o1 = c_ind[1] ∈ obs_inds + o2 = c_ind[2] ∈ obs_inds + par = ifelse( + o1 && o2, + start_covariances_observed, + ifelse(!o1 && !o2, start_covariances_latent, start_covariances_obs_lat), + ) end - elseif length(Ai_ind) != 0 + end + + Ai_ind = param_occurences(A, i) + if length(Ai_ind) != 0 + iszero(par) || + @warn "param[$i]=$(params(ram_matrices, i)) is already set to $par" c_ind = C_indices[Ai_ind[1]] - if (c_ind[1] ∈ F_ind) & !(c_ind[2] ∈ F_ind) - start_val[i] = start_loadings - else - start_val[i] = start_regressions + par = ifelse( + (c_ind[1] ∈ obs_inds) && !(c_ind[2] ∈ obs_inds), + start_loadings, + start_regressions, + ) + end + + if !isnothing(M) + Mi_inds = param_occurences(M, i) + if length(Mi_inds) != 0 + iszero(par) || + @warn "param[$i]=$(params(ram_matrices, i)) is already set to $par" + par = start_means end - elseif !isnothing(M) && (length(param_occurences(M, i)) != 0) - start_val[i] = start_means end + + start_vals[i] = par end - return start_val + return start_vals end # multigroup models From f7b3176f2081030189fb8637d8bff98d95d0d762 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:13:14 -0700 Subject: [PATCH 56/74] start_simple(): start vals for lat and obs means --- src/additional_functions/start_val/start_simple.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/additional_functions/start_val/start_simple.jl b/src/additional_functions/start_val/start_simple.jl index 4f9d7063..2228bc92 100644 --- a/src/additional_functions/start_val/start_simple.jl +++ b/src/additional_functions/start_val/start_simple.jl @@ -30,7 +30,8 @@ function start_simple( start_covariances_observed = 0.0, start_covariances_latent = 0.0, start_covariances_obs_lat = 0.0, - start_means = 0.0, + start_mean_latent = 0.0, + start_mean_observed = 0.0, kwargs..., ) A, S, M = ram_matrices.A, ram_matrices.S, ram_matrices.M @@ -79,7 +80,7 @@ function start_simple( if length(Mi_inds) != 0 iszero(par) || @warn "param[$i]=$(params(ram_matrices, i)) is already set to $par" - par = start_means + par = ifelse(Mi_inds[1] ∈ obs_inds, start_mean_observed, start_mean_latent) end end From 79f14e3dcc4ef1a204da787d7ff3b93b937c39d7 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:13:14 -0700 Subject: [PATCH 57/74] observed_vars(RAMMatrices; order): rows/cols order --- src/frontend/specification/RAMMatrices.jl | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/frontend/specification/RAMMatrices.jl b/src/frontend/specification/RAMMatrices.jl index ca77041d..866c9c12 100644 --- a/src/frontend/specification/RAMMatrices.jl +++ b/src/frontend/specification/RAMMatrices.jl @@ -36,17 +36,22 @@ end latent_var_indices(ram::RAMMatrices) = [i for i in axes(ram.F, 2) if islatent_var(ram, i)] -# observed variables in the order as they appear in ram.F rows -function observed_vars(ram::RAMMatrices) +# observed variables, if order=:rows, the order is as they appear in ram.F rows +# if order=:columns, the order is as they appear in the comined variables list (ram.F columns) +function observed_vars(ram::RAMMatrices; order::Symbol = :rows) + order ∈ [:rows, :columns] || + throw(ArgumentError("order kwarg should be :rows or :columns")) if isnothing(ram.vars) @warn "Your RAMMatrices do not contain variable names. Please make sure the order of variables in your data is correct!" return nothing else + nobs = 0 obs_vars = Vector{Symbol}(undef, nobserved_vars(ram)) @inbounds for (i, v) in enumerate(vars(ram)) colptr = ram.F.colptr[i] if ram.F.colptr[i+1] > colptr # is observed - obs_vars[ram.F.rowval[colptr]] = v + nobs += 1 + obs_vars[order == :rows ? ram.F.rowval[colptr] : nobs] = v end end return obs_vars From 01910421cf95c3d994e71130f0e2ffc9eda98c0f Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:13:14 -0700 Subject: [PATCH 58/74] observed_var_indices(::RAMMatrices; order=:columns) --- src/frontend/specification/RAMMatrices.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/frontend/specification/RAMMatrices.jl b/src/frontend/specification/RAMMatrices.jl index 866c9c12..e20f83d5 100644 --- a/src/frontend/specification/RAMMatrices.jl +++ b/src/frontend/specification/RAMMatrices.jl @@ -22,13 +22,16 @@ vars(ram::RAMMatrices) = ram.vars isobserved_var(ram::RAMMatrices, i::Integer) = ram.F.colptr[i+1] > ram.F.colptr[i] islatent_var(ram::RAMMatrices, i::Integer) = ram.F.colptr[i+1] == ram.F.colptr[i] -# indices of observed variables in the order as they appear in ram.F rows -function observed_var_indices(ram::RAMMatrices) +# indices of observed variables, for order=:rows (default), the order is as they appear in ram.F rows +# if order=:columns, the order is as they appear in the comined variables list (ram.F columns) +function observed_var_indices(ram::RAMMatrices; order::Symbol = :rows) obs_inds = Vector{Int}(undef, nobserved_vars(ram)) + nobs = 0 @inbounds for i in 1:nvars(ram) colptr = ram.F.colptr[i] if ram.F.colptr[i+1] > colptr # is observed - obs_inds[ram.F.rowval[colptr]] = i + nobs += 1 + obs_inds[order == :rows ? ram.F.rowval[colptr] : nobs] = i end end return obs_inds From e08fc5e58ef753be0fb48d1329071a709eb0c0ab Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:13:14 -0700 Subject: [PATCH 59/74] move sparse mtx utils to new file use sparse utils in RAMMatrices --- src/StructuralEquationModels.jl | 1 + src/additional_functions/sparse_utils.jl | 83 +++++++++++++++++++++++ src/frontend/specification/RAMMatrices.jl | 33 ++------- 3 files changed, 88 insertions(+), 29 deletions(-) create mode 100644 src/additional_functions/sparse_utils.jl diff --git a/src/StructuralEquationModels.jl b/src/StructuralEquationModels.jl index 45a12dae..fb833660 100644 --- a/src/StructuralEquationModels.jl +++ b/src/StructuralEquationModels.jl @@ -32,6 +32,7 @@ include("objective_gradient_hessian.jl") # helper objects and functions include("additional_functions/commutation_matrix.jl") +include("additional_functions/sparse_utils.jl") include("additional_functions/params_array.jl") # fitted objects diff --git a/src/additional_functions/sparse_utils.jl b/src/additional_functions/sparse_utils.jl new file mode 100644 index 00000000..76381c47 --- /dev/null +++ b/src/additional_functions/sparse_utils.jl @@ -0,0 +1,83 @@ +# generate sparse matrix with 1 in each row +function eachrow_to_col( + ::Type{T}, + column_indices::AbstractVector{Int}, + ncolumns::Integer, +) where {T} + nrows = length(column_indices) + (nrows > ncolumns) && throw( + DimensionMismatch( + "The number of rows ($nrows) cannot exceed the number of columns ($ncolumns)", + ), + ) + all(i -> 1 <= i <= ncolumns, column_indices) || + throw(ArgumentError("All column indices must be between 1 and $ncolumns")) + + sparse(eachindex(column_indices), column_indices, ones(T, nrows), nrows, ncolumns) +end + +eachrow_to_col(column_indices::AbstractVector{Int}, ncolumns::Integer) = + eachrow_to_col(Float64, column_indices, ncolumns) + +# generate sparse matrix with 1 in each column +function eachcol_to_row( + ::Type{T}, + row_indices::AbstractVector{Int}, + nrows::Integer, +) where {T} + ncols = length(row_indices) + (ncols > nrows) && throw( + DimensionMismatch( + "The number of columns ($ncols) cannot exceed the number of rows ($nrows)", + ), + ) + all(i -> 1 <= i <= nrows, row_indices) || + throw(ArgumentError("All row indices must be between 1 and $nrows")) + + sparse(row_indices, eachindex(row_indices), ones(T, ncols), nrows, ncols) +end + +eachcol_to_row(row_indices::AbstractVector{Int}, nrows::Integer) = + eachcol_to_row(Float64, row_indices, nrows) + +# non-zero indices of columns in matrix A generated by eachrow_to_col() +# if order == :rows, then the indices are in the order of the rows, +# if order == :columns, the indices are in the order of the columns +function nzcols_eachrow_to_col(F, A::SparseMatrixCSC; order::Symbol = :rows) + order ∈ [:rows, :columns] || throw(ArgumentError("order must be :rows or :columns")) + T = typeof(F(1)) + res = Vector{T}(undef, size(A, 1)) + n = 0 + for i in 1:size(A, 2) + colptr = A.colptr[i] + next_colptr = A.colptr[i+1] + if next_colptr > colptr # non-zero + @assert next_colptr - colptr == 1 + n += 1 + res[order == :rows ? A.rowval[colptr] : n] = F(i) + end + end + @assert n == size(A, 1) + return res +end + +nzcols_eachrow_to_col(A::SparseMatrixCSC; order::Symbol = :rows) = + nzcols_eachrow_to_col(identity, A, order = order) + +# same as nzcols_eachrow_to_col() +# but without assumption that each row cooresponds to exactly one column +# the order is always columns order +function nzcols(F, A::SparseMatrixCSC) + T = typeof(F(1)) + res = Vector{T}() + for i in 1:size(A, 2) + colptr = A.colptr[i] + next_colptr = A.colptr[i+1] + if next_colptr > colptr # non-zero + push!(res, F(i)) + end + end + return res +end + +nzcols(A::SparseMatrixCSC) = nzcols(identity, A) diff --git a/src/frontend/specification/RAMMatrices.jl b/src/frontend/specification/RAMMatrices.jl index e20f83d5..7997cc32 100644 --- a/src/frontend/specification/RAMMatrices.jl +++ b/src/frontend/specification/RAMMatrices.jl @@ -24,18 +24,8 @@ islatent_var(ram::RAMMatrices, i::Integer) = ram.F.colptr[i+1] == ram.F.colptr[i # indices of observed variables, for order=:rows (default), the order is as they appear in ram.F rows # if order=:columns, the order is as they appear in the comined variables list (ram.F columns) -function observed_var_indices(ram::RAMMatrices; order::Symbol = :rows) - obs_inds = Vector{Int}(undef, nobserved_vars(ram)) - nobs = 0 - @inbounds for i in 1:nvars(ram) - colptr = ram.F.colptr[i] - if ram.F.colptr[i+1] > colptr # is observed - nobs += 1 - obs_inds[order == :rows ? ram.F.rowval[colptr] : nobs] = i - end - end - return obs_inds -end +observed_var_indices(ram::RAMMatrices; order::Symbol = :rows) = + nzcols_eachrow_to_col(ram.F; order) latent_var_indices(ram::RAMMatrices) = [i for i in axes(ram.F, 2) if islatent_var(ram, i)] @@ -48,16 +38,7 @@ function observed_vars(ram::RAMMatrices; order::Symbol = :rows) @warn "Your RAMMatrices do not contain variable names. Please make sure the order of variables in your data is correct!" return nothing else - nobs = 0 - obs_vars = Vector{Symbol}(undef, nobserved_vars(ram)) - @inbounds for (i, v) in enumerate(vars(ram)) - colptr = ram.F.colptr[i] - if ram.F.colptr[i+1] > colptr # is observed - nobs += 1 - obs_vars[order == :rows ? ram.F.rowval[colptr] : nobs] = v - end - end - return obs_vars + return nzcols_eachrow_to_col(Base.Fix1(getindex, vars(ram)), ram.F; order = order) end end @@ -240,13 +221,7 @@ function RAMMatrices( return RAMMatrices( ParamsMatrix{T}(A_inds, A_consts, (n_vars, n_vars)), ParamsMatrix{T}(S_inds, S_consts, (n_vars, n_vars)), - sparse( - 1:n_observed, - [vars_index[var] for var in partable.observed_vars], - ones(T, n_observed), - n_observed, - n_vars, - ), + eachrow_to_col(T, [vars_index[var] for var in partable.observed_vars], n_vars), !isnothing(M_inds) ? ParamsVector{T}(M_inds, M_consts, (n_vars,)) : nothing, param_labels, vars_sorted, From b4d738c31bf762dc300d3f4e076f106e2d4eaf46 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:13:14 -0700 Subject: [PATCH 60/74] reorder_observed_vars!(spec) method --- src/frontend/specification/ParameterTable.jl | 12 ++++++++++++ src/frontend/specification/RAMMatrices.jl | 13 +++++++++++++ 2 files changed, 25 insertions(+) diff --git a/src/frontend/specification/ParameterTable.jl b/src/frontend/specification/ParameterTable.jl index 739ad571..a29deae7 100644 --- a/src/frontend/specification/ParameterTable.jl +++ b/src/frontend/specification/ParameterTable.jl @@ -247,6 +247,18 @@ See [sort_vars!](@ref) for in-place version. """ sort_vars(partable::ParameterTable) = sort_vars!(deepcopy(partable)) +function reorder_observed_vars!(partable::ParameterTable, new_order::AbstractVector{Symbol}) + # just check that it's 1-to-1 + source_to_dest_perm( + partable.observed_vars, + new_order, + one_to_one = true, + entities = "observed_vars", + ) + copy!(partable.observed_vars, new_order) + return partable +end + # add a row -------------------------------------------------------------------------------- function Base.push!(partable::ParameterTable, d::Union{AbstractDict{Symbol}, NamedTuple}) diff --git a/src/frontend/specification/RAMMatrices.jl b/src/frontend/specification/RAMMatrices.jl index 7997cc32..023e98b9 100644 --- a/src/frontend/specification/RAMMatrices.jl +++ b/src/frontend/specification/RAMMatrices.jl @@ -234,6 +234,19 @@ Base.convert( param_labels::Union{AbstractVector{Symbol}, Nothing} = nothing, ) = RAMMatrices(partable; param_labels) +# reorders the observed variables in the RAMMatrices, i.e. the order of the rows in F +function reorder_observed_vars!(ram::RAMMatrices, new_order::AbstractVector{Symbol}) + # just check that it's 1-to-1 + src2dest = source_to_dest_perm( + observed_vars(ram), + new_order, + one_to_one = true, + entities = "observed_vars", + ) + copy!(ram.F, ram.F[src2dest, :]) + return ram +end + ############################################################################################ ### get parameter table from RAMMatrices ############################################################################################ From 7b099fc9f36e1586bcef1f6ef00d6eff4a6d3d0d Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:14:28 -0700 Subject: [PATCH 61/74] vech() and vechinds() functions --- src/additional_functions/helper.jl | 33 ++++++++++++++++++++++++++++++ src/implied/RAM/symbolic.jl | 5 +---- src/loss/WLS/WLS.jl | 3 +-- 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/src/additional_functions/helper.jl b/src/additional_functions/helper.jl index 5442357f..d7c7d810 100644 --- a/src/additional_functions/helper.jl +++ b/src/additional_functions/helper.jl @@ -69,6 +69,39 @@ function elimination_matrix(n::Integer) return L end +# vector of lower-triangular values of a square matrix +function vech(A::AbstractMatrix{T}) where {T} + size(A, 1) == size(A, 2) || + throw(ArgumentError("Matrix must be square, $(size(A)) given")) + n = size(A, 1) + v = Vector{T}(undef, (n * (n + 1)) >> 1) + k = 0 + for (j, Aj) in enumerate(eachcol(A)), i in j:n + @inbounds v[k+=1] = Aj[i] + end + @assert k == length(v) + return v +end + +# vector of lower-triangular linear indices of a nXn square matrix +function vechinds(n::Integer) + A_lininds = LinearIndices((n, n)) + v = Vector{Int}(undef, (n * (n + 1)) >> 1) + k = 0 + for j in 1:n, i in j:n + @inbounds v[k+=1] = A_lininds[i, j] + end + @assert k == length(v) + return v +end + +# vector of lower-triangular linear indices of a square matrix +function vechinds(A::AbstractMatrix) + size(A, 1) == size(A, 2) || + throw(ArgumentError("Matrix must be square, $(size(A)) given")) + return vechinds(size(A, 1)) +end + # truncate eigenvalues of a symmetric matrix and return the result function trunc_eigvals( mtx::AbstractMatrix{T}, diff --git a/src/implied/RAM/symbolic.jl b/src/implied/RAM/symbolic.jl index 52a192e6..647b5b41 100644 --- a/src/implied/RAM/symbolic.jl +++ b/src/implied/RAM/symbolic.jl @@ -198,10 +198,7 @@ end function eval_Σ_symbolic(S, I_A⁻¹, F; vech::Bool = false, simplify::Bool = false) Σ = F * I_A⁻¹ * S * permutedims(I_A⁻¹) * permutedims(F) Σ = Array(Σ) - if vech - n = size(Σ, 1) - Σ = [Σ[i, j] for j in 1:n for i in j:n] - end + vech && (Σ = SEM.vech(Σ)) if simplify Threads.@threads for i in eachindex(Σ) Σ[i] = Symbolics.simplify(Σ[i]) diff --git a/src/loss/WLS/WLS.jl b/src/loss/WLS/WLS.jl index d067e346..2ccef15b 100644 --- a/src/loss/WLS/WLS.jl +++ b/src/loss/WLS/WLS.jl @@ -95,8 +95,7 @@ function SemWLS( check_observed_vars(observed, implied) nobs_vars = nobserved_vars(observed) - tril_ind = filter(x -> (x[1] >= x[2]), CartesianIndices(obs_cov(observed))) - s = obs_cov(observed)[tril_ind] + s = vech(obs_cov(observed)) size(s) == size(implied.Σ) || throw( DimensionMismatch( "SemWLS requires implied covariance to be in vech-ed form " * From 7cc82bc8d11b3e226b22642ede7df3b87b9504be Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:18:27 -0700 Subject: [PATCH 62/74] RAMMatrices(): ctor to replace params --- src/frontend/specification/RAMMatrices.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/frontend/specification/RAMMatrices.jl b/src/frontend/specification/RAMMatrices.jl index 023e98b9..357ab408 100644 --- a/src/frontend/specification/RAMMatrices.jl +++ b/src/frontend/specification/RAMMatrices.jl @@ -113,6 +113,17 @@ function RAMMatrices(; return RAMMatrices(A, S, F, M, copy(param_labels), vars) end +# copy RAMMatrices replacing the parameters vector +# (e.g. when reordering parameters or adding new parameters to the ensemble model) +RAMMatrices(ram::RAMMatrices; params::AbstractVector{Symbol}) = RAMMatrices(; + A = materialize(ram.A, SEM.params(ram)), + S = materialize(ram.S, SEM.params(ram)), + F = copy(ram.F), + M = !isnothing(ram.M) ? materialize(ram.M, SEM.params(ram)) : nothing, + params, + vars = ram.vars, +) + ############################################################################################ ### get RAMMatrices from parameter table ############################################################################################ From 3ae9f01840dd428d6922edffe5413159d4b290c6 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:25:18 -0700 Subject: [PATCH 63/74] use `@printf` to limit signif digits printed --- Project.toml | 1 + ext/SEMNLOptExt/NLopt.jl | 2 +- ext/SEMNLOptExt/SEMNLOptExt.jl | 2 +- ext/SEMProximalOptExt/ProximalAlgorithms.jl | 2 +- ext/SEMProximalOptExt/SEMProximalOptExt.jl | 1 + src/StructuralEquationModels.jl | 1 + src/frontend/fit/SemFit.jl | 2 +- src/frontend/fit/summary.jl | 3 +-- 8 files changed, 8 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index f56b4ed7..cdc2f5a7 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/ext/SEMNLOptExt/NLopt.jl b/ext/SEMNLOptExt/NLopt.jl index 87601030..3d430201 100644 --- a/ext/SEMNLOptExt/NLopt.jl +++ b/ext/SEMNLOptExt/NLopt.jl @@ -189,7 +189,7 @@ end function Base.show(io::IO, result::NLoptResult) print(io, "Optimizer status: $(result.result[3]) \n") - print(io, "Minimum: $(round(result.result[1]; digits = 2)) \n") + @printf(io, "Minimum: %.4g\n", result.result[1]) print(io, "Algorithm: $(result.problem.algorithm) \n") print(io, "No. evaluations: $(result.problem.numevals) \n") end diff --git a/ext/SEMNLOptExt/SEMNLOptExt.jl b/ext/SEMNLOptExt/SEMNLOptExt.jl index 61c41338..24439459 100644 --- a/ext/SEMNLOptExt/SEMNLOptExt.jl +++ b/ext/SEMNLOptExt/SEMNLOptExt.jl @@ -1,6 +1,6 @@ module SEMNLOptExt -using StructuralEquationModels, NLopt +using StructuralEquationModels, NLopt, Printf SEM = StructuralEquationModels diff --git a/ext/SEMProximalOptExt/ProximalAlgorithms.jl b/ext/SEMProximalOptExt/ProximalAlgorithms.jl index 0937ee04..15c8215f 100644 --- a/ext/SEMProximalOptExt/ProximalAlgorithms.jl +++ b/ext/SEMProximalOptExt/ProximalAlgorithms.jl @@ -99,7 +99,7 @@ function Base.show(io::IO, struct_inst::SemOptimizerProximal) end function Base.show(io::IO, result::ProximalResult) - print(io, "Minimum: $(round(result.minimum; digits = 2)) \n") + @printf(io, "Minimum: %.4g\n", result.minimum) print(io, "No. evaluations: $(result.n_iterations) \n") print(io, "Operator: $(nameof(typeof(result.optimizer.operator_g))) \n") op_h = result.optimizer.operator_h diff --git a/ext/SEMProximalOptExt/SEMProximalOptExt.jl b/ext/SEMProximalOptExt/SEMProximalOptExt.jl index bedf1920..565207f3 100644 --- a/ext/SEMProximalOptExt/SEMProximalOptExt.jl +++ b/ext/SEMProximalOptExt/SEMProximalOptExt.jl @@ -3,6 +3,7 @@ module SEMProximalOptExt using StructuralEquationModels using StructuralEquationModels: print_type_name, print_field_types using ProximalAlgorithms +using Printf export SemOptimizerProximal diff --git a/src/StructuralEquationModels.jl b/src/StructuralEquationModels.jl index fb833660..3ba19fba 100644 --- a/src/StructuralEquationModels.jl +++ b/src/StructuralEquationModels.jl @@ -1,6 +1,7 @@ module StructuralEquationModels using LinearAlgebra, + Printf, Optim, NLSolversBase, Statistics, diff --git a/src/frontend/fit/SemFit.jl b/src/frontend/fit/SemFit.jl index 1d2e82a6..f6fbb28b 100644 --- a/src/frontend/fit/SemFit.jl +++ b/src/frontend/fit/SemFit.jl @@ -37,7 +37,7 @@ function Base.show(io::IO, semfit::SemFit) print(io, "\n") print(io, semfit.model) print(io, "\n") - #print(io, "Objective value: $(round(semfit.minimum, digits = 4)) \n") + #@printf(io, "Objective value: %.4g\n", semfit.minimum) print(io, "------------- Optimization result ------------- \n") print(io, "\n") print(io, "engine: ") diff --git a/src/frontend/fit/summary.jl b/src/frontend/fit/summary.jl index fe7ea930..3364f7f5 100644 --- a/src/frontend/fit/summary.jl +++ b/src/frontend/fit/summary.jl @@ -69,8 +69,7 @@ function details(sem_fit::SemFit; show_fitmeasures = false, color = :light_cyan, key_length = length(string(k)) print(k) print(repeat(" ", goal_length - key_length)) - print(round(a[k]; digits = 2)) - print("\n") + @printf("%.3g\n", a[k]) end end print("\n") From f15d7fb34488c657879e4406cb549ecb86124678 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:30:16 -0700 Subject: [PATCH 64/74] ML/FIML: workaround generic_matmul issue --- src/loss/ML/FIML.jl | 12 +++++++----- src/loss/ML/ML.jl | 14 +++++--------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/loss/ML/FIML.jl b/src/loss/ML/FIML.jl index 74d5edfb..0a9f66b6 100644 --- a/src/loss/ML/FIML.jl +++ b/src/loss/ML/FIML.jl @@ -181,16 +181,18 @@ end function ∇F_fiml_outer!(G, JΣ, Jμ, loss::SemFIML) implied = loss.implied + I_A⁻¹ = parent(implied.I_A⁻¹) + F⨉I_A⁻¹ = parent(implied.F⨉I_A⁻¹) + S = parent(implied.S) + Iₙ = sparse(1.0I, size(implied.A)...) - P = kron(implied.F⨉I_A⁻¹, implied.F⨉I_A⁻¹) - Q = kron(implied.S * implied.I_A⁻¹', Iₙ) + P = kron(F⨉I_A⁻¹, F⨉I_A⁻¹) + Q = kron(S * I_A⁻¹', Iₙ) Q .+= loss.commutator * Q ∇Σ = P * (implied.∇S + Q * implied.∇A) - ∇μ = - implied.F⨉I_A⁻¹ * implied.∇M + - kron((implied.I_A⁻¹ * implied.M)', implied.F⨉I_A⁻¹) * implied.∇A + ∇μ = F⨉I_A⁻¹ * implied.∇M + kron((I_A⁻¹ * implied.M)', F⨉I_A⁻¹) * implied.∇A mul!(G, ∇Σ', JΣ) # actually transposed mul!(G, ∇μ', Jμ, -1, 1) diff --git a/src/loss/ML/ML.jl b/src/loss/ML/ML.jl index 5052aa8d..58ecbe9f 100644 --- a/src/loss/ML/ML.jl +++ b/src/loss/ML/ML.jl @@ -209,9 +209,9 @@ function evaluate!(objective, gradient, hessian, loss::SemML, par) end if !isnothing(gradient) - S = implied.S - F⨉I_A⁻¹ = implied.F⨉I_A⁻¹ - I_A⁻¹ = implied.I_A⁻¹ + S = parent(implied.S) + F⨉I_A⁻¹ = parent(implied.F⨉I_A⁻¹) + I_A⁻¹ = parent(implied.I_A⁻¹) ∇A = implied.∇A ∇S = implied.∇S @@ -223,16 +223,12 @@ function evaluate!(objective, gradient, hessian, loss::SemML, par) C = mul!( loss.varXvar_1, F⨉I_A⁻¹', - mul!( - loss.obsXvar_1, - Symmetric(mul!(loss.obsXobs_3, one_Σ⁻¹Σₒ, Σ⁻¹)), - F⨉I_A⁻¹, - ), + mul!(loss.obsXvar_1, mul!(loss.obsXobs_3, one_Σ⁻¹Σₒ, Σ⁻¹), F⨉I_A⁻¹), ) mul!( gradient, ∇A', - vec(mul!(loss.varXvar_3, Symmetric(C), mul!(loss.varXvar_2, S, I_A⁻¹'))), + vec(mul!(loss.varXvar_3, C, mul!(loss.varXvar_2, S, I_A⁻¹'))), 2, 0, ) From 5dd1ae58f201d224dfcfba3d1b7933a43178d730 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:38:12 -0700 Subject: [PATCH 65/74] BlackBoxOptim.jl backend support --- Project.toml | 3 + ext/SEMBlackBoxOptimExt/AdamMutation.jl | 49 +++++ ext/SEMBlackBoxOptimExt/BlackBoxOptim.jl | 89 ++++++++ ext/SEMBlackBoxOptimExt/DiffEvoFactory.jl | 196 ++++++++++++++++++ .../SEMBlackBoxOptimExt.jl | 13 ++ .../SemOptimizerBlackBoxOptim.jl | 91 ++++++++ 6 files changed, 441 insertions(+) create mode 100644 ext/SEMBlackBoxOptimExt/AdamMutation.jl create mode 100644 ext/SEMBlackBoxOptimExt/BlackBoxOptim.jl create mode 100644 ext/SEMBlackBoxOptimExt/DiffEvoFactory.jl create mode 100644 ext/SEMBlackBoxOptimExt/SEMBlackBoxOptimExt.jl create mode 100644 ext/SEMBlackBoxOptimExt/SemOptimizerBlackBoxOptim.jl diff --git a/Project.toml b/Project.toml index cdc2f5a7..d72009b8 100644 --- a/Project.toml +++ b/Project.toml @@ -53,9 +53,12 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" test = ["Test"] [weakdeps] +BlackBoxOptim = "a134a8b2-14d6-55f6-9291-3336d3ab0209" NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ProximalAlgorithms = "140ffc9f-1907-541a-a177-7475e0a401e9" [extensions] SEMNLOptExt = "NLopt" SEMProximalOptExt = "ProximalAlgorithms" +SEMBlackBoxOptimExt = ["BlackBoxOptim", "Optimisers"] diff --git a/ext/SEMBlackBoxOptimExt/AdamMutation.jl b/ext/SEMBlackBoxOptimExt/AdamMutation.jl new file mode 100644 index 00000000..4f1a80e3 --- /dev/null +++ b/ext/SEMBlackBoxOptimExt/AdamMutation.jl @@ -0,0 +1,49 @@ +# mutate by moving in the gradient direction +mutable struct AdamMutation{M <: AbstractSem, O, S} <: MutationOperator + model::M + optim::O + opt_state::S + params_fraction::Float64 + + function AdamMutation(model::AbstractSem, params::AbstractDict) + optim = RAdam(params[:AdamMutation_eta], params[:AdamMutation_beta]) + params_fraction = params[:AdamMutation_params_fraction] + opt_state = Optimisers.init(optim, Vector{Float64}(undef, nparams(model))) + + new{typeof(model), typeof(optim), typeof(opt_state)}( + model, + optim, + opt_state, + params_fraction, + ) + end +end + +Base.show(io::IO, op::AdamMutation) = + print(io, "AdamMutation(", op.optim, " state[3]=", op.opt_state[3], ")") + +""" +Default parameters for `AdamMutation`. +""" +const AdamMutation_DefaultOptions = ParamsDict( + :AdamMutation_eta => 1E-1, + :AdamMutation_beta => (0.99, 0.999), + :AdamMutation_params_fraction => 0.25, +) + +function BlackBoxOptim.apply!(m::AdamMutation, v::AbstractVector{<:Real}, target_index::Int) + grad = similar(v) + obj = SEM.evaluate!(0.0, grad, nothing, m.model, v) + @inbounds for i in eachindex(grad) + (rand() > m.params_fraction) && (grad[i] = 0.0) + end + + m.opt_state, dv = Optimisers.apply!(m.optim, m.opt_state, v, grad) + if (m.opt_state[3][1] <= 1E-20) || !isfinite(obj) || any(!isfinite, dv) + m.opt_state = Optimisers.init(m.optim, v) + else + v .-= dv + end + + return v +end diff --git a/ext/SEMBlackBoxOptimExt/BlackBoxOptim.jl b/ext/SEMBlackBoxOptimExt/BlackBoxOptim.jl new file mode 100644 index 00000000..f9d67e06 --- /dev/null +++ b/ext/SEMBlackBoxOptimExt/BlackBoxOptim.jl @@ -0,0 +1,89 @@ +############################################################################################ +### connect to BlackBoxOptim.jl as backend +############################################################################################ + +""" +""" +struct SemOptimizerBlackBoxOptim <: SemOptimizer{:BlackBoxOptim} + lower_bound::Float64 # default lower bound + variance_lower_bound::Float64 # default variance lower bound + lower_bounds::Union{Dict{Symbol, Float64}, Nothing} + + upper_bound::Float64 # default upper bound + upper_bounds::Union{Dict{Symbol, Float64}, Nothing} +end + +function SemOptimizerBlackBoxOptim(; + lower_bound::Float64 = -1000.0, + lower_bounds::Union{AbstractDict{Symbol, Float64}, Nothing} = nothing, + variance_lower_bound::Float64 = 0.001, + upper_bound::Float64 = 1000.0, + upper_bounds::Union{AbstractDict{Symbol, Float64}, Nothing} = nothing, + kwargs..., +) + if variance_lower_bound < 0.0 + throw(ArgumentError("variance_lower_bound must be non-negative")) + end + return SemOptimizerBlackBoxOptim( + lower_bound, + variance_lower_bound, + lower_bounds, + upper_bound, + upper_bounds, + ) +end + +SEM.SemOptimizer{:BlackBoxOptim}(args...; kwargs...) = + SemOptimizerBlackBoxOptim(args...; kwargs...) + +SEM.algorithm(optimizer::SemOptimizerBlackBoxOptim) = optimizer.algorithm +SEM.options(optimizer::SemOptimizerBlackBoxOptim) = optimizer.options + +struct SemModelBlackBoxOptimProblem{M <: AbstractSem} <: + OptimizationProblem{ScalarFitnessScheme{true}} + model::M + fitness_scheme::ScalarFitnessScheme{true} + search_space::ContinuousRectSearchSpace +end + +function BlackBoxOptim.search_space(model::AbstractSem) + optim = model.optimizer::SemOptimizerBlackBoxOptim + varparams = Set(SEM.variance_params(model.implied.ram_matrices)) + return ContinuousRectSearchSpace( + [ + begin + def = in(p, varparams) ? optim.variance_lower_bound : optim.lower_bound + isnothing(optim.lower_bounds) ? def : get(optim.lower_bounds, p, def) + end for p in SEM.params(model) + ], + [ + begin + def = optim.upper_bound + isnothing(optim.upper_bounds) ? def : get(optim.upper_bounds, p, def) + end for p in SEM.params(model) + ], + ) +end + +function SemModelBlackBoxOptimProblem( + model::AbstractSem, + optimizer::SemOptimizerBlackBoxOptim, +) + SemModelBlackBoxOptimProblem(model, ScalarFitnessScheme{true}(), search_space(model)) +end + +BlackBoxOptim.fitness(params::AbstractVector, wrapper::SemModelBlackBoxOptimProblem) = + return SEM.evaluate!(0.0, nothing, nothing, wrapper.model, params) + +# sem_fit method +function SEM.sem_fit( + optimizer::SemOptimizerBlackBoxOptim, + model::AbstractSem, + start_params::AbstractVector; + MaxSteps::Integer = 50000, + kwargs..., +) + problem = SemModelBlackBoxOptimProblem(model, optimizer) + res = bboptimize(problem; MaxSteps, kwargs...) + return SemFit(best_fitness(res), best_candidate(res), nothing, model, res) +end diff --git a/ext/SEMBlackBoxOptimExt/DiffEvoFactory.jl b/ext/SEMBlackBoxOptimExt/DiffEvoFactory.jl new file mode 100644 index 00000000..75080541 --- /dev/null +++ b/ext/SEMBlackBoxOptimExt/DiffEvoFactory.jl @@ -0,0 +1,196 @@ +""" +Base class for factories of optimizers for a specific problem. +""" +abstract type OptimizerFactory{P <: OptimizationProblem} end + +problem(factory::OptimizerFactory) = factory.problem + +const OptController_DefaultParameters = ParamsDict( + :MaxTime => 60.0, + :MaxSteps => 10^8, + :TraceMode => :compact, + :TraceInterval => 5.0, + :RecoverResults => false, + :SaveTrace => false, +) + +function generate_opt_controller(alg::Optimizer, optim_factory::OptimizerFactory, params) + return BlackBoxOptim.OptController( + alg, + problem(optim_factory), + BlackBoxOptim.chain( + BlackBoxOptim.DefaultParameters, + OptController_DefaultParameters, + params, + ), + ) +end + +function check_population( + factory::OptimizerFactory, + popmatrix::BlackBoxOptim.PopulationMatrix, +) + ssp = factory |> problem |> search_space + for i in 1:popsize(popmatrix) + @assert popmatrix[:, i] ∈ ssp "Individual $i is out of space: $(popmatrix[:,i])" # fitness: $(fitness(population, i))" + end +end + +initial_search_space(factory::OptimizerFactory, id::Int) = search_space(factory.problem) + +function initial_population_matrix(factory::OptimizerFactory, id::Int) + #@info "Standard initial_population_matrix()" + ini_ss = initial_search_space(factory, id) + if !isempty(factory.initial_population) + numdims(factory.initial_population) == numdims(factory.problem) || throw( + DimensionMismatch( + "Dimensions of :Population ($(numdims(factory.initial_population))) " * + "are different from the problem dimensions ($(numdims(factory.problem)))", + ), + ) + res = factory.initial_population[ + :, + StatsBase.sample( + 1:popsize(factory.initial_population), + factory.population_size, + ), + ] + else + res = rand_individuals(ini_ss, factory.population_size, method = :latin_hypercube) + end + prj = RandomBound(ini_ss) + if size(res, 2) > 1 + apply!(prj, view(res, :, 1), SEM.start_fabin3(factory.problem.model)) + end + if size(res, 2) > 2 + apply!(prj, view(res, :, 2), SEM.start_simple(factory.problem.model)) + end + return res +end + +# convert individuals in the archive into population matrix +population_matrix(archive::Any) = population_matrix!( + Matrix{Float64}(undef, length(BlackBoxOptim.params(first(archive))), length(archive)), + archive, +) + +function population_matrix!(pop::AbstractMatrix{<:Real}, archive::Any) + npars = length(BlackBoxOptim.params(first(archive))) + size(pop, 1) == npars || throw( + DimensionMismatch( + "Matrix rows count ($(size(pop, 1))) doesn't match the number of problem dimensions ($(npars))", + ), + ) + @inbounds for (i, indi) in enumerate(archive) + (i <= size(pop, 2)) || break + pop[:, i] .= BlackBoxOptim.params(indi) + end + if size(pop, 2) > length(archive) + @warn "Matrix columns count ($(size(pop, 2))) is bigger than population size ($(length(archive))), last columns not set" + end + return pop +end + +generate_embedder(factory::OptimizerFactory, id::Int, problem::OptimizationProblem) = + RandomBound(search_space(problem)) + +abstract type DiffEvoFactory{P <: OptimizationProblem} <: OptimizerFactory{P} end + +generate_selector( + factory::DiffEvoFactory, + id::Int, + problem::OptimizationProblem, + population, +) = RadiusLimitedSelector(get(factory.params, :selector_radius, popsize(population) ÷ 5)) + +function generate_modifier(factory::DiffEvoFactory, id::Int, problem::OptimizationProblem) + ops = GeneticOperator[ + MutationClock(UniformMutation(search_space(problem)), 1 / numdims(problem)), + BlackBoxOptim.AdaptiveDiffEvoRandBin1( + BlackBoxOptim.AdaptiveDiffEvoParameters( + factory.params[:fdistr], + factory.params[:crdistr], + ), + ), + SimplexCrossover{3}(1.05), + SimplexCrossover{2}(1.1), + #SimulatedBinaryCrossover(0.05, 16.0), + #SimulatedBinaryCrossover(0.05, 3.0), + #SimulatedBinaryCrossover(0.1, 5.0), + #SimulatedBinaryCrossover(0.2, 16.0), + UnimodalNormalDistributionCrossover{2}( + chain(BlackBoxOptim.UNDX_DefaultOptions, factory.params), + ), + UnimodalNormalDistributionCrossover{3}( + chain(BlackBoxOptim.UNDX_DefaultOptions, factory.params), + ), + ParentCentricCrossover{2}(chain(BlackBoxOptim.PCX_DefaultOptions, factory.params)), + ParentCentricCrossover{3}(chain(BlackBoxOptim.PCX_DefaultOptions, factory.params)), + ] + if problem isa SemModelBlackBoxOptimProblem + push!( + ops, + AdamMutation(problem.model, chain(AdamMutation_DefaultOptions, factory.params)), + ) + end + FAGeneticOperatorsMixture(ops) +end + +function generate_optimizer( + factory::DiffEvoFactory, + id::Int, + problem::OptimizationProblem, + popmatrix, +) + population = FitPopulation(popmatrix, nafitness(fitness_scheme(problem))) + BlackBoxOptim.DiffEvoOpt( + "AdaptiveDE/rand/1/bin/gradient", + population, + generate_selector(factory, id, problem, population), + generate_modifier(factory, id, problem), + generate_embedder(factory, id, problem), + ) +end + +const Population_DefaultParameters = ParamsDict( + :Population => BlackBoxOptim.PopulationMatrix(undef, 0, 0), + :PopulationSize => 100, +) + +const DE_DefaultParameters = chain( + ParamsDict( + :SelectorRadius => 0, + :fdistr => + BlackBoxOptim.BimodalCauchy(0.65, 0.1, 1.0, 0.1, clampBelow0 = false), + :crdistr => + BlackBoxOptim.BimodalCauchy(0.1, 0.1, 0.95, 0.1, clampBelow0 = false), + ), + Population_DefaultParameters, +) + +struct DefaultDiffEvoFactory{P <: OptimizationProblem} <: DiffEvoFactory{P} + problem::P + initial_population::BlackBoxOptim.PopulationMatrix + population_size::Int + params::ParamsDictChain +end + +DefaultDiffEvoFactory(problem::OptimizationProblem; kwargs...) = + DefaultDiffEvoFactory(problem, BlackBoxOptim.kwargs2dict(kwargs)) + +function DefaultDiffEvoFactory(problem::OptimizationProblem, params::AbstractDict) + params = chain(DE_DefaultParameters, params) + DefaultDiffEvoFactory{typeof(problem)}( + problem, + params[:Population], + params[:PopulationSize], + params, + ) +end + +function BlackBoxOptim.bbsetup(factory::OptimizerFactory; kwargs...) + popmatrix = initial_population_matrix(factory, 1) + check_population(factory, popmatrix) + alg = generate_optimizer(factory, 1, problem(factory), popmatrix) + return generate_opt_controller(alg, factory, BlackBoxOptim.kwargs2dict(kwargs)) +end diff --git a/ext/SEMBlackBoxOptimExt/SEMBlackBoxOptimExt.jl b/ext/SEMBlackBoxOptimExt/SEMBlackBoxOptimExt.jl new file mode 100644 index 00000000..9cbdac4d --- /dev/null +++ b/ext/SEMBlackBoxOptimExt/SEMBlackBoxOptimExt.jl @@ -0,0 +1,13 @@ +module SEMBlackBoxOptimExt + +using StructuralEquationModels, BlackBoxOptim, Optimisers + +SEM = StructuralEquationModels + +export SemOptimizerBlackBoxOptim + +include("AdamMutation.jl") +include("DiffEvoFactory.jl") +include("SemOptimizerBlackBoxOptim.jl") + +end diff --git a/ext/SEMBlackBoxOptimExt/SemOptimizerBlackBoxOptim.jl b/ext/SEMBlackBoxOptimExt/SemOptimizerBlackBoxOptim.jl new file mode 100644 index 00000000..219f409e --- /dev/null +++ b/ext/SEMBlackBoxOptimExt/SemOptimizerBlackBoxOptim.jl @@ -0,0 +1,91 @@ +############################################################################################ +### connect to BlackBoxOptim.jl as backend +############################################################################################ + +""" +""" +struct SemOptimizerBlackBoxOptim <: SemOptimizer{:BlackBoxOptim} + lower_bound::Float64 # default lower bound + variance_lower_bound::Float64 # default variance lower bound + lower_bounds::Union{Dict{Symbol, Float64}, Nothing} + + upper_bound::Float64 # default upper bound + upper_bounds::Union{Dict{Symbol, Float64}, Nothing} +end + +function SemOptimizerBlackBoxOptim(; + lower_bound::Float64 = -1000.0, + lower_bounds::Union{AbstractDict{Symbol, Float64}, Nothing} = nothing, + variance_lower_bound::Float64 = 0.001, + upper_bound::Float64 = 1000.0, + upper_bounds::Union{AbstractDict{Symbol, Float64}, Nothing} = nothing, + kwargs..., +) + if variance_lower_bound < 0.0 + throw(ArgumentError("variance_lower_bound must be non-negative")) + end + return SemOptimizerBlackBoxOptim( + lower_bound, + variance_lower_bound, + lower_bounds, + upper_bound, + upper_bounds, + ) +end + +SEM.SemOptimizer{:BlackBoxOptim}(args...; kwargs...) = + SemOptimizerBlackBoxOptim(args...; kwargs...) + +SEM.algorithm(optimizer::SemOptimizerBlackBoxOptim) = optimizer.algorithm +SEM.options(optimizer::SemOptimizerBlackBoxOptim) = optimizer.options + +struct SemModelBlackBoxOptimProblem{M <: AbstractSem} <: + OptimizationProblem{ScalarFitnessScheme{true}} + model::M + fitness_scheme::ScalarFitnessScheme{true} + search_space::ContinuousRectSearchSpace +end + +function BlackBoxOptim.search_space(model::AbstractSem) + optim = model.optimizer::SemOptimizerBlackBoxOptim + return ContinuousRectSearchSpace( + SEM.lower_bounds( + optim.lower_bounds, + model, + default = optim.lower_bound, + variance_default = optim.variance_lower_bound, + ), + SEM.upper_bounds(optim.upper_bounds, model, default = optim.upper_bound), + ) +end + +function SemModelBlackBoxOptimProblem( + model::AbstractSem, + optimizer::SemOptimizerBlackBoxOptim, +) + SemModelBlackBoxOptimProblem(model, ScalarFitnessScheme{true}(), search_space(model)) +end + +BlackBoxOptim.fitness(params::AbstractVector, wrapper::SemModelBlackBoxOptimProblem) = + return SEM.evaluate!(0.0, nothing, nothing, wrapper.model, params) + +# sem_fit method +function SEM.sem_fit( + optimizer::SemOptimizerBlackBoxOptim, + model::AbstractSem, + start_params::AbstractVector; + Method::Symbol = :adaptive_de_rand_1_bin_with_gradient, + MaxSteps::Integer = 50000, + kwargs..., +) + problem = SemModelBlackBoxOptimProblem(model, optimizer) + if Method == :adaptive_de_rand_1_bin_with_gradient + # custom adaptive differential evolution with mutation that moves along the gradient + bbopt_factory = DefaultDiffEvoFactory(problem; kwargs...) + bbopt = bbsetup(bbopt_factory; MaxSteps, kwargs...) + else + bbopt = bbsetup(problem; Method, MaxSteps, kwargs...) + end + res = bboptimize(bbopt) + return SemFit(best_fitness(res), best_candidate(res), nothing, model, res) +end From 20125efbaa5457f570d1b270f677a8df335408de Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:40:37 -0700 Subject: [PATCH 66/74] non_posdef_return(v) -> non_posdef_objective(v) and move to abstract.jl --- src/loss/ML/FIML.jl | 2 +- src/loss/ML/ML.jl | 16 ++-------------- src/loss/abstract.jl | 9 +++++++++ 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/src/loss/ML/FIML.jl b/src/loss/ML/FIML.jl index 0a9f66b6..47e3c1ed 100644 --- a/src/loss/ML/FIML.jl +++ b/src/loss/ML/FIML.jl @@ -154,7 +154,7 @@ function evaluate!(objective, gradient, hessian, loss::SemFIML, params) Σ_chol = cholesky!(Symmetric(loss.imp_inv); check = false) if !isposdef(Σ_chol) - isnothing(objective) || (objective = non_posdef_return(params)) + isnothing(objective) || (objective = non_posdef_objective(params)) isnothing(gradient) || fill!(gradient, 1) return objective end diff --git a/src/loss/ML/ML.jl b/src/loss/ML/ML.jl index 58ecbe9f..8295d73f 100644 --- a/src/loss/ML/ML.jl +++ b/src/loss/ML/ML.jl @@ -121,7 +121,7 @@ function evaluate!( Σ_chol = cholesky!(Symmetric(Σ⁻¹); check = false) if !isposdef(Σ_chol) #@warn "∑⁻¹ is not positive definite" - isnothing(objective) || (objective = non_posdef_return(par)) + isnothing(objective) || (objective = non_posdef_objective(par)) isnothing(gradient) || fill!(gradient, 1) isnothing(hessian) || copyto!(hessian, I) return objective @@ -188,7 +188,7 @@ function evaluate!(objective, gradient, hessian, loss::SemML, par) Σ_chol = cholesky!(Symmetric(Σ⁻¹); check = false) if !isposdef(Σ_chol) #@warn "Σ⁻¹ is not positive definite" - isnothing(objective) || (objective = non_posdef_return(par)) + isnothing(objective) || (objective = non_posdef_objective(par)) isnothing(gradient) || fill!(gradient, 1) isnothing(hessian) || copyto!(hessian, I) return objective @@ -256,15 +256,3 @@ function evaluate!(objective, gradient, hessian, loss::SemML, par) return objective end - -############################################################################################ -### additional functions -############################################################################################ - -function non_posdef_return(par) - if eltype(par) <: AbstractFloat - return floatmax(eltype(par)) - else - return typemax(eltype(par)) - end -end diff --git a/src/loss/abstract.jl b/src/loss/abstract.jl index 56a3af58..8456f0ae 100644 --- a/src/loss/abstract.jl +++ b/src/loss/abstract.jl @@ -76,3 +76,12 @@ replace_observed(loss::AbstractLoss, ::Any; kwargs...) = loss # LossTerm: delegate to inner loss replace_observed(term::LossTerm, data; kwargs...) = LossTerm(replace_observed(loss(term), data; kwargs...), id(term), weight(term)) + +# returned objective if the implied Σ(par) matrix is not positive definite +function non_posdef_objective(par::AbstractVector) + if eltype(par) <: AbstractFloat + return floatmax(eltype(par)) + else + return typemax(eltype(par)) + end +end From 57fce3317956f457227e1c7933679afd75940470 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:42:29 -0700 Subject: [PATCH 67/74] MeanStruct(ram) --- src/frontend/specification/RAMMatrices.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/frontend/specification/RAMMatrices.jl b/src/frontend/specification/RAMMatrices.jl index 357ab408..35ea1225 100644 --- a/src/frontend/specification/RAMMatrices.jl +++ b/src/frontend/specification/RAMMatrices.jl @@ -12,6 +12,8 @@ struct RAMMatrices <: SemSpecification vars::Union{Vector{Symbol}, Nothing} end +MeanStruct(ram::RAMMatrices) = isnothing(ram.M) ? NoMeanStruct() : HasMeanStruct() + nparams(ram::RAMMatrices) = nparams(ram.A) nvars(ram::RAMMatrices) = size(ram.F, 2) nobserved_vars(ram::RAMMatrices) = size(ram.F, 1) From c993486aed25f69bbc0e43b4d7370224fdcaddb4 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:42:29 -0700 Subject: [PATCH 68/74] SemObserved: fix mean_and_cov() call it requires that data is Matrix --- src/observed/data.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/observed/data.jl b/src/observed/data.jl index 39eebe30..7b8a2baa 100644 --- a/src/observed/data.jl +++ b/src/observed/data.jl @@ -40,7 +40,7 @@ function SemObservedData(; ) data, obs_vars, _ = prepare_data(data, observed_vars, specification; observed_var_prefix) - obs_mean, obs_cov = mean_and_cov(data, 1) + obs_mean, obs_cov = mean_and_cov(convert(Matrix, data), 1) if any(ismissing, data) """ From 55219e482a56b39cabffa46ad8ce32a83bb96fd4 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:42:29 -0700 Subject: [PATCH 69/74] filter_used_params() not used currently --- src/additional_functions/params_array.jl | 27 ++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/additional_functions/params_array.jl b/src/additional_functions/params_array.jl index 8cd89032..bbbd29d7 100644 --- a/src/additional_functions/params_array.jl +++ b/src/additional_functions/params_array.jl @@ -265,3 +265,30 @@ function params_range(arr::ParamsArray; allow_gaps::Bool = false) return first_i:last_i end + +""" + filter_used_params([linearindex_test], arr::ParamsArray) + +Filter the parameters that are referenced in the `arr`, and +the linear indices of the corresponding parameters pass the +optional `linearindex_test`. + +Returns the indices of the used parameters. +""" +function filter_used_params(linearindex_test, arr::ParamsArray) + inds = Vector{Int}() + for i in 1:nparams(arr) + par_range = SEM.param_occurences_range(arr, i) + isempty(par_range) && continue # not relevant + @inbounds for j in par_range + lin_ind = arr.linear_indices[j] + if isnothing(linearindex_test) || linearindex_test(lin_ind) + push!(inds, i) + break + end + end + end + return inds +end + +filter_used_params(arr::ParamsArray) = filter_used_params(nothing, arr) From 87796f744735092ee85a36142c0051f008b9682e Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:42:29 -0700 Subject: [PATCH 70/74] param_indices(spec) method --- src/frontend/specification/documentation.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/frontend/specification/documentation.jl b/src/frontend/specification/documentation.jl index a3a8d265..9aa6d932 100644 --- a/src/frontend/specification/documentation.jl +++ b/src/frontend/specification/documentation.jl @@ -37,6 +37,16 @@ Return the vector of parameter labels (in the same order as [`params`](@ref)). """ param_labels(spec::SemSpecification) = spec.param_labels +""" + param_indices(spec::SemSpecification, params::AbstractVector{Symbol}) -> Vector{Int} + +Convert parameter labels to their indices in the SEM specification. +""" +function param_indices(spec::SemSpecification, params::AbstractVector{Symbol}) + param_map = Dict(param => i for (i, param) in enumerate(SEM.params(spec))) + return [param_map[param] for param in params] +end + """ `ParameterTable`s contain the specification of a structural equation model. From 8679990fbc3e1c8d386ec4dd6e21b77527f48339 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:44:41 -0700 Subject: [PATCH 71/74] SemNorm: generalize SemRidge - allow l^alpha for any alpha - specific support for alpha=1 (SemLasso) and alpha=2 (SemRidge) - allow affine transform of parameters before regularization - SemSpec rather then SemImplied in ctor - convert calculation to evaluate!() --- src/StructuralEquationModels.jl | 8 +- src/loss/regularization/norm.jl | 190 +++++++++++++++++++++++++++++++ src/loss/regularization/ridge.jl | 87 -------------- 3 files changed, 195 insertions(+), 90 deletions(-) create mode 100644 src/loss/regularization/norm.jl delete mode 100644 src/loss/regularization/ridge.jl diff --git a/src/StructuralEquationModels.jl b/src/StructuralEquationModels.jl index 3ba19fba..b6ba8219 100644 --- a/src/StructuralEquationModels.jl +++ b/src/StructuralEquationModels.jl @@ -67,9 +67,9 @@ include("implied/empty.jl") include("loss/abstract.jl") include("loss/ML/ML.jl") include("loss/ML/FIML.jl") -include("loss/regularization/ridge.jl") include("loss/WLS/WLS.jl") include("loss/constant/constant.jl") +include("loss/regularization/norm.jl") # constructor include("frontend/specification/Sem.jl") include("frontend/specification/documentation.jl") @@ -125,9 +125,11 @@ export AbstractSem, SemML, SemFIML, em_mvn, - SemRidge, - SemConstant, SemWLS, + SemConstant, + SemLasso, + SemNorm, + SemRidge, loss, nsem_terms, sem_terms, diff --git a/src/loss/regularization/norm.jl b/src/loss/regularization/norm.jl new file mode 100644 index 00000000..82b05719 --- /dev/null +++ b/src/loss/regularization/norm.jl @@ -0,0 +1,190 @@ +# l^α regularization + +############################################################################################ +### Types +############################################################################################ +""" + struct SemNorm{α, T, TB} <: AbstractLoss{ExactHessian} + +Regularization term that provides *Lᵅ* regularization of SEM parameters. +The term implements the ``\\sum_{i=1\\ldots n} \\left| p_i \right|^{\\alpha}``, +where *p_i*, *i = 1..n* is the vector of selected SEM parameter values. +For `α = 1` it implements the *LASSO* (`SemLasso` alias type), and for +`α = 2` it implements the *Ridge* regularization (`SemRidge` alias type). +The term also allows specifying an optional affine transform (*A × p + b*) +to apply to the parameters before the regularization. + +# Constructors + SemNorm(A::SparseMatrixCSC, b::Union{AbstractVector, Nothing} = nothing) + SemNorm(spec::SemSpecification, params::AbstractVector, + [A::AbstractMatrix = nothing], + [b::AbstractVector = nothing]; + α::Real) + +# Arguments +- `spec`: SEM model specification. +- `params::Vector`: optional IDs (Symbols) or indices of parameters to regularize. +- `A`: optional transformation matrix that defines how to transform the vector of parameter values + before the regularization. If `params` is not specified, the transformation is applied + to the entire parameters vector. +- `b`: optional vector of intercepts to add to the transformed parameters. +- `α`: regularization parameter, any positive real number is supported + +# Examples +```julia +my_lasso = SemLasso(spec, [:λ₁, :λ₂, :ω₂₃]) +my_trans_ridge = SemRidge(spec, [:λ₁, :λ₂, :ω₂₃], [1.0 1.0 0.0; 0.0 0.0 1.0], [-2.0, 0.0]) +``` +""" +struct SemNorm{α, TP, TA, TB, TH} <: AbstractLoss{ExactHessian} + param_inds::TP # indices of parameters to regularize + A::TA # transformation/subsetting of the parameters + At::TA # Aᵀ + b::TB # optional transformed parameter intercepts + H_inds::Vector{Int} # non-zero linear indices of Hessian + H_vals::TH # non-zero values of Hessian +end + +const SemRidge{TP, TA, TB, TH} = SemNorm{2, TP, TA, TB, TH} +const SemLasso{TP, TA, TB, TH} = SemNorm{1, TP, TA, TB, TH} + +############################################################################ +### Constructors +############################################################################ + +function SemNorm{α}( + param_inds::Union{AbstractVector, Nothing} = nothing, + A::Union{AbstractMatrix, Nothing} = nothing, + b::Union{AbstractVector, Nothing} = nothing, +) where {α} + isnothing(A) || + isnothing(param_inds) || + size(A, 2) == length(param_inds) || + throw( + DimensionMismatch( + "The transformation matrix columns ($(size(A, 2))) should match " * + "the number of parameters to regularize ($(length(param_inds)))", + ), + ) + isnothing(b) || + (isnothing(A) && isnothing(param_inds)) || + (length(b) == (isnothing(A) ? length(param_inds) : size(A, 1))) || + throw( + DimensionMismatch( + "The intercept length ($(length(b))) should match the rows of " * + "the transformation matrix ($(isnothing(A) ? "not specified" : size(A, 1)))" * + " or the number of parameters to regularize ($(isnothing(param_inds) ? "not specified" : length(param_inds)))", + ), + ) + + At = !isnothing(A) ? convert(typeof(A), A') : nothing + H = !isnothing(A) ? α * At * A : nothing # FIXME + if isnothing(H) + H_inds = Vector{Int}() + H_v = nothing + else + H_i, H_j, H_v = findnz(H) + H_indmtx = LinearIndices(H) + H_inds = [H_indmtx[i, j] for (i, j) in zip(H_i, H_j)] + H_v = copy(H_v) + end + return SemNorm{α, typeof(param_inds), typeof(A), typeof(b), typeof(H_v)}( + param_inds, + A, + At, + b, + H_inds, + H_v, + ) +end + +function SemNorm{α}( + spec::SemSpecification, + params::AbstractVector, + A::Union{AbstractMatrix, Nothing} = nothing, + b::Union{AbstractVector, Nothing} = nothing, +) where {α} + param_inds = eltype(params) <: Symbol ? param_indices(spec, params) : params + + isnothing(A) || + size(A, 2) == length(param_inds) || + throw( + DimensionMismatch( + "The transformation matrix columns ($(size(A, 2))) should match " * + "the parameters to regularize ($(length(param_inds)))", + ), + ) + + sel_params_mtx = eachrow_to_col(Float64, param_inds, nparams(spec)) + if !isnothing(A) + if A isa SparseMatrixCSC + # for sparse matrices do parameters selection and multiplication in one step + A = convert(typeof(A), A * sel_params_mtx) + param_inds = nothing + end + else # if no matrix, just use selection matrix + A = sel_params_mtx + param_inds = nothing + end + return SemNorm{α}(param_inds, A, b) +end + +SemNorm(args...; α::Real) = SemNorm{α}(args...) + +nparams(f::SemNorm) = size(f.A, 2) + +############################################################################################ +### methods +############################################################################################ + +elnorm(_::Val{α}) where {α} = x -> abs(x)^α +elnorm(_::Val{1}) = abs +elnorm(_::Val{2}) = abs2 + +elnorm(_::SemNorm{α}) where {α} = elnorm(Val(α)) + +# not multiplied by α, handled by mul! +elnormgrad(_::Val{α}) where {α} = x -> abs(x)^(α - 1) * sign(x) +elnormgrad(_::Val{2}) = identity +elnormgrad(_::Val{1}) = sign + +elnormgrad(::SemNorm{α}) where {α} = elnormgrad(Val(α)) + +function evaluate!(objective, gradient, hessian, norm::SemNorm{α}, params) where {α} + if !isnothing(norm.param_inds) + params = params[norm.param_inds] + end + if !isnothing(norm.A) + trf_params = norm.A * params + end + if !isnothing(norm.b) + trf_params .+= norm.b + end + + obj = NaN + isnothing(objective) || (obj = sum(elnorm(norm), trf_params)) + + if !isnothing(gradient) + elgrad_trf_params = elnormgrad(norm).(trf_params) + if !isnothing(norm.param_inds) + mul!(params, norm.At, elgrad_trf_params, α, 0) + fill!(gradient, 0) + @inbounds gradient[norm.param_inds] .= params + else + mul!(gradient, norm.At, elgrad_trf_params, α, 0) + end + end + + if !isnothing(hessian) + fill!(hessian, 0) + if α === 1 + # do nothing, hessian is zero + elseif α === 2 + @inbounds hessian[norm.H_inds] .= norm.H_vals + else + error("Hessian not implemented for α ≠ 1, 2") + # TODO: Implement Hessian for other values of α + end + end + return obj +end diff --git a/src/loss/regularization/ridge.jl b/src/loss/regularization/ridge.jl deleted file mode 100644 index 813aff11..00000000 --- a/src/loss/regularization/ridge.jl +++ /dev/null @@ -1,87 +0,0 @@ -# (Ridge) regularization - -############################################################################################ -### Types -############################################################################################ -""" -Ridge regularization. - -# Constructor - - SemRidge(;α_ridge, which_ridge, nparams, parameter_type = Float64, implied = nothing, kwargs...) - -# Arguments -- `α_ridge`: hyperparameter for penalty term -- `which_ridge::Vector`: Vector of parameter labels (Symbols) or indices that indicate which parameters should be regularized. -- `nparams::Int`: number of parameters of the model -- `implied::SemImplied`: implied part of the model -- `parameter_type`: type of the parameters - -# Examples -```julia -my_ridge = SemRidge(;α_ridge = 0.02, which_ridge = [:λ₁, :λ₂, :ω₂₃], nparams = 30, implied = my_implied) -``` - -# Interfaces -Analytic gradients and hessians are available. -""" -struct SemRidge{P, W1, W2, GT, HT} <: AbstractLoss - hessianeval::ExactHessian - α::P - which::W1 - which_H::W2 - - gradient::GT - hessian::HT -end - -############################################################################ -### Constructors -############################################################################ - -function SemRidge(; - α_ridge, - which_ridge, - nparams, - parameter_type = Float64, - implied = nothing, - kwargs..., -) - if eltype(which_ridge) <: Symbol - if isnothing(implied) - throw( - ArgumentError( - "When referring to parameters by label, `implied = ...` has to be specified", - ), - ) - else - par2ind = param_indices(implied) - which_ridge = getindex.(Ref(par2ind), which_ridge) - end - end - which_H = [CartesianIndex(x, x) for x in which_ridge] - return SemRidge( - ExactHessian(), - α_ridge, - which_ridge, - which_H, - zeros(parameter_type, nparams), - zeros(parameter_type, nparams, nparams), - ) -end - -############################################################################################ -### methods -############################################################################################ - -objective(ridge::SemRidge, par) = @views ridge.α * sum(abs2, par[ridge.which]) - -function gradient(ridge::SemRidge, par) - @views ridge.gradient[ridge.which] .= (2 * ridge.α) * par[ridge.which] - return ridge.gradient -end - -function hessian(ridge::SemRidge, par) - @views @. ridge.hessian[ridge.which_H] .= 2 * ridge.α - return ridge.hessian -end From bd0b62dfc2ecc1746e08974bbc47563e39ff2450 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:45:17 -0700 Subject: [PATCH 72/74] add SemHinge --- src/StructuralEquationModels.jl | 3 ++ src/loss/regularization/hinge.jl | 80 ++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+) create mode 100644 src/loss/regularization/hinge.jl diff --git a/src/StructuralEquationModels.jl b/src/StructuralEquationModels.jl index b6ba8219..d74806ab 100644 --- a/src/StructuralEquationModels.jl +++ b/src/StructuralEquationModels.jl @@ -69,6 +69,8 @@ include("loss/ML/ML.jl") include("loss/ML/FIML.jl") include("loss/WLS/WLS.jl") include("loss/constant/constant.jl") +# regularization +include("loss/regularization/hinge.jl") include("loss/regularization/norm.jl") # constructor include("frontend/specification/Sem.jl") @@ -127,6 +129,7 @@ export AbstractSem, em_mvn, SemWLS, SemConstant, + SemHinge, SemLasso, SemNorm, SemRidge, diff --git a/src/loss/regularization/hinge.jl b/src/loss/regularization/hinge.jl new file mode 100644 index 00000000..b968068e --- /dev/null +++ b/src/loss/regularization/hinge.jl @@ -0,0 +1,80 @@ +""" + SemHinge{B} <: SemLossFunction{ExactHessian} + +Hinge regularization. + +Implements *hinge* a.k.a *rectified linear unit* (*ReLU*) loss function: +```math +f_{\\alpha, t}(x) = \\begin{cases} 0 & \\text{if}\\ x \\leq t \\\\ + \\alpha (x - t) & \\text{if } x > t. + \\end{cases} +``` +""" +struct SemHinge{B} <: SemLossFunction{ExactHessian} + threshold::Float64 + α::Float64 + param_inds::Vector{Int} # indices of parameters to regularize +end + +""" + SemHinge(spec::SemSpecification; + bound = 'l', threshold = 0.0, α, params) + +# Arguments +- `spec`: SEM model specification +- `threshold`: hyperparameter for penalty term +- `α_hinge`: hyperparameter for penalty term +- `which_hinge::AbstractVector`: Vector of parameter labels (Symbols) + or indices that indicate which parameters should be regularized. + +# Examples +```julia +my_hinge = SemHinge(spec; bound = 'u', α = 0.02, params = [:λ₁, :λ₂, :ω₂₃]) +``` +""" +function SemHinge( + spec::SemSpecification; + bound::Char = 'l', + threshold::Number = 0.0, + α::Number, + params::AbstractVector, +) + bound ∈ ('l', 'u') || + throw(ArgumentError("bound must be either 'l' or 'u', $bound given")) + + param_inds = eltype(params) <: Symbol ? param_indices(spec, params) : params + return SemHinge{bound}(threshold, α, param_inds) +end + +(hinge::SemHinge{'l'})(val::Number) = max(val - hinge.threshold, 0.0) +(hinge::SemHinge{'u'})(val::Number) = max(hinge.threshold - val, 0.0) + +function evaluate!( + objective, + gradient, + hessian, + hinge::SemHinge{B}, + imply::SemImply, + model, + params, +) where {B} + obj = NaN + if !isnothing(objective) + @inbounds obj = hinge.α * sum(i -> hinge(params[i]), hinge.param_inds) + end + if !isnothing(gradient) + fill!(gradient, 0) + @inbounds for i in hinge.param_inds + par = params[i] + if B == 'l' && par > hinge.threshold + gradient[i] = hinge.α + elseif B == 'u' && par < hinge.threshold + gradient[i] = -hinge.α + end + end + end + if !isnothing(hessian) + fill!(hessian, 0) + end + return obj +end From 6b3e69b64ead86a7a1e1a5150e0a8861600588f2 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 4 May 2026 15:45:45 -0700 Subject: [PATCH 73/74] add SemSquaredHinge --- src/StructuralEquationModels.jl | 2 + src/loss/regularization/squared_hinge.jl | 93 ++++++++++++++++++++++++ 2 files changed, 95 insertions(+) create mode 100644 src/loss/regularization/squared_hinge.jl diff --git a/src/StructuralEquationModels.jl b/src/StructuralEquationModels.jl index d74806ab..530305a0 100644 --- a/src/StructuralEquationModels.jl +++ b/src/StructuralEquationModels.jl @@ -72,6 +72,7 @@ include("loss/constant/constant.jl") # regularization include("loss/regularization/hinge.jl") include("loss/regularization/norm.jl") +include("loss/regularization/squared_hinge.jl") # constructor include("frontend/specification/Sem.jl") include("frontend/specification/documentation.jl") @@ -133,6 +134,7 @@ export AbstractSem, SemLasso, SemNorm, SemRidge, + SemSquaredHinge, loss, nsem_terms, sem_terms, diff --git a/src/loss/regularization/squared_hinge.jl b/src/loss/regularization/squared_hinge.jl new file mode 100644 index 00000000..c49b897b --- /dev/null +++ b/src/loss/regularization/squared_hinge.jl @@ -0,0 +1,93 @@ +""" + SemHinge{B} <: SemLossFunction{ExactHessian} + +Hinge regularization. + +Implements *hinge* a.k.a *rectified linear unit* (*ReLU*) loss function: +```math +f_{\\alpha, t}(x) = \\begin{cases} 0 & \\text{if}\\ x \\leq t \\\\ + \\alpha (x - t)^2 & \\text{if } x > t. + \\end{cases} +``` +""" +struct SemSquaredHinge{B} <: SemLossFunction{ExactHessian} + threshold::Float64 + α::Float64 + param_inds::Vector{Int} # indices of parameters to regularize + H_diag_inds::Vector{Int} # indices of Hessian diagonal elements to regularize +end + +""" + SemSquaredHinge(spec::SemSpecification; + bound = 'l', threshold = 0.0, α, params) + +# Arguments +- `spec`: SEM model specification +- `threshold`: hyperparameter for penalty term +- `α_hinge`: hyperparameter for penalty term +- `which_hinge::AbstractVector`: Vector of parameter labels (Symbols) + or indices that indicate which parameters should be regularized. + +# Examples +```julia +my_hinge = SemHinge(spec; bound = 'u', α = 0.02, params = [:λ₁, :λ₂, :ω₂₃]) +``` +""" +function SemSquaredHinge( + spec::SemSpecification; + bound::Char = 'l', + threshold::Number = 0.0, + α::Number, + params::AbstractVector, +) + bound ∈ ('l', 'u') || + throw(ArgumentError("bound must be either 'l' or 'u', $bound given")) + + param_inds = eltype(params) <: Symbol ? param_indices(spec, params) : params + H_linind = LinearIndices((nparams(spec), nparams(spec))) + return SemSquaredHinge{bound}( + threshold, + α, + param_inds, + [H_linind[i, i] for i in param_inds], + ) +end + +(sqhinge::SemSquaredHinge{'l'})(val::Number) = abs2(max(val - sqhinge.threshold, 0.0)) +(sqhinge::SemSquaredHinge{'u'})(val::Number) = abs2(max(sqhinge.threshold - val, 0.0)) + +function evaluate!( + objective, + gradient, + hessian, + sqhinge::SemSquaredHinge{B}, + imply::SemImply, + model, + params, +) where {B} + obj = NaN + if !isnothing(objective) + @inbounds obj = sqhinge.α * sum(i -> sqhinge(params[i]), sqhinge.param_inds) + end + if !isnothing(gradient) + fill!(gradient, 0) + @inbounds for i in sqhinge.param_inds + par = params[i] + if (B == 'l' && par > sqhinge.threshold) || + (B == 'u' && par < sqhinge.threshold) + gradient[i] = 2 * sqhinge.α * (par - sqhinge.threshold) + end + end + end + if !isnothing(hessian) + fill!(hessian, 0) + @inbounds for (par_i, H_i) in zip(sqhinge.param_inds, sqhinge.H_diag_inds) + par = params[par_i] + if (B == 'l' && par > sqhinge.threshold) || + (B == 'u' && par < sqhinge.threshold) + hessian[H_i] = 2 * sqhinge.α + end + end + end + return obj +end From 7bc2c880fe374075437c212e831afcb32bc2b776 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Sat, 31 Aug 2024 15:30:13 -0700 Subject: [PATCH 74/74] quad.jl: optimized methods for X*A*X', X*X' etc --- src/StructuralEquationModels.jl | 1 + src/additional_functions/quad.jl | 192 +++++++++++++++++++++++++++++++ 2 files changed, 193 insertions(+) create mode 100644 src/additional_functions/quad.jl diff --git a/src/StructuralEquationModels.jl b/src/StructuralEquationModels.jl index 530305a0..6f27f757 100644 --- a/src/StructuralEquationModels.jl +++ b/src/StructuralEquationModels.jl @@ -35,6 +35,7 @@ include("objective_gradient_hessian.jl") include("additional_functions/commutation_matrix.jl") include("additional_functions/sparse_utils.jl") include("additional_functions/params_array.jl") +include("additional_functions/quad.jl") # fitted objects include("frontend/fit/SemFit.jl") diff --git a/src/additional_functions/quad.jl b/src/additional_functions/quad.jl new file mode 100644 index 00000000..11287607 --- /dev/null +++ b/src/additional_functions/quad.jl @@ -0,0 +1,192 @@ +_unwrap_symmetric(res::AbstractMatrix) = res +_unwrap_symmetric(res::Symmetric) = parent(res) + +# faster version of copytri!() that uses blascopy!() +function blascopytri!(A::StridedMatrix, uplo::AbstractChar) + n = LinearAlgebra.checksquare(A) + if uplo == 'L' + for (i, di) in enumerate(diagind(A)) + (i < n) || continue + BLAS.blascopy!( + n - i, + pointer(A, di + 1), + stride(A, 1), + pointer(A, di + size(A, 2)), + stride(A, 2), + ) + end + elseif uplo == 'U' + for (i, di) in enumerate(diagind(A)) + (i < n) || continue + BLAS.blascopy!( + n - i, + pointer(A, di + size(A, 2)), + stride(A, 2), + pointer(A, di + 1), + stride(A, 1), + ) + end + else + lazy"uplo argument must be 'U' (upper) or 'L' (lower), got $uplo" |> + ArgumentError |> + throw + end + return A +end + +# faster copytri!() that uses @simd, @inbounds and drops elementwise conjugation +@inline function fastcopytri!(A::AbstractMatrix, uplo::AbstractChar) + n = LinearAlgebra.checksquare(A) + if uplo == 'U' + @inbounds for i in axes(A, 1) + @simd for j in (i+1):n + A[j, i] = A[i, j] + end + end + elseif uplo == 'L' + @inbounds for i in axes(A, 1) + @simd for j in (i+1):n + A[i, j] = A[j, i] + end + end + else + lazy"uplo argument must be 'U' (upper) or 'L' (lower), got $uplo" |> + ArgumentError |> + throw + end + A +end + +# faster version that drops issymmetric checks +# and switches to gemm mode for large matrices +@inline function syrk_wrapper!( + res::AbstractMatrix, + trans::Char, + X::Union{StridedMatrix, StridedVector}, + alpha::Real = 1, + beta::Real = 0; + check::Bool = true, + # big matrices are multiplied in gemm mode to avoid long copytri!() + mode::Symbol = size(res, 1) >= 1000 ? :gemm : :syrk, +) + T = eltype(X) + if mode == :syrk && (iszero(beta) || (!check || issymmetric(res))) + BLAS.syrk!('U', trans, T(alpha), X, T(beta), _unwrap_symmetric(res)) + fastcopytri!(_unwrap_symmetric(res), 'U') + elseif mode == :gemm # generic + LinearAlgebra.gemm_wrapper!( + _unwrap_symmetric(res), + trans, + trans == 'N' ? 'T' : 'N', + X, + X, + LinearAlgebra.MulAddMul(alpha, beta), + ) + else + throw(ArgumentError(lazy"mode must be :syrk or :gemm, $mode given")) + end + return res +end + +# calculate Xᵀ⋅X +Xt_X!(res::AbstractMatrix, X::AbstractMatrix, alpha::Real = 1, beta::Real = 0) = + mul!(_unwrap_symmetric(res), X', X, alpha, beta) + +Xt_X!(res::AbstractMatrix, X::StridedMatrix, alpha::Real = 1, beta::Real = 0; kwargs...) = + syrk_wrapper!(res, 'T', X, alpha, beta; kwargs...) + +X_Xt!( + res::AbstractMatrix, + X::Union{AbstractMatrix, AbstractVector}, + alpha::Real = 1, + beta::Real = 0, +) = mul!(_unwrap_symmetric(res), X, X', alpha, beta) + +X_Xt!( + res::AbstractMatrix, + X::Union{StridedMatrix, StridedVector}, + alpha::Real = 1, + beta::Real = 0; + kwargs..., +) = syrk_wrapper!(res, 'N', X, alpha, beta; kwargs...) + +Xt_X(X::AbstractMatrix) = Xt_X!(Matrix{eltype(X)}(undef, size(X, 2), size(X, 2)), X) + +X_Xt(X::Union{AbstractMatrix, AbstractVector}) = + X_Xt!(Matrix{eltype(X)}(undef, size(X, 1), size(X, 1)), X) + +# calculate Xᵀ⋅A⋅X +# FIXME: use PDMats.jl when its sparse matrix support is refactored +# see https://github.com/JuliaStats/PDMats.jl/pull/188 +function Xt_A_X!( + res::AbstractMatrix, + A::AbstractMatrix, + X::AbstractMatrix, + alpha::Real = 1, + beta::Real = 0; + Xt_A_buf::Union{AbstractMatrix, Nothing} = nothing, +) + Xt_A = !isnothing(Xt_A_buf) ? mul!(Xt_A_buf, X', A) : X'A + return mul!(_unwrap_symmetric(res), Xt_A, X, alpha, beta) +end + +# special handling of symmetric to make sure it is the first argument in * +function Xt_A_X!( + res::AbstractMatrix, + A::Symmetric{<:Any, M}, + X::AbstractMatrix, + alpha::Real = 1, + beta::Real = 0; + Xt_A_buf::Union{AbstractMatrix, Nothing} = nothing, +) where {M <: StridedMatrix} + A_X = !isnothing(Xt_A_buf) ? mul!(reshape(Xt_A_buf, size(X)), A, X) : A*X + return mul!(_unwrap_symmetric(res), X', A_X, alpha, beta) +end + +Xt_A_X(A::AbstractMatrix, X::AbstractMatrix, alpha::Real = 1; kwargs...) = Xt_A_X!( + Matrix{promote_type(eltype(A), eltype(X))}(undef, size(X, 2), size(X, 2)), + A, + X, + alpha, + 0; + kwargs..., +) + +function X_A_Xt!( + res::AbstractMatrix, + A::AbstractMatrix, + X::AbstractMatrix, + alpha::Real = 1, + beta::Real = 0; + X_A_buf::Union{AbstractMatrix, Nothing} = nothing, +) + X_A = !isnothing(X_A_buf) ? mul!(X_A_buf, X, A) : X * A + return mul!(_unwrap_symmetric(res), X_A, X', alpha, beta) +end + +# special handling of symmetric to make sure it is the first argument in * +function X_A_Xt!( + res::AbstractMatrix, + A::Symmetric{<:Any, M}, + X::AbstractMatrix, + alpha::Real = 1, + beta::Real = 0; + X_A_buf::Union{AbstractMatrix, Nothing} = nothing, +) where {M <: StridedMatrix} + # FIXME in principle no need to unwrap A, but with symmetric A and transposed X + # julia's generic_matmatmul() falls back into non-BLAS implementation (looks like Julia's bug) + A_Xt = + !isnothing(X_A_buf) ? + mul!(reshape(X_A_buf, size(X, 2), size(X, 1)), _unwrap_symmetric(A), X') : + _unwrap_symmetric(A) * X' + return mul!(_unwrap_symmetric(res), X, A_Xt, alpha, beta) +end + +X_A_Xt(A::AbstractMatrix, X::AbstractMatrix, alpha::Real = 1; kwargs...) = X_A_Xt!( + Matrix{promote_type(eltype(A), eltype(X))}(undef, size(X, 1), size(X, 1)), + A, + X, + alpha, + 0; + kwargs..., +)