about summary refs log tree commit diff
path: root/absl/base
diff options
context:
space:
mode:
Diffstat (limited to 'absl/base')
-rw-r--r--absl/base/BUILD.bazel1
-rw-r--r--absl/base/exception_safety_testing_test.cc250
-rw-r--r--absl/base/internal/exception_safety_testing.h128
3 files changed, 195 insertions, 184 deletions
diff --git a/absl/base/BUILD.bazel b/absl/base/BUILD.bazel
index e528463345b8..743add7e0127 100644
--- a/absl/base/BUILD.bazel
+++ b/absl/base/BUILD.bazel
@@ -241,6 +241,7 @@ cc_library(
         "//absl/memory",
         "//absl/meta:type_traits",
         "//absl/strings",
+        "//absl/types:optional",
         "@com_google_googletest//:gtest",
     ],
 )
diff --git a/absl/base/exception_safety_testing_test.cc b/absl/base/exception_safety_testing_test.cc
index 4931c8865af5..20cbb435926a 100644
--- a/absl/base/exception_safety_testing_test.cc
+++ b/absl/base/exception_safety_testing_test.cc
@@ -384,70 +384,62 @@ struct CallOperator {
   }
 };
 
-struct FailsBasicGuarantee {
+struct NonNegative {
+  friend testing::AssertionResult AbslCheckInvariants(NonNegative* g) {
+    if (g->i >= 0) return testing::AssertionSuccess();
+    return testing::AssertionFailure()
+           << "i should be non-negative but is " << g->i;
+  }
+  bool operator==(const NonNegative& other) const { return i == other.i; }
+
+  int i;
+};
+
+template <typename T>
+struct DefaultFactory {
+  std::unique_ptr<T> operator()() const { return absl::make_unique<T>(); }
+};
+
+struct FailsBasicGuarantee : public NonNegative {
   void operator()() {
     --i;
     ThrowingValue<> bomb;
     ++i;
   }
-
-  bool operator==(const FailsBasicGuarantee& other) const {
-    return i == other.i;
-  }
-
-  friend testing::AssertionResult AbslCheckInvariants(
-      const FailsBasicGuarantee& g) {
-    if (g.i >= 0) return testing::AssertionSuccess();
-    return testing::AssertionFailure()
-           << "i should be non-negative but is " << g.i;
-  }
-
-  int i = 0;
 };
 
 TEST(ExceptionCheckTest, BasicGuaranteeFailure) {
-  FailsBasicGuarantee g;
-  EXPECT_FALSE(TestExceptionSafety(&g, CallOperator{}));
+  EXPECT_FALSE(TestExceptionSafety(DefaultFactory<FailsBasicGuarantee>(),
+                                   CallOperator{}));
 }
 
-struct FollowsBasicGuarantee {
+struct FollowsBasicGuarantee : public NonNegative {
   void operator()() {
     ++i;
     ThrowingValue<> bomb;
   }
-
-  bool operator==(const FollowsBasicGuarantee& other) const {
-    return i == other.i;
-  }
-
-  friend testing::AssertionResult AbslCheckInvariants(
-      const FollowsBasicGuarantee& g) {
-    if (g.i >= 0) return testing::AssertionSuccess();
-    return testing::AssertionFailure()
-           << "i should be non-negative but is " << g.i;
-  }
-
-  int i = 0;
 };
 
 TEST(ExceptionCheckTest, BasicGuarantee) {
-  FollowsBasicGuarantee g;
-  EXPECT_TRUE(TestExceptionSafety(&g, CallOperator{}));
+  EXPECT_TRUE(TestExceptionSafety(DefaultFactory<FollowsBasicGuarantee>(),
+                                  CallOperator{}));
 }
 
 TEST(ExceptionCheckTest, StrongGuaranteeFailure) {
   {
-    FailsBasicGuarantee g;
-    EXPECT_FALSE(TestExceptionSafety(&g, CallOperator{}, StrongGuarantee(g)));
+    DefaultFactory<FailsBasicGuarantee> factory;
+    EXPECT_FALSE(
+        TestExceptionSafety(factory, CallOperator{}, StrongGuarantee(factory)));
   }
 
   {
-    FollowsBasicGuarantee g;
-    EXPECT_FALSE(TestExceptionSafety(&g, CallOperator{}, StrongGuarantee(g)));
+    DefaultFactory<FollowsBasicGuarantee> factory;
+    EXPECT_FALSE(
+        TestExceptionSafety(factory, CallOperator{}, StrongGuarantee(factory)));
   }
 }
 
-struct BasicGuaranteeWithExtraInvariants {
+struct BasicGuaranteeWithExtraInvariants : public NonNegative {
   // After operator(), i is incremented.  If operator() throws, i is set to 9999
   void operator()() {
     int old_i = i;
@@ -456,92 +448,94 @@ struct BasicGuaranteeWithExtraInvariants {
     i = ++old_i;
   }
 
-  bool operator==(const FollowsBasicGuarantee& other) const {
-    return i == other.i;
-  }
-
-  friend testing::AssertionResult AbslCheckInvariants(
-      const BasicGuaranteeWithExtraInvariants& g) {
-    if (g.i >= 0) return testing::AssertionSuccess();
-    return testing::AssertionFailure()
-           << "i should be non-negative but is " << g.i;
-  }
-
-  int i = 0;
   static constexpr int kExceptionSentinel = 9999;
 };
 constexpr int BasicGuaranteeWithExtraInvariants::kExceptionSentinel;
 
 TEST(ExceptionCheckTest, BasicGuaranteeWithInvariants) {
-  {
-    BasicGuaranteeWithExtraInvariants g;
-    EXPECT_TRUE(TestExceptionSafety(&g, CallOperator{}));
-  }
+  DefaultFactory<BasicGuaranteeWithExtraInvariants> factory;
 
-  {
-    BasicGuaranteeWithExtraInvariants g;
-    EXPECT_TRUE(TestExceptionSafety(
-        &g, CallOperator{}, [](const BasicGuaranteeWithExtraInvariants& w) {
-          if (w.i == BasicGuaranteeWithExtraInvariants::kExceptionSentinel) {
-            return testing::AssertionSuccess();
-          }
-          return testing::AssertionFailure()
-                 << "i should be "
-                 << BasicGuaranteeWithExtraInvariants::kExceptionSentinel
-                 << ", but is " << w.i;
-        }));
-  }
-}
+  EXPECT_TRUE(TestExceptionSafety(factory, CallOperator{}));
 
-struct FollowsStrongGuarantee {
+  EXPECT_TRUE(TestExceptionSafety(
+      factory, CallOperator{}, [](BasicGuaranteeWithExtraInvariants* w) {
+        if (w->i == BasicGuaranteeWithExtraInvariants::kExceptionSentinel) {
+          return testing::AssertionSuccess();
+        }
+        return testing::AssertionFailure()
+               << "i should be "
+               << BasicGuaranteeWithExtraInvariants::kExceptionSentinel
+               << ", but is " << w->i;
+      }));
+}
+
+struct FollowsStrongGuarantee : public NonNegative {
   void operator()() { ThrowingValue<> bomb; }
+};
 
-  bool operator==(const FollowsStrongGuarantee& other) const {
-    return i == other.i;
-  }
+TEST(ExceptionCheckTest, StrongGuarantee) {
+  DefaultFactory<FollowsStrongGuarantee> factory;
+  EXPECT_TRUE(TestExceptionSafety(factory, CallOperator{}));
+  EXPECT_TRUE(
+      TestExceptionSafety(factory, CallOperator{}, StrongGuarantee(factory)));
+}
 
-  friend testing::AssertionResult AbslCheckInvariants(
-      const FollowsStrongGuarantee& g) {
-    if (g.i >= 0) return testing::AssertionSuccess();
-    return testing::AssertionFailure()
-           << "i should be non-negative but is " << g.i;
+struct HasReset : public NonNegative {
+  void operator()() {
+    i = -1;
+    ThrowingValue<> bomb;
+    i = 1;
   }
 
-  int i = 0;
+  void reset() { i = 0; }
+
+  friend bool AbslCheckInvariants(HasReset* h) {
+    h->reset();
+    return h->i == 0;
+  }
 };
 
-TEST(ExceptionCheckTest, StrongGuarantee) {
-  FollowsStrongGuarantee g;
-  EXPECT_TRUE(TestExceptionSafety(&g, CallOperator{}));
-  EXPECT_TRUE(TestExceptionSafety(&g, CallOperator{}, StrongGuarantee(g)));
+TEST(ExceptionCheckTest, ModifyingChecker) {
+  {
+    DefaultFactory<FollowsBasicGuarantee> factory;
+    EXPECT_FALSE(TestExceptionSafety(
+        factory, CallOperator{},
+        [](FollowsBasicGuarantee* g) {
+          g->i = 1000;
+          return true;
+        },
+        [](FollowsBasicGuarantee* g) { return g->i == 1000; }));
+  }
+  {
+    DefaultFactory<FollowsStrongGuarantee> factory;
+    EXPECT_TRUE(TestExceptionSafety(factory, CallOperator{},
+                                    [](FollowsStrongGuarantee* g) {
+                                      ++g->i;
+                                      return true;
+                                    },
+                                    StrongGuarantee(factory)));
+  }
+  {
+    DefaultFactory<HasReset> factory;
+    EXPECT_TRUE(TestExceptionSafety(factory, CallOperator{}));
+  }
 }
 
-struct NonCopyable {
+struct NonCopyable : public NonNegative {
   NonCopyable(const NonCopyable&) = delete;
-  explicit NonCopyable(int ii) : i(ii) {}
+  NonCopyable() : NonNegative{0} {}
 
   void operator()() { ThrowingValue<> bomb; }
-
-  bool operator==(const NonCopyable& other) const { return i == other.i; }
-
-  friend testing::AssertionResult AbslCheckInvariants(const NonCopyable& g) {
-    if (g.i >= 0) return testing::AssertionSuccess();
-    return testing::AssertionFailure()
-           << "i should be non-negative but is " << g.i;
-  }
-
-  int i;
 };
 
 TEST(ExceptionCheckTest, NonCopyable) {
-  NonCopyable g(0);
-  EXPECT_TRUE(TestExceptionSafety(&g, CallOperator{}));
-  EXPECT_TRUE(TestExceptionSafety(
-      &g, CallOperator{},
-      PointeeStrongGuarantee(absl::make_unique<NonCopyable>(g.i))));
+  DefaultFactory<NonCopyable> factory;
+  EXPECT_TRUE(TestExceptionSafety(factory, CallOperator{}));
+  EXPECT_TRUE(
+      TestExceptionSafety(factory, CallOperator{}, StrongGuarantee(factory)));
 }
 
-struct NonEqualityComparable {
+struct NonEqualityComparable : public NonNegative {
   void operator()() { ThrowingValue<> bomb; }
 
   void ModifyOnThrow() {
@@ -550,71 +544,61 @@ struct NonEqualityComparable {
     static_cast<void>(bomb);
     --i;
   }
-
-  friend testing::AssertionResult AbslCheckInvariants(
-      const NonEqualityComparable& g) {
-    if (g.i >= 0) return testing::AssertionSuccess();
-    return testing::AssertionFailure()
-           << "i should be non-negative but is " << g.i;
-  }
-
-  int i = 0;
 };
 
 TEST(ExceptionCheckTest, NonEqualityComparable) {
-  NonEqualityComparable g;
+  DefaultFactory<NonEqualityComparable> factory;
   auto comp = [](const NonEqualityComparable& a,
                  const NonEqualityComparable& b) { return a.i == b.i; };
-  EXPECT_TRUE(TestExceptionSafety(&g, CallOperator{}));
-  EXPECT_TRUE(
-      TestExceptionSafety(&g, CallOperator{}, absl::StrongGuarantee(g, comp)));
+  EXPECT_TRUE(TestExceptionSafety(factory, CallOperator{}));
+  EXPECT_TRUE(TestExceptionSafety(factory, CallOperator{},
+                                  absl::StrongGuarantee(factory, comp)));
   EXPECT_FALSE(TestExceptionSafety(
-      &g, [&](NonEqualityComparable* n) { n->ModifyOnThrow(); },
-      absl::StrongGuarantee(g, comp)));
+      factory, [&](NonEqualityComparable* n) { n->ModifyOnThrow(); },
+      absl::StrongGuarantee(factory, comp)));
 }
 
 template <typename T>
-struct InstructionCounter {
+struct ExhaustivenessTester {
   void operator()() {
-    ++counter;
+    successes |= 1;
     T b1;
     static_cast<void>(b1);
-    ++counter;
+    successes |= (1 << 1);
     T b2;
     static_cast<void>(b2);
-    ++counter;
+    successes |= (1 << 2);
     T b3;
     static_cast<void>(b3);
-    ++counter;
+    successes |= (1 << 3);
   }
 
-  bool operator==(const InstructionCounter<ThrowingValue<>>&) const {
+  bool operator==(const ExhaustivenessTester<ThrowingValue<>>&) const {
     return true;
   }
 
-  friend testing::AssertionResult AbslCheckInvariants(
-      const InstructionCounter&) {
+  friend testing::AssertionResult AbslCheckInvariants(ExhaustivenessTester*) {
     return testing::AssertionSuccess();
   }
 
-  static int counter;
+  static unsigned char successes;
 };
 template <typename T>
-int InstructionCounter<T>::counter = 0;
+unsigned char ExhaustivenessTester<T>::successes = 0;
 
 TEST(ExceptionCheckTest, Exhaustiveness) {
-  InstructionCounter<int> int_factory;
-  EXPECT_TRUE(TestExceptionSafety(&int_factory, CallOperator{}));
-  EXPECT_EQ(InstructionCounter<int>::counter, 4);
+  DefaultFactory<ExhaustivenessTester<int>> int_factory;
+  EXPECT_TRUE(TestExceptionSafety(int_factory, CallOperator{}));
+  EXPECT_EQ(ExhaustivenessTester<int>::successes, 0xF);
 
-  InstructionCounter<ThrowingValue<>> bomb_factory;
-  EXPECT_TRUE(TestExceptionSafety(&bomb_factory, CallOperator{}));
-  EXPECT_EQ(InstructionCounter<ThrowingValue<>>::counter, 10);
+  DefaultFactory<ExhaustivenessTester<ThrowingValue<>>> bomb_factory;
+  EXPECT_TRUE(TestExceptionSafety(bomb_factory, CallOperator{}));
+  EXPECT_EQ(ExhaustivenessTester<ThrowingValue<>>::successes, 0xF);
 
-  InstructionCounter<ThrowingValue<>>::counter = 0;
-  EXPECT_TRUE(TestExceptionSafety(&bomb_factory, CallOperator{},
+  ExhaustivenessTester<ThrowingValue<>>::successes = 0;
+  EXPECT_TRUE(TestExceptionSafety(bomb_factory, CallOperator{},
                                   StrongGuarantee(bomb_factory)));
-  EXPECT_EQ(InstructionCounter<ThrowingValue<>>::counter, 10);
+  EXPECT_EQ(ExhaustivenessTester<ThrowingValue<>>::successes, 0xF);
 }
 
 struct LeaksIfCtorThrows : private exceptions_internal::TrackedObject {
diff --git a/absl/base/internal/exception_safety_testing.h b/absl/base/internal/exception_safety_testing.h
index a0a70d91d2d9..05bcd0ab59d6 100644
--- a/absl/base/internal/exception_safety_testing.h
+++ b/absl/base/internal/exception_safety_testing.h
@@ -18,6 +18,7 @@
 #include "absl/meta/type_traits.h"
 #include "absl/strings/string_view.h"
 #include "absl/strings/substitute.h"
+#include "absl/types/optional.h"
 
 namespace absl {
 struct AllocInspector;
@@ -97,19 +98,50 @@ class TrackedObject {
   friend struct ::absl::AllocInspector;
 };
 
-template <typename T, typename... Checkers>
-testing::AssertionResult TestInvariants(const T& t, const TestException& e,
-                                        int count,
-                                        const Checkers&... checkers) {
-  auto out = AbslCheckInvariants(t);
+template <typename Factory>
+using FactoryType = typename absl::result_of_t<Factory()>::element_type;
+
+// Returns an optional with the result of the check if op fails, or an empty
+// optional if op passes
+template <typename Factory, typename Op, typename Checker>
+absl::optional<testing::AssertionResult> TestCheckerAtCountdown(
+    Factory factory, const Op& op, int count, const Checker& check) {
+  exceptions_internal::countdown = count;
+  auto t_ptr = factory();
+  absl::optional<testing::AssertionResult> out;
+  try {
+    op(t_ptr.get());
+  } catch (const exceptions_internal::TestException& e) {
+    out.emplace(check(t_ptr.get()));
+    if (!*out) {
+      *out << " caused by exception thrown by " << e.what();
+    }
+  }
+  return out;
+}
+
+template <typename Factory, typename Op, typename Checker>
+int UpdateOut(Factory factory, const Op& op, int count, const Checker& checker,
+              testing::AssertionResult* out) {
+  if (*out) *out = *TestCheckerAtCountdown(factory, op, count, checker);
+  return 0;
+}
+
+// Returns an optional with the result of the check if op fails, or an empty
+// optional if op passes
+template <typename Factory, typename Op, typename... Checkers>
+absl::optional<testing::AssertionResult> TestAtCountdown(
+    Factory factory, const Op& op, int count, const Checkers&... checkers) {
   // Don't bother with the checkers if the class invariants are already broken.
-  bool dummy[] = {true,
-                  (out && (out = testing::AssertionResult(checkers(t))))...};
-  static_cast<void>(dummy);
+  auto out = TestCheckerAtCountdown(
+      factory, op, count,
+      [](FactoryType<Factory>* t_ptr) { return AbslCheckInvariants(t_ptr); });
+  if (!out.has_value()) return out;
 
-  return out ? out
-             : out << " Caused by exception " << count << "thrown by "
-                   << e.what();
+  // Run each checker, short circuiting after the first failure
+  int dummy[] = {0, (UpdateOut(factory, op, count, checkers, &*out))...};
+  static_cast<void>(dummy);
+  return out;
 }
 
 template <typename T, typename EqualTo>
@@ -118,9 +150,9 @@ class StrongGuaranteeTester {
   explicit StrongGuaranteeTester(std::unique_ptr<T> t_ptr, EqualTo eq) noexcept
       : val_(std::move(t_ptr)), eq_(eq) {}
 
-  testing::AssertionResult operator()(const T& other) const {
-    return eq_(*val_, other) ? testing::AssertionSuccess()
-                             : testing::AssertionFailure() << "State changed";
+  testing::AssertionResult operator()(T* other) const {
+    return eq_(*val_, *other) ? testing::AssertionSuccess()
+                              : testing::AssertionFailure() << "State changed";
   }
 
  private:
@@ -673,58 +705,52 @@ T TestThrowingCtor(Args&&... args) {
 }
 
 // Tests that performing operation Op on a T follows exception safety
-// guarantees.  By default only tests the basic guarantee.
+// guarantees.  By default only tests the basic guarantee. There must be a
+// function, AbslCheckInvariants(T*) which returns
+// anything convertible to bool and which makes sure the invariants of the type
+// are upheld.  This is called before any of the checkers.
 //
 // Parameters:
-//   * T: the type under test.
+//   * TFactory: operator() returns a unique_ptr to the type under test (T).  It
+//   should always return pointers to values which compare equal.
 //   * FunctionFromTPtrToVoid: A functor exercising the function under test.  It
 //   should take a T* and return void.
-//   * Checkers: Any number of functions taking a const T& and returning
+//   * Checkers: Any number of functions taking a T* and returning
 //   anything contextually convertible to bool.  If a testing::AssertionResult
 //   is used then the error message is kept.  These test invariants related to
 //   the operation. To test the strong guarantee, pass
-//   absl::StrongGuarantee(...) as one of these arguments if T has operator==.
-//   Some types for which the strong guarantee makes sense don't have operator==
-//   (eg std::any).  A function capturing *t or a T equal to it, taking a const
-//   T&, and returning contextually-convertible-to-bool may be passed instead.
-template <typename T, typename FunctionFromTPtrToVoid, typename... Checkers>
-testing::AssertionResult TestExceptionSafety(T* t, FunctionFromTPtrToVoid&& op,
+//   absl::StrongGuarantee(factory).  A checker may freely modify the passed-in
+//   T, for example to make sure the T can be set to a known state.
+template <typename TFactory, typename FunctionFromTPtrToVoid,
+          typename... Checkers>
+testing::AssertionResult TestExceptionSafety(TFactory factory,
+                                             FunctionFromTPtrToVoid&& op,
                                              const Checkers&... checkers) {
-  auto out = testing::AssertionSuccess();
   for (int countdown = 0;; ++countdown) {
-    exceptions_internal::countdown = countdown;
-    try {
-      op(t);
-      break;
-    } catch (const exceptions_internal::TestException& e) {
-      out = exceptions_internal::TestInvariants(*t, e, countdown, checkers...);
-      if (!out) return out;
+    auto out = exceptions_internal::TestAtCountdown(factory, op, countdown,
+                                                    checkers...);
+    if (!out.has_value()) {
+      UnsetCountdown();
+      return testing::AssertionSuccess();
     }
+    if (!*out) return *out;
   }
-  UnsetCountdown();
-  return out;
 }
 
-// Returns a functor to test for the strong exception-safety guarantee.  If T is
-// copyable, use the const T& overload, otherwise pass a unique_ptr<T>.
-// Equality comparisons are made against the T provided and default to using
-// operator==.  See the documentation for TestExceptionSafety if T doesn't have
-// operator== but the strong guarantee still makes sense for it.
+// Returns a functor to test for the strong exception-safety guarantee.
+// Equality comparisons are made against the T provided by the factory and
+// default to using operator==.
 //
 // Parameters:
-//   * T: The type under test.
-template <typename T, typename EqualTo = std::equal_to<T>>
-exceptions_internal::StrongGuaranteeTester<T, EqualTo> StrongGuarantee(
-    const T& t, EqualTo eq = EqualTo()) {
-  return exceptions_internal::StrongGuaranteeTester<T, EqualTo>(
-      absl::make_unique<T>(t), eq);
-}
-
-template <typename T, typename EqualTo = std::equal_to<T>>
-exceptions_internal::StrongGuaranteeTester<T, EqualTo> PointeeStrongGuarantee(
-    std::unique_ptr<T> t_ptr, EqualTo eq = EqualTo()) {
-  return exceptions_internal::StrongGuaranteeTester<T, EqualTo>(
-      std::move(t_ptr), eq);
+//   * TFactory: operator() returns a unique_ptr to the type under test.  It
+//   should always return pointers to values which compare equal.
+template <typename TFactory, typename EqualTo = std::equal_to<
+                                 exceptions_internal::FactoryType<TFactory>>>
+exceptions_internal::StrongGuaranteeTester<
+    exceptions_internal::FactoryType<TFactory>, EqualTo>
+StrongGuarantee(TFactory factory, EqualTo eq = EqualTo()) {
+  return exceptions_internal::StrongGuaranteeTester<
+      exceptions_internal::FactoryType<TFactory>, EqualTo>(factory(), eq);
 }
 
 }  // namespace absl