From 4d8ebd80a375cf225cf2bb1fe6553a92c1dd2646 Mon Sep 17 00:00:00 2001 From: lovasoa Date: Fri, 21 Feb 2025 19:28:43 +0100 Subject: [PATCH 1/5] use numbered parameters in sqlite --- src/webserver/database/sql.rs | 49 +++++++++---------- .../it_works_case_variables.sql | 10 ++++ 2 files changed, 34 insertions(+), 25 deletions(-) create mode 100644 tests/sql_test_files/it_works_case_variables.sql diff --git a/src/webserver/database/sql.rs b/src/webserver/database/sql.rs index df422fc1..d708c59b 100644 --- a/src/webserver/database/sql.rs +++ b/src/webserver/database/sql.rs @@ -473,8 +473,11 @@ struct ParameterExtractor { parameters: Vec, } -const PLACEHOLDER_PREFIXES: [(AnyKind, &str); 2] = - [(AnyKind::Postgres, "$"), (AnyKind::Mssql, "@p")]; +const PLACEHOLDER_PREFIXES: [(AnyKind, &str); 3] =[ + (AnyKind::Sqlite, "?"), + (AnyKind::Postgres, "$"), + (AnyKind::Mssql, "@p") +]; const DEFAULT_PLACEHOLDER: &str = "?"; impl ParameterExtractor { @@ -490,8 +493,17 @@ impl ParameterExtractor { this.parameters } + fn replace_with_placeholder(&mut self, value: &mut Expr, param: StmtParam) { + let placeholder = self.make_placeholder(); + log::trace!("Replacing {value:?} with {placeholder:?}, which references parameter {param:?}"); + self.parameters.push(param); + *value = placeholder; + } + fn make_placeholder(&self) -> Expr { - let name = make_placeholder(self.db_kind, self.parameters.len() + 1); + // This now uses the current length before sorting, which will be corrected after sorting + let current_index = self.parameters.len(); + let name = make_placeholder(self.db_kind, current_index + 1); // We cast our placeholders to TEXT even though we always bind TEXT data to them anyway // because that helps the database engine to prepare the query. // For instance in PostgreSQL, the query planner will not be able to use an index on a @@ -512,18 +524,6 @@ impl ParameterExtractor { } } - fn handle_builtin_function( - &mut self, - func_name: &str, - mut arguments: Vec, - ) -> Expr { - #[allow(clippy::single_match_else)] - let placeholder = self.make_placeholder(); - let param = func_call_to_param(func_name, &mut arguments); - self.parameters.push(param); - placeholder - } - fn is_own_placeholder(&self, param: &str) -> bool { if let Some((_, prefix)) = PLACEHOLDER_PREFIXES .iter() @@ -746,20 +746,18 @@ fn extract_ident_param(Ident { value, .. }: &mut Ident) -> Option { impl VisitorMut for ParameterExtractor { type Break = (); fn pre_visit_expr(&mut self, value: &mut Expr) -> ControlFlow { + log::trace!("Visiting {value} with span {span:?}", span = value.span()); match value { Expr::Identifier(ident) => { if let Some(param) = extract_ident_param(ident) { - *value = self.make_placeholder(); - self.parameters.push(param); + self.replace_with_placeholder(value, param); } } Expr::Value(Value::Placeholder(param)) if !self.is_own_placeholder(param) => // this check is to avoid recursively replacing placeholders in the form of '?', or '$1', '$2', which we emit ourselves { - let new_expr = self.make_placeholder(); let name = std::mem::take(param); - self.parameters.push(map_param(name)); - *value = new_expr; + self.replace_with_placeholder(value, map_param(name)); } Expr::Function(Function { name: ObjectName(func_name_parts), @@ -776,8 +774,9 @@ impl VisitorMut for ParameterExtractor { }) if is_sqlpage_func(func_name_parts) && are_params_extractable(args) => { let func_name = sqlpage_func_name(func_name_parts); log::trace!("Handling builtin function: {func_name}"); - let arguments = std::mem::take(args); - *value = self.handle_builtin_function(func_name, arguments); + let mut arguments = std::mem::take(args); + let param = func_call_to_param(func_name, &mut arguments); + self.replace_with_placeholder(value, param); } // Replace 'str1' || 'str2' with CONCAT('str1', 'str2') for MSSQL Expr::BinaryOp { @@ -1005,7 +1004,7 @@ mod test { let parameters = ParameterExtractor::extract_parameters(&mut ast, AnyKind::Sqlite); assert_eq!( ast.to_string(), - "SELECT CAST(? AS TEXT), CAST(? AS TEXT) FROM t" + "SELECT CAST(?1 AS TEXT), CAST(?2 AS TEXT) FROM t" ); assert_eq!( parameters, @@ -1148,7 +1147,7 @@ mod test { assert!(ParameterExtractor { db_kind: AnyKind::Postgres, - parameters: vec![StmtParam::Get('x'.to_string())] + parameters: vec![StmtParam::Get("x".to_string())] } .is_own_placeholder("$2")); @@ -1162,7 +1161,7 @@ mod test { db_kind: AnyKind::Sqlite, parameters: vec![] } - .is_own_placeholder("?")); + .is_own_placeholder("?1")); assert!(!ParameterExtractor { db_kind: AnyKind::Sqlite, diff --git a/tests/sql_test_files/it_works_case_variables.sql b/tests/sql_test_files/it_works_case_variables.sql new file mode 100644 index 00000000..44d68510 --- /dev/null +++ b/tests/sql_test_files/it_works_case_variables.sql @@ -0,0 +1,10 @@ +-- https://github.com/sqlpage/SQLPage/issues/818 + +set success = 'It works !'; +set failure = 'You should never see this'; + +select 'text' as component, + case $success + when $success then $success + when $failure then $failure + end AS contents; \ No newline at end of file From 91b096fc4a5c260454a7573bff0ae7eb106ccb4b Mon Sep 17 00:00:00 2001 From: lovasoa Date: Fri, 21 Feb 2025 23:42:31 +0100 Subject: [PATCH 2/5] implement parameter deduplication --- CHANGELOG.md | 5 +++++ src/webserver/database/sql.rs | 42 +++++++++++++++++++---------------- 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 024445b4..0fc53669 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,11 @@ - Fix a bug with date sorting in the table component. - Center table descriptions. - Fix a rare crash on startup in some restricted linux environments. +- Fix a rare but serious issue when on SQLite and MySQL, some variable values were assigned incorrectly + - `CASE WHEN $a THEN $x WHEN $b THEN $y` would be executed as `CASE WHEN $a THEN $b WHEN $x THEN $y` on these databases. + - the issue only occured when using in case expressions where variables were used both in conditions and results. +- Implement parameter deduplication. + Now, when you write `select $x where $x is not null`, the value of `$x` is sent to the database only once. It used to be sent as many times as `$x` appeared in the statement. ## 0.33.0 (2025-02-15) diff --git a/src/webserver/database/sql.rs b/src/webserver/database/sql.rs index d708c59b..1ab90b50 100644 --- a/src/webserver/database/sql.rs +++ b/src/webserver/database/sql.rs @@ -473,10 +473,10 @@ struct ParameterExtractor { parameters: Vec, } -const PLACEHOLDER_PREFIXES: [(AnyKind, &str); 3] =[ +const PLACEHOLDER_PREFIXES: [(AnyKind, &str); 3] = [ (AnyKind::Sqlite, "?"), (AnyKind::Postgres, "$"), - (AnyKind::Mssql, "@p") + (AnyKind::Mssql, "@p"), ]; const DEFAULT_PLACEHOLDER: &str = "?"; @@ -494,23 +494,23 @@ impl ParameterExtractor { } fn replace_with_placeholder(&mut self, value: &mut Expr, param: StmtParam) { - let placeholder = self.make_placeholder(); - log::trace!("Replacing {value:?} with {placeholder:?}, which references parameter {param:?}"); - self.parameters.push(param); + let placeholder = + if let Some(existing_idx) = self.parameters.iter().position(|p| *p == param) { + // Parameter already exists, use its index + self.make_placeholder_for_index(existing_idx + 1) + } else { + // New parameter, add it to the list + let placeholder = self.make_placeholder(); + log::trace!("Replacing {param} with {placeholder}"); + self.parameters.push(param); + placeholder + }; *value = placeholder; } - fn make_placeholder(&self) -> Expr { - // This now uses the current length before sorting, which will be corrected after sorting - let current_index = self.parameters.len(); - let name = make_placeholder(self.db_kind, current_index + 1); - // We cast our placeholders to TEXT even though we always bind TEXT data to them anyway - // because that helps the database engine to prepare the query. - // For instance in PostgreSQL, the query planner will not be able to use an index on a - // column if the column is compared to a placeholder of type VARCHAR, but it will be able - // to use the index if the column is compared to a placeholder of type TEXT. + fn make_placeholder_for_index(&self, index: usize) -> Expr { + let name = make_placeholder(self.db_kind, index); let data_type = match self.db_kind { - // MySQL requires CAST(? AS CHAR) and does not understand CAST(? AS TEXT) AnyKind::MySql => DataType::Char(None), AnyKind::Mssql => DataType::Varchar(Some(CharacterLength::Max)), _ => DataType::Text, @@ -524,6 +524,10 @@ impl ParameterExtractor { } } + fn make_placeholder(&self) -> Expr { + self.make_placeholder_for_index(self.parameters.len() + 1) + } + fn is_own_placeholder(&self, param: &str) -> bool { if let Some((_, prefix)) = PLACEHOLDER_PREFIXES .iter() @@ -746,7 +750,6 @@ fn extract_ident_param(Ident { value, .. }: &mut Ident) -> Option { impl VisitorMut for ParameterExtractor { type Break = (); fn pre_visit_expr(&mut self, value: &mut Expr) -> ControlFlow { - log::trace!("Visiting {value} with span {span:?}", span = value.span()); match value { Expr::Identifier(ident) => { if let Some(param) = extract_ident_param(ident) { @@ -979,15 +982,16 @@ mod test { let mut ast = parse_postgres_stmt("select $a from t where $x > $a OR $x = sqlpage.cookie('cookoo')"); let parameters = ParameterExtractor::extract_parameters(&mut ast, AnyKind::Postgres); + // $a -> $1 + // $x -> $2 + // sqlpage.cookie(...) -> $3 assert_eq!( ast.to_string(), - "SELECT CAST($1 AS TEXT) FROM t WHERE CAST($2 AS TEXT) > CAST($3 AS TEXT) OR CAST($4 AS TEXT) = CAST($5 AS TEXT)" + "SELECT CAST($1 AS TEXT) FROM t WHERE CAST($2 AS TEXT) > CAST($1 AS TEXT) OR CAST($2 AS TEXT) = CAST($3 AS TEXT)" ); assert_eq!( parameters, [ - StmtParam::PostOrGet("a".to_string()), - StmtParam::PostOrGet("x".to_string()), StmtParam::PostOrGet("a".to_string()), StmtParam::PostOrGet("x".to_string()), StmtParam::FunctionCall(SqlPageFunctionCall { From 8536ccd5157b75dc0202111a5e5fab0f8c26d43c Mon Sep 17 00:00:00 2001 From: lovasoa Date: Sat, 22 Feb 2025 01:10:47 +0100 Subject: [PATCH 3/5] implement parameter re-duplication for mysql --- src/webserver/database/mod.rs | 19 +++- src/webserver/database/sql.rs | 180 +++++++++++++++++++++++++++++----- 2 files changed, 171 insertions(+), 28 deletions(-) diff --git a/src/webserver/database/mod.rs b/src/webserver/database/mod.rs index 344bbcc0..f96935bf 100644 --- a/src/webserver/database/mod.rs +++ b/src/webserver/database/mod.rs @@ -9,7 +9,9 @@ mod syntax_tree; mod error_highlighting; mod sql_to_json; -pub use sql::{make_placeholder, ParsedSqlFile}; +pub use sql::ParsedSqlFile; +use sql::{DbPlaceHolder, DB_PLACEHOLDERS}; +use sqlx::any::AnyKind; pub struct Database { pub connection: sqlx::AnyPool, @@ -34,3 +36,18 @@ impl std::fmt::Display for Database { write!(f, "{:?}", self.connection.any_kind()) } } + +#[inline] +#[must_use] +pub fn make_placeholder(db_kind: AnyKind, arg_number: usize) -> String { + if let Some((_, placeholder)) = + DB_PLACEHOLDERS.iter().find(|(kind, _)| *kind == db_kind) + { + match placeholder { + DbPlaceHolder::PrefixedNumber { prefix } => format!("{prefix}{arg_number}"), + DbPlaceHolder::Positional { placeholder } => placeholder.to_string(), + } + } else { + unreachable!("missing db_kind: {db_kind:?} in DB_PLACEHOLDERS ({DB_PLACEHOLDERS:?})") + } +} \ No newline at end of file diff --git a/src/webserver/database/sql.rs b/src/webserver/database/sql.rs index 1ab90b50..db291ec6 100644 --- a/src/webserver/database/sql.rs +++ b/src/webserver/database/sql.rs @@ -143,6 +143,29 @@ fn parse_sql<'a>( })) } +fn transform_to_positional_placeholders(stmt: &mut StmtWithParams, db_kind: AnyKind) { + if let Some((_, DbPlaceHolder::Positional { placeholder })) = + DB_PLACEHOLDERS.iter().find(|(kind, _)| *kind == db_kind) + { + let mut new_params = Vec::new(); + let mut query = stmt.query.clone(); + while let Some(pos) = query.find(TEMP_PLACEHOLDER_PREFIX) { + let start_of_number = pos + TEMP_PLACEHOLDER_PREFIX.len(); + let end = query[start_of_number..] + .find(|c: char| !c.is_ascii_digit()) + .map_or(query.len(), |i| start_of_number + i); + let param_idx = query[start_of_number..end] + .parse::() + .unwrap_or(1) + - 1; + query.replace_range(pos..end, placeholder); + new_params.push(stmt.params[param_idx].clone()); + } + stmt.query = query; + stmt.params = new_params; + } +} + fn parse_single_statement( parser: &mut Parser<'_>, db_kind: AnyKind, @@ -161,7 +184,8 @@ fn parse_single_statement( semicolon = true; } let mut params = ParameterExtractor::extract_parameters(&mut stmt, db_kind); - if let Some((variable, value)) = extract_set_variable(&mut stmt, &mut params, db_kind) { + if let Some((variable, mut value)) = extract_set_variable(&mut stmt, &mut params, db_kind) { + transform_to_positional_placeholders(&mut value, db_kind); return Some(ParsedStatement::SetVariable { variable, value }); } if let Some(csv_import) = extract_csv_copy_statement(&mut stmt) { @@ -178,14 +202,16 @@ fn parse_single_statement( "{stmt}{semicolon}", semicolon = if semicolon { ";" } else { "" } ); - log::debug!("Final transformed statement: {stmt}"); - Some(ParsedStatement::StmtWithParams(StmtWithParams { + let mut stmt_with_params = StmtWithParams { query, query_position: extract_query_start(&stmt), params, delayed_functions, json_columns, - })) + }; + transform_to_positional_placeholders(&mut stmt_with_params, db_kind); + log::debug!("Final transformed statement: {}", stmt_with_params.query); + Some(ParsedStatement::StmtWithParams(stmt_with_params)) } fn extract_query_start(stmt: &impl Spanned) -> SourceSpan { @@ -473,12 +499,45 @@ struct ParameterExtractor { parameters: Vec, } -const PLACEHOLDER_PREFIXES: [(AnyKind, &str); 3] = [ - (AnyKind::Sqlite, "?"), - (AnyKind::Postgres, "$"), - (AnyKind::Mssql, "@p"), +#[derive(Debug)] +pub enum DbPlaceHolder { + PrefixedNumber { prefix: &'static str }, + Positional { placeholder: &'static str }, +} + +pub const DB_PLACEHOLDERS: [(AnyKind, DbPlaceHolder); 4] = [ + ( + AnyKind::Sqlite, + DbPlaceHolder::PrefixedNumber { prefix: "?" }, + ), + ( + AnyKind::Postgres, + DbPlaceHolder::PrefixedNumber { prefix: "$" }, + ), + ( + AnyKind::MySql, + DbPlaceHolder::Positional { placeholder: "?" }, + ), + ( + AnyKind::Mssql, + DbPlaceHolder::PrefixedNumber { prefix: "@p" }, + ), ]; -const DEFAULT_PLACEHOLDER: &str = "?"; + +/// For positional parameters, we use a temporary placeholder during parameter extraction, +/// And then replace it with the actual placeholder during statement rewriting. +const TEMP_PLACEHOLDER_PREFIX: &str = "@SQLPAGE_TEMP"; + +fn get_placeholder_prefix(db_kind: AnyKind) -> &'static str { + if let Some((_, DbPlaceHolder::PrefixedNumber { prefix })) = DB_PLACEHOLDERS + .iter() + .find(|(kind, _prefix)| *kind == db_kind) + { + prefix + } else { + TEMP_PLACEHOLDER_PREFIX + } +} impl ParameterExtractor { fn extract_parameters( @@ -509,7 +568,7 @@ impl ParameterExtractor { } fn make_placeholder_for_index(&self, index: usize) -> Expr { - let name = make_placeholder(self.db_kind, index); + let name = make_tmp_placeholder(self.db_kind, index); let data_type = match self.db_kind { AnyKind::MySql => DataType::Char(None), AnyKind::Mssql => DataType::Varchar(Some(CharacterLength::Max)), @@ -529,18 +588,13 @@ impl ParameterExtractor { } fn is_own_placeholder(&self, param: &str) -> bool { - if let Some((_, prefix)) = PLACEHOLDER_PREFIXES - .iter() - .find(|(kind, _prefix)| *kind == self.db_kind) - { - if let Some(param) = param.strip_prefix(prefix) { - if let Ok(index) = param.parse::() { - return index <= self.parameters.len() + 1; - } + let prefix = get_placeholder_prefix(self.db_kind); + if let Some(param) = param.strip_prefix(prefix) { + if let Ok(index) = param.parse::() { + return index <= self.parameters.len() + 1; } - return false; } - param == DEFAULT_PLACEHOLDER + return false; } } @@ -728,14 +782,15 @@ fn function_arg_expr(arg: &mut FunctionArg) -> Option<&mut Expr> { #[inline] #[must_use] -pub fn make_placeholder(db_kind: AnyKind, arg_number: usize) -> String { - if let Some((_, prefix)) = PLACEHOLDER_PREFIXES - .iter() - .find(|(kind, _)| *kind == db_kind) +pub fn make_tmp_placeholder(db_kind: AnyKind, arg_number: usize) -> String { + let prefix = if let Some((_, DbPlaceHolder::PrefixedNumber { prefix })) = + DB_PLACEHOLDERS.iter().find(|(kind, _)| *kind == db_kind) { - return format!("{prefix}{arg_number}"); - } - DEFAULT_PLACEHOLDER.to_string() + prefix + } else { + TEMP_PLACEHOLDER_PREFIX + }; + format!("{prefix}{arg_number}") } fn extract_ident_param(Ident { value, .. }: &mut Ident) -> Option { @@ -1415,4 +1470,75 @@ mod test { assert!(json_columns.contains(&"item".to_string())); assert!(!json_columns.contains(&"title".to_string())); } + + #[test] + fn test_positional_placeholders() { + let sql = "select \ + @SQLPAGE_TEMP10 as a1, \ + @SQLPAGE_TEMP9 as a2, \ + @SQLPAGE_TEMP8 as a3, \ + @SQLPAGE_TEMP7 as a4, \ + @SQLPAGE_TEMP6 as a5, \ + @SQLPAGE_TEMP5 as a6, \ + @SQLPAGE_TEMP4 as a7, \ + @SQLPAGE_TEMP3 as a8, \ + @SQLPAGE_TEMP2 as a9, \ + @SQLPAGE_TEMP1 as a10 \ + @SQLPAGE_TEMP10 as a1bis \ + from t"; + let mut stmt = StmtWithParams { + query: sql.to_string(), + query_position: SourceSpan { + start: SourceLocation { line: 1, column: 1 }, + end: SourceLocation { line: 1, column: 1 }, + }, + params: vec![ + StmtParam::PostOrGet("x1".to_string()), + StmtParam::PostOrGet("x2".to_string()), + StmtParam::PostOrGet("x3".to_string()), + StmtParam::PostOrGet("x4".to_string()), + StmtParam::PostOrGet("x5".to_string()), + StmtParam::PostOrGet("x6".to_string()), + StmtParam::PostOrGet("x7".to_string()), + StmtParam::PostOrGet("x8".to_string()), + StmtParam::PostOrGet("x9".to_string()), + StmtParam::PostOrGet("x10".to_string()), + ], + delayed_functions: vec![], + json_columns: vec![], + }; + transform_to_positional_placeholders(&mut stmt, AnyKind::MySql); + assert_eq!( + stmt.query, + "select \ + ? as a1, \ + ? as a2, \ + ? as a3, \ + ? as a4, \ + ? as a5, \ + ? as a6, \ + ? as a7, \ + ? as a8, \ + ? as a9, \ + ? as a10 \ + ? as a1bis \ + from t" + ); + assert_eq!( + stmt.params, + vec![ + StmtParam::PostOrGet("x10".to_string()), + StmtParam::PostOrGet("x9".to_string()), + StmtParam::PostOrGet("x8".to_string()), + StmtParam::PostOrGet("x7".to_string()), + StmtParam::PostOrGet("x6".to_string()), + StmtParam::PostOrGet("x5".to_string()), + StmtParam::PostOrGet("x4".to_string()), + StmtParam::PostOrGet("x3".to_string()), + StmtParam::PostOrGet("x2".to_string()), + StmtParam::PostOrGet("x1".to_string()), + StmtParam::PostOrGet("x10".to_string()), + ] + ); + } } From 9035580fdec694f2f5a2e45e5155cc5c74b6bc54 Mon Sep 17 00:00:00 2001 From: lovasoa Date: Sat, 22 Feb 2025 13:47:01 +0100 Subject: [PATCH 4/5] Improve error messages on invalid sqlpage function calls. The messages now contain actionable advice. --- CHANGELOG.md | 1 + src/webserver/database/mod.rs | 6 ++-- src/webserver/database/sql.rs | 67 +++++++++++++++++++++++++---------- 3 files changed, 51 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0fc53669..a7d28f00 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ - the issue only occured when using in case expressions where variables were used both in conditions and results. - Implement parameter deduplication. Now, when you write `select $x where $x is not null`, the value of `$x` is sent to the database only once. It used to be sent as many times as `$x` appeared in the statement. +- Improve error messages on invalid sqlpage function calls. The messages now contain actionable advice. ## 0.33.0 (2025-02-15) diff --git a/src/webserver/database/mod.rs b/src/webserver/database/mod.rs index f96935bf..9f3c797e 100644 --- a/src/webserver/database/mod.rs +++ b/src/webserver/database/mod.rs @@ -40,9 +40,7 @@ impl std::fmt::Display for Database { #[inline] #[must_use] pub fn make_placeholder(db_kind: AnyKind, arg_number: usize) -> String { - if let Some((_, placeholder)) = - DB_PLACEHOLDERS.iter().find(|(kind, _)| *kind == db_kind) - { + if let Some((_, placeholder)) = DB_PLACEHOLDERS.iter().find(|(kind, _)| *kind == db_kind) { match placeholder { DbPlaceHolder::PrefixedNumber { prefix } => format!("{prefix}{arg_number}"), DbPlaceHolder::Positional { placeholder } => placeholder.to_string(), @@ -50,4 +48,4 @@ pub fn make_placeholder(db_kind: AnyKind, arg_number: usize) -> String { } else { unreachable!("missing db_kind: {db_kind:?} in DB_PLACEHOLDERS ({DB_PLACEHOLDERS:?})") } -} \ No newline at end of file +} diff --git a/src/webserver/database/sql.rs b/src/webserver/database/sql.rs index db291ec6..91734ce9 100644 --- a/src/webserver/database/sql.rs +++ b/src/webserver/database/sql.rs @@ -10,7 +10,8 @@ use sqlparser::ast::helpers::attached_token::AttachedToken; use sqlparser::ast::{ BinaryOperator, CastKind, CharacterLength, DataType, Expr, Function, FunctionArg, FunctionArgExpr, FunctionArgumentList, FunctionArguments, Ident, ObjectName, - OneOrManyWithParens, SelectItem, SetExpr, Spanned, Statement, Value, VisitMut, VisitorMut, + OneOrManyWithParens, SelectItem, SetExpr, Spanned, Statement, Value, Visit, VisitMut, Visitor, + VisitorMut, }; use sqlparser::dialect::{Dialect, MsSqlDialect, MySqlDialect, PostgreSqlDialect, SQLiteDialect}; use sqlparser::parser::{Parser, ParserError}; @@ -154,10 +155,7 @@ fn transform_to_positional_placeholders(stmt: &mut StmtWithParams, db_kind: AnyK let end = query[start_of_number..] .find(|c: char| !c.is_ascii_digit()) .map_or(query.len(), |i| start_of_number + i); - let param_idx = query[start_of_number..end] - .parse::() - .unwrap_or(1) - - 1; + let param_idx = query[start_of_number..end].parse::().unwrap_or(1) - 1; query.replace_range(pos..end, placeholder); new_params.push(stmt.params[param_idx].clone()); } @@ -196,7 +194,11 @@ fn parse_single_statement( return Some(ParsedStatement::StaticSimpleSelect(static_statement)); } let delayed_functions = extract_toplevel_functions(&mut stmt); - remove_invalid_function_calls(&mut stmt, &mut params); + if let Err(err) = validate_function_calls(&stmt) { + return Some(ParsedStatement::Error(err.context(format!( + "Invalid SQLPage function call found in:\n{stmt}" + )))); + } let json_columns = extract_json_columns(&stmt, db_kind); let query = format!( "{stmt}{semicolon}", @@ -477,7 +479,18 @@ fn extract_set_variable( let owned_expr = std::mem::replace(value, Expr::Value(Value::Null)); let mut select_stmt: Statement = expr_to_statement(owned_expr); let delayed_functions = extract_toplevel_functions(&mut select_stmt); - remove_invalid_function_calls(&mut select_stmt, params); + if let Err(err) = validate_function_calls(&mut select_stmt) { + return Some(( + variable, + StmtWithParams { + query: format!("SELECT '' WHERE false -- {}", err), + query_position: extract_query_start(&select_stmt), + params: std::mem::take(params), + delayed_functions: vec![], + json_columns: vec![], + }, + )); + } let json_columns = extract_json_columns(&select_stmt, db_kind); return Some(( variable, @@ -598,10 +611,10 @@ impl ParameterExtractor { } } -struct BadFunctionRemover; -impl VisitorMut for BadFunctionRemover { - type Break = StmtParam; - fn pre_visit_expr(&mut self, value: &mut Expr) -> ControlFlow { +struct InvalidFunctionFinder; +impl Visitor for InvalidFunctionFinder { + type Break = (String, Vec); + fn pre_visit_expr(&mut self, value: &Expr) -> ControlFlow { match value { Expr::Function(Function { name: ObjectName(func_name_parts), @@ -614,10 +627,8 @@ impl VisitorMut for BadFunctionRemover { .. }) if is_sqlpage_func(func_name_parts) => { let func_name = sqlpage_func_name(func_name_parts); - log::error!("Invalid function call to sqlpage.{func_name}. SQLPage function arguments must be static if the function is not at the top level of a select statement."); - let mut arguments = std::mem::take(args); - let param = func_call_to_param(func_name, &mut arguments); - return ControlFlow::Break(param); + let arguments = args.clone(); + return ControlFlow::Break((func_name.to_string(), arguments)); } _ => (), } @@ -625,10 +636,28 @@ impl VisitorMut for BadFunctionRemover { } } -fn remove_invalid_function_calls(stmt: &mut Statement, params: &mut Vec) { - let mut remover = BadFunctionRemover; - if let ControlFlow::Break(param) = stmt.visit(&mut remover) { - params.push(param); +fn validate_function_calls(stmt: &Statement) -> anyhow::Result<()> { + let mut finder = InvalidFunctionFinder; + if let ControlFlow::Break((func_name, args)) = stmt.visit(&mut finder) { + let args_str = FormatArguments(&args); + let error_msg = format!( + "Invalid SQLPage function call: sqlpage.{func_name}({args_str})\n\n\ + Arbitrary SQL expressions as function arguments are not supported.\n\n\ + SQLPage functions can either:\n\ + 1. Run BEFORE the query (to provide input values)\n\ + 2. Run AFTER the query (to process the results)\n\ + But they can't run DURING the query - the database doesn't know how to call them!\n\n\ + To fix this, you can either:\n\ + 1. Store the function argument in a variable first:\n\ + SET {func_name}_arg = ...;\n\ + SET {func_name}_result = sqlpage.{func_name}(${func_name}_arg);\n\ + SELECT * FROM example WHERE xxx = ${func_name}_result;\n\n\ + 2. Or move the function to the top level to process results:\n\ + SELECT sqlpage.{func_name}(...) FROM example;" + ); + Err(anyhow::anyhow!(error_msg)) + } else { + Ok(()) } } From ee5d57c0055fec69b7cd3866431a6ec5ba0e30c6 Mon Sep 17 00:00:00 2001 From: lovasoa Date: Sat, 22 Feb 2025 14:05:03 +0100 Subject: [PATCH 5/5] simplify extract_set_variable --- src/webserver/database/mod.rs | 2 +- src/webserver/database/sql.rs | 82 +++++++++++++++++++---------------- 2 files changed, 45 insertions(+), 39 deletions(-) diff --git a/src/webserver/database/mod.rs b/src/webserver/database/mod.rs index 9f3c797e..e9f0949e 100644 --- a/src/webserver/database/mod.rs +++ b/src/webserver/database/mod.rs @@ -41,7 +41,7 @@ impl std::fmt::Display for Database { #[must_use] pub fn make_placeholder(db_kind: AnyKind, arg_number: usize) -> String { if let Some((_, placeholder)) = DB_PLACEHOLDERS.iter().find(|(kind, _)| *kind == db_kind) { - match placeholder { + match *placeholder { DbPlaceHolder::PrefixedNumber { prefix } => format!("{prefix}{arg_number}"), DbPlaceHolder::Positional { placeholder } => placeholder.to_string(), } diff --git a/src/webserver/database/sql.rs b/src/webserver/database/sql.rs index 91734ce9..8f92f11c 100644 --- a/src/webserver/database/sql.rs +++ b/src/webserver/database/sql.rs @@ -182,9 +182,8 @@ fn parse_single_statement( semicolon = true; } let mut params = ParameterExtractor::extract_parameters(&mut stmt, db_kind); - if let Some((variable, mut value)) = extract_set_variable(&mut stmt, &mut params, db_kind) { - transform_to_positional_placeholders(&mut value, db_kind); - return Some(ParsedStatement::SetVariable { variable, value }); + if let Some(parsed) = extract_set_variable(&mut stmt, &mut params, db_kind) { + return Some(parsed); } if let Some(csv_import) = extract_csv_copy_statement(&mut stmt) { return Some(ParsedStatement::CsvImport(csv_import)); @@ -462,7 +461,7 @@ fn extract_set_variable( stmt: &mut Statement, params: &mut Vec, db_kind: AnyKind, -) -> Option<(StmtParam, StmtWithParams)> { +) -> Option { if let Statement::SetVariable { variables: OneOrManyWithParens::One(ObjectName(name)), value, @@ -479,29 +478,19 @@ fn extract_set_variable( let owned_expr = std::mem::replace(value, Expr::Value(Value::Null)); let mut select_stmt: Statement = expr_to_statement(owned_expr); let delayed_functions = extract_toplevel_functions(&mut select_stmt); - if let Err(err) = validate_function_calls(&mut select_stmt) { - return Some(( - variable, - StmtWithParams { - query: format!("SELECT '' WHERE false -- {}", err), - query_position: extract_query_start(&select_stmt), - params: std::mem::take(params), - delayed_functions: vec![], - json_columns: vec![], - }, - )); + if let Err(err) = validate_function_calls(&select_stmt) { + return Some(ParsedStatement::Error(err)); } let json_columns = extract_json_columns(&select_stmt, db_kind); - return Some(( - variable, - StmtWithParams { - query: select_stmt.to_string(), - query_position: extract_query_start(&select_stmt), - params: std::mem::take(params), - delayed_functions, - json_columns, - }, - )); + let mut value = StmtWithParams { + query: select_stmt.to_string(), + query_position: extract_query_start(&select_stmt), + params: std::mem::take(params), + delayed_functions, + json_columns, + }; + transform_to_positional_placeholders(&mut value, db_kind); + return Some(ParsedStatement::SetVariable { variable, value }); } } None @@ -607,7 +596,7 @@ impl ParameterExtractor { return index <= self.parameters.len() + 1; } } - return false; + false } } @@ -1419,18 +1408,18 @@ mod test { json_array(1, 2, 3) AS json_col2, (SELECT json_build_object('nested', subq.val) FROM (SELECT AVG(x) AS val FROM generate_series(1, 5) x) subq - ) AS json_col3, -- not supported because of the subquery - CASE - WHEN EXISTS (SELECT 1 FROM json_cte WHERE cte_json->>'a' = '2') - THEN to_json(ARRAY(SELECT cte_json FROM json_cte)) - ELSE json_build_array() - END AS json_col4, -- not supported because of the CASE - json_unknown_fn(regular_column) AS non_json_col, - CAST(json_col1 AS json) AS json_col6 - FROM some_table - CROSS JOIN json_cte - WHERE json_typeof(json_col1) = 'object' - "; + ) AS json_col3, -- not supported because of the subquery + CASE + WHEN EXISTS (SELECT 1 FROM json_cte WHERE cte_json->>'a' = '2') + THEN to_json(ARRAY(SELECT cte_json FROM json_cte)) + ELSE json_build_array() + END AS json_col4, -- not supported because of the CASE + json_unknown_fn(regular_column) AS non_json_col, + CAST(json_col1 AS json) AS json_col6 + FROM some_table + CROSS JOIN json_cte + WHERE json_typeof(json_col1) = 'object' + "; let stmt = parse_postgres_stmt(sql); let json_columns = extract_json_columns(&stmt, AnyKind::Sqlite); @@ -1570,4 +1559,21 @@ mod test { ] ); } + + #[test] + fn test_set_variable_error_handling() { + let sql = "set x = db_function(sqlpage.fetch(other_db_function()))"; + for &(dialect, db_kind) in ALL_DIALECTS { + let mut parser = Parser::new(dialect).try_with_sql(sql).unwrap(); + let stmt = parse_single_statement(&mut parser, db_kind, sql); + if let Some(ParsedStatement::Error(err)) = stmt { + assert!( + err.to_string().contains("Invalid SQLPage function call"), + "Expected error for invalid function, got: {err}" + ); + } else { + panic!("Expected error for invalid function, got: {stmt:#?}"); + } + } + } }