about summary refs log tree commit diff
path: root/tvix/eval/src/vm
diff options
context:
space:
mode:
authoredef <edef@edef.eu>2023-12-30T00·52+0000
committeredef <edef@edef.eu>2024-01-16T19·20+0000
commitb624edb2ae0df67f9d70a031563b83a33bb282ef (patch)
treebbf4a6e05f56ee593f6fa70565a3eecdc3b76942 /tvix/eval/src/vm
parent5ac9450395bf88c0716914dd407a5d13e23b49ed (diff)
fix(tvix/eval): lift VM ops over Catchable r/7395
We want to handle bottoms in a consistent fashion. Previously this was
handled by repetitive is_catchable checks, which were not consistently
present.

Change-Id: I9614c479cc6297d1f64efba22b620a26e2a96802
Reviewed-on: https://cl.tvl.fyi/c/depot/+/10485
Reviewed-by: tazjin <tazjin@tvl.su>
Tested-by: BuildkiteCI
Diffstat (limited to 'tvix/eval/src/vm')
-rw-r--r--tvix/eval/src/vm/macros.rs49
-rw-r--r--tvix/eval/src/vm/mod.rs168
2 files changed, 102 insertions, 115 deletions
diff --git a/tvix/eval/src/vm/macros.rs b/tvix/eval/src/vm/macros.rs
index eb386df92fa3..fdb812961172 100644
--- a/tvix/eval/src/vm/macros.rs
+++ b/tvix/eval/src/vm/macros.rs
@@ -36,24 +36,25 @@ macro_rules! arithmetic_op {
 #[macro_export]
 macro_rules! cmp_op {
     ( $vm:ident, $frame:ident, $span:ident, $op:tt ) => {{
-        let b = $vm.stack_pop();
-        let a = $vm.stack_pop();
+        lifted_pop! {
+            $vm(b, a) => {
+                async fn compare(a: Value, b: Value, co: GenCo) -> Result<Value, ErrorKind> {
+                    let a = generators::request_force(&co, a).await;
+                    let b = generators::request_force(&co, b).await;
+                    let span = generators::request_span(&co).await;
+                    let ordering = a.nix_cmp_ordering(b, co, span).await?;
+                    match ordering {
+                        Err(cek) => Ok(Value::Catchable(cek)),
+                        Ok(ordering) => Ok(Value::Bool(cmp_op!(@order $op ordering))),
+                    }
+                }
 
-        async fn compare(a: Value, b: Value, co: GenCo) -> Result<Value, ErrorKind> {
-            let a = generators::request_force(&co, a).await;
-            let b = generators::request_force(&co, b).await;
-            let span = generators::request_span(&co).await;
-            let ordering = a.nix_cmp_ordering(b, co, span).await?;
-            match ordering {
-                Err(cek) => Ok(Value::Catchable(cek)),
-                Ok(ordering) => Ok(Value::Bool(cmp_op!(@order $op ordering))),
+                let gen_span = $frame.current_light_span();
+                $vm.push_call_frame($span, $frame);
+                $vm.enqueue_generator("compare", gen_span, |co| compare(a, b, co));
+                return Ok(false);
             }
         }
-
-        let gen_span = $frame.current_light_span();
-        $vm.push_call_frame($span, $frame);
-        $vm.enqueue_generator("compare", gen_span, |co| compare(a, b, co));
-        return Ok(false);
     }};
 
     (@order < $ordering:expr) => {
@@ -72,3 +73,21 @@ macro_rules! cmp_op {
         matches!($ordering, Ordering::Equal | Ordering::Greater)
     };
 }
+
+#[macro_export]
+macro_rules! lifted_pop {
+    ($vm:ident ($($bind:ident),+) => $body:expr) => {
+        {
+            $(
+                let $bind = $vm.stack_pop();
+            )+
+            $(
+                if $bind.is_catchable() {
+                    $vm.stack.push($bind);
+                    continue;
+                }
+            )+
+            $body
+        }
+    }
+}
diff --git a/tvix/eval/src/vm/mod.rs b/tvix/eval/src/vm/mod.rs
index d9a8a2c411df..398a02cc78c8 100644
--- a/tvix/eval/src/vm/mod.rs
+++ b/tvix/eval/src/vm/mod.rs
@@ -23,6 +23,7 @@ use crate::{
     compiler::GlobalsMap,
     errors::{CatchableErrorKind, Error, ErrorKind, EvalResult},
     io::EvalIO,
+    lifted_pop,
     nix_search_path::NixSearchPath,
     observer::RuntimeObserver,
     opcode::{CodeIdx, Count, JumpOffset, OpCode, StackIdx, UpvalueIdx},
@@ -541,14 +542,8 @@ impl<'o> VM<'o> {
                         ))));
                 }
 
-                OpCode::OpAttrsSelect => {
-                    let key = self.stack_pop();
-                    let attrs = self.stack_pop();
-                    if key.is_catchable() {
-                        self.stack.push(key);
-                    } else if attrs.is_catchable() {
-                        self.stack.push(attrs);
-                    } else {
+                OpCode::OpAttrsSelect => lifted_pop! {
+                    self(key, attrs) => {
                         let key = key.to_str().with_span(&frame, self)?;
                         let attrs = attrs.to_attrs().with_span(&frame, self)?;
 
@@ -565,7 +560,7 @@ impl<'o> VM<'o> {
                             }
                         }
                     }
-                }
+                },
 
                 OpCode::OpJumpIfFalse(JumpOffset(offset)) => {
                     debug_assert!(offset != 0);
@@ -629,16 +624,16 @@ impl<'o> VM<'o> {
                     frame.ip += offset;
                 }
 
-                OpCode::OpEqual => {
-                    let b = self.stack_pop();
-                    let a = self.stack_pop();
-                    let gen_span = frame.current_light_span();
-                    self.push_call_frame(span, frame);
-                    self.enqueue_generator("nix_eq", gen_span.clone(), |co| {
-                        a.nix_eq_owned_genco(b, co, PointerEquality::ForbidAll, gen_span)
-                    });
-                    return Ok(false);
-                }
+                OpCode::OpEqual => lifted_pop! {
+                    self(b, a) => {
+                        let gen_span = frame.current_light_span();
+                        self.push_call_frame(span, frame);
+                        self.enqueue_generator("nix_eq", gen_span.clone(), |co| {
+                            a.nix_eq_owned_genco(b, co, PointerEquality::ForbidAll, gen_span)
+                        });
+                        return Ok(false);
+                    }
+                },
 
                 // These assertion operations error out if the stack
                 // top is not of the expected type. This is necessary
@@ -646,6 +641,7 @@ impl<'o> VM<'o> {
                 // exactly.
                 OpCode::OpAssertBool => {
                     let val = self.stack_peek(0);
+                    // TODO(edef): propagate this into is_bool, since bottom values *are* values of any type
                     if !val.is_catchable() && !val.is_bool() {
                         return frame.error(
                             self,
@@ -659,7 +655,8 @@ impl<'o> VM<'o> {
 
                 OpCode::OpAssertAttrs => {
                     let val = self.stack_peek(0);
-                    if !val.is_attrs() {
+                    // TODO(edef): propagate this into is_attrs, since bottom values *are* values of any type
+                    if !val.is_catchable() && !val.is_attrs() {
                         return frame.error(
                             self,
                             ErrorKind::TypeError {
@@ -672,29 +669,20 @@ impl<'o> VM<'o> {
 
                 OpCode::OpAttrs(Count(count)) => self.run_attrset(&frame, count)?,
 
-                OpCode::OpAttrsUpdate => {
-                    let rhs = self.stack_pop();
-                    let lhs = self.stack_pop();
-                    if lhs.is_catchable() {
-                        self.stack.push(lhs);
-                    } else if rhs.is_catchable() {
-                        self.stack.push(rhs);
-                    } else {
+                OpCode::OpAttrsUpdate => lifted_pop! {
+                    self(rhs, lhs) => {
                         let rhs = rhs.to_attrs().with_span(&frame, self)?;
                         let lhs = lhs.to_attrs().with_span(&frame, self)?;
                         self.stack.push(Value::attrs(lhs.update(*rhs)))
                     }
-                }
+                },
 
-                OpCode::OpInvert => {
-                    let v = self.stack_pop();
-                    if v.is_catchable() {
-                        self.stack.push(v);
-                    } else {
+                OpCode::OpInvert => lifted_pop! {
+                    self(v) => {
                         let v = v.as_bool().with_span(&frame, self)?;
                         self.stack.push(Value::Bool(!v));
                     }
-                }
+                },
 
                 OpCode::OpList(Count(count)) => {
                     let list =
@@ -710,14 +698,8 @@ impl<'o> VM<'o> {
                     }
                 }
 
-                OpCode::OpHasAttr => {
-                    let key = self.stack_pop();
-                    let attrs = self.stack_pop();
-                    if key.is_catchable() {
-                        self.stack.push(key);
-                    } else if attrs.is_catchable() {
-                        self.stack.push(attrs);
-                    } else {
+                OpCode::OpHasAttr => lifted_pop! {
+                    self(key, attrs) => {
                         let key = key.to_str().with_span(&frame, self)?;
                         let result = match attrs {
                             Value::Attrs(attrs) => attrs.contains(key.as_str()),
@@ -729,21 +711,15 @@ impl<'o> VM<'o> {
 
                         self.stack.push(Value::Bool(result));
                     }
-                }
+                },
 
-                OpCode::OpConcat => {
-                    // right hand side, left hand side
-                    match (self.stack_pop(), self.stack_pop()) {
-                        (Value::Catchable(cek), _) | (_, Value::Catchable(cek)) => {
-                            self.stack.push(Value::Catchable(cek));
-                        }
-                        (rhs, lhs) => {
-                            let rhs = rhs.to_list().with_span(&frame, self)?.into_inner();
-                            let lhs = lhs.to_list().with_span(&frame, self)?.into_inner();
-                            self.stack.push(Value::List(NixList::from(lhs + rhs)))
-                        }
+                OpCode::OpConcat => lifted_pop! {
+                    self(rhs, lhs) => {
+                        let rhs = rhs.to_list().with_span(&frame, self)?.into_inner();
+                        let lhs = lhs.to_list().with_span(&frame, self)?.into_inner();
+                        self.stack.push(Value::List(NixList::from(lhs + rhs)))
                     }
-                }
+                },
 
                 OpCode::OpResolveWith => {
                     let ident = self.stack_pop().to_str().with_span(&frame, self)?;
@@ -816,60 +792,52 @@ impl<'o> VM<'o> {
                     }
                 }
 
-                OpCode::OpAdd => {
-                    match (self.stack_pop(), self.stack_pop()) {
-                        (Value::Catchable(cek), _) | (_, Value::Catchable(cek)) => {
-                            self.stack.push(Value::Catchable(cek));
-                        }
+                OpCode::OpAdd => lifted_pop! {
+                    self(b, a) => {
+                        let gen_span = frame.current_light_span();
+                        self.push_call_frame(span, frame);
 
-                        (b, a) => {
-                            let gen_span = frame.current_light_span();
-                            self.push_call_frame(span, frame);
-
-                            // OpAdd can add not just numbers, but also string-like
-                            // things, which requires more VM logic. This operation is
-                            // evaluated in a generator frame.
-                            self.enqueue_generator("add_values", gen_span, |co| {
-                                add_values(co, a, b)
-                            });
-                            return Ok(false);
-                        }
+                        // OpAdd can add not just numbers, but also string-like
+                        // things, which requires more VM logic. This operation is
+                        // evaluated in a generator frame.
+                        self.enqueue_generator("add_values", gen_span, |co| add_values(co, a, b));
+                        return Ok(false);
                     }
-                }
-
-                OpCode::OpSub => {
-                    let b = self.stack_pop();
-                    let a = self.stack_pop();
-                    let result = arithmetic_op!(&a, &b, -).with_span(&frame, self)?;
-                    self.stack.push(result);
-                }
+                },
 
-                OpCode::OpMul => {
-                    let b = self.stack_pop();
-                    let a = self.stack_pop();
-                    let result = arithmetic_op!(&a, &b, *).with_span(&frame, self)?;
-                    self.stack.push(result);
-                }
+                OpCode::OpSub => lifted_pop! {
+                    self(b, a) => {
+                        let result = arithmetic_op!(&a, &b, -).with_span(&frame, self)?;
+                        self.stack.push(result);
+                    }
+                },
 
-                OpCode::OpDiv => {
-                    let b = self.stack_pop();
+                OpCode::OpMul => lifted_pop! {
+                    self(b, a) => {
+                        let result = arithmetic_op!(&a, &b, *).with_span(&frame, self)?;
+                        self.stack.push(result);
+                    }
+                },
 
-                    match b {
-                        Value::Integer(0) => return frame.error(self, ErrorKind::DivisionByZero),
-                        Value::Float(b) if b == 0.0_f64 => {
-                            return frame.error(self, ErrorKind::DivisionByZero)
-                        }
-                        _ => {}
-                    };
+                OpCode::OpDiv => lifted_pop! {
+                    self(b, a) => {
+                        match b {
+                            Value::Integer(0) => return frame.error(self, ErrorKind::DivisionByZero),
+                            Value::Float(b) if b == 0.0_f64 => {
+                                return frame.error(self, ErrorKind::DivisionByZero)
+                            }
+                            _ => {}
+                        };
 
-                    let a = self.stack_pop();
-                    let result = arithmetic_op!(&a, &b, /).with_span(&frame, self)?;
-                    self.stack.push(result);
-                }
+                        let result = arithmetic_op!(&a, &b, /).with_span(&frame, self)?;
+                        self.stack.push(result);
+                    }
+                },
 
                 OpCode::OpNegate => match self.stack_pop() {
                     Value::Integer(i) => self.stack.push(Value::Integer(-i)),
                     Value::Float(f) => self.stack.push(Value::Float(-f)),
+                    Value::Catchable(cex) => self.stack.push(Value::Catchable(cex)),
                     v => {
                         return frame.error(
                             self,