Skip to content
91 changes: 62 additions & 29 deletions src/canon/canonicalizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,12 @@ impl CanonContext {
}
}

fn canonicalize_sum(&mut self, a: &Expr, _axis: Option<usize>) -> CanonExpr {
fn canonicalize_sum(&mut self, a: &Expr, axis: Option<usize>) -> CanonExpr {
let ca = self.canonicalize_expr(a, false).as_linear().clone();
if let Some(axis) = axis {
return self.canonicalize_sum_axis_lin(&ca, axis);
}

// Sum all elements: multiply by ones vector
let size = ca.size();
let ones = DMatrix::from_element(1, size, 1.0);
Expand All @@ -469,6 +473,50 @@ impl CanonContext {
})
}

fn canonicalize_sum_axis_lin(&self, x: &LinExpr, axis: usize) -> CanonExpr {
if x.shape.ndim() <= 1 {
return CanonExpr::Linear(x.clone());
}

let rows = x.shape.rows();
let cols = x.shape.cols();
let (out_size, mut s_rows, mut s_cols, mut s_vals) = match axis {
0 => (cols, Vec::new(), Vec::new(), Vec::new()),
1 => (rows, Vec::new(), Vec::new(), Vec::new()),
_ => return CanonExpr::Linear(x.clone()),
};

for col in 0..cols {
for row in 0..rows {
let input_idx = row + col * rows;
let output_idx = if axis == 0 { col } else { row };
s_rows.push(output_idx);
s_cols.push(input_idx);
s_vals.push(1.0);
}
}

let selector =
crate::sparse::triplets_to_csc(out_size, x.size(), &s_rows, &s_cols, &s_vals);

let mut new_coeffs = std::collections::HashMap::new();
for (var_id, coeff) in &x.coeffs {
new_coeffs.insert(*var_id, crate::sparse::csc_matmul(&selector, coeff));
}

let flat_const = x
.constant
.clone()
.reshape_generic(nalgebra::Dyn(x.size()), nalgebra::Dyn(1));
let new_const = csc_to_dense(&selector) * flat_const;

CanonExpr::Linear(LinExpr {
coeffs: new_coeffs,
constant: new_const,
shape: Shape::vector(out_size),
})
}

fn canonicalize_reshape(&mut self, a: &Expr, shape: &Shape) -> CanonExpr {
let ca = self.canonicalize_expr(a, false).as_linear().clone();
// Reshape doesn't change the linear structure, just the shape interpretation
Expand Down Expand Up @@ -942,36 +990,26 @@ impl CanonContext {
});
}

// SOC reformulation: ||x||^2 <= t iff SOC(sqrt(t), x)
// Actually: introduce t, s, with t = s + 1, and SOC(s, x)
// Simpler: introduce t >= 0, with ||x||_2^2 <= t via rotated SOC
// Or: ||x||^2 = quad_over_lin(x, 1)
let (_, t) = self.new_nonneg_aux_var(Shape::scalar());

// Rotated SOC: ||x||^2 <= 2 * t * 1 = 2t
// Standard form: || [2t - 1; 2x] ||_2 <= 2t + 1
// Simplified: use SOC with proper reformulation
self.constraints.push(ConeConstraint::SOC {
t: t.clone(),
x: x.clone(),
});
let one = LinExpr::scalar(1.0);
let soc_t = t.add(&one);
let soc_x = self.vstack_lin(&t.add(&one.neg()), &x.scale(2.0));
self.constraints
.push(ConeConstraint::SOC { t: soc_t, x: soc_x });

CanonExpr::Linear(t)
}

fn canonicalize_quad_over_lin(&mut self, x: &Expr, y: &Expr) -> CanonExpr {
// ||x||_2^2 / y: Introduce t, rotated SOC constraint
// ||x||_2^2 / y: introduce t with ||[2x; t-y]||_2 <= t+y.
let cx = self.canonicalize_expr(x, false).as_linear().clone();
let _cy = self.canonicalize_expr(y, false).as_linear().clone();
let cy = self.canonicalize_expr(y, false).as_linear().clone();
let (_, t) = self.new_nonneg_aux_var(Shape::scalar());

// Rotated SOC: ||x||^2 <= t * y
// This requires proper rotated SOC support
// Simplified: add as SOC
self.constraints.push(ConeConstraint::SOC {
t: t.clone(),
x: cx,
});
let soc_t = t.add(&cy);
let soc_x = self.vstack_lin(&t.add(&cy.neg()), &cx.scale(2.0));
self.constraints
.push(ConeConstraint::SOC { t: soc_t, x: soc_x });

CanonExpr::Linear(t)
}
Expand Down Expand Up @@ -1125,11 +1163,6 @@ impl CanonContext {
return CanonExpr::Linear(cx);
}

if (p - 2.0).abs() < 1e-10 {
// x^2 use sum_squares approach (more efficient)
return self.canonicalize_sum_squares(&Expr::from(x), false);
}

// Create auxiliary variable t for the result
let (t_var_id, t) = self.new_nonneg_aux_var(cx.shape.clone());
let _ = t_var_id;
Expand Down Expand Up @@ -1247,8 +1280,8 @@ impl CanonContext {
expr.constant[(idx, 0)]
} else {
// For matrix, compute flat index
let row = idx / expr.shape.cols();
let col = idx % expr.shape.cols();
let row = idx % expr.shape.rows();
let col = idx / expr.shape.rows();
expr.constant[(row, col)]
};

Expand Down
21 changes: 20 additions & 1 deletion src/expr/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,10 @@ impl Expr {
let dims = a.shape();
if dims.ndim() <= 1 {
Shape::scalar()
} else {
} else if *axis == Some(0) {
Shape::vector(dims.cols())
} else {
Shape::vector(dims.rows())
}
} else {
Shape::scalar()
Expand Down Expand Up @@ -570,4 +572,21 @@ mod tests {
assert_eq!(c.shape(), Shape::matrix(3, 1));
assert!(c.is_constant());
}

#[test]
fn test_sum_axis_shape() {
let x = Expr::Variable(VariableData {
id: ExprId::new(),
shape: Shape::matrix(2, 3),
name: None,
nonneg: false,
nonpos: false,
});

assert_eq!(
Expr::Sum(Arc::new(x.clone()), Some(0)).shape(),
Shape::vector(3)
);
assert_eq!(Expr::Sum(Arc::new(x), Some(1)).shape(), Shape::vector(2));
}
}
10 changes: 9 additions & 1 deletion src/solver/clarabel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ use clarabel::solver::{
};

use super::stuffing::{ConeDims, StuffedProblem, VariableMap};
use nalgebra::DMatrix;

use crate::expr::{Array, Evaluable, ExprId};

/// Solution status from the solver.
Expand Down Expand Up @@ -345,7 +347,13 @@ fn unpack_primal(x: &[f64], var_map: &VariableMap) -> HashMap<ExprId, Array> {

for (&var_id, &(start, size)) in &var_map.id_to_col {
let values: Vec<f64> = x[start..start + size].to_vec();
let arr = if size == 1 {
let arr = if let Some(shape) = var_map.shape(var_id) {
if shape.is_scalar() {
Array::Scalar(values[0])
} else {
Array::Dense(DMatrix::from_vec(shape.rows(), shape.cols(), values))
}
} else if size == 1 {
Array::Scalar(values[0])
} else {
Array::from_vec(values)
Expand Down
12 changes: 11 additions & 1 deletion src/solver/stuffing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ impl ConeDims {
pub struct VariableMap {
/// Map from variable ID to (start_col, size).
pub id_to_col: HashMap<ExprId, (usize, usize)>,
/// Original variable shapes, used when unpacking solver results.
pub id_to_shape: HashMap<ExprId, Shape>,
/// Total number of optimization variables.
pub total_vars: usize,
}
Expand All @@ -50,16 +52,19 @@ impl VariableMap {
/// Create from a list of (variable_id, shape) pairs.
pub fn from_vars(vars: &[(ExprId, Shape)]) -> Self {
let mut id_to_col = HashMap::new();
let mut id_to_shape = HashMap::new();
let mut offset = 0;

for (var_id, shape) in vars {
let size = shape.size();
id_to_col.insert(*var_id, (offset, size));
id_to_shape.insert(*var_id, shape.clone());
offset += size;
}

VariableMap {
id_to_col,
id_to_shape,
total_vars: offset,
}
}
Expand All @@ -68,6 +73,11 @@ impl VariableMap {
pub fn get(&self, var_id: ExprId) -> Option<(usize, usize)> {
self.id_to_col.get(&var_id).copied()
}

/// Get the original shape for a variable.
pub fn shape(&self, var_id: ExprId) -> Option<&Shape> {
self.id_to_shape.get(&var_id)
}
}

/// Stuffed problem ready for Clarabel.
Expand Down Expand Up @@ -424,7 +434,7 @@ fn stuff_linear_expr(
let size = expr.size();
for i in 0..expr.constant.nrows() {
for j in 0..expr.constant.ncols() {
let idx = i * expr.constant.ncols() + j;
let idx = i + j * expr.constant.nrows();
if idx < size {
let const_val = expr.constant[(i, j)];
// For Zero (negate=false): b = -constant
Expand Down
Loading
Loading