about summary refs log tree commit diff
path: root/users/glittershark/achilles/src
diff options
context:
space:
mode:
Diffstat (limited to 'users/glittershark/achilles/src')
-rw-r--r--users/glittershark/achilles/src/ast/hir.rs7
-rw-r--r--users/glittershark/achilles/src/tc/mod.rs90
2 files changed, 67 insertions, 30 deletions
diff --git a/users/glittershark/achilles/src/ast/hir.rs b/users/glittershark/achilles/src/ast/hir.rs
index 0212b3dbcdbb..691b9607e7e6 100644
--- a/users/glittershark/achilles/src/ast/hir.rs
+++ b/users/glittershark/achilles/src/ast/hir.rs
@@ -228,6 +228,13 @@ pub enum Decl<'a, T> {
 }
 
 impl<'a, T> Decl<'a, T> {
+    pub fn name(&self) -> &Ident<'a> {
+        match self {
+            Decl::Fun { name, .. } => name,
+            Decl::Extern { name, .. } => name,
+        }
+    }
+
     pub fn type_(&self) -> Option<&T> {
         match self {
             Decl::Fun { type_, .. } => Some(type_),
diff --git a/users/glittershark/achilles/src/tc/mod.rs b/users/glittershark/achilles/src/tc/mod.rs
index 4bca52733bc8..f3cb40a8b010 100644
--- a/users/glittershark/achilles/src/tc/mod.rs
+++ b/users/glittershark/achilles/src/tc/mod.rs
@@ -262,10 +262,10 @@ impl<'ast> Typechecker<'ast> {
                 let body = self.tc_expr(body)?;
                 self.env.pop();
                 Ok(hir::Expr::Fun {
-                    type_: self.universalize(
-                        args.iter().map(|(_, ty)| ty.clone()).collect(),
-                        body.type_().clone(),
-                    ),
+                    type_: Type::Fun {
+                        args: args.iter().map(|(_, ty)| ty.clone()).collect(),
+                        ret: Box::new(body.type_().clone()),
+                    },
                     args,
                     body: Box::new(body),
                 })
@@ -319,9 +319,11 @@ impl<'ast> Typechecker<'ast> {
                     };
                 }
 
+                self.env.push();
                 let body = self.tc_expr(expr)?;
                 let type_ = body.type_().clone();
                 self.env.set(name.clone(), type_);
+                self.env.pop();
                 match body {
                     hir::Expr::Fun { args, body, type_ } => Ok(Some(hir::Decl::Fun {
                         name,
@@ -365,27 +367,6 @@ impl<'ast> Typechecker<'ast> {
         Type::Univ(self.fresh_tv())
     }
 
-    #[allow(clippy::redundant_closure)] // https://github.com/rust-lang/rust-clippy/issues/6903
-    fn universalize(&mut self, args: Vec<Type>, ret: Type) -> Type {
-        let mut vars = HashMap::new();
-        let mut universalize_type = move |ty| match ty {
-            Type::Exist(tv) if self.resolve_tv(tv).is_none() => vars
-                .entry(tv)
-                .or_insert_with(|| {
-                    let ty = self.fresh_univ();
-                    self.ctx.insert(tv, ty.clone());
-                    ty
-                })
-                .clone(),
-            _ => ty,
-        };
-
-        Type::Fun {
-            args: args.into_iter().map(|t| universalize_type(t)).collect(),
-            ret: Box::new(universalize_type(ret)),
-        }
-    }
-
     fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result<Type> {
         match (ty1, ty2) {
             (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) {
@@ -462,10 +443,15 @@ impl<'ast> Typechecker<'ast> {
     }
 
     fn finalize_decl(
-        &self,
+        &mut self,
         decl: hir::Decl<'ast, Type>,
     ) -> Result<hir::Decl<'ast, ast::Type<'ast>>> {
-        decl.traverse_type(|ty| self.finalize_type(ty))
+        let res = decl.traverse_type(|ty| self.finalize_type(ty))?;
+        if let Some(type_) = res.type_() {
+            let ty = self.type_from_ast_type(type_.clone());
+            self.env.set(res.name().clone(), ty);
+        }
+        Ok(res)
     }
 
     fn finalize_type(&self, ty: Type) -> Result<ast::Type<'static>> {
@@ -541,7 +527,12 @@ impl<'ast> Typechecker<'ast> {
             .get_by_right(&tv)
             .map(Ident::to_owned)
             .unwrap_or_else(|| {
-                let name = vars.1.make_name();
+                let name = loop {
+                    let name = vars.1.make_name();
+                    if !vars.0.contains_left(&name) {
+                        break name;
+                    }
+                };
                 vars.0.insert_no_overwrite(name.clone(), tv).unwrap();
                 name
             })
@@ -619,6 +610,26 @@ mod tests {
                 $type
             );
         };
+
+        (toplevel($program: expr), $($decl: ident => $type: expr),+ $(,)?) => {{
+            use crate::parser::{toplevel, type_};
+            let program = test_parse!(toplevel, $program);
+            let res = typecheck_toplevel(program).unwrap_or_else(|e| panic!("{}", e));
+            $(
+            let parsed_type = test_parse!(type_, $type);
+            let ident = Ident::try_from(::std::stringify!($decl)).unwrap();
+            let decl = res.iter().find(|decl| {
+                matches!(decl, crate::ast::hir::Decl::Fun { name, .. } if name == &ident)
+            }).unwrap_or_else(|| panic!("Could not find declaration for {}", ident));
+            assert!(
+                decl.type_().unwrap().alpha_equiv(&parsed_type),
+                "inferred type {} for {}, but expected {}",
+                decl.type_().unwrap(),
+                ident,
+                $type
+            );
+            )+
+        }};
     }
 
     macro_rules! assert_type_error {
@@ -656,8 +667,27 @@ mod tests {
     }
 
     #[test]
-    fn generic_function() {
-        assert_type!("fn x = x", "fn x -> x");
+    fn call_let_bound_generic() {
+        assert_type!("let id = fn x = x in id 1", "int");
+    }
+
+    #[test]
+    fn universal_ascripted_let() {
+        assert_type!("let id: fn a -> a = fn x = x in id 1", "int");
+    }
+
+    #[test]
+    fn call_generic_function_toplevel() {
+        assert_type!(
+            toplevel(
+                "ty id : fn a -> a
+                 fn id x = x
+
+                 fn main = id 0"
+            ),
+            main => "fn -> int",
+            id => "fn a -> a",
+        );
     }
 
     #[test]