1 /*
2  * Copyright (C) 2021, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 //! DoH server frontend.
18 
19 use super::client::{ClientMap, ConnectionID, CONN_ID_LEN, DNS_HEADER_SIZE, MAX_UDP_PAYLOAD_SIZE};
20 use super::config::{Config, QUICHE_IDLE_TIMEOUT_MS};
21 use super::stats::Stats;
22 use anyhow::{bail, ensure, Result};
23 use lazy_static::lazy_static;
24 use log::{debug, error, warn};
25 use std::fs::File;
26 use std::io::Write;
27 use std::os::unix::io::{AsRawFd, FromRawFd};
28 use std::sync::{Arc, Mutex};
29 use std::time::Duration;
30 use tokio::net::UdpSocket;
31 use tokio::runtime::{Builder, Runtime};
32 use tokio::sync::{mpsc, oneshot};
33 use tokio::task::JoinHandle;
34 
35 lazy_static! {
36     static ref RUNTIME_STATIC: Arc<Runtime> = Arc::new(
37         Builder::new_multi_thread()
38             .worker_threads(1)
39             .enable_all()
40             .thread_name("DohFrontend")
41             .build()
42             .expect("Failed to create tokio runtime")
43     );
44 }
45 
46 /// Command used by worker_thread itself.
47 #[derive(Debug)]
48 enum InternalCommand {
49     MaybeWrite { connection_id: ConnectionID },
50 }
51 
52 /// Commands that DohFrontend to ask its worker_thread for.
53 #[derive(Debug)]
54 enum ControlCommand {
55     Stats { resp: oneshot::Sender<Stats> },
56     StatsClearQueries,
57     CloseConnection,
58 }
59 
60 /// Frontend object.
61 #[derive(Debug)]
62 pub struct DohFrontend {
63     // Socket address the frontend listens to.
64     listen_socket_addr: std::net::SocketAddr,
65 
66     // Socket address the backend listens to.
67     backend_socket_addr: std::net::SocketAddr,
68 
69     /// The content of the certificate.
70     certificate: String,
71 
72     /// The content of the private key.
73     private_key: String,
74 
75     // The thread listening to frontend socket and backend socket
76     // and processing the messages.
77     worker_thread: Option<JoinHandle<Result<()>>>,
78 
79     // Custom runtime configuration to control the behavior of the worker thread.
80     // It's shared with the worker thread.
81     // TODO: use channel to update worker_thread configuration.
82     config: Arc<Mutex<Config>>,
83 
84     // Caches the latest stats so that the stats remains after worker_thread stops.
85     latest_stats: Stats,
86 
87     // It is wrapped as Option because the channel is not created in DohFrontend construction.
88     command_tx: Option<mpsc::UnboundedSender<ControlCommand>>,
89 }
90 
91 /// The parameters passed to the worker thread.
92 struct WorkerParams {
93     frontend_socket: std::net::UdpSocket,
94     backend_socket: std::net::UdpSocket,
95     clients: ClientMap,
96     config: Arc<Mutex<Config>>,
97     command_rx: mpsc::UnboundedReceiver<ControlCommand>,
98 }
99 
100 impl DohFrontend {
new( listen: std::net::SocketAddr, backend: std::net::SocketAddr, ) -> Result<Box<DohFrontend>>101     pub fn new(
102         listen: std::net::SocketAddr,
103         backend: std::net::SocketAddr,
104     ) -> Result<Box<DohFrontend>> {
105         let doh = Box::new(DohFrontend {
106             listen_socket_addr: listen,
107             backend_socket_addr: backend,
108             certificate: String::new(),
109             private_key: String::new(),
110             worker_thread: None,
111             config: Arc::new(Mutex::new(Config::new())),
112             latest_stats: Stats::new(),
113             command_tx: None,
114         });
115         debug!("DohFrontend created: {:?}", doh);
116         Ok(doh)
117     }
118 
start(&mut self) -> Result<()>119     pub fn start(&mut self) -> Result<()> {
120         ensure!(self.worker_thread.is_none(), "Worker thread has been running");
121         ensure!(!self.certificate.is_empty(), "certificate is empty");
122         ensure!(!self.private_key.is_empty(), "private_key is empty");
123 
124         // Doing error handling here is much simpler.
125         let params = match self.init_worker_thread_params() {
126             Ok(v) => v,
127             Err(e) => return Err(e.context("init_worker_thread_params failed")),
128         };
129 
130         self.worker_thread = Some(RUNTIME_STATIC.spawn(worker_thread(params)));
131         Ok(())
132     }
133 
stop(&mut self) -> Result<()>134     pub fn stop(&mut self) -> Result<()> {
135         debug!("DohFrontend: stopping: {:?}", self);
136         if let Some(worker_thread) = self.worker_thread.take() {
137             // Update latest_stats before stopping worker_thread.
138             let _ = self.request_stats();
139 
140             self.command_tx.as_ref().unwrap().send(ControlCommand::CloseConnection)?;
141             if let Err(e) = self.wait_for_connections_closed() {
142                 warn!("wait_for_connections_closed failed: {}", e);
143             }
144 
145             worker_thread.abort();
146             RUNTIME_STATIC.block_on(async {
147                 debug!("worker_thread result: {:?}", worker_thread.await);
148             })
149         }
150 
151         debug!("DohFrontend: stopped: {:?}", self);
152         Ok(())
153     }
154 
set_certificate(&mut self, certificate: &str) -> Result<()>155     pub fn set_certificate(&mut self, certificate: &str) -> Result<()> {
156         self.certificate = certificate.to_string();
157         Ok(())
158     }
159 
set_private_key(&mut self, private_key: &str) -> Result<()>160     pub fn set_private_key(&mut self, private_key: &str) -> Result<()> {
161         self.private_key = private_key.to_string();
162         Ok(())
163     }
164 
set_delay_queries(&self, value: i32) -> Result<()>165     pub fn set_delay_queries(&self, value: i32) -> Result<()> {
166         self.config.lock().unwrap().delay_queries = value;
167         Ok(())
168     }
169 
set_max_idle_timeout(&self, value: u64) -> Result<()>170     pub fn set_max_idle_timeout(&self, value: u64) -> Result<()> {
171         self.config.lock().unwrap().max_idle_timeout = value;
172         Ok(())
173     }
174 
set_max_buffer_size(&self, value: u64) -> Result<()>175     pub fn set_max_buffer_size(&self, value: u64) -> Result<()> {
176         self.config.lock().unwrap().max_buffer_size = value;
177         Ok(())
178     }
179 
set_max_streams_bidi(&self, value: u64) -> Result<()>180     pub fn set_max_streams_bidi(&self, value: u64) -> Result<()> {
181         self.config.lock().unwrap().max_streams_bidi = value;
182         Ok(())
183     }
184 
block_sending(&self, value: bool) -> Result<()>185     pub fn block_sending(&self, value: bool) -> Result<()> {
186         self.config.lock().unwrap().block_sending = value;
187         Ok(())
188     }
189 
set_reset_stream_id(&self, value: u64) -> Result<()>190     pub fn set_reset_stream_id(&self, value: u64) -> Result<()> {
191         self.config.lock().unwrap().reset_stream_id = Some(value);
192         Ok(())
193     }
194 
request_stats(&mut self) -> Result<Stats>195     pub fn request_stats(&mut self) -> Result<Stats> {
196         ensure!(
197             self.command_tx.is_some(),
198             "command_tx is None because worker thread not yet initialized"
199         );
200         let command_tx = self.command_tx.as_ref().unwrap();
201 
202         if command_tx.is_closed() {
203             return Ok(self.latest_stats.clone());
204         }
205 
206         let (resp_tx, resp_rx) = oneshot::channel();
207         command_tx.send(ControlCommand::Stats { resp: resp_tx })?;
208 
209         match RUNTIME_STATIC
210             .block_on(async { tokio::time::timeout(Duration::from_secs(1), resp_rx).await })
211         {
212             Ok(v) => match v {
213                 Ok(stats) => {
214                     self.latest_stats = stats.clone();
215                     Ok(stats)
216                 }
217                 Err(e) => bail!(e),
218             },
219             Err(e) => bail!(e),
220         }
221     }
222 
stats_clear_queries(&self) -> Result<()>223     pub fn stats_clear_queries(&self) -> Result<()> {
224         ensure!(
225             self.command_tx.is_some(),
226             "command_tx is None because worker thread not yet initialized"
227         );
228         return self
229             .command_tx
230             .as_ref()
231             .unwrap()
232             .send(ControlCommand::StatsClearQueries)
233             .or_else(|e| bail!(e));
234     }
235 
init_worker_thread_params(&mut self) -> Result<WorkerParams>236     fn init_worker_thread_params(&mut self) -> Result<WorkerParams> {
237         let bind_addr =
238             if self.backend_socket_addr.ip().is_ipv4() { "0.0.0.0:0" } else { "[::]:0" };
239         let backend_socket = std::net::UdpSocket::bind(bind_addr)?;
240         backend_socket.connect(self.backend_socket_addr)?;
241         backend_socket.set_nonblocking(true)?;
242 
243         let frontend_socket = bind_udp_socket_retry(self.listen_socket_addr)?;
244         frontend_socket.set_nonblocking(true)?;
245 
246         let clients = ClientMap::new(create_quiche_config(
247             self.certificate.to_string(),
248             self.private_key.to_string(),
249             self.config.clone(),
250         )?)?;
251 
252         let (command_tx, command_rx) = mpsc::unbounded_channel::<ControlCommand>();
253         self.command_tx = Some(command_tx);
254 
255         Ok(WorkerParams {
256             frontend_socket,
257             backend_socket,
258             clients,
259             config: self.config.clone(),
260             command_rx,
261         })
262     }
263 
wait_for_connections_closed(&mut self) -> Result<()>264     fn wait_for_connections_closed(&mut self) -> Result<()> {
265         for _ in 0..3 {
266             std::thread::sleep(Duration::from_millis(50));
267             match self.request_stats() {
268                 Ok(stats) if stats.alive_connections == 0 => return Ok(()),
269                 Ok(_) => (),
270 
271                 // The worker thread is down. No connection is alive.
272                 Err(_) => return Ok(()),
273             }
274         }
275         bail!("Some connections still alive")
276     }
277 }
278 
worker_thread(params: WorkerParams) -> Result<()>279 async fn worker_thread(params: WorkerParams) -> Result<()> {
280     let backend_socket = into_tokio_udp_socket(params.backend_socket)?;
281     let frontend_socket = into_tokio_udp_socket(params.frontend_socket)?;
282     let config = params.config;
283     let (event_tx, mut event_rx) = mpsc::unbounded_channel::<InternalCommand>();
284     let mut command_rx = params.command_rx;
285     let mut clients = params.clients;
286     let mut frontend_buf = [0; 65535];
287     let mut backend_buf = [0; 16384];
288     let mut delay_queries_buffer: Vec<Vec<u8>> = vec![];
289     let mut queries_received = 0;
290 
291     debug!("frontend={:?}, backend={:?}", frontend_socket, backend_socket);
292 
293     loop {
294         let timeout = clients
295             .iter_mut()
296             .filter_map(|(_, c)| c.timeout())
297             .min()
298             .unwrap_or_else(|| Duration::from_millis(QUICHE_IDLE_TIMEOUT_MS));
299 
300         tokio::select! {
301             _ = tokio::time::sleep(timeout) => {
302                 debug!("timeout");
303                 for (_, client) in clients.iter_mut() {
304                     // If no timeout has occurred it does nothing.
305                     client.on_timeout();
306 
307                     let connection_id = client.connection_id().clone();
308                     event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
309                 }
310             }
311 
312             Ok((len, peer)) = frontend_socket.recv_from(&mut frontend_buf) => {
313                 debug!("Got {} bytes from {}", len, peer);
314 
315                 // Parse QUIC packet.
316                 let pkt_buf = &mut frontend_buf[..len];
317                 let hdr = match quiche::Header::from_slice(pkt_buf, CONN_ID_LEN) {
318                     Ok(v) => v,
319                     Err(e) => {
320                         error!("Failed to parse QUIC header: {:?}", e);
321                         continue;
322                     }
323                 };
324                 debug!("Got QUIC packet: {:?}", hdr);
325 
326                 let local = frontend_socket.local_addr()?;
327                 let client = match clients.get_or_create(&hdr, &peer, &local) {
328                     Ok(v) => v,
329                     Err(e) => {
330                         error!("Failed to get the client by the hdr {:?}: {}", hdr, e);
331                         continue;
332                     }
333                 };
334                 debug!("Got client: {:?}", client);
335 
336                 match client.handle_frontend_message(pkt_buf, &local) {
337                     Ok(v) if !v.is_empty() => {
338                         delay_queries_buffer.push(v);
339                         queries_received += 1;
340                     }
341                     Err(e) => {
342                         error!("Failed to process QUIC packet: {}", e);
343                         continue;
344                     }
345                     _ => {}
346                 }
347 
348                 if delay_queries_buffer.len() >= config.lock().unwrap().delay_queries as usize {
349                     for query in delay_queries_buffer.drain(..) {
350                         debug!("sending {} bytes to backend", query.len());
351                         backend_socket.send(&query).await?;
352                     }
353                 }
354 
355                 let connection_id = client.connection_id().clone();
356                 event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
357             }
358 
359             Ok((len, src)) = backend_socket.recv_from(&mut backend_buf) => {
360                 debug!("Got {} bytes from {}", len, src);
361                 if len < DNS_HEADER_SIZE {
362                     error!("Received insufficient bytes for DNS header");
363                     continue;
364                 }
365 
366                 let query_id = [backend_buf[0], backend_buf[1]];
367                 for (_, client) in clients.iter_mut() {
368                     if client.is_waiting_for_query(&query_id) {
369                         let reset_stream_id = config.lock().unwrap().reset_stream_id;
370                         if let Err(e) = client.handle_backend_message(&backend_buf[..len], reset_stream_id) {
371                             error!("Failed to handle message from backend: {}", e);
372                         }
373                         let connection_id = client.connection_id().clone();
374                         event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
375 
376                         // It's a bug if more than one client is waiting for this query.
377                         break;
378                     }
379                 }
380             }
381 
382             Some(command) = event_rx.recv(), if !config.lock().unwrap().block_sending => {
383                 match command {
384                     InternalCommand::MaybeWrite {connection_id} => {
385                         if let Some(client) = clients.get_mut(&connection_id) {
386                             while let Ok(v) = client.flush_egress() {
387                                 let addr = client.addr();
388                                 debug!("Sending {} bytes to client {}", v.len(), addr);
389                                 if let Err(e) = frontend_socket.send_to(&v, addr).await {
390                                     error!("Failed to send packet to {:?}: {:?}", client, e);
391                                 }
392                             }
393                             client.process_pending_answers()?;
394                         }
395                     }
396                 }
397             }
398             Some(command) = command_rx.recv() => {
399                 debug!("ControlCommand: {:?}", command);
400                 match command {
401                     ControlCommand::Stats {resp} => {
402                         let stats = Stats {
403                             queries_received,
404                             connections_accepted: clients.len() as u32,
405                             alive_connections: clients.iter().filter(|(_, client)| client.is_alive()).count() as u32,
406                             resumed_connections: clients.iter().filter(|(_, client)| client.is_resumed()).count() as u32,
407                             early_data_connections: clients.iter().filter(|(_, client)| client.handled_early_data()).count() as u32,
408                         };
409                         if let Err(e) = resp.send(stats) {
410                             error!("Failed to send ControlCommand::Stats response: {:?}", e);
411                         }
412                     }
413                     ControlCommand::StatsClearQueries => queries_received = 0,
414                     ControlCommand::CloseConnection => {
415                         for (_, client) in clients.iter_mut() {
416                             client.close();
417                             event_tx.send(InternalCommand::MaybeWrite { connection_id: client.connection_id().clone() })?;
418                         }
419                     }
420                 }
421             }
422         }
423     }
424 }
425 
create_quiche_config( certificate: String, private_key: String, config: Arc<Mutex<Config>>, ) -> Result<quiche::Config>426 fn create_quiche_config(
427     certificate: String,
428     private_key: String,
429     config: Arc<Mutex<Config>>,
430 ) -> Result<quiche::Config> {
431     let mut quiche_config = quiche::Config::new(quiche::PROTOCOL_VERSION)?;
432 
433     // Use pipe as a file path for Quiche to read the certificate and the private key.
434     let (rd, mut wr) = build_pipe()?;
435     let handle = std::thread::spawn(move || {
436         wr.write_all(certificate.as_bytes()).expect("Failed to write to pipe");
437     });
438     let filepath = format!("/proc/self/fd/{}", rd.as_raw_fd());
439     quiche_config.load_cert_chain_from_pem_file(&filepath)?;
440     handle.join().unwrap();
441 
442     let (rd, mut wr) = build_pipe()?;
443     let handle = std::thread::spawn(move || {
444         wr.write_all(private_key.as_bytes()).expect("Failed to write to pipe");
445     });
446     let filepath = format!("/proc/self/fd/{}", rd.as_raw_fd());
447     quiche_config.load_priv_key_from_pem_file(&filepath)?;
448     handle.join().unwrap();
449 
450     quiche_config.set_application_protos(quiche::h3::APPLICATION_PROTOCOL)?;
451     quiche_config.set_max_idle_timeout(config.lock().unwrap().max_idle_timeout);
452     quiche_config.set_max_recv_udp_payload_size(MAX_UDP_PAYLOAD_SIZE);
453 
454     let max_buffer_size = config.lock().unwrap().max_buffer_size;
455     quiche_config.set_initial_max_data(max_buffer_size);
456     quiche_config.set_initial_max_stream_data_bidi_local(max_buffer_size);
457     quiche_config.set_initial_max_stream_data_bidi_remote(max_buffer_size);
458     quiche_config.set_initial_max_stream_data_uni(max_buffer_size);
459 
460     quiche_config.set_initial_max_streams_bidi(config.lock().unwrap().max_streams_bidi);
461     quiche_config.set_initial_max_streams_uni(100);
462     quiche_config.set_disable_active_migration(true);
463     quiche_config.enable_early_data();
464 
465     Ok(quiche_config)
466 }
467 
into_tokio_udp_socket(socket: std::net::UdpSocket) -> Result<UdpSocket>468 fn into_tokio_udp_socket(socket: std::net::UdpSocket) -> Result<UdpSocket> {
469     match UdpSocket::from_std(socket) {
470         Ok(v) => Ok(v),
471         Err(e) => {
472             error!("into_tokio_udp_socket failed: {}", e);
473             bail!("into_tokio_udp_socket failed: {}", e)
474         }
475     }
476 }
477 
build_pipe() -> Result<(File, File)>478 fn build_pipe() -> Result<(File, File)> {
479     let mut fds = [0, 0];
480     // SAFETY: The pointer we pass to `pipe` must be valid because it comes from a reference. The
481     // file descriptors it returns must be valid and open, so they are safe to pass to
482     // `File::from_raw_fd`.
483     unsafe {
484         if libc::pipe(fds.as_mut_ptr()) == 0 {
485             return Ok((File::from_raw_fd(fds[0]), File::from_raw_fd(fds[1])));
486         }
487     }
488     Err(anyhow::Error::new(std::io::Error::last_os_error()).context("build_pipe failed"))
489 }
490 
491 // Can retry to bind the socket address if it is in use.
bind_udp_socket_retry(addr: std::net::SocketAddr) -> Result<std::net::UdpSocket>492 fn bind_udp_socket_retry(addr: std::net::SocketAddr) -> Result<std::net::UdpSocket> {
493     for _ in 0..3 {
494         match std::net::UdpSocket::bind(addr) {
495             Ok(socket) => return Ok(socket),
496             Err(e) if e.kind() == std::io::ErrorKind::AddrInUse => {
497                 warn!("Binding socket address {} that is in use. Try again", addr);
498                 std::thread::sleep(Duration::from_millis(50));
499             }
500             Err(e) => return Err(anyhow::anyhow!(e)),
501         }
502     }
503     Err(anyhow::anyhow!(std::io::Error::last_os_error()))
504 }
505