Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/ModelingToolkitTearing/src/reassemble.jl
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ function get_linear_scc_linsol(state::TearingState, alg_eqs::Vector{Int},
N = length(b)
A = collect(A)::Matrix{Num}

if N <= analytical_linear_scc_limit && _check_allow_symbolic_parameter(
if N == 1 || N <= analytical_linear_scc_limit && _check_allow_symbolic_parameter(
state, A, allow_symbolic, allow_parameter
)
lu = try
Expand Down
150 changes: 128 additions & 22 deletions lib/ModelingToolkitTearing/src/tearingstate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -382,22 +382,6 @@ function TearingState(sys::System, source_info::Union{Nothing, MTKBase.EquationS
neqs = length(eqs)
symbolic_incidence = symbolic_incidence[eqs_to_retain]

if sort_eqs
# sort equations lexicographically to reduce simplification issues
# depending on order due to NP-completeness of tearing. Sort on a
# bounded prefix of the printed form: the full `string` of an equation
# is exponential in the sharing depth of hash-consed expressions, and
# ties on the first 4096 bytes keep their original (deterministic)
# relative order since the default sort is stable.
sortidxs = Base.sortperm(map(Base.Fix2(bounded_string, 4096), eqs))
eqs = eqs[sortidxs]
original_eqs = original_eqs[sortidxs]
symbolic_incidence = symbolic_incidence[sortidxs]
if !isempty(sources)
sources = sources[sortidxs]
end
end

dervaridxs = OrderedSet{Int}()
add_intermediate_derivatives!(fullvars, dervaridxs, addvar!)
# Handle shifts - find lowest shift and add intermediates with derivative edges
Expand All @@ -414,12 +398,27 @@ function TearingState(sys::System, source_info::Union{Nothing, MTKBase.EquationS
# build `var_to_diff`
var_to_diff = build_var_to_diff(fullvars, ndervars, var2idx, iv)

# build incidence graph
graph = build_incidence_graph(length(fullvars), symbolic_incidence, var2idx)

state_priorities = build_state_priorities(sys, fullvars, var_to_diff)
canonical_ranks = build_canonical_ranks(fullvars)

if sort_eqs
sortkeys = Vector{EquationSortKeyT}(undef, length(eqs))
cache = Base.IdDict{SymbolicT, EquationSortKeyT}()
for (i, eq) in enumerate(eqs)
sortkeys[i] = get_equation_sort_key!(cache, eq, var2idx, canonical_ranks)
end
sortidxs = Base.sortperm(sortkeys)
eqs = eqs[sortidxs]
original_eqs = original_eqs[sortidxs]
symbolic_incidence = symbolic_incidence[sortidxs]
if !isempty(sources)
sources = sources[sortidxs]
end
end

# build incidence graph
graph = build_incidence_graph(length(fullvars), symbolic_incidence, var2idx)

# Identify unknowns that do not appear in any equations and are thus not present in
# `fullvars`. The bindings and initial conditions for these variables should be removed.
for v in fullvars
Expand Down Expand Up @@ -449,6 +448,110 @@ function TearingState(sys::System, source_info::Union{Nothing, MTKBase.EquationS
typeof(sys)[], sources, nothing)
end

"""
Key used for sorting equations. Each element is a tuple of
(`canonical_rank` of variable, constant coefficient, exponent).
"""
const EquationSortKeyT = Vector{Tuple{Int, Float64, Float64}}

function get_equation_sort_key!(
cache::Base.IdDict{SymbolicT, EquationSortKeyT}, eq::Equation,
var2idx::Dict{SymbolicT, Int}, canonical_ranks::Vector{Int}
)
return get_expression_sort_key!(cache, eq.rhs, var2idx, canonical_ranks)
end

function get_expression_sort_key!(
cache::Base.IdDict{SymbolicT, EquationSortKeyT}, expr::SymbolicT,
var2idx::Dict{SymbolicT, Int}, canonical_ranks::Vector{Int}
)
val = get(cache, expr, nothing)
val === nothing || return val
val = __get_expression_sort_key!(cache, expr, var2idx, canonical_ranks)
cache[expr] = val
return val
end

function __get_expression_sort_key!(
cache::Base.IdDict{SymbolicT, EquationSortKeyT}, expr::SymbolicT,
var2idx::Dict{SymbolicT, Int}, canonical_ranks::Vector{Int}
)
@match expr begin
# Use `0` as the rank for constants
BSImpl.Const(; val) => if val isa Real
return [(0, convert(Float64, val), 1.0)]
else
return eltype(EquationSortKeyT)[]
end
BSImpl.Sym(;) => begin
idx = get(var2idx, expr, nothing)
idx === nothing && return eltype(EquationSortKeyT)[]
return [(canonical_ranks[idx], 1.0, 1.0)]
end
BSImpl.AddMul(; coeff, dict, variant) => begin
result = eltype(EquationSortKeyT)[]
ks = collect(keys(dict))
sort!(ks; lt = SU.:(<ₑ))
if variant == SU.AddMulVariant.ADD
for k in ks
arg_k = get_expression_sort_key!(cache, k, var2idx, canonical_ranks)
v = dict[k]
if !(v isa Real)
append!(result, arg_k)
continue
end
v = convert(Float64, v)
for t in arg_k
push!(result, (t[1], t[2] * v, t[3]))
end
end
else
if coeff isa Real
cf = convert(Float64, coeff)
else
cf = 1.0
end
for k in ks
arg_k = get_expression_sort_key!(cache, k, var2idx, canonical_ranks)
v = dict[k]
if !(v isa Real)
append!(result, arg_k)
continue
end
v = convert(Float64, v)
for t in arg_k
push!(result, (t[1], t[2] ^ v * cf, t[3] + v))
end
end
end
return result
end
_ => begin
idx = get(var2idx, expr, nothing)
idx === nothing || return [(canonical_ranks[idx], 1.0, 1.0)]
f = operation(expr)
args = arguments(expr)
if f === (^)
base_key = get_expression_sort_key!(cache, args[1], var2idx, canonical_ranks)
@match args[2] begin
BSImpl.Const(; val) => if val isa Real
v = convert(Float64, val)
return map(k -> (k[1], k[2] ^ v, k[3] + v), base_key)
else
return vcat(base_key, get_expression_sort_key!(cache, args[2], var2idx, canonical_ranks))
end
_ => return vcat(base_key, get_expression_sort_key!(cache, args[2], var2idx, canonical_ranks))
end
return base_key
end
result = eltype(EquationSortKeyT)[]
for arg in args
append!(result, get_expression_sort_key!(cache, arg, var2idx, canonical_ranks))
end
return result
end
end
end
"""
$TYPEDSIGNATURES

Expand Down Expand Up @@ -527,7 +630,7 @@ function canonical_sort_key(v::SymbolicT)
end
_ => nothing
end
return (canonical_name(x), idxs, opsig)
return (opsig, canonical_name(x), idxs)
end

"""
Expand Down Expand Up @@ -584,6 +687,10 @@ function sort_fullvars(fullvars::Vector{SymbolicT}, dervaridxs::Vector{Int}, var
return fullvars, var_types
end
iv = iv::SymbolicT
# Canonicalize the within-group order by `canonical_sort_key`
ckcache = Dict{SymbolicT, Any}() # TODO: Fix Any type
ckey(v) = get!(() -> canonical_sort_key(v), ckcache, v)
dervaridxs = sort(dervaridxs; by = i -> ckey(fullvars[i]))
sorted_fullvars = OrderedSet{SymbolicT}(fullvars[dervaridxs])
var_to_old_var = Dict{SymbolicT, SymbolicT}(zip(fullvars, fullvars))
for dervaridx in dervaridxs
Expand All @@ -593,7 +700,7 @@ function sort_fullvars(fullvars::Vector{SymbolicT}, dervaridxs::Vector{Int}, var
push!(sorted_fullvars, diffvar)
end
end
for v in fullvars
for v in sort(fullvars; by = ckey)
if !(v in sorted_fullvars)
push!(sorted_fullvars, v)
end
Expand Down Expand Up @@ -819,4 +926,3 @@ function shift_discrete_system(ts::TearingState)
@set! ts.fullvars = fullvars
return ts
end

7 changes: 6 additions & 1 deletion src/partial_state_selection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ function dummy_derivative_graph!(
diff_to_eq = invview(eq_to_diff)
diff_to_var = invview(var_to_diff)
invgraph = invview(graph)
cranks = has_canonical_ranks(structure) ? get_canonical_ranks(structure) : nothing
extended_sp = let state_priority = state_priority, var_to_diff = var_to_diff,
diff_to_var = diff_to_var

Expand Down Expand Up @@ -278,7 +279,11 @@ function dummy_derivative_graph!(
if state_priority !== nothing && isfirst
sp = extended_sp.(vars)
resize!(var_perm, length(sp))
sortperm!(var_perm, sp)
if cranks === nothing
sortperm!(var_perm, sp)
else
sortperm!(var_perm, collect(zip(sp, @view cranks[vars])))
end
permute!(vars, var_perm)
permute!(sp, var_perm)
# keep the Jacobian columns aligned with the permuted variable
Expand Down
11 changes: 10 additions & 1 deletion src/singularity_removal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,16 @@ function aag_bareiss!(structure, mm_orig::SparseMatrixCLIL{T, Ti}) where {T, Ti}
end
end
solvable_variables = findall(is_linear_variables)
var_priorities = has_state_priorities(structure) ? get_state_priorities(structure) : nothing
sp = has_state_priorities(structure) ? get_state_priorities(structure) : nothing
cr = has_canonical_ranks(structure) ? get_canonical_ranks(structure) : nothing
var_priorities = if cr === nothing
sp
elseif sp === nothing
cr
else
big = maximum(cr; init = 0) + 1
Int[sp[i] * big + cr[i] for i in eachindex(sp)]
end

bar = do_bareiss!(mm, mm_orig, is_linear_variables, is_highest_diff, var_priorities)

Expand Down
Loading