about summary refs log tree commit diff
path: root/src/tc/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/tc/mod.rs')
-rw-r--r--src/tc/mod.rs528
1 files changed, 528 insertions, 0 deletions
diff --git a/src/tc/mod.rs b/src/tc/mod.rs
new file mode 100644
index 000000000000..b5acfac2b426
--- /dev/null
+++ b/src/tc/mod.rs
@@ -0,0 +1,528 @@
+use derive_more::From;
+use itertools::Itertools;
+use std::collections::HashMap;
+use std::convert::{TryFrom, TryInto};
+use std::fmt::{self, Display};
+use std::result;
+use thiserror::Error;
+
+use crate::ast::{self, hir, BinaryOperator, Ident, Literal};
+use crate::common::env::Env;
+
+#[derive(Debug, Error)]
+pub enum Error {
+    #[error("Undefined variable {0}")]
+    UndefinedVariable(Ident<'static>),
+
+    #[error("Mismatched types: expected {expected}, but got {actual}")]
+    TypeMismatch { expected: Type, actual: Type },
+
+    #[error("Mismatched types, expected numeric type, but got {0}")]
+    NonNumeric(Type),
+
+    #[error("Ambiguous type {0}")]
+    AmbiguousType(TyVar),
+}
+
+pub type Result<T> = result::Result<T, Error>;
+
+#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
+pub struct TyVar(u64);
+
+impl Display for TyVar {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "t{}", self.0)
+    }
+}
+
+#[derive(Debug, PartialEq, Eq, Clone, Hash)]
+pub struct NullaryType(String);
+
+impl Display for NullaryType {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.write_str(&self.0)
+    }
+}
+
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+pub enum PrimType {
+    Int,
+    Float,
+    Bool,
+}
+
+impl From<PrimType> for ast::Type {
+    fn from(pr: PrimType) -> Self {
+        match pr {
+            PrimType::Int => ast::Type::Int,
+            PrimType::Float => ast::Type::Float,
+            PrimType::Bool => ast::Type::Bool,
+        }
+    }
+}
+
+impl Display for PrimType {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        match self {
+            PrimType::Int => f.write_str("int"),
+            PrimType::Float => f.write_str("float"),
+            PrimType::Bool => f.write_str("bool"),
+        }
+    }
+}
+
+#[derive(Debug, PartialEq, Eq, Clone, From)]
+pub enum Type {
+    #[from(ignore)]
+    Univ(TyVar),
+    #[from(ignore)]
+    Exist(TyVar),
+    Nullary(NullaryType),
+    Prim(PrimType),
+    Fun {
+        args: Vec<Type>,
+        ret: Box<Type>,
+    },
+}
+
+impl PartialEq<ast::Type> for Type {
+    fn eq(&self, other: &ast::Type) -> bool {
+        match (self, other) {
+            (Type::Univ(_), _) => todo!(),
+            (Type::Exist(_), _) => false,
+            (Type::Nullary(_), _) => todo!(),
+            (Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty,
+            (Type::Fun { args, ret }, ast::Type::Function(ft)) => {
+                *args == ft.args && (**ret).eq(&*ft.ret)
+            }
+            (Type::Fun { .. }, _) => false,
+        }
+    }
+}
+
+impl TryFrom<Type> for ast::Type {
+    type Error = Type;
+
+    fn try_from(value: Type) -> result::Result<Self, Self::Error> {
+        match value {
+            Type::Univ(_) => todo!(),
+            Type::Exist(_) => Err(value),
+            Type::Nullary(_) => todo!(),
+            Type::Prim(p) => Ok(p.into()),
+            Type::Fun { ref args, ref ret } => Ok(ast::Type::Function(ast::FunctionType {
+                args: args
+                    .clone()
+                    .into_iter()
+                    .map(Self::try_from)
+                    .try_collect()
+                    .map_err(|_| value.clone())?,
+                ret: Box::new((*ret.clone()).try_into().map_err(|_| value.clone())?),
+            })),
+        }
+    }
+}
+
+const INT: Type = Type::Prim(PrimType::Int);
+const FLOAT: Type = Type::Prim(PrimType::Float);
+const BOOL: Type = Type::Prim(PrimType::Bool);
+
+impl Display for Type {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        match self {
+            Type::Nullary(nt) => nt.fmt(f),
+            Type::Prim(p) => p.fmt(f),
+            Type::Univ(TyVar(n)) => write!(f, "∀{}", n),
+            Type::Exist(TyVar(n)) => write!(f, "∃{}", n),
+            Type::Fun { args, ret } => write!(f, "fn {} -> {}", args.iter().join(", "), ret),
+        }
+    }
+}
+
+impl From<ast::Type> for Type {
+    fn from(type_: ast::Type) -> Self {
+        match type_ {
+            ast::Type::Int => INT,
+            ast::Type::Float => FLOAT,
+            ast::Type::Bool => BOOL,
+            ast::Type::Function(ast::FunctionType { args, ret }) => Type::Fun {
+                args: args.into_iter().map(Self::from).collect(),
+                ret: Box::new(Self::from(*ret)),
+            },
+        }
+    }
+}
+
+struct Typechecker<'ast> {
+    ty_var_counter: u64,
+    ctx: HashMap<TyVar, Type>,
+    env: Env<Ident<'ast>, Type>,
+}
+
+impl<'ast> Typechecker<'ast> {
+    fn new() -> Self {
+        Self {
+            ty_var_counter: 0,
+            ctx: Default::default(),
+            env: Default::default(),
+        }
+    }
+
+    pub(crate) fn tc_expr(&mut self, expr: ast::Expr<'ast>) -> Result<hir::Expr<'ast, Type>> {
+        match expr {
+            ast::Expr::Ident(ident) => {
+                let type_ = self
+                    .env
+                    .resolve(&ident)
+                    .ok_or_else(|| Error::UndefinedVariable(ident.to_owned()))?
+                    .clone();
+                Ok(hir::Expr::Ident(ident, type_))
+            }
+            ast::Expr::Literal(lit) => {
+                let type_ = match lit {
+                    Literal::Int(_) => Type::Prim(PrimType::Int),
+                    Literal::Bool(_) => Type::Prim(PrimType::Bool),
+                };
+                Ok(hir::Expr::Literal(lit, type_))
+            }
+            ast::Expr::UnaryOp { op, rhs } => todo!(),
+            ast::Expr::BinaryOp { lhs, op, rhs } => {
+                let lhs = self.tc_expr(*lhs)?;
+                let rhs = self.tc_expr(*rhs)?;
+                let type_ = match op {
+                    BinaryOperator::Equ | BinaryOperator::Neq => {
+                        self.unify(lhs.type_(), rhs.type_())?;
+                        Type::Prim(PrimType::Bool)
+                    }
+                    BinaryOperator::Add | BinaryOperator::Sub | BinaryOperator::Mul => {
+                        let ty = self.unify(lhs.type_(), rhs.type_())?;
+                        // if !matches!(ty, Type::Int | Type::Float) {
+                        //     return Err(Error::NonNumeric(ty));
+                        // }
+                        ty
+                    }
+                    BinaryOperator::Div => todo!(),
+                    BinaryOperator::Pow => todo!(),
+                };
+                Ok(hir::Expr::BinaryOp {
+                    lhs: Box::new(lhs),
+                    op,
+                    rhs: Box::new(rhs),
+                    type_,
+                })
+            }
+            ast::Expr::Let { bindings, body } => {
+                self.env.push();
+                let bindings = bindings
+                    .into_iter()
+                    .map(
+                        |ast::Binding { ident, type_, body }| -> Result<hir::Binding<Type>> {
+                            let body = self.tc_expr(body)?;
+                            if let Some(type_) = type_ {
+                                self.unify(body.type_(), &type_.into())?;
+                            }
+                            self.env.set(ident.clone(), body.type_().clone());
+                            Ok(hir::Binding {
+                                ident,
+                                type_: body.type_().clone(),
+                                body,
+                            })
+                        },
+                    )
+                    .collect::<Result<Vec<hir::Binding<Type>>>>()?;
+                let body = self.tc_expr(*body)?;
+                self.env.pop();
+                Ok(hir::Expr::Let {
+                    bindings,
+                    type_: body.type_().clone(),
+                    body: Box::new(body),
+                })
+            }
+            ast::Expr::If {
+                condition,
+                then,
+                else_,
+            } => {
+                let condition = self.tc_expr(*condition)?;
+                self.unify(&Type::Prim(PrimType::Bool), condition.type_())?;
+                let then = self.tc_expr(*then)?;
+                let else_ = self.tc_expr(*else_)?;
+                let type_ = self.unify(then.type_(), else_.type_())?;
+                Ok(hir::Expr::If {
+                    condition: Box::new(condition),
+                    then: Box::new(then),
+                    else_: Box::new(else_),
+                    type_,
+                })
+            }
+            ast::Expr::Fun(f) => {
+                let ast::Fun { args, body } = *f;
+                self.env.push();
+                let args: Vec<_> = args
+                    .into_iter()
+                    .map(|id| {
+                        let ty = self.fresh_ex();
+                        self.env.set(id.clone(), ty.clone());
+                        (id, ty)
+                    })
+                    .collect();
+                let body = self.tc_expr(body)?;
+                self.env.pop();
+                Ok(hir::Expr::Fun {
+                    type_: Type::Fun {
+                        args: args.iter().map(|(_, ty)| ty.clone()).collect(),
+                        ret: Box::new(body.type_().clone()),
+                    },
+                    args,
+                    body: Box::new(body),
+                })
+            }
+            ast::Expr::Call { fun, args } => {
+                let ret_ty = self.fresh_ex();
+                let arg_tys = args.iter().map(|_| self.fresh_ex()).collect::<Vec<_>>();
+                let ft = Type::Fun {
+                    args: arg_tys.clone(),
+                    ret: Box::new(ret_ty.clone()),
+                };
+                let fun = self.tc_expr(*fun)?;
+                self.unify(&ft, fun.type_())?;
+                let args = args
+                    .into_iter()
+                    .zip(arg_tys)
+                    .map(|(arg, ty)| {
+                        let arg = self.tc_expr(arg)?;
+                        self.unify(&ty, arg.type_())?;
+                        Ok(arg)
+                    })
+                    .try_collect()?;
+                Ok(hir::Expr::Call {
+                    fun: Box::new(fun),
+                    args,
+                    type_: ret_ty,
+                })
+            }
+            ast::Expr::Ascription { expr, type_ } => {
+                let expr = self.tc_expr(*expr)?;
+                self.unify(expr.type_(), &type_.into())?;
+                Ok(expr)
+            }
+        }
+    }
+
+    pub(crate) fn tc_decl(&mut self, decl: ast::Decl<'ast>) -> Result<hir::Decl<'ast, Type>> {
+        match decl {
+            ast::Decl::Fun { name, body } => {
+                let body = self.tc_expr(ast::Expr::Fun(Box::new(body)))?;
+                let type_ = body.type_().clone();
+                self.env.set(name.clone(), type_);
+                match body {
+                    hir::Expr::Fun { args, body, type_ } => Ok(hir::Decl::Fun {
+                        name,
+                        args,
+                        body,
+                        type_,
+                    }),
+                    _ => unreachable!(),
+                }
+            }
+        }
+    }
+
+    fn fresh_tv(&mut self) -> TyVar {
+        self.ty_var_counter += 1;
+        TyVar(self.ty_var_counter)
+    }
+
+    fn fresh_ex(&mut self) -> Type {
+        Type::Exist(self.fresh_tv())
+    }
+
+    fn fresh_univ(&mut self) -> Type {
+        Type::Exist(self.fresh_tv())
+    }
+
+    fn universalize<'a>(&mut self, expr: hir::Expr<'a, Type>) -> hir::Expr<'a, Type> {
+        // TODO
+        expr
+    }
+
+    fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result<Type> {
+        match (ty1, ty2) {
+            (Type::Prim(p1), Type::Prim(p2)) if p1 == p2 => Ok(ty2.clone()),
+            (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) {
+                Some(existing_ty) if *ty == existing_ty => Ok(ty.clone()),
+                Some(existing_ty) => Err(Error::TypeMismatch {
+                    expected: ty.clone(),
+                    actual: existing_ty.into(),
+                }),
+                None => match self.ctx.insert(*tv, ty.clone()) {
+                    Some(existing) => self.unify(&existing, ty),
+                    None => Ok(ty.clone()),
+                },
+            },
+            (Type::Univ(u1), Type::Univ(u2)) if u1 == u2 => Ok(ty2.clone()),
+            (
+                Type::Fun {
+                    args: args1,
+                    ret: ret1,
+                },
+                Type::Fun {
+                    args: args2,
+                    ret: ret2,
+                },
+            ) => {
+                let args = args1
+                    .iter()
+                    .zip(args2)
+                    .map(|(t1, t2)| self.unify(t1, t2))
+                    .try_collect()?;
+                let ret = self.unify(ret1, ret2)?;
+                Ok(Type::Fun {
+                    args,
+                    ret: Box::new(ret),
+                })
+            }
+            (Type::Nullary(_), _) | (_, Type::Nullary(_)) => todo!(),
+            _ => Err(Error::TypeMismatch {
+                expected: ty1.clone(),
+                actual: ty2.clone(),
+            }),
+        }
+    }
+
+    fn finalize_expr(&self, expr: hir::Expr<'ast, Type>) -> Result<hir::Expr<'ast, ast::Type>> {
+        expr.traverse_type(|ty| self.finalize_type(ty))
+    }
+
+    fn finalize_decl(&self, decl: hir::Decl<'ast, Type>) -> Result<hir::Decl<'ast, ast::Type>> {
+        decl.traverse_type(|ty| self.finalize_type(ty))
+    }
+
+    fn finalize_type(&self, ty: Type) -> Result<ast::Type> {
+        match ty {
+            Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)),
+            Type::Univ(tv) => todo!(),
+            Type::Nullary(_) => todo!(),
+            Type::Prim(pr) => Ok(pr.into()),
+            Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType {
+                args: args
+                    .into_iter()
+                    .map(|ty| self.finalize_type(ty))
+                    .try_collect()?,
+                ret: Box::new(self.finalize_type(*ret)?),
+            })),
+        }
+    }
+
+    fn resolve_tv(&self, tv: TyVar) -> Option<ast::Type> {
+        let mut res = &Type::Exist(tv);
+        loop {
+            match res {
+                Type::Exist(tv) => {
+                    res = self.ctx.get(tv)?;
+                }
+                Type::Univ(_) => todo!(),
+                Type::Nullary(_) => todo!(),
+                Type::Prim(pr) => break Some((*pr).into()),
+                Type::Fun { args, ret } => todo!(),
+            }
+        }
+    }
+}
+
+pub fn typecheck_expr(expr: ast::Expr) -> Result<hir::Expr<ast::Type>> {
+    let mut typechecker = Typechecker::new();
+    let typechecked = typechecker.tc_expr(expr)?;
+    typechecker.finalize_expr(typechecked)
+}
+
+pub fn typecheck_toplevel(decls: Vec<ast::Decl>) -> Result<Vec<hir::Decl<ast::Type>>> {
+    let mut typechecker = Typechecker::new();
+    decls
+        .into_iter()
+        .map(|decl| {
+            let decl = typechecker.tc_decl(decl)?;
+            typechecker.finalize_decl(decl)
+        })
+        .try_collect()
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    macro_rules! assert_type {
+        ($expr: expr, $type: expr) => {
+            use crate::parser::{expr, type_};
+            let parsed_expr = test_parse!(expr, $expr);
+            let parsed_type = test_parse!(type_, $type);
+            let res = typecheck_expr(parsed_expr).unwrap_or_else(|e| panic!("{}", e));
+            assert_eq!(res.type_(), &parsed_type);
+        };
+    }
+
+    macro_rules! assert_type_error {
+        ($expr: expr) => {
+            use crate::parser::expr;
+            let parsed_expr = test_parse!(expr, $expr);
+            let res = typecheck_expr(parsed_expr);
+            assert!(
+                res.is_err(),
+                "Expected type error, but got type: {}",
+                res.unwrap().type_()
+            );
+        };
+    }
+
+    #[test]
+    fn literal_int() {
+        assert_type!("1", "int");
+    }
+
+    #[test]
+    fn conditional() {
+        assert_type!("if 1 == 2 then 3 else 4", "int");
+    }
+
+    #[test]
+    #[ignore]
+    fn add_bools() {
+        assert_type_error!("true + false");
+    }
+
+    #[test]
+    fn call_generic_function() {
+        assert_type!("(fn x = x) 1", "int");
+    }
+
+    #[test]
+    #[ignore]
+    fn generic_function() {
+        assert_type!("fn x = x", "fn x, y -> x");
+    }
+
+    #[test]
+    #[ignore]
+    fn let_generalization() {
+        assert_type!("let id = fn x = x in if id true then id 1 else 2", "int");
+    }
+
+    #[test]
+    fn concrete_function() {
+        assert_type!("fn x = x + 1", "fn int -> int");
+    }
+
+    #[test]
+    fn call_concrete_function() {
+        assert_type!("(fn x = x + 1) 2", "int");
+    }
+
+    #[test]
+    fn conditional_non_bool() {
+        assert_type_error!("if 3 then true else false");
+    }
+
+    #[test]
+    fn let_int() {
+        assert_type!("let x = 1 in x", "int");
+    }
+}