about summary refs log tree commit diff
path: root/absl/random
diff options
context:
space:
mode:
Diffstat (limited to 'absl/random')
-rw-r--r--absl/random/BUILD.bazel2
-rw-r--r--absl/random/distribution_format_traits.h22
-rw-r--r--absl/random/distributions.h56
-rw-r--r--absl/random/internal/distributions.h30
-rw-r--r--absl/random/internal/uniform_helper.h12
5 files changed, 58 insertions, 64 deletions
diff --git a/absl/random/BUILD.bazel b/absl/random/BUILD.bazel
index c904618da2a0..92111827c150 100644
--- a/absl/random/BUILD.bazel
+++ b/absl/random/BUILD.bazel
@@ -26,7 +26,7 @@ load(
 
 package(default_visibility = ["//visibility:public"])
 
-licenses(["notice"])  # Apache 2.0
+licenses(["notice"])
 
 cc_library(
     name = "random",
diff --git a/absl/random/distribution_format_traits.h b/absl/random/distribution_format_traits.h
index 3298c2cdb64b..f9f070589f00 100644
--- a/absl/random/distribution_format_traits.h
+++ b/absl/random/distribution_format_traits.h
@@ -249,12 +249,12 @@ struct DistributionFormatTraits<absl::log_uniform_int_distribution<R>> {
   }
 };
 
-template <typename TagType, typename NumType>
+template <typename NumType>
 struct UniformDistributionWrapper;
 
-template <typename TagType, typename NumType>
-struct DistributionFormatTraits<UniformDistributionWrapper<TagType, NumType>> {
-  using distribution_t = UniformDistributionWrapper<TagType, NumType>;
+template <typename NumType>
+struct DistributionFormatTraits<UniformDistributionWrapper<NumType>> {
+  using distribution_t = UniformDistributionWrapper<NumType>;
   using result_t = NumType;
 
   static constexpr const char* Name() { return "Uniform"; }
@@ -263,19 +263,7 @@ struct DistributionFormatTraits<UniformDistributionWrapper<TagType, NumType>> {
     return absl::StrCat(Name(), "<", ScalarTypeName<NumType>(), ">");
   }
   static std::string FormatArgs(const distribution_t& d) {
-    absl::string_view tag;
-    if (std::is_same<TagType, IntervalClosedClosedTag>::value) {
-      tag = "IntervalClosedClosed";
-    } else if (std::is_same<TagType, IntervalClosedOpenTag>::value) {
-      tag = "IntervalClosedOpen";
-    } else if (std::is_same<TagType, IntervalOpenClosedTag>::value) {
-      tag = "IntervalOpenClosed";
-    } else if (std::is_same<TagType, IntervalOpenOpenTag>::value) {
-      tag = "IntervalOpenOpen";
-    } else {
-      tag = "[[unknown tag type]]";
-    }
-    return absl::StrCat(tag, ", ", (d.min)(), ", ", (d.max)());
+    return absl::StrCat((d.min)(), ", ", (d.max)());
   }
   static std::string FormatResults(absl::Span<const result_t> results) {
     return absl::StrJoin(results, ", ");
diff --git a/absl/random/distributions.h b/absl/random/distributions.h
index 3a4e93abfa2f..6ced60616158 100644
--- a/absl/random/distributions.h
+++ b/absl/random/distributions.h
@@ -124,7 +124,15 @@ Uniform(TagType tag,
         URBG&& urbg,  // NOLINT(runtime/references)
         R lo, R hi) {
   using gen_t = absl::decay_t<URBG>;
-  return random_internal::UniformImpl<R, TagType, gen_t>(tag, urbg, lo, hi);
+  using distribution_t = random_internal::UniformDistributionWrapper<R>;
+  using format_t = random_internal::DistributionFormatTraits<distribution_t>;
+
+  auto a = random_internal::uniform_lower_bound(tag, lo, hi);
+  auto b = random_internal::uniform_upper_bound(tag, lo, hi);
+  if (a > b) return a;
+
+  return random_internal::DistributionCaller<gen_t>::template Call<
+      distribution_t, format_t>(&urbg, tag, lo, hi);
 }
 
 // absl::Uniform<T>(bitgen, lo, hi)
@@ -135,11 +143,17 @@ template <typename R = void, typename URBG>
 typename absl::enable_if_t<!std::is_same<R, void>::value, R>  //
 Uniform(URBG&& urbg,  // NOLINT(runtime/references)
         R lo, R hi) {
-  constexpr auto tag = absl::IntervalClosedOpen;
-  using tag_t = decltype(tag);
   using gen_t = absl::decay_t<URBG>;
+  using distribution_t = random_internal::UniformDistributionWrapper<R>;
+  using format_t = random_internal::DistributionFormatTraits<distribution_t>;
+
+  constexpr auto tag = absl::IntervalClosedOpen;
+  auto a = random_internal::uniform_lower_bound(tag, lo, hi);
+  auto b = random_internal::uniform_upper_bound(tag, lo, hi);
+  if (a > b) return a;
 
-  return random_internal::UniformImpl<R, tag_t, gen_t>(tag, urbg, lo, hi);
+  return random_internal::DistributionCaller<gen_t>::template Call<
+      distribution_t, format_t>(&urbg, lo, hi);
 }
 
 // absl::Uniform(tag, bitgen, lo, hi)
@@ -156,9 +170,16 @@ Uniform(TagType tag,
         A lo, B hi) {
   using gen_t = absl::decay_t<URBG>;
   using return_t = typename random_internal::uniform_inferred_return_t<A, B>;
+  using distribution_t = random_internal::UniformDistributionWrapper<return_t>;
+  using format_t = random_internal::DistributionFormatTraits<distribution_t>;
 
-  return random_internal::UniformImpl<return_t, TagType, gen_t>(tag, urbg, lo,
-                                                                hi);
+  auto a = random_internal::uniform_lower_bound<return_t>(tag, lo, hi);
+  auto b = random_internal::uniform_upper_bound<return_t>(tag, lo, hi);
+  if (a > b) return a;
+
+  return random_internal::DistributionCaller<gen_t>::template Call<
+      distribution_t, format_t>(&urbg, tag, static_cast<return_t>(lo),
+                                static_cast<return_t>(hi));
 }
 
 // absl::Uniform(bitgen, lo, hi)
@@ -171,13 +192,19 @@ typename absl::enable_if_t<std::is_same<R, void>::value,
                            random_internal::uniform_inferred_return_t<A, B>>
 Uniform(URBG&& urbg,  // NOLINT(runtime/references)
         A lo, B hi) {
-  constexpr auto tag = absl::IntervalClosedOpen;
-  using tag_t = decltype(tag);
   using gen_t = absl::decay_t<URBG>;
   using return_t = typename random_internal::uniform_inferred_return_t<A, B>;
+  using distribution_t = random_internal::UniformDistributionWrapper<return_t>;
+  using format_t = random_internal::DistributionFormatTraits<distribution_t>;
 
-  return random_internal::UniformImpl<return_t, tag_t, gen_t>(tag, urbg, lo,
-                                                              hi);
+  constexpr auto tag = absl::IntervalClosedOpen;
+  auto a = random_internal::uniform_lower_bound<return_t>(tag, lo, hi);
+  auto b = random_internal::uniform_upper_bound<return_t>(tag, lo, hi);
+  if (a > b) return a;
+
+  return random_internal::DistributionCaller<gen_t>::template Call<
+      distribution_t, format_t>(&urbg, static_cast<return_t>(lo),
+                                static_cast<return_t>(hi));
 }
 
 // absl::Uniform<unsigned T>(bitgen)
@@ -187,13 +214,12 @@ Uniform(URBG&& urbg,  // NOLINT(runtime/references)
 template <typename R, typename URBG>
 typename absl::enable_if_t<!std::is_signed<R>::value, R>  //
 Uniform(URBG&& urbg) {  // NOLINT(runtime/references)
-  constexpr auto tag = absl::IntervalClosedClosed;
-  constexpr auto lo = std::numeric_limits<R>::lowest();
-  constexpr auto hi = (std::numeric_limits<R>::max)();
-  using tag_t = decltype(tag);
   using gen_t = absl::decay_t<URBG>;
+  using distribution_t = random_internal::UniformDistributionWrapper<R>;
+  using format_t = random_internal::DistributionFormatTraits<distribution_t>;
 
-  return random_internal::UniformImpl<R, tag_t, gen_t>(tag, urbg, lo, hi);
+  return random_internal::DistributionCaller<gen_t>::template Call<
+      distribution_t, format_t>(&urbg);
 }
 
 // -----------------------------------------------------------------------------
diff --git a/absl/random/internal/distributions.h b/absl/random/internal/distributions.h
index 96f8bae3918f..c8cec02b7fbc 100644
--- a/absl/random/internal/distributions.h
+++ b/absl/random/internal/distributions.h
@@ -24,36 +24,6 @@
 
 namespace absl {
 namespace random_internal {
-template <typename D>
-struct DistributionFormatTraits;
-
-// UniformImpl implements the core logic of the Uniform<T> call, which is to
-// select the correct distribution type, compute the bounds based on the
-// interval tag, and then generate a value.
-template <typename NumType, typename TagType, typename URBG>
-NumType UniformImpl(TagType tag,
-                    URBG& urbg,  // NOLINT(runtime/references)
-                    NumType lo, NumType hi) {
-  static_assert(
-      std::is_arithmetic<NumType>::value,
-      "absl::Uniform<T>() must use an integer or real parameter type.");
-
-  using distribution_t =
-      UniformDistributionWrapper<absl::decay_t<TagType>, NumType>;
-  using format_t = random_internal::DistributionFormatTraits<distribution_t>;
-  auto a = uniform_lower_bound(tag, lo, hi);
-  auto b = uniform_upper_bound(tag, lo, hi);
-
-  // TODO(lar): it doesn't make a lot of sense to ask for a random number in an
-  // empty range.  Right now we just return a boundary--even though that
-  // boundary is not an acceptable value!  Is there something better we can do
-  // here?
-  if (a > b) return a;
-
-  using gen_t = absl::decay_t<URBG>;
-  return DistributionCaller<gen_t>::template Call<distribution_t, format_t>(
-      &urbg, tag, lo, hi);
-}
 
 // In the absence of an explicitly provided return-type, the template
 // "uniform_inferred_return_t<A, B>" is used to derive a suitable type, based on
diff --git a/absl/random/internal/uniform_helper.h b/absl/random/internal/uniform_helper.h
index 2929407e1497..f68b1823ef00 100644
--- a/absl/random/internal/uniform_helper.h
+++ b/absl/random/internal/uniform_helper.h
@@ -154,12 +154,22 @@ using UniformDistribution =
                               absl::uniform_int_distribution<NumType>,
                               absl::uniform_real_distribution<NumType>>::type;
 
-template <typename TagType, typename NumType>
+template <typename NumType>
 struct UniformDistributionWrapper : public UniformDistribution<NumType> {
+  template <typename TagType>
   explicit UniformDistributionWrapper(TagType, NumType lo, NumType hi)
       : UniformDistribution<NumType>(
             uniform_lower_bound<NumType>(TagType{}, lo, hi),
             uniform_upper_bound<NumType>(TagType{}, lo, hi)) {}
+
+  explicit UniformDistributionWrapper(NumType lo, NumType hi)
+      : UniformDistribution<NumType>(
+            uniform_lower_bound<NumType>(IntervalClosedOpenTag(), lo, hi),
+            uniform_upper_bound<NumType>(IntervalClosedOpenTag(), lo, hi)) {}
+
+  explicit UniformDistributionWrapper()
+      : UniformDistribution<NumType>(std::numeric_limits<NumType>::lowest(),
+                                     (std::numeric_limits<NumType>::max)()) {}
 };
 
 }  // namespace random_internal