diff --git a/crates/tasks/src/new_queue.rs b/crates/tasks/src/new_queue.rs index 143b83ec..ce850411 100644 --- a/crates/tasks/src/new_queue.rs +++ b/crates/tasks/src/new_queue.rs @@ -12,7 +12,7 @@ use mas_storage::{ Clock, RepositoryAccess, RepositoryError, }; use mas_storage_pg::{DatabaseError, PgRepository}; -use rand::{distributions::Uniform, Rng}; +use rand::{distributions::Uniform, Rng, RngCore}; use rand_chacha::ChaChaRng; use serde::de::DeserializeOwned; use sqlx::{ @@ -160,6 +160,7 @@ const MAX_JOBS_TO_FETCH: usize = 5; // How many attempts a job should be retried const MAX_ATTEMPTS: usize = 5; +type JobResult = Result<(), JobError>; type JobFactory = Arc Box + Send + Sync>; pub struct QueueWorker { @@ -171,13 +172,7 @@ pub struct QueueWorker { last_heartbeat: DateTime, cancellation_token: CancellationToken, state: State, - running_jobs: JoinSet>, - job_contexts: HashMap, - factories: HashMap<&'static str, JobFactory>, - - #[allow(clippy::type_complexity)] - last_join_result: - Option), tokio::task::JoinError>>, + tracker: JobTracker, } impl QueueWorker { @@ -234,10 +229,7 @@ impl QueueWorker { last_heartbeat: now, cancellation_token, state, - job_contexts: HashMap::new(), - running_jobs: JoinSet::new(), - factories: HashMap::new(), - last_join_result: None, + tracker: JobTracker::default(), }) } @@ -248,7 +240,9 @@ impl QueueWorker { box_runnable_job(T::from_job(payload).expect("Failed to deserialize job")) }; - self.factories.insert(T::QUEUE_NAME, Arc::new(factory)); + self.tracker + .factories + .insert(T::QUEUE_NAME, Arc::new(factory)); self } @@ -266,7 +260,6 @@ impl QueueWorker { async fn run_loop(&mut self) -> Result<(), QueueRunnerError> { self.wait_until_wakeup().await?; - // TODO: join all the jobs handles when shutting down if self.cancellation_token.is_cancelled() { return Ok(()); } @@ -284,8 +277,6 @@ impl QueueWorker { async fn shutdown(&mut self) -> Result<(), QueueRunnerError> { tracing::info!("Shutting down worker"); - // TODO: collect running jobs - // Start a transaction on the existing PgListener connection let txn = self .listener @@ -295,6 +286,24 @@ impl QueueWorker { let mut repo = PgRepository::from_conn(txn); + // Log about any job still running + match self.tracker.running_jobs() { + 0 => {} + 1 => tracing::warn!("There is one job still running, waiting for it to finish"), + n => tracing::warn!("There are {n} jobs still running, waiting for them to finish"), + } + + // TODO: we may want to introduce a timeout here, and abort the tasks if they + // take too long. It's fine for now, as we don't have long-running + // tasks, most of them are idempotent, and the only effect might be that + // the worker would 'dirtily' shutdown, meaning that its tasks would be + // considered, later retried by another worker + + // Wait for all the jobs to finish + self.tracker + .process_jobs(&mut self.rng, &self.clock, &mut repo, true) + .await?; + // Tell the other workers we're shutting down // This also releases the leader election lease repo.queue_worker() @@ -330,9 +339,8 @@ impl QueueWorker { tracing::debug!("Woke up from sleep"); }, - Some(result) = self.running_jobs.join_next_with_id() => { + () = self.tracker.collect_next_job(), if self.tracker.has_jobs() => { tracing::debug!("Joined job task"); - self.last_join_result = Some(result); }, notification = self.listener.recv() => { @@ -393,135 +401,21 @@ impl QueueWorker { .try_get_leader_lease(&self.clock, &self.registration) .await?; - // Find any job task which finished - // If we got woken up by a join on the joinset, it will be stored in the - // last_join_result so that we don't loose it - - if self.last_join_result.is_none() { - self.last_join_result = self.running_jobs.try_join_next_with_id(); - } - - while let Some(result) = self.last_join_result.take() { - // TODO: add metrics to track the job status and the time it took - match result { - Ok((id, Ok(()))) => { - // The job succeeded - let context = self - .job_contexts - .remove(&id) - .expect("Job context not found"); - - tracing::info!( - job.id = %context.id, - job.queue_name = %context.queue_name, - job.attempt = %context.attempt, - "Job completed" - ); - - repo.queue_job() - .mark_as_completed(&self.clock, context.id) - .await?; - } - Ok((id, Err(e))) => { - // The job failed - let context = self - .job_contexts - .remove(&id) - .expect("Job context not found"); - - let reason = format!("{:?}", e.error); - repo.queue_job() - .mark_as_failed(&self.clock, context.id, &reason) - .await?; - - match e.decision { - JobErrorDecision::Fail => { - tracing::error!( - error = &e as &dyn std::error::Error, - job.id = %context.id, - job.queue_name = %context.queue_name, - job.attempt = %context.attempt, - "Job failed, not retrying" - ); - } - - JobErrorDecision::Retry => { - if context.attempt < MAX_ATTEMPTS { - tracing::warn!( - error = &e as &dyn std::error::Error, - job.id = %context.id, - job.queue_name = %context.queue_name, - job.attempt = %context.attempt, - "Job failed, will retry" - ); - - // TODO: retry with an exponential backoff, once we know how to - // schedule jobs in the future - repo.queue_job() - .retry(&mut self.rng, &self.clock, context.id) - .await?; - } else { - tracing::error!( - error = &e as &dyn std::error::Error, - job.id = %context.id, - job.queue_name = %context.queue_name, - job.attempt = %context.attempt, - "Job failed too many times, abandonning" - ); - } - } - } - } - Err(e) => { - // The job crashed (or was cancelled) - let id = e.id(); - let context = self - .job_contexts - .remove(&id) - .expect("Job context not found"); - - let reason = e.to_string(); - repo.queue_job() - .mark_as_failed(&self.clock, context.id, &reason) - .await?; - - if context.attempt < MAX_ATTEMPTS { - tracing::warn!( - error = &e as &dyn std::error::Error, - job.id = %context.id, - job.queue_name = %context.queue_name, - job.attempt = %context.attempt, - "Job crashed, will retry" - ); - - repo.queue_job() - .retry(&mut self.rng, &self.clock, context.id) - .await?; - } else { - tracing::error!( - error = &e as &dyn std::error::Error, - job.id = %context.id, - job.queue_name = %context.queue_name, - job.attempt = %context.attempt, - "Job crashed too many times, abandonning" - ); - } - } - }; - - self.last_join_result = self.running_jobs.try_join_next_with_id(); - } + // Process any job task which finished + self.tracker + .process_jobs(&mut self.rng, &self.clock, &mut repo, false) + .await?; // Compute how many jobs we should fetch at most let max_jobs_to_fetch = MAX_CONCURRENT_JOBS - .saturating_sub(self.running_jobs.len()) + .saturating_sub(self.tracker.running_jobs()) .max(MAX_JOBS_TO_FETCH); if max_jobs_to_fetch == 0 { tracing::warn!("Internal job queue is full, not fetching any new jobs"); } else { // Grab a few jobs in the queue - let queues = self.factories.keys().copied().collect::>(); + let queues = self.tracker.queues(); let jobs = repo .queue_job() .reserve(&self.clock, &self.registration, &queues, max_jobs_to_fetch) @@ -536,7 +430,6 @@ impl QueueWorker { } in jobs { let cancellation_token = self.cancellation_token.child_token(); - let factory = self.factories.get(queue_name.as_str()).cloned(); let context = JobContext { id, metadata, @@ -545,21 +438,7 @@ impl QueueWorker { cancellation_token, }; - let task = { - let context = context.clone(); - let span = context.span(); - let state = self.state.clone(); - async move { - // We should never crash, but in case we do, we do that in the task and - // don't crash the worker - let job = factory.expect("unknown job factory")(payload); - job.run(&state, context).await - } - .instrument(span) - }; - - let handle = self.running_jobs.spawn(task); - self.job_contexts.insert(handle.id(), context); + self.tracker.spawn_job(self.state.clone(), context, payload); } } @@ -651,3 +530,211 @@ impl QueueWorker { Ok(()) } } + +/// Tracks running jobs +/// +/// This is a separate structure to be able to borrow it mutably at the same +/// time as the connection to the database is borrowed +#[derive(Default)] +struct JobTracker { + /// Stores a mapping from the job queue name to the job factory + factories: HashMap<&'static str, JobFactory>, + + /// A join set of all the currently running jobs + running_jobs: JoinSet, + + /// Stores a mapping from the Tokio task ID to the job context + job_contexts: HashMap, + + /// Stores the last `join_next_with_id` result for processing, in case we + /// got woken up in `collect_next_job` + last_join_result: Option>, +} + +impl JobTracker { + /// Returns the queue names that are currently being tracked + fn queues(&self) -> Vec<&'static str> { + self.factories.keys().copied().collect() + } + + /// Spawn a job on the job tracker + fn spawn_job(&mut self, state: State, context: JobContext, payload: JobPayload) { + let factory = self.factories.get(context.queue_name.as_str()).cloned(); + let task = { + let context = context.clone(); + let span = context.span(); + async move { + // We should never crash, but in case we do, we do that in the task and + // don't crash the worker + let job = factory.expect("unknown job factory")(payload); + tracing::info!("Running job"); + job.run(&state, context).await + } + .instrument(span) + }; + + let handle = self.running_jobs.spawn(task); + self.job_contexts.insert(handle.id(), context); + } + + /// Returns `true` if there are currently running jobs + fn has_jobs(&self) -> bool { + !self.running_jobs.is_empty() + } + + /// Returns the number of currently running jobs + /// + /// This also includes the job result which may be stored for processing + fn running_jobs(&self) -> usize { + self.running_jobs.len() + usize::from(self.last_join_result.is_some()) + } + + async fn collect_next_job(&mut self) { + // Double-check that we don't have a job result stored + if self.last_join_result.is_some() { + tracing::error!( + "Job tracker already had a job result stored, this should never happen!" + ); + return; + } + + self.last_join_result = self.running_jobs.join_next_with_id().await; + } + + /// Process all the jobs which are currently running + /// + /// If `blocking` is `true`, this function will block until all the jobs + /// are finished. Otherwise, it will return as soon as it processed the + /// already finished jobs. + #[allow(clippy::too_many_lines)] + async fn process_jobs( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + repo: &mut dyn RepositoryAccess, + blocking: bool, + ) -> Result<(), E> { + if self.last_join_result.is_none() { + if blocking { + self.last_join_result = self.running_jobs.join_next_with_id().await; + } else { + self.last_join_result = self.running_jobs.try_join_next_with_id(); + } + } + + while let Some(result) = self.last_join_result.take() { + // TODO: add metrics to track the job status and the time it took + match result { + // The job succeeded + Ok((id, Ok(()))) => { + let context = self + .job_contexts + .remove(&id) + .expect("Job context not found"); + + tracing::info!( + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job completed" + ); + + repo.queue_job() + .mark_as_completed(clock, context.id) + .await?; + } + + // The job failed + Ok((id, Err(e))) => { + let context = self + .job_contexts + .remove(&id) + .expect("Job context not found"); + + let reason = format!("{:?}", e.error); + repo.queue_job() + .mark_as_failed(clock, context.id, &reason) + .await?; + + match e.decision { + JobErrorDecision::Fail => { + tracing::error!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job failed, not retrying" + ); + } + + JobErrorDecision::Retry => { + if context.attempt < MAX_ATTEMPTS { + tracing::warn!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job failed, will retry" + ); + + // TODO: retry with an exponential backoff, once we know how to + // schedule jobs in the future + repo.queue_job().retry(&mut *rng, clock, context.id).await?; + } else { + tracing::error!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job failed too many times, abandonning" + ); + } + } + } + } + + // The job crashed (or was aborted) + Err(e) => { + let id = e.id(); + let context = self + .job_contexts + .remove(&id) + .expect("Job context not found"); + + let reason = e.to_string(); + repo.queue_job() + .mark_as_failed(clock, context.id, &reason) + .await?; + + if context.attempt < MAX_ATTEMPTS { + tracing::warn!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job crashed, will retry" + ); + + repo.queue_job().retry(&mut *rng, clock, context.id).await?; + } else { + tracing::error!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job crashed too many times, abandonning" + ); + } + } + }; + + if blocking { + self.last_join_result = self.running_jobs.join_next_with_id().await; + } else { + self.last_join_result = self.running_jobs.try_join_next_with_id(); + } + } + + Ok(()) + } +}