Skip to content

Plugins! #420

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 3, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Some queries
  • Loading branch information
levkk committed May 2, 2023
commit 604bf995fb1156ad4922c2f5117a4f02852b407a
14 changes: 13 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pgcat"
version = "1.0.1"
version = "1.0.2-alpha1"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand All @@ -19,7 +19,7 @@ serde_derive = "1"
regex = "1"
num_cpus = "1"
once_cell = "1"
sqlparser = "0.33.0"
sqlparser = {version = "0.33", features = ["visitor"] }
log = "0.4"
arc-swap = "1"
env_logger = "0.10"
Expand Down
8 changes: 6 additions & 2 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -815,15 +815,19 @@ where

'Q' => {
if query_router.query_parser_enabled() {
query_router.infer(&message);
if let Ok(ast) = QueryRouter::parse(&message) {
let _ = query_router.infer(&ast);
}
}
}

'P' => {
self.buffer.put(&message[..]);

if query_router.query_parser_enabled() {
query_router.infer(&message);
if let Ok(ast) = QueryRouter::parse(&message) {
let _ = query_router.infer(&ast);
}
}

continue;
Expand Down
1 change: 1 addition & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ pub struct General {
pub admin_username: String,
pub admin_password: String,

// Support for auth query
pub auth_query: Option<String>,
pub auth_query_user: Option<String>,
pub auth_query_password: Option<String>,
Expand Down
3 changes: 3 additions & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ pub enum Error {
ParseBytesError(String),
AuthError(String),
AuthPassthroughError(String),
UnsupportedStatement,
QueryRouterParserError(String),
PermissionDeniedTable(String),
}

#[derive(Clone, PartialEq, Debug)]
Expand Down
2 changes: 2 additions & 0 deletions src/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,8 @@ impl ConnectionPool {
);
}

debug!("Query router: {}", pool_config.query_parser_enabled);

let pool = ConnectionPool {
databases: shards,
stats: pool_stats,
Expand Down
153 changes: 106 additions & 47 deletions src/query_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser;

use crate::config::Role;
use crate::errors::Error;
use crate::messages::BytesMutReader;
use crate::pool::PoolSettings;
use crate::sharding::Sharder;
Expand Down Expand Up @@ -324,10 +325,7 @@ impl QueryRouter {
Some((command, value))
}

/// Try to infer which server to connect to based on the contents of the query.
pub fn infer(&mut self, message: &BytesMut) -> bool {
debug!("Inferring role");

pub fn parse(message: &BytesMut) -> Result<Vec<sqlparser::ast::Statement>, Error> {
let mut message_cursor = Cursor::new(message);

let code = message_cursor.get_u8() as char;
Expand All @@ -353,28 +351,33 @@ impl QueryRouter {
query
}

_ => return false,
_ => return Err(Error::UnsupportedStatement),
};

let ast = match Parser::parse_sql(&PostgreSqlDialect {}, &query) {
Ok(ast) => ast,
match Parser::parse_sql(&PostgreSqlDialect {}, &query) {
Ok(ast) => {
debug!("AST: {:?}", ast);
Ok(ast)
}

Err(err) => {
// SELECT ... FOR UPDATE won't get parsed correctly.
debug!("{}: {}", err, query);
self.active_role = Some(Role::Primary);
return false;
Err(Error::QueryRouterParserError(err.to_string()))
}
};
}
}

debug!("AST: {:?}", ast);
/// Try to infer which server to connect to based on the contents of the query.
pub fn infer(&mut self, ast: &Vec<sqlparser::ast::Statement>) -> Result<(), Error> {
debug!("Inferring role");

if ast.is_empty() {
// That's weird, no idea, let's go to primary
self.active_role = Some(Role::Primary);
return false;
return Err(Error::QueryRouterParserError("empty query".into()));
}

for q in &ast {
for q in ast {
match q {
// All transactions go to the primary, probably a write.
StartTransaction { .. } => {
Expand Down Expand Up @@ -418,7 +421,7 @@ impl QueryRouter {
};
}

true
Ok(())
}

/// Parse the shard number from the Bind message
Expand Down Expand Up @@ -862,7 +865,7 @@ mod test {

for query in queries {
// It's a recognized query
assert!(qr.infer(&query));
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Replica));
}
}
Expand All @@ -881,7 +884,7 @@ mod test {

for query in queries {
// It's a recognized query
assert!(qr.infer(&query));
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Primary));
}
}
Expand All @@ -893,7 +896,7 @@ mod test {
let query = simple_query("SELECT * FROM items WHERE id = 5");
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None);

assert!(qr.infer(&query));
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), None);
}

Expand All @@ -913,7 +916,7 @@ mod test {
res.put(prepared_stmt);
res.put_i16(0);

assert!(qr.infer(&res));
assert!(qr.infer(&QueryRouter::parse(&res).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Replica));
}

Expand Down Expand Up @@ -1077,11 +1080,11 @@ mod test {
assert_eq!(qr.role(), None);

let query = simple_query("INSERT INTO test_table VALUES (1)");
assert!(qr.infer(&query));
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Primary));

let query = simple_query("SELECT * FROM test_table");
assert!(qr.infer(&query));
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Replica));

assert!(qr.query_parser_enabled());
Expand Down Expand Up @@ -1142,15 +1145,24 @@ mod test {
QueryRouter::setup();

let mut qr = QueryRouter::new();
assert!(qr.infer(&simple_query("BEGIN; SELECT 1; COMMIT;")));
assert!(qr
.infer(&QueryRouter::parse(&simple_query("BEGIN; SELECT 1; COMMIT;")).unwrap())
.is_ok());
assert_eq!(qr.role(), Role::Primary);

assert!(qr.infer(&simple_query("SELECT 1; SELECT 2;")));
assert!(qr
.infer(&QueryRouter::parse(&simple_query("SELECT 1; SELECT 2;")).unwrap())
.is_ok());
assert_eq!(qr.role(), Role::Replica);

assert!(qr.infer(&simple_query(
"SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
)));
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
"SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.role(), Role::Primary);
}

Expand Down Expand Up @@ -1208,47 +1220,84 @@ mod test {
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
qr.pool_settings.shards = 3;

assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 5")));
assert!(qr
.infer(&QueryRouter::parse(&simple_query("SELECT * FROM data WHERE id = 5")).unwrap())
.is_ok());
assert_eq!(qr.shard(), 2);

assert!(qr.infer(&simple_query(
"SELECT one, two, three FROM public.data WHERE id = 6"
)));
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
"SELECT one, two, three FROM public.data WHERE id = 6"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 0);

assert!(qr.infer(&simple_query(
"SELECT * FROM data
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
"SELECT * FROM data
INNER JOIN t2 ON data.id = 5
AND t2.data_id = data.id
WHERE data.id = 5"
)));
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 2);

// Shard did not move because we couldn't determine the sharding key since it could be ambiguous
// in the query.
assert!(qr.infer(&simple_query(
"SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
)));
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
"SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 2);

assert!(qr.infer(&simple_query(
r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
)));
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 0);

assert!(qr.infer(&simple_query(
r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
)));
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 2);

// Super unique sharding key
qr.pool_settings.automatic_sharding_key = Some("*.unique_enough_column_name".to_string());
assert!(qr.infer(&simple_query(
"SELECT * FROM table_x WHERE unique_enough_column_name = 6"
)));
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
"SELECT * FROM table_x WHERE unique_enough_column_name = 6"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 0);

assert!(qr.infer(&simple_query("SELECT * FROM table_y WHERE another_key = 5")));
assert!(qr
.infer(
&QueryRouter::parse(&simple_query("SELECT * FROM table_y WHERE another_key = 5"))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 0);
}

Expand All @@ -1272,11 +1321,21 @@ mod test {
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
qr.pool_settings.shards = 3;

assert!(qr.infer(&simple_query(stmt)));
assert!(qr
.infer(&QueryRouter::parse(&simple_query(stmt)).unwrap())
.is_ok());
assert_eq!(qr.placeholders.len(), 1);

assert!(qr.infer_shard_from_bind(&bind));
assert_eq!(qr.shard(), 2);
assert!(qr.placeholders.is_empty());
}

#[test]
fn test_parse() {
let query = simple_query("SELECT * FROM pg_database");
let ast = QueryRouter::parse(&query);

assert!(ast.is_ok());
}
}