about summary refs log tree commit diff
path: root/tvix/castore/src/nodes/symlink_target.rs
use bstr::ByteSlice;
use std::fmt::{self, Debug, Display};

/// A wrapper type for symlink targets.
/// Internally uses a [bytes::Bytes], but disallows empty targets and those
/// containing null bytes.
#[repr(transparent)]
#[derive(Clone, PartialEq, Eq)]
pub struct SymlinkTarget {
    inner: bytes::Bytes,
}

/// The maximum length a symlink target can have.
/// Linux allows 4095 bytes here.
pub const MAX_TARGET_LEN: usize = 4095;

impl AsRef<[u8]> for SymlinkTarget {
    fn as_ref(&self) -> &[u8] {
        self.inner.as_ref()
    }
}

impl From<SymlinkTarget> for bytes::Bytes {
    fn from(value: SymlinkTarget) -> Self {
        value.inner
    }
}

fn validate_symlink_target<B: AsRef<[u8]>>(symlink_target: B) -> Result<B, SymlinkTargetError> {
    let v = symlink_target.as_ref();

    if v.is_empty() {
        return Err(SymlinkTargetError::Empty);
    }
    if v.len() > MAX_TARGET_LEN {
        return Err(SymlinkTargetError::TooLong);
    }
    if v.contains(&0x00) {
        return Err(SymlinkTargetError::Null);
    }

    Ok(symlink_target)
}

impl TryFrom<bytes::Bytes> for SymlinkTarget {
    type Error = SymlinkTargetError;

    fn try_from(value: bytes::Bytes) -> Result<Self, Self::Error> {
        if let Err(e) = validate_symlink_target(&value) {
            return Err(SymlinkTargetError::Convert(value, Box::new(e)));
        }

        Ok(Self { inner: value })
    }
}

impl TryFrom<&'static [u8]> for SymlinkTarget {
    type Error = SymlinkTargetError;

    fn try_from(value: &'static [u8]) -> Result<Self, Self::Error> {
        if let Err(e) = validate_symlink_target(&value) {
            return Err(SymlinkTargetError::Convert(value.into(), Box::new(e)));
        }

        Ok(Self {
            inner: bytes::Bytes::from_static(value),
        })
    }
}

impl TryFrom<&str> for SymlinkTarget {
    type Error = SymlinkTargetError;

    fn try_from(value: &str) -> Result<Self, Self::Error> {
        if let Err(e) = validate_symlink_target(value) {
            return Err(SymlinkTargetError::Convert(
                value.to_owned().into(),
                Box::new(e),
            ));
        }

        Ok(Self {
            inner: bytes::Bytes::copy_from_slice(value.as_bytes()),
        })
    }
}

impl Debug for SymlinkTarget {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        Debug::fmt(self.inner.as_bstr(), f)
    }
}

impl Display for SymlinkTarget {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        Display::fmt(self.inner.as_bstr(), f)
    }
}

/// Errors created when constructing / converting to [SymlinkTarget].
#[derive(Debug, PartialEq, Eq, thiserror::Error)]
#[cfg_attr(test, derive(Clone))]
pub enum SymlinkTargetError {
    #[error("cannot be empty")]
    Empty,
    #[error("cannot contain null bytes")]
    Null,
    #[error("cannot be over {} bytes long", MAX_TARGET_LEN)]
    TooLong,
    #[error("unable to convert '{:?}", .0.as_bstr())]
    Convert(bytes::Bytes, Box<Self>),
}

#[cfg(test)]
mod tests {
    use bytes::Bytes;
    use rstest::rstest;

    use super::validate_symlink_target;
    use super::{SymlinkTarget, SymlinkTargetError};

    #[rstest]
    #[case::empty(b"", SymlinkTargetError::Empty)]
    #[case::null(b"foo\0", SymlinkTargetError::Null)]
    fn errors(#[case] v: &'static [u8], #[case] err: SymlinkTargetError) {
        {
            assert_eq!(
                Err(err.clone()),
                validate_symlink_target(v),
                "validate_symlink_target must fail as expected"
            );
        }

        let exp_err_v = Bytes::from_static(v);

        // Bytes
        {
            let v = Bytes::from_static(v);
            assert_eq!(
                Err(SymlinkTargetError::Convert(
                    exp_err_v.clone(),
                    Box::new(err.clone())
                )),
                SymlinkTarget::try_from(v),
                "conversion must fail as expected"
            );
        }
        // &[u8]
        {
            assert_eq!(
                Err(SymlinkTargetError::Convert(
                    exp_err_v.clone(),
                    Box::new(err.clone())
                )),
                SymlinkTarget::try_from(v),
                "conversion must fail as expected"
            );
        }
        // &str, if this is valid UTF-8
        {
            if let Ok(v) = std::str::from_utf8(v) {
                assert_eq!(
                    Err(SymlinkTargetError::Convert(
                        exp_err_v.clone(),
                        Box::new(err.clone())
                    )),
                    SymlinkTarget::try_from(v),
                    "conversion must fail as expected"
                );
            }
        }
    }

    #[test]
    fn error_toolong() {
        assert_eq!(
            Err(SymlinkTargetError::TooLong),
            validate_symlink_target("X".repeat(5000).into_bytes().as_slice())
        )
    }

    #[rstest]
    #[case::boring(b"aa")]
    #[case::dot(b".")]
    #[case::dotsandslashes(b"./..")]
    #[case::dotdot(b"..")]
    #[case::slashes(b"a/b")]
    #[case::slashes_and_absolute(b"/a/b")]
    #[case::invalid_utf8(b"\xc5\xc4\xd6")]
    fn success(#[case] v: &'static [u8]) {
        let exp = SymlinkTarget { inner: v.into() };

        // Bytes
        {
            let v: Bytes = v.into();
            assert_eq!(
                Ok(exp.clone()),
                SymlinkTarget::try_from(v),
                "conversion must succeed"
            )
        }

        // &[u8]
        {
            assert_eq!(
                Ok(exp.clone()),
                SymlinkTarget::try_from(v),
                "conversion must succeed"
            )
        }

        // &str, if this is valid UTF-8
        {
            if let Ok(v) = std::str::from_utf8(v) {
                assert_eq!(
                    Ok(exp.clone()),
                    SymlinkTarget::try_from(v),
                    "conversion must succeed"
                )
            }
        }
    }
}