diff --git a/src/R2DH.jl b/src/R2DH.jl index d3b7aaf4..ea8c84ba 100644 --- a/src/R2DH.jl +++ b/src/R2DH.jl @@ -81,6 +81,12 @@ function R2DHSolver( ) end +function SolverCore.reset!(solver::R2DHSolver) + LinearOperators.reset!(solver.D) +end + +SolverCore.reset!(solver::R2DHSolver, model) = SolverCore.reset!(solver) + """ R2DH(reg_nlp; kwargs…) diff --git a/src/R2N.jl b/src/R2N.jl index 74c96f7b..8b4eb014 100644 --- a/src/R2N.jl +++ b/src/R2N.jl @@ -96,6 +96,14 @@ function R2NSolver( ) end +function SolverCore.reset!(solver::R2NSolver) + _reset_power_method!(solver.v0) + B = solver.subpb.model.B + isa(B, AbstractLinearOperator) && LinearOperators.reset!(B) +end + +SolverCore.reset!(solver::R2NSolver, model) = SolverCore.reset!(solver) + """ R2N(reg_nlp; kwargs…) diff --git a/src/TRDH_alg.jl b/src/TRDH_alg.jl index 032e72fc..75127487 100644 --- a/src/TRDH_alg.jl +++ b/src/TRDH_alg.jl @@ -90,6 +90,12 @@ function TRDHSolver( ) end +function SolverCore.reset!(solver::TRDHSolver) + LinearOperators.reset!(solver.D) +end + +SolverCore.reset!(solver::TRDHSolver, model) = SolverCore.reset!(solver) + """ TRDH(reg_nlp; kwargs…) TRDH(nlp, h, χ, options; kwargs...) diff --git a/src/TR_alg.jl b/src/TR_alg.jl index dbab9163..484a6a37 100644 --- a/src/TR_alg.jl +++ b/src/TR_alg.jl @@ -95,6 +95,15 @@ function TRSolver( ) end +function SolverCore.reset!(solver::TRSolver) + _reset_power_method!(solver.v0) + reset_data!(solver.subpb.model) + LinearOperators.reset!(solver.subpb.model) +end + +SolverCore.reset!(solver::TRSolver, model) = SolverCore.reset!(solver) + + """ TR(reg_nlp; kwargs…) TR(nlp, h, χ, options; kwargs...) diff --git a/src/utils.jl b/src/utils.jl index ee257198..15c53d9e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,6 +2,11 @@ export RegularizedExecutionStats import SolverCore.GenericExecutionStats +# Reset the initial vector for the power method to [1/sqrt(n), -1/sqrt(n), 1/sqrt(n), ...]. +function _reset_power_method!(v0::AbstractVector) + v0 .= (isodd.(eachindex(v0)) .* 2 .- 1) ./ sqrt(length(v0)) +end + function power_method!(B::M, v₀::S, v₁::S, max_iter::Int = 1) where {M, S} @assert max_iter >= 1 "max_iter must be at least 1." mul!(v₁, B, v₀)