Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 161 additions & 5 deletions src/atoms/affine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
use std::ops::{Add, Div, Mul, Neg, Sub};
use std::sync::Arc;

use crate::expr::{Expr, Shape, constant};
use crate::expr::{AxisIndex, Expr, IndexSpec, Shape, constant};

// ============================================================================
// Operator overloading for Expr
Expand Down Expand Up @@ -242,14 +242,91 @@ pub fn dot(a: &Expr, b: &Expr) -> Expr {

/// Index into an expression.
pub fn index(expr: &Expr, idx: usize) -> Expr {
use crate::expr::IndexSpec;
Expr::Index(Arc::new(expr.clone()), IndexSpec::element(vec![idx]))
select(expr, AxisIndex::Index(idx), AxisIndex::All)
}

/// Slice a range from an expression.
pub fn slice(expr: &Expr, start: usize, stop: usize) -> Expr {
use crate::expr::IndexSpec;
Expr::Index(Arc::new(expr.clone()), IndexSpec::range(start, stop))
select(expr, AxisIndex::Slice(start, stop), AxisIndex::All)
}

/// Index a matrix column.
pub fn indexc(expr: &Expr, idx: usize) -> Expr {
select(expr, AxisIndex::All, AxisIndex::Index(idx))
}

/// Slice a range of matrix columns.
pub fn slicec(expr: &Expr, start: usize, stop: usize) -> Expr {
select(expr, AxisIndex::All, AxisIndex::Slice(start, stop))
}

/// Select rows and columns from an expression.
pub fn select(expr: &Expr, rows: AxisIndex, cols: AxisIndex) -> Expr {
let shape = expr.shape();
assert!(!shape.is_scalar(), "cannot select from a scalar expression");

let spec = if shape.is_vector() {
assert!(
cols == AxisIndex::All,
"vector selection only supports AxisIndex::All for columns"
);
IndexSpec {
ranges: vec![axis_index_to_range(rows, shape.rows(), "first")],
drop_axes: vec![axis_index_drops_axis(rows)],
}
} else if shape.is_matrix() {
IndexSpec {
ranges: vec![
axis_index_to_range(rows, shape.rows(), "row"),
axis_index_to_range(cols, shape.cols(), "column"),
],
drop_axes: vec![axis_index_drops_axis(rows), axis_index_drops_axis(cols)],
}
} else {
panic!("select only supports vector and matrix expressions");
};
Expr::Index(Arc::new(expr.clone()), spec)
}

fn axis_index_to_range(
selector: AxisIndex,
axis_len: usize,
axis_name: &str,
) -> Option<(usize, usize, usize)> {
match selector {
AxisIndex::Index(idx) => {
assert!(
idx < axis_len,
"{} index {} out of bounds for axis with length {}",
axis_name,
idx,
axis_len
);
Some((idx, idx + 1, 1))
}
AxisIndex::Slice(start, stop) => {
assert!(
start <= stop,
"{} slice start {} must be less than or equal to stop {}",
axis_name,
start,
stop
);
assert!(
stop <= axis_len,
"{} slice stop {} out of bounds for axis with length {}",
axis_name,
stop,
axis_len
);
Some((start, stop, 1))
}
AxisIndex::All => None,
}
}

fn axis_index_drops_axis(selector: AxisIndex) -> bool {
matches!(selector, AxisIndex::Index(_))
}

/// Cumulative sum along an axis.
Expand Down Expand Up @@ -327,6 +404,85 @@ mod tests {
assert_eq!(b.shape(), Shape::vector(3));
}

#[test]
fn test_index_and_slice_shapes() {
let x = variable(10);
assert_eq!(index(&x, 1).shape(), Shape::scalar());
assert_eq!(slice(&x, 0, 5).shape(), Shape::vector(5));
assert_eq!(
select(&x, AxisIndex::Slice(1, 3), AxisIndex::All).shape(),
Shape::vector(2)
);

let x = variable((10, 10));
assert_eq!(index(&x, 1).shape(), Shape::vector(10));
assert_eq!(slice(&x, 0, 5).shape(), Shape::matrix(5, 10));
assert_eq!(slice(&x, 1, 2).shape(), Shape::matrix(1, 10));
assert_eq!(
select(&x, AxisIndex::All, AxisIndex::Index(2)).shape(),
Shape::vector(10)
);
assert_eq!(indexc(&x, 2).shape(), Shape::vector(10));
assert_eq!(slicec(&x, 1, 3).shape(), Shape::matrix(10, 2));
assert_eq!(
select(&x, AxisIndex::Index(1), AxisIndex::Index(2)).shape(),
Shape::scalar()
);
assert_eq!(
select(&x, AxisIndex::Slice(0, 5), AxisIndex::Slice(1, 3)).shape(),
Shape::matrix(5, 2)
);
}

#[test]
#[should_panic(expected = "row index 10 out of bounds")]
fn test_matrix_index_out_of_bounds_panics() {
let x = variable((10, 10));
let _ = index(&x, 10);
}

#[test]
#[should_panic(expected = "row slice stop 11 out of bounds")]
fn test_matrix_slice_stop_out_of_bounds_panics() {
let x = variable((10, 10));
let _ = slice(&x, 0, 11);
}

#[test]
#[should_panic(expected = "row slice start 5 must be less than or equal to stop 3")]
fn test_slice_start_after_stop_panics() {
let x = variable((10, 10));
let _ = slice(&x, 5, 3);
}

#[test]
#[should_panic(expected = "cannot select from a scalar expression")]
fn test_scalar_index_panics() {
let x = variable(());
let _ = index(&x, 0);
}

#[test]
#[should_panic(expected = "vector selection only supports AxisIndex::All for columns")]
fn test_vector_column_select_panics() {
let x = variable(10);
let _ = select(&x, AxisIndex::All, AxisIndex::Index(0));
}

#[test]
#[should_panic(expected = "column index 10 out of bounds")]
fn test_matrix_column_index_out_of_bounds_panics() {
let x = variable((10, 10));
let _ = select(&x, AxisIndex::All, AxisIndex::Index(10));
}

#[test]
#[should_panic(expected = "column slice stop 11 out of bounds")]
fn test_matrix_column_slice_stop_out_of_bounds_panics() {
let x = variable((10, 10));
let _ = select(&x, AxisIndex::All, AxisIndex::Slice(0, 11));
}

#[test]
fn test_vstack() {
let x = variable((2, 3));
Expand Down
4 changes: 2 additions & 2 deletions src/atoms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ pub mod nonlinear;

// Re-export affine operations
pub use affine::{
cumsum, diag, dot, flatten, hstack, index, matmul, reshape, slice, sum, sum_axis, trace,
transpose, vstack,
cumsum, diag, dot, flatten, hstack, index, indexc, matmul, reshape, select, slice, slicec, sum,
sum_axis, trace, transpose, vstack,
};

// Re-export nonlinear atoms
Expand Down
118 changes: 58 additions & 60 deletions src/canon/canonicalizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -531,72 +531,70 @@ impl CanonContext {

fn canonicalize_index(&mut self, a: &Expr, spec: &IndexSpec) -> CanonExpr {
let ca = self.canonicalize_expr(a, false).as_linear().clone();

// For now, handle the common case: 1D vector indexing with a single range
// The LinExpr stores data in column-major (flattened) order
if spec.ranges.len() == 1 {
if let Some((start, stop, step)) = spec.ranges[0] {
// Simple range indexing on a vector
let input_size = ca.shape.size();

// Compute output indices
let output_indices: Vec<usize> = (start..stop).step_by(step).collect();
let output_size = output_indices.len();

// Build selection matrix: S[i, output_indices[i]] = 1
// S has shape (output_size, input_size)
let mut s_rows = Vec::new();
let mut s_cols = Vec::new();
let mut s_vals = Vec::new();
for (out_idx, &in_idx) in output_indices.iter().enumerate() {
if in_idx < input_size {
s_rows.push(out_idx);
s_cols.push(in_idx);
s_vals.push(1.0);
let new_shape = spec.output_shape(&ca.shape);
let input_rows = ca.shape.rows();
let input_cols = ca.shape.cols();

let input_indices = match spec.ranges.as_slice() {
[Some((start, stop, step))] => (*start..*stop).step_by(*step).collect(),
[None] => (0..ca.shape.size()).collect(),
[row_spec, col_spec] => {
let row_range: Vec<usize> = match row_spec {
Some((start, stop, step)) => (*start..*stop).step_by(*step).collect(),
None => (0..input_rows).collect(),
};
let col_range: Vec<usize> = match col_spec {
Some((start, stop, step)) => (*start..*stop).step_by(*step).collect(),
None => (0..input_cols).collect(),
};
let mut indices = Vec::with_capacity(row_range.len() * col_range.len());
for col in col_range {
for &row in &row_range {
indices.push(row + col * input_rows);
}
}
let s_mat = crate::sparse::triplets_to_csc(
output_size,
input_size,
&s_rows,
&s_cols,
&s_vals,
);

// Apply selection to each coefficient: new_A = S @ A
let mut new_coeffs = std::collections::HashMap::new();
for (var_id, coeff) in &ca.coeffs {
let new_coeff = crate::sparse::csc_matmul(&s_mat, coeff);
new_coeffs.insert(*var_id, new_coeff);
}

// Apply selection to constant: flatten, select rows, reshape
let const_flat: Vec<f64> = ca.constant.iter().cloned().collect();
let new_const_vals: Vec<f64> = output_indices
.iter()
.map(|&i| {
if i < const_flat.len() {
const_flat[i]
} else {
0.0
}
})
.collect();
let new_const = DMatrix::from_vec(output_size, 1, new_const_vals);

let new_shape = Shape::vector(output_size);
indices
}
_ => (0..ca.shape.size()).collect(),
};

return CanonExpr::Linear(LinExpr {
coeffs: new_coeffs,
constant: new_const,
shape: new_shape,
});
let output_size = input_indices.len();
let mut s_rows = Vec::new();
let mut s_cols = Vec::new();
let mut s_vals = Vec::new();
for (out_idx, &in_idx) in input_indices.iter().enumerate() {
if in_idx < ca.shape.size() {
s_rows.push(out_idx);
s_cols.push(in_idx);
s_vals.push(1.0);
}
}
let s_mat =
crate::sparse::triplets_to_csc(output_size, ca.shape.size(), &s_rows, &s_cols, &s_vals);

// Fallback for None ranges (take all) - return unchanged
// Note: 2D matrix indexing could be added in future versions
CanonExpr::Linear(ca)
let mut new_coeffs = std::collections::HashMap::new();
for (var_id, coeff) in &ca.coeffs {
new_coeffs.insert(*var_id, crate::sparse::csc_matmul(&s_mat, coeff));
}

let const_flat: Vec<f64> = ca.constant.iter().cloned().collect();
let new_const_vals: Vec<f64> = input_indices
.iter()
.map(|&i| {
if i < const_flat.len() {
const_flat[i]
} else {
0.0
}
})
.collect();
let new_const = DMatrix::from_vec(new_shape.rows(), new_shape.cols(), new_const_vals);

CanonExpr::Linear(LinExpr {
coeffs: new_coeffs,
constant: new_const,
shape: new_shape,
})
}

fn canonicalize_vstack(&mut self, exprs: &[Arc<Expr>]) -> CanonExpr {
Expand Down
Loading
Loading