about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--tvix/castore/src/refscan.rs69
1 files changed, 30 insertions, 39 deletions
diff --git a/tvix/castore/src/refscan.rs b/tvix/castore/src/refscan.rs
index 80a126349746..0b8af296bb50 100644
--- a/tvix/castore/src/refscan.rs
+++ b/tvix/castore/src/refscan.rs
@@ -9,6 +9,7 @@
 use pin_project::pin_project;
 use std::collections::BTreeSet;
 use std::pin::Pin;
+use std::sync::atomic::{AtomicBool, Ordering};
 use std::sync::Arc;
 use std::task::{ready, Poll};
 use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf};
@@ -74,7 +75,7 @@ where
 /// of bytes patterns to scan for.
 pub struct ReferenceScanner<P> {
     pattern: ReferencePattern<P>,
-    matches: Vec<bool>,
+    matches: Vec<AtomicBool>,
 }
 
 impl<P: AsRef<[u8]>> ReferenceScanner<P> {
@@ -82,20 +83,23 @@ impl<P: AsRef<[u8]>> ReferenceScanner<P> {
     /// candidate bytes patterns.
     pub fn new<IP: Into<ReferencePattern<P>>>(pattern: IP) -> Self {
         let pattern = pattern.into();
-        let matches = vec![false; pattern.candidates().len()];
+        let mut matches = Vec::new();
+        for _ in 0..pattern.candidates().len() {
+            matches.push(AtomicBool::new(false));
+        }
         ReferenceScanner { pattern, matches }
     }
 
     /// Scan the given buffer for all non-overlapping matches and collect them
     /// in the scanner.
-    pub fn scan<S: AsRef<[u8]>>(&mut self, haystack: S) {
+    pub fn scan<S: AsRef<[u8]>>(&self, haystack: S) {
         if haystack.as_ref().len() < self.pattern.longest_candidate() {
             return;
         }
 
         if let Some(searcher) = &self.pattern.inner.searcher {
             for m in searcher.find(haystack) {
-                self.matches[m.pat_idx] = true;
+                self.matches[m.pat_idx].store(true, Ordering::Release);
             }
         }
     }
@@ -104,14 +108,17 @@ impl<P: AsRef<[u8]>> ReferenceScanner<P> {
         &self.pattern
     }
 
-    pub fn matches(&self) -> &[bool] {
-        &self.matches
+    pub fn matches(&self) -> Vec<bool> {
+        self.matches
+            .iter()
+            .map(|m| m.load(Ordering::Acquire))
+            .collect()
     }
 
     pub fn candidate_matches(&self) -> impl Iterator<Item = &P> {
         let candidates = self.pattern.candidates();
         self.matches.iter().enumerate().filter_map(|(idx, found)| {
-            if *found {
+            if found.load(Ordering::Acquire) {
                 Some(&candidates[idx])
             } else {
                 None
@@ -130,52 +137,35 @@ impl<P: Clone + Ord + AsRef<[u8]>> ReferenceScanner<P> {
 const DEFAULT_BUF_SIZE: usize = 8 * 1024;
 
 #[pin_project]
-pub struct ReferenceReader<P, R> {
-    scanner: ReferenceScanner<P>,
+pub struct ReferenceReader<'a, P, R> {
+    scanner: &'a ReferenceScanner<P>,
     buffer: Vec<u8>,
     consumed: usize,
     #[pin]
     reader: R,
 }
 
-impl<P, R> ReferenceReader<P, R>
+impl<'a, P, R> ReferenceReader<'a, P, R>
 where
     P: AsRef<[u8]>,
 {
-    pub fn new(pattern: ReferencePattern<P>, reader: R) -> ReferenceReader<P, R> {
-        Self::with_capacity(DEFAULT_BUF_SIZE, pattern, reader)
+    pub fn new(scanner: &'a ReferenceScanner<P>, reader: R) -> Self {
+        Self::with_capacity(DEFAULT_BUF_SIZE, scanner, reader)
     }
 
-    pub fn with_capacity(
-        capacity: usize,
-        pattern: ReferencePattern<P>,
-        reader: R,
-    ) -> ReferenceReader<P, R> {
+    pub fn with_capacity(capacity: usize, scanner: &'a ReferenceScanner<P>, reader: R) -> Self {
         // If capacity is not at least as long as longest_candidate we can't do a scan
-        let capacity = capacity.max(pattern.longest_candidate());
+        let capacity = capacity.max(scanner.pattern().longest_candidate());
         ReferenceReader {
-            scanner: ReferenceScanner::new(pattern),
+            scanner,
             buffer: Vec::with_capacity(capacity),
             consumed: 0,
             reader,
         }
     }
-
-    pub fn scanner(&self) -> &ReferenceScanner<P> {
-        &self.scanner
-    }
-}
-
-impl<P, R> ReferenceReader<P, R>
-where
-    P: Clone + Ord + AsRef<[u8]>,
-{
-    pub fn finalise(self) -> BTreeSet<P> {
-        self.scanner.finalise()
-    }
 }
 
-impl<P, R> AsyncRead for ReferenceReader<P, R>
+impl<'a, P, R> AsyncRead for ReferenceReader<'a, P, R>
 where
     R: AsyncRead,
     P: AsRef<[u8]>,
@@ -193,7 +183,7 @@ where
     }
 }
 
-impl<P, R> AsyncBufRead for ReferenceReader<P, R>
+impl<'a, P, R> AsyncBufRead for ReferenceReader<'a, P, R>
 where
     R: AsyncRead,
     P: AsRef<[u8]>,
@@ -257,7 +247,7 @@ mod tests {
 
     #[test]
     fn test_no_patterns() {
-        let mut scanner: ReferenceScanner<String> = ReferenceScanner::new(vec![]);
+        let scanner: ReferenceScanner<String> = ReferenceScanner::new(vec![]);
 
         scanner.scan(HELLO_DRV);
 
@@ -268,7 +258,7 @@ mod tests {
 
     #[test]
     fn test_single_match() {
-        let mut scanner = ReferenceScanner::new(vec![
+        let scanner = ReferenceScanner::new(vec![
             "/nix/store/4xw8n979xpivdc46a9ndcvyhwgif00hz-bash-5.1-p16".to_string(),
         ]);
         scanner.scan(HELLO_DRV);
@@ -290,7 +280,7 @@ mod tests {
             "/nix/store/fn7zvafq26f0c8b17brs7s95s10ibfzs-emacs-28.2.drv".to_string(),
         ];
 
-        let mut scanner = ReferenceScanner::new(candidates.clone());
+        let scanner = ReferenceScanner::new(candidates.clone());
         scanner.scan(HELLO_DRV);
 
         let result = scanner.finalise();
@@ -317,17 +307,18 @@ mod tests {
             "fn7zvafq26f0c8b17brs7s95s10ibfzs",
         ];
         let pattern = ReferencePattern::new(candidates.clone());
+        let scanner = ReferenceScanner::new(pattern);
         let mut mock = Builder::new();
         for c in HELLO_DRV.as_bytes().chunks(chunk_size) {
             mock.read(c);
         }
         let mock = mock.build();
-        let mut reader = ReferenceReader::with_capacity(capacity, pattern, mock);
+        let mut reader = ReferenceReader::with_capacity(capacity, &scanner, mock);
         let mut s = String::new();
         reader.read_to_string(&mut s).await.unwrap();
         assert_eq!(s, HELLO_DRV);
 
-        let result = reader.finalise();
+        let result = scanner.finalise();
         assert_eq!(result.len(), 3);
 
         for c in candidates[..3].iter() {