1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use std::{
16     sync::{mpsc, Arc, Mutex},
17     thread,
18 };
19 
20 use log::{error, info};
21 
22 pub struct ThreadPool {
23     workers: Vec<Worker>,
24     sender: Option<mpsc::Sender<Job>>,
25 }
26 
27 type Job = Box<dyn FnOnce() + Send + 'static>;
28 
29 impl ThreadPool {
30     /// Create a new ThreadPool.
31     ///
32     /// The size is the number of threads in the pool.
33     ///
34     /// # Panics
35     ///
36     /// The `new` function will panic if the size is zero.
new(size: usize) -> ThreadPool37     pub fn new(size: usize) -> ThreadPool {
38         assert!(size > 0);
39 
40         let (sender, receiver) = mpsc::channel();
41 
42         let receiver = Arc::new(Mutex::new(receiver));
43 
44         let mut workers = Vec::with_capacity(size);
45 
46         for id in 0..size {
47             workers.push(Worker::new(id, Arc::clone(&receiver)));
48         }
49 
50         ThreadPool { workers, sender: Some(sender) }
51     }
52 
execute<F>(&self, f: F) where F: FnOnce() + Send + 'static,53     pub fn execute<F>(&self, f: F)
54     where
55         F: FnOnce() + Send + 'static,
56     {
57         let job = Box::new(f);
58 
59         self.sender.as_ref().unwrap().send(job).unwrap();
60     }
61 }
62 
63 impl Drop for ThreadPool {
drop(&mut self)64     fn drop(&mut self) {
65         drop(self.sender.take());
66 
67         for worker in &mut self.workers {
68             info!("Shutting down worker {}", worker.id);
69 
70             if let Some(thread) = worker.thread.take() {
71                 thread.join().unwrap();
72             }
73         }
74     }
75 }
76 
77 struct Worker {
78     id: usize,
79     thread: Option<thread::JoinHandle<()>>,
80 }
81 
82 impl Worker {
new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Job>>>) -> Worker83     fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Job>>>) -> Worker {
84         let thread = thread::Builder::new()
85             .name("http_pool_{id}".to_string())
86             .spawn(move || loop {
87                 let message = receiver.lock().expect("Failed to acquire lock on receiver").recv();
88 
89                 match message {
90                     Ok(job) => {
91                         job();
92                     }
93                     Err(_) => {
94                         error!("Worker {id} disconnected; shutting down.");
95                         break;
96                     }
97                 }
98             })
99             .expect("http_pool_{id} spawn failed");
100         Worker { id, thread: Some(thread) }
101     }
102 }
103