1 // Copyright (C) 2017 The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #pragma once
16 
17 #include <atomic>
18 #include <cstdint>
19 #include <functional>
20 #include <memory>
21 #include <type_traits>
22 #include <utility>
23 #include <vector>
24 
25 #include "aemu/base/Compiler.h"
26 #include "aemu/base/Optional.h"
27 #include "aemu/base/system/System.h"
28 #include "aemu/base/threads/WorkerThread.h"
29 
30 //
31 // ThreadPool<Item> - a simple collection of worker threads to process enqueued
32 // items on multiple cores.
33 //
34 // To create a thread pool supply a processing function and an optional number
35 // of threads to use (default is number of CPU cores).
36 // Thread pool distributes the work in simple round robin manner over all its
37 // workers - this means individual items should be simple and take similar time
38 // to process.
39 //
40 // Usage is very similar to one of WorkerThread, with difference being in the
41 // number of worker threads used and in existence of explicit done() method:
42 //
43 //      struct WorkItem { int number; };
44 //
45 //      ThreadPool<WorkItem> tp([](WorkItem&& item) { std::cout << item.num; });
46 //      CHECK(tp.start()) << "Failed to start the thread pool";
47 //      tp.enqueue({1});
48 //      tp.enqueue({2});
49 //      tp.enqueue({3});
50 //      tp.enqueue({4});
51 //      tp.enqueue({5});
52 //      tp.done();
53 //      tp.join();
54 //
55 // Make sure that the processing function won't block worker threads - thread
56 // pool has no way of detecting it and may potentially get all workers to block,
57 // resulting in a hanging application.
58 //
59 
60 namespace android {
61 namespace base {
62 
63 using ThreadPoolWorkerId = uint32_t;
64 
65 template <class ItemT>
66 class ThreadPool {
67     DISALLOW_COPY_AND_ASSIGN(ThreadPool);
68 
69 public:
70     using Item = ItemT;
71     using WorkerId = ThreadPoolWorkerId;
72     using Processor = std::function<void(Item&&, WorkerId)>;
73 
74    private:
75     struct Command {
76         Item mItem;
77         WorkerId mWorkerId;
78 
CommandCommand79         Command(Item&& item, WorkerId workerId) : mItem(std::move(item)), mWorkerId(workerId) {}
80         DISALLOW_COPY_AND_ASSIGN(Command);
81         Command(Command&&) = default;
82     };
83     using Worker = WorkerThread<Optional<Command>>;
84 
85    public:
86     // Fn is the type of the processor, it can either have 2 parameters: 1 for the Item, 1 for the
87     // WorkerId, or have only 1 Item parameter.
88     template <class Fn, typename = std::enable_if_t<std::is_invocable_v<Fn, Item, WorkerId> ||
89                                                     std::is_invocable_v<Fn, Item>>>
ThreadPool(int threads,Fn && processor)90     ThreadPool(int threads, Fn&& processor) : mProcessor() {
91         if constexpr (std::is_invocable_v<Fn, Item, WorkerId>) {
92             mProcessor = std::move(processor);
93         } else if constexpr (std::is_invocable_v<Fn, Item>) {
94             using namespace std::placeholders;
95             mProcessor = std::bind(std::move(processor), _1);
96         }
97         if (threads < 1) {
98             threads = android::base::getCpuCoreCount();
99         }
100         mWorkers = std::vector<Optional<Worker>>(threads);
101         for (auto& workerPtr : mWorkers) {
102             workerPtr.emplace([this](Optional<Command>&& commandOpt) {
103                 if (!commandOpt) {
104                     return Worker::Result::Stop;
105                 }
106                 Command command = std::move(commandOpt.value());
107                 mProcessor(std::move(command.mItem), command.mWorkerId);
108                 return Worker::Result::Continue;
109             });
110         }
111     }
ThreadPool(Processor && processor)112     explicit ThreadPool(Processor&& processor)
113         : ThreadPool(0, std::move(processor)) {}
~ThreadPool()114     ~ThreadPool() {
115         done();
116         join();
117     }
118 
start()119     bool start() {
120         for (auto& workerPtr : mWorkers) {
121             if (workerPtr->start()) {
122                 ++mValidWorkersCount;
123             } else {
124                 workerPtr.clear();
125             }
126         }
127         return mValidWorkersCount > 0;
128     }
129 
done()130     void done() {
131         for (auto& workerPtr : mWorkers) {
132             if (workerPtr) {
133                 workerPtr->enqueue(kNullopt);
134             }
135         }
136     }
137 
join()138     void join() {
139         for (auto& workerPtr : mWorkers) {
140             if (workerPtr) {
141                 workerPtr->join();
142             }
143         }
144         mWorkers.clear();
145         mValidWorkersCount = 0;
146     }
147 
enqueue(Item && item)148     void enqueue(Item&& item) {
149         for (;;) {
150             int currentIndex =
151                     mNextWorkerIndex.fetch_add(1, std::memory_order_relaxed);
152             int workerIndex = currentIndex % mWorkers.size();
153             auto& workerPtr = mWorkers[workerIndex];
154             if (workerPtr) {
155                 Command command(std::forward<Item>(item), workerIndex);
156                 workerPtr->enqueue(std::move(command));
157                 break;
158             }
159         }
160     }
161 
162     // The itemFactory will be called multiple times to generate one item for each worker thread.
163     template <class Fn, typename = std::enable_if_t<std::is_invocable_r_v<Item, Fn>>>
broadcast(Fn && itemFactory)164     void broadcast(Fn&& itemFactory) {
165         int i = 0;
166         for (auto& workerOpt : mWorkers) {
167             if (!workerOpt) continue;
168             Command command(std::move(itemFactory()), i);
169             workerOpt->enqueue(std::move(command));
170             ++i;
171         }
172     }
173 
waitAllItems()174     void waitAllItems() {
175         if (0 == mValidWorkersCount) return;
176         for (auto& workerOpt : mWorkers) {
177             if (!workerOpt) continue;
178             workerOpt->waitQueuedItems();
179         }
180     }
181 
numWorkers()182     int numWorkers() const { return mValidWorkersCount; }
183 
184 private:
185     Processor mProcessor;
186     std::vector<Optional<Worker>> mWorkers;
187     std::atomic<int> mNextWorkerIndex{0};
188     int mValidWorkersCount{0};
189 };
190 
191 }  // namespace base
192 }  // namespace android
193