use bstr::ByteSlice; use std::fmt::{self, Debug, Display}; /// A wrapper type for symlink targets. /// Internally uses a [bytes::Bytes], but disallows empty targets and those /// containing null bytes. #[repr(transparent)] #[derive(Clone, PartialEq, Eq)] pub struct SymlinkTarget { inner: bytes::Bytes, } /// The maximum length a symlink target can have. /// Linux allows 4095 bytes here. pub const MAX_TARGET_LEN: usize = 4095; impl AsRef<[u8]> for SymlinkTarget { fn as_ref(&self) -> &[u8] { self.inner.as_ref() } } impl From for bytes::Bytes { fn from(value: SymlinkTarget) -> Self { value.inner } } fn validate_symlink_target>(symlink_target: B) -> Result { let v = symlink_target.as_ref(); if v.is_empty() { return Err(SymlinkTargetError::Empty); } if v.len() > MAX_TARGET_LEN { return Err(SymlinkTargetError::TooLong); } if v.contains(&0x00) { return Err(SymlinkTargetError::Null); } Ok(symlink_target) } impl TryFrom for SymlinkTarget { type Error = SymlinkTargetError; fn try_from(value: bytes::Bytes) -> Result { if let Err(e) = validate_symlink_target(&value) { return Err(SymlinkTargetError::Convert(value, Box::new(e))); } Ok(Self { inner: value }) } } impl TryFrom<&'static [u8]> for SymlinkTarget { type Error = SymlinkTargetError; fn try_from(value: &'static [u8]) -> Result { if let Err(e) = validate_symlink_target(&value) { return Err(SymlinkTargetError::Convert(value.into(), Box::new(e))); } Ok(Self { inner: bytes::Bytes::from_static(value), }) } } impl TryFrom<&str> for SymlinkTarget { type Error = SymlinkTargetError; fn try_from(value: &str) -> Result { if let Err(e) = validate_symlink_target(value) { return Err(SymlinkTargetError::Convert( value.to_owned().into(), Box::new(e), )); } Ok(Self { inner: bytes::Bytes::copy_from_slice(value.as_bytes()), }) } } impl Debug for SymlinkTarget { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { Debug::fmt(self.inner.as_bstr(), f) } } impl Display for SymlinkTarget { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { Display::fmt(self.inner.as_bstr(), f) } } /// Errors created when constructing / converting to [SymlinkTarget]. #[derive(Debug, PartialEq, Eq, thiserror::Error)] #[cfg_attr(test, derive(Clone))] pub enum SymlinkTargetError { #[error("cannot be empty")] Empty, #[error("cannot contain null bytes")] Null, #[error("cannot be over {} bytes long", MAX_TARGET_LEN)] TooLong, #[error("unable to convert '{:?}", .0.as_bstr())] Convert(bytes::Bytes, Box), } #[cfg(test)] mod tests { use bytes::Bytes; use rstest::rstest; use super::validate_symlink_target; use super::{SymlinkTarget, SymlinkTargetError}; #[rstest] #[case::empty(b"", SymlinkTargetError::Empty)] #[case::null(b"foo\0", SymlinkTargetError::Null)] fn errors(#[case] v: &'static [u8], #[case] err: SymlinkTargetError) { { assert_eq!( Err(err.clone()), validate_symlink_target(v), "validate_symlink_target must fail as expected" ); } let exp_err_v = Bytes::from_static(v); // Bytes { let v = Bytes::from_static(v); assert_eq!( Err(SymlinkTargetError::Convert( exp_err_v.clone(), Box::new(err.clone()) )), SymlinkTarget::try_from(v), "conversion must fail as expected" ); } // &[u8] { assert_eq!( Err(SymlinkTargetError::Convert( exp_err_v.clone(), Box::new(err.clone()) )), SymlinkTarget::try_from(v), "conversion must fail as expected" ); } // &str, if this is valid UTF-8 { if let Ok(v) = std::str::from_utf8(v) { assert_eq!( Err(SymlinkTargetError::Convert( exp_err_v.clone(), Box::new(err.clone()) )), SymlinkTarget::try_from(v), "conversion must fail as expected" ); } } } #[test] fn error_toolong() { assert_eq!( Err(SymlinkTargetError::TooLong), validate_symlink_target("X".repeat(5000).into_bytes().as_slice()) ) } #[rstest] #[case::boring(b"aa")] #[case::dot(b".")] #[case::dotsandslashes(b"./..")] #[case::dotdot(b"..")] #[case::slashes(b"a/b")] #[case::slashes_and_absolute(b"/a/b")] #[case::invalid_utf8(b"\xc5\xc4\xd6")] fn success(#[case] v: &'static [u8]) { let exp = SymlinkTarget { inner: v.into() }; // Bytes { let v: Bytes = v.into(); assert_eq!( Ok(exp.clone()), SymlinkTarget::try_from(v), "conversion must succeed" ) } // &[u8] { assert_eq!( Ok(exp.clone()), SymlinkTarget::try_from(v), "conversion must succeed" ) } // &str, if this is valid UTF-8 { if let Ok(v) = std::str::from_utf8(v) { assert_eq!( Ok(exp.clone()), SymlinkTarget::try_from(v), "conversion must succeed" ) } } } }