about summary refs log tree commit diff
path: root/src/libutil/thread-pool.cc
blob: 819aed748340938c64388bc10871e4d033c0eb6d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include "thread-pool.hh"

namespace nix {

ThreadPool::ThreadPool(size_t _nrThreads)
    : nrThreads(_nrThreads)
{
    if (!nrThreads) {
        nrThreads = std::thread::hardware_concurrency();
        if (!nrThreads) nrThreads = 1;
    }
}

void ThreadPool::enqueue(const work_t & t)
{
    auto state_(state.lock());
    state_->left.push(t);
    wakeup.notify_one();
}

void ThreadPool::process()
{
    printMsg(lvlDebug, format("starting pool of %d threads") % nrThreads);

    std::vector<std::thread> workers;

    for (size_t n = 0; n < nrThreads; n++)
        workers.push_back(std::thread([&]() {
            bool first = true;

            while (true) {
                work_t work;
                {
                    auto state_(state.lock());
                    if (state_->exception) return;
                    if (!first) {
                        assert(state_->pending);
                        state_->pending--;
                    }
                    first = false;
                    while (state_->left.empty()) {
                        if (!state_->pending) {
                            wakeup.notify_all();
                            return;
                        }
                        if (state_->exception) return;
                        state_.wait(wakeup);
                    }
                    work = state_->left.front();
                    state_->left.pop();
                    state_->pending++;
                }

                try {
                    work();
                } catch (std::exception & e) {
                    auto state_(state.lock());
                    if (state_->exception)
                        printMsg(lvlError, format("error: %s") % e.what());
                    else {
                        state_->exception = std::current_exception();
                        wakeup.notify_all();
                    }
                }
            }

        }));

    for (auto & thr : workers)
        thr.join();

    {
        auto state_(state.lock());
        if (state_->exception)
            std::rethrow_exception(state_->exception);
    }
}

}