about summary refs log tree commit diff
path: root/third_party/abseil_cpp/absl/random/internal/fast_uniform_bits.h
diff options
context:
space:
mode:
authorVincent Ambo <mail@tazj.in>2020-11-21T13·43+0100
committerVincent Ambo <mail@tazj.in>2020-11-21T14·48+0100
commit082c006c04343a78d87b6c6ab3608c25d6213c3f (patch)
tree16e6f04f8d1d1d2d67e8e917d5e7bb48c1b60375 /third_party/abseil_cpp/absl/random/internal/fast_uniform_bits.h
parentcc27324d0226953943f408ce3c69ad7d648e005e (diff)
merge(3p/absl): subtree merge of Abseil up to e19260f r/1889
... notably, this includes Abseil's own StatusOr type, which
conflicted with our implementation (that was taken from TensorFlow).

Change-Id: Ie7d6764b64055caaeb8dc7b6b9d066291e6b538f
Diffstat (limited to 'third_party/abseil_cpp/absl/random/internal/fast_uniform_bits.h')
-rw-r--r--third_party/abseil_cpp/absl/random/internal/fast_uniform_bits.h202
1 files changed, 103 insertions, 99 deletions
diff --git a/third_party/abseil_cpp/absl/random/internal/fast_uniform_bits.h b/third_party/abseil_cpp/absl/random/internal/fast_uniform_bits.h
index f13c8729f7..425aaf7d83 100644
--- a/third_party/abseil_cpp/absl/random/internal/fast_uniform_bits.h
+++ b/third_party/abseil_cpp/absl/random/internal/fast_uniform_bits.h
@@ -21,6 +21,7 @@
 #include <type_traits>
 
 #include "absl/base/config.h"
+#include "absl/meta/type_traits.h"
 
 namespace absl {
 ABSL_NAMESPACE_BEGIN
@@ -38,28 +39,17 @@ constexpr bool IsPowerOfTwoOrZero(UIntType n) {
 template <typename URBG>
 constexpr typename URBG::result_type RangeSize() {
   using result_type = typename URBG::result_type;
+  static_assert((URBG::max)() != (URBG::min)(), "URBG range cannot be 0.");
   return ((URBG::max)() == (std::numeric_limits<result_type>::max)() &&
           (URBG::min)() == std::numeric_limits<result_type>::lowest())
              ? result_type{0}
-             : (URBG::max)() - (URBG::min)() + result_type{1};
-}
-
-template <typename UIntType>
-constexpr UIntType LargestPowerOfTwoLessThanOrEqualTo(UIntType n) {
-  return n < 2 ? n : 2 * LargestPowerOfTwoLessThanOrEqualTo(n / 2);
-}
-
-// Given a URBG generating values in the closed interval [Lo, Hi], returns the
-// largest power of two less than or equal to `Hi - Lo + 1`.
-template <typename URBG>
-constexpr typename URBG::result_type PowerOfTwoSubRangeSize() {
-  return LargestPowerOfTwoLessThanOrEqualTo(RangeSize<URBG>());
+             : ((URBG::max)() - (URBG::min)() + result_type{1});
 }
 
 // Computes the floor of the log. (i.e., std::floor(std::log2(N));
 template <typename UIntType>
 constexpr UIntType IntegerLog2(UIntType n) {
-  return (n <= 1) ? 0 : 1 + IntegerLog2(n / 2);
+  return (n <= 1) ? 0 : 1 + IntegerLog2(n >> 1);
 }
 
 // Returns the number of bits of randomness returned through
@@ -68,18 +58,23 @@ template <typename URBG>
 constexpr size_t NumBits() {
   return RangeSize<URBG>() == 0
              ? std::numeric_limits<typename URBG::result_type>::digits
-             : IntegerLog2(PowerOfTwoSubRangeSize<URBG>());
+             : IntegerLog2(RangeSize<URBG>());
 }
 
 // Given a shift value `n`, constructs a mask with exactly the low `n` bits set.
 // If `n == 0`, all bits are set.
 template <typename UIntType>
-constexpr UIntType MaskFromShift(UIntType n) {
+constexpr UIntType MaskFromShift(size_t n) {
   return ((n % std::numeric_limits<UIntType>::digits) == 0)
              ? ~UIntType{0}
              : (UIntType{1} << n) - UIntType{1};
 }
 
+// Tags used to dispatch FastUniformBits::generate to the simple or more complex
+// entropy extraction algorithm.
+struct SimplifiedLoopTag {};
+struct RejectionLoopTag {};
+
 // FastUniformBits implements a fast path to acquire uniform independent bits
 // from a type which conforms to the [rand.req.urbg] concept.
 // Parameterized by:
@@ -107,50 +102,16 @@ class FastUniformBits {
                 "Class-template FastUniformBits<> must be parameterized using "
                 "an unsigned type.");
 
-  // PowerOfTwoVariate() generates a single random variate, always returning a
-  // value in the half-open interval `[0, PowerOfTwoSubRangeSize<URBG>())`. If
-  // the URBG already generates values in a power-of-two range, the generator
-  // itself is used. Otherwise, we use rejection sampling on the largest
-  // possible power-of-two-sized subrange.
-  struct PowerOfTwoTag {};
-  struct RejectionSamplingTag {};
-  template <typename URBG>
-  static typename URBG::result_type PowerOfTwoVariate(
-      URBG& g) {  // NOLINT(runtime/references)
-    using tag =
-        typename std::conditional<IsPowerOfTwoOrZero(RangeSize<URBG>()),
-                                  PowerOfTwoTag, RejectionSamplingTag>::type;
-    return PowerOfTwoVariate(g, tag{});
-  }
-
-  template <typename URBG>
-  static typename URBG::result_type PowerOfTwoVariate(
-      URBG& g,  // NOLINT(runtime/references)
-      PowerOfTwoTag) {
-    return g() - (URBG::min)();
-  }
-
-  template <typename URBG>
-  static typename URBG::result_type PowerOfTwoVariate(
-      URBG& g,  // NOLINT(runtime/references)
-      RejectionSamplingTag) {
-    // Use rejection sampling to ensure uniformity across the range.
-    typename URBG::result_type u;
-    do {
-      u = g() - (URBG::min)();
-    } while (u >= PowerOfTwoSubRangeSize<URBG>());
-    return u;
-  }
-
   // Generate() generates a random value, dispatched on whether
-  // the underlying URBG must loop over multiple calls or not.
+  // the underlying URBG must use rejection sampling to generate a value,
+  // or whether a simplified loop will suffice.
   template <typename URBG>
   result_type Generate(URBG& g,  // NOLINT(runtime/references)
-                       std::true_type /* avoid_looping */);
+                       SimplifiedLoopTag);
 
   template <typename URBG>
   result_type Generate(URBG& g,  // NOLINT(runtime/references)
-                       std::false_type /* avoid_looping */);
+                       RejectionLoopTag);
 };
 
 template <typename UIntType>
@@ -162,31 +123,47 @@ FastUniformBits<UIntType>::operator()(URBG& g) {  // NOLINT(runtime/references)
   // Y = (2 ^ kRange) - 1
   static_assert((URBG::max)() > (URBG::min)(),
                 "URBG::max and URBG::min may not be equal.");
-  using urbg_result_type = typename URBG::result_type;
-  constexpr urbg_result_type kRangeMask =
-      RangeSize<URBG>() == 0
-          ? (std::numeric_limits<urbg_result_type>::max)()
-          : static_cast<urbg_result_type>(PowerOfTwoSubRangeSize<URBG>() - 1);
-  return Generate(g, std::integral_constant<bool, (kRangeMask >= (max)())>{});
+
+  using tag = absl::conditional_t<IsPowerOfTwoOrZero(RangeSize<URBG>()),
+                                  SimplifiedLoopTag, RejectionLoopTag>;
+  return Generate(g, tag{});
 }
 
 template <typename UIntType>
 template <typename URBG>
 typename FastUniformBits<UIntType>::result_type
 FastUniformBits<UIntType>::Generate(URBG& g,  // NOLINT(runtime/references)
-                                    std::true_type /* avoid_looping */) {
-  // The width of the result_type is less than than the width of the random bits
-  // provided by URBG.  Thus, generate a single value and then simply mask off
-  // the required bits.
+                                    SimplifiedLoopTag) {
+  // The simplified version of FastUniformBits works only on URBGs that have
+  // a range that is a power of 2. In this case we simply loop and shift without
+  // attempting to balance the bits across calls.
+  static_assert(IsPowerOfTwoOrZero(RangeSize<URBG>()),
+                "incorrect Generate tag for URBG instance");
+
+  static constexpr size_t kResultBits =
+      std::numeric_limits<result_type>::digits;
+  static constexpr size_t kUrbgBits = NumBits<URBG>();
+  static constexpr size_t kIters =
+      (kResultBits / kUrbgBits) + (kResultBits % kUrbgBits != 0);
+  static constexpr size_t kShift = (kIters == 1) ? 0 : kUrbgBits;
+  static constexpr auto kMin = (URBG::min)();
 
-  return PowerOfTwoVariate(g) & (max)();
+  result_type r = static_cast<result_type>(g() - kMin);
+  for (size_t n = 1; n < kIters; ++n) {
+    r = (r << kShift) + static_cast<result_type>(g() - kMin);
+  }
+  return r;
 }
 
 template <typename UIntType>
 template <typename URBG>
 typename FastUniformBits<UIntType>::result_type
 FastUniformBits<UIntType>::Generate(URBG& g,  // NOLINT(runtime/references)
-                                    std::false_type /* avoid_looping */) {
+                                    RejectionLoopTag) {
+  static_assert(!IsPowerOfTwoOrZero(RangeSize<URBG>()),
+                "incorrect Generate tag for URBG instance");
+  using urbg_result_type = typename URBG::result_type;
+
   // See [rand.adapt.ibits] for more details on the constants calculated below.
   //
   // It is preferable to use roughly the same number of bits from each generator
@@ -199,21 +176,44 @@ FastUniformBits<UIntType>::Generate(URBG& g,  // NOLINT(runtime/references)
   // `kSmallIters` and `kLargeIters` times respectively such
   // that
   //
-  //    `kTotalWidth == kSmallIters * kSmallWidth
-  //                    + kLargeIters * kLargeWidth`
+  //    `kResultBits == kSmallIters * kSmallBits
+  //                    + kLargeIters * kLargeBits`
   //
-  // where `kTotalWidth` is the total number of bits in `result_type`.
+  // where `kResultBits` is the total number of bits in `result_type`.
   //
-  constexpr size_t kTotalWidth = std::numeric_limits<result_type>::digits;
-  constexpr size_t kUrbgWidth = NumBits<URBG>();
-  constexpr size_t kTotalIters =
-      kTotalWidth / kUrbgWidth + (kTotalWidth % kUrbgWidth != 0);
-  constexpr size_t kSmallWidth = kTotalWidth / kTotalIters;
-  constexpr size_t kLargeWidth = kSmallWidth + 1;
+  static constexpr size_t kResultBits =
+      std::numeric_limits<result_type>::digits;                      // w
+  static constexpr urbg_result_type kUrbgRange = RangeSize<URBG>();  // R
+  static constexpr size_t kUrbgBits = NumBits<URBG>();               // m
+
+  // compute the initial estimate of the bits used.
+  // [rand.adapt.ibits] 2 (c)
+  static constexpr size_t kA =  // ceil(w/m)
+      (kResultBits / kUrbgBits) + ((kResultBits % kUrbgBits) != 0);  // n'
+
+  static constexpr size_t kABits = kResultBits / kA;  // w0'
+  static constexpr urbg_result_type kARejection =
+      ((kUrbgRange >> kABits) << kABits);  // y0'
+
+  // refine the selection to reduce the rejection frequency.
+  static constexpr size_t kTotalIters =
+      ((kUrbgRange - kARejection) <= (kARejection / kA)) ? kA : (kA + 1);  // n
+
+  // [rand.adapt.ibits] 2 (b)
+  static constexpr size_t kSmallIters =
+      kTotalIters - (kResultBits % kTotalIters);                   // n0
+  static constexpr size_t kSmallBits = kResultBits / kTotalIters;  // w0
+  static constexpr urbg_result_type kSmallRejection =
+      ((kUrbgRange >> kSmallBits) << kSmallBits);  // y0
+
+  static constexpr size_t kLargeBits = kSmallBits + 1;  // w0+1
+  static constexpr urbg_result_type kLargeRejection =
+      ((kUrbgRange >> kLargeBits) << kLargeBits);  // y1
+
   //
-  // Because `kLargeWidth == kSmallWidth + 1`, it follows that
+  // Because `kLargeBits == kSmallBits + 1`, it follows that
   //
-  //     `kTotalWidth == kTotalIters * kSmallWidth + kLargeIters`
+  //     `kResultBits == kSmallIters * kSmallBits + kLargeIters`
   //
   // and therefore
   //
@@ -224,36 +224,40 @@ FastUniformBits<UIntType>::Generate(URBG& g,  // NOLINT(runtime/references)
   // mentioned above, if the URBG width is a divisor of `kTotalWidth`, then
   // there would be no need for any large iterations (i.e., one loop would
   // suffice), and indeed, in this case, `kLargeIters` would be zero.
-  constexpr size_t kLargeIters = kTotalWidth % kSmallWidth;
-  constexpr size_t kSmallIters =
-      (kTotalWidth - (kLargeWidth * kLargeIters)) / kSmallWidth;
+  static_assert(kResultBits == kSmallIters * kSmallBits +
+                                   (kTotalIters - kSmallIters) * kLargeBits,
+                "Error in looping constant calculations.");
 
-  static_assert(
-      kTotalWidth == kSmallIters * kSmallWidth + kLargeIters * kLargeWidth,
-      "Error in looping constant calculations.");
+  // The small shift is essentially small bits, but due to the potential
+  // of generating a smaller result_type from a larger urbg type, the actual
+  // shift might be 0.
+  static constexpr size_t kSmallShift = kSmallBits % kResultBits;
+  static constexpr auto kSmallMask =
+      MaskFromShift<urbg_result_type>(kSmallShift);
+  static constexpr size_t kLargeShift = kLargeBits % kResultBits;
+  static constexpr auto kLargeMask =
+      MaskFromShift<urbg_result_type>(kLargeShift);
 
-  result_type s = 0;
+  static constexpr auto kMin = (URBG::min)();
 
-  constexpr size_t kSmallShift = kSmallWidth % kTotalWidth;
-  constexpr result_type kSmallMask = MaskFromShift(result_type{kSmallShift});
+  result_type s = 0;
   for (size_t n = 0; n < kSmallIters; ++n) {
-    s = (s << kSmallShift) +
-        (static_cast<result_type>(PowerOfTwoVariate(g)) & kSmallMask);
-  }
+    urbg_result_type v;
+    do {
+      v = g() - kMin;
+    } while (v >= kSmallRejection);
 
-  constexpr size_t kLargeShift = kLargeWidth % kTotalWidth;
-  constexpr result_type kLargeMask = MaskFromShift(result_type{kLargeShift});
-  for (size_t n = 0; n < kLargeIters; ++n) {
-    s = (s << kLargeShift) +
-        (static_cast<result_type>(PowerOfTwoVariate(g)) & kLargeMask);
+    s = (s << kSmallShift) + static_cast<result_type>(v & kSmallMask);
   }
 
-  static_assert(
-      kLargeShift == kSmallShift + 1 ||
-          (kLargeShift == 0 &&
-           kSmallShift == std::numeric_limits<result_type>::digits - 1),
-      "Error in looping constant calculations");
+  for (size_t n = kSmallIters; n < kTotalIters; ++n) {
+    urbg_result_type v;
+    do {
+      v = g() - kMin;
+    } while (v >= kLargeRejection);
 
+    s = (s << kLargeShift) + static_cast<result_type>(v & kLargeMask);
+  }
   return s;
 }