1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "FibonacciDriver"
18
19 #include "FibonacciDriver.h"
20
21 #include <HalInterfaces.h>
22 #include <OperationResolver.h>
23 #include <OperationsUtils.h>
24 #include <Utils.h>
25 #include <ValidateHal.h>
26 #include <nnapi/Types.h>
27
28 #include <vector>
29
30 #include "FibonacciExtension.h"
31 #include "NeuralNetworksExtensions.h"
32
33 namespace android {
34 namespace nn {
35 namespace sample_driver {
36 namespace {
37
38 const uint32_t kTypeWithinExtensionMask = (1 << kExtensionTypeBits) - 1;
39
40 namespace fibonacci_op {
41
42 constexpr char kOperationName[] = "EXAMPLE_FIBONACCI";
43
44 constexpr uint32_t kNumInputs = 1;
45 constexpr uint32_t kInputN = 0;
46
47 constexpr uint32_t kNumOutputs = 1;
48 constexpr uint32_t kOutputTensor = 0;
49
getFibonacciExtensionPrefix(const V1_3::Model & model,uint16_t * prefix)50 bool getFibonacciExtensionPrefix(const V1_3::Model& model, uint16_t* prefix) {
51 NN_RET_CHECK_EQ(model.extensionNameToPrefix.size(), 1u); // Assumes no other extensions in use.
52 NN_RET_CHECK_EQ(model.extensionNameToPrefix[0].name, EXAMPLE_FIBONACCI_EXTENSION_NAME);
53 *prefix = model.extensionNameToPrefix[0].prefix;
54 return true;
55 }
56
isFibonacciOperation(const V1_3::Operation & operation,const V1_3::Model & model)57 bool isFibonacciOperation(const V1_3::Operation& operation, const V1_3::Model& model) {
58 int32_t operationType = static_cast<int32_t>(operation.type);
59 uint16_t prefix;
60 NN_RET_CHECK(getFibonacciExtensionPrefix(model, &prefix));
61 NN_RET_CHECK_EQ(operationType, (prefix << kExtensionTypeBits) | EXAMPLE_FIBONACCI);
62 return true;
63 }
64
validate(const V1_3::Operation & operation,const V1_3::Model & model)65 bool validate(const V1_3::Operation& operation, const V1_3::Model& model) {
66 NN_RET_CHECK(isFibonacciOperation(operation, model));
67 NN_RET_CHECK_EQ(operation.inputs.size(), kNumInputs);
68 NN_RET_CHECK_EQ(operation.outputs.size(), kNumOutputs);
69 int32_t inputType = static_cast<int32_t>(model.main.operands[operation.inputs[0]].type);
70 int32_t outputType = static_cast<int32_t>(model.main.operands[operation.outputs[0]].type);
71 uint16_t prefix;
72 NN_RET_CHECK(getFibonacciExtensionPrefix(model, &prefix));
73 NN_RET_CHECK(inputType == ((prefix << kExtensionTypeBits) | EXAMPLE_INT64) ||
74 inputType == ANEURALNETWORKS_TENSOR_FLOAT32);
75 NN_RET_CHECK(outputType == ((prefix << kExtensionTypeBits) | EXAMPLE_TENSOR_QUANT64_ASYMM) ||
76 outputType == ANEURALNETWORKS_TENSOR_FLOAT32);
77 return true;
78 }
79
prepare(IOperationExecutionContext * context)80 bool prepare(IOperationExecutionContext* context) {
81 int64_t n;
82 if (context->getInputType(kInputN) == OperandType::TENSOR_FLOAT32) {
83 n = static_cast<int64_t>(context->getInputValue<float>(kInputN));
84 } else {
85 n = context->getInputValue<int64_t>(kInputN);
86 }
87 NN_RET_CHECK_GE(n, 1);
88 Shape output = context->getOutputShape(kOutputTensor);
89 output.dimensions = {static_cast<uint32_t>(n)};
90 return context->setOutputShape(kOutputTensor, output);
91 }
92
93 template <typename ScaleT, typename ZeroPointT, typename OutputT>
compute(int32_t n,ScaleT outputScale,ZeroPointT outputZeroPoint,OutputT * output)94 bool compute(int32_t n, ScaleT outputScale, ZeroPointT outputZeroPoint, OutputT* output) {
95 // Compute the Fibonacci numbers.
96 if (n >= 1) {
97 output[0] = 1;
98 }
99 if (n >= 2) {
100 output[1] = 1;
101 }
102 if (n >= 3) {
103 for (int32_t i = 2; i < n; ++i) {
104 output[i] = output[i - 1] + output[i - 2];
105 }
106 }
107
108 // Quantize output.
109 for (int32_t i = 0; i < n; ++i) {
110 output[i] = output[i] / outputScale + outputZeroPoint;
111 }
112
113 return true;
114 }
115
execute(IOperationExecutionContext * context)116 bool execute(IOperationExecutionContext* context) {
117 int64_t n;
118 if (context->getInputType(kInputN) == OperandType::TENSOR_FLOAT32) {
119 n = static_cast<int64_t>(context->getInputValue<float>(kInputN));
120 } else {
121 n = context->getInputValue<int64_t>(kInputN);
122 }
123 if (context->getOutputType(kOutputTensor) == OperandType::TENSOR_FLOAT32) {
124 float* output = context->getOutputBuffer<float>(kOutputTensor);
125 return compute(n, /*scale=*/1.0, /*zeroPoint=*/0, output);
126 } else {
127 uint64_t* output = context->getOutputBuffer<uint64_t>(kOutputTensor);
128 Shape outputShape = context->getOutputShape(kOutputTensor);
129 auto outputQuant = reinterpret_cast<const ExampleQuant64AsymmParams*>(
130 std::get<Operand::ExtensionParams>(outputShape.extraParams).data());
131 return compute(n, outputQuant->scale, outputQuant->zeroPoint, output);
132 }
133 }
134
135 } // namespace fibonacci_op
136 } // namespace
137
findOperation(OperationType operationType) const138 const OperationRegistration* FibonacciOperationResolver::findOperation(
139 OperationType operationType) const {
140 // .validate is omitted because it's not used by the extension driver.
141 static OperationRegistration operationRegistration(operationType, fibonacci_op::kOperationName,
142 nullptr, fibonacci_op::prepare,
143 fibonacci_op::execute, {});
144 uint16_t prefix = static_cast<int32_t>(operationType) >> kExtensionTypeBits;
145 uint16_t typeWithinExtension = static_cast<int32_t>(operationType) & kTypeWithinExtensionMask;
146 // Assumes no other extensions in use.
147 return prefix != 0 && typeWithinExtension == EXAMPLE_FIBONACCI ? &operationRegistration
148 : nullptr;
149 }
150
getSupportedExtensions(getSupportedExtensions_cb cb)151 hardware::Return<void> FibonacciDriver::getSupportedExtensions(getSupportedExtensions_cb cb) {
152 cb(V1_0::ErrorStatus::NONE,
153 {
154 {
155 .name = EXAMPLE_FIBONACCI_EXTENSION_NAME,
156 .operandTypes =
157 {
158 {
159 .type = EXAMPLE_INT64,
160 .isTensor = false,
161 .byteSize = 8,
162 },
163 {
164 .type = EXAMPLE_TENSOR_QUANT64_ASYMM,
165 .isTensor = true,
166 .byteSize = 8,
167 },
168 },
169 },
170 });
171 return hardware::Void();
172 }
173
getCapabilities_1_3(getCapabilities_1_3_cb cb)174 hardware::Return<void> FibonacciDriver::getCapabilities_1_3(getCapabilities_1_3_cb cb) {
175 android::nn::initVLogMask();
176 VLOG(DRIVER) << "getCapabilities()";
177 static const V1_0::PerformanceInfo kPerf = {.execTime = 1.0f, .powerUsage = 1.0f};
178 V1_3::Capabilities capabilities = {
179 .relaxedFloat32toFloat16PerformanceScalar = kPerf,
180 .relaxedFloat32toFloat16PerformanceTensor = kPerf,
181 .operandPerformance = nonExtensionOperandPerformance<HalVersion::V1_3>(kPerf),
182 .ifPerformance = kPerf,
183 .whilePerformance = kPerf};
184 cb(V1_3::ErrorStatus::NONE, capabilities);
185 return hardware::Void();
186 }
187
getSupportedOperations_1_3(const V1_3::Model & model,getSupportedOperations_1_3_cb cb)188 hardware::Return<void> FibonacciDriver::getSupportedOperations_1_3(
189 const V1_3::Model& model, getSupportedOperations_1_3_cb cb) {
190 VLOG(DRIVER) << "getSupportedOperations()";
191 if (!validateModel(model)) {
192 cb(V1_3::ErrorStatus::INVALID_ARGUMENT, {});
193 return hardware::Void();
194 }
195 const size_t count = model.main.operations.size();
196 std::vector<bool> supported(count);
197 for (size_t i = 0; i < count; ++i) {
198 const V1_3::Operation& operation = model.main.operations[i];
199 if (fibonacci_op::isFibonacciOperation(operation, model)) {
200 if (!fibonacci_op::validate(operation, model)) {
201 cb(V1_3::ErrorStatus::INVALID_ARGUMENT, {});
202 return hardware::Void();
203 }
204 supported[i] = true;
205 }
206 }
207 cb(V1_3::ErrorStatus::NONE, supported);
208 return hardware::Void();
209 }
210
211 } // namespace sample_driver
212 } // namespace nn
213 } // namespace android
214