about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/libstore/local.mk2
-rw-r--r--src/libstore/s3-binary-cache-store.cc80
2 files changed, 65 insertions, 17 deletions
diff --git a/src/libstore/local.mk b/src/libstore/local.mk
index a7279aa3939f..3799257f83ff 100644
--- a/src/libstore/local.mk
+++ b/src/libstore/local.mk
@@ -18,7 +18,7 @@ libstore_FILES = sandbox-defaults.sb sandbox-minimal.sb sandbox-network.sb
 $(foreach file,$(libstore_FILES),$(eval $(call install-data-in,$(d)/$(file),$(datadir)/nix/sandbox)))
 
 ifeq ($(ENABLE_S3), 1)
-	libstore_LDFLAGS += -laws-cpp-sdk-s3 -laws-cpp-sdk-core
+	libstore_LDFLAGS += -laws-cpp-sdk-transfer -laws-cpp-sdk-s3 -laws-cpp-sdk-core
 endif
 
 ifeq ($(OS), SunOS)
diff --git a/src/libstore/s3-binary-cache-store.cc b/src/libstore/s3-binary-cache-store.cc
index 23af452094cf..b55ff5c6d5d5 100644
--- a/src/libstore/s3-binary-cache-store.cc
+++ b/src/libstore/s3-binary-cache-store.cc
@@ -17,6 +17,7 @@
 #include <aws/core/client/DefaultRetryStrategy.h>
 #include <aws/core/utils/logging/FormattedLogSystem.h>
 #include <aws/core/utils/logging/LogMacros.h>
+#include <aws/core/utils/threading/Executor.h>
 #include <aws/s3/S3Client.h>
 #include <aws/s3/model/CreateBucketRequest.h>
 #include <aws/s3/model/GetBucketLocationRequest.h>
@@ -24,6 +25,9 @@
 #include <aws/s3/model/HeadObjectRequest.h>
 #include <aws/s3/model/ListObjectsRequest.h>
 #include <aws/s3/model/PutObjectRequest.h>
+#include <aws/transfer/TransferManager.h>
+
+using namespace Aws::Transfer;
 
 namespace nix {
 
@@ -169,6 +173,8 @@ struct S3BinaryCacheStoreImpl : public S3BinaryCacheStore
     const Setting<std::string> narinfoCompression{this, "", "narinfo-compression", "compression method for .narinfo files"};
     const Setting<std::string> lsCompression{this, "", "ls-compression", "compression method for .ls files"};
     const Setting<std::string> logCompression{this, "", "log-compression", "compression method for log/* files"};
+    const Setting<uint64_t> bufferSize{
+        this, 5 * 1024 * 1024, "buffer-size", "size (in bytes) of each part in multi-part uploads. defaults to 5Mb"};
 
     std::string bucketName;
 
@@ -271,34 +277,76 @@ struct S3BinaryCacheStoreImpl : public S3BinaryCacheStore
         const std::string & mimeType,
         const std::string & contentEncoding)
     {
-        auto request =
-            Aws::S3::Model::PutObjectRequest()
-            .WithBucket(bucketName)
-            .WithKey(path);
+        auto stream = std::make_shared<istringstream_nocopy>(data);
 
-        request.SetContentType(mimeType);
+        auto maxThreads = std::thread::hardware_concurrency();
 
-        if (contentEncoding != "")
-            request.SetContentEncoding(contentEncoding);
+        auto executor =
+            std::make_shared<Aws::Utils::Threading::PooledThreadExecutor>(maxThreads);
 
-        auto stream = std::make_shared<istringstream_nocopy>(data);
+        TransferManagerConfiguration transferConfig(executor.get());
 
-        request.SetBody(stream);
+        transferConfig.s3Client = s3Helper.client;
+        transferConfig.bufferSize = bufferSize;
 
-        stats.put++;
-        stats.putBytes += data.size();
+        if (contentEncoding != "")
+            transferConfig.createMultipartUploadTemplate.SetContentEncoding(
+                contentEncoding);
+
+        transferConfig.uploadProgressCallback =
+            [&](const TransferManager *transferManager,
+                const std::shared_ptr<const TransferHandle>
+                    &transferHandle)
+            {
+              checkInterrupt();
+              printTalkative("upload progress ('%s'): '%d' of '%d' bytes",
+                             path,
+                             transferHandle->GetBytesTransferred(),
+                             transferHandle->GetBytesTotalSize());
+            };
+
+        transferConfig.transferStatusUpdatedCallback =
+            [&](const TransferManager *,
+                const std::shared_ptr<const TransferHandle>
+                    &transferHandle) {
+              switch (transferHandle->GetStatus()) {
+                  case TransferStatus::COMPLETED:
+                      printTalkative("upload of '%s' completed", path);
+                      stats.put++;
+                      stats.putBytes += data.size();
+                      break;
+                  case TransferStatus::IN_PROGRESS:
+                      break;
+                  case TransferStatus::FAILED:
+                      throw Error("AWS error: failed to upload 's3://%s/%s'",
+                                  bucketName, path);
+                      break;
+                  default:
+                      throw Error("AWS error: transfer status of 's3://%s/%s' "
+                                  "in unexpected state",
+                                  bucketName, path);
+              };
+            };
+
+        std::shared_ptr<TransferManager> transferManager =
+            TransferManager::Create(transferConfig);
 
         auto now1 = std::chrono::steady_clock::now();
 
-        auto result = checkAws(format("AWS error uploading '%s'") % path,
-            s3Helper.client->PutObject(request));
+        std::shared_ptr<TransferHandle> transferHandle =
+            transferManager->UploadFile(stream, bucketName, path, mimeType,
+                                        Aws::Map<Aws::String, Aws::String>());
+
+        transferHandle->WaitUntilFinished();
 
         auto now2 = std::chrono::steady_clock::now();
 
-        auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(now2 - now1).count();
+        auto duration =
+            std::chrono::duration_cast<std::chrono::milliseconds>(now2 - now1)
+                .count();
 
-        printInfo(format("uploaded 's3://%1%/%2%' (%3% bytes) in %4% ms")
-            % bucketName % path % data.size() % duration);
+        printInfo(format("uploaded 's3://%1%/%2%' (%3% bytes) in %4% ms") %
+                  bucketName % path % data.size() % duration);
 
         stats.putTimeMs += duration;
     }