1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "common/libs/utils/unix_sockets.h"
17 
18 #include <fcntl.h>
19 #include <sys/uio.h>
20 #include <unistd.h>
21 
22 #include <cstring>
23 #include <memory>
24 #include <ostream>
25 #include <utility>
26 #include <vector>
27 
28 #include <android-base/logging.h>
29 
30 #include "common/libs/fs/shared_fd.h"
31 #include "common/libs/utils/result.h"
32 
33 // This would use android::base::ReceiveFileDescriptors, but it silently drops
34 // SCM_CREDENTIALS control messages.
35 
36 namespace cuttlefish {
37 
FromRaw(const cmsghdr * cmsg)38 ControlMessage ControlMessage::FromRaw(const cmsghdr* cmsg) {
39   ControlMessage message;
40   message.data_ =
41       std::vector<char>((char*)cmsg, ((char*)cmsg) + cmsg->cmsg_len);
42   if (message.IsFileDescriptors()) {
43     size_t fdcount =
44         static_cast<size_t>(cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
45     for (int i = 0; i < fdcount; i++) {
46       // Use memcpy as CMSG_DATA may be unaligned
47       int fd = -1;
48       memcpy(&fd, CMSG_DATA(cmsg) + (i * sizeof(int)), sizeof(fd));
49       message.fds_.push_back(fd);
50     }
51   }
52   return message;
53 }
54 
FromFileDescriptors(const std::vector<SharedFD> & fds)55 Result<ControlMessage> ControlMessage::FromFileDescriptors(
56     const std::vector<SharedFD>& fds) {
57   ControlMessage message;
58   message.data_.resize(CMSG_SPACE(fds.size() * sizeof(int)), 0);
59   message.Raw()->cmsg_len = CMSG_LEN(fds.size() * sizeof(int));
60   message.Raw()->cmsg_level = SOL_SOCKET;
61   message.Raw()->cmsg_type = SCM_RIGHTS;
62   for (int i = 0; i < fds.size(); i++) {
63     int fd_copy = fds[i]->Fcntl(F_DUPFD_CLOEXEC, 3);
64     CF_EXPECT(fd_copy >= 0, "Failed to duplicate fd: " << fds[i]->StrError());
65     message.fds_.push_back(fd_copy);
66     // Following the CMSG_DATA spec, use memcpy to avoid alignment issues.
67     memcpy(CMSG_DATA(message.Raw()) + (i * sizeof(int)), &fd_copy, sizeof(int));
68   }
69   return message;
70 }
71 
72 #ifdef __linux__
FromCredentials(const ucred & credentials)73 ControlMessage ControlMessage::FromCredentials(const ucred& credentials) {
74   ControlMessage message;
75   message.data_.resize(CMSG_SPACE(sizeof(ucred)), 0);
76   message.Raw()->cmsg_len = CMSG_LEN(sizeof(ucred));
77   message.Raw()->cmsg_level = SOL_SOCKET;
78   message.Raw()->cmsg_type = SCM_CREDENTIALS;
79   // Following the CMSG_DATA spec, use memcpy to avoid alignment issues.
80   memcpy(CMSG_DATA(message.Raw()), &credentials, sizeof(credentials));
81   return message;
82 }
83 #endif
84 
ControlMessage(ControlMessage && existing)85 ControlMessage::ControlMessage(ControlMessage&& existing) {
86   // Enforce that the old ControlMessage is left empty, so it doesn't try to
87   // close any file descriptors. https://stackoverflow.com/a/17735913
88   data_ = std::move(existing.data_);
89   existing.data_.clear();
90   fds_ = std::move(existing.fds_);
91   existing.fds_.clear();
92 }
93 
operator =(ControlMessage && existing)94 ControlMessage& ControlMessage::operator=(ControlMessage&& existing) {
95   // Enforce that the old ControlMessage is left empty, so it doesn't try to
96   // close any file descriptors. https://stackoverflow.com/a/17735913
97   data_ = std::move(existing.data_);
98   existing.data_.clear();
99   fds_ = std::move(existing.fds_);
100   existing.fds_.clear();
101   return *this;
102 }
103 
~ControlMessage()104 ControlMessage::~ControlMessage() {
105   for (const auto& fd : fds_) {
106     if (close(fd) != 0) {
107       PLOG(ERROR) << "Failed to close fd " << fd
108                   << ", may have leaked or closed prematurely";
109     }
110   }
111 }
112 
Raw()113 cmsghdr* ControlMessage::Raw() {
114   return reinterpret_cast<cmsghdr*>(data_.data());
115 }
116 
Raw() const117 const cmsghdr* ControlMessage::Raw() const {
118   return reinterpret_cast<const cmsghdr*>(data_.data());
119 }
120 
121 #ifdef __linux__
IsCredentials() const122 bool ControlMessage::IsCredentials() const {
123   bool right_level = Raw()->cmsg_level == SOL_SOCKET;
124   bool right_type = Raw()->cmsg_type == SCM_CREDENTIALS;
125   bool enough_data = Raw()->cmsg_len >= sizeof(cmsghdr) + sizeof(ucred);
126   return right_level && right_type && enough_data;
127 }
128 
AsCredentials() const129 Result<ucred> ControlMessage::AsCredentials() const {
130   CF_EXPECT(IsCredentials(), "Control message does not hold a credential");
131   ucred credentials;
132   memcpy(&credentials, CMSG_DATA(Raw()), sizeof(ucred));
133   return credentials;
134 }
135 #endif
136 
IsFileDescriptors() const137 bool ControlMessage::IsFileDescriptors() const {
138   bool right_level = Raw()->cmsg_level == SOL_SOCKET;
139   bool right_type = Raw()->cmsg_type == SCM_RIGHTS;
140   return right_level && right_type;
141 }
142 
AsSharedFDs() const143 Result<std::vector<SharedFD>> ControlMessage::AsSharedFDs() const {
144   CF_EXPECT(IsFileDescriptors(), "Message does not contain file descriptors");
145   size_t fdcount =
146       static_cast<size_t>(Raw()->cmsg_len - CMSG_LEN(0)) / sizeof(int);
147   std::vector<SharedFD> shared_fds;
148   for (int i = 0; i < fdcount; i++) {
149     // Use memcpy as CMSG_DATA may be unaligned
150     int fd = -1;
151     memcpy(&fd, CMSG_DATA(Raw()) + (i * sizeof(int)), sizeof(fd));
152     SharedFD shared_fd = SharedFD::Dup(fd);
153     CF_EXPECT(shared_fd->IsOpen(), "Could not dup FD " << fd);
154     shared_fds.push_back(shared_fd);
155   }
156   return shared_fds;
157 }
158 
HasFileDescriptors()159 bool UnixSocketMessage::HasFileDescriptors() {
160   for (const auto& control_message : control) {
161     if (control_message.IsFileDescriptors()) {
162       return true;
163     }
164   }
165   return false;
166 }
FileDescriptors()167 Result<std::vector<SharedFD>> UnixSocketMessage::FileDescriptors() {
168   std::vector<SharedFD> fds;
169   for (const auto& control_message : control) {
170     if (control_message.IsFileDescriptors()) {
171       auto additional_fds = CF_EXPECT(control_message.AsSharedFDs());
172       fds.insert(fds.end(), additional_fds.begin(), additional_fds.end());
173     }
174   }
175   return fds;
176 }
177 #ifdef __linux__
HasCredentials()178 bool UnixSocketMessage::HasCredentials() {
179   for (const auto& control_message : control) {
180     if (control_message.IsCredentials()) {
181       return true;
182     }
183   }
184   return false;
185 }
Credentials()186 Result<ucred> UnixSocketMessage::Credentials() {
187   std::vector<ucred> credentials;
188   for (const auto& control_message : control) {
189     if (control_message.IsCredentials()) {
190       auto creds = CF_EXPECT(control_message.AsCredentials(),
191                              "Message claims to have credentials but does not");
192       credentials.push_back(creds);
193     }
194   }
195   if (credentials.size() == 0) {
196     return CF_ERR("No credentials present");
197   } else if (credentials.size() == 1) {
198     return credentials[0];
199   } else {
200     return CF_ERR("Excepted 1 credential, received " << credentials.size());
201   }
202 }
203 #endif
204 
UnixMessageSocket(SharedFD socket)205 UnixMessageSocket::UnixMessageSocket(SharedFD socket) : socket_(socket) {
206   socklen_t ln = sizeof(max_message_size_);
207   CHECK(socket->GetSockOpt(SOL_SOCKET, SO_SNDBUF, &max_message_size_, &ln) == 0)
208       << "error: can't retrieve socket max message size: "
209       << socket->StrError();
210 }
211 
212 #ifdef __linux__
EnableCredentials(bool enable)213 Result<void> UnixMessageSocket::EnableCredentials(bool enable) {
214   int flag = enable ? 1 : 0;
215   if (socket_->SetSockOpt(SOL_SOCKET, SO_PASSCRED, &flag, sizeof(flag)) != 0) {
216     return CF_ERR("Could not set credential status to " << enable << ": "
217                                                         << socket_->StrError());
218   }
219   return {};
220 }
221 #endif
222 
WriteMessage(const UnixSocketMessage & message)223 Result<void> UnixMessageSocket::WriteMessage(const UnixSocketMessage& message) {
224   auto control_size = 0;
225   for (const auto& control : message.control) {
226     control_size += control.data_.size();
227   }
228   std::vector<char> message_control(control_size, 0);
229   msghdr message_header{};
230   message_header.msg_control = message_control.data();
231   message_header.msg_controllen = message_control.size();
232   auto cmsg = CMSG_FIRSTHDR(&message_header);
233   for (const ControlMessage& control : message.control) {
234     CF_EXPECT(cmsg != nullptr,
235               "Control messages did not fit in control buffer");
236     /* size() should match CMSG_SPACE */
237     memcpy(cmsg, control.data_.data(), control.data_.size());
238     cmsg = CMSG_NXTHDR(&message_header, cmsg);
239   }
240 
241   iovec message_iovec;
242   message_iovec.iov_base = (void*)message.data.data();
243   message_iovec.iov_len = message.data.size();
244   message_header.msg_name = nullptr;
245   message_header.msg_namelen = 0;
246   message_header.msg_iov = &message_iovec;
247   message_header.msg_iovlen = 1;
248   message_header.msg_flags = 0;
249 
250   auto bytes_sent = socket_->SendMsg(&message_header, MSG_NOSIGNAL);
251   CF_EXPECT(bytes_sent >= 0, "Failed to send message: " << socket_->StrError());
252   CF_EXPECT(bytes_sent == message.data.size(),
253             "Failed to send entire message. Sent "
254                 << bytes_sent << ", excepted to send " << message.data.size());
255   return {};
256 }
257 
ReadMessage()258 Result<UnixSocketMessage> UnixMessageSocket::ReadMessage() {
259   msghdr message_header{};
260   std::vector<char> message_control(max_message_size_, 0);
261   message_header.msg_control = message_control.data();
262   message_header.msg_controllen = message_control.size();
263   std::vector<char> message_data(max_message_size_, 0);
264   iovec message_iovec;
265   message_iovec.iov_base = message_data.data();
266   message_iovec.iov_len = message_data.size();
267   message_header.msg_iov = &message_iovec;
268   message_header.msg_iovlen = 1;
269   message_header.msg_name = nullptr;
270   message_header.msg_namelen = 0;
271   message_header.msg_flags = 0;
272 
273 #ifdef __linux__
274   auto bytes_read = socket_->RecvMsg(&message_header, MSG_CMSG_CLOEXEC);
275 #elif defined(__APPLE__)
276   auto bytes_read = socket_->RecvMsg(&message_header, 0);
277 #else
278 #error "Unsupported operating system"
279 #endif
280   CF_EXPECT(bytes_read >= 0, "Read error: " << socket_->StrError());
281   CF_EXPECT(!(message_header.msg_flags & MSG_TRUNC),
282             "Message was truncated on read");
283   CF_EXPECT(!(message_header.msg_flags & MSG_CTRUNC),
284             "Message control data was truncated on read");
285 #ifdef __linux__
286   CF_EXPECT(!(message_header.msg_flags & MSG_ERRQUEUE), "Error queue error");
287 #endif
288   UnixSocketMessage managed_message;
289   for (auto cmsg = CMSG_FIRSTHDR(&message_header); cmsg != nullptr;
290        cmsg = CMSG_NXTHDR(&message_header, cmsg)) {
291     managed_message.control.emplace_back(ControlMessage::FromRaw(cmsg));
292   }
293   message_data.resize(bytes_read);
294   managed_message.data = std::move(message_data);
295 
296   return managed_message;
297 }
298 
299 }  // namespace cuttlefish
300