diff options
-rw-r--r-- | tvix/nix-compat/src/nixbase32.rs | 64 |
1 files changed, 42 insertions, 22 deletions
diff --git a/tvix/nix-compat/src/nixbase32.rs b/tvix/nix-compat/src/nixbase32.rs index d70e4c1a8290..558d7ae8f48e 100644 --- a/tvix/nix-compat/src/nixbase32.rs +++ b/tvix/nix-compat/src/nixbase32.rs @@ -51,17 +51,21 @@ pub fn encode(input: &[u8]) -> String { } /// This maps a nixbase32-encoded character to its binary representation, which -/// is also the index of the character in the alphabet. -fn decode_char(encoded_char: u8) -> Option<u8> { - Some(match encoded_char { - b'0'..=b'9' => encoded_char - b'0', - b'a'..=b'd' => encoded_char - b'a' + 10_u8, - b'f'..=b'n' => encoded_char - b'f' + 14_u8, - b'p'..=b's' => encoded_char - b'p' + 23_u8, - b'v'..=b'z' => encoded_char - b'v' + 27_u8, - _ => return None, - }) -} +/// is also the index of the character in the alphabet. Invalid characters are +/// mapped to 0xFF, which is itself an invalid value. +const BASE32_ORD: [u8; 256] = { + let mut ord = [0xFF; 256]; + let mut alphabet = ALPHABET.as_slice(); + let mut i = 0; + + while let &[c, ref tail @ ..] = alphabet { + ord[c as usize] = i; + alphabet = tail; + i += 1; + } + + ord +}; /// Returns decoded input pub fn decode(input: &[u8]) -> Result<Vec<u8>, Nixbase32DecodeError> { @@ -70,18 +74,23 @@ pub fn decode(input: &[u8]) -> Result<Vec<u8>, Nixbase32DecodeError> { // loop over all characters in reverse, and keep the iteration count in n. let mut carry = 0; + let mut mask = 0; for (n, &c) in input.iter().rev().enumerate() { - if let Some(digit) = decode_char(c) { - let b = n * 5; - let i = b / 8; - let j = b % 8; - - let value = (digit as u16) << j; - output[i] |= value as u8 | carry; - carry = (value >> 8) as u8; - } else { - return Err(Nixbase32DecodeError::CharacterNotInAlphabet(c)); - } + let b = n * 5; + let i = b / 8; + let j = b % 8; + + let digit = BASE32_ORD[c as usize]; + let value = (digit as u16) << j; + output[i] |= value as u8 | carry; + carry = (value >> 8) as u8; + + mask |= digit; + } + + if mask == 0xFF { + let c = find_invalid(input); + return Err(Nixbase32DecodeError::CharacterNotInAlphabet(c)); } // if we're at the end, but have a nonzero carry, the encoding is invalid. @@ -92,6 +101,17 @@ pub fn decode(input: &[u8]) -> Result<Vec<u8>, Nixbase32DecodeError> { Ok(output) } +#[cold] +fn find_invalid(input: &[u8]) -> u8 { + for &c in input { + if !ALPHABET.contains(&c) { + return c; + } + } + + unreachable!() +} + /// Returns the decoded length of an input of length len. pub fn decode_len(len: usize) -> usize { (len * 5) / 8 |