about summary refs log tree commit diff
path: root/src/tc/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/tc/mod.rs')
-rw-r--r--src/tc/mod.rs274
1 files changed, 202 insertions, 72 deletions
diff --git a/src/tc/mod.rs b/src/tc/mod.rs
index 2c40a02bf7c6..4c088c885749 100644
--- a/src/tc/mod.rs
+++ b/src/tc/mod.rs
@@ -1,13 +1,16 @@
+use bimap::BiMap;
 use derive_more::From;
 use itertools::Itertools;
+use std::cell::RefCell;
 use std::collections::HashMap;
 use std::convert::{TryFrom, TryInto};
 use std::fmt::{self, Display};
-use std::result;
+use std::{mem, result};
 use thiserror::Error;
 
-use crate::ast::{self, hir, BinaryOperator, Ident, Literal};
+use crate::ast::{self, hir, Arg, BinaryOperator, Ident, Literal};
 use crate::common::env::Env;
+use crate::common::{Namer, NamerOf};
 
 #[derive(Debug, Error)]
 pub enum Error {
@@ -52,7 +55,7 @@ pub enum PrimType {
     CString,
 }
 
-impl From<PrimType> for ast::Type {
+impl<'a> From<PrimType> for ast::Type<'a> {
     fn from(pr: PrimType) -> Self {
         match pr {
             PrimType::Int => ast::Type::Int,
@@ -88,22 +91,7 @@ pub enum Type {
     },
 }
 
-impl PartialEq<ast::Type> for Type {
-    fn eq(&self, other: &ast::Type) -> bool {
-        match (self, other) {
-            (Type::Univ(_), _) => todo!(),
-            (Type::Exist(_), _) => false,
-            (Type::Nullary(_), _) => todo!(),
-            (Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty,
-            (Type::Fun { args, ret }, ast::Type::Function(ft)) => {
-                *args == ft.args && (**ret).eq(&*ft.ret)
-            }
-            (Type::Fun { .. }, _) => false,
-        }
-    }
-}
-
-impl TryFrom<Type> for ast::Type {
+impl<'a> TryFrom<Type> for ast::Type<'a> {
     type Error = Type;
 
     fn try_from(value: Type) -> result::Result<Self, Self::Error> {
@@ -142,33 +130,29 @@ impl Display for Type {
     }
 }
 
-impl From<ast::Type> for Type {
-    fn from(type_: ast::Type) -> Self {
-        match type_ {
-            ast::Type::Int => INT,
-            ast::Type::Float => FLOAT,
-            ast::Type::Bool => BOOL,
-            ast::Type::CString => CSTRING,
-            ast::Type::Function(ast::FunctionType { args, ret }) => Type::Fun {
-                args: args.into_iter().map(Self::from).collect(),
-                ret: Box::new(Self::from(*ret)),
-            },
-        }
-    }
-}
-
 struct Typechecker<'ast> {
-    ty_var_counter: u64,
+    ty_var_namer: NamerOf<TyVar>,
     ctx: HashMap<TyVar, Type>,
     env: Env<Ident<'ast>, Type>,
+
+    /// AST type var -> type
+    instantiations: Env<Ident<'ast>, Type>,
+
+    /// AST type-var -> universal TyVar
+    type_vars: RefCell<(BiMap<Ident<'ast>, TyVar>, NamerOf<Ident<'static>>)>,
 }
 
 impl<'ast> Typechecker<'ast> {
     fn new() -> Self {
         Self {
-            ty_var_counter: 0,
+            ty_var_namer: Namer::new(TyVar).boxed(),
+            type_vars: RefCell::new((
+                Default::default(),
+                Namer::alphabetic().map(|n| Ident::try_from(n).unwrap()),
+            )),
             ctx: Default::default(),
             env: Default::default(),
+            instantiations: Default::default(),
         }
     }
 
@@ -224,7 +208,8 @@ impl<'ast> Typechecker<'ast> {
                         |ast::Binding { ident, type_, body }| -> Result<hir::Binding<Type>> {
                             let body = self.tc_expr(body)?;
                             if let Some(type_) = type_ {
-                                self.unify(body.type_(), &type_.into())?;
+                                let type_ = self.type_from_ast_type(type_);
+                                self.unify(body.type_(), &type_)?;
                             }
                             self.env.set(ident.clone(), body.type_().clone());
                             Ok(hir::Binding {
@@ -265,19 +250,22 @@ impl<'ast> Typechecker<'ast> {
                 self.env.push();
                 let args: Vec<_> = args
                     .into_iter()
-                    .map(|id| {
-                        let ty = self.fresh_ex();
-                        self.env.set(id.clone(), ty.clone());
-                        (id, ty)
+                    .map(|Arg { ident, type_ }| {
+                        let ty = match type_ {
+                            Some(t) => self.type_from_ast_type(t),
+                            None => self.fresh_ex(),
+                        };
+                        self.env.set(ident.clone(), ty.clone());
+                        (ident, ty)
                     })
                     .collect();
                 let body = self.tc_expr(body)?;
                 self.env.pop();
                 Ok(hir::Expr::Fun {
-                    type_: Type::Fun {
-                        args: args.iter().map(|(_, ty)| ty.clone()).collect(),
-                        ret: Box::new(body.type_().clone()),
-                    },
+                    type_: self.universalize(
+                        args.iter().map(|(_, ty)| ty.clone()).collect(),
+                        body.type_().clone(),
+                    ),
                     args,
                     body: Box::new(body),
                 })
@@ -290,6 +278,7 @@ impl<'ast> Typechecker<'ast> {
                     ret: Box::new(ret_ty.clone()),
                 };
                 let fun = self.tc_expr(*fun)?;
+                self.instantiations.push();
                 self.unify(&ft, fun.type_())?;
                 let args = args
                     .into_iter()
@@ -300,6 +289,7 @@ impl<'ast> Typechecker<'ast> {
                         Ok(arg)
                     })
                     .try_collect()?;
+                self.commit_instantiations();
                 Ok(hir::Expr::Call {
                     fun: Box::new(fun),
                     args,
@@ -308,7 +298,8 @@ impl<'ast> Typechecker<'ast> {
             }
             ast::Expr::Ascription { expr, type_ } => {
                 let expr = self.tc_expr(*expr)?;
-                self.unify(expr.type_(), &type_.into())?;
+                let type_ = self.type_from_ast_type(type_);
+                self.unify(expr.type_(), &type_)?;
                 Ok(expr)
             }
         }
@@ -334,8 +325,7 @@ impl<'ast> Typechecker<'ast> {
     }
 
     fn fresh_tv(&mut self) -> TyVar {
-        self.ty_var_counter += 1;
-        TyVar(self.ty_var_counter)
+        self.ty_var_namer.make_name()
     }
 
     fn fresh_ex(&mut self) -> Type {
@@ -343,29 +333,69 @@ impl<'ast> Typechecker<'ast> {
     }
 
     fn fresh_univ(&mut self) -> Type {
-        Type::Exist(self.fresh_tv())
-    }
+        Type::Univ(self.fresh_tv())
+    }
+
+    #[allow(clippy::redundant_closure)] // https://github.com/rust-lang/rust-clippy/issues/6903
+    fn universalize(&mut self, args: Vec<Type>, ret: Type) -> Type {
+        let mut vars = HashMap::new();
+        let mut universalize_type = move |ty| match ty {
+            Type::Exist(tv) if self.resolve_tv(tv).is_none() => vars
+                .entry(tv)
+                .or_insert_with_key(|tv| {
+                    let ty = self.fresh_univ();
+                    self.ctx.insert(*tv, ty.clone());
+                    ty
+                })
+                .clone(),
+            _ => ty,
+        };
 
-    fn universalize<'a>(&mut self, expr: hir::Expr<'a, Type>) -> hir::Expr<'a, Type> {
-        // TODO
-        expr
+        Type::Fun {
+            args: args.into_iter().map(|t| universalize_type(t)).collect(),
+            ret: Box::new(universalize_type(ret)),
+        }
     }
 
     fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result<Type> {
         match (ty1, ty2) {
-            (Type::Prim(p1), Type::Prim(p2)) if p1 == p2 => Ok(ty2.clone()),
             (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) {
-                Some(existing_ty) if *ty == existing_ty => Ok(ty.clone()),
-                Some(existing_ty) => Err(Error::TypeMismatch {
-                    expected: ty.clone(),
-                    actual: existing_ty.into(),
-                }),
+                Some(existing_ty) if self.types_match(ty, &existing_ty) => Ok(ty.clone()),
+                Some(var @ ast::Type::Var(_)) => {
+                    let var = self.type_from_ast_type(var);
+                    self.unify(&var, ty)
+                }
+                Some(existing_ty) => match ty {
+                    Type::Exist(_) => {
+                        let rhs = self.type_from_ast_type(existing_ty);
+                        self.unify(ty, &rhs)
+                    }
+                    _ => Err(Error::TypeMismatch {
+                        expected: ty.clone(),
+                        actual: self.type_from_ast_type(existing_ty),
+                    }),
+                },
                 None => match self.ctx.insert(*tv, ty.clone()) {
                     Some(existing) => self.unify(&existing, ty),
                     None => Ok(ty.clone()),
                 },
             },
             (Type::Univ(u1), Type::Univ(u2)) if u1 == u2 => Ok(ty2.clone()),
+            (Type::Univ(u), ty) | (ty, Type::Univ(u)) => {
+                let ident = self.name_univ(*u);
+                match self.instantiations.resolve(&ident) {
+                    Some(existing_ty) if ty == existing_ty => Ok(ty.clone()),
+                    Some(existing_ty) => Err(Error::TypeMismatch {
+                        expected: ty.clone(),
+                        actual: existing_ty.clone(),
+                    }),
+                    None => {
+                        self.instantiations.set(ident, ty.clone());
+                        Ok(ty.clone())
+                    }
+                }
+            }
+            (Type::Prim(p1), Type::Prim(p2)) if p1 == p2 => Ok(ty2.clone()),
             (
                 Type::Fun {
                     args: args1,
@@ -395,18 +425,24 @@ impl<'ast> Typechecker<'ast> {
         }
     }
 
-    fn finalize_expr(&self, expr: hir::Expr<'ast, Type>) -> Result<hir::Expr<'ast, ast::Type>> {
+    fn finalize_expr(
+        &self,
+        expr: hir::Expr<'ast, Type>,
+    ) -> Result<hir::Expr<'ast, ast::Type<'ast>>> {
         expr.traverse_type(|ty| self.finalize_type(ty))
     }
 
-    fn finalize_decl(&self, decl: hir::Decl<'ast, Type>) -> Result<hir::Decl<'ast, ast::Type>> {
+    fn finalize_decl(
+        &self,
+        decl: hir::Decl<'ast, Type>,
+    ) -> Result<hir::Decl<'ast, ast::Type<'ast>>> {
         decl.traverse_type(|ty| self.finalize_type(ty))
     }
 
-    fn finalize_type(&self, ty: Type) -> Result<ast::Type> {
-        match ty {
+    fn finalize_type(&self, ty: Type) -> Result<ast::Type<'static>> {
+        let ret = match ty {
             Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)),
-            Type::Univ(tv) => todo!(),
+            Type::Univ(tv) => Ok(ast::Type::Var(self.name_univ(tv))),
             Type::Nullary(_) => todo!(),
             Type::Prim(pr) => Ok(pr.into()),
             Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType {
@@ -416,23 +452,105 @@ impl<'ast> Typechecker<'ast> {
                     .try_collect()?,
                 ret: Box::new(self.finalize_type(*ret)?),
             })),
-        }
+        };
+        ret
     }
 
-    fn resolve_tv(&self, tv: TyVar) -> Option<ast::Type> {
+    fn resolve_tv(&self, tv: TyVar) -> Option<ast::Type<'static>> {
         let mut res = &Type::Exist(tv);
         loop {
             match res {
                 Type::Exist(tv) => {
                     res = self.ctx.get(tv)?;
                 }
-                Type::Univ(_) => todo!(),
+                Type::Univ(tv) => {
+                    let ident = self.name_univ(*tv);
+                    if let Some(r) = self.instantiations.resolve(&ident) {
+                        res = r;
+                    } else {
+                        break Some(ast::Type::Var(ident));
+                    }
+                }
                 Type::Nullary(_) => todo!(),
                 Type::Prim(pr) => break Some((*pr).into()),
                 Type::Fun { args, ret } => todo!(),
             }
         }
     }
+
+    fn type_from_ast_type(&mut self, ast_type: ast::Type<'ast>) -> Type {
+        match ast_type {
+            ast::Type::Int => INT,
+            ast::Type::Float => FLOAT,
+            ast::Type::Bool => BOOL,
+            ast::Type::CString => CSTRING,
+            ast::Type::Function(ast::FunctionType { args, ret }) => Type::Fun {
+                args: args
+                    .into_iter()
+                    .map(|t| self.type_from_ast_type(t))
+                    .collect(),
+                ret: Box::new(self.type_from_ast_type(*ret)),
+            },
+            ast::Type::Var(id) => Type::Univ({
+                let opt_tv = { self.type_vars.borrow_mut().0.get_by_left(&id).copied() };
+                opt_tv.unwrap_or_else(|| {
+                    let tv = self.fresh_tv();
+                    self.type_vars
+                        .borrow_mut()
+                        .0
+                        .insert_no_overwrite(id, tv)
+                        .unwrap();
+                    tv
+                })
+            }),
+        }
+    }
+
+    fn name_univ(&self, tv: TyVar) -> Ident<'static> {
+        let mut vars = self.type_vars.borrow_mut();
+        vars.0
+            .get_by_right(&tv)
+            .map(Ident::to_owned)
+            .unwrap_or_else(|| {
+                let name = vars.1.make_name();
+                vars.0.insert_no_overwrite(name.clone(), tv).unwrap();
+                name
+            })
+    }
+
+    fn commit_instantiations(&mut self) {
+        let mut ctx = mem::take(&mut self.ctx);
+        for (_, v) in ctx.iter_mut() {
+            if let Type::Univ(tv) = v {
+                if let Some(concrete) = self.instantiations.resolve(&self.name_univ(*tv)) {
+                    *v = concrete.clone();
+                }
+            }
+        }
+        self.ctx = ctx;
+        self.instantiations.pop();
+    }
+
+    fn types_match(&self, type_: &Type, ast_type: &ast::Type<'ast>) -> bool {
+        match (type_, ast_type) {
+            (Type::Univ(u), ast::Type::Var(v)) => {
+                Some(u) == self.type_vars.borrow().0.get_by_left(v)
+            }
+            (Type::Univ(_), _) => false,
+            (Type::Exist(_), _) => false,
+            (Type::Nullary(_), _) => todo!(),
+            (Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty,
+            (Type::Fun { args, ret }, ast::Type::Function(ft)) => {
+                args.len() == ft.args.len()
+                    && args
+                        .iter()
+                        .zip(&ft.args)
+                        .all(|(a1, a2)| self.types_match(a1, &a2))
+                    && self.types_match(&*ret, &*ft.ret)
+            }
+            (Type::Fun { .. }, _) => false,
+        }
+    }
 }
 
 pub fn typecheck_expr(expr: ast::Expr) -> Result<hir::Expr<ast::Type>> {
@@ -446,8 +564,10 @@ pub fn typecheck_toplevel(decls: Vec<ast::Decl>) -> Result<Vec<hir::Decl<ast::Ty
     decls
         .into_iter()
         .map(|decl| {
-            let decl = typechecker.tc_decl(decl)?;
-            typechecker.finalize_decl(decl)
+            let hir_decl = typechecker.tc_decl(decl)?;
+            let res = typechecker.finalize_decl(hir_decl)?;
+            typechecker.ctx.clear();
+            Ok(res)
         })
         .try_collect()
 }
@@ -462,7 +582,13 @@ mod tests {
             let parsed_expr = test_parse!(expr, $expr);
             let parsed_type = test_parse!(type_, $type);
             let res = typecheck_expr(parsed_expr).unwrap_or_else(|e| panic!("{}", e));
-            assert_eq!(res.type_(), &parsed_type);
+            assert!(
+                res.type_().alpha_equiv(&parsed_type),
+                "{} inferred type {}, but expected {}",
+                $expr,
+                res.type_(),
+                $type
+            );
         };
     }
 
@@ -501,9 +627,8 @@ mod tests {
     }
 
     #[test]
-    #[ignore]
     fn generic_function() {
-        assert_type!("fn x = x", "fn x, y -> x");
+        assert_type!("fn x = x", "fn x -> x");
     }
 
     #[test]
@@ -518,6 +643,11 @@ mod tests {
     }
 
     #[test]
+    fn arg_ascriptions() {
+        assert_type!("fn (x: int) = x", "fn int -> int");
+    }
+
+    #[test]
     fn call_concrete_function() {
         assert_type!("(fn x = x + 1) 2", "int");
     }