diff options
Diffstat (limited to 'absl/random')
-rw-r--r-- | absl/random/BUILD.bazel | 2 | ||||
-rw-r--r-- | absl/random/distribution_format_traits.h | 22 | ||||
-rw-r--r-- | absl/random/distributions.h | 56 | ||||
-rw-r--r-- | absl/random/internal/distributions.h | 30 | ||||
-rw-r--r-- | absl/random/internal/uniform_helper.h | 12 |
5 files changed, 58 insertions, 64 deletions
diff --git a/absl/random/BUILD.bazel b/absl/random/BUILD.bazel index c904618da2a0..92111827c150 100644 --- a/absl/random/BUILD.bazel +++ b/absl/random/BUILD.bazel @@ -26,7 +26,7 @@ load( package(default_visibility = ["//visibility:public"]) -licenses(["notice"]) # Apache 2.0 +licenses(["notice"]) cc_library( name = "random", diff --git a/absl/random/distribution_format_traits.h b/absl/random/distribution_format_traits.h index 3298c2cdb64b..f9f070589f00 100644 --- a/absl/random/distribution_format_traits.h +++ b/absl/random/distribution_format_traits.h @@ -249,12 +249,12 @@ struct DistributionFormatTraits<absl::log_uniform_int_distribution<R>> { } }; -template <typename TagType, typename NumType> +template <typename NumType> struct UniformDistributionWrapper; -template <typename TagType, typename NumType> -struct DistributionFormatTraits<UniformDistributionWrapper<TagType, NumType>> { - using distribution_t = UniformDistributionWrapper<TagType, NumType>; +template <typename NumType> +struct DistributionFormatTraits<UniformDistributionWrapper<NumType>> { + using distribution_t = UniformDistributionWrapper<NumType>; using result_t = NumType; static constexpr const char* Name() { return "Uniform"; } @@ -263,19 +263,7 @@ struct DistributionFormatTraits<UniformDistributionWrapper<TagType, NumType>> { return absl::StrCat(Name(), "<", ScalarTypeName<NumType>(), ">"); } static std::string FormatArgs(const distribution_t& d) { - absl::string_view tag; - if (std::is_same<TagType, IntervalClosedClosedTag>::value) { - tag = "IntervalClosedClosed"; - } else if (std::is_same<TagType, IntervalClosedOpenTag>::value) { - tag = "IntervalClosedOpen"; - } else if (std::is_same<TagType, IntervalOpenClosedTag>::value) { - tag = "IntervalOpenClosed"; - } else if (std::is_same<TagType, IntervalOpenOpenTag>::value) { - tag = "IntervalOpenOpen"; - } else { - tag = "[[unknown tag type]]"; - } - return absl::StrCat(tag, ", ", (d.min)(), ", ", (d.max)()); + return absl::StrCat((d.min)(), ", ", (d.max)()); } static std::string FormatResults(absl::Span<const result_t> results) { return absl::StrJoin(results, ", "); diff --git a/absl/random/distributions.h b/absl/random/distributions.h index 3a4e93abfa2f..6ced60616158 100644 --- a/absl/random/distributions.h +++ b/absl/random/distributions.h @@ -124,7 +124,15 @@ Uniform(TagType tag, URBG&& urbg, // NOLINT(runtime/references) R lo, R hi) { using gen_t = absl::decay_t<URBG>; - return random_internal::UniformImpl<R, TagType, gen_t>(tag, urbg, lo, hi); + using distribution_t = random_internal::UniformDistributionWrapper<R>; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; + + auto a = random_internal::uniform_lower_bound(tag, lo, hi); + auto b = random_internal::uniform_upper_bound(tag, lo, hi); + if (a > b) return a; + + return random_internal::DistributionCaller<gen_t>::template Call< + distribution_t, format_t>(&urbg, tag, lo, hi); } // absl::Uniform<T>(bitgen, lo, hi) @@ -135,11 +143,17 @@ template <typename R = void, typename URBG> typename absl::enable_if_t<!std::is_same<R, void>::value, R> // Uniform(URBG&& urbg, // NOLINT(runtime/references) R lo, R hi) { - constexpr auto tag = absl::IntervalClosedOpen; - using tag_t = decltype(tag); using gen_t = absl::decay_t<URBG>; + using distribution_t = random_internal::UniformDistributionWrapper<R>; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; + + constexpr auto tag = absl::IntervalClosedOpen; + auto a = random_internal::uniform_lower_bound(tag, lo, hi); + auto b = random_internal::uniform_upper_bound(tag, lo, hi); + if (a > b) return a; - return random_internal::UniformImpl<R, tag_t, gen_t>(tag, urbg, lo, hi); + return random_internal::DistributionCaller<gen_t>::template Call< + distribution_t, format_t>(&urbg, lo, hi); } // absl::Uniform(tag, bitgen, lo, hi) @@ -156,9 +170,16 @@ Uniform(TagType tag, A lo, B hi) { using gen_t = absl::decay_t<URBG>; using return_t = typename random_internal::uniform_inferred_return_t<A, B>; + using distribution_t = random_internal::UniformDistributionWrapper<return_t>; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; - return random_internal::UniformImpl<return_t, TagType, gen_t>(tag, urbg, lo, - hi); + auto a = random_internal::uniform_lower_bound<return_t>(tag, lo, hi); + auto b = random_internal::uniform_upper_bound<return_t>(tag, lo, hi); + if (a > b) return a; + + return random_internal::DistributionCaller<gen_t>::template Call< + distribution_t, format_t>(&urbg, tag, static_cast<return_t>(lo), + static_cast<return_t>(hi)); } // absl::Uniform(bitgen, lo, hi) @@ -171,13 +192,19 @@ typename absl::enable_if_t<std::is_same<R, void>::value, random_internal::uniform_inferred_return_t<A, B>> Uniform(URBG&& urbg, // NOLINT(runtime/references) A lo, B hi) { - constexpr auto tag = absl::IntervalClosedOpen; - using tag_t = decltype(tag); using gen_t = absl::decay_t<URBG>; using return_t = typename random_internal::uniform_inferred_return_t<A, B>; + using distribution_t = random_internal::UniformDistributionWrapper<return_t>; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; - return random_internal::UniformImpl<return_t, tag_t, gen_t>(tag, urbg, lo, - hi); + constexpr auto tag = absl::IntervalClosedOpen; + auto a = random_internal::uniform_lower_bound<return_t>(tag, lo, hi); + auto b = random_internal::uniform_upper_bound<return_t>(tag, lo, hi); + if (a > b) return a; + + return random_internal::DistributionCaller<gen_t>::template Call< + distribution_t, format_t>(&urbg, static_cast<return_t>(lo), + static_cast<return_t>(hi)); } // absl::Uniform<unsigned T>(bitgen) @@ -187,13 +214,12 @@ Uniform(URBG&& urbg, // NOLINT(runtime/references) template <typename R, typename URBG> typename absl::enable_if_t<!std::is_signed<R>::value, R> // Uniform(URBG&& urbg) { // NOLINT(runtime/references) - constexpr auto tag = absl::IntervalClosedClosed; - constexpr auto lo = std::numeric_limits<R>::lowest(); - constexpr auto hi = (std::numeric_limits<R>::max)(); - using tag_t = decltype(tag); using gen_t = absl::decay_t<URBG>; + using distribution_t = random_internal::UniformDistributionWrapper<R>; + using format_t = random_internal::DistributionFormatTraits<distribution_t>; - return random_internal::UniformImpl<R, tag_t, gen_t>(tag, urbg, lo, hi); + return random_internal::DistributionCaller<gen_t>::template Call< + distribution_t, format_t>(&urbg); } // ----------------------------------------------------------------------------- diff --git a/absl/random/internal/distributions.h b/absl/random/internal/distributions.h index 96f8bae3918f..c8cec02b7fbc 100644 --- a/absl/random/internal/distributions.h +++ b/absl/random/internal/distributions.h @@ -24,36 +24,6 @@ namespace absl { namespace random_internal { -template <typename D> -struct DistributionFormatTraits; - -// UniformImpl implements the core logic of the Uniform<T> call, which is to -// select the correct distribution type, compute the bounds based on the -// interval tag, and then generate a value. -template <typename NumType, typename TagType, typename URBG> -NumType UniformImpl(TagType tag, - URBG& urbg, // NOLINT(runtime/references) - NumType lo, NumType hi) { - static_assert( - std::is_arithmetic<NumType>::value, - "absl::Uniform<T>() must use an integer or real parameter type."); - - using distribution_t = - UniformDistributionWrapper<absl::decay_t<TagType>, NumType>; - using format_t = random_internal::DistributionFormatTraits<distribution_t>; - auto a = uniform_lower_bound(tag, lo, hi); - auto b = uniform_upper_bound(tag, lo, hi); - - // TODO(lar): it doesn't make a lot of sense to ask for a random number in an - // empty range. Right now we just return a boundary--even though that - // boundary is not an acceptable value! Is there something better we can do - // here? - if (a > b) return a; - - using gen_t = absl::decay_t<URBG>; - return DistributionCaller<gen_t>::template Call<distribution_t, format_t>( - &urbg, tag, lo, hi); -} // In the absence of an explicitly provided return-type, the template // "uniform_inferred_return_t<A, B>" is used to derive a suitable type, based on diff --git a/absl/random/internal/uniform_helper.h b/absl/random/internal/uniform_helper.h index 2929407e1497..f68b1823ef00 100644 --- a/absl/random/internal/uniform_helper.h +++ b/absl/random/internal/uniform_helper.h @@ -154,12 +154,22 @@ using UniformDistribution = absl::uniform_int_distribution<NumType>, absl::uniform_real_distribution<NumType>>::type; -template <typename TagType, typename NumType> +template <typename NumType> struct UniformDistributionWrapper : public UniformDistribution<NumType> { + template <typename TagType> explicit UniformDistributionWrapper(TagType, NumType lo, NumType hi) : UniformDistribution<NumType>( uniform_lower_bound<NumType>(TagType{}, lo, hi), uniform_upper_bound<NumType>(TagType{}, lo, hi)) {} + + explicit UniformDistributionWrapper(NumType lo, NumType hi) + : UniformDistribution<NumType>( + uniform_lower_bound<NumType>(IntervalClosedOpenTag(), lo, hi), + uniform_upper_bound<NumType>(IntervalClosedOpenTag(), lo, hi)) {} + + explicit UniformDistributionWrapper() + : UniformDistribution<NumType>(std::numeric_limits<NumType>::lowest(), + (std::numeric_limits<NumType>::max)()) {} }; } // namespace random_internal |