1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.example.android.vdmdemo.common;
18 
19 import android.content.Context;
20 import android.content.pm.PackageManager;
21 import android.net.ConnectivityManager;
22 import android.net.Network;
23 import android.net.NetworkCapabilities;
24 import android.net.NetworkRequest;
25 import android.net.wifi.aware.AttachCallback;
26 import android.net.wifi.aware.DiscoverySession;
27 import android.net.wifi.aware.DiscoverySessionCallback;
28 import android.net.wifi.aware.PeerHandle;
29 import android.net.wifi.aware.PublishConfig;
30 import android.net.wifi.aware.PublishDiscoverySession;
31 import android.net.wifi.aware.SubscribeConfig;
32 import android.net.wifi.aware.SubscribeDiscoverySession;
33 import android.net.wifi.aware.WifiAwareManager;
34 import android.net.wifi.aware.WifiAwareNetworkInfo;
35 import android.net.wifi.aware.WifiAwareNetworkSpecifier;
36 import android.net.wifi.aware.WifiAwareSession;
37 import android.os.Build;
38 import android.os.Handler;
39 import android.os.HandlerThread;
40 import android.util.Log;
41 
42 import androidx.annotation.GuardedBy;
43 import androidx.annotation.NonNull;
44 
45 import dagger.hilt.android.qualifiers.ApplicationContext;
46 
47 import java.io.IOException;
48 import java.net.Inet6Address;
49 import java.net.ServerSocket;
50 import java.net.Socket;
51 import java.net.SocketTimeoutException;
52 import java.util.ArrayList;
53 import java.util.List;
54 import java.util.Optional;
55 import java.util.concurrent.CompletableFuture;
56 import java.util.function.Consumer;
57 
58 import javax.inject.Inject;
59 import javax.inject.Singleton;
60 
61 /** Shared class between the client and the host, managing the connection between them. */
62 @Singleton
63 public class ConnectionManager {
64 
65     private static final String TAG = "VdmConnectionManager";
66     private static final String CONNECTION_SERVICE_ID = "com.example.android.vdmdemo";
67     private static final int NETWORK_TIMEOUT_MS = 5000;
68 
69     private final RemoteIo mRemoteIo;
70 
71     @ApplicationContext private final Context mContext;
72     private final ConnectivityManager mConnectivityManager;
73     private final Handler mBackgroundHandler;
74 
75     private CompletableFuture<WifiAwareSession> mWifiAwareSessionFuture = new CompletableFuture<>();
76 
77     private DiscoverySession mDiscoverySession;
78 
79     /** Simple data structure to allow clients to query the current status. */
80     public static final class ConnectionStatus {
81         public String remoteDeviceName = null;
82         public String errorMessage = null;
83         public State state = State.DISCONNECTED;
84 
85         /** Enum indicating the current connection state. */
86         public enum State {
87             DISCONNECTED, INITIALIZED, CONNECTING, CONNECTED, ERROR
88         }
89     }
90 
91     @GuardedBy("mConnectionStatus")
92     private final ConnectionStatus mConnectionStatus = new ConnectionStatus();
93 
94     @GuardedBy("mConnectionCallbacks")
95     private final List<Consumer<ConnectionStatus>> mConnectionCallbacks = new ArrayList<>();
96 
97     private final RemoteIo.StreamClosedCallback mStreamClosedCallback = this::onInitialized;
98 
99     @Inject
ConnectionManager(@pplicationContext Context context, RemoteIo remoteIo)100     ConnectionManager(@ApplicationContext Context context, RemoteIo remoteIo) {
101         mRemoteIo = remoteIo;
102         mContext = context;
103 
104         mConnectivityManager = context.getSystemService(ConnectivityManager.class);
105         final HandlerThread backgroundThread = new HandlerThread("ConnectionThread");
106         backgroundThread.start();
107         mBackgroundHandler = new Handler(backgroundThread.getLooper());
108     }
109 
getLocalEndpointId()110     static String getLocalEndpointId() {
111         return Build.MODEL;
112     }
113 
114     /** Registers a listener for connection events. */
addConnectionCallback(Consumer<ConnectionStatus> callback)115     public void addConnectionCallback(Consumer<ConnectionStatus> callback) {
116         synchronized (mConnectionCallbacks) {
117             mConnectionCallbacks.add(callback);
118         }
119     }
120 
121     /** Registers a listener for connection events. */
removeConnectionCallback(Consumer<ConnectionStatus> callback)122     public void removeConnectionCallback(Consumer<ConnectionStatus> callback) {
123         synchronized (mConnectionCallbacks) {
124             mConnectionCallbacks.remove(callback);
125         }
126     }
127 
128     /** Returns the current connection status. */
getConnectionStatus()129     public ConnectionStatus getConnectionStatus() {
130         synchronized (mConnectionStatus) {
131             return mConnectionStatus;
132         }
133     }
134 
135     /** Publish a local service so remote devices can discover this device. */
startHostSession()136     public void startHostSession() {
137         var unused = createWifiAwareSession().thenAccept(session -> session.publish(
138                     new PublishConfig.Builder().setServiceName(CONNECTION_SERVICE_ID).build(),
139                     new HostDiscoverySessionCallback(),
140                     mBackgroundHandler));
141     }
142 
143     /** Looks for published services from remote devices and subscribes to them. */
startClientSession()144     public void startClientSession() {
145         var unused = createWifiAwareSession().thenAccept(session -> session.subscribe(
146                 new SubscribeConfig.Builder().setServiceName(CONNECTION_SERVICE_ID).build(),
147                 new ClientDiscoverySessionCallback(),
148                 mBackgroundHandler));
149     }
150 
isConnected()151     private boolean isConnected() {
152         synchronized (mConnectionStatus) {
153             return mConnectionStatus.state == ConnectionStatus.State.CONNECTED;
154         }
155     }
156 
createWifiAwareSession()157     private CompletableFuture<WifiAwareSession> createWifiAwareSession() {
158         if (mWifiAwareSessionFuture.isDone()
159                 && !mWifiAwareSessionFuture.isCompletedExceptionally()) {
160             return mWifiAwareSessionFuture;
161         }
162 
163         Log.d(TAG, "Creating a new Wifi Aware session.");
164         WifiAwareManager wifiAwareManager = mContext.getSystemService(WifiAwareManager.class);
165         if (!mContext.getPackageManager().hasSystemFeature(PackageManager.FEATURE_WIFI_AWARE)
166                 || wifiAwareManager == null
167                 || !wifiAwareManager.isAvailable()) {
168             mWifiAwareSessionFuture.completeExceptionally(
169                     new Exception("Wifi Aware is not available."));
170         } else {
171             wifiAwareManager.attach(
172                     new AttachCallback() {
173                         @Override
174                         public void onAttached(WifiAwareSession session) {
175                             Log.d(TAG, "New Wifi Aware attached.");
176                             mWifiAwareSessionFuture.complete(session);
177                         }
178 
179                         @Override
180                         public void onAttachFailed() {
181                             mWifiAwareSessionFuture.completeExceptionally(
182                                     new Exception("Failed to attach Wifi Aware session."));
183                         }
184                     },
185                     mBackgroundHandler);
186         }
187         mWifiAwareSessionFuture = mWifiAwareSessionFuture
188                 .exceptionally(e -> {
189                     Log.e(TAG, "Failed to create Wifi Aware session", e);
190                     onError("Failed to create Wifi Aware session");
191                     return null;
192                 });
193         return mWifiAwareSessionFuture;
194     }
195 
196     /** Explicitly terminate any existing connection. */
disconnect()197     public void disconnect() {
198         Log.d(TAG, "Terminating connections.");
199         if (mDiscoverySession != null) {
200             mDiscoverySession.close();
201             mDiscoverySession = null;
202         }
203     }
204 
onSocketAvailable(Socket socket)205     private void onSocketAvailable(Socket socket) throws IOException {
206         mRemoteIo.initialize(socket.getInputStream(), mStreamClosedCallback);
207         mRemoteIo.initialize(socket.getOutputStream(), mStreamClosedCallback);
208         synchronized (mConnectionStatus) {
209             mConnectionStatus.state = ConnectionStatus.State.CONNECTED;
210             notifyStateChangedLocked();
211         }
212     }
213 
onInitialized()214     private void onInitialized() {
215         synchronized (mConnectionStatus) {
216             mConnectionStatus.state = ConnectionStatus.State.INITIALIZED;
217             notifyStateChangedLocked();
218         }
219     }
220 
onConnecting(byte[] remoteDeviceName)221     private void onConnecting(byte[] remoteDeviceName) {
222         synchronized (mConnectionStatus) {
223             mConnectionStatus.state = ConnectionStatus.State.CONNECTING;
224             mConnectionStatus.remoteDeviceName = new String(remoteDeviceName);
225             Log.d(TAG, "Connecting to " + mConnectionStatus.remoteDeviceName);
226             notifyStateChangedLocked();
227         }
228     }
229 
onError(String message)230     private void onError(String message) {
231         Log.e(TAG, "Error: " + message);
232         synchronized (mConnectionStatus) {
233             mConnectionStatus.state = ConnectionStatus.State.ERROR;
234             mConnectionStatus.errorMessage = message;
235             notifyStateChangedLocked();
236         }
237     }
238 
239     @GuardedBy("mConnectionStatus")
notifyStateChangedLocked()240     private void notifyStateChangedLocked() {
241         Log.d(TAG, "Connection state changed: " + mConnectionStatus.state);
242         synchronized (mConnectionCallbacks) {
243             for (Consumer<ConnectionStatus> callback : mConnectionCallbacks) {
244                 callback.accept(mConnectionStatus);
245             }
246         }
247     }
248 
249     private abstract class VdmDiscoverySessionCallback extends DiscoverySessionCallback {
250 
251         private NetworkCallback mNetworkCallback;
252 
253         @Override
onPublishStarted(@onNull PublishDiscoverySession session)254         public void onPublishStarted(@NonNull PublishDiscoverySession session) {
255             mDiscoverySession = session;
256             onInitialized();
257         }
258 
259         @Override
onSubscribeStarted(@onNull SubscribeDiscoverySession session)260         public void onSubscribeStarted(@NonNull SubscribeDiscoverySession session) {
261             mDiscoverySession = session;
262             onInitialized();
263         }
264 
265         @Override
onServiceDiscovered( PeerHandle peerHandle, byte[] serviceSpecificInfo, List<byte[]> matchFilter)266         public void onServiceDiscovered(
267                 PeerHandle peerHandle, byte[] serviceSpecificInfo, List<byte[]> matchFilter) {
268             Log.d(TAG, "Discovered service: " + new String(serviceSpecificInfo));
269             sendLocalEndpointId(peerHandle);
270         }
271 
272         @Override
onSessionTerminated()273         public void onSessionTerminated() {
274             Log.d(TAG, "Discovery session terminated.");
275             if (mNetworkCallback != null) {
276                 mConnectivityManager.unregisterNetworkCallback(mNetworkCallback);
277                 mNetworkCallback = null;
278             }
279         }
280 
sendLocalEndpointId(PeerHandle peerHandle)281         void sendLocalEndpointId(PeerHandle peerHandle) {
282             mDiscoverySession.sendMessage(peerHandle, 0, getLocalEndpointId().getBytes());
283         }
284 
285         @Override
onMessageReceived(PeerHandle peerHandle, byte[] message)286         public void onMessageReceived(PeerHandle peerHandle, byte[] message) {
287             Log.d(TAG, "Received message: " + new String(message));
288             if (isConnected()) {
289                 return;
290             }
291             onConnecting(message);
292             establishConnection(peerHandle);
293         }
294 
establishConnection(PeerHandle peerHandle)295         protected abstract void establishConnection(PeerHandle peerHandle);
296 
requestNetwork( PeerHandle peerHandle, Optional<Integer> port, NetworkCallback networkCallback)297         void requestNetwork(
298                 PeerHandle peerHandle, Optional<Integer> port, NetworkCallback networkCallback) {
299             WifiAwareNetworkSpecifier.Builder networkSpecifierBuilder;
300             networkSpecifierBuilder =
301                     new WifiAwareNetworkSpecifier.Builder(mDiscoverySession, peerHandle)
302                             .setPskPassphrase(CONNECTION_SERVICE_ID);
303             if (mNetworkCallback != null) {
304                 mConnectivityManager.unregisterNetworkCallback(mNetworkCallback);
305             }
306             mNetworkCallback = networkCallback;
307             port.ifPresent(networkSpecifierBuilder::setPort);
308 
309             NetworkRequest networkRequest =
310                     new NetworkRequest.Builder()
311                             .addTransportType(NetworkCapabilities.TRANSPORT_WIFI_AWARE)
312                             .setNetworkSpecifier(networkSpecifierBuilder.build())
313                             .build();
314             Log.d(TAG, "Requesting network");
315             mConnectivityManager.requestNetwork(
316                     networkRequest, mNetworkCallback, NETWORK_TIMEOUT_MS);
317         }
318     }
319 
320     private final class HostDiscoverySessionCallback extends VdmDiscoverySessionCallback {
321         @Override
establishConnection(PeerHandle peerHandle)322         protected void establishConnection(PeerHandle peerHandle) {
323             try {
324                 ServerSocket serverSocket = new ServerSocket(0);
325                 serverSocket.setSoTimeout(NETWORK_TIMEOUT_MS);
326                 requestNetwork(peerHandle, Optional.of(serverSocket.getLocalPort()),
327                         new NetworkCallback());
328                 sendLocalEndpointId(peerHandle);
329                 onSocketAvailable(serverSocket.accept());
330             } catch (SocketTimeoutException e) {
331                 Log.e(TAG, "Socket timeout: " + e.getMessage());
332             } catch (IOException e) {
333                 onError("Failed to establish connection.");
334             }
335         }
336     }
337 
338     private final class ClientDiscoverySessionCallback extends VdmDiscoverySessionCallback {
339         @Override
establishConnection(PeerHandle peerHandle)340         protected void establishConnection(PeerHandle peerHandle) {
341             requestNetwork(peerHandle, /* port= */ Optional.empty(), new ClientNetworkCallback());
342         }
343     }
344 
345     private class NetworkCallback extends ConnectivityManager.NetworkCallback {
346 
347         @Override
onLost(@onNull Network network)348         public void onLost(@NonNull Network network) {
349             Log.d(TAG, "Network lost");
350             onInitialized();
351         }
352 
353         @Override
onUnavailable()354         public void onUnavailable() {
355             Log.d(TAG, "Network unavailable");
356         }
357     }
358 
359     private class ClientNetworkCallback extends NetworkCallback {
360 
361         @Override
onCapabilitiesChanged(@onNull Network network, @NonNull NetworkCapabilities networkCapabilities)362         public void onCapabilitiesChanged(@NonNull Network network,
363                 @NonNull NetworkCapabilities networkCapabilities) {
364             if (isConnected()) {
365                 return;
366             }
367 
368             WifiAwareNetworkInfo peerAwareInfo =
369                     (WifiAwareNetworkInfo) networkCapabilities.getTransportInfo();
370             Inet6Address peerIpv6 = peerAwareInfo.getPeerIpv6Addr();
371             int peerPort = peerAwareInfo.getPort();
372             try {
373                 Socket socket = network.getSocketFactory().createSocket(peerIpv6, peerPort);
374                 onSocketAvailable(socket);
375             } catch (IOException e) {
376                 Log.e(TAG, "Failed to establish connection.", e);
377                 onError("Failed to establish connection.");
378             }
379         }
380     }
381 }
382