1#
2# Copyright (C) 2020 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16
17# Model: given n, produces [fib(n), fib(n + 1)].
18#
19# fib = [1, 1]
20# i = 1
21# while i < n:
22#     fib = matmul(fib, [0 1;
23#                        1 1])
24#     i = i + 1
25
26FibType = ["TENSOR_FLOAT32", [1, 2]]
27FibTypeQuant8 = ["TENSOR_QUANT8_ASYMM", 1.0, 0]
28FibTypeQuant8Signed = ["TENSOR_QUANT8_ASYMM_SIGNED", 1.0, 0]
29CounterType = ["TENSOR_INT32", [1]]
30BoolType = ["TENSOR_BOOL8", [1]]
31
32def MakeConditionModel():
33  fib = Input("fib", FibType)
34  i = Input("i", CounterType)
35  n = Input("n", CounterType)
36  out = Output("out", BoolType)
37  model = Model()
38  model.IdentifyInputs(fib, i, n)
39  model.IdentifyOutputs(out)
40  model.Operation("LESS", i, n).To(out)
41
42  quant8 = DataTypeConverter().Identify({
43      fib: FibTypeQuant8,
44  })
45  quant8_signed = DataTypeConverter().Identify({
46      fib: FibTypeQuant8Signed,
47  })
48
49  return SubgraphReference("cond", model), quant8, quant8_signed
50
51def MakeBodyModel():
52  fib = Input("fib", FibType)
53  i = Input("i", CounterType)
54  n = Input("n", CounterType)
55  fib_out = Output("fib_out", FibType)
56  i_out = Output("i_out", CounterType)
57  matrix = Parameter("matrix", ["TENSOR_FLOAT32", [2, 2]], [0, 1, 1, 1])
58  zero_bias = Parameter("zero_bias", ["TENSOR_FLOAT32", [2]], [0, 0])
59  model = Model()
60  model.IdentifyInputs(fib, i, n)
61  model.IdentifyOutputs(fib_out, i_out)
62  model.Operation("ADD", i, [1], 0).To(i_out)
63  model.Operation("FULLY_CONNECTED", fib, matrix, zero_bias, 0).To(fib_out)
64
65  quant8 = DataTypeConverter().Identify({
66      fib: FibTypeQuant8,
67      matrix: FibTypeQuant8,
68      zero_bias: ["TENSOR_INT32", 1.0, 0],
69      fib_out: FibTypeQuant8,
70  })
71  quant8_signed = DataTypeConverter().Identify({
72      fib: FibTypeQuant8Signed,
73      matrix: FibTypeQuant8Signed,
74      zero_bias: ["TENSOR_INT32", 1.0, 0],
75      fib_out: FibTypeQuant8Signed,
76  })
77
78  return SubgraphReference("body", model), quant8, quant8_signed
79
80def Test(n_data, fib_data, add_unused_output=False):
81  n = Input("n", CounterType)
82  fib_out = Output("fib_out", FibType)
83  cond, cond_quant8, cond_quant8_signed = MakeConditionModel()
84  body, body_quant8, body_quant8_signed = MakeBodyModel()
85  fib_init = Parameter("fib_init", FibType, [1, 1])
86  i_init = [1]
87  outputs = [fib_out]
88  if add_unused_output:
89    i_out = Internal("i_out", CounterType)  # Unused.
90    outputs.append(i_out)
91  model = Model().Operation("WHILE", cond, body, fib_init, i_init, n).To(outputs)
92
93  quant8 = DataTypeConverter().Identify({
94      fib_init: FibTypeQuant8,
95      fib_out: FibTypeQuant8,
96      cond: cond_quant8,
97      body: body_quant8,
98  })
99  quant8_signed = DataTypeConverter().Identify({
100      fib_init: FibTypeQuant8Signed,
101      fib_out: FibTypeQuant8Signed,
102      cond: cond_quant8_signed,
103      body: body_quant8_signed,
104  })
105
106  name = "n_{}".format(n_data)
107  if add_unused_output:
108    name += "_unused_output"
109  example = Example({n: [n_data], fib_out: fib_data}, name=name)
110  example.AddVariations("relaxed", "float16", quant8, quant8_signed)
111  example.AddVariations(AllOutputsAsInternalCoverter())
112
113for use_shm_for_weights in [False, True]:
114  Configuration.use_shm_for_weights = use_shm_for_weights
115  # Fibonacci numbers: 1 1 2 3 5 8
116  Test(n_data=1, fib_data=[1, 1], add_unused_output=True)
117  Test(n_data=2, fib_data=[1, 2], add_unused_output=True)
118  Test(n_data=3, fib_data=[2, 3], add_unused_output=True)
119  Test(n_data=4, fib_data=[3, 5])
120  Test(n_data=5, fib_data=[5, 8])
121