1 /*
2 * Copyright (C) 2021, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 //! DoH server frontend.
18
19 use super::client::{ClientMap, ConnectionID, CONN_ID_LEN, DNS_HEADER_SIZE, MAX_UDP_PAYLOAD_SIZE};
20 use super::config::{Config, QUICHE_IDLE_TIMEOUT_MS};
21 use super::stats::Stats;
22 use anyhow::{bail, ensure, Result};
23 use lazy_static::lazy_static;
24 use log::{debug, error, warn};
25 use std::fs::File;
26 use std::io::Write;
27 use std::os::unix::io::{AsRawFd, FromRawFd};
28 use std::sync::{Arc, Mutex};
29 use std::time::Duration;
30 use tokio::net::UdpSocket;
31 use tokio::runtime::{Builder, Runtime};
32 use tokio::sync::{mpsc, oneshot};
33 use tokio::task::JoinHandle;
34
35 lazy_static! {
36 static ref RUNTIME_STATIC: Arc<Runtime> = Arc::new(
37 Builder::new_multi_thread()
38 .worker_threads(1)
39 .enable_all()
40 .thread_name("DohFrontend")
41 .build()
42 .expect("Failed to create tokio runtime")
43 );
44 }
45
46 /// Command used by worker_thread itself.
47 #[derive(Debug)]
48 enum InternalCommand {
49 MaybeWrite { connection_id: ConnectionID },
50 }
51
52 /// Commands that DohFrontend to ask its worker_thread for.
53 #[derive(Debug)]
54 enum ControlCommand {
55 Stats { resp: oneshot::Sender<Stats> },
56 StatsClearQueries,
57 CloseConnection,
58 }
59
60 /// Frontend object.
61 #[derive(Debug)]
62 pub struct DohFrontend {
63 // Socket address the frontend listens to.
64 listen_socket_addr: std::net::SocketAddr,
65
66 // Socket address the backend listens to.
67 backend_socket_addr: std::net::SocketAddr,
68
69 /// The content of the certificate.
70 certificate: String,
71
72 /// The content of the private key.
73 private_key: String,
74
75 // The thread listening to frontend socket and backend socket
76 // and processing the messages.
77 worker_thread: Option<JoinHandle<Result<()>>>,
78
79 // Custom runtime configuration to control the behavior of the worker thread.
80 // It's shared with the worker thread.
81 // TODO: use channel to update worker_thread configuration.
82 config: Arc<Mutex<Config>>,
83
84 // Caches the latest stats so that the stats remains after worker_thread stops.
85 latest_stats: Stats,
86
87 // It is wrapped as Option because the channel is not created in DohFrontend construction.
88 command_tx: Option<mpsc::UnboundedSender<ControlCommand>>,
89 }
90
91 /// The parameters passed to the worker thread.
92 struct WorkerParams {
93 frontend_socket: std::net::UdpSocket,
94 backend_socket: std::net::UdpSocket,
95 clients: ClientMap,
96 config: Arc<Mutex<Config>>,
97 command_rx: mpsc::UnboundedReceiver<ControlCommand>,
98 }
99
100 impl DohFrontend {
new( listen: std::net::SocketAddr, backend: std::net::SocketAddr, ) -> Result<Box<DohFrontend>>101 pub fn new(
102 listen: std::net::SocketAddr,
103 backend: std::net::SocketAddr,
104 ) -> Result<Box<DohFrontend>> {
105 let doh = Box::new(DohFrontend {
106 listen_socket_addr: listen,
107 backend_socket_addr: backend,
108 certificate: String::new(),
109 private_key: String::new(),
110 worker_thread: None,
111 config: Arc::new(Mutex::new(Config::new())),
112 latest_stats: Stats::new(),
113 command_tx: None,
114 });
115 debug!("DohFrontend created: {:?}", doh);
116 Ok(doh)
117 }
118
start(&mut self) -> Result<()>119 pub fn start(&mut self) -> Result<()> {
120 ensure!(self.worker_thread.is_none(), "Worker thread has been running");
121 ensure!(!self.certificate.is_empty(), "certificate is empty");
122 ensure!(!self.private_key.is_empty(), "private_key is empty");
123
124 // Doing error handling here is much simpler.
125 let params = match self.init_worker_thread_params() {
126 Ok(v) => v,
127 Err(e) => return Err(e.context("init_worker_thread_params failed")),
128 };
129
130 self.worker_thread = Some(RUNTIME_STATIC.spawn(worker_thread(params)));
131 Ok(())
132 }
133
stop(&mut self) -> Result<()>134 pub fn stop(&mut self) -> Result<()> {
135 debug!("DohFrontend: stopping: {:?}", self);
136 if let Some(worker_thread) = self.worker_thread.take() {
137 // Update latest_stats before stopping worker_thread.
138 let _ = self.request_stats();
139
140 self.command_tx.as_ref().unwrap().send(ControlCommand::CloseConnection)?;
141 if let Err(e) = self.wait_for_connections_closed() {
142 warn!("wait_for_connections_closed failed: {}", e);
143 }
144
145 worker_thread.abort();
146 RUNTIME_STATIC.block_on(async {
147 debug!("worker_thread result: {:?}", worker_thread.await);
148 })
149 }
150
151 debug!("DohFrontend: stopped: {:?}", self);
152 Ok(())
153 }
154
set_certificate(&mut self, certificate: &str) -> Result<()>155 pub fn set_certificate(&mut self, certificate: &str) -> Result<()> {
156 self.certificate = certificate.to_string();
157 Ok(())
158 }
159
set_private_key(&mut self, private_key: &str) -> Result<()>160 pub fn set_private_key(&mut self, private_key: &str) -> Result<()> {
161 self.private_key = private_key.to_string();
162 Ok(())
163 }
164
set_delay_queries(&self, value: i32) -> Result<()>165 pub fn set_delay_queries(&self, value: i32) -> Result<()> {
166 self.config.lock().unwrap().delay_queries = value;
167 Ok(())
168 }
169
set_max_idle_timeout(&self, value: u64) -> Result<()>170 pub fn set_max_idle_timeout(&self, value: u64) -> Result<()> {
171 self.config.lock().unwrap().max_idle_timeout = value;
172 Ok(())
173 }
174
set_max_buffer_size(&self, value: u64) -> Result<()>175 pub fn set_max_buffer_size(&self, value: u64) -> Result<()> {
176 self.config.lock().unwrap().max_buffer_size = value;
177 Ok(())
178 }
179
set_max_streams_bidi(&self, value: u64) -> Result<()>180 pub fn set_max_streams_bidi(&self, value: u64) -> Result<()> {
181 self.config.lock().unwrap().max_streams_bidi = value;
182 Ok(())
183 }
184
block_sending(&self, value: bool) -> Result<()>185 pub fn block_sending(&self, value: bool) -> Result<()> {
186 self.config.lock().unwrap().block_sending = value;
187 Ok(())
188 }
189
set_reset_stream_id(&self, value: u64) -> Result<()>190 pub fn set_reset_stream_id(&self, value: u64) -> Result<()> {
191 self.config.lock().unwrap().reset_stream_id = Some(value);
192 Ok(())
193 }
194
request_stats(&mut self) -> Result<Stats>195 pub fn request_stats(&mut self) -> Result<Stats> {
196 ensure!(
197 self.command_tx.is_some(),
198 "command_tx is None because worker thread not yet initialized"
199 );
200 let command_tx = self.command_tx.as_ref().unwrap();
201
202 if command_tx.is_closed() {
203 return Ok(self.latest_stats.clone());
204 }
205
206 let (resp_tx, resp_rx) = oneshot::channel();
207 command_tx.send(ControlCommand::Stats { resp: resp_tx })?;
208
209 match RUNTIME_STATIC
210 .block_on(async { tokio::time::timeout(Duration::from_secs(1), resp_rx).await })
211 {
212 Ok(v) => match v {
213 Ok(stats) => {
214 self.latest_stats = stats.clone();
215 Ok(stats)
216 }
217 Err(e) => bail!(e),
218 },
219 Err(e) => bail!(e),
220 }
221 }
222
stats_clear_queries(&self) -> Result<()>223 pub fn stats_clear_queries(&self) -> Result<()> {
224 ensure!(
225 self.command_tx.is_some(),
226 "command_tx is None because worker thread not yet initialized"
227 );
228 return self
229 .command_tx
230 .as_ref()
231 .unwrap()
232 .send(ControlCommand::StatsClearQueries)
233 .or_else(|e| bail!(e));
234 }
235
init_worker_thread_params(&mut self) -> Result<WorkerParams>236 fn init_worker_thread_params(&mut self) -> Result<WorkerParams> {
237 let bind_addr =
238 if self.backend_socket_addr.ip().is_ipv4() { "0.0.0.0:0" } else { "[::]:0" };
239 let backend_socket = std::net::UdpSocket::bind(bind_addr)?;
240 backend_socket.connect(self.backend_socket_addr)?;
241 backend_socket.set_nonblocking(true)?;
242
243 let frontend_socket = bind_udp_socket_retry(self.listen_socket_addr)?;
244 frontend_socket.set_nonblocking(true)?;
245
246 let clients = ClientMap::new(create_quiche_config(
247 self.certificate.to_string(),
248 self.private_key.to_string(),
249 self.config.clone(),
250 )?)?;
251
252 let (command_tx, command_rx) = mpsc::unbounded_channel::<ControlCommand>();
253 self.command_tx = Some(command_tx);
254
255 Ok(WorkerParams {
256 frontend_socket,
257 backend_socket,
258 clients,
259 config: self.config.clone(),
260 command_rx,
261 })
262 }
263
wait_for_connections_closed(&mut self) -> Result<()>264 fn wait_for_connections_closed(&mut self) -> Result<()> {
265 for _ in 0..3 {
266 std::thread::sleep(Duration::from_millis(50));
267 match self.request_stats() {
268 Ok(stats) if stats.alive_connections == 0 => return Ok(()),
269 Ok(_) => (),
270
271 // The worker thread is down. No connection is alive.
272 Err(_) => return Ok(()),
273 }
274 }
275 bail!("Some connections still alive")
276 }
277 }
278
worker_thread(params: WorkerParams) -> Result<()>279 async fn worker_thread(params: WorkerParams) -> Result<()> {
280 let backend_socket = into_tokio_udp_socket(params.backend_socket)?;
281 let frontend_socket = into_tokio_udp_socket(params.frontend_socket)?;
282 let config = params.config;
283 let (event_tx, mut event_rx) = mpsc::unbounded_channel::<InternalCommand>();
284 let mut command_rx = params.command_rx;
285 let mut clients = params.clients;
286 let mut frontend_buf = [0; 65535];
287 let mut backend_buf = [0; 16384];
288 let mut delay_queries_buffer: Vec<Vec<u8>> = vec![];
289 let mut queries_received = 0;
290
291 debug!("frontend={:?}, backend={:?}", frontend_socket, backend_socket);
292
293 loop {
294 let timeout = clients
295 .iter_mut()
296 .filter_map(|(_, c)| c.timeout())
297 .min()
298 .unwrap_or_else(|| Duration::from_millis(QUICHE_IDLE_TIMEOUT_MS));
299
300 tokio::select! {
301 _ = tokio::time::sleep(timeout) => {
302 debug!("timeout");
303 for (_, client) in clients.iter_mut() {
304 // If no timeout has occurred it does nothing.
305 client.on_timeout();
306
307 let connection_id = client.connection_id().clone();
308 event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
309 }
310 }
311
312 Ok((len, peer)) = frontend_socket.recv_from(&mut frontend_buf) => {
313 debug!("Got {} bytes from {}", len, peer);
314
315 // Parse QUIC packet.
316 let pkt_buf = &mut frontend_buf[..len];
317 let hdr = match quiche::Header::from_slice(pkt_buf, CONN_ID_LEN) {
318 Ok(v) => v,
319 Err(e) => {
320 error!("Failed to parse QUIC header: {:?}", e);
321 continue;
322 }
323 };
324 debug!("Got QUIC packet: {:?}", hdr);
325
326 let local = frontend_socket.local_addr()?;
327 let client = match clients.get_or_create(&hdr, &peer, &local) {
328 Ok(v) => v,
329 Err(e) => {
330 error!("Failed to get the client by the hdr {:?}: {}", hdr, e);
331 continue;
332 }
333 };
334 debug!("Got client: {:?}", client);
335
336 match client.handle_frontend_message(pkt_buf, &local) {
337 Ok(v) if !v.is_empty() => {
338 delay_queries_buffer.push(v);
339 queries_received += 1;
340 }
341 Err(e) => {
342 error!("Failed to process QUIC packet: {}", e);
343 continue;
344 }
345 _ => {}
346 }
347
348 if delay_queries_buffer.len() >= config.lock().unwrap().delay_queries as usize {
349 for query in delay_queries_buffer.drain(..) {
350 debug!("sending {} bytes to backend", query.len());
351 backend_socket.send(&query).await?;
352 }
353 }
354
355 let connection_id = client.connection_id().clone();
356 event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
357 }
358
359 Ok((len, src)) = backend_socket.recv_from(&mut backend_buf) => {
360 debug!("Got {} bytes from {}", len, src);
361 if len < DNS_HEADER_SIZE {
362 error!("Received insufficient bytes for DNS header");
363 continue;
364 }
365
366 let query_id = [backend_buf[0], backend_buf[1]];
367 for (_, client) in clients.iter_mut() {
368 if client.is_waiting_for_query(&query_id) {
369 let reset_stream_id = config.lock().unwrap().reset_stream_id;
370 if let Err(e) = client.handle_backend_message(&backend_buf[..len], reset_stream_id) {
371 error!("Failed to handle message from backend: {}", e);
372 }
373 let connection_id = client.connection_id().clone();
374 event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
375
376 // It's a bug if more than one client is waiting for this query.
377 break;
378 }
379 }
380 }
381
382 Some(command) = event_rx.recv(), if !config.lock().unwrap().block_sending => {
383 match command {
384 InternalCommand::MaybeWrite {connection_id} => {
385 if let Some(client) = clients.get_mut(&connection_id) {
386 while let Ok(v) = client.flush_egress() {
387 let addr = client.addr();
388 debug!("Sending {} bytes to client {}", v.len(), addr);
389 if let Err(e) = frontend_socket.send_to(&v, addr).await {
390 error!("Failed to send packet to {:?}: {:?}", client, e);
391 }
392 }
393 client.process_pending_answers()?;
394 }
395 }
396 }
397 }
398 Some(command) = command_rx.recv() => {
399 debug!("ControlCommand: {:?}", command);
400 match command {
401 ControlCommand::Stats {resp} => {
402 let stats = Stats {
403 queries_received,
404 connections_accepted: clients.len() as u32,
405 alive_connections: clients.iter().filter(|(_, client)| client.is_alive()).count() as u32,
406 resumed_connections: clients.iter().filter(|(_, client)| client.is_resumed()).count() as u32,
407 early_data_connections: clients.iter().filter(|(_, client)| client.handled_early_data()).count() as u32,
408 };
409 if let Err(e) = resp.send(stats) {
410 error!("Failed to send ControlCommand::Stats response: {:?}", e);
411 }
412 }
413 ControlCommand::StatsClearQueries => queries_received = 0,
414 ControlCommand::CloseConnection => {
415 for (_, client) in clients.iter_mut() {
416 client.close();
417 event_tx.send(InternalCommand::MaybeWrite { connection_id: client.connection_id().clone() })?;
418 }
419 }
420 }
421 }
422 }
423 }
424 }
425
create_quiche_config( certificate: String, private_key: String, config: Arc<Mutex<Config>>, ) -> Result<quiche::Config>426 fn create_quiche_config(
427 certificate: String,
428 private_key: String,
429 config: Arc<Mutex<Config>>,
430 ) -> Result<quiche::Config> {
431 let mut quiche_config = quiche::Config::new(quiche::PROTOCOL_VERSION)?;
432
433 // Use pipe as a file path for Quiche to read the certificate and the private key.
434 let (rd, mut wr) = build_pipe()?;
435 let handle = std::thread::spawn(move || {
436 wr.write_all(certificate.as_bytes()).expect("Failed to write to pipe");
437 });
438 let filepath = format!("/proc/self/fd/{}", rd.as_raw_fd());
439 quiche_config.load_cert_chain_from_pem_file(&filepath)?;
440 handle.join().unwrap();
441
442 let (rd, mut wr) = build_pipe()?;
443 let handle = std::thread::spawn(move || {
444 wr.write_all(private_key.as_bytes()).expect("Failed to write to pipe");
445 });
446 let filepath = format!("/proc/self/fd/{}", rd.as_raw_fd());
447 quiche_config.load_priv_key_from_pem_file(&filepath)?;
448 handle.join().unwrap();
449
450 quiche_config.set_application_protos(quiche::h3::APPLICATION_PROTOCOL)?;
451 quiche_config.set_max_idle_timeout(config.lock().unwrap().max_idle_timeout);
452 quiche_config.set_max_recv_udp_payload_size(MAX_UDP_PAYLOAD_SIZE);
453
454 let max_buffer_size = config.lock().unwrap().max_buffer_size;
455 quiche_config.set_initial_max_data(max_buffer_size);
456 quiche_config.set_initial_max_stream_data_bidi_local(max_buffer_size);
457 quiche_config.set_initial_max_stream_data_bidi_remote(max_buffer_size);
458 quiche_config.set_initial_max_stream_data_uni(max_buffer_size);
459
460 quiche_config.set_initial_max_streams_bidi(config.lock().unwrap().max_streams_bidi);
461 quiche_config.set_initial_max_streams_uni(100);
462 quiche_config.set_disable_active_migration(true);
463 quiche_config.enable_early_data();
464
465 Ok(quiche_config)
466 }
467
into_tokio_udp_socket(socket: std::net::UdpSocket) -> Result<UdpSocket>468 fn into_tokio_udp_socket(socket: std::net::UdpSocket) -> Result<UdpSocket> {
469 match UdpSocket::from_std(socket) {
470 Ok(v) => Ok(v),
471 Err(e) => {
472 error!("into_tokio_udp_socket failed: {}", e);
473 bail!("into_tokio_udp_socket failed: {}", e)
474 }
475 }
476 }
477
build_pipe() -> Result<(File, File)>478 fn build_pipe() -> Result<(File, File)> {
479 let mut fds = [0, 0];
480 // SAFETY: The pointer we pass to `pipe` must be valid because it comes from a reference. The
481 // file descriptors it returns must be valid and open, so they are safe to pass to
482 // `File::from_raw_fd`.
483 unsafe {
484 if libc::pipe(fds.as_mut_ptr()) == 0 {
485 return Ok((File::from_raw_fd(fds[0]), File::from_raw_fd(fds[1])));
486 }
487 }
488 Err(anyhow::Error::new(std::io::Error::last_os_error()).context("build_pipe failed"))
489 }
490
491 // Can retry to bind the socket address if it is in use.
bind_udp_socket_retry(addr: std::net::SocketAddr) -> Result<std::net::UdpSocket>492 fn bind_udp_socket_retry(addr: std::net::SocketAddr) -> Result<std::net::UdpSocket> {
493 for _ in 0..3 {
494 match std::net::UdpSocket::bind(addr) {
495 Ok(socket) => return Ok(socket),
496 Err(e) if e.kind() == std::io::ErrorKind::AddrInUse => {
497 warn!("Binding socket address {} that is in use. Try again", addr);
498 std::thread::sleep(Duration::from_millis(50));
499 }
500 Err(e) => return Err(anyhow::anyhow!(e)),
501 }
502 }
503 Err(anyhow::anyhow!(std::io::Error::last_os_error()))
504 }
505