about summary refs log tree commit diff
diff options
context:
space:
mode:
authorVincent Ambo <mail@tazj.in>2021-01-14T20·16+0300
committertazjin <mail@tazj.in>2021-01-14T20·26+0000
commit740a9a3565b3c30d8b667ff159b15d9f455e94b9 (patch)
treeed144ed7adf28143b121d9cc14cc8df53cb8015b
parent39439d59e8e9ddb1e2b7802f3aff092d77de7acf (diff)
feat(tazjin/rlox): Implement support for closures r/2109
Change-Id: I0ffc810807a1a6ec90455a4f2d2bd977833005bd
Reviewed-on: https://cl.tvl.fyi/c/depot/+/2396
Reviewed-by: tazjin <mail@tazj.in>
Tested-by: BuildkiteCI
-rw-r--r--users/tazjin/rlox/src/interpreter.rs59
-rw-r--r--users/tazjin/rlox/src/interpreter/tests.rs22
2 files changed, 57 insertions, 24 deletions
diff --git a/users/tazjin/rlox/src/interpreter.rs b/users/tazjin/rlox/src/interpreter.rs
index 5fdde9adac..a1d246ba2e 100644
--- a/users/tazjin/rlox/src/interpreter.rs
+++ b/users/tazjin/rlox/src/interpreter.rs
@@ -18,14 +18,17 @@ mod tests;
 #[derive(Clone, Debug)]
 pub enum Callable {
     Builtin(&'static dyn builtins::Builtin),
-    Function(Rc<parser::Function>),
+    Function {
+        func: Rc<parser::Function>,
+        closure: Rc<RwLock<Environment>>,
+    },
 }
 
 impl Callable {
     fn arity(&self) -> usize {
         match self {
             Callable::Builtin(builtin) => builtin.arity(),
-            Callable::Function(func) => func.params.len(),
+            Callable::Function { func, .. } => func.params.len(),
         }
     }
 
@@ -33,14 +36,15 @@ impl Callable {
         match self {
             Callable::Builtin(builtin) => builtin.call(args),
 
-            Callable::Function(func) => {
+            Callable::Function { func, closure } => {
                 let mut fn_env: Environment = Default::default();
+                fn_env.enclosing = Some(closure.clone());
 
                 for (param, value) in func.params.iter().zip(args.into_iter()) {
                     fn_env.define(param, value)?;
                 }
 
-                let result = lox.interpret_block(Rc::new(RwLock::new(fn_env)), &func.body);
+                let result = lox.interpret_block(Some(Rc::new(RwLock::new(fn_env))), &func.body);
 
                 match result {
                     // extract returned values if applicable
@@ -90,7 +94,7 @@ impl Value {
 }
 
 #[derive(Debug, Default)]
-struct Environment {
+pub struct Environment {
     enclosing: Option<Rc<RwLock<Environment>>>,
     values: HashMap<String, Value>,
 }
@@ -200,13 +204,6 @@ impl Interpreter {
             .get(var)
     }
 
-    fn set_enclosing(&mut self, parent: Rc<RwLock<Environment>>) {
-        self.env
-            .write()
-            .expect("environment lock is poisoned")
-            .enclosing = Some(parent);
-    }
-
     // Interpreter itself
     pub fn interpret(&mut self, program: &Block) -> Result<Value, Error> {
         let mut value = Value::Literal(Literal::Nil);
@@ -228,7 +225,7 @@ impl Interpreter {
                 Value::Literal(Literal::String(output))
             }
             Statement::Var(var) => return self.interpret_var(var),
-            Statement::Block(block) => return self.interpret_block(Default::default(), block),
+            Statement::Block(block) => return self.interpret_block(None, block),
             Statement::If(if_stmt) => return self.interpret_if(if_stmt),
             Statement::While(while_stmt) => return self.interpret_while(while_stmt),
             Statement::Function(func) => return self.interpret_function(func.clone()),
@@ -253,19 +250,24 @@ impl Interpreter {
         Ok(value)
     }
 
+    /// Interpret the block in the supplied environment. If no
+    /// environment is supplied, a new one is created using the
+    /// current one as its parent.
     fn interpret_block(
         &mut self,
-        env: Rc<RwLock<Environment>>,
+        env: Option<Rc<RwLock<Environment>>>,
         block: &parser::Block,
     ) -> Result<Value, Error> {
-        // Initialise a new environment and point it at the parent
-        // (this is a bit tedious because we need to wrap it in and
-        // out of the Rc).
-        //
-        // TODO(tazjin): Refactor this to use Rc on the interpreter itself.
-        let previous = std::mem::replace(&mut self.env, env);
-        self.set_enclosing(previous.clone());
+        let env = match env {
+            Some(env) => env,
+            None => {
+                let env: Rc<RwLock<Environment>> = Default::default();
+                set_enclosing_env(&env, self.env.clone());
+                env
+            }
+        };
 
+        let previous = std::mem::replace(&mut self.env, env);
         let result = self.interpret(block);
 
         // Swap it back, discarding the child env.
@@ -295,9 +297,12 @@ impl Interpreter {
         Ok(value)
     }
 
-    fn interpret_function(&mut self, stmt: Rc<parser::Function>) -> Result<Value, Error> {
-        let name = stmt.name.clone();
-        let value = Value::Callable(Callable::Function(stmt));
+    fn interpret_function(&mut self, func: Rc<parser::Function>) -> Result<Value, Error> {
+        let name = func.name.clone();
+        let value = Value::Callable(Callable::Function {
+            func,
+            closure: self.env.clone(),
+        });
         self.define_var(&name, value.clone())?;
         Ok(value)
     }
@@ -442,3 +447,9 @@ fn eval_truthy(lit: &Value) -> bool {
         false
     }
 }
+
+fn set_enclosing_env(this: &RwLock<Environment>, parent: Rc<RwLock<Environment>>) {
+    this.write()
+        .expect("environment lock is poisoned")
+        .enclosing = Some(parent);
+}
diff --git a/users/tazjin/rlox/src/interpreter/tests.rs b/users/tazjin/rlox/src/interpreter/tests.rs
index 5bc9f0a0a4..875116593e 100644
--- a/users/tazjin/rlox/src/interpreter/tests.rs
+++ b/users/tazjin/rlox/src/interpreter/tests.rs
@@ -76,3 +76,25 @@ add(1, 2, 3);
 
     assert_eq!(Value::Literal(Literal::Number(6.0)), result);
 }
+
+#[test]
+fn test_closure() {
+    let result = parse_eval(
+        r#"
+fun makeCounter() {
+  var i = 0;
+  fun count() {
+    i = i + 1;
+  }
+
+  return count;
+}
+
+var counter = makeCounter();
+counter(); // "1".
+counter(); // "2".
+"#,
+    );
+
+    assert_eq!(Value::Literal(Literal::Number(2.0)), result);
+}