diff options
Diffstat (limited to 'tvix/nix-compat/src/nix_daemon/protocol_version.rs')
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/protocol_version.rs | 123 |
1 files changed, 123 insertions, 0 deletions
diff --git a/tvix/nix-compat/src/nix_daemon/protocol_version.rs b/tvix/nix-compat/src/nix_daemon/protocol_version.rs new file mode 100644 index 000000000000..8fd2b085c962 --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/protocol_version.rs @@ -0,0 +1,123 @@ +/// Protocol versions are represented as a u16. +/// The upper 8 bits are the major version, the lower bits the minor. +/// This is not aware of any endianness, use [crate::wire::read_u64] to get an +/// u64 first, and the try_from() impl from here if you're receiving over the +/// Nix Worker protocol. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct ProtocolVersion(u16); + +impl ProtocolVersion { + pub const fn from_parts(major: u8, minor: u8) -> Self { + Self(((major as u16) << 8) | minor as u16) + } + + pub fn major(&self) -> u8 { + ((self.0 & 0xff00) >> 8) as u8 + } + + pub fn minor(&self) -> u8 { + (self.0 & 0x00ff) as u8 + } +} + +impl PartialOrd for ProtocolVersion { + fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { + Some(self.cmp(other)) + } +} + +impl Ord for ProtocolVersion { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + match self.major().cmp(&other.major()) { + std::cmp::Ordering::Less => std::cmp::Ordering::Less, + std::cmp::Ordering::Greater => std::cmp::Ordering::Greater, + std::cmp::Ordering::Equal => { + // same major, compare minor + self.minor().cmp(&other.minor()) + } + } + } +} + +impl From<u16> for ProtocolVersion { + fn from(value: u16) -> Self { + Self::from_parts(((value & 0xff00) >> 8) as u8, (value & 0x00ff) as u8) + } +} + +impl TryFrom<u64> for ProtocolVersion { + type Error = &'static str; + + fn try_from(value: u64) -> Result<Self, Self::Error> { + if value & !0xffff != 0 { + return Err("only two least significant bits might be populated"); + } + + Ok((value as u16).into()) + } +} + +impl From<ProtocolVersion> for u16 { + fn from(value: ProtocolVersion) -> Self { + value.0 + } +} + +impl From<ProtocolVersion> for u64 { + fn from(value: ProtocolVersion) -> Self { + value.0 as u64 + } +} + +impl std::fmt::Display for ProtocolVersion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}.{}", self.major(), self.minor()) + } +} + +#[cfg(test)] +mod tests { + use super::ProtocolVersion; + + #[test] + fn from_parts() { + let version = ProtocolVersion::from_parts(1, 37); + assert_eq!(version.major(), 1, "correct major"); + assert_eq!(version.minor(), 37, "correct minor"); + assert_eq!("1.37", &version.to_string(), "to_string"); + + assert_eq!(0x0125, Into::<u16>::into(version)); + assert_eq!(0x0125, Into::<u64>::into(version)); + } + + #[test] + fn from_u16() { + let version = ProtocolVersion::from(0x0125_u16); + assert_eq!("1.37", &version.to_string()); + } + + #[test] + fn from_u64() { + let version = ProtocolVersion::try_from(0x0125_u64).expect("must succeed"); + assert_eq!("1.37", &version.to_string()); + } + + /// This contains data in higher bits, which should fail. + #[test] + fn from_u64_fail() { + ProtocolVersion::try_from(0xaa0125_u64).expect_err("must fail"); + } + + #[test] + fn ord() { + let v0_37 = ProtocolVersion::from_parts(0, 37); + let v1_37 = ProtocolVersion::from_parts(1, 37); + let v1_40 = ProtocolVersion::from_parts(1, 40); + + assert!(v0_37 < v1_37); + assert!(v1_37 > v0_37); + assert!(v1_37 < v1_40); + assert!(v1_40 > v1_37); + assert!(v1_40 <= v1_40); + } +} |