Skip to content

Commit

Permalink
Merge pull request #10 from koditoriet/minor-bug-fixes
Browse files Browse the repository at this point in the history
Minor bug fixes.
  • Loading branch information
valderman authored Sep 10, 2024
2 parents 99f116b + 6cfabfc commit ccc08bd
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 41 deletions.
82 changes: 74 additions & 8 deletions src/commands/gen.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
use crate::{config::Config, result::Result, term::pick_one, totp_store::TotpStore};
use crate::{config::Config, result::{Error, Result}, term::pick_one, totp_store::TotpStore};

pub fn run(
config: Config,
service: &str,
account: Option<&str>
) -> Result<()> {
let mut store = TotpStore::with_tpm(config)?;
let alternatives = store.list(Some(service), account)?;
let alternatives = TotpStore::without_tpm(config.clone()).list(Some(service), account)?;

if alternatives.is_empty() {
println!("service/account combination not found");
return Ok(())
return Err(Error::SecretNotFound);
}

if let Some(alt) = pick_one(
Expand All @@ -19,9 +17,77 @@ pub fn run(
"found multiple matches for the given service/account combination",
alternatives.iter()
) {
let code = store.gen(alt.id, std::time::SystemTime::now())?;
let code = TotpStore::with_tpm(config)?.gen(alt.id, std::time::SystemTime::now())?;
println!("{}", code);
Ok(())
} else {
Err(Error::AmbiguousSecret)
}
}

Ok(())
}
#[cfg(test)]
mod tests {
use serial_test::serial;
use tempfile::{tempdir, TempDir};
use testutil::tpm::SwTpm;

use crate::presence_verification::PresenceVerificationMethod;
use crate::tpm::Error::PresenceVerificationFailed;
use crate::totp_store::Error::TpmError;

use super::*;

#[test]
fn gen_succeeds_on_unambiguous_secret() {
let (_tpm, _dir, cfg) = setup();
TotpStore::init(cfg.clone()).unwrap();
let mut store = TotpStore::with_tpm(cfg.clone()).unwrap();
store.add("foo", "bar", 6, 30, &[0,0,0,0,0,0,0,0,0,0]).unwrap();
run(cfg, "foo", None).unwrap();
}

#[test]
fn gen_fails_on_secret_not_found() {
let (_tpm, _dir, cfg) = setup();
TotpStore::init(cfg.clone()).unwrap();
match run(cfg, "foo", None).unwrap_err() {
crate::result::Error::SecretNotFound => {},
err => panic!("wrong error: {:#?}", err),
}
}

#[test]
#[serial]
fn presence_verification_happens_after_disambiguation() {
let (_tpm, _dir, cfg) = setup();
let mut failing_cfg = cfg.clone();
failing_cfg.pv_method = PresenceVerificationMethod::AlwaysFail;
TotpStore::init(cfg.clone()).unwrap();

// If there are no matching accounts, we should quit before PV happens
let error = run(failing_cfg.clone(), "foo", Some("bar")).unwrap_err();
if let Error::SecretNotFound = error {} else {
panic!("wrong error: {:#?}", error)
}

// If there is exactly one matching accounts, we should see PV happening and failing
TotpStore::with_tpm(cfg.clone()).unwrap().add("foo", "bar", 6, 30, &[0,0,0,0,0,0,0,0,0,0]).unwrap();
let error = run(failing_cfg.clone(), "foo", Some("bar")).unwrap_err();
if let Error::TotpStoreError(TpmError(PresenceVerificationFailed)) = error {} else {
panic!("wrong error: {:#?}", error)
}
}

fn setup() -> (SwTpm, TempDir, Config) {
let tpm = SwTpm::new();
let dir = tempdir().unwrap();
let cfg = Config::default(
true,
tpm.tcti.clone(),
Some(dir.path().join("sys")),
Some(dir.path().join("user")),
Some(PresenceVerificationMethod::None)
);
(tpm, dir, cfg)
}
}
154 changes: 122 additions & 32 deletions src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use std::{fs::Permissions, os::unix::fs::PermissionsExt, path::Path};
use model::Secret;
use rusqlite::{params, Connection, Row, Transaction};

const CURRENT_SCHEMA_VERSION: u32 = 1;

pub struct DB<'a> {
transaction: Transaction<'a>
}
Expand All @@ -16,6 +18,7 @@ pub enum Error {
NoSuchElement,
DbDirIsNotADir,
DbFileIsNotAFile,
UnknownSchemaVersion(u32),
}

impl From<rusqlite::Error> for Error {
Expand Down Expand Up @@ -75,7 +78,8 @@ impl <'a> DB<'a> {
let mut stmt = self.transaction.prepare("
SELECT id, service, account, digits, interval, public_data, private_data
FROM secrets
WHERE service LIKE ('%' || ?1 || '%') AND ACCOUNT LIKE ('%' || ?2 || '%')
WHERE service LIKE ('%' || ?1 || '%') AND account LIKE ('%' || ?2 || '%')
ORDER BY service, account ASC
")?;
let secrets = stmt.query_map([service, account], to_secret)
?.filter_map(core::result::Result::ok);
Expand All @@ -99,7 +103,7 @@ pub fn with_db<P : AsRef<Path>, T, F: FnOnce(&DB) -> Result<T>>(db_path: P, f: F

log::info!("starting transaction");
let transaction = db.transaction()?;
ensure_tables_exist(&transaction)?;
ensure_schema_is_up_to_date(&transaction)?;
let db = DB::new(transaction);
let result = f(&db);
if result.is_ok() {
Expand Down Expand Up @@ -144,8 +148,37 @@ fn to_secret(row: &Row) -> rusqlite::Result<Secret> {
})
}

fn ensure_tables_exist(tr: &Transaction) -> Result<()> {
tr.execute("
fn ensure_schema_is_up_to_date(tx: &Transaction) -> Result<()> {
let schema_version = schema_version(tx)?;
if schema_version > CURRENT_SCHEMA_VERSION {
return Err(Error::UnknownSchemaVersion(schema_version));
}
for v in schema_version .. CURRENT_SCHEMA_VERSION {
match v {
0 => create_secrets_table(tx)?,
_ => unreachable!(),
}
}
update_schema_version(tx, CURRENT_SCHEMA_VERSION)?;
Ok(())
}

fn update_schema_version(tx: &Transaction, schema_version: u32) -> Result<()> {
tx.execute("UPDATE __version SET version = ?1", params![schema_version])?;
Ok(())
}

fn schema_version(tx: &Transaction) -> Result<u32> {
if let Ok(v) = tx.query_row("SELECT version FROM __version", (),|row| row.get(0)) {
Ok(v)
} else {
create_version_table(tx)?;
Ok(0)
}
}

fn create_secrets_table(tx: &Transaction) -> std::result::Result<(), Error> {
tx.execute("
CREATE TABLE IF NOT EXISTS secrets (
id INTEGER PRIMARY KEY,
service TEXT NOT NULL,
Expand All @@ -160,8 +193,26 @@ fn ensure_tables_exist(tr: &Transaction) -> Result<()> {
Ok(())
}

fn create_version_table(tx: &Transaction) -> Result<()> {
tx.execute("
CREATE TABLE IF NOT EXISTS __version (
id INTEGER PRIMARY KEY,
version INTEGER NOT NULL,
CHECK(id = 1)
)",
()
)?;
let num_rows: u32 = tx.query_row("SELECT COUNT(version) FROM __version", (), |row| row.get(0))?;
if num_rows == 0 {
tx.execute("INSERT INTO __version (version) VALUES (0)", ())?;
}
Ok(())
}

#[cfg(test)]
mod tests {
use std::collections::HashSet;

use super::*;

#[test]
Expand All @@ -178,8 +229,11 @@ mod tests {
0o600,
);

let result = with_db(&db, |tx| tx.list_secrets("", "")).unwrap();
assert_eq!(result, vec![]);
with_db(&db, |tx| {
assert_eq!(schema_version(&tx.transaction)?, CURRENT_SCHEMA_VERSION);
assert_eq!(tx.list_secrets("", "")?, vec![]);
Ok(())
}).unwrap();
}

#[test]
Expand Down Expand Up @@ -362,6 +416,42 @@ mod tests {
assert_eq!(actual_secret, expected_secret);
}

#[test]
fn list_secrets_return_value_is_sorted_alphabetically() {
let mut secret = Secret {
id: 0,
service: "x".to_owned(),
account: "x".to_owned(),
digits: 6,
interval: 30,
public_data: vec![],
private_data: vec![],
};
let db = tempfile::NamedTempFile::new().unwrap();
with_db(db.path(), |tx| {
tx.add_secret(secret.clone())?;
secret.service = "c".to_owned();
tx.add_secret(secret.clone())?;
secret.account = "c".to_owned();
tx.add_secret(secret.clone())?;
secret.service = "b".to_owned();
tx.add_secret(secret.clone())?;
secret.account = "b".to_owned();
tx.add_secret(secret.clone())?;
Ok(())
}).unwrap();

/* empty strings match all secrets */
let accounts = with_db(db.path(), |tx| tx.list_secrets("", "")).unwrap();
let account_names: Vec<(&str, &str)> = accounts.iter()
.map(|x| (x.service.as_ref(), x.account.as_ref()))
.collect();
assert_eq!(
account_names,
[("b", "b"), ("b", "c"), ("c", "c"), ("c", "x"), ("x", "x")]
);
}

#[test]
fn list_secrets_returns_correct_secrets() {
let mut secret = Secret {
Expand All @@ -388,57 +478,57 @@ mod tests {
}).unwrap();

/* empty strings match all secrets */
let ids: Vec<i64> = with_db(db.path(), |tx| tx.list_secrets("", ""))
let ids: HashSet<i64> = with_db(db.path(), |tx| tx.list_secrets("", ""))
.unwrap().iter().map(|x| x.id).collect();
assert_eq!(ids, all_ids);
assert_eq!(ids, HashSet::from_iter(all_ids.clone()));

/* full match on service */
let ids: Vec<i64> = with_db(db.path(), |tx| tx.list_secrets("service", ""))
let ids: HashSet<i64> = with_db(db.path(), |tx| tx.list_secrets("service", ""))
.unwrap().iter().map(|x| x.id).collect();
assert_eq!(ids, vec![all_ids[1], all_ids[2]]);
assert_eq!(ids, HashSet::from_iter([all_ids[1], all_ids[2]]));

/* full match on account */
let ids: Vec<i64> = with_db(db.path(), |tx| tx.list_secrets("", "acct"))
let ids: HashSet<i64> = with_db(db.path(), |tx| tx.list_secrets("", "acct"))
.unwrap().iter().map(|x| x.id).collect();
assert_eq!(ids, vec![all_ids[0], all_ids[1]]);
assert_eq!(ids, HashSet::from_iter([all_ids[0], all_ids[1]]));

/* full match on both service and account */
let ids: Vec<i64> = with_db(db.path(), |tx| tx.list_secrets("svc", "acct"))
let ids: HashSet<i64> = with_db(db.path(), |tx| tx.list_secrets("svc", "acct"))
.unwrap().iter().map(|x| x.id).collect();
assert_eq!(ids, vec![all_ids[0]]);
assert_eq!(ids, HashSet::from_iter([all_ids[0]]));

/* partial match on service */
let ids: Vec<i64> = with_db(db.path(), |tx| tx.list_secrets("tj", ""))
let ids: HashSet<i64> = with_db(db.path(), |tx| tx.list_secrets("tj", ""))
.unwrap().iter().map(|x| x.id).collect();
assert_eq!(ids, vec![all_ids[3]]);
let ids: Vec<i64> = with_db(db.path(), |tx| tx.list_secrets("c", ""))
assert_eq!(ids, HashSet::from_iter([all_ids[3]]));
let ids: HashSet<i64> = with_db(db.path(), |tx| tx.list_secrets("c", ""))
.unwrap().iter().map(|x| x.id).collect();
assert_eq!(ids, vec![all_ids[0], all_ids[1], all_ids[2]]);
let ids: Vec<i64> = with_db(db.path(), |tx| tx.list_secrets("ce", ""))
assert_eq!(ids, HashSet::from_iter([all_ids[0], all_ids[1], all_ids[2]]));
let ids: HashSet<i64> = with_db(db.path(), |tx| tx.list_secrets("ce", ""))
.unwrap().iter().map(|x| x.id).collect();
assert_eq!(ids, vec![all_ids[1], all_ids[2]]);
assert_eq!(ids, HashSet::from_iter([all_ids[1], all_ids[2]]));

/* partial match on account */
let ids: Vec<i64> = with_db(db.path(), |tx| tx.list_secrets("", "acc"))
let ids: HashSet<i64> = with_db(db.path(), |tx| tx.list_secrets("", "acc"))
.unwrap().iter().map(|x| x.id).collect();
assert_eq!(ids, vec![all_ids[0], all_ids[1], all_ids[2]]);
let ids: Vec<i64> = with_db(db.path(), |tx| tx.list_secrets("", "cco"))
assert_eq!(ids, HashSet::from_iter([all_ids[0], all_ids[1], all_ids[2]]));
let ids: HashSet<i64> = with_db(db.path(), |tx| tx.list_secrets("", "cco"))
.unwrap().iter().map(|x| x.id).collect();
assert_eq!(ids, vec![all_ids[2]]);
let ids: Vec<i64> = with_db(db.path(), |tx| tx.list_secrets("", "nto"))
assert_eq!(ids, HashSet::from_iter([all_ids[2]]));
let ids: HashSet<i64> = with_db(db.path(), |tx| tx.list_secrets("", "nto"))
.unwrap().iter().map(|x| x.id).collect();
assert_eq!(ids, vec![all_ids[3]]);
assert_eq!(ids, HashSet::from_iter([all_ids[3]]));

/* no match */
let ids: Vec<i64> = with_db(db.path(), |tx| tx.list_secrets("potato", ""))
let ids: HashSet<i64> = with_db(db.path(), |tx| tx.list_secrets("potato", ""))
.unwrap().iter().map(|x| x.id).collect();
assert_eq!(ids, vec![]);
let ids: Vec<i64> = with_db(db.path(), |tx| tx.list_secrets("", "potato"))
assert_eq!(ids, HashSet::from_iter([]));
let ids: HashSet<i64> = with_db(db.path(), |tx| tx.list_secrets("", "potato"))
.unwrap().iter().map(|x| x.id).collect();
assert_eq!(ids, vec![]);
let ids: Vec<i64> = with_db(db.path(), |tx| tx.list_secrets("potato", "potato"))
assert_eq!(ids, HashSet::from_iter([]));
let ids: HashSet<i64> = with_db(db.path(), |tx| tx.list_secrets("potato", "potato"))
.unwrap().iter().map(|x| x.id).collect();
assert_eq!(ids, vec![]);
assert_eq!(ids, HashSet::from_iter([]));
}

#[test]
Expand Down
6 changes: 6 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ fn fail(e: totpm::result::Error) {
totpm::result::Error::RootRequired => {
eprintln!("root permissions required");
},
totpm::result::Error::SecretNotFound => {
eprintln!("service/account combination not found");
},
totpm::result::Error::AmbiguousSecret => {
eprintln!("more than one secret matched the given parameters");
},
};
exit(1);
}
Expand Down
2 changes: 2 additions & 0 deletions src/presence_verification/factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@ pub(crate) fn create_presence_verifier(
match method {
PresenceVerificationMethod::Fprintd => Box::new(FprintdPresenceVerifier::new(timeout_secs)),
PresenceVerificationMethod::None => Box::new(ConstPresenceVerifier::new(true)),
#[cfg(test)]
PresenceVerificationMethod::AlwaysFail => Box::new(ConstPresenceVerifier::new(false))
}
}
2 changes: 2 additions & 0 deletions src/presence_verification/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ pub type Result<T> = std::result::Result<T, Error>;
pub enum PresenceVerificationMethod {
Fprintd,
None,
#[cfg(test)]
AlwaysFail,
}

impl FromStr for PresenceVerificationMethod {
Expand Down
2 changes: 2 additions & 0 deletions src/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ pub enum Error {
SecretFormatError,
InvalidPVMethod(String),
RootRequired,
SecretNotFound,
AmbiguousSecret,
}

impl From<toml::ser::Error> for Error {
Expand Down
2 changes: 1 addition & 1 deletion src/term.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ mod tests {
}

#[test]
fn pick_one_warns_and_returns_nothing_on_invalid_selection() {
fn pick_one_retries_on_invalid_selection() {
let mut term = MockTerminal::new()
.expect_stdout("hello\n")
.expect_stdout("0:\t[cancel]\n")
Expand Down

0 comments on commit ccc08bd

Please sign in to comment.