diff --git a/.sqlx/query-2873b774933c51e429479514b18eeb761e1702532482e7b310413d48832d71d0.json b/.sqlx/query-2873b774933c51e429479514b18eeb761e1702532482e7b310413d48832d71d0.json new file mode 100644 index 0000000..e8f9b1e --- /dev/null +++ b/.sqlx/query-2873b774933c51e429479514b18eeb761e1702532482e7b310413d48832d71d0.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE pg_task\n SET is_running = false,\n tried = 0,\n step = $2,\n updated_at = $3,\n wakeup_at = $3\n WHERE id = $1\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "2873b774933c51e429479514b18eeb761e1702532482e7b310413d48832d71d0" +} diff --git a/.sqlx/query-61eee517879d1017c23bf070468544070818bee22799cb20d7c4898223260b2f.json b/.sqlx/query-61eee517879d1017c23bf070468544070818bee22799cb20d7c4898223260b2f.json deleted file mode 100644 index c87a30c..0000000 --- a/.sqlx/query-61eee517879d1017c23bf070468544070818bee22799cb20d7c4898223260b2f.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n UPDATE pg_task\n SET is_running = false,\n tried = 0,\n step = $2,\n updated_at = $3,\n wakeup_at = $3\n WHERE id = $1\n ", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Timestamptz" - ] - }, - "nullable": [] - }, - "hash": "61eee517879d1017c23bf070468544070818bee22799cb20d7c4898223260b2f" -} diff --git a/Cargo.toml b/Cargo.toml index 776fdbf..3782958 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ async-trait = "0.1" chrono = { version = "0.4", features = ["std", "serde"] } code-path = "0.3" displaydoc = "0.2" +num_cpus = "1" serde = { version = "1", features = ["derive"] } serde_json = "1" source-chain = "0.1" diff --git a/examples/counter.rs b/examples/counter.rs index b2fdbb7..0f5a8ad 100644 --- a/examples/counter.rs +++ b/examples/counter.rs @@ -58,14 +58,13 @@ impl Step for Proceed { const RETRY_DELAY: Duration = Duration::from_secs(1); async fn step(self, _db: &PgPool) -> StepResult { - // return Err(anyhow::anyhow!("bailing").into()); let Self { up_to, mut cur, started_at, } = self; + cur += 1; - // println!("1..{up_to}: {cur}"); if cur < up_to { NextStep::now(Proceed { up_to, diff --git a/src/error.rs b/src/error.rs index 5788984..855abd7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -14,13 +14,15 @@ pub enum Error { WaiterConnect(#[source] sqlx::Error), /// waiter can't start listening to tables changes WaiterListen(#[source] sqlx::Error), + /// unreachable: worker semaphore is closed + UnreachableWorkerSemaphoreClosed(#[source] tokio::sync::AcquireError), } /// The crate result pub type Result = StdResult; /// Error of a task step -pub type StepError = Box; +pub type StepError = Box; /// Result returning from task steps pub type StepResult = StdResult, StepError>; diff --git a/src/worker.rs b/src/worker.rs index 7adac36..a8ed5ee 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -1,12 +1,13 @@ use crate::{util, waiter::Waiter, NextStep, Step, StepError, LOST_CONNECTION_SLEEP}; use chrono::{DateTime, Utc}; use code_path::code_path; +use serde::Serialize; use sqlx::{ postgres::{PgConnection, PgPool}, types::Uuid, }; -use std::{marker::PhantomData, time::Duration}; -use tokio::time::sleep; +use std::{fmt, marker::PhantomData, sync::Arc, time::Duration}; +use tokio::{sync::Semaphore, time::sleep}; use tracing::{debug, error, info, trace, warn}; /// An error report to log from the worker @@ -39,6 +40,7 @@ pub struct Worker { db: PgPool, waiter: Waiter, tasks: PhantomData, + concurrency: usize, } #[derive(Debug)] @@ -55,39 +57,44 @@ impl ErrorReport { } } -impl Task { - /// Returns the delay time before running the task - fn delay(&self) -> Duration { - let delay = self.wakeup_at - Utc::now(); - if delay <= chrono::Duration::zero() { - Duration::ZERO - } else { - util::chrono_duration_to_std(delay) - } - } -} - impl> Worker { /// Creates a new worker pub fn new(db: PgPool) -> Self { let waiter = Waiter::new(); + let concurrency = num_cpus::get(); Self { db, waiter, + concurrency, tasks: PhantomData, } } + /// Sets the number of concurrent tasks, default is the number of CPUs + pub fn with_concurrency(mut self, concurrency: usize) -> Self { + self.concurrency = concurrency; + self + } + /// Runs all ready tasks to completion and waits for new ones pub async fn run(&self) -> crate::Result<()> { self.unlock_stale_tasks().await?; self.waiter.listen(self.db.clone()).await?; - // TODO concurrency + let semaphore = Arc::new(Semaphore::new(self.concurrency)); loop { match self.recv_task().await { Ok(task) => { - self.run_step(task).await.map_err(ErrorReport::log).ok(); + let permit = semaphore + .clone() + .acquire_owned() + .await + .map_err(crate::Error::UnreachableWorkerSemaphoreClosed)?; + let db = self.db.clone(); + tokio::spawn(async move { + task.run_step::(&db).await.map_err(ErrorReport::log).ok(); + drop(permit); + }); } Err(e) => { warn!( @@ -98,48 +105,8 @@ impl> Worker { util::wait_for_reconnection(&self.db, LOST_CONNECTION_SLEEP).await; warn!("Task fetching is probably restored"); } - }; - } - } - - /// Runs the next step of the task - async fn run_step(&self, task: Task) -> ReportResult<()> { - let Task { - id, step, tried, .. - } = task; - info!( - "[{id}]{} run step {step}", - if tried > 0 { - format!(" {} attempt to", util::ordinal(tried + 1)) - } else { - "".into() - }, - ); - let step: S = match serde_json::from_str(&step) - .map_err(|e| ErrorReport::DeserializeStep(e, format!("{:?}", step))) - { - Ok(x) => x, - Err(e) => { - self.set_task_error(id, e.into()) - .await - .map_err(ErrorReport::log) - .ok(); - return Ok(()); - } - }; - - let retry_limit = step.retry_limit(); - let retry_delay = step.retry_delay(); - match step.step(&self.db).await { - Err(e) => { - self.process_error(id, tried, retry_limit, retry_delay, e) - .await? } - Ok(NextStep::None) => self.finish_task(id).await?, - Ok(NextStep::Now(step)) => self.update_task_step(id, step, Duration::ZERO).await?, - Ok(NextStep::Delayed(step, delay)) => self.update_task_step(id, step, delay).await?, - }; - Ok(()) + } } /// Unlocks all tasks. This is intended to run at the start of the worker as @@ -167,10 +134,10 @@ impl> Worker { loop { let table_changes = self.waiter.subscribe(); let mut tx = self.db.begin().await.map_err(sqlx_error!("begin"))?; - if let Some(task) = fetch_closest_task(&mut tx).await? { + if let Some(task) = Task::fetch_closest(&mut tx).await? { let time_to_run = task.wakeup_at - Utc::now(); if time_to_run <= chrono::Duration::zero() { - mark_task_running(&mut tx, task.id).await?; + task.mark_running(&mut tx).await?; tx.commit() .await .map_err(sqlx_error!("commit on task return"))?; @@ -190,50 +157,188 @@ impl> Worker { } } } +} - /// Updates the tasks step - async fn update_task_step(&self, task_id: Uuid, step: S, delay: Duration) -> ReportResult<()> { - let step = match serde_json::to_string(&step) - .map_err(|e| ErrorReport::SerializeStep(e, format!("{:?}", step))) +impl Task { + /// Fetches the closest task to run + async fn fetch_closest(con: &mut PgConnection) -> ReportResult> { + trace!("Fetching the closest task to run"); + let task = sqlx::query_as!( + Task, + r#" + SELECT + id, + step, + tried, + wakeup_at + FROM pg_task + WHERE is_running = false + AND error IS NULL + ORDER BY wakeup_at + LIMIT 1 + FOR UPDATE + "#, + ) + .fetch_optional(con) + .await + .map_err(sqlx_error!("select"))?; + + if let Some(ref task) = task { + let delay = task.delay(); + if delay == Duration::ZERO { + trace!("[{}] is to run now", task.id); + } else { + trace!("[{}] is to run in {:?}", task.id, delay); + } + } else { + debug!("No tasks to run"); + } + Ok(task) + } + + async fn mark_running(&self, con: &mut PgConnection) -> ReportResult<()> { + sqlx::query!( + " + UPDATE pg_task + SET is_running = true, + updated_at = now() + WHERE id = $1 + ", + self.id + ) + .execute(con) + .await + .map_err(sqlx_error!())?; + Ok(()) + } + + /// Returns the delay time before running the task + fn delay(&self) -> Duration { + let delay = self.wakeup_at - Utc::now(); + if delay <= chrono::Duration::zero() { + Duration::ZERO + } else { + util::chrono_duration_to_std(delay) + } + } + + /// Runs the current step of the task to completion + async fn run_step>(&self, db: &PgPool) -> ReportResult<()> { + info!( + "[{}]{} run step {}", + self.id, + if self.tried > 0 { + format!(" {} attempt to", util::ordinal(self.tried + 1)) + } else { + "".into() + }, + self.step + ); + let step: S = match serde_json::from_str(&self.step) + .map_err(|e| ErrorReport::DeserializeStep(e, format!("{:?}", self.step))) { Ok(x) => x, Err(e) => { - self.set_task_error(task_id, e.into()) + self.save_error(db, e.into()) .await .map_err(ErrorReport::log) .ok(); return Ok(()); } }; - trace!("[{task_id}] update step to {step}"); - sqlx::query!( - " + let retry_limit = step.retry_limit(); + let retry_delay = step.retry_delay(); + match step.step(db).await { + Err(e) => { + self.process_error(db, self.tried, retry_limit, retry_delay, e) + .await? + } + Ok(NextStep::None) => self.complete(db).await?, + Ok(NextStep::Now(step)) => self.save_next_step(db, step, Duration::ZERO).await?, + Ok(NextStep::Delayed(step, delay)) => self.save_next_step(db, step, delay).await?, + }; + Ok(()) + } + + /// Saves the task error + async fn save_error(&self, db: &PgPool, err: StepError) -> ReportResult<()> { + trace!("[{}] saving error", self.id); + + let err = source_chain::to_string(&*err); + let (tried, step) = sqlx::query!( + r#" UPDATE pg_task SET is_running = false, - tried = 0, - step = $2, + error = $2, updated_at = $3, wakeup_at = $3 WHERE id = $1 - ", - task_id, + RETURNING tried, step::TEXT as "step!" + "#, + self.id, + &err, + Utc::now(), + ) + .fetch_one(db) + .await + .map(|r| (r.tried, r.step)) + .map_err(sqlx_error!())?; + error!( + "[{}] resulted in an error at step {step} after {tried} attempts: {}", + self.id, err + ); + Ok(()) + } + + /// Updates the tasks step + async fn save_next_step( + &self, + db: &PgPool, + step: impl Serialize + fmt::Debug, + delay: Duration, + ) -> ReportResult<()> { + let step = match serde_json::to_string(&step) + .map_err(|e| ErrorReport::SerializeStep(e, format!("{:?}", step))) + { + Ok(x) => x, + Err(e) => { + self.save_error(db, e.into()) + .await + .map_err(ErrorReport::log) + .ok(); + return Ok(()); + } + }; + + trace!("[{}] moved to the next step {step}", self.id); + sqlx::query!( + " + UPDATE pg_task + SET is_running = false, + tried = 0, + step = $2, + updated_at = $3, + wakeup_at = $3 + WHERE id = $1 + ", + self.id, step, Utc::now() + util::std_duration_to_chrono(delay), ) - .execute(&self.db) + .execute(db) .await .map_err(sqlx_error!())?; - debug!("[{task_id}] step is done"); + debug!("[{}] step is done", self.id); Ok(()) } /// Removes the finished task - async fn finish_task(&self, task_id: Uuid) -> ReportResult<()> { - info!("[{task_id}] is successfully completed"); - sqlx::query!("DELETE FROM pg_task WHERE id = $1", task_id) - .execute(&self.db) + async fn complete(&self, db: &PgPool) -> ReportResult<()> { + info!("[{}] is successfully completed", self.id); + sqlx::query!("DELETE FROM pg_task WHERE id = $1", self.id) + .execute(db) .await .map_err(sqlx_error!())?; Ok(()) @@ -242,33 +347,31 @@ impl> Worker { /// Dealing with the step error async fn process_error( &self, - task_id: Uuid, + db: &PgPool, tried: i32, retry_limit: i32, retry_delay: Duration, err: StepError, ) -> ReportResult<()> { if tried < retry_limit { - self.retry_task(task_id, tried, retry_limit, retry_delay, err) - .await + self.retry(db, tried, retry_limit, retry_delay, err).await } else { - self.set_task_error(task_id, err).await + self.save_error(db, err).await } } /// Schedules the task for retry - async fn retry_task( + async fn retry( &self, - task_id: Uuid, + db: &PgPool, tried: i32, retry_limit: i32, delay: Duration, err: StepError, ) -> ReportResult<()> { - trace!("[{task_id}] scheduling a retry"); + trace!("[{}] scheduling a retry", self.id); - let delay = - chrono::Duration::from_std(delay).unwrap_or_else(|_| chrono::Duration::max_value()); + let delay = util::std_duration_to_chrono(delay); let wakeup_at = Utc::now() + delay; sqlx::query!( " @@ -279,100 +382,19 @@ impl> Worker { wakeup_at = $2 WHERE id = $1 ", - task_id, + self.id, wakeup_at, ) - .execute(&self.db) + .execute(db) .await .map_err(sqlx_error!())?; debug!( - "[{task_id}] scheduled {attempt} of {retry_limit} retries in {delay:?} on error: {}", + "[{}] scheduled {attempt} of {retry_limit} retries in {delay:?} on error: {}", + self.id, source_chain::to_string(&*err), attempt = util::ordinal(tried + 1) ); Ok(()) } - - /// Sets the task error - async fn set_task_error(&self, task_id: Uuid, err: StepError) -> ReportResult<()> { - trace!("[{task_id}] saving error"); - - let err = source_chain::to_string(&*err); - let (tried, step) = sqlx::query!( - r#" - UPDATE pg_task - SET is_running = false, - error = $2, - updated_at = $3, - wakeup_at = $3 - WHERE id = $1 - RETURNING tried, step::TEXT as "step!" - "#, - task_id, - &err, - Utc::now(), - ) - .fetch_one(&self.db) - .await - .map(|r| (r.tried, r.step)) - .map_err(sqlx_error!())?; - error!( - "[{task_id}] resulted in an error at step {step} after {tried} attempts: {}", - &err - ); - Ok(()) - } -} - -/// Fetches the closest task to run -async fn fetch_closest_task(con: &mut PgConnection) -> ReportResult> { - trace!("Fetching the closest task to run"); - let task = sqlx::query_as!( - Task, - r#" - SELECT - id, - step, - tried, - wakeup_at - FROM pg_task - WHERE is_running = false - AND error IS NULL - ORDER BY wakeup_at - LIMIT 1 - FOR UPDATE - "#, - ) - .fetch_optional(con) - .await - .map_err(sqlx_error!("select"))?; - - if let Some(ref task) = task { - let delay = task.delay(); - if delay == Duration::ZERO { - trace!("[{}] is to run now", task.id); - } else { - trace!("[{}] is to run in {:?}", task.id, delay); - } - } else { - debug!("No tasks to run"); - } - Ok(task) -} - -async fn mark_task_running(con: &mut PgConnection, task_id: Uuid) -> ReportResult<()> { - sqlx::query!( - " - UPDATE pg_task - SET is_running = true, - updated_at = now() - WHERE id = $1 - ", - task_id - ) - .execute(con) - .await - .map_err(sqlx_error!())?; - Ok(()) }