about summary refs log tree commit diff
path: root/tvix/eval/src
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/eval/src')
-rw-r--r--tvix/eval/src/value/mod.rs182
-rw-r--r--tvix/eval/src/vm/generators.rs6
2 files changed, 182 insertions, 6 deletions
diff --git a/tvix/eval/src/value/mod.rs b/tvix/eval/src/value/mod.rs
index 07ce83020613..a869c0167524 100644
--- a/tvix/eval/src/value/mod.rs
+++ b/tvix/eval/src/value/mod.rs
@@ -20,8 +20,9 @@ mod path;
 mod string;
 mod thunk;
 
-use crate::errors::ErrorKind;
+use crate::errors::{AddContext, ErrorKind};
 use crate::opcode::StackIdx;
+use crate::vm::generators::{self, GenCo};
 use crate::vm::VM;
 pub use attrs::NixAttrs;
 pub use builtin::{Builtin, BuiltinArgument};
@@ -187,6 +188,21 @@ impl Value {
     }
 }
 
+/// Controls what kind of by-pointer equality comparison is allowed.
+///
+/// See `//tvix/docs/value-pointer-equality.md` for details.
+#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
+pub enum PointerEquality {
+    /// Pointer equality not allowed at all.
+    ForbidAll,
+
+    /// Pointer equality comparisons only allowed for nested values.
+    AllowNested,
+
+    /// Pointer equality comparisons are allowed in all contexts.
+    AllowAll,
+}
+
 impl Value {
     /// Coerce a `Value` to a string. See `CoercionKind` for a rundown of what
     /// input types are accepted under what circumstances.
@@ -307,6 +323,170 @@ impl Value {
         }
     }
 
+    /// Compare two Nix values for equality, forcing nested parts of the structure
+    /// as needed.
+    ///
+    /// This comparison needs to be invoked for nested values (e.g. in lists and
+    /// attribute sets) as well, which is done by suspending and asking the VM to
+    /// perform the nested comparison.
+    ///
+    /// The `top_level` parameter controls whether this invocation is the top-level
+    /// comparison, or a nested value comparison. See
+    /// `//tvix/docs/value-pointer-equality.md`
+    pub(crate) async fn neo_nix_eq(
+        self,
+        other: Value,
+        co: GenCo,
+        ptr_eq: PointerEquality,
+    ) -> Result<Value, ErrorKind> {
+        let a = match self {
+            Value::Thunk(ref thunk) => {
+                // If both values are thunks, and thunk comparisons are allowed by
+                // pointer, do that and move on.
+                if ptr_eq == PointerEquality::AllowAll {
+                    if let Value::Thunk(t1) = &other {
+                        if t1.ptr_eq(thunk) {
+                            return Ok(Value::Bool(true));
+                        }
+                    }
+                };
+
+                generators::request_force(&co, self).await
+            }
+
+            _ => self,
+        };
+
+        let b = match other {
+            Value::Thunk(_) => generators::request_force(&co, other).await,
+            _ => other,
+        };
+
+        debug_assert!(!matches!(a, Value::Thunk(_)));
+        debug_assert!(!matches!(b, Value::Thunk(_)));
+
+        let result = match (a, b) {
+            // Trivial comparisons
+            (Value::Null, Value::Null) => true,
+            (Value::Bool(b1), Value::Bool(b2)) => b1 == b2,
+            (Value::String(s1), Value::String(s2)) => s1 == s2,
+            (Value::Path(p1), Value::Path(p2)) => p1 == p2,
+
+            // Numerical comparisons (they work between float & int)
+            (Value::Integer(i1), Value::Integer(i2)) => i1 == i2,
+            (Value::Integer(i), Value::Float(f)) => i as f64 == f,
+            (Value::Float(f1), Value::Float(f2)) => f1 == f2,
+            (Value::Float(f), Value::Integer(i)) => i as f64 == f,
+
+            // List comparisons
+            (Value::List(l1), Value::List(l2)) => {
+                if ptr_eq >= PointerEquality::AllowNested && l1.ptr_eq(&l2) {
+                    return Ok(Value::Bool(true));
+                }
+
+                if l1.len() != l2.len() {
+                    return Ok(Value::Bool(false));
+                }
+
+                for (vi1, vi2) in l1.into_iter().zip(l2.into_iter()) {
+                    if !generators::check_equality(
+                        &co,
+                        vi1,
+                        vi2,
+                        std::cmp::max(ptr_eq, PointerEquality::AllowNested),
+                    )
+                    .await?
+                    {
+                        return Ok(Value::Bool(false));
+                    }
+                }
+
+                true
+            }
+
+            (_, Value::List(_)) | (Value::List(_), _) => false,
+
+            // Attribute set comparisons
+            (Value::Attrs(a1), Value::Attrs(a2)) => {
+                if ptr_eq >= PointerEquality::AllowNested && a1.ptr_eq(&a2) {
+                    return Ok(Value::Bool(true));
+                }
+
+                // Special-case for derivation comparisons: If both attribute sets
+                // have `type = derivation`, compare them by `outPath`.
+                match (a1.select("type"), a2.select("type")) {
+                    (Some(v1), Some(v2)) => {
+                        let s1 = generators::request_force(&co, v1.clone()).await.to_str();
+                        let s2 = generators::request_force(&co, v2.clone()).await.to_str();
+
+                        if let (Ok(s1), Ok(s2)) = (s1, s2) {
+                            if s1.as_str() == "derivation" && s2.as_str() == "derivation" {
+                                // TODO(tazjin): are the outPaths really required,
+                                // or should it fall through?
+                                let out1 = a1
+                                    .select_required("outPath")
+                                    .context("comparing derivations")?
+                                    .clone();
+
+                                let out2 = a2
+                                    .select_required("outPath")
+                                    .context("comparing derivations")?
+                                    .clone();
+
+                                let result = generators::request_force(&co, out1.clone())
+                                    .await
+                                    .to_str()?
+                                    == generators::request_force(&co, out2.clone())
+                                        .await
+                                        .to_str()?;
+                                return Ok(Value::Bool(result));
+                            }
+                        }
+                    }
+                    _ => {}
+                };
+
+                if a1.len() != a2.len() {
+                    return Ok(Value::Bool(false));
+                }
+
+                let iter1 = a1.into_iter_sorted();
+                let iter2 = a2.into_iter_sorted();
+
+                for ((k1, v1), (k2, v2)) in iter1.zip(iter2) {
+                    if k1 != k2 {
+                        return Ok(Value::Bool(false));
+                    }
+
+                    if !generators::check_equality(
+                        &co,
+                        v1,
+                        v2,
+                        std::cmp::max(ptr_eq, PointerEquality::AllowNested),
+                    )
+                    .await?
+                    {
+                        return Ok(Value::Bool(false));
+                    }
+                }
+
+                true
+            }
+
+            (Value::Attrs(_), _) | (_, Value::Attrs(_)) => false,
+
+            (Value::Closure(c1), Value::Closure(c2)) if ptr_eq >= PointerEquality::AllowNested => {
+                Rc::ptr_eq(&c1, &c2)
+            }
+
+            // Everything else is either incomparable (e.g. internal types) or
+            // false.
+            _ => false,
+        };
+
+        Ok(Value::Bool(result))
+    }
+
     pub fn type_of(&self) -> &'static str {
         match self {
             Value::Null => "null",
diff --git a/tvix/eval/src/vm/generators.rs b/tvix/eval/src/vm/generators.rs
index 3b822b086346..2a6a8fa730d8 100644
--- a/tvix/eval/src/vm/generators.rs
+++ b/tvix/eval/src/vm/generators.rs
@@ -14,17 +14,13 @@ use smol_str::SmolStr;
 use std::fmt::Display;
 use std::future::Future;
 
-use crate::value::SharedThunkSet;
+use crate::value::{PointerEquality, SharedThunkSet};
 use crate::warnings::WarningKind;
 use crate::FileType;
 use crate::NixString;
 
 use super::*;
 
-/// Dummy type, before the actual implementation is in place.
-#[derive(Debug)]
-pub struct PointerEquality {}
-
 // -- Implementation of generic generator logic.
 
 /// States that a generator can be in while being driven by the VM.