about summary refs log tree commit diff
path: root/src/interpreter
diff options
context:
space:
mode:
Diffstat (limited to 'src/interpreter')
-rw-r--r--src/interpreter/error.rs16
-rw-r--r--src/interpreter/mod.rs125
-rw-r--r--src/interpreter/value.rs134
3 files changed, 275 insertions, 0 deletions
diff --git a/src/interpreter/error.rs b/src/interpreter/error.rs
new file mode 100644
index 000000000000..e0299d180553
--- /dev/null
+++ b/src/interpreter/error.rs
@@ -0,0 +1,16 @@
+use std::result;
+
+use thiserror::Error;
+
+use crate::ast::{Ident, Type};
+
+#[derive(Debug, PartialEq, Eq, Error)]
+pub enum Error {
+    #[error("Undefined variable {0}")]
+    UndefinedVariable(Ident<'static>),
+
+    #[error("Unexpected type {actual}, expected type {expected}")]
+    InvalidType { actual: Type, expected: Type },
+}
+
+pub type Result<T> = result::Result<T, Error>;
diff --git a/src/interpreter/mod.rs b/src/interpreter/mod.rs
new file mode 100644
index 000000000000..adff3568c2c3
--- /dev/null
+++ b/src/interpreter/mod.rs
@@ -0,0 +1,125 @@
+mod error;
+mod value;
+
+pub use self::error::{Error, Result};
+pub use self::value::Value;
+use crate::ast::{BinaryOperator, Expr, Ident, Literal, UnaryOperator};
+use crate::common::env::Env;
+
+#[derive(Debug, Default)]
+pub struct Interpreter<'a> {
+    env: Env<'a, Value>,
+}
+
+impl<'a> Interpreter<'a> {
+    pub fn new() -> Self {
+        Self::default()
+    }
+
+    fn resolve(&self, var: &'a Ident<'a>) -> Result<Value> {
+        self.env
+            .resolve(var)
+            .cloned()
+            .ok_or_else(|| Error::UndefinedVariable(var.to_owned()))
+    }
+
+    pub fn eval(&mut self, expr: &'a Expr<'a>) -> Result<Value> {
+        match expr {
+            Expr::Ident(id) => self.resolve(id),
+            Expr::Literal(Literal::Int(i)) => Ok((*i).into()),
+            Expr::UnaryOp { op, rhs } => {
+                let rhs = self.eval(rhs)?;
+                match op {
+                    UnaryOperator::Neg => -rhs,
+                    _ => unimplemented!(),
+                }
+            }
+            Expr::BinaryOp { lhs, op, rhs } => {
+                let lhs = self.eval(lhs)?;
+                let rhs = self.eval(rhs)?;
+                match op {
+                    BinaryOperator::Add => lhs + rhs,
+                    BinaryOperator::Sub => lhs - rhs,
+                    BinaryOperator::Mul => lhs * rhs,
+                    BinaryOperator::Div => lhs / rhs,
+                    BinaryOperator::Pow => todo!(),
+                    BinaryOperator::Equ => Ok(lhs.eq(&rhs).into()),
+                    BinaryOperator::Neq => todo!(),
+                }
+            }
+            Expr::Let { bindings, body } => {
+                self.env.push();
+                for (id, val) in bindings {
+                    let val = self.eval(val)?;
+                    self.env.set(id, val);
+                }
+                let res = self.eval(body)?;
+                self.env.pop();
+                Ok(res)
+            }
+            Expr::If {
+                condition,
+                then,
+                else_,
+            } => {
+                let condition = self.eval(condition)?;
+                if *(condition.into_type::<bool>()?) {
+                    self.eval(then)
+                } else {
+                    self.eval(else_)
+                }
+            }
+        }
+    }
+}
+
+pub fn eval<'a>(expr: &'a Expr<'a>) -> Result<Value> {
+    let mut interpreter = Interpreter::new();
+    interpreter.eval(expr)
+}
+
+#[cfg(test)]
+mod tests {
+    use std::convert::TryFrom;
+
+    use super::value::{TypeOf, Val};
+    use super::*;
+    use BinaryOperator::*;
+
+    fn int_lit(i: u64) -> Box<Expr<'static>> {
+        Box::new(Expr::Literal(Literal::Int(i)))
+    }
+
+    fn parse_eval<T>(src: &str) -> T
+    where
+        for<'a> &'a T: TryFrom<&'a Val>,
+        T: Clone + TypeOf,
+    {
+        let expr = crate::parser::expr(src).unwrap().1;
+        let res = eval(&expr).unwrap();
+        res.into_type::<T>().unwrap().clone()
+    }
+
+    #[test]
+    fn simple_addition() {
+        let expr = Expr::BinaryOp {
+            lhs: int_lit(1),
+            op: Mul,
+            rhs: int_lit(2),
+        };
+        let res = eval(&expr).unwrap();
+        assert_eq!(*res.into_type::<i64>().unwrap(), 2);
+    }
+
+    #[test]
+    fn variable_shadowing() {
+        let res = parse_eval::<i64>("let x = 1 in (let x = 2 in x) + x");
+        assert_eq!(res, 3);
+    }
+
+    #[test]
+    fn conditional_with_equals() {
+        let res = parse_eval::<i64>("let x = 1 in if x == 1 then 2 else 4");
+        assert_eq!(res, 2);
+    }
+}
diff --git a/src/interpreter/value.rs b/src/interpreter/value.rs
new file mode 100644
index 000000000000..69e4d4ffeb96
--- /dev/null
+++ b/src/interpreter/value.rs
@@ -0,0 +1,134 @@
+use std::convert::TryFrom;
+use std::fmt::{self, Display};
+use std::ops::{Add, Div, Mul, Neg, Sub};
+use std::rc::Rc;
+
+use derive_more::{Deref, From, TryInto};
+
+use super::{Error, Result};
+use crate::ast::Type;
+
+#[derive(Debug, PartialEq, From, TryInto)]
+#[try_into(owned, ref)]
+pub enum Val {
+    Int(i64),
+    Float(f64),
+    Bool(bool),
+}
+
+impl From<u64> for Val {
+    fn from(i: u64) -> Self {
+        Self::from(i as i64)
+    }
+}
+
+impl Display for Val {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        match self {
+            Val::Int(x) => x.fmt(f),
+            Val::Float(x) => x.fmt(f),
+            Val::Bool(x) => x.fmt(f),
+        }
+    }
+}
+
+impl Val {
+    pub fn type_(&self) -> Type {
+        match self {
+            Val::Int(_) => Type::Int,
+            Val::Float(_) => Type::Float,
+            Val::Bool(_) => Type::Bool,
+        }
+    }
+
+    pub fn into_type<'a, T>(&'a self) -> Result<&'a T>
+    where
+        T: TypeOf + 'a + Clone,
+        &'a T: TryFrom<&'a Self>,
+    {
+        <&T>::try_from(self).map_err(|_| Error::InvalidType {
+            actual: self.type_(),
+            expected: <T as TypeOf>::type_of(),
+        })
+    }
+}
+
+#[derive(Debug, PartialEq, Clone, Deref)]
+pub struct Value(Rc<Val>);
+
+impl Display for Value {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        self.0.fmt(f)
+    }
+}
+
+impl<T> From<T> for Value
+where
+    Val: From<T>,
+{
+    fn from(x: T) -> Self {
+        Self(Rc::new(x.into()))
+    }
+}
+
+impl Neg for Value {
+    type Output = Result<Value>;
+
+    fn neg(self) -> Self::Output {
+        Ok((-self.into_type::<i64>()?).into())
+    }
+}
+
+impl Add for Value {
+    type Output = Result<Value>;
+
+    fn add(self, rhs: Self) -> Self::Output {
+        Ok((self.into_type::<i64>()? + rhs.into_type::<i64>()?).into())
+    }
+}
+
+impl Sub for Value {
+    type Output = Result<Value>;
+
+    fn sub(self, rhs: Self) -> Self::Output {
+        Ok((self.into_type::<i64>()? - rhs.into_type::<i64>()?).into())
+    }
+}
+
+impl Mul for Value {
+    type Output = Result<Value>;
+
+    fn mul(self, rhs: Self) -> Self::Output {
+        Ok((self.into_type::<i64>()? * rhs.into_type::<i64>()?).into())
+    }
+}
+
+impl Div for Value {
+    type Output = Result<Value>;
+
+    fn div(self, rhs: Self) -> Self::Output {
+        Ok((self.into_type::<f64>()? / rhs.into_type::<f64>()?).into())
+    }
+}
+
+pub trait TypeOf {
+    fn type_of() -> Type;
+}
+
+impl TypeOf for i64 {
+    fn type_of() -> Type {
+        Type::Int
+    }
+}
+
+impl TypeOf for bool {
+    fn type_of() -> Type {
+        Type::Bool
+    }
+}
+
+impl TypeOf for f64 {
+    fn type_of() -> Type {
+        Type::Float
+    }
+}