1 // Copyright 2021, The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use keystore2_selinux::{check_access, Context};
16 use nix::sched::sched_setaffinity;
17 use nix::sched::CpuSet;
18 use nix::unistd::getpid;
19 use std::thread;
20 use std::{
21     sync::{atomic::AtomicU8, atomic::Ordering, Arc},
22     time::{Duration, Instant},
23 };
24 
25 #[derive(Clone, Copy)]
26 struct CatCount(u8, u8, u8, u8);
27 
28 impl CatCount {
next(&mut self) -> CatCount29     fn next(&mut self) -> CatCount {
30         let result = *self;
31         if self.3 == 255 {
32             if self.2 == 254 {
33                 if self.1 == 253 {
34                     if self.0 == 252 {
35                         self.0 = 255;
36                     }
37                     self.0 += 1;
38                     self.1 = self.0;
39                 }
40                 self.1 += 1;
41                 self.2 = self.1;
42             }
43             self.2 += 1;
44             self.3 = self.2;
45         }
46         self.3 += 1;
47         result
48     }
49 
make_string(&self) -> String50     fn make_string(&self) -> String {
51         format!("c{},c{},c{},c{}", self.0, self.1, self.2, self.3)
52     }
53 }
54 
55 impl Default for CatCount {
default() -> Self56     fn default() -> Self {
57         Self(0, 1, 2, 3)
58     }
59 }
60 
61 /// This test calls selinux_check_access concurrently causing access vector cache misses
62 /// in libselinux avc. The test then checks if any of the threads fails to report back
63 /// after a burst of access checks. The purpose of the test is to draw out a specific
64 /// access vector cache corruption that sends a calling thread into an infinite loop.
65 /// This was observed when keystore2 used libselinux concurrently in a non thread safe
66 /// way. See b/184006658.
67 #[test]
test_concurrent_check_access()68 fn test_concurrent_check_access() {
69     android_logger::init_once(
70         android_logger::Config::default()
71             .with_tag("keystore2_selinux_concurrency_test")
72             .with_max_level(log::LevelFilter::Debug),
73     );
74 
75     let cpus = num_cpus::get();
76     let turnpike = Arc::new(AtomicU8::new(0));
77     let complete_count = Arc::new(AtomicU8::new(0));
78     let mut threads: Vec<thread::JoinHandle<()>> = Vec::new();
79 
80     for i in 0..cpus {
81         log::info!("Spawning thread {}", i);
82         let turnpike_clone = turnpike.clone();
83         let complete_count_clone = complete_count.clone();
84         threads.push(thread::spawn(move || {
85             let mut cpu_set = CpuSet::new();
86             cpu_set.set(i).unwrap();
87             sched_setaffinity(getpid(), &cpu_set).unwrap();
88             let mut cat_count: CatCount = Default::default();
89 
90             log::info!("Thread 0 reached turnpike");
91             loop {
92                 turnpike_clone.fetch_add(1, Ordering::Relaxed);
93                 loop {
94                     match turnpike_clone.load(Ordering::Relaxed) {
95                         0 => break,
96                         255 => return,
97                         _ => {}
98                     }
99                 }
100 
101                 for _ in 0..250 {
102                     let (tctx, sctx, perm, class) = (
103                         Context::new("u:object_r:keystore:s0").unwrap(),
104                         Context::new(&format!(
105                             "u:r:untrusted_app:s0:{}",
106                             cat_count.next().make_string()
107                         ))
108                         .unwrap(),
109                         "use",
110                         "keystore2_key",
111                     );
112 
113                     check_access(&sctx, &tctx, class, perm).unwrap();
114                 }
115 
116                 complete_count_clone.fetch_add(1, Ordering::Relaxed);
117                 while complete_count_clone.load(Ordering::Relaxed) as usize != cpus {
118                     thread::sleep(Duration::from_millis(5));
119                 }
120             }
121         }));
122     }
123 
124     let mut i = 0;
125     let run_time = Instant::now();
126 
127     loop {
128         const TEST_ITERATIONS: u32 = 500;
129         const MAX_SLEEPS: u64 = 500;
130         const SLEEP_MILLISECONDS: u64 = 5;
131         let mut sleep_count: u64 = 0;
132         while turnpike.load(Ordering::Relaxed) as usize != cpus {
133             thread::sleep(Duration::from_millis(SLEEP_MILLISECONDS));
134             sleep_count += 1;
135             assert!(
136                 sleep_count < MAX_SLEEPS,
137                 "Waited too long to go ready on iteration {}, only {} are ready",
138                 i,
139                 turnpike.load(Ordering::Relaxed)
140             );
141         }
142 
143         if i % 100 == 0 {
144             let elapsed = run_time.elapsed().as_secs();
145             println!("{:02}:{:02}: Iteration {}", elapsed / 60, elapsed % 60, i);
146         }
147 
148         // Give the threads some time to reach and spin on the turn pike.
149         assert_eq!(turnpike.load(Ordering::Relaxed) as usize, cpus, "i = {}", i);
150         if i >= TEST_ITERATIONS {
151             turnpike.store(255, Ordering::Relaxed);
152             break;
153         }
154 
155         // Now go.
156         complete_count.store(0, Ordering::Relaxed);
157         turnpike.store(0, Ordering::Relaxed);
158         i += 1;
159 
160         // Wait for them to all complete.
161         sleep_count = 0;
162         while complete_count.load(Ordering::Relaxed) as usize != cpus {
163             thread::sleep(Duration::from_millis(SLEEP_MILLISECONDS));
164             sleep_count += 1;
165             if sleep_count >= MAX_SLEEPS {
166                 // Enable the following block to park the thread to allow attaching a debugger.
167                 if false {
168                     println!(
169                         "Waited {} seconds and we seem stuck. Going to sleep forever.",
170                         (MAX_SLEEPS * SLEEP_MILLISECONDS) as f32 / 1000.0
171                     );
172                     loop {
173                         thread::park();
174                     }
175                 } else {
176                     assert!(
177                         sleep_count < MAX_SLEEPS,
178                         "Waited too long to complete on iteration {}, only {} are complete",
179                         i,
180                         complete_count.load(Ordering::Relaxed)
181                     );
182                 }
183             }
184         }
185     }
186 
187     for t in threads {
188         t.join().unwrap();
189     }
190 }
191