diff options
Diffstat (limited to 'tvix/castore/src/nodes')
-rw-r--r-- | tvix/castore/src/nodes/mod.rs | 2 | ||||
-rw-r--r-- | tvix/castore/src/nodes/symlink_target.rs | 171 |
2 files changed, 157 insertions, 16 deletions
diff --git a/tvix/castore/src/nodes/mod.rs b/tvix/castore/src/nodes/mod.rs index 684e65f89b25..ac7aa1e666df 100644 --- a/tvix/castore/src/nodes/mod.rs +++ b/tvix/castore/src/nodes/mod.rs @@ -4,7 +4,7 @@ mod symlink_target; use crate::B3Digest; pub use directory::Directory; -pub use symlink_target::SymlinkTarget; +pub use symlink_target::{SymlinkTarget, SymlinkTargetError}; /// A Node is either a [DirectoryNode], [FileNode] or [SymlinkNode]. /// Nodes themselves don't have names, what gives them names is either them diff --git a/tvix/castore/src/nodes/symlink_target.rs b/tvix/castore/src/nodes/symlink_target.rs index 838cdaaeda5b..e9a1a0bd05c2 100644 --- a/tvix/castore/src/nodes/symlink_target.rs +++ b/tvix/castore/src/nodes/symlink_target.rs @@ -1,6 +1,3 @@ -// TODO: split out this error -use crate::ValidateNodeError; - use bstr::ByteSlice; use std::fmt::{self, Debug, Display}; @@ -13,6 +10,10 @@ pub struct SymlinkTarget { inner: bytes::Bytes, } +/// The maximum length a symlink target can have. +/// Linux allows 4095 bytes here. +pub const MAX_TARGET_LEN: usize = 4095; + impl AsRef<[u8]> for SymlinkTarget { fn as_ref(&self) -> &[u8] { self.inner.as_ref() @@ -25,12 +26,28 @@ impl From<SymlinkTarget> for bytes::Bytes { } } +fn validate_symlink_target<B: AsRef<[u8]>>(symlink_target: B) -> Result<B, SymlinkTargetError> { + let v = symlink_target.as_ref(); + + if v.is_empty() { + return Err(SymlinkTargetError::Empty); + } + if v.len() > MAX_TARGET_LEN { + return Err(SymlinkTargetError::TooLong); + } + if v.contains(&0x00) { + return Err(SymlinkTargetError::Null); + } + + Ok(symlink_target) +} + impl TryFrom<bytes::Bytes> for SymlinkTarget { - type Error = ValidateNodeError; + type Error = SymlinkTargetError; fn try_from(value: bytes::Bytes) -> Result<Self, Self::Error> { - if value.is_empty() || value.contains(&b'\0') { - return Err(ValidateNodeError::InvalidSymlinkTarget(value)); + if let Err(e) = validate_symlink_target(&value) { + return Err(SymlinkTargetError::Convert(value, Box::new(e))); } Ok(Self { inner: value }) @@ -38,13 +55,11 @@ impl TryFrom<bytes::Bytes> for SymlinkTarget { } impl TryFrom<&'static [u8]> for SymlinkTarget { - type Error = ValidateNodeError; + type Error = SymlinkTargetError; fn try_from(value: &'static [u8]) -> Result<Self, Self::Error> { - if value.is_empty() || value.contains(&b'\0') { - return Err(ValidateNodeError::InvalidSymlinkTarget( - bytes::Bytes::from_static(value), - )); + if let Err(e) = validate_symlink_target(&value) { + return Err(SymlinkTargetError::Convert(value.into(), Box::new(e))); } Ok(Self { @@ -54,12 +69,13 @@ impl TryFrom<&'static [u8]> for SymlinkTarget { } impl TryFrom<&str> for SymlinkTarget { - type Error = ValidateNodeError; + type Error = SymlinkTargetError; fn try_from(value: &str) -> Result<Self, Self::Error> { - if value.is_empty() { - return Err(ValidateNodeError::InvalidSymlinkTarget( - bytes::Bytes::copy_from_slice(value.as_bytes()), + if let Err(e) = validate_symlink_target(value) { + return Err(SymlinkTargetError::Convert( + value.to_owned().into(), + Box::new(e), )); } @@ -80,3 +96,128 @@ impl Display for SymlinkTarget { Display::fmt(self.inner.as_bstr(), f) } } + +/// Errors created when constructing / converting to [SymlinkTarget]. +#[derive(Debug, PartialEq, Eq, thiserror::Error)] +#[cfg_attr(test, derive(Clone))] +pub enum SymlinkTargetError { + #[error("cannot be empty")] + Empty, + #[error("cannot contain null bytes")] + Null, + #[error("cannot be over {} bytes long", MAX_TARGET_LEN)] + TooLong, + #[error("unable to convert '{:?}", .0.as_bstr())] + Convert(bytes::Bytes, Box<Self>), +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + use rstest::rstest; + + use super::validate_symlink_target; + use super::{SymlinkTarget, SymlinkTargetError}; + + #[rstest] + #[case::empty(b"", SymlinkTargetError::Empty)] + #[case::null(b"foo\0", SymlinkTargetError::Null)] + fn errors(#[case] v: &'static [u8], #[case] err: SymlinkTargetError) { + { + assert_eq!( + Err(err.clone()), + validate_symlink_target(v), + "validate_symlink_target must fail as expected" + ); + } + + let exp_err_v = Bytes::from_static(v); + + // Bytes + { + let v = Bytes::from_static(v); + assert_eq!( + Err(SymlinkTargetError::Convert( + exp_err_v.clone(), + Box::new(err.clone()) + )), + SymlinkTarget::try_from(v), + "conversion must fail as expected" + ); + } + // &[u8] + { + assert_eq!( + Err(SymlinkTargetError::Convert( + exp_err_v.clone(), + Box::new(err.clone()) + )), + SymlinkTarget::try_from(v), + "conversion must fail as expected" + ); + } + // &str, if this is valid UTF-8 + { + if let Ok(v) = std::str::from_utf8(v) { + assert_eq!( + Err(SymlinkTargetError::Convert( + exp_err_v.clone(), + Box::new(err.clone()) + )), + SymlinkTarget::try_from(v), + "conversion must fail as expected" + ); + } + } + } + + #[test] + fn error_toolong() { + assert_eq!( + Err(SymlinkTargetError::TooLong), + validate_symlink_target("X".repeat(5000).into_bytes().as_slice()) + ) + } + + #[rstest] + #[case::boring(b"aa")] + #[case::dot(b".")] + #[case::dotsandslashes(b"./..")] + #[case::dotdot(b"..")] + #[case::slashes(b"a/b")] + #[case::slashes_and_absolute(b"/a/b")] + #[case::invalid_utf8(b"\xc5\xc4\xd6")] + fn success(#[case] v: &'static [u8]) { + let exp = SymlinkTarget { inner: v.into() }; + + // Bytes + { + let v: Bytes = v.into(); + assert_eq!( + Ok(exp.clone()), + SymlinkTarget::try_from(v), + "conversion must succeed" + ) + } + + // &[u8] + { + assert_eq!( + Ok(exp.clone()), + SymlinkTarget::try_from(v), + "conversion must succeed" + ) + } + + // &str, if this is valid UTF-8 + { + if let Ok(v) = std::str::from_utf8(v) { + assert_eq!( + Ok(exp.clone()), + SymlinkTarget::try_from(v), + "conversion must succeed" + ) + } + } + } +} |