1 // Copyright (C) 2020 The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "GrpcGraph.h"
16 
17 #include <cstdlib>
18 
19 #include <android-base/logging.h>
20 #include <grpcpp/grpcpp.h>
21 
22 #include "ClientConfig.pb.h"
23 #include "GrpcGraph.h"
24 #include "InputFrame.h"
25 #include "RunnerComponent.h"
26 #include "prebuilt_interface.h"
27 #include "types/Status.h"
28 
29 namespace android {
30 namespace automotive {
31 namespace computepipe {
32 namespace graph {
33 namespace {
34 constexpr int64_t kRpcDeadlineMilliseconds = 100;
35 
36 template <class ResponseType, class RpcType>
FinishRpcAndGetResult(::grpc::ClientAsyncResponseReader<RpcType> * rpc,::grpc::CompletionQueue * cq,ResponseType * response)37 std::pair<Status, std::string> FinishRpcAndGetResult(
38         ::grpc::ClientAsyncResponseReader<RpcType>* rpc, ::grpc::CompletionQueue* cq,
39         ResponseType* response) {
40     int random_tag = rand();
41     ::grpc::Status grpcStatus;
42     rpc->Finish(response, &grpcStatus, reinterpret_cast<void*>(random_tag));
43     bool ok = false;
44     void* got_tag;
45     if (!cq->Next(&got_tag, &ok)) {
46         LOG(ERROR) << "Unexpected shutdown of the completion queue";
47         return std::pair(Status::FATAL_ERROR, "Unexpected shutdown of the completion queue");
48     }
49 
50     if (!ok) {
51         LOG(ERROR) << "Unable to complete RPC request";
52         return std::pair(Status::FATAL_ERROR, "Unable to complete RPC request");
53     }
54 
55     CHECK_EQ(got_tag, reinterpret_cast<void*>(random_tag));
56     if (!grpcStatus.ok()) {
57         std::string error_message =
58                 std::string("Grpc failed with error: ") + grpcStatus.error_message();
59         LOG(ERROR) << error_message;
60         return std::pair(Status::FATAL_ERROR, std::move(error_message));
61     }
62 
63     return std::pair(Status::SUCCESS, std::string(""));
64 }
65 
66 }  // namespace
67 
~GrpcGraph()68 GrpcGraph::~GrpcGraph() {
69     mStreamSetObserver.reset();
70 }
71 
GetGraphState() const72 PrebuiltGraphState GrpcGraph::GetGraphState() const {
73     std::lock_guard lock(mLock);
74     return mGraphState;
75 }
76 
GetStatus() const77 Status GrpcGraph::GetStatus() const {
78     std::lock_guard lock(mLock);
79     return mStatus;
80 }
81 
GetErrorMessage() const82 std::string GrpcGraph::GetErrorMessage() const {
83     std::lock_guard lock(mLock);
84     return mErrorMessage;
85 }
86 
initialize(const std::string & address,std::weak_ptr<PrebuiltEngineInterface> engineInterface)87 Status GrpcGraph::initialize(const std::string& address,
88                              std::weak_ptr<PrebuiltEngineInterface> engineInterface) {
89     std::shared_ptr<::grpc::ChannelCredentials> creds = ::grpc::InsecureChannelCredentials();
90     std::shared_ptr<::grpc::Channel> channel = ::grpc::CreateChannel(address, creds);
91     mGraphStub = proto::GrpcGraphService::NewStub(channel);
92     mEngineInterface = engineInterface;
93 
94     ::grpc::ClientContext context;
95     context.set_deadline(std::chrono::system_clock::now() +
96                          std::chrono::milliseconds(kRpcDeadlineMilliseconds));
97     ::grpc::CompletionQueue cq;
98 
99     proto::GraphOptionsRequest getGraphOptionsRequest;
100     std::unique_ptr<::grpc::ClientAsyncResponseReader<proto::GraphOptionsResponse>> rpc(
101             mGraphStub->AsyncGetGraphOptions(&context, getGraphOptionsRequest, &cq));
102 
103     proto::GraphOptionsResponse response;
104     auto [mStatus, mErrorMessage] = FinishRpcAndGetResult(rpc.get(), &cq, &response);
105 
106     if (mStatus != Status::SUCCESS) {
107         LOG(ERROR) << "Failed to get graph options: " << mErrorMessage;
108         return Status::FATAL_ERROR;
109     }
110 
111     std::string serialized_options = response.serialized_options();
112     if (!mGraphConfig.ParseFromString(serialized_options)) {
113         mErrorMessage = "Failed to parse graph options";
114         LOG(ERROR) << "Failed to parse graph options";
115         return Status::FATAL_ERROR;
116     }
117 
118     mGraphState = PrebuiltGraphState::STOPPED;
119     return Status::SUCCESS;
120 }
121 
122 // Function to confirm that there would be no further changes to the graph configuration. This
123 // needs to be called before starting the graph.
handleConfigPhase(const runner::ClientConfig & e)124 Status GrpcGraph::handleConfigPhase(const runner::ClientConfig& e) {
125     std::lock_guard lock(mLock);
126     if (mGraphState == PrebuiltGraphState::UNINITIALIZED) {
127         mStatus = Status::ILLEGAL_STATE;
128         return Status::ILLEGAL_STATE;
129     }
130 
131     // handleConfigPhase is a blocking call, so abort call is pointless for this RunnerEvent.
132     if (e.isAborted()) {
133         mStatus = Status::INVALID_ARGUMENT;
134         return mStatus;
135     } else if (e.isTransitionComplete()) {
136         mStatus = Status::SUCCESS;
137         return mStatus;
138     }
139 
140     ::grpc::ClientContext context;
141     context.set_deadline(std::chrono::system_clock::now() +
142                          std::chrono::milliseconds(kRpcDeadlineMilliseconds));
143     ::grpc::CompletionQueue cq;
144 
145     std::string serializedConfig = e.getSerializedClientConfig();
146     proto::SetGraphConfigRequest setGraphConfigRequest;
147     setGraphConfigRequest.set_serialized_config(std::move(serializedConfig));
148 
149     std::unique_ptr<::grpc::ClientAsyncResponseReader<proto::StatusResponse>> rpc(
150             mGraphStub->AsyncSetGraphConfig(&context, setGraphConfigRequest, &cq));
151 
152     proto::StatusResponse response;
153     auto [mStatus, mErrorMessage] = FinishRpcAndGetResult(rpc.get(), &cq, &response);
154     if (mStatus != Status::SUCCESS) {
155         LOG(ERROR) << "Rpc failed while trying to set configuration";
156         return mStatus;
157     }
158 
159     if (response.code() != proto::RemoteGraphStatusCode::SUCCESS) {
160         LOG(ERROR) << "Failed to cofngure remote graph. " << response.message();
161     }
162 
163     mStatus = static_cast<Status>(static_cast<int>(response.code()));
164     mErrorMessage = response.message();
165 
166     mStreamSetObserver = std::make_unique<StreamSetObserver>(e, this);
167 
168     return mStatus;
169 }
170 
171 // Starts the graph.
handleExecutionPhase(const runner::RunnerEvent & e)172 Status GrpcGraph::handleExecutionPhase(const runner::RunnerEvent& e) {
173     std::lock_guard lock(mLock);
174     if (mGraphState != PrebuiltGraphState::STOPPED || mStreamSetObserver == nullptr) {
175         mStatus = Status::ILLEGAL_STATE;
176         return mStatus;
177     }
178 
179     if (e.isAborted()) {
180         // Starting the graph is a blocking call and cannot be aborted in between.
181         mStatus = Status::INVALID_ARGUMENT;
182         return mStatus;
183     } else if (e.isTransitionComplete()) {
184         mStatus = Status::SUCCESS;
185         return mStatus;
186     }
187 
188     // Start observing the output streams
189     mStatus = mStreamSetObserver->startObservingStreams();
190     if (mStatus != Status::SUCCESS) {
191         mErrorMessage = "Failed to observe output streams";
192         return mStatus;
193     }
194 
195     ::grpc::ClientContext context;
196     context.set_deadline(std::chrono::system_clock::now() +
197                          std::chrono::milliseconds(kRpcDeadlineMilliseconds));
198 
199     proto::StartGraphExecutionRequest startExecutionRequest;
200     ::grpc::CompletionQueue cq;
201     std::unique_ptr<::grpc::ClientAsyncResponseReader<proto::StatusResponse>> rpc(
202             mGraphStub->AsyncStartGraphExecution(&context, startExecutionRequest, &cq));
203 
204     proto::StatusResponse response;
205     auto [mStatus, mErrorMessage] = FinishRpcAndGetResult(rpc.get(), &cq, &response);
206     if (mStatus != Status::SUCCESS) {
207         LOG(ERROR) << "Failed to start graph execution";
208         return mStatus;
209     }
210 
211     mStatus = static_cast<Status>(static_cast<int>(response.code()));
212     mErrorMessage = response.message();
213 
214     if (mStatus == Status::SUCCESS) {
215         mGraphState = PrebuiltGraphState::RUNNING;
216     }
217 
218     return mStatus;
219 }
220 
221 // Stops the graph while letting the graph flush output packets in flight.
handleStopWithFlushPhase(const runner::RunnerEvent & e)222 Status GrpcGraph::handleStopWithFlushPhase(const runner::RunnerEvent& e) {
223     std::lock_guard lock(mLock);
224     if (mGraphState != PrebuiltGraphState::RUNNING) {
225         return Status::ILLEGAL_STATE;
226     }
227 
228     if (e.isAborted()) {
229         return Status::INVALID_ARGUMENT;
230     } else if (e.isTransitionComplete()) {
231         return Status::SUCCESS;
232     }
233 
234     ::grpc::ClientContext context;
235     context.set_deadline(std::chrono::system_clock::now() +
236                          std::chrono::milliseconds(kRpcDeadlineMilliseconds));
237 
238     proto::StopGraphExecutionRequest stopExecutionRequest;
239     stopExecutionRequest.set_stop_immediate(false);
240     ::grpc::CompletionQueue cq;
241     std::unique_ptr<::grpc::ClientAsyncResponseReader<proto::StatusResponse>> rpc(
242             mGraphStub->AsyncStopGraphExecution(&context, stopExecutionRequest, &cq));
243 
244     proto::StatusResponse response;
245     auto [mStatus, mErrorMessage] = FinishRpcAndGetResult(rpc.get(), &cq, &response);
246     if (mStatus != Status::SUCCESS) {
247         LOG(ERROR) << "Failed to stop graph execution";
248         return Status::FATAL_ERROR;
249     }
250 
251     // Stop observing streams immendiately.
252     mStreamSetObserver->stopObservingStreams(false);
253 
254     mStatus = static_cast<Status>(static_cast<int>(response.code()));
255     mErrorMessage = response.message();
256 
257     if (mStatus == Status::SUCCESS) {
258         mGraphState = PrebuiltGraphState::FLUSHING;
259     }
260 
261     return mStatus;
262 }
263 
264 // Stops the graph and cancels all the output packets.
handleStopImmediatePhase(const runner::RunnerEvent & e)265 Status GrpcGraph::handleStopImmediatePhase(const runner::RunnerEvent& e) {
266     std::lock_guard lock(mLock);
267     if (mGraphState != PrebuiltGraphState::RUNNING) {
268         return Status::ILLEGAL_STATE;
269     }
270 
271     if (e.isAborted()) {
272         return Status::INVALID_ARGUMENT;
273     } else if (e.isTransitionComplete()) {
274         return Status::SUCCESS;
275     }
276 
277     ::grpc::ClientContext context;
278     context.set_deadline(std::chrono::system_clock::now() +
279                          std::chrono::milliseconds(kRpcDeadlineMilliseconds));
280 
281     proto::StopGraphExecutionRequest stopExecutionRequest;
282     stopExecutionRequest.set_stop_immediate(true);
283     ::grpc::CompletionQueue cq;
284     std::unique_ptr<::grpc::ClientAsyncResponseReader<proto::StatusResponse>> rpc(
285             mGraphStub->AsyncStopGraphExecution(&context, stopExecutionRequest, &cq));
286 
287     proto::StatusResponse response;
288     auto [mStatus, mErrorMessage] = FinishRpcAndGetResult(rpc.get(), &cq, &response);
289     if (mStatus != Status::SUCCESS) {
290         LOG(ERROR) << "Failed to stop graph execution";
291         return Status::FATAL_ERROR;
292     }
293 
294     mStatus = static_cast<Status>(static_cast<int>(response.code()));
295     mErrorMessage = response.message();
296 
297     // Stop observing streams immendiately.
298     mStreamSetObserver->stopObservingStreams(true);
299 
300     if (mStatus == Status::SUCCESS) {
301         mGraphState = PrebuiltGraphState::STOPPED;
302     }
303     return mStatus;
304 }
305 
handleResetPhase(const runner::RunnerEvent & e)306 Status GrpcGraph::handleResetPhase(const runner::RunnerEvent& e) {
307     std::lock_guard lock(mLock);
308     if (mGraphState != PrebuiltGraphState::STOPPED) {
309         return Status::ILLEGAL_STATE;
310     }
311 
312     if (e.isAborted()) {
313         return Status::INVALID_ARGUMENT;
314     } else if (e.isTransitionComplete()) {
315         return Status::SUCCESS;
316     }
317 
318     ::grpc::ClientContext context;
319     context.set_deadline(std::chrono::system_clock::now() +
320                          std::chrono::milliseconds(kRpcDeadlineMilliseconds));
321 
322     proto::ResetGraphRequest resetGraphRequest;
323     ::grpc::CompletionQueue cq;
324     std::unique_ptr<::grpc::ClientAsyncResponseReader<proto::StatusResponse>> rpc(
325             mGraphStub->AsyncResetGraph(&context, resetGraphRequest, &cq));
326 
327     proto::StatusResponse response;
328     auto [mStatus, mErrorMessage] = FinishRpcAndGetResult(rpc.get(), &cq, &response);
329     if (mStatus != Status::SUCCESS) {
330         LOG(ERROR) << "Failed to stop graph execution";
331         return Status::FATAL_ERROR;
332     }
333 
334     mStatus = static_cast<Status>(static_cast<int>(response.code()));
335     mErrorMessage = response.message();
336     mStreamSetObserver.reset();
337 
338     return mStatus;
339 }
340 
SetInputStreamData(int,int64_t,const std::string &)341 Status GrpcGraph::SetInputStreamData(int /*streamIndex*/, int64_t /*timestamp*/,
342                                      const std::string& /*streamData*/) {
343     LOG(ERROR) << "Cannot set input stream for remote graphs";
344     return Status::FATAL_ERROR;
345 }
346 
SetInputStreamPixelData(int,int64_t,const runner::InputFrame &)347 Status GrpcGraph::SetInputStreamPixelData(int /*streamIndex*/, int64_t /*timestamp*/,
348                                           const runner::InputFrame& /*inputFrame*/) {
349     LOG(ERROR) << "Cannot set input streams for remote graphs";
350     return Status::FATAL_ERROR;
351 }
352 
StartGraphProfiling()353 Status GrpcGraph::StartGraphProfiling() {
354     std::lock_guard lock(mLock);
355     if (mGraphState != PrebuiltGraphState::RUNNING) {
356         return Status::ILLEGAL_STATE;
357     }
358 
359     ::grpc::ClientContext context;
360     context.set_deadline(std::chrono::system_clock::now() +
361                          std::chrono::milliseconds(kRpcDeadlineMilliseconds));
362 
363     proto::StartGraphProfilingRequest startProfilingRequest;
364     ::grpc::CompletionQueue cq;
365     std::unique_ptr<::grpc::ClientAsyncResponseReader<proto::StatusResponse>> rpc(
366             mGraphStub->AsyncStartGraphProfiling(&context, startProfilingRequest, &cq));
367 
368     proto::StatusResponse response;
369     auto [mStatus, mErrorMessage] = FinishRpcAndGetResult(rpc.get(), &cq, &response);
370     if (mStatus != Status::SUCCESS) {
371         LOG(ERROR) << "Failed to start graph profiling";
372         return Status::FATAL_ERROR;
373     }
374 
375     mStatus = static_cast<Status>(static_cast<int>(response.code()));
376     mErrorMessage = response.message();
377 
378     return mStatus;
379 }
380 
StopGraphProfiling()381 Status GrpcGraph::StopGraphProfiling() {
382     // Stopping profiling after graph has already stopped can be a no-op
383     ::grpc::ClientContext context;
384     context.set_deadline(std::chrono::system_clock::now() +
385                          std::chrono::milliseconds(kRpcDeadlineMilliseconds));
386 
387     proto::StopGraphProfilingRequest stopProfilingRequest;
388     ::grpc::CompletionQueue cq;
389     std::unique_ptr<::grpc::ClientAsyncResponseReader<proto::StatusResponse>> rpc(
390             mGraphStub->AsyncStopGraphProfiling(&context, stopProfilingRequest, &cq));
391 
392     proto::StatusResponse response;
393     auto [mStatus, mErrorMessage] = FinishRpcAndGetResult(rpc.get(), &cq, &response);
394     if (mStatus != Status::SUCCESS) {
395         LOG(ERROR) << "Failed to stop graph profiling";
396         return Status::FATAL_ERROR;
397     }
398 
399     mStatus = static_cast<Status>(static_cast<int>(response.code()));
400     mErrorMessage = response.message();
401 
402     return mStatus;
403 }
404 
GetDebugInfo()405 std::string GrpcGraph::GetDebugInfo() {
406     ::grpc::ClientContext context;
407     context.set_deadline(std::chrono::system_clock::now() +
408                          std::chrono::milliseconds(kRpcDeadlineMilliseconds));
409 
410     proto::ProfilingDataRequest profilingDataRequest;
411     ::grpc::CompletionQueue cq;
412     std::unique_ptr<::grpc::ClientAsyncResponseReader<proto::ProfilingDataResponse>> rpc(
413             mGraphStub->AsyncGetProfilingData(&context, profilingDataRequest, &cq));
414 
415     proto::ProfilingDataResponse response;
416     auto [mStatus, mErrorMessage] = FinishRpcAndGetResult(rpc.get(), &cq, &response);
417     if (mStatus != Status::SUCCESS) {
418         LOG(ERROR) << "Failed to get profiling info";
419         return "";
420     }
421 
422     return response.data();
423 }
424 
dispatchPixelData(int streamId,int64_t timestamp_us,const runner::InputFrame & frame)425 void GrpcGraph::dispatchPixelData(int streamId, int64_t timestamp_us,
426                                   const runner::InputFrame& frame) {
427     std::shared_ptr<PrebuiltEngineInterface> engineInterface = mEngineInterface.lock();
428     if (engineInterface) {
429         return engineInterface->DispatchPixelData(streamId, timestamp_us, frame);
430     }
431 }
432 
dispatchSerializedData(int streamId,int64_t timestamp_us,std::string && serialized_data)433 void GrpcGraph::dispatchSerializedData(int streamId, int64_t timestamp_us,
434                                        std::string&& serialized_data) {
435     std::shared_ptr<PrebuiltEngineInterface> engineInterface = mEngineInterface.lock();
436     if (engineInterface) {
437         return engineInterface->DispatchSerializedData(streamId, timestamp_us,
438                                                        std::move(serialized_data));
439     }
440 }
441 
dispatchGraphTerminationMessage(Status status,std::string && errorMessage)442 void GrpcGraph::dispatchGraphTerminationMessage(Status status, std::string&& errorMessage) {
443     std::lock_guard lock(mLock);
444     mErrorMessage = std::move(errorMessage);
445     mStatus = status;
446     mGraphState = PrebuiltGraphState::STOPPED;
447     std::shared_ptr<PrebuiltEngineInterface> engineInterface = mEngineInterface.lock();
448     if (engineInterface) {
449         std::string errorMessageTmp = mErrorMessage;
450         engineInterface->DispatchGraphTerminationMessage(mStatus, std::move(errorMessageTmp));
451     }
452 }
453 
GetRemoteGraphFromAddress(const std::string & address,std::weak_ptr<PrebuiltEngineInterface> engineInterface)454 std::unique_ptr<PrebuiltGraph> GetRemoteGraphFromAddress(
455         const std::string& address, std::weak_ptr<PrebuiltEngineInterface> engineInterface) {
456     auto prebuiltGraph = std::make_unique<GrpcGraph>();
457     Status status = prebuiltGraph->initialize(address, engineInterface);
458     if (status != Status::SUCCESS) {
459         return nullptr;
460     }
461 
462     return prebuiltGraph;
463 }
464 
465 }  // namespace graph
466 }  // namespace computepipe
467 }  // namespace automotive
468 }  // namespace android
469