about summary refs log tree commit diff
diff options
context:
space:
mode:
authorEelco Dolstra <edolstra@gmail.com>2019-09-03T10·51+0200
committerEelco Dolstra <edolstra@gmail.com>2019-09-03T11·45+0200
commit7348653ff4fc4e9b2dc24943aabdb57179b1c75a (patch)
treec730713ccb743b185578ee1b37462a7fc68b10b4
parent8c4ea7a4516c517a0dd37b446bf5c1a6b157064c (diff)
Ensure that Callback is called only once
Also, make Callback movable but uncopyable.
-rw-r--r--src/libstore/binary-cache-store.cc8
-rw-r--r--src/libstore/download.cc6
-rw-r--r--src/libstore/http-binary-cache-store.cc12
-rw-r--r--src/libstore/store-api.cc8
-rw-r--r--src/libutil/util.hh19
5 files changed, 36 insertions, 17 deletions
diff --git a/src/libstore/binary-cache-store.cc b/src/libstore/binary-cache-store.cc
index 4527ee6ba660..e56be625de47 100644
--- a/src/libstore/binary-cache-store.cc
+++ b/src/libstore/binary-cache-store.cc
@@ -249,21 +249,23 @@ void BinaryCacheStore::queryPathInfoUncached(const Path & storePath,
 
     auto narInfoFile = narInfoFileFor(storePath);
 
+    auto callbackPtr = std::make_shared<decltype(callback)>(std::move(callback));
+
     getFile(narInfoFile,
         {[=](std::future<std::shared_ptr<std::string>> fut) {
             try {
                 auto data = fut.get();
 
-                if (!data) return callback(nullptr);
+                if (!data) return (*callbackPtr)(nullptr);
 
                 stats.narInfoRead++;
 
-                callback((std::shared_ptr<ValidPathInfo>)
+                (*callbackPtr)((std::shared_ptr<ValidPathInfo>)
                     std::make_shared<NarInfo>(*this, *data, narInfoFile));
 
                 (void) act; // force Activity into this lambda to ensure it stays alive
             } catch (...) {
-                callback.rethrow();
+                callbackPtr->rethrow();
             }
         }});
 }
diff --git a/src/libstore/download.cc b/src/libstore/download.cc
index a7d059465b0e..cdf56e09d69a 100644
--- a/src/libstore/download.cc
+++ b/src/libstore/download.cc
@@ -77,13 +77,13 @@ struct CurlDownloader : public Downloader
 
         DownloadItem(CurlDownloader & downloader,
             const DownloadRequest & request,
-            Callback<DownloadResult> callback)
+            Callback<DownloadResult> && callback)
             : downloader(downloader)
             , request(request)
             , act(*logger, lvlTalkative, actDownload,
                 fmt(request.data ? "uploading '%s'" : "downloading '%s'", request.uri),
                 {request.uri}, request.parentAct)
-            , callback(callback)
+            , callback(std::move(callback))
             , finalSink([this](const unsigned char * data, size_t len) {
                 if (this->request.dataCallback) {
                     writtenToSink += len;
@@ -665,7 +665,7 @@ struct CurlDownloader : public Downloader
             return;
         }
 
-        enqueueItem(std::make_shared<DownloadItem>(*this, request, callback));
+        enqueueItem(std::make_shared<DownloadItem>(*this, request, std::move(callback)));
     }
 };
 
diff --git a/src/libstore/http-binary-cache-store.cc b/src/libstore/http-binary-cache-store.cc
index df2fb93320fc..e631d95f0fd1 100644
--- a/src/libstore/http-binary-cache-store.cc
+++ b/src/libstore/http-binary-cache-store.cc
@@ -137,17 +137,19 @@ protected:
 
         auto request(makeRequest(path));
 
+        auto callbackPtr = std::make_shared<decltype(callback)>(std::move(callback));
+
         getDownloader()->enqueueDownload(request,
-            {[callback, this](std::future<DownloadResult> result) {
+            {[callbackPtr, this](std::future<DownloadResult> result) {
                 try {
-                    callback(result.get().data);
+                    (*callbackPtr)(result.get().data);
                 } catch (DownloadError & e) {
                     if (e.error == Downloader::NotFound || e.error == Downloader::Forbidden)
-                        return callback(std::shared_ptr<std::string>());
+                        return (*callbackPtr)(std::shared_ptr<std::string>());
                     maybeDisable();
-                    callback.rethrow();
+                    callbackPtr->rethrow();
                 } catch (...) {
-                    callback.rethrow();
+                    callbackPtr->rethrow();
                 }
             }});
     }
diff --git a/src/libstore/store-api.cc b/src/libstore/store-api.cc
index 3bb9db0b723b..88a5b2f448bd 100644
--- a/src/libstore/store-api.cc
+++ b/src/libstore/store-api.cc
@@ -365,8 +365,10 @@ void Store::queryPathInfo(const Path & storePath,
 
     } catch (...) { return callback.rethrow(); }
 
+    auto callbackPtr = std::make_shared<decltype(callback)>(std::move(callback));
+
     queryPathInfoUncached(storePath,
-        {[this, storePath, hashPart, callback](std::future<std::shared_ptr<ValidPathInfo>> fut) {
+        {[this, storePath, hashPart, callbackPtr](std::future<std::shared_ptr<ValidPathInfo>> fut) {
 
             try {
                 auto info = fut.get();
@@ -386,8 +388,8 @@ void Store::queryPathInfo(const Path & storePath,
                     throw InvalidPath("path '%s' is not valid", storePath);
                 }
 
-                callback(ref<ValidPathInfo>(info));
-            } catch (...) { callback.rethrow(); }
+                (*callbackPtr)(ref<ValidPathInfo>(info));
+            } catch (...) { callbackPtr->rethrow(); }
         }});
 }
 
diff --git a/src/libutil/util.hh b/src/libutil/util.hh
index b538a0b41ce8..686e81d3f893 100644
--- a/src/libutil/util.hh
+++ b/src/libutil/util.hh
@@ -445,21 +445,34 @@ string get(const T & map, const string & key, const string & def = "")
    type T or an exception. (We abuse std::future<T> to pass the value or
    exception.) */
 template<typename T>
-struct Callback
+class Callback
 {
     std::function<void(std::future<T>)> fun;
+    std::atomic_flag done = ATOMIC_FLAG_INIT;
+
+public:
 
     Callback(std::function<void(std::future<T>)> fun) : fun(fun) { }
 
-    void operator()(T && t) const
+    Callback(Callback && callback) : fun(std::move(callback.fun))
+    {
+        auto prev = callback.done.test_and_set();
+        if (prev) done.test_and_set();
+    }
+
+    void operator()(T && t)
     {
+        auto prev = done.test_and_set();
+        assert(!prev);
         std::promise<T> promise;
         promise.set_value(std::move(t));
         fun(promise.get_future());
     }
 
-    void rethrow(const std::exception_ptr & exc = std::current_exception()) const
+    void rethrow(const std::exception_ptr & exc = std::current_exception())
     {
+        auto prev = done.test_and_set();
+        assert(!prev);
         std::promise<T> promise;
         promise.set_exception(exc);
         fun(promise.get_future());