1 /*
2  * Copyright (C) 2021, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 //! Client management, including the communication with quiche I/O.
18 
19 use anyhow::{anyhow, bail, ensure, Result};
20 use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine};
21 use log::{debug, error, info, warn};
22 use quiche::h3::NameValue;
23 use std::collections::{hash_map, HashMap};
24 use std::net::SocketAddr;
25 use std::time::Duration;
26 
27 pub const DNS_HEADER_SIZE: usize = 12;
28 pub const MAX_UDP_PAYLOAD_SIZE: usize = 1350;
29 pub const CONN_ID_LEN: usize = 8;
30 
31 pub type ConnectionID = Vec<u8>;
32 
33 const URL_PATH_PREFIX: &str = "/dns-query?dns=";
34 
35 /// Manages a QUIC and HTTP/3 connection. No socket I/O operations.
36 pub struct Client {
37     /// QUIC connection.
38     conn: quiche::Connection,
39 
40     /// HTTP/3 connection.
41     h3_conn: Option<quiche::h3::Connection>,
42 
43     /// Socket address the client from.
44     addr: SocketAddr,
45 
46     /// The unique ID for the client.
47     id: ConnectionID,
48 
49     /// Queues the DNS queries being processed in backend.
50     /// <Query ID, Stream ID>
51     in_flight_queries: HashMap<[u8; 2], u64>,
52 
53     /// Queues the second part DNS answers needed to be sent after first part.
54     /// <Stream ID, ans>
55     pending_answers: Vec<(u64, Vec<u8>)>,
56 
57     /// Returns true if early data is received.
58     handled_early_data: bool,
59 }
60 
61 impl Client {
new(conn: quiche::Connection, addr: &SocketAddr, id: ConnectionID) -> Client62     fn new(conn: quiche::Connection, addr: &SocketAddr, id: ConnectionID) -> Client {
63         Client {
64             conn,
65             h3_conn: None,
66             addr: *addr,
67             id,
68             in_flight_queries: HashMap::new(),
69             pending_answers: Vec::new(),
70             handled_early_data: false,
71         }
72     }
73 
create_http3_connection(&mut self) -> Result<()>74     fn create_http3_connection(&mut self) -> Result<()> {
75         ensure!(self.h3_conn.is_none(), "HTTP/3 connection is already created");
76 
77         let config = quiche::h3::Config::new()?;
78         let conn = quiche::h3::Connection::with_transport(&mut self.conn, &config)?;
79         self.h3_conn = Some(conn);
80         Ok(())
81     }
82 
83     // Processes HTTP/3 request and returns the wire format DNS query or an empty vector.
handle_http3_request(&mut self) -> Result<Vec<u8>>84     fn handle_http3_request(&mut self) -> Result<Vec<u8>> {
85         ensure!(self.h3_conn.is_some(), "HTTP/3 connection not created");
86 
87         let h3_conn = self.h3_conn.as_mut().unwrap();
88         let mut ret = vec![];
89 
90         loop {
91             match h3_conn.poll(&mut self.conn) {
92                 Ok((stream_id, quiche::h3::Event::Headers { list, has_body })) => {
93                     info!(
94                         "Processing HTTP/3 Headers {:?} on stream id {} has_body {}",
95                         list, stream_id, has_body
96                     );
97 
98                     // Find ":path" field to get the query.
99                     if let Some(target) = list.iter().find(|e| {
100                         e.name() == b":path" && e.value().starts_with(URL_PATH_PREFIX.as_bytes())
101                     }) {
102                         let b64url_query = &target.value()[URL_PATH_PREFIX.len()..];
103                         let decoded = BASE64_URL_SAFE_NO_PAD.decode(b64url_query)?;
104                         self.in_flight_queries.insert([decoded[0], decoded[1]], stream_id);
105                         ret = decoded;
106                     }
107                 }
108                 Ok((stream_id, quiche::h3::Event::Data)) => {
109                     warn!("Received unexpected HTTP/3 data");
110                     let mut buf = [0; 65535];
111                     if let Ok(read) = h3_conn.recv_body(&mut self.conn, stream_id, &mut buf) {
112                         warn!("Got {} bytes of response data on stream {}", read, stream_id);
113                     }
114                 }
115                 Ok(n) => {
116                     debug!("Got event {:?}", n);
117                 }
118                 Err(quiche::h3::Error::Done) => {
119                     debug!("quiche::h3::Error::Done");
120                     break;
121                 }
122                 Err(e) => bail!("HTTP/3 processing failed: {:?}", e),
123             }
124         }
125 
126         Ok(ret)
127     }
128 
129     // Converts the clear-text DNS response to a DoH response, and sends it to the quiche.
handle_backend_message( &mut self, response: &[u8], send_reset_stream: Option<u64>, ) -> Result<()>130     pub fn handle_backend_message(
131         &mut self,
132         response: &[u8],
133         send_reset_stream: Option<u64>,
134     ) -> Result<()> {
135         ensure!(self.h3_conn.is_some(), "HTTP/3 connection not created");
136         ensure!(response.len() >= DNS_HEADER_SIZE, "Insufficient bytes of DNS response");
137 
138         let len = response.len();
139         let headers = vec![
140             quiche::h3::Header::new(b":status", b"200"),
141             quiche::h3::Header::new(b"content-type", b"application/dns-message"),
142             quiche::h3::Header::new(b"content-length", len.to_string().as_bytes()),
143             // TODO: need to add cache-control?
144         ];
145 
146         let h3_conn = self.h3_conn.as_mut().unwrap();
147         let query_id = u16::from_be_bytes([response[0], response[1]]);
148         let stream_id = self
149             .in_flight_queries
150             .remove(&[response[0], response[1]])
151             .ok_or_else(|| anyhow!("query_id {:x} not found", query_id))?;
152 
153         if let Some(send_reset_stream) = send_reset_stream {
154             if send_reset_stream == stream_id {
155                 // Terminate the stream with an error code 99.
156                 self.conn.stream_shutdown(stream_id, quiche::Shutdown::Write, 99)?;
157                 info!("Preparing RESET_STREAM on stream {}", stream_id);
158                 return Ok(());
159             }
160         }
161 
162         info!("Preparing HTTP/3 response {:?} on stream {}", headers, stream_id);
163 
164         h3_conn.send_response(&mut self.conn, stream_id, &headers, false)?;
165 
166         // In order to simulate the case that server send multiple packets for a DNS answer,
167         // only send half of the answer here. The remaining one will be cached here and then
168         // processed later in process_pending_answers().
169         let (first, second) = response.split_at(len / 2);
170         h3_conn.send_body(&mut self.conn, stream_id, first, false)?;
171         self.pending_answers.push((stream_id, second.to_vec()));
172 
173         Ok(())
174     }
175 
process_pending_answers(&mut self) -> Result<()>176     pub fn process_pending_answers(&mut self) -> Result<()> {
177         if let Some((stream_id, ans)) = self.pending_answers.pop() {
178             let h3_conn = self.h3_conn.as_mut().unwrap();
179             info!("process the remaining response for stream {}", stream_id);
180             h3_conn.send_body(&mut self.conn, stream_id, &ans, true)?;
181         }
182         Ok(())
183     }
184 
185     // Returns the data the client wants to send.
flush_egress(&mut self) -> Result<Vec<u8>>186     pub fn flush_egress(&mut self) -> Result<Vec<u8>> {
187         let mut ret = vec![];
188         let mut buf = [0; MAX_UDP_PAYLOAD_SIZE];
189 
190         let (write, _) = match self.conn.send(&mut buf) {
191             Ok(v) => v,
192             Err(quiche::Error::Done) => bail!(quiche::Error::Done),
193             Err(e) => {
194                 error!("flush_egress failed: {}", e);
195                 bail!(e)
196             }
197         };
198         ret.append(&mut buf[..write].to_vec());
199 
200         Ok(ret)
201     }
202 
203     // Processes the packet received from the frontend socket. If |data| is a DoH query,
204     // the function returns the wire format DNS query; otherwise, it returns empty vector.
handle_frontend_message( &mut self, data: &mut [u8], local: &SocketAddr, ) -> Result<Vec<u8>>205     pub fn handle_frontend_message(
206         &mut self,
207         data: &mut [u8],
208         local: &SocketAddr,
209     ) -> Result<Vec<u8>> {
210         let recv_info = quiche::RecvInfo { from: self.addr, to: *local };
211         self.conn.recv(data, recv_info)?;
212 
213         if (self.conn.is_in_early_data() || self.conn.is_established()) && self.h3_conn.is_none() {
214             // Create a HTTP3 connection as soon as either the QUIC connection is established or
215             // the handshake has progressed enough to receive early data.
216             self.create_http3_connection()?;
217             info!("HTTP/3 connection created");
218         }
219 
220         if self.h3_conn.is_some() {
221             if self.conn.is_in_early_data() {
222                 self.handled_early_data = true;
223             }
224             return self.handle_http3_request();
225         }
226 
227         Ok(vec![])
228     }
229 
is_waiting_for_query(&self, query_id: &[u8; 2]) -> bool230     pub fn is_waiting_for_query(&self, query_id: &[u8; 2]) -> bool {
231         self.in_flight_queries.contains_key(query_id)
232     }
233 
addr(&self) -> SocketAddr234     pub fn addr(&self) -> SocketAddr {
235         self.addr
236     }
237 
connection_id(&self) -> &ConnectionID238     pub fn connection_id(&self) -> &ConnectionID {
239         self.id.as_ref()
240     }
241 
timeout(&self) -> Option<Duration>242     pub fn timeout(&self) -> Option<Duration> {
243         self.conn.timeout()
244     }
245 
on_timeout(&mut self)246     pub fn on_timeout(&mut self) {
247         self.conn.on_timeout();
248     }
249 
is_alive(&self) -> bool250     pub fn is_alive(&self) -> bool {
251         self.conn.is_established() && !self.conn.is_closed()
252     }
253 
is_resumed(&self) -> bool254     pub fn is_resumed(&self) -> bool {
255         self.conn.is_resumed()
256     }
257 
close(&mut self)258     pub fn close(&mut self) {
259         let _ = self.conn.close(false, 0, b"Graceful shutdown");
260     }
261 
handled_early_data(&self) -> bool262     pub fn handled_early_data(&self) -> bool {
263         self.handled_early_data
264     }
265 }
266 
267 impl std::fmt::Debug for Client {
fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result268     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
269         f.debug_struct("Client")
270             .field("addr", &self.addr())
271             .field("conn_id", &self.conn.trace_id())
272             .finish()
273     }
274 }
275 
276 pub struct ClientMap {
277     clients: HashMap<ConnectionID, Client>,
278     config: quiche::Config,
279 }
280 
281 impl ClientMap {
new(config: quiche::Config) -> Result<ClientMap>282     pub fn new(config: quiche::Config) -> Result<ClientMap> {
283         Ok(ClientMap { clients: HashMap::new(), config })
284     }
285 
get_or_create( &mut self, hdr: &quiche::Header, peer: &SocketAddr, local: &SocketAddr, ) -> Result<&mut Client>286     pub fn get_or_create(
287         &mut self,
288         hdr: &quiche::Header,
289         peer: &SocketAddr,
290         local: &SocketAddr,
291     ) -> Result<&mut Client> {
292         let conn_id = get_conn_id(hdr)?;
293         let client = match self.clients.entry(conn_id.clone()) {
294             hash_map::Entry::Occupied(client) => client.into_mut(),
295             hash_map::Entry::Vacant(vacant) => {
296                 ensure!(hdr.ty == quiche::Type::Initial, "Packet is not Initial");
297                 ensure!(
298                     quiche::version_is_supported(hdr.version),
299                     "Protocol version not supported"
300                 );
301                 let conn = quiche::accept(
302                     &quiche::ConnectionId::from_ref(&conn_id),
303                     None, /* odcid */
304                     *local,
305                     *peer,
306                     &mut self.config,
307                 )?;
308                 let client = Client::new(conn, peer, conn_id.clone());
309                 info!("New client: {:?}", client);
310                 vacant.insert(client)
311             }
312         };
313         Ok(client)
314     }
315 
get_mut(&mut self, id: &[u8]) -> Option<&mut Client>316     pub fn get_mut(&mut self, id: &[u8]) -> Option<&mut Client> {
317         self.clients.get_mut(id)
318     }
319 
iter_mut(&mut self) -> hash_map::IterMut<ConnectionID, Client>320     pub fn iter_mut(&mut self) -> hash_map::IterMut<ConnectionID, Client> {
321         self.clients.iter_mut()
322     }
323 
iter(&mut self) -> hash_map::Iter<ConnectionID, Client>324     pub fn iter(&mut self) -> hash_map::Iter<ConnectionID, Client> {
325         self.clients.iter()
326     }
327 
len(&mut self) -> usize328     pub fn len(&mut self) -> usize {
329         self.clients.len()
330     }
331 }
332 
333 // Per RFC 9000 section 7.2, an Initial packet's dcid from a new client must be
334 // at least 8 bytes in length. We use the first 8 bytes of dcid as new connection
335 // ID to identify the client.
336 // This is helpful to identify 0-RTT packets. In 0-RTT handshake, 0-RTT packets
337 // are followed after the Initial packet with the same dcid. With this function, we
338 // know which 0-RTT packets belong to which client.
get_conn_id(hdr: &quiche::Header) -> Result<ConnectionID>339 fn get_conn_id(hdr: &quiche::Header) -> Result<ConnectionID> {
340     if let Some(v) = hdr.dcid.as_ref().get(0..CONN_ID_LEN) {
341         return Ok(v.to_vec());
342     }
343     bail!("QUIC packet {:?} dcid too small", hdr.ty)
344 }
345