about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--tvix/nix-compat/src/nixbase32.rs35
1 files changed, 34 insertions, 1 deletions
diff --git a/tvix/nix-compat/src/nixbase32.rs b/tvix/nix-compat/src/nixbase32.rs
index fae80cd4a042..d9821bc495bb 100644
--- a/tvix/nix-compat/src/nixbase32.rs
+++ b/tvix/nix-compat/src/nixbase32.rs
@@ -20,6 +20,8 @@ pub enum Nixbase32DecodeError {
     CharacterNotInAlphabet(u8),
     #[error("nonzero carry")]
     NonzeroCarry(),
+    #[error("invalid length")]
+    InvalidLength,
 }
 
 /// Returns encoded input
@@ -74,6 +76,25 @@ pub fn decode(input: impl AsRef<[u8]>) -> Result<Vec<u8>, Nixbase32DecodeError>
     let output_len = decode_len(input.len());
     let mut output: Vec<u8> = vec![0x00; output_len];
 
+    decode_inner(input, &mut output)?;
+    Ok(output)
+}
+
+pub fn decode_fixed<const K: usize>(
+    input: impl AsRef<[u8]>,
+) -> Result<[u8; K], Nixbase32DecodeError> {
+    let input = input.as_ref();
+
+    if input.len() != encode_len(K) {
+        return Err(Nixbase32DecodeError::InvalidLength);
+    }
+
+    let mut output = [0; K];
+    decode_inner(input, &mut output)?;
+    Ok(output)
+}
+
+fn decode_inner(input: &[u8], output: &mut [u8]) -> Result<(), Nixbase32DecodeError> {
     // loop over all characters in reverse, and keep the iteration count in n.
     let mut carry = 0;
     let mut mask = 0;
@@ -100,7 +121,7 @@ pub fn decode(input: impl AsRef<[u8]>) -> Result<Vec<u8>, Nixbase32DecodeError>
         return Err(Nixbase32DecodeError::NonzeroCarry());
     }
 
-    Ok(output)
+    Ok(())
 }
 
 #[cold]
@@ -161,6 +182,18 @@ mod tests {
     }
 
     #[test]
+    fn decode_fixed() {
+        assert_eq!(
+            super::decode_fixed("00bgd045z0d4icpbc2yyz4gx48ak44la").unwrap(),
+            hex!("8a12321522fd91efbd60ebb2481af88580f61600")
+        );
+        assert_eq!(
+            super::decode_fixed::<32>("00").unwrap_err(),
+            super::Nixbase32DecodeError::InvalidLength
+        );
+    }
+
+    #[test]
     fn encode_len() {
         assert_eq!(super::encode_len(0), 0);
         assert_eq!(super::encode_len(20), 32);