From 1fdf0f793606eb4a07e6c1b10bdc6985fd82bb94 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 15 Oct 2024 11:59:32 +0200 Subject: [PATCH 1/5] WIP: job consumption --- crates/tasks/src/new_queue.rs | 44 ++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/crates/tasks/src/new_queue.rs b/crates/tasks/src/new_queue.rs index 571f8591b..f90b72011 100644 --- a/crates/tasks/src/new_queue.rs +++ b/crates/tasks/src/new_queue.rs @@ -3,11 +3,18 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. +use std::collections::HashMap; + +use async_trait::async_trait; use chrono::{DateTime, Duration, Utc}; -use mas_storage::{queue::Worker, Clock, RepositoryAccess, RepositoryError}; +use mas_storage::{ + queue::{InsertableJob, Job, Worker}, + Clock, RepositoryAccess, RepositoryError, +}; use mas_storage_pg::{DatabaseError, PgRepository}; use rand::{distributions::Uniform, Rng}; use rand_chacha::ChaChaRng; +use serde::de::DeserializeOwned; use sqlx::{ postgres::{PgAdvisoryLock, PgListener}, Acquire, Either, @@ -17,6 +24,30 @@ use tokio_util::sync::CancellationToken; use crate::State; +pub trait FromJob { + fn from_job(job: &Job) -> Result + where + Self: Sized; +} + +impl FromJob for T +where + T: DeserializeOwned, +{ + fn from_job(job: &Job) -> Result { + serde_json::from_value(job.payload.clone()).map_err(Into::into) + } +} + +#[async_trait] +pub trait RunnableJob: FromJob + Send + 'static { + async fn run(&self, state: &State) -> Result<(), anyhow::Error>; +} + +fn box_runnable_job(job: T) -> Box { + Box::new(job) +} + #[derive(Debug, Error)] pub enum QueueRunnerError { #[error("Failed to setup listener")] @@ -48,6 +79,8 @@ pub enum QueueRunnerError { const MIN_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(900); const MAX_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(1100); +type JobFactory = Box Box + Send>; + pub struct QueueWorker { rng: ChaChaRng, clock: Box, @@ -56,6 +89,7 @@ pub struct QueueWorker { am_i_leader: bool, last_heartbeat: DateTime, cancellation_token: CancellationToken, + factories: HashMap<&'static str, JobFactory>, } impl QueueWorker { @@ -105,9 +139,17 @@ impl QueueWorker { am_i_leader: false, last_heartbeat: now, cancellation_token, + factories: HashMap::new(), }) } + pub fn register_handler(&mut self) -> &mut Self { + // TODO: error handling + let factory = |job: &Job| box_runnable_job(T::from_job(job).unwrap()); + self.factories.insert(T::QUEUE_NAME, Box::new(factory)); + self + } + pub async fn run(&mut self) -> Result<(), QueueRunnerError> { while !self.cancellation_token.is_cancelled() { self.run_loop().await?; From c1e433556ca3ce018e88b3a0efa5dadd5eac35f0 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 31 Oct 2024 17:38:43 +0100 Subject: [PATCH 2/5] Actually consume jobs --- ...b411aa9f15e7beccfd6212787c3452d35d061.json | 43 ++ ...ddd9d3bc8c99b5b8b1d373939fc3ae9715c27.json | 15 + crates/storage-pg/src/queue/job.rs | 130 +++++- crates/storage/src/queue/job.rs | 48 ++- crates/tasks/src/email.rs | 162 ++++---- crates/tasks/src/lib.rs | 20 +- crates/tasks/src/matrix.rs | 390 ++++++++--------- crates/tasks/src/new_queue.rs | 208 +++++++++- crates/tasks/src/recovery.rs | 198 +++++---- crates/tasks/src/storage/from_row.rs | 70 ---- crates/tasks/src/storage/mod.rs | 14 - crates/tasks/src/storage/postgres.rs | 391 ------------------ crates/tasks/src/user.rs | 211 +++++----- crates/tasks/src/utils.rs | 91 ---- 14 files changed, 907 insertions(+), 1084 deletions(-) create mode 100644 crates/storage-pg/.sqlx/query-9f2fae84d17991a179f93c4ea43b411aa9f15e7beccfd6212787c3452d35d061.json create mode 100644 crates/storage-pg/.sqlx/query-a63a217981b97448ddcc96b2489ddd9d3bc8c99b5b8b1d373939fc3ae9715c27.json delete mode 100644 crates/tasks/src/storage/from_row.rs delete mode 100644 crates/tasks/src/storage/mod.rs delete mode 100644 crates/tasks/src/storage/postgres.rs delete mode 100644 crates/tasks/src/utils.rs diff --git a/crates/storage-pg/.sqlx/query-9f2fae84d17991a179f93c4ea43b411aa9f15e7beccfd6212787c3452d35d061.json b/crates/storage-pg/.sqlx/query-9f2fae84d17991a179f93c4ea43b411aa9f15e7beccfd6212787c3452d35d061.json new file mode 100644 index 000000000..67f1ad132 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-9f2fae84d17991a179f93c4ea43b411aa9f15e7beccfd6212787c3452d35d061.json @@ -0,0 +1,43 @@ +{ + "db_name": "PostgreSQL", + "query": "\n -- We first grab a few jobs that are available,\n -- using a FOR UPDATE SKIP LOCKED so that this can be run concurrently\n -- and we don't get multiple workers grabbing the same jobs\n WITH locked_jobs AS (\n SELECT queue_job_id\n FROM queue_jobs\n WHERE\n status = 'available'\n AND queue_name = ANY($1)\n ORDER BY queue_job_id ASC\n LIMIT $2\n FOR UPDATE\n SKIP LOCKED\n )\n -- then we update the status of those jobs to 'running', returning the job details\n UPDATE queue_jobs\n SET status = 'running', started_at = $3, started_by = $4\n FROM locked_jobs\n WHERE queue_jobs.queue_job_id = locked_jobs.queue_job_id\n RETURNING\n queue_jobs.queue_job_id,\n queue_jobs.queue_name,\n queue_jobs.payload,\n queue_jobs.metadata\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "queue_job_id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "queue_name", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "payload", + "type_info": "Jsonb" + }, + { + "ordinal": 3, + "name": "metadata", + "type_info": "Jsonb" + } + ], + "parameters": { + "Left": [ + "TextArray", + "Int8", + "Timestamptz", + "Uuid" + ] + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "9f2fae84d17991a179f93c4ea43b411aa9f15e7beccfd6212787c3452d35d061" +} diff --git a/crates/storage-pg/.sqlx/query-a63a217981b97448ddcc96b2489ddd9d3bc8c99b5b8b1d373939fc3ae9715c27.json b/crates/storage-pg/.sqlx/query-a63a217981b97448ddcc96b2489ddd9d3bc8c99b5b8b1d373939fc3ae9715c27.json new file mode 100644 index 000000000..407258ab4 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-a63a217981b97448ddcc96b2489ddd9d3bc8c99b5b8b1d373939fc3ae9715c27.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE queue_jobs\n SET status = 'completed', completed_at = $1\n WHERE queue_job_id = $2 AND status = 'running'\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Timestamptz", + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "a63a217981b97448ddcc96b2489ddd9d3bc8c99b5b8b1d373939fc3ae9715c27" +} diff --git a/crates/storage-pg/src/queue/job.rs b/crates/storage-pg/src/queue/job.rs index 4b4433b50..90f8546a7 100644 --- a/crates/storage-pg/src/queue/job.rs +++ b/crates/storage-pg/src/queue/job.rs @@ -7,13 +7,16 @@ //! [`QueueJobRepository`]. use async_trait::async_trait; -use mas_storage::{queue::QueueJobRepository, Clock}; +use mas_storage::{ + queue::{Job, QueueJobRepository, Worker}, + Clock, +}; use rand::RngCore; use sqlx::PgConnection; use ulid::Ulid; use uuid::Uuid; -use crate::{DatabaseError, ExecuteExt}; +use crate::{DatabaseError, DatabaseInconsistencyError, ExecuteExt}; /// An implementation of [`QueueJobRepository`] for a PostgreSQL connection. pub struct PgQueueJobRepository<'c> { @@ -29,6 +32,37 @@ impl<'c> PgQueueJobRepository<'c> { } } +struct JobReservationResult { + queue_job_id: Uuid, + queue_name: String, + payload: serde_json::Value, + metadata: serde_json::Value, +} + +impl TryFrom for Job { + type Error = DatabaseInconsistencyError; + + fn try_from(value: JobReservationResult) -> Result { + let id = value.queue_job_id.into(); + let queue_name = value.queue_name; + let payload = value.payload; + + let metadata = serde_json::from_value(value.metadata).map_err(|e| { + DatabaseInconsistencyError::on("queue_jobs") + .column("metadata") + .row(id) + .source(e) + })?; + + Ok(Self { + id, + queue_name, + payload, + metadata, + }) + } +} + #[async_trait] impl<'c> QueueJobRepository for PgQueueJobRepository<'c> { type Error = DatabaseError; @@ -73,4 +107,96 @@ impl<'c> QueueJobRepository for PgQueueJobRepository<'c> { Ok(()) } + + #[tracing::instrument( + name = "db.queue_job.reserve", + skip_all, + fields( + db.query.text, + ), + err, + )] + async fn reserve( + &mut self, + clock: &dyn Clock, + worker: &Worker, + queues: &[&str], + count: usize, + ) -> Result, Self::Error> { + let now = clock.now(); + let max_count = i64::try_from(count).unwrap_or(i64::MAX); + let queues: Vec = queues.iter().map(|&s| s.to_owned()).collect(); + let results = sqlx::query_as!( + JobReservationResult, + r#" + -- We first grab a few jobs that are available, + -- using a FOR UPDATE SKIP LOCKED so that this can be run concurrently + -- and we don't get multiple workers grabbing the same jobs + WITH locked_jobs AS ( + SELECT queue_job_id + FROM queue_jobs + WHERE + status = 'available' + AND queue_name = ANY($1) + ORDER BY queue_job_id ASC + LIMIT $2 + FOR UPDATE + SKIP LOCKED + ) + -- then we update the status of those jobs to 'running', returning the job details + UPDATE queue_jobs + SET status = 'running', started_at = $3, started_by = $4 + FROM locked_jobs + WHERE queue_jobs.queue_job_id = locked_jobs.queue_job_id + RETURNING + queue_jobs.queue_job_id, + queue_jobs.queue_name, + queue_jobs.payload, + queue_jobs.metadata + "#, + &queues, + max_count, + now, + Uuid::from(worker.id), + ) + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let jobs = results + .into_iter() + .map(TryFrom::try_from) + .collect::, _>>()?; + + Ok(jobs) + } + + #[tracing::instrument( + name = "db.queue_job.mark_as_completed", + skip_all, + fields( + db.query.text, + job.id = %id, + ), + err, + )] + async fn mark_as_completed(&mut self, clock: &dyn Clock, id: Ulid) -> Result<(), Self::Error> { + let now = clock.now(); + let res = sqlx::query!( + r#" + UPDATE queue_jobs + SET status = 'completed', completed_at = $1 + WHERE queue_job_id = $2 AND status = 'running' + "#, + now, + Uuid::from(id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(()) + } } diff --git a/crates/storage/src/queue/job.rs b/crates/storage/src/queue/job.rs index c5ec3f4f4..13df586d7 100644 --- a/crates/storage/src/queue/job.rs +++ b/crates/storage/src/queue/job.rs @@ -12,6 +12,7 @@ use serde::{Deserialize, Serialize}; use tracing_opentelemetry::OpenTelemetrySpanExt; use ulid::Ulid; +use super::Worker; use crate::{repository_impl, Clock}; /// Represents a job in the job queue @@ -19,6 +20,9 @@ pub struct Job { /// The ID of the job pub id: Ulid, + /// The queue on which the job was placed + pub queue_name: String, + /// The payload of the job pub payload: serde_json::Value, @@ -27,7 +31,7 @@ pub struct Job { } /// Metadata stored alongside the job -#[derive(Serialize, Deserialize, Default)] +#[derive(Serialize, Deserialize, Default, Clone, Debug)] pub struct JobMetadata { #[serde(default)] trace_id: String, @@ -97,6 +101,38 @@ pub trait QueueJobRepository: Send + Sync { payload: serde_json::Value, metadata: serde_json::Value, ) -> Result<(), Self::Error>; + + /// Reserve multiple jobs from multiple queues + /// + /// # Parameters + /// + /// * `clock` - The clock used to generate timestamps + /// * `worker` - The worker that is reserving the jobs + /// * `queues` - The queues to reserve jobs from + /// * `count` - The number of jobs to reserve + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails. + async fn reserve( + &mut self, + clock: &dyn Clock, + worker: &Worker, + queues: &[&str], + count: usize, + ) -> Result, Self::Error>; + + /// Mark a job as completed + /// + /// # Parameters + /// + /// * `clock` - The clock used to generate timestamps + /// * `job` - The job to mark as completed + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails. + async fn mark_as_completed(&mut self, clock: &dyn Clock, id: Ulid) -> Result<(), Self::Error>; } repository_impl!(QueueJobRepository: @@ -108,6 +144,16 @@ repository_impl!(QueueJobRepository: payload: serde_json::Value, metadata: serde_json::Value, ) -> Result<(), Self::Error>; + + async fn reserve( + &mut self, + clock: &dyn Clock, + worker: &Worker, + queues: &[&str], + count: usize, + ) -> Result, Self::Error>; + + async fn mark_as_completed(&mut self, clock: &dyn Clock, id: Ulid) -> Result<(), Self::Error>; ); /// Extension trait for [`QueueJobRepository`] to help adding a job to the queue diff --git a/crates/tasks/src/email.rs b/crates/tasks/src/email.rs index a16ca29dc..3afbab8ce 100644 --- a/crates/tasks/src/email.rs +++ b/crates/tasks/src/email.rs @@ -5,97 +5,87 @@ // Please see LICENSE in the repository root for full details. use anyhow::Context; -use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor}; +use async_trait::async_trait; use chrono::Duration; use mas_email::{Address, Mailbox}; use mas_i18n::locale; -use mas_storage::{job::JobWithSpanContext, queue::VerifyEmailJob}; +use mas_storage::queue::VerifyEmailJob; use mas_templates::{EmailVerificationContext, TemplateContext}; use rand::{distributions::Uniform, Rng}; use tracing::info; -use crate::{storage::PostgresStorageFactory, JobContextExt, State}; - -#[tracing::instrument( - name = "job.verify_email", - fields(user_email.id = %job.user_email_id()), - skip_all, - err(Debug), -)] -async fn verify_email( - job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); - let mut repo = state.repository().await?; - let mut rng = state.rng(); - let mailer = state.mailer(); - let clock = state.clock(); - - let language = job - .language() - .and_then(|l| l.parse().ok()) - .unwrap_or(locale!("en").into()); - - // Lookup the user email - let user_email = repo - .user_email() - .lookup(job.user_email_id()) - .await? - .context("User email not found")?; - - // Lookup the user associated with the email - let user = repo - .user() - .lookup(user_email.user_id) - .await? - .context("User not found")?; - - // Generate a verification code - let range = Uniform::::from(0..1_000_000); - let code = rng.sample(range); - let code = format!("{code:06}"); - - let address: Address = user_email.email.parse()?; - - // Save the verification code in the database - let verification = repo - .user_email() - .add_verification_code( - &mut rng, - &clock, - &user_email, - Duration::try_hours(8).unwrap(), - code, - ) - .await?; - - // And send the verification email - let mailbox = Mailbox::new(Some(user.username.clone()), address); - - let context = - EmailVerificationContext::new(user.clone(), verification.clone()).with_language(language); - - mailer.send_verification_email(mailbox, &context).await?; - - info!( - email.id = %user_email.id, - "Verification email sent" - ); - - repo.save().await?; - - Ok(()) -} - -pub(crate) fn register( - suffix: &str, - monitor: Monitor, - state: &State, - storage_factory: &PostgresStorageFactory, -) -> Monitor { - let verify_email_worker = - crate::build!(VerifyEmailJob => verify_email, suffix, state, storage_factory); - - monitor.register(verify_email_worker) +use crate::{ + new_queue::{JobContext, RunnableJob}, + State, +}; + +#[async_trait] +impl RunnableJob for VerifyEmailJob { + #[tracing::instrument( + name = "job.verify_email", + fields(user_email.id = %self.user_email_id()), + skip_all, + err(Debug), + )] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + let mut repo = state.repository().await?; + let mut rng = state.rng(); + let mailer = state.mailer(); + let clock = state.clock(); + + let language = self + .language() + .and_then(|l| l.parse().ok()) + .unwrap_or(locale!("en").into()); + + // Lookup the user email + let user_email = repo + .user_email() + .lookup(self.user_email_id()) + .await? + .context("User email not found")?; + + // Lookup the user associated with the email + let user = repo + .user() + .lookup(user_email.user_id) + .await? + .context("User not found")?; + + // Generate a verification code + let range = Uniform::::from(0..1_000_000); + let code = rng.sample(range); + let code = format!("{code:06}"); + + let address: Address = user_email.email.parse()?; + + // Save the verification code in the database + let verification = repo + .user_email() + .add_verification_code( + &mut rng, + &clock, + &user_email, + Duration::try_hours(8).unwrap(), + code, + ) + .await?; + + // And send the verification email + let mailbox = Mailbox::new(Some(user.username.clone()), address); + + let context = EmailVerificationContext::new(user.clone(), verification.clone()) + .with_language(language); + + mailer.send_verification_email(mailbox, &context).await?; + + info!( + email.id = %user_email.id, + "Verification email sent" + ); + + repo.save().await?; + + Ok(()) + } } diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index e56a082c7..ad2ede868 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -18,14 +18,13 @@ use rand::SeedableRng; use sqlx::{Pool, Postgres}; use tokio_util::{sync::CancellationToken, task::TaskTracker}; +// TODO: we need to have a way to schedule recurring tasks // mod database; -// mod email; -// mod matrix; +mod email; +mod matrix; mod new_queue; -// mod recovery; -// mod storage; -// mod user; -// mod utils; +mod recovery; +mod user; #[derive(Clone)] struct State { @@ -111,6 +110,15 @@ pub async fn init( ); let mut worker = self::new_queue::QueueWorker::new(state, cancellation_token).await?; + worker.register_handler::(); + worker.register_handler::(); + worker.register_handler::(); + worker.register_handler::(); + worker.register_handler::(); + worker.register_handler::(); + worker.register_handler::(); + worker.register_handler::(); + task_tracker.spawn(async move { if let Err(e) = worker.run().await { tracing::error!( diff --git a/crates/tasks/src/matrix.rs b/crates/tasks/src/matrix.rs index 3cc09b272..f4596c05f 100644 --- a/crates/tasks/src/matrix.rs +++ b/crates/tasks/src/matrix.rs @@ -7,239 +7,239 @@ use std::collections::HashSet; use anyhow::Context; -use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor}; +use async_trait::async_trait; use mas_data_model::Device; use mas_matrix::ProvisionRequest; use mas_storage::{ compat::CompatSessionFilter, - job::{JobRepositoryExt as _, JobWithSpanContext}, oauth2::OAuth2SessionFilter, - queue::{DeleteDeviceJob, ProvisionDeviceJob, ProvisionUserJob, SyncDevicesJob}, + queue::{ + DeleteDeviceJob, ProvisionDeviceJob, ProvisionUserJob, QueueJobRepositoryExt as _, + SyncDevicesJob, + }, user::{UserEmailRepository, UserRepository}, Pagination, RepositoryAccess, }; use tracing::info; -use crate::{storage::PostgresStorageFactory, JobContextExt, State}; +use crate::{ + new_queue::{JobContext, RunnableJob}, + State, +}; /// Job to provision a user on the Matrix homeserver. -/// This works by doing a PUT request to the /_synapse/admin/v2/users/{user_id} -/// endpoint. -#[tracing::instrument( - name = "job.provision_user" - fields(user.id = %job.user_id()), - skip_all, - err(Debug), -)] -async fn provision_user( - job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); - let matrix = state.matrix_connection(); - let mut repo = state.repository().await?; - - let user = repo - .user() - .lookup(job.user_id()) - .await? - .context("User not found")?; - - let mxid = matrix.mxid(&user.username); - let emails = repo - .user_email() - .all(&user) - .await? - .into_iter() - .filter(|email| email.confirmed_at.is_some()) - .map(|email| email.email) - .collect(); - let mut request = ProvisionRequest::new(mxid.clone(), user.sub.clone()).set_emails(emails); - - if let Some(display_name) = job.display_name_to_set() { - request = request.set_displayname(display_name.to_owned()); - } +/// This works by doing a PUT request to the +/// /_synapse/admin/v2/users/{user_id} endpoint. +#[async_trait] +impl RunnableJob for ProvisionUserJob { + #[tracing::instrument( + name = "job.provision_user" + fields(user.id = %self.user_id()), + skip_all, + err(Debug), + )] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + let matrix = state.matrix_connection(); + let mut repo = state.repository().await?; + let mut rng = state.rng(); + let clock = state.clock(); + + let user = repo + .user() + .lookup(self.user_id()) + .await? + .context("User not found")?; + + let mxid = matrix.mxid(&user.username); + let emails = repo + .user_email() + .all(&user) + .await? + .into_iter() + .filter(|email| email.confirmed_at.is_some()) + .map(|email| email.email) + .collect(); + let mut request = ProvisionRequest::new(mxid.clone(), user.sub.clone()).set_emails(emails); + + if let Some(display_name) = self.display_name_to_set() { + request = request.set_displayname(display_name.to_owned()); + } - let created = matrix.provision_user(&request).await?; + let created = matrix.provision_user(&request).await?; - if created { - info!(%user.id, %mxid, "User created"); - } else { - info!(%user.id, %mxid, "User updated"); - } + if created { + info!(%user.id, %mxid, "User created"); + } else { + info!(%user.id, %mxid, "User updated"); + } - // Schedule a device sync job - let sync_device_job = SyncDevicesJob::new(&user); - repo.job().schedule_job(sync_device_job).await?; + // Schedule a device sync job + let sync_device_job = SyncDevicesJob::new(&user); + repo.queue_job() + .schedule_job(&mut rng, &clock, sync_device_job) + .await?; - repo.save().await?; + repo.save().await?; - Ok(()) + Ok(()) + } } /// Job to provision a device on the Matrix homeserver. /// /// This job is deprecated and therefore just schedules a [`SyncDevicesJob`] -#[tracing::instrument( - name = "job.provision_device" - fields( - user.id = %job.user_id(), - device.id = %job.device_id(), - ), - skip_all, - err(Debug), -)] -async fn provision_device( - job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); - let mut repo = state.repository().await?; - - let user = repo - .user() - .lookup(job.user_id()) - .await? - .context("User not found")?; - - // Schedule a device sync job - repo.job().schedule_job(SyncDevicesJob::new(&user)).await?; - - Ok(()) +#[async_trait] +impl RunnableJob for ProvisionDeviceJob { + #[tracing::instrument( + name = "job.provision_device" + fields( + user.id = %self.user_id(), + device.id = %self.device_id(), + ), + skip_all, + err(Debug), + )] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + let mut repo = state.repository().await?; + let mut rng = state.rng(); + let clock = state.clock(); + + let user = repo + .user() + .lookup(self.user_id()) + .await? + .context("User not found")?; + + // Schedule a device sync job + repo.queue_job() + .schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user)) + .await?; + + Ok(()) + } } /// Job to delete a device from a user's account. /// /// This job is deprecated and therefore just schedules a [`SyncDevicesJob`] -#[tracing::instrument( - name = "job.delete_device" - fields( - user.id = %job.user_id(), - device.id = %job.device_id(), - ), - skip_all, - err(Debug), -)] -async fn delete_device( - job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); - let mut repo = state.repository().await?; - - let user = repo - .user() - .lookup(job.user_id()) - .await? - .context("User not found")?; - - // Schedule a device sync job - repo.job().schedule_job(SyncDevicesJob::new(&user)).await?; - - Ok(()) +#[async_trait] +impl RunnableJob for DeleteDeviceJob { + #[tracing::instrument( + name = "job.delete_device" + fields( + user.id = %self.user_id(), + device.id = %self.device_id(), + ), + skip_all, + err(Debug), + )] + #[tracing::instrument( + name = "job.delete_device" + fields( + user.id = %self.user_id(), + device.id = %self.device_id(), + ), + skip_all, + err(Debug), + )] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + let mut rng = state.rng(); + let clock = state.clock(); + let mut repo = state.repository().await?; + + let user = repo + .user() + .lookup(self.user_id()) + .await? + .context("User not found")?; + + // Schedule a device sync job + repo.queue_job() + .schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user)) + .await?; + + Ok(()) + } } /// Job to sync the list of devices of a user with the homeserver. -#[tracing::instrument( - name = "job.sync_devices", - fields(user.id = %job.user_id()), - skip_all, - err(Debug), -)] -async fn sync_devices( - job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); - let matrix = state.matrix_connection(); - let mut repo = state.repository().await?; - - let user = repo - .user() - .lookup(job.user_id()) - .await? - .context("User not found")?; - - // Lock the user sync to make sure we don't get into a race condition - repo.user().acquire_lock_for_sync(&user).await?; - - let mut devices = HashSet::new(); - - // Cycle through all the compat sessions of the user, and grab the devices - let mut cursor = Pagination::first(100); - loop { - let page = repo - .compat_session() - .list( - CompatSessionFilter::new().for_user(&user).active_only(), - cursor, - ) - .await?; - - for (compat_session, _) in page.edges { - devices.insert(compat_session.device.as_str().to_owned()); - cursor = cursor.after(compat_session.id); - } +#[async_trait] +impl RunnableJob for SyncDevicesJob { + #[tracing::instrument( + name = "job.sync_devices", + fields(user.id = %self.user_id()), + skip_all, + err(Debug), + )] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + let matrix = state.matrix_connection(); + let mut repo = state.repository().await?; + + let user = repo + .user() + .lookup(self.user_id()) + .await? + .context("User not found")?; + + // Lock the user sync to make sure we don't get into a race condition + repo.user().acquire_lock_for_sync(&user).await?; + + let mut devices = HashSet::new(); + + // Cycle through all the compat sessions of the user, and grab the devices + let mut cursor = Pagination::first(100); + loop { + let page = repo + .compat_session() + .list( + CompatSessionFilter::new().for_user(&user).active_only(), + cursor, + ) + .await?; + + for (compat_session, _) in page.edges { + devices.insert(compat_session.device.as_str().to_owned()); + cursor = cursor.after(compat_session.id); + } - if !page.has_next_page { - break; + if !page.has_next_page { + break; + } } - } - // Cycle though all the oauth2 sessions of the user, and grab the devices - let mut cursor = Pagination::first(100); - loop { - let page = repo - .oauth2_session() - .list( - OAuth2SessionFilter::new().for_user(&user).active_only(), - cursor, - ) - .await?; - - for oauth2_session in page.edges { - for scope in &*oauth2_session.scope { - if let Some(device) = Device::from_scope_token(scope) { - devices.insert(device.as_str().to_owned()); + // Cycle though all the oauth2 sessions of the user, and grab the devices + let mut cursor = Pagination::first(100); + loop { + let page = repo + .oauth2_session() + .list( + OAuth2SessionFilter::new().for_user(&user).active_only(), + cursor, + ) + .await?; + + for oauth2_session in page.edges { + for scope in &*oauth2_session.scope { + if let Some(device) = Device::from_scope_token(scope) { + devices.insert(device.as_str().to_owned()); + } } - } - cursor = cursor.after(oauth2_session.id); - } + cursor = cursor.after(oauth2_session.id); + } - if !page.has_next_page { - break; + if !page.has_next_page { + break; + } } - } - let mxid = matrix.mxid(&user.username); - matrix.sync_devices(&mxid, devices).await?; + let mxid = matrix.mxid(&user.username); + matrix.sync_devices(&mxid, devices).await?; - // We kept the connection until now, so that we still hold the lock on the user - // throughout the sync - repo.save().await?; + // We kept the connection until now, so that we still hold the lock on the user + // throughout the sync + repo.save().await?; - Ok(()) -} - -pub(crate) fn register( - suffix: &str, - monitor: Monitor, - state: &State, - storage_factory: &PostgresStorageFactory, -) -> Monitor { - let provision_user_worker = - crate::build!(ProvisionUserJob => provision_user, suffix, state, storage_factory); - let provision_device_worker = - crate::build!(ProvisionDeviceJob => provision_device, suffix, state, storage_factory); - let delete_device_worker = - crate::build!(DeleteDeviceJob => delete_device, suffix, state, storage_factory); - let sync_devices_worker = - crate::build!(SyncDevicesJob => sync_devices, suffix, state, storage_factory); - - monitor - .register(provision_user_worker) - .register(provision_device_worker) - .register(delete_device_worker) - .register(sync_devices_worker) + Ok(()) + } } diff --git a/crates/tasks/src/new_queue.rs b/crates/tasks/src/new_queue.rs index f90b72011..42a037af4 100644 --- a/crates/tasks/src/new_queue.rs +++ b/crates/tasks/src/new_queue.rs @@ -3,12 +3,12 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use async_trait::async_trait; use chrono::{DateTime, Duration, Utc}; use mas_storage::{ - queue::{InsertableJob, Job, Worker}, + queue::{InsertableJob, Job, JobMetadata, Worker}, Clock, RepositoryAccess, RepositoryError, }; use mas_storage_pg::{DatabaseError, PgRepository}; @@ -20,12 +20,42 @@ use sqlx::{ Acquire, Either, }; use thiserror::Error; +use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; +use tracing::{Instrument as _, Span}; +use tracing_opentelemetry::OpenTelemetrySpanExt as _; +use ulid::Ulid; use crate::State; +type JobPayload = serde_json::Value; + +#[derive(Clone)] +pub struct JobContext { + pub id: Ulid, + pub metadata: JobMetadata, + pub queue_name: String, + pub cancellation_token: CancellationToken, +} + +impl JobContext { + pub fn span(&self) -> Span { + let span = tracing::info_span!( + parent: Span::none(), + "job.run", + job.id = %self.id, + job.queue_name = self.queue_name, + job.attempt = self.attempt, + ); + + span.add_link(self.metadata.span_context()); + + span + } +} + pub trait FromJob { - fn from_job(job: &Job) -> Result + fn from_job(payload: JobPayload) -> Result where Self: Sized; } @@ -34,14 +64,14 @@ impl FromJob for T where T: DeserializeOwned, { - fn from_job(job: &Job) -> Result { - serde_json::from_value(job.payload.clone()).map_err(Into::into) + fn from_job(payload: JobPayload) -> Result { + serde_json::from_value(payload).map_err(Into::into) } } #[async_trait] pub trait RunnableJob: FromJob + Send + 'static { - async fn run(&self, state: &State) -> Result<(), anyhow::Error>; + async fn run(&self, state: &State, context: JobContext) -> Result<(), anyhow::Error>; } fn box_runnable_job(job: T) -> Box { @@ -79,7 +109,13 @@ pub enum QueueRunnerError { const MIN_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(900); const MAX_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(1100); -type JobFactory = Box Box + Send>; +// How many jobs can we run concurrently +const MAX_CONCURRENT_JOBS: usize = 10; + +// How many jobs can we fetch at once +const MAX_JOBS_TO_FETCH: usize = 5; + +type JobFactory = Arc Box + Send + Sync>; pub struct QueueWorker { rng: ChaChaRng, @@ -89,7 +125,14 @@ pub struct QueueWorker { am_i_leader: bool, last_heartbeat: DateTime, cancellation_token: CancellationToken, + state: State, + running_jobs: JoinSet>, + job_contexts: HashMap, factories: HashMap<&'static str, JobFactory>, + + #[allow(clippy::type_complexity)] + last_join_result: + Option), tokio::task::JoinError>>, } impl QueueWorker { @@ -115,6 +158,12 @@ impl QueueWorker { .await .map_err(QueueRunnerError::SetupListener)?; + // We get notifications when a job is available on this channel + listener + .listen("queue_available") + .await + .map_err(QueueRunnerError::SetupListener)?; + let txn = listener .begin() .await @@ -139,14 +188,22 @@ impl QueueWorker { am_i_leader: false, last_heartbeat: now, cancellation_token, + state, + job_contexts: HashMap::new(), + running_jobs: JoinSet::new(), factories: HashMap::new(), + last_join_result: None, }) } pub fn register_handler(&mut self) -> &mut Self { - // TODO: error handling - let factory = |job: &Job| box_runnable_job(T::from_job(job).unwrap()); - self.factories.insert(T::QUEUE_NAME, Box::new(factory)); + // There is a potential panic here, which is fine as it's going to be caught + // within the job task + let factory = |payload: JobPayload| { + box_runnable_job(T::from_job(payload).expect("Failed to deserialize job")) + }; + + self.factories.insert(T::QUEUE_NAME, Arc::new(factory)); self } @@ -164,6 +221,7 @@ impl QueueWorker { async fn run_loop(&mut self) -> Result<(), QueueRunnerError> { self.wait_until_wakeup().await?; + // TODO: join all the jobs handles when shutting down if self.cancellation_token.is_cancelled() { return Ok(()); } @@ -214,6 +272,8 @@ impl QueueWorker { .sample(Uniform::new(MIN_SLEEP_DURATION, MAX_SLEEP_DURATION)); let wakeup_sleep = tokio::time::sleep(sleep_duration); + // TODO: add metrics to track the wake up reasons + tokio::select! { () = self.cancellation_token.cancelled() => { tracing::debug!("Woke up from cancellation"); @@ -223,6 +283,11 @@ impl QueueWorker { tracing::debug!("Woke up from sleep"); }, + Some(result) = self.running_jobs.join_next_with_id() => { + tracing::debug!("Joined job task"); + self.last_join_result = Some(result); + }, + notification = self.listener.recv() => { match notification { Ok(notification) => { @@ -281,6 +346,127 @@ impl QueueWorker { .try_get_leader_lease(&self.clock, &self.registration) .await?; + // Find any job task which finished + // If we got woken up by a join on the joinset, it will be stored in the + // last_join_result so that we don't loose it + + if self.last_join_result.is_none() { + self.last_join_result = self.running_jobs.try_join_next_with_id(); + } + + while let Some(result) = self.last_join_result.take() { + // TODO: add metrics to track the job status and the time it took + let context = match result { + Ok((id, Ok(()))) => { + // The job succeeded + let context = self + .job_contexts + .remove(&id) + .expect("Job context not found"); + + tracing::info!( + job.id = %context.id, + job.queue_name = %context.queue_name, + "Job completed" + ); + + context + } + Ok((id, Err(e))) => { + // The job failed + let context = self + .job_contexts + .remove(&id) + .expect("Job context not found"); + + tracing::error!( + error = ?e, + job.id = %context.id, + job.queue_name = %context.queue_name, + "Job failed" + ); + + // TODO: reschedule the job + + context + } + Err(e) => { + // The job crashed (or was cancelled) + let id = e.id(); + let context = self + .job_contexts + .remove(&id) + .expect("Job context not found"); + + tracing::error!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + "Job crashed" + ); + + // TODO: reschedule the job + + context + } + }; + + repo.queue_job() + .mark_as_completed(&self.clock, context.id) + .await?; + + self.last_join_result = self.running_jobs.try_join_next_with_id(); + } + + // Compute how many jobs we should fetch at most + let max_jobs_to_fetch = MAX_CONCURRENT_JOBS + .saturating_sub(self.running_jobs.len()) + .max(MAX_JOBS_TO_FETCH); + + if max_jobs_to_fetch == 0 { + tracing::warn!("Internal job queue is full, not fetching any new jobs"); + } else { + // Grab a few jobs in the queue + let queues = self.factories.keys().copied().collect::>(); + let jobs = repo + .queue_job() + .reserve(&self.clock, &self.registration, &queues, max_jobs_to_fetch) + .await?; + + for Job { + id, + queue_name, + payload, + metadata, + } in jobs + { + let cancellation_token = self.cancellation_token.child_token(); + let factory = self.factories.get(queue_name.as_str()).cloned(); + let context = JobContext { + id, + metadata, + queue_name, + cancellation_token, + }; + + let task = { + let context = context.clone(); + let span = context.span(); + let state = self.state.clone(); + async move { + // We should never crash, but in case we do, we do that in the task and + // don't crash the worker + let job = factory.expect("unknown job factory")(payload); + job.run(&state, context).await + } + .instrument(span) + }; + + let handle = self.running_jobs.spawn(task); + self.job_contexts.insert(handle.id(), context); + } + } + // After this point, we are locking the leader table, so it's important that we // commit as soon as possible to not block the other workers for too long repo.into_inner() @@ -353,6 +539,8 @@ impl QueueWorker { .shutdown_dead_workers(&self.clock, Duration::minutes(2)) .await?; + // TODO: mark tasks those workers had as lost + // Release the leader lock let txn = repo .into_inner() diff --git a/crates/tasks/src/recovery.rs b/crates/tasks/src/recovery.rs index 79f469b06..cd3787d2a 100644 --- a/crates/tasks/src/recovery.rs +++ b/crates/tasks/src/recovery.rs @@ -5,11 +5,10 @@ // Please see LICENSE in the repository root for full details. use anyhow::Context; -use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor}; +use async_trait::async_trait; use mas_email::{Address, Mailbox}; use mas_i18n::DataLocale; use mas_storage::{ - job::JobWithSpanContext, queue::SendAccountRecoveryEmailsJob, user::{UserEmailFilter, UserRecoveryRepository}, Pagination, RepositoryAccess, @@ -18,117 +17,108 @@ use mas_templates::{EmailRecoveryContext, TemplateContext}; use rand::distributions::{Alphanumeric, DistString}; use tracing::{error, info}; -use crate::{storage::PostgresStorageFactory, JobContextExt, State}; +use crate::{ + new_queue::{JobContext, RunnableJob}, + State, +}; /// Job to send account recovery emails for a given recovery session. -#[tracing::instrument( - name = "job.send_account_recovery_email", - fields( - user_recovery_session.id = %job.user_recovery_session_id(), - user_recovery_session.email, - ), - skip_all, - err(Debug), -)] -async fn send_account_recovery_email_job( - job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); - let clock = state.clock(); - let mailer = state.mailer(); - let url_builder = state.url_builder(); - let mut rng = state.rng(); - let mut repo = state.repository().await?; - - let session = repo - .user_recovery() - .lookup_session(job.user_recovery_session_id()) - .await? - .context("User recovery session not found")?; - - tracing::Span::current().record("user_recovery_session.email", &session.email); - - if session.consumed_at.is_some() { - info!("Recovery session already consumed, not sending email"); - return Ok(()); - } +#[async_trait] +impl RunnableJob for SendAccountRecoveryEmailsJob { + #[tracing::instrument( + name = "job.send_account_recovery_email", + fields( + user_recovery_session.id = %self.user_recovery_session_id(), + user_recovery_session.email, + ), + skip_all, + err(Debug), + )] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + let clock = state.clock(); + let mailer = state.mailer(); + let url_builder = state.url_builder(); + let mut rng = state.rng(); + let mut repo = state.repository().await?; + + let session = repo + .user_recovery() + .lookup_session(self.user_recovery_session_id()) + .await? + .context("User recovery session not found")?; + + tracing::Span::current().record("user_recovery_session.email", &session.email); + + if session.consumed_at.is_some() { + info!("Recovery session already consumed, not sending email"); + return Ok(()); + } - let mut cursor = Pagination::first(50); - - let lang: DataLocale = session - .locale - .parse() - .context("Invalid locale in database on recovery session")?; - - loop { - let page = repo - .user_email() - .list( - UserEmailFilter::new() - .for_email(&session.email) - .verified_only(), - cursor, - ) - .await?; - - for email in page.edges { - let ticket = Alphanumeric.sample_string(&mut rng, 32); - - let ticket = repo - .user_recovery() - .add_ticket(&mut rng, &clock, &session, &email, ticket) - .await?; + let mut cursor = Pagination::first(50); + + let lang: DataLocale = session + .locale + .parse() + .context("Invalid locale in database on recovery session")?; - let user_email = repo + loop { + let page = repo .user_email() - .lookup(email.id) - .await? - .context("User email not found")?; - - let user = repo - .user() - .lookup(user_email.user_id) - .await? - .context("User not found")?; - - let url = url_builder.account_recovery_link(ticket.ticket); - - let address: Address = user_email.email.parse()?; - let mailbox = Mailbox::new(Some(user.username.clone()), address); - - info!("Sending recovery email to {}", mailbox); - let context = - EmailRecoveryContext::new(user, session.clone(), url).with_language(lang.clone()); - - // XXX: we only log if the email fails to send, to avoid stopping the loop - if let Err(e) = mailer.send_recovery_email(mailbox, &context).await { - error!( - error = &e as &dyn std::error::Error, - "Failed to send recovery email" - ); - } + .list( + UserEmailFilter::new() + .for_email(&session.email) + .verified_only(), + cursor, + ) + .await?; - cursor = cursor.after(email.id); - } + for email in page.edges { + let ticket = Alphanumeric.sample_string(&mut rng, 32); - if !page.has_next_page { - break; - } - } + let ticket = repo + .user_recovery() + .add_ticket(&mut rng, &clock, &session, &email, ticket) + .await?; - repo.save().await?; + let user_email = repo + .user_email() + .lookup(email.id) + .await? + .context("User email not found")?; - Ok(()) -} + let user = repo + .user() + .lookup(user_email.user_id) + .await? + .context("User not found")?; -pub(crate) fn register( - suffix: &str, - monitor: Monitor, - state: &State, - storage_factory: &PostgresStorageFactory, -) -> Monitor { - let send_user_recovery_email_worker = crate::build!(SendAccountRecoveryEmailsJob => send_account_recovery_email_job, suffix, state, storage_factory); + let url = url_builder.account_recovery_link(ticket.ticket); - monitor.register(send_user_recovery_email_worker) + let address: Address = user_email.email.parse()?; + let mailbox = Mailbox::new(Some(user.username.clone()), address); + + info!("Sending recovery email to {}", mailbox); + let context = EmailRecoveryContext::new(user, session.clone(), url) + .with_language(lang.clone()); + + // XXX: we only log if the email fails to send, to avoid stopping the loop + if let Err(e) = mailer.send_recovery_email(mailbox, &context).await { + error!( + error = &e as &dyn std::error::Error, + "Failed to send recovery email" + ); + } + + cursor = cursor.after(email.id); + } + + if !page.has_next_page { + break; + } + } + + repo.save().await?; + + Ok(()) + } } diff --git a/crates/tasks/src/storage/from_row.rs b/crates/tasks/src/storage/from_row.rs deleted file mode 100644 index 5acf6848a..000000000 --- a/crates/tasks/src/storage/from_row.rs +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2023, 2024 The Matrix.org Foundation C.I.C. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -use std::str::FromStr; - -use apalis_core::{context::JobContext, job::JobId, request::JobRequest, worker::WorkerId}; -use chrono::{DateTime, Utc}; -use serde_json::Value; -use sqlx::Row; - -/// Wrapper for [`JobRequest`] -pub(crate) struct SqlJobRequest(JobRequest); - -impl From> for JobRequest { - fn from(val: SqlJobRequest) -> Self { - val.0 - } -} - -impl<'r, T: serde::de::DeserializeOwned> sqlx::FromRow<'r, sqlx::postgres::PgRow> - for SqlJobRequest -{ - fn from_row(row: &'r sqlx::postgres::PgRow) -> Result { - let job: Value = row.try_get("job")?; - let id: JobId = - JobId::from_str(row.try_get("id")?).map_err(|e| sqlx::Error::ColumnDecode { - index: "id".to_owned(), - source: Box::new(e), - })?; - let mut context = JobContext::new(id); - - let run_at = row.try_get("run_at")?; - context.set_run_at(run_at); - - let attempts = row.try_get("attempts").unwrap_or(0); - context.set_attempts(attempts); - - let max_attempts = row.try_get("max_attempts").unwrap_or(25); - context.set_max_attempts(max_attempts); - - let done_at: Option> = row.try_get("done_at").unwrap_or_default(); - context.set_done_at(done_at); - - let lock_at: Option> = row.try_get("lock_at").unwrap_or_default(); - context.set_lock_at(lock_at); - - let last_error = row.try_get("last_error").unwrap_or_default(); - context.set_last_error(last_error); - - let status: String = row.try_get("status")?; - context.set_status(status.parse().map_err(|e| sqlx::Error::ColumnDecode { - index: "job".to_owned(), - source: Box::new(e), - })?); - - let lock_by: Option = row.try_get("lock_by").unwrap_or_default(); - context.set_lock_by(lock_by.map(WorkerId::new)); - - Ok(SqlJobRequest(JobRequest::new_with_context( - serde_json::from_value(job).map_err(|e| sqlx::Error::ColumnDecode { - index: "job".to_owned(), - source: Box::new(e), - })?, - context, - ))) - } -} diff --git a/crates/tasks/src/storage/mod.rs b/crates/tasks/src/storage/mod.rs deleted file mode 100644 index 5f6e77e31..000000000 --- a/crates/tasks/src/storage/mod.rs +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2023, 2024 The Matrix.org Foundation C.I.C. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -//! Reimplementation of the [`apalis_sql::storage::PostgresStorage`] using a -//! shared connection for the [`PgListener`] - -mod from_row; -mod postgres; - -use self::from_row::SqlJobRequest; -pub(crate) use self::postgres::StorageFactory as PostgresStorageFactory; diff --git a/crates/tasks/src/storage/postgres.rs b/crates/tasks/src/storage/postgres.rs deleted file mode 100644 index f709579ed..000000000 --- a/crates/tasks/src/storage/postgres.rs +++ /dev/null @@ -1,391 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2023, 2024 The Matrix.org Foundation C.I.C. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -use std::{convert::TryInto, marker::PhantomData, ops::Add, sync::Arc, time::Duration}; - -use apalis_core::{ - error::JobStreamError, - job::{Job, JobId, JobStreamResult}, - request::JobRequest, - storage::{StorageError, StorageResult, StorageWorkerPulse}, - utils::Timer, - worker::WorkerId, -}; -use async_stream::try_stream; -use chrono::{DateTime, Utc}; -use event_listener::Event; -use futures_lite::{Stream, StreamExt}; -use serde::{de::DeserializeOwned, Serialize}; -use sqlx::{postgres::PgListener, PgPool, Pool, Postgres, Row}; -use tokio::task::JoinHandle; - -use super::SqlJobRequest; - -pub struct StorageFactory { - pool: PgPool, - event: Arc, -} - -impl StorageFactory { - pub fn new(pool: Pool) -> Self { - StorageFactory { - pool, - event: Arc::new(Event::new()), - } - } - - pub async fn listen(self) -> Result, sqlx::Error> { - let mut listener = PgListener::connect_with(&self.pool).await?; - listener.listen("apalis::job").await?; - - let handle = tokio::spawn(async move { - loop { - let notification = listener.recv().await.expect("Failed to poll notification"); - self.event.notify(usize::MAX); - tracing::debug!(?notification, "Broadcast notification"); - } - }); - - Ok(handle) - } - - pub fn build(&self) -> Storage { - Storage { - pool: self.pool.clone(), - event: self.event.clone(), - job_type: PhantomData, - } - } -} - -/// Represents a [`apalis_core::storage::Storage`] that persists to Postgres -#[derive(Debug)] -pub struct Storage { - pool: PgPool, - event: Arc, - job_type: PhantomData, -} - -impl Clone for Storage { - fn clone(&self) -> Self { - Storage { - pool: self.pool.clone(), - event: self.event.clone(), - job_type: PhantomData, - } - } -} - -impl Storage { - fn stream_jobs( - &self, - worker_id: &WorkerId, - interval: Duration, - buffer_size: usize, - ) -> impl Stream, JobStreamError>> { - let pool = self.pool.clone(); - let sleeper = apalis_core::utils::timer::TokioTimer; - let worker_id = worker_id.clone(); - let event = self.event.clone(); - try_stream! { - loop { - // Wait for a notification or a timeout - let listener = event.listen(); - let interval = sleeper.sleep(interval); - futures_lite::future::race(interval, listener).await; - - let tx = pool.clone(); - let job_type = T::NAME; - let fetch_query = "SELECT * FROM apalis.get_jobs($1, $2, $3);"; - let jobs: Vec> = sqlx::query_as(fetch_query) - .bind(worker_id.name()) - .bind(job_type) - // https://docs.rs/sqlx/latest/sqlx/postgres/types/index.html - .bind(i32::try_from(buffer_size).map_err(|e| JobStreamError::BrokenPipe(Box::from(e)))?) - .fetch_all(&tx) - .await.map_err(|e| JobStreamError::BrokenPipe(Box::from(e)))?; - for job in jobs { - yield job.into() - } - } - } - } - - async fn keep_alive_at( - &mut self, - worker_id: &WorkerId, - last_seen: DateTime, - ) -> StorageResult<()> { - let pool = self.pool.clone(); - - let worker_type = T::NAME; - let storage_name = std::any::type_name::(); - let query = "INSERT INTO apalis.workers (id, worker_type, storage_name, layers, last_seen) - VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (id) DO - UPDATE SET last_seen = EXCLUDED.last_seen"; - sqlx::query(query) - .bind(worker_id.name()) - .bind(worker_type) - .bind(storage_name) - .bind(std::any::type_name::()) - .bind(last_seen) - .execute(&pool) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(()) - } -} - -#[async_trait::async_trait] -impl apalis_core::storage::Storage for Storage -where - T: Job + Serialize + DeserializeOwned + Send + 'static + Unpin + Sync, -{ - type Output = T; - - /// Push a job to Postgres [Storage] - /// - /// # SQL Example - /// - /// ```sql - /// SELECT apalis.push_job(job_type::text, job::json); - /// ``` - async fn push(&mut self, job: Self::Output) -> StorageResult { - let id = JobId::new(); - let query = "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, 25, NOW() , NULL, NULL, NULL, NULL)"; - let pool = self.pool.clone(); - let job = serde_json::to_value(&job).map_err(|e| StorageError::Parse(Box::from(e)))?; - let job_type = T::NAME; - sqlx::query(query) - .bind(job) - .bind(id.to_string()) - .bind(job_type) - .execute(&pool) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(id) - } - - async fn schedule( - &mut self, - job: Self::Output, - on: chrono::DateTime, - ) -> StorageResult { - let query = - "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, 25, $4, NULL, NULL, NULL, NULL)"; - - let mut conn = self - .pool - .acquire() - .await - .map_err(|e| StorageError::Connection(Box::from(e)))?; - - let id = JobId::new(); - let job = serde_json::to_value(&job).map_err(|e| StorageError::Parse(Box::from(e)))?; - let job_type = T::NAME; - sqlx::query(query) - .bind(job) - .bind(id.to_string()) - .bind(job_type) - .bind(on) - .execute(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(id) - } - - async fn fetch_by_id(&self, job_id: &JobId) -> StorageResult>> { - let mut conn = self - .pool - .acquire() - .await - .map_err(|e| StorageError::Connection(Box::from(e)))?; - - let fetch_query = "SELECT * FROM apalis.jobs WHERE id = $1"; - let res: Option> = sqlx::query_as(fetch_query) - .bind(job_id.to_string()) - .fetch_optional(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(res.map(Into::into)) - } - - async fn heartbeat(&mut self, pulse: StorageWorkerPulse) -> StorageResult { - match pulse { - StorageWorkerPulse::EnqueueScheduled { count: _ } => { - // Ideally jobs are queue via run_at. So this is not necessary - Ok(true) - } - - // Worker not seen in 5 minutes yet has running jobs - StorageWorkerPulse::ReenqueueOrphaned { count, .. } => { - let job_type = T::NAME; - let mut conn = self - .pool - .acquire() - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - let query = "UPDATE apalis.jobs - SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, last_error ='Job was abandoned' - WHERE id in - (SELECT jobs.id from apalis.jobs INNER join apalis.workers ON lock_by = workers.id - WHERE status = 'Running' AND workers.last_seen < NOW() - INTERVAL '5 minutes' - AND workers.worker_type = $1 ORDER BY lock_at ASC LIMIT $2);"; - sqlx::query(query) - .bind(job_type) - .bind(count) - .execute(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(true) - } - - _ => unimplemented!(), - } - } - - async fn kill(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> { - let pool = self.pool.clone(); - - let mut conn = pool - .acquire() - .await - .map_err(|e| StorageError::Connection(Box::from(e)))?; - let query = - "UPDATE apalis.jobs SET status = 'Killed', done_at = now() WHERE id = $1 AND lock_by = $2"; - sqlx::query(query) - .bind(job_id.to_string()) - .bind(worker_id.name()) - .execute(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(()) - } - - /// Puts the job instantly back into the queue - /// Another [Worker] may consume - async fn retry(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> { - let pool = self.pool.clone(); - - let mut conn = pool - .acquire() - .await - .map_err(|e| StorageError::Connection(Box::from(e)))?; - let query = - "UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = $1 AND lock_by = $2"; - sqlx::query(query) - .bind(job_id.to_string()) - .bind(worker_id.name()) - .execute(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(()) - } - - fn consume( - &mut self, - worker_id: &WorkerId, - interval: Duration, - buffer_size: usize, - ) -> JobStreamResult { - Box::pin( - self.stream_jobs(worker_id, interval, buffer_size) - .map(|r| r.map(Some)), - ) - } - async fn len(&self) -> StorageResult { - let pool = self.pool.clone(); - let query = "SELECT COUNT(*) AS count FROM apalis.jobs WHERE status = 'Pending'"; - let record = sqlx::query(query) - .fetch_one(&pool) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(record - .try_get("count") - .map_err(|e| StorageError::Database(Box::from(e)))?) - } - async fn ack(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> { - let pool = self.pool.clone(); - let query = - "UPDATE apalis.jobs SET status = 'Done', done_at = now() WHERE id = $1 AND lock_by = $2"; - sqlx::query(query) - .bind(job_id.to_string()) - .bind(worker_id.name()) - .execute(&pool) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(()) - } - - async fn reschedule(&mut self, job: &JobRequest, wait: Duration) -> StorageResult<()> { - let pool = self.pool.clone(); - let job_id = job.id(); - - let wait: i64 = wait - .as_secs() - .try_into() - .map_err(|e| StorageError::Database(Box::new(e)))?; - let wait = chrono::Duration::microseconds(wait * 1000 * 1000); - // TODO: should we use a clock here? - #[allow(clippy::disallowed_methods)] - let run_at = Utc::now().add(wait); - - let mut conn = pool - .acquire() - .await - .map_err(|e| StorageError::Connection(Box::from(e)))?; - let query = - "UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = $2 WHERE id = $1"; - sqlx::query(query) - .bind(job_id.to_string()) - .bind(run_at) - .execute(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(()) - } - - async fn update_by_id( - &self, - job_id: &JobId, - job: &JobRequest, - ) -> StorageResult<()> { - let pool = self.pool.clone(); - let status = job.status().as_ref(); - let attempts = job.attempts(); - let done_at = *job.done_at(); - let lock_by = job.lock_by().clone(); - let lock_at = *job.lock_at(); - let last_error = job.last_error().clone(); - - let mut conn = pool - .acquire() - .await - .map_err(|e| StorageError::Connection(Box::from(e)))?; - let query = - "UPDATE apalis.jobs SET status = $1, attempts = $2, done_at = $3, lock_by = $4, lock_at = $5, last_error = $6 WHERE id = $7"; - sqlx::query(query) - .bind(status.to_owned()) - .bind(attempts) - .bind(done_at) - .bind(lock_by.as_ref().map(WorkerId::name)) - .bind(lock_at) - .bind(last_error) - .bind(job_id.to_string()) - .execute(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(()) - } - - async fn keep_alive(&mut self, worker_id: &WorkerId) -> StorageResult<()> { - #[allow(clippy::disallowed_methods)] - let now = Utc::now(); - - self.keep_alive_at::(worker_id, now).await - } -} diff --git a/crates/tasks/src/user.rs b/crates/tasks/src/user.rs index b3d062bb4..ad4444be5 100644 --- a/crates/tasks/src/user.rs +++ b/crates/tasks/src/user.rs @@ -5,10 +5,9 @@ // Please see LICENSE in the repository root for full details. use anyhow::Context; -use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor}; +use async_trait::async_trait; use mas_storage::{ compat::CompatSessionFilter, - job::JobWithSpanContext, oauth2::OAuth2SessionFilter, queue::{DeactivateUserJob, ReactivateUserJob}, user::{BrowserSessionFilter, UserRepository}, @@ -16,122 +15,106 @@ use mas_storage::{ }; use tracing::info; -use crate::{storage::PostgresStorageFactory, JobContextExt, State}; +use crate::{ + new_queue::{JobContext, RunnableJob}, + State, +}; /// Job to deactivate a user, both locally and on the Matrix homeserver. -#[tracing::instrument( +#[async_trait] +impl RunnableJob for DeactivateUserJob { + #[tracing::instrument( name = "job.deactivate_user" - fields(user.id = %job.user_id(), erase = %job.hs_erase()), - skip_all, - err(Debug), -)] -async fn deactivate_user( - job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); - let clock = state.clock(); - let matrix = state.matrix_connection(); - let mut repo = state.repository().await?; - - let user = repo - .user() - .lookup(job.user_id()) - .await? - .context("User not found")?; - - // Let's first lock the user - let user = repo - .user() - .lock(&clock, user) - .await - .context("Failed to lock user")?; - - // Kill all sessions for the user - let n = repo - .browser_session() - .finish_bulk( - &clock, - BrowserSessionFilter::new().for_user(&user).active_only(), - ) - .await?; - info!(affected = n, "Killed all browser sessions for user"); - - let n = repo - .oauth2_session() - .finish_bulk( - &clock, - OAuth2SessionFilter::new().for_user(&user).active_only(), - ) - .await?; - info!(affected = n, "Killed all OAuth 2.0 sessions for user"); - - let n = repo - .compat_session() - .finish_bulk( - &clock, - CompatSessionFilter::new().for_user(&user).active_only(), - ) - .await?; - info!(affected = n, "Killed all compatibility sessions for user"); - - // Before calling back to the homeserver, commit the changes to the database, as - // we want the user to be locked out as soon as possible - repo.save().await?; - - let mxid = matrix.mxid(&user.username); - info!("Deactivating user {} on homeserver", mxid); - matrix.delete_user(&mxid, job.hs_erase()).await?; - - Ok(()) + fields(user.id = %self.user_id(), erase = %self.hs_erase()), + skip_all, + err(Debug), + )] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + let clock = state.clock(); + let matrix = state.matrix_connection(); + let mut repo = state.repository().await?; + + let user = repo + .user() + .lookup(self.user_id()) + .await? + .context("User not found")?; + + // Let's first lock the user + let user = repo + .user() + .lock(&clock, user) + .await + .context("Failed to lock user")?; + + // Kill all sessions for the user + let n = repo + .browser_session() + .finish_bulk( + &clock, + BrowserSessionFilter::new().for_user(&user).active_only(), + ) + .await?; + info!(affected = n, "Killed all browser sessions for user"); + + let n = repo + .oauth2_session() + .finish_bulk( + &clock, + OAuth2SessionFilter::new().for_user(&user).active_only(), + ) + .await?; + info!(affected = n, "Killed all OAuth 2.0 sessions for user"); + + let n = repo + .compat_session() + .finish_bulk( + &clock, + CompatSessionFilter::new().for_user(&user).active_only(), + ) + .await?; + info!(affected = n, "Killed all compatibility sessions for user"); + + // Before calling back to the homeserver, commit the changes to the database, as + // we want the user to be locked out as soon as possible + repo.save().await?; + + let mxid = matrix.mxid(&user.username); + info!("Deactivating user {} on homeserver", mxid); + matrix.delete_user(&mxid, self.hs_erase()).await?; + + Ok(()) + } } /// Job to reactivate a user, both locally and on the Matrix homeserver. -#[tracing::instrument( - name = "job.reactivate_user", - fields(user.id = %job.user_id()), - skip_all, - err(Debug), -)] -pub async fn reactivate_user( - job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); - let matrix = state.matrix_connection(); - let mut repo = state.repository().await?; - - let user = repo - .user() - .lookup(job.user_id()) - .await? - .context("User not found")?; - - let mxid = matrix.mxid(&user.username); - info!("Reactivating user {} on homeserver", mxid); - matrix.reactivate_user(&mxid).await?; - - // We want to unlock the user from our side only once it has been reactivated on - // the homeserver - let _user = repo.user().unlock(user).await?; - repo.save().await?; - - Ok(()) -} - -pub(crate) fn register( - suffix: &str, - monitor: Monitor, - state: &State, - storage_factory: &PostgresStorageFactory, -) -> Monitor { - let deactivate_user_worker = - crate::build!(DeactivateUserJob => deactivate_user, suffix, state, storage_factory); - - let reactivate_user_worker = - crate::build!(ReactivateUserJob => reactivate_user, suffix, state, storage_factory); - - monitor - .register(deactivate_user_worker) - .register(reactivate_user_worker) +#[async_trait] +impl RunnableJob for ReactivateUserJob { + #[tracing::instrument( + name = "job.reactivate_user", + fields(user.id = %self.user_id()), + skip_all, + err(Debug), + )] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + let matrix = state.matrix_connection(); + let mut repo = state.repository().await?; + + let user = repo + .user() + .lookup(self.user_id()) + .await? + .context("User not found")?; + + let mxid = matrix.mxid(&user.username); + info!("Reactivating user {} on homeserver", mxid); + matrix.reactivate_user(&mxid).await?; + + // We want to unlock the user from our side only once it has been reactivated on + // the homeserver + let _user = repo.user().unlock(user).await?; + repo.save().await?; + + Ok(()) + } } diff --git a/crates/tasks/src/utils.rs b/crates/tasks/src/utils.rs deleted file mode 100644 index c5862f9cf..000000000 --- a/crates/tasks/src/utils.rs +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2023, 2024 The Matrix.org Foundation C.I.C. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -use apalis_core::{job::Job, request::JobRequest}; -use mas_storage::job::JobWithSpanContext; -use mas_tower::{ - make_span_fn, DurationRecorderLayer, FnWrapper, IdentityLayer, InFlightCounterLayer, - TraceLayer, KV, -}; -use opentelemetry::{trace::SpanContext, Key, KeyValue}; -use tracing::info_span; -use tracing_opentelemetry::OpenTelemetrySpanExt; - -const JOB_NAME: Key = Key::from_static_str("job.name"); -const JOB_STATUS: Key = Key::from_static_str("job.status"); - -/// Represents a job that can may have a span context attached to it. -pub trait TracedJob: Job { - /// Returns the span context for this job, if any. - /// - /// The default implementation returns `None`. - fn span_context(&self) -> Option { - None - } -} - -/// Implements [`TracedJob`] for any job with the [`JobWithSpanContext`] -/// wrapper. -impl TracedJob for JobWithSpanContext { - fn span_context(&self) -> Option { - JobWithSpanContext::span_context(self) - } -} - -fn make_span_for_job_request(req: &JobRequest) -> tracing::Span { - let span = info_span!( - "job.run", - "otel.kind" = "consumer", - "otel.status_code" = tracing::field::Empty, - "job.id" = %req.id(), - "job.attempts" = req.attempts(), - "job.name" = J::NAME, - ); - - if let Some(context) = req.inner().span_context() { - span.add_link(context); - } - - span -} - -type TraceLayerForJob = - TraceLayer) -> tracing::Span>, KV<&'static str>, KV<&'static str>>; - -pub(crate) fn trace_layer() -> TraceLayerForJob -where - J: TracedJob, -{ - TraceLayer::new(make_span_fn( - make_span_for_job_request:: as fn(&JobRequest) -> tracing::Span, - )) - .on_response(KV("otel.status_code", "OK")) - .on_error(KV("otel.status_code", "ERROR")) -} - -type MetricsLayerForJob = ( - IdentityLayer>, - DurationRecorderLayer, - InFlightCounterLayer, -); - -pub(crate) fn metrics_layer() -> MetricsLayerForJob -where - J: Job, -{ - let duration_recorder = DurationRecorderLayer::new("job.run.duration") - .on_request(JOB_NAME.string(J::NAME)) - .on_response(JOB_STATUS.string("success")) - .on_error(JOB_STATUS.string("error")); - let in_flight_counter = - InFlightCounterLayer::new("job.run.active").on_request(JOB_NAME.string(J::NAME)); - - ( - IdentityLayer::default(), - duration_recorder, - in_flight_counter, - ) -} From 29181e7ce30f033f9c55a6d2d1affe520158ca4e Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 20 Nov 2024 17:03:00 +0100 Subject: [PATCH 3/5] Decide in each job whether it should retry or not --- crates/tasks/src/email.rs | 32 +++++++----- crates/tasks/src/matrix.rs | 96 ++++++++++++++++++++--------------- crates/tasks/src/new_queue.rs | 76 +++++++++++++++++++++++---- crates/tasks/src/recovery.rs | 39 ++++++++------ crates/tasks/src/user.rs | 54 ++++++++++++-------- 5 files changed, 198 insertions(+), 99 deletions(-) diff --git a/crates/tasks/src/email.rs b/crates/tasks/src/email.rs index 3afbab8ce..25cbf2e7d 100644 --- a/crates/tasks/src/email.rs +++ b/crates/tasks/src/email.rs @@ -15,7 +15,7 @@ use rand::{distributions::Uniform, Rng}; use tracing::info; use crate::{ - new_queue::{JobContext, RunnableJob}, + new_queue::{JobContext, JobError, RunnableJob}, State, }; @@ -25,10 +25,10 @@ impl RunnableJob for VerifyEmailJob { name = "job.verify_email", fields(user_email.id = %self.user_email_id()), skip_all, - err(Debug), + err, )] - async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { - let mut repo = state.repository().await?; + async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> { + let mut repo = state.repository().await.map_err(JobError::retry)?; let mut rng = state.rng(); let mailer = state.mailer(); let clock = state.clock(); @@ -42,22 +42,26 @@ impl RunnableJob for VerifyEmailJob { let user_email = repo .user_email() .lookup(self.user_email_id()) - .await? - .context("User email not found")?; + .await + .map_err(JobError::retry)? + .context("User email not found") + .map_err(JobError::fail)?; // Lookup the user associated with the email let user = repo .user() .lookup(user_email.user_id) - .await? - .context("User not found")?; + .await + .map_err(JobError::retry)? + .context("User not found") + .map_err(JobError::fail)?; // Generate a verification code let range = Uniform::::from(0..1_000_000); let code = rng.sample(range); let code = format!("{code:06}"); - let address: Address = user_email.email.parse()?; + let address: Address = user_email.email.parse().map_err(JobError::fail)?; // Save the verification code in the database let verification = repo @@ -69,7 +73,8 @@ impl RunnableJob for VerifyEmailJob { Duration::try_hours(8).unwrap(), code, ) - .await?; + .await + .map_err(JobError::retry)?; // And send the verification email let mailbox = Mailbox::new(Some(user.username.clone()), address); @@ -77,14 +82,17 @@ impl RunnableJob for VerifyEmailJob { let context = EmailVerificationContext::new(user.clone(), verification.clone()) .with_language(language); - mailer.send_verification_email(mailbox, &context).await?; + mailer + .send_verification_email(mailbox, &context) + .await + .map_err(JobError::retry)?; info!( email.id = %user_email.id, "Verification email sent" ); - repo.save().await?; + repo.save().await.map_err(JobError::retry)?; Ok(()) } diff --git a/crates/tasks/src/matrix.rs b/crates/tasks/src/matrix.rs index f4596c05f..0f58773b3 100644 --- a/crates/tasks/src/matrix.rs +++ b/crates/tasks/src/matrix.rs @@ -23,7 +23,7 @@ use mas_storage::{ use tracing::info; use crate::{ - new_queue::{JobContext, RunnableJob}, + new_queue::{JobContext, JobError, RunnableJob}, State, }; @@ -36,25 +36,28 @@ impl RunnableJob for ProvisionUserJob { name = "job.provision_user" fields(user.id = %self.user_id()), skip_all, - err(Debug), + err, )] - async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> { let matrix = state.matrix_connection(); - let mut repo = state.repository().await?; + let mut repo = state.repository().await.map_err(JobError::retry)?; let mut rng = state.rng(); let clock = state.clock(); let user = repo .user() .lookup(self.user_id()) - .await? - .context("User not found")?; + .await + .map_err(JobError::retry)? + .context("User not found") + .map_err(JobError::fail)?; let mxid = matrix.mxid(&user.username); let emails = repo .user_email() .all(&user) - .await? + .await + .map_err(JobError::retry)? .into_iter() .filter(|email| email.confirmed_at.is_some()) .map(|email| email.email) @@ -65,7 +68,10 @@ impl RunnableJob for ProvisionUserJob { request = request.set_displayname(display_name.to_owned()); } - let created = matrix.provision_user(&request).await?; + let created = matrix + .provision_user(&request) + .await + .map_err(JobError::retry)?; if created { info!(%user.id, %mxid, "User created"); @@ -77,9 +83,10 @@ impl RunnableJob for ProvisionUserJob { let sync_device_job = SyncDevicesJob::new(&user); repo.queue_job() .schedule_job(&mut rng, &clock, sync_device_job) - .await?; + .await + .map_err(JobError::retry)?; - repo.save().await?; + repo.save().await.map_err(JobError::retry)?; Ok(()) } @@ -97,23 +104,26 @@ impl RunnableJob for ProvisionDeviceJob { device.id = %self.device_id(), ), skip_all, - err(Debug), + err, )] - async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { - let mut repo = state.repository().await?; + async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> { + let mut repo = state.repository().await.map_err(JobError::retry)?; let mut rng = state.rng(); let clock = state.clock(); let user = repo .user() .lookup(self.user_id()) - .await? - .context("User not found")?; + .await + .map_err(JobError::retry)? + .context("User not found") + .map_err(JobError::fail)?; // Schedule a device sync job repo.queue_job() .schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user)) - .await?; + .await + .map_err(JobError::retry)?; Ok(()) } @@ -131,32 +141,26 @@ impl RunnableJob for DeleteDeviceJob { device.id = %self.device_id(), ), skip_all, - err(Debug), + err, )] - #[tracing::instrument( - name = "job.delete_device" - fields( - user.id = %self.user_id(), - device.id = %self.device_id(), - ), - skip_all, - err(Debug), - )] - async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> { let mut rng = state.rng(); let clock = state.clock(); - let mut repo = state.repository().await?; + let mut repo = state.repository().await.map_err(JobError::retry)?; let user = repo .user() .lookup(self.user_id()) - .await? - .context("User not found")?; + .await + .map_err(JobError::retry)? + .context("User not found") + .map_err(JobError::fail)?; // Schedule a device sync job repo.queue_job() .schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user)) - .await?; + .await + .map_err(JobError::retry)?; Ok(()) } @@ -169,20 +173,25 @@ impl RunnableJob for SyncDevicesJob { name = "job.sync_devices", fields(user.id = %self.user_id()), skip_all, - err(Debug), + err, )] - async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> { let matrix = state.matrix_connection(); - let mut repo = state.repository().await?; + let mut repo = state.repository().await.map_err(JobError::retry)?; let user = repo .user() .lookup(self.user_id()) - .await? - .context("User not found")?; + .await + .map_err(JobError::retry)? + .context("User not found") + .map_err(JobError::fail)?; // Lock the user sync to make sure we don't get into a race condition - repo.user().acquire_lock_for_sync(&user).await?; + repo.user() + .acquire_lock_for_sync(&user) + .await + .map_err(JobError::retry)?; let mut devices = HashSet::new(); @@ -195,7 +204,8 @@ impl RunnableJob for SyncDevicesJob { CompatSessionFilter::new().for_user(&user).active_only(), cursor, ) - .await?; + .await + .map_err(JobError::retry)?; for (compat_session, _) in page.edges { devices.insert(compat_session.device.as_str().to_owned()); @@ -216,7 +226,8 @@ impl RunnableJob for SyncDevicesJob { OAuth2SessionFilter::new().for_user(&user).active_only(), cursor, ) - .await?; + .await + .map_err(JobError::retry)?; for oauth2_session in page.edges { for scope in &*oauth2_session.scope { @@ -234,11 +245,14 @@ impl RunnableJob for SyncDevicesJob { } let mxid = matrix.mxid(&user.username); - matrix.sync_devices(&mxid, devices).await?; + matrix + .sync_devices(&mxid, devices) + .await + .map_err(JobError::retry)?; // We kept the connection until now, so that we still hold the lock on the user // throughout the sync - repo.save().await?; + repo.save().await.map_err(JobError::retry)?; Ok(()) } diff --git a/crates/tasks/src/new_queue.rs b/crates/tasks/src/new_queue.rs index 42a037af4..ba707cff8 100644 --- a/crates/tasks/src/new_queue.rs +++ b/crates/tasks/src/new_queue.rs @@ -54,6 +54,47 @@ impl JobContext { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum JobErrorDecision { + Retry, + + #[default] + Fail, +} + +impl std::fmt::Display for JobErrorDecision { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Retry => f.write_str("retry"), + Self::Fail => f.write_str("fail"), + } + } +} + +#[derive(Debug, Error)] +#[error("Job failed to run, will {decision}")] +pub struct JobError { + decision: JobErrorDecision, + #[source] + error: anyhow::Error, +} + +impl JobError { + pub fn retry>(error: T) -> Self { + Self { + decision: JobErrorDecision::Retry, + error: error.into(), + } + } + + pub fn fail>(error: T) -> Self { + Self { + decision: JobErrorDecision::Fail, + error: error.into(), + } + } +} + pub trait FromJob { fn from_job(payload: JobPayload) -> Result where @@ -71,7 +112,7 @@ where #[async_trait] pub trait RunnableJob: FromJob + Send + 'static { - async fn run(&self, state: &State, context: JobContext) -> Result<(), anyhow::Error>; + async fn run(&self, state: &State, context: JobContext) -> Result<(), JobError>; } fn box_runnable_job(job: T) -> Box { @@ -126,13 +167,13 @@ pub struct QueueWorker { last_heartbeat: DateTime, cancellation_token: CancellationToken, state: State, - running_jobs: JoinSet>, + running_jobs: JoinSet>, job_contexts: HashMap, factories: HashMap<&'static str, JobFactory>, #[allow(clippy::type_complexity)] last_join_result: - Option), tokio::task::JoinError>>, + Option), tokio::task::JoinError>>, } impl QueueWorker { @@ -379,14 +420,27 @@ impl QueueWorker { .remove(&id) .expect("Job context not found"); - tracing::error!( - error = ?e, - job.id = %context.id, - job.queue_name = %context.queue_name, - "Job failed" - ); - - // TODO: reschedule the job + match e.decision { + JobErrorDecision::Fail => { + tracing::error!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + "Job failed" + ); + } + + JobErrorDecision::Retry => { + tracing::warn!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + "Job failed, will retry" + ); + + // TODO: reschedule the job + } + } context } diff --git a/crates/tasks/src/recovery.rs b/crates/tasks/src/recovery.rs index cd3787d2a..294d7f1ba 100644 --- a/crates/tasks/src/recovery.rs +++ b/crates/tasks/src/recovery.rs @@ -18,7 +18,7 @@ use rand::distributions::{Alphanumeric, DistString}; use tracing::{error, info}; use crate::{ - new_queue::{JobContext, RunnableJob}, + new_queue::{JobContext, JobError, RunnableJob}, State, }; @@ -32,20 +32,22 @@ impl RunnableJob for SendAccountRecoveryEmailsJob { user_recovery_session.email, ), skip_all, - err(Debug), + err, )] - async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> { let clock = state.clock(); let mailer = state.mailer(); let url_builder = state.url_builder(); let mut rng = state.rng(); - let mut repo = state.repository().await?; + let mut repo = state.repository().await.map_err(JobError::retry)?; let session = repo .user_recovery() .lookup_session(self.user_recovery_session_id()) - .await? - .context("User recovery session not found")?; + .await + .map_err(JobError::retry)? + .context("User recovery session not found") + .map_err(JobError::fail)?; tracing::Span::current().record("user_recovery_session.email", &session.email); @@ -59,7 +61,8 @@ impl RunnableJob for SendAccountRecoveryEmailsJob { let lang: DataLocale = session .locale .parse() - .context("Invalid locale in database on recovery session")?; + .context("Invalid locale in database on recovery session") + .map_err(JobError::fail)?; loop { let page = repo @@ -70,7 +73,8 @@ impl RunnableJob for SendAccountRecoveryEmailsJob { .verified_only(), cursor, ) - .await?; + .await + .map_err(JobError::retry)?; for email in page.edges { let ticket = Alphanumeric.sample_string(&mut rng, 32); @@ -78,23 +82,28 @@ impl RunnableJob for SendAccountRecoveryEmailsJob { let ticket = repo .user_recovery() .add_ticket(&mut rng, &clock, &session, &email, ticket) - .await?; + .await + .map_err(JobError::retry)?; let user_email = repo .user_email() .lookup(email.id) - .await? - .context("User email not found")?; + .await + .map_err(JobError::retry)? + .context("User email not found") + .map_err(JobError::fail)?; let user = repo .user() .lookup(user_email.user_id) - .await? - .context("User not found")?; + .await + .map_err(JobError::retry)? + .context("User not found") + .map_err(JobError::fail)?; let url = url_builder.account_recovery_link(ticket.ticket); - let address: Address = user_email.email.parse()?; + let address: Address = user_email.email.parse().map_err(JobError::fail)?; let mailbox = Mailbox::new(Some(user.username.clone()), address); info!("Sending recovery email to {}", mailbox); @@ -117,7 +126,7 @@ impl RunnableJob for SendAccountRecoveryEmailsJob { } } - repo.save().await?; + repo.save().await.map_err(JobError::fail)?; Ok(()) } diff --git a/crates/tasks/src/user.rs b/crates/tasks/src/user.rs index ad4444be5..eaa9d2b43 100644 --- a/crates/tasks/src/user.rs +++ b/crates/tasks/src/user.rs @@ -16,7 +16,7 @@ use mas_storage::{ use tracing::info; use crate::{ - new_queue::{JobContext, RunnableJob}, + new_queue::{JobContext, JobError, RunnableJob}, State, }; @@ -27,25 +27,28 @@ impl RunnableJob for DeactivateUserJob { name = "job.deactivate_user" fields(user.id = %self.user_id(), erase = %self.hs_erase()), skip_all, - err(Debug), + err, )] - async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> { let clock = state.clock(); let matrix = state.matrix_connection(); - let mut repo = state.repository().await?; + let mut repo = state.repository().await.map_err(JobError::retry)?; let user = repo .user() .lookup(self.user_id()) - .await? - .context("User not found")?; + .await + .map_err(JobError::retry)? + .context("User not found") + .map_err(JobError::fail)?; // Let's first lock the user let user = repo .user() .lock(&clock, user) .await - .context("Failed to lock user")?; + .context("Failed to lock user") + .map_err(JobError::retry)?; // Kill all sessions for the user let n = repo @@ -54,7 +57,8 @@ impl RunnableJob for DeactivateUserJob { &clock, BrowserSessionFilter::new().for_user(&user).active_only(), ) - .await?; + .await + .map_err(JobError::retry)?; info!(affected = n, "Killed all browser sessions for user"); let n = repo @@ -63,7 +67,8 @@ impl RunnableJob for DeactivateUserJob { &clock, OAuth2SessionFilter::new().for_user(&user).active_only(), ) - .await?; + .await + .map_err(JobError::retry)?; info!(affected = n, "Killed all OAuth 2.0 sessions for user"); let n = repo @@ -72,16 +77,20 @@ impl RunnableJob for DeactivateUserJob { &clock, CompatSessionFilter::new().for_user(&user).active_only(), ) - .await?; + .await + .map_err(JobError::retry)?; info!(affected = n, "Killed all compatibility sessions for user"); // Before calling back to the homeserver, commit the changes to the database, as // we want the user to be locked out as soon as possible - repo.save().await?; + repo.save().await.map_err(JobError::retry)?; let mxid = matrix.mxid(&user.username); info!("Deactivating user {} on homeserver", mxid); - matrix.delete_user(&mxid, self.hs_erase()).await?; + matrix + .delete_user(&mxid, self.hs_erase()) + .await + .map_err(JobError::retry)?; Ok(()) } @@ -94,26 +103,31 @@ impl RunnableJob for ReactivateUserJob { name = "job.reactivate_user", fields(user.id = %self.user_id()), skip_all, - err(Debug), + err, )] - async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> { let matrix = state.matrix_connection(); - let mut repo = state.repository().await?; + let mut repo = state.repository().await.map_err(JobError::retry)?; let user = repo .user() .lookup(self.user_id()) - .await? - .context("User not found")?; + .await + .map_err(JobError::retry)? + .context("User not found") + .map_err(JobError::fail)?; let mxid = matrix.mxid(&user.username); info!("Reactivating user {} on homeserver", mxid); - matrix.reactivate_user(&mxid).await?; + matrix + .reactivate_user(&mxid) + .await + .map_err(JobError::retry)?; // We want to unlock the user from our side only once it has been reactivated on // the homeserver - let _user = repo.user().unlock(user).await?; - repo.save().await?; + let _user = repo.user().unlock(user).await.map_err(JobError::retry)?; + repo.save().await.map_err(JobError::retry)?; Ok(()) } From 612d9ad244c008d0b9d3cd6af9b647e5f933f291 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 20 Nov 2024 18:36:45 +0100 Subject: [PATCH 4/5] Retry failed jobs --- ...6c35c9c236ea8beb6696e5740fa45655e59f3.json | 15 +++ ...a8e4d1682263079ec09c38a20c059580adb38.json | 16 +++ ...1388d6723f82549d88d704d9c939b9d35c49.json} | 10 +- ...ca42c790c101a3fc9442862b5885d5116325a.json | 16 +++ .../20241120163320_queue_job_failures.sql | 17 +++ crates/storage-pg/src/queue/job.rs | 111 +++++++++++++++++- crates/storage/src/queue/job.rs | 54 ++++++++- crates/tasks/src/new_queue.rs | 94 +++++++++++---- 8 files changed, 303 insertions(+), 30 deletions(-) create mode 100644 crates/storage-pg/.sqlx/query-07cd2da428f0984513b4ce58e526c35c9c236ea8beb6696e5740fa45655e59f3.json create mode 100644 crates/storage-pg/.sqlx/query-47e74a8fc614653ffaa60930fafa8e4d1682263079ec09c38a20c059580adb38.json rename crates/storage-pg/.sqlx/{query-9f2fae84d17991a179f93c4ea43b411aa9f15e7beccfd6212787c3452d35d061.json => query-707d78340069627aba9f18bbe5ac1388d6723f82549d88d704d9c939b9d35c49.json} (87%) create mode 100644 crates/storage-pg/.sqlx/query-f50b7fb5a2c09e7b7e89e2addb0ca42c790c101a3fc9442862b5885d5116325a.json create mode 100644 crates/storage-pg/migrations/20241120163320_queue_job_failures.sql diff --git a/crates/storage-pg/.sqlx/query-07cd2da428f0984513b4ce58e526c35c9c236ea8beb6696e5740fa45655e59f3.json b/crates/storage-pg/.sqlx/query-07cd2da428f0984513b4ce58e526c35c9c236ea8beb6696e5740fa45655e59f3.json new file mode 100644 index 000000000..e5ffe95e2 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-07cd2da428f0984513b4ce58e526c35c9c236ea8beb6696e5740fa45655e59f3.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE queue_jobs\n SET next_attempt_id = $1\n WHERE queue_job_id = $2\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "07cd2da428f0984513b4ce58e526c35c9c236ea8beb6696e5740fa45655e59f3" +} diff --git a/crates/storage-pg/.sqlx/query-47e74a8fc614653ffaa60930fafa8e4d1682263079ec09c38a20c059580adb38.json b/crates/storage-pg/.sqlx/query-47e74a8fc614653ffaa60930fafa8e4d1682263079ec09c38a20c059580adb38.json new file mode 100644 index 000000000..2962db553 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-47e74a8fc614653ffaa60930fafa8e4d1682263079ec09c38a20c059580adb38.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO queue_jobs\n (queue_job_id, queue_name, payload, metadata, created_at, attempt)\n SELECT $1, queue_name, payload, metadata, $2, attempt + 1\n FROM queue_jobs\n WHERE queue_job_id = $3\n AND status = 'failed'\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz", + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "47e74a8fc614653ffaa60930fafa8e4d1682263079ec09c38a20c059580adb38" +} diff --git a/crates/storage-pg/.sqlx/query-9f2fae84d17991a179f93c4ea43b411aa9f15e7beccfd6212787c3452d35d061.json b/crates/storage-pg/.sqlx/query-707d78340069627aba9f18bbe5ac1388d6723f82549d88d704d9c939b9d35c49.json similarity index 87% rename from crates/storage-pg/.sqlx/query-9f2fae84d17991a179f93c4ea43b411aa9f15e7beccfd6212787c3452d35d061.json rename to crates/storage-pg/.sqlx/query-707d78340069627aba9f18bbe5ac1388d6723f82549d88d704d9c939b9d35c49.json index 67f1ad132..88eb81f9f 100644 --- a/crates/storage-pg/.sqlx/query-9f2fae84d17991a179f93c4ea43b411aa9f15e7beccfd6212787c3452d35d061.json +++ b/crates/storage-pg/.sqlx/query-707d78340069627aba9f18bbe5ac1388d6723f82549d88d704d9c939b9d35c49.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n -- We first grab a few jobs that are available,\n -- using a FOR UPDATE SKIP LOCKED so that this can be run concurrently\n -- and we don't get multiple workers grabbing the same jobs\n WITH locked_jobs AS (\n SELECT queue_job_id\n FROM queue_jobs\n WHERE\n status = 'available'\n AND queue_name = ANY($1)\n ORDER BY queue_job_id ASC\n LIMIT $2\n FOR UPDATE\n SKIP LOCKED\n )\n -- then we update the status of those jobs to 'running', returning the job details\n UPDATE queue_jobs\n SET status = 'running', started_at = $3, started_by = $4\n FROM locked_jobs\n WHERE queue_jobs.queue_job_id = locked_jobs.queue_job_id\n RETURNING\n queue_jobs.queue_job_id,\n queue_jobs.queue_name,\n queue_jobs.payload,\n queue_jobs.metadata\n ", + "query": "\n -- We first grab a few jobs that are available,\n -- using a FOR UPDATE SKIP LOCKED so that this can be run concurrently\n -- and we don't get multiple workers grabbing the same jobs\n WITH locked_jobs AS (\n SELECT queue_job_id\n FROM queue_jobs\n WHERE\n status = 'available'\n AND queue_name = ANY($1)\n ORDER BY queue_job_id ASC\n LIMIT $2\n FOR UPDATE\n SKIP LOCKED\n )\n -- then we update the status of those jobs to 'running', returning the job details\n UPDATE queue_jobs\n SET status = 'running', started_at = $3, started_by = $4\n FROM locked_jobs\n WHERE queue_jobs.queue_job_id = locked_jobs.queue_job_id\n RETURNING\n queue_jobs.queue_job_id,\n queue_jobs.queue_name,\n queue_jobs.payload,\n queue_jobs.metadata,\n queue_jobs.attempt\n ", "describe": { "columns": [ { @@ -22,6 +22,11 @@ "ordinal": 3, "name": "metadata", "type_info": "Jsonb" + }, + { + "ordinal": 4, + "name": "attempt", + "type_info": "Int4" } ], "parameters": { @@ -36,8 +41,9 @@ false, false, false, + false, false ] }, - "hash": "9f2fae84d17991a179f93c4ea43b411aa9f15e7beccfd6212787c3452d35d061" + "hash": "707d78340069627aba9f18bbe5ac1388d6723f82549d88d704d9c939b9d35c49" } diff --git a/crates/storage-pg/.sqlx/query-f50b7fb5a2c09e7b7e89e2addb0ca42c790c101a3fc9442862b5885d5116325a.json b/crates/storage-pg/.sqlx/query-f50b7fb5a2c09e7b7e89e2addb0ca42c790c101a3fc9442862b5885d5116325a.json new file mode 100644 index 000000000..df75b11b1 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-f50b7fb5a2c09e7b7e89e2addb0ca42c790c101a3fc9442862b5885d5116325a.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE queue_jobs\n SET\n status = 'failed',\n failed_at = $1,\n failed_reason = $2\n WHERE\n queue_job_id = $3\n AND status = 'running'\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Timestamptz", + "Text", + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "f50b7fb5a2c09e7b7e89e2addb0ca42c790c101a3fc9442862b5885d5116325a" +} diff --git a/crates/storage-pg/migrations/20241120163320_queue_job_failures.sql b/crates/storage-pg/migrations/20241120163320_queue_job_failures.sql new file mode 100644 index 000000000..0407d6342 --- /dev/null +++ b/crates/storage-pg/migrations/20241120163320_queue_job_failures.sql @@ -0,0 +1,17 @@ +-- Copyright 2024 New Vector Ltd. +-- +-- SPDX-License-Identifier: AGPL-3.0-only +-- Please see LICENSE in the repository root for full details. + +-- Add a new status for failed jobs +ALTER TYPE "queue_job_status" ADD VALUE 'failed'; + +ALTER TABLE "queue_jobs" + -- When the job failed + ADD COLUMN "failed_at" TIMESTAMP WITH TIME ZONE, + -- Error message of the failure + ADD COLUMN "failed_reason" TEXT, + -- How many times we've already tried to run the job + ADD COLUMN "attempt" INTEGER NOT NULL DEFAULT 0, + -- The next attempt, if it was retried + ADD COLUMN "next_attempt_id" UUID REFERENCES "queue_jobs" ("queue_job_id"); diff --git a/crates/storage-pg/src/queue/job.rs b/crates/storage-pg/src/queue/job.rs index 90f8546a7..cd7d5428e 100644 --- a/crates/storage-pg/src/queue/job.rs +++ b/crates/storage-pg/src/queue/job.rs @@ -37,6 +37,7 @@ struct JobReservationResult { queue_name: String, payload: serde_json::Value, metadata: serde_json::Value, + attempt: i32, } impl TryFrom for Job { @@ -54,11 +55,19 @@ impl TryFrom for Job { .source(e) })?; + let attempt = value.attempt.try_into().map_err(|e| { + DatabaseInconsistencyError::on("queue_jobs") + .column("attempt") + .row(id) + .source(e) + })?; + Ok(Self { id, queue_name, payload, metadata, + attempt, }) } } @@ -152,7 +161,8 @@ impl<'c> QueueJobRepository for PgQueueJobRepository<'c> { queue_jobs.queue_job_id, queue_jobs.queue_name, queue_jobs.payload, - queue_jobs.metadata + queue_jobs.metadata, + queue_jobs.attempt "#, &queues, max_count, @@ -199,4 +209,103 @@ impl<'c> QueueJobRepository for PgQueueJobRepository<'c> { Ok(()) } + + #[tracing::instrument( + name = "db.queue_job.mark_as_failed", + skip_all, + fields( + db.query.text, + job.id = %id, + ), + err + )] + async fn mark_as_failed( + &mut self, + clock: &dyn Clock, + id: Ulid, + reason: &str, + ) -> Result<(), Self::Error> { + let now = clock.now(); + let res = sqlx::query!( + r#" + UPDATE queue_jobs + SET + status = 'failed', + failed_at = $1, + failed_reason = $2 + WHERE + queue_job_id = $3 + AND status = 'running' + "#, + now, + reason, + Uuid::from(id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(()) + } + + #[tracing::instrument( + name = "db.queue_job.retry", + skip_all, + fields( + db.query.text, + job.id = %id, + ), + err + )] + async fn retry( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + id: Ulid, + ) -> Result<(), Self::Error> { + let now = clock.now(); + let new_id = Ulid::from_datetime_with_source(now.into(), rng); + + // Create a new job with the same payload and metadata, but a new ID and + // increment the attempt + // We make sure we do this only for 'failed' jobs + let res = sqlx::query!( + r#" + INSERT INTO queue_jobs + (queue_job_id, queue_name, payload, metadata, created_at, attempt) + SELECT $1, queue_name, payload, metadata, $2, attempt + 1 + FROM queue_jobs + WHERE queue_job_id = $3 + AND status = 'failed' + "#, + Uuid::from(new_id), + now, + Uuid::from(id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + // Update the old job to point to the new attempt + let res = sqlx::query!( + r#" + UPDATE queue_jobs + SET next_attempt_id = $1 + WHERE queue_job_id = $2 + "#, + Uuid::from(new_id), + Uuid::from(id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(()) + } } diff --git a/crates/storage/src/queue/job.rs b/crates/storage/src/queue/job.rs index 13df586d7..9a24fa649 100644 --- a/crates/storage/src/queue/job.rs +++ b/crates/storage/src/queue/job.rs @@ -28,6 +28,9 @@ pub struct Job { /// Arbitrary metadata about the job pub metadata: JobMetadata, + + /// Which attempt it is + pub attempt: usize, } /// Metadata stored alongside the job @@ -127,12 +130,48 @@ pub trait QueueJobRepository: Send + Sync { /// # Parameters /// /// * `clock` - The clock used to generate timestamps - /// * `job` - The job to mark as completed + /// * `id` - The ID of the job to mark as completed /// /// # Errors /// /// Returns an error if the underlying repository fails. async fn mark_as_completed(&mut self, clock: &dyn Clock, id: Ulid) -> Result<(), Self::Error>; + + /// Marks a job as failed. + /// + /// # Parameters + /// + /// * `clock` - The clock used to generate timestamps + /// * `id` - The ID of the job to mark as failed + /// * `reason` - The reason for the failure + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails. + async fn mark_as_failed( + &mut self, + clock: &dyn Clock, + id: Ulid, + reason: &str, + ) -> Result<(), Self::Error>; + + /// Retry a job. + /// + /// # Parameters + /// + /// * `rng` - The random number generator used to generate a new job ID + /// * `clock` - The clock used to generate timestamps + /// * `id` - The ID of the job to reschedule + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails. + async fn retry( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + id: Ulid, + ) -> Result<(), Self::Error>; } repository_impl!(QueueJobRepository: @@ -154,6 +193,19 @@ repository_impl!(QueueJobRepository: ) -> Result, Self::Error>; async fn mark_as_completed(&mut self, clock: &dyn Clock, id: Ulid) -> Result<(), Self::Error>; + + async fn mark_as_failed(&mut self, + clock: &dyn Clock, + id: Ulid, + reason: &str, + ) -> Result<(), Self::Error>; + + async fn retry( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + id: Ulid, + ) -> Result<(), Self::Error>; ); /// Extension trait for [`QueueJobRepository`] to help adding a job to the queue diff --git a/crates/tasks/src/new_queue.rs b/crates/tasks/src/new_queue.rs index ba707cff8..143b83ece 100644 --- a/crates/tasks/src/new_queue.rs +++ b/crates/tasks/src/new_queue.rs @@ -35,6 +35,7 @@ pub struct JobContext { pub id: Ulid, pub metadata: JobMetadata, pub queue_name: String, + pub attempt: usize, pub cancellation_token: CancellationToken, } @@ -156,6 +157,9 @@ const MAX_CONCURRENT_JOBS: usize = 10; // How many jobs can we fetch at once const MAX_JOBS_TO_FETCH: usize = 5; +// How many attempts a job should be retried +const MAX_ATTEMPTS: usize = 5; + type JobFactory = Arc Box + Send + Sync>; pub struct QueueWorker { @@ -280,6 +284,8 @@ impl QueueWorker { async fn shutdown(&mut self) -> Result<(), QueueRunnerError> { tracing::info!("Shutting down worker"); + // TODO: collect running jobs + // Start a transaction on the existing PgListener connection let txn = self .listener @@ -397,7 +403,7 @@ impl QueueWorker { while let Some(result) = self.last_join_result.take() { // TODO: add metrics to track the job status and the time it took - let context = match result { + match result { Ok((id, Ok(()))) => { // The job succeeded let context = self @@ -408,10 +414,13 @@ impl QueueWorker { tracing::info!( job.id = %context.id, job.queue_name = %context.queue_name, + job.attempt = %context.attempt, "Job completed" ); - context + repo.queue_job() + .mark_as_completed(&self.clock, context.id) + .await?; } Ok((id, Err(e))) => { // The job failed @@ -420,29 +429,48 @@ impl QueueWorker { .remove(&id) .expect("Job context not found"); + let reason = format!("{:?}", e.error); + repo.queue_job() + .mark_as_failed(&self.clock, context.id, &reason) + .await?; + match e.decision { JobErrorDecision::Fail => { tracing::error!( error = &e as &dyn std::error::Error, job.id = %context.id, job.queue_name = %context.queue_name, - "Job failed" + job.attempt = %context.attempt, + "Job failed, not retrying" ); } JobErrorDecision::Retry => { - tracing::warn!( - error = &e as &dyn std::error::Error, - job.id = %context.id, - job.queue_name = %context.queue_name, - "Job failed, will retry" - ); - - // TODO: reschedule the job + if context.attempt < MAX_ATTEMPTS { + tracing::warn!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job failed, will retry" + ); + + // TODO: retry with an exponential backoff, once we know how to + // schedule jobs in the future + repo.queue_job() + .retry(&mut self.rng, &self.clock, context.id) + .await?; + } else { + tracing::error!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job failed too many times, abandonning" + ); + } } } - - context } Err(e) => { // The job crashed (or was cancelled) @@ -452,23 +480,35 @@ impl QueueWorker { .remove(&id) .expect("Job context not found"); - tracing::error!( - error = &e as &dyn std::error::Error, - job.id = %context.id, - job.queue_name = %context.queue_name, - "Job crashed" - ); - - // TODO: reschedule the job + let reason = e.to_string(); + repo.queue_job() + .mark_as_failed(&self.clock, context.id, &reason) + .await?; + + if context.attempt < MAX_ATTEMPTS { + tracing::warn!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job crashed, will retry" + ); - context + repo.queue_job() + .retry(&mut self.rng, &self.clock, context.id) + .await?; + } else { + tracing::error!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job crashed too many times, abandonning" + ); + } } }; - repo.queue_job() - .mark_as_completed(&self.clock, context.id) - .await?; - self.last_join_result = self.running_jobs.try_join_next_with_id(); } @@ -492,6 +532,7 @@ impl QueueWorker { queue_name, payload, metadata, + attempt, } in jobs { let cancellation_token = self.cancellation_token.child_token(); @@ -500,6 +541,7 @@ impl QueueWorker { id, metadata, queue_name, + attempt, cancellation_token, }; From 0d5c391fb95cc518e59a65bca65171911531fe11 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 21 Nov 2024 10:55:45 +0100 Subject: [PATCH 5/5] Refactor job processing to wait for them to finish on shutdown --- crates/tasks/src/new_queue.rs | 395 +++++++++++++++++++++------------- 1 file changed, 241 insertions(+), 154 deletions(-) diff --git a/crates/tasks/src/new_queue.rs b/crates/tasks/src/new_queue.rs index 143b83ece..ce8504116 100644 --- a/crates/tasks/src/new_queue.rs +++ b/crates/tasks/src/new_queue.rs @@ -12,7 +12,7 @@ use mas_storage::{ Clock, RepositoryAccess, RepositoryError, }; use mas_storage_pg::{DatabaseError, PgRepository}; -use rand::{distributions::Uniform, Rng}; +use rand::{distributions::Uniform, Rng, RngCore}; use rand_chacha::ChaChaRng; use serde::de::DeserializeOwned; use sqlx::{ @@ -160,6 +160,7 @@ const MAX_JOBS_TO_FETCH: usize = 5; // How many attempts a job should be retried const MAX_ATTEMPTS: usize = 5; +type JobResult = Result<(), JobError>; type JobFactory = Arc Box + Send + Sync>; pub struct QueueWorker { @@ -171,13 +172,7 @@ pub struct QueueWorker { last_heartbeat: DateTime, cancellation_token: CancellationToken, state: State, - running_jobs: JoinSet>, - job_contexts: HashMap, - factories: HashMap<&'static str, JobFactory>, - - #[allow(clippy::type_complexity)] - last_join_result: - Option), tokio::task::JoinError>>, + tracker: JobTracker, } impl QueueWorker { @@ -234,10 +229,7 @@ impl QueueWorker { last_heartbeat: now, cancellation_token, state, - job_contexts: HashMap::new(), - running_jobs: JoinSet::new(), - factories: HashMap::new(), - last_join_result: None, + tracker: JobTracker::default(), }) } @@ -248,7 +240,9 @@ impl QueueWorker { box_runnable_job(T::from_job(payload).expect("Failed to deserialize job")) }; - self.factories.insert(T::QUEUE_NAME, Arc::new(factory)); + self.tracker + .factories + .insert(T::QUEUE_NAME, Arc::new(factory)); self } @@ -266,7 +260,6 @@ impl QueueWorker { async fn run_loop(&mut self) -> Result<(), QueueRunnerError> { self.wait_until_wakeup().await?; - // TODO: join all the jobs handles when shutting down if self.cancellation_token.is_cancelled() { return Ok(()); } @@ -284,8 +277,6 @@ impl QueueWorker { async fn shutdown(&mut self) -> Result<(), QueueRunnerError> { tracing::info!("Shutting down worker"); - // TODO: collect running jobs - // Start a transaction on the existing PgListener connection let txn = self .listener @@ -295,6 +286,24 @@ impl QueueWorker { let mut repo = PgRepository::from_conn(txn); + // Log about any job still running + match self.tracker.running_jobs() { + 0 => {} + 1 => tracing::warn!("There is one job still running, waiting for it to finish"), + n => tracing::warn!("There are {n} jobs still running, waiting for them to finish"), + } + + // TODO: we may want to introduce a timeout here, and abort the tasks if they + // take too long. It's fine for now, as we don't have long-running + // tasks, most of them are idempotent, and the only effect might be that + // the worker would 'dirtily' shutdown, meaning that its tasks would be + // considered, later retried by another worker + + // Wait for all the jobs to finish + self.tracker + .process_jobs(&mut self.rng, &self.clock, &mut repo, true) + .await?; + // Tell the other workers we're shutting down // This also releases the leader election lease repo.queue_worker() @@ -330,9 +339,8 @@ impl QueueWorker { tracing::debug!("Woke up from sleep"); }, - Some(result) = self.running_jobs.join_next_with_id() => { + () = self.tracker.collect_next_job(), if self.tracker.has_jobs() => { tracing::debug!("Joined job task"); - self.last_join_result = Some(result); }, notification = self.listener.recv() => { @@ -393,135 +401,21 @@ impl QueueWorker { .try_get_leader_lease(&self.clock, &self.registration) .await?; - // Find any job task which finished - // If we got woken up by a join on the joinset, it will be stored in the - // last_join_result so that we don't loose it - - if self.last_join_result.is_none() { - self.last_join_result = self.running_jobs.try_join_next_with_id(); - } - - while let Some(result) = self.last_join_result.take() { - // TODO: add metrics to track the job status and the time it took - match result { - Ok((id, Ok(()))) => { - // The job succeeded - let context = self - .job_contexts - .remove(&id) - .expect("Job context not found"); - - tracing::info!( - job.id = %context.id, - job.queue_name = %context.queue_name, - job.attempt = %context.attempt, - "Job completed" - ); - - repo.queue_job() - .mark_as_completed(&self.clock, context.id) - .await?; - } - Ok((id, Err(e))) => { - // The job failed - let context = self - .job_contexts - .remove(&id) - .expect("Job context not found"); - - let reason = format!("{:?}", e.error); - repo.queue_job() - .mark_as_failed(&self.clock, context.id, &reason) - .await?; - - match e.decision { - JobErrorDecision::Fail => { - tracing::error!( - error = &e as &dyn std::error::Error, - job.id = %context.id, - job.queue_name = %context.queue_name, - job.attempt = %context.attempt, - "Job failed, not retrying" - ); - } - - JobErrorDecision::Retry => { - if context.attempt < MAX_ATTEMPTS { - tracing::warn!( - error = &e as &dyn std::error::Error, - job.id = %context.id, - job.queue_name = %context.queue_name, - job.attempt = %context.attempt, - "Job failed, will retry" - ); - - // TODO: retry with an exponential backoff, once we know how to - // schedule jobs in the future - repo.queue_job() - .retry(&mut self.rng, &self.clock, context.id) - .await?; - } else { - tracing::error!( - error = &e as &dyn std::error::Error, - job.id = %context.id, - job.queue_name = %context.queue_name, - job.attempt = %context.attempt, - "Job failed too many times, abandonning" - ); - } - } - } - } - Err(e) => { - // The job crashed (or was cancelled) - let id = e.id(); - let context = self - .job_contexts - .remove(&id) - .expect("Job context not found"); - - let reason = e.to_string(); - repo.queue_job() - .mark_as_failed(&self.clock, context.id, &reason) - .await?; - - if context.attempt < MAX_ATTEMPTS { - tracing::warn!( - error = &e as &dyn std::error::Error, - job.id = %context.id, - job.queue_name = %context.queue_name, - job.attempt = %context.attempt, - "Job crashed, will retry" - ); - - repo.queue_job() - .retry(&mut self.rng, &self.clock, context.id) - .await?; - } else { - tracing::error!( - error = &e as &dyn std::error::Error, - job.id = %context.id, - job.queue_name = %context.queue_name, - job.attempt = %context.attempt, - "Job crashed too many times, abandonning" - ); - } - } - }; - - self.last_join_result = self.running_jobs.try_join_next_with_id(); - } + // Process any job task which finished + self.tracker + .process_jobs(&mut self.rng, &self.clock, &mut repo, false) + .await?; // Compute how many jobs we should fetch at most let max_jobs_to_fetch = MAX_CONCURRENT_JOBS - .saturating_sub(self.running_jobs.len()) + .saturating_sub(self.tracker.running_jobs()) .max(MAX_JOBS_TO_FETCH); if max_jobs_to_fetch == 0 { tracing::warn!("Internal job queue is full, not fetching any new jobs"); } else { // Grab a few jobs in the queue - let queues = self.factories.keys().copied().collect::>(); + let queues = self.tracker.queues(); let jobs = repo .queue_job() .reserve(&self.clock, &self.registration, &queues, max_jobs_to_fetch) @@ -536,7 +430,6 @@ impl QueueWorker { } in jobs { let cancellation_token = self.cancellation_token.child_token(); - let factory = self.factories.get(queue_name.as_str()).cloned(); let context = JobContext { id, metadata, @@ -545,21 +438,7 @@ impl QueueWorker { cancellation_token, }; - let task = { - let context = context.clone(); - let span = context.span(); - let state = self.state.clone(); - async move { - // We should never crash, but in case we do, we do that in the task and - // don't crash the worker - let job = factory.expect("unknown job factory")(payload); - job.run(&state, context).await - } - .instrument(span) - }; - - let handle = self.running_jobs.spawn(task); - self.job_contexts.insert(handle.id(), context); + self.tracker.spawn_job(self.state.clone(), context, payload); } } @@ -651,3 +530,211 @@ impl QueueWorker { Ok(()) } } + +/// Tracks running jobs +/// +/// This is a separate structure to be able to borrow it mutably at the same +/// time as the connection to the database is borrowed +#[derive(Default)] +struct JobTracker { + /// Stores a mapping from the job queue name to the job factory + factories: HashMap<&'static str, JobFactory>, + + /// A join set of all the currently running jobs + running_jobs: JoinSet, + + /// Stores a mapping from the Tokio task ID to the job context + job_contexts: HashMap, + + /// Stores the last `join_next_with_id` result for processing, in case we + /// got woken up in `collect_next_job` + last_join_result: Option>, +} + +impl JobTracker { + /// Returns the queue names that are currently being tracked + fn queues(&self) -> Vec<&'static str> { + self.factories.keys().copied().collect() + } + + /// Spawn a job on the job tracker + fn spawn_job(&mut self, state: State, context: JobContext, payload: JobPayload) { + let factory = self.factories.get(context.queue_name.as_str()).cloned(); + let task = { + let context = context.clone(); + let span = context.span(); + async move { + // We should never crash, but in case we do, we do that in the task and + // don't crash the worker + let job = factory.expect("unknown job factory")(payload); + tracing::info!("Running job"); + job.run(&state, context).await + } + .instrument(span) + }; + + let handle = self.running_jobs.spawn(task); + self.job_contexts.insert(handle.id(), context); + } + + /// Returns `true` if there are currently running jobs + fn has_jobs(&self) -> bool { + !self.running_jobs.is_empty() + } + + /// Returns the number of currently running jobs + /// + /// This also includes the job result which may be stored for processing + fn running_jobs(&self) -> usize { + self.running_jobs.len() + usize::from(self.last_join_result.is_some()) + } + + async fn collect_next_job(&mut self) { + // Double-check that we don't have a job result stored + if self.last_join_result.is_some() { + tracing::error!( + "Job tracker already had a job result stored, this should never happen!" + ); + return; + } + + self.last_join_result = self.running_jobs.join_next_with_id().await; + } + + /// Process all the jobs which are currently running + /// + /// If `blocking` is `true`, this function will block until all the jobs + /// are finished. Otherwise, it will return as soon as it processed the + /// already finished jobs. + #[allow(clippy::too_many_lines)] + async fn process_jobs( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + repo: &mut dyn RepositoryAccess, + blocking: bool, + ) -> Result<(), E> { + if self.last_join_result.is_none() { + if blocking { + self.last_join_result = self.running_jobs.join_next_with_id().await; + } else { + self.last_join_result = self.running_jobs.try_join_next_with_id(); + } + } + + while let Some(result) = self.last_join_result.take() { + // TODO: add metrics to track the job status and the time it took + match result { + // The job succeeded + Ok((id, Ok(()))) => { + let context = self + .job_contexts + .remove(&id) + .expect("Job context not found"); + + tracing::info!( + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job completed" + ); + + repo.queue_job() + .mark_as_completed(clock, context.id) + .await?; + } + + // The job failed + Ok((id, Err(e))) => { + let context = self + .job_contexts + .remove(&id) + .expect("Job context not found"); + + let reason = format!("{:?}", e.error); + repo.queue_job() + .mark_as_failed(clock, context.id, &reason) + .await?; + + match e.decision { + JobErrorDecision::Fail => { + tracing::error!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job failed, not retrying" + ); + } + + JobErrorDecision::Retry => { + if context.attempt < MAX_ATTEMPTS { + tracing::warn!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job failed, will retry" + ); + + // TODO: retry with an exponential backoff, once we know how to + // schedule jobs in the future + repo.queue_job().retry(&mut *rng, clock, context.id).await?; + } else { + tracing::error!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job failed too many times, abandonning" + ); + } + } + } + } + + // The job crashed (or was aborted) + Err(e) => { + let id = e.id(); + let context = self + .job_contexts + .remove(&id) + .expect("Job context not found"); + + let reason = e.to_string(); + repo.queue_job() + .mark_as_failed(clock, context.id, &reason) + .await?; + + if context.attempt < MAX_ATTEMPTS { + tracing::warn!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job crashed, will retry" + ); + + repo.queue_job().retry(&mut *rng, clock, context.id).await?; + } else { + tracing::error!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + job.attempt = %context.attempt, + "Job crashed too many times, abandonning" + ); + } + } + }; + + if blocking { + self.last_join_result = self.running_jobs.join_next_with_id().await; + } else { + self.last_join_result = self.running_jobs.try_join_next_with_id(); + } + } + + Ok(()) + } +}