about summary refs log tree commit diff
path: root/src/interpreter
diff options
context:
space:
mode:
Diffstat (limited to 'src/interpreter')
-rw-r--r--src/interpreter/mod.rs56
-rw-r--r--src/interpreter/value.rs101
2 files changed, 120 insertions, 37 deletions
diff --git a/src/interpreter/mod.rs b/src/interpreter/mod.rs
index adff3568c2c3..fc1556d1c292 100644
--- a/src/interpreter/mod.rs
+++ b/src/interpreter/mod.rs
@@ -2,13 +2,13 @@ mod error;
 mod value;
 
 pub use self::error::{Error, Result};
-pub use self::value::Value;
-use crate::ast::{BinaryOperator, Expr, Ident, Literal, UnaryOperator};
+pub use self::value::{Function, Value};
+use crate::ast::{BinaryOperator, Expr, FunctionType, Ident, Literal, Type, UnaryOperator};
 use crate::common::env::Env;
 
 #[derive(Debug, Default)]
 pub struct Interpreter<'a> {
-    env: Env<'a, Value>,
+    env: Env<'a, Value<'a>>,
 }
 
 impl<'a> Interpreter<'a> {
@@ -16,14 +16,14 @@ impl<'a> Interpreter<'a> {
         Self::default()
     }
 
-    fn resolve(&self, var: &'a Ident<'a>) -> Result<Value> {
+    fn resolve(&self, var: &'a Ident<'a>) -> Result<Value<'a>> {
         self.env
             .resolve(var)
             .cloned()
             .ok_or_else(|| Error::UndefinedVariable(var.to_owned()))
     }
 
-    pub fn eval(&mut self, expr: &'a Expr<'a>) -> Result<Value> {
+    pub fn eval(&mut self, expr: &'a Expr<'a>) -> Result<Value<'a>> {
         match expr {
             Expr::Ident(id) => self.resolve(id),
             Expr::Literal(Literal::Int(i)) => Ok((*i).into()),
@@ -63,12 +63,44 @@ impl<'a> Interpreter<'a> {
                 else_,
             } => {
                 let condition = self.eval(condition)?;
-                if *(condition.into_type::<bool>()?) {
+                if *(condition.as_type::<bool>()?) {
                     self.eval(then)
                 } else {
                     self.eval(else_)
                 }
             }
+            Expr::Call { ref fun, args } => {
+                let fun = self.eval(fun)?;
+                let expected_type = FunctionType {
+                    args: args.iter().map(|_| Type::Int).collect(),
+                    ret: Box::new(Type::Int),
+                };
+
+                let Function {
+                    args: function_args,
+                    body,
+                    ..
+                } = fun.as_function(expected_type)?;
+                let arg_values = function_args.iter().zip(
+                    args.iter()
+                        .map(|v| self.eval(v))
+                        .collect::<Result<Vec<_>>>()?,
+                );
+                let mut interpreter = Interpreter::new();
+                for (arg_name, arg_value) in arg_values {
+                    interpreter.env.set(arg_name, arg_value);
+                }
+                Ok(Value::from(*interpreter.eval(body)?.as_type::<i64>()?))
+            }
+            Expr::Fun(fun) => Ok(Value::from(value::Function {
+                // TODO
+                type_: FunctionType {
+                    args: fun.args.iter().map(|_| Type::Int).collect(),
+                    ret: Box::new(Type::Int),
+                },
+                args: fun.args.iter().map(|arg| arg.to_owned()).collect(),
+                body: fun.body.to_owned(),
+            })),
         }
     }
 }
@@ -92,12 +124,12 @@ mod tests {
 
     fn parse_eval<T>(src: &str) -> T
     where
-        for<'a> &'a T: TryFrom<&'a Val>,
+        for<'a> &'a T: TryFrom<&'a Val<'a>>,
         T: Clone + TypeOf,
     {
         let expr = crate::parser::expr(src).unwrap().1;
         let res = eval(&expr).unwrap();
-        res.into_type::<T>().unwrap().clone()
+        res.as_type::<T>().unwrap().clone()
     }
 
     #[test]
@@ -108,7 +140,7 @@ mod tests {
             rhs: int_lit(2),
         };
         let res = eval(&expr).unwrap();
-        assert_eq!(*res.into_type::<i64>().unwrap(), 2);
+        assert_eq!(*res.as_type::<i64>().unwrap(), 2);
     }
 
     #[test]
@@ -122,4 +154,10 @@ mod tests {
         let res = parse_eval::<i64>("let x = 1 in if x == 1 then 2 else 4");
         assert_eq!(res, 2);
     }
+
+    #[test]
+    fn function_call() {
+        let res = parse_eval::<i64>("let id = fn x = x in id 1");
+        assert_eq!(res, 1);
+    }
 }
diff --git a/src/interpreter/value.rs b/src/interpreter/value.rs
index 69e4d4ffeb96..496e9c4230de 100644
--- a/src/interpreter/value.rs
+++ b/src/interpreter/value.rs
@@ -6,108 +6,153 @@ use std::rc::Rc;
 use derive_more::{Deref, From, TryInto};
 
 use super::{Error, Result};
-use crate::ast::Type;
+use crate::ast::{Expr, FunctionType, Ident, Type};
 
-#[derive(Debug, PartialEq, From, TryInto)]
+#[derive(Debug, Clone)]
+pub struct Function<'a> {
+    pub type_: FunctionType,
+    pub args: Vec<Ident<'a>>,
+    pub body: Expr<'a>,
+}
+
+#[derive(From, TryInto)]
 #[try_into(owned, ref)]
-pub enum Val {
+pub enum Val<'a> {
     Int(i64),
     Float(f64),
     Bool(bool),
+    Function(Function<'a>),
+}
+
+impl<'a> fmt::Debug for Val<'a> {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        match self {
+            Val::Int(x) => f.debug_tuple("Int").field(x).finish(),
+            Val::Float(x) => f.debug_tuple("Float").field(x).finish(),
+            Val::Bool(x) => f.debug_tuple("Bool").field(x).finish(),
+            Val::Function(Function { type_, .. }) => {
+                f.debug_struct("Function").field("type_", type_).finish()
+            }
+        }
+    }
 }
 
-impl From<u64> for Val {
+impl<'a> PartialEq for Val<'a> {
+    fn eq(&self, other: &Self) -> bool {
+        match (self, other) {
+            (Val::Int(x), Val::Int(y)) => x == y,
+            (Val::Float(x), Val::Float(y)) => x == y,
+            (Val::Bool(x), Val::Bool(y)) => x == y,
+            (Val::Function(_), Val::Function(_)) => false,
+            (_, _) => false,
+        }
+    }
+}
+
+impl<'a> From<u64> for Val<'a> {
     fn from(i: u64) -> Self {
         Self::from(i as i64)
     }
 }
 
-impl Display for Val {
+impl<'a> Display for Val<'a> {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
         match self {
             Val::Int(x) => x.fmt(f),
             Val::Float(x) => x.fmt(f),
             Val::Bool(x) => x.fmt(f),
+            Val::Function(Function { type_, .. }) => write!(f, "<{}>", type_),
         }
     }
 }
 
-impl Val {
+impl<'a> Val<'a> {
     pub fn type_(&self) -> Type {
         match self {
             Val::Int(_) => Type::Int,
             Val::Float(_) => Type::Float,
             Val::Bool(_) => Type::Bool,
+            Val::Function(Function { type_, .. }) => Type::Function(type_.clone()),
         }
     }
 
-    pub fn into_type<'a, T>(&'a self) -> Result<&'a T>
+    pub fn as_type<'b, T>(&'b self) -> Result<&'b T>
     where
-        T: TypeOf + 'a + Clone,
-        &'a T: TryFrom<&'a Self>,
+        T: TypeOf + 'b + Clone,
+        &'b T: TryFrom<&'b Self>,
     {
         <&T>::try_from(self).map_err(|_| Error::InvalidType {
             actual: self.type_(),
             expected: <T as TypeOf>::type_of(),
         })
     }
+
+    pub fn as_function<'b>(&'b self, function_type: FunctionType) -> Result<&'b Function<'a>> {
+        match self {
+            Val::Function(f) if f.type_ == function_type => Ok(&f),
+            _ => Err(Error::InvalidType {
+                actual: self.type_(),
+                expected: Type::Function(function_type),
+            }),
+        }
+    }
 }
 
 #[derive(Debug, PartialEq, Clone, Deref)]
-pub struct Value(Rc<Val>);
+pub struct Value<'a>(Rc<Val<'a>>);
 
-impl Display for Value {
+impl<'a> Display for Value<'a> {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
         self.0.fmt(f)
     }
 }
 
-impl<T> From<T> for Value
+impl<'a, T> From<T> for Value<'a>
 where
-    Val: From<T>,
+    Val<'a>: From<T>,
 {
     fn from(x: T) -> Self {
         Self(Rc::new(x.into()))
     }
 }
 
-impl Neg for Value {
-    type Output = Result<Value>;
+impl<'a> Neg for Value<'a> {
+    type Output = Result<Value<'a>>;
 
     fn neg(self) -> Self::Output {
-        Ok((-self.into_type::<i64>()?).into())
+        Ok((-self.as_type::<i64>()?).into())
     }
 }
 
-impl Add for Value {
-    type Output = Result<Value>;
+impl<'a> Add for Value<'a> {
+    type Output = Result<Value<'a>>;
 
     fn add(self, rhs: Self) -> Self::Output {
-        Ok((self.into_type::<i64>()? + rhs.into_type::<i64>()?).into())
+        Ok((self.as_type::<i64>()? + rhs.as_type::<i64>()?).into())
     }
 }
 
-impl Sub for Value {
-    type Output = Result<Value>;
+impl<'a> Sub for Value<'a> {
+    type Output = Result<Value<'a>>;
 
     fn sub(self, rhs: Self) -> Self::Output {
-        Ok((self.into_type::<i64>()? - rhs.into_type::<i64>()?).into())
+        Ok((self.as_type::<i64>()? - rhs.as_type::<i64>()?).into())
     }
 }
 
-impl Mul for Value {
-    type Output = Result<Value>;
+impl<'a> Mul for Value<'a> {
+    type Output = Result<Value<'a>>;
 
     fn mul(self, rhs: Self) -> Self::Output {
-        Ok((self.into_type::<i64>()? * rhs.into_type::<i64>()?).into())
+        Ok((self.as_type::<i64>()? * rhs.as_type::<i64>()?).into())
     }
 }
 
-impl Div for Value {
-    type Output = Result<Value>;
+impl<'a> Div for Value<'a> {
+    type Output = Result<Value<'a>>;
 
     fn div(self, rhs: Self) -> Self::Output {
-        Ok((self.into_type::<f64>()? / rhs.into_type::<f64>()?).into())
+        Ok((self.as_type::<f64>()? / rhs.as_type::<f64>()?).into())
     }
 }