about summary refs log tree commit diff
path: root/tvix/eval/src/vm.rs
diff options
context:
space:
mode:
authorAdam Joseph <adam@westernsemico.com>2022-11-22T04·04-0800
committerAdam Joseph <adam@westernsemico.com>2022-11-27T19·10+0000
commitdad8a7cbffbb0fac850e081e564eb09c67dd2fca (patch)
tree8b22a1affbe3c9175b68012caa5b849dd920c14e /tvix/eval/src/vm.rs
parented9aa0c32a5c14768f887b9c237ab80d7d6be254 (diff)
feat(tvix/eval): non-recursive implementation of nix_eq() r/5345
This passes all the function/thunk-pointer-equality tests in
cl/7369.

Change-Id: Ib47535ba2fc77a4f1c2cc2fd23d3a879e21d8b4c
Signed-off-by: Adam Joseph <adam@westernsemico.com>
Reviewed-on: https://cl.tvl.fyi/c/depot/+/7358
Tested-by: BuildkiteCI
Reviewed-by: tazjin <tazjin@tvl.su>
Diffstat (limited to 'tvix/eval/src/vm.rs')
-rw-r--r--tvix/eval/src/vm.rs129
1 files changed, 122 insertions, 7 deletions
diff --git a/tvix/eval/src/vm.rs b/tvix/eval/src/vm.rs
index 41640a32b3c1..f7c9a9dd7871 100644
--- a/tvix/eval/src/vm.rs
+++ b/tvix/eval/src/vm.rs
@@ -190,6 +190,10 @@ impl<'o> VM<'o> {
         self.stack.pop().expect("runtime stack empty")
     }
 
+    pub fn pop_then_drop(&mut self, num_items: usize) {
+        self.stack.truncate(self.stack.len() - num_items);
+    }
+
     pub fn push(&mut self, value: Value) {
         self.stack.push(value)
     }
@@ -392,6 +396,123 @@ impl<'o> VM<'o> {
         }
     }
 
+    pub(crate) fn nix_eq(
+        &mut self,
+        v1: Value,
+        v2: Value,
+        allow_top_level_pointer_equality_on_functions_and_thunks: bool,
+    ) -> EvalResult<bool> {
+        self.push(v1);
+        self.push(v2);
+        self.nix_op_eq(allow_top_level_pointer_equality_on_functions_and_thunks)?;
+        match self.pop() {
+            Value::Bool(b) => Ok(b),
+            v => panic!("run_op(OpEqual) left a non-boolean on the stack: {v:#?}"),
+        }
+    }
+
+    pub(crate) fn nix_op_eq(
+        &mut self,
+        allow_top_level_pointer_equality_on_functions_and_thunks: bool,
+    ) -> EvalResult<()> {
+        // This bit gets set to `true` (if it isn't already) as soon
+        // as we start comparing the contents of two
+        // {lists,attrsets} -- but *not* the contents of two thunks.
+        // See tvix/docs/value-pointer-equality.md for details.
+        let mut allow_pointer_equality_on_functions_and_thunks =
+            allow_top_level_pointer_equality_on_functions_and_thunks;
+
+        let mut numpairs: usize = 1;
+        let res = 'outer: loop {
+            if numpairs == 0 {
+                break true;
+            } else {
+                numpairs -= 1;
+            }
+            let v2 = self.pop();
+            let v1 = self.pop();
+            let v2 = match v2 {
+                Value::Thunk(thunk) => {
+                    if allow_top_level_pointer_equality_on_functions_and_thunks {
+                        if let Value::Thunk(t1) = &v1 {
+                            if t1.ptr_eq(&thunk) {
+                                continue;
+                            }
+                        }
+                    }
+                    fallible!(self, thunk.force(self));
+                    thunk.value().clone()
+                }
+                v => v,
+            };
+            let v1 = match v1 {
+                Value::Thunk(thunk) => {
+                    fallible!(self, thunk.force(self));
+                    thunk.value().clone()
+                }
+                v => v,
+            };
+            match (v1, v2) {
+                (Value::List(l1), Value::List(l2)) => {
+                    allow_pointer_equality_on_functions_and_thunks = true;
+                    if l1.ptr_eq(&l2) {
+                        continue;
+                    }
+                    if l1.len() != l2.len() {
+                        break false;
+                    }
+                    for (vi1, vi2) in l1.into_iter().zip(l2.into_iter()) {
+                        self.stack.push(vi1);
+                        self.stack.push(vi2);
+                        numpairs += 1;
+                    }
+                }
+                (_, Value::List(_)) => break false,
+                (Value::List(_), _) => break false,
+
+                (Value::Attrs(a1), Value::Attrs(a2)) => {
+                    if allow_pointer_equality_on_functions_and_thunks {
+                        if Rc::ptr_eq(&a1, &a2) {
+                            continue;
+                        }
+                    }
+                    allow_pointer_equality_on_functions_and_thunks = true;
+                    let iter1 = unwrap_or_clone_rc(a1).into_iter_sorted();
+                    let iter2 = unwrap_or_clone_rc(a2).into_iter_sorted();
+                    if iter1.len() != iter2.len() {
+                        break false;
+                    }
+                    for ((k1, v1), (k2, v2)) in iter1.zip(iter2) {
+                        if k1 != k2 {
+                            break 'outer false;
+                        }
+                        self.stack.push(v1);
+                        self.stack.push(v2);
+                        numpairs += 1;
+                    }
+                }
+                (Value::Attrs(_), _) => break false,
+                (_, Value::Attrs(_)) => break false,
+
+                (v1, v2) => {
+                    if allow_pointer_equality_on_functions_and_thunks {
+                        if let (Value::Closure(c1), Value::Closure(c2)) = (&v1, &v2) {
+                            if c1.ptr_eq(c2) {
+                                continue;
+                            }
+                        }
+                    }
+                    if !fallible!(self, v1.nix_eq(&v2, self)) {
+                        break false;
+                    }
+                }
+            }
+        };
+        self.pop_then_drop(numpairs * 2);
+        self.push(Value::Bool(res));
+        Ok(())
+    }
+
     fn run_op(&mut self, op: OpCode) -> EvalResult<()> {
         match op {
             OpCode::OpConstant(idx) => {
@@ -467,13 +588,7 @@ impl<'o> VM<'o> {
                 }
             },
 
-            OpCode::OpEqual => {
-                let v2 = self.pop();
-                let v1 = self.pop();
-                let res = fallible!(self, v1.nix_eq(&v2, self));
-
-                self.push(Value::Bool(res))
-            }
+            OpCode::OpEqual => return self.nix_op_eq(false),
 
             OpCode::OpLess => cmp_op!(self, <),
             OpCode::OpLessOrEq => cmp_op!(self, <=),