diff --git a/src/canon/canonicalizer.rs b/src/canon/canonicalizer.rs index 718e13b..c1535aa 100644 --- a/src/canon/canonicalizer.rs +++ b/src/canon/canonicalizer.rs @@ -444,8 +444,12 @@ impl CanonContext { } } - fn canonicalize_sum(&mut self, a: &Expr, _axis: Option) -> CanonExpr { + fn canonicalize_sum(&mut self, a: &Expr, axis: Option) -> CanonExpr { let ca = self.canonicalize_expr(a, false).as_linear().clone(); + if let Some(axis) = axis { + return self.canonicalize_sum_axis_lin(&ca, axis); + } + // Sum all elements: multiply by ones vector let size = ca.size(); let ones = DMatrix::from_element(1, size, 1.0); @@ -469,6 +473,50 @@ impl CanonContext { }) } + fn canonicalize_sum_axis_lin(&self, x: &LinExpr, axis: usize) -> CanonExpr { + if x.shape.ndim() <= 1 { + return CanonExpr::Linear(x.clone()); + } + + let rows = x.shape.rows(); + let cols = x.shape.cols(); + let (out_size, mut s_rows, mut s_cols, mut s_vals) = match axis { + 0 => (cols, Vec::new(), Vec::new(), Vec::new()), + 1 => (rows, Vec::new(), Vec::new(), Vec::new()), + _ => return CanonExpr::Linear(x.clone()), + }; + + for col in 0..cols { + for row in 0..rows { + let input_idx = row + col * rows; + let output_idx = if axis == 0 { col } else { row }; + s_rows.push(output_idx); + s_cols.push(input_idx); + s_vals.push(1.0); + } + } + + let selector = + crate::sparse::triplets_to_csc(out_size, x.size(), &s_rows, &s_cols, &s_vals); + + let mut new_coeffs = std::collections::HashMap::new(); + for (var_id, coeff) in &x.coeffs { + new_coeffs.insert(*var_id, crate::sparse::csc_matmul(&selector, coeff)); + } + + let flat_const = x + .constant + .clone() + .reshape_generic(nalgebra::Dyn(x.size()), nalgebra::Dyn(1)); + let new_const = csc_to_dense(&selector) * flat_const; + + CanonExpr::Linear(LinExpr { + coeffs: new_coeffs, + constant: new_const, + shape: Shape::vector(out_size), + }) + } + fn canonicalize_reshape(&mut self, a: &Expr, shape: &Shape) -> CanonExpr { let ca = self.canonicalize_expr(a, false).as_linear().clone(); // Reshape doesn't change the linear structure, just the shape interpretation @@ -942,36 +990,26 @@ impl CanonContext { }); } - // SOC reformulation: ||x||^2 <= t iff SOC(sqrt(t), x) - // Actually: introduce t, s, with t = s + 1, and SOC(s, x) - // Simpler: introduce t >= 0, with ||x||_2^2 <= t via rotated SOC - // Or: ||x||^2 = quad_over_lin(x, 1) let (_, t) = self.new_nonneg_aux_var(Shape::scalar()); - - // Rotated SOC: ||x||^2 <= 2 * t * 1 = 2t - // Standard form: || [2t - 1; 2x] ||_2 <= 2t + 1 - // Simplified: use SOC with proper reformulation - self.constraints.push(ConeConstraint::SOC { - t: t.clone(), - x: x.clone(), - }); + let one = LinExpr::scalar(1.0); + let soc_t = t.add(&one); + let soc_x = self.vstack_lin(&t.add(&one.neg()), &x.scale(2.0)); + self.constraints + .push(ConeConstraint::SOC { t: soc_t, x: soc_x }); CanonExpr::Linear(t) } fn canonicalize_quad_over_lin(&mut self, x: &Expr, y: &Expr) -> CanonExpr { - // ||x||_2^2 / y: Introduce t, rotated SOC constraint + // ||x||_2^2 / y: introduce t with ||[2x; t-y]||_2 <= t+y. let cx = self.canonicalize_expr(x, false).as_linear().clone(); - let _cy = self.canonicalize_expr(y, false).as_linear().clone(); + let cy = self.canonicalize_expr(y, false).as_linear().clone(); let (_, t) = self.new_nonneg_aux_var(Shape::scalar()); - // Rotated SOC: ||x||^2 <= t * y - // This requires proper rotated SOC support - // Simplified: add as SOC - self.constraints.push(ConeConstraint::SOC { - t: t.clone(), - x: cx, - }); + let soc_t = t.add(&cy); + let soc_x = self.vstack_lin(&t.add(&cy.neg()), &cx.scale(2.0)); + self.constraints + .push(ConeConstraint::SOC { t: soc_t, x: soc_x }); CanonExpr::Linear(t) } @@ -1125,11 +1163,6 @@ impl CanonContext { return CanonExpr::Linear(cx); } - if (p - 2.0).abs() < 1e-10 { - // x^2 use sum_squares approach (more efficient) - return self.canonicalize_sum_squares(&Expr::from(x), false); - } - // Create auxiliary variable t for the result let (t_var_id, t) = self.new_nonneg_aux_var(cx.shape.clone()); let _ = t_var_id; @@ -1247,8 +1280,8 @@ impl CanonContext { expr.constant[(idx, 0)] } else { // For matrix, compute flat index - let row = idx / expr.shape.cols(); - let col = idx % expr.shape.cols(); + let row = idx % expr.shape.rows(); + let col = idx / expr.shape.rows(); expr.constant[(row, col)] }; diff --git a/src/expr/expression.rs b/src/expr/expression.rs index aa21e04..1c8e93e 100644 --- a/src/expr/expression.rs +++ b/src/expr/expression.rs @@ -314,8 +314,10 @@ impl Expr { let dims = a.shape(); if dims.ndim() <= 1 { Shape::scalar() - } else { + } else if *axis == Some(0) { Shape::vector(dims.cols()) + } else { + Shape::vector(dims.rows()) } } else { Shape::scalar() @@ -570,4 +572,21 @@ mod tests { assert_eq!(c.shape(), Shape::matrix(3, 1)); assert!(c.is_constant()); } + + #[test] + fn test_sum_axis_shape() { + let x = Expr::Variable(VariableData { + id: ExprId::new(), + shape: Shape::matrix(2, 3), + name: None, + nonneg: false, + nonpos: false, + }); + + assert_eq!( + Expr::Sum(Arc::new(x.clone()), Some(0)).shape(), + Shape::vector(3) + ); + assert_eq!(Expr::Sum(Arc::new(x), Some(1)).shape(), Shape::vector(2)); + } } diff --git a/src/solver/clarabel.rs b/src/solver/clarabel.rs index 59615c1..6088243 100644 --- a/src/solver/clarabel.rs +++ b/src/solver/clarabel.rs @@ -10,6 +10,8 @@ use clarabel::solver::{ }; use super::stuffing::{ConeDims, StuffedProblem, VariableMap}; +use nalgebra::DMatrix; + use crate::expr::{Array, Evaluable, ExprId}; /// Solution status from the solver. @@ -345,7 +347,13 @@ fn unpack_primal(x: &[f64], var_map: &VariableMap) -> HashMap { for (&var_id, &(start, size)) in &var_map.id_to_col { let values: Vec = x[start..start + size].to_vec(); - let arr = if size == 1 { + let arr = if let Some(shape) = var_map.shape(var_id) { + if shape.is_scalar() { + Array::Scalar(values[0]) + } else { + Array::Dense(DMatrix::from_vec(shape.rows(), shape.cols(), values)) + } + } else if size == 1 { Array::Scalar(values[0]) } else { Array::from_vec(values) diff --git a/src/solver/stuffing.rs b/src/solver/stuffing.rs index d146d7c..2e7e6a7 100644 --- a/src/solver/stuffing.rs +++ b/src/solver/stuffing.rs @@ -42,6 +42,8 @@ impl ConeDims { pub struct VariableMap { /// Map from variable ID to (start_col, size). pub id_to_col: HashMap, + /// Original variable shapes, used when unpacking solver results. + pub id_to_shape: HashMap, /// Total number of optimization variables. pub total_vars: usize, } @@ -50,16 +52,19 @@ impl VariableMap { /// Create from a list of (variable_id, shape) pairs. pub fn from_vars(vars: &[(ExprId, Shape)]) -> Self { let mut id_to_col = HashMap::new(); + let mut id_to_shape = HashMap::new(); let mut offset = 0; for (var_id, shape) in vars { let size = shape.size(); id_to_col.insert(*var_id, (offset, size)); + id_to_shape.insert(*var_id, shape.clone()); offset += size; } VariableMap { id_to_col, + id_to_shape, total_vars: offset, } } @@ -68,6 +73,11 @@ impl VariableMap { pub fn get(&self, var_id: ExprId) -> Option<(usize, usize)> { self.id_to_col.get(&var_id).copied() } + + /// Get the original shape for a variable. + pub fn shape(&self, var_id: ExprId) -> Option<&Shape> { + self.id_to_shape.get(&var_id) + } } /// Stuffed problem ready for Clarabel. @@ -424,7 +434,7 @@ fn stuff_linear_expr( let size = expr.size(); for i in 0..expr.constant.nrows() { for j in 0..expr.constant.ncols() { - let idx = i * expr.constant.ncols() + j; + let idx = i + j * expr.constant.nrows(); if idx < size { let const_val = expr.constant[(i, j)]; // For Zero (negate=false): b = -constant diff --git a/tests/canon_tests.rs b/tests/canon_tests.rs new file mode 100644 index 0000000..0ddc596 --- /dev/null +++ b/tests/canon_tests.rs @@ -0,0 +1,254 @@ +use cvxrust::prelude::*; +use nalgebra::DMatrix; + +const TOL: f64 = 1e-4; + +/// sum_squares(Ax - b) must be canonicalized as ||r||^2, not ||r||_2. +/// The old bug used a plain SOC `||r||_2 <= t`, which minimized the L2 norm +/// instead of its square, causing solution.value to be off by a square root. + +#[test] +fn test_sum_squares_scalar_constrained() { + // minimize (x - 3)^2 s.t. x <= 1 + // True optimum: x* = 1, obj* = (1 - 3)^2 = 4 + // Old bug would report: |1 - 3| = 2 + let x = variable(()); + let residual = x.clone() - constant(3.0); + + let sol = Problem::minimize(sum_squares(&residual)) + .constraint(constraint!(x <= 1.0)) + .solve() + .unwrap(); + + let reported = sol.value.unwrap(); + let eval_sq = sum_squares(&residual).value(&sol).as_scalar().unwrap(); + let eval_norm = norm2(&residual).value(&sol).as_scalar().unwrap(); + + assert!( + (reported - 4.0).abs() < TOL, + "objective should be 4.0, got {reported}" + ); + assert!( + (reported - eval_sq).abs() < TOL, + "reported obj must equal sum_squares evaluated at solution" + ); + assert!((eval_norm - 2.0).abs() < TOL); + assert!( + (reported - eval_norm).abs() > 0.5, + "objective must not equal the L2 norm (old bug)" + ); +} + +#[test] +fn test_sum_squares_vector_constrained() { + // minimize ||x - [3, 4]||^2 s.t. x <= 2 + // True optimum: x* = [2, 2], obj* = (2-3)^2 + (2-4)^2 = 5 + // Old bug would report: sqrt(5) ~= 2.236 + let x = variable(2); + let residual = x.clone() - constant_vec(vec![3.0, 4.0]); + + let sol = Problem::minimize(sum_squares(&residual)) + .constraint(constraint!(x <= 2.0)) + .solve() + .unwrap(); + + let reported = sol.value.unwrap(); + let eval_sq = sum_squares(&residual).value(&sol).as_scalar().unwrap(); + let eval_norm = norm2(&residual).value(&sol).as_scalar().unwrap(); + + assert!( + (reported - 5.0).abs() < TOL, + "objective should be 5.0, got {reported}" + ); + assert!( + (reported - eval_sq).abs() < TOL, + "reported obj must equal sum_squares evaluated at solution" + ); + assert!((eval_norm - 5f64.sqrt()).abs() < TOL); + assert!( + (reported - eval_norm).abs() > 0.5, + "objective must not equal the L2 norm (old bug)" + ); +} + +#[test] +fn test_sum_squares_matmul_constrained() { + // minimize ||Ax - b||^2 s.t. x <= 1 + // A = [[1], [1]] (2x1), b = [2, 4], x scalar + // Unconstrained LS: x* = 3, obj* = (3-2)^2 + (3-4)^2 = 2 + // With x <= 1: x* = 1, residual = [-1, -3], obj* = 1 + 9 = 10 + // Old bug would report: sqrt(10) ~= 3.162 + let a = constant_matrix(vec![1.0, 1.0], 2, 1); + let b = constant_vec(vec![2.0, 4.0]); + let x = variable(()); + let residual = matmul(&a, &x) - &b; + + let sol = Problem::minimize(sum_squares(&residual)) + .constraint(constraint!(x <= 1.0)) + .solve() + .unwrap(); + + let reported = sol.value.unwrap(); + let eval_sq = sum_squares(&residual).value(&sol).as_scalar().unwrap(); + let eval_norm = norm2(&residual).value(&sol).as_scalar().unwrap(); + + assert!( + (reported - 10.0).abs() < TOL, + "objective should be 10.0, got {reported}" + ); + assert!( + (reported - eval_sq).abs() < TOL, + "reported obj must equal sum_squares evaluated at solution" + ); + assert!((eval_norm - 10f64.sqrt()).abs() < TOL); + assert!( + (reported - eval_norm).abs() > 0.5, + "objective must not equal the L2 norm (old bug)" + ); +} + +#[test] +fn test_sum_squares_constraint_uses_square_not_norm() { + // maximize x s.t. x^2 <= 4, x >= 0 + // True optimum: x* = 2. Old bug modeled |x| <= 4 and allowed x* = 4. + let x = variable(()); + + let sol = Problem::maximize(x.clone()) + .subject_to([sum_squares(&x).le(4.0), x.ge(0.0)]) + .solve() + .expect("problem should solve"); + + assert!((sol.value.unwrap() - 2.0).abs() < TOL); + assert!((x.value(&sol).as_scalar().unwrap() - 2.0).abs() < TOL); +} + +#[test] +fn test_quad_over_lin_uses_denominator() { + let x = variable(()); + + let sol = Problem::maximize(x.clone()) + .subject_to([quad_over_lin(&x, &constant(0.5)).le(2.0), x.ge(0.0)]) + .solve() + .expect("problem should solve"); + + assert!((solution_value(&sol, &x) - 1.0).abs() < TOL); +} + +#[test] +fn test_quad_over_lin_uses_variable_denominator() { + let x = variable(()); + let y = variable(()); + + let sol = Problem::maximize(x.clone()) + .subject_to([quad_over_lin(&x, &y).le(2.0), x.ge(0.0), y.eq(0.5)]) + .solve() + .expect("problem should solve"); + + assert!((solution_value(&sol, &x) - 1.0).abs() < TOL); +} + +#[test] +fn test_quad_over_lin_objective_constant_denominator() { + let x = variable(()); + + let sol = Problem::minimize(quad_over_lin(&x, &constant(0.5))) + .subject_to([x.ge(2.0)]) + .solve() + .expect("problem should solve"); + + assert!((sol.value.unwrap() - 8.0).abs() < TOL); + assert!((solution_value(&sol, &x) - 2.0).abs() < TOL); +} + +#[test] +fn test_quad_over_lin_objective_variable_denominator() { + let x = variable(()); + let y = variable(()); + + let sol = Problem::minimize(quad_over_lin(&x, &y)) + .subject_to([x.eq(2.0), y.eq(2.0)]) + .solve() + .expect("problem should solve"); + + assert!((sol.value.unwrap() - 2.0).abs() < TOL); +} + +#[test] +fn test_power_two_is_elementwise() { + let x = variable(2); + + let sol = Problem::maximize(sum(&x)) + .subject_to([power(&x, 2.0).le(constant_vec(vec![1.0, 4.0])), x.ge(0.0)]) + .solve() + .expect("problem should solve"); + + assert!((sol.value.unwrap() - 3.0).abs() < TOL); + + if let Array::Dense(x_vals) = x.value(&sol) { + assert!((x_vals[(0, 0)] - 1.0).abs() < TOL); + assert!((x_vals[(1, 0)] - 2.0).abs() < TOL); + } else { + panic!("expected dense vector solution"); + } +} + +#[test] +fn test_extract_element_matrix_constants_use_column_major_order() { + let x = variable((2, 2)); + let offset = constant_dmatrix(DMatrix::from_row_slice(2, 2, &[0.0, 10.0, 20.0, 30.0])); + let upper_sq = constant_dmatrix(DMatrix::from_row_slice(2, 2, &[1.0, 4.0, 25.0, 49.0])); + + let sol = Problem::maximize(sum(&x)) + .subject_to([power(&(&x + &offset), 2.0).le(upper_sq)]) + .solve() + .expect("problem should solve"); + + if let Array::Dense(x_vals) = x.value(&sol) { + assert!((x_vals[(0, 0)] - 1.0).abs() < TOL); + assert!((x_vals[(0, 1)] - -8.0).abs() < TOL); + assert!((x_vals[(1, 0)] - -15.0).abs() < TOL); + assert!((x_vals[(1, 1)] - -23.0).abs() < TOL); + } else { + panic!("expected dense matrix solution"); + } +} + +#[test] +fn test_sum_axis_constraints_are_not_total_sum() { + let x = variable((2, 2)); + + let sol = Problem::minimize(sum(&x)) + .subject_to([x.ge(0.0), sum_axis(&x, 0).eq(constant_vec(vec![1.0, 2.0]))]) + .solve() + .expect("problem should solve"); + + let vals = x.value(&sol); + if let Array::Dense(m) = vals { + assert!(((m[(0, 0)] + m[(1, 0)]) - 1.0).abs() < TOL); + assert!(((m[(0, 1)] + m[(1, 1)]) - 2.0).abs() < TOL); + } else { + panic!("expected dense matrix solution"); + } +} + +#[test] +fn test_sum_axis_one_constraints_are_not_total_sum() { + let x = variable((2, 2)); + + let sol = Problem::minimize(sum(&x)) + .subject_to([x.ge(0.0), sum_axis(&x, 1).eq(constant_vec(vec![1.0, 2.0]))]) + .solve() + .expect("problem should solve"); + + let vals = x.value(&sol); + if let Array::Dense(m) = vals { + assert!(((m[(0, 0)] + m[(0, 1)]) - 1.0).abs() < TOL); + assert!(((m[(1, 0)] + m[(1, 1)]) - 2.0).abs() < TOL); + } else { + panic!("expected dense matrix solution"); + } +} + +fn solution_value(sol: &Solution, expr: &Expr) -> f64 { + expr.value(sol).as_scalar().expect("expected scalar") +} diff --git a/tests/solver_tests.rs b/tests/solver_tests.rs new file mode 100644 index 0000000..9969555 --- /dev/null +++ b/tests/solver_tests.rs @@ -0,0 +1,25 @@ +use cvxrust::prelude::*; +use nalgebra::DMatrix; + +const TOL: f64 = 1e-4; + +#[test] +fn test_matrix_solution_recovery_preserves_shape_and_order() { + let x = variable((2, 2)); + let target = constant_dmatrix(DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0])); + + let sol = Problem::minimize(sum(&x)) + .subject_to([x.eq(target)]) + .solve() + .expect("problem should solve"); + + let vals = x.value(&sol); + if let Array::Dense(m) = vals { + assert!((m[(0, 0)] - 1.0).abs() < TOL); + assert!((m[(0, 1)] - 2.0).abs() < TOL); + assert!((m[(1, 0)] - 3.0).abs() < TOL); + assert!((m[(1, 1)] - 4.0).abs() < TOL); + } else { + panic!("expected dense matrix solution"); + } +} diff --git a/tests/sum_squares_tests.rs b/tests/sum_squares_tests.rs deleted file mode 100644 index c1b59e1..0000000 --- a/tests/sum_squares_tests.rs +++ /dev/null @@ -1,107 +0,0 @@ -use cvxrust::prelude::*; - -/// sum_squares(Ax - b) must be canonicalized as ||r||^2, not ||r||_2. -/// The old bug used a plain SOC `||r||_2 <= t`, which minimized the L2 norm -/// instead of its square, causing solution.value to be off by a square root. - -#[test] -fn test_sum_squares_scalar_constrained() { - // minimize (x - 3)^2 s.t. x <= 1 - // True optimum: x* = 1, obj* = (1 - 3)^2 = 4 - // Old bug would report: |1 - 3| = 2 - let x = variable(()); - let residual = x.clone() - constant(3.0); - - let sol = Problem::minimize(sum_squares(&residual)) - .constraint(constraint!(x <= 1.0)) - .solve() - .unwrap(); - - let reported = sol.value.unwrap(); - let eval_sq = sum_squares(&residual).value(&sol).as_scalar().unwrap(); - let eval_norm = norm2(&residual).value(&sol).as_scalar().unwrap(); - - assert!( - (reported - 4.0).abs() < 1e-4, - "objective should be 4.0, got {reported}" - ); - assert!( - (reported - eval_sq).abs() < 1e-4, - "reported obj must equal sum_squares evaluated at solution" - ); - // Sanity check: the L2 norm (sqrt(4) = 2) is clearly different from the correct answer - assert!((eval_norm - 2.0).abs() < 1e-4); - assert!( - (reported - eval_norm).abs() > 0.5, - "objective must not equal the L2 norm (old bug)" - ); -} - -#[test] -fn test_sum_squares_vector_constrained() { - // minimize ||x - [3, 4]||^2 s.t. x <= 2 - // True optimum: x* = [2, 2], obj* = (2-3)^2 + (2-4)^2 = 5 - // Old bug would report: sqrt(5) ≈ 2.236 - let x = variable(2); - let residual = x.clone() - constant_vec(vec![3.0, 4.0]); - - let sol = Problem::minimize(sum_squares(&residual)) - .constraint(constraint!(x <= 2.0)) - .solve() - .unwrap(); - - let reported = sol.value.unwrap(); - let eval_sq = sum_squares(&residual).value(&sol).as_scalar().unwrap(); - let eval_norm = norm2(&residual).value(&sol).as_scalar().unwrap(); - - assert!( - (reported - 5.0).abs() < 1e-4, - "objective should be 5.0, got {reported}" - ); - assert!( - (reported - eval_sq).abs() < 1e-4, - "reported obj must equal sum_squares evaluated at solution" - ); - // Sanity check: the L2 norm (sqrt(5) ≈ 2.236) is clearly different - assert!((eval_norm - 5f64.sqrt()).abs() < 1e-4); - assert!( - (reported - eval_norm).abs() > 0.5, - "objective must not equal the L2 norm (old bug)" - ); -} - -#[test] -fn test_sum_squares_matmul_constrained() { - // minimize ||Ax - b||^2 s.t. x <= 1 - // A = [[1], [1]] (2x1), b = [2, 4], x scalar - // Unconstrained LS: x* = 3, obj* = (3-2)^2 + (3-4)^2 = 2 - // With x <= 1: x* = 1, residual = [-1, -3], obj* = 1 + 9 = 10 - // Old bug would report: sqrt(10) ≈ 3.162 - let a = constant_matrix(vec![1.0, 1.0], 2, 1); - let b = constant_vec(vec![2.0, 4.0]); - let x = variable(()); - let residual = matmul(&a, &x) - &b; - - let sol = Problem::minimize(sum_squares(&residual)) - .constraint(constraint!(x <= 1.0)) - .solve() - .unwrap(); - - let reported = sol.value.unwrap(); - let eval_sq = sum_squares(&residual).value(&sol).as_scalar().unwrap(); - let eval_norm = norm2(&residual).value(&sol).as_scalar().unwrap(); - - assert!( - (reported - 10.0).abs() < 1e-4, - "objective should be 10.0, got {reported}" - ); - assert!( - (reported - eval_sq).abs() < 1e-4, - "reported obj must equal sum_squares evaluated at solution" - ); - assert!((eval_norm - 10f64.sqrt()).abs() < 1e-4); - assert!( - (reported - eval_norm).abs() > 0.5, - "objective must not equal the L2 norm (old bug)" - ); -}