1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "chre_host/socket_server.h"
18 
19 #include <poll.h>
20 
21 #include <cassert>
22 #include <cinttypes>
23 #include <csignal>
24 #include <cstdlib>
25 #include <map>
26 #include <mutex>
27 
28 #include <cutils/sockets.h>
29 
30 #include "chre_host/log.h"
31 
32 namespace android {
33 namespace chre {
34 
35 std::atomic<bool> SocketServer::sSignalReceived(false);
36 
37 namespace {
38 
maskAllSignals()39 void maskAllSignals() {
40   sigset_t signalMask;
41   sigfillset(&signalMask);
42   if (sigprocmask(SIG_SETMASK, &signalMask, NULL) != 0) {
43     LOG_ERROR("Couldn't mask all signals", errno);
44   }
45 }
46 
maskAllSignalsExceptIntAndTerm()47 void maskAllSignalsExceptIntAndTerm() {
48   sigset_t signalMask;
49   sigfillset(&signalMask);
50   sigdelset(&signalMask, SIGINT);
51   sigdelset(&signalMask, SIGTERM);
52   if (sigprocmask(SIG_SETMASK, &signalMask, NULL) != 0) {
53     LOG_ERROR("Couldn't mask all signals except INT/TERM", errno);
54   }
55 }
56 
57 }  // anonymous namespace
58 
SocketServer()59 SocketServer::SocketServer() {
60   // Initialize the socket fds field for all inactive client slots to -1, so
61   // poll skips over it, and we don't attempt to send on it
62   for (size_t i = 1; i <= kMaxActiveClients; i++) {
63     mPollFds[i].fd = -1;
64     mPollFds[i].events = POLLIN;
65   }
66 }
67 
run(const char * socketName,bool allowSocketCreation,ClientMessageCallback clientMessageCallback)68 void SocketServer::run(const char *socketName, bool allowSocketCreation,
69                        ClientMessageCallback clientMessageCallback) {
70   mClientMessageCallback = clientMessageCallback;
71 
72   mSockFd = android_get_control_socket(socketName);
73   if (mSockFd == INVALID_SOCKET && allowSocketCreation) {
74     LOGI("Didn't inherit socket, creating...");
75     mSockFd = socket_local_server(socketName, ANDROID_SOCKET_NAMESPACE_RESERVED,
76                                   SOCK_SEQPACKET);
77   }
78 
79   if (mSockFd == INVALID_SOCKET) {
80     LOGE("Couldn't get/create socket");
81   } else {
82     int ret = listen(mSockFd, kMaxPendingConnectionRequests);
83     if (ret < 0) {
84       LOG_ERROR("Couldn't listen on socket", errno);
85     } else {
86       serviceSocket();
87     }
88 
89     {
90       std::lock_guard<std::mutex> lock(mClientsMutex);
91       for (const auto &pair : mClients) {
92         int clientSocket = pair.first;
93         if (close(clientSocket) != 0) {
94           LOGI("Couldn't close client %" PRIu16 "'s socket: %s",
95                pair.second.clientId, strerror(errno));
96         }
97       }
98       mClients.clear();
99     }
100     close(mSockFd);
101   }
102 }
103 
sendToAllClients(const void * data,size_t length)104 void SocketServer::sendToAllClients(const void *data, size_t length) {
105   std::lock_guard<std::mutex> lock(mClientsMutex);
106 
107   int deliveredCount = 0;
108   for (const auto &pair : mClients) {
109     int clientSocket = pair.first;
110     uint16_t clientId = pair.second.clientId;
111     if (sendToClientSocket(data, length, clientSocket, clientId)) {
112       deliveredCount++;
113     } else if (errno == EINTR) {
114       // Exit early if we were interrupted - we should only get this for
115       // SIGINT/SIGTERM, so we should exit quickly
116       break;
117     }
118   }
119 
120   if (deliveredCount == 0) {
121     LOGW("Got message but didn't deliver to any clients");
122   }
123 }
124 
sendToClientById(const void * data,size_t length,uint16_t clientId)125 bool SocketServer::sendToClientById(const void *data, size_t length,
126                                     uint16_t clientId) {
127   std::lock_guard<std::mutex> lock(mClientsMutex);
128 
129   bool sent = false;
130   for (const auto &pair : mClients) {
131     uint16_t thisClientId = pair.second.clientId;
132     if (thisClientId == clientId) {
133       int clientSocket = pair.first;
134       sent = sendToClientSocket(data, length, clientSocket, thisClientId);
135       break;
136     }
137   }
138 
139   return sent;
140 }
141 
acceptClientConnection()142 void SocketServer::acceptClientConnection() {
143   int clientSocket = accept(mSockFd, NULL, NULL);
144   if (clientSocket < 0) {
145     LOG_ERROR("Couldn't accept client connection", errno);
146   } else if (mClients.size() >= kMaxActiveClients) {
147     LOGW("Rejecting client request - maximum number of clients reached");
148     close(clientSocket);
149   } else {
150     ClientData clientData;
151     clientData.clientId = mNextClientId++;
152 
153     // We currently don't handle wraparound - if we're getting this many
154     // connects/disconnects, then something is wrong.
155     // TODO: can handle this properly by iterating over the existing clients to
156     // avoid a conflict.
157     if (clientData.clientId == 0) {
158       LOGE("Couldn't allocate client ID");
159       std::exit(-1);
160     }
161 
162     bool slotFound = false;
163     for (size_t i = 1; i <= kMaxActiveClients; i++) {
164       if (mPollFds[i].fd < 0) {
165         mPollFds[i].fd = clientSocket;
166         slotFound = true;
167         break;
168       }
169     }
170 
171     if (!slotFound) {
172       LOGE("Couldn't find slot for client!");
173       assert(slotFound);
174       close(clientSocket);
175     } else {
176       {
177         std::lock_guard<std::mutex> lock(mClientsMutex);
178         mClients[clientSocket] = clientData;
179       }
180       LOGI(
181           "Accepted new client connection (count %zu), assigned client ID "
182           "%" PRIu16,
183           mClients.size(), clientData.clientId);
184     }
185   }
186 }
187 
handleClientData(int clientSocket)188 void SocketServer::handleClientData(int clientSocket) {
189   const ClientData &clientData = mClients[clientSocket];
190   uint16_t clientId = clientData.clientId;
191 
192   ssize_t packetSize =
193       recv(clientSocket, mRecvBuffer.data(), mRecvBuffer.size(), MSG_DONTWAIT);
194   if (packetSize < 0) {
195     LOGE("Couldn't get packet from client %" PRIu16 ": %s", clientId,
196          strerror(errno));
197   } else if (packetSize == 0) {
198     LOGI("Client %" PRIu16 " disconnected", clientId);
199     disconnectClient(clientSocket);
200   } else {
201     LOGV("Got %zd byte packet from client %" PRIu16, packetSize, clientId);
202     mClientMessageCallback(clientId, mRecvBuffer.data(), packetSize);
203   }
204 }
205 
disconnectClient(int clientSocket)206 void SocketServer::disconnectClient(int clientSocket) {
207   {
208     std::lock_guard<std::mutex> lock(mClientsMutex);
209     mClients.erase(clientSocket);
210   }
211   close(clientSocket);
212 
213   bool removed = false;
214   for (size_t i = 1; i <= kMaxActiveClients; i++) {
215     if (mPollFds[i].fd == clientSocket) {
216       mPollFds[i].fd = -1;
217       removed = true;
218       break;
219     }
220   }
221 
222   if (!removed) {
223     LOGE("Out of sync");
224     assert(removed);
225   }
226 }
227 
sendToClientSocket(const void * data,size_t length,int clientSocket,uint16_t clientId)228 bool SocketServer::sendToClientSocket(const void *data, size_t length,
229                                       int clientSocket, uint16_t clientId) {
230   errno = 0;
231   ssize_t bytesSent = send(clientSocket, data, length, 0);
232   if (bytesSent < 0) {
233     LOGE("Error sending packet of size %zu to client %" PRIu16 ": %s", length,
234          clientId, strerror(errno));
235   } else if (bytesSent == 0) {
236     LOGW("Client %" PRIu16 " disconnected before message could be delivered",
237          clientId);
238   } else {
239     LOGV("Delivered message of size %zu bytes to client %" PRIu16, length,
240          clientId);
241   }
242 
243   return (bytesSent > 0);
244 }
245 
serviceSocket()246 void SocketServer::serviceSocket() {
247   constexpr size_t kListenIndex = 0;
248   static_assert(kListenIndex == 0,
249                 "Code assumes that the first index is always the listen "
250                 "socket");
251 
252   mPollFds[kListenIndex].fd = mSockFd;
253   mPollFds[kListenIndex].events = POLLIN;
254 
255   // Signal mask used with ppoll() so we gracefully handle SIGINT and SIGTERM,
256   // and ignore other signals
257   sigset_t signalMask;
258   sigfillset(&signalMask);
259   sigdelset(&signalMask, SIGINT);
260   sigdelset(&signalMask, SIGTERM);
261 
262   // Masking signals here ensure that after this point, we won't handle INT/TERM
263   // until after we call into ppoll()
264   maskAllSignals();
265   std::signal(SIGINT, signalHandler);
266   std::signal(SIGTERM, signalHandler);
267 
268   LOGI("Ready to accept connections");
269   while (!sSignalReceived) {
270     int ret = TEMP_FAILURE_RETRY(
271         ppoll(mPollFds, 1 + kMaxActiveClients, nullptr, &signalMask));
272     maskAllSignalsExceptIntAndTerm();
273     if (ret == -1) {
274       LOGI("Exiting poll loop: %s", strerror(errno));
275       break;
276     }
277 
278     if (mPollFds[kListenIndex].revents & POLLIN) {
279       acceptClientConnection();
280     }
281 
282     for (size_t i = 1; i <= kMaxActiveClients; i++) {
283       if (mPollFds[i].fd < 0) {
284         continue;
285       }
286 
287       if (mPollFds[i].revents & POLLIN) {
288         handleClientData(mPollFds[i].fd);
289       }
290     }
291 
292     // Mask all signals to ensure that sSignalReceived can't become true between
293     // checking it in the while condition and calling into ppoll()
294     maskAllSignals();
295   }
296 }
297 
signalHandler(int signal)298 void SocketServer::signalHandler(int signal) {
299   LOGD("Caught signal %d", signal);
300   sSignalReceived = true;
301 }
302 
303 }  // namespace chre
304 }  // namespace android
305