Skip to content

Commit

Permalink
implement statement timeout for mysql read, and postgres queries
Browse files Browse the repository at this point in the history
  • Loading branch information
lyang24 committed Nov 14, 2024
1 parent 449548b commit dec9e08
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 9 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/common/recordbatch/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pin-project.workspace = true
serde.workspace = true
serde_json.workspace = true
snafu.workspace = true
tokio.workspace = true

[dev-dependencies]
tokio.workspace = true
9 changes: 9 additions & 0 deletions src/common/recordbatch/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ pub enum Error {
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Stream timeout"))]
StreamTimeout {
#[snafu(implicit)]
location: Location,
#[snafu(source)]
error: tokio::time::error::Elapsed,
},
}

impl ErrorExt for Error {
Expand Down Expand Up @@ -190,6 +197,8 @@ impl ErrorExt for Error {
Error::SchemaConversion { source, .. } | Error::CastVector { source, .. } => {
source.status_code()
}

Error::StreamTimeout { .. } => StatusCode::Cancelled,
}
}

Expand Down
8 changes: 4 additions & 4 deletions src/common/recordbatch/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,14 +227,14 @@ impl Stream for SimpleRecordBatchStream {
}

/// Adapt a [Stream] of [RecordBatch] to a [RecordBatchStream].
pub struct RecordBatchStreamWrapper<S> {
pub struct RecordBatchStreamWrapper<S: Stream<Item = Result<RecordBatch>> + Send + Unpin> {
pub schema: SchemaRef,
pub stream: S,
pub output_ordering: Option<Vec<OrderOption>>,
pub metrics: Arc<ArcSwapOption<RecordBatchMetrics>>,
}

impl<S> RecordBatchStreamWrapper<S> {
impl<S: Stream<Item = Result<RecordBatch>> + Send + Unpin> RecordBatchStreamWrapper<S> {
/// Creates a [RecordBatchStreamWrapper] without output ordering requirement.
pub fn new(schema: SchemaRef, stream: S) -> RecordBatchStreamWrapper<S> {
RecordBatchStreamWrapper {
Expand All @@ -246,7 +246,7 @@ impl<S> RecordBatchStreamWrapper<S> {
}
}

impl<S: Stream<Item = Result<RecordBatch>> + Unpin> RecordBatchStream
impl<S: Stream<Item = Result<RecordBatch>> + Send + Unpin> RecordBatchStream
for RecordBatchStreamWrapper<S>
{
fn name(&self) -> &str {
Expand All @@ -266,7 +266,7 @@ impl<S: Stream<Item = Result<RecordBatch>> + Unpin> RecordBatchStream
}
}

impl<S: Stream<Item = Result<RecordBatch>> + Unpin> Stream for RecordBatchStreamWrapper<S> {
impl<S: Stream<Item = Result<RecordBatch>> + Send + Unpin> Stream for RecordBatchStreamWrapper<S> {
type Item = Result<RecordBatch>;

fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Expand Down
1 change: 1 addition & 0 deletions src/operator/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ workspace = true
[dependencies]
api.workspace = true
async-trait = "0.1"
async-stream.workspace = true
catalog.workspace = true
chrono.workspace = true
client.workspace = true
Expand Down
10 changes: 10 additions & 0 deletions src/operator/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use datafusion::parquet;
use datatypes::arrow::error::ArrowError;
use snafu::{Location, Snafu};
use table::metadata::TableType;
use tokio::time::error::Elapsed;

#[derive(Snafu)]
#[snafu(visibility(pub))]
Expand Down Expand Up @@ -777,6 +778,14 @@ pub enum Error {
location: Location,
json: String,
},

#[snafu(display("Canceling statement due to statement timeout"))]
StatementTimeout {
#[snafu(implicit)]
location: Location,
#[snafu(source)]
error: Elapsed,
},
}

pub type Result<T> = std::result::Result<T, Error>;
Expand Down Expand Up @@ -924,6 +933,7 @@ impl ErrorExt for Error {
Error::BuildRecordBatch { source, .. } => source.status_code(),

Error::UpgradeCatalogManagerRef { .. } => StatusCode::Internal,
Error::StatementTimeout { .. } => StatusCode::Cancelled,
}
}

Expand Down
70 changes: 65 additions & 5 deletions src/operator/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@ mod show;
mod tql;

use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;

use async_stream::stream;
use catalog::kvbackend::KvBackendCatalogManager;
use catalog::CatalogManagerRef;
use client::RecordBatches;
use client::{OutputData, RecordBatches};
use common_error::ext::BoxedError;
use common_meta::cache::TableRouteCacheRef;
use common_meta::cache_invalidator::CacheInvalidatorRef;
Expand All @@ -39,10 +42,13 @@ use common_meta::key::view_info::{ViewInfoManager, ViewInfoManagerRef};
use common_meta::key::{TableMetadataManager, TableMetadataManagerRef};
use common_meta::kv_backend::KvBackendRef;
use common_query::Output;
use common_recordbatch::error::StreamTimeoutSnafu;
use common_recordbatch::RecordBatchStreamWrapper;
use common_telemetry::tracing;
use common_time::range::TimestampRange;
use common_time::Timestamp;
use datafusion_expr::LogicalPlan;
use futures::stream::{Stream, StreamExt};
use partition::manager::{PartitionRuleManager, PartitionRuleManagerRef};
use query::parser::QueryStatement;
use query::QueryEngineRef;
Expand All @@ -64,8 +70,8 @@ use table::TableRef;
use self::set::{set_bytea_output, set_datestyle, set_timezone, validate_client_encoding};
use crate::error::{
self, CatalogSnafu, ExecLogicalPlanSnafu, ExternalSnafu, InvalidSqlSnafu, NotSupportedSnafu,
PlanStatementSnafu, Result, SchemaNotFoundSnafu, TableMetadataManagerSnafu, TableNotFoundSnafu,
UpgradeCatalogManagerRefSnafu,
PlanStatementSnafu, Result, SchemaNotFoundSnafu, StatementTimeoutSnafu,
TableMetadataManagerSnafu, TableNotFoundSnafu, UpgradeCatalogManagerRefSnafu,
};
use crate::insert::InserterRef;
use crate::statement::copy_database::{COPY_DATABASE_TIME_END_KEY, COPY_DATABASE_TIME_START_KEY};
Expand Down Expand Up @@ -413,8 +419,19 @@ impl StatementExecutor {

#[tracing::instrument(skip_all)]
async fn plan_exec(&self, stmt: QueryStatement, query_ctx: QueryContextRef) -> Result<Output> {
let plan = self.plan(&stmt, query_ctx.clone()).await?;
self.exec_plan(plan, query_ctx).await
let timeout = derive_timeout(&stmt, &query_ctx);
match timeout {
Some(timeout) => {
let start = tokio::time::Instant::now();
let output = tokio::time::timeout(timeout, self.plan_exec_inner(stmt, query_ctx))
.await
.context(StatementTimeoutSnafu)?;
// compute remaining timeout
let remaining_timeout = timeout.checked_sub(start.elapsed()).unwrap_or_default();
Ok(attach_timeout(output?, remaining_timeout))
}
None => self.plan_exec_inner(stmt, query_ctx).await,
}
}

async fn get_table(&self, table_ref: &TableReference<'_>) -> Result<TableRef> {
Expand All @@ -431,6 +448,49 @@ impl StatementExecutor {
table_name: table_ref.to_string(),
})
}

async fn plan_exec_inner(
&self,
stmt: QueryStatement,
query_ctx: QueryContextRef,
) -> Result<Output> {
let plan = self.plan(&stmt, query_ctx.clone()).await?;
self.exec_plan(plan, query_ctx).await
}
}

fn attach_timeout(output: Output, mut timeout: Duration) -> Output {
match output.data {
OutputData::AffectedRows(_) | OutputData::RecordBatches(_) => output,
OutputData::Stream(mut stream) => {
let schema = stream.schema();
let s = Box::pin(stream! {
let start = tokio::time::Instant::now();
while let Some(item) = tokio::time::timeout(timeout, stream.next()).await.context(StreamTimeoutSnafu)? {
yield item;
timeout = timeout.checked_sub(tokio::time::Instant::now() - start).unwrap_or(Duration::ZERO);
}
}) as Pin<Box<dyn Stream<Item = _> + Send>>;
let stream = RecordBatchStreamWrapper {
schema,
stream: s,
output_ordering: None,
metrics: Default::default(),
};
Output::new(OutputData::Stream(Box::pin(stream)), output.meta)
}
}
}

/// If the relevant variables are set, the timeout is enforced for all PostgreSQL statements.
/// For MySQL, it applies only to read-only statements.
fn derive_timeout(stmt: &QueryStatement, query_ctx: &QueryContextRef) -> Option<Duration> {
let query_timeout = query_ctx.query_timeout()?;
match (query_ctx.channel(), stmt) {
(Channel::Mysql, QueryStatement::Sql(Statement::Query(_)))
| (Channel::Postgres, QueryStatement::Sql(_)) => Some(query_timeout),
(_, _) => None,
}
}

fn to_copy_table_request(stmt: CopyTable, query_ctx: QueryContextRef) -> Result<CopyTableRequest> {
Expand Down

0 comments on commit dec9e08

Please sign in to comment.