about summary refs log tree commit diff
path: root/src/interpreter
diff options
context:
space:
mode:
authorGriffin Smith <root@gws.fyi>2021-03-14T02·57-0500
committerGriffin Smith <root@gws.fyi>2021-03-14T03·07-0500
commit32a5c0ff0fc58aa6721c1e0ad41950bde2d66744 (patch)
treeef5dcf5234c2a86607ee2f8f30db73bad016e075 /src/interpreter
parentf8beda81fbe8d04883aee71ff4ea078f897c6de4 (diff)
Add the start of a hindley-milner typechecker
The beginning of a parse-don't-validate-based hindley-milner
typechecker, which returns on success an IR where every AST node
trivially knows its own type, and using those types to determine LLVM
types in codegen.
Diffstat (limited to 'src/interpreter')
-rw-r--r--src/interpreter/mod.rs70
-rw-r--r--src/interpreter/value.rs5
2 files changed, 42 insertions, 33 deletions
diff --git a/src/interpreter/mod.rs b/src/interpreter/mod.rs
index 00421ee90dc8..85a8928cbf9a 100644
--- a/src/interpreter/mod.rs
+++ b/src/interpreter/mod.rs
@@ -3,14 +3,13 @@ mod value;
 
 pub use self::error::{Error, Result};
 pub use self::value::{Function, Value};
-use crate::ast::{
-    BinaryOperator, Binding, Expr, FunctionType, Ident, Literal, Type, UnaryOperator,
-};
+use crate::ast::hir::{Binding, Expr};
+use crate::ast::{BinaryOperator, FunctionType, Ident, Literal, Type, UnaryOperator};
 use crate::common::env::Env;
 
 #[derive(Debug, Default)]
 pub struct Interpreter<'a> {
-    env: Env<'a, Value<'a>>,
+    env: Env<&'a Ident<'a>, Value<'a>>,
 }
 
 impl<'a> Interpreter<'a> {
@@ -25,18 +24,19 @@ impl<'a> Interpreter<'a> {
             .ok_or_else(|| Error::UndefinedVariable(var.to_owned()))
     }
 
-    pub fn eval(&mut self, expr: &'a Expr<'a>) -> Result<Value<'a>> {
-        match expr {
-            Expr::Ident(id) => self.resolve(id),
-            Expr::Literal(Literal::Int(i)) => Ok((*i).into()),
-            Expr::UnaryOp { op, rhs } => {
+    pub fn eval(&mut self, expr: &'a Expr<'a, Type>) -> Result<Value<'a>> {
+        let res = match expr {
+            Expr::Ident(id, _) => self.resolve(id),
+            Expr::Literal(Literal::Int(i), _) => Ok((*i).into()),
+            Expr::Literal(Literal::Bool(b), _) => Ok((*b).into()),
+            Expr::UnaryOp { op, rhs, .. } => {
                 let rhs = self.eval(rhs)?;
                 match op {
                     UnaryOperator::Neg => -rhs,
                     _ => unimplemented!(),
                 }
             }
-            Expr::BinaryOp { lhs, op, rhs } => {
+            Expr::BinaryOp { lhs, op, rhs, .. } => {
                 let lhs = self.eval(lhs)?;
                 let rhs = self.eval(rhs)?;
                 match op {
@@ -49,7 +49,7 @@ impl<'a> Interpreter<'a> {
                     BinaryOperator::Neq => todo!(),
                 }
             }
-            Expr::Let { bindings, body } => {
+            Expr::Let { bindings, body, .. } => {
                 self.env.push();
                 for Binding { ident, body, .. } in bindings {
                     let val = self.eval(body)?;
@@ -63,6 +63,7 @@ impl<'a> Interpreter<'a> {
                 condition,
                 then,
                 else_,
+                ..
             } => {
                 let condition = self.eval(condition)?;
                 if *(condition.as_type::<bool>()?) {
@@ -71,7 +72,7 @@ impl<'a> Interpreter<'a> {
                     self.eval(else_)
                 }
             }
-            Expr::Call { ref fun, args } => {
+            Expr::Call { ref fun, args, .. } => {
                 let fun = self.eval(fun)?;
                 let expected_type = FunctionType {
                     args: args.iter().map(|_| Type::Int).collect(),
@@ -94,21 +95,26 @@ impl<'a> Interpreter<'a> {
                 }
                 Ok(Value::from(*interpreter.eval(body)?.as_type::<i64>()?))
             }
-            Expr::Fun(fun) => Ok(Value::from(value::Function {
-                // TODO
-                type_: FunctionType {
-                    args: fun.args.iter().map(|_| Type::Int).collect(),
-                    ret: Box::new(Type::Int),
-                },
-                args: fun.args.iter().map(|arg| arg.to_owned()).collect(),
-                body: fun.body.to_owned(),
-            })),
-            Expr::Ascription { expr, .. } => self.eval(expr),
-        }
+            Expr::Fun { args, body, type_ } => {
+                let type_ = match type_ {
+                    Type::Function(ft) => ft.clone(),
+                    _ => unreachable!("Function expression without function type"),
+                };
+
+                Ok(Value::from(value::Function {
+                    // TODO
+                    type_,
+                    args: args.iter().map(|(arg, _)| arg.to_owned()).collect(),
+                    body: (**body).to_owned(),
+                }))
+            }
+        }?;
+        debug_assert_eq!(&res.type_(), expr.type_());
+        Ok(res)
     }
 }
 
-pub fn eval<'a>(expr: &'a Expr<'a>) -> Result<Value> {
+pub fn eval<'a>(expr: &'a Expr<'a, Type>) -> Result<Value> {
     let mut interpreter = Interpreter::new();
     interpreter.eval(expr)
 }
@@ -121,17 +127,18 @@ mod tests {
     use super::*;
     use BinaryOperator::*;
 
-    fn int_lit(i: u64) -> Box<Expr<'static>> {
-        Box::new(Expr::Literal(Literal::Int(i)))
+    fn int_lit(i: u64) -> Box<Expr<'static, Type>> {
+        Box::new(Expr::Literal(Literal::Int(i), Type::Int))
     }
 
-    fn parse_eval<T>(src: &str) -> T
+    fn do_eval<T>(src: &str) -> T
     where
         for<'a> &'a T: TryFrom<&'a Val<'a>>,
         T: Clone + TypeOf,
     {
         let expr = crate::parser::expr(src).unwrap().1;
-        let res = eval(&expr).unwrap();
+        let hir = crate::tc::typecheck_expr(expr).unwrap();
+        let res = eval(&hir).unwrap();
         res.as_type::<T>().unwrap().clone()
     }
 
@@ -141,6 +148,7 @@ mod tests {
             lhs: int_lit(1),
             op: Mul,
             rhs: int_lit(2),
+            type_: Type::Int,
         };
         let res = eval(&expr).unwrap();
         assert_eq!(*res.as_type::<i64>().unwrap(), 2);
@@ -148,19 +156,19 @@ mod tests {
 
     #[test]
     fn variable_shadowing() {
-        let res = parse_eval::<i64>("let x = 1 in (let x = 2 in x) + x");
+        let res = do_eval::<i64>("let x = 1 in (let x = 2 in x) + x");
         assert_eq!(res, 3);
     }
 
     #[test]
     fn conditional_with_equals() {
-        let res = parse_eval::<i64>("let x = 1 in if x == 1 then 2 else 4");
+        let res = do_eval::<i64>("let x = 1 in if x == 1 then 2 else 4");
         assert_eq!(res, 2);
     }
 
     #[test]
     fn function_call() {
-        let res = parse_eval::<i64>("let id = fn x = x in id 1");
+        let res = do_eval::<i64>("let id = fn x = x in id 1");
         assert_eq!(res, 1);
     }
 }
diff --git a/src/interpreter/value.rs b/src/interpreter/value.rs
index 496e9c4230de..5e55825160cd 100644
--- a/src/interpreter/value.rs
+++ b/src/interpreter/value.rs
@@ -6,13 +6,14 @@ use std::rc::Rc;
 use derive_more::{Deref, From, TryInto};
 
 use super::{Error, Result};
-use crate::ast::{Expr, FunctionType, Ident, Type};
+use crate::ast::hir::Expr;
+use crate::ast::{FunctionType, Ident, Type};
 
 #[derive(Debug, Clone)]
 pub struct Function<'a> {
     pub type_: FunctionType,
     pub args: Vec<Ident<'a>>,
-    pub body: Expr<'a>,
+    pub body: Expr<'a, Type>,
 }
 
 #[derive(From, TryInto)]