Skip to content

Commit

Permalink
implement statement/execution timeout session variable
Browse files Browse the repository at this point in the history
  • Loading branch information
lyang24 committed Oct 10, 2024
1 parent caf5f2c commit 84930a2
Show file tree
Hide file tree
Showing 9 changed files with 225 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/log-store/src/kafka/client_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ impl ClientManager {
}

async fn try_create_client(&self, provider: &Arc<KafkaProvider>) -> Result<Client> {
// Sets to Retry to retry connecting if the kafka cluter replies with an UnknownTopic error.
// Sets to Retry to retry connecting if the kafka cluster replies with an UnknownTopic error.
// That's because the topic is believed to exist as the metasrv is expected to create required topics upon start.
// The reconnecting won't stop until succeed or a different error returns.
let client = self
Expand Down
10 changes: 10 additions & 0 deletions src/operator/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use datafusion::parquet;
use datatypes::arrow::error::ArrowError;
use snafu::{Location, Snafu};
use table::metadata::TableType;
use tokio::time::error::Elapsed;

#[derive(Snafu)]
#[snafu(visibility(pub))]
Expand Down Expand Up @@ -770,6 +771,14 @@ pub enum Error {
#[snafu(implicit)]
location: Location,
},

#[snafu(display("Canceling statement due to query timeout"))]
QueryTimeout {
#[snafu(source)]
error: Elapsed,
#[snafu(implicit)]
location: Location,
},
}

pub type Result<T> = std::result::Result<T, Error>;
Expand Down Expand Up @@ -916,6 +925,7 @@ impl ErrorExt for Error {
Error::BuildRecordBatch { source, .. } => source.status_code(),

Error::UpgradeCatalogManagerRef { .. } => StatusCode::Internal,
Error::QueryTimeout { .. } => StatusCode::Cancelled,
}
}

Expand Down
35 changes: 32 additions & 3 deletions src/operator/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ use query::stats::StatementStatistics;
use query::QueryEngineRef;
use session::context::{Channel, QueryContextRef};
use session::table_name::table_idents_to_full_name;
use set::set_query_timeout;
use snafu::{ensure, OptionExt, ResultExt};
use sql::statements::copy::{CopyDatabase, CopyDatabaseArgument, CopyTable, CopyTableArgument};
use sql::statements::set_variables::SetVariables;
Expand All @@ -64,8 +65,8 @@ use table::TableRef;
use self::set::{set_bytea_output, set_datestyle, set_timezone, validate_client_encoding};
use crate::error::{
self, CatalogSnafu, ExecLogicalPlanSnafu, ExternalSnafu, InvalidSqlSnafu, NotSupportedSnafu,
PlanStatementSnafu, Result, SchemaNotFoundSnafu, TableMetadataManagerSnafu, TableNotFoundSnafu,
UpgradeCatalogManagerRefSnafu,
PlanStatementSnafu, QueryTimeoutSnafu, Result, SchemaNotFoundSnafu, TableMetadataManagerSnafu,
TableNotFoundSnafu, UpgradeCatalogManagerRefSnafu,
};
use crate::insert::InserterRef;
use crate::statement::copy_database::{COPY_DATABASE_TIME_END_KEY, COPY_DATABASE_TIME_START_KEY};
Expand Down Expand Up @@ -120,7 +121,15 @@ impl StatementExecutor {
) -> Result<Output> {
let _slow_query_timer = self.stats.start_slow_query_timer(stmt.clone());
match stmt {
QueryStatement::Sql(stmt) => self.execute_sql(stmt, query_ctx).await,
QueryStatement::Sql(stmt) => {
if let Some(timeout) = query_ctx.query_timeout() {
return tokio::time::timeout(timeout, self.execute_sql(stmt, query_ctx))
.await
.context(QueryTimeoutSnafu)?;
} else {
self.execute_sql(stmt, query_ctx).await
}
}
QueryStatement::Promql(_) => self.plan_exec(stmt, query_ctx).await,
}
}
Expand Down Expand Up @@ -343,6 +352,26 @@ impl StatementExecutor {
"DATESTYLE" => set_datestyle(set_var.value, query_ctx)?,

"CLIENT_ENCODING" => validate_client_encoding(set_var)?,
// TODO: write sqlness test for query timeout variables
// once the proper channel is configured in the test infra.
// The current sqlness test channel is default to Unknown.
"MAX_EXECUTION_TIME" => {
if query_ctx.channel() == Channel::Mysql {
set_query_timeout(set_var.value, query_ctx)?
} else {
return NotSupportedSnafu {
feat: format!("Unsupported set variable {}", var_name),
}
.fail();
}
}
"STATEMENT_TIMEOUT" => {
if query_ctx.channel() == Channel::Postgres {
set_query_timeout(set_var.value, query_ctx)?
} else {
query_ctx.set_warning(format!("Unsupported set variable {}", var_name));
}
}
_ => {
// for postgres, we give unknown SET statements a warning with
// success, this is prevent the SET call becoming a blocker
Expand Down
97 changes: 96 additions & 1 deletion src/operator/src/statement/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::time::Duration;

use common_time::Timezone;
use session::context::Channel::Postgres;
use session::context::QueryContextRef;
use session::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle};
use snafu::{ensure, OptionExt, ResultExt};
use sql::ast::{Expr, Ident, Value};
use sql::statements::set_variables::SetVariables;

use crate::error::{InvalidConfigValueSnafu, InvalidSqlSnafu, NotSupportedSnafu, Result};
use crate::error::{
BuildRegexSnafu, InvalidConfigValueSnafu, InvalidSqlSnafu, NotSupportedSnafu, Result,
};

pub fn set_timezone(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
let tz_expr = exprs.first().context(NotSupportedSnafu {
Expand Down Expand Up @@ -177,3 +182,93 @@ pub fn set_datestyle(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
.set_pg_datetime_style(style.unwrap_or(old_style), order.unwrap_or(older_order));
Ok(())
}

pub fn set_query_timeout(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
let timeout_expr = exprs.first().context(NotSupportedSnafu {
feat: "No timeout value find in set query timeout statement",
})?;
match timeout_expr {
Expr::Value(Value::Number(timeout, _)) => {
match timeout.parse::<u64>() {
Ok(timeout) => ctx.set_query_timeout(Duration::from_millis(timeout)),
Err(_) => {
return NotSupportedSnafu {
feat: format!("Invalid timeout expr {} in set variable statement", timeout),
}
.fail()
}
}
Ok(())
}
// postgres support time units i.e. SET STATEMENT_TIMEOUT = '50ms';
Expr::Value(Value::SingleQuotedString(timeout))
| Expr::Value(Value::DoubleQuotedString(timeout)) => {
if ctx.channel() != Postgres {
return NotSupportedSnafu {
feat: format!("Invalid timeout expr {} in set variable statement", timeout),
}
.fail();
}
let timeout = parse_pg_query_timeout_input(timeout)?;
ctx.set_query_timeout(Duration::from_millis(timeout));
Ok(())
}
expr => NotSupportedSnafu {
feat: format!(
"Unsupported timeout expr {} in set variable statement",
expr
),
}
.fail(),
}
}

// support time units in ms, s, min, h, d for postgres protocol.
// https://www.postgresql.org/docs/8.4/config-setting.html#:~:text=Valid%20memory%20units%20are%20kB,%2C%20and%20d%20(days).
fn parse_pg_query_timeout_input(input: &str) -> Result<u64> {
// Regex rules:
// The string must start with a number (one or more digits).
// The number must be followed by one of the valid time units (ms, s, min, h, d).
// The string must end immediately after the unit, meaning there can be no extra
// characters or spaces after the valid time specification.
let re = regex::Regex::new(r"^(\d+)(ms|s|min|h|d)$").context(BuildRegexSnafu)?;
if let Some(captures) = re.captures(input) {
let value = captures[1].parse::<u64>().expect("regex failed");
let unit = &captures[2];

match unit {
"ms" => Ok(value),
"s" => Ok(value * 1000),
"min" => Ok(value * 60 * 1000),
"h" => Ok(value * 60 * 60 * 1000),
"d" => Ok(value * 24 * 60 * 60 * 1000),
_ => unreachable!("regex failed"),
}
} else {
NotSupportedSnafu {
feat: format!(
"Unsupported timeout expr {} in set variable statement",
input
),
}
.fail()
}
}

#[cfg(test)]
mod test {
use crate::statement::set::parse_pg_query_timeout_input;

#[test]
fn test_parse_pg_query_timeout_input() {
assert!(parse_pg_query_timeout_input("").is_err());
assert!(parse_pg_query_timeout_input(" 50 ms").is_err());
assert!(parse_pg_query_timeout_input("5s 1ms").is_err());
assert!(parse_pg_query_timeout_input("3a").is_err());
assert!(parse_pg_query_timeout_input("1.5min").is_err());

assert_eq!(12, parse_pg_query_timeout_input("12ms").unwrap());
assert_eq!(2000, parse_pg_query_timeout_input("2s").unwrap());
assert_eq!(60000, parse_pg_query_timeout_input("1min").unwrap());
}
}
19 changes: 18 additions & 1 deletion src/query/src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use datatypes::vectors::StringVector;
use object_store::ObjectStore;
use once_cell::sync::Lazy;
use regex::Regex;
use session::context::QueryContextRef;
use session::context::{Channel, QueryContextRef};
pub use show_create_table::create_table_stmt;
use snafu::{ensure, OptionExt, ResultExt};
use sql::ast::Ident;
Expand Down Expand Up @@ -650,6 +650,23 @@ pub fn show_variable(stmt: ShowVariables, query_ctx: QueryContextRef) -> Result<
let (style, order) = *query_ctx.configuration_parameter().pg_datetime_style();
format!("{}, {}", style, order)
}
"MAX_EXECUTION_TIME" => {
if query_ctx.channel() == Channel::Mysql {
query_ctx.query_timeout_as_millis().to_string()
} else {
return UnsupportedVariableSnafu { name: variable }.fail();
}
}
"STATEMENT_TIMEOUT" => {
// Add time units to postgres query timeout display.
if query_ctx.channel() == Channel::Postgres {
let mut timeout = query_ctx.query_timeout_as_millis().to_string();
timeout.push_str("ms");
timeout
} else {
return UnsupportedVariableSnafu { name: variable }.fail();
}
}
_ => return UnsupportedVariableSnafu { name: variable }.fail(),
};
let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
Expand Down
2 changes: 1 addition & 1 deletion src/script/src/python/rspython/builtins/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ enum PyValue {
}

impl PyValue {
/// compare if results is just as expect, not using PartialEq because it is not transtive .e.g. [1,2,3] == len(3) == [4,5,6]
/// compare if results is just as expect, not using PartialEq because it is not transitive .e.g. [1,2,3] == len(3) == [4,5,6]
fn just_as_expect(&self, other: &Self) -> bool {
match (self, other) {
(PyValue::FloatVec(a), PyValue::FloatVec(b)) => a
Expand Down
17 changes: 17 additions & 0 deletions src/session/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::collections::HashMap;
use std::fmt::{Display, Formatter};
use std::net::SocketAddr;
use std::sync::{Arc, RwLock};
use std::time::Duration;

use api::v1::region::RegionRequestHeader;
use arc_swap::ArcSwap;
Expand Down Expand Up @@ -282,6 +283,22 @@ impl QueryContext {
pub fn set_warning(&self, msg: String) {
self.mutable_query_context_data.write().unwrap().warning = Some(msg);
}

pub fn query_timeout(&self) -> Option<Duration> {
self.mutable_session_data.read().unwrap().query_timeout
}

pub fn query_timeout_as_millis(&self) -> u128 {
let timeout = self.mutable_session_data.read().unwrap().query_timeout;
if let Some(t) = timeout {
return t.as_millis();
}
0
}

pub fn set_query_timeout(&self, timeout: Duration) {
self.mutable_session_data.write().unwrap().query_timeout = Some(timeout);
}
}

impl QueryContextBuilder {
Expand Down
3 changes: 3 additions & 0 deletions src/session/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub mod table_name;

use std::net::SocketAddr;
use std::sync::{Arc, RwLock};
use std::time::Duration;

use auth::UserInfoRef;
use common_catalog::build_db_string;
Expand Down Expand Up @@ -45,6 +46,7 @@ pub(crate) struct MutableInner {
schema: String,
user_info: UserInfoRef,
timezone: Timezone,
query_timeout: Option<Duration>,
}

impl Default for MutableInner {
Expand All @@ -53,6 +55,7 @@ impl Default for MutableInner {
schema: DEFAULT_SCHEMA_NAME.into(),
user_info: auth::userinfo_by_name(None),
timezone: get_timezone(None).clone(),
query_timeout: None,
}
}
}
Expand Down
Loading

0 comments on commit 84930a2

Please sign in to comment.