diff --git a/src/commands/gen.rs b/src/commands/gen.rs index ce6ab9b..fac9b76 100644 --- a/src/commands/gen.rs +++ b/src/commands/gen.rs @@ -1,16 +1,14 @@ -use crate::{config::Config, result::Result, term::pick_one, totp_store::TotpStore}; +use crate::{config::Config, result::{Error, Result}, term::pick_one, totp_store::TotpStore}; pub fn run( config: Config, service: &str, account: Option<&str> ) -> Result<()> { - let mut store = TotpStore::with_tpm(config)?; - let alternatives = store.list(Some(service), account)?; + let alternatives = TotpStore::without_tpm(config.clone()).list(Some(service), account)?; if alternatives.is_empty() { - println!("service/account combination not found"); - return Ok(()) + return Err(Error::SecretNotFound); } if let Some(alt) = pick_one( @@ -19,9 +17,77 @@ pub fn run( "found multiple matches for the given service/account combination", alternatives.iter() ) { - let code = store.gen(alt.id, std::time::SystemTime::now())?; + let code = TotpStore::with_tpm(config)?.gen(alt.id, std::time::SystemTime::now())?; println!("{}", code); + Ok(()) + } else { + Err(Error::AmbiguousSecret) } +} - Ok(()) -} \ No newline at end of file +#[cfg(test)] +mod tests { + use serial_test::serial; + use tempfile::{tempdir, TempDir}; + use testutil::tpm::SwTpm; + + use crate::presence_verification::PresenceVerificationMethod; + use crate::tpm::Error::PresenceVerificationFailed; + use crate::totp_store::Error::TpmError; + + use super::*; + + #[test] + fn gen_succeeds_on_unambiguous_secret() { + let (_tpm, _dir, cfg) = setup(); + TotpStore::init(cfg.clone()).unwrap(); + let mut store = TotpStore::with_tpm(cfg.clone()).unwrap(); + store.add("foo", "bar", 6, 30, &[0,0,0,0,0,0,0,0,0,0]).unwrap(); + run(cfg, "foo", None).unwrap(); + } + + #[test] + fn gen_fails_on_secret_not_found() { + let (_tpm, _dir, cfg) = setup(); + TotpStore::init(cfg.clone()).unwrap(); + match run(cfg, "foo", None).unwrap_err() { + crate::result::Error::SecretNotFound => {}, + err => panic!("wrong error: {:#?}", err), + } + } + + #[test] + #[serial] + fn presence_verification_happens_after_disambiguation() { + let (_tpm, _dir, cfg) = setup(); + let mut failing_cfg = cfg.clone(); + failing_cfg.pv_method = PresenceVerificationMethod::AlwaysFail; + TotpStore::init(cfg.clone()).unwrap(); + + // If there are no matching accounts, we should quit before PV happens + let error = run(failing_cfg.clone(), "foo", Some("bar")).unwrap_err(); + if let Error::SecretNotFound = error {} else { + panic!("wrong error: {:#?}", error) + } + + // If there is exactly one matching accounts, we should see PV happening and failing + TotpStore::with_tpm(cfg.clone()).unwrap().add("foo", "bar", 6, 30, &[0,0,0,0,0,0,0,0,0,0]).unwrap(); + let error = run(failing_cfg.clone(), "foo", Some("bar")).unwrap_err(); + if let Error::TotpStoreError(TpmError(PresenceVerificationFailed)) = error {} else { + panic!("wrong error: {:#?}", error) + } + } + + fn setup() -> (SwTpm, TempDir, Config) { + let tpm = SwTpm::new(); + let dir = tempdir().unwrap(); + let cfg = Config::default( + true, + tpm.tcti.clone(), + Some(dir.path().join("sys")), + Some(dir.path().join("user")), + Some(PresenceVerificationMethod::None) + ); + (tpm, dir, cfg) + } +} diff --git a/src/db/mod.rs b/src/db/mod.rs index 7b35bfa..cb38d6a 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -5,6 +5,8 @@ use std::{fs::Permissions, os::unix::fs::PermissionsExt, path::Path}; use model::Secret; use rusqlite::{params, Connection, Row, Transaction}; +const CURRENT_SCHEMA_VERSION: u32 = 1; + pub struct DB<'a> { transaction: Transaction<'a> } @@ -16,6 +18,7 @@ pub enum Error { NoSuchElement, DbDirIsNotADir, DbFileIsNotAFile, + UnknownSchemaVersion(u32), } impl From for Error { @@ -75,7 +78,8 @@ impl <'a> DB<'a> { let mut stmt = self.transaction.prepare(" SELECT id, service, account, digits, interval, public_data, private_data FROM secrets - WHERE service LIKE ('%' || ?1 || '%') AND ACCOUNT LIKE ('%' || ?2 || '%') + WHERE service LIKE ('%' || ?1 || '%') AND account LIKE ('%' || ?2 || '%') + ORDER BY service, account ASC ")?; let secrets = stmt.query_map([service, account], to_secret) ?.filter_map(core::result::Result::ok); @@ -99,7 +103,7 @@ pub fn with_db

, T, F: FnOnce(&DB) -> Result>(db_path: P, f: F log::info!("starting transaction"); let transaction = db.transaction()?; - ensure_tables_exist(&transaction)?; + ensure_schema_is_up_to_date(&transaction)?; let db = DB::new(transaction); let result = f(&db); if result.is_ok() { @@ -144,8 +148,37 @@ fn to_secret(row: &Row) -> rusqlite::Result { }) } -fn ensure_tables_exist(tr: &Transaction) -> Result<()> { - tr.execute(" +fn ensure_schema_is_up_to_date(tx: &Transaction) -> Result<()> { + let schema_version = schema_version(tx)?; + if schema_version > CURRENT_SCHEMA_VERSION { + return Err(Error::UnknownSchemaVersion(schema_version)); + } + for v in schema_version .. CURRENT_SCHEMA_VERSION { + match v { + 0 => create_secrets_table(tx)?, + _ => unreachable!(), + } + } + update_schema_version(tx, CURRENT_SCHEMA_VERSION)?; + Ok(()) +} + +fn update_schema_version(tx: &Transaction, schema_version: u32) -> Result<()> { + tx.execute("UPDATE __version SET version = ?1", params![schema_version])?; + Ok(()) +} + +fn schema_version(tx: &Transaction) -> Result { + if let Ok(v) = tx.query_row("SELECT version FROM __version", (),|row| row.get(0)) { + Ok(v) + } else { + create_version_table(tx)?; + Ok(0) + } +} + +fn create_secrets_table(tx: &Transaction) -> std::result::Result<(), Error> { + tx.execute(" CREATE TABLE IF NOT EXISTS secrets ( id INTEGER PRIMARY KEY, service TEXT NOT NULL, @@ -160,8 +193,26 @@ fn ensure_tables_exist(tr: &Transaction) -> Result<()> { Ok(()) } +fn create_version_table(tx: &Transaction) -> Result<()> { + tx.execute(" + CREATE TABLE IF NOT EXISTS __version ( + id INTEGER PRIMARY KEY, + version INTEGER NOT NULL, + CHECK(id = 1) + )", + () + )?; + let num_rows: u32 = tx.query_row("SELECT COUNT(version) FROM __version", (), |row| row.get(0))?; + if num_rows == 0 { + tx.execute("INSERT INTO __version (version) VALUES (0)", ())?; + } + Ok(()) +} + #[cfg(test)] mod tests { + use std::collections::HashSet; + use super::*; #[test] @@ -178,8 +229,11 @@ mod tests { 0o600, ); - let result = with_db(&db, |tx| tx.list_secrets("", "")).unwrap(); - assert_eq!(result, vec![]); + with_db(&db, |tx| { + assert_eq!(schema_version(&tx.transaction)?, CURRENT_SCHEMA_VERSION); + assert_eq!(tx.list_secrets("", "")?, vec![]); + Ok(()) + }).unwrap(); } #[test] @@ -362,6 +416,42 @@ mod tests { assert_eq!(actual_secret, expected_secret); } + #[test] + fn list_secrets_return_value_is_sorted_alphabetically() { + let mut secret = Secret { + id: 0, + service: "x".to_owned(), + account: "x".to_owned(), + digits: 6, + interval: 30, + public_data: vec![], + private_data: vec![], + }; + let db = tempfile::NamedTempFile::new().unwrap(); + with_db(db.path(), |tx| { + tx.add_secret(secret.clone())?; + secret.service = "c".to_owned(); + tx.add_secret(secret.clone())?; + secret.account = "c".to_owned(); + tx.add_secret(secret.clone())?; + secret.service = "b".to_owned(); + tx.add_secret(secret.clone())?; + secret.account = "b".to_owned(); + tx.add_secret(secret.clone())?; + Ok(()) + }).unwrap(); + + /* empty strings match all secrets */ + let accounts = with_db(db.path(), |tx| tx.list_secrets("", "")).unwrap(); + let account_names: Vec<(&str, &str)> = accounts.iter() + .map(|x| (x.service.as_ref(), x.account.as_ref())) + .collect(); + assert_eq!( + account_names, + [("b", "b"), ("b", "c"), ("c", "c"), ("c", "x"), ("x", "x")] + ); + } + #[test] fn list_secrets_returns_correct_secrets() { let mut secret = Secret { @@ -388,57 +478,57 @@ mod tests { }).unwrap(); /* empty strings match all secrets */ - let ids: Vec = with_db(db.path(), |tx| tx.list_secrets("", "")) + let ids: HashSet = with_db(db.path(), |tx| tx.list_secrets("", "")) .unwrap().iter().map(|x| x.id).collect(); - assert_eq!(ids, all_ids); + assert_eq!(ids, HashSet::from_iter(all_ids.clone())); /* full match on service */ - let ids: Vec = with_db(db.path(), |tx| tx.list_secrets("service", "")) + let ids: HashSet = with_db(db.path(), |tx| tx.list_secrets("service", "")) .unwrap().iter().map(|x| x.id).collect(); - assert_eq!(ids, vec![all_ids[1], all_ids[2]]); + assert_eq!(ids, HashSet::from_iter([all_ids[1], all_ids[2]])); /* full match on account */ - let ids: Vec = with_db(db.path(), |tx| tx.list_secrets("", "acct")) + let ids: HashSet = with_db(db.path(), |tx| tx.list_secrets("", "acct")) .unwrap().iter().map(|x| x.id).collect(); - assert_eq!(ids, vec![all_ids[0], all_ids[1]]); + assert_eq!(ids, HashSet::from_iter([all_ids[0], all_ids[1]])); /* full match on both service and account */ - let ids: Vec = with_db(db.path(), |tx| tx.list_secrets("svc", "acct")) + let ids: HashSet = with_db(db.path(), |tx| tx.list_secrets("svc", "acct")) .unwrap().iter().map(|x| x.id).collect(); - assert_eq!(ids, vec![all_ids[0]]); + assert_eq!(ids, HashSet::from_iter([all_ids[0]])); /* partial match on service */ - let ids: Vec = with_db(db.path(), |tx| tx.list_secrets("tj", "")) + let ids: HashSet = with_db(db.path(), |tx| tx.list_secrets("tj", "")) .unwrap().iter().map(|x| x.id).collect(); - assert_eq!(ids, vec![all_ids[3]]); - let ids: Vec = with_db(db.path(), |tx| tx.list_secrets("c", "")) + assert_eq!(ids, HashSet::from_iter([all_ids[3]])); + let ids: HashSet = with_db(db.path(), |tx| tx.list_secrets("c", "")) .unwrap().iter().map(|x| x.id).collect(); - assert_eq!(ids, vec![all_ids[0], all_ids[1], all_ids[2]]); - let ids: Vec = with_db(db.path(), |tx| tx.list_secrets("ce", "")) + assert_eq!(ids, HashSet::from_iter([all_ids[0], all_ids[1], all_ids[2]])); + let ids: HashSet = with_db(db.path(), |tx| tx.list_secrets("ce", "")) .unwrap().iter().map(|x| x.id).collect(); - assert_eq!(ids, vec![all_ids[1], all_ids[2]]); + assert_eq!(ids, HashSet::from_iter([all_ids[1], all_ids[2]])); /* partial match on account */ - let ids: Vec = with_db(db.path(), |tx| tx.list_secrets("", "acc")) + let ids: HashSet = with_db(db.path(), |tx| tx.list_secrets("", "acc")) .unwrap().iter().map(|x| x.id).collect(); - assert_eq!(ids, vec![all_ids[0], all_ids[1], all_ids[2]]); - let ids: Vec = with_db(db.path(), |tx| tx.list_secrets("", "cco")) + assert_eq!(ids, HashSet::from_iter([all_ids[0], all_ids[1], all_ids[2]])); + let ids: HashSet = with_db(db.path(), |tx| tx.list_secrets("", "cco")) .unwrap().iter().map(|x| x.id).collect(); - assert_eq!(ids, vec![all_ids[2]]); - let ids: Vec = with_db(db.path(), |tx| tx.list_secrets("", "nto")) + assert_eq!(ids, HashSet::from_iter([all_ids[2]])); + let ids: HashSet = with_db(db.path(), |tx| tx.list_secrets("", "nto")) .unwrap().iter().map(|x| x.id).collect(); - assert_eq!(ids, vec![all_ids[3]]); + assert_eq!(ids, HashSet::from_iter([all_ids[3]])); /* no match */ - let ids: Vec = with_db(db.path(), |tx| tx.list_secrets("potato", "")) + let ids: HashSet = with_db(db.path(), |tx| tx.list_secrets("potato", "")) .unwrap().iter().map(|x| x.id).collect(); - assert_eq!(ids, vec![]); - let ids: Vec = with_db(db.path(), |tx| tx.list_secrets("", "potato")) + assert_eq!(ids, HashSet::from_iter([])); + let ids: HashSet = with_db(db.path(), |tx| tx.list_secrets("", "potato")) .unwrap().iter().map(|x| x.id).collect(); - assert_eq!(ids, vec![]); - let ids: Vec = with_db(db.path(), |tx| tx.list_secrets("potato", "potato")) + assert_eq!(ids, HashSet::from_iter([])); + let ids: HashSet = with_db(db.path(), |tx| tx.list_secrets("potato", "potato")) .unwrap().iter().map(|x| x.id).collect(); - assert_eq!(ids, vec![]); + assert_eq!(ids, HashSet::from_iter([])); } #[test] diff --git a/src/main.rs b/src/main.rs index 98f4520..a84a867 100644 --- a/src/main.rs +++ b/src/main.rs @@ -47,6 +47,12 @@ fn fail(e: totpm::result::Error) { totpm::result::Error::RootRequired => { eprintln!("root permissions required"); }, + totpm::result::Error::SecretNotFound => { + eprintln!("service/account combination not found"); + }, + totpm::result::Error::AmbiguousSecret => { + eprintln!("more than one secret matched the given parameters"); + }, }; exit(1); } diff --git a/src/presence_verification/factory.rs b/src/presence_verification/factory.rs index 98438d4..8e36943 100644 --- a/src/presence_verification/factory.rs +++ b/src/presence_verification/factory.rs @@ -7,5 +7,7 @@ pub(crate) fn create_presence_verifier( match method { PresenceVerificationMethod::Fprintd => Box::new(FprintdPresenceVerifier::new(timeout_secs)), PresenceVerificationMethod::None => Box::new(ConstPresenceVerifier::new(true)), + #[cfg(test)] + PresenceVerificationMethod::AlwaysFail => Box::new(ConstPresenceVerifier::new(false)) } } diff --git a/src/presence_verification/mod.rs b/src/presence_verification/mod.rs index 62dd65f..12fd74b 100644 --- a/src/presence_verification/mod.rs +++ b/src/presence_verification/mod.rs @@ -18,6 +18,8 @@ pub type Result = std::result::Result; pub enum PresenceVerificationMethod { Fprintd, None, + #[cfg(test)] + AlwaysFail, } impl FromStr for PresenceVerificationMethod { diff --git a/src/result.rs b/src/result.rs index 71a3a55..ef3f10e 100644 --- a/src/result.rs +++ b/src/result.rs @@ -10,6 +10,8 @@ pub enum Error { SecretFormatError, InvalidPVMethod(String), RootRequired, + SecretNotFound, + AmbiguousSecret, } impl From for Error { diff --git a/src/term.rs b/src/term.rs index 53b1285..84044ba 100644 --- a/src/term.rs +++ b/src/term.rs @@ -116,7 +116,7 @@ mod tests { } #[test] - fn pick_one_warns_and_returns_nothing_on_invalid_selection() { + fn pick_one_retries_on_invalid_selection() { let mut term = MockTerminal::new() .expect_stdout("hello\n") .expect_stdout("0:\t[cancel]\n")