about summary refs log tree commit diff
path: root/tvix
diff options
context:
space:
mode:
Diffstat (limited to 'tvix')
-rw-r--r--tvix/eval/src/value/attrs.rs9
-rw-r--r--tvix/eval/src/value/attrs/tests.rs22
-rw-r--r--tvix/eval/src/value/list.rs5
-rw-r--r--tvix/eval/src/value/mod.rs37
-rw-r--r--tvix/eval/src/vm.rs2
5 files changed, 55 insertions, 20 deletions
diff --git a/tvix/eval/src/value/attrs.rs b/tvix/eval/src/value/attrs.rs
index 6802d195e8e2..5eb258cc9e47 100644
--- a/tvix/eval/src/value/attrs.rs
+++ b/tvix/eval/src/value/attrs.rs
@@ -11,6 +11,7 @@ use std::fmt::Display;
 use std::rc::Rc;
 
 use crate::errors::ErrorKind;
+use crate::vm::VM;
 
 use super::string::NixString;
 use super::Value;
@@ -287,7 +288,7 @@ impl NixAttrs {
     }
 
     /// Compare `self` against `other` for equality using Nix equality semantics
-    pub fn nix_eq(&self, other: &Self) -> Result<bool, ErrorKind> {
+    pub fn nix_eq(&self, other: &Self, vm: &mut VM) -> Result<bool, ErrorKind> {
         match (&self.0, &other.0) {
             (AttrsRep::Empty, AttrsRep::Empty) => Ok(true),
 
@@ -316,7 +317,7 @@ impl NixAttrs {
                     name: n2,
                     value: v2,
                 },
-            ) => Ok(n1.nix_eq(n2)? && v1.nix_eq(v2)?),
+            ) => Ok(n1.nix_eq(n2, vm)? && v1.nix_eq(v2, vm)?),
 
             (AttrsRep::Map(map), AttrsRep::KV { name, value })
             | (AttrsRep::KV { name, value }, AttrsRep::Map(map)) => {
@@ -327,7 +328,7 @@ impl NixAttrs {
                 if let (Some(m_name), Some(m_value)) =
                     (map.get(&NixString::NAME), map.get(&NixString::VALUE))
                 {
-                    return Ok(name.nix_eq(m_name)? && value.nix_eq(m_value)?);
+                    return Ok(name.nix_eq(m_name, vm)? && value.nix_eq(m_value, vm)?);
                 }
 
                 Ok(false)
@@ -340,7 +341,7 @@ impl NixAttrs {
 
                 for (k, v1) in m1 {
                     if let Some(v2) = m2.get(k) {
-                        if !v1.nix_eq(v2)? {
+                        if !v1.nix_eq(v2, vm)? {
                             return Ok(false);
                         }
                     } else {
diff --git a/tvix/eval/src/value/attrs/tests.rs b/tvix/eval/src/value/attrs/tests.rs
index e96d50855b84..539a07e00f79 100644
--- a/tvix/eval/src/value/attrs/tests.rs
+++ b/tvix/eval/src/value/attrs/tests.rs
@@ -1,24 +1,38 @@
 use super::*;
 
 mod nix_eq {
+    use crate::observer::NoOpObserver;
+
     use super::*;
     use proptest::prelude::ProptestConfig;
     use test_strategy::proptest;
 
     #[proptest(ProptestConfig { cases: 5, ..Default::default() })]
     fn reflexive(x: NixAttrs) {
-        assert!(x.nix_eq(&x).unwrap())
+        let mut observer = NoOpObserver {};
+        let mut vm = VM::new(&mut observer);
+
+        assert!(x.nix_eq(&x, &mut vm).unwrap())
     }
 
     #[proptest(ProptestConfig { cases: 5, ..Default::default() })]
     fn symmetric(x: NixAttrs, y: NixAttrs) {
-        assert_eq!(x.nix_eq(&y).unwrap(), y.nix_eq(&x).unwrap())
+        let mut observer = NoOpObserver {};
+        let mut vm = VM::new(&mut observer);
+
+        assert_eq!(
+            x.nix_eq(&y, &mut vm).unwrap(),
+            y.nix_eq(&x, &mut vm).unwrap()
+        )
     }
 
     #[proptest(ProptestConfig { cases: 5, ..Default::default() })]
     fn transitive(x: NixAttrs, y: NixAttrs, z: NixAttrs) {
-        if x.nix_eq(&y).unwrap() && y.nix_eq(&z).unwrap() {
-            assert!(x.nix_eq(&z).unwrap())
+        let mut observer = NoOpObserver {};
+        let mut vm = VM::new(&mut observer);
+
+        if x.nix_eq(&y, &mut vm).unwrap() && y.nix_eq(&z, &mut vm).unwrap() {
+            assert!(x.nix_eq(&z, &mut vm).unwrap())
         }
     }
 }
diff --git a/tvix/eval/src/value/list.rs b/tvix/eval/src/value/list.rs
index 8563c2256f07..81abb3c07e2f 100644
--- a/tvix/eval/src/value/list.rs
+++ b/tvix/eval/src/value/list.rs
@@ -2,6 +2,7 @@
 use std::fmt::Display;
 
 use crate::errors::ErrorKind;
+use crate::vm::VM;
 
 use super::Value;
 
@@ -83,13 +84,13 @@ impl NixList {
     }
 
     /// Compare `self` against `other` for equality using Nix equality semantics
-    pub fn nix_eq(&self, other: &Self) -> Result<bool, ErrorKind> {
+    pub fn nix_eq(&self, other: &Self, vm: &mut VM) -> Result<bool, ErrorKind> {
         if self.len() != other.len() {
             return Ok(false);
         }
 
         for (v1, v2) in self.iter().zip(other.iter()) {
-            if !v1.nix_eq(v2)? {
+            if !v1.nix_eq(v2, vm)? {
                 return Ok(false);
             }
         }
diff --git a/tvix/eval/src/value/mod.rs b/tvix/eval/src/value/mod.rs
index 34dfbddfa167..8a2ab19961cf 100644
--- a/tvix/eval/src/value/mod.rs
+++ b/tvix/eval/src/value/mod.rs
@@ -254,13 +254,15 @@ impl Value {
     gen_is!(is_number, Value::Integer(_) | Value::Float(_));
     gen_is!(is_bool, Value::Bool(_));
 
-    /// Compare `self` against `other` for equality using Nix equality semantics
-    pub fn nix_eq(&self, other: &Self) -> Result<bool, ErrorKind> {
+    /// Compare `self` against `other` for equality using Nix equality semantics.
+    ///
+    /// Takes a reference to the `VM` to allow forcing thunks during comparison
+    pub fn nix_eq(&self, other: &Self, vm: &mut VM) -> Result<bool, ErrorKind> {
         match (self, other) {
             // Trivial comparisons
             (Value::Null, Value::Null) => Ok(true),
             (Value::Bool(b1), Value::Bool(b2)) => Ok(b1 == b2),
-            (Value::List(l1), Value::List(l2)) => l1.nix_eq(l2),
+            (Value::List(l1), Value::List(l2)) => l1.nix_eq(l2, vm),
             (Value::String(s1), Value::String(s2)) => Ok(s1 == s2),
             (Value::Path(p1), Value::Path(p2)) => Ok(p1 == p2),
 
@@ -271,7 +273,7 @@ impl Value {
             (Value::Float(f), Value::Integer(i)) => Ok(*i as f64 == *f),
 
             // Optimised attribute set comparison
-            (Value::Attrs(a1), Value::Attrs(a2)) => Ok(Rc::ptr_eq(a1, a2) || a1.nix_eq(a2)?),
+            (Value::Attrs(a1), Value::Attrs(a2)) => Ok(Rc::ptr_eq(a1, a2) || a1.nix_eq(a2, vm)?),
 
             // If either value is a thunk, the inner value must be
             // compared instead. The compiler should ensure that
@@ -340,33 +342,50 @@ mod tests {
     fn test_name() {}
 
     mod nix_eq {
+        use crate::observer::NoOpObserver;
+
         use super::*;
         use proptest::prelude::ProptestConfig;
         use test_strategy::proptest;
 
         #[proptest(ProptestConfig { cases: 5, ..Default::default() })]
         fn reflexive(x: Value) {
-            assert!(x.nix_eq(&x).unwrap())
+            let mut observer = NoOpObserver {};
+            let mut vm = VM::new(&mut observer);
+
+            assert!(x.nix_eq(&x, &mut vm).unwrap())
         }
 
         #[proptest(ProptestConfig { cases: 5, ..Default::default() })]
         fn symmetric(x: Value, y: Value) {
-            assert_eq!(x.nix_eq(&y).unwrap(), y.nix_eq(&x).unwrap())
+            let mut observer = NoOpObserver {};
+            let mut vm = VM::new(&mut observer);
+
+            assert_eq!(
+                x.nix_eq(&y, &mut vm).unwrap(),
+                y.nix_eq(&x, &mut vm).unwrap()
+            )
         }
 
         #[proptest(ProptestConfig { cases: 5, ..Default::default() })]
         fn transitive(x: Value, y: Value, z: Value) {
-            if x.nix_eq(&y).unwrap() && y.nix_eq(&z).unwrap() {
-                assert!(x.nix_eq(&z).unwrap())
+            let mut observer = NoOpObserver {};
+            let mut vm = VM::new(&mut observer);
+
+            if x.nix_eq(&y, &mut vm).unwrap() && y.nix_eq(&z, &mut vm).unwrap() {
+                assert!(x.nix_eq(&z, &mut vm).unwrap())
             }
         }
 
         #[test]
         fn list_int_float_fungibility() {
+            let mut observer = NoOpObserver {};
+            let mut vm = VM::new(&mut observer);
+
             let v1 = Value::List(NixList::from(vec![Value::Integer(1)]));
             let v2 = Value::List(NixList::from(vec![Value::Float(1.0)]));
 
-            assert!(v1.nix_eq(&v2).unwrap())
+            assert!(v1.nix_eq(&v2, &mut vm).unwrap())
         }
     }
 }
diff --git a/tvix/eval/src/vm.rs b/tvix/eval/src/vm.rs
index 52819a6e210b..909e219bcd78 100644
--- a/tvix/eval/src/vm.rs
+++ b/tvix/eval/src/vm.rs
@@ -265,7 +265,7 @@ impl<'o> VM<'o> {
                 OpCode::OpEqual => {
                     let v2 = self.pop();
                     let v1 = self.pop();
-                    let res = fallible!(self, v1.nix_eq(&v2));
+                    let res = fallible!(self, v1.nix_eq(&v2, self));
 
                     self.push(Value::Bool(res))
                 }