1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.internal.util;
18 
19 import android.annotation.NonNull;
20 import android.annotation.Nullable;
21 
22 import java.util.Comparator;
23 import java.util.List;
24 
25 /**
26  * An implementation of the quick selection algorithm as described in
27  * http://en.wikipedia.org/wiki/Quickselect.
28  *
29  * @hide
30  */
31 @android.ravenwood.annotation.RavenwoodKeepWholeClass
32 public final class QuickSelect {
selectImpl(@onNull List<T> list, int left, int right, int k, @NonNull Comparator<? super T> comparator)33     private static <T> int selectImpl(@NonNull List<T> list, int left, int right, int k,
34             @NonNull Comparator<? super T> comparator) {
35         while (true) {
36             if (left == right) {
37                 return left;
38             }
39             final int pivotIndex = partition(list, left, right, (left + right) >> 1, comparator);
40             if (k == pivotIndex) {
41                 return k;
42             } else if (k < pivotIndex) {
43                 right = pivotIndex - 1;
44             } else {
45                 left = pivotIndex + 1;
46             }
47         }
48     }
49 
selectImpl(@onNull int[] array, int left, int right, int k)50     private static int selectImpl(@NonNull int[] array, int left, int right, int k) {
51         while (true) {
52             if (left == right) {
53                 return left;
54             }
55             final int pivotIndex = partition(array, left, right, (left + right) >> 1);
56             if (k == pivotIndex) {
57                 return k;
58             } else if (k < pivotIndex) {
59                 right = pivotIndex - 1;
60             } else {
61                 left = pivotIndex + 1;
62             }
63         }
64     }
65 
selectImpl(@onNull long[] array, int left, int right, int k)66     private static int selectImpl(@NonNull long[] array, int left, int right, int k) {
67         while (true) {
68             if (left == right) {
69                 return left;
70             }
71             final int pivotIndex = partition(array, left, right, (left + right) >> 1);
72             if (k == pivotIndex) {
73                 return k;
74             } else if (k < pivotIndex) {
75                 right = pivotIndex - 1;
76             } else {
77                 left = pivotIndex + 1;
78             }
79         }
80     }
81 
selectImpl(@onNull T[] array, int left, int right, int k, @NonNull Comparator<? super T> comparator)82     private static <T> int selectImpl(@NonNull T[] array, int left, int right, int k,
83             @NonNull Comparator<? super T> comparator) {
84         while (true) {
85             if (left == right) {
86                 return left;
87             }
88             final int pivotIndex = partition(array, left, right, (left + right) >> 1, comparator);
89             if (k == pivotIndex) {
90                 return k;
91             } else if (k < pivotIndex) {
92                 right = pivotIndex - 1;
93             } else {
94                 left = pivotIndex + 1;
95             }
96         }
97     }
98 
partition(@onNull List<T> list, int left, int right, int pivotIndex, @NonNull Comparator<? super T> comparator)99     private static <T> int partition(@NonNull List<T> list, int left, int right, int pivotIndex,
100             @NonNull Comparator<? super T> comparator) {
101         final T pivotValue = list.get(pivotIndex);
102         swap(list, right, pivotIndex);
103         int storeIndex = left;
104         for (int i = left; i < right; i++) {
105             if (comparator.compare(list.get(i), pivotValue) < 0) {
106                 swap(list, storeIndex, i);
107                 storeIndex++;
108             }
109         }
110         swap(list, right, storeIndex);
111         return storeIndex;
112     }
113 
partition(@onNull int[] array, int left, int right, int pivotIndex)114     private static int partition(@NonNull int[] array, int left, int right, int pivotIndex) {
115         final int pivotValue = array[pivotIndex];
116         swap(array, right, pivotIndex);
117         int storeIndex = left;
118         for (int i = left; i < right; i++) {
119             if (array[i] < pivotValue) {
120                 swap(array, storeIndex, i);
121                 storeIndex++;
122             }
123         }
124         swap(array, right, storeIndex);
125         return storeIndex;
126     }
127 
partition(@onNull long[] array, int left, int right, int pivotIndex)128     private static int partition(@NonNull long[] array, int left, int right, int pivotIndex) {
129         final long pivotValue = array[pivotIndex];
130         swap(array, right, pivotIndex);
131         int storeIndex = left;
132         for (int i = left; i < right; i++) {
133             if (array[i] < pivotValue) {
134                 swap(array, storeIndex, i);
135                 storeIndex++;
136             }
137         }
138         swap(array, right, storeIndex);
139         return storeIndex;
140     }
141 
partition(@onNull T[] array, int left, int right, int pivotIndex, @NonNull Comparator<? super T> comparator)142     private static <T> int partition(@NonNull T[] array, int left, int right, int pivotIndex,
143             @NonNull Comparator<? super T> comparator) {
144         final T pivotValue = array[pivotIndex];
145         swap(array, right, pivotIndex);
146         int storeIndex = left;
147         for (int i = left; i < right; i++) {
148             if (comparator.compare(array[i], pivotValue) < 0) {
149                 swap(array, storeIndex, i);
150                 storeIndex++;
151             }
152         }
153         swap(array, right, storeIndex);
154         return storeIndex;
155     }
156 
swap(@onNull List<T> list, int left, int right)157     private static <T> void swap(@NonNull List<T> list, int left, int right) {
158         final T tmp = list.get(left);
159         list.set(left, list.get(right));
160         list.set(right, tmp);
161     }
162 
swap(@onNull int[] array, int left, int right)163     private static void swap(@NonNull int[] array, int left, int right) {
164         final int tmp = array[left];
165         array[left] = array[right];
166         array[right] = tmp;
167     }
168 
swap(@onNull long[] array, int left, int right)169     private static void swap(@NonNull long[] array, int left, int right) {
170         final long tmp = array[left];
171         array[left] = array[right];
172         array[right] = tmp;
173     }
174 
swap(@onNull T[] array, int left, int right)175     private static <T> void swap(@NonNull T[] array, int left, int right) {
176         final T tmp = array[left];
177         array[left] = array[right];
178         array[right] = tmp;
179     }
180 
181     /**
182      * Return the kth(0-based) smallest element from the given unsorted list.
183      *
184      * @param list The input list, it <b>will</b> be modified by the algorithm here.
185      * @param start The start offset of the list, inclusive.
186      * @param length The length of the sub list to be searched in.
187      * @param k The 0-based index.
188      * @param comparator The comparator which knows how to compare the elements in the list.
189      * @return The kth smallest element from the given list,
190      *         or IllegalArgumentException will be thrown if not found.
191      */
192     @Nullable
select(@onNull List<T> list, int start, int length, int k, @NonNull Comparator<? super T> comparator)193     public static <T> T select(@NonNull List<T> list, int start, int length, int k,
194             @NonNull Comparator<? super T> comparator) {
195         if (list == null || start < 0 || length <= 0 || list.size() < start + length
196                 || k < 0 || length <= k) {
197             throw new IllegalArgumentException();
198         }
199         return list.get(selectImpl(list, start, start + length - 1, k + start, comparator));
200     }
201 
202     /**
203      * Return the kth(0-based) smallest element from the given unsorted array.
204      *
205      * @param array The input array, it <b>will</b> be modified by the algorithm here.
206      * @param start The start offset of the array, inclusive.
207      * @param length The length of the sub array to be searched in.
208      * @param k The 0-based index to search for.
209      * @return The kth smallest element from the given array,
210      *         or IllegalArgumentException will be thrown if not found.
211      */
select(@onNull int[] array, int start, int length, int k)212     public static int select(@NonNull int[] array, int start, int length, int k) {
213         if (array == null || start < 0 || length <= 0 || array.length < start + length
214                 || k < 0 || length <= k) {
215             throw new IllegalArgumentException();
216         }
217         return array[selectImpl(array, start, start + length - 1, k + start)];
218     }
219 
220     /**
221      * Return the kth(0-based) smallest element from the given unsorted array.
222      *
223      * @param array The input array, it <b>will</b> be modified by the algorithm here.
224      * @param start The start offset of the array, inclusive.
225      * @param length The length of the sub array to be searched in.
226      * @param k The 0-based index to search for.
227      * @return The kth smallest element from the given array,
228      *         or IllegalArgumentException will be thrown if not found.
229      */
select(@onNull long[] array, int start, int length, int k)230     public static long select(@NonNull long[] array, int start, int length, int k) {
231         if (array == null || start < 0 || length <= 0 || array.length < start + length
232                 || k < 0 || length <= k) {
233             throw new IllegalArgumentException();
234         }
235         return array[selectImpl(array, start, start + length - 1, k + start)];
236     }
237 
238     /**
239      * Return the kth(0-based) smallest element from the given unsorted array.
240      *
241      * @param array The input array, it <b>will</b> be modified by the algorithm here.
242      * @param start The start offset of the array, inclusive.
243      * @param length The length of the sub array to be searched in.
244      * @param k The 0-based index to search for.
245      * @param comparator The comparator which knows how to compare the elements in the list.
246      * @return The kth smallest element from the given array,
247      *         or IllegalArgumentException will be thrown if not found.
248      */
select(@onNull T[] array, int start, int length, int k, @NonNull Comparator<? super T> comparator)249     public static <T> T select(@NonNull T[] array, int start, int length, int k,
250             @NonNull Comparator<? super T> comparator) {
251         if (array == null || start < 0 || length <= 0 || array.length < start + length
252                 || k < 0 || length <= k) {
253             throw new IllegalArgumentException();
254         }
255         return array[selectImpl(array, start, start + length - 1, k + start, comparator)];
256     }
257 }
258