1 /*
2 * Copyright (C) 2022 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "FlatbufferModelBuilder"
18
19 #include "FlatbufferModelBuilder.h"
20
21 #include <LegacyUtils.h>
22
23 #include "FlatbufferModelBuilderUtils.h"
24 #include "operation_converters/OperationConverterResolver.h"
25
26 namespace android {
27 namespace nn {
28
verifyModel(const tflite::Model * model)29 void FlatbufferModelBuilder::verifyModel(const tflite::Model* model) {
30 flatbuffers::Verifier verifier(mBuilder.GetBufferPointer(), mBuilder.GetSize());
31 CHECK(model != nullptr);
32 CHECK(model->Verify(verifier));
33 }
34
initializeBufferVector()35 void FlatbufferModelBuilder::initializeBufferVector() {
36 mBufferVector.clear();
37
38 std::vector<uint8_t> emptyData;
39 auto emptyBuffer = tflite::CreateBufferDirect(mBuilder, &emptyData);
40 mBufferVector.push_back(emptyBuffer);
41 }
42
initializeOpCodeIndexForOperationType()43 void FlatbufferModelBuilder::initializeOpCodeIndexForOperationType() {
44 mOpCodeIndexForOperationType.clear();
45 mOpCodeIndexForOperationType.resize(kNumberOfOperationTypes, -1);
46 }
47
createMetadataVector()48 std::vector<MetadataFlatbuffer> FlatbufferModelBuilder::createMetadataVector() {
49 std::vector<MetadataFlatbuffer> metadataVector;
50 for (uint32_t i = 0; i < mBufferVector.size(); i++) {
51 auto metadata = tflite::CreateMetadataDirect(mBuilder, std::to_string(i).c_str() /* name */,
52 i /* buffer */);
53 metadataVector.push_back(metadata);
54 }
55 return metadataVector;
56 }
57
createTfliteModel()58 Result<const tflite::Model*> FlatbufferModelBuilder::createTfliteModel() {
59 mModel = makeModel();
60
61 // Initialize and clear data structures
62 initializeBufferVector();
63 mOpCodesVector.clear();
64 initializeOpCodeIndexForOperationType();
65
66 // Generate subgraphs
67 auto subgraphsVector = NN_TRY(createSubGraphs());
68
69 auto metadataVector = createMetadataVector();
70
71 ModelFlatbuffer flatbufferModel = tflite::CreateModelDirect(
72 mBuilder, 3 /* version*/, &mOpCodesVector /* operator_codes */,
73 &subgraphsVector /* subgraphs */, nullptr /* description */,
74 &mBufferVector /* buffers */, nullptr /* metadata_buffer */,
75 &metadataVector /* metadata */);
76 mBuilder.Finish(flatbufferModel);
77
78 const tflite::Model* tfliteModel = tflite::GetModel(mBuilder.GetBufferPointer());
79 verifyModel(tfliteModel);
80 return tfliteModel;
81 }
82
createSubGraphFlatbuffer(const Model::Subgraph & subgraph)83 Result<SubGraphFlatbuffer> FlatbufferModelBuilder::createSubGraphFlatbuffer(
84 const Model::Subgraph& subgraph) {
85 // TFLite does not support unspecified ranks in Operands
86 NN_TRY(checkAllTensorOperandsHaveSpecifiedRank(subgraph.operands));
87 // TFLite does not support dynamic shapes for subgrah output Operands
88 NN_TRY(checkNoSubgraphOutputOperandsHaveDynamicShape(subgraph.operands));
89
90 SubGraphContext context(&mModel, &subgraph, &mBuilder, &mOpCodesVector,
91 &mOpCodeIndexForOperationType, &mBufferVector);
92 for (const Operation& operation : subgraph.operations) {
93 const IOperationConverter* converter =
94 OperationConverterResolver::get()->findOperationConverter(operation.type);
95 NN_RET_CHECK(converter != nullptr)
96 << "IOperationConverter not implemented for OperationType: " << operation.type;
97
98 NN_TRY(converter->convert(operation, &context));
99 }
100
101 for (uint32_t idx : subgraph.inputIndexes) {
102 context.addSubGraphInput(idx);
103 }
104 for (uint32_t idx : subgraph.outputIndexes) {
105 context.addSubGraphOutput(idx);
106 }
107
108 return context.finish();
109 }
110
createSubGraphs()111 Result<std::vector<SubGraphFlatbuffer>> FlatbufferModelBuilder::createSubGraphs() {
112 // We do not support control flow yet
113 NN_RET_CHECK(mModel.referenced.empty()) << "Control flow for multiple subgraphs not supported";
114
115 std::vector<SubGraphFlatbuffer> subGraphVector;
116
117 auto mainSubGraph = NN_TRY(createSubGraphFlatbuffer(mModel.main));
118 subGraphVector.push_back(mainSubGraph);
119
120 return subGraphVector;
121 }
122
123 } // namespace nn
124 } // namespace android
125