1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "chre_host/socket_server.h"
18
19 #include <poll.h>
20
21 #include <cassert>
22 #include <cerrno>
23 #include <cinttypes>
24 #include <csignal>
25 #include <cstdlib>
26 #include <map>
27 #include <mutex>
28
29 #include <cutils/sockets.h>
30
31 #include "chre_host/log.h"
32
33 namespace android {
34 namespace chre {
35
36 std::atomic<bool> SocketServer::sSignalReceived(false);
37
SocketServer()38 SocketServer::SocketServer() {
39 // Initialize the socket fds field for all inactive client slots to -1, so
40 // poll skips over it, and we don't attempt to send on it
41 for (size_t i = 1; i <= kMaxActiveClients; i++) {
42 mPollFds[i].fd = -1;
43 mPollFds[i].events = POLLIN;
44 }
45 }
46
run(const char * socketName,bool allowSocketCreation,ClientMessageCallback clientMessageCallback)47 void SocketServer::run(const char *socketName, bool allowSocketCreation,
48 ClientMessageCallback clientMessageCallback) {
49 mClientMessageCallback = clientMessageCallback;
50
51 mSockFd = android_get_control_socket(socketName);
52 if (mSockFd == INVALID_SOCKET && allowSocketCreation) {
53 LOGI("Didn't inherit socket, creating...");
54 mSockFd = socket_local_server(socketName, ANDROID_SOCKET_NAMESPACE_RESERVED,
55 SOCK_SEQPACKET);
56 }
57
58 if (mSockFd == INVALID_SOCKET) {
59 LOGE("Couldn't get/create socket");
60 } else {
61 int ret = listen(mSockFd, kMaxPendingConnectionRequests);
62 if (ret < 0) {
63 LOG_ERROR("Couldn't listen on socket", errno);
64 } else {
65 serviceSocket();
66 }
67
68 {
69 std::lock_guard<std::mutex> lock(mClientsMutex);
70 for (const auto &pair : mClients) {
71 int clientSocket = pair.first;
72 if (close(clientSocket) != 0) {
73 LOGI("Couldn't close client %" PRIu16 "'s socket: %s",
74 pair.second.clientId, strerror(errno));
75 }
76 }
77 mClients.clear();
78 }
79 close(mSockFd);
80 }
81 }
82
sendToAllClients(const void * data,size_t length)83 void SocketServer::sendToAllClients(const void *data, size_t length) {
84 std::lock_guard<std::mutex> lock(mClientsMutex);
85
86 int deliveredCount = 0;
87 for (const auto &pair : mClients) {
88 int clientSocket = pair.first;
89 uint16_t clientId = pair.second.clientId;
90 if (sendToClientSocket(data, length, clientSocket, clientId)) {
91 deliveredCount++;
92 } else if (errno == EINTR) {
93 // Exit early if we were interrupted - we should only get this for
94 // SIGINT/SIGTERM, so we should exit quickly
95 break;
96 }
97 }
98
99 if (deliveredCount == 0) {
100 LOGW("Got message but didn't deliver to any clients");
101 }
102 }
103
sendToClientById(const void * data,size_t length,uint16_t clientId)104 bool SocketServer::sendToClientById(const void *data, size_t length,
105 uint16_t clientId) {
106 std::lock_guard<std::mutex> lock(mClientsMutex);
107
108 bool sent = false;
109 for (const auto &pair : mClients) {
110 uint16_t thisClientId = pair.second.clientId;
111 if (thisClientId == clientId) {
112 int clientSocket = pair.first;
113 sent = sendToClientSocket(data, length, clientSocket, thisClientId);
114 break;
115 }
116 }
117
118 return sent;
119 }
120
acceptClientConnection()121 void SocketServer::acceptClientConnection() {
122 int clientSocket = accept(mSockFd, NULL, NULL);
123 if (clientSocket < 0) {
124 LOG_ERROR("Couldn't accept client connection", errno);
125 } else if (mClients.size() >= kMaxActiveClients) {
126 LOGW("Rejecting client request - maximum number of clients reached");
127 close(clientSocket);
128 } else {
129 ClientData clientData;
130 clientData.clientId = mNextClientId++;
131
132 // We currently don't handle wraparound - if we're getting this many
133 // connects/disconnects, then something is wrong.
134 // TODO: can handle this properly by iterating over the existing clients to
135 // avoid a conflict.
136 if (clientData.clientId == 0) {
137 LOGE("Couldn't allocate client ID");
138 std::exit(-1);
139 }
140
141 bool slotFound = false;
142 for (size_t i = 1; i <= kMaxActiveClients; i++) {
143 if (mPollFds[i].fd < 0) {
144 mPollFds[i].fd = clientSocket;
145 slotFound = true;
146 break;
147 }
148 }
149
150 if (!slotFound) {
151 LOGE("Couldn't find slot for client!");
152 assert(slotFound);
153 close(clientSocket);
154 } else {
155 {
156 std::lock_guard<std::mutex> lock(mClientsMutex);
157 mClients[clientSocket] = clientData;
158 }
159 LOGI(
160 "Accepted new client connection (count %zu), assigned client ID "
161 "%" PRIu16,
162 mClients.size(), clientData.clientId);
163 }
164 }
165 }
166
handleClientData(int clientSocket)167 void SocketServer::handleClientData(int clientSocket) {
168 const ClientData &clientData = mClients[clientSocket];
169 uint16_t clientId = clientData.clientId;
170
171 ssize_t packetSize =
172 recv(clientSocket, mRecvBuffer.data(), mRecvBuffer.size(), MSG_DONTWAIT);
173 if (packetSize < 0) {
174 LOGE("Couldn't get packet from client %" PRIu16 ": %s", clientId,
175 strerror(errno));
176 if (ENOTCONN == errno) {
177 disconnectClient(clientSocket);
178 }
179 } else if (packetSize == 0) {
180 LOGI("Client %" PRIu16 " disconnected", clientId);
181 disconnectClient(clientSocket);
182 } else {
183 LOGV("Got %zd byte packet from client %" PRIu16, packetSize, clientId);
184 mClientMessageCallback(clientId, mRecvBuffer.data(), packetSize);
185 }
186 }
187
disconnectClient(int clientSocket)188 void SocketServer::disconnectClient(int clientSocket) {
189 {
190 std::lock_guard<std::mutex> lock(mClientsMutex);
191 mClients.erase(clientSocket);
192 }
193 close(clientSocket);
194
195 bool removed = false;
196 for (size_t i = 1; i <= kMaxActiveClients; i++) {
197 if (mPollFds[i].fd == clientSocket) {
198 mPollFds[i].fd = -1;
199 removed = true;
200 break;
201 }
202 }
203
204 if (!removed) {
205 LOGE("Out of sync");
206 assert(removed);
207 }
208 }
209
sendToClientSocket(const void * data,size_t length,int clientSocket,uint16_t clientId)210 bool SocketServer::sendToClientSocket(const void *data, size_t length,
211 int clientSocket, uint16_t clientId) {
212 errno = 0;
213 ssize_t bytesSent = send(clientSocket, data, length, 0);
214 if (bytesSent < 0) {
215 LOGE("Error sending packet of size %zu to client %" PRIu16 ": %s", length,
216 clientId, strerror(errno));
217 } else if (bytesSent == 0) {
218 LOGW("Client %" PRIu16 " disconnected before message could be delivered",
219 clientId);
220 } else {
221 LOGV("Delivered message of size %zu bytes to client %" PRIu16, length,
222 clientId);
223 }
224
225 return (bytesSent > 0);
226 }
227
serviceSocket()228 void SocketServer::serviceSocket() {
229 constexpr size_t kListenIndex = 0;
230 static_assert(kListenIndex == 0,
231 "Code assumes that the first index is always the listen "
232 "socket");
233
234 mPollFds[kListenIndex].fd = mSockFd;
235 mPollFds[kListenIndex].events = POLLIN;
236
237 // Signal mask used with ppoll() so we gracefully handle SIGINT and SIGTERM,
238 // and ignore other signals
239 sigset_t signalMask;
240 sigfillset(&signalMask);
241 sigdelset(&signalMask, SIGINT);
242 sigdelset(&signalMask, SIGTERM);
243
244 LOGI("Ready to accept connections");
245 while (!sSignalReceived) {
246 int ret = ppoll(mPollFds, 1 + kMaxActiveClients, nullptr, &signalMask);
247 if (ret == -1) {
248 // Don't use TEMP_FAILURE_RETRY since our logic needs to check
249 // sSignalReceived to see if it should exit where as TEMP_FAILURE_RETRY
250 // is a tight retry loop around ppoll.
251 if (errno == EINTR) {
252 continue;
253 }
254 LOGI("Exiting poll loop: %s", strerror(errno));
255 break;
256 }
257
258 if (mPollFds[kListenIndex].revents & POLLIN) {
259 acceptClientConnection();
260 }
261
262 for (size_t i = 1; i <= kMaxActiveClients; i++) {
263 if (mPollFds[i].fd < 0) {
264 continue;
265 }
266
267 if (mPollFds[i].revents & POLLIN) {
268 handleClientData(mPollFds[i].fd);
269 }
270 }
271 }
272 }
273
274 } // namespace chre
275 } // namespace android
276