diff options
author | Adam Joseph <adam@westernsemico.com> | 2022-11-22T04·04-0800 |
---|---|---|
committer | Adam Joseph <adam@westernsemico.com> | 2022-11-27T19·10+0000 |
commit | dad8a7cbffbb0fac850e081e564eb09c67dd2fca (patch) | |
tree | 8b22a1affbe3c9175b68012caa5b849dd920c14e /tvix | |
parent | ed9aa0c32a5c14768f887b9c237ab80d7d6be254 (diff) |
feat(tvix/eval): non-recursive implementation of nix_eq() r/5345
This passes all the function/thunk-pointer-equality tests in cl/7369. Change-Id: Ib47535ba2fc77a4f1c2cc2fd23d3a879e21d8b4c Signed-off-by: Adam Joseph <adam@westernsemico.com> Reviewed-on: https://cl.tvl.fyi/c/depot/+/7358 Tested-by: BuildkiteCI Reviewed-by: tazjin <tazjin@tvl.su>
Diffstat (limited to 'tvix')
-rw-r--r-- | tvix/eval/src/builtins/mod.rs | 2 | ||||
-rw-r--r-- | tvix/eval/src/value/mod.rs | 30 | ||||
-rw-r--r-- | tvix/eval/src/value/thunk.rs | 6 | ||||
-rw-r--r-- | tvix/eval/src/vm.rs | 129 |
4 files changed, 133 insertions, 34 deletions
diff --git a/tvix/eval/src/builtins/mod.rs b/tvix/eval/src/builtins/mod.rs index b2344a422e63..0003bddbf224 100644 --- a/tvix/eval/src/builtins/mod.rs +++ b/tvix/eval/src/builtins/mod.rs @@ -250,7 +250,7 @@ mod pure_builtins { #[builtin("elem")] fn builtin_elem(vm: &mut VM, x: Value, xs: Value) -> Result<Value, ErrorKind> { for val in xs.to_list()? { - if val.nix_eq(&x, vm)? { + if vm.nix_eq(val, x.clone(), true)? { return Ok(true.into()); } } diff --git a/tvix/eval/src/value/mod.rs b/tvix/eval/src/value/mod.rs index d777816bb31f..1763b716ee45 100644 --- a/tvix/eval/src/value/mod.rs +++ b/tvix/eval/src/value/mod.rs @@ -313,7 +313,6 @@ impl Value { // Trivial comparisons (Value::Null, Value::Null) => Ok(true), (Value::Bool(b1), Value::Bool(b2)) => Ok(b1 == b2), - (Value::List(l1), Value::List(l2)) => l1.nix_eq(l2, vm), (Value::String(s1), Value::String(s2)) => Ok(s1 == s2), (Value::Path(p1), Value::Path(p2)) => Ok(p1 == p2), @@ -323,31 +322,10 @@ impl Value { (Value::Float(f1), Value::Float(f2)) => Ok(f1 == f2), (Value::Float(f), Value::Integer(i)) => Ok(*i as f64 == *f), - // Optimised attribute set comparison - (Value::Attrs(a1), Value::Attrs(a2)) => Ok(Rc::ptr_eq(a1, a2) || a1.nix_eq(a2, vm)?), - - // If either value is a thunk, the thunk should be forced, and then - // the resulting value must be compared instead. - (Value::Thunk(lhs), Value::Thunk(rhs)) => { - lhs.force(vm)?; - rhs.force(vm)?; - - // TODO: this cloning is done because there is a potential issue - // with keeping borrows into both thunks around while recursing, - // as they might recurse themselves, leading to a borrow error - // when they are later being forced. - let lhs = lhs.value().clone(); - let rhs = rhs.value().clone(); - lhs.nix_eq(&rhs, vm) - } - (Value::Thunk(lhs), rhs) => { - lhs.force(vm)?; - lhs.value().nix_eq(rhs, vm) - } - (lhs, Value::Thunk(rhs)) => { - rhs.force(vm)?; - lhs.nix_eq(&*rhs.value(), vm) - } + (Value::Attrs(_), Value::Attrs(_)) + | (Value::List(_), Value::List(_)) + | (Value::Thunk(_), _) + | (_, Value::Thunk(_)) => Ok(vm.nix_eq(self.clone(), other.clone(), false)?), // Everything else is either incomparable (e.g. internal // types) or false. diff --git a/tvix/eval/src/value/thunk.rs b/tvix/eval/src/value/thunk.rs index 1be13bfe893f..0d4c26bab492 100644 --- a/tvix/eval/src/value/thunk.rs +++ b/tvix/eval/src/value/thunk.rs @@ -190,6 +190,12 @@ impl Thunk { thunk => panic!("upvalues() on non-suspended thunk: {thunk:?}"), }) } + + /// Do not use this without first reading and understanding + /// `tvix/docs/value-pointer-equality.md`. + pub(crate) fn ptr_eq(&self, other: &Self) -> bool { + Rc::ptr_eq(&self.0, &other.0) + } } impl TotalDisplay for Thunk { diff --git a/tvix/eval/src/vm.rs b/tvix/eval/src/vm.rs index 41640a32b3c1..f7c9a9dd7871 100644 --- a/tvix/eval/src/vm.rs +++ b/tvix/eval/src/vm.rs @@ -190,6 +190,10 @@ impl<'o> VM<'o> { self.stack.pop().expect("runtime stack empty") } + pub fn pop_then_drop(&mut self, num_items: usize) { + self.stack.truncate(self.stack.len() - num_items); + } + pub fn push(&mut self, value: Value) { self.stack.push(value) } @@ -392,6 +396,123 @@ impl<'o> VM<'o> { } } + pub(crate) fn nix_eq( + &mut self, + v1: Value, + v2: Value, + allow_top_level_pointer_equality_on_functions_and_thunks: bool, + ) -> EvalResult<bool> { + self.push(v1); + self.push(v2); + self.nix_op_eq(allow_top_level_pointer_equality_on_functions_and_thunks)?; + match self.pop() { + Value::Bool(b) => Ok(b), + v => panic!("run_op(OpEqual) left a non-boolean on the stack: {v:#?}"), + } + } + + pub(crate) fn nix_op_eq( + &mut self, + allow_top_level_pointer_equality_on_functions_and_thunks: bool, + ) -> EvalResult<()> { + // This bit gets set to `true` (if it isn't already) as soon + // as we start comparing the contents of two + // {lists,attrsets} -- but *not* the contents of two thunks. + // See tvix/docs/value-pointer-equality.md for details. + let mut allow_pointer_equality_on_functions_and_thunks = + allow_top_level_pointer_equality_on_functions_and_thunks; + + let mut numpairs: usize = 1; + let res = 'outer: loop { + if numpairs == 0 { + break true; + } else { + numpairs -= 1; + } + let v2 = self.pop(); + let v1 = self.pop(); + let v2 = match v2 { + Value::Thunk(thunk) => { + if allow_top_level_pointer_equality_on_functions_and_thunks { + if let Value::Thunk(t1) = &v1 { + if t1.ptr_eq(&thunk) { + continue; + } + } + } + fallible!(self, thunk.force(self)); + thunk.value().clone() + } + v => v, + }; + let v1 = match v1 { + Value::Thunk(thunk) => { + fallible!(self, thunk.force(self)); + thunk.value().clone() + } + v => v, + }; + match (v1, v2) { + (Value::List(l1), Value::List(l2)) => { + allow_pointer_equality_on_functions_and_thunks = true; + if l1.ptr_eq(&l2) { + continue; + } + if l1.len() != l2.len() { + break false; + } + for (vi1, vi2) in l1.into_iter().zip(l2.into_iter()) { + self.stack.push(vi1); + self.stack.push(vi2); + numpairs += 1; + } + } + (_, Value::List(_)) => break false, + (Value::List(_), _) => break false, + + (Value::Attrs(a1), Value::Attrs(a2)) => { + if allow_pointer_equality_on_functions_and_thunks { + if Rc::ptr_eq(&a1, &a2) { + continue; + } + } + allow_pointer_equality_on_functions_and_thunks = true; + let iter1 = unwrap_or_clone_rc(a1).into_iter_sorted(); + let iter2 = unwrap_or_clone_rc(a2).into_iter_sorted(); + if iter1.len() != iter2.len() { + break false; + } + for ((k1, v1), (k2, v2)) in iter1.zip(iter2) { + if k1 != k2 { + break 'outer false; + } + self.stack.push(v1); + self.stack.push(v2); + numpairs += 1; + } + } + (Value::Attrs(_), _) => break false, + (_, Value::Attrs(_)) => break false, + + (v1, v2) => { + if allow_pointer_equality_on_functions_and_thunks { + if let (Value::Closure(c1), Value::Closure(c2)) = (&v1, &v2) { + if c1.ptr_eq(c2) { + continue; + } + } + } + if !fallible!(self, v1.nix_eq(&v2, self)) { + break false; + } + } + } + }; + self.pop_then_drop(numpairs * 2); + self.push(Value::Bool(res)); + Ok(()) + } + fn run_op(&mut self, op: OpCode) -> EvalResult<()> { match op { OpCode::OpConstant(idx) => { @@ -467,13 +588,7 @@ impl<'o> VM<'o> { } }, - OpCode::OpEqual => { - let v2 = self.pop(); - let v1 = self.pop(); - let res = fallible!(self, v1.nix_eq(&v2, self)); - - self.push(Value::Bool(res)) - } + OpCode::OpEqual => return self.nix_op_eq(false), OpCode::OpLess => cmp_op!(self, <), OpCode::OpLessOrEq => cmp_op!(self, <=), |