1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "FlatbufferModelBuilder"
18 
19 #include "FlatbufferModelBuilder.h"
20 
21 #include <LegacyUtils.h>
22 
23 #include "FlatbufferModelBuilderUtils.h"
24 #include "operation_converters/OperationConverterResolver.h"
25 
26 namespace android {
27 namespace nn {
28 
verifyModel(const tflite::Model * model)29 void FlatbufferModelBuilder::verifyModel(const tflite::Model* model) {
30     flatbuffers::Verifier verifier(mBuilder.GetBufferPointer(), mBuilder.GetSize());
31     CHECK(model != nullptr);
32     CHECK(model->Verify(verifier));
33 }
34 
initializeBufferVector()35 void FlatbufferModelBuilder::initializeBufferVector() {
36     mBufferVector.clear();
37 
38     std::vector<uint8_t> emptyData;
39     auto emptyBuffer = tflite::CreateBufferDirect(mBuilder, &emptyData);
40     mBufferVector.push_back(emptyBuffer);
41 }
42 
initializeOpCodeIndexForOperationType()43 void FlatbufferModelBuilder::initializeOpCodeIndexForOperationType() {
44     mOpCodeIndexForOperationType.clear();
45     mOpCodeIndexForOperationType.resize(kNumberOfOperationTypes, -1);
46 }
47 
createMetadataVector()48 std::vector<MetadataFlatbuffer> FlatbufferModelBuilder::createMetadataVector() {
49     std::vector<MetadataFlatbuffer> metadataVector;
50     for (uint32_t i = 0; i < mBufferVector.size(); i++) {
51         auto metadata = tflite::CreateMetadataDirect(mBuilder, std::to_string(i).c_str() /* name */,
52                                                      i /* buffer */);
53         metadataVector.push_back(metadata);
54     }
55     return metadataVector;
56 }
57 
createTfliteModel()58 Result<const tflite::Model*> FlatbufferModelBuilder::createTfliteModel() {
59     mModel = makeModel();
60 
61     // Initialize and clear data structures
62     initializeBufferVector();
63     mOpCodesVector.clear();
64     initializeOpCodeIndexForOperationType();
65 
66     // Generate subgraphs
67     auto subgraphsVector = NN_TRY(createSubGraphs());
68 
69     auto metadataVector = createMetadataVector();
70 
71     ModelFlatbuffer flatbufferModel = tflite::CreateModelDirect(
72             mBuilder, 3 /* version*/, &mOpCodesVector /* operator_codes */,
73             &subgraphsVector /* subgraphs */, nullptr /* description */,
74             &mBufferVector /* buffers */, nullptr /* metadata_buffer */,
75             &metadataVector /* metadata */);
76     mBuilder.Finish(flatbufferModel);
77 
78     const tflite::Model* tfliteModel = tflite::GetModel(mBuilder.GetBufferPointer());
79     verifyModel(tfliteModel);
80     return tfliteModel;
81 }
82 
createSubGraphFlatbuffer(const Model::Subgraph & subgraph)83 Result<SubGraphFlatbuffer> FlatbufferModelBuilder::createSubGraphFlatbuffer(
84         const Model::Subgraph& subgraph) {
85     // TFLite does not support unspecified ranks in Operands
86     NN_TRY(checkAllTensorOperandsHaveSpecifiedRank(subgraph.operands));
87     // TFLite does not support dynamic shapes for subgrah output Operands
88     NN_TRY(checkNoSubgraphOutputOperandsHaveDynamicShape(subgraph.operands));
89 
90     SubGraphContext context(&mModel, &subgraph, &mBuilder, &mOpCodesVector,
91                             &mOpCodeIndexForOperationType, &mBufferVector);
92     for (const Operation& operation : subgraph.operations) {
93         const IOperationConverter* converter =
94                 OperationConverterResolver::get()->findOperationConverter(operation.type);
95         NN_RET_CHECK(converter != nullptr)
96                 << "IOperationConverter not implemented for OperationType: " << operation.type;
97 
98         NN_TRY(converter->convert(operation, &context));
99     }
100 
101     for (uint32_t idx : subgraph.inputIndexes) {
102         context.addSubGraphInput(idx);
103     }
104     for (uint32_t idx : subgraph.outputIndexes) {
105         context.addSubGraphOutput(idx);
106     }
107 
108     return context.finish();
109 }
110 
createSubGraphs()111 Result<std::vector<SubGraphFlatbuffer>> FlatbufferModelBuilder::createSubGraphs() {
112     // We do not support control flow yet
113     NN_RET_CHECK(mModel.referenced.empty()) << "Control flow for multiple subgraphs not supported";
114 
115     std::vector<SubGraphFlatbuffer> subGraphVector;
116 
117     auto mainSubGraph = NN_TRY(createSubGraphFlatbuffer(mModel.main));
118     subGraphVector.push_back(mainSubGraph);
119 
120     return subGraphVector;
121 }
122 
123 }  // namespace nn
124 }  // namespace android
125