about summary refs log tree commit diff
path: root/tvix/nix-compat/src
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/nix-compat/src')
-rw-r--r--tvix/nix-compat/src/nixbase32.rs64
1 files changed, 42 insertions, 22 deletions
diff --git a/tvix/nix-compat/src/nixbase32.rs b/tvix/nix-compat/src/nixbase32.rs
index d70e4c1a8290..558d7ae8f48e 100644
--- a/tvix/nix-compat/src/nixbase32.rs
+++ b/tvix/nix-compat/src/nixbase32.rs
@@ -51,17 +51,21 @@ pub fn encode(input: &[u8]) -> String {
 }
 
 /// This maps a nixbase32-encoded character to its binary representation, which
-/// is also the index of the character in the alphabet.
-fn decode_char(encoded_char: u8) -> Option<u8> {
-    Some(match encoded_char {
-        b'0'..=b'9' => encoded_char - b'0',
-        b'a'..=b'd' => encoded_char - b'a' + 10_u8,
-        b'f'..=b'n' => encoded_char - b'f' + 14_u8,
-        b'p'..=b's' => encoded_char - b'p' + 23_u8,
-        b'v'..=b'z' => encoded_char - b'v' + 27_u8,
-        _ => return None,
-    })
-}
+/// is also the index of the character in the alphabet. Invalid characters are
+/// mapped to 0xFF, which is itself an invalid value.
+const BASE32_ORD: [u8; 256] = {
+    let mut ord = [0xFF; 256];
+    let mut alphabet = ALPHABET.as_slice();
+    let mut i = 0;
+
+    while let &[c, ref tail @ ..] = alphabet {
+        ord[c as usize] = i;
+        alphabet = tail;
+        i += 1;
+    }
+
+    ord
+};
 
 /// Returns decoded input
 pub fn decode(input: &[u8]) -> Result<Vec<u8>, Nixbase32DecodeError> {
@@ -70,18 +74,23 @@ pub fn decode(input: &[u8]) -> Result<Vec<u8>, Nixbase32DecodeError> {
 
     // loop over all characters in reverse, and keep the iteration count in n.
     let mut carry = 0;
+    let mut mask = 0;
     for (n, &c) in input.iter().rev().enumerate() {
-        if let Some(digit) = decode_char(c) {
-            let b = n * 5;
-            let i = b / 8;
-            let j = b % 8;
-
-            let value = (digit as u16) << j;
-            output[i] |= value as u8 | carry;
-            carry = (value >> 8) as u8;
-        } else {
-            return Err(Nixbase32DecodeError::CharacterNotInAlphabet(c));
-        }
+        let b = n * 5;
+        let i = b / 8;
+        let j = b % 8;
+
+        let digit = BASE32_ORD[c as usize];
+        let value = (digit as u16) << j;
+        output[i] |= value as u8 | carry;
+        carry = (value >> 8) as u8;
+
+        mask |= digit;
+    }
+
+    if mask == 0xFF {
+        let c = find_invalid(input);
+        return Err(Nixbase32DecodeError::CharacterNotInAlphabet(c));
     }
 
     // if we're at the end, but have a nonzero carry, the encoding is invalid.
@@ -92,6 +101,17 @@ pub fn decode(input: &[u8]) -> Result<Vec<u8>, Nixbase32DecodeError> {
     Ok(output)
 }
 
+#[cold]
+fn find_invalid(input: &[u8]) -> u8 {
+    for &c in input {
+        if !ALPHABET.contains(&c) {
+            return c;
+        }
+    }
+
+    unreachable!()
+}
+
 /// Returns the decoded length of an input of length len.
 pub fn decode_len(len: usize) -> usize {
     (len * 5) / 8