1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "Validation.h"
18
19 #include <android-base/logging.h>
20 #include <android-base/mapped_file.h>
21
22 #include <algorithm>
23 #include <cctype>
24 #include <functional>
25 #include <limits>
26 #include <memory>
27 #include <numeric>
28 #include <set>
29 #include <sstream>
30 #include <string>
31 #include <string_view>
32 #include <tuple>
33 #include <utility>
34 #include <variant>
35 #include <vector>
36
37 #include "ControlFlow.h"
38 #include "OperandTypes.h"
39 #include "OperationTypes.h"
40 #include "OperationsUtils.h"
41 #include "OperationsValidationUtils.h"
42 #include "Result.h"
43 #include "SharedMemory.h"
44 #include "TypeUtils.h"
45 #include "Types.h"
46
47 namespace android::nn {
48
49 #define NN_FORWARD_DECLARE_VALIDATION_FUNCTION(opType) NN_VALIDATION_FUNCTION_SIGNATURE(opType);
50
NN_FOR_EACH_OPERATION(NN_FORWARD_DECLARE_VALIDATION_FUNCTION)51 NN_FOR_EACH_OPERATION(NN_FORWARD_DECLARE_VALIDATION_FUNCTION)
52
53 #undef NN_FORWARD_DECLARE_VALIDATION_FUNCTION
54
55 static Result<Version> notImplementedThroughRegistration(
56 const IOperationValidationContext* context) {
57 LOG(FATAL) << "Operation " << context->getOperationName()
58 << " not supported through registration";
59 return NN_ERROR();
60 }
61
62 NN_DEFINE_VALIDATION_FUNCTION(IF, notImplementedThroughRegistration);
63 NN_DEFINE_VALIDATION_FUNCTION(WHILE, notImplementedThroughRegistration);
64 NN_DEFINE_VALIDATION_FUNCTION(OEM_OPERATION, notImplementedThroughRegistration);
65
66 namespace {
67
68 constexpr auto kNullptrVariant = std::variant<const void*, void*>{};
69 constexpr auto kInvalidMemoryDomainToken = Request::MemoryDomainToken{};
70
71 template <typename Type, typename ValidationFunction>
validateVector(const std::vector<Type> & objects,const ValidationFunction & validationFunction)72 Result<Version> validateVector(const std::vector<Type>& objects,
73 const ValidationFunction& validationFunction) {
74 auto version = kVersionFeatureLevel1;
75 for (const auto& object : objects) {
76 version = combineVersions(version, NN_TRY(validationFunction(object)));
77 }
78 return version;
79 }
80
isValidExtensionName(const std::string & name)81 bool isValidExtensionName(const std::string& name) {
82 constexpr auto validSymbol = [](char symbol) {
83 return std::islower(symbol) || std::isdigit(symbol) || symbol == '.' || symbol == '_';
84 };
85 const bool hasOnlyValidSymbols = std::all_of(name.begin(), name.end(), validSymbol);
86 const bool hasAtLeastOnePeriod = std::find(name.begin(), name.end(), '.') != name.end();
87 return hasOnlyValidSymbols && hasAtLeastOnePeriod;
88 }
89
validateDeviceStatus(const DeviceStatus & deviceStatus)90 Result<Version> validateDeviceStatus(const DeviceStatus& deviceStatus) {
91 switch (deviceStatus) {
92 case DeviceStatus::AVAILABLE:
93 case DeviceStatus::BUSY:
94 case DeviceStatus::OFFLINE:
95 case DeviceStatus::UNKNOWN:
96 return kVersionFeatureLevel1;
97 }
98 NN_RET_CHECK_FAIL() << "Invalid DeviceStatus " << deviceStatus;
99 }
100
validateExecutionPreference(const ExecutionPreference & executionPreference)101 Result<Version> validateExecutionPreference(const ExecutionPreference& executionPreference) {
102 switch (executionPreference) {
103 case ExecutionPreference::FAST_SINGLE_ANSWER:
104 // ExecutionPreference::FAST_SINGLE_ANSWER is the default value, so it is implicitly
105 // valid for all versions.
106 return kVersionFeatureLevel1;
107 case ExecutionPreference::LOW_POWER:
108 case ExecutionPreference::SUSTAINED_SPEED:
109 return kVersionFeatureLevel2;
110 }
111 NN_RET_CHECK_FAIL() << "Invalid ExecutionPreference " << executionPreference;
112 }
113
validateDeviceType(const DeviceType & deviceType)114 Result<Version> validateDeviceType(const DeviceType& deviceType) {
115 switch (deviceType) {
116 case DeviceType::UNKNOWN:
117 // DeviceType was introduced in the 1.2 NN HAL. DeviceType::UNKNOWN is returned when
118 // querying versions that are prior to the 1.2 NN HAL. DeviceType::UNKNOWN is not a
119 // valid code to return for a driver that implement at least a 1.2 NN HAL. If we need a
120 // range of versions, make ANDROID_Q (NN HAL 1.2) the exclusive upper bound for
121 // DeviceType::UNKNOWN.
122 return kVersionFeatureLevel1;
123 case DeviceType::OTHER:
124 case DeviceType::CPU:
125 case DeviceType::GPU:
126 case DeviceType::ACCELERATOR:
127 return kVersionFeatureLevel3;
128 }
129 NN_RET_CHECK_FAIL() << "Invalid DeviceType " << deviceType;
130 }
131
validateMeasureTiming(const MeasureTiming & measureTiming)132 Result<Version> validateMeasureTiming(const MeasureTiming& measureTiming) {
133 switch (measureTiming) {
134 case MeasureTiming::NO:
135 // MeasureTiming::NO is the default value, so it is implicitly valid for all versions.
136 return kVersionFeatureLevel1;
137 case MeasureTiming::YES:
138 return kVersionFeatureLevel3;
139 }
140 NN_RET_CHECK_FAIL() << "Invalid MeasureTiming " << measureTiming;
141 }
142
validateOperandType(const OperandType & operandType)143 Result<Version> validateOperandType(const OperandType& operandType) {
144 switch (operandType) {
145 case OperandType::FLOAT32:
146 case OperandType::INT32:
147 case OperandType::UINT32:
148 case OperandType::TENSOR_FLOAT32:
149 case OperandType::TENSOR_INT32:
150 case OperandType::TENSOR_QUANT8_ASYMM:
151 case OperandType::OEM:
152 case OperandType::TENSOR_OEM_BYTE:
153 return kVersionFeatureLevel1;
154 case OperandType::BOOL:
155 case OperandType::TENSOR_QUANT16_SYMM:
156 case OperandType::TENSOR_FLOAT16:
157 case OperandType::TENSOR_BOOL8:
158 case OperandType::FLOAT16:
159 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
160 case OperandType::TENSOR_QUANT16_ASYMM:
161 case OperandType::TENSOR_QUANT8_SYMM:
162 return kVersionFeatureLevel3;
163 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
164 case OperandType::SUBGRAPH:
165 return kVersionFeatureLevel4;
166 }
167 if (isExtension(operandType)) {
168 return kVersionFeatureLevel3;
169 }
170 NN_RET_CHECK_FAIL() << "Invalid OperandType " << operandType;
171 }
172
validateOperandLifeTime(const Operand & operand)173 Result<Version> validateOperandLifeTime(const Operand& operand) {
174 // Make sure SUBGRAPH operand type and lifetime always go together.
175 NN_RET_CHECK_EQ((operand.type == OperandType::SUBGRAPH),
176 (operand.lifetime == Operand::LifeTime::SUBGRAPH))
177 << "Operand of type " << operand.type << " cannot have lifetime " << operand.lifetime;
178
179 switch (operand.lifetime) {
180 case Operand::LifeTime::TEMPORARY_VARIABLE:
181 case Operand::LifeTime::SUBGRAPH_INPUT:
182 case Operand::LifeTime::SUBGRAPH_OUTPUT:
183 case Operand::LifeTime::CONSTANT_COPY:
184 case Operand::LifeTime::CONSTANT_REFERENCE:
185 case Operand::LifeTime::NO_VALUE:
186 case Operand::LifeTime::POINTER:
187 return kVersionFeatureLevel1;
188 case Operand::LifeTime::SUBGRAPH:
189 return kVersionFeatureLevel4;
190 }
191 NN_RET_CHECK_FAIL() << "Invalid Operand::LifeTime " << operand.lifetime;
192 }
193
validatePriority(const Priority & priority)194 Result<Version> validatePriority(const Priority& priority) {
195 switch (priority) {
196 case Priority::MEDIUM:
197 // Priority::MEDIUM is the default value, so it is implicitly valid for all versions.
198 return kVersionFeatureLevel1;
199 case Priority::LOW:
200 case Priority::HIGH:
201 return kVersionFeatureLevel4;
202 }
203 NN_RET_CHECK_FAIL() << "Invalid Priority " << priority;
204 }
205
validateErrorStatus(const ErrorStatus & errorStatus)206 Result<Version> validateErrorStatus(const ErrorStatus& errorStatus) {
207 // Note that MISSED_DEADLINE_*, RESOURCE_EXHAUSTED_*, and DEAD_OBJECT were introduced ih
208 // ANDROID_R, but these can be cast to ANDROID_OC_MR1 as GENERAL_FAILURE.
209 switch (errorStatus) {
210 case ErrorStatus::NONE:
211 case ErrorStatus::DEVICE_UNAVAILABLE:
212 case ErrorStatus::GENERAL_FAILURE:
213 case ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
214 case ErrorStatus::INVALID_ARGUMENT:
215 case ErrorStatus::MISSED_DEADLINE_TRANSIENT:
216 case ErrorStatus::MISSED_DEADLINE_PERSISTENT:
217 case ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
218 case ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
219 case ErrorStatus::DEAD_OBJECT:
220 return kVersionFeatureLevel1;
221 }
222 NN_RET_CHECK_FAIL() << "Invalid ErrorStatus " << errorStatus;
223 }
224
validateFusedActivationFunc(const FusedActivationFunc & activation)225 Result<Version> validateFusedActivationFunc(const FusedActivationFunc& activation) {
226 switch (activation) {
227 case FusedActivationFunc::NONE:
228 case FusedActivationFunc::RELU:
229 case FusedActivationFunc::RELU1:
230 case FusedActivationFunc::RELU6:
231 return kVersionFeatureLevel1;
232 }
233 NN_RET_CHECK_FAIL() << "Invalid FusedActivationFunc " << activation;
234 }
235
validateOutputShape(const OutputShape &)236 Result<Version> validateOutputShape(const OutputShape& /*outputShape*/) {
237 return kVersionFeatureLevel3;
238 }
239
validateTiming(const Timing & timing)240 Result<Version> validateTiming(const Timing& timing) {
241 constexpr auto kNoTiming = Timing{};
242 if (timing == kNoTiming) {
243 // kNoTiming is the default value, so it is implicitly valid for all versions.
244 return kVersionFeatureLevel1;
245 }
246 if (timing.timeInDriver.has_value() && timing.timeOnDevice.has_value()) {
247 // `lazyMessage` is a lazy function to produce the timing validation error message.
248 // Currently, the code is not able to inline the message in NN_RET_CHECK due to a
249 // argument-dependent lookup issue with nn::detail::ErrorBuilder interacting with std types
250 // such as std::chrono::duration, so this function uses an indirection through
251 // std::ostringstream.
252 const auto lazyMessage = [&timing]() -> std::string {
253 std::ostringstream oss;
254 oss << "Timing::timeOnDevice (" << timing.timeOnDevice.value()
255 << ") must not exceed Timing::timeInDriver (" << timing.timeInDriver.value() << ")";
256 return oss.str();
257 };
258 NN_RET_CHECK(timing.timeOnDevice.value() <= timing.timeInDriver.value()) << lazyMessage();
259 }
260 return kVersionFeatureLevel3;
261 }
262
validateCapabilitiesPerformanceInfo(const Capabilities::PerformanceInfo & performanceInfo)263 Result<Version> validateCapabilitiesPerformanceInfo(
264 const Capabilities::PerformanceInfo& performanceInfo) {
265 NN_RET_CHECK_GT(performanceInfo.execTime, 0.0f);
266 NN_RET_CHECK_GT(performanceInfo.powerUsage, 0.0f);
267 return kVersionFeatureLevel1;
268 }
269
validateCapabilitiesOperandPerformance(const Capabilities::OperandPerformance & operandPerformance)270 Result<Version> validateCapabilitiesOperandPerformance(
271 const Capabilities::OperandPerformance& operandPerformance) {
272 auto version = NN_TRY(validateOperandType(operandPerformance.type));
273 return combineVersions(version,
274 NN_TRY(validateCapabilitiesPerformanceInfo(operandPerformance.info)));
275 }
276
validateCapabilitiesOperandPerformanceTable(const Capabilities::OperandPerformanceTable & operandPerformances)277 Result<Version> validateCapabilitiesOperandPerformanceTable(
278 const Capabilities::OperandPerformanceTable& operandPerformances) {
279 // OperandPerformanceTable's order was validated when it was created, and it is castable to any
280 // version. If an OperandType does not exist in the lower version being converted to, that
281 // OperandPerformance will be dropped.
282 NN_TRY(validateVector(operandPerformances.asVector(), validateCapabilitiesOperandPerformance));
283 return kVersionFeatureLevel1;
284 }
285
validateCapabilities(const Capabilities & capabilities)286 Result<Version> validateCapabilities(const Capabilities& capabilities) {
287 auto version =
288 NN_TRY(validateCapabilitiesOperandPerformanceTable(capabilities.operandPerformance));
289
290 version = combineVersions(version,
291 NN_TRY(validateCapabilitiesPerformanceInfo(
292 capabilities.relaxedFloat32toFloat16PerformanceScalar)));
293 version = combineVersions(version,
294 NN_TRY(validateCapabilitiesPerformanceInfo(
295 capabilities.relaxedFloat32toFloat16PerformanceTensor)));
296 version = combineVersions(
297 version, NN_TRY(validateCapabilitiesPerformanceInfo(capabilities.ifPerformance)));
298 version = combineVersions(
299 version, NN_TRY(validateCapabilitiesPerformanceInfo(capabilities.whilePerformance)));
300
301 return version;
302 }
303
validateExtensionOperandTypeInformation(const Extension::OperandTypeInformation & operandTypeInformation)304 Result<Version> validateExtensionOperandTypeInformation(
305 const Extension::OperandTypeInformation& operandTypeInformation) {
306 NN_RET_CHECK_GT(operandTypeInformation.byteSize, 0u);
307 return kVersionFeatureLevel3;
308 }
309
validateExtension(const Extension & extension)310 Result<Version> validateExtension(const Extension& extension) {
311 NN_RET_CHECK(isValidExtensionName(extension.name));
312
313 // Verify all OperandTypeInformations have unique types.
314 std::vector<uint16_t> types;
315 types.reserve(extension.operandTypes.size());
316 std::transform(extension.operandTypes.begin(), extension.operandTypes.end(),
317 std::back_inserter(types),
318 [](const Extension::OperandTypeInformation& operandTypeInformation) {
319 return operandTypeInformation.type;
320 });
321 std::sort(types.begin(), types.end());
322 const auto iter = std::adjacent_find(types.begin(), types.end());
323 NN_RET_CHECK(iter == types.end()) << "Extension has duplicate type " << *iter;
324
325 return combineVersions(kVersionFeatureLevel3,
326 NN_TRY(validateVector(extension.operandTypes,
327 validateExtensionOperandTypeInformation)));
328 }
329
validateExtensions(const std::vector<Extension> & extensions)330 Result<Version> validateExtensions(const std::vector<Extension>& extensions) {
331 const auto version = NN_TRY(validateVector(extensions, validateExtension));
332
333 // Verify all extensions have unique names.
334 std::vector<std::reference_wrapper<const std::string>> names;
335 names.reserve(extensions.size());
336 std::transform(extensions.begin(), extensions.end(), std::back_inserter(names),
337 [](const Extension& extension) { return std::cref(extension.name); });
338 std::sort(names.begin(), names.end(), std::less<std::string>{});
339 const auto nameIter =
340 std::adjacent_find(names.begin(), names.end(), std::equal_to<std::string>{});
341 NN_RET_CHECK(nameIter == names.end())
342 << "Two or more extensions have the duplicate name " << nameIter->get();
343
344 return version;
345 }
346
347 // Forward declaration of subgraph validation function.
348 Result<Version> validateModelSubgraph(const Model::Subgraph& subgraph,
349 std::optional<size_t> referencedIndex,
350 size_t operandValuesSize,
351 const std::vector<size_t>& poolSizes,
352 const std::vector<Model::Subgraph>& referenced,
353 std::vector<std::optional<Version>>* subgraphVersionCache);
354
validateOperandDataLocation(const Operand & operand,size_t operandValuesSize,const std::vector<size_t> & poolSizes,const std::vector<Model::Subgraph> & subgraphs,std::vector<std::optional<Version>> * subgraphVersionCache)355 Result<Version> validateOperandDataLocation(
356 const Operand& operand, size_t operandValuesSize, const std::vector<size_t>& poolSizes,
357 const std::vector<Model::Subgraph>& subgraphs,
358 std::vector<std::optional<Version>>* subgraphVersionCache) {
359 const DataLocation& location = operand.location;
360 NN_RET_CHECK_EQ(location.padding, 0u)
361 << "DataLocation with a non-zero padding used in Model: " << location.padding;
362 switch (operand.lifetime) {
363 case Operand::LifeTime::CONSTANT_COPY:
364 NN_RET_CHECK(location.pointer == kNullptrVariant)
365 << "CONSTANT_COPY with a non-null pointer";
366 NN_RET_CHECK_EQ(location.poolIndex, 0u)
367 << "CONSTANT_COPY with a non-zero poolIndex " << location.poolIndex;
368 // Do the addition using uint64_t to avoid potential wrap-around problems.
369 NN_RET_CHECK_LE(static_cast<uint64_t>(location.offset) + location.length,
370 operandValuesSize)
371 << "OperandValue location out of range. Starts at " << location.offset
372 << ", length " << location.length << ", max " << operandValuesSize;
373 return kVersionFeatureLevel1;
374 case Operand::LifeTime::CONSTANT_REFERENCE:
375 NN_RET_CHECK_LT(location.poolIndex, poolSizes.size());
376 // Do the addition using uint64_t to avoid potential wrap-around problems.
377 NN_RET_CHECK_LE(static_cast<uint64_t>(location.offset) + location.length,
378 poolSizes[location.poolIndex])
379 << "OperandValue location out of range. Starts at " << location.offset
380 << ", length " << location.length << ", max " << poolSizes[location.poolIndex];
381 return kVersionFeatureLevel1;
382 case Operand::LifeTime::TEMPORARY_VARIABLE:
383 case Operand::LifeTime::SUBGRAPH_INPUT:
384 case Operand::LifeTime::SUBGRAPH_OUTPUT:
385 case Operand::LifeTime::NO_VALUE:
386 NN_RET_CHECK(location.pointer == kNullptrVariant)
387 << "Unexpected pointer value for operand of lifetime " << operand.lifetime;
388 NN_RET_CHECK_EQ(location.poolIndex, 0u)
389 << "Unexpected poolIndex " << location.poolIndex << " for operand of lifetime "
390 << operand.lifetime;
391 NN_RET_CHECK_EQ(location.offset, 0u) << "Unexpected offset " << location.offset
392 << " for operand of lifetime " << operand.lifetime;
393 NN_RET_CHECK_EQ(location.length, 0u) << "Unexpected length " << location.length
394 << " for operand of lifetime " << operand.lifetime;
395 return kVersionFeatureLevel1;
396 case Operand::LifeTime::SUBGRAPH: {
397 NN_RET_CHECK(location.pointer == kNullptrVariant) << "SUBGRAPH with a non-null pointer";
398 NN_RET_CHECK_EQ(location.poolIndex, 0u)
399 << "SUBGRAPH with a non-zero poolIndex " << location.poolIndex;
400 NN_RET_CHECK_LT(location.offset, subgraphs.size())
401 << "Subgraph index out of range: " << location.offset
402 << " >= " << subgraphs.size();
403 NN_RET_CHECK_EQ(location.length, 0u)
404 << "SUBGRAPH with a non-zero length " << location.length;
405 const auto version = NN_TRY(validateModelSubgraph(
406 subgraphs[location.offset], location.offset, operandValuesSize, poolSizes,
407 subgraphs, subgraphVersionCache));
408 return combineVersions(version, kVersionFeatureLevel4);
409 }
410 case Operand::LifeTime::POINTER: {
411 const bool nonNull =
412 std::visit([](auto* ptr) { return ptr != nullptr; }, location.pointer);
413 NN_RET_CHECK(nonNull) << "POINTER with a null pointer";
414 NN_RET_CHECK_EQ(location.poolIndex, 0u)
415 << "POINTER with a non-zero poolIndex " << location.poolIndex;
416 NN_RET_CHECK_EQ(location.offset, 0u)
417 << "POINTER with a non-zero offset " << location.offset;
418 return kVersionFeatureLevel1;
419 }
420 }
421 NN_RET_CHECK_FAIL() << "Invalid Operand::LifeTime " << operand.lifetime;
422 }
423
validateOperandDimensions(const Operand & operand)424 Result<Version> validateOperandDimensions(const Operand& operand) {
425 switch (operand.type) {
426 case OperandType::FLOAT32:
427 case OperandType::INT32:
428 case OperandType::UINT32:
429 case OperandType::BOOL:
430 case OperandType::FLOAT16:
431 case OperandType::SUBGRAPH:
432 case OperandType::OEM:
433 NN_RET_CHECK(operand.dimensions.empty())
434 << "Scalar data has dimensions of rank " << operand.dimensions.size();
435 return kVersionFeatureLevel1;
436 case OperandType::TENSOR_FLOAT32:
437 case OperandType::TENSOR_INT32:
438 case OperandType::TENSOR_QUANT8_ASYMM:
439 case OperandType::TENSOR_QUANT16_SYMM:
440 case OperandType::TENSOR_FLOAT16:
441 case OperandType::TENSOR_BOOL8:
442 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
443 case OperandType::TENSOR_QUANT16_ASYMM:
444 case OperandType::TENSOR_QUANT8_SYMM:
445 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
446 case OperandType::TENSOR_OEM_BYTE: {
447 if (operand.lifetime == Operand::LifeTime::CONSTANT_COPY ||
448 operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE ||
449 operand.lifetime == Operand::LifeTime::POINTER) {
450 NN_RET_CHECK(!operand.dimensions.empty())
451 << "Tensor has lifetime of " << operand.lifetime
452 << " but dimensions of rank 0";
453 const auto size = getNonExtensionSize(operand);
454 NN_RET_CHECK(size.has_value()) << "Tensor dimensions overflow";
455 NN_RET_CHECK_NE(size.value(), 0u) << "Tensor has at least one unknown dimension";
456 }
457 // TODO(b/165152547): aren't NO_VALUE arguments allowed to be .empty() even before
458 // Android Q?
459 if (operand.dimensions.empty()) {
460 // Unspecified rank was added in Android Q.
461 return kVersionFeatureLevel3;
462 }
463 return kVersionFeatureLevel1;
464 }
465 }
466 if (isExtension(operand.type)) {
467 // Extension types were added in Android Q.
468 return kVersionFeatureLevel3;
469 }
470 NN_RET_CHECK_FAIL() << "Invalid OperandType " << operand.type;
471 }
472
validateOperandScale(const Operand & operand)473 Result<Version> validateOperandScale(const Operand& operand) {
474 switch (operand.type) {
475 case OperandType::FLOAT32:
476 case OperandType::INT32:
477 case OperandType::UINT32:
478 case OperandType::TENSOR_FLOAT32:
479 case OperandType::BOOL:
480 case OperandType::TENSOR_FLOAT16:
481 case OperandType::TENSOR_BOOL8:
482 case OperandType::FLOAT16:
483 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
484 case OperandType::SUBGRAPH:
485 NN_RET_CHECK_EQ(operand.scale, 0.0f)
486 << "Operand of type " << operand.type << " with a non-zero scale ("
487 << operand.scale << ")";
488 return kVersionFeatureLevel1;
489 case OperandType::TENSOR_INT32:
490 // TENSOR_INT32 may be used with or without scale, depending on the operation.
491 // TODO(b/119869082) We should have a separate type for TENSOR_INT32 with a scale.
492 NN_RET_CHECK_GE(operand.scale, 0.0f)
493 << "Operand of type " << operand.type << " with a negative scale";
494 return kVersionFeatureLevel1;
495 case OperandType::TENSOR_QUANT8_ASYMM:
496 case OperandType::TENSOR_QUANT16_SYMM:
497 case OperandType::TENSOR_QUANT16_ASYMM:
498 case OperandType::TENSOR_QUANT8_SYMM:
499 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
500 NN_RET_CHECK_GT(operand.scale, 0.0f)
501 << "Operand of type " << operand.type << " with a non-positive scale";
502 return kVersionFeatureLevel1;
503 case OperandType::OEM:
504 case OperandType::TENSOR_OEM_BYTE:
505 // No validation for OEM types.
506 return kVersionFeatureLevel1;
507 }
508 if (isExtension(operand.type)) {
509 NN_RET_CHECK_EQ(operand.scale, 0.0f) << "Operand of type " << operand.type
510 << " with a non-zero scale (" << operand.scale << ")";
511 return kVersionFeatureLevel3;
512 }
513 NN_RET_CHECK_FAIL() << "Invalid OperandType " << operand.type;
514 }
515
validateOperandZeroPoint(const Operand & operand)516 Result<Version> validateOperandZeroPoint(const Operand& operand) {
517 switch (operand.type) {
518 case OperandType::FLOAT32:
519 case OperandType::INT32:
520 case OperandType::UINT32:
521 case OperandType::TENSOR_FLOAT32:
522 case OperandType::TENSOR_INT32:
523 case OperandType::BOOL:
524 case OperandType::TENSOR_FLOAT16:
525 case OperandType::TENSOR_BOOL8:
526 case OperandType::FLOAT16:
527 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
528 case OperandType::TENSOR_QUANT8_SYMM:
529 case OperandType::SUBGRAPH:
530 NN_RET_CHECK_EQ(operand.zeroPoint, 0)
531 << "Operand of type " << operand.type << " with a non-zero zeroPoint "
532 << operand.zeroPoint;
533 return kVersionFeatureLevel1;
534 case OperandType::TENSOR_QUANT8_ASYMM:
535 NN_RET_CHECK(operand.zeroPoint >= 0 && operand.zeroPoint <= 255)
536 << "Operand of type " << operand.type << " with an invalid zeroPoint "
537 << operand.zeroPoint << ", must be in range [0, 255]";
538 return kVersionFeatureLevel1;
539 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
540 NN_RET_CHECK(operand.zeroPoint >= -128 && operand.zeroPoint <= 127)
541 << "Operand of type " << operand.type << " with an invalid zeroPoint "
542 << operand.zeroPoint << ", must be in range [-128, 127]";
543 return kVersionFeatureLevel1;
544 case OperandType::TENSOR_QUANT16_ASYMM:
545 NN_RET_CHECK(operand.zeroPoint >= 0 && operand.zeroPoint <= 65535)
546 << "Operand of type " << operand.type << " with an invalid zeroPoint "
547 << operand.zeroPoint << ", must be in range [0, 65535]";
548 return kVersionFeatureLevel1;
549 case OperandType::TENSOR_QUANT16_SYMM:
550 NN_RET_CHECK_EQ(operand.zeroPoint, 0)
551 << "Operand of type " << operand.type << " with a non-zero zeroPoint "
552 << operand.zeroPoint;
553 return kVersionFeatureLevel1;
554 case OperandType::OEM:
555 case OperandType::TENSOR_OEM_BYTE:
556 // No validation for OEM types.
557 return kVersionFeatureLevel1;
558 }
559 if (isExtension(operand.type)) {
560 NN_RET_CHECK_EQ(operand.zeroPoint, 0) << "Operand of type " << operand.type
561 << " with a non-zero zeroPoint " << operand.zeroPoint;
562 return kVersionFeatureLevel3;
563 }
564 NN_RET_CHECK_FAIL() << "Invalid OperandType " << operand.type;
565 }
566
validateOperandExtraParams(const Operand & operand)567 Result<Version> validateOperandExtraParams(const Operand& operand) {
568 switch (operand.type) {
569 case OperandType::FLOAT32:
570 case OperandType::INT32:
571 case OperandType::UINT32:
572 case OperandType::TENSOR_FLOAT32:
573 case OperandType::TENSOR_INT32:
574 case OperandType::TENSOR_QUANT8_ASYMM:
575 case OperandType::BOOL:
576 case OperandType::TENSOR_QUANT16_SYMM:
577 case OperandType::TENSOR_FLOAT16:
578 case OperandType::TENSOR_BOOL8:
579 case OperandType::FLOAT16:
580 case OperandType::TENSOR_QUANT16_ASYMM:
581 case OperandType::TENSOR_QUANT8_SYMM:
582 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
583 case OperandType::SUBGRAPH:
584 NN_RET_CHECK(std::holds_alternative<Operand::NoParams>(operand.extraParams))
585 << "Operand of type " << operand.type
586 << " has extraParams when there must be none";
587 return kVersionFeatureLevel1;
588 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: {
589 NN_RET_CHECK(
590 std::holds_alternative<Operand::SymmPerChannelQuantParams>(operand.extraParams))
591 << "Operand of type " << operand.type
592 << " without a Channel Quantization params";
593 const auto& channelQuant =
594 std::get<Operand::SymmPerChannelQuantParams>(operand.extraParams);
595
596 const size_t count = operand.dimensions.size();
597 NN_RET_CHECK_LT(channelQuant.channelDim, count)
598 << "Operand of type " << operand.type
599 << " with an invalid channelQuant.channelDim " << channelQuant.channelDim
600 << ", must be valid dimension index in range [0, " << count << ")";
601 const uint32_t expected = operand.dimensions[channelQuant.channelDim];
602 NN_RET_CHECK_EQ(channelQuant.scales.size(), expected)
603 << "Operand of type " << operand.type << " with a wrong-sized scales, expected "
604 << expected << " was " << channelQuant.scales.size();
605 NN_RET_CHECK_NE(expected, 0u)
606 << "Operand of type " << operand.type << " channel dimension "
607 << channelQuant.channelDim << " is underspecified (can't be 0)";
608 for (uint32_t i = 0; i < expected; ++i) {
609 NN_RET_CHECK_GT(channelQuant.scales[i], 0.0f)
610 << "Operand of type " << operand.type
611 << " with a non-positive value in scales[" << i
612 << "]=" << channelQuant.scales[i];
613 }
614 return kVersionFeatureLevel3;
615 }
616 case OperandType::OEM:
617 case OperandType::TENSOR_OEM_BYTE:
618 // No validation for OEM types.
619 return kVersionFeatureLevel1;
620 }
621 if (isExtension(operand.type)) {
622 NN_RET_CHECK(std::holds_alternative<Operand::NoParams>(operand.extraParams) ||
623 std::holds_alternative<Operand::ExtensionParams>(operand.extraParams))
624 << "Extension operand of type " << operand.type
625 << " must not have SymmPerChannelQuant extraParams";
626 return kVersionFeatureLevel1;
627 }
628 NN_RET_CHECK_FAIL() << "Invalid OperandType " << operand.type;
629 }
630
validateOperand(const Operand & operand,size_t operandValuesSize,const std::vector<size_t> & poolSizes,const std::vector<Model::Subgraph> & subgraphs,std::vector<std::optional<Version>> * subgraphVersionCache)631 Result<Version> validateOperand(const Operand& operand, size_t operandValuesSize,
632 const std::vector<size_t>& poolSizes,
633 const std::vector<Model::Subgraph>& subgraphs,
634 std::vector<std::optional<Version>>* subgraphVersionCache) {
635 auto version = NN_TRY(validateOperandType(operand.type));
636 version = combineVersions(version, NN_TRY(validateOperandLifeTime(operand)));
637 version = combineVersions(version, NN_TRY(validateOperandDimensions(operand)));
638 version = combineVersions(version, NN_TRY(validateOperandScale(operand)));
639 version = combineVersions(version, NN_TRY(validateOperandZeroPoint(operand)));
640 version = combineVersions(version, NN_TRY(validateOperandExtraParams(operand)));
641 version = combineVersions(
642 version, NN_TRY(validateOperandDataLocation(operand, operandValuesSize, poolSizes,
643 subgraphs, subgraphVersionCache)));
644
645 // For constants, validate that the length is as expected. The other lifetimes
646 // expect the length to be 0. Don't validate for OEM types.
647 if (operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE ||
648 operand.lifetime == Operand::LifeTime::CONSTANT_COPY ||
649 operand.lifetime == Operand::LifeTime::POINTER) {
650 if (!isExtension(operand.type) && operand.type != OperandType::OEM &&
651 operand.type != OperandType::TENSOR_OEM_BYTE) {
652 const auto expectedLength = getNonExtensionSize(operand).value();
653 NN_RET_CHECK_EQ(operand.location.length, expectedLength)
654 << "For operand " << operand.type << " expected a size of " << expectedLength
655 << " but got " << operand.location.length;
656 }
657 }
658
659 return version;
660 }
661
validateOperands(const std::vector<Operand> & operands,size_t operandValuesSize,const std::vector<size_t> & poolSizes,const std::vector<Model::Subgraph> & subgraphs,std::vector<std::optional<Version>> * subgraphVersionCache)662 Result<std::vector<Version>> validateOperands(
663 const std::vector<Operand>& operands, size_t operandValuesSize,
664 const std::vector<size_t>& poolSizes, const std::vector<Model::Subgraph>& subgraphs,
665 std::vector<std::optional<Version>>* subgraphVersionCache) {
666 std::vector<Version> versions;
667 versions.reserve(operands.size());
668 for (size_t i = 0; i < operands.size(); ++i) {
669 auto result = validateOperand(operands[i], operandValuesSize, poolSizes, subgraphs,
670 subgraphVersionCache);
671 if (!result.has_value()) {
672 return error() << std::move(result).error() << " for operand " << i;
673 }
674 versions.push_back(result.value());
675 }
676 return versions;
677 }
678
679 // Forward declaration.
680 Result<Version> validateOperationIncludingOperandVersions(
681 const Operation& operation, const std::vector<Operand>& operands,
682 const std::vector<Version>& operandVersions, const std::vector<Model::Subgraph>& subgraphs);
683
validateOperations(const std::vector<Operation> & operations,const std::vector<Operand> & operands,const std::vector<Version> & operandVersions,const std::vector<Model::Subgraph> & subgraphs)684 Result<Version> validateOperations(const std::vector<Operation>& operations,
685 const std::vector<Operand>& operands,
686 const std::vector<Version>& operandVersions,
687 const std::vector<Model::Subgraph>& subgraphs) {
688 auto version = kVersionFeatureLevel1;
689 for (size_t i = 0; i < operations.size(); ++i) {
690 auto result = validateOperationIncludingOperandVersions(operations[i], operands,
691 operandVersions, subgraphs);
692 if (!result.has_value()) {
693 return error() << std::move(result).error() << " for operation " << i;
694 }
695 version = combineVersions(version, result.value());
696 }
697 return version;
698 }
699
validateUnknownHandle(const Memory::Unknown::Handle & handle)700 Result<Version> validateUnknownHandle(const Memory::Unknown::Handle& handle) {
701 NN_RET_CHECK(std::all_of(handle.fds.begin(), handle.fds.end(),
702 [](const base::unique_fd& fd) { return fd.ok(); }));
703 return kVersionFeatureLevel3;
704 }
705
validateSharedHandle(const SharedHandle & handle)706 Result<Version> validateSharedHandle(const SharedHandle& handle) {
707 // The absence of a shared handle is implicitly valid for all versions.
708 if (handle == nullptr) {
709 return kVersionFeatureLevel1;
710 }
711 NN_RET_CHECK(handle->ok());
712 return kVersionFeatureLevel3;
713 }
714
validateMemory(const Memory::Ashmem & memory)715 Result<Version> validateMemory(const Memory::Ashmem& memory) {
716 NN_RET_CHECK(memory.fd.ok());
717 NN_RET_CHECK_NE(memory.size, 0u);
718 return kVersionFeatureLevel1;
719 }
720
validateMemory(const Memory::Fd & memory)721 Result<Version> validateMemory(const Memory::Fd& memory) {
722 NN_RET_CHECK(memory.fd.ok());
723 NN_RET_CHECK_NE(memory.size, 0u);
724
725 // `prot` is allowed to be either PROT_NONE (which has a value of 0) or the bitwise OR of either
726 // PROT_READ or PROT_WRITE. If any other bits are set, the `prot` field is invalid.
727 constexpr int kAllowedBits = PROT_READ | PROT_WRITE;
728 NN_RET_CHECK_EQ(memory.prot & ~kAllowedBits, 0);
729
730 return kVersionFeatureLevel1;
731 }
732
validateMemory(const Memory::HardwareBuffer & memory)733 Result<Version> validateMemory(const Memory::HardwareBuffer& memory) {
734 NN_RET_CHECK(memory.handle.get() != nullptr);
735 return kVersionFeatureLevel3;
736 }
737
validateMemory(const Memory::Unknown & memory)738 Result<Version> validateMemory(const Memory::Unknown& memory) {
739 NN_TRY(validateUnknownHandle(memory.handle));
740 return kVersionFeatureLevel3;
741 }
742
validateSharedMemory(const SharedMemory & memory)743 Result<Version> validateSharedMemory(const SharedMemory& memory) {
744 NN_RET_CHECK(memory != nullptr);
745 return std::visit([](const auto& x) { return validateMemory(x); }, memory->handle);
746 }
747
validateModelSubgraphInputOutputs(const std::vector<uint32_t> & indexes,const std::vector<Operand> & operands,Operand::LifeTime lifetime)748 Result<void> validateModelSubgraphInputOutputs(const std::vector<uint32_t>& indexes,
749 const std::vector<Operand>& operands,
750 Operand::LifeTime lifetime) {
751 const size_t operandCount = operands.size();
752 for (uint32_t i : indexes) {
753 NN_RET_CHECK_LT(i, operandCount)
754 << "Model " << lifetime << " input or output index out of range: " << i << "/"
755 << operandCount;
756 const Operand& operand = operands[i];
757 NN_RET_CHECK_EQ(operand.lifetime, lifetime)
758 << "Model " << lifetime << " operand " << i << " has lifetime of "
759 << operand.lifetime << " instead of the expected " << lifetime;
760 }
761
762 std::vector<uint32_t> sortedIndexes = indexes;
763 std::sort(sortedIndexes.begin(), sortedIndexes.end());
764 const auto iter = std::adjacent_find(sortedIndexes.begin(), sortedIndexes.end());
765 NN_RET_CHECK(iter == sortedIndexes.end())
766 << "Model input or output occurs multiple times: " << *iter;
767
768 for (size_t i = 0; i < operands.size(); ++i) {
769 if (operands[i].lifetime == lifetime) {
770 const auto containsIndex = [&sortedIndexes](size_t index) {
771 return binary_search(sortedIndexes.begin(), sortedIndexes.end(), index);
772 };
773 NN_RET_CHECK(containsIndex(i))
774 << "Operand " << i << " marked as " << lifetime
775 << " but is not included in Model input or output indexes";
776 }
777 }
778
779 return {};
780 }
781
validateExecutionOrder(const Model::Subgraph & subgraph)782 Result<void> validateExecutionOrder(const Model::Subgraph& subgraph) {
783 // Either the operand has a known value before model execution begins, or we've seen a writer
784 // for this operand while walking operands in execution order. Initialize to known operands.
785 std::vector<bool> operandValueKnown;
786 operandValueKnown.reserve(subgraph.operands.size());
787 std::transform(subgraph.operands.begin(), subgraph.operands.end(),
788 std::back_inserter(operandValueKnown), [](const Operand& operand) {
789 return operand.lifetime != Operand::LifeTime::TEMPORARY_VARIABLE &&
790 operand.lifetime != Operand::LifeTime::SUBGRAPH_OUTPUT;
791 });
792
793 // Validate that operations are sorted into execution order.
794 //
795 // If there is a cycle in the graph, the operations will not
796 // appear to be sorted into execution order: Some operation will
797 // have an input for which operandValueKnown[] is false.
798 for (size_t i = 0; i < subgraph.operations.size(); ++i) {
799 const auto& operation = subgraph.operations[i];
800
801 for (size_t j = 0; j < operation.inputs.size(); ++j) {
802 const uint32_t k = operation.inputs[j];
803 NN_RET_CHECK(operandValueKnown[k])
804 << "Operation " << i << " input " << j << " (operand " << k
805 << ") is read before it is written";
806 }
807
808 for (size_t j = 0; j < operation.outputs.size(); ++j) {
809 const uint32_t k = operation.outputs[j];
810 // Assuming validateOperations() has not returned an error, we know that this output is
811 // TEMPORARY_VARIABLE or MODEL_OUTPUT, and so the only way operandValueKnown[k] can be
812 // true is if we've already seen a writer for this operand.
813 NN_RET_CHECK(!operandValueKnown[k])
814 << "Operation " << i << " output " << j << " (operand " << k
815 << ") has already been written";
816 operandValueKnown[k] = true;
817 }
818 }
819
820 // Verify all operands are written.
821 for (size_t i = 0; i < subgraph.operands.size(); ++i) {
822 NN_RET_CHECK(operandValueKnown[i]) << "Operand " << i << " is never written";
823 }
824
825 // TODO(b/77871786): verify that every operation has at least one output operand that is read?
826
827 return {};
828 }
829
830 // Validate a subgraph, ensuring all subgraphs it depends on are also validated.
831 //
832 // `referencedIndex` is empty if the subgraph being validated is the main subgraph, otherwise it is
833 // the index of the referenced subgraph being validated.
834 //
835 // referenced[i] and (*subgraphVersionCache)[i] correspond to the same subgraph, and therefore
836 // `referenced` and `subgraphVersionCache` must have the same length.
validateModelSubgraph(const Model::Subgraph & subgraph,std::optional<size_t> referencedIndex,size_t operandValuesSize,const std::vector<size_t> & poolSizes,const std::vector<Model::Subgraph> & referenced,std::vector<std::optional<Version>> * subgraphVersionCache)837 Result<Version> validateModelSubgraph(const Model::Subgraph& subgraph,
838 std::optional<size_t> referencedIndex,
839 size_t operandValuesSize,
840 const std::vector<size_t>& poolSizes,
841 const std::vector<Model::Subgraph>& referenced,
842 std::vector<std::optional<Version>>* subgraphVersionCache) {
843 CHECK(subgraphVersionCache != nullptr);
844 CHECK_EQ(referenced.size(), subgraphVersionCache->size());
845
846 // Quickly return if the current subgraph has already been checked for its version.
847 if (referencedIndex.has_value()) {
848 if (auto version = subgraphVersionCache->at(*referencedIndex)) {
849 return *version;
850 }
851 }
852
853 NN_RET_CHECK(!subgraph.operands.empty());
854 NN_RET_CHECK(!subgraph.operations.empty());
855 // TODO(b/173780642): Clarify whether subgraphs with no inputs or outputs are valid.
856 // NN_RET_CHECK(!subgraph.inputIndexes.empty());
857 // NN_RET_CHECK(!subgraph.outputIndexes.empty());
858
859 const auto operandVersions = NN_TRY(validateOperands(
860 subgraph.operands, operandValuesSize, poolSizes, referenced, subgraphVersionCache));
861 const auto operationsVersion = NN_TRY(validateOperations(subgraph.operations, subgraph.operands,
862 operandVersions, referenced));
863
864 // Accumulate the versions from all operands and operations.
865 const auto version = std::accumulate(operandVersions.begin(), operandVersions.end(),
866 operationsVersion, combineVersions);
867
868 NN_TRY(validateModelSubgraphInputOutputs(subgraph.inputIndexes, subgraph.operands,
869 Operand::LifeTime::SUBGRAPH_INPUT));
870 NN_TRY(validateModelSubgraphInputOutputs(subgraph.outputIndexes, subgraph.operands,
871 Operand::LifeTime::SUBGRAPH_OUTPUT));
872
873 NN_TRY(validateExecutionOrder(subgraph));
874
875 // Mark the current subgraph as having already been validated so the caller can quickly return
876 // if this subgraph is checked again.
877 if (referencedIndex.has_value()) {
878 subgraphVersionCache->at(*referencedIndex) = version;
879 }
880 return version;
881 }
882
validateExtensionNamesAndPrefixes(const std::vector<ExtensionNameAndPrefix> & extensionNamesAndPrefixes)883 Result<Version> validateExtensionNamesAndPrefixes(
884 const std::vector<ExtensionNameAndPrefix>& extensionNamesAndPrefixes) {
885 for (const auto& extensionNameAndPrefix : extensionNamesAndPrefixes) {
886 NN_RET_CHECK(isValidExtensionName(extensionNameAndPrefix.name));
887 }
888
889 std::vector<std::reference_wrapper<const std::string>> names;
890 names.reserve(extensionNamesAndPrefixes.size());
891 std::transform(extensionNamesAndPrefixes.begin(), extensionNamesAndPrefixes.end(),
892 std::back_inserter(names),
893 [](const ExtensionNameAndPrefix& extensionNameAndPrefix) {
894 return std::cref(extensionNameAndPrefix.name);
895 });
896 std::sort(names.begin(), names.end(), std::less<std::string>{});
897 const auto nameIter =
898 std::adjacent_find(names.begin(), names.end(), std::equal_to<std::string>{});
899 NN_RET_CHECK(nameIter == names.end())
900 << "ExtensionNamesAndPrefixes has duplicate name " << nameIter->get();
901
902 std::vector<uint16_t> types;
903 types.reserve(extensionNamesAndPrefixes.size());
904 std::transform(extensionNamesAndPrefixes.begin(), extensionNamesAndPrefixes.end(),
905 std::back_inserter(types),
906 [](const ExtensionNameAndPrefix& extensionNameAndPrefix) {
907 return extensionNameAndPrefix.prefix;
908 });
909 std::sort(types.begin(), types.end());
910 const auto typeIter = std::adjacent_find(types.begin(), types.end());
911 NN_RET_CHECK(typeIter == types.end())
912 << "ExtensionNamesAndPrefixes has duplicate type " << *typeIter;
913
914 const bool hasExtensions = !extensionNamesAndPrefixes.empty();
915 return hasExtensions ? kVersionFeatureLevel3 : kVersionFeatureLevel1;
916 }
917
918 // Makes sure the model does not contain subgraph reference cycles.
919 //
920 // This function verifies that referencedSubgraphs[subgraphIndex] and any subgraphs it refences do
921 // not contain any reference cycles. `path` is used to keep track of which referenced subgraphs have
922 // already been visited in the current recursive reference path. `verified` is a cache to keep track
923 // of which referenced subgraphs have already been verified not to form reference cycles.
924 //
925 // referencedSubgraphs[i], (*path)[i], and (*verified)[i] all correspond to the same subgraph, and
926 // therefore `referencedSubgraphs`, `path`, and `verified` must all have the same length.
checkNoReferenceCycles(const std::vector<Model::Subgraph> & referencedSubgraphs,uint32_t subgraphIndex,std::vector<bool> * path,std::vector<bool> * verified)927 Result<void> checkNoReferenceCycles(const std::vector<Model::Subgraph>& referencedSubgraphs,
928 uint32_t subgraphIndex, std::vector<bool>* path,
929 std::vector<bool>* verified) {
930 CHECK(path != nullptr);
931 CHECK(verified != nullptr);
932 CHECK_EQ(referencedSubgraphs.size(), path->size());
933 CHECK_EQ(referencedSubgraphs.size(), verified->size());
934 NN_RET_CHECK_LT(subgraphIndex, referencedSubgraphs.size());
935 const auto& subgraph = referencedSubgraphs.at(subgraphIndex);
936
937 // Quickly return if the current subgraph has already been verified to have no reference cycles.
938 if ((*verified)[subgraphIndex]) {
939 return {};
940 }
941
942 // Add the current subgraph to the path (making sure that it is not already part of the path),
943 // and verify that all subgraphs this subgraph references do not contain cycles. The current
944 // subgraph is removed from the path only after all subgraphs this subgraph references have been
945 // checked.
946 NN_RET_CHECK((*path)[subgraphIndex] == false) << "Model contains a circular subgraph reference";
947 (*path)[subgraphIndex] = true;
948 for (const Operand& operand : subgraph.operands) {
949 if (operand.lifetime == Operand::LifeTime::SUBGRAPH) {
950 const uint32_t refSubgraphIndex = operand.location.offset;
951 NN_TRY(checkNoReferenceCycles(referencedSubgraphs, refSubgraphIndex, path, verified));
952 }
953 }
954 (*path)[subgraphIndex] = false;
955
956 // Mark the current subgraph as having already been verified so the caller can quickly return if
957 // this subgraph is checked again.
958 (*verified)[subgraphIndex] = true;
959 return {};
960 }
961
checkNoReferenceCycles(const std::vector<Model::Subgraph> & referencedSubgraphs)962 Result<void> checkNoReferenceCycles(const std::vector<Model::Subgraph>& referencedSubgraphs) {
963 const size_t count = referencedSubgraphs.size();
964 std::vector<bool> path(count);
965 std::vector<bool> verified(count);
966 for (size_t i = 0; i < count; ++i) {
967 NN_TRY(checkNoReferenceCycles(referencedSubgraphs, i, &path, &verified));
968 }
969 return {};
970 }
971
validateModel(const Model & model)972 Result<Version> validateModel(const Model& model) {
973 auto version = NN_TRY(validateVector(model.pools, validateSharedMemory));
974 version = combineVersions(
975 version, NN_TRY(validateExtensionNamesAndPrefixes(model.extensionNameToPrefix)));
976
977 // Ignore relaxComputationFloat32toFloat16 version because in the worst case it makes the
978 // execution stricter.
979
980 // Referenced models were introduced in Android R.
981 const bool hasReferencedModels = !model.referenced.empty();
982 const auto referenceModelVersion =
983 hasReferencedModels ? kVersionFeatureLevel4 : kVersionFeatureLevel1;
984 version = combineVersions(version, referenceModelVersion);
985
986 // Ensure that there are no cycles formed by the subgraphs.
987 NN_TRY(checkNoReferenceCycles(model.referenced));
988
989 // Get memory sizes.
990 const auto [operandValuesSize, poolSizes] = getMemorySizes(model);
991
992 // Validate referenced subgraphs.
993 auto subgraphVersionCache = std::vector<std::optional<Version>>(model.referenced.size());
994 for (size_t referencedIndex = 0; referencedIndex < model.referenced.size(); ++referencedIndex) {
995 const auto& subgraph = model.referenced[referencedIndex];
996 const auto subgraphVersion =
997 NN_TRY(validateModelSubgraph(subgraph, referencedIndex, operandValuesSize,
998 poolSizes, model.referenced, &subgraphVersionCache));
999 version = combineVersions(version, subgraphVersion);
1000 }
1001
1002 // Validate main subgraph.
1003 const auto subgraphVersion =
1004 NN_TRY(validateModelSubgraph(model.main, std::nullopt, operandValuesSize, poolSizes,
1005 model.referenced, &subgraphVersionCache));
1006 version = combineVersions(version, subgraphVersion);
1007
1008 return version;
1009 }
1010
validateBufferDesc(const BufferDesc & bufferDesc)1011 Result<Version> validateBufferDesc(const BufferDesc& bufferDesc) {
1012 // An empty BufferDesc is the default value, so it is implicitly valid for all versions.
1013 return bufferDesc.dimensions.empty() ? kVersionFeatureLevel1 : kVersionFeatureLevel4;
1014 }
1015
validateBufferRole(const BufferRole & bufferRole)1016 Result<Version> validateBufferRole(const BufferRole& bufferRole) {
1017 NN_RET_CHECK_GT(bufferRole.probability, 0.0f);
1018 NN_RET_CHECK_LE(bufferRole.probability, 1.0f);
1019 return kVersionFeatureLevel4;
1020 }
1021
validateRequestArgument(const Request::Argument & requestArgument,const std::vector<size_t> & memorySizes,bool isOutput)1022 Result<Version> validateRequestArgument(const Request::Argument& requestArgument,
1023 const std::vector<size_t>& memorySizes, bool isOutput) {
1024 const auto lifetime = requestArgument.lifetime;
1025 const auto& location = requestArgument.location;
1026 const auto& dimensions = requestArgument.dimensions;
1027
1028 switch (lifetime) {
1029 case Request::Argument::LifeTime::POOL: {
1030 NN_RET_CHECK(location.pointer == kNullptrVariant);
1031 NN_RET_CHECK_LT(location.poolIndex, memorySizes.size());
1032 // Do the addition using uint64_t to avoid potential wrap-around problems.
1033 const auto lastPosition =
1034 static_cast<uint64_t>(location.offset) + location.length + location.padding;
1035 const auto memorySize = memorySizes[location.poolIndex];
1036 NN_RET_CHECK_LE(lastPosition, memorySize);
1037 if (memorySize > 0) {
1038 // Must specify a positive length if the memory pool has a known size.
1039 NN_RET_CHECK_GT(location.length, 0u);
1040 }
1041 return kVersionFeatureLevel1;
1042 }
1043 case Request::Argument::LifeTime::NO_VALUE:
1044 NN_RET_CHECK(location.pointer == kNullptrVariant);
1045 NN_RET_CHECK_EQ(location.poolIndex, 0u);
1046 NN_RET_CHECK_EQ(location.offset, 0u);
1047 NN_RET_CHECK_EQ(location.length, 0u);
1048 NN_RET_CHECK_EQ(location.padding, 0u);
1049 NN_RET_CHECK(dimensions.empty());
1050 return kVersionFeatureLevel1;
1051 case Request::Argument::LifeTime::POINTER: {
1052 const bool isNullptr =
1053 std::visit([](auto ptr) { return ptr == nullptr; }, location.pointer);
1054 NN_RET_CHECK(!isNullptr);
1055 NN_RET_CHECK_EQ(location.poolIndex, 0u);
1056 NN_RET_CHECK_EQ(location.offset, 0u);
1057 NN_RET_CHECK_NE(location.length, 0u);
1058 if (isOutput) {
1059 NN_RET_CHECK(std::holds_alternative<void*>(location.pointer));
1060 }
1061 return kVersionFeatureLevel1;
1062 }
1063 }
1064 NN_RET_CHECK_FAIL() << "Invalid Request::Argument::LifeTime " << lifetime;
1065 }
1066
validateRequestMemoryPool(const Request::MemoryPool & memoryPool)1067 Result<Version> validateRequestMemoryPool(const Request::MemoryPool& memoryPool) {
1068 if (std::holds_alternative<Request::MemoryDomainToken>(memoryPool)) {
1069 NN_RET_CHECK(std::get<Request::MemoryDomainToken>(memoryPool) != kInvalidMemoryDomainToken);
1070 return kVersionFeatureLevel4;
1071 }
1072 if (std::holds_alternative<SharedBuffer>(memoryPool)) {
1073 NN_RET_CHECK(std::get<SharedBuffer>(memoryPool) != nullptr);
1074 return kVersionFeatureLevel4;
1075 }
1076 return validateSharedMemory(std::get<SharedMemory>(memoryPool));
1077 }
1078
validateRequest(const Request & request)1079 Result<Version> validateRequest(const Request& request) {
1080 auto version = NN_TRY(validateVector(request.pools, validateRequestMemoryPool));
1081
1082 // Get memory sizes. For IBuffer or MemoryDomainToken types, set size to 0.
1083 std::vector<size_t> memorySizes;
1084 memorySizes.reserve(request.pools.size());
1085 std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(memorySizes),
1086 [](const Request::MemoryPool& memoryPool) {
1087 const auto* memory = std::get_if<SharedMemory>(&memoryPool);
1088 return memory != nullptr ? getSize(*memory) : 0;
1089 });
1090
1091 for (size_t i = 0; i < request.inputs.size(); ++i) {
1092 const auto& input = request.inputs[i];
1093 auto result = validateRequestArgument(input, memorySizes, /*isOutput=*/false);
1094 if (!result.has_value()) {
1095 return error() << std::move(result).error() << " for input RequestArgument " << i;
1096 }
1097 version = combineVersions(version, result.value());
1098 }
1099 for (size_t i = 0; i < request.outputs.size(); ++i) {
1100 const auto& output = request.outputs[i];
1101 auto result = validateRequestArgument(output, memorySizes, /*isOutput=*/true);
1102 if (!result.has_value()) {
1103 return error() << std::move(result).error() << " for output RequestArgument " << i;
1104 }
1105 version = combineVersions(version, result.value());
1106 }
1107
1108 return version;
1109 }
1110
validateOptionalTimePoint(const OptionalTimePoint & optionalTimePoint)1111 Result<Version> validateOptionalTimePoint(const OptionalTimePoint& optionalTimePoint) {
1112 if (optionalTimePoint.has_value()) {
1113 NN_RET_CHECK_GE(optionalTimePoint->time_since_epoch().count(), 0);
1114 }
1115 // An omitted time point is the default value, so it is implicitly valid for all versions.
1116 return !optionalTimePoint.has_value() ? kVersionFeatureLevel1 : kVersionFeatureLevel4;
1117 }
1118
validateOptionalTimeoutDuration(const OptionalDuration & optionalTimeoutDuration)1119 Result<Version> validateOptionalTimeoutDuration(const OptionalDuration& optionalTimeoutDuration) {
1120 if (optionalTimeoutDuration.has_value()) {
1121 NN_RET_CHECK_GE(optionalTimeoutDuration->count(), 0);
1122 }
1123 // An omitted duration is the default value, so it is implicitly valid for all versions.
1124 return !optionalTimeoutDuration.has_value() ? kVersionFeatureLevel1 : kVersionFeatureLevel4;
1125 }
1126
validateCacheToken(const CacheToken & cacheToken)1127 Result<Version> validateCacheToken(const CacheToken& cacheToken) {
1128 // A CacheToken of 0 is the default value, so it is implicitly valid for all versions.
1129 constexpr auto kDefaultCacheToken = CacheToken{};
1130 return cacheToken == kDefaultCacheToken ? kVersionFeatureLevel1 : kVersionFeatureLevel3;
1131 }
1132
validateSyncFence(const SyncFence & syncFence)1133 Result<Version> validateSyncFence(const SyncFence& syncFence) {
1134 // The absence of a sync fence is implicitly valid for all versions.
1135 if (!syncFence.hasFd()) {
1136 return kVersionFeatureLevel1;
1137 }
1138 NN_RET_CHECK_GE(syncFence.getFd(), 0);
1139 return kVersionFeatureLevel4;
1140 }
1141
validateTokenValuePair(const TokenValuePair &)1142 Result<Version> validateTokenValuePair(const TokenValuePair& /*tokenValuePair*/) {
1143 return kVersionFeatureLevel8;
1144 }
1145
validateRequestArgumentsForModel(const std::vector<Request::Argument> & requestArguments,const std::vector<uint32_t> & operandIndexes,const std::vector<Operand> & operands,bool isOutput,bool allowUnspecifiedOutput)1146 Result<Version> validateRequestArgumentsForModel(
1147 const std::vector<Request::Argument>& requestArguments,
1148 const std::vector<uint32_t>& operandIndexes, const std::vector<Operand>& operands,
1149 bool isOutput, bool allowUnspecifiedOutput) {
1150 auto version = kVersionFeatureLevel1;
1151 // The request should specify as many arguments as were described in the model.
1152 const std::string_view type = isOutput ? "output" : "input";
1153 const size_t requestArgumentCount = requestArguments.size();
1154 NN_RET_CHECK_EQ(requestArgumentCount, operandIndexes.size())
1155 << "Request specifies " << requestArgumentCount << " " << type << "s but the model has "
1156 << operandIndexes.size();
1157 for (size_t requestArgumentIndex = 0; requestArgumentIndex < requestArgumentCount;
1158 requestArgumentIndex++) {
1159 const Request::Argument& requestArgument = requestArguments[requestArgumentIndex];
1160 // Get the operand index for this argument. We extract it from the list
1161 // that was provided in the call to ANeuralNetworksModel_identifyInputsAndOutputs.
1162 // We assume in this function that the model has been validated already.
1163 const uint32_t operandIndex = operandIndexes[requestArgumentIndex];
1164 const Operand& operand = operands[operandIndex];
1165 if (requestArgument.lifetime != Request::Argument::LifeTime::NO_VALUE) {
1166 const bool isExtensionType = isExtension(operand.type);
1167 // If the argument specified a dimension, validate it.
1168 uint32_t modelRank = operand.dimensions.size();
1169 uint32_t requestRank = requestArgument.dimensions.size();
1170 if (requestRank == 0) {
1171 // NOTE: validateRequestArguments cannot validate unknown tensor rank with
1172 // extension operand type.
1173 if (!isExtensionType && !isNonExtensionScalar(operand.type)) {
1174 if (modelRank <= 0) {
1175 NN_RET_CHECK(isOutput)
1176 << "Model has unknown input rank but the request does not "
1177 "specify the rank.";
1178 NN_RET_CHECK(allowUnspecifiedOutput)
1179 << "Model has unknown output rank and request does not specify it.";
1180 // Unspecified output dimensions introduced in Android Q.
1181 version = combineVersions(version, kVersionFeatureLevel3);
1182 }
1183 }
1184 // Validate that all the dimensions are specified in the model.
1185 for (size_t i = 0; i < modelRank; i++) {
1186 if (operand.dimensions[i] == 0) {
1187 NN_RET_CHECK(isOutput && allowUnspecifiedOutput)
1188 << "Model has dimension " << i
1189 << " set to 0 but the request does not specify the dimension.";
1190 // Unspecified output dimensions introduced in Android Q.
1191 version = combineVersions(version, kVersionFeatureLevel3);
1192 }
1193 }
1194 } else {
1195 NN_RET_CHECK(modelRank == 0 || requestRank == modelRank)
1196 << "Request " << type << " " << requestArgumentIndex
1197 << " has number of dimensions (" << requestRank
1198 << ") different than the model's (" << modelRank << ")";
1199 for (size_t i = 0; i < requestRank; i++) {
1200 NN_RET_CHECK(modelRank == 0 || operand.dimensions[i] == 0 ||
1201 requestArgument.dimensions[i] == operand.dimensions[i])
1202 << "Request " << type << " " << requestArgumentIndex
1203 << " has dimension " << i << " of " << requestArgument.dimensions[i]
1204 << " different than the model's " << operand.dimensions[i];
1205 if (requestArgument.dimensions[i] == 0) {
1206 NN_RET_CHECK(isOutput && allowUnspecifiedOutput)
1207 << "Request " << type << " " << requestArgumentIndex
1208 << " has dimension " << i << " of zero";
1209 // Unspecified output dimensions introduced in Android Q.
1210 version = combineVersions(version, kVersionFeatureLevel3);
1211 }
1212 }
1213 }
1214 // NOTE: validateRequestArguments cannot validate DataLocation::length
1215 // with extension operand type.
1216 if (!isExtensionType && requestArgument.location.length != 0) {
1217 const auto dimensions =
1218 NN_TRY(combineDimensions(operand.dimensions, requestArgument.dimensions));
1219 const size_t expectedLength = getNonExtensionSize(operand.type, dimensions).value();
1220 if (expectedLength != 0) {
1221 NN_RET_CHECK_EQ(requestArgument.location.length, expectedLength)
1222 << "Request " << type << " " << requestArgumentIndex
1223 << " expected a size of " << expectedLength << " but got "
1224 << requestArgument.location.length;
1225 }
1226 }
1227 }
1228 }
1229 return version;
1230 }
1231
validateRequestForModelImpl(const Request & request,const Model & model,bool allowUnspecifiedOutput)1232 Result<Version> validateRequestForModelImpl(const Request& request, const Model& model,
1233 bool allowUnspecifiedOutput) {
1234 auto version = NN_TRY(validateRequest(request));
1235 version = combineVersions(version,
1236 NN_TRY(validateRequestArgumentsForModel(
1237 request.inputs, model.main.inputIndexes, model.main.operands,
1238 /*isOutput=*/false, /*allowUnspecifiedOutput=*/true)));
1239 version = combineVersions(
1240 version, NN_TRY(validateRequestArgumentsForModel(
1241 request.outputs, model.main.outputIndexes, model.main.operands,
1242 /*isOutput=*/true, allowUnspecifiedOutput)));
1243 return version;
1244 }
1245
validateMemoryDescImpl(const BufferDesc & desc,const std::vector<SharedPreparedModel> & preparedModels,const std::vector<BufferRole> & inputRoles,const std::vector<BufferRole> & outputRoles,const std::function<const Model * (const SharedPreparedModel &)> & getModel,std::set<PreparedModelRole> * preparedModelRoles,Operand * combinedOperand)1246 Result<Version> validateMemoryDescImpl(
1247 const BufferDesc& desc, const std::vector<SharedPreparedModel>& preparedModels,
1248 const std::vector<BufferRole>& inputRoles, const std::vector<BufferRole>& outputRoles,
1249 const std::function<const Model*(const SharedPreparedModel&)>& getModel,
1250 std::set<PreparedModelRole>* preparedModelRoles, Operand* combinedOperand) {
1251 NN_RET_CHECK(!preparedModels.empty());
1252 NN_RET_CHECK(!inputRoles.empty() || !outputRoles.empty());
1253
1254 std::set<PreparedModelRole> roles;
1255 std::vector<nn::Operand> operands;
1256 operands.reserve(inputRoles.size() + outputRoles.size());
1257 for (const auto& role : inputRoles) {
1258 NN_RET_CHECK_LT(role.modelIndex, preparedModels.size());
1259 const auto& preparedModel = preparedModels[role.modelIndex];
1260 NN_RET_CHECK(preparedModel != nullptr);
1261 const auto* model = getModel(preparedModel);
1262 NN_RET_CHECK(model != nullptr);
1263 const auto& inputIndexes = model->main.inputIndexes;
1264 NN_RET_CHECK_LT(role.ioIndex, inputIndexes.size());
1265 NN_RET_CHECK_GT(role.probability, 0.0f);
1266 NN_RET_CHECK_LE(role.probability, 1.0f);
1267 const auto [it, success] = roles.emplace(preparedModel.get(), IOType::INPUT, role.ioIndex);
1268 NN_RET_CHECK(success);
1269 operands.push_back(model->main.operands[inputIndexes[role.ioIndex]]);
1270 }
1271 for (const auto& role : outputRoles) {
1272 NN_RET_CHECK_LT(role.modelIndex, preparedModels.size());
1273 const auto& preparedModel = preparedModels[role.modelIndex];
1274 NN_RET_CHECK(preparedModel != nullptr);
1275 const auto* model = getModel(preparedModel);
1276 NN_RET_CHECK(model != nullptr);
1277 const auto& outputIndexes = model->main.outputIndexes;
1278 NN_RET_CHECK_LT(role.ioIndex, outputIndexes.size());
1279 NN_RET_CHECK_GT(role.probability, 0.0f);
1280 NN_RET_CHECK_LE(role.probability, 1.0f);
1281 const auto [it, success] = roles.emplace(preparedModel.get(), IOType::OUTPUT, role.ioIndex);
1282 NN_RET_CHECK(success);
1283 operands.push_back(model->main.operands[outputIndexes[role.ioIndex]]);
1284 }
1285
1286 CHECK(!operands.empty());
1287 const auto opType = operands.front().type;
1288
1289 Dimensions dimensions = desc.dimensions;
1290 for (const auto& operand : operands) {
1291 NN_RET_CHECK_EQ(operand.type, opType) << operand.type << " vs " << operands.front().type;
1292 NN_RET_CHECK_EQ(operand.scale, operands.front().scale);
1293 NN_RET_CHECK_EQ(operand.zeroPoint, operands.front().zeroPoint);
1294 // NOTE: validateMemoryDesc cannot validate extra parameters for extension operand type.
1295 if (!isExtension(opType)) {
1296 NN_RET_CHECK_EQ(operand.extraParams, operands.front().extraParams)
1297 << operand.extraParams << " vs " << operands.front().extraParams;
1298 }
1299 dimensions = NN_TRY(combineDimensions(dimensions, operand.dimensions));
1300 }
1301
1302 // NOTE: validateMemoryDesc cannot validate scalar dimensions with extension operand type.
1303 if (!isExtension(opType)) {
1304 NN_RET_CHECK(!isNonExtensionScalar(opType) || dimensions.empty())
1305 << "invalid dimensions with scalar operand type.";
1306 }
1307
1308 if (preparedModelRoles != nullptr) {
1309 *preparedModelRoles = std::move(roles);
1310 }
1311 if (combinedOperand != nullptr) {
1312 *combinedOperand = operands.front();
1313 combinedOperand->dimensions = dimensions;
1314 }
1315 return kVersionFeatureLevel4;
1316 }
1317
1318 class OperationValidationContext : public IOperationValidationContext {
1319 DISALLOW_IMPLICIT_CONSTRUCTORS(OperationValidationContext);
1320
1321 public:
OperationValidationContext(const char * operationName,const std::vector<uint32_t> & inputIndexes,const std::vector<uint32_t> & outputIndexes,const std::vector<Operand> & operands)1322 OperationValidationContext(const char* operationName, const std::vector<uint32_t>& inputIndexes,
1323 const std::vector<uint32_t>& outputIndexes,
1324 const std::vector<Operand>& operands)
1325 : operationName(operationName),
1326 inputIndexes(inputIndexes),
1327 outputIndexes(outputIndexes),
1328 operands(operands) {}
1329
1330 const char* getOperationName() const override;
1331
1332 uint32_t getNumInputs() const override;
1333 OperandType getInputType(uint32_t index) const override;
1334 Shape getInputShape(uint32_t index) const override;
1335 const Operand::ExtraParams& getInputExtraParams(uint32_t index) const override;
1336
1337 uint32_t getNumOutputs() const override;
1338 OperandType getOutputType(uint32_t index) const override;
1339 Shape getOutputShape(uint32_t index) const override;
1340
1341 private:
1342 const Operand* getInputOperand(uint32_t index) const;
1343 const Operand* getOutputOperand(uint32_t index) const;
1344
1345 const char* operationName;
1346 const std::vector<uint32_t>& inputIndexes;
1347 const std::vector<uint32_t>& outputIndexes;
1348 const std::vector<Operand>& operands;
1349 };
1350
getOperationName() const1351 const char* OperationValidationContext::getOperationName() const {
1352 return operationName;
1353 }
1354
getInputOperand(uint32_t index) const1355 const Operand* OperationValidationContext::getInputOperand(uint32_t index) const {
1356 return &operands.at(inputIndexes.at(index));
1357 }
1358
getOutputOperand(uint32_t index) const1359 const Operand* OperationValidationContext::getOutputOperand(uint32_t index) const {
1360 return &operands.at(outputIndexes.at(index));
1361 }
1362
getNumInputs() const1363 uint32_t OperationValidationContext::getNumInputs() const {
1364 auto count = inputIndexes.size();
1365 CHECK_LE(count, std::numeric_limits<uint32_t>::max());
1366 return static_cast<uint32_t>(count);
1367 }
1368
getNumOutputs() const1369 uint32_t OperationValidationContext::getNumOutputs() const {
1370 auto count = outputIndexes.size();
1371 CHECK_LE(count, std::numeric_limits<uint32_t>::max());
1372 return static_cast<uint32_t>(count);
1373 }
1374
getInputType(uint32_t index) const1375 OperandType OperationValidationContext::getInputType(uint32_t index) const {
1376 return getInputOperand(index)->type;
1377 }
1378
getInputShape(uint32_t index) const1379 Shape OperationValidationContext::getInputShape(uint32_t index) const {
1380 const Operand* operand = getInputOperand(index);
1381 return {operand->type, operand->dimensions, operand->scale, operand->zeroPoint,
1382 operand->extraParams};
1383 }
1384
getInputExtraParams(uint32_t index) const1385 const Operand::ExtraParams& OperationValidationContext::getInputExtraParams(uint32_t index) const {
1386 return getInputOperand(index)->extraParams;
1387 }
1388
getOutputType(uint32_t index) const1389 OperandType OperationValidationContext::getOutputType(uint32_t index) const {
1390 return getOutputOperand(index)->type;
1391 }
1392
getOutputShape(uint32_t index) const1393 Shape OperationValidationContext::getOutputShape(uint32_t index) const {
1394 const Operand* operand = getOutputOperand(index);
1395 return {operand->type, operand->dimensions, operand->scale, operand->zeroPoint,
1396 operand->extraParams};
1397 }
1398
1399 // TODO(b/169345292): reduce the duplicate validation here
1400
validateOperandSymmPerChannelQuantParamsImpl(const Operand & operand,const Operand::SymmPerChannelQuantParams & channelQuant,const char * tag)1401 Result<void> validateOperandSymmPerChannelQuantParamsImpl(
1402 const Operand& operand, const Operand::SymmPerChannelQuantParams& channelQuant,
1403 const char* tag) {
1404 if (operand.type != OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
1405 NN_RET_CHECK_FAIL();
1406 }
1407
1408 NN_RET_CHECK_LT(channelQuant.channelDim, operand.dimensions.size()) << tag;
1409 NN_RET_CHECK(!channelQuant.scales.empty()) << tag;
1410 NN_RET_CHECK_EQ(channelQuant.scales.size(), operand.dimensions[channelQuant.channelDim]) << tag;
1411 NN_RET_CHECK_NE(operand.dimensions[channelQuant.channelDim], 0u)
1412 << tag << " channel dimension " << channelQuant.channelDim << " is underspecified";
1413 for (uint32_t i = 0; i < operand.dimensions[channelQuant.channelDim]; i++) {
1414 NN_RET_CHECK_GT(channelQuant.scales[i], 0.0f) << tag << " invalid scaleArray[" << i << "]";
1415 }
1416 return {};
1417 }
1418
validateScalarDimensions(const Operand & type,const char * tag)1419 Result<void> validateScalarDimensions(const Operand& type, const char* tag) {
1420 NN_RET_CHECK(type.dimensions.empty()) << tag << " invalid dimensions for scalar type";
1421 return {};
1422 }
1423
validateQuant8AsymmParams(const Operand & type,const char * tag)1424 Result<void> validateQuant8AsymmParams(const Operand& type, const char* tag) {
1425 NN_RET_CHECK(0 <= type.zeroPoint && type.zeroPoint <= 255)
1426 << tag << " invalid zeroPoint: " << type.zeroPoint;
1427 NN_RET_CHECK_GT(type.scale, 0.0f) << tag << " invalid scale";
1428 return {};
1429 }
1430
validateQuant8AsymmSignedParams(const Operand & type,const char * tag)1431 Result<void> validateQuant8AsymmSignedParams(const Operand& type, const char* tag) {
1432 NN_RET_CHECK(-128 <= type.zeroPoint && type.zeroPoint <= 127)
1433 << tag << " invalid zeroPoint: " << type.zeroPoint;
1434 NN_RET_CHECK_GT(type.scale, 0.0f) << tag << " invalid scale";
1435 return {};
1436 }
1437
validateQuant8SymmParams(const Operand & type,const char * tag)1438 Result<void> validateQuant8SymmParams(const Operand& type, const char* tag) {
1439 NN_RET_CHECK_EQ(type.zeroPoint, 0) << tag << " invalid zeroPoint: " << type.zeroPoint;
1440 NN_RET_CHECK_GT(type.scale, 0.0f) << tag << " invalid scale";
1441 return {};
1442 }
1443
validateQuant16AsymmParams(const Operand & type,const char * tag)1444 Result<void> validateQuant16AsymmParams(const Operand& type, const char* tag) {
1445 NN_RET_CHECK(0 <= type.zeroPoint && type.zeroPoint <= 65535)
1446 << tag << " invalid zeroPoint: " << type.zeroPoint;
1447 NN_RET_CHECK_GT(type.scale, 0.0f) << tag << " invalid scale";
1448 return {};
1449 }
1450
validateQuantSymmParams(const Operand & type,const char * tag)1451 Result<void> validateQuantSymmParams(const Operand& type, const char* tag) {
1452 NN_RET_CHECK_EQ(type.zeroPoint, 0) << tag << " zeroPoint is not zero";
1453 NN_RET_CHECK_GT(type.scale, 0.0f) << tag << " invalid scale";
1454 return {};
1455 }
1456
validateNoQuantParams(const Operand & type,const char * tag)1457 Result<void> validateNoQuantParams(const Operand& type, const char* tag) {
1458 NN_RET_CHECK_EQ(type.zeroPoint, 0) << tag << " zeroPoint is not zero";
1459 NN_RET_CHECK_EQ(type.scale, 0.0f) << tag << " scale is not zero";
1460 return {};
1461 }
1462
validateTensorDimensions(const Operand & type,const Extension::OperandTypeInformation * extensionOperandTypeInfo,const char * tag,bool allowPartial)1463 Result<void> validateTensorDimensions(
1464 const Operand& type, const Extension::OperandTypeInformation* extensionOperandTypeInfo,
1465 const char* tag, bool allowPartial) {
1466 if (!allowPartial) {
1467 NN_RET_CHECK(!type.dimensions.empty()) << tag << " invalid operand dimensions";
1468 }
1469 uint64_t size = isExtension(type.type) ? extensionOperandTypeInfo->byteSize
1470 : getNonExtensionSize(type.type);
1471 constexpr uint64_t kMaxSize = std::numeric_limits<uint32_t>::max();
1472 for (size_t i = 0; i < type.dimensions.size(); i++) {
1473 if (!allowPartial) {
1474 NN_RET_CHECK_NE(type.dimensions[i], 0u) << tag << " invalid operand dimensions";
1475 }
1476 if (type.dimensions[i] != 0) {
1477 size *= type.dimensions[i];
1478 NN_RET_CHECK_LE(size, kMaxSize) << tag << " operand byte size exceeds " << kMaxSize;
1479 }
1480 }
1481 return {};
1482 }
1483
validateOperandTypeImpl(const Operand & type,const Extension::OperandTypeInformation * const extensionOperandTypeInfo,const char * tag,bool allowPartial)1484 Result<void> validateOperandTypeImpl(
1485 const Operand& type,
1486 const Extension::OperandTypeInformation* const extensionOperandTypeInfo, const char* tag,
1487 bool allowPartial) {
1488 if (isExtension(type.type)) {
1489 NN_RET_CHECK(extensionOperandTypeInfo != nullptr);
1490 if (extensionOperandTypeInfo->isTensor) {
1491 NN_TRY(validateTensorDimensions(type, extensionOperandTypeInfo, tag, allowPartial));
1492 } else {
1493 NN_TRY(validateScalarDimensions(type, tag));
1494 }
1495 return validateNoQuantParams(type, tag);
1496 }
1497
1498 NN_RET_CHECK(extensionOperandTypeInfo == nullptr);
1499 NN_TRY(validateOperandType(type.type));
1500
1501 if (isNonExtensionScalar(type.type)) {
1502 NN_TRY(validateScalarDimensions(type, tag));
1503 if (type.type != OperandType::OEM) { // Historically, we have allowed OEM types
1504 // to use quantization parameters.
1505 NN_TRY(validateNoQuantParams(type, tag));
1506 }
1507 } else {
1508 NN_TRY(validateTensorDimensions(type, extensionOperandTypeInfo, tag, allowPartial));
1509 if (type.type == OperandType::TENSOR_QUANT8_ASYMM) {
1510 NN_TRY(validateQuant8AsymmParams(type, tag));
1511 } else if (type.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1512 NN_TRY(validateQuant8AsymmSignedParams(type, tag));
1513 } else if (type.type == OperandType::TENSOR_QUANT8_SYMM) {
1514 NN_TRY(validateQuant8SymmParams(type, tag));
1515 } else if (type.type == OperandType::TENSOR_QUANT16_ASYMM) {
1516 NN_TRY(validateQuant16AsymmParams(type, tag));
1517 } else if (type.type == OperandType::TENSOR_QUANT16_SYMM) {
1518 NN_TRY(validateQuantSymmParams(type, tag));
1519 } else if (type.type == OperandType::TENSOR_INT32 ||
1520 type.type == OperandType::TENSOR_OEM_BYTE) {
1521 // TODO(b/119869082): TENSOR_INT32 should not use quantization parameters.
1522 // Historically, we have allowed OEM types to use quantization parameters.
1523 } else {
1524 NN_TRY(validateNoQuantParams(type, tag));
1525 }
1526 }
1527
1528 return {};
1529 }
1530
validateOperandListImpl(const std::vector<uint32_t> & list,size_t operandCount,const char * tag)1531 Result<void> validateOperandListImpl(const std::vector<uint32_t>& list, size_t operandCount,
1532 const char* tag) {
1533 for (size_t i = 0; i < list.size(); i++) {
1534 NN_RET_CHECK_LT(list[i], operandCount) << tag << " invalid operand index at " << i << " = "
1535 << list[i] << ", operandCount " << operandCount;
1536 }
1537 return {};
1538 }
1539
validateSubgraphReference(const std::vector<Model::Subgraph> & subgraphs,const Operand & modelOperand)1540 Result<void> validateSubgraphReference(const std::vector<Model::Subgraph>& subgraphs,
1541 const Operand& modelOperand) {
1542 NN_RET_CHECK_EQ(modelOperand.type, OperandType::SUBGRAPH)
1543 << "Unexpected operand type: " << modelOperand.type;
1544 NN_RET_CHECK_LT(modelOperand.location.offset, subgraphs.size()) << "Invalid subgraph reference";
1545 return {};
1546 }
getSubgraph(const std::vector<Model::Subgraph> & subgraphs,const Operand & modelOperand)1547 const Model::Subgraph& getSubgraph(const std::vector<Model::Subgraph>& subgraphs,
1548 const Operand& modelOperand) {
1549 return subgraphs.at(modelOperand.location.offset);
1550 }
getInputCount(const std::vector<Model::Subgraph> & subgraphs,const Operand & modelOperand)1551 uint32_t getInputCount(const std::vector<Model::Subgraph>& subgraphs, const Operand& modelOperand) {
1552 return getSubgraph(subgraphs, modelOperand).inputIndexes.size();
1553 }
getOutputCount(const std::vector<Model::Subgraph> & subgraphs,const Operand & modelOperand)1554 uint32_t getOutputCount(const std::vector<Model::Subgraph>& subgraphs,
1555 const Operand& modelOperand) {
1556 return getSubgraph(subgraphs, modelOperand).outputIndexes.size();
1557 }
getInputOperand(const std::vector<Model::Subgraph> & subgraphs,const Operand & modelOperand,uint32_t index)1558 const Operand& getInputOperand(const std::vector<Model::Subgraph>& subgraphs,
1559 const Operand& modelOperand, uint32_t index) {
1560 const Model::Subgraph& subgraph = getSubgraph(subgraphs, modelOperand);
1561 return subgraph.operands.at(subgraph.inputIndexes.at(index));
1562 }
getOutputOperand(const std::vector<Model::Subgraph> & subgraphs,const Operand & modelOperand,uint32_t index)1563 const Operand& getOutputOperand(const std::vector<Model::Subgraph>& subgraphs,
1564 const Operand& modelOperand, uint32_t index) {
1565 const Model::Subgraph& subgraph = getSubgraph(subgraphs, modelOperand);
1566 return subgraph.operands.at(subgraph.outputIndexes.at(index));
1567 }
1568
1569 // Checks if two operands have the same types, ranks (if specified), dimensions
1570 // (if specified), scales, zeroPoints, and extraParams.
compatible(const Operand & a,const Operand & b)1571 Result<void> compatible(const Operand& a, const Operand& b) {
1572 NN_RET_CHECK_EQ(a.type, b.type) << a.type << " != " << b.type;
1573 if (!a.dimensions.empty() && !b.dimensions.empty()) {
1574 NN_RET_CHECK_EQ(a.dimensions.size(), b.dimensions.size()) << "Incompatible dimensions";
1575 for (uint32_t i = 0, n = a.dimensions.size(); i < n; ++i) {
1576 if (a.dimensions[i] != 0 && b.dimensions[i] != 0) {
1577 NN_RET_CHECK_EQ(a.dimensions[i], b.dimensions[i]) << "Incompatible dimensions";
1578 }
1579 }
1580 }
1581 NN_RET_CHECK_EQ(a.scale, b.scale);
1582 NN_RET_CHECK_EQ(a.zeroPoint, b.zeroPoint);
1583 NN_RET_CHECK_EQ(a.extraParams, b.extraParams) << a.extraParams << " != " << b.extraParams;
1584 return {};
1585 }
1586
validateConditionOperand(const Operand & operand)1587 Result<void> validateConditionOperand(const Operand& operand) {
1588 NN_RET_CHECK_EQ(operand.type, OperandType::TENSOR_BOOL8)
1589 << "Unexpected condition operand type: " << operand.type;
1590 NN_RET_CHECK_EQ(operand.dimensions.size(), 1u) << "Condition operand must be a singleton";
1591 NN_RET_CHECK_EQ(operand.dimensions[0], 1u) << "Condition operand must be a singleton";
1592 return {};
1593 }
1594
validateIfOperation(const std::vector<uint32_t> & inputs,const std::vector<uint32_t> & outputs,const std::vector<Operand> & operands,const std::vector<Model::Subgraph> & subgraphs)1595 Result<Version> validateIfOperation(const std::vector<uint32_t>& inputs,
1596 const std::vector<uint32_t>& outputs,
1597 const std::vector<Operand>& operands,
1598 const std::vector<Model::Subgraph>& subgraphs) {
1599 namespace op = operation_if;
1600 NN_RET_CHECK_GE(inputs.size(), 3u) << "IF must have at least 3 inputs";
1601 NN_RET_CHECK_GE(outputs.size(), 1u) << "IF must have at least 1 output";
1602 auto validateBranchOperand = [&](const Operand& branchModelOperand) -> Result<void> {
1603 auto result = validateSubgraphReference(subgraphs, branchModelOperand);
1604 if (!result.has_value()) {
1605 return error() << std::move(result).error()
1606 << " -- Operand is not a valid subgraph reference";
1607 }
1608 const uint32_t branchModelInputCount = getInputCount(subgraphs, branchModelOperand);
1609 const uint32_t branchModelOutputCount = getOutputCount(subgraphs, branchModelOperand);
1610 NN_RET_CHECK_EQ(inputs.size(), op::kFirstInput + branchModelInputCount);
1611 NN_RET_CHECK_EQ(outputs.size(), branchModelOutputCount);
1612 for (uint32_t i = 0; i < branchModelInputCount; ++i) {
1613 const Operand& innerOperand = getInputOperand(subgraphs, branchModelOperand, i);
1614 const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
1615 NN_TRY(compatible(innerOperand, outerOperand));
1616 }
1617 for (uint32_t i = 0; i < branchModelOutputCount; ++i) {
1618 const Operand& innerOperand = getOutputOperand(subgraphs, branchModelOperand, i);
1619 const Operand& outerOperand = operands[outputs[i]];
1620 NN_TRY(compatible(innerOperand, outerOperand));
1621 }
1622 return {};
1623 };
1624 auto result = validateConditionOperand(operands[inputs[op::kCondBoolOperand]]);
1625 if (!result.has_value()) {
1626 return error() << std::move(result).error() << " for IF condition operand";
1627 }
1628 result = validateBranchOperand(operands[inputs[op::kThenModelOperand]]);
1629 if (!result.has_value()) {
1630 return error() << std::move(result).error() << " for IF then model";
1631 }
1632 result = validateBranchOperand(operands[inputs[op::kElseModelOperand]]);
1633 if (!result.has_value()) {
1634 return error() << std::move(result).error() << " for IF else model";
1635 }
1636 return kVersionFeatureLevel4;
1637 }
1638
validateControlFlowOperandUnknownSize(const Operand & operand)1639 Result<Version> validateControlFlowOperandUnknownSize(const Operand& operand) {
1640 auto version = kVersionFeatureLevel4;
1641 if (!isExtension(operand.type) && getNonExtensionSize(operand).value() == 0) {
1642 // 1.3 HAL (corresponding to kVersionFeatureLevel4) does not support CF operations with
1643 // operands of unknown size. See http://b/132458982#comment63.
1644 version.runtimeOnlyFeatures = true;
1645 }
1646 return version;
1647 }
1648
validateWhileOperation(const std::vector<uint32_t> & inputs,const std::vector<uint32_t> & outputs,const std::vector<Operand> & operands,const std::vector<Model::Subgraph> & subgraphs)1649 Result<Version> validateWhileOperation(const std::vector<uint32_t>& inputs,
1650 const std::vector<uint32_t>& outputs,
1651 const std::vector<Operand>& operands,
1652 const std::vector<Model::Subgraph>& subgraphs) {
1653 // Let the loop have
1654 // - m >= 1 input-output operands,
1655 // - k >= 0 state-only operands, and
1656 // - n >= 0 input-only operands.
1657 // Then
1658 // - the WHILE loop operation has (2 + m + k + n) inputs and m outputs.
1659 // - the condition model has (m + k + n) inputs and 1 output.
1660 // - the body model has (m + k + n) inputs and (m + k) outputs.
1661 namespace op = operation_while;
1662 NN_RET_CHECK_GE(inputs.size(), 3u) << "WHILE must have at least 3 inputs";
1663 NN_RET_CHECK_GE(outputs.size(), 1u) << "WHILE must have at least 1 output";
1664 auto validateCondOperand = [&](const Operand& condModelOperand) -> Result<Version> {
1665 Version version = kVersionFeatureLevel4;
1666 auto result = validateSubgraphReference(subgraphs, condModelOperand);
1667 if (!result.has_value()) {
1668 return error() << std::move(result).error()
1669 << " -- Operand is not a valid subgraph reference";
1670 }
1671 const uint32_t condModelInputCount = getInputCount(subgraphs, condModelOperand);
1672 const uint32_t condModelOutputCount = getOutputCount(subgraphs, condModelOperand);
1673 NN_RET_CHECK_EQ(inputs.size(), op::kFirstInput + condModelInputCount);
1674 NN_RET_CHECK_EQ(condModelOutputCount, 1u);
1675 for (uint32_t i = 0; i < condModelInputCount; ++i) {
1676 const Operand& innerOperand = getInputOperand(subgraphs, condModelOperand, i);
1677 const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
1678 NN_TRY(compatible(innerOperand, outerOperand));
1679 version = combineVersions(version,
1680 NN_TRY(validateControlFlowOperandUnknownSize(innerOperand)));
1681 version = combineVersions(version,
1682 NN_TRY(validateControlFlowOperandUnknownSize(outerOperand)));
1683 }
1684 NN_TRY(validateConditionOperand(getOutputOperand(subgraphs, condModelOperand, 0)));
1685 return version;
1686 };
1687 auto validateBodyOperand = [&](const Operand& bodyModelOperand) -> Result<Version> {
1688 Version version = kVersionFeatureLevel4;
1689 auto result = validateSubgraphReference(subgraphs, bodyModelOperand);
1690 if (!result.has_value()) {
1691 return error() << std::move(result).error()
1692 << " -- Operand is not a valid subgraph reference";
1693 }
1694 const uint32_t bodyModelInputCount = getInputCount(subgraphs, bodyModelOperand);
1695 const uint32_t bodyModelOutputCount = getOutputCount(subgraphs, bodyModelOperand);
1696 NN_RET_CHECK_EQ(inputs.size(), op::kFirstInput + bodyModelInputCount);
1697 NN_RET_CHECK_GE(bodyModelOutputCount, outputs.size());
1698 NN_RET_CHECK_GE(bodyModelInputCount, bodyModelOutputCount);
1699 const uint32_t inputOutputCount = outputs.size();
1700 const uint32_t stateOnlyCount = bodyModelOutputCount - inputOutputCount;
1701 const uint32_t inputOnlyCount = bodyModelInputCount - bodyModelOutputCount;
1702 for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount + inputOnlyCount; i < n; ++i) {
1703 const Operand& innerOperand = getInputOperand(subgraphs, bodyModelOperand, i);
1704 const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
1705 NN_TRY(compatible(innerOperand, outerOperand));
1706 version = combineVersions(version,
1707 NN_TRY(validateControlFlowOperandUnknownSize(innerOperand)));
1708 version = combineVersions(version,
1709 NN_TRY(validateControlFlowOperandUnknownSize(outerOperand)));
1710 }
1711 for (uint32_t i = 0; i < inputOutputCount; ++i) {
1712 const Operand& innerOperand = getOutputOperand(subgraphs, bodyModelOperand, i);
1713 const Operand& outerOperand = operands[outputs[i]];
1714 NN_TRY(compatible(innerOperand, outerOperand));
1715 version = combineVersions(version,
1716 NN_TRY(validateControlFlowOperandUnknownSize(outerOperand)));
1717 }
1718 for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount; i < n; ++i) {
1719 const Operand& inputOperand = getInputOperand(subgraphs, bodyModelOperand, i);
1720 const Operand& outputOperand = getOutputOperand(subgraphs, bodyModelOperand, i);
1721 NN_TRY(compatible(inputOperand, outputOperand));
1722 version = combineVersions(version,
1723 NN_TRY(validateControlFlowOperandUnknownSize(outputOperand)));
1724 }
1725 return version;
1726 };
1727 auto result = validateCondOperand(operands[inputs[op::kCondModelOperand]]);
1728 if (!result.has_value()) {
1729 return error() << std::move(result).error() << " for WHILE condition model";
1730 }
1731 auto version = result.value();
1732 result = validateBodyOperand(operands[inputs[op::kBodyModelOperand]]);
1733 if (!result.has_value()) {
1734 return error() << std::move(result).error() << " for WHILE body model";
1735 }
1736 version = combineVersions(version, result.value());
1737 return version;
1738 }
1739
validateOperationButNotOperandsImpl(const Operation & operation,const std::vector<Operand> & operands,const std::vector<Model::Subgraph> & subgraphs)1740 Result<Version> validateOperationButNotOperandsImpl(const Operation& operation,
1741 const std::vector<Operand>& operands,
1742 const std::vector<Model::Subgraph>& subgraphs) {
1743 const auto opType = operation.type;
1744 const auto& inputIndexes = operation.inputs;
1745 const auto& outputIndexes = operation.outputs;
1746
1747 NN_TRY(validateOperandListImpl(inputIndexes, operands.size(),
1748 "ANeuralNetworksModel_addOperation inputs"));
1749 NN_TRY(validateOperandListImpl(outputIndexes, operands.size(),
1750 "ANeuralNetworksModel_addOperation outputs"));
1751
1752 if (isExtension(opType)) {
1753 // There is no other validation we can do for an extension operation.
1754 return kVersionFeatureLevel3;
1755 }
1756
1757 std::ostringstream oss;
1758 oss << operation.type;
1759 const std::string name = oss.str();
1760 OperationValidationContext context(name.c_str(), inputIndexes, outputIndexes, operands);
1761
1762 // Validate some operations explicitly.
1763 switch (opType) {
1764 case OperationType::OEM_OPERATION:
1765 return kVersionFeatureLevel1;
1766 case OperationType::IF:
1767 return validateIfOperation(inputIndexes, outputIndexes, operands, subgraphs);
1768 case OperationType::WHILE:
1769 return validateWhileOperation(inputIndexes, outputIndexes, operands, subgraphs);
1770 default:
1771 break;
1772 }
1773
1774 #define NN_HANDLE_SWITCH_CASE(operationName) \
1775 case OperationType::operationName: \
1776 return NN_VALIDATION_FUNCTION_NAME(operationName)(&context);
1777
1778 // Validate the remaining operations through operation-specific functions defined in
1779 // common/operations/.
1780 // TODO(b/213938830): operation validation dispatch is duplicated and does not handle extension
1781 // types.
1782 switch (opType) { NN_FOR_EACH_OPERATION(NN_HANDLE_SWITCH_CASE) }
1783
1784 #undef NN_HANDLE_SWITCH_CASE
1785
1786 NN_RET_CHECK_FAIL() << "Invalid OperationType " << opType;
1787 }
1788
validateOperationIncludingOperandVersions(const Operation & operation,const std::vector<Operand> & operands,const std::vector<Version> & operandVersions,const std::vector<Model::Subgraph> & subgraphs)1789 Result<Version> validateOperationIncludingOperandVersions(
1790 const Operation& operation, const std::vector<Operand>& operands,
1791 const std::vector<Version>& operandVersions,
1792 const std::vector<Model::Subgraph>& subgraphs) {
1793 auto version = NN_TRY(validateOperationButNotOperandsImpl(operation, operands, subgraphs));
1794 for (uint32_t index : operation.inputs) {
1795 version = combineVersions(version, operandVersions[index]);
1796 }
1797 for (uint32_t index : operation.outputs) {
1798 version = combineVersions(version, operandVersions[index]);
1799 }
1800 return version;
1801 }
1802
1803 } // anonymous namespace
1804
1805 // Below this point are all the functions that are declared in Validation.h. The convention of this
1806 // file is to keep the function bodies of the functions declared in Validation.h minimal, meaning
1807 // that most functions below simply redirect to one of the functions defined above in the anonymous
1808 // namespace. If there is a function name clash between one of the functions below and one of the
1809 // functions above, the function in the anonymous namespace is appended with "Impl".
1810
combineVersions(Version minVersionNeeded1,Version minVersionNeeded2)1811 Version combineVersions(Version minVersionNeeded1, Version minVersionNeeded2) {
1812 return Version{
1813 .level = std::max<Version::Level>(minVersionNeeded1.level, minVersionNeeded2.level),
1814 .runtimeOnlyFeatures =
1815 minVersionNeeded1.runtimeOnlyFeatures || minVersionNeeded2.runtimeOnlyFeatures,
1816 };
1817 }
1818
isCompliantVersion(Version minVersionNeeded,Version maxVersionSupported)1819 bool isCompliantVersion(Version minVersionNeeded, Version maxVersionSupported) {
1820 if (minVersionNeeded.runtimeOnlyFeatures && !maxVersionSupported.runtimeOnlyFeatures) {
1821 return false;
1822 }
1823 return minVersionNeeded.level <= maxVersionSupported.level;
1824 }
1825
validate(const DeviceStatus & deviceStatus)1826 Result<Version> validate(const DeviceStatus& deviceStatus) {
1827 return validateDeviceStatus(deviceStatus);
1828 }
1829
validate(const ExecutionPreference & executionPreference)1830 Result<Version> validate(const ExecutionPreference& executionPreference) {
1831 return validateExecutionPreference(executionPreference);
1832 }
1833
validate(const DeviceType & deviceType)1834 Result<Version> validate(const DeviceType& deviceType) {
1835 return validateDeviceType(deviceType);
1836 }
1837
validate(const MeasureTiming & measureTiming)1838 Result<Version> validate(const MeasureTiming& measureTiming) {
1839 return validateMeasureTiming(measureTiming);
1840 }
1841
validate(const OperandType & operandType)1842 Result<Version> validate(const OperandType& operandType) {
1843 return validateOperandType(operandType);
1844 }
1845
validate(const Priority & priority)1846 Result<Version> validate(const Priority& priority) {
1847 return validatePriority(priority);
1848 }
1849
validate(const ErrorStatus & errorStatus)1850 Result<Version> validate(const ErrorStatus& errorStatus) {
1851 return validateErrorStatus(errorStatus);
1852 }
1853
validate(const FusedActivationFunc & activation)1854 Result<Version> validate(const FusedActivationFunc& activation) {
1855 return validateFusedActivationFunc(activation);
1856 }
1857
validate(const OutputShape & outputShape)1858 Result<Version> validate(const OutputShape& outputShape) {
1859 return validateOutputShape(outputShape);
1860 }
1861
validate(const Timing & timing)1862 Result<Version> validate(const Timing& timing) {
1863 return validateTiming(timing);
1864 }
1865
validate(const Capabilities & capabilities)1866 Result<Version> validate(const Capabilities& capabilities) {
1867 return validateCapabilities(capabilities);
1868 }
1869
validate(const Extension & extension)1870 Result<Version> validate(const Extension& extension) {
1871 return validateExtension(extension);
1872 }
1873
validate(const SharedHandle & handle)1874 Result<Version> validate(const SharedHandle& handle) {
1875 return validateSharedHandle(handle);
1876 }
1877
validate(const SharedMemory & memory)1878 Result<Version> validate(const SharedMemory& memory) {
1879 return validateSharedMemory(memory);
1880 }
1881
validate(const Model & model)1882 Result<Version> validate(const Model& model) {
1883 return validateModel(model);
1884 }
1885
validate(const BufferDesc & bufferDesc)1886 Result<Version> validate(const BufferDesc& bufferDesc) {
1887 return validateBufferDesc(bufferDesc);
1888 }
1889
validate(const BufferRole & bufferRole)1890 Result<Version> validate(const BufferRole& bufferRole) {
1891 return validateBufferRole(bufferRole);
1892 }
1893
validate(const Request & request)1894 Result<Version> validate(const Request& request) {
1895 return validateRequest(request);
1896 }
1897
validate(const OptionalTimePoint & optionalTimePoint)1898 Result<Version> validate(const OptionalTimePoint& optionalTimePoint) {
1899 return validateOptionalTimePoint(optionalTimePoint);
1900 }
1901
validate(const OptionalDuration & optionalTimeoutDuration)1902 Result<Version> validate(const OptionalDuration& optionalTimeoutDuration) {
1903 return validateOptionalTimeoutDuration(optionalTimeoutDuration);
1904 }
1905
validate(const CacheToken & cacheToken)1906 Result<Version> validate(const CacheToken& cacheToken) {
1907 return validateCacheToken(cacheToken);
1908 }
1909
validate(const SyncFence & syncFence)1910 Result<Version> validate(const SyncFence& syncFence) {
1911 return validateSyncFence(syncFence);
1912 }
1913
validate(const TokenValuePair & tokenValuePair)1914 Result<Version> validate(const TokenValuePair& tokenValuePair) {
1915 return validateTokenValuePair(tokenValuePair);
1916 }
1917
validate(const std::vector<OutputShape> & outputShapes)1918 Result<Version> validate(const std::vector<OutputShape>& outputShapes) {
1919 return validateVector(outputShapes, validateOutputShape);
1920 }
1921
validate(const std::vector<Extension> & extensions)1922 Result<Version> validate(const std::vector<Extension>& extensions) {
1923 return validateExtensions(extensions);
1924 }
1925
validate(const std::vector<SharedHandle> & handles)1926 Result<Version> validate(const std::vector<SharedHandle>& handles) {
1927 return validateVector(handles, validateSharedHandle);
1928 }
1929
validate(const std::vector<BufferRole> & bufferRoles)1930 Result<Version> validate(const std::vector<BufferRole>& bufferRoles) {
1931 return validateVector(bufferRoles, validateBufferRole);
1932 }
1933
validate(const std::vector<SyncFence> & syncFences)1934 Result<Version> validate(const std::vector<SyncFence>& syncFences) {
1935 return validateVector(syncFences, validateSyncFence);
1936 }
1937
validate(const std::vector<TokenValuePair> & metaData)1938 Result<Version> validate(const std::vector<TokenValuePair>& metaData) {
1939 std::set<int32_t> tokenSet;
1940 for (const auto& p : metaData) {
1941 if (!tokenSet.insert(p.token).second) {
1942 NN_RET_CHECK_FAIL() << "Token added more than once " << p.token;
1943 }
1944 }
1945 return validateVector(metaData, validateTokenValuePair);
1946 }
1947
validate(const std::vector<ExtensionNameAndPrefix> & extensionNamesAndPrefixes)1948 Result<Version> validate(const std::vector<ExtensionNameAndPrefix>& extensionNamesAndPrefixes) {
1949 return validateExtensionNamesAndPrefixes(extensionNamesAndPrefixes);
1950 }
1951
validateRequestForModel(const Request & request,const Model & model,bool allowUnspecifiedOutput)1952 Result<Version> validateRequestForModel(const Request& request, const Model& model,
1953 bool allowUnspecifiedOutput) {
1954 return validateRequestForModelImpl(request, model, allowUnspecifiedOutput);
1955 }
1956
validateMemoryDesc(const BufferDesc & desc,const std::vector<SharedPreparedModel> & preparedModels,const std::vector<BufferRole> & inputRoles,const std::vector<BufferRole> & outputRoles,const std::function<const Model * (const SharedPreparedModel &)> & getModel,std::set<PreparedModelRole> * preparedModelRoles,Operand * combinedOperand)1957 Result<Version> validateMemoryDesc(
1958 const BufferDesc& desc, const std::vector<SharedPreparedModel>& preparedModels,
1959 const std::vector<BufferRole>& inputRoles, const std::vector<BufferRole>& outputRoles,
1960 const std::function<const Model*(const SharedPreparedModel&)>& getModel,
1961 std::set<PreparedModelRole>* preparedModelRoles, Operand* combinedOperand) {
1962 return validateMemoryDescImpl(desc, preparedModels, inputRoles, outputRoles, getModel,
1963 preparedModelRoles, combinedOperand);
1964 }
1965
validateOperandSymmPerChannelQuantParams(const Operand & operand,const Operand::SymmPerChannelQuantParams & channelQuant,const char * tag)1966 Result<void> validateOperandSymmPerChannelQuantParams(
1967 const Operand& operand, const Operand::SymmPerChannelQuantParams& channelQuant,
1968 const char* tag) {
1969 return validateOperandSymmPerChannelQuantParamsImpl(operand, channelQuant, tag);
1970 }
1971
validateOperandType(const Operand & type,const Extension::OperandTypeInformation * extensionOperandTypeInfo,const char * tag,bool allowPartial)1972 Result<void> validateOperandType(const Operand& type,
1973 const Extension::OperandTypeInformation* extensionOperandTypeInfo,
1974 const char* tag, bool allowPartial) {
1975 return validateOperandTypeImpl(type, extensionOperandTypeInfo, tag, allowPartial);
1976 }
1977
validateOperandList(const std::vector<uint32_t> & list,size_t operandCount,const char * tag)1978 Result<void> validateOperandList(const std::vector<uint32_t>& list, size_t operandCount,
1979 const char* tag) {
1980 return validateOperandListImpl(list, operandCount, tag);
1981 }
1982
validateOperationButNotOperands(const Operation & operation,const std::vector<Operand> & operands,const std::vector<Model::Subgraph> & subgraphs)1983 Result<void> validateOperationButNotOperands(const Operation& operation,
1984 const std::vector<Operand>& operands,
1985 const std::vector<Model::Subgraph>& subgraphs) {
1986 NN_TRY(validateOperationButNotOperandsImpl(operation, operands, subgraphs));
1987 return {};
1988 }
1989
1990 struct SubgraphVersionCache {
1991 std::vector<std::optional<Version>> cache;
1992 };
1993
createSubgraphVersionCache(size_t referencedSubgraphCount)1994 std::unique_ptr<SubgraphVersionCache, void (*)(SubgraphVersionCache*)> createSubgraphVersionCache(
1995 size_t referencedSubgraphCount) {
1996 auto subgraphVersionCache = std::make_unique<SubgraphVersionCache>();
1997 subgraphVersionCache->cache.resize(referencedSubgraphCount);
1998 constexpr auto deleter = [](SubgraphVersionCache* ptr) { delete ptr; };
1999 return {subgraphVersionCache.release(), deleter};
2000 }
2001
validateOperationAndAnythingItDependsOn(const Operation & operation,const std::vector<Operand> & operands,size_t operandValuesSize,const std::vector<size_t> & poolSizes,const std::vector<Model::Subgraph> & subgraphs,SubgraphVersionCache * subgraphVersionCache)2002 Result<Version> validateOperationAndAnythingItDependsOn(
2003 const Operation& operation, const std::vector<Operand>& operands, size_t operandValuesSize,
2004 const std::vector<size_t>& poolSizes, const std::vector<Model::Subgraph>& subgraphs,
2005 SubgraphVersionCache* subgraphVersionCache) {
2006 CHECK(subgraphVersionCache != nullptr);
2007 std::vector<Version> operandVersions(operands.size(), kVersionFeatureLevel1);
2008 for (uint32_t index : operation.inputs) {
2009 NN_RET_CHECK_LT(index, operands.size());
2010 const Operand& operand = operands[index];
2011 operandVersions[index] = NN_TRY(validateOperandAndAnythingItDependsOn(
2012 operand, operandValuesSize, poolSizes, subgraphs, subgraphVersionCache));
2013 }
2014 for (uint32_t index : operation.outputs) {
2015 NN_RET_CHECK_LT(index, operands.size());
2016 const Operand& operand = operands[index];
2017 operandVersions[index] = NN_TRY(validateOperandAndAnythingItDependsOn(
2018 operand, operandValuesSize, poolSizes, subgraphs, subgraphVersionCache));
2019 }
2020 return validateOperationIncludingOperandVersions(operation, operands, operandVersions,
2021 subgraphs);
2022 }
2023
validateOperandAndAnythingItDependsOn(const Operand & operand,size_t operandValuesSize,const std::vector<size_t> & poolSizes,const std::vector<Model::Subgraph> & subgraphs,SubgraphVersionCache * subgraphVersionCache)2024 Result<Version> validateOperandAndAnythingItDependsOn(const Operand& operand,
2025 size_t operandValuesSize,
2026 const std::vector<size_t>& poolSizes,
2027 const std::vector<Model::Subgraph>& subgraphs,
2028 SubgraphVersionCache* subgraphVersionCache) {
2029 CHECK(subgraphVersionCache != nullptr);
2030 return validateOperand(operand, operandValuesSize, poolSizes, subgraphs,
2031 &subgraphVersionCache->cache);
2032 }
2033
2034 } // namespace android::nn
2035