1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "Validation.h"
18 
19 #include <android-base/logging.h>
20 #include <android-base/mapped_file.h>
21 
22 #include <algorithm>
23 #include <cctype>
24 #include <functional>
25 #include <limits>
26 #include <memory>
27 #include <numeric>
28 #include <set>
29 #include <sstream>
30 #include <string>
31 #include <string_view>
32 #include <tuple>
33 #include <utility>
34 #include <variant>
35 #include <vector>
36 
37 #include "ControlFlow.h"
38 #include "OperandTypes.h"
39 #include "OperationTypes.h"
40 #include "OperationsUtils.h"
41 #include "OperationsValidationUtils.h"
42 #include "Result.h"
43 #include "SharedMemory.h"
44 #include "TypeUtils.h"
45 #include "Types.h"
46 
47 namespace android::nn {
48 
49 #define NN_FORWARD_DECLARE_VALIDATION_FUNCTION(opType) NN_VALIDATION_FUNCTION_SIGNATURE(opType);
50 
NN_FOR_EACH_OPERATION(NN_FORWARD_DECLARE_VALIDATION_FUNCTION)51 NN_FOR_EACH_OPERATION(NN_FORWARD_DECLARE_VALIDATION_FUNCTION)
52 
53 #undef NN_FORWARD_DECLARE_VALIDATION_FUNCTION
54 
55 static Result<Version> notImplementedThroughRegistration(
56         const IOperationValidationContext* context) {
57     LOG(FATAL) << "Operation " << context->getOperationName()
58                << " not supported through registration";
59     return NN_ERROR();
60 }
61 
62 NN_DEFINE_VALIDATION_FUNCTION(IF, notImplementedThroughRegistration);
63 NN_DEFINE_VALIDATION_FUNCTION(WHILE, notImplementedThroughRegistration);
64 NN_DEFINE_VALIDATION_FUNCTION(OEM_OPERATION, notImplementedThroughRegistration);
65 
66 namespace {
67 
68 constexpr auto kNullptrVariant = std::variant<const void*, void*>{};
69 constexpr auto kInvalidMemoryDomainToken = Request::MemoryDomainToken{};
70 
71 template <typename Type, typename ValidationFunction>
validateVector(const std::vector<Type> & objects,const ValidationFunction & validationFunction)72 Result<Version> validateVector(const std::vector<Type>& objects,
73                                const ValidationFunction& validationFunction) {
74     auto version = kVersionFeatureLevel1;
75     for (const auto& object : objects) {
76         version = combineVersions(version, NN_TRY(validationFunction(object)));
77     }
78     return version;
79 }
80 
isValidExtensionName(const std::string & name)81 bool isValidExtensionName(const std::string& name) {
82     constexpr auto validSymbol = [](char symbol) {
83         return std::islower(symbol) || std::isdigit(symbol) || symbol == '.' || symbol == '_';
84     };
85     const bool hasOnlyValidSymbols = std::all_of(name.begin(), name.end(), validSymbol);
86     const bool hasAtLeastOnePeriod = std::find(name.begin(), name.end(), '.') != name.end();
87     return hasOnlyValidSymbols && hasAtLeastOnePeriod;
88 }
89 
validateDeviceStatus(const DeviceStatus & deviceStatus)90 Result<Version> validateDeviceStatus(const DeviceStatus& deviceStatus) {
91     switch (deviceStatus) {
92         case DeviceStatus::AVAILABLE:
93         case DeviceStatus::BUSY:
94         case DeviceStatus::OFFLINE:
95         case DeviceStatus::UNKNOWN:
96             return kVersionFeatureLevel1;
97     }
98     NN_RET_CHECK_FAIL() << "Invalid DeviceStatus " << deviceStatus;
99 }
100 
validateExecutionPreference(const ExecutionPreference & executionPreference)101 Result<Version> validateExecutionPreference(const ExecutionPreference& executionPreference) {
102     switch (executionPreference) {
103         case ExecutionPreference::FAST_SINGLE_ANSWER:
104             // ExecutionPreference::FAST_SINGLE_ANSWER is the default value, so it is implicitly
105             // valid for all versions.
106             return kVersionFeatureLevel1;
107         case ExecutionPreference::LOW_POWER:
108         case ExecutionPreference::SUSTAINED_SPEED:
109             return kVersionFeatureLevel2;
110     }
111     NN_RET_CHECK_FAIL() << "Invalid ExecutionPreference " << executionPreference;
112 }
113 
validateDeviceType(const DeviceType & deviceType)114 Result<Version> validateDeviceType(const DeviceType& deviceType) {
115     switch (deviceType) {
116         case DeviceType::UNKNOWN:
117             // DeviceType was introduced in the 1.2 NN HAL. DeviceType::UNKNOWN is returned when
118             // querying versions that are prior to the 1.2 NN HAL. DeviceType::UNKNOWN is not a
119             // valid code to return for a driver that implement at least a 1.2 NN HAL. If we need a
120             // range of versions, make ANDROID_Q (NN HAL 1.2) the exclusive upper bound for
121             // DeviceType::UNKNOWN.
122             return kVersionFeatureLevel1;
123         case DeviceType::OTHER:
124         case DeviceType::CPU:
125         case DeviceType::GPU:
126         case DeviceType::ACCELERATOR:
127             return kVersionFeatureLevel3;
128     }
129     NN_RET_CHECK_FAIL() << "Invalid DeviceType " << deviceType;
130 }
131 
validateMeasureTiming(const MeasureTiming & measureTiming)132 Result<Version> validateMeasureTiming(const MeasureTiming& measureTiming) {
133     switch (measureTiming) {
134         case MeasureTiming::NO:
135             // MeasureTiming::NO is the default value, so it is implicitly valid for all versions.
136             return kVersionFeatureLevel1;
137         case MeasureTiming::YES:
138             return kVersionFeatureLevel3;
139     }
140     NN_RET_CHECK_FAIL() << "Invalid MeasureTiming " << measureTiming;
141 }
142 
validateOperandType(const OperandType & operandType)143 Result<Version> validateOperandType(const OperandType& operandType) {
144     switch (operandType) {
145         case OperandType::FLOAT32:
146         case OperandType::INT32:
147         case OperandType::UINT32:
148         case OperandType::TENSOR_FLOAT32:
149         case OperandType::TENSOR_INT32:
150         case OperandType::TENSOR_QUANT8_ASYMM:
151         case OperandType::OEM:
152         case OperandType::TENSOR_OEM_BYTE:
153             return kVersionFeatureLevel1;
154         case OperandType::BOOL:
155         case OperandType::TENSOR_QUANT16_SYMM:
156         case OperandType::TENSOR_FLOAT16:
157         case OperandType::TENSOR_BOOL8:
158         case OperandType::FLOAT16:
159         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
160         case OperandType::TENSOR_QUANT16_ASYMM:
161         case OperandType::TENSOR_QUANT8_SYMM:
162             return kVersionFeatureLevel3;
163         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
164         case OperandType::SUBGRAPH:
165             return kVersionFeatureLevel4;
166     }
167     if (isExtension(operandType)) {
168         return kVersionFeatureLevel3;
169     }
170     NN_RET_CHECK_FAIL() << "Invalid OperandType " << operandType;
171 }
172 
validateOperandLifeTime(const Operand & operand)173 Result<Version> validateOperandLifeTime(const Operand& operand) {
174     // Make sure SUBGRAPH operand type and lifetime always go together.
175     NN_RET_CHECK_EQ((operand.type == OperandType::SUBGRAPH),
176                     (operand.lifetime == Operand::LifeTime::SUBGRAPH))
177             << "Operand of type " << operand.type << " cannot have lifetime " << operand.lifetime;
178 
179     switch (operand.lifetime) {
180         case Operand::LifeTime::TEMPORARY_VARIABLE:
181         case Operand::LifeTime::SUBGRAPH_INPUT:
182         case Operand::LifeTime::SUBGRAPH_OUTPUT:
183         case Operand::LifeTime::CONSTANT_COPY:
184         case Operand::LifeTime::CONSTANT_REFERENCE:
185         case Operand::LifeTime::NO_VALUE:
186         case Operand::LifeTime::POINTER:
187             return kVersionFeatureLevel1;
188         case Operand::LifeTime::SUBGRAPH:
189             return kVersionFeatureLevel4;
190     }
191     NN_RET_CHECK_FAIL() << "Invalid Operand::LifeTime " << operand.lifetime;
192 }
193 
validatePriority(const Priority & priority)194 Result<Version> validatePriority(const Priority& priority) {
195     switch (priority) {
196         case Priority::MEDIUM:
197             // Priority::MEDIUM is the default value, so it is implicitly valid for all versions.
198             return kVersionFeatureLevel1;
199         case Priority::LOW:
200         case Priority::HIGH:
201             return kVersionFeatureLevel4;
202     }
203     NN_RET_CHECK_FAIL() << "Invalid Priority " << priority;
204 }
205 
validateErrorStatus(const ErrorStatus & errorStatus)206 Result<Version> validateErrorStatus(const ErrorStatus& errorStatus) {
207     // Note that MISSED_DEADLINE_*, RESOURCE_EXHAUSTED_*, and DEAD_OBJECT were introduced ih
208     // ANDROID_R, but these can be cast to ANDROID_OC_MR1 as GENERAL_FAILURE.
209     switch (errorStatus) {
210         case ErrorStatus::NONE:
211         case ErrorStatus::DEVICE_UNAVAILABLE:
212         case ErrorStatus::GENERAL_FAILURE:
213         case ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
214         case ErrorStatus::INVALID_ARGUMENT:
215         case ErrorStatus::MISSED_DEADLINE_TRANSIENT:
216         case ErrorStatus::MISSED_DEADLINE_PERSISTENT:
217         case ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
218         case ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
219         case ErrorStatus::DEAD_OBJECT:
220             return kVersionFeatureLevel1;
221     }
222     NN_RET_CHECK_FAIL() << "Invalid ErrorStatus " << errorStatus;
223 }
224 
validateFusedActivationFunc(const FusedActivationFunc & activation)225 Result<Version> validateFusedActivationFunc(const FusedActivationFunc& activation) {
226     switch (activation) {
227         case FusedActivationFunc::NONE:
228         case FusedActivationFunc::RELU:
229         case FusedActivationFunc::RELU1:
230         case FusedActivationFunc::RELU6:
231             return kVersionFeatureLevel1;
232     }
233     NN_RET_CHECK_FAIL() << "Invalid FusedActivationFunc " << activation;
234 }
235 
validateOutputShape(const OutputShape &)236 Result<Version> validateOutputShape(const OutputShape& /*outputShape*/) {
237     return kVersionFeatureLevel3;
238 }
239 
validateTiming(const Timing & timing)240 Result<Version> validateTiming(const Timing& timing) {
241     constexpr auto kNoTiming = Timing{};
242     if (timing == kNoTiming) {
243         // kNoTiming is the default value, so it is implicitly valid for all versions.
244         return kVersionFeatureLevel1;
245     }
246     if (timing.timeInDriver.has_value() && timing.timeOnDevice.has_value()) {
247         // `lazyMessage` is a lazy function to produce the timing validation error message.
248         // Currently, the code is not able to inline the message in NN_RET_CHECK due to a
249         // argument-dependent lookup issue with nn::detail::ErrorBuilder interacting with std types
250         // such as std::chrono::duration, so this function uses an indirection through
251         // std::ostringstream.
252         const auto lazyMessage = [&timing]() -> std::string {
253             std::ostringstream oss;
254             oss << "Timing::timeOnDevice (" << timing.timeOnDevice.value()
255                 << ") must not exceed Timing::timeInDriver (" << timing.timeInDriver.value() << ")";
256             return oss.str();
257         };
258         NN_RET_CHECK(timing.timeOnDevice.value() <= timing.timeInDriver.value()) << lazyMessage();
259     }
260     return kVersionFeatureLevel3;
261 }
262 
validateCapabilitiesPerformanceInfo(const Capabilities::PerformanceInfo & performanceInfo)263 Result<Version> validateCapabilitiesPerformanceInfo(
264         const Capabilities::PerformanceInfo& performanceInfo) {
265     NN_RET_CHECK_GT(performanceInfo.execTime, 0.0f);
266     NN_RET_CHECK_GT(performanceInfo.powerUsage, 0.0f);
267     return kVersionFeatureLevel1;
268 }
269 
validateCapabilitiesOperandPerformance(const Capabilities::OperandPerformance & operandPerformance)270 Result<Version> validateCapabilitiesOperandPerformance(
271         const Capabilities::OperandPerformance& operandPerformance) {
272     auto version = NN_TRY(validateOperandType(operandPerformance.type));
273     return combineVersions(version,
274                            NN_TRY(validateCapabilitiesPerformanceInfo(operandPerformance.info)));
275 }
276 
validateCapabilitiesOperandPerformanceTable(const Capabilities::OperandPerformanceTable & operandPerformances)277 Result<Version> validateCapabilitiesOperandPerformanceTable(
278         const Capabilities::OperandPerformanceTable& operandPerformances) {
279     // OperandPerformanceTable's order was validated when it was created, and it is castable to any
280     // version. If an OperandType does not exist in the lower version being converted to, that
281     // OperandPerformance will be dropped.
282     NN_TRY(validateVector(operandPerformances.asVector(), validateCapabilitiesOperandPerformance));
283     return kVersionFeatureLevel1;
284 }
285 
validateCapabilities(const Capabilities & capabilities)286 Result<Version> validateCapabilities(const Capabilities& capabilities) {
287     auto version =
288             NN_TRY(validateCapabilitiesOperandPerformanceTable(capabilities.operandPerformance));
289 
290     version = combineVersions(version,
291                               NN_TRY(validateCapabilitiesPerformanceInfo(
292                                       capabilities.relaxedFloat32toFloat16PerformanceScalar)));
293     version = combineVersions(version,
294                               NN_TRY(validateCapabilitiesPerformanceInfo(
295                                       capabilities.relaxedFloat32toFloat16PerformanceTensor)));
296     version = combineVersions(
297             version, NN_TRY(validateCapabilitiesPerformanceInfo(capabilities.ifPerformance)));
298     version = combineVersions(
299             version, NN_TRY(validateCapabilitiesPerformanceInfo(capabilities.whilePerformance)));
300 
301     return version;
302 }
303 
validateExtensionOperandTypeInformation(const Extension::OperandTypeInformation & operandTypeInformation)304 Result<Version> validateExtensionOperandTypeInformation(
305         const Extension::OperandTypeInformation& operandTypeInformation) {
306     NN_RET_CHECK_GT(operandTypeInformation.byteSize, 0u);
307     return kVersionFeatureLevel3;
308 }
309 
validateExtension(const Extension & extension)310 Result<Version> validateExtension(const Extension& extension) {
311     NN_RET_CHECK(isValidExtensionName(extension.name));
312 
313     // Verify all OperandTypeInformations have unique types.
314     std::vector<uint16_t> types;
315     types.reserve(extension.operandTypes.size());
316     std::transform(extension.operandTypes.begin(), extension.operandTypes.end(),
317                    std::back_inserter(types),
318                    [](const Extension::OperandTypeInformation& operandTypeInformation) {
319                        return operandTypeInformation.type;
320                    });
321     std::sort(types.begin(), types.end());
322     const auto iter = std::adjacent_find(types.begin(), types.end());
323     NN_RET_CHECK(iter == types.end()) << "Extension has duplicate type " << *iter;
324 
325     return combineVersions(kVersionFeatureLevel3,
326                            NN_TRY(validateVector(extension.operandTypes,
327                                                  validateExtensionOperandTypeInformation)));
328 }
329 
validateExtensions(const std::vector<Extension> & extensions)330 Result<Version> validateExtensions(const std::vector<Extension>& extensions) {
331     const auto version = NN_TRY(validateVector(extensions, validateExtension));
332 
333     // Verify all extensions have unique names.
334     std::vector<std::reference_wrapper<const std::string>> names;
335     names.reserve(extensions.size());
336     std::transform(extensions.begin(), extensions.end(), std::back_inserter(names),
337                    [](const Extension& extension) { return std::cref(extension.name); });
338     std::sort(names.begin(), names.end(), std::less<std::string>{});
339     const auto nameIter =
340             std::adjacent_find(names.begin(), names.end(), std::equal_to<std::string>{});
341     NN_RET_CHECK(nameIter == names.end())
342             << "Two or more extensions have the duplicate name " << nameIter->get();
343 
344     return version;
345 }
346 
347 // Forward declaration of subgraph validation function.
348 Result<Version> validateModelSubgraph(const Model::Subgraph& subgraph,
349                                       std::optional<size_t> referencedIndex,
350                                       size_t operandValuesSize,
351                                       const std::vector<size_t>& poolSizes,
352                                       const std::vector<Model::Subgraph>& referenced,
353                                       std::vector<std::optional<Version>>* subgraphVersionCache);
354 
validateOperandDataLocation(const Operand & operand,size_t operandValuesSize,const std::vector<size_t> & poolSizes,const std::vector<Model::Subgraph> & subgraphs,std::vector<std::optional<Version>> * subgraphVersionCache)355 Result<Version> validateOperandDataLocation(
356         const Operand& operand, size_t operandValuesSize, const std::vector<size_t>& poolSizes,
357         const std::vector<Model::Subgraph>& subgraphs,
358         std::vector<std::optional<Version>>* subgraphVersionCache) {
359     const DataLocation& location = operand.location;
360     NN_RET_CHECK_EQ(location.padding, 0u)
361             << "DataLocation with a non-zero padding used in Model: " << location.padding;
362     switch (operand.lifetime) {
363         case Operand::LifeTime::CONSTANT_COPY:
364             NN_RET_CHECK(location.pointer == kNullptrVariant)
365                     << "CONSTANT_COPY with a non-null pointer";
366             NN_RET_CHECK_EQ(location.poolIndex, 0u)
367                     << "CONSTANT_COPY with a non-zero poolIndex " << location.poolIndex;
368             // Do the addition using uint64_t to avoid potential wrap-around problems.
369             NN_RET_CHECK_LE(static_cast<uint64_t>(location.offset) + location.length,
370                             operandValuesSize)
371                     << "OperandValue location out of range.  Starts at " << location.offset
372                     << ", length " << location.length << ", max " << operandValuesSize;
373             return kVersionFeatureLevel1;
374         case Operand::LifeTime::CONSTANT_REFERENCE:
375             NN_RET_CHECK_LT(location.poolIndex, poolSizes.size());
376             // Do the addition using uint64_t to avoid potential wrap-around problems.
377             NN_RET_CHECK_LE(static_cast<uint64_t>(location.offset) + location.length,
378                             poolSizes[location.poolIndex])
379                     << "OperandValue location out of range.  Starts at " << location.offset
380                     << ", length " << location.length << ", max " << poolSizes[location.poolIndex];
381             return kVersionFeatureLevel1;
382         case Operand::LifeTime::TEMPORARY_VARIABLE:
383         case Operand::LifeTime::SUBGRAPH_INPUT:
384         case Operand::LifeTime::SUBGRAPH_OUTPUT:
385         case Operand::LifeTime::NO_VALUE:
386             NN_RET_CHECK(location.pointer == kNullptrVariant)
387                     << "Unexpected pointer value for operand of lifetime " << operand.lifetime;
388             NN_RET_CHECK_EQ(location.poolIndex, 0u)
389                     << "Unexpected poolIndex " << location.poolIndex << " for operand of lifetime "
390                     << operand.lifetime;
391             NN_RET_CHECK_EQ(location.offset, 0u) << "Unexpected offset " << location.offset
392                                                  << " for operand of lifetime " << operand.lifetime;
393             NN_RET_CHECK_EQ(location.length, 0u) << "Unexpected length " << location.length
394                                                  << " for operand of lifetime " << operand.lifetime;
395             return kVersionFeatureLevel1;
396         case Operand::LifeTime::SUBGRAPH: {
397             NN_RET_CHECK(location.pointer == kNullptrVariant) << "SUBGRAPH with a non-null pointer";
398             NN_RET_CHECK_EQ(location.poolIndex, 0u)
399                     << "SUBGRAPH with a non-zero poolIndex " << location.poolIndex;
400             NN_RET_CHECK_LT(location.offset, subgraphs.size())
401                     << "Subgraph index out of range: " << location.offset
402                     << " >= " << subgraphs.size();
403             NN_RET_CHECK_EQ(location.length, 0u)
404                     << "SUBGRAPH with a non-zero length " << location.length;
405             const auto version = NN_TRY(validateModelSubgraph(
406                     subgraphs[location.offset], location.offset, operandValuesSize, poolSizes,
407                     subgraphs, subgraphVersionCache));
408             return combineVersions(version, kVersionFeatureLevel4);
409         }
410         case Operand::LifeTime::POINTER: {
411             const bool nonNull =
412                     std::visit([](auto* ptr) { return ptr != nullptr; }, location.pointer);
413             NN_RET_CHECK(nonNull) << "POINTER with a null pointer";
414             NN_RET_CHECK_EQ(location.poolIndex, 0u)
415                     << "POINTER with a non-zero poolIndex " << location.poolIndex;
416             NN_RET_CHECK_EQ(location.offset, 0u)
417                     << "POINTER with a non-zero offset " << location.offset;
418             return kVersionFeatureLevel1;
419         }
420     }
421     NN_RET_CHECK_FAIL() << "Invalid Operand::LifeTime " << operand.lifetime;
422 }
423 
validateOperandDimensions(const Operand & operand)424 Result<Version> validateOperandDimensions(const Operand& operand) {
425     switch (operand.type) {
426         case OperandType::FLOAT32:
427         case OperandType::INT32:
428         case OperandType::UINT32:
429         case OperandType::BOOL:
430         case OperandType::FLOAT16:
431         case OperandType::SUBGRAPH:
432         case OperandType::OEM:
433             NN_RET_CHECK(operand.dimensions.empty())
434                     << "Scalar data has dimensions of rank " << operand.dimensions.size();
435             return kVersionFeatureLevel1;
436         case OperandType::TENSOR_FLOAT32:
437         case OperandType::TENSOR_INT32:
438         case OperandType::TENSOR_QUANT8_ASYMM:
439         case OperandType::TENSOR_QUANT16_SYMM:
440         case OperandType::TENSOR_FLOAT16:
441         case OperandType::TENSOR_BOOL8:
442         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
443         case OperandType::TENSOR_QUANT16_ASYMM:
444         case OperandType::TENSOR_QUANT8_SYMM:
445         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
446         case OperandType::TENSOR_OEM_BYTE: {
447             if (operand.lifetime == Operand::LifeTime::CONSTANT_COPY ||
448                 operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE ||
449                 operand.lifetime == Operand::LifeTime::POINTER) {
450                 NN_RET_CHECK(!operand.dimensions.empty())
451                         << "Tensor has lifetime of " << operand.lifetime
452                         << " but dimensions of rank 0";
453                 const auto size = getNonExtensionSize(operand);
454                 NN_RET_CHECK(size.has_value()) << "Tensor dimensions overflow";
455                 NN_RET_CHECK_NE(size.value(), 0u) << "Tensor has at least one unknown dimension";
456             }
457             // TODO(b/165152547): aren't NO_VALUE arguments allowed to be .empty() even before
458             // Android Q?
459             if (operand.dimensions.empty()) {
460                 // Unspecified rank was added in Android Q.
461                 return kVersionFeatureLevel3;
462             }
463             return kVersionFeatureLevel1;
464         }
465     }
466     if (isExtension(operand.type)) {
467         // Extension types were added in Android Q.
468         return kVersionFeatureLevel3;
469     }
470     NN_RET_CHECK_FAIL() << "Invalid OperandType " << operand.type;
471 }
472 
validateOperandScale(const Operand & operand)473 Result<Version> validateOperandScale(const Operand& operand) {
474     switch (operand.type) {
475         case OperandType::FLOAT32:
476         case OperandType::INT32:
477         case OperandType::UINT32:
478         case OperandType::TENSOR_FLOAT32:
479         case OperandType::BOOL:
480         case OperandType::TENSOR_FLOAT16:
481         case OperandType::TENSOR_BOOL8:
482         case OperandType::FLOAT16:
483         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
484         case OperandType::SUBGRAPH:
485             NN_RET_CHECK_EQ(operand.scale, 0.0f)
486                     << "Operand of type " << operand.type << " with a non-zero scale ("
487                     << operand.scale << ")";
488             return kVersionFeatureLevel1;
489         case OperandType::TENSOR_INT32:
490             // TENSOR_INT32 may be used with or without scale, depending on the operation.
491             // TODO(b/119869082) We should have a separate type for TENSOR_INT32 with a scale.
492             NN_RET_CHECK_GE(operand.scale, 0.0f)
493                     << "Operand of type " << operand.type << " with a negative scale";
494             return kVersionFeatureLevel1;
495         case OperandType::TENSOR_QUANT8_ASYMM:
496         case OperandType::TENSOR_QUANT16_SYMM:
497         case OperandType::TENSOR_QUANT16_ASYMM:
498         case OperandType::TENSOR_QUANT8_SYMM:
499         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
500             NN_RET_CHECK_GT(operand.scale, 0.0f)
501                     << "Operand of type " << operand.type << " with a non-positive scale";
502             return kVersionFeatureLevel1;
503         case OperandType::OEM:
504         case OperandType::TENSOR_OEM_BYTE:
505             // No validation for OEM types.
506             return kVersionFeatureLevel1;
507     }
508     if (isExtension(operand.type)) {
509         NN_RET_CHECK_EQ(operand.scale, 0.0f) << "Operand of type " << operand.type
510                                              << " with a non-zero scale (" << operand.scale << ")";
511         return kVersionFeatureLevel3;
512     }
513     NN_RET_CHECK_FAIL() << "Invalid OperandType " << operand.type;
514 }
515 
validateOperandZeroPoint(const Operand & operand)516 Result<Version> validateOperandZeroPoint(const Operand& operand) {
517     switch (operand.type) {
518         case OperandType::FLOAT32:
519         case OperandType::INT32:
520         case OperandType::UINT32:
521         case OperandType::TENSOR_FLOAT32:
522         case OperandType::TENSOR_INT32:
523         case OperandType::BOOL:
524         case OperandType::TENSOR_FLOAT16:
525         case OperandType::TENSOR_BOOL8:
526         case OperandType::FLOAT16:
527         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
528         case OperandType::TENSOR_QUANT8_SYMM:
529         case OperandType::SUBGRAPH:
530             NN_RET_CHECK_EQ(operand.zeroPoint, 0)
531                     << "Operand of type " << operand.type << " with a non-zero zeroPoint "
532                     << operand.zeroPoint;
533             return kVersionFeatureLevel1;
534         case OperandType::TENSOR_QUANT8_ASYMM:
535             NN_RET_CHECK(operand.zeroPoint >= 0 && operand.zeroPoint <= 255)
536                     << "Operand of type " << operand.type << " with an invalid zeroPoint "
537                     << operand.zeroPoint << ", must be in range [0, 255]";
538             return kVersionFeatureLevel1;
539         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
540             NN_RET_CHECK(operand.zeroPoint >= -128 && operand.zeroPoint <= 127)
541                     << "Operand of type " << operand.type << " with an invalid zeroPoint "
542                     << operand.zeroPoint << ", must be in range [-128, 127]";
543             return kVersionFeatureLevel1;
544         case OperandType::TENSOR_QUANT16_ASYMM:
545             NN_RET_CHECK(operand.zeroPoint >= 0 && operand.zeroPoint <= 65535)
546                     << "Operand of type " << operand.type << " with an invalid zeroPoint "
547                     << operand.zeroPoint << ", must be in range [0, 65535]";
548             return kVersionFeatureLevel1;
549         case OperandType::TENSOR_QUANT16_SYMM:
550             NN_RET_CHECK_EQ(operand.zeroPoint, 0)
551                     << "Operand of type " << operand.type << " with a non-zero zeroPoint "
552                     << operand.zeroPoint;
553             return kVersionFeatureLevel1;
554         case OperandType::OEM:
555         case OperandType::TENSOR_OEM_BYTE:
556             // No validation for OEM types.
557             return kVersionFeatureLevel1;
558     }
559     if (isExtension(operand.type)) {
560         NN_RET_CHECK_EQ(operand.zeroPoint, 0) << "Operand of type " << operand.type
561                                               << " with a non-zero zeroPoint " << operand.zeroPoint;
562         return kVersionFeatureLevel3;
563     }
564     NN_RET_CHECK_FAIL() << "Invalid OperandType " << operand.type;
565 }
566 
validateOperandExtraParams(const Operand & operand)567 Result<Version> validateOperandExtraParams(const Operand& operand) {
568     switch (operand.type) {
569         case OperandType::FLOAT32:
570         case OperandType::INT32:
571         case OperandType::UINT32:
572         case OperandType::TENSOR_FLOAT32:
573         case OperandType::TENSOR_INT32:
574         case OperandType::TENSOR_QUANT8_ASYMM:
575         case OperandType::BOOL:
576         case OperandType::TENSOR_QUANT16_SYMM:
577         case OperandType::TENSOR_FLOAT16:
578         case OperandType::TENSOR_BOOL8:
579         case OperandType::FLOAT16:
580         case OperandType::TENSOR_QUANT16_ASYMM:
581         case OperandType::TENSOR_QUANT8_SYMM:
582         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
583         case OperandType::SUBGRAPH:
584             NN_RET_CHECK(std::holds_alternative<Operand::NoParams>(operand.extraParams))
585                     << "Operand of type " << operand.type
586                     << " has extraParams when there must be none";
587             return kVersionFeatureLevel1;
588         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: {
589             NN_RET_CHECK(
590                     std::holds_alternative<Operand::SymmPerChannelQuantParams>(operand.extraParams))
591                     << "Operand of type " << operand.type
592                     << " without a Channel Quantization params";
593             const auto& channelQuant =
594                     std::get<Operand::SymmPerChannelQuantParams>(operand.extraParams);
595 
596             const size_t count = operand.dimensions.size();
597             NN_RET_CHECK_LT(channelQuant.channelDim, count)
598                     << "Operand of type " << operand.type
599                     << " with an invalid channelQuant.channelDim " << channelQuant.channelDim
600                     << ", must be valid dimension index in range [0, " << count << ")";
601             const uint32_t expected = operand.dimensions[channelQuant.channelDim];
602             NN_RET_CHECK_EQ(channelQuant.scales.size(), expected)
603                     << "Operand of type " << operand.type << " with a wrong-sized scales, expected "
604                     << expected << " was " << channelQuant.scales.size();
605             NN_RET_CHECK_NE(expected, 0u)
606                     << "Operand of type " << operand.type << " channel dimension "
607                     << channelQuant.channelDim << " is underspecified (can't be 0)";
608             for (uint32_t i = 0; i < expected; ++i) {
609                 NN_RET_CHECK_GT(channelQuant.scales[i], 0.0f)
610                         << "Operand of type " << operand.type
611                         << " with a non-positive value in scales[" << i
612                         << "]=" << channelQuant.scales[i];
613             }
614             return kVersionFeatureLevel3;
615         }
616         case OperandType::OEM:
617         case OperandType::TENSOR_OEM_BYTE:
618             // No validation for OEM types.
619             return kVersionFeatureLevel1;
620     }
621     if (isExtension(operand.type)) {
622         NN_RET_CHECK(std::holds_alternative<Operand::NoParams>(operand.extraParams) ||
623                      std::holds_alternative<Operand::ExtensionParams>(operand.extraParams))
624                 << "Extension operand of type " << operand.type
625                 << " must not have SymmPerChannelQuant extraParams";
626         return kVersionFeatureLevel1;
627     }
628     NN_RET_CHECK_FAIL() << "Invalid OperandType " << operand.type;
629 }
630 
validateOperand(const Operand & operand,size_t operandValuesSize,const std::vector<size_t> & poolSizes,const std::vector<Model::Subgraph> & subgraphs,std::vector<std::optional<Version>> * subgraphVersionCache)631 Result<Version> validateOperand(const Operand& operand, size_t operandValuesSize,
632                                 const std::vector<size_t>& poolSizes,
633                                 const std::vector<Model::Subgraph>& subgraphs,
634                                 std::vector<std::optional<Version>>* subgraphVersionCache) {
635     auto version = NN_TRY(validateOperandType(operand.type));
636     version = combineVersions(version, NN_TRY(validateOperandLifeTime(operand)));
637     version = combineVersions(version, NN_TRY(validateOperandDimensions(operand)));
638     version = combineVersions(version, NN_TRY(validateOperandScale(operand)));
639     version = combineVersions(version, NN_TRY(validateOperandZeroPoint(operand)));
640     version = combineVersions(version, NN_TRY(validateOperandExtraParams(operand)));
641     version = combineVersions(
642             version, NN_TRY(validateOperandDataLocation(operand, operandValuesSize, poolSizes,
643                                                         subgraphs, subgraphVersionCache)));
644 
645     // For constants, validate that the length is as expected. The other lifetimes
646     // expect the length to be 0. Don't validate for OEM types.
647     if (operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE ||
648         operand.lifetime == Operand::LifeTime::CONSTANT_COPY ||
649         operand.lifetime == Operand::LifeTime::POINTER) {
650         if (!isExtension(operand.type) && operand.type != OperandType::OEM &&
651             operand.type != OperandType::TENSOR_OEM_BYTE) {
652             const auto expectedLength = getNonExtensionSize(operand).value();
653             NN_RET_CHECK_EQ(operand.location.length, expectedLength)
654                     << "For operand " << operand.type << " expected a size of " << expectedLength
655                     << " but got " << operand.location.length;
656         }
657     }
658 
659     return version;
660 }
661 
validateOperands(const std::vector<Operand> & operands,size_t operandValuesSize,const std::vector<size_t> & poolSizes,const std::vector<Model::Subgraph> & subgraphs,std::vector<std::optional<Version>> * subgraphVersionCache)662 Result<std::vector<Version>> validateOperands(
663         const std::vector<Operand>& operands, size_t operandValuesSize,
664         const std::vector<size_t>& poolSizes, const std::vector<Model::Subgraph>& subgraphs,
665         std::vector<std::optional<Version>>* subgraphVersionCache) {
666     std::vector<Version> versions;
667     versions.reserve(operands.size());
668     for (size_t i = 0; i < operands.size(); ++i) {
669         auto result = validateOperand(operands[i], operandValuesSize, poolSizes, subgraphs,
670                                       subgraphVersionCache);
671         if (!result.has_value()) {
672             return error() << std::move(result).error() << " for operand " << i;
673         }
674         versions.push_back(result.value());
675     }
676     return versions;
677 }
678 
679 // Forward declaration.
680 Result<Version> validateOperationIncludingOperandVersions(
681         const Operation& operation, const std::vector<Operand>& operands,
682         const std::vector<Version>& operandVersions, const std::vector<Model::Subgraph>& subgraphs);
683 
validateOperations(const std::vector<Operation> & operations,const std::vector<Operand> & operands,const std::vector<Version> & operandVersions,const std::vector<Model::Subgraph> & subgraphs)684 Result<Version> validateOperations(const std::vector<Operation>& operations,
685                                    const std::vector<Operand>& operands,
686                                    const std::vector<Version>& operandVersions,
687                                    const std::vector<Model::Subgraph>& subgraphs) {
688     auto version = kVersionFeatureLevel1;
689     for (size_t i = 0; i < operations.size(); ++i) {
690         auto result = validateOperationIncludingOperandVersions(operations[i], operands,
691                                                                 operandVersions, subgraphs);
692         if (!result.has_value()) {
693             return error() << std::move(result).error() << " for operation " << i;
694         }
695         version = combineVersions(version, result.value());
696     }
697     return version;
698 }
699 
validateUnknownHandle(const Memory::Unknown::Handle & handle)700 Result<Version> validateUnknownHandle(const Memory::Unknown::Handle& handle) {
701     NN_RET_CHECK(std::all_of(handle.fds.begin(), handle.fds.end(),
702                              [](const base::unique_fd& fd) { return fd.ok(); }));
703     return kVersionFeatureLevel3;
704 }
705 
validateSharedHandle(const SharedHandle & handle)706 Result<Version> validateSharedHandle(const SharedHandle& handle) {
707     // The absence of a shared handle is implicitly valid for all versions.
708     if (handle == nullptr) {
709         return kVersionFeatureLevel1;
710     }
711     NN_RET_CHECK(handle->ok());
712     return kVersionFeatureLevel3;
713 }
714 
validateMemory(const Memory::Ashmem & memory)715 Result<Version> validateMemory(const Memory::Ashmem& memory) {
716     NN_RET_CHECK(memory.fd.ok());
717     NN_RET_CHECK_NE(memory.size, 0u);
718     return kVersionFeatureLevel1;
719 }
720 
validateMemory(const Memory::Fd & memory)721 Result<Version> validateMemory(const Memory::Fd& memory) {
722     NN_RET_CHECK(memory.fd.ok());
723     NN_RET_CHECK_NE(memory.size, 0u);
724 
725     // `prot` is allowed to be either PROT_NONE (which has a value of 0) or the bitwise OR of either
726     // PROT_READ or PROT_WRITE. If any other bits are set, the `prot` field is invalid.
727     constexpr int kAllowedBits = PROT_READ | PROT_WRITE;
728     NN_RET_CHECK_EQ(memory.prot & ~kAllowedBits, 0);
729 
730     return kVersionFeatureLevel1;
731 }
732 
validateMemory(const Memory::HardwareBuffer & memory)733 Result<Version> validateMemory(const Memory::HardwareBuffer& memory) {
734     NN_RET_CHECK(memory.handle.get() != nullptr);
735     return kVersionFeatureLevel3;
736 }
737 
validateMemory(const Memory::Unknown & memory)738 Result<Version> validateMemory(const Memory::Unknown& memory) {
739     NN_TRY(validateUnknownHandle(memory.handle));
740     return kVersionFeatureLevel3;
741 }
742 
validateSharedMemory(const SharedMemory & memory)743 Result<Version> validateSharedMemory(const SharedMemory& memory) {
744     NN_RET_CHECK(memory != nullptr);
745     return std::visit([](const auto& x) { return validateMemory(x); }, memory->handle);
746 }
747 
validateModelSubgraphInputOutputs(const std::vector<uint32_t> & indexes,const std::vector<Operand> & operands,Operand::LifeTime lifetime)748 Result<void> validateModelSubgraphInputOutputs(const std::vector<uint32_t>& indexes,
749                                                const std::vector<Operand>& operands,
750                                                Operand::LifeTime lifetime) {
751     const size_t operandCount = operands.size();
752     for (uint32_t i : indexes) {
753         NN_RET_CHECK_LT(i, operandCount)
754                 << "Model " << lifetime << " input or output index out of range: " << i << "/"
755                 << operandCount;
756         const Operand& operand = operands[i];
757         NN_RET_CHECK_EQ(operand.lifetime, lifetime)
758                 << "Model " << lifetime << " operand " << i << " has lifetime of "
759                 << operand.lifetime << " instead of the expected " << lifetime;
760     }
761 
762     std::vector<uint32_t> sortedIndexes = indexes;
763     std::sort(sortedIndexes.begin(), sortedIndexes.end());
764     const auto iter = std::adjacent_find(sortedIndexes.begin(), sortedIndexes.end());
765     NN_RET_CHECK(iter == sortedIndexes.end())
766             << "Model input or output occurs multiple times: " << *iter;
767 
768     for (size_t i = 0; i < operands.size(); ++i) {
769         if (operands[i].lifetime == lifetime) {
770             const auto containsIndex = [&sortedIndexes](size_t index) {
771                 return binary_search(sortedIndexes.begin(), sortedIndexes.end(), index);
772             };
773             NN_RET_CHECK(containsIndex(i))
774                     << "Operand " << i << " marked as " << lifetime
775                     << " but is not included in Model input or output indexes";
776         }
777     }
778 
779     return {};
780 }
781 
validateExecutionOrder(const Model::Subgraph & subgraph)782 Result<void> validateExecutionOrder(const Model::Subgraph& subgraph) {
783     // Either the operand has a known value before model execution begins, or we've seen a writer
784     // for this operand while walking operands in execution order. Initialize to known operands.
785     std::vector<bool> operandValueKnown;
786     operandValueKnown.reserve(subgraph.operands.size());
787     std::transform(subgraph.operands.begin(), subgraph.operands.end(),
788                    std::back_inserter(operandValueKnown), [](const Operand& operand) {
789                        return operand.lifetime != Operand::LifeTime::TEMPORARY_VARIABLE &&
790                               operand.lifetime != Operand::LifeTime::SUBGRAPH_OUTPUT;
791                    });
792 
793     // Validate that operations are sorted into execution order.
794     //
795     // If there is a cycle in the graph, the operations will not
796     // appear to be sorted into execution order: Some operation will
797     // have an input for which operandValueKnown[] is false.
798     for (size_t i = 0; i < subgraph.operations.size(); ++i) {
799         const auto& operation = subgraph.operations[i];
800 
801         for (size_t j = 0; j < operation.inputs.size(); ++j) {
802             const uint32_t k = operation.inputs[j];
803             NN_RET_CHECK(operandValueKnown[k])
804                     << "Operation " << i << " input " << j << " (operand " << k
805                     << ") is read before it is written";
806         }
807 
808         for (size_t j = 0; j < operation.outputs.size(); ++j) {
809             const uint32_t k = operation.outputs[j];
810             // Assuming validateOperations() has not returned an error, we know that this output is
811             // TEMPORARY_VARIABLE or MODEL_OUTPUT, and so the only way operandValueKnown[k] can be
812             // true is if we've already seen a writer for this operand.
813             NN_RET_CHECK(!operandValueKnown[k])
814                     << "Operation " << i << " output " << j << " (operand " << k
815                     << ") has already been written";
816             operandValueKnown[k] = true;
817         }
818     }
819 
820     // Verify all operands are written.
821     for (size_t i = 0; i < subgraph.operands.size(); ++i) {
822         NN_RET_CHECK(operandValueKnown[i]) << "Operand " << i << " is never written";
823     }
824 
825     // TODO(b/77871786): verify that every operation has at least one output operand that is read?
826 
827     return {};
828 }
829 
830 // Validate a subgraph, ensuring all subgraphs it depends on are also validated.
831 //
832 // `referencedIndex` is empty if the subgraph being validated is the main subgraph, otherwise it is
833 // the index of the referenced subgraph being validated.
834 //
835 // referenced[i] and (*subgraphVersionCache)[i] correspond to the same subgraph, and therefore
836 // `referenced` and `subgraphVersionCache` must have the same length.
validateModelSubgraph(const Model::Subgraph & subgraph,std::optional<size_t> referencedIndex,size_t operandValuesSize,const std::vector<size_t> & poolSizes,const std::vector<Model::Subgraph> & referenced,std::vector<std::optional<Version>> * subgraphVersionCache)837 Result<Version> validateModelSubgraph(const Model::Subgraph& subgraph,
838                                       std::optional<size_t> referencedIndex,
839                                       size_t operandValuesSize,
840                                       const std::vector<size_t>& poolSizes,
841                                       const std::vector<Model::Subgraph>& referenced,
842                                       std::vector<std::optional<Version>>* subgraphVersionCache) {
843     CHECK(subgraphVersionCache != nullptr);
844     CHECK_EQ(referenced.size(), subgraphVersionCache->size());
845 
846     // Quickly return if the current subgraph has already been checked for its version.
847     if (referencedIndex.has_value()) {
848         if (auto version = subgraphVersionCache->at(*referencedIndex)) {
849             return *version;
850         }
851     }
852 
853     NN_RET_CHECK(!subgraph.operands.empty());
854     NN_RET_CHECK(!subgraph.operations.empty());
855     // TODO(b/173780642): Clarify whether subgraphs with no inputs or outputs are valid.
856     // NN_RET_CHECK(!subgraph.inputIndexes.empty());
857     // NN_RET_CHECK(!subgraph.outputIndexes.empty());
858 
859     const auto operandVersions = NN_TRY(validateOperands(
860             subgraph.operands, operandValuesSize, poolSizes, referenced, subgraphVersionCache));
861     const auto operationsVersion = NN_TRY(validateOperations(subgraph.operations, subgraph.operands,
862                                                              operandVersions, referenced));
863 
864     // Accumulate the versions from all operands and operations.
865     const auto version = std::accumulate(operandVersions.begin(), operandVersions.end(),
866                                          operationsVersion, combineVersions);
867 
868     NN_TRY(validateModelSubgraphInputOutputs(subgraph.inputIndexes, subgraph.operands,
869                                              Operand::LifeTime::SUBGRAPH_INPUT));
870     NN_TRY(validateModelSubgraphInputOutputs(subgraph.outputIndexes, subgraph.operands,
871                                              Operand::LifeTime::SUBGRAPH_OUTPUT));
872 
873     NN_TRY(validateExecutionOrder(subgraph));
874 
875     // Mark the current subgraph as having already been validated so the caller can quickly return
876     // if this subgraph is checked again.
877     if (referencedIndex.has_value()) {
878         subgraphVersionCache->at(*referencedIndex) = version;
879     }
880     return version;
881 }
882 
validateExtensionNamesAndPrefixes(const std::vector<ExtensionNameAndPrefix> & extensionNamesAndPrefixes)883 Result<Version> validateExtensionNamesAndPrefixes(
884         const std::vector<ExtensionNameAndPrefix>& extensionNamesAndPrefixes) {
885     for (const auto& extensionNameAndPrefix : extensionNamesAndPrefixes) {
886         NN_RET_CHECK(isValidExtensionName(extensionNameAndPrefix.name));
887     }
888 
889     std::vector<std::reference_wrapper<const std::string>> names;
890     names.reserve(extensionNamesAndPrefixes.size());
891     std::transform(extensionNamesAndPrefixes.begin(), extensionNamesAndPrefixes.end(),
892                    std::back_inserter(names),
893                    [](const ExtensionNameAndPrefix& extensionNameAndPrefix) {
894                        return std::cref(extensionNameAndPrefix.name);
895                    });
896     std::sort(names.begin(), names.end(), std::less<std::string>{});
897     const auto nameIter =
898             std::adjacent_find(names.begin(), names.end(), std::equal_to<std::string>{});
899     NN_RET_CHECK(nameIter == names.end())
900             << "ExtensionNamesAndPrefixes has duplicate name " << nameIter->get();
901 
902     std::vector<uint16_t> types;
903     types.reserve(extensionNamesAndPrefixes.size());
904     std::transform(extensionNamesAndPrefixes.begin(), extensionNamesAndPrefixes.end(),
905                    std::back_inserter(types),
906                    [](const ExtensionNameAndPrefix& extensionNameAndPrefix) {
907                        return extensionNameAndPrefix.prefix;
908                    });
909     std::sort(types.begin(), types.end());
910     const auto typeIter = std::adjacent_find(types.begin(), types.end());
911     NN_RET_CHECK(typeIter == types.end())
912             << "ExtensionNamesAndPrefixes has duplicate type " << *typeIter;
913 
914     const bool hasExtensions = !extensionNamesAndPrefixes.empty();
915     return hasExtensions ? kVersionFeatureLevel3 : kVersionFeatureLevel1;
916 }
917 
918 // Makes sure the model does not contain subgraph reference cycles.
919 //
920 // This function verifies that referencedSubgraphs[subgraphIndex] and any subgraphs it refences do
921 // not contain any reference cycles. `path` is used to keep track of which referenced subgraphs have
922 // already been visited in the current recursive reference path. `verified` is a cache to keep track
923 // of which referenced subgraphs have already been verified not to form reference cycles.
924 //
925 // referencedSubgraphs[i], (*path)[i], and (*verified)[i] all correspond to the same subgraph, and
926 // therefore `referencedSubgraphs`, `path`, and `verified` must all have the same length.
checkNoReferenceCycles(const std::vector<Model::Subgraph> & referencedSubgraphs,uint32_t subgraphIndex,std::vector<bool> * path,std::vector<bool> * verified)927 Result<void> checkNoReferenceCycles(const std::vector<Model::Subgraph>& referencedSubgraphs,
928                                     uint32_t subgraphIndex, std::vector<bool>* path,
929                                     std::vector<bool>* verified) {
930     CHECK(path != nullptr);
931     CHECK(verified != nullptr);
932     CHECK_EQ(referencedSubgraphs.size(), path->size());
933     CHECK_EQ(referencedSubgraphs.size(), verified->size());
934     NN_RET_CHECK_LT(subgraphIndex, referencedSubgraphs.size());
935     const auto& subgraph = referencedSubgraphs.at(subgraphIndex);
936 
937     // Quickly return if the current subgraph has already been verified to have no reference cycles.
938     if ((*verified)[subgraphIndex]) {
939         return {};
940     }
941 
942     // Add the current subgraph to the path (making sure that it is not already part of the path),
943     // and verify that all subgraphs this subgraph references do not contain cycles. The current
944     // subgraph is removed from the path only after all subgraphs this subgraph references have been
945     // checked.
946     NN_RET_CHECK((*path)[subgraphIndex] == false) << "Model contains a circular subgraph reference";
947     (*path)[subgraphIndex] = true;
948     for (const Operand& operand : subgraph.operands) {
949         if (operand.lifetime == Operand::LifeTime::SUBGRAPH) {
950             const uint32_t refSubgraphIndex = operand.location.offset;
951             NN_TRY(checkNoReferenceCycles(referencedSubgraphs, refSubgraphIndex, path, verified));
952         }
953     }
954     (*path)[subgraphIndex] = false;
955 
956     // Mark the current subgraph as having already been verified so the caller can quickly return if
957     // this subgraph is checked again.
958     (*verified)[subgraphIndex] = true;
959     return {};
960 }
961 
checkNoReferenceCycles(const std::vector<Model::Subgraph> & referencedSubgraphs)962 Result<void> checkNoReferenceCycles(const std::vector<Model::Subgraph>& referencedSubgraphs) {
963     const size_t count = referencedSubgraphs.size();
964     std::vector<bool> path(count);
965     std::vector<bool> verified(count);
966     for (size_t i = 0; i < count; ++i) {
967         NN_TRY(checkNoReferenceCycles(referencedSubgraphs, i, &path, &verified));
968     }
969     return {};
970 }
971 
validateModel(const Model & model)972 Result<Version> validateModel(const Model& model) {
973     auto version = NN_TRY(validateVector(model.pools, validateSharedMemory));
974     version = combineVersions(
975             version, NN_TRY(validateExtensionNamesAndPrefixes(model.extensionNameToPrefix)));
976 
977     // Ignore relaxComputationFloat32toFloat16 version because in the worst case it makes the
978     // execution stricter.
979 
980     // Referenced models were introduced in Android R.
981     const bool hasReferencedModels = !model.referenced.empty();
982     const auto referenceModelVersion =
983             hasReferencedModels ? kVersionFeatureLevel4 : kVersionFeatureLevel1;
984     version = combineVersions(version, referenceModelVersion);
985 
986     // Ensure that there are no cycles formed by the subgraphs.
987     NN_TRY(checkNoReferenceCycles(model.referenced));
988 
989     // Get memory sizes.
990     const auto [operandValuesSize, poolSizes] = getMemorySizes(model);
991 
992     // Validate referenced subgraphs.
993     auto subgraphVersionCache = std::vector<std::optional<Version>>(model.referenced.size());
994     for (size_t referencedIndex = 0; referencedIndex < model.referenced.size(); ++referencedIndex) {
995         const auto& subgraph = model.referenced[referencedIndex];
996         const auto subgraphVersion =
997                 NN_TRY(validateModelSubgraph(subgraph, referencedIndex, operandValuesSize,
998                                              poolSizes, model.referenced, &subgraphVersionCache));
999         version = combineVersions(version, subgraphVersion);
1000     }
1001 
1002     // Validate main subgraph.
1003     const auto subgraphVersion =
1004             NN_TRY(validateModelSubgraph(model.main, std::nullopt, operandValuesSize, poolSizes,
1005                                          model.referenced, &subgraphVersionCache));
1006     version = combineVersions(version, subgraphVersion);
1007 
1008     return version;
1009 }
1010 
validateBufferDesc(const BufferDesc & bufferDesc)1011 Result<Version> validateBufferDesc(const BufferDesc& bufferDesc) {
1012     // An empty BufferDesc is the default value, so it is implicitly valid for all versions.
1013     return bufferDesc.dimensions.empty() ? kVersionFeatureLevel1 : kVersionFeatureLevel4;
1014 }
1015 
validateBufferRole(const BufferRole & bufferRole)1016 Result<Version> validateBufferRole(const BufferRole& bufferRole) {
1017     NN_RET_CHECK_GT(bufferRole.probability, 0.0f);
1018     NN_RET_CHECK_LE(bufferRole.probability, 1.0f);
1019     return kVersionFeatureLevel4;
1020 }
1021 
validateRequestArgument(const Request::Argument & requestArgument,const std::vector<size_t> & memorySizes,bool isOutput)1022 Result<Version> validateRequestArgument(const Request::Argument& requestArgument,
1023                                         const std::vector<size_t>& memorySizes, bool isOutput) {
1024     const auto lifetime = requestArgument.lifetime;
1025     const auto& location = requestArgument.location;
1026     const auto& dimensions = requestArgument.dimensions;
1027 
1028     switch (lifetime) {
1029         case Request::Argument::LifeTime::POOL: {
1030             NN_RET_CHECK(location.pointer == kNullptrVariant);
1031             NN_RET_CHECK_LT(location.poolIndex, memorySizes.size());
1032             // Do the addition using uint64_t to avoid potential wrap-around problems.
1033             const auto lastPosition =
1034                     static_cast<uint64_t>(location.offset) + location.length + location.padding;
1035             const auto memorySize = memorySizes[location.poolIndex];
1036             NN_RET_CHECK_LE(lastPosition, memorySize);
1037             if (memorySize > 0) {
1038                 // Must specify a positive length if the memory pool has a known size.
1039                 NN_RET_CHECK_GT(location.length, 0u);
1040             }
1041             return kVersionFeatureLevel1;
1042         }
1043         case Request::Argument::LifeTime::NO_VALUE:
1044             NN_RET_CHECK(location.pointer == kNullptrVariant);
1045             NN_RET_CHECK_EQ(location.poolIndex, 0u);
1046             NN_RET_CHECK_EQ(location.offset, 0u);
1047             NN_RET_CHECK_EQ(location.length, 0u);
1048             NN_RET_CHECK_EQ(location.padding, 0u);
1049             NN_RET_CHECK(dimensions.empty());
1050             return kVersionFeatureLevel1;
1051         case Request::Argument::LifeTime::POINTER: {
1052             const bool isNullptr =
1053                     std::visit([](auto ptr) { return ptr == nullptr; }, location.pointer);
1054             NN_RET_CHECK(!isNullptr);
1055             NN_RET_CHECK_EQ(location.poolIndex, 0u);
1056             NN_RET_CHECK_EQ(location.offset, 0u);
1057             NN_RET_CHECK_NE(location.length, 0u);
1058             if (isOutput) {
1059                 NN_RET_CHECK(std::holds_alternative<void*>(location.pointer));
1060             }
1061             return kVersionFeatureLevel1;
1062         }
1063     }
1064     NN_RET_CHECK_FAIL() << "Invalid Request::Argument::LifeTime " << lifetime;
1065 }
1066 
validateRequestMemoryPool(const Request::MemoryPool & memoryPool)1067 Result<Version> validateRequestMemoryPool(const Request::MemoryPool& memoryPool) {
1068     if (std::holds_alternative<Request::MemoryDomainToken>(memoryPool)) {
1069         NN_RET_CHECK(std::get<Request::MemoryDomainToken>(memoryPool) != kInvalidMemoryDomainToken);
1070         return kVersionFeatureLevel4;
1071     }
1072     if (std::holds_alternative<SharedBuffer>(memoryPool)) {
1073         NN_RET_CHECK(std::get<SharedBuffer>(memoryPool) != nullptr);
1074         return kVersionFeatureLevel4;
1075     }
1076     return validateSharedMemory(std::get<SharedMemory>(memoryPool));
1077 }
1078 
validateRequest(const Request & request)1079 Result<Version> validateRequest(const Request& request) {
1080     auto version = NN_TRY(validateVector(request.pools, validateRequestMemoryPool));
1081 
1082     // Get memory sizes. For IBuffer or MemoryDomainToken types, set size to 0.
1083     std::vector<size_t> memorySizes;
1084     memorySizes.reserve(request.pools.size());
1085     std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(memorySizes),
1086                    [](const Request::MemoryPool& memoryPool) {
1087                        const auto* memory = std::get_if<SharedMemory>(&memoryPool);
1088                        return memory != nullptr ? getSize(*memory) : 0;
1089                    });
1090 
1091     for (size_t i = 0; i < request.inputs.size(); ++i) {
1092         const auto& input = request.inputs[i];
1093         auto result = validateRequestArgument(input, memorySizes, /*isOutput=*/false);
1094         if (!result.has_value()) {
1095             return error() << std::move(result).error() << " for input RequestArgument " << i;
1096         }
1097         version = combineVersions(version, result.value());
1098     }
1099     for (size_t i = 0; i < request.outputs.size(); ++i) {
1100         const auto& output = request.outputs[i];
1101         auto result = validateRequestArgument(output, memorySizes, /*isOutput=*/true);
1102         if (!result.has_value()) {
1103             return error() << std::move(result).error() << " for output RequestArgument " << i;
1104         }
1105         version = combineVersions(version, result.value());
1106     }
1107 
1108     return version;
1109 }
1110 
validateOptionalTimePoint(const OptionalTimePoint & optionalTimePoint)1111 Result<Version> validateOptionalTimePoint(const OptionalTimePoint& optionalTimePoint) {
1112     if (optionalTimePoint.has_value()) {
1113         NN_RET_CHECK_GE(optionalTimePoint->time_since_epoch().count(), 0);
1114     }
1115     // An omitted time point is the default value, so it is implicitly valid for all versions.
1116     return !optionalTimePoint.has_value() ? kVersionFeatureLevel1 : kVersionFeatureLevel4;
1117 }
1118 
validateOptionalTimeoutDuration(const OptionalDuration & optionalTimeoutDuration)1119 Result<Version> validateOptionalTimeoutDuration(const OptionalDuration& optionalTimeoutDuration) {
1120     if (optionalTimeoutDuration.has_value()) {
1121         NN_RET_CHECK_GE(optionalTimeoutDuration->count(), 0);
1122     }
1123     // An omitted duration is the default value, so it is implicitly valid for all versions.
1124     return !optionalTimeoutDuration.has_value() ? kVersionFeatureLevel1 : kVersionFeatureLevel4;
1125 }
1126 
validateCacheToken(const CacheToken & cacheToken)1127 Result<Version> validateCacheToken(const CacheToken& cacheToken) {
1128     // A CacheToken of 0 is the default value, so it is implicitly valid for all versions.
1129     constexpr auto kDefaultCacheToken = CacheToken{};
1130     return cacheToken == kDefaultCacheToken ? kVersionFeatureLevel1 : kVersionFeatureLevel3;
1131 }
1132 
validateSyncFence(const SyncFence & syncFence)1133 Result<Version> validateSyncFence(const SyncFence& syncFence) {
1134     // The absence of a sync fence is implicitly valid for all versions.
1135     if (!syncFence.hasFd()) {
1136         return kVersionFeatureLevel1;
1137     }
1138     NN_RET_CHECK_GE(syncFence.getFd(), 0);
1139     return kVersionFeatureLevel4;
1140 }
1141 
validateTokenValuePair(const TokenValuePair &)1142 Result<Version> validateTokenValuePair(const TokenValuePair& /*tokenValuePair*/) {
1143     return kVersionFeatureLevel8;
1144 }
1145 
validateRequestArgumentsForModel(const std::vector<Request::Argument> & requestArguments,const std::vector<uint32_t> & operandIndexes,const std::vector<Operand> & operands,bool isOutput,bool allowUnspecifiedOutput)1146 Result<Version> validateRequestArgumentsForModel(
1147         const std::vector<Request::Argument>& requestArguments,
1148         const std::vector<uint32_t>& operandIndexes, const std::vector<Operand>& operands,
1149         bool isOutput, bool allowUnspecifiedOutput) {
1150     auto version = kVersionFeatureLevel1;
1151     // The request should specify as many arguments as were described in the model.
1152     const std::string_view type = isOutput ? "output" : "input";
1153     const size_t requestArgumentCount = requestArguments.size();
1154     NN_RET_CHECK_EQ(requestArgumentCount, operandIndexes.size())
1155             << "Request specifies " << requestArgumentCount << " " << type << "s but the model has "
1156             << operandIndexes.size();
1157     for (size_t requestArgumentIndex = 0; requestArgumentIndex < requestArgumentCount;
1158          requestArgumentIndex++) {
1159         const Request::Argument& requestArgument = requestArguments[requestArgumentIndex];
1160         // Get the operand index for this argument. We extract it from the list
1161         // that was provided in the call to ANeuralNetworksModel_identifyInputsAndOutputs.
1162         // We assume in this function that the model has been validated already.
1163         const uint32_t operandIndex = operandIndexes[requestArgumentIndex];
1164         const Operand& operand = operands[operandIndex];
1165         if (requestArgument.lifetime != Request::Argument::LifeTime::NO_VALUE) {
1166             const bool isExtensionType = isExtension(operand.type);
1167             // If the argument specified a dimension, validate it.
1168             uint32_t modelRank = operand.dimensions.size();
1169             uint32_t requestRank = requestArgument.dimensions.size();
1170             if (requestRank == 0) {
1171                 // NOTE: validateRequestArguments cannot validate unknown tensor rank with
1172                 // extension operand type.
1173                 if (!isExtensionType && !isNonExtensionScalar(operand.type)) {
1174                     if (modelRank <= 0) {
1175                         NN_RET_CHECK(isOutput)
1176                                 << "Model has unknown input rank but the request does not "
1177                                    "specify the rank.";
1178                         NN_RET_CHECK(allowUnspecifiedOutput)
1179                                 << "Model has unknown output rank and request does not specify it.";
1180                         // Unspecified output dimensions introduced in Android Q.
1181                         version = combineVersions(version, kVersionFeatureLevel3);
1182                     }
1183                 }
1184                 // Validate that all the dimensions are specified in the model.
1185                 for (size_t i = 0; i < modelRank; i++) {
1186                     if (operand.dimensions[i] == 0) {
1187                         NN_RET_CHECK(isOutput && allowUnspecifiedOutput)
1188                                 << "Model has dimension " << i
1189                                 << " set to 0 but the request does not specify the dimension.";
1190                         // Unspecified output dimensions introduced in Android Q.
1191                         version = combineVersions(version, kVersionFeatureLevel3);
1192                     }
1193                 }
1194             } else {
1195                 NN_RET_CHECK(modelRank == 0 || requestRank == modelRank)
1196                         << "Request " << type << " " << requestArgumentIndex
1197                         << " has number of dimensions (" << requestRank
1198                         << ") different than the model's (" << modelRank << ")";
1199                 for (size_t i = 0; i < requestRank; i++) {
1200                     NN_RET_CHECK(modelRank == 0 || operand.dimensions[i] == 0 ||
1201                                  requestArgument.dimensions[i] == operand.dimensions[i])
1202                             << "Request " << type << " " << requestArgumentIndex
1203                             << " has dimension " << i << " of " << requestArgument.dimensions[i]
1204                             << " different than the model's " << operand.dimensions[i];
1205                     if (requestArgument.dimensions[i] == 0) {
1206                         NN_RET_CHECK(isOutput && allowUnspecifiedOutput)
1207                                 << "Request " << type << " " << requestArgumentIndex
1208                                 << " has dimension " << i << " of zero";
1209                         // Unspecified output dimensions introduced in Android Q.
1210                         version = combineVersions(version, kVersionFeatureLevel3);
1211                     }
1212                 }
1213             }
1214             // NOTE: validateRequestArguments cannot validate DataLocation::length
1215             // with extension operand type.
1216             if (!isExtensionType && requestArgument.location.length != 0) {
1217                 const auto dimensions =
1218                         NN_TRY(combineDimensions(operand.dimensions, requestArgument.dimensions));
1219                 const size_t expectedLength = getNonExtensionSize(operand.type, dimensions).value();
1220                 if (expectedLength != 0) {
1221                     NN_RET_CHECK_EQ(requestArgument.location.length, expectedLength)
1222                             << "Request " << type << " " << requestArgumentIndex
1223                             << " expected a size of " << expectedLength << " but got "
1224                             << requestArgument.location.length;
1225                 }
1226             }
1227         }
1228     }
1229     return version;
1230 }
1231 
validateRequestForModelImpl(const Request & request,const Model & model,bool allowUnspecifiedOutput)1232 Result<Version> validateRequestForModelImpl(const Request& request, const Model& model,
1233                                             bool allowUnspecifiedOutput) {
1234     auto version = NN_TRY(validateRequest(request));
1235     version = combineVersions(version,
1236                               NN_TRY(validateRequestArgumentsForModel(
1237                                       request.inputs, model.main.inputIndexes, model.main.operands,
1238                                       /*isOutput=*/false, /*allowUnspecifiedOutput=*/true)));
1239     version = combineVersions(
1240             version, NN_TRY(validateRequestArgumentsForModel(
1241                              request.outputs, model.main.outputIndexes, model.main.operands,
1242                              /*isOutput=*/true, allowUnspecifiedOutput)));
1243     return version;
1244 }
1245 
validateMemoryDescImpl(const BufferDesc & desc,const std::vector<SharedPreparedModel> & preparedModels,const std::vector<BufferRole> & inputRoles,const std::vector<BufferRole> & outputRoles,const std::function<const Model * (const SharedPreparedModel &)> & getModel,std::set<PreparedModelRole> * preparedModelRoles,Operand * combinedOperand)1246 Result<Version> validateMemoryDescImpl(
1247         const BufferDesc& desc, const std::vector<SharedPreparedModel>& preparedModels,
1248         const std::vector<BufferRole>& inputRoles, const std::vector<BufferRole>& outputRoles,
1249         const std::function<const Model*(const SharedPreparedModel&)>& getModel,
1250         std::set<PreparedModelRole>* preparedModelRoles, Operand* combinedOperand) {
1251     NN_RET_CHECK(!preparedModels.empty());
1252     NN_RET_CHECK(!inputRoles.empty() || !outputRoles.empty());
1253 
1254     std::set<PreparedModelRole> roles;
1255     std::vector<nn::Operand> operands;
1256     operands.reserve(inputRoles.size() + outputRoles.size());
1257     for (const auto& role : inputRoles) {
1258         NN_RET_CHECK_LT(role.modelIndex, preparedModels.size());
1259         const auto& preparedModel = preparedModels[role.modelIndex];
1260         NN_RET_CHECK(preparedModel != nullptr);
1261         const auto* model = getModel(preparedModel);
1262         NN_RET_CHECK(model != nullptr);
1263         const auto& inputIndexes = model->main.inputIndexes;
1264         NN_RET_CHECK_LT(role.ioIndex, inputIndexes.size());
1265         NN_RET_CHECK_GT(role.probability, 0.0f);
1266         NN_RET_CHECK_LE(role.probability, 1.0f);
1267         const auto [it, success] = roles.emplace(preparedModel.get(), IOType::INPUT, role.ioIndex);
1268         NN_RET_CHECK(success);
1269         operands.push_back(model->main.operands[inputIndexes[role.ioIndex]]);
1270     }
1271     for (const auto& role : outputRoles) {
1272         NN_RET_CHECK_LT(role.modelIndex, preparedModels.size());
1273         const auto& preparedModel = preparedModels[role.modelIndex];
1274         NN_RET_CHECK(preparedModel != nullptr);
1275         const auto* model = getModel(preparedModel);
1276         NN_RET_CHECK(model != nullptr);
1277         const auto& outputIndexes = model->main.outputIndexes;
1278         NN_RET_CHECK_LT(role.ioIndex, outputIndexes.size());
1279         NN_RET_CHECK_GT(role.probability, 0.0f);
1280         NN_RET_CHECK_LE(role.probability, 1.0f);
1281         const auto [it, success] = roles.emplace(preparedModel.get(), IOType::OUTPUT, role.ioIndex);
1282         NN_RET_CHECK(success);
1283         operands.push_back(model->main.operands[outputIndexes[role.ioIndex]]);
1284     }
1285 
1286     CHECK(!operands.empty());
1287     const auto opType = operands.front().type;
1288 
1289     Dimensions dimensions = desc.dimensions;
1290     for (const auto& operand : operands) {
1291         NN_RET_CHECK_EQ(operand.type, opType) << operand.type << " vs " << operands.front().type;
1292         NN_RET_CHECK_EQ(operand.scale, operands.front().scale);
1293         NN_RET_CHECK_EQ(operand.zeroPoint, operands.front().zeroPoint);
1294         // NOTE: validateMemoryDesc cannot validate extra parameters for extension operand type.
1295         if (!isExtension(opType)) {
1296             NN_RET_CHECK_EQ(operand.extraParams, operands.front().extraParams)
1297                     << operand.extraParams << " vs " << operands.front().extraParams;
1298         }
1299         dimensions = NN_TRY(combineDimensions(dimensions, operand.dimensions));
1300     }
1301 
1302     // NOTE: validateMemoryDesc cannot validate scalar dimensions with extension operand type.
1303     if (!isExtension(opType)) {
1304         NN_RET_CHECK(!isNonExtensionScalar(opType) || dimensions.empty())
1305                 << "invalid dimensions with scalar operand type.";
1306     }
1307 
1308     if (preparedModelRoles != nullptr) {
1309         *preparedModelRoles = std::move(roles);
1310     }
1311     if (combinedOperand != nullptr) {
1312         *combinedOperand = operands.front();
1313         combinedOperand->dimensions = dimensions;
1314     }
1315     return kVersionFeatureLevel4;
1316 }
1317 
1318 class OperationValidationContext : public IOperationValidationContext {
1319     DISALLOW_IMPLICIT_CONSTRUCTORS(OperationValidationContext);
1320 
1321    public:
OperationValidationContext(const char * operationName,const std::vector<uint32_t> & inputIndexes,const std::vector<uint32_t> & outputIndexes,const std::vector<Operand> & operands)1322     OperationValidationContext(const char* operationName, const std::vector<uint32_t>& inputIndexes,
1323                                const std::vector<uint32_t>& outputIndexes,
1324                                const std::vector<Operand>& operands)
1325         : operationName(operationName),
1326           inputIndexes(inputIndexes),
1327           outputIndexes(outputIndexes),
1328           operands(operands) {}
1329 
1330     const char* getOperationName() const override;
1331 
1332     uint32_t getNumInputs() const override;
1333     OperandType getInputType(uint32_t index) const override;
1334     Shape getInputShape(uint32_t index) const override;
1335     const Operand::ExtraParams& getInputExtraParams(uint32_t index) const override;
1336 
1337     uint32_t getNumOutputs() const override;
1338     OperandType getOutputType(uint32_t index) const override;
1339     Shape getOutputShape(uint32_t index) const override;
1340 
1341    private:
1342     const Operand* getInputOperand(uint32_t index) const;
1343     const Operand* getOutputOperand(uint32_t index) const;
1344 
1345     const char* operationName;
1346     const std::vector<uint32_t>& inputIndexes;
1347     const std::vector<uint32_t>& outputIndexes;
1348     const std::vector<Operand>& operands;
1349 };
1350 
getOperationName() const1351 const char* OperationValidationContext::getOperationName() const {
1352     return operationName;
1353 }
1354 
getInputOperand(uint32_t index) const1355 const Operand* OperationValidationContext::getInputOperand(uint32_t index) const {
1356     return &operands.at(inputIndexes.at(index));
1357 }
1358 
getOutputOperand(uint32_t index) const1359 const Operand* OperationValidationContext::getOutputOperand(uint32_t index) const {
1360     return &operands.at(outputIndexes.at(index));
1361 }
1362 
getNumInputs() const1363 uint32_t OperationValidationContext::getNumInputs() const {
1364     auto count = inputIndexes.size();
1365     CHECK_LE(count, std::numeric_limits<uint32_t>::max());
1366     return static_cast<uint32_t>(count);
1367 }
1368 
getNumOutputs() const1369 uint32_t OperationValidationContext::getNumOutputs() const {
1370     auto count = outputIndexes.size();
1371     CHECK_LE(count, std::numeric_limits<uint32_t>::max());
1372     return static_cast<uint32_t>(count);
1373 }
1374 
getInputType(uint32_t index) const1375 OperandType OperationValidationContext::getInputType(uint32_t index) const {
1376     return getInputOperand(index)->type;
1377 }
1378 
getInputShape(uint32_t index) const1379 Shape OperationValidationContext::getInputShape(uint32_t index) const {
1380     const Operand* operand = getInputOperand(index);
1381     return {operand->type, operand->dimensions, operand->scale, operand->zeroPoint,
1382             operand->extraParams};
1383 }
1384 
getInputExtraParams(uint32_t index) const1385 const Operand::ExtraParams& OperationValidationContext::getInputExtraParams(uint32_t index) const {
1386     return getInputOperand(index)->extraParams;
1387 }
1388 
getOutputType(uint32_t index) const1389 OperandType OperationValidationContext::getOutputType(uint32_t index) const {
1390     return getOutputOperand(index)->type;
1391 }
1392 
getOutputShape(uint32_t index) const1393 Shape OperationValidationContext::getOutputShape(uint32_t index) const {
1394     const Operand* operand = getOutputOperand(index);
1395     return {operand->type, operand->dimensions, operand->scale, operand->zeroPoint,
1396             operand->extraParams};
1397 }
1398 
1399 // TODO(b/169345292): reduce the duplicate validation here
1400 
validateOperandSymmPerChannelQuantParamsImpl(const Operand & operand,const Operand::SymmPerChannelQuantParams & channelQuant,const char * tag)1401 Result<void> validateOperandSymmPerChannelQuantParamsImpl(
1402         const Operand& operand, const Operand::SymmPerChannelQuantParams& channelQuant,
1403         const char* tag) {
1404     if (operand.type != OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
1405         NN_RET_CHECK_FAIL();
1406     }
1407 
1408     NN_RET_CHECK_LT(channelQuant.channelDim, operand.dimensions.size()) << tag;
1409     NN_RET_CHECK(!channelQuant.scales.empty()) << tag;
1410     NN_RET_CHECK_EQ(channelQuant.scales.size(), operand.dimensions[channelQuant.channelDim]) << tag;
1411     NN_RET_CHECK_NE(operand.dimensions[channelQuant.channelDim], 0u)
1412             << tag << " channel dimension " << channelQuant.channelDim << " is underspecified";
1413     for (uint32_t i = 0; i < operand.dimensions[channelQuant.channelDim]; i++) {
1414         NN_RET_CHECK_GT(channelQuant.scales[i], 0.0f) << tag << " invalid scaleArray[" << i << "]";
1415     }
1416     return {};
1417 }
1418 
validateScalarDimensions(const Operand & type,const char * tag)1419 Result<void> validateScalarDimensions(const Operand& type, const char* tag) {
1420     NN_RET_CHECK(type.dimensions.empty()) << tag << " invalid dimensions for scalar type";
1421     return {};
1422 }
1423 
validateQuant8AsymmParams(const Operand & type,const char * tag)1424 Result<void> validateQuant8AsymmParams(const Operand& type, const char* tag) {
1425     NN_RET_CHECK(0 <= type.zeroPoint && type.zeroPoint <= 255)
1426             << tag << " invalid zeroPoint: " << type.zeroPoint;
1427     NN_RET_CHECK_GT(type.scale, 0.0f) << tag << " invalid scale";
1428     return {};
1429 }
1430 
validateQuant8AsymmSignedParams(const Operand & type,const char * tag)1431 Result<void> validateQuant8AsymmSignedParams(const Operand& type, const char* tag) {
1432     NN_RET_CHECK(-128 <= type.zeroPoint && type.zeroPoint <= 127)
1433             << tag << " invalid zeroPoint: " << type.zeroPoint;
1434     NN_RET_CHECK_GT(type.scale, 0.0f) << tag << " invalid scale";
1435     return {};
1436 }
1437 
validateQuant8SymmParams(const Operand & type,const char * tag)1438 Result<void> validateQuant8SymmParams(const Operand& type, const char* tag) {
1439     NN_RET_CHECK_EQ(type.zeroPoint, 0) << tag << " invalid zeroPoint: " << type.zeroPoint;
1440     NN_RET_CHECK_GT(type.scale, 0.0f) << tag << " invalid scale";
1441     return {};
1442 }
1443 
validateQuant16AsymmParams(const Operand & type,const char * tag)1444 Result<void> validateQuant16AsymmParams(const Operand& type, const char* tag) {
1445     NN_RET_CHECK(0 <= type.zeroPoint && type.zeroPoint <= 65535)
1446             << tag << " invalid zeroPoint: " << type.zeroPoint;
1447     NN_RET_CHECK_GT(type.scale, 0.0f) << tag << " invalid scale";
1448     return {};
1449 }
1450 
validateQuantSymmParams(const Operand & type,const char * tag)1451 Result<void> validateQuantSymmParams(const Operand& type, const char* tag) {
1452     NN_RET_CHECK_EQ(type.zeroPoint, 0) << tag << " zeroPoint is not zero";
1453     NN_RET_CHECK_GT(type.scale, 0.0f) << tag << " invalid scale";
1454     return {};
1455 }
1456 
validateNoQuantParams(const Operand & type,const char * tag)1457 Result<void> validateNoQuantParams(const Operand& type, const char* tag) {
1458     NN_RET_CHECK_EQ(type.zeroPoint, 0) << tag << " zeroPoint is not zero";
1459     NN_RET_CHECK_EQ(type.scale, 0.0f) << tag << " scale is not zero";
1460     return {};
1461 }
1462 
validateTensorDimensions(const Operand & type,const Extension::OperandTypeInformation * extensionOperandTypeInfo,const char * tag,bool allowPartial)1463 Result<void> validateTensorDimensions(
1464         const Operand& type, const Extension::OperandTypeInformation* extensionOperandTypeInfo,
1465         const char* tag, bool allowPartial) {
1466     if (!allowPartial) {
1467         NN_RET_CHECK(!type.dimensions.empty()) << tag << " invalid operand dimensions";
1468     }
1469     uint64_t size = isExtension(type.type) ? extensionOperandTypeInfo->byteSize
1470                                            : getNonExtensionSize(type.type);
1471     constexpr uint64_t kMaxSize = std::numeric_limits<uint32_t>::max();
1472     for (size_t i = 0; i < type.dimensions.size(); i++) {
1473         if (!allowPartial) {
1474             NN_RET_CHECK_NE(type.dimensions[i], 0u) << tag << " invalid operand dimensions";
1475         }
1476         if (type.dimensions[i] != 0) {
1477             size *= type.dimensions[i];
1478             NN_RET_CHECK_LE(size, kMaxSize) << tag << " operand byte size exceeds " << kMaxSize;
1479         }
1480     }
1481     return {};
1482 }
1483 
validateOperandTypeImpl(const Operand & type,const Extension::OperandTypeInformation * const extensionOperandTypeInfo,const char * tag,bool allowPartial)1484 Result<void> validateOperandTypeImpl(
1485         const Operand& type,
1486         const Extension::OperandTypeInformation* const extensionOperandTypeInfo, const char* tag,
1487         bool allowPartial) {
1488     if (isExtension(type.type)) {
1489         NN_RET_CHECK(extensionOperandTypeInfo != nullptr);
1490         if (extensionOperandTypeInfo->isTensor) {
1491             NN_TRY(validateTensorDimensions(type, extensionOperandTypeInfo, tag, allowPartial));
1492         } else {
1493             NN_TRY(validateScalarDimensions(type, tag));
1494         }
1495         return validateNoQuantParams(type, tag);
1496     }
1497 
1498     NN_RET_CHECK(extensionOperandTypeInfo == nullptr);
1499     NN_TRY(validateOperandType(type.type));
1500 
1501     if (isNonExtensionScalar(type.type)) {
1502         NN_TRY(validateScalarDimensions(type, tag));
1503         if (type.type != OperandType::OEM) {  // Historically, we have allowed OEM types
1504                                               // to use quantization parameters.
1505             NN_TRY(validateNoQuantParams(type, tag));
1506         }
1507     } else {
1508         NN_TRY(validateTensorDimensions(type, extensionOperandTypeInfo, tag, allowPartial));
1509         if (type.type == OperandType::TENSOR_QUANT8_ASYMM) {
1510             NN_TRY(validateQuant8AsymmParams(type, tag));
1511         } else if (type.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1512             NN_TRY(validateQuant8AsymmSignedParams(type, tag));
1513         } else if (type.type == OperandType::TENSOR_QUANT8_SYMM) {
1514             NN_TRY(validateQuant8SymmParams(type, tag));
1515         } else if (type.type == OperandType::TENSOR_QUANT16_ASYMM) {
1516             NN_TRY(validateQuant16AsymmParams(type, tag));
1517         } else if (type.type == OperandType::TENSOR_QUANT16_SYMM) {
1518             NN_TRY(validateQuantSymmParams(type, tag));
1519         } else if (type.type == OperandType::TENSOR_INT32 ||
1520                    type.type == OperandType::TENSOR_OEM_BYTE) {
1521             // TODO(b/119869082): TENSOR_INT32 should not use quantization parameters.
1522             // Historically, we have allowed OEM types to use quantization parameters.
1523         } else {
1524             NN_TRY(validateNoQuantParams(type, tag));
1525         }
1526     }
1527 
1528     return {};
1529 }
1530 
validateOperandListImpl(const std::vector<uint32_t> & list,size_t operandCount,const char * tag)1531 Result<void> validateOperandListImpl(const std::vector<uint32_t>& list, size_t operandCount,
1532                                      const char* tag) {
1533     for (size_t i = 0; i < list.size(); i++) {
1534         NN_RET_CHECK_LT(list[i], operandCount) << tag << " invalid operand index at " << i << " = "
1535                                                << list[i] << ", operandCount " << operandCount;
1536     }
1537     return {};
1538 }
1539 
validateSubgraphReference(const std::vector<Model::Subgraph> & subgraphs,const Operand & modelOperand)1540 Result<void> validateSubgraphReference(const std::vector<Model::Subgraph>& subgraphs,
1541                                        const Operand& modelOperand) {
1542     NN_RET_CHECK_EQ(modelOperand.type, OperandType::SUBGRAPH)
1543             << "Unexpected operand type: " << modelOperand.type;
1544     NN_RET_CHECK_LT(modelOperand.location.offset, subgraphs.size()) << "Invalid subgraph reference";
1545     return {};
1546 }
getSubgraph(const std::vector<Model::Subgraph> & subgraphs,const Operand & modelOperand)1547 const Model::Subgraph& getSubgraph(const std::vector<Model::Subgraph>& subgraphs,
1548                                    const Operand& modelOperand) {
1549     return subgraphs.at(modelOperand.location.offset);
1550 }
getInputCount(const std::vector<Model::Subgraph> & subgraphs,const Operand & modelOperand)1551 uint32_t getInputCount(const std::vector<Model::Subgraph>& subgraphs, const Operand& modelOperand) {
1552     return getSubgraph(subgraphs, modelOperand).inputIndexes.size();
1553 }
getOutputCount(const std::vector<Model::Subgraph> & subgraphs,const Operand & modelOperand)1554 uint32_t getOutputCount(const std::vector<Model::Subgraph>& subgraphs,
1555                         const Operand& modelOperand) {
1556     return getSubgraph(subgraphs, modelOperand).outputIndexes.size();
1557 }
getInputOperand(const std::vector<Model::Subgraph> & subgraphs,const Operand & modelOperand,uint32_t index)1558 const Operand& getInputOperand(const std::vector<Model::Subgraph>& subgraphs,
1559                                const Operand& modelOperand, uint32_t index) {
1560     const Model::Subgraph& subgraph = getSubgraph(subgraphs, modelOperand);
1561     return subgraph.operands.at(subgraph.inputIndexes.at(index));
1562 }
getOutputOperand(const std::vector<Model::Subgraph> & subgraphs,const Operand & modelOperand,uint32_t index)1563 const Operand& getOutputOperand(const std::vector<Model::Subgraph>& subgraphs,
1564                                 const Operand& modelOperand, uint32_t index) {
1565     const Model::Subgraph& subgraph = getSubgraph(subgraphs, modelOperand);
1566     return subgraph.operands.at(subgraph.outputIndexes.at(index));
1567 }
1568 
1569 // Checks if two operands have the same types, ranks (if specified), dimensions
1570 // (if specified), scales, zeroPoints, and extraParams.
compatible(const Operand & a,const Operand & b)1571 Result<void> compatible(const Operand& a, const Operand& b) {
1572     NN_RET_CHECK_EQ(a.type, b.type) << a.type << " != " << b.type;
1573     if (!a.dimensions.empty() && !b.dimensions.empty()) {
1574         NN_RET_CHECK_EQ(a.dimensions.size(), b.dimensions.size()) << "Incompatible dimensions";
1575         for (uint32_t i = 0, n = a.dimensions.size(); i < n; ++i) {
1576             if (a.dimensions[i] != 0 && b.dimensions[i] != 0) {
1577                 NN_RET_CHECK_EQ(a.dimensions[i], b.dimensions[i]) << "Incompatible dimensions";
1578             }
1579         }
1580     }
1581     NN_RET_CHECK_EQ(a.scale, b.scale);
1582     NN_RET_CHECK_EQ(a.zeroPoint, b.zeroPoint);
1583     NN_RET_CHECK_EQ(a.extraParams, b.extraParams) << a.extraParams << " != " << b.extraParams;
1584     return {};
1585 }
1586 
validateConditionOperand(const Operand & operand)1587 Result<void> validateConditionOperand(const Operand& operand) {
1588     NN_RET_CHECK_EQ(operand.type, OperandType::TENSOR_BOOL8)
1589             << "Unexpected condition operand type: " << operand.type;
1590     NN_RET_CHECK_EQ(operand.dimensions.size(), 1u) << "Condition operand must be a singleton";
1591     NN_RET_CHECK_EQ(operand.dimensions[0], 1u) << "Condition operand must be a singleton";
1592     return {};
1593 }
1594 
validateIfOperation(const std::vector<uint32_t> & inputs,const std::vector<uint32_t> & outputs,const std::vector<Operand> & operands,const std::vector<Model::Subgraph> & subgraphs)1595 Result<Version> validateIfOperation(const std::vector<uint32_t>& inputs,
1596                                     const std::vector<uint32_t>& outputs,
1597                                     const std::vector<Operand>& operands,
1598                                     const std::vector<Model::Subgraph>& subgraphs) {
1599     namespace op = operation_if;
1600     NN_RET_CHECK_GE(inputs.size(), 3u) << "IF must have at least 3 inputs";
1601     NN_RET_CHECK_GE(outputs.size(), 1u) << "IF must have at least 1 output";
1602     auto validateBranchOperand = [&](const Operand& branchModelOperand) -> Result<void> {
1603         auto result = validateSubgraphReference(subgraphs, branchModelOperand);
1604         if (!result.has_value()) {
1605             return error() << std::move(result).error()
1606                            << " -- Operand is not a valid subgraph reference";
1607         }
1608         const uint32_t branchModelInputCount = getInputCount(subgraphs, branchModelOperand);
1609         const uint32_t branchModelOutputCount = getOutputCount(subgraphs, branchModelOperand);
1610         NN_RET_CHECK_EQ(inputs.size(), op::kFirstInput + branchModelInputCount);
1611         NN_RET_CHECK_EQ(outputs.size(), branchModelOutputCount);
1612         for (uint32_t i = 0; i < branchModelInputCount; ++i) {
1613             const Operand& innerOperand = getInputOperand(subgraphs, branchModelOperand, i);
1614             const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
1615             NN_TRY(compatible(innerOperand, outerOperand));
1616         }
1617         for (uint32_t i = 0; i < branchModelOutputCount; ++i) {
1618             const Operand& innerOperand = getOutputOperand(subgraphs, branchModelOperand, i);
1619             const Operand& outerOperand = operands[outputs[i]];
1620             NN_TRY(compatible(innerOperand, outerOperand));
1621         }
1622         return {};
1623     };
1624     auto result = validateConditionOperand(operands[inputs[op::kCondBoolOperand]]);
1625     if (!result.has_value()) {
1626         return error() << std::move(result).error() << " for IF condition operand";
1627     }
1628     result = validateBranchOperand(operands[inputs[op::kThenModelOperand]]);
1629     if (!result.has_value()) {
1630         return error() << std::move(result).error() << " for IF then model";
1631     }
1632     result = validateBranchOperand(operands[inputs[op::kElseModelOperand]]);
1633     if (!result.has_value()) {
1634         return error() << std::move(result).error() << " for IF else model";
1635     }
1636     return kVersionFeatureLevel4;
1637 }
1638 
validateControlFlowOperandUnknownSize(const Operand & operand)1639 Result<Version> validateControlFlowOperandUnknownSize(const Operand& operand) {
1640     auto version = kVersionFeatureLevel4;
1641     if (!isExtension(operand.type) && getNonExtensionSize(operand).value() == 0) {
1642         // 1.3 HAL (corresponding to kVersionFeatureLevel4) does not support CF operations with
1643         // operands of unknown size. See http://b/132458982#comment63.
1644         version.runtimeOnlyFeatures = true;
1645     }
1646     return version;
1647 }
1648 
validateWhileOperation(const std::vector<uint32_t> & inputs,const std::vector<uint32_t> & outputs,const std::vector<Operand> & operands,const std::vector<Model::Subgraph> & subgraphs)1649 Result<Version> validateWhileOperation(const std::vector<uint32_t>& inputs,
1650                                        const std::vector<uint32_t>& outputs,
1651                                        const std::vector<Operand>& operands,
1652                                        const std::vector<Model::Subgraph>& subgraphs) {
1653     // Let the loop have
1654     // - m >= 1 input-output operands,
1655     // - k >= 0 state-only operands, and
1656     // - n >= 0 input-only operands.
1657     // Then
1658     // - the WHILE loop operation has (2 + m + k + n) inputs and m outputs.
1659     // - the condition model has (m + k + n) inputs and 1 output.
1660     // - the body model has (m + k + n) inputs and (m + k) outputs.
1661     namespace op = operation_while;
1662     NN_RET_CHECK_GE(inputs.size(), 3u) << "WHILE must have at least 3 inputs";
1663     NN_RET_CHECK_GE(outputs.size(), 1u) << "WHILE must have at least 1 output";
1664     auto validateCondOperand = [&](const Operand& condModelOperand) -> Result<Version> {
1665         Version version = kVersionFeatureLevel4;
1666         auto result = validateSubgraphReference(subgraphs, condModelOperand);
1667         if (!result.has_value()) {
1668             return error() << std::move(result).error()
1669                            << " -- Operand is not a valid subgraph reference";
1670         }
1671         const uint32_t condModelInputCount = getInputCount(subgraphs, condModelOperand);
1672         const uint32_t condModelOutputCount = getOutputCount(subgraphs, condModelOperand);
1673         NN_RET_CHECK_EQ(inputs.size(), op::kFirstInput + condModelInputCount);
1674         NN_RET_CHECK_EQ(condModelOutputCount, 1u);
1675         for (uint32_t i = 0; i < condModelInputCount; ++i) {
1676             const Operand& innerOperand = getInputOperand(subgraphs, condModelOperand, i);
1677             const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
1678             NN_TRY(compatible(innerOperand, outerOperand));
1679             version = combineVersions(version,
1680                                       NN_TRY(validateControlFlowOperandUnknownSize(innerOperand)));
1681             version = combineVersions(version,
1682                                       NN_TRY(validateControlFlowOperandUnknownSize(outerOperand)));
1683         }
1684         NN_TRY(validateConditionOperand(getOutputOperand(subgraphs, condModelOperand, 0)));
1685         return version;
1686     };
1687     auto validateBodyOperand = [&](const Operand& bodyModelOperand) -> Result<Version> {
1688         Version version = kVersionFeatureLevel4;
1689         auto result = validateSubgraphReference(subgraphs, bodyModelOperand);
1690         if (!result.has_value()) {
1691             return error() << std::move(result).error()
1692                            << " -- Operand is not a valid subgraph reference";
1693         }
1694         const uint32_t bodyModelInputCount = getInputCount(subgraphs, bodyModelOperand);
1695         const uint32_t bodyModelOutputCount = getOutputCount(subgraphs, bodyModelOperand);
1696         NN_RET_CHECK_EQ(inputs.size(), op::kFirstInput + bodyModelInputCount);
1697         NN_RET_CHECK_GE(bodyModelOutputCount, outputs.size());
1698         NN_RET_CHECK_GE(bodyModelInputCount, bodyModelOutputCount);
1699         const uint32_t inputOutputCount = outputs.size();
1700         const uint32_t stateOnlyCount = bodyModelOutputCount - inputOutputCount;
1701         const uint32_t inputOnlyCount = bodyModelInputCount - bodyModelOutputCount;
1702         for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount + inputOnlyCount; i < n; ++i) {
1703             const Operand& innerOperand = getInputOperand(subgraphs, bodyModelOperand, i);
1704             const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
1705             NN_TRY(compatible(innerOperand, outerOperand));
1706             version = combineVersions(version,
1707                                       NN_TRY(validateControlFlowOperandUnknownSize(innerOperand)));
1708             version = combineVersions(version,
1709                                       NN_TRY(validateControlFlowOperandUnknownSize(outerOperand)));
1710         }
1711         for (uint32_t i = 0; i < inputOutputCount; ++i) {
1712             const Operand& innerOperand = getOutputOperand(subgraphs, bodyModelOperand, i);
1713             const Operand& outerOperand = operands[outputs[i]];
1714             NN_TRY(compatible(innerOperand, outerOperand));
1715             version = combineVersions(version,
1716                                       NN_TRY(validateControlFlowOperandUnknownSize(outerOperand)));
1717         }
1718         for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount; i < n; ++i) {
1719             const Operand& inputOperand = getInputOperand(subgraphs, bodyModelOperand, i);
1720             const Operand& outputOperand = getOutputOperand(subgraphs, bodyModelOperand, i);
1721             NN_TRY(compatible(inputOperand, outputOperand));
1722             version = combineVersions(version,
1723                                       NN_TRY(validateControlFlowOperandUnknownSize(outputOperand)));
1724         }
1725         return version;
1726     };
1727     auto result = validateCondOperand(operands[inputs[op::kCondModelOperand]]);
1728     if (!result.has_value()) {
1729         return error() << std::move(result).error() << " for WHILE condition model";
1730     }
1731     auto version = result.value();
1732     result = validateBodyOperand(operands[inputs[op::kBodyModelOperand]]);
1733     if (!result.has_value()) {
1734         return error() << std::move(result).error() << " for WHILE body model";
1735     }
1736     version = combineVersions(version, result.value());
1737     return version;
1738 }
1739 
validateOperationButNotOperandsImpl(const Operation & operation,const std::vector<Operand> & operands,const std::vector<Model::Subgraph> & subgraphs)1740 Result<Version> validateOperationButNotOperandsImpl(const Operation& operation,
1741                                                     const std::vector<Operand>& operands,
1742                                                     const std::vector<Model::Subgraph>& subgraphs) {
1743     const auto opType = operation.type;
1744     const auto& inputIndexes = operation.inputs;
1745     const auto& outputIndexes = operation.outputs;
1746 
1747     NN_TRY(validateOperandListImpl(inputIndexes, operands.size(),
1748                                    "ANeuralNetworksModel_addOperation inputs"));
1749     NN_TRY(validateOperandListImpl(outputIndexes, operands.size(),
1750                                    "ANeuralNetworksModel_addOperation outputs"));
1751 
1752     if (isExtension(opType)) {
1753         // There is no other validation we can do for an extension operation.
1754         return kVersionFeatureLevel3;
1755     }
1756 
1757     std::ostringstream oss;
1758     oss << operation.type;
1759     const std::string name = oss.str();
1760     OperationValidationContext context(name.c_str(), inputIndexes, outputIndexes, operands);
1761 
1762     // Validate some operations explicitly.
1763     switch (opType) {
1764         case OperationType::OEM_OPERATION:
1765             return kVersionFeatureLevel1;
1766         case OperationType::IF:
1767             return validateIfOperation(inputIndexes, outputIndexes, operands, subgraphs);
1768         case OperationType::WHILE:
1769             return validateWhileOperation(inputIndexes, outputIndexes, operands, subgraphs);
1770         default:
1771             break;
1772     }
1773 
1774 #define NN_HANDLE_SWITCH_CASE(operationName) \
1775     case OperationType::operationName:       \
1776         return NN_VALIDATION_FUNCTION_NAME(operationName)(&context);
1777 
1778     // Validate the remaining operations through operation-specific functions defined in
1779     // common/operations/.
1780     // TODO(b/213938830): operation validation dispatch is duplicated and does not handle extension
1781     // types.
1782     switch (opType) { NN_FOR_EACH_OPERATION(NN_HANDLE_SWITCH_CASE) }
1783 
1784 #undef NN_HANDLE_SWITCH_CASE
1785 
1786     NN_RET_CHECK_FAIL() << "Invalid OperationType " << opType;
1787 }
1788 
validateOperationIncludingOperandVersions(const Operation & operation,const std::vector<Operand> & operands,const std::vector<Version> & operandVersions,const std::vector<Model::Subgraph> & subgraphs)1789 Result<Version> validateOperationIncludingOperandVersions(
1790         const Operation& operation, const std::vector<Operand>& operands,
1791         const std::vector<Version>& operandVersions,
1792         const std::vector<Model::Subgraph>& subgraphs) {
1793     auto version = NN_TRY(validateOperationButNotOperandsImpl(operation, operands, subgraphs));
1794     for (uint32_t index : operation.inputs) {
1795         version = combineVersions(version, operandVersions[index]);
1796     }
1797     for (uint32_t index : operation.outputs) {
1798         version = combineVersions(version, operandVersions[index]);
1799     }
1800     return version;
1801 }
1802 
1803 }  // anonymous namespace
1804 
1805 // Below this point are all the functions that are declared in Validation.h. The convention of this
1806 // file is to keep the function bodies of the functions declared in Validation.h minimal, meaning
1807 // that most functions below simply redirect to one of the functions defined above in the anonymous
1808 // namespace. If there is a function name clash between one of the functions below and one of the
1809 // functions above, the function in the anonymous namespace is appended with "Impl".
1810 
combineVersions(Version minVersionNeeded1,Version minVersionNeeded2)1811 Version combineVersions(Version minVersionNeeded1, Version minVersionNeeded2) {
1812     return Version{
1813             .level = std::max<Version::Level>(minVersionNeeded1.level, minVersionNeeded2.level),
1814             .runtimeOnlyFeatures =
1815                     minVersionNeeded1.runtimeOnlyFeatures || minVersionNeeded2.runtimeOnlyFeatures,
1816     };
1817 }
1818 
isCompliantVersion(Version minVersionNeeded,Version maxVersionSupported)1819 bool isCompliantVersion(Version minVersionNeeded, Version maxVersionSupported) {
1820     if (minVersionNeeded.runtimeOnlyFeatures && !maxVersionSupported.runtimeOnlyFeatures) {
1821         return false;
1822     }
1823     return minVersionNeeded.level <= maxVersionSupported.level;
1824 }
1825 
validate(const DeviceStatus & deviceStatus)1826 Result<Version> validate(const DeviceStatus& deviceStatus) {
1827     return validateDeviceStatus(deviceStatus);
1828 }
1829 
validate(const ExecutionPreference & executionPreference)1830 Result<Version> validate(const ExecutionPreference& executionPreference) {
1831     return validateExecutionPreference(executionPreference);
1832 }
1833 
validate(const DeviceType & deviceType)1834 Result<Version> validate(const DeviceType& deviceType) {
1835     return validateDeviceType(deviceType);
1836 }
1837 
validate(const MeasureTiming & measureTiming)1838 Result<Version> validate(const MeasureTiming& measureTiming) {
1839     return validateMeasureTiming(measureTiming);
1840 }
1841 
validate(const OperandType & operandType)1842 Result<Version> validate(const OperandType& operandType) {
1843     return validateOperandType(operandType);
1844 }
1845 
validate(const Priority & priority)1846 Result<Version> validate(const Priority& priority) {
1847     return validatePriority(priority);
1848 }
1849 
validate(const ErrorStatus & errorStatus)1850 Result<Version> validate(const ErrorStatus& errorStatus) {
1851     return validateErrorStatus(errorStatus);
1852 }
1853 
validate(const FusedActivationFunc & activation)1854 Result<Version> validate(const FusedActivationFunc& activation) {
1855     return validateFusedActivationFunc(activation);
1856 }
1857 
validate(const OutputShape & outputShape)1858 Result<Version> validate(const OutputShape& outputShape) {
1859     return validateOutputShape(outputShape);
1860 }
1861 
validate(const Timing & timing)1862 Result<Version> validate(const Timing& timing) {
1863     return validateTiming(timing);
1864 }
1865 
validate(const Capabilities & capabilities)1866 Result<Version> validate(const Capabilities& capabilities) {
1867     return validateCapabilities(capabilities);
1868 }
1869 
validate(const Extension & extension)1870 Result<Version> validate(const Extension& extension) {
1871     return validateExtension(extension);
1872 }
1873 
validate(const SharedHandle & handle)1874 Result<Version> validate(const SharedHandle& handle) {
1875     return validateSharedHandle(handle);
1876 }
1877 
validate(const SharedMemory & memory)1878 Result<Version> validate(const SharedMemory& memory) {
1879     return validateSharedMemory(memory);
1880 }
1881 
validate(const Model & model)1882 Result<Version> validate(const Model& model) {
1883     return validateModel(model);
1884 }
1885 
validate(const BufferDesc & bufferDesc)1886 Result<Version> validate(const BufferDesc& bufferDesc) {
1887     return validateBufferDesc(bufferDesc);
1888 }
1889 
validate(const BufferRole & bufferRole)1890 Result<Version> validate(const BufferRole& bufferRole) {
1891     return validateBufferRole(bufferRole);
1892 }
1893 
validate(const Request & request)1894 Result<Version> validate(const Request& request) {
1895     return validateRequest(request);
1896 }
1897 
validate(const OptionalTimePoint & optionalTimePoint)1898 Result<Version> validate(const OptionalTimePoint& optionalTimePoint) {
1899     return validateOptionalTimePoint(optionalTimePoint);
1900 }
1901 
validate(const OptionalDuration & optionalTimeoutDuration)1902 Result<Version> validate(const OptionalDuration& optionalTimeoutDuration) {
1903     return validateOptionalTimeoutDuration(optionalTimeoutDuration);
1904 }
1905 
validate(const CacheToken & cacheToken)1906 Result<Version> validate(const CacheToken& cacheToken) {
1907     return validateCacheToken(cacheToken);
1908 }
1909 
validate(const SyncFence & syncFence)1910 Result<Version> validate(const SyncFence& syncFence) {
1911     return validateSyncFence(syncFence);
1912 }
1913 
validate(const TokenValuePair & tokenValuePair)1914 Result<Version> validate(const TokenValuePair& tokenValuePair) {
1915     return validateTokenValuePair(tokenValuePair);
1916 }
1917 
validate(const std::vector<OutputShape> & outputShapes)1918 Result<Version> validate(const std::vector<OutputShape>& outputShapes) {
1919     return validateVector(outputShapes, validateOutputShape);
1920 }
1921 
validate(const std::vector<Extension> & extensions)1922 Result<Version> validate(const std::vector<Extension>& extensions) {
1923     return validateExtensions(extensions);
1924 }
1925 
validate(const std::vector<SharedHandle> & handles)1926 Result<Version> validate(const std::vector<SharedHandle>& handles) {
1927     return validateVector(handles, validateSharedHandle);
1928 }
1929 
validate(const std::vector<BufferRole> & bufferRoles)1930 Result<Version> validate(const std::vector<BufferRole>& bufferRoles) {
1931     return validateVector(bufferRoles, validateBufferRole);
1932 }
1933 
validate(const std::vector<SyncFence> & syncFences)1934 Result<Version> validate(const std::vector<SyncFence>& syncFences) {
1935     return validateVector(syncFences, validateSyncFence);
1936 }
1937 
validate(const std::vector<TokenValuePair> & metaData)1938 Result<Version> validate(const std::vector<TokenValuePair>& metaData) {
1939     std::set<int32_t> tokenSet;
1940     for (const auto& p : metaData) {
1941         if (!tokenSet.insert(p.token).second) {
1942             NN_RET_CHECK_FAIL() << "Token added more than once " << p.token;
1943         }
1944     }
1945     return validateVector(metaData, validateTokenValuePair);
1946 }
1947 
validate(const std::vector<ExtensionNameAndPrefix> & extensionNamesAndPrefixes)1948 Result<Version> validate(const std::vector<ExtensionNameAndPrefix>& extensionNamesAndPrefixes) {
1949     return validateExtensionNamesAndPrefixes(extensionNamesAndPrefixes);
1950 }
1951 
validateRequestForModel(const Request & request,const Model & model,bool allowUnspecifiedOutput)1952 Result<Version> validateRequestForModel(const Request& request, const Model& model,
1953                                         bool allowUnspecifiedOutput) {
1954     return validateRequestForModelImpl(request, model, allowUnspecifiedOutput);
1955 }
1956 
validateMemoryDesc(const BufferDesc & desc,const std::vector<SharedPreparedModel> & preparedModels,const std::vector<BufferRole> & inputRoles,const std::vector<BufferRole> & outputRoles,const std::function<const Model * (const SharedPreparedModel &)> & getModel,std::set<PreparedModelRole> * preparedModelRoles,Operand * combinedOperand)1957 Result<Version> validateMemoryDesc(
1958         const BufferDesc& desc, const std::vector<SharedPreparedModel>& preparedModels,
1959         const std::vector<BufferRole>& inputRoles, const std::vector<BufferRole>& outputRoles,
1960         const std::function<const Model*(const SharedPreparedModel&)>& getModel,
1961         std::set<PreparedModelRole>* preparedModelRoles, Operand* combinedOperand) {
1962     return validateMemoryDescImpl(desc, preparedModels, inputRoles, outputRoles, getModel,
1963                                   preparedModelRoles, combinedOperand);
1964 }
1965 
validateOperandSymmPerChannelQuantParams(const Operand & operand,const Operand::SymmPerChannelQuantParams & channelQuant,const char * tag)1966 Result<void> validateOperandSymmPerChannelQuantParams(
1967         const Operand& operand, const Operand::SymmPerChannelQuantParams& channelQuant,
1968         const char* tag) {
1969     return validateOperandSymmPerChannelQuantParamsImpl(operand, channelQuant, tag);
1970 }
1971 
validateOperandType(const Operand & type,const Extension::OperandTypeInformation * extensionOperandTypeInfo,const char * tag,bool allowPartial)1972 Result<void> validateOperandType(const Operand& type,
1973                                  const Extension::OperandTypeInformation* extensionOperandTypeInfo,
1974                                  const char* tag, bool allowPartial) {
1975     return validateOperandTypeImpl(type, extensionOperandTypeInfo, tag, allowPartial);
1976 }
1977 
validateOperandList(const std::vector<uint32_t> & list,size_t operandCount,const char * tag)1978 Result<void> validateOperandList(const std::vector<uint32_t>& list, size_t operandCount,
1979                                  const char* tag) {
1980     return validateOperandListImpl(list, operandCount, tag);
1981 }
1982 
validateOperationButNotOperands(const Operation & operation,const std::vector<Operand> & operands,const std::vector<Model::Subgraph> & subgraphs)1983 Result<void> validateOperationButNotOperands(const Operation& operation,
1984                                              const std::vector<Operand>& operands,
1985                                              const std::vector<Model::Subgraph>& subgraphs) {
1986     NN_TRY(validateOperationButNotOperandsImpl(operation, operands, subgraphs));
1987     return {};
1988 }
1989 
1990 struct SubgraphVersionCache {
1991     std::vector<std::optional<Version>> cache;
1992 };
1993 
createSubgraphVersionCache(size_t referencedSubgraphCount)1994 std::unique_ptr<SubgraphVersionCache, void (*)(SubgraphVersionCache*)> createSubgraphVersionCache(
1995         size_t referencedSubgraphCount) {
1996     auto subgraphVersionCache = std::make_unique<SubgraphVersionCache>();
1997     subgraphVersionCache->cache.resize(referencedSubgraphCount);
1998     constexpr auto deleter = [](SubgraphVersionCache* ptr) { delete ptr; };
1999     return {subgraphVersionCache.release(), deleter};
2000 }
2001 
validateOperationAndAnythingItDependsOn(const Operation & operation,const std::vector<Operand> & operands,size_t operandValuesSize,const std::vector<size_t> & poolSizes,const std::vector<Model::Subgraph> & subgraphs,SubgraphVersionCache * subgraphVersionCache)2002 Result<Version> validateOperationAndAnythingItDependsOn(
2003         const Operation& operation, const std::vector<Operand>& operands, size_t operandValuesSize,
2004         const std::vector<size_t>& poolSizes, const std::vector<Model::Subgraph>& subgraphs,
2005         SubgraphVersionCache* subgraphVersionCache) {
2006     CHECK(subgraphVersionCache != nullptr);
2007     std::vector<Version> operandVersions(operands.size(), kVersionFeatureLevel1);
2008     for (uint32_t index : operation.inputs) {
2009         NN_RET_CHECK_LT(index, operands.size());
2010         const Operand& operand = operands[index];
2011         operandVersions[index] = NN_TRY(validateOperandAndAnythingItDependsOn(
2012                 operand, operandValuesSize, poolSizes, subgraphs, subgraphVersionCache));
2013     }
2014     for (uint32_t index : operation.outputs) {
2015         NN_RET_CHECK_LT(index, operands.size());
2016         const Operand& operand = operands[index];
2017         operandVersions[index] = NN_TRY(validateOperandAndAnythingItDependsOn(
2018                 operand, operandValuesSize, poolSizes, subgraphs, subgraphVersionCache));
2019     }
2020     return validateOperationIncludingOperandVersions(operation, operands, operandVersions,
2021                                                      subgraphs);
2022 }
2023 
validateOperandAndAnythingItDependsOn(const Operand & operand,size_t operandValuesSize,const std::vector<size_t> & poolSizes,const std::vector<Model::Subgraph> & subgraphs,SubgraphVersionCache * subgraphVersionCache)2024 Result<Version> validateOperandAndAnythingItDependsOn(const Operand& operand,
2025                                                       size_t operandValuesSize,
2026                                                       const std::vector<size_t>& poolSizes,
2027                                                       const std::vector<Model::Subgraph>& subgraphs,
2028                                                       SubgraphVersionCache* subgraphVersionCache) {
2029     CHECK(subgraphVersionCache != nullptr);
2030     return validateOperand(operand, operandValuesSize, poolSizes, subgraphs,
2031                            &subgraphVersionCache->cache);
2032 }
2033 
2034 }  // namespace android::nn
2035