about summary refs log tree commit diff
diff options
context:
space:
mode:
authorAdam Joseph <adam@westernsemico.com>2022-12-09T14·27+0300
committertazjin <tazjin@tvl.su>2022-12-25T18·17+0000
commit67d508f2ece710714ce8abf6f7deba1fd2440487 (patch)
treea41e2838e5d740ce7577f367fdca81ded2a3470e
parent4cda236c0c4513e4be9668ede727a8aac5ba1223 (diff)
refactor(tvix/eval): non-recursive thunk forcing r/5485
Introduces continuation-passing-based trampolining of thunk forcing to
avoid recursing when forcing deeply nested expressions.

This is required for evaluating large expressions.

This change was extracted out of cl/7362.

Co-authored-by: Vincent Ambo <tazjin@tvl.su>
Co-authored-by: Griffin Smith <grfn@gws.fyi>
Change-Id: Ifc1747e712663684b2fff53095de62b8459a47f3
Reviewed-on: https://cl.tvl.fyi/c/depot/+/7551
Reviewed-by: grfn <grfn@gws.fyi>
Tested-by: BuildkiteCI
Reviewed-by: tazjin <tazjin@tvl.su>
-rw-r--r--tvix/eval/src/value/thunk.rs144
-rw-r--r--tvix/eval/src/vm.rs166
2 files changed, 243 insertions, 67 deletions
diff --git a/tvix/eval/src/value/thunk.rs b/tvix/eval/src/value/thunk.rs
index 680cc9b1fb..c7cdfa2441 100644
--- a/tvix/eval/src/value/thunk.rs
+++ b/tvix/eval/src/value/thunk.rs
@@ -30,7 +30,7 @@ use crate::{
     spans::LightSpan,
     upvalues::Upvalues,
     value::{Builtin, Closure},
-    vm::VM,
+    vm::{Trampoline, TrampolineAction, VM},
     Value,
 };
 
@@ -84,10 +84,6 @@ impl Thunk {
         })))
     }
 
-    /// Create a new thunk from suspended Rust code.
-    ///
-    /// The suspended code will be executed and expected to return a
-    /// value whenever the thunk is forced like any other thunk.
     pub fn new_suspended_native(
         native: Rc<Box<dyn Fn(&mut VM) -> Result<Value, ErrorKind>>>,
     ) -> Self {
@@ -103,7 +99,7 @@ impl Thunk {
             None,
             move |v: Vec<Value>, vm: &mut VM| {
                 // sanity check that only the dummy argument was popped
-                assert_eq!(v.len(), 1);
+                assert!(v.len() == 1);
                 assert!(matches!(v[0], Value::Null));
                 native(vm)
             },
@@ -127,41 +123,103 @@ impl Thunk {
         })))
     }
 
+    /// Force a thunk from a context that can't handle trampoline
+    /// continuations, eg outside the VM's normal execution loop.  Calling
+    /// `force_trampoline()` instead should be preferred whenever possible.
+    pub fn force(&self, vm: &mut VM) -> Result<(), ErrorKind> {
+        if self.is_forced() {
+            return Ok(());
+        }
+        vm.push(Value::Thunk(self.clone()));
+        let mut trampoline = Self::force_trampoline(vm)?;
+        loop {
+            match trampoline.action {
+                None => (),
+                Some(TrampolineAction::EnterFrame {
+                    lambda,
+                    upvalues,
+                    arg_count,
+                    light_span: _,
+                }) => vm.enter_frame(lambda, upvalues, arg_count)?,
+            }
+            match trampoline.continuation {
+                None => break (),
+                Some(cont) => {
+                    trampoline = cont(vm)?;
+                    continue;
+                }
+            }
+        }
+        vm.pop();
+        Ok(())
+    }
+
     /// Evaluate the content of a thunk, potentially repeatedly, until a
     /// non-thunk value is returned.
     ///
     /// This will change the existing thunk (and thus all references to it,
     /// providing memoization) through interior mutability. In case of nested
     /// thunks, the intermediate thunk representations are replaced.
-    pub fn force(&self, vm: &mut VM) -> Result<(), ErrorKind> {
+    ///
+    /// The thunk to be forced should be at the top of the VM stack,
+    /// and will be left there (but possibly partially forced) when
+    /// this function returns.
+    pub fn force_trampoline(vm: &mut VM) -> Result<Trampoline, ErrorKind> {
+        match vm.pop() {
+            Value::Thunk(thunk) => thunk.force_trampoline_self(vm),
+            v => {
+                vm.push(v);
+                Ok(Trampoline::default())
+            }
+        }
+    }
+
+    fn force_trampoline_self(&self, vm: &mut VM) -> Result<Trampoline, ErrorKind> {
         loop {
-            let mut thunk_mut = self.0.borrow_mut();
+            if !self.is_suspended() {
+                let thunk = self.0.borrow();
+                match *thunk {
+                    ThunkRepr::Evaluated(Value::Thunk(ref inner_thunk)) => {
+                        let inner_repr = inner_thunk.0.borrow().clone();
+                        drop(thunk);
+                        self.0.replace(inner_repr);
+                    }
 
-            match *thunk_mut {
-                ThunkRepr::Evaluated(Value::Thunk(ref inner_thunk)) => {
-                    let inner_repr = inner_thunk.0.borrow().clone();
-                    *thunk_mut = inner_repr;
+                    ThunkRepr::Evaluated(ref v) => {
+                        vm.push(v.clone());
+                        return Ok(Trampoline::default());
+                    }
+                    ThunkRepr::Blackhole => return Err(ErrorKind::InfiniteRecursion),
+                    _ => panic!("impossible"),
                 }
-
-                ThunkRepr::Evaluated(_) => return Ok(()),
-                ThunkRepr::Blackhole => return Err(ErrorKind::InfiniteRecursion),
-
-                ThunkRepr::Suspended { .. } => {
-                    if let ThunkRepr::Suspended {
+            } else {
+                match self.0.replace(ThunkRepr::Blackhole) {
+                    ThunkRepr::Suspended {
                         lambda,
                         upvalues,
                         light_span,
-                    } = std::mem::replace(&mut *thunk_mut, ThunkRepr::Blackhole)
-                    {
-                        drop(thunk_mut);
-                        vm.enter_frame(lambda, upvalues, 0).map_err(|e| {
-                            ErrorKind::ThunkForce(Box::new(Error {
-                                span: light_span.span(),
-                                ..e
-                            }))
-                        })?;
-                        (*self.0.borrow_mut()) = ThunkRepr::Evaluated(vm.pop())
+                    } => {
+                        let self_clone = self.clone();
+                        return Ok(Trampoline {
+                            action: Some(TrampolineAction::EnterFrame {
+                                lambda,
+                                upvalues: upvalues.clone(),
+                                arg_count: 0,
+                                light_span: light_span.clone(),
+                            }),
+                            continuation: Some(Box::new(move |vm| {
+                                let should_be_blackhole =
+                                    self_clone.0.replace(ThunkRepr::Evaluated(vm.pop()));
+                                assert!(matches!(should_be_blackhole, ThunkRepr::Blackhole));
+                                vm.push(Value::Thunk(self_clone));
+                                return Self::force_trampoline(vm).map_err(|kind| Error {
+                                    kind,
+                                    span: light_span.span(),
+                                });
+                            })),
+                        });
                     }
+                    _ => panic!("impossible"),
                 }
             }
         }
@@ -175,6 +233,20 @@ impl Thunk {
         matches!(*self.0.borrow(), ThunkRepr::Evaluated(_))
     }
 
+    pub fn is_suspended(&self) -> bool {
+        matches!(*self.0.borrow(), ThunkRepr::Suspended { .. })
+    }
+
+    /// Returns true if forcing this thunk will not change it.
+    pub fn is_forced(&self) -> bool {
+        match *self.0.borrow() {
+            ThunkRepr::Blackhole => panic!("is_forced() called on a blackholed thunk"),
+            ThunkRepr::Evaluated(Value::Thunk(_)) => false,
+            ThunkRepr::Evaluated(_) => true,
+            _ => false,
+        }
+    }
+
     /// Returns a reference to the inner evaluated value of a thunk.
     /// It is an error to call this on a thunk that has not been
     /// forced, or is not otherwise known to be fully evaluated.
@@ -183,7 +255,21 @@ impl Thunk {
     // API too much.
     pub fn value(&self) -> Ref<Value> {
         Ref::map(self.0.borrow(), |thunk| match thunk {
-            ThunkRepr::Evaluated(value) => value,
+            ThunkRepr::Evaluated(value) => {
+                /*
+                #[cfg(debug_assertions)]
+                if matches!(
+                    value,
+                    Value::Closure(Closure {
+                        is_finalised: false,
+                        ..
+                    })
+                ) {
+                    panic!("Thunk::value called on an unfinalised closure");
+                }
+                */
+                return value;
+            }
             ThunkRepr::Blackhole => panic!("Thunk::value called on a black-holed thunk"),
             ThunkRepr::Suspended { .. } => panic!("Thunk::value called on a suspended thunk"),
         })
diff --git a/tvix/eval/src/vm.rs b/tvix/eval/src/vm.rs
index b074bd4224..6c0d1157ec 100644
--- a/tvix/eval/src/vm.rs
+++ b/tvix/eval/src/vm.rs
@@ -18,6 +18,52 @@ use crate::{
     warnings::{EvalWarning, WarningKind},
 };
 
+/// Representation of a VM continuation;
+/// see: https://en.wikipedia.org/wiki/Continuation-passing_style#CPS_in_Haskell
+type Continuation = Box<dyn FnOnce(&mut VM) -> EvalResult<Trampoline>>;
+
+/// A description of how to continue evaluation of a thunk when returned to by the VM
+///
+/// This struct is used when forcing thunks to avoid stack-based recursion, which for deeply nested
+/// evaluation can easily overflow the stack.
+#[must_use = "this `Trampoline` may be a continuation request, which should be handled"]
+#[derive(Default)]
+pub struct Trampoline {
+    /// The action to perform upon return to the trampoline
+    pub action: Option<TrampolineAction>,
+
+    /// The continuation to execute after the action has completed
+    pub continuation: Option<Continuation>,
+}
+
+impl Trampoline {
+    /// Add the execution of a new [`Continuation`] to the existing continuation
+    /// of this `Trampoline`, returning the resulting `Trampoline`.
+    pub fn append_to_continuation(self, f: Continuation) -> Self {
+        Trampoline {
+            action: self.action,
+            continuation: match self.continuation {
+                None => Some(f),
+                Some(f0) => Some(Box::new(move |vm| {
+                    let trampoline = f0(vm)?;
+                    Ok(trampoline.append_to_continuation(f))
+                })),
+            },
+        }
+    }
+}
+
+/// Description of an action to perform upon return to a [`Trampoline`] by the VM
+pub enum TrampolineAction {
+    /// Enter a new stack frame
+    EnterFrame {
+        lambda: Rc<Lambda>,
+        upvalues: Rc<Upvalues>,
+        light_span: LightSpan,
+        arg_count: usize,
+    },
+}
+
 struct CallFrame {
     /// The lambda currently being executed.
     lambda: Rc<Lambda>,
@@ -32,6 +78,8 @@ struct CallFrame {
 
     /// Stack offset, i.e. the frames "view" into the VM's full stack.
     stack_offset: usize,
+
+    continuation: Option<Continuation>,
 }
 
 impl CallFrame {
@@ -324,7 +372,6 @@ impl<'o> VM<'o> {
         Ok(res)
     }
 
-    #[inline(always)]
     fn tail_call_value(&mut self, callable: Value) -> EvalResult<()> {
         match callable {
             Value::Builtin(builtin) => self.call_builtin(builtin),
@@ -362,8 +409,8 @@ impl<'o> VM<'o> {
         }
     }
 
-    /// Execute the given lambda in this VM's context, returning its
-    /// value after its stack frame completes.
+    /// Execute the given lambda in this VM's context, leaving the
+    /// computed value on its stack after the frame completes.
     pub fn enter_frame(
         &mut self,
         lambda: Rc<Lambda>,
@@ -378,10 +425,33 @@ impl<'o> VM<'o> {
             upvalues,
             ip: CodeIdx(0),
             stack_offset: self.stack.len() - arg_count,
+            continuation: None,
         };
 
+        let starting_frames_depth = self.frames.len();
         self.frames.push(frame);
-        let result = self.run();
+
+        let result = loop {
+            let op = self.inc_ip();
+
+            self.observer
+                .observe_execute_op(self.frame().ip, &op, &self.stack);
+
+            let res = self.run_op(op);
+
+            let mut retrampoline: Option<Continuation> = None;
+
+            // we need to pop the frame before checking `res` for an
+            // error in order to implement `tryEval` correctly.
+            if self.frame().ip.0 == self.chunk().code.len() {
+                let frame = self.frames.pop();
+                retrampoline = frame.and_then(|frame| frame.continuation);
+            }
+            self.trampoline_loop(res?, retrampoline)?;
+            if self.frames.len() == starting_frames_depth {
+                break Ok(());
+            }
+        };
 
         self.observer
             .observe_exit_frame(self.frames.len() + 1, &self.stack);
@@ -389,35 +459,53 @@ impl<'o> VM<'o> {
         result
     }
 
-    /// Run the VM's current call frame to completion.
-    ///
-    /// On successful return, the top of the stack is the value that
-    /// the frame evaluated to. The frame itself is popped off. It is
-    /// up to the caller to consume the value.
-    fn run(&mut self) -> EvalResult<()> {
+    fn trampoline_loop(
+        &mut self,
+        mut trampoline: Trampoline,
+        mut retrampoline: Option<Continuation>,
+    ) -> EvalResult<()> {
         loop {
-            // Break the loop if this call frame has already run to
-            // completion, pop it off, and return the value to the
-            // caller.
-            if self.frame().ip.0 == self.chunk().code.len() {
-                self.frames.pop();
-                return Ok(());
+            if let Some(TrampolineAction::EnterFrame {
+                lambda,
+                upvalues,
+                arg_count,
+                light_span: _,
+            }) = trampoline.action
+            {
+                let frame = CallFrame {
+                    lambda,
+                    upvalues,
+                    ip: CodeIdx(0),
+                    stack_offset: self.stack.len() - arg_count,
+                    continuation: match retrampoline {
+                        None => trampoline.continuation,
+                        Some(retrampoline) => match trampoline.continuation {
+                            None => None,
+                            Some(cont) => Some(Box::new(|vm| {
+                                Ok(cont(vm)?.append_to_continuation(retrampoline))
+                            })),
+                        },
+                    },
+                };
+                self.frames.push(frame);
+                break;
             }
 
-            let op = self.inc_ip();
-
-            self.observer
-                .observe_execute_op(self.frame().ip, &op, &self.stack);
-
-            let res = self.run_op(op);
-
-            if self.frame().ip.0 == self.chunk().code.len() {
-                self.frames.pop();
-                return res;
-            } else {
-                res?;
+            match trampoline.continuation {
+                None => {
+                    if let Some(cont) = retrampoline.take() {
+                        trampoline = cont(self)?;
+                    } else {
+                        break;
+                    }
+                }
+                Some(cont) => {
+                    trampoline = cont(self)?;
+                    continue;
+                }
             }
         }
+        Ok(())
     }
 
     pub(crate) fn nix_eq(
@@ -428,7 +516,8 @@ impl<'o> VM<'o> {
     ) -> EvalResult<bool> {
         self.push(v1);
         self.push(v2);
-        self.nix_op_eq(allow_top_level_pointer_equality_on_functions_and_thunks)?;
+        let res = self.nix_op_eq(allow_top_level_pointer_equality_on_functions_and_thunks);
+        self.trampoline_loop(res?, None)?;
         match self.pop() {
             Value::Bool(b) => Ok(b),
             v => panic!("run_op(OpEqual) left a non-boolean on the stack: {v:#?}"),
@@ -438,7 +527,7 @@ impl<'o> VM<'o> {
     pub(crate) fn nix_op_eq(
         &mut self,
         allow_top_level_pointer_equality_on_functions_and_thunks: bool,
-    ) -> EvalResult<()> {
+    ) -> EvalResult<Trampoline> {
         // This bit gets set to `true` (if it isn't already) as soon
         // as we start comparing the contents of two
         // {lists,attrsets} -- but *not* the contents of two thunks.
@@ -566,10 +655,10 @@ impl<'o> VM<'o> {
         };
         self.pop_then_drop(numpairs * 2);
         self.push(Value::Bool(res));
-        Ok(())
+        Ok(Trampoline::default())
     }
 
-    fn run_op(&mut self, op: OpCode) -> EvalResult<()> {
+    pub(crate) fn run_op(&mut self, op: OpCode) -> EvalResult<Trampoline> {
         match op {
             OpCode::OpConstant(idx) => {
                 let c = self.chunk()[idx].clone();
@@ -918,14 +1007,15 @@ impl<'o> VM<'o> {
             }
 
             OpCode::OpForce => {
-                let mut value = self.pop();
+                let value = self.pop();
 
                 if let Value::Thunk(thunk) = value {
-                    fallible!(self, thunk.force(self));
-                    value = thunk.value().clone();
+                    self.push(Value::Thunk(thunk));
+                    let trampoline = fallible!(self, Thunk::force_trampoline(self));
+                    return Ok(trampoline);
+                } else {
+                    self.push(value);
                 }
-
-                self.push(value);
             }
 
             OpCode::OpFinalise(StackIdx(idx)) => {
@@ -953,7 +1043,7 @@ impl<'o> VM<'o> {
             }
         }
 
-        Ok(())
+        Ok(Trampoline::default())
     }
 
     fn run_attrset(&mut self, count: usize) -> EvalResult<()> {