1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 //! Implementation of JNI platform functionality.
16 use crate::jnames::{SEND_REQUEST_MNAME, SEND_REQUEST_MSIG};
17 use crate::unique_jvm;
18 use anyhow::anyhow;
19 use jni::errors::Error as JNIError;
20 use jni::objects::{GlobalRef, JMethodID, JObject, JValue};
21 use jni::signature::TypeSignature;
22 use jni::sys::{jbyteArray, jint, jlong, jvalue};
23 use jni::{JNIEnv, JavaVM};
24 use lazy_static::lazy_static;
25 use log::{debug, error, info};
26 use std::collections::HashMap;
27 use std::sync::{
28     atomic::{AtomicI64, Ordering},
29     Arc, Mutex,
30 };
31 
32 /// Macro capturing the name of the function calling this macro.
33 ///
34 /// function_name()! -> &'static str
35 /// Returns the function name as 'static reference.
36 macro_rules! function_name {
37     () => {{
38         // Declares function f inside current function.
39         fn f() {}
40         fn type_name_of<T>(_: T) -> &'static str {
41             std::any::type_name::<T>()
42         }
43         // type name of f is struct_or_crate_name::calling_function_name::f
44         let name = type_name_of(f);
45         // Find and cut the rest of the path:
46         // Third to last character, up to the first semicolon: is calling_function_name
47         match &name[..name.len() - 3].rfind(':') {
48             Some(pos) => &name[pos + 1..name.len() - 3],
49             None => &name[..name.len() - 3],
50         }
51     }};
52 }
53 
54 lazy_static! {
55     static ref HANDLE_MAPPING: Mutex<HashMap<i64, Arc<Mutex<JavaPlatform>>>> =
56         Mutex::new(HashMap::new());
57     static ref HANDLE_RN: AtomicI64 = AtomicI64::new(0);
58 }
59 
generate_platform_handle() -> i6460 fn generate_platform_handle() -> i64 {
61     HANDLE_RN.fetch_add(1, Ordering::SeqCst)
62 }
63 
insert_platform_handle(handle: i64, item: Arc<Mutex<JavaPlatform>>)64 fn insert_platform_handle(handle: i64, item: Arc<Mutex<JavaPlatform>>) {
65     if 0 == handle {
66         // Init once
67         logger::init(
68             logger::Config::default()
69                 .with_tag_on_device("remoteauth")
70                 .with_max_level(log::LevelFilter::Trace)
71                 .with_filter("trace,jni=info"),
72         );
73     }
74     HANDLE_MAPPING.lock().unwrap().insert(handle, Arc::clone(&item));
75 }
76 
77 /// Reports a response from remote device.
78 pub trait ResponseCallback {
79     /// Invoked upon successful response
on_response(&mut self, response: Vec<u8>)80     fn on_response(&mut self, response: Vec<u8>);
81     /// Invoked upon failure
on_error(&mut self, error_code: i32)82     fn on_error(&mut self, error_code: i32);
83 }
84 
85 /// Trait to platform functionality
86 pub trait Platform {
87     /// Send a binary message to the remote with the given connection id and return the response.
send_request( &mut self, connection_id: i32, request: &[u8], callback: Box<dyn ResponseCallback + Send>, ) -> anyhow::Result<()>88     fn send_request(
89         &mut self,
90         connection_id: i32,
91         request: &[u8],
92         callback: Box<dyn ResponseCallback + Send>,
93     ) -> anyhow::Result<()>;
94 }
95 //////////////////////////////////
96 
97 /// Implementation of Platform trait
98 pub struct JavaPlatform {
99     platform_handle: i64,
100     vm: &'static Arc<JavaVM>,
101     platform_native_obj: GlobalRef,
102     send_request_method_id: JMethodID,
103     map_futures: Mutex<HashMap<i64, Box<dyn ResponseCallback + Send>>>,
104     atomic_handle: AtomicI64,
105 }
106 
107 impl JavaPlatform {
108     /// Creates JavaPlatform and associates with unique handle id
create( java_platform_native: JObject<'_>, ) -> Result<Arc<Mutex<impl Platform>>, JNIError>109     pub fn create(
110         java_platform_native: JObject<'_>,
111     ) -> Result<Arc<Mutex<impl Platform>>, JNIError> {
112         let platform_handle = generate_platform_handle();
113         let platform = Arc::new(Mutex::new(JavaPlatform::new(
114             platform_handle,
115             unique_jvm::get_static_ref().ok_or(JNIError::InvalidCtorReturn)?,
116             java_platform_native,
117         )?));
118         insert_platform_handle(platform_handle, Arc::clone(&platform));
119         Ok(Arc::clone(&platform))
120     }
121 
new( platform_handle: i64, vm: &'static Arc<JavaVM>, java_platform_native: JObject, ) -> Result<JavaPlatform, JNIError>122     fn new(
123         platform_handle: i64,
124         vm: &'static Arc<JavaVM>,
125         java_platform_native: JObject,
126     ) -> Result<JavaPlatform, JNIError> {
127         vm.attach_current_thread().and_then(|env| {
128             let platform_class = env.get_object_class(java_platform_native)?;
129             let platform_native_obj = env.new_global_ref(java_platform_native)?;
130             let send_request_method: JMethodID =
131                 env.get_method_id(platform_class, SEND_REQUEST_MNAME, SEND_REQUEST_MSIG)?;
132 
133             Ok(Self {
134                 platform_handle,
135                 vm,
136                 platform_native_obj,
137                 send_request_method_id: send_request_method,
138                 map_futures: Mutex::new(HashMap::new()),
139                 atomic_handle: AtomicI64::new(0),
140             })
141         })
142     }
143 }
144 
145 impl Platform for JavaPlatform {
send_request( &mut self, connection_id: i32, request: &[u8], callback: Box<dyn ResponseCallback + Send>, ) -> anyhow::Result<()>146     fn send_request(
147         &mut self,
148         connection_id: i32,
149         request: &[u8],
150         callback: Box<dyn ResponseCallback + Send>,
151     ) -> anyhow::Result<()> {
152         let type_signature = TypeSignature::from_str(SEND_REQUEST_MSIG)
153             .map_err(|e| anyhow!("JNI: Invalid type signature: {:?}", e))?;
154 
155         let response_handle = self.atomic_handle.fetch_add(1, Ordering::SeqCst);
156         self.map_futures.lock().unwrap().insert(response_handle, callback);
157         self.vm
158             .attach_current_thread()
159             .and_then(|env| {
160                 let request_jbytearray = env.byte_array_from_slice(request)?;
161                 // Safety: request_jbytearray is safely instantiated above.
162                 let request_jobject = unsafe { JObject::from_raw(request_jbytearray) };
163 
164                 let _ = env.call_method_unchecked(
165                     self.platform_native_obj.as_obj(),
166                     self.send_request_method_id,
167                     type_signature.ret,
168                     &[
169                         jvalue::from(JValue::Int(connection_id)),
170                         jvalue::from(JValue::Object(request_jobject)),
171                         jvalue::from(JValue::Long(response_handle)),
172                         jvalue::from(JValue::Long(self.platform_handle)),
173                     ],
174                 );
175                 Ok(info!(
176                     "{} successfully sent-message, waiting for response {}:{}",
177                     function_name!(),
178                     self.platform_handle,
179                     response_handle
180                 ))
181             })
182             .map_err(|e| anyhow!("JNI: Failed to attach current thread: {:?}", e))?;
183         Ok(())
184     }
185 }
186 
187 impl JavaPlatform {
on_send_request_success(&mut self, response: &[u8], response_handle: i64)188     fn on_send_request_success(&mut self, response: &[u8], response_handle: i64) {
189         info!(
190             "{} completed successfully {}:{}",
191             function_name!(),
192             self.platform_handle,
193             response_handle
194         );
195         if let Some(mut callback) = self.map_futures.lock().unwrap().remove(&response_handle) {
196             callback.on_response(response.to_vec());
197         } else {
198             error!(
199                 "Failed to find TX for {} and {}:{}",
200                 function_name!(),
201                 self.platform_handle,
202                 response_handle
203             );
204         }
205     }
206 
on_send_request_error(&self, error_code: i32, response_handle: i64)207     fn on_send_request_error(&self, error_code: i32, response_handle: i64) {
208         error!(
209             "{} completed with error {} {}:{}",
210             function_name!(),
211             error_code,
212             self.platform_handle,
213             response_handle
214         );
215         if let Some(mut callback) = self.map_futures.lock().unwrap().remove(&response_handle) {
216             callback.on_error(error_code);
217         } else {
218             error!(
219                 "Failed to find callback for {} and {}:{}",
220                 function_name!(),
221                 self.platform_handle,
222                 response_handle
223             );
224         }
225     }
226 }
227 
228 /// Returns successful response from remote device
229 #[no_mangle]
Java_com_android_server_remoteauth_jni_NativeRemoteAuthJavaPlatform_native_on_send_request_success( env: JNIEnv, _: JObject, app_response: jbyteArray, platform_handle: jlong, response_handle: jlong, )230 pub extern "system" fn Java_com_android_server_remoteauth_jni_NativeRemoteAuthJavaPlatform_native_on_send_request_success(
231     env: JNIEnv,
232     _: JObject,
233     app_response: jbyteArray,
234     platform_handle: jlong,
235     response_handle: jlong,
236 ) {
237     debug!("{}: enter", function_name!());
238     native_on_send_request_success(env, app_response, platform_handle, response_handle);
239 }
240 
native_on_send_request_success( env: JNIEnv<'_>, app_response: jbyteArray, platform_handle: jlong, response_handle: jlong, )241 fn native_on_send_request_success(
242     env: JNIEnv<'_>,
243     app_response: jbyteArray,
244     platform_handle: jlong,
245     response_handle: jlong,
246 ) {
247     if let Some(platform) = HANDLE_MAPPING.lock().unwrap().get(&platform_handle) {
248         let response =
249             env.convert_byte_array(app_response).map_err(|_| JNIError::InvalidCtorReturn).unwrap();
250         let mut platform = (*platform).lock().unwrap();
251         platform.on_send_request_success(&response, response_handle);
252     } else {
253         let _ = env.throw_new(
254             "com/android/server/remoteauth/jni/BadHandleException",
255             format!("Failed to find Platform with ID {} in {}", platform_handle, function_name!()),
256         );
257     }
258 }
259 
260 /// Notifies about failure to receive a response from remote device
261 #[no_mangle]
Java_com_android_server_remoteauth_jni_NativeRemoteAuthJavaPlatform_native_on_send_request_error( env: JNIEnv, _: JObject, error_code: jint, platform_handle: jlong, response_handle: jlong, )262 pub extern "system" fn Java_com_android_server_remoteauth_jni_NativeRemoteAuthJavaPlatform_native_on_send_request_error(
263     env: JNIEnv,
264     _: JObject,
265     error_code: jint,
266     platform_handle: jlong,
267     response_handle: jlong,
268 ) {
269     debug!("{}: enter", function_name!());
270     native_on_send_request_error(env, error_code, platform_handle, response_handle);
271 }
272 
native_on_send_request_error( env: JNIEnv<'_>, error_code: jint, platform_handle: jlong, response_handle: jlong, )273 fn native_on_send_request_error(
274     env: JNIEnv<'_>,
275     error_code: jint,
276     platform_handle: jlong,
277     response_handle: jlong,
278 ) {
279     if let Some(platform) = HANDLE_MAPPING.lock().unwrap().get(&platform_handle) {
280         let platform = (*platform).lock().unwrap();
281         platform.on_send_request_error(error_code, response_handle);
282     } else {
283         let _ = env.throw_new(
284             "com/android/server/remoteauth/jni/BadHandleException",
285             format!("Failed to find Platform with ID {} in {}", platform_handle, function_name!()),
286         );
287     }
288 }
289 
290 #[cfg(test)]
291 mod tests {
292     //use super::*;
293 
294     //use tokio::runtime::Builder;
295 
296     /// Checks validity of the function_name! macro.
297     #[test]
test_function_name()298     fn test_function_name() {
299         assert_eq!(function_name!(), "test_function_name");
300     }
301 }
302