about summary refs log tree commit diff
path: root/src/interpreter/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/interpreter/mod.rs')
-rw-r--r--src/interpreter/mod.rs70
1 files changed, 39 insertions, 31 deletions
diff --git a/src/interpreter/mod.rs b/src/interpreter/mod.rs
index 00421ee90dc8..85a8928cbf9a 100644
--- a/src/interpreter/mod.rs
+++ b/src/interpreter/mod.rs
@@ -3,14 +3,13 @@ mod value;
 
 pub use self::error::{Error, Result};
 pub use self::value::{Function, Value};
-use crate::ast::{
-    BinaryOperator, Binding, Expr, FunctionType, Ident, Literal, Type, UnaryOperator,
-};
+use crate::ast::hir::{Binding, Expr};
+use crate::ast::{BinaryOperator, FunctionType, Ident, Literal, Type, UnaryOperator};
 use crate::common::env::Env;
 
 #[derive(Debug, Default)]
 pub struct Interpreter<'a> {
-    env: Env<'a, Value<'a>>,
+    env: Env<&'a Ident<'a>, Value<'a>>,
 }
 
 impl<'a> Interpreter<'a> {
@@ -25,18 +24,19 @@ impl<'a> Interpreter<'a> {
             .ok_or_else(|| Error::UndefinedVariable(var.to_owned()))
     }
 
-    pub fn eval(&mut self, expr: &'a Expr<'a>) -> Result<Value<'a>> {
-        match expr {
-            Expr::Ident(id) => self.resolve(id),
-            Expr::Literal(Literal::Int(i)) => Ok((*i).into()),
-            Expr::UnaryOp { op, rhs } => {
+    pub fn eval(&mut self, expr: &'a Expr<'a, Type>) -> Result<Value<'a>> {
+        let res = match expr {
+            Expr::Ident(id, _) => self.resolve(id),
+            Expr::Literal(Literal::Int(i), _) => Ok((*i).into()),
+            Expr::Literal(Literal::Bool(b), _) => Ok((*b).into()),
+            Expr::UnaryOp { op, rhs, .. } => {
                 let rhs = self.eval(rhs)?;
                 match op {
                     UnaryOperator::Neg => -rhs,
                     _ => unimplemented!(),
                 }
             }
-            Expr::BinaryOp { lhs, op, rhs } => {
+            Expr::BinaryOp { lhs, op, rhs, .. } => {
                 let lhs = self.eval(lhs)?;
                 let rhs = self.eval(rhs)?;
                 match op {
@@ -49,7 +49,7 @@ impl<'a> Interpreter<'a> {
                     BinaryOperator::Neq => todo!(),
                 }
             }
-            Expr::Let { bindings, body } => {
+            Expr::Let { bindings, body, .. } => {
                 self.env.push();
                 for Binding { ident, body, .. } in bindings {
                     let val = self.eval(body)?;
@@ -63,6 +63,7 @@ impl<'a> Interpreter<'a> {
                 condition,
                 then,
                 else_,
+                ..
             } => {
                 let condition = self.eval(condition)?;
                 if *(condition.as_type::<bool>()?) {
@@ -71,7 +72,7 @@ impl<'a> Interpreter<'a> {
                     self.eval(else_)
                 }
             }
-            Expr::Call { ref fun, args } => {
+            Expr::Call { ref fun, args, .. } => {
                 let fun = self.eval(fun)?;
                 let expected_type = FunctionType {
                     args: args.iter().map(|_| Type::Int).collect(),
@@ -94,21 +95,26 @@ impl<'a> Interpreter<'a> {
                 }
                 Ok(Value::from(*interpreter.eval(body)?.as_type::<i64>()?))
             }
-            Expr::Fun(fun) => Ok(Value::from(value::Function {
-                // TODO
-                type_: FunctionType {
-                    args: fun.args.iter().map(|_| Type::Int).collect(),
-                    ret: Box::new(Type::Int),
-                },
-                args: fun.args.iter().map(|arg| arg.to_owned()).collect(),
-                body: fun.body.to_owned(),
-            })),
-            Expr::Ascription { expr, .. } => self.eval(expr),
-        }
+            Expr::Fun { args, body, type_ } => {
+                let type_ = match type_ {
+                    Type::Function(ft) => ft.clone(),
+                    _ => unreachable!("Function expression without function type"),
+                };
+
+                Ok(Value::from(value::Function {
+                    // TODO
+                    type_,
+                    args: args.iter().map(|(arg, _)| arg.to_owned()).collect(),
+                    body: (**body).to_owned(),
+                }))
+            }
+        }?;
+        debug_assert_eq!(&res.type_(), expr.type_());
+        Ok(res)
     }
 }
 
-pub fn eval<'a>(expr: &'a Expr<'a>) -> Result<Value> {
+pub fn eval<'a>(expr: &'a Expr<'a, Type>) -> Result<Value> {
     let mut interpreter = Interpreter::new();
     interpreter.eval(expr)
 }
@@ -121,17 +127,18 @@ mod tests {
     use super::*;
     use BinaryOperator::*;
 
-    fn int_lit(i: u64) -> Box<Expr<'static>> {
-        Box::new(Expr::Literal(Literal::Int(i)))
+    fn int_lit(i: u64) -> Box<Expr<'static, Type>> {
+        Box::new(Expr::Literal(Literal::Int(i), Type::Int))
     }
 
-    fn parse_eval<T>(src: &str) -> T
+    fn do_eval<T>(src: &str) -> T
     where
         for<'a> &'a T: TryFrom<&'a Val<'a>>,
         T: Clone + TypeOf,
     {
         let expr = crate::parser::expr(src).unwrap().1;
-        let res = eval(&expr).unwrap();
+        let hir = crate::tc::typecheck_expr(expr).unwrap();
+        let res = eval(&hir).unwrap();
         res.as_type::<T>().unwrap().clone()
     }
 
@@ -141,6 +148,7 @@ mod tests {
             lhs: int_lit(1),
             op: Mul,
             rhs: int_lit(2),
+            type_: Type::Int,
         };
         let res = eval(&expr).unwrap();
         assert_eq!(*res.as_type::<i64>().unwrap(), 2);
@@ -148,19 +156,19 @@ mod tests {
 
     #[test]
     fn variable_shadowing() {
-        let res = parse_eval::<i64>("let x = 1 in (let x = 2 in x) + x");
+        let res = do_eval::<i64>("let x = 1 in (let x = 2 in x) + x");
         assert_eq!(res, 3);
     }
 
     #[test]
     fn conditional_with_equals() {
-        let res = parse_eval::<i64>("let x = 1 in if x == 1 then 2 else 4");
+        let res = do_eval::<i64>("let x = 1 in if x == 1 then 2 else 4");
         assert_eq!(res, 2);
     }
 
     #[test]
     fn function_call() {
-        let res = parse_eval::<i64>("let id = fn x = x in id 1");
+        let res = do_eval::<i64>("let id = fn x = x in id 1");
         assert_eq!(res, 1);
     }
 }