From f4abb555d3b675f8b9e76fed49f2ffa091608da0 Mon Sep 17 00:00:00 2001 From: Peter Duchovni Date: Tue, 20 May 2025 15:19:34 +1000 Subject: [PATCH 01/15] Add doodle-rec crate for recursion model experimentation --- Cargo.lock | 12 +- Cargo.toml | 2 +- experiments/doodle-rec/Cargo.toml | 12 + experiments/doodle-rec/src/lib.rs | 497 +++++++++++++++++++++++++++++ experiments/doodle-rec/src/main.rs | 3 + src/valuetype.rs | 4 +- 6 files changed, 525 insertions(+), 5 deletions(-) create mode 100644 experiments/doodle-rec/Cargo.toml create mode 100644 experiments/doodle-rec/src/lib.rs create mode 100644 experiments/doodle-rec/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index d34974e3..3d17e7fb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -59,9 +59,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.79" +version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca" +checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" [[package]] name = "autocfg" @@ -388,6 +388,14 @@ dependencies = [ "serde_json", ] +[[package]] +name = "doodle-rec" +version = "0.1.0" +dependencies = [ + "anyhow", + "doodle", +] + [[package]] name = "doodle_gencode" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index dc36b8c2..76067621 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = [".", "generated/", "doodle-formats/"] +members = [".", "generated/", "doodle-formats/", "experiments/doodle-rec"] [package] name = "doodle" diff --git a/experiments/doodle-rec/Cargo.toml b/experiments/doodle-rec/Cargo.toml new file mode 100644 index 00000000..1e98e2ae --- /dev/null +++ b/experiments/doodle-rec/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "doodle-rec" +version = "0.1.0" +edition = "2024" + +[lib] +path = "src/lib.rs" +bench = false + +[dependencies] +anyhow = "1.0.98" +doodle = { path = "../../" } diff --git a/experiments/doodle-rec/src/lib.rs b/experiments/doodle-rec/src/lib.rs new file mode 100644 index 00000000..227e1a70 --- /dev/null +++ b/experiments/doodle-rec/src/lib.rs @@ -0,0 +1,497 @@ +use std::{borrow::Cow, cell::OnceCell, collections::{BTreeMap, HashSet}, ops::Range, rc::Rc}; +use doodle::byte_set::ByteSet; +use anyhow::{anyhow, Result as AResult}; + + +pub type Label = Cow<'static, str>; + +pub type FormatId = usize; + +pub type RecId = usize; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct FormatRef(FormatId); + +impl FormatRef { + pub const fn get_level(&self) -> usize { + self.0 + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct Span { + pub start: Idx, + pub end: Idx, +} + +impl Span { + pub const fn new(start: Idx, end: Idx) -> Self { + Self { start, end } + } +} + +impl From> for Span { + fn from(value: Range) -> Self { + Self { start: value.start, end: value.end } + } +} + +#[derive(Debug, Clone)] +pub struct FormatDecl { + format: Format, + fmt_id: FormatId, + f_type: Rc>, + batch: Option>, +} + +impl FormatDecl { + pub fn solve_type(&self, module: &FormatModule) -> AResult<&FormatType> { + let mut visited = HashSet::new(); + self.solve_type_with(module, &mut visited) + } + + pub(crate) fn solve_type_with(&self, module: &FormatModule, visited: &mut HashSet) -> AResult<&FormatType> { + match self.f_type.get() { + None => { + visited.insert(self.fmt_id); + let f_type = self.format.infer_type(visited, module, self.batch)?; + let Ok(_) = self.f_type.set(f_type) else { unreachable!("synchronous TOCTOU!?") }; + Ok(self.f_type.get().unwrap()) + } + Some(f_type) => Ok(f_type), + } + } +} + + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BaseType { + Bool, + U8, + U16, + U32, + U64, + Char, +} + +impl BaseType { + pub fn is_numeric(&self) -> bool { + matches!(self, BaseType::U8 | BaseType::U16 | BaseType::U32 | BaseType::U64) + } +} + +#[derive(Debug, Clone)] +pub enum FormatType { + Any, + Void, + Unit, + Base(BaseType), + Ref(FormatId), + Shape(TypeShape), +} + +impl FormatType { + pub fn is_numeric(&self) -> bool { + match self { + FormatType::Base(base) => base.is_numeric(), + _ => false, + } + } + + fn unify(&self, other: &FormatType) -> AResult { + match (self, other) { + (FormatType::Any, _) => Ok(other.clone()), + (_, FormatType::Any) => Ok(self.clone()), + (FormatType::Void, _) | (_, FormatType::Void) => Ok(FormatType::Void), + (FormatType::Unit, FormatType::Unit) => Ok(FormatType::Unit), + (FormatType::Unit, FormatType::Shape(TypeShape::Tuple(empty))) + | (FormatType::Shape(TypeShape::Tuple(empty)), FormatType::Unit) if empty.is_empty() => Ok(FormatType::Unit), + (FormatType::Base(b1), FormatType::Base(b2)) if b1 == b2 => Ok(FormatType::Base(*b1)), + (FormatType::Shape(s1), FormatType::Shape(s2)) => { + let s = s1.unify(s2)?; + Ok(FormatType::Shape(s)) + } + _ => Err(anyhow!("cannot unify incompatible types: {self:?}, {other:?}")), + } + } +} + +#[derive(Debug, Clone)] +pub enum TypeShape { + Tuple(Vec), + Seq(Box), + Option(Box), + Union(BTreeMap), +} + +impl TypeShape { + fn unify(&self, other: &Self) -> AResult { + match (self, other) { + (TypeShape::Tuple(t1), TypeShape::Tuple(t2)) => { + if t1.len() != t2.len() { + return Err(anyhow!("cannot unify tuples of different arity: {t1:?}, {t2:?}")); + } + let mut unified = Vec::with_capacity(t1.len()); + for (t1, t2) in t1.iter().zip(t2.iter()) { + unified.push(t1.unify(t2)?); + } + Ok(TypeShape::Tuple(unified)) + } + (TypeShape::Seq(t1), TypeShape::Seq(t2)) => Ok(TypeShape::Seq(Box::new(t1.unify(t2)?))), + (TypeShape::Option(t1), TypeShape::Option(t2)) => Ok(TypeShape::Option(Box::new(t1.unify(t2)?))), + (TypeShape::Union(bs1), TypeShape::Union(bs2)) => { + let mut bs = BTreeMap::new(); + + let keys1 = bs1.keys().collect::>(); + let keys2 = bs2.keys().collect::>(); + + let all_keys = HashSet::union(&keys1, &keys2).cloned(); + + for key in all_keys.into_iter() { + match (bs1.get(key), bs2.get(key)) { + (Some(t1), Some(t2)) => { + let t = t1.unify(t2)?; + bs.insert(key.clone(), t); + } + (Some(t), None) | (None, Some(t)) => { + bs.insert(key.clone(), t.clone()); + } + (None, None) => unreachable!("key must appear in at least one operand"), + } + } + Ok(TypeShape::Union(bs)) + } + _ => Err(anyhow!("cannot unify shapes: {self:?}, {other:?}")), + } + } +} + +#[derive(Debug, Clone)] +pub enum Format { + // References to other formats + ItemVar(FormatId), + RecVar(RecId), + + // Basic Primitives + FailWith(Label), + EndOfInput, + Byte(ByteSet), + Compute(Box), + + // Union-Based + Variant(Label, Box), + Union(Vec), + + // Sequential + Repeat(Box), + Seq(Vec), + + // Higher-Order + Tuple(Vec), + Maybe(Box, Box), +} + +impl Format { + fn infer_type<'ctx>(&'ctx self, visited: &mut HashSet, module: &'ctx FormatModule, batch: Option>) -> AResult { + match self { + Format::ItemVar(level) => { + if visited.contains(level) { + Ok(FormatType::Ref(*level)) + } else { + let decl = &module.decls[*level]; + Ok(decl.solve_type_with(module, visited)?.clone()) + } + } + Format::RecVar(batch_ix) => { + match batch { + None => Err(anyhow!("Recursion without a batch")), + Some(range) => { + let level = range.start + batch_ix; + if level > range.end { + return Err(anyhow!("batch index out of range")) + } + if visited.contains(&level) { + Ok(FormatType::Ref(level)) + } else { + let decl = &module.decls[level]; + Ok(decl.solve_type_with(module, visited)?.clone()) + } + } + } + } + Format::FailWith(_msg) => Ok(FormatType::Void), + Format::EndOfInput => Ok(FormatType::Unit), + Format::Byte(bs) if bs.is_empty() => Ok(FormatType::Void), + Format::Byte(_) => Ok(FormatType::Base(BaseType::U8)), + Format::Compute(expr) => expr.as_ref().infer_type(), + Format::Variant(label, inner) => { + let inner_type = inner.infer_type(visited, module, batch)?; + Ok(FormatType::Shape(TypeShape::Union(BTreeMap::from([(label.clone(), inner_type)])))) + } + Format::Union(branches) => { + let mut t = FormatType::Any; + for f in branches { + t = t.unify(&f.infer_type(visited, module, batch)?)?; + } + Ok(t) + } + Format::Repeat(inner) => { + let t = inner.infer_type(visited, module, batch)?; + Ok(FormatType::Shape(TypeShape::Seq(Box::new(t)))) + } + Format::Seq(elts) => { + let mut elem_type = FormatType::Any; + for elt in elts { + elem_type = elem_type.unify(&elt.infer_type(visited, module, batch)?)?; + } + Ok(FormatType::Shape(TypeShape::Seq(Box::new(elem_type)))) + } + Format::Tuple(elts) => { + let mut types = Vec::with_capacity(elts.len()); + for elt in elts { + types.push(elt.infer_type(visited, module, batch)?); + } + Ok(FormatType::Shape(TypeShape::Tuple(types))) + } + Format::Maybe(expr, format) => match expr.infer_type()? { + FormatType::Base(BaseType::Bool) => { + let t = format.infer_type(visited, module, batch)?; + Ok(FormatType::Shape(TypeShape::Option(Box::new(t)))) + } + other => Err(anyhow!("maybe expression type was inferred to be non-bool: {other:?}")), + } + } + } + + +} + +#[derive(Debug, Clone)] +pub enum Expr { + // Primitive Values + U8(u8), + U16(u16), + U32(u32), + U64(u64), + Bool(bool), + + // Primitive Value Casts + AsChar(Box), + AsU8(Box), + AsU16(Box), + AsU32(Box), + AsU64(Box), + + // Higher-Order Exprs + Seq(Vec), + Tuple(Vec), + LiftMaybe(Option>), + Variant(Label, Box), + + // Operational + IntRel(IntRel, Box, Box), + Arith(Arith, Box, Box), + Unary(Unary, Box), +} + +impl Expr { + fn infer_type(&self) -> AResult { + match self { + Expr::U8(_) => Ok(FormatType::Base(BaseType::U8)), + Expr::U16(_) => Ok(FormatType::Base(BaseType::U16)), + Expr::U32(_) => Ok(FormatType::Base(BaseType::U32)), + Expr::U64(_) => Ok(FormatType::Base(BaseType::U64)), + Expr::Bool(_) => Ok(FormatType::Base(BaseType::Bool)), + Expr::AsChar(expr) => { + let expr_type = expr.infer_type()?; + if expr_type.is_numeric() { + Ok(FormatType::Base(BaseType::Char)) + } else { + Err(anyhow!("invalid char type conversion from {expr_type:?}")) + } + } + Expr::AsU8(expr) => { + let expr_type = expr.infer_type()?; + if expr_type.is_numeric() { + Ok(FormatType::Base(BaseType::U8)) + } else { + Err(anyhow!("invalid u8 type conversion from {expr_type:?}")) + } + } + Expr::AsU16(expr) => { + let expr_type = expr.infer_type()?; + if expr_type.is_numeric() { + Ok(FormatType::Base(BaseType::U16)) + } else { + Err(anyhow!("invalid u16 type conversion from {expr_type:?}")) + } + } + Expr::AsU32(expr) => { + let expr_type = expr.infer_type()?; + if expr_type.is_numeric() { + Ok(FormatType::Base(BaseType::U32)) + } else { + Err(anyhow!("invalid u32 type conversion from {expr_type:?}")) + } + } + Expr::AsU64(expr) => { + let expr_type = expr.infer_type()?; + if expr_type.is_numeric() { + Ok(FormatType::Base(BaseType::U64)) + } else { + Err(anyhow!("invalid u64 type conversion from {expr_type:?}")) + } + } + Expr::Seq(exprs) => { + let mut elem_type = FormatType::Any; + for expr in exprs { + elem_type = expr.infer_type()?.unify(&elem_type)?; + } + Ok(FormatType::Shape(TypeShape::Seq(Box::new(elem_type)))) + } + Expr::Tuple(exprs) => { + let mut elem_types = Vec::with_capacity(exprs.len()); + for expr in exprs { + elem_types.push(expr.infer_type()?); + } + Ok(FormatType::Shape(TypeShape::Tuple(elem_types))) + } + Expr::LiftMaybe(None) => Ok(FormatType::Shape(TypeShape::Option(Box::new(FormatType::Any)))), + Expr::LiftMaybe(Some(expr)) => { + let expr_type = expr.infer_type()?; + Ok(FormatType::Shape(TypeShape::Option(Box::new(expr_type)))) + } + Expr::Variant(lab, expr) => { + let expr_type = expr.infer_type()?; + Ok(FormatType::Shape(TypeShape::Union(BTreeMap::from([(lab.clone(), expr_type)])))) + } + Expr::IntRel(_rel, lhs, rhs) => { + let lhs_type = lhs.infer_type()?; + let rhs_type = rhs.infer_type()?; + match (lhs_type, rhs_type) { + (FormatType::Base(b1), FormatType::Base(b2)) if b1 == b2 && b1.is_numeric() => Ok(FormatType::Base(BaseType::Bool)), + (lhs_type, rhs_type) => Err(anyhow!("invalid integer relation between {lhs_type:?} and {rhs_type:?}")), + } + } + Expr::Arith(_arith, lhs, rhs) => { + let lhs_type = lhs.infer_type()?; + let rhs_type = rhs.infer_type()?; + match (lhs_type, rhs_type) { + (FormatType::Base(b1), FormatType::Base(b2)) if b1 == b2 && b1.is_numeric() => Ok(FormatType::Base(b1)), + (lhs_type, rhs_type) => Err(anyhow!("invalid arithmetic operation between {lhs_type:?} and {rhs_type:?}")), + } + } + Expr::Unary(Unary::BoolNot, expr) => { + let expr_type = expr.infer_type()?; + if matches!(expr_type, FormatType::Base(BaseType::Bool)) { + Ok(FormatType::Base(BaseType::Bool)) + } else { + Err(anyhow!("invalid bool-not on {expr_type:?}")) + } + } + } + } +} + +#[derive(Debug, Clone, Copy)] +pub enum IntRel { + Eq, Neq, + Gt, Gte, + Lt, Lte, +} + +#[derive(Debug, Clone, Copy)] +pub enum Arith { + Add, Sub, + Mul, Div, Rem, + Shl, Shr, + BitOr, BitAnd, +} + +#[derive(Debug, Clone, Copy)] +pub enum Unary { + BoolNot, +} + +pub struct FormatModule { + names: Vec