1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "Conversions.h"
18
19 #include <android-base/logging.h>
20 #include <android/hardware/neuralnetworks/1.2/types.h>
21 #include <nnapi/OperandTypes.h>
22 #include <nnapi/OperationTypes.h>
23 #include <nnapi/Result.h>
24 #include <nnapi/SharedMemory.h>
25 #include <nnapi/TypeUtils.h>
26 #include <nnapi/Types.h>
27 #include <nnapi/Validation.h>
28 #include <nnapi/hal/1.0/Conversions.h>
29 #include <nnapi/hal/1.1/Conversions.h>
30 #include <nnapi/hal/CommonUtils.h>
31
32 #include <algorithm>
33 #include <functional>
34 #include <iterator>
35 #include <memory>
36 #include <type_traits>
37 #include <utility>
38
39 #include "Utils.h"
40
41 namespace {
42
43 template <typename Type>
underlyingType(Type value)44 constexpr std::underlying_type_t<Type> underlyingType(Type value) {
45 return static_cast<std::underlying_type_t<Type>>(value);
46 }
47
48 using HalDuration = std::chrono::duration<uint64_t, std::micro>;
49
50 } // namespace
51
52 namespace android::nn {
53 namespace {
54
55 using hardware::hidl_handle;
56 using hardware::hidl_vec;
57
58 template <typename Input>
59 using UnvalidatedConvertOutput =
60 std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
61
62 template <typename Type>
unvalidatedConvert(const hidl_vec<Type> & arguments)63 GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
64 const hidl_vec<Type>& arguments) {
65 std::vector<UnvalidatedConvertOutput<Type>> canonical;
66 canonical.reserve(arguments.size());
67 for (const auto& argument : arguments) {
68 canonical.push_back(NN_TRY(nn::unvalidatedConvert(argument)));
69 }
70 return canonical;
71 }
72
73 template <typename Type>
validatedConvert(const Type & halObject)74 GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& halObject) {
75 auto canonical = NN_TRY(nn::unvalidatedConvert(halObject));
76 NN_TRY(hal::V1_2::utils::compliantVersion(canonical));
77 return canonical;
78 }
79
80 template <typename Type>
validatedConvert(const hidl_vec<Type> & arguments)81 GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> validatedConvert(
82 const hidl_vec<Type>& arguments) {
83 std::vector<UnvalidatedConvertOutput<Type>> canonical;
84 canonical.reserve(arguments.size());
85 for (const auto& argument : arguments) {
86 canonical.push_back(NN_TRY(validatedConvert(argument)));
87 }
88 return canonical;
89 }
90
91 } // anonymous namespace
92
unvalidatedConvert(const hal::V1_2::OperandType & operandType)93 GeneralResult<OperandType> unvalidatedConvert(const hal::V1_2::OperandType& operandType) {
94 return static_cast<OperandType>(operandType);
95 }
96
unvalidatedConvert(const hal::V1_2::OperationType & operationType)97 GeneralResult<OperationType> unvalidatedConvert(const hal::V1_2::OperationType& operationType) {
98 return static_cast<OperationType>(operationType);
99 }
100
unvalidatedConvert(const hal::V1_2::DeviceType & deviceType)101 GeneralResult<DeviceType> unvalidatedConvert(const hal::V1_2::DeviceType& deviceType) {
102 return static_cast<DeviceType>(deviceType);
103 }
104
unvalidatedConvert(const hal::V1_2::Capabilities & capabilities)105 GeneralResult<Capabilities> unvalidatedConvert(const hal::V1_2::Capabilities& capabilities) {
106 const bool validOperandTypes = std::all_of(
107 capabilities.operandPerformance.begin(), capabilities.operandPerformance.end(),
108 [](const hal::V1_2::Capabilities::OperandPerformance& operandPerformance) {
109 return validatedConvert(operandPerformance.type).has_value();
110 });
111 if (!validOperandTypes) {
112 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
113 << "Invalid OperandType when converting OperandPerformance in Capabilities";
114 }
115
116 const auto relaxedFloat32toFloat16PerformanceScalar =
117 NN_TRY(unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceScalar));
118 const auto relaxedFloat32toFloat16PerformanceTensor =
119 NN_TRY(unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceTensor));
120 auto operandPerformance = NN_TRY(unvalidatedConvert(capabilities.operandPerformance));
121
122 auto table =
123 NN_TRY(Capabilities::OperandPerformanceTable::create(std::move(operandPerformance)));
124
125 return Capabilities{
126 .relaxedFloat32toFloat16PerformanceScalar = relaxedFloat32toFloat16PerformanceScalar,
127 .relaxedFloat32toFloat16PerformanceTensor = relaxedFloat32toFloat16PerformanceTensor,
128 .operandPerformance = std::move(table),
129 };
130 }
131
unvalidatedConvert(const hal::V1_2::Capabilities::OperandPerformance & operandPerformance)132 GeneralResult<Capabilities::OperandPerformance> unvalidatedConvert(
133 const hal::V1_2::Capabilities::OperandPerformance& operandPerformance) {
134 const auto type = NN_TRY(unvalidatedConvert(operandPerformance.type));
135 const auto info = NN_TRY(unvalidatedConvert(operandPerformance.info));
136 return Capabilities::OperandPerformance{
137 .type = type,
138 .info = info,
139 };
140 }
141
unvalidatedConvert(const hal::V1_2::Operation & operation)142 GeneralResult<Operation> unvalidatedConvert(const hal::V1_2::Operation& operation) {
143 const auto type = NN_TRY(unvalidatedConvert(operation.type));
144 return Operation{
145 .type = type,
146 .inputs = operation.inputs,
147 .outputs = operation.outputs,
148 };
149 }
150
unvalidatedConvert(const hal::V1_2::SymmPerChannelQuantParams & symmPerChannelQuantParams)151 GeneralResult<Operand::SymmPerChannelQuantParams> unvalidatedConvert(
152 const hal::V1_2::SymmPerChannelQuantParams& symmPerChannelQuantParams) {
153 return Operand::SymmPerChannelQuantParams{
154 .scales = symmPerChannelQuantParams.scales,
155 .channelDim = symmPerChannelQuantParams.channelDim,
156 };
157 }
158
unvalidatedConvert(const hal::V1_2::Operand & operand)159 GeneralResult<Operand> unvalidatedConvert(const hal::V1_2::Operand& operand) {
160 const auto type = NN_TRY(unvalidatedConvert(operand.type));
161 const auto lifetime = NN_TRY(unvalidatedConvert(operand.lifetime));
162 const auto location = NN_TRY(unvalidatedConvert(operand.location));
163 auto extraParams = NN_TRY(unvalidatedConvert(operand.extraParams));
164 return Operand{
165 .type = type,
166 .dimensions = operand.dimensions,
167 .scale = operand.scale,
168 .zeroPoint = operand.zeroPoint,
169 .lifetime = lifetime,
170 .location = location,
171 .extraParams = std::move(extraParams),
172 };
173 }
174
unvalidatedConvert(const hal::V1_2::Operand::ExtraParams & extraParams)175 GeneralResult<Operand::ExtraParams> unvalidatedConvert(
176 const hal::V1_2::Operand::ExtraParams& extraParams) {
177 using Discriminator = hal::V1_2::Operand::ExtraParams::hidl_discriminator;
178 switch (extraParams.getDiscriminator()) {
179 case Discriminator::none:
180 return Operand::NoParams{};
181 case Discriminator::channelQuant:
182 return unvalidatedConvert(extraParams.channelQuant());
183 case Discriminator::extension:
184 return extraParams.extension();
185 }
186 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
187 << "Unrecognized Operand::ExtraParams discriminator: "
188 << underlyingType(extraParams.getDiscriminator());
189 }
190
unvalidatedConvert(const hal::V1_2::Model & model)191 GeneralResult<Model> unvalidatedConvert(const hal::V1_2::Model& model) {
192 auto operations = NN_TRY(unvalidatedConvert(model.operations));
193
194 // Verify number of consumers.
195 const auto numberOfConsumers =
196 NN_TRY(countNumberOfConsumers(model.operands.size(), operations));
197 CHECK(model.operands.size() == numberOfConsumers.size());
198 for (size_t i = 0; i < model.operands.size(); ++i) {
199 if (model.operands[i].numberOfConsumers != numberOfConsumers[i]) {
200 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
201 << "Invalid numberOfConsumers for operand " << i << ", expected "
202 << numberOfConsumers[i] << " but found " << model.operands[i].numberOfConsumers;
203 }
204 }
205
206 auto operands = NN_TRY(unvalidatedConvert(model.operands));
207 auto main = Model::Subgraph{
208 .operands = std::move(operands),
209 .operations = std::move(operations),
210 .inputIndexes = model.inputIndexes,
211 .outputIndexes = model.outputIndexes,
212 };
213
214 auto operandValues = NN_TRY(unvalidatedConvert(model.operandValues));
215 auto pools = NN_TRY(unvalidatedConvert(model.pools));
216 auto extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix));
217 return Model{
218 .main = std::move(main),
219 .operandValues = std::move(operandValues),
220 .pools = std::move(pools),
221 .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
222 .extensionNameToPrefix = std::move(extensionNameToPrefix),
223 };
224 }
225
unvalidatedConvert(const hal::V1_2::Model::ExtensionNameAndPrefix & extensionNameAndPrefix)226 GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
227 const hal::V1_2::Model::ExtensionNameAndPrefix& extensionNameAndPrefix) {
228 return ExtensionNameAndPrefix{
229 .name = extensionNameAndPrefix.name,
230 .prefix = extensionNameAndPrefix.prefix,
231 };
232 }
233
unvalidatedConvert(const hal::V1_2::OutputShape & outputShape)234 GeneralResult<OutputShape> unvalidatedConvert(const hal::V1_2::OutputShape& outputShape) {
235 return OutputShape{
236 .dimensions = outputShape.dimensions,
237 .isSufficient = outputShape.isSufficient,
238 };
239 }
240
unvalidatedConvert(const hal::V1_2::MeasureTiming & measureTiming)241 GeneralResult<MeasureTiming> unvalidatedConvert(const hal::V1_2::MeasureTiming& measureTiming) {
242 return static_cast<MeasureTiming>(measureTiming);
243 }
244
unvalidatedConvert(const hal::V1_2::Timing & timing)245 GeneralResult<Timing> unvalidatedConvert(const hal::V1_2::Timing& timing) {
246 constexpr uint64_t kMaxTiming = std::chrono::floor<HalDuration>(Duration::max()).count();
247 constexpr auto convertTiming = [](uint64_t halTiming) -> OptionalDuration {
248 constexpr uint64_t kNoTiming = std::numeric_limits<uint64_t>::max();
249 if (halTiming == kNoTiming) {
250 return {};
251 }
252 if (halTiming > kMaxTiming) {
253 return Duration::max();
254 }
255 return HalDuration{halTiming};
256 };
257 return Timing{.timeOnDevice = convertTiming(timing.timeOnDevice),
258 .timeInDriver = convertTiming(timing.timeInDriver)};
259 }
260
unvalidatedConvert(const hal::V1_2::Extension & extension)261 GeneralResult<Extension> unvalidatedConvert(const hal::V1_2::Extension& extension) {
262 auto operandTypes = NN_TRY(unvalidatedConvert(extension.operandTypes));
263 return Extension{
264 .name = extension.name,
265 .operandTypes = std::move(operandTypes),
266 };
267 }
268
unvalidatedConvert(const hal::V1_2::Extension::OperandTypeInformation & operandTypeInformation)269 GeneralResult<Extension::OperandTypeInformation> unvalidatedConvert(
270 const hal::V1_2::Extension::OperandTypeInformation& operandTypeInformation) {
271 return Extension::OperandTypeInformation{
272 .type = operandTypeInformation.type,
273 .isTensor = operandTypeInformation.isTensor,
274 .byteSize = operandTypeInformation.byteSize,
275 };
276 }
277
convert(const hal::V1_2::DeviceType & deviceType)278 GeneralResult<DeviceType> convert(const hal::V1_2::DeviceType& deviceType) {
279 return validatedConvert(deviceType);
280 }
281
convert(const hal::V1_2::Capabilities & capabilities)282 GeneralResult<Capabilities> convert(const hal::V1_2::Capabilities& capabilities) {
283 return validatedConvert(capabilities);
284 }
285
convert(const hal::V1_2::Model & model)286 GeneralResult<Model> convert(const hal::V1_2::Model& model) {
287 return validatedConvert(model);
288 }
289
convert(const hal::V1_2::MeasureTiming & measureTiming)290 GeneralResult<MeasureTiming> convert(const hal::V1_2::MeasureTiming& measureTiming) {
291 return validatedConvert(measureTiming);
292 }
293
convert(const hal::V1_2::Timing & timing)294 GeneralResult<Timing> convert(const hal::V1_2::Timing& timing) {
295 return validatedConvert(timing);
296 }
297
convert(const hardware::hidl_memory & memory)298 GeneralResult<SharedMemory> convert(const hardware::hidl_memory& memory) {
299 return validatedConvert(memory);
300 }
301
convert(const hidl_vec<hal::V1_2::Extension> & extensions)302 GeneralResult<std::vector<Extension>> convert(const hidl_vec<hal::V1_2::Extension>& extensions) {
303 return validatedConvert(extensions);
304 }
305
convert(const hidl_vec<hidl_handle> & handles)306 GeneralResult<std::vector<SharedHandle>> convert(const hidl_vec<hidl_handle>& handles) {
307 return validatedConvert(handles);
308 }
309
convert(const hidl_vec<hal::V1_2::OutputShape> & outputShapes)310 GeneralResult<std::vector<OutputShape>> convert(
311 const hidl_vec<hal::V1_2::OutputShape>& outputShapes) {
312 return validatedConvert(outputShapes);
313 }
314
315 } // namespace android::nn
316
317 namespace android::hardware::neuralnetworks::V1_2::utils {
318 namespace {
319
320 using utils::unvalidatedConvert;
321
unvalidatedConvert(const nn::Operand::LifeTime & lifetime)322 nn::GeneralResult<V1_0::OperandLifeTime> unvalidatedConvert(const nn::Operand::LifeTime& lifetime) {
323 return V1_0::utils::unvalidatedConvert(lifetime);
324 }
325
unvalidatedConvert(const nn::Capabilities::PerformanceInfo & performanceInfo)326 nn::GeneralResult<V1_0::PerformanceInfo> unvalidatedConvert(
327 const nn::Capabilities::PerformanceInfo& performanceInfo) {
328 return V1_0::utils::unvalidatedConvert(performanceInfo);
329 }
330
unvalidatedConvert(const nn::DataLocation & location)331 nn::GeneralResult<V1_0::DataLocation> unvalidatedConvert(const nn::DataLocation& location) {
332 return V1_0::utils::unvalidatedConvert(location);
333 }
334
unvalidatedConvert(const nn::Model::OperandValues & operandValues)335 nn::GeneralResult<hidl_vec<uint8_t>> unvalidatedConvert(
336 const nn::Model::OperandValues& operandValues) {
337 return V1_0::utils::unvalidatedConvert(operandValues);
338 }
339
unvalidatedConvert(const nn::SharedHandle & handle)340 nn::GeneralResult<hidl_handle> unvalidatedConvert(const nn::SharedHandle& handle) {
341 return V1_0::utils::unvalidatedConvert(handle);
342 }
343
unvalidatedConvert(const nn::SharedMemory & memory)344 nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::SharedMemory& memory) {
345 return V1_0::utils::unvalidatedConvert(memory);
346 }
347
348 template <typename Input>
349 using UnvalidatedConvertOutput =
350 std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
351
352 template <typename Type>
unvalidatedConvert(const std::vector<Type> & arguments)353 nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
354 const std::vector<Type>& arguments) {
355 hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
356 for (size_t i = 0; i < arguments.size(); ++i) {
357 halObject[i] = NN_TRY(unvalidatedConvert(arguments[i]));
358 }
359 return halObject;
360 }
361
makeExtraParams(nn::Operand::NoParams)362 nn::GeneralResult<Operand::ExtraParams> makeExtraParams(nn::Operand::NoParams /*noParams*/) {
363 return Operand::ExtraParams{};
364 }
365
makeExtraParams(const nn::Operand::SymmPerChannelQuantParams & channelQuant)366 nn::GeneralResult<Operand::ExtraParams> makeExtraParams(
367 const nn::Operand::SymmPerChannelQuantParams& channelQuant) {
368 Operand::ExtraParams ret;
369 ret.channelQuant(NN_TRY(unvalidatedConvert(channelQuant)));
370 return ret;
371 }
372
makeExtraParams(const nn::Operand::ExtensionParams & extension)373 nn::GeneralResult<Operand::ExtraParams> makeExtraParams(
374 const nn::Operand::ExtensionParams& extension) {
375 Operand::ExtraParams ret;
376 ret.extension(extension);
377 return ret;
378 }
379
380 template <typename Type>
validatedConvert(const Type & canonical)381 nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) {
382 NN_TRY(compliantVersion(canonical));
383 return unvalidatedConvert(canonical);
384 }
385
386 template <typename Type>
validatedConvert(const std::vector<Type> & arguments)387 nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> validatedConvert(
388 const std::vector<Type>& arguments) {
389 hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
390 for (size_t i = 0; i < arguments.size(); ++i) {
391 halObject[i] = NN_TRY(validatedConvert(arguments[i]));
392 }
393 return halObject;
394 }
395
396 } // anonymous namespace
397
unvalidatedConvert(const nn::OperandType & operandType)398 nn::GeneralResult<OperandType> unvalidatedConvert(const nn::OperandType& operandType) {
399 return static_cast<OperandType>(operandType);
400 }
401
unvalidatedConvert(const nn::OperationType & operationType)402 nn::GeneralResult<OperationType> unvalidatedConvert(const nn::OperationType& operationType) {
403 return static_cast<OperationType>(operationType);
404 }
405
unvalidatedConvert(const nn::DeviceType & deviceType)406 nn::GeneralResult<DeviceType> unvalidatedConvert(const nn::DeviceType& deviceType) {
407 switch (deviceType) {
408 case nn::DeviceType::UNKNOWN:
409 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE) << "Invalid DeviceType UNKNOWN";
410 case nn::DeviceType::OTHER:
411 case nn::DeviceType::CPU:
412 case nn::DeviceType::GPU:
413 case nn::DeviceType::ACCELERATOR:
414 return static_cast<DeviceType>(deviceType);
415 }
416 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
417 << "Invalid DeviceType " << underlyingType(deviceType);
418 }
419
unvalidatedConvert(const nn::Capabilities & capabilities)420 nn::GeneralResult<Capabilities> unvalidatedConvert(const nn::Capabilities& capabilities) {
421 std::vector<nn::Capabilities::OperandPerformance> filteredOperandPerformances;
422 filteredOperandPerformances.reserve(capabilities.operandPerformance.asVector().size());
423 std::copy_if(capabilities.operandPerformance.asVector().begin(),
424 capabilities.operandPerformance.asVector().end(),
425 std::back_inserter(filteredOperandPerformances),
426 [](const nn::Capabilities::OperandPerformance& operandPerformance) {
427 return compliantVersion(operandPerformance.type).has_value();
428 });
429
430 const auto relaxedFloat32toFloat16PerformanceScalar =
431 NN_TRY(unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceScalar));
432 const auto relaxedFloat32toFloat16PerformanceTensor =
433 NN_TRY(unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceTensor));
434 auto operandPerformance = NN_TRY(unvalidatedConvert(filteredOperandPerformances));
435 return Capabilities{
436 .relaxedFloat32toFloat16PerformanceScalar = relaxedFloat32toFloat16PerformanceScalar,
437 .relaxedFloat32toFloat16PerformanceTensor = relaxedFloat32toFloat16PerformanceTensor,
438 .operandPerformance = std::move(operandPerformance),
439 };
440 }
441
unvalidatedConvert(const nn::Capabilities::OperandPerformance & operandPerformance)442 nn::GeneralResult<Capabilities::OperandPerformance> unvalidatedConvert(
443 const nn::Capabilities::OperandPerformance& operandPerformance) {
444 const auto type = NN_TRY(unvalidatedConvert(operandPerformance.type));
445 const auto info = NN_TRY(unvalidatedConvert(operandPerformance.info));
446 return Capabilities::OperandPerformance{
447 .type = type,
448 .info = info,
449 };
450 }
451
unvalidatedConvert(const nn::Operation & operation)452 nn::GeneralResult<Operation> unvalidatedConvert(const nn::Operation& operation) {
453 const auto type = NN_TRY(unvalidatedConvert(operation.type));
454 return Operation{
455 .type = type,
456 .inputs = operation.inputs,
457 .outputs = operation.outputs,
458 };
459 }
460
unvalidatedConvert(const nn::Operand::SymmPerChannelQuantParams & symmPerChannelQuantParams)461 nn::GeneralResult<SymmPerChannelQuantParams> unvalidatedConvert(
462 const nn::Operand::SymmPerChannelQuantParams& symmPerChannelQuantParams) {
463 return SymmPerChannelQuantParams{
464 .scales = symmPerChannelQuantParams.scales,
465 .channelDim = symmPerChannelQuantParams.channelDim,
466 };
467 }
468
unvalidatedConvert(const nn::Operand & operand)469 nn::GeneralResult<Operand> unvalidatedConvert(const nn::Operand& operand) {
470 const auto type = NN_TRY(unvalidatedConvert(operand.type));
471 const auto lifetime = NN_TRY(unvalidatedConvert(operand.lifetime));
472 const auto location = NN_TRY(unvalidatedConvert(operand.location));
473 auto extraParams = NN_TRY(unvalidatedConvert(operand.extraParams));
474 return Operand{
475 .type = type,
476 .dimensions = operand.dimensions,
477 .numberOfConsumers = 0,
478 .scale = operand.scale,
479 .zeroPoint = operand.zeroPoint,
480 .lifetime = lifetime,
481 .location = location,
482 .extraParams = std::move(extraParams),
483 };
484 }
485
unvalidatedConvert(const nn::Operand::ExtraParams & extraParams)486 nn::GeneralResult<Operand::ExtraParams> unvalidatedConvert(
487 const nn::Operand::ExtraParams& extraParams) {
488 return std::visit([](const auto& x) { return makeExtraParams(x); }, extraParams);
489 }
490
unvalidatedConvert(const nn::Model & model)491 nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model) {
492 if (!hal::utils::hasNoPointerData(model)) {
493 return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
494 << "Model cannot be unvalidatedConverted because it contains pointer-based memory";
495 }
496
497 auto operands = NN_TRY(unvalidatedConvert(model.main.operands));
498
499 // Update number of consumers.
500 const auto numberOfConsumers =
501 NN_TRY(countNumberOfConsumers(operands.size(), model.main.operations));
502 CHECK(operands.size() == numberOfConsumers.size());
503 for (size_t i = 0; i < operands.size(); ++i) {
504 operands[i].numberOfConsumers = numberOfConsumers[i];
505 }
506
507 auto operations = NN_TRY(unvalidatedConvert(model.main.operations));
508 auto operandValues = NN_TRY(unvalidatedConvert(model.operandValues));
509 auto pools = NN_TRY(unvalidatedConvert(model.pools));
510 auto extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix));
511 return Model{
512 .operands = std::move(operands),
513 .operations = std::move(operations),
514 .inputIndexes = model.main.inputIndexes,
515 .outputIndexes = model.main.outputIndexes,
516 .operandValues = std::move(operandValues),
517 .pools = std::move(pools),
518 .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
519 .extensionNameToPrefix = std::move(extensionNameToPrefix),
520 };
521 }
522
unvalidatedConvert(const nn::ExtensionNameAndPrefix & extensionNameAndPrefix)523 nn::GeneralResult<Model::ExtensionNameAndPrefix> unvalidatedConvert(
524 const nn::ExtensionNameAndPrefix& extensionNameAndPrefix) {
525 return Model::ExtensionNameAndPrefix{
526 .name = extensionNameAndPrefix.name,
527 .prefix = extensionNameAndPrefix.prefix,
528 };
529 }
530
unvalidatedConvert(const nn::OutputShape & outputShape)531 nn::GeneralResult<OutputShape> unvalidatedConvert(const nn::OutputShape& outputShape) {
532 return OutputShape{.dimensions = outputShape.dimensions,
533 .isSufficient = outputShape.isSufficient};
534 }
535
unvalidatedConvert(const nn::MeasureTiming & measureTiming)536 nn::GeneralResult<MeasureTiming> unvalidatedConvert(const nn::MeasureTiming& measureTiming) {
537 return static_cast<MeasureTiming>(measureTiming);
538 }
539
unvalidatedConvert(const nn::Timing & timing)540 nn::GeneralResult<Timing> unvalidatedConvert(const nn::Timing& timing) {
541 constexpr auto convertTiming = [](nn::OptionalDuration canonicalTiming) -> uint64_t {
542 constexpr uint64_t kNoTiming = std::numeric_limits<uint64_t>::max();
543 if (!canonicalTiming.has_value()) {
544 return kNoTiming;
545 }
546 return std::chrono::ceil<HalDuration>(*canonicalTiming).count();
547 };
548 return Timing{.timeOnDevice = convertTiming(timing.timeOnDevice),
549 .timeInDriver = convertTiming(timing.timeInDriver)};
550 }
551
unvalidatedConvert(const nn::Extension & extension)552 nn::GeneralResult<Extension> unvalidatedConvert(const nn::Extension& extension) {
553 auto operandTypes = NN_TRY(unvalidatedConvert(extension.operandTypes));
554 return Extension{
555 .name = extension.name,
556 .operandTypes = std::move(operandTypes),
557 };
558 }
559
unvalidatedConvert(const nn::Extension::OperandTypeInformation & operandTypeInformation)560 nn::GeneralResult<Extension::OperandTypeInformation> unvalidatedConvert(
561 const nn::Extension::OperandTypeInformation& operandTypeInformation) {
562 return Extension::OperandTypeInformation{
563 .type = operandTypeInformation.type,
564 .isTensor = operandTypeInformation.isTensor,
565 .byteSize = operandTypeInformation.byteSize,
566 };
567 }
568
convert(const nn::DeviceType & deviceType)569 nn::GeneralResult<DeviceType> convert(const nn::DeviceType& deviceType) {
570 return validatedConvert(deviceType);
571 }
572
convert(const nn::Capabilities & capabilities)573 nn::GeneralResult<Capabilities> convert(const nn::Capabilities& capabilities) {
574 return validatedConvert(capabilities);
575 }
576
convert(const nn::Model & model)577 nn::GeneralResult<Model> convert(const nn::Model& model) {
578 return validatedConvert(model);
579 }
580
convert(const nn::MeasureTiming & measureTiming)581 nn::GeneralResult<MeasureTiming> convert(const nn::MeasureTiming& measureTiming) {
582 return validatedConvert(measureTiming);
583 }
584
convert(const nn::Timing & timing)585 nn::GeneralResult<Timing> convert(const nn::Timing& timing) {
586 return validatedConvert(timing);
587 }
588
convert(const std::vector<nn::Extension> & extensions)589 nn::GeneralResult<hidl_vec<Extension>> convert(const std::vector<nn::Extension>& extensions) {
590 return validatedConvert(extensions);
591 }
592
convert(const std::vector<nn::SharedHandle> & handles)593 nn::GeneralResult<hidl_vec<hidl_handle>> convert(const std::vector<nn::SharedHandle>& handles) {
594 return validatedConvert(handles);
595 }
596
convert(const std::vector<nn::OutputShape> & outputShapes)597 nn::GeneralResult<hidl_vec<OutputShape>> convert(const std::vector<nn::OutputShape>& outputShapes) {
598 return validatedConvert(outputShapes);
599 }
600
convert(const nn::DeviceStatus & deviceStatus)601 nn::GeneralResult<V1_0::DeviceStatus> convert(const nn::DeviceStatus& deviceStatus) {
602 return V1_1::utils::convert(deviceStatus);
603 }
604
convert(const nn::Request & request)605 nn::GeneralResult<V1_0::Request> convert(const nn::Request& request) {
606 return V1_1::utils::convert(request);
607 }
608
convert(const nn::ErrorStatus & status)609 nn::GeneralResult<V1_0::ErrorStatus> convert(const nn::ErrorStatus& status) {
610 return V1_1::utils::convert(status);
611 }
612
convert(const nn::ExecutionPreference & executionPreference)613 nn::GeneralResult<V1_1::ExecutionPreference> convert(
614 const nn::ExecutionPreference& executionPreference) {
615 return V1_1::utils::convert(executionPreference);
616 }
617
618 } // namespace android::hardware::neuralnetworks::V1_2::utils
619