1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.net
18 
19 import android.net.NetworkStats.Entry
20 import com.android.testutils.DevSdkIgnoreRunner
21 import java.time.Clock
22 import java.util.function.Supplier
23 import kotlin.test.assertEquals
24 import kotlin.test.assertNull
25 import kotlin.test.fail
26 import org.junit.Test
27 import org.junit.runner.RunWith
28 import org.mockito.Mockito.doReturn
29 import org.mockito.Mockito.mock
30 import org.mockito.Mockito.verify
31 import org.mockito.Mockito.`when`
32 
33 @RunWith(DevSdkIgnoreRunner::class)
34 class TrafficStatsRateLimitCacheTest {
35     companion object {
36         private const val expiryDurationMs = 1000L
37         private const val maxSize = 2
38     }
39 
40     private val clock = mock(Clock::class.java)
41     private val entry = mock(Entry::class.java)
42     private val cache = TrafficStatsRateLimitCache(clock, expiryDurationMs, maxSize)
43 
44     @Test
testGet_returnsEntryIfNotExpirednull45     fun testGet_returnsEntryIfNotExpired() {
46         cache.put("iface", 2, entry)
47         doReturn(500L).`when`(clock).millis() // Set clock to before expiry
48         val result = cache.get("iface", 2)
49         assertEquals(entry, result)
50     }
51 
52     @Test
testGet_returnsNullIfExpirednull53     fun testGet_returnsNullIfExpired() {
54         cache.put("iface", 2, entry)
55         doReturn(2000L).`when`(clock).millis() // Set clock to after expiry
56         assertNull(cache.get("iface", 2))
57     }
58 
59     @Test
testGet_returnsNullForNonExistentKeynull60     fun testGet_returnsNullForNonExistentKey() {
61         val result = cache.get("otherIface", 99)
62         assertNull(result)
63     }
64 
65     @Test
testPutAndGet_retrievesCorrectEntryForDifferentKeysnull66     fun testPutAndGet_retrievesCorrectEntryForDifferentKeys() {
67         val entry1 = mock(Entry::class.java)
68         val entry2 = mock(Entry::class.java)
69 
70         cache.put("iface1", 2, entry1)
71         cache.put("iface2", 4, entry2)
72 
73         assertEquals(entry1, cache.get("iface1", 2))
74         assertEquals(entry2, cache.get("iface2", 4))
75     }
76 
77     @Test
testPut_overridesExistingEntrynull78     fun testPut_overridesExistingEntry() {
79         val entry1 = mock(Entry::class.java)
80         val entry2 = mock(Entry::class.java)
81 
82         cache.put("iface", 2, entry1)
83         cache.put("iface", 2, entry2) // Put with the same key
84 
85         assertEquals(entry2, cache.get("iface", 2))
86     }
87 
88     @Test
testPut_removeLrunull89     fun testPut_removeLru() {
90         // Assumes max size is 2. Verify eldest entry get removed.
91         val entry1 = mock(Entry::class.java)
92         val entry2 = mock(Entry::class.java)
93         val entry3 = mock(Entry::class.java)
94 
95         cache.put("iface1", 2, entry1)
96         cache.put("iface2", 4, entry2)
97         cache.put("iface3", 8, entry3)
98 
99         assertNull(cache.get("iface1", 2))
100         assertEquals(entry2, cache.get("iface2", 4))
101         assertEquals(entry3, cache.get("iface3", 8))
102     }
103 
104     @Test
testGetOrCompute_cacheHitnull105     fun testGetOrCompute_cacheHit() {
106         val entry1 = mock(Entry::class.java)
107 
108         cache.put("iface1", 2, entry1)
109 
110         // Set clock to before expiry.
111         doReturn(500L).`when`(clock).millis()
112 
113         // Now call getOrCompute
114         val result = cache.getOrCompute("iface1", 2) {
115             fail("Supplier should not be called")
116         }
117 
118         // Assertions
119         assertEquals(entry1, result) // Should get the cached entry.
120     }
121 
122     @Suppress("UNCHECKED_CAST")
123     @Test
testGetOrCompute_cacheMissnull124     fun testGetOrCompute_cacheMiss() {
125         val entry1 = mock(Entry::class.java)
126 
127         cache.put("iface1", 2, entry1)
128 
129         // Set clock to after expiry.
130         doReturn(1500L).`when`(clock).millis()
131 
132         // Mock the supplier to return our network stats entry.
133         val supplier = mock(Supplier::class.java) as Supplier<Entry>
134         doReturn(entry1).`when`(supplier).get()
135 
136         // Now call getOrCompute.
137         val result = cache.getOrCompute("iface1", 2, supplier)
138 
139         // Assertions.
140         assertEquals(entry1, result) // Should get the cached entry.
141         verify(supplier).get()
142     }
143 
144     @Test
testClearnull145     fun testClear() {
146         cache.put("iface", 2, entry)
147         cache.clear()
148         assertNull(cache.get("iface", 2))
149     }
150 }
151