about summary refs log tree commit diff
path: root/src/libstore/download.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/libstore/download.cc')
-rw-r--r--src/libstore/download.cc192
1 files changed, 151 insertions, 41 deletions
diff --git a/src/libstore/download.cc b/src/libstore/download.cc
index 5ab625f42288..72a08ef0089c 100644
--- a/src/libstore/download.cc
+++ b/src/libstore/download.cc
@@ -7,6 +7,7 @@
 #include "s3.hh"
 #include "compression.hh"
 #include "pathlocks.hh"
+#include "finally.hh"
 
 #ifdef ENABLE_S3
 #include <aws/core/client/ClientConfiguration.h>
@@ -29,12 +30,25 @@ using namespace std::string_literals;
 
 namespace nix {
 
-double getTime()
+struct DownloadSettings : Config
 {
-    struct timeval tv;
-    gettimeofday(&tv, 0);
-    return tv.tv_sec + (tv.tv_usec / 1000000.0);
-}
+    Setting<bool> enableHttp2{this, true, "http2",
+        "Whether to enable HTTP/2 support."};
+
+    Setting<std::string> userAgentSuffix{this, "", "user-agent-suffix",
+        "String appended to the user agent in HTTP requests."};
+
+    Setting<size_t> httpConnections{this, 25, "http-connections",
+        "Number of parallel HTTP connections.",
+        {"binary-caches-parallel-connections"}};
+
+    Setting<unsigned long> connectTimeout{this, 0, "connect-timeout",
+        "Timeout for connecting to servers during downloads. 0 means use curl's builtin default."};
+};
+
+static DownloadSettings downloadSettings;
+
+static GlobalConfig::Register r1(&downloadSettings);
 
 std::string resolveUri(const std::string & uri)
 {
@@ -61,8 +75,6 @@ struct CurlDownloader : public Downloader
     std::random_device rd;
     std::mt19937 mt19937;
 
-    bool enableHttp2;
-
     struct DownloadItem : public std::enable_shared_from_this<DownloadItem>
     {
         CurlDownloader & downloader;
@@ -70,8 +82,7 @@ struct CurlDownloader : public Downloader
         DownloadResult result;
         Activity act;
         bool done = false; // whether either the success or failure function has been called
-        std::function<void(const DownloadResult &)> success;
-        std::function<void(std::exception_ptr exc)> failure;
+        Callback<DownloadResult> callback;
         CURL * req = 0;
         bool active = false; // whether the handle has been added to the multi object
         std::string status;
@@ -86,10 +97,13 @@ struct CurlDownloader : public Downloader
 
         std::string encoding;
 
-        DownloadItem(CurlDownloader & downloader, const DownloadRequest & request)
+        DownloadItem(CurlDownloader & downloader,
+            const DownloadRequest & request,
+            Callback<DownloadResult> callback)
             : downloader(downloader)
             , request(request)
             , act(*logger, lvlTalkative, actDownload, fmt("downloading '%s'", request.uri), {request.uri}, request.parentAct)
+            , callback(callback)
         {
             if (!request.expectedETag.empty())
                 requestHeaders = curl_slist_append(requestHeaders, ("If-None-Match: " + request.expectedETag).c_str());
@@ -118,13 +132,16 @@ struct CurlDownloader : public Downloader
         {
             assert(!done);
             done = true;
-            callFailure(failure, std::make_exception_ptr(e));
+            callback.rethrow(std::make_exception_ptr(e));
         }
 
         size_t writeCallback(void * contents, size_t size, size_t nmemb)
         {
             size_t realSize = size * nmemb;
-            result.data->append((char *) contents, realSize);
+            if (request.dataCallback)
+                request.dataCallback((char *) contents, realSize);
+            else
+                result.data->append((char *) contents, realSize);
             return realSize;
         }
 
@@ -173,7 +190,11 @@ struct CurlDownloader : public Downloader
 
         int progressCallback(double dltotal, double dlnow)
         {
-            act.progress(dlnow, dltotal);
+            try {
+              act.progress(dlnow, dltotal);
+            } catch (nix::Interrupted &) {
+              assert(_isInterrupted);
+            }
             return _isInterrupted;
         }
 
@@ -195,6 +216,7 @@ struct CurlDownloader : public Downloader
             if (readOffset == request.data->length())
                 return 0;
             auto count = std::min(size * nitems, request.data->length() - readOffset);
+            assert(count);
             memcpy(buffer, request.data->data() + readOffset, count);
             readOffset += count;
             return count;
@@ -223,12 +245,12 @@ struct CurlDownloader : public Downloader
             curl_easy_setopt(req, CURLOPT_NOSIGNAL, 1);
             curl_easy_setopt(req, CURLOPT_USERAGENT,
                 ("curl/" LIBCURL_VERSION " Nix/" + nixVersion +
-                    (settings.userAgentSuffix != "" ? " " + settings.userAgentSuffix.get() : "")).c_str());
+                    (downloadSettings.userAgentSuffix != "" ? " " + downloadSettings.userAgentSuffix.get() : "")).c_str());
             #if LIBCURL_VERSION_NUM >= 0x072b00
             curl_easy_setopt(req, CURLOPT_PIPEWAIT, 1);
             #endif
             #if LIBCURL_VERSION_NUM >= 0x072f00
-            if (downloader.enableHttp2)
+            if (downloadSettings.enableHttp2)
                 curl_easy_setopt(req, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_2TLS);
             #endif
             curl_easy_setopt(req, CURLOPT_WRITEFUNCTION, DownloadItem::writeCallbackWrapper);
@@ -260,7 +282,7 @@ struct CurlDownloader : public Downloader
                 curl_easy_setopt(req, CURLOPT_SSL_VERIFYHOST, 0);
             }
 
-            curl_easy_setopt(req, CURLOPT_CONNECTTIMEOUT, settings.connectTimeout.get());
+            curl_easy_setopt(req, CURLOPT_CONNECTTIMEOUT, downloadSettings.connectTimeout.get());
 
             curl_easy_setopt(req, CURLOPT_LOW_SPEED_LIMIT, 1L);
             curl_easy_setopt(req, CURLOPT_LOW_SPEED_TIME, lowSpeedTimeout);
@@ -300,11 +322,11 @@ struct CurlDownloader : public Downloader
                 try {
                     if (request.decompress)
                         result.data = decodeContent(encoding, ref<std::string>(result.data));
-                    callSuccess(success, failure, const_cast<const DownloadResult &>(result));
                     act.progress(result.data->size(), result.data->size());
+                    callback(std::move(result));
                 } catch (...) {
                     done = true;
-                    callFailure(failure, std::current_exception());
+                    callback.rethrow();
                 }
             } else {
                 // We treat most errors as transient, but won't retry when hopeless
@@ -339,6 +361,7 @@ struct CurlDownloader : public Downloader
                         case CURLE_BAD_FUNCTION_ARGUMENT:
                         case CURLE_INTERFACE_FAILED:
                         case CURLE_UNKNOWN_OPTION:
+                        case CURLE_SSL_CACERT_BADFILE:
                             err = Misc;
                             break;
                         default: // Shut up warnings
@@ -402,11 +425,9 @@ struct CurlDownloader : public Downloader
         #endif
         #if LIBCURL_VERSION_NUM >= 0x071e00 // Max connections requires >= 7.30.0
         curl_multi_setopt(curlm, CURLMOPT_MAX_TOTAL_CONNECTIONS,
-            settings.binaryCachesParallelConnections.get());
+            downloadSettings.httpConnections.get());
         #endif
 
-        enableHttp2 = settings.enableHttp2;
-
         wakeupPipe.create();
         fcntl(wakeupPipe.readSide.get(), F_SETFL, O_NONBLOCK);
 
@@ -555,13 +576,12 @@ struct CurlDownloader : public Downloader
     }
 
     void enqueueDownload(const DownloadRequest & request,
-        std::function<void(const DownloadResult &)> success,
-        std::function<void(std::exception_ptr exc)> failure) override
+        Callback<DownloadResult> callback) override
     {
         /* Ugly hack to support s3:// URIs. */
         if (hasPrefix(request.uri, "s3://")) {
             // FIXME: do this on a worker thread
-            sync2async<DownloadResult>(success, failure, [&]() -> DownloadResult {
+            try {
 #ifdef ENABLE_S3
                 S3Helper s3Helper("", Aws::Region::US_EAST_1); // FIXME: make configurable
                 auto slash = request.uri.find('/', 5);
@@ -575,27 +595,22 @@ struct CurlDownloader : public Downloader
                 if (!s3Res.data)
                     throw DownloadError(NotFound, fmt("S3 object '%s' does not exist", request.uri));
                 res.data = s3Res.data;
-                return res;
+                callback(std::move(res));
 #else
                 throw nix::Error("cannot download '%s' because Nix is not built with S3 support", request.uri);
 #endif
-            });
+            } catch (...) { callback.rethrow(); }
             return;
         }
 
-        auto item = std::make_shared<DownloadItem>(*this, request);
-        item->success = success;
-        item->failure = failure;
-        enqueueItem(item);
+        enqueueItem(std::make_shared<DownloadItem>(*this, request, callback));
     }
 };
 
 ref<Downloader> getDownloader()
 {
-    static std::shared_ptr<Downloader> downloader;
-    static std::once_flag downloaderCreated;
-    std::call_once(downloaderCreated, [&]() { downloader = makeDownloader(); });
-    return ref<Downloader>(downloader);
+    static ref<Downloader> downloader = makeDownloader();
+    return downloader;
 }
 
 ref<Downloader> makeDownloader()
@@ -607,8 +622,13 @@ std::future<DownloadResult> Downloader::enqueueDownload(const DownloadRequest &
 {
     auto promise = std::make_shared<std::promise<DownloadResult>>();
     enqueueDownload(request,
-        [promise](const DownloadResult & result) { promise->set_value(result); },
-        [promise](std::exception_ptr exc) { promise->set_exception(exc); });
+        {[promise](std::future<DownloadResult> fut) {
+            try {
+                promise->set_value(fut.get());
+            } catch (...) {
+                promise->set_exception(std::current_exception());
+            }
+        }});
     return promise->get_future();
 }
 
@@ -617,7 +637,93 @@ DownloadResult Downloader::download(const DownloadRequest & request)
     return enqueueDownload(request).get();
 }
 
-Path Downloader::downloadCached(ref<Store> store, const string & url_, bool unpack, string name, const Hash & expectedHash, string * effectiveUrl)
+void Downloader::download(DownloadRequest && request, Sink & sink)
+{
+    /* Note: we can't call 'sink' via request.dataCallback, because
+       that would cause the sink to execute on the downloader
+       thread. If 'sink' is a coroutine, this will fail. Also, if the
+       sink is expensive (e.g. one that does decompression and writing
+       to the Nix store), it would stall the download thread too much.
+       Therefore we use a buffer to communicate data between the
+       download thread and the calling thread. */
+
+    struct State {
+        bool quit = false;
+        std::exception_ptr exc;
+        std::string data;
+        std::condition_variable avail, request;
+    };
+
+    auto _state = std::make_shared<Sync<State>>();
+
+    /* In case of an exception, wake up the download thread. FIXME:
+       abort the download request. */
+    Finally finally([&]() {
+        auto state(_state->lock());
+        state->quit = true;
+        state->request.notify_one();
+    });
+
+    request.dataCallback = [_state](char * buf, size_t len) {
+
+        auto state(_state->lock());
+
+        if (state->quit) return;
+
+        /* If the buffer is full, then go to sleep until the calling
+           thread wakes us up (i.e. when it has removed data from the
+           buffer). Note: this does stall the download thread. */
+        while (state->data.size() > 1024 * 1024) {
+            if (state->quit) return;
+            debug("download buffer is full; going to sleep");
+            state.wait(state->request);
+        }
+
+        /* Append data to the buffer and wake up the calling
+           thread. */
+        state->data.append(buf, len);
+        state->avail.notify_one();
+    };
+
+    enqueueDownload(request,
+        {[_state](std::future<DownloadResult> fut) {
+            auto state(_state->lock());
+            state->quit = true;
+            try {
+                fut.get();
+            } catch (...) {
+                state->exc = std::current_exception();
+            }
+            state->avail.notify_one();
+            state->request.notify_one();
+        }});
+
+    auto state(_state->lock());
+
+    while (true) {
+        checkInterrupt();
+
+        if (state->quit) {
+            if (state->exc) std::rethrow_exception(state->exc);
+            break;
+        }
+
+        /* If no data is available, then wait for the download thread
+           to wake us up. */
+        if (state->data.empty())
+            state.wait(state->avail);
+
+        /* If data is available, then flush it to the sink and wake up
+           the download thread if it's blocked on a full buffer. */
+        if (!state->data.empty()) {
+            sink((unsigned char *) state->data.data(), state->data.size());
+            state->data.clear();
+            state->request.notify_one();
+        }
+    }
+}
+
+Path Downloader::downloadCached(ref<Store> store, const string & url_, bool unpack, string name, const Hash & expectedHash, string * effectiveUrl, int ttl)
 {
     auto url = resolveUri(url_);
 
@@ -630,7 +736,7 @@ Path Downloader::downloadCached(ref<Store> store, const string & url_, bool unpa
     if (expectedHash) {
         expectedStorePath = store->makeFixedOutputPath(unpack, expectedHash, name);
         if (store->isValidPath(expectedStorePath))
-            return expectedStorePath;
+            return store->toRealPath(expectedStorePath);
     }
 
     Path cacheDir = getCacheDir() + "/nix/tarballs";
@@ -647,7 +753,6 @@ Path Downloader::downloadCached(ref<Store> store, const string & url_, bool unpa
 
     string expectedETag;
 
-    int ttl = settings.tarballTtl;
     bool skip = false;
 
     if (pathExists(fileLink) && pathExists(dataFile)) {
@@ -724,8 +829,13 @@ Path Downloader::downloadCached(ref<Store> store, const string & url_, bool unpa
         storePath = unpackedStorePath;
     }
 
-    if (expectedStorePath != "" && storePath != expectedStorePath)
-        throw nix::Error("store path mismatch in file downloaded from '%s'", url);
+    if (expectedStorePath != "" && storePath != expectedStorePath) {
+        Hash gotHash = unpack
+            ? hashPath(expectedHash.type, store->toRealPath(storePath)).first
+            : hashFile(expectedHash.type, store->toRealPath(storePath));
+        throw nix::Error("hash mismatch in file downloaded from '%s': got hash '%s' instead of the expected hash '%s'",
+            url, gotHash.to_string(), expectedHash.to_string());
+    }
 
     return store->toRealPath(storePath);
 }