diff options
Diffstat (limited to 'absl')
121 files changed, 22807 insertions, 305 deletions
diff --git a/absl/container/BUILD.bazel b/absl/container/BUILD.bazel index 998294c07cab..17d725c18ba9 100644 --- a/absl/container/BUILD.bazel +++ b/absl/container/BUILD.bazel @@ -127,6 +127,7 @@ cc_library( "//absl/base:core_headers", "//absl/memory", "//absl/meta:type_traits", + "//absl/types:span", ], ) diff --git a/absl/container/CMakeLists.txt b/absl/container/CMakeLists.txt index 5196e50346a8..6df331e17165 100644 --- a/absl/container/CMakeLists.txt +++ b/absl/container/CMakeLists.txt @@ -126,6 +126,7 @@ absl_cc_library( absl::compressed_tuple absl::core_headers absl::memory + absl::span absl::type_traits PUBLIC ) diff --git a/absl/container/inlined_vector.h b/absl/container/inlined_vector.h index 2c96cc37b47a..10881b22a918 100644 --- a/absl/container/inlined_vector.h +++ b/absl/container/inlined_vector.h @@ -166,7 +166,7 @@ class InlinedVector { InlinedVector(const InlinedVector& other, const allocator_type& alloc) : storage_(alloc) { if (IsMemcpyOk::value && !other.storage_.GetIsAllocated()) { - storage_.MemcpyContents(other.storage_); + storage_.MemcpyFrom(other.storage_); } else { storage_.Initialize(IteratorValueAdapter<const_pointer>(other.data()), other.size()); @@ -193,7 +193,7 @@ class InlinedVector { std::is_nothrow_move_constructible<value_type>::value) : storage_(*other.storage_.GetAllocPtr()) { if (IsMemcpyOk::value) { - storage_.MemcpyContents(other.storage_); + storage_.MemcpyFrom(other.storage_); other.storage_.SetInlinedSize(0); } else if (other.storage_.GetIsAllocated()) { storage_.SetAllocatedData(other.storage_.GetAllocatedData(), @@ -227,7 +227,7 @@ class InlinedVector { absl::allocator_is_nothrow<allocator_type>::value) : storage_(alloc) { if (IsMemcpyOk::value) { - storage_.MemcpyContents(other.storage_); + storage_.MemcpyFrom(other.storage_); other.storage_.SetInlinedSize(0); } else if ((*storage_.GetAllocPtr() == *other.storage_.GetAllocPtr()) && other.storage_.GetIsAllocated()) { @@ -464,26 +464,22 @@ class InlinedVector { InlinedVector& operator=(InlinedVector&& other) { if (ABSL_PREDICT_FALSE(this == std::addressof(other))) return *this; - if (other.storage_.GetIsAllocated()) { - clear(); - storage_.SetAllocatedSize(other.size()); - storage_.SetAllocatedData(other.storage_.GetAllocatedData(), - other.storage_.GetAllocatedCapacity()); + if (IsMemcpyOk::value || other.storage_.GetIsAllocated()) { + inlined_vector_internal::DestroyElements(storage_.GetAllocPtr(), data(), + size()); + if (storage_.GetIsAllocated()) { + AllocatorTraits::deallocate(*storage_.GetAllocPtr(), + storage_.GetAllocatedData(), + storage_.GetAllocatedCapacity()); + } + storage_.MemcpyFrom(other.storage_); other.storage_.SetInlinedSize(0); } else { - if (storage_.GetIsAllocated()) clear(); - // Both are inlined now. - if (size() < other.size()) { - auto mid = std::make_move_iterator(other.begin() + size()); - std::copy(std::make_move_iterator(other.begin()), mid, begin()); - UninitializedCopy(mid, std::make_move_iterator(other.end()), end()); - } else { - auto new_end = std::copy(std::make_move_iterator(other.begin()), - std::make_move_iterator(other.end()), begin()); - Destroy(new_end, end()); - } - storage_.SetInlinedSize(other.size()); + storage_.Assign(IteratorValueAdapter<MoveIterator>( + MoveIterator(other.storage_.GetInlinedData())), + other.size()); } + return *this; } @@ -491,23 +487,7 @@ class InlinedVector { // // Replaces the contents of the inlined vector with `n` copies of `v`. void assign(size_type n, const_reference v) { - if (n <= size()) { // Possibly shrink - std::fill_n(begin(), n, v); - erase(begin() + n, end()); - return; - } - // Grow - reserve(n); - std::fill_n(begin(), size(), v); - if (storage_.GetIsAllocated()) { - UninitializedFill(storage_.GetAllocatedData() + size(), - storage_.GetAllocatedData() + n, v); - storage_.SetAllocatedSize(n); - } else { - UninitializedFill(storage_.GetInlinedData() + size(), - storage_.GetInlinedData() + n, v); - storage_.SetInlinedSize(n); - } + storage_.Assign(CopyValueAdapter(v), n); } // Overload of `InlinedVector::assign()` to replace the contents of the @@ -522,24 +502,8 @@ class InlinedVector { template <typename ForwardIterator, EnableIfAtLeastForwardIterator<ForwardIterator>* = nullptr> void assign(ForwardIterator first, ForwardIterator last) { - auto length = std::distance(first, last); - - // Prefer reassignment to copy construction for elements. - if (static_cast<size_type>(length) <= size()) { - erase(std::copy(first, last, begin()), end()); - return; - } - - reserve(length); - iterator out = begin(); - for (; out != end(); ++first, ++out) *out = *first; - if (storage_.GetIsAllocated()) { - UninitializedCopy(first, last, out); - storage_.SetAllocatedSize(length); - } else { - UninitializedCopy(first, last, out); - storage_.SetInlinedSize(length); - } + storage_.Assign(IteratorValueAdapter<ForwardIterator>(first), + std::distance(first, last)); } // Overload of `InlinedVector::assign()` to replace the contents of the @@ -624,7 +588,15 @@ class InlinedVector { // of `v` starting at `pos`. Returns an `iterator` pointing to the first of // the newly inserted elements. iterator insert(const_iterator pos, size_type n, const_reference v) { - return InsertWithCount(pos, n, v); + assert(pos >= begin() && pos <= end()); + if (ABSL_PREDICT_FALSE(n == 0)) { + return const_cast<iterator>(pos); + } + value_type copy = v; + std::pair<iterator, iterator> it_pair = ShiftRight(pos, n); + std::fill(it_pair.first, it_pair.second, copy); + UninitializedFill(it_pair.second, it_pair.first + n, copy); + return it_pair.first; } // Overload of `InlinedVector::insert()` for copying the contents of the @@ -644,7 +616,17 @@ class InlinedVector { EnableIfAtLeastForwardIterator<ForwardIterator>* = nullptr> iterator insert(const_iterator pos, ForwardIterator first, ForwardIterator last) { - return InsertWithForwardRange(pos, first, last); + assert(pos >= begin() && pos <= end()); + if (ABSL_PREDICT_FALSE(first == last)) { + return const_cast<iterator>(pos); + } + auto n = std::distance(first, last); + std::pair<iterator, iterator> it_pair = ShiftRight(pos, n); + size_type used_spots = it_pair.second - it_pair.first; + auto open_spot = std::next(first, used_spots); + std::copy(first, open_spot, it_pair.first); + UninitializedCopy(open_spot, last, it_pair.second); + return it_pair.first; } // Overload of `InlinedVector::insert()` for inserting elements constructed @@ -696,17 +678,26 @@ class InlinedVector { reference emplace_back(Args&&... args) { size_type s = size(); if (ABSL_PREDICT_FALSE(s == capacity())) { - return GrowAndEmplaceBack(std::forward<Args>(args)...); - } - pointer space; - if (storage_.GetIsAllocated()) { - storage_.SetAllocatedSize(s + 1); - space = storage_.GetAllocatedData(); + size_type new_capacity = 2 * capacity(); + pointer new_data = + AllocatorTraits::allocate(*storage_.GetAllocPtr(), new_capacity); + reference new_element = + Construct(new_data + s, std::forward<Args>(args)...); + UninitializedCopy(std::make_move_iterator(data()), + std::make_move_iterator(data() + s), new_data); + ResetAllocation(new_data, new_capacity, s + 1); + return new_element; } else { - storage_.SetInlinedSize(s + 1); - space = storage_.GetInlinedData(); + pointer space; + if (storage_.GetIsAllocated()) { + storage_.SetAllocatedSize(s + 1); + space = storage_.GetAllocatedData(); + } else { + storage_.SetInlinedSize(s + 1); + space = storage_.GetInlinedData(); + } + return Construct(space + s, std::forward<Args>(args)...); } - return Construct(space + s, std::forward<Args>(args)...); } // `InlinedVector::push_back()` @@ -727,7 +718,7 @@ class InlinedVector { void pop_back() noexcept { assert(!empty()); AllocatorTraits::destroy(*storage_.GetAllocPtr(), data() + (size() - 1)); - storage_.AddSize(-1); + storage_.SubtractSize(1); } // `InlinedVector::erase()` @@ -794,10 +785,20 @@ class InlinedVector { // effects. Otherwise, `reserve()` will reallocate, performing an n-time // element-wise move of everything contained. void reserve(size_type n) { - if (n > capacity()) { - // Make room for new elements - EnlargeBy(n - size()); + if (n <= capacity()) { + return; + } + const size_type s = size(); + size_type target = (std::max)(static_cast<size_type>(N), n); + size_type new_capacity = capacity(); + while (new_capacity < target) { + new_capacity <<= 1; } + pointer new_data = + AllocatorTraits::allocate(*storage_.GetAllocPtr(), new_capacity); + UninitializedCopy(std::make_move_iterator(data()), + std::make_move_iterator(data() + s), new_data); + ResetAllocation(new_data, new_capacity, s); } // `InlinedVector::shrink_to_fit()` @@ -812,37 +813,105 @@ class InlinedVector { // If `size() > N` and `size() < capacity()` the elements will be moved to a // smaller heap allocation. void shrink_to_fit() { - const auto s = size(); - if (ABSL_PREDICT_FALSE(!storage_.GetIsAllocated() || s == capacity())) - return; - - if (s <= N) { - // Move the elements to the inlined storage. - // We have to do this using a temporary, because `inlined_storage` and - // `allocation_storage` are in a union field. - auto temp = std::move(*this); - assign(std::make_move_iterator(temp.begin()), - std::make_move_iterator(temp.end())); - return; + if (storage_.GetIsAllocated()) { + storage_.ShrinkToFit(); } - - // Reallocate storage and move elements. - // We can't simply use the same approach as above, because `assign()` would - // call into `reserve()` internally and reserve larger capacity than we need - pointer new_data = AllocatorTraits::allocate(*storage_.GetAllocPtr(), s); - UninitializedCopy(std::make_move_iterator(storage_.GetAllocatedData()), - std::make_move_iterator(storage_.GetAllocatedData() + s), - new_data); - ResetAllocation(new_data, s, s); } // `InlinedVector::swap()` // // Swaps the contents of this inlined vector with the contents of `other`. void swap(InlinedVector& other) { - if (ABSL_PREDICT_FALSE(this == std::addressof(other))) return; + using std::swap; + + if (ABSL_PREDICT_FALSE(this == std::addressof(other))) { + return; + } + + bool is_allocated = storage_.GetIsAllocated(); + bool other_is_allocated = other.storage_.GetIsAllocated(); + + if (is_allocated && other_is_allocated) { + // Both out of line, so just swap the tag, allocation, and allocator. + storage_.SwapSizeAndIsAllocated(std::addressof(other.storage_)); + storage_.SwapAllocatedSizeAndCapacity(std::addressof(other.storage_)); + swap(*storage_.GetAllocPtr(), *other.storage_.GetAllocPtr()); + + return; + } + + if (!is_allocated && !other_is_allocated) { + // Both inlined: swap up to smaller size, then move remaining elements. + InlinedVector* a = this; + InlinedVector* b = std::addressof(other); + if (size() < other.size()) { + swap(a, b); + } + + const size_type a_size = a->size(); + const size_type b_size = b->size(); + assert(a_size >= b_size); + // `a` is larger. Swap the elements up to the smaller array size. + std::swap_ranges(a->storage_.GetInlinedData(), + a->storage_.GetInlinedData() + b_size, + b->storage_.GetInlinedData()); + + // Move the remaining elements: + // [`b_size`, `a_size`) from `a` -> [`b_size`, `a_size`) from `b` + b->UninitializedCopy(a->storage_.GetInlinedData() + b_size, + a->storage_.GetInlinedData() + a_size, + b->storage_.GetInlinedData() + b_size); + a->Destroy(a->storage_.GetInlinedData() + b_size, + a->storage_.GetInlinedData() + a_size); + + storage_.SwapSizeAndIsAllocated(std::addressof(other.storage_)); + swap(*storage_.GetAllocPtr(), *other.storage_.GetAllocPtr()); + + assert(b->size() == a_size); + assert(a->size() == b_size); + return; + } + + // One is out of line, one is inline. + // We first move the elements from the inlined vector into the + // inlined space in the other vector. We then put the other vector's + // pointer/capacity into the originally inlined vector and swap + // the tags. + InlinedVector* a = this; + InlinedVector* b = std::addressof(other); + if (a->storage_.GetIsAllocated()) { + swap(a, b); + } + + assert(!a->storage_.GetIsAllocated()); + assert(b->storage_.GetIsAllocated()); + + const size_type a_size = a->size(); + const size_type b_size = b->size(); + // In an optimized build, `b_size` would be unused. + static_cast<void>(b_size); + + // Made Local copies of `size()`, these can now be swapped + a->storage_.SwapSizeAndIsAllocated(std::addressof(b->storage_)); + + // Copy out before `b`'s union gets clobbered by `inline_space` + pointer b_data = b->storage_.GetAllocatedData(); + size_type b_capacity = b->storage_.GetAllocatedCapacity(); + + b->UninitializedCopy(a->storage_.GetInlinedData(), + a->storage_.GetInlinedData() + a_size, + b->storage_.GetInlinedData()); + a->Destroy(a->storage_.GetInlinedData(), + a->storage_.GetInlinedData() + a_size); + + a->storage_.SetAllocatedData(b_data, b_capacity); - SwapImpl(other); + if (*a->storage_.GetAllocPtr() != *b->storage_.GetAllocPtr()) { + swap(*a->storage_.GetAllocPtr(), *b->storage_.GetAllocPtr()); + } + + assert(b->size() == a_size); + assert(a->size() == b_size); } private: @@ -900,31 +969,6 @@ class InlinedVector { #endif // !defined(NDEBUG) } - // Enlarge the underlying representation so we can store `size_ + delta` elems - // in allocated space. The size is not changed, and any newly added memory is - // not initialized. - void EnlargeBy(size_type delta) { - const size_type s = size(); - assert(s <= capacity()); - - size_type target = (std::max)(static_cast<size_type>(N), s + delta); - - // Compute new capacity by repeatedly doubling current capacity - // TODO(psrc): Check and avoid overflow? - size_type new_capacity = capacity(); - while (new_capacity < target) { - new_capacity <<= 1; - } - - pointer new_data = - AllocatorTraits::allocate(*storage_.GetAllocPtr(), new_capacity); - - UninitializedCopy(std::make_move_iterator(data()), - std::make_move_iterator(data() + s), new_data); - - ResetAllocation(new_data, new_capacity, s); - } - // Shift all elements from `position` to `end()` by `n` places to the right. // If the vector needs to be enlarged, memory will be allocated. // Returns `iterator`s pointing to the start of the previously-initialized @@ -991,147 +1035,6 @@ class InlinedVector { return std::make_pair(start_used, start_raw); } - template <typename... Args> - reference GrowAndEmplaceBack(Args&&... args) { - assert(size() == capacity()); - const size_type s = size(); - - size_type new_capacity = 2 * capacity(); - pointer new_data = - AllocatorTraits::allocate(*storage_.GetAllocPtr(), new_capacity); - - reference new_element = - Construct(new_data + s, std::forward<Args>(args)...); - UninitializedCopy(std::make_move_iterator(data()), - std::make_move_iterator(data() + s), new_data); - - ResetAllocation(new_data, new_capacity, s + 1); - - return new_element; - } - - iterator InsertWithCount(const_iterator position, size_type n, - const_reference v) { - assert(position >= begin() && position <= end()); - if (ABSL_PREDICT_FALSE(n == 0)) return const_cast<iterator>(position); - - value_type copy = v; - std::pair<iterator, iterator> it_pair = ShiftRight(position, n); - std::fill(it_pair.first, it_pair.second, copy); - UninitializedFill(it_pair.second, it_pair.first + n, copy); - - return it_pair.first; - } - - template <typename ForwardIt> - iterator InsertWithForwardRange(const_iterator position, ForwardIt first, - ForwardIt last) { - static_assert(absl::inlined_vector_internal::IsAtLeastForwardIterator< - ForwardIt>::value, - ""); - assert(position >= begin() && position <= end()); - - if (ABSL_PREDICT_FALSE(first == last)) - return const_cast<iterator>(position); - - auto n = std::distance(first, last); - std::pair<iterator, iterator> it_pair = ShiftRight(position, n); - size_type used_spots = it_pair.second - it_pair.first; - auto open_spot = std::next(first, used_spots); - std::copy(first, open_spot, it_pair.first); - UninitializedCopy(open_spot, last, it_pair.second); - return it_pair.first; - } - - void SwapImpl(InlinedVector& other) { - using std::swap; - - bool is_allocated = storage_.GetIsAllocated(); - bool other_is_allocated = other.storage_.GetIsAllocated(); - - if (is_allocated && other_is_allocated) { - // Both out of line, so just swap the tag, allocation, and allocator. - storage_.SwapSizeAndIsAllocated(std::addressof(other.storage_)); - storage_.SwapAllocatedSizeAndCapacity(std::addressof(other.storage_)); - swap(*storage_.GetAllocPtr(), *other.storage_.GetAllocPtr()); - - return; - } - - if (!is_allocated && !other_is_allocated) { - // Both inlined: swap up to smaller size, then move remaining elements. - InlinedVector* a = this; - InlinedVector* b = std::addressof(other); - if (size() < other.size()) { - swap(a, b); - } - - const size_type a_size = a->size(); - const size_type b_size = b->size(); - assert(a_size >= b_size); - // `a` is larger. Swap the elements up to the smaller array size. - std::swap_ranges(a->storage_.GetInlinedData(), - a->storage_.GetInlinedData() + b_size, - b->storage_.GetInlinedData()); - - // Move the remaining elements: - // [`b_size`, `a_size`) from `a` -> [`b_size`, `a_size`) from `b` - b->UninitializedCopy(a->storage_.GetInlinedData() + b_size, - a->storage_.GetInlinedData() + a_size, - b->storage_.GetInlinedData() + b_size); - a->Destroy(a->storage_.GetInlinedData() + b_size, - a->storage_.GetInlinedData() + a_size); - - storage_.SwapSizeAndIsAllocated(std::addressof(other.storage_)); - swap(*storage_.GetAllocPtr(), *other.storage_.GetAllocPtr()); - - assert(b->size() == a_size); - assert(a->size() == b_size); - return; - } - - // One is out of line, one is inline. - // We first move the elements from the inlined vector into the - // inlined space in the other vector. We then put the other vector's - // pointer/capacity into the originally inlined vector and swap - // the tags. - InlinedVector* a = this; - InlinedVector* b = std::addressof(other); - if (a->storage_.GetIsAllocated()) { - swap(a, b); - } - - assert(!a->storage_.GetIsAllocated()); - assert(b->storage_.GetIsAllocated()); - - const size_type a_size = a->size(); - const size_type b_size = b->size(); - // In an optimized build, `b_size` would be unused. - static_cast<void>(b_size); - - // Made Local copies of `size()`, these can now be swapped - a->storage_.SwapSizeAndIsAllocated(std::addressof(b->storage_)); - - // Copy out before `b`'s union gets clobbered by `inline_space` - pointer b_data = b->storage_.GetAllocatedData(); - size_type b_capacity = b->storage_.GetAllocatedCapacity(); - - b->UninitializedCopy(a->storage_.GetInlinedData(), - a->storage_.GetInlinedData() + a_size, - b->storage_.GetInlinedData()); - a->Destroy(a->storage_.GetInlinedData(), - a->storage_.GetInlinedData() + a_size); - - a->storage_.SetAllocatedData(b_data, b_capacity); - - if (*a->storage_.GetAllocPtr() != *b->storage_.GetAllocPtr()) { - swap(*a->storage_.GetAllocPtr(), *b->storage_.GetAllocPtr()); - } - - assert(b->size() == a_size); - assert(a->size() == b_size); - } - Storage storage_; }; diff --git a/absl/container/inlined_vector_benchmark.cc b/absl/container/inlined_vector_benchmark.cc index df4d3ce5a2d1..b99bbd62544f 100644 --- a/absl/container/inlined_vector_benchmark.cc +++ b/absl/container/inlined_vector_benchmark.cc @@ -599,6 +599,146 @@ void BM_AssignFromMove(benchmark::State& state) { ABSL_INTERNAL_BENCHMARK_TWO_SIZE(BM_AssignFromMove, TrivialType); ABSL_INTERNAL_BENCHMARK_TWO_SIZE(BM_AssignFromMove, NontrivialType); +template <typename T, size_t FromSize, size_t ToSize> +void BM_ResizeSize(benchmark::State& state) { + BatchedBenchmark<T>( + state, + /* prepare_vec = */ + [](InlVec<T>* vec, size_t) { + vec->clear(); + vec->resize(FromSize); + }, + /* test_vec = */ + [](InlVec<T>* vec, size_t) { vec->resize(ToSize); }); +} +ABSL_INTERNAL_BENCHMARK_TWO_SIZE(BM_ResizeSize, TrivialType); +ABSL_INTERNAL_BENCHMARK_TWO_SIZE(BM_ResizeSize, NontrivialType); + +template <typename T, size_t FromSize, size_t ToSize> +void BM_ResizeSizeRef(benchmark::State& state) { + auto t = T(); + BatchedBenchmark<T>( + state, + /* prepare_vec = */ + [](InlVec<T>* vec, size_t) { + vec->clear(); + vec->resize(FromSize); + }, + /* test_vec = */ + [&](InlVec<T>* vec, size_t) { + benchmark::DoNotOptimize(t); + vec->resize(ToSize, t); + }); +} +ABSL_INTERNAL_BENCHMARK_TWO_SIZE(BM_ResizeSizeRef, TrivialType); +ABSL_INTERNAL_BENCHMARK_TWO_SIZE(BM_ResizeSizeRef, NontrivialType); + +template <typename T, size_t FromSize, size_t ToSize> +void BM_InsertSizeRef(benchmark::State& state) { + auto t = T(); + BatchedBenchmark<T>( + state, + /* prepare_vec = */ + [](InlVec<T>* vec, size_t) { + vec->clear(); + vec->resize(FromSize); + }, + /* test_vec = */ + [&](InlVec<T>* vec, size_t) { + benchmark::DoNotOptimize(t); + auto* pos = vec->data() + (vec->size() / 2); + vec->insert(pos, t); + }); +} +ABSL_INTERNAL_BENCHMARK_TWO_SIZE(BM_InsertSizeRef, TrivialType); +ABSL_INTERNAL_BENCHMARK_TWO_SIZE(BM_InsertSizeRef, NontrivialType); + +template <typename T, size_t FromSize, size_t ToSize> +void BM_InsertRange(benchmark::State& state) { + InlVec<T> other_vec(ToSize); + BatchedBenchmark<T>( + state, + /* prepare_vec = */ + [](InlVec<T>* vec, size_t) { + vec->clear(); + vec->resize(FromSize); + }, + /* test_vec = */ + [&](InlVec<T>* vec, size_t) { + benchmark::DoNotOptimize(other_vec); + auto* pos = vec->data() + (vec->size() / 2); + vec->insert(pos, other_vec.begin(), other_vec.end()); + }); +} +ABSL_INTERNAL_BENCHMARK_TWO_SIZE(BM_InsertRange, TrivialType); +ABSL_INTERNAL_BENCHMARK_TWO_SIZE(BM_InsertRange, NontrivialType); + +template <typename T, size_t FromSize> +void BM_EmplaceBack(benchmark::State& state) { + BatchedBenchmark<T>( + state, + /* prepare_vec = */ + [](InlVec<T>* vec, size_t) { + vec->clear(); + vec->resize(FromSize); + }, + /* test_vec = */ + [](InlVec<T>* vec, size_t) { vec->emplace_back(); }); +} +ABSL_INTERNAL_BENCHMARK_ONE_SIZE(BM_EmplaceBack, TrivialType); +ABSL_INTERNAL_BENCHMARK_ONE_SIZE(BM_EmplaceBack, NontrivialType); + +template <typename T, size_t FromSize> +void BM_PopBack(benchmark::State& state) { + BatchedBenchmark<T>( + state, + /* prepare_vec = */ + [](InlVec<T>* vec, size_t) { + vec->clear(); + vec->resize(FromSize); + }, + /* test_vec = */ + [](InlVec<T>* vec, size_t) { vec->pop_back(); }); +} +ABSL_INTERNAL_BENCHMARK_ONE_SIZE(BM_PopBack, TrivialType); +ABSL_INTERNAL_BENCHMARK_ONE_SIZE(BM_PopBack, NontrivialType); + +template <typename T, size_t FromSize> +void BM_EraseOne(benchmark::State& state) { + BatchedBenchmark<T>( + state, + /* prepare_vec = */ + [](InlVec<T>* vec, size_t) { + vec->clear(); + vec->resize(FromSize); + }, + /* test_vec = */ + [](InlVec<T>* vec, size_t) { + auto* pos = vec->data() + (vec->size() / 2); + vec->erase(pos); + }); +} +ABSL_INTERNAL_BENCHMARK_ONE_SIZE(BM_EraseOne, TrivialType); +ABSL_INTERNAL_BENCHMARK_ONE_SIZE(BM_EraseOne, NontrivialType); + +template <typename T, size_t FromSize> +void BM_EraseRange(benchmark::State& state) { + BatchedBenchmark<T>( + state, + /* prepare_vec = */ + [](InlVec<T>* vec, size_t) { + vec->clear(); + vec->resize(FromSize); + }, + /* test_vec = */ + [](InlVec<T>* vec, size_t) { + auto* pos = vec->data() + (vec->size() / 2); + vec->erase(pos, pos + 1); + }); +} +ABSL_INTERNAL_BENCHMARK_ONE_SIZE(BM_EraseRange, TrivialType); +ABSL_INTERNAL_BENCHMARK_ONE_SIZE(BM_EraseRange, NontrivialType); + template <typename T, size_t FromSize> void BM_Clear(benchmark::State& state) { BatchedBenchmark<T>( @@ -609,4 +749,56 @@ void BM_Clear(benchmark::State& state) { ABSL_INTERNAL_BENCHMARK_ONE_SIZE(BM_Clear, TrivialType); ABSL_INTERNAL_BENCHMARK_ONE_SIZE(BM_Clear, NontrivialType); +template <typename T, size_t FromSize, size_t ToCapacity> +void BM_Reserve(benchmark::State& state) { + BatchedBenchmark<T>( + state, + /* prepare_vec = */ + [](InlVec<T>* vec, size_t) { + vec->clear(); + vec->resize(FromSize); + }, + /* test_vec = */ + [](InlVec<T>* vec, size_t) { vec->reserve(ToCapacity); }); +} +ABSL_INTERNAL_BENCHMARK_TWO_SIZE(BM_Reserve, TrivialType); +ABSL_INTERNAL_BENCHMARK_TWO_SIZE(BM_Reserve, NontrivialType); + +template <typename T, size_t FromCapacity, size_t ToCapacity> +void BM_ShrinkToFit(benchmark::State& state) { + BatchedBenchmark<T>( + state, + /* prepare_vec = */ + [](InlVec<T>* vec, size_t) { + vec->clear(); + vec->resize(ToCapacity); + vec->reserve(FromCapacity); + }, + /* test_vec = */ [](InlVec<T>* vec, size_t) { vec->shrink_to_fit(); }); +} +ABSL_INTERNAL_BENCHMARK_TWO_SIZE(BM_ShrinkToFit, TrivialType); +ABSL_INTERNAL_BENCHMARK_TWO_SIZE(BM_ShrinkToFit, NontrivialType); + +template <typename T, size_t FromSize, size_t ToSize> +void BM_Swap(benchmark::State& state) { + using VecT = InlVec<T>; + std::array<VecT, kBatchSize> vector_batch{}; + BatchedBenchmark<T>( + state, + /* prepare_vec = */ + [&](InlVec<T>* vec, size_t i) { + vector_batch[i].clear(); + vector_batch[i].resize(ToSize); + vec->resize(FromSize); + }, + /* test_vec = */ + [&](InlVec<T>* vec, size_t i) { + using std::swap; + benchmark::DoNotOptimize(vector_batch[i]); + swap(*vec, vector_batch[i]); + }); +} +ABSL_INTERNAL_BENCHMARK_TWO_SIZE(BM_Swap, TrivialType); +ABSL_INTERNAL_BENCHMARK_TWO_SIZE(BM_Swap, NontrivialType); + } // namespace diff --git a/absl/container/inlined_vector_exception_safety_test.cc b/absl/container/inlined_vector_exception_safety_test.cc index 0a9649250302..e7c47127df9f 100644 --- a/absl/container/inlined_vector_exception_safety_test.cc +++ b/absl/container/inlined_vector_exception_safety_test.cc @@ -12,7 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <array> +#include <initializer_list> +#include <iterator> #include <memory> +#include <utility> #include "gtest/gtest.h" #include "absl/base/internal/exception_safety_testing.h" @@ -81,6 +85,24 @@ using OneSizeTestParams = TestParams<ThrowAllocMovableThrowerVec, kLargeSize>, TestParams<ThrowAllocMovableThrowerVec, kSmallSize>>; +using TwoSizeTestParams = ::testing::Types< + TestParams<ThrowerVec, kLargeSize, kLargeSize>, + TestParams<ThrowerVec, kLargeSize, kSmallSize>, + TestParams<ThrowerVec, kSmallSize, kLargeSize>, + TestParams<ThrowerVec, kSmallSize, kSmallSize>, + TestParams<MovableThrowerVec, kLargeSize, kLargeSize>, + TestParams<MovableThrowerVec, kLargeSize, kSmallSize>, + TestParams<MovableThrowerVec, kSmallSize, kLargeSize>, + TestParams<MovableThrowerVec, kSmallSize, kSmallSize>, + TestParams<ThrowAllocThrowerVec, kLargeSize, kLargeSize>, + TestParams<ThrowAllocThrowerVec, kLargeSize, kSmallSize>, + TestParams<ThrowAllocThrowerVec, kSmallSize, kLargeSize>, + TestParams<ThrowAllocThrowerVec, kSmallSize, kSmallSize>, + TestParams<ThrowAllocMovableThrowerVec, kLargeSize, kLargeSize>, + TestParams<ThrowAllocMovableThrowerVec, kLargeSize, kSmallSize>, + TestParams<ThrowAllocMovableThrowerVec, kSmallSize, kLargeSize>, + TestParams<ThrowAllocMovableThrowerVec, kSmallSize, kSmallSize>>; + template <typename> struct NoSizeTest : ::testing::Test {}; TYPED_TEST_SUITE(NoSizeTest, NoSizeTestParams); @@ -89,6 +111,25 @@ template <typename> struct OneSizeTest : ::testing::Test {}; TYPED_TEST_SUITE(OneSizeTest, OneSizeTestParams); +template <typename> +struct TwoSizeTest : ::testing::Test {}; +TYPED_TEST_SUITE(TwoSizeTest, TwoSizeTestParams); + +template <typename VecT> +bool InlinedVectorInvariants(VecT* vec) { + if (*vec != *vec) return false; + if (vec->size() > vec->capacity()) return false; + if (vec->size() > vec->max_size()) return false; + if (vec->capacity() > vec->max_size()) return false; + if (vec->data() != std::addressof(vec->at(0))) return false; + if (vec->data() != vec->begin()) return false; + if (*vec->data() != *vec->begin()) return false; + if (vec->begin() > vec->end()) return false; + if ((vec->end() - vec->begin()) != vec->size()) return false; + if (std::distance(vec->begin(), vec->end()) != vec->size()) return false; + return true; +} + // Function that always returns false is correct, but refactoring is required // for clarity. It's needed to express that, as a contract, certain operations // should not throw at all. Execution of this function means an exception was @@ -179,6 +220,45 @@ TYPED_TEST(OneSizeTest, MoveConstructor) { } } +TYPED_TEST(TwoSizeTest, Assign) { + using VecT = typename TypeParam::VecT; + using value_type = typename VecT::value_type; + constexpr static auto from_size = TypeParam::GetSizeAt(0); + constexpr static auto to_size = TypeParam::GetSizeAt(1); + + auto tester = testing::MakeExceptionSafetyTester() + .WithInitialValue(VecT{from_size}) + .WithContracts(InlinedVectorInvariants<VecT>); + + EXPECT_TRUE(tester.Test([](VecT* vec) { + *vec = ABSL_INTERNAL_MAKE_INIT_LIST(value_type, to_size); + })); + + EXPECT_TRUE(tester.Test([](VecT* vec) { + VecT other_vec{to_size}; + *vec = other_vec; + })); + + EXPECT_TRUE(tester.Test([](VecT* vec) { + VecT other_vec{to_size}; + *vec = std::move(other_vec); + })); + + EXPECT_TRUE(tester.Test([](VecT* vec) { + value_type val{}; + vec->assign(to_size, val); + })); + + EXPECT_TRUE(tester.Test([](VecT* vec) { + vec->assign(ABSL_INTERNAL_MAKE_INIT_LIST(value_type, to_size)); + })); + + EXPECT_TRUE(tester.Test([](VecT* vec) { + std::array<value_type, to_size> arr{}; + vec->assign(arr.begin(), arr.end()); + })); +} + TYPED_TEST(OneSizeTest, PopBack) { using VecT = typename TypeParam::VecT; constexpr static auto size = TypeParam::GetSizeAt(0); @@ -205,4 +285,17 @@ TYPED_TEST(OneSizeTest, Clear) { })); } +TYPED_TEST(OneSizeTest, ShrinkToFit) { + using VecT = typename TypeParam::VecT; + constexpr static auto size = TypeParam::GetSizeAt(0); + + auto tester = testing::MakeExceptionSafetyTester() + .WithInitialValue(VecT{size}) + .WithContracts(InlinedVectorInvariants<VecT>); + + EXPECT_TRUE(tester.Test([](VecT* vec) { + vec->shrink_to_fit(); // + })); +} + } // namespace diff --git a/absl/container/inlined_vector_test.cc b/absl/container/inlined_vector_test.cc index 6037001a0473..60fe89b28aee 100644 --- a/absl/container/inlined_vector_test.cc +++ b/absl/container/inlined_vector_test.cc @@ -190,6 +190,12 @@ TEST(IntVec, SimpleOps) { } } +TEST(IntVec, PopBackNoOverflow) { + IntVec v = {1}; + v.pop_back(); + EXPECT_EQ(v.size(), 0); +} + TEST(IntVec, AtThrows) { IntVec v = {1, 2, 3}; EXPECT_EQ(v.at(2), 3); diff --git a/absl/container/internal/compressed_tuple.h b/absl/container/internal/compressed_tuple.h index bb3471f5d747..1713ad6862c5 100644 --- a/absl/container/internal/compressed_tuple.h +++ b/absl/container/internal/compressed_tuple.h @@ -32,6 +32,7 @@ #ifndef ABSL_CONTAINER_INTERNAL_COMPRESSED_TUPLE_H_ #define ABSL_CONTAINER_INTERNAL_COMPRESSED_TUPLE_H_ +#include <initializer_list> #include <tuple> #include <type_traits> #include <utility> @@ -75,17 +76,30 @@ constexpr bool IsFinal() { #endif } +// We can't use EBCO on other CompressedTuples because that would mean that we +// derive from multiple Storage<> instantiations with the same I parameter, +// and potentially from multiple identical Storage<> instantiations. So anytime +// we use type inheritance rather than encapsulation, we mark +// CompressedTupleImpl, to make this easy to detect. +struct uses_inheritance {}; + template <typename T> constexpr bool ShouldUseBase() { - return std::is_class<T>::value && std::is_empty<T>::value && !IsFinal<T>(); + return std::is_class<T>::value && std::is_empty<T>::value && !IsFinal<T>() && + !std::is_base_of<uses_inheritance, T>::value; } // The storage class provides two specializations: // - For empty classes, it stores T as a base class. // - For everything else, it stores T as a member. -template <typename D, size_t I, bool = ShouldUseBase<ElemT<D, I>>()> +template <typename T, size_t I, +#if defined(_MSC_VER) + bool UseBase = + ShouldUseBase<typename std::enable_if<true, T>::type>()> +#else + bool UseBase = ShouldUseBase<T>()> +#endif struct Storage { - using T = ElemT<D, I>; T value; constexpr Storage() = default; explicit constexpr Storage(T&& v) : value(absl::forward<T>(v)) {} @@ -95,10 +109,8 @@ struct Storage { T&& get() && { return std::move(*this).value; } }; -template <typename D, size_t I> -struct ABSL_INTERNAL_COMPRESSED_TUPLE_DECLSPEC Storage<D, I, true> - : ElemT<D, I> { - using T = internal_compressed_tuple::ElemT<D, I>; +template <typename T, size_t I> +struct ABSL_INTERNAL_COMPRESSED_TUPLE_DECLSPEC Storage<T, I, true> : T { constexpr Storage() = default; explicit constexpr Storage(T&& v) : T(absl::forward<T>(v)) {} constexpr const T& get() const& { return *this; } @@ -107,29 +119,54 @@ struct ABSL_INTERNAL_COMPRESSED_TUPLE_DECLSPEC Storage<D, I, true> T&& get() && { return std::move(*this); } }; -template <typename D, typename I> +template <typename D, typename I, bool ShouldAnyUseBase> struct ABSL_INTERNAL_COMPRESSED_TUPLE_DECLSPEC CompressedTupleImpl; -template <typename... Ts, size_t... I> -struct ABSL_INTERNAL_COMPRESSED_TUPLE_DECLSPEC - CompressedTupleImpl<CompressedTuple<Ts...>, absl::index_sequence<I...>> +template <typename... Ts, size_t... I, bool ShouldAnyUseBase> +struct ABSL_INTERNAL_COMPRESSED_TUPLE_DECLSPEC CompressedTupleImpl< + CompressedTuple<Ts...>, absl::index_sequence<I...>, ShouldAnyUseBase> // We use the dummy identity function through std::integral_constant to // convince MSVC of accepting and expanding I in that context. Without it // you would get: // error C3548: 'I': parameter pack cannot be used in this context - : Storage<CompressedTuple<Ts...>, - std::integral_constant<size_t, I>::value>... { + : uses_inheritance, + Storage<Ts, std::integral_constant<size_t, I>::value>... { + constexpr CompressedTupleImpl() = default; + explicit constexpr CompressedTupleImpl(Ts&&... args) + : Storage<Ts, I>(absl::forward<Ts>(args))... {} + friend CompressedTuple<Ts...>; +}; + +template <typename... Ts, size_t... I> +struct ABSL_INTERNAL_COMPRESSED_TUPLE_DECLSPEC CompressedTupleImpl< + CompressedTuple<Ts...>, absl::index_sequence<I...>, false> + // We use the dummy identity function as above... + : Storage<Ts, std::integral_constant<size_t, I>::value, false>... { constexpr CompressedTupleImpl() = default; explicit constexpr CompressedTupleImpl(Ts&&... args) - : Storage<CompressedTuple<Ts...>, I>(absl::forward<Ts>(args))... {} + : Storage<Ts, I, false>(absl::forward<Ts>(args))... {} + friend CompressedTuple<Ts...>; }; +std::false_type Or(std::initializer_list<std::false_type>); +std::true_type Or(std::initializer_list<bool>); + +// MSVC requires this to be done separately rather than within the declaration +// of CompressedTuple below. +template <typename... Ts> +constexpr bool ShouldAnyUseBase() { + return decltype( + Or({std::integral_constant<bool, ShouldUseBase<Ts>()>()...})){}; +} + } // namespace internal_compressed_tuple // Helper class to perform the Empty Base Class Optimization. // Ts can contain classes and non-classes, empty or not. For the ones that // are empty classes, we perform the CompressedTuple. If all types in Ts are -// empty classes, then CompressedTuple<Ts...> is itself an empty class. +// empty classes, then CompressedTuple<Ts...> is itself an empty class. (This +// does not apply when one or more of those empty classes is itself an empty +// CompressedTuple.) // // To access the members, use member .get<N>() function. // @@ -145,7 +182,8 @@ struct ABSL_INTERNAL_COMPRESSED_TUPLE_DECLSPEC template <typename... Ts> class ABSL_INTERNAL_COMPRESSED_TUPLE_DECLSPEC CompressedTuple : private internal_compressed_tuple::CompressedTupleImpl< - CompressedTuple<Ts...>, absl::index_sequence_for<Ts...>> { + CompressedTuple<Ts...>, absl::index_sequence_for<Ts...>, + internal_compressed_tuple::ShouldAnyUseBase<Ts...>()> { private: template <int I> using ElemT = internal_compressed_tuple::ElemT<CompressedTuple, I>; @@ -157,24 +195,24 @@ class ABSL_INTERNAL_COMPRESSED_TUPLE_DECLSPEC CompressedTuple template <int I> ElemT<I>& get() & { - return internal_compressed_tuple::Storage<CompressedTuple, I>::get(); + return internal_compressed_tuple::Storage<ElemT<I>, I>::get(); } template <int I> constexpr const ElemT<I>& get() const& { - return internal_compressed_tuple::Storage<CompressedTuple, I>::get(); + return internal_compressed_tuple::Storage<ElemT<I>, I>::get(); } template <int I> ElemT<I>&& get() && { return std::move(*this) - .internal_compressed_tuple::template Storage<CompressedTuple, I>::get(); + .internal_compressed_tuple::template Storage<ElemT<I>, I>::get(); } template <int I> constexpr const ElemT<I>&& get() const&& { return absl::move(*this) - .internal_compressed_tuple::template Storage<CompressedTuple, I>::get(); + .internal_compressed_tuple::template Storage<ElemT<I>, I>::get(); } }; diff --git a/absl/container/internal/compressed_tuple_test.cc b/absl/container/internal/compressed_tuple_test.cc index 28e7741c9a45..3b0ec4555abf 100644 --- a/absl/container/internal/compressed_tuple_test.cc +++ b/absl/container/internal/compressed_tuple_test.cc @@ -22,10 +22,8 @@ #include "absl/memory/memory.h" #include "absl/utility/utility.h" -namespace absl { -namespace container_internal { -namespace { - +// These are declared at global scope purely so that error messages +// are smaller and easier to understand. enum class CallType { kConstRef, kConstMove }; template <int> @@ -45,6 +43,10 @@ struct TwoValues { U value2; }; +namespace absl { +namespace container_internal { +namespace { + TEST(CompressedTupleTest, Sizeof) { EXPECT_EQ(sizeof(int), sizeof(CompressedTuple<int>)); EXPECT_EQ(sizeof(int), sizeof(CompressedTuple<int, Empty<0>>)); @@ -120,9 +122,14 @@ TEST(CompressedTupleTest, Nested) { EXPECT_EQ(4 * sizeof(char), sizeof(CompressedTuple<CompressedTuple<char, char>, CompressedTuple<char, char>>)); - EXPECT_TRUE( - (std::is_empty<CompressedTuple<CompressedTuple<Empty<0>>, - CompressedTuple<Empty<1>>>>::value)); + EXPECT_TRUE((std::is_empty<CompressedTuple<Empty<0>, Empty<1>>>::value)); + + // Make sure everything still works when things are nested. + struct CT_Empty : CompressedTuple<Empty<0>> {}; + CompressedTuple<Empty<0>, CT_Empty> nested_empty; + auto contained = nested_empty.get<0>(); + auto nested = nested_empty.get<1>().get<0>(); + EXPECT_TRUE((std::is_same<decltype(contained), decltype(nested)>::value)); } TEST(CompressedTupleTest, Reference) { diff --git a/absl/container/internal/inlined_vector.h b/absl/container/internal/inlined_vector.h index 92c21ab96571..f117ee0c75de 100644 --- a/absl/container/internal/inlined_vector.h +++ b/absl/container/internal/inlined_vector.h @@ -25,6 +25,7 @@ #include "absl/container/internal/compressed_tuple.h" #include "absl/memory/memory.h" #include "absl/meta/type_traits.h" +#include "absl/types/span.h" namespace absl { namespace inlined_vector_internal { @@ -78,6 +79,14 @@ void ConstructElements(AllocatorType* alloc_ptr, ValueType* construct_first, } } +template <typename ValueType, typename ValueAdapter, typename SizeType> +void AssignElements(ValueType* assign_first, ValueAdapter* values_ptr, + SizeType assign_size) { + for (SizeType i = 0; i < assign_size; ++i) { + values_ptr->AssignNext(assign_first + i); + } +} + template <typename AllocatorType> struct StorageView { using pointer = typename AllocatorType::pointer; @@ -101,6 +110,11 @@ class IteratorValueAdapter { ++it_; } + void AssignNext(pointer assign_at) { + *assign_at = *it_; + ++it_; + } + private: Iterator it_; }; @@ -119,6 +133,8 @@ class CopyValueAdapter { AllocatorTraits::construct(*alloc_ptr, construct_at, *ptr_); } + void AssignNext(pointer assign_at) { *assign_at = *ptr_; } + private: const_pointer ptr_; }; @@ -135,6 +151,44 @@ class DefaultValueAdapter { void ConstructNext(AllocatorType* alloc_ptr, pointer construct_at) { AllocatorTraits::construct(*alloc_ptr, construct_at); } + + void AssignNext(pointer assign_at) { *assign_at = value_type(); } +}; + +template <typename AllocatorType> +class AllocationTransaction { + using value_type = typename AllocatorType::value_type; + using pointer = typename AllocatorType::pointer; + using size_type = typename AllocatorType::size_type; + using AllocatorTraits = absl::allocator_traits<AllocatorType>; + + public: + explicit AllocationTransaction(AllocatorType* alloc_ptr) + : alloc_data_(*alloc_ptr, nullptr) {} + + AllocationTransaction(const AllocationTransaction&) = delete; + void operator=(const AllocationTransaction&) = delete; + + AllocatorType& GetAllocator() { return alloc_data_.template get<0>(); } + pointer& GetData() { return alloc_data_.template get<1>(); } + size_type& GetCapacity() { return capacity_; } + + bool DidAllocate() { return GetData() != nullptr; } + pointer Allocate(size_type capacity) { + GetData() = AllocatorTraits::allocate(GetAllocator(), capacity); + GetCapacity() = capacity; + return GetData(); + } + + ~AllocationTransaction() { + if (DidAllocate()) { + AllocatorTraits::deallocate(GetAllocator(), GetData(), GetCapacity()); + } + } + + private: + container_internal::CompressedTuple<AllocatorType, pointer> alloc_data_; + size_type capacity_ = 0; }; template <typename T, size_t N, typename A> @@ -167,6 +221,9 @@ class Storage { using DefaultValueAdapter = inlined_vector_internal::DefaultValueAdapter<allocator_type>; + using AllocationTransaction = + inlined_vector_internal::AllocationTransaction<allocator_type>; + Storage() : metadata_() {} explicit Storage(const allocator_type& alloc) @@ -215,19 +272,48 @@ class Storage { void SetIsAllocated() { GetSizeAndIsAllocated() |= 1; } + void UnsetIsAllocated() { + SetIsAllocated(); + GetSizeAndIsAllocated() -= 1; + } + void SetAllocatedSize(size_type size) { GetSizeAndIsAllocated() = (size << 1) | static_cast<size_type>(1); } void SetInlinedSize(size_type size) { GetSizeAndIsAllocated() = size << 1; } + void SetSize(size_type size) { + GetSizeAndIsAllocated() = + (size << 1) | static_cast<size_type>(GetIsAllocated()); + } + void AddSize(size_type count) { GetSizeAndIsAllocated() += count << 1; } + void SubtractSize(size_type count) { + assert(count <= GetSize()); + GetSizeAndIsAllocated() -= count << 1; + } + void SetAllocatedData(pointer data, size_type capacity) { data_.allocated.allocated_data = data; data_.allocated.allocated_capacity = capacity; } + void DeallocateIfAllocated() { + if (GetIsAllocated()) { + AllocatorTraits::deallocate(*GetAllocPtr(), GetAllocatedData(), + GetAllocatedCapacity()); + } + } + + void AcquireAllocation(AllocationTransaction* allocation_tx_ptr) { + SetAllocatedData(allocation_tx_ptr->GetData(), + allocation_tx_ptr->GetCapacity()); + allocation_tx_ptr->GetData() = nullptr; + allocation_tx_ptr->GetCapacity() = 0; + } + void SwapSizeAndIsAllocated(Storage* other) { using std::swap; swap(GetSizeAndIsAllocated(), other->GetSizeAndIsAllocated()); @@ -238,11 +324,11 @@ class Storage { swap(data_.allocated, other->data_.allocated); } - void MemcpyContents(const Storage& other) { - assert(IsMemcpyOk::value); + void MemcpyFrom(const Storage& other_storage) { + assert(IsMemcpyOk::value || other_storage.GetIsAllocated()); - GetSizeAndIsAllocated() = other.GetSizeAndIsAllocated(); - data_ = other.data_; + GetSizeAndIsAllocated() = other_storage.GetSizeAndIsAllocated(); + data_ = other_storage.data_; } void DestroyAndDeallocate(); @@ -250,6 +336,11 @@ class Storage { template <typename ValueAdapter> void Initialize(ValueAdapter values, size_type new_size); + template <typename ValueAdapter> + void Assign(ValueAdapter values, size_type new_size); + + void ShrinkToFit(); + private: size_type& GetSizeAndIsAllocated() { return metadata_.template get<1>(); } @@ -282,15 +373,10 @@ class Storage { template <typename T, size_t N, typename A> void Storage<T, N, A>::DestroyAndDeallocate() { - StorageView storage_view = MakeStorageView(); - - inlined_vector_internal::DestroyElements(GetAllocPtr(), storage_view.data, - storage_view.size); - - if (GetIsAllocated()) { - AllocatorTraits::deallocate(*GetAllocPtr(), storage_view.data, - storage_view.capacity); - } + inlined_vector_internal::DestroyElements( + GetAllocPtr(), (GetIsAllocated() ? GetAllocatedData() : GetInlinedData()), + GetSize()); + DeallocateIfAllocated(); } template <typename T, size_t N, typename A> @@ -323,6 +409,91 @@ auto Storage<T, N, A>::Initialize(ValueAdapter values, size_type new_size) AddSize(new_size); } +template <typename T, size_t N, typename A> +template <typename ValueAdapter> +auto Storage<T, N, A>::Assign(ValueAdapter values, size_type new_size) -> void { + StorageView storage_view = MakeStorageView(); + + AllocationTransaction allocation_tx(GetAllocPtr()); + + absl::Span<value_type> assign_loop; + absl::Span<value_type> construct_loop; + absl::Span<value_type> destroy_loop; + + if (new_size > storage_view.capacity) { + construct_loop = {allocation_tx.Allocate(new_size), new_size}; + destroy_loop = {storage_view.data, storage_view.size}; + } else if (new_size > storage_view.size) { + assign_loop = {storage_view.data, storage_view.size}; + construct_loop = {storage_view.data + storage_view.size, + new_size - storage_view.size}; + } else { + assign_loop = {storage_view.data, new_size}; + destroy_loop = {storage_view.data + new_size, storage_view.size - new_size}; + } + + inlined_vector_internal::AssignElements(assign_loop.data(), &values, + assign_loop.size()); + inlined_vector_internal::ConstructElements( + GetAllocPtr(), construct_loop.data(), &values, construct_loop.size()); + inlined_vector_internal::DestroyElements(GetAllocPtr(), destroy_loop.data(), + destroy_loop.size()); + + if (allocation_tx.DidAllocate()) { + DeallocateIfAllocated(); + AcquireAllocation(&allocation_tx); + SetIsAllocated(); + } + + SetSize(new_size); +} + +template <typename T, size_t N, typename A> +auto Storage<T, N, A>::ShrinkToFit() -> void { + // May only be called on allocated instances! + assert(GetIsAllocated()); + + StorageView storage_view = {GetAllocatedData(), GetSize(), + GetAllocatedCapacity()}; + + AllocationTransaction allocation_tx(GetAllocPtr()); + + IteratorValueAdapter<MoveIterator> move_values( + MoveIterator(storage_view.data)); + + pointer construct_data; + + if (storage_view.size <= static_cast<size_type>(N)) { + construct_data = GetInlinedData(); + } else if (storage_view.size < GetAllocatedCapacity()) { + construct_data = allocation_tx.Allocate(storage_view.size); + } else { + return; + } + + ABSL_INTERNAL_TRY { + inlined_vector_internal::ConstructElements(GetAllocPtr(), construct_data, + &move_values, storage_view.size); + } + ABSL_INTERNAL_CATCH_ANY { + // Writing to inlined data will trample on the existing state, thus it needs + // to be restored when a construction fails. + SetAllocatedData(storage_view.data, storage_view.capacity); + ABSL_INTERNAL_RETHROW; + } + + inlined_vector_internal::DestroyElements(GetAllocPtr(), storage_view.data, + storage_view.size); + AllocatorTraits::deallocate(*GetAllocPtr(), storage_view.data, + storage_view.capacity); + + if (allocation_tx.DidAllocate()) { + AcquireAllocation(&allocation_tx); + } else { + UnsetIsAllocated(); + } +} + } // namespace inlined_vector_internal } // namespace absl diff --git a/absl/copts/GENERATED_AbseilCopts.cmake b/absl/copts/GENERATED_AbseilCopts.cmake index c4948d42d495..88400e98b505 100644 --- a/absl/copts/GENERATED_AbseilCopts.cmake +++ b/absl/copts/GENERATED_AbseilCopts.cmake @@ -218,3 +218,21 @@ list(APPEND ABSL_MSVC_TEST_FLAGS "/wd4996" "/DNOMINMAX" ) + +list(APPEND ABSL_RANDOM_HWAES_ARM32_FLAGS + "-mfpu=neon" +) + +list(APPEND ABSL_RANDOM_HWAES_ARM64_FLAGS + "-march=armv8-a+crypto" +) + +list(APPEND ABSL_RANDOM_HWAES_MSVC_X64_FLAGS + "/O2" + "/Ob2" +) + +list(APPEND ABSL_RANDOM_HWAES_X64_FLAGS + "-maes" + "-msse4.1" +) diff --git a/absl/copts/GENERATED_copts.bzl b/absl/copts/GENERATED_copts.bzl index 422b3a9b1c3a..d7edc93673da 100644 --- a/absl/copts/GENERATED_copts.bzl +++ b/absl/copts/GENERATED_copts.bzl @@ -219,3 +219,21 @@ ABSL_MSVC_TEST_FLAGS = [ "/wd4996", "/DNOMINMAX", ] + +ABSL_RANDOM_HWAES_ARM32_FLAGS = [ + "-mfpu=neon", +] + +ABSL_RANDOM_HWAES_ARM64_FLAGS = [ + "-march=armv8-a+crypto", +] + +ABSL_RANDOM_HWAES_MSVC_X64_FLAGS = [ + "/O2", + "/Ob2", +] + +ABSL_RANDOM_HWAES_X64_FLAGS = [ + "-maes", + "-msse4.1", +] diff --git a/absl/copts/configure_copts.bzl b/absl/copts/configure_copts.bzl index 0015931762a4..8c4efe77c21a 100644 --- a/absl/copts/configure_copts.bzl +++ b/absl/copts/configure_copts.bzl @@ -16,6 +16,10 @@ load( "ABSL_MSVC_FLAGS", "ABSL_MSVC_LINKOPTS", "ABSL_MSVC_TEST_FLAGS", + "ABSL_RANDOM_HWAES_ARM32_FLAGS", + "ABSL_RANDOM_HWAES_ARM64_FLAGS", + "ABSL_RANDOM_HWAES_MSVC_X64_FLAGS", + "ABSL_RANDOM_HWAES_X64_FLAGS", ) ABSL_DEFAULT_COPTS = select({ @@ -46,3 +50,40 @@ ABSL_DEFAULT_LINKOPTS = select({ "//absl:windows": ABSL_MSVC_LINKOPTS, "//conditions:default": [], }) + +# ABSL_RANDOM_RANDEN_COPTS blaze copts flags which are required by each +# environment to build an accelerated RandenHwAes library. +ABSL_RANDOM_RANDEN_COPTS = select({ + # APPLE + ":cpu_darwin_x86_64": ABSL_RANDOM_HWAES_X64_FLAGS, + ":cpu_darwin": ABSL_RANDOM_HWAES_X64_FLAGS, + ":cpu_x64_windows_msvc": ABSL_RANDOM_HWAES_MSVC_X64_FLAGS, + ":cpu_x64_windows": ABSL_RANDOM_HWAES_MSVC_X64_FLAGS, + ":cpu_haswell": ABSL_RANDOM_HWAES_X64_FLAGS, + ":cpu_ppc": ["-mcrypto"], + + # Supported by default or unsupported. + "//conditions:default": [], +}) + +# absl_random_randen_copts_init: +# Initialize the config targets based on cpu, os, etc. used to select +# the required values for ABSL_RANDOM_RANDEN_COPTS +def absl_random_randen_copts_init(): + """Initialize the config_settings used by ABSL_RANDOM_RANDEN_COPTS.""" + + # CPU configs. + # These configs have consistent flags to enable HWAES intsructions. + cpu_configs = [ + "ppc", + "haswell", + "darwin_x86_64", + "darwin", + "x64_windows_msvc", + "x64_windows", + ] + for cpu in cpu_configs: + native.config_setting( + name = "cpu_%s" % cpu, + values = {"cpu": cpu}, + ) diff --git a/absl/copts/copts.py b/absl/copts/copts.py index 5bede34cb8e7..d850bb4f0553 100644 --- a/absl/copts/copts.py +++ b/absl/copts/copts.py @@ -199,4 +199,18 @@ COPT_VARS = { # Object file doesn't export any previously undefined symbols "-ignore:4221", ], + # "HWAES" is an abbreviation for "hardware AES" (AES - Advanced Encryption + # Standard). These flags are used for detecting whether or not the target + # architecture has hardware support for AES instructions which can be used + # to improve performance of some random bit generators. + "ABSL_RANDOM_HWAES_ARM64_FLAGS": ["-march=armv8-a+crypto"], + "ABSL_RANDOM_HWAES_ARM32_FLAGS": ["-mfpu=neon"], + "ABSL_RANDOM_HWAES_X64_FLAGS": [ + "-maes", + "-msse4.1", + ], + "ABSL_RANDOM_HWAES_MSVC_X64_FLAGS": [ + "/O2", # Maximize speed + "/Ob2", # Aggressive inlining + ], } diff --git a/absl/flags/flag.h b/absl/flags/flag.h index 2cbbb45ceacb..185bd8aba2a2 100644 --- a/absl/flags/flag.h +++ b/absl/flags/flag.h @@ -155,7 +155,7 @@ void SetFlag(absl::Flag<T>* flag, const V& v) { // // where: // -// * `T` is a supported flag type (See below), +// * `T` is a supported flag type (see `marshalling.h`), // * `name` designates the name of the flag (as a global variable // `FLAGS_name`), // * `default_value` is an expression holding the default value for this flag diff --git a/absl/random/BUILD.bazel b/absl/random/BUILD.bazel new file mode 100644 index 000000000000..00d42c9d839b --- /dev/null +++ b/absl/random/BUILD.bazel @@ -0,0 +1,390 @@ +# ABSL random-number generation libraries. + +load( + "//absl:copts/configure_copts.bzl", + "ABSL_DEFAULT_COPTS", + "ABSL_DEFAULT_LINKOPTS", + "ABSL_EXCEPTIONS_FLAG", + "ABSL_EXCEPTIONS_FLAG_LINKOPTS", + "ABSL_TEST_COPTS", +) + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "random", + hdrs = ["random.h"], + copts = ABSL_DEFAULT_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":distributions", + ":seed_sequences", + "//absl/random/internal:nonsecure_base", + "//absl/random/internal:pcg_engine", + "//absl/random/internal:pool_urbg", + "//absl/random/internal:randen_engine", + ], +) + +cc_library( + name = "distributions", + srcs = [ + "discrete_distribution.cc", + "gaussian_distribution.cc", + ], + hdrs = [ + "bernoulli_distribution.h", + "beta_distribution.h", + "discrete_distribution.h", + "distribution_format_traits.h", + "distributions.h", + "exponential_distribution.h", + "gaussian_distribution.h", + "log_uniform_int_distribution.h", + "poisson_distribution.h", + "uniform_int_distribution.h", + "uniform_real_distribution.h", + "zipf_distribution.h", + ], + copts = ABSL_DEFAULT_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + "//absl/base:base_internal", + "//absl/base:core_headers", + "//absl/meta:type_traits", + "//absl/random/internal:distribution_impl", + "//absl/random/internal:distributions", + "//absl/random/internal:fast_uniform_bits", + "//absl/random/internal:fastmath", + "//absl/random/internal:iostream_state_saver", + "//absl/random/internal:traits", + "//absl/random/internal:uniform_helper", + "//absl/strings", + "//absl/types:span", + ], +) + +cc_library( + name = "seed_gen_exception", + srcs = ["seed_gen_exception.cc"], + hdrs = ["seed_gen_exception.h"], + copts = ABSL_DEFAULT_COPTS + ABSL_EXCEPTIONS_FLAG, + linkopts = ABSL_EXCEPTIONS_FLAG_LINKOPTS + ABSL_DEFAULT_LINKOPTS, + deps = ["//absl/base:config"], +) + +cc_library( + name = "seed_sequences", + srcs = ["seed_sequences.cc"], + hdrs = [ + "seed_sequences.h", + ], + copts = ABSL_DEFAULT_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":seed_gen_exception", + "//absl/container:inlined_vector", + "//absl/random/internal:nonsecure_base", + "//absl/random/internal:pool_urbg", + "//absl/random/internal:salted_seed_seq", + "//absl/random/internal:seed_material", + "//absl/types:span", + ], +) + +cc_test( + name = "bernoulli_distribution_test", + size = "small", + timeout = "eternal", # Android can take a very long time + srcs = ["bernoulli_distribution_test.cc"], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":distributions", + ":random", + "//absl/random/internal:sequence_urbg", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "beta_distribution_test", + size = "small", + timeout = "eternal", # Android can take a very long time + srcs = ["beta_distribution_test.cc"], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":distributions", + ":random", + "//absl/base", + "//absl/random/internal:distribution_test_util", + "//absl/random/internal:sequence_urbg", + "//absl/strings", + "//absl/strings:str_format", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "distributions_test", + size = "small", + srcs = [ + "distributions_test.cc", + ], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":distributions", + ":random", + "//absl/random/internal:distribution_test_util", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "generators_test", + size = "small", + srcs = ["generators_test.cc"], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":distributions", + ":random", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "log_uniform_int_distribution_test", + size = "medium", + srcs = [ + "log_uniform_int_distribution_test.cc", + ], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":distributions", + ":random", + "//absl/base", + "//absl/base:core_headers", + "//absl/random/internal:distribution_test_util", + "//absl/random/internal:sequence_urbg", + "//absl/strings", + "//absl/strings:str_format", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "discrete_distribution_test", + size = "medium", + srcs = [ + "discrete_distribution_test.cc", + ], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":distributions", + ":random", + "//absl/base", + "//absl/random/internal:distribution_test_util", + "//absl/random/internal:sequence_urbg", + "//absl/strings", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "poisson_distribution_test", + size = "small", + timeout = "eternal", # Android can take a very long time + srcs = [ + "poisson_distribution_test.cc", + ], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + tags = [ + # Too Slow. + "no_test_android_arm", + "no_test_loonix", + ], + deps = [ + ":distributions", + ":random", + "//absl/base", + "//absl/base:core_headers", + "//absl/container:flat_hash_map", + "//absl/random/internal:distribution_test_util", + "//absl/random/internal:sequence_urbg", + "//absl/strings", + "//absl/strings:str_format", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "exponential_distribution_test", + size = "small", + srcs = ["exponential_distribution_test.cc"], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":distributions", + ":random", + "//absl/base", + "//absl/base:core_headers", + "//absl/random/internal:distribution_test_util", + "//absl/random/internal:sequence_urbg", + "//absl/strings", + "//absl/strings:str_format", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "gaussian_distribution_test", + size = "small", + timeout = "eternal", # Android can take a very long time + srcs = [ + "gaussian_distribution_test.cc", + ], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":distributions", + ":random", + "//absl/base", + "//absl/base:core_headers", + "//absl/random/internal:distribution_test_util", + "//absl/random/internal:sequence_urbg", + "//absl/strings", + "//absl/strings:str_format", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "uniform_int_distribution_test", + size = "medium", + timeout = "long", + srcs = [ + "uniform_int_distribution_test.cc", + ], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":distributions", + ":random", + "//absl/base", + "//absl/random/internal:distribution_test_util", + "//absl/random/internal:sequence_urbg", + "//absl/strings", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "uniform_real_distribution_test", + size = "medium", + srcs = [ + "uniform_real_distribution_test.cc", + ], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + tags = [ + "no_test_android_arm", + "no_test_android_arm64", + "no_test_android_x86", + ], + deps = [ + ":distributions", + ":random", + "//absl/base", + "//absl/random/internal:distribution_test_util", + "//absl/random/internal:sequence_urbg", + "//absl/strings", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "zipf_distribution_test", + size = "medium", + srcs = [ + "zipf_distribution_test.cc", + ], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":distributions", + ":random", + "//absl/base", + "//absl/random/internal:distribution_test_util", + "//absl/random/internal:sequence_urbg", + "//absl/strings", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "examples_test", + size = "small", + srcs = ["examples_test.cc"], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":random", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "seed_sequences_test", + size = "small", + srcs = ["seed_sequences_test.cc"], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":random", + ":seed_sequences", + "//absl/random/internal:nonsecure_base", + "@com_google_googletest//:gtest_main", + ], +) + +BENCHMARK_TAGS = [ + "benchmark", + "no_test_android_arm", + "no_test_android_arm64", + "no_test_android_x86", + "no_test_darwin_x86_64", + "no_test_ios_x86_64", + "no_test_loonix", + "no_test_msvc_x64", + "no_test_wasm", +] + +# Benchmarks for various methods / test utilities +cc_test( + name = "benchmarks", + size = "small", + srcs = [ + "benchmarks.cc", + ], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + tags = BENCHMARK_TAGS, + deps = [ + ":distributions", + ":random", + ":seed_sequences", + "//absl/base:core_headers", + "//absl/meta:type_traits", + "//absl/random/internal:fast_uniform_bits", + "//absl/random/internal:randen_engine", + "@com_github_google_benchmark//:benchmark_main", + ], +) diff --git a/absl/random/benchmarks.cc b/absl/random/benchmarks.cc new file mode 100644 index 000000000000..8e6d889e72ae --- /dev/null +++ b/absl/random/benchmarks.cc @@ -0,0 +1,383 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Benchmarks for absl random distributions as well as a selection of the +// C++ standard library random distributions. + +#include <algorithm> +#include <cstddef> +#include <cstdint> +#include <initializer_list> +#include <iterator> +#include <limits> +#include <random> +#include <type_traits> +#include <vector> + +#include "benchmark/benchmark.h" +#include "absl/base/macros.h" +#include "absl/meta/type_traits.h" +#include "absl/random/bernoulli_distribution.h" +#include "absl/random/beta_distribution.h" +#include "absl/random/exponential_distribution.h" +#include "absl/random/gaussian_distribution.h" +#include "absl/random/internal/fast_uniform_bits.h" +#include "absl/random/internal/randen_engine.h" +#include "absl/random/log_uniform_int_distribution.h" +#include "absl/random/poisson_distribution.h" +#include "absl/random/random.h" +#include "absl/random/uniform_int_distribution.h" +#include "absl/random/uniform_real_distribution.h" +#include "absl/random/zipf_distribution.h" + +namespace { + +// Seed data to avoid reading random_device() for benchmarks. +uint32_t kSeedData[] = { + 0x1B510052, 0x9A532915, 0xD60F573F, 0xBC9BC6E4, 0x2B60A476, 0x81E67400, + 0x08BA6FB5, 0x571BE91F, 0xF296EC6B, 0x2A0DD915, 0xB6636521, 0xE7B9F9B6, + 0xFF34052E, 0xC5855664, 0x53B02D5D, 0xA99F8FA1, 0x08BA4799, 0x6E85076A, + 0x4B7A70E9, 0xB5B32944, 0xDB75092E, 0xC4192623, 0xAD6EA6B0, 0x49A7DF7D, + 0x9CEE60B8, 0x8FEDB266, 0xECAA8C71, 0x699A18FF, 0x5664526C, 0xC2B19EE1, + 0x193602A5, 0x75094C29, 0xA0591340, 0xE4183A3E, 0x3F54989A, 0x5B429D65, + 0x6B8FE4D6, 0x99F73FD6, 0xA1D29C07, 0xEFE830F5, 0x4D2D38E6, 0xF0255DC1, + 0x4CDD2086, 0x8470EB26, 0x6382E9C6, 0x021ECC5E, 0x09686B3F, 0x3EBAEFC9, + 0x3C971814, 0x6B6A70A1, 0x687F3584, 0x52A0E286, 0x13198A2E, 0x03707344, +}; + +// PrecompiledSeedSeq provides kSeedData to a conforming +// random engine to speed initialization in the benchmarks. +class PrecompiledSeedSeq { + public: + using result_type = uint32_t; + + PrecompiledSeedSeq() {} + + template <typename Iterator> + PrecompiledSeedSeq(Iterator begin, Iterator end) {} + + template <typename T> + PrecompiledSeedSeq(std::initializer_list<T> il) {} + + template <typename OutIterator> + void generate(OutIterator begin, OutIterator end) { + static size_t idx = 0; + for (; begin != end; begin++) { + *begin = kSeedData[idx++]; + if (idx >= ABSL_ARRAYSIZE(kSeedData)) { + idx = 0; + } + } + } + + size_t size() const { return ABSL_ARRAYSIZE(kSeedData); } + + template <typename OutIterator> + void param(OutIterator out) const { + std::copy(std::begin(kSeedData), std::end(kSeedData), out); + } +}; + +// use_default_initialization<T> indicates whether the random engine +// T must be default initialized, or whether we may initialize it using +// a seed sequence. This is used because some engines do not accept seed +// sequence-based initialization. +template <typename E> +using use_default_initialization = std::false_type; + +// make_engine<T, SSeq> returns a random_engine which is initialized, +// either via the default constructor, when use_default_initialization<T> +// is true, or via the indicated seed sequence, SSeq. +template <typename Engine, typename SSeq = PrecompiledSeedSeq> +typename absl::enable_if_t<!use_default_initialization<Engine>::value, Engine> +make_engine() { + // Initialize the random engine using the seed sequence SSeq, which + // is constructed from the precompiled seed data. + SSeq seq(std::begin(kSeedData), std::end(kSeedData)); + return Engine(seq); +} + +template <typename Engine, typename SSeq = PrecompiledSeedSeq> +typename absl::enable_if_t<use_default_initialization<Engine>::value, Engine> +make_engine() { + // Initialize the random engine using the default constructor. + return Engine(); +} + +template <typename Engine, typename SSeq> +void BM_Construct(benchmark::State& state) { + for (auto _ : state) { + auto rng = make_engine<Engine, SSeq>(); + benchmark::DoNotOptimize(rng()); + } +} + +template <typename Engine> +void BM_Direct(benchmark::State& state) { + using value_type = typename Engine::result_type; + // Direct use of the URBG. + auto rng = make_engine<Engine>(); + for (auto _ : state) { + benchmark::DoNotOptimize(rng()); + } + state.SetBytesProcessed(sizeof(value_type) * state.iterations()); +} + +template <typename Engine> +void BM_Generate(benchmark::State& state) { + // std::generate makes a copy of the RNG; thus this tests the + // copy-constructor efficiency. + using value_type = typename Engine::result_type; + std::vector<value_type> v(64); + auto rng = make_engine<Engine>(); + while (state.KeepRunningBatch(64)) { + std::generate(std::begin(v), std::end(v), rng); + } +} + +template <typename Engine, size_t elems> +void BM_Shuffle(benchmark::State& state) { + // Direct use of the Engine. + std::vector<uint32_t> v(elems); + while (state.KeepRunningBatch(elems)) { + auto rng = make_engine<Engine>(); + std::shuffle(std::begin(v), std::end(v), rng); + } +} + +template <typename Engine, size_t elems> +void BM_ShuffleReuse(benchmark::State& state) { + // Direct use of the Engine. + std::vector<uint32_t> v(elems); + auto rng = make_engine<Engine>(); + while (state.KeepRunningBatch(elems)) { + std::shuffle(std::begin(v), std::end(v), rng); + } +} + +template <typename Engine, typename Dist, typename... Args> +void BM_Dist(benchmark::State& state, Args&&... args) { + using value_type = typename Dist::result_type; + auto rng = make_engine<Engine>(); + Dist dis{std::forward<Args>(args)...}; + // Compare the following loop performance: + for (auto _ : state) { + benchmark::DoNotOptimize(dis(rng)); + } + state.SetBytesProcessed(sizeof(value_type) * state.iterations()); +} + +template <typename Engine, typename Dist> +void BM_Large(benchmark::State& state) { + using value_type = typename Dist::result_type; + volatile value_type kMin = 0; + volatile value_type kMax = std::numeric_limits<value_type>::max() / 2 + 1; + BM_Dist<Engine, Dist>(state, kMin, kMax); +} + +template <typename Engine, typename Dist> +void BM_Small(benchmark::State& state) { + using value_type = typename Dist::result_type; + volatile value_type kMin = 0; + volatile value_type kMax = std::numeric_limits<value_type>::max() / 64 + 1; + BM_Dist<Engine, Dist>(state, kMin, kMax); +} + +template <typename Engine, typename Dist, int A> +void BM_Bernoulli(benchmark::State& state) { + volatile double a = static_cast<double>(A) / 1000000; + BM_Dist<Engine, Dist>(state, a); +} + +template <typename Engine, typename Dist, int A, int B> +void BM_Beta(benchmark::State& state) { + using value_type = typename Dist::result_type; + volatile value_type a = static_cast<value_type>(A) / 100; + volatile value_type b = static_cast<value_type>(B) / 100; + BM_Dist<Engine, Dist>(state, a, b); +} + +template <typename Engine, typename Dist, int A> +void BM_Gamma(benchmark::State& state) { + using value_type = typename Dist::result_type; + volatile value_type a = static_cast<value_type>(A) / 100; + BM_Dist<Engine, Dist>(state, a); +} + +template <typename Engine, typename Dist, int A = 100> +void BM_Poisson(benchmark::State& state) { + volatile double a = static_cast<double>(A) / 100; + BM_Dist<Engine, Dist>(state, a); +} + +template <typename Engine, typename Dist, int V = 1, int Q = 2> +void BM_Zipf(benchmark::State& state) { + using value_type = typename Dist::result_type; + volatile double v = V; + volatile double q = Q; + BM_Dist<Engine, Dist>(state, std::numeric_limits<value_type>::max(), v, q); +} + +template <typename Engine, typename Dist> +void BM_Thread(benchmark::State& state) { + using value_type = typename Dist::result_type; + auto rng = make_engine<Engine>(); + Dist dis{}; + for (auto _ : state) { + benchmark::DoNotOptimize(dis(rng)); + } + state.SetBytesProcessed(sizeof(value_type) * state.iterations()); +} + +// NOTES: +// +// std::geometric_distribution is similar to the zipf distributions. +// The algorithm for the geometric_distribution is, basically, +// floor(log(1-X) / log(1-p)) + +// Normal benchmark suite +#define BM_BASIC(Engine) \ + BENCHMARK_TEMPLATE(BM_Construct, Engine, PrecompiledSeedSeq); \ + BENCHMARK_TEMPLATE(BM_Construct, Engine, std::seed_seq); \ + BENCHMARK_TEMPLATE(BM_Direct, Engine); \ + BENCHMARK_TEMPLATE(BM_Shuffle, Engine, 10); \ + BENCHMARK_TEMPLATE(BM_Shuffle, Engine, 100); \ + BENCHMARK_TEMPLATE(BM_Shuffle, Engine, 1000); \ + BENCHMARK_TEMPLATE(BM_ShuffleReuse, Engine, 100); \ + BENCHMARK_TEMPLATE(BM_ShuffleReuse, Engine, 1000); \ + BENCHMARK_TEMPLATE(BM_Dist, Engine, \ + absl::random_internal::FastUniformBits<uint32_t, 32>); \ + BENCHMARK_TEMPLATE(BM_Dist, Engine, \ + absl::random_internal::FastUniformBits<uint64_t, 64>); \ + BENCHMARK_TEMPLATE(BM_Dist, Engine, std::uniform_int_distribution<int32_t>); \ + BENCHMARK_TEMPLATE(BM_Dist, Engine, std::uniform_int_distribution<int64_t>); \ + BENCHMARK_TEMPLATE(BM_Dist, Engine, \ + absl::uniform_int_distribution<int32_t>); \ + BENCHMARK_TEMPLATE(BM_Dist, Engine, \ + absl::uniform_int_distribution<int64_t>); \ + BENCHMARK_TEMPLATE(BM_Large, Engine, \ + std::uniform_int_distribution<int32_t>); \ + BENCHMARK_TEMPLATE(BM_Large, Engine, \ + std::uniform_int_distribution<int64_t>); \ + BENCHMARK_TEMPLATE(BM_Large, Engine, \ + absl::uniform_int_distribution<int32_t>); \ + BENCHMARK_TEMPLATE(BM_Large, Engine, \ + absl::uniform_int_distribution<int64_t>); \ + BENCHMARK_TEMPLATE(BM_Dist, Engine, std::uniform_real_distribution<float>); \ + BENCHMARK_TEMPLATE(BM_Dist, Engine, std::uniform_real_distribution<double>); \ + BENCHMARK_TEMPLATE(BM_Dist, Engine, absl::uniform_real_distribution<float>); \ + BENCHMARK_TEMPLATE(BM_Dist, Engine, absl::uniform_real_distribution<double>) + +#define BM_COPY(Engine) BENCHMARK_TEMPLATE(BM_Generate, Engine) + +#define BM_THREAD(Engine) \ + BENCHMARK_TEMPLATE(BM_Thread, Engine, \ + absl::uniform_int_distribution<int64_t>) \ + ->ThreadPerCpu(); \ + BENCHMARK_TEMPLATE(BM_Thread, Engine, \ + absl::uniform_real_distribution<double>) \ + ->ThreadPerCpu(); \ + BENCHMARK_TEMPLATE(BM_Shuffle, Engine, 100)->ThreadPerCpu(); \ + BENCHMARK_TEMPLATE(BM_Shuffle, Engine, 1000)->ThreadPerCpu(); \ + BENCHMARK_TEMPLATE(BM_ShuffleReuse, Engine, 100)->ThreadPerCpu(); \ + BENCHMARK_TEMPLATE(BM_ShuffleReuse, Engine, 1000)->ThreadPerCpu(); + +#define BM_EXTENDED(Engine) \ + /* -------------- Extended Uniform -----------------------*/ \ + BENCHMARK_TEMPLATE(BM_Small, Engine, \ + std::uniform_int_distribution<int32_t>); \ + BENCHMARK_TEMPLATE(BM_Small, Engine, \ + std::uniform_int_distribution<int64_t>); \ + BENCHMARK_TEMPLATE(BM_Small, Engine, \ + absl::uniform_int_distribution<int32_t>); \ + BENCHMARK_TEMPLATE(BM_Small, Engine, \ + absl::uniform_int_distribution<int64_t>); \ + BENCHMARK_TEMPLATE(BM_Small, Engine, std::uniform_real_distribution<float>); \ + BENCHMARK_TEMPLATE(BM_Small, Engine, \ + std::uniform_real_distribution<double>); \ + BENCHMARK_TEMPLATE(BM_Small, Engine, \ + absl::uniform_real_distribution<float>); \ + BENCHMARK_TEMPLATE(BM_Small, Engine, \ + absl::uniform_real_distribution<double>); \ + /* -------------- Other -----------------------*/ \ + BENCHMARK_TEMPLATE(BM_Dist, Engine, std::normal_distribution<double>); \ + BENCHMARK_TEMPLATE(BM_Dist, Engine, absl::gaussian_distribution<double>); \ + BENCHMARK_TEMPLATE(BM_Dist, Engine, std::exponential_distribution<double>); \ + BENCHMARK_TEMPLATE(BM_Dist, Engine, absl::exponential_distribution<double>); \ + BENCHMARK_TEMPLATE(BM_Poisson, Engine, std::poisson_distribution<int64_t>, \ + 100); \ + BENCHMARK_TEMPLATE(BM_Poisson, Engine, absl::poisson_distribution<int64_t>, \ + 100); \ + BENCHMARK_TEMPLATE(BM_Poisson, Engine, std::poisson_distribution<int64_t>, \ + 10 * 100); \ + BENCHMARK_TEMPLATE(BM_Poisson, Engine, absl::poisson_distribution<int64_t>, \ + 10 * 100); \ + BENCHMARK_TEMPLATE(BM_Poisson, Engine, std::poisson_distribution<int64_t>, \ + 13 * 100); \ + BENCHMARK_TEMPLATE(BM_Poisson, Engine, absl::poisson_distribution<int64_t>, \ + 13 * 100); \ + BENCHMARK_TEMPLATE(BM_Dist, Engine, \ + absl::log_uniform_int_distribution<int32_t>); \ + BENCHMARK_TEMPLATE(BM_Dist, Engine, \ + absl::log_uniform_int_distribution<int64_t>); \ + BENCHMARK_TEMPLATE(BM_Dist, Engine, std::geometric_distribution<int64_t>); \ + BENCHMARK_TEMPLATE(BM_Zipf, Engine, absl::zipf_distribution<uint64_t>); \ + BENCHMARK_TEMPLATE(BM_Zipf, Engine, absl::zipf_distribution<uint64_t>, 3, \ + 2); \ + BENCHMARK_TEMPLATE(BM_Bernoulli, Engine, std::bernoulli_distribution, \ + 257305); \ + BENCHMARK_TEMPLATE(BM_Bernoulli, Engine, absl::bernoulli_distribution, \ + 257305); \ + BENCHMARK_TEMPLATE(BM_Beta, Engine, absl::beta_distribution<double>, 65, \ + 41); \ + BENCHMARK_TEMPLATE(BM_Beta, Engine, absl::beta_distribution<double>, 99, \ + 330); \ + BENCHMARK_TEMPLATE(BM_Beta, Engine, absl::beta_distribution<double>, 150, \ + 150); \ + BENCHMARK_TEMPLATE(BM_Beta, Engine, absl::beta_distribution<double>, 410, \ + 580); \ + BENCHMARK_TEMPLATE(BM_Beta, Engine, absl::beta_distribution<float>, 65, 41); \ + BENCHMARK_TEMPLATE(BM_Beta, Engine, absl::beta_distribution<float>, 99, \ + 330); \ + BENCHMARK_TEMPLATE(BM_Beta, Engine, absl::beta_distribution<float>, 150, \ + 150); \ + BENCHMARK_TEMPLATE(BM_Beta, Engine, absl::beta_distribution<float>, 410, \ + 580); \ + BENCHMARK_TEMPLATE(BM_Gamma, Engine, std::gamma_distribution<float>, 199); \ + BENCHMARK_TEMPLATE(BM_Gamma, Engine, std::gamma_distribution<double>, 199); + +// ABSL Recommended interfaces. +BM_BASIC(absl::InsecureBitGen); // === pcg64_2018_engine +BM_BASIC(absl::BitGen); // === randen_engine<uint64_t>. +BM_THREAD(absl::BitGen); +BM_EXTENDED(absl::BitGen); + +// Instantiate benchmarks for multiple engines. +using randen_engine_64 = absl::random_internal::randen_engine<uint64_t>; +using randen_engine_32 = absl::random_internal::randen_engine<uint32_t>; + +// Comparison interfaces. +BM_BASIC(std::mt19937_64); +BM_COPY(std::mt19937_64); +BM_EXTENDED(std::mt19937_64); +BM_BASIC(randen_engine_64); +BM_COPY(randen_engine_64); +BM_EXTENDED(randen_engine_64); + +BM_BASIC(std::mt19937); +BM_COPY(std::mt19937); +BM_BASIC(randen_engine_32); +BM_COPY(randen_engine_32); + +} // namespace diff --git a/absl/random/bernoulli_distribution.h b/absl/random/bernoulli_distribution.h new file mode 100644 index 000000000000..326fcb6e012d --- /dev/null +++ b/absl/random/bernoulli_distribution.h @@ -0,0 +1,198 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_BERNOULLI_DISTRIBUTION_H_ +#define ABSL_RANDOM_BERNOULLI_DISTRIBUTION_H_ + +#include <cstdint> +#include <istream> +#include <limits> + +#include "absl/base/optimization.h" +#include "absl/random/internal/fast_uniform_bits.h" +#include "absl/random/internal/iostream_state_saver.h" + +namespace absl { + +// absl::bernoulli_distribution is a drop in replacement for +// std::bernoulli_distribution. It guarantees that (given a perfect +// UniformRandomBitGenerator) the acceptance probability is *exactly* equal to +// the given double. +// +// The implementation assumes that double is IEEE754 +class bernoulli_distribution { + public: + using result_type = bool; + + class param_type { + public: + using distribution_type = bernoulli_distribution; + + explicit param_type(double p = 0.5) : prob_(p) { + assert(p >= 0.0 && p <= 1.0); + } + + double p() const { return prob_; } + + friend bool operator==(const param_type& p1, const param_type& p2) { + return p1.p() == p2.p(); + } + friend bool operator!=(const param_type& p1, const param_type& p2) { + return p1.p() != p2.p(); + } + + private: + double prob_; + }; + + bernoulli_distribution() : bernoulli_distribution(0.5) {} + + explicit bernoulli_distribution(double p) : param_(p) {} + + explicit bernoulli_distribution(param_type p) : param_(p) {} + + // no-op + void reset() {} + + template <typename URBG> + bool operator()(URBG& g) { // NOLINT(runtime/references) + return Generate(param_.p(), g); + } + + template <typename URBG> + bool operator()(URBG& g, // NOLINT(runtime/references) + const param_type& param) { + return Generate(param.p(), g); + } + + param_type param() const { return param_; } + void param(const param_type& param) { param_ = param; } + + double p() const { return param_.p(); } + + result_type(min)() const { return false; } + result_type(max)() const { return true; } + + friend bool operator==(const bernoulli_distribution& d1, + const bernoulli_distribution& d2) { + return d1.param_ == d2.param_; + } + + friend bool operator!=(const bernoulli_distribution& d1, + const bernoulli_distribution& d2) { + return d1.param_ != d2.param_; + } + + private: + static constexpr uint64_t kP32 = static_cast<uint64_t>(1) << 32; + + template <typename URBG> + static bool Generate(double p, URBG& g); // NOLINT(runtime/references) + + param_type param_; +}; + +template <typename CharT, typename Traits> +std::basic_ostream<CharT, Traits>& operator<<( + std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references) + const bernoulli_distribution& x) { + auto saver = random_internal::make_ostream_state_saver(os); + os.precision(random_internal::stream_precision_helper<double>::kPrecision); + os << x.p(); + return os; +} + +template <typename CharT, typename Traits> +std::basic_istream<CharT, Traits>& operator>>( + std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references) + bernoulli_distribution& x) { // NOLINT(runtime/references) + auto saver = random_internal::make_istream_state_saver(is); + auto p = random_internal::read_floating_point<double>(is); + if (!is.fail()) { + x.param(bernoulli_distribution::param_type(p)); + } + return is; +} + +template <typename URBG> +bool bernoulli_distribution::Generate(double p, + URBG& g) { // NOLINT(runtime/references) + random_internal::FastUniformBits<uint32_t> fast_u32; + + while (true) { + // There are two aspects of the definition of `c` below that are worth + // commenting on. First, because `p` is in the range [0, 1], `c` is in the + // range [0, 2^32] which does not fit in a uint32_t and therefore requires + // 64 bits. + // + // Second, `c` is constructed by first casting explicitly to a signed + // integer and then converting implicitly to an unsigned integer of the same + // size. This is done because the hardware conversion instructions produce + // signed integers from double; if taken as a uint64_t the conversion would + // be wrong for doubles greater than 2^63 (not relevant in this use-case). + // If converted directly to an unsigned integer, the compiler would end up + // emitting code to handle such large values that are not relevant due to + // the known bounds on `c`. To avoid these extra instructions this + // implementation converts first to the signed type and then use the + // implicit conversion to unsigned (which is a no-op). + const uint64_t c = static_cast<int64_t>(p * kP32); + const uint32_t v = fast_u32(g); + // FAST PATH: this path fails with probability 1/2^32. Note that simply + // returning v <= c would approximate P very well (up to an absolute error + // of 1/2^32); the slow path (taken in that range of possible error, in the + // case of equality) eliminates the remaining error. + if (ABSL_PREDICT_TRUE(v != c)) return v < c; + + // It is guaranteed that `q` is strictly less than 1, because if `q` were + // greater than or equal to 1, the same would be true for `p`. Certainly `p` + // cannot be greater than 1, and if `p == 1`, then the fast path would + // necessary have been taken already. + const double q = static_cast<double>(c) / kP32; + + // The probability of acceptance on the fast path is `q` and so the + // probability of acceptance here should be `p - q`. + // + // Note that `q` is obtained from `p` via some shifts and conversions, the + // upshot of which is that `q` is simply `p` with some of the + // least-significant bits of its mantissa set to zero. This means that the + // difference `p - q` will not have any rounding errors. To see why, pretend + // that double has 10 bits of resolution and q is obtained from `p` in such + // a way that the 4 least-significant bits of its mantissa are set to zero. + // For example: + // p = 1.1100111011 * 2^-1 + // q = 1.1100110000 * 2^-1 + // p - q = 1.011 * 2^-8 + // The difference `p - q` has exactly the nonzero mantissa bits that were + // "lost" in `q` producing a number which is certainly representable in a + // double. + const double left = p - q; + + // By construction, the probability of being on this slow path is 1/2^32, so + // P(accept in slow path) = P(accept| in slow path) * P(slow path), + // which means the probability of acceptance here is `1 / (left * kP32)`: + const double here = left * kP32; + + // The simplest way to compute the result of this trial is to repeat the + // whole algorithm with the new probability. This terminates because even + // given arbitrarily unfriendly "random" bits, each iteration either + // multiplies a tiny probability by 2^32 (if c == 0) or strips off some + // number of nonzero mantissa bits. That process is bounded. + if (here == 0) return false; + p = here; + } +} + +} // namespace absl + +#endif // ABSL_RANDOM_BERNOULLI_DISTRIBUTION_H_ diff --git a/absl/random/bernoulli_distribution_test.cc b/absl/random/bernoulli_distribution_test.cc new file mode 100644 index 000000000000..f2c3b99cd4ff --- /dev/null +++ b/absl/random/bernoulli_distribution_test.cc @@ -0,0 +1,213 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/bernoulli_distribution.h" + +#include <cmath> +#include <cstddef> +#include <random> +#include <sstream> +#include <utility> + +#include "gtest/gtest.h" +#include "absl/random/internal/sequence_urbg.h" +#include "absl/random/random.h" + +namespace { + +class BernoulliTest : public testing::TestWithParam<std::pair<double, size_t>> { +}; + +TEST_P(BernoulliTest, Serialize) { + const double d = GetParam().first; + absl::bernoulli_distribution before(d); + + { + absl::bernoulli_distribution via_param{ + absl::bernoulli_distribution::param_type(d)}; + EXPECT_EQ(via_param, before); + } + + std::stringstream ss; + ss << before; + absl::bernoulli_distribution after(0.6789); + + EXPECT_NE(before.p(), after.p()); + EXPECT_NE(before.param(), after.param()); + EXPECT_NE(before, after); + + ss >> after; + + EXPECT_EQ(before.p(), after.p()); + EXPECT_EQ(before.param(), after.param()); + EXPECT_EQ(before, after); +} + +TEST_P(BernoulliTest, Accuracy) { + // Sadly, the claim to fame for this implementation is precise accuracy, which + // is very, very hard to measure, the improvements come as trials approach the + // limit of double accuracy; thus the outcome differs from the + // std::bernoulli_distribution with a probability of approximately 1 in 2^-53. + const std::pair<double, size_t> para = GetParam(); + size_t trials = para.second; + double p = para.first; + + absl::InsecureBitGen rng; + + size_t yes = 0; + absl::bernoulli_distribution dist(p); + for (size_t i = 0; i < trials; ++i) { + if (dist(rng)) yes++; + } + + // Compute the distribution parameters for a binomial test, using a normal + // approximation for the confidence interval, as there are a sufficiently + // large number of trials that the central limit theorem applies. + const double stddev_p = std::sqrt((p * (1.0 - p)) / trials); + const double expected = trials * p; + const double stddev = trials * stddev_p; + + // 5 sigma, approved by Richard Feynman + EXPECT_NEAR(yes, expected, 5 * stddev) + << "@" << p << ", " + << std::abs(static_cast<double>(yes) - expected) / stddev << " stddev"; +} + +// There must be many more trials to make the mean approximately normal for `p` +// closes to 0 or 1. +INSTANTIATE_TEST_SUITE_P( + All, BernoulliTest, + ::testing::Values( + // Typical values. + std::make_pair(0, 30000), std::make_pair(1e-3, 30000000), + std::make_pair(0.1, 3000000), std::make_pair(0.5, 3000000), + std::make_pair(0.9, 30000000), std::make_pair(0.999, 30000000), + std::make_pair(1, 30000), + // Boundary cases. + std::make_pair(std::nextafter(1.0, 0.0), 1), // ~1 - epsilon + std::make_pair(std::numeric_limits<double>::epsilon(), 1), + std::make_pair(std::nextafter(std::numeric_limits<double>::min(), + 1.0), // min + epsilon + 1), + std::make_pair(std::numeric_limits<double>::min(), // smallest normal + 1), + std::make_pair( + std::numeric_limits<double>::denorm_min(), // smallest denorm + 1), + std::make_pair(std::numeric_limits<double>::min() / 2, 1), // denorm + std::make_pair(std::nextafter(std::numeric_limits<double>::min(), + 0.0), // denorm_max + 1))); + +// NOTE: absl::bernoulli_distribution is not guaranteed to be stable. +TEST(BernoulliTest, StabilityTest) { + // absl::bernoulli_distribution stability relies on FastUniformBits and + // integer arithmetic. + absl::random_internal::sequence_urbg urbg({ + 0x0003eb76f6f7f755ull, 0xFFCEA50FDB2F953Bull, 0xC332DDEFBE6C5AA5ull, + 0x6558218568AB9702ull, 0x2AEF7DAD5B6E2F84ull, 0x1521B62829076170ull, + 0xECDD4775619F1510ull, 0x13CCA830EB61BD96ull, 0x0334FE1EAA0363CFull, + 0xB5735C904C70A239ull, 0xD59E9E0BCBAADE14ull, 0xEECC86BC60622CA7ull, + 0x4864f22c059bf29eull, 0x247856d8b862665cull, 0xe46e86e9a1337e10ull, + 0xd8c8541f3519b133ull, 0xe75b5162c567b9e4ull, 0xf732e5ded7009c5bull, + 0xb170b98353121eacull, 0x1ec2e8986d2362caull, 0x814c8e35fe9a961aull, + 0x0c3cd59c9b638a02ull, 0xcb3bb6478a07715cull, 0x1224e62c978bbc7full, + 0x671ef2cb04e81f6eull, 0x3c1cbd811eaf1808ull, 0x1bbc23cfa8fac721ull, + 0xa4c2cda65e596a51ull, 0xb77216fad37adf91ull, 0x836d794457c08849ull, + 0xe083df03475f49d7ull, 0xbc9feb512e6b0d6cull, 0xb12d74fdd718c8c5ull, + 0x12ff09653bfbe4caull, 0x8dd03a105bc4ee7eull, 0x5738341045ba0d85ull, + 0xe3fd722dc65ad09eull, 0x5a14fd21ea2a5705ull, 0x14e6ea4d6edb0c73ull, + 0x275b0dc7e0a18acfull, 0x36cebe0d2653682eull, 0x0361e9b23861596bull, + }); + + // Generate a std::string of '0' and '1' for the distribution output. + auto generate = [&urbg](absl::bernoulli_distribution& dist) { + std::string output; + output.reserve(36); + urbg.reset(); + for (int i = 0; i < 35; i++) { + output.append(dist(urbg) ? "1" : "0"); + } + return output; + }; + + const double kP = 0.0331289862362; + { + absl::bernoulli_distribution dist(kP); + auto v = generate(dist); + EXPECT_EQ(35, urbg.invocations()); + EXPECT_EQ(v, "00000000000010000000000010000000000") << dist; + } + { + absl::bernoulli_distribution dist(kP * 10.0); + auto v = generate(dist); + EXPECT_EQ(35, urbg.invocations()); + EXPECT_EQ(v, "00000100010010010010000011000011010") << dist; + } + { + absl::bernoulli_distribution dist(kP * 20.0); + auto v = generate(dist); + EXPECT_EQ(35, urbg.invocations()); + EXPECT_EQ(v, "00011110010110110011011111110111011") << dist; + } + { + absl::bernoulli_distribution dist(1.0 - kP); + auto v = generate(dist); + EXPECT_EQ(35, urbg.invocations()); + EXPECT_EQ(v, "11111111111111111111011111111111111") << dist; + } +} + +TEST(BernoulliTest, StabilityTest2) { + absl::random_internal::sequence_urbg urbg( + {0x0003eb76f6f7f755ull, 0xFFCEA50FDB2F953Bull, 0xC332DDEFBE6C5AA5ull, + 0x6558218568AB9702ull, 0x2AEF7DAD5B6E2F84ull, 0x1521B62829076170ull, + 0xECDD4775619F1510ull, 0x13CCA830EB61BD96ull, 0x0334FE1EAA0363CFull, + 0xB5735C904C70A239ull, 0xD59E9E0BCBAADE14ull, 0xEECC86BC60622CA7ull}); + + // Generate a std::string of '0' and '1' for the distribution output. + auto generate = [&urbg](absl::bernoulli_distribution& dist) { + std::string output; + output.reserve(13); + urbg.reset(); + for (int i = 0; i < 12; i++) { + output.append(dist(urbg) ? "1" : "0"); + } + return output; + }; + + constexpr double b0 = 1.0 / 13.0 / 0.2; + constexpr double b1 = 2.0 / 13.0 / 0.2; + constexpr double b3 = (5.0 / 13.0 / 0.2) - ((1 - b0) + (1 - b1) + (1 - b1)); + { + absl::bernoulli_distribution dist(b0); + auto v = generate(dist); + EXPECT_EQ(12, urbg.invocations()); + EXPECT_EQ(v, "000011100101") << dist; + } + { + absl::bernoulli_distribution dist(b1); + auto v = generate(dist); + EXPECT_EQ(12, urbg.invocations()); + EXPECT_EQ(v, "001111101101") << dist; + } + { + absl::bernoulli_distribution dist(b3); + auto v = generate(dist); + EXPECT_EQ(12, urbg.invocations()); + EXPECT_EQ(v, "001111101111") << dist; + } +} + +} // namespace diff --git a/absl/random/beta_distribution.h b/absl/random/beta_distribution.h new file mode 100644 index 000000000000..d7afd61c7fb1 --- /dev/null +++ b/absl/random/beta_distribution.h @@ -0,0 +1,414 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_BETA_DISTRIBUTION_H_ +#define ABSL_RANDOM_BETA_DISTRIBUTION_H_ + +#include <cassert> +#include <cmath> +#include <istream> +#include <limits> +#include <ostream> +#include <type_traits> + +#include "absl/random/internal/distribution_impl.h" +#include "absl/random/internal/fast_uniform_bits.h" +#include "absl/random/internal/fastmath.h" +#include "absl/random/internal/iostream_state_saver.h" + +namespace absl { + +// absl::beta_distribution: +// Generate a floating-point variate conforming to a Beta distribution: +// pdf(x) \propto x^(alpha-1) * (1-x)^(beta-1), +// where the params alpha and beta are both strictly positive real values. +// +// The support is the open interval (0, 1), but the return value might be equal +// to 0 or 1, due to numerical errors when alpha and beta are very different. +// +// Usage note: One usage is that alpha and beta are counts of number of +// successes and failures. When the total number of trials are large, consider +// approximating a beta distribution with a Gaussian distribution with the same +// mean and variance. One could use the skewness, which depends only on the +// smaller of alpha and beta when the number of trials are sufficiently large, +// to quantify how far a beta distribution is from the normal distribution. +template <typename RealType = double> +class beta_distribution { + public: + using result_type = RealType; + + class param_type { + public: + using distribution_type = beta_distribution; + + explicit param_type(result_type alpha, result_type beta) + : alpha_(alpha), beta_(beta) { + assert(alpha >= 0); + assert(beta >= 0); + assert(alpha <= (std::numeric_limits<result_type>::max)()); + assert(beta <= (std::numeric_limits<result_type>::max)()); + if (alpha == 0 || beta == 0) { + method_ = DEGENERATE_SMALL; + x_ = (alpha >= beta) ? 1 : 0; + return; + } + // a_ = min(beta, alpha), b_ = max(beta, alpha). + if (beta < alpha) { + inverted_ = true; + a_ = beta; + b_ = alpha; + } else { + inverted_ = false; + a_ = alpha; + b_ = beta; + } + if (a_ <= 1 && b_ >= ThresholdForLargeA()) { + method_ = DEGENERATE_SMALL; + x_ = inverted_ ? result_type(1) : result_type(0); + return; + } + // For threshold values, see also: + // Evaluation of Beta Generation Algorithms, Ying-Chao Hung, et. al. + // February, 2009. + if ((b_ < 1.0 && a_ + b_ <= 1.2) || a_ <= ThresholdForSmallA()) { + // Choose Joehnk over Cheng when it's faster or when Cheng encounters + // numerical issues. + method_ = JOEHNK; + a_ = result_type(1) / alpha_; + b_ = result_type(1) / beta_; + if (std::isinf(a_) || std::isinf(b_)) { + method_ = DEGENERATE_SMALL; + x_ = inverted_ ? result_type(1) : result_type(0); + } + return; + } + if (a_ >= ThresholdForLargeA()) { + method_ = DEGENERATE_LARGE; + // Note: on PPC for long double, evaluating + // `std::numeric_limits::max() / ThresholdForLargeA` results in NaN. + result_type r = a_ / b_; + x_ = (inverted_ ? result_type(1) : r) / (1 + r); + return; + } + x_ = a_ + b_; + log_x_ = std::log(x_); + if (a_ <= 1) { + method_ = CHENG_BA; + y_ = result_type(1) / a_; + gamma_ = a_ + a_; + return; + } + method_ = CHENG_BB; + result_type r = (a_ - 1) / (b_ - 1); + y_ = std::sqrt((1 + r) / (b_ * r * 2 - r + 1)); + gamma_ = a_ + result_type(1) / y_; + } + + result_type alpha() const { return alpha_; } + result_type beta() const { return beta_; } + + friend bool operator==(const param_type& a, const param_type& b) { + return a.alpha_ == b.alpha_ && a.beta_ == b.beta_; + } + + friend bool operator!=(const param_type& a, const param_type& b) { + return !(a == b); + } + + private: + friend class beta_distribution; + +#ifdef COMPILER_MSVC + // MSVC does not have constexpr implementations for std::log and std::exp + // so they are computed at runtime. +#define ABSL_RANDOM_INTERNAL_LOG_EXP_CONSTEXPR +#else +#define ABSL_RANDOM_INTERNAL_LOG_EXP_CONSTEXPR constexpr +#endif + + // The threshold for whether std::exp(1/a) is finite. + // Note that this value is quite large, and a smaller a_ is NOT abnormal. + static ABSL_RANDOM_INTERNAL_LOG_EXP_CONSTEXPR result_type + ThresholdForSmallA() { + return result_type(1) / + std::log((std::numeric_limits<result_type>::max)()); + } + + // The threshold for whether a * std::log(a) is finite. + static ABSL_RANDOM_INTERNAL_LOG_EXP_CONSTEXPR result_type + ThresholdForLargeA() { + return std::exp( + std::log((std::numeric_limits<result_type>::max)()) - + std::log(std::log((std::numeric_limits<result_type>::max)())) - + ThresholdPadding()); + } + +#undef ABSL_RANDOM_INTERNAL_LOG_EXP_CONSTEXPR + + // Pad the threshold for large A for long double on PPC. This is done via a + // template specialization below. + static constexpr result_type ThresholdPadding() { return 0; } + + enum Method { + JOEHNK, // Uses algorithm Joehnk + CHENG_BA, // Uses algorithm BA in Cheng + CHENG_BB, // Uses algorithm BB in Cheng + + // Note: See also: + // Hung et al. Evaluation of beta generation algorithms. Communications + // in Statistics-Simulation and Computation 38.4 (2009): 750-770. + // especially: + // Zechner, Heinz, and Ernst Stadlober. Generating beta variates via + // patchwork rejection. Computing 50.1 (1993): 1-18. + + DEGENERATE_SMALL, // a_ is abnormally small. + DEGENERATE_LARGE, // a_ is abnormally large. + }; + + result_type alpha_; + result_type beta_; + + result_type a_; // the smaller of {alpha, beta}, or 1.0/alpha_ in JOEHNK + result_type b_; // the larger of {alpha, beta}, or 1.0/beta_ in JOEHNK + result_type x_; // alpha + beta, or the result in degenerate cases + result_type log_x_; // log(x_) + result_type y_; // "beta" in Cheng + result_type gamma_; // "gamma" in Cheng + + Method method_; + + // Placing this last for optimal alignment. + // Whether alpha_ != a_, i.e. true iff alpha_ > beta_. + bool inverted_; + + static_assert(std::is_floating_point<RealType>::value, + "Class-template absl::beta_distribution<> must be " + "parameterized using a floating-point type."); + }; + + beta_distribution() : beta_distribution(1) {} + + explicit beta_distribution(result_type alpha, result_type beta = 1) + : param_(alpha, beta) {} + + explicit beta_distribution(const param_type& p) : param_(p) {} + + void reset() {} + + // Generating functions + template <typename URBG> + result_type operator()(URBG& g) { // NOLINT(runtime/references) + return (*this)(g, param_); + } + + template <typename URBG> + result_type operator()(URBG& g, // NOLINT(runtime/references) + const param_type& p); + + param_type param() const { return param_; } + void param(const param_type& p) { param_ = p; } + + result_type(min)() const { return 0; } + result_type(max)() const { return 1; } + + result_type alpha() const { return param_.alpha(); } + result_type beta() const { return param_.beta(); } + + friend bool operator==(const beta_distribution& a, + const beta_distribution& b) { + return a.param_ == b.param_; + } + friend bool operator!=(const beta_distribution& a, + const beta_distribution& b) { + return a.param_ != b.param_; + } + + private: + template <typename URBG> + result_type AlgorithmJoehnk(URBG& g, // NOLINT(runtime/references) + const param_type& p); + + template <typename URBG> + result_type AlgorithmCheng(URBG& g, // NOLINT(runtime/references) + const param_type& p); + + template <typename URBG> + result_type DegenerateCase(URBG& g, // NOLINT(runtime/references) + const param_type& p) { + if (p.method_ == param_type::DEGENERATE_SMALL && p.alpha_ == p.beta_) { + // Returns 0 or 1 with equal probability. + random_internal::FastUniformBits<uint8_t> fast_u8; + return static_cast<result_type>((fast_u8(g) & 0x10) != + 0); // pick any single bit. + } + return p.x_; + } + + param_type param_; + random_internal::FastUniformBits<uint64_t> fast_u64_; +}; + +#if defined(__powerpc64__) || defined(__PPC64__) || defined(__powerpc__) || \ + defined(__ppc__) || defined(__PPC__) +// PPC needs a more stringent boundary for long double. +template <> +constexpr long double +beta_distribution<long double>::param_type::ThresholdPadding() { + return 10; +} +#endif + +template <typename RealType> +template <typename URBG> +typename beta_distribution<RealType>::result_type +beta_distribution<RealType>::AlgorithmJoehnk( + URBG& g, // NOLINT(runtime/references) + const param_type& p) { + // Based on Joehnk, M. D. Erzeugung von betaverteilten und gammaverteilten + // Zufallszahlen. Metrika 8.1 (1964): 5-15. + // This method is described in Knuth, Vol 2 (Third Edition), pp 134. + using RandU64ToReal = typename random_internal::RandU64ToReal<result_type>; + using random_internal::PositiveValueT; + result_type u, v, x, y, z; + for (;;) { + u = RandU64ToReal::template Value<PositiveValueT, false>(fast_u64_(g)); + v = RandU64ToReal::template Value<PositiveValueT, false>(fast_u64_(g)); + + // Direct method. std::pow is slow for float, so rely on the optimizer to + // remove the std::pow() path for that case. + if (!std::is_same<float, result_type>::value) { + x = std::pow(u, p.a_); + y = std::pow(v, p.b_); + z = x + y; + if (z > 1) { + // Reject if and only if `x + y > 1.0` + continue; + } + if (z > 0) { + // When both alpha and beta are small, x and y are both close to 0, so + // divide by (x+y) directly may result in nan. + return x / z; + } + } + + // Log transform. + // x = log( pow(u, p.a_) ), y = log( pow(v, p.b_) ) + // since u, v <= 1.0, x, y < 0. + x = std::log(u) * p.a_; + y = std::log(v) * p.b_; + if (!std::isfinite(x) || !std::isfinite(y)) { + continue; + } + // z = log( pow(u, a) + pow(v, b) ) + z = x > y ? (x + std::log(1 + std::exp(y - x))) + : (y + std::log(1 + std::exp(x - y))); + // Reject iff log(x+y) > 0. + if (z > 0) { + continue; + } + return std::exp(x - z); + } +} + +template <typename RealType> +template <typename URBG> +typename beta_distribution<RealType>::result_type +beta_distribution<RealType>::AlgorithmCheng( + URBG& g, // NOLINT(runtime/references) + const param_type& p) { + // Based on Cheng, Russell CH. Generating beta variates with nonintegral + // shape parameters. Communications of the ACM 21.4 (1978): 317-322. + // (https://dl.acm.org/citation.cfm?id=359482). + using RandU64ToReal = typename random_internal::RandU64ToReal<result_type>; + using random_internal::PositiveValueT; + + static constexpr result_type kLogFour = + result_type(1.3862943611198906188344642429163531361); // log(4) + static constexpr result_type kS = + result_type(2.6094379124341003746007593332261876); // 1+log(5) + + const bool use_algorithm_ba = (p.method_ == param_type::CHENG_BA); + result_type u1, u2, v, w, z, r, s, t, bw_inv, lhs; + for (;;) { + u1 = RandU64ToReal::template Value<PositiveValueT, false>(fast_u64_(g)); + u2 = RandU64ToReal::template Value<PositiveValueT, false>(fast_u64_(g)); + v = p.y_ * std::log(u1 / (1 - u1)); + w = p.a_ * std::exp(v); + bw_inv = result_type(1) / (p.b_ + w); + r = p.gamma_ * v - kLogFour; + s = p.a_ + r - w; + z = u1 * u1 * u2; + if (!use_algorithm_ba && s + kS >= 5 * z) { + break; + } + t = std::log(z); + if (!use_algorithm_ba && s >= t) { + break; + } + lhs = p.x_ * (p.log_x_ + std::log(bw_inv)) + r; + if (lhs >= t) { + break; + } + } + return p.inverted_ ? (1 - w * bw_inv) : w * bw_inv; +} + +template <typename RealType> +template <typename URBG> +typename beta_distribution<RealType>::result_type +beta_distribution<RealType>::operator()(URBG& g, // NOLINT(runtime/references) + const param_type& p) { + switch (p.method_) { + case param_type::JOEHNK: + return AlgorithmJoehnk(g, p); + case param_type::CHENG_BA: + ABSL_FALLTHROUGH_INTENDED; + case param_type::CHENG_BB: + return AlgorithmCheng(g, p); + default: + return DegenerateCase(g, p); + } +} + +template <typename CharT, typename Traits, typename RealType> +std::basic_ostream<CharT, Traits>& operator<<( + std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references) + const beta_distribution<RealType>& x) { + auto saver = random_internal::make_ostream_state_saver(os); + os.precision(random_internal::stream_precision_helper<RealType>::kPrecision); + os << x.alpha() << os.fill() << x.beta(); + return os; +} + +template <typename CharT, typename Traits, typename RealType> +std::basic_istream<CharT, Traits>& operator>>( + std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references) + beta_distribution<RealType>& x) { // NOLINT(runtime/references) + using result_type = typename beta_distribution<RealType>::result_type; + using param_type = typename beta_distribution<RealType>::param_type; + result_type alpha, beta; + + auto saver = random_internal::make_istream_state_saver(is); + alpha = random_internal::read_floating_point<result_type>(is); + if (is.fail()) return is; + beta = random_internal::read_floating_point<result_type>(is); + if (!is.fail()) { + x.param(param_type(alpha, beta)); + } + return is; +} + +} // namespace absl + +#endif // ABSL_RANDOM_BETA_DISTRIBUTION_H_ diff --git a/absl/random/beta_distribution_test.cc b/absl/random/beta_distribution_test.cc new file mode 100644 index 000000000000..966ad08b5c91 --- /dev/null +++ b/absl/random/beta_distribution_test.cc @@ -0,0 +1,614 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/beta_distribution.h" + +#include <algorithm> +#include <cstddef> +#include <cstdint> +#include <iterator> +#include <random> +#include <sstream> +#include <string> +#include <unordered_map> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/base/internal/raw_logging.h" +#include "absl/random/internal/chi_square.h" +#include "absl/random/internal/distribution_test_util.h" +#include "absl/random/internal/sequence_urbg.h" +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/strip.h" + +namespace { + +template <typename IntType> +class BetaDistributionInterfaceTest : public ::testing::Test {}; + +using RealTypes = ::testing::Types<float, double, long double>; +TYPED_TEST_CASE(BetaDistributionInterfaceTest, RealTypes); + +TYPED_TEST(BetaDistributionInterfaceTest, SerializeTest) { + // The threshold for whether std::exp(1/a) is finite. + const TypeParam kSmallA = + 1.0f / std::log((std::numeric_limits<TypeParam>::max)()); + // The threshold for whether a * std::log(a) is finite. + const TypeParam kLargeA = + std::exp(std::log((std::numeric_limits<TypeParam>::max)()) - + std::log(std::log((std::numeric_limits<TypeParam>::max)()))); + const TypeParam kLargeAPPC = std::exp( + std::log((std::numeric_limits<TypeParam>::max)()) - + std::log(std::log((std::numeric_limits<TypeParam>::max)())) - 10.0f); + using param_type = typename absl::beta_distribution<TypeParam>::param_type; + + constexpr int kCount = 1000; + absl::InsecureBitGen gen; + const TypeParam kValues[] = { + TypeParam(1e-20), TypeParam(1e-12), TypeParam(1e-8), TypeParam(1e-4), + TypeParam(1e-3), TypeParam(0.1), TypeParam(0.25), + std::nextafter(TypeParam(0.5), TypeParam(0)), // 0.5 - epsilon + std::nextafter(TypeParam(0.5), TypeParam(1)), // 0.5 + epsilon + TypeParam(0.5), TypeParam(1.0), // + std::nextafter(TypeParam(1), TypeParam(0)), // 1 - epsilon + std::nextafter(TypeParam(1), TypeParam(2)), // 1 + epsilon + TypeParam(12.5), TypeParam(1e2), TypeParam(1e8), TypeParam(1e12), + TypeParam(1e20), // + kSmallA, // + std::nextafter(kSmallA, TypeParam(0)), // + std::nextafter(kSmallA, TypeParam(1)), // + kLargeA, // + std::nextafter(kLargeA, TypeParam(0)), // + std::nextafter(kLargeA, std::numeric_limits<TypeParam>::max()), + kLargeAPPC, // + std::nextafter(kLargeAPPC, TypeParam(0)), + std::nextafter(kLargeAPPC, std::numeric_limits<TypeParam>::max()), + // Boundary cases. + std::numeric_limits<TypeParam>::max(), + std::numeric_limits<TypeParam>::epsilon(), + std::nextafter(std::numeric_limits<TypeParam>::min(), + TypeParam(1)), // min + epsilon + std::numeric_limits<TypeParam>::min(), // smallest normal + std::numeric_limits<TypeParam>::denorm_min(), // smallest denorm + std::numeric_limits<TypeParam>::min() / 2, // denorm + std::nextafter(std::numeric_limits<TypeParam>::min(), + TypeParam(0)), // denorm_max + }; + for (TypeParam alpha : kValues) { + for (TypeParam beta : kValues) { + ABSL_INTERNAL_LOG( + INFO, absl::StrFormat("Smoke test for Beta(%f, %f)", alpha, beta)); + + param_type param(alpha, beta); + absl::beta_distribution<TypeParam> before(alpha, beta); + EXPECT_EQ(before.alpha(), param.alpha()); + EXPECT_EQ(before.beta(), param.beta()); + + { + absl::beta_distribution<TypeParam> via_param(param); + EXPECT_EQ(via_param, before); + EXPECT_EQ(via_param.param(), before.param()); + } + + // Smoke test. + for (int i = 0; i < kCount; ++i) { + auto sample = before(gen); + EXPECT_TRUE(std::isfinite(sample)); + EXPECT_GE(sample, before.min()); + EXPECT_LE(sample, before.max()); + } + + // Validate stream serialization. + std::stringstream ss; + ss << before; + absl::beta_distribution<TypeParam> after(3.8f, 1.43f); + EXPECT_NE(before.alpha(), after.alpha()); + EXPECT_NE(before.beta(), after.beta()); + EXPECT_NE(before.param(), after.param()); + EXPECT_NE(before, after); + + ss >> after; + +#if defined(__powerpc64__) || defined(__PPC64__) || defined(__powerpc__) || \ + defined(__ppc__) || defined(__PPC__) + if (std::is_same<TypeParam, long double>::value) { + // Roundtripping floating point values requires sufficient precision + // to reconstruct the exact value. It turns out that long double + // has some errors doing this on ppc. + if (alpha <= std::numeric_limits<double>::max() && + alpha >= std::numeric_limits<double>::lowest()) { + EXPECT_EQ(static_cast<double>(before.alpha()), + static_cast<double>(after.alpha())) + << ss.str(); + } + if (beta <= std::numeric_limits<double>::max() && + beta >= std::numeric_limits<double>::lowest()) { + EXPECT_EQ(static_cast<double>(before.beta()), + static_cast<double>(after.beta())) + << ss.str(); + } + continue; + } +#endif + + EXPECT_EQ(before.alpha(), after.alpha()); + EXPECT_EQ(before.beta(), after.beta()); + EXPECT_EQ(before, after) // + << ss.str() << " " // + << (ss.good() ? "good " : "") // + << (ss.bad() ? "bad " : "") // + << (ss.eof() ? "eof " : "") // + << (ss.fail() ? "fail " : ""); + } + } +} + +TYPED_TEST(BetaDistributionInterfaceTest, DegenerateCases) { + // Extreme cases when the params are abnormal. + absl::InsecureBitGen gen; + constexpr int kCount = 1000; + const TypeParam kSmallValues[] = { + std::numeric_limits<TypeParam>::min(), + std::numeric_limits<TypeParam>::denorm_min(), + std::nextafter(std::numeric_limits<TypeParam>::min(), + TypeParam(0)), // denorm_max + std::numeric_limits<TypeParam>::epsilon(), + }; + const TypeParam kLargeValues[] = { + std::numeric_limits<TypeParam>::max() * static_cast<TypeParam>(0.9999), + std::numeric_limits<TypeParam>::max() - 1, + std::numeric_limits<TypeParam>::max(), + }; + { + // Small alpha and beta. + // Useful WolframAlpha plots: + // * plot InverseBetaRegularized[x, 0.0001, 0.0001] from 0.495 to 0.505 + // * Beta[1.0, 0.0000001, 0.0000001] + // * Beta[0.9999, 0.0000001, 0.0000001] + for (TypeParam alpha : kSmallValues) { + for (TypeParam beta : kSmallValues) { + int zeros = 0; + int ones = 0; + absl::beta_distribution<TypeParam> d(alpha, beta); + for (int i = 0; i < kCount; ++i) { + TypeParam x = d(gen); + if (x == 0.0) { + zeros++; + } else if (x == 1.0) { + ones++; + } + } + EXPECT_EQ(ones + zeros, kCount); + if (alpha == beta) { + EXPECT_NE(ones, 0); + EXPECT_NE(zeros, 0); + } + } + } + } + { + // Small alpha, large beta. + // Useful WolframAlpha plots: + // * plot InverseBetaRegularized[x, 0.0001, 10000] from 0.995 to 1 + // * Beta[0, 0.0000001, 1000000] + // * Beta[0.001, 0.0000001, 1000000] + // * Beta[1, 0.0000001, 1000000] + for (TypeParam alpha : kSmallValues) { + for (TypeParam beta : kLargeValues) { + absl::beta_distribution<TypeParam> d(alpha, beta); + for (int i = 0; i < kCount; ++i) { + EXPECT_EQ(d(gen), 0.0); + } + } + } + } + { + // Large alpha, small beta. + // Useful WolframAlpha plots: + // * plot InverseBetaRegularized[x, 10000, 0.0001] from 0 to 0.001 + // * Beta[0.99, 1000000, 0.0000001] + // * Beta[1, 1000000, 0.0000001] + for (TypeParam alpha : kLargeValues) { + for (TypeParam beta : kSmallValues) { + absl::beta_distribution<TypeParam> d(alpha, beta); + for (int i = 0; i < kCount; ++i) { + EXPECT_EQ(d(gen), 1.0); + } + } + } + } + { + // Large alpha and beta. + absl::beta_distribution<TypeParam> d(std::numeric_limits<TypeParam>::max(), + std::numeric_limits<TypeParam>::max()); + for (int i = 0; i < kCount; ++i) { + EXPECT_EQ(d(gen), 0.5); + } + } + { + // Large alpha and beta but unequal. + absl::beta_distribution<TypeParam> d( + std::numeric_limits<TypeParam>::max(), + std::numeric_limits<TypeParam>::max() * 0.9999); + for (int i = 0; i < kCount; ++i) { + TypeParam x = d(gen); + EXPECT_NE(x, 0.5f); + EXPECT_FLOAT_EQ(x, 0.500025f); + } + } +} + +class BetaDistributionModel { + public: + explicit BetaDistributionModel(::testing::tuple<double, double> p) + : alpha_(::testing::get<0>(p)), beta_(::testing::get<1>(p)) {} + + double Mean() const { return alpha_ / (alpha_ + beta_); } + + double Variance() const { + return alpha_ * beta_ / (alpha_ + beta_ + 1) / (alpha_ + beta_) / + (alpha_ + beta_); + } + + double Kurtosis() const { + return 3 + 6 * + ((alpha_ - beta_) * (alpha_ - beta_) * (alpha_ + beta_ + 1) - + alpha_ * beta_ * (2 + alpha_ + beta_)) / + alpha_ / beta_ / (alpha_ + beta_ + 2) / (alpha_ + beta_ + 3); + } + + protected: + const double alpha_; + const double beta_; +}; + +class BetaDistributionTest + : public ::testing::TestWithParam<::testing::tuple<double, double>>, + public BetaDistributionModel { + public: + BetaDistributionTest() : BetaDistributionModel(GetParam()) {} + + protected: + template <class D> + bool SingleZTestOnMeanAndVariance(double p, size_t samples); + + template <class D> + bool SingleChiSquaredTest(double p, size_t samples, size_t buckets); + + absl::InsecureBitGen rng_; +}; + +template <class D> +bool BetaDistributionTest::SingleZTestOnMeanAndVariance(double p, + size_t samples) { + D dis(alpha_, beta_); + + std::vector<double> data; + data.reserve(samples); + for (size_t i = 0; i < samples; i++) { + const double variate = dis(rng_); + EXPECT_FALSE(std::isnan(variate)); + // Note that equality is allowed on both sides. + EXPECT_GE(variate, 0.0); + EXPECT_LE(variate, 1.0); + data.push_back(variate); + } + + // We validate that the sample mean and sample variance are indeed from a + // Beta distribution with the given shape parameters. + const auto m = absl::random_internal::ComputeDistributionMoments(data); + + // The variance of the sample mean is variance / n. + const double mean_stddev = std::sqrt(Variance() / static_cast<double>(m.n)); + + // The variance of the sample variance is (approximately): + // (kurtosis - 1) * variance^2 / n + const double variance_stddev = std::sqrt( + (Kurtosis() - 1) * Variance() * Variance() / static_cast<double>(m.n)); + // z score for the sample variance. + const double z_variance = (m.variance - Variance()) / variance_stddev; + + const double max_err = absl::random_internal::MaxErrorTolerance(p); + const double z_mean = absl::random_internal::ZScore(Mean(), m); + const bool pass = + absl::random_internal::Near("z", z_mean, 0.0, max_err) && + absl::random_internal::Near("z_variance", z_variance, 0.0, max_err); + if (!pass) { + ABSL_INTERNAL_LOG( + INFO, + absl::StrFormat( + "Beta(%f, %f), " + "mean: sample %f, expect %f, which is %f stddevs away, " + "variance: sample %f, expect %f, which is %f stddevs away.", + alpha_, beta_, m.mean, Mean(), + std::abs(m.mean - Mean()) / mean_stddev, m.variance, Variance(), + std::abs(m.variance - Variance()) / variance_stddev)); + } + return pass; +} + +template <class D> +bool BetaDistributionTest::SingleChiSquaredTest(double p, size_t samples, + size_t buckets) { + constexpr double kErr = 1e-7; + std::vector<double> cutoffs, expected; + const double bucket_width = 1.0 / static_cast<double>(buckets); + int i = 1; + int unmerged_buckets = 0; + for (; i < buckets; ++i) { + const double p = bucket_width * static_cast<double>(i); + const double boundary = + absl::random_internal::BetaIncompleteInv(alpha_, beta_, p); + // The intention is to add `boundary` to the list of `cutoffs`. It becomes + // problematic, however, when the boundary values are not monotone, due to + // numerical issues when computing the inverse regularized incomplete + // Beta function. In these cases, we merge that bucket with its previous + // neighbor and merge their expected counts. + if ((cutoffs.empty() && boundary < kErr) || + (!cutoffs.empty() && boundary <= cutoffs.back())) { + unmerged_buckets++; + continue; + } + if (boundary >= 1.0 - 1e-10) { + break; + } + cutoffs.push_back(boundary); + expected.push_back(static_cast<double>(1 + unmerged_buckets) * + bucket_width * static_cast<double>(samples)); + unmerged_buckets = 0; + } + cutoffs.push_back(std::numeric_limits<double>::infinity()); + // Merge all remaining buckets. + expected.push_back(static_cast<double>(buckets - i + 1) * bucket_width * + static_cast<double>(samples)); + // Make sure that we don't merge all the buckets, making this test + // meaningless. + EXPECT_GE(cutoffs.size(), 3) << alpha_ << ", " << beta_; + + D dis(alpha_, beta_); + + std::vector<int32_t> counts(cutoffs.size(), 0); + for (int i = 0; i < samples; i++) { + const double x = dis(rng_); + auto it = std::upper_bound(cutoffs.begin(), cutoffs.end(), x); + counts[std::distance(cutoffs.begin(), it)]++; + } + + // Null-hypothesis is that the distribution is beta distributed with the + // provided alpha, beta params (not estimated from the data). + const int dof = cutoffs.size() - 1; + + const double chi_square = absl::random_internal::ChiSquare( + counts.begin(), counts.end(), expected.begin(), expected.end()); + const bool pass = + (absl::random_internal::ChiSquarePValue(chi_square, dof) >= p); + if (!pass) { + for (int i = 0; i < cutoffs.size(); i++) { + ABSL_INTERNAL_LOG( + INFO, absl::StrFormat("cutoff[%d] = %f, actual count %d, expected %d", + i, cutoffs[i], counts[i], + static_cast<int>(expected[i]))); + } + + ABSL_INTERNAL_LOG( + INFO, absl::StrFormat( + "Beta(%f, %f) %s %f, p = %f", alpha_, beta_, + absl::random_internal::kChiSquared, chi_square, + absl::random_internal::ChiSquarePValue(chi_square, dof))); + } + return pass; +} + +TEST_P(BetaDistributionTest, TestSampleStatistics) { + static constexpr int kRuns = 20; + static constexpr double kPFail = 0.02; + const double p = + absl::random_internal::RequiredSuccessProbability(kPFail, kRuns); + static constexpr int kSampleCount = 10000; + static constexpr int kBucketCount = 100; + int failed = 0; + for (int i = 0; i < kRuns; ++i) { + if (!SingleZTestOnMeanAndVariance<absl::beta_distribution<double>>( + p, kSampleCount)) { + failed++; + } + if (!SingleChiSquaredTest<absl::beta_distribution<double>>( + 0.005, kSampleCount, kBucketCount)) { + failed++; + } + } + // Set so that the test is not flaky at --runs_per_test=10000 + EXPECT_LE(failed, 5); +} + +std::string ParamName( + const ::testing::TestParamInfo<::testing::tuple<double, double>>& info) { + std::string name = absl::StrCat("alpha_", ::testing::get<0>(info.param), + "__beta_", ::testing::get<1>(info.param)); + return absl::StrReplaceAll(name, {{"+", "_"}, {"-", "_"}, {".", "_"}}); +} + +INSTANTIATE_TEST_CASE_P( + TestSampleStatisticsCombinations, BetaDistributionTest, + ::testing::Combine(::testing::Values(0.1, 0.2, 0.9, 1.1, 2.5, 10.0, 123.4), + ::testing::Values(0.1, 0.2, 0.9, 1.1, 2.5, 10.0, 123.4)), + ParamName); + +INSTANTIATE_TEST_CASE_P( + TestSampleStatistics_SelectedPairs, BetaDistributionTest, + ::testing::Values(std::make_pair(0.5, 1000), std::make_pair(1000, 0.5), + std::make_pair(900, 1000), std::make_pair(10000, 20000), + std::make_pair(4e5, 2e7), std::make_pair(1e7, 1e5)), + ParamName); + +// NOTE: absl::beta_distribution is not guaranteed to be stable. +TEST(BetaDistributionTest, StabilityTest) { + // absl::beta_distribution stability relies on the stability of + // absl::random_interna::RandU64ToDouble, std::exp, std::log, std::pow, + // and std::sqrt. + // + // This test also depends on the stability of std::frexp. + using testing::ElementsAre; + absl::random_internal::sequence_urbg urbg({ + 0xffff00000000e6c8ull, 0xffff0000000006c8ull, 0x800003766295CFA9ull, + 0x11C819684E734A41ull, 0x832603766295CFA9ull, 0x7fbe76c8b4395800ull, + 0xB3472DCA7B14A94Aull, 0x0003eb76f6f7f755ull, 0xFFCEA50FDB2F953Bull, + 0x13CCA830EB61BD96ull, 0x0334FE1EAA0363CFull, 0x00035C904C70A239ull, + 0x00009E0BCBAADE14ull, 0x0000000000622CA7ull, 0x4864f22c059bf29eull, + 0x247856d8b862665cull, 0xe46e86e9a1337e10ull, 0xd8c8541f3519b133ull, + 0xffe75b52c567b9e4ull, 0xfffff732e5709c5bull, 0xff1f7f0b983532acull, + 0x1ec2e8986d2362caull, 0xC332DDEFBE6C5AA5ull, 0x6558218568AB9702ull, + 0x2AEF7DAD5B6E2F84ull, 0x1521B62829076170ull, 0xECDD4775619F1510ull, + 0x814c8e35fe9a961aull, 0x0c3cd59c9b638a02ull, 0xcb3bb6478a07715cull, + 0x1224e62c978bbc7full, 0x671ef2cb04e81f6eull, 0x3c1cbd811eaf1808ull, + 0x1bbc23cfa8fac721ull, 0xa4c2cda65e596a51ull, 0xb77216fad37adf91ull, + 0x836d794457c08849ull, 0xe083df03475f49d7ull, 0xbc9feb512e6b0d6cull, + 0xb12d74fdd718c8c5ull, 0x12ff09653bfbe4caull, 0x8dd03a105bc4ee7eull, + 0x5738341045ba0d85ull, 0xf3fd722dc65ad09eull, 0xfa14fd21ea2a5705ull, + 0xffe6ea4d6edb0c73ull, 0xD07E9EFE2BF11FB4ull, 0x95DBDA4DAE909198ull, + 0xEAAD8E716B93D5A0ull, 0xD08ED1D0AFC725E0ull, 0x8E3C5B2F8E7594B7ull, + 0x8FF6E2FBF2122B64ull, 0x8888B812900DF01Cull, 0x4FAD5EA0688FC31Cull, + 0xD1CFF191B3A8C1ADull, 0x2F2F2218BE0E1777ull, 0xEA752DFE8B021FA1ull, + }); + + // Convert the real-valued result into a unit64 where we compare + // 5 (float) or 10 (double) decimal digits plus the base-2 exponent. + auto float_to_u64 = [](float d) { + int exp = 0; + auto f = std::frexp(d, &exp); + return (static_cast<uint64_t>(1e5 * f) * 10000) + std::abs(exp); + }; + auto double_to_u64 = [](double d) { + int exp = 0; + auto f = std::frexp(d, &exp); + return (static_cast<uint64_t>(1e10 * f) * 10000) + std::abs(exp); + }; + + std::vector<uint64_t> output(20); + { + // Algorithm Joehnk (float) + absl::beta_distribution<float> dist(0.1f, 0.2f); + std::generate(std::begin(output), std::end(output), + [&] { return float_to_u64(dist(urbg)); }); + EXPECT_EQ(44, urbg.invocations()); + EXPECT_THAT(output, // + testing::ElementsAre( + 998340000, 619030004, 500000001, 999990000, 996280000, + 500000001, 844740004, 847210001, 999970000, 872320000, + 585480007, 933280000, 869080042, 647670031, 528240004, + 969980004, 626050008, 915930002, 833440033, 878040015)); + } + + urbg.reset(); + { + // Algorithm Joehnk (double) + absl::beta_distribution<double> dist(0.1, 0.2); + std::generate(std::begin(output), std::end(output), + [&] { return double_to_u64(dist(urbg)); }); + EXPECT_EQ(44, urbg.invocations()); + EXPECT_THAT( + output, // + testing::ElementsAre( + 99834713000000, 61903356870004, 50000000000001, 99999721170000, + 99628374770000, 99999999990000, 84474397860004, 84721276240001, + 99997407490000, 87232528120000, 58548364780007, 93328932910000, + 86908237770042, 64767917930031, 52824581970004, 96998544140004, + 62605946270008, 91593604380002, 83345031740033, 87804397230015)); + } + + urbg.reset(); + { + // Algorithm Cheng 1 + absl::beta_distribution<double> dist(0.9, 2.0); + std::generate(std::begin(output), std::end(output), + [&] { return double_to_u64(dist(urbg)); }); + EXPECT_EQ(62, urbg.invocations()); + EXPECT_THAT( + output, // + testing::ElementsAre( + 62069004780001, 64433204450001, 53607416560000, 89644295430008, + 61434586310019, 55172615890002, 62187161490000, 56433684810003, + 80454622050005, 86418558710003, 92920514700001, 64645184680001, + 58549183380000, 84881283650005, 71078728590002, 69949694970000, + 73157461710001, 68592191300001, 70747623900000, 78584696930005)); + } + + urbg.reset(); + { + // Algorithm Cheng 2 + absl::beta_distribution<double> dist(1.5, 2.5); + std::generate(std::begin(output), std::end(output), + [&] { return double_to_u64(dist(urbg)); }); + EXPECT_EQ(54, urbg.invocations()); + EXPECT_THAT( + output, // + testing::ElementsAre( + 75000029250001, 76751482860001, 53264575220000, 69193133650005, + 78028324470013, 91573587560002, 59167523770000, 60658618560002, + 80075870540000, 94141320460004, 63196592770003, 78883906300002, + 96797992590001, 76907587800001, 56645167560000, 65408302280003, + 53401156320001, 64731238570000, 83065573750001, 79788333820001)); + } +} + +// This is an implementation-specific test. If any part of the implementation +// changes, then it is likely that this test will change as well. Also, if +// dependencies of the distribution change, such as RandU64ToDouble, then this +// is also likely to change. +TEST(BetaDistributionTest, AlgorithmBounds) { + { + absl::random_internal::sequence_urbg urbg( + {0x7fbe76c8b4395800ull, 0x8000000000000000ull}); + // u=0.499, v=0.5 + absl::beta_distribution<double> dist(1e-4, 1e-4); + double a = dist(urbg); + EXPECT_EQ(a, 2.0202860861567108529e-09); + EXPECT_EQ(2, urbg.invocations()); + } + + // Test that both the float & double algorithms appropriately reject the + // initial draw. + { + // 1/alpha = 1/beta = 2. + absl::beta_distribution<float> dist(0.5, 0.5); + + // first two outputs are close to 1.0 - epsilon, + // thus: (u ^ 2 + v ^ 2) > 1.0 + absl::random_internal::sequence_urbg urbg( + {0xffff00000006e6c8ull, 0xffff00000007c7c8ull, 0x800003766295CFA9ull, + 0x11C819684E734A41ull}); + { + double y = absl::beta_distribution<double>(0.5, 0.5)(urbg); + EXPECT_EQ(4, urbg.invocations()); + EXPECT_EQ(y, 0.9810668952633862) << y; + } + + // ...and: log(u) * a ~= log(v) * b ~= -0.02 + // thus z ~= -0.02 + log(1 + e(~0)) + // ~= -0.02 + 0.69 + // thus z > 0 + urbg.reset(); + { + float x = absl::beta_distribution<float>(0.5, 0.5)(urbg); + EXPECT_EQ(4, urbg.invocations()); + EXPECT_NEAR(0.98106688261032104, x, 0.0000005) << x << "f"; + } + } +} + +} // namespace diff --git a/absl/random/discrete_distribution.cc b/absl/random/discrete_distribution.cc new file mode 100644 index 000000000000..e6c09c5180b3 --- /dev/null +++ b/absl/random/discrete_distribution.cc @@ -0,0 +1,96 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/discrete_distribution.h" + +namespace absl { +namespace random_internal { + +// Initializes the distribution table for Walker's Aliasing algorithm, described +// in Knuth, Vol 2. as well as in https://en.wikipedia.org/wiki/Alias_method +std::vector<std::pair<double, size_t>> InitDiscreteDistribution( + std::vector<double>* probabilities) { + // The empty-case should already be handled by the constructor. + assert(probabilities); + assert(!probabilities->empty()); + + // Step 1. Normalize the input probabilities to 1.0. + double sum = std::accumulate(std::begin(*probabilities), + std::end(*probabilities), 0.0); + if (std::fabs(sum - 1.0) > 1e-6) { + // Scale `probabilities` only when the sum is too far from 1.0. Scaling + // unconditionally will alter the probabilities slightly. + for (double& item : *probabilities) { + item = item / sum; + } + } + + // Step 2. At this point `probabilities` is set to the conditional + // probabilities of each element which sum to 1.0, to within reasonable error. + // These values are used to construct the proportional probability tables for + // the selection phases of Walker's Aliasing algorithm. + // + // To construct the table, pick an element which is under-full (i.e., an + // element for which `(*probabilities)[i] < 1.0/n`), and pair it with an + // element which is over-full (i.e., an element for which + // `(*probabilities)[i] > 1.0/n`). The smaller value can always be retired. + // The larger may still be greater than 1.0/n, or may now be less than 1.0/n, + // and put back onto the appropriate collection. + const size_t n = probabilities->size(); + std::vector<std::pair<double, size_t>> q; + q.reserve(n); + + std::vector<size_t> over; + std::vector<size_t> under; + size_t idx = 0; + for (const double item : *probabilities) { + assert(item >= 0); + const double v = item * n; + q.emplace_back(v, 0); + if (v < 1.0) { + under.push_back(idx++); + } else { + over.push_back(idx++); + } + } + while (!over.empty() && !under.empty()) { + auto lo = under.back(); + under.pop_back(); + auto hi = over.back(); + over.pop_back(); + + q[lo].second = hi; + const double r = q[hi].first - (1.0 - q[lo].first); + q[hi].first = r; + if (r < 1.0) { + under.push_back(hi); + } else { + over.push_back(hi); + } + } + + // Due to rounding errors, there may be un-paired elements in either + // collection; these should all be values near 1.0. For these values, set `q` + // to 1.0 and set the alternate to the identity. + for (auto i : over) { + q[i] = {1.0, i}; + } + for (auto i : under) { + q[i] = {1.0, i}; + } + return q; +} + +} // namespace random_internal +} // namespace absl diff --git a/absl/random/discrete_distribution.h b/absl/random/discrete_distribution.h new file mode 100644 index 000000000000..1560f03c5fdd --- /dev/null +++ b/absl/random/discrete_distribution.h @@ -0,0 +1,245 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_DISCRETE_DISTRIBUTION_H_ +#define ABSL_RANDOM_DISCRETE_DISTRIBUTION_H_ + +#include <cassert> +#include <cmath> +#include <istream> +#include <limits> +#include <numeric> +#include <type_traits> +#include <utility> +#include <vector> + +#include "absl/random/bernoulli_distribution.h" +#include "absl/random/internal/iostream_state_saver.h" +#include "absl/random/uniform_int_distribution.h" + +namespace absl { + +// absl::discrete_distribution +// +// A discrete distribution produces random integers i, where 0 <= i < n +// distributed according to the discrete probability function: +// +// P(i|p0,...,pnโ1)=pi +// +// This class is an implementation of discrete_distribution (see +// [rand.dist.samp.discrete]). +// +// The algorithm used is Walker's Aliasing algorithm, described in Knuth, Vol 2. +// absl::discrete_distribution takes O(N) time to precompute the probabilities +// (where N is the number of possible outcomes in the distribution) at +// construction, and then takes O(1) time for each variate generation. Many +// other implementations also take O(N) time to construct an ordered sequence of +// partial sums, plus O(log N) time per variate to binary search. +// +template <typename IntType = int> +class discrete_distribution { + public: + using result_type = IntType; + + class param_type { + public: + using distribution_type = discrete_distribution; + + param_type() { init(); } + + template <typename InputIterator> + explicit param_type(InputIterator begin, InputIterator end) + : p_(begin, end) { + init(); + } + + explicit param_type(std::initializer_list<double> weights) : p_(weights) { + init(); + } + + template <class UnaryOperation> + explicit param_type(size_t nw, double xmin, double xmax, + UnaryOperation fw) { + if (nw > 0) { + p_.reserve(nw); + double delta = (xmax - xmin) / static_cast<double>(nw); + assert(delta > 0); + double t = delta * 0.5; + for (size_t i = 0; i < nw; ++i) { + p_.push_back(fw(xmin + i * delta + t)); + } + } + init(); + } + + const std::vector<double>& probabilities() const { return p_; } + size_t n() const { return p_.size() - 1; } + + friend bool operator==(const param_type& a, const param_type& b) { + return a.probabilities() == b.probabilities(); + } + + friend bool operator!=(const param_type& a, const param_type& b) { + return !(a == b); + } + + private: + friend class discrete_distribution; + + void init(); + + std::vector<double> p_; // normalized probabilities + std::vector<std::pair<double, size_t>> q_; // (acceptance, alternate) pairs + + static_assert(std::is_integral<result_type>::value, + "Class-template absl::discrete_distribution<> must be " + "parameterized using an integral type."); + }; + + discrete_distribution() : param_() {} + + explicit discrete_distribution(const param_type& p) : param_(p) {} + + template <typename InputIterator> + explicit discrete_distribution(InputIterator begin, InputIterator end) + : param_(begin, end) {} + + explicit discrete_distribution(std::initializer_list<double> weights) + : param_(weights) {} + + template <class UnaryOperation> + explicit discrete_distribution(size_t nw, double xmin, double xmax, + UnaryOperation fw) + : param_(nw, xmin, xmax, std::move(fw)) {} + + void reset() {} + + // generating functions + template <typename URBG> + result_type operator()(URBG& g) { // NOLINT(runtime/references) + return (*this)(g, param_); + } + + template <typename URBG> + result_type operator()(URBG& g, // NOLINT(runtime/references) + const param_type& p); + + const param_type& param() const { return param_; } + void param(const param_type& p) { param_ = p; } + + result_type(min)() const { return 0; } + result_type(max)() const { + return static_cast<result_type>(param_.n()); + } // inclusive + + // NOTE [rand.dist.sample.discrete] returns a std::vector<double> not a + // const std::vector<double>&. + const std::vector<double>& probabilities() const { + return param_.probabilities(); + } + + friend bool operator==(const discrete_distribution& a, + const discrete_distribution& b) { + return a.param_ == b.param_; + } + friend bool operator!=(const discrete_distribution& a, + const discrete_distribution& b) { + return a.param_ != b.param_; + } + + private: + param_type param_; +}; + +// -------------------------------------------------------------------------- +// Implementation details only below +// -------------------------------------------------------------------------- + +namespace random_internal { + +// Using the vector `*probabilities`, whose values are the weights or +// probabilities of an element being selected, constructs the proportional +// probabilities used by the discrete distribution. `*probabilities` will be +// scaled, if necessary, so that its entries sum to a value sufficiently close +// to 1.0. +std::vector<std::pair<double, size_t>> InitDiscreteDistribution( + std::vector<double>* probabilities); + +} // namespace random_internal + +template <typename IntType> +void discrete_distribution<IntType>::param_type::init() { + if (p_.empty()) { + p_.push_back(1.0); + q_.emplace_back(1.0, 0); + } else { + assert(n() <= (std::numeric_limits<IntType>::max)()); + q_ = random_internal::InitDiscreteDistribution(&p_); + } +} + +template <typename IntType> +template <typename URBG> +typename discrete_distribution<IntType>::result_type +discrete_distribution<IntType>::operator()( + URBG& g, // NOLINT(runtime/references) + const param_type& p) { + const auto idx = absl::uniform_int_distribution<result_type>(0, p.n())(g); + const auto& q = p.q_[idx]; + const bool selected = absl::bernoulli_distribution(q.first)(g); + return selected ? idx : static_cast<result_type>(q.second); +} + +template <typename CharT, typename Traits, typename IntType> +std::basic_ostream<CharT, Traits>& operator<<( + std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references) + const discrete_distribution<IntType>& x) { + auto saver = random_internal::make_ostream_state_saver(os); + const auto& probabilities = x.param().probabilities(); + os << probabilities.size(); + + os.precision(random_internal::stream_precision_helper<double>::kPrecision); + for (const auto& p : probabilities) { + os << os.fill() << p; + } + return os; +} + +template <typename CharT, typename Traits, typename IntType> +std::basic_istream<CharT, Traits>& operator>>( + std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references) + discrete_distribution<IntType>& x) { // NOLINT(runtime/references) + using param_type = typename discrete_distribution<IntType>::param_type; + auto saver = random_internal::make_istream_state_saver(is); + + size_t n; + std::vector<double> p; + + is >> n; + if (is.fail()) return is; + if (n > 0) { + p.reserve(n); + for (IntType i = 0; i < n && !is.fail(); ++i) { + auto tmp = random_internal::read_floating_point<double>(is); + if (is.fail()) return is; + p.push_back(tmp); + } + } + x.param(param_type(p.begin(), p.end())); + return is; +} + +} // namespace absl + +#endif // ABSL_RANDOM_DISCRETE_DISTRIBUTION_H_ diff --git a/absl/random/discrete_distribution_test.cc b/absl/random/discrete_distribution_test.cc new file mode 100644 index 000000000000..7296f0ac226a --- /dev/null +++ b/absl/random/discrete_distribution_test.cc @@ -0,0 +1,246 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/discrete_distribution.h" + +#include <cmath> +#include <cstddef> +#include <cstdint> +#include <iterator> +#include <numeric> +#include <random> +#include <sstream> +#include <string> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/base/internal/raw_logging.h" +#include "absl/random/internal/chi_square.h" +#include "absl/random/internal/distribution_test_util.h" +#include "absl/random/internal/sequence_urbg.h" +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/strip.h" + +namespace { + +template <typename IntType> +class DiscreteDistributionTypeTest : public ::testing::Test {}; + +using IntTypes = ::testing::Types<int8_t, uint8_t, int16_t, uint16_t, int32_t, + uint32_t, int64_t, uint64_t>; +TYPED_TEST_SUITE(DiscreteDistributionTypeTest, IntTypes); + +TYPED_TEST(DiscreteDistributionTypeTest, ParamSerializeTest) { + using param_type = + typename absl::discrete_distribution<TypeParam>::param_type; + + absl::discrete_distribution<TypeParam> empty; + EXPECT_THAT(empty.probabilities(), testing::ElementsAre(1.0)); + + absl::discrete_distribution<TypeParam> before({1.0, 2.0, 1.0}); + + // Validate that the probabilities sum to 1.0. We picked values which + // can be represented exactly to avoid floating-point roundoff error. + double s = 0; + for (const auto& x : before.probabilities()) { + s += x; + } + EXPECT_EQ(s, 1.0); + EXPECT_THAT(before.probabilities(), testing::ElementsAre(0.25, 0.5, 0.25)); + + // Validate the same data via an initializer list. + { + std::vector<double> data({1.0, 2.0, 1.0}); + + absl::discrete_distribution<TypeParam> via_param{ + param_type(std::begin(data), std::end(data))}; + + EXPECT_EQ(via_param, before); + } + + std::stringstream ss; + ss << before; + absl::discrete_distribution<TypeParam> after; + + EXPECT_NE(before, after); + + ss >> after; + + EXPECT_EQ(before, after); +} + +TYPED_TEST(DiscreteDistributionTypeTest, Constructor) { + auto fn = [](double x) { return x; }; + { + absl::discrete_distribution<int> unary(0, 1.0, 9.0, fn); + EXPECT_THAT(unary.probabilities(), testing::ElementsAre(1.0)); + } + + { + absl::discrete_distribution<int> unary(2, 1.0, 9.0, fn); + // => fn(1.0 + 0 * 4 + 2) => 3 + // => fn(1.0 + 1 * 4 + 2) => 7 + EXPECT_THAT(unary.probabilities(), testing::ElementsAre(0.3, 0.7)); + } +} + +TEST(DiscreteDistributionTest, InitDiscreteDistribution) { + using testing::Pair; + + { + std::vector<double> p({1.0, 2.0, 3.0}); + std::vector<std::pair<double, size_t>> q = + absl::random_internal::InitDiscreteDistribution(&p); + + EXPECT_THAT(p, testing::ElementsAre(1 / 6.0, 2 / 6.0, 3 / 6.0)); + + // Each bucket is p=1/3, so bucket 0 will send half it's traffic + // to bucket 2, while the rest will retain all of their traffic. + EXPECT_THAT(q, testing::ElementsAre(Pair(0.5, 2), // + Pair(1.0, 1), // + Pair(1.0, 2))); + } + + { + std::vector<double> p({1.0, 2.0, 3.0, 5.0, 2.0}); + + std::vector<std::pair<double, size_t>> q = + absl::random_internal::InitDiscreteDistribution(&p); + + EXPECT_THAT(p, testing::ElementsAre(1 / 13.0, 2 / 13.0, 3 / 13.0, 5 / 13.0, + 2 / 13.0)); + + // A more complex bucketing solution: Each bucket has p=0.2 + // So buckets 0, 1, 4 will send their alternate traffic elsewhere, which + // happens to be bucket 3. + // However, summing up that alternate traffic gives bucket 3 too much + // traffic, so it will send some traffic to bucket 2. + constexpr double b0 = 1.0 / 13.0 / 0.2; + constexpr double b1 = 2.0 / 13.0 / 0.2; + constexpr double b3 = (5.0 / 13.0 / 0.2) - ((1 - b0) + (1 - b1) + (1 - b1)); + + EXPECT_THAT(q, testing::ElementsAre(Pair(b0, 3), // + Pair(b1, 3), // + Pair(1.0, 2), // + Pair(b3, 2), // + Pair(b1, 3))); + } +} + +TEST(DiscreteDistributionTest, ChiSquaredTest50) { + using absl::random_internal::kChiSquared; + + constexpr size_t kTrials = 10000; + constexpr int kBuckets = 50; // inclusive, so actally +1 + + // 1-in-100000 threshold, but remember, there are about 8 tests + // in this file. And the test could fail for other reasons. + // Empirically validated with --runs_per_test=10000. + const int kThreshold = + absl::random_internal::ChiSquareValue(kBuckets, 0.99999); + + std::vector<double> weights(kBuckets, 0); + std::iota(std::begin(weights), std::end(weights), 1); + absl::discrete_distribution<int> dist(std::begin(weights), std::end(weights)); + + absl::InsecureBitGen rng; + + std::vector<int32_t> counts(kBuckets, 0); + for (size_t i = 0; i < kTrials; i++) { + auto x = dist(rng); + counts[x]++; + } + + // Scale weights. + double sum = 0; + for (double x : weights) { + sum += x; + } + for (double& x : weights) { + x = kTrials * (x / sum); + } + + double chi_square = + absl::random_internal::ChiSquare(std::begin(counts), std::end(counts), + std::begin(weights), std::end(weights)); + + if (chi_square > kThreshold) { + double p_value = + absl::random_internal::ChiSquarePValue(chi_square, kBuckets); + + // Chi-squared test failed. Output does not appear to be uniform. + std::string msg; + for (size_t i = 0; i < counts.size(); i++) { + absl::StrAppend(&msg, i, ": ", counts[i], " vs ", weights[i], "\n"); + } + absl::StrAppend(&msg, kChiSquared, " p-value ", p_value, "\n"); + absl::StrAppend(&msg, "High ", kChiSquared, " value: ", chi_square, " > ", + kThreshold); + ABSL_RAW_LOG(INFO, "%s", msg.c_str()); + FAIL() << msg; + } +} + +TEST(DiscreteDistributionTest, StabilityTest) { + // absl::discrete_distribution stabilitiy relies on + // absl::uniform_int_distribution and absl::bernoulli_distribution. + absl::random_internal::sequence_urbg urbg( + {0x0003eb76f6f7f755ull, 0xFFCEA50FDB2F953Bull, 0xC332DDEFBE6C5AA5ull, + 0x6558218568AB9702ull, 0x2AEF7DAD5B6E2F84ull, 0x1521B62829076170ull, + 0xECDD4775619F1510ull, 0x13CCA830EB61BD96ull, 0x0334FE1EAA0363CFull, + 0xB5735C904C70A239ull, 0xD59E9E0BCBAADE14ull, 0xEECC86BC60622CA7ull}); + + std::vector<int> output(6); + + { + absl::discrete_distribution<int32_t> dist({1.0, 2.0, 3.0, 5.0, 2.0}); + EXPECT_EQ(0, dist.min()); + EXPECT_EQ(4, dist.max()); + for (auto& v : output) { + v = dist(urbg); + } + EXPECT_EQ(12, urbg.invocations()); + } + + // With 12 calls to urbg, each call into discrete_distribution consumes + // precisely 2 values: one for the uniform call, and a second for the + // bernoulli. + // + // Given the alt mapping: 0=>3, 1=>3, 2=>2, 3=>2, 4=>3, we can + // + // uniform: 443210143131 + // bernoulli: b0 000011100101 + // bernoulli: b1 001111101101 + // bernoulli: b2 111111111111 + // bernoulli: b3 001111101111 + // bernoulli: b4 001111101101 + // ... + EXPECT_THAT(output, testing::ElementsAre(3, 3, 1, 3, 3, 3)); + + { + urbg.reset(); + absl::discrete_distribution<int64_t> dist({1.0, 2.0, 3.0, 5.0, 2.0}); + EXPECT_EQ(0, dist.min()); + EXPECT_EQ(4, dist.max()); + for (auto& v : output) { + v = dist(urbg); + } + EXPECT_EQ(12, urbg.invocations()); + } + EXPECT_THAT(output, testing::ElementsAre(3, 3, 0, 3, 0, 4)); +} + +} // namespace diff --git a/absl/random/distribution_format_traits.h b/absl/random/distribution_format_traits.h new file mode 100644 index 000000000000..3f28c9069453 --- /dev/null +++ b/absl/random/distribution_format_traits.h @@ -0,0 +1,249 @@ +// +// Copyright 2018 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +#ifndef ABSL_RANDOM_DISTRIBUTION_FORMAT_TRAITS_H_ +#define ABSL_RANDOM_DISTRIBUTION_FORMAT_TRAITS_H_ + +#include <string> +#include <tuple> +#include <typeinfo> + +#include "absl/meta/type_traits.h" +#include "absl/random/bernoulli_distribution.h" +#include "absl/random/beta_distribution.h" +#include "absl/random/exponential_distribution.h" +#include "absl/random/gaussian_distribution.h" +#include "absl/random/log_uniform_int_distribution.h" +#include "absl/random/poisson_distribution.h" +#include "absl/random/uniform_int_distribution.h" +#include "absl/random/uniform_real_distribution.h" +#include "absl/random/zipf_distribution.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" + +namespace absl { +namespace random_internal { + +// ScalarTypeName defines a preferred hierarchy of preferred type names for +// scalars, and is evaluated at compile time for the specific type +// specialization. +template <typename T> +constexpr const char* ScalarTypeName() { + static_assert(std::is_integral<T>() || std::is_floating_point<T>(), ""); + // clang-format off + return + std::is_same<T, float>::value ? "float" : + std::is_same<T, double>::value ? "double" : + std::is_same<T, long double>::value ? "long double" : + std::is_same<T, bool>::value ? "bool" : + std::is_signed<T>::value && sizeof(T) == 1 ? "int8_t" : + std::is_signed<T>::value && sizeof(T) == 2 ? "int16_t" : + std::is_signed<T>::value && sizeof(T) == 4 ? "int32_t" : + std::is_signed<T>::value && sizeof(T) == 8 ? "int64_t" : + std::is_unsigned<T>::value && sizeof(T) == 1 ? "uint8_t" : + std::is_unsigned<T>::value && sizeof(T) == 2 ? "uint16_t" : + std::is_unsigned<T>::value && sizeof(T) == 4 ? "uint32_t" : + std::is_unsigned<T>::value && sizeof(T) == 8 ? "uint64_t" : + "undefined"; + // clang-format on + + // NOTE: It would be nice to use typeid(T).name(), but that's an + // implementation-defined attribute which does not necessarily + // correspond to a name. We could potentially demangle it + // using, e.g. abi::__cxa_demangle. +} + +// Distribution traits used by DistributionCaller and internal implementation +// details of the mocking framework. +/* +struct DistributionFormatTraits { + // Returns the parameterized name of the distribution function. + static constexpr const char* FunctionName() + // Format DistrT parameters. + static std::string FormatArgs(DistrT& dist); + // Format DistrT::result_type results. + static std::string FormatResults(DistrT& dist); +}; +*/ +template <typename DistrT> +struct DistributionFormatTraits; + +template <typename R> +struct DistributionFormatTraits<absl::uniform_int_distribution<R>> { + using distribution_t = absl::uniform_int_distribution<R>; + using result_t = typename distribution_t::result_type; + + static constexpr const char* Name() { return "Uniform"; } + + static std::string FunctionName() { + return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">"); + } + static std::string FormatArgs(const distribution_t& d) { + return absl::StrCat("absl::IntervalClosedClosed, ", (d.min)(), ", ", + (d.max)()); + } + static std::string FormatResults(absl::Span<const result_t> results) { + return absl::StrJoin(results, ", "); + } +}; + +template <typename R> +struct DistributionFormatTraits<absl::uniform_real_distribution<R>> { + using distribution_t = absl::uniform_real_distribution<R>; + using result_t = typename distribution_t::result_type; + + static constexpr const char* Name() { return "Uniform"; } + + static std::string FunctionName() { + return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">"); + } + static std::string FormatArgs(const distribution_t& d) { + return absl::StrCat((d.min)(), ", ", (d.max)()); + } + static std::string FormatResults(absl::Span<const result_t> results) { + return absl::StrJoin(results, ", "); + } +}; + +template <typename R> +struct DistributionFormatTraits<absl::exponential_distribution<R>> { + using distribution_t = absl::exponential_distribution<R>; + using result_t = typename distribution_t::result_type; + + static constexpr const char* Name() { return "Exponential"; } + + static std::string FunctionName() { + return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">"); + } + static std::string FormatArgs(const distribution_t& d) { + return absl::StrCat(d.lambda()); + } + static std::string FormatResults(absl::Span<const result_t> results) { + return absl::StrJoin(results, ", "); + } +}; + +template <typename R> +struct DistributionFormatTraits<absl::poisson_distribution<R>> { + using distribution_t = absl::poisson_distribution<R>; + using result_t = typename distribution_t::result_type; + + static constexpr const char* Name() { return "Poisson"; } + + static std::string FunctionName() { + return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">"); + } + static std::string FormatArgs(const distribution_t& d) { + return absl::StrCat(d.mean()); + } + static std::string FormatResults(absl::Span<const result_t> results) { + return absl::StrJoin(results, ", "); + } +}; + +template <> +struct DistributionFormatTraits<absl::bernoulli_distribution> { + using distribution_t = absl::bernoulli_distribution; + using result_t = typename distribution_t::result_type; + + static constexpr const char* Name() { return "Bernoulli"; } + + static constexpr const char* FunctionName() { return Name(); } + static std::string FormatArgs(const distribution_t& d) { + return absl::StrCat(d.p()); + } + static std::string FormatResults(absl::Span<const result_t> results) { + return absl::StrJoin(results, ", "); + } +}; + +template <typename R> +struct DistributionFormatTraits<absl::beta_distribution<R>> { + using distribution_t = absl::beta_distribution<R>; + using result_t = typename distribution_t::result_type; + + static constexpr const char* Name() { return "Beta"; } + + static std::string FunctionName() { + return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">"); + } + static std::string FormatArgs(const distribution_t& d) { + return absl::StrCat(d.alpha(), ", ", d.beta()); + } + static std::string FormatResults(absl::Span<const result_t> results) { + return absl::StrJoin(results, ", "); + } +}; + +template <typename R> +struct DistributionFormatTraits<absl::zipf_distribution<R>> { + using distribution_t = absl::zipf_distribution<R>; + using result_t = typename distribution_t::result_type; + + static constexpr const char* Name() { return "Zipf"; } + + static std::string FunctionName() { + return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">"); + } + static std::string FormatArgs(const distribution_t& d) { + return absl::StrCat(d.k(), ", ", d.v(), ", ", d.q()); + } + static std::string FormatResults(absl::Span<const result_t> results) { + return absl::StrJoin(results, ", "); + } +}; + +template <typename R> +struct DistributionFormatTraits<absl::gaussian_distribution<R>> { + using distribution_t = absl::gaussian_distribution<R>; + using result_t = typename distribution_t::result_type; + + static constexpr const char* Name() { return "Gaussian"; } + + static std::string FunctionName() { + return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">"); + } + static std::string FormatArgs(const distribution_t& d) { + return absl::StrJoin(std::make_tuple(d.mean(), d.stddev()), ", "); + } + static std::string FormatResults(absl::Span<const result_t> results) { + return absl::StrJoin(results, ", "); + } +}; + +template <typename R> +struct DistributionFormatTraits<absl::log_uniform_int_distribution<R>> { + using distribution_t = absl::log_uniform_int_distribution<R>; + using result_t = typename distribution_t::result_type; + + static constexpr const char* Name() { return "LogUniform"; } + + static std::string FunctionName() { + return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">"); + } + static std::string FormatArgs(const distribution_t& d) { + return absl::StrJoin(std::make_tuple((d.min)(), (d.max)(), d.base()), ", "); + } + static std::string FormatResults(absl::Span<const result_t> results) { + return absl::StrJoin(results, ", "); + } +}; + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_DISTRIBUTION_FORMAT_TRAITS_H_ diff --git a/absl/random/distributions.h b/absl/random/distributions.h new file mode 100644 index 000000000000..c37b7347fd6e --- /dev/null +++ b/absl/random/distributions.h @@ -0,0 +1,442 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// ----------------------------------------------------------------------------- +// File: distributions.h +// ----------------------------------------------------------------------------- +// +// This header defines functions representing distributions, which you use in +// combination with an Abseil random bit generator to produce random values +// according to the rules of that distribution. +// +// The Abseil random library defines the following distributions within this +// file: +// +// * `absl::Uniform` for uniform (constant) distributions having constant +// probability +// * `absl::Bernoulli` for discrete distributions having exactly two outcomes +// * `absl::Beta` for continuous distributions parameterized through two +// free parameters +// * `absl::Exponential` for discrete distributions of events occurring +// continuously and independently at a constant average rate +// * `absl::Gaussian` (also known as "normal distributions") for continuous +// distributions using an associated quadratic function +// * `absl::LogUniform` for continuous uniform distributions where the log +// to the given base of all values is uniform +// * `absl::Poisson` for discrete probability distributions that express the +// probability of a given number of events occurring within a fixed interval +// * `absl::Zipf` for discrete probability distributions commonly used for +// modelling of rare events +// +// Prefer use of these distribution function classes over manual construction of +// your own distribution classes, as it allows library maintainers greater +// flexibility to change the underlying implementation in the future. + +#ifndef ABSL_RANDOM_DISTRIBUTIONS_H_ +#define ABSL_RANDOM_DISTRIBUTIONS_H_ + +#include <algorithm> +#include <cmath> +#include <limits> +#include <random> +#include <type_traits> + +#include "absl/base/internal/inline_variable.h" +#include "absl/random/bernoulli_distribution.h" +#include "absl/random/beta_distribution.h" +#include "absl/random/distribution_format_traits.h" +#include "absl/random/exponential_distribution.h" +#include "absl/random/gaussian_distribution.h" +#include "absl/random/internal/distributions.h" // IWYU pragma: export +#include "absl/random/internal/uniform_helper.h" // IWYU pragma: export +#include "absl/random/log_uniform_int_distribution.h" +#include "absl/random/poisson_distribution.h" +#include "absl/random/uniform_int_distribution.h" +#include "absl/random/uniform_real_distribution.h" +#include "absl/random/zipf_distribution.h" + +namespace absl { + +ABSL_INTERNAL_INLINE_CONSTEXPR(random_internal::IntervalClosedClosedT, + IntervalClosedClosed, {}); +ABSL_INTERNAL_INLINE_CONSTEXPR(random_internal::IntervalClosedClosedT, + IntervalClosed, {}); +ABSL_INTERNAL_INLINE_CONSTEXPR(random_internal::IntervalClosedOpenT, + IntervalClosedOpen, {}); +ABSL_INTERNAL_INLINE_CONSTEXPR(random_internal::IntervalOpenOpenT, + IntervalOpenOpen, {}); +ABSL_INTERNAL_INLINE_CONSTEXPR(random_internal::IntervalOpenOpenT, + IntervalOpen, {}); +ABSL_INTERNAL_INLINE_CONSTEXPR(random_internal::IntervalOpenClosedT, + IntervalOpenClosed, {}); + +// ----------------------------------------------------------------------------- +// absl::Uniform<T>(tag, bitgen, lo, hi) +// ----------------------------------------------------------------------------- +// +// `absl::Uniform()` produces random values of type `T` uniformly distributed in +// a defined interval {lo, hi}. The interval `tag` defines the type of interval +// which should be one of the following possible values: +// +// * `absl::IntervalOpenOpen` +// * `absl::IntervalOpenClosed` +// * `absl::IntervalClosedOpen` +// * `absl::IntervalClosedClosed` +// +// where "open" refers to an exclusive value (excluded) from the output, while +// "closed" refers to an inclusive value (included) from the output. +// +// In the absence of an explicit return type `T`, `absl::Uniform()` will deduce +// the return type based on the provided endpoint arguments {A lo, B hi}. +// Given these endpoints, one of {A, B} will be chosen as the return type, if +// a type can be implicitly converted into the other in a lossless way. The +// lack of any such implcit conversion between {A, B} will produce a +// compile-time error +// +// See https://en.wikipedia.org/wiki/Uniform_distribution_(continuous) +// +// Example: +// +// absl::BitGen bitgen; +// +// // Produce a random float value between 0.0 and 1.0, inclusive +// auto x = absl::Uniform(absl::IntervalClosedClosed, bitgen, 0.0f, 1.0f); +// +// // The most common interval of `absl::IntervalClosedOpen` is available by +// // default: +// +// auto x = absl::Uniform(bitgen, 0.0f, 1.0f); +// +// // Return-types are typically inferred from the arguments, however callers +// // can optionally provide an explicit return-type to the template. +// +// auto x = absl::Uniform<float>(bitgen, 0, 1); +// +template <typename R = void, typename TagType, typename URBG> +typename absl::enable_if_t<!std::is_same<R, void>::value, R> // +Uniform(TagType tag, + URBG&& urbg, // NOLINT(runtime/references) + R lo, R hi) { + using gen_t = absl::decay_t<URBG>; + return random_internal::UniformImpl<R, TagType, gen_t>(tag, urbg, lo, hi); +} + +// absl::Uniform<T>(bitgen, lo, hi) +// +// Overload of `Uniform()` using the default closed-open interval of [lo, hi), +// and returning values of type `T` +template <typename R = void, typename URBG> +typename absl::enable_if_t<!std::is_same<R, void>::value, R> // +Uniform(URBG&& urbg, // NOLINT(runtime/references) + R lo, R hi) { + constexpr auto tag = absl::IntervalClosedOpen; + using tag_t = decltype(tag); + using gen_t = absl::decay_t<URBG>; + + return random_internal::UniformImpl<R, tag_t, gen_t>(tag, urbg, lo, hi); +} + +// absl::Uniform(tag, bitgen, lo, hi) +// +// Overload of `Uniform()` using different (but compatible) lo, hi types. Note +// that a compile-error will result if the return type cannot be deduced +// correctly from the passed types. +template <typename R = void, typename TagType, typename URBG, typename A, + typename B> +typename absl::enable_if_t<std::is_same<R, void>::value, + random_internal::uniform_inferred_return_t<A, B>> +Uniform(TagType tag, + URBG&& urbg, // NOLINT(runtime/references) + A lo, B hi) { + using gen_t = absl::decay_t<URBG>; + using return_t = typename random_internal::uniform_inferred_return_t<A, B>; + + return random_internal::UniformImpl<return_t, TagType, gen_t>(tag, urbg, lo, + hi); +} + +// absl::Uniform(bitgen, lo, hi) +// +// Overload of `Uniform()` using different (but compatible) lo, hi types and the +// default closed-open interval of [lo, hi). Note that a compile-error will +// result if the return type cannot be deduced correctly from the passed types. +template <typename R = void, typename URBG, typename A, typename B> +typename absl::enable_if_t<std::is_same<R, void>::value, + random_internal::uniform_inferred_return_t<A, B>> +Uniform(URBG&& urbg, // NOLINT(runtime/references) + A lo, B hi) { + constexpr auto tag = absl::IntervalClosedOpen; + using tag_t = decltype(tag); + using gen_t = absl::decay_t<URBG>; + using return_t = typename random_internal::uniform_inferred_return_t<A, B>; + + return random_internal::UniformImpl<return_t, tag_t, gen_t>(tag, urbg, lo, + hi); +} + +// absl::Uniform<unsigned T>(bitgen) +// +// Overload of Uniform() using the minimum and maximum values of a given type +// `T` (which must be unsigned), returning a value of type `unsigned T` +template <typename R, typename URBG> +typename absl::enable_if_t<!std::is_signed<R>::value, R> // +Uniform(URBG&& urbg) { // NOLINT(runtime/references) + constexpr auto tag = absl::IntervalClosedClosed; + constexpr auto lo = std::numeric_limits<R>::lowest(); + constexpr auto hi = (std::numeric_limits<R>::max)(); + using tag_t = decltype(tag); + using gen_t = absl::decay_t<URBG>; + + return random_internal::UniformImpl<R, tag_t, gen_t>(tag, urbg, lo, hi); +} + +// ----------------------------------------------------------------------------- +// absl::Bernoulli(bitgen, p) +// ----------------------------------------------------------------------------- +// +// `absl::Bernoulli` produces a random boolean value, with probability `p` +// (where 0.0 <= p <= 1.0) equaling `true`. +// +// Prefer `absl::Bernoulli` to produce boolean values over other alternatives +// such as comparing an `absl::Uniform()` value to a specific output. +// +// See https://en.wikipedia.org/wiki/Bernoulli_distribution +// +// Example: +// +// absl::BitGen bitgen; +// ... +// if (absl::Bernoulli(bitgen, 1.0/3721.0)) { +// std::cout << "Asteroid field navigation successful."; +// } +// +template <typename URBG> +bool Bernoulli(URBG&& urbg, // NOLINT(runtime/references) + double p) { + using gen_t = absl::decay_t<URBG>; + using distribution_t = absl::bernoulli_distribution; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; + + return random_internal::DistributionCaller<gen_t>::template Call< + distribution_t, format_t>(&urbg, p); +} + +// ----------------------------------------------------------------------------- +// absl::Beta<T>(bitgen, alpha, beta) +// ----------------------------------------------------------------------------- +// +// `absl::Beta` produces a floating point number distributed in the closed +// interval [0,1] and parameterized by two values `alpha` and `beta` as per a +// Beta distribution. `T` must be a floating point type, but may be inferred +// from the types of `alpha` and `beta`. +// +// See https://en.wikipedia.org/wiki/Beta_distribution. +// +// Example: +// +// absl::BitGen bitgen; +// ... +// double sample = absl::Beta(bitgen, 3.0, 2.0); +// +template <typename RealType, typename URBG> +RealType Beta(URBG&& urbg, // NOLINT(runtime/references) + RealType alpha, RealType beta) { + static_assert( + std::is_floating_point<RealType>::value, + "Template-argument 'RealType' must be a floating-point type, in " + "absl::Beta<RealType, URBG>(...)"); + + using gen_t = absl::decay_t<URBG>; + using distribution_t = typename absl::beta_distribution<RealType>; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; + + return random_internal::DistributionCaller<gen_t>::template Call< + distribution_t, format_t>(&urbg, alpha, beta); +} + +// ----------------------------------------------------------------------------- +// absl::Exponential<T>(bitgen, lambda = 1) +// ----------------------------------------------------------------------------- +// +// `absl::Exponential` produces a floating point number for discrete +// distributions of events occurring continuously and independently at a +// constant average rate. `T` must be a floating point type, but may be inferred +// from the type of `lambda`. +// +// See https://en.wikipedia.org/wiki/Exponential_distribution. +// +// Example: +// +// absl::BitGen bitgen; +// ... +// double call_length = absl::Exponential(bitgen, 7.0); +// +template <typename RealType, typename URBG> +RealType Exponential(URBG&& urbg, // NOLINT(runtime/references) + RealType lambda = 1) { + static_assert( + std::is_floating_point<RealType>::value, + "Template-argument 'RealType' must be a floating-point type, in " + "absl::Exponential<RealType, URBG>(...)"); + + using gen_t = absl::decay_t<URBG>; + using distribution_t = typename absl::exponential_distribution<RealType>; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; + + return random_internal::DistributionCaller<gen_t>::template Call< + distribution_t, format_t>(&urbg, lambda); +} + +// ----------------------------------------------------------------------------- +// absl::Gaussian<T>(bitgen, mean = 0, stddev = 1) +// ----------------------------------------------------------------------------- +// +// `absl::Gaussian` produces a floating point number selected from the Gaussian +// (ie. "Normal") distribution. `T` must be a floating point type, but may be +// inferred from the types of `mean` and `stddev`. +// +// See https://en.wikipedia.org/wiki/Normal_distribution +// +// Example: +// +// absl::BitGen bitgen; +// ... +// double giraffe_height = absl::Gaussian(bitgen, 16.3, 3.3); +// +template <typename RealType, typename URBG> +RealType Gaussian(URBG&& urbg, // NOLINT(runtime/references) + RealType mean = 0, RealType stddev = 1) { + static_assert( + std::is_floating_point<RealType>::value, + "Template-argument 'RealType' must be a floating-point type, in " + "absl::Gaussian<RealType, URBG>(...)"); + + using gen_t = absl::decay_t<URBG>; + using distribution_t = typename absl::gaussian_distribution<RealType>; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; + + return random_internal::DistributionCaller<gen_t>::template Call< + distribution_t, format_t>(&urbg, mean, stddev); +} + +// ----------------------------------------------------------------------------- +// absl::LogUniform<T>(bitgen, lo, hi, base = 2) +// ----------------------------------------------------------------------------- +// +// `absl::LogUniform` produces random values distributed where the log to a +// given base of all values is uniform in a closed interval [lo, hi]. `T` must +// be an integral type, but may be inferred from the types of `lo` and `hi`. +// +// I.e., `LogUniform(0, n, b)` is uniformly distributed across buckets +// [0], [1, b-1], [b, b^2-1] .. [b^(k-1), (b^k)-1] .. [b^floor(log(n, b)), n] +// and is uniformly distributed within each bucket. +// +// The resulting probability density is inversely related to bucket size, though +// values in the final bucket may be more likely than previous values. (In the +// extreme case where n = b^i the final value will be tied with zero as the most +// probable result. +// +// If `lo` is nonzero then this distribution is shifted to the desired interval, +// so LogUniform(lo, hi, b) is equivalent to LogUniform(0, hi-lo, b)+lo. +// +// See http://ecolego.facilia.se/ecolego/show/Log-Uniform%20Distribution +// +// Example: +// +// absl::BitGen bitgen; +// ... +// int v = absl::LogUniform(bitgen, 0, 1000); +// +template <typename IntType, typename URBG> +IntType LogUniform(URBG&& urbg, // NOLINT(runtime/references) + IntType lo, IntType hi, IntType base = 2) { + static_assert(std::is_integral<IntType>::value, + "Template-argument 'IntType' must be an integral type, in " + "absl::LogUniform<IntType, URBG>(...)"); + + using gen_t = absl::decay_t<URBG>; + using distribution_t = typename absl::log_uniform_int_distribution<IntType>; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; + + return random_internal::DistributionCaller<gen_t>::template Call< + distribution_t, format_t>(&urbg, lo, hi, base); +} + +// ----------------------------------------------------------------------------- +// absl::Poisson<T>(bitgen, mean = 1) +// ----------------------------------------------------------------------------- +// +// `absl::Poisson` produces discrete probabilities for a given number of events +// occurring within a fixed interval within the closed interval [0, max]. `T` +// must be an integral type. +// +// See https://en.wikipedia.org/wiki/Poisson_distribution +// +// Example: +// +// absl::BitGen bitgen; +// ... +// int requests_per_minute = absl::Poisson<int>(bitgen, 3.2); +// +template <typename IntType, typename URBG> +IntType Poisson(URBG&& urbg, // NOLINT(runtime/references) + double mean = 1.0) { + static_assert(std::is_integral<IntType>::value, + "Template-argument 'IntType' must be an integral type, in " + "absl::Poisson<IntType, URBG>(...)"); + + using gen_t = absl::decay_t<URBG>; + using distribution_t = typename absl::poisson_distribution<IntType>; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; + + return random_internal::DistributionCaller<gen_t>::template Call< + distribution_t, format_t>(&urbg, mean); +} + +// ----------------------------------------------------------------------------- +// absl::Zipf<T>(bitgen, hi = max, q = 2, v = 1) +// ----------------------------------------------------------------------------- +// +// `absl::Zipf` produces discrete probabilities commonly used for modelling of +// rare events over the closed interval [0, hi]. The parameters `v` and `q` +// determine the skew of the distribution. `T` must be an integral type, but +// may be inferred from the type of `hi`. +// +// See http://mathworld.wolfram.com/ZipfDistribution.html +// +// Example: +// +// absl::BitGen bitgen; +// ... +// int term_rank = absl::Zipf<int>(bitgen); +// +template <typename IntType, typename URBG> +IntType Zipf(URBG&& urbg, // NOLINT(runtime/references) + IntType hi = (std::numeric_limits<IntType>::max)(), double q = 2.0, + double v = 1.0) { + static_assert(std::is_integral<IntType>::value, + "Template-argument 'IntType' must be an integral type, in " + "absl::Zipf<IntType, URBG>(...)"); + + using gen_t = absl::decay_t<URBG>; + using distribution_t = typename absl::zipf_distribution<IntType>; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; + + return random_internal::DistributionCaller<gen_t>::template Call< + distribution_t, format_t>(&urbg, hi, q, v); +} + +} // namespace absl. + +#endif // ABSL_RANDOM_DISTRIBUTIONS_H_ diff --git a/absl/random/distributions_test.cc b/absl/random/distributions_test.cc new file mode 100644 index 000000000000..eb82868d5320 --- /dev/null +++ b/absl/random/distributions_test.cc @@ -0,0 +1,494 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/distributions.h" + +#include <cmath> +#include <cstdint> +#include <random> +#include <vector> + +#include "gtest/gtest.h" +#include "absl/random/internal/distribution_test_util.h" +#include "absl/random/random.h" + +namespace { + +constexpr int kSize = 400000; + +class RandomDistributionsTest : public testing::Test {}; + +TEST_F(RandomDistributionsTest, UniformBoundFunctions) { + using absl::IntervalClosedClosed; + using absl::IntervalClosedOpen; + using absl::IntervalOpenClosed; + using absl::IntervalOpenOpen; + using absl::random_internal::uniform_lower_bound; + using absl::random_internal::uniform_upper_bound; + + // absl::uniform_int_distribution natively assumes IntervalClosedClosed + // absl::uniform_real_distribution natively assumes IntervalClosedOpen + + EXPECT_EQ(uniform_lower_bound(IntervalOpenClosed, 0, 100), 1); + EXPECT_EQ(uniform_lower_bound(IntervalOpenOpen, 0, 100), 1); + EXPECT_GT(uniform_lower_bound<float>(IntervalOpenClosed, 0, 1.0), 0); + EXPECT_GT(uniform_lower_bound<float>(IntervalOpenOpen, 0, 1.0), 0); + EXPECT_GT(uniform_lower_bound<double>(IntervalOpenClosed, 0, 1.0), 0); + EXPECT_GT(uniform_lower_bound<double>(IntervalOpenOpen, 0, 1.0), 0); + + EXPECT_EQ(uniform_lower_bound(IntervalClosedClosed, 0, 100), 0); + EXPECT_EQ(uniform_lower_bound(IntervalClosedOpen, 0, 100), 0); + EXPECT_EQ(uniform_lower_bound<float>(IntervalClosedClosed, 0, 1.0), 0); + EXPECT_EQ(uniform_lower_bound<float>(IntervalClosedOpen, 0, 1.0), 0); + EXPECT_EQ(uniform_lower_bound<double>(IntervalClosedClosed, 0, 1.0), 0); + EXPECT_EQ(uniform_lower_bound<double>(IntervalClosedOpen, 0, 1.0), 0); + + EXPECT_EQ(uniform_upper_bound(IntervalOpenOpen, 0, 100), 99); + EXPECT_EQ(uniform_upper_bound(IntervalClosedOpen, 0, 100), 99); + EXPECT_EQ(uniform_upper_bound<float>(IntervalOpenOpen, 0, 1.0), 1.0); + EXPECT_EQ(uniform_upper_bound<float>(IntervalClosedOpen, 0, 1.0), 1.0); + EXPECT_EQ(uniform_upper_bound<double>(IntervalOpenOpen, 0, 1.0), 1.0); + EXPECT_EQ(uniform_upper_bound<double>(IntervalClosedOpen, 0, 1.0), 1.0); + + EXPECT_EQ(uniform_upper_bound(IntervalOpenClosed, 0, 100), 100); + EXPECT_EQ(uniform_upper_bound(IntervalClosedClosed, 0, 100), 100); + EXPECT_GT(uniform_upper_bound<float>(IntervalOpenClosed, 0, 1.0), 1.0); + EXPECT_GT(uniform_upper_bound<float>(IntervalClosedClosed, 0, 1.0), 1.0); + EXPECT_GT(uniform_upper_bound<double>(IntervalOpenClosed, 0, 1.0), 1.0); + EXPECT_GT(uniform_upper_bound<double>(IntervalClosedClosed, 0, 1.0), 1.0); + + // Negative value tests + EXPECT_EQ(uniform_lower_bound(IntervalOpenClosed, -100, -1), -99); + EXPECT_EQ(uniform_lower_bound(IntervalOpenOpen, -100, -1), -99); + EXPECT_GT(uniform_lower_bound<float>(IntervalOpenClosed, -2.0, -1.0), -2.0); + EXPECT_GT(uniform_lower_bound<float>(IntervalOpenOpen, -2.0, -1.0), -2.0); + EXPECT_GT(uniform_lower_bound<double>(IntervalOpenClosed, -2.0, -1.0), -2.0); + EXPECT_GT(uniform_lower_bound<double>(IntervalOpenOpen, -2.0, -1.0), -2.0); + + EXPECT_EQ(uniform_lower_bound(IntervalClosedClosed, -100, -1), -100); + EXPECT_EQ(uniform_lower_bound(IntervalClosedOpen, -100, -1), -100); + EXPECT_EQ(uniform_lower_bound<float>(IntervalClosedClosed, -2.0, -1.0), -2.0); + EXPECT_EQ(uniform_lower_bound<float>(IntervalClosedOpen, -2.0, -1.0), -2.0); + EXPECT_EQ(uniform_lower_bound<double>(IntervalClosedClosed, -2.0, -1.0), + -2.0); + EXPECT_EQ(uniform_lower_bound<double>(IntervalClosedOpen, -2.0, -1.0), -2.0); + + EXPECT_EQ(uniform_upper_bound(IntervalOpenOpen, -100, -1), -2); + EXPECT_EQ(uniform_upper_bound(IntervalClosedOpen, -100, -1), -2); + EXPECT_EQ(uniform_upper_bound<float>(IntervalOpenOpen, -2.0, -1.0), -1.0); + EXPECT_EQ(uniform_upper_bound<float>(IntervalClosedOpen, -2.0, -1.0), -1.0); + EXPECT_EQ(uniform_upper_bound<double>(IntervalOpenOpen, -2.0, -1.0), -1.0); + EXPECT_EQ(uniform_upper_bound<double>(IntervalClosedOpen, -2.0, -1.0), -1.0); + + EXPECT_EQ(uniform_upper_bound(IntervalOpenClosed, -100, -1), -1); + EXPECT_EQ(uniform_upper_bound(IntervalClosedClosed, -100, -1), -1); + EXPECT_GT(uniform_upper_bound<float>(IntervalOpenClosed, -2.0, -1.0), -1.0); + EXPECT_GT(uniform_upper_bound<float>(IntervalClosedClosed, -2.0, -1.0), -1.0); + EXPECT_GT(uniform_upper_bound<double>(IntervalOpenClosed, -2.0, -1.0), -1.0); + EXPECT_GT(uniform_upper_bound<double>(IntervalClosedClosed, -2.0, -1.0), + -1.0); + + // Edge cases: the next value toward itself is itself. + const double d = 1.0; + const float f = 1.0; + EXPECT_EQ(uniform_lower_bound(IntervalOpenClosed, d, d), d); + EXPECT_EQ(uniform_lower_bound(IntervalOpenClosed, f, f), f); + + EXPECT_GT(uniform_lower_bound(IntervalOpenClosed, 1.0, 2.0), 1.0); + EXPECT_LT(uniform_lower_bound(IntervalOpenClosed, 1.0, +0.0), 1.0); + EXPECT_LT(uniform_lower_bound(IntervalOpenClosed, 1.0, -0.0), 1.0); + EXPECT_LT(uniform_lower_bound(IntervalOpenClosed, 1.0, -1.0), 1.0); + + EXPECT_EQ(uniform_upper_bound(IntervalClosedClosed, 0.0f, + std::numeric_limits<float>::max()), + std::numeric_limits<float>::max()); + EXPECT_EQ(uniform_upper_bound(IntervalClosedClosed, 0.0, + std::numeric_limits<double>::max()), + std::numeric_limits<double>::max()); +} + +struct Invalid {}; + +template <typename A, typename B> +auto InferredUniformReturnT(int) + -> decltype(absl::Uniform(std::declval<absl::InsecureBitGen&>(), + std::declval<A>(), std::declval<B>())); + +template <typename, typename> +Invalid InferredUniformReturnT(...); + +template <typename TagType, typename A, typename B> +auto InferredTaggedUniformReturnT(int) + -> decltype(absl::Uniform(std::declval<TagType>(), + std::declval<absl::InsecureBitGen&>(), + std::declval<A>(), std::declval<B>())); + +template <typename, typename, typename> +Invalid InferredTaggedUniformReturnT(...); + +// Given types <A, B, Expect>, CheckArgsInferType() verifies that +// +// absl::Uniform(gen, A{}, B{}) +// +// returns the type "Expect". +// +// This interface can also be used to assert that a given absl::Uniform() +// overload does not exist / will not compile. Given types <A, B>, the +// expression +// +// decltype(absl::Uniform(..., std::declval<A>(), std::declval<B>())) +// +// will not compile, leaving the definition of InferredUniformReturnT<A, B> to +// resolve (via SFINAE) to the overload which returns type "Invalid". This +// allows tests to assert that an invocation such as +// +// absl::Uniform(gen, 1.23f, std::numeric_limits<int>::max() - 1) +// +// should not compile, since neither type, float nor int, can precisely +// represent both endpoint-values. Writing: +// +// CheckArgsInferType<float, int, Invalid>() +// +// will assert that this overload does not exist. +template <typename A, typename B, typename Expect> +void CheckArgsInferType() { + static_assert( + absl::conjunction< + std::is_same<Expect, decltype(InferredUniformReturnT<A, B>(0))>, + std::is_same<Expect, + decltype(InferredUniformReturnT<B, A>(0))>>::value, + ""); + static_assert( + absl::conjunction< + std::is_same<Expect, + decltype(InferredTaggedUniformReturnT< + absl::random_internal::IntervalOpenOpenT, A, B>( + 0))>, + std::is_same<Expect, + decltype(InferredTaggedUniformReturnT< + absl::random_internal::IntervalOpenOpenT, B, A>( + 0))>>::value, + ""); +} + +template <typename A, typename B, typename ExplicitRet> +auto ExplicitUniformReturnT(int) -> decltype( + absl::Uniform<ExplicitRet>(*std::declval<absl::InsecureBitGen*>(), + std::declval<A>(), std::declval<B>())); + +template <typename, typename, typename ExplicitRet> +Invalid ExplicitUniformReturnT(...); + +template <typename TagType, typename A, typename B, typename ExplicitRet> +auto ExplicitTaggedUniformReturnT(int) -> decltype(absl::Uniform<ExplicitRet>( + std::declval<TagType>(), *std::declval<absl::InsecureBitGen*>(), + std::declval<A>(), std::declval<B>())); + +template <typename, typename, typename, typename ExplicitRet> +Invalid ExplicitTaggedUniformReturnT(...); + +// Given types <A, B, Expect>, CheckArgsReturnExpectedType() verifies that +// +// absl::Uniform<Expect>(gen, A{}, B{}) +// +// returns the type "Expect", and that the function-overload has the signature +// +// Expect(URBG&, Expect, Expect) +template <typename A, typename B, typename Expect> +void CheckArgsReturnExpectedType() { + static_assert( + absl::conjunction< + std::is_same<Expect, + decltype(ExplicitUniformReturnT<A, B, Expect>(0))>, + std::is_same<Expect, decltype(ExplicitUniformReturnT<B, A, Expect>( + 0))>>::value, + ""); + static_assert( + absl::conjunction< + std::is_same<Expect, + decltype(ExplicitTaggedUniformReturnT< + absl::random_internal::IntervalOpenOpenT, A, B, + Expect>(0))>, + std::is_same<Expect, + decltype(ExplicitTaggedUniformReturnT< + absl::random_internal::IntervalOpenOpenT, B, A, + Expect>(0))>>::value, + ""); +} + +TEST_F(RandomDistributionsTest, UniformTypeInference) { + // Infers common types. + CheckArgsInferType<uint16_t, uint16_t, uint16_t>(); + CheckArgsInferType<uint32_t, uint32_t, uint32_t>(); + CheckArgsInferType<uint64_t, uint64_t, uint64_t>(); + CheckArgsInferType<int16_t, int16_t, int16_t>(); + CheckArgsInferType<int32_t, int32_t, int32_t>(); + CheckArgsInferType<int64_t, int64_t, int64_t>(); + CheckArgsInferType<float, float, float>(); + CheckArgsInferType<double, double, double>(); + + // Explicitly-specified return-values override inferences. + CheckArgsReturnExpectedType<int16_t, int16_t, int32_t>(); + CheckArgsReturnExpectedType<uint16_t, uint16_t, int32_t>(); + CheckArgsReturnExpectedType<int16_t, int16_t, int64_t>(); + CheckArgsReturnExpectedType<int16_t, int32_t, int64_t>(); + CheckArgsReturnExpectedType<int16_t, int32_t, double>(); + CheckArgsReturnExpectedType<float, float, double>(); + CheckArgsReturnExpectedType<int, int, int16_t>(); + + // Properly promotes uint16_t. + CheckArgsInferType<uint16_t, uint32_t, uint32_t>(); + CheckArgsInferType<uint16_t, uint64_t, uint64_t>(); + CheckArgsInferType<uint16_t, int32_t, int32_t>(); + CheckArgsInferType<uint16_t, int64_t, int64_t>(); + CheckArgsInferType<uint16_t, float, float>(); + CheckArgsInferType<uint16_t, double, double>(); + + // Properly promotes int16_t. + CheckArgsInferType<int16_t, int32_t, int32_t>(); + CheckArgsInferType<int16_t, int64_t, int64_t>(); + CheckArgsInferType<int16_t, float, float>(); + CheckArgsInferType<int16_t, double, double>(); + + // Invalid (u)int16_t-pairings do not compile. + // See "CheckArgsInferType" comments above, for how this is achieved. + CheckArgsInferType<uint16_t, int16_t, Invalid>(); + CheckArgsInferType<int16_t, uint32_t, Invalid>(); + CheckArgsInferType<int16_t, uint64_t, Invalid>(); + + // Properly promotes uint32_t. + CheckArgsInferType<uint32_t, uint64_t, uint64_t>(); + CheckArgsInferType<uint32_t, int64_t, int64_t>(); + CheckArgsInferType<uint32_t, double, double>(); + + // Properly promotes int32_t. + CheckArgsInferType<int32_t, int64_t, int64_t>(); + CheckArgsInferType<int32_t, double, double>(); + + // Invalid (u)int32_t-pairings do not compile. + CheckArgsInferType<uint32_t, int32_t, Invalid>(); + CheckArgsInferType<int32_t, uint64_t, Invalid>(); + CheckArgsInferType<int32_t, float, Invalid>(); + CheckArgsInferType<uint32_t, float, Invalid>(); + + // Invalid (u)int64_t-pairings do not compile. + CheckArgsInferType<uint64_t, int64_t, Invalid>(); + CheckArgsInferType<int64_t, float, Invalid>(); + CheckArgsInferType<int64_t, double, Invalid>(); + + // Properly promotes float. + CheckArgsInferType<float, double, double>(); + + // Examples. + absl::InsecureBitGen gen; + EXPECT_NE(1, absl::Uniform(gen, static_cast<uint16_t>(0), 1.0f)); + EXPECT_NE(1, absl::Uniform(gen, 0, 1.0)); + EXPECT_NE(1, absl::Uniform(absl::IntervalOpenOpen, gen, + static_cast<uint16_t>(0), 1.0f)); + EXPECT_NE(1, absl::Uniform(absl::IntervalOpenOpen, gen, 0, 1.0)); + EXPECT_NE(1, absl::Uniform(absl::IntervalOpenOpen, gen, -1, 1.0)); + EXPECT_NE(1, absl::Uniform<double>(absl::IntervalOpenOpen, gen, -1, 1)); + EXPECT_NE(1, absl::Uniform<float>(absl::IntervalOpenOpen, gen, 0, 1)); + EXPECT_NE(1, absl::Uniform<float>(gen, 0, 1)); +} + +TEST_F(RandomDistributionsTest, UniformNoBounds) { + absl::InsecureBitGen gen; + + absl::Uniform<uint8_t>(gen); + absl::Uniform<uint16_t>(gen); + absl::Uniform<uint32_t>(gen); + absl::Uniform<uint64_t>(gen); +} + +// TODO(lar): Validate properties of non-default interval-semantics. +TEST_F(RandomDistributionsTest, UniformReal) { + std::vector<double> values(kSize); + + absl::InsecureBitGen gen; + for (int i = 0; i < kSize; i++) { + values[i] = absl::Uniform(gen, 0, 1.0); + } + + const auto moments = + absl::random_internal::ComputeDistributionMoments(values); + EXPECT_NEAR(0.5, moments.mean, 0.02); + EXPECT_NEAR(1 / 12.0, moments.variance, 0.02); + EXPECT_NEAR(0.0, moments.skewness, 0.02); + EXPECT_NEAR(9 / 5.0, moments.kurtosis, 0.02); +} + +TEST_F(RandomDistributionsTest, UniformInt) { + std::vector<double> values(kSize); + + absl::InsecureBitGen gen; + for (int i = 0; i < kSize; i++) { + const int64_t kMax = 1000000000000ll; + int64_t j = absl::Uniform(absl::IntervalClosedClosed, gen, 0, kMax); + // convert to double. + values[i] = static_cast<double>(j) / static_cast<double>(kMax); + } + + const auto moments = + absl::random_internal::ComputeDistributionMoments(values); + EXPECT_NEAR(0.5, moments.mean, 0.02); + EXPECT_NEAR(1 / 12.0, moments.variance, 0.02); + EXPECT_NEAR(0.0, moments.skewness, 0.02); + EXPECT_NEAR(9 / 5.0, moments.kurtosis, 0.02); + + /* + // NOTE: These are not supported by absl::Uniform, which is specialized + // on integer and real valued types. + + enum E { E0, E1 }; // enum + enum S : int { S0, S1 }; // signed enum + enum U : unsigned int { U0, U1 }; // unsigned enum + + absl::Uniform(gen, E0, E1); + absl::Uniform(gen, S0, S1); + absl::Uniform(gen, U0, U1); + */ +} + +TEST_F(RandomDistributionsTest, Exponential) { + std::vector<double> values(kSize); + + absl::InsecureBitGen gen; + for (int i = 0; i < kSize; i++) { + values[i] = absl::Exponential<double>(gen); + } + + const auto moments = + absl::random_internal::ComputeDistributionMoments(values); + EXPECT_NEAR(1.0, moments.mean, 0.02); + EXPECT_NEAR(1.0, moments.variance, 0.025); + EXPECT_NEAR(2.0, moments.skewness, 0.1); + EXPECT_LT(5.0, moments.kurtosis); +} + +TEST_F(RandomDistributionsTest, PoissonDefault) { + std::vector<double> values(kSize); + + absl::InsecureBitGen gen; + for (int i = 0; i < kSize; i++) { + values[i] = absl::Poisson<int64_t>(gen); + } + + const auto moments = + absl::random_internal::ComputeDistributionMoments(values); + EXPECT_NEAR(1.0, moments.mean, 0.02); + EXPECT_NEAR(1.0, moments.variance, 0.02); + EXPECT_NEAR(1.0, moments.skewness, 0.025); + EXPECT_LT(2.0, moments.kurtosis); +} + +TEST_F(RandomDistributionsTest, PoissonLarge) { + constexpr double kMean = 100000000.0; + std::vector<double> values(kSize); + + absl::InsecureBitGen gen; + for (int i = 0; i < kSize; i++) { + values[i] = absl::Poisson<int64_t>(gen, kMean); + } + + const auto moments = + absl::random_internal::ComputeDistributionMoments(values); + EXPECT_NEAR(kMean, moments.mean, kMean * 0.015); + EXPECT_NEAR(kMean, moments.variance, kMean * 0.015); + EXPECT_NEAR(std::sqrt(kMean), moments.skewness, kMean * 0.02); + EXPECT_LT(2.0, moments.kurtosis); +} + +TEST_F(RandomDistributionsTest, Bernoulli) { + constexpr double kP = 0.5151515151; + std::vector<double> values(kSize); + + absl::InsecureBitGen gen; + for (int i = 0; i < kSize; i++) { + values[i] = absl::Bernoulli(gen, kP); + } + + const auto moments = + absl::random_internal::ComputeDistributionMoments(values); + EXPECT_NEAR(kP, moments.mean, 0.01); +} + +TEST_F(RandomDistributionsTest, Beta) { + constexpr double kAlpha = 2.0; + constexpr double kBeta = 3.0; + std::vector<double> values(kSize); + + absl::InsecureBitGen gen; + for (int i = 0; i < kSize; i++) { + values[i] = absl::Beta(gen, kAlpha, kBeta); + } + + const auto moments = + absl::random_internal::ComputeDistributionMoments(values); + EXPECT_NEAR(0.4, moments.mean, 0.01); +} + +TEST_F(RandomDistributionsTest, Zipf) { + std::vector<double> values(kSize); + + absl::InsecureBitGen gen; + for (int i = 0; i < kSize; i++) { + values[i] = absl::Zipf<int64_t>(gen, 100); + } + + // The mean of a zipf distribution is: H(N, s-1) / H(N,s). + // Given the parameter v = 1, this gives the following function: + // (Hn(100, 1) - Hn(1,1)) / (Hn(100,2) - Hn(1,2)) = 6.5944 + const auto moments = + absl::random_internal::ComputeDistributionMoments(values); + EXPECT_NEAR(6.5944, moments.mean, 2000) << moments; +} + +TEST_F(RandomDistributionsTest, Gaussian) { + std::vector<double> values(kSize); + + absl::InsecureBitGen gen; + for (int i = 0; i < kSize; i++) { + values[i] = absl::Gaussian<double>(gen); + } + + const auto moments = + absl::random_internal::ComputeDistributionMoments(values); + EXPECT_NEAR(0.0, moments.mean, 0.02); + EXPECT_NEAR(1.0, moments.variance, 0.04); + EXPECT_NEAR(0, moments.skewness, 0.2); + EXPECT_NEAR(3.0, moments.kurtosis, 0.5); +} + +TEST_F(RandomDistributionsTest, LogUniform) { + std::vector<double> values(kSize); + + absl::InsecureBitGen gen; + for (int i = 0; i < kSize; i++) { + values[i] = absl::LogUniform<int64_t>(gen, 0, (1 << 10) - 1); + } + + // The mean is the sum of the fractional means of the uniform distributions: + // [0..0][1..1][2..3][4..7][8..15][16..31][32..63] + // [64..127][128..255][256..511][512..1023] + const double mean = (0 + 1 + 1 + 2 + 3 + 4 + 7 + 8 + 15 + 16 + 31 + 32 + 63 + + 64 + 127 + 128 + 255 + 256 + 511 + 512 + 1023) / + (2.0 * 11.0); + + const auto moments = + absl::random_internal::ComputeDistributionMoments(values); + EXPECT_NEAR(mean, moments.mean, 2) << moments; +} + +} // namespace diff --git a/absl/random/examples_test.cc b/absl/random/examples_test.cc new file mode 100644 index 000000000000..1dcb5146f3d0 --- /dev/null +++ b/absl/random/examples_test.cc @@ -0,0 +1,99 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <cinttypes> +#include <random> +#include <sstream> +#include <vector> + +#include "gtest/gtest.h" +#include "absl/random/random.h" + +template <typename T> +void Use(T) {} + +TEST(Examples, Basic) { + absl::BitGen gen; + std::vector<int> objs = {10, 20, 30, 40, 50}; + + // Choose an element from a set. + auto elem = objs[absl::Uniform(gen, 0u, objs.size())]; + Use(elem); + + // Generate a uniform value between 1 and 6. + auto dice_roll = absl::Uniform<int>(absl::IntervalClosedClosed, gen, 1, 6); + Use(dice_roll); + + // Generate a random byte. + auto byte = absl::Uniform<uint8_t>(gen); + Use(byte); + + // Generate a fractional value from [0f, 1f). + auto fraction = absl::Uniform<float>(gen, 0, 1); + Use(fraction); + + // Toss a fair coin; 50/50 probability. + bool coin_toss = absl::Bernoulli(gen, 0.5); + Use(coin_toss); + + // Select a file size between 1k and 10MB, biased towards smaller file sizes. + auto file_size = absl::LogUniform<size_t>(gen, 1000, 10 * 1000 * 1000); + Use(file_size); + + // Randomize (shuffle) a collection. + std::shuffle(std::begin(objs), std::end(objs), gen); +} + +TEST(Examples, CreateingCorrelatedVariateSequences) { + // Unexpected PRNG correlation is often a source of bugs, + // so when using absl::BitGen it must be an intentional choice. + // NOTE: All of these only exhibit process-level stability. + + // Create a correlated sequence from system entropy. + { + auto my_seed = absl::MakeSeedSeq(); + + absl::BitGen gen_1(my_seed); + absl::BitGen gen_2(my_seed); // Produces same variates as gen_1. + + EXPECT_EQ(absl::Bernoulli(gen_1, 0.5), absl::Bernoulli(gen_2, 0.5)); + EXPECT_EQ(absl::Uniform<uint32_t>(gen_1), absl::Uniform<uint32_t>(gen_2)); + } + + // Create a correlated sequence from an existing URBG. + { + absl::BitGen gen; + + auto my_seed = absl::CreateSeedSeqFrom(&gen); + absl::BitGen gen_1(my_seed); + absl::BitGen gen_2(my_seed); + + EXPECT_EQ(absl::Bernoulli(gen_1, 0.5), absl::Bernoulli(gen_2, 0.5)); + EXPECT_EQ(absl::Uniform<uint32_t>(gen_1), absl::Uniform<uint32_t>(gen_2)); + } + + // An alternate construction which uses user-supplied data + // instead of a random seed. + { + const char kData[] = "A simple seed string"; + std::seed_seq my_seed(std::begin(kData), std::end(kData)); + + absl::BitGen gen_1(my_seed); + absl::BitGen gen_2(my_seed); + + EXPECT_EQ(absl::Bernoulli(gen_1, 0.5), absl::Bernoulli(gen_2, 0.5)); + EXPECT_EQ(absl::Uniform<uint32_t>(gen_1), absl::Uniform<uint32_t>(gen_2)); + } +} + diff --git a/absl/random/exponential_distribution.h b/absl/random/exponential_distribution.h new file mode 100644 index 000000000000..c8af197575c1 --- /dev/null +++ b/absl/random/exponential_distribution.h @@ -0,0 +1,157 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_EXPONENTIAL_DISTRIBUTION_H_ +#define ABSL_RANDOM_EXPONENTIAL_DISTRIBUTION_H_ + +#include <cassert> +#include <cmath> +#include <istream> +#include <limits> +#include <type_traits> + +#include "absl/random/internal/distribution_impl.h" +#include "absl/random/internal/fast_uniform_bits.h" +#include "absl/random/internal/iostream_state_saver.h" + +namespace absl { + +// absl::exponential_distribution: +// Generates a number conforming to an exponential distribution and is +// equivalent to the standard [rand.dist.pois.exp] distribution. +template <typename RealType = double> +class exponential_distribution { + public: + using result_type = RealType; + + class param_type { + public: + using distribution_type = exponential_distribution; + + explicit param_type(result_type lambda = 1) : lambda_(lambda) { + assert(lambda > 0); + neg_inv_lambda_ = -result_type(1) / lambda_; + } + + result_type lambda() const { return lambda_; } + + friend bool operator==(const param_type& a, const param_type& b) { + return a.lambda_ == b.lambda_; + } + + friend bool operator!=(const param_type& a, const param_type& b) { + return !(a == b); + } + + private: + friend class exponential_distribution; + + result_type lambda_; + result_type neg_inv_lambda_; + + static_assert( + std::is_floating_point<RealType>::value, + "Class-template absl::exponential_distribution<> must be parameterized " + "using a floating-point type."); + }; + + exponential_distribution() : exponential_distribution(1) {} + + explicit exponential_distribution(result_type lambda) : param_(lambda) {} + + explicit exponential_distribution(const param_type& p) : param_(p) {} + + void reset() {} + + // Generating functions + template <typename URBG> + result_type operator()(URBG& g) { // NOLINT(runtime/references) + return (*this)(g, param_); + } + + template <typename URBG> + result_type operator()(URBG& g, // NOLINT(runtime/references) + const param_type& p); + + param_type param() const { return param_; } + void param(const param_type& p) { param_ = p; } + + result_type(min)() const { return 0; } + result_type(max)() const { + return std::numeric_limits<result_type>::infinity(); + } + + result_type lambda() const { return param_.lambda(); } + + friend bool operator==(const exponential_distribution& a, + const exponential_distribution& b) { + return a.param_ == b.param_; + } + friend bool operator!=(const exponential_distribution& a, + const exponential_distribution& b) { + return a.param_ != b.param_; + } + + private: + param_type param_; + random_internal::FastUniformBits<uint64_t> fast_u64_; +}; + +// -------------------------------------------------------------------------- +// Implementation details follow +// -------------------------------------------------------------------------- + +template <typename RealType> +template <typename URBG> +typename exponential_distribution<RealType>::result_type +exponential_distribution<RealType>::operator()( + URBG& g, // NOLINT(runtime/references) + const param_type& p) { + using random_internal::NegativeValueT; + const result_type u = random_internal::RandU64ToReal< + result_type>::template Value<NegativeValueT, false>(fast_u64_(g)); + // log1p(-x) is mathematically equivalent to log(1 - x) but has more + // accuracy for x near zero. + return p.neg_inv_lambda_ * std::log1p(u); +} + +template <typename CharT, typename Traits, typename RealType> +std::basic_ostream<CharT, Traits>& operator<<( + std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references) + const exponential_distribution<RealType>& x) { + auto saver = random_internal::make_ostream_state_saver(os); + os.precision(random_internal::stream_precision_helper<RealType>::kPrecision); + os << x.lambda(); + return os; +} + +template <typename CharT, typename Traits, typename RealType> +std::basic_istream<CharT, Traits>& operator>>( + std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references) + exponential_distribution<RealType>& x) { // NOLINT(runtime/references) + using result_type = typename exponential_distribution<RealType>::result_type; + using param_type = typename exponential_distribution<RealType>::param_type; + result_type lambda; + + auto saver = random_internal::make_istream_state_saver(is); + lambda = random_internal::read_floating_point<result_type>(is); + if (!is.fail()) { + x.param(param_type(lambda)); + } + return is; +} + +} // namespace absl + +#endif // ABSL_RANDOM_EXPONENTIAL_DISTRIBUTION_H_ diff --git a/absl/random/exponential_distribution_test.cc b/absl/random/exponential_distribution_test.cc new file mode 100644 index 000000000000..6f8865c2599f --- /dev/null +++ b/absl/random/exponential_distribution_test.cc @@ -0,0 +1,422 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/exponential_distribution.h" + +#include <algorithm> +#include <cmath> +#include <cstddef> +#include <cstdint> +#include <iterator> +#include <limits> +#include <random> +#include <sstream> +#include <string> +#include <type_traits> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/base/internal/raw_logging.h" +#include "absl/base/macros.h" +#include "absl/random/internal/chi_square.h" +#include "absl/random/internal/distribution_test_util.h" +#include "absl/random/internal/sequence_urbg.h" +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/strip.h" + +namespace { + +using absl::random_internal::kChiSquared; + +template <typename RealType> +class ExponentialDistributionTypedTest : public ::testing::Test {}; + +using RealTypes = ::testing::Types<float, double, long double>; +TYPED_TEST_CASE(ExponentialDistributionTypedTest, RealTypes); + +TYPED_TEST(ExponentialDistributionTypedTest, SerializeTest) { + using param_type = + typename absl::exponential_distribution<TypeParam>::param_type; + + const TypeParam kParams[] = { + // Cases around 1. + 1, // + std::nextafter(TypeParam(1), TypeParam(0)), // 1 - epsilon + std::nextafter(TypeParam(1), TypeParam(2)), // 1 + epsilon + // Typical cases. + TypeParam(1e-8), TypeParam(1e-4), TypeParam(1), TypeParam(2), + TypeParam(1e4), TypeParam(1e8), TypeParam(1e20), TypeParam(2.5), + // Boundary cases. + std::numeric_limits<TypeParam>::max(), + std::numeric_limits<TypeParam>::epsilon(), + std::nextafter(std::numeric_limits<TypeParam>::min(), + TypeParam(1)), // min + epsilon + std::numeric_limits<TypeParam>::min(), // smallest normal + // There are some errors dealing with denorms on apple platforms. + std::numeric_limits<TypeParam>::denorm_min(), // smallest denorm + std::numeric_limits<TypeParam>::min() / 2, // denorm + std::nextafter(std::numeric_limits<TypeParam>::min(), + TypeParam(0)), // denorm_max + }; + + constexpr int kCount = 1000; + absl::InsecureBitGen gen; + + for (const TypeParam lambda : kParams) { + // Some values may be invalid; skip those. + if (!std::isfinite(lambda)) continue; + ABSL_ASSERT(lambda > 0); + + const param_type param(lambda); + + absl::exponential_distribution<TypeParam> before(lambda); + EXPECT_EQ(before.lambda(), param.lambda()); + + { + absl::exponential_distribution<TypeParam> via_param(param); + EXPECT_EQ(via_param, before); + EXPECT_EQ(via_param.param(), before.param()); + } + + // Smoke test. + auto sample_min = before.max(); + auto sample_max = before.min(); + for (int i = 0; i < kCount; i++) { + auto sample = before(gen); + EXPECT_GE(sample, before.min()) << before; + EXPECT_LE(sample, before.max()) << before; + if (sample > sample_max) sample_max = sample; + if (sample < sample_min) sample_min = sample; + } + if (!std::is_same<TypeParam, long double>::value) { + ABSL_INTERNAL_LOG(INFO, + absl::StrFormat("Range {%f}: %f, %f, lambda=%f", lambda, + sample_min, sample_max, lambda)); + } + + std::stringstream ss; + ss << before; + + if (!std::isfinite(lambda)) { + // Streams do not deserialize inf/nan correctly. + continue; + } + // Validate stream serialization. + absl::exponential_distribution<TypeParam> after(34.56f); + + EXPECT_NE(before.lambda(), after.lambda()); + EXPECT_NE(before.param(), after.param()); + EXPECT_NE(before, after); + + ss >> after; + +#if defined(__powerpc64__) || defined(__PPC64__) || defined(__powerpc__) || \ + defined(__ppc__) || defined(__PPC__) + if (std::is_same<TypeParam, long double>::value) { + // Roundtripping floating point values requires sufficient precision to + // reconstruct the exact value. It turns out that long double has some + // errors doing this on ppc, particularly for values + // near {1.0 +/- epsilon}. + if (lambda <= std::numeric_limits<double>::max() && + lambda >= std::numeric_limits<double>::lowest()) { + EXPECT_EQ(static_cast<double>(before.lambda()), + static_cast<double>(after.lambda())) + << ss.str(); + } + continue; + } +#endif + + EXPECT_EQ(before.lambda(), after.lambda()) // + << ss.str() << " " // + << (ss.good() ? "good " : "") // + << (ss.bad() ? "bad " : "") // + << (ss.eof() ? "eof " : "") // + << (ss.fail() ? "fail " : ""); + } +} + +// http://www.itl.nist.gov/div898/handbook/eda/section3/eda3667.htm + +class ExponentialModel { + public: + explicit ExponentialModel(double lambda) + : lambda_(lambda), beta_(1.0 / lambda) {} + + double lambda() const { return lambda_; } + + double mean() const { return beta_; } + double variance() const { return beta_ * beta_; } + double stddev() const { return std::sqrt(variance()); } + double skew() const { return 2; } + double kurtosis() const { return 6.0; } + + double CDF(double x) { return 1.0 - std::exp(-lambda_ * x); } + + // The inverse CDF, or PercentPoint function of the distribution + double InverseCDF(double p) { + ABSL_ASSERT(p >= 0.0); + ABSL_ASSERT(p < 1.0); + return -beta_ * std::log(1.0 - p); + } + + private: + const double lambda_; + const double beta_; +}; + +struct Param { + double lambda; + double p_fail; + int trials; +}; + +class ExponentialDistributionTests : public testing::TestWithParam<Param>, + public ExponentialModel { + public: + ExponentialDistributionTests() : ExponentialModel(GetParam().lambda) {} + + // SingleZTest provides a basic z-squared test of the mean vs. expected + // mean for data generated by the poisson distribution. + template <typename D> + bool SingleZTest(const double p, const size_t samples); + + // SingleChiSquaredTest provides a basic chi-squared test of the normal + // distribution. + template <typename D> + double SingleChiSquaredTest(); + + absl::InsecureBitGen rng_; +}; + +template <typename D> +bool ExponentialDistributionTests::SingleZTest(const double p, + const size_t samples) { + D dis(lambda()); + + std::vector<double> data; + data.reserve(samples); + for (size_t i = 0; i < samples; i++) { + const double x = dis(rng_); + data.push_back(x); + } + + const auto m = absl::random_internal::ComputeDistributionMoments(data); + const double max_err = absl::random_internal::MaxErrorTolerance(p); + const double z = absl::random_internal::ZScore(mean(), m); + const bool pass = absl::random_internal::Near("z", z, 0.0, max_err); + + if (!pass) { + ABSL_INTERNAL_LOG( + INFO, absl::StrFormat("p=%f max_err=%f\n" + " lambda=%f\n" + " mean=%f vs. %f\n" + " stddev=%f vs. %f\n" + " skewness=%f vs. %f\n" + " kurtosis=%f vs. %f\n" + " z=%f vs. 0", + p, max_err, lambda(), m.mean, mean(), + std::sqrt(m.variance), stddev(), m.skewness, + skew(), m.kurtosis, kurtosis(), z)); + } + return pass; +} + +template <typename D> +double ExponentialDistributionTests::SingleChiSquaredTest() { + const size_t kSamples = 10000; + const int kBuckets = 50; + + // The InverseCDF is the percent point function of the distribution, and can + // be used to assign buckets roughly uniformly. + std::vector<double> cutoffs; + const double kInc = 1.0 / static_cast<double>(kBuckets); + for (double p = kInc; p < 1.0; p += kInc) { + cutoffs.push_back(InverseCDF(p)); + } + if (cutoffs.back() != std::numeric_limits<double>::infinity()) { + cutoffs.push_back(std::numeric_limits<double>::infinity()); + } + + D dis(lambda()); + + std::vector<int32_t> counts(cutoffs.size(), 0); + for (int j = 0; j < kSamples; j++) { + const double x = dis(rng_); + auto it = std::upper_bound(cutoffs.begin(), cutoffs.end(), x); + counts[std::distance(cutoffs.begin(), it)]++; + } + + // Null-hypothesis is that the distribution is exponentially distributed + // with the provided lambda (not estimated from the data). + const int dof = static_cast<int>(counts.size()) - 1; + + // Our threshold for logging is 1-in-50. + const double threshold = absl::random_internal::ChiSquareValue(dof, 0.98); + + const double expected = + static_cast<double>(kSamples) / static_cast<double>(counts.size()); + + double chi_square = absl::random_internal::ChiSquareWithExpected( + std::begin(counts), std::end(counts), expected); + double p = absl::random_internal::ChiSquarePValue(chi_square, dof); + + if (chi_square > threshold) { + for (int i = 0; i < cutoffs.size(); i++) { + ABSL_INTERNAL_LOG( + INFO, absl::StrFormat("%d : (%f) = %d", i, cutoffs[i], counts[i])); + } + + ABSL_INTERNAL_LOG(INFO, + absl::StrCat("lambda ", lambda(), "\n", // + " expected ", expected, "\n", // + kChiSquared, " ", chi_square, " (", p, ")\n", + kChiSquared, " @ 0.98 = ", threshold)); + } + return p; +} + +TEST_P(ExponentialDistributionTests, ZTest) { + const size_t kSamples = 10000; + const auto& param = GetParam(); + const int expected_failures = + std::max(1, static_cast<int>(std::ceil(param.trials * param.p_fail))); + const double p = absl::random_internal::RequiredSuccessProbability( + param.p_fail, param.trials); + + int failures = 0; + for (int i = 0; i < param.trials; i++) { + failures += SingleZTest<absl::exponential_distribution<double>>(p, kSamples) + ? 0 + : 1; + } + EXPECT_LE(failures, expected_failures); +} + +TEST_P(ExponentialDistributionTests, ChiSquaredTest) { + const int kTrials = 20; + int failures = 0; + + for (int i = 0; i < kTrials; i++) { + double p_value = + SingleChiSquaredTest<absl::exponential_distribution<double>>(); + if (p_value < 0.005) { // 1/200 + failures++; + } + } + + // There is a 0.10% chance of producing at least one failure, so raise the + // failure threshold high enough to allow for a flake rate < 10,000. + EXPECT_LE(failures, 4); +} + +std::vector<Param> GenParams() { + return { + Param{1.0, 0.02, 100}, + Param{2.5, 0.02, 100}, + Param{10, 0.02, 100}, + // large + Param{1e4, 0.02, 100}, + Param{1e9, 0.02, 100}, + // small + Param{0.1, 0.02, 100}, + Param{1e-3, 0.02, 100}, + Param{1e-5, 0.02, 100}, + }; +} + +std::string ParamName(const ::testing::TestParamInfo<Param>& info) { + const auto& p = info.param; + std::string name = absl::StrCat("lambda_", absl::SixDigits(p.lambda)); + return absl::StrReplaceAll(name, {{"+", "_"}, {"-", "_"}, {".", "_"}}); +} + +INSTANTIATE_TEST_CASE_P(, ExponentialDistributionTests, + ::testing::ValuesIn(GenParams()), ParamName); + +// NOTE: absl::exponential_distribution is not guaranteed to be stable. +TEST(ExponentialDistributionTest, StabilityTest) { + // absl::exponential_distribution stability relies on std::log1p and + // absl::uniform_real_distribution. + absl::random_internal::sequence_urbg urbg( + {0x0003eb76f6f7f755ull, 0xFFCEA50FDB2F953Bull, 0xC332DDEFBE6C5AA5ull, + 0x6558218568AB9702ull, 0x2AEF7DAD5B6E2F84ull, 0x1521B62829076170ull, + 0xECDD4775619F1510ull, 0x13CCA830EB61BD96ull, 0x0334FE1EAA0363CFull, + 0xB5735C904C70A239ull, 0xD59E9E0BCBAADE14ull, 0xEECC86BC60622CA7ull}); + + std::vector<int> output(14); + + { + absl::exponential_distribution<double> dist; + std::generate(std::begin(output), std::end(output), + [&] { return static_cast<int>(10000.0 * dist(urbg)); }); + + EXPECT_EQ(14, urbg.invocations()); + EXPECT_THAT(output, + testing::ElementsAre(0, 71913, 14375, 5039, 1835, 861, 25936, + 804, 126, 12337, 17984, 27002, 0, 71913)); + } + + urbg.reset(); + { + absl::exponential_distribution<float> dist; + std::generate(std::begin(output), std::end(output), + [&] { return static_cast<int>(10000.0f * dist(urbg)); }); + + EXPECT_EQ(14, urbg.invocations()); + EXPECT_THAT(output, + testing::ElementsAre(0, 71913, 14375, 5039, 1835, 861, 25936, + 804, 126, 12337, 17984, 27002, 0, 71913)); + } +} + +TEST(ExponentialDistributionTest, AlgorithmBounds) { + // Relies on absl::uniform_real_distribution, so some of these comments + // reference that. + absl::exponential_distribution<double> dist; + + { + // This returns the smallest value >0 from absl::uniform_real_distribution. + absl::random_internal::sequence_urbg urbg({0x0000000000000001ull}); + double a = dist(urbg); + EXPECT_EQ(a, 5.42101086242752217004e-20); + } + + { + // This returns a value very near 0.5 from absl::uniform_real_distribution. + absl::random_internal::sequence_urbg urbg({0x7fffffffffffffefull}); + double a = dist(urbg); + EXPECT_EQ(a, 0.693147180559945175204); + } + + { + // This returns the largest value <1 from absl::uniform_real_distribution. + // WolframAlpha: ~39.1439465808987766283058547296341915292187253 + absl::random_internal::sequence_urbg urbg({0xFFFFFFFFFFFFFFeFull}); + double a = dist(urbg); + EXPECT_EQ(a, 36.7368005696771007251); + } + { + // This *ALSO* returns the largest value <1. + absl::random_internal::sequence_urbg urbg({0xFFFFFFFFFFFFFFFFull}); + double a = dist(urbg); + EXPECT_EQ(a, 36.7368005696771007251); + } +} + +} // namespace diff --git a/absl/random/gaussian_distribution.cc b/absl/random/gaussian_distribution.cc new file mode 100644 index 000000000000..5dd846194141 --- /dev/null +++ b/absl/random/gaussian_distribution.cc @@ -0,0 +1,102 @@ +// BEGIN GENERATED CODE; DO NOT EDIT +// clang-format off + +#include "absl/random/gaussian_distribution.h" + +namespace absl { +namespace random_internal { + +const gaussian_distribution_base::Tables + gaussian_distribution_base::zg_ = { + {3.7130862467425505, 3.442619855899000214, 3.223084984581141565, + 3.083228858216868318, 2.978696252647779819, 2.894344007021528942, + 2.82312535054891045, 2.761169372387176857, 2.706113573121819549, + 2.656406411261359679, 2.610972248431847387, 2.56903362592493778, + 2.530009672388827457, 2.493454522095372106, 2.459018177411830486, + 2.426420645533749809, 2.395434278011062457, 2.365871370117638595, + 2.337575241339236776, 2.310413683698762988, 2.284274059677471769, + 2.25905957386919809, 2.234686395590979036, 2.21108140887870297, + 2.188180432076048731, 2.165926793748921497, 2.144270182360394905, + 2.123165708673976138, 2.102573135189237608, 2.082456237992015957, + 2.062782274508307978, 2.043521536655067194, 2.02464697337738464, + 2.006133869963471206, 1.987959574127619033, 1.970103260854325633, + 1.952545729553555764, 1.935269228296621957, 1.918257300864508963, + 1.901494653105150423, 1.884967035707758143, 1.868661140994487768, + 1.852564511728090002, 1.836665460258444904, 1.820952996596124418, + 1.805416764219227366, 1.790046982599857506, 1.77483439558606837, + 1.759770224899592339, 1.744846128113799244, 1.730054160563729182, + 1.71538674071366648, 1.700836618569915748, 1.686396846779167014, + 1.6720607540975998, 1.657821920954023254, 1.643674156862867441, + 1.629611479470633562, 1.615628095043159629, 1.601718380221376581, + 1.587876864890574558, 1.574098216022999264, 1.560377222366167382, + 1.546708779859908844, 1.533087877674041755, 1.519509584765938559, + 1.505969036863201937, 1.492461423781352714, 1.478981976989922842, + 1.465525957342709296, 1.452088642889222792, 1.438665316684561546, + 1.425251254514058319, 1.411841712447055919, 1.398431914131003539, + 1.385017037732650058, 1.371592202427340812, 1.358152454330141534, + 1.34469275175354519, 1.331207949665625279, 1.317692783209412299, + 1.304141850128615054, 1.290549591926194894, 1.27691027356015363, + 1.263217961454619287, 1.249466499573066436, 1.23564948326336066, + 1.221760230539994385, 1.207791750415947662, 1.193736707833126465, + 1.17958738466398616, 1.165335636164750222, 1.150972842148865416, + 1.136489852013158774, 1.121876922582540237, 1.107123647534034028, + 1.092218876907275371, 1.077150624892893482, 1.061905963694822042, + 1.046470900764042922, 1.030830236068192907, 1.014967395251327842, + 0.9988642334929808131, 0.9825008035154263464, 0.9658550794011470098, + 0.9489026255113034436, 0.9316161966151479401, 0.9139652510230292792, + 0.8959153525809346874, 0.8774274291129204872, 0.8584568431938099931, + 0.8389522142975741614, 0.8188539067003538507, 0.7980920606440534693, + 0.7765839878947563557, 0.7542306644540520688, 0.7309119106424850631, + 0.7064796113354325779, 0.6807479186691505202, 0.6534786387399710295, + 0.6243585973360461505, 0.5929629424714434327, 0.5586921784081798625, + 0.5206560387620546848, 0.4774378372966830431, 0.4265479863554152429, + 0.3628714310970211909, 0.2723208648139477384, 0}, + {0.001014352564120377413, 0.002669629083880922793, 0.005548995220771345792, + 0.008624484412859888607, 0.01183947865788486861, 0.01516729801054656976, + 0.01859210273701129151, 0.02210330461592709475, 0.02569329193593428151, + 0.02935631744000685023, 0.03308788614622575758, 0.03688438878665621645, + 0.04074286807444417458, 0.04466086220049143157, 0.04863629585986780496, + 0.05266740190305100461, 0.05675266348104984759, 0.06089077034804041277, + 0.06508058521306804567, 0.06932111739357792179, 0.07361150188411341722, + 0.07795098251397346301, 0.08233889824223575293, 0.08677467189478028919, + 0.09125780082683036809, 0.095787849121731522, 0.1003644410286559929, + 0.1049872554094214289, 0.1096560210148404546, 0.1143705124488661323, + 0.1191305467076509556, 0.1239359802028679736, 0.1287867061959434012, + 0.1336826525834396151, 0.1386237799845948804, 0.1436100800906280339, + 0.1486415742423425057, 0.1537183122081819397, 0.1588403711394795748, + 0.1640078546834206341, 0.1692208922373653057, 0.1744796383307898324, + 0.1797842721232958407, 0.1851349970089926078, 0.1905320403191375633, + 0.1959756531162781534, 0.2014661100743140865, 0.2070037094399269362, + 0.2125887730717307134, 0.2182216465543058426, 0.2239026993850088965, + 0.229632325232116602, 0.2354109422634795556, 0.2412389935454402889, + 0.2471169475123218551, 0.2530452985073261551, 0.2590245673962052742, + 0.2650553022555897087, 0.271138079138385224, 0.2772735029191887857, + 0.2834622082232336471, 0.2897048604429605656, 0.2960021568469337061, + 0.3023548277864842593, 0.3087636380061818397, 0.3152293880650116065, + 0.3217529158759855901, 0.3283350983728509642, 0.3349768533135899506, + 0.3416791412315512977, 0.3484429675463274756, 0.355269384847918035, + 0.3621594953693184626, 0.3691144536644731522, 0.376135469510563536, + 0.3832238110559021416, 0.3903808082373155797, 0.3976078564938743676, + 0.404906420807223999, 0.4122780401026620578, 0.4197243320495753771, + 0.4272469983049970721, 0.4348478302499918513, 0.4425287152754694975, + 0.4502916436820402768, 0.458138716267873114, 0.4660721526894572309, + 0.4740943006930180559, 0.4822076463294863724, 0.4904148252838453348, + 0.4987186354709807201, 0.5071220510755701794, 0.5156282382440030565, + 0.5242405726729852944, 0.5329626593838373561, 0.5417983550254266145, + 0.5507517931146057588, 0.5598274127040882009, 0.5690299910679523787, + 0.5783646811197646898, 0.5878370544347081283, 0.5974531509445183408, + 0.6072195366251219584, 0.6171433708188825973, 0.6272324852499290282, + 0.6374954773350440806, 0.6479418211102242475, 0.6585820000500898219, + 0.6694276673488921414, 0.6804918409973358395, 0.6917891434366769676, + 0.7033360990161600101, 0.7151515074105005976, 0.7272569183441868201, + 0.7396772436726493094, 0.7524415591746134169, 0.7655841738977066102, + 0.7791460859296898134, 0.7931770117713072832, 0.8077382946829627652, + 0.8229072113814113187, 0.8387836052959920519, 0.8555006078694531446, + 0.873243048910072206, 0.8922816507840289901, 0.9130436479717434217, + 0.9362826816850632339, 0.9635996931270905952, 1}}; + +} // namespace random_internal +} // namespace absl + +// clang-format on +// END GENERATED CODE diff --git a/absl/random/gaussian_distribution.h b/absl/random/gaussian_distribution.h new file mode 100644 index 000000000000..1d1347bce0a0 --- /dev/null +++ b/absl/random/gaussian_distribution.h @@ -0,0 +1,260 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_GAUSSIAN_DISTRIBUTION_H_ +#define ABSL_RANDOM_GAUSSIAN_DISTRIBUTION_H_ + +// absl::gaussian_distribution implements the Ziggurat algorithm +// for generating random gaussian numbers. +// +// Implementation based on "The Ziggurat Method for Generating Random Variables" +// by George Marsaglia and Wai Wan Tsang: http://www.jstatsoft.org/v05/i08/ +// + +#include <cmath> +#include <cstdint> +#include <istream> +#include <limits> +#include <type_traits> + +#include "absl/random/internal/distribution_impl.h" +#include "absl/random/internal/fast_uniform_bits.h" +#include "absl/random/internal/iostream_state_saver.h" + +namespace absl { +namespace random_internal { + +// absl::gaussian_distribution_base implements the underlying ziggurat algorithm +// using the ziggurat tables generated by the gaussian_distribution_gentables +// binary. +// +// The specific algorithm has some of the improvements suggested by the +// 2005 paper, "An Improved Ziggurat Method to Generate Normal Random Samples", +// Jurgen A Doornik. (https://www.doornik.com/research/ziggurat.pdf) +class gaussian_distribution_base { + public: + template <typename URBG> + inline double zignor(URBG& g); // NOLINT(runtime/references) + + private: + friend class TableGenerator; + + template <typename URBG> + inline double zignor_fallback(URBG& g, // NOLINT(runtime/references) + bool neg); + + // Constants used for the gaussian distribution. + static constexpr double kR = 3.442619855899; // Start of the tail. + static constexpr double kRInv = 0.29047645161474317; // ~= (1.0 / kR) . + static constexpr double kV = 9.91256303526217e-3; + static constexpr uint64_t kMask = 0x07f; + + // The ziggurat tables store the pdf(f) and inverse-pdf(x) for equal-area + // points on one-half of the normal distribution, where the pdf function, + // pdf = e ^ (-1/2 *x^2), assumes that the mean = 0 & stddev = 1. + // + // These tables are just over 2kb in size; larger tables might improve the + // distributions, but also lead to more cache pollution. + // + // x = {3.71308, 3.44261, 3.22308, ..., 0} + // f = {0.00101, 0.00266, 0.00554, ..., 1} + struct Tables { + double x[kMask + 2]; + double f[kMask + 2]; + }; + static const Tables zg_; + random_internal::FastUniformBits<uint64_t> fast_u64_; +}; + +} // namespace random_internal + +// absl::gaussian_distribution: +// Generates a number conforming to a Gaussian distribution. +template <typename RealType = double> +class gaussian_distribution : random_internal::gaussian_distribution_base { + public: + using result_type = RealType; + + class param_type { + public: + using distribution_type = gaussian_distribution; + + explicit param_type(result_type mean = 0, result_type stddev = 1) + : mean_(mean), stddev_(stddev) {} + + // Returns the mean distribution parameter. The mean specifies the location + // of the peak. The default value is 0.0. + result_type mean() const { return mean_; } + + // Returns the deviation distribution parameter. The default value is 1.0. + result_type stddev() const { return stddev_; } + + friend bool operator==(const param_type& a, const param_type& b) { + return a.mean_ == b.mean_ && a.stddev_ == b.stddev_; + } + + friend bool operator!=(const param_type& a, const param_type& b) { + return !(a == b); + } + + private: + result_type mean_; + result_type stddev_; + + static_assert( + std::is_floating_point<RealType>::value, + "Class-template absl::gaussian_distribution<> must be parameterized " + "using a floating-point type."); + }; + + gaussian_distribution() : gaussian_distribution(0) {} + + explicit gaussian_distribution(result_type mean, result_type stddev = 1) + : param_(mean, stddev) {} + + explicit gaussian_distribution(const param_type& p) : param_(p) {} + + void reset() {} + + // Generating functions + template <typename URBG> + result_type operator()(URBG& g) { // NOLINT(runtime/references) + return (*this)(g, param_); + } + + template <typename URBG> + result_type operator()(URBG& g, // NOLINT(runtime/references) + const param_type& p); + + param_type param() const { return param_; } + void param(const param_type& p) { param_ = p; } + + result_type(min)() const { + return -std::numeric_limits<result_type>::infinity(); + } + result_type(max)() const { + return std::numeric_limits<result_type>::infinity(); + } + + result_type mean() const { return param_.mean(); } + result_type stddev() const { return param_.stddev(); } + + friend bool operator==(const gaussian_distribution& a, + const gaussian_distribution& b) { + return a.param_ == b.param_; + } + friend bool operator!=(const gaussian_distribution& a, + const gaussian_distribution& b) { + return a.param_ != b.param_; + } + + private: + param_type param_; +}; + +// -------------------------------------------------------------------------- +// Implementation details only below +// -------------------------------------------------------------------------- + +template <typename RealType> +template <typename URBG> +typename gaussian_distribution<RealType>::result_type +gaussian_distribution<RealType>::operator()( + URBG& g, // NOLINT(runtime/references) + const param_type& p) { + return p.mean() + p.stddev() * static_cast<result_type>(zignor(g)); +} + +template <typename CharT, typename Traits, typename RealType> +std::basic_ostream<CharT, Traits>& operator<<( + std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references) + const gaussian_distribution<RealType>& x) { + auto saver = random_internal::make_ostream_state_saver(os); + os.precision(random_internal::stream_precision_helper<RealType>::kPrecision); + os << x.mean() << os.fill() << x.stddev(); + return os; +} + +template <typename CharT, typename Traits, typename RealType> +std::basic_istream<CharT, Traits>& operator>>( + std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references) + gaussian_distribution<RealType>& x) { // NOLINT(runtime/references) + using result_type = typename gaussian_distribution<RealType>::result_type; + using param_type = typename gaussian_distribution<RealType>::param_type; + + auto saver = random_internal::make_istream_state_saver(is); + auto mean = random_internal::read_floating_point<result_type>(is); + if (is.fail()) return is; + auto stddev = random_internal::read_floating_point<result_type>(is); + if (!is.fail()) { + x.param(param_type(mean, stddev)); + } + return is; +} + +namespace random_internal { + +template <typename URBG> +inline double gaussian_distribution_base::zignor_fallback(URBG& g, bool neg) { + // This fallback path happens approximately 0.05% of the time. + double x, y; + do { + // kRInv = 1/r, U(0, 1) + x = kRInv * std::log(RandU64ToDouble<PositiveValueT, false>(fast_u64_(g))); + y = -std::log(RandU64ToDouble<PositiveValueT, false>(fast_u64_(g))); + } while ((y + y) < (x * x)); + return neg ? (x - kR) : (kR - x); +} + +template <typename URBG> +inline double gaussian_distribution_base::zignor( + URBG& g) { // NOLINT(runtime/references) + while (true) { + // We use a single uint64_t to generate both a double and a strip. + // These bits are unused when the generated double is > 1/2^5. + // This may introduce some bias from the duplicated low bits of small + // values (those smaller than 1/2^5, which all end up on the left tail). + uint64_t bits = fast_u64_(g); + int i = static_cast<int>(bits & kMask); // pick a random strip + double j = RandU64ToDouble<SignedValueT, false>(bits); // U(-1, 1) + const double x = j * zg_.x[i]; + + // Retangular box. Handles >97% of all cases. + // For any given box, this handles between 75% and 99% of values. + // Equivalent to U(01) < (x[i+1] / x[i]), and when i == 0, ~93.5% + if (std::abs(x) < zg_.x[i + 1]) { + return x; + } + + // i == 0: Base box. Sample using a ratio of uniforms. + if (i == 0) { + // This path happens about 0.05% of the time. + return zignor_fallback(g, j < 0); + } + + // i > 0: Wedge samples using precomputed values. + double v = RandU64ToDouble<PositiveValueT, false>(fast_u64_(g)); // U(0, 1) + if ((zg_.f[i + 1] + v * (zg_.f[i] - zg_.f[i + 1])) < + std::exp(-0.5 * x * x)) { + return x; + } + + // The wedge was missed; reject the value and try again. + } +} + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_GAUSSIAN_DISTRIBUTION_H_ diff --git a/absl/random/gaussian_distribution_test.cc b/absl/random/gaussian_distribution_test.cc new file mode 100644 index 000000000000..47c2989d340b --- /dev/null +++ b/absl/random/gaussian_distribution_test.cc @@ -0,0 +1,573 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/gaussian_distribution.h" + +#include <algorithm> +#include <cmath> +#include <cstddef> +#include <ios> +#include <iterator> +#include <random> +#include <string> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/base/internal/raw_logging.h" +#include "absl/base/macros.h" +#include "absl/random/internal/chi_square.h" +#include "absl/random/internal/distribution_test_util.h" +#include "absl/random/internal/sequence_urbg.h" +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/strip.h" + +namespace { + +using absl::random_internal::kChiSquared; + +template <typename RealType> +class GaussianDistributionInterfaceTest : public ::testing::Test {}; + +using RealTypes = ::testing::Types<float, double, long double>; +TYPED_TEST_CASE(GaussianDistributionInterfaceTest, RealTypes); + +TYPED_TEST(GaussianDistributionInterfaceTest, SerializeTest) { + using param_type = + typename absl::gaussian_distribution<TypeParam>::param_type; + + const TypeParam kParams[] = { + // Cases around 1. + 1, // + std::nextafter(TypeParam(1), TypeParam(0)), // 1 - epsilon + std::nextafter(TypeParam(1), TypeParam(2)), // 1 + epsilon + // Arbitrary values. + TypeParam(1e-8), TypeParam(1e-4), TypeParam(2), TypeParam(1e4), + TypeParam(1e8), TypeParam(1e20), TypeParam(2.5), + // Boundary cases. + std::numeric_limits<TypeParam>::infinity(), + std::numeric_limits<TypeParam>::max(), + std::numeric_limits<TypeParam>::epsilon(), + std::nextafter(std::numeric_limits<TypeParam>::min(), + TypeParam(1)), // min + epsilon + std::numeric_limits<TypeParam>::min(), // smallest normal + // There are some errors dealing with denorms on apple platforms. + std::numeric_limits<TypeParam>::denorm_min(), // smallest denorm + std::numeric_limits<TypeParam>::min() / 2, + std::nextafter(std::numeric_limits<TypeParam>::min(), + TypeParam(0)), // denorm_max + }; + + constexpr int kCount = 1000; + absl::InsecureBitGen gen; + + // Use a loop to generate the combinations of {+/-x, +/-y}, and assign x, y to + // all values in kParams, + for (const auto mod : {0, 1, 2, 3}) { + for (const auto x : kParams) { + if (!std::isfinite(x)) continue; + for (const auto y : kParams) { + const TypeParam mean = (mod & 0x1) ? -x : x; + const TypeParam stddev = (mod & 0x2) ? -y : y; + const param_type param(mean, stddev); + + absl::gaussian_distribution<TypeParam> before(mean, stddev); + EXPECT_EQ(before.mean(), param.mean()); + EXPECT_EQ(before.stddev(), param.stddev()); + + { + absl::gaussian_distribution<TypeParam> via_param(param); + EXPECT_EQ(via_param, before); + EXPECT_EQ(via_param.param(), before.param()); + } + + // Smoke test. + auto sample_min = before.max(); + auto sample_max = before.min(); + for (int i = 0; i < kCount; i++) { + auto sample = before(gen); + if (sample > sample_max) sample_max = sample; + if (sample < sample_min) sample_min = sample; + EXPECT_GE(sample, before.min()) << before; + EXPECT_LE(sample, before.max()) << before; + } + if (!std::is_same<TypeParam, long double>::value) { + ABSL_INTERNAL_LOG( + INFO, absl::StrFormat("Range{%f, %f}: %f, %f", mean, stddev, + sample_min, sample_max)); + } + + std::stringstream ss; + ss << before; + + if (!std::isfinite(mean) || !std::isfinite(stddev)) { + // Streams do not parse inf/nan. + continue; + } + + // Validate stream serialization. + absl::gaussian_distribution<TypeParam> after(-0.53f, 2.3456f); + + EXPECT_NE(before.mean(), after.mean()); + EXPECT_NE(before.stddev(), after.stddev()); + EXPECT_NE(before.param(), after.param()); + EXPECT_NE(before, after); + + ss >> after; + +#if defined(__powerpc64__) || defined(__PPC64__) || defined(__powerpc__) || \ + defined(__ppc__) || defined(__PPC__) + if (std::is_same<TypeParam, long double>::value) { + // Roundtripping floating point values requires sufficient precision + // to reconstruct the exact value. It turns out that long double + // has some errors doing this on ppc, particularly for values + // near {1.0 +/- epsilon}. + if (mean <= std::numeric_limits<double>::max() && + mean >= std::numeric_limits<double>::lowest()) { + EXPECT_EQ(static_cast<double>(before.mean()), + static_cast<double>(after.mean())) + << ss.str(); + } + if (stddev <= std::numeric_limits<double>::max() && + stddev >= std::numeric_limits<double>::lowest()) { + EXPECT_EQ(static_cast<double>(before.stddev()), + static_cast<double>(after.stddev())) + << ss.str(); + } + continue; + } +#endif + + EXPECT_EQ(before.mean(), after.mean()); + EXPECT_EQ(before.stddev(), after.stddev()) // + << ss.str() << " " // + << (ss.good() ? "good " : "") // + << (ss.bad() ? "bad " : "") // + << (ss.eof() ? "eof " : "") // + << (ss.fail() ? "fail " : ""); + } + } + } +} + +// http://www.itl.nist.gov/div898/handbook/eda/section3/eda3661.htm + +class GaussianModel { + public: + GaussianModel(double mean, double stddev) : mean_(mean), stddev_(stddev) {} + + double mean() const { return mean_; } + double variance() const { return stddev() * stddev(); } + double stddev() const { return stddev_; } + double skew() const { return 0; } + double kurtosis() const { return 3.0; } + + // The inverse CDF, or PercentPoint function. + double InverseCDF(double p) { + ABSL_ASSERT(p >= 0.0); + ABSL_ASSERT(p < 1.0); + return mean() + stddev() * -absl::random_internal::InverseNormalSurvival(p); + } + + private: + const double mean_; + const double stddev_; +}; + +struct Param { + double mean; + double stddev; + double p_fail; // Z-Test probability of failure. + int trials; // Z-Test trials. +}; + +// GaussianDistributionTests implements a z-test for the gaussian +// distribution. +class GaussianDistributionTests : public testing::TestWithParam<Param>, + public GaussianModel { + public: + GaussianDistributionTests() + : GaussianModel(GetParam().mean, GetParam().stddev) {} + + // SingleZTest provides a basic z-squared test of the mean vs. expected + // mean for data generated by the poisson distribution. + template <typename D> + bool SingleZTest(const double p, const size_t samples); + + // SingleChiSquaredTest provides a basic chi-squared test of the normal + // distribution. + template <typename D> + double SingleChiSquaredTest(); + + absl::InsecureBitGen rng_; +}; + +template <typename D> +bool GaussianDistributionTests::SingleZTest(const double p, + const size_t samples) { + D dis(mean(), stddev()); + + std::vector<double> data; + data.reserve(samples); + for (size_t i = 0; i < samples; i++) { + const double x = dis(rng_); + data.push_back(x); + } + + const double max_err = absl::random_internal::MaxErrorTolerance(p); + const auto m = absl::random_internal::ComputeDistributionMoments(data); + const double z = absl::random_internal::ZScore(mean(), m); + const bool pass = absl::random_internal::Near("z", z, 0.0, max_err); + + // NOTE: Informational statistical test: + // + // Compute the Jarque-Bera test statistic given the excess skewness + // and kurtosis. The statistic is drawn from a chi-square(2) distribution. + // https://en.wikipedia.org/wiki/Jarque%E2%80%93Bera_test + // + // The null-hypothesis (normal distribution) is rejected when + // (p = 0.05 => jb > 5.99) + // (p = 0.01 => jb > 9.21) + // NOTE: JB has a large type-I error rate, so it will reject the + // null-hypothesis even when it is true more often than the z-test. + // + const double jb = + static_cast<double>(m.n) / 6.0 * + (std::pow(m.skewness, 2.0) + std::pow(m.kurtosis - 3.0, 2.0) / 4.0); + + if (!pass || jb > 9.21) { + ABSL_INTERNAL_LOG( + INFO, absl::StrFormat("p=%f max_err=%f\n" + " mean=%f vs. %f\n" + " stddev=%f vs. %f\n" + " skewness=%f vs. %f\n" + " kurtosis=%f vs. %f\n" + " z=%f vs. 0\n" + " jb=%f vs. 9.21", + p, max_err, m.mean, mean(), std::sqrt(m.variance), + stddev(), m.skewness, skew(), m.kurtosis, + kurtosis(), z, jb)); + } + return pass; +} + +template <typename D> +double GaussianDistributionTests::SingleChiSquaredTest() { + const size_t kSamples = 10000; + const int kBuckets = 50; + + // The InverseCDF is the percent point function of the + // distribution, and can be used to assign buckets + // roughly uniformly. + std::vector<double> cutoffs; + const double kInc = 1.0 / static_cast<double>(kBuckets); + for (double p = kInc; p < 1.0; p += kInc) { + cutoffs.push_back(InverseCDF(p)); + } + if (cutoffs.back() != std::numeric_limits<double>::infinity()) { + cutoffs.push_back(std::numeric_limits<double>::infinity()); + } + + D dis(mean(), stddev()); + + std::vector<int32_t> counts(cutoffs.size(), 0); + for (int j = 0; j < kSamples; j++) { + const double x = dis(rng_); + auto it = std::upper_bound(cutoffs.begin(), cutoffs.end(), x); + counts[std::distance(cutoffs.begin(), it)]++; + } + + // Null-hypothesis is that the distribution is a gaussian distribution + // with the provided mean and stddev (not estimated from the data). + const int dof = static_cast<int>(counts.size()) - 1; + + // Our threshold for logging is 1-in-50. + const double threshold = absl::random_internal::ChiSquareValue(dof, 0.98); + + const double expected = + static_cast<double>(kSamples) / static_cast<double>(counts.size()); + + double chi_square = absl::random_internal::ChiSquareWithExpected( + std::begin(counts), std::end(counts), expected); + double p = absl::random_internal::ChiSquarePValue(chi_square, dof); + + // Log if the chi_square value is above the threshold. + if (chi_square > threshold) { + for (int i = 0; i < cutoffs.size(); i++) { + ABSL_INTERNAL_LOG( + INFO, absl::StrFormat("%d : (%f) = %d", i, cutoffs[i], counts[i])); + } + + ABSL_INTERNAL_LOG( + INFO, absl::StrCat("mean=", mean(), " stddev=", stddev(), "\n", // + " expected ", expected, "\n", // + kChiSquared, " ", chi_square, " (", p, ")\n", // + kChiSquared, " @ 0.98 = ", threshold)); + } + return p; +} + +TEST_P(GaussianDistributionTests, ZTest) { + // TODO(absl-team): Run these tests against std::normal_distribution<double> + // to validate outcomes are similar. + const size_t kSamples = 10000; + const auto& param = GetParam(); + const int expected_failures = + std::max(1, static_cast<int>(std::ceil(param.trials * param.p_fail))); + const double p = absl::random_internal::RequiredSuccessProbability( + param.p_fail, param.trials); + + int failures = 0; + for (int i = 0; i < param.trials; i++) { + failures += + SingleZTest<absl::gaussian_distribution<double>>(p, kSamples) ? 0 : 1; + } + EXPECT_LE(failures, expected_failures); +} + +TEST_P(GaussianDistributionTests, ChiSquaredTest) { + const int kTrials = 20; + int failures = 0; + + for (int i = 0; i < kTrials; i++) { + double p_value = + SingleChiSquaredTest<absl::gaussian_distribution<double>>(); + if (p_value < 0.0025) { // 1/400 + failures++; + } + } + // There is a 0.05% chance of producing at least one failure, so raise the + // failure threshold high enough to allow for a flake rate of less than one in + // 10,000. + EXPECT_LE(failures, 4); +} + +std::vector<Param> GenParams() { + return { + // Mean around 0. + Param{0.0, 1.0, 0.01, 100}, + Param{0.0, 1e2, 0.01, 100}, + Param{0.0, 1e4, 0.01, 100}, + Param{0.0, 1e8, 0.01, 100}, + Param{0.0, 1e16, 0.01, 100}, + Param{0.0, 1e-3, 0.01, 100}, + Param{0.0, 1e-5, 0.01, 100}, + Param{0.0, 1e-9, 0.01, 100}, + Param{0.0, 1e-17, 0.01, 100}, + + // Mean around 1. + Param{1.0, 1.0, 0.01, 100}, + Param{1.0, 1e2, 0.01, 100}, + Param{1.0, 1e-2, 0.01, 100}, + + // Mean around 100 / -100 + Param{1e2, 1.0, 0.01, 100}, + Param{-1e2, 1.0, 0.01, 100}, + Param{1e2, 1e6, 0.01, 100}, + Param{-1e2, 1e6, 0.01, 100}, + + // More extreme + Param{1e4, 1e4, 0.01, 100}, + Param{1e8, 1e4, 0.01, 100}, + Param{1e12, 1e4, 0.01, 100}, + }; +} + +std::string ParamName(const ::testing::TestParamInfo<Param>& info) { + const auto& p = info.param; + std::string name = absl::StrCat("mean_", absl::SixDigits(p.mean), "__stddev_", + absl::SixDigits(p.stddev)); + return absl::StrReplaceAll(name, {{"+", "_"}, {"-", "_"}, {".", "_"}}); +} + +INSTANTIATE_TEST_SUITE_P(, GaussianDistributionTests, + ::testing::ValuesIn(GenParams()), ParamName); + +// NOTE: absl::gaussian_distribution is not guaranteed to be stable. +TEST(GaussianDistributionTest, StabilityTest) { + // absl::gaussian_distribution stability relies on the underlying zignor + // data, absl::random_interna::RandU64ToDouble, std::exp, std::log, and + // std::abs. + absl::random_internal::sequence_urbg urbg( + {0x0003eb76f6f7f755ull, 0xFFCEA50FDB2F953Bull, 0xC332DDEFBE6C5AA5ull, + 0x6558218568AB9702ull, 0x2AEF7DAD5B6E2F84ull, 0x1521B62829076170ull, + 0xECDD4775619F1510ull, 0x13CCA830EB61BD96ull, 0x0334FE1EAA0363CFull, + 0xB5735C904C70A239ull, 0xD59E9E0BCBAADE14ull, 0xEECC86BC60622CA7ull}); + + std::vector<int> output(11); + + { + absl::gaussian_distribution<double> dist; + std::generate(std::begin(output), std::end(output), + [&] { return static_cast<int>(10000000.0 * dist(urbg)); }); + + EXPECT_EQ(13, urbg.invocations()); + EXPECT_THAT(output, // + testing::ElementsAre(1494, 25518841, 9991550, 1351856, + -20373238, 3456682, 333530, -6804981, + -15279580, -16459654, 1494)); + } + + urbg.reset(); + { + absl::gaussian_distribution<float> dist; + std::generate(std::begin(output), std::end(output), + [&] { return static_cast<int>(1000000.0f * dist(urbg)); }); + + EXPECT_EQ(13, urbg.invocations()); + EXPECT_THAT( + output, // + testing::ElementsAre(149, 2551884, 999155, 135185, -2037323, 345668, + 33353, -680498, -1527958, -1645965, 149)); + } +} + +// This is an implementation-specific test. If any part of the implementation +// changes, then it is likely that this test will change as well. +// Also, if dependencies of the distribution change, such as RandU64ToDouble, +// then this is also likely to change. +TEST(GaussianDistributionTest, AlgorithmBounds) { + absl::gaussian_distribution<double> dist; + + // In ~95% of cases, a single value is used to generate the output. + // for all inputs where |x| < 0.750461021389 this should be the case. + // + // The exact constraints are based on the ziggurat tables, and any + // changes to the ziggurat tables may require adjusting these bounds. + // + // for i in range(0, len(X)-1): + // print i, X[i+1]/X[i], (X[i+1]/X[i] > 0.984375) + // + // 0.125 <= |values| <= 0.75 + const uint64_t kValues[] = { + 0x1000000000000100ull, 0x2000000000000100ull, 0x3000000000000100ull, + 0x4000000000000100ull, 0x5000000000000100ull, 0x6000000000000100ull, + // negative values + 0x9000000000000100ull, 0xa000000000000100ull, 0xb000000000000100ull, + 0xc000000000000100ull, 0xd000000000000100ull, 0xe000000000000100ull}; + + // 0.875 <= |values| <= 0.984375 + const uint64_t kExtraValues[] = { + 0x7000000000000100ull, 0x7800000000000100ull, // + 0x7c00000000000100ull, 0x7e00000000000100ull, // + // negative values + 0xf000000000000100ull, 0xf800000000000100ull, // + 0xfc00000000000100ull, 0xfe00000000000100ull}; + + auto make_box = [](uint64_t v, uint64_t box) { + return (v & 0xffffffffffffff80ull) | box; + }; + + // The box is the lower 7 bits of the value. When the box == 0, then + // the algorithm uses an escape hatch to select the result for large + // outputs. + for (uint64_t box = 0; box < 0x7f; box++) { + for (const uint64_t v : kValues) { + // Extra values are added to the sequence to attempt to avoid + // infinite loops from rejection sampling on bugs/errors. + absl::random_internal::sequence_urbg urbg( + {make_box(v, box), 0x0003eb76f6f7f755ull, 0x5FCEA50FDB2F953Bull}); + + auto a = dist(urbg); + EXPECT_EQ(1, urbg.invocations()) << box << " " << std::hex << v; + if (v & 0x8000000000000000ull) { + EXPECT_LT(a, 0.0) << box << " " << std::hex << v; + } else { + EXPECT_GT(a, 0.0) << box << " " << std::hex << v; + } + } + if (box > 10 && box < 100) { + // The center boxes use the fast algorithm for more + // than 98.4375% of values. + for (const uint64_t v : kExtraValues) { + absl::random_internal::sequence_urbg urbg( + {make_box(v, box), 0x0003eb76f6f7f755ull, 0x5FCEA50FDB2F953Bull}); + + auto a = dist(urbg); + EXPECT_EQ(1, urbg.invocations()) << box << " " << std::hex << v; + if (v & 0x8000000000000000ull) { + EXPECT_LT(a, 0.0) << box << " " << std::hex << v; + } else { + EXPECT_GT(a, 0.0) << box << " " << std::hex << v; + } + } + } + } + + // When the box == 0, the fallback algorithm uses a ratio of uniforms, + // which consumes 2 additional values from the urbg. + // Fallback also requires that the initial value be > 0.9271586026096681. + auto make_fallback = [](uint64_t v) { return (v & 0xffffffffffffff80ull); }; + + double tail[2]; + { + // 0.9375 + absl::random_internal::sequence_urbg urbg( + {make_fallback(0x7800000000000000ull), 0x13CCA830EB61BD96ull, + 0x00000076f6f7f755ull}); + tail[0] = dist(urbg); + EXPECT_EQ(3, urbg.invocations()); + EXPECT_GT(tail[0], 0); + } + { + // -0.9375 + absl::random_internal::sequence_urbg urbg( + {make_fallback(0xf800000000000000ull), 0x13CCA830EB61BD96ull, + 0x00000076f6f7f755ull}); + tail[1] = dist(urbg); + EXPECT_EQ(3, urbg.invocations()); + EXPECT_LT(tail[1], 0); + } + EXPECT_EQ(tail[0], -tail[1]); + EXPECT_EQ(418610, static_cast<int64_t>(tail[0] * 100000.0)); + + // When the box != 0, the fallback algorithm computes a wedge function. + // Depending on the box, the threshold for varies as high as + // 0.991522480228. + { + // 0.9921875, 0.875 + absl::random_internal::sequence_urbg urbg( + {make_box(0x7f00000000000000ull, 120), 0xe000000000000001ull, + 0x13CCA830EB61BD96ull}); + tail[0] = dist(urbg); + EXPECT_EQ(2, urbg.invocations()); + EXPECT_GT(tail[0], 0); + } + { + // -0.9921875, 0.875 + absl::random_internal::sequence_urbg urbg( + {make_box(0xff00000000000000ull, 120), 0xe000000000000001ull, + 0x13CCA830EB61BD96ull}); + tail[1] = dist(urbg); + EXPECT_EQ(2, urbg.invocations()); + EXPECT_LT(tail[1], 0); + } + EXPECT_EQ(tail[0], -tail[1]); + EXPECT_EQ(61948, static_cast<int64_t>(tail[0] * 100000.0)); + + // Fallback rejected, try again. + { + // -0.9921875, 0.0625 + absl::random_internal::sequence_urbg urbg( + {make_box(0xff00000000000000ull, 120), 0x1000000000000001, + make_box(0x1000000000000100ull, 50), 0x13CCA830EB61BD96ull}); + dist(urbg); + EXPECT_EQ(3, urbg.invocations()); + } +} + +} // namespace diff --git a/absl/random/generators_test.cc b/absl/random/generators_test.cc new file mode 100644 index 000000000000..41725f139cd3 --- /dev/null +++ b/absl/random/generators_test.cc @@ -0,0 +1,179 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <cstddef> +#include <cstdint> +#include <random> +#include <vector> + +#include "gtest/gtest.h" +#include "absl/random/distributions.h" +#include "absl/random/random.h" + +namespace { + +template <typename URBG> +void TestUniform(URBG* gen) { + // [a, b) default-semantics, inferred types. + absl::Uniform(*gen, 0, 100); // int + absl::Uniform(*gen, 0, 1.0); // Promoted to double + absl::Uniform(*gen, 0.0f, 1.0); // Promoted to double + absl::Uniform(*gen, 0.0, 1.0); // double + absl::Uniform(*gen, -1, 1L); // Promoted to long + + // Roll a die. + absl::Uniform(absl::IntervalClosedClosed, *gen, 1, 6); + + // Get a fraction. + absl::Uniform(absl::IntervalOpenOpen, *gen, 0.0, 1.0); + + // Assign a value to a random element. + std::vector<int> elems = {10, 20, 30, 40, 50}; + elems[absl::Uniform(*gen, 0u, elems.size())] = 5; + elems[absl::Uniform<size_t>(*gen, 0, elems.size())] = 3; + + // Choose some epsilon around zero. + absl::Uniform(absl::IntervalOpenOpen, *gen, -1.0, 1.0); + + // (a, b) semantics, inferred types. + absl::Uniform(absl::IntervalOpenOpen, *gen, 0, 1.0); // Promoted to double + + // Explict overriding of types. + absl::Uniform<int>(*gen, 0, 100); + absl::Uniform<int8_t>(*gen, 0, 100); + absl::Uniform<int16_t>(*gen, 0, 100); + absl::Uniform<uint16_t>(*gen, 0, 100); + absl::Uniform<int32_t>(*gen, 0, 1 << 10); + absl::Uniform<uint32_t>(*gen, 0, 1 << 10); + absl::Uniform<int64_t>(*gen, 0, 1 << 10); + absl::Uniform<uint64_t>(*gen, 0, 1 << 10); + + absl::Uniform<float>(*gen, 0.0, 1.0); + absl::Uniform<float>(*gen, 0, 1); + absl::Uniform<float>(*gen, -1, 1); + absl::Uniform<double>(*gen, 0.0, 1.0); + + absl::Uniform<float>(*gen, -1.0, 0); + absl::Uniform<double>(*gen, -1.0, 0); + + // Tagged + absl::Uniform<double>(absl::IntervalClosedClosed, *gen, 0, 1); + absl::Uniform<double>(absl::IntervalClosedOpen, *gen, 0, 1); + absl::Uniform<double>(absl::IntervalOpenOpen, *gen, 0, 1); + absl::Uniform<double>(absl::IntervalOpenClosed, *gen, 0, 1); + absl::Uniform<double>(absl::IntervalClosedClosed, *gen, 0, 1); + absl::Uniform<double>(absl::IntervalOpenOpen, *gen, 0, 1); + + absl::Uniform<int>(absl::IntervalClosedClosed, *gen, 0, 100); + absl::Uniform<int>(absl::IntervalClosedOpen, *gen, 0, 100); + absl::Uniform<int>(absl::IntervalOpenOpen, *gen, 0, 100); + absl::Uniform<int>(absl::IntervalOpenClosed, *gen, 0, 100); + absl::Uniform<int>(absl::IntervalClosedClosed, *gen, 0, 100); + absl::Uniform<int>(absl::IntervalOpenOpen, *gen, 0, 100); + + // With *generator as an R-value reference. + absl::Uniform<int>(URBG(), 0, 100); + absl::Uniform<double>(URBG(), 0.0, 1.0); +} + +template <typename URBG> +void TestExponential(URBG* gen) { + absl::Exponential<float>(*gen); + absl::Exponential<double>(*gen); + absl::Exponential<double>(URBG()); +} + +template <typename URBG> +void TestPoisson(URBG* gen) { + // [rand.dist.pois] Indicates that the std::poisson_distribution + // is parameterized by IntType, however MSVC does not allow 8-bit + // types. + absl::Poisson<int>(*gen); + absl::Poisson<int16_t>(*gen); + absl::Poisson<uint16_t>(*gen); + absl::Poisson<int32_t>(*gen); + absl::Poisson<uint32_t>(*gen); + absl::Poisson<int64_t>(*gen); + absl::Poisson<uint64_t>(*gen); + absl::Poisson<uint64_t>(URBG()); +} + +template <typename URBG> +void TestBernoulli(URBG* gen) { + absl::Bernoulli(*gen, 0.5); + absl::Bernoulli(*gen, 0.5); +} + +template <typename URBG> +void TestZipf(URBG* gen) { + absl::Zipf<int>(*gen, 100); + absl::Zipf<int8_t>(*gen, 100); + absl::Zipf<int16_t>(*gen, 100); + absl::Zipf<uint16_t>(*gen, 100); + absl::Zipf<int32_t>(*gen, 1 << 10); + absl::Zipf<uint32_t>(*gen, 1 << 10); + absl::Zipf<int64_t>(*gen, 1 << 10); + absl::Zipf<uint64_t>(*gen, 1 << 10); + absl::Zipf<uint64_t>(URBG(), 1 << 10); +} + +template <typename URBG> +void TestGaussian(URBG* gen) { + absl::Gaussian<float>(*gen, 1.0, 1.0); + absl::Gaussian<double>(*gen, 1.0, 1.0); + absl::Gaussian<double>(URBG(), 1.0, 1.0); +} + +template <typename URBG> +void TestLogNormal(URBG* gen) { + absl::LogUniform<int>(*gen, 0, 100); + absl::LogUniform<int8_t>(*gen, 0, 100); + absl::LogUniform<int16_t>(*gen, 0, 100); + absl::LogUniform<uint16_t>(*gen, 0, 100); + absl::LogUniform<int32_t>(*gen, 0, 1 << 10); + absl::LogUniform<uint32_t>(*gen, 0, 1 << 10); + absl::LogUniform<int64_t>(*gen, 0, 1 << 10); + absl::LogUniform<uint64_t>(*gen, 0, 1 << 10); + absl::LogUniform<uint64_t>(URBG(), 0, 1 << 10); +} + +template <typename URBG> +void CompatibilityTest() { + URBG gen; + + TestUniform(&gen); + TestExponential(&gen); + TestPoisson(&gen); + TestBernoulli(&gen); + TestZipf(&gen); + TestGaussian(&gen); + TestLogNormal(&gen); +} + +TEST(std_mt19937_64, Compatibility) { + // Validate with std::mt19937_64 + CompatibilityTest<std::mt19937_64>(); +} + +TEST(BitGen, Compatibility) { + // Validate with absl::BitGen + CompatibilityTest<absl::BitGen>(); +} + +TEST(InsecureBitGen, Compatibility) { + // Validate with absl::InsecureBitGen + CompatibilityTest<absl::InsecureBitGen>(); +} + +} // namespace diff --git a/absl/random/internal/BUILD.bazel b/absl/random/internal/BUILD.bazel new file mode 100644 index 000000000000..50360acb973b --- /dev/null +++ b/absl/random/internal/BUILD.bazel @@ -0,0 +1,656 @@ +# Internal-only implementation classes for Abseil Random +load( + "//absl:copts/configure_copts.bzl", + "ABSL_DEFAULT_COPTS", + "ABSL_DEFAULT_LINKOPTS", + "ABSL_RANDOM_RANDEN_COPTS", + "ABSL_TEST_COPTS", + "absl_random_randen_copts_init", +) + +package(default_visibility = ["//absl/random:__pkg__"]) + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "traits", + hdrs = ["traits.h"], + copts = ABSL_DEFAULT_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + visibility = [ + "//absl/random:__pkg__", + ], + deps = ["//absl/base:config"], +) + +cc_library( + name = "distribution_caller", + hdrs = ["distribution_caller.h"], + copts = ABSL_DEFAULT_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + visibility = [ + "//absl/random:__pkg__", + ], +) + +cc_library( + name = "distributions", + hdrs = [ + "distributions.h", + ], + copts = ABSL_DEFAULT_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":distribution_caller", + ":fast_uniform_bits", + ":fastmath", + ":traits", + ":uniform_helper", + "//absl/meta:type_traits", + "//absl/strings", + "//absl/types:span", + ], +) + +cc_library( + name = "fast_uniform_bits", + hdrs = [ + "fast_uniform_bits.h", + ], + copts = ABSL_DEFAULT_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + visibility = [ + "//absl/random:__pkg__", + ], +) + +cc_library( + name = "seed_material", + srcs = [ + "seed_material.cc", + ], + hdrs = [ + "seed_material.h", + ], + copts = ABSL_DEFAULT_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":fast_uniform_bits", + "//absl/base", + "//absl/base:core_headers", + "//absl/strings", + "//absl/types:optional", + "//absl/types:span", + ], +) + +cc_library( + name = "pool_urbg", + srcs = [ + "pool_urbg.cc", + ], + hdrs = [ + "pool_urbg.h", + ], + copts = ABSL_DEFAULT_COPTS, + linkopts = select({ + "//absl:windows": [], + "//conditions:default": ["-pthread"], + }) + ABSL_DEFAULT_LINKOPTS, + deps = [ + ":randen", + ":seed_material", + ":traits", + "//absl/base", + "//absl/base:config", + "//absl/base:core_headers", + "//absl/base:endian", + "//absl/random:seed_gen_exception", + "//absl/types:span", + ], +) + +cc_library( + name = "explicit_seed_seq", + testonly = 1, + hdrs = [ + "explicit_seed_seq.h", + ], + copts = ABSL_DEFAULT_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, +) + +cc_library( + name = "sequence_urbg", + testonly = 1, + hdrs = [ + "sequence_urbg.h", + ], + copts = ABSL_DEFAULT_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, +) + +cc_library( + name = "salted_seed_seq", + hdrs = [ + "salted_seed_seq.h", + ], + copts = ABSL_DEFAULT_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":seed_material", + "//absl/container:inlined_vector", + "//absl/meta:type_traits", + "//absl/types:optional", + "//absl/types:span", + ], +) + +cc_library( + name = "iostream_state_saver", + hdrs = ["iostream_state_saver.h"], + copts = ABSL_DEFAULT_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + "//absl/meta:type_traits", + "//absl/numeric:int128", + ], +) + +cc_library( + name = "distribution_impl", + hdrs = [ + "distribution_impl.h", + ], + copts = ABSL_DEFAULT_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":fastmath", + ":traits", + "//absl/base:bits", + "//absl/base:config", + "//absl/numeric:int128", + ], +) + +cc_library( + name = "fastmath", + hdrs = [ + "fastmath.h", + ], + copts = ABSL_DEFAULT_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = ["//absl/base:bits"], +) + +cc_library( + name = "nonsecure_base", + hdrs = ["nonsecure_base.h"], + copts = ABSL_DEFAULT_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":pool_urbg", + ":salted_seed_seq", + ":seed_material", + "//absl/base:core_headers", + "//absl/meta:type_traits", + "//absl/strings", + "//absl/types:optional", + "//absl/types:span", + ], +) + +cc_library( + name = "pcg_engine", + hdrs = ["pcg_engine.h"], + copts = ABSL_DEFAULT_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":fastmath", + ":iostream_state_saver", + "//absl/base:config", + "//absl/meta:type_traits", + "//absl/numeric:int128", + ], +) + +cc_library( + name = "randen_engine", + hdrs = ["randen_engine.h"], + copts = ABSL_DEFAULT_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":iostream_state_saver", + ":randen", + "//absl/meta:type_traits", + ], +) + +cc_library( + name = "platform", + hdrs = [ + "randen_traits.h", + ], + copts = ABSL_DEFAULT_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + textual_hdrs = [ + "randen-keys.inc", + "platform.h", + ], +) + +cc_library( + name = "randen", + srcs = [ + "randen.cc", + ], + hdrs = [ + "randen.h", + ], + copts = ABSL_DEFAULT_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":platform", + ":randen_hwaes", + ":randen_slow", + "//absl/base", + ], +) + +cc_library( + name = "randen_slow", + srcs = ["randen_slow.cc"], + hdrs = ["randen_slow.h"], + copts = ABSL_DEFAULT_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":platform", + ], +) + +absl_random_randen_copts_init() + +cc_library( + name = "randen_hwaes", + srcs = [ + "randen_detect.cc", + ], + hdrs = [ + "randen_detect.h", + "randen_hwaes.h", + ], + copts = ABSL_DEFAULT_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":platform", + ":randen_hwaes_impl", + ], +) + +# build with --save_temps to see assembly language output. +cc_library( + name = "randen_hwaes_impl", + srcs = [ + "randen_hwaes.cc", + "randen_hwaes.h", + ], + copts = ABSL_DEFAULT_COPTS + ABSL_RANDOM_RANDEN_COPTS + select({ + "//absl:windows": [], + "//conditions:default": ["-Wno-pass-failed"], + }), + # copts in RANDEN_HWAES_COPTS can make this target unusable as a module + # leading to a Clang diagnostic. Furthermore, it only has a private header + # anyway and thus there wouldn't be any gain from using it as a module. + features = ["-header_modules"], + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [":platform"], +) + +cc_binary( + name = "gaussian_distribution_gentables", + srcs = [ + "gaussian_distribution_gentables.cc", + ], + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + "//absl/base:core_headers", + "//absl/random:distributions", + ], +) + +cc_library( + name = "distribution_test_util", + testonly = 1, + srcs = [ + "chi_square.cc", + "distribution_test_util.cc", + ], + hdrs = [ + "chi_square.h", + "distribution_test_util.h", + ], + copts = ABSL_DEFAULT_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + "//absl/base", + "//absl/base:core_headers", + "//absl/strings", + "//absl/strings:str_format", + "//absl/types:span", + ], +) + +# Common tags for tests, etc. +ABSL_RANDOM_NONPORTABLE_TAGS = [ + "no_test_android_arm", + "no_test_android_arm64", + "no_test_android_x86", + "no_test_darwin_x86_64", + "no_test_ios_x86_64", + "no_test_loonix", + "no_test_msvc_x64", + "no_test_wasm", +] + +cc_test( + name = "traits_test", + size = "small", + srcs = ["traits_test.cc"], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":traits", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "distribution_impl_test", + size = "small", + srcs = ["distribution_impl_test.cc"], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":distribution_impl", + "//absl/base:bits", + "//absl/flags:flag", + "//absl/numeric:int128", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "distribution_test_util_test", + size = "small", + srcs = ["distribution_test_util_test.cc"], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":distribution_test_util", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "fastmath_test", + size = "small", + srcs = ["fastmath_test.cc"], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":fastmath", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "explicit_seed_seq_test", + size = "small", + srcs = ["explicit_seed_seq_test.cc"], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":explicit_seed_seq", + "//absl/random:seed_sequences", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "salted_seed_seq_test", + size = "small", + srcs = ["salted_seed_seq_test.cc"], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":salted_seed_seq", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "chi_square_test", + size = "small", + srcs = [ + "chi_square_test.cc", + ], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":distribution_test_util", + "//absl/base:core_headers", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "fast_uniform_bits_test", + size = "small", + srcs = [ + "fast_uniform_bits_test.cc", + ], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":fast_uniform_bits", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "nonsecure_base_test", + size = "small", + srcs = [ + "nonsecure_base_test.cc", + ], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":nonsecure_base", + "//absl/random", + "//absl/random:distributions", + "//absl/random:seed_sequences", + "//absl/strings", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "seed_material_test", + size = "small", + srcs = ["seed_material_test.cc"], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":seed_material", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "pool_urbg_test", + size = "small", + srcs = [ + "pool_urbg_test.cc", + ], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":pool_urbg", + "//absl/meta:type_traits", + "//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "pcg_engine_test", + size = "medium", # Trying to measure accuracy. + srcs = ["pcg_engine_test.cc"], + copts = ABSL_TEST_COPTS, + flaky = 1, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":explicit_seed_seq", + ":pcg_engine", + "//absl/time", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "randen_engine_test", + size = "small", + srcs = [ + "randen_engine_test.cc", + ], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":explicit_seed_seq", + ":randen_engine", + "//absl/base", + "//absl/strings", + "//absl/time", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "randen_test", + size = "small", + srcs = ["randen_test.cc"], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":randen", + "//absl/meta:type_traits", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "randen_slow_test", + size = "small", + srcs = ["randen_slow_test.cc"], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":randen_slow", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "randen_hwaes_test", + size = "small", + srcs = ["randen_hwaes_test.cc"], + copts = ABSL_TEST_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + tags = ABSL_RANDOM_NONPORTABLE_TAGS, + deps = [ + ":platform", + ":randen_hwaes", + ":randen_hwaes_impl", # build_cleaner: keep + "//absl/base", + "//absl/strings:str_format", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "nanobenchmark", + srcs = ["nanobenchmark.cc"], + linkopts = ABSL_DEFAULT_LINKOPTS, + textual_hdrs = ["nanobenchmark.h"], + deps = [ + ":platform", + ":randen_engine", + "//absl/base", + ], +) + +cc_library( + name = "uniform_helper", + hdrs = ["uniform_helper.h"], + copts = ABSL_DEFAULT_COPTS, + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + "//absl/base:core_headers", + "//absl/meta:type_traits", + "//absl/random/internal:distribution_impl", + "//absl/random/internal:fast_uniform_bits", + "//absl/random/internal:iostream_state_saver", + "//absl/random/internal:traits", + ], +) + +cc_test( + name = "nanobenchmark_test", + size = "small", + srcs = ["nanobenchmark_test.cc"], + flaky = 1, + linkopts = ABSL_DEFAULT_LINKOPTS, + tags = [ + "benchmark", + "no_test_ios_x86_64", + "no_test_loonix", # Crashing. + ], + deps = [ + ":nanobenchmark", + "//absl/base", + "//absl/strings", + ], +) + +cc_test( + name = "randen_benchmarks", + size = "medium", + srcs = ["randen_benchmarks.cc"], + copts = ABSL_TEST_COPTS + ABSL_RANDOM_RANDEN_COPTS, + flaky = 1, + linkopts = ABSL_DEFAULT_LINKOPTS, + tags = ABSL_RANDOM_NONPORTABLE_TAGS + ["benchmark"], + deps = [ + ":nanobenchmark", + ":platform", + ":randen", + ":randen_engine", + ":randen_hwaes", + ":randen_hwaes_impl", + ":randen_slow", + "//absl/base", + "//absl/strings", + ], +) + +cc_test( + name = "iostream_state_saver_test", + size = "small", + srcs = ["iostream_state_saver_test.cc"], + linkopts = ABSL_DEFAULT_LINKOPTS, + deps = [ + ":iostream_state_saver", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/absl/random/internal/chi_square.cc b/absl/random/internal/chi_square.cc new file mode 100644 index 000000000000..c0acc9477883 --- /dev/null +++ b/absl/random/internal/chi_square.cc @@ -0,0 +1,230 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/chi_square.h" + +#include <cmath> + +#include "absl/random/internal/distribution_test_util.h" + +namespace absl { +namespace random_internal { +namespace { + +#if defined(__EMSCRIPTEN__) +// Workaround __EMSCRIPTEN__ error: llvm_fma_f64 not found. +inline double fma(double x, double y, double z) { + return (x * y) + z; +} +#endif + +// Use Horner's method to evaluate a polynomial. +template <typename T, unsigned N> +inline T EvaluatePolynomial(T x, const T (&poly)[N]) { +#if !defined(__EMSCRIPTEN__) + using std::fma; +#endif + T p = poly[N - 1]; + for (unsigned i = 2; i <= N; i++) { + p = fma(p, x, poly[N - i]); + } + return p; +} + +static constexpr int kLargeDOF = 150; + +// Returns the probability of a normal z-value. +// +// Adapted from the POZ function in: +// Ibbetson D, Algorithm 209 +// Collected Algorithms of the CACM 1963 p. 616 +// +double POZ(double z) { + static constexpr double kP1[] = { + 0.797884560593, -0.531923007300, 0.319152932694, + -0.151968751364, 0.059054035642, -0.019198292004, + 0.005198775019, -0.001075204047, 0.000124818987, + }; + static constexpr double kP2[] = { + 0.999936657524, 0.000535310849, -0.002141268741, 0.005353579108, + -0.009279453341, 0.011630447319, -0.010557625006, 0.006549791214, + -0.002034254874, -0.000794620820, 0.001390604284, -0.000676904986, + -0.000019538132, 0.000152529290, -0.000045255659, + }; + + const double kZMax = 6.0; // Maximum meaningful z-value. + if (z == 0.0) { + return 0.5; + } + double x; + double y = 0.5 * std::fabs(z); + if (y >= (kZMax * 0.5)) { + x = 1.0; + } else if (y < 1.0) { + double w = y * y; + x = EvaluatePolynomial(w, kP1) * y * 2.0; + } else { + y -= 2.0; + x = EvaluatePolynomial(y, kP2); + } + return z > 0.0 ? ((x + 1.0) * 0.5) : ((1.0 - x) * 0.5); +} + +// Approximates the survival function of the normal distribution. +// +// Algorithm 26.2.18, from: +// [Abramowitz and Stegun, Handbook of Mathematical Functions,p.932] +// http://people.math.sfu.ca/~cbm/aands/abramowitz_and_stegun.pdf +// +double normal_survival(double z) { + // Maybe replace with the alternate formulation. + // 0.5 * erfc((x - mean)/(sqrt(2) * sigma)) + static constexpr double kR[] = { + 1.0, 0.196854, 0.115194, 0.000344, 0.019527, + }; + double r = EvaluatePolynomial(z, kR); + r *= r; + return 0.5 / (r * r); +} + +} // namespace + +// Calculates the critical chi-square value given degrees-of-freedom and a +// p-value, usually using bisection. Also known by the name CRITCHI. +double ChiSquareValue(int dof, double p) { + static constexpr double kChiEpsilon = + 0.000001; // Accuracy of the approximation. + static constexpr double kChiMax = + 99999.0; // Maximum chi-squared value. + + const double p_value = 1.0 - p; + if (dof < 1 || p_value > 1.0) { + return 0.0; + } + + if (dof > kLargeDOF) { + // For large degrees of freedom, use the normal approximation by + // Wilson, E. B. and Hilferty, M. M. (1931) + // chi^2 - mean + // Z = -------------- + // stddev + const double z = InverseNormalSurvival(p_value); + const double mean = 1 - 2.0 / (9 * dof); + const double variance = 2.0 / (9 * dof); + // Cannot use this method if the variance is 0. + if (variance != 0) { + return std::pow(z * std::sqrt(variance) + mean, 3.0) * dof; + } + } + + if (p_value <= 0.0) return kChiMax; + + // Otherwise search for the p value by bisection + double min_chisq = 0.0; + double max_chisq = kChiMax; + double current = dof / std::sqrt(p_value); + while ((max_chisq - min_chisq) > kChiEpsilon) { + if (ChiSquarePValue(current, dof) < p_value) { + max_chisq = current; + } else { + min_chisq = current; + } + current = (max_chisq + min_chisq) * 0.5; + } + return current; +} + +// Calculates the p-value (probability) of a given chi-square value +// and degrees of freedom. +// +// Adapted from the POCHISQ function from: +// Hill, I. D. and Pike, M. C. Algorithm 299 +// Collected Algorithms of the CACM 1963 p. 243 +// +double ChiSquarePValue(double chi_square, int dof) { + static constexpr double kLogSqrtPi = + 0.5723649429247000870717135; // Log[Sqrt[Pi]] + static constexpr double kInverseSqrtPi = + 0.5641895835477562869480795; // 1/(Sqrt[Pi]) + + // For large degrees of freedom, use the normal approximation by + // Wilson, E. B. and Hilferty, M. M. (1931) + // Via Wikipedia: + // By the Central Limit Theorem, because the chi-square distribution is the + // sum of k independent random variables with finite mean and variance, it + // converges to a normal distribution for large k. + if (dof > kLargeDOF) { + // Re-scale everything. + const double chi_square_scaled = std::pow(chi_square / dof, 1.0 / 3); + const double mean = 1 - 2.0 / (9 * dof); + const double variance = 2.0 / (9 * dof); + // If variance is 0, this method cannot be used. + if (variance != 0) { + const double z = (chi_square_scaled - mean) / std::sqrt(variance); + if (z > 0) { + return normal_survival(z); + } else if (z < 0) { + return 1.0 - normal_survival(-z); + } else { + return 0.5; + } + } + } + + // The chi square function is >= 0 for any degrees of freedom. + // In other words, probability that the chi square function >= 0 is 1. + if (chi_square <= 0.0) return 1.0; + + // If the degrees of freedom is zero, the chi square function is always 0 by + // definition. In other words, the probability that the chi square function + // is > 0 is zero (chi square values <= 0 have been filtered above). + if (dof < 1) return 0; + + auto capped_exp = [](double x) { return x < -20 ? 0.0 : std::exp(x); }; + static constexpr double kBigX = 20; + + double a = 0.5 * chi_square; + const bool even = !(dof & 1); // True if dof is an even number. + const double y = capped_exp(-a); + double s = even ? y : (2.0 * POZ(-std::sqrt(chi_square))); + + if (dof <= 2) { + return s; + } + + chi_square = 0.5 * (dof - 1.0); + double z = (even ? 1.0 : 0.5); + if (a > kBigX) { + double e = (even ? 0.0 : kLogSqrtPi); + double c = std::log(a); + while (z <= chi_square) { + e = std::log(z) + e; + s += capped_exp(c * z - a - e); + z += 1.0; + } + return s; + } + + double e = (even ? 1.0 : (kInverseSqrtPi / std::sqrt(a))); + double c = 0.0; + while (z <= chi_square) { + e = e * (a / z); + c = c + e; + z += 1.0; + } + return c * y + s; +} + +} // namespace random_internal +} // namespace absl diff --git a/absl/random/internal/chi_square.h b/absl/random/internal/chi_square.h new file mode 100644 index 000000000000..fa8646f274cd --- /dev/null +++ b/absl/random/internal/chi_square.h @@ -0,0 +1,85 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_CHI_SQUARE_H_ +#define ABSL_RANDOM_INTERNAL_CHI_SQUARE_H_ + +// The chi-square statistic. +// +// Useful for evaluating if `D` independent random variables are behaving as +// expected, or if two distributions are similar. (`D` is the degrees of +// freedom). +// +// Each bucket should have an expected count of 10 or more for the chi square to +// be meaningful. + +#include <cassert> + +namespace absl { +namespace random_internal { + +constexpr const char kChiSquared[] = "chi-squared"; + +// Returns the measured chi square value, using a single expected value. This +// assumes that the values in [begin, end) are uniformly distributed. +template <typename Iterator> +double ChiSquareWithExpected(Iterator begin, Iterator end, double expected) { + // Compute the sum and the number of buckets. + assert(expected >= 10); // require at least 10 samples per bucket. + double chi_square = 0; + for (auto it = begin; it != end; it++) { + double d = static_cast<double>(*it) - expected; + chi_square += d * d; + } + chi_square = chi_square / expected; + return chi_square; +} + +// Returns the measured chi square value, taking the actual value of each bucket +// from the first set of iterators, and the expected value of each bucket from +// the second set of iterators. +template <typename Iterator, typename Expected> +double ChiSquare(Iterator it, Iterator end, Expected eit, Expected eend) { + double chi_square = 0; + for (; it != end && eit != eend; ++it, ++eit) { + if (*it > 0) { + assert(*eit > 0); + } + double e = static_cast<double>(*eit); + double d = static_cast<double>(*it - *eit); + if (d != 0) { + assert(e > 0); + chi_square += (d * d) / e; + } + } + assert(it == end && eit == eend); + return chi_square; +} + +// ====================================================================== +// The following methods can be used for an arbitrary significance level. +// + +// Calculates critical chi-square values to produce the given p-value using a +// bisection search for a value within epsilon, relying on the monotonicity of +// ChiSquarePValue(). +double ChiSquareValue(int dof, double p); + +// Calculates the p-value (probability) of a given chi-square value. +double ChiSquarePValue(double chi_square, int dof); + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_CHI_SQUARE_H_ diff --git a/absl/random/internal/chi_square_test.cc b/absl/random/internal/chi_square_test.cc new file mode 100644 index 000000000000..5025defac12c --- /dev/null +++ b/absl/random/internal/chi_square_test.cc @@ -0,0 +1,365 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/chi_square.h" + +#include <algorithm> +#include <cstddef> +#include <cstdint> +#include <iterator> +#include <numeric> +#include <vector> + +#include "gtest/gtest.h" +#include "absl/base/macros.h" + +using absl::random_internal::ChiSquare; +using absl::random_internal::ChiSquarePValue; +using absl::random_internal::ChiSquareValue; +using absl::random_internal::ChiSquareWithExpected; + +namespace { + +TEST(ChiSquare, Value) { + struct { + int line; + double chi_square; + int df; + double confidence; + } const specs[] = { + // Testing lookup at 1% confidence + {__LINE__, 0, 0, 0.01}, + {__LINE__, 0.00016, 1, 0.01}, + {__LINE__, 1.64650, 8, 0.01}, + {__LINE__, 5.81221, 16, 0.01}, + {__LINE__, 156.4319, 200, 0.01}, + {__LINE__, 1121.3784, 1234, 0.01}, + {__LINE__, 53557.1629, 54321, 0.01}, + {__LINE__, 651662.6647, 654321, 0.01}, + + // Testing lookup at 99% confidence + {__LINE__, 0, 0, 0.99}, + {__LINE__, 6.635, 1, 0.99}, + {__LINE__, 20.090, 8, 0.99}, + {__LINE__, 32.000, 16, 0.99}, + {__LINE__, 249.4456, 200, 0.99}, + {__LINE__, 1131.1573, 1023, 0.99}, + {__LINE__, 1352.5038, 1234, 0.99}, + {__LINE__, 55090.7356, 54321, 0.99}, + {__LINE__, 656985.1514, 654321, 0.99}, + + // Testing lookup at 99.9% confidence + {__LINE__, 16.2659, 3, 0.999}, + {__LINE__, 22.4580, 6, 0.999}, + {__LINE__, 267.5409, 200, 0.999}, + {__LINE__, 1168.5033, 1023, 0.999}, + {__LINE__, 55345.1741, 54321, 0.999}, + {__LINE__, 657861.7284, 654321, 0.999}, + {__LINE__, 51.1772, 24, 0.999}, + {__LINE__, 59.7003, 30, 0.999}, + {__LINE__, 37.6984, 15, 0.999}, + {__LINE__, 29.5898, 10, 0.999}, + {__LINE__, 27.8776, 9, 0.999}, + + // Testing lookup at random confidences + {__LINE__, 0.000157088, 1, 0.01}, + {__LINE__, 5.31852, 2, 0.93}, + {__LINE__, 1.92256, 4, 0.25}, + {__LINE__, 10.7709, 13, 0.37}, + {__LINE__, 26.2514, 17, 0.93}, + {__LINE__, 36.4799, 29, 0.84}, + {__LINE__, 25.818, 31, 0.27}, + {__LINE__, 63.3346, 64, 0.50}, + {__LINE__, 196.211, 128, 0.9999}, + {__LINE__, 215.21, 243, 0.10}, + {__LINE__, 285.393, 256, 0.90}, + {__LINE__, 984.504, 1024, 0.1923}, + {__LINE__, 2043.85, 2048, 0.4783}, + {__LINE__, 48004.6, 48273, 0.194}, + }; + for (const auto& spec : specs) { + SCOPED_TRACE(spec.line); + // Verify all values are have at most a 1% relative error. + const double val = ChiSquareValue(spec.df, spec.confidence); + const double err = std::max(5e-6, spec.chi_square / 5e3); // 1 part in 5000 + EXPECT_NEAR(spec.chi_square, val, err) << spec.line; + } + + // Relaxed test for extreme values, from + // http://www.ciphersbyritter.com/JAVASCRP/NORMCHIK.HTM#ChiSquare + EXPECT_NEAR(49.2680, ChiSquareValue(100, 1e-6), 5); // 0.000'005 mark + EXPECT_NEAR(123.499, ChiSquareValue(200, 1e-6), 5); // 0.000'005 mark + + EXPECT_NEAR(149.449, ChiSquareValue(100, 0.999), 0.01); + EXPECT_NEAR(161.318, ChiSquareValue(100, 0.9999), 0.01); + EXPECT_NEAR(172.098, ChiSquareValue(100, 0.99999), 0.01); + + EXPECT_NEAR(381.426, ChiSquareValue(300, 0.999), 0.05); + EXPECT_NEAR(399.756, ChiSquareValue(300, 0.9999), 0.1); + EXPECT_NEAR(416.126, ChiSquareValue(300, 0.99999), 0.2); +} + +TEST(ChiSquareTest, PValue) { + struct { + int line; + double pval; + double chi_square; + int df; + } static const specs[] = { + {__LINE__, 1, 0, 0}, + {__LINE__, 0, 0.001, 0}, + {__LINE__, 1.000, 0, 453}, + {__LINE__, 0.134471, 7972.52, 7834}, + {__LINE__, 0.203922, 28.32, 23}, + {__LINE__, 0.737171, 48274, 48472}, + {__LINE__, 0.444146, 583.1234, 579}, + {__LINE__, 0.294814, 138.2, 130}, + {__LINE__, 0.0816532, 12.63, 7}, + {__LINE__, 0, 682.32, 67}, + {__LINE__, 0.49405, 999, 999}, + {__LINE__, 1.000, 0, 9999}, + {__LINE__, 0.997477, 0.00001, 1}, + {__LINE__, 0, 5823.21, 5040}, + }; + for (const auto& spec : specs) { + SCOPED_TRACE(spec.line); + const double pval = ChiSquarePValue(spec.chi_square, spec.df); + EXPECT_NEAR(spec.pval, pval, 1e-3); + } +} + +TEST(ChiSquareTest, CalcChiSquare) { + struct { + int line; + std::vector<int> expected; + std::vector<int> actual; + } const specs[] = { + {__LINE__, + {56, 234, 76, 1, 546, 1, 87, 345, 1, 234}, + {2, 132, 4, 43, 234, 8, 345, 8, 236, 56}}, + {__LINE__, + {123, 36, 234, 367, 345, 2, 456, 567, 234, 567}, + {123, 56, 2345, 8, 345, 8, 2345, 23, 48, 267}}, + {__LINE__, + {123, 234, 345, 456, 567, 678, 789, 890, 98, 76}, + {123, 234, 345, 456, 567, 678, 789, 890, 98, 76}}, + {__LINE__, {3, 675, 23, 86, 2, 8, 2}, {456, 675, 23, 86, 23, 65, 2}}, + {__LINE__, {1}, {23}}, + }; + for (const auto& spec : specs) { + SCOPED_TRACE(spec.line); + double chi_square = 0; + for (int i = 0; i < spec.expected.size(); ++i) { + const double diff = spec.actual[i] - spec.expected[i]; + chi_square += (diff * diff) / spec.expected[i]; + } + EXPECT_NEAR(chi_square, + ChiSquare(std::begin(spec.actual), std::end(spec.actual), + std::begin(spec.expected), std::end(spec.expected)), + 1e-5); + } +} + +TEST(ChiSquareTest, CalcChiSquareInt64) { + const int64_t data[3] = {910293487, 910292491, 910216780}; + // $ python -c "import scipy.stats + // > print scipy.stats.chisquare([910293487, 910292491, 910216780])[0]" + // 4.25410123524 + double sum = std::accumulate(std::begin(data), std::end(data), double{0}); + size_t n = std::distance(std::begin(data), std::end(data)); + double a = ChiSquareWithExpected(std::begin(data), std::end(data), sum / n); + EXPECT_NEAR(4.254101, a, 1e-6); + + // ... Or with known values. + double b = + ChiSquareWithExpected(std::begin(data), std::end(data), 910267586.0); + EXPECT_NEAR(4.254101, b, 1e-6); +} + +TEST(ChiSquareTest, TableData) { + // Test data from + // http://www.itl.nist.gov/div898/handbook/eda/section3/eda3674.htm + // 0.90 0.95 0.975 0.99 0.999 + const double data[100][5] = { + /* 1*/ {2.706, 3.841, 5.024, 6.635, 10.828}, + /* 2*/ {4.605, 5.991, 7.378, 9.210, 13.816}, + /* 3*/ {6.251, 7.815, 9.348, 11.345, 16.266}, + /* 4*/ {7.779, 9.488, 11.143, 13.277, 18.467}, + /* 5*/ {9.236, 11.070, 12.833, 15.086, 20.515}, + /* 6*/ {10.645, 12.592, 14.449, 16.812, 22.458}, + /* 7*/ {12.017, 14.067, 16.013, 18.475, 24.322}, + /* 8*/ {13.362, 15.507, 17.535, 20.090, 26.125}, + /* 9*/ {14.684, 16.919, 19.023, 21.666, 27.877}, + /*10*/ {15.987, 18.307, 20.483, 23.209, 29.588}, + /*11*/ {17.275, 19.675, 21.920, 24.725, 31.264}, + /*12*/ {18.549, 21.026, 23.337, 26.217, 32.910}, + /*13*/ {19.812, 22.362, 24.736, 27.688, 34.528}, + /*14*/ {21.064, 23.685, 26.119, 29.141, 36.123}, + /*15*/ {22.307, 24.996, 27.488, 30.578, 37.697}, + /*16*/ {23.542, 26.296, 28.845, 32.000, 39.252}, + /*17*/ {24.769, 27.587, 30.191, 33.409, 40.790}, + /*18*/ {25.989, 28.869, 31.526, 34.805, 42.312}, + /*19*/ {27.204, 30.144, 32.852, 36.191, 43.820}, + /*20*/ {28.412, 31.410, 34.170, 37.566, 45.315}, + /*21*/ {29.615, 32.671, 35.479, 38.932, 46.797}, + /*22*/ {30.813, 33.924, 36.781, 40.289, 48.268}, + /*23*/ {32.007, 35.172, 38.076, 41.638, 49.728}, + /*24*/ {33.196, 36.415, 39.364, 42.980, 51.179}, + /*25*/ {34.382, 37.652, 40.646, 44.314, 52.620}, + /*26*/ {35.563, 38.885, 41.923, 45.642, 54.052}, + /*27*/ {36.741, 40.113, 43.195, 46.963, 55.476}, + /*28*/ {37.916, 41.337, 44.461, 48.278, 56.892}, + /*29*/ {39.087, 42.557, 45.722, 49.588, 58.301}, + /*30*/ {40.256, 43.773, 46.979, 50.892, 59.703}, + /*31*/ {41.422, 44.985, 48.232, 52.191, 61.098}, + /*32*/ {42.585, 46.194, 49.480, 53.486, 62.487}, + /*33*/ {43.745, 47.400, 50.725, 54.776, 63.870}, + /*34*/ {44.903, 48.602, 51.966, 56.061, 65.247}, + /*35*/ {46.059, 49.802, 53.203, 57.342, 66.619}, + /*36*/ {47.212, 50.998, 54.437, 58.619, 67.985}, + /*37*/ {48.363, 52.192, 55.668, 59.893, 69.347}, + /*38*/ {49.513, 53.384, 56.896, 61.162, 70.703}, + /*39*/ {50.660, 54.572, 58.120, 62.428, 72.055}, + /*40*/ {51.805, 55.758, 59.342, 63.691, 73.402}, + /*41*/ {52.949, 56.942, 60.561, 64.950, 74.745}, + /*42*/ {54.090, 58.124, 61.777, 66.206, 76.084}, + /*43*/ {55.230, 59.304, 62.990, 67.459, 77.419}, + /*44*/ {56.369, 60.481, 64.201, 68.710, 78.750}, + /*45*/ {57.505, 61.656, 65.410, 69.957, 80.077}, + /*46*/ {58.641, 62.830, 66.617, 71.201, 81.400}, + /*47*/ {59.774, 64.001, 67.821, 72.443, 82.720}, + /*48*/ {60.907, 65.171, 69.023, 73.683, 84.037}, + /*49*/ {62.038, 66.339, 70.222, 74.919, 85.351}, + /*50*/ {63.167, 67.505, 71.420, 76.154, 86.661}, + /*51*/ {64.295, 68.669, 72.616, 77.386, 87.968}, + /*52*/ {65.422, 69.832, 73.810, 78.616, 89.272}, + /*53*/ {66.548, 70.993, 75.002, 79.843, 90.573}, + /*54*/ {67.673, 72.153, 76.192, 81.069, 91.872}, + /*55*/ {68.796, 73.311, 77.380, 82.292, 93.168}, + /*56*/ {69.919, 74.468, 78.567, 83.513, 94.461}, + /*57*/ {71.040, 75.624, 79.752, 84.733, 95.751}, + /*58*/ {72.160, 76.778, 80.936, 85.950, 97.039}, + /*59*/ {73.279, 77.931, 82.117, 87.166, 98.324}, + /*60*/ {74.397, 79.082, 83.298, 88.379, 99.607}, + /*61*/ {75.514, 80.232, 84.476, 89.591, 100.888}, + /*62*/ {76.630, 81.381, 85.654, 90.802, 102.166}, + /*63*/ {77.745, 82.529, 86.830, 92.010, 103.442}, + /*64*/ {78.860, 83.675, 88.004, 93.217, 104.716}, + /*65*/ {79.973, 84.821, 89.177, 94.422, 105.988}, + /*66*/ {81.085, 85.965, 90.349, 95.626, 107.258}, + /*67*/ {82.197, 87.108, 91.519, 96.828, 108.526}, + /*68*/ {83.308, 88.250, 92.689, 98.028, 109.791}, + /*69*/ {84.418, 89.391, 93.856, 99.228, 111.055}, + /*70*/ {85.527, 90.531, 95.023, 100.425, 112.317}, + /*71*/ {86.635, 91.670, 96.189, 101.621, 113.577}, + /*72*/ {87.743, 92.808, 97.353, 102.816, 114.835}, + /*73*/ {88.850, 93.945, 98.516, 104.010, 116.092}, + /*74*/ {89.956, 95.081, 99.678, 105.202, 117.346}, + /*75*/ {91.061, 96.217, 100.839, 106.393, 118.599}, + /*76*/ {92.166, 97.351, 101.999, 107.583, 119.850}, + /*77*/ {93.270, 98.484, 103.158, 108.771, 121.100}, + /*78*/ {94.374, 99.617, 104.316, 109.958, 122.348}, + /*79*/ {95.476, 100.749, 105.473, 111.144, 123.594}, + /*80*/ {96.578, 101.879, 106.629, 112.329, 124.839}, + /*81*/ {97.680, 103.010, 107.783, 113.512, 126.083}, + /*82*/ {98.780, 104.139, 108.937, 114.695, 127.324}, + /*83*/ {99.880, 105.267, 110.090, 115.876, 128.565}, + /*84*/ {100.980, 106.395, 111.242, 117.057, 129.804}, + /*85*/ {102.079, 107.522, 112.393, 118.236, 131.041}, + /*86*/ {103.177, 108.648, 113.544, 119.414, 132.277}, + /*87*/ {104.275, 109.773, 114.693, 120.591, 133.512}, + /*88*/ {105.372, 110.898, 115.841, 121.767, 134.746}, + /*89*/ {106.469, 112.022, 116.989, 122.942, 135.978}, + /*90*/ {107.565, 113.145, 118.136, 124.116, 137.208}, + /*91*/ {108.661, 114.268, 119.282, 125.289, 138.438}, + /*92*/ {109.756, 115.390, 120.427, 126.462, 139.666}, + /*93*/ {110.850, 116.511, 121.571, 127.633, 140.893}, + /*94*/ {111.944, 117.632, 122.715, 128.803, 142.119}, + /*95*/ {113.038, 118.752, 123.858, 129.973, 143.344}, + /*96*/ {114.131, 119.871, 125.000, 131.141, 144.567}, + /*97*/ {115.223, 120.990, 126.141, 132.309, 145.789}, + /*98*/ {116.315, 122.108, 127.282, 133.476, 147.010}, + /*99*/ {117.407, 123.225, 128.422, 134.642, 148.230}, + /*100*/ {118.498, 124.342, 129.561, 135.807, 149.449} + /**/}; + + // 0.90 0.95 0.975 0.99 0.999 + for (int i = 0; i < ABSL_ARRAYSIZE(data); i++) { + const double E = 0.0001; + EXPECT_NEAR(ChiSquarePValue(data[i][0], i + 1), 0.10, E) + << i << " " << data[i][0]; + EXPECT_NEAR(ChiSquarePValue(data[i][1], i + 1), 0.05, E) + << i << " " << data[i][1]; + EXPECT_NEAR(ChiSquarePValue(data[i][2], i + 1), 0.025, E) + << i << " " << data[i][2]; + EXPECT_NEAR(ChiSquarePValue(data[i][3], i + 1), 0.01, E) + << i << " " << data[i][3]; + EXPECT_NEAR(ChiSquarePValue(data[i][4], i + 1), 0.001, E) + << i << " " << data[i][4]; + + const double F = 0.1; + EXPECT_NEAR(ChiSquareValue(i + 1, 0.90), data[i][0], F) << i; + EXPECT_NEAR(ChiSquareValue(i + 1, 0.95), data[i][1], F) << i; + EXPECT_NEAR(ChiSquareValue(i + 1, 0.975), data[i][2], F) << i; + EXPECT_NEAR(ChiSquareValue(i + 1, 0.99), data[i][3], F) << i; + EXPECT_NEAR(ChiSquareValue(i + 1, 0.999), data[i][4], F) << i; + } +} + +TEST(ChiSquareTest, ChiSquareTwoIterator) { + // Test data from http://www.stat.yale.edu/Courses/1997-98/101/chigf.htm + // Null-hypothesis: This data is normally distributed. + const int counts[10] = {6, 6, 18, 33, 38, 38, 28, 21, 9, 3}; + const double expected[10] = {4.6, 8.8, 18.4, 30.0, 38.2, + 38.2, 30.0, 18.4, 8.8, 4.6}; + double chi_square = ChiSquare(std::begin(counts), std::end(counts), + std::begin(expected), std::end(expected)); + EXPECT_NEAR(chi_square, 2.69, 0.001); + + // Degrees of freedom: 10 bins. two estimated parameters. = 10 - 2 - 1. + const int dof = 7; + // The critical value of 7, 95% => 14.067 (see above test) + double p_value_05 = ChiSquarePValue(14.067, dof); + EXPECT_NEAR(p_value_05, 0.05, 0.001); // 95%-ile p-value + + double p_actual = ChiSquarePValue(chi_square, dof); + EXPECT_GT(p_actual, 0.05); // Accept the null hypothesis. +} + +TEST(ChiSquareTest, DiceRolls) { + // Assume we are testing 102 fair dice rolls. + // Null-hypothesis: This data is fairly distributed. + // + // The dof value of 4, @95% = 9.488 (see above test) + // The dof value of 5, @95% = 11.070 + const int rolls[6] = {22, 11, 17, 14, 20, 18}; + double sum = std::accumulate(std::begin(rolls), std::end(rolls), double{0}); + size_t n = std::distance(std::begin(rolls), std::end(rolls)); + + double a = ChiSquareWithExpected(std::begin(rolls), std::end(rolls), sum / n); + EXPECT_NEAR(a, 4.70588, 1e-5); + EXPECT_LT(a, ChiSquareValue(4, 0.95)); + + double p_a = ChiSquarePValue(a, 4); + EXPECT_NEAR(p_a, 0.318828, 1e-5); // Accept the null hypothesis. + + double b = ChiSquareWithExpected(std::begin(rolls), std::end(rolls), 17.0); + EXPECT_NEAR(b, 4.70588, 1e-5); + EXPECT_LT(b, ChiSquareValue(5, 0.95)); + + double p_b = ChiSquarePValue(b, 5); + EXPECT_NEAR(p_b, 0.4528180, 1e-5); // Accept the null hypothesis. +} + +} // namespace diff --git a/absl/random/internal/distribution_caller.h b/absl/random/internal/distribution_caller.h new file mode 100644 index 000000000000..0318e1f8072e --- /dev/null +++ b/absl/random/internal/distribution_caller.h @@ -0,0 +1,56 @@ +// +// Copyright 2018 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef ABSL_RANDOM_INTERNAL_DISTRIBUTION_CALLER_H_ +#define ABSL_RANDOM_INTERNAL_DISTRIBUTION_CALLER_H_ + +#include <utility> + +namespace absl { +namespace random_internal { + +// DistributionCaller provides an opportunity to overload the general +// mechanism for calling a distribution, allowing for mock-RNG classes +// to intercept such calls. +template <typename URBG> +struct DistributionCaller { + // Call the provided distribution type. The parameters are expected + // to be explicitly specified. + // DistrT is the distribution type. + // FormatT is the formatter type: + // + // struct FormatT { + // using result_type = distribution_t::result_type; + // static std::string FormatCall( + // const distribution_t& distr, + // absl::Span<const result_type>); + // + // static std::string FormatExpectation( + // absl::string_view match_args, + // absl::Span<const result_t> results); + // } + // + template <typename DistrT, typename FormatT, typename... Args> + static typename DistrT::result_type Call(URBG* urbg, Args&&... args) { + DistrT dist(std::forward<Args>(args)...); + return dist(*urbg); + } +}; + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_DISTRIBUTION_CALLER_H_ diff --git a/absl/random/internal/distribution_impl.h b/absl/random/internal/distribution_impl.h new file mode 100644 index 000000000000..9b6ffb0fb504 --- /dev/null +++ b/absl/random/internal/distribution_impl.h @@ -0,0 +1,260 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_DISTRIBUTION_IMPL_H_ +#define ABSL_RANDOM_INTERNAL_DISTRIBUTION_IMPL_H_ + +// This file contains some implementation details which are used by one or more +// of the absl random number distributions. + +#include <cfloat> +#include <cstddef> +#include <cstdint> +#include <cstring> +#include <limits> +#include <type_traits> + +#if (defined(_WIN32) || defined(_WIN64)) && defined(_M_IA64) +#include <intrin.h> // NOLINT(build/include_order) +#pragma intrinsic(_umul128) +#define ABSL_INTERNAL_USE_UMUL128 1 +#endif + +#include "absl/base/config.h" +#include "absl/base/internal/bits.h" +#include "absl/numeric/int128.h" +#include "absl/random/internal/fastmath.h" +#include "absl/random/internal/traits.h" + +namespace absl { +namespace random_internal { + +// Creates a double from `bits`, with the template fields controlling the +// output. +// +// RandU64To is both more efficient and generates more unique values in the +// result interval than known implementations of std::generate_canonical(). +// +// The `Signed` parameter controls whether positive, negative, or both are +// returned (thus affecting the output interval). +// When Signed == SignedValueT, range is U(-1, 1) +// When Signed == NegativeValueT, range is U(-1, 0) +// When Signed == PositiveValueT, range is U(0, 1) +// +// When the `IncludeZero` parameter is true, the function may return 0 for some +// inputs, otherwise it never returns 0. +// +// The `ExponentBias` parameter determines the scale of the output range by +// adjusting the exponent. +// +// When a value in U(0,1) is required, use: +// RandU64ToDouble<PositiveValueT, true, 0>(); +// +// When a value in U(-1,1) is required, use: +// RandU64ToDouble<SignedValueT, false, 0>() => U(-1, 1) +// This generates more distinct values than the mathematically equivalent +// expression `U(0, 1) * 2.0 - 1.0`, and is preferable. +// +// Scaling the result by powers of 2 (and avoiding a multiply) is also possible: +// RandU64ToDouble<PositiveValueT, false, 1>(); => U(0, 2) +// RandU64ToDouble<PositiveValueT, false, -1>(); => U(0, 0.5) +// + +// Tristate types controlling the output. +struct PositiveValueT {}; +struct NegativeValueT {}; +struct SignedValueT {}; + +// RandU64ToDouble is the double-result variant of RandU64To, described above. +template <typename Signed, bool IncludeZero, int ExponentBias = 0> +inline double RandU64ToDouble(uint64_t bits) { + static_assert(std::is_same<Signed, PositiveValueT>::value || + std::is_same<Signed, NegativeValueT>::value || + std::is_same<Signed, SignedValueT>::value, + ""); + + // Maybe use the left-most bit for a sign bit. + uint64_t sign = std::is_same<Signed, NegativeValueT>::value + ? 0x8000000000000000ull + : 0; // Sign bits. + + if (std::is_same<Signed, SignedValueT>::value) { + sign = bits & 0x8000000000000000ull; + bits = bits & 0x7FFFFFFFFFFFFFFFull; + } + if (IncludeZero) { + if (bits == 0u) return 0; + } + + // Number of leading zeros is mapped to the exponent: 2^-clz + int clz = base_internal::CountLeadingZeros64(bits); + // Shift number left to erase leading zeros. + bits <<= IncludeZero ? clz : (clz & 63); + + // Shift number right to remove bits that overflow double mantissa. The + // direction of the shift depends on `clz`. + bits >>= (64 - DBL_MANT_DIG); + + // Compute IEEE 754 double exponent. + // In the Signed case, bits is a 63-bit number with a 0 msb. Adjust the + // exponent to account for that. + const uint64_t exp = + (std::is_same<Signed, SignedValueT>::value ? 1023U : 1022U) + + static_cast<uint64_t>(ExponentBias - clz); + constexpr int kExp = DBL_MANT_DIG - 1; + // Construct IEEE 754 double from exponent and mantissa. + const uint64_t val = sign | (exp << kExp) | (bits & ((1ULL << kExp) - 1U)); + + double res; + static_assert(sizeof(res) == sizeof(val), "double is not 64 bit"); + // Memcpy value from "val" to "res" to avoid aliasing problems. Assumes that + // endian-ness is same for double and uint64_t. + std::memcpy(&res, &val, sizeof(res)); + + return res; +} + +// RandU64ToFloat is the float-result variant of RandU64To, described above. +template <typename Signed, bool IncludeZero, int ExponentBias = 0> +inline float RandU64ToFloat(uint64_t bits) { + static_assert(std::is_same<Signed, PositiveValueT>::value || + std::is_same<Signed, NegativeValueT>::value || + std::is_same<Signed, SignedValueT>::value, + ""); + + // Maybe use the left-most bit for a sign bit. + uint64_t sign = std::is_same<Signed, NegativeValueT>::value + ? 0x80000000ul + : 0; // Sign bits. + + if (std::is_same<Signed, SignedValueT>::value) { + uint64_t a = bits & 0x8000000000000000ull; + sign = static_cast<uint32_t>(a >> 32); + bits = bits & 0x7FFFFFFFFFFFFFFFull; + } + if (IncludeZero) { + if (bits == 0u) return 0; + } + + // Number of leading zeros is mapped to the exponent: 2^-clz + int clz = base_internal::CountLeadingZeros64(bits); + // Shift number left to erase leading zeros. + bits <<= IncludeZero ? clz : (clz & 63); + // Shift number right to remove bits that overflow double mantissa. The + // direction of the shift depends on `clz`. + bits >>= (64 - FLT_MANT_DIG); + + // Construct IEEE 754 float exponent. + // In the Signed case, bits is a 63-bit number with a 0 msb. Adjust the + // exponent to account for that. + const uint32_t exp = + (std::is_same<Signed, SignedValueT>::value ? 127U : 126U) + + static_cast<uint32_t>(ExponentBias - clz); + constexpr int kExp = FLT_MANT_DIG - 1; + const uint32_t val = sign | (exp << kExp) | (bits & ((1U << kExp) - 1U)); + + float res; + static_assert(sizeof(res) == sizeof(val), "float is not 32 bit"); + // Assumes that endian-ness is same for float and uint32_t. + std::memcpy(&res, &val, sizeof(res)); + + return res; +} + +template <typename Result> +struct RandU64ToReal { + template <typename Signed, bool IncludeZero, int ExponentBias = 0> + static inline Result Value(uint64_t bits) { + return RandU64ToDouble<Signed, IncludeZero, ExponentBias>(bits); + } +}; + +template <> +struct RandU64ToReal<float> { + template <typename Signed, bool IncludeZero, int ExponentBias = 0> + static inline float Value(uint64_t bits) { + return RandU64ToFloat<Signed, IncludeZero, ExponentBias>(bits); + } +}; + +inline uint128 MultiplyU64ToU128(uint64_t a, uint64_t b) { +#if defined(ABSL_HAVE_INTRINSIC_INT128) + return uint128(static_cast<__uint128_t>(a) * b); +#elif defined(ABSL_INTERNAL_USE_UMUL128) + // uint64_t * uint64_t => uint128 multiply using imul intrinsic on MSVC. + uint64_t high = 0; + const uint64_t low = _umul128(a, b, &high); + return absl::MakeUint128(high, low); +#else + // uint128(a) * uint128(b) in emulated mode computes a full 128-bit x 128-bit + // multiply. However there are many cases where that is not necessary, and it + // is only necessary to support a 64-bit x 64-bit = 128-bit multiply. This is + // for those cases. + const uint64_t a00 = static_cast<uint32_t>(a); + const uint64_t a32 = a >> 32; + const uint64_t b00 = static_cast<uint32_t>(b); + const uint64_t b32 = b >> 32; + + const uint64_t c00 = a00 * b00; + const uint64_t c32a = a00 * b32; + const uint64_t c32b = a32 * b00; + const uint64_t c64 = a32 * b32; + + const uint32_t carry = + static_cast<uint32_t>(((c00 >> 32) + static_cast<uint32_t>(c32a) + + static_cast<uint32_t>(c32b)) >> + 32); + + return absl::MakeUint128(c64 + (c32a >> 32) + (c32b >> 32) + carry, + c00 + (c32a << 32) + (c32b << 32)); +#endif +} + +// wide_multiply<T> multiplies two N-bit values to a 2N-bit result. +template <typename UIntType> +struct wide_multiply { + static constexpr size_t kN = std::numeric_limits<UIntType>::digits; + using input_type = UIntType; + using result_type = typename random_internal::unsigned_bits<kN * 2>::type; + + static result_type multiply(input_type a, input_type b) { + return static_cast<result_type>(a) * b; + } + + static input_type hi(result_type r) { return r >> kN; } + static input_type lo(result_type r) { return r; } + + static_assert(std::is_unsigned<UIntType>::value, + "Class-template wide_multiply<> argument must be unsigned."); +}; + +#ifndef ABSL_HAVE_INTRINSIC_INT128 +template <> +struct wide_multiply<uint64_t> { + using input_type = uint64_t; + using result_type = uint128; + + static result_type multiply(uint64_t a, uint64_t b) { + return MultiplyU64ToU128(a, b); + } + + static uint64_t hi(result_type r) { return Uint128High64(r); } + static uint64_t lo(result_type r) { return Uint128Low64(r); } +}; +#endif + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_DISTRIBUTION_IMPL_H_ diff --git a/absl/random/internal/distribution_impl_test.cc b/absl/random/internal/distribution_impl_test.cc new file mode 100644 index 000000000000..09e7a318ccc5 --- /dev/null +++ b/absl/random/internal/distribution_impl_test.cc @@ -0,0 +1,506 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/distribution_impl.h" + +#include "gtest/gtest.h" +#include "absl/base/internal/bits.h" +#include "absl/flags/flag.h" +#include "absl/numeric/int128.h" + +ABSL_FLAG(int64_t, absl_random_test_trials, 50000, + "Number of trials for the probability tests."); + +using absl::random_internal::NegativeValueT; +using absl::random_internal::PositiveValueT; +using absl::random_internal::RandU64ToDouble; +using absl::random_internal::RandU64ToFloat; +using absl::random_internal::SignedValueT; + +namespace { + +TEST(DistributionImplTest, U64ToFloat_Positive_NoZero_Test) { + auto ToFloat = [](uint64_t a) { + return RandU64ToFloat<PositiveValueT, false>(a); + }; + EXPECT_EQ(ToFloat(0x0000000000000000), 2.710505431e-20f); + EXPECT_EQ(ToFloat(0x0000000000000001), 5.421010862e-20f); + EXPECT_EQ(ToFloat(0x8000000000000000), 0.5); + EXPECT_EQ(ToFloat(0xFFFFFFFFFFFFFFFF), 0.9999999404f); +} + +TEST(DistributionImplTest, U64ToFloat_Positive_Zero_Test) { + auto ToFloat = [](uint64_t a) { + return RandU64ToFloat<PositiveValueT, true>(a); + }; + EXPECT_EQ(ToFloat(0x0000000000000000), 0.0); + EXPECT_EQ(ToFloat(0x0000000000000001), 5.421010862e-20f); + EXPECT_EQ(ToFloat(0x8000000000000000), 0.5); + EXPECT_EQ(ToFloat(0xFFFFFFFFFFFFFFFF), 0.9999999404f); +} + +TEST(DistributionImplTest, U64ToFloat_Negative_NoZero_Test) { + auto ToFloat = [](uint64_t a) { + return RandU64ToFloat<NegativeValueT, false>(a); + }; + EXPECT_EQ(ToFloat(0x0000000000000000), -2.710505431e-20f); + EXPECT_EQ(ToFloat(0x0000000000000001), -5.421010862e-20f); + EXPECT_EQ(ToFloat(0x8000000000000000), -0.5); + EXPECT_EQ(ToFloat(0xFFFFFFFFFFFFFFFF), -0.9999999404f); +} + +TEST(DistributionImplTest, U64ToFloat_Signed_NoZero_Test) { + auto ToFloat = [](uint64_t a) { + return RandU64ToFloat<SignedValueT, false>(a); + }; + EXPECT_EQ(ToFloat(0x0000000000000000), 5.421010862e-20f); + EXPECT_EQ(ToFloat(0x0000000000000001), 1.084202172e-19f); + EXPECT_EQ(ToFloat(0x7FFFFFFFFFFFFFFF), 0.9999999404f); + EXPECT_EQ(ToFloat(0x8000000000000000), -5.421010862e-20f); + EXPECT_EQ(ToFloat(0x8000000000000001), -1.084202172e-19f); + EXPECT_EQ(ToFloat(0xFFFFFFFFFFFFFFFF), -0.9999999404f); +} + +TEST(DistributionImplTest, U64ToFloat_Signed_Zero_Test) { + auto ToFloat = [](uint64_t a) { + return RandU64ToFloat<SignedValueT, true>(a); + }; + EXPECT_EQ(ToFloat(0x0000000000000000), 0); + EXPECT_EQ(ToFloat(0x0000000000000001), 1.084202172e-19f); + EXPECT_EQ(ToFloat(0x7FFFFFFFFFFFFFFF), 0.9999999404f); + EXPECT_EQ(ToFloat(0x8000000000000000), 0); + EXPECT_EQ(ToFloat(0x8000000000000001), -1.084202172e-19f); + EXPECT_EQ(ToFloat(0xFFFFFFFFFFFFFFFF), -0.9999999404f); +} + +TEST(DistributionImplTest, U64ToFloat_Signed_Bias_Test) { + auto ToFloat = [](uint64_t a) { + return RandU64ToFloat<SignedValueT, true, 1>(a); + }; + EXPECT_EQ(ToFloat(0x0000000000000000), 0); + EXPECT_EQ(ToFloat(0x0000000000000001), 2 * 1.084202172e-19f); + EXPECT_EQ(ToFloat(0x7FFFFFFFFFFFFFFF), 2 * 0.9999999404f); + EXPECT_EQ(ToFloat(0x8000000000000000), 0); + EXPECT_EQ(ToFloat(0x8000000000000001), 2 * -1.084202172e-19f); + EXPECT_EQ(ToFloat(0xFFFFFFFFFFFFFFFF), 2 * -0.9999999404f); +} + +TEST(DistributionImplTest, U64ToFloatTest) { + auto ToFloat = [](uint64_t a) -> float { + return RandU64ToFloat<PositiveValueT, true>(a); + }; + + EXPECT_EQ(ToFloat(0x0000000000000000), 0.0f); + + EXPECT_EQ(ToFloat(0x8000000000000000), 0.5f); + EXPECT_EQ(ToFloat(0x8000000000000001), 0.5f); + EXPECT_EQ(ToFloat(0x800000FFFFFFFFFF), 0.5f); + EXPECT_EQ(ToFloat(0xFFFFFFFFFFFFFFFF), 0.9999999404f); + + EXPECT_GT(ToFloat(0x0000000000000001), 0.0f); + + EXPECT_NE(ToFloat(0x7FFFFF0000000000), ToFloat(0x7FFFFEFFFFFFFFFF)); + + EXPECT_LT(ToFloat(0xFFFFFFFFFFFFFFFF), 1.0f); + int32_t two_to_24 = 1 << 24; + EXPECT_EQ(static_cast<int32_t>(ToFloat(0xFFFFFFFFFFFFFFFF) * two_to_24), + two_to_24 - 1); + EXPECT_NE(static_cast<int32_t>(ToFloat(0xFFFFFFFFFFFFFFFF) * two_to_24 * 2), + two_to_24 * 2 - 1); + EXPECT_EQ(ToFloat(0xFFFFFFFFFFFFFFFF), ToFloat(0xFFFFFF0000000000)); + EXPECT_NE(ToFloat(0xFFFFFFFFFFFFFFFF), ToFloat(0xFFFFFEFFFFFFFFFF)); + EXPECT_EQ(ToFloat(0x7FFFFFFFFFFFFFFF), ToFloat(0x7FFFFF8000000000)); + EXPECT_NE(ToFloat(0x7FFFFFFFFFFFFFFF), ToFloat(0x7FFFFF7FFFFFFFFF)); + EXPECT_EQ(ToFloat(0x3FFFFFFFFFFFFFFF), ToFloat(0x3FFFFFC000000000)); + EXPECT_NE(ToFloat(0x3FFFFFFFFFFFFFFF), ToFloat(0x3FFFFFBFFFFFFFFF)); + + // For values where every bit counts, the values scale as multiples of the + // input. + for (int i = 0; i < 100; ++i) { + EXPECT_EQ(i * ToFloat(0x0000000000000001), ToFloat(i)); + } + + // For each i: value generated from (1 << i). + float exp_values[64]; + exp_values[63] = 0.5f; + for (int i = 62; i >= 0; --i) exp_values[i] = 0.5f * exp_values[i + 1]; + constexpr uint64_t one = 1; + for (int i = 0; i < 64; ++i) { + EXPECT_EQ(ToFloat(one << i), exp_values[i]); + for (int j = 1; j < FLT_MANT_DIG && i - j >= 0; ++j) { + EXPECT_NE(exp_values[i] + exp_values[i - j], exp_values[i]); + EXPECT_EQ(ToFloat((one << i) + (one << (i - j))), + exp_values[i] + exp_values[i - j]); + } + for (int j = FLT_MANT_DIG; i - j >= 0; ++j) { + EXPECT_EQ(exp_values[i] + exp_values[i - j], exp_values[i]); + EXPECT_EQ(ToFloat((one << i) + (one << (i - j))), exp_values[i]); + } + } +} + +TEST(DistributionImplTest, U64ToDouble_Positive_NoZero_Test) { + auto ToDouble = [](uint64_t a) { + return RandU64ToDouble<PositiveValueT, false>(a); + }; + + EXPECT_EQ(ToDouble(0x0000000000000000), 2.710505431213761085e-20); + EXPECT_EQ(ToDouble(0x0000000000000001), 5.42101086242752217004e-20); + EXPECT_EQ(ToDouble(0x0000000000000002), 1.084202172485504434e-19); + EXPECT_EQ(ToDouble(0x8000000000000000), 0.5); + EXPECT_EQ(ToDouble(0xFFFFFFFFFFFFFFFF), 0.999999999999999888978); +} + +TEST(DistributionImplTest, U64ToDouble_Positive_Zero_Test) { + auto ToDouble = [](uint64_t a) { + return RandU64ToDouble<PositiveValueT, true>(a); + }; + + EXPECT_EQ(ToDouble(0x0000000000000000), 0.0); + EXPECT_EQ(ToDouble(0x0000000000000001), 5.42101086242752217004e-20); + EXPECT_EQ(ToDouble(0x8000000000000000), 0.5); + EXPECT_EQ(ToDouble(0xFFFFFFFFFFFFFFFF), 0.999999999999999888978); +} + +TEST(DistributionImplTest, U64ToDouble_Negative_NoZero_Test) { + auto ToDouble = [](uint64_t a) { + return RandU64ToDouble<NegativeValueT, false>(a); + }; + + EXPECT_EQ(ToDouble(0x0000000000000000), -2.710505431213761085e-20); + EXPECT_EQ(ToDouble(0x0000000000000001), -5.42101086242752217004e-20); + EXPECT_EQ(ToDouble(0x0000000000000002), -1.084202172485504434e-19); + EXPECT_EQ(ToDouble(0x8000000000000000), -0.5); + EXPECT_EQ(ToDouble(0xFFFFFFFFFFFFFFFF), -0.999999999999999888978); +} + +TEST(DistributionImplTest, U64ToDouble_Signed_NoZero_Test) { + auto ToDouble = [](uint64_t a) { + return RandU64ToDouble<SignedValueT, false>(a); + }; + + EXPECT_EQ(ToDouble(0x0000000000000000), 5.42101086242752217004e-20); + EXPECT_EQ(ToDouble(0x0000000000000001), 1.084202172485504434e-19); + EXPECT_EQ(ToDouble(0x7FFFFFFFFFFFFFFF), 0.999999999999999888978); + EXPECT_EQ(ToDouble(0x8000000000000000), -5.42101086242752217004e-20); + EXPECT_EQ(ToDouble(0x8000000000000001), -1.084202172485504434e-19); + EXPECT_EQ(ToDouble(0xFFFFFFFFFFFFFFFF), -0.999999999999999888978); +} + +TEST(DistributionImplTest, U64ToDouble_Signed_Zero_Test) { + auto ToDouble = [](uint64_t a) { + return RandU64ToDouble<SignedValueT, true>(a); + }; + EXPECT_EQ(ToDouble(0x0000000000000000), 0); + EXPECT_EQ(ToDouble(0x0000000000000001), 1.084202172485504434e-19); + EXPECT_EQ(ToDouble(0x7FFFFFFFFFFFFFFF), 0.999999999999999888978); + EXPECT_EQ(ToDouble(0x8000000000000000), 0); + EXPECT_EQ(ToDouble(0x8000000000000001), -1.084202172485504434e-19); + EXPECT_EQ(ToDouble(0xFFFFFFFFFFFFFFFF), -0.999999999999999888978); +} + +TEST(DistributionImplTest, U64ToDouble_Signed_Bias_Test) { + auto ToDouble = [](uint64_t a) { + return RandU64ToDouble<SignedValueT, true, -1>(a); + }; + EXPECT_EQ(ToDouble(0x0000000000000000), 0); + EXPECT_EQ(ToDouble(0x0000000000000001), 1.084202172485504434e-19 / 2); + EXPECT_EQ(ToDouble(0x7FFFFFFFFFFFFFFF), 0.999999999999999888978 / 2); + EXPECT_EQ(ToDouble(0x8000000000000000), 0); + EXPECT_EQ(ToDouble(0x8000000000000001), -1.084202172485504434e-19 / 2); + EXPECT_EQ(ToDouble(0xFFFFFFFFFFFFFFFF), -0.999999999999999888978 / 2); +} + +TEST(DistributionImplTest, U64ToDoubleTest) { + auto ToDouble = [](uint64_t a) { + return RandU64ToDouble<PositiveValueT, true>(a); + }; + + EXPECT_EQ(ToDouble(0x0000000000000000), 0.0); + EXPECT_EQ(ToDouble(0x0000000000000000), 0.0); + + EXPECT_EQ(ToDouble(0x0000000000000001), 5.42101086242752217004e-20); + EXPECT_EQ(ToDouble(0x7fffffffffffffef), 0.499999999999999944489); + EXPECT_EQ(ToDouble(0x8000000000000000), 0.5); + + // For values > 0.5, RandU64ToDouble discards up to 11 bits. (64-53). + EXPECT_EQ(ToDouble(0x8000000000000001), 0.5); + EXPECT_EQ(ToDouble(0x80000000000007FF), 0.5); + EXPECT_EQ(ToDouble(0xFFFFFFFFFFFFFFFF), 0.999999999999999888978); + EXPECT_NE(ToDouble(0x7FFFFFFFFFFFF800), ToDouble(0x7FFFFFFFFFFFF7FF)); + + EXPECT_LT(ToDouble(0xFFFFFFFFFFFFFFFF), 1.0); + EXPECT_EQ(ToDouble(0xFFFFFFFFFFFFFFFF), ToDouble(0xFFFFFFFFFFFFF800)); + EXPECT_NE(ToDouble(0xFFFFFFFFFFFFFFFF), ToDouble(0xFFFFFFFFFFFFF7FF)); + EXPECT_EQ(ToDouble(0x7FFFFFFFFFFFFFFF), ToDouble(0x7FFFFFFFFFFFFC00)); + EXPECT_NE(ToDouble(0x7FFFFFFFFFFFFFFF), ToDouble(0x7FFFFFFFFFFFFBFF)); + EXPECT_EQ(ToDouble(0x3FFFFFFFFFFFFFFF), ToDouble(0x3FFFFFFFFFFFFE00)); + EXPECT_NE(ToDouble(0x3FFFFFFFFFFFFFFF), ToDouble(0x3FFFFFFFFFFFFDFF)); + + EXPECT_EQ(ToDouble(0x1000000000000001), 0.0625); + EXPECT_EQ(ToDouble(0x2000000000000001), 0.125); + EXPECT_EQ(ToDouble(0x3000000000000001), 0.1875); + EXPECT_EQ(ToDouble(0x4000000000000001), 0.25); + EXPECT_EQ(ToDouble(0x5000000000000001), 0.3125); + EXPECT_EQ(ToDouble(0x6000000000000001), 0.375); + EXPECT_EQ(ToDouble(0x7000000000000001), 0.4375); + EXPECT_EQ(ToDouble(0x8000000000000001), 0.5); + EXPECT_EQ(ToDouble(0x9000000000000001), 0.5625); + EXPECT_EQ(ToDouble(0xa000000000000001), 0.625); + EXPECT_EQ(ToDouble(0xb000000000000001), 0.6875); + EXPECT_EQ(ToDouble(0xc000000000000001), 0.75); + EXPECT_EQ(ToDouble(0xd000000000000001), 0.8125); + EXPECT_EQ(ToDouble(0xe000000000000001), 0.875); + EXPECT_EQ(ToDouble(0xf000000000000001), 0.9375); + + // Large powers of 2. + int64_t two_to_53 = int64_t{1} << 53; + EXPECT_EQ(static_cast<int64_t>(ToDouble(0xFFFFFFFFFFFFFFFF) * two_to_53), + two_to_53 - 1); + EXPECT_NE(static_cast<int64_t>(ToDouble(0xFFFFFFFFFFFFFFFF) * two_to_53 * 2), + two_to_53 * 2 - 1); + + // For values where every bit counts, the values scale as multiples of the + // input. + for (int i = 0; i < 100; ++i) { + EXPECT_EQ(i * ToDouble(0x0000000000000001), ToDouble(i)); + } + + // For each i: value generated from (1 << i). + double exp_values[64]; + exp_values[63] = 0.5; + for (int i = 62; i >= 0; --i) exp_values[i] = 0.5 * exp_values[i + 1]; + constexpr uint64_t one = 1; + for (int i = 0; i < 64; ++i) { + EXPECT_EQ(ToDouble(one << i), exp_values[i]); + for (int j = 1; j < DBL_MANT_DIG && i - j >= 0; ++j) { + EXPECT_NE(exp_values[i] + exp_values[i - j], exp_values[i]); + EXPECT_EQ(ToDouble((one << i) + (one << (i - j))), + exp_values[i] + exp_values[i - j]); + } + for (int j = DBL_MANT_DIG; i - j >= 0; ++j) { + EXPECT_EQ(exp_values[i] + exp_values[i - j], exp_values[i]); + EXPECT_EQ(ToDouble((one << i) + (one << (i - j))), exp_values[i]); + } + } +} + +TEST(DistributionImplTest, U64ToDoubleSignedTest) { + auto ToDouble = [](uint64_t a) { + return RandU64ToDouble<SignedValueT, false>(a); + }; + + EXPECT_EQ(ToDouble(0x0000000000000000), 5.42101086242752217004e-20); + EXPECT_EQ(ToDouble(0x0000000000000001), 1.084202172485504434e-19); + + EXPECT_EQ(ToDouble(0x8000000000000000), -5.42101086242752217004e-20); + EXPECT_EQ(ToDouble(0x8000000000000001), -1.084202172485504434e-19); + + const double e_plus = ToDouble(0x0000000000000001); + const double e_minus = ToDouble(0x8000000000000001); + EXPECT_EQ(e_plus, 1.084202172485504434e-19); + EXPECT_EQ(e_minus, -1.084202172485504434e-19); + + EXPECT_EQ(ToDouble(0x3fffffffffffffef), 0.499999999999999944489); + EXPECT_EQ(ToDouble(0xbfffffffffffffef), -0.499999999999999944489); + + // For values > 0.5, RandU64ToDouble discards up to 10 bits. (63-53). + EXPECT_EQ(ToDouble(0x4000000000000000), 0.5); + EXPECT_EQ(ToDouble(0x4000000000000001), 0.5); + EXPECT_EQ(ToDouble(0x40000000000003FF), 0.5); + + EXPECT_EQ(ToDouble(0xC000000000000000), -0.5); + EXPECT_EQ(ToDouble(0xC000000000000001), -0.5); + EXPECT_EQ(ToDouble(0xC0000000000003FF), -0.5); + + EXPECT_EQ(ToDouble(0x7FFFFFFFFFFFFFFe), 0.999999999999999888978); + EXPECT_EQ(ToDouble(0xFFFFFFFFFFFFFFFe), -0.999999999999999888978); + + EXPECT_NE(ToDouble(0x7FFFFFFFFFFFF800), ToDouble(0x7FFFFFFFFFFFF7FF)); + + EXPECT_LT(ToDouble(0x7FFFFFFFFFFFFFFF), 1.0); + EXPECT_GT(ToDouble(0x7FFFFFFFFFFFFFFF), 0.9999999999); + + EXPECT_GT(ToDouble(0xFFFFFFFFFFFFFFFe), -1.0); + EXPECT_LT(ToDouble(0xFFFFFFFFFFFFFFFe), -0.999999999); + + EXPECT_EQ(ToDouble(0xFFFFFFFFFFFFFFFe), ToDouble(0xFFFFFFFFFFFFFC00)); + EXPECT_EQ(ToDouble(0x7FFFFFFFFFFFFFFF), ToDouble(0x7FFFFFFFFFFFFC00)); + EXPECT_NE(ToDouble(0xFFFFFFFFFFFFFFFe), ToDouble(0xFFFFFFFFFFFFF3FF)); + EXPECT_NE(ToDouble(0x7FFFFFFFFFFFFFFF), ToDouble(0x7FFFFFFFFFFFF3FF)); + + EXPECT_EQ(ToDouble(0x1000000000000001), 0.125); + EXPECT_EQ(ToDouble(0x2000000000000001), 0.25); + EXPECT_EQ(ToDouble(0x3000000000000001), 0.375); + EXPECT_EQ(ToDouble(0x4000000000000001), 0.5); + EXPECT_EQ(ToDouble(0x5000000000000001), 0.625); + EXPECT_EQ(ToDouble(0x6000000000000001), 0.75); + EXPECT_EQ(ToDouble(0x7000000000000001), 0.875); + EXPECT_EQ(ToDouble(0x7800000000000001), 0.9375); + EXPECT_EQ(ToDouble(0x7c00000000000001), 0.96875); + EXPECT_EQ(ToDouble(0x7e00000000000001), 0.984375); + EXPECT_EQ(ToDouble(0x7f00000000000001), 0.9921875); + + // 0x8000000000000000 ~= 0 + EXPECT_EQ(ToDouble(0x9000000000000001), -0.125); + EXPECT_EQ(ToDouble(0xa000000000000001), -0.25); + EXPECT_EQ(ToDouble(0xb000000000000001), -0.375); + EXPECT_EQ(ToDouble(0xc000000000000001), -0.5); + EXPECT_EQ(ToDouble(0xd000000000000001), -0.625); + EXPECT_EQ(ToDouble(0xe000000000000001), -0.75); + EXPECT_EQ(ToDouble(0xf000000000000001), -0.875); + + // Large powers of 2. + int64_t two_to_53 = int64_t{1} << 53; + EXPECT_EQ(static_cast<int64_t>(ToDouble(0x7FFFFFFFFFFFFFFF) * two_to_53), + two_to_53 - 1); + EXPECT_EQ(static_cast<int64_t>(ToDouble(0xFFFFFFFFFFFFFFFF) * two_to_53), + -(two_to_53 - 1)); + + EXPECT_NE(static_cast<int64_t>(ToDouble(0x7FFFFFFFFFFFFFFF) * two_to_53 * 2), + two_to_53 * 2 - 1); + + // For values where every bit counts, the values scale as multiples of the + // input. + for (int i = 1; i < 100; ++i) { + EXPECT_EQ(i * e_plus, ToDouble(i)) << i; + EXPECT_EQ(i * e_minus, ToDouble(0x8000000000000000 | i)) << i; + } +} + +TEST(DistributionImplTest, ExhaustiveFloat) { + using absl::base_internal::CountLeadingZeros64; + auto ToFloat = [](uint64_t a) { + return RandU64ToFloat<PositiveValueT, true>(a); + }; + + // Rely on RandU64ToFloat generating values from greatest to least when + // supplied with uint64_t values from greatest (0xfff...) to least (0x0). Thus, + // this algorithm stores the previous value, and if the new value is at + // greater than or equal to the previous value, then there is a collision in + // the generation algorithm. + // + // Use the computation below to convert the random value into a result: + // double res = a() * (1.0f - sample) + b() * sample; + float last_f = 1.0, last_g = 2.0; + uint64_t f_collisions = 0, g_collisions = 0; + uint64_t f_unique = 0, g_unique = 0; + uint64_t total = 0; + auto count = [&](const float r) { + total++; + // `f` is mapped to the range [0, 1) (default) + const float f = 0.0f * (1.0f - r) + 1.0f * r; + if (f >= last_f) { + f_collisions++; + } else { + f_unique++; + last_f = f; + } + // `g` is mapped to the range [1, 2) + const float g = 1.0f * (1.0f - r) + 2.0f * r; + if (g >= last_g) { + g_collisions++; + } else { + g_unique++; + last_g = g; + } + }; + + size_t limit = absl::GetFlag(FLAGS_absl_random_test_trials); + + // Generate all uint64_t which have unique floating point values. + // Counting down from 0xFFFFFFFFFFFFFFFFu ... 0x0u + uint64_t x = ~uint64_t(0); + for (; x != 0 && limit > 0;) { + constexpr int kDig = (64 - FLT_MANT_DIG); + // Set a decrement value & the next point at which to change + // the decrement value. By default these are 1, 0. + uint64_t dec = 1; + uint64_t chk = 0; + + // Adjust decrement and check value based on how many leading 0 + // bits are set in the current value. + const int clz = CountLeadingZeros64(x); + if (clz < kDig) { + dec <<= (kDig - clz); + chk = (~uint64_t(0)) >> (clz + 1); + } + for (; x > chk && limit > 0; x -= dec) { + count(ToFloat(x)); + --limit; + } + } + + static_assert(FLT_MANT_DIG == 24, + "The float type is expected to have a 24 bit mantissa."); + + if (limit != 0) { + // There are between 2^28 and 2^29 unique values in the range [0, 1). For + // the low values of x, there are 2^24 -1 unique values. Once x > 2^24, + // there are 40 * 2^24 unique values. Thus: + // (2 + 4 + 8 ... + 2^23) + 40 * 2^23 + EXPECT_LT(1 << 28, f_unique); + EXPECT_EQ((1 << 24) + 40 * (1 << 23) - 1, f_unique); + EXPECT_EQ(total, f_unique); + EXPECT_EQ(0, f_collisions); + + // Expect at least 2^23 unique values for the range [1, 2) + EXPECT_LE(1 << 23, g_unique); + EXPECT_EQ(total - g_unique, g_collisions); + } +} + +TEST(DistributionImplTest, MultiplyU64ToU128Test) { + using absl::random_internal::MultiplyU64ToU128; + constexpr uint64_t k1 = 1; + constexpr uint64_t kMax = ~static_cast<uint64_t>(0); + + EXPECT_EQ(absl::uint128(0), MultiplyU64ToU128(0, 0)); + + // Max uint64 + EXPECT_EQ(MultiplyU64ToU128(kMax, kMax), + absl::MakeUint128(0xfffffffffffffffe, 0x0000000000000001)); + EXPECT_EQ(absl::MakeUint128(0, kMax), MultiplyU64ToU128(kMax, 1)); + EXPECT_EQ(absl::MakeUint128(0, kMax), MultiplyU64ToU128(1, kMax)); + for (int i = 0; i < 64; ++i) { + EXPECT_EQ(absl::MakeUint128(0, kMax) << i, + MultiplyU64ToU128(kMax, k1 << i)); + EXPECT_EQ(absl::MakeUint128(0, kMax) << i, + MultiplyU64ToU128(k1 << i, kMax)); + } + + // 1-bit x 1-bit. + for (int i = 0; i < 64; ++i) { + for (int j = 0; j < 64; ++j) { + EXPECT_EQ(absl::MakeUint128(0, 1) << (i + j), + MultiplyU64ToU128(k1 << i, k1 << j)); + EXPECT_EQ(absl::MakeUint128(0, 1) << (i + j), + MultiplyU64ToU128(k1 << i, k1 << j)); + } + } + + // Verified multiplies + EXPECT_EQ(MultiplyU64ToU128(0xffffeeeeddddcccc, 0xbbbbaaaa99998888), + absl::MakeUint128(0xbbbb9e2692c5dddc, 0xc28f7531048d2c60)); + EXPECT_EQ(MultiplyU64ToU128(0x0123456789abcdef, 0xfedcba9876543210), + absl::MakeUint128(0x0121fa00ad77d742, 0x2236d88fe5618cf0)); + EXPECT_EQ(MultiplyU64ToU128(0x0123456789abcdef, 0xfdb97531eca86420), + absl::MakeUint128(0x0120ae99d26725fc, 0xce197f0ecac319e0)); + EXPECT_EQ(MultiplyU64ToU128(0x97a87f4f261ba3f2, 0xfedcba9876543210), + absl::MakeUint128(0x96fbf1a8ae78d0ba, 0x5a6dd4b71f278320)); + EXPECT_EQ(MultiplyU64ToU128(0xfedcba9876543210, 0xfdb97531eca86420), + absl::MakeUint128(0xfc98c6981a413e22, 0x342d0bbf48948200)); +} + +} // namespace diff --git a/absl/random/internal/distribution_test_util.cc b/absl/random/internal/distribution_test_util.cc new file mode 100644 index 000000000000..85c8d596ebf8 --- /dev/null +++ b/absl/random/internal/distribution_test_util.cc @@ -0,0 +1,416 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/distribution_test_util.h" + +#include <cassert> +#include <cmath> +#include <string> +#include <vector> + +#include "absl/base/internal/raw_logging.h" +#include "absl/base/macros.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" + +namespace absl { +namespace random_internal { +namespace { + +#if defined(__EMSCRIPTEN__) +// Workaround __EMSCRIPTEN__ error: llvm_fma_f64 not found. +inline double fma(double x, double y, double z) { return (x * y) + z; } +#endif + +} // namespace + +DistributionMoments ComputeDistributionMoments( + absl::Span<const double> data_points) { + DistributionMoments result; + + // Compute m1 + for (double x : data_points) { + result.n++; + result.mean += x; + } + result.mean /= static_cast<double>(result.n); + + // Compute m2, m3, m4 + for (double x : data_points) { + double v = x - result.mean; + result.variance += v * v; + result.skewness += v * v * v; + result.kurtosis += v * v * v * v; + } + result.variance /= static_cast<double>(result.n - 1); + + result.skewness /= static_cast<double>(result.n); + result.skewness /= std::pow(result.variance, 1.5); + + result.kurtosis /= static_cast<double>(result.n); + result.kurtosis /= std::pow(result.variance, 2.0); + return result; + + // When validating the min/max count, the following confidence intervals may + // be of use: + // 3.291 * stddev = 99.9% CI + // 2.576 * stddev = 99% CI + // 1.96 * stddev = 95% CI + // 1.65 * stddev = 90% CI +} + +std::ostream& operator<<(std::ostream& os, const DistributionMoments& moments) { + return os << absl::StrFormat("mean=%f, stddev=%f, skewness=%f, kurtosis=%f", + moments.mean, std::sqrt(moments.variance), + moments.skewness, moments.kurtosis); +} + +double InverseNormalSurvival(double x) { + // inv_sf(u) = -sqrt(2) * erfinv(2u-1) + static constexpr double kSqrt2 = 1.4142135623730950488; + return -kSqrt2 * absl::random_internal::erfinv(2 * x - 1.0); +} + +bool Near(absl::string_view msg, double actual, double expected, double bound) { + assert(bound > 0.0); + double delta = fabs(expected - actual); + if (delta < bound) { + return true; + } + + std::string formatted = absl::StrCat( + msg, " actual=", actual, " expected=", expected, " err=", delta / bound); + ABSL_RAW_LOG(INFO, "%s", formatted.c_str()); + return false; +} + +// TODO(absl-team): Replace with an "ABSL_HAVE_SPECIAL_MATH" and try +// to use std::beta(). As of this writing P0226R1 is not implemented +// in libc++: http://libcxx.llvm.org/cxx1z_status.html +double beta(double p, double q) { + // Beta(x, y) = Gamma(x) * Gamma(y) / Gamma(x+y) + double lbeta = std::lgamma(p) + std::lgamma(q) - std::lgamma(p + q); + return std::exp(lbeta); +} + +// Approximation to inverse of the Error Function in double precision. +// (http://people.maths.ox.ac.uk/gilesm/files/gems_erfinv.pdf) +double erfinv(double x) { +#if !defined(__EMSCRIPTEN__) + using std::fma; +#endif + + double w = 0.0; + double p = 0.0; + w = -std::log((1.0 - x) * (1.0 + x)); + if (w < 6.250000) { + w = w - 3.125000; + p = -3.6444120640178196996e-21; + p = fma(p, w, -1.685059138182016589e-19); + p = fma(p, w, 1.2858480715256400167e-18); + p = fma(p, w, 1.115787767802518096e-17); + p = fma(p, w, -1.333171662854620906e-16); + p = fma(p, w, 2.0972767875968561637e-17); + p = fma(p, w, 6.6376381343583238325e-15); + p = fma(p, w, -4.0545662729752068639e-14); + p = fma(p, w, -8.1519341976054721522e-14); + p = fma(p, w, 2.6335093153082322977e-12); + p = fma(p, w, -1.2975133253453532498e-11); + p = fma(p, w, -5.4154120542946279317e-11); + p = fma(p, w, 1.051212273321532285e-09); + p = fma(p, w, -4.1126339803469836976e-09); + p = fma(p, w, -2.9070369957882005086e-08); + p = fma(p, w, 4.2347877827932403518e-07); + p = fma(p, w, -1.3654692000834678645e-06); + p = fma(p, w, -1.3882523362786468719e-05); + p = fma(p, w, 0.0001867342080340571352); + p = fma(p, w, -0.00074070253416626697512); + p = fma(p, w, -0.0060336708714301490533); + p = fma(p, w, 0.24015818242558961693); + p = fma(p, w, 1.6536545626831027356); + } else if (w < 16.000000) { + w = std::sqrt(w) - 3.250000; + p = 2.2137376921775787049e-09; + p = fma(p, w, 9.0756561938885390979e-08); + p = fma(p, w, -2.7517406297064545428e-07); + p = fma(p, w, 1.8239629214389227755e-08); + p = fma(p, w, 1.5027403968909827627e-06); + p = fma(p, w, -4.013867526981545969e-06); + p = fma(p, w, 2.9234449089955446044e-06); + p = fma(p, w, 1.2475304481671778723e-05); + p = fma(p, w, -4.7318229009055733981e-05); + p = fma(p, w, 6.8284851459573175448e-05); + p = fma(p, w, 2.4031110387097893999e-05); + p = fma(p, w, -0.0003550375203628474796); + p = fma(p, w, 0.00095328937973738049703); + p = fma(p, w, -0.0016882755560235047313); + p = fma(p, w, 0.0024914420961078508066); + p = fma(p, w, -0.0037512085075692412107); + p = fma(p, w, 0.005370914553590063617); + p = fma(p, w, 1.0052589676941592334); + p = fma(p, w, 3.0838856104922207635); + } else { + w = std::sqrt(w) - 5.000000; + p = -2.7109920616438573243e-11; + p = fma(p, w, -2.5556418169965252055e-10); + p = fma(p, w, 1.5076572693500548083e-09); + p = fma(p, w, -3.7894654401267369937e-09); + p = fma(p, w, 7.6157012080783393804e-09); + p = fma(p, w, -1.4960026627149240478e-08); + p = fma(p, w, 2.9147953450901080826e-08); + p = fma(p, w, -6.7711997758452339498e-08); + p = fma(p, w, 2.2900482228026654717e-07); + p = fma(p, w, -9.9298272942317002539e-07); + p = fma(p, w, 4.5260625972231537039e-06); + p = fma(p, w, -1.9681778105531670567e-05); + p = fma(p, w, 7.5995277030017761139e-05); + p = fma(p, w, -0.00021503011930044477347); + p = fma(p, w, -0.00013871931833623122026); + p = fma(p, w, 1.0103004648645343977); + p = fma(p, w, 4.8499064014085844221); + } + return p * x; +} + +namespace { + +// Direct implementation of AS63, BETAIN() +// https://www.jstor.org/stable/2346797?seq=3#page_scan_tab_contents. +// +// BETAIN(x, p, q, beta) +// x: the value of the upper limit x. +// p: the value of the parameter p. +// q: the value of the parameter q. +// beta: the value of ln B(p, q) +// +double BetaIncompleteImpl(const double x, const double p, const double q, + const double beta) { + if (p < (p + q) * x) { + // Incomplete beta function is symmetrical, so return the complement. + return 1. - BetaIncompleteImpl(1.0 - x, q, p, beta); + } + + double psq = p + q; + const double kErr = 1e-14; + const double xc = 1. - x; + const double pre = + std::exp(p * std::log(x) + (q - 1.) * std::log(xc) - beta) / p; + + double term = 1.; + double ai = 1.; + double result = 1.; + int ns = static_cast<int>(q + xc * psq); + + // Use the soper reduction forumla. + double rx = (ns == 0) ? x : x / xc; + double temp = q - ai; + for (;;) { + term = term * temp * rx / (p + ai); + result = result + term; + temp = std::fabs(term); + if (temp < kErr && temp < kErr * result) { + return result * pre; + } + ai = ai + 1.; + --ns; + if (ns >= 0) { + temp = q - ai; + if (ns == 0) { + rx = x; + } + } else { + temp = psq; + psq = psq + 1.; + } + } + + // NOTE: See also TOMS Alogrithm 708. + // http://www.netlib.org/toms/index.html + // + // NOTE: The NWSC library also includes BRATIO / ISUBX (p87) + // https://archive.org/details/DTIC_ADA261511/page/n75 +} + +// Direct implementation of AS109, XINBTA(p, q, beta, alpha) +// https://www.jstor.org/stable/2346798?read-now=1&seq=4#page_scan_tab_contents +// https://www.jstor.org/stable/2346887?seq=1#page_scan_tab_contents +// +// XINBTA(p, q, beta, alhpa) +// p: the value of the parameter p. +// q: the value of the parameter q. +// beta: the value of ln B(p, q) +// alpha: the value of the lower tail area. +// +double BetaIncompleteInvImpl(const double p, const double q, const double beta, + const double alpha) { + if (alpha < 0.5) { + // Inverse Incomplete beta function is symmetrical, return the complement. + return 1. - BetaIncompleteInvImpl(q, p, beta, 1. - alpha); + } + const double kErr = 1e-14; + double value = kErr; + + // Compute the initial estimate. + { + double r = std::sqrt(-std::log(alpha * alpha)); + double y = + r - fma(r, 0.27061, 2.30753) / fma(r, fma(r, 0.04481, 0.99229), 1.0); + if (p > 1. && q > 1.) { + r = (y * y - 3.) / 6.; + double s = 1. / (p + p - 1.); + double t = 1. / (q + q - 1.); + double h = 2. / s + t; + double w = + y * std::sqrt(h + r) / h - (t - s) * (r + 5. / 6. - t / (3. * h)); + value = p / (p + q * std::exp(w + w)); + } else { + r = q + q; + double t = 1.0 / (9. * q); + double u = 1.0 - t + y * std::sqrt(t); + t = r * (u * u * u); + if (t <= 0) { + value = 1.0 - std::exp((std::log((1.0 - alpha) * q) + beta) / q); + } else { + t = (4.0 * p + r - 2.0) / t; + if (t <= 1) { + value = std::exp((std::log(alpha * p) + beta) / p); + } else { + value = 1.0 - 2.0 / (t + 1.0); + } + } + } + } + + // Solve for x using a modified newton-raphson method using the function + // BetaIncomplete. + { + value = std::max(value, kErr); + value = std::min(value, 1.0 - kErr); + + const double r = 1.0 - p; + const double t = 1.0 - q; + double y; + double yprev = 0; + double sq = 1; + double prev = 1; + for (;;) { + if (value < 0 || value > 1.0) { + // Error case; value went infinite. + return std::numeric_limits<double>::infinity(); + } else if (value == 0 || value == 1) { + y = value; + } else { + y = BetaIncompleteImpl(value, p, q, beta); + if (!std::isfinite(y)) { + return y; + } + } + y = (y - alpha) * + std::exp(beta + r * std::log(value) + t * std::log(1.0 - value)); + if (y * yprev <= 0) { + prev = std::max(sq, std::numeric_limits<double>::min()); + } + double g = 1.0; + for (;;) { + const double adj = g * y; + const double adj_sq = adj * adj; + if (adj_sq >= prev) { + g = g / 3.0; + continue; + } + const double tx = value - adj; + if (tx < 0 || tx > 1) { + g = g / 3.0; + continue; + } + if (prev < kErr) { + return value; + } + if (y * y < kErr) { + return value; + } + if (tx == value) { + return value; + } + if (tx == 0 || tx == 1) { + g = g / 3.0; + continue; + } + value = tx; + yprev = y; + break; + } + } + } + + // NOTES: See also: Asymptotic inversion of the incomplete beta function. + // https://core.ac.uk/download/pdf/82140723.pdf + // + // NOTE: See the Boost library documentation as well: + // https://www.boost.org/doc/libs/1_52_0/libs/math/doc/sf_and_dist/html/math_toolkit/special/sf_beta/ibeta_function.html +} + +} // namespace + +double BetaIncomplete(const double x, const double p, const double q) { + // Error cases. + if (p < 0 || q < 0 || x < 0 || x > 1.0) { + return std::numeric_limits<double>::infinity(); + } + if (x == 0 || x == 1) { + return x; + } + // ln(Beta(p, q)) + double beta = std::lgamma(p) + std::lgamma(q) - std::lgamma(p + q); + return BetaIncompleteImpl(x, p, q, beta); +} + +double BetaIncompleteInv(const double p, const double q, const double alpha) { + // Error cases. + if (p < 0 || q < 0 || alpha < 0 || alpha > 1.0) { + return std::numeric_limits<double>::infinity(); + } + if (alpha == 0 || alpha == 1) { + return alpha; + } + // ln(Beta(p, q)) + double beta = std::lgamma(p) + std::lgamma(q) - std::lgamma(p + q); + return BetaIncompleteInvImpl(p, q, beta, alpha); +} + +// Given `num_trials` trials each with probability `p` of success, the +// probability of no failures is `p^k`. To ensure the probability of a failure +// is no more than `p_fail`, it must be that `p^k == 1 - p_fail`. This function +// computes `p` from that equation. +double RequiredSuccessProbability(const double p_fail, const int num_trials) { + double p = std::exp(std::log(1.0 - p_fail) / static_cast<double>(num_trials)); + ABSL_ASSERT(p > 0); + return p; +} + +double ZScore(double expected_mean, const DistributionMoments& moments) { + return (moments.mean - expected_mean) / + (std::sqrt(moments.variance) / + std::sqrt(static_cast<double>(moments.n))); +} + +double MaxErrorTolerance(double acceptance_probability) { + double one_sided_pvalue = 0.5 * (1.0 - acceptance_probability); + const double max_err = InverseNormalSurvival(one_sided_pvalue); + ABSL_ASSERT(max_err > 0); + return max_err; +} + +} // namespace random_internal +} // namespace absl diff --git a/absl/random/internal/distribution_test_util.h b/absl/random/internal/distribution_test_util.h new file mode 100644 index 000000000000..b5ba49fa7e1e --- /dev/null +++ b/absl/random/internal/distribution_test_util.h @@ -0,0 +1,111 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_DISTRIBUTION_TEST_UTIL_H_ +#define ABSL_RANDOM_INTERNAL_DISTRIBUTION_TEST_UTIL_H_ + +#include <cstddef> +#include <iostream> +#include <vector> + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" + +// NOTE: The functions in this file are test only, and are should not be used in +// non-test code. + +namespace absl { +namespace random_internal { + +// http://webspace.ship.edu/pgmarr/Geo441/Lectures/Lec%205%20-%20Normality%20Testing.pdf + +// Compute the 1st to 4th standard moments: +// mean, variance, skewness, and kurtosis. +// http://www.itl.nist.gov/div898/handbook/eda/section3/eda35b.htm +struct DistributionMoments { + size_t n = 0; + double mean = 0.0; + double variance = 0.0; + double skewness = 0.0; + double kurtosis = 0.0; +}; +DistributionMoments ComputeDistributionMoments( + absl::Span<const double> data_points); + +std::ostream& operator<<(std::ostream& os, const DistributionMoments& moments); + +// Computes the Z-score for a set of data with the given distribution moments +// compared against `expected_mean`. +double ZScore(double expected_mean, const DistributionMoments& moments); + +// Returns the probability of success required for a single trial to ensure that +// after `num_trials` trials, the probability of at least one failure is no more +// than `p_fail`. +double RequiredSuccessProbability(double p_fail, int num_trials); + +// Computes the maximum distance from the mean tolerable, for Z-Tests that are +// expected to pass with `acceptance_probability`. Will terminate if the +// resulting tolerance is zero (due to passing in 0.0 for +// `acceptance_probability` or rounding errors). +// +// For example, +// MaxErrorTolerance(0.001) = 0.0 +// MaxErrorTolerance(0.5) = ~0.47 +// MaxErrorTolerance(1.0) = inf +double MaxErrorTolerance(double acceptance_probability); + +// Approximation to inverse of the Error Function in double precision. +// (http://people.maths.ox.ac.uk/gilesm/files/gems_erfinv.pdf) +double erfinv(double x); + +// Beta(p, q) = Gamma(p) * Gamma(q) / Gamma(p+q) +double beta(double p, double q); + +// The inverse of the normal survival function. +double InverseNormalSurvival(double x); + +// Returns whether actual is "near" expected, based on the bound. +bool Near(absl::string_view msg, double actual, double expected, double bound); + +// Implements the incomplete regularized beta function, AS63, BETAIN. +// https://www.jstor.org/stable/2346797 +// +// BetaIncomplete(x, p, q), where +// `x` is the value of the upper limit +// `p` is beta parameter p, `q` is beta parameter q. +// +// NOTE: This is a test-only function which is only accurate to within, at most, +// 1e-13 of the actual value. +// +double BetaIncomplete(double x, double p, double q); + +// Implements the inverse of the incomplete regularized beta function, AS109, +// XINBTA. +// https://www.jstor.org/stable/2346798 +// https://www.jstor.org/stable/2346887 +// +// BetaIncompleteInv(p, q, beta, alhpa) +// `p` is beta parameter p, `q` is beta parameter q. +// `alpha` is the value of the lower tail area. +// +// NOTE: This is a test-only function and, when successful, is only accurate to +// within ~1e-6 of the actual value; there are some cases where it diverges from +// the actual value by much more than that. The function uses Newton's method, +// and thus the runtime is highly variable. +double BetaIncompleteInv(double p, double q, double alpha); + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_DISTRIBUTION_TEST_UTIL_H_ diff --git a/absl/random/internal/distribution_test_util_test.cc b/absl/random/internal/distribution_test_util_test.cc new file mode 100644 index 000000000000..c49d44fb4796 --- /dev/null +++ b/absl/random/internal/distribution_test_util_test.cc @@ -0,0 +1,193 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/distribution_test_util.h" + +#include "gtest/gtest.h" + +namespace { + +TEST(TestUtil, InverseErf) { + const struct { + const double z; + const double value; + } kErfInvTable[] = { + {0.0000001, 8.86227e-8}, + {0.00001, 8.86227e-6}, + {0.5, 0.4769362762044}, + {0.6, 0.5951160814499}, + {0.99999, 3.1234132743}, + {0.9999999, 3.7665625816}, + {0.999999944, 3.8403850690566985}, // = log((1-x) * (1+x)) =~ 16.004 + {0.999999999, 4.3200053849134452}, + }; + + for (const auto& data : kErfInvTable) { + auto value = absl::random_internal::erfinv(data.z); + + // Log using the Wolfram-alpha function name & parameters. + EXPECT_NEAR(value, data.value, 1e-8) + << " InverseErf[" << data.z << "] (expected=" << data.value << ") -> " + << value; + } +} + +const struct { + const double p; + const double q; + const double x; + const double alpha; +} kBetaTable[] = { + {0.5, 0.5, 0.01, 0.06376856085851985}, + {0.5, 0.5, 0.1, 0.2048327646991335}, + {0.5, 0.5, 1, 1}, + {1, 0.5, 0, 0}, + {1, 0.5, 0.01, 0.005012562893380045}, + {1, 0.5, 0.1, 0.0513167019494862}, + {1, 0.5, 0.5, 0.2928932188134525}, + {1, 1, 0.5, 0.5}, + {2, 2, 0.1, 0.028}, + {2, 2, 0.2, 0.104}, + {2, 2, 0.3, 0.216}, + {2, 2, 0.4, 0.352}, + {2, 2, 0.5, 0.5}, + {2, 2, 0.6, 0.648}, + {2, 2, 0.7, 0.784}, + {2, 2, 0.8, 0.896}, + {2, 2, 0.9, 0.972}, + {5.5, 5, 0.5, 0.4361908850559777}, + {10, 0.5, 0.9, 0.1516409096346979}, + {10, 5, 0.5, 0.08978271484375}, + {10, 5, 1, 1}, + {10, 10, 0.5, 0.5}, + {20, 5, 0.8, 0.4598773297575791}, + {20, 10, 0.6, 0.2146816102371739}, + {20, 10, 0.8, 0.9507364826957875}, + {20, 20, 0.5, 0.5}, + {20, 20, 0.6, 0.8979413687105918}, + {30, 10, 0.7, 0.2241297491808366}, + {30, 10, 0.8, 0.7586405487192086}, + {40, 20, 0.7, 0.7001783247477069}, + {1, 0.5, 0.1, 0.0513167019494862}, + {1, 0.5, 0.2, 0.1055728090000841}, + {1, 0.5, 0.3, 0.1633399734659245}, + {1, 0.5, 0.4, 0.2254033307585166}, + {1, 2, 0.2, 0.36}, + {1, 3, 0.2, 0.488}, + {1, 4, 0.2, 0.5904}, + {1, 5, 0.2, 0.67232}, + {2, 2, 0.3, 0.216}, + {3, 2, 0.3, 0.0837}, + {4, 2, 0.3, 0.03078}, + {5, 2, 0.3, 0.010935}, + + // These values test small & large points along the range of the Beta + // function. + // + // When selecting test points, remember that if BetaIncomplete(x, p, q) + // returns the same value to within the limits of precision over a large + // domain of the input, x, then BetaIncompleteInv(alpha, p, q) may return an + // essentially arbitrary value where BetaIncomplete(x, p, q) =~ alpha. + + // BetaRegularized[x, 0.00001, 0.00001], + // For x in {~0.001 ... ~0.999}, => ~0.5 + {1e-5, 1e-5, 1e-5, 0.4999424388184638311}, + {1e-5, 1e-5, (1.0 - 1e-8), 0.5000920948389232964}, + + // BetaRegularized[x, 0.00001, 10000]. + // For x in {~epsilon ... 1.0}, => ~1 + {1e-5, 1e5, 1e-6, 0.9999817708130066936}, + {1e-5, 1e5, (1.0 - 1e-7), 1.0}, + + // BetaRegularized[x, 10000, 0.00001]. + // For x in {0 .. 1-epsilon}, => ~0 + {1e5, 1e-5, 1e-6, 0}, + {1e5, 1e-5, (1.0 - 1e-6), 1.8229186993306369e-5}, +}; + +TEST(BetaTest, BetaIncomplete) { + for (const auto& data : kBetaTable) { + auto value = absl::random_internal::BetaIncomplete(data.x, data.p, data.q); + + // Log using the Wolfram-alpha function name & parameters. + EXPECT_NEAR(value, data.alpha, 1e-12) + << " BetaRegularized[" << data.x << ", " << data.p << ", " << data.q + << "] (expected=" << data.alpha << ") -> " << value; + } +} + +TEST(BetaTest, BetaIncompleteInv) { + for (const auto& data : kBetaTable) { + auto value = + absl::random_internal::BetaIncompleteInv(data.p, data.q, data.alpha); + + // Log using the Wolfram-alpha function name & parameters. + EXPECT_NEAR(value, data.x, 1e-6) + << " InverseBetaRegularized[" << data.alpha << ", " << data.p << ", " + << data.q << "] (expected=" << data.x << ") -> " << value; + } +} + +TEST(MaxErrorTolerance, MaxErrorTolerance) { + std::vector<std::pair<double, double>> cases = { + {0.0000001, 8.86227e-8 * 1.41421356237}, + {0.00001, 8.86227e-6 * 1.41421356237}, + {0.5, 0.4769362762044 * 1.41421356237}, + {0.6, 0.5951160814499 * 1.41421356237}, + {0.99999, 3.1234132743 * 1.41421356237}, + {0.9999999, 3.7665625816 * 1.41421356237}, + {0.999999944, 3.8403850690566985 * 1.41421356237}, + {0.999999999, 4.3200053849134452 * 1.41421356237}}; + for (auto entry : cases) { + EXPECT_NEAR(absl::random_internal::MaxErrorTolerance(entry.first), + entry.second, 1e-8); + } +} + +TEST(ZScore, WithSameMean) { + absl::random_internal::DistributionMoments m; + m.n = 100; + m.mean = 5; + m.variance = 1; + EXPECT_NEAR(absl::random_internal::ZScore(5, m), 0, 1e-12); + + m.n = 1; + m.mean = 0; + m.variance = 1; + EXPECT_NEAR(absl::random_internal::ZScore(0, m), 0, 1e-12); + + m.n = 10000; + m.mean = -5; + m.variance = 100; + EXPECT_NEAR(absl::random_internal::ZScore(-5, m), 0, 1e-12); +} + +TEST(ZScore, DifferentMean) { + absl::random_internal::DistributionMoments m; + m.n = 100; + m.mean = 5; + m.variance = 1; + EXPECT_NEAR(absl::random_internal::ZScore(4, m), 10, 1e-12); + + m.n = 1; + m.mean = 0; + m.variance = 1; + EXPECT_NEAR(absl::random_internal::ZScore(-1, m), 1, 1e-12); + + m.n = 10000; + m.mean = -5; + m.variance = 100; + EXPECT_NEAR(absl::random_internal::ZScore(-4, m), -10, 1e-12); +} +} // namespace diff --git a/absl/random/internal/distributions.h b/absl/random/internal/distributions.h new file mode 100644 index 000000000000..34db3b326f4c --- /dev/null +++ b/absl/random/internal/distributions.h @@ -0,0 +1,82 @@ +// Copyright 2019 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_DISTRIBUTIONS_H_ +#define ABSL_RANDOM_INTERNAL_DISTRIBUTIONS_H_ + +#include <type_traits> + +#include "absl/meta/type_traits.h" +#include "absl/random/internal/distribution_caller.h" +#include "absl/random/internal/traits.h" +#include "absl/random/internal/uniform_helper.h" + +namespace absl { +namespace random_internal { +template <typename D> +struct DistributionFormatTraits; + +// UniformImpl implements the core logic of the Uniform<T> call, which is to +// select the correct distribution type, compute the bounds based on the +// interval tag, and then generate a value. +template <typename NumType, typename TagType, typename URBG> +NumType UniformImpl(TagType tag, + URBG& urbg, // NOLINT(runtime/references) + NumType lo, NumType hi) { + static_assert( + std::is_arithmetic<NumType>::value, + "absl::Uniform<T>() must use an integer or real parameter type."); + + using distribution_t = + typename std::conditional<std::is_integral<NumType>::value, + absl::uniform_int_distribution<NumType>, + absl::uniform_real_distribution<NumType>>::type; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; + + auto a = random_internal::uniform_lower_bound<NumType>(tag, lo, hi); + auto b = random_internal::uniform_upper_bound<NumType>(tag, lo, hi); + // TODO(lar): it doesn't make a lot of sense to ask for a random number in an + // empty range. Right now we just return a boundary--even though that + // boundary is not an acceptable value! Is there something better we can do + // here? + + using gen_t = absl::decay_t<URBG>; + if (a > b) return a; + return DistributionCaller<gen_t>::template Call<distribution_t, format_t>( + &urbg, a, b); +} + +// In the absence of an explicitly provided return-type, the template +// "uniform_inferred_return_t<A, B>" is used to derive a suitable type, based on +// the data-types of the endpoint-arguments {A lo, B hi}. +// +// Given endpoints {A lo, B hi}, one of {A, B} will be chosen as the +// return-type, if one type can be implicitly converted into the other, in a +// lossless way. The template "is_widening_convertible" implements the +// compile-time logic for deciding if such a conversion is possible. +// +// If no such conversion between {A, B} exists, then the overload for +// absl::Uniform() will be discarded, and the call will be ill-formed. +// Return-type for absl::Uniform() when the return-type is inferred. +template <typename A, typename B> +using uniform_inferred_return_t = + absl::enable_if_t<absl::disjunction<is_widening_convertible<A, B>, + is_widening_convertible<B, A>>::value, + typename std::conditional< + is_widening_convertible<A, B>::value, B, A>::type>; + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_DISTRIBUTIONS_H_ diff --git a/absl/random/internal/explicit_seed_seq.h b/absl/random/internal/explicit_seed_seq.h new file mode 100644 index 000000000000..b660ece50017 --- /dev/null +++ b/absl/random/internal/explicit_seed_seq.h @@ -0,0 +1,87 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_EXPLICIT_SEED_SEQ_H_ +#define ABSL_RANDOM_INTERNAL_EXPLICIT_SEED_SEQ_H_ + +#include <algorithm> +#include <cstddef> +#include <cstdint> +#include <initializer_list> +#include <iterator> +#include <vector> + +namespace absl { +namespace random_internal { + +// This class conforms to the C++ Standard "Seed Sequence" concept +// [rand.req.seedseq]. +// +// An "ExplicitSeedSeq" is meant to provide a conformant interface for +// forwarding pre-computed seed material to the constructor of a class +// conforming to the "Uniform Random Bit Generator" concept. This class makes no +// attempt to mutate the state provided by its constructor, and returns it +// directly via ExplicitSeedSeq::generate(). +// +// If this class is asked to generate more seed material than was provided to +// the constructor, then the remaining bytes will be filled with deterministic, +// nonrandom data. +class ExplicitSeedSeq { + public: + using result_type = uint32_t; + + ExplicitSeedSeq() : state_() {} + + // Copy and move both allowed. + ExplicitSeedSeq(const ExplicitSeedSeq& other) = default; + ExplicitSeedSeq& operator=(const ExplicitSeedSeq& other) = default; + ExplicitSeedSeq(ExplicitSeedSeq&& other) = default; + ExplicitSeedSeq& operator=(ExplicitSeedSeq&& other) = default; + + template <typename Iterator> + ExplicitSeedSeq(Iterator begin, Iterator end) { + for (auto it = begin; it != end; it++) { + state_.push_back(*it & 0xffffffff); + } + } + + template <typename T> + ExplicitSeedSeq(std::initializer_list<T> il) + : ExplicitSeedSeq(il.begin(), il.end()) {} + + size_t size() const { return state_.size(); } + + template <typename OutIterator> + void param(OutIterator out) const { + std::copy(std::begin(state_), std::end(state_), out); + } + + template <typename OutIterator> + void generate(OutIterator begin, OutIterator end) { + for (size_t index = 0; begin != end; begin++) { + *begin = state_.empty() ? 0 : state_[index++]; + if (index >= state_.size()) { + index = 0; + } + } + } + + protected: + std::vector<uint32_t> state_; +}; + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_EXPLICIT_SEED_SEQ_H_ diff --git a/absl/random/internal/explicit_seed_seq_test.cc b/absl/random/internal/explicit_seed_seq_test.cc new file mode 100644 index 000000000000..a55ad73948b9 --- /dev/null +++ b/absl/random/internal/explicit_seed_seq_test.cc @@ -0,0 +1,204 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/explicit_seed_seq.h" + +#include <iterator> +#include <random> +#include <utility> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/random/seed_sequences.h" + +namespace { + +template <typename Sseq> +bool ConformsToInterface() { + // Check that the SeedSequence can be default-constructed. + { Sseq default_constructed_seq; } + // Check that the SeedSequence can be constructed with two iterators. + { + uint32_t init_array[] = {1, 3, 5, 7, 9}; + Sseq iterator_constructed_seq(init_array, &init_array[5]); + } + // Check that the SeedSequence can be std::initializer_list-constructed. + { Sseq list_constructed_seq = {1, 3, 5, 7, 9, 11, 13}; } + // Check that param() and size() return state provided to constructor. + { + uint32_t init_array[] = {1, 2, 3, 4, 5}; + Sseq seq(init_array, &init_array[ABSL_ARRAYSIZE(init_array)]); + EXPECT_EQ(seq.size(), ABSL_ARRAYSIZE(init_array)); + + uint32_t state_array[ABSL_ARRAYSIZE(init_array)]; + seq.param(state_array); + + for (int i = 0; i < ABSL_ARRAYSIZE(state_array); i++) { + EXPECT_EQ(state_array[i], i + 1); + } + } + // Check for presence of generate() method. + { + Sseq seq; + uint32_t seeds[5]; + + seq.generate(seeds, &seeds[ABSL_ARRAYSIZE(seeds)]); + } + return true; +} +} // namespace + +TEST(SeedSequences, CheckInterfaces) { + // Control case + EXPECT_TRUE(ConformsToInterface<std::seed_seq>()); + + // Abseil classes + EXPECT_TRUE(ConformsToInterface<absl::random_internal::ExplicitSeedSeq>()); +} + +TEST(ExplicitSeedSeq, DefaultConstructorGeneratesZeros) { + const size_t kNumBlocks = 128; + + uint32_t outputs[kNumBlocks]; + absl::random_internal::ExplicitSeedSeq seq; + seq.generate(outputs, &outputs[kNumBlocks]); + + for (uint32_t& seed : outputs) { + EXPECT_EQ(seed, 0); + } +} + +TEST(ExplicitSeeqSeq, SeedMaterialIsForwardedIdentically) { + const size_t kNumBlocks = 128; + + uint32_t seed_material[kNumBlocks]; + std::random_device urandom{"/dev/urandom"}; + for (uint32_t& seed : seed_material) { + seed = urandom(); + } + absl::random_internal::ExplicitSeedSeq seq(seed_material, + &seed_material[kNumBlocks]); + + // Check that output is same as seed-material provided to constructor. + { + const size_t kNumGenerated = kNumBlocks / 2; + uint32_t outputs[kNumGenerated]; + seq.generate(outputs, &outputs[kNumGenerated]); + for (size_t i = 0; i < kNumGenerated; i++) { + EXPECT_EQ(outputs[i], seed_material[i]); + } + } + // Check that SeedSequence is stateless between invocations: Despite the last + // invocation of generate() only consuming half of the input-entropy, the same + // entropy will be recycled for the next invocation. + { + const size_t kNumGenerated = kNumBlocks; + uint32_t outputs[kNumGenerated]; + seq.generate(outputs, &outputs[kNumGenerated]); + for (size_t i = 0; i < kNumGenerated; i++) { + EXPECT_EQ(outputs[i], seed_material[i]); + } + } + // Check that when more seed-material is asked for than is provided, nonzero + // values are still written. + { + const size_t kNumGenerated = kNumBlocks * 2; + uint32_t outputs[kNumGenerated]; + seq.generate(outputs, &outputs[kNumGenerated]); + for (size_t i = 0; i < kNumGenerated; i++) { + EXPECT_EQ(outputs[i], seed_material[i % kNumBlocks]); + } + } +} + +TEST(ExplicitSeedSeq, CopyAndMoveConstructors) { + using testing::Each; + using testing::Eq; + using testing::Not; + using testing::Pointwise; + + uint32_t entropy[4]; + std::random_device urandom("/dev/urandom"); + for (uint32_t& entry : entropy) { + entry = urandom(); + } + absl::random_internal::ExplicitSeedSeq seq_from_entropy(std::begin(entropy), + std::end(entropy)); + // Copy constructor. + { + absl::random_internal::ExplicitSeedSeq seq_copy(seq_from_entropy); + EXPECT_EQ(seq_copy.size(), seq_from_entropy.size()); + + std::vector<uint32_t> seeds_1; + seeds_1.resize(1000, 0); + std::vector<uint32_t> seeds_2; + seeds_2.resize(1000, 1); + + seq_from_entropy.generate(seeds_1.begin(), seeds_1.end()); + seq_copy.generate(seeds_2.begin(), seeds_2.end()); + + EXPECT_THAT(seeds_1, Pointwise(Eq(), seeds_2)); + } + // Assignment operator. + { + for (uint32_t& entry : entropy) { + entry = urandom(); + } + absl::random_internal::ExplicitSeedSeq another_seq(std::begin(entropy), + std::end(entropy)); + + std::vector<uint32_t> seeds_1; + seeds_1.resize(1000, 0); + std::vector<uint32_t> seeds_2; + seeds_2.resize(1000, 0); + + seq_from_entropy.generate(seeds_1.begin(), seeds_1.end()); + another_seq.generate(seeds_2.begin(), seeds_2.end()); + + // Assert precondition: Sequences generated by seed-sequences are not equal. + EXPECT_THAT(seeds_1, Not(Pointwise(Eq(), seeds_2))); + + // Apply the assignment-operator. + another_seq = seq_from_entropy; + + // Re-generate seeds. + seq_from_entropy.generate(seeds_1.begin(), seeds_1.end()); + another_seq.generate(seeds_2.begin(), seeds_2.end()); + + // Seeds generated by seed-sequences should now be equal. + EXPECT_THAT(seeds_1, Pointwise(Eq(), seeds_2)); + } + // Move constructor. + { + // Get seeds from seed-sequence constructed from entropy. + std::vector<uint32_t> seeds_1; + seeds_1.resize(1000, 0); + seq_from_entropy.generate(seeds_1.begin(), seeds_1.end()); + + // Apply move-constructor move the sequence to another instance. + absl::random_internal::ExplicitSeedSeq moved_seq( + std::move(seq_from_entropy)); + std::vector<uint32_t> seeds_2; + seeds_2.resize(1000, 1); + moved_seq.generate(seeds_2.begin(), seeds_2.end()); + // Verify that seeds produced by moved-instance are the same as original. + EXPECT_THAT(seeds_1, Pointwise(Eq(), seeds_2)); + + // Verify that the moved-from instance now behaves like a + // default-constructed instance. + EXPECT_EQ(seq_from_entropy.size(), 0); + seq_from_entropy.generate(seeds_1.begin(), seeds_1.end()); + EXPECT_THAT(seeds_1, Each(Eq(0))); + } +} diff --git a/absl/random/internal/fast_uniform_bits.h b/absl/random/internal/fast_uniform_bits.h new file mode 100644 index 000000000000..23eabbc8444c --- /dev/null +++ b/absl/random/internal/fast_uniform_bits.h @@ -0,0 +1,299 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_FAST_UNIFORM_BITS_H_ +#define ABSL_RANDOM_INTERNAL_FAST_UNIFORM_BITS_H_ + +#include <cstddef> +#include <cstdint> +#include <limits> +#include <type_traits> + +namespace absl { +namespace random_internal { +// Computes the length of the range of values producible by the URBG, or returns +// zero if that would encompass the entire range of representable values in +// URBG::result_type. +template <typename URBG> +constexpr typename URBG::result_type constexpr_range() { + using result_type = typename URBG::result_type; + return ((URBG::max)() == (std::numeric_limits<result_type>::max)() && + (URBG::min)() == std::numeric_limits<result_type>::lowest()) + ? result_type{0} + : (URBG::max)() - (URBG::min)() + result_type{1}; +} + +// FastUniformBits implements a fast path to acquire uniform independent bits +// from a type which conforms to the [rand.req.urbg] concept. +// Parameterized by: +// `UIntType`: the result (output) type +// `Width`: binary output width +// +// The std::independent_bits_engine [rand.adapt.ibits] adaptor can be +// instantiated from an existing generator through a copy or a move. It does +// not, however, facilitate the production of pseudorandom bits from an un-owned +// generator that will outlive the std::independent_bits_engine instance. +template <typename UIntType = uint64_t, + size_t Width = std::numeric_limits<UIntType>::digits> +class FastUniformBits { + static_assert(std::is_unsigned<UIntType>::value, + "Class-template FastUniformBits<> must be parameterized using " + "an unsigned type."); + + // `kWidth` is the width, in binary digits, of the output. By default it is + // the number of binary digits in the `result_type`. + static constexpr size_t kWidth = Width; + static_assert(kWidth > 0, + "Class-template FastUniformBits<> Width argument must be > 0"); + + static_assert(kWidth <= std::numeric_limits<UIntType>::digits, + "Class-template FastUniformBits<> Width argument must be <= " + "width of UIntType."); + + static constexpr bool kIsMaxWidth = + (kWidth >= std::numeric_limits<UIntType>::digits); + + // Computes a mask of `n` bits for the `UIntType`. + static constexpr UIntType constexpr_mask(size_t n) { + return (UIntType(1) << n) - 1; + } + + public: + using result_type = UIntType; + + static constexpr result_type(min)() { return 0; } + static constexpr result_type(max)() { + return kIsMaxWidth ? (std::numeric_limits<result_type>::max)() + : constexpr_mask(kWidth); + } + + template <typename URBG> + result_type operator()(URBG& g); // NOLINT(runtime/references) + + private: + // Variate() generates a single random variate, always returning a value + // in the closed interval [0 ... FastUniformBitsURBGConstants::kRangeMask] + // (kRangeMask+1 is a power of 2). + template <typename URBG> + typename URBG::result_type Variate(URBG& g); // NOLINT(runtime/references) + + // generate() generates a random value, dispatched on whether + // the underlying URNG must loop over multiple calls or not. + template <typename URBG> + result_type Generate(URBG& g, // NOLINT(runtime/references) + std::true_type /* avoid_looping */); + + template <typename URBG> + result_type Generate(URBG& g, // NOLINT(runtime/references) + std::false_type /* avoid_looping */); +}; + +// FastUniformBitsURBGConstants computes the URBG-derived constants used +// by FastUniformBits::Generate and FastUniformBits::Variate. +// Parameterized by the FastUniformBits parameter: +// `URBG`: The underlying UniformRandomNumberGenerator. +// +// The values here indicate the URBG range as well as providing an indicator +// whether the URBG output is a power of 2, and kRangeMask, which allows masking +// the generated output to kRangeBits. +template <typename URBG> +class FastUniformBitsURBGConstants { + // Computes the floor of the log. (i.e., std::floor(std::log2(N)); + static constexpr size_t constexpr_log2(size_t n) { + return (n <= 1) ? 0 : 1 + constexpr_log2(n / 2); + } + + // Computes a mask of n bits for the URBG::result_type. + static constexpr typename URBG::result_type constexpr_mask(size_t n) { + return (typename URBG::result_type(1) << n) - 1; + } + + public: + using result_type = typename URBG::result_type; + + // The range of the URNG, max - min + 1, or zero if that result would cause + // overflow. + static constexpr result_type kRange = constexpr_range<URBG>(); + + static constexpr bool kPowerOfTwo = + (kRange == 0) || ((kRange & (kRange - 1)) == 0); + + // kRangeBits describes the number number of bits suitable to mask off of URNG + // variate, which is: + // kRangeBits = floor(log2(kRange)) + static constexpr size_t kRangeBits = + kRange == 0 ? std::numeric_limits<result_type>::digits + : constexpr_log2(kRange); + + // kRangeMask is the mask used when sampling variates from the URNG when the + // width of the URNG range is not a power of 2. + // Y = (2 ^ kRange) - 1 + static constexpr result_type kRangeMask = + kRange == 0 ? (std::numeric_limits<result_type>::max)() + : constexpr_mask(kRangeBits); + + static_assert((URBG::max)() != (URBG::min)(), + "Class-template FastUniformBitsURBGConstants<> " + "URBG::max and URBG::min may not be equal."); + + static_assert(std::is_unsigned<result_type>::value, + "Class-template FastUniformBitsURBGConstants<> " + "URBG::result_type must be unsigned."); + + static_assert(kRangeMask > 0, + "Class-template FastUniformBitsURBGConstants<> " + "URBG does not generate sufficient random bits."); + + static_assert(kRange == 0 || + kRangeBits < std::numeric_limits<result_type>::digits, + "Class-template FastUniformBitsURBGConstants<> " + "URBG range computation error."); +}; + +// FastUniformBitsLoopingConstants computes the looping constants used +// by FastUniformBits::Generate. These constants indicate how multiple +// URBG::result_type values are combined into an output_value. +// Parameterized by the FastUniformBits parameters: +// `UIntType`: output type. +// `Width`: binary output width, +// `URNG`: The underlying UniformRandomNumberGenerator. +// +// The looping constants describe the sets of loop counters and mask values +// which control how individual variates are combined the final output. The +// algorithm ensures that the number of bits used by any individual call differs +// by at-most one bit from any other call. This is simplified into constants +// which describe two loops, with the second loop parameters providing one extra +// bit per variate. +// +// See [rand.adapt.ibits] for more details on the use of these constants. +template <typename UIntType, size_t Width, typename URBG> +class FastUniformBitsLoopingConstants { + private: + static constexpr size_t kWidth = Width; + using urbg_result_type = typename URBG::result_type; + using uint_result_type = UIntType; + + public: + using result_type = + typename std::conditional<(sizeof(urbg_result_type) <= + sizeof(uint_result_type)), + uint_result_type, urbg_result_type>::type; + + private: + // Estimate N as ceil(width / urng width), and W0 as (width / N). + static constexpr size_t kRangeBits = + FastUniformBitsURBGConstants<URBG>::kRangeBits; + + // The range of the URNG, max - min + 1, or zero if that result would cause + // overflow. + static constexpr result_type kRange = constexpr_range<URBG>(); + static constexpr size_t kEstimateN = + kWidth / kRangeBits + (kWidth % kRangeBits != 0); + static constexpr size_t kEstimateW0 = kWidth / kEstimateN; + static constexpr result_type kEstimateY0 = (kRange >> kEstimateW0) + << kEstimateW0; + + public: + // Parameters for the two loops: + // kN0, kN1 are the number of underlying calls required for each loop. + // KW0, kW1 are shift widths for each loop. + // + static constexpr size_t kN1 = (kRange - kEstimateY0) > + (kEstimateY0 / kEstimateN) + ? kEstimateN + 1 + : kEstimateN; + static constexpr size_t kN0 = kN1 - (kWidth % kN1); + static constexpr size_t kW0 = kWidth / kN1; + static constexpr size_t kW1 = kW0 + 1; + + static constexpr result_type kM0 = (result_type(1) << kW0) - 1; + static constexpr result_type kM1 = (result_type(1) << kW1) - 1; + + static_assert( + kW0 <= kRangeBits, + "Class-template FastUniformBitsLoopingConstants::kW0 too large."); + + static_assert( + kW0 > 0, + "Class-template FastUniformBitsLoopingConstants::kW0 too small."); +}; + +template <typename UIntType, size_t Width> +template <typename URBG> +typename FastUniformBits<UIntType, Width>::result_type +FastUniformBits<UIntType, Width>::operator()( + URBG& g) { // NOLINT(runtime/references) + using constants = FastUniformBitsURBGConstants<URBG>; + return Generate( + g, std::integral_constant<bool, constants::kRangeMask >= (max)()>{}); +} + +template <typename UIntType, size_t Width> +template <typename URBG> +typename URBG::result_type FastUniformBits<UIntType, Width>::Variate( + URBG& g) { // NOLINT(runtime/references) + using constants = FastUniformBitsURBGConstants<URBG>; + if (constants::kPowerOfTwo) { + return g() - (URBG::min)(); + } + + // Use rejection sampling to ensure uniformity across the range. + typename URBG::result_type u; + do { + u = g() - (URBG::min)(); + } while (u > constants::kRangeMask); + return u; +} + +template <typename UIntType, size_t Width> +template <typename URBG> +typename FastUniformBits<UIntType, Width>::result_type +FastUniformBits<UIntType, Width>::Generate( + URBG& g, // NOLINT(runtime/references) + std::true_type /* avoid_looping */) { + // The width of the result_type is less than than the width of the random bits + // provided by URNG. Thus, generate a single value and then simply mask off + // the required bits. + return Variate(g) & (max)(); +} + +template <typename UIntType, size_t Width> +template <typename URBG> +typename FastUniformBits<UIntType, Width>::result_type +FastUniformBits<UIntType, Width>::Generate( + URBG& g, // NOLINT(runtime/references) + std::false_type /* avoid_looping */) { + // The width of the result_type is wider than the number of random bits + // provided by URNG. Thus we merge several variates of URNG into the result + // using a shift and mask. The constants type generates the parameters used + // ensure that the bits are distributed across all the invocations of the + // underlying URNG. + using constants = FastUniformBitsLoopingConstants<UIntType, Width, URBG>; + + result_type s = 0; + for (size_t n = 0; n < constants::kN0; ++n) { + auto u = Variate(g); + s = (s << constants::kW0) + (u & constants::kM0); + } + for (size_t n = constants::kN0; n < constants::kN1; ++n) { + auto u = Variate(g); + s = (s << constants::kW1) + (u & constants::kM1); + } + return s; +} + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_FAST_UNIFORM_BITS_H_ diff --git a/absl/random/internal/fast_uniform_bits_test.cc b/absl/random/internal/fast_uniform_bits_test.cc new file mode 100644 index 000000000000..f4b9cd5fcc6d --- /dev/null +++ b/absl/random/internal/fast_uniform_bits_test.cc @@ -0,0 +1,290 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/fast_uniform_bits.h" + +#include <random> + +#include "gtest/gtest.h" + +namespace { + +template <typename IntType> +class FastUniformBitsTypedTest : public ::testing::Test {}; + +using IntTypes = ::testing::Types<uint8_t, uint16_t, uint32_t, uint64_t>; + +TYPED_TEST_SUITE(FastUniformBitsTypedTest, IntTypes); + +TYPED_TEST(FastUniformBitsTypedTest, BasicTest) { + using Limits = std::numeric_limits<TypeParam>; + using FastBits = absl::random_internal::FastUniformBits<TypeParam>; + + EXPECT_EQ(0, FastBits::min()); + EXPECT_EQ(Limits::max(), FastBits::max()); + + constexpr int kIters = 10000; + std::random_device rd; + std::mt19937 gen(rd()); + FastBits fast; + for (int i = 0; i < kIters; i++) { + const auto v = fast(gen); + EXPECT_LE(v, FastBits::max()); + EXPECT_GE(v, FastBits::min()); + } +} + +TEST(FastUniformBitsTest, TypeBoundaries32) { + // Tests that FastUniformBits can adapt to 32-bit boundaries. + absl::random_internal::FastUniformBits<uint32_t, 1> a; + absl::random_internal::FastUniformBits<uint32_t, 31> b; + absl::random_internal::FastUniformBits<uint32_t, 32> c; + + { + std::mt19937 gen; // 32-bit + a(gen); + b(gen); + c(gen); + } + + { + std::mt19937_64 gen; // 64-bit + a(gen); + b(gen); + c(gen); + } +} + +TEST(FastUniformBitsTest, TypeBoundaries64) { + // Tests that FastUniformBits can adapt to 64-bit boundaries. + absl::random_internal::FastUniformBits<uint64_t, 1> a; + absl::random_internal::FastUniformBits<uint64_t, 31> b; + absl::random_internal::FastUniformBits<uint64_t, 32> c; + absl::random_internal::FastUniformBits<uint64_t, 33> d; + absl::random_internal::FastUniformBits<uint64_t, 63> e; + absl::random_internal::FastUniformBits<uint64_t, 64> f; + + { + std::mt19937 gen; // 32-bit + a(gen); + b(gen); + c(gen); + d(gen); + e(gen); + f(gen); + } + + { + std::mt19937_64 gen; // 64-bit + a(gen); + b(gen); + c(gen); + d(gen); + e(gen); + f(gen); + } +} + +class UrngOddbits { + public: + using result_type = uint8_t; + static constexpr result_type min() { return 1; } + static constexpr result_type max() { return 0xfe; } + result_type operator()() { return 2; } +}; + +class Urng4bits { + public: + using result_type = uint8_t; + static constexpr result_type min() { return 1; } + static constexpr result_type max() { return 0xf + 1; } + result_type operator()() { return 2; } +}; + +class Urng32bits { + public: + using result_type = uint32_t; + static constexpr result_type min() { return 0; } + static constexpr result_type max() { return 0xffffffff; } + result_type operator()() { return 1; } +}; + +// Compile-time test to validate the helper classes used by FastUniformBits +TEST(FastUniformBitsTest, FastUniformBitsDetails) { + using absl::random_internal::FastUniformBitsLoopingConstants; + using absl::random_internal::FastUniformBitsURBGConstants; + + // 4-bit URBG + { + using constants = FastUniformBitsURBGConstants<Urng4bits>; + static_assert(constants::kPowerOfTwo == true, + "constants::kPowerOfTwo == false"); + static_assert(constants::kRange == 16, "constants::kRange == false"); + static_assert(constants::kRangeBits == 4, "constants::kRangeBits == false"); + static_assert(constants::kRangeMask == 0x0f, + "constants::kRangeMask == false"); + } + { + using looping = FastUniformBitsLoopingConstants<uint32_t, 31, Urng4bits>; + // To get 31 bits from a 4-bit generator, issue 8 calls and extract 4 bits + // per call on all except the first. + static_assert(looping::kN0 == 1, "looping::kN0"); + static_assert(looping::kW0 == 3, "looping::kW0"); + static_assert(looping::kM0 == 0x7, "looping::kM0"); + // (The second set of calls, kN1, will not do anything.) + static_assert(looping::kN1 == 8, "looping::kN1"); + static_assert(looping::kW1 == 4, "looping::kW1"); + static_assert(looping::kM1 == 0xf, "looping::kM1"); + } + + // ~7-bit URBG + { + using constants = FastUniformBitsURBGConstants<UrngOddbits>; + static_assert(constants::kPowerOfTwo == false, + "constants::kPowerOfTwo == false"); + static_assert(constants::kRange == 0xfe, "constants::kRange == 0xfe"); + static_assert(constants::kRangeBits == 7, "constants::kRangeBits == 7"); + static_assert(constants::kRangeMask == 0x7f, + "constants::kRangeMask == 0x7f"); + } + { + using looping = FastUniformBitsLoopingConstants<uint64_t, 60, UrngOddbits>; + // To get 60 bits from a 7-bit generator, issue 10 calls and extract 6 bits + // per call, discarding the excess entropy. + static_assert(looping::kN0 == 10, "looping::kN0"); + static_assert(looping::kW0 == 6, "looping::kW0"); + static_assert(looping::kM0 == 0x3f, "looping::kM0"); + // (The second set of calls, kN1, will not do anything.) + static_assert(looping::kN1 == 10, "looping::kN1"); + static_assert(looping::kW1 == 7, "looping::kW1"); + static_assert(looping::kM1 == 0x7f, "looping::kM1"); + } + { + using looping = FastUniformBitsLoopingConstants<uint64_t, 63, UrngOddbits>; + // To get 63 bits from a 7-bit generator, issue 10 calls--the same as we + // would issue for 60 bits--however this time we use two groups. The first + // group (kN0) will issue 7 calls, extracting 6 bits per call. + static_assert(looping::kN0 == 7, "looping::kN0"); + static_assert(looping::kW0 == 6, "looping::kW0"); + static_assert(looping::kM0 == 0x3f, "looping::kM0"); + // The second group (kN1) will issue 3 calls, extracting 7 bits per call. + static_assert(looping::kN1 == 10, "looping::kN1"); + static_assert(looping::kW1 == 7, "looping::kW1"); + static_assert(looping::kM1 == 0x7f, "looping::kM1"); + } +} + +TEST(FastUniformBitsTest, Urng4_VariousOutputs) { + // Tests that how values are composed; the single-bit deltas should be spread + // across each invocation. + Urng4bits urng4; + Urng32bits urng32; + + // 8-bit types + { + absl::random_internal::FastUniformBits<uint8_t, 1> fast1; + EXPECT_EQ(0x1, fast1(urng4)); + EXPECT_EQ(0x1, fast1(urng32)); + } + { + absl::random_internal::FastUniformBits<uint8_t, 2> fast2; + EXPECT_EQ(0x1, fast2(urng4)); + EXPECT_EQ(0x1, fast2(urng32)); + } + + { + absl::random_internal::FastUniformBits<uint8_t, 4> fast4; + EXPECT_EQ(0x1, fast4(urng4)); + EXPECT_EQ(0x1, fast4(urng32)); + } + { + absl::random_internal::FastUniformBits<uint8_t, 6> fast6; + EXPECT_EQ(0x9, fast6(urng4)); // b001001 (2x3) + EXPECT_EQ(0x1, fast6(urng32)); + } + { + absl::random_internal::FastUniformBits<uint8_t, 6> fast7; + EXPECT_EQ(0x9, fast7(urng4)); // b00001001 (1x4 + 1x3) + EXPECT_EQ(0x1, fast7(urng32)); + } + + { + absl::random_internal::FastUniformBits<uint8_t> fast8; + EXPECT_EQ(0x11, fast8(urng4)); + EXPECT_EQ(0x1, fast8(urng32)); + } + + // 16-bit types + { + absl::random_internal::FastUniformBits<uint16_t, 10> fast10; + EXPECT_EQ(0x91, fast10(urng4)); // b 0010010001 (2x3 + 1x4) + EXPECT_EQ(0x1, fast10(urng32)); + } + { + absl::random_internal::FastUniformBits<uint16_t, 11> fast11; + EXPECT_EQ(0x111, fast11(urng4)); + EXPECT_EQ(0x1, fast11(urng32)); + } + { + absl::random_internal::FastUniformBits<uint16_t, 12> fast12; + EXPECT_EQ(0x111, fast12(urng4)); + EXPECT_EQ(0x1, fast12(urng32)); + } + + { + absl::random_internal::FastUniformBits<uint16_t> fast16; + EXPECT_EQ(0x1111, fast16(urng4)); + EXPECT_EQ(0x1, fast16(urng32)); + } + + // 32-bit types + { + absl::random_internal::FastUniformBits<uint32_t, 21> fast21; + EXPECT_EQ(0x49111, fast21(urng4)); // b 001001001 000100010001 (3x3 + 3x4) + EXPECT_EQ(0x1, fast21(urng32)); + } + { + absl::random_internal::FastUniformBits<uint32_t, 24> fast24; + EXPECT_EQ(0x111111, fast24(urng4)); + EXPECT_EQ(0x1, fast24(urng32)); + } + + { + absl::random_internal::FastUniformBits<uint32_t> fast32; + EXPECT_EQ(0x11111111, fast32(urng4)); + EXPECT_EQ(0x1, fast32(urng32)); + } + + // 64-bit types + { + absl::random_internal::FastUniformBits<uint64_t, 5> fast5; + EXPECT_EQ(0x9, fast5(urng4)); + EXPECT_EQ(0x1, fast5(urng32)); + } + + { + absl::random_internal::FastUniformBits<uint64_t, 48> fast48; + EXPECT_EQ(0x111111111111, fast48(urng4)); + // computes in 2 steps, should be 24 << 24 + EXPECT_EQ(0x000001000001, fast48(urng32)); + } + + { + absl::random_internal::FastUniformBits<uint64_t> fast64; + EXPECT_EQ(0x1111111111111111, fast64(urng4)); + EXPECT_EQ(0x0000000100000001, fast64(urng32)); + } +} + +} // namespace diff --git a/absl/random/internal/fastmath.h b/absl/random/internal/fastmath.h new file mode 100644 index 000000000000..4bd184105399 --- /dev/null +++ b/absl/random/internal/fastmath.h @@ -0,0 +1,72 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_FASTMATH_H_ +#define ABSL_RANDOM_INTERNAL_FASTMATH_H_ + +// This file contains fast math functions (bitwise ops as well as some others) +// which are implementation details of various absl random number distributions. + +#include <cassert> +#include <cmath> +#include <cstdint> + +#include "absl/base/internal/bits.h" + +namespace absl { +namespace random_internal { + +// Returns the position of the first bit set. +inline int LeadingSetBit(uint64_t n) { + return 64 - base_internal::CountLeadingZeros64(n); +} + +// Compute log2(n) using integer operations. +// While std::log2 is more accurate than std::log(n) / std::log(2), for +// very large numbers--those close to std::numeric_limits<uint64_t>::max() - 2, +// for instance--std::log2 rounds up rather than down, which introduces +// definite skew in the results. +inline int IntLog2Floor(uint64_t n) { + return (n <= 1) ? 0 : (63 - base_internal::CountLeadingZeros64(n)); +} +inline int IntLog2Ceil(uint64_t n) { + return (n <= 1) ? 0 : (64 - base_internal::CountLeadingZeros64(n - 1)); +} + +inline double StirlingLogFactorial(double n) { + assert(n >= 1); + // Using Stirling's approximation. + constexpr double kLog2PI = 1.83787706640934548356; + const double logn = std::log(n); + const double ninv = 1.0 / static_cast<double>(n); + return n * logn - n + 0.5 * (kLog2PI + logn) + (1.0 / 12.0) * ninv - + (1.0 / 360.0) * ninv * ninv * ninv; +} + +// Rotate value right. +// +// We only implement the uint32_t / uint64_t versions because +// 1) those are the only ones we use, and +// 2) those are the only ones where clang detects the rotate idiom correctly. +inline constexpr uint32_t rotr(uint32_t value, uint8_t bits) { + return (value >> (bits & 31)) | (value << ((-bits) & 31)); +} +inline constexpr uint64_t rotr(uint64_t value, uint8_t bits) { + return (value >> (bits & 63)) | (value << ((-bits) & 63)); +} + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_FASTMATH_H_ diff --git a/absl/random/internal/fastmath_test.cc b/absl/random/internal/fastmath_test.cc new file mode 100644 index 000000000000..65859c25de4c --- /dev/null +++ b/absl/random/internal/fastmath_test.cc @@ -0,0 +1,110 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/fastmath.h" + +#include "gtest/gtest.h" + +#if defined(__native_client__) || defined(__EMSCRIPTEN__) +// NACL has a less accurate implementation of std::log2 than most of +// the other platforms. For some values which should have integral results, +// sometimes NACL returns slightly larger values. +// +// The MUSL libc used by emscripten also has a similar bug. +#define ABSL_RANDOM_INACCURATE_LOG2 +#endif + +namespace { + +TEST(DistributionImplTest, LeadingSetBit) { + using absl::random_internal::LeadingSetBit; + constexpr uint64_t kZero = 0; + EXPECT_EQ(0, LeadingSetBit(kZero)); + EXPECT_EQ(64, LeadingSetBit(~kZero)); + + for (int index = 0; index < 64; index++) { + uint64_t x = static_cast<uint64_t>(1) << index; + EXPECT_EQ(index + 1, LeadingSetBit(x)) << index; + EXPECT_EQ(index + 1, LeadingSetBit(x + x - 1)) << index; + } +} + +TEST(FastMathTest, IntLog2FloorTest) { + using absl::random_internal::IntLog2Floor; + constexpr uint64_t kZero = 0; + EXPECT_EQ(0, IntLog2Floor(0)); // boundary. return 0. + EXPECT_EQ(0, IntLog2Floor(1)); + EXPECT_EQ(1, IntLog2Floor(2)); + EXPECT_EQ(63, IntLog2Floor(~kZero)); + + // A boundary case: Converting 0xffffffffffffffff requires > 53 + // bits of precision, so the conversion to double rounds up, + // and the result of std::log2(x) > IntLog2Floor(x). + EXPECT_LT(IntLog2Floor(~kZero), static_cast<int>(std::log2(~kZero))); + + for (int i = 0; i < 64; i++) { + const uint64_t i_pow_2 = static_cast<uint64_t>(1) << i; + EXPECT_EQ(i, IntLog2Floor(i_pow_2)); + EXPECT_EQ(i, static_cast<int>(std::log2(i_pow_2))); + + uint64_t y = i_pow_2; + for (int j = i - 1; j > 0; --j) { + y = y | (i_pow_2 >> j); + EXPECT_EQ(i, IntLog2Floor(y)); + } + } +} + +TEST(FastMathTest, IntLog2CeilTest) { + using absl::random_internal::IntLog2Ceil; + constexpr uint64_t kZero = 0; + EXPECT_EQ(0, IntLog2Ceil(0)); // boundary. return 0. + EXPECT_EQ(0, IntLog2Ceil(1)); + EXPECT_EQ(1, IntLog2Ceil(2)); + EXPECT_EQ(64, IntLog2Ceil(~kZero)); + + // A boundary case: Converting 0xffffffffffffffff requires > 53 + // bits of precision, so the conversion to double rounds up, + // and the result of std::log2(x) > IntLog2Floor(x). + EXPECT_LE(IntLog2Ceil(~kZero), static_cast<int>(std::log2(~kZero))); + + for (int i = 0; i < 64; i++) { + const uint64_t i_pow_2 = static_cast<uint64_t>(1) << i; + EXPECT_EQ(i, IntLog2Ceil(i_pow_2)); +#ifndef ABSL_RANDOM_INACCURATE_LOG2 + EXPECT_EQ(i, static_cast<int>(std::ceil(std::log2(i_pow_2)))); +#endif + + uint64_t y = i_pow_2; + for (int j = i - 1; j > 0; --j) { + y = y | (i_pow_2 >> j); + EXPECT_EQ(i + 1, IntLog2Ceil(y)); + } + } +} + +TEST(FastMathTest, StirlingLogFactorial) { + using absl::random_internal::StirlingLogFactorial; + + EXPECT_NEAR(StirlingLogFactorial(1.0), 0, 1e-3); + EXPECT_NEAR(StirlingLogFactorial(1.50), 0.284683, 1e-3); + EXPECT_NEAR(StirlingLogFactorial(2.0), 0.69314718056, 1e-4); + + for (int i = 2; i < 50; i++) { + double d = static_cast<double>(i); + EXPECT_NEAR(StirlingLogFactorial(d), std::lgamma(d + 1), 3e-5); + } +} + +} // namespace diff --git a/absl/random/internal/gaussian_distribution_gentables.cc b/absl/random/internal/gaussian_distribution_gentables.cc new file mode 100644 index 000000000000..85247966f51b --- /dev/null +++ b/absl/random/internal/gaussian_distribution_gentables.cc @@ -0,0 +1,139 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Generates gaussian_distribution.cc +// +// $ blaze run :gaussian_distribution_gentables > gaussian_distribution.cc +// +#include "absl/random/gaussian_distribution.h" + +#include <cmath> +#include <cstddef> +#include <iostream> +#include <limits> +#include <string> + +#include "absl/base/macros.h" + +namespace absl { +namespace random_internal { +namespace { + +template <typename T, size_t N> +void FormatArrayContents(std::ostream* os, T (&data)[N]) { + if (!std::numeric_limits<T>::is_exact) { + // Note: T is either an integer or a float. + // float requires higher precision to ensure that values are + // reproduced exactly. + // Trivia: C99 has hexadecimal floating point literals, but C++11 does not. + // Using them would remove all concern of precision loss. + os->precision(std::numeric_limits<T>::max_digits10 + 2); + } + *os << " {"; + std::string separator = ""; + for (size_t i = 0; i < N; ++i) { + *os << separator << data[i]; + if ((i + 1) % 3 != 0) { + separator = ", "; + } else { + separator = ",\n "; + } + } + *os << "}"; +} + +} // namespace + +class TableGenerator : public gaussian_distribution_base { + public: + TableGenerator(); + void Print(std::ostream* os); + + using gaussian_distribution_base::kMask; + using gaussian_distribution_base::kR; + using gaussian_distribution_base::kV; + + private: + Tables tables_; +}; + +// Ziggurat gaussian initialization. For an explanation of the algorithm, see +// the Marsaglia paper, "The Ziggurat Method for Generating Random Variables". +// http://www.jstatsoft.org/v05/i08/ +// +// Further details are available in the Doornik paper +// https://www.doornik.com/research/ziggurat.pdf +// +TableGenerator::TableGenerator() { + // The constants here should match the values in gaussian_distribution.h + static constexpr int kC = kMask + 1; + + static_assert((ABSL_ARRAYSIZE(tables_.x) == kC + 1), + "xArray must be length kMask + 2"); + + static_assert((ABSL_ARRAYSIZE(tables_.x) == ABSL_ARRAYSIZE(tables_.f)), + "fx and x arrays must be identical length"); + + auto f = [](double x) { return std::exp(-0.5 * x * x); }; + auto f_inv = [](double x) { return std::sqrt(-2.0 * std::log(x)); }; + + tables_.x[0] = kV / f(kR); + tables_.f[0] = f(tables_.x[0]); + + tables_.x[1] = kR; + tables_.f[1] = f(tables_.x[1]); + + tables_.x[kC] = 0.0; + tables_.f[kC] = f(tables_.x[kC]); // 1.0 + + for (int i = 2; i < kC; i++) { + double v = (kV / tables_.x[i - 1]) + tables_.f[i - 1]; + tables_.x[i] = f_inv(v); + tables_.f[i] = v; + } +} + +void TableGenerator::Print(std::ostream* os) { + *os << "// BEGIN GENERATED CODE; DO NOT EDIT\n" + "// clang-format off\n" + "\n" + "#include \"absl/random/gaussian_distribution.h\"\n" + "\n" + "namespace absl {\n" + "namespace random_internal {\n" + "\n" + "const gaussian_distribution_base::Tables\n" + " gaussian_distribution_base::zg_ = {\n"; + FormatArrayContents(os, tables_.x); + *os << ",\n"; + FormatArrayContents(os, tables_.f); + *os << "};\n" + "\n" + "} // namespace random_internal\n" + "} // namespace absl\n" + "\n" + "// clang-format on\n" + "// END GENERATED CODE"; + *os << std::endl; +} + +} // namespace random_internal +} // namespace absl + +int main(int, char**) { + std::cerr << "\nCopy the output to gaussian_distribution.cc" << std::endl; + absl::random_internal::TableGenerator generator; + generator.Print(&std::cout); + return 0; +} diff --git a/absl/random/internal/iostream_state_saver.h b/absl/random/internal/iostream_state_saver.h new file mode 100644 index 000000000000..df88fa76bcde --- /dev/null +++ b/absl/random/internal/iostream_state_saver.h @@ -0,0 +1,243 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_IOSTREAM_STATE_SAVER_H_ +#define ABSL_RANDOM_INTERNAL_IOSTREAM_STATE_SAVER_H_ + +#include <cmath> +#include <iostream> +#include <limits> +#include <type_traits> + +#include "absl/meta/type_traits.h" +#include "absl/numeric/int128.h" + +namespace absl { +namespace random_internal { + +// The null_state_saver does nothing. +template <typename T> +class null_state_saver { + public: + using stream_type = T; + using flags_type = std::ios_base::fmtflags; + + null_state_saver(T&, flags_type) {} + ~null_state_saver() {} +}; + +// ostream_state_saver is a RAII object to save and restore the common +// basic_ostream flags used when implementing `operator <<()` on any of +// the absl random distributions. +template <typename OStream> +class ostream_state_saver { + public: + using ostream_type = OStream; + using flags_type = std::ios_base::fmtflags; + using fill_type = typename ostream_type::char_type; + using precision_type = std::streamsize; + + ostream_state_saver(ostream_type& os, // NOLINT(runtime/references) + flags_type flags, fill_type fill) + : os_(os), + flags_(os.flags(flags)), + fill_(os.fill(fill)), + precision_(os.precision()) { + // Save state in initialized variables. + } + + ~ostream_state_saver() { + // Restore saved state. + os_.precision(precision_); + os_.fill(fill_); + os_.flags(flags_); + } + + private: + ostream_type& os_; + const flags_type flags_; + const fill_type fill_; + const precision_type precision_; +}; + +#if defined(__NDK_MAJOR__) && __NDK_MAJOR__ < 16 +#define ABSL_RANDOM_INTERNAL_IOSTREAM_HEXFLOAT 1 +#else +#define ABSL_RANDOM_INTERNAL_IOSTREAM_HEXFLOAT 0 +#endif + +template <typename CharT, typename Traits> +ostream_state_saver<std::basic_ostream<CharT, Traits>> make_ostream_state_saver( + std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references) + std::ios_base::fmtflags flags = std::ios_base::dec | std::ios_base::left | +#if ABSL_RANDOM_INTERNAL_IOSTREAM_HEXFLOAT + std::ios_base::fixed | +#endif + std::ios_base::scientific) { + using result_type = ostream_state_saver<std::basic_ostream<CharT, Traits>>; + return result_type(os, flags, os.widen(' ')); +} + +template <typename T> +typename absl::enable_if_t<!std::is_base_of<std::ios_base, T>::value, + null_state_saver<T>> +make_ostream_state_saver(T& is, // NOLINT(runtime/references) + std::ios_base::fmtflags flags = std::ios_base::dec) { + std::cerr << "null_state_saver"; + using result_type = null_state_saver<T>; + return result_type(is, flags); +} + +// stream_precision_helper<type>::kPrecision returns the base 10 precision +// required to stream and reconstruct a real type exact binary value through +// a binary->decimal->binary transition. +template <typename T> +struct stream_precision_helper { + // max_digits10 may be 0 on MSVC; if so, use digits10 + 3. + static constexpr int kPrecision = + (std::numeric_limits<T>::max_digits10 > std::numeric_limits<T>::digits10) + ? std::numeric_limits<T>::max_digits10 + : (std::numeric_limits<T>::digits10 + 3); +}; + +template <> +struct stream_precision_helper<float> { + static constexpr int kPrecision = 9; +}; +template <> +struct stream_precision_helper<double> { + static constexpr int kPrecision = 17; +}; +template <> +struct stream_precision_helper<long double> { + static constexpr int kPrecision = 36; // assuming fp128 +}; + +// istream_state_saver is a RAII object to save and restore the common +// std::basic_istream<> flags used when implementing `operator >>()` on any of +// the absl random distributions. +template <typename IStream> +class istream_state_saver { + public: + using istream_type = IStream; + using flags_type = std::ios_base::fmtflags; + + istream_state_saver(istream_type& is, // NOLINT(runtime/references) + flags_type flags) + : is_(is), flags_(is.flags(flags)) {} + + ~istream_state_saver() { is_.flags(flags_); } + + private: + istream_type& is_; + flags_type flags_; +}; + +template <typename CharT, typename Traits> +istream_state_saver<std::basic_istream<CharT, Traits>> make_istream_state_saver( + std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references) + std::ios_base::fmtflags flags = std::ios_base::dec | + std::ios_base::scientific | + std::ios_base::skipws) { + using result_type = istream_state_saver<std::basic_istream<CharT, Traits>>; + return result_type(is, flags); +} + +template <typename T> +typename absl::enable_if_t<!std::is_base_of<std::ios_base, T>::value, + null_state_saver<T>> +make_istream_state_saver(T& is, // NOLINT(runtime/references) + std::ios_base::fmtflags flags = std::ios_base::dec) { + using result_type = null_state_saver<T>; + return result_type(is, flags); +} + +// stream_format_type<T> is a helper struct to convert types which +// basic_iostream cannot output as decimal numbers into types which +// basic_iostream can output as decimal numbers. Specifically: +// * signed/unsigned char-width types are converted to int. +// * TODO(lar): __int128 => uint128, except there is no operator << yet. +// +template <typename T> +struct stream_format_type + : public std::conditional<(sizeof(T) == sizeof(char)), int, T> {}; + +// stream_u128_helper allows us to write out either absl::uint128 or +// __uint128_t types in the same way, which enables their use as internal +// state of PRNG engines. +template <typename T> +struct stream_u128_helper; + +template <> +struct stream_u128_helper<absl::uint128> { + template <typename IStream> + inline absl::uint128 read(IStream& in) { + uint64_t h = 0; + uint64_t l = 0; + in >> h >> l; + return absl::MakeUint128(h, l); + } + + template <typename OStream> + inline void write(absl::uint128 val, OStream& out) { + uint64_t h = Uint128High64(val); + uint64_t l = Uint128Low64(val); + out << h << out.fill() << l; + } +}; + +#ifdef ABSL_HAVE_INTRINSIC_INT128 +template <> +struct stream_u128_helper<__uint128_t> { + template <typename IStream> + inline __uint128_t read(IStream& in) { + uint64_t h = 0; + uint64_t l = 0; + in >> h >> l; + return (static_cast<__uint128_t>(h) << 64) | l; + } + + template <typename OStream> + inline void write(__uint128_t val, OStream& out) { + uint64_t h = static_cast<uint64_t>(val >> 64u); + uint64_t l = static_cast<uint64_t>(val); + out << h << out.fill() << l; + } +}; +#endif + +template <typename FloatType, typename IStream> +inline FloatType read_floating_point(IStream& is) { + static_assert(std::is_floating_point<FloatType>::value, ""); + FloatType dest; + is >> dest; + // Parsing a double value may report a subnormal value as an error + // despite being able to represent it. + // See https://stackoverflow.com/q/52410931/3286653 + // It may also report an underflow when parsing DOUBLE_MIN as an + // ERANGE error, as the parsed value may be smaller than DOUBLE_MIN + // and rounded up. + // See: https://stackoverflow.com/q/42005462 + if (is.fail() && + (std::fabs(dest) == (std::numeric_limits<FloatType>::min)() || + std::fpclassify(dest) == FP_SUBNORMAL)) { + is.clear(is.rdstate() & (~std::ios_base::failbit)); + } + return dest; +} + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_IOSTREAM_STATE_SAVER_H_ diff --git a/absl/random/internal/iostream_state_saver_test.cc b/absl/random/internal/iostream_state_saver_test.cc new file mode 100644 index 000000000000..2ecbaac143dc --- /dev/null +++ b/absl/random/internal/iostream_state_saver_test.cc @@ -0,0 +1,369 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/iostream_state_saver.h" + +#include <sstream> +#include <string> + +#include "gtest/gtest.h" + +namespace { + +using absl::random_internal::make_istream_state_saver; +using absl::random_internal::make_ostream_state_saver; +using absl::random_internal::stream_precision_helper; + +template <typename T> +typename absl::enable_if_t<std::is_integral<T>::value, T> // +StreamRoundTrip(T t) { + std::stringstream ss; + { + auto saver = make_ostream_state_saver(ss); + ss.precision(stream_precision_helper<T>::kPrecision); + ss << t; + } + T result = 0; + { + auto saver = make_istream_state_saver(ss); + ss >> result; + } + EXPECT_FALSE(ss.fail()) // + << ss.str() << " " // + << (ss.good() ? "good " : "") // + << (ss.bad() ? "bad " : "") // + << (ss.eof() ? "eof " : "") // + << (ss.fail() ? "fail " : ""); + + return result; +} + +template <typename T> +typename absl::enable_if_t<std::is_floating_point<T>::value, T> // +StreamRoundTrip(T t) { + std::stringstream ss; + { + auto saver = make_ostream_state_saver(ss); + ss.precision(stream_precision_helper<T>::kPrecision); + ss << t; + } + T result = 0; + { + auto saver = make_istream_state_saver(ss); + result = absl::random_internal::read_floating_point<T>(ss); + } + EXPECT_FALSE(ss.fail()) // + << ss.str() << " " // + << (ss.good() ? "good " : "") // + << (ss.bad() ? "bad " : "") // + << (ss.eof() ? "eof " : "") // + << (ss.fail() ? "fail " : ""); + + return result; +} + +TEST(IOStreamStateSaver, BasicSaverState) { + std::stringstream ss; + ss.precision(2); + ss.fill('x'); + ss.flags(std::ios_base::dec | std::ios_base::right); + + { + auto saver = make_ostream_state_saver(ss); + ss.precision(10); + EXPECT_NE('x', ss.fill()); + EXPECT_EQ(10, ss.precision()); + EXPECT_NE(std::ios_base::dec | std::ios_base::right, ss.flags()); + + ss << 1.23; + } + + EXPECT_EQ('x', ss.fill()); + EXPECT_EQ(2, ss.precision()); + EXPECT_EQ(std::ios_base::dec | std::ios_base::right, ss.flags()); +} + +TEST(IOStreamStateSaver, RoundTripInts) { + const uint64_t kUintValues[] = { + 0, + 1, + static_cast<uint64_t>(-1), + 2, + static_cast<uint64_t>(-2), + + 1 << 7, + 1 << 8, + 1 << 16, + 1ull << 32, + 1ull << 50, + 1ull << 62, + 1ull << 63, + + (1 << 7) - 1, + (1 << 8) - 1, + (1 << 16) - 1, + (1ull << 32) - 1, + (1ull << 50) - 1, + (1ull << 62) - 1, + (1ull << 63) - 1, + + static_cast<uint64_t>(-(1 << 8)), + static_cast<uint64_t>(-(1 << 16)), + static_cast<uint64_t>(-(1ll << 32)), + static_cast<uint64_t>(-(1ll << 50)), + static_cast<uint64_t>(-(1ll << 62)), + + static_cast<uint64_t>(-(1 << 8) - 1), + static_cast<uint64_t>(-(1 << 16) - 1), + static_cast<uint64_t>(-(1ll << 32) - 1), + static_cast<uint64_t>(-(1ll << 50) - 1), + static_cast<uint64_t>(-(1ll << 62) - 1), + }; + + for (const uint64_t u : kUintValues) { + EXPECT_EQ(u, StreamRoundTrip<uint64_t>(u)); + + int64_t x = static_cast<int64_t>(u); + EXPECT_EQ(x, StreamRoundTrip<int64_t>(x)); + + double d = static_cast<double>(x); + EXPECT_EQ(d, StreamRoundTrip<double>(d)); + + float f = d; + EXPECT_EQ(f, StreamRoundTrip<float>(f)); + } +} + +TEST(IOStreamStateSaver, RoundTripFloats) { + static_assert( + stream_precision_helper<float>::kPrecision >= 9, + "stream_precision_helper<float>::kPrecision should be at least 9"); + + const float kValues[] = { + 1, + std::nextafter(1.0f, 0.0f), // 1 - epsilon + std::nextafter(1.0f, 2.0f), // 1 + epsilon + + 1.0e+1f, + 1.0e-1f, + 1.0e+2f, + 1.0e-2f, + 1.0e+10f, + 1.0e-10f, + + 0.00000051110000111311111111f, + -0.00000051110000111211111111f, + + 1.234678912345678912345e+6f, + 1.234678912345678912345e-6f, + 1.234678912345678912345e+30f, + 1.234678912345678912345e-30f, + 1.234678912345678912345e+38f, + 1.0234678912345678912345e-38f, + + // Boundary cases. + std::numeric_limits<float>::max(), + std::numeric_limits<float>::lowest(), + std::numeric_limits<float>::epsilon(), + std::nextafter(std::numeric_limits<float>::min(), + 1.0f), // min + epsilon + std::numeric_limits<float>::min(), // smallest normal + // There are some errors dealing with denorms on apple platforms. + std::numeric_limits<float>::denorm_min(), // smallest denorm + std::numeric_limits<float>::min() / 2, + std::nextafter(std::numeric_limits<float>::min(), + 0.0f), // denorm_max + std::nextafter(std::numeric_limits<float>::denorm_min(), 1.0f), + }; + + for (const float f : kValues) { + EXPECT_EQ(f, StreamRoundTrip<float>(f)); + EXPECT_EQ(-f, StreamRoundTrip<float>(-f)); + + double d = f; + EXPECT_EQ(d, StreamRoundTrip<double>(d)); + EXPECT_EQ(-d, StreamRoundTrip<double>(-d)); + + // Avoid undefined behavior (overflow/underflow). + if (d <= std::numeric_limits<int64_t>::max() && + d >= std::numeric_limits<int64_t>::lowest()) { + int64_t x = static_cast<int64_t>(f); + EXPECT_EQ(x, StreamRoundTrip<int64_t>(x)); + } + } +} + +TEST(IOStreamStateSaver, RoundTripDoubles) { + static_assert( + stream_precision_helper<double>::kPrecision >= 17, + "stream_precision_helper<double>::kPrecision should be at least 17"); + + const double kValues[] = { + 1, + std::nextafter(1.0, 0.0), // 1 - epsilon + std::nextafter(1.0, 2.0), // 1 + epsilon + + 1.0e+1, + 1.0e-1, + 1.0e+2, + 1.0e-2, + 1.0e+10, + 1.0e-10, + + 0.00000051110000111311111111, + -0.00000051110000111211111111, + + 1.234678912345678912345e+6, + 1.234678912345678912345e-6, + 1.234678912345678912345e+30, + 1.234678912345678912345e-30, + 1.234678912345678912345e+38, + 1.0234678912345678912345e-38, + + 1.0e+100, + 1.0e-100, + 1.234678912345678912345e+308, + 1.0234678912345678912345e-308, + 2.22507385850720138e-308, + + // Boundary cases. + std::numeric_limits<double>::max(), + std::numeric_limits<double>::lowest(), + std::numeric_limits<double>::epsilon(), + std::nextafter(std::numeric_limits<double>::min(), + 1.0), // min + epsilon + std::numeric_limits<double>::min(), // smallest normal + // There are some errors dealing with denorms on apple platforms. + std::numeric_limits<double>::denorm_min(), // smallest denorm + std::numeric_limits<double>::min() / 2, + std::nextafter(std::numeric_limits<double>::min(), + 0.0), // denorm_max + std::nextafter(std::numeric_limits<double>::denorm_min(), 1.0f), + }; + + for (const double d : kValues) { + EXPECT_EQ(d, StreamRoundTrip<double>(d)); + EXPECT_EQ(-d, StreamRoundTrip<double>(-d)); + + // Avoid undefined behavior (overflow/underflow). + if (d <= std::numeric_limits<float>::max() && + d >= std::numeric_limits<float>::lowest()) { + float f = static_cast<float>(d); + EXPECT_EQ(f, StreamRoundTrip<float>(f)); + } + + // Avoid undefined behavior (overflow/underflow). + if (d <= std::numeric_limits<int64_t>::max() && + d >= std::numeric_limits<int64_t>::lowest()) { + int64_t x = static_cast<int64_t>(d); + EXPECT_EQ(x, StreamRoundTrip<int64_t>(x)); + } + } +} + +TEST(IOStreamStateSaver, RoundTripLongDoubles) { + // Technically, C++ only guarantees that long double is at least as large as a + // double. Practically it varies from 64-bits to 128-bits. + // + // So it is best to consider long double a best-effort extended precision + // type. + + static_assert( + stream_precision_helper<long double>::kPrecision >= 36, + "stream_precision_helper<long double>::kPrecision should be at least 36"); + + using real_type = long double; + const real_type kValues[] = { + 1, + std::nextafter(1.0, 0.0), // 1 - epsilon + std::nextafter(1.0, 2.0), // 1 + epsilon + + 1.0e+1, + 1.0e-1, + 1.0e+2, + 1.0e-2, + 1.0e+10, + 1.0e-10, + + 0.00000051110000111311111111, + -0.00000051110000111211111111, + + 1.2346789123456789123456789123456789e+6, + 1.2346789123456789123456789123456789e-6, + 1.2346789123456789123456789123456789e+30, + 1.2346789123456789123456789123456789e-30, + 1.2346789123456789123456789123456789e+38, + 1.2346789123456789123456789123456789e-38, + 1.2346789123456789123456789123456789e+308, + 1.2346789123456789123456789123456789e-308, + + 1.0e+100, + 1.0e-100, + 1.234678912345678912345e+308, + 1.0234678912345678912345e-308, + + // Boundary cases. + std::numeric_limits<real_type>::max(), + std::numeric_limits<real_type>::lowest(), + std::numeric_limits<real_type>::epsilon(), + std::nextafter(std::numeric_limits<real_type>::min(), + real_type(1)), // min + epsilon + std::numeric_limits<real_type>::min(), // smallest normal + // There are some errors dealing with denorms on apple platforms. + std::numeric_limits<real_type>::denorm_min(), // smallest denorm + std::numeric_limits<real_type>::min() / 2, + std::nextafter(std::numeric_limits<real_type>::min(), + 0.0), // denorm_max + std::nextafter(std::numeric_limits<real_type>::denorm_min(), 1.0f), + }; + + int index = -1; + for (const long double dd : kValues) { + index++; + EXPECT_EQ(dd, StreamRoundTrip<real_type>(dd)) << index; + EXPECT_EQ(-dd, StreamRoundTrip<real_type>(-dd)) << index; + + // Avoid undefined behavior (overflow/underflow). + if (dd <= std::numeric_limits<double>::max() && + dd >= std::numeric_limits<double>::lowest()) { + double d = static_cast<double>(dd); + EXPECT_EQ(d, StreamRoundTrip<double>(d)); + } + + // Avoid undefined behavior (overflow/underflow). + if (dd <= std::numeric_limits<int64_t>::max() && + dd >= std::numeric_limits<int64_t>::lowest()) { + int64_t x = static_cast<int64_t>(dd); + EXPECT_EQ(x, StreamRoundTrip<int64_t>(x)); + } + } +} + +TEST(StrToDTest, DoubleMin) { + const char kV[] = "2.22507385850720138e-308"; + char* end; + double x = std::strtod(kV, &end); + EXPECT_EQ(std::numeric_limits<double>::min(), x); + // errno may equal ERANGE. +} + +TEST(StrToDTest, DoubleDenormMin) { + const char kV[] = "4.94065645841246544e-324"; + char* end; + double x = std::strtod(kV, &end); + EXPECT_EQ(std::numeric_limits<double>::denorm_min(), x); + // errno may equal ERANGE. +} + +} // namespace diff --git a/absl/random/internal/named_generator.cc b/absl/random/internal/named_generator.cc new file mode 100644 index 000000000000..b168a25be2e7 --- /dev/null +++ b/absl/random/internal/named_generator.cc @@ -0,0 +1,30 @@ +// Copyright 2018 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <cstddef> +#include <iostream> + +#include "absl/random/random.h" + +// This program is used in integration tests. + +int main() { + auto seed_seq = absl::MakeTaggedSeedSeq("TEST_GENERATOR", std::cerr); + absl::BitGen rng(seed_seq); + constexpr size_t kSequenceLength = 8; + for (size_t i = 0; i < kSequenceLength; i++) { + std::cout << rng() << "\n"; + } + return 0; +} diff --git a/absl/random/internal/nanobenchmark.cc b/absl/random/internal/nanobenchmark.cc new file mode 100644 index 000000000000..5a8b1ed1dba1 --- /dev/null +++ b/absl/random/internal/nanobenchmark.cc @@ -0,0 +1,792 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/nanobenchmark.h" + +#include <sys/types.h> + +#include <algorithm> // sort +#include <atomic> +#include <cstddef> +#include <cstdint> +#include <cstdlib> +#include <cstring> // memcpy +#include <limits> +#include <string> +#include <utility> +#include <vector> + +#include "absl/base/internal/raw_logging.h" +#include "absl/random/internal/platform.h" +#include "absl/random/internal/randen_engine.h" + +// OS +#if defined(_WIN32) || defined(_WIN64) +#define ABSL_OS_WIN +#include <windows.h> // NOLINT + +#elif defined(__ANDROID__) +#define ABSL_OS_ANDROID + +#elif defined(__linux__) +#define ABSL_OS_LINUX +#include <sched.h> // NOLINT +#include <sys/syscall.h> // NOLINT +#endif + +#if defined(ABSL_ARCH_X86_64) && !defined(ABSL_OS_WIN) +#include <cpuid.h> // NOLINT +#endif + +// __ppc_get_timebase_freq +#if defined(ABSL_ARCH_PPC) +#include <sys/platform/ppc.h> // NOLINT +#endif + +// clock_gettime +#if defined(ABSL_ARCH_ARM) || defined(ABSL_ARCH_AARCH64) +#include <time.h> // NOLINT +#endif + +namespace absl { +namespace random_internal_nanobenchmark { +namespace { + +// For code folding. +namespace platform { +#if defined(ABSL_ARCH_X86_64) + +// TODO(janwas): Merge with the one in randen_hwaes.cc? +void Cpuid(const uint32_t level, const uint32_t count, + uint32_t* ABSL_RANDOM_INTERNAL_RESTRICT abcd) { +#if defined(ABSL_OS_WIN) + int regs[4]; + __cpuidex(regs, level, count); + for (int i = 0; i < 4; ++i) { + abcd[i] = regs[i]; + } +#else + uint32_t a, b, c, d; + __cpuid_count(level, count, a, b, c, d); + abcd[0] = a; + abcd[1] = b; + abcd[2] = c; + abcd[3] = d; +#endif +} + +std::string BrandString() { + char brand_string[49]; + uint32_t abcd[4]; + + // Check if brand std::string is supported (it is on all reasonable Intel/AMD) + Cpuid(0x80000000U, 0, abcd); + if (abcd[0] < 0x80000004U) { + return std::string(); + } + + for (int i = 0; i < 3; ++i) { + Cpuid(0x80000002U + i, 0, abcd); + memcpy(brand_string + i * 16, &abcd, sizeof(abcd)); + } + brand_string[48] = 0; + return brand_string; +} + +// Returns the frequency quoted inside the brand string. This does not +// account for throttling nor Turbo Boost. +double NominalClockRate() { + const std::string& brand_string = BrandString(); + // Brand strings include the maximum configured frequency. These prefixes are + // defined by Intel CPUID documentation. + const char* prefixes[3] = {"MHz", "GHz", "THz"}; + const double multipliers[3] = {1E6, 1E9, 1E12}; + for (size_t i = 0; i < 3; ++i) { + const size_t pos_prefix = brand_string.find(prefixes[i]); + if (pos_prefix != std::string::npos) { + const size_t pos_space = brand_string.rfind(' ', pos_prefix - 1); + if (pos_space != std::string::npos) { + const std::string digits = + brand_string.substr(pos_space + 1, pos_prefix - pos_space - 1); + return std::stod(digits) * multipliers[i]; + } + } + } + + return 0.0; +} + +#endif // ABSL_ARCH_X86_64 +} // namespace platform + +// Prevents the compiler from eliding the computations that led to "output". +template <class T> +inline void PreventElision(T&& output) { +#ifndef ABSL_OS_WIN + // Works by indicating to the compiler that "output" is being read and + // modified. The +r constraint avoids unnecessary writes to memory, but only + // works for built-in types (typically FuncOutput). + asm volatile("" : "+r"(output) : : "memory"); +#else + // MSVC does not support inline assembly anymore (and never supported GCC's + // RTL constraints). Self-assignment with #pragma optimize("off") might be + // expected to prevent elision, but it does not with MSVC 2015. Type-punning + // with volatile pointers generates inefficient code on MSVC 2017. + static std::atomic<T> dummy(T{}); + dummy.store(output, std::memory_order_relaxed); +#endif +} + +namespace timer { + +// Start/Stop return absolute timestamps and must be placed immediately before +// and after the region to measure. We provide separate Start/Stop functions +// because they use different fences. +// +// Background: RDTSC is not 'serializing'; earlier instructions may complete +// after it, and/or later instructions may complete before it. 'Fences' ensure +// regions' elapsed times are independent of such reordering. The only +// documented unprivileged serializing instruction is CPUID, which acts as a +// full fence (no reordering across it in either direction). Unfortunately +// the latency of CPUID varies wildly (perhaps made worse by not initializing +// its EAX input). Because it cannot reliably be deducted from the region's +// elapsed time, it must not be included in the region to measure (i.e. +// between the two RDTSC). +// +// The newer RDTSCP is sometimes described as serializing, but it actually +// only serves as a half-fence with release semantics. Although all +// instructions in the region will complete before the final timestamp is +// captured, subsequent instructions may leak into the region and increase the +// elapsed time. Inserting another fence after the final RDTSCP would prevent +// such reordering without affecting the measured region. +// +// Fortunately, such a fence exists. The LFENCE instruction is only documented +// to delay later loads until earlier loads are visible. However, Intel's +// reference manual says it acts as a full fence (waiting until all earlier +// instructions have completed, and delaying later instructions until it +// completes). AMD assigns the same behavior to MFENCE. +// +// We need a fence before the initial RDTSC to prevent earlier instructions +// from leaking into the region, and arguably another after RDTSC to avoid +// region instructions from completing before the timestamp is recorded. +// When surrounded by fences, the additional RDTSCP half-fence provides no +// benefit, so the initial timestamp can be recorded via RDTSC, which has +// lower overhead than RDTSCP because it does not read TSC_AUX. In summary, +// we define Start = LFENCE/RDTSC/LFENCE; Stop = RDTSCP/LFENCE. +// +// Using Start+Start leads to higher variance and overhead than Stop+Stop. +// However, Stop+Stop includes an LFENCE in the region measurements, which +// adds a delay dependent on earlier loads. The combination of Start+Stop +// is faster than Start+Start and more consistent than Stop+Stop because +// the first LFENCE already delayed subsequent loads before the measured +// region. This combination seems not to have been considered in prior work: +// http://akaros.cs.berkeley.edu/lxr/akaros/kern/arch/x86/rdtsc_test.c +// +// Note: performance counters can measure 'exact' instructions-retired or +// (unhalted) cycle counts. The RDPMC instruction is not serializing and also +// requires fences. Unfortunately, it is not accessible on all OSes and we +// prefer to avoid kernel-mode drivers. Performance counters are also affected +// by several under/over-count errata, so we use the TSC instead. + +// Returns a 64-bit timestamp in unit of 'ticks'; to convert to seconds, +// divide by InvariantTicksPerSecond. +inline uint64_t Start64() { + uint64_t t; +#if defined(ABSL_ARCH_PPC) + asm volatile("mfspr %0, %1" : "=r"(t) : "i"(268)); +#elif defined(ABSL_ARCH_X86_64) +#if defined(ABSL_OS_WIN) + _ReadWriteBarrier(); + _mm_lfence(); + _ReadWriteBarrier(); + t = __rdtsc(); + _ReadWriteBarrier(); + _mm_lfence(); + _ReadWriteBarrier(); +#else + asm volatile( + "lfence\n\t" + "rdtsc\n\t" + "shl $32, %%rdx\n\t" + "or %%rdx, %0\n\t" + "lfence" + : "=a"(t) + : + // "memory" avoids reordering. rdx = TSC >> 32. + // "cc" = flags modified by SHL. + : "rdx", "memory", "cc"); +#endif +#else + // Fall back to OS - unsure how to reliably query cntvct_el0 frequency. + timespec ts; + clock_gettime(CLOCK_REALTIME, &ts); + t = ts.tv_sec * 1000000000LL + ts.tv_nsec; +#endif + return t; +} + +inline uint64_t Stop64() { + uint64_t t; +#if defined(ABSL_ARCH_X86_64) +#if defined(ABSL_OS_WIN) + _ReadWriteBarrier(); + unsigned aux; + t = __rdtscp(&aux); + _ReadWriteBarrier(); + _mm_lfence(); + _ReadWriteBarrier(); +#else + // Use inline asm because __rdtscp generates code to store TSC_AUX (ecx). + asm volatile( + "rdtscp\n\t" + "shl $32, %%rdx\n\t" + "or %%rdx, %0\n\t" + "lfence" + : "=a"(t) + : + // "memory" avoids reordering. rcx = TSC_AUX. rdx = TSC >> 32. + // "cc" = flags modified by SHL. + : "rcx", "rdx", "memory", "cc"); +#endif +#else + t = Start64(); +#endif + return t; +} + +// Returns a 32-bit timestamp with about 4 cycles less overhead than +// Start64. Only suitable for measuring very short regions because the +// timestamp overflows about once a second. +inline uint32_t Start32() { + uint32_t t; +#if defined(ABSL_ARCH_X86_64) +#if defined(ABSL_OS_WIN) + _ReadWriteBarrier(); + _mm_lfence(); + _ReadWriteBarrier(); + t = static_cast<uint32_t>(__rdtsc()); + _ReadWriteBarrier(); + _mm_lfence(); + _ReadWriteBarrier(); +#else + asm volatile( + "lfence\n\t" + "rdtsc\n\t" + "lfence" + : "=a"(t) + : + // "memory" avoids reordering. rdx = TSC >> 32. + : "rdx", "memory"); +#endif +#else + t = static_cast<uint32_t>(Start64()); +#endif + return t; +} + +inline uint32_t Stop32() { + uint32_t t; +#if defined(ABSL_ARCH_X86_64) +#if defined(ABSL_OS_WIN) + _ReadWriteBarrier(); + unsigned aux; + t = static_cast<uint32_t>(__rdtscp(&aux)); + _ReadWriteBarrier(); + _mm_lfence(); + _ReadWriteBarrier(); +#else + // Use inline asm because __rdtscp generates code to store TSC_AUX (ecx). + asm volatile( + "rdtscp\n\t" + "lfence" + : "=a"(t) + : + // "memory" avoids reordering. rcx = TSC_AUX. rdx = TSC >> 32. + : "rcx", "rdx", "memory"); +#endif +#else + t = static_cast<uint32_t>(Stop64()); +#endif + return t; +} + +} // namespace timer + +namespace robust_statistics { + +// Sorts integral values in ascending order (e.g. for Mode). About 3x faster +// than std::sort for input distributions with very few unique values. +template <class T> +void CountingSort(T* values, size_t num_values) { + // Unique values and their frequency (similar to flat_map). + using Unique = std::pair<T, int>; + std::vector<Unique> unique; + for (size_t i = 0; i < num_values; ++i) { + const T value = values[i]; + const auto pos = + std::find_if(unique.begin(), unique.end(), + [value](const Unique u) { return u.first == value; }); + if (pos == unique.end()) { + unique.push_back(std::make_pair(value, 1)); + } else { + ++pos->second; + } + } + + // Sort in ascending order of value (pair.first). + std::sort(unique.begin(), unique.end()); + + // Write that many copies of each unique value to the array. + T* ABSL_RANDOM_INTERNAL_RESTRICT p = values; + for (const auto& value_count : unique) { + std::fill(p, p + value_count.second, value_count.first); + p += value_count.second; + } + ABSL_RAW_CHECK(p == values + num_values, "Did not produce enough output"); +} + +// @return i in [idx_begin, idx_begin + half_count) that minimizes +// sorted[i + half_count] - sorted[i]. +template <typename T> +size_t MinRange(const T* const ABSL_RANDOM_INTERNAL_RESTRICT sorted, + const size_t idx_begin, const size_t half_count) { + T min_range = (std::numeric_limits<T>::max)(); + size_t min_idx = 0; + + for (size_t idx = idx_begin; idx < idx_begin + half_count; ++idx) { + ABSL_RAW_CHECK(sorted[idx] <= sorted[idx + half_count], "Not sorted"); + const T range = sorted[idx + half_count] - sorted[idx]; + if (range < min_range) { + min_range = range; + min_idx = idx; + } + } + + return min_idx; +} + +// Returns an estimate of the mode by calling MinRange on successively +// halved intervals. "sorted" must be in ascending order. This is the +// Half Sample Mode estimator proposed by Bickel in "On a fast, robust +// estimator of the mode", with complexity O(N log N). The mode is less +// affected by outliers in highly-skewed distributions than the median. +// The averaging operation below assumes "T" is an unsigned integer type. +template <typename T> +T ModeOfSorted(const T* const ABSL_RANDOM_INTERNAL_RESTRICT sorted, + const size_t num_values) { + size_t idx_begin = 0; + size_t half_count = num_values / 2; + while (half_count > 1) { + idx_begin = MinRange(sorted, idx_begin, half_count); + half_count >>= 1; + } + + const T x = sorted[idx_begin + 0]; + if (half_count == 0) { + return x; + } + ABSL_RAW_CHECK(half_count == 1, "Should stop at half_count=1"); + const T average = (x + sorted[idx_begin + 1] + 1) / 2; + return average; +} + +// Returns the mode. Side effect: sorts "values". +template <typename T> +T Mode(T* values, const size_t num_values) { + CountingSort(values, num_values); + return ModeOfSorted(values, num_values); +} + +template <typename T, size_t N> +T Mode(T (&values)[N]) { + return Mode(&values[0], N); +} + +// Returns the median value. Side effect: sorts "values". +template <typename T> +T Median(T* values, const size_t num_values) { + ABSL_RAW_CHECK(num_values != 0, "Empty input"); + std::sort(values, values + num_values); + const size_t half = num_values / 2; + // Odd count: return middle + if (num_values % 2) { + return values[half]; + } + // Even count: return average of middle two. + return (values[half] + values[half - 1] + 1) / 2; +} + +// Returns a robust measure of variability. +template <typename T> +T MedianAbsoluteDeviation(const T* values, const size_t num_values, + const T median) { + ABSL_RAW_CHECK(num_values != 0, "Empty input"); + std::vector<T> abs_deviations; + abs_deviations.reserve(num_values); + for (size_t i = 0; i < num_values; ++i) { + const int64_t abs = std::abs(int64_t(values[i]) - int64_t(median)); + abs_deviations.push_back(static_cast<T>(abs)); + } + return Median(abs_deviations.data(), num_values); +} + +} // namespace robust_statistics + +// Ticks := platform-specific timer values (CPU cycles on x86). Must be +// unsigned to guarantee wraparound on overflow. 32 bit timers are faster to +// read than 64 bit. +using Ticks = uint32_t; + +// Returns timer overhead / minimum measurable difference. +Ticks TimerResolution() { + // Nested loop avoids exceeding stack/L1 capacity. + Ticks repetitions[Params::kTimerSamples]; + for (size_t rep = 0; rep < Params::kTimerSamples; ++rep) { + Ticks samples[Params::kTimerSamples]; + for (size_t i = 0; i < Params::kTimerSamples; ++i) { + const Ticks t0 = timer::Start32(); + const Ticks t1 = timer::Stop32(); + samples[i] = t1 - t0; + } + repetitions[rep] = robust_statistics::Mode(samples); + } + return robust_statistics::Mode(repetitions); +} + +static const Ticks timer_resolution = TimerResolution(); + +// Estimates the expected value of "lambda" values with a variable number of +// samples until the variability "rel_mad" is less than "max_rel_mad". +template <class Lambda> +Ticks SampleUntilStable(const double max_rel_mad, double* rel_mad, + const Params& p, const Lambda& lambda) { + auto measure_duration = [&lambda]() -> Ticks { + const Ticks t0 = timer::Start32(); + lambda(); + const Ticks t1 = timer::Stop32(); + return t1 - t0; + }; + + // Choose initial samples_per_eval based on a single estimated duration. + Ticks est = measure_duration(); + static const double ticks_per_second = InvariantTicksPerSecond(); + const size_t ticks_per_eval = ticks_per_second * p.seconds_per_eval; + size_t samples_per_eval = ticks_per_eval / est; + samples_per_eval = (std::max)(samples_per_eval, p.min_samples_per_eval); + + std::vector<Ticks> samples; + samples.reserve(1 + samples_per_eval); + samples.push_back(est); + + // Percentage is too strict for tiny differences, so also allow a small + // absolute "median absolute deviation". + const Ticks max_abs_mad = (timer_resolution + 99) / 100; + *rel_mad = 0.0; // ensure initialized + + for (size_t eval = 0; eval < p.max_evals; ++eval, samples_per_eval *= 2) { + samples.reserve(samples.size() + samples_per_eval); + for (size_t i = 0; i < samples_per_eval; ++i) { + const Ticks r = measure_duration(); + samples.push_back(r); + } + + if (samples.size() >= p.min_mode_samples) { + est = robust_statistics::Mode(samples.data(), samples.size()); + } else { + // For "few" (depends also on the variance) samples, Median is safer. + est = robust_statistics::Median(samples.data(), samples.size()); + } + ABSL_RAW_CHECK(est != 0, "Estimator returned zero duration"); + + // Median absolute deviation (mad) is a robust measure of 'variability'. + const Ticks abs_mad = robust_statistics::MedianAbsoluteDeviation( + samples.data(), samples.size(), est); + *rel_mad = static_cast<double>(static_cast<int>(abs_mad)) / est; + + if (*rel_mad <= max_rel_mad || abs_mad <= max_abs_mad) { + if (p.verbose) { + ABSL_RAW_LOG(INFO, + "%6zu samples => %5u (abs_mad=%4u, rel_mad=%4.2f%%)\n", + samples.size(), est, abs_mad, *rel_mad * 100.0); + } + return est; + } + } + + if (p.verbose) { + ABSL_RAW_LOG(WARNING, + "rel_mad=%4.2f%% still exceeds %4.2f%% after %6zu samples.\n", + *rel_mad * 100.0, max_rel_mad * 100.0, samples.size()); + } + return est; +} + +using InputVec = std::vector<FuncInput>; + +// Returns vector of unique input values. +InputVec UniqueInputs(const FuncInput* inputs, const size_t num_inputs) { + InputVec unique(inputs, inputs + num_inputs); + std::sort(unique.begin(), unique.end()); + unique.erase(std::unique(unique.begin(), unique.end()), unique.end()); + return unique; +} + +// Returns how often we need to call func for sufficient precision, or zero +// on failure (e.g. the elapsed time is too long for a 32-bit tick count). +size_t NumSkip(const Func func, const void* arg, const InputVec& unique, + const Params& p) { + // Min elapsed ticks for any input. + Ticks min_duration = ~0u; + + for (const FuncInput input : unique) { + // Make sure a 32-bit timer is sufficient. + const uint64_t t0 = timer::Start64(); + PreventElision(func(arg, input)); + const uint64_t t1 = timer::Stop64(); + const uint64_t elapsed = t1 - t0; + if (elapsed >= (1ULL << 30)) { + ABSL_RAW_LOG(WARNING, + "Measurement failed: need 64-bit timer for input=%zu\n", + static_cast<size_t>(input)); + return 0; + } + + double rel_mad; + const Ticks total = SampleUntilStable( + p.target_rel_mad, &rel_mad, p, + [func, arg, input]() { PreventElision(func(arg, input)); }); + min_duration = (std::min)(min_duration, total - timer_resolution); + } + + // Number of repetitions required to reach the target resolution. + const size_t max_skip = p.precision_divisor; + // Number of repetitions given the estimated duration. + const size_t num_skip = + min_duration == 0 ? 0 : (max_skip + min_duration - 1) / min_duration; + if (p.verbose) { + ABSL_RAW_LOG(INFO, "res=%u max_skip=%zu min_dur=%u num_skip=%zu\n", + timer_resolution, max_skip, min_duration, num_skip); + } + return num_skip; +} + +// Replicates inputs until we can omit "num_skip" occurrences of an input. +InputVec ReplicateInputs(const FuncInput* inputs, const size_t num_inputs, + const size_t num_unique, const size_t num_skip, + const Params& p) { + InputVec full; + if (num_unique == 1) { + full.assign(p.subset_ratio * num_skip, inputs[0]); + return full; + } + + full.reserve(p.subset_ratio * num_skip * num_inputs); + for (size_t i = 0; i < p.subset_ratio * num_skip; ++i) { + full.insert(full.end(), inputs, inputs + num_inputs); + } + absl::random_internal::randen_engine<uint32_t> rng; + std::shuffle(full.begin(), full.end(), rng); + return full; +} + +// Copies the "full" to "subset" in the same order, but with "num_skip" +// randomly selected occurrences of "input_to_skip" removed. +void FillSubset(const InputVec& full, const FuncInput input_to_skip, + const size_t num_skip, InputVec* subset) { + const size_t count = std::count(full.begin(), full.end(), input_to_skip); + // Generate num_skip random indices: which occurrence to skip. + std::vector<uint32_t> omit; + // Replacement for std::iota, not yet available in MSVC builds. + omit.reserve(count); + for (size_t i = 0; i < count; ++i) { + omit.push_back(i); + } + // omit[] is the same on every call, but that's OK because they identify the + // Nth instance of input_to_skip, so the position within full[] differs. + absl::random_internal::randen_engine<uint32_t> rng; + std::shuffle(omit.begin(), omit.end(), rng); + omit.resize(num_skip); + std::sort(omit.begin(), omit.end()); + + uint32_t occurrence = ~0u; // 0 after preincrement + size_t idx_omit = 0; // cursor within omit[] + size_t idx_subset = 0; // cursor within *subset + for (const FuncInput next : full) { + if (next == input_to_skip) { + ++occurrence; + // Haven't removed enough already + if (idx_omit < num_skip) { + // This one is up for removal + if (occurrence == omit[idx_omit]) { + ++idx_omit; + continue; + } + } + } + if (idx_subset < subset->size()) { + (*subset)[idx_subset++] = next; + } + } + ABSL_RAW_CHECK(idx_subset == subset->size(), "idx_subset not at end"); + ABSL_RAW_CHECK(idx_omit == omit.size(), "idx_omit not at end"); + ABSL_RAW_CHECK(occurrence == count - 1, "occurrence not at end"); +} + +// Returns total ticks elapsed for all inputs. +Ticks TotalDuration(const Func func, const void* arg, const InputVec* inputs, + const Params& p, double* max_rel_mad) { + double rel_mad; + const Ticks duration = + SampleUntilStable(p.target_rel_mad, &rel_mad, p, [func, arg, inputs]() { + for (const FuncInput input : *inputs) { + PreventElision(func(arg, input)); + } + }); + *max_rel_mad = (std::max)(*max_rel_mad, rel_mad); + return duration; +} + +// (Nearly) empty Func for measuring timer overhead/resolution. +ABSL_ATTRIBUTE_NEVER_INLINE FuncOutput EmptyFunc(const void* arg, + const FuncInput input) { + return input; +} + +// Returns overhead of accessing inputs[] and calling a function; this will +// be deducted from future TotalDuration return values. +Ticks Overhead(const void* arg, const InputVec* inputs, const Params& p) { + double rel_mad; + // Zero tolerance because repeatability is crucial and EmptyFunc is fast. + return SampleUntilStable(0.0, &rel_mad, p, [arg, inputs]() { + for (const FuncInput input : *inputs) { + PreventElision(EmptyFunc(arg, input)); + } + }); +} + +} // namespace + +void PinThreadToCPU(int cpu) { + // We might migrate to another CPU before pinning below, but at least cpu + // will be one of the CPUs on which this thread ran. +#if defined(ABSL_OS_WIN) + if (cpu < 0) { + cpu = static_cast<int>(GetCurrentProcessorNumber()); + ABSL_RAW_CHECK(cpu >= 0, "PinThreadToCPU detect failed"); + if (cpu >= 64) { + // NOTE: On wine, at least, GetCurrentProcessorNumber() sometimes returns + // a value > 64, which is out of range. When this happens, log a message + // and don't set a cpu affinity. + ABSL_RAW_LOG(ERROR, "Invalid CPU number: %d", cpu); + return; + } + } else if (cpu >= 64) { + // User specified an explicit CPU affinity > the valid range. + ABSL_RAW_LOG(FATAL, "Invalid CPU number: %d", cpu); + } + const DWORD_PTR prev = SetThreadAffinityMask(GetCurrentThread(), 1ULL << cpu); + ABSL_RAW_CHECK(prev != 0, "SetAffinity failed"); +#elif defined(ABSL_OS_LINUX) && !defined(ABSL_OS_ANDROID) + if (cpu < 0) { + cpu = sched_getcpu(); + ABSL_RAW_CHECK(cpu >= 0, "PinThreadToCPU detect failed"); + } + const pid_t pid = 0; // current thread + cpu_set_t set; + CPU_ZERO(&set); + CPU_SET(cpu, &set); + const int err = sched_setaffinity(pid, sizeof(set), &set); + ABSL_RAW_CHECK(err == 0, "SetAffinity failed"); +#endif +} + +// Returns tick rate. Invariant means the tick counter frequency is independent +// of CPU throttling or sleep. May be expensive, caller should cache the result. +double InvariantTicksPerSecond() { +#if defined(ABSL_ARCH_PPC) + return __ppc_get_timebase_freq(); +#elif defined(ABSL_ARCH_X86_64) + // We assume the TSC is invariant; it is on all recent Intel/AMD CPUs. + return platform::NominalClockRate(); +#else + // Fall back to clock_gettime nanoseconds. + return 1E9; +#endif +} + +size_t MeasureImpl(const Func func, const void* arg, const size_t num_skip, + const InputVec& unique, const InputVec& full, + const Params& p, Result* results) { + const float mul = 1.0f / static_cast<int>(num_skip); + + InputVec subset(full.size() - num_skip); + const Ticks overhead = Overhead(arg, &full, p); + const Ticks overhead_skip = Overhead(arg, &subset, p); + if (overhead < overhead_skip) { + ABSL_RAW_LOG(WARNING, "Measurement failed: overhead %u < %u\n", overhead, + overhead_skip); + return 0; + } + + if (p.verbose) { + ABSL_RAW_LOG(INFO, "#inputs=%5zu,%5zu overhead=%5u,%5u\n", full.size(), + subset.size(), overhead, overhead_skip); + } + + double max_rel_mad = 0.0; + const Ticks total = TotalDuration(func, arg, &full, p, &max_rel_mad); + + for (size_t i = 0; i < unique.size(); ++i) { + FillSubset(full, unique[i], num_skip, &subset); + const Ticks total_skip = TotalDuration(func, arg, &subset, p, &max_rel_mad); + + if (total < total_skip) { + ABSL_RAW_LOG(WARNING, "Measurement failed: total %u < %u\n", total, + total_skip); + return 0; + } + + const Ticks duration = (total - overhead) - (total_skip - overhead_skip); + results[i].input = unique[i]; + results[i].ticks = duration * mul; + results[i].variability = max_rel_mad; + } + + return unique.size(); +} + +size_t Measure(const Func func, const void* arg, const FuncInput* inputs, + const size_t num_inputs, Result* results, const Params& p) { + ABSL_RAW_CHECK(num_inputs != 0, "No inputs"); + + const InputVec unique = UniqueInputs(inputs, num_inputs); + const size_t num_skip = NumSkip(func, arg, unique, p); // never 0 + if (num_skip == 0) return 0; // NumSkip already printed error message + + const InputVec full = + ReplicateInputs(inputs, num_inputs, unique.size(), num_skip, p); + + // MeasureImpl may fail up to p.max_measure_retries times. + for (size_t i = 0; i < p.max_measure_retries; i++) { + auto result = MeasureImpl(func, arg, num_skip, unique, full, p, results); + if (result != 0) { + return result; + } + } + // All retries failed. (Unusual) + return 0; +} + +} // namespace random_internal_nanobenchmark +} // namespace absl diff --git a/absl/random/internal/nanobenchmark.h b/absl/random/internal/nanobenchmark.h new file mode 100644 index 000000000000..c2b650d191b7 --- /dev/null +++ b/absl/random/internal/nanobenchmark.h @@ -0,0 +1,168 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_NANOBENCHMARK_H_ +#define ABSL_RANDOM_INTERNAL_NANOBENCHMARK_H_ + +// Benchmarks functions of a single integer argument with realistic branch +// prediction hit rates. Uses a robust estimator to summarize the measurements. +// The precision is about 0.2%. +// +// Examples: see nanobenchmark_test.cc. +// +// Background: Microbenchmarks such as http://github.com/google/benchmark +// can measure elapsed times on the order of a microsecond. Shorter functions +// are typically measured by repeating them thousands of times and dividing +// the total elapsed time by this count. Unfortunately, repetition (especially +// with the same input parameter!) influences the runtime. In time-critical +// code, it is reasonable to expect warm instruction/data caches and TLBs, +// but a perfect record of which branches will be taken is unrealistic. +// Unless the application also repeatedly invokes the measured function with +// the same parameter, the benchmark is measuring something very different - +// a best-case result, almost as if the parameter were made a compile-time +// constant. This may lead to erroneous conclusions about branch-heavy +// algorithms outperforming branch-free alternatives. +// +// Our approach differs in three ways. Adding fences to the timer functions +// reduces variability due to instruction reordering, improving the timer +// resolution to about 40 CPU cycles. However, shorter functions must still +// be invoked repeatedly. For more realistic branch prediction performance, +// we vary the input parameter according to a user-specified distribution. +// Thus, instead of VaryInputs(Measure(Repeat(func))), we change the +// loop nesting to Measure(Repeat(VaryInputs(func))). We also estimate the +// central tendency of the measurement samples with the "half sample mode", +// which is more robust to outliers and skewed data than the mean or median. + +// NOTE: for compatibility with multiple translation units compiled with +// distinct flags, avoid #including headers that define functions. + +#include <stddef.h> +#include <stdint.h> + +namespace absl { +namespace random_internal_nanobenchmark { + +// Input influencing the function being measured (e.g. number of bytes to copy). +using FuncInput = size_t; + +// "Proof of work" returned by Func to ensure the compiler does not elide it. +using FuncOutput = uint64_t; + +// Function to measure: either 1) a captureless lambda or function with two +// arguments or 2) a lambda with capture, in which case the first argument +// is reserved for use by MeasureClosure. +using Func = FuncOutput (*)(const void*, FuncInput); + +// Internal parameters that determine precision/resolution/measuring time. +struct Params { + // For measuring timer overhead/resolution. Used in a nested loop => + // quadratic time, acceptable because we know timer overhead is "low". + // constexpr because this is used to define array bounds. + static constexpr size_t kTimerSamples = 256; + + // Best-case precision, expressed as a divisor of the timer resolution. + // Larger => more calls to Func and higher precision. + size_t precision_divisor = 1024; + + // Ratio between full and subset input distribution sizes. Cannot be less + // than 2; larger values increase measurement time but more faithfully + // model the given input distribution. + size_t subset_ratio = 2; + + // Together with the estimated Func duration, determines how many times to + // call Func before checking the sample variability. Larger values increase + // measurement time, memory/cache use and precision. + double seconds_per_eval = 4E-3; + + // The minimum number of samples before estimating the central tendency. + size_t min_samples_per_eval = 7; + + // The mode is better than median for estimating the central tendency of + // skewed/fat-tailed distributions, but it requires sufficient samples + // relative to the width of half-ranges. + size_t min_mode_samples = 64; + + // Maximum permissible variability (= median absolute deviation / center). + double target_rel_mad = 0.002; + + // Abort after this many evals without reaching target_rel_mad. This + // prevents infinite loops. + size_t max_evals = 9; + + // Retry the measure loop up to this many times. + size_t max_measure_retries = 2; + + // Whether to print additional statistics to stdout. + bool verbose = true; +}; + +// Measurement result for each unique input. +struct Result { + FuncInput input; + + // Robust estimate (mode or median) of duration. + float ticks; + + // Measure of variability (median absolute deviation relative to "ticks"). + float variability; +}; + +// Ensures the thread is running on the specified cpu, and no others. +// Reduces noise due to desynchronized socket RDTSC and context switches. +// If "cpu" is negative, pin to the currently running core. +void PinThreadToCPU(const int cpu = -1); + +// Returns tick rate, useful for converting measurements to seconds. Invariant +// means the tick counter frequency is independent of CPU throttling or sleep. +// This call may be expensive, callers should cache the result. +double InvariantTicksPerSecond(); + +// Precisely measures the number of ticks elapsed when calling "func" with the +// given inputs, shuffled to ensure realistic branch prediction hit rates. +// +// "func" returns a 'proof of work' to ensure its computations are not elided. +// "arg" is passed to Func, or reserved for internal use by MeasureClosure. +// "inputs" is an array of "num_inputs" (not necessarily unique) arguments to +// "func". The values should be chosen to maximize coverage of "func". This +// represents a distribution, so a value's frequency should reflect its +// probability in the real application. Order does not matter; for example, a +// uniform distribution over [0, 4) could be represented as {3,0,2,1}. +// Returns how many Result were written to "results": one per unique input, or +// zero if the measurement failed (an error message goes to stderr). +size_t Measure(const Func func, const void* arg, const FuncInput* inputs, + const size_t num_inputs, Result* results, + const Params& p = Params()); + +// Calls operator() of the given closure (lambda function). +template <class Closure> +static FuncOutput CallClosure(const void* f, const FuncInput input) { + return (*reinterpret_cast<const Closure*>(f))(input); +} + +// Same as Measure, except "closure" is typically a lambda function of +// FuncInput -> FuncOutput with a capture list. +template <class Closure> +static inline size_t MeasureClosure(const Closure& closure, + const FuncInput* inputs, + const size_t num_inputs, Result* results, + const Params& p = Params()) { + return Measure(reinterpret_cast<Func>(&CallClosure<Closure>), + reinterpret_cast<const void*>(&closure), inputs, num_inputs, + results, p); +} + +} // namespace random_internal_nanobenchmark +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_NANOBENCHMARK_H_ diff --git a/absl/random/internal/nanobenchmark_test.cc b/absl/random/internal/nanobenchmark_test.cc new file mode 100644 index 000000000000..383345a8b474 --- /dev/null +++ b/absl/random/internal/nanobenchmark_test.cc @@ -0,0 +1,75 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/nanobenchmark.h" + +#include "absl/base/internal/raw_logging.h" +#include "absl/strings/numbers.h" + +namespace absl { +namespace random_internal_nanobenchmark { +namespace { + +uint64_t Div(const void*, FuncInput in) { + // Here we're measuring the throughput because benchmark invocations are + // independent. + const int64_t d1 = 0xFFFFFFFFFFll / int64_t(in); // IDIV + return d1; +} + +template <size_t N> +void MeasureDiv(const FuncInput (&inputs)[N]) { + Result results[N]; + Params params; + params.max_evals = 6; // avoid test timeout + const size_t num_results = Measure(&Div, nullptr, inputs, N, results, params); + if (num_results == 0) { + ABSL_RAW_LOG( + WARNING, + "WARNING: Measurement failed, should not happen when using " + "PinThreadToCPU unless the region to measure takes > 1 second.\n"); + return; + } + for (size_t i = 0; i < num_results; ++i) { + ABSL_RAW_LOG(INFO, "%5zu: %6.2f ticks; MAD=%4.2f%%\n", results[i].input, + results[i].ticks, results[i].variability * 100.0); + ABSL_RAW_CHECK(results[i].ticks != 0.0f, "Zero duration"); + } +} + +void RunAll(const int argc, char* argv[]) { + // Avoid migrating between cores - important on multi-socket systems. + int cpu = -1; + if (argc == 2) { + if (!SimpleAtoi(argv[1], &cpu)) { + ABSL_RAW_LOG(FATAL, "The optional argument must be a CPU number >= 0.\n"); + } + } + PinThreadToCPU(cpu); + + // unpredictable == 1 but the compiler doesn't know that. + const FuncInput unpredictable = argc != 999; + static const FuncInput inputs[] = {unpredictable * 10, unpredictable * 100}; + + MeasureDiv(inputs); +} + +} // namespace +} // namespace random_internal_nanobenchmark +} // namespace absl + +int main(int argc, char* argv[]) { + absl::random_internal_nanobenchmark::RunAll(argc, argv); + return 0; +} diff --git a/absl/random/internal/nonsecure_base.h b/absl/random/internal/nonsecure_base.h new file mode 100644 index 000000000000..8847e74bef6b --- /dev/null +++ b/absl/random/internal/nonsecure_base.h @@ -0,0 +1,148 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_NONSECURE_BASE_H_ +#define ABSL_RANDOM_INTERNAL_NONSECURE_BASE_H_ + +#include <algorithm> +#include <cstdint> +#include <iostream> +#include <iterator> +#include <random> +#include <string> +#include <type_traits> +#include <vector> + +#include "absl/base/macros.h" +#include "absl/meta/type_traits.h" +#include "absl/random/internal/pool_urbg.h" +#include "absl/random/internal/salted_seed_seq.h" +#include "absl/random/internal/seed_material.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" + +namespace absl { +namespace random_internal { + +// Each instance of NonsecureURBGBase<URBG> will be seeded by variates produced +// by a thread-unique URBG-instance. +template <typename URBG> +class NonsecureURBGBase { + public: + using result_type = typename URBG::result_type; + + // Default constructor + NonsecureURBGBase() : urbg_(ConstructURBG()) {} + + // Copy disallowed, move allowed. + NonsecureURBGBase(const NonsecureURBGBase&) = delete; + NonsecureURBGBase& operator=(const NonsecureURBGBase&) = delete; + NonsecureURBGBase(NonsecureURBGBase&&) = default; + NonsecureURBGBase& operator=(NonsecureURBGBase&&) = default; + + // Constructor using a seed + template <class SSeq, typename = typename absl::enable_if_t< + !std::is_same<SSeq, NonsecureURBGBase>::value>> + explicit NonsecureURBGBase(SSeq&& seq) + : urbg_(ConstructURBG(std::forward<SSeq>(seq))) {} + + // Note: on MSVC, min() or max() can be interpreted as MIN() or MAX(), so we + // enclose min() or max() in parens as (min)() and (max)(). + // Additionally, clang-format requires no space before this construction. + + // NonsecureURBGBase::min() + static constexpr result_type(min)() { return (URBG::min)(); } + + // NonsecureURBGBase::max() + static constexpr result_type(max)() { return (URBG::max)(); } + + // NonsecureURBGBase::operator()() + result_type operator()() { return urbg_(); } + + // NonsecureURBGBase::discard() + void discard(unsigned long long values) { // NOLINT(runtime/int) + urbg_.discard(values); + } + + bool operator==(const NonsecureURBGBase& other) const { + return urbg_ == other.urbg_; + } + + bool operator!=(const NonsecureURBGBase& other) const { + return !(urbg_ == other.urbg_); + } + + private: + // Seeder is a custom seed sequence type where generate() fills the provided + // buffer via the RandenPool entropy source. + struct Seeder { + using result_type = uint32_t; + + size_t size() { return 0; } + + template <typename OutIterator> + void param(OutIterator) const {} + + template <typename RandomAccessIterator> + void generate(RandomAccessIterator begin, RandomAccessIterator end) { + if (begin != end) { + // begin, end must be random access iterators assignable from uint32_t. + generate_impl( + std::integral_constant<bool, sizeof(*begin) == sizeof(uint32_t)>{}, + begin, end); + } + } + + // Commonly, generate is invoked with a pointer to a buffer which + // can be cast to a uint32_t. + template <typename RandomAccessIterator> + void generate_impl(std::integral_constant<bool, true>, + RandomAccessIterator begin, RandomAccessIterator end) { + auto buffer = absl::MakeSpan(begin, end); + auto target = absl::MakeSpan(reinterpret_cast<uint32_t*>(buffer.data()), + buffer.size()); + RandenPool<uint32_t>::Fill(target); + } + + // The non-uint32_t case should be uncommon, and involves an extra copy, + // filling the uint32_t buffer and then mixing into the output. + template <typename RandomAccessIterator> + void generate_impl(std::integral_constant<bool, false>, + RandomAccessIterator begin, RandomAccessIterator end) { + const size_t n = std::distance(begin, end); + absl::InlinedVector<uint32_t, 8> data(n, 0); + RandenPool<uint32_t>::Fill(absl::MakeSpan(data.begin(), data.end())); + std::copy(std::begin(data), std::end(data), begin); + } + }; + + static URBG ConstructURBG() { + Seeder seeder; + return URBG(seeder); + } + + template <typename SSeq> + static URBG ConstructURBG(SSeq&& seq) { // NOLINT(runtime/references) + auto salted_seq = + random_internal::MakeSaltedSeedSeq(std::forward<SSeq>(seq)); + return URBG(salted_seq); + } + + URBG urbg_; +}; + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_NONSECURE_BASE_H_ diff --git a/absl/random/internal/nonsecure_base_test.cc b/absl/random/internal/nonsecure_base_test.cc new file mode 100644 index 000000000000..d9de99019ce0 --- /dev/null +++ b/absl/random/internal/nonsecure_base_test.cc @@ -0,0 +1,244 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/nonsecure_base.h" + +#include <algorithm> +#include <iostream> +#include <memory> +#include <random> +#include <sstream> + +#include "gtest/gtest.h" +#include "absl/random/distributions.h" +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" + +namespace { + +using ExampleNonsecureURBG = + absl::random_internal::NonsecureURBGBase<std::mt19937>; + +template <typename T> +void Use(const T&) {} + +} // namespace + +TEST(NonsecureURBGBase, DefaultConstructorIsValid) { + ExampleNonsecureURBG urbg; +} + +// Ensure that the recommended template-instantiations are valid. +TEST(RecommendedTemplates, CanBeConstructed) { + absl::BitGen default_generator; + absl::InsecureBitGen insecure_generator; +} + +TEST(RecommendedTemplates, CanDiscardValues) { + absl::BitGen default_generator; + absl::InsecureBitGen insecure_generator; + + default_generator.discard(5); + insecure_generator.discard(5); +} + +TEST(NonsecureURBGBase, StandardInterface) { + // Names after definition of [rand.req.urbg] in C++ standard. + // e us a value of E + // v is a lvalue of E + // x, y are possibly const values of E + // s is a value of T + // q is a value satisfying requirements of seed_sequence + // z is a value of type unsigned long long + // os is a some specialization of basic_ostream + // is is a some specialization of basic_istream + + using E = absl::random_internal::NonsecureURBGBase<std::minstd_rand>; + + using T = typename E::result_type; + + static_assert(!std::is_copy_constructible<E>::value, + "NonsecureURBGBase should not be copy constructible"); + + static_assert(!absl::is_copy_assignable<E>::value, + "NonsecureURBGBase should not be copy assignable"); + + static_assert(std::is_move_constructible<E>::value, + "NonsecureURBGBase should be move constructible"); + + static_assert(absl::is_move_assignable<E>::value, + "NonsecureURBGBase should be move assignable"); + + static_assert(std::is_same<decltype(std::declval<E>()()), T>::value, + "return type of operator() must be result_type"); + + { + const E x, y; + Use(x); + Use(y); + + static_assert(std::is_same<decltype(x == y), bool>::value, + "return type of operator== must be bool"); + + static_assert(std::is_same<decltype(x != y), bool>::value, + "return type of operator== must be bool"); + } + + E e; + std::seed_seq q{1, 2, 3}; + + E{}; + E{q}; + + // Copy constructor not supported. + // E{x}; + + // result_type seed constructor not supported. + // E{T{1}}; + + // Move constructors are supported. + { + E tmp(q); + E m = std::move(tmp); + E n(std::move(m)); + EXPECT_TRUE(e != n); + } + + // Comparisons work. + { + // MSVC emits error 2718 when using EXPECT_EQ(e, x) + // * actual parameter with __declspec(align('#')) won't be aligned + E a(q); + E b(q); + + EXPECT_TRUE(a != e); + EXPECT_TRUE(a == b); + + a(); + EXPECT_TRUE(a != b); + } + + // e.seed(s) not supported. + + // [rand.req.eng] specifies the parameter as 'unsigned long long' + // e.discard(unsigned long long) is supported. + unsigned long long z = 1; // NOLINT(runtime/int) + e.discard(z); +} + +TEST(NonsecureURBGBase, SeedSeqConstructorIsValid) { + std::seed_seq seq; + ExampleNonsecureURBG rbg(seq); +} + +TEST(NonsecureURBGBase, CompatibleWithDistributionUtils) { + ExampleNonsecureURBG rbg; + + absl::Uniform(rbg, 0, 100); + absl::Uniform(rbg, 0.5, 0.7); + absl::Poisson<uint32_t>(rbg); + absl::Exponential<float>(rbg); +} + +TEST(NonsecureURBGBase, CompatibleWithStdDistributions) { + ExampleNonsecureURBG rbg; + + std::uniform_int_distribution<uint32_t>(0, 100)(rbg); + std::uniform_real_distribution<float>()(rbg); + std::bernoulli_distribution(0.2)(rbg); +} + +TEST(NonsecureURBGBase, ConsecutiveDefaultInstancesYieldUniqueVariates) { + const size_t kNumSamples = 128; + + ExampleNonsecureURBG rbg1; + ExampleNonsecureURBG rbg2; + + for (size_t i = 0; i < kNumSamples; i++) { + EXPECT_NE(rbg1(), rbg2()); + } +} + +TEST(NonsecureURBGBase, EqualSeedSequencesYieldEqualVariates) { + std::seed_seq seq; + + ExampleNonsecureURBG rbg1(seq); + ExampleNonsecureURBG rbg2(seq); + + // ExampleNonsecureURBG rbg3({1, 2, 3}); // Should not compile. + + for (uint32_t i = 0; i < 1000; i++) { + EXPECT_EQ(rbg1(), rbg2()); + } + + rbg1.discard(100); + rbg2.discard(100); + + // The sequences should continue after discarding + for (uint32_t i = 0; i < 1000; i++) { + EXPECT_EQ(rbg1(), rbg2()); + } +} + +// This is a PRNG-compatible type specifically designed to test +// that NonsecureURBGBase::Seeder can correctly handle iterators +// to arbitrary non-uint32_t size types. +template <typename T> +struct SeederTestEngine { + using result_type = T; + + static constexpr result_type(min)() { + return (std::numeric_limits<result_type>::min)(); + } + static constexpr result_type(max)() { + return (std::numeric_limits<result_type>::max)(); + } + + template <class SeedSequence, + typename = typename absl::enable_if_t< + !std::is_same<SeedSequence, SeederTestEngine>::value>> + explicit SeederTestEngine(SeedSequence&& seq) { + seed(seq); + } + + SeederTestEngine(const SeederTestEngine&) = default; + SeederTestEngine& operator=(const SeederTestEngine&) = default; + SeederTestEngine(SeederTestEngine&&) = default; + SeederTestEngine& operator=(SeederTestEngine&&) = default; + + result_type operator()() { return state[0]; } + + template <class SeedSequence> + void seed(SeedSequence&& seq) { + std::fill(std::begin(state), std::end(state), T(0)); + seq.generate(std::begin(state), std::end(state)); + } + + T state[2]; +}; + +TEST(NonsecureURBGBase, SeederWorksForU32) { + using U32 = + absl::random_internal::NonsecureURBGBase<SeederTestEngine<uint32_t>>; + U32 x; + EXPECT_NE(0, x()); +} + +TEST(NonsecureURBGBase, SeederWorksForU64) { + using U64 = + absl::random_internal::NonsecureURBGBase<SeederTestEngine<uint64_t>>; + + U64 x; + EXPECT_NE(0, x()); +} diff --git a/absl/random/internal/pcg_engine.h b/absl/random/internal/pcg_engine.h new file mode 100644 index 000000000000..33fea0b9689f --- /dev/null +++ b/absl/random/internal/pcg_engine.h @@ -0,0 +1,305 @@ +// Copyright 2018 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_PCG_ENGINE_H_ +#define ABSL_RANDOM_PCG_ENGINE_H_ + +#include <type_traits> + +#include "absl/base/config.h" +#include "absl/meta/type_traits.h" +#include "absl/numeric/int128.h" +#include "absl/random/internal/fastmath.h" +#include "absl/random/internal/iostream_state_saver.h" + +namespace absl { +namespace random_internal { + +// pcg_engine is a simplified implementation of Melissa O'Neil's PCG engine in +// C++. PCG combines a linear congruential generator (LCG) with output state +// mixing functions to generate each random variate. pcg_engine supports only a +// single sequence (oneseq), and does not support streams. +// +// pcg_engine is parameterized by two types: +// Params, which provides the multiplier and increment values; +// Mix, which mixes the state into the result. +// +template <typename Params, typename Mix> +class pcg_engine { + static_assert(std::is_same<typename Params::state_type, + typename Mix::state_type>::value, + "Class-template absl::pcg_engine must be parameterized by " + "Params and Mix with identical state_type"); + + static_assert(std::is_unsigned<typename Mix::result_type>::value, + "Class-template absl::pcg_engine must be parameterized by " + "an unsigned Mix::result_type"); + + using params_type = Params; + using mix_type = Mix; + using state_type = typename Mix::state_type; + + public: + // C++11 URBG interface: + using result_type = typename Mix::result_type; + + static constexpr result_type(min)() { + return (std::numeric_limits<result_type>::min)(); + } + + static constexpr result_type(max)() { + return (std::numeric_limits<result_type>::max)(); + } + + explicit pcg_engine(uint64_t seed_value = 0) { seed(seed_value); } + + template <class SeedSequence, + typename = typename absl::enable_if_t< + !std::is_same<SeedSequence, pcg_engine>::value>> + explicit pcg_engine(SeedSequence&& seq) { + seed(seq); + } + + pcg_engine(const pcg_engine&) = default; + pcg_engine& operator=(const pcg_engine&) = default; + pcg_engine(pcg_engine&&) = default; + pcg_engine& operator=(pcg_engine&&) = default; + + result_type operator()() { + // Advance the LCG state, always using the new value to generate the output. + state_ = lcg(state_); + return Mix{}(state_); + } + + void seed(uint64_t seed_value = 0) { + state_type tmp = seed_value; + state_ = lcg(tmp + Params::increment()); + } + + template <class SeedSequence> + typename absl::enable_if_t< + !std::is_convertible<SeedSequence, uint64_t>::value, void> + seed(SeedSequence&& seq) { + reseed(seq); + } + + void discard(uint64_t count) { state_ = advance(state_, count); } + + bool operator==(const pcg_engine& other) const { + return state_ == other.state_; + } + + bool operator!=(const pcg_engine& other) const { return !(*this == other); } + + template <class CharT, class Traits> + friend typename absl::enable_if_t<(sizeof(state_type) == 16), + std::basic_ostream<CharT, Traits>&> + operator<<( + std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references) + const pcg_engine& engine) { + auto saver = random_internal::make_ostream_state_saver(os); + random_internal::stream_u128_helper<state_type> helper; + helper.write(pcg_engine::params_type::multiplier(), os); + os << os.fill(); + helper.write(pcg_engine::params_type::increment(), os); + os << os.fill(); + helper.write(engine.state_, os); + return os; + } + + template <class CharT, class Traits> + friend typename absl::enable_if_t<(sizeof(state_type) <= 8), + std::basic_ostream<CharT, Traits>&> + operator<<( + std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references) + const pcg_engine& engine) { + auto saver = random_internal::make_ostream_state_saver(os); + os << pcg_engine::params_type::multiplier() << os.fill(); + os << pcg_engine::params_type::increment() << os.fill(); + os << engine.state_; + return os; + } + + template <class CharT, class Traits> + friend typename absl::enable_if_t<(sizeof(state_type) == 16), + std::basic_istream<CharT, Traits>&> + operator>>( + std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references) + pcg_engine& engine) { // NOLINT(runtime/references) + random_internal::stream_u128_helper<state_type> helper; + auto mult = helper.read(is); + auto inc = helper.read(is); + auto tmp = helper.read(is); + if (mult != pcg_engine::params_type::multiplier() || + inc != pcg_engine::params_type::increment()) { + // signal failure by setting the failbit. + is.setstate(is.rdstate() | std::ios_base::failbit); + } + if (!is.fail()) { + engine.state_ = tmp; + } + return is; + } + + template <class CharT, class Traits> + friend typename absl::enable_if_t<(sizeof(state_type) <= 8), + std::basic_istream<CharT, Traits>&> + operator>>( + std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references) + pcg_engine& engine) { // NOLINT(runtime/references) + state_type mult{}, inc{}, tmp{}; + is >> mult >> inc >> tmp; + if (mult != pcg_engine::params_type::multiplier() || + inc != pcg_engine::params_type::increment()) { + // signal failure by setting the failbit. + is.setstate(is.rdstate() | std::ios_base::failbit); + } + if (!is.fail()) { + engine.state_ = tmp; + } + return is; + } + + private: + state_type state_; + + // Returns the linear-congruential generator next state. + static inline constexpr state_type lcg(state_type s) { + return s * Params::multiplier() + Params::increment(); + } + + // Returns the linear-congruential arbitrary seek state. + inline state_type advance(state_type s, uint64_t n) const { + state_type mult = Params::multiplier(); + state_type inc = Params::increment(); + state_type m = 1; + state_type i = 0; + while (n > 0) { + if (n & 1) { + m *= mult; + i = i * mult + inc; + } + inc = (mult + 1) * inc; + mult *= mult; + n >>= 1; + } + return m * s + i; + } + + template <class SeedSequence> + void reseed(SeedSequence& seq) { + using sequence_result_type = typename SeedSequence::result_type; + constexpr size_t kBufferSize = + sizeof(state_type) / sizeof(sequence_result_type); + sequence_result_type buffer[kBufferSize]; + seq.generate(std::begin(buffer), std::end(buffer)); + // Convert the seed output to a single state value. + state_type tmp = buffer[0]; + for (size_t i = 1; i < kBufferSize; i++) { + tmp <<= (sizeof(sequence_result_type) * 8); + tmp |= buffer[i]; + } + state_ = lcg(tmp + params_type::increment()); + } +}; + +// Parameterized implementation of the PCG 128-bit oneseq state. +// This provides state_type, multiplier, and increment for pcg_engine. +template <uint64_t kMultA, uint64_t kMultB, uint64_t kIncA, uint64_t kIncB> +class pcg128_params { + public: +#if ABSL_HAVE_INTRINSIC_INT128 + using state_type = __uint128_t; + static inline constexpr state_type make_u128(uint64_t a, uint64_t b) { + return (static_cast<__uint128_t>(a) << 64) | b; + } +#else + using state_type = absl::uint128; + static inline constexpr state_type make_u128(uint64_t a, uint64_t b) { + return absl::MakeUint128(a, b); + } +#endif + + static inline constexpr state_type multiplier() { + return make_u128(kMultA, kMultB); + } + static inline constexpr state_type increment() { + return make_u128(kIncA, kIncB); + } +}; + +// Implementation of the PCG xsl_rr_128_64 128-bit mixing function, which +// accepts an input of state_type and mixes it into an output of result_type. +struct pcg_xsl_rr_128_64 { +#if ABSL_HAVE_INTRINSIC_INT128 + using state_type = __uint128_t; +#else + using state_type = absl::uint128; +#endif + using result_type = uint64_t; + + inline uint64_t operator()(state_type state) { + // This is equivalent to the xsl_rr_128_64 mixing function. +#if ABSL_HAVE_INTRINSIC_INT128 + uint64_t rotate = static_cast<uint64_t>(state >> 122u); + state ^= state >> 64; + uint64_t s = static_cast<uint64_t>(state); +#else + uint64_t h = Uint128High64(state); + uint64_t rotate = h >> 58u; + uint64_t s = Uint128Low64(state) ^ h; +#endif + return random_internal::rotr(s, rotate); + } +}; + +// Parameterized implementation of the PCG 64-bit oneseq state. +// This provides state_type, multiplier, and increment for pcg_engine. +template <uint64_t kMult, uint64_t kInc> +class pcg64_params { + public: + using state_type = uint64_t; + static inline constexpr state_type multiplier() { return kMult; } + static inline constexpr state_type increment() { return kInc; } +}; + +// Implementation of the PCG xsh_rr_64_32 64-bit mixing function, which accepts +// an input of state_type and mixes it into an output of result_type. +struct pcg_xsh_rr_64_32 { + using state_type = uint64_t; + using result_type = uint32_t; + inline uint32_t operator()(uint64_t state) { + return random_internal::rotr( + static_cast<uint32_t>(((state >> 18) ^ state) >> 27), state >> 59); + } +}; + +// Stable pcg_engine implementations: +// This is a 64-bit generator using 128-bits of state. +// The output sequence is equivalent to Melissa O'Neil's pcg64_oneseq. +using pcg64_2018_engine = pcg_engine< + random_internal::pcg128_params<0x2360ed051fc65da4ull, 0x4385df649fccf645ull, + 0x5851f42d4c957f2d, 0x14057b7ef767814f>, + random_internal::pcg_xsl_rr_128_64>; + +// This is a 32-bit generator using 64-bits of state. +// This is equivalent to Melissa O'Neil's pcg32_oneseq. +using pcg32_2018_engine = pcg_engine< + random_internal::pcg64_params<0x5851f42d4c957f2dull, 0x14057b7ef767814full>, + random_internal::pcg_xsh_rr_64_32>; + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_PCG2018_ENGINE_H_ diff --git a/absl/random/internal/pcg_engine_test.cc b/absl/random/internal/pcg_engine_test.cc new file mode 100644 index 000000000000..4d763e89eb58 --- /dev/null +++ b/absl/random/internal/pcg_engine_test.cc @@ -0,0 +1,638 @@ +// Copyright 2018 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/pcg_engine.h" + +#include <algorithm> +#include <bitset> +#include <random> +#include <sstream> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/random/internal/explicit_seed_seq.h" +#include "absl/time/clock.h" + +#define UPDATE_GOLDEN 0 + +namespace { + +using absl::random_internal::ExplicitSeedSeq; +using absl::random_internal::pcg32_2018_engine; +using absl::random_internal::pcg64_2018_engine; + +template <typename EngineType> +class PCGEngineTest : public ::testing::Test {}; + +using EngineTypes = ::testing::Types<pcg64_2018_engine, pcg32_2018_engine>; + +TYPED_TEST_SUITE(PCGEngineTest, EngineTypes); + +TYPED_TEST(PCGEngineTest, VerifyReseedChangesAllValues) { + using engine_type = TypeParam; + using result_type = typename engine_type::result_type; + + const size_t kNumOutputs = 16; + engine_type engine; + + // MSVC emits error 2719 without the use of std::ref below. + // * formal parameter with __declspec(align('#')) won't be aligned + + { + std::seed_seq seq1{1, 2, 3, 4, 5, 6, 7}; + engine.seed(seq1); + } + result_type a[kNumOutputs]; + std::generate(std::begin(a), std::end(a), std::ref(engine)); + + { + std::random_device rd; + std::seed_seq seq2{rd(), rd(), rd()}; + engine.seed(seq2); + } + result_type b[kNumOutputs]; + std::generate(std::begin(b), std::end(b), std::ref(engine)); + + // Verify that two uncorrelated values have ~50% of there bits in common. Use + // a 10% margin-of-error to reduce flakiness. + size_t changed_bits = 0; + size_t unchanged_bits = 0; + size_t total_set = 0; + size_t total_bits = 0; + size_t equal_count = 0; + for (size_t i = 0; i < kNumOutputs; ++i) { + equal_count += (a[i] == b[i]) ? 1 : 0; + std::bitset<sizeof(result_type) * 8> bitset(a[i] ^ b[i]); + changed_bits += bitset.count(); + unchanged_bits += bitset.size() - bitset.count(); + + std::bitset<sizeof(result_type) * 8> a_set(a[i]); + std::bitset<sizeof(result_type) * 8> b_set(b[i]); + total_set += a_set.count() + b_set.count(); + total_bits += 2 * 8 * sizeof(result_type); + } + // On average, half the bits are changed between two calls. + EXPECT_LE(changed_bits, 0.60 * (changed_bits + unchanged_bits)); + EXPECT_GE(changed_bits, 0.40 * (changed_bits + unchanged_bits)); + + // verify using a quick normal-approximation to the binomial. + EXPECT_NEAR(total_set, total_bits * 0.5, 4 * std::sqrt(total_bits)) + << "@" << total_set / static_cast<double>(total_bits); + + // Also, A[i] == B[i] with probability (1/range) * N. + // Give this a pretty wide latitude, though. + const double kExpected = kNumOutputs / (1.0 * sizeof(result_type) * 8); + EXPECT_LE(equal_count, 1.0 + kExpected); +} + +// Number of values that needs to be consumed to clean two sizes of buffer +// and trigger third refresh. (slightly overestimates the actual state size). +constexpr size_t kTwoBufferValues = 16; + +TYPED_TEST(PCGEngineTest, VerifyDiscard) { + using engine_type = TypeParam; + + for (size_t num_used = 0; num_used < kTwoBufferValues; ++num_used) { + engine_type engine_used; + for (size_t i = 0; i < num_used; ++i) { + engine_used(); + } + + for (size_t num_discard = 0; num_discard < kTwoBufferValues; + ++num_discard) { + engine_type engine1 = engine_used; + engine_type engine2 = engine_used; + for (size_t i = 0; i < num_discard; ++i) { + engine1(); + } + engine2.discard(num_discard); + for (size_t i = 0; i < kTwoBufferValues; ++i) { + const auto r1 = engine1(); + const auto r2 = engine2(); + ASSERT_EQ(r1, r2) << "used=" << num_used << " discard=" << num_discard; + } + } + } +} + +TYPED_TEST(PCGEngineTest, StreamOperatorsResult) { + using engine_type = TypeParam; + + std::wostringstream os; + std::wistringstream is; + engine_type engine; + + EXPECT_EQ(&(os << engine), &os); + EXPECT_EQ(&(is >> engine), &is); +} + +TYPED_TEST(PCGEngineTest, StreamSerialization) { + using engine_type = TypeParam; + + for (size_t discard = 0; discard < kTwoBufferValues; ++discard) { + ExplicitSeedSeq seed_sequence{12, 34, 56}; + engine_type engine(seed_sequence); + engine.discard(discard); + + std::stringstream stream; + stream << engine; + + engine_type new_engine; + stream >> new_engine; + for (size_t i = 0; i < 64; ++i) { + EXPECT_EQ(engine(), new_engine()) << " " << i; + } + } +} + +constexpr size_t kNumGoldenOutputs = 127; + +// This test is checking if randen_engine is meets interface requirements +// defined in [rand.req.urbg]. +TYPED_TEST(PCGEngineTest, RandomNumberEngineInterface) { + using engine_type = TypeParam; + + using E = engine_type; + using T = typename E::result_type; + + static_assert(std::is_copy_constructible<E>::value, + "engine_type must be copy constructible"); + + static_assert(absl::is_copy_assignable<E>::value, + "engine_type must be copy assignable"); + + static_assert(std::is_move_constructible<E>::value, + "engine_type must be move constructible"); + + static_assert(absl::is_move_assignable<E>::value, + "engine_type must be move assignable"); + + static_assert(std::is_same<decltype(std::declval<E>()()), T>::value, + "return type of operator() must be result_type"); + + // Names after definition of [rand.req.urbg] in C++ standard. + // e us a value of E + // v is a lvalue of E + // x, y are possibly const values of E + // s is a value of T + // q is a value satisfying requirements of seed_sequence + // z is a value of type unsigned long long + // os is a some specialization of basic_ostream + // is is a some specialization of basic_istream + + E e, v; + const E x, y; + T s = 1; + std::seed_seq q{1, 2, 3}; + unsigned long long z = 1; // NOLINT(runtime/int) + std::wostringstream os; + std::wistringstream is; + + E{}; + E{x}; + E{s}; + E{q}; + + e.seed(); + + // MSVC emits error 2718 when using EXPECT_EQ(e, x) + // * actual parameter with __declspec(align('#')) won't be aligned + EXPECT_TRUE(e == x); + + e.seed(q); + { + E tmp(q); + EXPECT_TRUE(e == tmp); + } + + e(); + { + E tmp(q); + EXPECT_TRUE(e != tmp); + } + + e.discard(z); + + static_assert(std::is_same<decltype(x == y), bool>::value, + "return type of operator== must be bool"); + + static_assert(std::is_same<decltype(x != y), bool>::value, + "return type of operator== must be bool"); +} + +TYPED_TEST(PCGEngineTest, RandenEngineSFINAETest) { + using engine_type = TypeParam; + using result_type = typename engine_type::result_type; + + { + engine_type engine(result_type(1)); + engine.seed(result_type(1)); + } + + { + result_type n = 1; + engine_type engine(n); + engine.seed(n); + } + + { + engine_type engine(1); + engine.seed(1); + } + + { + int n = 1; + engine_type engine(n); + engine.seed(n); + } + + { + std::seed_seq seed_seq; + engine_type engine(seed_seq); + engine.seed(seed_seq); + } + + { + engine_type engine{std::seed_seq()}; + engine.seed(std::seed_seq()); + } +} + +// ------------------------------------------------------------------ +// Stability tests for pcg64_2018_engine +// ------------------------------------------------------------------ +TEST(PCG642018EngineTest, VerifyGolden) { + constexpr uint64_t kGolden[kNumGoldenOutputs] = { + 0x01070196e695f8f1, 0x703ec840c59f4493, 0xe54954914b3a44fa, + 0x96130ff204b9285e, 0x7d9fdef535ceb21a, 0x666feed42e1219a0, + 0x981f685721c8326f, 0xad80710d6eab4dda, 0xe202c480b037a029, + 0x5d3390eaedd907e2, 0x0756befb39c6b8aa, 0x1fb44ba6634d62a3, + 0x8d20423662426642, 0x34ea910167a39fb4, 0x93010b43a80d0ab6, + 0x663db08a98fc568a, 0x720b0a1335956fae, 0x2c35483e31e1d3ba, + 0x429f39776337409d, 0xb46d99e638687344, 0x105370b96aedcaee, + 0x3999e92f811cff71, 0xd230f8bcb591cfc9, 0x0dce3db2ba7bdea5, + 0xcf2f52c91eec99af, 0x2bc7c24a8b998a39, 0xbd8af1b0d599a19c, + 0x56bc45abc66059f5, 0x170a46dc170f7f1e, 0xc25daf5277b85fad, + 0xe629c2e0c948eadb, 0x1720a796915542ed, 0x22fb0caa4f909951, + 0x7e0c0f4175acd83d, 0xd9fcab37ff2a860c, 0xab2280fb2054bad1, + 0x58e8a06f37fa9e99, 0xc3a52a30b06528c7, 0x0175f773a13fc1bd, + 0x731cfc584b00e840, 0x404cc7b2648069cb, 0x5bc29153b0b7f783, + 0x771310a38cc999d1, 0x766a572f0a71a916, 0x90f450fb4fc48348, + 0xf080ea3e1c7b1a0d, 0x15471a4507d66a44, 0x7d58e55a78f3df69, + 0x0130a094576ac99c, 0x46669cb2d04b1d87, 0x17ab5bed20191840, + 0x95b177d260adff3e, 0x025fb624b6ee4c07, 0xb35de4330154a95f, + 0xe8510fff67e24c79, 0x132c3cbcd76ed2d3, 0x35e7cc145a093904, + 0x9f5b5b5f81583b79, 0x3ee749a533966233, 0x4af85886cdeda8cd, + 0x0ca5380ecb3ef3aa, 0x4f674eb7661d3192, 0x88a29aad00cd7733, + 0x70b627ca045ffac6, 0x5912b43ea887623d, 0x95dc9fc6f62cf221, + 0x926081a12a5c905b, 0x9c57d4cd7dfce651, 0x85ab2cbf23e3bb5d, + 0xc5cd669f63023152, 0x3067be0fad5d898e, 0x12b56f444cb53d05, + 0xbc2e5a640c3434fc, 0x9280bff0e4613fe1, 0x98819094c528743e, + 0x999d1c98d829df33, 0x9ff82a012dc89242, 0xf99183ed39c8be94, + 0xf0f59161cd421c55, 0x3c705730c2f6c48d, 0x66ad85c6e9278a61, + 0x2a3428e4a428d5d0, 0x79207d68fd04940d, 0xea7f2b402edc8430, + 0xa06b419ac857f63b, 0xcb1dd0e6fbc47e1c, 0x4f55229200ada6a4, + 0x9647b5e6359c927f, 0x30bf8f9197c7efe5, 0xa79519529cc384d0, + 0xbb22c4f339ad6497, 0xd7b9782f59d14175, 0x0dff12fff2ec0118, + 0xa331ad8305343a7c, 0x48dad7e3f17e0862, 0x324c6fb3fd3c9665, + 0xf0e4350e7933dfc4, 0x7ccda2f30b8b03b6, 0xa0afc6179005de40, + 0xee65da6d063b3a30, 0xb9506f42f2bfe87a, 0xc9a2e26b0ef5baa0, + 0x39fa9d4f495011d6, 0xbecc21a45d023948, 0x6bf484c6593f737f, + 0x8065e0070cadc3b7, 0x9ef617ed8d419799, 0xac692cf8c233dd15, + 0xd2ed87583c4ebb98, 0xad95ba1bebfedc62, 0x9b60b160a8264e43, + 0x0bc8c45f71fcf25b, 0x4a78035cdf1c9931, 0x4602dc106667e029, + 0xb335a3c250498ac8, 0x0256ebc4df20cab8, 0x0c61efd153f0c8d9, + 0xe5d0150a4f806f88, 0x99d6521d351e7d87, 0x8d4888c9f80f4325, + 0x106c5735c1ba868d, 0x73414881b880a878, 0x808a9a58a3064751, + 0x339a29f3746de3d5, 0x5410d7fa4f873896, 0xd84623c81d7b8a03, + 0x1f7c7e7a7f47f462, + }; + + pcg64_2018_engine engine(0); +#if UPDATE_GOLDEN + (void)kGolden; // Silence warning. + for (size_t i = 0; i < kNumGoldenOutputs; ++i) { + printf("0x%016lx, ", engine()); + if (i % 3 == 2) { + printf("\n"); + } + } + printf("\n\n\n"); +#else + for (const auto& elem : kGolden) { + EXPECT_EQ(elem, engine()); + } + engine.seed(); + for (const auto& elem : kGolden) { + EXPECT_EQ(elem, engine()); + } +#endif +} + +TEST(PCG642018EngineTest, VerifyGoldenSeeded) { + constexpr uint64_t kGolden[kNumGoldenOutputs] = { + 0xb03988f1e39691ee, 0xbd2a1eb5ac31e97a, 0x8f00d6d433634d02, + 0x1823c28d483d5776, 0x000c3ee3e1aeb74a, 0xfa82ef27a4f3df9c, + 0xc6f382308654e454, 0x414afb1a238996c2, 0x4703a4bc252eb411, + 0x99d64f62c8f7f654, 0xbb07ebe11a34fa44, 0x79eb06a363c06131, + 0xf66ad3756f1c6b21, 0x130c01d5e869f457, 0x5ca2b9963aecbc81, + 0xfef7bebc1de27e6c, 0x1d174faa5ed2cdbf, 0xd75b7a773f2bb889, + 0xc35c872327a170a5, 0x46da6d88646a42fe, 0x4622985e0442dae2, + 0xbe3cbd67297f1f9b, 0xe7c37b4a4798bfd1, 0x173d5dfad15a25c3, + 0x0eb6849ba2961522, 0xb0ff7246e6700d73, 0x88cb9c42d3afa577, + 0xb609731dbd94d917, 0xd3941cda04b40081, 0x28d140f7409bea3a, + 0x3c96699a920a124a, 0xdb28be521958b2fd, 0x0a3f44db3d4c5124, + 0x7ac8e60ba13b70d2, 0x75f03a41ded5195a, 0xaed10ac7c4e4825d, + 0xb92a3b18aadb7adc, 0xda45e0081f2bca46, 0x74d39ab3753143fc, + 0xb686038018fac9ca, 0x4cc309fe99542dbb, 0xf3e1a4fcb311097c, + 0x58763d6fa698d69d, 0xd11c365dbecd8d60, 0x2c15d55725b1dee7, + 0x89805f254d85658c, 0x2374c44dfc62158b, 0x9a8350fa7995328d, + 0x198f838970cf91da, 0x96aff569562c0e53, 0xd76c8c52b7ec6e3f, + 0x23a01cd9ae4baa81, 0x3adb366b6d02a893, 0xb3313e2a4c5b333f, + 0x04c11230b96a5425, 0x1f7f7af04787d571, 0xaddb019365275ec7, + 0x5c960468ccb09f42, 0x8438db698c69a44a, 0x492be1e46111637e, + 0x9c6c01e18100c610, 0xbfe48e75b7d0aceb, 0xb5e0b89ec1ce6a00, + 0x9d280ecbc2fe8997, 0x290d9e991ba5fcab, 0xeec5bec7d9d2a4f0, + 0x726e81488f19150e, 0x1a6df7955a7e462c, 0x37a12d174ba46bb5, + 0x3cdcdffd96b1b5c5, 0x2c5d5ac10661a26e, 0xa742ed18f22e50c4, + 0x00e0ed88ff0d8a35, 0x3d3c1718cb1efc0b, 0x1d70c51ffbccbf11, + 0xfbbb895132a4092f, 0x619d27f2fb095f24, 0x69af68200985e5c4, + 0xbee4885f57373f8d, 0x10b7a6bfe0587e40, 0xa885e6cf2f7e5f0a, + 0x59f879464f767550, 0x24e805d69056990d, 0x860970b911095891, + 0xca3189954f84170d, 0x6652a5edd4590134, 0x5e1008cef76174bf, + 0xcbd417881f2bcfe5, 0xfd49fc9d706ecd17, 0xeebf540221ebd066, + 0x46af7679464504cb, 0xd4028486946956f1, 0xd4f41864b86c2103, + 0x7af090e751583372, 0x98cdaa09278cb642, 0xffd42b921215602f, + 0x1d05bec8466b1740, 0xf036fa78a0132044, 0x787880589d1ecc78, + 0x5644552cfef33230, 0x0a97e275fe06884b, 0x96d1b13333d470b5, + 0xc8b3cdad52d3b034, 0x091357b9db7376fd, 0xa5fe4232555edf8c, + 0x3371bc3b6ada76b5, 0x7deeb2300477c995, 0x6fc6d4244f2849c1, + 0x750e8cc797ca340a, 0x81728613cd79899f, 0x3467f4ee6f9aeb93, + 0x5ef0a905f58c640f, 0x432db85e5101c98a, 0x6488e96f46ac80c2, + 0x22fddb282625048c, 0x15b287a0bc2d4c5d, 0xa7e2343ef1f28bce, + 0xc87ee1aa89bed09e, 0x220610107812c5e9, 0xcbdab6fcd640f586, + 0x8d41047970928784, 0x1aa431509ec1ade0, 0xac3f0be53f518ddc, + 0x16f4428ad81d0cbb, 0x675b13c2736fc4bb, 0x6db073afdd87e32d, + 0x572f3ca2f1a078c6, + }; + + ExplicitSeedSeq seed_sequence{12, 34, 56}; + pcg64_2018_engine engine(seed_sequence); +#if UPDATE_GOLDEN + (void)kGolden; // Silence warning. + for (size_t i = 0; i < kNumGoldenOutputs; ++i) { + printf("0x%016lx, ", engine()); + if (i % 3 == 2) { + printf("\n"); + } + } + printf("\n\n\n"); +#else + for (const auto& elem : kGolden) { + EXPECT_EQ(elem, engine()); + } + engine.seed(seed_sequence); + for (const auto& elem : kGolden) { + EXPECT_EQ(elem, engine()); + } +#endif +} + +TEST(PCG642018EngineTest, VerifyGoldenFromDeserializedEngine) { + constexpr uint64_t kGolden[kNumGoldenOutputs] = { + 0xdd425b47b4113dea, 0x1b07176479d444b0, 0x6b391027586f2e42, + 0xa166f2b15f4a2143, 0xffb6dbd7a179ee97, 0xb2c00035365bf0b1, + 0x8fbb518b45855521, 0xfc789a55ddf87c3b, 0x429531f0f17ff355, + 0xbe708560d603d283, 0x5bff415175c5cb6b, 0xe813491f4ad45394, + 0xa853f4506d55880d, 0x7e538453e568172e, 0xe101f1e098ddd0ec, + 0x6ee31266ee4c766d, 0xa8786d92d66b39d7, 0xfee622a2acf5e5b0, + 0x5fe8e82c102fa7b3, 0x01f10be4cdb53c9d, 0xbe0545366f857022, + 0x12e74f010a339bca, 0xb10d85ca40d5ce34, 0xe80d6feba5054875, + 0x2b7c1ee6d567d4ee, 0x2a9cd043bfd03b66, 0x5cfc531bd239f3f1, + 0x1c4734e4647d70f5, 0x85a8f60f006b5760, 0x6a4239ce76dca387, + 0x8da0f86d7339335c, 0xf055b0468551374d, 0x486e8567e9bea9a0, + 0x4cb531b8405192dd, 0xf813b1ee3157110b, 0x214c2a664a875d8e, + 0x74531237b29b35f7, 0xa6f0267bb77a771e, 0x64b552bff54184a4, + 0xa2d6f7af2d75b6fc, 0x460a10018e03b5ab, 0x76fd1fdcb81d0800, + 0x76f5f81805070d9d, 0x1fb75cb1a70b289a, 0x9dfd25a022c4b27f, + 0x9a31a14a80528e9e, 0x910dc565ddc25820, 0xd6aef8e2b0936c10, + 0xe1773c507fe70225, 0xe027fd7aadd632bc, 0xc1fecb427089c8b8, + 0xb5c74c69fa9dbf26, 0x71bf9b0e4670227d, 0x25f48fad205dcfdd, + 0x905248ec4d689c56, 0x5c2b7631b0de5c9d, 0x9f2ee0f8f485036c, + 0xfd6ce4ebb90bf7ea, 0xd435d20046085574, 0x6b7eadcb0625f986, + 0x679d7d44b48be89e, 0x49683b8e1cdc49de, 0x4366cf76e9a2f4ca, + 0x54026ec1cdad7bed, 0xa9a04385207f28d3, 0xc8e66de4eba074b2, + 0x40b08c42de0f4cc0, 0x1d4c5e0e93c5bbc0, 0x19b80792e470ae2d, + 0x6fcaaeaa4c2a5bd9, 0xa92cb07c4238438e, 0x8bb5c918a007e298, + 0x7cd671e944874cf4, 0x88166470b1ba3cac, 0xd013d476eaeeade6, + 0xcee416947189b3c3, 0x5d7c16ab0dce6088, 0xd3578a5c32b13d27, + 0x3875db5adc9cc973, 0xfbdaba01c5b5dc56, 0xffc4fdd391b231c3, + 0x2334520ecb164fec, 0x361c115e7b6de1fa, 0xeee58106cc3563d7, + 0x8b7f35a8db25ebb8, 0xb29d00211e2cafa6, 0x22a39fe4614b646b, + 0x92ca6de8b998506d, 0x40922fe3d388d1db, 0x9da47f1e540f802a, + 0x811dceebf16a25db, 0xf6524ae22e0e53a9, 0x52d9e780a16eb99d, + 0x4f504286bb830207, 0xf6654d4786bd5cc3, 0x00bd98316003a7e1, + 0xefda054a6ab8f5f3, 0x46cfb0f4c1872827, 0xc22b316965c0f3b2, + 0xd1a28087c7e7562a, 0xaa4f6a094b7f5cff, 0xfe2bc853a041f7da, + 0xe9d531402a83c3ba, 0xe545d8663d3ce4dd, 0xfa2dcd7d91a13fa8, + 0xda1a080e52a127b8, 0x19c98f1f809c3d84, 0x2cef109af4678c88, + 0x53462accab3b9132, 0x176b13a80415394e, 0xea70047ef6bc178b, + 0x57bca80506d6dcdf, 0xd853ba09ff09f5c4, 0x75f4df3a7ddd4775, + 0x209c367ade62f4fe, 0xa9a0bbc74d5f4682, 0x5dfe34bada86c21a, + 0xc2c05bbcd38566d1, 0x6de8088e348c916a, 0x6a7001c6000c2196, + 0xd9fb51865fc4a367, 0x12f320e444ece8ff, 0x6d56f7f793d65035, + 0x138f31b7a865f8aa, 0x58fc68b4026b9adf, 0xcd48954b79fb6436, + 0x27dfce4a0232af87, + }; + +#if UPDATE_GOLDEN + (void)kGolden; // Silence warning. + std::seed_seq seed_sequence{1, 2, 3}; + pcg64_2018_engine engine(seed_sequence); + std::ostringstream stream; + stream << engine; + auto str = stream.str(); + printf("%s\n\n", str.c_str()); + for (size_t i = 0; i < kNumGoldenOutputs; ++i) { + printf("0x%016lx, ", engine()); + if (i % 3 == 2) { + printf("\n"); + } + } + printf("\n\n\n"); +#else + pcg64_2018_engine engine; + std::istringstream stream( + "2549297995355413924 4865540595714422341 6364136223846793005 " + "1442695040888963407 18088519957565336995 4845369368158826708"); + stream >> engine; + for (const auto& elem : kGolden) { + EXPECT_EQ(elem, engine()); + } +#endif +} + +// ------------------------------------------------------------------ +// Stability tests for pcg32_2018_engine +// ------------------------------------------------------------------ +TEST(PCG322018EngineTest, VerifyGolden) { + constexpr uint32_t kGolden[kNumGoldenOutputs] = { + 0x7a7ecbd9, 0x89fd6c06, 0xae646aa8, 0xcd3cf945, 0x6204b303, 0x198c8585, + 0x49fce611, 0xd1e9297a, 0x142d9440, 0xee75f56b, 0x473a9117, 0xe3a45903, + 0xbce807a1, 0xe54e5f4d, 0x497d6c51, 0x61829166, 0xa740474b, 0x031912a8, + 0x9de3defa, 0xd266dbf1, 0x0f38bebb, 0xec3c4f65, 0x07c5057d, 0xbbce03c8, + 0xfd2ac7a8, 0xffcf4773, 0x5b10affb, 0xede1c842, 0xe22b01b7, 0xda133c8c, + 0xaf89b0f4, 0x25d1b8bc, 0x9f625482, 0x7bfd6882, 0x2e2210c0, 0x2c8fb9a6, + 0x42cb3b83, 0x40ce0dab, 0x644a3510, 0x36230ef2, 0xe2cb6d43, 0x1012b343, + 0x746c6c9f, 0x36714cf8, 0xed1f5026, 0x8bbbf83e, 0xe98710f4, 0x8a2afa36, + 0x09035349, 0x6dc1a487, 0x682b634b, 0xc106794f, 0x7dd78beb, 0x628c262b, + 0x852fb232, 0xb153ac4c, 0x4f169d1b, 0xa69ab774, 0x4bd4b6f2, 0xdc351dd3, + 0x93ff3c8c, 0xa30819ab, 0xff07758c, 0x5ab13c62, 0xd16d7fb5, 0xc4950ffa, + 0xd309ae49, 0xb9677a87, 0x4464e317, 0x90dc44f1, 0xc694c1d4, 0x1d5e1168, + 0xadf37a2d, 0xda38990d, 0x1ec4bd33, 0x36ca25ce, 0xfa0dc76a, 0x968a9d43, + 0x6950ac39, 0xdd3276bc, 0x06d5a71e, 0x1f6f282d, 0x5c626c62, 0xdde3fc31, + 0x152194ce, 0xc35ed14c, 0xb1f7224e, 0x47f76bb8, 0xb34fdd08, 0x7011395e, + 0x162d2a49, 0x0d1bf09f, 0x9428a952, 0x03c5c344, 0xd3525616, 0x7816fff3, + 0x6bceb8a8, 0x8345a081, 0x366420fd, 0x182abeda, 0x70f82745, 0xaf15ded8, + 0xc7f52ca2, 0xa98db9c5, 0x919d99ba, 0x9c376c1c, 0xed8d34c2, 0x716ae9f5, + 0xef062fa5, 0xee3b6c56, 0x52325658, 0x61afa9c3, 0xfdaf02f0, 0x961cf3ab, + 0x9f291565, 0x4fbf3045, 0x0590c899, 0xde901385, 0x45005ffb, 0x509db162, + 0x262fa941, 0x4c421653, 0x4b17c21e, 0xea0d1530, 0xde803845, 0x61bfd515, + 0x438523ef, + }; + + pcg32_2018_engine engine(0); +#if UPDATE_GOLDEN + (void)kGolden; // Silence warning. + for (size_t i = 0; i < kNumGoldenOutputs; ++i) { + printf("0x%08x, ", engine()); + if (i % 6 == 5) { + printf("\n"); + } + } + printf("\n\n\n"); +#else + for (const auto& elem : kGolden) { + EXPECT_EQ(elem, engine()); + } + engine.seed(); + for (const auto& elem : kGolden) { + EXPECT_EQ(elem, engine()); + } +#endif +} + +TEST(PCG322018EngineTest, VerifyGoldenSeeded) { + constexpr uint32_t kGolden[kNumGoldenOutputs] = { + 0x60b5a64c, 0x978502f9, 0x80a75f60, 0x241f1158, 0xa4cd1dbb, 0xe7284017, + 0x3b678da5, 0x5223ec99, 0xe4bdd5d9, 0x72190e6d, 0xe6e702c9, 0xff80c768, + 0xcf126ed3, 0x1fbd20ab, 0x60980489, 0xbc72bf89, 0x407ac6c0, 0x00bf3c51, + 0xf9087897, 0x172e4eb6, 0xe9e4f443, 0x1a6098bf, 0xbf44f8c2, 0xdd84a0e5, + 0xd9a52364, 0xc0e2e786, 0x061ae2ba, 0x9facb8e3, 0x6109432d, 0xd4e0a013, + 0xbd8eb9a6, 0x7e86c3b6, 0x629c0e68, 0x05337430, 0xb495b9f4, 0x11ccd65d, + 0xb578db25, 0x66f1246d, 0x6ef20a7f, 0x5e429812, 0x11772130, 0xb944b5c2, + 0x01624128, 0xa2385ab7, 0xd3e10d35, 0xbe570ec3, 0xc951656f, 0xbe8944a0, + 0x7be41062, 0x5709f919, 0xd745feda, 0x9870b9ae, 0xb44b8168, 0x19e7683b, + 0xded8017f, 0xc6e4d544, 0x91ae4225, 0xd6745fba, 0xb992f284, 0x65b12b33, + 0xa9d5fdb4, 0xf105ce1a, 0x35ca1a6e, 0x2ff70dd0, 0xd8335e49, 0xfb71ddf2, + 0xcaeabb89, 0x5c6f5f84, 0x9a811a7d, 0xbcecbbd1, 0x0f661ba0, 0x9ad93b9d, + 0xedd23e0b, 0x42062f48, 0xd38dd7e4, 0x6cd63c9c, 0x640b98ae, 0x4bff5653, + 0x12626371, 0x13266017, 0xe7a698d8, 0x39c74667, 0xe8fdf2e3, 0x52803bf8, + 0x2af6895b, 0x91335b7b, 0x699e4961, 0x00a40fff, 0x253ff2b6, 0x4a6cf672, + 0x9584e85f, 0xf2a5000c, 0x4d58aba8, 0xb8513e6a, 0x767fad65, 0x8e326f9e, + 0x182f15a1, 0x163dab52, 0xdf99c780, 0x047282a1, 0xee4f90dd, 0xd50394ae, + 0x6c9fd5f0, 0xb06a9194, 0x387e3840, 0x04a9487b, 0xf678a4c2, 0xd0a78810, + 0xd502c97e, 0xd6a9b12a, 0x4accc5dc, 0x416ed53e, 0x50411536, 0xeeb89c24, + 0x813a7902, 0x034ebca6, 0xffa52e7c, 0x7ecd3d0e, 0xfa37a0d2, 0xb1fbe2c1, + 0xb7efc6d1, 0xefa4ccee, 0xf6f80424, 0x2283f3d9, 0x68732284, 0x94f3b5c8, + 0xbbdeceb9, + }; + + ExplicitSeedSeq seed_sequence{12, 34, 56}; + pcg32_2018_engine engine(seed_sequence); +#if UPDATE_GOLDEN + (void)kGolden; // Silence warning. + for (size_t i = 0; i < kNumGoldenOutputs; ++i) { + printf("0x%08x, ", engine()); + if (i % 6 == 5) { + printf("\n"); + } + } + printf("\n\n\n"); +#else + for (const auto& elem : kGolden) { + EXPECT_EQ(elem, engine()); + } + engine.seed(seed_sequence); + for (const auto& elem : kGolden) { + EXPECT_EQ(elem, engine()); + } +#endif +} + +TEST(PCG322018EngineTest, VerifyGoldenFromDeserializedEngine) { + constexpr uint64_t kGolden[kNumGoldenOutputs] = { + 0x780f7042, 0xba137215, 0x43ab6f22, 0x0cb55f46, 0x44b2627d, 0x835597af, + 0xea973ea1, 0x0d2abd35, 0x4fdd601c, 0xac4342fe, 0x7db7e93c, 0xe56ebcaf, + 0x3596470a, 0x7770a9ad, 0x9b893320, 0x57db3415, 0xb432de54, 0xa02baf71, + 0xa256aadb, 0x88921fc7, 0xa35fa6b3, 0xde3eca46, 0x605739a7, 0xa890b82b, + 0xe457b7ad, 0x335fb903, 0xeb06790c, 0xb3c54bf6, 0x6141e442, 0xa599a482, + 0xb78987cc, 0xc61dfe9d, 0x0f1d6ace, 0x17460594, 0x8f6a5061, 0x083dc354, + 0xe9c337fb, 0xcfd105f7, 0x926764b6, 0x638d24dc, 0xeaac650a, 0x67d2cb9c, + 0xd807733c, 0x205fc52e, 0xf5399e2e, 0x6c46ddcc, 0xb603e875, 0xce113a25, + 0x3c8d4813, 0xfb584db8, 0xf6d255ff, 0xea80954f, 0x42e8be85, 0xb2feee72, + 0x62bd8d16, 0x1be4a142, 0x97dca1a4, 0xdd6e7333, 0xb2caa20e, 0xa12b1588, + 0xeb3a5a1a, 0x6fa5ba89, 0x077ea931, 0x8ddb1713, 0x0dd03079, 0x2c2ba965, + 0xa77fac17, 0xc8325742, 0x8bb893bf, 0xc2315741, 0xeaceee92, 0x81dd2ee2, + 0xe5214216, 0x1b9b8fb2, 0x01646d03, 0x24facc25, 0xd8c0e0bb, 0xa33fe106, + 0xf34fe976, 0xb3b4b44e, 0x65618fed, 0x032c6192, 0xa9dd72ce, 0xf391887b, + 0xf41c6a6e, 0x05c4bd6d, 0x37fa260e, 0x46b05659, 0xb5f6348a, 0x62d26d89, + 0x39f6452d, 0xb17b30a2, 0xbdd82743, 0x38ecae3b, 0xfe90f0a2, 0xcb2d226d, + 0xcf8a0b1c, 0x0eed3d4d, 0xa1f69cfc, 0xd7ac3ba5, 0xce9d9a6b, 0x121deb4c, + 0x4a0d03f3, 0xc1821ed1, 0x59c249ac, 0xc0abb474, 0x28149985, 0xfd9a82ba, + 0x5960c3b2, 0xeff00cba, 0x6073aa17, 0x25dc0919, 0x9976626e, 0xdd2ccc33, + 0x39ecb6ec, 0xc6e15d13, 0xfac94cfd, 0x28cfd34f, 0xf2d2c32d, 0x51c23d08, + 0x4fdb2f48, 0x97baa807, 0xf2c1004c, 0xc4ae8136, 0x71f31c94, 0x8c92d601, + 0x36caf5cd, + }; + +#if UPDATE_GOLDEN + (void)kGolden; // Silence warning. + std::seed_seq seed_sequence{1, 2, 3}; + pcg32_2018_engine engine(seed_sequence); + std::ostringstream stream; + stream << engine; + auto str = stream.str(); + printf("%s\n\n", str.c_str()); + for (size_t i = 0; i < kNumGoldenOutputs; ++i) { + printf("0x%08x, ", engine()); + if (i % 6 == 5) { + printf("\n"); + } + } + printf("\n\n\n"); + + EXPECT_FALSE(true); +#else + pcg32_2018_engine engine; + std::istringstream stream( + "6364136223846793005 1442695040888963407 6537028157270659894"); + stream >> engine; + for (const auto& elem : kGolden) { + EXPECT_EQ(elem, engine()); + } +#endif +} + +} // namespace diff --git a/absl/random/internal/platform.h b/absl/random/internal/platform.h new file mode 100644 index 000000000000..5edab3448bbf --- /dev/null +++ b/absl/random/internal/platform.h @@ -0,0 +1,212 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_PLATFORM_H_ +#define ABSL_RANDOM_INTERNAL_PLATFORM_H_ + +// HERMETIC NOTE: The randen_hwaes target must not introduce duplicate +// symbols from arbitrary system and other headers, since it may be built +// with different flags from other targets, using different levels of +// optimization, potentially introducing ODR violations. + +// ----------------------------------------------------------------------------- +// Platform Feature Checks +// ----------------------------------------------------------------------------- + +// Currently supported operating systems and associated preprocessor +// symbols: +// +// Linux and Linux-derived __linux__ +// Android __ANDROID__ (implies __linux__) +// Linux (non-Android) __linux__ && !__ANDROID__ +// Darwin (Mac OS X and iOS) __APPLE__ +// Akaros (http://akaros.org) __ros__ +// Windows _WIN32 +// NaCL __native_client__ +// AsmJS __asmjs__ +// WebAssembly __wasm__ +// Fuchsia __Fuchsia__ +// +// Note that since Android defines both __ANDROID__ and __linux__, one +// may probe for either Linux or Android by simply testing for __linux__. +// +// NOTE: For __APPLE__ platforms, we use #include <TargetConditionals.h> +// to distinguish os variants. +// +// http://nadeausoftware.com/articles/2012/01/c_c_tip_how_use_compiler_predefined_macros_detect_operating_system + +#if defined(__APPLE__) +#include <TargetConditionals.h> +#endif + +// ----------------------------------------------------------------------------- +// Architecture Checks +// ----------------------------------------------------------------------------- + +// These preprocessor directives are trying to determine CPU architecture, +// including necessary headers to support hardware AES. +// +// ABSL_ARCH_{X86/PPC/ARM} macros determine the platform. +#if defined(__x86_64__) || defined(__x86_64) || defined(_M_AMD64) || \ + defined(_M_X64) +#define ABSL_ARCH_X86_64 +#elif defined(__i386) || defined(_M_IX86) +#define ABSL_ARCH_X86_32 +#elif defined(__aarch64__) || defined(__arm64__) || defined(_M_ARM64) +#define ABSL_ARCH_AARCH64 +#elif defined(__arm__) || defined(__ARMEL__) || defined(_M_ARM) +#define ABSL_ARCH_ARM +#elif defined(__powerpc64__) || defined(__PPC64__) || defined(__powerpc__) || \ + defined(__ppc__) || defined(__PPC__) +#define ABSL_ARCH_PPC +#else +// Unsupported architecture. +// * https://sourceforge.net/p/predef/wiki/Architectures/ +// * https://msdn.microsoft.com/en-us/library/b0084kay.aspx +// * for gcc, clang: "echo | gcc -E -dM -" +#endif + +// ----------------------------------------------------------------------------- +// Attribute Checks +// ----------------------------------------------------------------------------- + +// ABSL_HAVE_ATTRIBUTE +#undef ABSL_HAVE_ATTRIBUTE +#ifdef __has_attribute +#define ABSL_HAVE_ATTRIBUTE(x) __has_attribute(x) +#else +#define ABSL_HAVE_ATTRIBUTE(x) 0 +#endif + +// ABSL_ATTRIBUTE_ALWAYS_INLINE forces inlining of the method. +#undef ABSL_ATTRIBUTE_ALWAYS_INLINE +#if ABSL_HAVE_ATTRIBUTE(always_inline) || \ + (defined(__GNUC__) && !defined(__clang__)) +#define ABSL_ATTRIBUTE_ALWAYS_INLINE __attribute__((always_inline)) +#elif defined(_MSC_VER) +// We can achieve something similar to attribute((always_inline)) with MSVC by +// using the __forceinline keyword, however this is not perfect. MSVC is +// much less aggressive about inlining, and even with the __forceinline keyword. +#define ABSL_ATTRIBUTE_ALWAYS_INLINE __forceinline +#else +#define ABSL_ATTRIBUTE_ALWAYS_INLINE +#endif + +// ABSL_ATTRIBUTE_NEVER_INLINE prevents inlining of the method. +#undef ABSL_ATTRIBUTE_NEVER_INLINE +#if ABSL_HAVE_ATTRIBUTE(noinline) || (defined(__GNUC__) && !defined(__clang__)) +#define ABSL_ATTRIBUTE_NEVER_INLINE __attribute__((noinline)) +#elif defined(_MSC_VER) +#define ABSL_ATTRIBUTE_NEVER_INLINE __declspec(noinline) +#else +#define ABSL_ATTRIBUTE_NEVER_INLINE +#endif + +// ABSL_ATTRIBUTE_FLATTEN enables much more aggressive inlining within +// the indicated function. +#undef ABSL_ATTRIBUTE_FLATTEN +#if ABSL_HAVE_ATTRIBUTE(flatten) || (defined(__GNUC__) && !defined(__clang__)) +#define ABSL_ATTRIBUTE_FLATTEN __attribute__((flatten)) +#else +#define ABSL_ATTRIBUTE_FLATTEN +#endif + +// ABSL_RANDOM_INTERNAL_RESTRICT annotates whether pointers may be considered +// to be unaliased. +#undef ABSL_RANDOM_INTERNAL_RESTRICT +#if defined(__clang__) || defined(__GNUC__) +#define ABSL_RANDOM_INTERNAL_RESTRICT __restrict__ +#elif defined(_MSC_VER) +#define ABSL_RANDOM_INTERNAL_RESTRICT __restrict +#else +#define ABSL_RANDOM_INTERNAL_RESTRICT +#endif + +// ABSL_HAVE_ACCELERATED_AES indicates whether the currently active compiler +// flags (e.g. -maes) allow using hardware accelerated AES instructions, which +// implies us assuming that the target platform supports them. +#define ABSL_HAVE_ACCELERATED_AES 0 + +#if defined(ABSL_ARCH_X86_64) + +#if defined(__AES__) || defined(__AVX__) +#undef ABSL_HAVE_ACCELERATED_AES +#define ABSL_HAVE_ACCELERATED_AES 1 +#endif + +#elif defined(ABSL_ARCH_PPC) + +// Rely on VSX and CRYPTO extensions for vcipher on PowerPC. +#if (defined(__VEC__) || defined(__ALTIVEC__)) && defined(__VSX__) && \ + defined(__CRYPTO__) +#undef ABSL_HAVE_ACCELERATED_AES +#define ABSL_HAVE_ACCELERATED_AES 1 +#endif + +#elif defined(ABSL_ARCH_ARM) || defined(ABSL_ARCH_AARCH64) + +// http://infocenter.arm.com/help/topic/com.arm.doc.ihi0053c/IHI0053C_acle_2_0.pdf +// Rely on NEON+CRYPTO extensions for ARM. +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_CRYPTO) +#undef ABSL_HAVE_ACCELERATED_AES +#define ABSL_HAVE_ACCELERATED_AES 1 +#endif + +#endif + +// NaCl does not allow AES. +#if defined(__native_client__) +#undef ABSL_HAVE_ACCELERATED_AES +#define ABSL_HAVE_ACCELERATED_AES 0 +#endif + +// ABSL_RANDOM_INTERNAL_AES_DISPATCH indicates whether the currently active +// platform has, or should use run-time dispatch for selecting the +// acclerated Randen implementation. +#define ABSL_RANDOM_INTERNAL_AES_DISPATCH 0 + +#if defined(ABSL_ARCH_X86_64) +// Dispatch is available on x86_64 +#undef ABSL_RANDOM_INTERNAL_AES_DISPATCH +#define ABSL_RANDOM_INTERNAL_AES_DISPATCH 1 +#elif defined(__linux__) && defined(ABSL_ARCH_PPC) +// Or when running linux PPC +#undef ABSL_RANDOM_INTERNAL_AES_DISPATCH +#define ABSL_RANDOM_INTERNAL_AES_DISPATCH 1 +#elif defined(__linux__) && defined(ABSL_ARCH_AARCH64) +// Or when running linux AArch64 +#undef ABSL_RANDOM_INTERNAL_AES_DISPATCH +#define ABSL_RANDOM_INTERNAL_AES_DISPATCH 1 +#elif defined(__linux__) && defined(ABSL_ARCH_ARM) && (__ARM_ARCH >= 8) +// Or when running linux ARM v8 or higher. +// (This captures a lot of Android configurations.) +#undef ABSL_RANDOM_INTERNAL_AES_DISPATCH +#define ABSL_RANDOM_INTERNAL_AES_DISPATCH 1 +#endif + +// NaCl does not allow dispatch. +#if defined(__native_client__) +#undef ABSL_RANDOM_INTERNAL_AES_DISPATCH +#define ABSL_RANDOM_INTERNAL_AES_DISPATCH 0 +#endif + +// iOS does not support dispatch, even on x86, since applications +// should be bundled as fat binaries, with a different build tailored for +// each specific supported platform/architecture. +#if defined(__APPLE__) && (TARGET_OS_IPHONE || TARGET_OS_IPHONE_SIMULATOR) +#undef ABSL_RANDOM_INTERNAL_AES_DISPATCH +#define ABSL_RANDOM_INTERNAL_AES_DISPATCH 0 +#endif + +#endif // ABSL_RANDOM_INTERNAL_PLATFORM_H_ diff --git a/absl/random/internal/pool_urbg.cc b/absl/random/internal/pool_urbg.cc new file mode 100644 index 000000000000..b24eeeffeaa9 --- /dev/null +++ b/absl/random/internal/pool_urbg.cc @@ -0,0 +1,252 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/pool_urbg.h" + +#include <algorithm> +#include <atomic> +#include <cstdint> +#include <cstring> +#include <iterator> + +#include "absl/base/attributes.h" +#include "absl/base/call_once.h" +#include "absl/base/config.h" +#include "absl/base/internal/endian.h" +#include "absl/base/internal/raw_logging.h" +#include "absl/base/internal/spinlock.h" +#include "absl/base/internal/sysinfo.h" +#include "absl/base/internal/unaligned_access.h" +#include "absl/base/optimization.h" +#include "absl/random/internal/randen.h" +#include "absl/random/internal/seed_material.h" +#include "absl/random/seed_gen_exception.h" + +using absl::base_internal::SpinLock; +using absl::base_internal::SpinLockHolder; + +namespace absl { +namespace random_internal { +namespace { + +// RandenPoolEntry is a thread-safe pseudorandom bit generator, implementing a +// single generator within a RandenPool<T>. It is an internal implementation +// detail, and does not aim to conform to [rand.req.urng]. +// +// NOTE: There are alignment issues when used on ARM, for instance. +// See the allocation code in PoolAlignedAlloc(). +class RandenPoolEntry { + public: + static constexpr size_t kState = RandenTraits::kStateBytes / sizeof(uint32_t); + static constexpr size_t kCapacity = + RandenTraits::kCapacityBytes / sizeof(uint32_t); + + void Init(absl::Span<const uint32_t> data) { + SpinLockHolder l(&mu_); // Always uncontested. + std::copy(data.begin(), data.end(), std::begin(state_)); + next_ = kState; + } + + // Copy bytes into out. + void Fill(uint8_t* out, size_t bytes) LOCKS_EXCLUDED(mu_); + + // Returns random bits from the buffer in units of T. + template <typename T> + inline T Generate() LOCKS_EXCLUDED(mu_); + + inline void MaybeRefill() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (next_ >= kState) { + next_ = kCapacity; + impl_.Generate(state_); + } + } + + private: + // Randen URBG state. + uint32_t state_[kState] GUARDED_BY(mu_); // First to satisfy alignment. + SpinLock mu_; + const Randen impl_; + size_t next_ GUARDED_BY(mu_); +}; + +template <> +inline uint8_t RandenPoolEntry::Generate<uint8_t>() { + SpinLockHolder l(&mu_); + MaybeRefill(); + return static_cast<uint8_t>(state_[next_++]); +} + +template <> +inline uint16_t RandenPoolEntry::Generate<uint16_t>() { + SpinLockHolder l(&mu_); + MaybeRefill(); + return static_cast<uint16_t>(state_[next_++]); +} + +template <> +inline uint32_t RandenPoolEntry::Generate<uint32_t>() { + SpinLockHolder l(&mu_); + MaybeRefill(); + return state_[next_++]; +} + +template <> +inline uint64_t RandenPoolEntry::Generate<uint64_t>() { + SpinLockHolder l(&mu_); + if (next_ >= kState - 1) { + next_ = kCapacity; + impl_.Generate(state_); + } + auto p = state_ + next_; + next_ += 2; + + uint64_t result; + std::memcpy(&result, p, sizeof(result)); + return result; +} + +void RandenPoolEntry::Fill(uint8_t* out, size_t bytes) { + SpinLockHolder l(&mu_); + while (bytes > 0) { + MaybeRefill(); + size_t remaining = (kState - next_) * sizeof(state_[0]); + size_t to_copy = std::min(bytes, remaining); + std::memcpy(out, &state_[next_], to_copy); + out += to_copy; + bytes -= to_copy; + next_ += (to_copy + sizeof(state_[0]) - 1) / sizeof(state_[0]); + } +} + +// Number of pooled urbg entries. +static constexpr int kPoolSize = 8; + +// Shared pool entries. +static absl::once_flag pool_once; +ABSL_CACHELINE_ALIGNED static RandenPoolEntry* shared_pools[kPoolSize]; + +// Returns an id in the range [0 ... kPoolSize), which indexes into the +// pool of random engines. +// +// Each thread to access the pool is assigned a sequential ID (without reuse) +// from the pool-id space; the id is cached in a thread_local variable. +// This id is assigned based on the arrival-order of the thread to the +// GetPoolID call; this has no binary, CL, or runtime stability because +// on subsequent runs the order within the same program may be significantly +// different. However, as other thread IDs are not assigned sequentially, +// this is not expected to matter. +int GetPoolID() { + static_assert(kPoolSize >= 1, + "At least one urbg instance is required for PoolURBG"); + + ABSL_CONST_INIT static std::atomic<int64_t> sequence{0}; + +#ifdef ABSL_HAVE_THREAD_LOCAL + static thread_local int my_pool_id = -1; + if (ABSL_PREDICT_FALSE(my_pool_id < 0)) { + my_pool_id = (sequence++ % kPoolSize); + } + return my_pool_id; +#else + static pthread_key_t tid_key = [] { + pthread_key_t tmp_key; + int err = pthread_key_create(&tmp_key, nullptr); + if (err) { + ABSL_RAW_LOG(FATAL, "pthread_key_create failed with %d", err); + } + return tmp_key; + }(); + + // Store the value in the pthread_{get/set}specific. However an uninitialized + // value is 0, so add +1 to distinguish from the null value. + intptr_t my_pool_id = + reinterpret_cast<intptr_t>(pthread_getspecific(tid_key)); + if (ABSL_PREDICT_FALSE(my_pool_id == 0)) { + // No allocated ID, allocate the next value, cache it, and return. + my_pool_id = (sequence++ % kPoolSize) + 1; + int err = pthread_setspecific(tid_key, reinterpret_cast<void*>(my_pool_id)); + if (err) { + ABSL_RAW_LOG(FATAL, "pthread_setspecific failed with %d", err); + } + } + return my_pool_id - 1; +#endif +} + +// Allocate a RandenPoolEntry with at least 32-byte alignment, which is required +// by ARM platform code. +RandenPoolEntry* PoolAlignedAlloc() { + constexpr size_t kAlignment = + ABSL_CACHELINE_SIZE > 32 ? ABSL_CACHELINE_SIZE : 32; + + // Not all the platforms that we build for have std::aligned_alloc, however + // since we never free these objects, we can over allocate and munge the + // pointers to the correct alignment. + void* memory = std::malloc(sizeof(RandenPoolEntry) + kAlignment); + auto x = reinterpret_cast<intptr_t>(memory); + auto y = x % kAlignment; + void* aligned = + (y == 0) ? memory : reinterpret_cast<void*>(x + kAlignment - y); + return new (aligned) RandenPoolEntry(); +} + +// Allocate and initialize kPoolSize objects of type RandenPoolEntry. +// +// The initialization strategy is to initialize one object directly from +// OS entropy, then to use that object to seed all of the individual +// pool instances. +void InitPoolURBG() { + static constexpr size_t kSeedSize = + RandenTraits::kStateBytes / sizeof(uint32_t); + // Read the seed data from OS entropy once. + uint32_t seed_material[kPoolSize * kSeedSize]; + if (!random_internal::ReadSeedMaterialFromOSEntropy( + absl::MakeSpan(seed_material))) { + random_internal::ThrowSeedGenException(); + } + for (int i = 0; i < kPoolSize; i++) { + shared_pools[i] = PoolAlignedAlloc(); + shared_pools[i]->Init( + absl::MakeSpan(&seed_material[i * kSeedSize], kSeedSize)); + } +} + +// Returns the pool entry for the current thread. +RandenPoolEntry* GetPoolForCurrentThread() { + absl::call_once(pool_once, InitPoolURBG); + return shared_pools[GetPoolID()]; +} + +} // namespace + +template <typename T> +typename RandenPool<T>::result_type RandenPool<T>::Generate() { + auto* pool = GetPoolForCurrentThread(); + return pool->Generate<T>(); +} + +template <typename T> +void RandenPool<T>::Fill(absl::Span<result_type> data) { + auto* pool = GetPoolForCurrentThread(); + pool->Fill(reinterpret_cast<uint8_t*>(data.data()), + data.size() * sizeof(result_type)); +} + +template class RandenPool<uint8_t>; +template class RandenPool<uint16_t>; +template class RandenPool<uint32_t>; +template class RandenPool<uint64_t>; + +} // namespace random_internal +} // namespace absl diff --git a/absl/random/internal/pool_urbg.h b/absl/random/internal/pool_urbg.h new file mode 100644 index 000000000000..9b2dd4bfceeb --- /dev/null +++ b/absl/random/internal/pool_urbg.h @@ -0,0 +1,129 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_POOL_URBG_H_ +#define ABSL_RANDOM_INTERNAL_POOL_URBG_H_ + +#include <cinttypes> +#include <limits> + +#include "absl/random/internal/traits.h" +#include "absl/types/span.h" + +namespace absl { +namespace random_internal { + +// RandenPool is a thread-safe random number generator [random.req.urbg] that +// uses an underlying pool of Randen generators to generate values. Each thread +// has affinity to one instance of the underlying pool generators. Concurrent +// access is guarded by a spin-lock. +template <typename T> +class RandenPool { + public: + using result_type = T; + static_assert(std::is_unsigned<result_type>::value, + "RandenPool template argument must be a built-in unsigned " + "integer type"); + + static constexpr result_type(min)() { + return (std::numeric_limits<result_type>::min)(); + } + + static constexpr result_type(max)() { + return (std::numeric_limits<result_type>::max)(); + } + + RandenPool() {} + + // Returns a single value. + inline result_type operator()() { return Generate(); } + + // Fill data with random values. + static void Fill(absl::Span<result_type> data); + + protected: + // Generate returns a single value. + static result_type Generate(); +}; + +extern template class RandenPool<uint8_t>; +extern template class RandenPool<uint16_t>; +extern template class RandenPool<uint32_t>; +extern template class RandenPool<uint64_t>; + +// PoolURBG uses an underlying pool of random generators to implement a +// thread-compatible [random.req.urbg] interface with an internal cache of +// values. +template <typename T, size_t kBufferSize> +class PoolURBG { + // Inheritance to access the protected static members of RandenPool. + using unsigned_type = typename make_unsigned_bits<T>::type; + using PoolType = RandenPool<unsigned_type>; + using SpanType = absl::Span<unsigned_type>; + + static constexpr size_t kInitialBuffer = kBufferSize + 1; + static constexpr size_t kHalfBuffer = kBufferSize / 2; + + public: + using result_type = T; + + static_assert(std::is_unsigned<result_type>::value, + "PoolURBG must be parameterized by an unsigned integer type"); + + static_assert(kBufferSize > 1, + "PoolURBG must be parameterized by a buffer-size > 1"); + + static_assert(kBufferSize <= 256, + "PoolURBG must be parameterized by a buffer-size <= 256"); + + static constexpr result_type(min)() { + return (std::numeric_limits<result_type>::min)(); + } + + static constexpr result_type(max)() { + return (std::numeric_limits<result_type>::max)(); + } + + PoolURBG() : next_(kInitialBuffer) {} + + // copy-constructor does not copy cache. + PoolURBG(const PoolURBG&) : next_(kInitialBuffer) {} + const PoolURBG& operator=(const PoolURBG&) { + next_ = kInitialBuffer; + return *this; + } + + // move-constructor does move cache. + PoolURBG(PoolURBG&&) = default; + PoolURBG& operator=(PoolURBG&&) = default; + + inline result_type operator()() { + if (next_ >= kBufferSize) { + next_ = (kBufferSize > 2 && next_ > kBufferSize) ? kHalfBuffer : 0; + PoolType::Fill(SpanType(reinterpret_cast<unsigned_type*>(state_ + next_), + kBufferSize - next_)); + } + return state_[next_++]; + } + + private: + // Buffer size. + size_t next_; // index within state_ + result_type state_[kBufferSize]; +}; + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_POOL_URBG_H_ diff --git a/absl/random/internal/pool_urbg_test.cc b/absl/random/internal/pool_urbg_test.cc new file mode 100644 index 000000000000..53f4eacf16fe --- /dev/null +++ b/absl/random/internal/pool_urbg_test.cc @@ -0,0 +1,182 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/pool_urbg.h" + +#include <algorithm> +#include <bitset> +#include <cmath> +#include <cstdint> +#include <iterator> + +#include "gtest/gtest.h" +#include "absl/meta/type_traits.h" +#include "absl/types/span.h" + +using absl::random_internal::PoolURBG; +using absl::random_internal::RandenPool; + +namespace { + +// is_randen_pool trait is true when parameterized by an RandenPool +template <typename T> +using is_randen_pool = typename absl::disjunction< // + std::is_same<T, RandenPool<uint8_t>>, // + std::is_same<T, RandenPool<uint16_t>>, // + std::is_same<T, RandenPool<uint32_t>>, // + std::is_same<T, RandenPool<uint64_t>>>; // + +// MyFill either calls RandenPool::Fill() or std::generate(..., rng) +template <typename T, typename V> +typename absl::enable_if_t<absl::negation<is_randen_pool<T>>::value, void> // +MyFill(T& rng, absl::Span<V> data) { // NOLINT(runtime/references) + std::generate(std::begin(data), std::end(data), rng); +} + +template <typename T, typename V> +typename absl::enable_if_t<is_randen_pool<T>::value, void> // +MyFill(T& rng, absl::Span<V> data) { // NOLINT(runtime/references) + rng.Fill(data); +} + +template <typename EngineType> +class PoolURBGTypedTest : public ::testing::Test {}; + +using EngineTypes = ::testing::Types< // + RandenPool<uint8_t>, // + RandenPool<uint16_t>, // + RandenPool<uint32_t>, // + RandenPool<uint64_t>, // + PoolURBG<uint8_t, 2>, // + PoolURBG<uint16_t, 2>, // + PoolURBG<uint32_t, 2>, // + PoolURBG<uint64_t, 2>, // + PoolURBG<unsigned int, 8>, // NOLINT(runtime/int) + PoolURBG<unsigned long, 8>, // NOLINT(runtime/int) + PoolURBG<unsigned long int, 4>, // NOLINT(runtime/int) + PoolURBG<unsigned long long, 4>>; // NOLINT(runtime/int) + +TYPED_TEST_SUITE(PoolURBGTypedTest, EngineTypes); + +// This test is checks that the engines meet the URBG interface requirements +// defined in [rand.req.urbg]. +TYPED_TEST(PoolURBGTypedTest, URBGInterface) { + using E = TypeParam; + using T = typename E::result_type; + + static_assert(std::is_copy_constructible<E>::value, + "engine must be copy constructible"); + + static_assert(absl::is_copy_assignable<E>::value, + "engine must be copy assignable"); + + E e; + const E x; + + e(); + + static_assert(std::is_same<decltype(e()), T>::value, + "return type of operator() must be result_type"); + + E u0(x); + u0(); + + E u1 = e; + u1(); +} + +// This validates that sequences are independent. +TYPED_TEST(PoolURBGTypedTest, VerifySequences) { + using E = TypeParam; + using result_type = typename E::result_type; + + E rng; + (void)rng(); // Discard one value. + + constexpr int kNumOutputs = 64; + result_type a[kNumOutputs]; + result_type b[kNumOutputs]; + std::fill(std::begin(b), std::end(b), 0); + + // Fill a using Fill or generate, depending on the engine type. + { + E x = rng; + MyFill(x, absl::MakeSpan(a)); + } + + // Fill b using std::generate(). + { + E x = rng; + std::generate(std::begin(b), std::end(b), x); + } + + // Test that generated sequence changed as sequence of bits, i.e. if about + // half of the bites were flipped between two non-correlated values. + size_t changed_bits = 0; + size_t unchanged_bits = 0; + size_t total_set = 0; + size_t total_bits = 0; + size_t equal_count = 0; + for (size_t i = 0; i < kNumOutputs; ++i) { + equal_count += (a[i] == b[i]) ? 1 : 0; + std::bitset<sizeof(result_type) * 8> bitset(a[i] ^ b[i]); + changed_bits += bitset.count(); + unchanged_bits += bitset.size() - bitset.count(); + + std::bitset<sizeof(result_type) * 8> a_set(a[i]); + std::bitset<sizeof(result_type) * 8> b_set(b[i]); + total_set += a_set.count() + b_set.count(); + total_bits += 2 * 8 * sizeof(result_type); + } + // On average, half the bits are changed between two calls. + EXPECT_LE(changed_bits, 0.60 * (changed_bits + unchanged_bits)); + EXPECT_GE(changed_bits, 0.40 * (changed_bits + unchanged_bits)); + + // verify using a quick normal-approximation to the binomial. + EXPECT_NEAR(total_set, total_bits * 0.5, 4 * std::sqrt(total_bits)) + << "@" << total_set / static_cast<double>(total_bits); + + // Also, A[i] == B[i] with probability (1/range) * N. + // Give this a pretty wide latitude, though. + const double kExpected = kNumOutputs / (1.0 * sizeof(result_type) * 8); + EXPECT_LE(equal_count, 1.0 + kExpected); +} + +} // namespace + +/* +$ nanobenchmarks 1 RandenPool construct +$ nanobenchmarks 1 PoolURBG construct + +RandenPool<uint32_t> | 1 | 1000 | 48482.00 ticks | 48.48 ticks | 13.9 ns +RandenPool<uint32_t> | 10 | 2000 | 1028795.00 ticks | 51.44 ticks | 14.7 ns +RandenPool<uint32_t> | 100 | 1000 | 5119968.00 ticks | 51.20 ticks | 14.6 ns +RandenPool<uint32_t> | 1000 | 500 | 25867936.00 ticks | 51.74 ticks | 14.8 ns + +RandenPool<uint64_t> | 1 | 1000 | 49921.00 ticks | 49.92 ticks | 14.3 ns +RandenPool<uint64_t> | 10 | 2000 | 1208269.00 ticks | 60.41 ticks | 17.3 ns +RandenPool<uint64_t> | 100 | 1000 | 5844955.00 ticks | 58.45 ticks | 16.7 ns +RandenPool<uint64_t> | 1000 | 500 | 28767404.00 ticks | 57.53 ticks | 16.4 ns + +PoolURBG<uint32_t,8> | 1 | 1000 | 86431.00 ticks | 86.43 ticks | 24.7 ns +PoolURBG<uint32_t,8> | 10 | 1000 | 206191.00 ticks | 20.62 ticks | 5.9 ns +PoolURBG<uint32_t,8> | 100 | 1000 | 1516049.00 ticks | 15.16 ticks | 4.3 ns +PoolURBG<uint32_t,8> | 1000 | 500 | 7613936.00 ticks | 15.23 ticks | 4.4 ns + +PoolURBG<uint64_t,4> | 1 | 1000 | 96668.00 ticks | 96.67 ticks | 27.6 ns +PoolURBG<uint64_t,4> | 10 | 1000 | 282423.00 ticks | 28.24 ticks | 8.1 ns +PoolURBG<uint64_t,4> | 100 | 1000 | 2609587.00 ticks | 26.10 ticks | 7.5 ns +PoolURBG<uint64_t,4> | 1000 | 500 | 12408757.00 ticks | 24.82 ticks | 7.1 ns + +*/ diff --git a/absl/random/internal/randen-keys.inc b/absl/random/internal/randen-keys.inc new file mode 100644 index 000000000000..fa4b16683537 --- /dev/null +++ b/absl/random/internal/randen-keys.inc @@ -0,0 +1,207 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_RANDEN_KEYS_INC_ +#define ABSL_RANDOM_INTERNAL_RANDEN_KEYS_INC_ + +// Textual header to include the randen_keys where necessary. +// REQUIRES: struct u64x2{} +// +// PROVIDES: kKeys +// PROVIDES: round_keys[] + +// "Nothing up my sleeve" numbers from the first hex digits of Pi, obtained +// from http://hexpi.sourceforge.net/. The array was generated by following +// Python script: +/* +python << EOF +"""Generates Randen round keys array from pi-hex.62500.txt file.""" +import binascii + +KEYS = 136 + +def chunks(l, n): + """Yield successive n-sized chunks from l.""" + for i in range(0, len(l), n): + yield l[i:i + n] + +def pairwise(t): + """Transforms sequence into sequence of pairs.""" + it = iter(t) + return zip(it,it) + +def digits_from_pi(): + """Reads digits from hexpi.sourceforge.net file.""" + with open("pi-hex.62500.txt") as file: + return file.read() + +def digits_from_urandom(): + """Reads digits from /dev/urandom.""" + with open("/dev/urandom") as file: + return binascii.hexlify(file.read(KEYS * 16)) + +digits = digits_from_pi() +print("static constexpr const size_t kRoundKeys = {0};\n".format(KEYS)) +print("alignas(16) constexpr const u64x2 round_keys[kRoundKeys] = {") + +for i, (hi, lo) in zip(range(KEYS), pairwise(chunks(digits, 16))): + hi = "0x{0}ull".format(hi) + lo = "0x{0}ull".format(lo) + print(" u64x2({0}, {1}){2}".format(hi, lo, ',' if i+1 < KEYS else '')) + +print("};") +EOF +*/ + +static constexpr const size_t kRoundKeys = 136; + +alignas(16) constexpr u64x2 round_keys[kRoundKeys] = { + u64x2(0x243F6A8885A308D3ull, 0x13198A2E03707344ull), + u64x2(0xA4093822299F31D0ull, 0x082EFA98EC4E6C89ull), + u64x2(0x452821E638D01377ull, 0xBE5466CF34E90C6Cull), + u64x2(0xC0AC29B7C97C50DDull, 0x3F84D5B5B5470917ull), + u64x2(0x9216D5D98979FB1Bull, 0xD1310BA698DFB5ACull), + u64x2(0x2FFD72DBD01ADFB7ull, 0xB8E1AFED6A267E96ull), + u64x2(0xBA7C9045F12C7F99ull, 0x24A19947B3916CF7ull), + u64x2(0x0801F2E2858EFC16ull, 0x636920D871574E69ull), + u64x2(0xA458FEA3F4933D7Eull, 0x0D95748F728EB658ull), + u64x2(0x718BCD5882154AEEull, 0x7B54A41DC25A59B5ull), + u64x2(0x9C30D5392AF26013ull, 0xC5D1B023286085F0ull), + u64x2(0xCA417918B8DB38EFull, 0x8E79DCB0603A180Eull), + u64x2(0x6C9E0E8BB01E8A3Eull, 0xD71577C1BD314B27ull), + u64x2(0x78AF2FDA55605C60ull, 0xE65525F3AA55AB94ull), + u64x2(0x5748986263E81440ull, 0x55CA396A2AAB10B6ull), + u64x2(0xB4CC5C341141E8CEull, 0xA15486AF7C72E993ull), + u64x2(0xB3EE1411636FBC2Aull, 0x2BA9C55D741831F6ull), + u64x2(0xCE5C3E169B87931Eull, 0xAFD6BA336C24CF5Cull), + u64x2(0x7A32538128958677ull, 0x3B8F48986B4BB9AFull), + u64x2(0xC4BFE81B66282193ull, 0x61D809CCFB21A991ull), + u64x2(0x487CAC605DEC8032ull, 0xEF845D5DE98575B1ull), + u64x2(0xDC262302EB651B88ull, 0x23893E81D396ACC5ull), + u64x2(0x0F6D6FF383F44239ull, 0x2E0B4482A4842004ull), + u64x2(0x69C8F04A9E1F9B5Eull, 0x21C66842F6E96C9Aull), + u64x2(0x670C9C61ABD388F0ull, 0x6A51A0D2D8542F68ull), + u64x2(0x960FA728AB5133A3ull, 0x6EEF0B6C137A3BE4ull), + u64x2(0xBA3BF0507EFB2A98ull, 0xA1F1651D39AF0176ull), + u64x2(0x66CA593E82430E88ull, 0x8CEE8619456F9FB4ull), + u64x2(0x7D84A5C33B8B5EBEull, 0xE06F75D885C12073ull), + u64x2(0x401A449F56C16AA6ull, 0x4ED3AA62363F7706ull), + u64x2(0x1BFEDF72429B023Dull, 0x37D0D724D00A1248ull), + u64x2(0xDB0FEAD349F1C09Bull, 0x075372C980991B7Bull), + u64x2(0x25D479D8F6E8DEF7ull, 0xE3FE501AB6794C3Bull), + u64x2(0x976CE0BD04C006BAull, 0xC1A94FB6409F60C4ull), + u64x2(0x5E5C9EC2196A2463ull, 0x68FB6FAF3E6C53B5ull), + u64x2(0x1339B2EB3B52EC6Full, 0x6DFC511F9B30952Cull), + u64x2(0xCC814544AF5EBD09ull, 0xBEE3D004DE334AFDull), + u64x2(0x660F2807192E4BB3ull, 0xC0CBA85745C8740Full), + u64x2(0xD20B5F39B9D3FBDBull, 0x5579C0BD1A60320Aull), + u64x2(0xD6A100C6402C7279ull, 0x679F25FEFB1FA3CCull), + u64x2(0x8EA5E9F8DB3222F8ull, 0x3C7516DFFD616B15ull), + u64x2(0x2F501EC8AD0552ABull, 0x323DB5FAFD238760ull), + u64x2(0x53317B483E00DF82ull, 0x9E5C57BBCA6F8CA0ull), + u64x2(0x1A87562EDF1769DBull, 0xD542A8F6287EFFC3ull), + u64x2(0xAC6732C68C4F5573ull, 0x695B27B0BBCA58C8ull), + u64x2(0xE1FFA35DB8F011A0ull, 0x10FA3D98FD2183B8ull), + u64x2(0x4AFCB56C2DD1D35Bull, 0x9A53E479B6F84565ull), + u64x2(0xD28E49BC4BFB9790ull, 0xE1DDF2DAA4CB7E33ull), + u64x2(0x62FB1341CEE4C6E8ull, 0xEF20CADA36774C01ull), + u64x2(0xD07E9EFE2BF11FB4ull, 0x95DBDA4DAE909198ull), + u64x2(0xEAAD8E716B93D5A0ull, 0xD08ED1D0AFC725E0ull), + u64x2(0x8E3C5B2F8E7594B7ull, 0x8FF6E2FBF2122B64ull), + u64x2(0x8888B812900DF01Cull, 0x4FAD5EA0688FC31Cull), + u64x2(0xD1CFF191B3A8C1ADull, 0x2F2F2218BE0E1777ull), + u64x2(0xEA752DFE8B021FA1ull, 0xE5A0CC0FB56F74E8ull), + u64x2(0x18ACF3D6CE89E299ull, 0xB4A84FE0FD13E0B7ull), + u64x2(0x7CC43B81D2ADA8D9ull, 0x165FA26680957705ull), + u64x2(0x93CC7314211A1477ull, 0xE6AD206577B5FA86ull), + u64x2(0xC75442F5FB9D35CFull, 0xEBCDAF0C7B3E89A0ull), + u64x2(0xD6411BD3AE1E7E49ull, 0x00250E2D2071B35Eull), + u64x2(0x226800BB57B8E0AFull, 0x2464369BF009B91Eull), + u64x2(0x5563911D59DFA6AAull, 0x78C14389D95A537Full), + u64x2(0x207D5BA202E5B9C5ull, 0x832603766295CFA9ull), + u64x2(0x11C819684E734A41ull, 0xB3472DCA7B14A94Aull), + u64x2(0x1B5100529A532915ull, 0xD60F573FBC9BC6E4ull), + u64x2(0x2B60A47681E67400ull, 0x08BA6FB5571BE91Full), + u64x2(0xF296EC6B2A0DD915ull, 0xB6636521E7B9F9B6ull), + u64x2(0xFF34052EC5855664ull, 0x53B02D5DA99F8FA1ull), + u64x2(0x08BA47996E85076Aull, 0x4B7A70E9B5B32944ull), + u64x2(0xDB75092EC4192623ull, 0xAD6EA6B049A7DF7Dull), + u64x2(0x9CEE60B88FEDB266ull, 0xECAA8C71699A18FFull), + u64x2(0x5664526CC2B19EE1ull, 0x193602A575094C29ull), + u64x2(0xA0591340E4183A3Eull, 0x3F54989A5B429D65ull), + u64x2(0x6B8FE4D699F73FD6ull, 0xA1D29C07EFE830F5ull), + u64x2(0x4D2D38E6F0255DC1ull, 0x4CDD20868470EB26ull), + u64x2(0x6382E9C6021ECC5Eull, 0x09686B3F3EBAEFC9ull), + u64x2(0x3C9718146B6A70A1ull, 0x687F358452A0E286ull), + u64x2(0xB79C5305AA500737ull, 0x3E07841C7FDEAE5Cull), + u64x2(0x8E7D44EC5716F2B8ull, 0xB03ADA37F0500C0Dull), + u64x2(0xF01C1F040200B3FFull, 0xAE0CF51A3CB574B2ull), + u64x2(0x25837A58DC0921BDull, 0xD19113F97CA92FF6ull), + u64x2(0x9432477322F54701ull, 0x3AE5E58137C2DADCull), + u64x2(0xC8B576349AF3DDA7ull, 0xA94461460FD0030Eull), + u64x2(0xECC8C73EA4751E41ull, 0xE238CD993BEA0E2Full), + u64x2(0x3280BBA1183EB331ull, 0x4E548B384F6DB908ull), + u64x2(0x6F420D03F60A04BFull, 0x2CB8129024977C79ull), + u64x2(0x5679B072BCAF89AFull, 0xDE9A771FD9930810ull), + u64x2(0xB38BAE12DCCF3F2Eull, 0x5512721F2E6B7124ull), + u64x2(0x501ADDE69F84CD87ull, 0x7A5847187408DA17ull), + u64x2(0xBC9F9ABCE94B7D8Cull, 0xEC7AEC3ADB851DFAull), + u64x2(0x63094366C464C3D2ull, 0xEF1C18473215D808ull), + u64x2(0xDD433B3724C2BA16ull, 0x12A14D432A65C451ull), + u64x2(0x50940002133AE4DDull, 0x71DFF89E10314E55ull), + u64x2(0x81AC77D65F11199Bull, 0x043556F1D7A3C76Bull), + u64x2(0x3C11183B5924A509ull, 0xF28FE6ED97F1FBFAull), + u64x2(0x9EBABF2C1E153C6Eull, 0x86E34570EAE96FB1ull), + u64x2(0x860E5E0A5A3E2AB3ull, 0x771FE71C4E3D06FAull), + u64x2(0x2965DCB999E71D0Full, 0x803E89D65266C825ull), + u64x2(0x2E4CC9789C10B36Aull, 0xC6150EBA94E2EA78ull), + u64x2(0xA6FC3C531E0A2DF4ull, 0xF2F74EA7361D2B3Dull), + u64x2(0x1939260F19C27960ull, 0x5223A708F71312B6ull), + u64x2(0xEBADFE6EEAC31F66ull, 0xE3BC4595A67BC883ull), + u64x2(0xB17F37D1018CFF28ull, 0xC332DDEFBE6C5AA5ull), + u64x2(0x6558218568AB9702ull, 0xEECEA50FDB2F953Bull), + u64x2(0x2AEF7DAD5B6E2F84ull, 0x1521B62829076170ull), + u64x2(0xECDD4775619F1510ull, 0x13CCA830EB61BD96ull), + u64x2(0x0334FE1EAA0363CFull, 0xB5735C904C70A239ull), + u64x2(0xD59E9E0BCBAADE14ull, 0xEECC86BC60622CA7ull), + u64x2(0x9CAB5CABB2F3846Eull, 0x648B1EAF19BDF0CAull), + u64x2(0xA02369B9655ABB50ull, 0x40685A323C2AB4B3ull), + u64x2(0x319EE9D5C021B8F7ull, 0x9B540B19875FA099ull), + u64x2(0x95F7997E623D7DA8ull, 0xF837889A97E32D77ull), + u64x2(0x11ED935F16681281ull, 0x0E358829C7E61FD6ull), + u64x2(0x96DEDFA17858BA99ull, 0x57F584A51B227263ull), + u64x2(0x9B83C3FF1AC24696ull, 0xCDB30AEB532E3054ull), + u64x2(0x8FD948E46DBC3128ull, 0x58EBF2EF34C6FFEAull), + u64x2(0xFE28ED61EE7C3C73ull, 0x5D4A14D9E864B7E3ull), + u64x2(0x42105D14203E13E0ull, 0x45EEE2B6A3AAABEAull), + u64x2(0xDB6C4F15FACB4FD0ull, 0xC742F442EF6ABBB5ull), + u64x2(0x654F3B1D41CD2105ull, 0xD81E799E86854DC7ull), + u64x2(0xE44B476A3D816250ull, 0xCF62A1F25B8D2646ull), + u64x2(0xFC8883A0C1C7B6A3ull, 0x7F1524C369CB7492ull), + u64x2(0x47848A0B5692B285ull, 0x095BBF00AD19489Dull), + u64x2(0x1462B17423820D00ull, 0x58428D2A0C55F5EAull), + u64x2(0x1DADF43E233F7061ull, 0x3372F0928D937E41ull), + u64x2(0xD65FECF16C223BDBull, 0x7CDE3759CBEE7460ull), + u64x2(0x4085F2A7CE77326Eull, 0xA607808419F8509Eull), + u64x2(0xE8EFD85561D99735ull, 0xA969A7AAC50C06C2ull), + u64x2(0x5A04ABFC800BCADCull, 0x9E447A2EC3453484ull), + u64x2(0xFDD567050E1E9EC9ull, 0xDB73DBD3105588CDull), + u64x2(0x675FDA79E3674340ull, 0xC5C43465713E38D8ull), + u64x2(0x3D28F89EF16DFF20ull, 0x153E21E78FB03D4Aull), + u64x2(0xE6E39F2BDB83ADF7ull, 0xE93D5A68948140F7ull), + u64x2(0xF64C261C94692934ull, 0x411520F77602D4F7ull), + u64x2(0xBCF46B2ED4A10068ull, 0xD40824713320F46Aull), + u64x2(0x43B7D4B7500061AFull, 0x1E39F62E97244546ull)}; + +#endif // ABSL_RANDOM_INTERNAL_RANDEN_KEYS_INC_ diff --git a/absl/random/internal/randen.cc b/absl/random/internal/randen.cc new file mode 100644 index 000000000000..bab8075a54f8 --- /dev/null +++ b/absl/random/internal/randen.cc @@ -0,0 +1,89 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/randen.h" + +#include "absl/base/internal/raw_logging.h" +#include "absl/random/internal/randen_detect.h" + +// RANDen = RANDom generator or beetroots in Swiss German. +// 'Strong' (well-distributed, unpredictable, backtracking-resistant) random +// generator, faster in some benchmarks than std::mt19937_64 and pcg64_c32. +// +// High-level summary: +// 1) Reverie (see "A Robust and Sponge-Like PRNG with Improved Efficiency") is +// a sponge-like random generator that requires a cryptographic permutation. +// It improves upon "Provably Robust Sponge-Based PRNGs and KDFs" by +// achieving backtracking resistance with only one Permute() per buffer. +// +// 2) "Simpira v2: A Family of Efficient Permutations Using the AES Round +// Function" constructs up to 1024-bit permutations using an improved +// Generalized Feistel network with 2-round AES-128 functions. This Feistel +// block shuffle achieves diffusion faster and is less vulnerable to +// sliced-biclique attacks than the Type-2 cyclic shuffle. +// +// 3) "Improving the Generalized Feistel" and "New criterion for diffusion +// property" extends the same kind of improved Feistel block shuffle to 16 +// branches, which enables a 2048-bit permutation. +// +// We combine these three ideas and also change Simpira's subround keys from +// structured/low-entropy counters to digits of Pi. + +namespace absl { +namespace random_internal { +namespace { + +struct RandenState { + const void* keys; + bool has_crypto; +}; + +RandenState GetRandenState() { + static const RandenState state = []() { + RandenState tmp; +#if ABSL_RANDOM_INTERNAL_AES_DISPATCH + // HW AES Dispatch. + if (HasRandenHwAesImplementation() && CPUSupportsRandenHwAes()) { + tmp.has_crypto = true; + tmp.keys = RandenHwAes::GetKeys(); + } else { + tmp.has_crypto = false; + tmp.keys = RandenSlow::GetKeys(); + } +#elif ABSL_HAVE_ACCELERATED_AES + // HW AES is enabled. + tmp.has_crypto = true; + tmp.keys = RandenHwAes::GetKeys(); +#else + // HW AES is disabled. + tmp.has_crypto = false; + tmp.keys = RandenSlow::GetKeys(); +#endif + return tmp; + }(); + return state; +} + +} // namespace + +Randen::Randen() { + auto tmp = GetRandenState(); + keys_ = tmp.keys; +#if ABSL_RANDOM_INTERNAL_AES_DISPATCH + has_crypto_ = tmp.has_crypto; +#endif +} + +} // namespace random_internal +} // namespace absl diff --git a/absl/random/internal/randen.h b/absl/random/internal/randen.h new file mode 100644 index 000000000000..a4ff25459688 --- /dev/null +++ b/absl/random/internal/randen.h @@ -0,0 +1,100 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_RANDEN_H_ +#define ABSL_RANDOM_INTERNAL_RANDEN_H_ + +#include <cstddef> + +#include "absl/random/internal/platform.h" +#include "absl/random/internal/randen_hwaes.h" +#include "absl/random/internal/randen_slow.h" +#include "absl/random/internal/randen_traits.h" + +namespace absl { +namespace random_internal { + +// RANDen = RANDom generator or beetroots in Swiss German. +// 'Strong' (well-distributed, unpredictable, backtracking-resistant) random +// generator, faster in some benchmarks than std::mt19937_64 and pcg64_c32. +// +// Randen implements the basic state manipulation methods. +class Randen { + public: + static constexpr size_t kStateBytes = RandenTraits::kStateBytes; + static constexpr size_t kCapacityBytes = RandenTraits::kCapacityBytes; + static constexpr size_t kSeedBytes = RandenTraits::kSeedBytes; + + ~Randen() = default; + + Randen(); + + // Generate updates the randen sponge. The outer portion of the sponge + // (kCapacityBytes .. kStateBytes) may be consumed as PRNG state. + template <typename T, size_t N> + void Generate(T (&state)[N]) const { + static_assert(N * sizeof(T) == kStateBytes, + "Randen::Generate() requires kStateBytes of state"); +#if ABSL_RANDOM_INTERNAL_AES_DISPATCH + // HW AES Dispatch. + if (has_crypto_) { + RandenHwAes::Generate(keys_, state); + } else { + RandenSlow::Generate(keys_, state); + } +#elif ABSL_HAVE_ACCELERATED_AES + // HW AES is enabled. + RandenHwAes::Generate(keys_, state); +#else + // HW AES is disabled. + RandenSlow::Generate(keys_, state); +#endif + } + + // Absorb incorporates additional seed material into the randen sponge. After + // absorb returns, Generate must be called before the state may be consumed. + template <typename S, size_t M, typename T, size_t N> + void Absorb(const S (&seed)[M], T (&state)[N]) const { + static_assert(M * sizeof(S) == RandenTraits::kSeedBytes, + "Randen::Absorb() requires kSeedBytes of seed"); + + static_assert(N * sizeof(T) == RandenTraits::kStateBytes, + "Randen::Absorb() requires kStateBytes of state"); +#if ABSL_RANDOM_INTERNAL_AES_DISPATCH + // HW AES Dispatch. + if (has_crypto_) { + RandenHwAes::Absorb(seed, state); + } else { + RandenSlow::Absorb(seed, state); + } +#elif ABSL_HAVE_ACCELERATED_AES + // HW AES is enabled. + RandenHwAes::Absorb(seed, state); +#else + // HW AES is disabled. + RandenSlow::Absorb(seed, state); +#endif + } + + private: + const void* keys_; +#if ABSL_RANDOM_INTERNAL_AES_DISPATCH + bool has_crypto_; +#endif +}; + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_RANDEN_H_ diff --git a/absl/random/internal/randen_benchmarks.cc b/absl/random/internal/randen_benchmarks.cc new file mode 100644 index 000000000000..f589172c0466 --- /dev/null +++ b/absl/random/internal/randen_benchmarks.cc @@ -0,0 +1,174 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +#include "absl/random/internal/randen.h" + +#include <cstdint> +#include <cstdio> +#include <cstring> + +#include "absl/base/internal/raw_logging.h" +#include "absl/random/internal/nanobenchmark.h" +#include "absl/random/internal/platform.h" +#include "absl/random/internal/randen_engine.h" +#include "absl/random/internal/randen_hwaes.h" +#include "absl/random/internal/randen_slow.h" +#include "absl/strings/numbers.h" + +namespace { + +using absl::random_internal::Randen; +using absl::random_internal::RandenHwAes; +using absl::random_internal::RandenSlow; + +using absl::random_internal_nanobenchmark::FuncInput; +using absl::random_internal_nanobenchmark::FuncOutput; +using absl::random_internal_nanobenchmark::InvariantTicksPerSecond; +using absl::random_internal_nanobenchmark::MeasureClosure; +using absl::random_internal_nanobenchmark::Params; +using absl::random_internal_nanobenchmark::PinThreadToCPU; +using absl::random_internal_nanobenchmark::Result; + +// Local state parameters. +static constexpr size_t kStateSizeT = Randen::kStateBytes / sizeof(uint64_t); +static constexpr size_t kSeedSizeT = Randen::kSeedBytes / sizeof(uint32_t); + +// Randen implementation benchmarks. +template <typename T> +struct AbsorbFn : public T { + mutable uint64_t state[kStateSizeT] = {}; + mutable uint32_t seed[kSeedSizeT] = {}; + + static constexpr size_t bytes() { return sizeof(seed); } + + FuncOutput operator()(const FuncInput num_iters) const { + for (size_t i = 0; i < num_iters; ++i) { + this->Absorb(seed, state); + } + return state[0]; + } +}; + +template <typename T> +struct GenerateFn : public T { + mutable uint64_t state[kStateSizeT]; + GenerateFn() { std::memset(state, 0, sizeof(state)); } + + static constexpr size_t bytes() { return sizeof(state); } + + FuncOutput operator()(const FuncInput num_iters) const { + const auto* keys = this->GetKeys(); + for (size_t i = 0; i < num_iters; ++i) { + this->Generate(keys, state); + } + return state[0]; + } +}; + +template <typename UInt> +struct Engine { + mutable absl::random_internal::randen_engine<UInt> rng; + + static constexpr size_t bytes() { return sizeof(UInt); } + + FuncOutput operator()(const FuncInput num_iters) const { + for (size_t i = 0; i < num_iters - 1; ++i) { + rng(); + } + return rng(); + } +}; + +template <size_t N> +void Print(const char* name, const size_t n, const Result (&results)[N], + const size_t bytes) { + if (n == 0) { + ABSL_RAW_LOG( + WARNING, + "WARNING: Measurement failed, should not happen when using " + "PinThreadToCPU unless the region to measure takes > 1 second.\n"); + return; + } + + static const double ns_per_tick = 1e9 / InvariantTicksPerSecond(); + static constexpr const double kNsPerS = 1e9; // ns/s + static constexpr const double kMBPerByte = 1.0 / 1048576.0; // Mb / b + static auto header = [] { + return printf("%20s %8s: %12s ticks; %9s (%9s) %8s\n", "Name", "Count", + "Total", "Variance", "Time", "bytes/s"); + }(); + (void)header; + + for (size_t i = 0; i < n; ++i) { + const double ticks_per_call = results[i].ticks / results[i].input; + const double ns_per_call = ns_per_tick * ticks_per_call; + const double bytes_per_ns = bytes / ns_per_call; + const double mb_per_s = bytes_per_ns * kNsPerS * kMBPerByte; + // Output + printf("%20s %8zu: %12.2f ticks; MAD=%4.2f%% (%6.1f ns) %8.1f Mb/s\n", + name, results[i].input, results[i].ticks, + results[i].variability * 100.0, ns_per_call, mb_per_s); + } +} + +// Fails here +template <typename Op, size_t N> +void Measure(const char* name, const FuncInput (&inputs)[N]) { + Op op; + + Result results[N]; + Params params; + params.verbose = false; + params.max_evals = 6; // avoid test timeout + const size_t num_results = MeasureClosure(op, inputs, N, results, params); + Print(name, num_results, results, op.bytes()); +} + +// unpredictable == 1 but the compiler does not know that. +void RunAll(const int argc, char* argv[]) { + if (argc == 2) { + int cpu = -1; + if (!absl::SimpleAtoi(argv[1], &cpu)) { + ABSL_RAW_LOG(FATAL, "The optional argument must be a CPU number >= 0.\n"); + } + PinThreadToCPU(cpu); + } + + // The compiler cannot reduce this to a constant. + const FuncInput unpredictable = (argc != 999); + static const FuncInput inputs[] = {unpredictable * 100, unpredictable * 1000}; + +#if !defined(ABSL_INTERNAL_DISABLE_AES) && ABSL_HAVE_ACCELERATED_AES + Measure<AbsorbFn<RandenHwAes>>("Absorb (HwAes)", inputs); +#endif + Measure<AbsorbFn<RandenSlow>>("Absorb (Slow)", inputs); + +#if !defined(ABSL_INTERNAL_DISABLE_AES) && ABSL_HAVE_ACCELERATED_AES + Measure<GenerateFn<RandenHwAes>>("Generate (HwAes)", inputs); +#endif + Measure<GenerateFn<RandenSlow>>("Generate (Slow)", inputs); + + // Measure the production engine. + static const FuncInput inputs1[] = {unpredictable * 1000, + unpredictable * 10000}; + Measure<Engine<uint64_t>>("randen_engine<uint64_t>", inputs1); + Measure<Engine<uint32_t>>("randen_engine<uint32_t>", inputs1); +} + +} // namespace + +int main(int argc, char* argv[]) { + RunAll(argc, argv); + return 0; +} diff --git a/absl/random/internal/randen_detect.cc b/absl/random/internal/randen_detect.cc new file mode 100644 index 000000000000..d5946b219cc2 --- /dev/null +++ b/absl/random/internal/randen_detect.cc @@ -0,0 +1,219 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the"License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an"AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// HERMETIC NOTE: The randen_hwaes target must not introduce duplicate +// symbols from arbitrary system and other headers, since it may be built +// with different flags from other targets, using different levels of +// optimization, potentially introducing ODR violations. + +#include "absl/random/internal/randen_detect.h" + +#include <cstdint> +#include <cstring> + +#include "absl/random/internal/platform.h" + +#if defined(ABSL_ARCH_X86_64) +#define ABSL_INTERNAL_USE_X86_CPUID +#elif defined(ABSL_ARCH_PPC) || defined(ABSL_ARCH_ARM) || \ + defined(ABSL_ARCH_AARCH64) +#if defined(__ANDROID__) +#define ABSL_INTERNAL_USE_ANDROID_GETAUXVAL +#define ABSL_INTERNAL_USE_GETAUXVAL +#elif defined(__linux__) +#define ABSL_INTERNAL_USE_LINUX_GETAUXVAL +#define ABSL_INTERNAL_USE_GETAUXVAL +#endif +#endif + +#if defined(ABSL_INTERNAL_USE_X86_CPUID) +#if defined(_WIN32) || defined(_WIN64) +#include <intrin.h> // NOLINT(build/include_order) +#pragma intrinsic(__cpuid) +#else +// MSVC-equivalent __cpuid intrinsic function. +static void __cpuid(int cpu_info[4], int info_type) { + __asm__ volatile("cpuid \n\t" + : "=a"(cpu_info[0]), "=b"(cpu_info[1]), "=c"(cpu_info[2]), + "=d"(cpu_info[3]) + : "a"(info_type), "c"(0)); +} +#endif +#endif // ABSL_INTERNAL_USE_X86_CPUID + +// On linux, just use the c-library getauxval call. +#if defined(ABSL_INTERNAL_USE_LINUX_GETAUXVAL) + +extern "C" unsigned long getauxval(unsigned long type); // NOLINT(runtime/int) + +static uint32_t GetAuxval(uint32_t hwcap_type) { + return static_cast<uint32_t>(getauxval(hwcap_type)); +} + +#endif + +// On android, probe the system's C library for getauxval(). +// This is the same technique used by the android NDK cpu features library +// as well as the google open-source cpu_features library. +// +// TODO(absl-team): Consider implementing a fallback of directly reading +// /proc/self/auxval. +#if defined(ABSL_INTERNAL_USE_ANDROID_GETAUXVAL) +#include <dlfcn.h> + +static uint32_t GetAuxval(uint32_t hwcap_type) { + // NOLINTNEXTLINE(runtime/int) + typedef unsigned long (*getauxval_func_t)(unsigned long); + + dlerror(); // Cleaning error state before calling dlopen. + void* libc_handle = dlopen("libc.so", RTLD_NOW); + if (!libc_handle) { + return 0; + } + uint32_t result = 0; + void* sym = dlsym(libc_handle, "getauxval"); + if (sym) { + getauxval_func_t func; + memcpy(&func, &sym, sizeof(func)); + result = static_cast<uint32_t>((*func)(hwcap_type)); + } + dlclose(libc_handle); + return result; +} + +#endif + +namespace absl { +namespace random_internal { + +// The default return at the end of the function might be unreachable depending +// on the configuration. Ignore that warning. +#if defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunreachable-code-return" +#endif + +// CPUSupportsRandenHwAes returns whether the CPU is a microarchitecture +// which supports the crpyto/aes instructions or extensions necessary to use the +// accelerated RandenHwAes implementation. +// +// 1. For x86 it is sufficient to use the CPUID instruction to detect whether +// the cpu supports AES instructions. Done. +// +// Fon non-x86 it is much more complicated. +// +// 2. When ABSL_INTERNAL_USE_GETAUXVAL is defined, use getauxval() (either +// the direct c-library version, or the android probing version which loads +// libc), and read the hardware capability bits. +// This is based on the technique used by boringssl uses to detect +// cpu capabilities, and should allow us to enable crypto in the android +// builds where it is supported. +// +// 3. Use the default for the compiler architecture. +// + +bool CPUSupportsRandenHwAes() { +#if defined(ABSL_INTERNAL_USE_X86_CPUID) + // 1. For x86: Use CPUID to detect the required AES instruction set. + int regs[4]; + __cpuid(reinterpret_cast<int*>(regs), 1); + return regs[2] & (1 << 25); // AES + +#elif defined(ABSL_INTERNAL_USE_GETAUXVAL) + // 2. Use getauxval() to read the hardware bits and determine + // cpu capabilities. + +#define AT_HWCAP 16 +#define AT_HWCAP2 26 +#if defined(ABSL_ARCH_PPC) + // For Power / PPC: Expect that the cpu supports VCRYPTO + // See https://members.openpowerfoundation.org/document/dl/576 + // VCRYPTO should be present in POWER8 >= 2.07. + // Uses Linux kernel constants from arch/powerpc/include/uapi/asm/cputable.h + static const uint32_t kVCRYPTO = 0x02000000; + const uint32_t hwcap = GetAuxval(AT_HWCAP2); + return (hwcap & kVCRYPTO) != 0; + +#elif defined(ABSL_ARCH_ARM) + // For ARM: Require crypto+neon + // http://infocenter.arm.com/help/index.jsp?topic=/com.arm.doc.ddi0500f/CIHBIBBA.html + // Uses Linux kernel constants from arch/arm64/include/asm/hwcap.h + static const uint32_t kNEON = 1 << 12; + uint32_t hwcap = GetAuxval(AT_HWCAP); + if ((hwcap & kNEON) == 0) { + return false; + } + + // And use it again to detect AES. + static const uint32_t kAES = 1 << 0; + const uint32_t hwcap2 = GetAuxval(AT_HWCAP2); + return (hwcap2 & kAES) != 0; + +#elif defined(ABSL_ARCH_AARCH64) + // For AARCH64: Require crypto+neon + // http://infocenter.arm.com/help/index.jsp?topic=/com.arm.doc.ddi0500f/CIHBIBBA.html + static const uint32_t kNEON = 1 << 1; + static const uint32_t kAES = 1 << 3; + const uint32_t hwcap = GetAuxval(AT_HWCAP); + return ((hwcap & kNEON) != 0) && ((hwcap & kAES) != 0); +#endif + +#else // ABSL_INTERNAL_USE_GETAUXVAL + // 3. By default, assume that the compiler default. + return ABSL_HAVE_ACCELERATED_AES ? true : false; + +#endif + // NOTE: There are some other techniques that may be worth trying: + // + // * Use an environment variable: ABSL_RANDOM_USE_HWAES + // + // * Rely on compiler-generated target-based dispatch. + // Using x86/gcc it might look something like this: + // + // int __attribute__((target("aes"))) HasAes() { return 1; } + // int __attribute__((target("default"))) HasAes() { return 0; } + // + // This does not work on all architecture/compiler combinations. + // + // * On Linux consider reading /proc/cpuinfo and/or /proc/self/auxv. + // These files have lines which are easy to parse; for ARM/AARCH64 it is quite + // easy to find the Features: line and extract aes / neon. Likewise for + // PPC. + // + // * Fork a process and test for SIGILL: + // + // * Many architectures have instructions to read the ISA. Unfortunately + // most of those require that the code is running in ring 0 / + // protected-mode. + // + // There are several examples. e.g. Valgrind detects PPC ISA 2.07: + // https://github.com/lu-zero/valgrind/blob/master/none/tests/ppc64/test_isa_2_07_part1.c + // + // MRS <Xt>, ID_AA64ISAR0_EL1 ; Read ID_AA64ISAR0_EL1 into Xt + // + // uint64_t val; + // __asm __volatile("mrs %0, id_aa64isar0_el1" :"=&r" (val)); + // + // * Use a CPUID-style heuristic database. + // + // * On Apple (__APPLE__), AES is available on Arm v8. + // https://stackoverflow.com/questions/45637888/how-to-determine-armv8-features-at-runtime-on-ios +} + +#if defined(__clang__) +#pragma clang diagnostic pop +#endif + +} // namespace random_internal +} // namespace absl diff --git a/absl/random/internal/randen_detect.h b/absl/random/internal/randen_detect.h new file mode 100644 index 000000000000..ab45f348b5a5 --- /dev/null +++ b/absl/random/internal/randen_detect.h @@ -0,0 +1,29 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_RANDEN_DETECT_H_ +#define ABSL_RANDOM_INTERNAL_RANDEN_DETECT_H_ + +namespace absl { +namespace random_internal { + +// Returns whether the current CPU supports RandenHwAes implementation. +// This typically involves supporting cryptographic extensions on whichever +// platform is currently running. +bool CPUSupportsRandenHwAes(); + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_RANDEN_FAST_H_ diff --git a/absl/random/internal/randen_engine.h b/absl/random/internal/randen_engine.h new file mode 100644 index 000000000000..02212a137c36 --- /dev/null +++ b/absl/random/internal/randen_engine.h @@ -0,0 +1,228 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_RANDEN_ENGINE_H_ +#define ABSL_RANDOM_INTERNAL_RANDEN_ENGINE_H_ + +#include <algorithm> +#include <cinttypes> +#include <cstdlib> +#include <iostream> +#include <iterator> +#include <limits> +#include <type_traits> + +#include "absl/meta/type_traits.h" +#include "absl/random/internal/iostream_state_saver.h" +#include "absl/random/internal/randen.h" + +namespace absl { +namespace random_internal { + +// Deterministic pseudorandom byte generator with backtracking resistance +// (leaking the state does not compromise prior outputs). Based on Reverie +// (see "A Robust and Sponge-Like PRNG with Improved Efficiency") instantiated +// with an improved Simpira-like permutation. +// Returns values of type "T" (must be a built-in unsigned integer type). +// +// RANDen = RANDom generator or beetroots in Swiss High German. +// 'Strong' (well-distributed, unpredictable, backtracking-resistant) random +// generator, faster in some benchmarks than std::mt19937_64 and pcg64_c32. +template <typename T> +class alignas(16) randen_engine { + public: + // C++11 URBG interface: + using result_type = T; + static_assert(std::is_unsigned<result_type>::value, + "randen_engine template argument must be a built-in unsigned " + "integer type"); + + static constexpr result_type(min)() { + return (std::numeric_limits<result_type>::min)(); + } + + static constexpr result_type(max)() { + return (std::numeric_limits<result_type>::max)(); + } + + explicit randen_engine(result_type seed_value = 0) { seed(seed_value); } + + template <class SeedSequence, + typename = typename absl::enable_if_t< + !std::is_same<SeedSequence, randen_engine>::value>> + explicit randen_engine(SeedSequence&& seq) { + seed(seq); + } + + randen_engine(const randen_engine&) = default; + + // Returns random bits from the buffer in units of result_type. + result_type operator()() { + // Refill the buffer if needed (unlikely). + if (next_ >= kStateSizeT) { + next_ = kCapacityT; + impl_.Generate(state_); + } + + return state_[next_++]; + } + + template <class SeedSequence> + typename absl::enable_if_t< + !std::is_convertible<SeedSequence, result_type>::value> + seed(SeedSequence&& seq) { + // Zeroes the state. + seed(); + reseed(seq); + } + + void seed(result_type seed_value = 0) { + next_ = kStateSizeT; + // Zeroes the inner state and fills the outer state with seed_value to + // mimics behaviour of reseed + std::fill(std::begin(state_), std::begin(state_) + kCapacityT, 0); + std::fill(std::begin(state_) + kCapacityT, std::end(state_), seed_value); + } + + // Inserts entropy into (part of) the state. Calling this periodically with + // sufficient entropy ensures prediction resistance (attackers cannot predict + // future outputs even if state is compromised). + template <class SeedSequence> + void reseed(SeedSequence& seq) { + using sequence_result_type = typename SeedSequence::result_type; + static_assert(sizeof(sequence_result_type) == 4, + "SeedSequence::result_type must be 32-bit"); + + constexpr size_t kBufferSize = + Randen::kSeedBytes / sizeof(sequence_result_type); + alignas(16) sequence_result_type buffer[kBufferSize]; + + // Randen::Absorb XORs the seed into state, which is then mixed by a call + // to Randen::Generate. Seeding with only the provided entropy is preferred + // to using an arbitrary generate() call, so use [rand.req.seed_seq] + // size as a proxy for the number of entropy units that can be generated + // without relying on seed sequence mixing... + const size_t entropy_size = seq.size(); + if (entropy_size < kBufferSize) { + // ... and only request that many values, or 256-bits, when unspecified. + const size_t requested_entropy = (entropy_size == 0) ? 8u : entropy_size; + std::fill(std::begin(buffer) + requested_entropy, std::end(buffer), 0); + seq.generate(std::begin(buffer), std::begin(buffer) + requested_entropy); + // The Randen paper suggests preferentially initializing even-numbered + // 128-bit vectors of the randen state (there are 16 such vectors). + // The seed data is merged into the state offset by 128-bits, which + // implies prefering seed bytes [16..31, ..., 208..223]. Since the + // buffer is 32-bit values, we swap the corresponding buffer positions in + // 128-bit chunks. + size_t dst = kBufferSize; + while (dst > 7) { + // leave the odd bucket as-is. + dst -= 4; + size_t src = dst >> 1; + // swap 128-bits into the even bucket + std::swap(buffer[--dst], buffer[--src]); + std::swap(buffer[--dst], buffer[--src]); + std::swap(buffer[--dst], buffer[--src]); + std::swap(buffer[--dst], buffer[--src]); + } + } else { + seq.generate(std::begin(buffer), std::end(buffer)); + } + impl_.Absorb(buffer, state_); + + // Generate will be called when operator() is called + next_ = kStateSizeT; + } + + void discard(uint64_t count) { + uint64_t step = std::min<uint64_t>(kStateSizeT - next_, count); + count -= step; + + constexpr uint64_t kRateT = kStateSizeT - kCapacityT; + while (count > 0) { + next_ = kCapacityT; + impl_.Generate(state_); + step = std::min<uint64_t>(kRateT, count); + count -= step; + } + next_ += step; + } + + bool operator==(const randen_engine& other) const { + return next_ == other.next_ && + std::equal(std::begin(state_), std::end(state_), + std::begin(other.state_)); + } + + bool operator!=(const randen_engine& other) const { + return !(*this == other); + } + + template <class CharT, class Traits> + friend std::basic_ostream<CharT, Traits>& operator<<( + std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references) + const randen_engine<T>& engine) { // NOLINT(runtime/references) + using numeric_type = + typename random_internal::stream_format_type<result_type>::type; + auto saver = random_internal::make_ostream_state_saver(os); + for (const auto& elem : engine.state_) { + // In the case that `elem` is `uint8_t`, it must be cast to something + // larger so that it prints as an integer rather than a character. For + // simplicity, apply the cast all circumstances. + os << static_cast<numeric_type>(elem) << os.fill(); + } + os << engine.next_; + return os; + } + + template <class CharT, class Traits> + friend std::basic_istream<CharT, Traits>& operator>>( + std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references) + randen_engine<T>& engine) { // NOLINT(runtime/references) + using numeric_type = + typename random_internal::stream_format_type<result_type>::type; + result_type state[kStateSizeT]; + size_t next; + for (auto& elem : state) { + // It is not possible to read uint8_t from wide streams, so it is + // necessary to read a wider type and then cast it to uint8_t. + numeric_type value; + is >> value; + elem = static_cast<result_type>(value); + } + is >> next; + if (is.fail()) { + return is; + } + std::memcpy(engine.state_, state, sizeof(engine.state_)); + engine.next_ = next; + return is; + } + + private: + static constexpr size_t kStateSizeT = + Randen::kStateBytes / sizeof(result_type); + static constexpr size_t kCapacityT = + Randen::kCapacityBytes / sizeof(result_type); + + // First kCapacityT are `inner', the others are accessible random bits. + alignas(16) result_type state_[kStateSizeT]; + size_t next_; // index within state_ + Randen impl_; +}; + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_RANDEN_ENGINE_H_ diff --git a/absl/random/internal/randen_engine_test.cc b/absl/random/internal/randen_engine_test.cc new file mode 100644 index 000000000000..c8e7685bddad --- /dev/null +++ b/absl/random/internal/randen_engine_test.cc @@ -0,0 +1,656 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/randen_engine.h" + +#include <algorithm> +#include <bitset> +#include <random> +#include <sstream> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/base/internal/raw_logging.h" +#include "absl/random/internal/explicit_seed_seq.h" +#include "absl/strings/str_cat.h" +#include "absl/time/clock.h" + +#define UPDATE_GOLDEN 0 + +using randen_u64 = absl::random_internal::randen_engine<uint64_t>; +using randen_u32 = absl::random_internal::randen_engine<uint32_t>; +using absl::random_internal::ExplicitSeedSeq; + +namespace { + +template <typename UIntType> +class RandenEngineTypedTest : public ::testing::Test {}; + +using UIntTypes = ::testing::Types<uint8_t, uint16_t, uint32_t, uint64_t>; + +TYPED_TEST_SUITE(RandenEngineTypedTest, UIntTypes); + +TYPED_TEST(RandenEngineTypedTest, VerifyReseedChangesAllValues) { + using randen = typename absl::random_internal::randen_engine<TypeParam>; + using result_type = typename randen::result_type; + + const size_t kNumOutputs = (sizeof(randen) * 2 / sizeof(TypeParam)) + 1; + randen engine; + + // MSVC emits error 2719 without the use of std::ref below. + // * formal parameter with __declspec(align('#')) won't be aligned + + { + std::seed_seq seq1{1, 2, 3, 4, 5, 6, 7}; + engine.seed(seq1); + } + result_type a[kNumOutputs]; + std::generate(std::begin(a), std::end(a), std::ref(engine)); + + { + std::random_device rd; + std::seed_seq seq2{rd(), rd(), rd()}; + engine.seed(seq2); + } + result_type b[kNumOutputs]; + std::generate(std::begin(b), std::end(b), std::ref(engine)); + + // Test that generated sequence changed as sequence of bits, i.e. if about + // half of the bites were flipped between two non-correlated values. + size_t changed_bits = 0; + size_t unchanged_bits = 0; + size_t total_set = 0; + size_t total_bits = 0; + size_t equal_count = 0; + for (size_t i = 0; i < kNumOutputs; ++i) { + equal_count += (a[i] == b[i]) ? 1 : 0; + std::bitset<sizeof(result_type) * 8> bitset(a[i] ^ b[i]); + changed_bits += bitset.count(); + unchanged_bits += bitset.size() - bitset.count(); + + std::bitset<sizeof(result_type) * 8> a_set(a[i]); + std::bitset<sizeof(result_type) * 8> b_set(b[i]); + total_set += a_set.count() + b_set.count(); + total_bits += 2 * 8 * sizeof(result_type); + } + // On average, half the bits are changed between two calls. + EXPECT_LE(changed_bits, 0.60 * (changed_bits + unchanged_bits)); + EXPECT_GE(changed_bits, 0.40 * (changed_bits + unchanged_bits)); + + // Verify using a quick normal-approximation to the binomial. + EXPECT_NEAR(total_set, total_bits * 0.5, 4 * std::sqrt(total_bits)) + << "@" << total_set / static_cast<double>(total_bits); + + // Also, A[i] == B[i] with probability (1/range) * N. + // Give this a pretty wide latitude, though. + const double kExpected = kNumOutputs / (1.0 * sizeof(result_type) * 8); + EXPECT_LE(equal_count, 1.0 + kExpected); +} + +// Number of values that needs to be consumed to clean two sizes of buffer +// and trigger third refresh. (slightly overestimates the actual state size). +constexpr size_t kTwoBufferValues = sizeof(randen_u64) / sizeof(uint16_t) + 1; + +TYPED_TEST(RandenEngineTypedTest, VerifyDiscard) { + using randen = typename absl::random_internal::randen_engine<TypeParam>; + + for (size_t num_used = 0; num_used < kTwoBufferValues; ++num_used) { + randen engine_used; + for (size_t i = 0; i < num_used; ++i) { + engine_used(); + } + + for (size_t num_discard = 0; num_discard < kTwoBufferValues; + ++num_discard) { + randen engine1 = engine_used; + randen engine2 = engine_used; + for (size_t i = 0; i < num_discard; ++i) { + engine1(); + } + engine2.discard(num_discard); + for (size_t i = 0; i < kTwoBufferValues; ++i) { + const auto r1 = engine1(); + const auto r2 = engine2(); + ASSERT_EQ(r1, r2) << "used=" << num_used << " discard=" << num_discard; + } + } + } +} + +TYPED_TEST(RandenEngineTypedTest, StreamOperatorsResult) { + using randen = typename absl::random_internal::randen_engine<TypeParam>; + std::wostringstream os; + std::wistringstream is; + randen engine; + + EXPECT_EQ(&(os << engine), &os); + EXPECT_EQ(&(is >> engine), &is); +} + +TYPED_TEST(RandenEngineTypedTest, StreamSerialization) { + using randen = typename absl::random_internal::randen_engine<TypeParam>; + + for (size_t discard = 0; discard < kTwoBufferValues; ++discard) { + ExplicitSeedSeq seed_sequence{12, 34, 56}; + randen engine(seed_sequence); + engine.discard(discard); + + std::stringstream stream; + stream << engine; + + randen new_engine; + stream >> new_engine; + for (size_t i = 0; i < 64; ++i) { + EXPECT_EQ(engine(), new_engine()) << " " << i; + } + } +} + +constexpr size_t kNumGoldenOutputs = 127; + +// This test is checking if randen_engine is meets interface requirements +// defined in [rand.req.urbg]. +TYPED_TEST(RandenEngineTypedTest, RandomNumberEngineInterface) { + using randen = typename absl::random_internal::randen_engine<TypeParam>; + + using E = randen; + using T = typename E::result_type; + + static_assert(std::is_copy_constructible<E>::value, + "randen_engine must be copy constructible"); + + static_assert(absl::is_copy_assignable<E>::value, + "randen_engine must be copy assignable"); + + static_assert(std::is_move_constructible<E>::value, + "randen_engine must be move constructible"); + + static_assert(absl::is_move_assignable<E>::value, + "randen_engine must be move assignable"); + + static_assert(std::is_same<decltype(std::declval<E>()()), T>::value, + "return type of operator() must be result_type"); + + // Names after definition of [rand.req.urbg] in C++ standard. + // e us a value of E + // v is a lvalue of E + // x, y are possibly const values of E + // s is a value of T + // q is a value satisfying requirements of seed_sequence + // z is a value of type unsigned long long + // os is a some specialization of basic_ostream + // is is a some specialization of basic_istream + + E e, v; + const E x, y; + T s = 1; + std::seed_seq q{1, 2, 3}; + unsigned long long z = 1; // NOLINT(runtime/int) + std::wostringstream os; + std::wistringstream is; + + E{}; + E{x}; + E{s}; + E{q}; + + e.seed(); + + // MSVC emits error 2718 when using EXPECT_EQ(e, x) + // * actual parameter with __declspec(align('#')) won't be aligned + EXPECT_TRUE(e == x); + + e.seed(q); + { + E tmp(q); + EXPECT_TRUE(e == tmp); + } + + e(); + { + E tmp(q); + EXPECT_TRUE(e != tmp); + } + + e.discard(z); + + static_assert(std::is_same<decltype(x == y), bool>::value, + "return type of operator== must be bool"); + + static_assert(std::is_same<decltype(x != y), bool>::value, + "return type of operator== must be bool"); +} + +TYPED_TEST(RandenEngineTypedTest, RandenEngineSFINAETest) { + using randen = typename absl::random_internal::randen_engine<TypeParam>; + using result_type = typename randen::result_type; + + { + randen engine(result_type(1)); + engine.seed(result_type(1)); + } + + { + result_type n = 1; + randen engine(n); + engine.seed(n); + } + + { + randen engine(1); + engine.seed(1); + } + + { + int n = 1; + randen engine(n); + engine.seed(n); + } + + { + std::seed_seq seed_seq; + randen engine(seed_seq); + engine.seed(seed_seq); + } + + { + randen engine{std::seed_seq()}; + engine.seed(std::seed_seq()); + } +} + +TEST(RandenTest, VerifyGoldenRanden64Default) { + constexpr uint64_t kGolden[kNumGoldenOutputs] = { + 0xc3c14f134e433977, 0xdda9f47cd90410ee, 0x887bf3087fd8ca10, + 0xf0b780f545c72912, 0x15dbb1d37696599f, 0x30ec63baff3c6d59, + 0xb29f73606f7f20a6, 0x02808a316f49a54c, 0x3b8feaf9d5c8e50e, + 0x9cbf605e3fd9de8a, 0xc970ae1a78183bbb, 0xd8b2ffd356301ed5, + 0xf4b327fe0fc73c37, 0xcdfd8d76eb8f9a19, 0xc3a506eb91420c9d, + 0xd5af05dd3eff9556, 0x48db1bb78f83c4a1, 0x7023920e0d6bfe8c, + 0x58d3575834956d42, 0xed1ef4c26b87b840, 0x8eef32a23e0b2df3, + 0x497cabf3431154fc, 0x4e24370570029a8b, 0xd88b5749f090e5ea, + 0xc651a582a970692f, 0x78fcec2cbb6342f5, 0x463cb745612f55db, + 0x352ee4ad1816afe3, 0x026ff374c101da7e, 0x811ef0821c3de851, + 0x6f7e616704c4fa59, 0xa0660379992d58fc, 0x04b0a374a3b795c7, + 0x915f3445685da798, 0x26802a8ac76571ce, 0x4663352533ce1882, + 0xb9fdefb4a24dc738, 0x5588ba3a4d6e6c51, 0xa2101a42d35f1956, + 0x607195a5e200f5fd, 0x7e100308f3290764, 0xe1e5e03c759c0709, + 0x082572cc5da6606f, 0xcbcf585399e432f1, 0xe8a2be4f8335d8f1, + 0x0904469acbfee8f2, 0xf08bd31b6daecd51, 0x08e8a1f1a69da69a, + 0x6542a20aad57bff5, 0x2e9705bb053d6b46, 0xda2fc9db0713c391, + 0x78e3a810213b6ffb, 0xdc16a59cdd85f8a6, 0xc0932718cd55781f, + 0xb9bfb29c2b20bfe5, 0xb97289c1be0f2f9c, 0xc0a2a0e403a892d4, + 0x5524bb834771435b, 0x8265da3d39d1a750, 0xff4af3ab8d1b78c5, + 0xf0ec5f424bcad77f, 0x66e455f627495189, 0xc82d3120b57e3270, + 0x3424e47dc22596e3, 0xbc0c95129ccedcdd, 0xc191c595afc4dcbf, + 0x120392bd2bb70939, 0x7f90650ea6cd6ab4, 0x7287491832695ad3, + 0xa7c8fac5a7917eb0, 0xd088cb9418be0361, 0x7c1bf9839c7c1ce5, + 0xe2e991fa58e1e79e, 0x78565cdefd28c4ad, 0x7351b9fef98bafad, + 0x2a9eac28b08c96bf, 0x6c4f179696cb2225, 0x13a685861bab87e0, + 0x64c6de5aa0501971, 0x30537425cac70991, 0x01590d9dc6c532b7, + 0x7e05e3aa8ec720dc, 0x74a07d9c54e3e63f, 0x738184388f3bc1d2, + 0x26ffdc5067be3acb, 0x6bcdf185561f255f, 0xa0eaf2e1cf99b1c6, + 0x171df81934f68604, 0x7ea5a21665683e5a, 0x5d1cb02075ba1cea, + 0x957f38cbd2123fdf, 0xba6364eff80de02f, 0x606e0a0e41d452ee, + 0x892d8317de82f7a2, 0xe707b1db50f7b43e, 0x4eb28826766fcf5b, + 0x5a362d56e80a0951, 0x6ee217df16527d78, 0xf6737962ba6b23dd, + 0x443e63857d4076ca, 0x790d9a5f048adfeb, 0xd796b052151ee94d, + 0x033ed95c12b04a03, 0x8b833ff84893da5d, 0x3d6724b1bb15eab9, + 0x9877c4225061ca76, 0xd68d6810adf74fb3, 0x42e5352fe30ce989, + 0x265b565a7431fde7, 0x3cdbf7e358df4b8b, 0x2922a47f6d3e8779, + 0x52d2242f65b37f88, 0x5d836d6e2958d6b5, 0x29d40f00566d5e26, + 0x288db0e1124b14a0, 0x6c056608b7d9c1b6, 0x0b9471bdb8f19d32, + 0x8fb946504faa6c9d, 0x8943a9464540251c, 0xfd1fe27d144a09e0, + 0xea6ac458da141bda, 0x8048f217633fce36, 0xfeda1384ade74d31, + 0x4334b8b02ff7612f, 0xdbc8441f5227e216, 0x096d119a3605c85b, + 0x2b72b31c21b7d7d0}; + + randen_u64 engine; +#if UPDATE_GOLDEN + (void)kGolden; // Silence warning. + for (size_t i = 0; i < kNumGoldenOutputs; ++i) { + printf("0x%016lx, ", engine()); + if (i % 3 == 2) { + printf("\n"); + } + } + printf("\n\n\n"); +#else + for (const auto& elem : kGolden) { + EXPECT_EQ(elem, engine()); + } + engine.seed(); + for (const auto& elem : kGolden) { + EXPECT_EQ(elem, engine()); + } +#endif +} + +TEST(RandenTest, VerifyGoldenRanden64Seeded) { + constexpr uint64_t kGolden[kNumGoldenOutputs] = { + 0x83a9e58f94d3dcd5, 0x70bbdff3d97949fb, 0x0438481f7471c1b4, + 0x34fdc58ee5fb5930, 0xceee4f2d2a937d17, 0xb5a26a68e432aea9, + 0x8b64774a3fb51740, 0xd89ac1fc74249c74, 0x03910d1d23fc3fdf, + 0xd38f630878aa897f, 0x0ee8f0f5615f7e44, 0x98f5a53df8279d52, + 0xb403f52c25938d0e, 0x240072996ea6e838, 0xd3a791246190fa61, + 0xaaedd3df7a7b4f80, 0xc6eacabe05deaf6e, 0xb7967dd8790edf4d, + 0x9a0a8e67e049d279, 0x0494f606aebc23e7, 0x598dcd687bc3e0ee, + 0x010ac81802d452a1, 0x6407c87160aa2842, 0x5a56e276486f93a0, + 0xc887a399d46a8f02, 0x9e1e6100fe93b740, 0x12d02e330f8901f6, + 0xc39ca52b47e790b7, 0xb0b0a2fa11e82e61, 0x1542d841a303806a, + 0x1fe659fd7d6e9d86, 0xb8c90d80746541ac, 0x239d56a5669ddc94, + 0xd40db57c8123d13c, 0x3abc2414153a0db0, 0x9bad665630cb8d61, + 0x0bd1fb90ee3f4bbc, 0x8f0b4d7e079b4e42, 0xfa0fb0e0ee59e793, + 0x51080b283e071100, 0x2c4b9e715081cc15, 0xbe10ed49de4941df, + 0xf8eaac9d4b1b0d37, 0x4bcce4b54605e139, 0xa64722b76765dda6, + 0xb9377d738ca28ab5, 0x779fad81a8ccc1af, 0x65cb3ee61ffd3ba7, + 0xd74e79087862836f, 0xd05b9c584c3f25bf, 0x2ba93a4693579827, + 0xd81530aff05420ce, 0xec06cea215478621, 0x4b1798a6796d65ad, + 0xf142f3fb3a6f6fa6, 0x002b7bf7e237b560, 0xf47f2605ef65b4f8, + 0x9804ec5517effc18, 0xaed3d7f8b7d481cd, 0x5651c24c1ce338d1, + 0x3e7a38208bf0a3c6, 0x6796a7b614534aed, 0x0d0f3b848358460f, + 0x0fa5fe7600b19524, 0x2b0cf38253faaedc, 0x10df9188233a9fd6, + 0x3a10033880138b59, 0x5fb0b0d23948e80f, 0x9e76f7b02fbf5350, + 0x0816052304b1a985, 0x30c9880db41fd218, 0x14aa399b65e20f28, + 0xe1454a8cace787b4, 0x325ac971b6c6f0f5, 0x716b1aa2784f3d36, + 0x3d5ce14accfd144f, 0x6c0c97710f651792, 0xbc5b0f59fb333532, + 0x2a90a7d2140470bc, 0x8da269f55c1e1c8d, 0xcfc37143895792ca, + 0xbe21eab1f30b238f, 0x8c47229dee4d65fd, 0x5743614ed1ed7d54, + 0x351372a99e9c476e, 0x2bd5ea15e5db085f, 0x6925fde46e0af4ca, + 0xed3eda2bdc1f45bd, 0xdef68c68d460fa6e, 0xe42a0de76253e2b5, + 0x4e5176dcbc29c305, 0xbfd85fba9f810f6e, 0x76a5a2a9beb815c6, + 0x01edc4ddceaf414c, 0xa4e98904b4bb3b4b, 0x00bd63ac7d2f1ddd, + 0xb8491fe6e998ddbb, 0xb386a3463dda6800, 0x0081887688871619, + 0x33d394b3344e9a38, 0x815dba65a3a8baf9, 0x4232f6ec02c2fd1a, + 0xb5cff603edd20834, 0x580189243f687663, 0xa8d5a2cbdc27fe99, + 0x725d881693fa0131, 0xa2be2c13db2c7ac5, 0x7b6a9614b509fd78, + 0xb6b136d71e717636, 0x660f1a71aff046ea, 0x0ba10ae346c8ec9e, + 0xe66dde53e3145b41, 0x3b18288c88c26be6, 0x4d9d9d2ff02db933, + 0x4167da8c70f46e8a, 0xf183beef8c6318b4, 0x4d889e1e71eeeef1, + 0x7175c71ad6689b6b, 0xfb9e42beacd1b7dd, 0xc33d0e91b29b5e0d, + 0xd39b83291ce47922, 0xc4d570fb8493d12e, 0x23d5a5724f424ae6, + 0x5245f161876b6616, 0x38d77dbd21ab578d, 0x9c3423311f4ecbfe, + 0x76fe31389bacd9d5, + }; + + ExplicitSeedSeq seed_sequence{12, 34, 56}; + randen_u64 engine(seed_sequence); +#if UPDATE_GOLDEN + (void)kGolden; // Silence warning. + for (size_t i = 0; i < kNumGoldenOutputs; ++i) { + printf("0x%016lx, ", engine()); + if (i % 3 == 2) { + printf("\n"); + } + } + printf("\n\n\n"); +#else + for (const auto& elem : kGolden) { + EXPECT_EQ(elem, engine()); + } + engine.seed(seed_sequence); + for (const auto& elem : kGolden) { + EXPECT_EQ(elem, engine()); + } +#endif +} + +TEST(RandenTest, VerifyGoldenRanden32Default) { + constexpr uint64_t kGolden[2 * kNumGoldenOutputs] = { + 0x4e433977, 0xc3c14f13, 0xd90410ee, 0xdda9f47c, 0x7fd8ca10, 0x887bf308, + 0x45c72912, 0xf0b780f5, 0x7696599f, 0x15dbb1d3, 0xff3c6d59, 0x30ec63ba, + 0x6f7f20a6, 0xb29f7360, 0x6f49a54c, 0x02808a31, 0xd5c8e50e, 0x3b8feaf9, + 0x3fd9de8a, 0x9cbf605e, 0x78183bbb, 0xc970ae1a, 0x56301ed5, 0xd8b2ffd3, + 0x0fc73c37, 0xf4b327fe, 0xeb8f9a19, 0xcdfd8d76, 0x91420c9d, 0xc3a506eb, + 0x3eff9556, 0xd5af05dd, 0x8f83c4a1, 0x48db1bb7, 0x0d6bfe8c, 0x7023920e, + 0x34956d42, 0x58d35758, 0x6b87b840, 0xed1ef4c2, 0x3e0b2df3, 0x8eef32a2, + 0x431154fc, 0x497cabf3, 0x70029a8b, 0x4e243705, 0xf090e5ea, 0xd88b5749, + 0xa970692f, 0xc651a582, 0xbb6342f5, 0x78fcec2c, 0x612f55db, 0x463cb745, + 0x1816afe3, 0x352ee4ad, 0xc101da7e, 0x026ff374, 0x1c3de851, 0x811ef082, + 0x04c4fa59, 0x6f7e6167, 0x992d58fc, 0xa0660379, 0xa3b795c7, 0x04b0a374, + 0x685da798, 0x915f3445, 0xc76571ce, 0x26802a8a, 0x33ce1882, 0x46633525, + 0xa24dc738, 0xb9fdefb4, 0x4d6e6c51, 0x5588ba3a, 0xd35f1956, 0xa2101a42, + 0xe200f5fd, 0x607195a5, 0xf3290764, 0x7e100308, 0x759c0709, 0xe1e5e03c, + 0x5da6606f, 0x082572cc, 0x99e432f1, 0xcbcf5853, 0x8335d8f1, 0xe8a2be4f, + 0xcbfee8f2, 0x0904469a, 0x6daecd51, 0xf08bd31b, 0xa69da69a, 0x08e8a1f1, + 0xad57bff5, 0x6542a20a, 0x053d6b46, 0x2e9705bb, 0x0713c391, 0xda2fc9db, + 0x213b6ffb, 0x78e3a810, 0xdd85f8a6, 0xdc16a59c, 0xcd55781f, 0xc0932718, + 0x2b20bfe5, 0xb9bfb29c, 0xbe0f2f9c, 0xb97289c1, 0x03a892d4, 0xc0a2a0e4, + 0x4771435b, 0x5524bb83, 0x39d1a750, 0x8265da3d, 0x8d1b78c5, 0xff4af3ab, + 0x4bcad77f, 0xf0ec5f42, 0x27495189, 0x66e455f6, 0xb57e3270, 0xc82d3120, + 0xc22596e3, 0x3424e47d, 0x9ccedcdd, 0xbc0c9512, 0xafc4dcbf, 0xc191c595, + 0x2bb70939, 0x120392bd, 0xa6cd6ab4, 0x7f90650e, 0x32695ad3, 0x72874918, + 0xa7917eb0, 0xa7c8fac5, 0x18be0361, 0xd088cb94, 0x9c7c1ce5, 0x7c1bf983, + 0x58e1e79e, 0xe2e991fa, 0xfd28c4ad, 0x78565cde, 0xf98bafad, 0x7351b9fe, + 0xb08c96bf, 0x2a9eac28, 0x96cb2225, 0x6c4f1796, 0x1bab87e0, 0x13a68586, + 0xa0501971, 0x64c6de5a, 0xcac70991, 0x30537425, 0xc6c532b7, 0x01590d9d, + 0x8ec720dc, 0x7e05e3aa, 0x54e3e63f, 0x74a07d9c, 0x8f3bc1d2, 0x73818438, + 0x67be3acb, 0x26ffdc50, 0x561f255f, 0x6bcdf185, 0xcf99b1c6, 0xa0eaf2e1, + 0x34f68604, 0x171df819, 0x65683e5a, 0x7ea5a216, 0x75ba1cea, 0x5d1cb020, + 0xd2123fdf, 0x957f38cb, 0xf80de02f, 0xba6364ef, 0x41d452ee, 0x606e0a0e, + 0xde82f7a2, 0x892d8317, 0x50f7b43e, 0xe707b1db, 0x766fcf5b, 0x4eb28826, + 0xe80a0951, 0x5a362d56, 0x16527d78, 0x6ee217df, 0xba6b23dd, 0xf6737962, + 0x7d4076ca, 0x443e6385, 0x048adfeb, 0x790d9a5f, 0x151ee94d, 0xd796b052, + 0x12b04a03, 0x033ed95c, 0x4893da5d, 0x8b833ff8, 0xbb15eab9, 0x3d6724b1, + 0x5061ca76, 0x9877c422, 0xadf74fb3, 0xd68d6810, 0xe30ce989, 0x42e5352f, + 0x7431fde7, 0x265b565a, 0x58df4b8b, 0x3cdbf7e3, 0x6d3e8779, 0x2922a47f, + 0x65b37f88, 0x52d2242f, 0x2958d6b5, 0x5d836d6e, 0x566d5e26, 0x29d40f00, + 0x124b14a0, 0x288db0e1, 0xb7d9c1b6, 0x6c056608, 0xb8f19d32, 0x0b9471bd, + 0x4faa6c9d, 0x8fb94650, 0x4540251c, 0x8943a946, 0x144a09e0, 0xfd1fe27d, + 0xda141bda, 0xea6ac458, 0x633fce36, 0x8048f217, 0xade74d31, 0xfeda1384, + 0x2ff7612f, 0x4334b8b0, 0x5227e216, 0xdbc8441f, 0x3605c85b, 0x096d119a, + 0x21b7d7d0, 0x2b72b31c}; + + randen_u32 engine; +#if UPDATE_GOLDEN + (void)kGolden; // Silence warning. + for (size_t i = 0; i < 2 * kNumGoldenOutputs; ++i) { + printf("0x%08x, ", engine()); + if (i % 6 == 5) { + printf("\n"); + } + } + printf("\n\n\n"); +#else + for (const auto& elem : kGolden) { + EXPECT_EQ(elem, engine()); + } + engine.seed(); + for (const auto& elem : kGolden) { + EXPECT_EQ(elem, engine()); + } +#endif +} + +TEST(RandenTest, VerifyGoldenRanden32Seeded) { + constexpr uint64_t kGolden[2 * kNumGoldenOutputs] = { + 0x94d3dcd5, 0x83a9e58f, 0xd97949fb, 0x70bbdff3, 0x7471c1b4, 0x0438481f, + 0xe5fb5930, 0x34fdc58e, 0x2a937d17, 0xceee4f2d, 0xe432aea9, 0xb5a26a68, + 0x3fb51740, 0x8b64774a, 0x74249c74, 0xd89ac1fc, 0x23fc3fdf, 0x03910d1d, + 0x78aa897f, 0xd38f6308, 0x615f7e44, 0x0ee8f0f5, 0xf8279d52, 0x98f5a53d, + 0x25938d0e, 0xb403f52c, 0x6ea6e838, 0x24007299, 0x6190fa61, 0xd3a79124, + 0x7a7b4f80, 0xaaedd3df, 0x05deaf6e, 0xc6eacabe, 0x790edf4d, 0xb7967dd8, + 0xe049d279, 0x9a0a8e67, 0xaebc23e7, 0x0494f606, 0x7bc3e0ee, 0x598dcd68, + 0x02d452a1, 0x010ac818, 0x60aa2842, 0x6407c871, 0x486f93a0, 0x5a56e276, + 0xd46a8f02, 0xc887a399, 0xfe93b740, 0x9e1e6100, 0x0f8901f6, 0x12d02e33, + 0x47e790b7, 0xc39ca52b, 0x11e82e61, 0xb0b0a2fa, 0xa303806a, 0x1542d841, + 0x7d6e9d86, 0x1fe659fd, 0x746541ac, 0xb8c90d80, 0x669ddc94, 0x239d56a5, + 0x8123d13c, 0xd40db57c, 0x153a0db0, 0x3abc2414, 0x30cb8d61, 0x9bad6656, + 0xee3f4bbc, 0x0bd1fb90, 0x079b4e42, 0x8f0b4d7e, 0xee59e793, 0xfa0fb0e0, + 0x3e071100, 0x51080b28, 0x5081cc15, 0x2c4b9e71, 0xde4941df, 0xbe10ed49, + 0x4b1b0d37, 0xf8eaac9d, 0x4605e139, 0x4bcce4b5, 0x6765dda6, 0xa64722b7, + 0x8ca28ab5, 0xb9377d73, 0xa8ccc1af, 0x779fad81, 0x1ffd3ba7, 0x65cb3ee6, + 0x7862836f, 0xd74e7908, 0x4c3f25bf, 0xd05b9c58, 0x93579827, 0x2ba93a46, + 0xf05420ce, 0xd81530af, 0x15478621, 0xec06cea2, 0x796d65ad, 0x4b1798a6, + 0x3a6f6fa6, 0xf142f3fb, 0xe237b560, 0x002b7bf7, 0xef65b4f8, 0xf47f2605, + 0x17effc18, 0x9804ec55, 0xb7d481cd, 0xaed3d7f8, 0x1ce338d1, 0x5651c24c, + 0x8bf0a3c6, 0x3e7a3820, 0x14534aed, 0x6796a7b6, 0x8358460f, 0x0d0f3b84, + 0x00b19524, 0x0fa5fe76, 0x53faaedc, 0x2b0cf382, 0x233a9fd6, 0x10df9188, + 0x80138b59, 0x3a100338, 0x3948e80f, 0x5fb0b0d2, 0x2fbf5350, 0x9e76f7b0, + 0x04b1a985, 0x08160523, 0xb41fd218, 0x30c9880d, 0x65e20f28, 0x14aa399b, + 0xace787b4, 0xe1454a8c, 0xb6c6f0f5, 0x325ac971, 0x784f3d36, 0x716b1aa2, + 0xccfd144f, 0x3d5ce14a, 0x0f651792, 0x6c0c9771, 0xfb333532, 0xbc5b0f59, + 0x140470bc, 0x2a90a7d2, 0x5c1e1c8d, 0x8da269f5, 0x895792ca, 0xcfc37143, + 0xf30b238f, 0xbe21eab1, 0xee4d65fd, 0x8c47229d, 0xd1ed7d54, 0x5743614e, + 0x9e9c476e, 0x351372a9, 0xe5db085f, 0x2bd5ea15, 0x6e0af4ca, 0x6925fde4, + 0xdc1f45bd, 0xed3eda2b, 0xd460fa6e, 0xdef68c68, 0x6253e2b5, 0xe42a0de7, + 0xbc29c305, 0x4e5176dc, 0x9f810f6e, 0xbfd85fba, 0xbeb815c6, 0x76a5a2a9, + 0xceaf414c, 0x01edc4dd, 0xb4bb3b4b, 0xa4e98904, 0x7d2f1ddd, 0x00bd63ac, + 0xe998ddbb, 0xb8491fe6, 0x3dda6800, 0xb386a346, 0x88871619, 0x00818876, + 0x344e9a38, 0x33d394b3, 0xa3a8baf9, 0x815dba65, 0x02c2fd1a, 0x4232f6ec, + 0xedd20834, 0xb5cff603, 0x3f687663, 0x58018924, 0xdc27fe99, 0xa8d5a2cb, + 0x93fa0131, 0x725d8816, 0xdb2c7ac5, 0xa2be2c13, 0xb509fd78, 0x7b6a9614, + 0x1e717636, 0xb6b136d7, 0xaff046ea, 0x660f1a71, 0x46c8ec9e, 0x0ba10ae3, + 0xe3145b41, 0xe66dde53, 0x88c26be6, 0x3b18288c, 0xf02db933, 0x4d9d9d2f, + 0x70f46e8a, 0x4167da8c, 0x8c6318b4, 0xf183beef, 0x71eeeef1, 0x4d889e1e, + 0xd6689b6b, 0x7175c71a, 0xacd1b7dd, 0xfb9e42be, 0xb29b5e0d, 0xc33d0e91, + 0x1ce47922, 0xd39b8329, 0x8493d12e, 0xc4d570fb, 0x4f424ae6, 0x23d5a572, + 0x876b6616, 0x5245f161, 0x21ab578d, 0x38d77dbd, 0x1f4ecbfe, 0x9c342331, + 0x9bacd9d5, 0x76fe3138, + }; + + ExplicitSeedSeq seed_sequence{12, 34, 56}; + randen_u32 engine(seed_sequence); +#if UPDATE_GOLDEN + (void)kGolden; // Silence warning. + for (size_t i = 0; i < 2 * kNumGoldenOutputs; ++i) { + printf("0x%08x, ", engine()); + if (i % 6 == 5) { + printf("\n"); + } + } + printf("\n\n\n"); +#else + for (const auto& elem : kGolden) { + EXPECT_EQ(elem, engine()); + } + engine.seed(seed_sequence); + for (const auto& elem : kGolden) { + EXPECT_EQ(elem, engine()); + } +#endif +} + +TEST(RandenTest, VerifyGoldenFromDeserializedEngine) { + constexpr uint64_t kGolden[kNumGoldenOutputs] = { + 0x067f9f9ab919657a, 0x0534605912988583, 0x8a303f72feaa673f, + 0x77b7fd747909185c, 0xd9af90403c56d891, 0xd939c6cb204d14b5, + 0x7fbe6b954a47b483, 0x8b31a47cc34c768d, 0x3a9e546da2701a9c, + 0x5246539046253e71, 0x417191ffb2a848a1, 0x7b1c7bf5a5001d09, + 0x9489b15d194f2361, 0xfcebdeea3bcd2461, 0xd643027c854cec97, + 0x5885397f91e0d21c, 0x53173b0efae30d58, 0x1c9c71168449fac1, + 0xe358202b711ed8aa, 0x94e3918ed1d8227c, 0x5bb4e251450144cf, + 0xb5c7a519b489af3b, 0x6f8b560b1f7b3469, 0xfde11dd4a1c74eef, + 0x33383d2f76457dcf, 0x3060c0ec6db9fce1, 0x18f451fcddeec766, + 0xe73c5d6b9f26da2a, 0x8d4cc566671b32a4, 0xb8189b73776bc9ff, + 0x497a70f9caf0bc23, 0x23afcc509791dcea, 0x18af70dc4b27d306, + 0xd3853f955a0ce5b9, 0x441db6c01a0afb17, 0xd0136c3fb8e1f13f, + 0x5e4fd6fc2f33783c, 0xe0d24548adb5da51, 0x0f4d8362a7d3485a, + 0x9f572d68270fa563, 0x6351fbc823024393, 0xa66dbfc61810e9ab, + 0x0ff17fc14b651af8, 0xd74c55dafb99e623, 0x36303bc1ad85c6c2, + 0x4920cd6a2af7e897, 0x0b8848addc30fecd, 0x9e1562eda6488e93, + 0x197553807d607828, 0xbef5eaeda5e21235, 0x18d91d2616aca527, + 0xb7821937f5c873cd, 0x2cd4ae5650dbeefc, 0xb35a64376f75ffdf, + 0x9226d414d647fe07, 0x663f3db455bbb35e, 0xa829eead6ae93247, + 0x7fd69c204dd0d25f, 0xbe1411f891c9acb1, 0xd476f34a506d5f11, + 0xf423d2831649c5ca, 0x1e503962951abd75, 0xeccc9e8b1e34b537, + 0xb11a147294044854, 0xc4cf27f0abf4929d, 0xe9193abf6fa24c8c, + 0xa94a259e3aba8808, 0x21dc414197deffa3, 0xa2ae211d1ff622ae, + 0xfe3995c46be5a4f4, 0xe9984c284bf11128, 0xcb1ce9d2f0851a80, + 0x42fee17971d87cd8, 0xac76a98d177adc88, 0xa0973b3dedc4af6f, + 0xdf56d6bbcb1b8e86, 0xf1e6485f407b11c9, 0x2c63de4deccb15c0, + 0x6fe69db32ed4fad7, 0xaa51a65f84bca1f1, 0x242f2ee81d608afc, + 0x8eb88b2b69fc153b, 0x22c20098baf73fd1, 0x57759466f576488c, + 0x075ca562cea1be9d, 0x9a74814d73d28891, 0x73d1555fc02f4d3d, + 0xc17f8f210ee89337, 0x46cca7999eaeafd4, 0x5db8d6a327a0d8ac, + 0xb79b4f93c738d7a1, 0x9994512f0036ded1, 0xd3883026f38747f4, + 0xf31f7458078d097c, 0x736ce4d480680669, 0x7a496f4c7e1033e3, + 0xecf85bf297fbc68c, 0x9e37e1d0f24f3c4e, 0x15b6e067ca0746fc, + 0xdd4a39905c5db81c, 0xb5dfafa7bcfdf7da, 0xca6646fb6f92a276, + 0x1c6b35f363ef0efd, 0x6a33d06037ad9f76, 0x45544241afd8f80f, + 0x83f8d83f859c90c5, 0x22aea9c5365e8c19, 0xfac35b11f20b6a6a, + 0xd1acf49d1a27dd2f, 0xf281cd09c4fed405, 0x076000a42cd38e4f, + 0x6ace300565070445, 0x463a62781bddc4db, 0x1477126b46b569ac, + 0x127f2bb15035fbb8, 0xdfa30946049c04a8, 0x89072a586ba8dd3e, + 0x62c809582bb7e74d, 0x22c0c3641406c28b, 0x9b66e36c47ff004d, + 0xb9cd2c7519653330, 0x18608d79cd7a598d, 0x92c0bd1323e53e32, + 0x887ff00de8524aa5, 0xa074410b787abd10, 0x18ab41b8057a2063, + 0x1560abf26bc5f987}; + +#if UPDATE_GOLDEN + (void)kGolden; // Silence warning. + std::seed_seq seed_sequence{1, 2, 3, 4, 5}; + randen_u64 engine(seed_sequence); + std::ostringstream stream; + stream << engine; + auto str = stream.str(); + printf("%s\n\n", str.c_str()); + for (size_t i = 0; i < kNumGoldenOutputs; ++i) { + printf("0x%016lx, ", engine()); + if (i % 3 == 2) { + printf("\n"); + } + } + printf("\n\n\n"); +#else + randen_u64 engine; + std::istringstream stream( + "0 0 9824501439887287479 3242284395352394785 243836530774933777 " + "4047941804708365596 17165468127298385802 949276103645889255 " + "10659970394998657921 1657570836810929787 11697746266668051452 " + "9967209969299905230 14140390331161524430 7383014124183271684 " + "13146719127702337852 13983155220295807171 11121125587542359264 " + "195757810993252695 17138580243103178492 11326030747260920501 " + "8585097322474965590 18342582839328350995 15052982824209724634 " + "7321861343874683609 1806786911778767826 10100850842665572955 " + "9249328950653985078 13600624835326909759 11137960060943860251 " + "10208781341792329629 9282723971471525577 16373271619486811032 32"); + stream >> engine; + for (const auto& elem : kGolden) { + EXPECT_EQ(elem, engine()); + } +#endif +} + +TEST(RandenTest, IsFastOrSlow) { + // randen_engine typically costs ~5ns per value for the optimized code paths, + // and the ~1000ns per value for slow code paths. However when running under + // msan, asan, etc. it can take much longer. + // + // The estimated operation time is something like: + // + // linux, optimized ~5ns + // ppc, optimized ~7ns + // nacl (slow), ~1100ns + // + // `kCount` is chosen below so that, in debug builds and without hardware + // acceleration, the test (assuming ~1us per call) should finish in ~0.1s + static constexpr size_t kCount = 100000; + randen_u64 engine; + randen_u64::result_type sum = 0; + auto start = absl::GetCurrentTimeNanos(); + for (int i = 0; i < kCount; i++) { + sum += engine(); + } + auto duration = absl::GetCurrentTimeNanos() - start; + + ABSL_INTERNAL_LOG(INFO, absl::StrCat(static_cast<double>(duration) / + static_cast<double>(kCount), + "ns")); + + EXPECT_GT(sum, 0); + EXPECT_GE(duration, kCount); // Should be slower than 1ns per call. +} + +} // namespace diff --git a/absl/random/internal/randen_hwaes.cc b/absl/random/internal/randen_hwaes.cc new file mode 100644 index 000000000000..0fcd9a85a8b9 --- /dev/null +++ b/absl/random/internal/randen_hwaes.cc @@ -0,0 +1,666 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// HERMETIC NOTE: The randen_hwaes target must not introduce duplicate +// symbols from arbitrary system and other headers, since it may be built +// with different flags from other targets, using different levels of +// optimization, potentially introducing ODR violations. + +#include "absl/random/internal/randen_hwaes.h" + +#include <cstdint> +#include <cstring> + +#include "absl/random/internal/platform.h" + +// ABSL_RANDEN_HWAES_IMPL indicates whether this file will contain +// a hardware accelerated implementation of randen, or whether it +// will contain stubs that exit the process. +#if defined(ABSL_ARCH_X86_64) || defined(ABSL_ARCH_X86_32) +// The platform.h directives are sufficient to indicate whether +// we should build accelerated implementations for x86. +#if (ABSL_HAVE_ACCELERATED_AES || ABSL_RANDOM_INTERNAL_AES_DISPATCH) +#define ABSL_RANDEN_HWAES_IMPL 1 +#endif +#elif defined(ABSL_ARCH_PPC) +// The platform.h directives are sufficient to indicate whether +// we should build accelerated implementations for PPC. +// +// NOTE: This has mostly been tested on 64-bit Power variants, +// and not embedded cpus such as powerpc32-8540 +#if ABSL_HAVE_ACCELERATED_AES +#define ABSL_RANDEN_HWAES_IMPL 1 +#endif +#elif defined(ABSL_ARCH_ARM) || defined(ABSL_ARCH_AARCH64) +// ARM is somewhat more complicated. We might support crypto natively... +#if ABSL_HAVE_ACCELERATED_AES || \ + (defined(__ARM_NEON) && defined(__ARM_FEATURE_CRYPTO)) +#define ABSL_RANDEN_HWAES_IMPL 1 + +#elif ABSL_RANDOM_INTERNAL_AES_DISPATCH && !defined(__APPLE__) && \ + (defined(__GNUC__) && __GNUC__ > 4 || __GNUC__ == 4 && __GNUC_MINOR__ > 9) +// ...or, on GCC, we can use an ASM directive to +// instruct the assember to allow crypto instructions. +#define ABSL_RANDEN_HWAES_IMPL 1 +#define ABSL_RANDEN_HWAES_IMPL_CRYPTO_DIRECTIVE 1 +#endif +#else +// HWAES is unsupported by these architectures / platforms: +// __myriad2__ +// __mips__ +// +// Other architectures / platforms are unknown. +// +// See the Abseil documentation on supported macros at: +// https://abseil.io/docs/cpp/platforms/macros +#endif + +#if !defined(ABSL_RANDEN_HWAES_IMPL) +// No accelerated implementation is supported. +// The RandenHwAes functions are stubs that print an error and exit. + +#include <cstdio> +#include <cstdlib> + +namespace absl { +namespace random_internal { + +// No accelerated implementation. +bool HasRandenHwAesImplementation() { return false; } + +// NOLINTNEXTLINE +const void* RandenHwAes::GetKeys() { + // Attempted to dispatch to an unsupported dispatch target. + const int d = ABSL_RANDOM_INTERNAL_AES_DISPATCH; + fprintf(stderr, "AES Hardware detection failed (%d).\n", d); + exit(1); + return nullptr; +} + +// NOLINTNEXTLINE +void RandenHwAes::Absorb(const void*, void*) { + // Attempted to dispatch to an unsupported dispatch target. + const int d = ABSL_RANDOM_INTERNAL_AES_DISPATCH; + fprintf(stderr, "AES Hardware detection failed (%d).\n", d); + exit(1); +} + +// NOLINTNEXTLINE +void RandenHwAes::Generate(const void*, void*) { + // Attempted to dispatch to an unsupported dispatch target. + const int d = ABSL_RANDOM_INTERNAL_AES_DISPATCH; + fprintf(stderr, "AES Hardware detection failed (%d).\n", d); + exit(1); +} + +} // namespace random_internal +} // namespace absl + +#else // defined(ABSL_RANDEN_HWAES_IMPL) +// +// Accelerated implementations are supported. +// We need the per-architecture includes and defines. +// + +#include "absl/random/internal/randen_traits.h" + +// ABSL_FUNCTION_ALIGN32 defines a 32-byte alignment attribute +// for the functions in this file. +// +// NOTE: Determine whether we actually have any wins from ALIGN32 +// using microbenchmarks. If not, remove. +#undef ABSL_FUNCTION_ALIGN32 +#if ABSL_HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__)) +#define ABSL_FUNCTION_ALIGN32 __attribute__((aligned(32))) +#else +#define ABSL_FUNCTION_ALIGN32 +#endif + +// TARGET_CRYPTO defines a crypto attribute for each architecture. +// +// NOTE: Evaluate whether we should eliminate ABSL_TARGET_CRYPTO. +#if (defined(__clang__) || defined(__GNUC__)) +#if defined(ABSL_ARCH_X86_64) || defined(ABSL_ARCH_X86_32) +#define ABSL_TARGET_CRYPTO __attribute__((target("aes"))) +#elif defined(ABSL_ARCH_PPC) +#define ABSL_TARGET_CRYPTO __attribute__((target("crypto"))) +#else +#define ABSL_TARGET_CRYPTO +#endif +#else +#define ABSL_TARGET_CRYPTO +#endif + +#if defined(ABSL_ARCH_PPC) +// NOTE: Keep in mind that PPC can operate in little-endian or big-endian mode, +// however the PPC altivec vector registers (and thus the AES instructions) +// always operate in big-endian mode. + +#include <altivec.h> +// <altivec.h> #defines vector __vector; in C++, this is bad form. +#undef vector + +// Rely on the PowerPC AltiVec vector operations for accelerated AES +// instructions. GCC support of the PPC vector types is described in: +// https://gcc.gnu.org/onlinedocs/gcc-4.9.0/gcc/PowerPC-AltiVec_002fVSX-Built-in-Functions.html +// +// Already provides operator^=. +using Vector128 = __vector unsigned long long; // NOLINT(runtime/int) + +namespace { + +inline ABSL_TARGET_CRYPTO ABSL_ATTRIBUTE_ALWAYS_INLINE Vector128 +ReverseBytes(const Vector128& v) { + // Reverses the bytes of the vector. + const __vector unsigned char perm = {15, 14, 13, 12, 11, 10, 9, 8, + 7, 6, 5, 4, 3, 2, 1, 0}; + return vec_perm(v, v, perm); +} + +// WARNING: these load/store in native byte order. It is OK to load and then +// store an unchanged vector, but interpreting the bits as a number or input +// to AES will have undefined results. +inline ABSL_TARGET_CRYPTO ABSL_ATTRIBUTE_ALWAYS_INLINE Vector128 +Vector128Load(const void* ABSL_RANDOM_INTERNAL_RESTRICT from) { + return vec_vsx_ld(0, reinterpret_cast<const Vector128*>(from)); +} + +inline ABSL_TARGET_CRYPTO ABSL_ATTRIBUTE_ALWAYS_INLINE void Vector128Store( + const Vector128& v, void* ABSL_RANDOM_INTERNAL_RESTRICT to) { + vec_vsx_st(v, 0, reinterpret_cast<Vector128*>(to)); +} + +// One round of AES. "round_key" is a public constant for breaking the +// symmetry of AES (ensures previously equal columns differ afterwards). +inline ABSL_TARGET_CRYPTO ABSL_ATTRIBUTE_ALWAYS_INLINE Vector128 +AesRound(const Vector128& state, const Vector128& round_key) { + return Vector128(__builtin_crypto_vcipher(state, round_key)); +} + +// Enables native loads in the round loop by pre-swapping. +inline ABSL_TARGET_CRYPTO ABSL_ATTRIBUTE_ALWAYS_INLINE void SwapEndian( + uint64_t* ABSL_RANDOM_INTERNAL_RESTRICT state) { + using absl::random_internal::RandenTraits; + constexpr size_t kLanes = 2; + constexpr size_t kFeistelBlocks = RandenTraits::kFeistelBlocks; + + for (uint32_t branch = 0; branch < kFeistelBlocks; ++branch) { + const Vector128 v = ReverseBytes(Vector128Load(state + kLanes * branch)); + Vector128Store(v, state + kLanes * branch); + } +} + +} // namespace + +#elif defined(ABSL_ARCH_ARM) || defined(ABSL_ARCH_AARCH64) + +// This asm directive will cause the file to be compiled with crypto extensions +// whether or not the cpu-architecture supports it. +#if ABSL_RANDEN_HWAES_IMPL_CRYPTO_DIRECTIVE +asm(".arch_extension crypto\n"); + +// Override missing defines. +#if !defined(__ARM_NEON) +#define __ARM_NEON 1 +#endif + +#if !defined(__ARM_FEATURE_CRYPTO) +#define __ARM_FEATURE_CRYPTO 1 +#endif + +#endif + +// Rely on the ARM NEON+Crypto advanced simd types, defined in <arm_neon.h>. +// uint8x16_t is the user alias for underlying __simd128_uint8_t type. +// http://infocenter.arm.com/help/topic/com.arm.doc.ihi0073a/IHI0073A_arm_neon_intrinsics_ref.pdf +// +// <arm_neon> defines the following +// +// typedef __attribute__((neon_vector_type(16))) uint8_t uint8x16_t; +// typedef __attribute__((neon_vector_type(16))) int8_t int8x16_t; +// typedef __attribute__((neon_polyvector_type(16))) int8_t poly8x16_t; +// +// vld1q_v +// vst1q_v +// vaeseq_v +// vaesmcq_v +#include <arm_neon.h> + +// Already provides operator^=. +using Vector128 = uint8x16_t; + +namespace { + +inline ABSL_TARGET_CRYPTO ABSL_ATTRIBUTE_ALWAYS_INLINE Vector128 +Vector128Load(const void* ABSL_RANDOM_INTERNAL_RESTRICT from) { + return vld1q_u8(reinterpret_cast<const uint8_t*>(from)); +} + +inline ABSL_TARGET_CRYPTO ABSL_ATTRIBUTE_ALWAYS_INLINE void Vector128Store( + const Vector128& v, void* ABSL_RANDOM_INTERNAL_RESTRICT to) { + vst1q_u8(reinterpret_cast<uint8_t*>(to), v); +} + +// One round of AES. "round_key" is a public constant for breaking the +// symmetry of AES (ensures previously equal columns differ afterwards). +inline ABSL_TARGET_CRYPTO ABSL_ATTRIBUTE_ALWAYS_INLINE Vector128 +AesRound(const Vector128& state, const Vector128& round_key) { + // It is important to always use the full round function - omitting the + // final MixColumns reduces security [https://eprint.iacr.org/2010/041.pdf] + // and does not help because we never decrypt. + // + // Note that ARM divides AES instructions differently than x86 / PPC, + // And we need to skip the first AddRoundKey step and add an extra + // AddRoundKey step to the end. Lucky for us this is just XOR. + return vaesmcq_u8(vaeseq_u8(state, uint8x16_t{})) ^ round_key; +} + +inline ABSL_TARGET_CRYPTO ABSL_ATTRIBUTE_ALWAYS_INLINE void SwapEndian( + uint64_t* ABSL_RANDOM_INTERNAL_RESTRICT) {} + +} // namespace + +#elif defined(ABSL_ARCH_X86_64) || defined(ABSL_ARCH_X86_32) +// On x86 we rely on the aesni instructions +#include <wmmintrin.h> + +namespace { + +// Vector128 class is only wrapper for __m128i, benchmark indicates that it's +// faster than using __m128i directly. +class Vector128 { + public: + // Convert from/to intrinsics. + inline ABSL_ATTRIBUTE_ALWAYS_INLINE explicit Vector128( + const __m128i& Vector128) + : data_(Vector128) {} + + inline ABSL_ATTRIBUTE_ALWAYS_INLINE __m128i data() const { return data_; } + + inline ABSL_ATTRIBUTE_ALWAYS_INLINE Vector128& operator^=( + const Vector128& other) { + data_ = _mm_xor_si128(data_, other.data()); + return *this; + } + + private: + __m128i data_; +}; + +inline ABSL_TARGET_CRYPTO ABSL_ATTRIBUTE_ALWAYS_INLINE Vector128 +Vector128Load(const void* ABSL_RANDOM_INTERNAL_RESTRICT from) { + return Vector128(_mm_load_si128(reinterpret_cast<const __m128i*>(from))); +} + +inline ABSL_TARGET_CRYPTO ABSL_ATTRIBUTE_ALWAYS_INLINE void Vector128Store( + const Vector128& v, void* ABSL_RANDOM_INTERNAL_RESTRICT to) { + _mm_store_si128(reinterpret_cast<__m128i * ABSL_RANDOM_INTERNAL_RESTRICT>(to), + v.data()); +} + +// One round of AES. "round_key" is a public constant for breaking the +// symmetry of AES (ensures previously equal columns differ afterwards). +inline ABSL_TARGET_CRYPTO ABSL_ATTRIBUTE_ALWAYS_INLINE Vector128 +AesRound(const Vector128& state, const Vector128& round_key) { + // It is important to always use the full round function - omitting the + // final MixColumns reduces security [https://eprint.iacr.org/2010/041.pdf] + // and does not help because we never decrypt. + return Vector128(_mm_aesenc_si128(state.data(), round_key.data())); +} + +inline ABSL_TARGET_CRYPTO ABSL_ATTRIBUTE_ALWAYS_INLINE void SwapEndian( + uint64_t* ABSL_RANDOM_INTERNAL_RESTRICT) {} + +} // namespace + +#endif + +namespace { + +// u64x2 is a 128-bit, (2 x uint64_t lanes) struct used to store +// the randen_keys. +struct alignas(16) u64x2 { + constexpr u64x2(uint64_t hi, uint64_t lo) +#if defined(ABSL_ARCH_PPC) + // This has been tested with PPC running in little-endian mode; + // We byte-swap the u64x2 structure from little-endian to big-endian + // because altivec always runs in big-endian mode. + : v{__builtin_bswap64(hi), __builtin_bswap64(lo)} { +#else + : v{lo, hi} { +#endif + } + + constexpr bool operator==(const u64x2& other) const { + return v[0] == other.v[0] && v[1] == other.v[1]; + } + + constexpr bool operator!=(const u64x2& other) const { + return !(*this == other); + } + + uint64_t v[2]; +}; // namespace + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-pragmas" +#endif + +// At this point, all of the platform-specific features have been defined / +// implemented. +// +// REQUIRES: using u64x2 = ... +// REQUIRES: using Vector128 = ... +// REQUIRES: Vector128 Vector128Load(void*) {...} +// REQUIRES: void Vector128Store(Vector128, void*) {...} +// REQUIRES: Vector128 AesRound(Vector128, Vector128) {...} +// REQUIRES: void SwapEndian(uint64_t*) {...} +// +// PROVIDES: absl::random_internal::RandenHwAes::Absorb +// PROVIDES: absl::random_internal::RandenHwAes::Generate + +// RANDen = RANDom generator or beetroots in Swiss German. +// 'Strong' (well-distributed, unpredictable, backtracking-resistant) random +// generator, faster in some benchmarks than std::mt19937_64 and pcg64_c32. +// +// High-level summary: +// 1) Reverie (see "A Robust and Sponge-Like PRNG with Improved Efficiency") is +// a sponge-like random generator that requires a cryptographic permutation. +// It improves upon "Provably Robust Sponge-Based PRNGs and KDFs" by +// achieving backtracking resistance with only one Permute() per buffer. +// +// 2) "Simpira v2: A Family of Efficient Permutations Using the AES Round +// Function" constructs up to 1024-bit permutations using an improved +// Generalized Feistel network with 2-round AES-128 functions. This Feistel +// block shuffle achieves diffusion faster and is less vulnerable to +// sliced-biclique attacks than the Type-2 cyclic shuffle. +// +// 3) "Improving the Generalized Feistel" and "New criterion for diffusion +// property" extends the same kind of improved Feistel block shuffle to 16 +// branches, which enables a 2048-bit permutation. +// +// We combine these three ideas and also change Simpira's subround keys from +// structured/low-entropy counters to digits of Pi. + +// Randen constants. +using absl::random_internal::RandenTraits; +constexpr size_t kStateBytes = RandenTraits::kStateBytes; +constexpr size_t kCapacityBytes = RandenTraits::kCapacityBytes; +constexpr size_t kFeistelBlocks = RandenTraits::kFeistelBlocks; +constexpr size_t kFeistelRounds = RandenTraits::kFeistelRounds; +constexpr size_t kFeistelFunctions = RandenTraits::kFeistelFunctions; + +// Independent keys (272 = 2.1 KiB) for the first AES subround of each function. +constexpr size_t kKeys = kFeistelRounds * kFeistelFunctions; + +// INCLUDE keys. +#include "absl/random/internal/randen-keys.inc" + +static_assert(kKeys == kRoundKeys, "kKeys and kRoundKeys must be equal"); +static_assert(round_keys[kKeys - 1] != u64x2(0, 0), + "Too few round_keys initializers"); + +// Number of uint64_t lanes per 128-bit vector; +constexpr size_t kLanes = 2; + +// Block shuffles applies a shuffle to the entire state between AES rounds. +// Improved odd-even shuffle from "New criterion for diffusion property". +inline ABSL_ATTRIBUTE_ALWAYS_INLINE ABSL_TARGET_CRYPTO void BlockShuffle( + uint64_t* ABSL_RANDOM_INTERNAL_RESTRICT state) { + static_assert(kFeistelBlocks == 16, "Expecting 16 FeistelBlocks."); + + constexpr size_t shuffle[kFeistelBlocks] = {7, 2, 13, 4, 11, 8, 3, 6, + 15, 0, 9, 10, 1, 14, 5, 12}; + + // The fully unrolled loop without the memcpy improves the speed by about + // 30% over the equivalent loop. + const Vector128 v0 = Vector128Load(state + kLanes * shuffle[0]); + const Vector128 v1 = Vector128Load(state + kLanes * shuffle[1]); + const Vector128 v2 = Vector128Load(state + kLanes * shuffle[2]); + const Vector128 v3 = Vector128Load(state + kLanes * shuffle[3]); + const Vector128 v4 = Vector128Load(state + kLanes * shuffle[4]); + const Vector128 v5 = Vector128Load(state + kLanes * shuffle[5]); + const Vector128 v6 = Vector128Load(state + kLanes * shuffle[6]); + const Vector128 v7 = Vector128Load(state + kLanes * shuffle[7]); + const Vector128 w0 = Vector128Load(state + kLanes * shuffle[8]); + const Vector128 w1 = Vector128Load(state + kLanes * shuffle[9]); + const Vector128 w2 = Vector128Load(state + kLanes * shuffle[10]); + const Vector128 w3 = Vector128Load(state + kLanes * shuffle[11]); + const Vector128 w4 = Vector128Load(state + kLanes * shuffle[12]); + const Vector128 w5 = Vector128Load(state + kLanes * shuffle[13]); + const Vector128 w6 = Vector128Load(state + kLanes * shuffle[14]); + const Vector128 w7 = Vector128Load(state + kLanes * shuffle[15]); + + Vector128Store(v0, state + kLanes * 0); + Vector128Store(v1, state + kLanes * 1); + Vector128Store(v2, state + kLanes * 2); + Vector128Store(v3, state + kLanes * 3); + Vector128Store(v4, state + kLanes * 4); + Vector128Store(v5, state + kLanes * 5); + Vector128Store(v6, state + kLanes * 6); + Vector128Store(v7, state + kLanes * 7); + Vector128Store(w0, state + kLanes * 8); + Vector128Store(w1, state + kLanes * 9); + Vector128Store(w2, state + kLanes * 10); + Vector128Store(w3, state + kLanes * 11); + Vector128Store(w4, state + kLanes * 12); + Vector128Store(w5, state + kLanes * 13); + Vector128Store(w6, state + kLanes * 14); + Vector128Store(w7, state + kLanes * 15); +} + +// Feistel round function using two AES subrounds. Very similar to F() +// from Simpira v2, but with independent subround keys. Uses 17 AES rounds +// per 16 bytes (vs. 10 for AES-CTR). Computing eight round functions in +// parallel hides the 7-cycle AESNI latency on HSW. Note that the Feistel +// XORs are 'free' (included in the second AES instruction). +inline ABSL_ATTRIBUTE_ALWAYS_INLINE ABSL_TARGET_CRYPTO const u64x2* +FeistelRound(uint64_t* ABSL_RANDOM_INTERNAL_RESTRICT state, + const u64x2* ABSL_RANDOM_INTERNAL_RESTRICT keys) { + static_assert(kFeistelBlocks == 16, "Expecting 16 FeistelBlocks."); + + // MSVC does a horrible job at unrolling loops. + // So we unroll the loop by hand to improve the performance. + const Vector128 s0 = Vector128Load(state + kLanes * 0); + const Vector128 s1 = Vector128Load(state + kLanes * 1); + const Vector128 s2 = Vector128Load(state + kLanes * 2); + const Vector128 s3 = Vector128Load(state + kLanes * 3); + const Vector128 s4 = Vector128Load(state + kLanes * 4); + const Vector128 s5 = Vector128Load(state + kLanes * 5); + const Vector128 s6 = Vector128Load(state + kLanes * 6); + const Vector128 s7 = Vector128Load(state + kLanes * 7); + const Vector128 s8 = Vector128Load(state + kLanes * 8); + const Vector128 s9 = Vector128Load(state + kLanes * 9); + const Vector128 s10 = Vector128Load(state + kLanes * 10); + const Vector128 s11 = Vector128Load(state + kLanes * 11); + const Vector128 s12 = Vector128Load(state + kLanes * 12); + const Vector128 s13 = Vector128Load(state + kLanes * 13); + const Vector128 s14 = Vector128Load(state + kLanes * 14); + const Vector128 s15 = Vector128Load(state + kLanes * 15); + + // Encode even blocks with keys. + const Vector128 e0 = AesRound(s0, Vector128Load(keys + 0)); + const Vector128 e2 = AesRound(s2, Vector128Load(keys + 1)); + const Vector128 e4 = AesRound(s4, Vector128Load(keys + 2)); + const Vector128 e6 = AesRound(s6, Vector128Load(keys + 3)); + const Vector128 e8 = AesRound(s8, Vector128Load(keys + 4)); + const Vector128 e10 = AesRound(s10, Vector128Load(keys + 5)); + const Vector128 e12 = AesRound(s12, Vector128Load(keys + 6)); + const Vector128 e14 = AesRound(s14, Vector128Load(keys + 7)); + + // Encode odd blocks with even output from above. + const Vector128 o1 = AesRound(e0, s1); + const Vector128 o3 = AesRound(e2, s3); + const Vector128 o5 = AesRound(e4, s5); + const Vector128 o7 = AesRound(e6, s7); + const Vector128 o9 = AesRound(e8, s9); + const Vector128 o11 = AesRound(e10, s11); + const Vector128 o13 = AesRound(e12, s13); + const Vector128 o15 = AesRound(e14, s15); + + // Store odd blocks. (These will be shuffled later). + Vector128Store(o1, state + kLanes * 1); + Vector128Store(o3, state + kLanes * 3); + Vector128Store(o5, state + kLanes * 5); + Vector128Store(o7, state + kLanes * 7); + Vector128Store(o9, state + kLanes * 9); + Vector128Store(o11, state + kLanes * 11); + Vector128Store(o13, state + kLanes * 13); + Vector128Store(o15, state + kLanes * 15); + + return keys + 8; +} + +// Cryptographic permutation based via type-2 Generalized Feistel Network. +// Indistinguishable from ideal by chosen-ciphertext adversaries using less than +// 2^64 queries if the round function is a PRF. This is similar to the b=8 case +// of Simpira v2, but more efficient than its generic construction for b=16. +inline ABSL_ATTRIBUTE_ALWAYS_INLINE ABSL_TARGET_CRYPTO void Permute( + const void* ABSL_RANDOM_INTERNAL_RESTRICT keys, + uint64_t* ABSL_RANDOM_INTERNAL_RESTRICT state) { + const u64x2* ABSL_RANDOM_INTERNAL_RESTRICT keys128 = + static_cast<const u64x2*>(keys); + + // (Successfully unrolled; the first iteration jumps into the second half) +#ifdef __clang__ +#pragma clang loop unroll_count(2) +#endif + for (size_t round = 0; round < kFeistelRounds; ++round) { + keys128 = FeistelRound(state, keys128); + BlockShuffle(state); + } +} + +} // namespace + +namespace absl { +namespace random_internal { + +bool HasRandenHwAesImplementation() { return true; } + +const void* ABSL_TARGET_CRYPTO ABSL_FUNCTION_ALIGN32 ABSL_ATTRIBUTE_FLATTEN +RandenHwAes::GetKeys() { + // Round keys for one AES per Feistel round and branch. + // The canonical implementation uses first digits of Pi. + return round_keys; +} + +// NOLINTNEXTLINE +void ABSL_TARGET_CRYPTO ABSL_FUNCTION_ALIGN32 ABSL_ATTRIBUTE_FLATTEN +RandenHwAes::Absorb(const void* seed_void, void* state_void) { + uint64_t* ABSL_RANDOM_INTERNAL_RESTRICT state = + reinterpret_cast<uint64_t*>(state_void); + const uint64_t* ABSL_RANDOM_INTERNAL_RESTRICT seed = + reinterpret_cast<const uint64_t*>(seed_void); + + constexpr size_t kCapacityBlocks = kCapacityBytes / sizeof(Vector128); + constexpr size_t kStateBlocks = kStateBytes / sizeof(Vector128); + + static_assert(kCapacityBlocks * sizeof(Vector128) == kCapacityBytes, + "Not i*V"); + static_assert(kCapacityBlocks == 1, "Unexpected Randen kCapacityBlocks"); + static_assert(kStateBlocks == 16, "Unexpected Randen kStateBlocks"); + + Vector128 b1 = Vector128Load(state + kLanes * 1); + b1 ^= Vector128Load(seed + kLanes * 0); + Vector128Store(b1, state + kLanes * 1); + + Vector128 b2 = Vector128Load(state + kLanes * 2); + b2 ^= Vector128Load(seed + kLanes * 1); + Vector128Store(b2, state + kLanes * 2); + + Vector128 b3 = Vector128Load(state + kLanes * 3); + b3 ^= Vector128Load(seed + kLanes * 2); + Vector128Store(b3, state + kLanes * 3); + + Vector128 b4 = Vector128Load(state + kLanes * 4); + b4 ^= Vector128Load(seed + kLanes * 3); + Vector128Store(b4, state + kLanes * 4); + + Vector128 b5 = Vector128Load(state + kLanes * 5); + b5 ^= Vector128Load(seed + kLanes * 4); + Vector128Store(b5, state + kLanes * 5); + + Vector128 b6 = Vector128Load(state + kLanes * 6); + b6 ^= Vector128Load(seed + kLanes * 5); + Vector128Store(b6, state + kLanes * 6); + + Vector128 b7 = Vector128Load(state + kLanes * 7); + b7 ^= Vector128Load(seed + kLanes * 6); + Vector128Store(b7, state + kLanes * 7); + + Vector128 b8 = Vector128Load(state + kLanes * 8); + b8 ^= Vector128Load(seed + kLanes * 7); + Vector128Store(b8, state + kLanes * 8); + + Vector128 b9 = Vector128Load(state + kLanes * 9); + b9 ^= Vector128Load(seed + kLanes * 8); + Vector128Store(b9, state + kLanes * 9); + + Vector128 b10 = Vector128Load(state + kLanes * 10); + b10 ^= Vector128Load(seed + kLanes * 9); + Vector128Store(b10, state + kLanes * 10); + + Vector128 b11 = Vector128Load(state + kLanes * 11); + b11 ^= Vector128Load(seed + kLanes * 10); + Vector128Store(b11, state + kLanes * 11); + + Vector128 b12 = Vector128Load(state + kLanes * 12); + b12 ^= Vector128Load(seed + kLanes * 11); + Vector128Store(b12, state + kLanes * 12); + + Vector128 b13 = Vector128Load(state + kLanes * 13); + b13 ^= Vector128Load(seed + kLanes * 12); + Vector128Store(b13, state + kLanes * 13); + + Vector128 b14 = Vector128Load(state + kLanes * 14); + b14 ^= Vector128Load(seed + kLanes * 13); + Vector128Store(b14, state + kLanes * 14); + + Vector128 b15 = Vector128Load(state + kLanes * 15); + b15 ^= Vector128Load(seed + kLanes * 14); + Vector128Store(b15, state + kLanes * 15); +} + +// NOLINTNEXTLINE +void ABSL_TARGET_CRYPTO ABSL_FUNCTION_ALIGN32 ABSL_ATTRIBUTE_FLATTEN +RandenHwAes::Generate(const void* keys, void* state_void) { + static_assert(kCapacityBytes == sizeof(Vector128), "Capacity mismatch"); + + uint64_t* ABSL_RANDOM_INTERNAL_RESTRICT state = + reinterpret_cast<uint64_t*>(state_void); + + const Vector128 prev_inner = Vector128Load(state); + + SwapEndian(state); + + Permute(keys, state); + + SwapEndian(state); + + // Ensure backtracking resistance. + Vector128 inner = Vector128Load(state); + inner ^= prev_inner; + Vector128Store(inner, state); +} + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +} // namespace random_internal +} // namespace absl + +#endif // (ABSL_RANDEN_HWAES_IMPL) diff --git a/absl/random/internal/randen_hwaes.h b/absl/random/internal/randen_hwaes.h new file mode 100644 index 000000000000..0acec4b700ef --- /dev/null +++ b/absl/random/internal/randen_hwaes.h @@ -0,0 +1,46 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_RANDEN_HWAES_H_ +#define ABSL_RANDOM_INTERNAL_RANDEN_HWAES_H_ + +// HERMETIC NOTE: The randen_hwaes target must not introduce duplicate +// symbols from arbitrary system and other headers, since it may be built +// with different flags from other targets, using different levels of +// optimization, potentially introducing ODR violations. + +namespace absl { +namespace random_internal { + +// RANDen = RANDom generator or beetroots in Swiss German. +// 'Strong' (well-distributed, unpredictable, backtracking-resistant) random +// generator, faster in some benchmarks than std::mt19937_64 and pcg64_c32. +// +// RandenHwAes implements the basic state manipulation methods. +class RandenHwAes { + public: + static void Generate(const void* keys, void* state_void); + static void Absorb(const void* seed_void, void* state_void); + static const void* GetKeys(); +}; + +// HasRandenHwAesImplementation returns true when there is an accelerated +// implementation, and false otherwise. If there is no implementation, +// then attempting to use it will abort the program. +bool HasRandenHwAesImplementation(); + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_RANDEN_FAST_H_ diff --git a/absl/random/internal/randen_hwaes_test.cc b/absl/random/internal/randen_hwaes_test.cc new file mode 100644 index 000000000000..a7cbd46b2dd5 --- /dev/null +++ b/absl/random/internal/randen_hwaes_test.cc @@ -0,0 +1,102 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/randen_hwaes.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/base/internal/raw_logging.h" +#include "absl/random/internal/platform.h" +#include "absl/random/internal/randen_detect.h" +#include "absl/random/internal/randen_traits.h" +#include "absl/strings/str_format.h" + +namespace { + +using absl::random_internal::RandenHwAes; +using absl::random_internal::RandenTraits; + +struct randen { + static constexpr size_t kStateSizeT = + RandenTraits::kStateBytes / sizeof(uint64_t); + uint64_t state[kStateSizeT]; + static constexpr size_t kSeedSizeT = + RandenTraits::kSeedBytes / sizeof(uint32_t); + uint32_t seed[kSeedSizeT]; +}; + +TEST(RandenHwAesTest, Default) { + EXPECT_TRUE(absl::random_internal::CPUSupportsRandenHwAes()); + + constexpr uint64_t kGolden[] = { + 0x6c6534090ee6d3ee, 0x044e2b9b9d5333c6, 0xc3c14f134e433977, + 0xdda9f47cd90410ee, 0x887bf3087fd8ca10, 0xf0b780f545c72912, + 0x15dbb1d37696599f, 0x30ec63baff3c6d59, 0xb29f73606f7f20a6, + 0x02808a316f49a54c, 0x3b8feaf9d5c8e50e, 0x9cbf605e3fd9de8a, + 0xc970ae1a78183bbb, 0xd8b2ffd356301ed5, 0xf4b327fe0fc73c37, + 0xcdfd8d76eb8f9a19, 0xc3a506eb91420c9d, 0xd5af05dd3eff9556, + 0x48db1bb78f83c4a1, 0x7023920e0d6bfe8c, 0x58d3575834956d42, + 0xed1ef4c26b87b840, 0x8eef32a23e0b2df3, 0x497cabf3431154fc, + 0x4e24370570029a8b, 0xd88b5749f090e5ea, 0xc651a582a970692f, + 0x78fcec2cbb6342f5, 0x463cb745612f55db, 0x352ee4ad1816afe3, + 0x026ff374c101da7e, 0x811ef0821c3de851, + }; + + alignas(16) randen d; + memset(d.state, 0, sizeof(d.state)); + RandenHwAes::Generate(RandenHwAes::GetKeys(), d.state); + + uint64_t* id = d.state; + for (const auto& elem : kGolden) { + auto a = absl::StrFormat("%#x", elem); + auto b = absl::StrFormat("%#x", *id++); + EXPECT_EQ(a, b); + } +} + +} // namespace + +int main(int argc, char* argv[]) { + testing::InitGoogleTest(&argc, argv); + + ABSL_RAW_LOG(INFO, "ABSL_HAVE_ACCELERATED_AES=%d", ABSL_HAVE_ACCELERATED_AES); + ABSL_RAW_LOG(INFO, "ABSL_RANDOM_INTERNAL_AES_DISPATCH=%d", + ABSL_RANDOM_INTERNAL_AES_DISPATCH); + +#if defined(ABSL_ARCH_X86_64) + ABSL_RAW_LOG(INFO, "ABSL_ARCH_X86_64"); +#elif defined(ABSL_ARCH_X86_32) + ABSL_RAW_LOG(INFO, "ABSL_ARCH_X86_32"); +#elif defined(ABSL_ARCH_AARCH64) + ABSL_RAW_LOG(INFO, "ABSL_ARCH_AARCH64"); +#elif defined(ABSL_ARCH_ARM) + ABSL_RAW_LOG(INFO, "ABSL_ARCH_ARM"); +#elif defined(ABSL_ARCH_PPC) + ABSL_RAW_LOG(INFO, "ABSL_ARCH_PPC"); +#else + ABSL_RAW_LOG(INFO, "ARCH Unknown"); +#endif + + int x = absl::random_internal::HasRandenHwAesImplementation(); + ABSL_RAW_LOG(INFO, "HasRandenHwAesImplementation = %d", x); + + int y = absl::random_internal::CPUSupportsRandenHwAes(); + ABSL_RAW_LOG(INFO, "CPUSupportsRandenHwAes = %d", x); + + if (!x || !y) { + ABSL_RAW_LOG(INFO, "Skipping Randen HWAES tests."); + return 0; + } + return RUN_ALL_TESTS(); +} diff --git a/absl/random/internal/randen_slow.cc b/absl/random/internal/randen_slow.cc new file mode 100644 index 000000000000..b2ecabff878c --- /dev/null +++ b/absl/random/internal/randen_slow.cc @@ -0,0 +1,490 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/randen_slow.h" + +#include <cstddef> +#include <cstdint> +#include <cstring> + +#include "absl/random/internal/platform.h" + +namespace { + +// AES portions based on rijndael-alg-fst.c, +// https://fastcrypto.org/front/misc/rijndael-alg-fst.c +// +// Implementation of +// http://www.csrc.nist.gov/publications/fips/fips197/fips-197.pdf +constexpr uint32_t te0[256] = { + 0xc66363a5, 0xf87c7c84, 0xee777799, 0xf67b7b8d, 0xfff2f20d, 0xd66b6bbd, + 0xde6f6fb1, 0x91c5c554, 0x60303050, 0x02010103, 0xce6767a9, 0x562b2b7d, + 0xe7fefe19, 0xb5d7d762, 0x4dababe6, 0xec76769a, 0x8fcaca45, 0x1f82829d, + 0x89c9c940, 0xfa7d7d87, 0xeffafa15, 0xb25959eb, 0x8e4747c9, 0xfbf0f00b, + 0x41adadec, 0xb3d4d467, 0x5fa2a2fd, 0x45afafea, 0x239c9cbf, 0x53a4a4f7, + 0xe4727296, 0x9bc0c05b, 0x75b7b7c2, 0xe1fdfd1c, 0x3d9393ae, 0x4c26266a, + 0x6c36365a, 0x7e3f3f41, 0xf5f7f702, 0x83cccc4f, 0x6834345c, 0x51a5a5f4, + 0xd1e5e534, 0xf9f1f108, 0xe2717193, 0xabd8d873, 0x62313153, 0x2a15153f, + 0x0804040c, 0x95c7c752, 0x46232365, 0x9dc3c35e, 0x30181828, 0x379696a1, + 0x0a05050f, 0x2f9a9ab5, 0x0e070709, 0x24121236, 0x1b80809b, 0xdfe2e23d, + 0xcdebeb26, 0x4e272769, 0x7fb2b2cd, 0xea75759f, 0x1209091b, 0x1d83839e, + 0x582c2c74, 0x341a1a2e, 0x361b1b2d, 0xdc6e6eb2, 0xb45a5aee, 0x5ba0a0fb, + 0xa45252f6, 0x763b3b4d, 0xb7d6d661, 0x7db3b3ce, 0x5229297b, 0xdde3e33e, + 0x5e2f2f71, 0x13848497, 0xa65353f5, 0xb9d1d168, 0x00000000, 0xc1eded2c, + 0x40202060, 0xe3fcfc1f, 0x79b1b1c8, 0xb65b5bed, 0xd46a6abe, 0x8dcbcb46, + 0x67bebed9, 0x7239394b, 0x944a4ade, 0x984c4cd4, 0xb05858e8, 0x85cfcf4a, + 0xbbd0d06b, 0xc5efef2a, 0x4faaaae5, 0xedfbfb16, 0x864343c5, 0x9a4d4dd7, + 0x66333355, 0x11858594, 0x8a4545cf, 0xe9f9f910, 0x04020206, 0xfe7f7f81, + 0xa05050f0, 0x783c3c44, 0x259f9fba, 0x4ba8a8e3, 0xa25151f3, 0x5da3a3fe, + 0x804040c0, 0x058f8f8a, 0x3f9292ad, 0x219d9dbc, 0x70383848, 0xf1f5f504, + 0x63bcbcdf, 0x77b6b6c1, 0xafdada75, 0x42212163, 0x20101030, 0xe5ffff1a, + 0xfdf3f30e, 0xbfd2d26d, 0x81cdcd4c, 0x180c0c14, 0x26131335, 0xc3ecec2f, + 0xbe5f5fe1, 0x359797a2, 0x884444cc, 0x2e171739, 0x93c4c457, 0x55a7a7f2, + 0xfc7e7e82, 0x7a3d3d47, 0xc86464ac, 0xba5d5de7, 0x3219192b, 0xe6737395, + 0xc06060a0, 0x19818198, 0x9e4f4fd1, 0xa3dcdc7f, 0x44222266, 0x542a2a7e, + 0x3b9090ab, 0x0b888883, 0x8c4646ca, 0xc7eeee29, 0x6bb8b8d3, 0x2814143c, + 0xa7dede79, 0xbc5e5ee2, 0x160b0b1d, 0xaddbdb76, 0xdbe0e03b, 0x64323256, + 0x743a3a4e, 0x140a0a1e, 0x924949db, 0x0c06060a, 0x4824246c, 0xb85c5ce4, + 0x9fc2c25d, 0xbdd3d36e, 0x43acacef, 0xc46262a6, 0x399191a8, 0x319595a4, + 0xd3e4e437, 0xf279798b, 0xd5e7e732, 0x8bc8c843, 0x6e373759, 0xda6d6db7, + 0x018d8d8c, 0xb1d5d564, 0x9c4e4ed2, 0x49a9a9e0, 0xd86c6cb4, 0xac5656fa, + 0xf3f4f407, 0xcfeaea25, 0xca6565af, 0xf47a7a8e, 0x47aeaee9, 0x10080818, + 0x6fbabad5, 0xf0787888, 0x4a25256f, 0x5c2e2e72, 0x381c1c24, 0x57a6a6f1, + 0x73b4b4c7, 0x97c6c651, 0xcbe8e823, 0xa1dddd7c, 0xe874749c, 0x3e1f1f21, + 0x964b4bdd, 0x61bdbddc, 0x0d8b8b86, 0x0f8a8a85, 0xe0707090, 0x7c3e3e42, + 0x71b5b5c4, 0xcc6666aa, 0x904848d8, 0x06030305, 0xf7f6f601, 0x1c0e0e12, + 0xc26161a3, 0x6a35355f, 0xae5757f9, 0x69b9b9d0, 0x17868691, 0x99c1c158, + 0x3a1d1d27, 0x279e9eb9, 0xd9e1e138, 0xebf8f813, 0x2b9898b3, 0x22111133, + 0xd26969bb, 0xa9d9d970, 0x078e8e89, 0x339494a7, 0x2d9b9bb6, 0x3c1e1e22, + 0x15878792, 0xc9e9e920, 0x87cece49, 0xaa5555ff, 0x50282878, 0xa5dfdf7a, + 0x038c8c8f, 0x59a1a1f8, 0x09898980, 0x1a0d0d17, 0x65bfbfda, 0xd7e6e631, + 0x844242c6, 0xd06868b8, 0x824141c3, 0x299999b0, 0x5a2d2d77, 0x1e0f0f11, + 0x7bb0b0cb, 0xa85454fc, 0x6dbbbbd6, 0x2c16163a, +}; + +constexpr uint32_t te1[256] = { + 0xa5c66363, 0x84f87c7c, 0x99ee7777, 0x8df67b7b, 0x0dfff2f2, 0xbdd66b6b, + 0xb1de6f6f, 0x5491c5c5, 0x50603030, 0x03020101, 0xa9ce6767, 0x7d562b2b, + 0x19e7fefe, 0x62b5d7d7, 0xe64dabab, 0x9aec7676, 0x458fcaca, 0x9d1f8282, + 0x4089c9c9, 0x87fa7d7d, 0x15effafa, 0xebb25959, 0xc98e4747, 0x0bfbf0f0, + 0xec41adad, 0x67b3d4d4, 0xfd5fa2a2, 0xea45afaf, 0xbf239c9c, 0xf753a4a4, + 0x96e47272, 0x5b9bc0c0, 0xc275b7b7, 0x1ce1fdfd, 0xae3d9393, 0x6a4c2626, + 0x5a6c3636, 0x417e3f3f, 0x02f5f7f7, 0x4f83cccc, 0x5c683434, 0xf451a5a5, + 0x34d1e5e5, 0x08f9f1f1, 0x93e27171, 0x73abd8d8, 0x53623131, 0x3f2a1515, + 0x0c080404, 0x5295c7c7, 0x65462323, 0x5e9dc3c3, 0x28301818, 0xa1379696, + 0x0f0a0505, 0xb52f9a9a, 0x090e0707, 0x36241212, 0x9b1b8080, 0x3ddfe2e2, + 0x26cdebeb, 0x694e2727, 0xcd7fb2b2, 0x9fea7575, 0x1b120909, 0x9e1d8383, + 0x74582c2c, 0x2e341a1a, 0x2d361b1b, 0xb2dc6e6e, 0xeeb45a5a, 0xfb5ba0a0, + 0xf6a45252, 0x4d763b3b, 0x61b7d6d6, 0xce7db3b3, 0x7b522929, 0x3edde3e3, + 0x715e2f2f, 0x97138484, 0xf5a65353, 0x68b9d1d1, 0x00000000, 0x2cc1eded, + 0x60402020, 0x1fe3fcfc, 0xc879b1b1, 0xedb65b5b, 0xbed46a6a, 0x468dcbcb, + 0xd967bebe, 0x4b723939, 0xde944a4a, 0xd4984c4c, 0xe8b05858, 0x4a85cfcf, + 0x6bbbd0d0, 0x2ac5efef, 0xe54faaaa, 0x16edfbfb, 0xc5864343, 0xd79a4d4d, + 0x55663333, 0x94118585, 0xcf8a4545, 0x10e9f9f9, 0x06040202, 0x81fe7f7f, + 0xf0a05050, 0x44783c3c, 0xba259f9f, 0xe34ba8a8, 0xf3a25151, 0xfe5da3a3, + 0xc0804040, 0x8a058f8f, 0xad3f9292, 0xbc219d9d, 0x48703838, 0x04f1f5f5, + 0xdf63bcbc, 0xc177b6b6, 0x75afdada, 0x63422121, 0x30201010, 0x1ae5ffff, + 0x0efdf3f3, 0x6dbfd2d2, 0x4c81cdcd, 0x14180c0c, 0x35261313, 0x2fc3ecec, + 0xe1be5f5f, 0xa2359797, 0xcc884444, 0x392e1717, 0x5793c4c4, 0xf255a7a7, + 0x82fc7e7e, 0x477a3d3d, 0xacc86464, 0xe7ba5d5d, 0x2b321919, 0x95e67373, + 0xa0c06060, 0x98198181, 0xd19e4f4f, 0x7fa3dcdc, 0x66442222, 0x7e542a2a, + 0xab3b9090, 0x830b8888, 0xca8c4646, 0x29c7eeee, 0xd36bb8b8, 0x3c281414, + 0x79a7dede, 0xe2bc5e5e, 0x1d160b0b, 0x76addbdb, 0x3bdbe0e0, 0x56643232, + 0x4e743a3a, 0x1e140a0a, 0xdb924949, 0x0a0c0606, 0x6c482424, 0xe4b85c5c, + 0x5d9fc2c2, 0x6ebdd3d3, 0xef43acac, 0xa6c46262, 0xa8399191, 0xa4319595, + 0x37d3e4e4, 0x8bf27979, 0x32d5e7e7, 0x438bc8c8, 0x596e3737, 0xb7da6d6d, + 0x8c018d8d, 0x64b1d5d5, 0xd29c4e4e, 0xe049a9a9, 0xb4d86c6c, 0xfaac5656, + 0x07f3f4f4, 0x25cfeaea, 0xafca6565, 0x8ef47a7a, 0xe947aeae, 0x18100808, + 0xd56fbaba, 0x88f07878, 0x6f4a2525, 0x725c2e2e, 0x24381c1c, 0xf157a6a6, + 0xc773b4b4, 0x5197c6c6, 0x23cbe8e8, 0x7ca1dddd, 0x9ce87474, 0x213e1f1f, + 0xdd964b4b, 0xdc61bdbd, 0x860d8b8b, 0x850f8a8a, 0x90e07070, 0x427c3e3e, + 0xc471b5b5, 0xaacc6666, 0xd8904848, 0x05060303, 0x01f7f6f6, 0x121c0e0e, + 0xa3c26161, 0x5f6a3535, 0xf9ae5757, 0xd069b9b9, 0x91178686, 0x5899c1c1, + 0x273a1d1d, 0xb9279e9e, 0x38d9e1e1, 0x13ebf8f8, 0xb32b9898, 0x33221111, + 0xbbd26969, 0x70a9d9d9, 0x89078e8e, 0xa7339494, 0xb62d9b9b, 0x223c1e1e, + 0x92158787, 0x20c9e9e9, 0x4987cece, 0xffaa5555, 0x78502828, 0x7aa5dfdf, + 0x8f038c8c, 0xf859a1a1, 0x80098989, 0x171a0d0d, 0xda65bfbf, 0x31d7e6e6, + 0xc6844242, 0xb8d06868, 0xc3824141, 0xb0299999, 0x775a2d2d, 0x111e0f0f, + 0xcb7bb0b0, 0xfca85454, 0xd66dbbbb, 0x3a2c1616, +}; + +constexpr uint32_t te2[256] = { + 0x63a5c663, 0x7c84f87c, 0x7799ee77, 0x7b8df67b, 0xf20dfff2, 0x6bbdd66b, + 0x6fb1de6f, 0xc55491c5, 0x30506030, 0x01030201, 0x67a9ce67, 0x2b7d562b, + 0xfe19e7fe, 0xd762b5d7, 0xabe64dab, 0x769aec76, 0xca458fca, 0x829d1f82, + 0xc94089c9, 0x7d87fa7d, 0xfa15effa, 0x59ebb259, 0x47c98e47, 0xf00bfbf0, + 0xadec41ad, 0xd467b3d4, 0xa2fd5fa2, 0xafea45af, 0x9cbf239c, 0xa4f753a4, + 0x7296e472, 0xc05b9bc0, 0xb7c275b7, 0xfd1ce1fd, 0x93ae3d93, 0x266a4c26, + 0x365a6c36, 0x3f417e3f, 0xf702f5f7, 0xcc4f83cc, 0x345c6834, 0xa5f451a5, + 0xe534d1e5, 0xf108f9f1, 0x7193e271, 0xd873abd8, 0x31536231, 0x153f2a15, + 0x040c0804, 0xc75295c7, 0x23654623, 0xc35e9dc3, 0x18283018, 0x96a13796, + 0x050f0a05, 0x9ab52f9a, 0x07090e07, 0x12362412, 0x809b1b80, 0xe23ddfe2, + 0xeb26cdeb, 0x27694e27, 0xb2cd7fb2, 0x759fea75, 0x091b1209, 0x839e1d83, + 0x2c74582c, 0x1a2e341a, 0x1b2d361b, 0x6eb2dc6e, 0x5aeeb45a, 0xa0fb5ba0, + 0x52f6a452, 0x3b4d763b, 0xd661b7d6, 0xb3ce7db3, 0x297b5229, 0xe33edde3, + 0x2f715e2f, 0x84971384, 0x53f5a653, 0xd168b9d1, 0x00000000, 0xed2cc1ed, + 0x20604020, 0xfc1fe3fc, 0xb1c879b1, 0x5bedb65b, 0x6abed46a, 0xcb468dcb, + 0xbed967be, 0x394b7239, 0x4ade944a, 0x4cd4984c, 0x58e8b058, 0xcf4a85cf, + 0xd06bbbd0, 0xef2ac5ef, 0xaae54faa, 0xfb16edfb, 0x43c58643, 0x4dd79a4d, + 0x33556633, 0x85941185, 0x45cf8a45, 0xf910e9f9, 0x02060402, 0x7f81fe7f, + 0x50f0a050, 0x3c44783c, 0x9fba259f, 0xa8e34ba8, 0x51f3a251, 0xa3fe5da3, + 0x40c08040, 0x8f8a058f, 0x92ad3f92, 0x9dbc219d, 0x38487038, 0xf504f1f5, + 0xbcdf63bc, 0xb6c177b6, 0xda75afda, 0x21634221, 0x10302010, 0xff1ae5ff, + 0xf30efdf3, 0xd26dbfd2, 0xcd4c81cd, 0x0c14180c, 0x13352613, 0xec2fc3ec, + 0x5fe1be5f, 0x97a23597, 0x44cc8844, 0x17392e17, 0xc45793c4, 0xa7f255a7, + 0x7e82fc7e, 0x3d477a3d, 0x64acc864, 0x5de7ba5d, 0x192b3219, 0x7395e673, + 0x60a0c060, 0x81981981, 0x4fd19e4f, 0xdc7fa3dc, 0x22664422, 0x2a7e542a, + 0x90ab3b90, 0x88830b88, 0x46ca8c46, 0xee29c7ee, 0xb8d36bb8, 0x143c2814, + 0xde79a7de, 0x5ee2bc5e, 0x0b1d160b, 0xdb76addb, 0xe03bdbe0, 0x32566432, + 0x3a4e743a, 0x0a1e140a, 0x49db9249, 0x060a0c06, 0x246c4824, 0x5ce4b85c, + 0xc25d9fc2, 0xd36ebdd3, 0xacef43ac, 0x62a6c462, 0x91a83991, 0x95a43195, + 0xe437d3e4, 0x798bf279, 0xe732d5e7, 0xc8438bc8, 0x37596e37, 0x6db7da6d, + 0x8d8c018d, 0xd564b1d5, 0x4ed29c4e, 0xa9e049a9, 0x6cb4d86c, 0x56faac56, + 0xf407f3f4, 0xea25cfea, 0x65afca65, 0x7a8ef47a, 0xaee947ae, 0x08181008, + 0xbad56fba, 0x7888f078, 0x256f4a25, 0x2e725c2e, 0x1c24381c, 0xa6f157a6, + 0xb4c773b4, 0xc65197c6, 0xe823cbe8, 0xdd7ca1dd, 0x749ce874, 0x1f213e1f, + 0x4bdd964b, 0xbddc61bd, 0x8b860d8b, 0x8a850f8a, 0x7090e070, 0x3e427c3e, + 0xb5c471b5, 0x66aacc66, 0x48d89048, 0x03050603, 0xf601f7f6, 0x0e121c0e, + 0x61a3c261, 0x355f6a35, 0x57f9ae57, 0xb9d069b9, 0x86911786, 0xc15899c1, + 0x1d273a1d, 0x9eb9279e, 0xe138d9e1, 0xf813ebf8, 0x98b32b98, 0x11332211, + 0x69bbd269, 0xd970a9d9, 0x8e89078e, 0x94a73394, 0x9bb62d9b, 0x1e223c1e, + 0x87921587, 0xe920c9e9, 0xce4987ce, 0x55ffaa55, 0x28785028, 0xdf7aa5df, + 0x8c8f038c, 0xa1f859a1, 0x89800989, 0x0d171a0d, 0xbfda65bf, 0xe631d7e6, + 0x42c68442, 0x68b8d068, 0x41c38241, 0x99b02999, 0x2d775a2d, 0x0f111e0f, + 0xb0cb7bb0, 0x54fca854, 0xbbd66dbb, 0x163a2c16, +}; + +constexpr uint32_t te3[256] = { + 0x6363a5c6, 0x7c7c84f8, 0x777799ee, 0x7b7b8df6, 0xf2f20dff, 0x6b6bbdd6, + 0x6f6fb1de, 0xc5c55491, 0x30305060, 0x01010302, 0x6767a9ce, 0x2b2b7d56, + 0xfefe19e7, 0xd7d762b5, 0xababe64d, 0x76769aec, 0xcaca458f, 0x82829d1f, + 0xc9c94089, 0x7d7d87fa, 0xfafa15ef, 0x5959ebb2, 0x4747c98e, 0xf0f00bfb, + 0xadadec41, 0xd4d467b3, 0xa2a2fd5f, 0xafafea45, 0x9c9cbf23, 0xa4a4f753, + 0x727296e4, 0xc0c05b9b, 0xb7b7c275, 0xfdfd1ce1, 0x9393ae3d, 0x26266a4c, + 0x36365a6c, 0x3f3f417e, 0xf7f702f5, 0xcccc4f83, 0x34345c68, 0xa5a5f451, + 0xe5e534d1, 0xf1f108f9, 0x717193e2, 0xd8d873ab, 0x31315362, 0x15153f2a, + 0x04040c08, 0xc7c75295, 0x23236546, 0xc3c35e9d, 0x18182830, 0x9696a137, + 0x05050f0a, 0x9a9ab52f, 0x0707090e, 0x12123624, 0x80809b1b, 0xe2e23ddf, + 0xebeb26cd, 0x2727694e, 0xb2b2cd7f, 0x75759fea, 0x09091b12, 0x83839e1d, + 0x2c2c7458, 0x1a1a2e34, 0x1b1b2d36, 0x6e6eb2dc, 0x5a5aeeb4, 0xa0a0fb5b, + 0x5252f6a4, 0x3b3b4d76, 0xd6d661b7, 0xb3b3ce7d, 0x29297b52, 0xe3e33edd, + 0x2f2f715e, 0x84849713, 0x5353f5a6, 0xd1d168b9, 0x00000000, 0xeded2cc1, + 0x20206040, 0xfcfc1fe3, 0xb1b1c879, 0x5b5bedb6, 0x6a6abed4, 0xcbcb468d, + 0xbebed967, 0x39394b72, 0x4a4ade94, 0x4c4cd498, 0x5858e8b0, 0xcfcf4a85, + 0xd0d06bbb, 0xefef2ac5, 0xaaaae54f, 0xfbfb16ed, 0x4343c586, 0x4d4dd79a, + 0x33335566, 0x85859411, 0x4545cf8a, 0xf9f910e9, 0x02020604, 0x7f7f81fe, + 0x5050f0a0, 0x3c3c4478, 0x9f9fba25, 0xa8a8e34b, 0x5151f3a2, 0xa3a3fe5d, + 0x4040c080, 0x8f8f8a05, 0x9292ad3f, 0x9d9dbc21, 0x38384870, 0xf5f504f1, + 0xbcbcdf63, 0xb6b6c177, 0xdada75af, 0x21216342, 0x10103020, 0xffff1ae5, + 0xf3f30efd, 0xd2d26dbf, 0xcdcd4c81, 0x0c0c1418, 0x13133526, 0xecec2fc3, + 0x5f5fe1be, 0x9797a235, 0x4444cc88, 0x1717392e, 0xc4c45793, 0xa7a7f255, + 0x7e7e82fc, 0x3d3d477a, 0x6464acc8, 0x5d5de7ba, 0x19192b32, 0x737395e6, + 0x6060a0c0, 0x81819819, 0x4f4fd19e, 0xdcdc7fa3, 0x22226644, 0x2a2a7e54, + 0x9090ab3b, 0x8888830b, 0x4646ca8c, 0xeeee29c7, 0xb8b8d36b, 0x14143c28, + 0xdede79a7, 0x5e5ee2bc, 0x0b0b1d16, 0xdbdb76ad, 0xe0e03bdb, 0x32325664, + 0x3a3a4e74, 0x0a0a1e14, 0x4949db92, 0x06060a0c, 0x24246c48, 0x5c5ce4b8, + 0xc2c25d9f, 0xd3d36ebd, 0xacacef43, 0x6262a6c4, 0x9191a839, 0x9595a431, + 0xe4e437d3, 0x79798bf2, 0xe7e732d5, 0xc8c8438b, 0x3737596e, 0x6d6db7da, + 0x8d8d8c01, 0xd5d564b1, 0x4e4ed29c, 0xa9a9e049, 0x6c6cb4d8, 0x5656faac, + 0xf4f407f3, 0xeaea25cf, 0x6565afca, 0x7a7a8ef4, 0xaeaee947, 0x08081810, + 0xbabad56f, 0x787888f0, 0x25256f4a, 0x2e2e725c, 0x1c1c2438, 0xa6a6f157, + 0xb4b4c773, 0xc6c65197, 0xe8e823cb, 0xdddd7ca1, 0x74749ce8, 0x1f1f213e, + 0x4b4bdd96, 0xbdbddc61, 0x8b8b860d, 0x8a8a850f, 0x707090e0, 0x3e3e427c, + 0xb5b5c471, 0x6666aacc, 0x4848d890, 0x03030506, 0xf6f601f7, 0x0e0e121c, + 0x6161a3c2, 0x35355f6a, 0x5757f9ae, 0xb9b9d069, 0x86869117, 0xc1c15899, + 0x1d1d273a, 0x9e9eb927, 0xe1e138d9, 0xf8f813eb, 0x9898b32b, 0x11113322, + 0x6969bbd2, 0xd9d970a9, 0x8e8e8907, 0x9494a733, 0x9b9bb62d, 0x1e1e223c, + 0x87879215, 0xe9e920c9, 0xcece4987, 0x5555ffaa, 0x28287850, 0xdfdf7aa5, + 0x8c8c8f03, 0xa1a1f859, 0x89898009, 0x0d0d171a, 0xbfbfda65, 0xe6e631d7, + 0x4242c684, 0x6868b8d0, 0x4141c382, 0x9999b029, 0x2d2d775a, 0x0f0f111e, + 0xb0b0cb7b, 0x5454fca8, 0xbbbbd66d, 0x16163a2c, +}; + +struct alignas(16) u64x2 { + constexpr u64x2() : v{0, 0} {}; + constexpr u64x2(uint64_t hi, uint64_t lo) : v{lo, hi} {} + + uint64_t v[2]; +}; + +// Software implementation of the Vector128 class, using uint32_t +// as an underlying vector register. +// +struct Vector128 { + inline ABSL_ATTRIBUTE_ALWAYS_INLINE Vector128& operator^=( + const Vector128& other) { + s[0] ^= other.s[0]; + s[1] ^= other.s[1]; + s[2] ^= other.s[2]; + s[3] ^= other.s[3]; + return *this; + } + + uint32_t s[4]; +}; + +inline ABSL_ATTRIBUTE_ALWAYS_INLINE Vector128 +Vector128Load(const void* ABSL_RANDOM_INTERNAL_RESTRICT from) { + Vector128 result; + const uint8_t* ABSL_RANDOM_INTERNAL_RESTRICT src = + reinterpret_cast<const uint8_t*>(from); + + result.s[0] = static_cast<uint32_t>(src[0]) << 24 | + static_cast<uint32_t>(src[1]) << 16 | + static_cast<uint32_t>(src[2]) << 8 | + static_cast<uint32_t>(src[3]); + result.s[1] = static_cast<uint32_t>(src[4]) << 24 | + static_cast<uint32_t>(src[5]) << 16 | + static_cast<uint32_t>(src[6]) << 8 | + static_cast<uint32_t>(src[7]); + result.s[2] = static_cast<uint32_t>(src[8]) << 24 | + static_cast<uint32_t>(src[9]) << 16 | + static_cast<uint32_t>(src[10]) << 8 | + static_cast<uint32_t>(src[11]); + result.s[3] = static_cast<uint32_t>(src[12]) << 24 | + static_cast<uint32_t>(src[13]) << 16 | + static_cast<uint32_t>(src[14]) << 8 | + static_cast<uint32_t>(src[15]); + return result; +} + +inline ABSL_ATTRIBUTE_ALWAYS_INLINE void Vector128Store( + const Vector128& v, void* ABSL_RANDOM_INTERNAL_RESTRICT to) { + uint8_t* dst = reinterpret_cast<uint8_t*>(to); + dst[0] = static_cast<uint8_t>(v.s[0] >> 24); + dst[1] = static_cast<uint8_t>(v.s[0] >> 16); + dst[2] = static_cast<uint8_t>(v.s[0] >> 8); + dst[3] = static_cast<uint8_t>(v.s[0]); + dst[4] = static_cast<uint8_t>(v.s[1] >> 24); + dst[5] = static_cast<uint8_t>(v.s[1] >> 16); + dst[6] = static_cast<uint8_t>(v.s[1] >> 8); + dst[7] = static_cast<uint8_t>(v.s[1]); + dst[8] = static_cast<uint8_t>(v.s[2] >> 24); + dst[9] = static_cast<uint8_t>(v.s[2] >> 16); + dst[10] = static_cast<uint8_t>(v.s[2] >> 8); + dst[11] = static_cast<uint8_t>(v.s[2]); + dst[12] = static_cast<uint8_t>(v.s[3] >> 24); + dst[13] = static_cast<uint8_t>(v.s[3] >> 16); + dst[14] = static_cast<uint8_t>(v.s[3] >> 8); + dst[15] = static_cast<uint8_t>(v.s[3]); +} + +// One round of AES. "round_key" is a public constant for breaking the +// symmetry of AES (ensures previously equal columns differ afterwards). +inline ABSL_ATTRIBUTE_ALWAYS_INLINE Vector128 +AesRound(const Vector128& state, const Vector128& round_key) { + // clang-format off + Vector128 result; + result.s[0] = round_key.s[0] ^ + te0[uint8_t(state.s[0] >> 24)] ^ + te1[uint8_t(state.s[1] >> 16)] ^ + te2[uint8_t(state.s[2] >> 8)] ^ + te3[uint8_t(state.s[3])]; + result.s[1] = round_key.s[1] ^ + te0[uint8_t(state.s[1] >> 24)] ^ + te1[uint8_t(state.s[2] >> 16)] ^ + te2[uint8_t(state.s[3] >> 8)] ^ + te3[uint8_t(state.s[0])]; + result.s[2] = round_key.s[2] ^ + te0[uint8_t(state.s[2] >> 24)] ^ + te1[uint8_t(state.s[3] >> 16)] ^ + te2[uint8_t(state.s[0] >> 8)] ^ + te3[uint8_t(state.s[1])]; + result.s[3] = round_key.s[3] ^ + te0[uint8_t(state.s[3] >> 24)] ^ + te1[uint8_t(state.s[0] >> 16)] ^ + te2[uint8_t(state.s[1] >> 8)] ^ + te3[uint8_t(state.s[2])]; + return result; + // clang-format on +} + +// RANDen = RANDom generator or beetroots in Swiss German. +// 'Strong' (well-distributed, unpredictable, backtracking-resistant) random +// generator, faster in some benchmarks than std::mt19937_64 and pcg64_c32. +// +// High-level summary: +// 1) Reverie (see "A Robust and Sponge-Like PRNG with Improved Efficiency") is +// a sponge-like random generator that requires a cryptographic permutation. +// It improves upon "Provably Robust Sponge-Based PRNGs and KDFs" by +// achieving backtracking resistance with only one Permute() per buffer. +// +// 2) "Simpira v2: A Family of Efficient Permutations Using the AES Round +// Function" constructs up to 1024-bit permutations using an improved +// Generalized Feistel network with 2-round AES-128 functions. This Feistel +// block shuffle achieves diffusion faster and is less vulnerable to +// sliced-biclique attacks than the Type-2 cyclic shuffle. +// +// 3) "Improving the Generalized Feistel" and "New criterion for diffusion +// property" extends the same kind of improved Feistel block shuffle to 16 +// branches, which enables a 2048-bit permutation. +// +// Combine these three ideas and also change Simpira's subround keys from +// structured/low-entropy counters to digits of Pi. + +// Randen constants. +constexpr size_t kFeistelBlocks = 16; +constexpr size_t kFeistelFunctions = kFeistelBlocks / 2; // = 8 +constexpr size_t kFeistelRounds = 16 + 1; // > 4 * log2(kFeistelBlocks) +constexpr size_t kKeys = kFeistelRounds * kFeistelFunctions; + +// INCLUDE keys. +#include "absl/random/internal/randen-keys.inc" + +static_assert(kKeys == kRoundKeys, "kKeys and kRoundKeys must be equal"); + +// 2 uint64_t lanes per Vector128 +static constexpr size_t kLanes = 2; + +// The improved Feistel block shuffle function for 16 blocks. +inline ABSL_ATTRIBUTE_ALWAYS_INLINE void BlockShuffle( + uint64_t* ABSL_RANDOM_INTERNAL_RESTRICT state_u64) { + static_assert(kFeistelBlocks == 16, + "Feistel block shuffle only works for 16 blocks."); + + constexpr size_t shuffle[kFeistelBlocks] = {7, 2, 13, 4, 11, 8, 3, 6, + 15, 0, 9, 10, 1, 14, 5, 12}; + + u64x2* ABSL_RANDOM_INTERNAL_RESTRICT state = + reinterpret_cast<u64x2*>(state_u64); + + // The fully unrolled loop without the memcpy improves the speed by about + // 30% over the equivalent (leaving code here as a comment): + if (false) { + u64x2 source[kFeistelBlocks]; + std::memcpy(source, state, sizeof(source)); + for (size_t i = 0; i < kFeistelBlocks; i++) { + const u64x2 v0 = source[shuffle[i]]; + state[i] = v0; + } + } + + const u64x2 v0 = state[shuffle[0]]; + const u64x2 v1 = state[shuffle[1]]; + const u64x2 v2 = state[shuffle[2]]; + const u64x2 v3 = state[shuffle[3]]; + const u64x2 v4 = state[shuffle[4]]; + const u64x2 v5 = state[shuffle[5]]; + const u64x2 v6 = state[shuffle[6]]; + const u64x2 v7 = state[shuffle[7]]; + const u64x2 w0 = state[shuffle[8]]; + const u64x2 w1 = state[shuffle[9]]; + const u64x2 w2 = state[shuffle[10]]; + const u64x2 w3 = state[shuffle[11]]; + const u64x2 w4 = state[shuffle[12]]; + const u64x2 w5 = state[shuffle[13]]; + const u64x2 w6 = state[shuffle[14]]; + const u64x2 w7 = state[shuffle[15]]; + state[0] = v0; + state[1] = v1; + state[2] = v2; + state[3] = v3; + state[4] = v4; + state[5] = v5; + state[6] = v6; + state[7] = v7; + state[8] = w0; + state[9] = w1; + state[10] = w2; + state[11] = w3; + state[12] = w4; + state[13] = w5; + state[14] = w6; + state[15] = w7; +} + +// Feistel round function using two AES subrounds. Very similar to F() +// from Simpira v2, but with independent subround keys. Uses 17 AES rounds +// per 16 bytes (vs. 10 for AES-CTR). Computing eight round functions in +// parallel hides the 7-cycle AESNI latency on HSW. Note that the Feistel +// XORs are 'free' (included in the second AES instruction). +inline ABSL_ATTRIBUTE_ALWAYS_INLINE const u64x2* FeistelRound( + uint64_t* ABSL_RANDOM_INTERNAL_RESTRICT state, + const u64x2* ABSL_RANDOM_INTERNAL_RESTRICT keys) { + for (size_t branch = 0; branch < kFeistelBlocks; branch += 4) { + const Vector128 s0 = Vector128Load(state + kLanes * branch); + const Vector128 s1 = Vector128Load(state + kLanes * (branch + 1)); + const Vector128 f0 = AesRound(s0, Vector128Load(keys)); + keys++; + const Vector128 o1 = AesRound(f0, s1); + Vector128Store(o1, state + kLanes * (branch + 1)); + + // Manually unroll this loop once. about 10% better than not unrolled. + const Vector128 s2 = Vector128Load(state + kLanes * (branch + 2)); + const Vector128 s3 = Vector128Load(state + kLanes * (branch + 3)); + const Vector128 f2 = AesRound(s2, Vector128Load(keys)); + keys++; + const Vector128 o3 = AesRound(f2, s3); + Vector128Store(o3, state + kLanes * (branch + 3)); + } + return keys; +} + +// Cryptographic permutation based via type-2 Generalized Feistel Network. +// Indistinguishable from ideal by chosen-ciphertext adversaries using less than +// 2^64 queries if the round function is a PRF. This is similar to the b=8 case +// of Simpira v2, but more efficient than its generic construction for b=16. +inline ABSL_ATTRIBUTE_ALWAYS_INLINE void Permute( + const void* keys, uint64_t* ABSL_RANDOM_INTERNAL_RESTRICT state) { + const u64x2* ABSL_RANDOM_INTERNAL_RESTRICT keys128 = + static_cast<const u64x2*>(keys); + for (size_t round = 0; round < kFeistelRounds; ++round) { + keys128 = FeistelRound(state, keys128); + BlockShuffle(state); + } +} + +} // namespace + +namespace absl { +namespace random_internal { + +const void* RandenSlow::GetKeys() { + // Round keys for one AES per Feistel round and branch. + // The canonical implementation uses first digits of Pi. + return round_keys; +} + +void RandenSlow::Absorb(const void* seed_void, void* state_void) { + uint64_t* ABSL_RANDOM_INTERNAL_RESTRICT state = + reinterpret_cast<uint64_t*>(state_void); + const uint64_t* ABSL_RANDOM_INTERNAL_RESTRICT seed = + reinterpret_cast<const uint64_t*>(seed_void); + + constexpr size_t kCapacityBlocks = kCapacityBytes / sizeof(uint64_t); + static_assert(kCapacityBlocks * sizeof(uint64_t) == kCapacityBytes, + "Not i*V"); + for (size_t i = kCapacityBlocks; i < kStateBytes / sizeof(uint64_t); ++i) { + state[i] ^= seed[i - kCapacityBlocks]; + } +} + +void RandenSlow::Generate(const void* keys, void* state_void) { + static_assert(kCapacityBytes == sizeof(Vector128), "Capacity mismatch"); + + uint64_t* ABSL_RANDOM_INTERNAL_RESTRICT state = + reinterpret_cast<uint64_t*>(state_void); + + const Vector128 prev_inner = Vector128Load(state); + + Permute(keys, state); + + // Ensure backtracking resistance. + Vector128 inner = Vector128Load(state); + inner ^= prev_inner; + Vector128Store(inner, state); +} + +} // namespace random_internal +} // namespace absl diff --git a/absl/random/internal/randen_slow.h b/absl/random/internal/randen_slow.h new file mode 100644 index 000000000000..305861308003 --- /dev/null +++ b/absl/random/internal/randen_slow.h @@ -0,0 +1,43 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_RANDEN_SLOW_H_ +#define ABSL_RANDOM_INTERNAL_RANDEN_SLOW_H_ + +#include <cstddef> + +namespace absl { +namespace random_internal { + +// RANDen = RANDom generator or beetroots in Swiss German. +// RandenSlow implements the basic state manipulation methods for +// architectures lacking AES hardware acceleration intrinsics. +class RandenSlow { + public: + // Size of the entire sponge / state for the randen PRNG. + static constexpr size_t kStateBytes = 256; // 2048-bit + + // Size of the 'inner' (inaccessible) part of the sponge. Larger values would + // require more frequent calls to RandenGenerate. + static constexpr size_t kCapacityBytes = 16; // 128-bit + + static void Generate(const void* keys, void* state_void); + static void Absorb(const void* seed_void, void* state_void); + static const void* GetKeys(); +}; + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_RANDEN_SLOW_H_ diff --git a/absl/random/internal/randen_slow_test.cc b/absl/random/internal/randen_slow_test.cc new file mode 100644 index 000000000000..c07155d8dd0c --- /dev/null +++ b/absl/random/internal/randen_slow_test.cc @@ -0,0 +1,61 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/randen_slow.h" + +#include <cstring> + +#include "gtest/gtest.h" + +namespace { + +using absl::random_internal::RandenSlow; + +// Local state parameters. +constexpr size_t kSeedBytes = + RandenSlow::kStateBytes - RandenSlow::kCapacityBytes; +constexpr size_t kStateSizeT = RandenSlow::kStateBytes / sizeof(uint64_t); +constexpr size_t kSeedSizeT = kSeedBytes / sizeof(uint32_t); + +struct randen { + uint64_t state[kStateSizeT]; + uint32_t seed[kSeedSizeT]; +}; + +TEST(RandenSlowTest, Default) { + constexpr uint64_t kGolden[] = { + 0x6c6534090ee6d3ee, 0x044e2b9b9d5333c6, 0xc3c14f134e433977, + 0xdda9f47cd90410ee, 0x887bf3087fd8ca10, 0xf0b780f545c72912, + 0x15dbb1d37696599f, 0x30ec63baff3c6d59, 0xb29f73606f7f20a6, + 0x02808a316f49a54c, 0x3b8feaf9d5c8e50e, 0x9cbf605e3fd9de8a, + 0xc970ae1a78183bbb, 0xd8b2ffd356301ed5, 0xf4b327fe0fc73c37, + 0xcdfd8d76eb8f9a19, 0xc3a506eb91420c9d, 0xd5af05dd3eff9556, + 0x48db1bb78f83c4a1, 0x7023920e0d6bfe8c, 0x58d3575834956d42, + 0xed1ef4c26b87b840, 0x8eef32a23e0b2df3, 0x497cabf3431154fc, + 0x4e24370570029a8b, 0xd88b5749f090e5ea, 0xc651a582a970692f, + 0x78fcec2cbb6342f5, 0x463cb745612f55db, 0x352ee4ad1816afe3, + 0x026ff374c101da7e, 0x811ef0821c3de851, + }; + + alignas(16) randen d; + std::memset(d.state, 0, sizeof(d.state)); + RandenSlow::Generate(RandenSlow::GetKeys(), d.state); + + uint64_t* id = d.state; + for (const auto& elem : kGolden) { + EXPECT_EQ(elem, *id++); + } +} + +} // namespace diff --git a/absl/random/internal/randen_test.cc b/absl/random/internal/randen_test.cc new file mode 100644 index 000000000000..c186fe0d686b --- /dev/null +++ b/absl/random/internal/randen_test.cc @@ -0,0 +1,70 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/randen.h" + +#include <cstring> + +#include "gtest/gtest.h" +#include "absl/meta/type_traits.h" + +namespace { + +using absl::random_internal::Randen; + +// Local state parameters. +constexpr size_t kStateSizeT = Randen::kStateBytes / sizeof(uint64_t); + +TEST(RandenTest, CopyAndMove) { + static_assert(std::is_copy_constructible<Randen>::value, + "Randen must be copy constructible"); + + static_assert(absl::is_copy_assignable<Randen>::value, + "Randen must be copy assignable"); + + static_assert(std::is_move_constructible<Randen>::value, + "Randen must be move constructible"); + + static_assert(absl::is_move_assignable<Randen>::value, + "Randen must be move assignable"); +} + +TEST(RandenTest, Default) { + constexpr uint64_t kGolden[] = { + 0x6c6534090ee6d3ee, 0x044e2b9b9d5333c6, 0xc3c14f134e433977, + 0xdda9f47cd90410ee, 0x887bf3087fd8ca10, 0xf0b780f545c72912, + 0x15dbb1d37696599f, 0x30ec63baff3c6d59, 0xb29f73606f7f20a6, + 0x02808a316f49a54c, 0x3b8feaf9d5c8e50e, 0x9cbf605e3fd9de8a, + 0xc970ae1a78183bbb, 0xd8b2ffd356301ed5, 0xf4b327fe0fc73c37, + 0xcdfd8d76eb8f9a19, 0xc3a506eb91420c9d, 0xd5af05dd3eff9556, + 0x48db1bb78f83c4a1, 0x7023920e0d6bfe8c, 0x58d3575834956d42, + 0xed1ef4c26b87b840, 0x8eef32a23e0b2df3, 0x497cabf3431154fc, + 0x4e24370570029a8b, 0xd88b5749f090e5ea, 0xc651a582a970692f, + 0x78fcec2cbb6342f5, 0x463cb745612f55db, 0x352ee4ad1816afe3, + 0x026ff374c101da7e, 0x811ef0821c3de851, + }; + + alignas(16) uint64_t state[kStateSizeT]; + std::memset(state, 0, sizeof(state)); + + Randen r; + r.Generate(state); + + auto id = std::begin(state); + for (const auto& elem : kGolden) { + EXPECT_EQ(elem, *id++); + } +} + +} // namespace diff --git a/absl/random/internal/randen_traits.h b/absl/random/internal/randen_traits.h new file mode 100644 index 000000000000..4f1f408d1917 --- /dev/null +++ b/absl/random/internal/randen_traits.h @@ -0,0 +1,59 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_RANDEN_TRAITS_H_ +#define ABSL_RANDOM_INTERNAL_RANDEN_TRAITS_H_ + +// HERMETIC NOTE: The randen_hwaes target must not introduce duplicate +// symbols from arbitrary system and other headers, since it may be built +// with different flags from other targets, using different levels of +// optimization, potentially introducing ODR violations. + +#include <cstddef> + +namespace absl { +namespace random_internal { + +// RANDen = RANDom generator or beetroots in Swiss German. +// 'Strong' (well-distributed, unpredictable, backtracking-resistant) random +// generator, faster in some benchmarks than std::mt19937_64 and pcg64_c32. +// +// RandenTraits contains the basic algorithm traits, such as the size of the +// state, seed, sponge, etc. +struct RandenTraits { + // Size of the entire sponge / state for the randen PRNG. + static constexpr size_t kStateBytes = 256; // 2048-bit + + // Size of the 'inner' (inaccessible) part of the sponge. Larger values would + // require more frequent calls to RandenGenerate. + static constexpr size_t kCapacityBytes = 16; // 128-bit + + // Size of the default seed consumed by the sponge. + static constexpr size_t kSeedBytes = kStateBytes - kCapacityBytes; + + // Largest size for which security proofs are known. + static constexpr size_t kFeistelBlocks = 16; + + // Type-2 generalized Feistel => one round function for every two blocks. + static constexpr size_t kFeistelFunctions = kFeistelBlocks / 2; // = 8 + + // Ensures SPRP security and two full subblock diffusions. + // Must be > 4 * log2(kFeistelBlocks). + static constexpr size_t kFeistelRounds = 16 + 1; +}; + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_RANDEN_TRAITS_H_ diff --git a/absl/random/internal/salted_seed_seq.h b/absl/random/internal/salted_seed_seq.h new file mode 100644 index 000000000000..3d16cf97d3d7 --- /dev/null +++ b/absl/random/internal/salted_seed_seq.h @@ -0,0 +1,152 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_SALTED_SEED_SEQ_H_ +#define ABSL_RANDOM_INTERNAL_SALTED_SEED_SEQ_H_ + +#include <cstdint> +#include <cstdlib> +#include <initializer_list> +#include <iterator> +#include <memory> +#include <type_traits> +#include <utility> + +#include "absl/container/inlined_vector.h" +#include "absl/meta/type_traits.h" +#include "absl/random/internal/seed_material.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" + +namespace absl { +namespace random_internal { + +// This class conforms to the C++ Standard "Seed Sequence" concept +// [rand.req.seedseq]. +// +// A `SaltedSeedSeq` is meant to wrap an existing seed sequence and modify +// generated sequence by mixing with extra entropy. This entropy may be +// build-dependent or process-dependent. The implementation may change to be +// have either or both kinds of entropy. If salt is not available sequence is +// not modified. +template <typename SSeq> +class SaltedSeedSeq { + public: + using inner_sequence_type = SSeq; + using result_type = typename SSeq::result_type; + + SaltedSeedSeq() : seq_(absl::make_unique<SSeq>()) {} + + template <typename Iterator> + SaltedSeedSeq(Iterator begin, Iterator end) + : seq_(absl::make_unique<SSeq>(begin, end)) {} + + template <typename T> + SaltedSeedSeq(std::initializer_list<T> il) + : SaltedSeedSeq(il.begin(), il.end()) {} + + SaltedSeedSeq(const SaltedSeedSeq& other) = delete; + SaltedSeedSeq& operator=(const SaltedSeedSeq& other) = delete; + + SaltedSeedSeq(SaltedSeedSeq&& other) = default; + SaltedSeedSeq& operator=(SaltedSeedSeq&& other) = default; + + template <typename RandomAccessIterator> + void generate(RandomAccessIterator begin, RandomAccessIterator end) { + if (begin != end) { + generate_impl( + std::integral_constant<bool, sizeof(*begin) == sizeof(uint32_t)>{}, + begin, end); + } + } + + template <typename OutIterator> + void param(OutIterator out) const { + seq_->param(out); + } + + size_t size() const { return seq_->size(); } + + private: + // The common case for generate is that it is called with iterators over a + // 32-bit value buffer. These can be reinterpreted to a uint32_t and we can + // operate on them as such. + template <typename RandomAccessIterator> + void generate_impl(std::integral_constant<bool, true> /*is_32bit*/, + RandomAccessIterator begin, RandomAccessIterator end) { + seq_->generate(begin, end); + const uint32_t salt = absl::random_internal::GetSaltMaterial().value_or(0); + auto buffer = absl::MakeSpan(begin, end); + MixIntoSeedMaterial( + absl::MakeConstSpan(&salt, 1), + absl::MakeSpan(reinterpret_cast<uint32_t*>(buffer.data()), + buffer.size())); + } + + // The uncommon case for generate is that it is called with iterators over + // some other buffer type which is assignable from a 32-bit value. In this + // case we allocate a temporary 32-bit buffer and then copy-assign back + // to the initial inputs. + template <typename RandomAccessIterator> + void generate_impl(std::integral_constant<bool, false> /*is_32bit*/, + RandomAccessIterator begin, RandomAccessIterator end) { + // Allocate a temporary buffer, seed, and then copy. + absl::InlinedVector<uint32_t, 8> data(std::distance(begin, end), 0); + generate_impl(std::integral_constant<bool, true>{}, data.begin(), + data.end()); + std::copy(data.begin(), data.end(), begin); + } + + // Because [rand.req.seedseq] is not copy-constructible, copy-assignable nor + // movable so we wrap it with unique pointer to be able to move SaltedSeedSeq. + std::unique_ptr<SSeq> seq_; +}; + +// is_salted_seed_seq indicates whether the type is a SaltedSeedSeq. +template <typename T, typename = void> +struct is_salted_seed_seq : public std::false_type {}; + +template <typename T> +struct is_salted_seed_seq< + T, typename std::enable_if<std::is_same< + T, SaltedSeedSeq<typename T::inner_sequence_type>>::value>::type> + : public std::true_type {}; + +// MakeSaltedSeedSeq returns a salted variant of the seed sequence. +// When provided with an existing SaltedSeedSeq, returns the input parameter, +// otherwise constructs a new SaltedSeedSeq which embodies the original +// non-salted seed parameters. +template < + typename SSeq, // + typename EnableIf = absl::enable_if_t<is_salted_seed_seq<SSeq>::value>> +SSeq MakeSaltedSeedSeq(SSeq&& seq) { + return SSeq(std::forward<SSeq>(seq)); +} + +template < + typename SSeq, // + typename EnableIf = absl::enable_if_t<!is_salted_seed_seq<SSeq>::value>> +SaltedSeedSeq<typename std::decay<SSeq>::type> MakeSaltedSeedSeq(SSeq&& seq) { + using sseq_type = typename std::decay<SSeq>::type; + using result_type = typename sseq_type::result_type; + + absl::InlinedVector<result_type, 8> data; + seq.param(std::back_inserter(data)); + return SaltedSeedSeq<sseq_type>(data.begin(), data.end()); +} + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_SALTED_SEED_SEQ_H_ diff --git a/absl/random/internal/salted_seed_seq_test.cc b/absl/random/internal/salted_seed_seq_test.cc new file mode 100644 index 000000000000..0bf19a63ef8c --- /dev/null +++ b/absl/random/internal/salted_seed_seq_test.cc @@ -0,0 +1,168 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/salted_seed_seq.h" + +#include <iterator> +#include <random> +#include <utility> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using absl::random_internal::GetSaltMaterial; +using absl::random_internal::MakeSaltedSeedSeq; +using absl::random_internal::SaltedSeedSeq; +using testing::Eq; +using testing::Pointwise; + +namespace { + +template <typename Sseq> +void ConformsToInterface() { + // Check that the SeedSequence can be default-constructed. + { Sseq default_constructed_seq; } + // Check that the SeedSequence can be constructed with two iterators. + { + uint32_t init_array[] = {1, 3, 5, 7, 9}; + Sseq iterator_constructed_seq(std::begin(init_array), std::end(init_array)); + } + // Check that the SeedSequence can be std::initializer_list-constructed. + { Sseq list_constructed_seq = {1, 3, 5, 7, 9, 11, 13}; } + // Check that param() and size() return state provided to constructor. + { + uint32_t init_array[] = {1, 2, 3, 4, 5}; + Sseq seq(std::begin(init_array), std::end(init_array)); + EXPECT_EQ(seq.size(), ABSL_ARRAYSIZE(init_array)); + + std::vector<uint32_t> state_vector; + seq.param(std::back_inserter(state_vector)); + + EXPECT_EQ(state_vector.size(), ABSL_ARRAYSIZE(init_array)); + for (int i = 0; i < state_vector.size(); i++) { + EXPECT_EQ(state_vector[i], i + 1); + } + } + // Check for presence of generate() method. + { + Sseq seq; + uint32_t seeds[5]; + + seq.generate(std::begin(seeds), std::end(seeds)); + } +} + +TEST(SaltedSeedSeq, CheckInterfaces) { + // Control case + ConformsToInterface<std::seed_seq>(); + + // Abseil classes + ConformsToInterface<SaltedSeedSeq<std::seed_seq>>(); +} + +TEST(SaltedSeedSeq, CheckConstructingFromOtherSequence) { + std::vector<uint32_t> seed_values(10, 1); + std::seed_seq seq(seed_values.begin(), seed_values.end()); + auto salted_seq = MakeSaltedSeedSeq(std::move(seq)); + + EXPECT_EQ(seq.size(), salted_seq.size()); + + std::vector<uint32_t> param_result; + seq.param(std::back_inserter(param_result)); + + EXPECT_EQ(seed_values, param_result); +} + +TEST(SaltedSeedSeq, SaltedSaltedSeedSeqIsNotDoubleSalted) { + uint32_t init[] = {1, 3, 5, 7, 9}; + + std::seed_seq seq(std::begin(init), std::end(init)); + + // The first salting. + SaltedSeedSeq<std::seed_seq> salted_seq = MakeSaltedSeedSeq(std::move(seq)); + uint32_t a[16]; + salted_seq.generate(std::begin(a), std::end(a)); + + // The second salting. + SaltedSeedSeq<std::seed_seq> salted_salted_seq = + MakeSaltedSeedSeq(std::move(salted_seq)); + uint32_t b[16]; + salted_salted_seq.generate(std::begin(b), std::end(b)); + + // ... both should be equal. + EXPECT_THAT(b, Pointwise(Eq(), a)) << "a[0] " << a[0]; +} + +TEST(SaltedSeedSeq, SeedMaterialIsSalted) { + const size_t kNumBlocks = 16; + + uint32_t seed_material[kNumBlocks]; + std::random_device urandom{"/dev/urandom"}; + for (uint32_t& seed : seed_material) { + seed = urandom(); + } + + std::seed_seq seq(std::begin(seed_material), std::end(seed_material)); + SaltedSeedSeq<std::seed_seq> salted_seq(std::begin(seed_material), + std::end(seed_material)); + + bool salt_is_available = GetSaltMaterial().has_value(); + + // If salt is available generated sequence should be different. + if (salt_is_available) { + uint32_t outputs[kNumBlocks]; + uint32_t salted_outputs[kNumBlocks]; + + seq.generate(std::begin(outputs), std::end(outputs)); + salted_seq.generate(std::begin(salted_outputs), std::end(salted_outputs)); + + EXPECT_THAT(outputs, Pointwise(testing::Ne(), salted_outputs)); + } +} + +TEST(SaltedSeedSeq, GenerateAcceptsDifferentTypes) { + const size_t kNumBlocks = 4; + + SaltedSeedSeq<std::seed_seq> seq({1, 2, 3}); + + uint32_t expected[kNumBlocks]; + seq.generate(std::begin(expected), std::end(expected)); + + // 32-bit outputs + { + unsigned long seed_material[kNumBlocks]; // NOLINT(runtime/int) + seq.generate(std::begin(seed_material), std::end(seed_material)); + EXPECT_THAT(seed_material, Pointwise(Eq(), expected)); + } + { + unsigned int seed_material[kNumBlocks]; // NOLINT(runtime/int) + seq.generate(std::begin(seed_material), std::end(seed_material)); + EXPECT_THAT(seed_material, Pointwise(Eq(), expected)); + } + + // 64-bit outputs. + { + uint64_t seed_material[kNumBlocks]; + seq.generate(std::begin(seed_material), std::end(seed_material)); + EXPECT_THAT(seed_material, Pointwise(Eq(), expected)); + } + { + int64_t seed_material[kNumBlocks]; + seq.generate(std::begin(seed_material), std::end(seed_material)); + EXPECT_THAT(seed_material, Pointwise(Eq(), expected)); + } +} + +} // namespace diff --git a/absl/random/internal/seed_material.cc b/absl/random/internal/seed_material.cc new file mode 100644 index 000000000000..ec3afe04a3f7 --- /dev/null +++ b/absl/random/internal/seed_material.cc @@ -0,0 +1,204 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/seed_material.h" + +#include <fcntl.h> + +#ifndef _WIN32 +#include <unistd.h> +#else +#include <io.h> +#endif + +#include <algorithm> +#include <cerrno> +#include <cstdint> +#include <cstdlib> +#include <cstring> + +#include "absl/base/internal/raw_logging.h" +#include "absl/strings/ascii.h" +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" + +#if defined(__native_client__) + +#include <nacl/nacl_random.h> +#define ABSL_RANDOM_USE_NACL_SECURE_RANDOM 1 + +#elif defined(_WIN32) + +#include <windows.h> +#define ABSL_RANDOM_USE_BCRYPT 1 +#pragma comment(lib, "bcrypt.lib") +#endif + +#if defined(ABSL_RANDOM_USE_BCRYPT) +#include <bcrypt.h> + +#ifndef BCRYPT_SUCCESS +#define BCRYPT_SUCCESS(Status) (((NTSTATUS)(Status)) >= 0) +#endif +// Also link bcrypt; this can be done via linker options or: +// #pragma comment(lib, "bcrypt.lib") +#endif + +namespace absl { +namespace random_internal { +namespace { + +// Read OS Entropy for random number seeds. +// TODO(absl-team): Possibly place a cap on how much entropy may be read at a +// time. + +#if defined(ABSL_RANDOM_USE_BCRYPT) + +// On Windows potentially use the BCRYPT CNG API to read available entropy. +bool ReadSeedMaterialFromOSEntropyImpl(absl::Span<uint32_t> values) { + BCRYPT_ALG_HANDLE hProvider; + NTSTATUS ret; + ret = BCryptOpenAlgorithmProvider(&hProvider, BCRYPT_RNG_ALGORITHM, + MS_PRIMITIVE_PROVIDER, 0); + if (!(BCRYPT_SUCCESS(ret))) { + ABSL_RAW_LOG(ERROR, "Failed to open crypto provider."); + return false; + } + ret = BCryptGenRandom( + hProvider, // provider + reinterpret_cast<UCHAR*>(values.data()), // buffer + static_cast<ULONG>(sizeof(uint32_t) * values.size()), // bytes + 0); // flags + BCryptCloseAlgorithmProvider(hProvider, 0); + return BCRYPT_SUCCESS(ret); +} + +#elif defined(ABSL_RANDOM_USE_NACL_SECURE_RANDOM) + +// On NaCL use nacl_secure_random to acquire bytes. +bool ReadSeedMaterialFromOSEntropyImpl(absl::Span<uint32_t> values) { + auto buffer = reinterpret_cast<uint8_t*>(values.data()); + size_t buffer_size = sizeof(uint32_t) * values.size(); + + uint8_t* output_ptr = buffer; + while (buffer_size > 0) { + size_t nread = 0; + const int error = nacl_secure_random(output_ptr, buffer_size, &nread); + if (error != 0 || nread > buffer_size) { + ABSL_RAW_LOG(ERROR, "Failed to read secure_random seed data: %d", error); + return false; + } + output_ptr += nread; + buffer_size -= nread; + } + return true; +} + +#else + +// On *nix, read entropy from /dev/urandom. +bool ReadSeedMaterialFromOSEntropyImpl(absl::Span<uint32_t> values) { + const char kEntropyFile[] = "/dev/urandom"; + + auto buffer = reinterpret_cast<uint8_t*>(values.data()); + size_t buffer_size = sizeof(uint32_t) * values.size(); + + int dev_urandom = open(kEntropyFile, O_RDONLY); + bool success = (-1 != dev_urandom); + if (!success) { + return false; + } + + while (success && buffer_size > 0) { + int bytes_read = read(dev_urandom, buffer, buffer_size); + int read_error = errno; + success = (bytes_read > 0); + if (success) { + buffer += bytes_read; + buffer_size -= bytes_read; + } else if (bytes_read == -1 && read_error == EINTR) { + success = true; // Need to try again. + } + } + close(dev_urandom); + return success; +} + +#endif + +} // namespace + +bool ReadSeedMaterialFromOSEntropy(absl::Span<uint32_t> values) { + assert(values.data() != nullptr); + if (values.data() == nullptr) { + return false; + } + if (values.empty()) { + return true; + } + return ReadSeedMaterialFromOSEntropyImpl(values); +} + +void MixIntoSeedMaterial(absl::Span<const uint32_t> sequence, + absl::Span<uint32_t> seed_material) { + // Algorithm is based on code available at + // https://gist.github.com/imneme/540829265469e673d045 + constexpr uint32_t kInitVal = 0x43b0d7e5; + constexpr uint32_t kHashMul = 0x931e8875; + constexpr uint32_t kMixMulL = 0xca01f9dd; + constexpr uint32_t kMixMulR = 0x4973f715; + constexpr uint32_t kShiftSize = sizeof(uint32_t) * 8 / 2; + + uint32_t hash_const = kInitVal; + auto hash = [&](uint32_t value) { + value ^= hash_const; + hash_const *= kHashMul; + value *= hash_const; + value ^= value >> kShiftSize; + return value; + }; + + auto mix = [&](uint32_t x, uint32_t y) { + uint32_t result = kMixMulL * x - kMixMulR * y; + result ^= result >> kShiftSize; + return result; + }; + + for (const auto& seq_val : sequence) { + for (auto& elem : seed_material) { + elem = mix(elem, hash(seq_val)); + } + } +} + +absl::optional<uint32_t> GetSaltMaterial() { + // Salt must be common for all generators within the same process so read it + // only once and store in static variable. + static const auto salt_material = []() -> absl::optional<uint32_t> { + uint32_t salt_value = 0; + + if (random_internal::ReadSeedMaterialFromOSEntropy( + MakeSpan(&salt_value, 1))) { + return salt_value; + } + + return absl::nullopt; + }(); + + return salt_material; +} + +} // namespace random_internal +} // namespace absl diff --git a/absl/random/internal/seed_material.h b/absl/random/internal/seed_material.h new file mode 100644 index 000000000000..57de8a24d5e0 --- /dev/null +++ b/absl/random/internal/seed_material.h @@ -0,0 +1,102 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_SEED_MATERIAL_H_ +#define ABSL_RANDOM_INTERNAL_SEED_MATERIAL_H_ + +#include <cassert> +#include <cstdint> +#include <cstdlib> +#include <string> +#include <vector> + +#include "absl/base/attributes.h" +#include "absl/random/internal/fast_uniform_bits.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" + +namespace absl { +namespace random_internal { + +// Returns the number of 32-bit blocks needed to contain the given number of +// bits. +constexpr size_t SeedBitsToBlocks(size_t seed_size) { + return (seed_size + 31) / 32; +} + +// Amount of entropy (measured in bits) used to instantiate a Seed Sequence, +// with which to create a URBG. +constexpr size_t kEntropyBitsNeeded = 256; + +// Amount of entropy (measured in 32-bit blocks) used to instantiate a Seed +// Sequence, with which to create a URBG. +constexpr size_t kEntropyBlocksNeeded = + random_internal::SeedBitsToBlocks(kEntropyBitsNeeded); + +static_assert(kEntropyBlocksNeeded > 0, + "Entropy used to seed URBGs must be nonzero."); + +// Attempts to fill a span of uint32_t-values using an OS-provided source of +// true entropy (eg. /dev/urandom) into an array of uint32_t blocks of data. The +// resulting array may be used to initialize an instance of a class conforming +// to the C++ Standard "Seed Sequence" concept [rand.req.seedseq]. +// +// If values.data() == nullptr, the behavior is undefined. +ABSL_MUST_USE_RESULT +bool ReadSeedMaterialFromOSEntropy(absl::Span<uint32_t> values); + +// Attempts to fill a span of uint32_t-values using variates generated by an +// existing instance of a class conforming to the C++ Standard "Uniform Random +// Bit Generator" concept [rand.req.urng]. The resulting data may be used to +// initialize an instance of a class conforming to the C++ Standard +// "Seed Sequence" concept [rand.req.seedseq]. +// +// If urbg == nullptr or values.data() == nullptr, the behavior is undefined. +template <typename URBG> +ABSL_MUST_USE_RESULT bool ReadSeedMaterialFromURBG( + URBG* urbg, absl::Span<uint32_t> values) { + random_internal::FastUniformBits<uint32_t> distr; + + assert(urbg != nullptr && values.data() != nullptr); + if (urbg == nullptr || values.data() == nullptr) { + return false; + } + + for (uint32_t& seed_value : values) { + seed_value = distr(*urbg); + } + return true; +} + +// Mixes given sequence of values with into given sequence of seed material. +// Time complexity of this function is O(sequence.size() * +// seed_material.size()). +// +// Algorithm is based on code available at +// https://gist.github.com/imneme/540829265469e673d045 +// by Melissa O'Neill. +void MixIntoSeedMaterial(absl::Span<const uint32_t> sequence, + absl::Span<uint32_t> seed_material); + +// Returns salt value. +// +// Salt is obtained only once and stored in static variable. +// +// May return empty value if optaining the salt was not possible. +absl::optional<uint32_t> GetSaltMaterial(); + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_SEED_MATERIAL_H_ diff --git a/absl/random/internal/seed_material_test.cc b/absl/random/internal/seed_material_test.cc new file mode 100644 index 000000000000..0de6c4c6984d --- /dev/null +++ b/absl/random/internal/seed_material_test.cc @@ -0,0 +1,201 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/seed_material.h" + +#include <bitset> +#include <cstdlib> +#include <cstring> +#include <random> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#ifdef __ANDROID__ +// Android assert messages only go to system log, so death tests cannot inspect +// the message for matching. +#define ABSL_EXPECT_DEATH_IF_SUPPORTED(statement, regex) \ + EXPECT_DEATH_IF_SUPPORTED(statement, ".*") +#else +#define ABSL_EXPECT_DEATH_IF_SUPPORTED EXPECT_DEATH_IF_SUPPORTED +#endif + +namespace { + +using testing::Each; +using testing::ElementsAre; +using testing::Eq; +using testing::Ne; +using testing::Pointwise; + +TEST(SeedBitsToBlocks, VerifyCases) { + EXPECT_EQ(0, absl::random_internal::SeedBitsToBlocks(0)); + EXPECT_EQ(1, absl::random_internal::SeedBitsToBlocks(1)); + EXPECT_EQ(1, absl::random_internal::SeedBitsToBlocks(31)); + EXPECT_EQ(1, absl::random_internal::SeedBitsToBlocks(32)); + EXPECT_EQ(2, absl::random_internal::SeedBitsToBlocks(33)); + EXPECT_EQ(4, absl::random_internal::SeedBitsToBlocks(127)); + EXPECT_EQ(4, absl::random_internal::SeedBitsToBlocks(128)); + EXPECT_EQ(5, absl::random_internal::SeedBitsToBlocks(129)); +} + +TEST(ReadSeedMaterialFromOSEntropy, SuccessiveReadsAreDistinct) { + constexpr size_t kSeedMaterialSize = 64; + uint32_t seed_material_1[kSeedMaterialSize] = {}; + uint32_t seed_material_2[kSeedMaterialSize] = {}; + + EXPECT_TRUE(absl::random_internal::ReadSeedMaterialFromOSEntropy( + absl::Span<uint32_t>(seed_material_1, kSeedMaterialSize))); + EXPECT_TRUE(absl::random_internal::ReadSeedMaterialFromOSEntropy( + absl::Span<uint32_t>(seed_material_2, kSeedMaterialSize))); + + EXPECT_THAT(seed_material_1, Pointwise(Ne(), seed_material_2)); +} + +TEST(ReadSeedMaterialFromOSEntropy, ReadZeroBytesIsNoOp) { + uint32_t seed_material[32] = {}; + std::memset(seed_material, 0xAA, sizeof(seed_material)); + EXPECT_TRUE(absl::random_internal::ReadSeedMaterialFromOSEntropy( + absl::Span<uint32_t>(seed_material, 0))); + + EXPECT_THAT(seed_material, Each(Eq(0xAAAAAAAA))); +} + +TEST(ReadSeedMaterialFromOSEntropy, NullPtrVectorArgument) { +#ifdef NDEBUG + EXPECT_FALSE(absl::random_internal::ReadSeedMaterialFromOSEntropy( + absl::Span<uint32_t>(nullptr, 32))); +#else + bool result; + ABSL_EXPECT_DEATH_IF_SUPPORTED( + result = absl::random_internal::ReadSeedMaterialFromOSEntropy( + absl::Span<uint32_t>(nullptr, 32)), + "!= nullptr"); + (void)result; // suppress unused-variable warning +#endif +} + +TEST(ReadSeedMaterialFromURBG, SeedMaterialEqualsVariateSequence) { + // Two default-constructed instances of std::mt19937_64 are guaranteed to + // produce equal variate-sequences. + std::mt19937 urbg_1; + std::mt19937 urbg_2; + constexpr size_t kSeedMaterialSize = 1024; + uint32_t seed_material[kSeedMaterialSize] = {}; + + EXPECT_TRUE(absl::random_internal::ReadSeedMaterialFromURBG( + &urbg_1, absl::Span<uint32_t>(seed_material, kSeedMaterialSize))); + for (uint32_t seed : seed_material) { + EXPECT_EQ(seed, urbg_2()); + } +} + +TEST(ReadSeedMaterialFromURBG, ReadZeroBytesIsNoOp) { + std::mt19937_64 urbg; + uint32_t seed_material[32]; + std::memset(seed_material, 0xAA, sizeof(seed_material)); + EXPECT_TRUE(absl::random_internal::ReadSeedMaterialFromURBG( + &urbg, absl::Span<uint32_t>(seed_material, 0))); + + EXPECT_THAT(seed_material, Each(Eq(0xAAAAAAAA))); +} + +TEST(ReadSeedMaterialFromURBG, NullUrbgArgument) { + constexpr size_t kSeedMaterialSize = 32; + uint32_t seed_material[kSeedMaterialSize]; +#ifdef NDEBUG + EXPECT_FALSE(absl::random_internal::ReadSeedMaterialFromURBG<std::mt19937_64>( + nullptr, absl::Span<uint32_t>(seed_material, kSeedMaterialSize))); +#else + bool result; + ABSL_EXPECT_DEATH_IF_SUPPORTED( + result = absl::random_internal::ReadSeedMaterialFromURBG<std::mt19937_64>( + nullptr, absl::Span<uint32_t>(seed_material, kSeedMaterialSize)), + "!= nullptr"); + (void)result; // suppress unused-variable warning +#endif +} + +TEST(ReadSeedMaterialFromURBG, NullPtrVectorArgument) { + std::mt19937_64 urbg; +#ifdef NDEBUG + EXPECT_FALSE(absl::random_internal::ReadSeedMaterialFromURBG( + &urbg, absl::Span<uint32_t>(nullptr, 32))); +#else + bool result; + ABSL_EXPECT_DEATH_IF_SUPPORTED( + result = absl::random_internal::ReadSeedMaterialFromURBG( + &urbg, absl::Span<uint32_t>(nullptr, 32)), + "!= nullptr"); + (void)result; // suppress unused-variable warning +#endif +} + +// The avalanche effect is a desirable cryptographic property of hashes in which +// changing a single bit in the input causes each bit of the output to be +// changed with probability near 50%. +// +// https://en.wikipedia.org/wiki/Avalanche_effect + +TEST(MixSequenceIntoSeedMaterial, AvalancheEffectTestOneBitLong) { + std::vector<uint32_t> seed_material = {1, 2, 3, 4, 5, 6, 7, 8}; + + // For every 32-bit number with exactly one bit set, verify the avalanche + // effect holds. In order to reduce flakiness of tests, accept values + // anywhere in the range of 30%-70%. + for (uint32_t v = 1; v != 0; v <<= 1) { + std::vector<uint32_t> seed_material_copy = seed_material; + absl::random_internal::MixIntoSeedMaterial( + absl::Span<uint32_t>(&v, 1), + absl::Span<uint32_t>(seed_material_copy.data(), + seed_material_copy.size())); + + uint32_t changed_bits = 0; + for (size_t i = 0; i < seed_material.size(); i++) { + std::bitset<sizeof(uint32_t) * 8> bitset(seed_material[i] ^ + seed_material_copy[i]); + changed_bits += bitset.count(); + } + + EXPECT_LE(changed_bits, 0.7 * sizeof(uint32_t) * 8 * seed_material.size()); + EXPECT_GE(changed_bits, 0.3 * sizeof(uint32_t) * 8 * seed_material.size()); + } +} + +TEST(MixSequenceIntoSeedMaterial, AvalancheEffectTestOneBitShort) { + std::vector<uint32_t> seed_material = {1}; + + // For every 32-bit number with exactly one bit set, verify the avalanche + // effect holds. In order to reduce flakiness of tests, accept values + // anywhere in the range of 30%-70%. + for (uint32_t v = 1; v != 0; v <<= 1) { + std::vector<uint32_t> seed_material_copy = seed_material; + absl::random_internal::MixIntoSeedMaterial( + absl::Span<uint32_t>(&v, 1), + absl::Span<uint32_t>(seed_material_copy.data(), + seed_material_copy.size())); + + uint32_t changed_bits = 0; + for (size_t i = 0; i < seed_material.size(); i++) { + std::bitset<sizeof(uint32_t) * 8> bitset(seed_material[i] ^ + seed_material_copy[i]); + changed_bits += bitset.count(); + } + + EXPECT_LE(changed_bits, 0.7 * sizeof(uint32_t) * 8 * seed_material.size()); + EXPECT_GE(changed_bits, 0.3 * sizeof(uint32_t) * 8 * seed_material.size()); + } +} + +} // namespace diff --git a/absl/random/internal/seed_salting_sequence_generator.cc b/absl/random/internal/seed_salting_sequence_generator.cc new file mode 100644 index 000000000000..31fdcfe1db41 --- /dev/null +++ b/absl/random/internal/seed_salting_sequence_generator.cc @@ -0,0 +1,30 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <iostream> +#include <random> + +#include "absl/random/random.h" + +// This program is used in integration tests. + +int main() { + std::seed_seq seed_seq{1234}; + absl::BitGen rng(seed_seq); + constexpr size_t kSequenceLength = 8; + for (size_t i = 0; i < kSequenceLength; i++) { + std::cout << rng() << "\n"; + } + return 0; +} diff --git a/absl/random/internal/seed_salting_sequence_generator_empty_sequence.cc b/absl/random/internal/seed_salting_sequence_generator_empty_sequence.cc new file mode 100644 index 000000000000..8797e2e7a6f2 --- /dev/null +++ b/absl/random/internal/seed_salting_sequence_generator_empty_sequence.cc @@ -0,0 +1,30 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <iostream> +#include <random> + +#include "absl/random/random.h" + +// This program is used in integration tests. + +int main() { + std::seed_seq seed_seq{}; + absl::BitGen rng(seed_seq); + constexpr size_t kSequenceLength = 8; + for (size_t i = 0; i < kSequenceLength; i++) { + std::cout << rng() << "\n"; + } + return 0; +} diff --git a/absl/random/internal/sequence_urbg.h b/absl/random/internal/sequence_urbg.h new file mode 100644 index 000000000000..9a9b57738da4 --- /dev/null +++ b/absl/random/internal/sequence_urbg.h @@ -0,0 +1,56 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_SEQUENCE_URBG_H_ +#define ABSL_RANDOM_INTERNAL_SEQUENCE_URBG_H_ + +#include <cstdint> +#include <cstring> +#include <limits> +#include <type_traits> +#include <vector> + +namespace absl { +namespace random_internal { + +// `sequence_urbg` is a simple random number generator which meets the +// requirements of [rand.req.urbg], and is solely for testing absl +// distributions. +class sequence_urbg { + public: + using result_type = uint64_t; + + static constexpr result_type(min)() { + return (std::numeric_limits<result_type>::min)(); + } + static constexpr result_type(max)() { + return (std::numeric_limits<result_type>::max)(); + } + + sequence_urbg(std::initializer_list<result_type> data) : i_(0), data_(data) {} + void reset() { i_ = 0; } + + result_type operator()() { return data_[i_++ % data_.size()]; } + + size_t invocations() const { return i_; } + + private: + size_t i_; + std::vector<result_type> data_; +}; + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_SEQUENCE_URBG_H_ diff --git a/absl/random/internal/traits.h b/absl/random/internal/traits.h new file mode 100644 index 000000000000..40eb011f195f --- /dev/null +++ b/absl/random/internal/traits.h @@ -0,0 +1,99 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_INTERNAL_TRAITS_H_ +#define ABSL_RANDOM_INTERNAL_TRAITS_H_ + +#include <cstdint> +#include <limits> +#include <type_traits> + +#include "absl/base/config.h" + +namespace absl { +namespace random_internal { + +// random_internal::is_widening_convertible<A, B> +// +// Returns whether a type A is widening-convertible to a type B. +// +// A is widening-convertible to B means: +// A a = <any number>; +// B b = a; +// A c = b; +// EXPECT_EQ(a, c); +template <typename A, typename B> +class is_widening_convertible { + // As long as there are enough bits in the exact part of a number: + // - unsigned can fit in float, signed, unsigned + // - signed can fit in float, signed + // - float can fit in float + // So we define rank to be: + // - rank(float) -> 2 + // - rank(signed) -> 1 + // - rank(unsigned) -> 0 + template <class T> + static constexpr int rank() { + return !std::numeric_limits<T>::is_integer + + std::numeric_limits<T>::is_signed; + } + + public: + // If an arithmetic-type B can represent at least as many digits as a type A, + // and B belongs to a rank no lower than A, then A can be safely represented + // by B through a widening-conversion. + static constexpr bool value = + std::numeric_limits<A>::digits <= std::numeric_limits<B>::digits && + rank<A>() <= rank<B>(); +}; + +// unsigned_bits<N>::type returns the unsigned int type with the indicated +// number of bits. +template <size_t N> +struct unsigned_bits; + +template <> +struct unsigned_bits<8> { + using type = uint8_t; +}; +template <> +struct unsigned_bits<16> { + using type = uint16_t; +}; +template <> +struct unsigned_bits<32> { + using type = uint32_t; +}; +template <> +struct unsigned_bits<64> { + using type = uint64_t; +}; + +#ifdef ABSL_HAVE_INTRINSIC_INT128 +template <> +struct unsigned_bits<128> { + using type = __uint128_t; +}; +#endif + +template <typename IntType> +struct make_unsigned_bits { + using type = typename unsigned_bits<std::numeric_limits< + typename std::make_unsigned<IntType>::type>::digits>::type; +}; + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_INTERNAL_TRAITS_H_ diff --git a/absl/random/internal/traits_test.cc b/absl/random/internal/traits_test.cc new file mode 100644 index 000000000000..a844887d3e44 --- /dev/null +++ b/absl/random/internal/traits_test.cc @@ -0,0 +1,126 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/internal/traits.h" + +#include <cstdint> +#include <type_traits> + +#include "gtest/gtest.h" + +namespace { + +using absl::random_internal::is_widening_convertible; + +// CheckWideningConvertsToSelf<T1, T2, ...>() +// +// For each type T, checks: +// - T IS widening-convertible to itself. +// +template <typename T> +void CheckWideningConvertsToSelf() { + static_assert(is_widening_convertible<T, T>::value, + "Type is not convertible to self!"); +} + +template <typename T, typename Next, typename... Args> +void CheckWideningConvertsToSelf() { + CheckWideningConvertsToSelf<T>(); + CheckWideningConvertsToSelf<Next, Args...>(); +} + +// CheckNotWideningConvertibleWithSigned<T1, T2, ...>() +// +// For each unsigned-type T, checks that: +// - T is NOT widening-convertible to Signed(T) +// - Signed(T) is NOT widening-convertible to T +// +template <typename T> +void CheckNotWideningConvertibleWithSigned() { + using signed_t = typename std::make_signed<T>::type; + + static_assert(!is_widening_convertible<T, signed_t>::value, + "Unsigned type is convertible to same-sized signed-type!"); + static_assert(!is_widening_convertible<signed_t, T>::value, + "Signed type is convertible to same-sized unsigned-type!"); +} + +template <typename T, typename Next, typename... Args> +void CheckNotWideningConvertibleWithSigned() { + CheckNotWideningConvertibleWithSigned<T>(); + CheckWideningConvertsToSelf<Next, Args...>(); +} + +// CheckWideningConvertsToLargerType<T1, T2, ...>() +// +// For each successive unsigned-types {Ti, Ti+1}, checks that: +// - Ti IS widening-convertible to Ti+1 +// - Ti IS widening-convertible to Signed(Ti+1) +// - Signed(Ti) is NOT widening-convertible to Ti +// - Signed(Ti) IS widening-convertible to Ti+1 +template <typename T, typename Higher> +void CheckWideningConvertsToLargerTypes() { + using signed_t = typename std::make_signed<T>::type; + using higher_t = Higher; + using signed_higher_t = typename std::make_signed<Higher>::type; + + static_assert(is_widening_convertible<T, higher_t>::value, + "Type not embeddable into larger type!"); + static_assert(is_widening_convertible<T, signed_higher_t>::value, + "Type not embeddable into larger signed type!"); + static_assert(!is_widening_convertible<signed_t, higher_t>::value, + "Signed type is embeddable into larger unsigned type!"); + static_assert(is_widening_convertible<signed_t, signed_higher_t>::value, + "Signed type not embeddable into larger signed type!"); +} + +template <typename T, typename Higher, typename Next, typename... Args> +void CheckWideningConvertsToLargerTypes() { + CheckWideningConvertsToLargerTypes<T, Higher>(); + CheckWideningConvertsToLargerTypes<Higher, Next, Args...>(); +} + +// CheckWideningConvertsTo<T, U, [expect]> +// +// Checks that T DOES widening-convert to U. +// If "expect" is false, then asserts that T does NOT widening-convert to U. +template <typename T, typename U, bool expect = true> +void CheckWideningConvertsTo() { + static_assert(is_widening_convertible<T, U>::value == expect, + "Unexpected result for is_widening_convertible<T, U>!"); +} + +TEST(TraitsTest, IsWideningConvertibleTest) { + constexpr bool kInvalid = false; + + CheckWideningConvertsToSelf< + uint8_t, uint16_t, uint32_t, uint64_t, + int8_t, int16_t, int32_t, int64_t, + float, double>(); + CheckNotWideningConvertibleWithSigned< + uint8_t, uint16_t, uint32_t, uint64_t>(); + CheckWideningConvertsToLargerTypes< + uint8_t, uint16_t, uint32_t, uint64_t>(); + + CheckWideningConvertsTo<float, double>(); + CheckWideningConvertsTo<uint16_t, float>(); + CheckWideningConvertsTo<uint32_t, double>(); + CheckWideningConvertsTo<uint64_t, double, kInvalid>(); + CheckWideningConvertsTo<double, float, kInvalid>(); + + CheckWideningConvertsTo<bool, int>(); + CheckWideningConvertsTo<bool, float>(); +} + +} // namespace diff --git a/absl/random/internal/uniform_helper.h b/absl/random/internal/uniform_helper.h new file mode 100644 index 000000000000..b6e2a4a59bac --- /dev/null +++ b/absl/random/internal/uniform_helper.h @@ -0,0 +1,150 @@ +// Copyright 2019 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +#ifndef ABSL_RANDOM_UNIFORM_HELPER_H_ +#define ABSL_RANDOM_UNIFORM_HELPER_H_ + +#include <cmath> +#include <limits> +#include <type_traits> + +#include "absl/meta/type_traits.h" + +namespace absl { +template <typename IntType> +class uniform_int_distribution; + +template <typename RealType> +class uniform_real_distribution; + +// Interval tag types which specify whether the interval is open or closed +// on either boundary. +namespace random_internal { +struct IntervalClosedClosedT {}; +struct IntervalClosedOpenT {}; +struct IntervalOpenClosedT {}; +struct IntervalOpenOpenT {}; +} // namespace random_internal + +namespace random_internal { + +// The functions +// uniform_lower_bound(tag, a, b) +// and +// uniform_upper_bound(tag, a, b) +// are used as implementation-details for absl::Uniform(). +// +// Conceptually, +// [a, b] == [uniform_lower_bound(IntervalClosedClosed, a, b), +// uniform_upper_bound(IntervalClosedClosed, a, b)] +// (a, b) == [uniform_lower_bound(IntervalOpenOpen, a, b), +// uniform_upper_bound(IntervalOpenOpen, a, b)] +// [a, b) == [uniform_lower_bound(IntervalClosedOpen, a, b), +// uniform_upper_bound(IntervalClosedOpen, a, b)] +// (a, b] == [uniform_lower_bound(IntervalOpenClosed, a, b), +// uniform_upper_bound(IntervalOpenClosed, a, b)] +// +template <typename IntType, typename Tag> +typename absl::enable_if_t< + absl::conjunction< + std::is_integral<IntType>, + absl::disjunction<std::is_same<Tag, IntervalOpenClosedT>, + std::is_same<Tag, IntervalOpenOpenT>>>::value, + IntType> +uniform_lower_bound(Tag, IntType a, IntType) { + return a + 1; +} + +template <typename FloatType, typename Tag> +typename absl::enable_if_t< + absl::conjunction< + std::is_floating_point<FloatType>, + absl::disjunction<std::is_same<Tag, IntervalOpenClosedT>, + std::is_same<Tag, IntervalOpenOpenT>>>::value, + FloatType> +uniform_lower_bound(Tag, FloatType a, FloatType b) { + return std::nextafter(a, b); +} + +template <typename NumType, typename Tag> +typename absl::enable_if_t< + absl::disjunction<std::is_same<Tag, IntervalClosedClosedT>, + std::is_same<Tag, IntervalClosedOpenT>>::value, + NumType> +uniform_lower_bound(Tag, NumType a, NumType) { + return a; +} + +template <typename IntType, typename Tag> +typename absl::enable_if_t< + absl::conjunction< + std::is_integral<IntType>, + absl::disjunction<std::is_same<Tag, IntervalClosedOpenT>, + std::is_same<Tag, IntervalOpenOpenT>>>::value, + IntType> +uniform_upper_bound(Tag, IntType, IntType b) { + return b - 1; +} + +template <typename FloatType, typename Tag> +typename absl::enable_if_t< + absl::conjunction< + std::is_floating_point<FloatType>, + absl::disjunction<std::is_same<Tag, IntervalClosedOpenT>, + std::is_same<Tag, IntervalOpenOpenT>>>::value, + FloatType> +uniform_upper_bound(Tag, FloatType, FloatType b) { + return b; +} + +template <typename IntType, typename Tag> +typename absl::enable_if_t< + absl::conjunction< + std::is_integral<IntType>, + absl::disjunction<std::is_same<Tag, IntervalClosedClosedT>, + std::is_same<Tag, IntervalOpenClosedT>>>::value, + IntType> +uniform_upper_bound(Tag, IntType, IntType b) { + return b; +} + +template <typename FloatType, typename Tag> +typename absl::enable_if_t< + absl::conjunction< + std::is_floating_point<FloatType>, + absl::disjunction<std::is_same<Tag, IntervalClosedClosedT>, + std::is_same<Tag, IntervalOpenClosedT>>>::value, + FloatType> +uniform_upper_bound(Tag, FloatType, FloatType b) { + return std::nextafter(b, (std::numeric_limits<FloatType>::max)()); +} + +template <typename NumType> +using UniformDistribution = + typename std::conditional<std::is_integral<NumType>::value, + absl::uniform_int_distribution<NumType>, + absl::uniform_real_distribution<NumType>>::type; + +template <typename TagType, typename NumType> +struct UniformDistributionWrapper : public UniformDistribution<NumType> { + explicit UniformDistributionWrapper(NumType lo, NumType hi) + : UniformDistribution<NumType>( + uniform_lower_bound<NumType>(TagType{}, lo, hi), + uniform_upper_bound<NumType>(TagType{}, lo, hi)) {} +}; + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_UNIFORM_HELPER_H_ diff --git a/absl/random/log_uniform_int_distribution.h b/absl/random/log_uniform_int_distribution.h new file mode 100644 index 000000000000..ac43416ee573 --- /dev/null +++ b/absl/random/log_uniform_int_distribution.h @@ -0,0 +1,250 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_LOG_UNIFORM_INT_DISTRIBUTION_H_ +#define ABSL_RANDOM_LOG_UNIFORM_INT_DISTRIBUTION_H_ + +#include <algorithm> +#include <cassert> +#include <cmath> +#include <istream> +#include <limits> +#include <ostream> +#include <type_traits> + +#include "absl/random/internal/distribution_impl.h" +#include "absl/random/internal/fastmath.h" +#include "absl/random/internal/iostream_state_saver.h" +#include "absl/random/internal/traits.h" +#include "absl/random/uniform_int_distribution.h" + +namespace absl { + +// log_uniform_int_distribution: +// +// Returns a random variate R in range [min, max] such that +// floor(log(R-min, base)) is uniformly distributed. +// We ensure uniformity by discretization using the +// boundary sets [0, 1, base, base * base, ... min(base*n, max)] +// +template <typename IntType = int> +class log_uniform_int_distribution { + private: + using unsigned_type = + typename random_internal::make_unsigned_bits<IntType>::type; + + public: + using result_type = IntType; + + class param_type { + public: + using distribution_type = log_uniform_int_distribution; + + explicit param_type( + result_type min = 0, + result_type max = (std::numeric_limits<result_type>::max)(), + result_type base = 2) + : min_(min), + max_(max), + base_(base), + range_(static_cast<unsigned_type>(max_) - + static_cast<unsigned_type>(min_)), + log_range_(0) { + assert(max_ >= min_); + assert(base_ > 1); + + if (base_ == 2) { + // Determine where the first set bit is on range(), giving a log2(range) + // value which can be used to construct bounds. + log_range_ = (std::min)(random_internal::LeadingSetBit(range()), + std::numeric_limits<unsigned_type>::digits); + } else { + // NOTE: Computing the logN(x) introduces error from 2 sources: + // 1. Conversion of int to double loses precision for values >= + // 2^53, which may cause some log() computations to operate on + // different values. + // 2. The error introduced by the division will cause the result + // to differ from the expected value. + // + // Thus a result which should equal K may equal K +/- epsilon, + // which can eliminate some values depending on where the bounds fall. + const double inv_log_base = 1.0 / std::log(base_); + const double log_range = std::log(static_cast<double>(range()) + 0.5); + log_range_ = static_cast<int>(std::ceil(inv_log_base * log_range)); + } + } + + result_type(min)() const { return min_; } + result_type(max)() const { return max_; } + result_type base() const { return base_; } + + friend bool operator==(const param_type& a, const param_type& b) { + return a.min_ == b.min_ && a.max_ == b.max_ && a.base_ == b.base_; + } + + friend bool operator!=(const param_type& a, const param_type& b) { + return !(a == b); + } + + private: + friend class log_uniform_int_distribution; + + int log_range() const { return log_range_; } + unsigned_type range() const { return range_; } + + result_type min_; + result_type max_; + result_type base_; + unsigned_type range_; // max - min + int log_range_; // ceil(logN(range_)) + + static_assert(std::is_integral<IntType>::value, + "Class-template absl::log_uniform_int_distribution<> must be " + "parameterized using an integral type."); + }; + + log_uniform_int_distribution() : log_uniform_int_distribution(0) {} + + explicit log_uniform_int_distribution( + result_type min, + result_type max = (std::numeric_limits<result_type>::max)(), + result_type base = 2) + : param_(min, max, base) {} + + explicit log_uniform_int_distribution(const param_type& p) : param_(p) {} + + void reset() {} + + // generating functions + template <typename URBG> + result_type operator()(URBG& g) { // NOLINT(runtime/references) + return (*this)(g, param_); + } + + template <typename URBG> + result_type operator()(URBG& g, // NOLINT(runtime/references) + const param_type& p) { + return (p.min)() + Generate(g, p); + } + + result_type(min)() const { return (param_.min)(); } + result_type(max)() const { return (param_.max)(); } + result_type base() const { return param_.base(); } + + param_type param() const { return param_; } + void param(const param_type& p) { param_ = p; } + + friend bool operator==(const log_uniform_int_distribution& a, + const log_uniform_int_distribution& b) { + return a.param_ == b.param_; + } + friend bool operator!=(const log_uniform_int_distribution& a, + const log_uniform_int_distribution& b) { + return a.param_ != b.param_; + } + + private: + // Returns a log-uniform variate in the range [0, p.range()]. The caller + // should add min() to shift the result to the correct range. + template <typename URNG> + unsigned_type Generate(URNG& g, // NOLINT(runtime/references) + const param_type& p); + + param_type param_; +}; + +template <typename IntType> +template <typename URBG> +typename log_uniform_int_distribution<IntType>::unsigned_type +log_uniform_int_distribution<IntType>::Generate( + URBG& g, // NOLINT(runtime/references) + const param_type& p) { + // sample e over [0, log_range]. Map the results of e to this: + // 0 => 0 + // 1 => [1, b-1] + // 2 => [b, (b^2)-1] + // n => [b^(n-1)..(b^n)-1] + const int e = absl::uniform_int_distribution<int>(0, p.log_range())(g); + if (e == 0) { + return 0; + } + const int d = e - 1; + + unsigned_type base_e, top_e; + if (p.base() == 2) { + base_e = static_cast<unsigned_type>(1) << d; + + top_e = (e >= std::numeric_limits<unsigned_type>::digits) + ? (std::numeric_limits<unsigned_type>::max)() + : (static_cast<unsigned_type>(1) << e) - 1; + } else { + const double r = std::pow(p.base(), d); + const double s = (r * p.base()) - 1.0; + + base_e = (r > (std::numeric_limits<unsigned_type>::max)()) + ? (std::numeric_limits<unsigned_type>::max)() + : static_cast<unsigned_type>(r); + + top_e = (s > (std::numeric_limits<unsigned_type>::max)()) + ? (std::numeric_limits<unsigned_type>::max)() + : static_cast<unsigned_type>(s); + } + + const unsigned_type lo = (base_e >= p.range()) ? p.range() : base_e; + const unsigned_type hi = (top_e >= p.range()) ? p.range() : top_e; + + // choose uniformly over [lo, hi] + return absl::uniform_int_distribution<result_type>(lo, hi)(g); +} + +template <typename CharT, typename Traits, typename IntType> +std::basic_ostream<CharT, Traits>& operator<<( + std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references) + const log_uniform_int_distribution<IntType>& x) { + using stream_type = + typename random_internal::stream_format_type<IntType>::type; + auto saver = random_internal::make_ostream_state_saver(os); + os << static_cast<stream_type>((x.min)()) << os.fill() + << static_cast<stream_type>((x.max)()) << os.fill() + << static_cast<stream_type>(x.base()); + return os; +} + +template <typename CharT, typename Traits, typename IntType> +std::basic_istream<CharT, Traits>& operator>>( + std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references) + log_uniform_int_distribution<IntType>& x) { // NOLINT(runtime/references) + using param_type = typename log_uniform_int_distribution<IntType>::param_type; + using result_type = + typename log_uniform_int_distribution<IntType>::result_type; + using stream_type = + typename random_internal::stream_format_type<IntType>::type; + + stream_type min; + stream_type max; + stream_type base; + + auto saver = random_internal::make_istream_state_saver(is); + is >> min >> max >> base; + if (!is.fail()) { + x.param(param_type(static_cast<result_type>(min), + static_cast<result_type>(max), + static_cast<result_type>(base))); + } + return is; +} + +} // namespace absl + +#endif // ABSL_RANDOM_LOG_UNIFORM_INT_DISTRIBUTION_H_ diff --git a/absl/random/log_uniform_int_distribution_test.cc b/absl/random/log_uniform_int_distribution_test.cc new file mode 100644 index 000000000000..0ff4c32d95d6 --- /dev/null +++ b/absl/random/log_uniform_int_distribution_test.cc @@ -0,0 +1,277 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/log_uniform_int_distribution.h" + +#include <cstddef> +#include <cstdint> +#include <iterator> +#include <random> +#include <sstream> +#include <string> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/base/internal/raw_logging.h" +#include "absl/random/internal/chi_square.h" +#include "absl/random/internal/distribution_test_util.h" +#include "absl/random/internal/sequence_urbg.h" +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/strip.h" + +namespace { + +template <typename IntType> +class LogUniformIntDistributionTypeTest : public ::testing::Test {}; + +using IntTypes = ::testing::Types<int8_t, int16_t, int32_t, int64_t, // + uint8_t, uint16_t, uint32_t, uint64_t>; +TYPED_TEST_CASE(LogUniformIntDistributionTypeTest, IntTypes); + +TYPED_TEST(LogUniformIntDistributionTypeTest, SerializeTest) { + using param_type = + typename absl::log_uniform_int_distribution<TypeParam>::param_type; + using Limits = std::numeric_limits<TypeParam>; + + constexpr int kCount = 1000; + absl::InsecureBitGen gen; + for (const auto& param : { + param_type(0, 1), // + param_type(0, 2), // + param_type(0, 2, 10), // + param_type(9, 32, 4), // + param_type(1, 101, 10), // + param_type(1, Limits::max() / 2), // + param_type(0, Limits::max() - 1), // + param_type(0, Limits::max(), 2), // + param_type(0, Limits::max(), 10), // + param_type(Limits::min(), 0), // + param_type(Limits::lowest(), Limits::max()), // + param_type(Limits::min(), Limits::max()), // + }) { + // Validate parameters. + const auto min = param.min(); + const auto max = param.max(); + const auto base = param.base(); + absl::log_uniform_int_distribution<TypeParam> before(min, max, base); + EXPECT_EQ(before.min(), param.min()); + EXPECT_EQ(before.max(), param.max()); + EXPECT_EQ(before.base(), param.base()); + + { + absl::log_uniform_int_distribution<TypeParam> via_param(param); + EXPECT_EQ(via_param, before); + } + + // Validate stream serialization. + std::stringstream ss; + ss << before; + + absl::log_uniform_int_distribution<TypeParam> after(3, 6, 17); + + EXPECT_NE(before.max(), after.max()); + EXPECT_NE(before.base(), after.base()); + EXPECT_NE(before.param(), after.param()); + EXPECT_NE(before, after); + + ss >> after; + + EXPECT_EQ(before.min(), after.min()); + EXPECT_EQ(before.max(), after.max()); + EXPECT_EQ(before.base(), after.base()); + EXPECT_EQ(before.param(), after.param()); + EXPECT_EQ(before, after); + + // Smoke test. + auto sample_min = after.max(); + auto sample_max = after.min(); + for (int i = 0; i < kCount; i++) { + auto sample = after(gen); + EXPECT_GE(sample, after.min()); + EXPECT_LE(sample, after.max()); + if (sample > sample_max) sample_max = sample; + if (sample < sample_min) sample_min = sample; + } + ABSL_INTERNAL_LOG(INFO, + absl::StrCat("Range: ", +sample_min, ", ", +sample_max)); + } +} + +using log_uniform_i32 = absl::log_uniform_int_distribution<int32_t>; + +class LogUniformIntChiSquaredTest + : public testing::TestWithParam<log_uniform_i32::param_type> { + public: + // The ChiSquaredTestImpl provides a chi-squared goodness of fit test for + // data generated by the log-uniform-int distribution. + double ChiSquaredTestImpl(); + + absl::InsecureBitGen rng_; +}; + +double LogUniformIntChiSquaredTest::ChiSquaredTestImpl() { + using absl::random_internal::kChiSquared; + + const auto& param = GetParam(); + + // Check the distribution of L=log(log_uniform_int_distribution, base), + // expecting that L is roughly uniformly distributed, that is: + // + // P[L=0] ~= P[L=1] ~= ... ~= P[L=log(max)] + // + // For a total of X entries, each bucket should contain some number of samples + // in the interval [X/k - a, X/k + a]. + // + // Where `a` is approximately sqrt(X/k). This is validated by bucketing + // according to the log function and using a chi-squared test for uniformity. + + const bool is_2 = (param.base() == 2); + const double base_log = 1.0 / std::log(param.base()); + const auto bucket_index = [base_log, is_2, ¶m](int32_t x) { + uint64_t y = static_cast<uint64_t>(x) - param.min(); + return (y == 0) ? 0 + : is_2 ? static_cast<int>(1 + std::log2(y)) + : static_cast<int>(1 + std::log(y) * base_log); + }; + const int max_bucket = bucket_index(param.max()); // inclusive + const size_t trials = 15 + (max_bucket + 1) * 10; + + log_uniform_i32 dist(param); + + std::vector<int64_t> buckets(max_bucket + 1); + for (size_t i = 0; i < trials; ++i) { + const auto sample = dist(rng_); + // Check the bounds. + ABSL_ASSERT(sample <= dist.max()); + ABSL_ASSERT(sample >= dist.min()); + // Convert the output of the generator to one of num_bucket buckets. + int bucket = bucket_index(sample); + ABSL_ASSERT(bucket <= max_bucket); + ++buckets[bucket]; + } + + // The null-hypothesis is that the distribution is uniform with respect to + // log-uniform-int bucketization. + const int dof = buckets.size() - 1; + const double expected = trials / static_cast<double>(buckets.size()); + + const double threshold = absl::random_internal::ChiSquareValue(dof, 0.98); + + double chi_square = absl::random_internal::ChiSquareWithExpected( + std::begin(buckets), std::end(buckets), expected); + + const double p = absl::random_internal::ChiSquarePValue(chi_square, dof); + + if (chi_square > threshold) { + ABSL_INTERNAL_LOG(INFO, "values"); + for (size_t i = 0; i < buckets.size(); i++) { + ABSL_INTERNAL_LOG(INFO, absl::StrCat(i, ": ", buckets[i])); + } + ABSL_INTERNAL_LOG(INFO, + absl::StrFormat("trials=%d\n" + "%s(data, %d) = %f (%f)\n" + "%s @ 0.98 = %f", + trials, kChiSquared, dof, chi_square, p, + kChiSquared, threshold)); + } + return p; +} + +TEST_P(LogUniformIntChiSquaredTest, MultiTest) { + const int kTrials = 5; + + int failures = 0; + for (int i = 0; i < kTrials; i++) { + double p_value = ChiSquaredTestImpl(); + if (p_value < 0.005) { + failures++; + } + } + + // There is a 0.10% chance of producing at least one failure, so raise the + // failure threshold high enough to allow for a flake rate < 10,000. + EXPECT_LE(failures, 4); +} + +// Generate the parameters for the test. +std::vector<log_uniform_i32::param_type> GenParams() { + using Param = log_uniform_i32::param_type; + using Limits = std::numeric_limits<int32_t>; + + return std::vector<Param>{ + Param{0, 1, 2}, + Param{1, 1, 2}, + Param{0, 2, 2}, + Param{0, 3, 2}, + Param{0, 4, 2}, + Param{0, 9, 10}, + Param{0, 10, 10}, + Param{0, 11, 10}, + Param{1, 10, 10}, + Param{0, (1 << 8) - 1, 2}, + Param{0, (1 << 8), 2}, + Param{0, (1 << 30) - 1, 2}, + Param{-1000, 1000, 10}, + Param{0, Limits::max(), 2}, + Param{0, Limits::max(), 3}, + Param{0, Limits::max(), 10}, + Param{Limits::min(), 0}, + Param{Limits::min(), Limits::max(), 2}, + }; +} + +std::string ParamName( + const ::testing::TestParamInfo<log_uniform_i32::param_type>& info) { + const auto& p = info.param; + std::string name = + absl::StrCat("min_", p.min(), "__max_", p.max(), "__base_", p.base()); + return absl::StrReplaceAll(name, {{"+", "_"}, {"-", "_"}, {".", "_"}}); +} + +INSTANTIATE_TEST_SUITE_P(, LogUniformIntChiSquaredTest, + ::testing::ValuesIn(GenParams()), ParamName); + +// NOTE: absl::log_uniform_int_distribution is not guaranteed to be stable. +TEST(LogUniformIntDistributionTest, StabilityTest) { + using testing::ElementsAre; + // absl::uniform_int_distribution stability relies on + // absl::random_internal::LeadingSetBit, std::log, std::pow. + absl::random_internal::sequence_urbg urbg( + {0x0003eb76f6f7f755ull, 0xFFCEA50FDB2F953Bull, 0xC332DDEFBE6C5AA5ull, + 0x6558218568AB9702ull, 0x2AEF7DAD5B6E2F84ull, 0x1521B62829076170ull, + 0xECDD4775619F1510ull, 0x13CCA830EB61BD96ull, 0x0334FE1EAA0363CFull, + 0xB5735C904C70A239ull, 0xD59E9E0BCBAADE14ull, 0xEECC86BC60622CA7ull}); + + std::vector<int> output(6); + + { + absl::log_uniform_int_distribution<int32_t> dist(0, 256); + std::generate(std::begin(output), std::end(output), + [&] { return dist(urbg); }); + EXPECT_THAT(output, ElementsAre(256, 66, 4, 6, 57, 103)); + } + urbg.reset(); + { + absl::log_uniform_int_distribution<int32_t> dist(0, 256, 10); + std::generate(std::begin(output), std::end(output), + [&] { return dist(urbg); }); + EXPECT_THAT(output, ElementsAre(8, 4, 0, 0, 0, 69)); + } +} + +} // namespace diff --git a/absl/random/poisson_distribution.h b/absl/random/poisson_distribution.h new file mode 100644 index 000000000000..7750b1c9ba59 --- /dev/null +++ b/absl/random/poisson_distribution.h @@ -0,0 +1,254 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_POISSON_DISTRIBUTION_H_ +#define ABSL_RANDOM_POISSON_DISTRIBUTION_H_ + +#include <cassert> +#include <cmath> +#include <istream> +#include <limits> +#include <ostream> +#include <type_traits> + +#include "absl/random/internal/distribution_impl.h" +#include "absl/random/internal/fast_uniform_bits.h" +#include "absl/random/internal/fastmath.h" +#include "absl/random/internal/iostream_state_saver.h" + +namespace absl { + +// absl::poisson_distribution: +// Generates discrete variates conforming to a Poisson distribution. +// p(n) = (mean^n / n!) exp(-mean) +// +// Depending on the parameter, the distribution selects one of the following +// algorithms: +// * The standard algorithm, attributed to Knuth, extended using a split method +// for larger values +// * The "Ratio of Uniforms as a convenient method for sampling from classical +// discrete distributions", Stadlober, 1989. +// http://www.sciencedirect.com/science/article/pii/0377042790903495 +// +// NOTE: param_type.mean() is a double, which permits values larger than +// poisson_distribution<IntType>::max(), however this should be avoided and +// the distribution results are limited to the max() value. +// +// The goals of this implementation are to provide good performance while still +// beig thread-safe: This limits the implementation to not using lgamma provided +// by <math.h>. +// +template <typename IntType = int> +class poisson_distribution { + public: + using result_type = IntType; + + class param_type { + public: + using distribution_type = poisson_distribution; + explicit param_type(double mean = 1.0); + + double mean() const { return mean_; } + + friend bool operator==(const param_type& a, const param_type& b) { + return a.mean_ == b.mean_; + } + + friend bool operator!=(const param_type& a, const param_type& b) { + return !(a == b); + } + + private: + friend class poisson_distribution; + + double mean_; + double emu_; // e ^ -mean_ + double lmu_; // ln(mean_) + double s_; + double log_k_; + int split_; + + static_assert(std::is_integral<IntType>::value, + "Class-template absl::poisson_distribution<> must be " + "parameterized using an integral type."); + }; + + poisson_distribution() : poisson_distribution(1.0) {} + + explicit poisson_distribution(double mean) : param_(mean) {} + + explicit poisson_distribution(const param_type& p) : param_(p) {} + + void reset() {} + + // generating functions + template <typename URBG> + result_type operator()(URBG& g) { // NOLINT(runtime/references) + return (*this)(g, param_); + } + + template <typename URBG> + result_type operator()(URBG& g, // NOLINT(runtime/references) + const param_type& p); + + param_type param() const { return param_; } + void param(const param_type& p) { param_ = p; } + + result_type(min)() const { return 0; } + result_type(max)() const { return (std::numeric_limits<result_type>::max)(); } + + double mean() const { return param_.mean(); } + + friend bool operator==(const poisson_distribution& a, + const poisson_distribution& b) { + return a.param_ == b.param_; + } + friend bool operator!=(const poisson_distribution& a, + const poisson_distribution& b) { + return a.param_ != b.param_; + } + + private: + param_type param_; + random_internal::FastUniformBits<uint64_t> fast_u64_; +}; + +// ----------------------------------------------------------------------------- +// Implementation details follow +// ----------------------------------------------------------------------------- + +template <typename IntType> +poisson_distribution<IntType>::param_type::param_type(double mean) + : mean_(mean), split_(0) { + assert(mean >= 0); + assert(mean <= (std::numeric_limits<result_type>::max)()); + // As a defensive measure, avoid large values of the mean. The rejection + // algorithm used does not support very large values well. It my be worth + // changing algorithms to better deal with these cases. + assert(mean <= 1e10); + if (mean_ < 10) { + // For small lambda, use the knuth method. + split_ = 1; + emu_ = std::exp(-mean_); + } else if (mean_ <= 50) { + // Use split-knuth method. + split_ = 1 + static_cast<int>(mean_ / 10.0); + emu_ = std::exp(-mean_ / static_cast<double>(split_)); + } else { + // Use ratio of uniforms method. + constexpr double k2E = 0.7357588823428846; + constexpr double kSA = 0.4494580810294493; + + lmu_ = std::log(mean_); + double a = mean_ + 0.5; + s_ = kSA + std::sqrt(k2E * a); + const double mode = std::ceil(mean_) - 1; + log_k_ = lmu_ * mode - absl::random_internal::StirlingLogFactorial(mode); + } +} + +template <typename IntType> +template <typename URBG> +typename poisson_distribution<IntType>::result_type +poisson_distribution<IntType>::operator()( + URBG& g, // NOLINT(runtime/references) + const param_type& p) { + using random_internal::PositiveValueT; + using random_internal::RandU64ToDouble; + using random_internal::SignedValueT; + + if (p.split_ != 0) { + // Use Knuth's algorithm with range splitting to avoid floating-point + // errors. Knuth's algorithm is: Ui is a sequence of uniform variates on + // (0,1); return the number of variates required for product(Ui) < + // exp(-lambda). + // + // The expected number of variates required for Knuth's method can be + // computed as follows: + // The expected value of U is 0.5, so solving for 0.5^n < exp(-lambda) gives + // the expected number of uniform variates + // required for a given lambda, which is: + // lambda = [2, 5, 9, 10, 11, 12, 13, 14, 15, 16, 17] + // n = [3, 8, 13, 15, 16, 18, 19, 21, 22, 24, 25] + // + result_type n = 0; + for (int split = p.split_; split > 0; --split) { + double r = 1.0; + do { + r *= RandU64ToDouble<PositiveValueT, true>(fast_u64_(g)); + ++n; + } while (r > p.emu_); + --n; + } + return n; + } + + // Use ratio of uniforms method. + // + // Let u ~ Uniform(0, 1), v ~ Uniform(-1, 1), + // a = lambda + 1/2, + // s = 1.5 - sqrt(3/e) + sqrt(2(lambda + 1/2)/e), + // x = s * v/u + a. + // P(floor(x) = k | u^2 < f(floor(x))/k), where + // f(m) = lambda^m exp(-lambda)/ m!, for 0 <= m, and f(m) = 0 otherwise, + // and k = max(f). + const double a = p.mean_ + 0.5; + for (;;) { + const double u = + RandU64ToDouble<PositiveValueT, false>(fast_u64_(g)); // (0, 1) + const double v = + RandU64ToDouble<SignedValueT, false>(fast_u64_(g)); // (-1, 1) + const double x = std::floor(p.s_ * v / u + a); + if (x < 0) continue; // f(negative) = 0 + const double rhs = x * p.lmu_; + // clang-format off + double s = (x <= 1.0) ? 0.0 + : (x == 2.0) ? 0.693147180559945 + : absl::random_internal::StirlingLogFactorial(x); + // clang-format on + const double lhs = 2.0 * std::log(u) + p.log_k_ + s; + if (lhs < rhs) { + return x > (max)() ? (max)() + : static_cast<result_type>(x); // f(x)/k >= u^2 + } + } +} + +template <typename CharT, typename Traits, typename IntType> +std::basic_ostream<CharT, Traits>& operator<<( + std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references) + const poisson_distribution<IntType>& x) { + auto saver = random_internal::make_ostream_state_saver(os); + os.precision(random_internal::stream_precision_helper<double>::kPrecision); + os << x.mean(); + return os; +} + +template <typename CharT, typename Traits, typename IntType> +std::basic_istream<CharT, Traits>& operator>>( + std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references) + poisson_distribution<IntType>& x) { // NOLINT(runtime/references) + using param_type = typename poisson_distribution<IntType>::param_type; + + auto saver = random_internal::make_istream_state_saver(is); + double mean = random_internal::read_floating_point<double>(is); + if (!is.fail()) { + x.param(param_type(mean)); + } + return is; +} + +} // namespace absl + +#endif // ABSL_RANDOM_POISSON_DISTRIBUTION_H_ diff --git a/absl/random/poisson_distribution_test.cc b/absl/random/poisson_distribution_test.cc new file mode 100644 index 000000000000..6d68999a3c54 --- /dev/null +++ b/absl/random/poisson_distribution_test.cc @@ -0,0 +1,565 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/poisson_distribution.h" + +#include <algorithm> +#include <cstddef> +#include <cstdint> +#include <iterator> +#include <random> +#include <sstream> +#include <string> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/base/internal/raw_logging.h" +#include "absl/base/macros.h" +#include "absl/container/flat_hash_map.h" +#include "absl/random/internal/chi_square.h" +#include "absl/random/internal/distribution_test_util.h" +#include "absl/random/internal/sequence_urbg.h" +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/strip.h" + +// Notes about generating poisson variates: +// +// It is unlikely that any implementation of std::poisson_distribution +// will be stable over time and across library implementations. For instance +// the three different poisson variate generators listed below all differ: +// +// https://github.com/ampl/gsl/tree/master/randist/poisson.c +// * GSL uses a gamma + binomial + knuth method to compute poisson variates. +// +// https://github.com/gcc-mirror/gcc/blob/master/libstdc%2B%2B-v3/include/bits/random.tcc +// * GCC uses the Devroye rejection algorithm, based on +// Devroye, L. Non-Uniform Random Variates Generation. Springer-Verlag, +// New York, 1986, Ch. X, Sects. 3.3 & 3.4 (+ Errata!), ~p.511 +// http://www.nrbook.com/devroye/ +// +// https://github.com/llvm-mirror/libcxx/blob/master/include/random +// * CLANG uses a different rejection method, which appears to include a +// normal-distribution approximation and an exponential distribution to +// compute the threshold, including a similar factorial approximation to this +// one, but it is unclear where the algorithm comes from, exactly. +// + +namespace { + +using absl::random_internal::kChiSquared; + +// The PoissonDistributionInterfaceTest provides a basic test that +// absl::poisson_distribution conforms to the interface and serialization +// requirements imposed by [rand.req.dist] for the common integer types. + +template <typename IntType> +class PoissonDistributionInterfaceTest : public ::testing::Test {}; + +using IntTypes = ::testing::Types<int, int8_t, int16_t, int32_t, int64_t, + uint8_t, uint16_t, uint32_t, uint64_t>; +TYPED_TEST_CASE(PoissonDistributionInterfaceTest, IntTypes); + +TYPED_TEST(PoissonDistributionInterfaceTest, SerializeTest) { + using param_type = typename absl::poisson_distribution<TypeParam>::param_type; + const double kMax = + std::min(1e10 /* assertion limit */, + static_cast<double>(std::numeric_limits<TypeParam>::max())); + + const double kParams[] = { + // Cases around 1. + 1, // + std::nextafter(1.0, 0.0), // 1 - epsilon + std::nextafter(1.0, 2.0), // 1 + epsilon + // Arbitrary values. + 1e-8, 1e-4, + 0.0000005, // ~7.2e-7 + 0.2, // ~0.2x + 0.5, // 0.72 + 2, // ~2.8 + 20, // 3x ~9.6 + 100, 1e4, 1e8, 1.5e9, 1e20, + // Boundary cases. + std::numeric_limits<double>::max(), + std::numeric_limits<double>::epsilon(), + std::nextafter(std::numeric_limits<double>::min(), + 1.0), // min + epsilon + std::numeric_limits<double>::min(), // smallest normal + std::numeric_limits<double>::denorm_min(), // smallest denorm + std::numeric_limits<double>::min() / 2, // denorm + std::nextafter(std::numeric_limits<double>::min(), + 0.0), // denorm_max + }; + + + constexpr int kCount = 1000; + absl::InsecureBitGen gen; + for (const double m : kParams) { + const double mean = std::min(kMax, m); + const param_type param(mean); + + // Validate parameters. + absl::poisson_distribution<TypeParam> before(mean); + EXPECT_EQ(before.mean(), param.mean()); + + { + absl::poisson_distribution<TypeParam> via_param(param); + EXPECT_EQ(via_param, before); + EXPECT_EQ(via_param.param(), before.param()); + } + + // Smoke test. + auto sample_min = before.max(); + auto sample_max = before.min(); + for (int i = 0; i < kCount; i++) { + auto sample = before(gen); + EXPECT_GE(sample, before.min()); + EXPECT_LE(sample, before.max()); + if (sample > sample_max) sample_max = sample; + if (sample < sample_min) sample_min = sample; + } + + ABSL_INTERNAL_LOG(INFO, absl::StrCat("Range {", param.mean(), "}: ", + +sample_min, ", ", +sample_max)); + + // Validate stream serialization. + std::stringstream ss; + ss << before; + + absl::poisson_distribution<TypeParam> after(3.8); + + EXPECT_NE(before.mean(), after.mean()); + EXPECT_NE(before.param(), after.param()); + EXPECT_NE(before, after); + + ss >> after; + + EXPECT_EQ(before.mean(), after.mean()) // + << ss.str() << " " // + << (ss.good() ? "good " : "") // + << (ss.bad() ? "bad " : "") // + << (ss.eof() ? "eof " : "") // + << (ss.fail() ? "fail " : ""); + } +} + +// See http://www.itl.nist.gov/div898/handbook/eda/section3/eda366j.htm + +class PoissonModel { + public: + explicit PoissonModel(double mean) : mean_(mean) {} + + double mean() const { return mean_; } + double variance() const { return mean_; } + double stddev() const { return std::sqrt(variance()); } + double skew() const { return 1.0 / mean_; } + double kurtosis() const { return 3.0 + 1.0 / mean_; } + + // InitCDF() initializes the CDF for the distribution parameters. + void InitCDF(); + + // The InverseCDF, or the Percent-point function returns x, P(x) < v. + struct CDF { + size_t index; + double pmf; + double cdf; + }; + CDF InverseCDF(double p) { + CDF target{0, 0, p}; + auto it = std::upper_bound( + std::begin(cdf_), std::end(cdf_), target, + [](const CDF& a, const CDF& b) { return a.cdf < b.cdf; }); + return *it; + } + + void LogCDF() { + ABSL_INTERNAL_LOG(INFO, absl::StrCat("CDF (mean = ", mean_, ")")); + for (const auto c : cdf_) { + ABSL_INTERNAL_LOG(INFO, + absl::StrCat(c.index, ": pmf=", c.pmf, " cdf=", c.cdf)); + } + } + + private: + const double mean_; + + std::vector<CDF> cdf_; +}; + +// The goal is to compute an InverseCDF function, or percent point function for +// the poisson distribution, and use that to partition our output into equal +// range buckets. However there is no closed form solution for the inverse cdf +// for poisson distributions (the closest is the incomplete gamma function). +// Instead, `InitCDF` iteratively computes the PMF and the CDF. This enables +// searching for the bucket points. +void PoissonModel::InitCDF() { + if (!cdf_.empty()) { + // State already initialized. + return; + } + ABSL_ASSERT(mean_ < 201.0); + + const size_t max_i = 50 * stddev() + mean(); + const double e_neg_mean = std::exp(-mean()); + ABSL_ASSERT(e_neg_mean > 0); + + double d = 1; + double last_result = e_neg_mean; + double cumulative = e_neg_mean; + if (e_neg_mean > 1e-10) { + cdf_.push_back({0, e_neg_mean, cumulative}); + } + for (size_t i = 1; i < max_i; i++) { + d *= (mean() / i); + double result = e_neg_mean * d; + cumulative += result; + if (result < 1e-10 && result < last_result && cumulative > 0.999999) { + break; + } + if (result > 1e-7) { + cdf_.push_back({i, result, cumulative}); + } + last_result = result; + } + ABSL_ASSERT(!cdf_.empty()); +} + +// PoissonDistributionZTest implements a z-test for the poisson distribution. + +struct ZParam { + double mean; + double p_fail; // Z-Test probability of failure. + int trials; // Z-Test trials. + size_t samples; // Z-Test samples. +}; + +class PoissonDistributionZTest : public testing::TestWithParam<ZParam>, + public PoissonModel { + public: + PoissonDistributionZTest() : PoissonModel(GetParam().mean) {} + + // ZTestImpl provides a basic z-squared test of the mean vs. expected + // mean for data generated by the poisson distribution. + template <typename D> + bool SingleZTest(const double p, const size_t samples); + + absl::InsecureBitGen rng_; +}; + +template <typename D> +bool PoissonDistributionZTest::SingleZTest(const double p, + const size_t samples) { + D dis(mean()); + + absl::flat_hash_map<int32_t, int> buckets; + std::vector<double> data; + data.reserve(samples); + for (int j = 0; j < samples; j++) { + const auto x = dis(rng_); + buckets[x]++; + data.push_back(x); + } + + // The null-hypothesis is that the distribution is a poisson distribution with + // the provided mean (not estimated from the data). + const auto m = absl::random_internal::ComputeDistributionMoments(data); + const double max_err = absl::random_internal::MaxErrorTolerance(p); + const double z = absl::random_internal::ZScore(mean(), m); + const bool pass = absl::random_internal::Near("z", z, 0.0, max_err); + + if (!pass) { + ABSL_INTERNAL_LOG( + INFO, absl::StrFormat("p=%f max_err=%f\n" + " mean=%f vs. %f\n" + " stddev=%f vs. %f\n" + " skewness=%f vs. %f\n" + " kurtosis=%f vs. %f\n" + " z=%f", + p, max_err, m.mean, mean(), std::sqrt(m.variance), + stddev(), m.skewness, skew(), m.kurtosis, + kurtosis(), z)); + } + return pass; +} + +TEST_P(PoissonDistributionZTest, AbslPoissonDistribution) { + const auto& param = GetParam(); + const int expected_failures = + std::max(1, static_cast<int>(std::ceil(param.trials * param.p_fail))); + const double p = absl::random_internal::RequiredSuccessProbability( + param.p_fail, param.trials); + + int failures = 0; + for (int i = 0; i < param.trials; i++) { + failures += + SingleZTest<absl::poisson_distribution<int32_t>>(p, param.samples) ? 0 + : 1; + } + EXPECT_LE(failures, expected_failures); +} + +std::vector<ZParam> GetZParams() { + // These values have been adjusted from the "exact" computed values to reduce + // failure rates. + // + // It turns out that the actual values are not as close to the expected values + // as would be ideal. + return std::vector<ZParam>({ + // Knuth method. + ZParam{0.5, 0.01, 100, 1000}, + ZParam{1.0, 0.01, 100, 1000}, + ZParam{10.0, 0.01, 100, 5000}, + // Split-knuth method. + ZParam{20.0, 0.01, 100, 10000}, + ZParam{50.0, 0.01, 100, 10000}, + // Ratio of gaussians method. + ZParam{51.0, 0.01, 100, 10000}, + ZParam{200.0, 0.05, 10, 100000}, + ZParam{100000.0, 0.05, 10, 1000000}, + }); +} + +std::string ZParamName(const ::testing::TestParamInfo<ZParam>& info) { + const auto& p = info.param; + std::string name = absl::StrCat("mean_", absl::SixDigits(p.mean)); + return absl::StrReplaceAll(name, {{"+", "_"}, {"-", "_"}, {".", "_"}}); +} + +INSTANTIATE_TEST_SUITE_P(, PoissonDistributionZTest, + ::testing::ValuesIn(GetZParams()), ZParamName); + +// The PoissonDistributionChiSquaredTest class provides a basic test framework +// for variates generated by a conforming poisson_distribution. +class PoissonDistributionChiSquaredTest : public testing::TestWithParam<double>, + public PoissonModel { + public: + PoissonDistributionChiSquaredTest() : PoissonModel(GetParam()) {} + + // The ChiSquaredTestImpl provides a chi-squared goodness of fit test for data + // generated by the poisson distribution. + template <typename D> + double ChiSquaredTestImpl(); + + private: + void InitChiSquaredTest(const double buckets); + + absl::InsecureBitGen rng_; + std::vector<size_t> cutoffs_; + std::vector<double> expected_; +}; + +void PoissonDistributionChiSquaredTest::InitChiSquaredTest( + const double buckets) { + if (!cutoffs_.empty() && !expected_.empty()) { + return; + } + InitCDF(); + + // The code below finds cuttoffs that yield approximately equally-sized + // buckets to the extent that it is possible. However for poisson + // distributions this is particularly challenging for small mean parameters. + // Track the expected proportion of items in each bucket. + double last_cdf = 0; + const double inc = 1.0 / buckets; + for (double p = inc; p <= 1.0; p += inc) { + auto result = InverseCDF(p); + if (!cutoffs_.empty() && cutoffs_.back() == result.index) { + continue; + } + double d = result.cdf - last_cdf; + cutoffs_.push_back(result.index); + expected_.push_back(d); + last_cdf = result.cdf; + } + cutoffs_.push_back(std::numeric_limits<size_t>::max()); + expected_.push_back(std::max(0.0, 1.0 - last_cdf)); +} + +template <typename D> +double PoissonDistributionChiSquaredTest::ChiSquaredTestImpl() { + const int kSamples = 2000; + const int kBuckets = 50; + + // The poisson CDF fails for large mean values, since e^-mean exceeds the + // machine precision. For these cases, using a normal approximation would be + // appropriate. + ABSL_ASSERT(mean() <= 200); + InitChiSquaredTest(kBuckets); + + D dis(mean()); + + std::vector<int32_t> counts(cutoffs_.size(), 0); + for (int j = 0; j < kSamples; j++) { + const size_t x = dis(rng_); + auto it = std::lower_bound(std::begin(cutoffs_), std::end(cutoffs_), x); + counts[std::distance(cutoffs_.begin(), it)]++; + } + + // Normalize the counts. + std::vector<int32_t> e(expected_.size(), 0); + for (int i = 0; i < e.size(); i++) { + e[i] = kSamples * expected_[i]; + } + + // The null-hypothesis is that the distribution is a poisson distribution with + // the provided mean (not estimated from the data). + const int dof = static_cast<int>(counts.size()) - 1; + + // The threshold for logging is 1-in-50. + const double threshold = absl::random_internal::ChiSquareValue(dof, 0.98); + + const double chi_square = absl::random_internal::ChiSquare( + std::begin(counts), std::end(counts), std::begin(e), std::end(e)); + + const double p = absl::random_internal::ChiSquarePValue(chi_square, dof); + + // Log if the chi_squared value is above the threshold. + if (chi_square > threshold) { + LogCDF(); + + ABSL_INTERNAL_LOG(INFO, absl::StrCat("VALUES buckets=", counts.size(), + " samples=", kSamples)); + for (size_t i = 0; i < counts.size(); i++) { + ABSL_INTERNAL_LOG( + INFO, absl::StrCat(cutoffs_[i], ": ", counts[i], " vs. E=", e[i])); + } + + ABSL_INTERNAL_LOG( + INFO, + absl::StrCat(kChiSquared, "(data, dof=", dof, ") = ", chi_square, " (", + p, ")\n", " vs.\n", kChiSquared, " @ 0.98 = ", threshold)); + } + return p; +} + +TEST_P(PoissonDistributionChiSquaredTest, AbslPoissonDistribution) { + const int kTrials = 20; + + // Large values are not yet supported -- this requires estimating the cdf + // using the normal distribution instead of the poisson in this case. + ASSERT_LE(mean(), 200.0); + if (mean() > 200.0) { + return; + } + + int failures = 0; + for (int i = 0; i < kTrials; i++) { + double p_value = ChiSquaredTestImpl<absl::poisson_distribution<int32_t>>(); + if (p_value < 0.005) { + failures++; + } + } + // There is a 0.10% chance of producing at least one failure, so raise the + // failure threshold high enough to allow for a flake rate < 10,000. + EXPECT_LE(failures, 4); +} + +INSTANTIATE_TEST_SUITE_P(, PoissonDistributionChiSquaredTest, + ::testing::Values(0.5, 1.0, 2.0, 10.0, 50.0, 51.0, + 200.0)); + +// NOTE: absl::poisson_distribution is not guaranteed to be stable. +TEST(PoissonDistributionTest, StabilityTest) { + using testing::ElementsAre; + // absl::poisson_distribution stability relies on stability of + // std::exp, std::log, std::sqrt, std::ceil, std::floor, and + // absl::FastUniformBits, absl::StirlingLogFactorial, absl::RandU64ToDouble. + absl::random_internal::sequence_urbg urbg({ + 0x035b0dc7e0a18acfull, 0x06cebe0d2653682eull, 0x0061e9b23861596bull, + 0x0003eb76f6f7f755ull, 0xFFCEA50FDB2F953Bull, 0xC332DDEFBE6C5AA5ull, + 0x6558218568AB9702ull, 0x2AEF7DAD5B6E2F84ull, 0x1521B62829076170ull, + 0xECDD4775619F1510ull, 0x13CCA830EB61BD96ull, 0x0334FE1EAA0363CFull, + 0xB5735C904C70A239ull, 0xD59E9E0BCBAADE14ull, 0xEECC86BC60622CA7ull, + 0x4864f22c059bf29eull, 0x247856d8b862665cull, 0xe46e86e9a1337e10ull, + 0xd8c8541f3519b133ull, 0xe75b5162c567b9e4ull, 0xf732e5ded7009c5bull, + 0xb170b98353121eacull, 0x1ec2e8986d2362caull, 0x814c8e35fe9a961aull, + 0x0c3cd59c9b638a02ull, 0xcb3bb6478a07715cull, 0x1224e62c978bbc7full, + 0x671ef2cb04e81f6eull, 0x3c1cbd811eaf1808ull, 0x1bbc23cfa8fac721ull, + 0xa4c2cda65e596a51ull, 0xb77216fad37adf91ull, 0x836d794457c08849ull, + 0xe083df03475f49d7ull, 0xbc9feb512e6b0d6cull, 0xb12d74fdd718c8c5ull, + 0x12ff09653bfbe4caull, 0x8dd03a105bc4ee7eull, 0x5738341045ba0d85ull, + 0xf3fd722dc65ad09eull, 0xfa14fd21ea2a5705ull, 0xffe6ea4d6edb0c73ull, + 0xD07E9EFE2BF11FB4ull, 0x95DBDA4DAE909198ull, 0xEAAD8E716B93D5A0ull, + 0xD08ED1D0AFC725E0ull, 0x8E3C5B2F8E7594B7ull, 0x8FF6E2FBF2122B64ull, + 0x8888B812900DF01Cull, 0x4FAD5EA0688FC31Cull, 0xD1CFF191B3A8C1ADull, + 0x2F2F2218BE0E1777ull, 0xEA752DFE8B021FA1ull, 0xE5A0CC0FB56F74E8ull, + 0x18ACF3D6CE89E299ull, 0xB4A84FE0FD13E0B7ull, 0x7CC43B81D2ADA8D9ull, + 0x165FA26680957705ull, 0x93CC7314211A1477ull, 0xE6AD206577B5FA86ull, + 0xC75442F5FB9D35CFull, 0xEBCDAF0C7B3E89A0ull, 0xD6411BD3AE1E7E49ull, + 0x00250E2D2071B35Eull, 0x226800BB57B8E0AFull, 0x2464369BF009B91Eull, + 0x5563911D59DFA6AAull, 0x78C14389D95A537Full, 0x207D5BA202E5B9C5ull, + 0x832603766295CFA9ull, 0x11C819684E734A41ull, 0xB3472DCA7B14A94Aull, + }); + + std::vector<int> output(10); + + // Method 1. + { + absl::poisson_distribution<int> dist(5); + std::generate(std::begin(output), std::end(output), + [&] { return dist(urbg); }); + } + EXPECT_THAT(output, // mean = 4.2 + ElementsAre(1, 0, 0, 4, 2, 10, 3, 3, 7, 12)); + + // Method 2. + { + urbg.reset(); + absl::poisson_distribution<int> dist(25); + std::generate(std::begin(output), std::end(output), + [&] { return dist(urbg); }); + } + EXPECT_THAT(output, // mean = 19.8 + ElementsAre(9, 35, 18, 10, 35, 18, 10, 35, 18, 10)); + + // Method 3. + { + urbg.reset(); + absl::poisson_distribution<int> dist(121); + std::generate(std::begin(output), std::end(output), + [&] { return dist(urbg); }); + } + EXPECT_THAT(output, // mean = 124.1 + ElementsAre(161, 122, 129, 124, 112, 112, 117, 120, 130, 114)); +} + +TEST(PoissonDistributionTest, AlgorithmExpectedValue_1) { + // This tests small values of the Knuth method. + // The underlying uniform distribution will generate exactly 0.5. + absl::random_internal::sequence_urbg urbg({0x8000000000000001ull}); + absl::poisson_distribution<int> dist(5); + EXPECT_EQ(7, dist(urbg)); +} + +TEST(PoissonDistributionTest, AlgorithmExpectedValue_2) { + // This tests larger values of the Knuth method. + // The underlying uniform distribution will generate exactly 0.5. + absl::random_internal::sequence_urbg urbg({0x8000000000000001ull}); + absl::poisson_distribution<int> dist(25); + EXPECT_EQ(36, dist(urbg)); +} + +TEST(PoissonDistributionTest, AlgorithmExpectedValue_3) { + // This variant uses the ratio of uniforms method. + absl::random_internal::sequence_urbg urbg( + {0x7fffffffffffffffull, 0x8000000000000000ull}); + + absl::poisson_distribution<int> dist(121); + EXPECT_EQ(121, dist(urbg)); +} + +} // namespace diff --git a/absl/random/random.h b/absl/random/random.h new file mode 100644 index 000000000000..dc6852f4adf3 --- /dev/null +++ b/absl/random/random.h @@ -0,0 +1,187 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// ----------------------------------------------------------------------------- +// File: random.h +// ----------------------------------------------------------------------------- +// +// This header defines the recommended Uniform Random Bit Generator (URBG) +// types for use within the Abseil Random library. These types are not +// suitable for security-related use-cases, but should suffice for most other +// uses of generating random values. +// +// The Abseil random library provides the following URBG types: +// +// * BitGen, a good general-purpose bit generator, optimized for generating +// random (but not cryptographically secure) values +// * InsecureBitGen, a slightly faster, though less random, bit generator, for +// cases where the existing BitGen is a drag on performance. + +#ifndef ABSL_RANDOM_RANDOM_H_ +#define ABSL_RANDOM_RANDOM_H_ + +#include <random> + +#include "absl/random/distributions.h" // IWYU pragma: export +#include "absl/random/internal/nonsecure_base.h" // IWYU pragma: export +#include "absl/random/internal/pcg_engine.h" // IWYU pragma: export +#include "absl/random/internal/pool_urbg.h" +#include "absl/random/internal/randen_engine.h" +#include "absl/random/seed_sequences.h" // IWYU pragma: export + +namespace absl { + +// ----------------------------------------------------------------------------- +// absl::BitGen +// ----------------------------------------------------------------------------- +// +// `absl::BitGen` is a general-purpose random bit generator for generating +// random values for use within the Abseil random library. Typically, you use a +// bit generator in combination with a distribution to provide random values. +// +// Example: +// +// // Create an absl::BitGen. There is no need to seed this bit generator. +// absl::BitGen gen; +// +// // Generate an integer value in the closed interval [1,6] +// int die_roll = absl::uniform_int_distribution<int>(1, 6)(gen); +// +// `absl::BitGen` is seeded by default with non-deterministic data to produce +// different sequences of random values across different instances, including +// different binary invocations. This behavior is different than the standard +// library bit generators, which use golden values as their seeds. Default +// construction intentionally provides no stability guarantees, to avoid +// accidental dependence on such a property. +// +// `absl::BitGen` may be constructed with an optional seed sequence type, +// conforming to [rand.req.seed_seq], which will be mixed with additional +// non-deterministic data. +// +// Example: +// +// // Create an absl::BitGen using an std::seed_seq seed sequence +// std::seed_seq seq{1,2,3}; +// absl::BitGen gen_with_seed(seq); +// +// // Generate an integer value in the closed interval [1,6] +// int die_roll2 = absl::uniform_int_distribution<int>(1, 6)(gen_with_seed); +// +// `absl::BitGen` meets the requirements of the Uniform Random Bit Generator +// (URBG) concept as per the C++17 standard [rand.req.urng] though differs +// slightly with [rand.req.eng]. Like its standard library equivalents (e.g. +// `std::mersenne_twister_engine`) `absl::BitGen` is not cryptographically +// secure. +// +// Constructing two `absl::BitGen`s with the same seed sequence in the same +// binary will produce the same sequence of variates within the same binary, but +// need not do so across multiple binary invocations. +// +// This type has been optimized to perform better than Mersenne Twister +// (https://en.wikipedia.org/wiki/Mersenne_Twister) and many other complex URBG +// types on modern x86, ARM, and PPC architectures. +// +// This type is thread-compatible, but not thread-safe. + +// --------------------------------------------------------------------------- +// absl::BitGen member functions +// --------------------------------------------------------------------------- + +// absl::BitGen::operator()() +// +// Calls the BitGen, returning a generated value. + +// absl::BitGen::min() +// +// Returns the smallest possible value from this bit generator. + +// absl::BitGen::max() +// +// Returns the largest possible value from this bit generator., and + +// absl::BitGen::discard(num) +// +// Advances the internal state of this bit generator by `num` times, and +// discards the intermediate results. +// --------------------------------------------------------------------------- + +using BitGen = random_internal::NonsecureURBGBase< + random_internal::randen_engine<uint64_t>>; + +// ----------------------------------------------------------------------------- +// absl::InsecureBitGen +// ----------------------------------------------------------------------------- +// +// `absl::InsecureBitGen` is an efficient random bit generator for generating +// random values, recommended only for performance-sensitive use cases where +// `absl::BitGen` is not satisfactory when compute-bounded by bit generation +// costs. +// +// Example: +// +// // Create an absl::InsecureBitGen +// absl::InsecureBitGen gen; +// for (size_t i = 0; i < 1000000; i++) { +// +// // Generate a bunch of random values from some complex distribution +// auto my_rnd = some_distribution(gen, 1, 1000); +// } +// +// Like `absl::BitGen`, `absl::InsecureBitGen` is seeded by default with +// non-deterministic data to produce different sequences of random values across +// different instances, including different binary invocations. (This behavior +// is different than the standard library bit generators, which use golden +// values as their seeds.) +// +// `absl::InsecureBitGen` may be constructed with an optional seed sequence +// type, conforming to [rand.req.seed_seq], which will be mixed with additional +// non-deterministic data. (See std_seed_seq.h for more information.) +// +// `absl::InsecureBitGen` meets the requirements of the Uniform Random Bit +// Generator (URBG) concept as per the C++17 standard [rand.req.urng] though +// its implementation differs slightly with [rand.req.eng]. Like its standard +// library equivalents (e.g. `std::mersenne_twister_engine`) +// `absl::InsecureBitGen` is not cryptographically secure. +// +// Prefer `absl::BitGen` over `absl::InsecureBitGen` as the general type is +// often fast enough for the vast majority of applications. + +using InsecureBitGen = + random_internal::NonsecureURBGBase<random_internal::pcg64_2018_engine>; + +// --------------------------------------------------------------------------- +// absl::InsecureBitGen member functions +// --------------------------------------------------------------------------- + +// absl::InsecureBitGen::operator()() +// +// Calls the InsecureBitGen, returning a generated value. + +// absl::InsecureBitGen::min() +// +// Returns the smallest possible value from this bit generator. + +// absl::InsecureBitGen::max() +// +// Returns the largest possible value from this bit generator. + +// absl::InsecureBitGen::discard(num) +// +// Advances the internal state of this bit generator by `num` times, and +// discards the intermediate results. +// --------------------------------------------------------------------------- + +} // namespace absl + +#endif // ABSL_RANDOM_RANDOM_H_ diff --git a/absl/random/seed_gen_exception.cc b/absl/random/seed_gen_exception.cc new file mode 100644 index 000000000000..e4271baaa07f --- /dev/null +++ b/absl/random/seed_gen_exception.cc @@ -0,0 +1,44 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/seed_gen_exception.h" + +#include <iostream> + +#include "absl/base/config.h" + +namespace absl { + +static constexpr const char kExceptionMessage[] = + "Failed generating seed-material for URBG."; + +SeedGenException::~SeedGenException() = default; + +const char* SeedGenException::what() const noexcept { + return kExceptionMessage; +} + +namespace random_internal { + +void ThrowSeedGenException() { +#ifdef ABSL_HAVE_EXCEPTIONS + throw absl::SeedGenException(); +#else + std::cerr << kExceptionMessage << std::endl; + std::terminate(); +#endif +} + +} // namespace random_internal +} // namespace absl diff --git a/absl/random/seed_gen_exception.h b/absl/random/seed_gen_exception.h new file mode 100644 index 000000000000..b464d52fb954 --- /dev/null +++ b/absl/random/seed_gen_exception.h @@ -0,0 +1,51 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// ----------------------------------------------------------------------------- +// File: seed_gen_exception.h +// ----------------------------------------------------------------------------- +// +// This header defines an exception class which may be thrown if unpredictable +// events prevent the derivation of suitable seed-material for constructing a +// bit generator conforming to [rand.req.urng] (eg. entropy cannot be read from +// /dev/urandom on a Unix-based system). +// +// Note: if exceptions are disabled, `std::terminate()` is called instead. + +#ifndef ABSL_RANDOM_SEED_GEN_EXCEPTION_H_ +#define ABSL_RANDOM_SEED_GEN_EXCEPTION_H_ + +#include <exception> + +namespace absl { + +//------------------------------------------------------------------------------ +// SeedGenException +//------------------------------------------------------------------------------ +class SeedGenException : public std::exception { + public: + SeedGenException() = default; + ~SeedGenException() override; + const char* what() const noexcept override; +}; + +namespace random_internal { + +// throw delegator +[[noreturn]] void ThrowSeedGenException(); + +} // namespace random_internal +} // namespace absl + +#endif // ABSL_RANDOM_SEED_GEN_EXCEPTION_H_ diff --git a/absl/random/seed_sequences.cc b/absl/random/seed_sequences.cc new file mode 100644 index 000000000000..9f319615289a --- /dev/null +++ b/absl/random/seed_sequences.cc @@ -0,0 +1,27 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/seed_sequences.h" + +#include "absl/random/internal/pool_urbg.h" + +namespace absl { + +SeedSeq MakeSeedSeq() { + SeedSeq::result_type seed_material[8]; + random_internal::RandenPool<uint32_t>::Fill(absl::MakeSpan(seed_material)); + return SeedSeq(std::begin(seed_material), std::end(seed_material)); +} + +} // namespace absl diff --git a/absl/random/seed_sequences.h b/absl/random/seed_sequences.h new file mode 100644 index 000000000000..631d1ecd4ea9 --- /dev/null +++ b/absl/random/seed_sequences.h @@ -0,0 +1,108 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// ----------------------------------------------------------------------------- +// File: seed_sequences.h +// ----------------------------------------------------------------------------- +// +// This header contains utilities for creating and working with seed sequences +// conforming to [rand.req.seedseq]. In general, direct construction of seed +// sequences is discouraged, but use-cases for construction of identical bit +// generators (using the same seed sequence) may be helpful (e.g. replaying a +// simulation whose state is derived from variates of a bit generator). + +#ifndef ABSL_RANDOM_SEED_SEQUENCES_H_ +#define ABSL_RANDOM_SEED_SEQUENCES_H_ + +#include <iterator> +#include <random> + +#include "absl/random/internal/salted_seed_seq.h" +#include "absl/random/internal/seed_material.h" +#include "absl/random/seed_gen_exception.h" +#include "absl/types/span.h" + +namespace absl { + +// ----------------------------------------------------------------------------- +// absl::SeedSeq +// ----------------------------------------------------------------------------- +// +// `absl::SeedSeq` constructs a seed sequence according to [rand.req.seedseq] +// for use within bit generators. `absl::SeedSeq`, unlike `std::seed_seq` +// additionally salts the generated seeds with extra implementation-defined +// entropy. For that reason, you can use `absl::SeedSeq` in combination with +// standard library bit generators (e.g. `std::mt19937`) to introduce +// non-determinism in your seeds. +// +// Example: +// +// absl::SeedSeq my_seed_seq({a, b, c}); +// std::mt19937 my_bitgen(my_seed_seq); +// +using SeedSeq = random_internal::SaltedSeedSeq<std::seed_seq>; + +// ----------------------------------------------------------------------------- +// absl::CreateSeedSeqFrom(bitgen*) +// ----------------------------------------------------------------------------- +// +// Constructs a seed sequence conforming to [rand.req.seedseq] using variates +// produced by a provided bit generator. +// +// You should generally avoid direct construction of seed sequences, but +// use-cases for reuse of a seed sequence to construct identical bit generators +// may be helpful (eg. replaying a simulation whose state is derived from bit +// generator values). +// +// If bitgen == nullptr, then behavior is undefined. +// +// Example: +// +// absl::BitGen my_bitgen; +// auto seed_seq = absl::CreateSeedSeqFrom(&my_bitgen); +// absl::BitGen new_engine(seed_seq); // derived from my_bitgen, but not +// // correlated. +// +template <typename URBG> +SeedSeq CreateSeedSeqFrom(URBG* urbg) { + SeedSeq::result_type + seed_material[random_internal::kEntropyBlocksNeeded]; + + if (!random_internal::ReadSeedMaterialFromURBG( + urbg, absl::MakeSpan(seed_material))) { + random_internal::ThrowSeedGenException(); + } + return SeedSeq(std::begin(seed_material), std::end(seed_material)); +} + +// ----------------------------------------------------------------------------- +// absl::MakeSeedSeq() +// ----------------------------------------------------------------------------- +// +// Constructs an `absl::SeedSeq` salting the generated values using +// implementation-defined entropy. The returned sequence can be used to create +// equivalent bit generators correlated using this sequence. +// +// Example: +// +// auto my_seed_seq = absl::MakeSeedSeq(); +// std::mt19937 rng1(my_seed_seq); +// std::mt19937 rng2(my_seed_seq); +// EXPECT_EQ(rng1(), rng2()); +// +SeedSeq MakeSeedSeq(); + +} // namespace absl + +#endif // ABSL_RANDOM_SEED_SEQUENCES_H_ diff --git a/absl/random/seed_sequences_test.cc b/absl/random/seed_sequences_test.cc new file mode 100644 index 000000000000..2cc8b0e6f2b7 --- /dev/null +++ b/absl/random/seed_sequences_test.cc @@ -0,0 +1,127 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/seed_sequences.h" + +#include <iterator> +#include <random> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/random/internal/nonsecure_base.h" +#include "absl/random/random.h" +namespace { + +TEST(SeedSequences, Examples) { + { + absl::SeedSeq seed_seq({1, 2, 3}); + absl::BitGen bitgen(seed_seq); + + EXPECT_NE(0, bitgen()); + } + { + absl::BitGen engine; + auto seed_seq = absl::CreateSeedSeqFrom(&engine); + absl::BitGen bitgen(seed_seq); + + EXPECT_NE(engine(), bitgen()); + } + { + auto seed_seq = absl::MakeSeedSeq(); + std::mt19937 random(seed_seq); + + EXPECT_NE(0, random()); + } +} + +TEST(CreateSeedSeqFrom, CompatibleWithStdTypes) { + using ExampleNonsecureURBG = + absl::random_internal::NonsecureURBGBase<std::minstd_rand0>; + + // Construct a URBG instance. + ExampleNonsecureURBG rng; + + // Construct a Seed Sequence from its variates. + auto seq_from_rng = absl::CreateSeedSeqFrom(&rng); + + // Ensure that another URBG can be validly constructed from the Seed Sequence. + std::mt19937_64{seq_from_rng}; +} + +TEST(CreateSeedSeqFrom, CompatibleWithBitGenerator) { + // Construct a URBG instance. + absl::BitGen rng; + + // Construct a Seed Sequence from its variates. + auto seq_from_rng = absl::CreateSeedSeqFrom(&rng); + + // Ensure that another URBG can be validly constructed from the Seed Sequence. + std::mt19937_64{seq_from_rng}; +} + +TEST(CreateSeedSeqFrom, CompatibleWithInsecureBitGen) { + // Construct a URBG instance. + absl::InsecureBitGen rng; + + // Construct a Seed Sequence from its variates. + auto seq_from_rng = absl::CreateSeedSeqFrom(&rng); + + // Ensure that another URBG can be validly constructed from the Seed Sequence. + std::mt19937_64{seq_from_rng}; +} + +TEST(CreateSeedSeqFrom, CompatibleWithRawURBG) { + // Construct a URBG instance. + std::random_device urandom; + + // Construct a Seed Sequence from its variates, using 64b of seed-material. + auto seq_from_rng = absl::CreateSeedSeqFrom(&urandom); + + // Ensure that another URBG can be validly constructed from the Seed Sequence. + std::mt19937_64{seq_from_rng}; +} + +template <typename URBG> +void TestReproducibleVariateSequencesForNonsecureURBG() { + const size_t kNumVariates = 1000; + + // Master RNG instance. + URBG rng; + // Reused for both RNG instances. + auto reusable_seed = absl::CreateSeedSeqFrom(&rng); + + typename URBG::result_type variates[kNumVariates]; + { + URBG child(reusable_seed); + for (auto& variate : variates) { + variate = child(); + } + } + // Ensure that variate-sequence can be "replayed" by identical RNG. + { + URBG child(reusable_seed); + for (auto& variate : variates) { + ASSERT_EQ(variate, child()); + } + } +} + +TEST(CreateSeedSeqFrom, ReproducesVariateSequencesForInsecureBitGen) { + TestReproducibleVariateSequencesForNonsecureURBG<absl::InsecureBitGen>(); +} + +TEST(CreateSeedSeqFrom, ReproducesVariateSequencesForBitGenerator) { + TestReproducibleVariateSequencesForNonsecureURBG<absl::BitGen>(); +} +} // namespace diff --git a/absl/random/uniform_int_distribution.h b/absl/random/uniform_int_distribution.h new file mode 100644 index 000000000000..4970486a561a --- /dev/null +++ b/absl/random/uniform_int_distribution.h @@ -0,0 +1,273 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// ----------------------------------------------------------------------------- +// File: uniform_int_distribution.h +// ----------------------------------------------------------------------------- +// +// This header defines a class for representing a uniform integer distribution +// over the closed (inclusive) interval [a,b]. You use this distribution in +// combination with an Abseil random bit generator to produce random values +// according to the rules of the distribution. +// +// `absl::uniform_int_distribution` is a drop-in replacement for the C++11 +// `std::uniform_int_distribution` [rand.dist.uni.int] but is considerably +// faster than the libstdc++ implementation. + +#ifndef ABSL_RANDOM_UNIFORM_INT_DISTRIBUTION_H_ +#define ABSL_RANDOM_UNIFORM_INT_DISTRIBUTION_H_ + +#include <cassert> +#include <istream> +#include <limits> +#include <type_traits> + +#include "absl/base/optimization.h" +#include "absl/random/internal/distribution_impl.h" +#include "absl/random/internal/fast_uniform_bits.h" +#include "absl/random/internal/iostream_state_saver.h" +#include "absl/random/internal/traits.h" + +namespace absl { + +// absl::uniform_int_distribution<T> +// +// This distribution produces random integer values uniformly distributed in the +// closed (inclusive) interval [a, b]. +// +// Example: +// +// absl::BitGen gen; +// +// // Use the distribution to produce a value between 1 and 6, inclusive. +// int die_roll = absl::uniform_int_distribution<int>(1, 6)(gen); +// +template <typename IntType = int> +class uniform_int_distribution { + private: + using unsigned_type = + typename random_internal::make_unsigned_bits<IntType>::type; + + public: + using result_type = IntType; + + class param_type { + public: + using distribution_type = uniform_int_distribution; + + explicit param_type( + result_type lo = 0, + result_type hi = (std::numeric_limits<result_type>::max)()) + : lo_(lo), + range_(static_cast<unsigned_type>(hi) - + static_cast<unsigned_type>(lo)) { + // [rand.dist.uni.int] precondition 2 + assert(lo <= hi); + } + + result_type a() const { return lo_; } + result_type b() const { + return static_cast<result_type>(static_cast<unsigned_type>(lo_) + range_); + } + + friend bool operator==(const param_type& a, const param_type& b) { + return a.lo_ == b.lo_ && a.range_ == b.range_; + } + + friend bool operator!=(const param_type& a, const param_type& b) { + return !(a == b); + } + + private: + friend class uniform_int_distribution; + unsigned_type range() const { return range_; } + + result_type lo_; + unsigned_type range_; + + static_assert(std::is_integral<result_type>::value, + "Class-template absl::uniform_int_distribution<> must be " + "parameterized using an integral type."); + }; // param_type + + uniform_int_distribution() : uniform_int_distribution(0) {} + + explicit uniform_int_distribution( + result_type lo, + result_type hi = (std::numeric_limits<result_type>::max)()) + : param_(lo, hi) {} + + explicit uniform_int_distribution(const param_type& param) : param_(param) {} + + // uniform_int_distribution<T>::reset() + // + // Resets the uniform int distribution. Note that this function has no effect + // because the distribution already produces independent values. + void reset() {} + + template <typename URBG> + result_type operator()(URBG& gen) { // NOLINT(runtime/references) + return (*this)(gen, param()); + } + + template <typename URBG> + result_type operator()( + URBG& gen, const param_type& param) { // NOLINT(runtime/references) + return param.a() + Generate(gen, param.range()); + } + + result_type a() const { return param_.a(); } + result_type b() const { return param_.b(); } + + param_type param() const { return param_; } + void param(const param_type& params) { param_ = params; } + + result_type(min)() const { return a(); } + result_type(max)() const { return b(); } + + friend bool operator==(const uniform_int_distribution& a, + const uniform_int_distribution& b) { + return a.param_ == b.param_; + } + friend bool operator!=(const uniform_int_distribution& a, + const uniform_int_distribution& b) { + return !(a == b); + } + + private: + // Generates a value in the *closed* interval [0, R] + template <typename URBG> + unsigned_type Generate(URBG& g, // NOLINT(runtime/references) + unsigned_type R); + param_type param_; +}; + +// ----------------------------------------------------------------------------- +// Implementation details follow +// ----------------------------------------------------------------------------- +template <typename CharT, typename Traits, typename IntType> +std::basic_ostream<CharT, Traits>& operator<<( + std::basic_ostream<CharT, Traits>& os, + const uniform_int_distribution<IntType>& x) { + using stream_type = + typename random_internal::stream_format_type<IntType>::type; + auto saver = random_internal::make_ostream_state_saver(os); + os << static_cast<stream_type>(x.a()) << os.fill() + << static_cast<stream_type>(x.b()); + return os; +} + +template <typename CharT, typename Traits, typename IntType> +std::basic_istream<CharT, Traits>& operator>>( + std::basic_istream<CharT, Traits>& is, + uniform_int_distribution<IntType>& x) { + using param_type = typename uniform_int_distribution<IntType>::param_type; + using result_type = typename uniform_int_distribution<IntType>::result_type; + using stream_type = + typename random_internal::stream_format_type<IntType>::type; + + stream_type a; + stream_type b; + + auto saver = random_internal::make_istream_state_saver(is); + is >> a >> b; + if (!is.fail()) { + x.param( + param_type(static_cast<result_type>(a), static_cast<result_type>(b))); + } + return is; +} + +template <typename IntType> +template <typename URBG> +typename random_internal::make_unsigned_bits<IntType>::type +uniform_int_distribution<IntType>::Generate( + URBG& g, // NOLINT(runtime/references) + typename random_internal::make_unsigned_bits<IntType>::type R) { + random_internal::FastUniformBits<unsigned_type> fast_bits; + unsigned_type bits = fast_bits(g); + const unsigned_type Lim = R + 1; + if ((R & Lim) == 0) { + // If the interval's length is a power of two range, just take the low bits. + return bits & R; + } + + // Generates a uniform variate on [0, Lim) using fixed-point multiplication. + // The above fast-path guarantees that Lim is representable in unsigned_type. + // + // Algorithm adapted from + // http://lemire.me/blog/2016/06/30/fast-random-shuffling/, with added + // explanation. + // + // The algorithm creates a uniform variate `bits` in the interval [0, 2^N), + // and treats it as the fractional part of a fixed-point real value in [0, 1), + // multiplied by 2^N. For example, 0.25 would be represented as 2^(N - 2), + // because 2^N * 0.25 == 2^(N - 2). + // + // Next, `bits` and `Lim` are multiplied with a wide-multiply to bring the + // value into the range [0, Lim). The integral part (the high word of the + // multiplication result) is then very nearly the desired result. However, + // this is not quite accurate; viewing the multiplication result as one + // double-width integer, the resulting values for the sample are mapped as + // follows: + // + // If the result lies in this interval: Return this value: + // [0, 2^N) 0 + // [2^N, 2 * 2^N) 1 + // ... ... + // [K * 2^N, (K + 1) * 2^N) K + // ... ... + // [(Lim - 1) * 2^N, Lim * 2^N) Lim - 1 + // + // While all of these intervals have the same size, the result of `bits * Lim` + // must be a multiple of `Lim`, and not all of these intervals contain the + // same number of multiples of `Lim`. In particular, some contain + // `F = floor(2^N / Lim)` and some contain `F + 1 = ceil(2^N / Lim)`. This + // difference produces a small nonuniformity, which is corrected by applying + // rejection sampling to one of the values in the "larger intervals" (i.e., + // the intervals containing `F + 1` multiples of `Lim`. + // + // An interval contains `F + 1` multiples of `Lim` if and only if its smallest + // value modulo 2^N is less than `2^N % Lim`. The unique value satisfying + // this property is used as the one for rejection. That is, a value of + // `bits * Lim` is rejected if `(bit * Lim) % 2^N < (2^N % Lim)`. + + using helper = random_internal::wide_multiply<unsigned_type>; + auto product = helper::multiply(bits, Lim); + + // Two optimizations here: + // * Rejection occurs with some probability less than 1/2, and for reasonable + // ranges considerably less (in particular, less than 1/(F+1)), so + // ABSL_PREDICT_FALSE is apt. + // * `Lim` is an overestimate of `threshold`, and doesn't require a divide. + if (ABSL_PREDICT_FALSE(helper::lo(product) < Lim)) { + // This quantity is exactly equal to `2^N % Lim`, but does not require high + // precision calculations: `2^N % Lim` is congruent to `(2^N - Lim) % Lim`. + // Ideally this could be expressed simply as `-X` rather than `2^N - X`, but + // for types smaller than int, this calculation is incorrect due to integer + // promotion rules. + const unsigned_type threshold = + ((std::numeric_limits<unsigned_type>::max)() - Lim + 1) % Lim; + while (helper::lo(product) < threshold) { + bits = fast_bits(g); + product = helper::multiply(bits, Lim); + } + } + + return helper::hi(product); +} + +} // namespace absl + +#endif // ABSL_RANDOM_UNIFORM_INT_DISTRIBUTION_H_ diff --git a/absl/random/uniform_int_distribution_test.cc b/absl/random/uniform_int_distribution_test.cc new file mode 100644 index 000000000000..aacff88d3866 --- /dev/null +++ b/absl/random/uniform_int_distribution_test.cc @@ -0,0 +1,250 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/uniform_int_distribution.h" + +#include <cmath> +#include <cstdint> +#include <iterator> +#include <random> +#include <sstream> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/base/internal/raw_logging.h" +#include "absl/random/internal/chi_square.h" +#include "absl/random/internal/distribution_test_util.h" +#include "absl/random/internal/sequence_urbg.h" +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" + +namespace { + +template <typename IntType> +class UniformIntDistributionTest : public ::testing::Test {}; + +using IntTypes = ::testing::Types<int8_t, uint8_t, int16_t, uint16_t, int32_t, + uint32_t, int64_t, uint64_t>; +TYPED_TEST_SUITE(UniformIntDistributionTest, IntTypes); + +TYPED_TEST(UniformIntDistributionTest, ParamSerializeTest) { + // This test essentially ensures that the parameters serialize, + // not that the values generated cover the full range. + using Limits = std::numeric_limits<TypeParam>; + using param_type = + typename absl::uniform_int_distribution<TypeParam>::param_type; + const TypeParam kMin = std::is_unsigned<TypeParam>::value ? 37 : -105; + const TypeParam kNegOneOrZero = std::is_unsigned<TypeParam>::value ? 0 : -1; + + constexpr int kCount = 1000; + absl::InsecureBitGen gen; + for (const auto& param : { + param_type(), + param_type(2, 2), // Same + param_type(9, 32), + param_type(kMin, 115), + param_type(kNegOneOrZero, Limits::max()), + param_type(Limits::min(), Limits::max()), + param_type(Limits::lowest(), Limits::max()), + param_type(Limits::min() + 1, Limits::max() - 1), + }) { + const auto a = param.a(); + const auto b = param.b(); + absl::uniform_int_distribution<TypeParam> before(a, b); + EXPECT_EQ(before.a(), param.a()); + EXPECT_EQ(before.b(), param.b()); + + { + // Initialize via param_type + absl::uniform_int_distribution<TypeParam> via_param(param); + EXPECT_EQ(via_param, before); + } + + // Initialize via iostreams + std::stringstream ss; + ss << before; + + absl::uniform_int_distribution<TypeParam> after(Limits::min() + 3, + Limits::max() - 5); + + EXPECT_NE(before.a(), after.a()); + EXPECT_NE(before.b(), after.b()); + EXPECT_NE(before.param(), after.param()); + EXPECT_NE(before, after); + + ss >> after; + + EXPECT_EQ(before.a(), after.a()); + EXPECT_EQ(before.b(), after.b()); + EXPECT_EQ(before.param(), after.param()); + EXPECT_EQ(before, after); + + // Smoke test. + auto sample_min = after.max(); + auto sample_max = after.min(); + for (int i = 0; i < kCount; i++) { + auto sample = after(gen); + EXPECT_GE(sample, after.min()); + EXPECT_LE(sample, after.max()); + if (sample > sample_max) { + sample_max = sample; + } + if (sample < sample_min) { + sample_min = sample; + } + } + std::string msg = absl::StrCat("Range: ", +sample_min, ", ", +sample_max); + ABSL_RAW_LOG(INFO, "%s", msg.c_str()); + } +} + +TYPED_TEST(UniformIntDistributionTest, ViolatesPreconditionsDeathTest) { +#if GTEST_HAS_DEATH_TEST + // Hi < Lo + EXPECT_DEBUG_DEATH({ absl::uniform_int_distribution<TypeParam> dist(10, 1); }, + ""); +#endif // GTEST_HAS_DEATH_TEST +#if defined(NDEBUG) + // opt-mode, for invalid parameters, will generate a garbage value, + // but should not enter an infinite loop. + absl::InsecureBitGen gen; + absl::uniform_int_distribution<TypeParam> dist(10, 1); + auto x = dist(gen); + + // Any value will generate a non-empty std::string. + EXPECT_FALSE(absl::StrCat(+x).empty()) << x; +#endif // NDEBUG +} + +TYPED_TEST(UniformIntDistributionTest, TestMoments) { + constexpr int kSize = 100000; + using Limits = std::numeric_limits<TypeParam>; + using param_type = + typename absl::uniform_int_distribution<TypeParam>::param_type; + + absl::InsecureBitGen rng; + std::vector<double> values(kSize); + for (const auto& param : + {param_type(0, Limits::max()), param_type(13, 127)}) { + absl::uniform_int_distribution<TypeParam> dist(param); + for (int i = 0; i < kSize; i++) { + const auto sample = dist(rng); + ASSERT_LE(dist.param().a(), sample); + ASSERT_GE(dist.param().b(), sample); + values[i] = sample; + } + + auto moments = absl::random_internal::ComputeDistributionMoments(values); + const double a = dist.param().a(); + const double b = dist.param().b(); + const double n = (b - a + 1); + const double mean = (a + b) / 2; + const double var = ((b - a + 1) * (b - a + 1) - 1) / 12; + const double kurtosis = 3 - 6 * (n * n + 1) / (5 * (n * n - 1)); + + // TODO(ahh): this is not the right bound + // empirically validated with --runs_per_test=10000. + EXPECT_NEAR(mean, moments.mean, 0.01 * var); + EXPECT_NEAR(var, moments.variance, 0.015 * var); + EXPECT_NEAR(0.0, moments.skewness, 0.025); + EXPECT_NEAR(kurtosis, moments.kurtosis, 0.02 * kurtosis); + } +} + +TYPED_TEST(UniformIntDistributionTest, ChiSquaredTest50) { + using absl::random_internal::kChiSquared; + + constexpr size_t kTrials = 1000; + constexpr int kBuckets = 50; // inclusive, so actally +1 + constexpr double kExpected = + static_cast<double>(kTrials) / static_cast<double>(kBuckets); + + // Empirically validated with --runs_per_test=10000. + const int kThreshold = + absl::random_internal::ChiSquareValue(kBuckets, 0.999999); + + const TypeParam min = std::is_unsigned<TypeParam>::value ? 37 : -37; + const TypeParam max = min + kBuckets; + + absl::InsecureBitGen rng; + absl::uniform_int_distribution<TypeParam> dist(min, max); + + std::vector<int32_t> counts(kBuckets + 1, 0); + for (size_t i = 0; i < kTrials; i++) { + auto x = dist(rng); + counts[x - min]++; + } + double chi_square = absl::random_internal::ChiSquareWithExpected( + std::begin(counts), std::end(counts), kExpected); + if (chi_square > kThreshold) { + double p_value = + absl::random_internal::ChiSquarePValue(chi_square, kBuckets); + + // Chi-squared test failed. Output does not appear to be uniform. + std::string msg; + for (const auto& a : counts) { + absl::StrAppend(&msg, a, "\n"); + } + absl::StrAppend(&msg, kChiSquared, " p-value ", p_value, "\n"); + absl::StrAppend(&msg, "High ", kChiSquared, " value: ", chi_square, " > ", + kThreshold); + ABSL_RAW_LOG(INFO, "%s", msg.c_str()); + FAIL() << msg; + } +} + +TEST(UniformIntDistributionTest, StabilityTest) { + // absl::uniform_int_distribution stability relies only on integer operations. + absl::random_internal::sequence_urbg urbg( + {0x0003eb76f6f7f755ull, 0xFFCEA50FDB2F953Bull, 0xC332DDEFBE6C5AA5ull, + 0x6558218568AB9702ull, 0x2AEF7DAD5B6E2F84ull, 0x1521B62829076170ull, + 0xECDD4775619F1510ull, 0x13CCA830EB61BD96ull, 0x0334FE1EAA0363CFull, + 0xB5735C904C70A239ull, 0xD59E9E0BCBAADE14ull, 0xEECC86BC60622CA7ull}); + + std::vector<int> output(12); + + { + absl::uniform_int_distribution<int32_t> dist(0, 4); + for (auto& v : output) { + v = dist(urbg); + } + } + EXPECT_EQ(12, urbg.invocations()); + EXPECT_THAT(output, testing::ElementsAre(4, 4, 3, 2, 1, 0, 1, 4, 3, 1, 3, 1)); + + { + urbg.reset(); + absl::uniform_int_distribution<int32_t> dist(0, 100); + for (auto& v : output) { + v = dist(urbg); + } + } + EXPECT_EQ(12, urbg.invocations()); + EXPECT_THAT(output, testing::ElementsAre(97, 86, 75, 41, 36, 16, 38, 92, 67, + 30, 80, 38)); + + { + urbg.reset(); + absl::uniform_int_distribution<int32_t> dist(0, 10000); + for (auto& v : output) { + v = dist(urbg); + } + } + EXPECT_EQ(12, urbg.invocations()); + EXPECT_THAT(output, testing::ElementsAre(9648, 8562, 7439, 4089, 3571, 1602, + 3813, 9195, 6641, 2986, 7956, 3765)); +} + +} // namespace diff --git a/absl/random/uniform_real_distribution.h b/absl/random/uniform_real_distribution.h new file mode 100644 index 000000000000..600f915b67af --- /dev/null +++ b/absl/random/uniform_real_distribution.h @@ -0,0 +1,193 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// ----------------------------------------------------------------------------- +// File: uniform_real_distribution.h +// ----------------------------------------------------------------------------- +// +// This header defines a class for representing a uniform floating-point +// distribution over a half-open interval [a,b). You use this distribution in +// combination with an Abseil random bit generator to produce random values +// according to the rules of the distribution. +// +// `absl::uniform_real_distribution` is a drop-in replacement for the C++11 +// `std::uniform_real_distribution` [rand.dist.uni.real] but is considerably +// faster than the libstdc++ implementation. +// +// Note: the standard-library version may occasionally return `1.0` when +// default-initialized. See https://bugs.llvm.org//show_bug.cgi?id=18767 +// `absl::uniform_real_distribution` does not exhibit this behavior. + +#ifndef ABSL_RANDOM_UNIFORM_REAL_DISTRIBUTION_H_ +#define ABSL_RANDOM_UNIFORM_REAL_DISTRIBUTION_H_ + +#include <cassert> +#include <cmath> +#include <cstdint> +#include <istream> +#include <limits> +#include <type_traits> + +#include "absl/random/internal/distribution_impl.h" +#include "absl/random/internal/fast_uniform_bits.h" +#include "absl/random/internal/iostream_state_saver.h" + +namespace absl { + +// absl::uniform_real_distribution<T> +// +// This distribution produces random floating-point values uniformly distributed +// over the half-open interval [a, b). +// +// Example: +// +// absl::BitGen gen; +// +// // Use the distribution to produce a value between 0.0 (inclusive) +// // and 1.0 (exclusive). +// int value = absl::uniform_real_distribution<double>(0, 1)(gen); +// +template <typename RealType = double> +class uniform_real_distribution { + public: + using result_type = RealType; + + class param_type { + public: + using distribution_type = uniform_real_distribution; + + explicit param_type(result_type lo = 0, result_type hi = 1) + : lo_(lo), hi_(hi), range_(hi - lo) { + // [rand.dist.uni.real] preconditions 2 & 3 + assert(lo <= hi); + // NOTE: For integral types, we can promote the range to an unsigned type, + // which gives full width of the range. However for real (fp) types, this + // is not possible, so value generation cannot use the full range of the + // real type. + assert(range_ <= (std::numeric_limits<result_type>::max)()); + } + + result_type a() const { return lo_; } + result_type b() const { return hi_; } + + friend bool operator==(const param_type& a, const param_type& b) { + return a.lo_ == b.lo_ && a.hi_ == b.hi_; + } + + friend bool operator!=(const param_type& a, const param_type& b) { + return !(a == b); + } + + private: + friend class uniform_real_distribution; + result_type lo_, hi_, range_; + + static_assert(std::is_floating_point<RealType>::value, + "Class-template absl::uniform_real_distribution<> must be " + "parameterized using a floating-point type."); + }; + + uniform_real_distribution() : uniform_real_distribution(0) {} + + explicit uniform_real_distribution(result_type lo, result_type hi = 1) + : param_(lo, hi) {} + + explicit uniform_real_distribution(const param_type& param) : param_(param) {} + + // uniform_real_distribution<T>::reset() + // + // Resets the uniform real distribution. Note that this function has no effect + // because the distribution already produces independent values. + void reset() {} + + template <typename URBG> + result_type operator()(URBG& gen) { // NOLINT(runtime/references) + return operator()(gen, param_); + } + + template <typename URBG> + result_type operator()(URBG& gen, // NOLINT(runtime/references) + const param_type& p); + + result_type a() const { return param_.a(); } + result_type b() const { return param_.b(); } + + param_type param() const { return param_; } + void param(const param_type& params) { param_ = params; } + + result_type(min)() const { return a(); } + result_type(max)() const { return b(); } + + friend bool operator==(const uniform_real_distribution& a, + const uniform_real_distribution& b) { + return a.param_ == b.param_; + } + friend bool operator!=(const uniform_real_distribution& a, + const uniform_real_distribution& b) { + return a.param_ != b.param_; + } + + private: + param_type param_; + random_internal::FastUniformBits<uint64_t> fast_u64_; +}; + +// ----------------------------------------------------------------------------- +// Implementation details follow +// ----------------------------------------------------------------------------- +template <typename RealType> +template <typename URBG> +typename uniform_real_distribution<RealType>::result_type +uniform_real_distribution<RealType>::operator()( + URBG& gen, const param_type& p) { // NOLINT(runtime/references) + using random_internal::PositiveValueT; + while (true) { + const result_type sample = random_internal::RandU64ToReal< + result_type>::template Value<PositiveValueT, true>(fast_u64_(gen)); + const result_type res = p.a() + (sample * p.range_); + if (res < p.b() || p.range_ <= 0 || !std::isfinite(p.range_)) { + return res; + } + // else sample rejected, try again. + } +} + +template <typename CharT, typename Traits, typename RealType> +std::basic_ostream<CharT, Traits>& operator<<( + std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references) + const uniform_real_distribution<RealType>& x) { + auto saver = random_internal::make_ostream_state_saver(os); + os.precision(random_internal::stream_precision_helper<RealType>::kPrecision); + os << x.a() << os.fill() << x.b(); + return os; +} + +template <typename CharT, typename Traits, typename RealType> +std::basic_istream<CharT, Traits>& operator>>( + std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references) + uniform_real_distribution<RealType>& x) { // NOLINT(runtime/references) + using param_type = typename uniform_real_distribution<RealType>::param_type; + using result_type = typename uniform_real_distribution<RealType>::result_type; + auto saver = random_internal::make_istream_state_saver(is); + auto a = random_internal::read_floating_point<result_type>(is); + if (is.fail()) return is; + auto b = random_internal::read_floating_point<result_type>(is); + if (!is.fail()) { + x.param(param_type(a, b)); + } + return is; +} +} // namespace absl + +#endif // ABSL_RANDOM_UNIFORM_REAL_DISTRIBUTION_H_ diff --git a/absl/random/uniform_real_distribution_test.cc b/absl/random/uniform_real_distribution_test.cc new file mode 100644 index 000000000000..597f0ee5efa5 --- /dev/null +++ b/absl/random/uniform_real_distribution_test.cc @@ -0,0 +1,322 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/uniform_real_distribution.h" + +#include <cmath> +#include <cstdint> +#include <iterator> +#include <random> +#include <sstream> +#include <string> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/base/internal/raw_logging.h" +#include "absl/random/internal/chi_square.h" +#include "absl/random/internal/distribution_test_util.h" +#include "absl/random/internal/sequence_urbg.h" +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" + +// NOTES: +// * Some documentation on generating random real values suggests that +// it is possible to use std::nextafter(b, DBL_MAX) to generate a value on +// the closed range [a, b]. Unfortunately, that technique is not universally +// reliable due to floating point quantization. +// +// * absl::uniform_real_distribution<float> generates between 2^28 and 2^29 +// distinct floating point values in the range [0, 1). +// +// * absl::uniform_real_distribution<float> generates at least 2^23 distinct +// floating point values in the range [1, 2). This should be the same as +// any other range covered by a single exponent in IEEE 754. +// +// * absl::uniform_real_distribution<double> generates more than 2^52 distinct +// values in the range [0, 1), and should generate at least 2^52 distinct +// values in the range of [1, 2). +// + +namespace { + +template <typename RealType> +class UniformRealDistributionTest : public ::testing::Test {}; + +using RealTypes = ::testing::Types<float, double, long double>; +TYPED_TEST_SUITE(UniformRealDistributionTest, RealTypes); + +TYPED_TEST(UniformRealDistributionTest, ParamSerializeTest) { + using param_type = + typename absl::uniform_real_distribution<TypeParam>::param_type; + + constexpr const TypeParam a{1152921504606846976}; + + constexpr int kCount = 1000; + absl::InsecureBitGen gen; + for (const auto& param : { + param_type(), + param_type(TypeParam(2.0), TypeParam(2.0)), // Same + param_type(TypeParam(-0.1), TypeParam(0.1)), + param_type(TypeParam(0.05), TypeParam(0.12)), + param_type(TypeParam(-0.05), TypeParam(0.13)), + param_type(TypeParam(-0.05), TypeParam(-0.02)), + // double range = 0 + // 2^60 , 2^60 + 2^6 + param_type(a, TypeParam(1152921504606847040)), + // 2^60 , 2^60 + 2^7 + param_type(a, TypeParam(1152921504606847104)), + // double range = 2^8 + // 2^60 , 2^60 + 2^8 + param_type(a, TypeParam(1152921504606847232)), + // float range = 0 + // 2^60 , 2^60 + 2^36 + param_type(a, TypeParam(1152921573326323712)), + // 2^60 , 2^60 + 2^37 + param_type(a, TypeParam(1152921642045800448)), + // float range = 2^38 + // 2^60 , 2^60 + 2^38 + param_type(a, TypeParam(1152921779484753920)), + // Limits + param_type(0, std::numeric_limits<TypeParam>::max()), + param_type(std::numeric_limits<TypeParam>::lowest(), 0), + param_type(0, std::numeric_limits<TypeParam>::epsilon()), + param_type(-std::numeric_limits<TypeParam>::epsilon(), + std::numeric_limits<TypeParam>::epsilon()), + param_type(std::numeric_limits<TypeParam>::epsilon(), + 2 * std::numeric_limits<TypeParam>::epsilon()), + }) { + // Validate parameters. + const auto a = param.a(); + const auto b = param.b(); + absl::uniform_real_distribution<TypeParam> before(a, b); + EXPECT_EQ(before.a(), param.a()); + EXPECT_EQ(before.b(), param.b()); + + { + absl::uniform_real_distribution<TypeParam> via_param(param); + EXPECT_EQ(via_param, before); + } + + std::stringstream ss; + ss << before; + absl::uniform_real_distribution<TypeParam> after(TypeParam(1.0), + TypeParam(3.1)); + + EXPECT_NE(before.a(), after.a()); + EXPECT_NE(before.b(), after.b()); + EXPECT_NE(before.param(), after.param()); + EXPECT_NE(before, after); + + ss >> after; + + EXPECT_EQ(before.a(), after.a()); + EXPECT_EQ(before.b(), after.b()); + EXPECT_EQ(before.param(), after.param()); + EXPECT_EQ(before, after); + + // Smoke test. + auto sample_min = after.max(); + auto sample_max = after.min(); + for (int i = 0; i < kCount; i++) { + auto sample = after(gen); + // Failure here indicates a bug in uniform_real_distribution::operator(), + // or bad parameters--range too large, etc. + if (after.min() == after.max()) { + EXPECT_EQ(sample, after.min()); + } else { + EXPECT_GE(sample, after.min()); + EXPECT_LT(sample, after.max()); + } + if (sample > sample_max) { + sample_max = sample; + } + if (sample < sample_min) { + sample_min = sample; + } + } + + if (!std::is_same<TypeParam, long double>::value) { + // static_cast<double>(long double) can overflow. + std::string msg = absl::StrCat("Range: ", static_cast<double>(sample_min), + ", ", static_cast<double>(sample_max)); + ABSL_RAW_LOG(INFO, "%s", msg.c_str()); + } + } +} + +TYPED_TEST(UniformRealDistributionTest, ViolatesPreconditionsDeathTest) { +#if GTEST_HAS_DEATH_TEST + // Hi < Lo + EXPECT_DEBUG_DEATH( + { absl::uniform_real_distribution<TypeParam> dist(10.0, 1.0); }, ""); + + // Hi - Lo > numeric_limits<>::max() + EXPECT_DEBUG_DEATH( + { + absl::uniform_real_distribution<TypeParam> dist( + std::numeric_limits<TypeParam>::lowest(), + std::numeric_limits<TypeParam>::max()); + }, + ""); +#endif // GTEST_HAS_DEATH_TEST +#if defined(NDEBUG) + // opt-mode, for invalid parameters, will generate a garbage value, + // but should not enter an infinite loop. + absl::InsecureBitGen gen; + { + absl::uniform_real_distribution<TypeParam> dist(10.0, 1.0); + auto x = dist(gen); + EXPECT_FALSE(std::isnan(x)) << x; + } + { + absl::uniform_real_distribution<TypeParam> dist( + std::numeric_limits<TypeParam>::lowest(), + std::numeric_limits<TypeParam>::max()); + auto x = dist(gen); + // Infinite result. + EXPECT_FALSE(std::isfinite(x)) << x; + } +#endif // NDEBUG +} + +TYPED_TEST(UniformRealDistributionTest, TestMoments) { + constexpr int kSize = 1000000; + std::vector<double> values(kSize); + + absl::InsecureBitGen rng; + absl::uniform_real_distribution<TypeParam> dist; + for (int i = 0; i < kSize; i++) { + values[i] = dist(rng); + } + + const auto moments = + absl::random_internal::ComputeDistributionMoments(values); + EXPECT_NEAR(0.5, moments.mean, 0.01); + EXPECT_NEAR(1 / 12.0, moments.variance, 0.015); + EXPECT_NEAR(0.0, moments.skewness, 0.02); + EXPECT_NEAR(9 / 5.0, moments.kurtosis, 0.015); +} + +TYPED_TEST(UniformRealDistributionTest, ChiSquaredTest50) { + using absl::random_internal::kChiSquared; + using param_type = + typename absl::uniform_real_distribution<TypeParam>::param_type; + + constexpr size_t kTrials = 100000; + constexpr int kBuckets = 50; + constexpr double kExpected = + static_cast<double>(kTrials) / static_cast<double>(kBuckets); + + // 1-in-100000 threshold, but remember, there are about 8 tests + // in this file. And the test could fail for other reasons. + // Empirically validated with --runs_per_test=10000. + const int kThreshold = + absl::random_internal::ChiSquareValue(kBuckets - 1, 0.999999); + + absl::InsecureBitGen rng; + for (const auto& param : {param_type(0, 1), param_type(5, 12), + param_type(-5, 13), param_type(-5, -2)}) { + const double min_val = param.a(); + const double max_val = param.b(); + const double factor = kBuckets / (max_val - min_val); + + std::vector<int32_t> counts(kBuckets, 0); + absl::uniform_real_distribution<TypeParam> dist(param); + for (size_t i = 0; i < kTrials; i++) { + auto x = dist(rng); + auto bucket = static_cast<size_t>((x - min_val) * factor); + counts[bucket]++; + } + + double chi_square = absl::random_internal::ChiSquareWithExpected( + std::begin(counts), std::end(counts), kExpected); + if (chi_square > kThreshold) { + double p_value = + absl::random_internal::ChiSquarePValue(chi_square, kBuckets); + + // Chi-squared test failed. Output does not appear to be uniform. + std::string msg; + for (const auto& a : counts) { + absl::StrAppend(&msg, a, "\n"); + } + absl::StrAppend(&msg, kChiSquared, " p-value ", p_value, "\n"); + absl::StrAppend(&msg, "High ", kChiSquared, " value: ", chi_square, " > ", + kThreshold); + ABSL_RAW_LOG(INFO, "%s", msg.c_str()); + FAIL() << msg; + } + } +} + +TYPED_TEST(UniformRealDistributionTest, StabilityTest) { + // absl::uniform_real_distribution stability relies only on + // random_internal::RandU64ToDouble and random_internal::RandU64ToFloat. + absl::random_internal::sequence_urbg urbg( + {0x0003eb76f6f7f755ull, 0xFFCEA50FDB2F953Bull, 0xC332DDEFBE6C5AA5ull, + 0x6558218568AB9702ull, 0x2AEF7DAD5B6E2F84ull, 0x1521B62829076170ull, + 0xECDD4775619F1510ull, 0x13CCA830EB61BD96ull, 0x0334FE1EAA0363CFull, + 0xB5735C904C70A239ull, 0xD59E9E0BCBAADE14ull, 0xEECC86BC60622CA7ull}); + + std::vector<int> output(12); + + absl::uniform_real_distribution<TypeParam> dist; + std::generate(std::begin(output), std::end(output), [&] { + return static_cast<int>(TypeParam(1000000) * dist(urbg)); + }); + + EXPECT_THAT( + output, // + testing::ElementsAre(59, 999246, 762494, 395876, 167716, 82545, 925251, + 77341, 12527, 708791, 834451, 932808)); +} + +TEST(UniformRealDistributionTest, AlgorithmBounds) { + absl::uniform_real_distribution<double> dist; + + { + // This returns the smallest value >0 from absl::uniform_real_distribution. + absl::random_internal::sequence_urbg urbg({0x0000000000000001ull}); + double a = dist(urbg); + EXPECT_EQ(a, 5.42101086242752217004e-20); + } + + { + // This returns a value very near 0.5 from absl::uniform_real_distribution. + absl::random_internal::sequence_urbg urbg({0x7fffffffffffffefull}); + double a = dist(urbg); + EXPECT_EQ(a, 0.499999999999999944489); + } + { + // This returns a value very near 0.5 from absl::uniform_real_distribution. + absl::random_internal::sequence_urbg urbg({0x8000000000000000ull}); + double a = dist(urbg); + EXPECT_EQ(a, 0.5); + } + + { + // This returns the largest value <1 from absl::uniform_real_distribution. + absl::random_internal::sequence_urbg urbg({0xFFFFFFFFFFFFFFEFull}); + double a = dist(urbg); + EXPECT_EQ(a, 0.999999999999999888978); + } + { + // This *ALSO* returns the largest value <1. + absl::random_internal::sequence_urbg urbg({0xFFFFFFFFFFFFFFFFull}); + double a = dist(urbg); + EXPECT_EQ(a, 0.999999999999999888978); + } +} + +} // namespace diff --git a/absl/random/zipf_distribution.h b/absl/random/zipf_distribution.h new file mode 100644 index 000000000000..1e4dba8b4e0c --- /dev/null +++ b/absl/random/zipf_distribution.h @@ -0,0 +1,269 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ABSL_RANDOM_ZIPF_DISTRIBUTION_H_ +#define ABSL_RANDOM_ZIPF_DISTRIBUTION_H_ + +#include <cassert> +#include <cmath> +#include <istream> +#include <limits> +#include <ostream> +#include <type_traits> + +#include "absl/random/internal/iostream_state_saver.h" +#include "absl/random/uniform_real_distribution.h" + +namespace absl { + +// absl::zipf_distribution produces random integer-values in the range [0, k], +// distributed according to the discrete probability function: +// +// P(x) = (v + x) ^ -q +// +// The parameter `v` must be greater than 0 and the parameter `q` must be +// greater than 1. If either of these parameters take invalid values then the +// behavior is undefined. +// +// IntType is the result_type generated by the generator. It must be of integral +// type; a static_assert ensures this is the case. +// +// The implementation is based on W.Hormann, G.Derflinger: +// +// "Rejection-Inversion to Generate Variates from Monotone Discrete +// Distributions" +// +// http://eeyore.wu-wien.ac.at/papers/96-04-04.wh-der.ps.gz +// +template <typename IntType = int> +class zipf_distribution { + public: + using result_type = IntType; + + class param_type { + public: + using distribution_type = zipf_distribution; + + // Preconditions: k > 0, v > 0, q > 1 + // The precondidtions are validated when NDEBUG is not defined via + // a pair of assert() directives. + // If NDEBUG is defined and either or both of these parameters take invalid + // values, the behavior of the class is undefined. + explicit param_type(result_type k = (std::numeric_limits<IntType>::max)(), + double q = 2.0, double v = 1.0); + + result_type k() const { return k_; } + double q() const { return q_; } + double v() const { return v_; } + + friend bool operator==(const param_type& a, const param_type& b) { + return a.k_ == b.k_ && a.q_ == b.q_ && a.v_ == b.v_; + } + friend bool operator!=(const param_type& a, const param_type& b) { + return !(a == b); + } + + private: + friend class zipf_distribution; + inline double h(double x) const; + inline double hinv(double x) const; + inline double compute_s() const; + inline double pow_negative_q(double x) const; + + // Parameters here are exactly the same as the parameters of Algorithm ZRI + // in the paper. + IntType k_; + double q_; + double v_; + + double one_minus_q_; // 1-q + double s_; + double one_minus_q_inv_; // 1 / 1-q + double hxm_; // h(k + 0.5) + double hx0_minus_hxm_; // h(x0) - h(k + 0.5) + + static_assert(std::is_integral<IntType>::value, + "Class-template absl::zipf_distribution<> must be " + "parameterized using an integral type."); + }; + + zipf_distribution() + : zipf_distribution((std::numeric_limits<IntType>::max)()) {} + + explicit zipf_distribution(result_type k, double q = 2.0, double v = 1.0) + : param_(k, q, v) {} + + explicit zipf_distribution(const param_type& p) : param_(p) {} + + void reset() {} + + template <typename URBG> + result_type operator()(URBG& g) { // NOLINT(runtime/references) + return (*this)(g, param_); + } + + template <typename URBG> + result_type operator()(URBG& g, // NOLINT(runtime/references) + const param_type& p); + + result_type k() const { return param_.k(); } + double q() const { return param_.q(); } + double v() const { return param_.v(); } + + param_type param() const { return param_; } + void param(const param_type& p) { param_ = p; } + + result_type(min)() const { return 0; } + result_type(max)() const { return k(); } + + friend bool operator==(const zipf_distribution& a, + const zipf_distribution& b) { + return a.param_ == b.param_; + } + friend bool operator!=(const zipf_distribution& a, + const zipf_distribution& b) { + return a.param_ != b.param_; + } + + private: + param_type param_; +}; + +// -------------------------------------------------------------------------- +// Implementation details follow +// -------------------------------------------------------------------------- + +template <typename IntType> +zipf_distribution<IntType>::param_type::param_type( + typename zipf_distribution<IntType>::result_type k, double q, double v) + : k_(k), q_(q), v_(v), one_minus_q_(1 - q) { + assert(q > 1); + assert(v > 0); + assert(k > 0); + one_minus_q_inv_ = 1 / one_minus_q_; + + // Setup for the ZRI algorithm (pg 17 of the paper). + // Compute: h(i max) => h(k + 0.5) + constexpr double kMax = 18446744073709549568.0; + double kd = static_cast<double>(k); + // TODO(absl-team): Determine if this check is needed, and if so, add a test + // that fails for k > kMax + if (kd > kMax) { + // Ensure that our maximum value is capped to a value which will + // round-trip back through double. + kd = kMax; + } + hxm_ = h(kd + 0.5); + + // Compute: h(0) + const bool use_precomputed = (v == 1.0 && q == 2.0); + const double h0x5 = use_precomputed ? (-1.0 / 1.5) // exp(-log(1.5)) + : h(0.5); + const double elogv_q = (v_ == 1.0) ? 1 : pow_negative_q(v_); + + // h(0) = h(0.5) - exp(log(v) * -q) + hx0_minus_hxm_ = (h0x5 - elogv_q) - hxm_; + + // And s + s_ = use_precomputed ? 0.46153846153846123 : compute_s(); +} + +template <typename IntType> +double zipf_distribution<IntType>::param_type::h(double x) const { + // std::exp(one_minus_q_ * std::log(v_ + x)) * one_minus_q_inv_; + x += v_; + return (one_minus_q_ == -1.0) + ? (-1.0 / x) // -exp(-log(x)) + : (std::exp(std::log(x) * one_minus_q_) * one_minus_q_inv_); +} + +template <typename IntType> +double zipf_distribution<IntType>::param_type::hinv(double x) const { + // std::exp(one_minus_q_inv_ * std::log(one_minus_q_ * x)) - v_; + return -v_ + ((one_minus_q_ == -1.0) + ? (-1.0 / x) // exp(-log(-x)) + : std::exp(one_minus_q_inv_ * std::log(one_minus_q_ * x))); +} + +template <typename IntType> +double zipf_distribution<IntType>::param_type::compute_s() const { + // 1 - hinv(h(1.5) - std::exp(std::log(v_ + 1) * -q_)); + return 1.0 - hinv(h(1.5) - pow_negative_q(v_ + 1.0)); +} + +template <typename IntType> +double zipf_distribution<IntType>::param_type::pow_negative_q(double x) const { + // std::exp(std::log(x) * -q_); + return q_ == 2.0 ? (1.0 / (x * x)) : std::exp(std::log(x) * -q_); +} + +template <typename IntType> +template <typename URBG> +typename zipf_distribution<IntType>::result_type +zipf_distribution<IntType>::operator()( + URBG& g, const param_type& p) { // NOLINT(runtime/references) + absl::uniform_real_distribution<double> uniform_double; + double k; + for (;;) { + const double v = uniform_double(g); + const double u = p.hxm_ + v * p.hx0_minus_hxm_; + const double x = p.hinv(u); + k = rint(x); // std::floor(x + 0.5); + if (k > p.k()) continue; // reject k > max_k + if (k - x <= p.s_) break; + const double h = p.h(k + 0.5); + const double r = p.pow_negative_q(p.v_ + k); + if (u >= h - r) break; + } + IntType ki = static_cast<IntType>(k); + assert(ki <= p.k_); + return ki; +} + +template <typename CharT, typename Traits, typename IntType> +std::basic_ostream<CharT, Traits>& operator<<( + std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references) + const zipf_distribution<IntType>& x) { + using stream_type = + typename random_internal::stream_format_type<IntType>::type; + auto saver = random_internal::make_ostream_state_saver(os); + os.precision(random_internal::stream_precision_helper<double>::kPrecision); + os << static_cast<stream_type>(x.k()) << os.fill() << x.q() << os.fill() + << x.v(); + return os; +} + +template <typename CharT, typename Traits, typename IntType> +std::basic_istream<CharT, Traits>& operator>>( + std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references) + zipf_distribution<IntType>& x) { // NOLINT(runtime/references) + using result_type = typename zipf_distribution<IntType>::result_type; + using param_type = typename zipf_distribution<IntType>::param_type; + using stream_type = + typename random_internal::stream_format_type<IntType>::type; + stream_type k; + double q; + double v; + + auto saver = random_internal::make_istream_state_saver(is); + is >> k >> q >> v; + if (!is.fail()) { + x.param(param_type(static_cast<result_type>(k), q, v)); + } + return is; +} + +} // namespace absl. + +#endif // ABSL_RANDOM_ZIPF_DISTRIBUTION_H_ diff --git a/absl/random/zipf_distribution_test.cc b/absl/random/zipf_distribution_test.cc new file mode 100644 index 000000000000..4d4a0fcf793b --- /dev/null +++ b/absl/random/zipf_distribution_test.cc @@ -0,0 +1,423 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/random/zipf_distribution.h" + +#include <algorithm> +#include <cstddef> +#include <cstdint> +#include <iterator> +#include <random> +#include <string> +#include <utility> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/base/internal/raw_logging.h" +#include "absl/random/internal/chi_square.h" +#include "absl/random/internal/sequence_urbg.h" +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/strip.h" + +namespace { + +using ::absl::random_internal::kChiSquared; +using ::testing::ElementsAre; + +template <typename IntType> +class ZipfDistributionTypedTest : public ::testing::Test {}; + +using IntTypes = ::testing::Types<int, int8_t, int16_t, int32_t, int64_t, + uint8_t, uint16_t, uint32_t, uint64_t>; +TYPED_TEST_CASE(ZipfDistributionTypedTest, IntTypes); + +TYPED_TEST(ZipfDistributionTypedTest, SerializeTest) { + using param_type = typename absl::zipf_distribution<TypeParam>::param_type; + + constexpr int kCount = 1000; + absl::InsecureBitGen gen; + for (const auto& param : { + param_type(), + param_type(32), + param_type(100, 3, 2), + param_type(std::numeric_limits<TypeParam>::max(), 4, 3), + param_type(std::numeric_limits<TypeParam>::max() / 2), + }) { + // Validate parameters. + const auto k = param.k(); + const auto q = param.q(); + const auto v = param.v(); + + absl::zipf_distribution<TypeParam> before(k, q, v); + EXPECT_EQ(before.k(), param.k()); + EXPECT_EQ(before.q(), param.q()); + EXPECT_EQ(before.v(), param.v()); + + { + absl::zipf_distribution<TypeParam> via_param(param); + EXPECT_EQ(via_param, before); + } + + // Validate stream serialization. + std::stringstream ss; + ss << before; + absl::zipf_distribution<TypeParam> after(4, 5.5, 4.4); + + EXPECT_NE(before.k(), after.k()); + EXPECT_NE(before.q(), after.q()); + EXPECT_NE(before.v(), after.v()); + EXPECT_NE(before.param(), after.param()); + EXPECT_NE(before, after); + + ss >> after; + + EXPECT_EQ(before.k(), after.k()); + EXPECT_EQ(before.q(), after.q()); + EXPECT_EQ(before.v(), after.v()); + EXPECT_EQ(before.param(), after.param()); + EXPECT_EQ(before, after); + + // Smoke test. + auto sample_min = after.max(); + auto sample_max = after.min(); + for (int i = 0; i < kCount; i++) { + auto sample = after(gen); + EXPECT_GE(sample, after.min()); + EXPECT_LE(sample, after.max()); + if (sample > sample_max) sample_max = sample; + if (sample < sample_min) sample_min = sample; + } + ABSL_INTERNAL_LOG(INFO, + absl::StrCat("Range: ", +sample_min, ", ", +sample_max)); + } +} + +class ZipfModel { + public: + ZipfModel(size_t k, double q, double v) : k_(k), q_(q), v_(v) {} + + double mean() const { return mean_; } + + // For the other moments of the Zipf distribution, see, for example, + // http://mathworld.wolfram.com/ZipfDistribution.html + + // PMF(k) = (1 / k^s) / H(N,s) + // Returns the probability that any single invocation returns k. + double PMF(size_t i) { return i >= hnq_.size() ? 0.0 : hnq_[i] / sum_hnq_; } + + // CDF = H(k, s) / H(N,s) + double CDF(size_t i) { + if (i >= hnq_.size()) { + return 1.0; + } + auto it = std::begin(hnq_); + double h = 0.0; + for (const auto end = it; it != end; it++) { + h += *it; + } + return h / sum_hnq_; + } + + // The InverseCDF returns the k values which bound p on the upper and lower + // bound. Since there is no closed-form solution, this is implemented as a + // bisction of the cdf. + std::pair<size_t, size_t> InverseCDF(double p) { + size_t min = 0; + size_t max = hnq_.size(); + while (max > min + 1) { + size_t target = (max + min) >> 1; + double x = CDF(target); + if (x > p) { + max = target; + } else { + min = target; + } + } + return {min, max}; + } + + // Compute the probability totals, which are based on the generalized harmonic + // number, H(N,s). + // H(N,s) == SUM(k=1..N, 1 / k^s) + // + // In the limit, H(N,s) == zetac(s) + 1. + // + // NOTE: The mean of a zipf distribution could be computed here as well. + // Mean := H(N, s-1) / H(N,s). + // Given the parameter v = 1, this gives the following function: + // (Hn(100, 1) - Hn(1,1)) / (Hn(100,2) - Hn(1,2)) = 6.5944 + // + void Init() { + if (!hnq_.empty()) { + return; + } + hnq_.clear(); + hnq_.reserve(std::min(k_, size_t{1000})); + + sum_hnq_ = 0; + double qm1 = q_ - 1.0; + double sum_hnq_m1 = 0; + for (size_t i = 0; i < k_; i++) { + // Partial n-th generalized harmonic number + const double x = v_ + i; + + // H(n, q-1) + const double hnqm1 = + (q_ == 2.0) ? (1.0 / x) + : (q_ == 3.0) ? (1.0 / (x * x)) : std::pow(x, -qm1); + sum_hnq_m1 += hnqm1; + + // H(n, q) + const double hnq = + (q_ == 2.0) ? (1.0 / (x * x)) + : (q_ == 3.0) ? (1.0 / (x * x * x)) : std::pow(x, -q_); + sum_hnq_ += hnq; + hnq_.push_back(hnq); + if (i > 1000 && hnq <= 1e-10) { + // The harmonic number is too small. + break; + } + } + assert(sum_hnq_ > 0); + mean_ = sum_hnq_m1 / sum_hnq_; + } + + private: + const size_t k_; + const double q_; + const double v_; + + double mean_; + std::vector<double> hnq_; + double sum_hnq_; +}; + +using zipf_u64 = absl::zipf_distribution<uint64_t>; + +class ZipfTest : public testing::TestWithParam<zipf_u64::param_type>, + public ZipfModel { + public: + ZipfTest() : ZipfModel(GetParam().k(), GetParam().q(), GetParam().v()) {} + + absl::InsecureBitGen rng_; +}; + +TEST_P(ZipfTest, ChiSquaredTest) { + const auto& param = GetParam(); + Init(); + + size_t trials = 10000; + + // Find the split-points for the buckets. + std::vector<size_t> points; + std::vector<double> expected; + { + double last_cdf = 0.0; + double min_p = 1.0; + for (double p = 0.01; p < 1.0; p += 0.01) { + auto x = InverseCDF(p); + if (points.empty() || points.back() < x.second) { + const double p = CDF(x.second); + points.push_back(x.second); + double q = p - last_cdf; + expected.push_back(q); + last_cdf = p; + if (q < min_p) { + min_p = q; + } + } + } + if (last_cdf < 0.999) { + points.push_back(std::numeric_limits<size_t>::max()); + double q = 1.0 - last_cdf; + expected.push_back(q); + if (q < min_p) { + min_p = q; + } + } else { + points.back() = std::numeric_limits<size_t>::max(); + expected.back() += (1.0 - last_cdf); + } + // The Chi-Squared score is not completely scale-invariant; it works best + // when the small values are in the small digits. + trials = static_cast<size_t>(8.0 / min_p); + } + ASSERT_GT(points.size(), 0); + + // Generate n variates and fill the counts vector with the count of their + // occurrences. + std::vector<int64_t> buckets(points.size(), 0); + double avg = 0; + { + zipf_u64 dis(param); + for (size_t i = 0; i < trials; i++) { + uint64_t x = dis(rng_); + ASSERT_LE(x, dis.max()); + ASSERT_GE(x, dis.min()); + avg += static_cast<double>(x); + auto it = std::upper_bound(std::begin(points), std::end(points), + static_cast<size_t>(x)); + buckets[std::distance(std::begin(points), it)]++; + } + avg = avg / static_cast<double>(trials); + } + + // Validate the output using the Chi-Squared test. + for (auto& e : expected) { + e *= trials; + } + + // The null-hypothesis is that the distribution is a poisson distribution with + // the provided mean (not estimated from the data). + const int dof = static_cast<int>(expected.size()) - 1; + + // NOTE: This test runs about 15x per invocation, so a value of 0.9995 is + // approximately correct for a test suite failure rate of 1 in 100. In + // practice we see failures slightly higher than that. + const double threshold = absl::random_internal::ChiSquareValue(dof, 0.9999); + + const double chi_square = absl::random_internal::ChiSquare( + std::begin(buckets), std::end(buckets), std::begin(expected), + std::end(expected)); + + const double p_actual = + absl::random_internal::ChiSquarePValue(chi_square, dof); + + // Log if the chi_squared value is above the threshold. + if (chi_square > threshold) { + ABSL_INTERNAL_LOG(INFO, "values"); + for (size_t i = 0; i < expected.size(); i++) { + ABSL_INTERNAL_LOG(INFO, absl::StrCat(points[i], ": ", buckets[i], + " vs. E=", expected[i])); + } + ABSL_INTERNAL_LOG(INFO, absl::StrCat("trials ", trials)); + ABSL_INTERNAL_LOG(INFO, + absl::StrCat("mean ", avg, " vs. expected ", mean())); + ABSL_INTERNAL_LOG(INFO, absl::StrCat(kChiSquared, "(data, ", dof, ") = ", + chi_square, " (", p_actual, ")")); + ABSL_INTERNAL_LOG(INFO, + absl::StrCat(kChiSquared, " @ 0.9995 = ", threshold)); + FAIL() << kChiSquared << " value of " << chi_square + << " is above the threshold."; + } +} + +std::vector<zipf_u64::param_type> GenParams() { + using param = zipf_u64::param_type; + const auto k = param().k(); + const auto q = param().q(); + const auto v = param().v(); + const uint64_t k2 = 1 << 10; + return std::vector<zipf_u64::param_type>{ + // Default + param(k, q, v), + // vary K + param(4, q, v), param(1 << 4, q, v), param(k2, q, v), + // vary V + param(k2, q, 0.5), param(k2, q, 1.5), param(k2, q, 2.5), param(k2, q, 10), + // vary Q + param(k2, 1.5, v), param(k2, 3, v), param(k2, 5, v), param(k2, 10, v), + // Vary V & Q + param(k2, 1.5, 0.5), param(k2, 3, 1.5), param(k, 10, 10)}; +} + +std::string ParamName( + const ::testing::TestParamInfo<zipf_u64::param_type>& info) { + const auto& p = info.param; + std::string name = absl::StrCat("k_", p.k(), "__q_", absl::SixDigits(p.q()), + "__v_", absl::SixDigits(p.v())); + return absl::StrReplaceAll(name, {{"+", "_"}, {"-", "_"}, {".", "_"}}); +} + +INSTANTIATE_TEST_SUITE_P(All, ZipfTest, ::testing::ValuesIn(GenParams()), + ParamName); + +// NOTE: absl::zipf_distribution is not guaranteed to be stable. +TEST(ZipfDistributionTest, StabilityTest) { + // absl::zipf_distribution stability relies on + // absl::uniform_real_distribution, std::log, std::exp, std::log1p + absl::random_internal::sequence_urbg urbg( + {0x0003eb76f6f7f755ull, 0xFFCEA50FDB2F953Bull, 0xC332DDEFBE6C5AA5ull, + 0x6558218568AB9702ull, 0x2AEF7DAD5B6E2F84ull, 0x1521B62829076170ull, + 0xECDD4775619F1510ull, 0x13CCA830EB61BD96ull, 0x0334FE1EAA0363CFull, + 0xB5735C904C70A239ull, 0xD59E9E0BCBAADE14ull, 0xEECC86BC60622CA7ull}); + + std::vector<int> output(10); + + { + absl::zipf_distribution<int32_t> dist; + std::generate(std::begin(output), std::end(output), + [&] { return dist(urbg); }); + EXPECT_THAT(output, ElementsAre(10031, 0, 0, 3, 6, 0, 7, 47, 0, 0)); + } + urbg.reset(); + { + absl::zipf_distribution<int32_t> dist(std::numeric_limits<int32_t>::max(), + 3.3); + std::generate(std::begin(output), std::end(output), + [&] { return dist(urbg); }); + EXPECT_THAT(output, ElementsAre(44, 0, 0, 0, 0, 1, 0, 1, 3, 0)); + } +} + +TEST(ZipfDistributionTest, AlgorithmBounds) { + absl::zipf_distribution<int32_t> dist; + + // Small values from absl::uniform_real_distribution map to larger Zipf + // distribution values. + const std::pair<uint64_t, int32_t> kInputs[] = { + {0xffffffffffffffff, 0x0}, {0x7fffffffffffffff, 0x0}, + {0x3ffffffffffffffb, 0x1}, {0x1ffffffffffffffd, 0x4}, + {0xffffffffffffffe, 0x9}, {0x7ffffffffffffff, 0x12}, + {0x3ffffffffffffff, 0x25}, {0x1ffffffffffffff, 0x4c}, + {0xffffffffffffff, 0x99}, {0x7fffffffffffff, 0x132}, + {0x3fffffffffffff, 0x265}, {0x1fffffffffffff, 0x4cc}, + {0xfffffffffffff, 0x999}, {0x7ffffffffffff, 0x1332}, + {0x3ffffffffffff, 0x2665}, {0x1ffffffffffff, 0x4ccc}, + {0xffffffffffff, 0x9998}, {0x7fffffffffff, 0x1332f}, + {0x3fffffffffff, 0x2665a}, {0x1fffffffffff, 0x4cc9e}, + {0xfffffffffff, 0x998e0}, {0x7ffffffffff, 0x133051}, + {0x3ffffffffff, 0x265ae4}, {0x1ffffffffff, 0x4c9ed3}, + {0xffffffffff, 0x98e223}, {0x7fffffffff, 0x13058c4}, + {0x3fffffffff, 0x25b178e}, {0x1fffffffff, 0x4a062b2}, + {0xfffffffff, 0x8ee23b8}, {0x7ffffffff, 0x10b21642}, + {0x3ffffffff, 0x1d89d89d}, {0x1ffffffff, 0x2fffffff}, + {0xffffffff, 0x45d1745d}, {0x7fffffff, 0x5a5a5a5a}, + {0x3fffffff, 0x69ee5846}, {0x1fffffff, 0x73ecade3}, + {0xfffffff, 0x79a9d260}, {0x7ffffff, 0x7cc0532b}, + {0x3ffffff, 0x7e5ad146}, {0x1ffffff, 0x7f2c0bec}, + {0xffffff, 0x7f95adef}, {0x7fffff, 0x7fcac0da}, + {0x3fffff, 0x7fe55ae2}, {0x1fffff, 0x7ff2ac0e}, + {0xfffff, 0x7ff955ae}, {0x7ffff, 0x7ffcaac1}, + {0x3ffff, 0x7ffe555b}, {0x1ffff, 0x7fff2aac}, + {0xffff, 0x7fff9556}, {0x7fff, 0x7fffcaab}, + {0x3fff, 0x7fffe555}, {0x1fff, 0x7ffff2ab}, + {0xfff, 0x7ffff955}, {0x7ff, 0x7ffffcab}, + {0x3ff, 0x7ffffe55}, {0x1ff, 0x7fffff2b}, + {0xff, 0x7fffff95}, {0x7f, 0x7fffffcb}, + {0x3f, 0x7fffffe5}, {0x1f, 0x7ffffff3}, + {0xf, 0x7ffffff9}, {0x7, 0x7ffffffd}, + {0x3, 0x7ffffffe}, {0x1, 0x7fffffff}, + }; + + for (const auto& instance : kInputs) { + absl::random_internal::sequence_urbg urbg({instance.first}); + EXPECT_EQ(instance.second, dist(urbg)); + } +} + +} // namespace diff --git a/absl/strings/BUILD.bazel b/absl/strings/BUILD.bazel index d6ce88da3243..f18b16001e4e 100644 --- a/absl/strings/BUILD.bazel +++ b/absl/strings/BUILD.bazel @@ -557,6 +557,7 @@ cc_library( visibility = ["//visibility:private"], deps = [ ":strings", + "//absl/base:config", "//absl/base:core_headers", "//absl/container:inlined_vector", "//absl/meta:type_traits", diff --git a/absl/strings/CMakeLists.txt b/absl/strings/CMakeLists.txt index 8515ec2b70f9..f7821290b7de 100644 --- a/absl/strings/CMakeLists.txt +++ b/absl/strings/CMakeLists.txt @@ -385,6 +385,7 @@ absl_cc_library( ${ABSL_DEFAULT_COPTS} DEPS absl::strings + absl::config absl::core_headers absl::inlined_vector absl::type_traits diff --git a/absl/strings/internal/str_format/float_conversion.cc b/absl/strings/internal/str_format/float_conversion.cc index 6176db9cb5a2..9236acdc74cb 100644 --- a/absl/strings/internal/str_format/float_conversion.cc +++ b/absl/strings/internal/str_format/float_conversion.cc @@ -6,6 +6,8 @@ #include <cmath> #include <string> +#include "absl/base/config.h" + namespace absl { namespace str_format_internal { @@ -334,7 +336,7 @@ bool FloatToBuffer(Decomposed<Float> decomposed, int precision, Buffer *out, static_cast<std::uint64_t>(decomposed.exponent), precision, out, exp)) return true; -#if defined(__SIZEOF_INT128__) +#if defined(ABSL_HAVE_INTRINSIC_INT128) // If that is not enough, try with __uint128_t. return CanFitMantissa<Float, __uint128_t>() && FloatToBufferImpl<__uint128_t, Float, mode>( diff --git a/absl/synchronization/mutex_test.cc b/absl/synchronization/mutex_test.cc index 2e10f09816df..9851ac19824a 100644 --- a/absl/synchronization/mutex_test.cc +++ b/absl/synchronization/mutex_test.cc @@ -14,7 +14,7 @@ #include "absl/synchronization/mutex.h" -#ifdef WIN32 +#ifdef _WIN32 #include <windows.h> #endif @@ -1035,7 +1035,7 @@ TEST(Mutex, DeadlockDetector) { class ScopedDisableBazelTestWarnings { public: ScopedDisableBazelTestWarnings() { -#ifdef WIN32 +#ifdef _WIN32 char file[MAX_PATH]; if (GetEnvironmentVariableA(kVarName, file, sizeof(file)) < sizeof(file)) { warnings_output_file_ = file; @@ -1052,7 +1052,7 @@ class ScopedDisableBazelTestWarnings { ~ScopedDisableBazelTestWarnings() { if (!warnings_output_file_.empty()) { -#ifdef WIN32 +#ifdef _WIN32 SetEnvironmentVariableA(kVarName, warnings_output_file_.c_str()); #else setenv(kVarName, warnings_output_file_.c_str(), 0); diff --git a/absl/time/clock_benchmark.cc b/absl/time/clock_benchmark.cc index a69fe00ba263..c5c795ecbd23 100644 --- a/absl/time/clock_benchmark.cc +++ b/absl/time/clock_benchmark.cc @@ -15,6 +15,8 @@ #if !defined(_WIN32) #include <sys/time.h> +#else +#include <winsock2.h> #endif // _WIN32 #include <cstdio> diff --git a/absl/time/duration.cc b/absl/time/duration.cc index 67791feedfb5..a3ac61a91894 100644 --- a/absl/time/duration.cc +++ b/absl/time/duration.cc @@ -49,6 +49,10 @@ // // Arithmetic overflows/underflows to +/- infinity and saturates. +#if defined(_MSC_VER) +#include <winsock2.h> // for timeval +#endif + #include <algorithm> #include <cassert> #include <cctype> diff --git a/absl/time/duration_test.cc b/absl/time/duration_test.cc index e3cede6ee9d4..5dce9ac8a557 100644 --- a/absl/time/duration_test.cc +++ b/absl/time/duration_test.cc @@ -12,6 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#if defined(_MSC_VER) +#include <winsock2.h> // for timeval +#endif + #include <chrono> // NOLINT(build/c++11) #include <cmath> #include <cstdint> diff --git a/absl/time/internal/cctz/src/time_zone_format.cc b/absl/time/internal/cctz/src/time_zone_format.cc index 25850980b541..aa8898750dc8 100644 --- a/absl/time/internal/cctz/src/time_zone_format.cc +++ b/absl/time/internal/cctz/src/time_zone_format.cc @@ -18,8 +18,18 @@ # endif #endif +#if defined(HAS_STRPTIME) && HAS_STRPTIME +# if !defined(_XOPEN_SOURCE) +# define _XOPEN_SOURCE // Definedness suffices for strptime. +# endif +#endif + #include "absl/time/internal/cctz/include/cctz/time_zone.h" +// Include time.h directly since, by C++ standards, ctime doesn't have to +// declare strptime. +#include <time.h> + #include <cctype> #include <chrono> #include <cstddef> diff --git a/absl/time/time.cc b/absl/time/time.cc index 977a95173db1..338c4523d1b1 100644 --- a/absl/time/time.cc +++ b/absl/time/time.cc @@ -33,6 +33,10 @@ #include "absl/time/time.h" +#if defined(_MSC_VER) +#include <winsock2.h> // for timeval +#endif + #include <cstring> #include <ctime> #include <limits> diff --git a/absl/time/time.h b/absl/time/time.h index 594396c751df..0534780516d0 100644 --- a/absl/time/time.h +++ b/absl/time/time.h @@ -65,7 +65,14 @@ #if !defined(_MSC_VER) #include <sys/time.h> #else -#include <winsock2.h> +// We don't include `winsock2.h` because it drags in `windows.h` and friends, +// and they define conflicting macros like OPAQUE, ERROR, and more. This has the +// potential to break Abseil users. +// +// Instead we only forward declare `timeval` and require Windows users include +// `winsock2.h` themselves. This is both inconsistent and troublesome, but so is +// including 'windows.h' so we are picking the lesser of two evils here. +struct timeval; #endif #include <chrono> // NOLINT(build/c++11) #include <cmath> diff --git a/absl/time/time_test.cc b/absl/time/time_test.cc index 4d791f4d11aa..e2b2f8097021 100644 --- a/absl/time/time_test.cc +++ b/absl/time/time_test.cc @@ -14,6 +14,10 @@ #include "absl/time/time.h" +#if defined(_MSC_VER) +#include <winsock2.h> // for timeval +#endif + #include <chrono> // NOLINT(build/c++11) #include <cstring> #include <ctime> diff --git a/absl/utility/utility.h b/absl/utility/utility.h index 7686db1250ae..eef8fb41fe8e 100644 --- a/absl/utility/utility.h +++ b/absl/utility/utility.h @@ -25,6 +25,7 @@ // * index_sequence_for<Ts...> == std::index_sequence_for<Ts...> // * apply<Functor, Tuple> == std::apply<Functor, Tuple> // * exchange<T> == std::exchange<T> +// * make_from_tuple<T> == std::make_from_tuple<T> // // This header file also provides the tag types `in_place_t`, `in_place_type_t`, // and `in_place_index_t`, as well as the constant `in_place`, and @@ -315,6 +316,33 @@ T exchange(T& obj, U&& new_value) { return old_value; } +namespace utility_internal { +template <typename T, typename Tuple, size_t... I> +T make_from_tuple_impl(Tuple&& tup, absl::index_sequence<I...>) { + return T(std::get<I>(std::forward<Tuple>(tup))...); +} +} // namespace utility_internal + +// make_from_tuple +// +// Given the template parameter type `T` and a tuple of arguments +// `std::tuple(arg0, arg1, ..., argN)` constructs an object of type `T` as if by +// calling `T(arg0, arg1, ..., argN)`. +// +// Example: +// +// std::tuple<const char*, size_t> args("hello world", 5); +// auto s = absl::make_from_tuple<std::string>(args); +// assert(s == "hello"); +// +template <typename T, typename Tuple> +constexpr T make_from_tuple(Tuple&& tup) { + return utility_internal::make_from_tuple_impl<T>( + std::forward<Tuple>(tup), + absl::make_index_sequence< + std::tuple_size<absl::decay_t<Tuple>>::value>{}); +} + } // namespace absl #endif // ABSL_UTILITY_UTILITY_H_ diff --git a/absl/utility/utility_test.cc b/absl/utility/utility_test.cc index 5a4972b60085..f044ad644a4d 100644 --- a/absl/utility/utility_test.cc +++ b/absl/utility/utility_test.cc @@ -341,5 +341,36 @@ TEST(ExchangeTest, MoveOnly) { EXPECT_EQ(1, *b); } +TEST(MakeFromTupleTest, String) { + EXPECT_EQ( + absl::make_from_tuple<std::string>(std::make_tuple("hello world", 5)), + "hello"); +} + +TEST(MakeFromTupleTest, MoveOnlyParameter) { + struct S { + S(std::unique_ptr<int> n, std::unique_ptr<int> m) : value(*n + *m) {} + int value = 0; + }; + auto tup = + std::make_tuple(absl::make_unique<int>(3), absl::make_unique<int>(4)); + auto s = absl::make_from_tuple<S>(std::move(tup)); + EXPECT_EQ(s.value, 7); +} + +TEST(MakeFromTupleTest, NoParameters) { + struct S { + S() : value(1) {} + int value = 2; + }; + EXPECT_EQ(absl::make_from_tuple<S>(std::make_tuple()).value, 1); +} + +TEST(MakeFromTupleTest, Pair) { + EXPECT_EQ( + (absl::make_from_tuple<std::pair<bool, int>>(std::make_tuple(true, 17))), + std::make_pair(true, 17)); +} + } // namespace |