From 32a5c0ff0fc58aa6721c1e0ad41950bde2d66744 Mon Sep 17 00:00:00 2001 From: Griffin Smith Date: Sat, 13 Mar 2021 21:57:27 -0500 Subject: Add the start of a hindley-milner typechecker The beginning of a parse-don't-validate-based hindley-milner typechecker, which returns on success an IR where every AST node trivially knows its own type, and using those types to determine LLVM types in codegen. --- Cargo.lock | 1 + Cargo.toml | 1 + ach/.gitignore | 3 + src/ast/hir.rs | 246 ++++++++++++++++++++++ src/ast/mod.rs | 3 + src/codegen/llvm.rs | 76 ++++--- src/codegen/mod.rs | 5 +- src/commands/check.rs | 39 ++++ src/commands/eval.rs | 6 +- src/commands/mod.rs | 2 + src/common/env.rs | 24 ++- src/common/error.rs | 11 +- src/compiler.rs | 4 +- src/interpreter/mod.rs | 70 ++++--- src/interpreter/value.rs | 5 +- src/main.rs | 3 + src/parser/expr.rs | 25 ++- src/parser/macros.rs | 1 + src/parser/mod.rs | 5 +- src/tc/mod.rs | 528 +++++++++++++++++++++++++++++++++++++++++++++++ 20 files changed, 980 insertions(+), 78 deletions(-) create mode 100644 src/ast/hir.rs create mode 100644 src/commands/check.rs create mode 100644 src/tc/mod.rs diff --git a/Cargo.lock b/Cargo.lock index d8eaedeca181..8ec5ad6cf952 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9,6 +9,7 @@ dependencies = [ "derive_more", "inkwell", "itertools", + "lazy_static", "llvm-sys", "nom", "nom-trace", diff --git a/Cargo.toml b/Cargo.toml index c9796a821586..2ac7d2540961 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ clap = "3.0.0-beta.2" derive_more = "0.99.11" inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm11-0"] } itertools = "0.10.0" +lazy_static = "1.4.0" llvm-sys = "110.0.1" nom = "6.1.2" nom-trace = { git = "https://github.com/glittershark/nom-trace", branch = "nom-6" } diff --git a/ach/.gitignore b/ach/.gitignore index e8423ae351b8..683a53a01f6c 100644 --- a/ach/.gitignore +++ b/ach/.gitignore @@ -1,2 +1,5 @@ *.ll *.o + +functions +simple diff --git a/src/ast/hir.rs b/src/ast/hir.rs new file mode 100644 index 000000000000..151ddd529872 --- /dev/null +++ b/src/ast/hir.rs @@ -0,0 +1,246 @@ +use itertools::Itertools; + +use super::{BinaryOperator, Ident, Literal, UnaryOperator}; + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Binding<'a, T> { + pub ident: Ident<'a>, + pub type_: T, + pub body: Expr<'a, T>, +} + +impl<'a, T> Binding<'a, T> { + fn to_owned(&self) -> Binding<'static, T> + where + T: Clone, + { + Binding { + ident: self.ident.to_owned(), + type_: self.type_.clone(), + body: self.body.to_owned(), + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum Expr<'a, T> { + Ident(Ident<'a>, T), + + Literal(Literal, T), + + UnaryOp { + op: UnaryOperator, + rhs: Box>, + type_: T, + }, + + BinaryOp { + lhs: Box>, + op: BinaryOperator, + rhs: Box>, + type_: T, + }, + + Let { + bindings: Vec>, + body: Box>, + type_: T, + }, + + If { + condition: Box>, + then: Box>, + else_: Box>, + type_: T, + }, + + Fun { + args: Vec<(Ident<'a>, T)>, + body: Box>, + type_: T, + }, + + Call { + fun: Box>, + args: Vec>, + type_: T, + }, +} + +impl<'a, T> Expr<'a, T> { + pub fn type_(&self) -> &T { + match self { + Expr::Ident(_, t) => t, + Expr::Literal(_, t) => t, + Expr::UnaryOp { type_, .. } => type_, + Expr::BinaryOp { type_, .. } => type_, + Expr::Let { type_, .. } => type_, + Expr::If { type_, .. } => type_, + Expr::Fun { type_, .. } => type_, + Expr::Call { type_, .. } => type_, + } + } + + pub fn traverse_type(self, f: F) -> Result, E> + where + F: Fn(T) -> Result + Clone, + { + match self { + Expr::Ident(id, t) => Ok(Expr::Ident(id, f(t)?)), + Expr::Literal(lit, t) => Ok(Expr::Literal(lit, f(t)?)), + Expr::UnaryOp { op, rhs, type_ } => Ok(Expr::UnaryOp { + op, + rhs: Box::new(rhs.traverse_type(f.clone())?), + type_: f(type_)?, + }), + Expr::BinaryOp { + lhs, + op, + rhs, + type_, + } => Ok(Expr::BinaryOp { + lhs: Box::new(lhs.traverse_type(f.clone())?), + op, + rhs: Box::new(rhs.traverse_type(f.clone())?), + type_: f(type_)?, + }), + Expr::Let { + bindings, + body, + type_, + } => Ok(Expr::Let { + bindings: bindings + .into_iter() + .map(|Binding { ident, type_, body }| { + Ok(Binding { + ident, + type_: f(type_)?, + body: body.traverse_type(f.clone())?, + }) + }) + .collect::, E>>()?, + body: Box::new(body.traverse_type(f.clone())?), + type_: f(type_)?, + }), + Expr::If { + condition, + then, + else_, + type_, + } => Ok(Expr::If { + condition: Box::new(condition.traverse_type(f.clone())?), + then: Box::new(then.traverse_type(f.clone())?), + else_: Box::new(else_.traverse_type(f.clone())?), + type_: f(type_)?, + }), + Expr::Fun { args, body, type_ } => Ok(Expr::Fun { + args: args + .into_iter() + .map(|(id, t)| Ok((id, f.clone()(t)?))) + .collect::, E>>()?, + body: Box::new(body.traverse_type(f.clone())?), + type_: f(type_)?, + }), + Expr::Call { fun, args, type_ } => Ok(Expr::Call { + fun: Box::new(fun.traverse_type(f.clone())?), + args: args + .into_iter() + .map(|e| e.traverse_type(f.clone())) + .collect::, E>>()?, + type_: f(type_)?, + }), + } + } + + pub fn to_owned(&self) -> Expr<'static, T> + where + T: Clone, + { + match self { + Expr::Ident(id, t) => Expr::Ident(id.to_owned(), t.clone()), + Expr::Literal(lit, t) => Expr::Literal(lit.clone(), t.clone()), + Expr::UnaryOp { op, rhs, type_ } => Expr::UnaryOp { + op: *op, + rhs: Box::new((**rhs).to_owned()), + type_: type_.clone(), + }, + Expr::BinaryOp { + lhs, + op, + rhs, + type_, + } => Expr::BinaryOp { + lhs: Box::new((**lhs).to_owned()), + op: *op, + rhs: Box::new((**rhs).to_owned()), + type_: type_.clone(), + }, + Expr::Let { + bindings, + body, + type_, + } => Expr::Let { + bindings: bindings.into_iter().map(|b| b.to_owned()).collect(), + body: Box::new((**body).to_owned()), + type_: type_.clone(), + }, + Expr::If { + condition, + then, + else_, + type_, + } => Expr::If { + condition: Box::new((**condition).to_owned()), + then: Box::new((**then).to_owned()), + else_: Box::new((**else_).to_owned()), + type_: type_.clone(), + }, + Expr::Fun { args, body, type_ } => Expr::Fun { + args: args + .into_iter() + .map(|(id, t)| (id.to_owned(), t.clone())) + .collect(), + body: Box::new((**body).to_owned()), + type_: type_.clone(), + }, + Expr::Call { fun, args, type_ } => Expr::Call { + fun: Box::new((**fun).to_owned()), + args: args.into_iter().map(|e| e.to_owned()).collect(), + type_: type_.clone(), + }, + } + } +} + +pub enum Decl<'a, T> { + Fun { + name: Ident<'a>, + args: Vec<(Ident<'a>, T)>, + body: Box>, + type_: T, + }, +} + +impl<'a, T> Decl<'a, T> { + pub fn traverse_type(self, f: F) -> Result, E> + where + F: Fn(T) -> Result + Clone, + { + match self { + Decl::Fun { + name, + args, + body, + type_, + } => Ok(Decl::Fun { + name, + args: args + .into_iter() + .map(|(id, t)| Ok((id, f(t)?))) + .try_collect()?, + body: Box::new(body.traverse_type(f.clone())?), + type_: f(type_)?, + }), + } + } +} diff --git a/src/ast/mod.rs b/src/ast/mod.rs index dc22ac3cdb56..cef366d16e04 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1,3 +1,5 @@ +pub(crate) mod hir; + use std::borrow::Cow; use std::convert::TryFrom; use std::fmt::{self, Display, Formatter}; @@ -107,6 +109,7 @@ pub enum UnaryOperator { #[derive(Debug, PartialEq, Eq, Clone)] pub enum Literal { Int(u64), + Bool(bool), } #[derive(Debug, PartialEq, Eq, Clone)] diff --git a/src/codegen/llvm.rs b/src/codegen/llvm.rs index 1f4a457cd81b..5b5db90a1ad8 100644 --- a/src/codegen/llvm.rs +++ b/src/codegen/llvm.rs @@ -7,12 +7,13 @@ use inkwell::builder::Builder; pub use inkwell::context::Context; use inkwell::module::Module; use inkwell::support::LLVMString; -use inkwell::types::FunctionType; +use inkwell::types::{BasicType, BasicTypeEnum, FunctionType, IntType}; use inkwell::values::{AnyValueEnum, BasicValueEnum, FunctionValue}; use inkwell::IntPredicate; use thiserror::Error; -use crate::ast::{BinaryOperator, Binding, Decl, Expr, Fun, Ident, Literal, UnaryOperator}; +use crate::ast::hir::{Binding, Decl, Expr}; +use crate::ast::{BinaryOperator, Ident, Literal, Type, UnaryOperator}; use crate::common::env::Env; #[derive(Debug, PartialEq, Eq, Error)] @@ -36,7 +37,7 @@ pub struct Codegen<'ctx, 'ast> { context: &'ctx Context, pub module: Module<'ctx>, builder: Builder<'ctx>, - env: Env<'ast, AnyValueEnum<'ctx>>, + env: Env<&'ast Ident<'ast>, AnyValueEnum<'ctx>>, function_stack: Vec>, identifier_counter: u32, } @@ -77,18 +78,23 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { .append_basic_block(*self.function_stack.last().unwrap(), name) } - pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast>) -> Result> { + pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast, Type>) -> Result> { match expr { - Expr::Ident(id) => self + Expr::Ident(id, _) => self .env .resolve(id) .cloned() .ok_or_else(|| Error::UndefinedVariable(id.to_owned())), - Expr::Literal(Literal::Int(i)) => { - let ty = self.context.i64_type(); - Ok(AnyValueEnum::IntValue(ty.const_int(*i, false))) + Expr::Literal(lit, ty) => { + let ty = self.codegen_int_type(ty); + match lit { + Literal::Int(i) => Ok(AnyValueEnum::IntValue(ty.const_int(*i, false))), + Literal::Bool(b) => Ok(AnyValueEnum::IntValue( + ty.const_int(if *b { 1 } else { 0 }, false), + )), + } } - Expr::UnaryOp { op, rhs } => { + Expr::UnaryOp { op, rhs, .. } => { let rhs = self.codegen_expr(rhs)?; match op { UnaryOperator::Not => unimplemented!(), @@ -97,7 +103,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { )), } } - Expr::BinaryOp { lhs, op, rhs } => { + Expr::BinaryOp { lhs, op, rhs, .. } => { let lhs = self.codegen_expr(lhs)?; let rhs = self.codegen_expr(rhs)?; match op { @@ -135,7 +141,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { BinaryOperator::Neq => todo!(), } } - Expr::Let { bindings, body } => { + Expr::Let { bindings, body, .. } => { self.env.push(); for Binding { ident, body, .. } in bindings { let val = self.codegen_expr(body)?; @@ -149,6 +155,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { condition, then, else_, + type_, } => { let then_block = self.append_basic_block("then"); let else_block = self.append_basic_block("else"); @@ -168,15 +175,15 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { self.builder.build_unconditional_branch(join_block); self.builder.position_at_end(join_block); - let phi = self.builder.build_phi(self.context.i64_type(), "join"); + let phi = self.builder.build_phi(self.codegen_type(type_), "join"); phi.add_incoming(&[ (&BasicValueEnum::try_from(then_res).unwrap(), then_block), (&BasicValueEnum::try_from(else_res).unwrap(), else_block), ]); Ok(phi.as_basic_value().into()) } - Expr::Call { fun, args } => { - if let Expr::Ident(id) = &**fun { + Expr::Call { fun, args, .. } => { + if let Expr::Ident(id, _) = &**fun { let function = self .module .get_function(id.into()) @@ -197,8 +204,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { todo!() } } - Expr::Fun(fun) => { - let Fun { args, body } = &**fun; + Expr::Fun { args, body, .. } => { let fname = self.fresh_ident("f"); let cur_block = self.builder.get_insert_block().unwrap(); let env = self.env.save(); // TODO: closures @@ -207,29 +213,27 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { self.env.restore(env); Ok(function.into()) } - Expr::Ascription { expr, .. } => self.codegen_expr(expr), } } pub fn codegen_function( &mut self, name: &str, - args: &'ast [Ident<'ast>], - body: &'ast Expr<'ast>, + args: &'ast [(Ident<'ast>, Type)], + body: &'ast Expr<'ast, Type>, ) -> Result> { - let i64_type = self.context.i64_type(); self.new_function( name, - i64_type.fn_type( + self.codegen_type(body.type_()).fn_type( args.iter() - .map(|_| i64_type.into()) + .map(|(_, at)| self.codegen_type(at)) .collect::>() .as_slice(), false, ), ); self.env.push(); - for (i, arg) in args.iter().enumerate() { + for (i, (arg, _)) in args.iter().enumerate() { self.env.set( arg, self.cur_function().get_nth_param(i as u32).unwrap().into(), @@ -240,11 +244,10 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { Ok(self.finish_function(&res)) } - pub fn codegen_decl(&mut self, decl: &'ast Decl<'ast>) -> Result<()> { + pub fn codegen_decl(&mut self, decl: &'ast Decl<'ast, Type>) -> Result<()> { match decl { Decl::Fun { - name, - body: Fun { args, body }, + name, args, body, .. } => { self.codegen_function(name.into(), args, body)?; Ok(()) @@ -252,13 +255,28 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { } } - pub fn codegen_main(&mut self, expr: &'ast Expr<'ast>) -> Result<()> { + pub fn codegen_main(&mut self, expr: &'ast Expr<'ast, Type>) -> Result<()> { self.new_function("main", self.context.i64_type().fn_type(&[], false)); let res = self.codegen_expr(expr)?.try_into().unwrap(); - self.finish_function(&res); + if *expr.type_() != Type::Int { + self.builder + .build_return(Some(&self.context.i64_type().const_int(0, false))); + } else { + self.finish_function(&res); + } Ok(()) } + fn codegen_type(&self, type_: &'ast Type) -> BasicTypeEnum<'ctx> { + // TODO + self.context.i64_type().into() + } + + fn codegen_int_type(&self, type_: &'ast Type) -> IntType<'ctx> { + // TODO + self.context.i64_type() + } + pub fn print_to_file

(&self, path: P) -> Result<()> where P: AsRef, @@ -299,6 +317,8 @@ mod tests { fn jit_eval(expr: &str) -> anyhow::Result { let expr = crate::parser::expr(expr).unwrap().1; + let expr = crate::tc::typecheck_expr(expr).unwrap(); + let context = Context::create(); let mut codegen = Codegen::new(&context, "test"); let execution_engine = codegen diff --git a/src/codegen/mod.rs b/src/codegen/mod.rs index 6f95d90b45a1..8ef057dba04f 100644 --- a/src/codegen/mod.rs +++ b/src/codegen/mod.rs @@ -4,10 +4,11 @@ use inkwell::execution_engine::JitFunction; use inkwell::OptimizationLevel; pub use llvm::*; -use crate::ast::Expr; +use crate::ast::hir::Expr; +use crate::ast::Type; use crate::common::Result; -pub fn jit_eval(expr: &Expr) -> Result { +pub fn jit_eval(expr: &Expr) -> Result { let context = Context::create(); let mut codegen = Codegen::new(&context, "eval"); let execution_engine = codegen diff --git a/src/commands/check.rs b/src/commands/check.rs new file mode 100644 index 000000000000..40de288a282c --- /dev/null +++ b/src/commands/check.rs @@ -0,0 +1,39 @@ +use clap::Clap; +use std::path::PathBuf; + +use crate::ast::Type; +use crate::{parser, tc, Result}; + +/// Typecheck a file or expression +#[derive(Clap)] +pub struct Check { + /// File to check + path: Option, + + /// Expression to check + #[clap(long, short = 'e')] + expr: Option, +} + +fn run_expr(expr: String) -> Result { + let (_, parsed) = parser::expr(&expr)?; + let hir_expr = tc::typecheck_expr(parsed)?; + Ok(hir_expr.type_().clone()) +} + +fn run_path(path: PathBuf) -> Result { + todo!() +} + +impl Check { + pub fn run(self) -> Result<()> { + let type_ = match (self.path, self.expr) { + (None, None) => Err("Must specify either a file or expression to check".into()), + (Some(_), Some(_)) => Err("Cannot specify both a file and expression to check".into()), + (None, Some(expr)) => run_expr(expr), + (Some(path), None) => run_path(path), + }?; + println!("type: {}", type_); + Ok(()) + } +} diff --git a/src/commands/eval.rs b/src/commands/eval.rs index 112bee64625b..61a712c08a8e 100644 --- a/src/commands/eval.rs +++ b/src/commands/eval.rs @@ -3,6 +3,7 @@ use clap::Clap; use crate::codegen; use crate::interpreter; use crate::parser; +use crate::tc; use crate::Result; /// Evaluate an expression and print its result @@ -19,10 +20,11 @@ pub struct Eval { impl Eval { pub fn run(self) -> Result<()> { let (_, parsed) = parser::expr(&self.expr)?; + let hir = tc::typecheck_expr(parsed)?; let result = if self.jit { - codegen::jit_eval::(&parsed)?.into() + codegen::jit_eval::(&hir)?.into() } else { - interpreter::eval(&parsed)? + interpreter::eval(&hir)? }; println!("{}", result); Ok(()) diff --git a/src/commands/mod.rs b/src/commands/mod.rs index 9c0038dabfb1..fd0a822708c2 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -1,5 +1,7 @@ +pub mod check; pub mod compile; pub mod eval; +pub use check::Check; pub use compile::Compile; pub use eval::Eval; diff --git a/src/common/env.rs b/src/common/env.rs index f499323639e3..59a5e46c466f 100644 --- a/src/common/env.rs +++ b/src/common/env.rs @@ -1,19 +1,25 @@ +use std::borrow::Borrow; use std::collections::HashMap; +use std::hash::Hash; use std::mem; -use crate::ast::Ident; - /// A lexical environment #[derive(Debug, PartialEq, Eq)] -pub struct Env<'ast, V>(Vec, V>>); +pub struct Env(Vec>); -impl<'ast, V> Default for Env<'ast, V> { +impl Default for Env +where + K: Eq + Hash, +{ fn default() -> Self { Self::new() } } -impl<'ast, V> Env<'ast, V> { +impl Env +where + K: Eq + Hash, +{ pub fn new() -> Self { Self(vec![Default::default()]) } @@ -34,11 +40,15 @@ impl<'ast, V> Env<'ast, V> { *self = saved; } - pub fn set(&mut self, k: &'ast Ident<'ast>, v: V) { + pub fn set(&mut self, k: K, v: V) { self.0.last_mut().unwrap().insert(k, v); } - pub fn resolve<'a>(&'a self, k: &'ast Ident<'ast>) -> Option<&'a V> { + pub fn resolve<'a, Q>(&'a self, k: &Q) -> Option<&'a V> + where + K: Borrow, + Q: Hash + Eq + ?Sized, + { for ctx in self.0.iter().rev() { if let Some(res) = ctx.get(k) { return Some(res); diff --git a/src/common/error.rs b/src/common/error.rs index f3f3023ceaf8..51575a895e91 100644 --- a/src/common/error.rs +++ b/src/common/error.rs @@ -2,7 +2,7 @@ use std::{io, result}; use thiserror::Error; -use crate::{codegen, interpreter, parser}; +use crate::{codegen, interpreter, parser, tc}; #[derive(Error, Debug)] pub enum Error { @@ -18,6 +18,9 @@ pub enum Error { #[error("Compile error: {0}")] CodegenError(#[from] codegen::Error), + #[error("Type error: {0}")] + TypeError(#[from] tc::Error), + #[error("{0}")] Message(String), } @@ -28,6 +31,12 @@ impl From for Error { } } +impl<'a> From<&'a str> for Error { + fn from(s: &'a str) -> Self { + Self::Message(s.to_owned()) + } +} + impl<'a> From>> for Error { fn from(e: nom::Err>) -> Self { use nom::error::Error as NomError; diff --git a/src/compiler.rs b/src/compiler.rs index 5f8e1ef4fa03..f925b267df57 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -8,7 +8,7 @@ use test_strategy::Arbitrary; use crate::codegen::{self, Codegen}; use crate::common::Result; -use crate::parser; +use crate::{parser, tc}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Arbitrary)] pub enum OutputFormat { @@ -55,6 +55,8 @@ pub struct CompilerOptions { pub fn compile_file(input: &Path, output: &Path, options: &CompilerOptions) -> Result<()> { let src = fs::read_to_string(input)?; let (_, decls) = parser::toplevel(&src)?; // TODO: statements + let decls = tc::typecheck_toplevel(decls)?; + let context = codegen::Context::create(); let mut codegen = Codegen::new( &context, diff --git a/src/interpreter/mod.rs b/src/interpreter/mod.rs index 00421ee90dc8..85a8928cbf9a 100644 --- a/src/interpreter/mod.rs +++ b/src/interpreter/mod.rs @@ -3,14 +3,13 @@ mod value; pub use self::error::{Error, Result}; pub use self::value::{Function, Value}; -use crate::ast::{ - BinaryOperator, Binding, Expr, FunctionType, Ident, Literal, Type, UnaryOperator, -}; +use crate::ast::hir::{Binding, Expr}; +use crate::ast::{BinaryOperator, FunctionType, Ident, Literal, Type, UnaryOperator}; use crate::common::env::Env; #[derive(Debug, Default)] pub struct Interpreter<'a> { - env: Env<'a, Value<'a>>, + env: Env<&'a Ident<'a>, Value<'a>>, } impl<'a> Interpreter<'a> { @@ -25,18 +24,19 @@ impl<'a> Interpreter<'a> { .ok_or_else(|| Error::UndefinedVariable(var.to_owned())) } - pub fn eval(&mut self, expr: &'a Expr<'a>) -> Result> { - match expr { - Expr::Ident(id) => self.resolve(id), - Expr::Literal(Literal::Int(i)) => Ok((*i).into()), - Expr::UnaryOp { op, rhs } => { + pub fn eval(&mut self, expr: &'a Expr<'a, Type>) -> Result> { + let res = match expr { + Expr::Ident(id, _) => self.resolve(id), + Expr::Literal(Literal::Int(i), _) => Ok((*i).into()), + Expr::Literal(Literal::Bool(b), _) => Ok((*b).into()), + Expr::UnaryOp { op, rhs, .. } => { let rhs = self.eval(rhs)?; match op { UnaryOperator::Neg => -rhs, _ => unimplemented!(), } } - Expr::BinaryOp { lhs, op, rhs } => { + Expr::BinaryOp { lhs, op, rhs, .. } => { let lhs = self.eval(lhs)?; let rhs = self.eval(rhs)?; match op { @@ -49,7 +49,7 @@ impl<'a> Interpreter<'a> { BinaryOperator::Neq => todo!(), } } - Expr::Let { bindings, body } => { + Expr::Let { bindings, body, .. } => { self.env.push(); for Binding { ident, body, .. } in bindings { let val = self.eval(body)?; @@ -63,6 +63,7 @@ impl<'a> Interpreter<'a> { condition, then, else_, + .. } => { let condition = self.eval(condition)?; if *(condition.as_type::()?) { @@ -71,7 +72,7 @@ impl<'a> Interpreter<'a> { self.eval(else_) } } - Expr::Call { ref fun, args } => { + Expr::Call { ref fun, args, .. } => { let fun = self.eval(fun)?; let expected_type = FunctionType { args: args.iter().map(|_| Type::Int).collect(), @@ -94,21 +95,26 @@ impl<'a> Interpreter<'a> { } Ok(Value::from(*interpreter.eval(body)?.as_type::()?)) } - Expr::Fun(fun) => Ok(Value::from(value::Function { - // TODO - type_: FunctionType { - args: fun.args.iter().map(|_| Type::Int).collect(), - ret: Box::new(Type::Int), - }, - args: fun.args.iter().map(|arg| arg.to_owned()).collect(), - body: fun.body.to_owned(), - })), - Expr::Ascription { expr, .. } => self.eval(expr), - } + Expr::Fun { args, body, type_ } => { + let type_ = match type_ { + Type::Function(ft) => ft.clone(), + _ => unreachable!("Function expression without function type"), + }; + + Ok(Value::from(value::Function { + // TODO + type_, + args: args.iter().map(|(arg, _)| arg.to_owned()).collect(), + body: (**body).to_owned(), + })) + } + }?; + debug_assert_eq!(&res.type_(), expr.type_()); + Ok(res) } } -pub fn eval<'a>(expr: &'a Expr<'a>) -> Result { +pub fn eval<'a>(expr: &'a Expr<'a, Type>) -> Result { let mut interpreter = Interpreter::new(); interpreter.eval(expr) } @@ -121,17 +127,18 @@ mod tests { use super::*; use BinaryOperator::*; - fn int_lit(i: u64) -> Box> { - Box::new(Expr::Literal(Literal::Int(i))) + fn int_lit(i: u64) -> Box> { + Box::new(Expr::Literal(Literal::Int(i), Type::Int)) } - fn parse_eval(src: &str) -> T + fn do_eval(src: &str) -> T where for<'a> &'a T: TryFrom<&'a Val<'a>>, T: Clone + TypeOf, { let expr = crate::parser::expr(src).unwrap().1; - let res = eval(&expr).unwrap(); + let hir = crate::tc::typecheck_expr(expr).unwrap(); + let res = eval(&hir).unwrap(); res.as_type::().unwrap().clone() } @@ -141,6 +148,7 @@ mod tests { lhs: int_lit(1), op: Mul, rhs: int_lit(2), + type_: Type::Int, }; let res = eval(&expr).unwrap(); assert_eq!(*res.as_type::().unwrap(), 2); @@ -148,19 +156,19 @@ mod tests { #[test] fn variable_shadowing() { - let res = parse_eval::("let x = 1 in (let x = 2 in x) + x"); + let res = do_eval::("let x = 1 in (let x = 2 in x) + x"); assert_eq!(res, 3); } #[test] fn conditional_with_equals() { - let res = parse_eval::("let x = 1 in if x == 1 then 2 else 4"); + let res = do_eval::("let x = 1 in if x == 1 then 2 else 4"); assert_eq!(res, 2); } #[test] fn function_call() { - let res = parse_eval::("let id = fn x = x in id 1"); + let res = do_eval::("let id = fn x = x in id 1"); assert_eq!(res, 1); } } diff --git a/src/interpreter/value.rs b/src/interpreter/value.rs index 496e9c4230de..5e55825160cd 100644 --- a/src/interpreter/value.rs +++ b/src/interpreter/value.rs @@ -6,13 +6,14 @@ use std::rc::Rc; use derive_more::{Deref, From, TryInto}; use super::{Error, Result}; -use crate::ast::{Expr, FunctionType, Ident, Type}; +use crate::ast::hir::Expr; +use crate::ast::{FunctionType, Ident, Type}; #[derive(Debug, Clone)] pub struct Function<'a> { pub type_: FunctionType, pub args: Vec>, - pub body: Expr<'a>, + pub body: Expr<'a, Type>, } #[derive(From, TryInto)] diff --git a/src/main.rs b/src/main.rs index b539ebbb3d99..d5b00d6b6c46 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,6 +8,7 @@ pub mod compiler; pub mod interpreter; #[macro_use] pub mod parser; +pub mod tc; pub use common::{Error, Result}; @@ -21,6 +22,7 @@ struct Opts { enum Command { Eval(commands::Eval), Compile(commands::Compile), + Check(commands::Check), } fn main() -> anyhow::Result<()> { @@ -28,5 +30,6 @@ fn main() -> anyhow::Result<()> { match opts.subcommand { Command::Eval(eval) => Ok(eval.run()?), Command::Compile(compile) => Ok(compile.run()?), + Command::Check(check) => Ok(check.run()?), } } diff --git a/src/parser/expr.rs b/src/parser/expr.rs index 2fda3e93fae9..73c873b5b304 100644 --- a/src/parser/expr.rs +++ b/src/parser/expr.rs @@ -156,7 +156,14 @@ where named!(int(&str) -> Literal, map!(flat_map!(digit1, parse_to!(u64)), Literal::Int)); -named!(literal(&str) -> Expr, map!(alt!(int), Expr::Literal)); +named!(bool_(&str) -> Literal, alt!( + tag!("true") => { |_| Literal::Bool(true) } | + tag!("false") => { |_| Literal::Bool(false) } +)); + +named!(literal(&str) -> Literal, alt!(int | bool_)); + +named!(literal_expr(&str) -> Expr, map!(literal, Expr::Literal)); named!(binding(&str) -> Binding, do_parse!( multispace0 @@ -262,7 +269,7 @@ named!(fun_expr(&str) -> Expr, do_parse!( named!(arg(&str) -> Expr, alt!( ident_expr | - literal | + literal_expr | paren_expr )); @@ -280,7 +287,7 @@ named!(simple_expr_unascripted(&str) -> Expr, alt!( let_ | if_ | fun_expr | - literal | + literal_expr | ident_expr )); @@ -399,6 +406,18 @@ pub(crate) mod tests { } } + #[test] + fn bools() { + assert_eq!( + test_parse!(expr, "true"), + Expr::Literal(Literal::Bool(true)) + ); + assert_eq!( + test_parse!(expr, "false"), + Expr::Literal(Literal::Bool(false)) + ); + } + #[test] fn let_complex() { let res = test_parse!(expr, "let x = 1; y = x * 7 in (x + y) * 4"); diff --git a/src/parser/macros.rs b/src/parser/macros.rs index 60db5133dc0f..406e5c0e699e 100644 --- a/src/parser/macros.rs +++ b/src/parser/macros.rs @@ -1,3 +1,4 @@ +#[cfg(test)] #[macro_use] macro_rules! test_parse { ($parser: ident, $src: expr) => {{ diff --git a/src/parser/mod.rs b/src/parser/mod.rs index af7dff6ff213..0251d02df464 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -14,7 +14,10 @@ pub use type_::type_; pub type Error = nom::Err>; pub(crate) fn is_reserved(s: &str) -> bool { - matches!(s, "if" | "then" | "else" | "let" | "in" | "fn") + matches!( + s, + "if" | "then" | "else" | "let" | "in" | "fn" | "int" | "float" | "bool" | "true" | "false" + ) } pub(crate) fn ident<'a, E>(i: &'a str) -> nom::IResult<&'a str, Ident, E> diff --git a/src/tc/mod.rs b/src/tc/mod.rs new file mode 100644 index 000000000000..b5acfac2b426 --- /dev/null +++ b/src/tc/mod.rs @@ -0,0 +1,528 @@ +use derive_more::From; +use itertools::Itertools; +use std::collections::HashMap; +use std::convert::{TryFrom, TryInto}; +use std::fmt::{self, Display}; +use std::result; +use thiserror::Error; + +use crate::ast::{self, hir, BinaryOperator, Ident, Literal}; +use crate::common::env::Env; + +#[derive(Debug, Error)] +pub enum Error { + #[error("Undefined variable {0}")] + UndefinedVariable(Ident<'static>), + + #[error("Mismatched types: expected {expected}, but got {actual}")] + TypeMismatch { expected: Type, actual: Type }, + + #[error("Mismatched types, expected numeric type, but got {0}")] + NonNumeric(Type), + + #[error("Ambiguous type {0}")] + AmbiguousType(TyVar), +} + +pub type Result = result::Result; + +#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] +pub struct TyVar(u64); + +impl Display for TyVar { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "t{}", self.0) + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Hash)] +pub struct NullaryType(String); + +impl Display for NullaryType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&self.0) + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum PrimType { + Int, + Float, + Bool, +} + +impl From for ast::Type { + fn from(pr: PrimType) -> Self { + match pr { + PrimType::Int => ast::Type::Int, + PrimType::Float => ast::Type::Float, + PrimType::Bool => ast::Type::Bool, + } + } +} + +impl Display for PrimType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + PrimType::Int => f.write_str("int"), + PrimType::Float => f.write_str("float"), + PrimType::Bool => f.write_str("bool"), + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone, From)] +pub enum Type { + #[from(ignore)] + Univ(TyVar), + #[from(ignore)] + Exist(TyVar), + Nullary(NullaryType), + Prim(PrimType), + Fun { + args: Vec, + ret: Box, + }, +} + +impl PartialEq for Type { + fn eq(&self, other: &ast::Type) -> bool { + match (self, other) { + (Type::Univ(_), _) => todo!(), + (Type::Exist(_), _) => false, + (Type::Nullary(_), _) => todo!(), + (Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty, + (Type::Fun { args, ret }, ast::Type::Function(ft)) => { + *args == ft.args && (**ret).eq(&*ft.ret) + } + (Type::Fun { .. }, _) => false, + } + } +} + +impl TryFrom for ast::Type { + type Error = Type; + + fn try_from(value: Type) -> result::Result { + match value { + Type::Univ(_) => todo!(), + Type::Exist(_) => Err(value), + Type::Nullary(_) => todo!(), + Type::Prim(p) => Ok(p.into()), + Type::Fun { ref args, ref ret } => Ok(ast::Type::Function(ast::FunctionType { + args: args + .clone() + .into_iter() + .map(Self::try_from) + .try_collect() + .map_err(|_| value.clone())?, + ret: Box::new((*ret.clone()).try_into().map_err(|_| value.clone())?), + })), + } + } +} + +const INT: Type = Type::Prim(PrimType::Int); +const FLOAT: Type = Type::Prim(PrimType::Float); +const BOOL: Type = Type::Prim(PrimType::Bool); + +impl Display for Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Type::Nullary(nt) => nt.fmt(f), + Type::Prim(p) => p.fmt(f), + Type::Univ(TyVar(n)) => write!(f, "∀{}", n), + Type::Exist(TyVar(n)) => write!(f, "∃{}", n), + Type::Fun { args, ret } => write!(f, "fn {} -> {}", args.iter().join(", "), ret), + } + } +} + +impl From for Type { + fn from(type_: ast::Type) -> Self { + match type_ { + ast::Type::Int => INT, + ast::Type::Float => FLOAT, + ast::Type::Bool => BOOL, + ast::Type::Function(ast::FunctionType { args, ret }) => Type::Fun { + args: args.into_iter().map(Self::from).collect(), + ret: Box::new(Self::from(*ret)), + }, + } + } +} + +struct Typechecker<'ast> { + ty_var_counter: u64, + ctx: HashMap, + env: Env, Type>, +} + +impl<'ast> Typechecker<'ast> { + fn new() -> Self { + Self { + ty_var_counter: 0, + ctx: Default::default(), + env: Default::default(), + } + } + + pub(crate) fn tc_expr(&mut self, expr: ast::Expr<'ast>) -> Result> { + match expr { + ast::Expr::Ident(ident) => { + let type_ = self + .env + .resolve(&ident) + .ok_or_else(|| Error::UndefinedVariable(ident.to_owned()))? + .clone(); + Ok(hir::Expr::Ident(ident, type_)) + } + ast::Expr::Literal(lit) => { + let type_ = match lit { + Literal::Int(_) => Type::Prim(PrimType::Int), + Literal::Bool(_) => Type::Prim(PrimType::Bool), + }; + Ok(hir::Expr::Literal(lit, type_)) + } + ast::Expr::UnaryOp { op, rhs } => todo!(), + ast::Expr::BinaryOp { lhs, op, rhs } => { + let lhs = self.tc_expr(*lhs)?; + let rhs = self.tc_expr(*rhs)?; + let type_ = match op { + BinaryOperator::Equ | BinaryOperator::Neq => { + self.unify(lhs.type_(), rhs.type_())?; + Type::Prim(PrimType::Bool) + } + BinaryOperator::Add | BinaryOperator::Sub | BinaryOperator::Mul => { + let ty = self.unify(lhs.type_(), rhs.type_())?; + // if !matches!(ty, Type::Int | Type::Float) { + // return Err(Error::NonNumeric(ty)); + // } + ty + } + BinaryOperator::Div => todo!(), + BinaryOperator::Pow => todo!(), + }; + Ok(hir::Expr::BinaryOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + type_, + }) + } + ast::Expr::Let { bindings, body } => { + self.env.push(); + let bindings = bindings + .into_iter() + .map( + |ast::Binding { ident, type_, body }| -> Result> { + let body = self.tc_expr(body)?; + if let Some(type_) = type_ { + self.unify(body.type_(), &type_.into())?; + } + self.env.set(ident.clone(), body.type_().clone()); + Ok(hir::Binding { + ident, + type_: body.type_().clone(), + body, + }) + }, + ) + .collect::>>>()?; + let body = self.tc_expr(*body)?; + self.env.pop(); + Ok(hir::Expr::Let { + bindings, + type_: body.type_().clone(), + body: Box::new(body), + }) + } + ast::Expr::If { + condition, + then, + else_, + } => { + let condition = self.tc_expr(*condition)?; + self.unify(&Type::Prim(PrimType::Bool), condition.type_())?; + let then = self.tc_expr(*then)?; + let else_ = self.tc_expr(*else_)?; + let type_ = self.unify(then.type_(), else_.type_())?; + Ok(hir::Expr::If { + condition: Box::new(condition), + then: Box::new(then), + else_: Box::new(else_), + type_, + }) + } + ast::Expr::Fun(f) => { + let ast::Fun { args, body } = *f; + self.env.push(); + let args: Vec<_> = args + .into_iter() + .map(|id| { + let ty = self.fresh_ex(); + self.env.set(id.clone(), ty.clone()); + (id, ty) + }) + .collect(); + let body = self.tc_expr(body)?; + self.env.pop(); + Ok(hir::Expr::Fun { + type_: Type::Fun { + args: args.iter().map(|(_, ty)| ty.clone()).collect(), + ret: Box::new(body.type_().clone()), + }, + args, + body: Box::new(body), + }) + } + ast::Expr::Call { fun, args } => { + let ret_ty = self.fresh_ex(); + let arg_tys = args.iter().map(|_| self.fresh_ex()).collect::>(); + let ft = Type::Fun { + args: arg_tys.clone(), + ret: Box::new(ret_ty.clone()), + }; + let fun = self.tc_expr(*fun)?; + self.unify(&ft, fun.type_())?; + let args = args + .into_iter() + .zip(arg_tys) + .map(|(arg, ty)| { + let arg = self.tc_expr(arg)?; + self.unify(&ty, arg.type_())?; + Ok(arg) + }) + .try_collect()?; + Ok(hir::Expr::Call { + fun: Box::new(fun), + args, + type_: ret_ty, + }) + } + ast::Expr::Ascription { expr, type_ } => { + let expr = self.tc_expr(*expr)?; + self.unify(expr.type_(), &type_.into())?; + Ok(expr) + } + } + } + + pub(crate) fn tc_decl(&mut self, decl: ast::Decl<'ast>) -> Result> { + match decl { + ast::Decl::Fun { name, body } => { + let body = self.tc_expr(ast::Expr::Fun(Box::new(body)))?; + let type_ = body.type_().clone(); + self.env.set(name.clone(), type_); + match body { + hir::Expr::Fun { args, body, type_ } => Ok(hir::Decl::Fun { + name, + args, + body, + type_, + }), + _ => unreachable!(), + } + } + } + } + + fn fresh_tv(&mut self) -> TyVar { + self.ty_var_counter += 1; + TyVar(self.ty_var_counter) + } + + fn fresh_ex(&mut self) -> Type { + Type::Exist(self.fresh_tv()) + } + + fn fresh_univ(&mut self) -> Type { + Type::Exist(self.fresh_tv()) + } + + fn universalize<'a>(&mut self, expr: hir::Expr<'a, Type>) -> hir::Expr<'a, Type> { + // TODO + expr + } + + fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result { + match (ty1, ty2) { + (Type::Prim(p1), Type::Prim(p2)) if p1 == p2 => Ok(ty2.clone()), + (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) { + Some(existing_ty) if *ty == existing_ty => Ok(ty.clone()), + Some(existing_ty) => Err(Error::TypeMismatch { + expected: ty.clone(), + actual: existing_ty.into(), + }), + None => match self.ctx.insert(*tv, ty.clone()) { + Some(existing) => self.unify(&existing, ty), + None => Ok(ty.clone()), + }, + }, + (Type::Univ(u1), Type::Univ(u2)) if u1 == u2 => Ok(ty2.clone()), + ( + Type::Fun { + args: args1, + ret: ret1, + }, + Type::Fun { + args: args2, + ret: ret2, + }, + ) => { + let args = args1 + .iter() + .zip(args2) + .map(|(t1, t2)| self.unify(t1, t2)) + .try_collect()?; + let ret = self.unify(ret1, ret2)?; + Ok(Type::Fun { + args, + ret: Box::new(ret), + }) + } + (Type::Nullary(_), _) | (_, Type::Nullary(_)) => todo!(), + _ => Err(Error::TypeMismatch { + expected: ty1.clone(), + actual: ty2.clone(), + }), + } + } + + fn finalize_expr(&self, expr: hir::Expr<'ast, Type>) -> Result> { + expr.traverse_type(|ty| self.finalize_type(ty)) + } + + fn finalize_decl(&self, decl: hir::Decl<'ast, Type>) -> Result> { + decl.traverse_type(|ty| self.finalize_type(ty)) + } + + fn finalize_type(&self, ty: Type) -> Result { + match ty { + Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)), + Type::Univ(tv) => todo!(), + Type::Nullary(_) => todo!(), + Type::Prim(pr) => Ok(pr.into()), + Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType { + args: args + .into_iter() + .map(|ty| self.finalize_type(ty)) + .try_collect()?, + ret: Box::new(self.finalize_type(*ret)?), + })), + } + } + + fn resolve_tv(&self, tv: TyVar) -> Option { + let mut res = &Type::Exist(tv); + loop { + match res { + Type::Exist(tv) => { + res = self.ctx.get(tv)?; + } + Type::Univ(_) => todo!(), + Type::Nullary(_) => todo!(), + Type::Prim(pr) => break Some((*pr).into()), + Type::Fun { args, ret } => todo!(), + } + } + } +} + +pub fn typecheck_expr(expr: ast::Expr) -> Result> { + let mut typechecker = Typechecker::new(); + let typechecked = typechecker.tc_expr(expr)?; + typechecker.finalize_expr(typechecked) +} + +pub fn typecheck_toplevel(decls: Vec) -> Result>> { + let mut typechecker = Typechecker::new(); + decls + .into_iter() + .map(|decl| { + let decl = typechecker.tc_decl(decl)?; + typechecker.finalize_decl(decl) + }) + .try_collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! assert_type { + ($expr: expr, $type: expr) => { + use crate::parser::{expr, type_}; + let parsed_expr = test_parse!(expr, $expr); + let parsed_type = test_parse!(type_, $type); + let res = typecheck_expr(parsed_expr).unwrap_or_else(|e| panic!("{}", e)); + assert_eq!(res.type_(), &parsed_type); + }; + } + + macro_rules! assert_type_error { + ($expr: expr) => { + use crate::parser::expr; + let parsed_expr = test_parse!(expr, $expr); + let res = typecheck_expr(parsed_expr); + assert!( + res.is_err(), + "Expected type error, but got type: {}", + res.unwrap().type_() + ); + }; + } + + #[test] + fn literal_int() { + assert_type!("1", "int"); + } + + #[test] + fn conditional() { + assert_type!("if 1 == 2 then 3 else 4", "int"); + } + + #[test] + #[ignore] + fn add_bools() { + assert_type_error!("true + false"); + } + + #[test] + fn call_generic_function() { + assert_type!("(fn x = x) 1", "int"); + } + + #[test] + #[ignore] + fn generic_function() { + assert_type!("fn x = x", "fn x, y -> x"); + } + + #[test] + #[ignore] + fn let_generalization() { + assert_type!("let id = fn x = x in if id true then id 1 else 2", "int"); + } + + #[test] + fn concrete_function() { + assert_type!("fn x = x + 1", "fn int -> int"); + } + + #[test] + fn call_concrete_function() { + assert_type!("(fn x = x + 1) 2", "int"); + } + + #[test] + fn conditional_non_bool() { + assert_type_error!("if 3 then true else false"); + } + + #[test] + fn let_int() { + assert_type!("let x = 1 in x", "int"); + } +} -- cgit 1.4.1