1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "chre_host/pigweed/hal_rpc_client.h"
18 
19 #include <cstdint>
20 #include <memory>
21 
22 #include "chre/event.h"
23 #include "chre/util/pigweed/rpc_common.h"
24 #include "chre_host/host_protocol_host.h"
25 #include "chre_host/log.h"
26 #include "chre_host/pigweed/hal_channel_output.h"
27 
28 namespace android::chre {
29 
30 using ::chre::fbs::HubInfoResponseT;
31 using ::chre::fbs::NanoappListEntryT;
32 using ::chre::fbs::NanoappListResponseT;
33 using ::chre::fbs::NanoappMessageT;
34 using ::chre::fbs::NanoappRpcServiceT;
35 using ::flatbuffers::FlatBufferBuilder;
36 
createClient(std::string_view appName,SocketClient & client,sp<SocketClient::ICallbacks> socketCallbacks,uint16_t hostEndpointId,uint64_t serverNanoappId)37 std::unique_ptr<HalRpcClient> HalRpcClient::createClient(
38     std::string_view appName, SocketClient &client,
39     sp<SocketClient::ICallbacks> socketCallbacks, uint16_t hostEndpointId,
40     uint64_t serverNanoappId) {
41   auto rpcClient = std::unique_ptr<HalRpcClient>(
42       new HalRpcClient(appName, client, hostEndpointId, serverNanoappId));
43 
44   if (!rpcClient->init(socketCallbacks)) {
45     return nullptr;
46   }
47 
48   return rpcClient;
49 }
50 
hasService(uint64_t id,uint32_t version)51 bool HalRpcClient::hasService(uint64_t id, uint32_t version) {
52   std::lock_guard lock(mNanoappMutex);
53   for (const NanoappRpcServiceT &service : mServices) {
54     if (service.id == id && service.version == version) {
55       return true;
56     }
57   }
58 
59   return false;
60 }
61 
close()62 void HalRpcClient::close() {
63   if (mIsChannelOpened) {
64     mRpcClient.CloseChannel(GetChannelId());
65     mIsChannelOpened = false;
66   }
67   if (mSocketClient.isConnected()) {
68     notifyEndpointDisconnected();
69     mSocketClient.disconnect();
70   }
71 }
72 
73 // Private methods
74 
init(sp<SocketClient::ICallbacks> socketCallbacks)75 bool HalRpcClient::init(sp<SocketClient::ICallbacks> socketCallbacks) {
76   if (mSocketClient.isConnected()) {
77     LOGE("Already connected to socket");
78     return false;
79   }
80 
81   auto callbacks = sp<Callbacks>::make(this, socketCallbacks);
82 
83   if (!mSocketClient.connect("chre", callbacks)) {
84     LOGE("Couldn't connect to socket");
85     return false;
86   }
87 
88   bool success = true;
89 
90   if (!notifyEndpointConnected()) {
91     LOGE("Failed to notify connection");
92     success = false;
93   } else if (!retrieveMaxMessageLen()) {
94     LOGE("Failed to retrieve the max message length");
95     success = false;
96   } else if (!retrieveServices()) {
97     LOGE("Failed to retrieve the services");
98     success = false;
99   }
100 
101   if (!success) {
102     mSocketClient.disconnect();
103     return false;
104   }
105 
106   {
107     std::lock_guard lock(mHubInfoMutex);
108     mChannelOutput = std::make_unique<HalChannelOutput>(
109         mSocketClient, mHostEndpointId, mServerNanoappId, mMaxMessageLen);
110   }
111 
112   return true;
113 }
114 
notifyEndpointConnected()115 bool HalRpcClient::notifyEndpointConnected() {
116   FlatBufferBuilder builder(64);
117   HostProtocolHost::encodeHostEndpointConnected(
118       builder, mHostEndpointId, CHRE_HOST_ENDPOINT_TYPE_NATIVE, mAppName,
119       /* attributionTag= */ "");
120   return mSocketClient.sendMessage(builder.GetBufferPointer(),
121                                    builder.GetSize());
122 }
123 
notifyEndpointDisconnected()124 bool HalRpcClient::notifyEndpointDisconnected() {
125   FlatBufferBuilder builder(64);
126   HostProtocolHost::encodeHostEndpointDisconnected(builder, mHostEndpointId);
127   return mSocketClient.sendMessage(builder.GetBufferPointer(),
128                                    builder.GetSize());
129 }
130 
retrieveMaxMessageLen()131 bool HalRpcClient::retrieveMaxMessageLen() {
132   FlatBufferBuilder builder(64);
133   HostProtocolHost::encodeHubInfoRequest(builder);
134   if (!mSocketClient.sendMessage(builder.GetBufferPointer(),
135                                  builder.GetSize())) {
136     return false;
137   }
138 
139   std::unique_lock lock(mHubInfoMutex);
140   std::cv_status status = mHubInfoCond.wait_for(lock, kRequestTimeout);
141 
142   return status != std::cv_status::timeout;
143 }
144 
retrieveServices()145 bool HalRpcClient::retrieveServices() {
146   FlatBufferBuilder builder(64);
147   HostProtocolHost::encodeNanoappListRequest(builder);
148 
149   if (!mSocketClient.sendMessage(builder.GetBufferPointer(),
150                                  builder.GetSize())) {
151     return false;
152   }
153 
154   std::unique_lock lock(mNanoappMutex);
155   std::cv_status status = mNanoappCond.wait_for(lock, kRequestTimeout);
156 
157   return status != std::cv_status::timeout;
158 }
159 
160 // Socket callbacks.
161 
onMessageReceived(const void * data,size_t length)162 void HalRpcClient::Callbacks::onMessageReceived(const void *data,
163                                                 size_t length) {
164   if (!android::chre::HostProtocolHost::decodeMessageFromChre(data, length,
165                                                               *this)) {
166     LOGE("Failed to decode message");
167   }
168   mSocketCallbacks->onMessageReceived(data, length);
169 }
170 
onConnected()171 void HalRpcClient::Callbacks::onConnected() {
172   mSocketCallbacks->onConnected();
173 }
174 
onConnectionAborted()175 void HalRpcClient::Callbacks::onConnectionAborted() {
176   mSocketCallbacks->onConnectionAborted();
177 }
178 
onDisconnected()179 void HalRpcClient::Callbacks::onDisconnected() {
180   // Close connections on CHRE reset.
181   mClient->close();
182   mSocketCallbacks->onDisconnected();
183 }
184 
185 // Message handlers.
186 
handleNanoappMessage(const NanoappMessageT & message)187 void HalRpcClient::Callbacks::handleNanoappMessage(
188     const NanoappMessageT &message) {
189   if (message.message_type == CHRE_MESSAGE_TYPE_RPC) {
190     pw::span packet(reinterpret_cast<const std::byte *>(message.message.data()),
191                     message.message.size());
192 
193     if (message.app_id == mClient->mServerNanoappId) {
194       pw::Status status = mClient->mRpcClient.ProcessPacket(packet);
195       if (status != pw::OkStatus()) {
196         LOGE("Failed to process the packet");
197       }
198     }
199   }
200 }
201 
handleHubInfoResponse(const HubInfoResponseT & response)202 void HalRpcClient::Callbacks::handleHubInfoResponse(
203     const HubInfoResponseT &response) {
204   {
205     std::lock_guard lock(mClient->mHubInfoMutex);
206     mClient->mMaxMessageLen = response.max_msg_len;
207   }
208   mClient->mHubInfoCond.notify_all();
209 }
210 
handleNanoappListResponse(const NanoappListResponseT & response)211 void HalRpcClient::Callbacks::handleNanoappListResponse(
212     const NanoappListResponseT &response) {
213   for (const std::unique_ptr<NanoappListEntryT> &nanoapp : response.nanoapps) {
214     if (nanoapp->app_id == mClient->mServerNanoappId) {
215       std::lock_guard lock(mClient->mNanoappMutex);
216       mClient->mServices.clear();
217       mClient->mServices.reserve(nanoapp->rpc_services.size());
218       for (const std::unique_ptr<NanoappRpcServiceT> &service :
219            nanoapp->rpc_services) {
220         mClient->mServices.push_back(*service);
221       }
222     }
223   }
224 
225   mClient->mNanoappCond.notify_all();
226 }
227 
228 }  // namespace android::chre