1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "chre_host/socket_server.h"
18 
19 #include <poll.h>
20 
21 #include <cassert>
22 #include <cerrno>
23 #include <cinttypes>
24 #include <csignal>
25 #include <cstdlib>
26 #include <map>
27 #include <mutex>
28 
29 #include <cutils/sockets.h>
30 
31 #include "chre_host/log.h"
32 
33 namespace android {
34 namespace chre {
35 
36 std::atomic<bool> SocketServer::sSignalReceived(false);
37 
SocketServer()38 SocketServer::SocketServer() {
39   // Initialize the socket fds field for all inactive client slots to -1, so
40   // poll skips over it, and we don't attempt to send on it
41   for (size_t i = 1; i <= kMaxActiveClients; i++) {
42     mPollFds[i].fd = -1;
43     mPollFds[i].events = POLLIN;
44   }
45 }
46 
run(const char * socketName,bool allowSocketCreation,ClientMessageCallback clientMessageCallback)47 void SocketServer::run(const char *socketName, bool allowSocketCreation,
48                        ClientMessageCallback clientMessageCallback) {
49   mClientMessageCallback = clientMessageCallback;
50 
51   mSockFd = android_get_control_socket(socketName);
52   if (mSockFd == INVALID_SOCKET && allowSocketCreation) {
53     LOGI("Didn't inherit socket, creating...");
54     mSockFd = socket_local_server(socketName, ANDROID_SOCKET_NAMESPACE_RESERVED,
55                                   SOCK_SEQPACKET);
56   }
57 
58   if (mSockFd == INVALID_SOCKET) {
59     LOGE("Couldn't get/create socket");
60   } else {
61     int ret = listen(mSockFd, kMaxPendingConnectionRequests);
62     if (ret < 0) {
63       LOG_ERROR("Couldn't listen on socket", errno);
64     } else {
65       serviceSocket();
66     }
67 
68     {
69       std::lock_guard<std::mutex> lock(mClientsMutex);
70       for (const auto &pair : mClients) {
71         int clientSocket = pair.first;
72         if (close(clientSocket) != 0) {
73           LOGI("Couldn't close client %" PRIu16 "'s socket: %s",
74                pair.second.clientId, strerror(errno));
75         }
76       }
77       mClients.clear();
78     }
79     close(mSockFd);
80   }
81 }
82 
sendToAllClients(const void * data,size_t length)83 void SocketServer::sendToAllClients(const void *data, size_t length) {
84   std::lock_guard<std::mutex> lock(mClientsMutex);
85 
86   int deliveredCount = 0;
87   for (const auto &pair : mClients) {
88     int clientSocket = pair.first;
89     uint16_t clientId = pair.second.clientId;
90     if (sendToClientSocket(data, length, clientSocket, clientId)) {
91       deliveredCount++;
92     } else if (errno == EINTR) {
93       // Exit early if we were interrupted - we should only get this for
94       // SIGINT/SIGTERM, so we should exit quickly
95       break;
96     }
97   }
98 
99   if (deliveredCount == 0) {
100     LOGW("Got message but didn't deliver to any clients");
101   }
102 }
103 
sendToClientById(const void * data,size_t length,uint16_t clientId)104 bool SocketServer::sendToClientById(const void *data, size_t length,
105                                     uint16_t clientId) {
106   std::lock_guard<std::mutex> lock(mClientsMutex);
107 
108   bool sent = false;
109   for (const auto &pair : mClients) {
110     uint16_t thisClientId = pair.second.clientId;
111     if (thisClientId == clientId) {
112       int clientSocket = pair.first;
113       sent = sendToClientSocket(data, length, clientSocket, thisClientId);
114       break;
115     }
116   }
117 
118   return sent;
119 }
120 
acceptClientConnection()121 void SocketServer::acceptClientConnection() {
122   int clientSocket = accept(mSockFd, NULL, NULL);
123   if (clientSocket < 0) {
124     LOG_ERROR("Couldn't accept client connection", errno);
125   } else if (mClients.size() >= kMaxActiveClients) {
126     LOGW("Rejecting client request - maximum number of clients reached");
127     close(clientSocket);
128   } else {
129     ClientData clientData;
130     clientData.clientId = mNextClientId++;
131 
132     // We currently don't handle wraparound - if we're getting this many
133     // connects/disconnects, then something is wrong.
134     // TODO: can handle this properly by iterating over the existing clients to
135     // avoid a conflict.
136     if (clientData.clientId == 0) {
137       LOGE("Couldn't allocate client ID");
138       std::exit(-1);
139     }
140 
141     bool slotFound = false;
142     for (size_t i = 1; i <= kMaxActiveClients; i++) {
143       if (mPollFds[i].fd < 0) {
144         mPollFds[i].fd = clientSocket;
145         slotFound = true;
146         break;
147       }
148     }
149 
150     if (!slotFound) {
151       LOGE("Couldn't find slot for client!");
152       assert(slotFound);
153       close(clientSocket);
154     } else {
155       {
156         std::lock_guard<std::mutex> lock(mClientsMutex);
157         mClients[clientSocket] = clientData;
158       }
159       LOGI(
160           "Accepted new client connection (count %zu), assigned client ID "
161           "%" PRIu16,
162           mClients.size(), clientData.clientId);
163     }
164   }
165 }
166 
handleClientData(int clientSocket)167 void SocketServer::handleClientData(int clientSocket) {
168   const ClientData &clientData = mClients[clientSocket];
169   uint16_t clientId = clientData.clientId;
170 
171   ssize_t packetSize =
172       recv(clientSocket, mRecvBuffer.data(), mRecvBuffer.size(), MSG_DONTWAIT);
173   if (packetSize < 0) {
174     LOGE("Couldn't get packet from client %" PRIu16 ": %s", clientId,
175          strerror(errno));
176     if (ENOTCONN == errno) {
177       disconnectClient(clientSocket);
178     }
179   } else if (packetSize == 0) {
180     LOGI("Client %" PRIu16 " disconnected", clientId);
181     disconnectClient(clientSocket);
182   } else {
183     LOGV("Got %zd byte packet from client %" PRIu16, packetSize, clientId);
184     mClientMessageCallback(clientId, mRecvBuffer.data(), packetSize);
185   }
186 }
187 
disconnectClient(int clientSocket)188 void SocketServer::disconnectClient(int clientSocket) {
189   {
190     std::lock_guard<std::mutex> lock(mClientsMutex);
191     mClients.erase(clientSocket);
192   }
193   close(clientSocket);
194 
195   bool removed = false;
196   for (size_t i = 1; i <= kMaxActiveClients; i++) {
197     if (mPollFds[i].fd == clientSocket) {
198       mPollFds[i].fd = -1;
199       removed = true;
200       break;
201     }
202   }
203 
204   if (!removed) {
205     LOGE("Out of sync");
206     assert(removed);
207   }
208 }
209 
sendToClientSocket(const void * data,size_t length,int clientSocket,uint16_t clientId)210 bool SocketServer::sendToClientSocket(const void *data, size_t length,
211                                       int clientSocket, uint16_t clientId) {
212   errno = 0;
213   ssize_t bytesSent = send(clientSocket, data, length, 0);
214   if (bytesSent < 0) {
215     LOGE("Error sending packet of size %zu to client %" PRIu16 ": %s", length,
216          clientId, strerror(errno));
217   } else if (bytesSent == 0) {
218     LOGW("Client %" PRIu16 " disconnected before message could be delivered",
219          clientId);
220   } else {
221     LOGV("Delivered message of size %zu bytes to client %" PRIu16, length,
222          clientId);
223   }
224 
225   return (bytesSent > 0);
226 }
227 
serviceSocket()228 void SocketServer::serviceSocket() {
229   constexpr size_t kListenIndex = 0;
230   static_assert(kListenIndex == 0,
231                 "Code assumes that the first index is always the listen "
232                 "socket");
233 
234   mPollFds[kListenIndex].fd = mSockFd;
235   mPollFds[kListenIndex].events = POLLIN;
236 
237   // Signal mask used with ppoll() so we gracefully handle SIGINT and SIGTERM,
238   // and ignore other signals
239   sigset_t signalMask;
240   sigfillset(&signalMask);
241   sigdelset(&signalMask, SIGINT);
242   sigdelset(&signalMask, SIGTERM);
243 
244   LOGI("Ready to accept connections");
245   while (!sSignalReceived) {
246     int ret = ppoll(mPollFds, 1 + kMaxActiveClients, nullptr, &signalMask);
247     if (ret == -1) {
248       // Don't use TEMP_FAILURE_RETRY since our logic needs to check
249       // sSignalReceived to see if it should exit where as TEMP_FAILURE_RETRY
250       // is a tight retry loop around ppoll.
251       if (errno == EINTR) {
252         continue;
253       }
254       LOGI("Exiting poll loop: %s", strerror(errno));
255       break;
256     }
257 
258     if (mPollFds[kListenIndex].revents & POLLIN) {
259       acceptClientConnection();
260     }
261 
262     for (size_t i = 1; i <= kMaxActiveClients; i++) {
263       if (mPollFds[i].fd < 0) {
264         continue;
265       }
266 
267       if (mPollFds[i].revents & POLLIN) {
268         handleClientData(mPollFds[i].fd);
269       }
270     }
271   }
272 }
273 
274 }  // namespace chre
275 }  // namespace android
276