1 /* 2 * Copyright (C) 2024 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.server.net 18 19 import android.net.NetworkStats.Entry 20 import com.android.testutils.DevSdkIgnoreRunner 21 import java.time.Clock 22 import java.util.function.Supplier 23 import kotlin.test.assertEquals 24 import kotlin.test.assertNull 25 import kotlin.test.fail 26 import org.junit.Test 27 import org.junit.runner.RunWith 28 import org.mockito.Mockito.doReturn 29 import org.mockito.Mockito.mock 30 import org.mockito.Mockito.verify 31 import org.mockito.Mockito.`when` 32 33 @RunWith(DevSdkIgnoreRunner::class) 34 class TrafficStatsRateLimitCacheTest { 35 companion object { 36 private const val expiryDurationMs = 1000L 37 private const val maxSize = 2 38 } 39 40 private val clock = mock(Clock::class.java) 41 private val entry = mock(Entry::class.java) 42 private val cache = TrafficStatsRateLimitCache(clock, expiryDurationMs, maxSize) 43 44 @Test testGet_returnsEntryIfNotExpirednull45 fun testGet_returnsEntryIfNotExpired() { 46 cache.put("iface", 2, entry) 47 doReturn(500L).`when`(clock).millis() // Set clock to before expiry 48 val result = cache.get("iface", 2) 49 assertEquals(entry, result) 50 } 51 52 @Test testGet_returnsNullIfExpirednull53 fun testGet_returnsNullIfExpired() { 54 cache.put("iface", 2, entry) 55 doReturn(2000L).`when`(clock).millis() // Set clock to after expiry 56 assertNull(cache.get("iface", 2)) 57 } 58 59 @Test testGet_returnsNullForNonExistentKeynull60 fun testGet_returnsNullForNonExistentKey() { 61 val result = cache.get("otherIface", 99) 62 assertNull(result) 63 } 64 65 @Test testPutAndGet_retrievesCorrectEntryForDifferentKeysnull66 fun testPutAndGet_retrievesCorrectEntryForDifferentKeys() { 67 val entry1 = mock(Entry::class.java) 68 val entry2 = mock(Entry::class.java) 69 70 cache.put("iface1", 2, entry1) 71 cache.put("iface2", 4, entry2) 72 73 assertEquals(entry1, cache.get("iface1", 2)) 74 assertEquals(entry2, cache.get("iface2", 4)) 75 } 76 77 @Test testPut_overridesExistingEntrynull78 fun testPut_overridesExistingEntry() { 79 val entry1 = mock(Entry::class.java) 80 val entry2 = mock(Entry::class.java) 81 82 cache.put("iface", 2, entry1) 83 cache.put("iface", 2, entry2) // Put with the same key 84 85 assertEquals(entry2, cache.get("iface", 2)) 86 } 87 88 @Test testPut_removeLrunull89 fun testPut_removeLru() { 90 // Assumes max size is 2. Verify eldest entry get removed. 91 val entry1 = mock(Entry::class.java) 92 val entry2 = mock(Entry::class.java) 93 val entry3 = mock(Entry::class.java) 94 95 cache.put("iface1", 2, entry1) 96 cache.put("iface2", 4, entry2) 97 cache.put("iface3", 8, entry3) 98 99 assertNull(cache.get("iface1", 2)) 100 assertEquals(entry2, cache.get("iface2", 4)) 101 assertEquals(entry3, cache.get("iface3", 8)) 102 } 103 104 @Test testGetOrCompute_cacheHitnull105 fun testGetOrCompute_cacheHit() { 106 val entry1 = mock(Entry::class.java) 107 108 cache.put("iface1", 2, entry1) 109 110 // Set clock to before expiry. 111 doReturn(500L).`when`(clock).millis() 112 113 // Now call getOrCompute 114 val result = cache.getOrCompute("iface1", 2) { 115 fail("Supplier should not be called") 116 } 117 118 // Assertions 119 assertEquals(entry1, result) // Should get the cached entry. 120 } 121 122 @Suppress("UNCHECKED_CAST") 123 @Test testGetOrCompute_cacheMissnull124 fun testGetOrCompute_cacheMiss() { 125 val entry1 = mock(Entry::class.java) 126 127 cache.put("iface1", 2, entry1) 128 129 // Set clock to after expiry. 130 doReturn(1500L).`when`(clock).millis() 131 132 // Mock the supplier to return our network stats entry. 133 val supplier = mock(Supplier::class.java) as Supplier<Entry> 134 doReturn(entry1).`when`(supplier).get() 135 136 // Now call getOrCompute. 137 val result = cache.getOrCompute("iface1", 2, supplier) 138 139 // Assertions. 140 assertEquals(entry1, result) // Should get the cached entry. 141 verify(supplier).get() 142 } 143 144 @Test testClearnull145 fun testClear() { 146 cache.put("iface", 2, entry) 147 cache.clear() 148 assertNull(cache.get("iface", 2)) 149 } 150 } 151