1 //
2 // Copyright (C) 2020 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 //
15
16 #include "host/frontend/webrtc/libdevice/server_connection.h"
17
18 #include <android-base/logging.h>
19 #include <libwebsockets.h>
20
21 #include "common/libs/fs/shared_fd.h"
22 #include "common/libs/fs/shared_select.h"
23 #include "common/libs/utils/files.h"
24
25 namespace cuttlefish {
26 namespace webrtc_streaming {
27
28 // ServerConnection over Unix socket
29 class UnixServerConnection : public ServerConnection {
30 public:
31 UnixServerConnection(const std::string& addr,
32 std::weak_ptr<ServerConnectionObserver> observer);
33 ~UnixServerConnection() override;
34
35 bool Send(const Json::Value& msg) override;
36
37 private:
38 void Connect() override;
39 void StopThread();
40 void ReadLoop();
41
42 const std::string addr_;
43 SharedFD conn_;
44 std::mutex write_mtx_;
45 std::weak_ptr<ServerConnectionObserver> observer_;
46 // The event fd must be declared before the thread to ensure it's initialized
47 // before the thread starts and is safe to be accessed from it.
48 SharedFD thread_notifier_;
49 std::atomic_bool running_ = false;
50 std::thread thread_;
51 };
52
53 // ServerConnection using websockets
54 class WsConnectionContext;
55
56 class WsConnection : public std::enable_shared_from_this<WsConnection> {
57 public:
58 struct CreateConnectionSul {
59 lws_sorted_usec_list_t sul = {};
60 std::weak_ptr<WsConnection> weak_this;
61 };
62
63 WsConnection(int port, const std::string& addr, const std::string& path,
64 ServerConfig::Security secure,
65 std::weak_ptr<ServerConnectionObserver> observer,
66 std::shared_ptr<WsConnectionContext> context);
67
68 ~WsConnection();
69
70 void Connect();
71 bool Send(const Json::Value& msg);
72
73 void ConnectInner();
74
75 void OnError(const std::string& error);
76 void OnReceive(const uint8_t* data, size_t len, bool is_binary);
77 void OnOpen();
78 void OnClose();
79 void OnWriteable();
80
81 private:
82 struct WsBuffer {
83 WsBuffer() = default;
WsBuffercuttlefish::webrtc_streaming::WsConnection::WsBuffer84 WsBuffer(const uint8_t* data, size_t len, bool binary)
85 : buffer_(LWS_PRE + len), is_binary_(binary) {
86 memcpy(&buffer_[LWS_PRE], data, len);
87 }
88
datacuttlefish::webrtc_streaming::WsConnection::WsBuffer89 uint8_t* data() { return &buffer_[LWS_PRE]; }
is_binarycuttlefish::webrtc_streaming::WsConnection::WsBuffer90 bool is_binary() const { return is_binary_; }
sizecuttlefish::webrtc_streaming::WsConnection::WsBuffer91 size_t size() const { return buffer_.size() - LWS_PRE; }
92
93 private:
94 std::vector<uint8_t> buffer_;
95 bool is_binary_;
96 };
97 bool Send(const uint8_t* data, size_t len, bool binary = false);
98
99 CreateConnectionSul extended_sul_;
100 struct lws* wsi_;
101 const int port_;
102 const std::string addr_;
103 const std::string path_;
104 const ServerConfig::Security security_;
105
106 std::weak_ptr<ServerConnectionObserver> observer_;
107
108 // each element contains the data to be sent and whether it's binary or not
109 std::deque<WsBuffer> write_queue_;
110 std::mutex write_queue_mutex_;
111 // The connection object should not outlive the context object. This reference
112 // guarantees it.
113 std::shared_ptr<WsConnectionContext> context_;
114 };
115
116 class WsConnectionContext
117 : public std::enable_shared_from_this<WsConnectionContext> {
118 public:
119 static std::shared_ptr<WsConnectionContext> Create();
120
121 WsConnectionContext(struct lws_context* lws_ctx);
122 ~WsConnectionContext();
123
124 std::unique_ptr<ServerConnection> CreateConnection(
125 int port, const std::string& addr, const std::string& path,
126 ServerConfig::Security secure,
127 std::weak_ptr<ServerConnectionObserver> observer);
128
129 void RememberConnection(void*, std::weak_ptr<WsConnection>);
130 void ForgetConnection(void*);
131 std::shared_ptr<WsConnection> GetConnection(void*);
132
lws_context()133 struct lws_context* lws_context() {
134 return lws_context_;
135 }
136
137 private:
138 void Start();
139
140 std::map<void*, std::weak_ptr<WsConnection>> weak_by_ptr_;
141 std::mutex map_mutex_;
142 struct lws_context* lws_context_;
143 std::thread message_loop_;
144 };
145
Connect(const ServerConfig & conf,std::weak_ptr<ServerConnectionObserver> observer)146 std::unique_ptr<ServerConnection> ServerConnection::Connect(
147 const ServerConfig& conf,
148 std::weak_ptr<ServerConnectionObserver> observer) {
149 std::unique_ptr<ServerConnection> ret;
150 // If the provided address points to an existing UNIX socket in the file
151 // system connect to it, otherwise assume it's a network address and connect
152 // using websockets
153 if (FileIsSocket(conf.addr)) {
154 ret.reset(new UnixServerConnection(conf.addr, observer));
155 } else {
156 // This can be a local variable since the ws connection will keep a
157 // reference to it.
158 auto ws_context = WsConnectionContext::Create();
159 CHECK(ws_context) << "Failed to create websocket context";
160 ret = ws_context->CreateConnection(conf.port, conf.addr, conf.path,
161 conf.security, observer);
162 }
163 ret->Connect();
164 return ret;
165 }
166
Reconnect()167 void ServerConnection::Reconnect() { Connect(); }
168
169 // UnixServerConnection implementation
170
UnixServerConnection(const std::string & addr,std::weak_ptr<ServerConnectionObserver> observer)171 UnixServerConnection::UnixServerConnection(
172 const std::string& addr, std::weak_ptr<ServerConnectionObserver> observer)
173 : addr_(addr), observer_(observer) {}
174
~UnixServerConnection()175 UnixServerConnection::~UnixServerConnection() {
176 StopThread();
177 }
178
Send(const Json::Value & msg)179 bool UnixServerConnection::Send(const Json::Value& msg) {
180 Json::StreamWriterBuilder factory;
181 auto str = Json::writeString(factory, msg);
182 std::lock_guard<std::mutex> lock(write_mtx_);
183 auto res =
184 conn_->Send(reinterpret_cast<const uint8_t*>(str.c_str()), str.size(), 0);
185 if (res < 0) {
186 LOG(ERROR) << "Failed to send data to signaling server: "
187 << conn_->StrError();
188 // Don't call OnError() here, the receiving thread probably did it already
189 // or is about to do it.
190 }
191 // A SOCK_SEQPACKET unix socket will send the entire message or fail, but it
192 // won't send a partial message.
193 return res == str.size();
194 }
195
Connect()196 void UnixServerConnection::Connect() {
197 // The thread could be running if this is a Reconnect
198 StopThread();
199
200 conn_ = SharedFD::SocketLocalClient(addr_, false, SOCK_SEQPACKET);
201 if (!conn_->IsOpen()) {
202 LOG(ERROR) << "Failed to connect to unix socket: " << conn_->StrError();
203 if (auto o = observer_.lock(); o) {
204 o->OnError("Failed to connect to unix socket");
205 }
206 return;
207 }
208 thread_notifier_ = SharedFD::Event();
209 if (!thread_notifier_->IsOpen()) {
210 LOG(ERROR) << "Failed to create eventfd for background thread: "
211 << thread_notifier_->StrError();
212 if (auto o = observer_.lock(); o) {
213 o->OnError("Failed to create eventfd for background thread");
214 }
215 return;
216 }
217 if (auto o = observer_.lock(); o) {
218 o->OnOpen();
219 }
220 // Start the thread
221 running_ = true;
222 thread_ = std::thread([this](){ReadLoop();});
223 }
224
StopThread()225 void UnixServerConnection::StopThread() {
226 running_ = false;
227 if (!thread_notifier_->IsOpen()) {
228 // The thread won't be running if this isn't open
229 return;
230 }
231 if (thread_notifier_->EventfdWrite(1) < 0) {
232 LOG(ERROR) << "Failed to notify background thread, this thread may block";
233 }
234 if (thread_.joinable()) {
235 thread_.join();
236 }
237 }
238
ReadLoop()239 void UnixServerConnection::ReadLoop() {
240 if (!thread_notifier_->IsOpen()) {
241 LOG(ERROR) << "The UnixServerConnection's background thread is unable to "
242 "receive notifications so it can't run";
243 return;
244 }
245 std::vector<uint8_t> buffer(4096, 0);
246 while (running_) {
247 SharedFDSet rset;
248 rset.Set(thread_notifier_);
249 rset.Set(conn_);
250 auto res = Select(&rset, nullptr, nullptr, nullptr);
251 if (res < 0) {
252 LOG(ERROR) << "Failed to select from background thread";
253 break;
254 }
255 if (rset.IsSet(thread_notifier_)) {
256 eventfd_t val;
257 auto res = thread_notifier_->EventfdRead(&val);
258 if (res < 0) {
259 LOG(ERROR) << "Error reading from event fd: "
260 << thread_notifier_->StrError();
261 break;
262 }
263 }
264 if (rset.IsSet(conn_)) {
265 auto size = conn_->Recv(buffer.data(), 0, MSG_TRUNC | MSG_PEEK);
266 if (size > buffer.size()) {
267 // Enlarge enough to accommodate size bytes and be a multiple of 4096
268 auto new_size = (size + 4095) & ~4095;
269 buffer.resize(new_size);
270 }
271 auto res = conn_->Recv(buffer.data(), buffer.size(), MSG_TRUNC);
272 if (res < 0) {
273 LOG(ERROR) << "Failed to read from server: " << conn_->StrError();
274 if (auto observer = observer_.lock(); observer) {
275 observer->OnError(conn_->StrError());
276 }
277 return;
278 }
279 if (res == 0) {
280 auto observer = observer_.lock();
281 if (observer) {
282 observer->OnClose();
283 }
284 break;
285 }
286 auto observer = observer_.lock();
287 if (observer) {
288 observer->OnReceive(buffer.data(), res, false);
289 }
290 }
291 }
292 }
293
294 // WsConnection implementation
295
296 int LwsCallback(struct lws* wsi, enum lws_callback_reasons reason, void* user,
297 void* in, size_t len);
298 void CreateConnectionCallback(lws_sorted_usec_list_t* sul);
299
300 namespace {
301
302 constexpr char kProtocolName[] = "cf-webrtc-device";
303 constexpr int kBufferSize = 65536;
304
305 const uint32_t backoff_ms[] = {1000, 2000, 3000, 4000, 5000};
306
307 const lws_retry_bo_t kRetry = {
308 .retry_ms_table = backoff_ms,
309 .retry_ms_table_count = LWS_ARRAY_SIZE(backoff_ms),
310 .conceal_count = LWS_ARRAY_SIZE(backoff_ms),
311
312 .secs_since_valid_ping = 3, /* force PINGs after secs idle */
313 .secs_since_valid_hangup = 10, /* hangup after secs idle */
314
315 .jitter_percent = 20,
316 };
317
318 const struct lws_protocols kProtocols[2] = {
319 {kProtocolName, LwsCallback, 0, kBufferSize, 0, NULL, 0},
320 {NULL, NULL, 0, 0, 0, NULL, 0}};
321
322 } // namespace
323
Create()324 std::shared_ptr<WsConnectionContext> WsConnectionContext::Create() {
325 struct lws_context_creation_info context_info = {};
326 context_info.port = CONTEXT_PORT_NO_LISTEN;
327 context_info.options = LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT;
328 context_info.protocols = kProtocols;
329 struct lws_context* lws_ctx = lws_create_context(&context_info);
330 if (!lws_ctx) {
331 return nullptr;
332 }
333 return std::shared_ptr<WsConnectionContext>(new WsConnectionContext(lws_ctx));
334 }
335
WsConnectionContext(struct lws_context * lws_ctx)336 WsConnectionContext::WsConnectionContext(struct lws_context* lws_ctx)
337 : lws_context_(lws_ctx) {
338 Start();
339 }
340
~WsConnectionContext()341 WsConnectionContext::~WsConnectionContext() {
342 lws_context_destroy(lws_context_);
343 if (message_loop_.joinable()) {
344 message_loop_.join();
345 }
346 }
347
Start()348 void WsConnectionContext::Start() {
349 message_loop_ = std::thread([this]() {
350 for (;;) {
351 if (lws_service(lws_context_, 0) < 0) {
352 break;
353 }
354 }
355 });
356 }
357
358 // This wrapper is needed because the ServerConnection objects are meant to be
359 // referenced by std::unique_ptr but WsConnection needs to be referenced by
360 // std::shared_ptr because it's also (weakly) referenced by the websocket
361 // thread.
362 class WsConnectionWrapper : public ServerConnection {
363 public:
WsConnectionWrapper(std::shared_ptr<WsConnection> conn)364 WsConnectionWrapper(std::shared_ptr<WsConnection> conn) : conn_(conn) {}
365
Send(const Json::Value & msg)366 bool Send(const Json::Value& msg) override { return conn_->Send(msg); }
367
368 private:
Connect()369 void Connect() override { return conn_->Connect(); }
370 std::shared_ptr<WsConnection> conn_;
371 };
372
CreateConnection(int port,const std::string & addr,const std::string & path,ServerConfig::Security security,std::weak_ptr<ServerConnectionObserver> observer)373 std::unique_ptr<ServerConnection> WsConnectionContext::CreateConnection(
374 int port, const std::string& addr, const std::string& path,
375 ServerConfig::Security security,
376 std::weak_ptr<ServerConnectionObserver> observer) {
377 return std::unique_ptr<ServerConnection>(
378 new WsConnectionWrapper(std::make_shared<WsConnection>(
379 port, addr, path, security, observer, shared_from_this())));
380 }
381
GetConnection(void * raw)382 std::shared_ptr<WsConnection> WsConnectionContext::GetConnection(void* raw) {
383 std::shared_ptr<WsConnection> connection;
384 {
385 std::lock_guard<std::mutex> lock(map_mutex_);
386 if (weak_by_ptr_.count(raw) == 0) {
387 return nullptr;
388 }
389 connection = weak_by_ptr_[raw].lock();
390 if (!connection) {
391 weak_by_ptr_.erase(raw);
392 }
393 }
394 return connection;
395 }
396
RememberConnection(void * raw,std::weak_ptr<WsConnection> conn)397 void WsConnectionContext::RememberConnection(void* raw,
398 std::weak_ptr<WsConnection> conn) {
399 std::lock_guard<std::mutex> lock(map_mutex_);
400 weak_by_ptr_.emplace(
401 std::pair<void*, std::weak_ptr<WsConnection>>(raw, conn));
402 }
403
ForgetConnection(void * raw)404 void WsConnectionContext::ForgetConnection(void* raw) {
405 std::lock_guard<std::mutex> lock(map_mutex_);
406 weak_by_ptr_.erase(raw);
407 }
408
WsConnection(int port,const std::string & addr,const std::string & path,ServerConfig::Security security,std::weak_ptr<ServerConnectionObserver> observer,std::shared_ptr<WsConnectionContext> context)409 WsConnection::WsConnection(int port, const std::string& addr,
410 const std::string& path,
411 ServerConfig::Security security,
412 std::weak_ptr<ServerConnectionObserver> observer,
413 std::shared_ptr<WsConnectionContext> context)
414 : port_(port),
415 addr_(addr),
416 path_(path),
417 security_(security),
418 observer_(observer),
419 context_(context) {}
420
~WsConnection()421 WsConnection::~WsConnection() {
422 context_->ForgetConnection(this);
423 // This will cause the callback to be called which will drop the connection
424 // after seeing the context doesn't remember this object
425 lws_callback_on_writable(wsi_);
426 }
427
Connect()428 void WsConnection::Connect() {
429 memset(&extended_sul_.sul, 0, sizeof(extended_sul_.sul));
430 extended_sul_.weak_this = weak_from_this();
431 lws_sul_schedule(context_->lws_context(), 0, &extended_sul_.sul,
432 CreateConnectionCallback, 1);
433 }
434
OnError(const std::string & error)435 void WsConnection::OnError(const std::string& error) {
436 auto observer = observer_.lock();
437 if (observer) {
438 observer->OnError(error);
439 }
440 }
OnReceive(const uint8_t * data,size_t len,bool is_binary)441 void WsConnection::OnReceive(const uint8_t* data, size_t len, bool is_binary) {
442 auto observer = observer_.lock();
443 if (observer) {
444 observer->OnReceive(data, len, is_binary);
445 }
446 }
OnOpen()447 void WsConnection::OnOpen() {
448 auto observer = observer_.lock();
449 if (observer) {
450 observer->OnOpen();
451 }
452 }
OnClose()453 void WsConnection::OnClose() {
454 auto observer = observer_.lock();
455 if (observer) {
456 observer->OnClose();
457 }
458 }
459
OnWriteable()460 void WsConnection::OnWriteable() {
461 WsBuffer buffer;
462 {
463 std::lock_guard<std::mutex> lock(write_queue_mutex_);
464 if (write_queue_.size() == 0) {
465 return;
466 }
467 buffer = std::move(write_queue_.front());
468 write_queue_.pop_front();
469 }
470 auto flags = lws_write_ws_flags(
471 buffer.is_binary() ? LWS_WRITE_BINARY : LWS_WRITE_TEXT, true, true);
472 auto res = lws_write(wsi_, buffer.data(), buffer.size(),
473 (enum lws_write_protocol)flags);
474 if (res != buffer.size()) {
475 LOG(WARNING) << "Unable to send the entire message!";
476 }
477 }
478
Send(const Json::Value & msg)479 bool WsConnection::Send(const Json::Value& msg) {
480 Json::StreamWriterBuilder factory;
481 auto str = Json::writeString(factory, msg);
482 return Send(reinterpret_cast<const uint8_t*>(str.c_str()), str.size());
483 }
484
Send(const uint8_t * data,size_t len,bool binary)485 bool WsConnection::Send(const uint8_t* data, size_t len, bool binary) {
486 if (!wsi_) {
487 LOG(WARNING) << "Send called on an uninitialized connection!!";
488 return false;
489 }
490 WsBuffer buffer(data, len, binary);
491 {
492 std::lock_guard<std::mutex> lock(write_queue_mutex_);
493 write_queue_.emplace_back(std::move(buffer));
494 }
495
496 lws_callback_on_writable(wsi_);
497 return true;
498 }
499
LwsCallback(struct lws * wsi,enum lws_callback_reasons reason,void * user,void * in,size_t len)500 int LwsCallback(struct lws* wsi, enum lws_callback_reasons reason, void* user,
501 void* in, size_t len) {
502 constexpr int DROP = -1;
503 constexpr int OK = 0;
504
505 // For some values of `reason`, `user` doesn't point to the value provided
506 // when the connection was created. This function object should be used with
507 // care.
508 auto with_connection =
509 [wsi, user](std::function<void(std::shared_ptr<WsConnection>)> cb) {
510 auto context = reinterpret_cast<WsConnectionContext*>(user);
511 auto connection = context->GetConnection(wsi);
512 if (!connection) {
513 return DROP;
514 }
515 cb(connection);
516 return OK;
517 };
518
519 switch (reason) {
520 case LWS_CALLBACK_CLIENT_CONNECTION_ERROR:
521 return with_connection([in](std::shared_ptr<WsConnection> connection) {
522 connection->OnError(in ? (char*)in : "(null)");
523 });
524
525 case LWS_CALLBACK_CLIENT_RECEIVE:
526 return with_connection(
527 [in, len, wsi](std::shared_ptr<WsConnection> connection) {
528 connection->OnReceive((const uint8_t*)in, len,
529 lws_frame_is_binary(wsi));
530 });
531
532 case LWS_CALLBACK_CLIENT_ESTABLISHED:
533 return with_connection([](std::shared_ptr<WsConnection> connection) {
534 connection->OnOpen();
535 });
536
537 case LWS_CALLBACK_CLIENT_CLOSED:
538 return with_connection([](std::shared_ptr<WsConnection> connection) {
539 connection->OnClose();
540 });
541
542 case LWS_CALLBACK_CLIENT_WRITEABLE:
543 return with_connection([](std::shared_ptr<WsConnection> connection) {
544 connection->OnWriteable();
545 });
546
547 default:
548 LOG(VERBOSE) << "Unhandled value: " << reason;
549 return lws_callback_http_dummy(wsi, reason, user, in, len);
550 }
551 }
552
CreateConnectionCallback(lws_sorted_usec_list_t * sul)553 void CreateConnectionCallback(lws_sorted_usec_list_t* sul) {
554 std::shared_ptr<WsConnection> connection =
555 reinterpret_cast<WsConnection::CreateConnectionSul*>(sul)
556 ->weak_this.lock();
557 if (!connection) {
558 LOG(WARNING) << "The object was already destroyed by the time of the first "
559 << "connection attempt. That's unusual.";
560 return;
561 }
562 connection->ConnectInner();
563 }
564
ConnectInner()565 void WsConnection::ConnectInner() {
566 struct lws_client_connect_info connect_info;
567
568 memset(&connect_info, 0, sizeof(connect_info));
569
570 connect_info.context = context_->lws_context();
571 connect_info.port = port_;
572 connect_info.address = addr_.c_str();
573 connect_info.path = path_.c_str();
574 connect_info.host = connect_info.address;
575 connect_info.origin = connect_info.address;
576 switch (security_) {
577 case ServerConfig::Security::kAllowSelfSigned:
578 connect_info.ssl_connection = LCCSCF_ALLOW_SELFSIGNED |
579 LCCSCF_SKIP_SERVER_CERT_HOSTNAME_CHECK |
580 LCCSCF_USE_SSL;
581 break;
582 case ServerConfig::Security::kStrict:
583 connect_info.ssl_connection = LCCSCF_USE_SSL;
584 break;
585 case ServerConfig::Security::kInsecure:
586 connect_info.ssl_connection = 0;
587 break;
588 }
589 connect_info.protocol = "webrtc-operator";
590 connect_info.local_protocol_name = kProtocolName;
591 connect_info.pwsi = &wsi_;
592 connect_info.retry_and_idle_policy = &kRetry;
593 // There is no guarantee the connection object still exists when the callback
594 // is called. Put the context instead as the user data which is guaranteed to
595 // still exist and holds a weak ptr to the connection.
596 connect_info.userdata = context_.get();
597
598 if (lws_client_connect_via_info(&connect_info)) {
599 // wsi_ is not initialized until after the call to
600 // lws_client_connect_via_info(). Luckily, this is guaranteed to run before
601 // the protocol callback is called because it runs in the same loop.
602 context_->RememberConnection(wsi_, weak_from_this());
603 } else {
604 LOG(ERROR) << "Connection failed!";
605 }
606 }
607
608 } // namespace webrtc_streaming
609 } // namespace cuttlefish
610