1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "neuralnetworks_hidl_hal_test"
18 
19 #include "1.0/Callbacks.h"
20 #include "1.0/Utils.h"
21 #include "GeneratedTestHarness.h"
22 #include "VtsHalNeuralnetworks.h"
23 
24 #include <optional>
25 #include <type_traits>
26 #include <utility>
27 
28 namespace android::hardware::neuralnetworks::V1_0::vts::functional {
29 
30 using implementation::PreparedModelCallback;
31 
32 using PrepareModelMutation = std::function<void(Model*)>;
33 
34 ///////////////////////// UTILITY FUNCTIONS /////////////////////////
35 
validateGetSupportedOperations(const sp<IDevice> & device,const std::string & message,const Model & model)36 static void validateGetSupportedOperations(const sp<IDevice>& device, const std::string& message,
37                                            const Model& model) {
38     SCOPED_TRACE(message + " [getSupportedOperations]");
39 
40     Return<void> ret =
41             device->getSupportedOperations(model, [&](ErrorStatus status, const hidl_vec<bool>&) {
42                 EXPECT_EQ(ErrorStatus::INVALID_ARGUMENT, status);
43             });
44     EXPECT_TRUE(ret.isOk());
45 }
46 
validatePrepareModel(const sp<IDevice> & device,const std::string & message,const Model & model)47 static void validatePrepareModel(const sp<IDevice>& device, const std::string& message,
48                                  const Model& model) {
49     SCOPED_TRACE(message + " [prepareModel]");
50 
51     sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
52     Return<ErrorStatus> prepareLaunchStatus = device->prepareModel(model, preparedModelCallback);
53     ASSERT_TRUE(prepareLaunchStatus.isOk());
54     ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, static_cast<ErrorStatus>(prepareLaunchStatus));
55 
56     preparedModelCallback->wait();
57     ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
58     ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, prepareReturnStatus);
59     sp<IPreparedModel> preparedModel = preparedModelCallback->getPreparedModel();
60     ASSERT_EQ(nullptr, preparedModel.get());
61 }
62 
63 // Primary validation function. This function will take a valid model, apply a
64 // mutation to invalidate the model, then pass these to supportedOperations and
65 // prepareModel.
validate(const sp<IDevice> & device,const std::string & message,const Model & originalModel,const PrepareModelMutation & mutate)66 static void validate(const sp<IDevice>& device, const std::string& message,
67                      const Model& originalModel, const PrepareModelMutation& mutate) {
68     Model model = originalModel;
69     mutate(&model);
70 
71     validateGetSupportedOperations(device, message, model);
72     validatePrepareModel(device, message, model);
73 }
74 
addOperand(Model * model)75 static uint32_t addOperand(Model* model) {
76     return hidl_vec_push_back(&model->operands,
77                               {
78                                       .type = OperandType::INT32,
79                                       .dimensions = {},
80                                       .numberOfConsumers = 0,
81                                       .scale = 0.0f,
82                                       .zeroPoint = 0,
83                                       .lifetime = OperandLifeTime::MODEL_INPUT,
84                                       .location = {.poolIndex = 0, .offset = 0, .length = 0},
85                               });
86 }
87 
addOperand(Model * model,OperandLifeTime lifetime)88 static uint32_t addOperand(Model* model, OperandLifeTime lifetime) {
89     uint32_t index = addOperand(model);
90     model->operands[index].numberOfConsumers = 1;
91     model->operands[index].lifetime = lifetime;
92     return index;
93 }
94 
95 // If we introduce a CONSTANT_COPY for an operand of size operandSize,
96 // how much will this increase the size of the model?  This assumes
97 // that we can (re)use all of model.operandValues for the operand
98 // value.
constantCopyExtraSize(const Model & model,size_t operandSize)99 static size_t constantCopyExtraSize(const Model& model, size_t operandSize) {
100     const size_t operandValuesSize = model.operandValues.size();
101     return (operandValuesSize < operandSize) ? (operandSize - operandValuesSize) : 0;
102 }
103 
104 // Highly specialized utility routine for converting an operand to
105 // CONSTANT_COPY lifetime.
106 //
107 // Expects that:
108 // - operand has a known size
109 // - operand->lifetime has already been set to CONSTANT_COPY
110 // - operand->location has been zeroed out
111 //
112 // Does the following:
113 // - initializes operand->location to point to the beginning of model->operandValues
114 // - resizes model->operandValues (if necessary) to be large enough for the operand
115 //   value, padding it with zeroes on the end
116 //
117 // Potential problem:
118 // By changing the operand to CONSTANT_COPY lifetime, this function is effectively initializing the
119 // operand with unspecified (but deterministic) data. This means that the model may be invalidated
120 // in two ways: not only is the lifetime of CONSTANT_COPY invalid, but the operand's value in the
121 // graph may also be invalid (e.g., if the operand is used as an activation code and has an invalid
122 // value). For now, this should be fine because it just means we're not testing what we think we're
123 // testing in certain cases; but we can handwave this and assume we're probabilistically likely to
124 // exercise the validation code over the span of the entire test set and operand space.
125 //
126 // Aborts if the specified operand type is an extension type or OEM type.
becomeConstantCopy(Model * model,Operand * operand)127 static void becomeConstantCopy(Model* model, Operand* operand) {
128     // sizeOfData will abort if the specified type is an extension type or OEM type.
129     const size_t sizeOfOperand = sizeOfData(*operand);
130     EXPECT_NE(sizeOfOperand, size_t(0));
131     operand->location.poolIndex = 0;
132     operand->location.offset = 0;
133     operand->location.length = sizeOfOperand;
134     if (model->operandValues.size() < sizeOfOperand) {
135         model->operandValues.resize(sizeOfOperand);
136     }
137 }
138 
139 // The sizeForBinder() functions estimate the size of the
140 // representation of a value when sent to binder.  It's probably a bit
141 // of an under-estimate, because we don't know the size of the
142 // metadata in the binder format (e.g., representation of the size of
143 // a vector); but at least it adds up "big" things like vector
144 // contents.  However, it doesn't treat inter-field or end-of-struct
145 // padding in a methodical way -- there's no attempt to be consistent
146 // in whether or not padding in the native (C++) representation
147 // contributes to the estimated size for the binder representation;
148 // and there's no attempt to understand what padding (if any) is
149 // needed in the binder representation.
150 //
151 // This assumes that non-metadata uses a fixed length encoding (e.g.,
152 // a uint32_t is always encoded in sizeof(uint32_t) bytes, rather than
153 // using an encoding whose length is related to the magnitude of the
154 // encoded value).
155 
156 template <typename Type>
sizeForBinder(const Type & val)157 static size_t sizeForBinder(const Type& val) {
158     static_assert(std::is_trivially_copyable_v<std::remove_reference_t<Type>>,
159                   "expected a trivially copyable type");
160     return sizeof(val);
161 }
162 
163 template <typename Type>
sizeForBinder(const hidl_vec<Type> & vec)164 static size_t sizeForBinder(const hidl_vec<Type>& vec) {
165     return std::accumulate(vec.begin(), vec.end(), 0,
166                            [](size_t acc, const Type& x) { return acc + sizeForBinder(x); });
167 }
168 
169 template <>
sizeForBinder(const Operand & operand)170 size_t sizeForBinder(const Operand& operand) {
171     size_t size = 0;
172 
173     size += sizeForBinder(operand.type);
174     size += sizeForBinder(operand.dimensions);
175     size += sizeForBinder(operand.numberOfConsumers);
176     size += sizeForBinder(operand.scale);
177     size += sizeForBinder(operand.zeroPoint);
178     size += sizeForBinder(operand.lifetime);
179     size += sizeForBinder(operand.location);
180 
181     return size;
182 }
183 
184 template <>
sizeForBinder(const Operation & operation)185 size_t sizeForBinder(const Operation& operation) {
186     size_t size = 0;
187 
188     size += sizeForBinder(operation.type);
189     size += sizeForBinder(operation.inputs);
190     size += sizeForBinder(operation.outputs);
191 
192     return size;
193 }
194 
195 template <>
sizeForBinder(const hidl_string & name)196 size_t sizeForBinder(const hidl_string& name) {
197     return name.size();
198 }
199 
200 template <>
sizeForBinder(const hidl_memory & memory)201 size_t sizeForBinder(const hidl_memory& memory) {
202     // This is just a guess.
203 
204     size_t size = 0;
205 
206     if (const native_handle_t* handle = memory.handle()) {
207         size += sizeof(*handle);
208         size += sizeof(handle->data[0] * (handle->numFds + handle->numInts));
209     }
210     size += sizeForBinder(memory.name());
211 
212     return size;
213 }
214 
215 template <>
sizeForBinder(const Model & model)216 size_t sizeForBinder(const Model& model) {
217     size_t size = 0;
218 
219     size += sizeForBinder(model.operands);
220     size += sizeForBinder(model.operations);
221     size += sizeForBinder(model.inputIndexes);
222     size += sizeForBinder(model.outputIndexes);
223     size += sizeForBinder(model.operandValues);
224     size += sizeForBinder(model.pools);
225 
226     return size;
227 }
228 
229 // https://developer.android.com/reference/android/os/TransactionTooLargeException.html
230 //
231 //     "The Binder transaction buffer has a limited fixed size,
232 //     currently 1Mb, which is shared by all transactions in progress
233 //     for the process."
234 //
235 // Will our representation fit under this limit?  There are three complications:
236 // - Our representation size is just approximate (see sizeForBinder()).
237 // - This object may not be the only occupant of the Binder transaction buffer
238 //   (although our VTS test suite should not be putting multiple objects in the
239 //   buffer at once).
240 // - IBinder.MAX_IPC_SIZE recommends limiting a transaction to 64 * 1024 bytes.
241 // So we'll be very conservative: We want the representation size to be no
242 // larger than half the recommended limit.
243 //
244 // If our representation grows large enough that it still fits within
245 // the transaction buffer but combined with other transactions may
246 // exceed the buffer size, then we may see intermittent HAL transport
247 // errors.
exceedsBinderSizeLimit(size_t representationSize)248 static bool exceedsBinderSizeLimit(size_t representationSize) {
249     // There is no C++ API to retrieve the value of the Java variable IBinder.MAX_IPC_SIZE.
250     static const size_t kHalfMaxIPCSize = 64 * 1024 / 2;
251 
252     return representationSize > kHalfMaxIPCSize;
253 }
254 
255 ///////////////////////// VALIDATE EXECUTION ORDER ////////////////////////////
256 
mutateExecutionOrderTest(const sp<IDevice> & device,const V1_0::Model & model)257 static void mutateExecutionOrderTest(const sp<IDevice>& device, const V1_0::Model& model) {
258     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
259         const Operation& operationObj = model.operations[operation];
260         for (uint32_t input : operationObj.inputs) {
261             if (model.operands[input].lifetime == OperandLifeTime::TEMPORARY_VARIABLE ||
262                 model.operands[input].lifetime == OperandLifeTime::MODEL_OUTPUT) {
263                 // This operation reads an operand written by some
264                 // other operation.  Move this operation to the
265                 // beginning of the sequence, ensuring that it reads
266                 // the operand before that operand is written, thereby
267                 // violating execution order rules.
268                 const std::string message = "mutateExecutionOrderTest: operation " +
269                                             std::to_string(operation) + " is a reader";
270                 validate(device, message, model, [operation](Model* model) {
271                     auto& operations = model->operations;
272                     std::rotate(operations.begin(), operations.begin() + operation,
273                                 operations.begin() + operation + 1);
274                 });
275                 break;  // only need to do this once per operation
276             }
277         }
278         for (uint32_t output : operationObj.outputs) {
279             if (model.operands[output].numberOfConsumers > 0) {
280                 // This operation writes an operand read by some other
281                 // operation.  Move this operation to the end of the
282                 // sequence, ensuring that it writes the operand after
283                 // that operand is read, thereby violating execution
284                 // order rules.
285                 const std::string message = "mutateExecutionOrderTest: operation " +
286                                             std::to_string(operation) + " is a writer";
287                 validate(device, message, model, [operation](Model* model) {
288                     auto& operations = model->operations;
289                     std::rotate(operations.begin() + operation, operations.begin() + operation + 1,
290                                 operations.end());
291                 });
292                 break;  // only need to do this once per operation
293             }
294         }
295     }
296 }
297 
298 ///////////////////////// VALIDATE MODEL OPERAND TYPE /////////////////////////
299 
300 static const int32_t invalidOperandTypes[] = {
301         static_cast<int32_t>(OperandType::FLOAT32) - 1,              // lower bound fundamental
302         static_cast<int32_t>(OperandType::TENSOR_QUANT8_ASYMM) + 1,  // upper bound fundamental
303         static_cast<int32_t>(OperandType::OEM) - 1,                  // lower bound OEM
304         static_cast<int32_t>(OperandType::TENSOR_OEM_BYTE) + 1,      // upper bound OEM
305 };
306 
mutateOperandTypeTest(const sp<IDevice> & device,const Model & model)307 static void mutateOperandTypeTest(const sp<IDevice>& device, const Model& model) {
308     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
309         for (int32_t invalidOperandType : invalidOperandTypes) {
310             const std::string message = "mutateOperandTypeTest: operand " +
311                                         std::to_string(operand) + " set to value " +
312                                         std::to_string(invalidOperandType);
313             validate(device, message, model, [operand, invalidOperandType](Model* model) {
314                 model->operands[operand].type = static_cast<OperandType>(invalidOperandType);
315             });
316         }
317     }
318 }
319 
320 ///////////////////////// VALIDATE OPERAND RANK /////////////////////////
321 
getInvalidRank(OperandType type)322 static uint32_t getInvalidRank(OperandType type) {
323     switch (type) {
324         case OperandType::FLOAT32:
325         case OperandType::INT32:
326         case OperandType::UINT32:
327             return 1;
328         case OperandType::TENSOR_FLOAT32:
329         case OperandType::TENSOR_INT32:
330         case OperandType::TENSOR_QUANT8_ASYMM:
331             return 0;
332         default:
333             return 0;
334     }
335 }
336 
mutateOperandRankTest(const sp<IDevice> & device,const Model & model)337 static void mutateOperandRankTest(const sp<IDevice>& device, const Model& model) {
338     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
339         const uint32_t invalidRank = getInvalidRank(model.operands[operand].type);
340         const std::string message = "mutateOperandRankTest: operand " + std::to_string(operand) +
341                                     " has rank of " + std::to_string(invalidRank);
342         validate(device, message, model, [operand, invalidRank](Model* model) {
343             model->operands[operand].dimensions = std::vector<uint32_t>(invalidRank, 0);
344         });
345     }
346 }
347 
348 ///////////////////////// VALIDATE OPERAND SCALE /////////////////////////
349 
getInvalidScale(OperandType type)350 static float getInvalidScale(OperandType type) {
351     switch (type) {
352         case OperandType::FLOAT32:
353         case OperandType::INT32:
354         case OperandType::UINT32:
355         case OperandType::TENSOR_FLOAT32:
356             return 1.0f;
357         case OperandType::TENSOR_INT32:
358             return -1.0f;
359         case OperandType::TENSOR_QUANT8_ASYMM:
360             return 0.0f;
361         default:
362             return 0.0f;
363     }
364 }
365 
mutateOperandScaleTest(const sp<IDevice> & device,const Model & model)366 static void mutateOperandScaleTest(const sp<IDevice>& device, const Model& model) {
367     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
368         const float invalidScale = getInvalidScale(model.operands[operand].type);
369         const std::string message = "mutateOperandScaleTest: operand " + std::to_string(operand) +
370                                     " has scale of " + std::to_string(invalidScale);
371         validate(device, message, model, [operand, invalidScale](Model* model) {
372             model->operands[operand].scale = invalidScale;
373         });
374     }
375 }
376 
377 ///////////////////////// VALIDATE OPERAND ZERO POINT /////////////////////////
378 
getInvalidZeroPoints(OperandType type)379 static std::vector<int32_t> getInvalidZeroPoints(OperandType type) {
380     switch (type) {
381         case OperandType::FLOAT32:
382         case OperandType::INT32:
383         case OperandType::UINT32:
384         case OperandType::TENSOR_FLOAT32:
385         case OperandType::TENSOR_INT32:
386             return {1};
387         case OperandType::TENSOR_QUANT8_ASYMM:
388             return {-1, 256};
389         default:
390             return {};
391     }
392 }
393 
mutateOperandZeroPointTest(const sp<IDevice> & device,const Model & model)394 static void mutateOperandZeroPointTest(const sp<IDevice>& device, const Model& model) {
395     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
396         const std::vector<int32_t> invalidZeroPoints =
397                 getInvalidZeroPoints(model.operands[operand].type);
398         for (int32_t invalidZeroPoint : invalidZeroPoints) {
399             const std::string message = "mutateOperandZeroPointTest: operand " +
400                                         std::to_string(operand) + " has zero point of " +
401                                         std::to_string(invalidZeroPoint);
402             validate(device, message, model, [operand, invalidZeroPoint](Model* model) {
403                 model->operands[operand].zeroPoint = invalidZeroPoint;
404             });
405         }
406     }
407 }
408 
409 ///////////////////////// VALIDATE OPERAND LIFETIME /////////////////////////////////////////////
410 
getInvalidLifeTimes(const Model & model,size_t modelSize,const Operand & operand)411 static std::vector<OperandLifeTime> getInvalidLifeTimes(const Model& model, size_t modelSize,
412                                                         const Operand& operand) {
413     // TODO: Support OperandLifeTime::CONSTANT_REFERENCE as an invalid lifetime
414     // TODO: Support OperandLifeTime::NO_VALUE as an invalid lifetime
415 
416     // Ways to get an invalid lifetime:
417     // - change whether a lifetime means an operand should have a writer
418     std::vector<OperandLifeTime> ret;
419     switch (operand.lifetime) {
420         case OperandLifeTime::MODEL_OUTPUT:
421         case OperandLifeTime::TEMPORARY_VARIABLE:
422             ret = {
423                     OperandLifeTime::MODEL_INPUT,
424                     OperandLifeTime::CONSTANT_COPY,
425             };
426             break;
427         case OperandLifeTime::CONSTANT_COPY:
428         case OperandLifeTime::CONSTANT_REFERENCE:
429         case OperandLifeTime::MODEL_INPUT:
430             ret = {
431                     OperandLifeTime::TEMPORARY_VARIABLE,
432                     OperandLifeTime::MODEL_OUTPUT,
433             };
434             break;
435         case OperandLifeTime::NO_VALUE:
436             // Not enough information to know whether
437             // TEMPORARY_VARIABLE or CONSTANT_COPY would be invalid --
438             // is this operand written (then CONSTANT_COPY would be
439             // invalid) or not (then TEMPORARY_VARIABLE would be
440             // invalid)?
441             break;
442         default:
443             ADD_FAILURE();
444             break;
445     }
446 
447     const size_t operandSize = sizeOfData(operand);  // will be zero if shape is unknown
448     if (!operandSize ||
449         exceedsBinderSizeLimit(modelSize + constantCopyExtraSize(model, operandSize))) {
450         // Unknown size or too-large size
451         ret.erase(std::remove(ret.begin(), ret.end(), OperandLifeTime::CONSTANT_COPY), ret.end());
452     }
453 
454     return ret;
455 }
456 
mutateOperandLifeTimeTest(const sp<IDevice> & device,const V1_0::Model & model)457 static void mutateOperandLifeTimeTest(const sp<IDevice>& device, const V1_0::Model& model) {
458     const size_t modelSize = sizeForBinder(model);
459     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
460         const std::vector<OperandLifeTime> invalidLifeTimes =
461                 getInvalidLifeTimes(model, modelSize, model.operands[operand]);
462         for (OperandLifeTime invalidLifeTime : invalidLifeTimes) {
463             const std::string message = "mutateOperandLifetimeTest: operand " +
464                                         std::to_string(operand) + " has lifetime " +
465                                         toString(invalidLifeTime) + " instead of lifetime " +
466                                         toString(model.operands[operand].lifetime);
467             validate(device, message, model, [operand, invalidLifeTime](Model* model) {
468                 static const DataLocation kZeroDataLocation = {};
469                 Operand& operandObj = model->operands[operand];
470                 switch (operandObj.lifetime) {
471                     case OperandLifeTime::MODEL_INPUT: {
472                         hidl_vec_remove(&model->inputIndexes, uint32_t(operand));
473                         break;
474                     }
475                     case OperandLifeTime::MODEL_OUTPUT: {
476                         hidl_vec_remove(&model->outputIndexes, uint32_t(operand));
477                         break;
478                     }
479                     default:
480                         break;
481                 }
482                 operandObj.lifetime = invalidLifeTime;
483                 operandObj.location = kZeroDataLocation;
484                 switch (invalidLifeTime) {
485                     case OperandLifeTime::CONSTANT_COPY: {
486                         becomeConstantCopy(model, &operandObj);
487                         break;
488                     }
489                     case OperandLifeTime::MODEL_INPUT:
490                         hidl_vec_push_back(&model->inputIndexes, uint32_t(operand));
491                         break;
492                     case OperandLifeTime::MODEL_OUTPUT:
493                         hidl_vec_push_back(&model->outputIndexes, uint32_t(operand));
494                         break;
495                     default:
496                         break;
497                 }
498             });
499         }
500     }
501 }
502 
503 ///////////////////////// VALIDATE OPERAND INPUT-or-OUTPUT //////////////////////////////////////
504 
getInputOutputLifeTime(const Model & model,size_t modelSize,const Operand & operand)505 static std::optional<OperandLifeTime> getInputOutputLifeTime(const Model& model, size_t modelSize,
506                                                              const Operand& operand) {
507     // Ways to get an invalid lifetime (with respect to model inputIndexes and outputIndexes):
508     // - change whether a lifetime means an operand is a model input, a model output, or neither
509     // - preserve whether or not a lifetime means an operand should have a writer
510     switch (operand.lifetime) {
511         case OperandLifeTime::CONSTANT_COPY:
512         case OperandLifeTime::CONSTANT_REFERENCE:
513             return OperandLifeTime::MODEL_INPUT;
514         case OperandLifeTime::MODEL_INPUT: {
515             const size_t operandSize = sizeOfData(operand);  // will be zero if shape is unknown
516             if (!operandSize ||
517                 exceedsBinderSizeLimit(modelSize + constantCopyExtraSize(model, operandSize))) {
518                 // Unknown size or too-large size
519                 break;
520             }
521             return OperandLifeTime::CONSTANT_COPY;
522         }
523         case OperandLifeTime::MODEL_OUTPUT:
524             return OperandLifeTime::TEMPORARY_VARIABLE;
525         case OperandLifeTime::TEMPORARY_VARIABLE:
526             return OperandLifeTime::MODEL_OUTPUT;
527         case OperandLifeTime::NO_VALUE:
528             // Not enough information to know whether
529             // TEMPORARY_VARIABLE or CONSTANT_COPY would be an
530             // appropriate choice -- is this operand written (then
531             // TEMPORARY_VARIABLE would be appropriate) or not (then
532             // CONSTANT_COPY would be appropriate)?
533             break;
534         default:
535             ADD_FAILURE();
536             break;
537     }
538 
539     return std::nullopt;
540 }
541 
mutateOperandInputOutputTest(const sp<IDevice> & device,const V1_0::Model & model)542 static void mutateOperandInputOutputTest(const sp<IDevice>& device, const V1_0::Model& model) {
543     const size_t modelSize = sizeForBinder(model);
544     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
545         const std::optional<OperandLifeTime> changedLifeTime =
546                 getInputOutputLifeTime(model, modelSize, model.operands[operand]);
547         if (changedLifeTime) {
548             const std::string message = "mutateOperandInputOutputTest: operand " +
549                                         std::to_string(operand) + " has lifetime " +
550                                         toString(*changedLifeTime) + " instead of lifetime " +
551                                         toString(model.operands[operand].lifetime);
552             validate(device, message, model, [operand, changedLifeTime](Model* model) {
553                 static const DataLocation kZeroDataLocation = {};
554                 Operand& operandObj = model->operands[operand];
555                 operandObj.lifetime = *changedLifeTime;
556                 operandObj.location = kZeroDataLocation;
557                 if (*changedLifeTime == OperandLifeTime::CONSTANT_COPY) {
558                     becomeConstantCopy(model, &operandObj);
559                 }
560             });
561         }
562     }
563 }
564 
565 ///////////////////////// VALIDATE OPERAND NUMBER OF CONSUMERS //////////////////////////////////
566 
getInvalidNumberOfConsumers(uint32_t numberOfConsumers)567 static std::vector<uint32_t> getInvalidNumberOfConsumers(uint32_t numberOfConsumers) {
568     if (numberOfConsumers == 0) {
569         return {1};
570     } else {
571         return {numberOfConsumers - 1, numberOfConsumers + 1};
572     }
573 }
574 
mutateOperandNumberOfConsumersTest(const sp<IDevice> & device,const V1_0::Model & model)575 static void mutateOperandNumberOfConsumersTest(const sp<IDevice>& device,
576                                                const V1_0::Model& model) {
577     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
578         const std::vector<uint32_t> invalidNumberOfConsumersVec =
579                 getInvalidNumberOfConsumers(model.operands[operand].numberOfConsumers);
580         for (uint32_t invalidNumberOfConsumers : invalidNumberOfConsumersVec) {
581             const std::string message =
582                     "mutateOperandNumberOfConsumersTest: operand " + std::to_string(operand) +
583                     " numberOfConsumers = " + std::to_string(invalidNumberOfConsumers);
584             validate(device, message, model, [operand, invalidNumberOfConsumers](Model* model) {
585                 model->operands[operand].numberOfConsumers = invalidNumberOfConsumers;
586             });
587         }
588     }
589 }
590 
591 ///////////////////////// VALIDATE OPERAND NUMBER OF WRITERS ////////////////////////////////////
592 
mutateOperandAddWriterTest(const sp<IDevice> & device,const V1_0::Model & model)593 static void mutateOperandAddWriterTest(const sp<IDevice>& device, const V1_0::Model& model) {
594     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
595         for (size_t badOutputNum = 0; badOutputNum < model.operations[operation].outputs.size();
596              ++badOutputNum) {
597             const uint32_t outputOperandIndex = model.operations[operation].outputs[badOutputNum];
598             const std::string message = "mutateOperandAddWriterTest: operation " +
599                                         std::to_string(operation) + " writes to " +
600                                         std::to_string(outputOperandIndex);
601             // We'll insert a copy of the operation, all of whose
602             // OTHER output operands are newly-created -- i.e.,
603             // there'll only be a duplicate write of ONE of that
604             // operation's output operands.
605             validate(device, message, model, [operation, badOutputNum](Model* model) {
606                 Operation newOperation = model->operations[operation];
607                 for (uint32_t input : newOperation.inputs) {
608                     ++model->operands[input].numberOfConsumers;
609                 }
610                 for (size_t outputNum = 0; outputNum < newOperation.outputs.size(); ++outputNum) {
611                     if (outputNum == badOutputNum) continue;
612 
613                     Operand operandValue = model->operands[newOperation.outputs[outputNum]];
614                     operandValue.numberOfConsumers = 0;
615                     if (operandValue.lifetime == OperandLifeTime::MODEL_OUTPUT) {
616                         operandValue.lifetime = OperandLifeTime::TEMPORARY_VARIABLE;
617                     } else {
618                         ASSERT_EQ(operandValue.lifetime, OperandLifeTime::TEMPORARY_VARIABLE);
619                     }
620                     newOperation.outputs[outputNum] =
621                             hidl_vec_push_back(&model->operands, operandValue);
622                 }
623                 // Where do we insert the extra writer (a new
624                 // operation)?  It has to be later than all the
625                 // writers of its inputs.  The easiest thing to do
626                 // is to insert it at the end of the operation
627                 // sequence.
628                 hidl_vec_push_back(&model->operations, newOperation);
629             });
630         }
631     }
632 }
633 
634 ///////////////////////// VALIDATE EXTRA ??? /////////////////////////
635 
636 // TODO: Operand::location
637 
638 ///////////////////////// VALIDATE OPERATION OPERAND TYPE /////////////////////////
639 
mutateOperand(Operand * operand,OperandType type)640 static void mutateOperand(Operand* operand, OperandType type) {
641     Operand newOperand = *operand;
642     newOperand.type = type;
643     switch (type) {
644         case OperandType::FLOAT32:
645         case OperandType::INT32:
646         case OperandType::UINT32:
647             newOperand.dimensions = hidl_vec<uint32_t>();
648             newOperand.scale = 0.0f;
649             newOperand.zeroPoint = 0;
650             break;
651         case OperandType::TENSOR_FLOAT32:
652             newOperand.dimensions =
653                     operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
654             newOperand.scale = 0.0f;
655             newOperand.zeroPoint = 0;
656             break;
657         case OperandType::TENSOR_INT32:
658             newOperand.dimensions =
659                     operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
660             newOperand.zeroPoint = 0;
661             break;
662         case OperandType::TENSOR_QUANT8_ASYMM:
663             newOperand.dimensions =
664                     operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
665             newOperand.scale = operand->scale != 0.0f ? operand->scale : 1.0f;
666             break;
667         case OperandType::OEM:
668         case OperandType::TENSOR_OEM_BYTE:
669         default:
670             break;
671     }
672     *operand = newOperand;
673 }
674 
mutateOperationOperandTypeSkip(size_t operand,const Model & model)675 static bool mutateOperationOperandTypeSkip(size_t operand, const Model& model) {
676     // LSH_PROJECTION's second argument is allowed to have any type. This is the
677     // only operation that currently has a type that can be anything independent
678     // from any other type. Changing the operand type to any other type will
679     // result in a valid model for LSH_PROJECTION. If this is the case, skip the
680     // test.
681     for (const Operation& operation : model.operations) {
682         if (operation.type == OperationType::LSH_PROJECTION && operand == operation.inputs[1]) {
683             return true;
684         }
685     }
686     return false;
687 }
688 
mutateOperationOperandTypeTest(const sp<IDevice> & device,const Model & model)689 static void mutateOperationOperandTypeTest(const sp<IDevice>& device, const Model& model) {
690     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
691         if (mutateOperationOperandTypeSkip(operand, model)) {
692             continue;
693         }
694         for (OperandType invalidOperandType : hidl_enum_range<OperandType>{}) {
695             // Do not test OEM types
696             if (invalidOperandType == model.operands[operand].type ||
697                 invalidOperandType == OperandType::OEM ||
698                 invalidOperandType == OperandType::TENSOR_OEM_BYTE) {
699                 continue;
700             }
701             const std::string message = "mutateOperationOperandTypeTest: operand " +
702                                         std::to_string(operand) + " set to type " +
703                                         toString(invalidOperandType);
704             validate(device, message, model, [operand, invalidOperandType](Model* model) {
705                 mutateOperand(&model->operands[operand], invalidOperandType);
706             });
707         }
708     }
709 }
710 
711 ///////////////////////// VALIDATE MODEL OPERATION TYPE /////////////////////////
712 
713 static const int32_t invalidOperationTypes[] = {
714         static_cast<int32_t>(OperationType::ADD) - 1,            // lower bound fundamental
715         static_cast<int32_t>(OperationType::TANH) + 1,           // upper bound fundamental
716         static_cast<int32_t>(OperationType::OEM_OPERATION) - 1,  // lower bound OEM
717         static_cast<int32_t>(OperationType::OEM_OPERATION) + 1,  // upper bound OEM
718 };
719 
mutateOperationTypeTest(const sp<IDevice> & device,const Model & model)720 static void mutateOperationTypeTest(const sp<IDevice>& device, const Model& model) {
721     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
722         for (int32_t invalidOperationType : invalidOperationTypes) {
723             const std::string message = "mutateOperationTypeTest: operation " +
724                                         std::to_string(operation) + " set to value " +
725                                         std::to_string(invalidOperationType);
726             validate(device, message, model, [operation, invalidOperationType](Model* model) {
727                 model->operations[operation].type =
728                         static_cast<OperationType>(invalidOperationType);
729             });
730         }
731     }
732 }
733 
734 ///////////////////////// VALIDATE MODEL OPERATION INPUT OPERAND INDEX /////////////////////////
735 
mutateOperationInputOperandIndexTest(const sp<IDevice> & device,const Model & model)736 static void mutateOperationInputOperandIndexTest(const sp<IDevice>& device, const Model& model) {
737     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
738         const uint32_t invalidOperand = model.operands.size();
739         for (size_t input = 0; input < model.operations[operation].inputs.size(); ++input) {
740             const std::string message = "mutateOperationInputOperandIndexTest: operation " +
741                                         std::to_string(operation) + " input " +
742                                         std::to_string(input);
743             validate(device, message, model, [operation, input, invalidOperand](Model* model) {
744                 model->operations[operation].inputs[input] = invalidOperand;
745             });
746         }
747     }
748 }
749 
750 ///////////////////////// VALIDATE MODEL OPERATION OUTPUT OPERAND INDEX /////////////////////////
751 
mutateOperationOutputOperandIndexTest(const sp<IDevice> & device,const Model & model)752 static void mutateOperationOutputOperandIndexTest(const sp<IDevice>& device, const Model& model) {
753     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
754         const uint32_t invalidOperand = model.operands.size();
755         for (size_t output = 0; output < model.operations[operation].outputs.size(); ++output) {
756             const std::string message = "mutateOperationOutputOperandIndexTest: operation " +
757                                         std::to_string(operation) + " output " +
758                                         std::to_string(output);
759             validate(device, message, model, [operation, output, invalidOperand](Model* model) {
760                 model->operations[operation].outputs[output] = invalidOperand;
761             });
762         }
763     }
764 }
765 
766 ///////////////////////// VALIDATE MODEL OPERANDS WRITTEN ///////////////////////////////////////
767 
mutateOperationRemoveWriteTest(const sp<IDevice> & device,const V1_0::Model & model)768 static void mutateOperationRemoveWriteTest(const sp<IDevice>& device, const V1_0::Model& model) {
769     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
770         for (size_t outputNum = 0; outputNum < model.operations[operation].outputs.size();
771              ++outputNum) {
772             const uint32_t outputOperandIndex = model.operations[operation].outputs[outputNum];
773             if (model.operands[outputOperandIndex].numberOfConsumers > 0) {
774                 const std::string message = "mutateOperationRemoveWriteTest: operation " +
775                                             std::to_string(operation) + " writes to " +
776                                             std::to_string(outputOperandIndex);
777                 validate(device, message, model, [operation, outputNum](Model* model) {
778                     uint32_t& outputOperandIndex = model->operations[operation].outputs[outputNum];
779                     Operand operandValue = model->operands[outputOperandIndex];
780                     operandValue.numberOfConsumers = 0;
781                     if (operandValue.lifetime == OperandLifeTime::MODEL_OUTPUT) {
782                         operandValue.lifetime = OperandLifeTime::TEMPORARY_VARIABLE;
783                     } else {
784                         ASSERT_EQ(operandValue.lifetime, OperandLifeTime::TEMPORARY_VARIABLE);
785                     }
786                     outputOperandIndex = hidl_vec_push_back(&model->operands, operandValue);
787                 });
788             }
789         }
790     }
791 }
792 
793 ///////////////////////// REMOVE OPERAND FROM EVERYTHING /////////////////////////
794 
removeValueAndDecrementGreaterValues(hidl_vec<uint32_t> * vec,uint32_t value)795 static void removeValueAndDecrementGreaterValues(hidl_vec<uint32_t>* vec, uint32_t value) {
796     if (vec) {
797         // remove elements matching "value"
798         auto last = std::remove(vec->begin(), vec->end(), value);
799         vec->resize(std::distance(vec->begin(), last));
800 
801         // decrement elements exceeding "value"
802         std::transform(vec->begin(), vec->end(), vec->begin(),
803                        [value](uint32_t v) { return v > value ? v-- : v; });
804     }
805 }
806 
removeOperand(Model * model,uint32_t index)807 static void removeOperand(Model* model, uint32_t index) {
808     hidl_vec_removeAt(&model->operands, index);
809     for (Operation& operation : model->operations) {
810         removeValueAndDecrementGreaterValues(&operation.inputs, index);
811         removeValueAndDecrementGreaterValues(&operation.outputs, index);
812     }
813     removeValueAndDecrementGreaterValues(&model->inputIndexes, index);
814     removeValueAndDecrementGreaterValues(&model->outputIndexes, index);
815 }
816 
removeOperandTest(const sp<IDevice> & device,const Model & model)817 static void removeOperandTest(const sp<IDevice>& device, const Model& model) {
818     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
819         const std::string message = "removeOperandTest: operand " + std::to_string(operand);
820         validate(device, message, model,
821                  [operand](Model* model) { removeOperand(model, operand); });
822     }
823 }
824 
825 ///////////////////////// REMOVE OPERATION /////////////////////////
826 
removeOperation(Model * model,uint32_t index)827 static void removeOperation(Model* model, uint32_t index) {
828     for (uint32_t operand : model->operations[index].inputs) {
829         model->operands[operand].numberOfConsumers--;
830     }
831     hidl_vec_removeAt(&model->operations, index);
832 }
833 
removeOperationTest(const sp<IDevice> & device,const Model & model)834 static void removeOperationTest(const sp<IDevice>& device, const Model& model) {
835     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
836         const std::string message = "removeOperationTest: operation " + std::to_string(operation);
837         validate(device, message, model,
838                  [operation](Model* model) { removeOperation(model, operation); });
839     }
840 }
841 
842 ///////////////////////// REMOVE OPERATION INPUT /////////////////////////
843 
removeOperationInputTest(const sp<IDevice> & device,const Model & model)844 static void removeOperationInputTest(const sp<IDevice>& device, const Model& model) {
845     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
846         for (size_t input = 0; input < model.operations[operation].inputs.size(); ++input) {
847             const Operation& op = model.operations[operation];
848             // CONCATENATION has at least 2 inputs, with the last element being
849             // INT32. Skip this test if removing one of CONCATENATION's
850             // inputs still produces a valid model.
851             if (op.type == OperationType::CONCATENATION && op.inputs.size() > 2 &&
852                 input != op.inputs.size() - 1) {
853                 continue;
854             }
855             const std::string message = "removeOperationInputTest: operation " +
856                                         std::to_string(operation) + ", input " +
857                                         std::to_string(input);
858             validate(device, message, model, [operation, input](Model* model) {
859                 uint32_t operand = model->operations[operation].inputs[input];
860                 model->operands[operand].numberOfConsumers--;
861                 hidl_vec_removeAt(&model->operations[operation].inputs, input);
862             });
863         }
864     }
865 }
866 
867 ///////////////////////// REMOVE OPERATION OUTPUT /////////////////////////
868 
removeOperationOutputTest(const sp<IDevice> & device,const Model & model)869 static void removeOperationOutputTest(const sp<IDevice>& device, const Model& model) {
870     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
871         for (size_t output = 0; output < model.operations[operation].outputs.size(); ++output) {
872             const std::string message = "removeOperationOutputTest: operation " +
873                                         std::to_string(operation) + ", output " +
874                                         std::to_string(output);
875             validate(device, message, model, [operation, output](Model* model) {
876                 hidl_vec_removeAt(&model->operations[operation].outputs, output);
877             });
878         }
879     }
880 }
881 
882 ///////////////////////// MODEL VALIDATION /////////////////////////
883 
884 // TODO: remove model input
885 // TODO: remove model output
886 // TODO: add unused operation
887 
888 ///////////////////////// ADD OPERATION INPUT /////////////////////////
889 
addOperationInputTest(const sp<IDevice> & device,const Model & model)890 static void addOperationInputTest(const sp<IDevice>& device, const Model& model) {
891     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
892         const std::string message = "addOperationInputTest: operation " + std::to_string(operation);
893         validate(device, message, model, [operation](Model* model) {
894             uint32_t index = addOperand(model, OperandLifeTime::MODEL_INPUT);
895             hidl_vec_push_back(&model->operations[operation].inputs, index);
896             hidl_vec_push_back(&model->inputIndexes, index);
897         });
898     }
899 }
900 
901 ///////////////////////// ADD OPERATION OUTPUT /////////////////////////
902 
addOperationOutputTest(const sp<IDevice> & device,const Model & model)903 static void addOperationOutputTest(const sp<IDevice>& device, const Model& model) {
904     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
905         const std::string message =
906                 "addOperationOutputTest: operation " + std::to_string(operation);
907         validate(device, message, model, [operation](Model* model) {
908             uint32_t index = addOperand(model, OperandLifeTime::MODEL_OUTPUT);
909             hidl_vec_push_back(&model->operations[operation].outputs, index);
910             hidl_vec_push_back(&model->outputIndexes, index);
911         });
912     }
913 }
914 
915 ////////////////////////// ENTRY POINT //////////////////////////////
916 
validateModel(const sp<IDevice> & device,const Model & model)917 void validateModel(const sp<IDevice>& device, const Model& model) {
918     mutateExecutionOrderTest(device, model);
919     mutateOperandTypeTest(device, model);
920     mutateOperandRankTest(device, model);
921     mutateOperandScaleTest(device, model);
922     mutateOperandZeroPointTest(device, model);
923     mutateOperandLifeTimeTest(device, model);
924     mutateOperandInputOutputTest(device, model);
925     mutateOperandNumberOfConsumersTest(device, model);
926     mutateOperandAddWriterTest(device, model);
927     mutateOperationOperandTypeTest(device, model);
928     mutateOperationTypeTest(device, model);
929     mutateOperationInputOperandIndexTest(device, model);
930     mutateOperationOutputOperandIndexTest(device, model);
931     mutateOperationRemoveWriteTest(device, model);
932     removeOperandTest(device, model);
933     removeOperationTest(device, model);
934     removeOperationInputTest(device, model);
935     removeOperationOutputTest(device, model);
936     addOperationInputTest(device, model);
937     addOperationOutputTest(device, model);
938 }
939 
940 }  // namespace android::hardware::neuralnetworks::V1_0::vts::functional
941