diff --git a/src/servers/src/error.rs b/src/servers/src/error.rs index 3109e6e8c574..6682a1c78967 100644 --- a/src/servers/src/error.rs +++ b/src/servers/src/error.rs @@ -575,6 +575,13 @@ pub enum Error { #[snafu(implicit)] location: Location, }, + + #[snafu(display("Prepare statement not found: {}", name))] + PrepareStatementNotFound { + name: String, + #[snafu(implicit)] + location: Location, + }, } pub type Result = std::result::Result; @@ -643,7 +650,8 @@ impl ErrorExt for Error { | TimestampOverflow { .. } | OpenTelemetryLog { .. } | UnsupportedJsonDataTypeForTag { .. } - | InvalidTableName { .. } => StatusCode::InvalidArguments, + | InvalidTableName { .. } + | PrepareStatementNotFound { .. } => StatusCode::InvalidArguments, Catalog { source, .. } => source.status_code(), RowWriter { source, .. } => source.status_code(), diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index 587742687d82..6e3df12d4006 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -57,6 +57,23 @@ use crate::SqlPlan; const MYSQL_NATIVE_PASSWORD: &str = "mysql_native_password"; const MYSQL_CLEAR_PASSWORD: &str = "mysql_clear_password"; +/// Parameters for the prepared statement +enum Params<'a> { + /// Parameters passed through protocol + ProtocolParams(Vec>), + /// Parameters passed through cli + CliParams(Vec), +} + +impl Params<'_> { + fn len(&self) -> usize { + match self { + Params::ProtocolParams(params) => params.len(), + Params::CliParams(params) => params.len(), + } + } +} + // An intermediate shim for executing MySQL queries. pub struct MysqlInstanceShim { query_handler: ServerSqlQueryHandlerRef, @@ -143,9 +160,9 @@ impl MysqlInstanceShim { } /// Retrieve the query and logical plan by a given statement key - fn plan(&self, stmt_key: String) -> Option { + fn plan(&self, stmt_key: &str) -> Option { let guard = self.prepared_stmts.read(); - guard.get(&stmt_key).cloned() + guard.get(stmt_key).cloned() } /// Save the prepared statement and return the parameters and result columns @@ -217,6 +234,66 @@ impl MysqlInstanceShim { Ok((params, columns)) } + async fn do_execute<'a>( + &mut self, + query_ctx: QueryContextRef, + stmt_key: String, + params: Params<'a>, + ) -> Result>> { + let sql_plan = match self.plan(&stmt_key) { + None => { + return error::PrepareStatementNotFoundSnafu { name: stmt_key }.fail(); + } + Some(sql_plan) => sql_plan, + }; + + let outputs = match sql_plan.plan { + Some(plan) => { + let param_types = plan + .get_parameter_types() + .context(DataFrameSnafu)? + .into_iter() + .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v)))) + .collect::>(); + + if params.len() != param_types.len() { + return error::InternalSnafu { + err_msg: "Prepare statement params number mismatch".to_string(), + } + .fail(); + } + + let plan = match params { + Params::ProtocolParams(params) => { + replace_params_with_values(&plan, param_types, ¶ms) + } + Params::CliParams(params) => { + replace_params_with_exprs(&plan, param_types, ¶ms) + } + }?; + + debug!("Mysql execute prepared plan: {}", plan.display_indent()); + vec![ + self.do_exec_plan(&sql_plan.query, plan, query_ctx.clone()) + .await, + ] + } + None => { + let param_strs = match params { + Params::ProtocolParams(params) => { + params.iter().map(convert_param_value_to_string).collect() + } + Params::CliParams(params) => params.iter().map(|x| x.to_string()).collect(), + }; + let query = replace_params(param_strs, sql_plan.query); + debug!("Mysql execute replaced query: {}", query); + self.do_query(&query, query_ctx.clone()).await + } + }; + + Ok(outputs) + } + /// Remove the prepared statement by a given statement key fn do_close(&mut self, stmt_key: String) { let mut guard = self.prepared_stmts.write(); @@ -346,62 +423,20 @@ impl AsyncMysqlShim for MysqlInstanceShi let params: Vec = p.into_iter().collect(); let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string(); - let sql_plan = match self.plan(stmt_key) { - None => { - w.error( - ErrorKind::ER_UNKNOWN_STMT_HANDLER, - b"prepare statement not found", - ) - .await?; - return Ok(()); - } - Some(sql_plan) => sql_plan, - }; - - let outputs = match sql_plan.plan { - Some(plan) => { - let param_types = plan - .get_parameter_types() - .context(DataFrameSnafu)? - .into_iter() - .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v)))) - .collect::>(); - - if params.len() != param_types.len() { - return error::InternalSnafu { - err_msg: "prepare statement params number mismatch".to_string(), - } - .fail(); - } - let plan = match replace_params_with_values(&plan, param_types, ¶ms) { - Ok(plan) => plan, - Err(e) => { - let (kind, err) = handle_err(e, query_ctx); - debug!( - "Failed to replace params on execute, kind: {:?}, err: {}", - kind, err - ); - w.error(kind, err.as_bytes()).await?; - - return Ok(()); - } - }; - - debug!("Mysql execute prepared plan: {}", plan.display_indent()); - vec![ - self.do_exec_plan(&sql_plan.query, plan, query_ctx.clone()) - .await, - ] - } - None => { - let param_strs = params - .iter() - .map(|x| convert_param_value_to_string(x)) - .collect(); - let query = replace_params(param_strs, sql_plan.query); - debug!("Mysql execute replaced query: {}", query); - self.do_query(&query, query_ctx.clone()).await + let outputs = match self + .do_execute(query_ctx.clone(), stmt_key, Params::ProtocolParams(params)) + .await + { + Ok(outputs) => outputs, + Err(e) => { + let (kind, err) = handle_err(e, query_ctx); + debug!( + "Failed to execute prepared statement, kind: {:?}, err: {}", + kind, err + ); + w.error(kind, err.as_bytes()).await?; + return Ok(()); } }; @@ -459,67 +494,19 @@ impl AsyncMysqlShim for MysqlInstanceShi } } else if query_upcase.starts_with("EXECUTE ") { match ParserContext::parse_mysql_execute_stmt(query, query_ctx.sql_dialect()) { - // TODO: similar to on_execute, refactor this Ok((stmt_name, params)) => { - let sql_plan = match self.plan(stmt_name) { - None => { - writer - .error( - ErrorKind::ER_UNKNOWN_STMT_HANDLER, - b"prepare statement not found", - ) - .await?; - return Ok(()); - } - Some(sql_plan) => sql_plan, - }; - - let outputs = match sql_plan.plan { - Some(plan) => { - let param_types = plan - .get_parameter_types() - .context(DataFrameSnafu)? - .into_iter() - .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v)))) - .collect::>(); - - if params.len() != param_types.len() { - writer - .error( - ErrorKind::ER_SP_BADSTATEMENT, - b"prepare statement params number mismatch", - ) - .await?; - return Ok(()); - } - - let plan = match replace_params_with_exprs(&plan, param_types, ¶ms) - { - Ok(plan) => plan, - Err(e) => { - let (kind, err) = handle_err(e, query_ctx); - debug!( - "Failed to replace params on query, kind: {:?}, err: {}", - kind, err - ); - writer.error(kind, err.as_bytes()).await?; - - return Ok(()); - } - }; - - debug!("Mysql execute prepared plan: {}", plan.display_indent()); - vec![ - self.do_exec_plan(&sql_plan.query, plan, query_ctx.clone()) - .await, - ] - } - None => { - let param_strs = params.iter().map(|x| x.to_string()).collect(); - let query = replace_params(param_strs, sql_plan.query); - debug!("Mysql execute replaced query: {}", query); - let outputs = self.do_query(&query, query_ctx.clone()).await; - writer::write_output(writer, query_ctx, outputs).await?; + let outputs = match self + .do_execute(query_ctx.clone(), stmt_name, Params::CliParams(params)) + .await + { + Ok(outputs) => outputs, + Err(e) => { + let (kind, err) = handle_err(e, query_ctx); + debug!( + "Failed to execute prepared statement, kind: {:?}, err: {}", + kind, err + ); + writer.error(kind, err.as_bytes()).await?; return Ok(()); } }; @@ -613,8 +600,8 @@ fn convert_param_value_to_string(param: &ParamValue) -> String { ValueInner::Double(u) => u.to_string(), ValueInner::NULL => "NULL".to_string(), ValueInner::Bytes(b) => format!("'{}'", &String::from_utf8_lossy(b)), - ValueInner::Date(_) => NaiveDate::from(param.value).to_string(), - ValueInner::Datetime(_) => NaiveDateTime::from(param.value).to_string(), + ValueInner::Date(_) => format!("'{}'", NaiveDate::from(param.value)), + ValueInner::Datetime(_) => format!("'{}'", NaiveDateTime::from(param.value)), ValueInner::Time(_) => format_duration(Duration::from(param.value)), } } @@ -633,7 +620,7 @@ fn format_duration(duration: Duration) -> String { let seconds = duration.as_secs() % 60; let minutes = (duration.as_secs() / 60) % 60; let hours = (duration.as_secs() / 60) / 60; - format!("{}:{}:{}", hours, minutes, seconds) + format!("'{}:{}:{}'", hours, minutes, seconds) } fn replace_params_with_values(