1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "neuralnetworks_hidl_hal_test"
18 
19 #include <android/hardware/neuralnetworks/1.1/types.h>
20 #include "1.0/Callbacks.h"
21 #include "1.0/Utils.h"
22 #include "GeneratedTestHarness.h"
23 #include "VtsHalNeuralnetworks.h"
24 
25 #include <optional>
26 #include <type_traits>
27 #include <utility>
28 
29 namespace android::hardware::neuralnetworks::V1_1::vts::functional {
30 
31 using V1_0::DataLocation;
32 using V1_0::ErrorStatus;
33 using V1_0::IPreparedModel;
34 using V1_0::Operand;
35 using V1_0::OperandLifeTime;
36 using V1_0::OperandType;
37 using V1_0::implementation::PreparedModelCallback;
38 
39 using PrepareModelMutation = std::function<void(Model*, ExecutionPreference*)>;
40 
41 ///////////////////////// UTILITY FUNCTIONS /////////////////////////
42 
validateGetSupportedOperations(const sp<IDevice> & device,const std::string & message,const Model & model)43 static void validateGetSupportedOperations(const sp<IDevice>& device, const std::string& message,
44                                            const Model& model) {
45     SCOPED_TRACE(message + " [getSupportedOperations_1_1]");
46 
47     Return<void> ret = device->getSupportedOperations_1_1(
48             model, [&](ErrorStatus status, const hidl_vec<bool>&) {
49                 EXPECT_EQ(ErrorStatus::INVALID_ARGUMENT, status);
50             });
51     EXPECT_TRUE(ret.isOk());
52 }
53 
validatePrepareModel(const sp<IDevice> & device,const std::string & message,const Model & model,ExecutionPreference preference)54 static void validatePrepareModel(const sp<IDevice>& device, const std::string& message,
55                                  const Model& model, ExecutionPreference preference) {
56     SCOPED_TRACE(message + " [prepareModel_1_1]");
57 
58     sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
59     Return<ErrorStatus> prepareLaunchStatus =
60             device->prepareModel_1_1(model, preference, preparedModelCallback);
61     ASSERT_TRUE(prepareLaunchStatus.isOk());
62     ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, static_cast<ErrorStatus>(prepareLaunchStatus));
63 
64     preparedModelCallback->wait();
65     ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
66     ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, prepareReturnStatus);
67     sp<IPreparedModel> preparedModel = preparedModelCallback->getPreparedModel();
68     ASSERT_EQ(nullptr, preparedModel.get());
69 }
70 
validExecutionPreference(ExecutionPreference preference)71 static bool validExecutionPreference(ExecutionPreference preference) {
72     return preference == ExecutionPreference::LOW_POWER ||
73            preference == ExecutionPreference::FAST_SINGLE_ANSWER ||
74            preference == ExecutionPreference::SUSTAINED_SPEED;
75 }
76 
77 // Primary validation function. This function will take a valid model, apply a
78 // mutation to invalidate either the model or the execution preference, then
79 // pass these to supportedOperations and/or prepareModel if that method is
80 // called with an invalid argument.
validate(const sp<IDevice> & device,const std::string & message,const Model & originalModel,const PrepareModelMutation & mutate)81 static void validate(const sp<IDevice>& device, const std::string& message,
82                      const Model& originalModel, const PrepareModelMutation& mutate) {
83     Model model = originalModel;
84     ExecutionPreference preference = ExecutionPreference::FAST_SINGLE_ANSWER;
85     mutate(&model, &preference);
86 
87     if (validExecutionPreference(preference)) {
88         validateGetSupportedOperations(device, message, model);
89     }
90 
91     validatePrepareModel(device, message, model, preference);
92 }
93 
addOperand(Model * model)94 static uint32_t addOperand(Model* model) {
95     return hidl_vec_push_back(&model->operands,
96                               {
97                                       .type = OperandType::INT32,
98                                       .dimensions = {},
99                                       .numberOfConsumers = 0,
100                                       .scale = 0.0f,
101                                       .zeroPoint = 0,
102                                       .lifetime = OperandLifeTime::MODEL_INPUT,
103                                       .location = {.poolIndex = 0, .offset = 0, .length = 0},
104                               });
105 }
106 
addOperand(Model * model,OperandLifeTime lifetime)107 static uint32_t addOperand(Model* model, OperandLifeTime lifetime) {
108     uint32_t index = addOperand(model);
109     model->operands[index].numberOfConsumers = 1;
110     model->operands[index].lifetime = lifetime;
111     return index;
112 }
113 
114 // If we introduce a CONSTANT_COPY for an operand of size operandSize,
115 // how much will this increase the size of the model?  This assumes
116 // that we can (re)use all of model.operandValues for the operand
117 // value.
constantCopyExtraSize(const Model & model,size_t operandSize)118 static size_t constantCopyExtraSize(const Model& model, size_t operandSize) {
119     const size_t operandValuesSize = model.operandValues.size();
120     return (operandValuesSize < operandSize) ? (operandSize - operandValuesSize) : 0;
121 }
122 
123 // Highly specialized utility routine for converting an operand to
124 // CONSTANT_COPY lifetime.
125 //
126 // Expects that:
127 // - operand has a known size
128 // - operand->lifetime has already been set to CONSTANT_COPY
129 // - operand->location has been zeroed out
130 //
131 // Does the following:
132 // - initializes operand->location to point to the beginning of model->operandValues
133 // - resizes model->operandValues (if necessary) to be large enough for the operand
134 //   value, padding it with zeroes on the end
135 //
136 // Potential problem:
137 // By changing the operand to CONSTANT_COPY lifetime, this function is effectively initializing the
138 // operand with unspecified (but deterministic) data. This means that the model may be invalidated
139 // in two ways: not only is the lifetime of CONSTANT_COPY invalid, but the operand's value in the
140 // graph may also be invalid (e.g., if the operand is used as an activation code and has an invalid
141 // value). For now, this should be fine because it just means we're not testing what we think we're
142 // testing in certain cases; but we can handwave this and assume we're probabilistically likely to
143 // exercise the validation code over the span of the entire test set and operand space.
144 //
145 // Aborts if the specified operand type is an extension type or OEM type.
becomeConstantCopy(Model * model,Operand * operand)146 static void becomeConstantCopy(Model* model, Operand* operand) {
147     // sizeOfData will abort if the specified type is an extension type or OEM type.
148     const size_t sizeOfOperand = sizeOfData(*operand);
149     EXPECT_NE(sizeOfOperand, size_t(0));
150     operand->location.poolIndex = 0;
151     operand->location.offset = 0;
152     operand->location.length = sizeOfOperand;
153     if (model->operandValues.size() < sizeOfOperand) {
154         model->operandValues.resize(sizeOfOperand);
155     }
156 }
157 
158 // The sizeForBinder() functions estimate the size of the
159 // representation of a value when sent to binder.  It's probably a bit
160 // of an under-estimate, because we don't know the size of the
161 // metadata in the binder format (e.g., representation of the size of
162 // a vector); but at least it adds up "big" things like vector
163 // contents.  However, it doesn't treat inter-field or end-of-struct
164 // padding in a methodical way -- there's no attempt to be consistent
165 // in whether or not padding in the native (C++) representation
166 // contributes to the estimated size for the binder representation;
167 // and there's no attempt to understand what padding (if any) is
168 // needed in the binder representation.
169 //
170 // This assumes that non-metadata uses a fixed length encoding (e.g.,
171 // a uint32_t is always encoded in sizeof(uint32_t) bytes, rather than
172 // using an encoding whose length is related to the magnitude of the
173 // encoded value).
174 
175 template <typename Type>
sizeForBinder(const Type & val)176 static size_t sizeForBinder(const Type& val) {
177     static_assert(std::is_trivially_copyable_v<std::remove_reference_t<Type>>,
178                   "expected a trivially copyable type");
179     return sizeof(val);
180 }
181 
182 template <typename Type>
sizeForBinder(const hidl_vec<Type> & vec)183 static size_t sizeForBinder(const hidl_vec<Type>& vec) {
184     return std::accumulate(vec.begin(), vec.end(), 0,
185                            [](size_t acc, const Type& x) { return acc + sizeForBinder(x); });
186 }
187 
188 template <>
sizeForBinder(const Operand & operand)189 size_t sizeForBinder(const Operand& operand) {
190     size_t size = 0;
191 
192     size += sizeForBinder(operand.type);
193     size += sizeForBinder(operand.dimensions);
194     size += sizeForBinder(operand.numberOfConsumers);
195     size += sizeForBinder(operand.scale);
196     size += sizeForBinder(operand.zeroPoint);
197     size += sizeForBinder(operand.lifetime);
198     size += sizeForBinder(operand.location);
199 
200     return size;
201 }
202 
203 template <>
sizeForBinder(const Operation & operation)204 size_t sizeForBinder(const Operation& operation) {
205     size_t size = 0;
206 
207     size += sizeForBinder(operation.type);
208     size += sizeForBinder(operation.inputs);
209     size += sizeForBinder(operation.outputs);
210 
211     return size;
212 }
213 
214 template <>
sizeForBinder(const hidl_string & name)215 size_t sizeForBinder(const hidl_string& name) {
216     return name.size();
217 }
218 
219 template <>
sizeForBinder(const hidl_memory & memory)220 size_t sizeForBinder(const hidl_memory& memory) {
221     // This is just a guess.
222 
223     size_t size = 0;
224 
225     if (const native_handle_t* handle = memory.handle()) {
226         size += sizeof(*handle);
227         size += sizeof(handle->data[0] * (handle->numFds + handle->numInts));
228     }
229     size += sizeForBinder(memory.name());
230 
231     return size;
232 }
233 
234 template <>
sizeForBinder(const Model & model)235 size_t sizeForBinder(const Model& model) {
236     size_t size = 0;
237 
238     size += sizeForBinder(model.operands);
239     size += sizeForBinder(model.operations);
240     size += sizeForBinder(model.inputIndexes);
241     size += sizeForBinder(model.outputIndexes);
242     size += sizeForBinder(model.operandValues);
243     size += sizeForBinder(model.pools);
244     size += sizeForBinder(model.relaxComputationFloat32toFloat16);
245 
246     return size;
247 }
248 
249 // https://developer.android.com/reference/android/os/TransactionTooLargeException.html
250 //
251 //     "The Binder transaction buffer has a limited fixed size,
252 //     currently 1Mb, which is shared by all transactions in progress
253 //     for the process."
254 //
255 // Will our representation fit under this limit?  There are three complications:
256 // - Our representation size is just approximate (see sizeForBinder()).
257 // - This object may not be the only occupant of the Binder transaction buffer
258 //   (although our VTS test suite should not be putting multiple objects in the
259 //   buffer at once).
260 // - IBinder.MAX_IPC_SIZE recommends limiting a transaction to 64 * 1024 bytes.
261 // So we'll be very conservative: We want the representation size to be no
262 // larger than half the recommended limit.
263 //
264 // If our representation grows large enough that it still fits within
265 // the transaction buffer but combined with other transactions may
266 // exceed the buffer size, then we may see intermittent HAL transport
267 // errors.
exceedsBinderSizeLimit(size_t representationSize)268 static bool exceedsBinderSizeLimit(size_t representationSize) {
269     // There is no C++ API to retrieve the value of the Java variable IBinder.MAX_IPC_SIZE.
270     static const size_t kHalfMaxIPCSize = 64 * 1024 / 2;
271 
272     return representationSize > kHalfMaxIPCSize;
273 }
274 
275 ///////////////////////// VALIDATE EXECUTION ORDER ////////////////////////////
276 
mutateExecutionOrderTest(const sp<IDevice> & device,const V1_1::Model & model)277 static void mutateExecutionOrderTest(const sp<IDevice>& device, const V1_1::Model& model) {
278     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
279         const Operation& operationObj = model.operations[operation];
280         for (uint32_t input : operationObj.inputs) {
281             if (model.operands[input].lifetime == OperandLifeTime::TEMPORARY_VARIABLE ||
282                 model.operands[input].lifetime == OperandLifeTime::MODEL_OUTPUT) {
283                 // This operation reads an operand written by some
284                 // other operation.  Move this operation to the
285                 // beginning of the sequence, ensuring that it reads
286                 // the operand before that operand is written, thereby
287                 // violating execution order rules.
288                 const std::string message = "mutateExecutionOrderTest: operation " +
289                                             std::to_string(operation) + " is a reader";
290                 validate(device, message, model, [operation](Model* model, ExecutionPreference*) {
291                     auto& operations = model->operations;
292                     std::rotate(operations.begin(), operations.begin() + operation,
293                                 operations.begin() + operation + 1);
294                 });
295                 break;  // only need to do this once per operation
296             }
297         }
298         for (uint32_t output : operationObj.outputs) {
299             if (model.operands[output].numberOfConsumers > 0) {
300                 // This operation writes an operand read by some other
301                 // operation.  Move this operation to the end of the
302                 // sequence, ensuring that it writes the operand after
303                 // that operand is read, thereby violating execution
304                 // order rules.
305                 const std::string message = "mutateExecutionOrderTest: operation " +
306                                             std::to_string(operation) + " is a writer";
307                 validate(device, message, model, [operation](Model* model, ExecutionPreference*) {
308                     auto& operations = model->operations;
309                     std::rotate(operations.begin() + operation, operations.begin() + operation + 1,
310                                 operations.end());
311                 });
312                 break;  // only need to do this once per operation
313             }
314         }
315     }
316 }
317 
318 ///////////////////////// VALIDATE MODEL OPERAND TYPE /////////////////////////
319 
320 static const int32_t invalidOperandTypes[] = {
321         static_cast<int32_t>(OperandType::FLOAT32) - 1,              // lower bound fundamental
322         static_cast<int32_t>(OperandType::TENSOR_QUANT8_ASYMM) + 1,  // upper bound fundamental
323         static_cast<int32_t>(OperandType::OEM) - 1,                  // lower bound OEM
324         static_cast<int32_t>(OperandType::TENSOR_OEM_BYTE) + 1,      // upper bound OEM
325 };
326 
mutateOperandTypeTest(const sp<IDevice> & device,const Model & model)327 static void mutateOperandTypeTest(const sp<IDevice>& device, const Model& model) {
328     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
329         for (int32_t invalidOperandType : invalidOperandTypes) {
330             const std::string message = "mutateOperandTypeTest: operand " +
331                                         std::to_string(operand) + " set to value " +
332                                         std::to_string(invalidOperandType);
333             validate(device, message, model,
334                      [operand, invalidOperandType](Model* model, ExecutionPreference*) {
335                          model->operands[operand].type =
336                                  static_cast<OperandType>(invalidOperandType);
337                      });
338         }
339     }
340 }
341 
342 ///////////////////////// VALIDATE OPERAND RANK /////////////////////////
343 
getInvalidRank(OperandType type)344 static uint32_t getInvalidRank(OperandType type) {
345     switch (type) {
346         case OperandType::FLOAT32:
347         case OperandType::INT32:
348         case OperandType::UINT32:
349             return 1;
350         case OperandType::TENSOR_FLOAT32:
351         case OperandType::TENSOR_INT32:
352         case OperandType::TENSOR_QUANT8_ASYMM:
353             return 0;
354         default:
355             return 0;
356     }
357 }
358 
mutateOperandRankTest(const sp<IDevice> & device,const Model & model)359 static void mutateOperandRankTest(const sp<IDevice>& device, const Model& model) {
360     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
361         const uint32_t invalidRank = getInvalidRank(model.operands[operand].type);
362         const std::string message = "mutateOperandRankTest: operand " + std::to_string(operand) +
363                                     " has rank of " + std::to_string(invalidRank);
364         validate(device, message, model,
365                  [operand, invalidRank](Model* model, ExecutionPreference*) {
366                      model->operands[operand].dimensions = std::vector<uint32_t>(invalidRank, 0);
367                  });
368     }
369 }
370 
371 ///////////////////////// VALIDATE OPERAND SCALE /////////////////////////
372 
getInvalidScale(OperandType type)373 static float getInvalidScale(OperandType type) {
374     switch (type) {
375         case OperandType::FLOAT32:
376         case OperandType::INT32:
377         case OperandType::UINT32:
378         case OperandType::TENSOR_FLOAT32:
379             return 1.0f;
380         case OperandType::TENSOR_INT32:
381             return -1.0f;
382         case OperandType::TENSOR_QUANT8_ASYMM:
383             return 0.0f;
384         default:
385             return 0.0f;
386     }
387 }
388 
mutateOperandScaleTest(const sp<IDevice> & device,const Model & model)389 static void mutateOperandScaleTest(const sp<IDevice>& device, const Model& model) {
390     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
391         const float invalidScale = getInvalidScale(model.operands[operand].type);
392         const std::string message = "mutateOperandScaleTest: operand " + std::to_string(operand) +
393                                     " has scale of " + std::to_string(invalidScale);
394         validate(device, message, model,
395                  [operand, invalidScale](Model* model, ExecutionPreference*) {
396                      model->operands[operand].scale = invalidScale;
397                  });
398     }
399 }
400 
401 ///////////////////////// VALIDATE OPERAND ZERO POINT /////////////////////////
402 
getInvalidZeroPoints(OperandType type)403 static std::vector<int32_t> getInvalidZeroPoints(OperandType type) {
404     switch (type) {
405         case OperandType::FLOAT32:
406         case OperandType::INT32:
407         case OperandType::UINT32:
408         case OperandType::TENSOR_FLOAT32:
409         case OperandType::TENSOR_INT32:
410             return {1};
411         case OperandType::TENSOR_QUANT8_ASYMM:
412             return {-1, 256};
413         default:
414             return {};
415     }
416 }
417 
mutateOperandZeroPointTest(const sp<IDevice> & device,const Model & model)418 static void mutateOperandZeroPointTest(const sp<IDevice>& device, const Model& model) {
419     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
420         const std::vector<int32_t> invalidZeroPoints =
421                 getInvalidZeroPoints(model.operands[operand].type);
422         for (int32_t invalidZeroPoint : invalidZeroPoints) {
423             const std::string message = "mutateOperandZeroPointTest: operand " +
424                                         std::to_string(operand) + " has zero point of " +
425                                         std::to_string(invalidZeroPoint);
426             validate(device, message, model,
427                      [operand, invalidZeroPoint](Model* model, ExecutionPreference*) {
428                          model->operands[operand].zeroPoint = invalidZeroPoint;
429                      });
430         }
431     }
432 }
433 
434 ///////////////////////// VALIDATE OPERAND LIFETIME /////////////////////////////////////////////
435 
getInvalidLifeTimes(const Model & model,size_t modelSize,const Operand & operand)436 static std::vector<OperandLifeTime> getInvalidLifeTimes(const Model& model, size_t modelSize,
437                                                         const Operand& operand) {
438     // TODO: Support OperandLifeTime::CONSTANT_REFERENCE as an invalid lifetime
439     // TODO: Support OperandLifeTime::NO_VALUE as an invalid lifetime
440 
441     // Ways to get an invalid lifetime:
442     // - change whether a lifetime means an operand should have a writer
443     std::vector<OperandLifeTime> ret;
444     switch (operand.lifetime) {
445         case OperandLifeTime::MODEL_OUTPUT:
446         case OperandLifeTime::TEMPORARY_VARIABLE:
447             ret = {
448                     OperandLifeTime::MODEL_INPUT,
449                     OperandLifeTime::CONSTANT_COPY,
450             };
451             break;
452         case OperandLifeTime::CONSTANT_COPY:
453         case OperandLifeTime::CONSTANT_REFERENCE:
454         case OperandLifeTime::MODEL_INPUT:
455             ret = {
456                     OperandLifeTime::TEMPORARY_VARIABLE,
457                     OperandLifeTime::MODEL_OUTPUT,
458             };
459             break;
460         case OperandLifeTime::NO_VALUE:
461             // Not enough information to know whether
462             // TEMPORARY_VARIABLE or CONSTANT_COPY would be invalid --
463             // is this operand written (then CONSTANT_COPY would be
464             // invalid) or not (then TEMPORARY_VARIABLE would be
465             // invalid)?
466             break;
467         default:
468             ADD_FAILURE();
469             break;
470     }
471 
472     const size_t operandSize = sizeOfData(operand);  // will be zero if shape is unknown
473     if (!operandSize ||
474         exceedsBinderSizeLimit(modelSize + constantCopyExtraSize(model, operandSize))) {
475         // Unknown size or too-large size
476         ret.erase(std::remove(ret.begin(), ret.end(), OperandLifeTime::CONSTANT_COPY), ret.end());
477     }
478 
479     return ret;
480 }
481 
mutateOperandLifeTimeTest(const sp<IDevice> & device,const V1_1::Model & model)482 static void mutateOperandLifeTimeTest(const sp<IDevice>& device, const V1_1::Model& model) {
483     const size_t modelSize = sizeForBinder(model);
484     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
485         const std::vector<OperandLifeTime> invalidLifeTimes =
486                 getInvalidLifeTimes(model, modelSize, model.operands[operand]);
487         for (OperandLifeTime invalidLifeTime : invalidLifeTimes) {
488             const std::string message = "mutateOperandLifetimeTest: operand " +
489                                         std::to_string(operand) + " has lifetime " +
490                                         toString(invalidLifeTime) + " instead of lifetime " +
491                                         toString(model.operands[operand].lifetime);
492             validate(device, message, model,
493                      [operand, invalidLifeTime](Model* model, ExecutionPreference*) {
494                          static const DataLocation kZeroDataLocation = {};
495                          Operand& operandObj = model->operands[operand];
496                          switch (operandObj.lifetime) {
497                              case OperandLifeTime::MODEL_INPUT: {
498                                  hidl_vec_remove(&model->inputIndexes, uint32_t(operand));
499                                  break;
500                              }
501                              case OperandLifeTime::MODEL_OUTPUT: {
502                                  hidl_vec_remove(&model->outputIndexes, uint32_t(operand));
503                                  break;
504                              }
505                              default:
506                                  break;
507                          }
508                          operandObj.lifetime = invalidLifeTime;
509                          operandObj.location = kZeroDataLocation;
510                          switch (invalidLifeTime) {
511                              case OperandLifeTime::CONSTANT_COPY: {
512                                  becomeConstantCopy(model, &operandObj);
513                                  break;
514                              }
515                              case OperandLifeTime::MODEL_INPUT:
516                                  hidl_vec_push_back(&model->inputIndexes, uint32_t(operand));
517                                  break;
518                              case OperandLifeTime::MODEL_OUTPUT:
519                                  hidl_vec_push_back(&model->outputIndexes, uint32_t(operand));
520                                  break;
521                              default:
522                                  break;
523                          }
524                      });
525         }
526     }
527 }
528 
529 ///////////////////////// VALIDATE OPERAND INPUT-or-OUTPUT //////////////////////////////////////
530 
getInputOutputLifeTime(const Model & model,size_t modelSize,const Operand & operand)531 static std::optional<OperandLifeTime> getInputOutputLifeTime(const Model& model, size_t modelSize,
532                                                              const Operand& operand) {
533     // Ways to get an invalid lifetime (with respect to model inputIndexes and outputIndexes):
534     // - change whether a lifetime means an operand is a model input, a model output, or neither
535     // - preserve whether or not a lifetime means an operand should have a writer
536     switch (operand.lifetime) {
537         case OperandLifeTime::CONSTANT_COPY:
538         case OperandLifeTime::CONSTANT_REFERENCE:
539             return OperandLifeTime::MODEL_INPUT;
540         case OperandLifeTime::MODEL_INPUT: {
541             const size_t operandSize = sizeOfData(operand);  // will be zero if shape is unknown
542             if (!operandSize ||
543                 exceedsBinderSizeLimit(modelSize + constantCopyExtraSize(model, operandSize))) {
544                 // Unknown size or too-large size
545                 break;
546             }
547             return OperandLifeTime::CONSTANT_COPY;
548         }
549         case OperandLifeTime::MODEL_OUTPUT:
550             return OperandLifeTime::TEMPORARY_VARIABLE;
551         case OperandLifeTime::TEMPORARY_VARIABLE:
552             return OperandLifeTime::MODEL_OUTPUT;
553         case OperandLifeTime::NO_VALUE:
554             // Not enough information to know whether
555             // TEMPORARY_VARIABLE or CONSTANT_COPY would be an
556             // appropriate choice -- is this operand written (then
557             // TEMPORARY_VARIABLE would be appropriate) or not (then
558             // CONSTANT_COPY would be appropriate)?
559             break;
560         default:
561             ADD_FAILURE();
562             break;
563     }
564 
565     return std::nullopt;
566 }
567 
mutateOperandInputOutputTest(const sp<IDevice> & device,const V1_1::Model & model)568 static void mutateOperandInputOutputTest(const sp<IDevice>& device, const V1_1::Model& model) {
569     const size_t modelSize = sizeForBinder(model);
570     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
571         const std::optional<OperandLifeTime> changedLifeTime =
572                 getInputOutputLifeTime(model, modelSize, model.operands[operand]);
573         if (changedLifeTime) {
574             const std::string message = "mutateOperandInputOutputTest: operand " +
575                                         std::to_string(operand) + " has lifetime " +
576                                         toString(*changedLifeTime) + " instead of lifetime " +
577                                         toString(model.operands[operand].lifetime);
578             validate(device, message, model,
579                      [operand, changedLifeTime](Model* model, ExecutionPreference*) {
580                          static const DataLocation kZeroDataLocation = {};
581                          Operand& operandObj = model->operands[operand];
582                          operandObj.lifetime = *changedLifeTime;
583                          operandObj.location = kZeroDataLocation;
584                          if (*changedLifeTime == OperandLifeTime::CONSTANT_COPY) {
585                              becomeConstantCopy(model, &operandObj);
586                          }
587                      });
588         }
589     }
590 }
591 
592 ///////////////////////// VALIDATE OPERAND NUMBER OF CONSUMERS //////////////////////////////////
593 
getInvalidNumberOfConsumers(uint32_t numberOfConsumers)594 static std::vector<uint32_t> getInvalidNumberOfConsumers(uint32_t numberOfConsumers) {
595     if (numberOfConsumers == 0) {
596         return {1};
597     } else {
598         return {numberOfConsumers - 1, numberOfConsumers + 1};
599     }
600 }
601 
mutateOperandNumberOfConsumersTest(const sp<IDevice> & device,const V1_1::Model & model)602 static void mutateOperandNumberOfConsumersTest(const sp<IDevice>& device,
603                                                const V1_1::Model& model) {
604     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
605         const std::vector<uint32_t> invalidNumberOfConsumersVec =
606                 getInvalidNumberOfConsumers(model.operands[operand].numberOfConsumers);
607         for (uint32_t invalidNumberOfConsumers : invalidNumberOfConsumersVec) {
608             const std::string message =
609                     "mutateOperandNumberOfConsumersTest: operand " + std::to_string(operand) +
610                     " numberOfConsumers = " + std::to_string(invalidNumberOfConsumers);
611             validate(device, message, model,
612                      [operand, invalidNumberOfConsumers](Model* model, ExecutionPreference*) {
613                          model->operands[operand].numberOfConsumers = invalidNumberOfConsumers;
614                      });
615         }
616     }
617 }
618 
619 ///////////////////////// VALIDATE OPERAND NUMBER OF WRITERS ////////////////////////////////////
620 
mutateOperandAddWriterTest(const sp<IDevice> & device,const V1_1::Model & model)621 static void mutateOperandAddWriterTest(const sp<IDevice>& device, const V1_1::Model& model) {
622     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
623         for (size_t badOutputNum = 0; badOutputNum < model.operations[operation].outputs.size();
624              ++badOutputNum) {
625             const uint32_t outputOperandIndex = model.operations[operation].outputs[badOutputNum];
626             const std::string message = "mutateOperandAddWriterTest: operation " +
627                                         std::to_string(operation) + " writes to " +
628                                         std::to_string(outputOperandIndex);
629             // We'll insert a copy of the operation, all of whose
630             // OTHER output operands are newly-created -- i.e.,
631             // there'll only be a duplicate write of ONE of that
632             // operation's output operands.
633             validate(device, message, model,
634                      [operation, badOutputNum](Model* model, ExecutionPreference*) {
635                          Operation newOperation = model->operations[operation];
636                          for (uint32_t input : newOperation.inputs) {
637                              ++model->operands[input].numberOfConsumers;
638                          }
639                          for (size_t outputNum = 0; outputNum < newOperation.outputs.size();
640                               ++outputNum) {
641                              if (outputNum == badOutputNum) continue;
642 
643                              Operand operandValue =
644                                      model->operands[newOperation.outputs[outputNum]];
645                              operandValue.numberOfConsumers = 0;
646                              if (operandValue.lifetime == OperandLifeTime::MODEL_OUTPUT) {
647                                  operandValue.lifetime = OperandLifeTime::TEMPORARY_VARIABLE;
648                              } else {
649                                  ASSERT_EQ(operandValue.lifetime,
650                                            OperandLifeTime::TEMPORARY_VARIABLE);
651                              }
652                              newOperation.outputs[outputNum] =
653                                      hidl_vec_push_back(&model->operands, operandValue);
654                          }
655                          // Where do we insert the extra writer (a new
656                          // operation)?  It has to be later than all the
657                          // writers of its inputs.  The easiest thing to do
658                          // is to insert it at the end of the operation
659                          // sequence.
660                          hidl_vec_push_back(&model->operations, newOperation);
661                      });
662         }
663     }
664 }
665 
666 ///////////////////////// VALIDATE EXTRA ??? /////////////////////////
667 
668 // TODO: Operand::location
669 
670 ///////////////////////// VALIDATE OPERATION OPERAND TYPE /////////////////////////
671 
mutateOperand(Operand * operand,OperandType type)672 static void mutateOperand(Operand* operand, OperandType type) {
673     Operand newOperand = *operand;
674     newOperand.type = type;
675     switch (type) {
676         case OperandType::FLOAT32:
677         case OperandType::INT32:
678         case OperandType::UINT32:
679             newOperand.dimensions = hidl_vec<uint32_t>();
680             newOperand.scale = 0.0f;
681             newOperand.zeroPoint = 0;
682             break;
683         case OperandType::TENSOR_FLOAT32:
684             newOperand.dimensions =
685                     operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
686             newOperand.scale = 0.0f;
687             newOperand.zeroPoint = 0;
688             break;
689         case OperandType::TENSOR_INT32:
690             newOperand.dimensions =
691                     operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
692             newOperand.zeroPoint = 0;
693             break;
694         case OperandType::TENSOR_QUANT8_ASYMM:
695             newOperand.dimensions =
696                     operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
697             newOperand.scale = operand->scale != 0.0f ? operand->scale : 1.0f;
698             break;
699         case OperandType::OEM:
700         case OperandType::TENSOR_OEM_BYTE:
701         default:
702             break;
703     }
704     *operand = newOperand;
705 }
706 
mutateOperationOperandTypeSkip(size_t operand,const Model & model)707 static bool mutateOperationOperandTypeSkip(size_t operand, const Model& model) {
708     // LSH_PROJECTION's second argument is allowed to have any type. This is the
709     // only operation that currently has a type that can be anything independent
710     // from any other type. Changing the operand type to any other type will
711     // result in a valid model for LSH_PROJECTION. If this is the case, skip the
712     // test.
713     for (const Operation& operation : model.operations) {
714         if (operation.type == OperationType::LSH_PROJECTION && operand == operation.inputs[1]) {
715             return true;
716         }
717     }
718     return false;
719 }
720 
mutateOperationOperandTypeTest(const sp<IDevice> & device,const Model & model)721 static void mutateOperationOperandTypeTest(const sp<IDevice>& device, const Model& model) {
722     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
723         if (mutateOperationOperandTypeSkip(operand, model)) {
724             continue;
725         }
726         for (OperandType invalidOperandType : hidl_enum_range<OperandType>{}) {
727             // Do not test OEM types
728             if (invalidOperandType == model.operands[operand].type ||
729                 invalidOperandType == OperandType::OEM ||
730                 invalidOperandType == OperandType::TENSOR_OEM_BYTE) {
731                 continue;
732             }
733             const std::string message = "mutateOperationOperandTypeTest: operand " +
734                                         std::to_string(operand) + " set to type " +
735                                         toString(invalidOperandType);
736             validate(device, message, model,
737                      [operand, invalidOperandType](Model* model, ExecutionPreference*) {
738                          mutateOperand(&model->operands[operand], invalidOperandType);
739                      });
740         }
741     }
742 }
743 
744 ///////////////////////// VALIDATE MODEL OPERATION TYPE /////////////////////////
745 
746 static const int32_t invalidOperationTypes[] = {
747         static_cast<int32_t>(OperationType::ADD) - 1,            // lower bound fundamental
748         static_cast<int32_t>(OperationType::TRANSPOSE) + 1,      // upper bound fundamental
749         static_cast<int32_t>(OperationType::OEM_OPERATION) - 1,  // lower bound OEM
750         static_cast<int32_t>(OperationType::OEM_OPERATION) + 1,  // upper bound OEM
751 };
752 
mutateOperationTypeTest(const sp<IDevice> & device,const Model & model)753 static void mutateOperationTypeTest(const sp<IDevice>& device, const Model& model) {
754     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
755         for (int32_t invalidOperationType : invalidOperationTypes) {
756             const std::string message = "mutateOperationTypeTest: operation " +
757                                         std::to_string(operation) + " set to value " +
758                                         std::to_string(invalidOperationType);
759             validate(device, message, model,
760                      [operation, invalidOperationType](Model* model, ExecutionPreference*) {
761                          model->operations[operation].type =
762                                  static_cast<OperationType>(invalidOperationType);
763                      });
764         }
765     }
766 }
767 
768 ///////////////////////// VALIDATE MODEL OPERATION INPUT OPERAND INDEX /////////////////////////
769 
mutateOperationInputOperandIndexTest(const sp<IDevice> & device,const Model & model)770 static void mutateOperationInputOperandIndexTest(const sp<IDevice>& device, const Model& model) {
771     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
772         const uint32_t invalidOperand = model.operands.size();
773         for (size_t input = 0; input < model.operations[operation].inputs.size(); ++input) {
774             const std::string message = "mutateOperationInputOperandIndexTest: operation " +
775                                         std::to_string(operation) + " input " +
776                                         std::to_string(input);
777             validate(device, message, model,
778                      [operation, input, invalidOperand](Model* model, ExecutionPreference*) {
779                          model->operations[operation].inputs[input] = invalidOperand;
780                      });
781         }
782     }
783 }
784 
785 ///////////////////////// VALIDATE MODEL OPERATION OUTPUT OPERAND INDEX /////////////////////////
786 
mutateOperationOutputOperandIndexTest(const sp<IDevice> & device,const Model & model)787 static void mutateOperationOutputOperandIndexTest(const sp<IDevice>& device, const Model& model) {
788     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
789         const uint32_t invalidOperand = model.operands.size();
790         for (size_t output = 0; output < model.operations[operation].outputs.size(); ++output) {
791             const std::string message = "mutateOperationOutputOperandIndexTest: operation " +
792                                         std::to_string(operation) + " output " +
793                                         std::to_string(output);
794             validate(device, message, model,
795                      [operation, output, invalidOperand](Model* model, ExecutionPreference*) {
796                          model->operations[operation].outputs[output] = invalidOperand;
797                      });
798         }
799     }
800 }
801 
802 ///////////////////////// VALIDATE MODEL OPERANDS WRITTEN ///////////////////////////////////////
803 
mutateOperationRemoveWriteTest(const sp<IDevice> & device,const V1_1::Model & model)804 static void mutateOperationRemoveWriteTest(const sp<IDevice>& device, const V1_1::Model& model) {
805     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
806         for (size_t outputNum = 0; outputNum < model.operations[operation].outputs.size();
807              ++outputNum) {
808             const uint32_t outputOperandIndex = model.operations[operation].outputs[outputNum];
809             if (model.operands[outputOperandIndex].numberOfConsumers > 0) {
810                 const std::string message = "mutateOperationRemoveWriteTest: operation " +
811                                             std::to_string(operation) + " writes to " +
812                                             std::to_string(outputOperandIndex);
813                 validate(device, message, model,
814                          [operation, outputNum](Model* model, ExecutionPreference*) {
815                              uint32_t& outputOperandIndex =
816                                      model->operations[operation].outputs[outputNum];
817                              Operand operandValue = model->operands[outputOperandIndex];
818                              operandValue.numberOfConsumers = 0;
819                              if (operandValue.lifetime == OperandLifeTime::MODEL_OUTPUT) {
820                                  operandValue.lifetime = OperandLifeTime::TEMPORARY_VARIABLE;
821                              } else {
822                                  ASSERT_EQ(operandValue.lifetime,
823                                            OperandLifeTime::TEMPORARY_VARIABLE);
824                              }
825                              outputOperandIndex =
826                                      hidl_vec_push_back(&model->operands, operandValue);
827                          });
828             }
829         }
830     }
831 }
832 
833 ///////////////////////// REMOVE OPERAND FROM EVERYTHING /////////////////////////
834 
removeValueAndDecrementGreaterValues(hidl_vec<uint32_t> * vec,uint32_t value)835 static void removeValueAndDecrementGreaterValues(hidl_vec<uint32_t>* vec, uint32_t value) {
836     if (vec) {
837         // remove elements matching "value"
838         auto last = std::remove(vec->begin(), vec->end(), value);
839         vec->resize(std::distance(vec->begin(), last));
840 
841         // decrement elements exceeding "value"
842         std::transform(vec->begin(), vec->end(), vec->begin(),
843                        [value](uint32_t v) { return v > value ? v-- : v; });
844     }
845 }
846 
removeOperand(Model * model,uint32_t index)847 static void removeOperand(Model* model, uint32_t index) {
848     hidl_vec_removeAt(&model->operands, index);
849     for (Operation& operation : model->operations) {
850         removeValueAndDecrementGreaterValues(&operation.inputs, index);
851         removeValueAndDecrementGreaterValues(&operation.outputs, index);
852     }
853     removeValueAndDecrementGreaterValues(&model->inputIndexes, index);
854     removeValueAndDecrementGreaterValues(&model->outputIndexes, index);
855 }
856 
removeOperandTest(const sp<IDevice> & device,const Model & model)857 static void removeOperandTest(const sp<IDevice>& device, const Model& model) {
858     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
859         const std::string message = "removeOperandTest: operand " + std::to_string(operand);
860         validate(device, message, model,
861                  [operand](Model* model, ExecutionPreference*) { removeOperand(model, operand); });
862     }
863 }
864 
865 ///////////////////////// REMOVE OPERATION /////////////////////////
866 
removeOperation(Model * model,uint32_t index)867 static void removeOperation(Model* model, uint32_t index) {
868     for (uint32_t operand : model->operations[index].inputs) {
869         model->operands[operand].numberOfConsumers--;
870     }
871     hidl_vec_removeAt(&model->operations, index);
872 }
873 
removeOperationTest(const sp<IDevice> & device,const Model & model)874 static void removeOperationTest(const sp<IDevice>& device, const Model& model) {
875     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
876         const std::string message = "removeOperationTest: operation " + std::to_string(operation);
877         validate(device, message, model, [operation](Model* model, ExecutionPreference*) {
878             removeOperation(model, operation);
879         });
880     }
881 }
882 
883 ///////////////////////// REMOVE OPERATION INPUT /////////////////////////
884 
removeOperationInputTest(const sp<IDevice> & device,const Model & model)885 static void removeOperationInputTest(const sp<IDevice>& device, const Model& model) {
886     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
887         for (size_t input = 0; input < model.operations[operation].inputs.size(); ++input) {
888             const Operation& op = model.operations[operation];
889             // CONCATENATION has at least 2 inputs, with the last element being
890             // INT32. Skip this test if removing one of CONCATENATION's
891             // inputs still produces a valid model.
892             if (op.type == OperationType::CONCATENATION && op.inputs.size() > 2 &&
893                 input != op.inputs.size() - 1) {
894                 continue;
895             }
896             const std::string message = "removeOperationInputTest: operation " +
897                                         std::to_string(operation) + ", input " +
898                                         std::to_string(input);
899             validate(device, message, model,
900                      [operation, input](Model* model, ExecutionPreference*) {
901                          uint32_t operand = model->operations[operation].inputs[input];
902                          model->operands[operand].numberOfConsumers--;
903                          hidl_vec_removeAt(&model->operations[operation].inputs, input);
904                      });
905         }
906     }
907 }
908 
909 ///////////////////////// REMOVE OPERATION OUTPUT /////////////////////////
910 
removeOperationOutputTest(const sp<IDevice> & device,const Model & model)911 static void removeOperationOutputTest(const sp<IDevice>& device, const Model& model) {
912     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
913         for (size_t output = 0; output < model.operations[operation].outputs.size(); ++output) {
914             const std::string message = "removeOperationOutputTest: operation " +
915                                         std::to_string(operation) + ", output " +
916                                         std::to_string(output);
917             validate(device, message, model,
918                      [operation, output](Model* model, ExecutionPreference*) {
919                          hidl_vec_removeAt(&model->operations[operation].outputs, output);
920                      });
921         }
922     }
923 }
924 
925 ///////////////////////// MODEL VALIDATION /////////////////////////
926 
927 // TODO: remove model input
928 // TODO: remove model output
929 // TODO: add unused operation
930 
931 ///////////////////////// ADD OPERATION INPUT /////////////////////////
932 
addOperationInputTest(const sp<IDevice> & device,const Model & model)933 static void addOperationInputTest(const sp<IDevice>& device, const Model& model) {
934     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
935         const std::string message = "addOperationInputTest: operation " + std::to_string(operation);
936         validate(device, message, model, [operation](Model* model, ExecutionPreference*) {
937             uint32_t index = addOperand(model, OperandLifeTime::MODEL_INPUT);
938             hidl_vec_push_back(&model->operations[operation].inputs, index);
939             hidl_vec_push_back(&model->inputIndexes, index);
940         });
941     }
942 }
943 
944 ///////////////////////// ADD OPERATION OUTPUT /////////////////////////
945 
addOperationOutputTest(const sp<IDevice> & device,const Model & model)946 static void addOperationOutputTest(const sp<IDevice>& device, const Model& model) {
947     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
948         const std::string message =
949                 "addOperationOutputTest: operation " + std::to_string(operation);
950         validate(device, message, model, [operation](Model* model, ExecutionPreference*) {
951             uint32_t index = addOperand(model, OperandLifeTime::MODEL_OUTPUT);
952             hidl_vec_push_back(&model->operations[operation].outputs, index);
953             hidl_vec_push_back(&model->outputIndexes, index);
954         });
955     }
956 }
957 
958 ///////////////////////// VALIDATE EXECUTION PREFERENCE /////////////////////////
959 
960 static const int32_t invalidExecutionPreferences[] = {
961         static_cast<int32_t>(ExecutionPreference::LOW_POWER) - 1,        // lower bound
962         static_cast<int32_t>(ExecutionPreference::SUSTAINED_SPEED) + 1,  // upper bound
963 };
964 
mutateExecutionPreferenceTest(const sp<IDevice> & device,const Model & model)965 static void mutateExecutionPreferenceTest(const sp<IDevice>& device, const Model& model) {
966     for (int32_t invalidPreference : invalidExecutionPreferences) {
967         const std::string message =
968                 "mutateExecutionPreferenceTest: preference " + std::to_string(invalidPreference);
969         validate(device, message, model,
970                  [invalidPreference](Model*, ExecutionPreference* preference) {
971                      *preference = static_cast<ExecutionPreference>(invalidPreference);
972                  });
973     }
974 }
975 
976 ////////////////////////// ENTRY POINT //////////////////////////////
977 
validateModel(const sp<IDevice> & device,const Model & model)978 void validateModel(const sp<IDevice>& device, const Model& model) {
979     mutateExecutionOrderTest(device, model);
980     mutateOperandTypeTest(device, model);
981     mutateOperandRankTest(device, model);
982     mutateOperandScaleTest(device, model);
983     mutateOperandZeroPointTest(device, model);
984     mutateOperandLifeTimeTest(device, model);
985     mutateOperandInputOutputTest(device, model);
986     mutateOperandNumberOfConsumersTest(device, model);
987     mutateOperandAddWriterTest(device, model);
988     mutateOperationOperandTypeTest(device, model);
989     mutateOperationTypeTest(device, model);
990     mutateOperationInputOperandIndexTest(device, model);
991     mutateOperationOutputOperandIndexTest(device, model);
992     mutateOperationRemoveWriteTest(device, model);
993     removeOperandTest(device, model);
994     removeOperationTest(device, model);
995     removeOperationInputTest(device, model);
996     removeOperationOutputTest(device, model);
997     addOperationInputTest(device, model);
998     addOperationOutputTest(device, model);
999     mutateExecutionPreferenceTest(device, model);
1000 }
1001 
1002 }  // namespace android::hardware::neuralnetworks::V1_1::vts::functional
1003