about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
authorGriffin Smith <root@gws.fyi>2021-03-13T18·12-0500
committerGriffin Smith <root@gws.fyi>2021-03-13T18·12-0500
commitf8beda81fbe8d04883aee71ff4ea078f897c6de4 (patch)
treead61046d7e86c8a71381ee6b936fcd46ec3a89ac /src
parent3dff189499af1ddd60d8fc128b794d15f1cb19ae (diff)
Allow exprs+bindings to optionally be ascripted
Diffstat (limited to 'src')
-rw-r--r--src/ast/mod.rs33
-rw-r--r--src/codegen/llvm.rs9
-rw-r--r--src/interpreter/mod.rs11
-rw-r--r--src/parser/expr.rs143
-rw-r--r--src/parser/mod.rs2
-rw-r--r--src/parser/type_.rs104
6 files changed, 264 insertions, 38 deletions
diff --git a/src/ast/mod.rs b/src/ast/mod.rs
index 7f3543271db8..dc22ac3cdb56 100644
--- a/src/ast/mod.rs
+++ b/src/ast/mod.rs
@@ -110,6 +110,23 @@ pub enum Literal {
 }
 
 #[derive(Debug, PartialEq, Eq, Clone)]
+pub struct Binding<'a> {
+    pub ident: Ident<'a>,
+    pub type_: Option<Type>,
+    pub body: Expr<'a>,
+}
+
+impl<'a> Binding<'a> {
+    fn to_owned(&self) -> Binding<'static> {
+        Binding {
+            ident: self.ident.to_owned(),
+            type_: self.type_.clone(),
+            body: self.body.to_owned(),
+        }
+    }
+}
+
+#[derive(Debug, PartialEq, Eq, Clone)]
 pub enum Expr<'a> {
     Ident(Ident<'a>),
 
@@ -127,7 +144,7 @@ pub enum Expr<'a> {
     },
 
     Let {
-        bindings: Vec<(Ident<'a>, Expr<'a>)>,
+        bindings: Vec<Binding<'a>>,
         body: Box<Expr<'a>>,
     },
 
@@ -143,6 +160,11 @@ pub enum Expr<'a> {
         fun: Box<Expr<'a>>,
         args: Vec<Expr<'a>>,
     },
+
+    Ascription {
+        expr: Box<Expr<'a>>,
+        type_: Type,
+    },
 }
 
 impl<'a> Expr<'a> {
@@ -160,10 +182,7 @@ impl<'a> Expr<'a> {
                 rhs: Box::new((**rhs).to_owned()),
             },
             Expr::Let { bindings, body } => Expr::Let {
-                bindings: bindings
-                    .iter()
-                    .map(|(id, expr)| (id.to_owned(), expr.to_owned()))
-                    .collect(),
+                bindings: bindings.iter().map(|binding| binding.to_owned()).collect(),
                 body: Box::new((**body).to_owned()),
             },
             Expr::If {
@@ -180,6 +199,10 @@ impl<'a> Expr<'a> {
                 fun: Box::new((**fun).to_owned()),
                 args: args.iter().map(|arg| arg.to_owned()).collect(),
             },
+            Expr::Ascription { expr, type_ } => Expr::Ascription {
+                expr: Box::new((**expr).to_owned()),
+                type_: type_.clone(),
+            },
         }
     }
 }
diff --git a/src/codegen/llvm.rs b/src/codegen/llvm.rs
index 1d1c742a9434..1f4a457cd81b 100644
--- a/src/codegen/llvm.rs
+++ b/src/codegen/llvm.rs
@@ -12,7 +12,7 @@ use inkwell::values::{AnyValueEnum, BasicValueEnum, FunctionValue};
 use inkwell::IntPredicate;
 use thiserror::Error;
 
-use crate::ast::{BinaryOperator, Decl, Expr, Fun, Ident, Literal, UnaryOperator};
+use crate::ast::{BinaryOperator, Binding, Decl, Expr, Fun, Ident, Literal, UnaryOperator};
 use crate::common::env::Env;
 
 #[derive(Debug, PartialEq, Eq, Error)]
@@ -137,9 +137,9 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
             }
             Expr::Let { bindings, body } => {
                 self.env.push();
-                for (id, val) in bindings {
-                    let val = self.codegen_expr(val)?;
-                    self.env.set(id, val);
+                for Binding { ident, body, .. } in bindings {
+                    let val = self.codegen_expr(body)?;
+                    self.env.set(ident, val);
                 }
                 let res = self.codegen_expr(body);
                 self.env.pop();
@@ -207,6 +207,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                 self.env.restore(env);
                 Ok(function.into())
             }
+            Expr::Ascription { expr, .. } => self.codegen_expr(expr),
         }
     }
 
diff --git a/src/interpreter/mod.rs b/src/interpreter/mod.rs
index fc1556d1c292..00421ee90dc8 100644
--- a/src/interpreter/mod.rs
+++ b/src/interpreter/mod.rs
@@ -3,7 +3,9 @@ mod value;
 
 pub use self::error::{Error, Result};
 pub use self::value::{Function, Value};
-use crate::ast::{BinaryOperator, Expr, FunctionType, Ident, Literal, Type, UnaryOperator};
+use crate::ast::{
+    BinaryOperator, Binding, Expr, FunctionType, Ident, Literal, Type, UnaryOperator,
+};
 use crate::common::env::Env;
 
 #[derive(Debug, Default)]
@@ -49,9 +51,9 @@ impl<'a> Interpreter<'a> {
             }
             Expr::Let { bindings, body } => {
                 self.env.push();
-                for (id, val) in bindings {
-                    let val = self.eval(val)?;
-                    self.env.set(id, val);
+                for Binding { ident, body, .. } in bindings {
+                    let val = self.eval(body)?;
+                    self.env.set(ident, val);
                 }
                 let res = self.eval(body)?;
                 self.env.pop();
@@ -101,6 +103,7 @@ impl<'a> Interpreter<'a> {
                 args: fun.args.iter().map(|arg| arg.to_owned()).collect(),
                 body: fun.body.to_owned(),
             })),
+            Expr::Ascription { expr, .. } => self.eval(expr),
         }
     }
 }
diff --git a/src/parser/expr.rs b/src/parser/expr.rs
index a42c7a6c765e..2fda3e93fae9 100644
--- a/src/parser/expr.rs
+++ b/src/parser/expr.rs
@@ -1,12 +1,12 @@
 use nom::character::complete::{digit1, multispace0, multispace1};
 use nom::{
-    alt, char, complete, delimited, do_parse, flat_map, many0, map, named, parse_to,
-    separated_list0, separated_list1, tag, tuple,
+    alt, call, char, complete, delimited, do_parse, flat_map, many0, map, named, opt, parse_to,
+    preceded, separated_list0, separated_list1, tag, tuple,
 };
 use pratt::{Affix, Associativity, PrattParser, Precedence};
 
-use crate::ast::{BinaryOperator, Expr, Fun, Ident, Literal, UnaryOperator};
-use crate::parser::ident;
+use crate::ast::{BinaryOperator, Binding, Expr, Fun, Ident, Literal, UnaryOperator};
+use crate::parser::{ident, type_};
 
 #[derive(Debug)]
 enum TokenTree<'a> {
@@ -158,14 +158,20 @@ named!(int(&str) -> Literal, map!(flat_map!(digit1, parse_to!(u64)), Literal::In
 
 named!(literal(&str) -> Expr, map!(alt!(int), Expr::Literal));
 
-named!(binding(&str) -> (Ident, Expr), do_parse!(
+named!(binding(&str) -> Binding, do_parse!(
     multispace0
         >> ident: ident
         >> multispace0
+        >> type_: opt!(preceded!(tuple!(tag!(":"), multispace0), type_))
+        >> multispace0
         >> char!('=')
         >> multispace0
-        >> expr: expr
-        >> (ident, expr)
+        >> body: expr
+        >> (Binding {
+            ident,
+            type_,
+            body
+        })
 ));
 
 named!(let_(&str) -> Expr, do_parse!(
@@ -203,6 +209,25 @@ named!(if_(&str) -> Expr, do_parse! (
 
 named!(ident_expr(&str) -> Expr, map!(ident, Expr::Ident));
 
+fn ascripted<'a>(
+    p: impl Fn(&'a str) -> nom::IResult<&'a str, Expr, nom::error::Error<&'a str>> + 'a,
+) -> impl Fn(&'a str) -> nom::IResult<&str, Expr, nom::error::Error<&'a str>> {
+    move |i| {
+        do_parse!(
+            i,
+            expr: p
+                >> multispace0
+                >> complete!(tag!(":"))
+                >> multispace0
+                >> type_: type_
+                >> (Expr::Ascription {
+                    expr: Box::new(expr),
+                    type_
+                })
+        )
+    }
+}
+
 named!(paren_expr(&str) -> Expr,
        delimited!(complete!(tag!("(")), expr, complete!(tag!(")"))));
 
@@ -251,7 +276,7 @@ named!(call_with_args(&str) -> Expr, do_parse!(
         })
 ));
 
-named!(simple_expr(&str) -> Expr, alt!(
+named!(simple_expr_unascripted(&str) -> Expr, alt!(
     let_ |
     if_ |
     fun_expr |
@@ -259,17 +284,24 @@ named!(simple_expr(&str) -> Expr, alt!(
     ident_expr
 ));
 
+named!(simple_expr(&str) -> Expr, alt!(
+    call!(ascripted(simple_expr_unascripted)) |
+    simple_expr_unascripted
+));
+
 named!(pub expr(&str) -> Expr, alt!(
     no_arg_call |
     call_with_args |
     map!(token_tree, |tt| {
         ExprParser.parse(&mut tt.into_iter()).unwrap()
     }) |
-    simple_expr));
+    simple_expr
+));
 
 #[cfg(test)]
 pub(crate) mod tests {
     use super::*;
+    use crate::ast::Type;
     use std::convert::TryFrom;
     use BinaryOperator::*;
     use Expr::{BinaryOp, If, Let, UnaryOp};
@@ -374,18 +406,20 @@ pub(crate) mod tests {
             res,
             Let {
                 bindings: vec![
-                    (
-                        Ident::try_from("x").unwrap(),
-                        Expr::Literal(Literal::Int(1))
-                    ),
-                    (
-                        Ident::try_from("y").unwrap(),
-                        Expr::BinaryOp {
+                    Binding {
+                        ident: Ident::try_from("x").unwrap(),
+                        type_: None,
+                        body: Expr::Literal(Literal::Int(1))
+                    },
+                    Binding {
+                        ident: Ident::try_from("y").unwrap(),
+                        type_: None,
+                        body: Expr::BinaryOp {
                             lhs: ident_expr("x"),
                             op: Mul,
                             rhs: Box::new(Expr::Literal(Literal::Int(7)))
                         }
-                    )
+                    }
                 ],
                 body: Box::new(Expr::BinaryOp {
                     lhs: Box::new(Expr::BinaryOp {
@@ -448,10 +482,11 @@ pub(crate) mod tests {
             res,
             Expr::Call {
                 fun: Box::new(Expr::Let {
-                    bindings: vec![(
-                        Ident::try_from("x").unwrap(),
-                        Expr::Literal(Literal::Int(1))
-                    )],
+                    bindings: vec![Binding {
+                        ident: Ident::try_from("x").unwrap(),
+                        type_: None,
+                        body: Expr::Literal(Literal::Int(1))
+                    }],
                     body: ident_expr("x")
                 }),
                 args: vec![Expr::Literal(Literal::Int(2))]
@@ -465,13 +500,14 @@ pub(crate) mod tests {
         assert_eq!(
             res,
             Expr::Let {
-                bindings: vec![(
-                    Ident::try_from("id").unwrap(),
-                    Expr::Fun(Box::new(Fun {
+                bindings: vec![Binding {
+                    ident: Ident::try_from("id").unwrap(),
+                    type_: None,
+                    body: Expr::Fun(Box::new(Fun {
                         args: vec![Ident::try_from("x").unwrap()],
                         body: *ident_expr("x")
                     }))
-                )],
+                }],
                 body: Box::new(Expr::Call {
                     fun: ident_expr("id"),
                     args: vec![Expr::Literal(Literal::Int(1))],
@@ -479,4 +515,61 @@ pub(crate) mod tests {
             }
         );
     }
+
+    mod ascriptions {
+        use super::*;
+
+        #[test]
+        fn bare_ascription() {
+            let res = test_parse!(expr, "1: float");
+            assert_eq!(
+                res,
+                Expr::Ascription {
+                    expr: Box::new(Expr::Literal(Literal::Int(1))),
+                    type_: Type::Float
+                }
+            )
+        }
+
+        #[test]
+        fn fn_body_ascription() {
+            let res = test_parse!(expr, "let const_1 = fn x = 1: int in const_1 2");
+            assert_eq!(
+                res,
+                Expr::Let {
+                    bindings: vec![Binding {
+                        ident: Ident::try_from("const_1").unwrap(),
+                        type_: None,
+                        body: Expr::Fun(Box::new(Fun {
+                            args: vec![Ident::try_from("x").unwrap()],
+                            body: Expr::Ascription {
+                                expr: Box::new(Expr::Literal(Literal::Int(1))),
+                                type_: Type::Int,
+                            }
+                        }))
+                    }],
+                    body: Box::new(Expr::Call {
+                        fun: ident_expr("const_1"),
+                        args: vec![Expr::Literal(Literal::Int(2))]
+                    })
+                }
+            )
+        }
+
+        #[test]
+        fn let_binding_ascripted() {
+            let res = test_parse!(expr, "let x: int = 1 in x");
+            assert_eq!(
+                res,
+                Expr::Let {
+                    bindings: vec![Binding {
+                        ident: Ident::try_from("x").unwrap(),
+                        type_: Some(Type::Int),
+                        body: Expr::Literal(Literal::Int(1))
+                    }],
+                    body: ident_expr("x")
+                }
+            )
+        }
+    }
 }
diff --git a/src/parser/mod.rs b/src/parser/mod.rs
index 3e162d449320..af7dff6ff213 100644
--- a/src/parser/mod.rs
+++ b/src/parser/mod.rs
@@ -5,9 +5,11 @@ use nom::{alt, char, complete, do_parse, many0, named, separated_list0, tag};
 #[macro_use]
 mod macros;
 mod expr;
+mod type_;
 
 use crate::ast::{Decl, Fun, Ident};
 pub use expr::expr;
+pub use type_::type_;
 
 pub type Error = nom::Err<nom::error::Error<String>>;
 
diff --git a/src/parser/type_.rs b/src/parser/type_.rs
new file mode 100644
index 000000000000..076df7d6bd55
--- /dev/null
+++ b/src/parser/type_.rs
@@ -0,0 +1,104 @@
+use nom::character::complete::{multispace0, multispace1};
+use nom::{alt, delimited, do_parse, map, named, opt, separated_list0, tag, terminated, tuple};
+
+use crate::ast::{FunctionType, Type};
+
+named!(function_type(&str) -> Type, do_parse!(
+    tag!("fn")
+        >> multispace1
+        >> args: map!(opt!(terminated!(separated_list0!(
+            tuple!(
+                multispace0,
+                tag!(","),
+                multispace0
+            ),
+            type_
+        ), multispace1)), |args| args.unwrap_or_default())
+        >> tag!("->")
+        >> multispace1
+        >> ret: type_
+        >> (Type::Function(FunctionType {
+            args,
+            ret: Box::new(ret)
+        }))
+));
+
+named!(pub type_(&str) -> Type, alt!(
+    tag!("int") => { |_| Type::Int } |
+    tag!("float") => { |_| Type::Float } |
+    tag!("bool") => { |_| Type::Bool } |
+    function_type |
+    delimited!(
+        tuple!(tag!("("), multispace0),
+        type_,
+        tuple!(tag!(")"), multispace0)
+    )
+));
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn simple_types() {
+        assert_eq!(test_parse!(type_, "int"), Type::Int);
+        assert_eq!(test_parse!(type_, "float"), Type::Float);
+        assert_eq!(test_parse!(type_, "bool"), Type::Bool);
+    }
+
+    #[test]
+    fn no_arg_fn_type() {
+        assert_eq!(
+            test_parse!(type_, "fn -> int"),
+            Type::Function(FunctionType {
+                args: vec![],
+                ret: Box::new(Type::Int)
+            })
+        );
+    }
+
+    #[test]
+    fn fn_type_with_args() {
+        assert_eq!(
+            test_parse!(type_, "fn int, bool -> int"),
+            Type::Function(FunctionType {
+                args: vec![Type::Int, Type::Bool],
+                ret: Box::new(Type::Int)
+            })
+        );
+    }
+
+    #[test]
+    fn fn_taking_fn() {
+        assert_eq!(
+            test_parse!(type_, "fn fn int, bool -> bool, float -> float"),
+            Type::Function(FunctionType {
+                args: vec![
+                    Type::Function(FunctionType {
+                        args: vec![Type::Int, Type::Bool],
+                        ret: Box::new(Type::Bool)
+                    }),
+                    Type::Float
+                ],
+                ret: Box::new(Type::Float)
+            })
+        )
+    }
+
+    #[test]
+    fn parenthesized() {
+        assert_eq!(
+            test_parse!(type_, "fn (fn int, bool -> bool), float -> float"),
+            Type::Function(FunctionType {
+                args: vec![
+                    Type::Function(FunctionType {
+                        args: vec![Type::Int, Type::Bool],
+                        ret: Box::new(Type::Bool)
+                    }),
+                    Type::Float
+                ],
+                ret: Box::new(Type::Float)
+            })
+        )
+    }
+}