diff --git a/src/atoms/affine.rs b/src/atoms/affine.rs index a519659..a6439a0 100644 --- a/src/atoms/affine.rs +++ b/src/atoms/affine.rs @@ -9,7 +9,7 @@ use std::ops::{Add, Div, Mul, Neg, Sub}; use std::sync::Arc; -use crate::expr::{Expr, Shape, constant}; +use crate::expr::{AxisIndex, Expr, IndexSpec, Shape, constant}; // ============================================================================ // Operator overloading for Expr @@ -242,14 +242,91 @@ pub fn dot(a: &Expr, b: &Expr) -> Expr { /// Index into an expression. pub fn index(expr: &Expr, idx: usize) -> Expr { - use crate::expr::IndexSpec; - Expr::Index(Arc::new(expr.clone()), IndexSpec::element(vec![idx])) + select(expr, AxisIndex::Index(idx), AxisIndex::All) } /// Slice a range from an expression. pub fn slice(expr: &Expr, start: usize, stop: usize) -> Expr { - use crate::expr::IndexSpec; - Expr::Index(Arc::new(expr.clone()), IndexSpec::range(start, stop)) + select(expr, AxisIndex::Slice(start, stop), AxisIndex::All) +} + +/// Index a matrix column. +pub fn indexc(expr: &Expr, idx: usize) -> Expr { + select(expr, AxisIndex::All, AxisIndex::Index(idx)) +} + +/// Slice a range of matrix columns. +pub fn slicec(expr: &Expr, start: usize, stop: usize) -> Expr { + select(expr, AxisIndex::All, AxisIndex::Slice(start, stop)) +} + +/// Select rows and columns from an expression. +pub fn select(expr: &Expr, rows: AxisIndex, cols: AxisIndex) -> Expr { + let shape = expr.shape(); + assert!(!shape.is_scalar(), "cannot select from a scalar expression"); + + let spec = if shape.is_vector() { + assert!( + cols == AxisIndex::All, + "vector selection only supports AxisIndex::All for columns" + ); + IndexSpec { + ranges: vec![axis_index_to_range(rows, shape.rows(), "first")], + drop_axes: vec![axis_index_drops_axis(rows)], + } + } else if shape.is_matrix() { + IndexSpec { + ranges: vec![ + axis_index_to_range(rows, shape.rows(), "row"), + axis_index_to_range(cols, shape.cols(), "column"), + ], + drop_axes: vec![axis_index_drops_axis(rows), axis_index_drops_axis(cols)], + } + } else { + panic!("select only supports vector and matrix expressions"); + }; + Expr::Index(Arc::new(expr.clone()), spec) +} + +fn axis_index_to_range( + selector: AxisIndex, + axis_len: usize, + axis_name: &str, +) -> Option<(usize, usize, usize)> { + match selector { + AxisIndex::Index(idx) => { + assert!( + idx < axis_len, + "{} index {} out of bounds for axis with length {}", + axis_name, + idx, + axis_len + ); + Some((idx, idx + 1, 1)) + } + AxisIndex::Slice(start, stop) => { + assert!( + start <= stop, + "{} slice start {} must be less than or equal to stop {}", + axis_name, + start, + stop + ); + assert!( + stop <= axis_len, + "{} slice stop {} out of bounds for axis with length {}", + axis_name, + stop, + axis_len + ); + Some((start, stop, 1)) + } + AxisIndex::All => None, + } +} + +fn axis_index_drops_axis(selector: AxisIndex) -> bool { + matches!(selector, AxisIndex::Index(_)) } /// Cumulative sum along an axis. @@ -327,6 +404,85 @@ mod tests { assert_eq!(b.shape(), Shape::vector(3)); } + #[test] + fn test_index_and_slice_shapes() { + let x = variable(10); + assert_eq!(index(&x, 1).shape(), Shape::scalar()); + assert_eq!(slice(&x, 0, 5).shape(), Shape::vector(5)); + assert_eq!( + select(&x, AxisIndex::Slice(1, 3), AxisIndex::All).shape(), + Shape::vector(2) + ); + + let x = variable((10, 10)); + assert_eq!(index(&x, 1).shape(), Shape::vector(10)); + assert_eq!(slice(&x, 0, 5).shape(), Shape::matrix(5, 10)); + assert_eq!(slice(&x, 1, 2).shape(), Shape::matrix(1, 10)); + assert_eq!( + select(&x, AxisIndex::All, AxisIndex::Index(2)).shape(), + Shape::vector(10) + ); + assert_eq!(indexc(&x, 2).shape(), Shape::vector(10)); + assert_eq!(slicec(&x, 1, 3).shape(), Shape::matrix(10, 2)); + assert_eq!( + select(&x, AxisIndex::Index(1), AxisIndex::Index(2)).shape(), + Shape::scalar() + ); + assert_eq!( + select(&x, AxisIndex::Slice(0, 5), AxisIndex::Slice(1, 3)).shape(), + Shape::matrix(5, 2) + ); + } + + #[test] + #[should_panic(expected = "row index 10 out of bounds")] + fn test_matrix_index_out_of_bounds_panics() { + let x = variable((10, 10)); + let _ = index(&x, 10); + } + + #[test] + #[should_panic(expected = "row slice stop 11 out of bounds")] + fn test_matrix_slice_stop_out_of_bounds_panics() { + let x = variable((10, 10)); + let _ = slice(&x, 0, 11); + } + + #[test] + #[should_panic(expected = "row slice start 5 must be less than or equal to stop 3")] + fn test_slice_start_after_stop_panics() { + let x = variable((10, 10)); + let _ = slice(&x, 5, 3); + } + + #[test] + #[should_panic(expected = "cannot select from a scalar expression")] + fn test_scalar_index_panics() { + let x = variable(()); + let _ = index(&x, 0); + } + + #[test] + #[should_panic(expected = "vector selection only supports AxisIndex::All for columns")] + fn test_vector_column_select_panics() { + let x = variable(10); + let _ = select(&x, AxisIndex::All, AxisIndex::Index(0)); + } + + #[test] + #[should_panic(expected = "column index 10 out of bounds")] + fn test_matrix_column_index_out_of_bounds_panics() { + let x = variable((10, 10)); + let _ = select(&x, AxisIndex::All, AxisIndex::Index(10)); + } + + #[test] + #[should_panic(expected = "column slice stop 11 out of bounds")] + fn test_matrix_column_slice_stop_out_of_bounds_panics() { + let x = variable((10, 10)); + let _ = select(&x, AxisIndex::All, AxisIndex::Slice(0, 11)); + } + #[test] fn test_vstack() { let x = variable((2, 3)); diff --git a/src/atoms/mod.rs b/src/atoms/mod.rs index 8d086c0..4dac29c 100644 --- a/src/atoms/mod.rs +++ b/src/atoms/mod.rs @@ -10,8 +10,8 @@ pub mod nonlinear; // Re-export affine operations pub use affine::{ - cumsum, diag, dot, flatten, hstack, index, matmul, reshape, slice, sum, sum_axis, trace, - transpose, vstack, + cumsum, diag, dot, flatten, hstack, index, indexc, matmul, reshape, select, slice, slicec, sum, + sum_axis, trace, transpose, vstack, }; // Re-export nonlinear atoms diff --git a/src/canon/canonicalizer.rs b/src/canon/canonicalizer.rs index c1535aa..44f4563 100644 --- a/src/canon/canonicalizer.rs +++ b/src/canon/canonicalizer.rs @@ -531,72 +531,70 @@ impl CanonContext { fn canonicalize_index(&mut self, a: &Expr, spec: &IndexSpec) -> CanonExpr { let ca = self.canonicalize_expr(a, false).as_linear().clone(); - - // For now, handle the common case: 1D vector indexing with a single range - // The LinExpr stores data in column-major (flattened) order - if spec.ranges.len() == 1 { - if let Some((start, stop, step)) = spec.ranges[0] { - // Simple range indexing on a vector - let input_size = ca.shape.size(); - - // Compute output indices - let output_indices: Vec = (start..stop).step_by(step).collect(); - let output_size = output_indices.len(); - - // Build selection matrix: S[i, output_indices[i]] = 1 - // S has shape (output_size, input_size) - let mut s_rows = Vec::new(); - let mut s_cols = Vec::new(); - let mut s_vals = Vec::new(); - for (out_idx, &in_idx) in output_indices.iter().enumerate() { - if in_idx < input_size { - s_rows.push(out_idx); - s_cols.push(in_idx); - s_vals.push(1.0); + let new_shape = spec.output_shape(&ca.shape); + let input_rows = ca.shape.rows(); + let input_cols = ca.shape.cols(); + + let input_indices = match spec.ranges.as_slice() { + [Some((start, stop, step))] => (*start..*stop).step_by(*step).collect(), + [None] => (0..ca.shape.size()).collect(), + [row_spec, col_spec] => { + let row_range: Vec = match row_spec { + Some((start, stop, step)) => (*start..*stop).step_by(*step).collect(), + None => (0..input_rows).collect(), + }; + let col_range: Vec = match col_spec { + Some((start, stop, step)) => (*start..*stop).step_by(*step).collect(), + None => (0..input_cols).collect(), + }; + let mut indices = Vec::with_capacity(row_range.len() * col_range.len()); + for col in col_range { + for &row in &row_range { + indices.push(row + col * input_rows); } } - let s_mat = crate::sparse::triplets_to_csc( - output_size, - input_size, - &s_rows, - &s_cols, - &s_vals, - ); - - // Apply selection to each coefficient: new_A = S @ A - let mut new_coeffs = std::collections::HashMap::new(); - for (var_id, coeff) in &ca.coeffs { - let new_coeff = crate::sparse::csc_matmul(&s_mat, coeff); - new_coeffs.insert(*var_id, new_coeff); - } - - // Apply selection to constant: flatten, select rows, reshape - let const_flat: Vec = ca.constant.iter().cloned().collect(); - let new_const_vals: Vec = output_indices - .iter() - .map(|&i| { - if i < const_flat.len() { - const_flat[i] - } else { - 0.0 - } - }) - .collect(); - let new_const = DMatrix::from_vec(output_size, 1, new_const_vals); - - let new_shape = Shape::vector(output_size); + indices + } + _ => (0..ca.shape.size()).collect(), + }; - return CanonExpr::Linear(LinExpr { - coeffs: new_coeffs, - constant: new_const, - shape: new_shape, - }); + let output_size = input_indices.len(); + let mut s_rows = Vec::new(); + let mut s_cols = Vec::new(); + let mut s_vals = Vec::new(); + for (out_idx, &in_idx) in input_indices.iter().enumerate() { + if in_idx < ca.shape.size() { + s_rows.push(out_idx); + s_cols.push(in_idx); + s_vals.push(1.0); } } + let s_mat = + crate::sparse::triplets_to_csc(output_size, ca.shape.size(), &s_rows, &s_cols, &s_vals); - // Fallback for None ranges (take all) - return unchanged - // Note: 2D matrix indexing could be added in future versions - CanonExpr::Linear(ca) + let mut new_coeffs = std::collections::HashMap::new(); + for (var_id, coeff) in &ca.coeffs { + new_coeffs.insert(*var_id, crate::sparse::csc_matmul(&s_mat, coeff)); + } + + let const_flat: Vec = ca.constant.iter().cloned().collect(); + let new_const_vals: Vec = input_indices + .iter() + .map(|&i| { + if i < const_flat.len() { + const_flat[i] + } else { + 0.0 + } + }) + .collect(); + let new_const = DMatrix::from_vec(new_shape.rows(), new_shape.cols(), new_const_vals); + + CanonExpr::Linear(LinExpr { + coeffs: new_coeffs, + constant: new_const, + shape: new_shape, + }) } fn canonicalize_vstack(&mut self, exprs: &[Arc]) -> CanonExpr { diff --git a/src/expr/eval.rs b/src/expr/eval.rs index 998cdc4..dcf6f68 100644 --- a/src/expr/eval.rs +++ b/src/expr/eval.rs @@ -295,15 +295,17 @@ fn eval_index(a: Array, spec: &IndexSpec) -> crate::Result { match spec.ranges.as_slice() { [Some((start, stop, step))] => { + if *stop > nrows * ncols { + return Err(crate::CvxError::InvalidProblem( + "Index out of bounds in eval".into(), + )); + } let indices: Vec = (*start..*stop).step_by(*step).collect(); // For column vectors, index rows; otherwise index flat (column-major) let data: Vec = if ncols == 1 { indices.iter().map(|&i| m[(i, 0)]).collect() } else { - indices - .iter() - .map(|&i| *m.iter().nth(i).unwrap_or(&0.0)) - .collect() + indices.iter().map(|&i| m.as_slice()[i]).collect() }; if data.len() == 1 { Ok(Array::Scalar(data[0])) @@ -321,6 +323,21 @@ fn eval_index(a: Array, spec: &IndexSpec) -> crate::Result { Some((s, e, step)) => (*s..*e).step_by(*step).collect(), None => (0..ncols).collect(), }; + if row_range.iter().any(|&i| i >= nrows) || col_range.iter().any(|&i| i >= ncols) { + return Err(crate::CvxError::InvalidProblem( + "Index out of bounds in eval".into(), + )); + } + let row_drop = spec.drop_axes.first().copied().unwrap_or(false); + let col_drop = spec.drop_axes.get(1).copied().unwrap_or(false); + if row_drop && !col_drop { + let data: Vec = col_range.iter().map(|&j| m[(row_range[0], j)]).collect(); + return Ok(Array::Dense(DMatrix::from_vec(data.len(), 1, data))); + } + if !row_drop && col_drop { + let data: Vec = row_range.iter().map(|&i| m[(i, col_range[0])]).collect(); + return Ok(Array::Dense(DMatrix::from_vec(data.len(), 1, data))); + } let result = DMatrix::from_fn(row_range.len(), col_range.len(), |i, j| { m[(row_range[i], col_range[j])] }); @@ -531,6 +548,7 @@ fn eval_diag(a: Array) -> Array { #[cfg(test)] mod tests { use super::*; + use crate::atoms::{index, indexc, select, slice, slicec}; use crate::expr::{constant, variable}; use crate::prelude::*; use std::collections::HashMap; @@ -622,6 +640,142 @@ mod tests { assert!((v - 6.0).abs() < 1e-10); } + #[test] + fn test_eval_matrix_first_axis_slice() { + let x = constant_dmatrix(DMatrix::from_row_slice( + 3, + 4, + &[ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + ], + )); + let (_, ctx) = make_var_scalar(0.0); + let value = slice(&x, 0, 2).value(&ctx); + + if let Array::Dense(m) = value { + assert_eq!(m.nrows(), 2); + assert_eq!(m.ncols(), 4); + assert_eq!(m[(0, 0)], 1.0); + assert_eq!(m[(1, 0)], 5.0); + assert_eq!(m[(0, 3)], 4.0); + assert_eq!(m[(1, 3)], 8.0); + } else { + panic!("expected dense matrix"); + } + } + + #[test] + fn test_eval_matrix_first_axis_index() { + let x = constant_dmatrix(DMatrix::from_row_slice( + 3, + 4, + &[ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + ], + )); + let (_, ctx) = make_var_scalar(0.0); + let value = index(&x, 1).value(&ctx); + + if let Array::Dense(m) = value { + assert_eq!(m.nrows(), 4); + assert_eq!(m.ncols(), 1); + assert_eq!(m[(0, 0)], 5.0); + assert_eq!(m[(1, 0)], 6.0); + assert_eq!(m[(2, 0)], 7.0); + assert_eq!(m[(3, 0)], 8.0); + } else { + panic!("expected dense vector"); + } + } + + #[test] + fn test_eval_matrix_column_index() { + let x = constant_dmatrix(DMatrix::from_row_slice( + 3, + 4, + &[ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + ], + )); + let (_, ctx) = make_var_scalar(0.0); + let value = indexc(&x, 2).value(&ctx); + + if let Array::Dense(m) = value { + assert_eq!(m.nrows(), 3); + assert_eq!(m.ncols(), 1); + assert_eq!(m[(0, 0)], 3.0); + assert_eq!(m[(1, 0)], 7.0); + assert_eq!(m[(2, 0)], 11.0); + } else { + panic!("expected dense vector"); + } + } + + #[test] + fn test_eval_matrix_scalar_select() { + let x = constant_dmatrix(DMatrix::from_row_slice( + 3, + 4, + &[ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + ], + )); + let (_, ctx) = make_var_scalar(0.0); + let value = select(&x, AxisIndex::Index(1), AxisIndex::Index(2)).value(&ctx); + + assert_eq!(value.as_scalar(), Some(7.0)); + } + + #[test] + fn test_eval_matrix_rectangular_select() { + let x = constant_dmatrix(DMatrix::from_row_slice( + 3, + 4, + &[ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + ], + )); + let (_, ctx) = make_var_scalar(0.0); + let value = select(&x, AxisIndex::Slice(0, 2), AxisIndex::Slice(1, 3)).value(&ctx); + + if let Array::Dense(m) = value { + assert_eq!(m.nrows(), 2); + assert_eq!(m.ncols(), 2); + assert_eq!(m[(0, 0)], 2.0); + assert_eq!(m[(1, 0)], 6.0); + assert_eq!(m[(0, 1)], 3.0); + assert_eq!(m[(1, 1)], 7.0); + } else { + panic!("expected dense matrix"); + } + } + + #[test] + fn test_eval_matrix_column_slice_alias() { + let x = constant_dmatrix(DMatrix::from_row_slice( + 3, + 4, + &[ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + ], + )); + let (_, ctx) = make_var_scalar(0.0); + let value = slicec(&x, 1, 3).value(&ctx); + + if let Array::Dense(m) = value { + assert_eq!(m.nrows(), 3); + assert_eq!(m.ncols(), 2); + assert_eq!(m[(0, 0)], 2.0); + assert_eq!(m[(1, 0)], 6.0); + assert_eq!(m[(2, 0)], 10.0); + assert_eq!(m[(0, 1)], 3.0); + assert_eq!(m[(1, 1)], 7.0); + assert_eq!(m[(2, 1)], 11.0); + } else { + panic!("expected dense matrix"); + } + } + #[test] fn test_eval_sum_squares() { let (x, ctx) = make_var_vec(vec![1.0, 2.0, 3.0]); diff --git a/src/expr/expression.rs b/src/expr/expression.rs index 1c8e93e..8559796 100644 --- a/src/expr/expression.rs +++ b/src/expr/expression.rs @@ -187,19 +187,34 @@ impl ConstantData { } } +/// Selector for one axis of an indexing operation. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AxisIndex { + /// Select a single index and drop this axis. + Index(usize), + /// Select a half-open range and preserve this axis. + Slice(usize, usize), + /// Select the entire axis and preserve it. + All, +} + /// Specification for indexing operations. #[derive(Debug, Clone)] pub struct IndexSpec { /// Ranges for each dimension: (start, stop, step). /// None means take the whole dimension. pub ranges: Vec>, + /// Whether each indexed dimension was selected by an integer index. + pub drop_axes: Vec, } impl IndexSpec { /// Create an index spec for a single element. pub fn element(indices: Vec) -> Self { + let len = indices.len(); IndexSpec { ranges: indices.into_iter().map(|i| Some((i, i + 1, 1))).collect(), + drop_axes: vec![true; len], } } @@ -207,12 +222,40 @@ impl IndexSpec { pub fn range(start: usize, stop: usize) -> Self { IndexSpec { ranges: vec![Some((start, stop, 1))], + drop_axes: vec![false], } } /// Create an index spec that takes everything. pub fn all() -> Self { - IndexSpec { ranges: vec![None] } + IndexSpec { + ranges: vec![None], + drop_axes: vec![false], + } + } + + /// Compute the output shape when this spec is applied to a base shape. + pub fn output_shape(&self, base: &Shape) -> Shape { + let mut new_dims = Vec::new(); + for i in 0..base.ndim() { + let range = self.ranges.get(i).copied().flatten(); + let drop_axis = self.drop_axes.get(i).copied().unwrap_or(false); + match range { + Some((start, stop, step)) => { + let size = (stop - start).div_ceil(step); + if !drop_axis { + new_dims.push(size); + } + } + None => new_dims.push(base.dims()[i]), + } + } + + if new_dims.is_empty() { + Shape::scalar() + } else { + Shape::from_dims(new_dims) + } } } @@ -324,31 +367,7 @@ impl Expr { } } Expr::Reshape(_, shape) => shape.clone(), - Expr::Index(a, spec) => { - // Simplified: compute resulting shape from index spec - let base = a.shape(); - let mut new_dims = Vec::new(); - for (i, r) in spec.ranges.iter().enumerate() { - match r { - Some((start, stop, step)) => { - let size = (stop - start).div_ceil(*step); - if size > 1 { - new_dims.push(size); - } - } - None => { - if i < base.ndim() { - new_dims.push(base.dims()[i]); - } - } - } - } - if new_dims.is_empty() { - Shape::scalar() - } else { - Shape::from_dims(new_dims) - } - } + Expr::Index(a, spec) => spec.output_shape(&a.shape()), Expr::VStack(exprs) => { if exprs.is_empty() { return Shape::scalar(); diff --git a/src/expr/mod.rs b/src/expr/mod.rs index ed76b2d..6d416e2 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -18,7 +18,7 @@ pub use constant::{ ones, zeros, }; pub use eval::Evaluable; -pub use expression::{Array, ConstantData, Expr, ExprId, IndexSpec, VariableData}; +pub use expression::{Array, AxisIndex, ConstantData, Expr, ExprId, IndexSpec, VariableData}; pub use shape::Shape; pub use variable::{ VariableBuilder, VariableExt, matrix_var, named_variable, nonneg_variable, nonpos_variable, diff --git a/src/lib.rs b/src/lib.rs index 80c895e..7ffc100 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -79,16 +79,17 @@ pub mod sparse; pub mod prelude { // Expression types pub use crate::expr::{ - Array, Evaluable, Expr, ExprId, IntoConstant, Shape, VariableBuilder, VariableExt, - constant, constant_dmatrix, constant_matrix, constant_sparse, constant_vec, eye, ones, - variable, zeros, + Array, AxisIndex, Evaluable, Expr, ExprId, IntoConstant, Shape, VariableBuilder, + VariableExt, constant, constant_dmatrix, constant_matrix, constant_sparse, constant_vec, + eye, ones, variable, zeros, }; // Atoms pub use crate::atoms::{ - abs, cumsum, diag, dot, entropy, exp, flatten, hstack, log, matmul, max2, maximum, min2, - minimum, neg_part, norm, norm_inf, norm1, norm2, pos, power, quad_form, quad_over_lin, - reshape, sqrt, sum, sum_axis, sum_squares, trace, transpose, try_norm, vstack, + abs, cumsum, diag, dot, entropy, exp, flatten, hstack, indexc, log, matmul, max2, maximum, + min2, minimum, neg_part, norm, norm_inf, norm1, norm2, pos, power, quad_form, + quad_over_lin, reshape, select, slicec, sqrt, sum, sum_axis, sum_squares, trace, transpose, + try_norm, vstack, }; // Constraints diff --git a/tests/solver_tests.rs b/tests/solver_tests.rs index 9969555..a6e5b48 100644 --- a/tests/solver_tests.rs +++ b/tests/solver_tests.rs @@ -1,3 +1,4 @@ +use cvxrust::atoms::{index, indexc, select, slice, slicec}; use cvxrust::prelude::*; use nalgebra::DMatrix; @@ -23,3 +24,143 @@ fn test_matrix_solution_recovery_preserves_shape_and_order() { panic!("expected dense matrix solution"); } } + +#[test] +fn test_matrix_first_axis_slice_constraints_affect_selected_rows() { + let x = variable((3, 4)); + + let sol = Problem::minimize(sum(&x)) + .subject_to([x.ge(0.0), slice(&x, 0, 2).ge(1.0)]) + .solve() + .expect("problem should solve"); + + let vals = x.value(&sol); + if let Array::Dense(m) = vals { + for col in 0..4 { + assert!((m[(0, col)] - 1.0).abs() < TOL); + assert!((m[(1, col)] - 1.0).abs() < TOL); + assert!(m[(2, col)].abs() < TOL); + } + } else { + panic!("expected dense matrix solution"); + } +} + +#[test] +fn test_matrix_first_axis_index_constraints_affect_selected_row() { + let x = variable((3, 4)); + + let sol = Problem::minimize(sum(&x)) + .subject_to([x.ge(0.0), index(&x, 1).ge(2.0)]) + .solve() + .expect("problem should solve"); + + let vals = x.value(&sol); + if let Array::Dense(m) = vals { + for col in 0..4 { + assert!(m[(0, col)].abs() < TOL); + assert!((m[(1, col)] - 2.0).abs() < TOL); + assert!(m[(2, col)].abs() < TOL); + } + } else { + panic!("expected dense matrix solution"); + } +} + +#[test] +fn test_matrix_column_select_constraints_affect_selected_column() { + let x = variable((3, 4)); + + let sol = Problem::minimize(sum(&x)) + .subject_to([x.ge(0.0), indexc(&x, 2).ge(3.0)]) + .solve() + .expect("problem should solve"); + + let vals = x.value(&sol); + if let Array::Dense(m) = vals { + for row in 0..3 { + assert!(m[(row, 0)].abs() < TOL); + assert!(m[(row, 1)].abs() < TOL); + assert!((m[(row, 2)] - 3.0).abs() < TOL); + assert!(m[(row, 3)].abs() < TOL); + } + } else { + panic!("expected dense matrix solution"); + } +} + +#[test] +fn test_matrix_column_slice_alias_constraints_affect_selected_columns() { + let x = variable((3, 4)); + + let sol = Problem::minimize(sum(&x)) + .subject_to([x.ge(0.0), slicec(&x, 1, 3).ge(6.0)]) + .solve() + .expect("problem should solve"); + + let vals = x.value(&sol); + if let Array::Dense(m) = vals { + for row in 0..3 { + assert!(m[(row, 0)].abs() < TOL); + assert!((m[(row, 1)] - 6.0).abs() < TOL); + assert!((m[(row, 2)] - 6.0).abs() < TOL); + assert!(m[(row, 3)].abs() < TOL); + } + } else { + panic!("expected dense matrix solution"); + } +} + +#[test] +fn test_matrix_scalar_select_constraints_affect_selected_entry() { + let x = variable((3, 4)); + + let sol = Problem::minimize(sum(&x)) + .subject_to([ + x.ge(0.0), + select(&x, AxisIndex::Index(1), AxisIndex::Index(2)).ge(4.0), + ]) + .solve() + .expect("problem should solve"); + + let vals = x.value(&sol); + if let Array::Dense(m) = vals { + for row in 0..3 { + for col in 0..4 { + let expected = if row == 1 && col == 2 { 4.0 } else { 0.0 }; + assert!((m[(row, col)] - expected).abs() < TOL); + } + } + } else { + panic!("expected dense matrix solution"); + } +} + +#[test] +fn test_matrix_rectangular_select_constraints_affect_selected_block() { + let x = variable((3, 4)); + + let sol = Problem::minimize(sum(&x)) + .subject_to([ + x.ge(0.0), + select(&x, AxisIndex::Slice(0, 2), AxisIndex::Slice(1, 3)).ge(5.0), + ]) + .solve() + .expect("problem should solve"); + + let vals = x.value(&sol); + if let Array::Dense(m) = vals { + for row in 0..3 { + for col in 0..4 { + let expected = if row < 2 && (1..3).contains(&col) { + 5.0 + } else { + 0.0 + }; + assert!((m[(row, col)] - expected).abs() < TOL); + } + } + } else { + panic!("expected dense matrix solution"); + } +}