diff options
Diffstat (limited to 'users/tazjin/german-string/src/lib.rs')
-rw-r--r-- | users/tazjin/german-string/src/lib.rs | 435 |
1 files changed, 435 insertions, 0 deletions
diff --git a/users/tazjin/german-string/src/lib.rs b/users/tazjin/german-string/src/lib.rs new file mode 100644 index 000000000000..328eca309f38 --- /dev/null +++ b/users/tazjin/german-string/src/lib.rs @@ -0,0 +1,435 @@ +use std::alloc::Layout; +use std::cmp::Ordering; +use std::fmt::{Debug, Formatter}; + +#[derive(Clone, Copy)] +#[repr(C)] +struct GSSmall { + len: u32, + data: [u8; 12], +} + +#[derive(Clone, Copy)] +#[repr(transparent)] +struct StorageClassPtr(usize); + +impl StorageClassPtr { + fn transient(ptr: *const u8) -> Self { + debug_assert!( + (ptr as usize & 0b1) == 0, + "pointer must be at least 2-byte aligned" + ); + Self(ptr as usize) + } + + fn persistent(ptr: *const u8) -> Self { + debug_assert!( + (ptr as usize & 0b1) == 0, + "pointer must be at least 2-byte aligned" + ); + Self((ptr as usize) | 0b1) + } + + fn as_ptr(&self) -> *const u8 { + (self.0 & !0b1) as *const u8 + } + + unsafe fn as_mut_ptr(&self) -> *mut u8 { + (self.0 & !0b1) as *mut u8 + } + + fn is_transient(&self) -> bool { + (self.0 & 0b1) == 0 + } +} + +#[derive(Clone, Copy)] +#[repr(C)] +struct GSLarge { + len: u32, + prefix: [u8; 4], + data: StorageClassPtr, +} + +const _ASSERT_VARIANTS_SIZE: () = assert!( + std::mem::size_of::<GSSmall>() == std::mem::size_of::<GSLarge>(), + "German String variants must have the same size" +); + +union GSRepr { + small: GSSmall, + large: GSLarge, +} + +#[repr(transparent)] +pub struct GermanString(GSRepr); + +const _ASSERT_GSTRING_SIZE: () = assert!( + std::mem::size_of::<GermanString>() == 16, + "German String should be 16 bytes in size", +); + +impl GermanString { + /// Creates a new transient German String from the given slice, copying the + /// data in the process. + pub fn transient(bytes: &[u8]) -> GermanString { + if bytes.len() > u32::MAX as usize { + panic!("GermanString maximum length is {} bytes", u32::MAX); + } + + if bytes.len() <= 12 { + let mut s = GSSmall { + len: bytes.len() as u32, + data: [0u8; 12], + }; + s.data[..bytes.len()].copy_from_slice(bytes); + GermanString(GSRepr { small: s }) + } else { + let layout = Layout::array::<u8>(bytes.len()).unwrap(); + let mut large = GSLarge { + len: bytes.len() as u32, + prefix: [0u8; 4], + data: unsafe { + let ptr = std::alloc::alloc(layout); + if ptr.is_null() { + std::alloc::handle_alloc_error(layout); + } + std::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr, bytes.len()); + StorageClassPtr::transient(ptr) + }, + }; + + large.prefix.copy_from_slice(&bytes[..4]); + + GermanString(GSRepr { large }) + } + } + + /// Creates a new transient German String from the given owned bytes. Short + /// strings will be copied into the string representation, long strings will + /// be moved out of the given vector without additional allocations. + pub fn transient_from_owned(bytes: Vec<u8>) -> GermanString { + if bytes.len() > u32::MAX as usize { + panic!("GermanString maximum length is {} bytes", u32::MAX); + } + + if bytes.len() <= 12 { + let mut s = GSSmall { + len: bytes.len() as u32, + data: [0u8; 12], + }; + + s.data[..bytes.len()].copy_from_slice(&bytes); + GermanString(GSRepr { small: s }) + } else { + let md = std::mem::ManuallyDrop::new(bytes); + let mut large = GSLarge { + len: md.len() as u32, + prefix: [0u8; 4], + data: StorageClassPtr::transient(md.as_ptr()), + }; + + large.prefix.copy_from_slice(&md[..4]); + GermanString(GSRepr { large }) + } + } + + /// Creates a persistent German String from a static data buffer. + pub fn persistent(bytes: &'static [u8]) -> GermanString { + if bytes.len() > u32::MAX as usize { + panic!("GermanString maximum length is {} bytes", u32::MAX); + } + + if bytes.len() <= 12 { + let mut s = GSSmall { + len: bytes.len() as u32, + data: [0u8; 12], + }; + + s.data[..bytes.len()].copy_from_slice(&bytes); + GermanString(GSRepr { small: s }) + } else { + let mut large = GSLarge { + len: bytes.len() as u32, + prefix: [0u8; 4], + data: StorageClassPtr::persistent(bytes.as_ptr()), + }; + + large.prefix.copy_from_slice(&bytes[..4]); + GermanString(GSRepr { large }) + } + } + + /// Creates a persistent German String by leaking the provided data. + pub fn persistent_leak(bytes: Vec<u8>) -> GermanString { + if bytes.len() > u32::MAX as usize { + panic!("GermanString maximum length is {} bytes", u32::MAX); + } + + if bytes.len() <= 12 { + let mut s = GSSmall { + len: bytes.len() as u32, + data: [0u8; 12], + }; + + s.data[..bytes.len()].copy_from_slice(&bytes); + GermanString(GSRepr { small: s }) + } else { + let md = std::mem::ManuallyDrop::new(bytes); + let mut large = GSLarge { + len: md.len() as u32, + prefix: [0u8; 4], + data: StorageClassPtr::persistent(md.as_ptr()), + }; + + large.prefix.copy_from_slice(&md[..4]); + GermanString(GSRepr { large }) + } + } + + /// Creates a persistent German String from a static data buffer. + pub fn persistent_from_str(s: &'static str) -> GermanString { + GermanString::persistent(s.as_bytes()) + } + + pub fn len(&self) -> usize { + // SAFETY: The length field is located in the same location for both + // variants, reading it from either is safe. + unsafe { self.0.small.len as usize } + } + + pub fn as_bytes(&self) -> &[u8] { + if self.len() > 12 { + unsafe { std::slice::from_raw_parts(self.0.large.data.as_ptr(), self.len()) } + } else { + unsafe { &self.0.small.data.as_ref()[..self.len()] } + } + } + + pub fn as_str(&self) -> Result<&str, std::str::Utf8Error> { + std::str::from_utf8(self.as_bytes()) + } +} + +impl Drop for GermanString { + fn drop(&mut self) { + unsafe { + if self.len() > 12 && self.0.large.data.is_transient() { + let layout = Layout::array::<u8>(self.len()).unwrap(); + std::alloc::dealloc(self.0.large.data.as_mut_ptr(), layout); + } + } + } +} + +impl PartialEq for GermanString { + fn eq(&self, other: &GermanString) -> bool { + if self.len() != other.len() { + return false; + } + + unsafe { + if self.len() <= 12 { + return self.0.small.data[..self.len()] == other.0.small.data[..other.len()]; + } + return self.0.large.data.as_ptr() == other.0.large.data.as_ptr() + || (self.0.large.prefix == other.0.large.prefix + && self.as_bytes() == other.as_bytes()); + } + } +} + +impl Eq for GermanString {} + +impl Ord for GermanString { + fn cmp(&self, other: &GermanString) -> Ordering { + match (self.len().cmp(&12), other.len().cmp(&12)) { + // two small strings + (Ordering::Less | Ordering::Equal, Ordering::Less | Ordering::Equal) => unsafe { + self.0.small.data[..self.len()].cmp(&other.0.small.data[..other.len()]) + }, + // two large strings + (Ordering::Greater, Ordering::Greater) => unsafe { + match self.0.large.prefix.cmp(&other.0.large.prefix) { + Ordering::Equal => self.as_bytes().cmp(other.as_bytes()), + ordering => ordering, + } + }, + + // LHS large, RHS small + (Ordering::Greater, _) => { + let prefix_ordering = + unsafe { self.0.large.prefix.as_slice().cmp(&other.0.small.data[..4]) }; + + if prefix_ordering != Ordering::Equal { + return prefix_ordering; + } + + self.as_bytes().cmp(other.as_bytes()) + } + + // LHS small, RHS large + (_, Ordering::Greater) => { + let prefix_ordering = + unsafe { self.0.small.data[..4].cmp(other.0.large.prefix.as_slice()) }; + + if prefix_ordering != Ordering::Equal { + return prefix_ordering; + } + + self.as_bytes().cmp(other.as_bytes()) + } + } + } +} + +impl PartialOrd for GermanString { + fn partial_cmp(&self, other: &GermanString) -> Option<Ordering> { + Some(self.cmp(other)) + } +} + +impl Debug for GermanString { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { + String::from_utf8_lossy(self.as_bytes()).fmt(f) + } +} + +impl Clone for GermanString { + fn clone(&self) -> Self { + unsafe { + if self.len() <= 12 { + return GermanString(GSRepr { + small: self.0.small.clone(), + }); + } + + if self.0.large.data.is_transient() { + return GermanString::transient(self.as_bytes()); + } + + return GermanString(GSRepr { + large: self.0.large.clone(), + }); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proptest::prelude::*; + + impl Arbitrary for GermanString { + type Parameters = <String as Arbitrary>::Parameters; + type Strategy = BoxedStrategy<Self>; + + fn arbitrary_with(args: Self::Parameters) -> Self::Strategy { + any_with::<String>(args) + .prop_map(|s| GermanString::transient(s.as_bytes())) + .boxed() + } + } + + #[test] + fn test_empty_string() { + let empty = GermanString::transient(b""); + + assert_eq!(empty.len(), 0, "empty string should be empty"); + assert_eq!(empty.as_bytes(), b"", "empty string should contain nothing"); + assert_eq!( + empty.as_str().expect("empty string is valid UTF-8"), + "", + "empty string should contain empty string" + ); + } + + #[test] + fn test_short_string() { + let short = GermanString::transient(b"meow"); + + assert_eq!(short.len(), 4, "'meow' is four characters"); + assert_eq!( + short.as_bytes(), + b"meow", + "short string returns correct bytes" + ); + assert_eq!( + short.as_str().expect("'meow' is valid UTF-8"), + "meow", + "short string returns correct string" + ); + } + + #[test] + fn test_long_string() { + let input: &str = "This code was written at https://signal.live"; + let long = GermanString::transient(input.as_bytes()); + + assert_eq!(long.len(), 44, "long string has correct length"); + assert_eq!( + long.as_bytes(), + input.as_bytes(), + "long string returns correct bytes" + ); + + assert_eq!( + long.as_str().expect("input is valid UTF-8"), + input, + "long string returns correct string" + ); + } + + proptest! { + #[test] + fn test_roundtrip_vec(input: Vec<u8>) { + let gs = GermanString::transient_from_owned(input.clone()); + assert_eq!(input.len(), gs.len(), "length should match"); + + let out = gs.as_bytes().to_owned(); + assert_eq!(input, out, "roundtrip should yield same bytes"); + } + + #[test] + fn test_roundtrip_string(input: String) { + let gs = GermanString::transient_from_owned(input.clone().into_bytes()); + assert_eq!(input.len(), gs.len(), "length should match"); + + let out = String::from_utf8(gs.as_bytes().to_owned()) + .expect("string should be valid after roundtrip"); + + assert_eq!(input, out, "roundtrip should yield same string"); + } + + // Test [`Eq`] implementation. + #[test] + fn test_eq(lhs: Vec<u8>, rhs: Vec<u8>) { + let lhs_gs = GermanString::transient(lhs.as_slice()); + let rhs_gs = GermanString::transient(rhs.as_slice()); + + assert_eq!( + (lhs == rhs), + (lhs_gs == rhs_gs), + "Eq should match between std::String and GermanString ({:?} == {:?})", + lhs, rhs, + ); + } + + #[test] + fn test_reflexivity(x: GermanString) { + prop_assert!(x == x); + } + + #[test] + fn test_symmetry(x: GermanString, y: GermanString) { + prop_assert_eq!(x == y, y == x); + } + + #[test] + fn test_transitivity(x: GermanString, y: GermanString, z: GermanString) { + if x == y && y == z { + assert!(x == z); + } + } + } +} |