diff options
-rw-r--r-- | tvix/castore/src/refscan.rs | 69 |
1 files changed, 30 insertions, 39 deletions
diff --git a/tvix/castore/src/refscan.rs b/tvix/castore/src/refscan.rs index 80a126349746..0b8af296bb50 100644 --- a/tvix/castore/src/refscan.rs +++ b/tvix/castore/src/refscan.rs @@ -9,6 +9,7 @@ use pin_project::pin_project; use std::collections::BTreeSet; use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::task::{ready, Poll}; use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf}; @@ -74,7 +75,7 @@ where /// of bytes patterns to scan for. pub struct ReferenceScanner<P> { pattern: ReferencePattern<P>, - matches: Vec<bool>, + matches: Vec<AtomicBool>, } impl<P: AsRef<[u8]>> ReferenceScanner<P> { @@ -82,20 +83,23 @@ impl<P: AsRef<[u8]>> ReferenceScanner<P> { /// candidate bytes patterns. pub fn new<IP: Into<ReferencePattern<P>>>(pattern: IP) -> Self { let pattern = pattern.into(); - let matches = vec![false; pattern.candidates().len()]; + let mut matches = Vec::new(); + for _ in 0..pattern.candidates().len() { + matches.push(AtomicBool::new(false)); + } ReferenceScanner { pattern, matches } } /// Scan the given buffer for all non-overlapping matches and collect them /// in the scanner. - pub fn scan<S: AsRef<[u8]>>(&mut self, haystack: S) { + pub fn scan<S: AsRef<[u8]>>(&self, haystack: S) { if haystack.as_ref().len() < self.pattern.longest_candidate() { return; } if let Some(searcher) = &self.pattern.inner.searcher { for m in searcher.find(haystack) { - self.matches[m.pat_idx] = true; + self.matches[m.pat_idx].store(true, Ordering::Release); } } } @@ -104,14 +108,17 @@ impl<P: AsRef<[u8]>> ReferenceScanner<P> { &self.pattern } - pub fn matches(&self) -> &[bool] { - &self.matches + pub fn matches(&self) -> Vec<bool> { + self.matches + .iter() + .map(|m| m.load(Ordering::Acquire)) + .collect() } pub fn candidate_matches(&self) -> impl Iterator<Item = &P> { let candidates = self.pattern.candidates(); self.matches.iter().enumerate().filter_map(|(idx, found)| { - if *found { + if found.load(Ordering::Acquire) { Some(&candidates[idx]) } else { None @@ -130,52 +137,35 @@ impl<P: Clone + Ord + AsRef<[u8]>> ReferenceScanner<P> { const DEFAULT_BUF_SIZE: usize = 8 * 1024; #[pin_project] -pub struct ReferenceReader<P, R> { - scanner: ReferenceScanner<P>, +pub struct ReferenceReader<'a, P, R> { + scanner: &'a ReferenceScanner<P>, buffer: Vec<u8>, consumed: usize, #[pin] reader: R, } -impl<P, R> ReferenceReader<P, R> +impl<'a, P, R> ReferenceReader<'a, P, R> where P: AsRef<[u8]>, { - pub fn new(pattern: ReferencePattern<P>, reader: R) -> ReferenceReader<P, R> { - Self::with_capacity(DEFAULT_BUF_SIZE, pattern, reader) + pub fn new(scanner: &'a ReferenceScanner<P>, reader: R) -> Self { + Self::with_capacity(DEFAULT_BUF_SIZE, scanner, reader) } - pub fn with_capacity( - capacity: usize, - pattern: ReferencePattern<P>, - reader: R, - ) -> ReferenceReader<P, R> { + pub fn with_capacity(capacity: usize, scanner: &'a ReferenceScanner<P>, reader: R) -> Self { // If capacity is not at least as long as longest_candidate we can't do a scan - let capacity = capacity.max(pattern.longest_candidate()); + let capacity = capacity.max(scanner.pattern().longest_candidate()); ReferenceReader { - scanner: ReferenceScanner::new(pattern), + scanner, buffer: Vec::with_capacity(capacity), consumed: 0, reader, } } - - pub fn scanner(&self) -> &ReferenceScanner<P> { - &self.scanner - } -} - -impl<P, R> ReferenceReader<P, R> -where - P: Clone + Ord + AsRef<[u8]>, -{ - pub fn finalise(self) -> BTreeSet<P> { - self.scanner.finalise() - } } -impl<P, R> AsyncRead for ReferenceReader<P, R> +impl<'a, P, R> AsyncRead for ReferenceReader<'a, P, R> where R: AsyncRead, P: AsRef<[u8]>, @@ -193,7 +183,7 @@ where } } -impl<P, R> AsyncBufRead for ReferenceReader<P, R> +impl<'a, P, R> AsyncBufRead for ReferenceReader<'a, P, R> where R: AsyncRead, P: AsRef<[u8]>, @@ -257,7 +247,7 @@ mod tests { #[test] fn test_no_patterns() { - let mut scanner: ReferenceScanner<String> = ReferenceScanner::new(vec![]); + let scanner: ReferenceScanner<String> = ReferenceScanner::new(vec![]); scanner.scan(HELLO_DRV); @@ -268,7 +258,7 @@ mod tests { #[test] fn test_single_match() { - let mut scanner = ReferenceScanner::new(vec![ + let scanner = ReferenceScanner::new(vec![ "/nix/store/4xw8n979xpivdc46a9ndcvyhwgif00hz-bash-5.1-p16".to_string(), ]); scanner.scan(HELLO_DRV); @@ -290,7 +280,7 @@ mod tests { "/nix/store/fn7zvafq26f0c8b17brs7s95s10ibfzs-emacs-28.2.drv".to_string(), ]; - let mut scanner = ReferenceScanner::new(candidates.clone()); + let scanner = ReferenceScanner::new(candidates.clone()); scanner.scan(HELLO_DRV); let result = scanner.finalise(); @@ -317,17 +307,18 @@ mod tests { "fn7zvafq26f0c8b17brs7s95s10ibfzs", ]; let pattern = ReferencePattern::new(candidates.clone()); + let scanner = ReferenceScanner::new(pattern); let mut mock = Builder::new(); for c in HELLO_DRV.as_bytes().chunks(chunk_size) { mock.read(c); } let mock = mock.build(); - let mut reader = ReferenceReader::with_capacity(capacity, pattern, mock); + let mut reader = ReferenceReader::with_capacity(capacity, &scanner, mock); let mut s = String::new(); reader.read_to_string(&mut s).await.unwrap(); assert_eq!(s, HELLO_DRV); - let result = reader.finalise(); + let result = scanner.finalise(); assert_eq!(result.len(), 3); for c in candidates[..3].iter() { |