about summary refs log tree commit diff
path: root/third_party/abseil_cpp/absl/random/bit_gen_ref.h
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/abseil_cpp/absl/random/bit_gen_ref.h')
-rw-r--r--third_party/abseil_cpp/absl/random/bit_gen_ref.h107
1 files changed, 68 insertions, 39 deletions
diff --git a/third_party/abseil_cpp/absl/random/bit_gen_ref.h b/third_party/abseil_cpp/absl/random/bit_gen_ref.h
index 59591a479d8c..9555460fd491 100644
--- a/third_party/abseil_cpp/absl/random/bit_gen_ref.h
+++ b/third_party/abseil_cpp/absl/random/bit_gen_ref.h
@@ -24,11 +24,11 @@
 #ifndef ABSL_RANDOM_BIT_GEN_REF_H_
 #define ABSL_RANDOM_BIT_GEN_REF_H_
 
+#include "absl/base/internal/fast_type_id.h"
 #include "absl/base/macros.h"
 #include "absl/meta/type_traits.h"
 #include "absl/random/internal/distribution_caller.h"
 #include "absl/random/internal/fast_uniform_bits.h"
-#include "absl/random/internal/mocking_bit_gen_base.h"
 
 namespace absl {
 ABSL_NAMESPACE_BEGIN
@@ -51,6 +51,10 @@ struct is_urbg<
         typename std::decay<decltype(std::declval<URBG>()())>::type>::value>>
     : std::true_type {};
 
+template <typename>
+struct DistributionCaller;
+class MockHelpers;
+
 }  // namespace random_internal
 
 // -----------------------------------------------------------------------------
@@ -77,23 +81,50 @@ struct is_urbg<
 //    }
 //
 class BitGenRef {
- public:
-  using result_type = uint64_t;
+  // SFINAE to detect whether the URBG type includes a member matching
+  // bool InvokeMock(base_internal::FastTypeIdType, void*, void*).
+  //
+  // These live inside BitGenRef so that they have friend access
+  // to MockingBitGen. (see similar methods in DistributionCaller).
+  template <template <class...> class Trait, class AlwaysVoid, class... Args>
+  struct detector : std::false_type {};
+  template <template <class...> class Trait, class... Args>
+  struct detector<Trait, absl::void_t<Trait<Args...>>, Args...>
+      : std::true_type {};
+
+  template <class T>
+  using invoke_mock_t = decltype(std::declval<T*>()->InvokeMock(
+      std::declval<base_internal::FastTypeIdType>(), std::declval<void*>(),
+      std::declval<void*>()));
+
+  template <typename T>
+  using HasInvokeMock = typename detector<invoke_mock_t, void, T>::type;
 
-  BitGenRef(const absl::BitGenRef&) = default;
-  BitGenRef(absl::BitGenRef&&) = default;
-  BitGenRef& operator=(const absl::BitGenRef&) = default;
-  BitGenRef& operator=(absl::BitGenRef&&) = default;
+ public:
+  BitGenRef(const BitGenRef&) = default;
+  BitGenRef(BitGenRef&&) = default;
+  BitGenRef& operator=(const BitGenRef&) = default;
+  BitGenRef& operator=(BitGenRef&&) = default;
+
+  template <typename URBG, typename absl::enable_if_t<
+                               (!std::is_same<URBG, BitGenRef>::value &&
+                                random_internal::is_urbg<URBG>::value &&
+                                !HasInvokeMock<URBG>::value)>* = nullptr>
+  BitGenRef(URBG& gen)  // NOLINT
+      : t_erased_gen_ptr_(reinterpret_cast<uintptr_t>(&gen)),
+        mock_call_(NotAMock),
+        generate_impl_fn_(ImplFn<URBG>) {}
 
   template <typename URBG,
-            typename absl::enable_if_t<
-                (!std::is_same<URBG, BitGenRef>::value &&
-                 random_internal::is_urbg<URBG>::value)>* = nullptr>
+            typename absl::enable_if_t<(!std::is_same<URBG, BitGenRef>::value &&
+                                        random_internal::is_urbg<URBG>::value &&
+                                        HasInvokeMock<URBG>::value)>* = nullptr>
   BitGenRef(URBG& gen)  // NOLINT
-      : mocked_gen_ptr_(MakeMockPointer(&gen)),
-        t_erased_gen_ptr_(reinterpret_cast<uintptr_t>(&gen)),
-        generate_impl_fn_(ImplFn<URBG>) {
-  }
+      : t_erased_gen_ptr_(reinterpret_cast<uintptr_t>(&gen)),
+        mock_call_(&MockCall<URBG>),
+        generate_impl_fn_(ImplFn<URBG>) {}
+
+  using result_type = uint64_t;
 
   static constexpr result_type(min)() {
     return (std::numeric_limits<result_type>::min)();
@@ -106,14 +137,9 @@ class BitGenRef {
   result_type operator()() { return generate_impl_fn_(t_erased_gen_ptr_); }
 
  private:
-  friend struct absl::random_internal::DistributionCaller<absl::BitGenRef>;
   using impl_fn = result_type (*)(uintptr_t);
-  using mocker_base_t = absl::random_internal::MockingBitGenBase;
-
-  // Convert an arbitrary URBG pointer into either a valid mocker_base_t
-  // pointer or a nullptr.
-  static inline mocker_base_t* MakeMockPointer(mocker_base_t* t) { return t; }
-  static inline mocker_base_t* MakeMockPointer(void*) { return nullptr; }
+  using mock_call_fn = bool (*)(uintptr_t, base_internal::FastTypeIdType, void*,
+                                void*);
 
   template <typename URBG>
   static result_type ImplFn(uintptr_t ptr) {
@@ -123,29 +149,32 @@ class BitGenRef {
     return fast_uniform_bits(*reinterpret_cast<URBG*>(ptr));
   }
 
-  mocker_base_t* mocked_gen_ptr_;
+  // Get a type-erased InvokeMock pointer.
+  template <typename URBG>
+  static bool MockCall(uintptr_t gen_ptr, base_internal::FastTypeIdType type,
+                       void* result, void* arg_tuple) {
+    return reinterpret_cast<URBG*>(gen_ptr)->InvokeMock(type, result,
+                                                        arg_tuple);
+  }
+  static bool NotAMock(uintptr_t, base_internal::FastTypeIdType, void*, void*) {
+    return false;
+  }
+
+  inline bool InvokeMock(base_internal::FastTypeIdType type, void* args_tuple,
+                         void* result) {
+    if (mock_call_ == NotAMock) return false;  // avoids an indirect call.
+    return mock_call_(t_erased_gen_ptr_, type, args_tuple, result);
+  }
+
   uintptr_t t_erased_gen_ptr_;
+  mock_call_fn mock_call_;
   impl_fn generate_impl_fn_;
-};
-
-namespace random_internal {
 
-template <>
-struct DistributionCaller<absl::BitGenRef> {
-  template <typename DistrT, typename... Args>
-  static typename DistrT::result_type Call(absl::BitGenRef* gen_ref,
-                                           Args&&... args) {
-    auto* mock_ptr = gen_ref->mocked_gen_ptr_;
-    if (mock_ptr == nullptr) {
-      DistrT dist(std::forward<Args>(args)...);
-      return dist(*gen_ref);
-    } else {
-      return mock_ptr->template Call<DistrT>(std::forward<Args>(args)...);
-    }
-  }
+  template <typename>
+  friend struct ::absl::random_internal::DistributionCaller;  // for InvokeMock
+  friend class ::absl::random_internal::MockHelpers;          // for InvokeMock
 };
 
-}  // namespace random_internal
 ABSL_NAMESPACE_END
 }  // namespace absl