about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ast/mod.rs1
-rw-r--r--src/parser/mod.rs30
-rw-r--r--src/tc/mod.rs42
3 files changed, 59 insertions, 14 deletions
diff --git a/src/ast/mod.rs b/src/ast/mod.rs
index 1884ba69f43c..3a2261aeda23 100644
--- a/src/ast/mod.rs
+++ b/src/ast/mod.rs
@@ -266,6 +266,7 @@ impl<'a> Fun<'a> {
 #[derive(Debug, PartialEq, Eq, Clone)]
 pub enum Decl<'a> {
     Fun { name: Ident<'a>, body: Fun<'a> },
+    Ascription { name: Ident<'a>, type_: Type<'a> },
 }
 
 ////
diff --git a/src/parser/mod.rs b/src/parser/mod.rs
index 8599ccabfc23..dd7874aff853 100644
--- a/src/parser/mod.rs
+++ b/src/parser/mod.rs
@@ -98,7 +98,20 @@ named!(fun_decl(&str) -> Decl, do_parse!(
         })
 ));
 
+named!(ascription_decl(&str) -> Decl, do_parse!(
+    name: ident
+        >> multispace0
+        >> complete!(char!(':'))
+        >> multispace0
+        >> type_: type_
+        >> (Decl::Ascription {
+            name,
+            type_
+        })
+));
+
 named!(pub decl(&str) -> Decl, alt!(
+    ascription_decl |
     fun_decl
 ));
 
@@ -108,7 +121,7 @@ named!(pub toplevel(&str) -> Vec<Decl>, terminated!(many0!(decl), multispace0));
 mod tests {
     use std::convert::TryInto;
 
-    use crate::ast::{BinaryOperator, Expr, Literal, Type};
+    use crate::ast::{BinaryOperator, Expr, FunctionType, Literal, Type};
 
     use super::*;
     use expr::tests::ident_expr;
@@ -166,4 +179,19 @@ mod tests {
         );
         assert_eq!(res.len(), 3);
     }
+
+    #[test]
+    fn top_level_ascription() {
+        let res = test_parse!(toplevel, "id : fn a -> a");
+        assert_eq!(
+            res,
+            vec![Decl::Ascription {
+                name: "id".try_into().unwrap(),
+                type_: Type::Function(FunctionType {
+                    args: vec![Type::Var("a".try_into().unwrap())],
+                    ret: Box::new(Type::Var("a".try_into().unwrap()))
+                })
+            }]
+        )
+    }
 }
diff --git a/src/tc/mod.rs b/src/tc/mod.rs
index 4c088c885749..559ac993cc9b 100644
--- a/src/tc/mod.rs
+++ b/src/tc/mod.rs
@@ -305,22 +305,38 @@ impl<'ast> Typechecker<'ast> {
         }
     }
 
-    pub(crate) fn tc_decl(&mut self, decl: ast::Decl<'ast>) -> Result<hir::Decl<'ast, Type>> {
+    pub(crate) fn tc_decl(
+        &mut self,
+        decl: ast::Decl<'ast>,
+    ) -> Result<Option<hir::Decl<'ast, Type>>> {
         match decl {
             ast::Decl::Fun { name, body } => {
-                let body = self.tc_expr(ast::Expr::Fun(Box::new(body)))?;
+                let mut expr = ast::Expr::Fun(Box::new(body));
+                if let Some(type_) = self.env.resolve(&name) {
+                    expr = ast::Expr::Ascription {
+                        expr: Box::new(expr),
+                        type_: self.finalize_type(type_.clone())?,
+                    };
+                }
+
+                let body = self.tc_expr(expr)?;
                 let type_ = body.type_().clone();
                 self.env.set(name.clone(), type_);
                 match body {
-                    hir::Expr::Fun { args, body, type_ } => Ok(hir::Decl::Fun {
+                    hir::Expr::Fun { args, body, type_ } => Ok(Some(hir::Decl::Fun {
                         name,
                         args,
                         body,
                         type_,
-                    }),
+                    })),
                     _ => unreachable!(),
                 }
             }
+            ast::Decl::Ascription { name, type_ } => {
+                let type_ = self.type_from_ast_type(type_);
+                self.env.set(name.clone(), type_);
+                Ok(None)
+            }
         }
     }
 
@@ -561,15 +577,15 @@ pub fn typecheck_expr(expr: ast::Expr) -> Result<hir::Expr<ast::Type>> {
 
 pub fn typecheck_toplevel(decls: Vec<ast::Decl>) -> Result<Vec<hir::Decl<ast::Type>>> {
     let mut typechecker = Typechecker::new();
-    decls
-        .into_iter()
-        .map(|decl| {
-            let hir_decl = typechecker.tc_decl(decl)?;
-            let res = typechecker.finalize_decl(hir_decl)?;
-            typechecker.ctx.clear();
-            Ok(res)
-        })
-        .try_collect()
+    let mut res = Vec::with_capacity(decls.len());
+    for decl in decls {
+        if let Some(hir_decl) = typechecker.tc_decl(decl)? {
+            let hir_decl = typechecker.finalize_decl(hir_decl)?;
+            res.push(hir_decl);
+        }
+        typechecker.ctx.clear();
+    }
+    Ok(res)
 }
 
 #[cfg(test)]