about summary refs log tree commit diff
path: root/src/libutil/thread-pool.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/libutil/thread-pool.cc')
-rw-r--r--src/libutil/thread-pool.cc130
1 files changed, 75 insertions, 55 deletions
diff --git a/src/libutil/thread-pool.cc b/src/libutil/thread-pool.cc
index 743038b588a7..32363ecf0098 100644
--- a/src/libutil/thread-pool.cc
+++ b/src/libutil/thread-pool.cc
@@ -1,79 +1,99 @@
 #include "thread-pool.hh"
+#include "affinity.hh"
 
 namespace nix {
 
-ThreadPool::ThreadPool(size_t _nrThreads)
-    : nrThreads(_nrThreads)
+ThreadPool::ThreadPool(size_t _maxThreads)
+    : maxThreads(_maxThreads)
 {
-    if (!nrThreads) {
-        nrThreads = std::thread::hardware_concurrency();
-        if (!nrThreads) nrThreads = 1;
+    restoreAffinity(); // FIXME
+
+    if (!maxThreads) {
+        maxThreads = std::thread::hardware_concurrency();
+        if (!maxThreads) maxThreads = 1;
     }
+
+    debug(format("starting pool of %d threads") % maxThreads);
+}
+
+ThreadPool::~ThreadPool()
+{
+    std::vector<std::thread> workers;
+    {
+        auto state(state_.lock());
+        state->quit = true;
+        std::swap(workers, state->workers);
+    }
+
+    debug(format("reaping %d worker threads") % workers.size());
+
+    work.notify_all();
+
+    for (auto & thr : workers)
+        thr.join();
 }
 
 void ThreadPool::enqueue(const work_t & t)
 {
-    auto state_(state.lock());
-    state_->left.push(t);
-    wakeup.notify_one();
+    auto state(state_.lock());
+    assert(!state->quit);
+    state->left.push(t);
+    if (state->left.size() > state->workers.size() && state->workers.size() < maxThreads)
+        state->workers.emplace_back(&ThreadPool::workerEntry, this);
+    work.notify_one();
 }
 
 void ThreadPool::process()
 {
-    printMsg(lvlDebug, format("starting pool of %d threads") % nrThreads);
-
-    std::vector<std::thread> workers;
+    while (true) {
+        auto state(state_.lock());
+        if (state->exception)
+            std::rethrow_exception(state->exception);
+        if (state->left.empty() && !state->pending) break;
+        state.wait(done);
+    }
+}
 
-    for (size_t n = 0; n < nrThreads; n++)
-        workers.push_back(std::thread([&]() {
-            bool first = true;
+void ThreadPool::workerEntry()
+{
+    bool didWork = false;
 
+    while (true) {
+        work_t w;
+        {
+            auto state(state_.lock());
             while (true) {
-                work_t work;
-                {
-                    auto state_(state.lock());
-                    if (state_->exception) return;
-                    if (!first) {
-                        assert(state_->pending);
-                        state_->pending--;
-                    }
-                    first = false;
-                    while (state_->left.empty()) {
-                        if (!state_->pending) {
-                            wakeup.notify_all();
-                            return;
-                        }
-                        if (state_->exception) return;
-                        state_.wait(wakeup);
-                    }
-                    work = state_->left.front();
-                    state_->left.pop();
-                    state_->pending++;
-                }
-
-                try {
-                    work();
-                } catch (std::exception & e) {
-                    auto state_(state.lock());
-                    if (state_->exception) {
-                        if (!dynamic_cast<Interrupted*>(&e))
-                            printMsg(lvlError, format("error: %s") % e.what());
-                    } else {
-                        state_->exception = std::current_exception();
-                        wakeup.notify_all();
-                    }
+                if (state->quit || state->exception) return;
+                if (didWork) {
+                    assert(state->pending);
+                    state->pending--;
+                    didWork = false;
                 }
+                if (!state->left.empty()) break;
+                if (!state->pending)
+                    done.notify_all();
+                state.wait(work);
             }
+            w = state->left.front();
+            state->left.pop();
+            state->pending++;
+        }
 
-        }));
-
-    for (auto & thr : workers)
-        thr.join();
+        try {
+            w();
+        } catch (std::exception & e) {
+            auto state(state_.lock());
+            if (state->exception) {
+                if (!dynamic_cast<Interrupted*>(&e))
+                    printMsg(lvlError, format("error: %s") % e.what());
+            } else {
+                state->exception = std::current_exception();
+                work.notify_all();
+                done.notify_all();
+            }
+        }
 
-    {
-        auto state_(state.lock());
-        if (state_->exception)
-            std::rethrow_exception(state_->exception);
+        didWork = true;
     }
 }