Skip to content
103 changes: 103 additions & 0 deletions py_qubed/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use ::qubed::Coordinates;
use ::qubed::Qube;
use ::qubed::select::SelectMode;
use pyo3::exceptions::PyTypeError;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyModule};
Expand Down Expand Up @@ -69,6 +71,96 @@ impl PyQube {
}
}

pub fn select(
&self,
request: Bound<'_, PyDict>,
mode: Option<String>,
_consume: Option<bool>,
) -> PyResult<PyQube> {
// Collect selection data with owned Strings and Coordinates
let mut selection_data: Vec<(String, Coordinates)> = Vec::new();

for (k, v) in request.iter() {
let key: String =
k.extract().map_err(|_| PyTypeError::new_err("select keys must be strings"))?;

let coords = if v.is_instance_of::<PyList>() {
let lst =
v.downcast::<PyList>().map_err(|e| PyTypeError::new_err(e.to_string()))?;
let joined = join_pylist_as_path(lst)?;
Coordinates::from_string(&joined)
} else {
// Convert any value to string representation (handles int, float, str)
let py_str = v.str()?;
let s: String = py_str.extract()?;
Coordinates::from_string(&s)
};
Comment thread
mathleur marked this conversation as resolved.

selection_data.push((key, coords));
}

let select_mode = match mode.as_deref() {
Some(m) if m.eq_ignore_ascii_case("prune") => SelectMode::Prune,
_ => SelectMode::Default,
};

// Convert to references for the select call
let pairs: Vec<(&str, Coordinates)> =
selection_data.iter().map(|(k, c)| (k.as_str(), c.clone())).collect();

match self.inner.select(&pairs, select_mode) {
Ok(q) => Ok(PyQube { inner: q }),
Err(e) => Err(PyTypeError::new_err(e)),
}
}

pub fn all_unique_dim_coords(&mut self, py: Python<'_>) -> PyResult<Py<PyAny>> {
let dim_coords = self.inner.all_unique_dim_coords();
let py_dict = PyDict::new(py);

for (dimension, coordinates) in dim_coords {
let coord_str = coordinates.to_string();
// Split on slash if present, otherwise treat as single value
let values: Vec<&str> = if coord_str.is_empty() {
vec![]
} else if coord_str.contains('/') {
coord_str.split('/').collect()
} else {
vec![&coord_str]
};

let py_list = PyList::empty(py);
for value in values {
py_list.append(value)?;
}

py_dict.set_item(dimension, py_list)?;
}

Ok(py_dict.into_any().unbind())
}

pub fn compress(&mut self) -> PyResult<()> {
self.inner.compress();
Ok(())
}

pub fn drop(&mut self, dims: &Bound<'_, PyList>) -> PyResult<()> {
let to_drop: Vec<String> = dims
.iter()
.map(|item| {
item.str()
.and_then(|s| s.extract::<String>())
.map_err(|_| PyTypeError::new_err("drop: dimension names must be strings"))
})
.collect::<PyResult<_>>()?;
self.inner.drop(to_drop).map_err(PyTypeError::new_err)
}

pub fn squeeze(&mut self) -> PyResult<()> {
self.inner.squeeze().map_err(PyTypeError::new_err)
}

pub fn append(&mut self, other: &Bound<'_, PyQube>) -> PyResult<()> {
let mut other_mut = other.borrow_mut();
self.inner.append(&mut other_mut.inner);
Expand Down Expand Up @@ -98,6 +190,17 @@ impl PyQube {
}
}

pub(crate) fn join_pylist_as_path(lst: &Bound<'_, PyList>) -> PyResult<String> {
let mut parts: Vec<String> = Vec::with_capacity(lst.len());
for item in lst.iter() {
// Convert any value to string representation (handles int, float, str)
let py_str = item.str()?;
let s: String = py_str.extract()?;
parts.push(s);
}
Ok(parts.join("/"))
}

#[pymodule]
#[pyo3(name = "qubed")]
fn py_qubed_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
Expand Down
10 changes: 7 additions & 3 deletions py_qubed/tests/test_qubed_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,13 @@ def test_to_from_arena_json_roundtrip() -> None:
import json

parsed = json.loads(arena_json)
assert isinstance(parsed, list)
# expect at least one node entry with dim and coords
assert any(isinstance(item, dict) and "dim" in item and "coords" in item for item in parsed)
assert isinstance(parsed, dict)
assert "qube" in parsed
assert "version" in parsed
# expect qube to be a list with node entries containing dim and coords
qube_list = parsed["qube"]
assert isinstance(qube_list, list)
assert any(isinstance(item, dict) and "dim" in item and "coords" in item for item in qube_list)

# Reconstruct and verify ascii equality
reconstructed = PyQube.from_arena_json(arena_json)
Expand Down
Loading
Loading