• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "InstanceNormalization.h"
20 
21 #include <cmath>
22 #include <vector>
23 
24 #include "OperationResolver.h"
25 #include "Tracing.h"
26 
27 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
28 #include "CpuOperationUtils.h"
29 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
30 
31 namespace android {
32 namespace nn {
33 namespace instance_normalization {
34 
35 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
36 namespace {
37 
38 template <typename T>
instanceNormNhwc(const T * inputData,const Shape & inputShape,T gamma,T beta,T epsilon,T * outputData,const Shape &)39 inline bool instanceNormNhwc(const T* inputData, const Shape& inputShape, T gamma, T beta,
40                              T epsilon, T* outputData, const Shape& /*outputShape*/) {
41     NNTRACE_TRANS("InstanceNormalizationNhwc");
42     uint32_t numBatches = getSizeOfDimension(inputShape, 0);
43     uint32_t height = getSizeOfDimension(inputShape, 1);
44     uint32_t width = getSizeOfDimension(inputShape, 2);
45     uint32_t depth = getSizeOfDimension(inputShape, 3);
46     for (uint32_t b = 0; b < numBatches; b++) {
47         for (uint32_t d = 0; d < depth; d++) {
48             uint32_t indexBase = b * height * width * depth + d;
49             T mean = 0, sigma = 0;
50 
51             // Compute the mean of a single layer.
52             for (uint32_t h = 0; h < height; h++) {
53                 for (uint32_t w = 0; w < width; w++) {
54                     T val = inputData[indexBase + (h * width + w) * depth];
55                     mean += val;
56                 }
57             }
58             mean /= static_cast<T>(height * width);
59 
60             // Compute the standard deviation (sigma) of a single layer.
61             for (uint32_t h = 0; h < height; h++) {
62                 for (uint32_t w = 0; w < width; w++) {
63                     T val = inputData[indexBase + (h * width + w) * depth] - mean;
64                     sigma += val * val;
65                 }
66             }
67             sigma = std::sqrt(static_cast<float>(sigma / static_cast<T>(height * width)) + epsilon);
68 
69             // Apply instance normalization.
70             for (uint32_t h = 0; h < height; h++) {
71                 for (uint32_t w = 0; w < width; w++) {
72                     uint32_t ind = indexBase + (h * width + w) * depth;
73                     outputData[ind] = (inputData[ind] - mean) * gamma / sigma + beta;
74                 }
75             }
76         }
77     }
78     return true;
79 }
80 
81 template <typename T>
instanceNorm(const T * inputData,const Shape & inputShape,T gamma,T beta,T epsilon,bool useNchw,T * outputData,const Shape & outputShape)82 inline bool instanceNorm(const T* inputData, const Shape& inputShape, T gamma, T beta, T epsilon,
83                          bool useNchw, T* outputData, const Shape& outputShape) {
84     InputWithLayout<T> input(useNchw);
85     OutputWithLayout<T> output(useNchw);
86     NN_RET_CHECK(input.initialize(inputData, inputShape));
87     NN_RET_CHECK(output.initialize(outputData, outputShape));
88     NN_RET_CHECK(instanceNormNhwc(input.getNhwcBuffer(), input.getNhwcShape(), gamma, beta, epsilon,
89                                   output.getNhwcBuffer(), output.getNhwcShape()));
90     NN_RET_CHECK(output.commit());
91     return true;
92 }
93 
94 }  // namespace
95 
prepare(IOperationExecutionContext * context)96 bool prepare(IOperationExecutionContext* context) {
97     Shape input = context->getInputShape(kInputTensor);
98     NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4u);
99     return context->setOutputShape(kOutputTensor, input);
100 }
101 
execute(IOperationExecutionContext * context)102 bool execute(IOperationExecutionContext* context) {
103     switch (context->getInputType(kInputTensor)) {
104         case OperandType::TENSOR_FLOAT16:
105             return instanceNorm(context->getInputBuffer<_Float16>(kInputTensor),
106                                 context->getInputShape(kInputTensor),
107                                 context->getInputValue<_Float16>(kGammaScalar),
108                                 context->getInputValue<_Float16>(kBetaScalar),
109                                 context->getInputValue<_Float16>(kEpsilonScalar),
110                                 context->getInputValue<bool>(kLayoutScalar),
111                                 context->getOutputBuffer<_Float16>(kOutputTensor),
112                                 context->getOutputShape(kOutputTensor));
113         case OperandType::TENSOR_FLOAT32:
114             return instanceNorm(context->getInputBuffer<float>(kInputTensor),
115                                 context->getInputShape(kInputTensor),
116                                 context->getInputValue<float>(kGammaScalar),
117                                 context->getInputValue<float>(kBetaScalar),
118                                 context->getInputValue<float>(kEpsilonScalar),
119                                 context->getInputValue<bool>(kLayoutScalar),
120                                 context->getOutputBuffer<float>(kOutputTensor),
121                                 context->getOutputShape(kOutputTensor));
122         default:
123             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
124     }
125 }
126 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
127 
128 }  // namespace instance_normalization
129 
130 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(INSTANCE_NORMALIZATION, instance_normalization::prepare,
131                                          instance_normalization::execute);
132 
133 }  // namespace nn
134 }  // namespace android
135