about summary refs log tree commit diff
diff options
context:
space:
mode:
authorAdam Joseph <adam@westernsemico.com>2022-11-22T04·04-0800
committerAdam Joseph <adam@westernsemico.com>2022-11-27T19·10+0000
commitdad8a7cbffbb0fac850e081e564eb09c67dd2fca (patch)
tree8b22a1affbe3c9175b68012caa5b849dd920c14e
parented9aa0c32a5c14768f887b9c237ab80d7d6be254 (diff)
feat(tvix/eval): non-recursive implementation of nix_eq() r/5345
This passes all the function/thunk-pointer-equality tests in
cl/7369.

Change-Id: Ib47535ba2fc77a4f1c2cc2fd23d3a879e21d8b4c
Signed-off-by: Adam Joseph <adam@westernsemico.com>
Reviewed-on: https://cl.tvl.fyi/c/depot/+/7358
Tested-by: BuildkiteCI
Reviewed-by: tazjin <tazjin@tvl.su>
-rw-r--r--tvix/eval/src/builtins/mod.rs2
-rw-r--r--tvix/eval/src/value/mod.rs30
-rw-r--r--tvix/eval/src/value/thunk.rs6
-rw-r--r--tvix/eval/src/vm.rs129
4 files changed, 133 insertions, 34 deletions
diff --git a/tvix/eval/src/builtins/mod.rs b/tvix/eval/src/builtins/mod.rs
index b2344a422e..0003bddbf2 100644
--- a/tvix/eval/src/builtins/mod.rs
+++ b/tvix/eval/src/builtins/mod.rs
@@ -250,7 +250,7 @@ mod pure_builtins {
     #[builtin("elem")]
     fn builtin_elem(vm: &mut VM, x: Value, xs: Value) -> Result<Value, ErrorKind> {
         for val in xs.to_list()? {
-            if val.nix_eq(&x, vm)? {
+            if vm.nix_eq(val, x.clone(), true)? {
                 return Ok(true.into());
             }
         }
diff --git a/tvix/eval/src/value/mod.rs b/tvix/eval/src/value/mod.rs
index d777816bb3..1763b716ee 100644
--- a/tvix/eval/src/value/mod.rs
+++ b/tvix/eval/src/value/mod.rs
@@ -313,7 +313,6 @@ impl Value {
             // Trivial comparisons
             (Value::Null, Value::Null) => Ok(true),
             (Value::Bool(b1), Value::Bool(b2)) => Ok(b1 == b2),
-            (Value::List(l1), Value::List(l2)) => l1.nix_eq(l2, vm),
             (Value::String(s1), Value::String(s2)) => Ok(s1 == s2),
             (Value::Path(p1), Value::Path(p2)) => Ok(p1 == p2),
 
@@ -323,31 +322,10 @@ impl Value {
             (Value::Float(f1), Value::Float(f2)) => Ok(f1 == f2),
             (Value::Float(f), Value::Integer(i)) => Ok(*i as f64 == *f),
 
-            // Optimised attribute set comparison
-            (Value::Attrs(a1), Value::Attrs(a2)) => Ok(Rc::ptr_eq(a1, a2) || a1.nix_eq(a2, vm)?),
-
-            // If either value is a thunk, the thunk should be forced, and then
-            // the resulting value must be compared instead.
-            (Value::Thunk(lhs), Value::Thunk(rhs)) => {
-                lhs.force(vm)?;
-                rhs.force(vm)?;
-
-                // TODO: this cloning is done because there is a potential issue
-                // with keeping borrows into both thunks around while recursing,
-                // as they might recurse themselves, leading to a borrow error
-                // when they are later being forced.
-                let lhs = lhs.value().clone();
-                let rhs = rhs.value().clone();
-                lhs.nix_eq(&rhs, vm)
-            }
-            (Value::Thunk(lhs), rhs) => {
-                lhs.force(vm)?;
-                lhs.value().nix_eq(rhs, vm)
-            }
-            (lhs, Value::Thunk(rhs)) => {
-                rhs.force(vm)?;
-                lhs.nix_eq(&*rhs.value(), vm)
-            }
+            (Value::Attrs(_), Value::Attrs(_))
+            | (Value::List(_), Value::List(_))
+            | (Value::Thunk(_), _)
+            | (_, Value::Thunk(_)) => Ok(vm.nix_eq(self.clone(), other.clone(), false)?),
 
             // Everything else is either incomparable (e.g. internal
             // types) or false.
diff --git a/tvix/eval/src/value/thunk.rs b/tvix/eval/src/value/thunk.rs
index 1be13bfe89..0d4c26bab4 100644
--- a/tvix/eval/src/value/thunk.rs
+++ b/tvix/eval/src/value/thunk.rs
@@ -190,6 +190,12 @@ impl Thunk {
             thunk => panic!("upvalues() on non-suspended thunk: {thunk:?}"),
         })
     }
+
+    /// Do not use this without first reading and understanding
+    /// `tvix/docs/value-pointer-equality.md`.
+    pub(crate) fn ptr_eq(&self, other: &Self) -> bool {
+        Rc::ptr_eq(&self.0, &other.0)
+    }
 }
 
 impl TotalDisplay for Thunk {
diff --git a/tvix/eval/src/vm.rs b/tvix/eval/src/vm.rs
index 41640a32b3..f7c9a9dd78 100644
--- a/tvix/eval/src/vm.rs
+++ b/tvix/eval/src/vm.rs
@@ -190,6 +190,10 @@ impl<'o> VM<'o> {
         self.stack.pop().expect("runtime stack empty")
     }
 
+    pub fn pop_then_drop(&mut self, num_items: usize) {
+        self.stack.truncate(self.stack.len() - num_items);
+    }
+
     pub fn push(&mut self, value: Value) {
         self.stack.push(value)
     }
@@ -392,6 +396,123 @@ impl<'o> VM<'o> {
         }
     }
 
+    pub(crate) fn nix_eq(
+        &mut self,
+        v1: Value,
+        v2: Value,
+        allow_top_level_pointer_equality_on_functions_and_thunks: bool,
+    ) -> EvalResult<bool> {
+        self.push(v1);
+        self.push(v2);
+        self.nix_op_eq(allow_top_level_pointer_equality_on_functions_and_thunks)?;
+        match self.pop() {
+            Value::Bool(b) => Ok(b),
+            v => panic!("run_op(OpEqual) left a non-boolean on the stack: {v:#?}"),
+        }
+    }
+
+    pub(crate) fn nix_op_eq(
+        &mut self,
+        allow_top_level_pointer_equality_on_functions_and_thunks: bool,
+    ) -> EvalResult<()> {
+        // This bit gets set to `true` (if it isn't already) as soon
+        // as we start comparing the contents of two
+        // {lists,attrsets} -- but *not* the contents of two thunks.
+        // See tvix/docs/value-pointer-equality.md for details.
+        let mut allow_pointer_equality_on_functions_and_thunks =
+            allow_top_level_pointer_equality_on_functions_and_thunks;
+
+        let mut numpairs: usize = 1;
+        let res = 'outer: loop {
+            if numpairs == 0 {
+                break true;
+            } else {
+                numpairs -= 1;
+            }
+            let v2 = self.pop();
+            let v1 = self.pop();
+            let v2 = match v2 {
+                Value::Thunk(thunk) => {
+                    if allow_top_level_pointer_equality_on_functions_and_thunks {
+                        if let Value::Thunk(t1) = &v1 {
+                            if t1.ptr_eq(&thunk) {
+                                continue;
+                            }
+                        }
+                    }
+                    fallible!(self, thunk.force(self));
+                    thunk.value().clone()
+                }
+                v => v,
+            };
+            let v1 = match v1 {
+                Value::Thunk(thunk) => {
+                    fallible!(self, thunk.force(self));
+                    thunk.value().clone()
+                }
+                v => v,
+            };
+            match (v1, v2) {
+                (Value::List(l1), Value::List(l2)) => {
+                    allow_pointer_equality_on_functions_and_thunks = true;
+                    if l1.ptr_eq(&l2) {
+                        continue;
+                    }
+                    if l1.len() != l2.len() {
+                        break false;
+                    }
+                    for (vi1, vi2) in l1.into_iter().zip(l2.into_iter()) {
+                        self.stack.push(vi1);
+                        self.stack.push(vi2);
+                        numpairs += 1;
+                    }
+                }
+                (_, Value::List(_)) => break false,
+                (Value::List(_), _) => break false,
+
+                (Value::Attrs(a1), Value::Attrs(a2)) => {
+                    if allow_pointer_equality_on_functions_and_thunks {
+                        if Rc::ptr_eq(&a1, &a2) {
+                            continue;
+                        }
+                    }
+                    allow_pointer_equality_on_functions_and_thunks = true;
+                    let iter1 = unwrap_or_clone_rc(a1).into_iter_sorted();
+                    let iter2 = unwrap_or_clone_rc(a2).into_iter_sorted();
+                    if iter1.len() != iter2.len() {
+                        break false;
+                    }
+                    for ((k1, v1), (k2, v2)) in iter1.zip(iter2) {
+                        if k1 != k2 {
+                            break 'outer false;
+                        }
+                        self.stack.push(v1);
+                        self.stack.push(v2);
+                        numpairs += 1;
+                    }
+                }
+                (Value::Attrs(_), _) => break false,
+                (_, Value::Attrs(_)) => break false,
+
+                (v1, v2) => {
+                    if allow_pointer_equality_on_functions_and_thunks {
+                        if let (Value::Closure(c1), Value::Closure(c2)) = (&v1, &v2) {
+                            if c1.ptr_eq(c2) {
+                                continue;
+                            }
+                        }
+                    }
+                    if !fallible!(self, v1.nix_eq(&v2, self)) {
+                        break false;
+                    }
+                }
+            }
+        };
+        self.pop_then_drop(numpairs * 2);
+        self.push(Value::Bool(res));
+        Ok(())
+    }
+
     fn run_op(&mut self, op: OpCode) -> EvalResult<()> {
         match op {
             OpCode::OpConstant(idx) => {
@@ -467,13 +588,7 @@ impl<'o> VM<'o> {
                 }
             },
 
-            OpCode::OpEqual => {
-                let v2 = self.pop();
-                let v1 = self.pop();
-                let res = fallible!(self, v1.nix_eq(&v2, self));
-
-                self.push(Value::Bool(res))
-            }
+            OpCode::OpEqual => return self.nix_op_eq(false),
 
             OpCode::OpLess => cmp_op!(self, <),
             OpCode::OpLessOrEq => cmp_op!(self, <=),