about summary refs log tree commit diff
path: root/src/tc
diff options
context:
space:
mode:
Diffstat (limited to 'src/tc')
-rw-r--r--src/tc/mod.rs42
1 files changed, 29 insertions, 13 deletions
diff --git a/src/tc/mod.rs b/src/tc/mod.rs
index 4c088c885749..559ac993cc9b 100644
--- a/src/tc/mod.rs
+++ b/src/tc/mod.rs
@@ -305,22 +305,38 @@ impl<'ast> Typechecker<'ast> {
         }
     }
 
-    pub(crate) fn tc_decl(&mut self, decl: ast::Decl<'ast>) -> Result<hir::Decl<'ast, Type>> {
+    pub(crate) fn tc_decl(
+        &mut self,
+        decl: ast::Decl<'ast>,
+    ) -> Result<Option<hir::Decl<'ast, Type>>> {
         match decl {
             ast::Decl::Fun { name, body } => {
-                let body = self.tc_expr(ast::Expr::Fun(Box::new(body)))?;
+                let mut expr = ast::Expr::Fun(Box::new(body));
+                if let Some(type_) = self.env.resolve(&name) {
+                    expr = ast::Expr::Ascription {
+                        expr: Box::new(expr),
+                        type_: self.finalize_type(type_.clone())?,
+                    };
+                }
+
+                let body = self.tc_expr(expr)?;
                 let type_ = body.type_().clone();
                 self.env.set(name.clone(), type_);
                 match body {
-                    hir::Expr::Fun { args, body, type_ } => Ok(hir::Decl::Fun {
+                    hir::Expr::Fun { args, body, type_ } => Ok(Some(hir::Decl::Fun {
                         name,
                         args,
                         body,
                         type_,
-                    }),
+                    })),
                     _ => unreachable!(),
                 }
             }
+            ast::Decl::Ascription { name, type_ } => {
+                let type_ = self.type_from_ast_type(type_);
+                self.env.set(name.clone(), type_);
+                Ok(None)
+            }
         }
     }
 
@@ -561,15 +577,15 @@ pub fn typecheck_expr(expr: ast::Expr) -> Result<hir::Expr<ast::Type>> {
 
 pub fn typecheck_toplevel(decls: Vec<ast::Decl>) -> Result<Vec<hir::Decl<ast::Type>>> {
     let mut typechecker = Typechecker::new();
-    decls
-        .into_iter()
-        .map(|decl| {
-            let hir_decl = typechecker.tc_decl(decl)?;
-            let res = typechecker.finalize_decl(hir_decl)?;
-            typechecker.ctx.clear();
-            Ok(res)
-        })
-        .try_collect()
+    let mut res = Vec::with_capacity(decls.len());
+    for decl in decls {
+        if let Some(hir_decl) = typechecker.tc_decl(decl)? {
+            let hir_decl = typechecker.finalize_decl(hir_decl)?;
+            res.push(hir_decl);
+        }
+        typechecker.ctx.clear();
+    }
+    Ok(res)
 }
 
 #[cfg(test)]