diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index e3db0782..e8393757 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -100,12 +100,21 @@ jobs: uses: biomejs/setup-biome@v2 with: version: latest + - name: Run Lints run: | cargo sqlx prepare --check --workspace - cargo clippy + cargo clippy --fix cargo run -p rules_check - biome lint + biome lint --write + + - name: Check for changes + run: | + if [[ $(git status --porcelain) ]]; then + git status + git diff + exit 1 + fi test: name: Test diff --git a/crates/pgt_completions/benches/sanitization.rs b/crates/pgt_completions/benches/sanitization.rs index 1e5333ff..50c2a0e3 100644 --- a/crates/pgt_completions/benches/sanitization.rs +++ b/crates/pgt_completions/benches/sanitization.rs @@ -13,7 +13,7 @@ fn sql_and_pos(sql: &str) -> (String, usize) { fn get_tree(sql: &str) -> tree_sitter::Tree { let mut parser = tree_sitter::Parser::new(); parser.set_language(tree_sitter_sql::language()).unwrap(); - parser.parse(sql.to_string(), None).unwrap() + parser.parse(sql, None).unwrap() } fn to_params<'a>( @@ -25,9 +25,9 @@ fn to_params<'a>( let pos: u32 = pos.try_into().unwrap(); CompletionParams { position: TextSize::new(pos), - schema: &cache, + schema: cache, text, - tree: tree, + tree, } } diff --git a/crates/pgt_completions/src/providers/columns.rs b/crates/pgt_completions/src/providers/columns.rs index 2898b63f..b792ba2c 100644 --- a/crates/pgt_completions/src/providers/columns.rs +++ b/crates/pgt_completions/src/providers/columns.rs @@ -204,8 +204,7 @@ mod tests { let has_column_in_first_four = |col: &'static str| { first_four .iter() - .find(|compl_item| compl_item.label.as_str() == col) - .is_some() + .any(|compl_item| compl_item.label.as_str() == col) }; assert!( diff --git a/crates/pgt_completions/src/providers/schemas.rs b/crates/pgt_completions/src/providers/schemas.rs index 2f41e8c3..6e86ab56 100644 --- a/crates/pgt_completions/src/providers/schemas.rs +++ b/crates/pgt_completions/src/providers/schemas.rs @@ -7,7 +7,7 @@ pub fn complete_schemas(ctx: &CompletionContext, builder: &mut CompletionBuilder let available_schemas = &ctx.schema_cache.schemas; for schema in available_schemas { - let relevance = CompletionRelevanceData::Schema(&schema); + let relevance = CompletionRelevanceData::Schema(schema); let item = CompletionItem { label: schema.name.clone(), diff --git a/crates/pgt_completions/src/sanitization.rs b/crates/pgt_completions/src/sanitization.rs index dc093847..710d488d 100644 --- a/crates/pgt_completions/src/sanitization.rs +++ b/crates/pgt_completions/src/sanitization.rs @@ -12,7 +12,7 @@ pub(crate) struct SanitizedCompletionParams<'a> { } pub fn benchmark_sanitization(params: CompletionParams) -> String { - let params: SanitizedCompletionParams = params.try_into().unwrap(); + let params: SanitizedCompletionParams = params.into(); params.text } @@ -212,7 +212,7 @@ mod tests { .set_language(tree_sitter_sql::language()) .expect("Error loading sql language"); - let mut tree = parser.parse(input.to_string(), None).unwrap(); + let mut tree = parser.parse(input, None).unwrap(); // select | from users; <-- just right, one space after select token, one space before from assert!(cursor_inbetween_nodes(&mut tree, TextSize::new(7))); @@ -236,7 +236,7 @@ mod tests { .set_language(tree_sitter_sql::language()) .expect("Error loading sql language"); - let mut tree = parser.parse(input.to_string(), None).unwrap(); + let mut tree = parser.parse(input, None).unwrap(); // select * from| <-- still on previous token assert!(!cursor_prepared_to_write_token_after_last_node( @@ -274,7 +274,7 @@ mod tests { .set_language(tree_sitter_sql::language()) .expect("Error loading sql language"); - let mut tree = parser.parse(input.to_string(), None).unwrap(); + let mut tree = parser.parse(input, None).unwrap(); // select * from ;| <-- it's after the statement assert!(!cursor_before_semicolon(&mut tree, TextSize::new(19))); diff --git a/crates/pgt_statement_splitter/src/lib.rs b/crates/pgt_statement_splitter/src/lib.rs index 3fa67213..e5e995b7 100644 --- a/crates/pgt_statement_splitter/src/lib.rs +++ b/crates/pgt_statement_splitter/src/lib.rs @@ -105,6 +105,14 @@ mod tests { .expect_statements(vec!["select 1 from contact", "select 1"]); } + #[test] + fn grant() { + Tester::from("GRANT SELECT ON TABLE \"public\".\"my_table\" TO \"my_role\";") + .expect_statements(vec![ + "GRANT SELECT ON TABLE \"public\".\"my_table\" TO \"my_role\";", + ]); + } + #[test] fn double_newlines() { Tester::from("select 1 from contact\n\nselect 1\n\nselect 3").expect_statements(vec![ diff --git a/crates/pgt_statement_splitter/src/parser/common.rs b/crates/pgt_statement_splitter/src/parser/common.rs index a353791b..ab3f8173 100644 --- a/crates/pgt_statement_splitter/src/parser/common.rs +++ b/crates/pgt_statement_splitter/src/parser/common.rs @@ -205,6 +205,8 @@ pub(crate) fn unknown(p: &mut Parser, exclude: &[SyntaxKind]) { SyntaxKind::All, // for UNION ... EXCEPT SyntaxKind::Except, + // for grant + SyntaxKind::Grant, ] .iter() .all(|x| Some(x) != prev.as_ref()) @@ -230,6 +232,8 @@ pub(crate) fn unknown(p: &mut Parser, exclude: &[SyntaxKind]) { SyntaxKind::Also, // for create rule SyntaxKind::Instead, + // for grant + SyntaxKind::Grant, ] .iter() .all(|x| Some(x) != prev.as_ref()) diff --git a/crates/pgt_typecheck/src/lib.rs b/crates/pgt_typecheck/src/lib.rs index 9311bb8e..f741c0e6 100644 --- a/crates/pgt_typecheck/src/lib.rs +++ b/crates/pgt_typecheck/src/lib.rs @@ -3,10 +3,9 @@ mod diagnostics; pub use diagnostics::TypecheckDiagnostic; use diagnostics::create_type_error; use pgt_text_size::TextRange; -use sqlx::Executor; -use sqlx::PgPool; use sqlx::postgres::PgDatabaseError; pub use sqlx::postgres::PgSeverity; +use sqlx::{Executor, PgPool}; #[derive(Debug)] pub struct TypecheckParams<'a> { @@ -29,7 +28,9 @@ pub struct TypeError { pub constraint: Option, } -pub async fn check_sql(params: TypecheckParams<'_>) -> Option { +pub async fn check_sql( + params: TypecheckParams<'_>, +) -> Result, sqlx::Error> { // Check if the AST is not a supported statement type if !matches!( params.ast, @@ -39,13 +40,10 @@ pub async fn check_sql(params: TypecheckParams<'_>) -> Option c, - Err(_) => return None, - }; + let mut conn = params.conn.acquire().await?; // Postgres caches prepared statements within the current DB session (connection). // This can cause issues if the underlying table schema changes while statements @@ -56,11 +54,11 @@ pub async fn check_sql(params: TypecheckParams<'_>) -> Option None, + Ok(_) => Ok(None), Err(sqlx::Error::Database(err)) => { let pg_err = err.downcast_ref::(); - Some(create_type_error(pg_err, params.tree)) + Ok(Some(create_type_error(pg_err, params.tree))) } - Err(_) => None, + Err(err) => Err(err), } } diff --git a/crates/pgt_typecheck/tests/diagnostics.rs b/crates/pgt_typecheck/tests/diagnostics.rs index 46daa8a1..4c780d74 100644 --- a/crates/pgt_typecheck/tests/diagnostics.rs +++ b/crates/pgt_typecheck/tests/diagnostics.rs @@ -37,7 +37,7 @@ async fn test(name: &str, query: &str, setup: &str) { Formatter::new(&mut writer) .write_markup(markup! { - {PrintDiagnostic::simple(&result.unwrap())} + {PrintDiagnostic::simple(&result.unwrap().unwrap())} }) .unwrap(); diff --git a/crates/pgt_workspace/src/features/completions.rs b/crates/pgt_workspace/src/features/completions.rs index 4a5c5e29..85342183 100644 --- a/crates/pgt_workspace/src/features/completions.rs +++ b/crates/pgt_workspace/src/features/completions.rs @@ -29,8 +29,8 @@ impl IntoIterator for CompletionsResult { } } -pub(crate) fn get_statement_for_completions<'a>( - doc: &'a ParsedDocument, +pub(crate) fn get_statement_for_completions( + doc: &ParsedDocument, position: TextSize, ) -> Option<(StatementId, TextRange, String, Arc)> { let count = doc.count(); @@ -89,7 +89,7 @@ mod tests { ( ParsedDocument::new( PgTPath::new("test.sql"), - sql.replace(CURSOR_POSITION, "").into(), + sql.replace(CURSOR_POSITION, ""), 5, ), TextSize::new(pos), @@ -119,14 +119,11 @@ mod tests { #[test] fn does_not_break_when_no_statements_exist() { - let sql = format!("{}", CURSOR_POSITION); + let sql = CURSOR_POSITION.to_string(); let (doc, position) = get_doc_and_pos(sql.as_str()); - assert!(matches!( - get_statement_for_completions(&doc, position), - None - )); + assert!(get_statement_for_completions(&doc, position).is_none()); } #[test] @@ -138,10 +135,7 @@ mod tests { // make sure these are parsed as two assert_eq!(doc.count(), 2); - assert!(matches!( - get_statement_for_completions(&doc, position), - None - )); + assert!(get_statement_for_completions(&doc, position).is_none()); } #[test] @@ -174,10 +168,7 @@ mod tests { let (doc, position) = get_doc_and_pos(sql.as_str()); - assert!(matches!( - get_statement_for_completions(&doc, position), - None - )); + assert!(get_statement_for_completions(&doc, position).is_none()); } #[test] @@ -186,9 +177,6 @@ mod tests { let (doc, position) = get_doc_and_pos(sql.as_str()); - assert!(matches!( - get_statement_for_completions(&doc, position), - None - )); + assert!(get_statement_for_completions(&doc, position).is_none()); } } diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index 2ad119f5..3bf540cc 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -360,8 +360,6 @@ impl Workspace for WorkspaceServer { let mut diagnostics: Vec = parser.document_diagnostics().to_vec(); - // TODO: run this in parallel with rayon based on rayon.count() - if let Some(pool) = self .connection .read() @@ -385,13 +383,15 @@ impl Workspace for WorkspaceServer { }) .await .map(|d| { - let r = d.location().span.map(|span| span + range.start()); + d.map(|d| { + let r = d.location().span.map(|span| span + range.start()); - d.with_file_path(path.as_path().display().to_string()) - .with_file_span(r.unwrap_or(range)) + d.with_file_path(path.as_path().display().to_string()) + .with_file_span(r.unwrap_or(range)) + }) }) } else { - None + Ok(None) } } }) @@ -400,8 +400,11 @@ impl Workspace for WorkspaceServer { .await })?; - for result in async_results.into_iter().flatten() { - diagnostics.push(SDiagnostic::new(result)); + for result in async_results.into_iter() { + let result = result?; + if let Some(diag) = result { + diagnostics.push(SDiagnostic::new(diag)); + } } } diff --git a/crates/pgt_workspace/src/workspace/server/change.rs b/crates/pgt_workspace/src/workspace/server/change.rs index 7dcd1a55..039c42db 100644 --- a/crates/pgt_workspace/src/workspace/server/change.rs +++ b/crates/pgt_workspace/src/workspace/server/change.rs @@ -460,6 +460,72 @@ mod tests { assert!(d.has_fatal_error()); } + #[test] + fn typing_comments() { + let path = PgTPath::new("test.sql"); + let input = "select id from users;\n"; + + let mut d = Document::new(input.to_string(), 0); + + let change1 = ChangeFileParams { + path: path.clone(), + version: 1, + changes: vec![ChangeParams { + text: "-".to_string(), + range: Some(TextRange::new(22.into(), 23.into())), + }], + }; + + let _changed1 = d.apply_file_change(&change1); + + assert_eq!(d.content, "select id from users;\n-"); + assert_eq!(d.positions.len(), 2); + + let change2 = ChangeFileParams { + path: path.clone(), + version: 2, + changes: vec![ChangeParams { + text: "-".to_string(), + range: Some(TextRange::new(23.into(), 24.into())), + }], + }; + + let _changed2 = d.apply_file_change(&change2); + + assert_eq!(d.content, "select id from users;\n--"); + assert_eq!(d.positions.len(), 1); + + let change3 = ChangeFileParams { + path: path.clone(), + version: 3, + changes: vec![ChangeParams { + text: " ".to_string(), + range: Some(TextRange::new(24.into(), 25.into())), + }], + }; + + let _changed3 = d.apply_file_change(&change3); + + assert_eq!(d.content, "select id from users;\n-- "); + assert_eq!(d.positions.len(), 1); + + let change4 = ChangeFileParams { + path: path.clone(), + version: 3, + changes: vec![ChangeParams { + text: "t".to_string(), + range: Some(TextRange::new(25.into(), 26.into())), + }], + }; + + let _changed4 = d.apply_file_change(&change4); + + assert_eq!(d.content, "select id from users;\n-- t"); + assert_eq!(d.positions.len(), 1); + + assert_document_integrity(&d); + } + #[test] fn change_into_scan_error_within_statement() { let path = PgTPath::new("test.sql"); diff --git a/crates/pgt_workspace/src/workspace/server/parsed_document.rs b/crates/pgt_workspace/src/workspace/server/parsed_document.rs index dafd5697..01f18d3c 100644 --- a/crates/pgt_workspace/src/workspace/server/parsed_document.rs +++ b/crates/pgt_workspace/src/workspace/server/parsed_document.rs @@ -404,3 +404,28 @@ impl<'a> StatementFilter<'a> for IdFilter { *id == self.id } } + +#[cfg(test)] +mod tests { + use super::*; + + use pgt_fs::PgTPath; + + #[test] + fn sql_function_body() { + let input = "CREATE FUNCTION add(integer, integer) RETURNS integer + AS 'select $1 + $2;' + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT;"; + + let path = PgTPath::new("test.sql"); + + let d = ParsedDocument::new(path, input.to_string(), 0); + + let stmts = d.iter(DefaultMapper).collect::>(); + + assert_eq!(stmts.len(), 2); + assert_eq!(stmts[1].2, "select $1 + $2;"); + } +} diff --git a/crates/pgt_workspace/src/workspace/server/sql_function.rs b/crates/pgt_workspace/src/workspace/server/sql_function.rs index 3273466d..777210d5 100644 --- a/crates/pgt_workspace/src/workspace/server/sql_function.rs +++ b/crates/pgt_workspace/src/workspace/server/sql_function.rs @@ -5,6 +5,7 @@ use pgt_text_size::TextRange; use super::statement_identifier::StatementId; +#[derive(Debug, Clone)] pub struct SQLFunctionBody { pub range: TextRange, pub body: String, @@ -97,6 +98,16 @@ fn find_option_value( .find_map(|arg| { if let pgt_query_ext::NodeEnum::String(s) = arg { Some(s.sval.clone()) + } else if let pgt_query_ext::NodeEnum::List(l) = arg { + l.items.iter().find_map(|item_wrapper| { + if let Some(pgt_query_ext::NodeEnum::String(s)) = + item_wrapper.node.as_ref() + { + Some(s.sval.clone()) + } else { + None + } + }) } else { None }