use std::alloc::Layout; use std::cmp::Ordering; use std::fmt::{Debug, Formatter}; #[derive(Clone, Copy)] #[repr(C)] struct GSSmall { len: u32, data: [u8; 12], } #[derive(Clone, Copy)] #[repr(transparent)] struct StorageClassPtr(usize); impl StorageClassPtr { fn transient(ptr: *const u8) -> Self { debug_assert!( (ptr as usize & 0b1) == 0, "pointer must be at least 2-byte aligned" ); Self(ptr as usize) } fn persistent(ptr: *const u8) -> Self { debug_assert!( (ptr as usize & 0b1) == 0, "pointer must be at least 2-byte aligned" ); Self((ptr as usize) | 0b1) } fn as_ptr(&self) -> *const u8 { (self.0 & !0b1) as *const u8 } unsafe fn as_mut_ptr(&self) -> *mut u8 { (self.0 & !0b1) as *mut u8 } fn is_transient(&self) -> bool { (self.0 & 0b1) == 0 } } #[derive(Clone, Copy)] #[repr(C)] struct GSLarge { len: u32, prefix: [u8; 4], data: StorageClassPtr, } const _ASSERT_VARIANTS_SIZE: () = assert!( std::mem::size_of::() == std::mem::size_of::(), "German String variants must have the same size" ); union GSRepr { small: GSSmall, large: GSLarge, } #[repr(transparent)] pub struct GermanString(GSRepr); const _ASSERT_GSTRING_SIZE: () = assert!( std::mem::size_of::() == 16, "German String should be 16 bytes in size", ); impl GermanString { /// Creates a new transient German String from the given slice, copying the /// data in the process. pub fn transient(bytes: &[u8]) -> GermanString { if bytes.len() > u32::MAX as usize { panic!("GermanString maximum length is {} bytes", u32::MAX); } if bytes.len() <= 12 { let mut s = GSSmall { len: bytes.len() as u32, data: [0u8; 12], }; s.data[..bytes.len()].copy_from_slice(bytes); GermanString(GSRepr { small: s }) } else { let layout = Layout::array::(bytes.len()).unwrap(); let mut large = GSLarge { len: bytes.len() as u32, prefix: [0u8; 4], data: unsafe { let ptr = std::alloc::alloc(layout); if ptr.is_null() { std::alloc::handle_alloc_error(layout); } std::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr, bytes.len()); StorageClassPtr::transient(ptr) }, }; large.prefix.copy_from_slice(&bytes[..4]); GermanString(GSRepr { large }) } } /// Creates a new transient German String from the given owned bytes. Short /// strings will be copied into the string representation, long strings will /// be moved out of the given vector without additional allocations. pub fn transient_from_owned(bytes: Vec) -> GermanString { if bytes.len() > u32::MAX as usize { panic!("GermanString maximum length is {} bytes", u32::MAX); } if bytes.len() <= 12 { let mut s = GSSmall { len: bytes.len() as u32, data: [0u8; 12], }; s.data[..bytes.len()].copy_from_slice(&bytes); GermanString(GSRepr { small: s }) } else { let md = std::mem::ManuallyDrop::new(bytes); let mut large = GSLarge { len: md.len() as u32, prefix: [0u8; 4], data: StorageClassPtr::transient(md.as_ptr()), }; large.prefix.copy_from_slice(&md[..4]); GermanString(GSRepr { large }) } } /// Creates a persistent German String from a static data buffer. pub fn persistent(bytes: &'static [u8]) -> GermanString { if bytes.len() > u32::MAX as usize { panic!("GermanString maximum length is {} bytes", u32::MAX); } if bytes.len() <= 12 { let mut s = GSSmall { len: bytes.len() as u32, data: [0u8; 12], }; s.data[..bytes.len()].copy_from_slice(&bytes); GermanString(GSRepr { small: s }) } else { let mut large = GSLarge { len: bytes.len() as u32, prefix: [0u8; 4], data: StorageClassPtr::persistent(bytes.as_ptr()), }; large.prefix.copy_from_slice(&bytes[..4]); GermanString(GSRepr { large }) } } /// Creates a persistent German String by leaking the provided data. pub fn persistent_leak(bytes: Vec) -> GermanString { if bytes.len() > u32::MAX as usize { panic!("GermanString maximum length is {} bytes", u32::MAX); } if bytes.len() <= 12 { let mut s = GSSmall { len: bytes.len() as u32, data: [0u8; 12], }; s.data[..bytes.len()].copy_from_slice(&bytes); GermanString(GSRepr { small: s }) } else { let md = std::mem::ManuallyDrop::new(bytes); let mut large = GSLarge { len: md.len() as u32, prefix: [0u8; 4], data: StorageClassPtr::persistent(md.as_ptr()), }; large.prefix.copy_from_slice(&md[..4]); GermanString(GSRepr { large }) } } /// Creates a persistent German String from a static data buffer. pub fn persistent_from_str(s: &'static str) -> GermanString { GermanString::persistent(s.as_bytes()) } pub fn len(&self) -> usize { // SAFETY: The length field is located in the same location for both // variants, reading it from either is safe. unsafe { self.0.small.len as usize } } pub fn as_bytes(&self) -> &[u8] { if self.len() > 12 { unsafe { std::slice::from_raw_parts(self.0.large.data.as_ptr(), self.len()) } } else { unsafe { &self.0.small.data.as_ref()[..self.len()] } } } pub fn as_str(&self) -> Result<&str, std::str::Utf8Error> { std::str::from_utf8(self.as_bytes()) } } impl Drop for GermanString { fn drop(&mut self) { unsafe { if self.len() > 12 && self.0.large.data.is_transient() { let layout = Layout::array::(self.len()).unwrap(); std::alloc::dealloc(self.0.large.data.as_mut_ptr(), layout); } } } } impl PartialEq for GermanString { fn eq(&self, other: &GermanString) -> bool { if self.len() != other.len() { return false; } unsafe { if self.len() <= 12 { return self.0.small.data[..self.len()] == other.0.small.data[..other.len()]; } return self.0.large.data.as_ptr() == other.0.large.data.as_ptr() || (self.0.large.prefix == other.0.large.prefix && self.as_bytes() == other.as_bytes()); } } } impl Eq for GermanString {} impl Ord for GermanString { fn cmp(&self, other: &GermanString) -> Ordering { match (self.len().cmp(&12), other.len().cmp(&12)) { // two small strings (Ordering::Less | Ordering::Equal, Ordering::Less | Ordering::Equal) => unsafe { self.0.small.data[..self.len()].cmp(&other.0.small.data[..other.len()]) }, // two large strings (Ordering::Greater, Ordering::Greater) => unsafe { match self.0.large.prefix.cmp(&other.0.large.prefix) { Ordering::Equal => self.as_bytes().cmp(other.as_bytes()), ordering => ordering, } }, // LHS large, RHS small (Ordering::Greater, _) => { let prefix_ordering = unsafe { self.0.large.prefix.as_slice().cmp(&other.0.small.data[..4]) }; if prefix_ordering != Ordering::Equal { return prefix_ordering; } self.as_bytes().cmp(other.as_bytes()) } // LHS small, RHS large (_, Ordering::Greater) => { let prefix_ordering = unsafe { self.0.small.data[..4].cmp(other.0.large.prefix.as_slice()) }; if prefix_ordering != Ordering::Equal { return prefix_ordering; } self.as_bytes().cmp(other.as_bytes()) } } } } impl PartialOrd for GermanString { fn partial_cmp(&self, other: &GermanString) -> Option { Some(self.cmp(other)) } } impl Debug for GermanString { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { String::from_utf8_lossy(self.as_bytes()).fmt(f) } } #[cfg(test)] mod tests { use super::*; use proptest::prelude::*; impl Arbitrary for GermanString { type Parameters = ::Parameters; type Strategy = BoxedStrategy; fn arbitrary_with(args: Self::Parameters) -> Self::Strategy { any_with::(args) .prop_map(|s| GermanString::transient(s.as_bytes())) .boxed() } } #[test] fn test_empty_string() { let empty = GermanString::transient(b""); assert_eq!(empty.len(), 0, "empty string should be empty"); assert_eq!(empty.as_bytes(), b"", "empty string should contain nothing"); assert_eq!( empty.as_str().expect("empty string is valid UTF-8"), "", "empty string should contain empty string" ); } #[test] fn test_short_string() { let short = GermanString::transient(b"meow"); assert_eq!(short.len(), 4, "'meow' is four characters"); assert_eq!( short.as_bytes(), b"meow", "short string returns correct bytes" ); assert_eq!( short.as_str().expect("'meow' is valid UTF-8"), "meow", "short string returns correct string" ); } #[test] fn test_long_string() { let input: &str = "This code was written at https://signal.live"; let long = GermanString::transient(input.as_bytes()); assert_eq!(long.len(), 44, "long string has correct length"); assert_eq!( long.as_bytes(), input.as_bytes(), "long string returns correct bytes" ); assert_eq!( long.as_str().expect("input is valid UTF-8"), input, "long string returns correct string" ); } proptest! { #[test] fn test_roundtrip_vec(input: Vec) { let gs = GermanString::transient_from_owned(input.clone()); assert_eq!(input.len(), gs.len(), "length should match"); let out = gs.as_bytes().to_owned(); assert_eq!(input, out, "roundtrip should yield same bytes"); } #[test] fn test_roundtrip_string(input: String) { let gs = GermanString::transient_from_owned(input.clone().into_bytes()); assert_eq!(input.len(), gs.len(), "length should match"); let out = String::from_utf8(gs.as_bytes().to_owned()) .expect("string should be valid after roundtrip"); assert_eq!(input, out, "roundtrip should yield same string"); } // Test [`Eq`] implementation. #[test] fn test_eq(lhs: Vec, rhs: Vec) { let lhs_gs = GermanString::transient(lhs.as_slice()); let rhs_gs = GermanString::transient(rhs.as_slice()); assert_eq!( (lhs == rhs), (lhs_gs == rhs_gs), "Eq should match between std::String and GermanString ({:?} == {:?})", lhs, rhs, ); } #[test] fn test_reflexivity(x: GermanString) { prop_assert!(x == x); } #[test] fn test_symmetry(x: GermanString, y: GermanString) { prop_assert_eq!(x == y, y == x); } #[test] fn test_transitivity(x: GermanString, y: GermanString, z: GermanString) { if x == y && y == z { assert!(x == z); } } } }