diff options
Diffstat (limited to 'src/interpreter/mod.rs')
-rw-r--r-- | src/interpreter/mod.rs | 70 |
1 files changed, 39 insertions, 31 deletions
diff --git a/src/interpreter/mod.rs b/src/interpreter/mod.rs index 00421ee90dc8..85a8928cbf9a 100644 --- a/src/interpreter/mod.rs +++ b/src/interpreter/mod.rs @@ -3,14 +3,13 @@ mod value; pub use self::error::{Error, Result}; pub use self::value::{Function, Value}; -use crate::ast::{ - BinaryOperator, Binding, Expr, FunctionType, Ident, Literal, Type, UnaryOperator, -}; +use crate::ast::hir::{Binding, Expr}; +use crate::ast::{BinaryOperator, FunctionType, Ident, Literal, Type, UnaryOperator}; use crate::common::env::Env; #[derive(Debug, Default)] pub struct Interpreter<'a> { - env: Env<'a, Value<'a>>, + env: Env<&'a Ident<'a>, Value<'a>>, } impl<'a> Interpreter<'a> { @@ -25,18 +24,19 @@ impl<'a> Interpreter<'a> { .ok_or_else(|| Error::UndefinedVariable(var.to_owned())) } - pub fn eval(&mut self, expr: &'a Expr<'a>) -> Result<Value<'a>> { - match expr { - Expr::Ident(id) => self.resolve(id), - Expr::Literal(Literal::Int(i)) => Ok((*i).into()), - Expr::UnaryOp { op, rhs } => { + pub fn eval(&mut self, expr: &'a Expr<'a, Type>) -> Result<Value<'a>> { + let res = match expr { + Expr::Ident(id, _) => self.resolve(id), + Expr::Literal(Literal::Int(i), _) => Ok((*i).into()), + Expr::Literal(Literal::Bool(b), _) => Ok((*b).into()), + Expr::UnaryOp { op, rhs, .. } => { let rhs = self.eval(rhs)?; match op { UnaryOperator::Neg => -rhs, _ => unimplemented!(), } } - Expr::BinaryOp { lhs, op, rhs } => { + Expr::BinaryOp { lhs, op, rhs, .. } => { let lhs = self.eval(lhs)?; let rhs = self.eval(rhs)?; match op { @@ -49,7 +49,7 @@ impl<'a> Interpreter<'a> { BinaryOperator::Neq => todo!(), } } - Expr::Let { bindings, body } => { + Expr::Let { bindings, body, .. } => { self.env.push(); for Binding { ident, body, .. } in bindings { let val = self.eval(body)?; @@ -63,6 +63,7 @@ impl<'a> Interpreter<'a> { condition, then, else_, + .. } => { let condition = self.eval(condition)?; if *(condition.as_type::<bool>()?) { @@ -71,7 +72,7 @@ impl<'a> Interpreter<'a> { self.eval(else_) } } - Expr::Call { ref fun, args } => { + Expr::Call { ref fun, args, .. } => { let fun = self.eval(fun)?; let expected_type = FunctionType { args: args.iter().map(|_| Type::Int).collect(), @@ -94,21 +95,26 @@ impl<'a> Interpreter<'a> { } Ok(Value::from(*interpreter.eval(body)?.as_type::<i64>()?)) } - Expr::Fun(fun) => Ok(Value::from(value::Function { - // TODO - type_: FunctionType { - args: fun.args.iter().map(|_| Type::Int).collect(), - ret: Box::new(Type::Int), - }, - args: fun.args.iter().map(|arg| arg.to_owned()).collect(), - body: fun.body.to_owned(), - })), - Expr::Ascription { expr, .. } => self.eval(expr), - } + Expr::Fun { args, body, type_ } => { + let type_ = match type_ { + Type::Function(ft) => ft.clone(), + _ => unreachable!("Function expression without function type"), + }; + + Ok(Value::from(value::Function { + // TODO + type_, + args: args.iter().map(|(arg, _)| arg.to_owned()).collect(), + body: (**body).to_owned(), + })) + } + }?; + debug_assert_eq!(&res.type_(), expr.type_()); + Ok(res) } } -pub fn eval<'a>(expr: &'a Expr<'a>) -> Result<Value> { +pub fn eval<'a>(expr: &'a Expr<'a, Type>) -> Result<Value> { let mut interpreter = Interpreter::new(); interpreter.eval(expr) } @@ -121,17 +127,18 @@ mod tests { use super::*; use BinaryOperator::*; - fn int_lit(i: u64) -> Box<Expr<'static>> { - Box::new(Expr::Literal(Literal::Int(i))) + fn int_lit(i: u64) -> Box<Expr<'static, Type>> { + Box::new(Expr::Literal(Literal::Int(i), Type::Int)) } - fn parse_eval<T>(src: &str) -> T + fn do_eval<T>(src: &str) -> T where for<'a> &'a T: TryFrom<&'a Val<'a>>, T: Clone + TypeOf, { let expr = crate::parser::expr(src).unwrap().1; - let res = eval(&expr).unwrap(); + let hir = crate::tc::typecheck_expr(expr).unwrap(); + let res = eval(&hir).unwrap(); res.as_type::<T>().unwrap().clone() } @@ -141,6 +148,7 @@ mod tests { lhs: int_lit(1), op: Mul, rhs: int_lit(2), + type_: Type::Int, }; let res = eval(&expr).unwrap(); assert_eq!(*res.as_type::<i64>().unwrap(), 2); @@ -148,19 +156,19 @@ mod tests { #[test] fn variable_shadowing() { - let res = parse_eval::<i64>("let x = 1 in (let x = 2 in x) + x"); + let res = do_eval::<i64>("let x = 1 in (let x = 2 in x) + x"); assert_eq!(res, 3); } #[test] fn conditional_with_equals() { - let res = parse_eval::<i64>("let x = 1 in if x == 1 then 2 else 4"); + let res = do_eval::<i64>("let x = 1 in if x == 1 then 2 else 4"); assert_eq!(res, 2); } #[test] fn function_call() { - let res = parse_eval::<i64>("let id = fn x = x in id 1"); + let res = do_eval::<i64>("let id = fn x = x in id 1"); assert_eq!(res, 1); } } |