1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.google.snippet.wifi.aware;
18 
19 import static java.util.concurrent.TimeUnit.SECONDS;
20 
21 import android.net.ConnectivityManager;
22 import android.net.Network;
23 import android.net.NetworkCapabilities;
24 import android.net.wifi.aware.AttachCallback;
25 import android.net.wifi.aware.AwarePairingConfig;
26 import android.net.wifi.aware.DiscoverySessionCallback;
27 import android.net.wifi.aware.IdentityChangedListener;
28 import android.net.wifi.aware.PeerHandle;
29 import android.net.wifi.aware.PublishDiscoverySession;
30 import android.net.wifi.aware.ServiceDiscoveryInfo;
31 import android.net.wifi.aware.SubscribeDiscoverySession;
32 import android.net.wifi.aware.WifiAwareSession;
33 import android.net.wifi.rtt.RangingResult;
34 import android.net.wifi.rtt.RangingResultCallback;
35 import android.os.Build;
36 import android.util.Log;
37 import android.util.Pair;
38 
39 import com.android.compatibility.common.util.ApiLevelUtil;
40 
41 import com.google.common.collect.ImmutableSet;
42 
43 import java.util.ArrayDeque;
44 import java.util.List;
45 import java.util.Set;
46 import java.util.concurrent.CountDownLatch;
47 
48 /** Blocking callbacks for Wi-Fi Aware and Connectivity Manager. */
49 public final class CallbackUtils {
50     private static final String TAG = "CallbackUtils";
51 
52     public static final int CALLBACK_TIMEOUT_SEC = 15;
53 
54     /**
55      * Utility AttachCallback - provides mechanism to block execution with the waitForAttach method.
56      */
57     public static class AttachCb extends AttachCallback {
58 
59         /** Callback codes. */
60         public enum CallbackCode {
61             TIMEOUT,
62             ON_ATTACHED,
63             ON_ATTACH_FAILED
64         }
65 
66         private final CountDownLatch mBlocker = new CountDownLatch(1);
67         private CallbackCode mCallbackCode = CallbackCode.TIMEOUT;
68         private WifiAwareSession mWifiAwareSession = null;
69 
70         @Override
onAttached(WifiAwareSession session)71         public void onAttached(WifiAwareSession session) {
72             mCallbackCode = CallbackCode.ON_ATTACHED;
73             mWifiAwareSession = session;
74             mBlocker.countDown();
75         }
76 
77         @Override
onAttachFailed()78         public void onAttachFailed() {
79             mCallbackCode = CallbackCode.ON_ATTACH_FAILED;
80             mBlocker.countDown();
81         }
82 
83         /**
84          * Wait (blocks) for any AttachCallback callback or timeout.
85          *
86          * @return A pair of values: the callback constant (or TIMEOUT) and the WifiAwareSession
87          * created
88          * when attach successful - null otherwise (attach failure or timeout).
89          */
waitForAttach()90         public Pair<CallbackCode, WifiAwareSession> waitForAttach() throws InterruptedException {
91             if (mBlocker.await(CALLBACK_TIMEOUT_SEC, SECONDS)) {
92                 return new Pair<>(mCallbackCode, mWifiAwareSession);
93             }
94 
95             return new Pair<>(CallbackCode.TIMEOUT, null);
96         }
97     }
98 
99     /**
100      * Utility IdentityChangedListener - provides mechanism to block execution with the
101      * waitForIdentity method. Single shot listener - only listens for the first triggered callback.
102      */
103     public static class IdentityListenerSingleShot extends IdentityChangedListener {
104         private final CountDownLatch mBlocker = new CountDownLatch(1);
105         private byte[] mMac = null;
106 
107         @Override
onIdentityChanged(byte[] mac)108         public void onIdentityChanged(byte[] mac) {
109             if (this.mMac != null) {
110                 return;
111             }
112 
113             this.mMac = mac;
114             mBlocker.countDown();
115         }
116 
117         /**
118          * Wait (blocks) for the onIdentityChanged callback or a timeout.
119          *
120          * @return The MAC address returned by the onIdentityChanged() callback, or null on timeout.
121          */
waitForMac()122         public byte[] waitForMac() throws InterruptedException {
123             if (mBlocker.await(CALLBACK_TIMEOUT_SEC, SECONDS)) {
124                 return mMac;
125             }
126 
127             return null;
128         }
129     }
130 
131     /**
132      * Utility NetworkCallback - provides mechanism for blocking/serializing access with the
133      * waitForNetwork method.
134      */
135     public static class NetworkCb extends ConnectivityManager.NetworkCallback {
136         private final CountDownLatch mBlocker = new CountDownLatch(1);
137         private Network mNetwork = null;
138         private NetworkCapabilities mNetworkCapabilities = null;
139 
140         @Override
onUnavailable()141         public void onUnavailable() {
142             mNetworkCapabilities = null;
143             mBlocker.countDown();
144         }
145 
146         @Override
onCapabilitiesChanged(Network network, NetworkCapabilities networkCapabilities)147         public void onCapabilitiesChanged(Network network,
148                 NetworkCapabilities networkCapabilities) {
149             this.mNetwork = network;
150             this.mNetworkCapabilities = networkCapabilities;
151             mBlocker.countDown();
152         }
153 
154         /**
155          * Wait (blocks) for Capabilities Changed callback - or timesout.
156          *
157          * @return Network + NetworkCapabilities (pair) if occurred, null otherwise.
158          */
waitForNetworkCapabilities()159         public Pair<Network, NetworkCapabilities> waitForNetworkCapabilities()
160                 throws InterruptedException {
161             if (mBlocker.await(CALLBACK_TIMEOUT_SEC, SECONDS)) {
162                 return Pair.create(mNetwork, mNetworkCapabilities);
163             }
164             return null;
165         }
166     }
167 
168     /**
169      * Utility DiscoverySessionCallback - provides mechanism to block/serialize Aware discovery
170      * operations using the waitForCallbacks() method.
171      */
172     public static class DiscoveryCb extends DiscoverySessionCallback {
173         /** Callback codes. */
174         public enum CallbackCode {
175             TIMEOUT,
176             ON_PUBLISH_STARTED,
177             ON_SUBSCRIBE_STARTED,
178             ON_SESSION_CONFIG_UPDATED,
179             ON_SESSION_CONFIG_FAILED,
180             ON_SESSION_TERMINATED,
181             ON_SERVICE_DISCOVERED,
182             ON_MESSAGE_SEND_SUCCEEDED,
183             ON_MESSAGE_SEND_FAILED,
184             ON_MESSAGE_RECEIVED,
185             ON_SERVICE_DISCOVERED_WITH_RANGE, ON_PAIRING_REQUEST_RECEIVED,
186             ON_PAIRING_SETUP_CONFIRMED, ON_BOOTSTRAPPING_CONFIRMED,
187             ON_PAIRING_VERIFICATION_CONFIRMED,
188         }
189 
190 
191         ;
192 
193         /**
194          * Data container for all parameters which can be returned by any DiscoverySessionCallback
195          * callback.
196          */
197         public static class CallbackData {
CallbackData(CallbackCode callbackCode)198             public CallbackData(CallbackCode callbackCode) {
199                 this.callbackCode = callbackCode;
200             }
201 
202             public CallbackCode callbackCode;
203 
204             public PublishDiscoverySession publishDiscoverySession;
205             public SubscribeDiscoverySession subscribeDiscoverySession;
206             public PeerHandle peerHandle;
207             public byte[] serviceSpecificInfo;
208             public List<byte[]> matchFilter;
209             public int messageId;
210             public int distanceMm;
211             public int pairingRequestId;
212             public boolean pairingAccept;
213             public boolean bootstrappingAccept;
214             public String pairingAlias;
215             public int bootstrappingMethod;
216             public String pairedAlias;
217             public AwarePairingConfig pairingConfig;
218         }
219 
220         private CountDownLatch mBlocker = null;
221         private Set<CallbackCode> mWaitForCallbackCodes = ImmutableSet.of();
222 
223         private final Object mLock = new Object();
224         private final ArrayDeque<CallbackData> mCallbackQueue = new ArrayDeque<>();
225 
processCallback(CallbackData callbackData)226         private void processCallback(CallbackData callbackData) {
227             synchronized (mLock) {
228                 mCallbackQueue.addLast(callbackData);
229                 if (mBlocker != null && mWaitForCallbackCodes.contains(callbackData.callbackCode)) {
230                     mBlocker.countDown();
231                 }
232             }
233         }
234 
getAndRemoveFirst(Set<CallbackCode> callbackCodes)235         private CallbackData getAndRemoveFirst(Set<CallbackCode> callbackCodes) {
236             synchronized (mLock) {
237                 for (CallbackData cbd : mCallbackQueue) {
238                     if (callbackCodes.contains(cbd.callbackCode)) {
239                         mCallbackQueue.remove(cbd);
240                         return cbd;
241                     }
242                 }
243             }
244 
245             return null;
246         }
247 
waitForCallbacks(Set<CallbackCode> callbackCodes, boolean timeout)248         private CallbackData waitForCallbacks(Set<CallbackCode> callbackCodes, boolean timeout)
249                 throws InterruptedException {
250             synchronized (mLock) {
251                 CallbackData cbd = getAndRemoveFirst(callbackCodes);
252                 if (cbd != null) {
253                     return cbd;
254                 }
255 
256                 mWaitForCallbackCodes = callbackCodes;
257                 mBlocker = new CountDownLatch(1);
258             }
259 
260             boolean finishedNormally = true;
261             if (timeout) {
262                 finishedNormally = mBlocker.await(CALLBACK_TIMEOUT_SEC, SECONDS);
263             } else {
264                 mBlocker.await();
265             }
266             if (finishedNormally) {
267                 CallbackData cbd = getAndRemoveFirst(callbackCodes);
268                 if (cbd != null) {
269                     return cbd;
270                 }
271 
272                 Log.wtf(
273                         TAG,
274                         "DiscoveryCb.waitForCallback: callbackCodes="
275                                 + callbackCodes
276                                 + ": did not time-out but doesn't have any of the requested "
277                                 + "callbacks in "
278                                 + "the stack!?");
279                 // falling-through to TIMEOUT
280             }
281 
282             return new CallbackData(CallbackCode.TIMEOUT);
283         }
284 
285         /**
286          * Wait for the specified callbacks - a bitmask of any of the ON_* constants. Returns the
287          * CallbackData structure whose CallbackData.callback specifies the callback which was
288          * triggered. The callback may be TIMEOUT.
289          *
290          * <p>Note: other callbacks happening while while waiting for the specified callback(s) will
291          * be
292          * queued.
293          */
waitForCallbacks(Set<CallbackCode> callbackCodes)294         public CallbackData waitForCallbacks(Set<CallbackCode> callbackCodes)
295                 throws InterruptedException {
296             return waitForCallbacks(callbackCodes, true);
297         }
298 
299         /**
300          * Wait for the specified callbacks - a bitmask of any of the ON_* constants. Returns the
301          * CallbackData structure whose CallbackData.callback specifies the callback which was
302          * triggered.
303          *
304          * <p>This call will not timeout - it can be interrupted though (which results in a thrown
305          * exception).
306          *
307          * <p>Note: other callbacks happening while while waiting for the specified callback(s) will
308          * be
309          * queued.
310          */
waitForCallbacksNoTimeout(Set<CallbackCode> callbackCodes)311         public CallbackData waitForCallbacksNoTimeout(Set<CallbackCode> callbackCodes)
312                 throws InterruptedException {
313             return waitForCallbacks(callbackCodes, false);
314         }
315 
316         @Override
onPublishStarted(PublishDiscoverySession session)317         public void onPublishStarted(PublishDiscoverySession session) {
318             CallbackData callbackData = new CallbackData(CallbackCode.ON_PUBLISH_STARTED);
319             callbackData.publishDiscoverySession = session;
320             processCallback(callbackData);
321         }
322 
323         @Override
onSubscribeStarted(SubscribeDiscoverySession session)324         public void onSubscribeStarted(SubscribeDiscoverySession session) {
325             CallbackData callbackData = new CallbackData(CallbackCode.ON_SUBSCRIBE_STARTED);
326             callbackData.subscribeDiscoverySession = session;
327             processCallback(callbackData);
328         }
329 
330         @Override
onSessionConfigUpdated()331         public void onSessionConfigUpdated() {
332             CallbackData callbackData = new CallbackData(CallbackCode.ON_SESSION_CONFIG_UPDATED);
333             processCallback(callbackData);
334         }
335 
336         @Override
onSessionConfigFailed()337         public void onSessionConfigFailed() {
338             CallbackData callbackData = new CallbackData(CallbackCode.ON_SESSION_CONFIG_FAILED);
339             processCallback(callbackData);
340         }
341 
342         @Override
onSessionTerminated()343         public void onSessionTerminated() {
344             CallbackData callbackData = new CallbackData(CallbackCode.ON_SESSION_TERMINATED);
345             processCallback(callbackData);
346         }
347 
348         @Override
onServiceDiscovered( PeerHandle peerHandle, byte[] serviceSpecificInfo, List<byte[]> matchFilter)349         public void onServiceDiscovered(
350                 PeerHandle peerHandle, byte[] serviceSpecificInfo, List<byte[]> matchFilter) {
351             CallbackData callbackData = new CallbackData(CallbackCode.ON_SERVICE_DISCOVERED);
352             callbackData.peerHandle = peerHandle;
353             callbackData.serviceSpecificInfo = serviceSpecificInfo;
354             callbackData.matchFilter = matchFilter;
355             processCallback(callbackData);
356         }
357 
358         @Override
onServiceDiscovered(ServiceDiscoveryInfo info)359         public void onServiceDiscovered(ServiceDiscoveryInfo info) {
360             CallbackData callbackData = new CallbackData(CallbackCode.ON_SERVICE_DISCOVERED);
361             callbackData.peerHandle = info.getPeerHandle();
362             callbackData.serviceSpecificInfo = info.getServiceSpecificInfo();
363             callbackData.matchFilter = info.getMatchFilters();
364             if (ApiLevelUtil.isAfter(Build.VERSION_CODES.TIRAMISU)) {
365                 callbackData.pairedAlias = info.getPairedAlias();
366                 callbackData.pairingConfig = info.getPairingConfig();
367             }
368             processCallback(callbackData);
369         }
370 
371         @Override
onServiceDiscoveredWithinRange( PeerHandle peerHandle, byte[] serviceSpecificInfo, List<byte[]> matchFilter, int distanceMm)372         public void onServiceDiscoveredWithinRange(
373                 PeerHandle peerHandle,
374                 byte[] serviceSpecificInfo,
375                 List<byte[]> matchFilter,
376                 int distanceMm) {
377             CallbackData callbackData = new CallbackData(
378                     CallbackCode.ON_SERVICE_DISCOVERED_WITH_RANGE);
379             callbackData.peerHandle = peerHandle;
380             callbackData.serviceSpecificInfo = serviceSpecificInfo;
381             callbackData.matchFilter = matchFilter;
382             callbackData.distanceMm = distanceMm;
383             processCallback(callbackData);
384         }
385 
386         @Override
onMessageSendSucceeded(int messageId)387         public void onMessageSendSucceeded(int messageId) {
388             CallbackData callbackData = new CallbackData(CallbackCode.ON_MESSAGE_SEND_SUCCEEDED);
389             callbackData.messageId = messageId;
390             processCallback(callbackData);
391         }
392 
393         @Override
onMessageSendFailed(int messageId)394         public void onMessageSendFailed(int messageId) {
395             CallbackData callbackData = new CallbackData(CallbackCode.ON_MESSAGE_SEND_FAILED);
396             callbackData.messageId = messageId;
397             processCallback(callbackData);
398         }
399 
400         @Override
onMessageReceived(PeerHandle peerHandle, byte[] message)401         public void onMessageReceived(PeerHandle peerHandle, byte[] message) {
402             CallbackData callbackData = new CallbackData(CallbackCode.ON_MESSAGE_RECEIVED);
403             callbackData.peerHandle = peerHandle;
404             callbackData.serviceSpecificInfo = message;
405             processCallback(callbackData);
406         }
407 
408         @Override
onPairingSetupRequestReceived(PeerHandle peerHandle, int requestId)409         public void onPairingSetupRequestReceived(PeerHandle peerHandle, int requestId) {
410             CallbackData callbackData = new CallbackData(CallbackCode.ON_PAIRING_REQUEST_RECEIVED);
411             callbackData.peerHandle = peerHandle;
412             callbackData.pairingRequestId = requestId;
413             processCallback(callbackData);
414         }
415 
416         @Override
onPairingSetupSucceeded(PeerHandle peerHandle, String alias)417         public void onPairingSetupSucceeded(PeerHandle peerHandle, String alias) {
418             CallbackData callbackData = new CallbackData(CallbackCode.ON_PAIRING_SETUP_CONFIRMED);
419             callbackData.peerHandle = peerHandle;
420             callbackData.pairingAccept = true;
421             callbackData.pairingAlias = alias;
422             processCallback(callbackData);
423         }
424 
425         @Override
onPairingSetupFailed(PeerHandle peerHandle)426         public void onPairingSetupFailed(PeerHandle peerHandle) {
427             CallbackData callbackData = new CallbackData(CallbackCode.ON_PAIRING_SETUP_CONFIRMED);
428             callbackData.peerHandle = peerHandle;
429             callbackData.pairingAccept = false;
430             processCallback(callbackData);
431         }
432 
433         @Override
onPairingVerificationSucceed(PeerHandle peerHandle, String alias)434         public void onPairingVerificationSucceed(PeerHandle peerHandle, String alias) {
435             CallbackData callbackData = new CallbackData(
436                     CallbackCode.ON_PAIRING_VERIFICATION_CONFIRMED);
437             callbackData.peerHandle = peerHandle;
438             callbackData.pairingAccept = true;
439             callbackData.pairingAlias = alias;
440             processCallback(callbackData);
441         }
442 
443         @Override
onPairingVerificationFailed(PeerHandle peerHandle)444         public void onPairingVerificationFailed(PeerHandle peerHandle) {
445             CallbackData callbackData = new CallbackData(
446                     CallbackCode.ON_PAIRING_VERIFICATION_CONFIRMED);
447             callbackData.peerHandle = peerHandle;
448             callbackData.pairingAccept = false;
449             processCallback(callbackData);
450         }
451 
452         @Override
onBootstrappingSucceeded(PeerHandle peerHandle, int method)453         public void onBootstrappingSucceeded(PeerHandle peerHandle, int method) {
454             CallbackData callbackData = new CallbackData(CallbackCode.ON_BOOTSTRAPPING_CONFIRMED);
455             callbackData.peerHandle = peerHandle;
456             callbackData.bootstrappingAccept = true;
457             callbackData.bootstrappingMethod = method;
458             processCallback(callbackData);
459         }
460 
461         @Override
onBootstrappingFailed(PeerHandle peerHandle)462         public void onBootstrappingFailed(PeerHandle peerHandle) {
463             CallbackData callbackData = new CallbackData(CallbackCode.ON_BOOTSTRAPPING_CONFIRMED);
464             callbackData.peerHandle = peerHandle;
465             callbackData.bootstrappingAccept = false;
466             processCallback(callbackData);
467         }
468     }
469 
470     /**
471      * Utility RangingResultCallback - provides mechanism for blocking/serializing access with the
472      * waitForRangingResults method.
473      */
474     public static class RangingCb extends RangingResultCallback {
475         public static final int TIMEOUT = -1;
476         public static final int ON_FAILURE = 0;
477         public static final int ON_RESULTS = 1;
478 
479         private final CountDownLatch mBlocker = new CountDownLatch(1);
480         private int mStatus = TIMEOUT;
481         private List<RangingResult> mResults = null;
482 
483         /**
484          * Wait (blocks) for Ranging results callbacks - or times-out.
485          *
486          * @return Pair of status & Ranging results if succeeded, null otherwise.
487          */
waitForRangingResults()488         public Pair<Integer, List<RangingResult>> waitForRangingResults()
489                 throws InterruptedException {
490             if (mBlocker.await(CALLBACK_TIMEOUT_SEC, SECONDS)) {
491                 return new Pair<>(mStatus, mResults);
492             }
493             return new Pair<>(TIMEOUT, null);
494         }
495 
496         @Override
onRangingFailure(int code)497         public void onRangingFailure(int code) {
498             mStatus = ON_FAILURE;
499             mBlocker.countDown();
500         }
501 
502         @Override
onRangingResults(List<RangingResult> results)503         public void onRangingResults(List<RangingResult> results) {
504             mStatus = ON_RESULTS;
505             this.mResults = results;
506             mBlocker.countDown();
507         }
508     }
509 
CallbackUtils()510     private CallbackUtils() {
511     }
512 }
513