1 // Copyright 2021, The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use anyhow::{anyhow, Context, Result};
16 use rusqlite::{params, OptionalExtension, Transaction};
17 
create_or_get_version(tx: &Transaction, current_version: u32) -> Result<u32>18 pub fn create_or_get_version(tx: &Transaction, current_version: u32) -> Result<u32> {
19     tx.execute(
20         "CREATE TABLE IF NOT EXISTS persistent.version (
21                 id INTEGER PRIMARY KEY,
22                 version INTEGER);",
23         [],
24     )
25     .context("In create_or_get_version: Failed to create version table.")?;
26 
27     let version = tx
28         .query_row("SELECT version FROM persistent.version WHERE id = 0;", [], |row| row.get(0))
29         .optional()
30         .context("In create_or_get_version: Failed to read version.")?;
31 
32     let version = if let Some(version) = version {
33         version
34     } else {
35         // If no version table existed it could mean one of two things:
36         // 1) This database is completely new. In this case the version has to be set
37         //    to the current version and the current version which also needs to be
38         //    returned.
39         // 2) The database predates db versioning. In this case the version needs to be
40         //    set to 0, and 0 needs to be returned.
41         let version = if tx
42             .query_row(
43                 "SELECT name FROM persistent.sqlite_master
44                  WHERE type = 'table' AND name = 'keyentry';",
45                 [],
46                 |_| Ok(()),
47             )
48             .optional()
49             .context("In create_or_get_version: Failed to check for keyentry table.")?
50             .is_none()
51         {
52             current_version
53         } else {
54             0
55         };
56 
57         tx.execute("INSERT INTO persistent.version (id, version) VALUES(0, ?);", params![version])
58             .context("In create_or_get_version: Failed to insert initial version.")?;
59         version
60     };
61     Ok(version)
62 }
63 
update_version(tx: &Transaction, new_version: u32) -> Result<()>64 pub fn update_version(tx: &Transaction, new_version: u32) -> Result<()> {
65     let updated = tx
66         .execute("UPDATE persistent.version SET version = ? WHERE id = 0;", params![new_version])
67         .context("In update_version: Failed to update row.")?;
68     if updated == 1 {
69         Ok(())
70     } else {
71         Err(anyhow!("In update_version: No rows were updated."))
72     }
73 }
74 
upgrade_database<F>(tx: &Transaction, current_version: u32, upgraders: &[F]) -> Result<()> where F: Fn(&Transaction) -> Result<u32> + 'static,75 pub fn upgrade_database<F>(tx: &Transaction, current_version: u32, upgraders: &[F]) -> Result<()>
76 where
77     F: Fn(&Transaction) -> Result<u32> + 'static,
78 {
79     if upgraders.len() < current_version as usize {
80         return Err(anyhow!("In upgrade_database: Insufficient upgraders provided."));
81     }
82     let mut db_version = create_or_get_version(tx, current_version)
83         .context("In upgrade_database: Failed to get database version.")?;
84     while db_version < current_version {
85         db_version = upgraders[db_version as usize](tx).with_context(|| {
86             format!("In upgrade_database: Trying to upgrade from db version {}.", db_version)
87         })?;
88     }
89     update_version(tx, db_version).context("In upgrade_database.")
90 }
91 
92 #[cfg(test)]
93 mod test {
94     use super::*;
95     use rusqlite::{Connection, TransactionBehavior};
96 
97     #[test]
upgrade_database_test()98     fn upgrade_database_test() {
99         let mut conn = Connection::open_in_memory().unwrap();
100         conn.execute("ATTACH DATABASE 'file::memory:' as persistent;", []).unwrap();
101 
102         let upgraders: Vec<_> = (0..30_u32)
103             .map(move |i| {
104                 move |tx: &Transaction| {
105                     tx.execute(
106                         "INSERT INTO persistent.test (test_field) VALUES(?);",
107                         params![i + 1],
108                     )
109                     .with_context(|| format!("In upgrade_from_{}_to_{}.", i, i + 1))?;
110                     Ok(i + 1)
111                 }
112             })
113             .collect();
114 
115         for legacy in &[false, true] {
116             if *legacy {
117                 conn.execute(
118                     "CREATE TABLE IF NOT EXISTS persistent.keyentry (
119                         id INTEGER UNIQUE,
120                         key_type INTEGER,
121                         domain INTEGER,
122                         namespace INTEGER,
123                         alias BLOB,
124                         state INTEGER,
125                         km_uuid BLOB);",
126                     [],
127                 )
128                 .unwrap();
129             }
130             for from in 1..29 {
131                 for to in from..30 {
132                     conn.execute("DROP TABLE IF EXISTS persistent.version;", []).unwrap();
133                     conn.execute("DROP TABLE IF EXISTS persistent.test;", []).unwrap();
134                     conn.execute(
135                         "CREATE TABLE IF NOT EXISTS persistent.test (
136                             id INTEGER PRIMARY KEY,
137                             test_field INTEGER);",
138                         [],
139                     )
140                     .unwrap();
141 
142                     {
143                         let tx =
144                             conn.transaction_with_behavior(TransactionBehavior::Immediate).unwrap();
145                         create_or_get_version(&tx, from).unwrap();
146                         tx.commit().unwrap();
147                     }
148                     {
149                         let tx =
150                             conn.transaction_with_behavior(TransactionBehavior::Immediate).unwrap();
151                         upgrade_database(&tx, to, &upgraders).unwrap();
152                         tx.commit().unwrap();
153                     }
154 
155                     // In the legacy database case all upgraders starting from 0 have to run. So
156                     // after the upgrade step, the expectations need to be adjusted.
157                     let from = if *legacy { 0 } else { from };
158 
159                     // There must be exactly to - from rows.
160                     assert_eq!(
161                         to - from,
162                         conn.query_row(
163                             "SELECT COUNT(test_field) FROM persistent.test;",
164                             [],
165                             |row| row.get(0)
166                         )
167                         .unwrap()
168                     );
169                     // Each row must have the correct relation between id and test_field. If this
170                     // is not the case, the upgraders were not executed in the correct order.
171                     assert_eq!(
172                         to - from,
173                         conn.query_row(
174                             "SELECT COUNT(test_field) FROM persistent.test
175                              WHERE id = test_field - ?;",
176                             params![from],
177                             |row| row.get(0)
178                         )
179                         .unwrap()
180                     );
181                 }
182             }
183         }
184     }
185 
186     #[test]
create_or_get_version_new_database()187     fn create_or_get_version_new_database() {
188         let mut conn = Connection::open_in_memory().unwrap();
189         conn.execute("ATTACH DATABASE 'file::memory:' as persistent;", []).unwrap();
190         {
191             let tx = conn.transaction_with_behavior(TransactionBehavior::Immediate).unwrap();
192             let version = create_or_get_version(&tx, 3).unwrap();
193             tx.commit().unwrap();
194             assert_eq!(version, 3);
195         }
196 
197         // Was the version table created as expected?
198         assert_eq!(
199             Ok("version".to_owned()),
200             conn.query_row(
201                 "SELECT name FROM persistent.sqlite_master
202                  WHERE type = 'table' AND name = 'version';",
203                 [],
204                 |row| row.get(0),
205             )
206         );
207 
208         // There is exactly one row in the version table.
209         assert_eq!(
210             Ok(1),
211             conn.query_row("SELECT COUNT(id) from persistent.version;", [], |row| row.get(0))
212         );
213 
214         // The version must be set to 3
215         assert_eq!(
216             Ok(3),
217             conn.query_row("SELECT version from persistent.version WHERE id = 0;", [], |row| row
218                 .get(0))
219         );
220 
221         // Will subsequent calls to create_or_get_version still return the same version even
222         // if the current version changes.
223         {
224             let tx = conn.transaction_with_behavior(TransactionBehavior::Immediate).unwrap();
225             let version = create_or_get_version(&tx, 5).unwrap();
226             tx.commit().unwrap();
227             assert_eq!(version, 3);
228         }
229 
230         // There is still exactly one row in the version table.
231         assert_eq!(
232             Ok(1),
233             conn.query_row("SELECT COUNT(id) from persistent.version;", [], |row| row.get(0))
234         );
235 
236         // Bump the version.
237         {
238             let tx = conn.transaction_with_behavior(TransactionBehavior::Immediate).unwrap();
239             update_version(&tx, 5).unwrap();
240             tx.commit().unwrap();
241         }
242 
243         // Now the version should have changed.
244         {
245             let tx = conn.transaction_with_behavior(TransactionBehavior::Immediate).unwrap();
246             let version = create_or_get_version(&tx, 7).unwrap();
247             tx.commit().unwrap();
248             assert_eq!(version, 5);
249         }
250 
251         // There is still exactly one row in the version table.
252         assert_eq!(
253             Ok(1),
254             conn.query_row("SELECT COUNT(id) from persistent.version;", [], |row| row.get(0))
255         );
256 
257         // The version must be set to 5
258         assert_eq!(
259             Ok(5),
260             conn.query_row("SELECT version from persistent.version WHERE id = 0;", [], |row| row
261                 .get(0))
262         );
263     }
264 
265     #[test]
create_or_get_version_legacy_database()266     fn create_or_get_version_legacy_database() {
267         let mut conn = Connection::open_in_memory().unwrap();
268         conn.execute("ATTACH DATABASE 'file::memory:' as persistent;", []).unwrap();
269         // A legacy (version 0) database is detected if the keyentry table exists but no
270         // version table.
271         conn.execute(
272             "CREATE TABLE IF NOT EXISTS persistent.keyentry (
273              id INTEGER UNIQUE,
274              key_type INTEGER,
275              domain INTEGER,
276              namespace INTEGER,
277              alias BLOB,
278              state INTEGER,
279              km_uuid BLOB);",
280             [],
281         )
282         .unwrap();
283 
284         {
285             let tx = conn.transaction_with_behavior(TransactionBehavior::Immediate).unwrap();
286             let version = create_or_get_version(&tx, 3).unwrap();
287             tx.commit().unwrap();
288             // In the legacy case, version 0 must be returned.
289             assert_eq!(version, 0);
290         }
291 
292         // Was the version table created as expected?
293         assert_eq!(
294             Ok("version".to_owned()),
295             conn.query_row(
296                 "SELECT name FROM persistent.sqlite_master
297                  WHERE type = 'table' AND name = 'version';",
298                 [],
299                 |row| row.get(0),
300             )
301         );
302 
303         // There is exactly one row in the version table.
304         assert_eq!(
305             Ok(1),
306             conn.query_row("SELECT COUNT(id) from persistent.version;", [], |row| row.get(0))
307         );
308 
309         // The version must be set to 0
310         assert_eq!(
311             Ok(0),
312             conn.query_row("SELECT version from persistent.version WHERE id = 0;", [], |row| row
313                 .get(0))
314         );
315 
316         // Will subsequent calls to create_or_get_version still return the same version even
317         // if the current version changes.
318         {
319             let tx = conn.transaction_with_behavior(TransactionBehavior::Immediate).unwrap();
320             let version = create_or_get_version(&tx, 5).unwrap();
321             tx.commit().unwrap();
322             assert_eq!(version, 0);
323         }
324 
325         // There is still exactly one row in the version table.
326         assert_eq!(
327             Ok(1),
328             conn.query_row("SELECT COUNT(id) from persistent.version;", [], |row| row.get(0))
329         );
330 
331         // Bump the version.
332         {
333             let tx = conn.transaction_with_behavior(TransactionBehavior::Immediate).unwrap();
334             update_version(&tx, 5).unwrap();
335             tx.commit().unwrap();
336         }
337 
338         // Now the version should have changed.
339         {
340             let tx = conn.transaction_with_behavior(TransactionBehavior::Immediate).unwrap();
341             let version = create_or_get_version(&tx, 7).unwrap();
342             tx.commit().unwrap();
343             assert_eq!(version, 5);
344         }
345 
346         // There is still exactly one row in the version table.
347         assert_eq!(
348             Ok(1),
349             conn.query_row("SELECT COUNT(id) from persistent.version;", [], |row| row.get(0))
350         );
351 
352         // The version must be set to 5
353         assert_eq!(
354             Ok(5),
355             conn.query_row("SELECT version from persistent.version WHERE id = 0;", [], |row| row
356                 .get(0))
357         );
358     }
359 }
360