/* * Copyright (C) 2022 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #define LOG_TAG "ModelUtils" #include "ModelUtils.h" #include #include #include #include #include #include #include "nnapi/TypeUtils.h" #include "nnapi/Types.h" #include "nnapi/Validation.h" namespace android::nn { namespace { // Map each `true` value in `includes` with a unique integer. `false` values are ignored. E.g.: // includes = {false, true, true, false, true} // returned = { X, 0, 1, X, 2} std::vector getMapping(const std::vector& includes) { std::vector mapping; mapping.reserve(includes.size()); std::transform_exclusive_scan(includes.begin(), includes.end(), std::back_inserter(mapping), 0u, std::plus<>{}, [](bool included) { return included ? 1u : 0u; }); return mapping; } // Remap indexes in `indexes` by the mapping `mapping`. // Precondition: indexes != nullptr void remapIndexes(std::vector* indexes, const std::vector& mapping) { CHECK(indexes != nullptr); for (uint32_t& index : (*indexes)) { index = mapping.at(index); } } // Keep elements from `elements` specified by `elementsToKeep`, removing all other elements. // Precondition: elements != nullptr // Precondition: elements->size() == elementsToKeep.size() template void keepSelectedElements(std::vector* elements, const std::vector& elementsToKeep) { CHECK(elements != nullptr); CHECK_EQ(elements->size(), elementsToKeep.size()); size_t elementsCopied = 0; for (size_t i = 0; i < elementsToKeep.size(); ++i) { if (elementsToKeep[i]) { if (elementsCopied != i) { (*elements)[elementsCopied] = std::move((*elements)[i]); } elementsCopied++; } } elements->resize(elementsCopied); } // Find which operands in model.main.operands are read or written by model.main.operations and // model.main.inputIndexes. // Postcondition: returned.size() == model.main.operands.size() std::vector identifyUsedOperands(const Model& model) { std::vector used(model.main.operands.size(), false); auto markUsed = [&used](const std::vector& indexes) { std::for_each(indexes.begin(), indexes.end(), [&used](uint32_t index) { used.at(index) = true; }); }; for (const auto& operation : model.main.operations) { markUsed(operation.inputs); markUsed(operation.outputs); } markUsed(model.main.inputIndexes); CHECK_EQ(used.size(), model.main.operands.size()); return used; } // Forward declaration. void identifyUsedSubgraphs(uint32_t current, const std::vector& subgraphs, std::vector* used); // Helper function to find which subgraphs are reachable by `operands`. // Precondition: used != nullptr // Precondition: subgraphs.size() == used->size() void identifyUsedSubgraphs(const std::vector& operands, const std::vector& subgraphs, std::vector* used) { for (const auto& operand : operands) { if (operand.lifetime == Operand::LifeTime::SUBGRAPH) { identifyUsedSubgraphs(operand.location.offset, subgraphs, used); } } } // Helper function to find which subgraphs are reachable by the subgraph at the `current` index, and // store when a subgraph is used in `used`. `used` also acts as a cache, ensuring each subgraph is // processed at most once. // Precondition: used != nullptr // Precondition: subgraphs.size() == used->size() // Precondition: current < subgraphs.size() void identifyUsedSubgraphs(uint32_t current, const std::vector& subgraphs, std::vector* used) { CHECK(used != nullptr); CHECK_EQ(subgraphs.size(), used->size()); CHECK_LT(current, subgraphs.size()); // If a subgraph was already marked as used, quickly return to avoid redundant processing. if ((*used)[current]) { return; } // Mark the current subgraph as used, then process any subgraph it references recursively. (*used)[current] = true; identifyUsedSubgraphs(subgraphs[current].operands, subgraphs, used); } // Find which subgraphs are reachable by the main operands of `model`. // Postcondition: returned.size() == model.referenced.size() std::vector identifyUsedSubgraphs(const Model& model) { std::vector used(model.referenced.size(), false); identifyUsedSubgraphs(model.main.operands, model.referenced, &used); CHECK_EQ(used.size(), model.referenced.size()); return used; } // Helper function to find which pools are used by `subgraph`, and store when a pool is used in // `used`. // Precondition: used != nullptr void identifyUsedPools(const Model::Subgraph& subgraph, std::vector* used) { CHECK(used != nullptr); for (const auto& operand : subgraph.operands) { if (operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE) { used->at(operand.location.poolIndex) = true; } } } // Find which pools are used by `model`. // Postcondition: returned.size() == model.pools.size() std::vector identifyUsedPools(const Model& model) { std::vector used(model.pools.size(), false); identifyUsedPools(model.main, &used); for (const auto& subgraph : model.referenced) { identifyUsedPools(subgraph, &used); } CHECK_EQ(used.size(), model.pools.size()); return used; } // Fix the DataLocation in `operand` by either remapping an index or by copying constant data. // Precondition: operand != nullptr // Precondition: newOperandValues != nullptr void fixOperandDataLocation(Operand* operand, Model::OperandValues* newOperandValues, const Model::OperandValues& oldOperandValues, const std::vector& remappedPoolIndex, const std::vector& remappedSubgraphIndex) { CHECK(operand != nullptr); CHECK(newOperandValues != nullptr); switch (operand->lifetime) { case Operand::LifeTime::CONSTANT_COPY: { const uint8_t* data = oldOperandValues.data() + operand->location.offset; const uint32_t length = operand->location.length; operand->location = newOperandValues->append(data, length); break; } case Operand::LifeTime::CONSTANT_REFERENCE: operand->location.poolIndex = remappedPoolIndex.at(operand->location.poolIndex); break; case Operand::LifeTime::SUBGRAPH: { uint32_t& subgraphIndex = operand->location.offset; subgraphIndex = remappedSubgraphIndex.at(subgraphIndex); break; } case Operand::LifeTime::TEMPORARY_VARIABLE: case Operand::LifeTime::SUBGRAPH_INPUT: case Operand::LifeTime::SUBGRAPH_OUTPUT: case Operand::LifeTime::NO_VALUE: case Operand::LifeTime::POINTER: break; } } // Fix all DataLocations in `operands` by either remapping an index or by copying constant data. // Precondition: operands != nullptr // Precondition: newOperandValues != nullptr void fixOperandDataLocations(std::vector* operands, Model::OperandValues* newOperandValues, const Model::OperandValues& oldOperandValues, const std::vector& remappedPoolIndex, const std::vector& remappedSubgraphIndex) { for (Operand& operand : (*operands)) { fixOperandDataLocation(&operand, newOperandValues, oldOperandValues, remappedPoolIndex, remappedSubgraphIndex); } } // Fix all operands' DataLocations in `model` by either remapping an index or by copying constant // data. // Precondition: model != nullptr void fixOperandDataLocations(Model* model, const std::vector& remappedPoolIndex, const std::vector& remappedSubgraphIndex) { const auto operandValues = std::exchange(model->operandValues, Model::OperandValues{}); fixOperandDataLocations(&model->main.operands, &model->operandValues, operandValues, remappedPoolIndex, remappedSubgraphIndex); for (auto& subgraph : model->referenced) { fixOperandDataLocations(&subgraph.operands, &model->operandValues, operandValues, remappedPoolIndex, remappedSubgraphIndex); } } // Find which extensions are used in `model`. // Postcondition: returned.size() == model.extensionNameToPrefix.size() std::vector identifyUsedExtensions(const Model& model) { std::unordered_set prefixes; const auto collectPrefix = [&prefixes](const auto& operandOrOperation) { const auto prefix = getExtensionPrefix(static_cast(operandOrOperation.type)); constexpr uint16_t kStandardPrefix = 0u; if (prefix != kStandardPrefix) { prefixes.insert(prefix); } }; const auto collectPrefixes = [collectPrefix](const Model::Subgraph& subgraph) { std::for_each(subgraph.operands.begin(), subgraph.operands.end(), collectPrefix); std::for_each(subgraph.operations.begin(), subgraph.operations.end(), collectPrefix); }; collectPrefixes(model.main); for (const auto& subgraph : model.referenced) { collectPrefixes(subgraph); } std::vector used; used.reserve(model.extensionNameToPrefix.size()); for (const auto& extension : model.extensionNameToPrefix) { used.push_back(prefixes.count(extension.prefix) > 0); } CHECK_EQ(used.size(), model.extensionNameToPrefix.size()); return used; } } // anonymous namespace void removeDeadOperands(Model* model) { CHECK(model != nullptr); // Keep only the operands which are used. const auto operandsUsed = identifyUsedOperands(*model); keepSelectedElements(&model->main.operands, operandsUsed); // Fix operand indexes. const auto mappedOperandIndices = getMapping(operandsUsed); for (auto& operation : model->main.operations) { remapIndexes(&operation.inputs, mappedOperandIndices); remapIndexes(&operation.outputs, mappedOperandIndices); } remapIndexes(&model->main.inputIndexes, mappedOperandIndices); remapIndexes(&model->main.outputIndexes, mappedOperandIndices); // Keep only the subgraphs which are used. const auto subgraphsUsed = identifyUsedSubgraphs(*model); keepSelectedElements(&model->referenced, subgraphsUsed); // Keep only the pools which are used. const auto poolsUsed = identifyUsedPools(*model); keepSelectedElements(&model->pools, poolsUsed); // Fix operand locations. const auto mappedPoolIndices = getMapping(poolsUsed); const auto mappedSubgraphIndices = getMapping(subgraphsUsed); fixOperandDataLocations(model, mappedPoolIndices, mappedSubgraphIndices); // Keep only the extensionNameToPrefixes which are used. const auto extensionsUsed = identifyUsedExtensions(*model); keepSelectedElements(&model->extensionNameToPrefix, extensionsUsed); } } // namespace android::nn