about summary refs log tree commit diff
path: root/absl/random/internal
diff options
context:
space:
mode:
Diffstat (limited to 'absl/random/internal')
-rw-r--r--absl/random/internal/BUILD.bazel30
-rw-r--r--absl/random/internal/distribution_caller.h53
-rw-r--r--absl/random/internal/distributions.h52
-rw-r--r--absl/random/internal/mock_helpers.h127
-rw-r--r--absl/random/internal/mock_overload_set.h26
-rw-r--r--absl/random/internal/mocking_bit_gen_base.h85
-rw-r--r--absl/random/internal/uniform_helper.h30
7 files changed, 232 insertions, 171 deletions
diff --git a/absl/random/internal/BUILD.bazel b/absl/random/internal/BUILD.bazel
index 85d1fb81d23b..d81477ffb7e8 100644
--- a/absl/random/internal/BUILD.bazel
+++ b/absl/random/internal/BUILD.bazel
@@ -45,21 +45,10 @@ cc_library(
     hdrs = ["distribution_caller.h"],
     copts = ABSL_DEFAULT_COPTS,
     linkopts = ABSL_DEFAULT_LINKOPTS,
-    deps = ["//absl/base:config"],
-)
-
-cc_library(
-    name = "distributions",
-    hdrs = ["distributions.h"],
-    copts = ABSL_DEFAULT_COPTS,
-    linkopts = ABSL_DEFAULT_LINKOPTS,
     deps = [
-        ":distribution_caller",
-        ":traits",
-        ":uniform_helper",
-        "//absl/base",
-        "//absl/meta:type_traits",
-        "//absl/strings",
+        "//absl/base:config",
+        "//absl/base:fast_type_id",
+        "//absl/utility",
     ],
 )
 
@@ -221,7 +210,6 @@ cc_library(
         ":seed_material",
         "//absl/base:core_headers",
         "//absl/meta:type_traits",
-        "//absl/strings",
         "//absl/types:optional",
         "//absl/types:span",
     ],
@@ -497,12 +485,11 @@ cc_test(
 )
 
 cc_library(
-    name = "mocking_bit_gen_base",
-    hdrs = ["mocking_bit_gen_base.h"],
-    linkopts = ABSL_DEFAULT_LINKOPTS,
+    name = "mock_helpers",
+    hdrs = ["mock_helpers.h"],
     deps = [
-        "//absl/random",
-        "//absl/strings",
+        "//absl/base:fast_type_id",
+        "//absl/types:optional",
     ],
 )
 
@@ -511,6 +498,7 @@ cc_library(
     testonly = 1,
     hdrs = ["mock_overload_set.h"],
     deps = [
+        ":mock_helpers",
         "//absl/random:mocking_bit_gen",
         "@com_google_googletest//:gtest",
     ],
@@ -672,6 +660,8 @@ cc_library(
     copts = ABSL_DEFAULT_COPTS,
     linkopts = ABSL_DEFAULT_LINKOPTS,
     deps = [
+        ":traits",
+        "//absl/base:config",
         "//absl/meta:type_traits",
     ],
 )
diff --git a/absl/random/internal/distribution_caller.h b/absl/random/internal/distribution_caller.h
index 4e0724440cbc..fc81b787ebe2 100644
--- a/absl/random/internal/distribution_caller.h
+++ b/absl/random/internal/distribution_caller.h
@@ -20,6 +20,8 @@
 #include <utility>
 
 #include "absl/base/config.h"
+#include "absl/base/internal/fast_type_id.h"
+#include "absl/utility/utility.h"
 
 namespace absl {
 ABSL_NAMESPACE_BEGIN
@@ -30,14 +32,57 @@ namespace random_internal {
 // to intercept such calls.
 template <typename URBG>
 struct DistributionCaller {
-  // Call the provided distribution type. The parameters are expected
-  // to be explicitly specified.
-  // DistrT is the distribution type.
+  // SFINAE to detect whether the URBG type includes a member matching
+  // bool InvokeMock(base_internal::FastTypeIdType, void*, void*).
+  //
+  // These live inside BitGenRef so that they have friend access
+  // to MockingBitGen. (see similar methods in DistributionCaller).
+  template <template <class...> class Trait, class AlwaysVoid, class... Args>
+  struct detector : std::false_type {};
+  template <template <class...> class Trait, class... Args>
+  struct detector<Trait, absl::void_t<Trait<Args...>>, Args...>
+      : std::true_type {};
+
+  template <class T>
+  using invoke_mock_t = decltype(std::declval<T*>()->InvokeMock(
+      std::declval<::absl::base_internal::FastTypeIdType>(),
+      std::declval<void*>(), std::declval<void*>()));
+
+  using HasInvokeMock = typename detector<invoke_mock_t, void, URBG>::type;
+
+  // Default implementation of distribution caller.
   template <typename DistrT, typename... Args>
-  static typename DistrT::result_type Call(URBG* urbg, Args&&... args) {
+  static typename DistrT::result_type Impl(std::false_type, URBG* urbg,
+                                           Args&&... args) {
     DistrT dist(std::forward<Args>(args)...);
     return dist(*urbg);
   }
+
+  // Mock implementation of distribution caller.
+  // The underlying KeyT must match the KeyT constructed by MockOverloadSet.
+  template <typename DistrT, typename... Args>
+  static typename DistrT::result_type Impl(std::true_type, URBG* urbg,
+                                           Args&&... args) {
+    using ResultT = typename DistrT::result_type;
+    using ArgTupleT = std::tuple<absl::decay_t<Args>...>;
+    using KeyT = ResultT(DistrT, ArgTupleT);
+
+    ArgTupleT arg_tuple(std::forward<Args>(args)...);
+    ResultT result;
+    if (!urbg->InvokeMock(::absl::base_internal::FastTypeId<KeyT>(), &arg_tuple,
+                          &result)) {
+      auto dist = absl::make_from_tuple<DistrT>(arg_tuple);
+      result = dist(*urbg);
+    }
+    return result;
+  }
+
+  // Default implementation of distribution caller.
+  template <typename DistrT, typename... Args>
+  static typename DistrT::result_type Call(URBG* urbg, Args&&... args) {
+    return Impl<DistrT, Args...>(HasInvokeMock{}, urbg,
+                                 std::forward<Args>(args)...);
+  }
 };
 
 }  // namespace random_internal
diff --git a/absl/random/internal/distributions.h b/absl/random/internal/distributions.h
deleted file mode 100644
index d7e3c0161f06..000000000000
--- a/absl/random/internal/distributions.h
+++ /dev/null
@@ -1,52 +0,0 @@
-// Copyright 2019 The Abseil Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef ABSL_RANDOM_INTERNAL_DISTRIBUTIONS_H_
-#define ABSL_RANDOM_INTERNAL_DISTRIBUTIONS_H_
-
-#include <type_traits>
-
-#include "absl/meta/type_traits.h"
-#include "absl/random/internal/distribution_caller.h"
-#include "absl/random/internal/traits.h"
-#include "absl/random/internal/uniform_helper.h"
-
-namespace absl {
-ABSL_NAMESPACE_BEGIN
-namespace random_internal {
-
-// In the absence of an explicitly provided return-type, the template
-// "uniform_inferred_return_t<A, B>" is used to derive a suitable type, based on
-// the data-types of the endpoint-arguments {A lo, B hi}.
-//
-// Given endpoints {A lo, B hi}, one of {A, B} will be chosen as the
-// return-type, if one type can be implicitly converted into the other, in a
-// lossless way. The template "is_widening_convertible" implements the
-// compile-time logic for deciding if such a conversion is possible.
-//
-// If no such conversion between {A, B} exists, then the overload for
-// absl::Uniform() will be discarded, and the call will be ill-formed.
-// Return-type for absl::Uniform() when the return-type is inferred.
-template <typename A, typename B>
-using uniform_inferred_return_t =
-    absl::enable_if_t<absl::disjunction<is_widening_convertible<A, B>,
-                                        is_widening_convertible<B, A>>::value,
-                      typename std::conditional<
-                          is_widening_convertible<A, B>::value, B, A>::type>;
-
-}  // namespace random_internal
-ABSL_NAMESPACE_END
-}  // namespace absl
-
-#endif  // ABSL_RANDOM_INTERNAL_DISTRIBUTIONS_H_
diff --git a/absl/random/internal/mock_helpers.h b/absl/random/internal/mock_helpers.h
new file mode 100644
index 000000000000..9af27ab3a210
--- /dev/null
+++ b/absl/random/internal/mock_helpers.h
@@ -0,0 +1,127 @@
+//
+// Copyright 2019 The Abseil Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ABSL_RANDOM_INTERNAL_MOCK_HELPERS_H_
+#define ABSL_RANDOM_INTERNAL_MOCK_HELPERS_H_
+
+#include <tuple>
+#include <type_traits>
+
+#include "absl/base/internal/fast_type_id.h"
+#include "absl/types/optional.h"
+
+namespace absl {
+ABSL_NAMESPACE_BEGIN
+namespace random_internal {
+
+// MockHelpers works in conjunction with MockOverloadSet, MockingBitGen, and
+// BitGenRef to enable the mocking capability for absl distribution functions.
+//
+// MockingBitGen registers mocks based on the typeid of a mock signature, KeyT,
+// which is used to generate a unique id.
+//
+// KeyT is a signature of the form:
+//   result_type(discriminator_type, std::tuple<args...>)
+// The mocked function signature will be composed from KeyT as:
+//   result_type(args...)
+//
+class MockHelpers {
+  using IdType = ::absl::base_internal::FastTypeIdType;
+
+  // Given a key signature type used to index the mock, extract the components.
+  // KeyT is expected to have the form:
+  //   result_type(discriminator_type, arg_tuple_type)
+  template <typename KeyT>
+  struct KeySignature;
+
+  template <typename ResultT, typename DiscriminatorT, typename ArgTupleT>
+  struct KeySignature<ResultT(DiscriminatorT, ArgTupleT)> {
+    using result_type = ResultT;
+    using discriminator_type = DiscriminatorT;
+    using arg_tuple_type = ArgTupleT;
+  };
+
+  // Detector for InvokeMock.
+  template <class T>
+  using invoke_mock_t = decltype(std::declval<T*>()->InvokeMock(
+      std::declval<IdType>(), std::declval<void*>(), std::declval<void*>()));
+
+  // Empty implementation of InvokeMock.
+  template <typename KeyT, typename ReturnT, typename ArgTupleT, typename URBG,
+            typename... Args>
+  static absl::optional<ReturnT> InvokeMockImpl(char, URBG*, Args&&...) {
+    return absl::nullopt;
+  }
+
+  // Non-empty implementation of InvokeMock.
+  template <typename KeyT, typename ReturnT, typename ArgTupleT, typename URBG,
+            typename = invoke_mock_t<URBG>, typename... Args>
+  static absl::optional<ReturnT> InvokeMockImpl(int, URBG* urbg,
+                                                Args&&... args) {
+    ArgTupleT arg_tuple(std::forward<Args>(args)...);
+    ReturnT result;
+    if (urbg->InvokeMock(::absl::base_internal::FastTypeId<KeyT>(), &arg_tuple,
+                         &result)) {
+      return result;
+    }
+    return absl::nullopt;
+  }
+
+ public:
+  // Invoke a mock for the KeyT (may or may not be a signature).
+  //
+  // KeyT is used to generate a typeid-based lookup key for the mock.
+  // KeyT is a signature of the form:
+  //   result_type(discriminator_type, std::tuple<args...>)
+  // The mocked function signature will be composed from KeyT as:
+  //   result_type(args...)
+  //
+  // An instance of arg_tuple_type must be constructable from Args..., since
+  // the underlying mechanism requires a pointer to an argument tuple.
+  template <typename KeyT, typename URBG, typename... Args>
+  static auto MaybeInvokeMock(URBG* urbg, Args&&... args)
+      -> absl::optional<typename KeySignature<KeyT>::result_type> {
+    // Use function overloading to dispatch to the implemenation since
+    // more modern patterns (e.g. require + constexpr) are not supported in all
+    // compiler configurations.
+    return InvokeMockImpl<KeyT, typename KeySignature<KeyT>::result_type,
+                          typename KeySignature<KeyT>::arg_tuple_type, URBG>(
+        0, urbg, std::forward<Args>(args)...);
+  }
+
+  // Acquire a mock for the KeyT (may or may not be a signature).
+  //
+  // KeyT is used to generate a typeid-based lookup for the mock.
+  // KeyT is a signature of the form:
+  //   result_type(discriminator_type, std::tuple<args...>)
+  // The mocked function signature will be composed from KeyT as:
+  //   result_type(args...)
+  template <typename KeyT, typename MockURBG>
+  static auto MockFor(MockURBG& m) -> decltype(
+      std::declval<MockURBG>()
+          .template RegisterMock<typename KeySignature<KeyT>::result_type,
+                                 typename KeySignature<KeyT>::arg_tuple_type>(
+              std::declval<IdType>())) {
+    return m.template RegisterMock<typename KeySignature<KeyT>::result_type,
+                                   typename KeySignature<KeyT>::arg_tuple_type>(
+        ::absl::base_internal::FastTypeId<KeyT>());
+  }
+};
+
+}  // namespace random_internal
+ABSL_NAMESPACE_END
+}  // namespace absl
+
+#endif  // ABSL_RANDOM_INTERNAL_MOCK_HELPERS_H_
diff --git a/absl/random/internal/mock_overload_set.h b/absl/random/internal/mock_overload_set.h
index c2a30d89d52b..dccc6cee679c 100644
--- a/absl/random/internal/mock_overload_set.h
+++ b/absl/random/internal/mock_overload_set.h
@@ -20,6 +20,7 @@
 
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
+#include "absl/random/internal/mock_helpers.h"
 #include "absl/random/mocking_bit_gen.h"
 
 namespace absl {
@@ -35,17 +36,20 @@ struct MockSingleOverload;
 // EXPECT_CALL(mock_single_overload, Call(...))` will expand to a call to
 // `mock_single_overload.gmock_Call(...)`. Because expectations are stored on
 // the MockingBitGen (an argument passed inside `Call(...)`), this forwards to
-// arguments to Mocking::Register.
+// arguments to MockingBitGen::Register.
+//
+// The underlying KeyT must match the KeyT constructed by DistributionCaller.
 template <typename DistrT, typename Ret, typename... Args>
 struct MockSingleOverload<DistrT, Ret(MockingBitGen&, Args...)> {
   static_assert(std::is_same<typename DistrT::result_type, Ret>::value,
                 "Overload signature must have return type matching the "
-                "distributions result type.");
+                "distribution result_type.");
+  using KeyT = Ret(DistrT, std::tuple<Args...>);
   auto gmock_Call(
       absl::MockingBitGen& gen,  // NOLINT(google-runtime-references)
-      const ::testing::Matcher<Args>&... args)
-      -> decltype(gen.Register<DistrT, Args...>(args...)) {
-    return gen.Register<DistrT, Args...>(args...);
+      const ::testing::Matcher<Args>&... matchers)
+      -> decltype(MockHelpers::MockFor<KeyT>(gen).gmock_Call(matchers...)) {
+    return MockHelpers::MockFor<KeyT>(gen).gmock_Call(matchers...);
   }
 };
 
@@ -53,13 +57,15 @@ template <typename DistrT, typename Ret, typename Arg, typename... Args>
 struct MockSingleOverload<DistrT, Ret(Arg, MockingBitGen&, Args...)> {
   static_assert(std::is_same<typename DistrT::result_type, Ret>::value,
                 "Overload signature must have return type matching the "
-                "distributions result type.");
+                "distribution result_type.");
+  using KeyT = Ret(DistrT, std::tuple<Arg, Args...>);
   auto gmock_Call(
-      const ::testing::Matcher<Arg>& arg,
+      const ::testing::Matcher<Arg>& matcher,
       absl::MockingBitGen& gen,  // NOLINT(google-runtime-references)
-      const ::testing::Matcher<Args>&... args)
-      -> decltype(gen.Register<DistrT, Arg, Args...>(arg, args...)) {
-    return gen.Register<DistrT, Arg, Args...>(arg, args...);
+      const ::testing::Matcher<Args>&... matchers)
+      -> decltype(MockHelpers::MockFor<KeyT>(gen).gmock_Call(matcher,
+                                                             matchers...)) {
+    return MockHelpers::MockFor<KeyT>(gen).gmock_Call(matcher, matchers...);
   }
 };
 
diff --git a/absl/random/internal/mocking_bit_gen_base.h b/absl/random/internal/mocking_bit_gen_base.h
deleted file mode 100644
index 23ecaf6c7e63..000000000000
--- a/absl/random/internal/mocking_bit_gen_base.h
+++ /dev/null
@@ -1,85 +0,0 @@
-//
-// Copyright 2018 The Abseil Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-#ifndef ABSL_RANDOM_INTERNAL_MOCKING_BIT_GEN_BASE_H_
-#define ABSL_RANDOM_INTERNAL_MOCKING_BIT_GEN_BASE_H_
-
-#include <string>
-#include <typeinfo>
-
-#include "absl/random/random.h"
-#include "absl/strings/str_cat.h"
-
-namespace absl {
-ABSL_NAMESPACE_BEGIN
-namespace random_internal {
-
-class MockingBitGenBase {
-  template <typename>
-  friend struct DistributionCaller;
-  using generator_type = absl::BitGen;
-
- public:
-  // URBG interface
-  using result_type = generator_type::result_type;
-  static constexpr result_type(min)() { return (generator_type::min)(); }
-  static constexpr result_type(max)() { return (generator_type::max)(); }
-  result_type operator()() { return gen_(); }
-
-  virtual ~MockingBitGenBase() = default;
-
- protected:
-  // CallImpl is the type-erased virtual dispatch.
-  // The type of dist is always distribution<T>,
-  // The type of result is always distribution<T>::result_type.
-  virtual bool CallImpl(const std::type_info& distr_type, void* dist_args,
-                        void* result) = 0;
-
-  template <typename DistrT, typename ArgTupleT>
-  static const std::type_info& GetTypeId() {
-    return typeid(std::pair<absl::decay_t<DistrT>, absl::decay_t<ArgTupleT>>);
-  }
-
-  // Call the generating distribution function.
-  // Invoked by DistributionCaller<>::Call<DistT>.
-  // DistT is the distribution type.
-  template <typename DistrT, typename... Args>
-  typename DistrT::result_type Call(Args&&... args) {
-    using distr_result_type = typename DistrT::result_type;
-    using ArgTupleT = std::tuple<absl::decay_t<Args>...>;
-
-    ArgTupleT arg_tuple(std::forward<Args>(args)...);
-    auto dist = absl::make_from_tuple<DistrT>(arg_tuple);
-
-    distr_result_type result{};
-    bool found_match =
-        CallImpl(GetTypeId<DistrT, ArgTupleT>(), &arg_tuple, &result);
-
-    if (!found_match) {
-      result = dist(gen_);
-    }
-
-    return result;
-  }
-
- private:
-  generator_type gen_;
-};  // namespace random_internal
-
-}  // namespace random_internal
-ABSL_NAMESPACE_END
-}  // namespace absl
-
-#endif  // ABSL_RANDOM_INTERNAL_MOCKING_BIT_GEN_BASE_H_
diff --git a/absl/random/internal/uniform_helper.h b/absl/random/internal/uniform_helper.h
index 663107cb3a63..5b2afecb89e6 100644
--- a/absl/random/internal/uniform_helper.h
+++ b/absl/random/internal/uniform_helper.h
@@ -19,10 +19,13 @@
 #include <limits>
 #include <type_traits>
 
+#include "absl/base/config.h"
 #include "absl/meta/type_traits.h"
+#include "absl/random/internal/traits.h"
 
 namespace absl {
 ABSL_NAMESPACE_BEGIN
+
 template <typename IntType>
 class uniform_int_distribution;
 
@@ -58,6 +61,26 @@ struct IntervalOpenOpenTag
     : public random_internal::TagTypeCompare<IntervalOpenOpenTag> {};
 
 namespace random_internal {
+
+// In the absence of an explicitly provided return-type, the template
+// "uniform_inferred_return_t<A, B>" is used to derive a suitable type, based on
+// the data-types of the endpoint-arguments {A lo, B hi}.
+//
+// Given endpoints {A lo, B hi}, one of {A, B} will be chosen as the
+// return-type, if one type can be implicitly converted into the other, in a
+// lossless way. The template "is_widening_convertible" implements the
+// compile-time logic for deciding if such a conversion is possible.
+//
+// If no such conversion between {A, B} exists, then the overload for
+// absl::Uniform() will be discarded, and the call will be ill-formed.
+// Return-type for absl::Uniform() when the return-type is inferred.
+template <typename A, typename B>
+using uniform_inferred_return_t =
+    absl::enable_if_t<absl::disjunction<is_widening_convertible<A, B>,
+                                        is_widening_convertible<B, A>>::value,
+                      typename std::conditional<
+                          is_widening_convertible<A, B>::value, B, A>::type>;
+
 // The functions
 //    uniform_lower_bound(tag, a, b)
 // and
@@ -149,12 +172,19 @@ uniform_upper_bound(Tag, FloatType, FloatType b) {
   return std::nextafter(b, (std::numeric_limits<FloatType>::max)());
 }
 
+// UniformDistribution selects either absl::uniform_int_distribution
+// or absl::uniform_real_distribution depending on the NumType parameter.
 template <typename NumType>
 using UniformDistribution =
     typename std::conditional<std::is_integral<NumType>::value,
                               absl::uniform_int_distribution<NumType>,
                               absl::uniform_real_distribution<NumType>>::type;
 
+// UniformDistributionWrapper is used as the underlying distribution type
+// by the absl::Uniform template function. It selects the proper Abseil
+// uniform distribution and provides constructor overloads that match the
+// expected parameter order as well as adjusting distribtuion bounds based
+// on the tag.
 template <typename NumType>
 struct UniformDistributionWrapper : public UniformDistribution<NumType> {
   template <typename TagType>