diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml new file mode 100644 index 0000000..39d3382 --- /dev/null +++ b/.github/workflows/rust.yml @@ -0,0 +1,69 @@ +name: Rust + +on: + - push + - pull_request + +permissions: read-all + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo registry + uses: actions/cache@v4 + with: + path: ~/.cargo/registry + key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }} + + - name: Cache cargo index + uses: actions/cache@v4 + with: + path: ~/.cargo/git + key: ${{ runner.os }}-cargo-index-${{ hashFiles('**/Cargo.lock') }} + + - name: Cache cargo build + uses: actions/cache@v4 + with: + path: target + key: ${{ runner.os }}-cargo-build-target-${{ hashFiles('**/Cargo.lock') }} + + - name: Build + run: cargo build --verbose + + - name: Run tests + run: cargo test --verbose + + fmt: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + + - name: Check formatting + run: cargo fmt -- --check + + clippy: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + with: + components: clippy + + - name: Run Clippy + run: cargo clippy -- -D warnings diff --git a/.gitignore b/.gitignore index 296fc8c..2fa1c6f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,7 @@ .direnv _build + + +# Added by cargo + +/target diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..77a7ca4 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,7 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "hm_inference_example" +version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..8099e1b --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,6 @@ +[package] +name = "hm_inference_example" +version = "0.1.0" +edition = "2024" + +[dependencies] diff --git a/README.md b/README.md index 8ba461f..a651fe2 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,73 @@ # Hindley-Milner Type Inference Example with Algorithm W This repo contains an interpreter for a toy language based on simply-typed lambda calculus. -It is designed to be a practical illustration of how Algorithm W works in OCaml code. +It is designed to be a practical illustration of how Algorithm W works, now blazingly fast in Rust! 🚀🦀 -Most of the code (e.g. parser, type checker) is contained in `/lib`, with `/lib/type_check.ml` being the type inference/checking code, `/lib/parser.mly` containing parser definitions, etc. +The code includes: +- `src/ast.rs`: Abstract syntax tree definitions for expressions +- `src/lexer.rs`: Lexer for tokenizing input +- `src/parser.rs`: Parser for building ASTs from tokens +- `src/type_check.rs`: Type inference/checking code implementing Algorithm W +- `src/eval.rs`: Evaluator for executing expressions +- `src/main.rs`: CLI interface ## Instructions to Build and Run -This project uses [Dune](https://dune.build/) to build and run tests. +This project uses [Cargo](https://doc.rust-lang.org/cargo/) to build and run. + +### Prerequisites + +Make sure you have Rust and Cargo installed. You can install them from [rustup.rs](https://rustup.rs/). + +### Building + +```bash +cargo build --release +``` + +### Running + +You can run the interpreter with an expression as a command-line argument: + +```bash +cargo run -- "let id = x -> x in id 1" +``` + +Or run it in REPL mode: + +```bash +cargo run +``` + +### Running Tests + +```bash +cargo test +``` + +## Example Expressions + +Here are some example expressions you can try: + +``` +1 + 2 * 3 +let x = 5 in x + 10 +x -> x +let id = x -> x in id 1 +let twice = f -> x -> f (f x) in twice +if true then 1 else 2 +``` + +## Original OCaml Version + +The original OCaml implementation can still be found in the `/lib` and `/bin` directories. To build and run the OCaml version: + +This project originally used [Dune](https://dune.build/) to build and run tests. There are a couple of options to get Dune: - If you already have [Nix](https://nixos.org), you can run `nix develop` in the root directory of this repo. That will install Dune with required OPAM packages. - You can install Dune and the required OPAM packages manually, following the normal Dune documentation. -## Instructions to Run the Web Version +To run the web version (OCaml compiled to JavaScript): +Run `python3 -m http.server --bind 127.0.0.1` in the root directory and open `http://localhost:8000` in your browser. -Run `python3 -m http.serever --bind 127.0.0.1` in the root directory and open `http://localhost:8000` in your browser. diff --git a/src/ast.rs b/src/ast.rs new file mode 100644 index 0000000..cdc6205 --- /dev/null +++ b/src/ast.rs @@ -0,0 +1,54 @@ +use std::fmt; + +#[derive(Debug, Clone, PartialEq)] +pub enum UnaryOp { + Not, + Neg, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum BinaryOp { + Equal, + Add, + Mul, + And, + Or, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum Expr { + Var(String), + Int(i32), + Bool(bool), + OpUnary(UnaryOp, Box), + OpBinary(BinaryOp, Box, Box), + Closure(String, Box), + Application(Box, Box), + Let(String, Box, Box), + If(Box, Box, Box), +} + +impl fmt::Display for Expr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Expr::Var(x) => write!(f, "{}", x), + Expr::Int(n) => write!(f, "{}", n), + Expr::Bool(b) => write!(f, "{}", b), + Expr::OpUnary(op, e) => match op { + UnaryOp::Not => write!(f, "not {}", e), + UnaryOp::Neg => write!(f, "-{}", e), + }, + Expr::OpBinary(op, e1, e2) => match op { + BinaryOp::Equal => write!(f, "{} = {}", e1, e2), + BinaryOp::Add => write!(f, "{} + {}", e1, e2), + BinaryOp::Mul => write!(f, "{} * {}", e1, e2), + BinaryOp::And => write!(f, "{} && {}", e1, e2), + BinaryOp::Or => write!(f, "{} || {}", e1, e2), + }, + Expr::Closure(x, e) => write!(f, "{} -> {}", x, e), + Expr::Application(e1, e2) => write!(f, "({}) ({})", e1, e2), + Expr::Let(x, e1, e2) => write!(f, "let {} = {} in {}", x, e1, e2), + Expr::If(e1, e2, e3) => write!(f, "if {} then {} else {}", e1, e2, e3), + } + } +} diff --git a/src/eval.rs b/src/eval.rs new file mode 100644 index 0000000..b5e59c5 --- /dev/null +++ b/src/eval.rs @@ -0,0 +1,99 @@ +use crate::ast::{BinaryOp, Expr, UnaryOp}; + +fn subst(e: &Expr, v: &Expr, x: &str) -> Expr { + match e { + Expr::Int(_) | Expr::Bool(_) => e.clone(), + Expr::Var(y) => { + if y == x { + v.clone() + } else { + e.clone() + } + } + Expr::OpUnary(uop, e1) => Expr::OpUnary(uop.clone(), Box::new(subst(e1, v, x))), + Expr::OpBinary(bop, e1, e2) => Expr::OpBinary( + bop.clone(), + Box::new(subst(e1, v, x)), + Box::new(subst(e2, v, x)), + ), + Expr::If(cond, e1, e2) => Expr::If( + Box::new(subst(cond, v, x)), + Box::new(subst(e1, v, x)), + Box::new(subst(e2, v, x)), + ), + Expr::Closure(y, e1) => { + if y == x { + e.clone() + } else { + Expr::Closure(y.clone(), Box::new(subst(e1, v, x))) + } + } + Expr::Application(e1, e2) => { + Expr::Application(Box::new(subst(e1, v, x)), Box::new(subst(e2, v, x))) + } + Expr::Let(y, e1, e2) => { + let e1_prime = subst(e1, v, x); + if y == x { + Expr::Let(y.clone(), Box::new(e1_prime), e2.clone()) + } else { + Expr::Let(y.clone(), Box::new(e1_prime), Box::new(subst(e2, v, x))) + } + } + } +} + +pub fn eval(e: &Expr) -> Result { + match e { + Expr::Int(_) | Expr::Bool(_) | Expr::Closure(_, _) => Ok(e.clone()), + Expr::Var(x) => Err(format!("unbound variable {} while evaluating {}", x, e)), + Expr::OpUnary(uop, e1) => eval_uop(uop, e1), + Expr::OpBinary(bop, e1, e2) => eval_bop(bop, e1, e2), + Expr::Application(e1, e2) => { + let e1_val = eval(e1)?; + match e1_val { + Expr::Closure(x, body) => { + let substituted = subst(&body, e2, &x); + eval(&substituted) + } + _ => Err("application of non-closure".to_string()), + } + } + Expr::Let(x, e1, e2) => { + let e1_val = eval(e1)?; + let substituted = subst(e2, &e1_val, x); + eval(&substituted) + } + Expr::If(cond, e_then, e_else) => { + let cond_val = eval(cond)?; + match cond_val { + Expr::Bool(true) => eval(e_then), + Expr::Bool(false) => eval(e_else), + _ => Err("guard must be bool".to_string()), + } + } + } +} + +fn eval_uop(uop: &UnaryOp, e1: &Expr) -> Result { + let e1_val = eval(e1)?; + match (uop, e1_val) { + (UnaryOp::Not, Expr::Bool(true)) => Ok(Expr::Bool(false)), + (UnaryOp::Not, Expr::Bool(false)) => Ok(Expr::Bool(true)), + (UnaryOp::Neg, Expr::Int(a)) => Ok(Expr::Int(-a)), + _ => Err("unary op with given operand not defined".to_string()), + } +} + +fn eval_bop(bop: &BinaryOp, e1: &Expr, e2: &Expr) -> Result { + let e1_val = eval(e1)?; + let e2_val = eval(e2)?; + match (bop, e1_val, e2_val) { + (BinaryOp::Add, Expr::Int(a), Expr::Int(b)) => Ok(Expr::Int(a + b)), + (BinaryOp::Mul, Expr::Int(a), Expr::Int(b)) => Ok(Expr::Int(a * b)), + (BinaryOp::And, Expr::Bool(a), Expr::Bool(b)) => Ok(Expr::Bool(a && b)), + (BinaryOp::Or, Expr::Bool(a), Expr::Bool(b)) => Ok(Expr::Bool(a || b)), + (BinaryOp::Equal, Expr::Int(a), Expr::Int(b)) => Ok(Expr::Bool(a == b)), + (BinaryOp::Equal, Expr::Bool(a), Expr::Bool(b)) => Ok(Expr::Bool(a == b)), + _ => Err("binary op with given operands not defined".to_string()), + } +} diff --git a/src/lexer.rs b/src/lexer.rs new file mode 100644 index 0000000..b1ff063 --- /dev/null +++ b/src/lexer.rs @@ -0,0 +1,213 @@ +use std::fmt; + +#[derive(Debug, Clone, PartialEq)] +pub enum Token { + Int(i32), + IdValue(String), + True, + False, + Plus, + Minus, + Times, + Equals, + Not, + And, + Or, + Let, + In, + If, + Then, + Else, + LParen, + RParen, + RArrow, + Eof, +} + +impl fmt::Display for Token { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Token::Int(n) => write!(f, "{}", n), + Token::IdValue(s) => write!(f, "{}", s), + Token::True => write!(f, "true"), + Token::False => write!(f, "false"), + Token::Plus => write!(f, "+"), + Token::Minus => write!(f, "-"), + Token::Times => write!(f, "*"), + Token::Equals => write!(f, "="), + Token::Not => write!(f, "!"), + Token::And => write!(f, "&&"), + Token::Or => write!(f, "||"), + Token::Let => write!(f, "let"), + Token::In => write!(f, "in"), + Token::If => write!(f, "if"), + Token::Then => write!(f, "then"), + Token::Else => write!(f, "else"), + Token::LParen => write!(f, "("), + Token::RParen => write!(f, ")"), + Token::RArrow => write!(f, "->"), + Token::Eof => write!(f, "EOF"), + } + } +} + +pub struct Lexer { + input: Vec, + pos: usize, +} + +impl Lexer { + pub fn new(input: &str) -> Self { + Lexer { + input: input.chars().collect(), + pos: 0, + } + } + + fn current_char(&self) -> Option { + if self.pos < self.input.len() { + Some(self.input[self.pos]) + } else { + None + } + } + + fn peek_char(&self, offset: usize) -> Option { + if self.pos + offset < self.input.len() { + Some(self.input[self.pos + offset]) + } else { + None + } + } + + fn advance(&mut self) { + self.pos += 1; + } + + fn skip_whitespace(&mut self) { + while let Some(ch) = self.current_char() { + if ch.is_whitespace() { + self.advance(); + } else { + break; + } + } + } + + fn read_number(&mut self) -> i32 { + let mut num_str = String::new(); + + // Handle negative sign + if self.current_char() == Some('-') { + num_str.push('-'); + self.advance(); + } + + while let Some(ch) = self.current_char() { + if ch.is_ascii_digit() { + num_str.push(ch); + self.advance(); + } else { + break; + } + } + + num_str.parse().unwrap_or(0) + } + + fn read_identifier(&mut self) -> String { + let mut id = String::new(); + + while let Some(ch) = self.current_char() { + if ch.is_alphanumeric() || ch == '_' || ch == '-' || ch == '\'' { + id.push(ch); + self.advance(); + } else { + break; + } + } + + id + } + + pub fn next_token(&mut self) -> Token { + self.skip_whitespace(); + + match self.current_char() { + None => Token::Eof, + Some(ch) => { + if ch.is_ascii_digit() { + return Token::Int(self.read_number()); + } + + if ch == '-' { + if let Some(next) = self.peek_char(1) { + if next == '>' { + self.advance(); + self.advance(); + return Token::RArrow; + } else if next.is_ascii_digit() { + return Token::Int(self.read_number()); + } + } + self.advance(); + return Token::Minus; + } + + if ch.is_alphabetic() { + let id = self.read_identifier(); + return match id.as_str() { + "true" => Token::True, + "false" => Token::False, + "let" => Token::Let, + "in" => Token::In, + "if" => Token::If, + "then" => Token::Then, + "else" => Token::Else, + _ => Token::IdValue(id), + }; + } + + self.advance(); + match ch { + '!' => Token::Not, + '+' => Token::Plus, + '*' => Token::Times, + '=' => Token::Equals, + '(' => Token::LParen, + ')' => Token::RParen, + '&' => { + if self.current_char() == Some('&') { + self.advance(); + Token::And + } else { + panic!("Unexpected character: {}", ch); + } + } + '|' => { + if self.current_char() == Some('|') { + self.advance(); + Token::Or + } else { + panic!("Unexpected character: {}", ch); + } + } + _ => panic!("Unexpected character: {}", ch), + } + } + } + } + + pub fn tokenize(&mut self) -> Vec { + let mut tokens = Vec::new(); + loop { + let token = self.next_token(); + if token == Token::Eof { + tokens.push(token); + break; + } + tokens.push(token); + } + tokens + } +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..7890035 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,93 @@ +mod ast; +mod eval; +mod lexer; +mod parser; +mod type_check; + +use std::env; +use std::io::{self, Write}; + +fn main() { + let args: Vec = env::args().collect(); + + if args.len() > 1 { + // Run with command line argument + let input = &args[1]; + run_program(input); + } else { + // Interactive REPL mode + repl(); + } +} + +fn repl() { + println!("Hindley-Milner Type Inference Example (Rust Edition)"); + println!("Type expressions to evaluate and infer types. Press Ctrl+C to exit."); + println!(); + + loop { + print!("> "); + io::stdout().flush().unwrap(); + + let mut input = String::new(); + if io::stdin().read_line(&mut input).is_err() { + break; + } + + let input = input.trim(); + if input.is_empty() { + continue; + } + + run_program(input); + println!(); + } +} + +fn run_program(input: &str) { + // Tokenize + let mut lexer = lexer::Lexer::new(input); + let tokens = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| lexer.tokenize())) { + Ok(tokens) => tokens, + Err(_) => { + println!("Error: Failed to tokenize input"); + return; + } + }; + + // Parse + let mut parser = parser::Parser::new(tokens); + let ast = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| parser.parse())) { + Ok(ast) => ast, + Err(_) => { + println!("Error: Failed to parse input"); + return; + } + }; + + // Type check + type_check::reset_gensym(); + let env = type_check::Env::new(); + let result_type = match type_check::infer(&ast, &env) { + Ok((t, _)) => type_check::lower(&t), + Err(e) => { + println!("Type error: {}", e); + return; + } + }; + + // Evaluate + let result = match eval::eval(&ast) { + Ok(r) => r, + Err(e) => { + println!("Evaluation error: {}", e); + return; + } + }; + + println!("Result: {}", result); + println!("Type: {}", result_type); +} + +#[cfg(test)] +mod tests; diff --git a/src/parser.rs b/src/parser.rs new file mode 100644 index 0000000..8fa2221 --- /dev/null +++ b/src/parser.rs @@ -0,0 +1,204 @@ +use crate::ast::{BinaryOp, Expr, UnaryOp}; +use crate::lexer::Token; + +pub struct Parser { + tokens: Vec, + pos: usize, +} + +impl Parser { + pub fn new(tokens: Vec) -> Self { + Parser { tokens, pos: 0 } + } + + fn current_token(&self) -> &Token { + if self.pos < self.tokens.len() { + &self.tokens[self.pos] + } else { + &Token::Eof + } + } + + fn advance(&mut self) { + self.pos += 1; + } + + fn expect(&mut self, expected: Token) { + if self.current_token() != &expected { + panic!("Expected {:?}, got {:?}", expected, self.current_token()); + } + self.advance(); + } + + pub fn parse(&mut self) -> Expr { + let expr = self.parse_expr(); + self.expect(Token::Eof); + expr + } + + fn parse_expr(&mut self) -> Expr { + self.parse_let_or_if() + } + + fn parse_let_or_if(&mut self) -> Expr { + match self.current_token() { + Token::Let => { + self.advance(); + let Token::IdValue(name) = self.current_token().clone() else { + panic!("Expected identifier after 'let'"); + }; + self.advance(); + self.expect(Token::Equals); + let e1 = self.parse_expr(); + self.expect(Token::In); + let e2 = self.parse_expr(); + Expr::Let(name, Box::new(e1), Box::new(e2)) + } + Token::If => { + self.advance(); + let cond = self.parse_expr(); + self.expect(Token::Then); + let then_branch = self.parse_expr(); + self.expect(Token::Else); + let else_branch = self.parse_expr(); + Expr::If(Box::new(cond), Box::new(then_branch), Box::new(else_branch)) + } + _ => self.parse_closure(), + } + } + + fn parse_closure(&mut self) -> Expr { + let start_pos = self.pos; + + // Try to parse as closure: x -> expr + if let Token::IdValue(param) = self.current_token().clone() { + self.advance(); + if let Token::RArrow = self.current_token() { + self.advance(); + let body = self.parse_expr(); + return Expr::Closure(param, Box::new(body)); + } + // Not a closure, backtrack + self.pos = start_pos; + } + + self.parse_or() + } + + fn parse_or(&mut self) -> Expr { + let mut left = self.parse_and(); + + while let Token::Or = self.current_token() { + self.advance(); + let right = self.parse_and(); + left = Expr::OpBinary(BinaryOp::Or, Box::new(left), Box::new(right)); + } + + left + } + + fn parse_and(&mut self) -> Expr { + let mut left = self.parse_equality(); + + while let Token::And = self.current_token() { + self.advance(); + let right = self.parse_equality(); + left = Expr::OpBinary(BinaryOp::And, Box::new(left), Box::new(right)); + } + + left + } + + fn parse_equality(&mut self) -> Expr { + let mut left = self.parse_additive(); + + while let Token::Equals = self.current_token() { + self.advance(); + let right = self.parse_additive(); + left = Expr::OpBinary(BinaryOp::Equal, Box::new(left), Box::new(right)); + } + + left + } + + fn parse_additive(&mut self) -> Expr { + let mut left = self.parse_multiplicative(); + + while let Token::Plus = self.current_token() { + self.advance(); + let right = self.parse_multiplicative(); + left = Expr::OpBinary(BinaryOp::Add, Box::new(left), Box::new(right)); + } + + left + } + + fn parse_multiplicative(&mut self) -> Expr { + let mut left = self.parse_unary(); + + while let Token::Times = self.current_token() { + self.advance(); + let right = self.parse_unary(); + left = Expr::OpBinary(BinaryOp::Mul, Box::new(left), Box::new(right)); + } + + left + } + + fn parse_unary(&mut self) -> Expr { + match self.current_token() { + Token::Not => { + self.advance(); + let expr = self.parse_unary(); + Expr::OpUnary(UnaryOp::Not, Box::new(expr)) + } + Token::Minus => { + self.advance(); + let expr = self.parse_unary(); + Expr::OpUnary(UnaryOp::Neg, Box::new(expr)) + } + _ => self.parse_application(), + } + } + + fn parse_application(&mut self) -> Expr { + let mut left = self.parse_primary(); + + while let Token::Int(_) | Token::True | Token::False | Token::IdValue(_) | Token::LParen = + self.current_token() + { + let right = self.parse_primary(); + left = Expr::Application(Box::new(left), Box::new(right)); + } + + left + } + + fn parse_primary(&mut self) -> Expr { + match self.current_token().clone() { + Token::Int(n) => { + self.advance(); + Expr::Int(n) + } + Token::True => { + self.advance(); + Expr::Bool(true) + } + Token::False => { + self.advance(); + Expr::Bool(false) + } + Token::IdValue(name) => { + self.advance(); + Expr::Var(name) + } + Token::LParen => { + self.advance(); + let expr = self.parse_expr(); + self.expect(Token::RParen); + expr + } + _ => panic!("Unexpected token: {:?}", self.current_token()), + } + } +} diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 0000000..b914f19 --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,241 @@ +use crate::ast::Expr; +use crate::eval; +use crate::lexer::Lexer; +use crate::parser::Parser; +use crate::type_check::{self, Env, Typ}; +use std::sync::Mutex; + +// Use a mutex to ensure tests run sequentially +static TEST_LOCK: Mutex<()> = Mutex::new(()); + +fn parse(input: &str) -> Expr { + let mut lexer = Lexer::new(input); + let tokens = lexer.tokenize(); + let mut parser = Parser::new(tokens); + parser.parse() +} + +fn test_eval(src: &str, expected: Expr) { + let _lock = TEST_LOCK.lock().unwrap(); + let ast = parse(src); + let result = eval::eval(&ast).expect(&format!("test_eval failed for: {}", src)); + assert_eq!( + result, expected, + "test_eval failed for: {}\nExpected: {:?}\nGot: {:?}", + src, expected, result + ); + println!("test_eval OK {}", src); +} + +fn test_infer(src: &str, expected: Typ) { + let _lock = TEST_LOCK.lock().unwrap(); + type_check::reset_gensym(); + let ast = parse(src); + let env = Env::new(); + let (result, _) = + type_check::infer(&ast, &env).expect(&format!("test_infer failed for: {}", src)); + let result = type_check::lower(&result); + assert_eq!( + result, expected, + "test_infer failed for: {}\nExpected: {}\nGot: {}", + src, expected, result + ); + println!("test_infer OK {}", result); +} + +#[test] +fn test_eval_literals() { + test_eval("1", Expr::Int(1)); + test_eval("true", Expr::Bool(true)); +} + +#[test] +fn test_eval_arithmetic() { + test_eval("0 * 1 + 2 * 4", Expr::Int(8)); +} + +#[test] +fn test_eval_let() { + test_eval("let x = 1 in x", Expr::Int(1)); + test_eval("let x = 1 in x + 2", Expr::Int(3)); + test_eval("let x = 1 in x + (let y = 2 in y)", Expr::Int(3)); + test_eval("let x = 1 in x + (let x = 2 in x)", Expr::Int(3)); + test_eval("let x = 1 in x + (let y = 2 in y + x)", Expr::Int(4)); +} + +#[test] +fn test_eval_closure() { + test_eval( + "x -> x", + Expr::Closure("x".to_string(), Box::new(Expr::Var("x".to_string()))), + ); +} + +#[test] +fn test_eval_application() { + test_eval("let first = x -> y -> x in first 1 2", Expr::Int(1)); + test_eval("let second = x -> y -> y in second 1 2", Expr::Int(2)); + test_eval("(x -> x) 1", Expr::Int(1)); + test_eval("(x -> x) (y -> y) 1", Expr::Int(1)); + test_eval("(x -> x) (x -> x) 1", Expr::Int(1)); + test_eval("let identity = x -> x in identity 1", Expr::Int(1)); + test_eval( + "let identity = x -> x in let x = 2 in identity x", + Expr::Int(2), + ); + test_eval("let f = x -> y -> x + y in f 1 2", Expr::Int(3)); +} + +#[test] +fn test_eval_compose() { + test_eval( + "let add1 = x -> (x + 1) in let compose-twice = (f -> x -> f (f x)) in compose-twice add1 1", + Expr::Int(3), + ); +} + +#[test] +fn test_infer_literals() { + test_infer("true", Typ::TBool); + test_infer("1", Typ::TInt); +} + +#[test] +fn test_infer_unary_ops() { + test_infer("!true", Typ::TBool); + test_infer("let x = 1 in -x", Typ::TInt); +} + +#[test] +fn test_infer_let() { + test_infer("let x = 1 in x", Typ::TInt); +} + +#[test] +fn test_infer_identity() { + test_infer("let id = x -> x in id 1", Typ::TInt); + test_infer( + "x -> x", + Typ::TClosure( + Box::new(Typ::TVar("$1".to_string())), + Box::new(Typ::TVar("$1".to_string())), + ), + ); + test_infer( + "let id = x -> x in id", + Typ::TClosure( + Box::new(Typ::TVar("$1".to_string())), + Box::new(Typ::TVar("$1".to_string())), + ), + ); + test_infer( + "let id = x -> x in id id", + Typ::TClosure( + Box::new(Typ::TVar("$1".to_string())), + Box::new(Typ::TVar("$1".to_string())), + ), + ); + test_infer( + "let id = x -> x in (id id) (id id)", + Typ::TClosure( + Box::new(Typ::TVar("$1".to_string())), + Box::new(Typ::TVar("$1".to_string())), + ), + ); +} + +#[test] +fn test_infer_equality() { + test_infer( + "let f = x -> x = 1 in f", + Typ::TClosure(Box::new(Typ::TInt), Box::new(Typ::TBool)), + ); + test_infer( + "let f = x -> x = true in f", + Typ::TClosure(Box::new(Typ::TBool), Box::new(Typ::TBool)), + ); + test_infer( + "let id = x -> x in let f = x -> x = id in f", + Typ::TClosure( + Box::new(Typ::TClosure( + Box::new(Typ::TVar("$1".to_string())), + Box::new(Typ::TVar("$1".to_string())), + )), + Box::new(Typ::TBool), + ), + ); +} + +#[test] +fn test_infer_arithmetic_ops() { + test_infer( + "let negate = x -> -x in negate", + Typ::TClosure(Box::new(Typ::TInt), Box::new(Typ::TInt)), + ); + test_infer( + "let id = x -> 1 * x in id", + Typ::TClosure(Box::new(Typ::TInt), Box::new(Typ::TInt)), + ); + test_infer( + "let id = x -> 0 + x in id", + Typ::TClosure(Box::new(Typ::TInt), Box::new(Typ::TInt)), + ); +} + +#[test] +fn test_infer_boolean_ops() { + test_infer( + "let flip = x -> !x in flip", + Typ::TClosure(Box::new(Typ::TBool), Box::new(Typ::TBool)), + ); + test_infer( + "let id = x -> true && x in id", + Typ::TClosure(Box::new(Typ::TBool), Box::new(Typ::TBool)), + ); + test_infer( + "let id = x -> false || x in id", + Typ::TClosure(Box::new(Typ::TBool), Box::new(Typ::TBool)), + ); + test_infer("true && false", Typ::TBool); + test_infer("true || false", Typ::TBool); +} + +#[test] +fn test_infer_twice() { + test_infer( + "let twice = f -> x -> f (f x) in twice", + Typ::TClosure( + Box::new(Typ::TClosure( + Box::new(Typ::TVar("$1".to_string())), + Box::new(Typ::TVar("$1".to_string())), + )), + Box::new(Typ::TClosure( + Box::new(Typ::TVar("$1".to_string())), + Box::new(Typ::TVar("$1".to_string())), + )), + ), + ); +} + +#[test] +fn test_infer_application() { + test_infer( + "f -> f (f 1)", + Typ::TClosure( + Box::new(Typ::TClosure(Box::new(Typ::TInt), Box::new(Typ::TInt))), + Box::new(Typ::TInt), + ), + ); +} + +#[test] +fn test_infer_if() { + test_infer( + "x -> if x then 1 else 2", + Typ::TClosure(Box::new(Typ::TBool), Box::new(Typ::TInt)), + ); + test_infer( + "x -> if true then 1 else x", + Typ::TClosure(Box::new(Typ::TInt), Box::new(Typ::TInt)), + ); +} diff --git a/src/type_check.rs b/src/type_check.rs new file mode 100644 index 0000000..a2575be --- /dev/null +++ b/src/type_check.rs @@ -0,0 +1,342 @@ +use crate::ast::{BinaryOp, Expr, UnaryOp}; +use std::collections::{HashMap, HashSet}; +use std::fmt; +use std::sync::atomic::{AtomicUsize, Ordering}; + +static GENSYM_COUNTER: AtomicUsize = AtomicUsize::new(0); + +fn gensym() -> String { + let id = GENSYM_COUNTER.fetch_add(1, Ordering::SeqCst); + format!("${}", id + 1) +} + +pub fn reset_gensym() { + GENSYM_COUNTER.store(0, Ordering::SeqCst); +} + +#[allow(clippy::enum_variant_names)] +#[derive(Debug, Clone, PartialEq)] +pub enum Typ { + TVar(String), + TClosure(Box, Box), + TInt, + TBool, +} + +impl fmt::Display for Typ { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Typ::TVar(x) => write!(f, "{}", x), + Typ::TClosure(t1, t2) => write!(f, "({} -> {})", t1, t2), + Typ::TInt => write!(f, "int"), + Typ::TBool => write!(f, "bool"), + } + } +} + +pub type Subst = HashMap; +pub type TypeScheme = (Vec, Typ); + +fn vars_in(t: &Typ) -> Vec { + let mut vars = HashSet::new(); + fn collect(t: &Typ, vars: &mut HashSet) { + match t { + Typ::TVar(x) => { + vars.insert(x.clone()); + } + Typ::TClosure(t1, t2) => { + collect(t1, vars); + collect(t2, vars); + } + _ => {} + } + } + collect(t, &mut vars); + let mut result: Vec<_> = vars.into_iter().collect(); + result.sort(); + result +} + +fn not_contains(a: &str, t: &Typ) -> bool { + match t { + Typ::TVar(b) if b == a => false, + Typ::TClosure(t1, t2) => not_contains(a, t1) && not_contains(a, t2), + _ => true, + } +} + +fn apply_typ(s: &Subst, t: &Typ) -> Typ { + match t { + Typ::TVar(x) => s.get(x).cloned().unwrap_or_else(|| t.clone()), + Typ::TClosure(t1, t2) => { + Typ::TClosure(Box::new(apply_typ(s, t1)), Box::new(apply_typ(s, t2))) + } + _ => t.clone(), + } +} + +fn compose(s1: &Subst, s2: &Subst) -> Subst { + let mut result = HashMap::new(); + + // Apply s2 to each type in s1 + for (x, t) in s1 { + result.insert(x.clone(), apply_typ(s2, t)); + } + + // Add substitutions from s2 that are not in s1 + for (x, t) in s2 { + if !result.contains_key(x) { + result.insert(x.clone(), t.clone()); + } + } + + result +} + +pub fn lower(t: &Typ) -> Typ { + let vars = vars_in(t); + let mut s = HashMap::new(); + + let mut counter = 0; + for var in vars { + counter += 1; + s.insert(var, Typ::TVar(format!("${}", counter))); + } + + apply_typ(&s, t) +} + +pub struct Env { + bindings: HashMap, +} + +impl Env { + pub fn new() -> Self { + Env { + bindings: HashMap::new(), + } + } + + pub fn lookup(&self, x: &str) -> Result { + self.bindings + .get(x) + .cloned() + .ok_or_else(|| format!("unbound variable {}", x)) + } + + pub fn extend(&self, x: String, ts: TypeScheme) -> Self { + let mut new_bindings = self.bindings.clone(); + new_bindings.insert(x, ts); + Env { + bindings: new_bindings, + } + } +} + +fn apply_env(s: &Subst, env: &Env) -> Env { + let mut new_bindings = HashMap::new(); + for (x, (vars, t)) in &env.bindings { + new_bindings.insert(x.clone(), (vars.clone(), apply_typ(s, t))); + } + Env { + bindings: new_bindings, + } +} + +fn instantiate(ts: &TypeScheme) -> Typ { + let (vars, t) = ts; + let mut s = HashMap::new(); + for var in vars { + s.insert(var.clone(), Typ::TVar(gensym())); + } + apply_typ(&s, t) +} + +fn generalize(env: &Env, t: &Typ) -> TypeScheme { + let vars_in_t: HashSet<_> = vars_in(t).into_iter().collect(); + let mut vars_in_env = HashSet::new(); + for (vars, _) in env.bindings.values() { + for var in vars { + vars_in_env.insert(var.clone()); + } + } + + let mut vars_not_in: Vec<_> = vars_in_t + .into_iter() + .filter(|x| !vars_in_env.contains(x)) + .collect(); + vars_not_in.sort(); + + (vars_not_in, t.clone()) +} + +fn unify(t1: &Typ, t2: &Typ) -> Result { + match (t1, t2) { + (Typ::TInt, Typ::TInt) => Ok(HashMap::new()), + (Typ::TBool, Typ::TBool) => Ok(HashMap::new()), + (Typ::TVar(x), Typ::TVar(y)) if x == y => Ok(HashMap::new()), + (Typ::TVar(x), t) if not_contains(x, t) => { + let mut s = HashMap::new(); + s.insert(x.clone(), t.clone()); + Ok(s) + } + (t, Typ::TVar(x)) if not_contains(x, t) => { + let mut s = HashMap::new(); + s.insert(x.clone(), t.clone()); + Ok(s) + } + (Typ::TClosure(lhs1, rhs1), Typ::TClosure(lhs2, rhs2)) => { + let s1 = unify(lhs1, lhs2)?; + let rhs1_applied = apply_typ(&s1, rhs1); + let rhs2_applied = apply_typ(&s1, rhs2); + let s2 = unify(&rhs1_applied, &rhs2_applied)?; + Ok(compose(&s1, &s2)) + } + _ => Err(format!("unification failed between {} and {}", t1, t2)), + } +} + +pub fn infer(e: &Expr, env: &Env) -> Result<(Typ, Subst), String> { + match e { + Expr::Var(x) => { + let ts = env.lookup(x)?; + let t = instantiate(&ts); + Ok((t, HashMap::new())) + } + Expr::Int(_) => Ok((Typ::TInt, HashMap::new())), + Expr::Bool(_) => Ok((Typ::TBool, HashMap::new())), + Expr::OpUnary(op, e) => { + let (t, s) = infer(e, env)?; + match op { + UnaryOp::Not => { + let expected = Typ::TBool; + match &t { + Typ::TVar(x) => { + let mut s_new = HashMap::new(); + s_new.insert(x.clone(), expected.clone()); + Ok((expected, compose(&s, &s_new))) + } + t if t == &expected => Ok((expected, s)), + _ => Err("must be bool".to_string()), + } + } + UnaryOp::Neg => { + let expected = Typ::TInt; + match &t { + Typ::TVar(x) => { + let mut s_new = HashMap::new(); + s_new.insert(x.clone(), expected.clone()); + Ok((expected, compose(&s, &s_new))) + } + t if t == &expected => Ok((expected, s)), + _ => Err("must be int".to_string()), + } + } + } + } + Expr::OpBinary(op, e1, e2) => match op { + BinaryOp::Equal => { + let (t1, s1) = infer(e1, env)?; + let (t2, s2) = infer(e2, env)?; + let s3 = unify(&t1, &t2)?; + let s4 = compose(&compose(&s1, &s2), &s3); + Ok((Typ::TBool, s4)) + } + BinaryOp::Add | BinaryOp::Mul => { + let expected = Typ::TInt; + let (t1, s1) = infer(e1, env)?; + let (t2, s2) = infer(e2, env)?; + let s3 = compose(&s1, &s2); + + let mut s_result = s3.clone(); + match (&t1, &t2) { + (Typ::TVar(x), Typ::TVar(y)) if x == y => { + s_result.insert(x.clone(), expected.clone()); + } + (Typ::TVar(x), Typ::TVar(y)) => { + s_result.insert(x.clone(), expected.clone()); + s_result.insert(y.clone(), expected.clone()); + } + (Typ::TVar(x), t2) if t2 == &expected => { + s_result.insert(x.clone(), expected.clone()); + } + (t1, Typ::TVar(y)) if t1 == &expected => { + s_result.insert(y.clone(), expected.clone()); + } + (t1, t2) if t1 == &expected && t2 == &expected => {} + _ => { + return Err(format!( + "both sides must be int but lhs is {} and rhs is {}", + t1, t2 + )); + } + } + Ok((expected, s_result)) + } + BinaryOp::And | BinaryOp::Or => { + let expected = Typ::TBool; + let (t1, s1) = infer(e1, env)?; + let (t2, s2) = infer(e2, env)?; + let s3 = compose(&s1, &s2); + + let mut s_result = s3.clone(); + match (&t1, &t2) { + (Typ::TVar(x), Typ::TVar(y)) if x == y => { + s_result.insert(x.clone(), expected.clone()); + } + (Typ::TVar(x), Typ::TVar(y)) => { + s_result.insert(x.clone(), expected.clone()); + s_result.insert(y.clone(), expected.clone()); + } + (Typ::TVar(x), t2) if t2 == &expected => { + s_result.insert(x.clone(), expected.clone()); + } + (t1, Typ::TVar(y)) if t1 == &expected => { + s_result.insert(y.clone(), expected.clone()); + } + (t1, t2) if t1 == &expected && t2 == &expected => {} + _ => { + return Err(format!( + "both sides must be bool but lhs is {} and rhs is {}", + t1, t2 + )); + } + } + Ok((expected, s_result)) + } + }, + Expr::Closure(x, e) => { + let a = Typ::TVar(gensym()); + let env_prime = env.extend(x.clone(), (vec![], a.clone())); + let (t, s) = infer(e, &env_prime)?; + let result_type = Typ::TClosure(Box::new(apply_typ(&s, &a)), Box::new(t)); + Ok((result_type, s)) + } + Expr::Application(e1, e2) => { + let a = Typ::TVar(gensym()); + let (t1, s1) = infer(e1, env)?; + let env1 = apply_env(&s1, env); + let (t2, s2) = infer(e2, &env1)?; + let t1_prime = apply_typ(&s2, &t1); + let t3 = Typ::TClosure(Box::new(t2), Box::new(a.clone())); + let s3 = unify(&t1_prime, &t3)?; + let s4 = compose(&compose(&s1, &s2), &s3); + Ok((apply_typ(&s4, &a), s4)) + } + Expr::Let(x, e1, e2) => { + let (t1, s1) = infer(e1, env)?; + let ts = generalize(env, &apply_typ(&s1, &t1)); + let env_prime = apply_env(&s1, env).extend(x.clone(), ts); + infer(e2, &env_prime) + } + Expr::If(e1, e2, e3) => { + let (t1, s1) = infer(e1, env)?; + let (t2, s2) = infer(e2, env)?; + let (t3, s3) = infer(e3, env)?; + let s4 = unify(&t1, &Typ::TBool)?; + let s5 = unify(&t2, &t3)?; + let s6 = compose(&compose(&compose(&compose(&s1, &s2), &s3), &s4), &s5); + Ok((apply_typ(&s6, &t2), s6)) + } + } +}