diff --git a/src/parse.rs b/src/parse.rs index a6eebd18..f74f4086 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -406,6 +406,142 @@ impl Match { impl_eq_hash!(Match; scrutinee, left, right); +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +enum ParsedMatchCtor { + Left, + Right, + Some, +} + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +enum ParsedMatchPattern { + False, + True, + None, + Ctor(ParsedMatchCtor, Box), + Bind(Pattern, AliasedType), +} + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +struct ParsedMatchArm { + pattern: ParsedMatchPattern, + expression: Arc, + span: Span, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] +enum MatchFamily { + Either, + Option, + Bool, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] +enum PatternCtorTag { + Left, + Right, + Some, + None, + False, + True, +} + +impl PatternCtorTag { + fn family(self) -> MatchFamily { + match self { + PatternCtorTag::Left | PatternCtorTag::Right => MatchFamily::Either, + PatternCtorTag::Some | PatternCtorTag::None => MatchFamily::Option, + PatternCtorTag::False | PatternCtorTag::True => MatchFamily::Bool, + } + } + + fn is_payload(self) -> bool { + matches!( + self, + PatternCtorTag::Left | PatternCtorTag::Right | PatternCtorTag::Some + ) + } + + fn as_payload_tag(self) -> &'static str { + match self { + PatternCtorTag::Left => "left", + PatternCtorTag::Right => "right", + PatternCtorTag::Some => "some", + PatternCtorTag::None | PatternCtorTag::False | PatternCtorTag::True => { + unreachable!("Nullary constructors do not carry payloads") + } + } + } +} + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +enum PatternIr { + Wildcard(AliasedType), + Bind(Pattern, AliasedType), + Constructor(PatternCtorTag, Option>), +} + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +enum TypedPatternIr { + Wildcard(AliasedType), + Bind(Pattern, AliasedType), + Constructor { + family: MatchFamily, + tag: PatternCtorTag, + inner: Option>, + }, +} + +impl TypedPatternIr { + fn family(&self) -> Option { + match self { + TypedPatternIr::Constructor { family, .. } => Some(*family), + TypedPatternIr::Wildcard(_) | TypedPatternIr::Bind(_, _) => None, + } + } +} + +#[derive(Clone, Debug)] +struct TypedMatchArm { + pattern: TypedPatternIr, + expression: Arc, + span: Span, +} + +#[derive(Clone, Debug)] +enum DecisionTree { + Switch { + family: MatchFamily, + scrutinee: Arc, + branches: Vec<(PatternCtorTag, DecisionBranchKind)>, + span: Span, + }, +} + +#[derive(Clone, Debug)] +enum DecisionBranchKind { + Nullary { + expression: Arc, + }, + UnaryBind { + pattern: Pattern, + ty: AliasedType, + expression: Arc, + }, + UnarySwitch { + payload_pattern: Pattern, + payload_expr: Arc, + payload_ty: AliasedType, + tree: Box, + }, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +struct ConstructorSpec { + tag: PatternCtorTag, + has_payload: bool, +} + /// Arm of a match expression. #[derive(Clone, Debug, Eq, PartialEq, Hash)] pub struct MatchArm { @@ -454,7 +590,7 @@ impl MatchPattern { } } - /// Access the pattern and the type of a match pattern that binds a variables. + /// Access the pattern and the type of match pattern that binds a variables. pub fn as_typed_pattern(&self) -> Option<(&Pattern, &AliasedType)> { match self { MatchPattern::Left(i, ty) | MatchPattern::Right(i, ty) | MatchPattern::Some(i, ty) => { @@ -889,7 +1025,7 @@ pub trait ChumskyParse: Sized { type ParseError<'src> = extra::Err; /// This implementation only returns first encountered error. -impl ParseFromStr for A { +impl ParseFromStr for A { fn parse_from_str(s: &str) -> Result { let (tokens, mut lex_errs) = crate::lexer::lex(s); @@ -917,7 +1053,7 @@ impl ParseFromStr for A { } } -impl ParseFromStrWithErrors for A { +impl ParseFromStrWithErrors for A { fn parse_from_str_with_errors(s: &str, handler: &mut ErrorCollector) -> Option { let (tokens, lex_errs) = crate::lexer::lex(s); @@ -1625,41 +1761,42 @@ impl SingleExpression { } } -impl ChumskyParse for MatchPattern { +impl ChumskyParse for ParsedMatchPattern { fn parser<'tokens, 'src: 'tokens, I>() -> impl Parser<'tokens, I, Self, ParseError<'src>> + Clone where I: ValueInput<'tokens, Token = Token<'src>, Span = Span>, { - let wrapper = |name: &'static str, ctor: fn(Pattern, AliasedType) -> Self| { - select! { Token::Ident(i) if i == name => i } - .ignore_then(delimited_with_recovery( - Pattern::parser() - .then_ignore(just(Token::Colon)) - .then(AliasedType::parser()), - Token::LParen, - Token::RParen, - |_| { - ( - Pattern::Ignore, - AliasedType::alias(AliasName::from_str_unchecked("error")), - ) - }, - )) - .map(move |(id, ty)| ctor(id, ty)) - }; + recursive(|pat| { + let bind = Pattern::parser() + .then_ignore(just(Token::Colon)) + .then(AliasedType::parser()) + .map(|(pattern, ty)| ParsedMatchPattern::Bind(pattern, ty)); + + let ctor = move |name: &'static str, ctor: ParsedMatchCtor| { + select! { Token::Ident(i) if i == name => i } + .ignore_then(delimited_with_recovery( + pat.clone(), + Token::LParen, + Token::RParen, + |_| ParsedMatchPattern::Bind(Pattern::Ignore, error_aliased_type()), + )) + .map(move |inner| ParsedMatchPattern::Ctor(ctor, Box::new(inner))) + }; - choice(( - wrapper("Left", MatchPattern::Left), - wrapper("Right", MatchPattern::Right), - wrapper("Some", MatchPattern::Some), - select! { Token::Ident("None") => MatchPattern::None }, - select! { Token::Bool(true) => MatchPattern::True }, - select! { Token::Bool(false) => MatchPattern::False }, - )) + choice(( + ctor("Left", ParsedMatchCtor::Left), + ctor("Right", ParsedMatchCtor::Right), + ctor("Some", ParsedMatchCtor::Some), + select! { Token::Ident("None") => ParsedMatchPattern::None }, + select! { Token::Bool(true) => ParsedMatchPattern::True }, + select! { Token::Bool(false) => ParsedMatchPattern::False }, + bind, + )) + }) } } -impl MatchArm { +impl ParsedMatchArm { fn parser<'tokens, 'src: 'tokens, I, E>( expr: E, ) -> impl Parser<'tokens, I, Self, ParseError<'src>> + Clone @@ -1667,7 +1804,7 @@ impl MatchArm { I: ValueInput<'tokens, Token = Token<'src>, Span = Span>, E: Parser<'tokens, I, Expression, ParseError<'src>> + Clone + 'tokens, { - MatchPattern::parser() + ParsedMatchPattern::parser() .then_ignore(just(Token::FatArrow)) .then(expr.map(Arc::new)) .then(just(Token::Comma).or_not()) @@ -1686,11 +1823,759 @@ impl MatchArm { Self { pattern, expression, + span: e.span(), + } + }) + } +} + +fn lower_pattern_for_error(pattern: &ParsedMatchPattern) -> MatchPattern { + match pattern { + ParsedMatchPattern::Ctor(ParsedMatchCtor::Left, _) => { + MatchPattern::Left(Pattern::Ignore, error_aliased_type()) + } + ParsedMatchPattern::Ctor(ParsedMatchCtor::Right, _) => { + MatchPattern::Right(Pattern::Ignore, error_aliased_type()) + } + ParsedMatchPattern::Ctor(ParsedMatchCtor::Some, _) => { + MatchPattern::Some(Pattern::Ignore, error_aliased_type()) + } + ParsedMatchPattern::None => MatchPattern::None, + ParsedMatchPattern::False => MatchPattern::False, + ParsedMatchPattern::True => MatchPattern::True, + ParsedMatchPattern::Bind(_, ty) => MatchPattern::Left(Pattern::Ignore, ty.clone()), + } +} + +fn payload_identifier(span: Span, branch: &str) -> Identifier { + Identifier::from_str_unchecked(&format!( + "__match_payload_{}_{}_{}", + branch, span.start, span.end + )) +} + +fn payload_expression(identifier: &Identifier, span: Span) -> Arc { + Arc::new(Expression { + inner: ExpressionInner::Single(SingleExpression { + inner: SingleExpressionInner::Variable(identifier.clone()), + span, + }), + span, + }) +} + +fn match_arm(pattern: MatchPattern, expression: Arc) -> MatchArm { + MatchArm { + pattern, + expression, + } +} + +fn match_node(scrutinee: Arc, left: MatchArm, right: MatchArm, span: Span) -> Match { + Match { + scrutinee, + left, + right, + span, + } +} + +fn error_aliased_type() -> AliasedType { + AliasedType::alias(AliasName::from_str_unchecked("error")) +} + +fn wrap_nested_match( + scrutinee: Arc, + left: MatchArm, + right: MatchArm, + span: Span, +) -> Arc { + Arc::new(Expression { + inner: ExpressionInner::Single(SingleExpression { + inner: SingleExpressionInner::Match(match_node(scrutinee, left, right, span)), + span, + }), + span, + }) +} + +fn pattern_ctor_tag_from_parsed_ctor(ctor: ParsedMatchCtor) -> PatternCtorTag { + match ctor { + ParsedMatchCtor::Left => PatternCtorTag::Left, + ParsedMatchCtor::Right => PatternCtorTag::Right, + ParsedMatchCtor::Some => PatternCtorTag::Some, + } +} + +fn constructor_specs(family: MatchFamily) -> &'static [ConstructorSpec] { + match family { + MatchFamily::Either => &[ + ConstructorSpec { + tag: PatternCtorTag::Left, + has_payload: true, + }, + ConstructorSpec { + tag: PatternCtorTag::Right, + has_payload: true, + }, + ], + MatchFamily::Option => &[ + ConstructorSpec { + tag: PatternCtorTag::None, + has_payload: false, + }, + ConstructorSpec { + tag: PatternCtorTag::Some, + has_payload: true, + }, + ], + MatchFamily::Bool => &[ + ConstructorSpec { + tag: PatternCtorTag::False, + has_payload: false, + }, + ConstructorSpec { + tag: PatternCtorTag::True, + has_payload: false, + }, + ], + } +} + +fn pattern_ir_from_parsed(pattern: ParsedMatchPattern) -> PatternIr { + match pattern { + ParsedMatchPattern::False => PatternIr::Constructor(PatternCtorTag::False, None), + ParsedMatchPattern::True => PatternIr::Constructor(PatternCtorTag::True, None), + ParsedMatchPattern::None => PatternIr::Constructor(PatternCtorTag::None, None), + ParsedMatchPattern::Ctor(ctor, inner) => PatternIr::Constructor( + pattern_ctor_tag_from_parsed_ctor(ctor), + Some(Box::new(pattern_ir_from_parsed(*inner))), + ), + ParsedMatchPattern::Bind(Pattern::Ignore, ty) => PatternIr::Wildcard(ty), + ParsedMatchPattern::Bind(pattern, ty) => PatternIr::Bind(pattern, ty), + } +} + +fn root_family_from_arm(arm: &ParsedMatchArm) -> Result { + match &arm.pattern { + ParsedMatchPattern::Ctor(ctor, _) => Ok(pattern_ctor_tag_from_parsed_ctor(*ctor).family()), + ParsedMatchPattern::None => Ok(MatchFamily::Option), + ParsedMatchPattern::False | ParsedMatchPattern::True => Ok(MatchFamily::Bool), + ParsedMatchPattern::Bind(_, _) => Err(Error::Grammar( + "Top-level match arms must start with Left(..), Right(..), Some(..), None, false, or true" + .to_string(), + )), + } +} + +fn type_pattern_ir( + pattern: PatternIr, + top_level_family: MatchFamily, + nested: bool, +) -> Result { + match pattern { + PatternIr::Wildcard(ty) => { + if nested { + Ok(TypedPatternIr::Wildcard(ty)) + } else { + Err(Error::Grammar( + "Top-level match arms must start with Left(..), Right(..), Some(..), None, false, or true" + .to_string(), + )) + } + } + PatternIr::Bind(pattern, ty) => { + if nested { + Ok(TypedPatternIr::Bind(pattern, ty)) + } else { + Err(Error::Grammar( + "Top-level match arms must start with Left(..), Right(..), Some(..), None, false, or true" + .to_string(), + )) + } + } + PatternIr::Constructor(tag, inner) => { + let family = tag.family(); + if !nested && family != top_level_family { + return Err(Error::Grammar("Mixed match families".to_string())); + } + + if nested && !tag.is_payload() { + return Err(Error::Grammar( + "Unexpected terminal pattern under constructor".to_string(), + )); + } + + let typed_inner = if tag.is_payload() { + let inner = inner.ok_or_else(|| { + Error::Grammar("Payload constructor missing inner pattern".to_string()) + })?; + Some(Box::new(type_pattern_ir(*inner, top_level_family, true)?)) + } else { + None + }; + + Ok(TypedPatternIr::Constructor { + family, + tag, + inner: typed_inner, + }) + } + } +} + +fn typed_arms_from_parsed( + arms: &[ParsedMatchArm], +) -> Result<(MatchFamily, Vec), Error> { + let first = arms + .first() + .ok_or_else(|| Error::Grammar("Match requires at least two arms".to_string()))?; + let top_level_family = root_family_from_arm(first)?; + + let mut typed = Vec::with_capacity(arms.len()); + for arm in arms { + let family = root_family_from_arm(arm)?; + if family != top_level_family { + return Err(Error::IncompatibleMatchArms( + lower_pattern_for_error(&first.pattern), + lower_pattern_for_error(&arm.pattern), + )); + } + + typed.push(TypedMatchArm { + pattern: type_pattern_ir( + pattern_ir_from_parsed(arm.pattern.clone()), + top_level_family, + false, + )?, + expression: arm.expression.clone(), + span: arm.span, + }); + } + + Ok((top_level_family, typed)) +} + +fn is_typed_wildcard(pattern: &TypedPatternIr) -> bool { + matches!( + pattern, + TypedPatternIr::Wildcard(_) | TypedPatternIr::Bind(_, _) + ) +} + +fn infer_family_from_patterns(patterns: &[TypedPatternIr]) -> Result, Error> { + let mut family: Option = None; + + for pattern in patterns { + if let Some(current) = pattern.family() { + if let Some(existing) = family { + if existing != current { + return Err(Error::Grammar("Mixed match families".to_string())); + } + } else { + family = Some(current); + } + } + } + + Ok(family) +} + +fn specialize_pattern_for_tag( + pattern: &TypedPatternIr, + tag: PatternCtorTag, +) -> Option { + match pattern { + TypedPatternIr::Wildcard(ty) => Some(TypedPatternIr::Wildcard(ty.clone())), + TypedPatternIr::Bind(pattern, ty) => { + Some(TypedPatternIr::Bind(pattern.clone(), ty.clone())) + } + TypedPatternIr::Constructor { + tag: pattern_tag, + inner, + .. + } if *pattern_tag == tag => { + if tag.is_payload() { + inner.as_ref().map(|inner| (**inner).clone()) + } else { + Some(pattern.clone()) + } + } + TypedPatternIr::Constructor { .. } => None, + } +} + +fn branch_patterns_for_tag( + patterns: &[TypedPatternIr], + tag: PatternCtorTag, +) -> Vec { + patterns + .iter() + .filter_map(|pattern| specialize_pattern_for_tag(pattern, tag)) + .collect() +} + +fn typed_pattern_same_shape(a: &TypedPatternIr, b: &TypedPatternIr) -> bool { + match (a, b) { + (TypedPatternIr::Wildcard(ty_a), TypedPatternIr::Wildcard(ty_b)) => ty_a == ty_b, + (TypedPatternIr::Bind(_, ty_a), TypedPatternIr::Bind(_, ty_b)) => ty_a == ty_b, + (TypedPatternIr::Wildcard(ty_a), TypedPatternIr::Bind(_, ty_b)) + | (TypedPatternIr::Bind(_, ty_a), TypedPatternIr::Wildcard(ty_b)) => ty_a == ty_b, + ( + TypedPatternIr::Constructor { + family: family_a, + tag: tag_a, + inner: inner_a, + }, + TypedPatternIr::Constructor { + family: family_b, + tag: tag_b, + inner: inner_b, + }, + ) if family_a == family_b && tag_a == tag_b => match (inner_a, inner_b) { + (Some(inner_a), Some(inner_b)) => typed_pattern_same_shape(inner_a, inner_b), + (None, None) => true, + _ => false, + }, + _ => false, + } +} + +fn typed_pattern_covers(cover: &TypedPatternIr, candidate: &TypedPatternIr) -> bool { + match (cover, candidate) { + (TypedPatternIr::Wildcard(_), _) | (TypedPatternIr::Bind(_, _), _) => true, + ( + TypedPatternIr::Constructor { + family: family_a, + tag: tag_a, + inner: inner_a, + }, + TypedPatternIr::Constructor { + family: family_b, + tag: tag_b, + inner: inner_b, + }, + ) if family_a == family_b && tag_a == tag_b => match (inner_a, inner_b) { + (Some(inner_a), Some(inner_b)) => typed_pattern_covers(inner_a, inner_b), + (None, None) => true, + _ => false, + }, + _ => false, + } +} + +fn validate_duplicate_patterns(arms: &[TypedMatchArm]) -> Result<(), Error> { + for i in 0..arms.len() { + for j in (i + 1)..arms.len() { + if typed_pattern_same_shape(&arms[i].pattern, &arms[j].pattern) { + return Err(Error::Grammar("Duplicate match arm".to_string()) + .with_span(arms[j].span) + .into()); + } + } + } + + Ok(()) +} + +fn validate_constructor_overlap_patterns( + family: MatchFamily, + patterns: &[TypedPatternIr], + spans: &[Span], +) -> Result<(), Error> { + for spec in constructor_specs(family) { + if !spec.has_payload { + continue; + } + + let mut child_patterns = Vec::new(); + let mut child_spans = Vec::new(); + let mut saw_nested = false; + + for (pattern, span) in patterns.iter().zip(spans.iter().copied()) { + if let Some(child) = specialize_pattern_for_tag(pattern, spec.tag) { + if is_typed_wildcard(&child) { + if saw_nested { + return Err(Error::Grammar( + "Overlapping match arms in constructor branch".to_string(), + ) + .with_span(span) + .into()); + } + } else if matches!(child, TypedPatternIr::Constructor { .. }) { + saw_nested = true; } + + child_patterns.push(child); + child_spans.push(span); + } + } + + if child_patterns.is_empty() { + continue; + } + + let nested_children: Vec<_> = child_patterns + .iter() + .filter(|pattern| matches!(pattern, TypedPatternIr::Constructor { .. })) + .cloned() + .collect(); + let nested_spans: Vec<_> = child_patterns + .iter() + .zip(child_spans.iter().copied()) + .filter_map(|(pattern, span)| { + matches!(pattern, TypedPatternIr::Constructor { .. }).then_some(span) }) + .collect(); + + if !nested_children.is_empty() { + let nested_family = infer_family_from_patterns(&nested_children)? + .ok_or_else(|| Error::Grammar("Mixed match families".to_string()))?; + validate_constructor_overlap_patterns(nested_family, &nested_children, &nested_spans)?; + } + } + + Ok(()) +} + +fn is_pattern_useful(previous: &[TypedPatternIr], new_pattern: &TypedPatternIr) -> bool { + !previous + .iter() + .any(|previous_pattern| typed_pattern_covers(previous_pattern, new_pattern)) +} + +fn validate_pattern_usefulness(arms: &[TypedMatchArm]) -> Result<(), Error> { + let mut previous = Vec::new(); + + for arm in arms { + if !is_pattern_useful(&previous, &arm.pattern) { + return Err(Error::Grammar( + "This match arm is unreachable because it is covered by a previous arm".to_string(), + ) + .with_span(arm.span) + .into()); + } + + previous.push(arm.pattern.clone()); + } + + Ok(()) +} + +fn missing_branch_message(tag: PatternCtorTag) -> &'static str { + match tag { + PatternCtorTag::Left => "Non-exhaustive Either match: missing Left branch", + PatternCtorTag::Right => "Non-exhaustive Either match: missing Right branch", + PatternCtorTag::Some => "Non-exhaustive Option match: missing Some", + PatternCtorTag::None => "Non-exhaustive Option match: missing None", + PatternCtorTag::False => "Non-exhaustive bool match: missing false", + PatternCtorTag::True => "Non-exhaustive bool match: missing true", + } +} + +fn first_missing_pattern_message( + family: MatchFamily, + patterns: &[TypedPatternIr], +) -> Result, Error> { + if patterns.iter().any(is_typed_wildcard) { + return Ok(None); + } + + for spec in constructor_specs(family) { + let branch_patterns = branch_patterns_for_tag(patterns, spec.tag); + if branch_patterns.is_empty() { + return Ok(Some(missing_branch_message(spec.tag).to_string())); + } + + if spec.has_payload { + if branch_patterns.iter().any(is_typed_wildcard) { + continue; + } + + let nested_family = infer_family_from_patterns(&branch_patterns)? + .ok_or_else(|| Error::Grammar("Mixed match families".to_string()))?; + + if let Some(message) = first_missing_pattern_message(nested_family, &branch_patterns)? { + return Ok(Some(message)); + } + } + } + + Ok(None) +} + +fn validate_pattern_exhaustiveness( + family: MatchFamily, + arms: &[TypedMatchArm], +) -> Result<(), Error> { + let patterns: Vec<_> = arms.iter().map(|arm| arm.pattern.clone()).collect(); + + if let Some(message) = first_missing_pattern_message(family, &patterns)? { + return Err(Error::Grammar(message)); + } + + Ok(()) +} + +fn make_branch_typed_arm(pattern: TypedPatternIr, arm: &TypedMatchArm) -> TypedMatchArm { + TypedMatchArm { + pattern, + expression: arm.expression.clone(), + span: arm.span, + } +} + +fn compile_decision_branch( + spec: ConstructorSpec, + arms: &[TypedMatchArm], + span: Span, +) -> Result { + let mut applicable = Vec::new(); + for arm in arms { + if let Some(pattern) = specialize_pattern_for_tag(&arm.pattern, spec.tag) { + applicable.push(make_branch_typed_arm(pattern, arm)); + } + } + + debug_assert!(!applicable.is_empty()); + + if !spec.has_payload { + return Ok(DecisionBranchKind::Nullary { + expression: applicable[0].expression.clone(), + }); + } + + let first = &applicable[0]; + match &first.pattern { + TypedPatternIr::Wildcard(ty) => Ok(DecisionBranchKind::UnaryBind { + pattern: Pattern::Ignore, + ty: ty.clone(), + expression: first.expression.clone(), + }), + TypedPatternIr::Bind(pattern, ty) => Ok(DecisionBranchKind::UnaryBind { + pattern: pattern.clone(), + ty: ty.clone(), + expression: first.expression.clone(), + }), + TypedPatternIr::Constructor { .. } => { + let payload_id = payload_identifier(span, spec.tag.as_payload_tag()); + let payload_pattern = Pattern::Identifier(payload_id.clone()); + let payload_expr = payload_expression(&payload_id, span); + let nested_family = infer_family_from_patterns( + &applicable + .iter() + .map(|arm| arm.pattern.clone()) + .collect::>(), + )? + .ok_or_else(|| Error::Grammar("Mixed match families".to_string()))?; + let nested_tree = + compile_decision_tree(payload_expr.clone(), nested_family, &applicable, span)?; + let payload_ty = decision_tree_scrutinee_type(&nested_tree); + + Ok(DecisionBranchKind::UnarySwitch { + payload_pattern, + payload_expr, + payload_ty, + tree: Box::new(nested_tree), + }) + } + } +} + +fn compile_decision_tree( + scrutinee: Arc, + family: MatchFamily, + arms: &[TypedMatchArm], + span: Span, +) -> Result { + let mut branches = Vec::new(); + + for spec in constructor_specs(family) { + branches.push((spec.tag, compile_decision_branch(*spec, arms, span)?)); + } + + Ok(DecisionTree::Switch { + family, + scrutinee, + branches, + span, + }) +} + +fn branch_payload_ty(branch: &DecisionBranchKind) -> Option { + match branch { + DecisionBranchKind::Nullary { .. } => None, + DecisionBranchKind::UnaryBind { ty, .. } => Some(ty.clone()), + DecisionBranchKind::UnarySwitch { payload_ty, .. } => Some(payload_ty.clone()), + } +} + +fn decision_tree_scrutinee_type(tree: &DecisionTree) -> AliasedType { + match tree { + DecisionTree::Switch { + family, branches, .. + } => match family { + MatchFamily::Bool => AliasedType::boolean(), + MatchFamily::Option => { + let some_ty = branches + .iter() + .find(|(tag, _)| *tag == PatternCtorTag::Some) + .and_then(|(_, branch)| branch_payload_ty(branch)) + .unwrap_or_else(error_aliased_type); + AliasedType::option(some_ty) + } + MatchFamily::Either => { + let left_ty = branches + .iter() + .find(|(tag, _)| *tag == PatternCtorTag::Left) + .and_then(|(_, branch)| branch_payload_ty(branch)) + .unwrap_or_else(error_aliased_type); + let right_ty = branches + .iter() + .find(|(tag, _)| *tag == PatternCtorTag::Right) + .and_then(|(_, branch)| branch_payload_ty(branch)) + .unwrap_or_else(error_aliased_type); + AliasedType::either(left_ty, right_ty) + } + }, + } +} + +fn lower_branch_kind_to_match_arm( + tag: PatternCtorTag, + branch: DecisionBranchKind, + span: Span, +) -> Result { + match branch { + DecisionBranchKind::Nullary { expression } => { + let pattern = match tag { + PatternCtorTag::None => MatchPattern::None, + PatternCtorTag::False => MatchPattern::False, + PatternCtorTag::True => MatchPattern::True, + PatternCtorTag::Left | PatternCtorTag::Right | PatternCtorTag::Some => { + return Err(Error::Grammar("Invalid nullary branch".to_string())) + } + }; + Ok(match_arm(pattern, expression)) + } + DecisionBranchKind::UnaryBind { + pattern, + ty, + expression, + } => { + let match_pattern = match tag { + PatternCtorTag::Left => MatchPattern::Left(pattern, ty), + PatternCtorTag::Right => MatchPattern::Right(pattern, ty), + PatternCtorTag::Some => MatchPattern::Some(pattern, ty), + PatternCtorTag::None | PatternCtorTag::False | PatternCtorTag::True => { + return Err(Error::Grammar("Invalid payload branch".to_string())) + } + }; + Ok(match_arm(match_pattern, expression)) + } + DecisionBranchKind::UnarySwitch { + payload_pattern, + payload_expr, + payload_ty, + tree, + } => { + let nested = lower_decision_tree_to_match(*tree)?; + let match_pattern = match tag { + PatternCtorTag::Left => MatchPattern::Left(payload_pattern, payload_ty), + PatternCtorTag::Right => MatchPattern::Right(payload_pattern, payload_ty), + PatternCtorTag::Some => MatchPattern::Some(payload_pattern, payload_ty), + PatternCtorTag::None | PatternCtorTag::False | PatternCtorTag::True => { + return Err(Error::Grammar("Invalid payload branch".to_string())) + } + }; + + Ok(match_arm( + match_pattern, + wrap_nested_match(payload_expr, nested.left, nested.right, span), + )) + } } } +fn lower_decision_tree_to_match(tree: DecisionTree) -> Result { + match tree { + DecisionTree::Switch { + family, + scrutinee, + branches, + span, + } => { + let mut branch_map = std::collections::BTreeMap::new(); + for (tag, branch) in branches { + branch_map.insert(tag, branch); + } + + match family { + MatchFamily::Either => Ok(match_node( + scrutinee, + lower_branch_kind_to_match_arm( + PatternCtorTag::Left, + branch_map.remove(&PatternCtorTag::Left).unwrap(), + span, + )?, + lower_branch_kind_to_match_arm( + PatternCtorTag::Right, + branch_map.remove(&PatternCtorTag::Right).unwrap(), + span, + )?, + span, + )), + MatchFamily::Option => Ok(match_node( + scrutinee, + lower_branch_kind_to_match_arm( + PatternCtorTag::None, + branch_map.remove(&PatternCtorTag::None).unwrap(), + span, + )?, + lower_branch_kind_to_match_arm( + PatternCtorTag::Some, + branch_map.remove(&PatternCtorTag::Some).unwrap(), + span, + )?, + span, + )), + MatchFamily::Bool => Ok(match_node( + scrutinee, + lower_branch_kind_to_match_arm( + PatternCtorTag::False, + branch_map.remove(&PatternCtorTag::False).unwrap(), + span, + )?, + lower_branch_kind_to_match_arm( + PatternCtorTag::True, + branch_map.remove(&PatternCtorTag::True).unwrap(), + span, + )?, + span, + )), + } + } + } +} + +fn lower_match_arms( + scrutinee: Arc, + arms: Vec, + span: Span, +) -> Result { + let (family, typed_arms) = typed_arms_from_parsed(&arms)?; + let patterns: Vec<_> = typed_arms.iter().map(|arm| arm.pattern.clone()).collect(); + let spans: Vec<_> = typed_arms.iter().map(|arm| arm.span).collect(); + + validate_duplicate_patterns(&typed_arms)?; + validate_constructor_overlap_patterns(family, &patterns, &spans)?; + validate_pattern_usefulness(&typed_arms)?; + validate_pattern_exhaustiveness(family, &typed_arms)?; + + let tree = compile_decision_tree(scrutinee, family, &typed_arms, span)?; + lower_decision_tree_to_match(tree) +} + impl Match { fn parser<'tokens, 'src: 'tokens, I, E>( expr: E, @@ -1701,79 +2586,29 @@ impl Match { { let scrutinee = expr.clone().map(Arc::new); - let arm_recovery = any() - .filter(|t| !matches!(t, Token::Comma | Token::RBrace)) - .ignored() - .or(nested_delimiters( - Token::LBrace, - Token::RBrace, - [ - (Token::LParen, Token::RParen), - (Token::LBracket, Token::RBracket), - ], - |_| (), - ) - .ignored()) - .repeated() - .map_with(|(), _| None); - - let arm_parser = MatchArm::parser(expr.clone()) - .map(Some) - .recover_with(via_parser(arm_recovery.clone())); - let arms = delimited_with_recovery( - arm_parser.clone().then(arm_parser.clone()), + ParsedMatchArm::parser(expr.clone()) + .repeated() + .collect::>(), Token::LBrace, Token::RBrace, - |_| (None, None), + |_| Vec::new(), ); just(Token::Match) .ignore_then(scrutinee) .then(arms) - .validate(|(scrutinee, arms), e, emit| match arms { - (Some(first), Some(second)) => { - let (left, right) = match (&first.pattern, &second.pattern) { - (MatchPattern::Left(..), MatchPattern::Right(..)) => (first, second), - (MatchPattern::Right(..), MatchPattern::Left(..)) => (second, first), - - (MatchPattern::None, MatchPattern::Some(..)) => (first, second), - (MatchPattern::Some(..), MatchPattern::None) => (second, first), - - (MatchPattern::False, MatchPattern::True) => (first, second), - (MatchPattern::True, MatchPattern::False) => (second, first), - - (p1, p2) => { - emit.emit( - Error::IncompatibleMatchArms(p1.clone(), p2.clone()) - .with_span(e.span()), - ); - (first, second) - } - }; - - Self { - scrutinee, - left, - right, - span: e.span(), - } - } - _ => { - let match_arm_fallback = MatchArm { - expression: Arc::new(Expression::empty(Span::new(0, 0))), - pattern: MatchPattern::False, - }; - - let (left, right) = ( - arms.0.unwrap_or(match_arm_fallback.clone()), - arms.1.unwrap_or(match_arm_fallback.clone()), - ); - Self { - scrutinee, - left, - right, - span: e.span(), + .validate(|(scrutinee, arms), e, emit| { + match lower_match_arms(scrutinee.clone(), arms, e.span()) { + Ok(match_) => match_, + Err(err) => { + emit.emit(err.with_span(e.span())); + match_node( + scrutinee, + match_arm(MatchPattern::False, Arc::new(Expression::empty(e.span()))), + match_arm(MatchPattern::True, Arc::new(Expression::empty(e.span()))), + e.span(), + ) } } }) @@ -2172,37 +3007,4 @@ impl crate::ArbitraryRec for Match { } #[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_reject_redefined_builtin_type() { - let ty = TypeAlias::parse_from_str("type Ctx8 = u32") - .expect_err("Redifining built-in alias should be rejected"); - - assert_eq!( - ty.error(), - &Error::RedefinedAliasAsBuiltin(AliasName::from_str_unchecked("Ctx8")) - ); - } - - #[test] - fn test_double_colon() { - let input = "fn main() { let ab: u8 = <(u4, u4)> : :into((0b1011, 0b1101)); }"; - let mut error_handler = ErrorCollector::new(Arc::from(input)); - let parse_program = Program::parse_from_str_with_errors(input, &mut error_handler); - - assert!(parse_program.is_none()); - assert!(ErrorCollector::to_string(&error_handler).contains("Expected '::', found ':'")); - } - - #[test] - fn test_double_double_colon() { - let input = "fn main() { let pk: Pubkey = witnes::::PK; }"; - let mut error_handler = ErrorCollector::new(Arc::from(input)); - let parse_program = Program::parse_from_str_with_errors(input, &mut error_handler); - - assert!(parse_program.is_none()); - assert!(ErrorCollector::to_string(&error_handler).contains("Expected ';', found '::'")); - } -} +mod test; diff --git a/src/parse/test.rs b/src/parse/test.rs new file mode 100644 index 00000000..9bb06a70 --- /dev/null +++ b/src/parse/test.rs @@ -0,0 +1,1181 @@ +use super::*; + +mod helpers { + use crate::error::ErrorCollector; + use crate::parse::{ParseFromStrWithErrors, Program}; + use std::sync::Arc; + + pub fn collect_parse_errors(input: &str) -> String { + let mut error_handler = ErrorCollector::new(Arc::from(input)); + let parsed = Program::parse_from_str_with_errors(input, &mut error_handler); + assert!(parsed.is_none(), "program unexpectedly parsed successfully"); + ErrorCollector::to_string(&error_handler) + } + + pub fn normalize_generated_payload_names(input: &str) -> String { + let mut out = String::with_capacity(input.len()); + let bytes = input.as_bytes(); + let mut i = 0; + let needle = b"__match_payload_"; + + while i < bytes.len() { + if bytes[i..].starts_with(needle) { + out.push_str("__match_payload_"); + i += needle.len(); + while i < bytes.len() { + let ch = bytes[i] as char; + if ch.is_ascii_alphanumeric() || ch == '_' { + i += 1; + } else { + break; + } + } + } else { + out.push(bytes[i] as char); + i += 1; + } + } + + out + } +} + +#[test] +fn test_reject_redefined_builtin_type() { + let ty = TypeAlias::parse_from_str("type Ctx8 = u32") + .expect_err("Redifining built-in alias should be rejected"); + + assert_eq!( + ty.error(), + &Error::RedefinedAliasAsBuiltin(AliasName::from_str_unchecked("Ctx8")) + ); +} + +#[test] +fn test_double_colon() { + let input = "fn main() { let ab: u8 = <(u4, u4)> : :into((0b1011, 0b1101)); }"; + let mut error_handler = ErrorCollector::new(Arc::from(input)); + let parse_program = Program::parse_from_str_with_errors(input, &mut error_handler); + + assert!(parse_program.is_none()); + assert!(ErrorCollector::to_string(&error_handler).contains("Expected '::', found ':'")); +} + +#[test] +fn test_double_double_colon() { + let input = "fn main() { let pk: Pubkey = witnes::::PK; }"; + let mut error_handler = ErrorCollector::new(Arc::from(input)); + let parse_program = Program::parse_from_str_with_errors(input, &mut error_handler); + + assert!(parse_program.is_none()); + assert!(ErrorCollector::to_string(&error_handler).contains("Expected ';', found '::'")); +} + +mod parsing { + use crate::parse::{ParseFromStr, Program}; + + #[test] + fn test_parse_two_arm_either_match_still_works() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(x: u8) => x, + Right(y: u16) => 0, + } + } + "#; + + Program::parse_from_str(input).expect("two-arm either match should still parse"); + } + + #[test] + fn test_parse_option_match_still_works() { + let input = r#" + fn main() -> u8 { + match witness::x { + None => 0, + Some(x: u8) => x, + } + } + "#; + + Program::parse_from_str(input).expect("option match should still parse"); + } + + #[test] + fn test_parse_bool_match_still_works() { + let input = r#" + fn main() -> u8 { + match true { + false => 0, + true => 1, + } + } + "#; + + Program::parse_from_str(input).expect("bool match should still parse"); + } +} + +mod binary_tree_match { + use crate::parse::{ + ExpressionInner, Item, MatchPattern, ParseFromStr, Program, SingleExpressionInner, + }; + + #[test] + fn test_parse_simple_two_arm_match_shape() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(x: u8) => x, + Right(y: u8) => y, + } + } + "#; + + let program = Program::parse_from_str(input).expect("program should parse"); + + let Item::Function(function) = &program.items()[0] else { + panic!("expected function"); + }; + + let ExpressionInner::Block(_, Some(body)) = function.body().inner() else { + panic!("expected block"); + }; + + let ExpressionInner::Single(single) = body.inner() else { + panic!("expected single expression"); + }; + + let SingleExpressionInner::Match(outer) = single.inner() else { + panic!("expected outer match"); + }; + + assert!(matches!(outer.left().pattern(), MatchPattern::Left(_, _))); + assert!(matches!(outer.right().pattern(), MatchPattern::Right(_, _))); + } + + #[test] + fn test_nested_four_arm_match_to_binary_tree() { + let input = r#" + fn main() -> u8 { + match witness::thingy { + Left(Left(x: u8)) => x, + Left(Right(y: u8)) => y, + Right(Left((a, b): (u8, u256))) => a, + Right(Right(s: Signature)) => 0, + } + } + "#; + + let program = Program::parse_from_str(input).expect("program should parse"); + + let Item::Function(function) = &program.items()[0] else { + panic!("expected function"); + }; + + let ExpressionInner::Block(_, Some(body)) = function.body().inner() else { + panic!("expected block body"); + }; + + let ExpressionInner::Single(single) = body.inner() else { + panic!("expected single expression body"); + }; + + let SingleExpressionInner::Match(outer) = single.inner() else { + panic!("expected outer match"); + }; + + assert!(matches!(outer.left().pattern(), MatchPattern::Left(_, _))); + assert!(matches!(outer.right().pattern(), MatchPattern::Right(_, _))); + + match outer.left().expression().inner() { + ExpressionInner::Single(left_single) => { + assert!(matches!( + left_single.inner(), + SingleExpressionInner::Match(_) + )); + } + _ => panic!("expected nested left match"), + } + + match outer.right().expression().inner() { + ExpressionInner::Single(right_single) => { + assert!(matches!( + right_single.inner(), + SingleExpressionInner::Match(_) + )); + } + _ => panic!("expected nested right match"), + } + } + + #[test] + fn test_parse_nested_option_either_match_to_binary_tree() { + let input = r#" + fn main() -> u8 { + match witness::x { + None => 0, + Some(Left(a: u8)) => a, + Some(Right(b: u8)) => b, + } + } + "#; + + let program = Program::parse_from_str(input).expect("program should parse"); + + let Item::Function(function) = &program.items()[0] else { + panic!("expected function"); + }; + + let ExpressionInner::Block(_, Some(body)) = function.body().inner() else { + panic!("expected block body"); + }; + + let ExpressionInner::Single(single) = body.inner() else { + panic!("expected single expression body"); + }; + + let SingleExpressionInner::Match(outer) = single.inner() else { + panic!("expected outer match"); + }; + + assert!(matches!(outer.left().pattern(), MatchPattern::None)); + assert!(matches!(outer.right().pattern(), MatchPattern::Some(_, _))); + + match outer.right().expression().inner() { + ExpressionInner::Single(right_single) => { + assert!(matches!( + right_single.inner(), + SingleExpressionInner::Match(_) + )); + } + _ => panic!("expected nested some match"), + } + } + + #[test] + fn test_parse_three_level_nested_match_recursively() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(Left(Left(a: u8))) => a, + Left(Left(Right(b: u8))) => b, + Left(Right(c: u8)) => c, + Right(d: u8) => d, + } + } + "#; + + let program = Program::parse_from_str(input).expect("program should parse"); + + let Item::Function(function) = &program.items()[0] else { + panic!("expected function"); + }; + let ExpressionInner::Block(_, Some(body)) = function.body().inner() else { + panic!("expected block"); + }; + let ExpressionInner::Single(single) = body.inner() else { + panic!("expected single"); + }; + let SingleExpressionInner::Match(outer) = single.inner() else { + panic!("expected outer match"); + }; + + match outer.left().expression().inner() { + ExpressionInner::Single(left_single) => { + let SingleExpressionInner::Match(inner) = left_single.inner() else { + panic!("expected first nested match"); + }; + match inner.left().expression().inner() { + ExpressionInner::Single(inner_left_single) => { + assert!(matches!( + inner_left_single.inner(), + SingleExpressionInner::Match(_) + )); + } + _ => panic!("expected second nested match"), + } + } + _ => panic!("expected nested match on outer left"), + } + } + + #[test] + fn test_nested_match_scrutinee_types_are_preserved() { + let input = r#" + fn main() -> u8 { + match witness::x { + None => 0, + Some(Left(a: u8)) => a, + Some(Right(b: u16)) => 0, + } + } + "#; + + let program = Program::parse_from_str(input).expect("program should parse"); + let Item::Function(function) = &program.items()[0] else { + panic!("expected function"); + }; + let ExpressionInner::Block(_, Some(body)) = function.body().inner() else { + panic!("expected block"); + }; + let ExpressionInner::Single(single) = body.inner() else { + panic!("expected single"); + }; + let SingleExpressionInner::Match(outer) = single.inner() else { + panic!("expected outer match"); + }; + + assert_eq!( + outer.scrutinee_type().to_string(), + "Option>" + ); + + let ExpressionInner::Single(right_single) = outer.right().expression().inner() else { + panic!("expected nested some match"); + }; + let SingleExpressionInner::Match(inner) = right_single.inner() else { + panic!("expected inner match"); + }; + assert_eq!(inner.scrutinee_type().to_string(), "Either"); + } +} + +mod duplicates { + use crate::parse::test::helpers::collect_parse_errors; + + #[test] + fn test_reject_duplicate_true_arm() { + let input = r#" + fn main() -> u8 { + match true { + false => 0, + true => 1, + true => 2, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Duplicate match arm") || err.contains("Duplicate true arm")); + } + + #[test] + fn test_reject_duplicate_nested_path() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(Right(a: u8)) => a, + Left(Right(b: u8)) => b, + Right(c: u8) => c, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Duplicate match arm")); + } + + #[test] + fn test_reject_duplicate_nested_path_later_arm() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(Right(a: u8)) => a, + Left(Right(b: u8)) => b, + Right(c: u8) => c, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Duplicate match arm")); + assert!(err.contains("Left(Right(b: u8)) => b")); + } + + #[test] + fn test_reject_duplicate_left_arm() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(a: u8) => a, + Left(b: u8) => b, + Right(c: u8) => c, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Duplicate match arm")); + assert!(err.contains("Left(b: u8) => b")); + } + + #[test] + fn test_reject_duplicate_right_arm() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(a: u8) => a, + Right(b: u8) => b, + Right(c: u8) => c, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Duplicate match arm")); + assert!(err.contains("Right(c: u8) => c")); + } + + #[test] + fn test_reject_duplicate_some_bind_arm() { + let input = r#" + fn main() -> u8 { + match witness::x { + None => 0, + Some(a: u8) => a, + Some(b: u8) => b, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Duplicate match arm")); + assert!(err.contains("Some(b: u8) => b")); + } + #[test] + fn test_reject_duplicate_none_later_arm() { + let input = r#" + fn main() -> u8 { + match witness::x { + None => 0, + None => 1, + Some(x: u8) => x, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Duplicate match arm")); + assert!(err.contains("None => 1")); + } + + #[test] + fn test_reject_duplicate_true_later_arm() { + let input = r#" + fn main() -> u8 { + match true { + false => 0, + true => 1, + true => 2, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Duplicate match arm")); + assert!(err.contains("true => 2")); + } + + #[test] + fn test_reject_duplicate_left_bind_and_ignore_arm() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(_: u8) => 0, + Left(a: u8) => a, + Right(b: u8) => b, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Duplicate match arm") || err.contains("unreachable")); + assert!(err.contains("Left(a: u8) => a")); + } + + #[test] + fn test_reject_duplicate_some_ignore_and_bind_arm() { + let input = r#" + fn main() -> u8 { + match witness::x { + None => 0, + Some(_: u8) => 1, + Some(a: u8) => a, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Duplicate match arm") || err.contains("unreachable")); + assert!(err.contains("Some(a: u8) => a")); + } + + #[test] + fn test_reject_duplicate_false_arm_later_arm() { + let input = r#" + fn main() -> u8 { + match true { + false => 0, + false => 1, + true => 2, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Duplicate match arm")); + assert!(err.contains("false => 1")); + } + + #[test] + fn test_reject_duplicate_three_level_path() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(Left(Left(a: u8))) => a, + Left(Left(Left(b: u8))) => b, + Left(Right(c: u8)) => c, + Right(d: u8) => d, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Duplicate match arm")); + assert!(err.contains("Left(Left(Left(b: u8))) => b")); + } + + #[test] + fn test_reject_duplicate_bool_arm() { + let input = r#" + fn main() -> u8 { + match true { + false => 0, + false => 1, + true => 2, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Duplicate match arm") || err.contains("Duplicate false arm")); + } + + #[test] + fn test_reject_duplicate_none_arm() { + let input = r#" + fn main() -> u8 { + match witness::x { + None => 0, + None => 1, + Some(x: u8) => x, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Duplicate match arm") || err.contains("Duplicate None arm")); + } +} + +mod syntax { + use crate::parse::test::helpers::collect_parse_errors; + use crate::parse::{ParseFromStr, Program}; + + #[test] + fn test_block_arms_do_not_require_trailing_commas() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(a: u8) => { a } + Right(b: u8) => { b } + } + } + "#; + + Program::parse_from_str(input) + .expect("block match arms should parse without trailing commas"); + } + + #[test] + fn test_non_block_arm_requires_trailing_comma() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(a: u8) => a + Right(b: u8) => b, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Missing ',' after a match arm that isn't block expression")); + } + + #[test] + fn test_block_first_arm_without_comma_and_non_block_second_arm_with_comma() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(a: u8) => { a } + Right(b: u8) => b, + } + } + "#; + + Program::parse_from_str(input).expect("mixed block/non-block arm formatting should parse"); + } +} + +mod constructor_overlap { + use crate::parse::test::helpers::collect_parse_errors; + #[test] + fn test_reject_specific_constructor_overlap() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(Left(a: u8)) => a, + Left(v: Either) => 0, + Right(b: u8) => b, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Overlapping match arms in constructor branch")); + } + + #[test] + fn test_reject_specific_some_constructor_overlap() { + let input = r#" + fn main() -> u8 { + match witness::x { + None => 0, + Some(v: Either) => 1, + Some(Left(a: u8)) => a, + Some(Right(b: u8)) => b, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("unreachable")); + assert!(err.contains("covered by a previous arm")); + } + #[test] + fn test_reject_specific_constructor_overlap_on_right_branch() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(a: u8) => a, + Right(Left(b: u8)) => b, + Right(v: Either) => 0, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Overlapping match arms in constructor branch")); + } + + #[test] + fn test_reject_specific_constructor_overlap_on_right_branch_alternate() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(a: u8) => a, + Right(v: Either) => 0, + Right(Left(b: u8)) => b, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("unreachable")); + assert!(err.contains("Right(Left(b: u8)) => b")); + } +} + +mod exhaustiveness { + use crate::parse::test::helpers::collect_parse_errors; + + #[test] + fn test_reject_non_exhaustive_either_match() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(a: u8) => a, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Non-exhaustive Either match: missing Right branch")); + } + + #[test] + fn test_reject_non_exhaustive_nested_either_match() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(Left(a: u8)) => a, + Right(b: u8) => b, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Non-exhaustive Either match: missing Right branch")); + } + + #[test] + fn test_reject_non_exhaustive_option_match() { + let input = r#" + fn main() -> u8 { + match witness::x { + Some(a: u8) => a, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Non-exhaustive Option match: missing None")); + } + + #[test] + fn test_reject_non_exhaustive_bool_match() { + let input = r#" + fn main() -> u8 { + match true { + true => 1, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Non-exhaustive bool match: missing false")); + } + #[test] + fn test_reject_non_exhaustive_nested_option_match() { + let input = r#" + fn main() -> u8 { + match witness::x { + None => 0, + Some(Left(a: u8)) => a, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Non-exhaustive Either match: missing Right branch")); + } + + #[test] + fn test_reject_non_exhaustive_nested_option_some() { + let input = r#" + fn main() -> u8 { + match witness::x { + None => 0, + Some(Some(a: u8)) => a, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Non-exhaustive Option match: missing None")); + } + + #[test] + fn test_reject_non_exhaustive_nested_option_left_arm() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(Some(a: u8)) => a, + Right(b: u8) => b, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Non-exhaustive Option match: missing None")); + } + #[test] + fn test_reject_non_exhaustive_three_level_nested_either() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(Left(Left(a: u8))) => a, + Left(Left(Right(b: u8))) => b, + Right(c: u8) => c, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Non-exhaustive Either match: missing Right branch")); + } +} + +mod usefulness { + use crate::parse::test::helpers::collect_parse_errors; + #[test] + fn test_reject_specific_either_unreachable() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(v: Either) => 0, + Left(Left(a: u8)) => a, + Right(b: u8) => b, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("unreachable")); + assert!(err.contains("Left(Left(a: u8)) => a")); + } + #[test] + fn test_reject_overlapping_nested_and_terminal_arm() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(v: Either) => 0, + Left(Right(a: u8)) => a, + Right(b: u8) => b, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("unreachable")); + assert!(err.contains("covered by a previous arm")); + } + + #[test] + fn test_reject_overlapping_nested_and_terminal_arm_later_arm() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(v: Either) => 0, + Left(Right(a: u8)) => a, + Right(b: u8) => b, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("unreachable")); + assert!(err.contains("Left(Right(a: u8)) => a")); + } + #[test] + fn test_reject_some_wildcard_with_specific_nested_some_arms() { + let input = r#" + fn main() -> u8 { + match witness::x { + None => 0, + Some(v: Either) => 1, + Some(Left(a: u8)) => a, + Some(Right(b: u8)) => b, + } + } + "#; + + let err = collect_parse_errors(input); + assert!( + err.contains("unreachable") + || err.contains("Overlapping match arms in constructor branch") + ); + } + + #[test] + fn test_reject_some_wildcard_with_specific_nested_some_arms_later_arm() { + let input = r#" + fn main() -> u8 { + match witness::x { + None => 0, + Some(v: Either) => 1, + Some(Left(a: u8)) => a, + Some(Right(b: u8)) => b, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Some(Left(a: u8)) => a") || err.contains("Some(Right(b: u8)) => b")); + } +} + +mod formatting { + use crate::parse::test::helpers::normalize_generated_payload_names; + use crate::parse::{ParseFromStr, Program}; + + #[test] + fn test_formatting_snapshot_for_nested_match() { + let input = r#" + fn main() -> u8 { + match witness::x { + None => 0, + Some(Left(a: u8)) => a, + Some(Right(b: u8)) => b, + } + } + "#; + + let program = Program::parse_from_str(input).expect("program should parse"); + let rendered = normalize_generated_payload_names(&program.to_string()); + + assert!(rendered.contains("match witness::x")); + assert!(rendered.contains("Some(__match_payload_: Either)")); + assert!(rendered.contains("match __match_payload_")); + } + #[test] + fn test_formatting_snapshot_for_four_arm_either_match() { + let input = r#" + fn main() -> u8 { + match witness::thingy { + Left(Left(x: u8)) => x, + Left(Right(y: u8)) => y, + Right(Left(a: u8)) => a, + Right(Right(b: u8)) => b, + } + } + "#; + + let program = Program::parse_from_str(input).expect("program should parse"); + let rendered = normalize_generated_payload_names(&program.to_string()); + + assert!(rendered.contains("match witness::thingy")); + assert!(rendered.contains("Left(__match_payload_: Either)")); + assert!(rendered.contains("Right(__match_payload_: Either)")); + assert!(rendered.matches("match __match_payload_").count() >= 2); + } + + #[test] + fn test_formatting_snapshot_for_three_level_match() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(Left(Left(a: u8))) => a, + Left(Left(Right(b: u8))) => b, + Left(Right(c: u8)) => c, + Right(d: u8) => d, + } + } + "#; + + let program = Program::parse_from_str(input).expect("program should parse"); + let rendered = normalize_generated_payload_names(&program.to_string()); + + assert!(rendered.contains("match witness::x")); + assert!(rendered.matches("match __match_payload_").count() >= 2); + } +} + +mod wildcards { + use crate::parse::{ParseFromStr, Program}; + + #[test] + fn test_parse_nested_wildcard_on_both_outer_either_branches() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(v: Either) => 0, + Right(w: Either) => 1, + } + } + "#; + + Program::parse_from_str(input).expect("wildcard outer either branches should parse"); + } + + #[test] + fn test_parse_nested_ignore_patterns() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(Left(_: u8)) => 0, + Left(Right(_: u8)) => 1, + Right(_: u8) => 2, + } + } + "#; + + Program::parse_from_str(input).expect("nested ignore patterns should parse"); + } +} + +mod invalid_patterns { + use crate::parse::test::helpers::collect_parse_errors; + + #[test] + fn test_reject_mixed_match_families() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(a: u8) => a, + Some(b: u8) => b, + } + } + "#; + + let err = collect_parse_errors(input); + assert!( + err.contains("Incompatible") + || err.contains("incompatible") + || err.contains("Left") + || err.contains("Some") + ); + } + + #[test] + fn test_reject_invalid_top_level_bare_arm() { + let input = r#" + fn main() -> u8 { + match witness::x { + a: u8 => a, + Right(b: u8) => b, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Top-level match arms must start with")); + } + + #[test] + fn test_reject_unexpected_terminal_pattern_under_constructor() { + let input = r#" + fn main() -> u8 { + match witness::x { + None => 0, + Some(None) => 1, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Unexpected terminal pattern under constructor")); + } + + #[test] + fn test_reject_malformed_terminal_under_constructor() { + let input = r#" + fn main() -> u8 { + match witness::x { + None => 0, + Some(Left(None)) => 1, + Some(Right(v: Option)) => 2, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Unexpected terminal pattern under constructor")); + } + + #[test] + fn test_reject_nested_bool_as_unexpected_terminal_left_arm() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(true) => 1, + Right(b: u8) => b, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Unexpected terminal pattern under constructor")); + } + + #[test] + fn test_reject_nested_bool_as_unexpected_terminal_left_arm_alternate() { + let input = r#" + fn main() -> u8 { + match witness::x { + Left(false) => 0, + Left(true) => 1, + Right(b: u8) => b, + } + } + "#; + + let err = collect_parse_errors(input); + assert!(err.contains("Unexpected terminal pattern under constructor")); + } + + #[test] + fn test_reject_constructor_family_nested_mismatch() { + let input = r#" + fn main() -> u8 { + match witness::x { + None => 0, + Some(Left(a: u8)) => a, + Some(true) => 1, + } + } + "#; + + let err = collect_parse_errors(input); + assert!( + err.contains("Mixed match families") + || err.contains("Incompatible") + || err.contains("Unexpected terminal pattern under constructor") + ); + } +} + +mod example_cases { + use crate::error::ErrorCollector; + use crate::parse::{ParseFromStrWithErrors, Program}; + use std::sync::Arc; + + #[test] + fn full_test_case() { + let input = r#" + fn main() { + match witness::PATH { + Left(left_or_right: Either<(u64, u256, u256, u256, u256, u256, u256, u256, u256), Either<(bool, u64, u64, u64), (bool, u64, u64)>>) => match left_or_right { + Left(params: (u64, u256, u256, u256, u256, u256, u256, u256, u256)) => { + let (expected_asset_amount, input_option_abf, input_option_vbf, input_grantor_abf, input_grantor_vbf, output_option_abf, output_option_vbf, output_grantor_abf, output_grantor_vbf): (u64, u256, u256, u256, u256, u256, u256, u256, u256) = params; + funding_path( + expected_asset_amount, + input_option_abf, input_option_vbf, + input_grantor_abf, input_grantor_vbf, + output_option_abf, output_option_vbf, + output_grantor_abf, output_grantor_vbf + ); + }, + Right(exercise_or_settlement: Either<(bool, u64, u64, u64), (bool, u64, u64)>) => match exercise_or_settlement { + Left(params: (bool, u64, u64, u64)) => { + let (is_change_needed, amount_to_burn, collateral_amount, asset_amount): (bool, u64, u64, u64) = dbg!(params); + exercise_path(amount_to_burn, collateral_amount, asset_amount, is_change_needed) + }, + Right(params: (bool, u64, u64)) => { + let (is_change_needed, amount_to_burn, asset_amount): (bool, u64, u64) = dbg!(params); + settlement_path(amount_to_burn, asset_amount, is_change_needed) + }, + }, + }, + Right(left_or_right: Either<(bool, u64, u64), (bool, u64, u64)>) => match left_or_right { + Left(params: (bool, u64, u64)) => { + let (is_change_needed, grantor_token_amount_to_burn, collateral_amount): (bool, u64, u64) = params; + expiry_path(grantor_token_amount_to_burn, collateral_amount, is_change_needed) + }, + Right(params: (bool, u64, u64)) => { + let (is_change_needed, amount_to_burn, collateral_amount): (bool, u64, u64) = params; + cancellation_path(amount_to_burn, collateral_amount, is_change_needed) + }, + }, + } + } + "#; + let mut error_handler = ErrorCollector::new(Arc::from(input)); + let parsed = Program::parse_from_str_with_errors(input, &mut error_handler); + + assert!(parsed.is_some()); + assert!(error_handler.is_empty()); + } +}