about summary refs log tree commit diff
path: root/src/tc/mod.rs
diff options
context:
space:
mode:
authorGriffin Smith <root@gws.fyi>2021-03-14T02·57-0500
committerGriffin Smith <root@gws.fyi>2021-03-14T03·07-0500
commit32a5c0ff0fc58aa6721c1e0ad41950bde2d66744 (patch)
treeef5dcf5234c2a86607ee2f8f30db73bad016e075 /src/tc/mod.rs
parentf8beda81fbe8d04883aee71ff4ea078f897c6de4 (diff)
Add the start of a hindley-milner typechecker
The beginning of a parse-don't-validate-based hindley-milner
typechecker, which returns on success an IR where every AST node
trivially knows its own type, and using those types to determine LLVM
types in codegen.
Diffstat (limited to '')
-rw-r--r--src/tc/mod.rs528
1 files changed, 528 insertions, 0 deletions
diff --git a/src/tc/mod.rs b/src/tc/mod.rs
new file mode 100644
index 0000000000..b5acfac2b4
--- /dev/null
+++ b/src/tc/mod.rs
@@ -0,0 +1,528 @@
+use derive_more::From;
+use itertools::Itertools;
+use std::collections::HashMap;
+use std::convert::{TryFrom, TryInto};
+use std::fmt::{self, Display};
+use std::result;
+use thiserror::Error;
+
+use crate::ast::{self, hir, BinaryOperator, Ident, Literal};
+use crate::common::env::Env;
+
+#[derive(Debug, Error)]
+pub enum Error {
+    #[error("Undefined variable {0}")]
+    UndefinedVariable(Ident<'static>),
+
+    #[error("Mismatched types: expected {expected}, but got {actual}")]
+    TypeMismatch { expected: Type, actual: Type },
+
+    #[error("Mismatched types, expected numeric type, but got {0}")]
+    NonNumeric(Type),
+
+    #[error("Ambiguous type {0}")]
+    AmbiguousType(TyVar),
+}
+
+pub type Result<T> = result::Result<T, Error>;
+
+#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
+pub struct TyVar(u64);
+
+impl Display for TyVar {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "t{}", self.0)
+    }
+}
+
+#[derive(Debug, PartialEq, Eq, Clone, Hash)]
+pub struct NullaryType(String);
+
+impl Display for NullaryType {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.write_str(&self.0)
+    }
+}
+
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+pub enum PrimType {
+    Int,
+    Float,
+    Bool,
+}
+
+impl From<PrimType> for ast::Type {
+    fn from(pr: PrimType) -> Self {
+        match pr {
+            PrimType::Int => ast::Type::Int,
+            PrimType::Float => ast::Type::Float,
+            PrimType::Bool => ast::Type::Bool,
+        }
+    }
+}
+
+impl Display for PrimType {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        match self {
+            PrimType::Int => f.write_str("int"),
+            PrimType::Float => f.write_str("float"),
+            PrimType::Bool => f.write_str("bool"),
+        }
+    }
+}
+
+#[derive(Debug, PartialEq, Eq, Clone, From)]
+pub enum Type {
+    #[from(ignore)]
+    Univ(TyVar),
+    #[from(ignore)]
+    Exist(TyVar),
+    Nullary(NullaryType),
+    Prim(PrimType),
+    Fun {
+        args: Vec<Type>,
+        ret: Box<Type>,
+    },
+}
+
+impl PartialEq<ast::Type> for Type {
+    fn eq(&self, other: &ast::Type) -> bool {
+        match (self, other) {
+            (Type::Univ(_), _) => todo!(),
+            (Type::Exist(_), _) => false,
+            (Type::Nullary(_), _) => todo!(),
+            (Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty,
+            (Type::Fun { args, ret }, ast::Type::Function(ft)) => {
+                *args == ft.args && (**ret).eq(&*ft.ret)
+            }
+            (Type::Fun { .. }, _) => false,
+        }
+    }
+}
+
+impl TryFrom<Type> for ast::Type {
+    type Error = Type;
+
+    fn try_from(value: Type) -> result::Result<Self, Self::Error> {
+        match value {
+            Type::Univ(_) => todo!(),
+            Type::Exist(_) => Err(value),
+            Type::Nullary(_) => todo!(),
+            Type::Prim(p) => Ok(p.into()),
+            Type::Fun { ref args, ref ret } => Ok(ast::Type::Function(ast::FunctionType {
+                args: args
+                    .clone()
+                    .into_iter()
+                    .map(Self::try_from)
+                    .try_collect()
+                    .map_err(|_| value.clone())?,
+                ret: Box::new((*ret.clone()).try_into().map_err(|_| value.clone())?),
+            })),
+        }
+    }
+}
+
+const INT: Type = Type::Prim(PrimType::Int);
+const FLOAT: Type = Type::Prim(PrimType::Float);
+const BOOL: Type = Type::Prim(PrimType::Bool);
+
+impl Display for Type {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        match self {
+            Type::Nullary(nt) => nt.fmt(f),
+            Type::Prim(p) => p.fmt(f),
+            Type::Univ(TyVar(n)) => write!(f, "∀{}", n),
+            Type::Exist(TyVar(n)) => write!(f, "∃{}", n),
+            Type::Fun { args, ret } => write!(f, "fn {} -> {}", args.iter().join(", "), ret),
+        }
+    }
+}
+
+impl From<ast::Type> for Type {
+    fn from(type_: ast::Type) -> Self {
+        match type_ {
+            ast::Type::Int => INT,
+            ast::Type::Float => FLOAT,
+            ast::Type::Bool => BOOL,
+            ast::Type::Function(ast::FunctionType { args, ret }) => Type::Fun {
+                args: args.into_iter().map(Self::from).collect(),
+                ret: Box::new(Self::from(*ret)),
+            },
+        }
+    }
+}
+
+struct Typechecker<'ast> {
+    ty_var_counter: u64,
+    ctx: HashMap<TyVar, Type>,
+    env: Env<Ident<'ast>, Type>,
+}
+
+impl<'ast> Typechecker<'ast> {
+    fn new() -> Self {
+        Self {
+            ty_var_counter: 0,
+            ctx: Default::default(),
+            env: Default::default(),
+        }
+    }
+
+    pub(crate) fn tc_expr(&mut self, expr: ast::Expr<'ast>) -> Result<hir::Expr<'ast, Type>> {
+        match expr {
+            ast::Expr::Ident(ident) => {
+                let type_ = self
+                    .env
+                    .resolve(&ident)
+                    .ok_or_else(|| Error::UndefinedVariable(ident.to_owned()))?
+                    .clone();
+                Ok(hir::Expr::Ident(ident, type_))
+            }
+            ast::Expr::Literal(lit) => {
+                let type_ = match lit {
+                    Literal::Int(_) => Type::Prim(PrimType::Int),
+                    Literal::Bool(_) => Type::Prim(PrimType::Bool),
+                };
+                Ok(hir::Expr::Literal(lit, type_))
+            }
+            ast::Expr::UnaryOp { op, rhs } => todo!(),
+            ast::Expr::BinaryOp { lhs, op, rhs } => {
+                let lhs = self.tc_expr(*lhs)?;
+                let rhs = self.tc_expr(*rhs)?;
+                let type_ = match op {
+                    BinaryOperator::Equ | BinaryOperator::Neq => {
+                        self.unify(lhs.type_(), rhs.type_())?;
+                        Type::Prim(PrimType::Bool)
+                    }
+                    BinaryOperator::Add | BinaryOperator::Sub | BinaryOperator::Mul => {
+                        let ty = self.unify(lhs.type_(), rhs.type_())?;
+                        // if !matches!(ty, Type::Int | Type::Float) {
+                        //     return Err(Error::NonNumeric(ty));
+                        // }
+                        ty
+                    }
+                    BinaryOperator::Div => todo!(),
+                    BinaryOperator::Pow => todo!(),
+                };
+                Ok(hir::Expr::BinaryOp {
+                    lhs: Box::new(lhs),
+                    op,
+                    rhs: Box::new(rhs),
+                    type_,
+                })
+            }
+            ast::Expr::Let { bindings, body } => {
+                self.env.push();
+                let bindings = bindings
+                    .into_iter()
+                    .map(
+                        |ast::Binding { ident, type_, body }| -> Result<hir::Binding<Type>> {
+                            let body = self.tc_expr(body)?;
+                            if let Some(type_) = type_ {
+                                self.unify(body.type_(), &type_.into())?;
+                            }
+                            self.env.set(ident.clone(), body.type_().clone());
+                            Ok(hir::Binding {
+                                ident,
+                                type_: body.type_().clone(),
+                                body,
+                            })
+                        },
+                    )
+                    .collect::<Result<Vec<hir::Binding<Type>>>>()?;
+                let body = self.tc_expr(*body)?;
+                self.env.pop();
+                Ok(hir::Expr::Let {
+                    bindings,
+                    type_: body.type_().clone(),
+                    body: Box::new(body),
+                })
+            }
+            ast::Expr::If {
+                condition,
+                then,
+                else_,
+            } => {
+                let condition = self.tc_expr(*condition)?;
+                self.unify(&Type::Prim(PrimType::Bool), condition.type_())?;
+                let then = self.tc_expr(*then)?;
+                let else_ = self.tc_expr(*else_)?;
+                let type_ = self.unify(then.type_(), else_.type_())?;
+                Ok(hir::Expr::If {
+                    condition: Box::new(condition),
+                    then: Box::new(then),
+                    else_: Box::new(else_),
+                    type_,
+                })
+            }
+            ast::Expr::Fun(f) => {
+                let ast::Fun { args, body } = *f;
+                self.env.push();
+                let args: Vec<_> = args
+                    .into_iter()
+                    .map(|id| {
+                        let ty = self.fresh_ex();
+                        self.env.set(id.clone(), ty.clone());
+                        (id, ty)
+                    })
+                    .collect();
+                let body = self.tc_expr(body)?;
+                self.env.pop();
+                Ok(hir::Expr::Fun {
+                    type_: Type::Fun {
+                        args: args.iter().map(|(_, ty)| ty.clone()).collect(),
+                        ret: Box::new(body.type_().clone()),
+                    },
+                    args,
+                    body: Box::new(body),
+                })
+            }
+            ast::Expr::Call { fun, args } => {
+                let ret_ty = self.fresh_ex();
+                let arg_tys = args.iter().map(|_| self.fresh_ex()).collect::<Vec<_>>();
+                let ft = Type::Fun {
+                    args: arg_tys.clone(),
+                    ret: Box::new(ret_ty.clone()),
+                };
+                let fun = self.tc_expr(*fun)?;
+                self.unify(&ft, fun.type_())?;
+                let args = args
+                    .into_iter()
+                    .zip(arg_tys)
+                    .map(|(arg, ty)| {
+                        let arg = self.tc_expr(arg)?;
+                        self.unify(&ty, arg.type_())?;
+                        Ok(arg)
+                    })
+                    .try_collect()?;
+                Ok(hir::Expr::Call {
+                    fun: Box::new(fun),
+                    args,
+                    type_: ret_ty,
+                })
+            }
+            ast::Expr::Ascription { expr, type_ } => {
+                let expr = self.tc_expr(*expr)?;
+                self.unify(expr.type_(), &type_.into())?;
+                Ok(expr)
+            }
+        }
+    }
+
+    pub(crate) fn tc_decl(&mut self, decl: ast::Decl<'ast>) -> Result<hir::Decl<'ast, Type>> {
+        match decl {
+            ast::Decl::Fun { name, body } => {
+                let body = self.tc_expr(ast::Expr::Fun(Box::new(body)))?;
+                let type_ = body.type_().clone();
+                self.env.set(name.clone(), type_);
+                match body {
+                    hir::Expr::Fun { args, body, type_ } => Ok(hir::Decl::Fun {
+                        name,
+                        args,
+                        body,
+                        type_,
+                    }),
+                    _ => unreachable!(),
+                }
+            }
+        }
+    }
+
+    fn fresh_tv(&mut self) -> TyVar {
+        self.ty_var_counter += 1;
+        TyVar(self.ty_var_counter)
+    }
+
+    fn fresh_ex(&mut self) -> Type {
+        Type::Exist(self.fresh_tv())
+    }
+
+    fn fresh_univ(&mut self) -> Type {
+        Type::Exist(self.fresh_tv())
+    }
+
+    fn universalize<'a>(&mut self, expr: hir::Expr<'a, Type>) -> hir::Expr<'a, Type> {
+        // TODO
+        expr
+    }
+
+    fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result<Type> {
+        match (ty1, ty2) {
+            (Type::Prim(p1), Type::Prim(p2)) if p1 == p2 => Ok(ty2.clone()),
+            (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) {
+                Some(existing_ty) if *ty == existing_ty => Ok(ty.clone()),
+                Some(existing_ty) => Err(Error::TypeMismatch {
+                    expected: ty.clone(),
+                    actual: existing_ty.into(),
+                }),
+                None => match self.ctx.insert(*tv, ty.clone()) {
+                    Some(existing) => self.unify(&existing, ty),
+                    None => Ok(ty.clone()),
+                },
+            },
+            (Type::Univ(u1), Type::Univ(u2)) if u1 == u2 => Ok(ty2.clone()),
+            (
+                Type::Fun {
+                    args: args1,
+                    ret: ret1,
+                },
+                Type::Fun {
+                    args: args2,
+                    ret: ret2,
+                },
+            ) => {
+                let args = args1
+                    .iter()
+                    .zip(args2)
+                    .map(|(t1, t2)| self.unify(t1, t2))
+                    .try_collect()?;
+                let ret = self.unify(ret1, ret2)?;
+                Ok(Type::Fun {
+                    args,
+                    ret: Box::new(ret),
+                })
+            }
+            (Type::Nullary(_), _) | (_, Type::Nullary(_)) => todo!(),
+            _ => Err(Error::TypeMismatch {
+                expected: ty1.clone(),
+                actual: ty2.clone(),
+            }),
+        }
+    }
+
+    fn finalize_expr(&self, expr: hir::Expr<'ast, Type>) -> Result<hir::Expr<'ast, ast::Type>> {
+        expr.traverse_type(|ty| self.finalize_type(ty))
+    }
+
+    fn finalize_decl(&self, decl: hir::Decl<'ast, Type>) -> Result<hir::Decl<'ast, ast::Type>> {
+        decl.traverse_type(|ty| self.finalize_type(ty))
+    }
+
+    fn finalize_type(&self, ty: Type) -> Result<ast::Type> {
+        match ty {
+            Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)),
+            Type::Univ(tv) => todo!(),
+            Type::Nullary(_) => todo!(),
+            Type::Prim(pr) => Ok(pr.into()),
+            Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType {
+                args: args
+                    .into_iter()
+                    .map(|ty| self.finalize_type(ty))
+                    .try_collect()?,
+                ret: Box::new(self.finalize_type(*ret)?),
+            })),
+        }
+    }
+
+    fn resolve_tv(&self, tv: TyVar) -> Option<ast::Type> {
+        let mut res = &Type::Exist(tv);
+        loop {
+            match res {
+                Type::Exist(tv) => {
+                    res = self.ctx.get(tv)?;
+                }
+                Type::Univ(_) => todo!(),
+                Type::Nullary(_) => todo!(),
+                Type::Prim(pr) => break Some((*pr).into()),
+                Type::Fun { args, ret } => todo!(),
+            }
+        }
+    }
+}
+
+pub fn typecheck_expr(expr: ast::Expr) -> Result<hir::Expr<ast::Type>> {
+    let mut typechecker = Typechecker::new();
+    let typechecked = typechecker.tc_expr(expr)?;
+    typechecker.finalize_expr(typechecked)
+}
+
+pub fn typecheck_toplevel(decls: Vec<ast::Decl>) -> Result<Vec<hir::Decl<ast::Type>>> {
+    let mut typechecker = Typechecker::new();
+    decls
+        .into_iter()
+        .map(|decl| {
+            let decl = typechecker.tc_decl(decl)?;
+            typechecker.finalize_decl(decl)
+        })
+        .try_collect()
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    macro_rules! assert_type {
+        ($expr: expr, $type: expr) => {
+            use crate::parser::{expr, type_};
+            let parsed_expr = test_parse!(expr, $expr);
+            let parsed_type = test_parse!(type_, $type);
+            let res = typecheck_expr(parsed_expr).unwrap_or_else(|e| panic!("{}", e));
+            assert_eq!(res.type_(), &parsed_type);
+        };
+    }
+
+    macro_rules! assert_type_error {
+        ($expr: expr) => {
+            use crate::parser::expr;
+            let parsed_expr = test_parse!(expr, $expr);
+            let res = typecheck_expr(parsed_expr);
+            assert!(
+                res.is_err(),
+                "Expected type error, but got type: {}",
+                res.unwrap().type_()
+            );
+        };
+    }
+
+    #[test]
+    fn literal_int() {
+        assert_type!("1", "int");
+    }
+
+    #[test]
+    fn conditional() {
+        assert_type!("if 1 == 2 then 3 else 4", "int");
+    }
+
+    #[test]
+    #[ignore]
+    fn add_bools() {
+        assert_type_error!("true + false");
+    }
+
+    #[test]
+    fn call_generic_function() {
+        assert_type!("(fn x = x) 1", "int");
+    }
+
+    #[test]
+    #[ignore]
+    fn generic_function() {
+        assert_type!("fn x = x", "fn x, y -> x");
+    }
+
+    #[test]
+    #[ignore]
+    fn let_generalization() {
+        assert_type!("let id = fn x = x in if id true then id 1 else 2", "int");
+    }
+
+    #[test]
+    fn concrete_function() {
+        assert_type!("fn x = x + 1", "fn int -> int");
+    }
+
+    #[test]
+    fn call_concrete_function() {
+        assert_type!("(fn x = x + 1) 2", "int");
+    }
+
+    #[test]
+    fn conditional_non_bool() {
+        assert_type_error!("if 3 then true else false");
+    }
+
+    #[test]
+    fn let_int() {
+        assert_type!("let x = 1 in x", "int");
+    }
+}