diff options
Diffstat (limited to 'absl/random')
-rw-r--r-- | absl/random/BUILD.bazel | 1 | ||||
-rw-r--r-- | absl/random/CMakeLists.txt | 2 | ||||
-rw-r--r-- | absl/random/bit_gen_ref.h | 5 | ||||
-rw-r--r-- | absl/random/distribution_format_traits.h | 278 | ||||
-rw-r--r-- | absl/random/distributions.h | 41 | ||||
-rw-r--r-- | absl/random/internal/BUILD.bazel | 1 | ||||
-rw-r--r-- | absl/random/internal/distribution_caller.h | 18 | ||||
-rw-r--r-- | absl/random/internal/mocking_bit_gen_base.h | 42 | ||||
-rw-r--r-- | absl/random/mocking_bit_gen.cc | 1 | ||||
-rw-r--r-- | absl/random/mocking_bit_gen.h | 6 | ||||
-rw-r--r-- | absl/random/mocking_bit_gen_test.cc | 6 |
11 files changed, 375 insertions, 26 deletions
diff --git a/absl/random/BUILD.bazel b/absl/random/BUILD.bazel index 43ed9840a035..2585b39742e2 100644 --- a/absl/random/BUILD.bazel +++ b/absl/random/BUILD.bazel @@ -53,6 +53,7 @@ cc_library( "bernoulli_distribution.h", "beta_distribution.h", "discrete_distribution.h", + "distribution_format_traits.h", "distributions.h", "exponential_distribution.h", "gaussian_distribution.h", diff --git a/absl/random/CMakeLists.txt b/absl/random/CMakeLists.txt index 53f1aa5c775c..46dbc3efbc83 100644 --- a/absl/random/CMakeLists.txt +++ b/absl/random/CMakeLists.txt @@ -78,6 +78,7 @@ absl_cc_library( ${ABSL_DEFAULT_LINKOPTS} DEPS absl::random_random + absl::strings ) # Internal-only target, do not depend on directly. @@ -167,6 +168,7 @@ absl_cc_library( "bernoulli_distribution.h" "beta_distribution.h" "discrete_distribution.h" + "distribution_format_traits.h" "distributions.h" "exponential_distribution.h" "gaussian_distribution.h" diff --git a/absl/random/bit_gen_ref.h b/absl/random/bit_gen_ref.h index 59591a479d8c..e8771162e5fb 100644 --- a/absl/random/bit_gen_ref.h +++ b/absl/random/bit_gen_ref.h @@ -132,7 +132,7 @@ namespace random_internal { template <> struct DistributionCaller<absl::BitGenRef> { - template <typename DistrT, typename... Args> + template <typename DistrT, typename FormatT, typename... Args> static typename DistrT::result_type Call(absl::BitGenRef* gen_ref, Args&&... args) { auto* mock_ptr = gen_ref->mocked_gen_ptr_; @@ -140,7 +140,8 @@ struct DistributionCaller<absl::BitGenRef> { DistrT dist(std::forward<Args>(args)...); return dist(*gen_ref); } else { - return mock_ptr->template Call<DistrT>(std::forward<Args>(args)...); + return mock_ptr->template Call<DistrT, FormatT>( + std::forward<Args>(args)...); } } }; diff --git a/absl/random/distribution_format_traits.h b/absl/random/distribution_format_traits.h new file mode 100644 index 000000000000..22b358cc8c37 --- /dev/null +++ b/absl/random/distribution_format_traits.h @@ -0,0 +1,278 @@ +// +// Copyright 2018 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +#ifndef ABSL_RANDOM_DISTRIBUTION_FORMAT_TRAITS_H_ +#define ABSL_RANDOM_DISTRIBUTION_FORMAT_TRAITS_H_ + +#include <string> +#include <tuple> +#include <typeinfo> + +#include "absl/meta/type_traits.h" +#include "absl/random/bernoulli_distribution.h" +#include "absl/random/beta_distribution.h" +#include "absl/random/exponential_distribution.h" +#include "absl/random/gaussian_distribution.h" +#include "absl/random/log_uniform_int_distribution.h" +#include "absl/random/poisson_distribution.h" +#include "absl/random/uniform_int_distribution.h" +#include "absl/random/uniform_real_distribution.h" +#include "absl/random/zipf_distribution.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" + +namespace absl { +ABSL_NAMESPACE_BEGIN + +struct IntervalClosedClosedTag; +struct IntervalClosedOpenTag; +struct IntervalOpenClosedTag; +struct IntervalOpenOpenTag; + +namespace random_internal { + +// ScalarTypeName defines a preferred hierarchy of preferred type names for +// scalars, and is evaluated at compile time for the specific type +// specialization. +template <typename T> +constexpr const char* ScalarTypeName() { + static_assert(std::is_integral<T>() || std::is_floating_point<T>(), ""); + // clang-format off + return + std::is_same<T, float>::value ? "float" : + std::is_same<T, double>::value ? "double" : + std::is_same<T, long double>::value ? "long double" : + std::is_same<T, bool>::value ? "bool" : + std::is_signed<T>::value && sizeof(T) == 1 ? "int8_t" : + std::is_signed<T>::value && sizeof(T) == 2 ? "int16_t" : + std::is_signed<T>::value && sizeof(T) == 4 ? "int32_t" : + std::is_signed<T>::value && sizeof(T) == 8 ? "int64_t" : + std::is_unsigned<T>::value && sizeof(T) == 1 ? "uint8_t" : + std::is_unsigned<T>::value && sizeof(T) == 2 ? "uint16_t" : + std::is_unsigned<T>::value && sizeof(T) == 4 ? "uint32_t" : + std::is_unsigned<T>::value && sizeof(T) == 8 ? "uint64_t" : + "undefined"; + // clang-format on + + // NOTE: It would be nice to use typeid(T).name(), but that's an + // implementation-defined attribute which does not necessarily + // correspond to a name. We could potentially demangle it + // using, e.g. abi::__cxa_demangle. +} + +// Distribution traits used by DistributionCaller and internal implementation +// details of the mocking framework. +/* +struct DistributionFormatTraits { + // Returns the parameterized name of the distribution function. + static constexpr const char* FunctionName() + // Format DistrT parameters. + static std::string FormatArgs(DistrT& dist); + // Format DistrT::result_type results. + static std::string FormatResults(DistrT& dist); +}; +*/ +template <typename DistrT> +struct DistributionFormatTraits; + +template <typename R> +struct DistributionFormatTraits<absl::uniform_int_distribution<R>> { + using distribution_t = absl::uniform_int_distribution<R>; + using result_t = typename distribution_t::result_type; + + static constexpr const char* Name() { return "Uniform"; } + + static std::string FunctionName() { + return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">"); + } + static std::string FormatArgs(const distribution_t& d) { + return absl::StrCat("absl::IntervalClosedClosed, ", (d.min)(), ", ", + (d.max)()); + } + static std::string FormatResults(absl::Span<const result_t> results) { + return absl::StrJoin(results, ", "); + } +}; + +template <typename R> +struct DistributionFormatTraits<absl::uniform_real_distribution<R>> { + using distribution_t = absl::uniform_real_distribution<R>; + using result_t = typename distribution_t::result_type; + + static constexpr const char* Name() { return "Uniform"; } + + static std::string FunctionName() { + return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">"); + } + static std::string FormatArgs(const distribution_t& d) { + return absl::StrCat((d.min)(), ", ", (d.max)()); + } + static std::string FormatResults(absl::Span<const result_t> results) { + return absl::StrJoin(results, ", "); + } +}; + +template <typename R> +struct DistributionFormatTraits<absl::exponential_distribution<R>> { + using distribution_t = absl::exponential_distribution<R>; + using result_t = typename distribution_t::result_type; + + static constexpr const char* Name() { return "Exponential"; } + + static std::string FunctionName() { + return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">"); + } + static std::string FormatArgs(const distribution_t& d) { + return absl::StrCat(d.lambda()); + } + static std::string FormatResults(absl::Span<const result_t> results) { + return absl::StrJoin(results, ", "); + } +}; + +template <typename R> +struct DistributionFormatTraits<absl::poisson_distribution<R>> { + using distribution_t = absl::poisson_distribution<R>; + using result_t = typename distribution_t::result_type; + + static constexpr const char* Name() { return "Poisson"; } + + static std::string FunctionName() { + return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">"); + } + static std::string FormatArgs(const distribution_t& d) { + return absl::StrCat(d.mean()); + } + static std::string FormatResults(absl::Span<const result_t> results) { + return absl::StrJoin(results, ", "); + } +}; + +template <> +struct DistributionFormatTraits<absl::bernoulli_distribution> { + using distribution_t = absl::bernoulli_distribution; + using result_t = typename distribution_t::result_type; + + static constexpr const char* Name() { return "Bernoulli"; } + + static constexpr const char* FunctionName() { return Name(); } + static std::string FormatArgs(const distribution_t& d) { + return absl::StrCat(d.p()); + } + static std::string FormatResults(absl::Span<const result_t> results) { + return absl::StrJoin(results, ", "); + } +}; + +template <typename R> +struct DistributionFormatTraits<absl::beta_distribution<R>> { + using distribution_t = absl::beta_distribution<R>; + using result_t = typename distribution_t::result_type; + + static constexpr const char* Name() { return "Beta"; } + + static std::string FunctionName() { + return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">"); + } + static std::string FormatArgs(const distribution_t& d) { + return absl::StrCat(d.alpha(), ", ", d.beta()); + } + static std::string FormatResults(absl::Span<const result_t> results) { + return absl::StrJoin(results, ", "); + } +}; + +template <typename R> +struct DistributionFormatTraits<absl::zipf_distribution<R>> { + using distribution_t = absl::zipf_distribution<R>; + using result_t = typename distribution_t::result_type; + + static constexpr const char* Name() { return "Zipf"; } + + static std::string FunctionName() { + return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">"); + } + static std::string FormatArgs(const distribution_t& d) { + return absl::StrCat(d.k(), ", ", d.v(), ", ", d.q()); + } + static std::string FormatResults(absl::Span<const result_t> results) { + return absl::StrJoin(results, ", "); + } +}; + +template <typename R> +struct DistributionFormatTraits<absl::gaussian_distribution<R>> { + using distribution_t = absl::gaussian_distribution<R>; + using result_t = typename distribution_t::result_type; + + static constexpr const char* Name() { return "Gaussian"; } + + static std::string FunctionName() { + return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">"); + } + static std::string FormatArgs(const distribution_t& d) { + return absl::StrJoin(std::make_tuple(d.mean(), d.stddev()), ", "); + } + static std::string FormatResults(absl::Span<const result_t> results) { + return absl::StrJoin(results, ", "); + } +}; + +template <typename R> +struct DistributionFormatTraits<absl::log_uniform_int_distribution<R>> { + using distribution_t = absl::log_uniform_int_distribution<R>; + using result_t = typename distribution_t::result_type; + + static constexpr const char* Name() { return "LogUniform"; } + + static std::string FunctionName() { + return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">"); + } + static std::string FormatArgs(const distribution_t& d) { + return absl::StrJoin(std::make_tuple((d.min)(), (d.max)(), d.base()), ", "); + } + static std::string FormatResults(absl::Span<const result_t> results) { + return absl::StrJoin(results, ", "); + } +}; + +template <typename NumType> +struct UniformDistributionWrapper; + +template <typename NumType> +struct DistributionFormatTraits<UniformDistributionWrapper<NumType>> { + using distribution_t = UniformDistributionWrapper<NumType>; + using result_t = NumType; + + static constexpr const char* Name() { return "Uniform"; } + + static std::string FunctionName() { + return absl::StrCat(Name(), "<", ScalarTypeName<NumType>(), ">"); + } + static std::string FormatArgs(const distribution_t& d) { + return absl::StrCat((d.min)(), ", ", (d.max)()); + } + static std::string FormatResults(absl::Span<const result_t> results) { + return absl::StrJoin(results, ", "); + } +}; + +} // namespace random_internal +ABSL_NAMESPACE_END +} // namespace absl + +#endif // ABSL_RANDOM_DISTRIBUTION_FORMAT_TRAITS_H_ diff --git a/absl/random/distributions.h b/absl/random/distributions.h index 7abdfa8f202a..c1fb66501593 100644 --- a/absl/random/distributions.h +++ b/absl/random/distributions.h @@ -55,6 +55,7 @@ #include "absl/base/internal/inline_variable.h" #include "absl/random/bernoulli_distribution.h" #include "absl/random/beta_distribution.h" +#include "absl/random/distribution_format_traits.h" #include "absl/random/exponential_distribution.h" #include "absl/random/gaussian_distribution.h" #include "absl/random/internal/distributions.h" // IWYU pragma: export @@ -125,13 +126,14 @@ Uniform(TagType tag, R lo, R hi) { using gen_t = absl::decay_t<URBG>; using distribution_t = random_internal::UniformDistributionWrapper<R>; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; auto a = random_internal::uniform_lower_bound(tag, lo, hi); auto b = random_internal::uniform_upper_bound(tag, lo, hi); if (a > b) return a; return random_internal::DistributionCaller<gen_t>::template Call< - distribution_t>(&urbg, tag, lo, hi); + distribution_t, format_t>(&urbg, tag, lo, hi); } // absl::Uniform<T>(bitgen, lo, hi) @@ -144,6 +146,7 @@ Uniform(URBG&& urbg, // NOLINT(runtime/references) R lo, R hi) { using gen_t = absl::decay_t<URBG>; using distribution_t = random_internal::UniformDistributionWrapper<R>; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; constexpr auto tag = absl::IntervalClosedOpen; auto a = random_internal::uniform_lower_bound(tag, lo, hi); @@ -151,7 +154,7 @@ Uniform(URBG&& urbg, // NOLINT(runtime/references) if (a > b) return a; return random_internal::DistributionCaller<gen_t>::template Call< - distribution_t>(&urbg, lo, hi); + distribution_t, format_t>(&urbg, lo, hi); } // absl::Uniform(tag, bitgen, lo, hi) @@ -169,14 +172,15 @@ Uniform(TagType tag, using gen_t = absl::decay_t<URBG>; using return_t = typename random_internal::uniform_inferred_return_t<A, B>; using distribution_t = random_internal::UniformDistributionWrapper<return_t>; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; auto a = random_internal::uniform_lower_bound<return_t>(tag, lo, hi); auto b = random_internal::uniform_upper_bound<return_t>(tag, lo, hi); if (a > b) return a; return random_internal::DistributionCaller<gen_t>::template Call< - distribution_t>(&urbg, tag, static_cast<return_t>(lo), - static_cast<return_t>(hi)); + distribution_t, format_t>(&urbg, tag, static_cast<return_t>(lo), + static_cast<return_t>(hi)); } // absl::Uniform(bitgen, lo, hi) @@ -192,6 +196,7 @@ Uniform(URBG&& urbg, // NOLINT(runtime/references) using gen_t = absl::decay_t<URBG>; using return_t = typename random_internal::uniform_inferred_return_t<A, B>; using distribution_t = random_internal::UniformDistributionWrapper<return_t>; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; constexpr auto tag = absl::IntervalClosedOpen; auto a = random_internal::uniform_lower_bound<return_t>(tag, lo, hi); @@ -199,8 +204,8 @@ Uniform(URBG&& urbg, // NOLINT(runtime/references) if (a > b) return a; return random_internal::DistributionCaller<gen_t>::template Call< - distribution_t>(&urbg, static_cast<return_t>(lo), - static_cast<return_t>(hi)); + distribution_t, format_t>(&urbg, static_cast<return_t>(lo), + static_cast<return_t>(hi)); } // absl::Uniform<unsigned T>(bitgen) @@ -212,9 +217,10 @@ typename absl::enable_if_t<!std::is_signed<R>::value, R> // Uniform(URBG&& urbg) { // NOLINT(runtime/references) using gen_t = absl::decay_t<URBG>; using distribution_t = random_internal::UniformDistributionWrapper<R>; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; return random_internal::DistributionCaller<gen_t>::template Call< - distribution_t>(&urbg); + distribution_t, format_t>(&urbg); } // ----------------------------------------------------------------------------- @@ -242,9 +248,10 @@ bool Bernoulli(URBG&& urbg, // NOLINT(runtime/references) double p) { using gen_t = absl::decay_t<URBG>; using distribution_t = absl::bernoulli_distribution; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; return random_internal::DistributionCaller<gen_t>::template Call< - distribution_t>(&urbg, p); + distribution_t, format_t>(&urbg, p); } // ----------------------------------------------------------------------------- @@ -274,9 +281,10 @@ RealType Beta(URBG&& urbg, // NOLINT(runtime/references) using gen_t = absl::decay_t<URBG>; using distribution_t = typename absl::beta_distribution<RealType>; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; return random_internal::DistributionCaller<gen_t>::template Call< - distribution_t>(&urbg, alpha, beta); + distribution_t, format_t>(&urbg, alpha, beta); } // ----------------------------------------------------------------------------- @@ -306,9 +314,10 @@ RealType Exponential(URBG&& urbg, // NOLINT(runtime/references) using gen_t = absl::decay_t<URBG>; using distribution_t = typename absl::exponential_distribution<RealType>; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; return random_internal::DistributionCaller<gen_t>::template Call< - distribution_t>(&urbg, lambda); + distribution_t, format_t>(&urbg, lambda); } // ----------------------------------------------------------------------------- @@ -337,9 +346,10 @@ RealType Gaussian(URBG&& urbg, // NOLINT(runtime/references) using gen_t = absl::decay_t<URBG>; using distribution_t = typename absl::gaussian_distribution<RealType>; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; return random_internal::DistributionCaller<gen_t>::template Call< - distribution_t>(&urbg, mean, stddev); + distribution_t, format_t>(&urbg, mean, stddev); } // ----------------------------------------------------------------------------- @@ -379,9 +389,10 @@ IntType LogUniform(URBG&& urbg, // NOLINT(runtime/references) using gen_t = absl::decay_t<URBG>; using distribution_t = typename absl::log_uniform_int_distribution<IntType>; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; return random_internal::DistributionCaller<gen_t>::template Call< - distribution_t>(&urbg, lo, hi, base); + distribution_t, format_t>(&urbg, lo, hi, base); } // ----------------------------------------------------------------------------- @@ -409,9 +420,10 @@ IntType Poisson(URBG&& urbg, // NOLINT(runtime/references) using gen_t = absl::decay_t<URBG>; using distribution_t = typename absl::poisson_distribution<IntType>; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; return random_internal::DistributionCaller<gen_t>::template Call< - distribution_t>(&urbg, mean); + distribution_t, format_t>(&urbg, mean); } // ----------------------------------------------------------------------------- @@ -441,9 +453,10 @@ IntType Zipf(URBG&& urbg, // NOLINT(runtime/references) using gen_t = absl::decay_t<URBG>; using distribution_t = typename absl::zipf_distribution<IntType>; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; return random_internal::DistributionCaller<gen_t>::template Call< - distribution_t>(&urbg, hi, q, v); + distribution_t, format_t>(&urbg, hi, q, v); } ABSL_NAMESPACE_END diff --git a/absl/random/internal/BUILD.bazel b/absl/random/internal/BUILD.bazel index 8839078f5f9e..d7ad4efec989 100644 --- a/absl/random/internal/BUILD.bazel +++ b/absl/random/internal/BUILD.bazel @@ -509,6 +509,7 @@ cc_library( linkopts = ABSL_DEFAULT_LINKOPTS, deps = [ "//absl/random", + "//absl/strings", ], ) diff --git a/absl/random/internal/distribution_caller.h b/absl/random/internal/distribution_caller.h index ae2680ddd40a..02603cf84355 100644 --- a/absl/random/internal/distribution_caller.h +++ b/absl/random/internal/distribution_caller.h @@ -31,8 +31,22 @@ namespace random_internal { template <typename URBG> struct DistributionCaller { // Call the provided distribution type. The parameters are expected - // to be explicitly specified. DistrT is the distribution type. - template <typename DistrT, typename... Args> + // to be explicitly specified. + // DistrT is the distribution type. + // FormatT is the formatter type: + // + // struct FormatT { + // using result_type = distribution_t::result_type; + // static std::string FormatCall( + // const distribution_t& distr, + // absl::Span<const result_type>); + // + // static std::string FormatExpectation( + // absl::string_view match_args, + // absl::Span<const result_t> results); + // } + // + template <typename DistrT, typename FormatT, typename... Args> static typename DistrT::result_type Call(URBG* urbg, Args&&... args) { DistrT dist(std::forward<Args>(args)...); return dist(*urbg); diff --git a/absl/random/internal/mocking_bit_gen_base.h b/absl/random/internal/mocking_bit_gen_base.h index acd6387204aa..eeeae9d295b2 100644 --- a/absl/random/internal/mocking_bit_gen_base.h +++ b/absl/random/internal/mocking_bit_gen_base.h @@ -16,14 +16,39 @@ #ifndef ABSL_RANDOM_INTERNAL_MOCKING_BIT_GEN_BASE_H_ #define ABSL_RANDOM_INTERNAL_MOCKING_BIT_GEN_BASE_H_ +#include <atomic> +#include <deque> +#include <string> #include <typeinfo> #include "absl/random/random.h" +#include "absl/strings/str_cat.h" namespace absl { ABSL_NAMESPACE_BEGIN namespace random_internal { +// MockingBitGenExpectationFormatter is invoked to format unsatisfied mocks +// and remaining results into a description string. +template <typename DistrT, typename FormatT> +struct MockingBitGenExpectationFormatter { + std::string operator()(absl::string_view args) { + return absl::StrCat(FormatT::FunctionName(), "(", args, ")"); + } +}; + +// MockingBitGenCallFormatter is invoked to format each distribution call +// into a description string for the mock log. +template <typename DistrT, typename FormatT> +struct MockingBitGenCallFormatter { + std::string operator()(const DistrT& dist, + const typename DistrT::result_type& result) { + return absl::StrCat( + FormatT::FunctionName(), "(", FormatT::FormatArgs(dist), ") => {", + FormatT::FormatResults(absl::MakeSpan(&result, 1)), "}"); + } +}; + class MockingBitGenBase { template <typename> friend struct DistributionCaller; @@ -36,9 +61,14 @@ class MockingBitGenBase { static constexpr result_type(max)() { return (generator_type::max)(); } result_type operator()() { return gen_(); } + MockingBitGenBase() : gen_(), observed_call_log_() {} virtual ~MockingBitGenBase() = default; protected: + const std::deque<std::string>& observed_call_log() { + return observed_call_log_; + } + // CallImpl is the type-erased virtual dispatch. // The type of dist is always distribution<T>, // The type of result is always distribution<T>::result_type. @@ -51,9 +81,10 @@ class MockingBitGenBase { } // Call the generating distribution function. - // Invoked by DistributionCaller<>::Call<DistT>. + // Invoked by DistributionCaller<>::Call<DistT, FormatT>. // DistT is the distribution type. - template <typename DistrT, typename... Args> + // FormatT is the distribution formatter traits type. + template <typename DistrT, typename FormatT, typename... Args> typename DistrT::result_type Call(Args&&... args) { using distr_result_type = typename DistrT::result_type; using ArgTupleT = std::tuple<absl::decay_t<Args>...>; @@ -68,11 +99,18 @@ class MockingBitGenBase { if (!found_match) { result = dist(gen_); } + + // TODO(asoffer): Forwarding the args through means we no longer need to + // extract them from the from the distribution in formatter traits. We can + // just StrJoin them. + observed_call_log_.push_back( + MockingBitGenCallFormatter<DistrT, FormatT>{}(dist, result)); return result; } private: generator_type gen_; + std::deque<std::string> observed_call_log_; }; // namespace random_internal } // namespace random_internal diff --git a/absl/random/mocking_bit_gen.cc b/absl/random/mocking_bit_gen.cc index 022091154541..6bb1e414aeab 100644 --- a/absl/random/mocking_bit_gen.cc +++ b/absl/random/mocking_bit_gen.cc @@ -20,6 +20,7 @@ namespace absl { ABSL_NAMESPACE_BEGIN MockingBitGen::~MockingBitGen() { + for (const auto& del : deleters_) { del(); } diff --git a/absl/random/mocking_bit_gen.h b/absl/random/mocking_bit_gen.h index 246c5b1e035a..36cef91113e3 100644 --- a/absl/random/mocking_bit_gen.h +++ b/absl/random/mocking_bit_gen.h @@ -109,7 +109,7 @@ class MockingBitGen : public absl::random_internal::MockingBitGenBase { // MockingBitGen::Register // - // Register<DistrT, ArgTupleT> is the main extension point for + // Register<DistrT, FormatT, ArgTupleT> is the main extension point for // extending the MockingBitGen framework. It provides a mechanism to install a // mock expectation for the distribution `distr_t` onto the MockingBitGen // context. @@ -182,10 +182,10 @@ namespace random_internal { template <> struct DistributionCaller<absl::MockingBitGen> { - template <typename DistrT, typename... Args> + template <typename DistrT, typename FormatT, typename... Args> static typename DistrT::result_type Call(absl::MockingBitGen* gen, Args&&... args) { - return gen->template Call<DistrT>(std::forward<Args>(args)...); + return gen->template Call<DistrT, FormatT>(std::forward<Args>(args)...); } }; diff --git a/absl/random/mocking_bit_gen_test.cc b/absl/random/mocking_bit_gen_test.cc index dcf74fd6db85..f0ffc9ac9283 100644 --- a/absl/random/mocking_bit_gen_test.cc +++ b/absl/random/mocking_bit_gen_test.cc @@ -66,10 +66,10 @@ TEST(BasicMocking, AllDistributionsAreOverridable) { .WillOnce(Return(0.001)); EXPECT_EQ(absl::Gaussian<double>(gen, 0.0, 1.0), 0.001); - EXPECT_NE(absl::LogUniform<int>(gen, 0, 1000000, 2), 2040); + EXPECT_NE(absl::LogUniform<int>(gen, 0, 1000000, 2), 500000); EXPECT_CALL(absl::MockLogUniform<int>(), Call(gen, 0, 1000000, 2)) - .WillOnce(Return(2040)); - EXPECT_EQ(absl::LogUniform<int>(gen, 0, 1000000, 2), 2040); + .WillOnce(Return(500000)); + EXPECT_EQ(absl::LogUniform<int>(gen, 0, 1000000, 2), 500000); } TEST(BasicMocking, OnDistribution) { |