1# 2# Copyright (C) 2021 The Android Open Source Project 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# 16 17from itertools import chain 18 19def test(name, axis_value, input_tensor, output_tensor, input_data, output_data): 20 model = Model().Operation("REVERSE", input_tensor, [axis_value]).To(output_tensor) 21 quant8_asymm_type = ("TENSOR_QUANT8_ASYMM", 0.5, 4) 22 quant8_asymm = DataTypeConverter(name="quant8_asymm").Identify({ 23 input_tensor: quant8_asymm_type, 24 output_tensor: quant8_asymm_type, 25 }) 26 quant8_asymm_signed_type = ("TENSOR_QUANT8_ASYMM_SIGNED", 0.25, -9) 27 quant8_asymm_signed = DataTypeConverter(name="quant8_asymm_signed").Identify({ 28 input_tensor: quant8_asymm_signed_type, 29 output_tensor: quant8_asymm_signed_type, 30 }) 31 Example({ 32 input_tensor: input_data, 33 output_tensor: output_data, 34 }, model=model, name=name).AddVariations("float16", quant8_asymm, quant8_asymm_signed, "int32") 35 36def rrange(hi, lo): 37 return range(hi, lo, -1) 38 39test( 40 name="dim1", 41 axis_value=0, 42 input_tensor=Input("in", ("TENSOR_FLOAT32", [3])), 43 output_tensor=Output("out", ("TENSOR_FLOAT32", [3])), 44 input_data=[6, 7, 8], 45 output_data=[8, 7, 6], 46 ) 47 48test( 49 name="dim3_axis0", 50 axis_value=0, 51 input_tensor=Input("in", ("TENSOR_FLOAT32", [2,3,4])), 52 output_tensor=Output("out", ("TENSOR_FLOAT32", [2,3,4])), 53 input_data = list(range(24)), 54 output_data = list(chain(range(12,24), range(0,12))), 55 ) 56 57test( 58 name="dim3_axis1", 59 axis_value=1, 60 input_tensor=Input("in", ("TENSOR_FLOAT32", [2,3,4])), 61 output_tensor=Output("out", ("TENSOR_FLOAT32", [2,3,4])), 62 input_data = list(range(24)), 63 output_data = list(chain(range(8,12), range(4,8), range(0,4), 64 range(20,24), range(16,20), range(12,16))), 65 ) 66 67test( 68 name="dim3_axis2", 69 axis_value=2, 70 input_tensor=Input("in", ("TENSOR_FLOAT32", [2,3,4])), 71 output_tensor=Output("out", ("TENSOR_FLOAT32", [2,3,4])), 72 input_data = list(range(24)), 73 output_data = list(chain(rrange(3,-1), rrange(7,3), rrange(11,7), 74 rrange(15,11), rrange(19,15), rrange(23,19))) 75 ) 76