1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "Converter.h"
18
19 #include <android-base/logging.h>
20 #include <nnapi/TypeUtils.h>
21
22 #include <algorithm>
23 #include <random>
24 #include <utility>
25 #include <vector>
26
27 namespace android::nn::fuzz {
28 namespace {
29
30 using namespace test_helper;
31
32 constexpr uint32_t kMaxSize = 65536;
33
convert(OperandType type)34 TestOperandType convert(OperandType type) {
35 return static_cast<TestOperandType>(type);
36 }
37
convert(OperationType type)38 TestOperationType convert(OperationType type) {
39 return static_cast<TestOperationType>(type);
40 }
41
convert(OperandLifeTime lifetime)42 TestOperandLifeTime convert(OperandLifeTime lifetime) {
43 return static_cast<TestOperandLifeTime>(lifetime);
44 }
45
convert(const Scales & scales)46 std::vector<float> convert(const Scales& scales) {
47 const auto& repeatedScale = scales.scale();
48 return std::vector<float>(repeatedScale.begin(), repeatedScale.end());
49 }
50
convert(const SymmPerChannelQuantParams & params)51 TestSymmPerChannelQuantParams convert(const SymmPerChannelQuantParams& params) {
52 std::vector<float> scales = convert(params.scales());
53 const uint32_t channelDim = params.channel_dim();
54 return {.scales = std::move(scales), .channelDim = channelDim};
55 }
56
convert(const Dimensions & dimensions)57 std::vector<uint32_t> convert(const Dimensions& dimensions) {
58 const auto& repeatedDimension = dimensions.dimension();
59 return std::vector<uint32_t>(repeatedDimension.begin(), repeatedDimension.end());
60 }
61
convert(size_t size,bool initialize,const Buffer & buffer)62 TestBuffer convert(size_t size, bool initialize, const Buffer& buffer) {
63 switch (buffer.type_case()) {
64 case Buffer::TypeCase::TYPE_NOT_SET:
65 case Buffer::TypeCase::kEmpty:
66 break;
67 case Buffer::TypeCase::kScalar: {
68 const uint32_t scalar = buffer.scalar();
69 return TestBuffer(sizeof(scalar), &scalar);
70 }
71 case Buffer::TypeCase::kRandomSeed: {
72 if (!initialize) {
73 return TestBuffer(size);
74 }
75 const uint32_t randomSeed = buffer.random_seed();
76 std::default_random_engine generator{randomSeed};
77 return TestBuffer::createRandom(size, &generator);
78 }
79 }
80 return TestBuffer();
81 }
82
convert(const Operand & operand)83 TestOperand convert(const Operand& operand) {
84 const TestOperandType type = convert(operand.type());
85 std::vector<uint32_t> dimensions = convert(operand.dimensions());
86 const float scale = operand.scale();
87 const int32_t zeroPoint = operand.zero_point();
88 const TestOperandLifeTime lifetime = convert(operand.lifetime());
89 auto channelQuant = convert(operand.channel_quant());
90
91 const bool isIgnored = lifetime == TestOperandLifeTime::SUBGRAPH_OUTPUT;
92 const auto opType = static_cast<nn::OperandType>(type);
93 const size_t size = getNonExtensionSize(opType, dimensions).value_or(0) % kMaxSize;
94 const bool makeEmpty = (lifetime == TestOperandLifeTime::NO_VALUE ||
95 lifetime == TestOperandLifeTime::TEMPORARY_VARIABLE);
96 const size_t bufferSize = makeEmpty ? 0 : size;
97 TestBuffer data = convert(bufferSize, !isIgnored, operand.data());
98
99 return {.type = type,
100 .dimensions = std::move(dimensions),
101 .numberOfConsumers = 0,
102 .scale = scale,
103 .zeroPoint = zeroPoint,
104 .lifetime = lifetime,
105 .channelQuant = std::move(channelQuant),
106 .isIgnored = isIgnored,
107 .data = std::move(data)};
108 }
109
convert(const Operands & operands)110 std::vector<TestOperand> convert(const Operands& operands) {
111 std::vector<TestOperand> testOperands;
112 testOperands.reserve(operands.operand_size());
113 const auto& repeatedOperand = operands.operand();
114 std::transform(repeatedOperand.begin(), repeatedOperand.end(), std::back_inserter(testOperands),
115 [](const auto& operand) { return convert(operand); });
116 return testOperands;
117 }
118
convert(const Indexes & indexes)119 std::vector<uint32_t> convert(const Indexes& indexes) {
120 const auto& repeatedIndex = indexes.index();
121 return std::vector<uint32_t>(repeatedIndex.begin(), repeatedIndex.end());
122 }
123
convert(const Operation & operation)124 TestOperation convert(const Operation& operation) {
125 const TestOperationType type = convert(operation.type());
126 std::vector<uint32_t> inputs = convert(operation.inputs());
127 std::vector<uint32_t> outputs = convert(operation.outputs());
128 return {.type = type, .inputs = std::move(inputs), .outputs = std::move(outputs)};
129 }
130
convert(const Operations & operations)131 std::vector<TestOperation> convert(const Operations& operations) {
132 std::vector<TestOperation> testOperations;
133 testOperations.reserve(operations.operation_size());
134 const auto& repeatedOperation = operations.operation();
135 std::transform(repeatedOperation.begin(), repeatedOperation.end(),
136 std::back_inserter(testOperations),
137 [](const auto& operation) { return convert(operation); });
138 return testOperations;
139 }
140
calculateNumberOfConsumers(const std::vector<TestOperation> & operations,std::vector<TestOperand> * operands)141 void calculateNumberOfConsumers(const std::vector<TestOperation>& operations,
142 std::vector<TestOperand>* operands) {
143 CHECK(operands != nullptr);
144 const auto addConsumer = [operands](uint32_t operand) {
145 if (operand < operands->size()) {
146 operands->at(operand).numberOfConsumers++;
147 }
148 };
149 const auto addAllConsumers = [&addConsumer](const TestOperation& operation) {
150 std::for_each(operation.inputs.begin(), operation.inputs.end(), addConsumer);
151 };
152 std::for_each(operations.begin(), operations.end(), addAllConsumers);
153 }
154
convert(const Subgraph & subgraph)155 TestSubgraph convert(const Subgraph& subgraph) {
156 std::vector<TestOperand> operands = convert(subgraph.operands());
157 std::vector<TestOperation> operations = convert(subgraph.operations());
158 std::vector<uint32_t> inputIndexes = convert(subgraph.input_indexes());
159 std::vector<uint32_t> outputIndexes = convert(subgraph.output_indexes());
160
161 // Calculate number of consumers.
162 calculateNumberOfConsumers(operations, &operands);
163
164 return {.operands = std::move(operands),
165 .operations = std::move(operations),
166 .inputIndexes = std::move(inputIndexes),
167 .outputIndexes = std::move(outputIndexes)};
168 }
169
convert(const Subgraphs & subgraphs)170 std::vector<TestSubgraph> convert(const Subgraphs& subgraphs) {
171 std::vector<TestSubgraph> testSubgraphs;
172 testSubgraphs.reserve(subgraphs.subgraph_size());
173 const auto& repeatedSubgraph = subgraphs.subgraph();
174 std::transform(repeatedSubgraph.begin(), repeatedSubgraph.end(),
175 std::back_inserter(testSubgraphs),
176 [](const auto& subgraph) { return convert(subgraph); });
177 return testSubgraphs;
178 }
179
convert(const Model & model)180 TestModel convert(const Model& model) {
181 TestSubgraph main = convert(model.main());
182 std::vector<TestSubgraph> referenced = convert(model.referenced());
183 const bool isRelaxed = model.is_relaxed();
184
185 return {.main = std::move(main), .referenced = std::move(referenced), .isRelaxed = isRelaxed};
186 }
187
188 } // anonymous namespace
189
convertToTestModel(const Test & model)190 TestModel convertToTestModel(const Test& model) {
191 return convert(model.model());
192 }
193
194 } // namespace android::nn::fuzz
195