1 //
2 // Copyright (C) 2020 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 //
15 
16 #include "host/frontend/webrtc/libdevice/server_connection.h"
17 
18 #include <android-base/logging.h>
19 #include <libwebsockets.h>
20 
21 #include "common/libs/fs/shared_fd.h"
22 #include "common/libs/fs/shared_select.h"
23 #include "common/libs/utils/files.h"
24 
25 namespace cuttlefish {
26 namespace webrtc_streaming {
27 
28 // ServerConnection over Unix socket
29 class UnixServerConnection : public ServerConnection {
30  public:
31   UnixServerConnection(const std::string& addr,
32                        std::weak_ptr<ServerConnectionObserver> observer);
33   ~UnixServerConnection() override;
34 
35   bool Send(const Json::Value& msg) override;
36 
37  private:
38   void Connect() override;
39   void StopThread();
40   void ReadLoop();
41 
42   const std::string addr_;
43   SharedFD conn_;
44   std::mutex write_mtx_;
45   std::weak_ptr<ServerConnectionObserver> observer_;
46   // The event fd must be declared before the thread to ensure it's initialized
47   // before the thread starts and is safe to be accessed from it.
48   SharedFD thread_notifier_;
49   std::atomic_bool running_ = false;
50   std::thread thread_;
51 };
52 
53 // ServerConnection using websockets
54 class WsConnectionContext;
55 
56 class WsConnection : public std::enable_shared_from_this<WsConnection> {
57  public:
58   struct CreateConnectionSul {
59     lws_sorted_usec_list_t sul = {};
60     std::weak_ptr<WsConnection> weak_this;
61   };
62 
63   WsConnection(int port, const std::string& addr, const std::string& path,
64                ServerConfig::Security secure,
65                std::weak_ptr<ServerConnectionObserver> observer,
66                std::shared_ptr<WsConnectionContext> context);
67 
68   ~WsConnection();
69 
70   void Connect();
71   bool Send(const Json::Value& msg);
72 
73   void ConnectInner();
74 
75   void OnError(const std::string& error);
76   void OnReceive(const uint8_t* data, size_t len, bool is_binary);
77   void OnOpen();
78   void OnClose();
79   void OnWriteable();
80 
81  private:
82   struct WsBuffer {
83     WsBuffer() = default;
WsBuffercuttlefish::webrtc_streaming::WsConnection::WsBuffer84     WsBuffer(const uint8_t* data, size_t len, bool binary)
85         : buffer_(LWS_PRE + len), is_binary_(binary) {
86       memcpy(&buffer_[LWS_PRE], data, len);
87     }
88 
datacuttlefish::webrtc_streaming::WsConnection::WsBuffer89     uint8_t* data() { return &buffer_[LWS_PRE]; }
is_binarycuttlefish::webrtc_streaming::WsConnection::WsBuffer90     bool is_binary() const { return is_binary_; }
sizecuttlefish::webrtc_streaming::WsConnection::WsBuffer91     size_t size() const { return buffer_.size() - LWS_PRE; }
92 
93    private:
94     std::vector<uint8_t> buffer_;
95     bool is_binary_;
96   };
97   bool Send(const uint8_t* data, size_t len, bool binary = false);
98 
99   CreateConnectionSul extended_sul_;
100   struct lws* wsi_;
101   const int port_;
102   const std::string addr_;
103   const std::string path_;
104   const ServerConfig::Security security_;
105 
106   std::weak_ptr<ServerConnectionObserver> observer_;
107 
108   // each element contains the data to be sent and whether it's binary or not
109   std::deque<WsBuffer> write_queue_;
110   std::mutex write_queue_mutex_;
111   // The connection object should not outlive the context object. This reference
112   // guarantees it.
113   std::shared_ptr<WsConnectionContext> context_;
114 };
115 
116 class WsConnectionContext
117     : public std::enable_shared_from_this<WsConnectionContext> {
118  public:
119   static std::shared_ptr<WsConnectionContext> Create();
120 
121   WsConnectionContext(struct lws_context* lws_ctx);
122   ~WsConnectionContext();
123 
124   std::unique_ptr<ServerConnection> CreateConnection(
125       int port, const std::string& addr, const std::string& path,
126       ServerConfig::Security secure,
127       std::weak_ptr<ServerConnectionObserver> observer);
128 
129   void RememberConnection(void*, std::weak_ptr<WsConnection>);
130   void ForgetConnection(void*);
131   std::shared_ptr<WsConnection> GetConnection(void*);
132 
lws_context()133   struct lws_context* lws_context() {
134     return lws_context_;
135   }
136 
137  private:
138   void Start();
139 
140   std::map<void*, std::weak_ptr<WsConnection>> weak_by_ptr_;
141   std::mutex map_mutex_;
142   struct lws_context* lws_context_;
143   std::thread message_loop_;
144 };
145 
Connect(const ServerConfig & conf,std::weak_ptr<ServerConnectionObserver> observer)146 std::unique_ptr<ServerConnection> ServerConnection::Connect(
147     const ServerConfig& conf,
148     std::weak_ptr<ServerConnectionObserver> observer) {
149   std::unique_ptr<ServerConnection> ret;
150   // If the provided address points to an existing UNIX socket in the file
151   // system connect to it, otherwise assume it's a network address and connect
152   // using websockets
153   if (FileIsSocket(conf.addr)) {
154     ret.reset(new UnixServerConnection(conf.addr, observer));
155   } else {
156     // This can be a local variable since the ws connection will keep a
157     // reference to it.
158     auto ws_context = WsConnectionContext::Create();
159     CHECK(ws_context) << "Failed to create websocket context";
160     ret = ws_context->CreateConnection(conf.port, conf.addr, conf.path,
161                                        conf.security, observer);
162   }
163   ret->Connect();
164   return ret;
165 }
166 
Reconnect()167 void ServerConnection::Reconnect() { Connect(); }
168 
169 // UnixServerConnection implementation
170 
UnixServerConnection(const std::string & addr,std::weak_ptr<ServerConnectionObserver> observer)171 UnixServerConnection::UnixServerConnection(
172     const std::string& addr, std::weak_ptr<ServerConnectionObserver> observer)
173     : addr_(addr), observer_(observer) {}
174 
~UnixServerConnection()175 UnixServerConnection::~UnixServerConnection() {
176   StopThread();
177 }
178 
Send(const Json::Value & msg)179 bool UnixServerConnection::Send(const Json::Value& msg) {
180   Json::StreamWriterBuilder factory;
181   auto str = Json::writeString(factory, msg);
182   std::lock_guard<std::mutex> lock(write_mtx_);
183   auto res =
184       conn_->Send(reinterpret_cast<const uint8_t*>(str.c_str()), str.size(), 0);
185   if (res < 0) {
186     LOG(ERROR) << "Failed to send data to signaling server: "
187                << conn_->StrError();
188     // Don't call OnError() here, the receiving thread probably did it already
189     // or is about to do it.
190   }
191   // A SOCK_SEQPACKET unix socket will send the entire message or fail, but it
192   // won't send a partial message.
193   return res == str.size();
194 }
195 
Connect()196 void UnixServerConnection::Connect() {
197   // The thread could be running if this is a Reconnect
198   StopThread();
199 
200   conn_ = SharedFD::SocketLocalClient(addr_, false, SOCK_SEQPACKET);
201   if (!conn_->IsOpen()) {
202     LOG(ERROR) << "Failed to connect to unix socket: " << conn_->StrError();
203     if (auto o = observer_.lock(); o) {
204       o->OnError("Failed to connect to unix socket");
205     }
206     return;
207   }
208   thread_notifier_ = SharedFD::Event();
209   if (!thread_notifier_->IsOpen()) {
210     LOG(ERROR) << "Failed to create eventfd for background thread: "
211                << thread_notifier_->StrError();
212     if (auto o = observer_.lock(); o) {
213       o->OnError("Failed to create eventfd for background thread");
214     }
215     return;
216   }
217   if (auto o = observer_.lock(); o) {
218     o->OnOpen();
219   }
220   // Start the thread
221   running_ = true;
222   thread_ = std::thread([this](){ReadLoop();});
223 }
224 
StopThread()225 void UnixServerConnection::StopThread() {
226   running_ = false;
227   if (!thread_notifier_->IsOpen()) {
228     // The thread won't be running if this isn't open
229     return;
230   }
231   if (thread_notifier_->EventfdWrite(1) < 0) {
232     LOG(ERROR) << "Failed to notify background thread, this thread may block";
233   }
234   if (thread_.joinable()) {
235     thread_.join();
236   }
237 }
238 
ReadLoop()239 void UnixServerConnection::ReadLoop() {
240   if (!thread_notifier_->IsOpen()) {
241     LOG(ERROR) << "The UnixServerConnection's background thread is unable to "
242                   "receive notifications so it can't run";
243     return;
244   }
245   std::vector<uint8_t> buffer(4096, 0);
246   while (running_) {
247     SharedFDSet rset;
248     rset.Set(thread_notifier_);
249     rset.Set(conn_);
250     auto res = Select(&rset, nullptr, nullptr, nullptr);
251     if (res < 0) {
252       LOG(ERROR) << "Failed to select from background thread";
253       break;
254     }
255     if (rset.IsSet(thread_notifier_)) {
256       eventfd_t val;
257       auto res = thread_notifier_->EventfdRead(&val);
258       if (res < 0) {
259         LOG(ERROR) << "Error reading from event fd: "
260                    << thread_notifier_->StrError();
261         break;
262       }
263     }
264     if (rset.IsSet(conn_)) {
265       auto size = conn_->Recv(buffer.data(), 0, MSG_TRUNC | MSG_PEEK);
266       if (size > buffer.size()) {
267         // Enlarge enough to accommodate size bytes and be a multiple of 4096
268         auto new_size = (size + 4095) & ~4095;
269         buffer.resize(new_size);
270       }
271       auto res = conn_->Recv(buffer.data(), buffer.size(), MSG_TRUNC);
272       if (res < 0) {
273         LOG(ERROR) << "Failed to read from server: " << conn_->StrError();
274         if (auto observer = observer_.lock(); observer) {
275           observer->OnError(conn_->StrError());
276         }
277         return;
278       }
279       if (res == 0) {
280         auto observer = observer_.lock();
281         if (observer) {
282           observer->OnClose();
283         }
284         break;
285       }
286       auto observer = observer_.lock();
287       if (observer) {
288         observer->OnReceive(buffer.data(), res, false);
289       }
290     }
291   }
292 }
293 
294 // WsConnection implementation
295 
296 int LwsCallback(struct lws* wsi, enum lws_callback_reasons reason, void* user,
297                 void* in, size_t len);
298 void CreateConnectionCallback(lws_sorted_usec_list_t* sul);
299 
300 namespace {
301 
302 constexpr char kProtocolName[] = "cf-webrtc-device";
303 constexpr int kBufferSize = 65536;
304 
305 const uint32_t backoff_ms[] = {1000, 2000, 3000, 4000, 5000};
306 
307 const lws_retry_bo_t kRetry = {
308     .retry_ms_table = backoff_ms,
309     .retry_ms_table_count = LWS_ARRAY_SIZE(backoff_ms),
310     .conceal_count = LWS_ARRAY_SIZE(backoff_ms),
311 
312     .secs_since_valid_ping = 3,    /* force PINGs after secs idle */
313     .secs_since_valid_hangup = 10, /* hangup after secs idle */
314 
315     .jitter_percent = 20,
316 };
317 
318 const struct lws_protocols kProtocols[2] = {
319     {kProtocolName, LwsCallback, 0, kBufferSize, 0, NULL, 0},
320     {NULL, NULL, 0, 0, 0, NULL, 0}};
321 
322 }  // namespace
323 
Create()324 std::shared_ptr<WsConnectionContext> WsConnectionContext::Create() {
325   struct lws_context_creation_info context_info = {};
326   context_info.port = CONTEXT_PORT_NO_LISTEN;
327   context_info.options = LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT;
328   context_info.protocols = kProtocols;
329   struct lws_context* lws_ctx = lws_create_context(&context_info);
330   if (!lws_ctx) {
331     return nullptr;
332   }
333   return std::shared_ptr<WsConnectionContext>(new WsConnectionContext(lws_ctx));
334 }
335 
WsConnectionContext(struct lws_context * lws_ctx)336 WsConnectionContext::WsConnectionContext(struct lws_context* lws_ctx)
337     : lws_context_(lws_ctx) {
338   Start();
339 }
340 
~WsConnectionContext()341 WsConnectionContext::~WsConnectionContext() {
342   lws_context_destroy(lws_context_);
343   if (message_loop_.joinable()) {
344     message_loop_.join();
345   }
346 }
347 
Start()348 void WsConnectionContext::Start() {
349   message_loop_ = std::thread([this]() {
350     for (;;) {
351       if (lws_service(lws_context_, 0) < 0) {
352         break;
353       }
354     }
355   });
356 }
357 
358 // This wrapper is needed because the ServerConnection objects are meant to be
359 // referenced by std::unique_ptr but WsConnection needs to be referenced by
360 // std::shared_ptr because it's also (weakly) referenced by the websocket
361 // thread.
362 class WsConnectionWrapper : public ServerConnection {
363  public:
WsConnectionWrapper(std::shared_ptr<WsConnection> conn)364   WsConnectionWrapper(std::shared_ptr<WsConnection> conn) : conn_(conn) {}
365 
Send(const Json::Value & msg)366   bool Send(const Json::Value& msg) override { return conn_->Send(msg); }
367 
368  private:
Connect()369   void Connect() override { return conn_->Connect(); }
370   std::shared_ptr<WsConnection> conn_;
371 };
372 
CreateConnection(int port,const std::string & addr,const std::string & path,ServerConfig::Security security,std::weak_ptr<ServerConnectionObserver> observer)373 std::unique_ptr<ServerConnection> WsConnectionContext::CreateConnection(
374     int port, const std::string& addr, const std::string& path,
375     ServerConfig::Security security,
376     std::weak_ptr<ServerConnectionObserver> observer) {
377   return std::unique_ptr<ServerConnection>(
378       new WsConnectionWrapper(std::make_shared<WsConnection>(
379           port, addr, path, security, observer, shared_from_this())));
380 }
381 
GetConnection(void * raw)382 std::shared_ptr<WsConnection> WsConnectionContext::GetConnection(void* raw) {
383   std::shared_ptr<WsConnection> connection;
384   {
385     std::lock_guard<std::mutex> lock(map_mutex_);
386     if (weak_by_ptr_.count(raw) == 0) {
387       return nullptr;
388     }
389     connection = weak_by_ptr_[raw].lock();
390     if (!connection) {
391       weak_by_ptr_.erase(raw);
392     }
393   }
394   return connection;
395 }
396 
RememberConnection(void * raw,std::weak_ptr<WsConnection> conn)397 void WsConnectionContext::RememberConnection(void* raw,
398                                              std::weak_ptr<WsConnection> conn) {
399   std::lock_guard<std::mutex> lock(map_mutex_);
400   weak_by_ptr_.emplace(
401       std::pair<void*, std::weak_ptr<WsConnection>>(raw, conn));
402 }
403 
ForgetConnection(void * raw)404 void WsConnectionContext::ForgetConnection(void* raw) {
405   std::lock_guard<std::mutex> lock(map_mutex_);
406   weak_by_ptr_.erase(raw);
407 }
408 
WsConnection(int port,const std::string & addr,const std::string & path,ServerConfig::Security security,std::weak_ptr<ServerConnectionObserver> observer,std::shared_ptr<WsConnectionContext> context)409 WsConnection::WsConnection(int port, const std::string& addr,
410                            const std::string& path,
411                            ServerConfig::Security security,
412                            std::weak_ptr<ServerConnectionObserver> observer,
413                            std::shared_ptr<WsConnectionContext> context)
414     : port_(port),
415       addr_(addr),
416       path_(path),
417       security_(security),
418       observer_(observer),
419       context_(context) {}
420 
~WsConnection()421 WsConnection::~WsConnection() {
422   context_->ForgetConnection(this);
423   // This will cause the callback to be called which will drop the connection
424   // after seeing the context doesn't remember this object
425   lws_callback_on_writable(wsi_);
426 }
427 
Connect()428 void WsConnection::Connect() {
429   memset(&extended_sul_.sul, 0, sizeof(extended_sul_.sul));
430   extended_sul_.weak_this = weak_from_this();
431   lws_sul_schedule(context_->lws_context(), 0, &extended_sul_.sul,
432                    CreateConnectionCallback, 1);
433 }
434 
OnError(const std::string & error)435 void WsConnection::OnError(const std::string& error) {
436   auto observer = observer_.lock();
437   if (observer) {
438     observer->OnError(error);
439   }
440 }
OnReceive(const uint8_t * data,size_t len,bool is_binary)441 void WsConnection::OnReceive(const uint8_t* data, size_t len, bool is_binary) {
442   auto observer = observer_.lock();
443   if (observer) {
444     observer->OnReceive(data, len, is_binary);
445   }
446 }
OnOpen()447 void WsConnection::OnOpen() {
448   auto observer = observer_.lock();
449   if (observer) {
450     observer->OnOpen();
451   }
452 }
OnClose()453 void WsConnection::OnClose() {
454   auto observer = observer_.lock();
455   if (observer) {
456     observer->OnClose();
457   }
458 }
459 
OnWriteable()460 void WsConnection::OnWriteable() {
461   WsBuffer buffer;
462   {
463     std::lock_guard<std::mutex> lock(write_queue_mutex_);
464     if (write_queue_.size() == 0) {
465       return;
466     }
467     buffer = std::move(write_queue_.front());
468     write_queue_.pop_front();
469   }
470   auto flags = lws_write_ws_flags(
471       buffer.is_binary() ? LWS_WRITE_BINARY : LWS_WRITE_TEXT, true, true);
472   auto res = lws_write(wsi_, buffer.data(), buffer.size(),
473                        (enum lws_write_protocol)flags);
474   if (res != buffer.size()) {
475     LOG(WARNING) << "Unable to send the entire message!";
476   }
477 }
478 
Send(const Json::Value & msg)479 bool WsConnection::Send(const Json::Value& msg) {
480   Json::StreamWriterBuilder factory;
481   auto str = Json::writeString(factory, msg);
482   return Send(reinterpret_cast<const uint8_t*>(str.c_str()), str.size());
483 }
484 
Send(const uint8_t * data,size_t len,bool binary)485 bool WsConnection::Send(const uint8_t* data, size_t len, bool binary) {
486   if (!wsi_) {
487     LOG(WARNING) << "Send called on an uninitialized connection!!";
488     return false;
489   }
490   WsBuffer buffer(data, len, binary);
491   {
492     std::lock_guard<std::mutex> lock(write_queue_mutex_);
493     write_queue_.emplace_back(std::move(buffer));
494   }
495 
496   lws_callback_on_writable(wsi_);
497   return true;
498 }
499 
LwsCallback(struct lws * wsi,enum lws_callback_reasons reason,void * user,void * in,size_t len)500 int LwsCallback(struct lws* wsi, enum lws_callback_reasons reason, void* user,
501                 void* in, size_t len) {
502   constexpr int DROP = -1;
503   constexpr int OK = 0;
504 
505   // For some values of `reason`, `user` doesn't point to the value provided
506   // when the connection was created. This function object should be used with
507   // care.
508   auto with_connection =
509       [wsi, user](std::function<void(std::shared_ptr<WsConnection>)> cb) {
510         auto context = reinterpret_cast<WsConnectionContext*>(user);
511         auto connection = context->GetConnection(wsi);
512         if (!connection) {
513           return DROP;
514         }
515         cb(connection);
516         return OK;
517       };
518 
519   switch (reason) {
520     case LWS_CALLBACK_CLIENT_CONNECTION_ERROR:
521       return with_connection([in](std::shared_ptr<WsConnection> connection) {
522         connection->OnError(in ? (char*)in : "(null)");
523       });
524 
525     case LWS_CALLBACK_CLIENT_RECEIVE:
526       return with_connection(
527           [in, len, wsi](std::shared_ptr<WsConnection> connection) {
528             connection->OnReceive((const uint8_t*)in, len,
529                                   lws_frame_is_binary(wsi));
530           });
531 
532     case LWS_CALLBACK_CLIENT_ESTABLISHED:
533       return with_connection([](std::shared_ptr<WsConnection> connection) {
534         connection->OnOpen();
535       });
536 
537     case LWS_CALLBACK_CLIENT_CLOSED:
538       return with_connection([](std::shared_ptr<WsConnection> connection) {
539         connection->OnClose();
540       });
541 
542     case LWS_CALLBACK_CLIENT_WRITEABLE:
543       return with_connection([](std::shared_ptr<WsConnection> connection) {
544         connection->OnWriteable();
545       });
546 
547     default:
548       LOG(VERBOSE) << "Unhandled value: " << reason;
549       return lws_callback_http_dummy(wsi, reason, user, in, len);
550   }
551 }
552 
CreateConnectionCallback(lws_sorted_usec_list_t * sul)553 void CreateConnectionCallback(lws_sorted_usec_list_t* sul) {
554   std::shared_ptr<WsConnection> connection =
555       reinterpret_cast<WsConnection::CreateConnectionSul*>(sul)
556           ->weak_this.lock();
557   if (!connection) {
558     LOG(WARNING) << "The object was already destroyed by the time of the first "
559                  << "connection attempt. That's unusual.";
560     return;
561   }
562   connection->ConnectInner();
563 }
564 
ConnectInner()565 void WsConnection::ConnectInner() {
566   struct lws_client_connect_info connect_info;
567 
568   memset(&connect_info, 0, sizeof(connect_info));
569 
570   connect_info.context = context_->lws_context();
571   connect_info.port = port_;
572   connect_info.address = addr_.c_str();
573   connect_info.path = path_.c_str();
574   connect_info.host = connect_info.address;
575   connect_info.origin = connect_info.address;
576   switch (security_) {
577     case ServerConfig::Security::kAllowSelfSigned:
578       connect_info.ssl_connection = LCCSCF_ALLOW_SELFSIGNED |
579                                     LCCSCF_SKIP_SERVER_CERT_HOSTNAME_CHECK |
580                                     LCCSCF_USE_SSL;
581       break;
582     case ServerConfig::Security::kStrict:
583       connect_info.ssl_connection = LCCSCF_USE_SSL;
584       break;
585     case ServerConfig::Security::kInsecure:
586       connect_info.ssl_connection = 0;
587       break;
588   }
589   connect_info.protocol = "webrtc-operator";
590   connect_info.local_protocol_name = kProtocolName;
591   connect_info.pwsi = &wsi_;
592   connect_info.retry_and_idle_policy = &kRetry;
593   // There is no guarantee the connection object still exists when the callback
594   // is called. Put the context instead as the user data which is guaranteed to
595   // still exist and holds a weak ptr to the connection.
596   connect_info.userdata = context_.get();
597 
598   if (lws_client_connect_via_info(&connect_info)) {
599     // wsi_ is not initialized until after the call to
600     // lws_client_connect_via_info(). Luckily, this is guaranteed to run before
601     // the protocol callback is called because it runs in the same loop.
602     context_->RememberConnection(wsi_, weak_from_this());
603   } else {
604     LOG(ERROR) << "Connection failed!";
605   }
606 }
607 
608 }  // namespace webrtc_streaming
609 }  // namespace cuttlefish
610