about summary refs log tree commit diff
path: root/absl/random/mocking_bit_gen.h
diff options
context:
space:
mode:
Diffstat (limited to 'absl/random/mocking_bit_gen.h')
-rw-r--r--absl/random/mocking_bit_gen.h178
1 files changed, 104 insertions, 74 deletions
diff --git a/absl/random/mocking_bit_gen.h b/absl/random/mocking_bit_gen.h
index 3d8a979e734f..6d2f2c836245 100644
--- a/absl/random/mocking_bit_gen.h
+++ b/absl/random/mocking_bit_gen.h
@@ -33,17 +33,16 @@
 #include <memory>
 #include <tuple>
 #include <type_traits>
-#include <typeindex>
-#include <typeinfo>
 #include <utility>
 
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
+#include "absl/base/internal/fast_type_id.h"
 #include "absl/container/flat_hash_map.h"
 #include "absl/meta/type_traits.h"
 #include "absl/random/distributions.h"
 #include "absl/random/internal/distribution_caller.h"
-#include "absl/random/internal/mocking_bit_gen_base.h"
+#include "absl/random/random.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_join.h"
 #include "absl/types/span.h"
@@ -54,11 +53,12 @@ namespace absl {
 ABSL_NAMESPACE_BEGIN
 
 namespace random_internal {
-
-template <typename, typename>
-struct MockSingleOverload;
+template <typename>
+struct DistributionCaller;
+class MockHelpers;
 
 }  // namespace random_internal
+class BitGenRef;
 
 // MockingBitGen
 //
@@ -96,102 +96,132 @@ struct MockSingleOverload;
 // At this time, only mock distributions supplied within the Abseil random
 // library are officially supported.
 //
-class MockingBitGen : public absl::random_internal::MockingBitGenBase {
+// EXPECT_CALL and ON_CALL need to be made within the same DLL component as
+// the call to absl::Uniform and related methods, otherwise mocking will fail
+// since the  underlying implementation creates a type-specific pointer which
+// will be distinct across different DLL boundaries.
+//
+class MockingBitGen {
  public:
-  MockingBitGen() {}
+  MockingBitGen() = default;
 
-  ~MockingBitGen() override {
+  ~MockingBitGen() {
     for (const auto& del : deleters_) del();
   }
 
+  // URBG interface
+  using result_type = absl::BitGen::result_type;
+
+  static constexpr result_type(min)() { return (absl::BitGen::min)(); }
+  static constexpr result_type(max)() { return (absl::BitGen::max)(); }
+  result_type operator()() { return gen_(); }
+
  private:
-  template <typename DistrT, typename... Args>
-  using MockFnType =
-      ::testing::MockFunction<typename DistrT::result_type(Args...)>;
+  using match_impl_fn = void (*)(void* mock_fn, void* t_erased_arg_tuple,
+                                 void* t_erased_result);
+
+  struct MockData {
+    void* mock_fn = nullptr;
+    match_impl_fn match_impl = nullptr;
+  };
+
+  // GetMockFnType returns the testing::MockFunction for a result and tuple.
+  // This method only exists for type deduction and is otherwise unimplemented.
+  template <typename ResultT, typename... Args>
+  static auto GetMockFnType(ResultT, std::tuple<Args...>)
+      -> ::testing::MockFunction<ResultT(Args...)>;
+
+  // MockFnCaller is a helper method for use with absl::apply to
+  // apply an ArgTupleT to a compatible MockFunction.
+  // NOTE: MockFnCaller is essentially equivalent to the lambda:
+  // [fn](auto... args) { return fn->Call(std::move(args)...)}
+  // however that fails to build on some supported platforms.
+  template <typename ResultT, typename MockFnType, typename Tuple>
+  struct MockFnCaller;
+  // specialization for std::tuple.
+  template <typename ResultT, typename MockFnType, typename... Args>
+  struct MockFnCaller<ResultT, MockFnType, std::tuple<Args...>> {
+    MockFnType* fn;
+    inline ResultT operator()(Args... args) {
+      return fn->Call(std::move(args)...);
+    }
+  };
 
-  // MockingBitGen::Register
+  // MockingBitGen::RegisterMock
   //
-  // Register<DistrT, FormatT, ArgTupleT> is the main extension point for
-  // extending the MockingBitGen framework. It provides a mechanism to install a
-  // mock expectation for the distribution `distr_t` onto the MockingBitGen
-  // context.
+  // RegisterMock<ResultT, ArgTupleT>(FastTypeIdType) is the main extension
+  // point for extending the MockingBitGen framework. It provides a mechanism to
+  // install a mock expectation for a function like ResultT(Args...) keyed by
+  // type_idex onto the MockingBitGen context. The key is that the type_index
+  // used to register must match the type index used to call the mock.
   //
   // The returned MockFunction<...> type can be used to setup additional
   // distribution parameters of the expectation.
-  template <typename DistrT, typename... Args, typename... Ms>
-  decltype(std::declval<MockFnType<DistrT, Args...>>().gmock_Call(
-      std::declval<Ms>()...))
-  Register(Ms&&... matchers) {
-    auto& mock =
-        mocks_[std::type_index(GetTypeId<DistrT, std::tuple<Args...>>())];
-
+  template <typename ResultT, typename ArgTupleT>
+  auto RegisterMock(base_internal::FastTypeIdType type)
+      -> decltype(GetMockFnType(std::declval<ResultT>(),
+                                std::declval<ArgTupleT>()))& {
+    using MockFnType = decltype(
+        GetMockFnType(std::declval<ResultT>(), std::declval<ArgTupleT>()));
+    auto& mock = mocks_[type];
     if (!mock.mock_fn) {
-      auto* mock_fn = new MockFnType<DistrT, Args...>;
+      auto* mock_fn = new MockFnType;
       mock.mock_fn = mock_fn;
-      mock.match_impl = &MatchImpl<DistrT, Args...>;
+      mock.match_impl = &MatchImpl<ResultT, ArgTupleT>;
       deleters_.emplace_back([mock_fn] { delete mock_fn; });
     }
-
-    return static_cast<MockFnType<DistrT, Args...>*>(mock.mock_fn)
-        ->gmock_Call(std::forward<Ms>(matchers)...);
+    return *static_cast<MockFnType*>(mock.mock_fn);
   }
 
-  mutable std::vector<std::function<void()>> deleters_;
-
-  using match_impl_fn = void (*)(void* mock_fn, void* t_erased_dist_args,
-                                 void* t_erased_result);
-  struct MockData {
-    void* mock_fn = nullptr;
-    match_impl_fn match_impl = nullptr;
-  };
-
-  mutable absl::flat_hash_map<std::type_index, MockData> mocks_;
-
-  template <typename DistrT, typename... Args>
-  static void MatchImpl(void* mock_fn, void* dist_args, void* result) {
-    using result_type = typename DistrT::result_type;
-    *static_cast<result_type*>(result) = absl::apply(
-        [mock_fn](Args... args) -> result_type {
-          return (*static_cast<MockFnType<DistrT, Args...>*>(mock_fn))
-              .Call(std::move(args)...);
-        },
-        *static_cast<std::tuple<Args...>*>(dist_args));
+  // MockingBitGen::MatchImpl<> is a dispatch function which converts the
+  // generic type-erased parameters into a specific mock invocation call.
+  // Requires tuple_args to point to a ArgTupleT, which is a std::tuple<Args...>
+  // used to invoke the mock function.
+  // Requires result to point to a ResultT, which is the result of the call.
+  template <typename ResultT, typename ArgTupleT>
+  static void MatchImpl(/*MockFnType<ResultT, Args...>*/ void* mock_fn,
+                        /*ArgTupleT*/ void* args_tuple,
+                        /*ResultT*/ void* result) {
+    using MockFnType = decltype(
+        GetMockFnType(std::declval<ResultT>(), std::declval<ArgTupleT>()));
+    *static_cast<ResultT*>(result) = absl::apply(
+        MockFnCaller<ResultT, MockFnType, ArgTupleT>{
+            static_cast<MockFnType*>(mock_fn)},
+        *static_cast<ArgTupleT*>(args_tuple));
   }
 
-  // Looks for an appropriate mock - Returns the mocked result if one is found.
-  // Otherwise, returns a random value generated by the underlying URBG.
-  bool CallImpl(const std::type_info& key_type, void* dist_args,
-                void* result) override {
+  // MockingBitGen::InvokeMock
+  //
+  // InvokeMock(FastTypeIdType, args, result) is the entrypoint for invoking
+  // mocks registered on MockingBitGen.
+  //
+  // When no mocks are registered on the provided FastTypeIdType, returns false.
+  // Otherwise attempts to invoke the mock function ResultT(Args...) that
+  // was previously registered via the type_index.
+  // Requires tuple_args to point to a ArgTupleT, which is a std::tuple<Args...>
+  // used to invoke the mock function.
+  // Requires result to point to a ResultT, which is the result of the call.
+  inline bool InvokeMock(base_internal::FastTypeIdType type, void* args_tuple,
+                         void* result) {
     // Trigger a mock, if there exists one that matches `param`.
-    auto it = mocks_.find(std::type_index(key_type));
+    auto it = mocks_.find(type);
     if (it == mocks_.end()) return false;
     auto* mock_data = static_cast<MockData*>(&it->second);
-    mock_data->match_impl(mock_data->mock_fn, dist_args, result);
+    mock_data->match_impl(mock_data->mock_fn, args_tuple, result);
     return true;
   }
 
-  template <typename, typename>
-  friend struct ::absl::random_internal::MockSingleOverload;
-  friend struct ::absl::random_internal::DistributionCaller<
-      absl::MockingBitGen>;
-};
-
-// -----------------------------------------------------------------------------
-// Implementation Details Only Below
-// -----------------------------------------------------------------------------
+  absl::flat_hash_map<base_internal::FastTypeIdType, MockData> mocks_;
+  std::vector<std::function<void()>> deleters_;
+  absl::BitGen gen_;
 
-namespace random_internal {
-
-template <>
-struct DistributionCaller<absl::MockingBitGen> {
-  template <typename DistrT, typename... Args>
-  static typename DistrT::result_type Call(absl::MockingBitGen* gen,
-                                           Args&&... args) {
-    return gen->template Call<DistrT>(std::forward<Args>(args)...);
-  }
+  template <typename>
+  friend struct ::absl::random_internal::DistributionCaller;  // for InvokeMock
+  friend class ::absl::BitGenRef;                             // for InvokeMock
+  friend class ::absl::random_internal::MockHelpers;  // for RegisterMock,
+                                                      // InvokeMock
 };
 
-}  // namespace random_internal
 ABSL_NAMESPACE_END
 }  // namespace absl